pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +77 -53
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +226 -189
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +14 -2
- torch_geometric/_compile.py +9 -3
- torch_geometric/_onnx.py +214 -0
- torch_geometric/config_mixin.py +5 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +109 -5
- torch_geometric/data/database.py +4 -0
- torch_geometric/data/dataset.py +14 -11
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +18 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +15 -17
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +310 -209
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/partition.py +2 -2
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +18 -14
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +87 -3
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +14 -5
- torch_geometric/inspector.py +4 -0
- torch_geometric/io/fs.py +5 -4
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/{nn/nlp → llm/models}/llm.py +179 -31
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +4 -4
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +3 -2
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +14 -0
- torch_geometric/metrics/link_pred.py +745 -92
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +8 -2
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +2 -1
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +5 -4
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +0 -20
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +14 -12
- torch_geometric/nn/model_hub.py +2 -15
- torch_geometric/nn/models/__init__.py +11 -2
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +8 -4
- torch_geometric/nn/pool/cluster_pool.py +3 -4
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +336 -7
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +296 -23
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +8 -6
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -8
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/typing.py +70 -17
- torch_geometric/utils/__init__.py +4 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +27 -12
- torch_geometric/utils/_scatter.py +132 -195
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +4 -3
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +13 -9
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +7 -14
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/sentence_transformer.py +0 -101
|
@@ -15,19 +15,16 @@ class Amazon(InMemoryDataset):
|
|
|
15
15
|
map goods to their respective product category.
|
|
16
16
|
|
|
17
17
|
Args:
|
|
18
|
-
root
|
|
19
|
-
name
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
:obj:`torch_geometric.data.Data` object and returns a transformed
|
|
18
|
+
root: Root directory where the dataset should be saved.
|
|
19
|
+
name: The name of the dataset (:obj:`"Computers"`, :obj:`"Photo"`).
|
|
20
|
+
transform: A function/transform that takes in a
|
|
21
|
+
:class:`torch_geometric.data.Data` object and returns a transformed
|
|
23
22
|
version. The data object will be transformed before every access.
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
an :obj:`torch_geometric.data.Data` object and returns a
|
|
23
|
+
pre_transform: A function/transform that takes in an
|
|
24
|
+
:class:`torch_geometric.data.Data` object and returns a
|
|
27
25
|
transformed version. The data object will be transformed before
|
|
28
|
-
being saved to disk.
|
|
29
|
-
force_reload
|
|
30
|
-
(default: :obj:`False`)
|
|
26
|
+
being saved to disk.
|
|
27
|
+
force_reload: Whether to re-process the dataset.
|
|
31
28
|
|
|
32
29
|
**STATS:**
|
|
33
30
|
|
|
@@ -14,17 +14,16 @@ class AmazonBook(InMemoryDataset):
|
|
|
14
14
|
No labels or features are provided.
|
|
15
15
|
|
|
16
16
|
Args:
|
|
17
|
-
root
|
|
18
|
-
transform
|
|
19
|
-
:
|
|
17
|
+
root: Root directory where the dataset should be saved.
|
|
18
|
+
transform: A function/transform that takes in an
|
|
19
|
+
:class:`torch_geometric.data.HeteroData` object and returns a
|
|
20
20
|
transformed version. The data object will be transformed before
|
|
21
|
-
every access.
|
|
22
|
-
pre_transform
|
|
23
|
-
|
|
21
|
+
every access.
|
|
22
|
+
pre_transform: A function/transform that takes in an
|
|
23
|
+
:class:`torch_geometric.data.HeteroData` object and returns a
|
|
24
24
|
transformed version. The data object will be transformed before
|
|
25
|
-
being saved to disk.
|
|
26
|
-
force_reload
|
|
27
|
-
(default: :obj:`False`)
|
|
25
|
+
being saved to disk.
|
|
26
|
+
force_reload: Whether to re-process the dataset.
|
|
28
27
|
"""
|
|
29
28
|
url = ('https://raw.githubusercontent.com/gusye1234/LightGCN-PyTorch/'
|
|
30
29
|
'master/data/amazon-book')
|
|
@@ -14,17 +14,15 @@ class AmazonProducts(InMemoryDataset):
|
|
|
14
14
|
containing products and its categories.
|
|
15
15
|
|
|
16
16
|
Args:
|
|
17
|
-
root
|
|
18
|
-
transform
|
|
19
|
-
:
|
|
17
|
+
root: Root directory where the dataset should be saved.
|
|
18
|
+
transform: A function/transform that takes in an
|
|
19
|
+
:class:`torch_geometric.data.Data` object and returns a transformed
|
|
20
20
|
version. The data object will be transformed before every access.
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
an :obj:`torch_geometric.data.Data` object and returns a
|
|
21
|
+
pre_transform: A function/transform that takes in a
|
|
22
|
+
:class:`torch_geometric.data.Data` object and returns a
|
|
24
23
|
transformed version. The data object will be transformed before
|
|
25
|
-
being saved to disk.
|
|
26
|
-
force_reload
|
|
27
|
-
(default: :obj:`False`)
|
|
24
|
+
being saved to disk.
|
|
25
|
+
force_reload: Whether to re-process the dataset.
|
|
28
26
|
|
|
29
27
|
**STATS:**
|
|
30
28
|
|
|
@@ -24,17 +24,16 @@ class AMiner(InMemoryDataset):
|
|
|
24
24
|
truth labels for a subset of nodes.
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
|
-
root
|
|
28
|
-
transform
|
|
29
|
-
:
|
|
27
|
+
root: Root directory where the dataset should be saved.
|
|
28
|
+
transform: A function/transform that takes in a
|
|
29
|
+
:class:`torch_geometric.data.HeteroData` object and returns a
|
|
30
30
|
transformed version. The data object will be transformed before
|
|
31
|
-
every access.
|
|
32
|
-
pre_transform
|
|
33
|
-
|
|
31
|
+
every access.
|
|
32
|
+
pre_transform: A function/transform that takes in a
|
|
33
|
+
:class:`torch_geometric.data.HeteroData` object and returns a
|
|
34
34
|
transformed version. The data object will be transformed before
|
|
35
|
-
being saved to disk.
|
|
36
|
-
force_reload
|
|
37
|
-
(default: :obj:`False`)
|
|
35
|
+
being saved to disk.
|
|
36
|
+
force_reload: Whether to re-process the dataset.
|
|
38
37
|
"""
|
|
39
38
|
|
|
40
39
|
url = 'https://www.dropbox.com/s/1bnz8r7mofx0osf/net_aminer.zip?dl=1'
|
|
@@ -30,25 +30,22 @@ class AQSOL(InMemoryDataset):
|
|
|
30
30
|
the :class:`~torch_geometric.datasets.ZINC` dataset.
|
|
31
31
|
|
|
32
32
|
Args:
|
|
33
|
-
root
|
|
34
|
-
split
|
|
33
|
+
root: Root directory where the dataset should be saved.
|
|
34
|
+
split: If :obj:`"train"`, loads the training dataset.
|
|
35
35
|
If :obj:`"val"`, loads the validation dataset.
|
|
36
36
|
If :obj:`"test"`, loads the test dataset.
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
:obj:`torch_geometric.data.Data` object and returns a transformed
|
|
37
|
+
transform: A function/transform that takes in a
|
|
38
|
+
:class:`torch_geometric.data.Data` object and returns a transformed
|
|
40
39
|
version. The data object will be transformed before every access.
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
an :obj:`torch_geometric.data.Data` object and returns a
|
|
40
|
+
pre_transform: A function/transform that takes in a
|
|
41
|
+
:class:`torch_geometric.data.Data` object and returns a
|
|
44
42
|
transformed version. The data object will be transformed before
|
|
45
|
-
being saved to disk.
|
|
43
|
+
being saved to disk.
|
|
46
44
|
pre_filter (callable, optional): A function that takes in an
|
|
47
|
-
:
|
|
45
|
+
:class:`torch_geometric.data.Data` object and returns a boolean
|
|
48
46
|
value, indicating whether the data object should be included in
|
|
49
|
-
the final dataset.
|
|
50
|
-
force_reload
|
|
51
|
-
(default: :obj:`False`)
|
|
47
|
+
the final dataset.
|
|
48
|
+
force_reload: Whether to re-process the dataset.
|
|
52
49
|
|
|
53
50
|
**STATS:**
|
|
54
51
|
|
|
@@ -19,21 +19,19 @@ class AttributedGraphDataset(InMemoryDataset):
|
|
|
19
19
|
<https://arxiv.org/abs/2009.00826>`_ paper.
|
|
20
20
|
|
|
21
21
|
Args:
|
|
22
|
-
root
|
|
23
|
-
name
|
|
22
|
+
root: Root directory where the dataset should be saved.
|
|
23
|
+
name: The name of the dataset (:obj:`"Wiki"`, :obj:`"Cora"`,
|
|
24
24
|
:obj:`"CiteSeer"`, :obj:`"PubMed"`, :obj:`"BlogCatalog"`,
|
|
25
25
|
:obj:`"PPI"`, :obj:`"Flickr"`, :obj:`"Facebook"`, :obj:`"Twitter"`,
|
|
26
26
|
:obj:`"TWeibo"`, :obj:`"MAG"`).
|
|
27
|
-
transform
|
|
28
|
-
:
|
|
27
|
+
transform: A function/transform that takes in a
|
|
28
|
+
:class:`torch_geometric.data.Data` object and returns a transformed
|
|
29
29
|
version. The data object will be transformed before every access.
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
an :obj:`torch_geometric.data.Data` object and returns a
|
|
30
|
+
pre_transform: A function/transform that takes in a
|
|
31
|
+
:class:`torch_geometric.data.Data` object and returns a
|
|
33
32
|
transformed version. The data object will be transformed before
|
|
34
|
-
being saved to disk.
|
|
35
|
-
force_reload
|
|
36
|
-
(default: :obj:`False`)
|
|
33
|
+
being saved to disk.
|
|
34
|
+
force_reload: Whether to re-process the dataset.
|
|
37
35
|
|
|
38
36
|
**STATS:**
|
|
39
37
|
|
|
@@ -25,21 +25,19 @@ class BAMultiShapesDataset(InMemoryDataset):
|
|
|
25
25
|
This dataset is pre-computed from the official implementation.
|
|
26
26
|
|
|
27
27
|
Args:
|
|
28
|
-
root
|
|
29
|
-
transform
|
|
30
|
-
:
|
|
28
|
+
root: Root directory where the dataset should be saved.
|
|
29
|
+
transform: A function/transform that takes in a
|
|
30
|
+
:class:`torch_geometric.data.Data` object and returns a transformed
|
|
31
31
|
version. The data object will be transformed before every access.
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
an :obj:`torch_geometric.data.Data` object and returns a
|
|
32
|
+
pre_transform: A function/transform that takes in a
|
|
33
|
+
:class:`torch_geometric.data.Data` object and returns a
|
|
35
34
|
transformed version. The data object will be transformed before
|
|
36
|
-
being saved to disk.
|
|
37
|
-
pre_filter
|
|
38
|
-
:
|
|
35
|
+
being saved to disk.
|
|
36
|
+
pre_filter: A function that takes in a
|
|
37
|
+
:class:`torch_geometric.data.Data` object and returns a boolean
|
|
39
38
|
value, indicating whether the data object should be included in the
|
|
40
|
-
final dataset.
|
|
41
|
-
force_reload
|
|
42
|
-
(default: :obj:`False`)
|
|
39
|
+
final dataset.
|
|
40
|
+
force_reload: Whether to re-process the dataset.
|
|
43
41
|
|
|
44
42
|
**STATS:**
|
|
45
43
|
|
|
@@ -30,15 +30,14 @@ class BAShapes(InMemoryDataset):
|
|
|
30
30
|
:class:`torch_geometric.datasets.graph_generator.BAGraph` instead.
|
|
31
31
|
|
|
32
32
|
Args:
|
|
33
|
-
connection_distribution
|
|
34
|
-
|
|
33
|
+
connection_distribution: Specifies how the houses and the BA graph get
|
|
34
|
+
connected. Valid inputs are :obj:`"random"`
|
|
35
35
|
(random BA graph nodes are selected for connection to the houses),
|
|
36
36
|
and :obj:`"uniform"` (uniformly distributed BA graph nodes are
|
|
37
|
-
selected for connection to the houses).
|
|
38
|
-
transform
|
|
39
|
-
:
|
|
37
|
+
selected for connection to the houses).
|
|
38
|
+
transform: A function/transform that takes in a
|
|
39
|
+
:class:`torch_geometric.data.Data` object and returns a transformed
|
|
40
40
|
version. The data object will be transformed before every access.
|
|
41
|
-
(default: :obj:`None`)
|
|
42
41
|
"""
|
|
43
42
|
def __init__(
|
|
44
43
|
self,
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import os.path as osp
|
|
2
|
+
from typing import Callable, Optional
|
|
3
|
+
|
|
4
|
+
from torch_geometric.data import (
|
|
5
|
+
Data,
|
|
6
|
+
InMemoryDataset,
|
|
7
|
+
download_url,
|
|
8
|
+
extract_tar,
|
|
9
|
+
)
|
|
10
|
+
from torch_geometric.io import fs
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CityNetwork(InMemoryDataset):
|
|
14
|
+
r"""The City-Networks are introduced in
|
|
15
|
+
`"Towards Quantifying Long-Range Interactions in Graph Machine Learning:
|
|
16
|
+
a Large Graph Dataset and a Measurement"
|
|
17
|
+
<https://arxiv.org/abs/2503.09008>`_ paper.
|
|
18
|
+
The dataset contains four city networks: `paris`, `shanghai`, `la`,
|
|
19
|
+
and `london`, where nodes represent junctions and edges represent
|
|
20
|
+
undirected road segments. The task is to predict each node's eccentricity
|
|
21
|
+
score, which is approximated based on its 16-hop neighborhood and naturally
|
|
22
|
+
requires long-range information. The score indicates how accessible one
|
|
23
|
+
node is in the network, and is mapped to 10 quantiles for transductive
|
|
24
|
+
classification. See the original
|
|
25
|
+
`source code <https://github.com/LeonResearch/City-Networks>`_ for more
|
|
26
|
+
details on the individual networks.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
root (str): Root directory where the dataset should be saved.
|
|
30
|
+
name (str): The name of the dataset (``"paris"``, ``"shanghai"``,
|
|
31
|
+
``"la"``, ``"london"``).
|
|
32
|
+
augmented (bool, optional): Whether to use the augmented node features
|
|
33
|
+
from edge features.(default: :obj:`True`)
|
|
34
|
+
transform (callable, optional): A function/transform that takes in an
|
|
35
|
+
:class:`~torch_geometric.data.Data` object and returns a
|
|
36
|
+
transformed version. The data object will be transformed before
|
|
37
|
+
every access. (default: :obj:`None`)
|
|
38
|
+
pre_transform (callable, optional): A function/transform that takes in
|
|
39
|
+
an :class:`~torch_geometric.data.Data` object and returns a
|
|
40
|
+
transformed version. The data object will be transformed before
|
|
41
|
+
being saved to disk. (default: :obj:`None`)
|
|
42
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
|
43
|
+
(default: :obj:`False`)
|
|
44
|
+
|
|
45
|
+
**STATS:**
|
|
46
|
+
|
|
47
|
+
.. list-table::
|
|
48
|
+
:widths: 10 10 10 10 10
|
|
49
|
+
:header-rows: 1
|
|
50
|
+
|
|
51
|
+
* - Name
|
|
52
|
+
- #nodes
|
|
53
|
+
- #edges
|
|
54
|
+
- #features
|
|
55
|
+
- #classes
|
|
56
|
+
* - paris
|
|
57
|
+
- 114,127
|
|
58
|
+
- 182,511
|
|
59
|
+
- 37
|
|
60
|
+
- 10
|
|
61
|
+
* - shanghai
|
|
62
|
+
- 183,917
|
|
63
|
+
- 262,092
|
|
64
|
+
- 37
|
|
65
|
+
- 10
|
|
66
|
+
* - la
|
|
67
|
+
- 240,587
|
|
68
|
+
- 341,523
|
|
69
|
+
- 37
|
|
70
|
+
- 10
|
|
71
|
+
* - london
|
|
72
|
+
- 568,795
|
|
73
|
+
- 756,502
|
|
74
|
+
- 37
|
|
75
|
+
- 10
|
|
76
|
+
"""
|
|
77
|
+
url = "https://github.com/LeonResearch/City-Networks/raw/refs/heads/main/data/" # noqa: E501
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
root: str,
|
|
82
|
+
name: str,
|
|
83
|
+
augmented: bool = True,
|
|
84
|
+
transform: Optional[Callable] = None,
|
|
85
|
+
pre_transform: Optional[Callable] = None,
|
|
86
|
+
force_reload: bool = False,
|
|
87
|
+
delete_raw: bool = False,
|
|
88
|
+
) -> None:
|
|
89
|
+
self.name = name.lower()
|
|
90
|
+
assert self.name in ["paris", "shanghai", "la", "london"]
|
|
91
|
+
self.augmented = augmented
|
|
92
|
+
self.delete_raw = delete_raw
|
|
93
|
+
super().__init__(
|
|
94
|
+
root,
|
|
95
|
+
transform,
|
|
96
|
+
pre_transform,
|
|
97
|
+
force_reload=force_reload,
|
|
98
|
+
)
|
|
99
|
+
self.load(self.processed_paths[0])
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def raw_dir(self) -> str:
|
|
103
|
+
return osp.join(self.root, self.name, "raw")
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def processed_dir(self) -> str:
|
|
107
|
+
return osp.join(self.root, self.name, "processed")
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def raw_file_names(self) -> str:
|
|
111
|
+
return f"{self.name}.json"
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def processed_file_names(self) -> str:
|
|
115
|
+
return "data.pt"
|
|
116
|
+
|
|
117
|
+
def download(self) -> None:
|
|
118
|
+
self.download_path = download_url(
|
|
119
|
+
self.url + f"{self.name}.tar.gz",
|
|
120
|
+
self.raw_dir,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def process(self) -> None:
|
|
124
|
+
extract_tar(self.download_path, self.raw_dir)
|
|
125
|
+
data_path = osp.join(self.raw_dir, self.name)
|
|
126
|
+
node_feat = fs.torch_load(
|
|
127
|
+
osp.join(
|
|
128
|
+
data_path,
|
|
129
|
+
f"node_features{'_augmented' if self.augmented else ''}.pt",
|
|
130
|
+
))
|
|
131
|
+
edge_index = fs.torch_load(osp.join(data_path, "edge_indices.pt"))
|
|
132
|
+
label = fs.torch_load(
|
|
133
|
+
osp.join(data_path, "10-chunk_16-hop_node_labels.pt"))
|
|
134
|
+
train_mask = fs.torch_load(osp.join(data_path, "train_mask.pt"))
|
|
135
|
+
val_mask = fs.torch_load(osp.join(data_path, "valid_mask.pt"))
|
|
136
|
+
test_mask = fs.torch_load(osp.join(data_path, "test_mask.pt"))
|
|
137
|
+
data = Data(
|
|
138
|
+
x=node_feat,
|
|
139
|
+
edge_index=edge_index,
|
|
140
|
+
y=label,
|
|
141
|
+
train_mask=train_mask,
|
|
142
|
+
val_mask=val_mask,
|
|
143
|
+
test_mask=test_mask,
|
|
144
|
+
)
|
|
145
|
+
if self.pre_transform is not None:
|
|
146
|
+
data = self.pre_transform(data)
|
|
147
|
+
|
|
148
|
+
self.save([data], self.processed_paths[0])
|
|
149
|
+
|
|
150
|
+
if self.delete_raw:
|
|
151
|
+
fs.rm(data_path)
|
|
152
|
+
|
|
153
|
+
def __repr__(self) -> str:
|
|
154
|
+
return (f"{self.__class__.__name__}("
|
|
155
|
+
f"root='{self.root}', "
|
|
156
|
+
f"name='{self.name}', "
|
|
157
|
+
f"augmented={self.augmented})")
|
|
@@ -73,7 +73,7 @@ class DBP15K(InMemoryDataset):
|
|
|
73
73
|
def process(self) -> None:
|
|
74
74
|
embs = {}
|
|
75
75
|
with open(osp.join(self.raw_dir, 'sub.glove.300d')) as f:
|
|
76
|
-
for
|
|
76
|
+
for line in f:
|
|
77
77
|
info = line.strip().split(' ')
|
|
78
78
|
if len(info) > 300:
|
|
79
79
|
embs[info[0]] = torch.tensor([float(x) for x in info[1:]])
|
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from torch_geometric.data import (
|
|
9
|
+
Data,
|
|
10
|
+
InMemoryDataset,
|
|
11
|
+
download_google_url,
|
|
12
|
+
extract_zip,
|
|
13
|
+
)
|
|
14
|
+
from torch_geometric.io import fs
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def safe_index(lst: List[Any], e: int) -> int:
|
|
18
|
+
return lst.index(e) if e in lst else len(lst) - 1
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class GitMolDataset(InMemoryDataset):
|
|
22
|
+
r"""The dataset from the `"GIT-Mol: A Multi-modal Large Language Model
|
|
23
|
+
for Molecular Science with Graph, Image, and Text"
|
|
24
|
+
<https://arxiv.org/pdf/2308.06911>`_ paper.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
root (str): Root directory where the dataset should be saved.
|
|
28
|
+
transform (callable, optional): A function/transform that takes in an
|
|
29
|
+
:obj:`torch_geometric.data.Data` object and returns a transformed
|
|
30
|
+
version. The data object will be transformed before every access.
|
|
31
|
+
(default: :obj:`None`)
|
|
32
|
+
pre_transform (callable, optional): A function/transform that takes in
|
|
33
|
+
an :obj:`torch_geometric.data.Data` object and returns a
|
|
34
|
+
transformed version. The data object will be transformed before
|
|
35
|
+
being saved to disk. (default: :obj:`None`)
|
|
36
|
+
pre_filter (callable, optional): A function that takes in an
|
|
37
|
+
:obj:`torch_geometric.data.Data` object and returns a boolean
|
|
38
|
+
value, indicating whether the data object should be included in the
|
|
39
|
+
final dataset. (default: :obj:`None`)
|
|
40
|
+
force_reload (bool, optional): Whether to re-process the dataset.
|
|
41
|
+
(default: :obj:`False`)
|
|
42
|
+
split (int, optional): Datasets split, train/valid/test=0/1/2.
|
|
43
|
+
(default: :obj:`0`)
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
raw_url_id = '1loBXabD6ncAFY-vanRsVtRUSFkEtBweg'
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
root: str,
|
|
51
|
+
transform: Optional[Callable] = None,
|
|
52
|
+
pre_transform: Optional[Callable] = None,
|
|
53
|
+
pre_filter: Optional[Callable] = None,
|
|
54
|
+
force_reload: bool = False,
|
|
55
|
+
split: int = 0,
|
|
56
|
+
):
|
|
57
|
+
from torchvision import transforms
|
|
58
|
+
|
|
59
|
+
self.split = split
|
|
60
|
+
|
|
61
|
+
if self.split == 0:
|
|
62
|
+
self.img_transform = transforms.Compose([
|
|
63
|
+
transforms.Resize((224, 224)),
|
|
64
|
+
transforms.RandomRotation(15),
|
|
65
|
+
transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
|
|
66
|
+
transforms.ToTensor(),
|
|
67
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
68
|
+
std=[0.229, 0.224, 0.225])
|
|
69
|
+
])
|
|
70
|
+
else:
|
|
71
|
+
self.img_transform = transforms.Compose([
|
|
72
|
+
transforms.Resize((224, 224)),
|
|
73
|
+
transforms.ToTensor(),
|
|
74
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
75
|
+
std=[0.229, 0.224, 0.225])
|
|
76
|
+
])
|
|
77
|
+
|
|
78
|
+
super().__init__(root, transform, pre_transform, pre_filter,
|
|
79
|
+
force_reload=force_reload)
|
|
80
|
+
|
|
81
|
+
self.load(self.processed_paths[0])
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def raw_file_names(self) -> List[str]:
|
|
85
|
+
return ['train_3500.pkl', 'valid_450.pkl', 'test_450.pkl']
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def processed_file_names(self) -> str:
|
|
89
|
+
return ['train.pt', 'valid.pt', 'test.pt'][self.split]
|
|
90
|
+
|
|
91
|
+
def download(self) -> None:
|
|
92
|
+
file_path = download_google_url(
|
|
93
|
+
self.raw_url_id,
|
|
94
|
+
self.raw_dir,
|
|
95
|
+
'gitmol.zip',
|
|
96
|
+
)
|
|
97
|
+
extract_zip(file_path, self.raw_dir)
|
|
98
|
+
|
|
99
|
+
def process(self) -> None:
|
|
100
|
+
import pandas as pd
|
|
101
|
+
from PIL import Image
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
from rdkit import Chem, RDLogger
|
|
105
|
+
RDLogger.DisableLog('rdApp.*') # type: ignore[attr-defined]
|
|
106
|
+
WITH_RDKIT = True
|
|
107
|
+
|
|
108
|
+
except ImportError:
|
|
109
|
+
WITH_RDKIT = False
|
|
110
|
+
|
|
111
|
+
if not WITH_RDKIT:
|
|
112
|
+
print(("Using a pre-processed version of the dataset. Please "
|
|
113
|
+
"install 'rdkit' to alternatively process the raw data."),
|
|
114
|
+
file=sys.stderr)
|
|
115
|
+
|
|
116
|
+
data_list = fs.torch_load(self.raw_paths[0])
|
|
117
|
+
data_list = [Data(**data_dict) for data_dict in data_list]
|
|
118
|
+
|
|
119
|
+
if self.pre_filter is not None:
|
|
120
|
+
data_list = [d for d in data_list if self.pre_filter(d)]
|
|
121
|
+
|
|
122
|
+
if self.pre_transform is not None:
|
|
123
|
+
data_list = [self.pre_transform(d) for d in data_list]
|
|
124
|
+
|
|
125
|
+
self.save(data_list, self.processed_paths[0])
|
|
126
|
+
return
|
|
127
|
+
|
|
128
|
+
allowable_features: Dict[str, List[Any]] = {
|
|
129
|
+
'possible_atomic_num_list':
|
|
130
|
+
list(range(1, 119)) + ['misc'],
|
|
131
|
+
'possible_formal_charge_list':
|
|
132
|
+
[-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'],
|
|
133
|
+
'possible_chirality_list': [
|
|
134
|
+
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
|
|
135
|
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
|
|
136
|
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
|
|
137
|
+
Chem.rdchem.ChiralType.CHI_OTHER
|
|
138
|
+
],
|
|
139
|
+
'possible_hybridization_list': [
|
|
140
|
+
Chem.rdchem.HybridizationType.SP,
|
|
141
|
+
Chem.rdchem.HybridizationType.SP2,
|
|
142
|
+
Chem.rdchem.HybridizationType.SP3,
|
|
143
|
+
Chem.rdchem.HybridizationType.SP3D,
|
|
144
|
+
Chem.rdchem.HybridizationType.SP3D2,
|
|
145
|
+
Chem.rdchem.HybridizationType.UNSPECIFIED, 'misc'
|
|
146
|
+
],
|
|
147
|
+
'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
|
|
148
|
+
'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6],
|
|
149
|
+
'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
|
|
150
|
+
'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'],
|
|
151
|
+
'possible_is_aromatic_list': [False, True],
|
|
152
|
+
'possible_is_in_ring_list': [False, True],
|
|
153
|
+
'possible_bond_type_list': [
|
|
154
|
+
Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
|
|
155
|
+
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC,
|
|
156
|
+
Chem.rdchem.BondType.ZERO
|
|
157
|
+
],
|
|
158
|
+
'possible_bond_dirs': [ # only for double bond stereo information
|
|
159
|
+
Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT,
|
|
160
|
+
Chem.rdchem.BondDir.ENDDOWNRIGHT
|
|
161
|
+
],
|
|
162
|
+
'possible_bond_stereo_list': [
|
|
163
|
+
Chem.rdchem.BondStereo.STEREONONE,
|
|
164
|
+
Chem.rdchem.BondStereo.STEREOZ,
|
|
165
|
+
Chem.rdchem.BondStereo.STEREOE,
|
|
166
|
+
Chem.rdchem.BondStereo.STEREOCIS,
|
|
167
|
+
Chem.rdchem.BondStereo.STEREOTRANS,
|
|
168
|
+
Chem.rdchem.BondStereo.STEREOANY,
|
|
169
|
+
],
|
|
170
|
+
'possible_is_conjugated_list': [False, True]
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
data = pd.read_pickle(
|
|
174
|
+
f'{self.raw_dir}/igcdata_toy/{self.raw_file_names[self.split]}')
|
|
175
|
+
|
|
176
|
+
data_list = []
|
|
177
|
+
for _, r in tqdm(data.iterrows(), total=data.shape[0]):
|
|
178
|
+
smiles = r['isosmiles']
|
|
179
|
+
mol = Chem.MolFromSmiles(smiles.strip('\n'))
|
|
180
|
+
if mol is not None:
|
|
181
|
+
# text
|
|
182
|
+
summary = r['summary']
|
|
183
|
+
# image
|
|
184
|
+
cid = r['cid']
|
|
185
|
+
img_file = f'{self.raw_dir}/igcdata_toy/imgs/CID_{cid}.png'
|
|
186
|
+
img = Image.open(img_file).convert('RGB')
|
|
187
|
+
img = self.img_transform(img).unsqueeze(0)
|
|
188
|
+
# graph
|
|
189
|
+
atom_features_list = []
|
|
190
|
+
for atom in mol.GetAtoms():
|
|
191
|
+
atom_feature = [
|
|
192
|
+
safe_index(
|
|
193
|
+
allowable_features['possible_atomic_num_list'],
|
|
194
|
+
atom.GetAtomicNum()),
|
|
195
|
+
allowable_features['possible_chirality_list'].index(
|
|
196
|
+
atom.GetChiralTag()),
|
|
197
|
+
safe_index(allowable_features['possible_degree_list'],
|
|
198
|
+
atom.GetTotalDegree()),
|
|
199
|
+
safe_index(
|
|
200
|
+
allowable_features['possible_formal_charge_list'],
|
|
201
|
+
atom.GetFormalCharge()),
|
|
202
|
+
safe_index(allowable_features['possible_numH_list'],
|
|
203
|
+
atom.GetTotalNumHs()),
|
|
204
|
+
safe_index(
|
|
205
|
+
allowable_features[
|
|
206
|
+
'possible_number_radical_e_list'],
|
|
207
|
+
atom.GetNumRadicalElectrons()),
|
|
208
|
+
safe_index(
|
|
209
|
+
allowable_features['possible_hybridization_list'],
|
|
210
|
+
atom.GetHybridization()),
|
|
211
|
+
allowable_features['possible_is_aromatic_list'].index(
|
|
212
|
+
atom.GetIsAromatic()),
|
|
213
|
+
allowable_features['possible_is_in_ring_list'].index(
|
|
214
|
+
atom.IsInRing()),
|
|
215
|
+
]
|
|
216
|
+
atom_features_list.append(atom_feature)
|
|
217
|
+
x = torch.tensor(np.array(atom_features_list),
|
|
218
|
+
dtype=torch.long)
|
|
219
|
+
|
|
220
|
+
edges_list = []
|
|
221
|
+
edge_features_list = []
|
|
222
|
+
for bond in mol.GetBonds():
|
|
223
|
+
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
|
224
|
+
edge_feature = [
|
|
225
|
+
safe_index(
|
|
226
|
+
allowable_features['possible_bond_type_list'],
|
|
227
|
+
bond.GetBondType()),
|
|
228
|
+
allowable_features['possible_bond_stereo_list'].index(
|
|
229
|
+
bond.GetStereo()),
|
|
230
|
+
allowable_features['possible_is_conjugated_list'].
|
|
231
|
+
index(bond.GetIsConjugated()),
|
|
232
|
+
]
|
|
233
|
+
edges_list.append((i, j))
|
|
234
|
+
edge_features_list.append(edge_feature)
|
|
235
|
+
edges_list.append((j, i))
|
|
236
|
+
edge_features_list.append(edge_feature)
|
|
237
|
+
|
|
238
|
+
edge_index = torch.tensor(
|
|
239
|
+
np.array(edges_list).T,
|
|
240
|
+
dtype=torch.long,
|
|
241
|
+
)
|
|
242
|
+
edge_attr = torch.tensor(
|
|
243
|
+
np.array(edge_features_list),
|
|
244
|
+
dtype=torch.long,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
data = Data(
|
|
248
|
+
x=x,
|
|
249
|
+
edge_index=edge_index,
|
|
250
|
+
smiles=smiles,
|
|
251
|
+
edge_attr=edge_attr,
|
|
252
|
+
image=img,
|
|
253
|
+
caption=summary,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
if self.pre_filter is not None and not self.pre_filter(data):
|
|
257
|
+
continue
|
|
258
|
+
if self.pre_transform is not None:
|
|
259
|
+
data = self.pre_transform(data)
|
|
260
|
+
|
|
261
|
+
data_list.append(data)
|
|
262
|
+
|
|
263
|
+
self.save(data_list, self.processed_paths[0])
|