aiagents4pharma 1.39.5__py3-none-any.whl → 1.40.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,509 @@
1
+ # pylint: disable=wrong-import-position
2
+ #!/usr/bin/env python3
3
+ """
4
+ Script to load PrimeKG multimodal data into Milvus database.
5
+ This script runs after Milvus container is ready and loads the .pkl file data.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import subprocess
11
+ import glob
12
+ import logging
13
+ from typing import Dict, Any, List
14
+
15
+ def install_packages():
16
+ """Install required packages."""
17
+ packages = [
18
+ "pip install --extra-index-url=https://pypi.nvidia.com cudf-cu12",
19
+ "pip install --extra-index-url=https://pypi.nvidia.com dask-cudf-cu12",
20
+ "pip install pymilvus==2.5.11",
21
+ "pip install numpy==1.26.4",
22
+ "pip install pandas==2.1.3",
23
+ "pip install tqdm==4.67.1",
24
+ ]
25
+
26
+ print("[DATA LOADER] Installing required packages...")
27
+ for package_cmd in packages:
28
+ print(f"[DATA LOADER] Running: {package_cmd}")
29
+ result = subprocess.run(package_cmd.split(), capture_output=True, text=True, check=True)
30
+ if result.returncode != 0:
31
+ print(f"[DATA LOADER] Error installing package: {result.stderr}")
32
+ sys.exit(1)
33
+ print("[DATA LOADER] All packages installed successfully!")
34
+
35
+ # Install packages first
36
+ install_packages()
37
+
38
+ try:
39
+ import cudf
40
+ import cupy as cp
41
+ except ImportError as e:
42
+ print("[DATA LOADER] cudf or cupy not found. Please ensure they are installed correctly.")
43
+ sys.exit(1)
44
+
45
+ from pymilvus import (
46
+ db,
47
+ connections,
48
+ FieldSchema,
49
+ CollectionSchema,
50
+ DataType,
51
+ Collection,
52
+ utility
53
+ )
54
+ from tqdm import tqdm
55
+
56
+ # Configure logging
57
+ logging.basicConfig(level=logging.INFO, format='[DATA LOADER] %(message)s')
58
+ logger = logging.getLogger(__name__)
59
+
60
+ class MilvusDataLoader:
61
+ """
62
+ Class to handle loading of BioBridge-PrimeKG multimodal data into Milvus.
63
+ """
64
+ def __init__(self, config: Dict[str, Any]):
65
+ """Initialize the MilvusDataLoader with configuration parameters."""
66
+ self.config = config
67
+ self.milvus_host = config.get('milvus_host', 'localhost')
68
+ self.milvus_port = config.get('milvus_port', '19530')
69
+ self.milvus_user = config.get('milvus_user', 'root')
70
+ self.milvus_password = config.get('milvus_password', 'Milvus')
71
+ self.milvus_database = config.get('milvus_database', 't2kg_primekg')
72
+ self.data_dir = config.get('data_dir',
73
+ 'tests/files/biobridge_multimodal/')
74
+ self.batch_size = config.get('batch_size', 500)
75
+ self.chunk_size = config.get('chunk_size', 5)
76
+
77
+ def normalize_matrix(self, m, axis=1):
78
+ """Normalize each row of a 2D matrix using CuPy."""
79
+ norms = cp.linalg.norm(m, axis=axis, keepdims=True)
80
+ return m / norms
81
+
82
+ def normalize_vector(self, v):
83
+ """Normalize a vector using CuPy."""
84
+ v = cp.asarray(v)
85
+ norm = cp.linalg.norm(v)
86
+ return v / norm
87
+
88
+ def connect_to_milvus(self):
89
+ """Connect to Milvus and setup database."""
90
+ logger.info("Connecting to Milvus at %s:%s", self.milvus_host, self.milvus_port)
91
+
92
+ connections.connect(
93
+ alias="default",
94
+ host=self.milvus_host,
95
+ port=self.milvus_port,
96
+ user=self.milvus_user,
97
+ password=self.milvus_password
98
+ )
99
+
100
+ # Check if database exists, create if it doesn't
101
+ if self.milvus_database not in db.list_database():
102
+ logger.info("Creating database: %s", self.milvus_database)
103
+ db.create_database(self.milvus_database)
104
+
105
+ # Switch to the desired database
106
+ db.using_database(self.milvus_database)
107
+ logger.info("Using database: %s", self.milvus_database)
108
+
109
+ def load_graph_data(self):
110
+ """Load the pickle file containing graph data."""
111
+ logger.info("Loading graph data from: %s", self.data_dir)
112
+
113
+ if not os.path.exists(self.data_dir):
114
+ raise FileNotFoundError(f"Data directory not found: {self.data_dir}")
115
+
116
+ # Load dataframes containing nodes and edges
117
+ # Loop over nodes and edges
118
+ graph = {}
119
+ for element in ["nodes", "edges"]:
120
+ # Make an empty dictionary for each folder
121
+ graph[element] = {}
122
+ for stage in ["enrichment", "embedding"]:
123
+ print(element, stage)
124
+ # Create the file pattern for the current subfolder
125
+ file_list = glob.glob(os.path.join(self.data_dir,
126
+ element,
127
+ stage,
128
+ '*.parquet.gzip'))
129
+ print(file_list)
130
+ # Read and concatenate all dataframes in the folder
131
+ # Except the edges embedding, which is too large to read in one go
132
+ # We are using a chunk size to read the edges embedding in smaller parts instead
133
+ if element == "edges" and stage == "embedding":
134
+ # For edges embedding, only read two columns:
135
+ # triplet_index and edge_emb
136
+ # Loop by chunks
137
+ chunk_size = self.chunk_size
138
+ graph[element][stage] = []
139
+ for i in range(0, len(file_list), chunk_size):
140
+ chunk_files = file_list[i:i+chunk_size]
141
+ chunk_df = cudf.concat([
142
+ cudf.read_parquet(f, columns=["triplet_index", "edge_emb"])
143
+ for f in chunk_files
144
+ ], ignore_index=True)
145
+ graph[element][stage].append(chunk_df)
146
+ else:
147
+ # For nodes and edges enrichment,
148
+ # read and concatenate all dataframes in the folder
149
+ # This includes the nodes embedding,
150
+ # which is small enough to read in one go
151
+ graph[element][stage] = cudf.concat([
152
+ cudf.read_parquet(f) for f in file_list
153
+ ], ignore_index=True)
154
+
155
+ logger.info("Graph data loaded successfully")
156
+ return graph
157
+
158
+ def create_nodes_collection(self, nodes_df: cudf.DataFrame):
159
+ """Create and populate the main nodes collection."""
160
+ logger.info("Creating main nodes collection...")
161
+ node_coll_name = f"{self.milvus_database}_nodes"
162
+
163
+ node_fields = [
164
+ FieldSchema(name="node_index",
165
+ dtype=DataType.INT64,
166
+ is_primary=True),
167
+ FieldSchema(name="node_id",
168
+ dtype=DataType.VARCHAR,
169
+ max_length=1024),
170
+ FieldSchema(name="node_name",
171
+ dtype=DataType.VARCHAR,
172
+ max_length=1024,
173
+ enable_analyzer=True,
174
+ enable_match=True),
175
+ FieldSchema(name="node_type",
176
+ dtype=DataType.VARCHAR,
177
+ max_length=1024,
178
+ enable_analyzer=True,
179
+ enable_match=True),
180
+ FieldSchema(name="desc",
181
+ dtype=DataType.VARCHAR,
182
+ max_length=40960,
183
+ enable_analyzer=True,
184
+ enable_match=True),
185
+ FieldSchema(name="desc_emb",
186
+ dtype=DataType.FLOAT_VECTOR,
187
+ dim=len(nodes_df.iloc[0]['desc_emb'].to_arrow().to_pylist()[0])),
188
+ ]
189
+ schema = CollectionSchema(fields=node_fields,
190
+ description=f"Schema for collection {node_coll_name}")
191
+
192
+ # Create collection if it doesn't exist
193
+ if not utility.has_collection(node_coll_name):
194
+ collection = Collection(name=node_coll_name, schema=schema)
195
+ else:
196
+ collection = Collection(name=node_coll_name)
197
+
198
+ # Create indexes
199
+ collection.create_index(field_name="node_index",
200
+ index_params={"index_type": "STL_SORT"},
201
+ index_name="node_index_index")
202
+ collection.create_index(field_name="node_name",
203
+ index_params={"index_type": "INVERTED"},
204
+ index_name="node_name_index")
205
+ collection.create_index(field_name="node_type",
206
+ index_params={"index_type": "INVERTED"},
207
+ index_name="node_type_index")
208
+ collection.create_index(field_name="desc",
209
+ index_params={"index_type": "INVERTED"},
210
+ index_name="desc_index")
211
+ collection.create_index(field_name="desc_emb",
212
+ index_params={"index_type": "GPU_CAGRA",
213
+ "metric_type": "IP"},
214
+ index_name="desc_emb_index")
215
+
216
+ # Prepare and insert data
217
+ desc_emb_norm = cp.asarray(nodes_df["desc_emb"].list.leaves).astype(cp.float32).\
218
+ reshape(nodes_df.shape[0], -1)
219
+ desc_emb_norm = self.normalize_matrix(desc_emb_norm, axis=1)
220
+ data = [
221
+ nodes_df["node_index"].to_arrow().to_pylist(),
222
+ nodes_df["node_id"].to_arrow().to_pylist(),
223
+ nodes_df["node_name"].to_arrow().to_pylist(),
224
+ nodes_df["node_type"].to_arrow().to_pylist(),
225
+ nodes_df["desc"].to_arrow().to_pylist(),
226
+ desc_emb_norm.tolist(), # Use normalized embeddings
227
+ ]
228
+
229
+ # Insert data in batches
230
+ total = len(data[0])
231
+ for i in tqdm(range(0, total, self.batch_size), desc="Inserting nodes"):
232
+ batch = [col[i:i+self.batch_size] for col in data]
233
+ collection.insert(batch)
234
+
235
+ collection.flush()
236
+ logger.info("Nodes collection created with %d entities", collection.num_entities)
237
+
238
+ def create_node_type_collections(self, nodes_df: cudf.DataFrame):
239
+ """Create separate collections for each node type."""
240
+ logger.info("Creating node type-specific collections...")
241
+
242
+ for node_type, nodes_df_ in tqdm(nodes_df.groupby('node_type'),
243
+ desc="Processing node types"):
244
+ node_coll_name = f"{self.milvus_database}_nodes_{node_type.replace('/', '_')}"
245
+
246
+ node_fields = [
247
+ FieldSchema(name="node_index",
248
+ dtype=DataType.INT64,
249
+ is_primary=True,
250
+ auto_id=False),
251
+ FieldSchema(name="node_id",
252
+ dtype=DataType.VARCHAR,
253
+ max_length=1024),
254
+ FieldSchema(name="node_name",
255
+ dtype=DataType.VARCHAR,
256
+ max_length=1024,
257
+ enable_analyzer=True,
258
+ enable_match=True),
259
+ FieldSchema(name="node_type",
260
+ dtype=DataType.VARCHAR,
261
+ max_length=1024,
262
+ enable_analyzer=True,
263
+ enable_match=True),
264
+ FieldSchema(name="desc",
265
+ dtype=DataType.VARCHAR,
266
+ max_length=40960,
267
+ enable_analyzer=True,
268
+ enable_match=True),
269
+ FieldSchema(name="desc_emb",
270
+ dtype=DataType.FLOAT_VECTOR,
271
+ dim=len(nodes_df_.iloc[0]['desc_emb'].to_arrow().to_pylist()[0])),
272
+ FieldSchema(name="feat",
273
+ dtype=DataType.VARCHAR,
274
+ max_length=40960,
275
+ enable_analyzer=True,
276
+ enable_match=True),
277
+ FieldSchema(name="feat_emb",
278
+ dtype=DataType.FLOAT_VECTOR,
279
+ dim=len(nodes_df_.iloc[0]['feat_emb'].to_arrow().to_pylist()[0])),
280
+ ]
281
+ schema = CollectionSchema(fields=node_fields,
282
+ description=f"schema for collection {node_coll_name}")
283
+
284
+ if not utility.has_collection(node_coll_name):
285
+ collection = Collection(name=node_coll_name, schema=schema)
286
+ else:
287
+ collection = Collection(name=node_coll_name)
288
+
289
+ # Create indexes
290
+ collection.create_index(field_name="node_index",
291
+ index_params={"index_type": "STL_SORT"},
292
+ index_name="node_index_index")
293
+ collection.create_index(field_name="node_name",
294
+ index_params={"index_type": "INVERTED"},
295
+ index_name="node_name_index")
296
+ collection.create_index(field_name="node_type",
297
+ index_params={"index_type": "INVERTED"},
298
+ index_name="node_type_index")
299
+ collection.create_index(field_name="desc",
300
+ index_params={"index_type": "INVERTED"},
301
+ index_name="desc_index")
302
+ collection.create_index(field_name="desc_emb",
303
+ index_params={"index_type": "GPU_CAGRA",
304
+ "metric_type": "IP"},
305
+ index_name="desc_emb_index")
306
+ collection.create_index(field_name="feat_emb",
307
+ index_params={"index_type": "GPU_CAGRA",
308
+ "metric_type": "IP"},
309
+ index_name="feat_emb_index")
310
+
311
+ # Prepare data
312
+ desc_emb_norm = cp.asarray(nodes_df_["desc_emb"].list.leaves).astype(cp.float32).\
313
+ reshape(nodes_df_.shape[0], -1)
314
+ desc_emb_norm = self.normalize_matrix(desc_emb_norm, axis=1)
315
+ feat_emb_norm = cp.asarray(nodes_df_["feat_emb"].list.leaves).astype(cp.float32).\
316
+ reshape(nodes_df_.shape[0], -1)
317
+ feat_emb_norm = self.normalize_matrix(feat_emb_norm, axis=1)
318
+ data = [
319
+ nodes_df_["node_index"].to_arrow().to_pylist(),
320
+ nodes_df_["node_id"].to_arrow().to_pylist(),
321
+ nodes_df_["node_name"].to_arrow().to_pylist(),
322
+ nodes_df_["node_type"].to_arrow().to_pylist(),
323
+ nodes_df_["desc"].to_arrow().to_pylist(),
324
+ desc_emb_norm.tolist(), # Use normalized embeddings
325
+ nodes_df_["feat"].to_arrow().to_pylist(),
326
+ feat_emb_norm.tolist(), # Use normalized embeddings
327
+ ]
328
+
329
+ # Insert data in batches
330
+ total_rows = len(data[0])
331
+ for i in range(0, total_rows, self.batch_size):
332
+ batch = [col[i:i + self.batch_size] for col in data]
333
+ collection.insert(batch)
334
+
335
+ collection.flush()
336
+ logger.info("Collection %s created with %d entities",
337
+ node_coll_name, collection.num_entities)
338
+
339
+ def create_edges_collection(self,
340
+ edges_enrichment_df: cudf.DataFrame,
341
+ edges_embedding_df: List[cudf.DataFrame]):
342
+ """Create and populate the edges collection."""
343
+ logger.info("Creating edges collection...")
344
+
345
+ edge_coll_name = f"{self.milvus_database}_edges"
346
+
347
+ edge_fields = [
348
+ FieldSchema(name="triplet_index",
349
+ dtype=DataType.INT64,
350
+ is_primary=True,
351
+ auto_id=False),
352
+ FieldSchema(name="head_id",
353
+ dtype=DataType.VARCHAR,
354
+ max_length=1024),
355
+ FieldSchema(name="head_index",
356
+ dtype=DataType.INT64),
357
+ FieldSchema(name="tail_id",
358
+ dtype=DataType.VARCHAR,
359
+ max_length=1024),
360
+ FieldSchema(name="tail_index",
361
+ dtype=DataType.INT64),
362
+ FieldSchema(name="edge_type",
363
+ dtype=DataType.VARCHAR,
364
+ max_length=1024),
365
+ FieldSchema(name="display_relation",
366
+ dtype=DataType.VARCHAR,
367
+ max_length=1024),
368
+ FieldSchema(name="feat",
369
+ dtype=DataType.VARCHAR,
370
+ max_length=40960),
371
+ FieldSchema(name="feat_emb",
372
+ dtype=DataType.FLOAT_VECTOR,
373
+ dim=len(edges_embedding_df[0].loc[0, 'edge_emb'])),
374
+ ]
375
+ edge_schema = CollectionSchema(fields=edge_fields,
376
+ description="Schema for edges collection")
377
+
378
+ if not utility.has_collection(edge_coll_name):
379
+ collection = Collection(name=edge_coll_name, schema=edge_schema)
380
+ else:
381
+ collection = Collection(name=edge_coll_name)
382
+
383
+ # Create indexes
384
+ collection.create_index(field_name="triplet_index",
385
+ index_params={"index_type": "STL_SORT"},
386
+ index_name="triplet_index_index")
387
+ collection.create_index(field_name="head_index",
388
+ index_params={"index_type": "STL_SORT"},
389
+ index_name="head_index_index")
390
+ collection.create_index(field_name="tail_index",
391
+ index_params={"index_type": "STL_SORT"},
392
+ index_name="tail_index_index")
393
+ collection.create_index(field_name="feat_emb",
394
+ index_params={"index_type": "GPU_CAGRA",
395
+ "metric_type": "IP"},
396
+ index_name="feat_emb_index")
397
+
398
+ # Iterate over chunked edges embedding df
399
+ for edges_df in tqdm(edges_embedding_df):
400
+ # Merge enrichment with embedding
401
+ merged_edges_df = edges_enrichment_df.merge(
402
+ edges_df[["triplet_index", "edge_emb"]],
403
+ on="triplet_index",
404
+ how="inner"
405
+ )
406
+
407
+ # Prepare data
408
+ edge_emb_cp = cp.asarray(merged_edges_df["edge_emb"].list.leaves).astype(cp.float32).\
409
+ reshape(merged_edges_df.shape[0], -1)
410
+ edge_emb_norm = self.normalize_matrix(edge_emb_cp, axis=1)
411
+ data = [
412
+ merged_edges_df["triplet_index"].to_arrow().to_pylist(),
413
+ merged_edges_df["head_id"].to_arrow().to_pylist(),
414
+ merged_edges_df["head_index"].to_arrow().to_pylist(),
415
+ merged_edges_df["tail_id"].to_arrow().to_pylist(),
416
+ merged_edges_df["tail_index"].to_arrow().to_pylist(),
417
+ merged_edges_df["edge_type_str"].to_arrow().to_pylist(),
418
+ merged_edges_df["display_relation"].to_arrow().to_pylist(),
419
+ merged_edges_df["feat"].to_arrow().to_pylist(),
420
+ edge_emb_norm.tolist(), # Use normalized embeddings
421
+ ]
422
+
423
+ # Insert data in batches
424
+ total = len(data[0])
425
+ for i in tqdm(range(0, total, self.batch_size), desc="Inserting edges"):
426
+ batch_data = [d[i:i+self.batch_size] for d in data]
427
+ collection.insert(batch_data)
428
+
429
+ collection.flush()
430
+ logger.info("Edges collection created with %d entities", collection.num_entities)
431
+
432
+ def run(self):
433
+ """Main execution method."""
434
+ try:
435
+ logger.info("Starting Milvus data loading process...")
436
+
437
+ # Connect to Milvus
438
+ self.connect_to_milvus()
439
+
440
+ # Load graph data
441
+ graph = self.load_graph_data()
442
+
443
+ # Prepare data
444
+ logger.info("Data Preparation started...")
445
+ # Get nodes enrichment and embedding dataframes
446
+ nodes_enrichment_df = graph['nodes']['enrichment']
447
+ nodes_embedding_df = graph['nodes']['embedding']
448
+
449
+ # Get edges enrichment and embedding dataframes
450
+ edges_enrichment_df = graph['edges']['enrichment']
451
+ # !!consisted of a list of dataframes!!
452
+ edges_embedding_df = graph['edges']['embedding']
453
+
454
+ # For nodes, we can directly merge enrichment and embedding
455
+ # Merge nodes enrichment and embedding dataframes
456
+ merged_nodes_df = nodes_enrichment_df.merge(
457
+ nodes_embedding_df[["node_id", "desc_emb", "feat_emb"]],
458
+ on="node_id",
459
+ how="left"
460
+ )
461
+
462
+ # Create collections and load data
463
+ self.create_nodes_collection(merged_nodes_df)
464
+ self.create_node_type_collections(merged_nodes_df)
465
+ self.create_edges_collection(edges_enrichment_df,
466
+ edges_embedding_df)
467
+
468
+ # List all collections for verification
469
+ logger.info("Data loading completed successfully!")
470
+ logger.info("Created collections:")
471
+ for coll in utility.list_collections():
472
+ collection = Collection(name=coll)
473
+ logger.info(" %s: %d entities", coll, collection.num_entities)
474
+
475
+ except Exception as e:
476
+ logger.error("Error during data loading: %s", str(e))
477
+ raise
478
+
479
+
480
+ def main():
481
+ """Main function to run the data loader."""
482
+ # Resolve the fallback data path relative to this script's location
483
+ script_dir = os.path.dirname(os.path.abspath(__file__))
484
+ default_data_dir = os.path.join(script_dir, "tests/files/biobridge_multimodal/")
485
+
486
+ # Configuration
487
+ config = {
488
+ 'milvus_host': os.getenv('MILVUS_HOST', 'localhost'),
489
+ 'milvus_port': os.getenv('MILVUS_PORT', '19530'),
490
+ 'milvus_user': os.getenv('MILVUS_USER', 'root'),
491
+ 'milvus_password': os.getenv('MILVUS_PASSWORD', 'Milvus'),
492
+ 'milvus_database': os.getenv('MILVUS_DATABASE', 't2kg_primekg'),
493
+ 'data_dir': os.getenv('DATA_DIR', default_data_dir),
494
+ 'batch_size': int(os.getenv('BATCH_SIZE', '500')),
495
+ 'chunk_size': int(os.getenv('CHUNK_SIZE', '5')),
496
+ }
497
+
498
+ # Print configuration for debugging
499
+ print("[DATA LOADER] Configuration:")
500
+ for key, value in config.items():
501
+ print(f"[DATA LOADER] {key}: {value}")
502
+
503
+ # Create and run data loader
504
+ loader = MilvusDataLoader(config)
505
+ loader.run()
506
+
507
+
508
+ if __name__ == "__main__":
509
+ main()
@@ -1,24 +1,23 @@
1
1
  """
2
2
  Test cases for agents/t2kg_agent.py
3
3
  """
4
-
4
+ from unittest.mock import patch, MagicMock
5
5
  import pytest
6
6
  from langchain_core.messages import HumanMessage
7
7
  from langchain_openai import ChatOpenAI, OpenAIEmbeddings
8
+ import pandas as pd
8
9
  from ..agents.t2kg_agent import get_app
9
10
 
10
- # Define the data path
11
11
  DATA_PATH = "aiagents4pharma/talk2knowledgegraphs/tests/files"
12
12
 
13
-
14
13
  @pytest.fixture(name="input_dict")
15
14
  def input_dict_fixture():
16
15
  """
17
16
  Input dictionary fixture.
18
17
  """
19
18
  input_dict = {
20
- "llm_model": None, # TBA for each test case
21
- "embedding_model": None, # TBA for each test case
19
+ "llm_model": None,
20
+ "embedding_model": None,
22
21
  "selections": {
23
22
  "gene/protein": [],
24
23
  "molecular_function": [],
@@ -47,30 +46,62 @@ def input_dict_fixture():
47
46
  ],
48
47
  "dic_extracted_graph": []
49
48
  }
50
-
51
49
  return input_dict
52
50
 
51
+ def mock_milvus_collection(name):
52
+ """
53
+ Mock Milvus collection for testing.
54
+ """
55
+ nodes = MagicMock()
56
+ nodes.query.return_value = [
57
+ {"node_index": 0,
58
+ "node_id": "id1",
59
+ "node_name": "Adalimumab",
60
+ "node_type": "drug",
61
+ "feat": "featA", "feat_emb": [0.1, 0.2, 0.3],
62
+ "desc": "descA", "desc_emb": [0.1, 0.2, 0.3]},
63
+ {"node_index": 1,
64
+ "node_id": "id2",
65
+ "node_name": "TNF",
66
+ "node_type": "gene/protein",
67
+ "feat": "featB", "feat_emb": [0.4, 0.5, 0.6],
68
+ "desc": "descB", "desc_emb": [0.4, 0.5, 0.6]}
69
+ ]
70
+ nodes.load.return_value = None
71
+
72
+ edges = MagicMock()
73
+ edges.query.return_value = [
74
+ {"triplet_index": 0,
75
+ "head_id": "id1",
76
+ "head_index": 0,
77
+ "tail_id": "id2",
78
+ "tail_index": 1,
79
+ "edge_type": "drug,acts_on,gene/protein",
80
+ "display_relation": "acts_on",
81
+ "feat": "featC",
82
+ "feat_emb": [0.7, 0.8, 0.9]}
83
+ ]
84
+ edges.load.return_value = None
53
85
 
54
- def test_t2kg_agent_openai(input_dict):
86
+ if "nodes" in name:
87
+ return nodes
88
+ if "edges" in name:
89
+ return edges
90
+ return None
91
+
92
+ def test_t2kg_agent_openai_milvus_mock(input_dict):
55
93
  """
56
- Test the T2KG agent using OpenAI model.
94
+ Test the T2KG agent using OpenAI model and Milvus mock.
57
95
 
58
96
  Args:
59
97
  input_dict: Input dictionary
60
98
  """
61
- # Prepare LLM and embedding model
62
99
  input_dict["llm_model"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
63
100
  input_dict["embedding_model"] = OpenAIEmbeddings(model="text-embedding-3-small")
64
-
65
- # Setup the app
66
101
  unique_id = 12345
67
102
  app = get_app(unique_id, llm_model=input_dict["llm_model"])
68
103
  config = {"configurable": {"thread_id": unique_id}}
69
- # Update state
70
- app.update_state(
71
- config,
72
- input_dict,
73
- )
104
+ app.update_state(config, input_dict)
74
105
  prompt = """
75
106
  Adalimumab is a fully human monoclonal antibody (IgG1)
76
107
  that specifically binds to tumor necrosis factor-alpha (TNF-α), a pro-inflammatory cytokine.
@@ -85,14 +116,43 @@ def test_t2kg_agent_openai(input_dict):
85
116
  Please set the extraction name for the extraction process as `subkg_12345`.
86
117
  """
87
118
 
88
- # Test the tool get_modelinfo
89
- response = app.invoke({"messages": [HumanMessage(content=prompt)]}, config=config)
119
+ with patch("aiagents4pharma.talk2knowledgegraphs.tools."
120
+ "milvus_multimodal_subgraph_extraction.Collection",
121
+ side_effect=mock_milvus_collection), \
122
+ patch("aiagents4pharma.talk2knowledgegraphs.tools."
123
+ "milvus_multimodal_subgraph_extraction.MultimodalPCSTPruning") as mock_pcst, \
124
+ patch("pymilvus.connections") as mock_connections, \
125
+ patch("aiagents4pharma.talk2knowledgegraphs.tools."
126
+ "milvus_multimodal_subgraph_extraction.hydra.initialize"), \
127
+ patch("aiagents4pharma.talk2knowledgegraphs.tools."
128
+ "milvus_multimodal_subgraph_extraction.hydra.compose") as mock_compose:
129
+ mock_connections.has_connection.return_value = True
130
+ mock_pcst_instance = MagicMock()
131
+ mock_pcst_instance.extract_subgraph.return_value = {
132
+ "nodes": pd.Series([0, 1]),
133
+ "edges": pd.Series([0])
134
+ }
135
+ mock_pcst.return_value = mock_pcst_instance
136
+ mock_cfg = MagicMock()
137
+ mock_cfg.cost_e = 1.0
138
+ mock_cfg.c_const = 1.0
139
+ mock_cfg.root = 0
140
+ mock_cfg.num_clusters = 1
141
+ mock_cfg.pruning = True
142
+ mock_cfg.verbosity_level = 0
143
+ mock_cfg.search_metric_type = "L2"
144
+ mock_cfg.node_colors_dict = {"drug": "blue", "gene/protein": "red"}
145
+ mock_compose.return_value = MagicMock()
146
+ mock_compose.return_value.tools.multimodal_subgraph_extraction = mock_cfg
147
+ mock_compose.return_value.tools.subgraph_summarization.\
148
+ prompt_subgraph_summarization = (
149
+ "Summarize the following subgraph: {textualized_subgraph}"
150
+ )
151
+
152
+ response = app.invoke({"messages": [HumanMessage(content=prompt)]}, config=config)
90
153
 
91
- # Check assistant message
92
154
  assistant_msg = response["messages"][-1].content
93
155
  assert isinstance(assistant_msg, str)
94
-
95
- # Check extracted subgraph dictionary
96
156
  current_state = app.get_state(config)
97
157
  dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
98
158
  assert isinstance(dic_extracted_graph, dict)
@@ -104,8 +164,10 @@ def test_t2kg_agent_openai(input_dict):
104
164
  assert len(dic_extracted_graph["graph_dict"]["nodes"]) > 0
105
165
  assert len(dic_extracted_graph["graph_dict"]["edges"]) > 0
106
166
  assert isinstance(dic_extracted_graph["graph_text"], str)
107
- # Check summarized subgraph
108
167
  assert isinstance(dic_extracted_graph["graph_summary"], str)
109
- # Check reasoning results
110
168
  assert "Adalimumab" in assistant_msg
111
169
  assert "TNF" in assistant_msg
170
+
171
+ # Another test for unknown collection
172
+ result = mock_milvus_collection("unknown")
173
+ assert result is None