pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
  2. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
  3. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +28 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +28 -19
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/collate.py +8 -2
  12. torch_geometric/data/data.py +16 -8
  13. torch_geometric/data/database.py +61 -15
  14. torch_geometric/data/dataset.py +14 -6
  15. torch_geometric/data/feature_store.py +25 -42
  16. torch_geometric/data/graph_store.py +1 -5
  17. torch_geometric/data/hetero_data.py +18 -9
  18. torch_geometric/data/in_memory_dataset.py +2 -4
  19. torch_geometric/data/large_graph_indexer.py +677 -0
  20. torch_geometric/data/lightning/datamodule.py +4 -4
  21. torch_geometric/data/separate.py +6 -1
  22. torch_geometric/data/storage.py +17 -7
  23. torch_geometric/data/summary.py +14 -4
  24. torch_geometric/data/temporal.py +1 -2
  25. torch_geometric/datasets/__init__.py +17 -2
  26. torch_geometric/datasets/actor.py +9 -11
  27. torch_geometric/datasets/airfrans.py +15 -18
  28. torch_geometric/datasets/airports.py +10 -12
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +9 -10
  31. torch_geometric/datasets/amazon_products.py +9 -10
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/bitcoin_otc.py +1 -1
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/cornell.py +145 -0
  40. torch_geometric/datasets/dblp.py +2 -1
  41. torch_geometric/datasets/dbp15k.py +2 -2
  42. torch_geometric/datasets/fake.py +1 -3
  43. torch_geometric/datasets/flickr.py +2 -1
  44. torch_geometric/datasets/freebase.py +1 -1
  45. torch_geometric/datasets/gdelt_lite.py +3 -2
  46. torch_geometric/datasets/ged_dataset.py +3 -2
  47. torch_geometric/datasets/git_mol_dataset.py +263 -0
  48. torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
  49. torch_geometric/datasets/hgb_dataset.py +8 -8
  50. torch_geometric/datasets/imdb.py +2 -1
  51. torch_geometric/datasets/karate.py +3 -2
  52. torch_geometric/datasets/last_fm.py +2 -1
  53. torch_geometric/datasets/linkx_dataset.py +4 -3
  54. torch_geometric/datasets/lrgb.py +3 -5
  55. torch_geometric/datasets/malnet_tiny.py +4 -3
  56. torch_geometric/datasets/mnist_superpixels.py +2 -3
  57. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  58. torch_geometric/datasets/molecule_net.py +15 -3
  59. torch_geometric/datasets/motif_generator/base.py +0 -1
  60. torch_geometric/datasets/neurograph.py +1 -3
  61. torch_geometric/datasets/ogb_mag.py +1 -1
  62. torch_geometric/datasets/opf.py +239 -0
  63. torch_geometric/datasets/ose_gvcs.py +1 -1
  64. torch_geometric/datasets/pascal.py +11 -9
  65. torch_geometric/datasets/pascal_pf.py +1 -1
  66. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  67. torch_geometric/datasets/pcqm4m.py +10 -3
  68. torch_geometric/datasets/ppi.py +1 -1
  69. torch_geometric/datasets/qm9.py +8 -7
  70. torch_geometric/datasets/rcdd.py +4 -4
  71. torch_geometric/datasets/reddit.py +2 -1
  72. torch_geometric/datasets/reddit2.py +2 -1
  73. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  74. torch_geometric/datasets/s3dis.py +5 -3
  75. torch_geometric/datasets/shapenet.py +3 -3
  76. torch_geometric/datasets/shrec2016.py +2 -2
  77. torch_geometric/datasets/snap_dataset.py +7 -1
  78. torch_geometric/datasets/tag_dataset.py +350 -0
  79. torch_geometric/datasets/upfd.py +2 -1
  80. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  81. torch_geometric/datasets/webkb.py +2 -2
  82. torch_geometric/datasets/wikics.py +1 -1
  83. torch_geometric/datasets/wikidata.py +3 -2
  84. torch_geometric/datasets/wikipedia_network.py +2 -2
  85. torch_geometric/datasets/willow_object_class.py +1 -1
  86. torch_geometric/datasets/word_net.py +2 -2
  87. torch_geometric/datasets/yelp.py +2 -1
  88. torch_geometric/datasets/zinc.py +1 -1
  89. torch_geometric/device.py +42 -0
  90. torch_geometric/distributed/local_feature_store.py +3 -2
  91. torch_geometric/distributed/local_graph_store.py +2 -1
  92. torch_geometric/distributed/partition.py +9 -8
  93. torch_geometric/edge_index.py +616 -438
  94. torch_geometric/explain/algorithm/base.py +0 -1
  95. torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
  96. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  97. torch_geometric/explain/explanation.py +2 -2
  98. torch_geometric/graphgym/checkpoint.py +2 -1
  99. torch_geometric/graphgym/logger.py +4 -4
  100. torch_geometric/graphgym/loss.py +1 -1
  101. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  102. torch_geometric/index.py +826 -0
  103. torch_geometric/inspector.py +8 -3
  104. torch_geometric/io/fs.py +28 -2
  105. torch_geometric/io/npz.py +2 -1
  106. torch_geometric/io/off.py +2 -2
  107. torch_geometric/io/sdf.py +2 -2
  108. torch_geometric/io/tu.py +4 -5
  109. torch_geometric/loader/__init__.py +4 -0
  110. torch_geometric/loader/cluster.py +10 -4
  111. torch_geometric/loader/graph_saint.py +2 -1
  112. torch_geometric/loader/ibmb_loader.py +12 -4
  113. torch_geometric/loader/mixin.py +1 -1
  114. torch_geometric/loader/neighbor_loader.py +1 -1
  115. torch_geometric/loader/neighbor_sampler.py +2 -2
  116. torch_geometric/loader/prefetch.py +1 -1
  117. torch_geometric/loader/rag_loader.py +107 -0
  118. torch_geometric/loader/utils.py +8 -7
  119. torch_geometric/loader/zip_loader.py +10 -0
  120. torch_geometric/metrics/__init__.py +11 -2
  121. torch_geometric/metrics/link_pred.py +159 -34
  122. torch_geometric/nn/aggr/__init__.py +4 -0
  123. torch_geometric/nn/aggr/attention.py +0 -2
  124. torch_geometric/nn/aggr/base.py +2 -4
  125. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  126. torch_geometric/nn/aggr/set_transformer.py +1 -1
  127. torch_geometric/nn/aggr/variance_preserving.py +33 -0
  128. torch_geometric/nn/attention/__init__.py +5 -1
  129. torch_geometric/nn/attention/qformer.py +71 -0
  130. torch_geometric/nn/conv/collect.jinja +7 -4
  131. torch_geometric/nn/conv/cugraph/base.py +8 -12
  132. torch_geometric/nn/conv/edge_conv.py +3 -2
  133. torch_geometric/nn/conv/fused_gat_conv.py +1 -1
  134. torch_geometric/nn/conv/gat_conv.py +35 -7
  135. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  136. torch_geometric/nn/conv/general_conv.py +1 -1
  137. torch_geometric/nn/conv/graph_conv.py +21 -3
  138. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  139. torch_geometric/nn/conv/hetero_conv.py +3 -3
  140. torch_geometric/nn/conv/hgt_conv.py +1 -1
  141. torch_geometric/nn/conv/message_passing.py +138 -87
  142. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  143. torch_geometric/nn/conv/propagate.jinja +9 -1
  144. torch_geometric/nn/conv/rgcn_conv.py +5 -5
  145. torch_geometric/nn/conv/spline_conv.py +4 -4
  146. torch_geometric/nn/conv/x_conv.py +3 -2
  147. torch_geometric/nn/dense/linear.py +11 -6
  148. torch_geometric/nn/fx.py +3 -3
  149. torch_geometric/nn/model_hub.py +3 -1
  150. torch_geometric/nn/models/__init__.py +10 -2
  151. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  152. torch_geometric/nn/models/dimenet_utils.py +5 -7
  153. torch_geometric/nn/models/g_retriever.py +230 -0
  154. torch_geometric/nn/models/git_mol.py +336 -0
  155. torch_geometric/nn/models/glem.py +385 -0
  156. torch_geometric/nn/models/gnnff.py +0 -1
  157. torch_geometric/nn/models/graph_unet.py +12 -3
  158. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  159. torch_geometric/nn/models/lightgcn.py +1 -1
  160. torch_geometric/nn/models/metapath2vec.py +5 -5
  161. torch_geometric/nn/models/molecule_gpt.py +222 -0
  162. torch_geometric/nn/models/node2vec.py +2 -3
  163. torch_geometric/nn/models/schnet.py +2 -1
  164. torch_geometric/nn/models/signed_gcn.py +3 -3
  165. torch_geometric/nn/module_dict.py +2 -2
  166. torch_geometric/nn/nlp/__init__.py +9 -0
  167. torch_geometric/nn/nlp/llm.py +322 -0
  168. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  169. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  170. torch_geometric/nn/norm/batch_norm.py +1 -1
  171. torch_geometric/nn/parameter_dict.py +2 -2
  172. torch_geometric/nn/pool/__init__.py +21 -5
  173. torch_geometric/nn/pool/cluster_pool.py +145 -0
  174. torch_geometric/nn/pool/connect/base.py +0 -1
  175. torch_geometric/nn/pool/edge_pool.py +1 -1
  176. torch_geometric/nn/pool/graclus.py +4 -2
  177. torch_geometric/nn/pool/pool.py +8 -2
  178. torch_geometric/nn/pool/select/base.py +0 -1
  179. torch_geometric/nn/pool/voxel_grid.py +3 -2
  180. torch_geometric/nn/resolver.py +1 -1
  181. torch_geometric/nn/sequential.jinja +10 -23
  182. torch_geometric/nn/sequential.py +204 -78
  183. torch_geometric/nn/summary.py +1 -1
  184. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  185. torch_geometric/profile/__init__.py +2 -0
  186. torch_geometric/profile/nvtx.py +66 -0
  187. torch_geometric/profile/profiler.py +30 -19
  188. torch_geometric/resolver.py +1 -1
  189. torch_geometric/sampler/base.py +34 -13
  190. torch_geometric/sampler/neighbor_sampler.py +11 -10
  191. torch_geometric/sampler/utils.py +1 -1
  192. torch_geometric/template.py +1 -0
  193. torch_geometric/testing/__init__.py +6 -2
  194. torch_geometric/testing/decorators.py +53 -20
  195. torch_geometric/testing/feature_store.py +1 -1
  196. torch_geometric/transforms/__init__.py +2 -0
  197. torch_geometric/transforms/add_metapaths.py +5 -5
  198. torch_geometric/transforms/add_positional_encoding.py +1 -1
  199. torch_geometric/transforms/delaunay.py +65 -14
  200. torch_geometric/transforms/face_to_edge.py +32 -3
  201. torch_geometric/transforms/gdc.py +7 -6
  202. torch_geometric/transforms/laplacian_lambda_max.py +3 -3
  203. torch_geometric/transforms/mask.py +5 -1
  204. torch_geometric/transforms/node_property_split.py +1 -2
  205. torch_geometric/transforms/pad.py +7 -6
  206. torch_geometric/transforms/random_link_split.py +1 -1
  207. torch_geometric/transforms/remove_self_loops.py +36 -0
  208. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  209. torch_geometric/transforms/to_sparse_tensor.py +1 -1
  210. torch_geometric/transforms/two_hop.py +1 -1
  211. torch_geometric/transforms/virtual_node.py +2 -1
  212. torch_geometric/typing.py +43 -6
  213. torch_geometric/utils/__init__.py +5 -1
  214. torch_geometric/utils/_negative_sampling.py +1 -1
  215. torch_geometric/utils/_normalize_edge_index.py +46 -0
  216. torch_geometric/utils/_scatter.py +38 -12
  217. torch_geometric/utils/_subgraph.py +4 -0
  218. torch_geometric/utils/_tree_decomposition.py +2 -2
  219. torch_geometric/utils/augmentation.py +1 -1
  220. torch_geometric/utils/convert.py +12 -8
  221. torch_geometric/utils/geodesic.py +24 -22
  222. torch_geometric/utils/hetero.py +1 -1
  223. torch_geometric/utils/map.py +8 -2
  224. torch_geometric/utils/smiles.py +65 -27
  225. torch_geometric/utils/sparse.py +39 -25
  226. torch_geometric/visualization/graph.py +3 -4
@@ -30,15 +30,14 @@ class BAShapes(InMemoryDataset):
30
30
  :class:`torch_geometric.datasets.graph_generator.BAGraph` instead.
31
31
 
32
32
  Args:
33
- connection_distribution (str, optional): Specifies how the houses
34
- and the BA graph get connected. Valid inputs are :obj:`"random"`
33
+ connection_distribution: Specifies how the houses and the BA graph get
34
+ connected. Valid inputs are :obj:`"random"`
35
35
  (random BA graph nodes are selected for connection to the houses),
36
36
  and :obj:`"uniform"` (uniformly distributed BA graph nodes are
37
- selected for connection to the houses). (default: :obj:`"random"`)
38
- transform (callable, optional): A function/transform that takes in an
39
- :obj:`torch_geometric.data.Data` object and returns a transformed
37
+ selected for connection to the houses).
38
+ transform: A function/transform that takes in a
39
+ :class:`torch_geometric.data.Data` object and returns a transformed
40
40
  version. The data object will be transformed before every access.
41
- (default: :obj:`None`)
42
41
  """
43
42
  def __init__(
44
43
  self,
@@ -87,7 +87,7 @@ class BitcoinOTC(InMemoryDataset):
87
87
  os.unlink(path)
88
88
 
89
89
  def process(self) -> None:
90
- with open(self.raw_paths[0], 'r') as f:
90
+ with open(self.raw_paths[0]) as f:
91
91
  lines = [[x for x in line.split(',')]
92
92
  for line in f.read().split('\n')[:-1]]
93
93
 
@@ -94,7 +94,7 @@ class BrcaTcga(InMemoryDataset):
94
94
  graph_feat = torch.from_numpy(graph_feat).to(torch.float)
95
95
  graph_labels = np.loadtxt(self.raw_paths[1], delimiter=',')
96
96
  graph_label = torch.from_numpy(graph_labels).to(torch.float)
97
- edge_index = torch.load(self.raw_paths[2])
97
+ edge_index = fs.torch_load(self.raw_paths[2])
98
98
 
99
99
  data_list = []
100
100
  for x, y in zip(graph_feat, graph_label):
@@ -0,0 +1,145 @@
1
+ import os.path as osp
2
+ from typing import Callable, List, Optional
3
+
4
+ import torch
5
+
6
+ from torch_geometric.data import InMemoryDataset, download_url
7
+ from torch_geometric.data.hypergraph_data import HyperGraphData
8
+
9
+
10
+ class CornellTemporalHyperGraphDataset(InMemoryDataset):
11
+ r"""A collection of temporal higher-order network datasets from the
12
+ `"Simplicial Closure and higher-order link prediction"
13
+ <https://arxiv.org/abs/1802.06916>`_ paper.
14
+ Each of the datasets is a timestamped sequence of simplices, where a
15
+ simplex is a set of :math:`k` nodes.
16
+
17
+ See the original `datasets page
18
+ <https://www.cs.cornell.edu/~arb/data/>`_ for more details about
19
+ individual datasets.
20
+
21
+ Args:
22
+ root (str): Root directory where the dataset should be saved.
23
+ name (str): The name of the dataset.
24
+ split (str, optional): If :obj:`"train"`, loads the training dataset.
25
+ If :obj:`"val"`, loads the validation dataset.
26
+ If :obj:`"test"`, loads the test dataset.
27
+ (default: :obj:`"train"`)
28
+ setting (str, optional): If :obj:`"transductive"`, loads the dataset
29
+ for transductive training.
30
+ If :obj:`"inductive"`, loads the dataset for inductive training.
31
+ (default: :obj:`"transductive"`)
32
+ transform (callable, optional): A function/transform that takes in an
33
+ :obj:`torch_geometric.data.Data` object and returns a transformed
34
+ version. The data object will be transformed before every access.
35
+ (default: :obj:`None`)
36
+ pre_transform (callable, optional): A function/transform that takes in
37
+ an :obj:`torch_geometric.data.Data` object and returns a
38
+ transformed version. The data object will be transformed before
39
+ being saved to disk. (default: :obj:`None`)
40
+ pre_filter (callable, optional): A function that takes in an
41
+ :obj:`torch_geometric.data.Data` object and returns a boolean
42
+ value, indicating whether the data object should be included in the
43
+ final dataset. (default: :obj:`None`)
44
+ force_reload (bool, optional): Whether to re-process the dataset.
45
+ (default: :obj:`False`)
46
+ """
47
+ names = [
48
+ 'email-Eu',
49
+ 'email-Enron',
50
+ 'NDC-classes',
51
+ 'tags-math-sx',
52
+ 'email-Eu-25',
53
+ 'NDC-substances',
54
+ 'congress-bills',
55
+ 'tags-ask-ubuntu',
56
+ 'email-Enron-25',
57
+ 'NDC-classes-25',
58
+ 'threads-ask-ubuntu',
59
+ 'contact-high-school',
60
+ 'NDC-substances-25',
61
+ 'congress-bills-25',
62
+ 'contact-primary-school',
63
+ ]
64
+ url = ('https://huggingface.co/datasets/SauravMaheshkar/{}/raw/main/'
65
+ 'processed/{}/{}')
66
+
67
+ def __init__(
68
+ self,
69
+ root: str,
70
+ name: str,
71
+ split: str = 'train',
72
+ setting: str = 'transductive',
73
+ transform: Optional[Callable] = None,
74
+ pre_transform: Optional[Callable] = None,
75
+ pre_filter: Optional[Callable] = None,
76
+ force_reload: bool = False,
77
+ ) -> None:
78
+ assert name in self.names
79
+ assert setting in ['transductive', 'inductive']
80
+
81
+ self.name = name
82
+ self.setting = setting
83
+
84
+ super().__init__(root, transform, pre_transform, pre_filter,
85
+ force_reload)
86
+
87
+ if split == 'train':
88
+ path = self.processed_paths[0]
89
+ elif split == 'val':
90
+ path = self.processed_paths[1]
91
+ elif split == 'test':
92
+ path = self.processed_paths[2]
93
+ else:
94
+ raise ValueError(f"Split '{split}' not found")
95
+
96
+ self.load(path)
97
+
98
+ @property
99
+ def raw_dir(self) -> str:
100
+ return osp.join(self.root, self.name, self.setting, 'raw')
101
+
102
+ @property
103
+ def raw_file_names(self) -> List[str]:
104
+ return ['train_df.csv', 'val_df.csv', 'test_df.csv']
105
+
106
+ @property
107
+ def processed_dir(self) -> str:
108
+ return osp.join(self.root, self.name, self.setting, 'processed')
109
+
110
+ @property
111
+ def processed_file_names(self) -> List[str]:
112
+ return ['train_data.pt', 'val_data.pt', 'test_data.pt']
113
+
114
+ def download(self) -> None:
115
+ for filename in self.raw_file_names:
116
+ url = self.url.format(self.name, self.setting, filename)
117
+ download_url(url, self.raw_dir)
118
+
119
+ def process(self) -> None:
120
+ import pandas as pd
121
+
122
+ for raw_path, path in zip(self.raw_paths, self.processed_paths):
123
+ df = pd.read_csv(raw_path)
124
+
125
+ data_list = []
126
+ for i, row in df.iterrows():
127
+ edge_indices: List[List[int]] = [[], []]
128
+ for node in eval(row['nodes']): # str(list) -> list:
129
+ edge_indices[0].append(node)
130
+ edge_indices[1].append(i) # Use `i` as hyper-edge index.
131
+
132
+ x = torch.tensor([[row['timestamp']]], dtype=torch.float)
133
+ edge_index = torch.tensor(edge_indices)
134
+
135
+ data = HyperGraphData(x=x, edge_index=edge_index)
136
+
137
+ if self.pre_filter is not None and not self.pre_filter(data):
138
+ continue
139
+
140
+ if self.pre_transform is not None:
141
+ data = self.pre_transform(data)
142
+
143
+ data_list.append(data)
144
+
145
+ self.save(data_list, path)
@@ -4,7 +4,6 @@ from itertools import product
4
4
  from typing import Callable, List, Optional
5
5
 
6
6
  import numpy as np
7
- import scipy.sparse as sp
8
7
  import torch
9
8
 
10
9
  from torch_geometric.data import (
@@ -110,6 +109,8 @@ class DBLP(InMemoryDataset):
110
109
  os.remove(path)
111
110
 
112
111
  def process(self) -> None:
112
+ import scipy.sparse as sp
113
+
113
114
  data = HeteroData()
114
115
 
115
116
  node_types = ['author', 'paper', 'term', 'conference']
@@ -72,7 +72,7 @@ class DBP15K(InMemoryDataset):
72
72
 
73
73
  def process(self) -> None:
74
74
  embs = {}
75
- with open(osp.join(self.raw_dir, 'sub.glove.300d'), 'r') as f:
75
+ with open(osp.join(self.raw_dir, 'sub.glove.300d')) as f:
76
76
  for i, line in enumerate(f):
77
77
  info = line.strip().split(' ')
78
78
  if len(info) > 300:
@@ -112,7 +112,7 @@ class DBP15K(InMemoryDataset):
112
112
  subj, rel, obj = g1.t()
113
113
 
114
114
  x_dict = {}
115
- with open(feature_path, 'r') as f:
115
+ with open(feature_path) as f:
116
116
  for line in f:
117
117
  info = line.strip().split('\t')
118
118
  info = info if len(info) == 2 else info + ['**UNK**']
@@ -170,7 +170,7 @@ class FakeHeteroDataset(InMemoryDataset):
170
170
  random.shuffle(edge_types)
171
171
 
172
172
  self.edge_types: List[Tuple[str, str, str]] = []
173
- count: Dict[Tuple[str, str], int] = defaultdict(lambda: 0)
173
+ count: Dict[Tuple[str, str], int] = defaultdict(int)
174
174
  for edge_type in edge_types[:max(num_edge_types, 1)]:
175
175
  rel = f'e{count[edge_type]}'
176
176
  count[edge_type] += 1
@@ -222,8 +222,6 @@ class FakeHeteroDataset(InMemoryDataset):
222
222
  elif self.edge_dim == 1:
223
223
  store.edge_weight = torch.rand(store.num_edges)
224
224
 
225
- pass
226
-
227
225
  if self._num_classes > 0 and self.task == 'graph':
228
226
  data.y = torch.tensor([random.randint(0, self._num_classes - 1)])
229
227
 
@@ -3,7 +3,6 @@ import os.path as osp
3
3
  from typing import Callable, List, Optional
4
4
 
5
5
  import numpy as np
6
- import scipy.sparse as sp
7
6
  import torch
8
7
 
9
8
  from torch_geometric.data import Data, InMemoryDataset, download_google_url
@@ -73,6 +72,8 @@ class Flickr(InMemoryDataset):
73
72
  download_google_url(self.role_id, self.raw_dir, 'role.json')
74
73
 
75
74
  def process(self) -> None:
75
+ import scipy.sparse as sp
76
+
76
77
  f = np.load(osp.join(self.raw_dir, 'adj_full.npz'))
77
78
  adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape'])
78
79
  adj = adj.tocoo()
@@ -75,7 +75,7 @@ class FB15k_237(InMemoryDataset):
75
75
  rel_dict: Dict[str, int] = {}
76
76
 
77
77
  for path in self.raw_paths:
78
- with open(path, 'r') as f:
78
+ with open(path) as f:
79
79
  lines = [x.split('\t') for x in f.read().split('\n')[:-1]]
80
80
 
81
81
  edge_index = torch.empty((2, len(lines)), dtype=torch.long)
@@ -9,6 +9,7 @@ from torch_geometric.data import (
9
9
  download_url,
10
10
  extract_zip,
11
11
  )
12
+ from torch_geometric.io import fs
12
13
 
13
14
 
14
15
  class GDELTLite(InMemoryDataset):
@@ -80,9 +81,9 @@ class GDELTLite(InMemoryDataset):
80
81
  def process(self) -> None:
81
82
  import pandas as pd
82
83
 
83
- x = torch.load(self.raw_paths[0])
84
+ x = fs.torch_load(self.raw_paths[0])
84
85
  df = pd.read_csv(self.raw_paths[1])
85
- edge_attr = torch.load(self.raw_paths[2])
86
+ edge_attr = fs.torch_load(self.raw_paths[2])
86
87
 
87
88
  row = torch.from_numpy(df['src'].values)
88
89
  col = torch.from_numpy(df['dst'].values)
@@ -13,6 +13,7 @@ from torch_geometric.data import (
13
13
  extract_tar,
14
14
  extract_zip,
15
15
  )
16
+ from torch_geometric.io import fs
16
17
  from torch_geometric.utils import one_hot, to_undirected
17
18
 
18
19
 
@@ -145,9 +146,9 @@ class GEDDataset(InMemoryDataset):
145
146
  path = self.processed_paths[0] if train else self.processed_paths[1]
146
147
  self.load(path)
147
148
  path = osp.join(self.processed_dir, f'{self.name}_ged.pt')
148
- self.ged = torch.load(path)
149
+ self.ged = fs.torch_load(path)
149
150
  path = osp.join(self.processed_dir, f'{self.name}_norm_ged.pt')
150
- self.norm_ged = torch.load(path)
151
+ self.norm_ged = fs.torch_load(path)
151
152
 
152
153
  @property
153
154
  def raw_file_names(self) -> List[str]:
@@ -0,0 +1,263 @@
1
+ import sys
2
+ from typing import Any, Callable, Dict, List, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ from torch_geometric.data import (
9
+ Data,
10
+ InMemoryDataset,
11
+ download_google_url,
12
+ extract_zip,
13
+ )
14
+ from torch_geometric.io import fs
15
+
16
+
17
+ def safe_index(lst: List[Any], e: int) -> int:
18
+ return lst.index(e) if e in lst else len(lst) - 1
19
+
20
+
21
+ class GitMolDataset(InMemoryDataset):
22
+ r"""The dataset from the `"GIT-Mol: A Multi-modal Large Language Model
23
+ for Molecular Science with Graph, Image, and Text"
24
+ <https://arxiv.org/pdf/2308.06911>`_ paper.
25
+
26
+ Args:
27
+ root (str): Root directory where the dataset should be saved.
28
+ transform (callable, optional): A function/transform that takes in an
29
+ :obj:`torch_geometric.data.Data` object and returns a transformed
30
+ version. The data object will be transformed before every access.
31
+ (default: :obj:`None`)
32
+ pre_transform (callable, optional): A function/transform that takes in
33
+ an :obj:`torch_geometric.data.Data` object and returns a
34
+ transformed version. The data object will be transformed before
35
+ being saved to disk. (default: :obj:`None`)
36
+ pre_filter (callable, optional): A function that takes in an
37
+ :obj:`torch_geometric.data.Data` object and returns a boolean
38
+ value, indicating whether the data object should be included in the
39
+ final dataset. (default: :obj:`None`)
40
+ force_reload (bool, optional): Whether to re-process the dataset.
41
+ (default: :obj:`False`)
42
+ split (int, optional): Datasets split, train/valid/test=0/1/2.
43
+ (default: :obj:`0`)
44
+ """
45
+
46
+ raw_url_id = '1loBXabD6ncAFY-vanRsVtRUSFkEtBweg'
47
+
48
+ def __init__(
49
+ self,
50
+ root: str,
51
+ transform: Optional[Callable] = None,
52
+ pre_transform: Optional[Callable] = None,
53
+ pre_filter: Optional[Callable] = None,
54
+ force_reload: bool = False,
55
+ split: int = 0,
56
+ ):
57
+ from torchvision import transforms
58
+
59
+ self.split = split
60
+
61
+ if self.split == 0:
62
+ self.img_transform = transforms.Compose([
63
+ transforms.Resize((224, 224)),
64
+ transforms.RandomRotation(15),
65
+ transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
66
+ transforms.ToTensor(),
67
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
68
+ std=[0.229, 0.224, 0.225])
69
+ ])
70
+ else:
71
+ self.img_transform = transforms.Compose([
72
+ transforms.Resize((224, 224)),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
75
+ std=[0.229, 0.224, 0.225])
76
+ ])
77
+
78
+ super().__init__(root, transform, pre_transform, pre_filter,
79
+ force_reload=force_reload)
80
+
81
+ self.load(self.processed_paths[0])
82
+
83
+ @property
84
+ def raw_file_names(self) -> List[str]:
85
+ return ['train_3500.pkl', 'valid_450.pkl', 'test_450.pkl']
86
+
87
+ @property
88
+ def processed_file_names(self) -> str:
89
+ return ['train.pt', 'valid.pt', 'test.pt'][self.split]
90
+
91
+ def download(self) -> None:
92
+ file_path = download_google_url(
93
+ self.raw_url_id,
94
+ self.raw_dir,
95
+ 'gitmol.zip',
96
+ )
97
+ extract_zip(file_path, self.raw_dir)
98
+
99
+ def process(self) -> None:
100
+ import pandas as pd
101
+ from PIL import Image
102
+
103
+ try:
104
+ from rdkit import Chem, RDLogger
105
+ RDLogger.DisableLog('rdApp.*') # type: ignore
106
+ WITH_RDKIT = True
107
+
108
+ except ImportError:
109
+ WITH_RDKIT = False
110
+
111
+ if not WITH_RDKIT:
112
+ print(("Using a pre-processed version of the dataset. Please "
113
+ "install 'rdkit' to alternatively process the raw data."),
114
+ file=sys.stderr)
115
+
116
+ data_list = fs.torch_load(self.raw_paths[0])
117
+ data_list = [Data(**data_dict) for data_dict in data_list]
118
+
119
+ if self.pre_filter is not None:
120
+ data_list = [d for d in data_list if self.pre_filter(d)]
121
+
122
+ if self.pre_transform is not None:
123
+ data_list = [self.pre_transform(d) for d in data_list]
124
+
125
+ self.save(data_list, self.processed_paths[0])
126
+ return
127
+
128
+ allowable_features: Dict[str, List[Any]] = {
129
+ 'possible_atomic_num_list':
130
+ list(range(1, 119)) + ['misc'],
131
+ 'possible_formal_charge_list':
132
+ [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'],
133
+ 'possible_chirality_list': [
134
+ Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
135
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
136
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
137
+ Chem.rdchem.ChiralType.CHI_OTHER
138
+ ],
139
+ 'possible_hybridization_list': [
140
+ Chem.rdchem.HybridizationType.SP,
141
+ Chem.rdchem.HybridizationType.SP2,
142
+ Chem.rdchem.HybridizationType.SP3,
143
+ Chem.rdchem.HybridizationType.SP3D,
144
+ Chem.rdchem.HybridizationType.SP3D2,
145
+ Chem.rdchem.HybridizationType.UNSPECIFIED, 'misc'
146
+ ],
147
+ 'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
148
+ 'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6],
149
+ 'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
150
+ 'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'],
151
+ 'possible_is_aromatic_list': [False, True],
152
+ 'possible_is_in_ring_list': [False, True],
153
+ 'possible_bond_type_list': [
154
+ Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
155
+ Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC,
156
+ Chem.rdchem.BondType.ZERO
157
+ ],
158
+ 'possible_bond_dirs': [ # only for double bond stereo information
159
+ Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT,
160
+ Chem.rdchem.BondDir.ENDDOWNRIGHT
161
+ ],
162
+ 'possible_bond_stereo_list': [
163
+ Chem.rdchem.BondStereo.STEREONONE,
164
+ Chem.rdchem.BondStereo.STEREOZ,
165
+ Chem.rdchem.BondStereo.STEREOE,
166
+ Chem.rdchem.BondStereo.STEREOCIS,
167
+ Chem.rdchem.BondStereo.STEREOTRANS,
168
+ Chem.rdchem.BondStereo.STEREOANY,
169
+ ],
170
+ 'possible_is_conjugated_list': [False, True]
171
+ }
172
+
173
+ data = pd.read_pickle(
174
+ f'{self.raw_dir}/igcdata_toy/{self.raw_file_names[self.split]}')
175
+
176
+ data_list = []
177
+ for _, r in tqdm(data.iterrows(), total=data.shape[0]):
178
+ smiles = r['isosmiles']
179
+ mol = Chem.MolFromSmiles(smiles.strip('\n'))
180
+ if mol is not None:
181
+ # text
182
+ summary = r['summary']
183
+ # image
184
+ cid = r['cid']
185
+ img_file = f'{self.raw_dir}/igcdata_toy/imgs/CID_{cid}.png'
186
+ img = Image.open(img_file).convert('RGB')
187
+ img = self.img_transform(img).unsqueeze(0)
188
+ # graph
189
+ atom_features_list = []
190
+ for atom in mol.GetAtoms():
191
+ atom_feature = [
192
+ safe_index(
193
+ allowable_features['possible_atomic_num_list'],
194
+ atom.GetAtomicNum()),
195
+ allowable_features['possible_chirality_list'].index(
196
+ atom.GetChiralTag()),
197
+ safe_index(allowable_features['possible_degree_list'],
198
+ atom.GetTotalDegree()),
199
+ safe_index(
200
+ allowable_features['possible_formal_charge_list'],
201
+ atom.GetFormalCharge()),
202
+ safe_index(allowable_features['possible_numH_list'],
203
+ atom.GetTotalNumHs()),
204
+ safe_index(
205
+ allowable_features[
206
+ 'possible_number_radical_e_list'],
207
+ atom.GetNumRadicalElectrons()),
208
+ safe_index(
209
+ allowable_features['possible_hybridization_list'],
210
+ atom.GetHybridization()),
211
+ allowable_features['possible_is_aromatic_list'].index(
212
+ atom.GetIsAromatic()),
213
+ allowable_features['possible_is_in_ring_list'].index(
214
+ atom.IsInRing()),
215
+ ]
216
+ atom_features_list.append(atom_feature)
217
+ x = torch.tensor(np.array(atom_features_list),
218
+ dtype=torch.long)
219
+
220
+ edges_list = []
221
+ edge_features_list = []
222
+ for bond in mol.GetBonds():
223
+ i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
224
+ edge_feature = [
225
+ safe_index(
226
+ allowable_features['possible_bond_type_list'],
227
+ bond.GetBondType()),
228
+ allowable_features['possible_bond_stereo_list'].index(
229
+ bond.GetStereo()),
230
+ allowable_features['possible_is_conjugated_list'].
231
+ index(bond.GetIsConjugated()),
232
+ ]
233
+ edges_list.append((i, j))
234
+ edge_features_list.append(edge_feature)
235
+ edges_list.append((j, i))
236
+ edge_features_list.append(edge_feature)
237
+
238
+ edge_index = torch.tensor(
239
+ np.array(edges_list).T,
240
+ dtype=torch.long,
241
+ )
242
+ edge_attr = torch.tensor(
243
+ np.array(edge_features_list),
244
+ dtype=torch.long,
245
+ )
246
+
247
+ data = Data(
248
+ x=x,
249
+ edge_index=edge_index,
250
+ smiles=smiles,
251
+ edge_attr=edge_attr,
252
+ image=img,
253
+ caption=summary,
254
+ )
255
+
256
+ if self.pre_filter is not None and not self.pre_filter(data):
257
+ continue
258
+ if self.pre_transform is not None:
259
+ data = self.pre_transform(data)
260
+
261
+ data_list.append(data)
262
+
263
+ self.save(data_list, self.processed_paths[0])
@@ -12,6 +12,7 @@ from torch_geometric.data import (
12
12
  download_url,
13
13
  extract_zip,
14
14
  )
15
+ from torch_geometric.io import fs
15
16
  from torch_geometric.utils import remove_self_loops
16
17
 
17
18
 
@@ -61,31 +62,31 @@ class GNNBenchmarkDataset(InMemoryDataset):
61
62
  - #features
62
63
  - #classes
63
64
  * - PATTERN
64
- - 10,000
65
+ - 14,000
65
66
  - ~118.9
66
67
  - ~6,098.9
67
68
  - 3
68
69
  - 2
69
70
  * - CLUSTER
70
- - 10,000
71
+ - 12,000
71
72
  - ~117.2
72
73
  - ~4,303.9
73
74
  - 7
74
75
  - 6
75
76
  * - MNIST
76
- - 55,000
77
+ - 70,000
77
78
  - ~70.6
78
79
  - ~564.5
79
80
  - 3
80
81
  - 10
81
82
  * - CIFAR10
82
- - 45,000
83
+ - 60,000
83
84
  - ~117.6
84
85
  - ~941.2
85
86
  - 5
86
87
  - 10
87
88
  * - TSP
88
- - 10,000
89
+ - 12,000
89
90
  - ~275.4
90
91
  - ~6,885.0
91
92
  - 2
@@ -126,9 +127,9 @@ class GNNBenchmarkDataset(InMemoryDataset):
126
127
  if self.name == 'CSL' and split != 'train':
127
128
  split = 'train'
128
129
  logging.warning(
129
- ("Dataset 'CSL' does not provide a standardized splitting. "
130
- "Instead, it is recommended to perform 5-fold cross "
131
- "validation with stratifed sampling"))
130
+ "Dataset 'CSL' does not provide a standardized splitting. "
131
+ "Instead, it is recommended to perform 5-fold cross "
132
+ "validation with stratifed sampling")
132
133
 
133
134
  super().__init__(root, transform, pre_transform, pre_filter,
134
135
  force_reload=force_reload)
@@ -181,7 +182,7 @@ class GNNBenchmarkDataset(InMemoryDataset):
181
182
  data_list = self.process_CSL()
182
183
  self.save(data_list, self.processed_paths[0])
183
184
  else:
184
- inputs = torch.load(self.raw_paths[0])
185
+ inputs = fs.torch_load(self.raw_paths[0])
185
186
  for i in range(len(inputs)):
186
187
  data_list = [Data(**data_dict) for data_dict in inputs[i]]
187
188
 
@@ -197,7 +198,7 @@ class GNNBenchmarkDataset(InMemoryDataset):
197
198
  with open(self.raw_paths[0], 'rb') as f:
198
199
  adjs = pickle.load(f)
199
200
 
200
- ys = torch.load(self.raw_paths[1]).tolist()
201
+ ys = fs.torch_load(self.raw_paths[1]).tolist()
201
202
 
202
203
  data_list = []
203
204
  for adj, y in zip(adjs, ys):