pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +28 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +28 -19
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +8 -2
- torch_geometric/data/data.py +16 -8
- torch_geometric/data/database.py +61 -15
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +25 -42
- torch_geometric/data/graph_store.py +1 -5
- torch_geometric/data/hetero_data.py +18 -9
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/separate.py +6 -1
- torch_geometric/data/storage.py +17 -7
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +17 -2
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/cornell.py +145 -0
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/karate.py +3 -2
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +15 -3
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal.py +11 -9
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +10 -3
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +8 -7
- torch_geometric/datasets/rcdd.py +4 -4
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +5 -3
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +7 -1
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/willow_object_class.py +1 -1
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +616 -438
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +826 -0
- torch_geometric/inspector.py +8 -3
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +4 -5
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +10 -4
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/utils.py +8 -7
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +159 -34
- torch_geometric/nn/aggr/__init__.py +4 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +2 -4
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/variance_preserving.py +33 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +7 -4
- torch_geometric/nn/conv/cugraph/base.py +8 -12
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/fused_gat_conv.py +1 -1
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/graph_conv.py +21 -3
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +138 -87
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/propagate.jinja +9 -1
- torch_geometric/nn/conv/rgcn_conv.py +5 -5
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +11 -6
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +5 -5
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +2 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +322 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +21 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/pool.py +8 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +204 -78
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +30 -19
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/sampler/utils.py +1 -1
- torch_geometric/template.py +1 -0
- torch_geometric/testing/__init__.py +6 -2
- torch_geometric/testing/decorators.py +53 -20
- torch_geometric/testing/feature_store.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +5 -5
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +3 -3
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/to_sparse_tensor.py +1 -1
- torch_geometric/transforms/two_hop.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +43 -6
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +38 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +12 -8
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +8 -2
- torch_geometric/utils/smiles.py +65 -27
- torch_geometric/utils/sparse.py +39 -25
- torch_geometric/visualization/graph.py +3 -4
torch_geometric/inspector.py
CHANGED
@@ -305,7 +305,7 @@ class Inspector:
|
|
305
305
|
according to its function signature from a data blob.
|
306
306
|
|
307
307
|
Args:
|
308
|
-
func (
|
308
|
+
func (callable or str): The function.
|
309
309
|
kwargs (dict[str, Any]): The data blob which may serve as inputs.
|
310
310
|
"""
|
311
311
|
out_dict: Dict[str, Any] = {}
|
@@ -346,7 +346,7 @@ class Inspector:
|
|
346
346
|
type annotations are not found.
|
347
347
|
|
348
348
|
Args:
|
349
|
-
func (
|
349
|
+
func (callable or str): The function.
|
350
350
|
exclude (list[int or str]): A list of parameters to exclude, either
|
351
351
|
given by their name or index. (default: :obj:`None`)
|
352
352
|
"""
|
@@ -401,7 +401,8 @@ class Inspector:
|
|
401
401
|
match = find_parenthesis_content(source, f'self.{func_name}')
|
402
402
|
if match is not None:
|
403
403
|
for i, kwarg in enumerate(split(match, sep=',')):
|
404
|
-
if
|
404
|
+
if ('=' not in kwarg and exclude is not None
|
405
|
+
and i in exclude):
|
405
406
|
continue
|
406
407
|
|
407
408
|
name_and_content = re.split(r'\s*=\s*', kwarg)
|
@@ -447,6 +448,10 @@ def type_repr(obj: Any, _globals: Dict[str, Any]) -> str:
|
|
447
448
|
return '...'
|
448
449
|
|
449
450
|
if obj.__module__ == 'typing': # Special logic for `typing.*` types:
|
451
|
+
|
452
|
+
if not hasattr(obj, '_name'):
|
453
|
+
return repr(obj)
|
454
|
+
|
450
455
|
name = obj._name
|
451
456
|
if name is None: # In some cases, `_name` is not populated.
|
452
457
|
name = str(obj.__origin__).split('.')[-1]
|
torch_geometric/io/fs.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1
1
|
import io
|
2
2
|
import os.path as osp
|
3
|
+
import pickle
|
4
|
+
import re
|
3
5
|
import sys
|
6
|
+
import warnings
|
4
7
|
from typing import Any, Dict, List, Literal, Optional, Union, overload
|
5
8
|
from uuid import uuid4
|
6
9
|
|
@@ -186,11 +189,11 @@ def rm(path: str, recursive: bool = True) -> None:
|
|
186
189
|
get_fs(path).rm(path, recursive)
|
187
190
|
|
188
191
|
|
189
|
-
def mv(path1: str, path2: str
|
192
|
+
def mv(path1: str, path2: str) -> None:
|
190
193
|
fs1 = get_fs(path1)
|
191
194
|
fs2 = get_fs(path2)
|
192
195
|
assert fs1.protocol == fs2.protocol
|
193
|
-
fs1.mv(path1, path2
|
196
|
+
fs1.mv(path1, path2)
|
194
197
|
|
195
198
|
|
196
199
|
def glob(path: str) -> List[str]:
|
@@ -211,5 +214,28 @@ def torch_save(data: Any, path: str) -> None:
|
|
211
214
|
|
212
215
|
|
213
216
|
def torch_load(path: str, map_location: Any = None) -> Any:
|
217
|
+
if torch_geometric.typing.WITH_PT24:
|
218
|
+
try:
|
219
|
+
with fsspec.open(path, 'rb') as f:
|
220
|
+
return torch.load(f, map_location, weights_only=True)
|
221
|
+
except pickle.UnpicklingError as e:
|
222
|
+
error_msg = str(e)
|
223
|
+
if "add_safe_globals" in error_msg:
|
224
|
+
warn_msg = ("Weights only load failed. Please file an issue "
|
225
|
+
"to make `torch.load(weights_only=True)` "
|
226
|
+
"compatible in your case.")
|
227
|
+
match = re.search(r'add_safe_globals\(.*?\)', error_msg)
|
228
|
+
if match is not None:
|
229
|
+
warnings.warn(f"{warn_msg} Please use "
|
230
|
+
f"`torch.serialization.{match.group()}` to "
|
231
|
+
f"allowlist this global.")
|
232
|
+
else:
|
233
|
+
warnings.warn(warn_msg)
|
234
|
+
|
235
|
+
with fsspec.open(path, 'rb') as f:
|
236
|
+
return torch.load(f, map_location, weights_only=False)
|
237
|
+
else:
|
238
|
+
raise e
|
239
|
+
|
214
240
|
with fsspec.open(path, 'rb') as f:
|
215
241
|
return torch.load(f, map_location)
|
torch_geometric/io/npz.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
from typing import Any, Dict
|
2
2
|
|
3
3
|
import numpy as np
|
4
|
-
import scipy.sparse as sp
|
5
4
|
import torch
|
6
5
|
|
7
6
|
from torch_geometric.data import Data
|
@@ -15,6 +14,8 @@ def read_npz(path: str, to_undirected: bool = True) -> Data:
|
|
15
14
|
|
16
15
|
|
17
16
|
def parse_npz(f: Dict[str, Any], to_undirected: bool = True) -> Data:
|
17
|
+
import scipy.sparse as sp
|
18
|
+
|
18
19
|
x = sp.csr_matrix((f['attr_data'], f['attr_indices'], f['attr_indptr']),
|
19
20
|
f['attr_shape']).todense()
|
20
21
|
x = torch.from_numpy(x).to(torch.float)
|
torch_geometric/io/off.py
CHANGED
@@ -16,7 +16,7 @@ def parse_off(src: List[str]) -> Data:
|
|
16
16
|
else:
|
17
17
|
src[0] = src[0][3:]
|
18
18
|
|
19
|
-
num_nodes, num_faces =
|
19
|
+
num_nodes, num_faces = (int(item) for item in src[0].split()[:2])
|
20
20
|
|
21
21
|
pos = parse_txt_array(src[1:1 + num_nodes])
|
22
22
|
|
@@ -52,7 +52,7 @@ def read_off(path: str) -> Data:
|
|
52
52
|
Args:
|
53
53
|
path (str): The path to the file.
|
54
54
|
"""
|
55
|
-
with open(path
|
55
|
+
with open(path) as f:
|
56
56
|
src = f.read().split('\n')[:-1]
|
57
57
|
return parse_off(src)
|
58
58
|
|
torch_geometric/io/sdf.py
CHANGED
@@ -9,7 +9,7 @@ elems = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
|
|
9
9
|
|
10
10
|
def parse_sdf(src: str) -> Data:
|
11
11
|
lines = src.split('\n')[3:]
|
12
|
-
num_atoms, num_bonds =
|
12
|
+
num_atoms, num_bonds = (int(item) for item in lines[0].split()[:2])
|
13
13
|
|
14
14
|
atom_block = lines[1:num_atoms + 1]
|
15
15
|
pos = parse_txt_array(atom_block, end=3)
|
@@ -28,5 +28,5 @@ def parse_sdf(src: str) -> Data:
|
|
28
28
|
|
29
29
|
|
30
30
|
def read_sdf(path: str) -> Data:
|
31
|
-
with open(path
|
31
|
+
with open(path) as f:
|
32
32
|
return parse_sdf(f.read())
|
torch_geometric/io/tu.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
import os.path as osp
|
2
2
|
from typing import Dict, List, Optional, Tuple
|
3
3
|
|
4
|
-
import numpy as np
|
5
4
|
import torch
|
6
5
|
from torch import Tensor
|
7
6
|
|
@@ -37,7 +36,7 @@ def read_tu_data(
|
|
37
36
|
if node_label.dim() == 1:
|
38
37
|
node_label = node_label.unsqueeze(-1)
|
39
38
|
node_label = node_label - node_label.min(dim=0)[0]
|
40
|
-
node_labels = node_label.unbind(dim=-1)
|
39
|
+
node_labels = list(node_label.unbind(dim=-1))
|
41
40
|
node_labels = [one_hot(x) for x in node_labels]
|
42
41
|
if len(node_labels) == 1:
|
43
42
|
node_label = node_labels[0]
|
@@ -56,7 +55,7 @@ def read_tu_data(
|
|
56
55
|
if edge_label.dim() == 1:
|
57
56
|
edge_label = edge_label.unsqueeze(-1)
|
58
57
|
edge_label = edge_label - edge_label.min(dim=0)[0]
|
59
|
-
edge_labels = edge_label.unbind(dim=-1)
|
58
|
+
edge_labels = list(edge_label.unbind(dim=-1))
|
60
59
|
edge_labels = [one_hot(e) for e in edge_labels]
|
61
60
|
if len(edge_labels) == 1:
|
62
61
|
edge_label = edge_labels[0]
|
@@ -108,11 +107,11 @@ def cat(seq: List[Optional[Tensor]]) -> Optional[Tensor]:
|
|
108
107
|
|
109
108
|
|
110
109
|
def split(data: Data, batch: Tensor) -> Tuple[Data, Dict[str, Tensor]]:
|
111
|
-
node_slice = cumsum(torch.
|
110
|
+
node_slice = cumsum(torch.bincount(batch))
|
112
111
|
|
113
112
|
assert data.edge_index is not None
|
114
113
|
row, _ = data.edge_index
|
115
|
-
edge_slice = cumsum(torch.
|
114
|
+
edge_slice = cumsum(torch.bincount(batch[row]))
|
116
115
|
|
117
116
|
# Edge indices should start at zero for every graph.
|
118
117
|
data.edge_index -= node_slice[batch[row]].unsqueeze(0)
|
@@ -22,6 +22,7 @@ from .dynamic_batch_sampler import DynamicBatchSampler
|
|
22
22
|
from .prefetch import PrefetchLoader
|
23
23
|
from .cache import CachedLoader
|
24
24
|
from .mixin import AffinityMixin
|
25
|
+
from .rag_loader import RAGQueryLoader, RAGFeatureStore, RAGGraphStore
|
25
26
|
|
26
27
|
__all__ = classes = [
|
27
28
|
'DataLoader',
|
@@ -50,6 +51,9 @@ __all__ = classes = [
|
|
50
51
|
'PrefetchLoader',
|
51
52
|
'CachedLoader',
|
52
53
|
'AffinityMixin',
|
54
|
+
'RAGQueryLoader',
|
55
|
+
'RAGFeatureStore',
|
56
|
+
'RAGGraphStore'
|
53
57
|
]
|
54
58
|
|
55
59
|
RandomNodeSampler = deprecated(
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import copy
|
2
|
+
import os
|
2
3
|
import os.path as osp
|
3
4
|
import sys
|
4
5
|
from dataclasses import dataclass
|
@@ -10,10 +11,11 @@ from torch import Tensor
|
|
10
11
|
|
11
12
|
import torch_geometric.typing
|
12
13
|
from torch_geometric.data import Data
|
14
|
+
from torch_geometric.index import index2ptr, ptr2index
|
15
|
+
from torch_geometric.io import fs
|
13
16
|
from torch_geometric.typing import pyg_lib
|
14
17
|
from torch_geometric.utils import index_sort, narrow, select, sort_edge_index
|
15
18
|
from torch_geometric.utils.map import map_index
|
16
|
-
from torch_geometric.utils.sparse import index2ptr, ptr2index
|
17
19
|
|
18
20
|
|
19
21
|
@dataclass
|
@@ -43,6 +45,8 @@ class ClusterData(torch.utils.data.Dataset):
|
|
43
45
|
(default: :obj:`False`)
|
44
46
|
save_dir (str, optional): If set, will save the partitioned data to the
|
45
47
|
:obj:`save_dir` directory for faster re-use. (default: :obj:`None`)
|
48
|
+
filename (str, optional): Name of the stored partitioned file.
|
49
|
+
(default: :obj:`None`)
|
46
50
|
log (bool, optional): If set to :obj:`False`, will not log any
|
47
51
|
progress. (default: :obj:`True`)
|
48
52
|
keep_inter_cluster_edges (bool, optional): If set to :obj:`True`,
|
@@ -56,6 +60,7 @@ class ClusterData(torch.utils.data.Dataset):
|
|
56
60
|
num_parts: int,
|
57
61
|
recursive: bool = False,
|
58
62
|
save_dir: Optional[str] = None,
|
63
|
+
filename: Optional[str] = None,
|
59
64
|
log: bool = True,
|
60
65
|
keep_inter_cluster_edges: bool = False,
|
61
66
|
sparse_format: Literal['csr', 'csc'] = 'csr',
|
@@ -69,11 +74,11 @@ class ClusterData(torch.utils.data.Dataset):
|
|
69
74
|
self.sparse_format = sparse_format
|
70
75
|
|
71
76
|
recursive_str = '_recursive' if recursive else ''
|
72
|
-
|
73
|
-
path = osp.join(
|
77
|
+
root_dir = osp.join(save_dir or '', f'part_{num_parts}{recursive_str}')
|
78
|
+
path = osp.join(root_dir, filename or 'metis.pt')
|
74
79
|
|
75
80
|
if save_dir is not None and osp.exists(path):
|
76
|
-
self.partition =
|
81
|
+
self.partition = fs.torch_load(path)
|
77
82
|
else:
|
78
83
|
if log: # pragma: no cover
|
79
84
|
print('Computing METIS partitioning...', file=sys.stderr)
|
@@ -82,6 +87,7 @@ class ClusterData(torch.utils.data.Dataset):
|
|
82
87
|
self.partition = self._partition(data.edge_index, cluster)
|
83
88
|
|
84
89
|
if save_dir is not None:
|
90
|
+
os.makedirs(root_dir, exist_ok=True)
|
85
91
|
torch.save(self.partition, path)
|
86
92
|
|
87
93
|
if log: # pragma: no cover
|
@@ -4,6 +4,7 @@ from typing import Optional
|
|
4
4
|
import torch
|
5
5
|
from tqdm import tqdm
|
6
6
|
|
7
|
+
from torch_geometric.io import fs
|
7
8
|
from torch_geometric.typing import SparseTensor
|
8
9
|
|
9
10
|
|
@@ -77,7 +78,7 @@ class GraphSAINTSampler(torch.utils.data.DataLoader):
|
|
77
78
|
if self.sample_coverage > 0:
|
78
79
|
path = osp.join(save_dir or '', self._filename)
|
79
80
|
if save_dir is not None and osp.exists(path): # pragma: no cover
|
80
|
-
self.node_norm, self.edge_norm =
|
81
|
+
self.node_norm, self.edge_norm = fs.torch_load(path)
|
81
82
|
else:
|
82
83
|
self.node_norm, self.edge_norm = self._compute_norm()
|
83
84
|
if save_dir is not None: # pragma: no cover
|
@@ -1,9 +1,17 @@
|
|
1
1
|
import logging
|
2
2
|
import math
|
3
|
-
from typing import
|
3
|
+
from typing import (
|
4
|
+
Any,
|
5
|
+
Callable,
|
6
|
+
Iterator,
|
7
|
+
List,
|
8
|
+
NamedTuple,
|
9
|
+
Optional,
|
10
|
+
Tuple,
|
11
|
+
Union,
|
12
|
+
)
|
4
13
|
|
5
14
|
import numpy as np
|
6
|
-
import scipy.sparse
|
7
15
|
import torch
|
8
16
|
from torch import Tensor
|
9
17
|
from tqdm import tqdm
|
@@ -281,7 +289,7 @@ def create_batchwise_out_aux_pairs(
|
|
281
289
|
return loader
|
282
290
|
|
283
291
|
|
284
|
-
def get_pairs(ppr_mat:
|
292
|
+
def get_pairs(ppr_mat: Any) -> np.ndarray:
|
285
293
|
ppr_mat = ppr_mat + ppr_mat.transpose()
|
286
294
|
|
287
295
|
ppr_mat = ppr_mat.tocoo()
|
@@ -387,7 +395,7 @@ def topk_ppr_matrix(
|
|
387
395
|
output_node_indices: Union[np.ndarray, torch.LongTensor],
|
388
396
|
topk: int,
|
389
397
|
normalization='row',
|
390
|
-
) -> Tuple[
|
398
|
+
) -> Tuple[Any, List[np.ndarray]]:
|
391
399
|
neighbors, weights = get_ppr(edge_index, alpha, eps, output_node_indices,
|
392
400
|
num_nodes)
|
393
401
|
|
torch_geometric/loader/mixin.py
CHANGED
@@ -56,7 +56,7 @@ def get_numa_nodes_cores() -> Dict[str, Any]:
|
|
56
56
|
nodes[numa_node_id] = sorted([(k, sorted(v))
|
57
57
|
for k, v in thread_siblings.items()])
|
58
58
|
|
59
|
-
except (OSError, ValueError, IndexError
|
59
|
+
except (OSError, ValueError, IndexError):
|
60
60
|
Warning('Failed to read NUMA info')
|
61
61
|
return {}
|
62
62
|
|
@@ -14,7 +14,7 @@ class NeighborLoader(NodeLoader):
|
|
14
14
|
This loader allows for mini-batch training of GNNs on large-scale graphs
|
15
15
|
where full-batch training is not feasible.
|
16
16
|
|
17
|
-
More specifically, :obj:`num_neighbors` denotes how
|
17
|
+
More specifically, :obj:`num_neighbors` denotes how many neighbors are
|
18
18
|
sampled for each node in each iteration.
|
19
19
|
:class:`~torch_geometric.loader.NeighborLoader` takes in this list of
|
20
20
|
:obj:`num_neighbors` and iteratively samples :obj:`num_neighbors[i]` for
|
@@ -72,9 +72,9 @@ class NeighborSampler(torch.utils.data.DataLoader):
|
|
72
72
|
`examples/reddit.py
|
73
73
|
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
|
74
74
|
reddit.py>`_ or
|
75
|
-
`examples/
|
75
|
+
`examples/ogbn_train.py
|
76
76
|
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
|
77
|
-
|
77
|
+
ogbn_train.py>`_.
|
78
78
|
|
79
79
|
Args:
|
80
80
|
edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a
|
@@ -73,7 +73,7 @@ class PrefetchLoader:
|
|
73
73
|
if isinstance(batch, dict):
|
74
74
|
return {k: self.non_blocking_transfer(v) for k, v in batch.items()}
|
75
75
|
|
76
|
-
batch = batch.pin_memory(
|
76
|
+
batch = batch.pin_memory()
|
77
77
|
return batch.to(self.device_helper.device, non_blocking=True)
|
78
78
|
|
79
79
|
def __iter__(self) -> Any:
|
@@ -0,0 +1,107 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
|
3
|
+
|
4
|
+
from torch_geometric.data import Data, FeatureStore, HeteroData
|
5
|
+
from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
|
6
|
+
from torch_geometric.typing import InputEdges, InputNodes
|
7
|
+
|
8
|
+
|
9
|
+
class RAGFeatureStore(Protocol):
|
10
|
+
"""Feature store template for remote GNN RAG backend."""
|
11
|
+
@abstractmethod
|
12
|
+
def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes:
|
13
|
+
"""Makes a comparison between the query and all the nodes to get all
|
14
|
+
the closest nodes. Return the indices of the nodes that are to be seeds
|
15
|
+
for the RAG Sampler.
|
16
|
+
"""
|
17
|
+
...
|
18
|
+
|
19
|
+
@abstractmethod
|
20
|
+
def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges:
|
21
|
+
"""Makes a comparison between the query and all the edges to get all
|
22
|
+
the closest nodes. Returns the edge indices that are to be the seeds
|
23
|
+
for the RAG Sampler.
|
24
|
+
"""
|
25
|
+
...
|
26
|
+
|
27
|
+
@abstractmethod
|
28
|
+
def load_subgraph(
|
29
|
+
self, sample: Union[SamplerOutput, HeteroSamplerOutput]
|
30
|
+
) -> Union[Data, HeteroData]:
|
31
|
+
"""Combines sampled subgraph output with features in a Data object."""
|
32
|
+
...
|
33
|
+
|
34
|
+
|
35
|
+
class RAGGraphStore(Protocol):
|
36
|
+
"""Graph store template for remote GNN RAG backend."""
|
37
|
+
@abstractmethod
|
38
|
+
def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges,
|
39
|
+
**kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]:
|
40
|
+
"""Sample a subgraph using the seeded nodes and edges."""
|
41
|
+
...
|
42
|
+
|
43
|
+
@abstractmethod
|
44
|
+
def register_feature_store(self, feature_store: FeatureStore):
|
45
|
+
"""Register a feature store to be used with the sampler. Samplers need
|
46
|
+
info from the feature store in order to work properly on HeteroGraphs.
|
47
|
+
"""
|
48
|
+
...
|
49
|
+
|
50
|
+
|
51
|
+
# TODO: Make compatible with Heterographs
|
52
|
+
|
53
|
+
|
54
|
+
class RAGQueryLoader:
|
55
|
+
"""Loader meant for making RAG queries from a remote backend."""
|
56
|
+
def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore],
|
57
|
+
local_filter: Optional[Callable[[Data, Any], Data]] = None,
|
58
|
+
seed_nodes_kwargs: Optional[Dict[str, Any]] = None,
|
59
|
+
seed_edges_kwargs: Optional[Dict[str, Any]] = None,
|
60
|
+
sampler_kwargs: Optional[Dict[str, Any]] = None,
|
61
|
+
loader_kwargs: Optional[Dict[str, Any]] = None):
|
62
|
+
"""Loader meant for making queries from a remote backend.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
data (Tuple[RAGFeatureStore, RAGGraphStore]): Remote FeatureStore
|
66
|
+
and GraphStore to load from. Assumed to conform to the
|
67
|
+
protocols listed above.
|
68
|
+
local_filter (Optional[Callable[[Data, Any], Data]], optional):
|
69
|
+
Optional local transform to apply to data after retrieval.
|
70
|
+
Defaults to None.
|
71
|
+
seed_nodes_kwargs (Optional[Dict[str, Any]], optional): Paramaters
|
72
|
+
to pass into process for fetching seed nodes. Defaults to None.
|
73
|
+
seed_edges_kwargs (Optional[Dict[str, Any]], optional): Parameters
|
74
|
+
to pass into process for fetching seed edges. Defaults to None.
|
75
|
+
sampler_kwargs (Optional[Dict[str, Any]], optional): Parameters to
|
76
|
+
pass into process for sampling graph. Defaults to None.
|
77
|
+
loader_kwargs (Optional[Dict[str, Any]], optional): Parameters to
|
78
|
+
pass into process for loading graph features. Defaults to None.
|
79
|
+
"""
|
80
|
+
fstore, gstore = data
|
81
|
+
self.feature_store = fstore
|
82
|
+
self.graph_store = gstore
|
83
|
+
self.graph_store.register_feature_store(self.feature_store)
|
84
|
+
self.local_filter = local_filter
|
85
|
+
self.seed_nodes_kwargs = seed_nodes_kwargs or {}
|
86
|
+
self.seed_edges_kwargs = seed_edges_kwargs or {}
|
87
|
+
self.sampler_kwargs = sampler_kwargs or {}
|
88
|
+
self.loader_kwargs = loader_kwargs or {}
|
89
|
+
|
90
|
+
def query(self, query: Any) -> Data:
|
91
|
+
"""Retrieve a subgraph associated with the query with all its feature
|
92
|
+
attributes.
|
93
|
+
"""
|
94
|
+
seed_nodes = self.feature_store.retrieve_seed_nodes(
|
95
|
+
query, **self.seed_nodes_kwargs)
|
96
|
+
seed_edges = self.feature_store.retrieve_seed_edges(
|
97
|
+
query, **self.seed_edges_kwargs)
|
98
|
+
|
99
|
+
subgraph_sample = self.graph_store.sample_subgraph(
|
100
|
+
seed_nodes, seed_edges, **self.sampler_kwargs)
|
101
|
+
|
102
|
+
data = self.feature_store.load_subgraph(sample=subgraph_sample,
|
103
|
+
**self.loader_kwargs)
|
104
|
+
|
105
|
+
if self.local_filter:
|
106
|
+
data = self.local_filter(data, query)
|
107
|
+
return data
|
torch_geometric/loader/utils.py
CHANGED
@@ -8,7 +8,6 @@ import torch
|
|
8
8
|
from torch import Tensor
|
9
9
|
|
10
10
|
import torch_geometric.typing
|
11
|
-
from torch_geometric import EdgeIndex
|
12
11
|
from torch_geometric.data import (
|
13
12
|
Data,
|
14
13
|
FeatureStore,
|
@@ -105,13 +104,15 @@ def filter_edge_store_(store: EdgeStorage, out_store: EdgeStorage, row: Tensor,
|
|
105
104
|
# which represents the new graph as denoted by `(row, col)`:
|
106
105
|
for key, value in store.items():
|
107
106
|
if key == 'edge_index':
|
107
|
+
edge_index = torch.stack([row, col], dim=0).to(value.device)
|
108
108
|
# TODO Integrate `EdgeIndex` into `custom_store`.
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
)
|
109
|
+
# edge_index = EdgeIndex(
|
110
|
+
# torch.stack([row, col], dim=0).to(value.device),
|
111
|
+
# sparse_size=out_store.size(),
|
112
|
+
# sort_order='col',
|
113
|
+
# # TODO Support `is_undirected`.
|
114
|
+
# )
|
115
|
+
out_store.edge_index = edge_index
|
115
116
|
|
116
117
|
elif key == 'adj_t':
|
117
118
|
# NOTE: We expect `(row, col)` to be sorted by `col` (CSC layout).
|
@@ -59,6 +59,16 @@ class ZipLoader(torch.utils.data.DataLoader):
|
|
59
59
|
self.loaders = loaders
|
60
60
|
self.filter_per_worker = filter_per_worker
|
61
61
|
|
62
|
+
def __call__(
|
63
|
+
self,
|
64
|
+
index: Union[Tensor, List[int]],
|
65
|
+
) -> Union[Tuple[Data, ...], Tuple[HeteroData, ...]]:
|
66
|
+
r"""Samples subgraphs from a batch of input IDs."""
|
67
|
+
out = self.collate_fn(index)
|
68
|
+
if not self.filter_per_worker:
|
69
|
+
out = self.filter_fn(out)
|
70
|
+
return out
|
71
|
+
|
62
72
|
def collate_fn(self, index: List[int]) -> Tuple[Any, ...]:
|
63
73
|
if not isinstance(index, Tensor):
|
64
74
|
index = torch.tensor(index, dtype=torch.long)
|
@@ -1,14 +1,23 @@
|
|
1
1
|
# flake8: noqa
|
2
2
|
|
3
|
-
from .link_pred import (
|
4
|
-
|
3
|
+
from .link_pred import (
|
4
|
+
LinkPredMetricCollection,
|
5
|
+
LinkPredPrecision,
|
6
|
+
LinkPredRecall,
|
7
|
+
LinkPredF1,
|
8
|
+
LinkPredMAP,
|
9
|
+
LinkPredNDCG,
|
10
|
+
LinkPredMRR,
|
11
|
+
)
|
5
12
|
|
6
13
|
link_pred_metrics = [
|
14
|
+
'LinkPredMetricCollection',
|
7
15
|
'LinkPredPrecision',
|
8
16
|
'LinkPredRecall',
|
9
17
|
'LinkPredF1',
|
10
18
|
'LinkPredMAP',
|
11
19
|
'LinkPredNDCG',
|
20
|
+
'LinkPredMRR',
|
12
21
|
]
|
13
22
|
|
14
23
|
__all__ = link_pred_metrics
|