aiagents4pharma 1.39.5__py3-none-any.whl → 1.40.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,509 @@
1
+ """
2
+ Tool for performing multimodal subgraph extraction.
3
+ """
4
+
5
+ # import datetime
6
+ from typing import Type, Annotated
7
+ import logging
8
+ import hydra
9
+ import pandas as pd
10
+ from pydantic import BaseModel, Field
11
+ from langchain_core.tools import BaseTool
12
+ from langchain_core.messages import ToolMessage
13
+ from langchain_core.tools.base import InjectedToolCallId
14
+ from langgraph.types import Command
15
+ from langgraph.prebuilt import InjectedState
16
+ from pymilvus import Collection
17
+ from ..utils.extractions.milvus_multimodal_pcst import MultimodalPCSTPruning
18
+ from .load_arguments import ArgumentData
19
+ try:
20
+ import cupy as py
21
+ import cudf
22
+ df = cudf
23
+ except ImportError:
24
+ import numpy as py
25
+ df = pd
26
+
27
+ # Initialize logger
28
+ logging.basicConfig(level=logging.INFO)
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class MultimodalSubgraphExtractionInput(BaseModel):
33
+ """
34
+ MultimodalSubgraphExtractionInput is a Pydantic model representing an input
35
+ for extracting a subgraph.
36
+
37
+ Args:
38
+ prompt: Prompt to interact with the backend.
39
+ tool_call_id: Tool call ID.
40
+ state: Injected state.
41
+ arg_data: Argument for analytical process over graph data.
42
+ """
43
+
44
+ tool_call_id: Annotated[str, InjectedToolCallId] = Field(
45
+ description="Tool call ID."
46
+ )
47
+ state: Annotated[dict, InjectedState] = Field(description="Injected state.")
48
+ prompt: str = Field(description="Prompt to interact with the backend.")
49
+ arg_data: ArgumentData = Field(
50
+ description="Experiment over graph data.", default=None
51
+ )
52
+
53
+
54
+ class MultimodalSubgraphExtractionTool(BaseTool):
55
+ """
56
+ This tool performs subgraph extraction based on user's prompt by taking into account
57
+ the top-k nodes and edges.
58
+ """
59
+
60
+ name: str = "subgraph_extraction"
61
+ description: str = "A tool for subgraph extraction based on user's prompt."
62
+ args_schema: Type[BaseModel] = MultimodalSubgraphExtractionInput
63
+
64
+ def _read_multimodal_files(self,
65
+ state: Annotated[dict, InjectedState]) -> df.DataFrame:
66
+ """
67
+ Read the uploaded multimodal files and return a DataFrame.
68
+
69
+ Args:
70
+ state: The injected state for the tool.
71
+
72
+ Returns:
73
+ A DataFrame containing the multimodal files.
74
+ """
75
+ multimodal_df = df.DataFrame({"name": [], "node_type": []})
76
+
77
+ # Loop over the uploaded files and find multimodal files
78
+ logger.log(logging.INFO, "Looping over uploaded files")
79
+ for i in range(len(state["uploaded_files"])):
80
+ # Check if multimodal file is uploaded
81
+ if state["uploaded_files"][i]["file_type"] == "multimodal":
82
+ # Read the Excel file
83
+ multimodal_df = pd.read_excel(state["uploaded_files"][i]["file_path"],
84
+ sheet_name=None)
85
+
86
+ # Check if the multimodal_df is empty
87
+ logger.log(logging.INFO, "Checking if multimodal_df is empty")
88
+ if len(multimodal_df) > 0:
89
+ # Prepare multimodal_df
90
+ logger.log(logging.INFO, "Preparing multimodal_df")
91
+ # Merge all obtained dataframes into a single dataframe
92
+ multimodal_df = pd.concat(multimodal_df).reset_index()
93
+ multimodal_df = df.DataFrame(multimodal_df)
94
+ multimodal_df.drop(columns=["level_1"], inplace=True)
95
+ multimodal_df.rename(columns={"level_0": "q_node_type",
96
+ "name": "q_node_name"}, inplace=True)
97
+ # Since an excel sheet name could not contain a `/`,
98
+ # but the node type can be 'gene/protein' as exists in the PrimeKG
99
+ multimodal_df["q_node_type"] = multimodal_df["q_node_type"].str.replace('-', '_')
100
+
101
+ return multimodal_df
102
+
103
+ def _prepare_query_modalities(self,
104
+ prompt: dict,
105
+ state: Annotated[dict, InjectedState],
106
+ cfg_db: dict) -> df.DataFrame:
107
+ """
108
+ Prepare the modality-specific query for subgraph extraction.
109
+
110
+ Args:
111
+ prompt: The dictionary containing the user prompt and embeddings.
112
+ state: The injected state for the tool.
113
+ cfg_db: The configuration dictionary for Milvus database.
114
+
115
+ Returns:
116
+ A DataFrame containing the query embeddings and modalities.
117
+ """
118
+ # Initialize dataframes
119
+ logger.log(logging.INFO, "Initializing dataframes")
120
+ query_df = []
121
+ prompt_df = df.DataFrame({
122
+ 'node_id': 'user_prompt',
123
+ 'node_name': 'User Prompt',
124
+ 'node_type': 'prompt',
125
+ 'feat': prompt["text"],
126
+ 'feat_emb': prompt["emb"],
127
+ 'desc': prompt["text"],
128
+ 'desc_emb': prompt["emb"],
129
+ 'use_description': True # set to True for user prompt embedding
130
+ })
131
+
132
+ # Read multimodal files uploaded by the user
133
+ multimodal_df = self._read_multimodal_files(state)
134
+
135
+ # Check if the multimodal_df is empty
136
+ logger.log(logging.INFO, "Prepare query modalities")
137
+ if len(multimodal_df) > 0:
138
+ # Query the Milvus database for each node type in multimodal_df
139
+ logger.log(logging.INFO, "Querying Milvus database for each node type in multimodal_df")
140
+ for node_type, node_type_df in multimodal_df.groupby("q_node_type"):
141
+ print(f"Processing node type: {node_type}")
142
+
143
+ # Load the collection
144
+ collection = Collection(
145
+ name=f"{cfg_db.milvus_db.database_name}_nodes_{node_type.replace('/', '_')}"
146
+ )
147
+ collection.load()
148
+
149
+ # Query the collection with node names from multimodal_df
150
+ q_node_names = getattr(node_type_df['q_node_name'],
151
+ "to_pandas",
152
+ lambda: node_type_df['q_node_name'])().tolist()
153
+ q_columns = ["node_id", "node_name", "node_type",
154
+ "feat", "feat_emb", "desc", "desc_emb"]
155
+ res = collection.query(
156
+ expr=f'node_name IN [{','.join(f'"{name}"' for name in q_node_names)}]',
157
+ output_fields=q_columns,
158
+ )
159
+ # Convert the embeedings into floats
160
+ for r_ in res:
161
+ r_['feat_emb'] = [float(x) for x in r_['feat_emb']]
162
+ r_['desc_emb'] = [float(x) for x in r_['desc_emb']]
163
+
164
+ # Convert the result to a DataFrame
165
+ res_df = df.DataFrame(res)[q_columns]
166
+ res_df["use_description"] = False
167
+
168
+ # Append the results to query_df
169
+ query_df.append(res_df)
170
+
171
+ # Concatenate all results into a single DataFrame
172
+ logger.log(logging.INFO, "Concatenating all results into a single DataFrame")
173
+ query_df = df.concat(query_df, ignore_index=True)
174
+
175
+ # Update the state by adding the the selected node IDs
176
+ logger.log(logging.INFO, "Updating state with selected node IDs")
177
+ state["selections"] = getattr(query_df,
178
+ "to_pandas",
179
+ lambda: query_df)().groupby(
180
+ "node_type"
181
+ )["node_id"].apply(list).to_dict()
182
+
183
+ # Append a user prompt to the query dataframe
184
+ logger.log(logging.INFO, "Adding user prompt to query dataframe")
185
+ query_df = df.concat([query_df, prompt_df]).reset_index(drop=True)
186
+ else:
187
+ # If no multimodal files are uploaded, use the prompt embeddings
188
+ query_df = prompt_df
189
+
190
+ return query_df
191
+
192
+ def _perform_subgraph_extraction(self,
193
+ state: Annotated[dict, InjectedState],
194
+ cfg: dict,
195
+ cfg_db: dict,
196
+ query_df: pd.DataFrame) -> dict:
197
+ """
198
+ Perform multimodal subgraph extraction based on modal-specific embeddings.
199
+
200
+ Args:
201
+ state: The injected state for the tool.
202
+ cfg: The configuration dictionary.
203
+ cfg_db: The configuration dictionary for Milvus database.
204
+ query_df: The DataFrame containing the query embeddings and modalities.
205
+
206
+ Returns:
207
+ A dictionary containing the extracted subgraph with nodes and edges.
208
+ """
209
+ # Initialize the subgraph dictionary
210
+ subgraphs = []
211
+ unified_subgraph = {
212
+ "nodes": [],
213
+ "edges": []
214
+ }
215
+ # subgraphs = {}
216
+ # subgraphs["nodes"] = []
217
+ # subgraphs["edges"] = []
218
+
219
+ # Loop over query embeddings and modalities
220
+ for q in getattr(query_df, "to_pandas", lambda: query_df)().iterrows():
221
+ logger.log(logging.INFO, "===========================================")
222
+ logger.log(logging.INFO, "Processing query: %s", q[1]['node_name'])
223
+ # Prepare the PCSTPruning object and extract the subgraph
224
+ # Parameters were set in the configuration file obtained from Hydra
225
+ # start = datetime.datetime.now()
226
+ subgraph = MultimodalPCSTPruning(
227
+ topk=state["topk_nodes"],
228
+ topk_e=state["topk_edges"],
229
+ cost_e=cfg.cost_e,
230
+ c_const=cfg.c_const,
231
+ root=cfg.root,
232
+ num_clusters=cfg.num_clusters,
233
+ pruning=cfg.pruning,
234
+ verbosity_level=cfg.verbosity_level,
235
+ use_description=q[1]['use_description'],
236
+ metric_type=cfg.search_metric_type
237
+ ).extract_subgraph(q[1]['desc_emb'],
238
+ q[1]['feat_emb'],
239
+ q[1]['node_type'],
240
+ cfg_db)
241
+
242
+ # Append the extracted subgraph to the dictionary
243
+ unified_subgraph["nodes"].append(subgraph["nodes"].tolist())
244
+ unified_subgraph["edges"].append(subgraph["edges"].tolist())
245
+ subgraphs.append((q[1]['node_name'],
246
+ subgraph["nodes"].tolist(),
247
+ subgraph["edges"].tolist()))
248
+
249
+ # end = datetime.datetime.now()
250
+ # logger.log(logging.INFO, "Subgraph extraction time: %s seconds",
251
+ # (end - start).total_seconds())
252
+
253
+ # Concatenate and get unique node and edge indices
254
+ unified_subgraph["nodes"] = py.unique(
255
+ py.concatenate([py.array(list_) for list_ in unified_subgraph["nodes"]])
256
+ ).tolist()
257
+ unified_subgraph["edges"] = py.unique(
258
+ py.concatenate([py.array(list_) for list_ in unified_subgraph["edges"]])
259
+ ).tolist()
260
+
261
+ # Convert the unified subgraph and subgraphs to cudf DataFrames
262
+ unified_subgraph = df.DataFrame([("Unified Subgraph",
263
+ unified_subgraph["nodes"],
264
+ unified_subgraph["edges"])],
265
+ columns=["name", "nodes", "edges"])
266
+ subgraphs = df.DataFrame(subgraphs, columns=["name", "nodes", "edges"])
267
+
268
+ # Concate both DataFrames
269
+ subgraphs = df.concat([unified_subgraph, subgraphs], ignore_index=True)
270
+
271
+ return subgraphs
272
+
273
+ def _prepare_final_subgraph(self,
274
+ state:Annotated[dict, InjectedState],
275
+ subgraph: dict,
276
+ cfg: dict,
277
+ cfg_db) -> dict:
278
+ """
279
+ Prepare the subgraph based on the extracted subgraph.
280
+
281
+ Args:
282
+ state: The injected state for the tool.
283
+ subgraph: The extracted subgraph.
284
+ graph: The graph dictionary.
285
+ cfg: The configuration dictionary for the tool.
286
+ cfg_db: The configuration dictionary for Milvus database.
287
+
288
+ Returns:
289
+ A dictionary containing the PyG graph, NetworkX graph, and textualized graph.
290
+ """
291
+ # Convert the dict to a cudf DataFrame
292
+ node_colors = {n: cfg.node_colors_dict[k]
293
+ for k, v in state["selections"].items() for n in v}
294
+ color_df = df.DataFrame(list(node_colors.items()), columns=["node_id", "color"])
295
+ # print(color_df)
296
+
297
+ # Prepare the subgraph dictionary
298
+ graph_dict = {
299
+ "name": [],
300
+ "nodes": [],
301
+ "edges": [],
302
+ "text": ""
303
+ }
304
+ for sub in getattr(subgraph, "to_pandas", lambda: subgraph)().itertuples(index=False):
305
+ # Prepare the graph name
306
+ print(f"Processing subgraph: {sub.name}")
307
+ print('---')
308
+ print(sub.nodes)
309
+ print('---')
310
+ print(sub.edges)
311
+ print('---')
312
+
313
+ # Prepare graph dataframes
314
+ # Nodes
315
+ coll_name = f"{cfg_db.milvus_db.database_name}_nodes"
316
+ node_coll = Collection(name=coll_name)
317
+ node_coll.load()
318
+ graph_nodes = node_coll.query(
319
+ expr=f'node_index IN [{",".join(f"{n}" for n in sub.nodes)}]',
320
+ output_fields=['node_id', 'node_name', 'node_type', 'desc']
321
+ )
322
+ graph_nodes = df.DataFrame(graph_nodes)
323
+ graph_nodes.drop(columns=['node_index'], inplace=True)
324
+ if not color_df.empty:
325
+ # Merge the color dataframe with the graph nodes
326
+ graph_nodes = graph_nodes.merge(color_df, on="node_id", how="left")
327
+ else:
328
+ graph_nodes["color"] = 'black' # Default color
329
+ graph_nodes['color'].fillna('black', inplace=True) # Fill NaN colors with black
330
+ # Edges
331
+ coll_name = f"{cfg_db.milvus_db.database_name}_edges"
332
+ edge_coll = Collection(name=coll_name)
333
+ edge_coll.load()
334
+ graph_edges = edge_coll.query(
335
+ expr=f'triplet_index IN [{",".join(f"{e}" for e in sub.edges)}]',
336
+ output_fields=['head_id', 'tail_id', 'edge_type']
337
+ )
338
+ graph_edges = df.DataFrame(graph_edges)
339
+ graph_edges.drop(columns=['triplet_index'], inplace=True)
340
+ graph_edges['edge_type'] = graph_edges['edge_type'].str.split('|')
341
+
342
+ # Prepare lists for visualization
343
+ graph_dict["name"].append(sub.name)
344
+ graph_dict["nodes"].append([(
345
+ row.node_id,
346
+ {'hover': "Node Name : " + row.node_name + "\n" +\
347
+ "Node Type : " + row.node_type + "\n" +
348
+ "Desc : " + row.desc,
349
+ 'click': '$hover',
350
+ 'color': row.color})
351
+ for row in getattr(graph_nodes,
352
+ "to_pandas",
353
+ lambda: graph_nodes)().itertuples(index=False)])
354
+ graph_dict["edges"].append([(
355
+ row.head_id,
356
+ row.tail_id,
357
+ {'label': tuple(row.edge_type)})
358
+ for row in getattr(graph_edges,
359
+ "to_pandas",
360
+ lambda: graph_edges)().itertuples(index=False)])
361
+
362
+ # Prepare the textualized subgraph
363
+ if sub.name == "Unified Subgraph":
364
+ graph_nodes = graph_nodes[['node_id', 'desc']]
365
+ graph_nodes.rename(columns={'desc': 'node_attr'}, inplace=True)
366
+ graph_edges = graph_edges[['head_id', 'edge_type', 'tail_id']]
367
+ graph_dict["text"] = (
368
+ getattr(graph_nodes, "to_pandas", lambda: graph_nodes)().to_csv(index=False)
369
+ + "\n"
370
+ + getattr(graph_edges, "to_pandas", lambda: graph_edges)().to_csv(index=False)
371
+ )
372
+
373
+ return graph_dict
374
+
375
+ def normalize_vector(self,
376
+ v : list) -> list:
377
+ """
378
+ Normalize a vector using CuPy.
379
+
380
+ Args:
381
+ v : Vector to normalize.
382
+
383
+ Returns:
384
+ Normalized vector.
385
+ """
386
+ v = py.asarray(v)
387
+ norm = py.linalg.norm(v)
388
+ return (v / norm).tolist()
389
+
390
+ def _run(
391
+ self,
392
+ tool_call_id: Annotated[str, InjectedToolCallId],
393
+ state: Annotated[dict, InjectedState],
394
+ prompt: str,
395
+ arg_data: ArgumentData = None,
396
+ ) -> Command:
397
+ """
398
+ Run the subgraph extraction tool.
399
+
400
+ Args:
401
+ tool_call_id: The tool call ID for the tool.
402
+ state: Injected state for the tool.
403
+ prompt: The prompt to interact with the backend.
404
+ arg_data (ArgumentData): The argument data.
405
+
406
+ Returns:
407
+ Command: The command to be executed.
408
+ """
409
+ logger.log(logging.INFO, "Invoking subgraph_extraction tool")
410
+
411
+ # Load hydra configuration
412
+ with hydra.initialize(version_base=None, config_path="../configs"):
413
+ cfg = hydra.compose(
414
+ config_name="config", overrides=["tools/multimodal_subgraph_extraction=default"]
415
+ )
416
+ cfg_db = cfg.app.frontend
417
+ cfg = cfg.tools.multimodal_subgraph_extraction
418
+
419
+ # Check if the Milvus connection exists
420
+ # logger.log(logging.INFO, "Checking Milvus connection")
421
+ # logger.log(logging.INFO, "Milvus connection name: %s", cfg_db.milvus_db.alias)
422
+ # logger.log(logging.INFO, "Milvus connection DB: %s", cfg_db.milvus_db.database_name)
423
+ # logger.log(logging.INFO, "Is connection established? %s",
424
+ # connections.has_connection(cfg_db.milvus_db.alias))
425
+ # if connections.has_connection(cfg_db.milvus_db.alias):
426
+ # logger.log(logging.INFO, "Milvus connection is established.")
427
+ # for collection_name in utility.list_collections():
428
+ # logger.log(logging.INFO, "Collection: %s", collection_name)
429
+
430
+ # Prepare the query embeddings and modalities
431
+ logger.log(logging.INFO, "_prepare_query_modalities")
432
+ # start = datetime.datetime.now()
433
+ query_df = self._prepare_query_modalities(
434
+ {"text": prompt,
435
+ "emb": [self.normalize_vector(
436
+ state["embedding_model"].embed_query(prompt)
437
+ )]
438
+ },
439
+ state,
440
+ cfg_db,
441
+ )
442
+ # end = datetime.datetime.now()
443
+ # logger.log(logging.INFO, "_prepare_query_modalities time: %s seconds",
444
+ # (end - start).total_seconds())
445
+
446
+ # Perform subgraph extraction
447
+ logger.log(logging.INFO, "_perform_subgraph_extraction")
448
+ # start = datetime.datetime.now()
449
+ subgraphs = self._perform_subgraph_extraction(state,
450
+ cfg,
451
+ cfg_db,
452
+ query_df)
453
+ # end = datetime.datetime.now()
454
+ # logger.log(logging.INFO, "_perform_subgraph_extraction time: %s seconds",
455
+ # (end - start).total_seconds())
456
+
457
+ # Prepare subgraph as a NetworkX graph and textualized graph
458
+ logger.log(logging.INFO, "_prepare_final_subgraph")
459
+ logger.log(logging.INFO, "Subgraphs extracted: %s", len(subgraphs))
460
+ # start = datetime.datetime.now()
461
+ final_subgraph = self._prepare_final_subgraph(state,
462
+ subgraphs,
463
+ cfg,
464
+ cfg_db)
465
+ # end = datetime.datetime.now()
466
+ # logger.log(logging.INFO, "_prepare_final_subgraph time: %s seconds",
467
+ # (end - start).total_seconds())
468
+
469
+ # Prepare the dictionary of extracted graph
470
+ logger.log(logging.INFO, "dic_extracted_graph")
471
+ # start = datetime.datetime.now()
472
+ dic_extracted_graph = {
473
+ "name": arg_data.extraction_name,
474
+ "tool_call_id": tool_call_id,
475
+ "graph_source": state["dic_source_graph"][0]["name"],
476
+ "topk_nodes": state["topk_nodes"],
477
+ "topk_edges": state["topk_edges"],
478
+ "graph_dict": {
479
+ "name": final_subgraph["name"],
480
+ "nodes": final_subgraph["nodes"],
481
+ "edges": final_subgraph["edges"],
482
+ },
483
+ "graph_text": final_subgraph["text"],
484
+ "graph_summary": None,
485
+ }
486
+ # end = datetime.datetime.now()
487
+ # logger.log(logging.INFO, "dic_extracted_graph time: %s seconds",
488
+ # (end - start).total_seconds())
489
+
490
+ # Prepare the dictionary of updated state
491
+ dic_updated_state_for_model = {}
492
+ for key, value in {
493
+ "dic_extracted_graph": [dic_extracted_graph],
494
+ }.items():
495
+ if value:
496
+ dic_updated_state_for_model[key] = value
497
+
498
+ # Return the updated state of the tool
499
+ return Command(
500
+ update=dic_updated_state_for_model | {
501
+ # update the message history
502
+ "messages": [
503
+ ToolMessage(
504
+ content=f"Subgraph Extraction Result of {arg_data.extraction_name}",
505
+ tool_call_id=tool_call_id,
506
+ )
507
+ ],
508
+ }
509
+ )
@@ -3,3 +3,4 @@ This file is used to import all the models in the package.
3
3
  '''
4
4
  from . import pcst
5
5
  from . import multimodal_pcst
6
+ from . import milvus_multimodal_pcst