pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
import gc
|
|
2
|
+
from collections.abc import Iterable, Iterator
|
|
3
|
+
from typing import Any, Dict, List, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import Tensor
|
|
7
|
+
|
|
8
|
+
from torch_geometric.data import Data, HeteroData
|
|
9
|
+
from torch_geometric.distributed import LocalFeatureStore
|
|
10
|
+
from torch_geometric.llm.utils.backend_utils import batch_knn
|
|
11
|
+
from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
|
|
12
|
+
from torch_geometric.typing import InputNodes
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# NOTE: Only compatible with Homogeneous graphs for now
|
|
16
|
+
class KNNRAGFeatureStore(LocalFeatureStore):
|
|
17
|
+
"""A feature store that uses a KNN-based retrieval."""
|
|
18
|
+
def __init__(self) -> None:
|
|
19
|
+
"""Initializes the feature store."""
|
|
20
|
+
# to be set by the config
|
|
21
|
+
self.encoder_model = None
|
|
22
|
+
self.k_nodes = None
|
|
23
|
+
self._config: Dict[str, Any] = {}
|
|
24
|
+
super().__init__()
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def config(self) -> Dict[str, Any]:
|
|
28
|
+
"""Get the config for the feature store."""
|
|
29
|
+
return self._config
|
|
30
|
+
|
|
31
|
+
def _set_from_config(self, config: Dict[str, Any], attr_name: str) -> None:
|
|
32
|
+
"""Set an attribute from the config.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
config (Dict[str, Any]): Config dictionary
|
|
36
|
+
attr_name (str): Name of attribute to set
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
ValueError: If required attribute not found in config
|
|
40
|
+
"""
|
|
41
|
+
if attr_name not in config:
|
|
42
|
+
raise ValueError(
|
|
43
|
+
f"Required config parameter '{attr_name}' not found")
|
|
44
|
+
setattr(self, attr_name, config[attr_name])
|
|
45
|
+
|
|
46
|
+
@config.setter # type: ignore
|
|
47
|
+
def config(self, config: Dict[str, Any]) -> None:
|
|
48
|
+
"""Set the config for the feature store.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
config (Dict[str, Any]):
|
|
52
|
+
Config dictionary containing required parameters
|
|
53
|
+
|
|
54
|
+
Raises:
|
|
55
|
+
ValueError: If required parameters missing from config
|
|
56
|
+
"""
|
|
57
|
+
self._set_from_config(config, "k_nodes")
|
|
58
|
+
self._set_from_config(config, "encoder_model")
|
|
59
|
+
assert self.encoder_model is not None, \
|
|
60
|
+
"Need to define encoder model from config"
|
|
61
|
+
self.encoder_model.eval()
|
|
62
|
+
|
|
63
|
+
self._config = config
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def x(self) -> Tensor:
|
|
67
|
+
"""Returns the node features."""
|
|
68
|
+
return Tensor(self.get_tensor(group_name=None, attr_name='x'))
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def edge_attr(self) -> Tensor:
|
|
72
|
+
"""Returns the edge attributes."""
|
|
73
|
+
return Tensor(
|
|
74
|
+
self.get_tensor(group_name=(None, None), attr_name='edge_attr'))
|
|
75
|
+
|
|
76
|
+
def retrieve_seed_nodes( # noqa: D417
|
|
77
|
+
self, query: Union[str, List[str],
|
|
78
|
+
Tuple[str]]) -> Tuple[InputNodes, Tensor]:
|
|
79
|
+
"""Retrieves the k_nodes most similar nodes to the given query.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
query (Union[str, List[str], Tuple[str]]): The query
|
|
83
|
+
or list of queries to search for.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
The indices of the most similar nodes and the encoded query
|
|
87
|
+
"""
|
|
88
|
+
if not isinstance(query, (list, tuple)):
|
|
89
|
+
query = [query]
|
|
90
|
+
assert self.k_nodes is not None, "please set k_nodes via config"
|
|
91
|
+
if len(query) == 1:
|
|
92
|
+
result, query_enc = next(
|
|
93
|
+
self._retrieve_seed_nodes_batch(query, self.k_nodes))
|
|
94
|
+
gc.collect()
|
|
95
|
+
torch.cuda.empty_cache()
|
|
96
|
+
return result, query_enc
|
|
97
|
+
else:
|
|
98
|
+
out_dict = {}
|
|
99
|
+
for i, out in enumerate(
|
|
100
|
+
self._retrieve_seed_nodes_batch(query, self.k_nodes)):
|
|
101
|
+
out_dict[query[i]] = out
|
|
102
|
+
gc.collect()
|
|
103
|
+
torch.cuda.empty_cache()
|
|
104
|
+
return out_dict
|
|
105
|
+
|
|
106
|
+
def _retrieve_seed_nodes_batch( # noqa: D417
|
|
107
|
+
self, query: Iterable[Any],
|
|
108
|
+
k_nodes: int) -> Iterator[Tuple[InputNodes, Tensor]]:
|
|
109
|
+
"""Retrieves the k_nodes most similar nodes to each query in the batch.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
- query (Iterable[Any]: The batch of queries to search for.
|
|
113
|
+
- k_nodes (int): The number of nodes to retrieve.
|
|
114
|
+
|
|
115
|
+
Yields:
|
|
116
|
+
- The indices of the most similar nodes for each query.
|
|
117
|
+
"""
|
|
118
|
+
if isinstance(self.meta, dict) and self.meta.get("is_hetero", False):
|
|
119
|
+
raise NotImplementedError
|
|
120
|
+
assert self.encoder_model is not None, \
|
|
121
|
+
"Need to define encoder model from config"
|
|
122
|
+
query_enc = self.encoder_model.encode(query)
|
|
123
|
+
return batch_knn(query_enc, self.x, k_nodes)
|
|
124
|
+
|
|
125
|
+
def load_subgraph( # noqa
|
|
126
|
+
self,
|
|
127
|
+
sample: Union[SamplerOutput, HeteroSamplerOutput],
|
|
128
|
+
induced: bool = True,
|
|
129
|
+
) -> Union[Data, HeteroData]:
|
|
130
|
+
"""Loads a subgraph from the given sample.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
sample: The sample to load the subgraph from.
|
|
134
|
+
induced: Whether to return the induced subgraph.
|
|
135
|
+
Resets node and edge ids.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
The loaded subgraph.
|
|
139
|
+
"""
|
|
140
|
+
if isinstance(sample, HeteroSamplerOutput):
|
|
141
|
+
raise NotImplementedError
|
|
142
|
+
"""
|
|
143
|
+
NOTE: torch_geometric.loader.utils.filter_custom_store
|
|
144
|
+
can be used here if it supported edge features.
|
|
145
|
+
"""
|
|
146
|
+
edge_id = sample.edge
|
|
147
|
+
x = self.x[sample.node]
|
|
148
|
+
edge_attr = self.edge_attr[edge_id]
|
|
149
|
+
|
|
150
|
+
edge_idx = torch.stack(
|
|
151
|
+
[sample.row, sample.col], dim=0) if induced else torch.stack(
|
|
152
|
+
[sample.global_row, sample.global_col], dim=0)
|
|
153
|
+
result = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx)
|
|
154
|
+
|
|
155
|
+
# useful for tracking what subset of the graph was sampled
|
|
156
|
+
result.node_idx = sample.node
|
|
157
|
+
result.edge_idx = edge_id
|
|
158
|
+
|
|
159
|
+
return result
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
"""
|
|
163
|
+
TODO: make class CuVSKNNRAGFeatureStore(KNNRAGFeatureStore)
|
|
164
|
+
include a approximate knn flag for the CuVS.
|
|
165
|
+
Connect this with a CuGraphGraphStore
|
|
166
|
+
for enabling a accelerated boolean flag for RAGQueryLoader.
|
|
167
|
+
On by default if CuGraph+CuVS avail.
|
|
168
|
+
If not raise note mentioning its speedup.
|
|
169
|
+
"""
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from torch_geometric.data import FeatureStore
|
|
7
|
+
from torch_geometric.distributed import LocalGraphStore
|
|
8
|
+
from torch_geometric.sampler import (
|
|
9
|
+
BidirectionalNeighborSampler,
|
|
10
|
+
NodeSamplerInput,
|
|
11
|
+
SamplerOutput,
|
|
12
|
+
)
|
|
13
|
+
from torch_geometric.utils import index_sort
|
|
14
|
+
|
|
15
|
+
# A representation of an edge index, following the possible formats:
|
|
16
|
+
# * default: Tensor, size = [2, num_edges]
|
|
17
|
+
# * Tensor[0, :] == row, Tensor[1, :] == col
|
|
18
|
+
# * COO: (row, col)
|
|
19
|
+
# * CSC: (row, colptr)
|
|
20
|
+
# * CSR: (rowptr, col)
|
|
21
|
+
_EdgeTensorType = Union[Tensor, Tuple[Tensor, Tensor]]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class NeighborSamplingRAGGraphStore(LocalGraphStore):
|
|
25
|
+
"""Neighbor sampling based graph-store to store & retrieve graph data."""
|
|
26
|
+
def __init__( # type: ignore[no-untyped-def]
|
|
27
|
+
self,
|
|
28
|
+
feature_store: Optional[FeatureStore] = None,
|
|
29
|
+
**kwargs,
|
|
30
|
+
):
|
|
31
|
+
"""Initializes the graph store.
|
|
32
|
+
Optional feature store and neighbor sampling settings.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
feature_store (optional): The feature store to use.
|
|
36
|
+
None if not yet registered.
|
|
37
|
+
**kwargs (optional):
|
|
38
|
+
Additional keyword arguments for neighbor sampling.
|
|
39
|
+
"""
|
|
40
|
+
self.feature_store = feature_store
|
|
41
|
+
self.sample_kwargs = kwargs
|
|
42
|
+
self._sampler_is_initialized = False
|
|
43
|
+
self._config: Dict[str, Any] = {}
|
|
44
|
+
|
|
45
|
+
# to be set by the config
|
|
46
|
+
self.num_neighbors = None
|
|
47
|
+
super().__init__()
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def config(self) -> Dict[str, Any]:
|
|
51
|
+
"""Get the config for the feature store."""
|
|
52
|
+
return self._config
|
|
53
|
+
|
|
54
|
+
def _set_from_config(self, config: Dict[str, Any], attr_name: str) -> None:
|
|
55
|
+
"""Set an attribute from the config.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
config (Dict[str, Any]): Config dictionary
|
|
59
|
+
attr_name (str): Name of attribute to set
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
ValueError: If required attribute not found in config
|
|
63
|
+
"""
|
|
64
|
+
if attr_name not in config:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
f"Required config parameter '{attr_name}' not found")
|
|
67
|
+
setattr(self, attr_name, config[attr_name])
|
|
68
|
+
|
|
69
|
+
@config.setter # type: ignore
|
|
70
|
+
def config(self, config: Dict[str, Any]) -> None:
|
|
71
|
+
"""Set the config for the feature store.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
config (Dict[str, Any]):
|
|
75
|
+
Config dictionary containing required parameters
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
ValueError: If required parameters missing from config
|
|
79
|
+
"""
|
|
80
|
+
self._set_from_config(config, "num_neighbors")
|
|
81
|
+
if hasattr(self, 'sampler'):
|
|
82
|
+
self.sampler.num_neighbors = ( # type: ignore[has-type]
|
|
83
|
+
self.num_neighbors)
|
|
84
|
+
|
|
85
|
+
self._config = config
|
|
86
|
+
|
|
87
|
+
def _init_sampler(self) -> None:
|
|
88
|
+
"""Initializes neighbor sampler with the registered feature store."""
|
|
89
|
+
if self.feature_store is None:
|
|
90
|
+
raise AttributeError("Feature store not registered yet.")
|
|
91
|
+
assert self.num_neighbors is not None, \
|
|
92
|
+
"Please set num_neighbors through config"
|
|
93
|
+
self.sampler = BidirectionalNeighborSampler(
|
|
94
|
+
data=(self.feature_store, self), num_neighbors=self.num_neighbors,
|
|
95
|
+
**self.sample_kwargs)
|
|
96
|
+
self._sampler_is_initialized = True
|
|
97
|
+
|
|
98
|
+
def register_feature_store(self, feature_store: FeatureStore) -> None:
|
|
99
|
+
"""Registers a feature store with the graph store.
|
|
100
|
+
|
|
101
|
+
:param feature_store: The feature store to register.
|
|
102
|
+
"""
|
|
103
|
+
self.feature_store = feature_store
|
|
104
|
+
self._sampler_is_initialized = False
|
|
105
|
+
|
|
106
|
+
def put_edge_id( # type: ignore[no-untyped-def]
|
|
107
|
+
self, edge_id: Tensor, *args, **kwargs) -> bool:
|
|
108
|
+
"""Stores an edge ID in the graph store.
|
|
109
|
+
|
|
110
|
+
:param edge_id: The edge ID to store.
|
|
111
|
+
:return: Whether the operation was successful.
|
|
112
|
+
"""
|
|
113
|
+
ret = super().put_edge_id(edge_id.contiguous(), *args, **kwargs)
|
|
114
|
+
self._sampler_is_initialized = False
|
|
115
|
+
return ret
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def edge_index(self) -> _EdgeTensorType:
|
|
119
|
+
"""Gets the edge index of the graph.
|
|
120
|
+
|
|
121
|
+
:return: The edge index as a tensor.
|
|
122
|
+
"""
|
|
123
|
+
return self.get_edge_index(*self.edge_idx_args, **self.edge_idx_kwargs)
|
|
124
|
+
|
|
125
|
+
def put_edge_index( # type: ignore[no-untyped-def]
|
|
126
|
+
self, edge_index: _EdgeTensorType, *args, **kwargs) -> bool:
|
|
127
|
+
"""Stores an edge index in the graph store.
|
|
128
|
+
|
|
129
|
+
:param edge_index: The edge index to store.
|
|
130
|
+
:return: Whether the operation was successful.
|
|
131
|
+
"""
|
|
132
|
+
ret = super().put_edge_index(edge_index, *args, **kwargs)
|
|
133
|
+
# HACK
|
|
134
|
+
self.edge_idx_args = args
|
|
135
|
+
self.edge_idx_kwargs = kwargs
|
|
136
|
+
self._sampler_is_initialized = False
|
|
137
|
+
return ret
|
|
138
|
+
|
|
139
|
+
# HACKY
|
|
140
|
+
@edge_index.setter # type: ignore
|
|
141
|
+
def edge_index(self, edge_index: _EdgeTensorType) -> None:
|
|
142
|
+
"""Sets the edge index of the graph.
|
|
143
|
+
|
|
144
|
+
:param edge_index: The edge index to set.
|
|
145
|
+
"""
|
|
146
|
+
# correct since we make node list from triples
|
|
147
|
+
if isinstance(edge_index, Tensor):
|
|
148
|
+
num_nodes = int(edge_index.max()) + 1
|
|
149
|
+
else:
|
|
150
|
+
assert isinstance(edge_index, tuple) \
|
|
151
|
+
and isinstance(edge_index[0], Tensor) \
|
|
152
|
+
and isinstance(edge_index[1], Tensor), \
|
|
153
|
+
"edge_index must be a Tensor of [2, num_edges] \
|
|
154
|
+
or a tuple of Tensors, (row, col)."
|
|
155
|
+
|
|
156
|
+
num_nodes = int(edge_index[0].max()) + 1
|
|
157
|
+
attr = dict(
|
|
158
|
+
edge_type=None,
|
|
159
|
+
layout='coo',
|
|
160
|
+
size=(num_nodes, num_nodes),
|
|
161
|
+
is_sorted=False,
|
|
162
|
+
)
|
|
163
|
+
# edge index needs to be sorted here and the perm saved for later
|
|
164
|
+
col_sorted, self.perm = index_sort(edge_index[1], num_nodes,
|
|
165
|
+
stable=True)
|
|
166
|
+
row_sorted = edge_index[0][self.perm]
|
|
167
|
+
edge_index_sorted = torch.stack([row_sorted, col_sorted], dim=0)
|
|
168
|
+
self.put_edge_index(edge_index_sorted, **attr)
|
|
169
|
+
|
|
170
|
+
def sample_subgraph(
|
|
171
|
+
self,
|
|
172
|
+
seed_nodes: Tensor,
|
|
173
|
+
) -> SamplerOutput:
|
|
174
|
+
"""Sample the graph starting from the given nodes using the
|
|
175
|
+
in-built NeighborSampler.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
seed_nodes (InputNodes): Seed nodes to start sampling from.
|
|
179
|
+
num_neighbors (Optional[NumNeighborsType], optional): Parameters
|
|
180
|
+
to determine how many hops and number of neighbors per hop.
|
|
181
|
+
Defaults to None.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
Union[SamplerOutput, HeteroSamplerOutput]: NeighborSamplerOutput
|
|
185
|
+
for the input.
|
|
186
|
+
"""
|
|
187
|
+
# TODO add support for Hetero
|
|
188
|
+
if not self._sampler_is_initialized:
|
|
189
|
+
self._init_sampler()
|
|
190
|
+
|
|
191
|
+
seed_nodes = seed_nodes.unique().contiguous()
|
|
192
|
+
node_sample_input = NodeSamplerInput(input_id=None, node=seed_nodes)
|
|
193
|
+
out = self.sampler.sample_from_nodes( # type: ignore[has-type]
|
|
194
|
+
node_sample_input)
|
|
195
|
+
|
|
196
|
+
# edge ids need to be remapped to the original indices
|
|
197
|
+
out.edge = self.perm[out.edge]
|
|
198
|
+
|
|
199
|
+
return out
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
# mypy: ignore-errors
|
|
2
|
+
import os
|
|
3
|
+
from abc import abstractmethod
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Protocol, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
|
|
9
|
+
from torch_geometric.data import Data
|
|
10
|
+
from torch_geometric.llm.models import SentenceTransformer
|
|
11
|
+
from torch_geometric.llm.utils.backend_utils import batch_knn
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class VectorRetriever(Protocol):
|
|
15
|
+
"""Protocol for VectorRAG."""
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def query(self, query: Any, **kwargs: Optional[Dict[str, Any]]) -> Data:
|
|
18
|
+
"""Retrieve a context for a given query."""
|
|
19
|
+
...
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DocumentRetriever(VectorRetriever):
|
|
23
|
+
"""Retrieve documents from a vector database."""
|
|
24
|
+
def __init__(self, raw_docs: List[str],
|
|
25
|
+
embedded_docs: Optional[Tensor] = None, k_for_docs: int = 2,
|
|
26
|
+
model: Optional[Union[SentenceTransformer, torch.nn.Module,
|
|
27
|
+
Callable]] = None,
|
|
28
|
+
model_kwargs: Optional[Dict[str, Any]] = None):
|
|
29
|
+
"""Retrieve documents from a vector database.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
raw_docs: List[str]: List of raw documents.
|
|
33
|
+
embedded_docs: Optional[Tensor]: Embedded documents.
|
|
34
|
+
k_for_docs: int: Number of documents to retrieve.
|
|
35
|
+
model: Optional[Union[SentenceTransformer, torch.nn.Module]]:
|
|
36
|
+
Model to use for encoding.
|
|
37
|
+
model_kwargs: Optional[Dict[str, Any]]:
|
|
38
|
+
Keyword arguments to pass to the model.
|
|
39
|
+
"""
|
|
40
|
+
self.raw_docs = raw_docs
|
|
41
|
+
self.embedded_docs = embedded_docs
|
|
42
|
+
self.k_for_docs = k_for_docs
|
|
43
|
+
self.model = model
|
|
44
|
+
|
|
45
|
+
if self.model is not None:
|
|
46
|
+
self.encoder = self.model
|
|
47
|
+
self.model_kwargs = model_kwargs
|
|
48
|
+
|
|
49
|
+
if self.embedded_docs is None:
|
|
50
|
+
assert self.model is not None, \
|
|
51
|
+
"Model must be provided if embedded_docs is not provided"
|
|
52
|
+
self.model_kwargs = model_kwargs or {}
|
|
53
|
+
self.embedded_docs = self.encoder(self.raw_docs,
|
|
54
|
+
**self.model_kwargs)
|
|
55
|
+
# we don't want to print the verbose output in `query`
|
|
56
|
+
self.model_kwargs.pop("verbose", None)
|
|
57
|
+
|
|
58
|
+
def query(self, query: Union[str, Tensor]) -> List[str]:
|
|
59
|
+
"""Retrieve documents from the vector database.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
query: Union[str, Tensor]: Query to retrieve documents for.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
List[str]: Documents retrieved from the vector database.
|
|
66
|
+
"""
|
|
67
|
+
if isinstance(query, str):
|
|
68
|
+
with torch.no_grad():
|
|
69
|
+
query_enc = self.encoder(query, **self.model_kwargs)
|
|
70
|
+
else:
|
|
71
|
+
query_enc = query
|
|
72
|
+
|
|
73
|
+
selected_doc_idxs, _ = next(
|
|
74
|
+
batch_knn(query_enc, self.embedded_docs, self.k_for_docs))
|
|
75
|
+
return [self.raw_docs[i] for i in selected_doc_idxs]
|
|
76
|
+
|
|
77
|
+
def save(self, path: str) -> None:
|
|
78
|
+
"""Save the DocumentRetriever instance to disk.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
path: str: Path where to save the retriever.
|
|
82
|
+
"""
|
|
83
|
+
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
|
84
|
+
|
|
85
|
+
# Prepare data to save
|
|
86
|
+
save_dict = {
|
|
87
|
+
'raw_docs': self.raw_docs,
|
|
88
|
+
'embedded_docs': self.embedded_docs,
|
|
89
|
+
'k_for_docs': self.k_for_docs,
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
# We do not serialize the model
|
|
93
|
+
torch.save(save_dict, path)
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def load(cls, path: str, model: Union[SentenceTransformer, torch.nn.Module,
|
|
97
|
+
Callable],
|
|
98
|
+
model_kwargs: Optional[Dict[str, Any]] = None) -> VectorRetriever:
|
|
99
|
+
"""Load a DocumentRetriever instance from disk.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
path: str: Path to the saved retriever.
|
|
103
|
+
model: Union[SentenceTransformer, torch.nn.Module, Callable]:
|
|
104
|
+
Model to use for encoding.
|
|
105
|
+
If None, the saved model will be used if available.
|
|
106
|
+
model_kwargs: Optional[Dict[str, Any]]
|
|
107
|
+
Key word args to be passed to model
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
DocumentRetriever: The loaded retriever.
|
|
111
|
+
"""
|
|
112
|
+
if not os.path.exists(path):
|
|
113
|
+
raise FileNotFoundError(
|
|
114
|
+
f"No saved document retriever found at {path}")
|
|
115
|
+
|
|
116
|
+
save_dict = torch.load(path, weights_only=False)
|
|
117
|
+
if save_dict['embedded_docs'] is not None \
|
|
118
|
+
and isinstance(save_dict['embedded_docs'], Tensor)\
|
|
119
|
+
and model_kwargs is not None:
|
|
120
|
+
model_kwargs.pop("verbose", None)
|
|
121
|
+
# Create a new DocumentRetriever with the loaded data
|
|
122
|
+
return cls(raw_docs=save_dict['raw_docs'],
|
|
123
|
+
embedded_docs=save_dict['embedded_docs'],
|
|
124
|
+
k_for_docs=save_dict['k_for_docs'], model=model,
|
|
125
|
+
model_kwargs=model_kwargs)
|
|
@@ -12,6 +12,7 @@ from torch import Tensor
|
|
|
12
12
|
import torch_geometric.typing
|
|
13
13
|
from torch_geometric.data import Data
|
|
14
14
|
from torch_geometric.index import index2ptr, ptr2index
|
|
15
|
+
from torch_geometric.io import fs
|
|
15
16
|
from torch_geometric.typing import pyg_lib
|
|
16
17
|
from torch_geometric.utils import index_sort, narrow, select, sort_edge_index
|
|
17
18
|
from torch_geometric.utils.map import map_index
|
|
@@ -77,7 +78,7 @@ class ClusterData(torch.utils.data.Dataset):
|
|
|
77
78
|
path = osp.join(root_dir, filename or 'metis.pt')
|
|
78
79
|
|
|
79
80
|
if save_dir is not None and osp.exists(path):
|
|
80
|
-
self.partition =
|
|
81
|
+
self.partition = fs.torch_load(path)
|
|
81
82
|
else:
|
|
82
83
|
if log: # pragma: no cover
|
|
83
84
|
print('Computing METIS partitioning...', file=sys.stderr)
|
|
@@ -234,9 +235,9 @@ class ClusterData(torch.utils.data.Dataset):
|
|
|
234
235
|
class ClusterLoader(torch.utils.data.DataLoader):
|
|
235
236
|
r"""The data loader scheme from the `"Cluster-GCN: An Efficient Algorithm
|
|
236
237
|
for Training Deep and Large Graph Convolutional Networks"
|
|
237
|
-
<https://arxiv.org/abs/1905.07953>`_ paper which merges
|
|
238
|
-
and their between-cluster links from a large-scale graph data
|
|
239
|
-
form a mini-batch.
|
|
238
|
+
<https://arxiv.org/abs/1905.07953>`_ paper which merges partitioned
|
|
239
|
+
subgraphs and their between-cluster links from a large-scale graph data
|
|
240
|
+
object to form a mini-batch.
|
|
240
241
|
|
|
241
242
|
.. note::
|
|
242
243
|
|
|
@@ -251,7 +252,7 @@ class ClusterLoader(torch.utils.data.DataLoader):
|
|
|
251
252
|
|
|
252
253
|
Args:
|
|
253
254
|
cluster_data (torch_geometric.loader.ClusterData): The already
|
|
254
|
-
|
|
255
|
+
partitioned data object.
|
|
255
256
|
**kwargs (optional): Additional arguments of
|
|
256
257
|
:class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
|
|
257
258
|
:obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
|
|
@@ -4,6 +4,7 @@ from typing import Optional
|
|
|
4
4
|
import torch
|
|
5
5
|
from tqdm import tqdm
|
|
6
6
|
|
|
7
|
+
from torch_geometric.io import fs
|
|
7
8
|
from torch_geometric.typing import SparseTensor
|
|
8
9
|
|
|
9
10
|
|
|
@@ -77,7 +78,7 @@ class GraphSAINTSampler(torch.utils.data.DataLoader):
|
|
|
77
78
|
if self.sample_coverage > 0:
|
|
78
79
|
path = osp.join(save_dir or '', self._filename)
|
|
79
80
|
if save_dir is not None and osp.exists(path): # pragma: no cover
|
|
80
|
-
self.node_norm, self.edge_norm =
|
|
81
|
+
self.node_norm, self.edge_norm = fs.torch_load(path)
|
|
81
82
|
else:
|
|
82
83
|
self.node_norm, self.edge_norm = self._compute_norm()
|
|
83
84
|
if save_dir is not None: # pragma: no cover
|
|
@@ -148,7 +148,7 @@ def indices_complete_check(
|
|
|
148
148
|
if isinstance(aux, Tensor):
|
|
149
149
|
aux = aux.cpu().numpy()
|
|
150
150
|
|
|
151
|
-
assert np.all(np.
|
|
151
|
+
assert np.all(np.isin(out,
|
|
152
152
|
aux)), "Not all output nodes are in aux nodes!"
|
|
153
153
|
outs.append(out)
|
|
154
154
|
|
|
@@ -236,7 +236,7 @@ def create_batchwise_out_aux_pairs(
|
|
|
236
236
|
logits[tele_set, i] = 1. / len(tele_set)
|
|
237
237
|
|
|
238
238
|
new_logits = logits.clone()
|
|
239
|
-
for
|
|
239
|
+
for _ in range(num_iter):
|
|
240
240
|
new_logits = adj @ new_logits * (1 - alpha) + alpha * logits
|
|
241
241
|
|
|
242
242
|
inds = new_logits.argsort(0)
|
|
@@ -498,7 +498,7 @@ class IBMBBaseLoader(torch.utils.data.DataLoader):
|
|
|
498
498
|
assert adj is not None
|
|
499
499
|
|
|
500
500
|
for out, aux in pbar:
|
|
501
|
-
mask = torch.from_numpy(np.
|
|
501
|
+
mask = torch.from_numpy(np.isin(aux, out))
|
|
502
502
|
if isinstance(aux, np.ndarray):
|
|
503
503
|
aux = torch.from_numpy(aux)
|
|
504
504
|
subg = get_subgraph(aux, graph, return_edge_index_type, adj,
|
|
@@ -541,7 +541,7 @@ class IBMBBaseLoader(torch.utils.data.DataLoader):
|
|
|
541
541
|
out, aux = zip(*data_list)
|
|
542
542
|
out = np.concatenate(out)
|
|
543
543
|
aux = np.unique(np.concatenate(aux))
|
|
544
|
-
mask = torch.from_numpy(np.
|
|
544
|
+
mask = torch.from_numpy(np.isin(aux, out))
|
|
545
545
|
aux = torch.from_numpy(aux)
|
|
546
546
|
|
|
547
547
|
subg = get_subgraph(aux, self.graph, self.return_edge_index_type,
|
|
@@ -70,7 +70,7 @@ class LinkLoader(
|
|
|
70
70
|
:obj:`edge_label_index`. If set, temporal sampling will be
|
|
71
71
|
used such that neighbors are guaranteed to fulfill temporal
|
|
72
72
|
constraints, *i.e.*, neighbors have an earlier timestamp than
|
|
73
|
-
the
|
|
73
|
+
the output edge. The :obj:`time_attr` needs to be set for this
|
|
74
74
|
to work. (default: :obj:`None`)
|
|
75
75
|
neg_sampling (NegativeSampling, optional): The negative sampling
|
|
76
76
|
configuration.
|
|
@@ -117,7 +117,7 @@ class LinkNeighborLoader(LinkLoader):
|
|
|
117
117
|
:obj:`edge_label_index`. If set, temporal sampling will be
|
|
118
118
|
used such that neighbors are guaranteed to fulfill temporal
|
|
119
119
|
constraints, *i.e.*, neighbors have an earlier timestamp than
|
|
120
|
-
the
|
|
120
|
+
the output edge. The :obj:`time_attr` needs to be set for this
|
|
121
121
|
to work. (default: :obj:`None`)
|
|
122
122
|
replace (bool, optional): If set to :obj:`True`, will sample with
|
|
123
123
|
replacement. (default: :obj:`False`)
|
|
@@ -170,6 +170,7 @@ class LinkNeighborLoader(LinkLoader):
|
|
|
170
170
|
negative sampling mode.
|
|
171
171
|
If set to :obj:`None`, no negative sampling strategy is applied.
|
|
172
172
|
(default: :obj:`None`)
|
|
173
|
+
For example use obj:`neg_sampling=dict(mode= 'binary', amount=0.5)`
|
|
173
174
|
neg_sampling_ratio (int or float, optional): The ratio of sampled
|
|
174
175
|
negative edges to the number of positive edges.
|
|
175
176
|
Deprecated in favor of the :obj:`neg_sampling` argument.
|
torch_geometric/loader/mixin.py
CHANGED
|
@@ -106,9 +106,9 @@ class MultithreadingMixin:
|
|
|
106
106
|
def _mt_init_fn(self, worker_id: int) -> None:
|
|
107
107
|
try:
|
|
108
108
|
torch.set_num_threads(int(self._worker_threads))
|
|
109
|
-
except IndexError:
|
|
109
|
+
except IndexError as e:
|
|
110
110
|
raise ValueError(f"Cannot set {self.worker_threads} threads "
|
|
111
|
-
f"in worker {worker_id}")
|
|
111
|
+
f"in worker {worker_id}") from e
|
|
112
112
|
|
|
113
113
|
# Chain worker init functions:
|
|
114
114
|
self._old_worker_init_fn(worker_id)
|
|
@@ -213,9 +213,9 @@ class AffinityMixin:
|
|
|
213
213
|
|
|
214
214
|
psutil.Process().cpu_affinity(worker_cores)
|
|
215
215
|
|
|
216
|
-
except IndexError:
|
|
216
|
+
except IndexError as e:
|
|
217
217
|
raise ValueError(f"Cannot use CPU affinity for worker ID "
|
|
218
|
-
f"{worker_id} on CPU {self.loader_cores}")
|
|
218
|
+
f"{worker_id} on CPU {self.loader_cores}") from e
|
|
219
219
|
|
|
220
220
|
# Chain worker init functions:
|
|
221
221
|
self._old_worker_init_fn(worker_id)
|
|
@@ -248,7 +248,8 @@ class AffinityMixin:
|
|
|
248
248
|
warnings.warn(
|
|
249
249
|
"Due to conflicting parallelization methods it is not advised "
|
|
250
250
|
"to use affinitization with 'HeteroData' datasets. "
|
|
251
|
-
"Use `enable_multithreading` for better performance."
|
|
251
|
+
"Use `enable_multithreading` for better performance.",
|
|
252
|
+
stacklevel=2)
|
|
252
253
|
|
|
253
254
|
self.loader_cores = loader_cores[:] if loader_cores else None
|
|
254
255
|
if self.loader_cores is None:
|
|
@@ -14,7 +14,7 @@ class NeighborLoader(NodeLoader):
|
|
|
14
14
|
This loader allows for mini-batch training of GNNs on large-scale graphs
|
|
15
15
|
where full-batch training is not feasible.
|
|
16
16
|
|
|
17
|
-
More specifically, :obj:`num_neighbors` denotes how
|
|
17
|
+
More specifically, :obj:`num_neighbors` denotes how many neighbors are
|
|
18
18
|
sampled for each node in each iteration.
|
|
19
19
|
:class:`~torch_geometric.loader.NeighborLoader` takes in this list of
|
|
20
20
|
:obj:`num_neighbors` and iteratively samples :obj:`num_neighbors[i]` for
|
|
@@ -72,9 +72,9 @@ class NeighborSampler(torch.utils.data.DataLoader):
|
|
|
72
72
|
`examples/reddit.py
|
|
73
73
|
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
|
|
74
74
|
reddit.py>`_ or
|
|
75
|
-
`examples/
|
|
75
|
+
`examples/ogbn_train.py
|
|
76
76
|
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
|
|
77
|
-
|
|
77
|
+
ogbn_train.py>`_.
|
|
78
78
|
|
|
79
79
|
Args:
|
|
80
80
|
edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a
|