pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pickle
|
|
3
|
+
import random
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
from torch_geometric.data import (
|
|
12
|
+
Data,
|
|
13
|
+
InMemoryDataset,
|
|
14
|
+
download_url,
|
|
15
|
+
extract_tar,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ProteinMPNNDataset(InMemoryDataset):
|
|
20
|
+
r"""The ProteinMPNN dataset from the `"Robust deep learning based protein
|
|
21
|
+
sequence design using ProteinMPNN"
|
|
22
|
+
<https://www.biorxiv.org/content/10.1101/2022.06.03.494563v1>`_ paper.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
root (str): Root directory where the dataset should be saved.
|
|
26
|
+
size (str): Size of the PDB information to train the model.
|
|
27
|
+
If :obj:`"small"`, loads the small dataset (229.4 MB).
|
|
28
|
+
If :obj:`"large"`, loads the large dataset (64.1 GB).
|
|
29
|
+
(default: :obj:`"small"`)
|
|
30
|
+
split (str, optional): If :obj:`"train"`, loads the training dataset.
|
|
31
|
+
If :obj:`"valid"`, loads the validation dataset.
|
|
32
|
+
If :obj:`"test"`, loads the test dataset.
|
|
33
|
+
(default: :obj:`"train"`)
|
|
34
|
+
datacut (str, optional): Date cutoff to filter the dataset.
|
|
35
|
+
(default: :obj:`"2030-01-01"`)
|
|
36
|
+
rescut (float, optional): PDB resolution cutoff.
|
|
37
|
+
(default: :obj:`3.5`)
|
|
38
|
+
homo (float, optional): Homology cutoff.
|
|
39
|
+
(default: :obj:`0.70`)
|
|
40
|
+
max_length (int, optional): Maximum length of the protein complex.
|
|
41
|
+
(default: :obj:`10000`)
|
|
42
|
+
num_units (int, optional): Number of units of the protein complex.
|
|
43
|
+
(default: :obj:`150`)
|
|
44
|
+
transform (callable, optional): A function/transform that takes in an
|
|
45
|
+
:obj:`torch_geometric.data.Data` object and returns a transformed
|
|
46
|
+
version. The data object will be transformed before every access.
|
|
47
|
+
(default: :obj:`None`)
|
|
48
|
+
pre_transform (callable, optional): A function/transform that takes in
|
|
49
|
+
an :obj:`torch_geometric.data.Data` object and returns a
|
|
50
|
+
transformed version. The data object will be transformed before
|
|
51
|
+
being saved to disk. (default: :obj:`None`)
|
|
52
|
+
pre_filter (callable, optional): A function that takes in an
|
|
53
|
+
:obj:`torch_geometric.data.Data` object and returns a boolean
|
|
54
|
+
value, indicating whether the data object should be included in the
|
|
55
|
+
final dataset. (default: :obj:`None`)
|
|
56
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
|
57
|
+
(default: :obj:`False`)
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
raw_url = {
|
|
61
|
+
'small':
|
|
62
|
+
'https://files.ipd.uw.edu/pub/training_sets/'
|
|
63
|
+
'pdb_2021aug02_sample.tar.gz',
|
|
64
|
+
'large':
|
|
65
|
+
'https://files.ipd.uw.edu/pub/training_sets/'
|
|
66
|
+
'pdb_2021aug02.tar.gz',
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
splits = {
|
|
70
|
+
'train': 1,
|
|
71
|
+
'valid': 2,
|
|
72
|
+
'test': 3,
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
root: str,
|
|
78
|
+
size: str = 'small',
|
|
79
|
+
split: str = 'train',
|
|
80
|
+
datacut: str = '2030-01-01',
|
|
81
|
+
rescut: float = 3.5,
|
|
82
|
+
homo: float = 0.70,
|
|
83
|
+
max_length: int = 10000,
|
|
84
|
+
num_units: int = 150,
|
|
85
|
+
transform: Optional[Callable] = None,
|
|
86
|
+
pre_transform: Optional[Callable] = None,
|
|
87
|
+
pre_filter: Optional[Callable] = None,
|
|
88
|
+
force_reload: bool = False,
|
|
89
|
+
) -> None:
|
|
90
|
+
self.size = size
|
|
91
|
+
self.split = split
|
|
92
|
+
self.datacut = datacut
|
|
93
|
+
self.rescut = rescut
|
|
94
|
+
self.homo = homo
|
|
95
|
+
self.max_length = max_length
|
|
96
|
+
self.num_units = num_units
|
|
97
|
+
|
|
98
|
+
self.sub_folder = self.raw_url[self.size].split('/')[-1].split('.')[0]
|
|
99
|
+
|
|
100
|
+
super().__init__(root, transform, pre_transform, pre_filter,
|
|
101
|
+
force_reload=force_reload)
|
|
102
|
+
self.load(self.processed_paths[self.splits[self.split]])
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def raw_file_names(self) -> List[str]:
|
|
106
|
+
return [
|
|
107
|
+
f'{self.sub_folder}/{f}'
|
|
108
|
+
for f in ['list.csv', 'valid_clusters.txt', 'test_clusters.txt']
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def processed_file_names(self) -> List[str]:
|
|
113
|
+
return ['splits.pkl', 'train.pt', 'valid.pt', 'test.pt']
|
|
114
|
+
|
|
115
|
+
def download(self) -> None:
|
|
116
|
+
file_path = download_url(self.raw_url[self.size], self.raw_dir)
|
|
117
|
+
extract_tar(file_path, self.raw_dir)
|
|
118
|
+
os.unlink(file_path)
|
|
119
|
+
|
|
120
|
+
def process(self) -> None:
|
|
121
|
+
alphabet_set = set(list('ACDEFGHIKLMNPQRSTVWYX'))
|
|
122
|
+
cluster_ids = self._process_split()
|
|
123
|
+
total_items = sum(len(items) for items in cluster_ids.values())
|
|
124
|
+
data_list = []
|
|
125
|
+
|
|
126
|
+
with tqdm(total=total_items, desc="Processing") as pbar:
|
|
127
|
+
for _, items in cluster_ids.items():
|
|
128
|
+
for chain_id, _ in items:
|
|
129
|
+
item = self._process_pdb1(chain_id)
|
|
130
|
+
|
|
131
|
+
if 'label' not in item:
|
|
132
|
+
pbar.update(1)
|
|
133
|
+
continue
|
|
134
|
+
if len(list(np.unique(item['idx']))) >= 352:
|
|
135
|
+
pbar.update(1)
|
|
136
|
+
continue
|
|
137
|
+
|
|
138
|
+
my_dict = self._process_pdb2(item)
|
|
139
|
+
|
|
140
|
+
if len(my_dict['seq']) > self.max_length:
|
|
141
|
+
pbar.update(1)
|
|
142
|
+
continue
|
|
143
|
+
bad_chars = set(list(
|
|
144
|
+
my_dict['seq'])).difference(alphabet_set)
|
|
145
|
+
if len(bad_chars) > 0:
|
|
146
|
+
pbar.update(1)
|
|
147
|
+
continue
|
|
148
|
+
|
|
149
|
+
x_chain_all, chain_seq_label_all, mask, chain_mask_all, residue_idx, chain_encoding_all = self._process_pdb3( # noqa: E501
|
|
150
|
+
my_dict)
|
|
151
|
+
|
|
152
|
+
data = Data(
|
|
153
|
+
x=x_chain_all, # [seq_len, 4, 3]
|
|
154
|
+
chain_seq_label=chain_seq_label_all, # [seq_len]
|
|
155
|
+
mask=mask, # [seq_len]
|
|
156
|
+
chain_mask_all=chain_mask_all, # [seq_len]
|
|
157
|
+
residue_idx=residue_idx, # [seq_len]
|
|
158
|
+
chain_encoding_all=chain_encoding_all, # [seq_len]
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
if self.pre_filter is not None and not self.pre_filter(
|
|
162
|
+
data):
|
|
163
|
+
continue
|
|
164
|
+
if self.pre_transform is not None:
|
|
165
|
+
data = self.pre_transform(data)
|
|
166
|
+
|
|
167
|
+
data_list.append(data)
|
|
168
|
+
|
|
169
|
+
if len(data_list) >= self.num_units:
|
|
170
|
+
pbar.update(total_items - pbar.n)
|
|
171
|
+
break
|
|
172
|
+
pbar.update(1)
|
|
173
|
+
else:
|
|
174
|
+
continue
|
|
175
|
+
break
|
|
176
|
+
self.save(data_list, self.processed_paths[self.splits[self.split]])
|
|
177
|
+
|
|
178
|
+
def _process_split(self) -> Dict[int, List[Tuple[str, int]]]:
|
|
179
|
+
import pandas as pd
|
|
180
|
+
save_path = self.processed_paths[0]
|
|
181
|
+
|
|
182
|
+
if os.path.exists(save_path):
|
|
183
|
+
print('Load split')
|
|
184
|
+
with open(save_path, 'rb') as f:
|
|
185
|
+
data = pickle.load(f)
|
|
186
|
+
else:
|
|
187
|
+
# CHAINID, DEPOSITION, RESOLUTION, HASH, CLUSTER, SEQUENCE
|
|
188
|
+
df = pd.read_csv(self.raw_paths[0])
|
|
189
|
+
df = df[(df['RESOLUTION'] <= self.rescut)
|
|
190
|
+
& (df['DEPOSITION'] <= self.datacut)]
|
|
191
|
+
|
|
192
|
+
val_ids = pd.read_csv(self.raw_paths[1], header=None)[0].tolist()
|
|
193
|
+
test_ids = pd.read_csv(self.raw_paths[2], header=None)[0].tolist()
|
|
194
|
+
|
|
195
|
+
# compile training and validation sets
|
|
196
|
+
data = {
|
|
197
|
+
'train': defaultdict(list),
|
|
198
|
+
'valid': defaultdict(list),
|
|
199
|
+
'test': defaultdict(list),
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
for _, r in tqdm(df.iterrows(), desc='Processing split',
|
|
203
|
+
total=len(df)):
|
|
204
|
+
cluster_id = r['CLUSTER']
|
|
205
|
+
hash_id = r['HASH']
|
|
206
|
+
chain_id = r['CHAINID']
|
|
207
|
+
if cluster_id in val_ids:
|
|
208
|
+
data['valid'][cluster_id].append((chain_id, hash_id))
|
|
209
|
+
elif cluster_id in test_ids:
|
|
210
|
+
data['test'][cluster_id].append((chain_id, hash_id))
|
|
211
|
+
else:
|
|
212
|
+
data['train'][cluster_id].append((chain_id, hash_id))
|
|
213
|
+
|
|
214
|
+
with open(save_path, 'wb') as f:
|
|
215
|
+
pickle.dump(data, f)
|
|
216
|
+
|
|
217
|
+
return data[self.split]
|
|
218
|
+
|
|
219
|
+
def _process_pdb1(self, chain_id: str) -> Dict[str, Any]:
|
|
220
|
+
pdbid, chid = chain_id.split('_')
|
|
221
|
+
prefix = f'{self.raw_dir}/{self.sub_folder}/pdb/{pdbid[1:3]}/{pdbid}'
|
|
222
|
+
# load metadata
|
|
223
|
+
if not os.path.isfile(f'{prefix}.pt'):
|
|
224
|
+
return {'seq': np.zeros(5)}
|
|
225
|
+
meta = torch.load(f'{prefix}.pt')
|
|
226
|
+
asmb_ids = meta['asmb_ids']
|
|
227
|
+
asmb_chains = meta['asmb_chains']
|
|
228
|
+
chids = np.array(meta['chains'])
|
|
229
|
+
|
|
230
|
+
# find candidate assemblies which contain chid chain
|
|
231
|
+
asmb_candidates = {
|
|
232
|
+
a
|
|
233
|
+
for a, b in zip(asmb_ids, asmb_chains) if chid in b.split(',')
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
# if the chains is missing is missing from all the assemblies
|
|
237
|
+
# then return this chain alone
|
|
238
|
+
if len(asmb_candidates) < 1:
|
|
239
|
+
chain = torch.load(f'{prefix}_{chid}.pt')
|
|
240
|
+
L = len(chain['seq'])
|
|
241
|
+
return {
|
|
242
|
+
'seq': chain['seq'],
|
|
243
|
+
'xyz': chain['xyz'],
|
|
244
|
+
'idx': torch.zeros(L).int(),
|
|
245
|
+
'masked': torch.Tensor([0]).int(),
|
|
246
|
+
'label': chain_id,
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
# randomly pick one assembly from candidates
|
|
250
|
+
asmb_i = random.sample(list(asmb_candidates), 1)
|
|
251
|
+
|
|
252
|
+
# indices of selected transforms
|
|
253
|
+
idx = np.where(np.array(asmb_ids) == asmb_i)[0]
|
|
254
|
+
|
|
255
|
+
# load relevant chains
|
|
256
|
+
chains = {
|
|
257
|
+
c: torch.load(f'{prefix}_{c}.pt')
|
|
258
|
+
for i in idx
|
|
259
|
+
for c in asmb_chains[i] if c in meta['chains']
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
# generate assembly
|
|
263
|
+
asmb = {}
|
|
264
|
+
for k in idx:
|
|
265
|
+
|
|
266
|
+
# pick k-th xform
|
|
267
|
+
xform = meta[f'asmb_xform{k}']
|
|
268
|
+
u = xform[:, :3, :3]
|
|
269
|
+
r = xform[:, :3, 3]
|
|
270
|
+
|
|
271
|
+
# select chains which k-th xform should be applied to
|
|
272
|
+
s1 = set(meta['chains'])
|
|
273
|
+
s2 = set(asmb_chains[k].split(','))
|
|
274
|
+
chains_k = s1 & s2
|
|
275
|
+
|
|
276
|
+
# transform selected chains
|
|
277
|
+
for c in chains_k:
|
|
278
|
+
try:
|
|
279
|
+
xyz = chains[c]['xyz']
|
|
280
|
+
xyz_ru = torch.einsum('bij,raj->brai', u, xyz) + r[:, None,
|
|
281
|
+
None, :]
|
|
282
|
+
asmb.update({
|
|
283
|
+
(c, k, i): xyz_i
|
|
284
|
+
for i, xyz_i in enumerate(xyz_ru)
|
|
285
|
+
})
|
|
286
|
+
except KeyError:
|
|
287
|
+
return {'seq': np.zeros(5)}
|
|
288
|
+
|
|
289
|
+
# select chains which share considerable similarity to chid
|
|
290
|
+
seqid = meta['tm'][chids == chid][0, :, 1]
|
|
291
|
+
homo = {
|
|
292
|
+
ch_j
|
|
293
|
+
for seqid_j, ch_j in zip(seqid, chids) if seqid_j > self.homo
|
|
294
|
+
}
|
|
295
|
+
# stack all chains in the assembly together
|
|
296
|
+
seq: str = ''
|
|
297
|
+
xyz_all: List[torch.Tensor] = []
|
|
298
|
+
idx_all: List[torch.Tensor] = []
|
|
299
|
+
masked: List[int] = []
|
|
300
|
+
seq_list = []
|
|
301
|
+
for counter, (k, v) in enumerate(asmb.items()):
|
|
302
|
+
seq += chains[k[0]]['seq']
|
|
303
|
+
seq_list.append(chains[k[0]]['seq'])
|
|
304
|
+
xyz_all.append(v)
|
|
305
|
+
idx_all.append(torch.full((v.shape[0], ), counter))
|
|
306
|
+
if k[0] in homo:
|
|
307
|
+
masked.append(counter)
|
|
308
|
+
|
|
309
|
+
return {
|
|
310
|
+
'seq': seq,
|
|
311
|
+
'xyz': torch.cat(xyz_all, dim=0),
|
|
312
|
+
'idx': torch.cat(idx_all, dim=0),
|
|
313
|
+
'masked': torch.Tensor(masked).int(),
|
|
314
|
+
'label': chain_id,
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
def _process_pdb2(self, t: Dict[str, Any]) -> Dict[str, Any]:
|
|
318
|
+
init_alphabet = list(
|
|
319
|
+
'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz')
|
|
320
|
+
extra_alphabet = [str(item) for item in list(np.arange(300))]
|
|
321
|
+
chain_alphabet = init_alphabet + extra_alphabet
|
|
322
|
+
my_dict: Dict[str, Union[str, int, Dict[str, Any], List[Any]]] = {}
|
|
323
|
+
concat_seq = ''
|
|
324
|
+
mask_list = []
|
|
325
|
+
visible_list = []
|
|
326
|
+
for idx in list(np.unique(t['idx'])):
|
|
327
|
+
letter = chain_alphabet[idx]
|
|
328
|
+
res = np.argwhere(t['idx'] == idx)
|
|
329
|
+
initial_sequence = "".join(list(
|
|
330
|
+
np.array(list(t['seq']))[res][
|
|
331
|
+
0,
|
|
332
|
+
]))
|
|
333
|
+
if initial_sequence[-6:] == "HHHHHH":
|
|
334
|
+
res = res[:, :-6]
|
|
335
|
+
if initial_sequence[0:6] == "HHHHHH":
|
|
336
|
+
res = res[:, 6:]
|
|
337
|
+
if initial_sequence[-7:-1] == "HHHHHH":
|
|
338
|
+
res = res[:, :-7]
|
|
339
|
+
if initial_sequence[-8:-2] == "HHHHHH":
|
|
340
|
+
res = res[:, :-8]
|
|
341
|
+
if initial_sequence[-9:-3] == "HHHHHH":
|
|
342
|
+
res = res[:, :-9]
|
|
343
|
+
if initial_sequence[-10:-4] == "HHHHHH":
|
|
344
|
+
res = res[:, :-10]
|
|
345
|
+
if initial_sequence[1:7] == "HHHHHH":
|
|
346
|
+
res = res[:, 7:]
|
|
347
|
+
if initial_sequence[2:8] == "HHHHHH":
|
|
348
|
+
res = res[:, 8:]
|
|
349
|
+
if initial_sequence[3:9] == "HHHHHH":
|
|
350
|
+
res = res[:, 9:]
|
|
351
|
+
if initial_sequence[4:10] == "HHHHHH":
|
|
352
|
+
res = res[:, 10:]
|
|
353
|
+
if res.shape[1] >= 4:
|
|
354
|
+
chain_seq = "".join(list(np.array(list(t['seq']))[res][0]))
|
|
355
|
+
my_dict[f'seq_chain_{letter}'] = chain_seq
|
|
356
|
+
concat_seq += chain_seq
|
|
357
|
+
if idx in t['masked']:
|
|
358
|
+
mask_list.append(letter)
|
|
359
|
+
else:
|
|
360
|
+
visible_list.append(letter)
|
|
361
|
+
coords_dict_chain = {}
|
|
362
|
+
all_atoms = np.array(t['xyz'][res])[0] # [L, 14, 3]
|
|
363
|
+
for i, c in enumerate(['N', 'CA', 'C', 'O']):
|
|
364
|
+
coords_dict_chain[
|
|
365
|
+
f'{c}_chain_{letter}'] = all_atoms[:, i, :].tolist()
|
|
366
|
+
my_dict[f'coords_chain_{letter}'] = coords_dict_chain
|
|
367
|
+
my_dict['name'] = t['label']
|
|
368
|
+
my_dict['masked_list'] = mask_list
|
|
369
|
+
my_dict['visible_list'] = visible_list
|
|
370
|
+
my_dict['num_of_chains'] = len(mask_list) + len(visible_list)
|
|
371
|
+
my_dict['seq'] = concat_seq
|
|
372
|
+
return my_dict
|
|
373
|
+
|
|
374
|
+
def _process_pdb3(
|
|
375
|
+
self, b: Dict[str, Any]
|
|
376
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
|
377
|
+
torch.Tensor, torch.Tensor]:
|
|
378
|
+
L = len(b['seq'])
|
|
379
|
+
# residue idx with jumps across chains
|
|
380
|
+
residue_idx = -100 * np.ones([L], dtype=np.int32)
|
|
381
|
+
# get the list of masked / visible chains
|
|
382
|
+
masked_chains, visible_chains = b['masked_list'], b['visible_list']
|
|
383
|
+
visible_temp_dict, masked_temp_dict = {}, {}
|
|
384
|
+
for letter in masked_chains + visible_chains:
|
|
385
|
+
chain_seq = b[f'seq_chain_{letter}']
|
|
386
|
+
if letter in visible_chains:
|
|
387
|
+
visible_temp_dict[letter] = chain_seq
|
|
388
|
+
elif letter in masked_chains:
|
|
389
|
+
masked_temp_dict[letter] = chain_seq
|
|
390
|
+
# check for duplicate chains (same sequence but different identity)
|
|
391
|
+
for _, vm in masked_temp_dict.items():
|
|
392
|
+
for kv, vv in visible_temp_dict.items():
|
|
393
|
+
if vm == vv:
|
|
394
|
+
if kv not in masked_chains:
|
|
395
|
+
masked_chains.append(kv)
|
|
396
|
+
if kv in visible_chains:
|
|
397
|
+
visible_chains.remove(kv)
|
|
398
|
+
# build protein data structures
|
|
399
|
+
all_chains = masked_chains + visible_chains
|
|
400
|
+
np.random.shuffle(all_chains)
|
|
401
|
+
x_chain_list = []
|
|
402
|
+
chain_mask_list = []
|
|
403
|
+
chain_seq_list = []
|
|
404
|
+
chain_encoding_list = []
|
|
405
|
+
c, l0, l1 = 1, 0, 0
|
|
406
|
+
for letter in all_chains:
|
|
407
|
+
chain_seq = b[f'seq_chain_{letter}']
|
|
408
|
+
chain_length = len(chain_seq)
|
|
409
|
+
chain_coords = b[f'coords_chain_{letter}']
|
|
410
|
+
x_chain = np.stack([
|
|
411
|
+
chain_coords[c] for c in [
|
|
412
|
+
f'N_chain_{letter}', f'CA_chain_{letter}',
|
|
413
|
+
f'C_chain_{letter}', f'O_chain_{letter}'
|
|
414
|
+
]
|
|
415
|
+
], 1) # [chain_length, 4, 3]
|
|
416
|
+
x_chain_list.append(x_chain)
|
|
417
|
+
chain_seq_list.append(chain_seq)
|
|
418
|
+
if letter in visible_chains:
|
|
419
|
+
chain_mask = np.zeros(chain_length) # 0 for visible chains
|
|
420
|
+
elif letter in masked_chains:
|
|
421
|
+
chain_mask = np.ones(chain_length) # 1 for masked chains
|
|
422
|
+
chain_mask_list.append(chain_mask)
|
|
423
|
+
chain_encoding_list.append(c * np.ones(chain_length))
|
|
424
|
+
l1 += chain_length
|
|
425
|
+
residue_idx[l0:l1] = 100 * (c - 1) + np.arange(l0, l1)
|
|
426
|
+
l0 += chain_length
|
|
427
|
+
c += 1
|
|
428
|
+
x_chain_all = np.concatenate(x_chain_list, 0) # [L, 4, 3]
|
|
429
|
+
chain_seq_all = "".join(chain_seq_list)
|
|
430
|
+
# [L,] 1.0 for places that need to be predicted
|
|
431
|
+
chain_mask_all = np.concatenate(chain_mask_list, 0)
|
|
432
|
+
chain_encoding_all = np.concatenate(chain_encoding_list, 0)
|
|
433
|
+
|
|
434
|
+
# Convert to labels
|
|
435
|
+
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
|
|
436
|
+
chain_seq_label_all = np.asarray(
|
|
437
|
+
[alphabet.index(a) for a in chain_seq_all], dtype=np.int32)
|
|
438
|
+
|
|
439
|
+
isnan = np.isnan(x_chain_all)
|
|
440
|
+
mask = np.isfinite(np.sum(x_chain_all, (1, 2))).astype(np.float32)
|
|
441
|
+
x_chain_all[isnan] = 0.
|
|
442
|
+
|
|
443
|
+
# Conversion
|
|
444
|
+
return (
|
|
445
|
+
torch.from_numpy(x_chain_all).to(dtype=torch.float32),
|
|
446
|
+
torch.from_numpy(chain_seq_label_all).to(dtype=torch.long),
|
|
447
|
+
torch.from_numpy(mask).to(dtype=torch.float32),
|
|
448
|
+
torch.from_numpy(chain_mask_all).to(dtype=torch.float32),
|
|
449
|
+
torch.from_numpy(residue_idx).to(dtype=torch.long),
|
|
450
|
+
torch.from_numpy(chain_encoding_all).to(dtype=torch.long),
|
|
451
|
+
)
|
torch_geometric/datasets/qm7.py
CHANGED
|
@@ -84,7 +84,7 @@ class QM7b(InMemoryDataset):
|
|
|
84
84
|
edge_attr = coulomb_matrix[i, edge_index[0], edge_index[1]]
|
|
85
85
|
y = target[i].view(1, -1)
|
|
86
86
|
data = Data(edge_index=edge_index, edge_attr=edge_attr, y=y)
|
|
87
|
-
data.num_nodes = edge_index.max()
|
|
87
|
+
data.num_nodes = int(edge_index.max()) + 1
|
|
88
88
|
data_list.append(data)
|
|
89
89
|
|
|
90
90
|
if self.pre_filter is not None:
|
torch_geometric/datasets/qm9.py
CHANGED
|
@@ -13,6 +13,7 @@ from torch_geometric.data import (
|
|
|
13
13
|
download_url,
|
|
14
14
|
extract_zip,
|
|
15
15
|
)
|
|
16
|
+
from torch_geometric.io import fs
|
|
16
17
|
from torch_geometric.utils import one_hot, scatter
|
|
17
18
|
|
|
18
19
|
HAR2EV = 27.211386246
|
|
@@ -201,7 +202,7 @@ class QM9(InMemoryDataset):
|
|
|
201
202
|
from rdkit import Chem, RDLogger
|
|
202
203
|
from rdkit.Chem.rdchem import BondType as BT
|
|
203
204
|
from rdkit.Chem.rdchem import HybridizationType
|
|
204
|
-
RDLogger.DisableLog('rdApp.*') # type: ignore
|
|
205
|
+
RDLogger.DisableLog('rdApp.*') # type: ignore[attr-defined]
|
|
205
206
|
WITH_RDKIT = True
|
|
206
207
|
|
|
207
208
|
except ImportError:
|
|
@@ -212,7 +213,7 @@ class QM9(InMemoryDataset):
|
|
|
212
213
|
"install 'rdkit' to alternatively process the raw data."),
|
|
213
214
|
file=sys.stderr)
|
|
214
215
|
|
|
215
|
-
data_list =
|
|
216
|
+
data_list = fs.torch_load(self.raw_paths[0])
|
|
216
217
|
data_list = [Data(**data_dict) for data_dict in data_list]
|
|
217
218
|
|
|
218
219
|
if self.pre_filter is not None:
|
|
@@ -6,7 +6,7 @@ from typing import Callable, List, Optional
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
8
|
from torch_geometric.data import InMemoryDataset, download_url, extract_zip
|
|
9
|
-
from torch_geometric.io import read_off, read_txt_array
|
|
9
|
+
from torch_geometric.io import fs, read_off, read_txt_array
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class SHREC2016(InMemoryDataset):
|
|
@@ -79,7 +79,7 @@ class SHREC2016(InMemoryDataset):
|
|
|
79
79
|
self.cat = category.lower()
|
|
80
80
|
super().__init__(root, transform, pre_transform, pre_filter,
|
|
81
81
|
force_reload=force_reload)
|
|
82
|
-
self.__ref__ =
|
|
82
|
+
self.__ref__ = fs.torch_load(self.processed_paths[0])
|
|
83
83
|
path = self.processed_paths[1] if train else self.processed_paths[2]
|
|
84
84
|
self.load(path)
|
|
85
85
|
|
|
@@ -109,7 +109,7 @@ def read_ego(files: List[str], name: str) -> List[EgoData]:
|
|
|
109
109
|
row = torch.cat([row, row_ego, col_ego], dim=0)
|
|
110
110
|
col = torch.cat([col, col_ego, row_ego], dim=0)
|
|
111
111
|
edge_index = torch.stack([row, col], dim=0)
|
|
112
|
-
edge_index = coalesce(edge_index, num_nodes=N)
|
|
112
|
+
edge_index = coalesce(edge_index, num_nodes=int(N))
|
|
113
113
|
|
|
114
114
|
data = EgoData(x=x, edge_index=edge_index, circle=circle,
|
|
115
115
|
circle_batch=circle_batch)
|
|
@@ -129,7 +129,7 @@ def read_soc(files: List[str], name: str) -> List[Data]:
|
|
|
129
129
|
edge_index = pd.read_csv(files[0], sep='\t', header=None,
|
|
130
130
|
skiprows=skiprows, dtype=np.int64)
|
|
131
131
|
edge_index = torch.from_numpy(edge_index.values).t()
|
|
132
|
-
num_nodes = edge_index.max()
|
|
132
|
+
num_nodes = int(edge_index.max()) + 1
|
|
133
133
|
edge_index = coalesce(edge_index, num_nodes=num_nodes)
|
|
134
134
|
|
|
135
135
|
return [Data(edge_index=edge_index, num_nodes=num_nodes)]
|
|
@@ -143,11 +143,15 @@ def read_wiki(files: List[str], name: str) -> List[Data]:
|
|
|
143
143
|
edge_index = torch.from_numpy(edge_index.values).t()
|
|
144
144
|
|
|
145
145
|
idx = torch.unique(edge_index.flatten())
|
|
146
|
-
idx_assoc = torch.full(
|
|
146
|
+
idx_assoc = torch.full(
|
|
147
|
+
(edge_index.max() + 1, ), # type: ignore
|
|
148
|
+
-1,
|
|
149
|
+
dtype=torch.long,
|
|
150
|
+
)
|
|
147
151
|
idx_assoc[idx] = torch.arange(idx.size(0))
|
|
148
152
|
|
|
149
153
|
edge_index = idx_assoc[edge_index]
|
|
150
|
-
num_nodes = edge_index.max()
|
|
154
|
+
num_nodes = int(edge_index.max()) + 1
|
|
151
155
|
edge_index = coalesce(edge_index, num_nodes=num_nodes)
|
|
152
156
|
|
|
153
157
|
return [Data(edge_index=edge_index, num_nodes=num_nodes)]
|