pyg-nightly 2.6.0.dev20240511__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +30 -31
- {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +205 -181
- {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +26 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +16 -14
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/data.py +13 -8
- torch_geometric/data/database.py +15 -7
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +13 -22
- torch_geometric/data/graph_store.py +0 -4
- torch_geometric/data/hetero_data.py +4 -4
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/storage.py +15 -5
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +11 -1
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +6 -5
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +7 -1
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +4 -3
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +2 -2
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +17 -8
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +20 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +2 -3
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +9 -3
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +159 -34
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +2 -4
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/base.py +0 -1
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +100 -82
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +5 -4
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +3 -4
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +1 -2
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +322 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +7 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +203 -77
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +24 -15
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/testing/decorators.py +17 -22
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +4 -4
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +2 -2
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +31 -5
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +37 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +5 -5
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +1 -1
- torch_geometric/utils/smiles.py +66 -28
- torch_geometric/utils/sparse.py +25 -10
- torch_geometric/visualization/graph.py +3 -4
@@ -67,7 +67,7 @@ class WordNet18(InMemoryDataset):
|
|
67
67
|
def process(self) -> None:
|
68
68
|
srcs, dsts, edge_types = [], [], []
|
69
69
|
for path in self.raw_paths:
|
70
|
-
with open(path
|
70
|
+
with open(path) as f:
|
71
71
|
edges = [int(x) for x in f.read().split()[1:]]
|
72
72
|
edge = torch.tensor(edges, dtype=torch.long)
|
73
73
|
srcs.append(edge[::3])
|
@@ -173,7 +173,7 @@ class WordNet18RR(InMemoryDataset):
|
|
173
173
|
|
174
174
|
srcs, dsts, edge_types = [], [], []
|
175
175
|
for path in self.raw_paths:
|
176
|
-
with open(path
|
176
|
+
with open(path) as f:
|
177
177
|
edges = f.read().split()
|
178
178
|
|
179
179
|
_src = edges[::3]
|
torch_geometric/datasets/yelp.py
CHANGED
@@ -3,7 +3,6 @@ import os.path as osp
|
|
3
3
|
from typing import Callable, List, Optional
|
4
4
|
|
5
5
|
import numpy as np
|
6
|
-
import scipy.sparse as sp
|
7
6
|
import torch
|
8
7
|
|
9
8
|
from torch_geometric.data import Data, InMemoryDataset, download_google_url
|
@@ -73,6 +72,8 @@ class Yelp(InMemoryDataset):
|
|
73
72
|
download_google_url(self.role_id, self.raw_dir, 'role.json')
|
74
73
|
|
75
74
|
def process(self) -> None:
|
75
|
+
import scipy.sparse as sp
|
76
|
+
|
76
77
|
f = np.load(osp.join(self.raw_dir, 'adj_full.npz'))
|
77
78
|
adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape'])
|
78
79
|
adj = adj.tocoo()
|
torch_geometric/datasets/zinc.py
CHANGED
@@ -139,7 +139,7 @@ class ZINC(InMemoryDataset):
|
|
139
139
|
indices = list(range(len(mols)))
|
140
140
|
|
141
141
|
if self.subset:
|
142
|
-
with open(osp.join(self.raw_dir, f'{split}.index')
|
142
|
+
with open(osp.join(self.raw_dir, f'{split}.index')) as f:
|
143
143
|
indices = [int(x) for x in f.read()[:-1].split(',')]
|
144
144
|
|
145
145
|
pbar = tqdm(total=len(indices))
|
@@ -0,0 +1,42 @@
|
|
1
|
+
from typing import Any
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
|
6
|
+
def is_mps_available() -> bool:
|
7
|
+
r"""Returns a bool indicating if MPS is currently available."""
|
8
|
+
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
9
|
+
try: # Github CI may not have access to MPS hardware. Confirm:
|
10
|
+
torch.empty(1, device='mps')
|
11
|
+
return True
|
12
|
+
except Exception:
|
13
|
+
return False
|
14
|
+
return False
|
15
|
+
|
16
|
+
|
17
|
+
def is_xpu_available() -> bool:
|
18
|
+
r"""Returns a bool indicating if XPU is currently available."""
|
19
|
+
if hasattr(torch, 'xpu') and torch.xpu.is_available():
|
20
|
+
return True
|
21
|
+
try:
|
22
|
+
import intel_extension_for_pytorch as ipex
|
23
|
+
return ipex.xpu.is_available()
|
24
|
+
except ImportError:
|
25
|
+
return False
|
26
|
+
|
27
|
+
|
28
|
+
def device(device: Any) -> torch.device:
|
29
|
+
r"""Returns a :class:`torch.device`.
|
30
|
+
|
31
|
+
If :obj:`"auto"` is specified, returns the optimal device depending on
|
32
|
+
available hardware.
|
33
|
+
"""
|
34
|
+
if device != 'auto':
|
35
|
+
return torch.device(device)
|
36
|
+
if torch.cuda.is_available():
|
37
|
+
return torch.device('cuda')
|
38
|
+
if is_mps_available():
|
39
|
+
return torch.device('mps')
|
40
|
+
if is_xpu_available():
|
41
|
+
return torch.device('xpu')
|
42
|
+
return torch.device('cpu')
|
@@ -15,6 +15,7 @@ from torch_geometric.distributed.rpc import (
|
|
15
15
|
rpc_async,
|
16
16
|
rpc_register,
|
17
17
|
)
|
18
|
+
from torch_geometric.io import fs
|
18
19
|
from torch_geometric.typing import EdgeType, NodeOrEdgeType, NodeType
|
19
20
|
|
20
21
|
|
@@ -415,11 +416,11 @@ class LocalFeatureStore(FeatureStore):
|
|
415
416
|
|
416
417
|
node_feats: Optional[Dict[str, Any]] = None
|
417
418
|
if osp.exists(osp.join(part_dir, 'node_feats.pt')):
|
418
|
-
node_feats =
|
419
|
+
node_feats = fs.torch_load(osp.join(part_dir, 'node_feats.pt'))
|
419
420
|
|
420
421
|
edge_feats: Optional[Dict[str, Any]] = None
|
421
422
|
if osp.exists(osp.join(part_dir, 'edge_feats.pt')):
|
422
|
-
edge_feats =
|
423
|
+
edge_feats = fs.torch_load(osp.join(part_dir, 'edge_feats.pt'))
|
423
424
|
|
424
425
|
if not meta['is_hetero'] and node_feats is not None:
|
425
426
|
feat_store.put_global_id(node_feats['global_id'], group_name=None)
|
@@ -6,6 +6,7 @@ from torch import Tensor
|
|
6
6
|
|
7
7
|
from torch_geometric.data import EdgeAttr, GraphStore
|
8
8
|
from torch_geometric.distributed.partition import load_partition_info
|
9
|
+
from torch_geometric.io import fs
|
9
10
|
from torch_geometric.typing import EdgeTensorType, EdgeType, NodeType
|
10
11
|
from torch_geometric.utils import sort_edge_index
|
11
12
|
|
@@ -185,7 +186,7 @@ class LocalGraphStore(GraphStore):
|
|
185
186
|
graph_store.edge_pb = edge_pb
|
186
187
|
graph_store.meta = meta
|
187
188
|
|
188
|
-
graph_data =
|
189
|
+
graph_data = fs.torch_load(osp.join(part_dir, 'graph.pt'))
|
189
190
|
graph_store.is_sorted = meta['is_sorted']
|
190
191
|
|
191
192
|
if not meta['is_hetero']:
|
@@ -3,15 +3,16 @@ import logging
|
|
3
3
|
import os
|
4
4
|
import os.path as osp
|
5
5
|
from collections import defaultdict
|
6
|
-
from typing import List, Optional, Union
|
6
|
+
from typing import Dict, List, Optional, Tuple, Union
|
7
7
|
|
8
8
|
import torch
|
9
9
|
|
10
10
|
import torch_geometric.distributed as pyg_dist
|
11
11
|
from torch_geometric.data import Data, HeteroData
|
12
|
+
from torch_geometric.io import fs
|
12
13
|
from torch_geometric.loader.cluster import ClusterData
|
13
14
|
from torch_geometric.sampler.utils import sort_csc
|
14
|
-
from torch_geometric.typing import
|
15
|
+
from torch_geometric.typing import EdgeType, EdgeTypeStr, NodeType
|
15
16
|
|
16
17
|
|
17
18
|
class Partitioner:
|
@@ -23,7 +24,7 @@ class Partitioner:
|
|
23
24
|
|
24
25
|
**Homogeneous graphs:**
|
25
26
|
|
26
|
-
.. code-block::
|
27
|
+
.. code-block:: none
|
27
28
|
|
28
29
|
root/
|
29
30
|
|-- META.json
|
@@ -40,7 +41,7 @@ class Partitioner:
|
|
40
41
|
|
41
42
|
**Heterogeneous graphs:**
|
42
43
|
|
43
|
-
.. code-block::
|
44
|
+
.. code-block:: none
|
44
45
|
|
45
46
|
root/
|
46
47
|
|-- META.json
|
@@ -380,21 +381,21 @@ def load_partition_info(
|
|
380
381
|
assert osp.exists(partition_dir)
|
381
382
|
|
382
383
|
if meta['is_hetero'] is False:
|
383
|
-
node_pb =
|
384
|
-
edge_pb =
|
384
|
+
node_pb = fs.torch_load(osp.join(root_dir, 'node_map.pt'))
|
385
|
+
edge_pb = fs.torch_load(osp.join(root_dir, 'edge_map.pt'))
|
385
386
|
|
386
387
|
return (meta, num_partitions, partition_idx, node_pb, edge_pb)
|
387
388
|
else:
|
388
389
|
node_pb_dict = {}
|
389
390
|
node_pb_dir = osp.join(root_dir, 'node_map')
|
390
391
|
for ntype in meta['node_types']:
|
391
|
-
node_pb_dict[ntype] =
|
392
|
+
node_pb_dict[ntype] = fs.torch_load(
|
392
393
|
osp.join(node_pb_dir, f'{pyg_dist.utils.as_str(ntype)}.pt'))
|
393
394
|
|
394
395
|
edge_pb_dict = {}
|
395
396
|
edge_pb_dir = osp.join(root_dir, 'edge_map')
|
396
397
|
for etype in meta['edge_types']:
|
397
|
-
edge_pb_dict[tuple(etype)] =
|
398
|
+
edge_pb_dict[tuple(etype)] = fs.torch_load(
|
398
399
|
osp.join(edge_pb_dir, f'{pyg_dist.utils.as_str(etype)}.pt'))
|
399
400
|
|
400
401
|
return (meta, num_partitions, partition_idx, node_pb_dict,
|
torch_geometric/edge_index.py
CHANGED
@@ -173,7 +173,7 @@ class EdgeIndex(Tensor):
|
|
173
173
|
:meth:`EdgeIndex.fill_cache_`, and are maintained and adjusted over its
|
174
174
|
lifespan (*e.g.*, when calling :meth:`EdgeIndex.flip`).
|
175
175
|
|
176
|
-
This representation ensures
|
176
|
+
This representation ensures optimal computation in GNN message passing
|
177
177
|
schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
|
178
178
|
workflows.
|
179
179
|
|
@@ -820,19 +820,28 @@ class EdgeIndex(Tensor):
|
|
820
820
|
:obj:`1.0`. (default: :obj:`None`)
|
821
821
|
"""
|
822
822
|
value = self._get_value() if value is None else value
|
823
|
-
|
823
|
+
|
824
|
+
if not torch_geometric.typing.WITH_PT21:
|
825
|
+
out = torch.sparse_coo_tensor(
|
826
|
+
indices=self._data,
|
827
|
+
values=value,
|
828
|
+
size=self.get_sparse_size(),
|
829
|
+
device=self.device,
|
830
|
+
requires_grad=value.requires_grad,
|
831
|
+
)
|
832
|
+
if self.is_sorted_by_row:
|
833
|
+
out = out._coalesced_(True)
|
834
|
+
return out
|
835
|
+
|
836
|
+
return torch.sparse_coo_tensor(
|
824
837
|
indices=self._data,
|
825
838
|
values=value,
|
826
839
|
size=self.get_sparse_size(),
|
827
840
|
device=self.device,
|
828
841
|
requires_grad=value.requires_grad,
|
842
|
+
is_coalesced=True if self.is_sorted_by_row else None,
|
829
843
|
)
|
830
844
|
|
831
|
-
if self.is_sorted_by_row:
|
832
|
-
out = out._coalesced_(True)
|
833
|
-
|
834
|
-
return out
|
835
|
-
|
836
845
|
def to_sparse_csr( # type: ignore
|
837
846
|
self,
|
838
847
|
value: Optional[Tensor] = None,
|
@@ -1928,7 +1937,7 @@ def _spmm(
|
|
1928
1937
|
if transpose and not input.is_sorted_by_col:
|
1929
1938
|
cls_name = input.__class__.__name__
|
1930
1939
|
raise ValueError(f"'matmul(..., transpose=True)' requires "
|
1931
|
-
f"'{cls_name}' to be sorted by
|
1940
|
+
f"'{cls_name}' to be sorted by columns")
|
1932
1941
|
|
1933
1942
|
if (torch_geometric.typing.WITH_TORCH_SPARSE and not is_compiling()
|
1934
1943
|
and other.is_cuda): # pragma: no cover
|
@@ -50,7 +50,6 @@ class ExplainerAlgorithm(torch.nn.Module):
|
|
50
50
|
r"""Checks if the explainer supports the user-defined settings provided
|
51
51
|
in :obj:`self.explainer_config`, :obj:`self.model_config`.
|
52
52
|
"""
|
53
|
-
pass
|
54
53
|
|
55
54
|
###########################################################################
|
56
55
|
|
@@ -340,10 +340,10 @@ class HeteroExplanation(HeteroData, ExplanationMixin):
|
|
340
340
|
"""
|
341
341
|
node_mask_dict = self.node_mask_dict
|
342
342
|
for node_mask in node_mask_dict.values():
|
343
|
-
if node_mask.dim() != 2
|
343
|
+
if node_mask.dim() != 2:
|
344
344
|
raise ValueError(f"Cannot compute feature importance for "
|
345
345
|
f"object-level 'node_mask' "
|
346
|
-
f"(got shape {
|
346
|
+
f"(got shape {node_mask.size()})")
|
347
347
|
|
348
348
|
if feat_labels is None:
|
349
349
|
feat_labels = {}
|
@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Union
|
|
6
6
|
import torch
|
7
7
|
|
8
8
|
from torch_geometric.graphgym.config import cfg
|
9
|
+
from torch_geometric.io import fs
|
9
10
|
|
10
11
|
MODEL_STATE = 'model_state'
|
11
12
|
OPTIMIZER_STATE = 'optimizer_state'
|
@@ -25,7 +26,7 @@ def load_ckpt(
|
|
25
26
|
if not osp.exists(path):
|
26
27
|
return 0
|
27
28
|
|
28
|
-
ckpt =
|
29
|
+
ckpt = fs.torch_load(path)
|
29
30
|
model.load_state_dict(ckpt[MODEL_STATE])
|
30
31
|
if optimizer is not None and OPTIMIZER_STATE in ckpt:
|
31
32
|
optimizer.load_state_dict(ckpt[OPTIMIZER_STATE])
|
@@ -19,7 +19,7 @@ def set_printing():
|
|
19
19
|
logging.root.handlers = []
|
20
20
|
logging_cfg = {'level': logging.INFO, 'format': '%(message)s'}
|
21
21
|
os.makedirs(cfg.run_dir, exist_ok=True)
|
22
|
-
h_file = logging.FileHandler('{}/logging.log'
|
22
|
+
h_file = logging.FileHandler(f'{cfg.run_dir}/logging.log')
|
23
23
|
h_stdout = logging.StreamHandler(sys.stdout)
|
24
24
|
if cfg.print == 'file':
|
25
25
|
logging_cfg['handlers'] = [h_file]
|
@@ -40,7 +40,7 @@ class Logger:
|
|
40
40
|
self._epoch_total = cfg.optim.max_epoch
|
41
41
|
self._time_total = 0 # won't be reset
|
42
42
|
|
43
|
-
self.out_dir = '{}/{}'
|
43
|
+
self.out_dir = f'{cfg.run_dir}/{name}'
|
44
44
|
os.makedirs(self.out_dir, exist_ok=True)
|
45
45
|
if cfg.tensorboard_each_run:
|
46
46
|
from tensorboardX import SummaryWriter
|
@@ -210,9 +210,9 @@ class Logger:
|
|
210
210
|
}
|
211
211
|
|
212
212
|
# print
|
213
|
-
logging.info('{}: {}'
|
213
|
+
logging.info(f'{self.name}: {stats}')
|
214
214
|
# json
|
215
|
-
dict_to_json(stats, '{}/stats.json'
|
215
|
+
dict_to_json(stats, f'{self.out_dir}/stats.json')
|
216
216
|
# tensorboard
|
217
217
|
if cfg.tensorboard_each_run:
|
218
218
|
dict_to_tb(stats, self.tb_writer, cur_epoch)
|
torch_geometric/graphgym/loss.py
CHANGED
@@ -54,7 +54,7 @@ def agg_dict_list(dict_list):
|
|
54
54
|
if key != 'epoch':
|
55
55
|
value = np.array([dict[key] for dict in dict_list])
|
56
56
|
dict_agg[key] = np.mean(value).round(cfg.round)
|
57
|
-
dict_agg['{}_std'
|
57
|
+
dict_agg[f'{key}_std'] = np.std(value).round(cfg.round)
|
58
58
|
return dict_agg
|
59
59
|
|
60
60
|
|
@@ -107,7 +107,7 @@ def agg_runs(dir, metric_best='auto'):
|
|
107
107
|
[stats[metric] for stats in stats_list])
|
108
108
|
best_epoch = \
|
109
109
|
stats_list[
|
110
|
-
eval("performance_np.{}()"
|
110
|
+
eval(f"performance_np.{cfg.metric_agg}()")][
|
111
111
|
'epoch']
|
112
112
|
print(best_epoch)
|
113
113
|
|
@@ -190,7 +190,7 @@ def agg_batch(dir, metric_best='auto'):
|
|
190
190
|
results[key] = pd.DataFrame(results[key])
|
191
191
|
results[key] = results[key].sort_values(
|
192
192
|
list(dict_name.keys()), ascending=[True] * len(dict_name))
|
193
|
-
fname = osp.join(dir_out, '{}_best.csv'
|
193
|
+
fname = osp.join(dir_out, f'{key}_best.csv')
|
194
194
|
results[key].to_csv(fname, index=False)
|
195
195
|
|
196
196
|
results = {'train': [], 'val': [], 'test': []}
|
@@ -213,7 +213,7 @@ def agg_batch(dir, metric_best='auto'):
|
|
213
213
|
results[key] = pd.DataFrame(results[key])
|
214
214
|
results[key] = results[key].sort_values(
|
215
215
|
list(dict_name.keys()), ascending=[True] * len(dict_name))
|
216
|
-
fname = osp.join(dir_out, '{}.csv'
|
216
|
+
fname = osp.join(dir_out, f'{key}.csv')
|
217
217
|
results[key].to_csv(fname, index=False)
|
218
218
|
|
219
219
|
results = {'train': [], 'val': [], 'test': []}
|
@@ -245,7 +245,7 @@ def agg_batch(dir, metric_best='auto'):
|
|
245
245
|
results[key] = pd.DataFrame(results[key])
|
246
246
|
results[key] = results[key].sort_values(
|
247
247
|
list(dict_name.keys()), ascending=[True] * len(dict_name))
|
248
|
-
fname = osp.join(dir_out, '{}_bestepoch.csv'
|
248
|
+
fname = osp.join(dir_out, f'{key}_bestepoch.csv')
|
249
249
|
results[key].to_csv(fname, index=False)
|
250
250
|
|
251
|
-
print('Results aggregated across models saved in {}'
|
251
|
+
print(f'Results aggregated across models saved in {dir_out}')
|
torch_geometric/index.py
CHANGED
@@ -106,7 +106,7 @@ class Index(Tensor):
|
|
106
106
|
:meth:`Index.fill_cache_`, and are maintaned and adjusted over its
|
107
107
|
lifespan.
|
108
108
|
|
109
|
-
This representation ensures
|
109
|
+
This representation ensures optimal computation in GNN message passing
|
110
110
|
schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
|
111
111
|
workflows.
|
112
112
|
|
@@ -120,7 +120,7 @@ class Index(Tensor):
|
|
120
120
|
assert index.is_sorted
|
121
121
|
|
122
122
|
# Flipping order:
|
123
|
-
|
123
|
+
index.flip(0)
|
124
124
|
>>> Index([[2, 1, 1, 0], dim_size=3)
|
125
125
|
assert not index.is_sorted
|
126
126
|
|
@@ -685,14 +685,14 @@ def _index(
|
|
685
685
|
|
686
686
|
@implements(aten.add.Tensor)
|
687
687
|
def _add(
|
688
|
-
input: Index,
|
688
|
+
input: Union[int, Tensor, Index],
|
689
689
|
other: Union[int, Tensor, Index],
|
690
690
|
*,
|
691
691
|
alpha: int = 1,
|
692
692
|
) -> Union[Index, Tensor]:
|
693
693
|
|
694
694
|
data = aten.add.Tensor(
|
695
|
-
input._data,
|
695
|
+
input._data if isinstance(input, Index) else input,
|
696
696
|
other._data if isinstance(other, Index) else other,
|
697
697
|
alpha=alpha,
|
698
698
|
)
|
@@ -704,15 +704,25 @@ def _add(
|
|
704
704
|
|
705
705
|
out = Index(data)
|
706
706
|
|
707
|
+
if isinstance(input, Tensor) and input.numel() <= 1:
|
708
|
+
input = int(input)
|
709
|
+
|
707
710
|
if isinstance(other, Tensor) and other.numel() <= 1:
|
708
711
|
other = int(other)
|
709
712
|
|
710
713
|
if isinstance(other, int):
|
714
|
+
assert isinstance(input, Index)
|
711
715
|
if input.dim_size is not None:
|
712
716
|
out._dim_size = input.dim_size + alpha * other
|
713
717
|
out._is_sorted = input.is_sorted
|
714
718
|
|
715
|
-
elif isinstance(
|
719
|
+
elif isinstance(input, int):
|
720
|
+
assert isinstance(other, Index)
|
721
|
+
if other.dim_size is not None:
|
722
|
+
out._dim_size = input + alpha * other.dim_size
|
723
|
+
out._is_sorted = other.is_sorted
|
724
|
+
|
725
|
+
elif isinstance(input, Index) and isinstance(other, Index):
|
716
726
|
if input.dim_size is not None and other.dim_size is not None:
|
717
727
|
out._dim_size = input.dim_size + alpha * other.dim_size
|
718
728
|
|
@@ -754,14 +764,14 @@ def add_(
|
|
754
764
|
|
755
765
|
@implements(aten.sub.Tensor)
|
756
766
|
def _sub(
|
757
|
-
input: Index,
|
767
|
+
input: Union[int, Tensor, Index],
|
758
768
|
other: Union[int, Tensor, Index],
|
759
769
|
*,
|
760
770
|
alpha: int = 1,
|
761
771
|
) -> Union[Index, Tensor]:
|
762
772
|
|
763
773
|
data = aten.sub.Tensor(
|
764
|
-
input._data,
|
774
|
+
input._data if isinstance(input, Index) else input,
|
765
775
|
other._data if isinstance(other, Index) else other,
|
766
776
|
alpha=alpha,
|
767
777
|
)
|
@@ -773,6 +783,9 @@ def _sub(
|
|
773
783
|
|
774
784
|
out = Index(data)
|
775
785
|
|
786
|
+
if not isinstance(input, Index):
|
787
|
+
return out
|
788
|
+
|
776
789
|
if isinstance(other, Tensor) and other.numel() <= 1:
|
777
790
|
other = int(other)
|
778
791
|
|
torch_geometric/inspector.py
CHANGED
@@ -305,7 +305,7 @@ class Inspector:
|
|
305
305
|
according to its function signature from a data blob.
|
306
306
|
|
307
307
|
Args:
|
308
|
-
func (
|
308
|
+
func (callable or str): The function.
|
309
309
|
kwargs (dict[str, Any]): The data blob which may serve as inputs.
|
310
310
|
"""
|
311
311
|
out_dict: Dict[str, Any] = {}
|
@@ -346,7 +346,7 @@ class Inspector:
|
|
346
346
|
type annotations are not found.
|
347
347
|
|
348
348
|
Args:
|
349
|
-
func (
|
349
|
+
func (callable or str): The function.
|
350
350
|
exclude (list[int or str]): A list of parameters to exclude, either
|
351
351
|
given by their name or index. (default: :obj:`None`)
|
352
352
|
"""
|
@@ -448,6 +448,10 @@ def type_repr(obj: Any, _globals: Dict[str, Any]) -> str:
|
|
448
448
|
return '...'
|
449
449
|
|
450
450
|
if obj.__module__ == 'typing': # Special logic for `typing.*` types:
|
451
|
+
|
452
|
+
if not hasattr(obj, '_name'):
|
453
|
+
return repr(obj)
|
454
|
+
|
451
455
|
name = obj._name
|
452
456
|
if name is None: # In some cases, `_name` is not populated.
|
453
457
|
name = str(obj.__origin__).split('.')[-1]
|
torch_geometric/io/fs.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1
1
|
import io
|
2
2
|
import os.path as osp
|
3
|
+
import pickle
|
4
|
+
import re
|
3
5
|
import sys
|
6
|
+
import warnings
|
4
7
|
from typing import Any, Dict, List, Literal, Optional, Union, overload
|
5
8
|
from uuid import uuid4
|
6
9
|
|
@@ -186,11 +189,11 @@ def rm(path: str, recursive: bool = True) -> None:
|
|
186
189
|
get_fs(path).rm(path, recursive)
|
187
190
|
|
188
191
|
|
189
|
-
def mv(path1: str, path2: str
|
192
|
+
def mv(path1: str, path2: str) -> None:
|
190
193
|
fs1 = get_fs(path1)
|
191
194
|
fs2 = get_fs(path2)
|
192
195
|
assert fs1.protocol == fs2.protocol
|
193
|
-
fs1.mv(path1, path2
|
196
|
+
fs1.mv(path1, path2)
|
194
197
|
|
195
198
|
|
196
199
|
def glob(path: str) -> List[str]:
|
@@ -211,5 +214,28 @@ def torch_save(data: Any, path: str) -> None:
|
|
211
214
|
|
212
215
|
|
213
216
|
def torch_load(path: str, map_location: Any = None) -> Any:
|
217
|
+
if torch_geometric.typing.WITH_PT24:
|
218
|
+
try:
|
219
|
+
with fsspec.open(path, 'rb') as f:
|
220
|
+
return torch.load(f, map_location, weights_only=True)
|
221
|
+
except pickle.UnpicklingError as e:
|
222
|
+
error_msg = str(e)
|
223
|
+
if "add_safe_globals" in error_msg:
|
224
|
+
warn_msg = ("Weights only load failed. Please file an issue "
|
225
|
+
"to make `torch.load(weights_only=True)` "
|
226
|
+
"compatible in your case.")
|
227
|
+
match = re.search(r'add_safe_globals\(.*?\)', error_msg)
|
228
|
+
if match is not None:
|
229
|
+
warnings.warn(f"{warn_msg} Please use "
|
230
|
+
f"`torch.serialization.{match.group()}` to "
|
231
|
+
f"allowlist this global.")
|
232
|
+
else:
|
233
|
+
warnings.warn(warn_msg)
|
234
|
+
|
235
|
+
with fsspec.open(path, 'rb') as f:
|
236
|
+
return torch.load(f, map_location, weights_only=False)
|
237
|
+
else:
|
238
|
+
raise e
|
239
|
+
|
214
240
|
with fsspec.open(path, 'rb') as f:
|
215
241
|
return torch.load(f, map_location)
|
torch_geometric/io/npz.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
from typing import Any, Dict
|
2
2
|
|
3
3
|
import numpy as np
|
4
|
-
import scipy.sparse as sp
|
5
4
|
import torch
|
6
5
|
|
7
6
|
from torch_geometric.data import Data
|
@@ -15,6 +14,8 @@ def read_npz(path: str, to_undirected: bool = True) -> Data:
|
|
15
14
|
|
16
15
|
|
17
16
|
def parse_npz(f: Dict[str, Any], to_undirected: bool = True) -> Data:
|
17
|
+
import scipy.sparse as sp
|
18
|
+
|
18
19
|
x = sp.csr_matrix((f['attr_data'], f['attr_indices'], f['attr_indptr']),
|
19
20
|
f['attr_shape']).todense()
|
20
21
|
x = torch.from_numpy(x).to(torch.float)
|
torch_geometric/io/off.py
CHANGED
@@ -16,7 +16,7 @@ def parse_off(src: List[str]) -> Data:
|
|
16
16
|
else:
|
17
17
|
src[0] = src[0][3:]
|
18
18
|
|
19
|
-
num_nodes, num_faces =
|
19
|
+
num_nodes, num_faces = (int(item) for item in src[0].split()[:2])
|
20
20
|
|
21
21
|
pos = parse_txt_array(src[1:1 + num_nodes])
|
22
22
|
|
@@ -52,7 +52,7 @@ def read_off(path: str) -> Data:
|
|
52
52
|
Args:
|
53
53
|
path (str): The path to the file.
|
54
54
|
"""
|
55
|
-
with open(path
|
55
|
+
with open(path) as f:
|
56
56
|
src = f.read().split('\n')[:-1]
|
57
57
|
return parse_off(src)
|
58
58
|
|
torch_geometric/io/sdf.py
CHANGED
@@ -9,7 +9,7 @@ elems = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
|
|
9
9
|
|
10
10
|
def parse_sdf(src: str) -> Data:
|
11
11
|
lines = src.split('\n')[3:]
|
12
|
-
num_atoms, num_bonds =
|
12
|
+
num_atoms, num_bonds = (int(item) for item in lines[0].split()[:2])
|
13
13
|
|
14
14
|
atom_block = lines[1:num_atoms + 1]
|
15
15
|
pos = parse_txt_array(atom_block, end=3)
|
@@ -28,5 +28,5 @@ def parse_sdf(src: str) -> Data:
|
|
28
28
|
|
29
29
|
|
30
30
|
def read_sdf(path: str) -> Data:
|
31
|
-
with open(path
|
31
|
+
with open(path) as f:
|
32
32
|
return parse_sdf(f.read())
|
torch_geometric/io/tu.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
import os.path as osp
|
2
2
|
from typing import Dict, List, Optional, Tuple
|
3
3
|
|
4
|
-
import numpy as np
|
5
4
|
import torch
|
6
5
|
from torch import Tensor
|
7
6
|
|
@@ -108,11 +107,11 @@ def cat(seq: List[Optional[Tensor]]) -> Optional[Tensor]:
|
|
108
107
|
|
109
108
|
|
110
109
|
def split(data: Data, batch: Tensor) -> Tuple[Data, Dict[str, Tensor]]:
|
111
|
-
node_slice = cumsum(torch.
|
110
|
+
node_slice = cumsum(torch.bincount(batch))
|
112
111
|
|
113
112
|
assert data.edge_index is not None
|
114
113
|
row, _ = data.edge_index
|
115
|
-
edge_slice = cumsum(torch.
|
114
|
+
edge_slice = cumsum(torch.bincount(batch[row]))
|
116
115
|
|
117
116
|
# Edge indices should start at zero for every graph.
|
118
117
|
data.edge_index -= node_slice[batch[row]].unsqueeze(0)
|
@@ -22,6 +22,7 @@ from .dynamic_batch_sampler import DynamicBatchSampler
|
|
22
22
|
from .prefetch import PrefetchLoader
|
23
23
|
from .cache import CachedLoader
|
24
24
|
from .mixin import AffinityMixin
|
25
|
+
from .rag_loader import RAGQueryLoader, RAGFeatureStore, RAGGraphStore
|
25
26
|
|
26
27
|
__all__ = classes = [
|
27
28
|
'DataLoader',
|
@@ -50,6 +51,9 @@ __all__ = classes = [
|
|
50
51
|
'PrefetchLoader',
|
51
52
|
'CachedLoader',
|
52
53
|
'AffinityMixin',
|
54
|
+
'RAGQueryLoader',
|
55
|
+
'RAGFeatureStore',
|
56
|
+
'RAGGraphStore'
|
53
57
|
]
|
54
58
|
|
55
59
|
RandomNodeSampler = deprecated(
|