pyg-nightly 2.7.0.dev20250904__py3-none-any.whl → 2.7.0.dev20250906__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {pyg_nightly-2.7.0.dev20250904.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/METADATA +2 -1
  2. {pyg_nightly-2.7.0.dev20250904.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/RECORD +34 -27
  3. torch_geometric/__init__.py +1 -1
  4. torch_geometric/data/__init__.py +0 -5
  5. torch_geometric/data/lightning/datamodule.py +2 -2
  6. torch_geometric/datasets/molecule_gpt_dataset.py +1 -1
  7. torch_geometric/datasets/web_qsp_dataset.py +262 -210
  8. torch_geometric/graphgym/imports.py +2 -2
  9. torch_geometric/llm/__init__.py +9 -0
  10. torch_geometric/{data → llm}/large_graph_indexer.py +124 -61
  11. torch_geometric/llm/models/__init__.py +23 -0
  12. torch_geometric/{nn → llm}/models/g_retriever.py +68 -49
  13. torch_geometric/{nn → llm}/models/git_mol.py +1 -1
  14. torch_geometric/{nn/nlp → llm/models}/llm.py +167 -33
  15. torch_geometric/llm/models/llm_judge.py +158 -0
  16. torch_geometric/{nn → llm}/models/molecule_gpt.py +1 -1
  17. torch_geometric/{nn/nlp → llm/models}/sentence_transformer.py +42 -8
  18. torch_geometric/llm/models/txt2kg.py +353 -0
  19. torch_geometric/llm/rag_loader.py +154 -0
  20. torch_geometric/llm/utils/backend_utils.py +442 -0
  21. torch_geometric/llm/utils/feature_store.py +169 -0
  22. torch_geometric/llm/utils/graph_store.py +199 -0
  23. torch_geometric/llm/utils/vectorrag.py +124 -0
  24. torch_geometric/loader/__init__.py +0 -4
  25. torch_geometric/metrics/link_pred.py +13 -2
  26. torch_geometric/nn/__init__.py +0 -1
  27. torch_geometric/nn/models/__init__.py +0 -10
  28. torch_geometric/nn/models/sgformer.py +2 -0
  29. torch_geometric/utils/cross_entropy.py +34 -13
  30. torch_geometric/loader/rag_loader.py +0 -107
  31. torch_geometric/nn/nlp/__init__.py +0 -9
  32. {pyg_nightly-2.7.0.dev20250904.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/WHEEL +0 -0
  33. {pyg_nightly-2.7.0.dev20250904.dist-info → pyg_nightly-2.7.0.dev20250906.dist-info}/licenses/LICENSE +0 -0
  34. /torch_geometric/{nn → llm}/models/glem.py +0 -0
  35. /torch_geometric/{nn → llm}/models/protein_mpnn.py +0 -0
  36. /torch_geometric/{nn/nlp → llm/models}/vision_transformer.py +0 -0
@@ -0,0 +1,199 @@
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from torch_geometric.data import FeatureStore
7
+ from torch_geometric.distributed import LocalGraphStore
8
+ from torch_geometric.sampler import (
9
+ BidirectionalNeighborSampler,
10
+ NodeSamplerInput,
11
+ SamplerOutput,
12
+ )
13
+ from torch_geometric.utils import index_sort
14
+
15
+ # A representation of an edge index, following the possible formats:
16
+ # * default: Tensor, size = [2, num_edges]
17
+ # * Tensor[0, :] == row, Tensor[1, :] == col
18
+ # * COO: (row, col)
19
+ # * CSC: (row, colptr)
20
+ # * CSR: (rowptr, col)
21
+ _EdgeTensorType = Union[Tensor, Tuple[Tensor, Tensor]]
22
+
23
+
24
+ class NeighborSamplingRAGGraphStore(LocalGraphStore):
25
+ """Neighbor sampling based graph-store to store & retrieve graph data."""
26
+ def __init__( # type: ignore[no-untyped-def]
27
+ self,
28
+ feature_store: Optional[FeatureStore] = None,
29
+ **kwargs,
30
+ ):
31
+ """Initializes the graph store.
32
+ Optional feature store and neighbor sampling settings.
33
+
34
+ Args:
35
+ feature_store (optional): The feature store to use.
36
+ None if not yet registered.
37
+ **kwargs (optional):
38
+ Additional keyword arguments for neighbor sampling.
39
+ """
40
+ self.feature_store = feature_store
41
+ self.sample_kwargs = kwargs
42
+ self._sampler_is_initialized = False
43
+ self._config: Dict[str, Any] = {}
44
+
45
+ # to be set by the config
46
+ self.num_neighbors = None
47
+ super().__init__()
48
+
49
+ @property
50
+ def config(self) -> Dict[str, Any]:
51
+ """Get the config for the feature store."""
52
+ return self._config
53
+
54
+ def _set_from_config(self, config: Dict[str, Any], attr_name: str) -> None:
55
+ """Set an attribute from the config.
56
+
57
+ Args:
58
+ config (Dict[str, Any]): Config dictionary
59
+ attr_name (str): Name of attribute to set
60
+
61
+ Raises:
62
+ ValueError: If required attribute not found in config
63
+ """
64
+ if attr_name not in config:
65
+ raise ValueError(
66
+ f"Required config parameter '{attr_name}' not found")
67
+ setattr(self, attr_name, config[attr_name])
68
+
69
+ @config.setter # type: ignore
70
+ def config(self, config: Dict[str, Any]) -> None:
71
+ """Set the config for the feature store.
72
+
73
+ Args:
74
+ config (Dict[str, Any]):
75
+ Config dictionary containing required parameters
76
+
77
+ Raises:
78
+ ValueError: If required parameters missing from config
79
+ """
80
+ self._set_from_config(config, "num_neighbors")
81
+ if hasattr(self, 'sampler'):
82
+ self.sampler.num_neighbors = ( # type: ignore[has-type]
83
+ self.num_neighbors)
84
+
85
+ self._config = config
86
+
87
+ def _init_sampler(self) -> None:
88
+ """Initializes neighbor sampler with the registered feature store."""
89
+ if self.feature_store is None:
90
+ raise AttributeError("Feature store not registered yet.")
91
+ assert self.num_neighbors is not None, \
92
+ "Please set num_neighbors through config"
93
+ self.sampler = BidirectionalNeighborSampler(
94
+ data=(self.feature_store, self), num_neighbors=self.num_neighbors,
95
+ **self.sample_kwargs)
96
+ self._sampler_is_initialized = True
97
+
98
+ def register_feature_store(self, feature_store: FeatureStore) -> None:
99
+ """Registers a feature store with the graph store.
100
+
101
+ :param feature_store: The feature store to register.
102
+ """
103
+ self.feature_store = feature_store
104
+ self._sampler_is_initialized = False
105
+
106
+ def put_edge_id( # type: ignore[no-untyped-def]
107
+ self, edge_id: Tensor, *args, **kwargs) -> bool:
108
+ """Stores an edge ID in the graph store.
109
+
110
+ :param edge_id: The edge ID to store.
111
+ :return: Whether the operation was successful.
112
+ """
113
+ ret = super().put_edge_id(edge_id.contiguous(), *args, **kwargs)
114
+ self._sampler_is_initialized = False
115
+ return ret
116
+
117
+ @property
118
+ def edge_index(self) -> _EdgeTensorType:
119
+ """Gets the edge index of the graph.
120
+
121
+ :return: The edge index as a tensor.
122
+ """
123
+ return self.get_edge_index(*self.edge_idx_args, **self.edge_idx_kwargs)
124
+
125
+ def put_edge_index( # type: ignore[no-untyped-def]
126
+ self, edge_index: _EdgeTensorType, *args, **kwargs) -> bool:
127
+ """Stores an edge index in the graph store.
128
+
129
+ :param edge_index: The edge index to store.
130
+ :return: Whether the operation was successful.
131
+ """
132
+ ret = super().put_edge_index(edge_index, *args, **kwargs)
133
+ # HACK
134
+ self.edge_idx_args = args
135
+ self.edge_idx_kwargs = kwargs
136
+ self._sampler_is_initialized = False
137
+ return ret
138
+
139
+ # HACKY
140
+ @edge_index.setter # type: ignore
141
+ def edge_index(self, edge_index: _EdgeTensorType) -> None:
142
+ """Sets the edge index of the graph.
143
+
144
+ :param edge_index: The edge index to set.
145
+ """
146
+ # correct since we make node list from triples
147
+ if isinstance(edge_index, Tensor):
148
+ num_nodes = int(edge_index.max()) + 1
149
+ else:
150
+ assert isinstance(edge_index, tuple) \
151
+ and isinstance(edge_index[0], Tensor) \
152
+ and isinstance(edge_index[1], Tensor), \
153
+ "edge_index must be a Tensor of [2, num_edges] \
154
+ or a tuple of Tensors, (row, col)."
155
+
156
+ num_nodes = int(edge_index[0].max()) + 1
157
+ attr = dict(
158
+ edge_type=None,
159
+ layout='coo',
160
+ size=(num_nodes, num_nodes),
161
+ is_sorted=False,
162
+ )
163
+ # edge index needs to be sorted here and the perm saved for later
164
+ col_sorted, self.perm = index_sort(edge_index[1], num_nodes,
165
+ stable=True)
166
+ row_sorted = edge_index[0][self.perm]
167
+ edge_index_sorted = torch.stack([row_sorted, col_sorted], dim=0)
168
+ self.put_edge_index(edge_index_sorted, **attr)
169
+
170
+ def sample_subgraph(
171
+ self,
172
+ seed_nodes: Tensor,
173
+ ) -> SamplerOutput:
174
+ """Sample the graph starting from the given nodes using the
175
+ in-built NeighborSampler.
176
+
177
+ Args:
178
+ seed_nodes (InputNodes): Seed nodes to start sampling from.
179
+ num_neighbors (Optional[NumNeighborsType], optional): Parameters
180
+ to determine how many hops and number of neighbors per hop.
181
+ Defaults to None.
182
+
183
+ Returns:
184
+ Union[SamplerOutput, HeteroSamplerOutput]: NeighborSamplerOutput
185
+ for the input.
186
+ """
187
+ # TODO add support for Hetero
188
+ if not self._sampler_is_initialized:
189
+ self._init_sampler()
190
+
191
+ seed_nodes = seed_nodes.unique().contiguous()
192
+ node_sample_input = NodeSamplerInput(input_id=None, node=seed_nodes)
193
+ out = self.sampler.sample_from_nodes( # type: ignore[has-type]
194
+ node_sample_input)
195
+
196
+ # edge ids need to be remapped to the original indices
197
+ out.edge = self.perm[out.edge]
198
+
199
+ return out
@@ -0,0 +1,124 @@
1
+ # mypy: ignore-errors
2
+ import os
3
+ from abc import abstractmethod
4
+ from typing import Any, Callable, Dict, List, Optional, Protocol, Union
5
+
6
+ import torch
7
+ from torch import Tensor
8
+
9
+ from torch_geometric.data import Data
10
+ from torch_geometric.llm.models import SentenceTransformer
11
+ from torch_geometric.llm.utils.backend_utils import batch_knn
12
+
13
+
14
+ class VectorRetriever(Protocol):
15
+ """Protocol for VectorRAG."""
16
+ @abstractmethod
17
+ def query(self, query: Any, **kwargs: Optional[Dict[str, Any]]) -> Data:
18
+ """Retrieve a context for a given query."""
19
+ ...
20
+
21
+
22
+ class DocumentRetriever(VectorRetriever):
23
+ """Retrieve documents from a vector database."""
24
+ def __init__(self, raw_docs: List[str],
25
+ embedded_docs: Optional[Tensor] = None, k_for_docs: int = 2,
26
+ model: Optional[Union[SentenceTransformer, torch.nn.Module,
27
+ Callable]] = None,
28
+ model_kwargs: Optional[Dict[str, Any]] = None):
29
+ """Retrieve documents from a vector database.
30
+
31
+ Args:
32
+ raw_docs: List[str]: List of raw documents.
33
+ embedded_docs: Optional[Tensor]: Embedded documents.
34
+ k_for_docs: int: Number of documents to retrieve.
35
+ model: Optional[Union[SentenceTransformer, torch.nn.Module]]:
36
+ Model to use for encoding.
37
+ model_kwargs: Optional[Dict[str, Any]]:
38
+ Keyword arguments to pass to the model.
39
+ """
40
+ self.raw_docs = raw_docs
41
+ self.embedded_docs = embedded_docs
42
+ self.k_for_docs = k_for_docs
43
+ self.model = model
44
+
45
+ if self.model is not None:
46
+ self.encoder = self.model
47
+ self.model_kwargs = model_kwargs
48
+
49
+ if self.embedded_docs is None:
50
+ assert self.model is not None, \
51
+ "Model must be provided if embedded_docs is not provided"
52
+ self.model_kwargs = model_kwargs or {}
53
+ self.embedded_docs = self.encoder(self.raw_docs,
54
+ **self.model_kwargs)
55
+ # we don't want to print the verbose output in `query`
56
+ self.model_kwargs.pop("verbose", None)
57
+
58
+ def query(self, query: Union[str, Tensor]) -> List[str]:
59
+ """Retrieve documents from the vector database.
60
+
61
+ Args:
62
+ query: Union[str, Tensor]: Query to retrieve documents for.
63
+
64
+ Returns:
65
+ List[str]: Documents retrieved from the vector database.
66
+ """
67
+ if isinstance(query, str):
68
+ query_enc = self.encoder(query, **self.model_kwargs)
69
+ else:
70
+ query_enc = query
71
+
72
+ selected_doc_idxs, _ = next(
73
+ batch_knn(query_enc, self.embedded_docs, self.k_for_docs))
74
+ return [self.raw_docs[i] for i in selected_doc_idxs]
75
+
76
+ def save(self, path: str) -> None:
77
+ """Save the DocumentRetriever instance to disk.
78
+
79
+ Args:
80
+ path: str: Path where to save the retriever.
81
+ """
82
+ os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
83
+
84
+ # Prepare data to save
85
+ save_dict = {
86
+ 'raw_docs': self.raw_docs,
87
+ 'embedded_docs': self.embedded_docs,
88
+ 'k_for_docs': self.k_for_docs,
89
+ }
90
+
91
+ # We do not serialize the model
92
+ torch.save(save_dict, path)
93
+
94
+ @classmethod
95
+ def load(cls, path: str, model: Union[SentenceTransformer, torch.nn.Module,
96
+ Callable],
97
+ model_kwargs: Optional[Dict[str, Any]] = None) -> VectorRetriever:
98
+ """Load a DocumentRetriever instance from disk.
99
+
100
+ Args:
101
+ path: str: Path to the saved retriever.
102
+ model: Union[SentenceTransformer, torch.nn.Module, Callable]:
103
+ Model to use for encoding.
104
+ If None, the saved model will be used if available.
105
+ model_kwargs: Optional[Dict[str, Any]]
106
+ Key word args to be passed to model
107
+
108
+ Returns:
109
+ DocumentRetriever: The loaded retriever.
110
+ """
111
+ if not os.path.exists(path):
112
+ raise FileNotFoundError(
113
+ f"No saved document retriever found at {path}")
114
+
115
+ save_dict = torch.load(path, weights_only=False)
116
+ if save_dict['embedded_docs'] is not None \
117
+ and isinstance(save_dict['embedded_docs'], Tensor)\
118
+ and model_kwargs is not None:
119
+ model_kwargs.pop("verbose", None)
120
+ # Create a new DocumentRetriever with the loaded data
121
+ return cls(raw_docs=save_dict['raw_docs'],
122
+ embedded_docs=save_dict['embedded_docs'],
123
+ k_for_docs=save_dict['k_for_docs'], model=model,
124
+ model_kwargs=model_kwargs)
@@ -22,7 +22,6 @@ from .dynamic_batch_sampler import DynamicBatchSampler
22
22
  from .prefetch import PrefetchLoader
23
23
  from .cache import CachedLoader
24
24
  from .mixin import AffinityMixin
25
- from .rag_loader import RAGQueryLoader, RAGFeatureStore, RAGGraphStore
26
25
 
27
26
  __all__ = classes = [
28
27
  'DataLoader',
@@ -51,9 +50,6 @@ __all__ = classes = [
51
50
  'PrefetchLoader',
52
51
  'CachedLoader',
53
52
  'AffinityMixin',
54
- 'RAGQueryLoader',
55
- 'RAGFeatureStore',
56
- 'RAGGraphStore'
57
53
  ]
58
54
 
59
55
  RandomNodeSampler = deprecated(
@@ -21,6 +21,19 @@ class LinkPredMetricData:
21
21
  edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]]
22
22
  edge_label_weight: Optional[Tensor] = None
23
23
 
24
+ def __post_init__(self) -> None:
25
+ # Filter all negative weights - they should not be used as ground-truth
26
+ if self.edge_label_weight is not None:
27
+ pos_mask = self.edge_label_weight > 0
28
+ self.edge_label_weight = self.edge_label_weight[pos_mask]
29
+ if isinstance(self.edge_label_index, Tensor):
30
+ self.edge_label_index = self.edge_label_index[:, pos_mask]
31
+ else:
32
+ self.edge_label_index = (
33
+ self.edge_label_index[0][pos_mask],
34
+ self.edge_label_index[1][pos_mask],
35
+ )
36
+
24
37
  @property
25
38
  def pred_rel_mat(self) -> Tensor:
26
39
  r"""Returns a matrix indicating the relevance of the `k`-th prediction.
@@ -374,8 +387,6 @@ class LinkPredMetricCollection(torch.nn.ModuleDict):
374
387
  if self.weighted and edge_label_weight is None:
375
388
  raise ValueError(f"'edge_label_weight' is a required argument for "
376
389
  f"weighted '{self.__class__.__name__}' metrics")
377
- if not self.weighted:
378
- edge_label_weight = None
379
390
 
380
391
  data = LinkPredMetricData( # Share metric data across metrics.
381
392
  pred_index_mat=pred_index_mat,
@@ -17,7 +17,6 @@ from .dense import * # noqa
17
17
  from .kge import * # noqa
18
18
  from .models import * # noqa
19
19
  from .functional import * # noqa
20
- from .nlp import * # noqa
21
20
 
22
21
  __all__ = [
23
22
  'Reshape',
@@ -29,11 +29,6 @@ from .gnnff import GNNFF
29
29
  from .pmlp import PMLP
30
30
  from .neural_fingerprint import NeuralFingerprint
31
31
  from .visnet import ViSNet
32
- from .g_retriever import GRetriever
33
- from .git_mol import GITMol
34
- from .molecule_gpt import MoleculeGPT
35
- from .protein_mpnn import ProteinMPNN
36
- from .glem import GLEM
37
32
  from .lpformer import LPFormer
38
33
  from .sgformer import SGFormer
39
34
 
@@ -87,11 +82,6 @@ __all__ = classes = [
87
82
  'PMLP',
88
83
  'NeuralFingerprint',
89
84
  'ViSNet',
90
- 'GRetriever',
91
- 'GITMol',
92
- 'MoleculeGPT',
93
- 'ProteinMPNN',
94
- 'GLEM',
95
85
  'LPFormer',
96
86
  'SGFormer',
97
87
  'Polynormer',
@@ -187,6 +187,8 @@ class SGFormer(torch.nn.Module):
187
187
  self.params2 = list(self.graph_conv.parameters())
188
188
  self.params2.extend(list(self.fc.parameters()))
189
189
 
190
+ self.out_channels = out_channels
191
+
190
192
  def reset_parameters(self) -> None:
191
193
  self.trans_conv.reset_parameters()
192
194
  self.graph_conv.reset_parameters()
@@ -18,30 +18,51 @@ class SparseCrossEntropy(torch.autograd.Function):
18
18
  ) -> Tensor:
19
19
  assert inputs.dim() == 2
20
20
 
21
- logsumexp = inputs.logsumexp(dim=-1)
22
- ctx.save_for_backward(inputs, edge_label_index, edge_label_weight,
23
- logsumexp)
21
+ # Support for both positive and negative weights:
22
+ # Positive weights scale the logits *after* softmax.
23
+ # Negative weights scale the denominator *before* softmax:
24
+ pos_y = edge_label_index
25
+ neg_y = pos_weight = neg_weight = None
24
26
 
25
- out = inputs[edge_label_index[0], edge_label_index[1]]
26
- out.neg_().add_(logsumexp[edge_label_index[0]])
27
27
  if edge_label_weight is not None:
28
- out *= edge_label_weight
28
+ pos_mask = edge_label_weight >= 0
29
+ pos_y = edge_label_index[:, pos_mask]
30
+ pos_weight = edge_label_weight[pos_mask]
31
+
32
+ if pos_y.size(1) < edge_label_index.size(1):
33
+ neg_mask = ~pos_mask
34
+ neg_y = edge_label_index[:, neg_mask]
35
+ neg_weight = edge_label_weight[neg_mask]
36
+
37
+ if neg_y is not None and neg_weight is not None:
38
+ inputs = inputs.clone()
39
+ inputs[
40
+ neg_y[0],
41
+ neg_y[1],
42
+ ] += neg_weight.abs().log().clamp(min=1e-12)
43
+
44
+ logsumexp = inputs.logsumexp(dim=-1)
45
+ ctx.save_for_backward(inputs, pos_y, pos_weight, logsumexp)
46
+
47
+ out = inputs[pos_y[0], pos_y[1]]
48
+ out.neg_().add_(logsumexp[pos_y[0]])
49
+ if pos_weight is not None:
50
+ out *= pos_weight
29
51
 
30
52
  return out.sum() / inputs.size(0)
31
53
 
32
54
  @staticmethod
33
55
  @torch.autograd.function.once_differentiable
34
56
  def backward(ctx: Any, grad_out: Tensor) -> Tuple[Tensor, None, None]:
35
- inputs, edge_label_index, edge_label_weight, logsumexp = (
36
- ctx.saved_tensors)
57
+ inputs, pos_y, pos_weight, logsumexp = ctx.saved_tensors
37
58
 
38
59
  grad_out = grad_out / inputs.size(0)
39
- grad_out = grad_out.expand(edge_label_index.size(1))
60
+ grad_out = grad_out.expand(pos_y.size(1))
40
61
 
41
- if edge_label_weight is not None:
42
- grad_out = grad_out * edge_label_weight
62
+ if pos_weight is not None:
63
+ grad_out = grad_out * pos_weight
43
64
 
44
- grad_logsumexp = scatter(grad_out, edge_label_index[0], dim=0,
65
+ grad_logsumexp = scatter(grad_out, pos_y[0], dim=0,
45
66
  dim_size=inputs.size(0), reduce='sum')
46
67
 
47
68
  # Gradient computation of `logsumexp`: `grad * (self - result).exp()`
@@ -49,7 +70,7 @@ class SparseCrossEntropy(torch.autograd.Function):
49
70
  grad_input.exp_()
50
71
  grad_input.mul_(grad_logsumexp.view(-1, 1))
51
72
 
52
- grad_input[edge_label_index[0], edge_label_index[1]] -= grad_out
73
+ grad_input[pos_y[0], pos_y[1]] -= grad_out
53
74
 
54
75
  return grad_input, None, None
55
76
 
@@ -1,107 +0,0 @@
1
- from abc import abstractmethod
2
- from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
3
-
4
- from torch_geometric.data import Data, FeatureStore, HeteroData
5
- from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
6
- from torch_geometric.typing import InputEdges, InputNodes
7
-
8
-
9
- class RAGFeatureStore(Protocol):
10
- """Feature store template for remote GNN RAG backend."""
11
- @abstractmethod
12
- def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes:
13
- """Makes a comparison between the query and all the nodes to get all
14
- the closest nodes. Return the indices of the nodes that are to be seeds
15
- for the RAG Sampler.
16
- """
17
- ...
18
-
19
- @abstractmethod
20
- def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges:
21
- """Makes a comparison between the query and all the edges to get all
22
- the closest nodes. Returns the edge indices that are to be the seeds
23
- for the RAG Sampler.
24
- """
25
- ...
26
-
27
- @abstractmethod
28
- def load_subgraph(
29
- self, sample: Union[SamplerOutput, HeteroSamplerOutput]
30
- ) -> Union[Data, HeteroData]:
31
- """Combines sampled subgraph output with features in a Data object."""
32
- ...
33
-
34
-
35
- class RAGGraphStore(Protocol):
36
- """Graph store template for remote GNN RAG backend."""
37
- @abstractmethod
38
- def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges,
39
- **kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]:
40
- """Sample a subgraph using the seeded nodes and edges."""
41
- ...
42
-
43
- @abstractmethod
44
- def register_feature_store(self, feature_store: FeatureStore):
45
- """Register a feature store to be used with the sampler. Samplers need
46
- info from the feature store in order to work properly on HeteroGraphs.
47
- """
48
- ...
49
-
50
-
51
- # TODO: Make compatible with Heterographs
52
-
53
-
54
- class RAGQueryLoader:
55
- """Loader meant for making RAG queries from a remote backend."""
56
- def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore],
57
- local_filter: Optional[Callable[[Data, Any], Data]] = None,
58
- seed_nodes_kwargs: Optional[Dict[str, Any]] = None,
59
- seed_edges_kwargs: Optional[Dict[str, Any]] = None,
60
- sampler_kwargs: Optional[Dict[str, Any]] = None,
61
- loader_kwargs: Optional[Dict[str, Any]] = None):
62
- """Loader meant for making queries from a remote backend.
63
-
64
- Args:
65
- data (Tuple[RAGFeatureStore, RAGGraphStore]): Remote FeatureStore
66
- and GraphStore to load from. Assumed to conform to the
67
- protocols listed above.
68
- local_filter (Optional[Callable[[Data, Any], Data]], optional):
69
- Optional local transform to apply to data after retrieval.
70
- Defaults to None.
71
- seed_nodes_kwargs (Optional[Dict[str, Any]], optional): Parameters
72
- to pass into process for fetching seed nodes. Defaults to None.
73
- seed_edges_kwargs (Optional[Dict[str, Any]], optional): Parameters
74
- to pass into process for fetching seed edges. Defaults to None.
75
- sampler_kwargs (Optional[Dict[str, Any]], optional): Parameters to
76
- pass into process for sampling graph. Defaults to None.
77
- loader_kwargs (Optional[Dict[str, Any]], optional): Parameters to
78
- pass into process for loading graph features. Defaults to None.
79
- """
80
- fstore, gstore = data
81
- self.feature_store = fstore
82
- self.graph_store = gstore
83
- self.graph_store.register_feature_store(self.feature_store)
84
- self.local_filter = local_filter
85
- self.seed_nodes_kwargs = seed_nodes_kwargs or {}
86
- self.seed_edges_kwargs = seed_edges_kwargs or {}
87
- self.sampler_kwargs = sampler_kwargs or {}
88
- self.loader_kwargs = loader_kwargs or {}
89
-
90
- def query(self, query: Any) -> Data:
91
- """Retrieve a subgraph associated with the query with all its feature
92
- attributes.
93
- """
94
- seed_nodes = self.feature_store.retrieve_seed_nodes(
95
- query, **self.seed_nodes_kwargs)
96
- seed_edges = self.feature_store.retrieve_seed_edges(
97
- query, **self.seed_edges_kwargs)
98
-
99
- subgraph_sample = self.graph_store.sample_subgraph(
100
- seed_nodes, seed_edges, **self.sampler_kwargs)
101
-
102
- data = self.feature_store.load_subgraph(sample=subgraph_sample,
103
- **self.loader_kwargs)
104
-
105
- if self.local_filter:
106
- data = self.local_filter(data, query)
107
- return data
@@ -1,9 +0,0 @@
1
- from .sentence_transformer import SentenceTransformer
2
- from .vision_transformer import VisionTransformer
3
- from .llm import LLM
4
-
5
- __all__ = classes = [
6
- 'SentenceTransformer',
7
- 'VisionTransformer',
8
- 'LLM',
9
- ]
File without changes
File without changes