aiagents4pharma 1.39.5__py3-none-any.whl → 1.40.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2aiagents4pharma/configs/agents/main_agent/default.yaml +26 -13
- aiagents4pharma/talk2aiagents4pharma/tests/test_main_agent.py +83 -3
- aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +4 -1
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +36 -5
- aiagents4pharma/talk2knowledgegraphs/milvus_data_dump.py +509 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +85 -23
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py +413 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_extractions_milvus_multimodal_pcst.py +175 -0
- aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/tools/milvus_multimodal_subgraph_extraction.py +509 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py +393 -0
- {aiagents4pharma-1.39.5.dist-info → aiagents4pharma-1.40.0.dist-info}/METADATA +13 -14
- {aiagents4pharma-1.39.5.dist-info → aiagents4pharma-1.40.0.dist-info}/RECORD +17 -12
- {aiagents4pharma-1.39.5.dist-info → aiagents4pharma-1.40.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.39.5.dist-info → aiagents4pharma-1.40.0.dist-info}/licenses/LICENSE +0 -0
- {aiagents4pharma-1.39.5.dist-info → aiagents4pharma-1.40.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,509 @@
|
|
1
|
+
"""
|
2
|
+
Tool for performing multimodal subgraph extraction.
|
3
|
+
"""
|
4
|
+
|
5
|
+
# import datetime
|
6
|
+
from typing import Type, Annotated
|
7
|
+
import logging
|
8
|
+
import hydra
|
9
|
+
import pandas as pd
|
10
|
+
from pydantic import BaseModel, Field
|
11
|
+
from langchain_core.tools import BaseTool
|
12
|
+
from langchain_core.messages import ToolMessage
|
13
|
+
from langchain_core.tools.base import InjectedToolCallId
|
14
|
+
from langgraph.types import Command
|
15
|
+
from langgraph.prebuilt import InjectedState
|
16
|
+
from pymilvus import Collection
|
17
|
+
from ..utils.extractions.milvus_multimodal_pcst import MultimodalPCSTPruning
|
18
|
+
from .load_arguments import ArgumentData
|
19
|
+
try:
|
20
|
+
import cupy as py
|
21
|
+
import cudf
|
22
|
+
df = cudf
|
23
|
+
except ImportError:
|
24
|
+
import numpy as py
|
25
|
+
df = pd
|
26
|
+
|
27
|
+
# Initialize logger
|
28
|
+
logging.basicConfig(level=logging.INFO)
|
29
|
+
logger = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
class MultimodalSubgraphExtractionInput(BaseModel):
|
33
|
+
"""
|
34
|
+
MultimodalSubgraphExtractionInput is a Pydantic model representing an input
|
35
|
+
for extracting a subgraph.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
prompt: Prompt to interact with the backend.
|
39
|
+
tool_call_id: Tool call ID.
|
40
|
+
state: Injected state.
|
41
|
+
arg_data: Argument for analytical process over graph data.
|
42
|
+
"""
|
43
|
+
|
44
|
+
tool_call_id: Annotated[str, InjectedToolCallId] = Field(
|
45
|
+
description="Tool call ID."
|
46
|
+
)
|
47
|
+
state: Annotated[dict, InjectedState] = Field(description="Injected state.")
|
48
|
+
prompt: str = Field(description="Prompt to interact with the backend.")
|
49
|
+
arg_data: ArgumentData = Field(
|
50
|
+
description="Experiment over graph data.", default=None
|
51
|
+
)
|
52
|
+
|
53
|
+
|
54
|
+
class MultimodalSubgraphExtractionTool(BaseTool):
|
55
|
+
"""
|
56
|
+
This tool performs subgraph extraction based on user's prompt by taking into account
|
57
|
+
the top-k nodes and edges.
|
58
|
+
"""
|
59
|
+
|
60
|
+
name: str = "subgraph_extraction"
|
61
|
+
description: str = "A tool for subgraph extraction based on user's prompt."
|
62
|
+
args_schema: Type[BaseModel] = MultimodalSubgraphExtractionInput
|
63
|
+
|
64
|
+
def _read_multimodal_files(self,
|
65
|
+
state: Annotated[dict, InjectedState]) -> df.DataFrame:
|
66
|
+
"""
|
67
|
+
Read the uploaded multimodal files and return a DataFrame.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
state: The injected state for the tool.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
A DataFrame containing the multimodal files.
|
74
|
+
"""
|
75
|
+
multimodal_df = df.DataFrame({"name": [], "node_type": []})
|
76
|
+
|
77
|
+
# Loop over the uploaded files and find multimodal files
|
78
|
+
logger.log(logging.INFO, "Looping over uploaded files")
|
79
|
+
for i in range(len(state["uploaded_files"])):
|
80
|
+
# Check if multimodal file is uploaded
|
81
|
+
if state["uploaded_files"][i]["file_type"] == "multimodal":
|
82
|
+
# Read the Excel file
|
83
|
+
multimodal_df = pd.read_excel(state["uploaded_files"][i]["file_path"],
|
84
|
+
sheet_name=None)
|
85
|
+
|
86
|
+
# Check if the multimodal_df is empty
|
87
|
+
logger.log(logging.INFO, "Checking if multimodal_df is empty")
|
88
|
+
if len(multimodal_df) > 0:
|
89
|
+
# Prepare multimodal_df
|
90
|
+
logger.log(logging.INFO, "Preparing multimodal_df")
|
91
|
+
# Merge all obtained dataframes into a single dataframe
|
92
|
+
multimodal_df = pd.concat(multimodal_df).reset_index()
|
93
|
+
multimodal_df = df.DataFrame(multimodal_df)
|
94
|
+
multimodal_df.drop(columns=["level_1"], inplace=True)
|
95
|
+
multimodal_df.rename(columns={"level_0": "q_node_type",
|
96
|
+
"name": "q_node_name"}, inplace=True)
|
97
|
+
# Since an excel sheet name could not contain a `/`,
|
98
|
+
# but the node type can be 'gene/protein' as exists in the PrimeKG
|
99
|
+
multimodal_df["q_node_type"] = multimodal_df["q_node_type"].str.replace('-', '_')
|
100
|
+
|
101
|
+
return multimodal_df
|
102
|
+
|
103
|
+
def _prepare_query_modalities(self,
|
104
|
+
prompt: dict,
|
105
|
+
state: Annotated[dict, InjectedState],
|
106
|
+
cfg_db: dict) -> df.DataFrame:
|
107
|
+
"""
|
108
|
+
Prepare the modality-specific query for subgraph extraction.
|
109
|
+
|
110
|
+
Args:
|
111
|
+
prompt: The dictionary containing the user prompt and embeddings.
|
112
|
+
state: The injected state for the tool.
|
113
|
+
cfg_db: The configuration dictionary for Milvus database.
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
A DataFrame containing the query embeddings and modalities.
|
117
|
+
"""
|
118
|
+
# Initialize dataframes
|
119
|
+
logger.log(logging.INFO, "Initializing dataframes")
|
120
|
+
query_df = []
|
121
|
+
prompt_df = df.DataFrame({
|
122
|
+
'node_id': 'user_prompt',
|
123
|
+
'node_name': 'User Prompt',
|
124
|
+
'node_type': 'prompt',
|
125
|
+
'feat': prompt["text"],
|
126
|
+
'feat_emb': prompt["emb"],
|
127
|
+
'desc': prompt["text"],
|
128
|
+
'desc_emb': prompt["emb"],
|
129
|
+
'use_description': True # set to True for user prompt embedding
|
130
|
+
})
|
131
|
+
|
132
|
+
# Read multimodal files uploaded by the user
|
133
|
+
multimodal_df = self._read_multimodal_files(state)
|
134
|
+
|
135
|
+
# Check if the multimodal_df is empty
|
136
|
+
logger.log(logging.INFO, "Prepare query modalities")
|
137
|
+
if len(multimodal_df) > 0:
|
138
|
+
# Query the Milvus database for each node type in multimodal_df
|
139
|
+
logger.log(logging.INFO, "Querying Milvus database for each node type in multimodal_df")
|
140
|
+
for node_type, node_type_df in multimodal_df.groupby("q_node_type"):
|
141
|
+
print(f"Processing node type: {node_type}")
|
142
|
+
|
143
|
+
# Load the collection
|
144
|
+
collection = Collection(
|
145
|
+
name=f"{cfg_db.milvus_db.database_name}_nodes_{node_type.replace('/', '_')}"
|
146
|
+
)
|
147
|
+
collection.load()
|
148
|
+
|
149
|
+
# Query the collection with node names from multimodal_df
|
150
|
+
q_node_names = getattr(node_type_df['q_node_name'],
|
151
|
+
"to_pandas",
|
152
|
+
lambda: node_type_df['q_node_name'])().tolist()
|
153
|
+
q_columns = ["node_id", "node_name", "node_type",
|
154
|
+
"feat", "feat_emb", "desc", "desc_emb"]
|
155
|
+
res = collection.query(
|
156
|
+
expr=f'node_name IN [{','.join(f'"{name}"' for name in q_node_names)}]',
|
157
|
+
output_fields=q_columns,
|
158
|
+
)
|
159
|
+
# Convert the embeedings into floats
|
160
|
+
for r_ in res:
|
161
|
+
r_['feat_emb'] = [float(x) for x in r_['feat_emb']]
|
162
|
+
r_['desc_emb'] = [float(x) for x in r_['desc_emb']]
|
163
|
+
|
164
|
+
# Convert the result to a DataFrame
|
165
|
+
res_df = df.DataFrame(res)[q_columns]
|
166
|
+
res_df["use_description"] = False
|
167
|
+
|
168
|
+
# Append the results to query_df
|
169
|
+
query_df.append(res_df)
|
170
|
+
|
171
|
+
# Concatenate all results into a single DataFrame
|
172
|
+
logger.log(logging.INFO, "Concatenating all results into a single DataFrame")
|
173
|
+
query_df = df.concat(query_df, ignore_index=True)
|
174
|
+
|
175
|
+
# Update the state by adding the the selected node IDs
|
176
|
+
logger.log(logging.INFO, "Updating state with selected node IDs")
|
177
|
+
state["selections"] = getattr(query_df,
|
178
|
+
"to_pandas",
|
179
|
+
lambda: query_df)().groupby(
|
180
|
+
"node_type"
|
181
|
+
)["node_id"].apply(list).to_dict()
|
182
|
+
|
183
|
+
# Append a user prompt to the query dataframe
|
184
|
+
logger.log(logging.INFO, "Adding user prompt to query dataframe")
|
185
|
+
query_df = df.concat([query_df, prompt_df]).reset_index(drop=True)
|
186
|
+
else:
|
187
|
+
# If no multimodal files are uploaded, use the prompt embeddings
|
188
|
+
query_df = prompt_df
|
189
|
+
|
190
|
+
return query_df
|
191
|
+
|
192
|
+
def _perform_subgraph_extraction(self,
|
193
|
+
state: Annotated[dict, InjectedState],
|
194
|
+
cfg: dict,
|
195
|
+
cfg_db: dict,
|
196
|
+
query_df: pd.DataFrame) -> dict:
|
197
|
+
"""
|
198
|
+
Perform multimodal subgraph extraction based on modal-specific embeddings.
|
199
|
+
|
200
|
+
Args:
|
201
|
+
state: The injected state for the tool.
|
202
|
+
cfg: The configuration dictionary.
|
203
|
+
cfg_db: The configuration dictionary for Milvus database.
|
204
|
+
query_df: The DataFrame containing the query embeddings and modalities.
|
205
|
+
|
206
|
+
Returns:
|
207
|
+
A dictionary containing the extracted subgraph with nodes and edges.
|
208
|
+
"""
|
209
|
+
# Initialize the subgraph dictionary
|
210
|
+
subgraphs = []
|
211
|
+
unified_subgraph = {
|
212
|
+
"nodes": [],
|
213
|
+
"edges": []
|
214
|
+
}
|
215
|
+
# subgraphs = {}
|
216
|
+
# subgraphs["nodes"] = []
|
217
|
+
# subgraphs["edges"] = []
|
218
|
+
|
219
|
+
# Loop over query embeddings and modalities
|
220
|
+
for q in getattr(query_df, "to_pandas", lambda: query_df)().iterrows():
|
221
|
+
logger.log(logging.INFO, "===========================================")
|
222
|
+
logger.log(logging.INFO, "Processing query: %s", q[1]['node_name'])
|
223
|
+
# Prepare the PCSTPruning object and extract the subgraph
|
224
|
+
# Parameters were set in the configuration file obtained from Hydra
|
225
|
+
# start = datetime.datetime.now()
|
226
|
+
subgraph = MultimodalPCSTPruning(
|
227
|
+
topk=state["topk_nodes"],
|
228
|
+
topk_e=state["topk_edges"],
|
229
|
+
cost_e=cfg.cost_e,
|
230
|
+
c_const=cfg.c_const,
|
231
|
+
root=cfg.root,
|
232
|
+
num_clusters=cfg.num_clusters,
|
233
|
+
pruning=cfg.pruning,
|
234
|
+
verbosity_level=cfg.verbosity_level,
|
235
|
+
use_description=q[1]['use_description'],
|
236
|
+
metric_type=cfg.search_metric_type
|
237
|
+
).extract_subgraph(q[1]['desc_emb'],
|
238
|
+
q[1]['feat_emb'],
|
239
|
+
q[1]['node_type'],
|
240
|
+
cfg_db)
|
241
|
+
|
242
|
+
# Append the extracted subgraph to the dictionary
|
243
|
+
unified_subgraph["nodes"].append(subgraph["nodes"].tolist())
|
244
|
+
unified_subgraph["edges"].append(subgraph["edges"].tolist())
|
245
|
+
subgraphs.append((q[1]['node_name'],
|
246
|
+
subgraph["nodes"].tolist(),
|
247
|
+
subgraph["edges"].tolist()))
|
248
|
+
|
249
|
+
# end = datetime.datetime.now()
|
250
|
+
# logger.log(logging.INFO, "Subgraph extraction time: %s seconds",
|
251
|
+
# (end - start).total_seconds())
|
252
|
+
|
253
|
+
# Concatenate and get unique node and edge indices
|
254
|
+
unified_subgraph["nodes"] = py.unique(
|
255
|
+
py.concatenate([py.array(list_) for list_ in unified_subgraph["nodes"]])
|
256
|
+
).tolist()
|
257
|
+
unified_subgraph["edges"] = py.unique(
|
258
|
+
py.concatenate([py.array(list_) for list_ in unified_subgraph["edges"]])
|
259
|
+
).tolist()
|
260
|
+
|
261
|
+
# Convert the unified subgraph and subgraphs to cudf DataFrames
|
262
|
+
unified_subgraph = df.DataFrame([("Unified Subgraph",
|
263
|
+
unified_subgraph["nodes"],
|
264
|
+
unified_subgraph["edges"])],
|
265
|
+
columns=["name", "nodes", "edges"])
|
266
|
+
subgraphs = df.DataFrame(subgraphs, columns=["name", "nodes", "edges"])
|
267
|
+
|
268
|
+
# Concate both DataFrames
|
269
|
+
subgraphs = df.concat([unified_subgraph, subgraphs], ignore_index=True)
|
270
|
+
|
271
|
+
return subgraphs
|
272
|
+
|
273
|
+
def _prepare_final_subgraph(self,
|
274
|
+
state:Annotated[dict, InjectedState],
|
275
|
+
subgraph: dict,
|
276
|
+
cfg: dict,
|
277
|
+
cfg_db) -> dict:
|
278
|
+
"""
|
279
|
+
Prepare the subgraph based on the extracted subgraph.
|
280
|
+
|
281
|
+
Args:
|
282
|
+
state: The injected state for the tool.
|
283
|
+
subgraph: The extracted subgraph.
|
284
|
+
graph: The graph dictionary.
|
285
|
+
cfg: The configuration dictionary for the tool.
|
286
|
+
cfg_db: The configuration dictionary for Milvus database.
|
287
|
+
|
288
|
+
Returns:
|
289
|
+
A dictionary containing the PyG graph, NetworkX graph, and textualized graph.
|
290
|
+
"""
|
291
|
+
# Convert the dict to a cudf DataFrame
|
292
|
+
node_colors = {n: cfg.node_colors_dict[k]
|
293
|
+
for k, v in state["selections"].items() for n in v}
|
294
|
+
color_df = df.DataFrame(list(node_colors.items()), columns=["node_id", "color"])
|
295
|
+
# print(color_df)
|
296
|
+
|
297
|
+
# Prepare the subgraph dictionary
|
298
|
+
graph_dict = {
|
299
|
+
"name": [],
|
300
|
+
"nodes": [],
|
301
|
+
"edges": [],
|
302
|
+
"text": ""
|
303
|
+
}
|
304
|
+
for sub in getattr(subgraph, "to_pandas", lambda: subgraph)().itertuples(index=False):
|
305
|
+
# Prepare the graph name
|
306
|
+
print(f"Processing subgraph: {sub.name}")
|
307
|
+
print('---')
|
308
|
+
print(sub.nodes)
|
309
|
+
print('---')
|
310
|
+
print(sub.edges)
|
311
|
+
print('---')
|
312
|
+
|
313
|
+
# Prepare graph dataframes
|
314
|
+
# Nodes
|
315
|
+
coll_name = f"{cfg_db.milvus_db.database_name}_nodes"
|
316
|
+
node_coll = Collection(name=coll_name)
|
317
|
+
node_coll.load()
|
318
|
+
graph_nodes = node_coll.query(
|
319
|
+
expr=f'node_index IN [{",".join(f"{n}" for n in sub.nodes)}]',
|
320
|
+
output_fields=['node_id', 'node_name', 'node_type', 'desc']
|
321
|
+
)
|
322
|
+
graph_nodes = df.DataFrame(graph_nodes)
|
323
|
+
graph_nodes.drop(columns=['node_index'], inplace=True)
|
324
|
+
if not color_df.empty:
|
325
|
+
# Merge the color dataframe with the graph nodes
|
326
|
+
graph_nodes = graph_nodes.merge(color_df, on="node_id", how="left")
|
327
|
+
else:
|
328
|
+
graph_nodes["color"] = 'black' # Default color
|
329
|
+
graph_nodes['color'].fillna('black', inplace=True) # Fill NaN colors with black
|
330
|
+
# Edges
|
331
|
+
coll_name = f"{cfg_db.milvus_db.database_name}_edges"
|
332
|
+
edge_coll = Collection(name=coll_name)
|
333
|
+
edge_coll.load()
|
334
|
+
graph_edges = edge_coll.query(
|
335
|
+
expr=f'triplet_index IN [{",".join(f"{e}" for e in sub.edges)}]',
|
336
|
+
output_fields=['head_id', 'tail_id', 'edge_type']
|
337
|
+
)
|
338
|
+
graph_edges = df.DataFrame(graph_edges)
|
339
|
+
graph_edges.drop(columns=['triplet_index'], inplace=True)
|
340
|
+
graph_edges['edge_type'] = graph_edges['edge_type'].str.split('|')
|
341
|
+
|
342
|
+
# Prepare lists for visualization
|
343
|
+
graph_dict["name"].append(sub.name)
|
344
|
+
graph_dict["nodes"].append([(
|
345
|
+
row.node_id,
|
346
|
+
{'hover': "Node Name : " + row.node_name + "\n" +\
|
347
|
+
"Node Type : " + row.node_type + "\n" +
|
348
|
+
"Desc : " + row.desc,
|
349
|
+
'click': '$hover',
|
350
|
+
'color': row.color})
|
351
|
+
for row in getattr(graph_nodes,
|
352
|
+
"to_pandas",
|
353
|
+
lambda: graph_nodes)().itertuples(index=False)])
|
354
|
+
graph_dict["edges"].append([(
|
355
|
+
row.head_id,
|
356
|
+
row.tail_id,
|
357
|
+
{'label': tuple(row.edge_type)})
|
358
|
+
for row in getattr(graph_edges,
|
359
|
+
"to_pandas",
|
360
|
+
lambda: graph_edges)().itertuples(index=False)])
|
361
|
+
|
362
|
+
# Prepare the textualized subgraph
|
363
|
+
if sub.name == "Unified Subgraph":
|
364
|
+
graph_nodes = graph_nodes[['node_id', 'desc']]
|
365
|
+
graph_nodes.rename(columns={'desc': 'node_attr'}, inplace=True)
|
366
|
+
graph_edges = graph_edges[['head_id', 'edge_type', 'tail_id']]
|
367
|
+
graph_dict["text"] = (
|
368
|
+
getattr(graph_nodes, "to_pandas", lambda: graph_nodes)().to_csv(index=False)
|
369
|
+
+ "\n"
|
370
|
+
+ getattr(graph_edges, "to_pandas", lambda: graph_edges)().to_csv(index=False)
|
371
|
+
)
|
372
|
+
|
373
|
+
return graph_dict
|
374
|
+
|
375
|
+
def normalize_vector(self,
|
376
|
+
v : list) -> list:
|
377
|
+
"""
|
378
|
+
Normalize a vector using CuPy.
|
379
|
+
|
380
|
+
Args:
|
381
|
+
v : Vector to normalize.
|
382
|
+
|
383
|
+
Returns:
|
384
|
+
Normalized vector.
|
385
|
+
"""
|
386
|
+
v = py.asarray(v)
|
387
|
+
norm = py.linalg.norm(v)
|
388
|
+
return (v / norm).tolist()
|
389
|
+
|
390
|
+
def _run(
|
391
|
+
self,
|
392
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
393
|
+
state: Annotated[dict, InjectedState],
|
394
|
+
prompt: str,
|
395
|
+
arg_data: ArgumentData = None,
|
396
|
+
) -> Command:
|
397
|
+
"""
|
398
|
+
Run the subgraph extraction tool.
|
399
|
+
|
400
|
+
Args:
|
401
|
+
tool_call_id: The tool call ID for the tool.
|
402
|
+
state: Injected state for the tool.
|
403
|
+
prompt: The prompt to interact with the backend.
|
404
|
+
arg_data (ArgumentData): The argument data.
|
405
|
+
|
406
|
+
Returns:
|
407
|
+
Command: The command to be executed.
|
408
|
+
"""
|
409
|
+
logger.log(logging.INFO, "Invoking subgraph_extraction tool")
|
410
|
+
|
411
|
+
# Load hydra configuration
|
412
|
+
with hydra.initialize(version_base=None, config_path="../configs"):
|
413
|
+
cfg = hydra.compose(
|
414
|
+
config_name="config", overrides=["tools/multimodal_subgraph_extraction=default"]
|
415
|
+
)
|
416
|
+
cfg_db = cfg.app.frontend
|
417
|
+
cfg = cfg.tools.multimodal_subgraph_extraction
|
418
|
+
|
419
|
+
# Check if the Milvus connection exists
|
420
|
+
# logger.log(logging.INFO, "Checking Milvus connection")
|
421
|
+
# logger.log(logging.INFO, "Milvus connection name: %s", cfg_db.milvus_db.alias)
|
422
|
+
# logger.log(logging.INFO, "Milvus connection DB: %s", cfg_db.milvus_db.database_name)
|
423
|
+
# logger.log(logging.INFO, "Is connection established? %s",
|
424
|
+
# connections.has_connection(cfg_db.milvus_db.alias))
|
425
|
+
# if connections.has_connection(cfg_db.milvus_db.alias):
|
426
|
+
# logger.log(logging.INFO, "Milvus connection is established.")
|
427
|
+
# for collection_name in utility.list_collections():
|
428
|
+
# logger.log(logging.INFO, "Collection: %s", collection_name)
|
429
|
+
|
430
|
+
# Prepare the query embeddings and modalities
|
431
|
+
logger.log(logging.INFO, "_prepare_query_modalities")
|
432
|
+
# start = datetime.datetime.now()
|
433
|
+
query_df = self._prepare_query_modalities(
|
434
|
+
{"text": prompt,
|
435
|
+
"emb": [self.normalize_vector(
|
436
|
+
state["embedding_model"].embed_query(prompt)
|
437
|
+
)]
|
438
|
+
},
|
439
|
+
state,
|
440
|
+
cfg_db,
|
441
|
+
)
|
442
|
+
# end = datetime.datetime.now()
|
443
|
+
# logger.log(logging.INFO, "_prepare_query_modalities time: %s seconds",
|
444
|
+
# (end - start).total_seconds())
|
445
|
+
|
446
|
+
# Perform subgraph extraction
|
447
|
+
logger.log(logging.INFO, "_perform_subgraph_extraction")
|
448
|
+
# start = datetime.datetime.now()
|
449
|
+
subgraphs = self._perform_subgraph_extraction(state,
|
450
|
+
cfg,
|
451
|
+
cfg_db,
|
452
|
+
query_df)
|
453
|
+
# end = datetime.datetime.now()
|
454
|
+
# logger.log(logging.INFO, "_perform_subgraph_extraction time: %s seconds",
|
455
|
+
# (end - start).total_seconds())
|
456
|
+
|
457
|
+
# Prepare subgraph as a NetworkX graph and textualized graph
|
458
|
+
logger.log(logging.INFO, "_prepare_final_subgraph")
|
459
|
+
logger.log(logging.INFO, "Subgraphs extracted: %s", len(subgraphs))
|
460
|
+
# start = datetime.datetime.now()
|
461
|
+
final_subgraph = self._prepare_final_subgraph(state,
|
462
|
+
subgraphs,
|
463
|
+
cfg,
|
464
|
+
cfg_db)
|
465
|
+
# end = datetime.datetime.now()
|
466
|
+
# logger.log(logging.INFO, "_prepare_final_subgraph time: %s seconds",
|
467
|
+
# (end - start).total_seconds())
|
468
|
+
|
469
|
+
# Prepare the dictionary of extracted graph
|
470
|
+
logger.log(logging.INFO, "dic_extracted_graph")
|
471
|
+
# start = datetime.datetime.now()
|
472
|
+
dic_extracted_graph = {
|
473
|
+
"name": arg_data.extraction_name,
|
474
|
+
"tool_call_id": tool_call_id,
|
475
|
+
"graph_source": state["dic_source_graph"][0]["name"],
|
476
|
+
"topk_nodes": state["topk_nodes"],
|
477
|
+
"topk_edges": state["topk_edges"],
|
478
|
+
"graph_dict": {
|
479
|
+
"name": final_subgraph["name"],
|
480
|
+
"nodes": final_subgraph["nodes"],
|
481
|
+
"edges": final_subgraph["edges"],
|
482
|
+
},
|
483
|
+
"graph_text": final_subgraph["text"],
|
484
|
+
"graph_summary": None,
|
485
|
+
}
|
486
|
+
# end = datetime.datetime.now()
|
487
|
+
# logger.log(logging.INFO, "dic_extracted_graph time: %s seconds",
|
488
|
+
# (end - start).total_seconds())
|
489
|
+
|
490
|
+
# Prepare the dictionary of updated state
|
491
|
+
dic_updated_state_for_model = {}
|
492
|
+
for key, value in {
|
493
|
+
"dic_extracted_graph": [dic_extracted_graph],
|
494
|
+
}.items():
|
495
|
+
if value:
|
496
|
+
dic_updated_state_for_model[key] = value
|
497
|
+
|
498
|
+
# Return the updated state of the tool
|
499
|
+
return Command(
|
500
|
+
update=dic_updated_state_for_model | {
|
501
|
+
# update the message history
|
502
|
+
"messages": [
|
503
|
+
ToolMessage(
|
504
|
+
content=f"Subgraph Extraction Result of {arg_data.extraction_name}",
|
505
|
+
tool_call_id=tool_call_id,
|
506
|
+
)
|
507
|
+
],
|
508
|
+
}
|
509
|
+
)
|