aiagents4pharma 1.45.1__py3-none-any.whl → 1.46.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2aiagents4pharma/configs/app/__init__.py +0 -0
- aiagents4pharma/talk2aiagents4pharma/configs/app/frontend/__init__.py +0 -0
- aiagents4pharma/talk2aiagents4pharma/configs/app/frontend/default.yaml +102 -0
- aiagents4pharma/talk2aiagents4pharma/configs/config.yaml +1 -0
- aiagents4pharma/talk2aiagents4pharma/tests/test_main_agent.py +144 -54
- aiagents4pharma/talk2biomodels/api/__init__.py +1 -1
- aiagents4pharma/talk2biomodels/configs/app/__init__.py +0 -0
- aiagents4pharma/talk2biomodels/configs/app/frontend/__init__.py +0 -0
- aiagents4pharma/talk2biomodels/configs/app/frontend/default.yaml +72 -0
- aiagents4pharma/talk2biomodels/configs/config.yaml +1 -0
- aiagents4pharma/talk2biomodels/tests/test_api.py +0 -30
- aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +1 -1
- aiagents4pharma/talk2biomodels/tools/get_annotation.py +1 -10
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +42 -26
- aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +1 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/default.yaml +4 -23
- aiagents4pharma/talk2knowledgegraphs/configs/utils/database/milvus/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/utils/database/milvus/default.yaml +61 -0
- aiagents4pharma/talk2knowledgegraphs/entrypoint.sh +1 -11
- aiagents4pharma/talk2knowledgegraphs/milvus_data_dump.py +11 -10
- aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +193 -73
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py +1375 -667
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_database_milvus_connection_manager.py +812 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_extractions_milvus_multimodal_pcst.py +723 -539
- aiagents4pharma/talk2knowledgegraphs/tools/milvus_multimodal_subgraph_extraction.py +474 -58
- aiagents4pharma/talk2knowledgegraphs/utils/database/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py +586 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py +240 -8
- aiagents4pharma/talk2scholars/configs/app/frontend/default.yaml +67 -31
- {aiagents4pharma-1.45.1.dist-info → aiagents4pharma-1.46.1.dist-info}/METADATA +10 -1
- {aiagents4pharma-1.45.1.dist-info → aiagents4pharma-1.46.1.dist-info}/RECORD +33 -23
- aiagents4pharma/talk2biomodels/api/kegg.py +0 -87
- {aiagents4pharma-1.45.1.dist-info → aiagents4pharma-1.46.1.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.45.1.dist-info → aiagents4pharma-1.46.1.dist-info}/licenses/LICENSE +0 -0
@@ -2,11 +2,15 @@
|
|
2
2
|
Tool for performing multimodal subgraph extraction.
|
3
3
|
"""
|
4
4
|
|
5
|
+
import asyncio
|
6
|
+
import concurrent.futures
|
5
7
|
import logging
|
8
|
+
from dataclasses import dataclass
|
6
9
|
from typing import Annotated
|
7
10
|
|
8
11
|
import hydra
|
9
12
|
import pandas as pd
|
13
|
+
import pcst_fast
|
10
14
|
from langchain_core.messages import ToolMessage
|
11
15
|
from langchain_core.tools import BaseTool
|
12
16
|
from langchain_core.tools.base import InjectedToolCallId
|
@@ -15,6 +19,8 @@ from langgraph.types import Command
|
|
15
19
|
from pydantic import BaseModel, Field
|
16
20
|
from pymilvus import Collection
|
17
21
|
|
22
|
+
from ..utils.database import MilvusConnectionManager
|
23
|
+
from ..utils.database.milvus_connection_manager import QueryParams
|
18
24
|
from ..utils.extractions.milvus_multimodal_pcst import (
|
19
25
|
DynamicLibraryLoader,
|
20
26
|
MultimodalPCSTPruning,
|
@@ -22,11 +28,23 @@ from ..utils.extractions.milvus_multimodal_pcst import (
|
|
22
28
|
)
|
23
29
|
from .load_arguments import ArgumentData
|
24
30
|
|
31
|
+
# pylint: disable=too-many-lines
|
25
32
|
# Initialize logger
|
26
33
|
logging.basicConfig(level=logging.INFO)
|
27
34
|
logger = logging.getLogger(__name__)
|
28
35
|
|
29
36
|
|
37
|
+
@dataclass
|
38
|
+
class ExtractionParams:
|
39
|
+
"""Parameters for subgraph extraction."""
|
40
|
+
|
41
|
+
state: dict
|
42
|
+
cfg: dict
|
43
|
+
cfg_db: dict
|
44
|
+
query_df: object
|
45
|
+
connection_manager: object
|
46
|
+
|
47
|
+
|
30
48
|
class MultimodalSubgraphExtractionInput(BaseModel):
|
31
49
|
"""
|
32
50
|
MultimodalSubgraphExtractionInput is a Pydantic model representing an input
|
@@ -118,7 +136,15 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
118
136
|
q_node_names = getattr(
|
119
137
|
node_names_series, "to_pandas", lambda series=node_names_series: series
|
120
138
|
)().tolist()
|
121
|
-
q_columns = [
|
139
|
+
q_columns = [
|
140
|
+
"node_id",
|
141
|
+
"node_name",
|
142
|
+
"node_type",
|
143
|
+
"feat",
|
144
|
+
"feat_emb",
|
145
|
+
"desc",
|
146
|
+
"desc_emb",
|
147
|
+
]
|
122
148
|
res = collection.query(
|
123
149
|
expr=f"node_name IN [{','.join(f'"{name}"' for name in q_node_names)}]",
|
124
150
|
output_fields=q_columns,
|
@@ -133,6 +159,52 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
133
159
|
res_df["use_description"] = False
|
134
160
|
return res_df
|
135
161
|
|
162
|
+
async def _query_milvus_collection_async(
|
163
|
+
self, node_type, node_type_df, cfg_db, connection_manager
|
164
|
+
):
|
165
|
+
"""Helper method to query Milvus collection asynchronously for a specific node type."""
|
166
|
+
collection_name = f"{cfg_db.milvus_db.database_name}_nodes_{node_type.replace('/', '_')}"
|
167
|
+
|
168
|
+
# Query the collection with node names from multimodal_df
|
169
|
+
node_names_series = node_type_df["q_node_name"]
|
170
|
+
q_node_names = getattr(
|
171
|
+
node_names_series, "to_pandas", lambda series=node_names_series: series
|
172
|
+
)().tolist()
|
173
|
+
|
174
|
+
# Create filter expression for async query
|
175
|
+
node_names_str = ",".join(f'"{name}"' for name in q_node_names)
|
176
|
+
expr = f"node_name IN [{node_names_str}]"
|
177
|
+
|
178
|
+
q_columns = [
|
179
|
+
"node_id",
|
180
|
+
"node_name",
|
181
|
+
"node_type",
|
182
|
+
"feat",
|
183
|
+
"feat_emb",
|
184
|
+
"desc",
|
185
|
+
"desc_emb",
|
186
|
+
]
|
187
|
+
|
188
|
+
# Create query parameters and perform async query
|
189
|
+
query_params = QueryParams(
|
190
|
+
collection_name=collection_name, expr=expr, output_fields=q_columns
|
191
|
+
)
|
192
|
+
res = await connection_manager.async_query(query_params)
|
193
|
+
|
194
|
+
# Convert the embeddings into floats
|
195
|
+
for r_ in res:
|
196
|
+
r_["feat_emb"] = [float(x) for x in r_["feat_emb"]]
|
197
|
+
r_["desc_emb"] = [float(x) for x in r_["desc_emb"]]
|
198
|
+
|
199
|
+
# Convert the result to a DataFrame
|
200
|
+
res_df = (
|
201
|
+
self.loader.df.DataFrame(res)[q_columns]
|
202
|
+
if res
|
203
|
+
else self.loader.df.DataFrame(columns=q_columns)
|
204
|
+
)
|
205
|
+
res_df["use_description"] = False
|
206
|
+
return res_df
|
207
|
+
|
136
208
|
def _prepare_query_modalities(
|
137
209
|
self, prompt: dict, state: Annotated[dict, InjectedState], cfg_db: dict
|
138
210
|
):
|
@@ -201,6 +273,97 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
201
273
|
|
202
274
|
return query_df
|
203
275
|
|
276
|
+
async def _prepare_query_modalities_async(
|
277
|
+
self,
|
278
|
+
prompt: dict,
|
279
|
+
state: Annotated[dict, InjectedState],
|
280
|
+
cfg_db: dict,
|
281
|
+
connection_manager,
|
282
|
+
):
|
283
|
+
"""
|
284
|
+
Prepare the modality-specific query for subgraph extraction asynchronously.
|
285
|
+
|
286
|
+
Args:
|
287
|
+
prompt: The dictionary containing the user prompt and embeddings
|
288
|
+
state: The injected state for the tool
|
289
|
+
cfg_db: The configuration dictionary for Milvus database
|
290
|
+
connection_manager: The MilvusConnectionManager instance
|
291
|
+
|
292
|
+
Returns:
|
293
|
+
A DataFrame containing the query embeddings and modalities
|
294
|
+
"""
|
295
|
+
# Initialize dataframes
|
296
|
+
logger.log(logging.INFO, "Initializing dataframes (async)")
|
297
|
+
query_df = []
|
298
|
+
prompt_df = self.loader.df.DataFrame(
|
299
|
+
{
|
300
|
+
"node_id": "user_prompt",
|
301
|
+
"node_name": "User Prompt",
|
302
|
+
"node_type": "prompt",
|
303
|
+
"feat": prompt["text"],
|
304
|
+
"feat_emb": prompt["emb"],
|
305
|
+
"desc": prompt["text"],
|
306
|
+
"desc_emb": prompt["emb"],
|
307
|
+
"use_description": True, # set to True for user prompt embedding
|
308
|
+
}
|
309
|
+
)
|
310
|
+
|
311
|
+
# Read multimodal files uploaded by the user
|
312
|
+
multimodal_df = self._read_multimodal_files(state)
|
313
|
+
|
314
|
+
# Check if the multimodal_df is empty
|
315
|
+
logger.log(logging.INFO, "Prepare query modalities (async)")
|
316
|
+
if len(multimodal_df) > 0:
|
317
|
+
# Create parallel tasks for querying each node type
|
318
|
+
logger.log(
|
319
|
+
logging.INFO,
|
320
|
+
"Querying Milvus database for each node type in multimodal_df (parallel)",
|
321
|
+
)
|
322
|
+
|
323
|
+
# Create async tasks for each node type
|
324
|
+
tasks = []
|
325
|
+
for node_type, node_type_df in multimodal_df.groupby("q_node_type"):
|
326
|
+
print(f"Processing node type: {node_type}")
|
327
|
+
task = self._query_milvus_collection_async(
|
328
|
+
node_type, node_type_df, cfg_db, connection_manager
|
329
|
+
)
|
330
|
+
tasks.append(task)
|
331
|
+
|
332
|
+
# Execute all queries in parallel using hybrid approach
|
333
|
+
if len(tasks) == 1:
|
334
|
+
# Single task, run directly
|
335
|
+
query_results = [await tasks[0]]
|
336
|
+
else:
|
337
|
+
# Multiple tasks, but use sequential execution to avoid event loop issues
|
338
|
+
query_results = []
|
339
|
+
for task in tasks:
|
340
|
+
result = await task
|
341
|
+
query_results.append(result)
|
342
|
+
|
343
|
+
query_df.extend(query_results)
|
344
|
+
|
345
|
+
# Concatenate all results into a single DataFrame
|
346
|
+
logger.log(logging.INFO, "Concatenating all results into a single DataFrame")
|
347
|
+
query_df = self.loader.df.concat(query_df, ignore_index=True)
|
348
|
+
|
349
|
+
# Update the state by adding the selected node IDs
|
350
|
+
logger.log(logging.INFO, "Updating state with selected node IDs")
|
351
|
+
state["selections"] = (
|
352
|
+
getattr(query_df, "to_pandas", lambda: query_df)()
|
353
|
+
.groupby("node_type")["node_id"]
|
354
|
+
.apply(list)
|
355
|
+
.to_dict()
|
356
|
+
)
|
357
|
+
|
358
|
+
# Append a user prompt to the query dataframe
|
359
|
+
logger.log(logging.INFO, "Adding user prompt to query dataframe")
|
360
|
+
query_df = self.loader.df.concat([query_df, prompt_df]).reset_index(drop=True)
|
361
|
+
else:
|
362
|
+
# If no multimodal files are uploaded, use the prompt embeddings
|
363
|
+
query_df = prompt_df
|
364
|
+
|
365
|
+
return query_df
|
366
|
+
|
204
367
|
def _perform_subgraph_extraction(
|
205
368
|
self,
|
206
369
|
state: Annotated[dict, InjectedState],
|
@@ -287,7 +450,13 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
287
450
|
|
288
451
|
# Convert the unified subgraph and subgraphs to DataFrames
|
289
452
|
unified_subgraph = self.loader.df.DataFrame(
|
290
|
-
[
|
453
|
+
[
|
454
|
+
(
|
455
|
+
"Unified Subgraph",
|
456
|
+
unified_subgraph["nodes"],
|
457
|
+
unified_subgraph["edges"],
|
458
|
+
)
|
459
|
+
],
|
291
460
|
columns=["name", "nodes", "edges"],
|
292
461
|
)
|
293
462
|
subgraphs = self.loader.df.DataFrame(subgraphs, columns=["name", "nodes", "edges"])
|
@@ -297,8 +466,199 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
297
466
|
|
298
467
|
return subgraphs
|
299
468
|
|
469
|
+
async def _perform_subgraph_extraction_async(self, params: ExtractionParams) -> dict:
|
470
|
+
"""
|
471
|
+
Perform multimodal subgraph extraction based on modal-specific embeddings asynchronously.
|
472
|
+
|
473
|
+
Args:
|
474
|
+
state: The injected state for the tool
|
475
|
+
cfg: The configuration dictionary
|
476
|
+
cfg_db: The configuration dictionary for Milvus database
|
477
|
+
query_df: The DataFrame containing the query embeddings and modalities
|
478
|
+
connection_manager: The MilvusConnectionManager instance
|
479
|
+
|
480
|
+
Returns:
|
481
|
+
A dictionary containing the extracted subgraph with nodes and edges
|
482
|
+
"""
|
483
|
+
# Initialize the subgraph dictionary
|
484
|
+
subgraphs = []
|
485
|
+
unified_subgraph = {"nodes": [], "edges": []}
|
486
|
+
|
487
|
+
# Create parallel tasks for each query
|
488
|
+
tasks = []
|
489
|
+
query_info = []
|
490
|
+
|
491
|
+
for q in getattr(params.query_df, "to_pandas", lambda: params.query_df)().iterrows():
|
492
|
+
logger.log(logging.INFO, "===========================================")
|
493
|
+
logger.log(logging.INFO, "Processing query: %s", q[1]["node_name"])
|
494
|
+
|
495
|
+
# Store query info for later processing
|
496
|
+
query_info.append(q[1])
|
497
|
+
|
498
|
+
# Get dynamic metric type using helper method
|
499
|
+
dynamic_metric_type = self._get_dynamic_metric_type(params.cfg)
|
500
|
+
|
501
|
+
# Create PCST pruning instance using helper
|
502
|
+
pcst_instance = self._create_pcst_instance(params, q[1], dynamic_metric_type)
|
503
|
+
|
504
|
+
# Create async task for subgraph extraction
|
505
|
+
task = self._extract_single_subgraph_async(
|
506
|
+
pcst_instance, q[1], params.cfg_db, params.connection_manager
|
507
|
+
)
|
508
|
+
tasks.append(task)
|
509
|
+
|
510
|
+
# Execute all subgraph extractions sequentially to avoid event loop conflicts
|
511
|
+
subgraph_results = []
|
512
|
+
for i, task in enumerate(tasks):
|
513
|
+
logger.log(logging.INFO, "Processing subgraph %d/%d", i + 1, len(tasks))
|
514
|
+
result = await task
|
515
|
+
subgraph_results.append(result)
|
516
|
+
|
517
|
+
# Process results and finalize
|
518
|
+
self._process_subgraph_results(subgraph_results, query_info, unified_subgraph, subgraphs)
|
519
|
+
return self._finalize_subgraph_results(subgraphs, unified_subgraph)
|
520
|
+
|
521
|
+
def _process_subgraph_results(self, subgraph_results, query_info, unified_subgraph, subgraphs):
|
522
|
+
"""Process individual subgraph results."""
|
523
|
+
for i, subgraph in enumerate(subgraph_results):
|
524
|
+
query_row = query_info[i]
|
525
|
+
unified_subgraph["nodes"].append(subgraph["nodes"].tolist())
|
526
|
+
unified_subgraph["edges"].append(subgraph["edges"].tolist())
|
527
|
+
subgraphs.append(
|
528
|
+
(
|
529
|
+
query_row["node_name"],
|
530
|
+
subgraph["nodes"].tolist(),
|
531
|
+
subgraph["edges"].tolist(),
|
532
|
+
)
|
533
|
+
)
|
534
|
+
|
535
|
+
def _finalize_subgraph_results(self, subgraphs, unified_subgraph):
|
536
|
+
"""Process and finalize subgraph results into DataFrames."""
|
537
|
+
# Concatenate and get unique node and edge indices
|
538
|
+
nodes_arrays = [self.loader.py.array(list_) for list_ in unified_subgraph["nodes"]]
|
539
|
+
unified_subgraph["nodes"] = self.loader.py.unique(
|
540
|
+
self.loader.py.concatenate(nodes_arrays)
|
541
|
+
).tolist()
|
542
|
+
edges_arrays = [self.loader.py.array(list_) for list_ in unified_subgraph["edges"]]
|
543
|
+
unified_subgraph["edges"] = self.loader.py.unique(
|
544
|
+
self.loader.py.concatenate(edges_arrays)
|
545
|
+
).tolist()
|
546
|
+
|
547
|
+
# Convert the unified subgraph and subgraphs to DataFrames
|
548
|
+
unified_subgraph_df = self.loader.df.DataFrame(
|
549
|
+
[
|
550
|
+
(
|
551
|
+
"Unified Subgraph",
|
552
|
+
unified_subgraph["nodes"],
|
553
|
+
unified_subgraph["edges"],
|
554
|
+
)
|
555
|
+
],
|
556
|
+
columns=["name", "nodes", "edges"],
|
557
|
+
)
|
558
|
+
subgraphs_df = self.loader.df.DataFrame(subgraphs, columns=["name", "nodes", "edges"])
|
559
|
+
|
560
|
+
# Concatenate both DataFrames
|
561
|
+
return self.loader.df.concat([unified_subgraph_df, subgraphs_df], ignore_index=True)
|
562
|
+
|
563
|
+
async def _extract_single_subgraph_async(
|
564
|
+
self, pcst_instance, query_row, cfg_db, connection_manager
|
565
|
+
):
|
566
|
+
"""
|
567
|
+
Extract a single subgraph asynchronously using the new async methods.
|
568
|
+
"""
|
569
|
+
# Load data and compute prizes
|
570
|
+
edge_index, prizes, num_nodes = await self._load_subgraph_data(
|
571
|
+
pcst_instance, query_row, cfg_db, connection_manager
|
572
|
+
)
|
573
|
+
|
574
|
+
# Run PCST algorithm and get results
|
575
|
+
return self._run_pcst_algorithm(pcst_instance, edge_index, num_nodes, prizes)
|
576
|
+
|
577
|
+
async def _load_subgraph_data(self, pcst_instance, query_row, cfg_db, connection_manager):
|
578
|
+
"""Load edge index, compute prizes, and get node count."""
|
579
|
+
# Load edge index asynchronously
|
580
|
+
edge_index = await pcst_instance.load_edge_index_async(cfg_db, connection_manager)
|
581
|
+
|
582
|
+
# Compute prizes asynchronously
|
583
|
+
prizes = await pcst_instance.compute_prizes_async(
|
584
|
+
query_row["desc_emb"],
|
585
|
+
query_row["feat_emb"],
|
586
|
+
cfg_db,
|
587
|
+
query_row["node_type"],
|
588
|
+
)
|
589
|
+
|
590
|
+
# Get number of nodes
|
591
|
+
nodes_collection = f"{cfg_db.milvus_db.database_name}_nodes"
|
592
|
+
stats = await connection_manager.async_get_collection_stats(nodes_collection)
|
593
|
+
num_nodes = stats["num_entities"]
|
594
|
+
|
595
|
+
return edge_index, prizes, num_nodes
|
596
|
+
|
597
|
+
def _run_pcst_algorithm(self, pcst_instance, edge_index, num_nodes, prizes):
|
598
|
+
"""Run PCST algorithm and get subgraph results."""
|
599
|
+
# Compute costs in constructing the subgraph
|
600
|
+
edges_dict, prizes_final, costs, mapping = pcst_instance.compute_subgraph_costs(
|
601
|
+
edge_index, num_nodes, prizes
|
602
|
+
)
|
603
|
+
|
604
|
+
# Retrieve the subgraph using the PCST algorithm
|
605
|
+
result_vertices, result_edges = pcst_fast.pcst_fast(
|
606
|
+
edges_dict["edges"].tolist(),
|
607
|
+
prizes_final.tolist(),
|
608
|
+
costs.tolist(),
|
609
|
+
pcst_instance.root,
|
610
|
+
pcst_instance.num_clusters,
|
611
|
+
pcst_instance.pruning,
|
612
|
+
pcst_instance.verbosity_level,
|
613
|
+
)
|
614
|
+
|
615
|
+
# Get subgraph nodes and edges based on the PCST result
|
616
|
+
return pcst_instance.get_subgraph_nodes_edges(
|
617
|
+
num_nodes,
|
618
|
+
pcst_instance.loader.py.asarray(result_vertices),
|
619
|
+
{
|
620
|
+
"edges": pcst_instance.loader.py.asarray(result_edges),
|
621
|
+
"num_prior_edges": edges_dict["num_prior_edges"],
|
622
|
+
"edge_index": edge_index,
|
623
|
+
},
|
624
|
+
mapping,
|
625
|
+
)
|
626
|
+
|
627
|
+
def _run(
|
628
|
+
self,
|
629
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
630
|
+
state: Annotated[dict, InjectedState],
|
631
|
+
prompt: str,
|
632
|
+
arg_data: ArgumentData = None,
|
633
|
+
) -> Command:
|
634
|
+
"""
|
635
|
+
Synchronous wrapper for the async _run_async method.
|
636
|
+
This maintains compatibility with LangGraph while using async operations internally.
|
637
|
+
"""
|
638
|
+
# concurrent.futures imported at top level
|
639
|
+
|
640
|
+
def run_in_thread():
|
641
|
+
"""Run async method in a new thread with its own event loop."""
|
642
|
+
# Create a new event loop for this thread
|
643
|
+
new_loop = asyncio.new_event_loop()
|
644
|
+
asyncio.set_event_loop(new_loop)
|
645
|
+
try:
|
646
|
+
result = new_loop.run_until_complete(
|
647
|
+
self._run_async(tool_call_id, state, prompt, arg_data)
|
648
|
+
)
|
649
|
+
return result
|
650
|
+
finally:
|
651
|
+
# Properly cleanup the event loop
|
652
|
+
new_loop.close()
|
653
|
+
asyncio.set_event_loop(None)
|
654
|
+
|
655
|
+
# Always use a separate thread to avoid event loop conflicts
|
656
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
657
|
+
future = executor.submit(run_in_thread)
|
658
|
+
return future.result()
|
659
|
+
|
300
660
|
def _prepare_final_subgraph(
|
301
|
-
self, state: Annotated[dict, InjectedState], subgraph: dict,
|
661
|
+
self, state: Annotated[dict, InjectedState], subgraph: dict, cfg_db
|
302
662
|
) -> dict:
|
303
663
|
"""
|
304
664
|
Prepare the subgraph based on the extracted subgraph.
|
@@ -306,8 +666,6 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
306
666
|
Args:
|
307
667
|
state: The injected state for the tool.
|
308
668
|
subgraph: The extracted subgraph.
|
309
|
-
graph: The graph dictionary.
|
310
|
-
cfg: The configuration dictionary for the tool.
|
311
669
|
cfg_db: The configuration dictionary for Milvus database.
|
312
670
|
|
313
671
|
Returns:
|
@@ -315,7 +673,7 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
315
673
|
"""
|
316
674
|
# Convert the dict to a DataFrame
|
317
675
|
node_colors = {
|
318
|
-
n:
|
676
|
+
n: cfg_db.node_colors_dict[k] for k, v in state["selections"].items() for n in v
|
319
677
|
}
|
320
678
|
color_df = self.loader.df.DataFrame(list(node_colors.items()), columns=["node_id", "color"])
|
321
679
|
# print(color_df)
|
@@ -345,7 +703,9 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
345
703
|
},
|
346
704
|
)
|
347
705
|
for row in getattr(
|
348
|
-
graph_nodes,
|
706
|
+
graph_nodes,
|
707
|
+
"to_pandas",
|
708
|
+
lambda graph_nodes=graph_nodes: graph_nodes,
|
349
709
|
)().itertuples(index=False)
|
350
710
|
]
|
351
711
|
)
|
@@ -353,7 +713,9 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
353
713
|
[
|
354
714
|
(row.head_id, row.tail_id, {"label": tuple(row.edge_type)})
|
355
715
|
for row in getattr(
|
356
|
-
graph_edges,
|
716
|
+
graph_edges,
|
717
|
+
"to_pandas",
|
718
|
+
lambda graph_edges=graph_edges: graph_edges,
|
357
719
|
)().itertuples(index=False)
|
358
720
|
]
|
359
721
|
)
|
@@ -364,11 +726,15 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
364
726
|
graph_nodes.rename(columns={"desc": "node_attr"}, inplace=True)
|
365
727
|
graph_edges = graph_edges[["head_id", "edge_type", "tail_id"]]
|
366
728
|
nodes_pandas = getattr(
|
367
|
-
graph_nodes,
|
729
|
+
graph_nodes,
|
730
|
+
"to_pandas",
|
731
|
+
lambda graph_nodes=graph_nodes: graph_nodes,
|
368
732
|
)()
|
369
733
|
nodes_csv = nodes_pandas.to_csv(index=False)
|
370
734
|
edges_pandas = getattr(
|
371
|
-
graph_edges,
|
735
|
+
graph_edges,
|
736
|
+
"to_pandas",
|
737
|
+
lambda graph_edges=graph_edges: graph_edges,
|
372
738
|
)()
|
373
739
|
edges_csv = edges_pandas.to_csv(index=False)
|
374
740
|
graph_dict["text"] = nodes_csv + "\n" + edges_csv
|
@@ -414,6 +780,35 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
414
780
|
|
415
781
|
return graph_nodes, graph_edges
|
416
782
|
|
783
|
+
def _get_dynamic_metric_type(self, cfg: dict) -> str:
|
784
|
+
"""Helper method to get dynamic metric type."""
|
785
|
+
has_vector_processing = hasattr(cfg, "vector_processing")
|
786
|
+
if has_vector_processing:
|
787
|
+
dynamic_metrics_enabled = getattr(cfg.vector_processing, "dynamic_metrics", True)
|
788
|
+
else:
|
789
|
+
dynamic_metrics_enabled = False
|
790
|
+
if has_vector_processing and dynamic_metrics_enabled:
|
791
|
+
return self.loader.metric_type
|
792
|
+
return getattr(cfg, "search_metric_type", self.loader.metric_type)
|
793
|
+
|
794
|
+
def _create_pcst_instance(
|
795
|
+
self, params: ExtractionParams, query_row: dict, dynamic_metric_type: str
|
796
|
+
) -> MultimodalPCSTPruning:
|
797
|
+
"""Helper method to create PCST pruning instance."""
|
798
|
+
return MultimodalPCSTPruning(
|
799
|
+
topk=params.state["topk_nodes"],
|
800
|
+
topk_e=params.state["topk_edges"],
|
801
|
+
cost_e=params.cfg.cost_e,
|
802
|
+
c_const=params.cfg.c_const,
|
803
|
+
root=params.cfg.root,
|
804
|
+
num_clusters=params.cfg.num_clusters,
|
805
|
+
pruning=params.cfg.pruning,
|
806
|
+
verbosity_level=params.cfg.verbosity_level,
|
807
|
+
use_description=query_row["use_description"],
|
808
|
+
metric_type=dynamic_metric_type,
|
809
|
+
loader=self.loader,
|
810
|
+
)
|
811
|
+
|
417
812
|
def normalize_vector(self, v: list) -> list:
|
418
813
|
"""
|
419
814
|
Normalize a vector using appropriate library (CuPy for GPU, NumPy for CPU).
|
@@ -432,7 +827,7 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
432
827
|
# CPU mode: return as-is for COSINE similarity
|
433
828
|
return v
|
434
829
|
|
435
|
-
def
|
830
|
+
async def _run_async(
|
436
831
|
self,
|
437
832
|
tool_call_id: Annotated[str, InjectedToolCallId],
|
438
833
|
state: Annotated[dict, InjectedState],
|
@@ -459,55 +854,71 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
459
854
|
config_name="config",
|
460
855
|
overrides=["tools/multimodal_subgraph_extraction=default"],
|
461
856
|
)
|
462
|
-
cfg_db = cfg.app.frontend
|
463
857
|
cfg = cfg.tools.multimodal_subgraph_extraction
|
464
858
|
|
465
|
-
#
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
#
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
859
|
+
# Load database configuration separately
|
860
|
+
with hydra.initialize(version_base=None, config_path="../configs"):
|
861
|
+
cfg_all = hydra.compose(config_name="config")
|
862
|
+
cfg_db = cfg_all.utils.database.milvus
|
863
|
+
|
864
|
+
# Establish Milvus connection using singleton connection manager
|
865
|
+
logger.log(logging.INFO, "Getting Milvus connection manager (singleton)")
|
866
|
+
connection_manager = MilvusConnectionManager(cfg_db)
|
867
|
+
try:
|
868
|
+
connection_manager.ensure_connection()
|
869
|
+
logger.log(logging.INFO, "Milvus connection established successfully")
|
870
|
+
|
871
|
+
# Log connection info
|
872
|
+
conn_info = connection_manager.get_connection_info()
|
873
|
+
logger.log(logging.INFO, "Connected to database: %s", conn_info.get("database"))
|
874
|
+
logger.log(
|
875
|
+
logging.INFO,
|
876
|
+
"Connection healthy: %s",
|
877
|
+
connection_manager.test_connection(),
|
878
|
+
)
|
879
|
+
except Exception as e:
|
880
|
+
logger.error("Failed to establish Milvus connection: %s", str(e))
|
881
|
+
raise RuntimeError(f"Cannot connect to Milvus database: {str(e)}") from e
|
882
|
+
|
883
|
+
# Prepare the query embeddings and modalities (async)
|
884
|
+
logger.log(logging.INFO, "_prepare_query_modalities_async")
|
885
|
+
query_df = await self._prepare_query_modalities_async(
|
480
886
|
{
|
481
887
|
"text": prompt,
|
482
888
|
"emb": [self.normalize_vector(state["embedding_model"].embed_query(prompt))],
|
483
889
|
},
|
484
890
|
state,
|
485
891
|
cfg_db,
|
892
|
+
connection_manager,
|
486
893
|
)
|
487
|
-
# end = datetime.datetime.now()
|
488
|
-
# logger.log(logging.INFO, "_prepare_query_modalities time: %s seconds",
|
489
|
-
# (end - start).total_seconds())
|
490
894
|
|
491
|
-
# Perform subgraph extraction
|
492
|
-
logger.log(logging.INFO, "
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
895
|
+
# Perform subgraph extraction (async)
|
896
|
+
logger.log(logging.INFO, "_perform_subgraph_extraction_async")
|
897
|
+
extraction_params = ExtractionParams(
|
898
|
+
state=state,
|
899
|
+
cfg=cfg,
|
900
|
+
cfg_db=cfg_db,
|
901
|
+
query_df=query_df,
|
902
|
+
connection_manager=connection_manager,
|
903
|
+
)
|
904
|
+
subgraphs = await self._perform_subgraph_extraction_async(extraction_params)
|
498
905
|
|
499
906
|
# Prepare subgraph as a NetworkX graph and textualized graph
|
500
907
|
logger.log(logging.INFO, "_prepare_final_subgraph")
|
501
908
|
logger.log(logging.INFO, "Subgraphs extracted: %s", len(subgraphs))
|
502
909
|
# start = datetime.datetime.now()
|
503
|
-
final_subgraph = self._prepare_final_subgraph(state, subgraphs,
|
910
|
+
final_subgraph = self._prepare_final_subgraph(state, subgraphs, cfg_db)
|
504
911
|
# end = datetime.datetime.now()
|
505
912
|
# logger.log(logging.INFO, "_prepare_final_subgraph time: %s seconds",
|
506
913
|
# (end - start).total_seconds())
|
507
914
|
|
915
|
+
# Create final result and return command
|
916
|
+
return self._create_extraction_result(tool_call_id, state, final_subgraph, arg_data)
|
917
|
+
|
918
|
+
def _create_extraction_result(self, tool_call_id, state, final_subgraph, arg_data):
|
919
|
+
"""Create the final extraction result and command."""
|
508
920
|
# Prepare the dictionary of extracted graph
|
509
921
|
logger.log(logging.INFO, "dic_extracted_graph")
|
510
|
-
# start = datetime.datetime.now()
|
511
922
|
dic_extracted_graph = {
|
512
923
|
"name": arg_data.extraction_name,
|
513
924
|
"tool_call_id": tool_call_id,
|
@@ -522,28 +933,33 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
522
933
|
"graph_text": final_subgraph["text"],
|
523
934
|
"graph_summary": None,
|
524
935
|
}
|
525
|
-
# end = datetime.datetime.now()
|
526
|
-
# logger.log(logging.INFO, "dic_extracted_graph time: %s seconds",
|
527
|
-
# (end - start).total_seconds())
|
528
936
|
|
529
|
-
#
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
937
|
+
# Debug logging
|
938
|
+
logger.info(
|
939
|
+
"Created dic_extracted_graph with keys: %s",
|
940
|
+
list(dic_extracted_graph.keys()),
|
941
|
+
)
|
942
|
+
logger.info(
|
943
|
+
"Graph dict structure - name count: %d, nodes count: %d, edges count: %d",
|
944
|
+
len(dic_extracted_graph["graph_dict"]["name"]),
|
945
|
+
len(dic_extracted_graph["graph_dict"]["nodes"]),
|
946
|
+
len(dic_extracted_graph["graph_dict"]["edges"]),
|
947
|
+
)
|
536
948
|
|
537
|
-
#
|
949
|
+
# Create success message
|
950
|
+
success_message = (
|
951
|
+
f"Successfully extracted subgraph '{arg_data.extraction_name}' "
|
952
|
+
f"with {len(final_subgraph['name'])} graph(s). The subgraph contains "
|
953
|
+
f"{sum(len(nodes) for nodes in final_subgraph['nodes'])} nodes and "
|
954
|
+
f"{sum(len(edges) for edges in final_subgraph['edges'])} edges. "
|
955
|
+
"The extracted subgraph has been stored and is ready for "
|
956
|
+
"visualization and analysis."
|
957
|
+
)
|
958
|
+
|
959
|
+
# Return the command with updated state
|
538
960
|
return Command(
|
539
|
-
update=
|
961
|
+
update={"dic_extracted_graph": [dic_extracted_graph]}
|
540
962
|
| {
|
541
|
-
|
542
|
-
"messages": [
|
543
|
-
ToolMessage(
|
544
|
-
content=f"Subgraph Extraction Result of {arg_data.extraction_name}",
|
545
|
-
tool_call_id=tool_call_id,
|
546
|
-
)
|
547
|
-
],
|
963
|
+
"messages": [ToolMessage(content=success_message, tool_call_id=tool_call_id)],
|
548
964
|
}
|
549
965
|
)
|