aiagents4pharma 1.43.0__py3-none-any.whl → 1.44.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,27 +2,25 @@
2
2
  Tool for performing multimodal subgraph extraction.
3
3
  """
4
4
 
5
- # import datetime
6
- from typing import Type, Annotated
7
5
  import logging
6
+ from typing import Annotated, Type
7
+
8
8
  import hydra
9
9
  import pandas as pd
10
- from pydantic import BaseModel, Field
11
- from langchain_core.tools import BaseTool
12
10
  from langchain_core.messages import ToolMessage
11
+ from langchain_core.tools import BaseTool
13
12
  from langchain_core.tools.base import InjectedToolCallId
14
- from langgraph.types import Command
15
13
  from langgraph.prebuilt import InjectedState
14
+ from langgraph.types import Command
15
+ from pydantic import BaseModel, Field
16
16
  from pymilvus import Collection
17
- from ..utils.extractions.milvus_multimodal_pcst import MultimodalPCSTPruning
17
+
18
+ from ..utils.extractions.milvus_multimodal_pcst import (
19
+ DynamicLibraryLoader,
20
+ MultimodalPCSTPruning,
21
+ SystemDetector,
22
+ )
18
23
  from .load_arguments import ArgumentData
19
- try:
20
- import cupy as py
21
- import cudf
22
- df = cudf
23
- except ImportError:
24
- import numpy as py
25
- df = pd
26
24
 
27
25
  # Initialize logger
28
26
  logging.basicConfig(level=logging.INFO)
@@ -61,8 +59,16 @@ class MultimodalSubgraphExtractionTool(BaseTool):
61
59
  description: str = "A tool for subgraph extraction based on user's prompt."
62
60
  args_schema: Type[BaseModel] = MultimodalSubgraphExtractionInput
63
61
 
62
+ def __init__(self, **kwargs):
63
+ super().__init__(**kwargs)
64
+ # Initialize hardware detection and dynamic library loading
65
+ object.__setattr__(self, 'detector', SystemDetector())
66
+ object.__setattr__(self, 'loader', DynamicLibraryLoader(self.detector))
67
+ logger.info("MultimodalSubgraphExtractionTool initialized with %s mode",
68
+ "GPU" if self.loader.use_gpu else "CPU")
69
+
64
70
  def _read_multimodal_files(self,
65
- state: Annotated[dict, InjectedState]) -> df.DataFrame:
71
+ state: Annotated[dict, InjectedState]):
66
72
  """
67
73
  Read the uploaded multimodal files and return a DataFrame.
68
74
 
@@ -72,7 +78,7 @@ class MultimodalSubgraphExtractionTool(BaseTool):
72
78
  Returns:
73
79
  A DataFrame containing the multimodal files.
74
80
  """
75
- multimodal_df = df.DataFrame({"name": [], "node_type": []})
81
+ multimodal_df = self.loader.df.DataFrame({"name": [], "node_type": []})
76
82
 
77
83
  # Loop over the uploaded files and find multimodal files
78
84
  logger.log(logging.INFO, "Looping over uploaded files")
@@ -90,7 +96,7 @@ class MultimodalSubgraphExtractionTool(BaseTool):
90
96
  logger.log(logging.INFO, "Preparing multimodal_df")
91
97
  # Merge all obtained dataframes into a single dataframe
92
98
  multimodal_df = pd.concat(multimodal_df).reset_index()
93
- multimodal_df = df.DataFrame(multimodal_df)
99
+ multimodal_df = self.loader.df.DataFrame(multimodal_df)
94
100
  multimodal_df.drop(columns=["level_1"], inplace=True)
95
101
  multimodal_df.rename(columns={"level_0": "q_node_type",
96
102
  "name": "q_node_name"}, inplace=True)
@@ -100,10 +106,39 @@ class MultimodalSubgraphExtractionTool(BaseTool):
100
106
 
101
107
  return multimodal_df
102
108
 
109
+ def _query_milvus_collection(self, node_type, node_type_df, cfg_db):
110
+ """Helper method to query Milvus collection for a specific node type."""
111
+ # Load the collection
112
+ collection = Collection(
113
+ name=f"{cfg_db.milvus_db.database_name}_nodes_{node_type.replace('/', '_')}"
114
+ )
115
+ collection.load()
116
+
117
+ # Query the collection with node names from multimodal_df
118
+ node_names_series = node_type_df['q_node_name']
119
+ q_node_names = getattr(node_names_series,
120
+ "to_pandas",
121
+ lambda series=node_names_series: series)().tolist()
122
+ q_columns = ["node_id", "node_name", "node_type",
123
+ "feat", "feat_emb", "desc", "desc_emb"]
124
+ res = collection.query(
125
+ expr=f'node_name IN [{','.join(f'"{name}"' for name in q_node_names)}]',
126
+ output_fields=q_columns,
127
+ )
128
+ # Convert the embeedings into floats
129
+ for r_ in res:
130
+ r_['feat_emb'] = [float(x) for x in r_['feat_emb']]
131
+ r_['desc_emb'] = [float(x) for x in r_['desc_emb']]
132
+
133
+ # Convert the result to a DataFrame
134
+ res_df = self.loader.df.DataFrame(res)[q_columns]
135
+ res_df["use_description"] = False
136
+ return res_df
137
+
103
138
  def _prepare_query_modalities(self,
104
139
  prompt: dict,
105
140
  state: Annotated[dict, InjectedState],
106
- cfg_db: dict) -> df.DataFrame:
141
+ cfg_db: dict):
107
142
  """
108
143
  Prepare the modality-specific query for subgraph extraction.
109
144
 
@@ -118,7 +153,7 @@ class MultimodalSubgraphExtractionTool(BaseTool):
118
153
  # Initialize dataframes
119
154
  logger.log(logging.INFO, "Initializing dataframes")
120
155
  query_df = []
121
- prompt_df = df.DataFrame({
156
+ prompt_df = self.loader.df.DataFrame({
122
157
  'node_id': 'user_prompt',
123
158
  'node_name': 'User Prompt',
124
159
  'node_type': 'prompt',
@@ -139,38 +174,12 @@ class MultimodalSubgraphExtractionTool(BaseTool):
139
174
  logger.log(logging.INFO, "Querying Milvus database for each node type in multimodal_df")
140
175
  for node_type, node_type_df in multimodal_df.groupby("q_node_type"):
141
176
  print(f"Processing node type: {node_type}")
142
-
143
- # Load the collection
144
- collection = Collection(
145
- name=f"{cfg_db.milvus_db.database_name}_nodes_{node_type.replace('/', '_')}"
146
- )
147
- collection.load()
148
-
149
- # Query the collection with node names from multimodal_df
150
- q_node_names = getattr(node_type_df['q_node_name'],
151
- "to_pandas",
152
- lambda: node_type_df['q_node_name'])().tolist()
153
- q_columns = ["node_id", "node_name", "node_type",
154
- "feat", "feat_emb", "desc", "desc_emb"]
155
- res = collection.query(
156
- expr=f'node_name IN [{','.join(f'"{name}"' for name in q_node_names)}]',
157
- output_fields=q_columns,
158
- )
159
- # Convert the embeedings into floats
160
- for r_ in res:
161
- r_['feat_emb'] = [float(x) for x in r_['feat_emb']]
162
- r_['desc_emb'] = [float(x) for x in r_['desc_emb']]
163
-
164
- # Convert the result to a DataFrame
165
- res_df = df.DataFrame(res)[q_columns]
166
- res_df["use_description"] = False
167
-
168
- # Append the results to query_df
177
+ res_df = self._query_milvus_collection(node_type, node_type_df, cfg_db)
169
178
  query_df.append(res_df)
170
179
 
171
180
  # Concatenate all results into a single DataFrame
172
181
  logger.log(logging.INFO, "Concatenating all results into a single DataFrame")
173
- query_df = df.concat(query_df, ignore_index=True)
182
+ query_df = self.loader.df.concat(query_df, ignore_index=True)
174
183
 
175
184
  # Update the state by adding the the selected node IDs
176
185
  logger.log(logging.INFO, "Updating state with selected node IDs")
@@ -182,7 +191,7 @@ class MultimodalSubgraphExtractionTool(BaseTool):
182
191
 
183
192
  # Append a user prompt to the query dataframe
184
193
  logger.log(logging.INFO, "Adding user prompt to query dataframe")
185
- query_df = df.concat([query_df, prompt_df]).reset_index(drop=True)
194
+ query_df = self.loader.df.concat([query_df, prompt_df]).reset_index(drop=True)
186
195
  else:
187
196
  # If no multimodal files are uploaded, use the prompt embeddings
188
197
  query_df = prompt_df
@@ -223,6 +232,19 @@ class MultimodalSubgraphExtractionTool(BaseTool):
223
232
  # Prepare the PCSTPruning object and extract the subgraph
224
233
  # Parameters were set in the configuration file obtained from Hydra
225
234
  # start = datetime.datetime.now()
235
+ # Get dynamic metric type (overrides any config setting)
236
+ # Get dynamic metric type (overrides any config setting)
237
+ has_vector_processing = hasattr(cfg, 'vector_processing')
238
+ if has_vector_processing:
239
+ dynamic_metrics_enabled = getattr(cfg.vector_processing, 'dynamic_metrics', True)
240
+ else:
241
+ dynamic_metrics_enabled = False
242
+ if has_vector_processing and dynamic_metrics_enabled:
243
+ dynamic_metric_type = self.loader.metric_type
244
+ else:
245
+ dynamic_metric_type = getattr(cfg, 'search_metric_type',
246
+ self.loader.metric_type)
247
+
226
248
  subgraph = MultimodalPCSTPruning(
227
249
  topk=state["topk_nodes"],
228
250
  topk_e=state["topk_edges"],
@@ -233,7 +255,8 @@ class MultimodalSubgraphExtractionTool(BaseTool):
233
255
  pruning=cfg.pruning,
234
256
  verbosity_level=cfg.verbosity_level,
235
257
  use_description=q[1]['use_description'],
236
- metric_type=cfg.search_metric_type
258
+ metric_type=dynamic_metric_type, # Use dynamic or config metric type
259
+ loader=self.loader # Pass the loader instance
237
260
  ).extract_subgraph(q[1]['desc_emb'],
238
261
  q[1]['feat_emb'],
239
262
  q[1]['node_type'],
@@ -251,22 +274,24 @@ class MultimodalSubgraphExtractionTool(BaseTool):
251
274
  # (end - start).total_seconds())
252
275
 
253
276
  # Concatenate and get unique node and edge indices
254
- unified_subgraph["nodes"] = py.unique(
255
- py.concatenate([py.array(list_) for list_ in unified_subgraph["nodes"]])
277
+ nodes_arrays = [self.loader.py.array(list_) for list_ in unified_subgraph["nodes"]]
278
+ unified_subgraph["nodes"] = self.loader.py.unique(
279
+ self.loader.py.concatenate(nodes_arrays)
256
280
  ).tolist()
257
- unified_subgraph["edges"] = py.unique(
258
- py.concatenate([py.array(list_) for list_ in unified_subgraph["edges"]])
281
+ edges_arrays = [self.loader.py.array(list_) for list_ in unified_subgraph["edges"]]
282
+ unified_subgraph["edges"] = self.loader.py.unique(
283
+ self.loader.py.concatenate(edges_arrays)
259
284
  ).tolist()
260
285
 
261
- # Convert the unified subgraph and subgraphs to cudf DataFrames
262
- unified_subgraph = df.DataFrame([("Unified Subgraph",
263
- unified_subgraph["nodes"],
264
- unified_subgraph["edges"])],
265
- columns=["name", "nodes", "edges"])
266
- subgraphs = df.DataFrame(subgraphs, columns=["name", "nodes", "edges"])
286
+ # Convert the unified subgraph and subgraphs to DataFrames
287
+ unified_subgraph = self.loader.df.DataFrame([("Unified Subgraph",
288
+ unified_subgraph["nodes"],
289
+ unified_subgraph["edges"])],
290
+ columns=["name", "nodes", "edges"])
291
+ subgraphs = self.loader.df.DataFrame(subgraphs, columns=["name", "nodes", "edges"])
267
292
 
268
- # Concate both DataFrames
269
- subgraphs = df.concat([unified_subgraph, subgraphs], ignore_index=True)
293
+ # Concatenate both DataFrames
294
+ subgraphs = self.loader.df.concat([unified_subgraph, subgraphs], ignore_index=True)
270
295
 
271
296
  return subgraphs
272
297
 
@@ -288,10 +313,10 @@ class MultimodalSubgraphExtractionTool(BaseTool):
288
313
  Returns:
289
314
  A dictionary containing the PyG graph, NetworkX graph, and textualized graph.
290
315
  """
291
- # Convert the dict to a cudf DataFrame
316
+ # Convert the dict to a DataFrame
292
317
  node_colors = {n: cfg.node_colors_dict[k]
293
318
  for k, v in state["selections"].items() for n in v}
294
- color_df = df.DataFrame(list(node_colors.items()), columns=["node_id", "color"])
319
+ color_df = self.loader.df.DataFrame(list(node_colors.items()), columns=["node_id", "color"])
295
320
  # print(color_df)
296
321
 
297
322
  # Prepare the subgraph dictionary
@@ -302,42 +327,7 @@ class MultimodalSubgraphExtractionTool(BaseTool):
302
327
  "text": ""
303
328
  }
304
329
  for sub in getattr(subgraph, "to_pandas", lambda: subgraph)().itertuples(index=False):
305
- # Prepare the graph name
306
- print(f"Processing subgraph: {sub.name}")
307
- print('---')
308
- print(sub.nodes)
309
- print('---')
310
- print(sub.edges)
311
- print('---')
312
-
313
- # Prepare graph dataframes
314
- # Nodes
315
- coll_name = f"{cfg_db.milvus_db.database_name}_nodes"
316
- node_coll = Collection(name=coll_name)
317
- node_coll.load()
318
- graph_nodes = node_coll.query(
319
- expr=f'node_index IN [{",".join(f"{n}" for n in sub.nodes)}]',
320
- output_fields=['node_id', 'node_name', 'node_type', 'desc']
321
- )
322
- graph_nodes = df.DataFrame(graph_nodes)
323
- graph_nodes.drop(columns=['node_index'], inplace=True)
324
- if not color_df.empty:
325
- # Merge the color dataframe with the graph nodes
326
- graph_nodes = graph_nodes.merge(color_df, on="node_id", how="left")
327
- else:
328
- graph_nodes["color"] = 'black' # Default color
329
- graph_nodes['color'].fillna('black', inplace=True) # Fill NaN colors with black
330
- # Edges
331
- coll_name = f"{cfg_db.milvus_db.database_name}_edges"
332
- edge_coll = Collection(name=coll_name)
333
- edge_coll.load()
334
- graph_edges = edge_coll.query(
335
- expr=f'triplet_index IN [{",".join(f"{e}" for e in sub.edges)}]',
336
- output_fields=['head_id', 'tail_id', 'edge_type']
337
- )
338
- graph_edges = df.DataFrame(graph_edges)
339
- graph_edges.drop(columns=['triplet_index'], inplace=True)
340
- graph_edges['edge_type'] = graph_edges['edge_type'].str.split('|')
330
+ graph_nodes, graph_edges = self._process_subgraph_data(sub, cfg_db, color_df)
341
331
 
342
332
  # Prepare lists for visualization
343
333
  graph_dict["name"].append(sub.name)
@@ -350,32 +340,75 @@ class MultimodalSubgraphExtractionTool(BaseTool):
350
340
  'color': row.color})
351
341
  for row in getattr(graph_nodes,
352
342
  "to_pandas",
353
- lambda: graph_nodes)().itertuples(index=False)])
343
+ lambda graph_nodes=graph_nodes: graph_nodes)()
344
+ .itertuples(index=False)])
354
345
  graph_dict["edges"].append([(
355
346
  row.head_id,
356
347
  row.tail_id,
357
348
  {'label': tuple(row.edge_type)})
358
349
  for row in getattr(graph_edges,
359
350
  "to_pandas",
360
- lambda: graph_edges)().itertuples(index=False)])
351
+ lambda graph_edges=graph_edges: graph_edges)()
352
+ .itertuples(index=False)])
361
353
 
362
354
  # Prepare the textualized subgraph
363
355
  if sub.name == "Unified Subgraph":
364
356
  graph_nodes = graph_nodes[['node_id', 'desc']]
365
357
  graph_nodes.rename(columns={'desc': 'node_attr'}, inplace=True)
366
358
  graph_edges = graph_edges[['head_id', 'edge_type', 'tail_id']]
367
- graph_dict["text"] = (
368
- getattr(graph_nodes, "to_pandas", lambda: graph_nodes)().to_csv(index=False)
369
- + "\n"
370
- + getattr(graph_edges, "to_pandas", lambda: graph_edges)().to_csv(index=False)
371
- )
359
+ nodes_pandas = getattr(graph_nodes, "to_pandas",
360
+ lambda graph_nodes=graph_nodes: graph_nodes)()
361
+ nodes_csv = nodes_pandas.to_csv(index=False)
362
+ edges_pandas = getattr(graph_edges, "to_pandas",
363
+ lambda graph_edges=graph_edges: graph_edges)()
364
+ edges_csv = edges_pandas.to_csv(index=False)
365
+ graph_dict["text"] = nodes_csv + "\n" + edges_csv
372
366
 
373
367
  return graph_dict
374
368
 
369
+ def _process_subgraph_data(self, sub, cfg_db, color_df):
370
+ """Helper method to process individual subgraph data."""
371
+ print(f"Processing subgraph: {sub.name}")
372
+ print('---')
373
+ print(sub.nodes)
374
+ print('---')
375
+ print(sub.edges)
376
+ print('---')
377
+
378
+ # Prepare graph dataframes - Nodes
379
+ coll_name = f"{cfg_db.milvus_db.database_name}_nodes"
380
+ node_coll = Collection(name=coll_name)
381
+ node_coll.load()
382
+ graph_nodes = node_coll.query(
383
+ expr=f'node_index IN [{",".join(f"{n}" for n in sub.nodes)}]',
384
+ output_fields=['node_id', 'node_name', 'node_type', 'desc']
385
+ )
386
+ graph_nodes = self.loader.df.DataFrame(graph_nodes)
387
+ graph_nodes.drop(columns=['node_index'], inplace=True)
388
+ if not color_df.empty:
389
+ graph_nodes = graph_nodes.merge(color_df, on="node_id", how="left")
390
+ else:
391
+ graph_nodes["color"] = 'black'
392
+ graph_nodes['color'] = graph_nodes['color'].fillna('black')
393
+
394
+ # Edges
395
+ coll_name = f"{cfg_db.milvus_db.database_name}_edges"
396
+ edge_coll = Collection(name=coll_name)
397
+ edge_coll.load()
398
+ graph_edges = edge_coll.query(
399
+ expr=f'triplet_index IN [{",".join(f"{e}" for e in sub.edges)}]',
400
+ output_fields=['head_id', 'tail_id', 'edge_type']
401
+ )
402
+ graph_edges = self.loader.df.DataFrame(graph_edges)
403
+ graph_edges.drop(columns=['triplet_index'], inplace=True)
404
+ graph_edges['edge_type'] = graph_edges['edge_type'].str.split('|')
405
+
406
+ return graph_nodes, graph_edges
407
+
375
408
  def normalize_vector(self,
376
409
  v : list) -> list:
377
410
  """
378
- Normalize a vector using CuPy.
411
+ Normalize a vector using appropriate library (CuPy for GPU, NumPy for CPU).
379
412
 
380
413
  Args:
381
414
  v : Vector to normalize.
@@ -383,9 +416,13 @@ class MultimodalSubgraphExtractionTool(BaseTool):
383
416
  Returns:
384
417
  Normalized vector.
385
418
  """
386
- v = py.asarray(v)
387
- norm = py.linalg.norm(v)
388
- return (v / norm).tolist()
419
+ if self.loader.normalize_vectors:
420
+ # GPU mode: normalize the vector
421
+ v_array = self.loader.py.asarray(v)
422
+ norm = self.loader.py.linalg.norm(v_array)
423
+ return (v_array / norm).tolist()
424
+ # CPU mode: return as-is for COSINE similarity
425
+ return v
389
426
 
390
427
  def _run(
391
428
  self,