pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +77 -53
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +226 -189
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +14 -2
- torch_geometric/_compile.py +9 -3
- torch_geometric/_onnx.py +214 -0
- torch_geometric/config_mixin.py +5 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +109 -5
- torch_geometric/data/database.py +4 -0
- torch_geometric/data/dataset.py +14 -11
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +18 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +15 -17
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +310 -209
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/partition.py +2 -2
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +18 -14
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +87 -3
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +14 -5
- torch_geometric/inspector.py +4 -0
- torch_geometric/io/fs.py +5 -4
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/{nn/nlp → llm/models}/llm.py +179 -31
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +4 -4
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +3 -2
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +14 -0
- torch_geometric/metrics/link_pred.py +745 -92
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +8 -2
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +2 -1
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +5 -4
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +0 -20
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +14 -12
- torch_geometric/nn/model_hub.py +2 -15
- torch_geometric/nn/models/__init__.py +11 -2
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +8 -4
- torch_geometric/nn/pool/cluster_pool.py +3 -4
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +336 -7
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +296 -23
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +8 -6
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -8
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/typing.py +70 -17
- torch_geometric/utils/__init__.py +4 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +27 -12
- torch_geometric/utils/_scatter.py +132 -195
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +4 -3
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +13 -9
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +7 -14
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/sentence_transformer.py +0 -101
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from torch_geometric.data import FeatureStore
|
|
7
|
+
from torch_geometric.distributed import LocalGraphStore
|
|
8
|
+
from torch_geometric.sampler import (
|
|
9
|
+
BidirectionalNeighborSampler,
|
|
10
|
+
NodeSamplerInput,
|
|
11
|
+
SamplerOutput,
|
|
12
|
+
)
|
|
13
|
+
from torch_geometric.utils import index_sort
|
|
14
|
+
|
|
15
|
+
# A representation of an edge index, following the possible formats:
|
|
16
|
+
# * default: Tensor, size = [2, num_edges]
|
|
17
|
+
# * Tensor[0, :] == row, Tensor[1, :] == col
|
|
18
|
+
# * COO: (row, col)
|
|
19
|
+
# * CSC: (row, colptr)
|
|
20
|
+
# * CSR: (rowptr, col)
|
|
21
|
+
_EdgeTensorType = Union[Tensor, Tuple[Tensor, Tensor]]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class NeighborSamplingRAGGraphStore(LocalGraphStore):
|
|
25
|
+
"""Neighbor sampling based graph-store to store & retrieve graph data."""
|
|
26
|
+
def __init__( # type: ignore[no-untyped-def]
|
|
27
|
+
self,
|
|
28
|
+
feature_store: Optional[FeatureStore] = None,
|
|
29
|
+
**kwargs,
|
|
30
|
+
):
|
|
31
|
+
"""Initializes the graph store.
|
|
32
|
+
Optional feature store and neighbor sampling settings.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
feature_store (optional): The feature store to use.
|
|
36
|
+
None if not yet registered.
|
|
37
|
+
**kwargs (optional):
|
|
38
|
+
Additional keyword arguments for neighbor sampling.
|
|
39
|
+
"""
|
|
40
|
+
self.feature_store = feature_store
|
|
41
|
+
self.sample_kwargs = kwargs
|
|
42
|
+
self._sampler_is_initialized = False
|
|
43
|
+
self._config: Dict[str, Any] = {}
|
|
44
|
+
|
|
45
|
+
# to be set by the config
|
|
46
|
+
self.num_neighbors = None
|
|
47
|
+
super().__init__()
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def config(self) -> Dict[str, Any]:
|
|
51
|
+
"""Get the config for the feature store."""
|
|
52
|
+
return self._config
|
|
53
|
+
|
|
54
|
+
def _set_from_config(self, config: Dict[str, Any], attr_name: str) -> None:
|
|
55
|
+
"""Set an attribute from the config.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
config (Dict[str, Any]): Config dictionary
|
|
59
|
+
attr_name (str): Name of attribute to set
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
ValueError: If required attribute not found in config
|
|
63
|
+
"""
|
|
64
|
+
if attr_name not in config:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
f"Required config parameter '{attr_name}' not found")
|
|
67
|
+
setattr(self, attr_name, config[attr_name])
|
|
68
|
+
|
|
69
|
+
@config.setter # type: ignore
|
|
70
|
+
def config(self, config: Dict[str, Any]) -> None:
|
|
71
|
+
"""Set the config for the feature store.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
config (Dict[str, Any]):
|
|
75
|
+
Config dictionary containing required parameters
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
ValueError: If required parameters missing from config
|
|
79
|
+
"""
|
|
80
|
+
self._set_from_config(config, "num_neighbors")
|
|
81
|
+
if hasattr(self, 'sampler'):
|
|
82
|
+
self.sampler.num_neighbors = ( # type: ignore[has-type]
|
|
83
|
+
self.num_neighbors)
|
|
84
|
+
|
|
85
|
+
self._config = config
|
|
86
|
+
|
|
87
|
+
def _init_sampler(self) -> None:
|
|
88
|
+
"""Initializes neighbor sampler with the registered feature store."""
|
|
89
|
+
if self.feature_store is None:
|
|
90
|
+
raise AttributeError("Feature store not registered yet.")
|
|
91
|
+
assert self.num_neighbors is not None, \
|
|
92
|
+
"Please set num_neighbors through config"
|
|
93
|
+
self.sampler = BidirectionalNeighborSampler(
|
|
94
|
+
data=(self.feature_store, self), num_neighbors=self.num_neighbors,
|
|
95
|
+
**self.sample_kwargs)
|
|
96
|
+
self._sampler_is_initialized = True
|
|
97
|
+
|
|
98
|
+
def register_feature_store(self, feature_store: FeatureStore) -> None:
|
|
99
|
+
"""Registers a feature store with the graph store.
|
|
100
|
+
|
|
101
|
+
:param feature_store: The feature store to register.
|
|
102
|
+
"""
|
|
103
|
+
self.feature_store = feature_store
|
|
104
|
+
self._sampler_is_initialized = False
|
|
105
|
+
|
|
106
|
+
def put_edge_id( # type: ignore[no-untyped-def]
|
|
107
|
+
self, edge_id: Tensor, *args, **kwargs) -> bool:
|
|
108
|
+
"""Stores an edge ID in the graph store.
|
|
109
|
+
|
|
110
|
+
:param edge_id: The edge ID to store.
|
|
111
|
+
:return: Whether the operation was successful.
|
|
112
|
+
"""
|
|
113
|
+
ret = super().put_edge_id(edge_id.contiguous(), *args, **kwargs)
|
|
114
|
+
self._sampler_is_initialized = False
|
|
115
|
+
return ret
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def edge_index(self) -> _EdgeTensorType:
|
|
119
|
+
"""Gets the edge index of the graph.
|
|
120
|
+
|
|
121
|
+
:return: The edge index as a tensor.
|
|
122
|
+
"""
|
|
123
|
+
return self.get_edge_index(*self.edge_idx_args, **self.edge_idx_kwargs)
|
|
124
|
+
|
|
125
|
+
def put_edge_index( # type: ignore[no-untyped-def]
|
|
126
|
+
self, edge_index: _EdgeTensorType, *args, **kwargs) -> bool:
|
|
127
|
+
"""Stores an edge index in the graph store.
|
|
128
|
+
|
|
129
|
+
:param edge_index: The edge index to store.
|
|
130
|
+
:return: Whether the operation was successful.
|
|
131
|
+
"""
|
|
132
|
+
ret = super().put_edge_index(edge_index, *args, **kwargs)
|
|
133
|
+
# HACK
|
|
134
|
+
self.edge_idx_args = args
|
|
135
|
+
self.edge_idx_kwargs = kwargs
|
|
136
|
+
self._sampler_is_initialized = False
|
|
137
|
+
return ret
|
|
138
|
+
|
|
139
|
+
# HACKY
|
|
140
|
+
@edge_index.setter # type: ignore
|
|
141
|
+
def edge_index(self, edge_index: _EdgeTensorType) -> None:
|
|
142
|
+
"""Sets the edge index of the graph.
|
|
143
|
+
|
|
144
|
+
:param edge_index: The edge index to set.
|
|
145
|
+
"""
|
|
146
|
+
# correct since we make node list from triples
|
|
147
|
+
if isinstance(edge_index, Tensor):
|
|
148
|
+
num_nodes = int(edge_index.max()) + 1
|
|
149
|
+
else:
|
|
150
|
+
assert isinstance(edge_index, tuple) \
|
|
151
|
+
and isinstance(edge_index[0], Tensor) \
|
|
152
|
+
and isinstance(edge_index[1], Tensor), \
|
|
153
|
+
"edge_index must be a Tensor of [2, num_edges] \
|
|
154
|
+
or a tuple of Tensors, (row, col)."
|
|
155
|
+
|
|
156
|
+
num_nodes = int(edge_index[0].max()) + 1
|
|
157
|
+
attr = dict(
|
|
158
|
+
edge_type=None,
|
|
159
|
+
layout='coo',
|
|
160
|
+
size=(num_nodes, num_nodes),
|
|
161
|
+
is_sorted=False,
|
|
162
|
+
)
|
|
163
|
+
# edge index needs to be sorted here and the perm saved for later
|
|
164
|
+
col_sorted, self.perm = index_sort(edge_index[1], num_nodes,
|
|
165
|
+
stable=True)
|
|
166
|
+
row_sorted = edge_index[0][self.perm]
|
|
167
|
+
edge_index_sorted = torch.stack([row_sorted, col_sorted], dim=0)
|
|
168
|
+
self.put_edge_index(edge_index_sorted, **attr)
|
|
169
|
+
|
|
170
|
+
def sample_subgraph(
|
|
171
|
+
self,
|
|
172
|
+
seed_nodes: Tensor,
|
|
173
|
+
) -> SamplerOutput:
|
|
174
|
+
"""Sample the graph starting from the given nodes using the
|
|
175
|
+
in-built NeighborSampler.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
seed_nodes (InputNodes): Seed nodes to start sampling from.
|
|
179
|
+
num_neighbors (Optional[NumNeighborsType], optional): Parameters
|
|
180
|
+
to determine how many hops and number of neighbors per hop.
|
|
181
|
+
Defaults to None.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
Union[SamplerOutput, HeteroSamplerOutput]: NeighborSamplerOutput
|
|
185
|
+
for the input.
|
|
186
|
+
"""
|
|
187
|
+
# TODO add support for Hetero
|
|
188
|
+
if not self._sampler_is_initialized:
|
|
189
|
+
self._init_sampler()
|
|
190
|
+
|
|
191
|
+
seed_nodes = seed_nodes.unique().contiguous()
|
|
192
|
+
node_sample_input = NodeSamplerInput(input_id=None, node=seed_nodes)
|
|
193
|
+
out = self.sampler.sample_from_nodes( # type: ignore[has-type]
|
|
194
|
+
node_sample_input)
|
|
195
|
+
|
|
196
|
+
# edge ids need to be remapped to the original indices
|
|
197
|
+
out.edge = self.perm[out.edge]
|
|
198
|
+
|
|
199
|
+
return out
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
# mypy: ignore-errors
|
|
2
|
+
import os
|
|
3
|
+
from abc import abstractmethod
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Protocol, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
|
|
9
|
+
from torch_geometric.data import Data
|
|
10
|
+
from torch_geometric.llm.models import SentenceTransformer
|
|
11
|
+
from torch_geometric.llm.utils.backend_utils import batch_knn
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class VectorRetriever(Protocol):
|
|
15
|
+
"""Protocol for VectorRAG."""
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def query(self, query: Any, **kwargs: Optional[Dict[str, Any]]) -> Data:
|
|
18
|
+
"""Retrieve a context for a given query."""
|
|
19
|
+
...
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DocumentRetriever(VectorRetriever):
|
|
23
|
+
"""Retrieve documents from a vector database."""
|
|
24
|
+
def __init__(self, raw_docs: List[str],
|
|
25
|
+
embedded_docs: Optional[Tensor] = None, k_for_docs: int = 2,
|
|
26
|
+
model: Optional[Union[SentenceTransformer, torch.nn.Module,
|
|
27
|
+
Callable]] = None,
|
|
28
|
+
model_kwargs: Optional[Dict[str, Any]] = None):
|
|
29
|
+
"""Retrieve documents from a vector database.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
raw_docs: List[str]: List of raw documents.
|
|
33
|
+
embedded_docs: Optional[Tensor]: Embedded documents.
|
|
34
|
+
k_for_docs: int: Number of documents to retrieve.
|
|
35
|
+
model: Optional[Union[SentenceTransformer, torch.nn.Module]]:
|
|
36
|
+
Model to use for encoding.
|
|
37
|
+
model_kwargs: Optional[Dict[str, Any]]:
|
|
38
|
+
Keyword arguments to pass to the model.
|
|
39
|
+
"""
|
|
40
|
+
self.raw_docs = raw_docs
|
|
41
|
+
self.embedded_docs = embedded_docs
|
|
42
|
+
self.k_for_docs = k_for_docs
|
|
43
|
+
self.model = model
|
|
44
|
+
|
|
45
|
+
if self.model is not None:
|
|
46
|
+
self.encoder = self.model
|
|
47
|
+
self.model_kwargs = model_kwargs
|
|
48
|
+
|
|
49
|
+
if self.embedded_docs is None:
|
|
50
|
+
assert self.model is not None, \
|
|
51
|
+
"Model must be provided if embedded_docs is not provided"
|
|
52
|
+
self.model_kwargs = model_kwargs or {}
|
|
53
|
+
self.embedded_docs = self.encoder(self.raw_docs,
|
|
54
|
+
**self.model_kwargs)
|
|
55
|
+
# we don't want to print the verbose output in `query`
|
|
56
|
+
self.model_kwargs.pop("verbose", None)
|
|
57
|
+
|
|
58
|
+
def query(self, query: Union[str, Tensor]) -> List[str]:
|
|
59
|
+
"""Retrieve documents from the vector database.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
query: Union[str, Tensor]: Query to retrieve documents for.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
List[str]: Documents retrieved from the vector database.
|
|
66
|
+
"""
|
|
67
|
+
if isinstance(query, str):
|
|
68
|
+
with torch.no_grad():
|
|
69
|
+
query_enc = self.encoder(query, **self.model_kwargs)
|
|
70
|
+
else:
|
|
71
|
+
query_enc = query
|
|
72
|
+
|
|
73
|
+
selected_doc_idxs, _ = next(
|
|
74
|
+
batch_knn(query_enc, self.embedded_docs, self.k_for_docs))
|
|
75
|
+
return [self.raw_docs[i] for i in selected_doc_idxs]
|
|
76
|
+
|
|
77
|
+
def save(self, path: str) -> None:
|
|
78
|
+
"""Save the DocumentRetriever instance to disk.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
path: str: Path where to save the retriever.
|
|
82
|
+
"""
|
|
83
|
+
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
|
84
|
+
|
|
85
|
+
# Prepare data to save
|
|
86
|
+
save_dict = {
|
|
87
|
+
'raw_docs': self.raw_docs,
|
|
88
|
+
'embedded_docs': self.embedded_docs,
|
|
89
|
+
'k_for_docs': self.k_for_docs,
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
# We do not serialize the model
|
|
93
|
+
torch.save(save_dict, path)
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def load(cls, path: str, model: Union[SentenceTransformer, torch.nn.Module,
|
|
97
|
+
Callable],
|
|
98
|
+
model_kwargs: Optional[Dict[str, Any]] = None) -> VectorRetriever:
|
|
99
|
+
"""Load a DocumentRetriever instance from disk.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
path: str: Path to the saved retriever.
|
|
103
|
+
model: Union[SentenceTransformer, torch.nn.Module, Callable]:
|
|
104
|
+
Model to use for encoding.
|
|
105
|
+
If None, the saved model will be used if available.
|
|
106
|
+
model_kwargs: Optional[Dict[str, Any]]
|
|
107
|
+
Key word args to be passed to model
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
DocumentRetriever: The loaded retriever.
|
|
111
|
+
"""
|
|
112
|
+
if not os.path.exists(path):
|
|
113
|
+
raise FileNotFoundError(
|
|
114
|
+
f"No saved document retriever found at {path}")
|
|
115
|
+
|
|
116
|
+
save_dict = torch.load(path, weights_only=False)
|
|
117
|
+
if save_dict['embedded_docs'] is not None \
|
|
118
|
+
and isinstance(save_dict['embedded_docs'], Tensor)\
|
|
119
|
+
and model_kwargs is not None:
|
|
120
|
+
model_kwargs.pop("verbose", None)
|
|
121
|
+
# Create a new DocumentRetriever with the loaded data
|
|
122
|
+
return cls(raw_docs=save_dict['raw_docs'],
|
|
123
|
+
embedded_docs=save_dict['embedded_docs'],
|
|
124
|
+
k_for_docs=save_dict['k_for_docs'], model=model,
|
|
125
|
+
model_kwargs=model_kwargs)
|
|
@@ -235,9 +235,9 @@ class ClusterData(torch.utils.data.Dataset):
|
|
|
235
235
|
class ClusterLoader(torch.utils.data.DataLoader):
|
|
236
236
|
r"""The data loader scheme from the `"Cluster-GCN: An Efficient Algorithm
|
|
237
237
|
for Training Deep and Large Graph Convolutional Networks"
|
|
238
|
-
<https://arxiv.org/abs/1905.07953>`_ paper which merges
|
|
239
|
-
and their between-cluster links from a large-scale graph data
|
|
240
|
-
form a mini-batch.
|
|
238
|
+
<https://arxiv.org/abs/1905.07953>`_ paper which merges partitioned
|
|
239
|
+
subgraphs and their between-cluster links from a large-scale graph data
|
|
240
|
+
object to form a mini-batch.
|
|
241
241
|
|
|
242
242
|
.. note::
|
|
243
243
|
|
|
@@ -252,7 +252,7 @@ class ClusterLoader(torch.utils.data.DataLoader):
|
|
|
252
252
|
|
|
253
253
|
Args:
|
|
254
254
|
cluster_data (torch_geometric.loader.ClusterData): The already
|
|
255
|
-
|
|
255
|
+
partitioned data object.
|
|
256
256
|
**kwargs (optional): Additional arguments of
|
|
257
257
|
:class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
|
|
258
258
|
:obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
|
|
@@ -148,7 +148,7 @@ def indices_complete_check(
|
|
|
148
148
|
if isinstance(aux, Tensor):
|
|
149
149
|
aux = aux.cpu().numpy()
|
|
150
150
|
|
|
151
|
-
assert np.all(np.
|
|
151
|
+
assert np.all(np.isin(out,
|
|
152
152
|
aux)), "Not all output nodes are in aux nodes!"
|
|
153
153
|
outs.append(out)
|
|
154
154
|
|
|
@@ -236,7 +236,7 @@ def create_batchwise_out_aux_pairs(
|
|
|
236
236
|
logits[tele_set, i] = 1. / len(tele_set)
|
|
237
237
|
|
|
238
238
|
new_logits = logits.clone()
|
|
239
|
-
for
|
|
239
|
+
for _ in range(num_iter):
|
|
240
240
|
new_logits = adj @ new_logits * (1 - alpha) + alpha * logits
|
|
241
241
|
|
|
242
242
|
inds = new_logits.argsort(0)
|
|
@@ -498,7 +498,7 @@ class IBMBBaseLoader(torch.utils.data.DataLoader):
|
|
|
498
498
|
assert adj is not None
|
|
499
499
|
|
|
500
500
|
for out, aux in pbar:
|
|
501
|
-
mask = torch.from_numpy(np.
|
|
501
|
+
mask = torch.from_numpy(np.isin(aux, out))
|
|
502
502
|
if isinstance(aux, np.ndarray):
|
|
503
503
|
aux = torch.from_numpy(aux)
|
|
504
504
|
subg = get_subgraph(aux, graph, return_edge_index_type, adj,
|
|
@@ -541,7 +541,7 @@ class IBMBBaseLoader(torch.utils.data.DataLoader):
|
|
|
541
541
|
out, aux = zip(*data_list)
|
|
542
542
|
out = np.concatenate(out)
|
|
543
543
|
aux = np.unique(np.concatenate(aux))
|
|
544
|
-
mask = torch.from_numpy(np.
|
|
544
|
+
mask = torch.from_numpy(np.isin(aux, out))
|
|
545
545
|
aux = torch.from_numpy(aux)
|
|
546
546
|
|
|
547
547
|
subg = get_subgraph(aux, self.graph, self.return_edge_index_type,
|
|
@@ -70,7 +70,7 @@ class LinkLoader(
|
|
|
70
70
|
:obj:`edge_label_index`. If set, temporal sampling will be
|
|
71
71
|
used such that neighbors are guaranteed to fulfill temporal
|
|
72
72
|
constraints, *i.e.*, neighbors have an earlier timestamp than
|
|
73
|
-
the
|
|
73
|
+
the output edge. The :obj:`time_attr` needs to be set for this
|
|
74
74
|
to work. (default: :obj:`None`)
|
|
75
75
|
neg_sampling (NegativeSampling, optional): The negative sampling
|
|
76
76
|
configuration.
|
|
@@ -117,7 +117,7 @@ class LinkNeighborLoader(LinkLoader):
|
|
|
117
117
|
:obj:`edge_label_index`. If set, temporal sampling will be
|
|
118
118
|
used such that neighbors are guaranteed to fulfill temporal
|
|
119
119
|
constraints, *i.e.*, neighbors have an earlier timestamp than
|
|
120
|
-
the
|
|
120
|
+
the output edge. The :obj:`time_attr` needs to be set for this
|
|
121
121
|
to work. (default: :obj:`None`)
|
|
122
122
|
replace (bool, optional): If set to :obj:`True`, will sample with
|
|
123
123
|
replacement. (default: :obj:`False`)
|
|
@@ -170,6 +170,7 @@ class LinkNeighborLoader(LinkLoader):
|
|
|
170
170
|
negative sampling mode.
|
|
171
171
|
If set to :obj:`None`, no negative sampling strategy is applied.
|
|
172
172
|
(default: :obj:`None`)
|
|
173
|
+
For example use obj:`neg_sampling=dict(mode= 'binary', amount=0.5)`
|
|
173
174
|
neg_sampling_ratio (int or float, optional): The ratio of sampled
|
|
174
175
|
negative edges to the number of positive edges.
|
|
175
176
|
Deprecated in favor of the :obj:`neg_sampling` argument.
|
torch_geometric/loader/mixin.py
CHANGED
|
@@ -106,9 +106,9 @@ class MultithreadingMixin:
|
|
|
106
106
|
def _mt_init_fn(self, worker_id: int) -> None:
|
|
107
107
|
try:
|
|
108
108
|
torch.set_num_threads(int(self._worker_threads))
|
|
109
|
-
except IndexError:
|
|
109
|
+
except IndexError as e:
|
|
110
110
|
raise ValueError(f"Cannot set {self.worker_threads} threads "
|
|
111
|
-
f"in worker {worker_id}")
|
|
111
|
+
f"in worker {worker_id}") from e
|
|
112
112
|
|
|
113
113
|
# Chain worker init functions:
|
|
114
114
|
self._old_worker_init_fn(worker_id)
|
|
@@ -213,9 +213,9 @@ class AffinityMixin:
|
|
|
213
213
|
|
|
214
214
|
psutil.Process().cpu_affinity(worker_cores)
|
|
215
215
|
|
|
216
|
-
except IndexError:
|
|
216
|
+
except IndexError as e:
|
|
217
217
|
raise ValueError(f"Cannot use CPU affinity for worker ID "
|
|
218
|
-
f"{worker_id} on CPU {self.loader_cores}")
|
|
218
|
+
f"{worker_id} on CPU {self.loader_cores}") from e
|
|
219
219
|
|
|
220
220
|
# Chain worker init functions:
|
|
221
221
|
self._old_worker_init_fn(worker_id)
|
|
@@ -248,7 +248,8 @@ class AffinityMixin:
|
|
|
248
248
|
warnings.warn(
|
|
249
249
|
"Due to conflicting parallelization methods it is not advised "
|
|
250
250
|
"to use affinitization with 'HeteroData' datasets. "
|
|
251
|
-
"Use `enable_multithreading` for better performance."
|
|
251
|
+
"Use `enable_multithreading` for better performance.",
|
|
252
|
+
stacklevel=2)
|
|
252
253
|
|
|
253
254
|
self.loader_cores = loader_cores[:] if loader_cores else None
|
|
254
255
|
if self.loader_cores is None:
|
|
@@ -14,7 +14,7 @@ class NeighborLoader(NodeLoader):
|
|
|
14
14
|
This loader allows for mini-batch training of GNNs on large-scale graphs
|
|
15
15
|
where full-batch training is not feasible.
|
|
16
16
|
|
|
17
|
-
More specifically, :obj:`num_neighbors` denotes how
|
|
17
|
+
More specifically, :obj:`num_neighbors` denotes how many neighbors are
|
|
18
18
|
sampled for each node in each iteration.
|
|
19
19
|
:class:`~torch_geometric.loader.NeighborLoader` takes in this list of
|
|
20
20
|
:obj:`num_neighbors` and iteratively samples :obj:`num_neighbors[i]` for
|
|
@@ -72,9 +72,9 @@ class NeighborSampler(torch.utils.data.DataLoader):
|
|
|
72
72
|
`examples/reddit.py
|
|
73
73
|
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
|
|
74
74
|
reddit.py>`_ or
|
|
75
|
-
`examples/
|
|
75
|
+
`examples/ogbn_train.py
|
|
76
76
|
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
|
|
77
|
-
|
|
77
|
+
ogbn_train.py>`_.
|
|
78
78
|
|
|
79
79
|
Args:
|
|
80
80
|
edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a
|
|
@@ -27,8 +27,9 @@ class DeviceHelper:
|
|
|
27
27
|
|
|
28
28
|
if ((self.device.type == 'cuda' and not with_cuda)
|
|
29
29
|
or (self.device.type == 'xpu' and not with_xpu)):
|
|
30
|
-
warnings.warn(
|
|
31
|
-
|
|
30
|
+
warnings.warn(
|
|
31
|
+
f"Requested device '{self.device.type}' is not "
|
|
32
|
+
f"available, falling back to CPU", stacklevel=2)
|
|
32
33
|
self.device = torch.device('cpu')
|
|
33
34
|
|
|
34
35
|
self.stream = None
|
|
@@ -6,7 +6,7 @@ from torch_geometric.data import TemporalData
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class TemporalDataLoader(torch.utils.data.DataLoader):
|
|
9
|
-
r"""A data loader which merges
|
|
9
|
+
r"""A data loader which merges successive events of a
|
|
10
10
|
:class:`torch_geometric.data.TemporalData` to a mini-batch.
|
|
11
11
|
|
|
12
12
|
Args:
|
|
@@ -15,7 +15,7 @@ class TemporalDataLoader(torch.utils.data.DataLoader):
|
|
|
15
15
|
batch_size (int, optional): How many samples per batch to load.
|
|
16
16
|
(default: :obj:`1`)
|
|
17
17
|
neg_sampling_ratio (float, optional): The ratio of sampled negative
|
|
18
|
-
destination nodes to the number of
|
|
18
|
+
destination nodes to the number of positive destination nodes.
|
|
19
19
|
(default: :obj:`0.0`)
|
|
20
20
|
**kwargs (optional): Additional arguments of
|
|
21
21
|
:class:`torch.utils.data.DataLoader`.
|
torch_geometric/loader/utils.py
CHANGED
|
@@ -178,7 +178,7 @@ def filter_hetero_data(
|
|
|
178
178
|
out = copy.copy(data)
|
|
179
179
|
|
|
180
180
|
for node_type in out.node_types:
|
|
181
|
-
# Handle the case of
|
|
181
|
+
# Handle the case of disconnected graph sampling:
|
|
182
182
|
if node_type not in node_dict:
|
|
183
183
|
node_dict[node_type] = torch.empty(0, dtype=torch.long)
|
|
184
184
|
|
|
@@ -186,7 +186,7 @@ def filter_hetero_data(
|
|
|
186
186
|
node_dict[node_type])
|
|
187
187
|
|
|
188
188
|
for edge_type in out.edge_types:
|
|
189
|
-
# Handle the case of
|
|
189
|
+
# Handle the case of disconnected graph sampling:
|
|
190
190
|
if edge_type not in row_dict:
|
|
191
191
|
row_dict[edge_type] = torch.empty(0, dtype=torch.long)
|
|
192
192
|
if edge_type not in col_dict:
|
|
@@ -256,14 +256,6 @@ def filter_custom_hetero_store(
|
|
|
256
256
|
# Construct a new `HeteroData` object:
|
|
257
257
|
data = custom_cls() if custom_cls is not None else HeteroData()
|
|
258
258
|
|
|
259
|
-
# Filter edge storage:
|
|
260
|
-
# TODO support edge attributes
|
|
261
|
-
for attr in graph_store.get_all_edge_attrs():
|
|
262
|
-
key = attr.edge_type
|
|
263
|
-
if key in row_dict and key in col_dict:
|
|
264
|
-
edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0)
|
|
265
|
-
data[attr.edge_type].edge_index = edge_index
|
|
266
|
-
|
|
267
259
|
# Filter node storage:
|
|
268
260
|
required_attrs = []
|
|
269
261
|
for attr in feature_store.get_all_tensor_attrs():
|
|
@@ -280,6 +272,14 @@ def filter_custom_hetero_store(
|
|
|
280
272
|
for i, attr in enumerate(required_attrs):
|
|
281
273
|
data[attr.group_name][attr.attr_name] = tensors[i]
|
|
282
274
|
|
|
275
|
+
# Filter edge storage:
|
|
276
|
+
# TODO support edge attributes
|
|
277
|
+
for attr in graph_store.get_all_edge_attrs():
|
|
278
|
+
key = attr.edge_type
|
|
279
|
+
if key in row_dict and key in col_dict:
|
|
280
|
+
edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0)
|
|
281
|
+
data[attr.edge_type].edge_index = edge_index
|
|
282
|
+
|
|
283
283
|
return data
|
|
284
284
|
|
|
285
285
|
|
|
@@ -1,21 +1,35 @@
|
|
|
1
1
|
# flake8: noqa
|
|
2
2
|
|
|
3
3
|
from .link_pred import (
|
|
4
|
+
LinkPredMetric,
|
|
5
|
+
LinkPredMetricCollection,
|
|
4
6
|
LinkPredPrecision,
|
|
5
7
|
LinkPredRecall,
|
|
6
8
|
LinkPredF1,
|
|
7
9
|
LinkPredMAP,
|
|
8
10
|
LinkPredNDCG,
|
|
9
11
|
LinkPredMRR,
|
|
12
|
+
LinkPredHitRatio,
|
|
13
|
+
LinkPredCoverage,
|
|
14
|
+
LinkPredDiversity,
|
|
15
|
+
LinkPredPersonalization,
|
|
16
|
+
LinkPredAveragePopularity,
|
|
10
17
|
)
|
|
11
18
|
|
|
12
19
|
link_pred_metrics = [
|
|
20
|
+
'LinkPredMetric',
|
|
21
|
+
'LinkPredMetricCollection',
|
|
13
22
|
'LinkPredPrecision',
|
|
14
23
|
'LinkPredRecall',
|
|
15
24
|
'LinkPredF1',
|
|
16
25
|
'LinkPredMAP',
|
|
17
26
|
'LinkPredNDCG',
|
|
18
27
|
'LinkPredMRR',
|
|
28
|
+
'LinkPredHitRatio',
|
|
29
|
+
'LinkPredCoverage',
|
|
30
|
+
'LinkPredDiversity',
|
|
31
|
+
'LinkPredPersonalization',
|
|
32
|
+
'LinkPredAveragePopularity',
|
|
19
33
|
]
|
|
20
34
|
|
|
21
35
|
__all__ = link_pred_metrics
|