pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -0,0 +1,451 @@
1
+ import os
2
+ import pickle
3
+ import random
4
+ from collections import defaultdict
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from torch_geometric.data import (
12
+ Data,
13
+ InMemoryDataset,
14
+ download_url,
15
+ extract_tar,
16
+ )
17
+
18
+
19
+ class ProteinMPNNDataset(InMemoryDataset):
20
+ r"""The ProteinMPNN dataset from the `"Robust deep learning based protein
21
+ sequence design using ProteinMPNN"
22
+ <https://www.biorxiv.org/content/10.1101/2022.06.03.494563v1>`_ paper.
23
+
24
+ Args:
25
+ root (str): Root directory where the dataset should be saved.
26
+ size (str): Size of the PDB information to train the model.
27
+ If :obj:`"small"`, loads the small dataset (229.4 MB).
28
+ If :obj:`"large"`, loads the large dataset (64.1 GB).
29
+ (default: :obj:`"small"`)
30
+ split (str, optional): If :obj:`"train"`, loads the training dataset.
31
+ If :obj:`"valid"`, loads the validation dataset.
32
+ If :obj:`"test"`, loads the test dataset.
33
+ (default: :obj:`"train"`)
34
+ datacut (str, optional): Date cutoff to filter the dataset.
35
+ (default: :obj:`"2030-01-01"`)
36
+ rescut (float, optional): PDB resolution cutoff.
37
+ (default: :obj:`3.5`)
38
+ homo (float, optional): Homology cutoff.
39
+ (default: :obj:`0.70`)
40
+ max_length (int, optional): Maximum length of the protein complex.
41
+ (default: :obj:`10000`)
42
+ num_units (int, optional): Number of units of the protein complex.
43
+ (default: :obj:`150`)
44
+ transform (callable, optional): A function/transform that takes in an
45
+ :obj:`torch_geometric.data.Data` object and returns a transformed
46
+ version. The data object will be transformed before every access.
47
+ (default: :obj:`None`)
48
+ pre_transform (callable, optional): A function/transform that takes in
49
+ an :obj:`torch_geometric.data.Data` object and returns a
50
+ transformed version. The data object will be transformed before
51
+ being saved to disk. (default: :obj:`None`)
52
+ pre_filter (callable, optional): A function that takes in an
53
+ :obj:`torch_geometric.data.Data` object and returns a boolean
54
+ value, indicating whether the data object should be included in the
55
+ final dataset. (default: :obj:`None`)
56
+ force_reload (bool, optional): Whether to re-process the dataset.
57
+ (default: :obj:`False`)
58
+ """
59
+
60
+ raw_url = {
61
+ 'small':
62
+ 'https://files.ipd.uw.edu/pub/training_sets/'
63
+ 'pdb_2021aug02_sample.tar.gz',
64
+ 'large':
65
+ 'https://files.ipd.uw.edu/pub/training_sets/'
66
+ 'pdb_2021aug02.tar.gz',
67
+ }
68
+
69
+ splits = {
70
+ 'train': 1,
71
+ 'valid': 2,
72
+ 'test': 3,
73
+ }
74
+
75
+ def __init__(
76
+ self,
77
+ root: str,
78
+ size: str = 'small',
79
+ split: str = 'train',
80
+ datacut: str = '2030-01-01',
81
+ rescut: float = 3.5,
82
+ homo: float = 0.70,
83
+ max_length: int = 10000,
84
+ num_units: int = 150,
85
+ transform: Optional[Callable] = None,
86
+ pre_transform: Optional[Callable] = None,
87
+ pre_filter: Optional[Callable] = None,
88
+ force_reload: bool = False,
89
+ ) -> None:
90
+ self.size = size
91
+ self.split = split
92
+ self.datacut = datacut
93
+ self.rescut = rescut
94
+ self.homo = homo
95
+ self.max_length = max_length
96
+ self.num_units = num_units
97
+
98
+ self.sub_folder = self.raw_url[self.size].split('/')[-1].split('.')[0]
99
+
100
+ super().__init__(root, transform, pre_transform, pre_filter,
101
+ force_reload=force_reload)
102
+ self.load(self.processed_paths[self.splits[self.split]])
103
+
104
+ @property
105
+ def raw_file_names(self) -> List[str]:
106
+ return [
107
+ f'{self.sub_folder}/{f}'
108
+ for f in ['list.csv', 'valid_clusters.txt', 'test_clusters.txt']
109
+ ]
110
+
111
+ @property
112
+ def processed_file_names(self) -> List[str]:
113
+ return ['splits.pkl', 'train.pt', 'valid.pt', 'test.pt']
114
+
115
+ def download(self) -> None:
116
+ file_path = download_url(self.raw_url[self.size], self.raw_dir)
117
+ extract_tar(file_path, self.raw_dir)
118
+ os.unlink(file_path)
119
+
120
+ def process(self) -> None:
121
+ alphabet_set = set(list('ACDEFGHIKLMNPQRSTVWYX'))
122
+ cluster_ids = self._process_split()
123
+ total_items = sum(len(items) for items in cluster_ids.values())
124
+ data_list = []
125
+
126
+ with tqdm(total=total_items, desc="Processing") as pbar:
127
+ for _, items in cluster_ids.items():
128
+ for chain_id, _ in items:
129
+ item = self._process_pdb1(chain_id)
130
+
131
+ if 'label' not in item:
132
+ pbar.update(1)
133
+ continue
134
+ if len(list(np.unique(item['idx']))) >= 352:
135
+ pbar.update(1)
136
+ continue
137
+
138
+ my_dict = self._process_pdb2(item)
139
+
140
+ if len(my_dict['seq']) > self.max_length:
141
+ pbar.update(1)
142
+ continue
143
+ bad_chars = set(list(
144
+ my_dict['seq'])).difference(alphabet_set)
145
+ if len(bad_chars) > 0:
146
+ pbar.update(1)
147
+ continue
148
+
149
+ x_chain_all, chain_seq_label_all, mask, chain_mask_all, residue_idx, chain_encoding_all = self._process_pdb3( # noqa: E501
150
+ my_dict)
151
+
152
+ data = Data(
153
+ x=x_chain_all, # [seq_len, 4, 3]
154
+ chain_seq_label=chain_seq_label_all, # [seq_len]
155
+ mask=mask, # [seq_len]
156
+ chain_mask_all=chain_mask_all, # [seq_len]
157
+ residue_idx=residue_idx, # [seq_len]
158
+ chain_encoding_all=chain_encoding_all, # [seq_len]
159
+ )
160
+
161
+ if self.pre_filter is not None and not self.pre_filter(
162
+ data):
163
+ continue
164
+ if self.pre_transform is not None:
165
+ data = self.pre_transform(data)
166
+
167
+ data_list.append(data)
168
+
169
+ if len(data_list) >= self.num_units:
170
+ pbar.update(total_items - pbar.n)
171
+ break
172
+ pbar.update(1)
173
+ else:
174
+ continue
175
+ break
176
+ self.save(data_list, self.processed_paths[self.splits[self.split]])
177
+
178
+ def _process_split(self) -> Dict[int, List[Tuple[str, int]]]:
179
+ import pandas as pd
180
+ save_path = self.processed_paths[0]
181
+
182
+ if os.path.exists(save_path):
183
+ print('Load split')
184
+ with open(save_path, 'rb') as f:
185
+ data = pickle.load(f)
186
+ else:
187
+ # CHAINID, DEPOSITION, RESOLUTION, HASH, CLUSTER, SEQUENCE
188
+ df = pd.read_csv(self.raw_paths[0])
189
+ df = df[(df['RESOLUTION'] <= self.rescut)
190
+ & (df['DEPOSITION'] <= self.datacut)]
191
+
192
+ val_ids = pd.read_csv(self.raw_paths[1], header=None)[0].tolist()
193
+ test_ids = pd.read_csv(self.raw_paths[2], header=None)[0].tolist()
194
+
195
+ # compile training and validation sets
196
+ data = {
197
+ 'train': defaultdict(list),
198
+ 'valid': defaultdict(list),
199
+ 'test': defaultdict(list),
200
+ }
201
+
202
+ for _, r in tqdm(df.iterrows(), desc='Processing split',
203
+ total=len(df)):
204
+ cluster_id = r['CLUSTER']
205
+ hash_id = r['HASH']
206
+ chain_id = r['CHAINID']
207
+ if cluster_id in val_ids:
208
+ data['valid'][cluster_id].append((chain_id, hash_id))
209
+ elif cluster_id in test_ids:
210
+ data['test'][cluster_id].append((chain_id, hash_id))
211
+ else:
212
+ data['train'][cluster_id].append((chain_id, hash_id))
213
+
214
+ with open(save_path, 'wb') as f:
215
+ pickle.dump(data, f)
216
+
217
+ return data[self.split]
218
+
219
+ def _process_pdb1(self, chain_id: str) -> Dict[str, Any]:
220
+ pdbid, chid = chain_id.split('_')
221
+ prefix = f'{self.raw_dir}/{self.sub_folder}/pdb/{pdbid[1:3]}/{pdbid}'
222
+ # load metadata
223
+ if not os.path.isfile(f'{prefix}.pt'):
224
+ return {'seq': np.zeros(5)}
225
+ meta = torch.load(f'{prefix}.pt')
226
+ asmb_ids = meta['asmb_ids']
227
+ asmb_chains = meta['asmb_chains']
228
+ chids = np.array(meta['chains'])
229
+
230
+ # find candidate assemblies which contain chid chain
231
+ asmb_candidates = {
232
+ a
233
+ for a, b in zip(asmb_ids, asmb_chains) if chid in b.split(',')
234
+ }
235
+
236
+ # if the chains is missing is missing from all the assemblies
237
+ # then return this chain alone
238
+ if len(asmb_candidates) < 1:
239
+ chain = torch.load(f'{prefix}_{chid}.pt')
240
+ L = len(chain['seq'])
241
+ return {
242
+ 'seq': chain['seq'],
243
+ 'xyz': chain['xyz'],
244
+ 'idx': torch.zeros(L).int(),
245
+ 'masked': torch.Tensor([0]).int(),
246
+ 'label': chain_id,
247
+ }
248
+
249
+ # randomly pick one assembly from candidates
250
+ asmb_i = random.sample(list(asmb_candidates), 1)
251
+
252
+ # indices of selected transforms
253
+ idx = np.where(np.array(asmb_ids) == asmb_i)[0]
254
+
255
+ # load relevant chains
256
+ chains = {
257
+ c: torch.load(f'{prefix}_{c}.pt')
258
+ for i in idx
259
+ for c in asmb_chains[i] if c in meta['chains']
260
+ }
261
+
262
+ # generate assembly
263
+ asmb = {}
264
+ for k in idx:
265
+
266
+ # pick k-th xform
267
+ xform = meta[f'asmb_xform{k}']
268
+ u = xform[:, :3, :3]
269
+ r = xform[:, :3, 3]
270
+
271
+ # select chains which k-th xform should be applied to
272
+ s1 = set(meta['chains'])
273
+ s2 = set(asmb_chains[k].split(','))
274
+ chains_k = s1 & s2
275
+
276
+ # transform selected chains
277
+ for c in chains_k:
278
+ try:
279
+ xyz = chains[c]['xyz']
280
+ xyz_ru = torch.einsum('bij,raj->brai', u, xyz) + r[:, None,
281
+ None, :]
282
+ asmb.update({
283
+ (c, k, i): xyz_i
284
+ for i, xyz_i in enumerate(xyz_ru)
285
+ })
286
+ except KeyError:
287
+ return {'seq': np.zeros(5)}
288
+
289
+ # select chains which share considerable similarity to chid
290
+ seqid = meta['tm'][chids == chid][0, :, 1]
291
+ homo = {
292
+ ch_j
293
+ for seqid_j, ch_j in zip(seqid, chids) if seqid_j > self.homo
294
+ }
295
+ # stack all chains in the assembly together
296
+ seq: str = ''
297
+ xyz_all: List[torch.Tensor] = []
298
+ idx_all: List[torch.Tensor] = []
299
+ masked: List[int] = []
300
+ seq_list = []
301
+ for counter, (k, v) in enumerate(asmb.items()):
302
+ seq += chains[k[0]]['seq']
303
+ seq_list.append(chains[k[0]]['seq'])
304
+ xyz_all.append(v)
305
+ idx_all.append(torch.full((v.shape[0], ), counter))
306
+ if k[0] in homo:
307
+ masked.append(counter)
308
+
309
+ return {
310
+ 'seq': seq,
311
+ 'xyz': torch.cat(xyz_all, dim=0),
312
+ 'idx': torch.cat(idx_all, dim=0),
313
+ 'masked': torch.Tensor(masked).int(),
314
+ 'label': chain_id,
315
+ }
316
+
317
+ def _process_pdb2(self, t: Dict[str, Any]) -> Dict[str, Any]:
318
+ init_alphabet = list(
319
+ 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz')
320
+ extra_alphabet = [str(item) for item in list(np.arange(300))]
321
+ chain_alphabet = init_alphabet + extra_alphabet
322
+ my_dict: Dict[str, Union[str, int, Dict[str, Any], List[Any]]] = {}
323
+ concat_seq = ''
324
+ mask_list = []
325
+ visible_list = []
326
+ for idx in list(np.unique(t['idx'])):
327
+ letter = chain_alphabet[idx]
328
+ res = np.argwhere(t['idx'] == idx)
329
+ initial_sequence = "".join(list(
330
+ np.array(list(t['seq']))[res][
331
+ 0,
332
+ ]))
333
+ if initial_sequence[-6:] == "HHHHHH":
334
+ res = res[:, :-6]
335
+ if initial_sequence[0:6] == "HHHHHH":
336
+ res = res[:, 6:]
337
+ if initial_sequence[-7:-1] == "HHHHHH":
338
+ res = res[:, :-7]
339
+ if initial_sequence[-8:-2] == "HHHHHH":
340
+ res = res[:, :-8]
341
+ if initial_sequence[-9:-3] == "HHHHHH":
342
+ res = res[:, :-9]
343
+ if initial_sequence[-10:-4] == "HHHHHH":
344
+ res = res[:, :-10]
345
+ if initial_sequence[1:7] == "HHHHHH":
346
+ res = res[:, 7:]
347
+ if initial_sequence[2:8] == "HHHHHH":
348
+ res = res[:, 8:]
349
+ if initial_sequence[3:9] == "HHHHHH":
350
+ res = res[:, 9:]
351
+ if initial_sequence[4:10] == "HHHHHH":
352
+ res = res[:, 10:]
353
+ if res.shape[1] >= 4:
354
+ chain_seq = "".join(list(np.array(list(t['seq']))[res][0]))
355
+ my_dict[f'seq_chain_{letter}'] = chain_seq
356
+ concat_seq += chain_seq
357
+ if idx in t['masked']:
358
+ mask_list.append(letter)
359
+ else:
360
+ visible_list.append(letter)
361
+ coords_dict_chain = {}
362
+ all_atoms = np.array(t['xyz'][res])[0] # [L, 14, 3]
363
+ for i, c in enumerate(['N', 'CA', 'C', 'O']):
364
+ coords_dict_chain[
365
+ f'{c}_chain_{letter}'] = all_atoms[:, i, :].tolist()
366
+ my_dict[f'coords_chain_{letter}'] = coords_dict_chain
367
+ my_dict['name'] = t['label']
368
+ my_dict['masked_list'] = mask_list
369
+ my_dict['visible_list'] = visible_list
370
+ my_dict['num_of_chains'] = len(mask_list) + len(visible_list)
371
+ my_dict['seq'] = concat_seq
372
+ return my_dict
373
+
374
+ def _process_pdb3(
375
+ self, b: Dict[str, Any]
376
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
377
+ torch.Tensor, torch.Tensor]:
378
+ L = len(b['seq'])
379
+ # residue idx with jumps across chains
380
+ residue_idx = -100 * np.ones([L], dtype=np.int32)
381
+ # get the list of masked / visible chains
382
+ masked_chains, visible_chains = b['masked_list'], b['visible_list']
383
+ visible_temp_dict, masked_temp_dict = {}, {}
384
+ for letter in masked_chains + visible_chains:
385
+ chain_seq = b[f'seq_chain_{letter}']
386
+ if letter in visible_chains:
387
+ visible_temp_dict[letter] = chain_seq
388
+ elif letter in masked_chains:
389
+ masked_temp_dict[letter] = chain_seq
390
+ # check for duplicate chains (same sequence but different identity)
391
+ for _, vm in masked_temp_dict.items():
392
+ for kv, vv in visible_temp_dict.items():
393
+ if vm == vv:
394
+ if kv not in masked_chains:
395
+ masked_chains.append(kv)
396
+ if kv in visible_chains:
397
+ visible_chains.remove(kv)
398
+ # build protein data structures
399
+ all_chains = masked_chains + visible_chains
400
+ np.random.shuffle(all_chains)
401
+ x_chain_list = []
402
+ chain_mask_list = []
403
+ chain_seq_list = []
404
+ chain_encoding_list = []
405
+ c, l0, l1 = 1, 0, 0
406
+ for letter in all_chains:
407
+ chain_seq = b[f'seq_chain_{letter}']
408
+ chain_length = len(chain_seq)
409
+ chain_coords = b[f'coords_chain_{letter}']
410
+ x_chain = np.stack([
411
+ chain_coords[c] for c in [
412
+ f'N_chain_{letter}', f'CA_chain_{letter}',
413
+ f'C_chain_{letter}', f'O_chain_{letter}'
414
+ ]
415
+ ], 1) # [chain_length, 4, 3]
416
+ x_chain_list.append(x_chain)
417
+ chain_seq_list.append(chain_seq)
418
+ if letter in visible_chains:
419
+ chain_mask = np.zeros(chain_length) # 0 for visible chains
420
+ elif letter in masked_chains:
421
+ chain_mask = np.ones(chain_length) # 1 for masked chains
422
+ chain_mask_list.append(chain_mask)
423
+ chain_encoding_list.append(c * np.ones(chain_length))
424
+ l1 += chain_length
425
+ residue_idx[l0:l1] = 100 * (c - 1) + np.arange(l0, l1)
426
+ l0 += chain_length
427
+ c += 1
428
+ x_chain_all = np.concatenate(x_chain_list, 0) # [L, 4, 3]
429
+ chain_seq_all = "".join(chain_seq_list)
430
+ # [L,] 1.0 for places that need to be predicted
431
+ chain_mask_all = np.concatenate(chain_mask_list, 0)
432
+ chain_encoding_all = np.concatenate(chain_encoding_list, 0)
433
+
434
+ # Convert to labels
435
+ alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
436
+ chain_seq_label_all = np.asarray(
437
+ [alphabet.index(a) for a in chain_seq_all], dtype=np.int32)
438
+
439
+ isnan = np.isnan(x_chain_all)
440
+ mask = np.isfinite(np.sum(x_chain_all, (1, 2))).astype(np.float32)
441
+ x_chain_all[isnan] = 0.
442
+
443
+ # Conversion
444
+ return (
445
+ torch.from_numpy(x_chain_all).to(dtype=torch.float32),
446
+ torch.from_numpy(chain_seq_label_all).to(dtype=torch.long),
447
+ torch.from_numpy(mask).to(dtype=torch.float32),
448
+ torch.from_numpy(chain_mask_all).to(dtype=torch.float32),
449
+ torch.from_numpy(residue_idx).to(dtype=torch.long),
450
+ torch.from_numpy(chain_encoding_all).to(dtype=torch.long),
451
+ )
@@ -84,7 +84,7 @@ class QM7b(InMemoryDataset):
84
84
  edge_attr = coulomb_matrix[i, edge_index[0], edge_index[1]]
85
85
  y = target[i].view(1, -1)
86
86
  data = Data(edge_index=edge_index, edge_attr=edge_attr, y=y)
87
- data.num_nodes = edge_index.max().item() + 1
87
+ data.num_nodes = int(edge_index.max()) + 1
88
88
  data_list.append(data)
89
89
 
90
90
  if self.pre_filter is not None:
@@ -13,6 +13,7 @@ from torch_geometric.data import (
13
13
  download_url,
14
14
  extract_zip,
15
15
  )
16
+ from torch_geometric.io import fs
16
17
  from torch_geometric.utils import one_hot, scatter
17
18
 
18
19
  HAR2EV = 27.211386246
@@ -201,7 +202,7 @@ class QM9(InMemoryDataset):
201
202
  from rdkit import Chem, RDLogger
202
203
  from rdkit.Chem.rdchem import BondType as BT
203
204
  from rdkit.Chem.rdchem import HybridizationType
204
- RDLogger.DisableLog('rdApp.*') # type: ignore
205
+ RDLogger.DisableLog('rdApp.*') # type: ignore[attr-defined]
205
206
  WITH_RDKIT = True
206
207
 
207
208
  except ImportError:
@@ -212,7 +213,7 @@ class QM9(InMemoryDataset):
212
213
  "install 'rdkit' to alternatively process the raw data."),
213
214
  file=sys.stderr)
214
215
 
215
- data_list = torch.load(self.raw_paths[0])
216
+ data_list = fs.torch_load(self.raw_paths[0])
216
217
  data_list = [Data(**data_dict) for data_dict in data_list]
217
218
 
218
219
  if self.pre_filter is not None:
@@ -6,7 +6,7 @@ from typing import Callable, List, Optional
6
6
  import torch
7
7
 
8
8
  from torch_geometric.data import InMemoryDataset, download_url, extract_zip
9
- from torch_geometric.io import read_off, read_txt_array
9
+ from torch_geometric.io import fs, read_off, read_txt_array
10
10
 
11
11
 
12
12
  class SHREC2016(InMemoryDataset):
@@ -79,7 +79,7 @@ class SHREC2016(InMemoryDataset):
79
79
  self.cat = category.lower()
80
80
  super().__init__(root, transform, pre_transform, pre_filter,
81
81
  force_reload=force_reload)
82
- self.__ref__ = torch.load(self.processed_paths[0])
82
+ self.__ref__ = fs.torch_load(self.processed_paths[0])
83
83
  path = self.processed_paths[1] if train else self.processed_paths[2]
84
84
  self.load(path)
85
85
 
@@ -109,7 +109,7 @@ def read_ego(files: List[str], name: str) -> List[EgoData]:
109
109
  row = torch.cat([row, row_ego, col_ego], dim=0)
110
110
  col = torch.cat([col, col_ego, row_ego], dim=0)
111
111
  edge_index = torch.stack([row, col], dim=0)
112
- edge_index = coalesce(edge_index, num_nodes=N)
112
+ edge_index = coalesce(edge_index, num_nodes=int(N))
113
113
 
114
114
  data = EgoData(x=x, edge_index=edge_index, circle=circle,
115
115
  circle_batch=circle_batch)
@@ -129,7 +129,7 @@ def read_soc(files: List[str], name: str) -> List[Data]:
129
129
  edge_index = pd.read_csv(files[0], sep='\t', header=None,
130
130
  skiprows=skiprows, dtype=np.int64)
131
131
  edge_index = torch.from_numpy(edge_index.values).t()
132
- num_nodes = edge_index.max().item() + 1
132
+ num_nodes = int(edge_index.max()) + 1
133
133
  edge_index = coalesce(edge_index, num_nodes=num_nodes)
134
134
 
135
135
  return [Data(edge_index=edge_index, num_nodes=num_nodes)]
@@ -143,11 +143,15 @@ def read_wiki(files: List[str], name: str) -> List[Data]:
143
143
  edge_index = torch.from_numpy(edge_index.values).t()
144
144
 
145
145
  idx = torch.unique(edge_index.flatten())
146
- idx_assoc = torch.full((edge_index.max() + 1, ), -1, dtype=torch.long)
146
+ idx_assoc = torch.full(
147
+ (edge_index.max() + 1, ), # type: ignore
148
+ -1,
149
+ dtype=torch.long,
150
+ )
147
151
  idx_assoc[idx] = torch.arange(idx.size(0))
148
152
 
149
153
  edge_index = idx_assoc[edge_index]
150
- num_nodes = edge_index.max().item() + 1
154
+ num_nodes = int(edge_index.max()) + 1
151
155
  edge_index = coalesce(edge_index, num_nodes=num_nodes)
152
156
 
153
157
  return [Data(edge_index=edge_index, num_nodes=num_nodes)]