pyg-nightly 2.7.0.dev20250905__py3-none-any.whl → 2.7.0.dev20250907__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250907.dist-info}/METADATA +2 -1
- {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250907.dist-info}/RECORD +32 -25
- torch_geometric/__init__.py +1 -1
- torch_geometric/data/__init__.py +0 -5
- torch_geometric/data/lightning/datamodule.py +2 -2
- torch_geometric/datasets/molecule_gpt_dataset.py +1 -1
- torch_geometric/datasets/web_qsp_dataset.py +262 -210
- torch_geometric/graphgym/imports.py +2 -2
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/{data → llm}/large_graph_indexer.py +124 -61
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +68 -49
- torch_geometric/{nn → llm}/models/git_mol.py +1 -1
- torch_geometric/{nn/nlp → llm/models}/llm.py +167 -33
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/{nn → llm}/models/molecule_gpt.py +1 -1
- torch_geometric/{nn/nlp → llm/models}/sentence_transformer.py +42 -8
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/backend_utils.py +442 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +124 -0
- torch_geometric/loader/__init__.py +0 -4
- torch_geometric/nn/__init__.py +0 -1
- torch_geometric/nn/models/__init__.py +0 -10
- torch_geometric/nn/models/sgformer.py +2 -0
- torch_geometric/loader/rag_loader.py +0 -107
- torch_geometric/nn/nlp/__init__.py +0 -9
- {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250907.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250907.dist-info}/licenses/LICENSE +0 -0
- /torch_geometric/{nn → llm}/models/glem.py +0 -0
- /torch_geometric/{nn → llm}/models/protein_mpnn.py +0 -0
- /torch_geometric/{nn/nlp → llm/models}/vision_transformer.py +0 -0
@@ -0,0 +1,199 @@
|
|
1
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import Tensor
|
5
|
+
|
6
|
+
from torch_geometric.data import FeatureStore
|
7
|
+
from torch_geometric.distributed import LocalGraphStore
|
8
|
+
from torch_geometric.sampler import (
|
9
|
+
BidirectionalNeighborSampler,
|
10
|
+
NodeSamplerInput,
|
11
|
+
SamplerOutput,
|
12
|
+
)
|
13
|
+
from torch_geometric.utils import index_sort
|
14
|
+
|
15
|
+
# A representation of an edge index, following the possible formats:
|
16
|
+
# * default: Tensor, size = [2, num_edges]
|
17
|
+
# * Tensor[0, :] == row, Tensor[1, :] == col
|
18
|
+
# * COO: (row, col)
|
19
|
+
# * CSC: (row, colptr)
|
20
|
+
# * CSR: (rowptr, col)
|
21
|
+
_EdgeTensorType = Union[Tensor, Tuple[Tensor, Tensor]]
|
22
|
+
|
23
|
+
|
24
|
+
class NeighborSamplingRAGGraphStore(LocalGraphStore):
|
25
|
+
"""Neighbor sampling based graph-store to store & retrieve graph data."""
|
26
|
+
def __init__( # type: ignore[no-untyped-def]
|
27
|
+
self,
|
28
|
+
feature_store: Optional[FeatureStore] = None,
|
29
|
+
**kwargs,
|
30
|
+
):
|
31
|
+
"""Initializes the graph store.
|
32
|
+
Optional feature store and neighbor sampling settings.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
feature_store (optional): The feature store to use.
|
36
|
+
None if not yet registered.
|
37
|
+
**kwargs (optional):
|
38
|
+
Additional keyword arguments for neighbor sampling.
|
39
|
+
"""
|
40
|
+
self.feature_store = feature_store
|
41
|
+
self.sample_kwargs = kwargs
|
42
|
+
self._sampler_is_initialized = False
|
43
|
+
self._config: Dict[str, Any] = {}
|
44
|
+
|
45
|
+
# to be set by the config
|
46
|
+
self.num_neighbors = None
|
47
|
+
super().__init__()
|
48
|
+
|
49
|
+
@property
|
50
|
+
def config(self) -> Dict[str, Any]:
|
51
|
+
"""Get the config for the feature store."""
|
52
|
+
return self._config
|
53
|
+
|
54
|
+
def _set_from_config(self, config: Dict[str, Any], attr_name: str) -> None:
|
55
|
+
"""Set an attribute from the config.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
config (Dict[str, Any]): Config dictionary
|
59
|
+
attr_name (str): Name of attribute to set
|
60
|
+
|
61
|
+
Raises:
|
62
|
+
ValueError: If required attribute not found in config
|
63
|
+
"""
|
64
|
+
if attr_name not in config:
|
65
|
+
raise ValueError(
|
66
|
+
f"Required config parameter '{attr_name}' not found")
|
67
|
+
setattr(self, attr_name, config[attr_name])
|
68
|
+
|
69
|
+
@config.setter # type: ignore
|
70
|
+
def config(self, config: Dict[str, Any]) -> None:
|
71
|
+
"""Set the config for the feature store.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
config (Dict[str, Any]):
|
75
|
+
Config dictionary containing required parameters
|
76
|
+
|
77
|
+
Raises:
|
78
|
+
ValueError: If required parameters missing from config
|
79
|
+
"""
|
80
|
+
self._set_from_config(config, "num_neighbors")
|
81
|
+
if hasattr(self, 'sampler'):
|
82
|
+
self.sampler.num_neighbors = ( # type: ignore[has-type]
|
83
|
+
self.num_neighbors)
|
84
|
+
|
85
|
+
self._config = config
|
86
|
+
|
87
|
+
def _init_sampler(self) -> None:
|
88
|
+
"""Initializes neighbor sampler with the registered feature store."""
|
89
|
+
if self.feature_store is None:
|
90
|
+
raise AttributeError("Feature store not registered yet.")
|
91
|
+
assert self.num_neighbors is not None, \
|
92
|
+
"Please set num_neighbors through config"
|
93
|
+
self.sampler = BidirectionalNeighborSampler(
|
94
|
+
data=(self.feature_store, self), num_neighbors=self.num_neighbors,
|
95
|
+
**self.sample_kwargs)
|
96
|
+
self._sampler_is_initialized = True
|
97
|
+
|
98
|
+
def register_feature_store(self, feature_store: FeatureStore) -> None:
|
99
|
+
"""Registers a feature store with the graph store.
|
100
|
+
|
101
|
+
:param feature_store: The feature store to register.
|
102
|
+
"""
|
103
|
+
self.feature_store = feature_store
|
104
|
+
self._sampler_is_initialized = False
|
105
|
+
|
106
|
+
def put_edge_id( # type: ignore[no-untyped-def]
|
107
|
+
self, edge_id: Tensor, *args, **kwargs) -> bool:
|
108
|
+
"""Stores an edge ID in the graph store.
|
109
|
+
|
110
|
+
:param edge_id: The edge ID to store.
|
111
|
+
:return: Whether the operation was successful.
|
112
|
+
"""
|
113
|
+
ret = super().put_edge_id(edge_id.contiguous(), *args, **kwargs)
|
114
|
+
self._sampler_is_initialized = False
|
115
|
+
return ret
|
116
|
+
|
117
|
+
@property
|
118
|
+
def edge_index(self) -> _EdgeTensorType:
|
119
|
+
"""Gets the edge index of the graph.
|
120
|
+
|
121
|
+
:return: The edge index as a tensor.
|
122
|
+
"""
|
123
|
+
return self.get_edge_index(*self.edge_idx_args, **self.edge_idx_kwargs)
|
124
|
+
|
125
|
+
def put_edge_index( # type: ignore[no-untyped-def]
|
126
|
+
self, edge_index: _EdgeTensorType, *args, **kwargs) -> bool:
|
127
|
+
"""Stores an edge index in the graph store.
|
128
|
+
|
129
|
+
:param edge_index: The edge index to store.
|
130
|
+
:return: Whether the operation was successful.
|
131
|
+
"""
|
132
|
+
ret = super().put_edge_index(edge_index, *args, **kwargs)
|
133
|
+
# HACK
|
134
|
+
self.edge_idx_args = args
|
135
|
+
self.edge_idx_kwargs = kwargs
|
136
|
+
self._sampler_is_initialized = False
|
137
|
+
return ret
|
138
|
+
|
139
|
+
# HACKY
|
140
|
+
@edge_index.setter # type: ignore
|
141
|
+
def edge_index(self, edge_index: _EdgeTensorType) -> None:
|
142
|
+
"""Sets the edge index of the graph.
|
143
|
+
|
144
|
+
:param edge_index: The edge index to set.
|
145
|
+
"""
|
146
|
+
# correct since we make node list from triples
|
147
|
+
if isinstance(edge_index, Tensor):
|
148
|
+
num_nodes = int(edge_index.max()) + 1
|
149
|
+
else:
|
150
|
+
assert isinstance(edge_index, tuple) \
|
151
|
+
and isinstance(edge_index[0], Tensor) \
|
152
|
+
and isinstance(edge_index[1], Tensor), \
|
153
|
+
"edge_index must be a Tensor of [2, num_edges] \
|
154
|
+
or a tuple of Tensors, (row, col)."
|
155
|
+
|
156
|
+
num_nodes = int(edge_index[0].max()) + 1
|
157
|
+
attr = dict(
|
158
|
+
edge_type=None,
|
159
|
+
layout='coo',
|
160
|
+
size=(num_nodes, num_nodes),
|
161
|
+
is_sorted=False,
|
162
|
+
)
|
163
|
+
# edge index needs to be sorted here and the perm saved for later
|
164
|
+
col_sorted, self.perm = index_sort(edge_index[1], num_nodes,
|
165
|
+
stable=True)
|
166
|
+
row_sorted = edge_index[0][self.perm]
|
167
|
+
edge_index_sorted = torch.stack([row_sorted, col_sorted], dim=0)
|
168
|
+
self.put_edge_index(edge_index_sorted, **attr)
|
169
|
+
|
170
|
+
def sample_subgraph(
|
171
|
+
self,
|
172
|
+
seed_nodes: Tensor,
|
173
|
+
) -> SamplerOutput:
|
174
|
+
"""Sample the graph starting from the given nodes using the
|
175
|
+
in-built NeighborSampler.
|
176
|
+
|
177
|
+
Args:
|
178
|
+
seed_nodes (InputNodes): Seed nodes to start sampling from.
|
179
|
+
num_neighbors (Optional[NumNeighborsType], optional): Parameters
|
180
|
+
to determine how many hops and number of neighbors per hop.
|
181
|
+
Defaults to None.
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
Union[SamplerOutput, HeteroSamplerOutput]: NeighborSamplerOutput
|
185
|
+
for the input.
|
186
|
+
"""
|
187
|
+
# TODO add support for Hetero
|
188
|
+
if not self._sampler_is_initialized:
|
189
|
+
self._init_sampler()
|
190
|
+
|
191
|
+
seed_nodes = seed_nodes.unique().contiguous()
|
192
|
+
node_sample_input = NodeSamplerInput(input_id=None, node=seed_nodes)
|
193
|
+
out = self.sampler.sample_from_nodes( # type: ignore[has-type]
|
194
|
+
node_sample_input)
|
195
|
+
|
196
|
+
# edge ids need to be remapped to the original indices
|
197
|
+
out.edge = self.perm[out.edge]
|
198
|
+
|
199
|
+
return out
|
@@ -0,0 +1,124 @@
|
|
1
|
+
# mypy: ignore-errors
|
2
|
+
import os
|
3
|
+
from abc import abstractmethod
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Protocol, Union
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch import Tensor
|
8
|
+
|
9
|
+
from torch_geometric.data import Data
|
10
|
+
from torch_geometric.llm.models import SentenceTransformer
|
11
|
+
from torch_geometric.llm.utils.backend_utils import batch_knn
|
12
|
+
|
13
|
+
|
14
|
+
class VectorRetriever(Protocol):
|
15
|
+
"""Protocol for VectorRAG."""
|
16
|
+
@abstractmethod
|
17
|
+
def query(self, query: Any, **kwargs: Optional[Dict[str, Any]]) -> Data:
|
18
|
+
"""Retrieve a context for a given query."""
|
19
|
+
...
|
20
|
+
|
21
|
+
|
22
|
+
class DocumentRetriever(VectorRetriever):
|
23
|
+
"""Retrieve documents from a vector database."""
|
24
|
+
def __init__(self, raw_docs: List[str],
|
25
|
+
embedded_docs: Optional[Tensor] = None, k_for_docs: int = 2,
|
26
|
+
model: Optional[Union[SentenceTransformer, torch.nn.Module,
|
27
|
+
Callable]] = None,
|
28
|
+
model_kwargs: Optional[Dict[str, Any]] = None):
|
29
|
+
"""Retrieve documents from a vector database.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
raw_docs: List[str]: List of raw documents.
|
33
|
+
embedded_docs: Optional[Tensor]: Embedded documents.
|
34
|
+
k_for_docs: int: Number of documents to retrieve.
|
35
|
+
model: Optional[Union[SentenceTransformer, torch.nn.Module]]:
|
36
|
+
Model to use for encoding.
|
37
|
+
model_kwargs: Optional[Dict[str, Any]]:
|
38
|
+
Keyword arguments to pass to the model.
|
39
|
+
"""
|
40
|
+
self.raw_docs = raw_docs
|
41
|
+
self.embedded_docs = embedded_docs
|
42
|
+
self.k_for_docs = k_for_docs
|
43
|
+
self.model = model
|
44
|
+
|
45
|
+
if self.model is not None:
|
46
|
+
self.encoder = self.model
|
47
|
+
self.model_kwargs = model_kwargs
|
48
|
+
|
49
|
+
if self.embedded_docs is None:
|
50
|
+
assert self.model is not None, \
|
51
|
+
"Model must be provided if embedded_docs is not provided"
|
52
|
+
self.model_kwargs = model_kwargs or {}
|
53
|
+
self.embedded_docs = self.encoder(self.raw_docs,
|
54
|
+
**self.model_kwargs)
|
55
|
+
# we don't want to print the verbose output in `query`
|
56
|
+
self.model_kwargs.pop("verbose", None)
|
57
|
+
|
58
|
+
def query(self, query: Union[str, Tensor]) -> List[str]:
|
59
|
+
"""Retrieve documents from the vector database.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
query: Union[str, Tensor]: Query to retrieve documents for.
|
63
|
+
|
64
|
+
Returns:
|
65
|
+
List[str]: Documents retrieved from the vector database.
|
66
|
+
"""
|
67
|
+
if isinstance(query, str):
|
68
|
+
query_enc = self.encoder(query, **self.model_kwargs)
|
69
|
+
else:
|
70
|
+
query_enc = query
|
71
|
+
|
72
|
+
selected_doc_idxs, _ = next(
|
73
|
+
batch_knn(query_enc, self.embedded_docs, self.k_for_docs))
|
74
|
+
return [self.raw_docs[i] for i in selected_doc_idxs]
|
75
|
+
|
76
|
+
def save(self, path: str) -> None:
|
77
|
+
"""Save the DocumentRetriever instance to disk.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
path: str: Path where to save the retriever.
|
81
|
+
"""
|
82
|
+
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
83
|
+
|
84
|
+
# Prepare data to save
|
85
|
+
save_dict = {
|
86
|
+
'raw_docs': self.raw_docs,
|
87
|
+
'embedded_docs': self.embedded_docs,
|
88
|
+
'k_for_docs': self.k_for_docs,
|
89
|
+
}
|
90
|
+
|
91
|
+
# We do not serialize the model
|
92
|
+
torch.save(save_dict, path)
|
93
|
+
|
94
|
+
@classmethod
|
95
|
+
def load(cls, path: str, model: Union[SentenceTransformer, torch.nn.Module,
|
96
|
+
Callable],
|
97
|
+
model_kwargs: Optional[Dict[str, Any]] = None) -> VectorRetriever:
|
98
|
+
"""Load a DocumentRetriever instance from disk.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
path: str: Path to the saved retriever.
|
102
|
+
model: Union[SentenceTransformer, torch.nn.Module, Callable]:
|
103
|
+
Model to use for encoding.
|
104
|
+
If None, the saved model will be used if available.
|
105
|
+
model_kwargs: Optional[Dict[str, Any]]
|
106
|
+
Key word args to be passed to model
|
107
|
+
|
108
|
+
Returns:
|
109
|
+
DocumentRetriever: The loaded retriever.
|
110
|
+
"""
|
111
|
+
if not os.path.exists(path):
|
112
|
+
raise FileNotFoundError(
|
113
|
+
f"No saved document retriever found at {path}")
|
114
|
+
|
115
|
+
save_dict = torch.load(path, weights_only=False)
|
116
|
+
if save_dict['embedded_docs'] is not None \
|
117
|
+
and isinstance(save_dict['embedded_docs'], Tensor)\
|
118
|
+
and model_kwargs is not None:
|
119
|
+
model_kwargs.pop("verbose", None)
|
120
|
+
# Create a new DocumentRetriever with the loaded data
|
121
|
+
return cls(raw_docs=save_dict['raw_docs'],
|
122
|
+
embedded_docs=save_dict['embedded_docs'],
|
123
|
+
k_for_docs=save_dict['k_for_docs'], model=model,
|
124
|
+
model_kwargs=model_kwargs)
|
@@ -22,7 +22,6 @@ from .dynamic_batch_sampler import DynamicBatchSampler
|
|
22
22
|
from .prefetch import PrefetchLoader
|
23
23
|
from .cache import CachedLoader
|
24
24
|
from .mixin import AffinityMixin
|
25
|
-
from .rag_loader import RAGQueryLoader, RAGFeatureStore, RAGGraphStore
|
26
25
|
|
27
26
|
__all__ = classes = [
|
28
27
|
'DataLoader',
|
@@ -51,9 +50,6 @@ __all__ = classes = [
|
|
51
50
|
'PrefetchLoader',
|
52
51
|
'CachedLoader',
|
53
52
|
'AffinityMixin',
|
54
|
-
'RAGQueryLoader',
|
55
|
-
'RAGFeatureStore',
|
56
|
-
'RAGGraphStore'
|
57
53
|
]
|
58
54
|
|
59
55
|
RandomNodeSampler = deprecated(
|
torch_geometric/nn/__init__.py
CHANGED
@@ -29,11 +29,6 @@ from .gnnff import GNNFF
|
|
29
29
|
from .pmlp import PMLP
|
30
30
|
from .neural_fingerprint import NeuralFingerprint
|
31
31
|
from .visnet import ViSNet
|
32
|
-
from .g_retriever import GRetriever
|
33
|
-
from .git_mol import GITMol
|
34
|
-
from .molecule_gpt import MoleculeGPT
|
35
|
-
from .protein_mpnn import ProteinMPNN
|
36
|
-
from .glem import GLEM
|
37
32
|
from .lpformer import LPFormer
|
38
33
|
from .sgformer import SGFormer
|
39
34
|
|
@@ -87,11 +82,6 @@ __all__ = classes = [
|
|
87
82
|
'PMLP',
|
88
83
|
'NeuralFingerprint',
|
89
84
|
'ViSNet',
|
90
|
-
'GRetriever',
|
91
|
-
'GITMol',
|
92
|
-
'MoleculeGPT',
|
93
|
-
'ProteinMPNN',
|
94
|
-
'GLEM',
|
95
85
|
'LPFormer',
|
96
86
|
'SGFormer',
|
97
87
|
'Polynormer',
|
@@ -187,6 +187,8 @@ class SGFormer(torch.nn.Module):
|
|
187
187
|
self.params2 = list(self.graph_conv.parameters())
|
188
188
|
self.params2.extend(list(self.fc.parameters()))
|
189
189
|
|
190
|
+
self.out_channels = out_channels
|
191
|
+
|
190
192
|
def reset_parameters(self) -> None:
|
191
193
|
self.trans_conv.reset_parameters()
|
192
194
|
self.graph_conv.reset_parameters()
|
@@ -1,107 +0,0 @@
|
|
1
|
-
from abc import abstractmethod
|
2
|
-
from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
|
3
|
-
|
4
|
-
from torch_geometric.data import Data, FeatureStore, HeteroData
|
5
|
-
from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
|
6
|
-
from torch_geometric.typing import InputEdges, InputNodes
|
7
|
-
|
8
|
-
|
9
|
-
class RAGFeatureStore(Protocol):
|
10
|
-
"""Feature store template for remote GNN RAG backend."""
|
11
|
-
@abstractmethod
|
12
|
-
def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes:
|
13
|
-
"""Makes a comparison between the query and all the nodes to get all
|
14
|
-
the closest nodes. Return the indices of the nodes that are to be seeds
|
15
|
-
for the RAG Sampler.
|
16
|
-
"""
|
17
|
-
...
|
18
|
-
|
19
|
-
@abstractmethod
|
20
|
-
def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges:
|
21
|
-
"""Makes a comparison between the query and all the edges to get all
|
22
|
-
the closest nodes. Returns the edge indices that are to be the seeds
|
23
|
-
for the RAG Sampler.
|
24
|
-
"""
|
25
|
-
...
|
26
|
-
|
27
|
-
@abstractmethod
|
28
|
-
def load_subgraph(
|
29
|
-
self, sample: Union[SamplerOutput, HeteroSamplerOutput]
|
30
|
-
) -> Union[Data, HeteroData]:
|
31
|
-
"""Combines sampled subgraph output with features in a Data object."""
|
32
|
-
...
|
33
|
-
|
34
|
-
|
35
|
-
class RAGGraphStore(Protocol):
|
36
|
-
"""Graph store template for remote GNN RAG backend."""
|
37
|
-
@abstractmethod
|
38
|
-
def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges,
|
39
|
-
**kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]:
|
40
|
-
"""Sample a subgraph using the seeded nodes and edges."""
|
41
|
-
...
|
42
|
-
|
43
|
-
@abstractmethod
|
44
|
-
def register_feature_store(self, feature_store: FeatureStore):
|
45
|
-
"""Register a feature store to be used with the sampler. Samplers need
|
46
|
-
info from the feature store in order to work properly on HeteroGraphs.
|
47
|
-
"""
|
48
|
-
...
|
49
|
-
|
50
|
-
|
51
|
-
# TODO: Make compatible with Heterographs
|
52
|
-
|
53
|
-
|
54
|
-
class RAGQueryLoader:
|
55
|
-
"""Loader meant for making RAG queries from a remote backend."""
|
56
|
-
def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore],
|
57
|
-
local_filter: Optional[Callable[[Data, Any], Data]] = None,
|
58
|
-
seed_nodes_kwargs: Optional[Dict[str, Any]] = None,
|
59
|
-
seed_edges_kwargs: Optional[Dict[str, Any]] = None,
|
60
|
-
sampler_kwargs: Optional[Dict[str, Any]] = None,
|
61
|
-
loader_kwargs: Optional[Dict[str, Any]] = None):
|
62
|
-
"""Loader meant for making queries from a remote backend.
|
63
|
-
|
64
|
-
Args:
|
65
|
-
data (Tuple[RAGFeatureStore, RAGGraphStore]): Remote FeatureStore
|
66
|
-
and GraphStore to load from. Assumed to conform to the
|
67
|
-
protocols listed above.
|
68
|
-
local_filter (Optional[Callable[[Data, Any], Data]], optional):
|
69
|
-
Optional local transform to apply to data after retrieval.
|
70
|
-
Defaults to None.
|
71
|
-
seed_nodes_kwargs (Optional[Dict[str, Any]], optional): Parameters
|
72
|
-
to pass into process for fetching seed nodes. Defaults to None.
|
73
|
-
seed_edges_kwargs (Optional[Dict[str, Any]], optional): Parameters
|
74
|
-
to pass into process for fetching seed edges. Defaults to None.
|
75
|
-
sampler_kwargs (Optional[Dict[str, Any]], optional): Parameters to
|
76
|
-
pass into process for sampling graph. Defaults to None.
|
77
|
-
loader_kwargs (Optional[Dict[str, Any]], optional): Parameters to
|
78
|
-
pass into process for loading graph features. Defaults to None.
|
79
|
-
"""
|
80
|
-
fstore, gstore = data
|
81
|
-
self.feature_store = fstore
|
82
|
-
self.graph_store = gstore
|
83
|
-
self.graph_store.register_feature_store(self.feature_store)
|
84
|
-
self.local_filter = local_filter
|
85
|
-
self.seed_nodes_kwargs = seed_nodes_kwargs or {}
|
86
|
-
self.seed_edges_kwargs = seed_edges_kwargs or {}
|
87
|
-
self.sampler_kwargs = sampler_kwargs or {}
|
88
|
-
self.loader_kwargs = loader_kwargs or {}
|
89
|
-
|
90
|
-
def query(self, query: Any) -> Data:
|
91
|
-
"""Retrieve a subgraph associated with the query with all its feature
|
92
|
-
attributes.
|
93
|
-
"""
|
94
|
-
seed_nodes = self.feature_store.retrieve_seed_nodes(
|
95
|
-
query, **self.seed_nodes_kwargs)
|
96
|
-
seed_edges = self.feature_store.retrieve_seed_edges(
|
97
|
-
query, **self.seed_edges_kwargs)
|
98
|
-
|
99
|
-
subgraph_sample = self.graph_store.sample_subgraph(
|
100
|
-
seed_nodes, seed_edges, **self.sampler_kwargs)
|
101
|
-
|
102
|
-
data = self.feature_store.load_subgraph(sample=subgraph_sample,
|
103
|
-
**self.loader_kwargs)
|
104
|
-
|
105
|
-
if self.local_filter:
|
106
|
-
data = self.local_filter(data, query)
|
107
|
-
return data
|
File without changes
|
{pyg_nightly-2.7.0.dev20250905.dist-info → pyg_nightly-2.7.0.dev20250907.dist-info}/licenses/LICENSE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|