aiagents4pharma 1.39.4__py3-none-any.whl → 1.40.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2aiagents4pharma/configs/agents/main_agent/default.yaml +26 -13
- aiagents4pharma/talk2aiagents4pharma/tests/test_main_agent.py +83 -3
- aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +4 -1
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +36 -5
- aiagents4pharma/talk2knowledgegraphs/milvus_data_dump.py +509 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +85 -23
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py +413 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ols.py +10 -10
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_extractions_milvus_multimodal_pcst.py +175 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_pubchem_utils.py +11 -0
- aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/tools/milvus_multimodal_subgraph_extraction.py +509 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ols_terms.py +15 -7
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/pubchem_strings.py +31 -9
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py +393 -0
- aiagents4pharma/talk2knowledgegraphs/utils/pubchem_utils.py +33 -2
- {aiagents4pharma-1.39.4.dist-info → aiagents4pharma-1.40.0.dist-info}/METADATA +13 -14
- {aiagents4pharma-1.39.4.dist-info → aiagents4pharma-1.40.0.dist-info}/RECORD +22 -17
- {aiagents4pharma-1.39.4.dist-info → aiagents4pharma-1.40.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.39.4.dist-info → aiagents4pharma-1.40.0.dist-info}/licenses/LICENSE +0 -0
- {aiagents4pharma-1.39.4.dist-info → aiagents4pharma-1.40.0.dist-info}/top_level.txt +0 -0
aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py
ADDED
@@ -0,0 +1,413 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for tools/milvus_multimodal_subgraph_extraction.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import importlib
|
6
|
+
import unittest
|
7
|
+
from unittest.mock import patch, MagicMock
|
8
|
+
import numpy as np
|
9
|
+
import pandas as pd
|
10
|
+
from ..tools.milvus_multimodal_subgraph_extraction import MultimodalSubgraphExtractionTool
|
11
|
+
|
12
|
+
class TestMultimodalSubgraphExtractionTool(unittest.TestCase):
|
13
|
+
"""
|
14
|
+
Test cases for MultimodalSubgraphExtractionTool (Milvus)
|
15
|
+
"""
|
16
|
+
def setUp(self):
|
17
|
+
self.tool = MultimodalSubgraphExtractionTool()
|
18
|
+
self.state = {
|
19
|
+
"uploaded_files": [],
|
20
|
+
"embedding_model": MagicMock(),
|
21
|
+
"topk_nodes": 5,
|
22
|
+
"topk_edges": 5,
|
23
|
+
"dic_source_graph": [{"name": "TestGraph"}],
|
24
|
+
}
|
25
|
+
self.prompt = "Find subgraph for test"
|
26
|
+
self.arg_data = {"extraction_name": "subkg_12345"}
|
27
|
+
self.cfg_db = MagicMock()
|
28
|
+
self.cfg_db.milvus_db.database_name = "testdb"
|
29
|
+
self.cfg_db.milvus_db.alias = "default"
|
30
|
+
self.cfg = MagicMock()
|
31
|
+
self.cfg.cost_e = 1.0
|
32
|
+
self.cfg.c_const = 1.0
|
33
|
+
self.cfg.root = 0
|
34
|
+
self.cfg.num_clusters = 1
|
35
|
+
self.cfg.pruning = True
|
36
|
+
self.cfg.verbosity_level = 0
|
37
|
+
self.cfg.search_metric_type = "L2"
|
38
|
+
self.cfg.node_colors_dict = {"gene/protein": "red"}
|
39
|
+
|
40
|
+
@patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
41
|
+
"milvus_multimodal_subgraph_extraction.Collection")
|
42
|
+
@patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
43
|
+
"milvus_multimodal_subgraph_extraction.MultimodalPCSTPruning")
|
44
|
+
@patch("pymilvus.connections")
|
45
|
+
def test_extract_multimodal_subgraph_wo_doc(self,
|
46
|
+
mock_connections,
|
47
|
+
mock_pcst,
|
48
|
+
mock_collection):
|
49
|
+
"""
|
50
|
+
Test the multimodal subgraph extraction tool for only text as modality.
|
51
|
+
"""
|
52
|
+
|
53
|
+
# Mock Milvus connection utilities
|
54
|
+
mock_connections.has_connection.return_value = True
|
55
|
+
|
56
|
+
# No uploaded_files (no doc)
|
57
|
+
self.state["uploaded_files"] = []
|
58
|
+
self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
|
59
|
+
self.state["selections"] = {}
|
60
|
+
|
61
|
+
# Mock Collection for nodes and edges
|
62
|
+
colls = {}
|
63
|
+
colls["nodes"] = MagicMock()
|
64
|
+
colls["nodes"] = MagicMock()
|
65
|
+
colls["nodes"].query.return_value = [
|
66
|
+
{"node_index": 0,
|
67
|
+
"node_id": "id1",
|
68
|
+
"node_name": "JAK1",
|
69
|
+
"node_type": "gene/protein",
|
70
|
+
"feat": "featA",
|
71
|
+
"feat_emb": [0.1, 0.2, 0.3],
|
72
|
+
"desc": "descA",
|
73
|
+
"desc_emb": [0.1, 0.2, 0.3]},
|
74
|
+
{"node_index": 1,
|
75
|
+
"node_id": "id2",
|
76
|
+
"node_name": "JAK2",
|
77
|
+
"node_type": "gene/protein",
|
78
|
+
"feat": "featB",
|
79
|
+
"feat_emb": [0.4, 0.5, 0.6],
|
80
|
+
"desc": "descB",
|
81
|
+
"desc_emb": [0.4, 0.5, 0.6]}
|
82
|
+
]
|
83
|
+
colls["nodes"].load.return_value = None
|
84
|
+
|
85
|
+
colls["edges"] = MagicMock()
|
86
|
+
colls["edges"].query.return_value = [
|
87
|
+
{"triplet_index": 0,
|
88
|
+
"head_id": "id1",
|
89
|
+
"head_index": 0,
|
90
|
+
"tail_id": "id2",
|
91
|
+
"tail_index": 1,
|
92
|
+
"edge_type": "gene/protein,ppi,gene/protein",
|
93
|
+
"display_relation": "ppi",
|
94
|
+
"feat": "featC",
|
95
|
+
"feat_emb": [0.7, 0.8, 0.9]}
|
96
|
+
]
|
97
|
+
colls["edges"].load.return_value = None
|
98
|
+
|
99
|
+
def collection_side_effect(name):
|
100
|
+
"""
|
101
|
+
Mock side effect for Collection to return nodes or edges based on name.
|
102
|
+
"""
|
103
|
+
if "nodes" in name:
|
104
|
+
return colls["nodes"]
|
105
|
+
if "edges" in name:
|
106
|
+
return colls["edges"]
|
107
|
+
return None
|
108
|
+
mock_collection.side_effect = collection_side_effect
|
109
|
+
|
110
|
+
# Mock MultimodalPCSTPruning
|
111
|
+
mock_pcst_instance = MagicMock()
|
112
|
+
mock_pcst_instance.extract_subgraph.return_value = {
|
113
|
+
"nodes": pd.Series([1, 2]),
|
114
|
+
"edges": pd.Series([0])
|
115
|
+
}
|
116
|
+
mock_pcst.return_value = mock_pcst_instance
|
117
|
+
|
118
|
+
# Patch hydra.compose to return config objects
|
119
|
+
with patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
120
|
+
"milvus_multimodal_subgraph_extraction.hydra.initialize"), \
|
121
|
+
patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
122
|
+
"milvus_multimodal_subgraph_extraction.hydra.compose") as mock_compose:
|
123
|
+
mock_compose.return_value = MagicMock()
|
124
|
+
mock_compose.return_value.app.frontend = self.cfg_db
|
125
|
+
mock_compose.return_value.tools.multimodal_subgraph_extraction = self.cfg
|
126
|
+
|
127
|
+
response = self.tool.invoke(
|
128
|
+
input={"prompt": self.prompt,
|
129
|
+
"tool_call_id": "subgraph_extraction_tool",
|
130
|
+
"state": self.state,
|
131
|
+
"arg_data": self.arg_data}
|
132
|
+
)
|
133
|
+
|
134
|
+
# Check tool message
|
135
|
+
self.assertEqual(response.update["messages"][-1].tool_call_id, "subgraph_extraction_tool")
|
136
|
+
|
137
|
+
# Check extracted subgraph dictionary
|
138
|
+
dic_extracted_graph = response.update["dic_extracted_graph"][0]
|
139
|
+
self.assertIsInstance(dic_extracted_graph, dict)
|
140
|
+
self.assertEqual(dic_extracted_graph["name"], self.arg_data["extraction_name"])
|
141
|
+
self.assertEqual(dic_extracted_graph["graph_source"], "TestGraph")
|
142
|
+
self.assertEqual(dic_extracted_graph["topk_nodes"], 5)
|
143
|
+
self.assertEqual(dic_extracted_graph["topk_edges"], 5)
|
144
|
+
self.assertIsInstance(dic_extracted_graph["graph_dict"], dict)
|
145
|
+
self.assertGreater(len(dic_extracted_graph["graph_dict"]["nodes"]), 0)
|
146
|
+
self.assertGreater(len(dic_extracted_graph["graph_dict"]["edges"]), 0)
|
147
|
+
self.assertIsInstance(dic_extracted_graph["graph_text"], str)
|
148
|
+
# Check if the nodes are in the graph_text
|
149
|
+
self.assertTrue(all(
|
150
|
+
n[0] in dic_extracted_graph["graph_text"].replace('"', '')
|
151
|
+
for subgraph_nodes in dic_extracted_graph["graph_dict"]["nodes"]
|
152
|
+
for n in subgraph_nodes
|
153
|
+
))
|
154
|
+
# Check if the edges are in the graph_text
|
155
|
+
self.assertTrue(all(
|
156
|
+
",".join([str(e[0])] + str(e[2]['label'][0]).split(",") + [str(e[1])])
|
157
|
+
in dic_extracted_graph["graph_text"].replace('"', '').\
|
158
|
+
replace("[", "").replace("]", "").replace("'", "")
|
159
|
+
for subgraph_edges in dic_extracted_graph["graph_dict"]["edges"]
|
160
|
+
for e in subgraph_edges
|
161
|
+
))
|
162
|
+
|
163
|
+
# Another test for unknown collection
|
164
|
+
result = collection_side_effect("unknown")
|
165
|
+
self.assertIsNone(result)
|
166
|
+
|
167
|
+
@patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
168
|
+
"milvus_multimodal_subgraph_extraction.Collection")
|
169
|
+
@patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
170
|
+
"milvus_multimodal_subgraph_extraction.pd.read_excel")
|
171
|
+
@patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
172
|
+
"milvus_multimodal_subgraph_extraction.MultimodalPCSTPruning")
|
173
|
+
@patch("pymilvus.connections")
|
174
|
+
def test_extract_multimodal_subgraph_w_doc(self,
|
175
|
+
mock_connections,
|
176
|
+
mock_pcst,
|
177
|
+
mock_read_excel,
|
178
|
+
mock_collection):
|
179
|
+
"""
|
180
|
+
Test the multimodal subgraph extraction tool for text as modality, plus genes.
|
181
|
+
"""
|
182
|
+
# Mock Milvus connection utilities
|
183
|
+
mock_connections.has_connection.return_value = True
|
184
|
+
|
185
|
+
# With uploaded_files (with doc)
|
186
|
+
self.state["uploaded_files"] = [
|
187
|
+
{"file_type": "multimodal", "file_path": "dummy.xlsx"}
|
188
|
+
]
|
189
|
+
self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
|
190
|
+
self.state["selections"] = {"gene/protein": ["JAK1", "JAK2"]}
|
191
|
+
|
192
|
+
# Mock pd.read_excel to return a dict of DataFrames
|
193
|
+
df = pd.DataFrame({
|
194
|
+
"name": ["JAK1", "JAK2"],
|
195
|
+
"node_type": ["gene/protein", "gene/protein"]
|
196
|
+
})
|
197
|
+
mock_read_excel.return_value = {"gene/protein": df}
|
198
|
+
|
199
|
+
# Mock Collection for nodes and edges
|
200
|
+
colls = {}
|
201
|
+
colls["nodes"] = MagicMock()
|
202
|
+
colls["nodes"] = MagicMock()
|
203
|
+
colls["nodes"].query.return_value = [
|
204
|
+
{"node_index": 0,
|
205
|
+
"node_id": "id1",
|
206
|
+
"node_name": "JAK1",
|
207
|
+
"node_type": "gene/protein",
|
208
|
+
"feat": "featA",
|
209
|
+
"feat_emb": [0.1, 0.2, 0.3],
|
210
|
+
"desc": "descA",
|
211
|
+
"desc_emb": [0.1, 0.2, 0.3]},
|
212
|
+
{"node_index": 1,
|
213
|
+
"node_id": "id2",
|
214
|
+
"node_name": "JAK2",
|
215
|
+
"node_type": "gene/protein",
|
216
|
+
"feat": "featB",
|
217
|
+
"feat_emb": [0.4, 0.5, 0.6],
|
218
|
+
"desc": "descB",
|
219
|
+
"desc_emb": [0.4, 0.5, 0.6]}
|
220
|
+
]
|
221
|
+
colls["nodes"].load.return_value = None
|
222
|
+
|
223
|
+
colls["edges"] = MagicMock()
|
224
|
+
colls["edges"].query.return_value = [
|
225
|
+
{"triplet_index": 0,
|
226
|
+
"head_id": "id1",
|
227
|
+
"head_index": 0,
|
228
|
+
"tail_id": "id2",
|
229
|
+
"tail_index": 1,
|
230
|
+
"edge_type": "gene/protein,ppi,gene/protein",
|
231
|
+
"display_relation": "ppi",
|
232
|
+
"feat": "featC",
|
233
|
+
"feat_emb": [0.7, 0.8, 0.9]}
|
234
|
+
]
|
235
|
+
colls["edges"].load.return_value = None
|
236
|
+
|
237
|
+
def collection_side_effect(name):
|
238
|
+
"""
|
239
|
+
Mock side effect for Collection to return nodes or edges based on name.
|
240
|
+
"""
|
241
|
+
if "nodes" in name:
|
242
|
+
return colls["nodes"]
|
243
|
+
if "edges" in name:
|
244
|
+
return colls["edges"]
|
245
|
+
return None
|
246
|
+
mock_collection.side_effect = collection_side_effect
|
247
|
+
|
248
|
+
# Mock MultimodalPCSTPruning
|
249
|
+
mock_pcst_instance = MagicMock()
|
250
|
+
mock_pcst_instance.extract_subgraph.return_value = {
|
251
|
+
"nodes": pd.Series([1, 2]),
|
252
|
+
"edges": pd.Series([0])
|
253
|
+
}
|
254
|
+
mock_pcst.return_value = mock_pcst_instance
|
255
|
+
|
256
|
+
# Patch hydra.compose to return config objects
|
257
|
+
with patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
258
|
+
"milvus_multimodal_subgraph_extraction.hydra.initialize"), \
|
259
|
+
patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
260
|
+
"milvus_multimodal_subgraph_extraction.hydra.compose") as mock_compose:
|
261
|
+
mock_compose.return_value = MagicMock()
|
262
|
+
mock_compose.return_value.app.frontend = self.cfg_db
|
263
|
+
mock_compose.return_value.tools.multimodal_subgraph_extraction = self.cfg
|
264
|
+
|
265
|
+
response = self.tool.invoke(
|
266
|
+
input={"prompt": self.prompt,
|
267
|
+
"tool_call_id": "subgraph_extraction_tool",
|
268
|
+
"state": self.state,
|
269
|
+
"arg_data": self.arg_data}
|
270
|
+
)
|
271
|
+
|
272
|
+
# Check tool message
|
273
|
+
self.assertEqual(response.update["messages"][-1].tool_call_id, "subgraph_extraction_tool")
|
274
|
+
|
275
|
+
# Check extracted subgraph dictionary
|
276
|
+
dic_extracted_graph = response.update["dic_extracted_graph"][0]
|
277
|
+
self.assertIsInstance(dic_extracted_graph, dict)
|
278
|
+
self.assertEqual(dic_extracted_graph["name"], self.arg_data["extraction_name"])
|
279
|
+
self.assertEqual(dic_extracted_graph["graph_source"], "TestGraph")
|
280
|
+
self.assertEqual(dic_extracted_graph["topk_nodes"], 5)
|
281
|
+
self.assertEqual(dic_extracted_graph["topk_edges"], 5)
|
282
|
+
self.assertIsInstance(dic_extracted_graph["graph_dict"], dict)
|
283
|
+
self.assertGreater(len(dic_extracted_graph["graph_dict"]["nodes"]), 0)
|
284
|
+
self.assertGreater(len(dic_extracted_graph["graph_dict"]["edges"]), 0)
|
285
|
+
self.assertIsInstance(dic_extracted_graph["graph_text"], str)
|
286
|
+
# Check if the nodes are in the graph_text
|
287
|
+
self.assertTrue(all(
|
288
|
+
n[0] in dic_extracted_graph["graph_text"].replace('"', '')
|
289
|
+
for subgraph_nodes in dic_extracted_graph["graph_dict"]["nodes"]
|
290
|
+
for n in subgraph_nodes
|
291
|
+
))
|
292
|
+
# Check if the edges are in the graph_text
|
293
|
+
self.assertTrue(all(
|
294
|
+
",".join([str(e[0])] + str(e[2]['label'][0]).split(",") + [str(e[1])])
|
295
|
+
in dic_extracted_graph["graph_text"].replace('"', '').\
|
296
|
+
replace("[", "").replace("]", "").replace("'", "")
|
297
|
+
for subgraph_edges in dic_extracted_graph["graph_dict"]["edges"]
|
298
|
+
for e in subgraph_edges
|
299
|
+
))
|
300
|
+
|
301
|
+
# Another test for unknown collection
|
302
|
+
result = collection_side_effect("unknown")
|
303
|
+
self.assertIsNone(result)
|
304
|
+
|
305
|
+
def test_extract_multimodal_subgraph_wo_doc_gpu(self):
|
306
|
+
"""
|
307
|
+
Test the multimodal subgraph extraction tool for only text as modality,
|
308
|
+
simulating GPU (cudf/cupy) environment.
|
309
|
+
"""
|
310
|
+
module_name = "aiagents4pharma.talk2knowledgegraphs.tools."+\
|
311
|
+
"milvus_multimodal_subgraph_extraction"
|
312
|
+
with patch.dict("sys.modules", {"cupy": np, "cudf": pd}):
|
313
|
+
mod = importlib.reload(importlib.import_module(module_name))
|
314
|
+
# Patch Collection and MultimodalPCSTPruning after reload
|
315
|
+
with patch(f"{module_name}.Collection") as mock_collection, \
|
316
|
+
patch(f"{module_name}.MultimodalPCSTPruning") as mock_pcst, \
|
317
|
+
patch("pymilvus.connections") as mock_connections:
|
318
|
+
# Setup mocks as in the original test
|
319
|
+
mock_connections.has_connection.return_value = True
|
320
|
+
colls = {}
|
321
|
+
colls["nodes"] = MagicMock()
|
322
|
+
colls["nodes"].query.return_value = [
|
323
|
+
{"node_index": 0,
|
324
|
+
"node_id": "id1",
|
325
|
+
"node_name": "JAK1",
|
326
|
+
"node_type": "gene/protein",
|
327
|
+
"feat": "featA",
|
328
|
+
"feat_emb": [0.1, 0.2, 0.3],
|
329
|
+
"desc": "descA",
|
330
|
+
"desc_emb": [0.1, 0.2, 0.3]},
|
331
|
+
{"node_index": 1,
|
332
|
+
"node_id": "id2",
|
333
|
+
"node_name": "JAK2",
|
334
|
+
"node_type": "gene/protein",
|
335
|
+
"feat": "featB",
|
336
|
+
"feat_emb": [0.4, 0.5, 0.6],
|
337
|
+
"desc": "descB",
|
338
|
+
"desc_emb": [0.4, 0.5, 0.6]}
|
339
|
+
]
|
340
|
+
colls["nodes"].load.return_value = None
|
341
|
+
colls["edges"] = MagicMock()
|
342
|
+
colls["edges"].query.return_value = [
|
343
|
+
{"triplet_index": 0,
|
344
|
+
"head_id": "id1",
|
345
|
+
"head_index": 0,
|
346
|
+
"tail_id": "id2",
|
347
|
+
"tail_index": 1,
|
348
|
+
"edge_type": "gene/protein,ppi,gene/protein",
|
349
|
+
"display_relation": "ppi",
|
350
|
+
"feat": "featC",
|
351
|
+
"feat_emb": [0.7, 0.8, 0.9]}
|
352
|
+
]
|
353
|
+
colls["edges"].load.return_value = None
|
354
|
+
def collection_side_effect(name):
|
355
|
+
if "nodes" in name:
|
356
|
+
return colls["nodes"]
|
357
|
+
if "edges" in name:
|
358
|
+
return colls["edges"]
|
359
|
+
return None
|
360
|
+
mock_collection.side_effect = collection_side_effect
|
361
|
+
mock_pcst_instance = MagicMock()
|
362
|
+
mock_pcst_instance.extract_subgraph.return_value = {
|
363
|
+
"nodes": pd.Series([1, 2]),
|
364
|
+
"edges": pd.Series([0])
|
365
|
+
}
|
366
|
+
mock_pcst.return_value = mock_pcst_instance
|
367
|
+
# Setup config mocks
|
368
|
+
tool_cls = getattr(mod, "MultimodalSubgraphExtractionTool")
|
369
|
+
tool = tool_cls()
|
370
|
+
|
371
|
+
# Patch hydra.compose
|
372
|
+
with patch(f"{module_name}.hydra.initialize"), \
|
373
|
+
patch(f"{module_name}.hydra.compose") as mock_compose:
|
374
|
+
mock_compose.return_value = MagicMock()
|
375
|
+
mock_compose.return_value.app.frontend = self.cfg_db
|
376
|
+
mock_compose.return_value.tools.multimodal_subgraph_extraction = self.cfg
|
377
|
+
self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
|
378
|
+
self.state["selections"] = {}
|
379
|
+
response = tool.invoke(
|
380
|
+
input={"prompt": self.prompt,
|
381
|
+
"tool_call_id": "subgraph_extraction_tool",
|
382
|
+
"state": self.state,
|
383
|
+
"arg_data": self.arg_data}
|
384
|
+
)
|
385
|
+
# Check tool message
|
386
|
+
self.assertEqual(response.update["messages"][-1].tool_call_id,
|
387
|
+
"subgraph_extraction_tool")
|
388
|
+
dic_extracted_graph = response.update["dic_extracted_graph"][0]
|
389
|
+
self.assertIsInstance(dic_extracted_graph, dict)
|
390
|
+
self.assertEqual(dic_extracted_graph["name"], self.arg_data["extraction_name"])
|
391
|
+
self.assertEqual(dic_extracted_graph["graph_source"], "TestGraph")
|
392
|
+
self.assertEqual(dic_extracted_graph["topk_nodes"], 5)
|
393
|
+
self.assertEqual(dic_extracted_graph["topk_edges"], 5)
|
394
|
+
self.assertIsInstance(dic_extracted_graph["graph_dict"], dict)
|
395
|
+
self.assertGreater(len(dic_extracted_graph["graph_dict"]["nodes"]), 0)
|
396
|
+
self.assertGreater(len(dic_extracted_graph["graph_dict"]["edges"]), 0)
|
397
|
+
self.assertIsInstance(dic_extracted_graph["graph_text"], str)
|
398
|
+
self.assertTrue(all(
|
399
|
+
n[0] in dic_extracted_graph["graph_text"].replace('"', '')
|
400
|
+
for subgraph_nodes in dic_extracted_graph["graph_dict"]["nodes"]
|
401
|
+
for n in subgraph_nodes
|
402
|
+
))
|
403
|
+
self.assertTrue(all(
|
404
|
+
",".join([str(e[0])] + str(e[2]['label'][0]).split(",") + [str(e[1])])
|
405
|
+
in dic_extracted_graph["graph_text"].replace('"', '').\
|
406
|
+
replace("[", "").replace("]", "").replace("'", "")
|
407
|
+
for subgraph_edges in dic_extracted_graph["graph_dict"]["edges"]
|
408
|
+
for e in subgraph_edges
|
409
|
+
))
|
410
|
+
|
411
|
+
# Another test for unknown collection
|
412
|
+
result = collection_side_effect("unknown")
|
413
|
+
self.assertIsNone(result)
|
@@ -42,11 +42,11 @@ def test_enrich_documents(enrich_obj):
|
|
42
42
|
"XYZ_0000000",
|
43
43
|
]
|
44
44
|
descriptions = enrich_obj.enrich_documents(ols_terms)
|
45
|
-
assert descriptions[0]
|
46
|
-
assert descriptions[1]
|
47
|
-
assert descriptions[2]
|
48
|
-
assert descriptions[3]
|
49
|
-
assert descriptions[4]
|
45
|
+
assert CL_DESC in descriptions[0]
|
46
|
+
assert GO_DESC in descriptions[1]
|
47
|
+
assert UBERON_DESC in descriptions[2]
|
48
|
+
assert HP_DESC in descriptions[3]
|
49
|
+
assert MONDO_DESC in descriptions[4]
|
50
50
|
assert descriptions[5] is None
|
51
51
|
|
52
52
|
|
@@ -61,9 +61,9 @@ def test_enrich_documents_with_rag(enrich_obj):
|
|
61
61
|
"XYZ_0000000",
|
62
62
|
]
|
63
63
|
descriptions = enrich_obj.enrich_documents_with_rag(ols_terms, None)
|
64
|
-
assert descriptions[0]
|
65
|
-
assert descriptions[1]
|
66
|
-
assert descriptions[2]
|
67
|
-
assert descriptions[3]
|
68
|
-
assert descriptions[4]
|
64
|
+
assert CL_DESC in descriptions[0]
|
65
|
+
assert GO_DESC in descriptions[1]
|
66
|
+
assert UBERON_DESC in descriptions[2]
|
67
|
+
assert HP_DESC in descriptions[3]
|
68
|
+
assert MONDO_DESC in descriptions[4]
|
69
69
|
assert descriptions[5] is None
|
@@ -0,0 +1,175 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for tools/utils/extractions/milvus_multimodal_pcst.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import importlib
|
6
|
+
import sys
|
7
|
+
import unittest
|
8
|
+
from unittest.mock import patch, MagicMock, mock_open
|
9
|
+
import numpy as np
|
10
|
+
import pandas as pd
|
11
|
+
from ..utils.extractions.milvus_multimodal_pcst import MultimodalPCSTPruning
|
12
|
+
|
13
|
+
class TestMultimodalPCSTPruning(unittest.TestCase):
|
14
|
+
"""
|
15
|
+
Test cases for MultimodalPCSTPruning class (Milvus-based PCST pruning).
|
16
|
+
"""
|
17
|
+
def setUp(self):
|
18
|
+
# Patch cupy and cudf to simulate GPU environment
|
19
|
+
patcher_cupy = patch.dict('sys.modules', {'cupy': MagicMock(), 'cudf': MagicMock()})
|
20
|
+
patcher_cupy.start()
|
21
|
+
self.addCleanup(patcher_cupy.stop)
|
22
|
+
|
23
|
+
# Patch pcst_fast
|
24
|
+
self.pcst_fast_patcher = patch("aiagents4pharma.talk2knowledgegraphs.utils."
|
25
|
+
"extractions.milvus_multimodal_pcst.pcst_fast")
|
26
|
+
self.mock_pcst_fast = self.pcst_fast_patcher.start()
|
27
|
+
self.addCleanup(self.pcst_fast_patcher.stop)
|
28
|
+
self.mock_pcst_fast.pcst_fast.return_value = ([0, 1], [0])
|
29
|
+
|
30
|
+
# Patch Collection
|
31
|
+
self.collection_patcher = patch("aiagents4pharma.talk2knowledgegraphs.utils."
|
32
|
+
"extractions.milvus_multimodal_pcst.Collection")
|
33
|
+
self.mock_collection = self.collection_patcher.start()
|
34
|
+
self.addCleanup(self.collection_patcher.stop)
|
35
|
+
|
36
|
+
# Patch open for cache_edge_index_path
|
37
|
+
self.open_patcher = patch('builtins.open', mock_open(read_data='[[0,1],[1,2]]'))
|
38
|
+
self.mock_open = self.open_patcher.start()
|
39
|
+
self.addCleanup(self.open_patcher.stop)
|
40
|
+
|
41
|
+
# Patch pickle.load to return a numpy array for edge_index
|
42
|
+
self.pickle_patcher = patch("aiagents4pharma.talk2knowledgegraphs.utils."
|
43
|
+
"extractions.milvus_multimodal_pcst.pickle")
|
44
|
+
self.mock_pickle = self.pickle_patcher.start()
|
45
|
+
self.addCleanup(self.pickle_patcher.stop)
|
46
|
+
self.mock_pickle.load.return_value = np.array([[0, 1], [1, 2]])
|
47
|
+
|
48
|
+
# Setup config mock
|
49
|
+
self.cfg = MagicMock()
|
50
|
+
self.cfg.milvus_db.database_name = "testdb"
|
51
|
+
self.cfg.milvus_db.cache_edge_index_path = "dummy_cache.pkl"
|
52
|
+
|
53
|
+
# Setup Collection mocks
|
54
|
+
node_coll = MagicMock()
|
55
|
+
node_coll.num_entities = 2
|
56
|
+
node_coll.search.return_value = [[MagicMock(id=0), MagicMock(id=1)]]
|
57
|
+
edge_coll = MagicMock()
|
58
|
+
edge_coll.num_entities = 2
|
59
|
+
edge_coll.search.return_value = [[MagicMock(id=0, score=1.0), MagicMock(id=1, score=0.5)]]
|
60
|
+
self.mock_collection.side_effect = lambda name: node_coll if "nodes" in name else edge_coll
|
61
|
+
|
62
|
+
def test_extract_subgraph_use_description_true(self):
|
63
|
+
"""
|
64
|
+
Test the extract_subgraph method of MultimodalPCSTPruning with use_description=True.
|
65
|
+
"""
|
66
|
+
# Create instance
|
67
|
+
pcst = MultimodalPCSTPruning(
|
68
|
+
topk=3, topk_e=3, cost_e=0.5, c_const=0.01, root=-1,
|
69
|
+
num_clusters=1, pruning="gw", verbosity_level=0, use_description=True, metric_type="IP"
|
70
|
+
)
|
71
|
+
# Dummy embeddings
|
72
|
+
text_emb = [0.1, 0.2, 0.3]
|
73
|
+
query_emb = [0.1, 0.2, 0.3]
|
74
|
+
modality = "gene/protein"
|
75
|
+
|
76
|
+
# Call extract_subgraph
|
77
|
+
result = pcst.extract_subgraph(text_emb, query_emb, modality, self.cfg)
|
78
|
+
|
79
|
+
# Assertions
|
80
|
+
self.assertIn("nodes", result)
|
81
|
+
self.assertIn("edges", result)
|
82
|
+
self.assertGreaterEqual(len(result["nodes"]), 0)
|
83
|
+
self.assertGreaterEqual(len(result["edges"]), 0)
|
84
|
+
|
85
|
+
def test_extract_subgraph_use_description_false(self):
|
86
|
+
"""
|
87
|
+
Test the extract_subgraph method of MultimodalPCSTPruning with use_description=False.
|
88
|
+
"""
|
89
|
+
# Create instance
|
90
|
+
pcst = MultimodalPCSTPruning(
|
91
|
+
topk=3, topk_e=3, cost_e=0.5, c_const=0.01, root=-1,
|
92
|
+
num_clusters=1, pruning="gw", verbosity_level=0, use_description=False, metric_type="IP"
|
93
|
+
)
|
94
|
+
# Dummy embeddings
|
95
|
+
text_emb = [0.1, 0.2, 0.3]
|
96
|
+
query_emb = [0.1, 0.2, 0.3]
|
97
|
+
modality = "gene/protein"
|
98
|
+
|
99
|
+
# Call extract_subgraph
|
100
|
+
result = pcst.extract_subgraph(text_emb, query_emb, modality, self.cfg)
|
101
|
+
|
102
|
+
# Assertions
|
103
|
+
self.assertIn("nodes", result)
|
104
|
+
self.assertIn("edges", result)
|
105
|
+
self.assertGreaterEqual(len(result["nodes"]), 0)
|
106
|
+
self.assertGreaterEqual(len(result["edges"]), 0)
|
107
|
+
|
108
|
+
def test_extract_subgraph_with_virtual_vertices(self):
|
109
|
+
"""
|
110
|
+
Test get_subgraph_nodes_edges with virtual vertices present (len(virtual_vertices) > 0).
|
111
|
+
"""
|
112
|
+
pcst = MultimodalPCSTPruning(
|
113
|
+
topk=3, topk_e=3, cost_e=0.5, c_const=0.01, root=-1,
|
114
|
+
num_clusters=1, pruning="gw", verbosity_level=0, use_description=True, metric_type="IP"
|
115
|
+
)
|
116
|
+
# Simulate num_nodes = 2, vertices contains [0, 1, 2, 3] (2 and 3 are virtual)
|
117
|
+
num_nodes = 2
|
118
|
+
# vertices: [0, 1, 2, 3] (2 and 3 are virtual)
|
119
|
+
vertices = np.array([0, 1, 2, 3])
|
120
|
+
# edges_dict simulates prior edges and edge_index
|
121
|
+
edges_dict = {
|
122
|
+
"edges": np.array([0, 1, 2]),
|
123
|
+
"num_prior_edges": 2,
|
124
|
+
"edge_index": np.array([[0, 1, 2, 3], [1, 2, 3, 4]])
|
125
|
+
}
|
126
|
+
# mapping simulates mapping for edges and nodes
|
127
|
+
mapping = {
|
128
|
+
"edges": {0: 0, 1: 1},
|
129
|
+
"nodes": {2: 2, 3: 3}
|
130
|
+
}
|
131
|
+
|
132
|
+
# Call extract_subgraph
|
133
|
+
result = pcst.get_subgraph_nodes_edges(num_nodes, vertices, edges_dict, mapping)
|
134
|
+
|
135
|
+
# Assertions
|
136
|
+
self.assertIn("nodes", result)
|
137
|
+
self.assertIn("edges", result)
|
138
|
+
self.assertGreaterEqual(len(result["nodes"]), 0)
|
139
|
+
self.assertGreaterEqual(len(result["edges"]), 0)
|
140
|
+
# Check that virtual edges are included
|
141
|
+
self.assertTrue(any(e in [2, 3] for e in result["edges"]))
|
142
|
+
|
143
|
+
def test_gpu_import_branch(self):
|
144
|
+
"""
|
145
|
+
Test coverage for GPU import branch by patching sys.modules to mock cupy and
|
146
|
+
cudf as numpy and pandas.
|
147
|
+
"""
|
148
|
+
module_name = "aiagents4pharma.talk2knowledgegraphs.utils" + \
|
149
|
+
".extractions.milvus_multimodal_pcst"
|
150
|
+
with patch.dict("sys.modules", {"cupy": np, "cudf": pd}):
|
151
|
+
# Reload the module to trigger the GPU branch
|
152
|
+
mod = importlib.reload(sys.modules[module_name])
|
153
|
+
# Patch Collection, pcst_fast, and pickle after reload
|
154
|
+
with patch(f"{module_name}.Collection", self.mock_collection), \
|
155
|
+
patch(f"{module_name}.pcst_fast", self.mock_pcst_fast), \
|
156
|
+
patch(f"{module_name}.pickle", self.mock_pickle):
|
157
|
+
pcst_pruning_cls = getattr(mod, "MultimodalPCSTPruning")
|
158
|
+
pcst = pcst_pruning_cls(
|
159
|
+
topk=3, topk_e=3, cost_e=0.5, c_const=0.01, root=-1,
|
160
|
+
num_clusters=1, pruning="gw", verbosity_level=0, use_description=True,
|
161
|
+
metric_type="IP"
|
162
|
+
)
|
163
|
+
# Dummy embeddings
|
164
|
+
text_emb = [0.1, 0.2, 0.3]
|
165
|
+
query_emb = [0.1, 0.2, 0.3]
|
166
|
+
modality = "gene/protein"
|
167
|
+
|
168
|
+
# Call extract_subgraph
|
169
|
+
result = pcst.extract_subgraph(text_emb, query_emb, modality, self.cfg)
|
170
|
+
|
171
|
+
# Assertions
|
172
|
+
self.assertIn("nodes", result)
|
173
|
+
self.assertIn("edges", result)
|
174
|
+
self.assertGreaterEqual(len(result["nodes"]), 0)
|
175
|
+
self.assertGreaterEqual(len(result["edges"]), 0)
|
@@ -4,6 +4,17 @@ Test cases for utils/pubchem_utils.py
|
|
4
4
|
|
5
5
|
from ..utils import pubchem_utils
|
6
6
|
|
7
|
+
def test_cas_rn2pubchem_cid():
|
8
|
+
"""
|
9
|
+
Test the casRN2pubchem_cid function.
|
10
|
+
|
11
|
+
The CAS RN for ethyl cabonate is 105-58-8.
|
12
|
+
The PubChem CID for ethyl cabonate is 7766.
|
13
|
+
"""
|
14
|
+
casrn = "105-58-8"
|
15
|
+
pubchem_cid = pubchem_utils.cas_rn2pubchem_cid(casrn)
|
16
|
+
assert pubchem_cid == 7766
|
17
|
+
|
7
18
|
def test_external_id2pubchem_cid():
|
8
19
|
"""
|
9
20
|
Test the external_id2pubchem_cid function.
|
@@ -3,5 +3,6 @@ This file is used to import all the models in the package.
|
|
3
3
|
'''
|
4
4
|
from . import subgraph_extraction
|
5
5
|
from . import multimodal_subgraph_extraction
|
6
|
+
from . import milvus_multimodal_subgraph_extraction
|
6
7
|
from . import subgraph_summarization
|
7
8
|
from . import graphrag_reasoning
|