aiagents4pharma 1.36.0__py3-none-any.whl → 1.38.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2aiagents4pharma/tests/test_main_agent.py +12 -4
- aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +2 -2
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +7 -6
- aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +1 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/__init__.py +0 -0
- aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +12 -11
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_multimodal_subgraph_extraction.py +152 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +36 -65
- aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/tools/multimodal_subgraph_extraction.py +374 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/multimodal_pcst.py +292 -0
- aiagents4pharma/talk2scholars/configs/tools/zotero_read/default.yaml +1 -0
- aiagents4pharma/talk2scholars/state/state_talk2scholars.py +33 -7
- aiagents4pharma/talk2scholars/tests/test_question_and_answer_tool.py +59 -3
- aiagents4pharma/talk2scholars/tests/test_read_helper_utils.py +110 -0
- aiagents4pharma/talk2scholars/tests/test_s2_display.py +20 -1
- aiagents4pharma/talk2scholars/tests/test_s2_query.py +17 -0
- aiagents4pharma/talk2scholars/tests/test_state.py +25 -1
- aiagents4pharma/talk2scholars/tests/test_zotero_pdf_downloader_utils.py +46 -0
- aiagents4pharma/talk2scholars/tests/test_zotero_read.py +35 -40
- aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +62 -40
- aiagents4pharma/talk2scholars/tools/s2/display_dataframe.py +6 -2
- aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +2 -1
- aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +7 -3
- aiagents4pharma/talk2scholars/tools/s2/search.py +2 -1
- aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +2 -1
- aiagents4pharma/talk2scholars/tools/zotero/utils/read_helper.py +79 -136
- aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_pdf_downloader.py +147 -0
- aiagents4pharma/talk2scholars/tools/zotero/zotero_read.py +42 -9
- {aiagents4pharma-1.36.0.dist-info → aiagents4pharma-1.38.0.dist-info}/METADATA +2 -1
- {aiagents4pharma-1.36.0.dist-info → aiagents4pharma-1.38.0.dist-info}/RECORD +36 -29
- {aiagents4pharma-1.36.0.dist-info → aiagents4pharma-1.38.0.dist-info}/WHEEL +1 -1
- {aiagents4pharma-1.36.0.dist-info → aiagents4pharma-1.38.0.dist-info}/licenses/LICENSE +0 -0
- {aiagents4pharma-1.36.0.dist-info → aiagents4pharma-1.38.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,374 @@
|
|
1
|
+
"""
|
2
|
+
Tool for performing multimodal subgraph extraction.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Type, Annotated
|
6
|
+
import logging
|
7
|
+
import pickle
|
8
|
+
import numpy as np
|
9
|
+
import pandas as pd
|
10
|
+
import hydra
|
11
|
+
import networkx as nx
|
12
|
+
from pydantic import BaseModel, Field
|
13
|
+
from langchain_core.tools import BaseTool
|
14
|
+
from langchain_core.messages import ToolMessage
|
15
|
+
from langchain_core.tools.base import InjectedToolCallId
|
16
|
+
from langgraph.types import Command
|
17
|
+
from langgraph.prebuilt import InjectedState
|
18
|
+
import torch
|
19
|
+
from torch_geometric.data import Data
|
20
|
+
from ..utils.extractions.multimodal_pcst import MultimodalPCSTPruning
|
21
|
+
from ..utils.embeddings.ollama import EmbeddingWithOllama
|
22
|
+
from .load_arguments import ArgumentData
|
23
|
+
|
24
|
+
# Initialize logger
|
25
|
+
logging.basicConfig(level=logging.INFO)
|
26
|
+
logger = logging.getLogger(__name__)
|
27
|
+
|
28
|
+
|
29
|
+
class MultimodalSubgraphExtractionInput(BaseModel):
|
30
|
+
"""
|
31
|
+
MultimodalSubgraphExtractionInput is a Pydantic model representing an input
|
32
|
+
for extracting a subgraph.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
prompt: Prompt to interact with the backend.
|
36
|
+
tool_call_id: Tool call ID.
|
37
|
+
state: Injected state.
|
38
|
+
arg_data: Argument for analytical process over graph data.
|
39
|
+
"""
|
40
|
+
|
41
|
+
tool_call_id: Annotated[str, InjectedToolCallId] = Field(
|
42
|
+
description="Tool call ID."
|
43
|
+
)
|
44
|
+
state: Annotated[dict, InjectedState] = Field(description="Injected state.")
|
45
|
+
prompt: str = Field(description="Prompt to interact with the backend.")
|
46
|
+
arg_data: ArgumentData = Field(
|
47
|
+
description="Experiment over graph data.", default=None
|
48
|
+
)
|
49
|
+
|
50
|
+
|
51
|
+
class MultimodalSubgraphExtractionTool(BaseTool):
|
52
|
+
"""
|
53
|
+
This tool performs subgraph extraction based on user's prompt by taking into account
|
54
|
+
the top-k nodes and edges.
|
55
|
+
"""
|
56
|
+
|
57
|
+
name: str = "subgraph_extraction"
|
58
|
+
description: str = "A tool for subgraph extraction based on user's prompt."
|
59
|
+
args_schema: Type[BaseModel] = MultimodalSubgraphExtractionInput
|
60
|
+
|
61
|
+
def _prepare_query_modalities(self,
|
62
|
+
prompt_emb: list,
|
63
|
+
state: Annotated[dict, InjectedState],
|
64
|
+
pyg_graph: Data) -> pd.DataFrame:
|
65
|
+
"""
|
66
|
+
Prepare the modality-specific query for subgraph extraction.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
prompt_emb: The embedding of the user prompt in a list.
|
70
|
+
state: The injected state for the tool.
|
71
|
+
pyg_graph: The PyTorch Geometric graph Data.
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
A DataFrame containing the query embeddings and modalities.
|
75
|
+
"""
|
76
|
+
# Initialize dataframes
|
77
|
+
multimodal_df = pd.DataFrame({"name": []})
|
78
|
+
query_df = pd.DataFrame({"node_id": [],
|
79
|
+
"node_type": [],
|
80
|
+
"x": [],
|
81
|
+
"desc_x": [],
|
82
|
+
"use_description": []})
|
83
|
+
|
84
|
+
# Loop over the uploaded files and find multimodal files
|
85
|
+
for i in range(len(state["uploaded_files"])):
|
86
|
+
# Check if multimodal file is uploaded
|
87
|
+
if state["uploaded_files"][i]["file_type"] == "multimodal":
|
88
|
+
# Read the Excel file
|
89
|
+
multimodal_df = pd.read_excel(state["uploaded_files"][i]["file_path"],
|
90
|
+
sheet_name=None)
|
91
|
+
|
92
|
+
# Check if the multimodal_df is empty
|
93
|
+
if len(multimodal_df) > 0:
|
94
|
+
# Merge all obtained dataframes into a single dataframe
|
95
|
+
multimodal_df = pd.concat(multimodal_df).reset_index()
|
96
|
+
multimodal_df.drop(columns=["level_1"], inplace=True)
|
97
|
+
multimodal_df.rename(columns={"level_0": "q_node_type",
|
98
|
+
"name": "q_node_name"}, inplace=True)
|
99
|
+
# Since an excel sheet name could not contain a `/`,
|
100
|
+
# but the node type can be 'gene/protein' as exists in the PrimeKG
|
101
|
+
multimodal_df["q_node_type"] = multimodal_df.q_node_type.apply(
|
102
|
+
lambda x: x.replace('-', '/')
|
103
|
+
)
|
104
|
+
|
105
|
+
# Convert PyG graph to a DataFrame for easier filtering
|
106
|
+
graph_df = pd.DataFrame({
|
107
|
+
"node_id": pyg_graph.node_id,
|
108
|
+
"node_name": pyg_graph.node_name,
|
109
|
+
"node_type": pyg_graph.node_type,
|
110
|
+
"x": pyg_graph.x,
|
111
|
+
"desc_x": pyg_graph.desc_x.tolist(),
|
112
|
+
})
|
113
|
+
|
114
|
+
# Make a query dataframe by merging the graph_df and multimodal_df
|
115
|
+
query_df = graph_df.merge(multimodal_df, how='cross')
|
116
|
+
query_df = query_df[
|
117
|
+
query_df.apply(
|
118
|
+
lambda x:
|
119
|
+
(x['q_node_name'].lower() in x['node_name'].lower()) & # node name
|
120
|
+
(x['node_type'] == x['q_node_type']), # node type
|
121
|
+
axis=1
|
122
|
+
)
|
123
|
+
]
|
124
|
+
query_df = query_df[['node_id', 'node_type', 'x', 'desc_x']].reset_index(drop=True)
|
125
|
+
query_df['use_description'] = False # set to False for modal-specific embeddings
|
126
|
+
|
127
|
+
# Update the state by adding the the selected node IDs
|
128
|
+
state["selections"] = query_df.groupby("node_type")["node_id"].apply(list).to_dict()
|
129
|
+
|
130
|
+
# Append a user prompt to the query dataframe
|
131
|
+
query_df = pd.concat([
|
132
|
+
query_df,
|
133
|
+
pd.DataFrame({
|
134
|
+
'node_id': 'user_prompt',
|
135
|
+
'node_type': 'prompt',
|
136
|
+
'x': prompt_emb,
|
137
|
+
'desc_x': prompt_emb,
|
138
|
+
'use_description': True # set to True for user prompt embedding
|
139
|
+
})
|
140
|
+
]).reset_index(drop=True)
|
141
|
+
|
142
|
+
return query_df
|
143
|
+
|
144
|
+
def _perform_subgraph_extraction(self,
|
145
|
+
state: Annotated[dict, InjectedState],
|
146
|
+
cfg: dict,
|
147
|
+
pyg_graph: Data,
|
148
|
+
query_df: pd.DataFrame) -> dict:
|
149
|
+
"""
|
150
|
+
Perform multimodal subgraph extraction based on modal-specific embeddings.
|
151
|
+
|
152
|
+
Args:
|
153
|
+
state: The injected state for the tool.
|
154
|
+
cfg: The configuration dictionary.
|
155
|
+
pyg_graph: The PyTorch Geometric graph Data.
|
156
|
+
query_df: The DataFrame containing the query embeddings and modalities.
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
A dictionary containing the extracted subgraph with nodes and edges.
|
160
|
+
"""
|
161
|
+
# Initialize the subgraph dictionary
|
162
|
+
subgraphs = {}
|
163
|
+
subgraphs["nodes"] = []
|
164
|
+
subgraphs["edges"] = []
|
165
|
+
|
166
|
+
# Loop over query embeddings and modalities
|
167
|
+
for q in query_df.iterrows():
|
168
|
+
# Prepare the PCSTPruning object and extract the subgraph
|
169
|
+
# Parameters were set in the configuration file obtained from Hydra
|
170
|
+
subgraph = MultimodalPCSTPruning(
|
171
|
+
topk=state["topk_nodes"],
|
172
|
+
topk_e=state["topk_edges"],
|
173
|
+
cost_e=cfg.cost_e,
|
174
|
+
c_const=cfg.c_const,
|
175
|
+
root=cfg.root,
|
176
|
+
num_clusters=cfg.num_clusters,
|
177
|
+
pruning=cfg.pruning,
|
178
|
+
verbosity_level=cfg.verbosity_level,
|
179
|
+
use_description=q[1]['use_description'],
|
180
|
+
).extract_subgraph(pyg_graph,
|
181
|
+
torch.tensor(q[1]['desc_x']), # description embedding
|
182
|
+
torch.tensor(q[1]['x']), # modal-specific embedding
|
183
|
+
q[1]['node_type'])
|
184
|
+
|
185
|
+
# Append the extracted subgraph to the dictionary
|
186
|
+
subgraphs["nodes"].append(subgraph["nodes"].tolist())
|
187
|
+
subgraphs["edges"].append(subgraph["edges"].tolist())
|
188
|
+
|
189
|
+
# Concatenate and get unique node and edge indices
|
190
|
+
subgraphs["nodes"] = np.unique(
|
191
|
+
np.concatenate([np.array(list_) for list_ in subgraphs["nodes"]])
|
192
|
+
)
|
193
|
+
subgraphs["edges"] = np.unique(
|
194
|
+
np.concatenate([np.array(list_) for list_ in subgraphs["edges"]])
|
195
|
+
)
|
196
|
+
|
197
|
+
return subgraphs
|
198
|
+
|
199
|
+
def _prepare_final_subgraph(self,
|
200
|
+
state:Annotated[dict, InjectedState],
|
201
|
+
subgraph: dict,
|
202
|
+
graph: dict,
|
203
|
+
cfg) -> dict:
|
204
|
+
"""
|
205
|
+
Prepare the subgraph based on the extracted subgraph.
|
206
|
+
|
207
|
+
Args:
|
208
|
+
state: The injected state for the tool.
|
209
|
+
subgraph: The extracted subgraph.
|
210
|
+
graph: The initial graph containing PyG and textualized graph.
|
211
|
+
cfg: The configuration dictionary.
|
212
|
+
|
213
|
+
Returns:
|
214
|
+
A dictionary containing the PyG graph, NetworkX graph, and textualized graph.
|
215
|
+
"""
|
216
|
+
# print(subgraph)
|
217
|
+
# Prepare the PyTorch Geometric graph
|
218
|
+
mapping = {n: i for i, n in enumerate(subgraph["nodes"].tolist())}
|
219
|
+
pyg_graph = Data(
|
220
|
+
# Node features
|
221
|
+
# x=pyg_graph.x[subgraph["nodes"]],
|
222
|
+
x=[graph["pyg"].x[i] for i in subgraph["nodes"]],
|
223
|
+
node_id=np.array(graph["pyg"].node_id)[subgraph["nodes"]].tolist(),
|
224
|
+
node_name=np.array(graph["pyg"].node_id)[subgraph["nodes"]].tolist(),
|
225
|
+
enriched_node=np.array(graph["pyg"].enriched_node)[subgraph["nodes"]].tolist(),
|
226
|
+
num_nodes=len(subgraph["nodes"]),
|
227
|
+
# Edge features
|
228
|
+
edge_index=torch.LongTensor(
|
229
|
+
[
|
230
|
+
[
|
231
|
+
mapping[i]
|
232
|
+
for i in graph["pyg"].edge_index[:, subgraph["edges"]][0].tolist()
|
233
|
+
],
|
234
|
+
[
|
235
|
+
mapping[i]
|
236
|
+
for i in graph["pyg"].edge_index[:, subgraph["edges"]][1].tolist()
|
237
|
+
],
|
238
|
+
]
|
239
|
+
),
|
240
|
+
edge_attr=graph["pyg"].edge_attr[subgraph["edges"]],
|
241
|
+
edge_type=np.array(graph["pyg"].edge_type)[subgraph["edges"]].tolist(),
|
242
|
+
relation=np.array(graph["pyg"].edge_type)[subgraph["edges"]].tolist(),
|
243
|
+
label=np.array(graph["pyg"].edge_type)[subgraph["edges"]].tolist(),
|
244
|
+
enriched_edge=np.array(graph["pyg"].enriched_edge)[subgraph["edges"]].tolist(),
|
245
|
+
)
|
246
|
+
|
247
|
+
# Networkx DiGraph construction to be visualized in the frontend
|
248
|
+
nx_graph = nx.DiGraph()
|
249
|
+
# Add nodes with attributes
|
250
|
+
node_colors = {n: cfg.node_colors_dict[k]
|
251
|
+
for k, v in state["selections"].items() for n in v}
|
252
|
+
for n in pyg_graph.node_name:
|
253
|
+
nx_graph.add_node(n, color=node_colors.get(n, None))
|
254
|
+
|
255
|
+
# Add edges with attributes
|
256
|
+
edges = zip(
|
257
|
+
pyg_graph.edge_index[0].tolist(),
|
258
|
+
pyg_graph.edge_index[1].tolist(),
|
259
|
+
pyg_graph.edge_type
|
260
|
+
)
|
261
|
+
for src, dst, edge_type in edges:
|
262
|
+
nx_graph.add_edge(
|
263
|
+
pyg_graph.node_name[src],
|
264
|
+
pyg_graph.node_name[dst],
|
265
|
+
relation=edge_type,
|
266
|
+
label=edge_type,
|
267
|
+
)
|
268
|
+
|
269
|
+
# Prepare the textualized subgraph
|
270
|
+
textualized_graph = (
|
271
|
+
graph["text"]["nodes"].iloc[subgraph["nodes"]].to_csv(index=False)
|
272
|
+
+ "\n"
|
273
|
+
+ graph["text"]["edges"].iloc[subgraph["edges"]].to_csv(index=False)
|
274
|
+
)
|
275
|
+
|
276
|
+
return {
|
277
|
+
"graph_pyg": pyg_graph,
|
278
|
+
"graph_nx": nx_graph,
|
279
|
+
"graph_text": textualized_graph,
|
280
|
+
}
|
281
|
+
|
282
|
+
def _run(
|
283
|
+
self,
|
284
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
285
|
+
state: Annotated[dict, InjectedState],
|
286
|
+
prompt: str,
|
287
|
+
arg_data: ArgumentData = None,
|
288
|
+
) -> Command:
|
289
|
+
"""
|
290
|
+
Run the subgraph extraction tool.
|
291
|
+
|
292
|
+
Args:
|
293
|
+
tool_call_id: The tool call ID for the tool.
|
294
|
+
state: Injected state for the tool.
|
295
|
+
prompt: The prompt to interact with the backend.
|
296
|
+
arg_data (ArgumentData): The argument data.
|
297
|
+
|
298
|
+
Returns:
|
299
|
+
Command: The command to be executed.
|
300
|
+
"""
|
301
|
+
logger.log(logging.INFO, "Invoking subgraph_extraction tool")
|
302
|
+
|
303
|
+
# Load hydra configuration
|
304
|
+
with hydra.initialize(version_base=None, config_path="../configs"):
|
305
|
+
cfg = hydra.compose(
|
306
|
+
config_name="config", overrides=["tools/multimodal_subgraph_extraction=default"]
|
307
|
+
)
|
308
|
+
cfg = cfg.tools.multimodal_subgraph_extraction
|
309
|
+
|
310
|
+
# Retrieve source graph from the state
|
311
|
+
initial_graph = {}
|
312
|
+
initial_graph["source"] = state["dic_source_graph"][-1] # The last source graph as of now
|
313
|
+
# logger.log(logging.INFO, "Source graph: %s", source_graph)
|
314
|
+
|
315
|
+
# Load the knowledge graph
|
316
|
+
with open(initial_graph["source"]["kg_pyg_path"], "rb") as f:
|
317
|
+
initial_graph["pyg"] = pickle.load(f)
|
318
|
+
with open(initial_graph["source"]["kg_text_path"], "rb") as f:
|
319
|
+
initial_graph["text"] = pickle.load(f)
|
320
|
+
|
321
|
+
# Prepare the query embeddings and modalities
|
322
|
+
query_df = self._prepare_query_modalities(
|
323
|
+
[EmbeddingWithOllama(model_name=cfg.ollama_embeddings[0]).embed_query(prompt)],
|
324
|
+
state,
|
325
|
+
initial_graph["pyg"]
|
326
|
+
)
|
327
|
+
|
328
|
+
# Perform subgraph extraction
|
329
|
+
subgraphs = self._perform_subgraph_extraction(state,
|
330
|
+
cfg,
|
331
|
+
initial_graph["pyg"],
|
332
|
+
query_df)
|
333
|
+
|
334
|
+
# Prepare subgraph as a NetworkX graph and textualized graph
|
335
|
+
final_subgraph = self._prepare_final_subgraph(state,
|
336
|
+
subgraphs,
|
337
|
+
initial_graph,
|
338
|
+
cfg)
|
339
|
+
|
340
|
+
# Prepare the dictionary of extracted graph
|
341
|
+
dic_extracted_graph = {
|
342
|
+
"name": arg_data.extraction_name,
|
343
|
+
"tool_call_id": tool_call_id,
|
344
|
+
"graph_source": initial_graph["source"]["name"],
|
345
|
+
"topk_nodes": state["topk_nodes"],
|
346
|
+
"topk_edges": state["topk_edges"],
|
347
|
+
"graph_dict": {
|
348
|
+
"nodes": list(final_subgraph["graph_nx"].nodes(data=True)),
|
349
|
+
"edges": list(final_subgraph["graph_nx"].edges(data=True)),
|
350
|
+
},
|
351
|
+
"graph_text": final_subgraph["graph_text"],
|
352
|
+
"graph_summary": None,
|
353
|
+
}
|
354
|
+
|
355
|
+
# Prepare the dictionary of updated state
|
356
|
+
dic_updated_state_for_model = {}
|
357
|
+
for key, value in {
|
358
|
+
"dic_extracted_graph": [dic_extracted_graph],
|
359
|
+
}.items():
|
360
|
+
if value:
|
361
|
+
dic_updated_state_for_model[key] = value
|
362
|
+
|
363
|
+
# Return the updated state of the tool
|
364
|
+
return Command(
|
365
|
+
update=dic_updated_state_for_model | {
|
366
|
+
# update the message history
|
367
|
+
"messages": [
|
368
|
+
ToolMessage(
|
369
|
+
content=f"Subgraph Extraction Result of {arg_data.extraction_name}",
|
370
|
+
tool_call_id=tool_call_id,
|
371
|
+
)
|
372
|
+
],
|
373
|
+
}
|
374
|
+
)
|
@@ -0,0 +1,292 @@
|
|
1
|
+
"""
|
2
|
+
Exctraction of multimodal subgraph using Prize-Collecting Steiner Tree (PCST) algorithm.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Tuple, NamedTuple
|
6
|
+
import numpy as np
|
7
|
+
import pandas as pd
|
8
|
+
import torch
|
9
|
+
import pcst_fast
|
10
|
+
from torch_geometric.data.data import Data
|
11
|
+
|
12
|
+
class MultimodalPCSTPruning(NamedTuple):
|
13
|
+
"""
|
14
|
+
Prize-Collecting Steiner Tree (PCST) pruning algorithm implementation inspired by G-Retriever
|
15
|
+
(He et al., 'G-Retriever: Retrieval-Augmented Generation for Textual Graph Understanding and
|
16
|
+
Question Answering', NeurIPS 2024) paper.
|
17
|
+
https://arxiv.org/abs/2402.07630
|
18
|
+
https://github.com/XiaoxinHe/G-Retriever/blob/main/src/dataset/utils/retrieval.py
|
19
|
+
|
20
|
+
Args:
|
21
|
+
topk: The number of top nodes to consider.
|
22
|
+
topk_e: The number of top edges to consider.
|
23
|
+
cost_e: The cost of the edges.
|
24
|
+
c_const: The constant value for the cost of the edges computation.
|
25
|
+
root: The root node of the subgraph, -1 for unrooted.
|
26
|
+
num_clusters: The number of clusters.
|
27
|
+
pruning: The pruning strategy to use.
|
28
|
+
verbosity_level: The verbosity level.
|
29
|
+
"""
|
30
|
+
topk: int = 3
|
31
|
+
topk_e: int = 3
|
32
|
+
cost_e: float = 0.5
|
33
|
+
c_const: float = 0.01
|
34
|
+
root: int = -1
|
35
|
+
num_clusters: int = 1
|
36
|
+
pruning: str = "gw"
|
37
|
+
verbosity_level: int = 0
|
38
|
+
use_description: bool = False
|
39
|
+
|
40
|
+
def _compute_node_prizes(self,
|
41
|
+
graph: Data,
|
42
|
+
query_emb: torch.Tensor,
|
43
|
+
modality: str) :
|
44
|
+
"""
|
45
|
+
Compute the node prizes based on the cosine similarity between the query and nodes.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
graph: The knowledge graph in PyTorch Geometric Data format.
|
49
|
+
query_emb: The query embedding in PyTorch Tensor format. This can be an embedding of
|
50
|
+
a prompt, sequence, or any other feature to be used for the subgraph extraction.
|
51
|
+
modality: The modality to use for the subgraph extraction based on the node type.
|
52
|
+
|
53
|
+
Returns:
|
54
|
+
The prizes of the nodes.
|
55
|
+
"""
|
56
|
+
# Convert PyG graph to a DataFrame
|
57
|
+
graph_df = pd.DataFrame({
|
58
|
+
"node_type": graph.node_type,
|
59
|
+
"desc_x": [x.tolist() for x in graph.desc_x],
|
60
|
+
"x": [list(x) for x in graph.x],
|
61
|
+
"score": [0.0 for _ in range(len(graph.node_id))],
|
62
|
+
})
|
63
|
+
|
64
|
+
# Calculate cosine similarity for text features and update the score
|
65
|
+
if self.use_description:
|
66
|
+
graph_df.loc[:, "score"] = torch.nn.CosineSimilarity(dim=-1)(
|
67
|
+
query_emb,
|
68
|
+
torch.tensor(list(graph_df.desc_x.values)) # Using textual description features
|
69
|
+
).tolist()
|
70
|
+
else:
|
71
|
+
graph_df.loc[graph_df["node_type"] == modality,
|
72
|
+
"score"] = torch.nn.CosineSimilarity(dim=-1)(
|
73
|
+
query_emb,
|
74
|
+
torch.tensor(list(graph_df[graph_df["node_type"]== modality].x.values))
|
75
|
+
).tolist()
|
76
|
+
|
77
|
+
# Set the prizes for nodes based on the similarity scores
|
78
|
+
n_prizes = torch.tensor(graph_df.score.values, dtype=torch.float32)
|
79
|
+
# n_prizes = torch.nn.CosineSimilarity(dim=-1)(query_emb, graph.x)
|
80
|
+
topk = min(self.topk, graph.num_nodes)
|
81
|
+
_, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
|
82
|
+
n_prizes = torch.zeros_like(n_prizes)
|
83
|
+
n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
|
84
|
+
|
85
|
+
return n_prizes
|
86
|
+
|
87
|
+
def _compute_edge_prizes(self,
|
88
|
+
graph: Data,
|
89
|
+
text_emb: torch.Tensor) :
|
90
|
+
"""
|
91
|
+
Compute the node prizes based on the cosine similarity between the query and nodes.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
graph: The knowledge graph in PyTorch Geometric Data format.
|
95
|
+
text_emb: The textual description embedding in PyTorch Tensor format.
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
The prizes of the nodes.
|
99
|
+
"""
|
100
|
+
# Note that as of now, the edge features are based on textual features
|
101
|
+
# Compute prizes for edges
|
102
|
+
e_prizes = torch.nn.CosineSimilarity(dim=-1)(text_emb, graph.edge_attr)
|
103
|
+
unique_prizes, inverse_indices = e_prizes.unique(return_inverse=True)
|
104
|
+
topk_e = min(self.topk_e, unique_prizes.size(0))
|
105
|
+
topk_e_values, _ = torch.topk(unique_prizes, topk_e, largest=True)
|
106
|
+
e_prizes[e_prizes < topk_e_values[-1]] = 0.0
|
107
|
+
last_topk_e_value = topk_e
|
108
|
+
for k in range(topk_e):
|
109
|
+
indices = inverse_indices == (
|
110
|
+
unique_prizes == topk_e_values[k]
|
111
|
+
).nonzero(as_tuple=True)[0]
|
112
|
+
value = min((topk_e - k) / indices.sum().item(), last_topk_e_value)
|
113
|
+
e_prizes[indices] = value
|
114
|
+
last_topk_e_value = value * (1 - self.c_const)
|
115
|
+
|
116
|
+
return e_prizes
|
117
|
+
|
118
|
+
def compute_prizes(self,
|
119
|
+
graph: Data,
|
120
|
+
text_emb: torch.Tensor,
|
121
|
+
query_emb: torch.Tensor,
|
122
|
+
modality: str):
|
123
|
+
"""
|
124
|
+
Compute the node prizes based on the cosine similarity between the query and nodes,
|
125
|
+
as well as the edge prizes based on the cosine similarity between the query and edges.
|
126
|
+
Note that the node and edge embeddings shall use the same embedding model and dimensions
|
127
|
+
with the query.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
graph: The knowledge graph in PyTorch Geometric Data format.
|
131
|
+
text_emb: The textual description embedding in PyTorch Tensor format.
|
132
|
+
query_emb: The query embedding in PyTorch Tensor format. This can be an embedding of
|
133
|
+
a prompt, sequence, or any other feature to be used for the subgraph extraction.
|
134
|
+
modality: The modality to use for the subgraph extraction based on node type.
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
The prizes of the nodes and edges.
|
138
|
+
"""
|
139
|
+
# Compute prizes for nodes
|
140
|
+
n_prizes = self._compute_node_prizes(graph, query_emb, modality)
|
141
|
+
|
142
|
+
# Compute prizes for edges
|
143
|
+
e_prizes = self._compute_edge_prizes(graph, text_emb)
|
144
|
+
|
145
|
+
return {"nodes": n_prizes, "edges": e_prizes}
|
146
|
+
|
147
|
+
def compute_subgraph_costs(self,
|
148
|
+
graph: Data,
|
149
|
+
prizes: dict) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
150
|
+
"""
|
151
|
+
Compute the costs in constructing the subgraph proposed by G-Retriever paper.
|
152
|
+
|
153
|
+
Args:
|
154
|
+
graph: The knowledge graph in PyTorch Geometric Data format.
|
155
|
+
prizes: The prizes of the nodes and the edges.
|
156
|
+
|
157
|
+
Returns:
|
158
|
+
edges: The edges of the subgraph, consisting of edges and number of edges without
|
159
|
+
virtual edges.
|
160
|
+
prizes: The prizes of the subgraph.
|
161
|
+
costs: The costs of the subgraph.
|
162
|
+
"""
|
163
|
+
# Logic to reduce the cost of the edges such that at least one edge is selected
|
164
|
+
updated_cost_e = min(
|
165
|
+
self.cost_e,
|
166
|
+
prizes["edges"].max().item() * (1 - self.c_const / 2),
|
167
|
+
)
|
168
|
+
|
169
|
+
# Initialize variables
|
170
|
+
edges = []
|
171
|
+
costs = []
|
172
|
+
virtual = {
|
173
|
+
"n_prizes": [],
|
174
|
+
"edges": [],
|
175
|
+
"costs": [],
|
176
|
+
}
|
177
|
+
mapping = {"nodes": {}, "edges": {}}
|
178
|
+
|
179
|
+
# Compute the costs, edges, and virtual variables based on the prizes
|
180
|
+
for i, (src, dst) in enumerate(graph.edge_index.T.numpy()):
|
181
|
+
prize_e = prizes["edges"][i]
|
182
|
+
if prize_e <= updated_cost_e:
|
183
|
+
mapping["edges"][len(edges)] = i
|
184
|
+
edges.append((src, dst))
|
185
|
+
costs.append(updated_cost_e - prize_e)
|
186
|
+
else:
|
187
|
+
virtual_node_id = graph.num_nodes + len(virtual["n_prizes"])
|
188
|
+
mapping["nodes"][virtual_node_id] = i
|
189
|
+
virtual["edges"].append((src, virtual_node_id))
|
190
|
+
virtual["edges"].append((virtual_node_id, dst))
|
191
|
+
virtual["costs"].append(0)
|
192
|
+
virtual["costs"].append(0)
|
193
|
+
virtual["n_prizes"].append(prize_e - updated_cost_e)
|
194
|
+
prizes = np.concatenate([prizes["nodes"], np.array(virtual["n_prizes"])])
|
195
|
+
edges_dict = {}
|
196
|
+
edges_dict["edges"] = edges
|
197
|
+
edges_dict["num_prior_edges"] = len(edges)
|
198
|
+
# Final computation of the costs and edges based on the virtual costs and virtual edges
|
199
|
+
if len(virtual["costs"]) > 0:
|
200
|
+
costs = np.array(costs + virtual["costs"])
|
201
|
+
edges = np.array(edges + virtual["edges"])
|
202
|
+
edges_dict["edges"] = edges
|
203
|
+
|
204
|
+
return edges_dict, prizes, costs, mapping
|
205
|
+
|
206
|
+
def get_subgraph_nodes_edges(
|
207
|
+
self, graph: Data, vertices: np.ndarray, edges_dict: dict, mapping: dict,
|
208
|
+
) -> dict:
|
209
|
+
"""
|
210
|
+
Get the selected nodes and edges of the subgraph based on the vertices and edges computed
|
211
|
+
by the PCST algorithm.
|
212
|
+
|
213
|
+
Args:
|
214
|
+
graph: The knowledge graph in PyTorch Geometric Data format.
|
215
|
+
vertices: The vertices of the subgraph computed by the PCST algorithm.
|
216
|
+
edges_dict: The dictionary of edges of the subgraph computed by the PCST algorithm,
|
217
|
+
and the number of prior edges (without virtual edges).
|
218
|
+
mapping: The mapping dictionary of the nodes and edges.
|
219
|
+
num_prior_edges: The number of edges before adding virtual edges.
|
220
|
+
|
221
|
+
Returns:
|
222
|
+
The selected nodes and edges of the extracted subgraph.
|
223
|
+
"""
|
224
|
+
# Get edges information
|
225
|
+
edges = edges_dict["edges"]
|
226
|
+
num_prior_edges = edges_dict["num_prior_edges"]
|
227
|
+
# Retrieve the selected nodes and edges based on the given vertices and edges
|
228
|
+
subgraph_nodes = vertices[vertices < graph.num_nodes]
|
229
|
+
subgraph_edges = [mapping["edges"][e] for e in edges if e < num_prior_edges]
|
230
|
+
virtual_vertices = vertices[vertices >= graph.num_nodes]
|
231
|
+
if len(virtual_vertices) > 0:
|
232
|
+
virtual_vertices = vertices[vertices >= graph.num_nodes]
|
233
|
+
virtual_edges = [mapping["nodes"][i] for i in virtual_vertices]
|
234
|
+
subgraph_edges = np.array(subgraph_edges + virtual_edges)
|
235
|
+
edge_index = graph.edge_index[:, subgraph_edges]
|
236
|
+
subgraph_nodes = np.unique(
|
237
|
+
np.concatenate(
|
238
|
+
[subgraph_nodes, edge_index[0].numpy(), edge_index[1].numpy()]
|
239
|
+
)
|
240
|
+
)
|
241
|
+
|
242
|
+
return {"nodes": subgraph_nodes, "edges": subgraph_edges}
|
243
|
+
|
244
|
+
def extract_subgraph(self,
|
245
|
+
graph: Data,
|
246
|
+
text_emb: torch.Tensor,
|
247
|
+
query_emb: torch.Tensor,
|
248
|
+
modality: str) -> dict:
|
249
|
+
"""
|
250
|
+
Perform the Prize-Collecting Steiner Tree (PCST) algorithm to extract the subgraph.
|
251
|
+
|
252
|
+
Args:
|
253
|
+
graph: The knowledge graph in PyTorch Geometric Data format.
|
254
|
+
text_emb: The textual description embedding in PyTorch Tensor format.
|
255
|
+
query_emb: The query embedding in PyTorch Tensor format. This can be an embedding of
|
256
|
+
a prompt, sequence, or any other feature to be used for the subgraph extraction.
|
257
|
+
modality: The modality to use for the subgraph extraction
|
258
|
+
(e.g., "text", "sequence", "smiles").
|
259
|
+
|
260
|
+
Returns:
|
261
|
+
The selected nodes and edges of the subgraph.
|
262
|
+
"""
|
263
|
+
# Assert the topk and topk_e values for subgraph retrieval
|
264
|
+
assert self.topk > 0, "topk must be greater than or equal to 0"
|
265
|
+
assert self.topk_e > 0, "topk_e must be greater than or equal to 0"
|
266
|
+
|
267
|
+
# Retrieve the top-k nodes and edges based on the query embedding
|
268
|
+
prizes = self.compute_prizes(graph, text_emb, query_emb, modality)
|
269
|
+
|
270
|
+
# Compute costs in constructing the subgraph
|
271
|
+
edges_dict, prizes, costs, mapping = self.compute_subgraph_costs(
|
272
|
+
graph, prizes
|
273
|
+
)
|
274
|
+
|
275
|
+
# Retrieve the subgraph using the PCST algorithm
|
276
|
+
result_vertices, result_edges = pcst_fast.pcst_fast(
|
277
|
+
edges_dict["edges"],
|
278
|
+
prizes,
|
279
|
+
costs,
|
280
|
+
self.root,
|
281
|
+
self.num_clusters,
|
282
|
+
self.pruning,
|
283
|
+
self.verbosity_level,
|
284
|
+
)
|
285
|
+
|
286
|
+
subgraph = self.get_subgraph_nodes_edges(
|
287
|
+
graph,
|
288
|
+
result_vertices,
|
289
|
+
{"edges": result_edges, "num_prior_edges": edges_dict["num_prior_edges"]},
|
290
|
+
mapping)
|
291
|
+
|
292
|
+
return subgraph
|
@@ -2,6 +2,7 @@
|
|
2
2
|
library_type: "user" # Type of library ('user' or 'group')
|
3
3
|
default_limit: 2
|
4
4
|
request_timeout: 10
|
5
|
+
chunk_size: 16384 # Size (in bytes) for streaming PDF download chunks
|
5
6
|
user_id: ${oc.env:ZOTERO_USER_ID} # Load from environment variable
|
6
7
|
api_key: ${oc.env:ZOTERO_API_KEY} # Load from environment variable
|
7
8
|
|