pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (228) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +226 -189
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +179 -31
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_trim_to_layer.py +2 -2
  215. torch_geometric/utils/convert.py +17 -10
  216. torch_geometric/utils/cross_entropy.py +34 -13
  217. torch_geometric/utils/embedding.py +91 -2
  218. torch_geometric/utils/geodesic.py +4 -3
  219. torch_geometric/utils/influence.py +279 -0
  220. torch_geometric/utils/map.py +13 -9
  221. torch_geometric/utils/nested.py +1 -1
  222. torch_geometric/utils/smiles.py +3 -3
  223. torch_geometric/utils/sparse.py +7 -14
  224. torch_geometric/visualization/__init__.py +2 -1
  225. torch_geometric/visualization/graph.py +250 -5
  226. torch_geometric/warnings.py +11 -2
  227. torch_geometric/nn/nlp/__init__.py +0 -7
  228. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -15,19 +15,16 @@ class Amazon(InMemoryDataset):
15
15
  map goods to their respective product category.
16
16
 
17
17
  Args:
18
- root (str): Root directory where the dataset should be saved.
19
- name (str): The name of the dataset (:obj:`"Computers"`,
20
- :obj:`"Photo"`).
21
- transform (callable, optional): A function/transform that takes in an
22
- :obj:`torch_geometric.data.Data` object and returns a transformed
18
+ root: Root directory where the dataset should be saved.
19
+ name: The name of the dataset (:obj:`"Computers"`, :obj:`"Photo"`).
20
+ transform: A function/transform that takes in a
21
+ :class:`torch_geometric.data.Data` object and returns a transformed
23
22
  version. The data object will be transformed before every access.
24
- (default: :obj:`None`)
25
- pre_transform (callable, optional): A function/transform that takes in
26
- an :obj:`torch_geometric.data.Data` object and returns a
23
+ pre_transform: A function/transform that takes in an
24
+ :class:`torch_geometric.data.Data` object and returns a
27
25
  transformed version. The data object will be transformed before
28
- being saved to disk. (default: :obj:`None`)
29
- force_reload (bool, optional): Whether to re-process the dataset.
30
- (default: :obj:`False`)
26
+ being saved to disk.
27
+ force_reload: Whether to re-process the dataset.
31
28
 
32
29
  **STATS:**
33
30
 
@@ -14,17 +14,16 @@ class AmazonBook(InMemoryDataset):
14
14
  No labels or features are provided.
15
15
 
16
16
  Args:
17
- root (str): Root directory where the dataset should be saved.
18
- transform (callable, optional): A function/transform that takes in an
19
- :obj:`torch_geometric.data.HeteroData` object and returns a
17
+ root: Root directory where the dataset should be saved.
18
+ transform: A function/transform that takes in an
19
+ :class:`torch_geometric.data.HeteroData` object and returns a
20
20
  transformed version. The data object will be transformed before
21
- every access. (default: :obj:`None`)
22
- pre_transform (callable, optional): A function/transform that takes in
23
- an :obj:`torch_geometric.data.HeteroData` object and returns a
21
+ every access.
22
+ pre_transform: A function/transform that takes in an
23
+ :class:`torch_geometric.data.HeteroData` object and returns a
24
24
  transformed version. The data object will be transformed before
25
- being saved to disk. (default: :obj:`None`)
26
- force_reload (bool, optional): Whether to re-process the dataset.
27
- (default: :obj:`False`)
25
+ being saved to disk.
26
+ force_reload: Whether to re-process the dataset.
28
27
  """
29
28
  url = ('https://raw.githubusercontent.com/gusye1234/LightGCN-PyTorch/'
30
29
  'master/data/amazon-book')
@@ -14,17 +14,15 @@ class AmazonProducts(InMemoryDataset):
14
14
  containing products and its categories.
15
15
 
16
16
  Args:
17
- root (str): Root directory where the dataset should be saved.
18
- transform (callable, optional): A function/transform that takes in an
19
- :obj:`torch_geometric.data.Data` object and returns a transformed
17
+ root: Root directory where the dataset should be saved.
18
+ transform: A function/transform that takes in an
19
+ :class:`torch_geometric.data.Data` object and returns a transformed
20
20
  version. The data object will be transformed before every access.
21
- (default: :obj:`None`)
22
- pre_transform (callable, optional): A function/transform that takes in
23
- an :obj:`torch_geometric.data.Data` object and returns a
21
+ pre_transform: A function/transform that takes in a
22
+ :class:`torch_geometric.data.Data` object and returns a
24
23
  transformed version. The data object will be transformed before
25
- being saved to disk. (default: :obj:`None`)
26
- force_reload (bool, optional): Whether to re-process the dataset.
27
- (default: :obj:`False`)
24
+ being saved to disk.
25
+ force_reload: Whether to re-process the dataset.
28
26
 
29
27
  **STATS:**
30
28
 
@@ -24,17 +24,16 @@ class AMiner(InMemoryDataset):
24
24
  truth labels for a subset of nodes.
25
25
 
26
26
  Args:
27
- root (str): Root directory where the dataset should be saved.
28
- transform (callable, optional): A function/transform that takes in an
29
- :obj:`torch_geometric.data.HeteroData` object and returns a
27
+ root: Root directory where the dataset should be saved.
28
+ transform: A function/transform that takes in a
29
+ :class:`torch_geometric.data.HeteroData` object and returns a
30
30
  transformed version. The data object will be transformed before
31
- every access. (default: :obj:`None`)
32
- pre_transform (callable, optional): A function/transform that takes in
33
- an :obj:`torch_geometric.data.HeteroData` object and returns a
31
+ every access.
32
+ pre_transform: A function/transform that takes in a
33
+ :class:`torch_geometric.data.HeteroData` object and returns a
34
34
  transformed version. The data object will be transformed before
35
- being saved to disk. (default: :obj:`None`)
36
- force_reload (bool, optional): Whether to re-process the dataset.
37
- (default: :obj:`False`)
35
+ being saved to disk.
36
+ force_reload: Whether to re-process the dataset.
38
37
  """
39
38
 
40
39
  url = 'https://www.dropbox.com/s/1bnz8r7mofx0osf/net_aminer.zip?dl=1'
@@ -30,25 +30,22 @@ class AQSOL(InMemoryDataset):
30
30
  the :class:`~torch_geometric.datasets.ZINC` dataset.
31
31
 
32
32
  Args:
33
- root (str): Root directory where the dataset should be saved.
34
- split (str, optional): If :obj:`"train"`, loads the training dataset.
33
+ root: Root directory where the dataset should be saved.
34
+ split: If :obj:`"train"`, loads the training dataset.
35
35
  If :obj:`"val"`, loads the validation dataset.
36
36
  If :obj:`"test"`, loads the test dataset.
37
- (default: :obj:`"train"`)
38
- transform (callable, optional): A function/transform that takes in an
39
- :obj:`torch_geometric.data.Data` object and returns a transformed
37
+ transform: A function/transform that takes in a
38
+ :class:`torch_geometric.data.Data` object and returns a transformed
40
39
  version. The data object will be transformed before every access.
41
- (default: :obj:`None`)
42
- pre_transform (callable, optional): A function/transform that takes in
43
- an :obj:`torch_geometric.data.Data` object and returns a
40
+ pre_transform: A function/transform that takes in a
41
+ :class:`torch_geometric.data.Data` object and returns a
44
42
  transformed version. The data object will be transformed before
45
- being saved to disk. (default: :obj:`None`)
43
+ being saved to disk.
46
44
  pre_filter (callable, optional): A function that takes in an
47
- :obj:`torch_geometric.data.Data` object and returns a boolean
45
+ :class:`torch_geometric.data.Data` object and returns a boolean
48
46
  value, indicating whether the data object should be included in
49
- the final dataset. (default: :obj:`None`)
50
- force_reload (bool, optional): Whether to re-process the dataset.
51
- (default: :obj:`False`)
47
+ the final dataset.
48
+ force_reload: Whether to re-process the dataset.
52
49
 
53
50
  **STATS:**
54
51
 
@@ -19,21 +19,19 @@ class AttributedGraphDataset(InMemoryDataset):
19
19
  <https://arxiv.org/abs/2009.00826>`_ paper.
20
20
 
21
21
  Args:
22
- root (str): Root directory where the dataset should be saved.
23
- name (str): The name of the dataset (:obj:`"Wiki"`, :obj:`"Cora"`
22
+ root: Root directory where the dataset should be saved.
23
+ name: The name of the dataset (:obj:`"Wiki"`, :obj:`"Cora"`,
24
24
  :obj:`"CiteSeer"`, :obj:`"PubMed"`, :obj:`"BlogCatalog"`,
25
25
  :obj:`"PPI"`, :obj:`"Flickr"`, :obj:`"Facebook"`, :obj:`"Twitter"`,
26
26
  :obj:`"TWeibo"`, :obj:`"MAG"`).
27
- transform (callable, optional): A function/transform that takes in an
28
- :obj:`torch_geometric.data.Data` object and returns a transformed
27
+ transform: A function/transform that takes in a
28
+ :class:`torch_geometric.data.Data` object and returns a transformed
29
29
  version. The data object will be transformed before every access.
30
- (default: :obj:`None`)
31
- pre_transform (callable, optional): A function/transform that takes in
32
- an :obj:`torch_geometric.data.Data` object and returns a
30
+ pre_transform: A function/transform that takes in a
31
+ :class:`torch_geometric.data.Data` object and returns a
33
32
  transformed version. The data object will be transformed before
34
- being saved to disk. (default: :obj:`None`)
35
- force_reload (bool, optional): Whether to re-process the dataset.
36
- (default: :obj:`False`)
33
+ being saved to disk.
34
+ force_reload: Whether to re-process the dataset.
37
35
 
38
36
  **STATS:**
39
37
 
@@ -25,21 +25,19 @@ class BAMultiShapesDataset(InMemoryDataset):
25
25
  This dataset is pre-computed from the official implementation.
26
26
 
27
27
  Args:
28
- root (str): Root directory where the dataset should be saved.
29
- transform (callable, optional): A function/transform that takes in an
30
- :obj:`torch_geometric.data.Data` object and returns a transformed
28
+ root: Root directory where the dataset should be saved.
29
+ transform: A function/transform that takes in a
30
+ :class:`torch_geometric.data.Data` object and returns a transformed
31
31
  version. The data object will be transformed before every access.
32
- (default: :obj:`None`)
33
- pre_transform (callable, optional): A function/transform that takes in
34
- an :obj:`torch_geometric.data.Data` object and returns a
32
+ pre_transform: A function/transform that takes in a
33
+ :class:`torch_geometric.data.Data` object and returns a
35
34
  transformed version. The data object will be transformed before
36
- being saved to disk. (default: :obj:`None`)
37
- pre_filter (callable, optional): A function that takes in an
38
- :obj:`torch_geometric.data.Data` object and returns a boolean
35
+ being saved to disk.
36
+ pre_filter: A function that takes in a
37
+ :class:`torch_geometric.data.Data` object and returns a boolean
39
38
  value, indicating whether the data object should be included in the
40
- final dataset. (default: :obj:`None`)
41
- force_reload (bool, optional): Whether to re-process the dataset.
42
- (default: :obj:`False`)
39
+ final dataset.
40
+ force_reload: Whether to re-process the dataset.
43
41
 
44
42
  **STATS:**
45
43
 
@@ -30,15 +30,14 @@ class BAShapes(InMemoryDataset):
30
30
  :class:`torch_geometric.datasets.graph_generator.BAGraph` instead.
31
31
 
32
32
  Args:
33
- connection_distribution (str, optional): Specifies how the houses
34
- and the BA graph get connected. Valid inputs are :obj:`"random"`
33
+ connection_distribution: Specifies how the houses and the BA graph get
34
+ connected. Valid inputs are :obj:`"random"`
35
35
  (random BA graph nodes are selected for connection to the houses),
36
36
  and :obj:`"uniform"` (uniformly distributed BA graph nodes are
37
- selected for connection to the houses). (default: :obj:`"random"`)
38
- transform (callable, optional): A function/transform that takes in an
39
- :obj:`torch_geometric.data.Data` object and returns a transformed
37
+ selected for connection to the houses).
38
+ transform: A function/transform that takes in a
39
+ :class:`torch_geometric.data.Data` object and returns a transformed
40
40
  version. The data object will be transformed before every access.
41
- (default: :obj:`None`)
42
41
  """
43
42
  def __init__(
44
43
  self,
@@ -0,0 +1,157 @@
1
+ import os.path as osp
2
+ from typing import Callable, Optional
3
+
4
+ from torch_geometric.data import (
5
+ Data,
6
+ InMemoryDataset,
7
+ download_url,
8
+ extract_tar,
9
+ )
10
+ from torch_geometric.io import fs
11
+
12
+
13
+ class CityNetwork(InMemoryDataset):
14
+ r"""The City-Networks are introduced in
15
+ `"Towards Quantifying Long-Range Interactions in Graph Machine Learning:
16
+ a Large Graph Dataset and a Measurement"
17
+ <https://arxiv.org/abs/2503.09008>`_ paper.
18
+ The dataset contains four city networks: `paris`, `shanghai`, `la`,
19
+ and `london`, where nodes represent junctions and edges represent
20
+ undirected road segments. The task is to predict each node's eccentricity
21
+ score, which is approximated based on its 16-hop neighborhood and naturally
22
+ requires long-range information. The score indicates how accessible one
23
+ node is in the network, and is mapped to 10 quantiles for transductive
24
+ classification. See the original
25
+ `source code <https://github.com/LeonResearch/City-Networks>`_ for more
26
+ details on the individual networks.
27
+
28
+ Args:
29
+ root (str): Root directory where the dataset should be saved.
30
+ name (str): The name of the dataset (``"paris"``, ``"shanghai"``,
31
+ ``"la"``, ``"london"``).
32
+ augmented (bool, optional): Whether to use the augmented node features
33
+ from edge features.(default: :obj:`True`)
34
+ transform (callable, optional): A function/transform that takes in an
35
+ :class:`~torch_geometric.data.Data` object and returns a
36
+ transformed version. The data object will be transformed before
37
+ every access. (default: :obj:`None`)
38
+ pre_transform (callable, optional): A function/transform that takes in
39
+ an :class:`~torch_geometric.data.Data` object and returns a
40
+ transformed version. The data object will be transformed before
41
+ being saved to disk. (default: :obj:`None`)
42
+ force_reload (bool, optional): Whether to re-process the dataset.
43
+ (default: :obj:`False`)
44
+
45
+ **STATS:**
46
+
47
+ .. list-table::
48
+ :widths: 10 10 10 10 10
49
+ :header-rows: 1
50
+
51
+ * - Name
52
+ - #nodes
53
+ - #edges
54
+ - #features
55
+ - #classes
56
+ * - paris
57
+ - 114,127
58
+ - 182,511
59
+ - 37
60
+ - 10
61
+ * - shanghai
62
+ - 183,917
63
+ - 262,092
64
+ - 37
65
+ - 10
66
+ * - la
67
+ - 240,587
68
+ - 341,523
69
+ - 37
70
+ - 10
71
+ * - london
72
+ - 568,795
73
+ - 756,502
74
+ - 37
75
+ - 10
76
+ """
77
+ url = "https://github.com/LeonResearch/City-Networks/raw/refs/heads/main/data/" # noqa: E501
78
+
79
+ def __init__(
80
+ self,
81
+ root: str,
82
+ name: str,
83
+ augmented: bool = True,
84
+ transform: Optional[Callable] = None,
85
+ pre_transform: Optional[Callable] = None,
86
+ force_reload: bool = False,
87
+ delete_raw: bool = False,
88
+ ) -> None:
89
+ self.name = name.lower()
90
+ assert self.name in ["paris", "shanghai", "la", "london"]
91
+ self.augmented = augmented
92
+ self.delete_raw = delete_raw
93
+ super().__init__(
94
+ root,
95
+ transform,
96
+ pre_transform,
97
+ force_reload=force_reload,
98
+ )
99
+ self.load(self.processed_paths[0])
100
+
101
+ @property
102
+ def raw_dir(self) -> str:
103
+ return osp.join(self.root, self.name, "raw")
104
+
105
+ @property
106
+ def processed_dir(self) -> str:
107
+ return osp.join(self.root, self.name, "processed")
108
+
109
+ @property
110
+ def raw_file_names(self) -> str:
111
+ return f"{self.name}.json"
112
+
113
+ @property
114
+ def processed_file_names(self) -> str:
115
+ return "data.pt"
116
+
117
+ def download(self) -> None:
118
+ self.download_path = download_url(
119
+ self.url + f"{self.name}.tar.gz",
120
+ self.raw_dir,
121
+ )
122
+
123
+ def process(self) -> None:
124
+ extract_tar(self.download_path, self.raw_dir)
125
+ data_path = osp.join(self.raw_dir, self.name)
126
+ node_feat = fs.torch_load(
127
+ osp.join(
128
+ data_path,
129
+ f"node_features{'_augmented' if self.augmented else ''}.pt",
130
+ ))
131
+ edge_index = fs.torch_load(osp.join(data_path, "edge_indices.pt"))
132
+ label = fs.torch_load(
133
+ osp.join(data_path, "10-chunk_16-hop_node_labels.pt"))
134
+ train_mask = fs.torch_load(osp.join(data_path, "train_mask.pt"))
135
+ val_mask = fs.torch_load(osp.join(data_path, "valid_mask.pt"))
136
+ test_mask = fs.torch_load(osp.join(data_path, "test_mask.pt"))
137
+ data = Data(
138
+ x=node_feat,
139
+ edge_index=edge_index,
140
+ y=label,
141
+ train_mask=train_mask,
142
+ val_mask=val_mask,
143
+ test_mask=test_mask,
144
+ )
145
+ if self.pre_transform is not None:
146
+ data = self.pre_transform(data)
147
+
148
+ self.save([data], self.processed_paths[0])
149
+
150
+ if self.delete_raw:
151
+ fs.rm(data_path)
152
+
153
+ def __repr__(self) -> str:
154
+ return (f"{self.__class__.__name__}("
155
+ f"root='{self.root}', "
156
+ f"name='{self.name}', "
157
+ f"augmented={self.augmented})")
@@ -73,7 +73,7 @@ class DBP15K(InMemoryDataset):
73
73
  def process(self) -> None:
74
74
  embs = {}
75
75
  with open(osp.join(self.raw_dir, 'sub.glove.300d')) as f:
76
- for i, line in enumerate(f):
76
+ for line in f:
77
77
  info = line.strip().split(' ')
78
78
  if len(info) > 300:
79
79
  embs[info[0]] = torch.tensor([float(x) for x in info[1:]])
@@ -0,0 +1,263 @@
1
+ import sys
2
+ from typing import Any, Callable, Dict, List, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ from torch_geometric.data import (
9
+ Data,
10
+ InMemoryDataset,
11
+ download_google_url,
12
+ extract_zip,
13
+ )
14
+ from torch_geometric.io import fs
15
+
16
+
17
+ def safe_index(lst: List[Any], e: int) -> int:
18
+ return lst.index(e) if e in lst else len(lst) - 1
19
+
20
+
21
+ class GitMolDataset(InMemoryDataset):
22
+ r"""The dataset from the `"GIT-Mol: A Multi-modal Large Language Model
23
+ for Molecular Science with Graph, Image, and Text"
24
+ <https://arxiv.org/pdf/2308.06911>`_ paper.
25
+
26
+ Args:
27
+ root (str): Root directory where the dataset should be saved.
28
+ transform (callable, optional): A function/transform that takes in an
29
+ :obj:`torch_geometric.data.Data` object and returns a transformed
30
+ version. The data object will be transformed before every access.
31
+ (default: :obj:`None`)
32
+ pre_transform (callable, optional): A function/transform that takes in
33
+ an :obj:`torch_geometric.data.Data` object and returns a
34
+ transformed version. The data object will be transformed before
35
+ being saved to disk. (default: :obj:`None`)
36
+ pre_filter (callable, optional): A function that takes in an
37
+ :obj:`torch_geometric.data.Data` object and returns a boolean
38
+ value, indicating whether the data object should be included in the
39
+ final dataset. (default: :obj:`None`)
40
+ force_reload (bool, optional): Whether to re-process the dataset.
41
+ (default: :obj:`False`)
42
+ split (int, optional): Datasets split, train/valid/test=0/1/2.
43
+ (default: :obj:`0`)
44
+ """
45
+
46
+ raw_url_id = '1loBXabD6ncAFY-vanRsVtRUSFkEtBweg'
47
+
48
+ def __init__(
49
+ self,
50
+ root: str,
51
+ transform: Optional[Callable] = None,
52
+ pre_transform: Optional[Callable] = None,
53
+ pre_filter: Optional[Callable] = None,
54
+ force_reload: bool = False,
55
+ split: int = 0,
56
+ ):
57
+ from torchvision import transforms
58
+
59
+ self.split = split
60
+
61
+ if self.split == 0:
62
+ self.img_transform = transforms.Compose([
63
+ transforms.Resize((224, 224)),
64
+ transforms.RandomRotation(15),
65
+ transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
66
+ transforms.ToTensor(),
67
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
68
+ std=[0.229, 0.224, 0.225])
69
+ ])
70
+ else:
71
+ self.img_transform = transforms.Compose([
72
+ transforms.Resize((224, 224)),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
75
+ std=[0.229, 0.224, 0.225])
76
+ ])
77
+
78
+ super().__init__(root, transform, pre_transform, pre_filter,
79
+ force_reload=force_reload)
80
+
81
+ self.load(self.processed_paths[0])
82
+
83
+ @property
84
+ def raw_file_names(self) -> List[str]:
85
+ return ['train_3500.pkl', 'valid_450.pkl', 'test_450.pkl']
86
+
87
+ @property
88
+ def processed_file_names(self) -> str:
89
+ return ['train.pt', 'valid.pt', 'test.pt'][self.split]
90
+
91
+ def download(self) -> None:
92
+ file_path = download_google_url(
93
+ self.raw_url_id,
94
+ self.raw_dir,
95
+ 'gitmol.zip',
96
+ )
97
+ extract_zip(file_path, self.raw_dir)
98
+
99
+ def process(self) -> None:
100
+ import pandas as pd
101
+ from PIL import Image
102
+
103
+ try:
104
+ from rdkit import Chem, RDLogger
105
+ RDLogger.DisableLog('rdApp.*') # type: ignore[attr-defined]
106
+ WITH_RDKIT = True
107
+
108
+ except ImportError:
109
+ WITH_RDKIT = False
110
+
111
+ if not WITH_RDKIT:
112
+ print(("Using a pre-processed version of the dataset. Please "
113
+ "install 'rdkit' to alternatively process the raw data."),
114
+ file=sys.stderr)
115
+
116
+ data_list = fs.torch_load(self.raw_paths[0])
117
+ data_list = [Data(**data_dict) for data_dict in data_list]
118
+
119
+ if self.pre_filter is not None:
120
+ data_list = [d for d in data_list if self.pre_filter(d)]
121
+
122
+ if self.pre_transform is not None:
123
+ data_list = [self.pre_transform(d) for d in data_list]
124
+
125
+ self.save(data_list, self.processed_paths[0])
126
+ return
127
+
128
+ allowable_features: Dict[str, List[Any]] = {
129
+ 'possible_atomic_num_list':
130
+ list(range(1, 119)) + ['misc'],
131
+ 'possible_formal_charge_list':
132
+ [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'],
133
+ 'possible_chirality_list': [
134
+ Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
135
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
136
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
137
+ Chem.rdchem.ChiralType.CHI_OTHER
138
+ ],
139
+ 'possible_hybridization_list': [
140
+ Chem.rdchem.HybridizationType.SP,
141
+ Chem.rdchem.HybridizationType.SP2,
142
+ Chem.rdchem.HybridizationType.SP3,
143
+ Chem.rdchem.HybridizationType.SP3D,
144
+ Chem.rdchem.HybridizationType.SP3D2,
145
+ Chem.rdchem.HybridizationType.UNSPECIFIED, 'misc'
146
+ ],
147
+ 'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
148
+ 'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6],
149
+ 'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
150
+ 'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'],
151
+ 'possible_is_aromatic_list': [False, True],
152
+ 'possible_is_in_ring_list': [False, True],
153
+ 'possible_bond_type_list': [
154
+ Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
155
+ Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC,
156
+ Chem.rdchem.BondType.ZERO
157
+ ],
158
+ 'possible_bond_dirs': [ # only for double bond stereo information
159
+ Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT,
160
+ Chem.rdchem.BondDir.ENDDOWNRIGHT
161
+ ],
162
+ 'possible_bond_stereo_list': [
163
+ Chem.rdchem.BondStereo.STEREONONE,
164
+ Chem.rdchem.BondStereo.STEREOZ,
165
+ Chem.rdchem.BondStereo.STEREOE,
166
+ Chem.rdchem.BondStereo.STEREOCIS,
167
+ Chem.rdchem.BondStereo.STEREOTRANS,
168
+ Chem.rdchem.BondStereo.STEREOANY,
169
+ ],
170
+ 'possible_is_conjugated_list': [False, True]
171
+ }
172
+
173
+ data = pd.read_pickle(
174
+ f'{self.raw_dir}/igcdata_toy/{self.raw_file_names[self.split]}')
175
+
176
+ data_list = []
177
+ for _, r in tqdm(data.iterrows(), total=data.shape[0]):
178
+ smiles = r['isosmiles']
179
+ mol = Chem.MolFromSmiles(smiles.strip('\n'))
180
+ if mol is not None:
181
+ # text
182
+ summary = r['summary']
183
+ # image
184
+ cid = r['cid']
185
+ img_file = f'{self.raw_dir}/igcdata_toy/imgs/CID_{cid}.png'
186
+ img = Image.open(img_file).convert('RGB')
187
+ img = self.img_transform(img).unsqueeze(0)
188
+ # graph
189
+ atom_features_list = []
190
+ for atom in mol.GetAtoms():
191
+ atom_feature = [
192
+ safe_index(
193
+ allowable_features['possible_atomic_num_list'],
194
+ atom.GetAtomicNum()),
195
+ allowable_features['possible_chirality_list'].index(
196
+ atom.GetChiralTag()),
197
+ safe_index(allowable_features['possible_degree_list'],
198
+ atom.GetTotalDegree()),
199
+ safe_index(
200
+ allowable_features['possible_formal_charge_list'],
201
+ atom.GetFormalCharge()),
202
+ safe_index(allowable_features['possible_numH_list'],
203
+ atom.GetTotalNumHs()),
204
+ safe_index(
205
+ allowable_features[
206
+ 'possible_number_radical_e_list'],
207
+ atom.GetNumRadicalElectrons()),
208
+ safe_index(
209
+ allowable_features['possible_hybridization_list'],
210
+ atom.GetHybridization()),
211
+ allowable_features['possible_is_aromatic_list'].index(
212
+ atom.GetIsAromatic()),
213
+ allowable_features['possible_is_in_ring_list'].index(
214
+ atom.IsInRing()),
215
+ ]
216
+ atom_features_list.append(atom_feature)
217
+ x = torch.tensor(np.array(atom_features_list),
218
+ dtype=torch.long)
219
+
220
+ edges_list = []
221
+ edge_features_list = []
222
+ for bond in mol.GetBonds():
223
+ i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
224
+ edge_feature = [
225
+ safe_index(
226
+ allowable_features['possible_bond_type_list'],
227
+ bond.GetBondType()),
228
+ allowable_features['possible_bond_stereo_list'].index(
229
+ bond.GetStereo()),
230
+ allowable_features['possible_is_conjugated_list'].
231
+ index(bond.GetIsConjugated()),
232
+ ]
233
+ edges_list.append((i, j))
234
+ edge_features_list.append(edge_feature)
235
+ edges_list.append((j, i))
236
+ edge_features_list.append(edge_feature)
237
+
238
+ edge_index = torch.tensor(
239
+ np.array(edges_list).T,
240
+ dtype=torch.long,
241
+ )
242
+ edge_attr = torch.tensor(
243
+ np.array(edge_features_list),
244
+ dtype=torch.long,
245
+ )
246
+
247
+ data = Data(
248
+ x=x,
249
+ edge_index=edge_index,
250
+ smiles=smiles,
251
+ edge_attr=edge_attr,
252
+ image=img,
253
+ caption=summary,
254
+ )
255
+
256
+ if self.pre_filter is not None and not self.pre_filter(data):
257
+ continue
258
+ if self.pre_transform is not None:
259
+ data = self.pre_transform(data)
260
+
261
+ data_list.append(data)
262
+
263
+ self.save(data_list, self.processed_paths[0])