aiagents4pharma 1.39.4__py3-none-any.whl → 1.40.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2aiagents4pharma/configs/agents/main_agent/default.yaml +26 -13
- aiagents4pharma/talk2aiagents4pharma/tests/test_main_agent.py +83 -3
- aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +4 -1
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +36 -5
- aiagents4pharma/talk2knowledgegraphs/milvus_data_dump.py +509 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +85 -23
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py +413 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ols.py +10 -10
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_extractions_milvus_multimodal_pcst.py +175 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_pubchem_utils.py +11 -0
- aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/tools/milvus_multimodal_subgraph_extraction.py +509 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ols_terms.py +15 -7
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/pubchem_strings.py +31 -9
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py +393 -0
- aiagents4pharma/talk2knowledgegraphs/utils/pubchem_utils.py +33 -2
- {aiagents4pharma-1.39.4.dist-info → aiagents4pharma-1.40.0.dist-info}/METADATA +13 -14
- {aiagents4pharma-1.39.4.dist-info → aiagents4pharma-1.40.0.dist-info}/RECORD +22 -17
- {aiagents4pharma-1.39.4.dist-info → aiagents4pharma-1.40.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.39.4.dist-info → aiagents4pharma-1.40.0.dist-info}/licenses/LICENSE +0 -0
- {aiagents4pharma-1.39.4.dist-info → aiagents4pharma-1.40.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,509 @@
|
|
1
|
+
# pylint: disable=wrong-import-position
|
2
|
+
#!/usr/bin/env python3
|
3
|
+
"""
|
4
|
+
Script to load PrimeKG multimodal data into Milvus database.
|
5
|
+
This script runs after Milvus container is ready and loads the .pkl file data.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import os
|
9
|
+
import sys
|
10
|
+
import subprocess
|
11
|
+
import glob
|
12
|
+
import logging
|
13
|
+
from typing import Dict, Any, List
|
14
|
+
|
15
|
+
def install_packages():
|
16
|
+
"""Install required packages."""
|
17
|
+
packages = [
|
18
|
+
"pip install --extra-index-url=https://pypi.nvidia.com cudf-cu12",
|
19
|
+
"pip install --extra-index-url=https://pypi.nvidia.com dask-cudf-cu12",
|
20
|
+
"pip install pymilvus==2.5.11",
|
21
|
+
"pip install numpy==1.26.4",
|
22
|
+
"pip install pandas==2.1.3",
|
23
|
+
"pip install tqdm==4.67.1",
|
24
|
+
]
|
25
|
+
|
26
|
+
print("[DATA LOADER] Installing required packages...")
|
27
|
+
for package_cmd in packages:
|
28
|
+
print(f"[DATA LOADER] Running: {package_cmd}")
|
29
|
+
result = subprocess.run(package_cmd.split(), capture_output=True, text=True, check=True)
|
30
|
+
if result.returncode != 0:
|
31
|
+
print(f"[DATA LOADER] Error installing package: {result.stderr}")
|
32
|
+
sys.exit(1)
|
33
|
+
print("[DATA LOADER] All packages installed successfully!")
|
34
|
+
|
35
|
+
# Install packages first
|
36
|
+
install_packages()
|
37
|
+
|
38
|
+
try:
|
39
|
+
import cudf
|
40
|
+
import cupy as cp
|
41
|
+
except ImportError as e:
|
42
|
+
print("[DATA LOADER] cudf or cupy not found. Please ensure they are installed correctly.")
|
43
|
+
sys.exit(1)
|
44
|
+
|
45
|
+
from pymilvus import (
|
46
|
+
db,
|
47
|
+
connections,
|
48
|
+
FieldSchema,
|
49
|
+
CollectionSchema,
|
50
|
+
DataType,
|
51
|
+
Collection,
|
52
|
+
utility
|
53
|
+
)
|
54
|
+
from tqdm import tqdm
|
55
|
+
|
56
|
+
# Configure logging
|
57
|
+
logging.basicConfig(level=logging.INFO, format='[DATA LOADER] %(message)s')
|
58
|
+
logger = logging.getLogger(__name__)
|
59
|
+
|
60
|
+
class MilvusDataLoader:
|
61
|
+
"""
|
62
|
+
Class to handle loading of BioBridge-PrimeKG multimodal data into Milvus.
|
63
|
+
"""
|
64
|
+
def __init__(self, config: Dict[str, Any]):
|
65
|
+
"""Initialize the MilvusDataLoader with configuration parameters."""
|
66
|
+
self.config = config
|
67
|
+
self.milvus_host = config.get('milvus_host', 'localhost')
|
68
|
+
self.milvus_port = config.get('milvus_port', '19530')
|
69
|
+
self.milvus_user = config.get('milvus_user', 'root')
|
70
|
+
self.milvus_password = config.get('milvus_password', 'Milvus')
|
71
|
+
self.milvus_database = config.get('milvus_database', 't2kg_primekg')
|
72
|
+
self.data_dir = config.get('data_dir',
|
73
|
+
'tests/files/biobridge_multimodal/')
|
74
|
+
self.batch_size = config.get('batch_size', 500)
|
75
|
+
self.chunk_size = config.get('chunk_size', 5)
|
76
|
+
|
77
|
+
def normalize_matrix(self, m, axis=1):
|
78
|
+
"""Normalize each row of a 2D matrix using CuPy."""
|
79
|
+
norms = cp.linalg.norm(m, axis=axis, keepdims=True)
|
80
|
+
return m / norms
|
81
|
+
|
82
|
+
def normalize_vector(self, v):
|
83
|
+
"""Normalize a vector using CuPy."""
|
84
|
+
v = cp.asarray(v)
|
85
|
+
norm = cp.linalg.norm(v)
|
86
|
+
return v / norm
|
87
|
+
|
88
|
+
def connect_to_milvus(self):
|
89
|
+
"""Connect to Milvus and setup database."""
|
90
|
+
logger.info("Connecting to Milvus at %s:%s", self.milvus_host, self.milvus_port)
|
91
|
+
|
92
|
+
connections.connect(
|
93
|
+
alias="default",
|
94
|
+
host=self.milvus_host,
|
95
|
+
port=self.milvus_port,
|
96
|
+
user=self.milvus_user,
|
97
|
+
password=self.milvus_password
|
98
|
+
)
|
99
|
+
|
100
|
+
# Check if database exists, create if it doesn't
|
101
|
+
if self.milvus_database not in db.list_database():
|
102
|
+
logger.info("Creating database: %s", self.milvus_database)
|
103
|
+
db.create_database(self.milvus_database)
|
104
|
+
|
105
|
+
# Switch to the desired database
|
106
|
+
db.using_database(self.milvus_database)
|
107
|
+
logger.info("Using database: %s", self.milvus_database)
|
108
|
+
|
109
|
+
def load_graph_data(self):
|
110
|
+
"""Load the pickle file containing graph data."""
|
111
|
+
logger.info("Loading graph data from: %s", self.data_dir)
|
112
|
+
|
113
|
+
if not os.path.exists(self.data_dir):
|
114
|
+
raise FileNotFoundError(f"Data directory not found: {self.data_dir}")
|
115
|
+
|
116
|
+
# Load dataframes containing nodes and edges
|
117
|
+
# Loop over nodes and edges
|
118
|
+
graph = {}
|
119
|
+
for element in ["nodes", "edges"]:
|
120
|
+
# Make an empty dictionary for each folder
|
121
|
+
graph[element] = {}
|
122
|
+
for stage in ["enrichment", "embedding"]:
|
123
|
+
print(element, stage)
|
124
|
+
# Create the file pattern for the current subfolder
|
125
|
+
file_list = glob.glob(os.path.join(self.data_dir,
|
126
|
+
element,
|
127
|
+
stage,
|
128
|
+
'*.parquet.gzip'))
|
129
|
+
print(file_list)
|
130
|
+
# Read and concatenate all dataframes in the folder
|
131
|
+
# Except the edges embedding, which is too large to read in one go
|
132
|
+
# We are using a chunk size to read the edges embedding in smaller parts instead
|
133
|
+
if element == "edges" and stage == "embedding":
|
134
|
+
# For edges embedding, only read two columns:
|
135
|
+
# triplet_index and edge_emb
|
136
|
+
# Loop by chunks
|
137
|
+
chunk_size = self.chunk_size
|
138
|
+
graph[element][stage] = []
|
139
|
+
for i in range(0, len(file_list), chunk_size):
|
140
|
+
chunk_files = file_list[i:i+chunk_size]
|
141
|
+
chunk_df = cudf.concat([
|
142
|
+
cudf.read_parquet(f, columns=["triplet_index", "edge_emb"])
|
143
|
+
for f in chunk_files
|
144
|
+
], ignore_index=True)
|
145
|
+
graph[element][stage].append(chunk_df)
|
146
|
+
else:
|
147
|
+
# For nodes and edges enrichment,
|
148
|
+
# read and concatenate all dataframes in the folder
|
149
|
+
# This includes the nodes embedding,
|
150
|
+
# which is small enough to read in one go
|
151
|
+
graph[element][stage] = cudf.concat([
|
152
|
+
cudf.read_parquet(f) for f in file_list
|
153
|
+
], ignore_index=True)
|
154
|
+
|
155
|
+
logger.info("Graph data loaded successfully")
|
156
|
+
return graph
|
157
|
+
|
158
|
+
def create_nodes_collection(self, nodes_df: cudf.DataFrame):
|
159
|
+
"""Create and populate the main nodes collection."""
|
160
|
+
logger.info("Creating main nodes collection...")
|
161
|
+
node_coll_name = f"{self.milvus_database}_nodes"
|
162
|
+
|
163
|
+
node_fields = [
|
164
|
+
FieldSchema(name="node_index",
|
165
|
+
dtype=DataType.INT64,
|
166
|
+
is_primary=True),
|
167
|
+
FieldSchema(name="node_id",
|
168
|
+
dtype=DataType.VARCHAR,
|
169
|
+
max_length=1024),
|
170
|
+
FieldSchema(name="node_name",
|
171
|
+
dtype=DataType.VARCHAR,
|
172
|
+
max_length=1024,
|
173
|
+
enable_analyzer=True,
|
174
|
+
enable_match=True),
|
175
|
+
FieldSchema(name="node_type",
|
176
|
+
dtype=DataType.VARCHAR,
|
177
|
+
max_length=1024,
|
178
|
+
enable_analyzer=True,
|
179
|
+
enable_match=True),
|
180
|
+
FieldSchema(name="desc",
|
181
|
+
dtype=DataType.VARCHAR,
|
182
|
+
max_length=40960,
|
183
|
+
enable_analyzer=True,
|
184
|
+
enable_match=True),
|
185
|
+
FieldSchema(name="desc_emb",
|
186
|
+
dtype=DataType.FLOAT_VECTOR,
|
187
|
+
dim=len(nodes_df.iloc[0]['desc_emb'].to_arrow().to_pylist()[0])),
|
188
|
+
]
|
189
|
+
schema = CollectionSchema(fields=node_fields,
|
190
|
+
description=f"Schema for collection {node_coll_name}")
|
191
|
+
|
192
|
+
# Create collection if it doesn't exist
|
193
|
+
if not utility.has_collection(node_coll_name):
|
194
|
+
collection = Collection(name=node_coll_name, schema=schema)
|
195
|
+
else:
|
196
|
+
collection = Collection(name=node_coll_name)
|
197
|
+
|
198
|
+
# Create indexes
|
199
|
+
collection.create_index(field_name="node_index",
|
200
|
+
index_params={"index_type": "STL_SORT"},
|
201
|
+
index_name="node_index_index")
|
202
|
+
collection.create_index(field_name="node_name",
|
203
|
+
index_params={"index_type": "INVERTED"},
|
204
|
+
index_name="node_name_index")
|
205
|
+
collection.create_index(field_name="node_type",
|
206
|
+
index_params={"index_type": "INVERTED"},
|
207
|
+
index_name="node_type_index")
|
208
|
+
collection.create_index(field_name="desc",
|
209
|
+
index_params={"index_type": "INVERTED"},
|
210
|
+
index_name="desc_index")
|
211
|
+
collection.create_index(field_name="desc_emb",
|
212
|
+
index_params={"index_type": "GPU_CAGRA",
|
213
|
+
"metric_type": "IP"},
|
214
|
+
index_name="desc_emb_index")
|
215
|
+
|
216
|
+
# Prepare and insert data
|
217
|
+
desc_emb_norm = cp.asarray(nodes_df["desc_emb"].list.leaves).astype(cp.float32).\
|
218
|
+
reshape(nodes_df.shape[0], -1)
|
219
|
+
desc_emb_norm = self.normalize_matrix(desc_emb_norm, axis=1)
|
220
|
+
data = [
|
221
|
+
nodes_df["node_index"].to_arrow().to_pylist(),
|
222
|
+
nodes_df["node_id"].to_arrow().to_pylist(),
|
223
|
+
nodes_df["node_name"].to_arrow().to_pylist(),
|
224
|
+
nodes_df["node_type"].to_arrow().to_pylist(),
|
225
|
+
nodes_df["desc"].to_arrow().to_pylist(),
|
226
|
+
desc_emb_norm.tolist(), # Use normalized embeddings
|
227
|
+
]
|
228
|
+
|
229
|
+
# Insert data in batches
|
230
|
+
total = len(data[0])
|
231
|
+
for i in tqdm(range(0, total, self.batch_size), desc="Inserting nodes"):
|
232
|
+
batch = [col[i:i+self.batch_size] for col in data]
|
233
|
+
collection.insert(batch)
|
234
|
+
|
235
|
+
collection.flush()
|
236
|
+
logger.info("Nodes collection created with %d entities", collection.num_entities)
|
237
|
+
|
238
|
+
def create_node_type_collections(self, nodes_df: cudf.DataFrame):
|
239
|
+
"""Create separate collections for each node type."""
|
240
|
+
logger.info("Creating node type-specific collections...")
|
241
|
+
|
242
|
+
for node_type, nodes_df_ in tqdm(nodes_df.groupby('node_type'),
|
243
|
+
desc="Processing node types"):
|
244
|
+
node_coll_name = f"{self.milvus_database}_nodes_{node_type.replace('/', '_')}"
|
245
|
+
|
246
|
+
node_fields = [
|
247
|
+
FieldSchema(name="node_index",
|
248
|
+
dtype=DataType.INT64,
|
249
|
+
is_primary=True,
|
250
|
+
auto_id=False),
|
251
|
+
FieldSchema(name="node_id",
|
252
|
+
dtype=DataType.VARCHAR,
|
253
|
+
max_length=1024),
|
254
|
+
FieldSchema(name="node_name",
|
255
|
+
dtype=DataType.VARCHAR,
|
256
|
+
max_length=1024,
|
257
|
+
enable_analyzer=True,
|
258
|
+
enable_match=True),
|
259
|
+
FieldSchema(name="node_type",
|
260
|
+
dtype=DataType.VARCHAR,
|
261
|
+
max_length=1024,
|
262
|
+
enable_analyzer=True,
|
263
|
+
enable_match=True),
|
264
|
+
FieldSchema(name="desc",
|
265
|
+
dtype=DataType.VARCHAR,
|
266
|
+
max_length=40960,
|
267
|
+
enable_analyzer=True,
|
268
|
+
enable_match=True),
|
269
|
+
FieldSchema(name="desc_emb",
|
270
|
+
dtype=DataType.FLOAT_VECTOR,
|
271
|
+
dim=len(nodes_df_.iloc[0]['desc_emb'].to_arrow().to_pylist()[0])),
|
272
|
+
FieldSchema(name="feat",
|
273
|
+
dtype=DataType.VARCHAR,
|
274
|
+
max_length=40960,
|
275
|
+
enable_analyzer=True,
|
276
|
+
enable_match=True),
|
277
|
+
FieldSchema(name="feat_emb",
|
278
|
+
dtype=DataType.FLOAT_VECTOR,
|
279
|
+
dim=len(nodes_df_.iloc[0]['feat_emb'].to_arrow().to_pylist()[0])),
|
280
|
+
]
|
281
|
+
schema = CollectionSchema(fields=node_fields,
|
282
|
+
description=f"schema for collection {node_coll_name}")
|
283
|
+
|
284
|
+
if not utility.has_collection(node_coll_name):
|
285
|
+
collection = Collection(name=node_coll_name, schema=schema)
|
286
|
+
else:
|
287
|
+
collection = Collection(name=node_coll_name)
|
288
|
+
|
289
|
+
# Create indexes
|
290
|
+
collection.create_index(field_name="node_index",
|
291
|
+
index_params={"index_type": "STL_SORT"},
|
292
|
+
index_name="node_index_index")
|
293
|
+
collection.create_index(field_name="node_name",
|
294
|
+
index_params={"index_type": "INVERTED"},
|
295
|
+
index_name="node_name_index")
|
296
|
+
collection.create_index(field_name="node_type",
|
297
|
+
index_params={"index_type": "INVERTED"},
|
298
|
+
index_name="node_type_index")
|
299
|
+
collection.create_index(field_name="desc",
|
300
|
+
index_params={"index_type": "INVERTED"},
|
301
|
+
index_name="desc_index")
|
302
|
+
collection.create_index(field_name="desc_emb",
|
303
|
+
index_params={"index_type": "GPU_CAGRA",
|
304
|
+
"metric_type": "IP"},
|
305
|
+
index_name="desc_emb_index")
|
306
|
+
collection.create_index(field_name="feat_emb",
|
307
|
+
index_params={"index_type": "GPU_CAGRA",
|
308
|
+
"metric_type": "IP"},
|
309
|
+
index_name="feat_emb_index")
|
310
|
+
|
311
|
+
# Prepare data
|
312
|
+
desc_emb_norm = cp.asarray(nodes_df_["desc_emb"].list.leaves).astype(cp.float32).\
|
313
|
+
reshape(nodes_df_.shape[0], -1)
|
314
|
+
desc_emb_norm = self.normalize_matrix(desc_emb_norm, axis=1)
|
315
|
+
feat_emb_norm = cp.asarray(nodes_df_["feat_emb"].list.leaves).astype(cp.float32).\
|
316
|
+
reshape(nodes_df_.shape[0], -1)
|
317
|
+
feat_emb_norm = self.normalize_matrix(feat_emb_norm, axis=1)
|
318
|
+
data = [
|
319
|
+
nodes_df_["node_index"].to_arrow().to_pylist(),
|
320
|
+
nodes_df_["node_id"].to_arrow().to_pylist(),
|
321
|
+
nodes_df_["node_name"].to_arrow().to_pylist(),
|
322
|
+
nodes_df_["node_type"].to_arrow().to_pylist(),
|
323
|
+
nodes_df_["desc"].to_arrow().to_pylist(),
|
324
|
+
desc_emb_norm.tolist(), # Use normalized embeddings
|
325
|
+
nodes_df_["feat"].to_arrow().to_pylist(),
|
326
|
+
feat_emb_norm.tolist(), # Use normalized embeddings
|
327
|
+
]
|
328
|
+
|
329
|
+
# Insert data in batches
|
330
|
+
total_rows = len(data[0])
|
331
|
+
for i in range(0, total_rows, self.batch_size):
|
332
|
+
batch = [col[i:i + self.batch_size] for col in data]
|
333
|
+
collection.insert(batch)
|
334
|
+
|
335
|
+
collection.flush()
|
336
|
+
logger.info("Collection %s created with %d entities",
|
337
|
+
node_coll_name, collection.num_entities)
|
338
|
+
|
339
|
+
def create_edges_collection(self,
|
340
|
+
edges_enrichment_df: cudf.DataFrame,
|
341
|
+
edges_embedding_df: List[cudf.DataFrame]):
|
342
|
+
"""Create and populate the edges collection."""
|
343
|
+
logger.info("Creating edges collection...")
|
344
|
+
|
345
|
+
edge_coll_name = f"{self.milvus_database}_edges"
|
346
|
+
|
347
|
+
edge_fields = [
|
348
|
+
FieldSchema(name="triplet_index",
|
349
|
+
dtype=DataType.INT64,
|
350
|
+
is_primary=True,
|
351
|
+
auto_id=False),
|
352
|
+
FieldSchema(name="head_id",
|
353
|
+
dtype=DataType.VARCHAR,
|
354
|
+
max_length=1024),
|
355
|
+
FieldSchema(name="head_index",
|
356
|
+
dtype=DataType.INT64),
|
357
|
+
FieldSchema(name="tail_id",
|
358
|
+
dtype=DataType.VARCHAR,
|
359
|
+
max_length=1024),
|
360
|
+
FieldSchema(name="tail_index",
|
361
|
+
dtype=DataType.INT64),
|
362
|
+
FieldSchema(name="edge_type",
|
363
|
+
dtype=DataType.VARCHAR,
|
364
|
+
max_length=1024),
|
365
|
+
FieldSchema(name="display_relation",
|
366
|
+
dtype=DataType.VARCHAR,
|
367
|
+
max_length=1024),
|
368
|
+
FieldSchema(name="feat",
|
369
|
+
dtype=DataType.VARCHAR,
|
370
|
+
max_length=40960),
|
371
|
+
FieldSchema(name="feat_emb",
|
372
|
+
dtype=DataType.FLOAT_VECTOR,
|
373
|
+
dim=len(edges_embedding_df[0].loc[0, 'edge_emb'])),
|
374
|
+
]
|
375
|
+
edge_schema = CollectionSchema(fields=edge_fields,
|
376
|
+
description="Schema for edges collection")
|
377
|
+
|
378
|
+
if not utility.has_collection(edge_coll_name):
|
379
|
+
collection = Collection(name=edge_coll_name, schema=edge_schema)
|
380
|
+
else:
|
381
|
+
collection = Collection(name=edge_coll_name)
|
382
|
+
|
383
|
+
# Create indexes
|
384
|
+
collection.create_index(field_name="triplet_index",
|
385
|
+
index_params={"index_type": "STL_SORT"},
|
386
|
+
index_name="triplet_index_index")
|
387
|
+
collection.create_index(field_name="head_index",
|
388
|
+
index_params={"index_type": "STL_SORT"},
|
389
|
+
index_name="head_index_index")
|
390
|
+
collection.create_index(field_name="tail_index",
|
391
|
+
index_params={"index_type": "STL_SORT"},
|
392
|
+
index_name="tail_index_index")
|
393
|
+
collection.create_index(field_name="feat_emb",
|
394
|
+
index_params={"index_type": "GPU_CAGRA",
|
395
|
+
"metric_type": "IP"},
|
396
|
+
index_name="feat_emb_index")
|
397
|
+
|
398
|
+
# Iterate over chunked edges embedding df
|
399
|
+
for edges_df in tqdm(edges_embedding_df):
|
400
|
+
# Merge enrichment with embedding
|
401
|
+
merged_edges_df = edges_enrichment_df.merge(
|
402
|
+
edges_df[["triplet_index", "edge_emb"]],
|
403
|
+
on="triplet_index",
|
404
|
+
how="inner"
|
405
|
+
)
|
406
|
+
|
407
|
+
# Prepare data
|
408
|
+
edge_emb_cp = cp.asarray(merged_edges_df["edge_emb"].list.leaves).astype(cp.float32).\
|
409
|
+
reshape(merged_edges_df.shape[0], -1)
|
410
|
+
edge_emb_norm = self.normalize_matrix(edge_emb_cp, axis=1)
|
411
|
+
data = [
|
412
|
+
merged_edges_df["triplet_index"].to_arrow().to_pylist(),
|
413
|
+
merged_edges_df["head_id"].to_arrow().to_pylist(),
|
414
|
+
merged_edges_df["head_index"].to_arrow().to_pylist(),
|
415
|
+
merged_edges_df["tail_id"].to_arrow().to_pylist(),
|
416
|
+
merged_edges_df["tail_index"].to_arrow().to_pylist(),
|
417
|
+
merged_edges_df["edge_type_str"].to_arrow().to_pylist(),
|
418
|
+
merged_edges_df["display_relation"].to_arrow().to_pylist(),
|
419
|
+
merged_edges_df["feat"].to_arrow().to_pylist(),
|
420
|
+
edge_emb_norm.tolist(), # Use normalized embeddings
|
421
|
+
]
|
422
|
+
|
423
|
+
# Insert data in batches
|
424
|
+
total = len(data[0])
|
425
|
+
for i in tqdm(range(0, total, self.batch_size), desc="Inserting edges"):
|
426
|
+
batch_data = [d[i:i+self.batch_size] for d in data]
|
427
|
+
collection.insert(batch_data)
|
428
|
+
|
429
|
+
collection.flush()
|
430
|
+
logger.info("Edges collection created with %d entities", collection.num_entities)
|
431
|
+
|
432
|
+
def run(self):
|
433
|
+
"""Main execution method."""
|
434
|
+
try:
|
435
|
+
logger.info("Starting Milvus data loading process...")
|
436
|
+
|
437
|
+
# Connect to Milvus
|
438
|
+
self.connect_to_milvus()
|
439
|
+
|
440
|
+
# Load graph data
|
441
|
+
graph = self.load_graph_data()
|
442
|
+
|
443
|
+
# Prepare data
|
444
|
+
logger.info("Data Preparation started...")
|
445
|
+
# Get nodes enrichment and embedding dataframes
|
446
|
+
nodes_enrichment_df = graph['nodes']['enrichment']
|
447
|
+
nodes_embedding_df = graph['nodes']['embedding']
|
448
|
+
|
449
|
+
# Get edges enrichment and embedding dataframes
|
450
|
+
edges_enrichment_df = graph['edges']['enrichment']
|
451
|
+
# !!consisted of a list of dataframes!!
|
452
|
+
edges_embedding_df = graph['edges']['embedding']
|
453
|
+
|
454
|
+
# For nodes, we can directly merge enrichment and embedding
|
455
|
+
# Merge nodes enrichment and embedding dataframes
|
456
|
+
merged_nodes_df = nodes_enrichment_df.merge(
|
457
|
+
nodes_embedding_df[["node_id", "desc_emb", "feat_emb"]],
|
458
|
+
on="node_id",
|
459
|
+
how="left"
|
460
|
+
)
|
461
|
+
|
462
|
+
# Create collections and load data
|
463
|
+
self.create_nodes_collection(merged_nodes_df)
|
464
|
+
self.create_node_type_collections(merged_nodes_df)
|
465
|
+
self.create_edges_collection(edges_enrichment_df,
|
466
|
+
edges_embedding_df)
|
467
|
+
|
468
|
+
# List all collections for verification
|
469
|
+
logger.info("Data loading completed successfully!")
|
470
|
+
logger.info("Created collections:")
|
471
|
+
for coll in utility.list_collections():
|
472
|
+
collection = Collection(name=coll)
|
473
|
+
logger.info(" %s: %d entities", coll, collection.num_entities)
|
474
|
+
|
475
|
+
except Exception as e:
|
476
|
+
logger.error("Error during data loading: %s", str(e))
|
477
|
+
raise
|
478
|
+
|
479
|
+
|
480
|
+
def main():
|
481
|
+
"""Main function to run the data loader."""
|
482
|
+
# Resolve the fallback data path relative to this script's location
|
483
|
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
484
|
+
default_data_dir = os.path.join(script_dir, "tests/files/biobridge_multimodal/")
|
485
|
+
|
486
|
+
# Configuration
|
487
|
+
config = {
|
488
|
+
'milvus_host': os.getenv('MILVUS_HOST', 'localhost'),
|
489
|
+
'milvus_port': os.getenv('MILVUS_PORT', '19530'),
|
490
|
+
'milvus_user': os.getenv('MILVUS_USER', 'root'),
|
491
|
+
'milvus_password': os.getenv('MILVUS_PASSWORD', 'Milvus'),
|
492
|
+
'milvus_database': os.getenv('MILVUS_DATABASE', 't2kg_primekg'),
|
493
|
+
'data_dir': os.getenv('DATA_DIR', default_data_dir),
|
494
|
+
'batch_size': int(os.getenv('BATCH_SIZE', '500')),
|
495
|
+
'chunk_size': int(os.getenv('CHUNK_SIZE', '5')),
|
496
|
+
}
|
497
|
+
|
498
|
+
# Print configuration for debugging
|
499
|
+
print("[DATA LOADER] Configuration:")
|
500
|
+
for key, value in config.items():
|
501
|
+
print(f"[DATA LOADER] {key}: {value}")
|
502
|
+
|
503
|
+
# Create and run data loader
|
504
|
+
loader = MilvusDataLoader(config)
|
505
|
+
loader.run()
|
506
|
+
|
507
|
+
|
508
|
+
if __name__ == "__main__":
|
509
|
+
main()
|
@@ -1,24 +1,23 @@
|
|
1
1
|
"""
|
2
2
|
Test cases for agents/t2kg_agent.py
|
3
3
|
"""
|
4
|
-
|
4
|
+
from unittest.mock import patch, MagicMock
|
5
5
|
import pytest
|
6
6
|
from langchain_core.messages import HumanMessage
|
7
7
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
8
|
+
import pandas as pd
|
8
9
|
from ..agents.t2kg_agent import get_app
|
9
10
|
|
10
|
-
# Define the data path
|
11
11
|
DATA_PATH = "aiagents4pharma/talk2knowledgegraphs/tests/files"
|
12
12
|
|
13
|
-
|
14
13
|
@pytest.fixture(name="input_dict")
|
15
14
|
def input_dict_fixture():
|
16
15
|
"""
|
17
16
|
Input dictionary fixture.
|
18
17
|
"""
|
19
18
|
input_dict = {
|
20
|
-
"llm_model": None,
|
21
|
-
"embedding_model": None,
|
19
|
+
"llm_model": None,
|
20
|
+
"embedding_model": None,
|
22
21
|
"selections": {
|
23
22
|
"gene/protein": [],
|
24
23
|
"molecular_function": [],
|
@@ -47,30 +46,62 @@ def input_dict_fixture():
|
|
47
46
|
],
|
48
47
|
"dic_extracted_graph": []
|
49
48
|
}
|
50
|
-
|
51
49
|
return input_dict
|
52
50
|
|
51
|
+
def mock_milvus_collection(name):
|
52
|
+
"""
|
53
|
+
Mock Milvus collection for testing.
|
54
|
+
"""
|
55
|
+
nodes = MagicMock()
|
56
|
+
nodes.query.return_value = [
|
57
|
+
{"node_index": 0,
|
58
|
+
"node_id": "id1",
|
59
|
+
"node_name": "Adalimumab",
|
60
|
+
"node_type": "drug",
|
61
|
+
"feat": "featA", "feat_emb": [0.1, 0.2, 0.3],
|
62
|
+
"desc": "descA", "desc_emb": [0.1, 0.2, 0.3]},
|
63
|
+
{"node_index": 1,
|
64
|
+
"node_id": "id2",
|
65
|
+
"node_name": "TNF",
|
66
|
+
"node_type": "gene/protein",
|
67
|
+
"feat": "featB", "feat_emb": [0.4, 0.5, 0.6],
|
68
|
+
"desc": "descB", "desc_emb": [0.4, 0.5, 0.6]}
|
69
|
+
]
|
70
|
+
nodes.load.return_value = None
|
71
|
+
|
72
|
+
edges = MagicMock()
|
73
|
+
edges.query.return_value = [
|
74
|
+
{"triplet_index": 0,
|
75
|
+
"head_id": "id1",
|
76
|
+
"head_index": 0,
|
77
|
+
"tail_id": "id2",
|
78
|
+
"tail_index": 1,
|
79
|
+
"edge_type": "drug,acts_on,gene/protein",
|
80
|
+
"display_relation": "acts_on",
|
81
|
+
"feat": "featC",
|
82
|
+
"feat_emb": [0.7, 0.8, 0.9]}
|
83
|
+
]
|
84
|
+
edges.load.return_value = None
|
53
85
|
|
54
|
-
|
86
|
+
if "nodes" in name:
|
87
|
+
return nodes
|
88
|
+
if "edges" in name:
|
89
|
+
return edges
|
90
|
+
return None
|
91
|
+
|
92
|
+
def test_t2kg_agent_openai_milvus_mock(input_dict):
|
55
93
|
"""
|
56
|
-
Test the T2KG agent using OpenAI model.
|
94
|
+
Test the T2KG agent using OpenAI model and Milvus mock.
|
57
95
|
|
58
96
|
Args:
|
59
97
|
input_dict: Input dictionary
|
60
98
|
"""
|
61
|
-
# Prepare LLM and embedding model
|
62
99
|
input_dict["llm_model"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
|
63
100
|
input_dict["embedding_model"] = OpenAIEmbeddings(model="text-embedding-3-small")
|
64
|
-
|
65
|
-
# Setup the app
|
66
101
|
unique_id = 12345
|
67
102
|
app = get_app(unique_id, llm_model=input_dict["llm_model"])
|
68
103
|
config = {"configurable": {"thread_id": unique_id}}
|
69
|
-
|
70
|
-
app.update_state(
|
71
|
-
config,
|
72
|
-
input_dict,
|
73
|
-
)
|
104
|
+
app.update_state(config, input_dict)
|
74
105
|
prompt = """
|
75
106
|
Adalimumab is a fully human monoclonal antibody (IgG1)
|
76
107
|
that specifically binds to tumor necrosis factor-alpha (TNF-α), a pro-inflammatory cytokine.
|
@@ -85,14 +116,43 @@ def test_t2kg_agent_openai(input_dict):
|
|
85
116
|
Please set the extraction name for the extraction process as `subkg_12345`.
|
86
117
|
"""
|
87
118
|
|
88
|
-
|
89
|
-
|
119
|
+
with patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
120
|
+
"milvus_multimodal_subgraph_extraction.Collection",
|
121
|
+
side_effect=mock_milvus_collection), \
|
122
|
+
patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
123
|
+
"milvus_multimodal_subgraph_extraction.MultimodalPCSTPruning") as mock_pcst, \
|
124
|
+
patch("pymilvus.connections") as mock_connections, \
|
125
|
+
patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
126
|
+
"milvus_multimodal_subgraph_extraction.hydra.initialize"), \
|
127
|
+
patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
128
|
+
"milvus_multimodal_subgraph_extraction.hydra.compose") as mock_compose:
|
129
|
+
mock_connections.has_connection.return_value = True
|
130
|
+
mock_pcst_instance = MagicMock()
|
131
|
+
mock_pcst_instance.extract_subgraph.return_value = {
|
132
|
+
"nodes": pd.Series([0, 1]),
|
133
|
+
"edges": pd.Series([0])
|
134
|
+
}
|
135
|
+
mock_pcst.return_value = mock_pcst_instance
|
136
|
+
mock_cfg = MagicMock()
|
137
|
+
mock_cfg.cost_e = 1.0
|
138
|
+
mock_cfg.c_const = 1.0
|
139
|
+
mock_cfg.root = 0
|
140
|
+
mock_cfg.num_clusters = 1
|
141
|
+
mock_cfg.pruning = True
|
142
|
+
mock_cfg.verbosity_level = 0
|
143
|
+
mock_cfg.search_metric_type = "L2"
|
144
|
+
mock_cfg.node_colors_dict = {"drug": "blue", "gene/protein": "red"}
|
145
|
+
mock_compose.return_value = MagicMock()
|
146
|
+
mock_compose.return_value.tools.multimodal_subgraph_extraction = mock_cfg
|
147
|
+
mock_compose.return_value.tools.subgraph_summarization.\
|
148
|
+
prompt_subgraph_summarization = (
|
149
|
+
"Summarize the following subgraph: {textualized_subgraph}"
|
150
|
+
)
|
151
|
+
|
152
|
+
response = app.invoke({"messages": [HumanMessage(content=prompt)]}, config=config)
|
90
153
|
|
91
|
-
# Check assistant message
|
92
154
|
assistant_msg = response["messages"][-1].content
|
93
155
|
assert isinstance(assistant_msg, str)
|
94
|
-
|
95
|
-
# Check extracted subgraph dictionary
|
96
156
|
current_state = app.get_state(config)
|
97
157
|
dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
|
98
158
|
assert isinstance(dic_extracted_graph, dict)
|
@@ -104,8 +164,10 @@ def test_t2kg_agent_openai(input_dict):
|
|
104
164
|
assert len(dic_extracted_graph["graph_dict"]["nodes"]) > 0
|
105
165
|
assert len(dic_extracted_graph["graph_dict"]["edges"]) > 0
|
106
166
|
assert isinstance(dic_extracted_graph["graph_text"], str)
|
107
|
-
# Check summarized subgraph
|
108
167
|
assert isinstance(dic_extracted_graph["graph_summary"], str)
|
109
|
-
# Check reasoning results
|
110
168
|
assert "Adalimumab" in assistant_msg
|
111
169
|
assert "TNF" in assistant_msg
|
170
|
+
|
171
|
+
# Another test for unknown collection
|
172
|
+
result = mock_milvus_collection("unknown")
|
173
|
+
assert result is None
|