pyg-nightly 2.6.0.dev20240511__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (205) hide show
  1. {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +30 -31
  2. {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +205 -181
  3. {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +26 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +16 -14
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/data.py +13 -8
  12. torch_geometric/data/database.py +15 -7
  13. torch_geometric/data/dataset.py +14 -6
  14. torch_geometric/data/feature_store.py +13 -22
  15. torch_geometric/data/graph_store.py +0 -4
  16. torch_geometric/data/hetero_data.py +4 -4
  17. torch_geometric/data/in_memory_dataset.py +2 -4
  18. torch_geometric/data/large_graph_indexer.py +677 -0
  19. torch_geometric/data/lightning/datamodule.py +4 -4
  20. torch_geometric/data/storage.py +15 -5
  21. torch_geometric/data/summary.py +14 -4
  22. torch_geometric/data/temporal.py +1 -2
  23. torch_geometric/datasets/__init__.py +11 -1
  24. torch_geometric/datasets/actor.py +9 -11
  25. torch_geometric/datasets/airfrans.py +15 -18
  26. torch_geometric/datasets/airports.py +10 -12
  27. torch_geometric/datasets/amazon.py +8 -11
  28. torch_geometric/datasets/amazon_book.py +9 -10
  29. torch_geometric/datasets/amazon_products.py +9 -10
  30. torch_geometric/datasets/aminer.py +8 -9
  31. torch_geometric/datasets/aqsol.py +10 -13
  32. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  33. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  34. torch_geometric/datasets/ba_shapes.py +5 -6
  35. torch_geometric/datasets/bitcoin_otc.py +1 -1
  36. torch_geometric/datasets/brca_tgca.py +1 -1
  37. torch_geometric/datasets/dblp.py +2 -1
  38. torch_geometric/datasets/dbp15k.py +2 -2
  39. torch_geometric/datasets/fake.py +1 -3
  40. torch_geometric/datasets/flickr.py +2 -1
  41. torch_geometric/datasets/freebase.py +1 -1
  42. torch_geometric/datasets/gdelt_lite.py +3 -2
  43. torch_geometric/datasets/ged_dataset.py +3 -2
  44. torch_geometric/datasets/git_mol_dataset.py +263 -0
  45. torch_geometric/datasets/gnn_benchmark_dataset.py +6 -5
  46. torch_geometric/datasets/hgb_dataset.py +8 -8
  47. torch_geometric/datasets/imdb.py +2 -1
  48. torch_geometric/datasets/last_fm.py +2 -1
  49. torch_geometric/datasets/linkx_dataset.py +4 -3
  50. torch_geometric/datasets/lrgb.py +3 -5
  51. torch_geometric/datasets/malnet_tiny.py +4 -3
  52. torch_geometric/datasets/mnist_superpixels.py +2 -3
  53. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  54. torch_geometric/datasets/molecule_net.py +7 -1
  55. torch_geometric/datasets/motif_generator/base.py +0 -1
  56. torch_geometric/datasets/neurograph.py +1 -3
  57. torch_geometric/datasets/ogb_mag.py +1 -1
  58. torch_geometric/datasets/opf.py +239 -0
  59. torch_geometric/datasets/ose_gvcs.py +1 -1
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  62. torch_geometric/datasets/pcqm4m.py +2 -1
  63. torch_geometric/datasets/ppi.py +1 -1
  64. torch_geometric/datasets/qm9.py +4 -3
  65. torch_geometric/datasets/reddit.py +2 -1
  66. torch_geometric/datasets/reddit2.py +2 -1
  67. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  68. torch_geometric/datasets/s3dis.py +2 -2
  69. torch_geometric/datasets/shapenet.py +3 -3
  70. torch_geometric/datasets/shrec2016.py +2 -2
  71. torch_geometric/datasets/tag_dataset.py +350 -0
  72. torch_geometric/datasets/upfd.py +2 -1
  73. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  74. torch_geometric/datasets/webkb.py +2 -2
  75. torch_geometric/datasets/wikics.py +1 -1
  76. torch_geometric/datasets/wikidata.py +3 -2
  77. torch_geometric/datasets/wikipedia_network.py +2 -2
  78. torch_geometric/datasets/word_net.py +2 -2
  79. torch_geometric/datasets/yelp.py +2 -1
  80. torch_geometric/datasets/zinc.py +1 -1
  81. torch_geometric/device.py +42 -0
  82. torch_geometric/distributed/local_feature_store.py +3 -2
  83. torch_geometric/distributed/local_graph_store.py +2 -1
  84. torch_geometric/distributed/partition.py +9 -8
  85. torch_geometric/edge_index.py +17 -8
  86. torch_geometric/explain/algorithm/base.py +0 -1
  87. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  88. torch_geometric/explain/explanation.py +2 -2
  89. torch_geometric/graphgym/checkpoint.py +2 -1
  90. torch_geometric/graphgym/logger.py +4 -4
  91. torch_geometric/graphgym/loss.py +1 -1
  92. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  93. torch_geometric/index.py +20 -7
  94. torch_geometric/inspector.py +6 -2
  95. torch_geometric/io/fs.py +28 -2
  96. torch_geometric/io/npz.py +2 -1
  97. torch_geometric/io/off.py +2 -2
  98. torch_geometric/io/sdf.py +2 -2
  99. torch_geometric/io/tu.py +2 -3
  100. torch_geometric/loader/__init__.py +4 -0
  101. torch_geometric/loader/cluster.py +9 -3
  102. torch_geometric/loader/graph_saint.py +2 -1
  103. torch_geometric/loader/ibmb_loader.py +12 -4
  104. torch_geometric/loader/mixin.py +1 -1
  105. torch_geometric/loader/neighbor_loader.py +1 -1
  106. torch_geometric/loader/neighbor_sampler.py +2 -2
  107. torch_geometric/loader/prefetch.py +1 -1
  108. torch_geometric/loader/rag_loader.py +107 -0
  109. torch_geometric/loader/zip_loader.py +10 -0
  110. torch_geometric/metrics/__init__.py +11 -2
  111. torch_geometric/metrics/link_pred.py +159 -34
  112. torch_geometric/nn/aggr/__init__.py +2 -0
  113. torch_geometric/nn/aggr/attention.py +0 -2
  114. torch_geometric/nn/aggr/base.py +2 -4
  115. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  116. torch_geometric/nn/aggr/set_transformer.py +1 -1
  117. torch_geometric/nn/attention/__init__.py +5 -1
  118. torch_geometric/nn/attention/qformer.py +71 -0
  119. torch_geometric/nn/conv/collect.jinja +6 -3
  120. torch_geometric/nn/conv/cugraph/base.py +0 -1
  121. torch_geometric/nn/conv/edge_conv.py +3 -2
  122. torch_geometric/nn/conv/gat_conv.py +35 -7
  123. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  124. torch_geometric/nn/conv/general_conv.py +1 -1
  125. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  126. torch_geometric/nn/conv/hetero_conv.py +3 -3
  127. torch_geometric/nn/conv/hgt_conv.py +1 -1
  128. torch_geometric/nn/conv/message_passing.py +100 -82
  129. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  130. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  131. torch_geometric/nn/conv/spline_conv.py +4 -4
  132. torch_geometric/nn/conv/x_conv.py +3 -2
  133. torch_geometric/nn/dense/linear.py +5 -4
  134. torch_geometric/nn/fx.py +3 -3
  135. torch_geometric/nn/model_hub.py +3 -1
  136. torch_geometric/nn/models/__init__.py +10 -2
  137. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  138. torch_geometric/nn/models/dimenet_utils.py +5 -7
  139. torch_geometric/nn/models/g_retriever.py +230 -0
  140. torch_geometric/nn/models/git_mol.py +336 -0
  141. torch_geometric/nn/models/glem.py +385 -0
  142. torch_geometric/nn/models/gnnff.py +0 -1
  143. torch_geometric/nn/models/graph_unet.py +12 -3
  144. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  145. torch_geometric/nn/models/lightgcn.py +1 -1
  146. torch_geometric/nn/models/metapath2vec.py +3 -4
  147. torch_geometric/nn/models/molecule_gpt.py +222 -0
  148. torch_geometric/nn/models/node2vec.py +1 -2
  149. torch_geometric/nn/models/schnet.py +2 -1
  150. torch_geometric/nn/models/signed_gcn.py +3 -3
  151. torch_geometric/nn/module_dict.py +2 -2
  152. torch_geometric/nn/nlp/__init__.py +9 -0
  153. torch_geometric/nn/nlp/llm.py +322 -0
  154. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  155. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  156. torch_geometric/nn/norm/batch_norm.py +1 -1
  157. torch_geometric/nn/parameter_dict.py +2 -2
  158. torch_geometric/nn/pool/__init__.py +7 -5
  159. torch_geometric/nn/pool/cluster_pool.py +145 -0
  160. torch_geometric/nn/pool/connect/base.py +0 -1
  161. torch_geometric/nn/pool/edge_pool.py +1 -1
  162. torch_geometric/nn/pool/graclus.py +4 -2
  163. torch_geometric/nn/pool/select/base.py +0 -1
  164. torch_geometric/nn/pool/voxel_grid.py +3 -2
  165. torch_geometric/nn/resolver.py +1 -1
  166. torch_geometric/nn/sequential.jinja +10 -23
  167. torch_geometric/nn/sequential.py +203 -77
  168. torch_geometric/nn/summary.py +1 -1
  169. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  170. torch_geometric/profile/__init__.py +2 -0
  171. torch_geometric/profile/nvtx.py +66 -0
  172. torch_geometric/profile/profiler.py +24 -15
  173. torch_geometric/resolver.py +1 -1
  174. torch_geometric/sampler/base.py +34 -13
  175. torch_geometric/sampler/neighbor_sampler.py +11 -10
  176. torch_geometric/testing/decorators.py +17 -22
  177. torch_geometric/transforms/__init__.py +2 -0
  178. torch_geometric/transforms/add_metapaths.py +4 -4
  179. torch_geometric/transforms/add_positional_encoding.py +1 -1
  180. torch_geometric/transforms/delaunay.py +65 -14
  181. torch_geometric/transforms/face_to_edge.py +32 -3
  182. torch_geometric/transforms/gdc.py +7 -6
  183. torch_geometric/transforms/laplacian_lambda_max.py +2 -2
  184. torch_geometric/transforms/mask.py +5 -1
  185. torch_geometric/transforms/node_property_split.py +1 -2
  186. torch_geometric/transforms/pad.py +7 -6
  187. torch_geometric/transforms/random_link_split.py +1 -1
  188. torch_geometric/transforms/remove_self_loops.py +36 -0
  189. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  190. torch_geometric/transforms/virtual_node.py +2 -1
  191. torch_geometric/typing.py +31 -5
  192. torch_geometric/utils/__init__.py +5 -1
  193. torch_geometric/utils/_negative_sampling.py +1 -1
  194. torch_geometric/utils/_normalize_edge_index.py +46 -0
  195. torch_geometric/utils/_scatter.py +37 -12
  196. torch_geometric/utils/_subgraph.py +4 -0
  197. torch_geometric/utils/_tree_decomposition.py +2 -2
  198. torch_geometric/utils/augmentation.py +1 -1
  199. torch_geometric/utils/convert.py +5 -5
  200. torch_geometric/utils/geodesic.py +24 -22
  201. torch_geometric/utils/hetero.py +1 -1
  202. torch_geometric/utils/map.py +1 -1
  203. torch_geometric/utils/smiles.py +66 -28
  204. torch_geometric/utils/sparse.py +25 -10
  205. torch_geometric/visualization/graph.py +3 -4
@@ -4,7 +4,6 @@ from itertools import product
4
4
  from typing import Callable, List, Optional
5
5
 
6
6
  import numpy as np
7
- import scipy.sparse as sp
8
7
  import torch
9
8
 
10
9
  from torch_geometric.data import (
@@ -110,6 +109,8 @@ class DBLP(InMemoryDataset):
110
109
  os.remove(path)
111
110
 
112
111
  def process(self) -> None:
112
+ import scipy.sparse as sp
113
+
113
114
  data = HeteroData()
114
115
 
115
116
  node_types = ['author', 'paper', 'term', 'conference']
@@ -72,7 +72,7 @@ class DBP15K(InMemoryDataset):
72
72
 
73
73
  def process(self) -> None:
74
74
  embs = {}
75
- with open(osp.join(self.raw_dir, 'sub.glove.300d'), 'r') as f:
75
+ with open(osp.join(self.raw_dir, 'sub.glove.300d')) as f:
76
76
  for i, line in enumerate(f):
77
77
  info = line.strip().split(' ')
78
78
  if len(info) > 300:
@@ -112,7 +112,7 @@ class DBP15K(InMemoryDataset):
112
112
  subj, rel, obj = g1.t()
113
113
 
114
114
  x_dict = {}
115
- with open(feature_path, 'r') as f:
115
+ with open(feature_path) as f:
116
116
  for line in f:
117
117
  info = line.strip().split('\t')
118
118
  info = info if len(info) == 2 else info + ['**UNK**']
@@ -170,7 +170,7 @@ class FakeHeteroDataset(InMemoryDataset):
170
170
  random.shuffle(edge_types)
171
171
 
172
172
  self.edge_types: List[Tuple[str, str, str]] = []
173
- count: Dict[Tuple[str, str], int] = defaultdict(lambda: 0)
173
+ count: Dict[Tuple[str, str], int] = defaultdict(int)
174
174
  for edge_type in edge_types[:max(num_edge_types, 1)]:
175
175
  rel = f'e{count[edge_type]}'
176
176
  count[edge_type] += 1
@@ -222,8 +222,6 @@ class FakeHeteroDataset(InMemoryDataset):
222
222
  elif self.edge_dim == 1:
223
223
  store.edge_weight = torch.rand(store.num_edges)
224
224
 
225
- pass
226
-
227
225
  if self._num_classes > 0 and self.task == 'graph':
228
226
  data.y = torch.tensor([random.randint(0, self._num_classes - 1)])
229
227
 
@@ -3,7 +3,6 @@ import os.path as osp
3
3
  from typing import Callable, List, Optional
4
4
 
5
5
  import numpy as np
6
- import scipy.sparse as sp
7
6
  import torch
8
7
 
9
8
  from torch_geometric.data import Data, InMemoryDataset, download_google_url
@@ -73,6 +72,8 @@ class Flickr(InMemoryDataset):
73
72
  download_google_url(self.role_id, self.raw_dir, 'role.json')
74
73
 
75
74
  def process(self) -> None:
75
+ import scipy.sparse as sp
76
+
76
77
  f = np.load(osp.join(self.raw_dir, 'adj_full.npz'))
77
78
  adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape'])
78
79
  adj = adj.tocoo()
@@ -75,7 +75,7 @@ class FB15k_237(InMemoryDataset):
75
75
  rel_dict: Dict[str, int] = {}
76
76
 
77
77
  for path in self.raw_paths:
78
- with open(path, 'r') as f:
78
+ with open(path) as f:
79
79
  lines = [x.split('\t') for x in f.read().split('\n')[:-1]]
80
80
 
81
81
  edge_index = torch.empty((2, len(lines)), dtype=torch.long)
@@ -9,6 +9,7 @@ from torch_geometric.data import (
9
9
  download_url,
10
10
  extract_zip,
11
11
  )
12
+ from torch_geometric.io import fs
12
13
 
13
14
 
14
15
  class GDELTLite(InMemoryDataset):
@@ -80,9 +81,9 @@ class GDELTLite(InMemoryDataset):
80
81
  def process(self) -> None:
81
82
  import pandas as pd
82
83
 
83
- x = torch.load(self.raw_paths[0])
84
+ x = fs.torch_load(self.raw_paths[0])
84
85
  df = pd.read_csv(self.raw_paths[1])
85
- edge_attr = torch.load(self.raw_paths[2])
86
+ edge_attr = fs.torch_load(self.raw_paths[2])
86
87
 
87
88
  row = torch.from_numpy(df['src'].values)
88
89
  col = torch.from_numpy(df['dst'].values)
@@ -13,6 +13,7 @@ from torch_geometric.data import (
13
13
  extract_tar,
14
14
  extract_zip,
15
15
  )
16
+ from torch_geometric.io import fs
16
17
  from torch_geometric.utils import one_hot, to_undirected
17
18
 
18
19
 
@@ -145,9 +146,9 @@ class GEDDataset(InMemoryDataset):
145
146
  path = self.processed_paths[0] if train else self.processed_paths[1]
146
147
  self.load(path)
147
148
  path = osp.join(self.processed_dir, f'{self.name}_ged.pt')
148
- self.ged = torch.load(path)
149
+ self.ged = fs.torch_load(path)
149
150
  path = osp.join(self.processed_dir, f'{self.name}_norm_ged.pt')
150
- self.norm_ged = torch.load(path)
151
+ self.norm_ged = fs.torch_load(path)
151
152
 
152
153
  @property
153
154
  def raw_file_names(self) -> List[str]:
@@ -0,0 +1,263 @@
1
+ import sys
2
+ from typing import Any, Callable, Dict, List, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ from torch_geometric.data import (
9
+ Data,
10
+ InMemoryDataset,
11
+ download_google_url,
12
+ extract_zip,
13
+ )
14
+ from torch_geometric.io import fs
15
+
16
+
17
+ def safe_index(lst: List[Any], e: int) -> int:
18
+ return lst.index(e) if e in lst else len(lst) - 1
19
+
20
+
21
+ class GitMolDataset(InMemoryDataset):
22
+ r"""The dataset from the `"GIT-Mol: A Multi-modal Large Language Model
23
+ for Molecular Science with Graph, Image, and Text"
24
+ <https://arxiv.org/pdf/2308.06911>`_ paper.
25
+
26
+ Args:
27
+ root (str): Root directory where the dataset should be saved.
28
+ transform (callable, optional): A function/transform that takes in an
29
+ :obj:`torch_geometric.data.Data` object and returns a transformed
30
+ version. The data object will be transformed before every access.
31
+ (default: :obj:`None`)
32
+ pre_transform (callable, optional): A function/transform that takes in
33
+ an :obj:`torch_geometric.data.Data` object and returns a
34
+ transformed version. The data object will be transformed before
35
+ being saved to disk. (default: :obj:`None`)
36
+ pre_filter (callable, optional): A function that takes in an
37
+ :obj:`torch_geometric.data.Data` object and returns a boolean
38
+ value, indicating whether the data object should be included in the
39
+ final dataset. (default: :obj:`None`)
40
+ force_reload (bool, optional): Whether to re-process the dataset.
41
+ (default: :obj:`False`)
42
+ split (int, optional): Datasets split, train/valid/test=0/1/2.
43
+ (default: :obj:`0`)
44
+ """
45
+
46
+ raw_url_id = '1loBXabD6ncAFY-vanRsVtRUSFkEtBweg'
47
+
48
+ def __init__(
49
+ self,
50
+ root: str,
51
+ transform: Optional[Callable] = None,
52
+ pre_transform: Optional[Callable] = None,
53
+ pre_filter: Optional[Callable] = None,
54
+ force_reload: bool = False,
55
+ split: int = 0,
56
+ ):
57
+ from torchvision import transforms
58
+
59
+ self.split = split
60
+
61
+ if self.split == 0:
62
+ self.img_transform = transforms.Compose([
63
+ transforms.Resize((224, 224)),
64
+ transforms.RandomRotation(15),
65
+ transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
66
+ transforms.ToTensor(),
67
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
68
+ std=[0.229, 0.224, 0.225])
69
+ ])
70
+ else:
71
+ self.img_transform = transforms.Compose([
72
+ transforms.Resize((224, 224)),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
75
+ std=[0.229, 0.224, 0.225])
76
+ ])
77
+
78
+ super().__init__(root, transform, pre_transform, pre_filter,
79
+ force_reload=force_reload)
80
+
81
+ self.load(self.processed_paths[0])
82
+
83
+ @property
84
+ def raw_file_names(self) -> List[str]:
85
+ return ['train_3500.pkl', 'valid_450.pkl', 'test_450.pkl']
86
+
87
+ @property
88
+ def processed_file_names(self) -> str:
89
+ return ['train.pt', 'valid.pt', 'test.pt'][self.split]
90
+
91
+ def download(self) -> None:
92
+ file_path = download_google_url(
93
+ self.raw_url_id,
94
+ self.raw_dir,
95
+ 'gitmol.zip',
96
+ )
97
+ extract_zip(file_path, self.raw_dir)
98
+
99
+ def process(self) -> None:
100
+ import pandas as pd
101
+ from PIL import Image
102
+
103
+ try:
104
+ from rdkit import Chem, RDLogger
105
+ RDLogger.DisableLog('rdApp.*') # type: ignore
106
+ WITH_RDKIT = True
107
+
108
+ except ImportError:
109
+ WITH_RDKIT = False
110
+
111
+ if not WITH_RDKIT:
112
+ print(("Using a pre-processed version of the dataset. Please "
113
+ "install 'rdkit' to alternatively process the raw data."),
114
+ file=sys.stderr)
115
+
116
+ data_list = fs.torch_load(self.raw_paths[0])
117
+ data_list = [Data(**data_dict) for data_dict in data_list]
118
+
119
+ if self.pre_filter is not None:
120
+ data_list = [d for d in data_list if self.pre_filter(d)]
121
+
122
+ if self.pre_transform is not None:
123
+ data_list = [self.pre_transform(d) for d in data_list]
124
+
125
+ self.save(data_list, self.processed_paths[0])
126
+ return
127
+
128
+ allowable_features: Dict[str, List[Any]] = {
129
+ 'possible_atomic_num_list':
130
+ list(range(1, 119)) + ['misc'],
131
+ 'possible_formal_charge_list':
132
+ [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'],
133
+ 'possible_chirality_list': [
134
+ Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
135
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
136
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
137
+ Chem.rdchem.ChiralType.CHI_OTHER
138
+ ],
139
+ 'possible_hybridization_list': [
140
+ Chem.rdchem.HybridizationType.SP,
141
+ Chem.rdchem.HybridizationType.SP2,
142
+ Chem.rdchem.HybridizationType.SP3,
143
+ Chem.rdchem.HybridizationType.SP3D,
144
+ Chem.rdchem.HybridizationType.SP3D2,
145
+ Chem.rdchem.HybridizationType.UNSPECIFIED, 'misc'
146
+ ],
147
+ 'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
148
+ 'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6],
149
+ 'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
150
+ 'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'],
151
+ 'possible_is_aromatic_list': [False, True],
152
+ 'possible_is_in_ring_list': [False, True],
153
+ 'possible_bond_type_list': [
154
+ Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
155
+ Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC,
156
+ Chem.rdchem.BondType.ZERO
157
+ ],
158
+ 'possible_bond_dirs': [ # only for double bond stereo information
159
+ Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT,
160
+ Chem.rdchem.BondDir.ENDDOWNRIGHT
161
+ ],
162
+ 'possible_bond_stereo_list': [
163
+ Chem.rdchem.BondStereo.STEREONONE,
164
+ Chem.rdchem.BondStereo.STEREOZ,
165
+ Chem.rdchem.BondStereo.STEREOE,
166
+ Chem.rdchem.BondStereo.STEREOCIS,
167
+ Chem.rdchem.BondStereo.STEREOTRANS,
168
+ Chem.rdchem.BondStereo.STEREOANY,
169
+ ],
170
+ 'possible_is_conjugated_list': [False, True]
171
+ }
172
+
173
+ data = pd.read_pickle(
174
+ f'{self.raw_dir}/igcdata_toy/{self.raw_file_names[self.split]}')
175
+
176
+ data_list = []
177
+ for _, r in tqdm(data.iterrows(), total=data.shape[0]):
178
+ smiles = r['isosmiles']
179
+ mol = Chem.MolFromSmiles(smiles.strip('\n'))
180
+ if mol is not None:
181
+ # text
182
+ summary = r['summary']
183
+ # image
184
+ cid = r['cid']
185
+ img_file = f'{self.raw_dir}/igcdata_toy/imgs/CID_{cid}.png'
186
+ img = Image.open(img_file).convert('RGB')
187
+ img = self.img_transform(img).unsqueeze(0)
188
+ # graph
189
+ atom_features_list = []
190
+ for atom in mol.GetAtoms():
191
+ atom_feature = [
192
+ safe_index(
193
+ allowable_features['possible_atomic_num_list'],
194
+ atom.GetAtomicNum()),
195
+ allowable_features['possible_chirality_list'].index(
196
+ atom.GetChiralTag()),
197
+ safe_index(allowable_features['possible_degree_list'],
198
+ atom.GetTotalDegree()),
199
+ safe_index(
200
+ allowable_features['possible_formal_charge_list'],
201
+ atom.GetFormalCharge()),
202
+ safe_index(allowable_features['possible_numH_list'],
203
+ atom.GetTotalNumHs()),
204
+ safe_index(
205
+ allowable_features[
206
+ 'possible_number_radical_e_list'],
207
+ atom.GetNumRadicalElectrons()),
208
+ safe_index(
209
+ allowable_features['possible_hybridization_list'],
210
+ atom.GetHybridization()),
211
+ allowable_features['possible_is_aromatic_list'].index(
212
+ atom.GetIsAromatic()),
213
+ allowable_features['possible_is_in_ring_list'].index(
214
+ atom.IsInRing()),
215
+ ]
216
+ atom_features_list.append(atom_feature)
217
+ x = torch.tensor(np.array(atom_features_list),
218
+ dtype=torch.long)
219
+
220
+ edges_list = []
221
+ edge_features_list = []
222
+ for bond in mol.GetBonds():
223
+ i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
224
+ edge_feature = [
225
+ safe_index(
226
+ allowable_features['possible_bond_type_list'],
227
+ bond.GetBondType()),
228
+ allowable_features['possible_bond_stereo_list'].index(
229
+ bond.GetStereo()),
230
+ allowable_features['possible_is_conjugated_list'].
231
+ index(bond.GetIsConjugated()),
232
+ ]
233
+ edges_list.append((i, j))
234
+ edge_features_list.append(edge_feature)
235
+ edges_list.append((j, i))
236
+ edge_features_list.append(edge_feature)
237
+
238
+ edge_index = torch.tensor(
239
+ np.array(edges_list).T,
240
+ dtype=torch.long,
241
+ )
242
+ edge_attr = torch.tensor(
243
+ np.array(edge_features_list),
244
+ dtype=torch.long,
245
+ )
246
+
247
+ data = Data(
248
+ x=x,
249
+ edge_index=edge_index,
250
+ smiles=smiles,
251
+ edge_attr=edge_attr,
252
+ image=img,
253
+ caption=summary,
254
+ )
255
+
256
+ if self.pre_filter is not None and not self.pre_filter(data):
257
+ continue
258
+ if self.pre_transform is not None:
259
+ data = self.pre_transform(data)
260
+
261
+ data_list.append(data)
262
+
263
+ self.save(data_list, self.processed_paths[0])
@@ -12,6 +12,7 @@ from torch_geometric.data import (
12
12
  download_url,
13
13
  extract_zip,
14
14
  )
15
+ from torch_geometric.io import fs
15
16
  from torch_geometric.utils import remove_self_loops
16
17
 
17
18
 
@@ -126,9 +127,9 @@ class GNNBenchmarkDataset(InMemoryDataset):
126
127
  if self.name == 'CSL' and split != 'train':
127
128
  split = 'train'
128
129
  logging.warning(
129
- ("Dataset 'CSL' does not provide a standardized splitting. "
130
- "Instead, it is recommended to perform 5-fold cross "
131
- "validation with stratifed sampling"))
130
+ "Dataset 'CSL' does not provide a standardized splitting. "
131
+ "Instead, it is recommended to perform 5-fold cross "
132
+ "validation with stratifed sampling")
132
133
 
133
134
  super().__init__(root, transform, pre_transform, pre_filter,
134
135
  force_reload=force_reload)
@@ -181,7 +182,7 @@ class GNNBenchmarkDataset(InMemoryDataset):
181
182
  data_list = self.process_CSL()
182
183
  self.save(data_list, self.processed_paths[0])
183
184
  else:
184
- inputs = torch.load(self.raw_paths[0])
185
+ inputs = fs.torch_load(self.raw_paths[0])
185
186
  for i in range(len(inputs)):
186
187
  data_list = [Data(**data_dict) for data_dict in inputs[i]]
187
188
 
@@ -197,7 +198,7 @@ class GNNBenchmarkDataset(InMemoryDataset):
197
198
  with open(self.raw_paths[0], 'rb') as f:
198
199
  adjs = pickle.load(f)
199
200
 
200
- ys = torch.load(self.raw_paths[1]).tolist()
201
+ ys = fs.torch_load(self.raw_paths[1]).tolist()
201
202
 
202
203
  data_list = []
203
204
  for adj, y in zip(adjs, ys):
@@ -99,7 +99,7 @@ class HGBDataset(InMemoryDataset):
99
99
  # node_types = {0: 'paper', 1, 'author', ...}
100
100
  # edge_types = {0: ('paper', 'cite', 'paper'), ...}
101
101
  if self.name in ['acm', 'dblp', 'imdb']:
102
- with open(self.raw_paths[0], 'r') as f: # `info.dat`
102
+ with open(self.raw_paths[0]) as f: # `info.dat`
103
103
  info = json.load(f)
104
104
  n_types = info['node.dat']['node type']
105
105
  n_types = {int(k): v for k, v in n_types.items()}
@@ -112,7 +112,7 @@ class HGBDataset(InMemoryDataset):
112
112
  e_types[key] = (src, rel, dst)
113
113
  num_classes = len(info['label.dat']['node type']['0'])
114
114
  elif self.name in ['freebase']:
115
- with open(self.raw_paths[0], 'r') as f: # `info.dat`
115
+ with open(self.raw_paths[0]) as f: # `info.dat`
116
116
  info = f.read().split('\n')
117
117
  start = info.index('TYPE\tMEANING') + 1
118
118
  end = info[start:].index('')
@@ -124,7 +124,7 @@ class HGBDataset(InMemoryDataset):
124
124
  end = info[start:].index('')
125
125
  for key, row in enumerate(info[start:start + end]):
126
126
  row = row.split('\t')[1:]
127
- src, dst, rel = [v for v in row if v != '']
127
+ src, dst, rel = (v for v in row if v != '')
128
128
  src, dst = n_types[int(src)], n_types[int(dst)]
129
129
  rel = rel.split('-')[1]
130
130
  e_types[key] = (src, rel, dst)
@@ -134,8 +134,8 @@ class HGBDataset(InMemoryDataset):
134
134
  # Extract node information:
135
135
  mapping_dict = {} # Maps global node indices to local ones.
136
136
  x_dict = defaultdict(list)
137
- num_nodes_dict: Dict[str, int] = defaultdict(lambda: 0)
138
- with open(self.raw_paths[1], 'r') as f: # `node.dat`
137
+ num_nodes_dict: Dict[str, int] = defaultdict(int)
138
+ with open(self.raw_paths[1]) as f: # `node.dat`
139
139
  xs = [v.split('\t') for v in f.read().split('\n')[:-1]]
140
140
  for x in xs:
141
141
  n_id, n_type = int(x[0]), n_types[int(x[2])]
@@ -151,7 +151,7 @@ class HGBDataset(InMemoryDataset):
151
151
 
152
152
  edge_index_dict = defaultdict(list)
153
153
  edge_weight_dict = defaultdict(list)
154
- with open(self.raw_paths[2], 'r') as f: # `link.dat`
154
+ with open(self.raw_paths[2]) as f: # `link.dat`
155
155
  edges = [v.split('\t') for v in f.read().split('\n')[:-1]]
156
156
  for src, dst, rel, weight in edges:
157
157
  e_type = e_types[int(rel)]
@@ -168,9 +168,9 @@ class HGBDataset(InMemoryDataset):
168
168
 
169
169
  # Node classification:
170
170
  if self.name in ['acm', 'dblp', 'freebase', 'imdb']:
171
- with open(self.raw_paths[3], 'r') as f: # `label.dat`
171
+ with open(self.raw_paths[3]) as f: # `label.dat`
172
172
  train_ys = [v.split('\t') for v in f.read().split('\n')[:-1]]
173
- with open(self.raw_paths[4], 'r') as f: # `label.dat.test`
173
+ with open(self.raw_paths[4]) as f: # `label.dat.test`
174
174
  test_ys = [v.split('\t') for v in f.read().split('\n')[:-1]]
175
175
  for y in train_ys:
176
176
  n_id, n_type = mapping_dict[int(y[0])], n_types[int(y[2])]
@@ -4,7 +4,6 @@ from itertools import product
4
4
  from typing import Callable, List, Optional
5
5
 
6
6
  import numpy as np
7
- import scipy.sparse as sp
8
7
  import torch
9
8
 
10
9
  from torch_geometric.data import (
@@ -69,6 +68,8 @@ class IMDB(InMemoryDataset):
69
68
  os.remove(path)
70
69
 
71
70
  def process(self) -> None:
71
+ import scipy.sparse as sp
72
+
72
73
  data = HeteroData()
73
74
 
74
75
  node_types = ['movie', 'director', 'actor']
@@ -4,7 +4,6 @@ from itertools import product
4
4
  from typing import Callable, List, Optional
5
5
 
6
6
  import numpy as np
7
- import scipy.sparse as sp
8
7
  import torch
9
8
 
10
9
  from torch_geometric.data import (
@@ -68,6 +67,8 @@ class LastFM(InMemoryDataset):
68
67
  os.remove(path)
69
68
 
70
69
  def process(self) -> None:
70
+ import scipy.sparse as sp
71
+
71
72
  data = HeteroData()
72
73
 
73
74
  node_type_idx = np.load(osp.join(self.raw_dir, 'node_types.npy'))
@@ -5,6 +5,7 @@ import numpy as np
5
5
  import torch
6
6
 
7
7
  from torch_geometric.data import Data, InMemoryDataset, download_url
8
+ from torch_geometric.io import fs
8
9
  from torch_geometric.utils import one_hot
9
10
 
10
11
 
@@ -115,9 +116,9 @@ class LINKXDataset(InMemoryDataset):
115
116
 
116
117
  def _process_wiki(self) -> Data:
117
118
  paths = {x.split('/')[-1]: x for x in self.raw_paths}
118
- x = torch.load(paths['wiki_features2M.pt'])
119
- edge_index = torch.load(paths['wiki_edges2M.pt']).t().contiguous()
120
- y = torch.load(paths['wiki_views2M.pt'])
119
+ x = fs.torch_load(paths['wiki_features2M.pt'])
120
+ edge_index = fs.torch_load(paths['wiki_edges2M.pt']).t().contiguous()
121
+ y = fs.torch_load(paths['wiki_views2M.pt'])
121
122
 
122
123
  return Data(x=x, edge_index=edge_index, y=y)
123
124
 
@@ -188,9 +188,8 @@ class LRGBDataset(InMemoryDataset):
188
188
  graphs = pickle.load(f)
189
189
  elif self.name.split('-')[0] == 'peptides':
190
190
  # Peptides-func and Peptides-struct
191
- with open(osp.join(self.raw_dir, f'{split}.pt'),
192
- 'rb') as f:
193
- graphs = torch.load(f)
191
+ graphs = fs.torch_load(
192
+ osp.join(self.raw_dir, f'{split}.pt'))
194
193
 
195
194
  data_list = []
196
195
  for graph in tqdm(graphs, desc=f'Processing {split} dataset'):
@@ -260,8 +259,7 @@ class LRGBDataset(InMemoryDataset):
260
259
 
261
260
  def process_pcqm_contact(self) -> None:
262
261
  for split in ['train', 'val', 'test']:
263
- with open(osp.join(self.raw_dir, f'{split}.pt'), 'rb') as f:
264
- graphs = torch.load(f)
262
+ graphs = fs.torch_load(osp.join(self.raw_dir, f'{split}.pt'))
265
263
 
266
264
  data_list = []
267
265
  for graph in tqdm(graphs, desc=f'Processing {split} dataset'):
@@ -11,6 +11,7 @@ from torch_geometric.data import (
11
11
  extract_tar,
12
12
  extract_zip,
13
13
  )
14
+ from torch_geometric.io import fs
14
15
 
15
16
 
16
17
  class MalNetTiny(InMemoryDataset):
@@ -65,7 +66,7 @@ class MalNetTiny(InMemoryDataset):
65
66
  self.load(self.processed_paths[0])
66
67
 
67
68
  if split is not None:
68
- split_slices = torch.load(self.processed_paths[1])
69
+ split_slices = fs.torch_load(self.processed_paths[1])
69
70
  if split == 'train':
70
71
  self._indices = range(split_slices[0], split_slices[1])
71
72
  elif split == 'val':
@@ -98,7 +99,7 @@ class MalNetTiny(InMemoryDataset):
98
99
  split_slices = [0]
99
100
 
100
101
  for split in ['train', 'val', 'test']:
101
- with open(osp.join(self.raw_paths[1], f'{split}.txt'), 'r') as f:
102
+ with open(osp.join(self.raw_paths[1], f'{split}.txt')) as f:
102
103
  filenames = f.read().split('\n')[:-1]
103
104
  split_slices.append(split_slices[-1] + len(filenames))
104
105
 
@@ -107,7 +108,7 @@ class MalNetTiny(InMemoryDataset):
107
108
  malware_type = filename.split('/')[0]
108
109
  y = y_map.setdefault(malware_type, len(y_map))
109
110
 
110
- with open(path, 'r') as f:
111
+ with open(path) as f:
111
112
  edges = f.read().split('\n')[5:-1]
112
113
 
113
114
  edge_indices = [[int(s) for s in e.split()] for e in edges]
@@ -1,14 +1,13 @@
1
1
  import os
2
2
  from typing import Callable, List, Optional
3
3
 
4
- import torch
5
-
6
4
  from torch_geometric.data import (
7
5
  Data,
8
6
  InMemoryDataset,
9
7
  download_url,
10
8
  extract_zip,
11
9
  )
10
+ from torch_geometric.io import fs
12
11
 
13
12
 
14
13
  class MNISTSuperpixels(InMemoryDataset):
@@ -85,7 +84,7 @@ class MNISTSuperpixels(InMemoryDataset):
85
84
  os.unlink(path)
86
85
 
87
86
  def process(self) -> None:
88
- inputs = torch.load(self.raw_paths[0])
87
+ inputs = fs.torch_load(self.raw_paths[0])
89
88
  for i in range(len(inputs)):
90
89
  data_list = [Data(**data_dict) for data_dict in inputs[i]]
91
90