pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +14 -2
- torch_geometric/_compile.py +9 -3
- torch_geometric/_onnx.py +214 -0
- torch_geometric/config_mixin.py +5 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +109 -5
- torch_geometric/data/database.py +4 -0
- torch_geometric/data/dataset.py +14 -11
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +18 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +15 -17
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +310 -209
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/partition.py +2 -2
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +18 -14
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +87 -3
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +14 -5
- torch_geometric/inspector.py +4 -0
- torch_geometric/io/fs.py +5 -4
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +4 -4
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +3 -2
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +14 -0
- torch_geometric/metrics/link_pred.py +745 -92
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +8 -2
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +2 -1
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +5 -4
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +0 -20
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +14 -12
- torch_geometric/nn/model_hub.py +2 -15
- torch_geometric/nn/models/__init__.py +11 -2
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +8 -4
- torch_geometric/nn/pool/cluster_pool.py +3 -4
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +336 -7
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +296 -23
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +8 -6
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -8
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/typing.py +70 -17
- torch_geometric/utils/__init__.py +4 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +27 -12
- torch_geometric/utils/_scatter.py +132 -195
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_to_dense_batch.py +2 -2
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +4 -3
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +13 -9
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +7 -14
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/sentence_transformer.py +0 -101
torch_geometric/utils/sparse.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import typing
|
|
2
1
|
import warnings
|
|
3
2
|
from typing import Any, List, Optional, Tuple, Union
|
|
4
3
|
|
|
@@ -71,8 +70,9 @@ def dense_to_sparse(
|
|
|
71
70
|
f"three-dimensional (got {adj.dim()} dimensions)")
|
|
72
71
|
|
|
73
72
|
if mask is not None and adj.dim() == 2:
|
|
74
|
-
warnings.warn(
|
|
75
|
-
|
|
73
|
+
warnings.warn(
|
|
74
|
+
"Mask should not be provided in case the dense "
|
|
75
|
+
"adjacency matrix is two-dimensional", stacklevel=2)
|
|
76
76
|
mask = None
|
|
77
77
|
|
|
78
78
|
if mask is not None and mask.dim() != 2:
|
|
@@ -124,8 +124,7 @@ def is_torch_sparse_tensor(src: Any) -> bool:
|
|
|
124
124
|
return True
|
|
125
125
|
if src.layout == torch.sparse_csr:
|
|
126
126
|
return True
|
|
127
|
-
if
|
|
128
|
-
and src.layout == torch.sparse_csc):
|
|
127
|
+
if src.layout == torch.sparse_csc:
|
|
129
128
|
return True
|
|
130
129
|
return False
|
|
131
130
|
|
|
@@ -320,12 +319,6 @@ def to_torch_csc_tensor(
|
|
|
320
319
|
size=(4, 4), nnz=6, layout=torch.sparse_csc)
|
|
321
320
|
|
|
322
321
|
"""
|
|
323
|
-
if not torch_geometric.typing.WITH_PT112:
|
|
324
|
-
if typing.TYPE_CHECKING:
|
|
325
|
-
raise NotImplementedError
|
|
326
|
-
return torch_geometric.typing.MockTorchCSCTensor(
|
|
327
|
-
edge_index, edge_attr, size)
|
|
328
|
-
|
|
329
322
|
if size is None:
|
|
330
323
|
size = int(edge_index.max()) + 1
|
|
331
324
|
|
|
@@ -392,7 +385,7 @@ def to_torch_sparse_tensor(
|
|
|
392
385
|
return to_torch_coo_tensor(edge_index, edge_attr, size, is_coalesced)
|
|
393
386
|
if layout == torch.sparse_csr:
|
|
394
387
|
return to_torch_csr_tensor(edge_index, edge_attr, size, is_coalesced)
|
|
395
|
-
if
|
|
388
|
+
if layout == torch.sparse_csc:
|
|
396
389
|
return to_torch_csc_tensor(edge_index, edge_attr, size, is_coalesced)
|
|
397
390
|
|
|
398
391
|
raise ValueError(f"Unexpected sparse tensor layout (got '{layout}')")
|
|
@@ -431,7 +424,7 @@ def to_edge_index(adj: Union[Tensor, SparseTensor]) -> Tuple[Tensor, Tensor]:
|
|
|
431
424
|
col = adj.col_indices().detach()
|
|
432
425
|
return torch.stack([row, col], dim=0).long(), adj.values()
|
|
433
426
|
|
|
434
|
-
if
|
|
427
|
+
if adj.layout == torch.sparse_csc:
|
|
435
428
|
col = ptr2index(adj.ccol_indices().detach())
|
|
436
429
|
row = adj.row_indices().detach()
|
|
437
430
|
return torch.stack([row, col], dim=0).long(), adj.values()
|
|
@@ -480,7 +473,7 @@ def set_sparse_value(adj: Tensor, value: Tensor) -> Tensor:
|
|
|
480
473
|
device=value.device,
|
|
481
474
|
)
|
|
482
475
|
|
|
483
|
-
if
|
|
476
|
+
if adj.layout == torch.sparse_csc:
|
|
484
477
|
return torch.sparse_csc_tensor(
|
|
485
478
|
ccol_indices=adj.ccol_indices(),
|
|
486
479
|
row_indices=adj.row_indices(),
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from math import sqrt
|
|
2
|
-
from typing import Any, List, Optional
|
|
2
|
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
from torch import Tensor
|
|
@@ -132,7 +132,7 @@ def _visualize_graph_via_networkx(
|
|
|
132
132
|
xy=pos[src],
|
|
133
133
|
xytext=pos[dst],
|
|
134
134
|
arrowprops=dict(
|
|
135
|
-
arrowstyle="
|
|
135
|
+
arrowstyle="<-",
|
|
136
136
|
alpha=data['alpha'],
|
|
137
137
|
shrinkA=sqrt(node_size) / 2.0,
|
|
138
138
|
shrinkB=sqrt(node_size) / 2.0,
|
|
@@ -140,9 +140,8 @@ def _visualize_graph_via_networkx(
|
|
|
140
140
|
),
|
|
141
141
|
)
|
|
142
142
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
nodes.set_edgecolor('black')
|
|
143
|
+
nx.draw_networkx_nodes(g, pos, node_size=node_size, node_color='white',
|
|
144
|
+
margins=0.1, edgecolors='black')
|
|
146
145
|
nx.draw_networkx_labels(g, pos, font_size=10)
|
|
147
146
|
|
|
148
147
|
if path is not None:
|
|
@@ -151,3 +150,249 @@ def _visualize_graph_via_networkx(
|
|
|
151
150
|
plt.show()
|
|
152
151
|
|
|
153
152
|
plt.close()
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def visualize_hetero_graph(
|
|
156
|
+
edge_index_dict: Dict[Tuple[str, str, str], Tensor],
|
|
157
|
+
edge_weight_dict: Dict[Tuple[str, str, str], Tensor],
|
|
158
|
+
path: Optional[str] = None,
|
|
159
|
+
backend: Optional[str] = None,
|
|
160
|
+
node_labels_dict: Optional[Dict[str, List[str]]] = None,
|
|
161
|
+
node_weight_dict: Optional[Dict[str, Tensor]] = None,
|
|
162
|
+
node_size_range: Tuple[float, float] = (50, 500),
|
|
163
|
+
node_opacity_range: Tuple[float, float] = (1.0, 1.0),
|
|
164
|
+
edge_width_range: Tuple[float, float] = (0.1, 2.0),
|
|
165
|
+
edge_opacity_range: Tuple[float, float] = (1.0, 1.0),
|
|
166
|
+
) -> Any:
|
|
167
|
+
"""Visualizes a heterogeneous graph using networkx."""
|
|
168
|
+
if backend is not None and backend != "networkx":
|
|
169
|
+
raise ValueError("Only 'networkx' backend is supported")
|
|
170
|
+
|
|
171
|
+
# Filter out edges with 0 weight
|
|
172
|
+
filtered_edge_index_dict = {}
|
|
173
|
+
filtered_edge_weight_dict = {}
|
|
174
|
+
for edge_type in edge_index_dict.keys():
|
|
175
|
+
mask = edge_weight_dict[edge_type] > 0
|
|
176
|
+
if mask.sum() > 0:
|
|
177
|
+
filtered_edge_index_dict[edge_type] = edge_index_dict[
|
|
178
|
+
edge_type][:, mask]
|
|
179
|
+
filtered_edge_weight_dict[edge_type] = edge_weight_dict[edge_type][
|
|
180
|
+
mask]
|
|
181
|
+
|
|
182
|
+
# Get all unique nodes that are still in the filtered edges
|
|
183
|
+
remaining_nodes: Dict[str, Set[int]] = {}
|
|
184
|
+
for edge_type, edge_index in filtered_edge_index_dict.items():
|
|
185
|
+
src_type, _, dst_type = edge_type
|
|
186
|
+
if src_type not in remaining_nodes:
|
|
187
|
+
remaining_nodes[src_type] = set()
|
|
188
|
+
if dst_type not in remaining_nodes:
|
|
189
|
+
remaining_nodes[dst_type] = set()
|
|
190
|
+
remaining_nodes[src_type].update(edge_index[0].tolist())
|
|
191
|
+
remaining_nodes[dst_type].update(edge_index[1].tolist())
|
|
192
|
+
|
|
193
|
+
# Filter node weights to only include remaining nodes
|
|
194
|
+
if node_weight_dict is not None:
|
|
195
|
+
filtered_node_weight_dict = {}
|
|
196
|
+
for node_type, weights in node_weight_dict.items():
|
|
197
|
+
if node_type in remaining_nodes:
|
|
198
|
+
mask = torch.zeros(len(weights), dtype=torch.bool)
|
|
199
|
+
mask[list(remaining_nodes[node_type])] = True
|
|
200
|
+
filtered_node_weight_dict[node_type] = weights[mask]
|
|
201
|
+
node_weight_dict = filtered_node_weight_dict
|
|
202
|
+
|
|
203
|
+
# Filter node labels to only include remaining nodes
|
|
204
|
+
if node_labels_dict is not None:
|
|
205
|
+
filtered_node_labels_dict = {}
|
|
206
|
+
for node_type, labels in node_labels_dict.items():
|
|
207
|
+
if node_type in remaining_nodes:
|
|
208
|
+
filtered_node_labels_dict[node_type] = [
|
|
209
|
+
label for i, label in enumerate(labels)
|
|
210
|
+
if i in remaining_nodes[node_type]
|
|
211
|
+
]
|
|
212
|
+
node_labels_dict = filtered_node_labels_dict
|
|
213
|
+
|
|
214
|
+
return _visualize_hetero_graph_via_networkx(
|
|
215
|
+
filtered_edge_index_dict,
|
|
216
|
+
filtered_edge_weight_dict,
|
|
217
|
+
path,
|
|
218
|
+
node_labels_dict,
|
|
219
|
+
node_weight_dict,
|
|
220
|
+
node_size_range,
|
|
221
|
+
node_opacity_range,
|
|
222
|
+
edge_width_range,
|
|
223
|
+
edge_opacity_range,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _visualize_hetero_graph_via_networkx(
|
|
228
|
+
edge_index_dict: Dict[Tuple[str, str, str], Tensor],
|
|
229
|
+
edge_weight_dict: Dict[Tuple[str, str, str], Tensor],
|
|
230
|
+
path: Optional[str] = None,
|
|
231
|
+
node_labels_dict: Optional[Dict[str, List[str]]] = None,
|
|
232
|
+
node_weight_dict: Optional[Dict[str, Tensor]] = None,
|
|
233
|
+
node_size_range: Tuple[float, float] = (50, 500),
|
|
234
|
+
node_opacity_range: Tuple[float, float] = (1.0, 1.0),
|
|
235
|
+
edge_width_range: Tuple[float, float] = (0.1, 2.0),
|
|
236
|
+
edge_opacity_range: Tuple[float, float] = (1.0, 1.0),
|
|
237
|
+
) -> Any:
|
|
238
|
+
import matplotlib.pyplot as plt
|
|
239
|
+
import networkx as nx
|
|
240
|
+
|
|
241
|
+
g = nx.DiGraph()
|
|
242
|
+
node_offsets: Dict[str, int] = {}
|
|
243
|
+
current_offset = 0
|
|
244
|
+
|
|
245
|
+
# First, collect all unique node types and their counts
|
|
246
|
+
node_types = set()
|
|
247
|
+
node_counts: Dict[str, int] = {}
|
|
248
|
+
remaining_nodes: Dict[str, Set[int]] = {
|
|
249
|
+
} # Track which nodes are actually present in edges
|
|
250
|
+
|
|
251
|
+
# Get all unique nodes that are in the edges
|
|
252
|
+
for edge_type in edge_index_dict.keys():
|
|
253
|
+
src_type, _, dst_type = edge_type
|
|
254
|
+
node_types.add(src_type)
|
|
255
|
+
node_types.add(dst_type)
|
|
256
|
+
|
|
257
|
+
if src_type not in remaining_nodes:
|
|
258
|
+
remaining_nodes[src_type] = set()
|
|
259
|
+
if dst_type not in remaining_nodes:
|
|
260
|
+
remaining_nodes[dst_type] = set()
|
|
261
|
+
|
|
262
|
+
remaining_nodes[src_type].update(
|
|
263
|
+
edge_index_dict[edge_type][0].tolist())
|
|
264
|
+
remaining_nodes[dst_type].update(
|
|
265
|
+
edge_index_dict[edge_type][1].tolist())
|
|
266
|
+
|
|
267
|
+
# Set node counts based on remaining nodes
|
|
268
|
+
for node_type in node_types:
|
|
269
|
+
node_counts[node_type] = len(remaining_nodes[node_type])
|
|
270
|
+
|
|
271
|
+
# Add nodes for each node type
|
|
272
|
+
for node_type in node_types:
|
|
273
|
+
num_nodes = node_counts[node_type]
|
|
274
|
+
node_offsets[node_type] = current_offset
|
|
275
|
+
|
|
276
|
+
# Get node weights if provided
|
|
277
|
+
weights = None
|
|
278
|
+
if node_weight_dict is not None and node_type in node_weight_dict:
|
|
279
|
+
weights = node_weight_dict[node_type]
|
|
280
|
+
if len(weights) != num_nodes:
|
|
281
|
+
raise ValueError(f"Number of weights for node type "
|
|
282
|
+
f"{node_type} ({len(weights)}) does not "
|
|
283
|
+
f"match number of nodes ({num_nodes})")
|
|
284
|
+
|
|
285
|
+
for i in range(num_nodes):
|
|
286
|
+
node_id = current_offset + i
|
|
287
|
+
label = (node_labels_dict[node_type][i]
|
|
288
|
+
if node_labels_dict is not None
|
|
289
|
+
and node_type in node_labels_dict else "")
|
|
290
|
+
|
|
291
|
+
# Calculate node size and opacity if weights provided
|
|
292
|
+
size = node_size_range[1]
|
|
293
|
+
opacity = node_opacity_range[1]
|
|
294
|
+
if weights is not None:
|
|
295
|
+
w = weights[i].item()
|
|
296
|
+
size = node_size_range[0] + w * \
|
|
297
|
+
(node_size_range[1] - node_size_range[0])
|
|
298
|
+
opacity = node_opacity_range[0] + w * \
|
|
299
|
+
(node_opacity_range[1] - node_opacity_range[0])
|
|
300
|
+
|
|
301
|
+
g.add_node(node_id, label=label, type=node_type, size=size,
|
|
302
|
+
alpha=opacity)
|
|
303
|
+
|
|
304
|
+
current_offset += num_nodes
|
|
305
|
+
|
|
306
|
+
# Add edges with remapped node indices
|
|
307
|
+
for edge_type, edge_index in edge_index_dict.items():
|
|
308
|
+
src_type, _, dst_type = edge_type
|
|
309
|
+
edge_weight = edge_weight_dict[edge_type]
|
|
310
|
+
src_offset = node_offsets[src_type]
|
|
311
|
+
dst_offset = node_offsets[dst_type]
|
|
312
|
+
|
|
313
|
+
# Create mappings for source and target nodes
|
|
314
|
+
src_mapping = {
|
|
315
|
+
old_idx: new_idx
|
|
316
|
+
for new_idx, old_idx in enumerate(sorted(
|
|
317
|
+
remaining_nodes[src_type]))
|
|
318
|
+
}
|
|
319
|
+
dst_mapping = {
|
|
320
|
+
old_idx: new_idx
|
|
321
|
+
for new_idx, old_idx in enumerate(sorted(
|
|
322
|
+
remaining_nodes[dst_type]))
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
for (src, dst), w in zip(edge_index.t().tolist(),
|
|
326
|
+
edge_weight.tolist()):
|
|
327
|
+
# Remap node indices
|
|
328
|
+
new_src = src_mapping[src] + src_offset
|
|
329
|
+
new_dst = dst_mapping[dst] + dst_offset
|
|
330
|
+
|
|
331
|
+
# Calculate edge width and opacity based on weight
|
|
332
|
+
width = edge_width_range[0] + w * \
|
|
333
|
+
(edge_width_range[1] - edge_width_range[0])
|
|
334
|
+
opacity = edge_opacity_range[0] + w * \
|
|
335
|
+
(edge_opacity_range[1] - edge_opacity_range[0])
|
|
336
|
+
g.add_edge(new_src, new_dst, width=width, alpha=opacity)
|
|
337
|
+
|
|
338
|
+
# Draw the graph
|
|
339
|
+
ax = plt.gca()
|
|
340
|
+
pos = nx.arf_layout(g)
|
|
341
|
+
|
|
342
|
+
# Draw edges with arrows
|
|
343
|
+
for src, dst, data in g.edges(data=True):
|
|
344
|
+
ax.annotate(
|
|
345
|
+
'',
|
|
346
|
+
xy=pos[src],
|
|
347
|
+
xytext=pos[dst],
|
|
348
|
+
arrowprops=dict(
|
|
349
|
+
arrowstyle="<-",
|
|
350
|
+
alpha=data['alpha'],
|
|
351
|
+
linewidth=data['width'],
|
|
352
|
+
shrinkA=sqrt(g.nodes[src]['size']) / 2.0,
|
|
353
|
+
shrinkB=sqrt(g.nodes[dst]['size']) / 2.0,
|
|
354
|
+
connectionstyle="arc3,rad=0.1",
|
|
355
|
+
),
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
# Draw nodes colored by type
|
|
359
|
+
node_colors = []
|
|
360
|
+
node_sizes = []
|
|
361
|
+
node_alphas = []
|
|
362
|
+
|
|
363
|
+
# Use matplotlib tab20 colormap for consistent coloring
|
|
364
|
+
tab10_cmap = plt.cm.tab10 # type: ignore[attr-defined]
|
|
365
|
+
node_type_colors: Dict[str, Any] = {} # Store color for each node type
|
|
366
|
+
for node in g.nodes():
|
|
367
|
+
node_type = g.nodes[node]['type']
|
|
368
|
+
# Assign a consistent color for each node type
|
|
369
|
+
if node_type not in node_type_colors:
|
|
370
|
+
color_idx = len(node_type_colors) % 10 # Cycle through colors
|
|
371
|
+
node_type_colors[node_type] = tab10_cmap(color_idx)
|
|
372
|
+
node_colors.append(node_type_colors[node_type])
|
|
373
|
+
node_sizes.append(g.nodes[node]['size'])
|
|
374
|
+
node_alphas.append(g.nodes[node]['alpha'])
|
|
375
|
+
|
|
376
|
+
nx.draw_networkx_nodes(g, pos, node_size=node_sizes,
|
|
377
|
+
node_color=node_colors, margins=0.1,
|
|
378
|
+
alpha=node_alphas)
|
|
379
|
+
|
|
380
|
+
# Draw labels
|
|
381
|
+
labels = nx.get_node_attributes(g, 'label')
|
|
382
|
+
nx.draw_networkx_labels(g, pos, labels, font_size=10)
|
|
383
|
+
|
|
384
|
+
# Add legend
|
|
385
|
+
legend_elements = []
|
|
386
|
+
for node_type, color in node_type_colors.items():
|
|
387
|
+
legend_elements.append(
|
|
388
|
+
plt.Line2D([0], [0], marker='o', color='w', label=node_type,
|
|
389
|
+
markerfacecolor=color, markersize=10))
|
|
390
|
+
ax.legend(handles=legend_elements, loc='upper right',
|
|
391
|
+
bbox_to_anchor=(0.9, 1))
|
|
392
|
+
|
|
393
|
+
if path is not None:
|
|
394
|
+
plt.savefig(path, bbox_inches='tight')
|
|
395
|
+
else:
|
|
396
|
+
plt.show()
|
|
397
|
+
|
|
398
|
+
plt.close()
|
torch_geometric/warnings.py
CHANGED
|
@@ -4,11 +4,11 @@ from typing import Literal
|
|
|
4
4
|
import torch_geometric
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
def warn(message: str) -> None:
|
|
7
|
+
def warn(message: str, stacklevel: int = 5) -> None:
|
|
8
8
|
if torch_geometric.is_compiling():
|
|
9
9
|
return
|
|
10
10
|
|
|
11
|
-
warnings.warn(message)
|
|
11
|
+
warnings.warn(message, stacklevel=stacklevel)
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def filterwarnings(
|
|
@@ -19,3 +19,12 @@ def filterwarnings(
|
|
|
19
19
|
return
|
|
20
20
|
|
|
21
21
|
warnings.filterwarnings(action, message)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class WarningCache(set):
|
|
25
|
+
"""Cache for warnings."""
|
|
26
|
+
def warn(self, message: str, stacklevel: int = 5) -> None:
|
|
27
|
+
"""Trigger warning message."""
|
|
28
|
+
if message not in self:
|
|
29
|
+
self.add(message)
|
|
30
|
+
warn(message, stacklevel=stacklevel)
|
|
@@ -1,101 +0,0 @@
|
|
|
1
|
-
from enum import Enum
|
|
2
|
-
from typing import List, Optional, Union
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
import torch.nn.functional as F
|
|
6
|
-
from torch import Tensor
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class PoolingStrategy(Enum):
|
|
10
|
-
MEAN = 'mean'
|
|
11
|
-
LAST = 'last'
|
|
12
|
-
CLS = 'cls'
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class SentenceTransformer(torch.nn.Module):
|
|
16
|
-
def __init__(
|
|
17
|
-
self,
|
|
18
|
-
model_name: str,
|
|
19
|
-
pooling_strategy: Union[PoolingStrategy, str] = 'mean',
|
|
20
|
-
) -> None:
|
|
21
|
-
super().__init__()
|
|
22
|
-
|
|
23
|
-
self.model_name = model_name
|
|
24
|
-
self.pooling_strategy = PoolingStrategy(pooling_strategy)
|
|
25
|
-
|
|
26
|
-
from transformers import AutoModel, AutoTokenizer
|
|
27
|
-
|
|
28
|
-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
29
|
-
self.model = AutoModel.from_pretrained(model_name)
|
|
30
|
-
if self.tokenizer.pad_token is None:
|
|
31
|
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
32
|
-
|
|
33
|
-
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
|
|
34
|
-
out = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
|
35
|
-
|
|
36
|
-
emb = out[0] # First element contains all token embeddings.
|
|
37
|
-
if self.pooling_strategy == PoolingStrategy.MEAN:
|
|
38
|
-
emb = mean_pooling(emb, attention_mask)
|
|
39
|
-
elif self.pooling_strategy == PoolingStrategy.LAST:
|
|
40
|
-
emb = last_pooling(emb, attention_mask)
|
|
41
|
-
else:
|
|
42
|
-
assert self.pooling_strategy == PoolingStrategy.CLS
|
|
43
|
-
emb = emb[:, 0, :]
|
|
44
|
-
|
|
45
|
-
emb = F.normalize(emb, p=2, dim=1)
|
|
46
|
-
return emb
|
|
47
|
-
|
|
48
|
-
@property
|
|
49
|
-
def device(self) -> torch.device:
|
|
50
|
-
return next(iter(self.model.parameters())).device
|
|
51
|
-
|
|
52
|
-
@torch.no_grad()
|
|
53
|
-
def encode(
|
|
54
|
-
self,
|
|
55
|
-
text: List[str],
|
|
56
|
-
batch_size: Optional[int] = None,
|
|
57
|
-
output_device: Optional[Union[torch.device, str]] = None,
|
|
58
|
-
) -> Tensor:
|
|
59
|
-
is_empty = len(text) == 0
|
|
60
|
-
text = ['dummy'] if is_empty else text
|
|
61
|
-
|
|
62
|
-
batch_size = len(text) if batch_size is None else batch_size
|
|
63
|
-
|
|
64
|
-
embs: List[Tensor] = []
|
|
65
|
-
for start in range(0, len(text), batch_size):
|
|
66
|
-
token = self.tokenizer(
|
|
67
|
-
text[start:start + batch_size],
|
|
68
|
-
padding=True,
|
|
69
|
-
truncation=True,
|
|
70
|
-
return_tensors='pt',
|
|
71
|
-
)
|
|
72
|
-
|
|
73
|
-
emb = self(
|
|
74
|
-
input_ids=token.input_ids.to(self.device),
|
|
75
|
-
attention_mask=token.attention_mask.to(self.device),
|
|
76
|
-
).to(output_device)
|
|
77
|
-
|
|
78
|
-
embs.append(emb)
|
|
79
|
-
|
|
80
|
-
out = torch.cat(embs, dim=0) if len(embs) > 1 else embs[0]
|
|
81
|
-
out = out[:0] if is_empty else out
|
|
82
|
-
return out
|
|
83
|
-
|
|
84
|
-
def __repr__(self) -> str:
|
|
85
|
-
return f'{self.__class__.__name__}(model_name={self.model_name})'
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
def mean_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor:
|
|
89
|
-
mask = attention_mask.unsqueeze(-1).expand(emb.size()).to(emb.dtype)
|
|
90
|
-
return (emb * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
def last_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor:
|
|
94
|
-
# Check whether language model uses left padding,
|
|
95
|
-
# which is always used for decoder LLMs
|
|
96
|
-
left_padding = attention_mask[:, -1].sum() == attention_mask.size(0)
|
|
97
|
-
if left_padding:
|
|
98
|
-
return emb[:, -1]
|
|
99
|
-
|
|
100
|
-
seq_indices = attention_mask.sum(dim=1) - 1
|
|
101
|
-
return emb[torch.arange(emb.size(0), device=emb.device), seq_indices]
|