pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -1,4 +1,3 @@
1
- import typing
2
1
  import warnings
3
2
  from typing import Any, List, Optional, Tuple, Union
4
3
 
@@ -71,8 +70,9 @@ def dense_to_sparse(
71
70
  f"three-dimensional (got {adj.dim()} dimensions)")
72
71
 
73
72
  if mask is not None and adj.dim() == 2:
74
- warnings.warn("Mask should not be provided in case the dense "
75
- "adjacency matrix is two-dimensional")
73
+ warnings.warn(
74
+ "Mask should not be provided in case the dense "
75
+ "adjacency matrix is two-dimensional", stacklevel=2)
76
76
  mask = None
77
77
 
78
78
  if mask is not None and mask.dim() != 2:
@@ -124,8 +124,7 @@ def is_torch_sparse_tensor(src: Any) -> bool:
124
124
  return True
125
125
  if src.layout == torch.sparse_csr:
126
126
  return True
127
- if (torch_geometric.typing.WITH_PT112
128
- and src.layout == torch.sparse_csc):
127
+ if src.layout == torch.sparse_csc:
129
128
  return True
130
129
  return False
131
130
 
@@ -198,15 +197,23 @@ def to_torch_coo_tensor(
198
197
  # edge_attr = edge_attr.expand(edge_index.size(1))
199
198
  edge_attr = torch.ones(edge_index.size(1), device=edge_index.device)
200
199
 
201
- adj = torch.sparse_coo_tensor(
200
+ if not torch_geometric.typing.WITH_PT21:
201
+ adj = torch.sparse_coo_tensor(
202
+ indices=edge_index,
203
+ values=edge_attr,
204
+ size=tuple(size) + edge_attr.size()[1:],
205
+ device=edge_index.device,
206
+ )
207
+ adj = adj._coalesced_(True)
208
+ return adj
209
+
210
+ return torch.sparse_coo_tensor(
202
211
  indices=edge_index,
203
212
  values=edge_attr,
204
213
  size=tuple(size) + edge_attr.size()[1:],
205
214
  device=edge_index.device,
215
+ is_coalesced=True,
206
216
  )
207
- adj = adj._coalesced_(True)
208
-
209
- return adj
210
217
 
211
218
 
212
219
  def to_torch_csr_tensor(
@@ -312,12 +319,6 @@ def to_torch_csc_tensor(
312
319
  size=(4, 4), nnz=6, layout=torch.sparse_csc)
313
320
 
314
321
  """
315
- if not torch_geometric.typing.WITH_PT112:
316
- if typing.TYPE_CHECKING:
317
- raise NotImplementedError
318
- return torch_geometric.typing.MockTorchCSCTensor(
319
- edge_index, edge_attr, size)
320
-
321
322
  if size is None:
322
323
  size = int(edge_index.max()) + 1
323
324
 
@@ -384,7 +385,7 @@ def to_torch_sparse_tensor(
384
385
  return to_torch_coo_tensor(edge_index, edge_attr, size, is_coalesced)
385
386
  if layout == torch.sparse_csr:
386
387
  return to_torch_csr_tensor(edge_index, edge_attr, size, is_coalesced)
387
- if torch_geometric.typing.WITH_PT112 and layout == torch.sparse_csc:
388
+ if layout == torch.sparse_csc:
388
389
  return to_torch_csc_tensor(edge_index, edge_attr, size, is_coalesced)
389
390
 
390
391
  raise ValueError(f"Unexpected sparse tensor layout (got '{layout}')")
@@ -423,7 +424,7 @@ def to_edge_index(adj: Union[Tensor, SparseTensor]) -> Tuple[Tensor, Tensor]:
423
424
  col = adj.col_indices().detach()
424
425
  return torch.stack([row, col], dim=0).long(), adj.values()
425
426
 
426
- if torch_geometric.typing.WITH_PT112 and adj.layout == torch.sparse_csc:
427
+ if adj.layout == torch.sparse_csc:
427
428
  col = ptr2index(adj.ccol_indices().detach())
428
429
  row = adj.row_indices().detach()
429
430
  return torch.stack([row, col], dim=0).long(), adj.values()
@@ -472,7 +473,7 @@ def set_sparse_value(adj: Tensor, value: Tensor) -> Tensor:
472
473
  device=value.device,
473
474
  )
474
475
 
475
- if torch_geometric.typing.WITH_PT112 and adj.layout == torch.sparse_csc:
476
+ if adj.layout == torch.sparse_csc:
476
477
  return torch.sparse_csc_tensor(
477
478
  ccol_indices=adj.ccol_indices(),
478
479
  row_indices=adj.row_indices(),
@@ -531,18 +532,25 @@ def cat_coo(tensors: List[Tensor], dim: Union[int, Tuple[int, int]]) -> Tensor:
531
532
  if not tensor.is_coalesced():
532
533
  is_coalesced = False
533
534
 
534
- out = torch.sparse_coo_tensor(
535
+ if not torch_geometric.typing.WITH_PT21:
536
+ out = torch.sparse_coo_tensor(
537
+ indices=torch.cat(indices, dim=-1),
538
+ values=torch.cat(values),
539
+ size=(num_rows, num_cols) + values[-1].size()[1:],
540
+ device=tensor.device,
541
+ )
542
+ if is_coalesced:
543
+ out = out._coalesced_(True)
544
+ return out
545
+
546
+ return torch.sparse_coo_tensor(
535
547
  indices=torch.cat(indices, dim=-1),
536
548
  values=torch.cat(values),
537
549
  size=(num_rows, num_cols) + values[-1].size()[1:],
538
550
  device=tensor.device,
551
+ is_coalesced=True if is_coalesced else None,
539
552
  )
540
553
 
541
- if is_coalesced:
542
- out = out._coalesced_(True)
543
-
544
- return out
545
-
546
554
 
547
555
  def cat_csr(tensors: List[Tensor], dim: Union[int, Tuple[int, int]]) -> Tensor:
548
556
  assert dim in {0, 1, (0, 1)}
@@ -1,9 +1,10 @@
1
1
  r"""Visualization package."""
2
2
 
3
- from .graph import visualize_graph
3
+ from .graph import visualize_graph, visualize_hetero_graph
4
4
  from .influence import influence
5
5
 
6
6
  __all__ = [
7
7
  'visualize_graph',
8
+ 'visualize_hetero_graph',
8
9
  'influence',
9
10
  ]
@@ -1,5 +1,5 @@
1
1
  from math import sqrt
2
- from typing import Any, List, Optional
2
+ from typing import Any, Dict, List, Optional, Set, Tuple
3
3
 
4
4
  import torch
5
5
  from torch import Tensor
@@ -132,7 +132,7 @@ def _visualize_graph_via_networkx(
132
132
  xy=pos[src],
133
133
  xytext=pos[dst],
134
134
  arrowprops=dict(
135
- arrowstyle="->",
135
+ arrowstyle="<-",
136
136
  alpha=data['alpha'],
137
137
  shrinkA=sqrt(node_size) / 2.0,
138
138
  shrinkB=sqrt(node_size) / 2.0,
@@ -140,9 +140,8 @@ def _visualize_graph_via_networkx(
140
140
  ),
141
141
  )
142
142
 
143
- nodes = nx.draw_networkx_nodes(g, pos, node_size=node_size,
144
- node_color='white', margins=0.1)
145
- nodes.set_edgecolor('black')
143
+ nx.draw_networkx_nodes(g, pos, node_size=node_size, node_color='white',
144
+ margins=0.1, edgecolors='black')
146
145
  nx.draw_networkx_labels(g, pos, font_size=10)
147
146
 
148
147
  if path is not None:
@@ -151,3 +150,249 @@ def _visualize_graph_via_networkx(
151
150
  plt.show()
152
151
 
153
152
  plt.close()
153
+
154
+
155
+ def visualize_hetero_graph(
156
+ edge_index_dict: Dict[Tuple[str, str, str], Tensor],
157
+ edge_weight_dict: Dict[Tuple[str, str, str], Tensor],
158
+ path: Optional[str] = None,
159
+ backend: Optional[str] = None,
160
+ node_labels_dict: Optional[Dict[str, List[str]]] = None,
161
+ node_weight_dict: Optional[Dict[str, Tensor]] = None,
162
+ node_size_range: Tuple[float, float] = (50, 500),
163
+ node_opacity_range: Tuple[float, float] = (1.0, 1.0),
164
+ edge_width_range: Tuple[float, float] = (0.1, 2.0),
165
+ edge_opacity_range: Tuple[float, float] = (1.0, 1.0),
166
+ ) -> Any:
167
+ """Visualizes a heterogeneous graph using networkx."""
168
+ if backend is not None and backend != "networkx":
169
+ raise ValueError("Only 'networkx' backend is supported")
170
+
171
+ # Filter out edges with 0 weight
172
+ filtered_edge_index_dict = {}
173
+ filtered_edge_weight_dict = {}
174
+ for edge_type in edge_index_dict.keys():
175
+ mask = edge_weight_dict[edge_type] > 0
176
+ if mask.sum() > 0:
177
+ filtered_edge_index_dict[edge_type] = edge_index_dict[
178
+ edge_type][:, mask]
179
+ filtered_edge_weight_dict[edge_type] = edge_weight_dict[edge_type][
180
+ mask]
181
+
182
+ # Get all unique nodes that are still in the filtered edges
183
+ remaining_nodes: Dict[str, Set[int]] = {}
184
+ for edge_type, edge_index in filtered_edge_index_dict.items():
185
+ src_type, _, dst_type = edge_type
186
+ if src_type not in remaining_nodes:
187
+ remaining_nodes[src_type] = set()
188
+ if dst_type not in remaining_nodes:
189
+ remaining_nodes[dst_type] = set()
190
+ remaining_nodes[src_type].update(edge_index[0].tolist())
191
+ remaining_nodes[dst_type].update(edge_index[1].tolist())
192
+
193
+ # Filter node weights to only include remaining nodes
194
+ if node_weight_dict is not None:
195
+ filtered_node_weight_dict = {}
196
+ for node_type, weights in node_weight_dict.items():
197
+ if node_type in remaining_nodes:
198
+ mask = torch.zeros(len(weights), dtype=torch.bool)
199
+ mask[list(remaining_nodes[node_type])] = True
200
+ filtered_node_weight_dict[node_type] = weights[mask]
201
+ node_weight_dict = filtered_node_weight_dict
202
+
203
+ # Filter node labels to only include remaining nodes
204
+ if node_labels_dict is not None:
205
+ filtered_node_labels_dict = {}
206
+ for node_type, labels in node_labels_dict.items():
207
+ if node_type in remaining_nodes:
208
+ filtered_node_labels_dict[node_type] = [
209
+ label for i, label in enumerate(labels)
210
+ if i in remaining_nodes[node_type]
211
+ ]
212
+ node_labels_dict = filtered_node_labels_dict
213
+
214
+ return _visualize_hetero_graph_via_networkx(
215
+ filtered_edge_index_dict,
216
+ filtered_edge_weight_dict,
217
+ path,
218
+ node_labels_dict,
219
+ node_weight_dict,
220
+ node_size_range,
221
+ node_opacity_range,
222
+ edge_width_range,
223
+ edge_opacity_range,
224
+ )
225
+
226
+
227
+ def _visualize_hetero_graph_via_networkx(
228
+ edge_index_dict: Dict[Tuple[str, str, str], Tensor],
229
+ edge_weight_dict: Dict[Tuple[str, str, str], Tensor],
230
+ path: Optional[str] = None,
231
+ node_labels_dict: Optional[Dict[str, List[str]]] = None,
232
+ node_weight_dict: Optional[Dict[str, Tensor]] = None,
233
+ node_size_range: Tuple[float, float] = (50, 500),
234
+ node_opacity_range: Tuple[float, float] = (1.0, 1.0),
235
+ edge_width_range: Tuple[float, float] = (0.1, 2.0),
236
+ edge_opacity_range: Tuple[float, float] = (1.0, 1.0),
237
+ ) -> Any:
238
+ import matplotlib.pyplot as plt
239
+ import networkx as nx
240
+
241
+ g = nx.DiGraph()
242
+ node_offsets: Dict[str, int] = {}
243
+ current_offset = 0
244
+
245
+ # First, collect all unique node types and their counts
246
+ node_types = set()
247
+ node_counts: Dict[str, int] = {}
248
+ remaining_nodes: Dict[str, Set[int]] = {
249
+ } # Track which nodes are actually present in edges
250
+
251
+ # Get all unique nodes that are in the edges
252
+ for edge_type in edge_index_dict.keys():
253
+ src_type, _, dst_type = edge_type
254
+ node_types.add(src_type)
255
+ node_types.add(dst_type)
256
+
257
+ if src_type not in remaining_nodes:
258
+ remaining_nodes[src_type] = set()
259
+ if dst_type not in remaining_nodes:
260
+ remaining_nodes[dst_type] = set()
261
+
262
+ remaining_nodes[src_type].update(
263
+ edge_index_dict[edge_type][0].tolist())
264
+ remaining_nodes[dst_type].update(
265
+ edge_index_dict[edge_type][1].tolist())
266
+
267
+ # Set node counts based on remaining nodes
268
+ for node_type in node_types:
269
+ node_counts[node_type] = len(remaining_nodes[node_type])
270
+
271
+ # Add nodes for each node type
272
+ for node_type in node_types:
273
+ num_nodes = node_counts[node_type]
274
+ node_offsets[node_type] = current_offset
275
+
276
+ # Get node weights if provided
277
+ weights = None
278
+ if node_weight_dict is not None and node_type in node_weight_dict:
279
+ weights = node_weight_dict[node_type]
280
+ if len(weights) != num_nodes:
281
+ raise ValueError(f"Number of weights for node type "
282
+ f"{node_type} ({len(weights)}) does not "
283
+ f"match number of nodes ({num_nodes})")
284
+
285
+ for i in range(num_nodes):
286
+ node_id = current_offset + i
287
+ label = (node_labels_dict[node_type][i]
288
+ if node_labels_dict is not None
289
+ and node_type in node_labels_dict else "")
290
+
291
+ # Calculate node size and opacity if weights provided
292
+ size = node_size_range[1]
293
+ opacity = node_opacity_range[1]
294
+ if weights is not None:
295
+ w = weights[i].item()
296
+ size = node_size_range[0] + w * \
297
+ (node_size_range[1] - node_size_range[0])
298
+ opacity = node_opacity_range[0] + w * \
299
+ (node_opacity_range[1] - node_opacity_range[0])
300
+
301
+ g.add_node(node_id, label=label, type=node_type, size=size,
302
+ alpha=opacity)
303
+
304
+ current_offset += num_nodes
305
+
306
+ # Add edges with remapped node indices
307
+ for edge_type, edge_index in edge_index_dict.items():
308
+ src_type, _, dst_type = edge_type
309
+ edge_weight = edge_weight_dict[edge_type]
310
+ src_offset = node_offsets[src_type]
311
+ dst_offset = node_offsets[dst_type]
312
+
313
+ # Create mappings for source and target nodes
314
+ src_mapping = {
315
+ old_idx: new_idx
316
+ for new_idx, old_idx in enumerate(sorted(
317
+ remaining_nodes[src_type]))
318
+ }
319
+ dst_mapping = {
320
+ old_idx: new_idx
321
+ for new_idx, old_idx in enumerate(sorted(
322
+ remaining_nodes[dst_type]))
323
+ }
324
+
325
+ for (src, dst), w in zip(edge_index.t().tolist(),
326
+ edge_weight.tolist()):
327
+ # Remap node indices
328
+ new_src = src_mapping[src] + src_offset
329
+ new_dst = dst_mapping[dst] + dst_offset
330
+
331
+ # Calculate edge width and opacity based on weight
332
+ width = edge_width_range[0] + w * \
333
+ (edge_width_range[1] - edge_width_range[0])
334
+ opacity = edge_opacity_range[0] + w * \
335
+ (edge_opacity_range[1] - edge_opacity_range[0])
336
+ g.add_edge(new_src, new_dst, width=width, alpha=opacity)
337
+
338
+ # Draw the graph
339
+ ax = plt.gca()
340
+ pos = nx.arf_layout(g)
341
+
342
+ # Draw edges with arrows
343
+ for src, dst, data in g.edges(data=True):
344
+ ax.annotate(
345
+ '',
346
+ xy=pos[src],
347
+ xytext=pos[dst],
348
+ arrowprops=dict(
349
+ arrowstyle="<-",
350
+ alpha=data['alpha'],
351
+ linewidth=data['width'],
352
+ shrinkA=sqrt(g.nodes[src]['size']) / 2.0,
353
+ shrinkB=sqrt(g.nodes[dst]['size']) / 2.0,
354
+ connectionstyle="arc3,rad=0.1",
355
+ ),
356
+ )
357
+
358
+ # Draw nodes colored by type
359
+ node_colors = []
360
+ node_sizes = []
361
+ node_alphas = []
362
+
363
+ # Use matplotlib tab20 colormap for consistent coloring
364
+ tab10_cmap = plt.cm.tab10 # type: ignore[attr-defined]
365
+ node_type_colors: Dict[str, Any] = {} # Store color for each node type
366
+ for node in g.nodes():
367
+ node_type = g.nodes[node]['type']
368
+ # Assign a consistent color for each node type
369
+ if node_type not in node_type_colors:
370
+ color_idx = len(node_type_colors) % 10 # Cycle through colors
371
+ node_type_colors[node_type] = tab10_cmap(color_idx)
372
+ node_colors.append(node_type_colors[node_type])
373
+ node_sizes.append(g.nodes[node]['size'])
374
+ node_alphas.append(g.nodes[node]['alpha'])
375
+
376
+ nx.draw_networkx_nodes(g, pos, node_size=node_sizes,
377
+ node_color=node_colors, margins=0.1,
378
+ alpha=node_alphas)
379
+
380
+ # Draw labels
381
+ labels = nx.get_node_attributes(g, 'label')
382
+ nx.draw_networkx_labels(g, pos, labels, font_size=10)
383
+
384
+ # Add legend
385
+ legend_elements = []
386
+ for node_type, color in node_type_colors.items():
387
+ legend_elements.append(
388
+ plt.Line2D([0], [0], marker='o', color='w', label=node_type,
389
+ markerfacecolor=color, markersize=10))
390
+ ax.legend(handles=legend_elements, loc='upper right',
391
+ bbox_to_anchor=(0.9, 1))
392
+
393
+ if path is not None:
394
+ plt.savefig(path, bbox_inches='tight')
395
+ else:
396
+ plt.show()
397
+
398
+ plt.close()
@@ -4,11 +4,11 @@ from typing import Literal
4
4
  import torch_geometric
5
5
 
6
6
 
7
- def warn(message: str) -> None:
7
+ def warn(message: str, stacklevel: int = 5) -> None:
8
8
  if torch_geometric.is_compiling():
9
9
  return
10
10
 
11
- warnings.warn(message)
11
+ warnings.warn(message, stacklevel=stacklevel)
12
12
 
13
13
 
14
14
  def filterwarnings(
@@ -19,3 +19,12 @@ def filterwarnings(
19
19
  return
20
20
 
21
21
  warnings.filterwarnings(action, message)
22
+
23
+
24
+ class WarningCache(set):
25
+ """Cache for warnings."""
26
+ def warn(self, message: str, stacklevel: int = 5) -> None:
27
+ """Trigger warning message."""
28
+ if message not in self:
29
+ self.add(message)
30
+ warn(message, stacklevel=stacklevel)
@@ -1,7 +0,0 @@
1
- from .sentence_transformer import SentenceTransformer
2
- from .llm import LLM
3
-
4
- __all__ = classes = [
5
- 'SentenceTransformer',
6
- 'LLM',
7
- ]