pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
torch_geometric/utils/sparse.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import typing
|
|
2
1
|
import warnings
|
|
3
2
|
from typing import Any, List, Optional, Tuple, Union
|
|
4
3
|
|
|
@@ -71,8 +70,9 @@ def dense_to_sparse(
|
|
|
71
70
|
f"three-dimensional (got {adj.dim()} dimensions)")
|
|
72
71
|
|
|
73
72
|
if mask is not None and adj.dim() == 2:
|
|
74
|
-
warnings.warn(
|
|
75
|
-
|
|
73
|
+
warnings.warn(
|
|
74
|
+
"Mask should not be provided in case the dense "
|
|
75
|
+
"adjacency matrix is two-dimensional", stacklevel=2)
|
|
76
76
|
mask = None
|
|
77
77
|
|
|
78
78
|
if mask is not None and mask.dim() != 2:
|
|
@@ -124,8 +124,7 @@ def is_torch_sparse_tensor(src: Any) -> bool:
|
|
|
124
124
|
return True
|
|
125
125
|
if src.layout == torch.sparse_csr:
|
|
126
126
|
return True
|
|
127
|
-
if
|
|
128
|
-
and src.layout == torch.sparse_csc):
|
|
127
|
+
if src.layout == torch.sparse_csc:
|
|
129
128
|
return True
|
|
130
129
|
return False
|
|
131
130
|
|
|
@@ -198,15 +197,23 @@ def to_torch_coo_tensor(
|
|
|
198
197
|
# edge_attr = edge_attr.expand(edge_index.size(1))
|
|
199
198
|
edge_attr = torch.ones(edge_index.size(1), device=edge_index.device)
|
|
200
199
|
|
|
201
|
-
|
|
200
|
+
if not torch_geometric.typing.WITH_PT21:
|
|
201
|
+
adj = torch.sparse_coo_tensor(
|
|
202
|
+
indices=edge_index,
|
|
203
|
+
values=edge_attr,
|
|
204
|
+
size=tuple(size) + edge_attr.size()[1:],
|
|
205
|
+
device=edge_index.device,
|
|
206
|
+
)
|
|
207
|
+
adj = adj._coalesced_(True)
|
|
208
|
+
return adj
|
|
209
|
+
|
|
210
|
+
return torch.sparse_coo_tensor(
|
|
202
211
|
indices=edge_index,
|
|
203
212
|
values=edge_attr,
|
|
204
213
|
size=tuple(size) + edge_attr.size()[1:],
|
|
205
214
|
device=edge_index.device,
|
|
215
|
+
is_coalesced=True,
|
|
206
216
|
)
|
|
207
|
-
adj = adj._coalesced_(True)
|
|
208
|
-
|
|
209
|
-
return adj
|
|
210
217
|
|
|
211
218
|
|
|
212
219
|
def to_torch_csr_tensor(
|
|
@@ -312,12 +319,6 @@ def to_torch_csc_tensor(
|
|
|
312
319
|
size=(4, 4), nnz=6, layout=torch.sparse_csc)
|
|
313
320
|
|
|
314
321
|
"""
|
|
315
|
-
if not torch_geometric.typing.WITH_PT112:
|
|
316
|
-
if typing.TYPE_CHECKING:
|
|
317
|
-
raise NotImplementedError
|
|
318
|
-
return torch_geometric.typing.MockTorchCSCTensor(
|
|
319
|
-
edge_index, edge_attr, size)
|
|
320
|
-
|
|
321
322
|
if size is None:
|
|
322
323
|
size = int(edge_index.max()) + 1
|
|
323
324
|
|
|
@@ -384,7 +385,7 @@ def to_torch_sparse_tensor(
|
|
|
384
385
|
return to_torch_coo_tensor(edge_index, edge_attr, size, is_coalesced)
|
|
385
386
|
if layout == torch.sparse_csr:
|
|
386
387
|
return to_torch_csr_tensor(edge_index, edge_attr, size, is_coalesced)
|
|
387
|
-
if
|
|
388
|
+
if layout == torch.sparse_csc:
|
|
388
389
|
return to_torch_csc_tensor(edge_index, edge_attr, size, is_coalesced)
|
|
389
390
|
|
|
390
391
|
raise ValueError(f"Unexpected sparse tensor layout (got '{layout}')")
|
|
@@ -423,7 +424,7 @@ def to_edge_index(adj: Union[Tensor, SparseTensor]) -> Tuple[Tensor, Tensor]:
|
|
|
423
424
|
col = adj.col_indices().detach()
|
|
424
425
|
return torch.stack([row, col], dim=0).long(), adj.values()
|
|
425
426
|
|
|
426
|
-
if
|
|
427
|
+
if adj.layout == torch.sparse_csc:
|
|
427
428
|
col = ptr2index(adj.ccol_indices().detach())
|
|
428
429
|
row = adj.row_indices().detach()
|
|
429
430
|
return torch.stack([row, col], dim=0).long(), adj.values()
|
|
@@ -472,7 +473,7 @@ def set_sparse_value(adj: Tensor, value: Tensor) -> Tensor:
|
|
|
472
473
|
device=value.device,
|
|
473
474
|
)
|
|
474
475
|
|
|
475
|
-
if
|
|
476
|
+
if adj.layout == torch.sparse_csc:
|
|
476
477
|
return torch.sparse_csc_tensor(
|
|
477
478
|
ccol_indices=adj.ccol_indices(),
|
|
478
479
|
row_indices=adj.row_indices(),
|
|
@@ -531,18 +532,25 @@ def cat_coo(tensors: List[Tensor], dim: Union[int, Tuple[int, int]]) -> Tensor:
|
|
|
531
532
|
if not tensor.is_coalesced():
|
|
532
533
|
is_coalesced = False
|
|
533
534
|
|
|
534
|
-
|
|
535
|
+
if not torch_geometric.typing.WITH_PT21:
|
|
536
|
+
out = torch.sparse_coo_tensor(
|
|
537
|
+
indices=torch.cat(indices, dim=-1),
|
|
538
|
+
values=torch.cat(values),
|
|
539
|
+
size=(num_rows, num_cols) + values[-1].size()[1:],
|
|
540
|
+
device=tensor.device,
|
|
541
|
+
)
|
|
542
|
+
if is_coalesced:
|
|
543
|
+
out = out._coalesced_(True)
|
|
544
|
+
return out
|
|
545
|
+
|
|
546
|
+
return torch.sparse_coo_tensor(
|
|
535
547
|
indices=torch.cat(indices, dim=-1),
|
|
536
548
|
values=torch.cat(values),
|
|
537
549
|
size=(num_rows, num_cols) + values[-1].size()[1:],
|
|
538
550
|
device=tensor.device,
|
|
551
|
+
is_coalesced=True if is_coalesced else None,
|
|
539
552
|
)
|
|
540
553
|
|
|
541
|
-
if is_coalesced:
|
|
542
|
-
out = out._coalesced_(True)
|
|
543
|
-
|
|
544
|
-
return out
|
|
545
|
-
|
|
546
554
|
|
|
547
555
|
def cat_csr(tensors: List[Tensor], dim: Union[int, Tuple[int, int]]) -> Tensor:
|
|
548
556
|
assert dim in {0, 1, (0, 1)}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from math import sqrt
|
|
2
|
-
from typing import Any, List, Optional
|
|
2
|
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
from torch import Tensor
|
|
@@ -132,7 +132,7 @@ def _visualize_graph_via_networkx(
|
|
|
132
132
|
xy=pos[src],
|
|
133
133
|
xytext=pos[dst],
|
|
134
134
|
arrowprops=dict(
|
|
135
|
-
arrowstyle="
|
|
135
|
+
arrowstyle="<-",
|
|
136
136
|
alpha=data['alpha'],
|
|
137
137
|
shrinkA=sqrt(node_size) / 2.0,
|
|
138
138
|
shrinkB=sqrt(node_size) / 2.0,
|
|
@@ -140,9 +140,8 @@ def _visualize_graph_via_networkx(
|
|
|
140
140
|
),
|
|
141
141
|
)
|
|
142
142
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
nodes.set_edgecolor('black')
|
|
143
|
+
nx.draw_networkx_nodes(g, pos, node_size=node_size, node_color='white',
|
|
144
|
+
margins=0.1, edgecolors='black')
|
|
146
145
|
nx.draw_networkx_labels(g, pos, font_size=10)
|
|
147
146
|
|
|
148
147
|
if path is not None:
|
|
@@ -151,3 +150,249 @@ def _visualize_graph_via_networkx(
|
|
|
151
150
|
plt.show()
|
|
152
151
|
|
|
153
152
|
plt.close()
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def visualize_hetero_graph(
|
|
156
|
+
edge_index_dict: Dict[Tuple[str, str, str], Tensor],
|
|
157
|
+
edge_weight_dict: Dict[Tuple[str, str, str], Tensor],
|
|
158
|
+
path: Optional[str] = None,
|
|
159
|
+
backend: Optional[str] = None,
|
|
160
|
+
node_labels_dict: Optional[Dict[str, List[str]]] = None,
|
|
161
|
+
node_weight_dict: Optional[Dict[str, Tensor]] = None,
|
|
162
|
+
node_size_range: Tuple[float, float] = (50, 500),
|
|
163
|
+
node_opacity_range: Tuple[float, float] = (1.0, 1.0),
|
|
164
|
+
edge_width_range: Tuple[float, float] = (0.1, 2.0),
|
|
165
|
+
edge_opacity_range: Tuple[float, float] = (1.0, 1.0),
|
|
166
|
+
) -> Any:
|
|
167
|
+
"""Visualizes a heterogeneous graph using networkx."""
|
|
168
|
+
if backend is not None and backend != "networkx":
|
|
169
|
+
raise ValueError("Only 'networkx' backend is supported")
|
|
170
|
+
|
|
171
|
+
# Filter out edges with 0 weight
|
|
172
|
+
filtered_edge_index_dict = {}
|
|
173
|
+
filtered_edge_weight_dict = {}
|
|
174
|
+
for edge_type in edge_index_dict.keys():
|
|
175
|
+
mask = edge_weight_dict[edge_type] > 0
|
|
176
|
+
if mask.sum() > 0:
|
|
177
|
+
filtered_edge_index_dict[edge_type] = edge_index_dict[
|
|
178
|
+
edge_type][:, mask]
|
|
179
|
+
filtered_edge_weight_dict[edge_type] = edge_weight_dict[edge_type][
|
|
180
|
+
mask]
|
|
181
|
+
|
|
182
|
+
# Get all unique nodes that are still in the filtered edges
|
|
183
|
+
remaining_nodes: Dict[str, Set[int]] = {}
|
|
184
|
+
for edge_type, edge_index in filtered_edge_index_dict.items():
|
|
185
|
+
src_type, _, dst_type = edge_type
|
|
186
|
+
if src_type not in remaining_nodes:
|
|
187
|
+
remaining_nodes[src_type] = set()
|
|
188
|
+
if dst_type not in remaining_nodes:
|
|
189
|
+
remaining_nodes[dst_type] = set()
|
|
190
|
+
remaining_nodes[src_type].update(edge_index[0].tolist())
|
|
191
|
+
remaining_nodes[dst_type].update(edge_index[1].tolist())
|
|
192
|
+
|
|
193
|
+
# Filter node weights to only include remaining nodes
|
|
194
|
+
if node_weight_dict is not None:
|
|
195
|
+
filtered_node_weight_dict = {}
|
|
196
|
+
for node_type, weights in node_weight_dict.items():
|
|
197
|
+
if node_type in remaining_nodes:
|
|
198
|
+
mask = torch.zeros(len(weights), dtype=torch.bool)
|
|
199
|
+
mask[list(remaining_nodes[node_type])] = True
|
|
200
|
+
filtered_node_weight_dict[node_type] = weights[mask]
|
|
201
|
+
node_weight_dict = filtered_node_weight_dict
|
|
202
|
+
|
|
203
|
+
# Filter node labels to only include remaining nodes
|
|
204
|
+
if node_labels_dict is not None:
|
|
205
|
+
filtered_node_labels_dict = {}
|
|
206
|
+
for node_type, labels in node_labels_dict.items():
|
|
207
|
+
if node_type in remaining_nodes:
|
|
208
|
+
filtered_node_labels_dict[node_type] = [
|
|
209
|
+
label for i, label in enumerate(labels)
|
|
210
|
+
if i in remaining_nodes[node_type]
|
|
211
|
+
]
|
|
212
|
+
node_labels_dict = filtered_node_labels_dict
|
|
213
|
+
|
|
214
|
+
return _visualize_hetero_graph_via_networkx(
|
|
215
|
+
filtered_edge_index_dict,
|
|
216
|
+
filtered_edge_weight_dict,
|
|
217
|
+
path,
|
|
218
|
+
node_labels_dict,
|
|
219
|
+
node_weight_dict,
|
|
220
|
+
node_size_range,
|
|
221
|
+
node_opacity_range,
|
|
222
|
+
edge_width_range,
|
|
223
|
+
edge_opacity_range,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _visualize_hetero_graph_via_networkx(
|
|
228
|
+
edge_index_dict: Dict[Tuple[str, str, str], Tensor],
|
|
229
|
+
edge_weight_dict: Dict[Tuple[str, str, str], Tensor],
|
|
230
|
+
path: Optional[str] = None,
|
|
231
|
+
node_labels_dict: Optional[Dict[str, List[str]]] = None,
|
|
232
|
+
node_weight_dict: Optional[Dict[str, Tensor]] = None,
|
|
233
|
+
node_size_range: Tuple[float, float] = (50, 500),
|
|
234
|
+
node_opacity_range: Tuple[float, float] = (1.0, 1.0),
|
|
235
|
+
edge_width_range: Tuple[float, float] = (0.1, 2.0),
|
|
236
|
+
edge_opacity_range: Tuple[float, float] = (1.0, 1.0),
|
|
237
|
+
) -> Any:
|
|
238
|
+
import matplotlib.pyplot as plt
|
|
239
|
+
import networkx as nx
|
|
240
|
+
|
|
241
|
+
g = nx.DiGraph()
|
|
242
|
+
node_offsets: Dict[str, int] = {}
|
|
243
|
+
current_offset = 0
|
|
244
|
+
|
|
245
|
+
# First, collect all unique node types and their counts
|
|
246
|
+
node_types = set()
|
|
247
|
+
node_counts: Dict[str, int] = {}
|
|
248
|
+
remaining_nodes: Dict[str, Set[int]] = {
|
|
249
|
+
} # Track which nodes are actually present in edges
|
|
250
|
+
|
|
251
|
+
# Get all unique nodes that are in the edges
|
|
252
|
+
for edge_type in edge_index_dict.keys():
|
|
253
|
+
src_type, _, dst_type = edge_type
|
|
254
|
+
node_types.add(src_type)
|
|
255
|
+
node_types.add(dst_type)
|
|
256
|
+
|
|
257
|
+
if src_type not in remaining_nodes:
|
|
258
|
+
remaining_nodes[src_type] = set()
|
|
259
|
+
if dst_type not in remaining_nodes:
|
|
260
|
+
remaining_nodes[dst_type] = set()
|
|
261
|
+
|
|
262
|
+
remaining_nodes[src_type].update(
|
|
263
|
+
edge_index_dict[edge_type][0].tolist())
|
|
264
|
+
remaining_nodes[dst_type].update(
|
|
265
|
+
edge_index_dict[edge_type][1].tolist())
|
|
266
|
+
|
|
267
|
+
# Set node counts based on remaining nodes
|
|
268
|
+
for node_type in node_types:
|
|
269
|
+
node_counts[node_type] = len(remaining_nodes[node_type])
|
|
270
|
+
|
|
271
|
+
# Add nodes for each node type
|
|
272
|
+
for node_type in node_types:
|
|
273
|
+
num_nodes = node_counts[node_type]
|
|
274
|
+
node_offsets[node_type] = current_offset
|
|
275
|
+
|
|
276
|
+
# Get node weights if provided
|
|
277
|
+
weights = None
|
|
278
|
+
if node_weight_dict is not None and node_type in node_weight_dict:
|
|
279
|
+
weights = node_weight_dict[node_type]
|
|
280
|
+
if len(weights) != num_nodes:
|
|
281
|
+
raise ValueError(f"Number of weights for node type "
|
|
282
|
+
f"{node_type} ({len(weights)}) does not "
|
|
283
|
+
f"match number of nodes ({num_nodes})")
|
|
284
|
+
|
|
285
|
+
for i in range(num_nodes):
|
|
286
|
+
node_id = current_offset + i
|
|
287
|
+
label = (node_labels_dict[node_type][i]
|
|
288
|
+
if node_labels_dict is not None
|
|
289
|
+
and node_type in node_labels_dict else "")
|
|
290
|
+
|
|
291
|
+
# Calculate node size and opacity if weights provided
|
|
292
|
+
size = node_size_range[1]
|
|
293
|
+
opacity = node_opacity_range[1]
|
|
294
|
+
if weights is not None:
|
|
295
|
+
w = weights[i].item()
|
|
296
|
+
size = node_size_range[0] + w * \
|
|
297
|
+
(node_size_range[1] - node_size_range[0])
|
|
298
|
+
opacity = node_opacity_range[0] + w * \
|
|
299
|
+
(node_opacity_range[1] - node_opacity_range[0])
|
|
300
|
+
|
|
301
|
+
g.add_node(node_id, label=label, type=node_type, size=size,
|
|
302
|
+
alpha=opacity)
|
|
303
|
+
|
|
304
|
+
current_offset += num_nodes
|
|
305
|
+
|
|
306
|
+
# Add edges with remapped node indices
|
|
307
|
+
for edge_type, edge_index in edge_index_dict.items():
|
|
308
|
+
src_type, _, dst_type = edge_type
|
|
309
|
+
edge_weight = edge_weight_dict[edge_type]
|
|
310
|
+
src_offset = node_offsets[src_type]
|
|
311
|
+
dst_offset = node_offsets[dst_type]
|
|
312
|
+
|
|
313
|
+
# Create mappings for source and target nodes
|
|
314
|
+
src_mapping = {
|
|
315
|
+
old_idx: new_idx
|
|
316
|
+
for new_idx, old_idx in enumerate(sorted(
|
|
317
|
+
remaining_nodes[src_type]))
|
|
318
|
+
}
|
|
319
|
+
dst_mapping = {
|
|
320
|
+
old_idx: new_idx
|
|
321
|
+
for new_idx, old_idx in enumerate(sorted(
|
|
322
|
+
remaining_nodes[dst_type]))
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
for (src, dst), w in zip(edge_index.t().tolist(),
|
|
326
|
+
edge_weight.tolist()):
|
|
327
|
+
# Remap node indices
|
|
328
|
+
new_src = src_mapping[src] + src_offset
|
|
329
|
+
new_dst = dst_mapping[dst] + dst_offset
|
|
330
|
+
|
|
331
|
+
# Calculate edge width and opacity based on weight
|
|
332
|
+
width = edge_width_range[0] + w * \
|
|
333
|
+
(edge_width_range[1] - edge_width_range[0])
|
|
334
|
+
opacity = edge_opacity_range[0] + w * \
|
|
335
|
+
(edge_opacity_range[1] - edge_opacity_range[0])
|
|
336
|
+
g.add_edge(new_src, new_dst, width=width, alpha=opacity)
|
|
337
|
+
|
|
338
|
+
# Draw the graph
|
|
339
|
+
ax = plt.gca()
|
|
340
|
+
pos = nx.arf_layout(g)
|
|
341
|
+
|
|
342
|
+
# Draw edges with arrows
|
|
343
|
+
for src, dst, data in g.edges(data=True):
|
|
344
|
+
ax.annotate(
|
|
345
|
+
'',
|
|
346
|
+
xy=pos[src],
|
|
347
|
+
xytext=pos[dst],
|
|
348
|
+
arrowprops=dict(
|
|
349
|
+
arrowstyle="<-",
|
|
350
|
+
alpha=data['alpha'],
|
|
351
|
+
linewidth=data['width'],
|
|
352
|
+
shrinkA=sqrt(g.nodes[src]['size']) / 2.0,
|
|
353
|
+
shrinkB=sqrt(g.nodes[dst]['size']) / 2.0,
|
|
354
|
+
connectionstyle="arc3,rad=0.1",
|
|
355
|
+
),
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
# Draw nodes colored by type
|
|
359
|
+
node_colors = []
|
|
360
|
+
node_sizes = []
|
|
361
|
+
node_alphas = []
|
|
362
|
+
|
|
363
|
+
# Use matplotlib tab20 colormap for consistent coloring
|
|
364
|
+
tab10_cmap = plt.cm.tab10 # type: ignore[attr-defined]
|
|
365
|
+
node_type_colors: Dict[str, Any] = {} # Store color for each node type
|
|
366
|
+
for node in g.nodes():
|
|
367
|
+
node_type = g.nodes[node]['type']
|
|
368
|
+
# Assign a consistent color for each node type
|
|
369
|
+
if node_type not in node_type_colors:
|
|
370
|
+
color_idx = len(node_type_colors) % 10 # Cycle through colors
|
|
371
|
+
node_type_colors[node_type] = tab10_cmap(color_idx)
|
|
372
|
+
node_colors.append(node_type_colors[node_type])
|
|
373
|
+
node_sizes.append(g.nodes[node]['size'])
|
|
374
|
+
node_alphas.append(g.nodes[node]['alpha'])
|
|
375
|
+
|
|
376
|
+
nx.draw_networkx_nodes(g, pos, node_size=node_sizes,
|
|
377
|
+
node_color=node_colors, margins=0.1,
|
|
378
|
+
alpha=node_alphas)
|
|
379
|
+
|
|
380
|
+
# Draw labels
|
|
381
|
+
labels = nx.get_node_attributes(g, 'label')
|
|
382
|
+
nx.draw_networkx_labels(g, pos, labels, font_size=10)
|
|
383
|
+
|
|
384
|
+
# Add legend
|
|
385
|
+
legend_elements = []
|
|
386
|
+
for node_type, color in node_type_colors.items():
|
|
387
|
+
legend_elements.append(
|
|
388
|
+
plt.Line2D([0], [0], marker='o', color='w', label=node_type,
|
|
389
|
+
markerfacecolor=color, markersize=10))
|
|
390
|
+
ax.legend(handles=legend_elements, loc='upper right',
|
|
391
|
+
bbox_to_anchor=(0.9, 1))
|
|
392
|
+
|
|
393
|
+
if path is not None:
|
|
394
|
+
plt.savefig(path, bbox_inches='tight')
|
|
395
|
+
else:
|
|
396
|
+
plt.show()
|
|
397
|
+
|
|
398
|
+
plt.close()
|
torch_geometric/warnings.py
CHANGED
|
@@ -4,11 +4,11 @@ from typing import Literal
|
|
|
4
4
|
import torch_geometric
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
def warn(message: str) -> None:
|
|
7
|
+
def warn(message: str, stacklevel: int = 5) -> None:
|
|
8
8
|
if torch_geometric.is_compiling():
|
|
9
9
|
return
|
|
10
10
|
|
|
11
|
-
warnings.warn(message)
|
|
11
|
+
warnings.warn(message, stacklevel=stacklevel)
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def filterwarnings(
|
|
@@ -19,3 +19,12 @@ def filterwarnings(
|
|
|
19
19
|
return
|
|
20
20
|
|
|
21
21
|
warnings.filterwarnings(action, message)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class WarningCache(set):
|
|
25
|
+
"""Cache for warnings."""
|
|
26
|
+
def warn(self, message: str, stacklevel: int = 5) -> None:
|
|
27
|
+
"""Trigger warning message."""
|
|
28
|
+
if message not in self:
|
|
29
|
+
self.add(message)
|
|
30
|
+
warn(message, stacklevel=stacklevel)
|