pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
torch_geometric/data/database.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import
|
|
1
|
+
import io
|
|
2
2
|
import warnings
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from dataclasses import dataclass
|
|
@@ -111,13 +111,17 @@ class Database(ABC):
|
|
|
111
111
|
for key, value in schema_dict.items()
|
|
112
112
|
}
|
|
113
113
|
|
|
114
|
+
@abstractmethod
|
|
114
115
|
def connect(self) -> None:
|
|
115
116
|
r"""Connects to the database.
|
|
116
117
|
Databases will automatically connect on instantiation.
|
|
117
118
|
"""
|
|
119
|
+
raise NotImplementedError
|
|
118
120
|
|
|
121
|
+
@abstractmethod
|
|
119
122
|
def close(self) -> None:
|
|
120
123
|
r"""Closes the connection to the database."""
|
|
124
|
+
raise NotImplementedError
|
|
121
125
|
|
|
122
126
|
@abstractmethod
|
|
123
127
|
def insert(self, index: int, data: Any) -> None:
|
|
@@ -496,7 +500,9 @@ class SQLiteDatabase(Database):
|
|
|
496
500
|
out.append(col)
|
|
497
501
|
|
|
498
502
|
else:
|
|
499
|
-
|
|
503
|
+
buffer = io.BytesIO()
|
|
504
|
+
torch.save(col, buffer)
|
|
505
|
+
out.append(buffer.getvalue())
|
|
500
506
|
|
|
501
507
|
return out
|
|
502
508
|
|
|
@@ -559,7 +565,10 @@ class SQLiteDatabase(Database):
|
|
|
559
565
|
out_dict[key] = value
|
|
560
566
|
|
|
561
567
|
else:
|
|
562
|
-
out_dict[key] =
|
|
568
|
+
out_dict[key] = torch.load(
|
|
569
|
+
io.BytesIO(value),
|
|
570
|
+
weights_only=False,
|
|
571
|
+
)
|
|
563
572
|
|
|
564
573
|
# In case `0` exists as integer in the schema, this means that the
|
|
565
574
|
# schema was passed as either a single entry or a tuple:
|
|
@@ -644,7 +653,12 @@ class RocksDatabase(Database):
|
|
|
644
653
|
# Ensure that data is not a view of a larger tensor:
|
|
645
654
|
if isinstance(row, Tensor):
|
|
646
655
|
row = row.clone()
|
|
647
|
-
|
|
656
|
+
buffer = io.BytesIO()
|
|
657
|
+
torch.save(row, buffer)
|
|
658
|
+
return buffer.getvalue()
|
|
648
659
|
|
|
649
660
|
def _deserialize(self, row: bytes) -> Any:
|
|
650
|
-
return
|
|
661
|
+
return torch.load(
|
|
662
|
+
io.BytesIO(row),
|
|
663
|
+
weights_only=False,
|
|
664
|
+
)
|
torch_geometric/data/dataset.py
CHANGED
|
@@ -166,10 +166,11 @@ class Dataset(torch.utils.data.Dataset):
|
|
|
166
166
|
elif y.numel() == y.size(0) and torch.is_floating_point(y):
|
|
167
167
|
num_classes = torch.unique(y).numel()
|
|
168
168
|
if num_classes > 2:
|
|
169
|
-
warnings.warn(
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
169
|
+
warnings.warn(
|
|
170
|
+
"Found floating-point labels while calling "
|
|
171
|
+
"`dataset.num_classes`. Returning the number of "
|
|
172
|
+
"unique elements. Please make sure that this "
|
|
173
|
+
"is expected before proceeding.", stacklevel=2)
|
|
173
174
|
return num_classes
|
|
174
175
|
else:
|
|
175
176
|
return y.size(-1)
|
|
@@ -235,20 +236,24 @@ class Dataset(torch.utils.data.Dataset):
|
|
|
235
236
|
|
|
236
237
|
def _process(self):
|
|
237
238
|
f = osp.join(self.processed_dir, 'pre_transform.pt')
|
|
238
|
-
if osp.exists(f) and torch.load(
|
|
239
|
+
if not self.force_reload and osp.exists(f) and torch.load(
|
|
240
|
+
f, weights_only=False) != _repr(self.pre_transform):
|
|
239
241
|
warnings.warn(
|
|
240
242
|
"The `pre_transform` argument differs from the one used in "
|
|
241
243
|
"the pre-processed version of this dataset. If you want to "
|
|
242
244
|
"make use of another pre-processing technique, pass "
|
|
243
|
-
"`force_reload=True` explicitly to reload the dataset."
|
|
245
|
+
"`force_reload=True` explicitly to reload the dataset.",
|
|
246
|
+
stacklevel=2)
|
|
244
247
|
|
|
245
248
|
f = osp.join(self.processed_dir, 'pre_filter.pt')
|
|
246
|
-
if osp.exists(f) and torch.load(
|
|
249
|
+
if not self.force_reload and osp.exists(f) and torch.load(
|
|
250
|
+
f, weights_only=False) != _repr(self.pre_filter):
|
|
247
251
|
warnings.warn(
|
|
248
252
|
"The `pre_filter` argument differs from the one used in "
|
|
249
253
|
"the pre-processed version of this dataset. If you want to "
|
|
250
254
|
"make use of another pre-fitering technique, pass "
|
|
251
|
-
"`force_reload=True` explicitly to reload the dataset."
|
|
255
|
+
"`force_reload=True` explicitly to reload the dataset.",
|
|
256
|
+
stacklevel=2)
|
|
252
257
|
|
|
253
258
|
if not self.force_reload and files_exist(self.processed_paths):
|
|
254
259
|
return
|
|
@@ -381,7 +386,7 @@ class Dataset(torch.utils.data.Dataset):
|
|
|
381
386
|
r"""Converts the dataset into a :class:`torch.utils.data.DataPipe`.
|
|
382
387
|
|
|
383
388
|
The returned instance can then be used with :pyg:`PyG's` built-in
|
|
384
|
-
:class:`DataPipes` for
|
|
389
|
+
:class:`DataPipes` for batching graphs as follows:
|
|
385
390
|
|
|
386
391
|
.. code-block:: python
|
|
387
392
|
|
torch_geometric/data/extract.py
CHANGED
|
@@ -11,7 +11,7 @@ This particular feature store abstraction makes a few key assumptions:
|
|
|
11
11
|
* A feature can be uniquely identified from any associated attributes specified
|
|
12
12
|
in `TensorAttr`.
|
|
13
13
|
|
|
14
|
-
It is the job of a feature store
|
|
14
|
+
It is the job of a feature store implementer class to handle these assumptions
|
|
15
15
|
properly. For example, a simple in-memory feature store implementation may
|
|
16
16
|
concatenate all metadata values with a feature index and use this as a unique
|
|
17
17
|
index in a KV store. More complicated implementations may choose to partition
|
|
@@ -74,13 +74,6 @@ class TensorAttr(CastMixin):
|
|
|
74
74
|
r"""Whether the :obj:`TensorAttr` has no unset fields."""
|
|
75
75
|
return all([self.is_set(key) for key in self.__dataclass_fields__])
|
|
76
76
|
|
|
77
|
-
def fully_specify(self) -> 'TensorAttr':
|
|
78
|
-
r"""Sets all :obj:`UNSET` fields to :obj:`None`."""
|
|
79
|
-
for key in self.__dataclass_fields__:
|
|
80
|
-
if not self.is_set(key):
|
|
81
|
-
setattr(self, key, None)
|
|
82
|
-
return self
|
|
83
|
-
|
|
84
77
|
def update(self, attr: 'TensorAttr') -> 'TensorAttr':
|
|
85
78
|
r"""Updates an :class:`TensorAttr` with set attributes from another
|
|
86
79
|
:class:`TensorAttr`.
|
|
@@ -230,10 +223,11 @@ class AttrView(CastMixin):
|
|
|
230
223
|
|
|
231
224
|
store[group_name, attr_name]()
|
|
232
225
|
"""
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
226
|
+
attr = copy.copy(self._attr)
|
|
227
|
+
for key in attr.__dataclass_fields__: # Set all UNSET values to None.
|
|
228
|
+
if not attr.is_set(key):
|
|
229
|
+
setattr(attr, key, None)
|
|
230
|
+
return self._store.get_tensor(attr)
|
|
237
231
|
|
|
238
232
|
def __copy__(self) -> 'AttrView':
|
|
239
233
|
out = self.__class__.__new__(self.__class__)
|
|
@@ -358,7 +352,7 @@ class FeatureStore(ABC):
|
|
|
358
352
|
|
|
359
353
|
.. note::
|
|
360
354
|
The default implementation simply iterates over all calls to
|
|
361
|
-
:meth:`get_tensor`.
|
|
355
|
+
:meth:`get_tensor`. Implementer classes that can provide
|
|
362
356
|
additional, more performant functionality are recommended to
|
|
363
357
|
to override this method.
|
|
364
358
|
|
|
@@ -415,10 +409,10 @@ class FeatureStore(ABC):
|
|
|
415
409
|
def update_tensor(self, tensor: FeatureTensorType, *args,
|
|
416
410
|
**kwargs) -> bool:
|
|
417
411
|
r"""Updates a :obj:`tensor` in the :class:`FeatureStore` with a new
|
|
418
|
-
value. Returns whether the update was
|
|
412
|
+
value. Returns whether the update was successful.
|
|
419
413
|
|
|
420
414
|
.. note::
|
|
421
|
-
|
|
415
|
+
Implementer classes can choose to define more efficient update
|
|
422
416
|
methods; the default performs a removal and insertion.
|
|
423
417
|
|
|
424
418
|
Args:
|
|
@@ -479,9 +473,7 @@ class FeatureStore(ABC):
|
|
|
479
473
|
# CastMixin will handle the case of key being a tuple or TensorAttr
|
|
480
474
|
# object:
|
|
481
475
|
key = self._tensor_attr_cls.cast(key)
|
|
482
|
-
|
|
483
|
-
# sense to work with a view here:
|
|
484
|
-
key.fully_specify()
|
|
476
|
+
assert key.is_fully_specified()
|
|
485
477
|
self.put_tensor(value, key)
|
|
486
478
|
|
|
487
479
|
def __getitem__(self, key: TensorAttr) -> Any:
|
|
@@ -503,13 +495,16 @@ class FeatureStore(ABC):
|
|
|
503
495
|
# If the view is not fully-specified, return a :class:`AttrView`:
|
|
504
496
|
return self.view(attr)
|
|
505
497
|
|
|
506
|
-
def __delitem__(self,
|
|
498
|
+
def __delitem__(self, attr: TensorAttr):
|
|
507
499
|
r"""Supports :obj:`del store[tensor_attr]`."""
|
|
508
500
|
# CastMixin will handle the case of key being a tuple or TensorAttr
|
|
509
501
|
# object:
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
502
|
+
attr = self._tensor_attr_cls.cast(attr)
|
|
503
|
+
attr = copy.copy(attr)
|
|
504
|
+
for key in attr.__dataclass_fields__: # Set all UNSET values to None.
|
|
505
|
+
if not attr.is_set(key):
|
|
506
|
+
setattr(attr, key, None)
|
|
507
|
+
self.remove_tensor(attr)
|
|
513
508
|
|
|
514
509
|
def __iter__(self):
|
|
515
510
|
raise NotImplementedError
|
|
@@ -10,7 +10,7 @@ This particular graph store abstraction makes a few key assumptions:
|
|
|
10
10
|
support dynamic modification of edge indices once they have been inserted
|
|
11
11
|
into the graph store.
|
|
12
12
|
|
|
13
|
-
It is the job of a graph store
|
|
13
|
+
It is the job of a graph store implementer class to handle these assumptions
|
|
14
14
|
properly. For example, a simple in-memory graph store implementation may
|
|
15
15
|
concatenate all metadata values with an edge index and use this as a unique
|
|
16
16
|
index in a KV store. More complicated implementations may choose to partition
|
|
@@ -261,7 +261,8 @@ class GraphStore(ABC):
|
|
|
261
261
|
col = ptr2index(col)
|
|
262
262
|
|
|
263
263
|
if attr.layout != EdgeLayout.CSR: # COO->CSR
|
|
264
|
-
num_rows = attr.size[0] if attr.size else int(
|
|
264
|
+
num_rows = attr.size[0] if attr.size is not None else int(
|
|
265
|
+
row.max()) + 1
|
|
265
266
|
row, perm = index_sort(row, max_value=num_rows)
|
|
266
267
|
col = col[perm]
|
|
267
268
|
row = index2ptr(row, num_rows)
|
|
@@ -282,6 +282,21 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
|
|
|
282
282
|
r"""Returns a list of edge type and edge storage pairs."""
|
|
283
283
|
return list(self._edge_store_dict.items())
|
|
284
284
|
|
|
285
|
+
@property
|
|
286
|
+
def input_type(self) -> Optional[Union[NodeType, EdgeType]]:
|
|
287
|
+
r"""Returns the seed/input node/edge type of the graph in case it
|
|
288
|
+
refers to a sampled subgraph, *e.g.*, obtained via
|
|
289
|
+
:class:`~torch_geometric.loader.NeighborLoader` or
|
|
290
|
+
:class:`~torch_geometric.loader.LinkNeighborLoader`.
|
|
291
|
+
"""
|
|
292
|
+
for node_type, store in self.node_items():
|
|
293
|
+
if hasattr(store, 'input_id'):
|
|
294
|
+
return node_type
|
|
295
|
+
for edge_type, store in self.edge_items():
|
|
296
|
+
if hasattr(store, 'input_id'):
|
|
297
|
+
return edge_type
|
|
298
|
+
return None
|
|
299
|
+
|
|
285
300
|
def to_dict(self) -> Dict[str, Any]:
|
|
286
301
|
out_dict: Dict[str, Any] = {}
|
|
287
302
|
out_dict['_global_store'] = self._global_store.to_dict()
|
|
@@ -472,6 +487,77 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
|
|
|
472
487
|
|
|
473
488
|
return status
|
|
474
489
|
|
|
490
|
+
def connected_components(self) -> List[Self]:
|
|
491
|
+
r"""Extracts connected components of the heterogeneous graph using
|
|
492
|
+
a union-find algorithm. The components are returned as a list of
|
|
493
|
+
:class:`~torch_geometric.data.HeteroData` objects.
|
|
494
|
+
|
|
495
|
+
.. code-block::
|
|
496
|
+
|
|
497
|
+
data = HeteroData()
|
|
498
|
+
data["red"].x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
|
|
499
|
+
data["blue"].x = torch.tensor([[5.0], [6.0]])
|
|
500
|
+
data["red", "to", "red"].edge_index = torch.tensor(
|
|
501
|
+
[[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
components = data.connected_components()
|
|
505
|
+
print(len(components))
|
|
506
|
+
>>> 4
|
|
507
|
+
|
|
508
|
+
print(components[0])
|
|
509
|
+
>>> HeteroData(
|
|
510
|
+
red={x: tensor([[1.], [2.]])},
|
|
511
|
+
blue={x: tensor([[]])},
|
|
512
|
+
red, to, red={edge_index: tensor([[0, 1], [1, 0]])}
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
Returns:
|
|
516
|
+
List[HeteroData]: A list of connected components.
|
|
517
|
+
"""
|
|
518
|
+
# Initialize union-find structures
|
|
519
|
+
self._parents: Dict[Tuple[str, int], Tuple[str, int]] = {}
|
|
520
|
+
self._ranks: Dict[Tuple[str, int], int] = {}
|
|
521
|
+
|
|
522
|
+
# Union-Find algorithm to find connected components
|
|
523
|
+
for edge_type in self.edge_types:
|
|
524
|
+
src, _, dst = edge_type
|
|
525
|
+
edge_index = self[edge_type].edge_index
|
|
526
|
+
for src_node, dst_node in edge_index.t().tolist():
|
|
527
|
+
self._union((src, src_node), (dst, dst_node))
|
|
528
|
+
|
|
529
|
+
# Rerun _find_parent to ensure all nodes are covered correctly
|
|
530
|
+
for node_type in self.node_types:
|
|
531
|
+
for node_index in range(self[node_type].num_nodes):
|
|
532
|
+
self._find_parent((node_type, node_index))
|
|
533
|
+
|
|
534
|
+
# Group nodes by their representative parent
|
|
535
|
+
components_map = defaultdict(list)
|
|
536
|
+
for node, parent in self._parents.items():
|
|
537
|
+
components_map[parent].append(node)
|
|
538
|
+
del self._parents
|
|
539
|
+
del self._ranks
|
|
540
|
+
|
|
541
|
+
components: List[Self] = []
|
|
542
|
+
for nodes in components_map.values():
|
|
543
|
+
# Prefill subset_dict with all node types to ensure all are present
|
|
544
|
+
subset_dict = {node_type: [] for node_type in self.node_types}
|
|
545
|
+
|
|
546
|
+
# Convert the list of (node_type, node_id) tuples to a subset_dict
|
|
547
|
+
for node_type, node_id in nodes:
|
|
548
|
+
subset_dict[node_type].append(node_id)
|
|
549
|
+
|
|
550
|
+
# Convert lists to tensors
|
|
551
|
+
for node_type, node_ids in subset_dict.items():
|
|
552
|
+
subset_dict[node_type] = torch.tensor(node_ids,
|
|
553
|
+
dtype=torch.long)
|
|
554
|
+
|
|
555
|
+
# Use the existing subgraph function to do all the heavy lifting
|
|
556
|
+
component_data = self.subgraph(subset_dict)
|
|
557
|
+
components.append(component_data)
|
|
558
|
+
|
|
559
|
+
return components
|
|
560
|
+
|
|
475
561
|
def debug(self):
|
|
476
562
|
pass # TODO
|
|
477
563
|
|
|
@@ -551,7 +637,7 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
|
|
|
551
637
|
This is equivalent to writing :obj:`data.x_dict`.
|
|
552
638
|
|
|
553
639
|
Args:
|
|
554
|
-
key (str): The attribute to collect from all node and
|
|
640
|
+
key (str): The attribute to collect from all node and edge types.
|
|
555
641
|
allow_empty (bool, optional): If set to :obj:`True`, will not raise
|
|
556
642
|
an error in case the attribute does not exit in any node or
|
|
557
643
|
edge type. (default: :obj:`False`)
|
|
@@ -570,12 +656,13 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
|
|
|
570
656
|
global _DISPLAYED_TYPE_NAME_WARNING
|
|
571
657
|
if not _DISPLAYED_TYPE_NAME_WARNING and '__' in name:
|
|
572
658
|
_DISPLAYED_TYPE_NAME_WARNING = True
|
|
573
|
-
warnings.warn(
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
659
|
+
warnings.warn(
|
|
660
|
+
f"There exist type names in the "
|
|
661
|
+
f"'{self.__class__.__name__}' object that contain "
|
|
662
|
+
f"double underscores '__' (e.g., '{name}'). This "
|
|
663
|
+
f"may lead to unexpected behavior. To avoid any "
|
|
664
|
+
f"issues, ensure that your type names only contain "
|
|
665
|
+
f"single underscores.", stacklevel=2)
|
|
579
666
|
|
|
580
667
|
def get_node_store(self, key: NodeType) -> NodeStorage:
|
|
581
668
|
r"""Gets the :class:`~torch_geometric.data.storage.NodeStorage` object
|
|
@@ -1132,6 +1219,51 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
|
|
|
1132
1219
|
|
|
1133
1220
|
return list(edge_attrs.values())
|
|
1134
1221
|
|
|
1222
|
+
# Connected Components Helper Functions ###################################
|
|
1223
|
+
|
|
1224
|
+
def _find_parent(self, node: Tuple[str, int]) -> Tuple[str, int]:
|
|
1225
|
+
r"""Finds and returns the representative parent of the given node in a
|
|
1226
|
+
disjoint-set (union-find) data structure. Implements path compression
|
|
1227
|
+
to optimize future queries.
|
|
1228
|
+
|
|
1229
|
+
Args:
|
|
1230
|
+
node (tuple[str, int]): The node for which to find the parent.
|
|
1231
|
+
First element is the node type, second is the node index.
|
|
1232
|
+
|
|
1233
|
+
Returns:
|
|
1234
|
+
tuple[str, int]: The representative parent of the node.
|
|
1235
|
+
"""
|
|
1236
|
+
if node not in self._parents:
|
|
1237
|
+
self._parents[node] = node
|
|
1238
|
+
self._ranks[node] = 0
|
|
1239
|
+
if self._parents[node] != node:
|
|
1240
|
+
self._parents[node] = self._find_parent(self._parents[node])
|
|
1241
|
+
return self._parents[node]
|
|
1242
|
+
|
|
1243
|
+
def _union(self, node1: Tuple[str, int], node2: Tuple[str, int]):
|
|
1244
|
+
r"""Merges the node1 and node2 in the disjoint-set data structure.
|
|
1245
|
+
|
|
1246
|
+
Finds the root parents of node1 and node2 using the _find_parent
|
|
1247
|
+
method. If they belong to different sets, updates the parent of
|
|
1248
|
+
root2 to be root1, effectively merging the two sets.
|
|
1249
|
+
|
|
1250
|
+
Args:
|
|
1251
|
+
node1 (Tuple[str, int]): The first node to union. First element is
|
|
1252
|
+
the node type, second is the node index.
|
|
1253
|
+
node2 (Tuple[str, int]): The second node to union. First element is
|
|
1254
|
+
the node type, second is the node index.
|
|
1255
|
+
"""
|
|
1256
|
+
root1 = self._find_parent(node1)
|
|
1257
|
+
root2 = self._find_parent(node2)
|
|
1258
|
+
if root1 != root2:
|
|
1259
|
+
if self._ranks[root1] < self._ranks[root2]:
|
|
1260
|
+
self._parents[root1] = root2
|
|
1261
|
+
elif self._ranks[root1] > self._ranks[root2]:
|
|
1262
|
+
self._parents[root2] = root1
|
|
1263
|
+
else:
|
|
1264
|
+
self._parents[root2] = root1
|
|
1265
|
+
self._ranks[root1] += 1
|
|
1266
|
+
|
|
1135
1267
|
|
|
1136
1268
|
# Helper functions ############################################################
|
|
1137
1269
|
|
|
@@ -39,7 +39,7 @@ class HyperGraphData(Data):
|
|
|
39
39
|
edge_index (LongTensor, optional): Hyperedge tensor
|
|
40
40
|
with shape :obj:`[2, num_edges*num_nodes_per_edge]`.
|
|
41
41
|
Where `edge_index[1]` denotes the hyperedge index and
|
|
42
|
-
`edge_index[0]` denotes the node
|
|
42
|
+
`edge_index[0]` denotes the node indices that are connected
|
|
43
43
|
by the hyperedge. (default: :obj:`None`)
|
|
44
44
|
(default: :obj:`None`)
|
|
45
45
|
edge_attr (torch.Tensor, optional): Edge feature matrix with shape
|
|
@@ -223,4 +223,4 @@ def warn_or_raise(msg: str, raise_on_error: bool = True) -> None:
|
|
|
223
223
|
if raise_on_error:
|
|
224
224
|
raise ValueError(msg)
|
|
225
225
|
else:
|
|
226
|
-
warnings.warn(msg)
|
|
226
|
+
warnings.warn(msg, stacklevel=2)
|
|
@@ -297,7 +297,7 @@ class InMemoryDataset(Dataset):
|
|
|
297
297
|
self._data_list = None
|
|
298
298
|
msg += f' {msg4}'
|
|
299
299
|
|
|
300
|
-
warnings.warn(msg)
|
|
300
|
+
warnings.warn(msg, stacklevel=2)
|
|
301
301
|
|
|
302
302
|
return self._data
|
|
303
303
|
|
|
@@ -346,7 +346,7 @@ class InMemoryDataset(Dataset):
|
|
|
346
346
|
|
|
347
347
|
def nested_iter(node: Union[Mapping, Sequence]) -> Iterable:
|
|
348
348
|
if isinstance(node, Mapping):
|
|
349
|
-
for
|
|
349
|
+
for value in node.values():
|
|
350
350
|
yield from nested_iter(value)
|
|
351
351
|
elif isinstance(node, Sequence):
|
|
352
352
|
yield from enumerate(node)
|
|
@@ -11,21 +11,27 @@ from torch_geometric.sampler import BaseSampler, NeighborSampler
|
|
|
11
11
|
from torch_geometric.typing import InputEdges, InputNodes, OptTensor
|
|
12
12
|
|
|
13
13
|
try:
|
|
14
|
-
from
|
|
15
|
-
|
|
14
|
+
from lightning.pytorch import LightningDataModule as _LightningDataModule
|
|
15
|
+
_pl_is_available = True
|
|
16
16
|
except ImportError:
|
|
17
|
-
|
|
18
|
-
|
|
17
|
+
try:
|
|
18
|
+
from pytorch_lightning import \
|
|
19
|
+
LightningDataModule as _LightningDataModule
|
|
20
|
+
_pl_is_available = True
|
|
21
|
+
except ImportError:
|
|
22
|
+
_pl_is_available = False
|
|
23
|
+
_LightningDataModule = object
|
|
19
24
|
|
|
20
25
|
|
|
21
|
-
class LightningDataModule(
|
|
26
|
+
class LightningDataModule(_LightningDataModule):
|
|
22
27
|
def __init__(self, has_val: bool, has_test: bool, **kwargs: Any) -> None:
|
|
23
28
|
super().__init__()
|
|
24
29
|
|
|
25
|
-
if
|
|
30
|
+
if not _pl_is_available:
|
|
26
31
|
raise ModuleNotFoundError(
|
|
27
|
-
"No module named 'pytorch_lightning'
|
|
28
|
-
"Run 'pip install
|
|
32
|
+
"No module named 'pytorch_lightning' (or 'lightning') found "
|
|
33
|
+
"in your Python environment. Run 'pip install "
|
|
34
|
+
"pytorch_lightning' or 'pip install lightning'")
|
|
29
35
|
|
|
30
36
|
if not has_val:
|
|
31
37
|
self.val_dataloader = None # type: ignore
|
|
@@ -40,9 +46,11 @@ class LightningDataModule(PLLightningDataModule):
|
|
|
40
46
|
kwargs.get('num_workers', 0) > 0)
|
|
41
47
|
|
|
42
48
|
if 'shuffle' in kwargs:
|
|
43
|
-
warnings.warn(
|
|
44
|
-
|
|
45
|
-
|
|
49
|
+
warnings.warn(
|
|
50
|
+
f"The 'shuffle={kwargs['shuffle']}' option is "
|
|
51
|
+
f"ignored in '{self.__class__.__name__}'. Remove it "
|
|
52
|
+
f"from the argument list to disable this warning",
|
|
53
|
+
stacklevel=2)
|
|
46
54
|
del kwargs['shuffle']
|
|
47
55
|
|
|
48
56
|
self.kwargs = kwargs
|
|
@@ -74,34 +82,39 @@ class LightningData(LightningDataModule):
|
|
|
74
82
|
raise ValueError(f"Undefined 'loader' option (got '{loader}')")
|
|
75
83
|
|
|
76
84
|
if loader == 'full' and kwargs['batch_size'] != 1:
|
|
77
|
-
warnings.warn(
|
|
78
|
-
|
|
79
|
-
|
|
85
|
+
warnings.warn(
|
|
86
|
+
f"Re-setting 'batch_size' to 1 in "
|
|
87
|
+
f"'{self.__class__.__name__}' for loader='full' "
|
|
88
|
+
f"(got '{kwargs['batch_size']}')", stacklevel=2)
|
|
80
89
|
kwargs['batch_size'] = 1
|
|
81
90
|
|
|
82
91
|
if loader == 'full' and kwargs['num_workers'] != 0:
|
|
83
|
-
warnings.warn(
|
|
84
|
-
|
|
85
|
-
|
|
92
|
+
warnings.warn(
|
|
93
|
+
f"Re-setting 'num_workers' to 0 in "
|
|
94
|
+
f"'{self.__class__.__name__}' for loader='full' "
|
|
95
|
+
f"(got '{kwargs['num_workers']}')", stacklevel=2)
|
|
86
96
|
kwargs['num_workers'] = 0
|
|
87
97
|
|
|
88
98
|
if loader == 'full' and kwargs.get('sampler') is not None:
|
|
89
|
-
warnings.warn(
|
|
90
|
-
|
|
99
|
+
warnings.warn(
|
|
100
|
+
"'sampler' option is not supported for "
|
|
101
|
+
"loader='full'", stacklevel=2)
|
|
91
102
|
kwargs.pop('sampler', None)
|
|
92
103
|
|
|
93
104
|
if loader == 'full' and kwargs.get('batch_sampler') is not None:
|
|
94
|
-
warnings.warn(
|
|
95
|
-
|
|
105
|
+
warnings.warn(
|
|
106
|
+
"'batch_sampler' option is not supported for "
|
|
107
|
+
"loader='full'", stacklevel=2)
|
|
96
108
|
kwargs.pop('batch_sampler', None)
|
|
97
109
|
|
|
98
110
|
super().__init__(has_val, has_test, **kwargs)
|
|
99
111
|
|
|
100
112
|
if loader == 'full':
|
|
101
113
|
if kwargs.get('pin_memory', False):
|
|
102
|
-
warnings.warn(
|
|
103
|
-
|
|
104
|
-
|
|
114
|
+
warnings.warn(
|
|
115
|
+
f"Re-setting 'pin_memory' to 'False' in "
|
|
116
|
+
f"'{self.__class__.__name__}' for loader='full' "
|
|
117
|
+
f"(got 'True')", stacklevel=2)
|
|
105
118
|
self.kwargs['pin_memory'] = False
|
|
106
119
|
|
|
107
120
|
self.data = data
|
|
@@ -127,10 +140,11 @@ class LightningData(LightningDataModule):
|
|
|
127
140
|
graph_sampler.__class__,
|
|
128
141
|
)
|
|
129
142
|
if len(sampler_kwargs) > 0:
|
|
130
|
-
warnings.warn(
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
143
|
+
warnings.warn(
|
|
144
|
+
f"Ignoring the arguments "
|
|
145
|
+
f"{list(sampler_kwargs.keys())} in "
|
|
146
|
+
f"'{self.__class__.__name__}' since a custom "
|
|
147
|
+
f"'graph_sampler' was passed", stacklevel=2)
|
|
134
148
|
self.graph_sampler = graph_sampler
|
|
135
149
|
|
|
136
150
|
else:
|
torch_geometric/data/storage.py
CHANGED
|
@@ -454,7 +454,7 @@ class NodeStorage(BaseStorage):
|
|
|
454
454
|
f"'{set(self.keys())}'. Please explicitly set 'num_nodes' as an "
|
|
455
455
|
f"attribute of " +
|
|
456
456
|
("'data'" if self._key is None else f"'data[{self._key}]'") +
|
|
457
|
-
" to suppress this warning")
|
|
457
|
+
" to suppress this warning", stacklevel=2)
|
|
458
458
|
if 'edge_index' in self and isinstance(self.edge_index, Tensor):
|
|
459
459
|
if self.edge_index.numel() > 0:
|
|
460
460
|
return int(self.edge_index.max()) + 1
|
|
@@ -806,6 +806,10 @@ class GlobalStorage(NodeStorage, EdgeStorage):
|
|
|
806
806
|
return False
|
|
807
807
|
|
|
808
808
|
cat_dim = self._parent().__cat_dim__(key, value, self)
|
|
809
|
+
|
|
810
|
+
if not isinstance(cat_dim, int):
|
|
811
|
+
return False
|
|
812
|
+
|
|
809
813
|
num_nodes, num_edges = self.num_nodes, self.num_edges
|
|
810
814
|
|
|
811
815
|
if value.shape[cat_dim] != num_nodes:
|
|
@@ -852,6 +856,10 @@ class GlobalStorage(NodeStorage, EdgeStorage):
|
|
|
852
856
|
return False
|
|
853
857
|
|
|
854
858
|
cat_dim = self._parent().__cat_dim__(key, value, self)
|
|
859
|
+
|
|
860
|
+
if not isinstance(cat_dim, int):
|
|
861
|
+
return False
|
|
862
|
+
|
|
855
863
|
num_nodes, num_edges = self.num_nodes, self.num_edges
|
|
856
864
|
|
|
857
865
|
if value.shape[cat_dim] != num_edges:
|
|
@@ -30,6 +30,7 @@ from .faust import FAUST
|
|
|
30
30
|
from .dynamic_faust import DynamicFAUST
|
|
31
31
|
from .shapenet import ShapeNet
|
|
32
32
|
from .modelnet import ModelNet
|
|
33
|
+
from .medshapenet import MedShapeNet
|
|
33
34
|
from .coma import CoMA
|
|
34
35
|
from .shrec2016 import SHREC2016
|
|
35
36
|
from .tosca import TOSCA
|
|
@@ -61,7 +62,6 @@ from .gemsec import GemsecDeezer
|
|
|
61
62
|
from .twitch import Twitch
|
|
62
63
|
from .airports import Airports
|
|
63
64
|
from .lrgb import LRGBDataset
|
|
64
|
-
from .neurograph import NeuroGraphDataset
|
|
65
65
|
from .malnet_tiny import MalNetTiny
|
|
66
66
|
from .omdb import OMDB
|
|
67
67
|
from .polblogs import PolBlogs
|
|
@@ -76,6 +76,15 @@ from .jodie import JODIEDataset
|
|
|
76
76
|
from .wikidata import Wikidata5M
|
|
77
77
|
from .myket import MyketDataset
|
|
78
78
|
from .brca_tgca import BrcaTcga
|
|
79
|
+
from .neurograph import NeuroGraphDataset
|
|
80
|
+
from .web_qsp_dataset import WebQSPDataset, CWQDataset
|
|
81
|
+
from .git_mol_dataset import GitMolDataset
|
|
82
|
+
from .molecule_gpt_dataset import MoleculeGPTDataset
|
|
83
|
+
from .instruct_mol_dataset import InstructMolDataset
|
|
84
|
+
from .protein_mpnn_dataset import ProteinMPNNDataset
|
|
85
|
+
from .tag_dataset import TAGDataset
|
|
86
|
+
from .city import CityNetwork
|
|
87
|
+
from .teeth3ds import Teeth3DS
|
|
79
88
|
|
|
80
89
|
from .dbp15k import DBP15K
|
|
81
90
|
from .aminer import AMiner
|
|
@@ -141,6 +150,7 @@ homo_datasets = [
|
|
|
141
150
|
'DynamicFAUST',
|
|
142
151
|
'ShapeNet',
|
|
143
152
|
'ModelNet',
|
|
153
|
+
'MedShapeNet',
|
|
144
154
|
'CoMA',
|
|
145
155
|
'SHREC2016',
|
|
146
156
|
'TOSCA',
|
|
@@ -188,6 +198,15 @@ homo_datasets = [
|
|
|
188
198
|
'MyketDataset',
|
|
189
199
|
'BrcaTcga',
|
|
190
200
|
'NeuroGraphDataset',
|
|
201
|
+
'WebQSPDataset',
|
|
202
|
+
'CWQDataset',
|
|
203
|
+
'GitMolDataset',
|
|
204
|
+
'MoleculeGPTDataset',
|
|
205
|
+
'InstructMolDataset',
|
|
206
|
+
'ProteinMPNNDataset',
|
|
207
|
+
'TAGDataset',
|
|
208
|
+
'CityNetwork',
|
|
209
|
+
'Teeth3DS',
|
|
191
210
|
]
|
|
192
211
|
|
|
193
212
|
hetero_datasets = [
|