pyg-nightly 2.6.0.dev20240511__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (205) hide show
  1. {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +30 -31
  2. {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +205 -181
  3. {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +26 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +16 -14
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/data.py +13 -8
  12. torch_geometric/data/database.py +15 -7
  13. torch_geometric/data/dataset.py +14 -6
  14. torch_geometric/data/feature_store.py +13 -22
  15. torch_geometric/data/graph_store.py +0 -4
  16. torch_geometric/data/hetero_data.py +4 -4
  17. torch_geometric/data/in_memory_dataset.py +2 -4
  18. torch_geometric/data/large_graph_indexer.py +677 -0
  19. torch_geometric/data/lightning/datamodule.py +4 -4
  20. torch_geometric/data/storage.py +15 -5
  21. torch_geometric/data/summary.py +14 -4
  22. torch_geometric/data/temporal.py +1 -2
  23. torch_geometric/datasets/__init__.py +11 -1
  24. torch_geometric/datasets/actor.py +9 -11
  25. torch_geometric/datasets/airfrans.py +15 -18
  26. torch_geometric/datasets/airports.py +10 -12
  27. torch_geometric/datasets/amazon.py +8 -11
  28. torch_geometric/datasets/amazon_book.py +9 -10
  29. torch_geometric/datasets/amazon_products.py +9 -10
  30. torch_geometric/datasets/aminer.py +8 -9
  31. torch_geometric/datasets/aqsol.py +10 -13
  32. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  33. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  34. torch_geometric/datasets/ba_shapes.py +5 -6
  35. torch_geometric/datasets/bitcoin_otc.py +1 -1
  36. torch_geometric/datasets/brca_tgca.py +1 -1
  37. torch_geometric/datasets/dblp.py +2 -1
  38. torch_geometric/datasets/dbp15k.py +2 -2
  39. torch_geometric/datasets/fake.py +1 -3
  40. torch_geometric/datasets/flickr.py +2 -1
  41. torch_geometric/datasets/freebase.py +1 -1
  42. torch_geometric/datasets/gdelt_lite.py +3 -2
  43. torch_geometric/datasets/ged_dataset.py +3 -2
  44. torch_geometric/datasets/git_mol_dataset.py +263 -0
  45. torch_geometric/datasets/gnn_benchmark_dataset.py +6 -5
  46. torch_geometric/datasets/hgb_dataset.py +8 -8
  47. torch_geometric/datasets/imdb.py +2 -1
  48. torch_geometric/datasets/last_fm.py +2 -1
  49. torch_geometric/datasets/linkx_dataset.py +4 -3
  50. torch_geometric/datasets/lrgb.py +3 -5
  51. torch_geometric/datasets/malnet_tiny.py +4 -3
  52. torch_geometric/datasets/mnist_superpixels.py +2 -3
  53. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  54. torch_geometric/datasets/molecule_net.py +7 -1
  55. torch_geometric/datasets/motif_generator/base.py +0 -1
  56. torch_geometric/datasets/neurograph.py +1 -3
  57. torch_geometric/datasets/ogb_mag.py +1 -1
  58. torch_geometric/datasets/opf.py +239 -0
  59. torch_geometric/datasets/ose_gvcs.py +1 -1
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  62. torch_geometric/datasets/pcqm4m.py +2 -1
  63. torch_geometric/datasets/ppi.py +1 -1
  64. torch_geometric/datasets/qm9.py +4 -3
  65. torch_geometric/datasets/reddit.py +2 -1
  66. torch_geometric/datasets/reddit2.py +2 -1
  67. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  68. torch_geometric/datasets/s3dis.py +2 -2
  69. torch_geometric/datasets/shapenet.py +3 -3
  70. torch_geometric/datasets/shrec2016.py +2 -2
  71. torch_geometric/datasets/tag_dataset.py +350 -0
  72. torch_geometric/datasets/upfd.py +2 -1
  73. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  74. torch_geometric/datasets/webkb.py +2 -2
  75. torch_geometric/datasets/wikics.py +1 -1
  76. torch_geometric/datasets/wikidata.py +3 -2
  77. torch_geometric/datasets/wikipedia_network.py +2 -2
  78. torch_geometric/datasets/word_net.py +2 -2
  79. torch_geometric/datasets/yelp.py +2 -1
  80. torch_geometric/datasets/zinc.py +1 -1
  81. torch_geometric/device.py +42 -0
  82. torch_geometric/distributed/local_feature_store.py +3 -2
  83. torch_geometric/distributed/local_graph_store.py +2 -1
  84. torch_geometric/distributed/partition.py +9 -8
  85. torch_geometric/edge_index.py +17 -8
  86. torch_geometric/explain/algorithm/base.py +0 -1
  87. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  88. torch_geometric/explain/explanation.py +2 -2
  89. torch_geometric/graphgym/checkpoint.py +2 -1
  90. torch_geometric/graphgym/logger.py +4 -4
  91. torch_geometric/graphgym/loss.py +1 -1
  92. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  93. torch_geometric/index.py +20 -7
  94. torch_geometric/inspector.py +6 -2
  95. torch_geometric/io/fs.py +28 -2
  96. torch_geometric/io/npz.py +2 -1
  97. torch_geometric/io/off.py +2 -2
  98. torch_geometric/io/sdf.py +2 -2
  99. torch_geometric/io/tu.py +2 -3
  100. torch_geometric/loader/__init__.py +4 -0
  101. torch_geometric/loader/cluster.py +9 -3
  102. torch_geometric/loader/graph_saint.py +2 -1
  103. torch_geometric/loader/ibmb_loader.py +12 -4
  104. torch_geometric/loader/mixin.py +1 -1
  105. torch_geometric/loader/neighbor_loader.py +1 -1
  106. torch_geometric/loader/neighbor_sampler.py +2 -2
  107. torch_geometric/loader/prefetch.py +1 -1
  108. torch_geometric/loader/rag_loader.py +107 -0
  109. torch_geometric/loader/zip_loader.py +10 -0
  110. torch_geometric/metrics/__init__.py +11 -2
  111. torch_geometric/metrics/link_pred.py +159 -34
  112. torch_geometric/nn/aggr/__init__.py +2 -0
  113. torch_geometric/nn/aggr/attention.py +0 -2
  114. torch_geometric/nn/aggr/base.py +2 -4
  115. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  116. torch_geometric/nn/aggr/set_transformer.py +1 -1
  117. torch_geometric/nn/attention/__init__.py +5 -1
  118. torch_geometric/nn/attention/qformer.py +71 -0
  119. torch_geometric/nn/conv/collect.jinja +6 -3
  120. torch_geometric/nn/conv/cugraph/base.py +0 -1
  121. torch_geometric/nn/conv/edge_conv.py +3 -2
  122. torch_geometric/nn/conv/gat_conv.py +35 -7
  123. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  124. torch_geometric/nn/conv/general_conv.py +1 -1
  125. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  126. torch_geometric/nn/conv/hetero_conv.py +3 -3
  127. torch_geometric/nn/conv/hgt_conv.py +1 -1
  128. torch_geometric/nn/conv/message_passing.py +100 -82
  129. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  130. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  131. torch_geometric/nn/conv/spline_conv.py +4 -4
  132. torch_geometric/nn/conv/x_conv.py +3 -2
  133. torch_geometric/nn/dense/linear.py +5 -4
  134. torch_geometric/nn/fx.py +3 -3
  135. torch_geometric/nn/model_hub.py +3 -1
  136. torch_geometric/nn/models/__init__.py +10 -2
  137. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  138. torch_geometric/nn/models/dimenet_utils.py +5 -7
  139. torch_geometric/nn/models/g_retriever.py +230 -0
  140. torch_geometric/nn/models/git_mol.py +336 -0
  141. torch_geometric/nn/models/glem.py +385 -0
  142. torch_geometric/nn/models/gnnff.py +0 -1
  143. torch_geometric/nn/models/graph_unet.py +12 -3
  144. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  145. torch_geometric/nn/models/lightgcn.py +1 -1
  146. torch_geometric/nn/models/metapath2vec.py +3 -4
  147. torch_geometric/nn/models/molecule_gpt.py +222 -0
  148. torch_geometric/nn/models/node2vec.py +1 -2
  149. torch_geometric/nn/models/schnet.py +2 -1
  150. torch_geometric/nn/models/signed_gcn.py +3 -3
  151. torch_geometric/nn/module_dict.py +2 -2
  152. torch_geometric/nn/nlp/__init__.py +9 -0
  153. torch_geometric/nn/nlp/llm.py +322 -0
  154. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  155. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  156. torch_geometric/nn/norm/batch_norm.py +1 -1
  157. torch_geometric/nn/parameter_dict.py +2 -2
  158. torch_geometric/nn/pool/__init__.py +7 -5
  159. torch_geometric/nn/pool/cluster_pool.py +145 -0
  160. torch_geometric/nn/pool/connect/base.py +0 -1
  161. torch_geometric/nn/pool/edge_pool.py +1 -1
  162. torch_geometric/nn/pool/graclus.py +4 -2
  163. torch_geometric/nn/pool/select/base.py +0 -1
  164. torch_geometric/nn/pool/voxel_grid.py +3 -2
  165. torch_geometric/nn/resolver.py +1 -1
  166. torch_geometric/nn/sequential.jinja +10 -23
  167. torch_geometric/nn/sequential.py +203 -77
  168. torch_geometric/nn/summary.py +1 -1
  169. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  170. torch_geometric/profile/__init__.py +2 -0
  171. torch_geometric/profile/nvtx.py +66 -0
  172. torch_geometric/profile/profiler.py +24 -15
  173. torch_geometric/resolver.py +1 -1
  174. torch_geometric/sampler/base.py +34 -13
  175. torch_geometric/sampler/neighbor_sampler.py +11 -10
  176. torch_geometric/testing/decorators.py +17 -22
  177. torch_geometric/transforms/__init__.py +2 -0
  178. torch_geometric/transforms/add_metapaths.py +4 -4
  179. torch_geometric/transforms/add_positional_encoding.py +1 -1
  180. torch_geometric/transforms/delaunay.py +65 -14
  181. torch_geometric/transforms/face_to_edge.py +32 -3
  182. torch_geometric/transforms/gdc.py +7 -6
  183. torch_geometric/transforms/laplacian_lambda_max.py +2 -2
  184. torch_geometric/transforms/mask.py +5 -1
  185. torch_geometric/transforms/node_property_split.py +1 -2
  186. torch_geometric/transforms/pad.py +7 -6
  187. torch_geometric/transforms/random_link_split.py +1 -1
  188. torch_geometric/transforms/remove_self_loops.py +36 -0
  189. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  190. torch_geometric/transforms/virtual_node.py +2 -1
  191. torch_geometric/typing.py +31 -5
  192. torch_geometric/utils/__init__.py +5 -1
  193. torch_geometric/utils/_negative_sampling.py +1 -1
  194. torch_geometric/utils/_normalize_edge_index.py +46 -0
  195. torch_geometric/utils/_scatter.py +37 -12
  196. torch_geometric/utils/_subgraph.py +4 -0
  197. torch_geometric/utils/_tree_decomposition.py +2 -2
  198. torch_geometric/utils/augmentation.py +1 -1
  199. torch_geometric/utils/convert.py +5 -5
  200. torch_geometric/utils/geodesic.py +24 -22
  201. torch_geometric/utils/hetero.py +1 -1
  202. torch_geometric/utils/map.py +1 -1
  203. torch_geometric/utils/smiles.py +66 -28
  204. torch_geometric/utils/sparse.py +25 -10
  205. torch_geometric/visualization/graph.py +3 -4
@@ -0,0 +1,239 @@
1
+ import json
2
+ import os
3
+ import os.path as osp
4
+ from typing import Callable, Dict, List, Literal, Optional
5
+
6
+ import torch
7
+ import tqdm
8
+ from torch import Tensor
9
+
10
+ from torch_geometric.data import (
11
+ HeteroData,
12
+ InMemoryDataset,
13
+ download_url,
14
+ extract_tar,
15
+ )
16
+
17
+
18
+ class OPFDataset(InMemoryDataset):
19
+ r"""The heterogeneous OPF data from the `"Large-scale Datasets for AC
20
+ Optimal Power Flow with Topological Perturbations"
21
+ <https://arxiv.org/abs/2406.07234>`_ paper.
22
+
23
+ :class:`OPFDataset` is a large-scale dataset of solved optimal power flow
24
+ problems, derived from the
25
+ `pglib-opf <https://github.com/power-grid-lib/pglib-opf>`_ dataset.
26
+
27
+ The physical topology of the grid is represented by the :obj:`"bus"` node
28
+ type, and the connecting AC lines and transformers. Additionally,
29
+ :obj:`"generator"`, :obj:`"load"`, and :obj:`"shunt"` nodes are connected
30
+ to :obj:`"bus"` nodes using a dedicated edge type each, *e.g.*,
31
+ :obj:`"generator_link"`.
32
+
33
+ Edge direction corresponds to the properties of the line, *e.g.*,
34
+ :obj:`b_fr` is the line charging susceptance at the :obj:`from`
35
+ (source/sender) bus.
36
+
37
+ Args:
38
+ root (str): Root directory where the dataset should be saved.
39
+ split (str, optional): If :obj:`"train"`, loads the training dataset.
40
+ If :obj:`"val"`, loads the validation dataset.
41
+ If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
42
+ case_name (str, optional): The name of the original pglib-opf case.
43
+ (default: :obj:`"pglib_opf_case14_ieee"`)
44
+ num_groups (int, optional): The dataset is divided into 20 groups with
45
+ each group containing 15,000 samples.
46
+ For large networks, this amount of data can be overwhelming.
47
+ The :obj:`num_groups` parameters controls the amount of data being
48
+ downloaded. Allowed values are :obj:`[1, 20]`.
49
+ (default: :obj:`20`)
50
+ topological_perturbations (bool, optional): Whether to use the dataset
51
+ with added topological perturbations. (default: :obj:`False`)
52
+ transform (callable, optional): A function/transform that takes in
53
+ a :obj:`torch_geometric.data.HeteroData` object and returns a
54
+ transformed version. The data object will be transformed before
55
+ every access. (default: :obj:`None`)
56
+ pre_transform (callable, optional): A function/transform that takes
57
+ in a :obj:`torch_geometric.data.HeteroData` object and returns
58
+ a transformed version. The data object will be transformed before
59
+ being saved to disk. (default: :obj:`None`)
60
+ pre_filter (callable, optional): A function that takes in a
61
+ :obj:`torch_geometric.data.HeteroData` object and returns a boolean
62
+ value, indicating whether the data object should be included in the
63
+ final dataset. (default: :obj:`None`)
64
+ force_reload (bool, optional): Whether to re-process the dataset.
65
+ (default: :obj:`False`)
66
+ """
67
+ url = 'https://storage.googleapis.com/gridopt-dataset'
68
+
69
+ def __init__(
70
+ self,
71
+ root: str,
72
+ split: Literal['train', 'val', 'test'] = 'train',
73
+ case_name: Literal[
74
+ 'pglib_opf_case14_ieee',
75
+ 'pglib_opf_case30_ieee',
76
+ 'pglib_opf_case57_ieee',
77
+ 'pglib_opf_case118_ieee',
78
+ 'pglib_opf_case500_goc',
79
+ 'pglib_opf_case2000_goc',
80
+ 'pglib_opf_case6470_rte',
81
+ 'pglib_opf_case4661_sdet'
82
+ 'pglib_opf_case10000_goc',
83
+ 'pglib_opf_case13659_pegase',
84
+ ] = 'pglib_opf_case14_ieee',
85
+ num_groups: int = 20,
86
+ topological_perturbations: bool = False,
87
+ transform: Optional[Callable] = None,
88
+ pre_transform: Optional[Callable] = None,
89
+ pre_filter: Optional[Callable] = None,
90
+ force_reload: bool = False,
91
+ ) -> None:
92
+
93
+ self.split = split
94
+ self.case_name = case_name
95
+ self.num_groups = num_groups
96
+ self.topological_perturbations = topological_perturbations
97
+
98
+ self._release = 'dataset_release_1'
99
+ if topological_perturbations:
100
+ self._release += '_nminusone'
101
+
102
+ super().__init__(root, transform, pre_transform, pre_filter,
103
+ force_reload=force_reload)
104
+
105
+ idx = self.processed_file_names.index(f'{split}.pt')
106
+ self.load(self.processed_paths[idx])
107
+
108
+ @property
109
+ def raw_dir(self) -> str:
110
+ return osp.join(self.root, self._release, self.case_name, 'raw')
111
+
112
+ @property
113
+ def processed_dir(self) -> str:
114
+ return osp.join(self.root, self._release, self.case_name,
115
+ f'processed_{self.num_groups}')
116
+
117
+ @property
118
+ def raw_file_names(self) -> List[str]:
119
+ return [f'{self.case_name}_{i}.tar.gz' for i in range(self.num_groups)]
120
+
121
+ @property
122
+ def processed_file_names(self) -> List[str]:
123
+ return ['train.pt', 'val.pt', 'test.pt']
124
+
125
+ def download(self) -> None:
126
+ for name in self.raw_file_names:
127
+ url = f'{self.url}/{self._release}/{name}'
128
+ path = download_url(url, self.raw_dir)
129
+ extract_tar(path, self.raw_dir)
130
+
131
+ def process(self) -> None:
132
+ train_data_list = []
133
+ val_data_list = []
134
+ test_data_list = []
135
+
136
+ for group in tqdm.tqdm(range(self.num_groups)):
137
+ tmp_dir = osp.join(
138
+ self.raw_dir,
139
+ 'gridopt-dataset-tmp',
140
+ self._release,
141
+ self.case_name,
142
+ f'group_{group}',
143
+ )
144
+
145
+ for name in os.listdir(tmp_dir):
146
+ with open(osp.join(tmp_dir, name)) as f:
147
+ obj = json.load(f)
148
+
149
+ grid = obj['grid']
150
+ solution = obj['solution']
151
+ metadata = obj['metadata']
152
+
153
+ # Graph-level properties:
154
+ data = HeteroData()
155
+ data.x = torch.tensor(grid['context']).view(-1)
156
+
157
+ data.objective = torch.tensor(metadata['objective'])
158
+
159
+ # Nodes (only some have a target):
160
+ data['bus'].x = torch.tensor(grid['nodes']['bus'])
161
+ data['bus'].y = torch.tensor(solution['nodes']['bus'])
162
+
163
+ data['generator'].x = torch.tensor(grid['nodes']['generator'])
164
+ data['generator'].y = torch.tensor(
165
+ solution['nodes']['generator'])
166
+
167
+ data['load'].x = torch.tensor(grid['nodes']['load'])
168
+
169
+ data['shunt'].x = torch.tensor(grid['nodes']['shunt'])
170
+
171
+ # Edges (only ac lines and transformers have features):
172
+ data['bus', 'ac_line', 'bus'].edge_index = ( #
173
+ extract_edge_index(obj, 'ac_line'))
174
+ data['bus', 'ac_line', 'bus'].edge_attr = torch.tensor(
175
+ grid['edges']['ac_line']['features'])
176
+ data['bus', 'ac_line', 'bus'].edge_label = torch.tensor(
177
+ solution['edges']['ac_line']['features'])
178
+
179
+ data['bus', 'transformer', 'bus'].edge_index = ( #
180
+ extract_edge_index(obj, 'transformer'))
181
+ data['bus', 'transformer', 'bus'].edge_attr = torch.tensor(
182
+ grid['edges']['transformer']['features'])
183
+ data['bus', 'transformer', 'bus'].edge_label = torch.tensor(
184
+ solution['edges']['transformer']['features'])
185
+
186
+ data['generator', 'generator_link', 'bus'].edge_index = ( #
187
+ extract_edge_index(obj, 'generator_link'))
188
+ data['bus', 'generator_link', 'generator'].edge_index = ( #
189
+ extract_edge_index_rev(obj, 'generator_link'))
190
+
191
+ data['load', 'load_link', 'bus'].edge_index = ( #
192
+ extract_edge_index(obj, 'load_link'))
193
+ data['bus', 'load_link', 'load'].edge_index = ( #
194
+ extract_edge_index_rev(obj, 'load_link'))
195
+
196
+ data['shunt', 'shunt_link', 'bus'].edge_index = ( #
197
+ extract_edge_index(obj, 'shunt_link'))
198
+ data['bus', 'shunt_link', 'shunt'].edge_index = ( #
199
+ extract_edge_index_rev(obj, 'shunt_link'))
200
+
201
+ if self.pre_filter is not None and not self.pre_filter(data):
202
+ continue
203
+
204
+ if self.pre_transform is not None:
205
+ data = self.pre_transform(data)
206
+
207
+ i = int(name.split('.')[0].split('_')[1])
208
+ train_limit = int(15_000 * self.num_groups * 0.9)
209
+ val_limit = train_limit + int(15_000 * self.num_groups * 0.05)
210
+ if i < train_limit:
211
+ train_data_list.append(data)
212
+ elif i < val_limit:
213
+ val_data_list.append(data)
214
+ else:
215
+ test_data_list.append(data)
216
+
217
+ self.save(train_data_list, self.processed_paths[0])
218
+ self.save(val_data_list, self.processed_paths[1])
219
+ self.save(test_data_list, self.processed_paths[2])
220
+
221
+ def __repr__(self) -> str:
222
+ return (f'{self.__class__.__name__}({len(self)}, '
223
+ f'split={self.split}, '
224
+ f'case_name={self.case_name}, '
225
+ f'topological_perturbations={self.topological_perturbations})')
226
+
227
+
228
+ def extract_edge_index(obj: Dict, edge_name: str) -> Tensor:
229
+ return torch.tensor([
230
+ obj['grid']['edges'][edge_name]['senders'],
231
+ obj['grid']['edges'][edge_name]['receivers'],
232
+ ])
233
+
234
+
235
+ def extract_edge_index_rev(obj: Dict, edge_name: str) -> Tensor:
236
+ return torch.tensor([
237
+ obj['grid']['edges'][edge_name]['receivers'],
238
+ obj['grid']['edges'][edge_name]['senders'],
239
+ ])
@@ -97,7 +97,7 @@ class OSE_GVCS(InMemoryDataset):
97
97
  edges = defaultdict(list)
98
98
 
99
99
  for path in self.raw_paths:
100
- with open(path, 'r') as f:
100
+ with open(path) as f:
101
101
  product = json.load(f)
102
102
  categories.append(self.categories.index(product['category']))
103
103
  for interaction in product['ecology']:
@@ -66,7 +66,7 @@ class PascalPF(InMemoryDataset):
66
66
  super().__init__(root, transform, pre_transform, pre_filter,
67
67
  force_reload=force_reload)
68
68
  self.load(self.processed_paths[0])
69
- self.pairs = torch.load(self.processed_paths[1])
69
+ self.pairs = fs.torch_load(self.processed_paths[1])
70
70
 
71
71
  @property
72
72
  def raw_file_names(self) -> List[str]:
@@ -121,7 +121,7 @@ class PCPNetDataset(InMemoryDataset):
121
121
 
122
122
  def process(self) -> None:
123
123
  path_file = self.raw_paths
124
- with open(path_file[0], "r") as f:
124
+ with open(path_file[0]) as f:
125
125
  filenames = f.read().split('\n')[:-1]
126
126
  data_list = []
127
127
  for filename in filenames:
@@ -7,6 +7,7 @@ from tqdm import tqdm
7
7
 
8
8
  from torch_geometric.data import Data, OnDiskDataset, download_url, extract_zip
9
9
  from torch_geometric.data.data import BaseData
10
+ from torch_geometric.io import fs
10
11
  from torch_geometric.utils import from_smiles as _from_smiles
11
12
 
12
13
 
@@ -72,7 +73,7 @@ class PCQM4Mv2(OnDiskDataset):
72
73
  self.from_smiles = from_smiles or _from_smiles
73
74
  super().__init__(root, transform, backend=backend, schema=schema)
74
75
 
75
- split_idx = torch.load(self.raw_paths[1])
76
+ split_idx = fs.torch_load(self.raw_paths[1])
76
77
  self._indices = split_idx[self.split_mapping[split]].tolist()
77
78
 
78
79
  @property
@@ -106,7 +106,7 @@ class PPI(InMemoryDataset):
106
106
 
107
107
  for s, split in enumerate(['train', 'valid', 'test']):
108
108
  path = osp.join(self.raw_dir, f'{split}_graph.json')
109
- with open(path, 'r') as f:
109
+ with open(path) as f:
110
110
  G = nx.DiGraph(json_graph.node_link_graph(json.load(f)))
111
111
 
112
112
  x = np.load(osp.join(self.raw_dir, f'{split}_feats.npy'))
@@ -13,6 +13,7 @@ from torch_geometric.data import (
13
13
  download_url,
14
14
  extract_zip,
15
15
  )
16
+ from torch_geometric.io import fs
16
17
  from torch_geometric.utils import one_hot, scatter
17
18
 
18
19
  HAR2EV = 27.211386246
@@ -212,7 +213,7 @@ class QM9(InMemoryDataset):
212
213
  "install 'rdkit' to alternatively process the raw data."),
213
214
  file=sys.stderr)
214
215
 
215
- data_list = torch.load(self.raw_paths[0])
216
+ data_list = fs.torch_load(self.raw_paths[0])
216
217
  data_list = [Data(**data_dict) for data_dict in data_list]
217
218
 
218
219
  if self.pre_filter is not None:
@@ -227,14 +228,14 @@ class QM9(InMemoryDataset):
227
228
  types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
228
229
  bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
229
230
 
230
- with open(self.raw_paths[1], 'r') as f:
231
+ with open(self.raw_paths[1]) as f:
231
232
  target = [[float(x) for x in line.split(',')[1:20]]
232
233
  for line in f.read().split('\n')[1:-1]]
233
234
  y = torch.tensor(target, dtype=torch.float)
234
235
  y = torch.cat([y[:, 3:], y[:, :3]], dim=-1)
235
236
  y = y * conversion.view(1, -1)
236
237
 
237
- with open(self.raw_paths[2], 'r') as f:
238
+ with open(self.raw_paths[2]) as f:
238
239
  skip = [int(x.split()[0]) - 1 for x in f.read().split('\n')[9:-2]]
239
240
 
240
241
  suppl = Chem.SDMolSupplier(self.raw_paths[0], removeHs=False,
@@ -3,7 +3,6 @@ import os.path as osp
3
3
  from typing import Callable, List, Optional
4
4
 
5
5
  import numpy as np
6
- import scipy.sparse as sp
7
6
  import torch
8
7
 
9
8
  from torch_geometric.data import (
@@ -76,6 +75,8 @@ class Reddit(InMemoryDataset):
76
75
  os.unlink(path)
77
76
 
78
77
  def process(self) -> None:
78
+ import scipy.sparse as sp
79
+
79
80
  data = np.load(osp.join(self.raw_dir, 'reddit_data.npz'))
80
81
  x = torch.from_numpy(data['feature']).to(torch.float)
81
82
  y = torch.from_numpy(data['label']).to(torch.long)
@@ -3,7 +3,6 @@ import os.path as osp
3
3
  from typing import Callable, List, Optional
4
4
 
5
5
  import numpy as np
6
- import scipy.sparse as sp
7
6
  import torch
8
7
 
9
8
  from torch_geometric.data import Data, InMemoryDataset, download_google_url
@@ -81,6 +80,8 @@ class Reddit2(InMemoryDataset):
81
80
  download_google_url(self.role_id, self.raw_dir, 'role.json')
82
81
 
83
82
  def process(self) -> None:
83
+ import scipy.sparse as sp
84
+
84
85
  f = np.load(osp.join(self.raw_dir, 'adj_full.npz'))
85
86
  adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape'])
86
87
  adj = adj.tocoo()
@@ -89,17 +89,17 @@ class RelLinkPredDataset(InMemoryDataset):
89
89
  download_url(f'{self.urls[self.name]}/{file_name}', self.raw_dir)
90
90
 
91
91
  def process(self) -> None:
92
- with open(osp.join(self.raw_dir, 'entities.dict'), 'r') as f:
92
+ with open(osp.join(self.raw_dir, 'entities.dict')) as f:
93
93
  lines = [row.split('\t') for row in f.read().split('\n')[:-1]]
94
94
  entities_dict = {key: int(value) for value, key in lines}
95
95
 
96
- with open(osp.join(self.raw_dir, 'relations.dict'), 'r') as f:
96
+ with open(osp.join(self.raw_dir, 'relations.dict')) as f:
97
97
  lines = [row.split('\t') for row in f.read().split('\n')[:-1]]
98
98
  relations_dict = {key: int(value) for value, key in lines}
99
99
 
100
100
  kwargs = {}
101
101
  for split in ['train', 'valid', 'test']:
102
- with open(osp.join(self.raw_dir, f'{split}.txt'), 'r') as f:
102
+ with open(osp.join(self.raw_dir, f'{split}.txt')) as f:
103
103
  lines = [row.split('\t') for row in f.read().split('\n')[:-1]]
104
104
  src = [entities_dict[row[0]] for row in lines]
105
105
  rel = [relations_dict[row[1]] for row in lines]
@@ -86,10 +86,10 @@ class S3DIS(InMemoryDataset):
86
86
  def process(self) -> None:
87
87
  import h5py
88
88
 
89
- with open(self.raw_paths[0], 'r') as f:
89
+ with open(self.raw_paths[0]) as f:
90
90
  filenames = [x.split('/')[-1] for x in f.read().split('\n')[:-1]]
91
91
 
92
- with open(self.raw_paths[1], 'r') as f:
92
+ with open(self.raw_paths[1]) as f:
93
93
  rooms = f.read().split('\n')[:-1]
94
94
 
95
95
  xs: List[Tensor] = []
@@ -148,8 +148,8 @@ class ShapeNet(InMemoryDataset):
148
148
  elif split == 'trainval':
149
149
  path = self.processed_paths[3]
150
150
  else:
151
- raise ValueError((f'Split {split} found, but expected either '
152
- 'train, val, trainval or test'))
151
+ raise ValueError(f'Split {split} found, but expected either '
152
+ 'train, val, trainval or test')
153
153
 
154
154
  self.load(path)
155
155
 
@@ -213,7 +213,7 @@ class ShapeNet(InMemoryDataset):
213
213
  for i, split in enumerate(['train', 'val', 'test']):
214
214
  path = osp.join(self.raw_dir, 'train_test_split',
215
215
  f'shuffled_{split}_file_list.json')
216
- with open(path, 'r') as f:
216
+ with open(path) as f:
217
217
  filenames = [
218
218
  osp.sep.join(name.split('/')[1:]) + '.txt'
219
219
  for name in json.load(f)
@@ -6,7 +6,7 @@ from typing import Callable, List, Optional
6
6
  import torch
7
7
 
8
8
  from torch_geometric.data import InMemoryDataset, download_url, extract_zip
9
- from torch_geometric.io import read_off, read_txt_array
9
+ from torch_geometric.io import fs, read_off, read_txt_array
10
10
 
11
11
 
12
12
  class SHREC2016(InMemoryDataset):
@@ -79,7 +79,7 @@ class SHREC2016(InMemoryDataset):
79
79
  self.cat = category.lower()
80
80
  super().__init__(root, transform, pre_transform, pre_filter,
81
81
  force_reload=force_reload)
82
- self.__ref__ = torch.load(self.processed_paths[0])
82
+ self.__ref__ = fs.torch_load(self.processed_paths[0])
83
83
  path = self.processed_paths[1] if train else self.processed_paths[2]
84
84
  self.load(path)
85
85