pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
|
|
3
|
+
|
|
4
|
+
from torch_geometric.data import Data, FeatureStore, HeteroData
|
|
5
|
+
from torch_geometric.llm.utils.vectorrag import VectorRetriever
|
|
6
|
+
from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
|
|
7
|
+
from torch_geometric.typing import InputEdges, InputNodes
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RAGFeatureStore(Protocol):
|
|
11
|
+
"""Feature store template for remote GNN RAG backend."""
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes:
|
|
14
|
+
"""Makes a comparison between the query and all the nodes to get all
|
|
15
|
+
the closest nodes. Return the indices of the nodes that are to be seeds
|
|
16
|
+
for the RAG Sampler.
|
|
17
|
+
"""
|
|
18
|
+
...
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def config(self) -> Dict[str, Any]:
|
|
23
|
+
"""Get the config for the RAGFeatureStore."""
|
|
24
|
+
...
|
|
25
|
+
|
|
26
|
+
@config.setter
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def config(self, config: Dict[str, Any]):
|
|
29
|
+
"""Set the config for the RAGFeatureStore."""
|
|
30
|
+
...
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges:
|
|
34
|
+
"""Makes a comparison between the query and all the edges to get all
|
|
35
|
+
the closest nodes. Returns the edge indices that are to be the seeds
|
|
36
|
+
for the RAG Sampler.
|
|
37
|
+
"""
|
|
38
|
+
...
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def load_subgraph(
|
|
42
|
+
self, sample: Union[SamplerOutput, HeteroSamplerOutput]
|
|
43
|
+
) -> Union[Data, HeteroData]:
|
|
44
|
+
"""Combines sampled subgraph output with features in a Data object."""
|
|
45
|
+
...
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class RAGGraphStore(Protocol):
|
|
49
|
+
"""Graph store template for remote GNN RAG backend."""
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges,
|
|
52
|
+
**kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]:
|
|
53
|
+
"""Sample a subgraph using the seeded nodes and edges."""
|
|
54
|
+
...
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
@abstractmethod
|
|
58
|
+
def config(self) -> Dict[str, Any]:
|
|
59
|
+
"""Get the config for the RAGGraphStore."""
|
|
60
|
+
...
|
|
61
|
+
|
|
62
|
+
@config.setter
|
|
63
|
+
@abstractmethod
|
|
64
|
+
def config(self, config: Dict[str, Any]):
|
|
65
|
+
"""Set the config for the RAGGraphStore."""
|
|
66
|
+
...
|
|
67
|
+
|
|
68
|
+
@abstractmethod
|
|
69
|
+
def register_feature_store(self, feature_store: FeatureStore):
|
|
70
|
+
"""Register a feature store to be used with the sampler. Samplers need
|
|
71
|
+
info from the feature store in order to work properly on HeteroGraphs.
|
|
72
|
+
"""
|
|
73
|
+
...
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# TODO: Make compatible with Heterographs
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class RAGQueryLoader:
|
|
80
|
+
"""Loader meant for making RAG queries from a remote backend."""
|
|
81
|
+
def __init__(self, graph_data: Tuple[RAGFeatureStore, RAGGraphStore],
|
|
82
|
+
subgraph_filter: Optional[Callable[[Data, Any], Data]] = None,
|
|
83
|
+
augment_query: bool = False,
|
|
84
|
+
vector_retriever: Optional[VectorRetriever] = None,
|
|
85
|
+
config: Optional[Dict[str, Any]] = None):
|
|
86
|
+
"""Loader meant for making queries from a remote backend.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
graph_data (Tuple[RAGFeatureStore, RAGGraphStore]):
|
|
90
|
+
Remote FeatureStore and GraphStore to load from.
|
|
91
|
+
Assumed to conform to the protocols listed above.
|
|
92
|
+
subgraph_filter (Optional[Callable[[Data, Any], Data]], optional):
|
|
93
|
+
Optional local transform to apply to data after retrieval.
|
|
94
|
+
Defaults to None.
|
|
95
|
+
augment_query (bool, optional): Whether to augment the query with
|
|
96
|
+
retrieved documents. Defaults to False.
|
|
97
|
+
vector_retriever (Optional[VectorRetriever], optional):
|
|
98
|
+
VectorRetriever to use for retrieving documents.
|
|
99
|
+
Defaults to None.
|
|
100
|
+
config (Optional[Dict[str, Any]], optional): Config to pass into
|
|
101
|
+
the RAGQueryLoader. Defaults to None.
|
|
102
|
+
"""
|
|
103
|
+
fstore, gstore = graph_data
|
|
104
|
+
self.vector_retriever = vector_retriever
|
|
105
|
+
self.augment_query = augment_query
|
|
106
|
+
self.feature_store = fstore
|
|
107
|
+
self.graph_store = gstore
|
|
108
|
+
self.graph_store.edge_index = self.graph_store.edge_index.contiguous()
|
|
109
|
+
self.graph_store.register_feature_store(self.feature_store)
|
|
110
|
+
self.subgraph_filter = subgraph_filter
|
|
111
|
+
self.config = config
|
|
112
|
+
|
|
113
|
+
def _propagate_config(self, config: Dict[str, Any]):
|
|
114
|
+
"""Propagate the config the relevant components."""
|
|
115
|
+
self.feature_store.config = config
|
|
116
|
+
self.graph_store.config = config
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def config(self):
|
|
120
|
+
"""Get the config for the RAGQueryLoader."""
|
|
121
|
+
return self._config
|
|
122
|
+
|
|
123
|
+
@config.setter
|
|
124
|
+
def config(self, config: Dict[str, Any]):
|
|
125
|
+
"""Set the config for the RAGQueryLoader.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
config (Dict[str, Any]): The config to set.
|
|
129
|
+
"""
|
|
130
|
+
self._propagate_config(config)
|
|
131
|
+
self._config = config
|
|
132
|
+
|
|
133
|
+
def query(self, query: Any) -> Data:
|
|
134
|
+
"""Retrieve a subgraph associated with the query with all its feature
|
|
135
|
+
attributes.
|
|
136
|
+
"""
|
|
137
|
+
if self.vector_retriever:
|
|
138
|
+
retrieved_docs = self.vector_retriever.query(query)
|
|
139
|
+
|
|
140
|
+
if self.augment_query:
|
|
141
|
+
query = [query] + retrieved_docs
|
|
142
|
+
|
|
143
|
+
seed_nodes, query_enc = self.feature_store.retrieve_seed_nodes(query)
|
|
144
|
+
|
|
145
|
+
subgraph_sample = self.graph_store.sample_subgraph(seed_nodes)
|
|
146
|
+
|
|
147
|
+
data = self.feature_store.load_subgraph(sample=subgraph_sample)
|
|
148
|
+
|
|
149
|
+
# apply local filter
|
|
150
|
+
if self.subgraph_filter:
|
|
151
|
+
data = self.subgraph_filter(data, query)
|
|
152
|
+
if self.vector_retriever:
|
|
153
|
+
data.text_context = retrieved_docs
|
|
154
|
+
return data
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from .backend_utils import * # noqa
|
|
2
|
+
from .feature_store import KNNRAGFeatureStore
|
|
3
|
+
from .graph_store import NeighborSamplingRAGGraphStore
|
|
4
|
+
from .vectorrag import DocumentRetriever
|
|
5
|
+
|
|
6
|
+
__all__ = classes = [
|
|
7
|
+
'KNNRAGFeatureStore',
|
|
8
|
+
'NeighborSamplingRAGGraphStore',
|
|
9
|
+
'DocumentRetriever',
|
|
10
|
+
]
|
|
@@ -0,0 +1,443 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from enum import Enum, auto
|
|
4
|
+
from typing import (
|
|
5
|
+
Any,
|
|
6
|
+
Callable,
|
|
7
|
+
Dict,
|
|
8
|
+
Iterable,
|
|
9
|
+
Iterator,
|
|
10
|
+
List,
|
|
11
|
+
Optional,
|
|
12
|
+
Protocol,
|
|
13
|
+
Tuple,
|
|
14
|
+
Type,
|
|
15
|
+
Union,
|
|
16
|
+
no_type_check,
|
|
17
|
+
runtime_checkable,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import torch
|
|
22
|
+
from torch import Tensor
|
|
23
|
+
from torch.nn import Module
|
|
24
|
+
|
|
25
|
+
from torch_geometric.data import Data, FeatureStore, GraphStore
|
|
26
|
+
from torch_geometric.distributed import (
|
|
27
|
+
LocalFeatureStore,
|
|
28
|
+
LocalGraphStore,
|
|
29
|
+
Partitioner,
|
|
30
|
+
)
|
|
31
|
+
from torch_geometric.llm.large_graph_indexer import (
|
|
32
|
+
EDGE_RELATION,
|
|
33
|
+
LargeGraphIndexer,
|
|
34
|
+
TripletLike,
|
|
35
|
+
)
|
|
36
|
+
from torch_geometric.llm.models import SentenceTransformer
|
|
37
|
+
from torch_geometric.typing import EdgeType, NodeType
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
from pandas import DataFrame
|
|
41
|
+
except ImportError:
|
|
42
|
+
DataFrame = None
|
|
43
|
+
RemoteGraphBackend = Tuple[FeatureStore, GraphStore]
|
|
44
|
+
|
|
45
|
+
# TODO: Make everything compatible with Hetero graphs aswell
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def preprocess_triplet(triplet: TripletLike) -> TripletLike:
|
|
49
|
+
h, r, t = triplet
|
|
50
|
+
return str(h).lower(), str(r).lower(), str(t).lower()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@no_type_check
|
|
54
|
+
def retrieval_via_pcst(
|
|
55
|
+
data: Data,
|
|
56
|
+
q_emb: Tensor,
|
|
57
|
+
textual_nodes: Any,
|
|
58
|
+
textual_edges: Any,
|
|
59
|
+
topk: int = 3,
|
|
60
|
+
topk_e: int = 5,
|
|
61
|
+
cost_e: float = 0.5,
|
|
62
|
+
num_clusters: int = 1,
|
|
63
|
+
) -> Tuple[Data, str]:
|
|
64
|
+
|
|
65
|
+
# skip PCST for bad graphs
|
|
66
|
+
booly = data.edge_attr is None or data.edge_attr.numel() == 0
|
|
67
|
+
booly = booly or data.x is None or data.x.numel() == 0
|
|
68
|
+
booly = booly or data.edge_index is None or data.edge_index.numel() == 0
|
|
69
|
+
if not booly:
|
|
70
|
+
c = 0.01
|
|
71
|
+
|
|
72
|
+
from pcst_fast import pcst_fast
|
|
73
|
+
|
|
74
|
+
root = -1
|
|
75
|
+
pruning = 'gw'
|
|
76
|
+
verbosity_level = 0
|
|
77
|
+
if topk > 0:
|
|
78
|
+
n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.x)
|
|
79
|
+
topk = min(topk, data.num_nodes)
|
|
80
|
+
_, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
|
|
81
|
+
|
|
82
|
+
n_prizes = torch.zeros_like(n_prizes)
|
|
83
|
+
n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
|
|
84
|
+
else:
|
|
85
|
+
n_prizes = torch.zeros(data.num_nodes)
|
|
86
|
+
|
|
87
|
+
if topk_e > 0:
|
|
88
|
+
e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.edge_attr)
|
|
89
|
+
topk_e = min(topk_e, e_prizes.unique().size(0))
|
|
90
|
+
|
|
91
|
+
topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e,
|
|
92
|
+
largest=True)
|
|
93
|
+
e_prizes[e_prizes < topk_e_values[-1]] = 0.0
|
|
94
|
+
last_topk_e_value = topk_e
|
|
95
|
+
for k in range(topk_e):
|
|
96
|
+
indices = e_prizes == topk_e_values[k]
|
|
97
|
+
value = min((topk_e - k) / sum(indices), last_topk_e_value - c)
|
|
98
|
+
e_prizes[indices] = value
|
|
99
|
+
last_topk_e_value = value * (1 - c)
|
|
100
|
+
# reduce the cost of the edges so that at least one edge is chosen
|
|
101
|
+
cost_e = min(cost_e, e_prizes.max().item() * (1 - c / 2))
|
|
102
|
+
else:
|
|
103
|
+
e_prizes = torch.zeros(data.num_edges)
|
|
104
|
+
|
|
105
|
+
costs = []
|
|
106
|
+
edges = []
|
|
107
|
+
virtual_n_prizes = []
|
|
108
|
+
virtual_edges = []
|
|
109
|
+
virtual_costs = []
|
|
110
|
+
mapping_n = {}
|
|
111
|
+
mapping_e = {}
|
|
112
|
+
for i, (src, dst) in enumerate(data.edge_index.t().numpy()):
|
|
113
|
+
prize_e = e_prizes[i]
|
|
114
|
+
if prize_e <= cost_e:
|
|
115
|
+
mapping_e[len(edges)] = i
|
|
116
|
+
edges.append((src, dst))
|
|
117
|
+
costs.append(cost_e - prize_e)
|
|
118
|
+
else:
|
|
119
|
+
virtual_node_id = data.num_nodes + len(virtual_n_prizes)
|
|
120
|
+
mapping_n[virtual_node_id] = i
|
|
121
|
+
virtual_edges.append((src, virtual_node_id))
|
|
122
|
+
virtual_edges.append((virtual_node_id, dst))
|
|
123
|
+
virtual_costs.append(0)
|
|
124
|
+
virtual_costs.append(0)
|
|
125
|
+
virtual_n_prizes.append(prize_e - cost_e)
|
|
126
|
+
|
|
127
|
+
prizes = np.concatenate([n_prizes, np.array(virtual_n_prizes)])
|
|
128
|
+
num_edges = len(edges)
|
|
129
|
+
if len(virtual_costs) > 0:
|
|
130
|
+
costs = np.array(costs + virtual_costs)
|
|
131
|
+
edges = np.array(edges + virtual_edges)
|
|
132
|
+
|
|
133
|
+
vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters,
|
|
134
|
+
pruning, verbosity_level)
|
|
135
|
+
|
|
136
|
+
selected_nodes = vertices[vertices < data.num_nodes]
|
|
137
|
+
selected_edges = [mapping_e[e] for e in edges if e < num_edges]
|
|
138
|
+
virtual_vertices = vertices[vertices >= data.num_nodes]
|
|
139
|
+
if len(virtual_vertices) > 0:
|
|
140
|
+
virtual_vertices = vertices[vertices >= data.num_nodes]
|
|
141
|
+
virtual_edges = [mapping_n[i] for i in virtual_vertices]
|
|
142
|
+
selected_edges = np.array(selected_edges + virtual_edges)
|
|
143
|
+
|
|
144
|
+
edge_index = data.edge_index[:, selected_edges]
|
|
145
|
+
selected_nodes = np.unique(
|
|
146
|
+
np.concatenate(
|
|
147
|
+
[selected_nodes, edge_index[0].numpy(),
|
|
148
|
+
edge_index[1].numpy()]))
|
|
149
|
+
|
|
150
|
+
n = textual_nodes.iloc[selected_nodes]
|
|
151
|
+
e = textual_edges.iloc[selected_edges]
|
|
152
|
+
else:
|
|
153
|
+
n = textual_nodes
|
|
154
|
+
e = textual_edges
|
|
155
|
+
|
|
156
|
+
desc = n.to_csv(index=False) + '\n' + e.to_csv(
|
|
157
|
+
index=False, columns=['src', 'edge_attr', 'dst'])
|
|
158
|
+
|
|
159
|
+
if booly:
|
|
160
|
+
return data, desc
|
|
161
|
+
|
|
162
|
+
mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}
|
|
163
|
+
src = [mapping[i] for i in edge_index[0].tolist()]
|
|
164
|
+
dst = [mapping[i] for i in edge_index[1].tolist()]
|
|
165
|
+
|
|
166
|
+
# HACK Added so that the subset of nodes and edges selected can be tracked
|
|
167
|
+
node_idx = np.array(data.node_idx)[selected_nodes]
|
|
168
|
+
edge_idx = np.array(data.edge_idx)[selected_edges]
|
|
169
|
+
|
|
170
|
+
data = Data(
|
|
171
|
+
x=data.x[selected_nodes],
|
|
172
|
+
edge_index=torch.tensor([src, dst]).to(torch.long),
|
|
173
|
+
edge_attr=data.edge_attr[selected_edges],
|
|
174
|
+
# HACK: track subset of selected nodes/edges
|
|
175
|
+
node_idx=node_idx,
|
|
176
|
+
edge_idx=edge_idx,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
return data, desc
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def batch_knn(query_enc: Tensor, embeds: Tensor,
|
|
183
|
+
k: int) -> Iterator[Tuple[Tensor, Tensor]]:
|
|
184
|
+
from torchmetrics.functional import pairwise_cosine_similarity
|
|
185
|
+
prizes = pairwise_cosine_similarity(query_enc, embeds.to(query_enc.device))
|
|
186
|
+
topk = min(k, len(embeds))
|
|
187
|
+
for i, q in enumerate(prizes):
|
|
188
|
+
_, indices = torch.topk(q, topk, largest=True)
|
|
189
|
+
yield indices, query_enc[i].unsqueeze(0)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
# Adapted from LocalGraphStore
|
|
193
|
+
@runtime_checkable
|
|
194
|
+
class ConvertableGraphStore(Protocol):
|
|
195
|
+
@classmethod
|
|
196
|
+
def from_data(
|
|
197
|
+
cls,
|
|
198
|
+
edge_id: Tensor,
|
|
199
|
+
edge_index: Tensor,
|
|
200
|
+
num_nodes: int,
|
|
201
|
+
is_sorted: bool = False,
|
|
202
|
+
) -> GraphStore:
|
|
203
|
+
...
|
|
204
|
+
|
|
205
|
+
@classmethod
|
|
206
|
+
def from_hetero_data(
|
|
207
|
+
cls,
|
|
208
|
+
edge_id_dict: Dict[EdgeType, Tensor],
|
|
209
|
+
edge_index_dict: Dict[EdgeType, Tensor],
|
|
210
|
+
num_nodes_dict: Dict[NodeType, int],
|
|
211
|
+
is_sorted: bool = False,
|
|
212
|
+
) -> GraphStore:
|
|
213
|
+
...
|
|
214
|
+
|
|
215
|
+
@classmethod
|
|
216
|
+
def from_partition(cls, root: str, pid: int) -> GraphStore:
|
|
217
|
+
...
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
# Adapted from LocalFeatureStore
|
|
221
|
+
@runtime_checkable
|
|
222
|
+
class ConvertableFeatureStore(Protocol):
|
|
223
|
+
@classmethod
|
|
224
|
+
def from_data(
|
|
225
|
+
cls,
|
|
226
|
+
node_id: Tensor,
|
|
227
|
+
x: Optional[Tensor] = None,
|
|
228
|
+
y: Optional[Tensor] = None,
|
|
229
|
+
edge_id: Optional[Tensor] = None,
|
|
230
|
+
edge_attr: Optional[Tensor] = None,
|
|
231
|
+
) -> FeatureStore:
|
|
232
|
+
...
|
|
233
|
+
|
|
234
|
+
@classmethod
|
|
235
|
+
def from_hetero_data(
|
|
236
|
+
cls,
|
|
237
|
+
node_id_dict: Dict[NodeType, Tensor],
|
|
238
|
+
x_dict: Optional[Dict[NodeType, Tensor]] = None,
|
|
239
|
+
y_dict: Optional[Dict[NodeType, Tensor]] = None,
|
|
240
|
+
edge_id_dict: Optional[Dict[EdgeType, Tensor]] = None,
|
|
241
|
+
edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None,
|
|
242
|
+
) -> FeatureStore:
|
|
243
|
+
...
|
|
244
|
+
|
|
245
|
+
@classmethod
|
|
246
|
+
def from_partition(cls, root: str, pid: int) -> FeatureStore:
|
|
247
|
+
...
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class RemoteDataType(Enum):
|
|
251
|
+
DATA = auto()
|
|
252
|
+
PARTITION = auto()
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
@dataclass
|
|
256
|
+
class RemoteGraphBackendLoader:
|
|
257
|
+
"""Utility class to load triplets into a RAG Backend."""
|
|
258
|
+
path: str
|
|
259
|
+
datatype: RemoteDataType
|
|
260
|
+
graph_store_type: Type[ConvertableGraphStore]
|
|
261
|
+
feature_store_type: Type[ConvertableFeatureStore]
|
|
262
|
+
|
|
263
|
+
def load(self, pid: Optional[int] = None) -> RemoteGraphBackend:
|
|
264
|
+
if self.datatype == RemoteDataType.DATA:
|
|
265
|
+
data_obj = torch.load(self.path, weights_only=False)
|
|
266
|
+
# is_sorted=true since assume nodes come sorted from indexer
|
|
267
|
+
graph_store = self.graph_store_type.from_data(
|
|
268
|
+
edge_id=data_obj['edge_id'], edge_index=data_obj.edge_index,
|
|
269
|
+
num_nodes=data_obj.num_nodes, is_sorted=True)
|
|
270
|
+
feature_store = self.feature_store_type.from_data(
|
|
271
|
+
node_id=data_obj['node_id'], x=data_obj.x,
|
|
272
|
+
edge_id=data_obj['edge_id'], edge_attr=data_obj.edge_attr)
|
|
273
|
+
elif self.datatype == RemoteDataType.PARTITION:
|
|
274
|
+
if pid is None:
|
|
275
|
+
assert pid is not None, \
|
|
276
|
+
"Partition ID must be defined for loading from a " \
|
|
277
|
+
+ "partitioned store."
|
|
278
|
+
graph_store = self.graph_store_type.from_partition(self.path, pid)
|
|
279
|
+
feature_store = self.feature_store_type.from_partition(
|
|
280
|
+
self.path, pid)
|
|
281
|
+
else:
|
|
282
|
+
raise NotImplementedError
|
|
283
|
+
return (feature_store, graph_store)
|
|
284
|
+
|
|
285
|
+
def __del__(self) -> None:
|
|
286
|
+
if os.path.exists(self.path):
|
|
287
|
+
os.remove(self.path)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def create_graph_from_triples(
|
|
291
|
+
triples: Iterable[TripletLike],
|
|
292
|
+
embedding_model: Union[Module, Callable],
|
|
293
|
+
embedding_method_kwargs: Optional[Dict[str, Any]] = None,
|
|
294
|
+
pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,
|
|
295
|
+
) -> Data:
|
|
296
|
+
"""Utility function that can be used to create a graph from triples."""
|
|
297
|
+
# Resolve callable methods
|
|
298
|
+
embedding_method_kwargs = embedding_method_kwargs \
|
|
299
|
+
if embedding_method_kwargs is not None else dict()
|
|
300
|
+
|
|
301
|
+
indexer = LargeGraphIndexer.from_triplets(triples,
|
|
302
|
+
pre_transform=pre_transform)
|
|
303
|
+
node_feats = embedding_model(indexer.get_unique_node_features(),
|
|
304
|
+
**embedding_method_kwargs)
|
|
305
|
+
indexer.add_node_feature('x', node_feats)
|
|
306
|
+
|
|
307
|
+
edge_feats = embedding_model(
|
|
308
|
+
indexer.get_unique_edge_features(feature_name=EDGE_RELATION),
|
|
309
|
+
**embedding_method_kwargs)
|
|
310
|
+
indexer.add_edge_feature(new_feature_name="edge_attr",
|
|
311
|
+
new_feature_vals=edge_feats,
|
|
312
|
+
map_from_feature=EDGE_RELATION)
|
|
313
|
+
|
|
314
|
+
data = indexer.to_data(node_feature_name='x',
|
|
315
|
+
edge_feature_name='edge_attr')
|
|
316
|
+
data = data.to("cpu")
|
|
317
|
+
return data
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def create_remote_backend_from_graph_data(
|
|
321
|
+
graph_data: Data,
|
|
322
|
+
graph_db: Type[ConvertableGraphStore] = LocalGraphStore,
|
|
323
|
+
feature_db: Type[ConvertableFeatureStore] = LocalFeatureStore,
|
|
324
|
+
path: str = '',
|
|
325
|
+
n_parts: int = 1,
|
|
326
|
+
) -> RemoteGraphBackendLoader:
|
|
327
|
+
"""Utility function that can be used to create a RAG Backend from triples.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
graph_data (Data): Graph data to load into the RAG Backend.
|
|
331
|
+
graph_db (Type[ConvertableGraphStore], optional): GraphStore class to
|
|
332
|
+
use. Defaults to LocalGraphStore.
|
|
333
|
+
feature_db (Type[ConvertableFeatureStore], optional): FeatureStore
|
|
334
|
+
class to use. Defaults to LocalFeatureStore.
|
|
335
|
+
path (str, optional): path to save resulting stores. Defaults to ''.
|
|
336
|
+
n_parts (int, optional): Number of partitons to store in.
|
|
337
|
+
Defaults to 1.
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
RemoteGraphBackendLoader: Loader to load RAG backend from disk or
|
|
341
|
+
memory.
|
|
342
|
+
"""
|
|
343
|
+
# Will return attribute errors for missing attributes
|
|
344
|
+
if not issubclass(graph_db, ConvertableGraphStore):
|
|
345
|
+
_ = graph_db.from_data
|
|
346
|
+
_ = graph_db.from_hetero_data
|
|
347
|
+
_ = graph_db.from_partition
|
|
348
|
+
elif not issubclass(feature_db, ConvertableFeatureStore):
|
|
349
|
+
_ = feature_db.from_data
|
|
350
|
+
_ = feature_db.from_hetero_data
|
|
351
|
+
_ = feature_db.from_partition
|
|
352
|
+
|
|
353
|
+
if n_parts == 1:
|
|
354
|
+
torch.save(graph_data, path)
|
|
355
|
+
return RemoteGraphBackendLoader(path, RemoteDataType.DATA, graph_db,
|
|
356
|
+
feature_db)
|
|
357
|
+
else:
|
|
358
|
+
partitioner = Partitioner(data=graph_data, num_parts=n_parts,
|
|
359
|
+
root=path)
|
|
360
|
+
partitioner.generate_partition()
|
|
361
|
+
return RemoteGraphBackendLoader(path, RemoteDataType.PARTITION,
|
|
362
|
+
graph_db, feature_db)
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def make_pcst_filter(triples: List[Tuple[str, str,
|
|
366
|
+
str]], model: SentenceTransformer,
|
|
367
|
+
topk: int = 5, topk_e: int = 5, cost_e: float = 0.5,
|
|
368
|
+
num_clusters: int = 1) -> Callable[[Data, str], Data]:
|
|
369
|
+
"""Creates a PCST (Prize Collecting Tree) filter.
|
|
370
|
+
|
|
371
|
+
:param triples: List of triples (head, relation, tail) representing KG data
|
|
372
|
+
:param model: SentenceTransformer model for embedding text
|
|
373
|
+
:param topk: Number of top-K results to return (default: 5)
|
|
374
|
+
:param topk_e: Number of top-K entity results to return (default: 5)
|
|
375
|
+
:param cost_e: Cost of edges (default: 0.5)
|
|
376
|
+
:param num_clusters: Number of connected components in the PCST output.
|
|
377
|
+
:return: PCST Filter function
|
|
378
|
+
"""
|
|
379
|
+
if DataFrame is None:
|
|
380
|
+
raise Exception("PCST requires `pip install pandas`"
|
|
381
|
+
) # Check if pandas is installed
|
|
382
|
+
|
|
383
|
+
# Remove duplicate triples to ensure unique set
|
|
384
|
+
triples = list(dict.fromkeys(triples))
|
|
385
|
+
|
|
386
|
+
# Initialize empty list to store nodes (entities) from triples
|
|
387
|
+
nodes = []
|
|
388
|
+
|
|
389
|
+
# Iterate over triples to extract unique nodes (entities)
|
|
390
|
+
for h, _, t in triples:
|
|
391
|
+
for node in (h, t): # Extract head and tail entities from each triple
|
|
392
|
+
nodes.append(node)
|
|
393
|
+
|
|
394
|
+
# Remove duplicates and create final list of unique nodes
|
|
395
|
+
nodes = list(dict.fromkeys(nodes))
|
|
396
|
+
|
|
397
|
+
# Create full list of textual nodes (entities) for filtering
|
|
398
|
+
full_textual_nodes = nodes
|
|
399
|
+
|
|
400
|
+
def apply_retrieval_via_pcst(
|
|
401
|
+
graph: Data, # Input graph data
|
|
402
|
+
query: str, # Search query
|
|
403
|
+
) -> Data:
|
|
404
|
+
"""Applies PCST filtering for retrieval.
|
|
405
|
+
|
|
406
|
+
:param graph: Input graph data
|
|
407
|
+
:param query: Search query
|
|
408
|
+
:return: Retrieved graph/query data
|
|
409
|
+
"""
|
|
410
|
+
# PCST relies on numpy and pcst_fast pypi libs, hence to("cpu")
|
|
411
|
+
with torch.no_grad():
|
|
412
|
+
q_emb = model.encode([query]).to("cpu")
|
|
413
|
+
textual_nodes = [(int(i), full_textual_nodes[i])
|
|
414
|
+
for i in graph["node_idx"]]
|
|
415
|
+
textual_nodes = DataFrame(textual_nodes,
|
|
416
|
+
columns=["node_id", "node_attr"])
|
|
417
|
+
textual_edges = [triples[i] for i in graph["edge_idx"]]
|
|
418
|
+
textual_edges = DataFrame(textual_edges,
|
|
419
|
+
columns=["src", "edge_attr", "dst"])
|
|
420
|
+
out_graph, desc = retrieval_via_pcst(graph.to(q_emb.device), q_emb,
|
|
421
|
+
textual_nodes, textual_edges,
|
|
422
|
+
topk=topk, topk_e=topk_e,
|
|
423
|
+
cost_e=cost_e,
|
|
424
|
+
num_clusters=num_clusters)
|
|
425
|
+
out_graph["desc"] = desc
|
|
426
|
+
where_trips_start = desc.find("src,edge_attr,dst")
|
|
427
|
+
parsed_trips = []
|
|
428
|
+
for trip in desc[where_trips_start + 18:-1].split("\n"):
|
|
429
|
+
parsed_trips.append(tuple(trip.split(",")))
|
|
430
|
+
|
|
431
|
+
# Handle case where PCST returns an isolated node
|
|
432
|
+
"""
|
|
433
|
+
TODO find a better solution since these failed subgraphs
|
|
434
|
+
severely hurt accuracy.
|
|
435
|
+
"""
|
|
436
|
+
if str(parsed_trips) == "[('',)]" or out_graph.edge_index.numel() == 0:
|
|
437
|
+
out_graph["triples"] = []
|
|
438
|
+
else:
|
|
439
|
+
out_graph["triples"] = parsed_trips
|
|
440
|
+
out_graph["question"] = query
|
|
441
|
+
return out_graph
|
|
442
|
+
|
|
443
|
+
return apply_retrieval_via_pcst
|