aiagents4pharma 1.45.1__py3-none-any.whl → 1.46.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. aiagents4pharma/talk2aiagents4pharma/configs/app/__init__.py +0 -0
  2. aiagents4pharma/talk2aiagents4pharma/configs/app/frontend/__init__.py +0 -0
  3. aiagents4pharma/talk2aiagents4pharma/configs/app/frontend/default.yaml +102 -0
  4. aiagents4pharma/talk2aiagents4pharma/configs/config.yaml +1 -0
  5. aiagents4pharma/talk2aiagents4pharma/tests/test_main_agent.py +144 -54
  6. aiagents4pharma/talk2biomodels/api/__init__.py +1 -1
  7. aiagents4pharma/talk2biomodels/configs/app/__init__.py +0 -0
  8. aiagents4pharma/talk2biomodels/configs/app/frontend/__init__.py +0 -0
  9. aiagents4pharma/talk2biomodels/configs/app/frontend/default.yaml +72 -0
  10. aiagents4pharma/talk2biomodels/configs/config.yaml +1 -0
  11. aiagents4pharma/talk2biomodels/tests/test_api.py +0 -30
  12. aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +1 -1
  13. aiagents4pharma/talk2biomodels/tools/get_annotation.py +1 -10
  14. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +42 -26
  15. aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +1 -0
  16. aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/default.yaml +4 -23
  17. aiagents4pharma/talk2knowledgegraphs/configs/utils/database/milvus/__init__.py +3 -0
  18. aiagents4pharma/talk2knowledgegraphs/configs/utils/database/milvus/default.yaml +61 -0
  19. aiagents4pharma/talk2knowledgegraphs/entrypoint.sh +1 -11
  20. aiagents4pharma/talk2knowledgegraphs/milvus_data_dump.py +11 -10
  21. aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +193 -73
  22. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py +1375 -667
  23. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_database_milvus_connection_manager.py +812 -0
  24. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_extractions_milvus_multimodal_pcst.py +723 -539
  25. aiagents4pharma/talk2knowledgegraphs/tools/milvus_multimodal_subgraph_extraction.py +474 -58
  26. aiagents4pharma/talk2knowledgegraphs/utils/database/__init__.py +5 -0
  27. aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py +586 -0
  28. aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py +240 -8
  29. aiagents4pharma/talk2scholars/configs/app/frontend/default.yaml +67 -31
  30. {aiagents4pharma-1.45.1.dist-info → aiagents4pharma-1.46.1.dist-info}/METADATA +10 -1
  31. {aiagents4pharma-1.45.1.dist-info → aiagents4pharma-1.46.1.dist-info}/RECORD +33 -23
  32. aiagents4pharma/talk2biomodels/api/kegg.py +0 -87
  33. {aiagents4pharma-1.45.1.dist-info → aiagents4pharma-1.46.1.dist-info}/WHEEL +0 -0
  34. {aiagents4pharma-1.45.1.dist-info → aiagents4pharma-1.46.1.dist-info}/licenses/LICENSE +0 -0
@@ -2,11 +2,15 @@
2
2
  Tool for performing multimodal subgraph extraction.
3
3
  """
4
4
 
5
+ import asyncio
6
+ import concurrent.futures
5
7
  import logging
8
+ from dataclasses import dataclass
6
9
  from typing import Annotated
7
10
 
8
11
  import hydra
9
12
  import pandas as pd
13
+ import pcst_fast
10
14
  from langchain_core.messages import ToolMessage
11
15
  from langchain_core.tools import BaseTool
12
16
  from langchain_core.tools.base import InjectedToolCallId
@@ -15,6 +19,8 @@ from langgraph.types import Command
15
19
  from pydantic import BaseModel, Field
16
20
  from pymilvus import Collection
17
21
 
22
+ from ..utils.database import MilvusConnectionManager
23
+ from ..utils.database.milvus_connection_manager import QueryParams
18
24
  from ..utils.extractions.milvus_multimodal_pcst import (
19
25
  DynamicLibraryLoader,
20
26
  MultimodalPCSTPruning,
@@ -22,11 +28,23 @@ from ..utils.extractions.milvus_multimodal_pcst import (
22
28
  )
23
29
  from .load_arguments import ArgumentData
24
30
 
31
+ # pylint: disable=too-many-lines
25
32
  # Initialize logger
26
33
  logging.basicConfig(level=logging.INFO)
27
34
  logger = logging.getLogger(__name__)
28
35
 
29
36
 
37
+ @dataclass
38
+ class ExtractionParams:
39
+ """Parameters for subgraph extraction."""
40
+
41
+ state: dict
42
+ cfg: dict
43
+ cfg_db: dict
44
+ query_df: object
45
+ connection_manager: object
46
+
47
+
30
48
  class MultimodalSubgraphExtractionInput(BaseModel):
31
49
  """
32
50
  MultimodalSubgraphExtractionInput is a Pydantic model representing an input
@@ -118,7 +136,15 @@ class MultimodalSubgraphExtractionTool(BaseTool):
118
136
  q_node_names = getattr(
119
137
  node_names_series, "to_pandas", lambda series=node_names_series: series
120
138
  )().tolist()
121
- q_columns = ["node_id", "node_name", "node_type", "feat", "feat_emb", "desc", "desc_emb"]
139
+ q_columns = [
140
+ "node_id",
141
+ "node_name",
142
+ "node_type",
143
+ "feat",
144
+ "feat_emb",
145
+ "desc",
146
+ "desc_emb",
147
+ ]
122
148
  res = collection.query(
123
149
  expr=f"node_name IN [{','.join(f'"{name}"' for name in q_node_names)}]",
124
150
  output_fields=q_columns,
@@ -133,6 +159,52 @@ class MultimodalSubgraphExtractionTool(BaseTool):
133
159
  res_df["use_description"] = False
134
160
  return res_df
135
161
 
162
+ async def _query_milvus_collection_async(
163
+ self, node_type, node_type_df, cfg_db, connection_manager
164
+ ):
165
+ """Helper method to query Milvus collection asynchronously for a specific node type."""
166
+ collection_name = f"{cfg_db.milvus_db.database_name}_nodes_{node_type.replace('/', '_')}"
167
+
168
+ # Query the collection with node names from multimodal_df
169
+ node_names_series = node_type_df["q_node_name"]
170
+ q_node_names = getattr(
171
+ node_names_series, "to_pandas", lambda series=node_names_series: series
172
+ )().tolist()
173
+
174
+ # Create filter expression for async query
175
+ node_names_str = ",".join(f'"{name}"' for name in q_node_names)
176
+ expr = f"node_name IN [{node_names_str}]"
177
+
178
+ q_columns = [
179
+ "node_id",
180
+ "node_name",
181
+ "node_type",
182
+ "feat",
183
+ "feat_emb",
184
+ "desc",
185
+ "desc_emb",
186
+ ]
187
+
188
+ # Create query parameters and perform async query
189
+ query_params = QueryParams(
190
+ collection_name=collection_name, expr=expr, output_fields=q_columns
191
+ )
192
+ res = await connection_manager.async_query(query_params)
193
+
194
+ # Convert the embeddings into floats
195
+ for r_ in res:
196
+ r_["feat_emb"] = [float(x) for x in r_["feat_emb"]]
197
+ r_["desc_emb"] = [float(x) for x in r_["desc_emb"]]
198
+
199
+ # Convert the result to a DataFrame
200
+ res_df = (
201
+ self.loader.df.DataFrame(res)[q_columns]
202
+ if res
203
+ else self.loader.df.DataFrame(columns=q_columns)
204
+ )
205
+ res_df["use_description"] = False
206
+ return res_df
207
+
136
208
  def _prepare_query_modalities(
137
209
  self, prompt: dict, state: Annotated[dict, InjectedState], cfg_db: dict
138
210
  ):
@@ -201,6 +273,97 @@ class MultimodalSubgraphExtractionTool(BaseTool):
201
273
 
202
274
  return query_df
203
275
 
276
+ async def _prepare_query_modalities_async(
277
+ self,
278
+ prompt: dict,
279
+ state: Annotated[dict, InjectedState],
280
+ cfg_db: dict,
281
+ connection_manager,
282
+ ):
283
+ """
284
+ Prepare the modality-specific query for subgraph extraction asynchronously.
285
+
286
+ Args:
287
+ prompt: The dictionary containing the user prompt and embeddings
288
+ state: The injected state for the tool
289
+ cfg_db: The configuration dictionary for Milvus database
290
+ connection_manager: The MilvusConnectionManager instance
291
+
292
+ Returns:
293
+ A DataFrame containing the query embeddings and modalities
294
+ """
295
+ # Initialize dataframes
296
+ logger.log(logging.INFO, "Initializing dataframes (async)")
297
+ query_df = []
298
+ prompt_df = self.loader.df.DataFrame(
299
+ {
300
+ "node_id": "user_prompt",
301
+ "node_name": "User Prompt",
302
+ "node_type": "prompt",
303
+ "feat": prompt["text"],
304
+ "feat_emb": prompt["emb"],
305
+ "desc": prompt["text"],
306
+ "desc_emb": prompt["emb"],
307
+ "use_description": True, # set to True for user prompt embedding
308
+ }
309
+ )
310
+
311
+ # Read multimodal files uploaded by the user
312
+ multimodal_df = self._read_multimodal_files(state)
313
+
314
+ # Check if the multimodal_df is empty
315
+ logger.log(logging.INFO, "Prepare query modalities (async)")
316
+ if len(multimodal_df) > 0:
317
+ # Create parallel tasks for querying each node type
318
+ logger.log(
319
+ logging.INFO,
320
+ "Querying Milvus database for each node type in multimodal_df (parallel)",
321
+ )
322
+
323
+ # Create async tasks for each node type
324
+ tasks = []
325
+ for node_type, node_type_df in multimodal_df.groupby("q_node_type"):
326
+ print(f"Processing node type: {node_type}")
327
+ task = self._query_milvus_collection_async(
328
+ node_type, node_type_df, cfg_db, connection_manager
329
+ )
330
+ tasks.append(task)
331
+
332
+ # Execute all queries in parallel using hybrid approach
333
+ if len(tasks) == 1:
334
+ # Single task, run directly
335
+ query_results = [await tasks[0]]
336
+ else:
337
+ # Multiple tasks, but use sequential execution to avoid event loop issues
338
+ query_results = []
339
+ for task in tasks:
340
+ result = await task
341
+ query_results.append(result)
342
+
343
+ query_df.extend(query_results)
344
+
345
+ # Concatenate all results into a single DataFrame
346
+ logger.log(logging.INFO, "Concatenating all results into a single DataFrame")
347
+ query_df = self.loader.df.concat(query_df, ignore_index=True)
348
+
349
+ # Update the state by adding the selected node IDs
350
+ logger.log(logging.INFO, "Updating state with selected node IDs")
351
+ state["selections"] = (
352
+ getattr(query_df, "to_pandas", lambda: query_df)()
353
+ .groupby("node_type")["node_id"]
354
+ .apply(list)
355
+ .to_dict()
356
+ )
357
+
358
+ # Append a user prompt to the query dataframe
359
+ logger.log(logging.INFO, "Adding user prompt to query dataframe")
360
+ query_df = self.loader.df.concat([query_df, prompt_df]).reset_index(drop=True)
361
+ else:
362
+ # If no multimodal files are uploaded, use the prompt embeddings
363
+ query_df = prompt_df
364
+
365
+ return query_df
366
+
204
367
  def _perform_subgraph_extraction(
205
368
  self,
206
369
  state: Annotated[dict, InjectedState],
@@ -287,7 +450,13 @@ class MultimodalSubgraphExtractionTool(BaseTool):
287
450
 
288
451
  # Convert the unified subgraph and subgraphs to DataFrames
289
452
  unified_subgraph = self.loader.df.DataFrame(
290
- [("Unified Subgraph", unified_subgraph["nodes"], unified_subgraph["edges"])],
453
+ [
454
+ (
455
+ "Unified Subgraph",
456
+ unified_subgraph["nodes"],
457
+ unified_subgraph["edges"],
458
+ )
459
+ ],
291
460
  columns=["name", "nodes", "edges"],
292
461
  )
293
462
  subgraphs = self.loader.df.DataFrame(subgraphs, columns=["name", "nodes", "edges"])
@@ -297,8 +466,199 @@ class MultimodalSubgraphExtractionTool(BaseTool):
297
466
 
298
467
  return subgraphs
299
468
 
469
+ async def _perform_subgraph_extraction_async(self, params: ExtractionParams) -> dict:
470
+ """
471
+ Perform multimodal subgraph extraction based on modal-specific embeddings asynchronously.
472
+
473
+ Args:
474
+ state: The injected state for the tool
475
+ cfg: The configuration dictionary
476
+ cfg_db: The configuration dictionary for Milvus database
477
+ query_df: The DataFrame containing the query embeddings and modalities
478
+ connection_manager: The MilvusConnectionManager instance
479
+
480
+ Returns:
481
+ A dictionary containing the extracted subgraph with nodes and edges
482
+ """
483
+ # Initialize the subgraph dictionary
484
+ subgraphs = []
485
+ unified_subgraph = {"nodes": [], "edges": []}
486
+
487
+ # Create parallel tasks for each query
488
+ tasks = []
489
+ query_info = []
490
+
491
+ for q in getattr(params.query_df, "to_pandas", lambda: params.query_df)().iterrows():
492
+ logger.log(logging.INFO, "===========================================")
493
+ logger.log(logging.INFO, "Processing query: %s", q[1]["node_name"])
494
+
495
+ # Store query info for later processing
496
+ query_info.append(q[1])
497
+
498
+ # Get dynamic metric type using helper method
499
+ dynamic_metric_type = self._get_dynamic_metric_type(params.cfg)
500
+
501
+ # Create PCST pruning instance using helper
502
+ pcst_instance = self._create_pcst_instance(params, q[1], dynamic_metric_type)
503
+
504
+ # Create async task for subgraph extraction
505
+ task = self._extract_single_subgraph_async(
506
+ pcst_instance, q[1], params.cfg_db, params.connection_manager
507
+ )
508
+ tasks.append(task)
509
+
510
+ # Execute all subgraph extractions sequentially to avoid event loop conflicts
511
+ subgraph_results = []
512
+ for i, task in enumerate(tasks):
513
+ logger.log(logging.INFO, "Processing subgraph %d/%d", i + 1, len(tasks))
514
+ result = await task
515
+ subgraph_results.append(result)
516
+
517
+ # Process results and finalize
518
+ self._process_subgraph_results(subgraph_results, query_info, unified_subgraph, subgraphs)
519
+ return self._finalize_subgraph_results(subgraphs, unified_subgraph)
520
+
521
+ def _process_subgraph_results(self, subgraph_results, query_info, unified_subgraph, subgraphs):
522
+ """Process individual subgraph results."""
523
+ for i, subgraph in enumerate(subgraph_results):
524
+ query_row = query_info[i]
525
+ unified_subgraph["nodes"].append(subgraph["nodes"].tolist())
526
+ unified_subgraph["edges"].append(subgraph["edges"].tolist())
527
+ subgraphs.append(
528
+ (
529
+ query_row["node_name"],
530
+ subgraph["nodes"].tolist(),
531
+ subgraph["edges"].tolist(),
532
+ )
533
+ )
534
+
535
+ def _finalize_subgraph_results(self, subgraphs, unified_subgraph):
536
+ """Process and finalize subgraph results into DataFrames."""
537
+ # Concatenate and get unique node and edge indices
538
+ nodes_arrays = [self.loader.py.array(list_) for list_ in unified_subgraph["nodes"]]
539
+ unified_subgraph["nodes"] = self.loader.py.unique(
540
+ self.loader.py.concatenate(nodes_arrays)
541
+ ).tolist()
542
+ edges_arrays = [self.loader.py.array(list_) for list_ in unified_subgraph["edges"]]
543
+ unified_subgraph["edges"] = self.loader.py.unique(
544
+ self.loader.py.concatenate(edges_arrays)
545
+ ).tolist()
546
+
547
+ # Convert the unified subgraph and subgraphs to DataFrames
548
+ unified_subgraph_df = self.loader.df.DataFrame(
549
+ [
550
+ (
551
+ "Unified Subgraph",
552
+ unified_subgraph["nodes"],
553
+ unified_subgraph["edges"],
554
+ )
555
+ ],
556
+ columns=["name", "nodes", "edges"],
557
+ )
558
+ subgraphs_df = self.loader.df.DataFrame(subgraphs, columns=["name", "nodes", "edges"])
559
+
560
+ # Concatenate both DataFrames
561
+ return self.loader.df.concat([unified_subgraph_df, subgraphs_df], ignore_index=True)
562
+
563
+ async def _extract_single_subgraph_async(
564
+ self, pcst_instance, query_row, cfg_db, connection_manager
565
+ ):
566
+ """
567
+ Extract a single subgraph asynchronously using the new async methods.
568
+ """
569
+ # Load data and compute prizes
570
+ edge_index, prizes, num_nodes = await self._load_subgraph_data(
571
+ pcst_instance, query_row, cfg_db, connection_manager
572
+ )
573
+
574
+ # Run PCST algorithm and get results
575
+ return self._run_pcst_algorithm(pcst_instance, edge_index, num_nodes, prizes)
576
+
577
+ async def _load_subgraph_data(self, pcst_instance, query_row, cfg_db, connection_manager):
578
+ """Load edge index, compute prizes, and get node count."""
579
+ # Load edge index asynchronously
580
+ edge_index = await pcst_instance.load_edge_index_async(cfg_db, connection_manager)
581
+
582
+ # Compute prizes asynchronously
583
+ prizes = await pcst_instance.compute_prizes_async(
584
+ query_row["desc_emb"],
585
+ query_row["feat_emb"],
586
+ cfg_db,
587
+ query_row["node_type"],
588
+ )
589
+
590
+ # Get number of nodes
591
+ nodes_collection = f"{cfg_db.milvus_db.database_name}_nodes"
592
+ stats = await connection_manager.async_get_collection_stats(nodes_collection)
593
+ num_nodes = stats["num_entities"]
594
+
595
+ return edge_index, prizes, num_nodes
596
+
597
+ def _run_pcst_algorithm(self, pcst_instance, edge_index, num_nodes, prizes):
598
+ """Run PCST algorithm and get subgraph results."""
599
+ # Compute costs in constructing the subgraph
600
+ edges_dict, prizes_final, costs, mapping = pcst_instance.compute_subgraph_costs(
601
+ edge_index, num_nodes, prizes
602
+ )
603
+
604
+ # Retrieve the subgraph using the PCST algorithm
605
+ result_vertices, result_edges = pcst_fast.pcst_fast(
606
+ edges_dict["edges"].tolist(),
607
+ prizes_final.tolist(),
608
+ costs.tolist(),
609
+ pcst_instance.root,
610
+ pcst_instance.num_clusters,
611
+ pcst_instance.pruning,
612
+ pcst_instance.verbosity_level,
613
+ )
614
+
615
+ # Get subgraph nodes and edges based on the PCST result
616
+ return pcst_instance.get_subgraph_nodes_edges(
617
+ num_nodes,
618
+ pcst_instance.loader.py.asarray(result_vertices),
619
+ {
620
+ "edges": pcst_instance.loader.py.asarray(result_edges),
621
+ "num_prior_edges": edges_dict["num_prior_edges"],
622
+ "edge_index": edge_index,
623
+ },
624
+ mapping,
625
+ )
626
+
627
+ def _run(
628
+ self,
629
+ tool_call_id: Annotated[str, InjectedToolCallId],
630
+ state: Annotated[dict, InjectedState],
631
+ prompt: str,
632
+ arg_data: ArgumentData = None,
633
+ ) -> Command:
634
+ """
635
+ Synchronous wrapper for the async _run_async method.
636
+ This maintains compatibility with LangGraph while using async operations internally.
637
+ """
638
+ # concurrent.futures imported at top level
639
+
640
+ def run_in_thread():
641
+ """Run async method in a new thread with its own event loop."""
642
+ # Create a new event loop for this thread
643
+ new_loop = asyncio.new_event_loop()
644
+ asyncio.set_event_loop(new_loop)
645
+ try:
646
+ result = new_loop.run_until_complete(
647
+ self._run_async(tool_call_id, state, prompt, arg_data)
648
+ )
649
+ return result
650
+ finally:
651
+ # Properly cleanup the event loop
652
+ new_loop.close()
653
+ asyncio.set_event_loop(None)
654
+
655
+ # Always use a separate thread to avoid event loop conflicts
656
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
657
+ future = executor.submit(run_in_thread)
658
+ return future.result()
659
+
300
660
  def _prepare_final_subgraph(
301
- self, state: Annotated[dict, InjectedState], subgraph: dict, cfg: dict, cfg_db
661
+ self, state: Annotated[dict, InjectedState], subgraph: dict, cfg_db
302
662
  ) -> dict:
303
663
  """
304
664
  Prepare the subgraph based on the extracted subgraph.
@@ -306,8 +666,6 @@ class MultimodalSubgraphExtractionTool(BaseTool):
306
666
  Args:
307
667
  state: The injected state for the tool.
308
668
  subgraph: The extracted subgraph.
309
- graph: The graph dictionary.
310
- cfg: The configuration dictionary for the tool.
311
669
  cfg_db: The configuration dictionary for Milvus database.
312
670
 
313
671
  Returns:
@@ -315,7 +673,7 @@ class MultimodalSubgraphExtractionTool(BaseTool):
315
673
  """
316
674
  # Convert the dict to a DataFrame
317
675
  node_colors = {
318
- n: cfg.node_colors_dict[k] for k, v in state["selections"].items() for n in v
676
+ n: cfg_db.node_colors_dict[k] for k, v in state["selections"].items() for n in v
319
677
  }
320
678
  color_df = self.loader.df.DataFrame(list(node_colors.items()), columns=["node_id", "color"])
321
679
  # print(color_df)
@@ -345,7 +703,9 @@ class MultimodalSubgraphExtractionTool(BaseTool):
345
703
  },
346
704
  )
347
705
  for row in getattr(
348
- graph_nodes, "to_pandas", lambda graph_nodes=graph_nodes: graph_nodes
706
+ graph_nodes,
707
+ "to_pandas",
708
+ lambda graph_nodes=graph_nodes: graph_nodes,
349
709
  )().itertuples(index=False)
350
710
  ]
351
711
  )
@@ -353,7 +713,9 @@ class MultimodalSubgraphExtractionTool(BaseTool):
353
713
  [
354
714
  (row.head_id, row.tail_id, {"label": tuple(row.edge_type)})
355
715
  for row in getattr(
356
- graph_edges, "to_pandas", lambda graph_edges=graph_edges: graph_edges
716
+ graph_edges,
717
+ "to_pandas",
718
+ lambda graph_edges=graph_edges: graph_edges,
357
719
  )().itertuples(index=False)
358
720
  ]
359
721
  )
@@ -364,11 +726,15 @@ class MultimodalSubgraphExtractionTool(BaseTool):
364
726
  graph_nodes.rename(columns={"desc": "node_attr"}, inplace=True)
365
727
  graph_edges = graph_edges[["head_id", "edge_type", "tail_id"]]
366
728
  nodes_pandas = getattr(
367
- graph_nodes, "to_pandas", lambda graph_nodes=graph_nodes: graph_nodes
729
+ graph_nodes,
730
+ "to_pandas",
731
+ lambda graph_nodes=graph_nodes: graph_nodes,
368
732
  )()
369
733
  nodes_csv = nodes_pandas.to_csv(index=False)
370
734
  edges_pandas = getattr(
371
- graph_edges, "to_pandas", lambda graph_edges=graph_edges: graph_edges
735
+ graph_edges,
736
+ "to_pandas",
737
+ lambda graph_edges=graph_edges: graph_edges,
372
738
  )()
373
739
  edges_csv = edges_pandas.to_csv(index=False)
374
740
  graph_dict["text"] = nodes_csv + "\n" + edges_csv
@@ -414,6 +780,35 @@ class MultimodalSubgraphExtractionTool(BaseTool):
414
780
 
415
781
  return graph_nodes, graph_edges
416
782
 
783
+ def _get_dynamic_metric_type(self, cfg: dict) -> str:
784
+ """Helper method to get dynamic metric type."""
785
+ has_vector_processing = hasattr(cfg, "vector_processing")
786
+ if has_vector_processing:
787
+ dynamic_metrics_enabled = getattr(cfg.vector_processing, "dynamic_metrics", True)
788
+ else:
789
+ dynamic_metrics_enabled = False
790
+ if has_vector_processing and dynamic_metrics_enabled:
791
+ return self.loader.metric_type
792
+ return getattr(cfg, "search_metric_type", self.loader.metric_type)
793
+
794
+ def _create_pcst_instance(
795
+ self, params: ExtractionParams, query_row: dict, dynamic_metric_type: str
796
+ ) -> MultimodalPCSTPruning:
797
+ """Helper method to create PCST pruning instance."""
798
+ return MultimodalPCSTPruning(
799
+ topk=params.state["topk_nodes"],
800
+ topk_e=params.state["topk_edges"],
801
+ cost_e=params.cfg.cost_e,
802
+ c_const=params.cfg.c_const,
803
+ root=params.cfg.root,
804
+ num_clusters=params.cfg.num_clusters,
805
+ pruning=params.cfg.pruning,
806
+ verbosity_level=params.cfg.verbosity_level,
807
+ use_description=query_row["use_description"],
808
+ metric_type=dynamic_metric_type,
809
+ loader=self.loader,
810
+ )
811
+
417
812
  def normalize_vector(self, v: list) -> list:
418
813
  """
419
814
  Normalize a vector using appropriate library (CuPy for GPU, NumPy for CPU).
@@ -432,7 +827,7 @@ class MultimodalSubgraphExtractionTool(BaseTool):
432
827
  # CPU mode: return as-is for COSINE similarity
433
828
  return v
434
829
 
435
- def _run(
830
+ async def _run_async(
436
831
  self,
437
832
  tool_call_id: Annotated[str, InjectedToolCallId],
438
833
  state: Annotated[dict, InjectedState],
@@ -459,55 +854,71 @@ class MultimodalSubgraphExtractionTool(BaseTool):
459
854
  config_name="config",
460
855
  overrides=["tools/multimodal_subgraph_extraction=default"],
461
856
  )
462
- cfg_db = cfg.app.frontend
463
857
  cfg = cfg.tools.multimodal_subgraph_extraction
464
858
 
465
- # Check if the Milvus connection exists
466
- # logger.log(logging.INFO, "Checking Milvus connection")
467
- # logger.log(logging.INFO, "Milvus connection name: %s", cfg_db.milvus_db.alias)
468
- # logger.log(logging.INFO, "Milvus connection DB: %s", cfg_db.milvus_db.database_name)
469
- # logger.log(logging.INFO, "Is connection established? %s",
470
- # connections.has_connection(cfg_db.milvus_db.alias))
471
- # if connections.has_connection(cfg_db.milvus_db.alias):
472
- # logger.log(logging.INFO, "Milvus connection is established.")
473
- # for collection_name in utility.list_collections():
474
- # logger.log(logging.INFO, "Collection: %s", collection_name)
475
-
476
- # Prepare the query embeddings and modalities
477
- logger.log(logging.INFO, "_prepare_query_modalities")
478
- # start = datetime.datetime.now()
479
- query_df = self._prepare_query_modalities(
859
+ # Load database configuration separately
860
+ with hydra.initialize(version_base=None, config_path="../configs"):
861
+ cfg_all = hydra.compose(config_name="config")
862
+ cfg_db = cfg_all.utils.database.milvus
863
+
864
+ # Establish Milvus connection using singleton connection manager
865
+ logger.log(logging.INFO, "Getting Milvus connection manager (singleton)")
866
+ connection_manager = MilvusConnectionManager(cfg_db)
867
+ try:
868
+ connection_manager.ensure_connection()
869
+ logger.log(logging.INFO, "Milvus connection established successfully")
870
+
871
+ # Log connection info
872
+ conn_info = connection_manager.get_connection_info()
873
+ logger.log(logging.INFO, "Connected to database: %s", conn_info.get("database"))
874
+ logger.log(
875
+ logging.INFO,
876
+ "Connection healthy: %s",
877
+ connection_manager.test_connection(),
878
+ )
879
+ except Exception as e:
880
+ logger.error("Failed to establish Milvus connection: %s", str(e))
881
+ raise RuntimeError(f"Cannot connect to Milvus database: {str(e)}") from e
882
+
883
+ # Prepare the query embeddings and modalities (async)
884
+ logger.log(logging.INFO, "_prepare_query_modalities_async")
885
+ query_df = await self._prepare_query_modalities_async(
480
886
  {
481
887
  "text": prompt,
482
888
  "emb": [self.normalize_vector(state["embedding_model"].embed_query(prompt))],
483
889
  },
484
890
  state,
485
891
  cfg_db,
892
+ connection_manager,
486
893
  )
487
- # end = datetime.datetime.now()
488
- # logger.log(logging.INFO, "_prepare_query_modalities time: %s seconds",
489
- # (end - start).total_seconds())
490
894
 
491
- # Perform subgraph extraction
492
- logger.log(logging.INFO, "_perform_subgraph_extraction")
493
- # start = datetime.datetime.now()
494
- subgraphs = self._perform_subgraph_extraction(state, cfg, cfg_db, query_df)
495
- # end = datetime.datetime.now()
496
- # logger.log(logging.INFO, "_perform_subgraph_extraction time: %s seconds",
497
- # (end - start).total_seconds())
895
+ # Perform subgraph extraction (async)
896
+ logger.log(logging.INFO, "_perform_subgraph_extraction_async")
897
+ extraction_params = ExtractionParams(
898
+ state=state,
899
+ cfg=cfg,
900
+ cfg_db=cfg_db,
901
+ query_df=query_df,
902
+ connection_manager=connection_manager,
903
+ )
904
+ subgraphs = await self._perform_subgraph_extraction_async(extraction_params)
498
905
 
499
906
  # Prepare subgraph as a NetworkX graph and textualized graph
500
907
  logger.log(logging.INFO, "_prepare_final_subgraph")
501
908
  logger.log(logging.INFO, "Subgraphs extracted: %s", len(subgraphs))
502
909
  # start = datetime.datetime.now()
503
- final_subgraph = self._prepare_final_subgraph(state, subgraphs, cfg, cfg_db)
910
+ final_subgraph = self._prepare_final_subgraph(state, subgraphs, cfg_db)
504
911
  # end = datetime.datetime.now()
505
912
  # logger.log(logging.INFO, "_prepare_final_subgraph time: %s seconds",
506
913
  # (end - start).total_seconds())
507
914
 
915
+ # Create final result and return command
916
+ return self._create_extraction_result(tool_call_id, state, final_subgraph, arg_data)
917
+
918
+ def _create_extraction_result(self, tool_call_id, state, final_subgraph, arg_data):
919
+ """Create the final extraction result and command."""
508
920
  # Prepare the dictionary of extracted graph
509
921
  logger.log(logging.INFO, "dic_extracted_graph")
510
- # start = datetime.datetime.now()
511
922
  dic_extracted_graph = {
512
923
  "name": arg_data.extraction_name,
513
924
  "tool_call_id": tool_call_id,
@@ -522,28 +933,33 @@ class MultimodalSubgraphExtractionTool(BaseTool):
522
933
  "graph_text": final_subgraph["text"],
523
934
  "graph_summary": None,
524
935
  }
525
- # end = datetime.datetime.now()
526
- # logger.log(logging.INFO, "dic_extracted_graph time: %s seconds",
527
- # (end - start).total_seconds())
528
936
 
529
- # Prepare the dictionary of updated state
530
- dic_updated_state_for_model = {}
531
- for key, value in {
532
- "dic_extracted_graph": [dic_extracted_graph],
533
- }.items():
534
- if value:
535
- dic_updated_state_for_model[key] = value
937
+ # Debug logging
938
+ logger.info(
939
+ "Created dic_extracted_graph with keys: %s",
940
+ list(dic_extracted_graph.keys()),
941
+ )
942
+ logger.info(
943
+ "Graph dict structure - name count: %d, nodes count: %d, edges count: %d",
944
+ len(dic_extracted_graph["graph_dict"]["name"]),
945
+ len(dic_extracted_graph["graph_dict"]["nodes"]),
946
+ len(dic_extracted_graph["graph_dict"]["edges"]),
947
+ )
536
948
 
537
- # Return the updated state of the tool
949
+ # Create success message
950
+ success_message = (
951
+ f"Successfully extracted subgraph '{arg_data.extraction_name}' "
952
+ f"with {len(final_subgraph['name'])} graph(s). The subgraph contains "
953
+ f"{sum(len(nodes) for nodes in final_subgraph['nodes'])} nodes and "
954
+ f"{sum(len(edges) for edges in final_subgraph['edges'])} edges. "
955
+ "The extracted subgraph has been stored and is ready for "
956
+ "visualization and analysis."
957
+ )
958
+
959
+ # Return the command with updated state
538
960
  return Command(
539
- update=dic_updated_state_for_model
961
+ update={"dic_extracted_graph": [dic_extracted_graph]}
540
962
  | {
541
- # update the message history
542
- "messages": [
543
- ToolMessage(
544
- content=f"Subgraph Extraction Result of {arg_data.extraction_name}",
545
- tool_call_id=tool_call_id,
546
- )
547
- ],
963
+ "messages": [ToolMessage(content=success_message, tool_call_id=tool_call_id)],
548
964
  }
549
965
  )