gcn_norm for bigraph (ex: user-item for recsys) #10506
Unanswered
winiciuspontes
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hey everyone, I'm implementing the LightGCN model from scratch and wanted to compare my results with the layer implemented by the team. However, I noticed that the normalization using gcn_norm does not take into account the number of edge_index[0] (users in my case). In other words, if there is a user with many interactions, there will be no normalization for them relative to the others.
Is there any implementation or parameter I can use in gcn_norm for bipartite graphs, or would I have to implement it manually
Example:
import torch
from torch_geometric.utils import scatter
from torch_geometric.nn.conv.gcn_conv import gcn_norm
def norm_bipartite(edge_index, num_nodes):
row, col = edge_index
deg_u = scatter(torch.ones_like(row, dtype=torch.float32), row, dim=0, reduce="sum", dim_size=num_nodes)
deg_i = scatter(torch.ones_like(col, dtype=torch.float32), col, dim=0, reduce="sum", dim_size=num_nodes)
deg_u_inv_sqrt = torch.where(deg_u > 0, torch.rsqrt(deg_u), torch.zeros_like(deg_u))
deg_i_inv_sqrt = torch.where(deg_i > 0, torch.rsqrt(deg_i), torch.zeros_like(deg_i))
edge_weight = deg_u_inv_sqrt[row] * deg_i_inv_sqrt[col]
edge_index = torch.tensor([
[0, 0, 1, 1, 2, 2, 2, 2], # users
[3, 4, 3, 5, 4, 5, 3, 4] # Items
], dtype=torch.long)
num_nodes = edge_index.max().item() + 1
edge_index_bigraph, edge_weight_bigraph = norm_bipartite(edge_index, num_nodes)
print("norm for bigraph: ")
print(edge_weight_bigraph.numpy())
edge_index_gcn, edge_weight_gcn = gcn_norm(edge_index, edge_weight=None, num_nodes=num_nodes)
print("norm using gcn_norm for bigraph:")
norm for bigraph:
[0.40824828 0.40824828 0.40824828 0.49999997 0.28867513 0.35355338
0.28867513 0.28867513]
norm using gcn_norm for bigraph:
tensor([0.5000, 0.5000, 0.5000, 0.5774, 0.5000, 0.5774, 0.5000, 0.5000, 1.0000,
1.0000, 1.0000, 0.2500, 0.2500, 0.3333])
Beta Was this translation helpful? Give feedback.
All reactions