aiagents4pharma 1.42.0__py3-none-any.whl → 1.44.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/default.yaml +17 -2
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py +618 -413
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_extractions_milvus_multimodal_pcst.py +362 -25
- aiagents4pharma/talk2knowledgegraphs/tools/milvus_multimodal_subgraph_extraction.py +146 -109
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py +240 -83
- aiagents4pharma/talk2scholars/agents/paper_download_agent.py +7 -4
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +49 -95
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/default.yaml +15 -1
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/pdf_agent/default.yaml +16 -2
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +40 -5
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +15 -5
- aiagents4pharma/talk2scholars/configs/config.yaml +1 -3
- aiagents4pharma/talk2scholars/configs/tools/paper_download/default.yaml +124 -0
- aiagents4pharma/talk2scholars/tests/test_arxiv_downloader.py +478 -0
- aiagents4pharma/talk2scholars/tests/test_base_paper_downloader.py +620 -0
- aiagents4pharma/talk2scholars/tests/test_biorxiv_downloader.py +697 -0
- aiagents4pharma/talk2scholars/tests/test_medrxiv_downloader.py +534 -0
- aiagents4pharma/talk2scholars/tests/test_paper_download_agent.py +22 -12
- aiagents4pharma/talk2scholars/tests/test_paper_downloader.py +545 -0
- aiagents4pharma/talk2scholars/tests/test_pubmed_downloader.py +1067 -0
- aiagents4pharma/talk2scholars/tools/paper_download/__init__.py +2 -4
- aiagents4pharma/talk2scholars/tools/paper_download/paper_downloader.py +457 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/__init__.py +20 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/arxiv_downloader.py +209 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/base_paper_downloader.py +343 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/biorxiv_downloader.py +321 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/medrxiv_downloader.py +198 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/pubmed_downloader.py +337 -0
- aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +97 -45
- aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +47 -29
- {aiagents4pharma-1.42.0.dist-info → aiagents4pharma-1.44.0.dist-info}/METADATA +3 -1
- {aiagents4pharma-1.42.0.dist-info → aiagents4pharma-1.44.0.dist-info}/RECORD +36 -33
- aiagents4pharma/talk2scholars/configs/tools/download_arxiv_paper/default.yaml +0 -4
- aiagents4pharma/talk2scholars/configs/tools/download_biorxiv_paper/__init__.py +0 -3
- aiagents4pharma/talk2scholars/configs/tools/download_biorxiv_paper/default.yaml +0 -2
- aiagents4pharma/talk2scholars/configs/tools/download_medrxiv_paper/__init__.py +0 -3
- aiagents4pharma/talk2scholars/configs/tools/download_medrxiv_paper/default.yaml +0 -2
- aiagents4pharma/talk2scholars/tests/test_paper_download_biorxiv.py +0 -151
- aiagents4pharma/talk2scholars/tests/test_paper_download_medrxiv.py +0 -151
- aiagents4pharma/talk2scholars/tests/test_paper_download_tools.py +0 -249
- aiagents4pharma/talk2scholars/tools/paper_download/download_arxiv_input.py +0 -177
- aiagents4pharma/talk2scholars/tools/paper_download/download_biorxiv_input.py +0 -114
- aiagents4pharma/talk2scholars/tools/paper_download/download_medrxiv_input.py +0 -114
- /aiagents4pharma/talk2scholars/configs/tools/{download_arxiv_paper → paper_download}/__init__.py +0 -0
- {aiagents4pharma-1.42.0.dist-info → aiagents4pharma-1.44.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.42.0.dist-info → aiagents4pharma-1.44.0.dist-info}/licenses/LICENSE +0 -0
- {aiagents4pharma-1.42.0.dist-info → aiagents4pharma-1.44.0.dist-info}/top_level.txt +0 -0
aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py
CHANGED
@@ -1,413 +1,618 @@
|
|
1
|
-
"""
|
2
|
-
Test cases for tools/milvus_multimodal_subgraph_extraction.py
|
3
|
-
"""
|
4
|
-
|
5
|
-
import importlib
|
6
|
-
import unittest
|
7
|
-
from unittest.mock import patch, MagicMock
|
8
|
-
import numpy as np
|
9
|
-
import pandas as pd
|
10
|
-
from ..tools.milvus_multimodal_subgraph_extraction import MultimodalSubgraphExtractionTool
|
11
|
-
|
12
|
-
class TestMultimodalSubgraphExtractionTool(unittest.TestCase):
|
13
|
-
"""
|
14
|
-
Test cases for MultimodalSubgraphExtractionTool (Milvus)
|
15
|
-
"""
|
16
|
-
def setUp(self):
|
17
|
-
self.tool = MultimodalSubgraphExtractionTool()
|
18
|
-
self.state = {
|
19
|
-
"uploaded_files": [],
|
20
|
-
"embedding_model": MagicMock(),
|
21
|
-
"topk_nodes": 5,
|
22
|
-
"topk_edges": 5,
|
23
|
-
"dic_source_graph": [{"name": "TestGraph"}],
|
24
|
-
}
|
25
|
-
self.prompt = "Find subgraph for test"
|
26
|
-
self.arg_data = {"extraction_name": "subkg_12345"}
|
27
|
-
self.cfg_db = MagicMock()
|
28
|
-
self.cfg_db.milvus_db.database_name = "testdb"
|
29
|
-
self.cfg_db.milvus_db.alias = "default"
|
30
|
-
self.cfg = MagicMock()
|
31
|
-
self.cfg.cost_e = 1.0
|
32
|
-
self.cfg.c_const = 1.0
|
33
|
-
self.cfg.root = 0
|
34
|
-
self.cfg.num_clusters = 1
|
35
|
-
self.cfg.pruning = True
|
36
|
-
self.cfg.verbosity_level = 0
|
37
|
-
self.cfg.search_metric_type = "L2"
|
38
|
-
self.cfg.node_colors_dict = {"gene/protein": "red"}
|
39
|
-
|
40
|
-
@patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
41
|
-
"milvus_multimodal_subgraph_extraction.Collection")
|
42
|
-
@patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
43
|
-
"milvus_multimodal_subgraph_extraction.MultimodalPCSTPruning")
|
44
|
-
@patch("pymilvus.connections")
|
45
|
-
def test_extract_multimodal_subgraph_wo_doc(self,
|
46
|
-
mock_connections,
|
47
|
-
mock_pcst,
|
48
|
-
mock_collection):
|
49
|
-
"""
|
50
|
-
Test the multimodal subgraph extraction tool for only text as modality.
|
51
|
-
"""
|
52
|
-
|
53
|
-
# Mock Milvus connection utilities
|
54
|
-
mock_connections.has_connection.return_value = True
|
55
|
-
|
56
|
-
# No uploaded_files (no doc)
|
57
|
-
self.state["uploaded_files"] = []
|
58
|
-
self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
|
59
|
-
self.state["selections"] = {}
|
60
|
-
|
61
|
-
# Mock Collection for nodes and edges
|
62
|
-
colls = {}
|
63
|
-
colls["nodes"] = MagicMock()
|
64
|
-
colls["nodes"] = MagicMock()
|
65
|
-
colls["nodes"].query.return_value = [
|
66
|
-
{"node_index": 0,
|
67
|
-
"node_id": "id1",
|
68
|
-
"node_name": "JAK1",
|
69
|
-
"node_type": "gene/protein",
|
70
|
-
"feat": "featA",
|
71
|
-
"feat_emb": [0.1, 0.2, 0.3],
|
72
|
-
"desc": "descA",
|
73
|
-
"desc_emb": [0.1, 0.2, 0.3]},
|
74
|
-
{"node_index": 1,
|
75
|
-
"node_id": "id2",
|
76
|
-
"node_name": "JAK2",
|
77
|
-
"node_type": "gene/protein",
|
78
|
-
"feat": "featB",
|
79
|
-
"feat_emb": [0.4, 0.5, 0.6],
|
80
|
-
"desc": "descB",
|
81
|
-
"desc_emb": [0.4, 0.5, 0.6]}
|
82
|
-
]
|
83
|
-
colls["nodes"].load.return_value = None
|
84
|
-
|
85
|
-
colls["edges"] = MagicMock()
|
86
|
-
colls["edges"].query.return_value = [
|
87
|
-
{"triplet_index": 0,
|
88
|
-
"head_id": "id1",
|
89
|
-
"head_index": 0,
|
90
|
-
"tail_id": "id2",
|
91
|
-
"tail_index": 1,
|
92
|
-
"edge_type": "gene/protein,ppi,gene/protein",
|
93
|
-
"display_relation": "ppi",
|
94
|
-
"feat": "featC",
|
95
|
-
"feat_emb": [0.7, 0.8, 0.9]}
|
96
|
-
]
|
97
|
-
colls["edges"].load.return_value = None
|
98
|
-
|
99
|
-
def collection_side_effect(name):
|
100
|
-
"""
|
101
|
-
Mock side effect for Collection to return nodes or edges based on name.
|
102
|
-
"""
|
103
|
-
if "nodes" in name:
|
104
|
-
return colls["nodes"]
|
105
|
-
if "edges" in name:
|
106
|
-
return colls["edges"]
|
107
|
-
return None
|
108
|
-
mock_collection.side_effect = collection_side_effect
|
109
|
-
|
110
|
-
# Mock MultimodalPCSTPruning
|
111
|
-
mock_pcst_instance = MagicMock()
|
112
|
-
mock_pcst_instance.extract_subgraph.return_value = {
|
113
|
-
"nodes": pd.Series([1, 2]),
|
114
|
-
"edges": pd.Series([0])
|
115
|
-
}
|
116
|
-
mock_pcst.return_value = mock_pcst_instance
|
117
|
-
|
118
|
-
# Patch hydra.compose to return config objects
|
119
|
-
with patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
120
|
-
"milvus_multimodal_subgraph_extraction.hydra.initialize"), \
|
121
|
-
patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
122
|
-
"milvus_multimodal_subgraph_extraction.hydra.compose") as mock_compose:
|
123
|
-
mock_compose.return_value = MagicMock()
|
124
|
-
mock_compose.return_value.app.frontend = self.cfg_db
|
125
|
-
mock_compose.return_value.tools.multimodal_subgraph_extraction = self.cfg
|
126
|
-
|
127
|
-
response = self.tool.invoke(
|
128
|
-
input={"prompt": self.prompt,
|
129
|
-
"tool_call_id": "subgraph_extraction_tool",
|
130
|
-
"state": self.state,
|
131
|
-
"arg_data": self.arg_data}
|
132
|
-
)
|
133
|
-
|
134
|
-
# Check tool message
|
135
|
-
self.assertEqual(response.update["messages"][-1].tool_call_id, "subgraph_extraction_tool")
|
136
|
-
|
137
|
-
# Check extracted subgraph dictionary
|
138
|
-
dic_extracted_graph = response.update["dic_extracted_graph"][0]
|
139
|
-
self.assertIsInstance(dic_extracted_graph, dict)
|
140
|
-
self.assertEqual(dic_extracted_graph["name"], self.arg_data["extraction_name"])
|
141
|
-
self.assertEqual(dic_extracted_graph["graph_source"], "TestGraph")
|
142
|
-
self.assertEqual(dic_extracted_graph["topk_nodes"], 5)
|
143
|
-
self.assertEqual(dic_extracted_graph["topk_edges"], 5)
|
144
|
-
self.assertIsInstance(dic_extracted_graph["graph_dict"], dict)
|
145
|
-
self.assertGreater(len(dic_extracted_graph["graph_dict"]["nodes"]), 0)
|
146
|
-
self.assertGreater(len(dic_extracted_graph["graph_dict"]["edges"]), 0)
|
147
|
-
self.assertIsInstance(dic_extracted_graph["graph_text"], str)
|
148
|
-
# Check if the nodes are in the graph_text
|
149
|
-
self.assertTrue(all(
|
150
|
-
n[0] in dic_extracted_graph["graph_text"].replace('"', '')
|
151
|
-
for subgraph_nodes in dic_extracted_graph["graph_dict"]["nodes"]
|
152
|
-
for n in subgraph_nodes
|
153
|
-
))
|
154
|
-
# Check if the edges are in the graph_text
|
155
|
-
self.assertTrue(all(
|
156
|
-
",".join([str(e[0])] + str(e[2]['label'][0]).split(",") + [str(e[1])])
|
157
|
-
in dic_extracted_graph["graph_text"].replace('"', '').\
|
158
|
-
replace("[", "").replace("]", "").replace("'", "")
|
159
|
-
for subgraph_edges in dic_extracted_graph["graph_dict"]["edges"]
|
160
|
-
for e in subgraph_edges
|
161
|
-
))
|
162
|
-
|
163
|
-
# Another test for unknown collection
|
164
|
-
result = collection_side_effect("unknown")
|
165
|
-
self.assertIsNone(result)
|
166
|
-
|
167
|
-
@patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
168
|
-
"milvus_multimodal_subgraph_extraction.Collection")
|
169
|
-
@patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
170
|
-
"milvus_multimodal_subgraph_extraction.pd.read_excel")
|
171
|
-
@patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
172
|
-
"milvus_multimodal_subgraph_extraction.MultimodalPCSTPruning")
|
173
|
-
@patch("pymilvus.connections")
|
174
|
-
def test_extract_multimodal_subgraph_w_doc(self,
|
175
|
-
mock_connections,
|
176
|
-
mock_pcst,
|
177
|
-
mock_read_excel,
|
178
|
-
mock_collection):
|
179
|
-
"""
|
180
|
-
Test the multimodal subgraph extraction tool for text as modality, plus genes.
|
181
|
-
"""
|
182
|
-
# Mock Milvus connection utilities
|
183
|
-
mock_connections.has_connection.return_value = True
|
184
|
-
|
185
|
-
# With uploaded_files (with doc)
|
186
|
-
self.state["uploaded_files"] = [
|
187
|
-
{"file_type": "multimodal", "file_path": "dummy.xlsx"}
|
188
|
-
]
|
189
|
-
self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
|
190
|
-
self.state["selections"] = {"gene/protein": ["JAK1", "JAK2"]}
|
191
|
-
|
192
|
-
# Mock pd.read_excel to return a dict of DataFrames
|
193
|
-
df = pd.DataFrame({
|
194
|
-
"name": ["JAK1", "JAK2"],
|
195
|
-
"node_type": ["gene/protein", "gene/protein"]
|
196
|
-
})
|
197
|
-
mock_read_excel.return_value = {"gene/protein": df}
|
198
|
-
|
199
|
-
# Mock Collection for nodes and edges
|
200
|
-
colls = {}
|
201
|
-
colls["nodes"] = MagicMock()
|
202
|
-
colls["nodes"] = MagicMock()
|
203
|
-
colls["nodes"].query.return_value = [
|
204
|
-
{"node_index": 0,
|
205
|
-
"node_id": "id1",
|
206
|
-
"node_name": "JAK1",
|
207
|
-
"node_type": "gene/protein",
|
208
|
-
"feat": "featA",
|
209
|
-
"feat_emb": [0.1, 0.2, 0.3],
|
210
|
-
"desc": "descA",
|
211
|
-
"desc_emb": [0.1, 0.2, 0.3]},
|
212
|
-
{"node_index": 1,
|
213
|
-
"node_id": "id2",
|
214
|
-
"node_name": "JAK2",
|
215
|
-
"node_type": "gene/protein",
|
216
|
-
"feat": "featB",
|
217
|
-
"feat_emb": [0.4, 0.5, 0.6],
|
218
|
-
"desc": "descB",
|
219
|
-
"desc_emb": [0.4, 0.5, 0.6]}
|
220
|
-
]
|
221
|
-
colls["nodes"].load.return_value = None
|
222
|
-
|
223
|
-
colls["edges"] = MagicMock()
|
224
|
-
colls["edges"].query.return_value = [
|
225
|
-
{"triplet_index": 0,
|
226
|
-
"head_id": "id1",
|
227
|
-
"head_index": 0,
|
228
|
-
"tail_id": "id2",
|
229
|
-
"tail_index": 1,
|
230
|
-
"edge_type": "gene/protein,ppi,gene/protein",
|
231
|
-
"display_relation": "ppi",
|
232
|
-
"feat": "featC",
|
233
|
-
"feat_emb": [0.7, 0.8, 0.9]}
|
234
|
-
]
|
235
|
-
colls["edges"].load.return_value = None
|
236
|
-
|
237
|
-
def collection_side_effect(name):
|
238
|
-
"""
|
239
|
-
Mock side effect for Collection to return nodes or edges based on name.
|
240
|
-
"""
|
241
|
-
if "nodes" in name:
|
242
|
-
return colls["nodes"]
|
243
|
-
if "edges" in name:
|
244
|
-
return colls["edges"]
|
245
|
-
return None
|
246
|
-
mock_collection.side_effect = collection_side_effect
|
247
|
-
|
248
|
-
# Mock MultimodalPCSTPruning
|
249
|
-
mock_pcst_instance = MagicMock()
|
250
|
-
mock_pcst_instance.extract_subgraph.return_value = {
|
251
|
-
"nodes": pd.Series([1, 2]),
|
252
|
-
"edges": pd.Series([0])
|
253
|
-
}
|
254
|
-
mock_pcst.return_value = mock_pcst_instance
|
255
|
-
|
256
|
-
# Patch hydra.compose to return config objects
|
257
|
-
with patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
258
|
-
"milvus_multimodal_subgraph_extraction.hydra.initialize"), \
|
259
|
-
patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
260
|
-
"milvus_multimodal_subgraph_extraction.hydra.compose") as mock_compose:
|
261
|
-
mock_compose.return_value = MagicMock()
|
262
|
-
mock_compose.return_value.app.frontend = self.cfg_db
|
263
|
-
mock_compose.return_value.tools.multimodal_subgraph_extraction = self.cfg
|
264
|
-
|
265
|
-
response = self.tool.invoke(
|
266
|
-
input={"prompt": self.prompt,
|
267
|
-
"tool_call_id": "subgraph_extraction_tool",
|
268
|
-
"state": self.state,
|
269
|
-
"arg_data": self.arg_data}
|
270
|
-
)
|
271
|
-
|
272
|
-
|
273
|
-
self.assertEqual(response.update["messages"][-1].tool_call_id, "subgraph_extraction_tool")
|
274
|
-
|
275
|
-
# Check extracted subgraph dictionary
|
276
|
-
dic_extracted_graph = response.update["dic_extracted_graph"][0]
|
277
|
-
self.assertIsInstance(dic_extracted_graph, dict)
|
278
|
-
self.assertEqual(dic_extracted_graph["name"], self.arg_data["extraction_name"])
|
279
|
-
self.assertEqual(dic_extracted_graph["graph_source"], "TestGraph")
|
280
|
-
self.assertEqual(dic_extracted_graph["topk_nodes"], 5)
|
281
|
-
self.assertEqual(dic_extracted_graph["topk_edges"], 5)
|
282
|
-
self.assertIsInstance(dic_extracted_graph["graph_dict"], dict)
|
283
|
-
self.assertGreater(len(dic_extracted_graph["graph_dict"]["nodes"]), 0)
|
284
|
-
self.assertGreater(len(dic_extracted_graph["graph_dict"]["edges"]), 0)
|
285
|
-
self.assertIsInstance(dic_extracted_graph["graph_text"], str)
|
286
|
-
# Check if the nodes are in the graph_text
|
287
|
-
self.assertTrue(all(
|
288
|
-
n[0] in dic_extracted_graph["graph_text"].replace('"', '')
|
289
|
-
for subgraph_nodes in dic_extracted_graph["graph_dict"]["nodes"]
|
290
|
-
for n in subgraph_nodes
|
291
|
-
))
|
292
|
-
# Check if the edges are in the graph_text
|
293
|
-
self.assertTrue(all(
|
294
|
-
",".join([str(e[0])] + str(e[2]['label'][0]).split(",") + [str(e[1])])
|
295
|
-
in dic_extracted_graph["graph_text"].replace('"', '').\
|
296
|
-
replace("[", "").replace("]", "").replace("'", "")
|
297
|
-
for subgraph_edges in dic_extracted_graph["graph_dict"]["edges"]
|
298
|
-
for e in subgraph_edges
|
299
|
-
))
|
300
|
-
|
301
|
-
# Another test for unknown collection
|
302
|
-
result = collection_side_effect("unknown")
|
303
|
-
self.assertIsNone(result)
|
304
|
-
|
305
|
-
def test_extract_multimodal_subgraph_wo_doc_gpu(self):
|
306
|
-
"""
|
307
|
-
Test the multimodal subgraph extraction tool for only text as modality,
|
308
|
-
simulating GPU (cudf/cupy) environment.
|
309
|
-
"""
|
310
|
-
module_name = "aiagents4pharma.talk2knowledgegraphs.tools."+\
|
311
|
-
"milvus_multimodal_subgraph_extraction"
|
312
|
-
with patch.dict("sys.modules", {"cupy": np, "cudf": pd}):
|
313
|
-
mod = importlib.reload(importlib.import_module(module_name))
|
314
|
-
# Patch Collection and MultimodalPCSTPruning after reload
|
315
|
-
with patch(f"{module_name}.Collection") as mock_collection, \
|
316
|
-
patch(f"{module_name}.MultimodalPCSTPruning") as mock_pcst, \
|
317
|
-
patch("pymilvus.connections") as mock_connections:
|
318
|
-
# Setup mocks as in the original test
|
319
|
-
mock_connections.has_connection.return_value = True
|
320
|
-
colls = {}
|
321
|
-
colls["nodes"] = MagicMock()
|
322
|
-
colls["nodes"].query.return_value = [
|
323
|
-
{"node_index": 0,
|
324
|
-
"node_id": "id1",
|
325
|
-
"node_name": "JAK1",
|
326
|
-
"node_type": "gene/protein",
|
327
|
-
"feat": "featA",
|
328
|
-
"feat_emb": [0.1, 0.2, 0.3],
|
329
|
-
"desc": "descA",
|
330
|
-
"desc_emb": [0.1, 0.2, 0.3]},
|
331
|
-
{"node_index": 1,
|
332
|
-
"node_id": "id2",
|
333
|
-
"node_name": "JAK2",
|
334
|
-
"node_type": "gene/protein",
|
335
|
-
"feat": "featB",
|
336
|
-
"feat_emb": [0.4, 0.5, 0.6],
|
337
|
-
"desc": "descB",
|
338
|
-
"desc_emb": [0.4, 0.5, 0.6]}
|
339
|
-
]
|
340
|
-
colls["nodes"].load.return_value = None
|
341
|
-
colls["edges"] = MagicMock()
|
342
|
-
colls["edges"].query.return_value = [
|
343
|
-
{"triplet_index": 0,
|
344
|
-
"head_id": "id1",
|
345
|
-
"head_index": 0,
|
346
|
-
"tail_id": "id2",
|
347
|
-
"tail_index": 1,
|
348
|
-
"edge_type": "gene/protein,ppi,gene/protein",
|
349
|
-
"display_relation": "ppi",
|
350
|
-
"feat": "featC",
|
351
|
-
"feat_emb": [0.7, 0.8, 0.9]}
|
352
|
-
]
|
353
|
-
colls["edges"].load.return_value = None
|
354
|
-
def collection_side_effect(name):
|
355
|
-
if "nodes" in name:
|
356
|
-
return colls["nodes"]
|
357
|
-
if "edges" in name:
|
358
|
-
return colls["edges"]
|
359
|
-
return None
|
360
|
-
mock_collection.side_effect = collection_side_effect
|
361
|
-
mock_pcst_instance = MagicMock()
|
362
|
-
mock_pcst_instance.extract_subgraph.return_value = {
|
363
|
-
"nodes": pd.Series([1, 2]),
|
364
|
-
"edges": pd.Series([0])
|
365
|
-
}
|
366
|
-
mock_pcst.return_value = mock_pcst_instance
|
367
|
-
# Setup config mocks
|
368
|
-
tool_cls = getattr(mod, "MultimodalSubgraphExtractionTool")
|
369
|
-
tool = tool_cls()
|
370
|
-
|
371
|
-
# Patch hydra.compose
|
372
|
-
with patch(f"{module_name}.hydra.initialize"), \
|
373
|
-
patch(f"{module_name}.hydra.compose") as mock_compose:
|
374
|
-
mock_compose.return_value = MagicMock()
|
375
|
-
mock_compose.return_value.app.frontend = self.cfg_db
|
376
|
-
mock_compose.return_value.tools.multimodal_subgraph_extraction = self.cfg
|
377
|
-
self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
|
378
|
-
self.state["selections"] = {}
|
379
|
-
response = tool.invoke(
|
380
|
-
input={"prompt": self.prompt,
|
381
|
-
"tool_call_id": "subgraph_extraction_tool",
|
382
|
-
"state": self.state,
|
383
|
-
"arg_data": self.arg_data}
|
384
|
-
)
|
385
|
-
# Check tool message
|
386
|
-
self.assertEqual(response.update["messages"][-1].tool_call_id,
|
387
|
-
"subgraph_extraction_tool")
|
388
|
-
dic_extracted_graph = response.update["dic_extracted_graph"][0]
|
389
|
-
self.assertIsInstance(dic_extracted_graph, dict)
|
390
|
-
self.assertEqual(dic_extracted_graph["name"], self.arg_data["extraction_name"])
|
391
|
-
self.assertEqual(dic_extracted_graph["graph_source"], "TestGraph")
|
392
|
-
self.assertEqual(dic_extracted_graph["topk_nodes"], 5)
|
393
|
-
self.assertEqual(dic_extracted_graph["topk_edges"], 5)
|
394
|
-
self.assertIsInstance(dic_extracted_graph["graph_dict"], dict)
|
395
|
-
self.assertGreater(len(dic_extracted_graph["graph_dict"]["nodes"]), 0)
|
396
|
-
self.assertGreater(len(dic_extracted_graph["graph_dict"]["edges"]), 0)
|
397
|
-
self.assertIsInstance(dic_extracted_graph["graph_text"], str)
|
398
|
-
self.assertTrue(all(
|
399
|
-
n[0] in dic_extracted_graph["graph_text"].replace('"', '')
|
400
|
-
for subgraph_nodes in dic_extracted_graph["graph_dict"]["nodes"]
|
401
|
-
for n in subgraph_nodes
|
402
|
-
))
|
403
|
-
self.assertTrue(all(
|
404
|
-
",".join([str(e[0])] + str(e[2]['label'][0]).split(",") + [str(e[1])])
|
405
|
-
in dic_extracted_graph["graph_text"].replace('"', '').\
|
406
|
-
replace("[", "").replace("]", "").replace("'", "")
|
407
|
-
for subgraph_edges in dic_extracted_graph["graph_dict"]["edges"]
|
408
|
-
for e in subgraph_edges
|
409
|
-
))
|
410
|
-
|
411
|
-
# Another test for unknown collection
|
412
|
-
result = collection_side_effect("unknown")
|
413
|
-
self.assertIsNone(result)
|
1
|
+
"""
|
2
|
+
Test cases for tools/milvus_multimodal_subgraph_extraction.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import importlib
|
6
|
+
import unittest
|
7
|
+
from unittest.mock import patch, MagicMock
|
8
|
+
import numpy as np
|
9
|
+
import pandas as pd
|
10
|
+
from ..tools.milvus_multimodal_subgraph_extraction import MultimodalSubgraphExtractionTool
|
11
|
+
|
12
|
+
class TestMultimodalSubgraphExtractionTool(unittest.TestCase):
|
13
|
+
"""
|
14
|
+
Test cases for MultimodalSubgraphExtractionTool (Milvus)
|
15
|
+
"""
|
16
|
+
def setUp(self):
|
17
|
+
self.tool = MultimodalSubgraphExtractionTool()
|
18
|
+
self.state = {
|
19
|
+
"uploaded_files": [],
|
20
|
+
"embedding_model": MagicMock(),
|
21
|
+
"topk_nodes": 5,
|
22
|
+
"topk_edges": 5,
|
23
|
+
"dic_source_graph": [{"name": "TestGraph"}],
|
24
|
+
}
|
25
|
+
self.prompt = "Find subgraph for test"
|
26
|
+
self.arg_data = {"extraction_name": "subkg_12345"}
|
27
|
+
self.cfg_db = MagicMock()
|
28
|
+
self.cfg_db.milvus_db.database_name = "testdb"
|
29
|
+
self.cfg_db.milvus_db.alias = "default"
|
30
|
+
self.cfg = MagicMock()
|
31
|
+
self.cfg.cost_e = 1.0
|
32
|
+
self.cfg.c_const = 1.0
|
33
|
+
self.cfg.root = 0
|
34
|
+
self.cfg.num_clusters = 1
|
35
|
+
self.cfg.pruning = True
|
36
|
+
self.cfg.verbosity_level = 0
|
37
|
+
self.cfg.search_metric_type = "L2"
|
38
|
+
self.cfg.node_colors_dict = {"gene/protein": "red"}
|
39
|
+
|
40
|
+
@patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
41
|
+
"milvus_multimodal_subgraph_extraction.Collection")
|
42
|
+
@patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
43
|
+
"milvus_multimodal_subgraph_extraction.MultimodalPCSTPruning")
|
44
|
+
@patch("pymilvus.connections")
|
45
|
+
def test_extract_multimodal_subgraph_wo_doc(self,
|
46
|
+
mock_connections,
|
47
|
+
mock_pcst,
|
48
|
+
mock_collection):
|
49
|
+
"""
|
50
|
+
Test the multimodal subgraph extraction tool for only text as modality.
|
51
|
+
"""
|
52
|
+
|
53
|
+
# Mock Milvus connection utilities
|
54
|
+
mock_connections.has_connection.return_value = True
|
55
|
+
|
56
|
+
# No uploaded_files (no doc)
|
57
|
+
self.state["uploaded_files"] = []
|
58
|
+
self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
|
59
|
+
self.state["selections"] = {}
|
60
|
+
|
61
|
+
# Mock Collection for nodes and edges
|
62
|
+
colls = {}
|
63
|
+
colls["nodes"] = MagicMock()
|
64
|
+
colls["nodes"] = MagicMock()
|
65
|
+
colls["nodes"].query.return_value = [
|
66
|
+
{"node_index": 0,
|
67
|
+
"node_id": "id1",
|
68
|
+
"node_name": "JAK1",
|
69
|
+
"node_type": "gene/protein",
|
70
|
+
"feat": "featA",
|
71
|
+
"feat_emb": [0.1, 0.2, 0.3],
|
72
|
+
"desc": "descA",
|
73
|
+
"desc_emb": [0.1, 0.2, 0.3]},
|
74
|
+
{"node_index": 1,
|
75
|
+
"node_id": "id2",
|
76
|
+
"node_name": "JAK2",
|
77
|
+
"node_type": "gene/protein",
|
78
|
+
"feat": "featB",
|
79
|
+
"feat_emb": [0.4, 0.5, 0.6],
|
80
|
+
"desc": "descB",
|
81
|
+
"desc_emb": [0.4, 0.5, 0.6]}
|
82
|
+
]
|
83
|
+
colls["nodes"].load.return_value = None
|
84
|
+
|
85
|
+
colls["edges"] = MagicMock()
|
86
|
+
colls["edges"].query.return_value = [
|
87
|
+
{"triplet_index": 0,
|
88
|
+
"head_id": "id1",
|
89
|
+
"head_index": 0,
|
90
|
+
"tail_id": "id2",
|
91
|
+
"tail_index": 1,
|
92
|
+
"edge_type": "gene/protein,ppi,gene/protein",
|
93
|
+
"display_relation": "ppi",
|
94
|
+
"feat": "featC",
|
95
|
+
"feat_emb": [0.7, 0.8, 0.9]}
|
96
|
+
]
|
97
|
+
colls["edges"].load.return_value = None
|
98
|
+
|
99
|
+
def collection_side_effect(name):
|
100
|
+
"""
|
101
|
+
Mock side effect for Collection to return nodes or edges based on name.
|
102
|
+
"""
|
103
|
+
if "nodes" in name:
|
104
|
+
return colls["nodes"]
|
105
|
+
if "edges" in name:
|
106
|
+
return colls["edges"]
|
107
|
+
return None
|
108
|
+
mock_collection.side_effect = collection_side_effect
|
109
|
+
|
110
|
+
# Mock MultimodalPCSTPruning
|
111
|
+
mock_pcst_instance = MagicMock()
|
112
|
+
mock_pcst_instance.extract_subgraph.return_value = {
|
113
|
+
"nodes": pd.Series([1, 2]),
|
114
|
+
"edges": pd.Series([0])
|
115
|
+
}
|
116
|
+
mock_pcst.return_value = mock_pcst_instance
|
117
|
+
|
118
|
+
# Patch hydra.compose to return config objects
|
119
|
+
with patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
120
|
+
"milvus_multimodal_subgraph_extraction.hydra.initialize"), \
|
121
|
+
patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
122
|
+
"milvus_multimodal_subgraph_extraction.hydra.compose") as mock_compose:
|
123
|
+
mock_compose.return_value = MagicMock()
|
124
|
+
mock_compose.return_value.app.frontend = self.cfg_db
|
125
|
+
mock_compose.return_value.tools.multimodal_subgraph_extraction = self.cfg
|
126
|
+
|
127
|
+
response = self.tool.invoke(
|
128
|
+
input={"prompt": self.prompt,
|
129
|
+
"tool_call_id": "subgraph_extraction_tool",
|
130
|
+
"state": self.state,
|
131
|
+
"arg_data": self.arg_data}
|
132
|
+
)
|
133
|
+
|
134
|
+
# Check tool message
|
135
|
+
self.assertEqual(response.update["messages"][-1].tool_call_id, "subgraph_extraction_tool")
|
136
|
+
|
137
|
+
# Check extracted subgraph dictionary
|
138
|
+
dic_extracted_graph = response.update["dic_extracted_graph"][0]
|
139
|
+
self.assertIsInstance(dic_extracted_graph, dict)
|
140
|
+
self.assertEqual(dic_extracted_graph["name"], self.arg_data["extraction_name"])
|
141
|
+
self.assertEqual(dic_extracted_graph["graph_source"], "TestGraph")
|
142
|
+
self.assertEqual(dic_extracted_graph["topk_nodes"], 5)
|
143
|
+
self.assertEqual(dic_extracted_graph["topk_edges"], 5)
|
144
|
+
self.assertIsInstance(dic_extracted_graph["graph_dict"], dict)
|
145
|
+
self.assertGreater(len(dic_extracted_graph["graph_dict"]["nodes"]), 0)
|
146
|
+
self.assertGreater(len(dic_extracted_graph["graph_dict"]["edges"]), 0)
|
147
|
+
self.assertIsInstance(dic_extracted_graph["graph_text"], str)
|
148
|
+
# Check if the nodes are in the graph_text
|
149
|
+
self.assertTrue(all(
|
150
|
+
n[0] in dic_extracted_graph["graph_text"].replace('"', '')
|
151
|
+
for subgraph_nodes in dic_extracted_graph["graph_dict"]["nodes"]
|
152
|
+
for n in subgraph_nodes
|
153
|
+
))
|
154
|
+
# Check if the edges are in the graph_text
|
155
|
+
self.assertTrue(all(
|
156
|
+
",".join([str(e[0])] + str(e[2]['label'][0]).split(",") + [str(e[1])])
|
157
|
+
in dic_extracted_graph["graph_text"].replace('"', '').\
|
158
|
+
replace("[", "").replace("]", "").replace("'", "")
|
159
|
+
for subgraph_edges in dic_extracted_graph["graph_dict"]["edges"]
|
160
|
+
for e in subgraph_edges
|
161
|
+
))
|
162
|
+
|
163
|
+
# Another test for unknown collection
|
164
|
+
result = collection_side_effect("unknown")
|
165
|
+
self.assertIsNone(result)
|
166
|
+
|
167
|
+
@patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
168
|
+
"milvus_multimodal_subgraph_extraction.Collection")
|
169
|
+
@patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
170
|
+
"milvus_multimodal_subgraph_extraction.pd.read_excel")
|
171
|
+
@patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
172
|
+
"milvus_multimodal_subgraph_extraction.MultimodalPCSTPruning")
|
173
|
+
@patch("pymilvus.connections")
|
174
|
+
def test_extract_multimodal_subgraph_w_doc(self,
|
175
|
+
mock_connections,
|
176
|
+
mock_pcst,
|
177
|
+
mock_read_excel,
|
178
|
+
mock_collection):
|
179
|
+
"""
|
180
|
+
Test the multimodal subgraph extraction tool for text as modality, plus genes.
|
181
|
+
"""
|
182
|
+
# Mock Milvus connection utilities
|
183
|
+
mock_connections.has_connection.return_value = True
|
184
|
+
|
185
|
+
# With uploaded_files (with doc)
|
186
|
+
self.state["uploaded_files"] = [
|
187
|
+
{"file_type": "multimodal", "file_path": "dummy.xlsx"}
|
188
|
+
]
|
189
|
+
self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
|
190
|
+
self.state["selections"] = {"gene/protein": ["JAK1", "JAK2"]}
|
191
|
+
|
192
|
+
# Mock pd.read_excel to return a dict of DataFrames
|
193
|
+
df = pd.DataFrame({
|
194
|
+
"name": ["JAK1", "JAK2"],
|
195
|
+
"node_type": ["gene/protein", "gene/protein"]
|
196
|
+
})
|
197
|
+
mock_read_excel.return_value = {"gene/protein": df}
|
198
|
+
|
199
|
+
# Mock Collection for nodes and edges
|
200
|
+
colls = {}
|
201
|
+
colls["nodes"] = MagicMock()
|
202
|
+
colls["nodes"] = MagicMock()
|
203
|
+
colls["nodes"].query.return_value = [
|
204
|
+
{"node_index": 0,
|
205
|
+
"node_id": "id1",
|
206
|
+
"node_name": "JAK1",
|
207
|
+
"node_type": "gene/protein",
|
208
|
+
"feat": "featA",
|
209
|
+
"feat_emb": [0.1, 0.2, 0.3],
|
210
|
+
"desc": "descA",
|
211
|
+
"desc_emb": [0.1, 0.2, 0.3]},
|
212
|
+
{"node_index": 1,
|
213
|
+
"node_id": "id2",
|
214
|
+
"node_name": "JAK2",
|
215
|
+
"node_type": "gene/protein",
|
216
|
+
"feat": "featB",
|
217
|
+
"feat_emb": [0.4, 0.5, 0.6],
|
218
|
+
"desc": "descB",
|
219
|
+
"desc_emb": [0.4, 0.5, 0.6]}
|
220
|
+
]
|
221
|
+
colls["nodes"].load.return_value = None
|
222
|
+
|
223
|
+
colls["edges"] = MagicMock()
|
224
|
+
colls["edges"].query.return_value = [
|
225
|
+
{"triplet_index": 0,
|
226
|
+
"head_id": "id1",
|
227
|
+
"head_index": 0,
|
228
|
+
"tail_id": "id2",
|
229
|
+
"tail_index": 1,
|
230
|
+
"edge_type": "gene/protein,ppi,gene/protein",
|
231
|
+
"display_relation": "ppi",
|
232
|
+
"feat": "featC",
|
233
|
+
"feat_emb": [0.7, 0.8, 0.9]}
|
234
|
+
]
|
235
|
+
colls["edges"].load.return_value = None
|
236
|
+
|
237
|
+
def collection_side_effect(name):
|
238
|
+
"""
|
239
|
+
Mock side effect for Collection to return nodes or edges based on name.
|
240
|
+
"""
|
241
|
+
if "nodes" in name:
|
242
|
+
return colls["nodes"]
|
243
|
+
if "edges" in name:
|
244
|
+
return colls["edges"]
|
245
|
+
return None
|
246
|
+
mock_collection.side_effect = collection_side_effect
|
247
|
+
|
248
|
+
# Mock MultimodalPCSTPruning
|
249
|
+
mock_pcst_instance = MagicMock()
|
250
|
+
mock_pcst_instance.extract_subgraph.return_value = {
|
251
|
+
"nodes": pd.Series([1, 2]),
|
252
|
+
"edges": pd.Series([0])
|
253
|
+
}
|
254
|
+
mock_pcst.return_value = mock_pcst_instance
|
255
|
+
|
256
|
+
# Patch hydra.compose to return config objects
|
257
|
+
with patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
258
|
+
"milvus_multimodal_subgraph_extraction.hydra.initialize"), \
|
259
|
+
patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
260
|
+
"milvus_multimodal_subgraph_extraction.hydra.compose") as mock_compose:
|
261
|
+
mock_compose.return_value = MagicMock()
|
262
|
+
mock_compose.return_value.app.frontend = self.cfg_db
|
263
|
+
mock_compose.return_value.tools.multimodal_subgraph_extraction = self.cfg
|
264
|
+
|
265
|
+
response = self.tool.invoke(
|
266
|
+
input={"prompt": self.prompt,
|
267
|
+
"tool_call_id": "subgraph_extraction_tool",
|
268
|
+
"state": self.state,
|
269
|
+
"arg_data": self.arg_data}
|
270
|
+
)
|
271
|
+
|
272
|
+
# Check tool message
|
273
|
+
self.assertEqual(response.update["messages"][-1].tool_call_id, "subgraph_extraction_tool")
|
274
|
+
|
275
|
+
# Check extracted subgraph dictionary
|
276
|
+
dic_extracted_graph = response.update["dic_extracted_graph"][0]
|
277
|
+
self.assertIsInstance(dic_extracted_graph, dict)
|
278
|
+
self.assertEqual(dic_extracted_graph["name"], self.arg_data["extraction_name"])
|
279
|
+
self.assertEqual(dic_extracted_graph["graph_source"], "TestGraph")
|
280
|
+
self.assertEqual(dic_extracted_graph["topk_nodes"], 5)
|
281
|
+
self.assertEqual(dic_extracted_graph["topk_edges"], 5)
|
282
|
+
self.assertIsInstance(dic_extracted_graph["graph_dict"], dict)
|
283
|
+
self.assertGreater(len(dic_extracted_graph["graph_dict"]["nodes"]), 0)
|
284
|
+
self.assertGreater(len(dic_extracted_graph["graph_dict"]["edges"]), 0)
|
285
|
+
self.assertIsInstance(dic_extracted_graph["graph_text"], str)
|
286
|
+
# Check if the nodes are in the graph_text
|
287
|
+
self.assertTrue(all(
|
288
|
+
n[0] in dic_extracted_graph["graph_text"].replace('"', '')
|
289
|
+
for subgraph_nodes in dic_extracted_graph["graph_dict"]["nodes"]
|
290
|
+
for n in subgraph_nodes
|
291
|
+
))
|
292
|
+
# Check if the edges are in the graph_text
|
293
|
+
self.assertTrue(all(
|
294
|
+
",".join([str(e[0])] + str(e[2]['label'][0]).split(",") + [str(e[1])])
|
295
|
+
in dic_extracted_graph["graph_text"].replace('"', '').\
|
296
|
+
replace("[", "").replace("]", "").replace("'", "")
|
297
|
+
for subgraph_edges in dic_extracted_graph["graph_dict"]["edges"]
|
298
|
+
for e in subgraph_edges
|
299
|
+
))
|
300
|
+
|
301
|
+
# Another test for unknown collection
|
302
|
+
result = collection_side_effect("unknown")
|
303
|
+
self.assertIsNone(result)
|
304
|
+
|
305
|
+
def test_extract_multimodal_subgraph_wo_doc_gpu(self):
|
306
|
+
"""
|
307
|
+
Test the multimodal subgraph extraction tool for only text as modality,
|
308
|
+
simulating GPU (cudf/cupy) environment.
|
309
|
+
"""
|
310
|
+
module_name = "aiagents4pharma.talk2knowledgegraphs.tools."+\
|
311
|
+
"milvus_multimodal_subgraph_extraction"
|
312
|
+
with patch.dict("sys.modules", {"cupy": np, "cudf": pd}):
|
313
|
+
mod = importlib.reload(importlib.import_module(module_name))
|
314
|
+
# Patch Collection and MultimodalPCSTPruning after reload
|
315
|
+
with patch(f"{module_name}.Collection") as mock_collection, \
|
316
|
+
patch(f"{module_name}.MultimodalPCSTPruning") as mock_pcst, \
|
317
|
+
patch("pymilvus.connections") as mock_connections:
|
318
|
+
# Setup mocks as in the original test
|
319
|
+
mock_connections.has_connection.return_value = True
|
320
|
+
colls = {}
|
321
|
+
colls["nodes"] = MagicMock()
|
322
|
+
colls["nodes"].query.return_value = [
|
323
|
+
{"node_index": 0,
|
324
|
+
"node_id": "id1",
|
325
|
+
"node_name": "JAK1",
|
326
|
+
"node_type": "gene/protein",
|
327
|
+
"feat": "featA",
|
328
|
+
"feat_emb": [0.1, 0.2, 0.3],
|
329
|
+
"desc": "descA",
|
330
|
+
"desc_emb": [0.1, 0.2, 0.3]},
|
331
|
+
{"node_index": 1,
|
332
|
+
"node_id": "id2",
|
333
|
+
"node_name": "JAK2",
|
334
|
+
"node_type": "gene/protein",
|
335
|
+
"feat": "featB",
|
336
|
+
"feat_emb": [0.4, 0.5, 0.6],
|
337
|
+
"desc": "descB",
|
338
|
+
"desc_emb": [0.4, 0.5, 0.6]}
|
339
|
+
]
|
340
|
+
colls["nodes"].load.return_value = None
|
341
|
+
colls["edges"] = MagicMock()
|
342
|
+
colls["edges"].query.return_value = [
|
343
|
+
{"triplet_index": 0,
|
344
|
+
"head_id": "id1",
|
345
|
+
"head_index": 0,
|
346
|
+
"tail_id": "id2",
|
347
|
+
"tail_index": 1,
|
348
|
+
"edge_type": "gene/protein,ppi,gene/protein",
|
349
|
+
"display_relation": "ppi",
|
350
|
+
"feat": "featC",
|
351
|
+
"feat_emb": [0.7, 0.8, 0.9]}
|
352
|
+
]
|
353
|
+
colls["edges"].load.return_value = None
|
354
|
+
def collection_side_effect(name):
|
355
|
+
if "nodes" in name:
|
356
|
+
return colls["nodes"]
|
357
|
+
if "edges" in name:
|
358
|
+
return colls["edges"]
|
359
|
+
return None
|
360
|
+
mock_collection.side_effect = collection_side_effect
|
361
|
+
mock_pcst_instance = MagicMock()
|
362
|
+
mock_pcst_instance.extract_subgraph.return_value = {
|
363
|
+
"nodes": pd.Series([1, 2]),
|
364
|
+
"edges": pd.Series([0])
|
365
|
+
}
|
366
|
+
mock_pcst.return_value = mock_pcst_instance
|
367
|
+
# Setup config mocks
|
368
|
+
tool_cls = getattr(mod, "MultimodalSubgraphExtractionTool")
|
369
|
+
tool = tool_cls()
|
370
|
+
|
371
|
+
# Patch hydra.compose
|
372
|
+
with patch(f"{module_name}.hydra.initialize"), \
|
373
|
+
patch(f"{module_name}.hydra.compose") as mock_compose:
|
374
|
+
mock_compose.return_value = MagicMock()
|
375
|
+
mock_compose.return_value.app.frontend = self.cfg_db
|
376
|
+
mock_compose.return_value.tools.multimodal_subgraph_extraction = self.cfg
|
377
|
+
self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
|
378
|
+
self.state["selections"] = {}
|
379
|
+
response = tool.invoke(
|
380
|
+
input={"prompt": self.prompt,
|
381
|
+
"tool_call_id": "subgraph_extraction_tool",
|
382
|
+
"state": self.state,
|
383
|
+
"arg_data": self.arg_data}
|
384
|
+
)
|
385
|
+
# Check tool message
|
386
|
+
self.assertEqual(response.update["messages"][-1].tool_call_id,
|
387
|
+
"subgraph_extraction_tool")
|
388
|
+
dic_extracted_graph = response.update["dic_extracted_graph"][0]
|
389
|
+
self.assertIsInstance(dic_extracted_graph, dict)
|
390
|
+
self.assertEqual(dic_extracted_graph["name"], self.arg_data["extraction_name"])
|
391
|
+
self.assertEqual(dic_extracted_graph["graph_source"], "TestGraph")
|
392
|
+
self.assertEqual(dic_extracted_graph["topk_nodes"], 5)
|
393
|
+
self.assertEqual(dic_extracted_graph["topk_edges"], 5)
|
394
|
+
self.assertIsInstance(dic_extracted_graph["graph_dict"], dict)
|
395
|
+
self.assertGreater(len(dic_extracted_graph["graph_dict"]["nodes"]), 0)
|
396
|
+
self.assertGreater(len(dic_extracted_graph["graph_dict"]["edges"]), 0)
|
397
|
+
self.assertIsInstance(dic_extracted_graph["graph_text"], str)
|
398
|
+
self.assertTrue(all(
|
399
|
+
n[0] in dic_extracted_graph["graph_text"].replace('"', '')
|
400
|
+
for subgraph_nodes in dic_extracted_graph["graph_dict"]["nodes"]
|
401
|
+
for n in subgraph_nodes
|
402
|
+
))
|
403
|
+
self.assertTrue(all(
|
404
|
+
",".join([str(e[0])] + str(e[2]['label'][0]).split(",") + [str(e[1])])
|
405
|
+
in dic_extracted_graph["graph_text"].replace('"', '').\
|
406
|
+
replace("[", "").replace("]", "").replace("'", "")
|
407
|
+
for subgraph_edges in dic_extracted_graph["graph_dict"]["edges"]
|
408
|
+
for e in subgraph_edges
|
409
|
+
))
|
410
|
+
|
411
|
+
# Another test for unknown collection
|
412
|
+
result = collection_side_effect("unknown")
|
413
|
+
self.assertIsNone(result)
|
414
|
+
|
415
|
+
def test_normalize_vector_gpu_mode(self):
|
416
|
+
"""Test normalize_vector method in GPU mode."""
|
417
|
+
# Mock the loader to simulate GPU mode
|
418
|
+
self.tool.loader.normalize_vectors = True
|
419
|
+
self.tool.loader.py = MagicMock()
|
420
|
+
# Mock the GPU array operations
|
421
|
+
mock_array = MagicMock()
|
422
|
+
mock_norm = MagicMock()
|
423
|
+
mock_norm.return_value = 2.0
|
424
|
+
mock_array.__truediv__ = MagicMock(return_value=mock_array)
|
425
|
+
mock_array.tolist.return_value = [0.5, 1.0, 1.5]
|
426
|
+
self.tool.loader.py.asarray.return_value = mock_array
|
427
|
+
self.tool.loader.py.linalg.norm.return_value = mock_norm
|
428
|
+
result = self.tool.normalize_vector([1.0, 2.0, 3.0])
|
429
|
+
# Verify the result
|
430
|
+
self.assertEqual(result, [0.5, 1.0, 1.5])
|
431
|
+
self.tool.loader.py.asarray.assert_called_once_with([1.0, 2.0, 3.0])
|
432
|
+
self.tool.loader.py.linalg.norm.assert_called_once_with(mock_array)
|
433
|
+
|
434
|
+
def test_normalize_vector_cpu_mode(self):
|
435
|
+
"""Test normalize_vector method in CPU mode."""
|
436
|
+
# Mock the loader to simulate CPU mode
|
437
|
+
self.tool.loader.normalize_vectors = False
|
438
|
+
result = self.tool.normalize_vector([1.0, 2.0, 3.0])
|
439
|
+
# In CPU mode, should return the input as-is
|
440
|
+
self.assertEqual(result, [1.0, 2.0, 3.0])
|
441
|
+
|
442
|
+
@patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
443
|
+
"milvus_multimodal_subgraph_extraction.Collection")
|
444
|
+
@patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
445
|
+
"milvus_multimodal_subgraph_extraction.MultimodalPCSTPruning")
|
446
|
+
@patch("pymilvus.connections")
|
447
|
+
def test_extract_multimodal_subgraph_no_vector_processing(self,
|
448
|
+
mock_connections,
|
449
|
+
mock_pcst,
|
450
|
+
mock_collection):
|
451
|
+
"""Test when vector_processing config is not present."""
|
452
|
+
# Mock Milvus connection utilities
|
453
|
+
mock_connections.has_connection.return_value = True
|
454
|
+
|
455
|
+
self.state["uploaded_files"] = []
|
456
|
+
self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
|
457
|
+
self.state["selections"] = {}
|
458
|
+
|
459
|
+
# Mock Collection for nodes and edges
|
460
|
+
colls = {}
|
461
|
+
colls["nodes"] = MagicMock()
|
462
|
+
colls["nodes"].query.return_value = [
|
463
|
+
{"node_index": 0, "node_id": "id1", "node_name": "JAK1",
|
464
|
+
"node_type": "gene/protein", "feat": "featA", "feat_emb": [0.1, 0.2, 0.3],
|
465
|
+
"desc": "descA", "desc_emb": [0.1, 0.2, 0.3]}
|
466
|
+
]
|
467
|
+
colls["nodes"].load.return_value = None
|
468
|
+
|
469
|
+
colls["edges"] = MagicMock()
|
470
|
+
colls["edges"].query.return_value = [
|
471
|
+
{"triplet_index": 0, "head_id": "id1", "tail_id": "id2",
|
472
|
+
"edge_type": "gene/protein,ppi,gene/protein"}
|
473
|
+
]
|
474
|
+
colls["edges"].load.return_value = None
|
475
|
+
|
476
|
+
def collection_side_effect(name):
|
477
|
+
if "nodes" in name:
|
478
|
+
return colls["nodes"]
|
479
|
+
if "edges" in name:
|
480
|
+
return colls["edges"]
|
481
|
+
return None
|
482
|
+
mock_collection.side_effect = collection_side_effect
|
483
|
+
|
484
|
+
# Mock MultimodalPCSTPruning
|
485
|
+
mock_pcst_instance = MagicMock()
|
486
|
+
mock_pcst_instance.extract_subgraph.return_value = {
|
487
|
+
"nodes": pd.Series([1]),
|
488
|
+
"edges": pd.Series([0])
|
489
|
+
}
|
490
|
+
mock_pcst.return_value = mock_pcst_instance
|
491
|
+
|
492
|
+
# Create config without vector_processing attribute
|
493
|
+
cfg_no_vector_processing = MagicMock()
|
494
|
+
cfg_no_vector_processing.cost_e = 1.0
|
495
|
+
cfg_no_vector_processing.c_const = 1.0
|
496
|
+
cfg_no_vector_processing.root = 0
|
497
|
+
cfg_no_vector_processing.num_clusters = 1
|
498
|
+
cfg_no_vector_processing.pruning = True
|
499
|
+
cfg_no_vector_processing.verbosity_level = 0
|
500
|
+
cfg_no_vector_processing.search_metric_type = "L2"
|
501
|
+
cfg_no_vector_processing.node_colors_dict = {"gene/protein": "red"}
|
502
|
+
# Remove vector_processing attribute to test the missing branch
|
503
|
+
del cfg_no_vector_processing.vector_processing
|
504
|
+
|
505
|
+
# Patch hydra.compose to return config without vector_processing
|
506
|
+
with patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
507
|
+
"milvus_multimodal_subgraph_extraction.hydra.initialize"), \
|
508
|
+
patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
509
|
+
"milvus_multimodal_subgraph_extraction.hydra.compose") as mock_compose:
|
510
|
+
mock_compose.return_value = MagicMock()
|
511
|
+
mock_compose.return_value.app.frontend = self.cfg_db
|
512
|
+
mock_compose.return_value.tools.multimodal_subgraph_extraction = \
|
513
|
+
cfg_no_vector_processing
|
514
|
+
|
515
|
+
response = self.tool.invoke(
|
516
|
+
input={"prompt": self.prompt,
|
517
|
+
"tool_call_id": "subgraph_extraction_tool",
|
518
|
+
"state": self.state,
|
519
|
+
"arg_data": self.arg_data}
|
520
|
+
)
|
521
|
+
|
522
|
+
# Verify the test completed successfully
|
523
|
+
self.assertEqual(response.update["messages"][-1].tool_call_id, "subgraph_extraction_tool")
|
524
|
+
|
525
|
+
# Test the collection_side_effect with unknown name for final test
|
526
|
+
result = collection_side_effect("final_unknown_collection")
|
527
|
+
self.assertIsNone(result)
|
528
|
+
|
529
|
+
# Test the collection_side_effect with unknown name
|
530
|
+
result = collection_side_effect("unknown_collection")
|
531
|
+
self.assertIsNone(result)
|
532
|
+
|
533
|
+
@patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
534
|
+
"milvus_multimodal_subgraph_extraction.Collection")
|
535
|
+
@patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
536
|
+
"milvus_multimodal_subgraph_extraction.MultimodalPCSTPruning")
|
537
|
+
@patch("pymilvus.connections")
|
538
|
+
def test_extract_multimodal_subgraph_dynamic_metrics_disabled(self,
|
539
|
+
mock_connections,
|
540
|
+
mock_pcst,
|
541
|
+
mock_collection):
|
542
|
+
"""Test when dynamic_metrics is disabled."""
|
543
|
+
# Mock Milvus connection utilities
|
544
|
+
mock_connections.has_connection.return_value = True
|
545
|
+
|
546
|
+
self.state["uploaded_files"] = []
|
547
|
+
self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
|
548
|
+
self.state["selections"] = {}
|
549
|
+
|
550
|
+
# Mock Collection for nodes and edges
|
551
|
+
colls = {}
|
552
|
+
colls["nodes"] = MagicMock()
|
553
|
+
colls["nodes"].query.return_value = [
|
554
|
+
{"node_index": 0, "node_id": "id1", "node_name": "JAK1",
|
555
|
+
"node_type": "gene/protein", "feat": "featA", "feat_emb": [0.1, 0.2, 0.3],
|
556
|
+
"desc": "descA", "desc_emb": [0.1, 0.2, 0.3]}
|
557
|
+
]
|
558
|
+
colls["nodes"].load.return_value = None
|
559
|
+
|
560
|
+
colls["edges"] = MagicMock()
|
561
|
+
colls["edges"].query.return_value = [
|
562
|
+
{"triplet_index": 0, "head_id": "id1", "tail_id": "id2",
|
563
|
+
"edge_type": "gene/protein,ppi,gene/protein"}
|
564
|
+
]
|
565
|
+
colls["edges"].load.return_value = None
|
566
|
+
|
567
|
+
def collection_side_effect(name):
|
568
|
+
if "nodes" in name:
|
569
|
+
return colls["nodes"]
|
570
|
+
if "edges" in name:
|
571
|
+
return colls["edges"]
|
572
|
+
return None
|
573
|
+
mock_collection.side_effect = collection_side_effect
|
574
|
+
|
575
|
+
# Mock MultimodalPCSTPruning
|
576
|
+
mock_pcst_instance = MagicMock()
|
577
|
+
mock_pcst_instance.extract_subgraph.return_value = {
|
578
|
+
"nodes": pd.Series([1]),
|
579
|
+
"edges": pd.Series([0])
|
580
|
+
}
|
581
|
+
mock_pcst.return_value = mock_pcst_instance
|
582
|
+
|
583
|
+
# Create config with dynamic_metrics disabled
|
584
|
+
cfg_dynamic_disabled = MagicMock()
|
585
|
+
cfg_dynamic_disabled.cost_e = 1.0
|
586
|
+
cfg_dynamic_disabled.c_const = 1.0
|
587
|
+
cfg_dynamic_disabled.root = 0
|
588
|
+
cfg_dynamic_disabled.num_clusters = 1
|
589
|
+
cfg_dynamic_disabled.pruning = True
|
590
|
+
cfg_dynamic_disabled.verbosity_level = 0
|
591
|
+
cfg_dynamic_disabled.search_metric_type = "L2"
|
592
|
+
cfg_dynamic_disabled.node_colors_dict = {"gene/protein": "red"}
|
593
|
+
# Set dynamic_metrics to False
|
594
|
+
cfg_dynamic_disabled.vector_processing = MagicMock()
|
595
|
+
cfg_dynamic_disabled.vector_processing.dynamic_metrics = False
|
596
|
+
|
597
|
+
# Patch hydra.compose to return config with dynamic_metrics disabled
|
598
|
+
with patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
599
|
+
"milvus_multimodal_subgraph_extraction.hydra.initialize"), \
|
600
|
+
patch("aiagents4pharma.talk2knowledgegraphs.tools."
|
601
|
+
"milvus_multimodal_subgraph_extraction.hydra.compose") as mock_compose:
|
602
|
+
mock_compose.return_value = MagicMock()
|
603
|
+
mock_compose.return_value.app.frontend = self.cfg_db
|
604
|
+
mock_compose.return_value.tools.multimodal_subgraph_extraction = cfg_dynamic_disabled
|
605
|
+
|
606
|
+
response = self.tool.invoke(
|
607
|
+
input={"prompt": self.prompt,
|
608
|
+
"tool_call_id": "subgraph_extraction_tool",
|
609
|
+
"state": self.state,
|
610
|
+
"arg_data": self.arg_data}
|
611
|
+
)
|
612
|
+
|
613
|
+
# Verify the test completed successfully
|
614
|
+
self.assertEqual(response.update["messages"][-1].tool_call_id, "subgraph_extraction_tool")
|
615
|
+
|
616
|
+
# Test the collection_side_effect with unknown name for final test
|
617
|
+
result = collection_side_effect("final_unknown_collection")
|
618
|
+
self.assertIsNone(result)
|