aiagents4pharma 1.42.0__py3-none-any.whl → 1.44.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/default.yaml +17 -2
  2. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py +618 -413
  3. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_extractions_milvus_multimodal_pcst.py +362 -25
  4. aiagents4pharma/talk2knowledgegraphs/tools/milvus_multimodal_subgraph_extraction.py +146 -109
  5. aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py +240 -83
  6. aiagents4pharma/talk2scholars/agents/paper_download_agent.py +7 -4
  7. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +49 -95
  8. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/default.yaml +15 -1
  9. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/pdf_agent/default.yaml +16 -2
  10. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +40 -5
  11. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +15 -5
  12. aiagents4pharma/talk2scholars/configs/config.yaml +1 -3
  13. aiagents4pharma/talk2scholars/configs/tools/paper_download/default.yaml +124 -0
  14. aiagents4pharma/talk2scholars/tests/test_arxiv_downloader.py +478 -0
  15. aiagents4pharma/talk2scholars/tests/test_base_paper_downloader.py +620 -0
  16. aiagents4pharma/talk2scholars/tests/test_biorxiv_downloader.py +697 -0
  17. aiagents4pharma/talk2scholars/tests/test_medrxiv_downloader.py +534 -0
  18. aiagents4pharma/talk2scholars/tests/test_paper_download_agent.py +22 -12
  19. aiagents4pharma/talk2scholars/tests/test_paper_downloader.py +545 -0
  20. aiagents4pharma/talk2scholars/tests/test_pubmed_downloader.py +1067 -0
  21. aiagents4pharma/talk2scholars/tools/paper_download/__init__.py +2 -4
  22. aiagents4pharma/talk2scholars/tools/paper_download/paper_downloader.py +457 -0
  23. aiagents4pharma/talk2scholars/tools/paper_download/utils/__init__.py +20 -0
  24. aiagents4pharma/talk2scholars/tools/paper_download/utils/arxiv_downloader.py +209 -0
  25. aiagents4pharma/talk2scholars/tools/paper_download/utils/base_paper_downloader.py +343 -0
  26. aiagents4pharma/talk2scholars/tools/paper_download/utils/biorxiv_downloader.py +321 -0
  27. aiagents4pharma/talk2scholars/tools/paper_download/utils/medrxiv_downloader.py +198 -0
  28. aiagents4pharma/talk2scholars/tools/paper_download/utils/pubmed_downloader.py +337 -0
  29. aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +97 -45
  30. aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +47 -29
  31. {aiagents4pharma-1.42.0.dist-info → aiagents4pharma-1.44.0.dist-info}/METADATA +3 -1
  32. {aiagents4pharma-1.42.0.dist-info → aiagents4pharma-1.44.0.dist-info}/RECORD +36 -33
  33. aiagents4pharma/talk2scholars/configs/tools/download_arxiv_paper/default.yaml +0 -4
  34. aiagents4pharma/talk2scholars/configs/tools/download_biorxiv_paper/__init__.py +0 -3
  35. aiagents4pharma/talk2scholars/configs/tools/download_biorxiv_paper/default.yaml +0 -2
  36. aiagents4pharma/talk2scholars/configs/tools/download_medrxiv_paper/__init__.py +0 -3
  37. aiagents4pharma/talk2scholars/configs/tools/download_medrxiv_paper/default.yaml +0 -2
  38. aiagents4pharma/talk2scholars/tests/test_paper_download_biorxiv.py +0 -151
  39. aiagents4pharma/talk2scholars/tests/test_paper_download_medrxiv.py +0 -151
  40. aiagents4pharma/talk2scholars/tests/test_paper_download_tools.py +0 -249
  41. aiagents4pharma/talk2scholars/tools/paper_download/download_arxiv_input.py +0 -177
  42. aiagents4pharma/talk2scholars/tools/paper_download/download_biorxiv_input.py +0 -114
  43. aiagents4pharma/talk2scholars/tools/paper_download/download_medrxiv_input.py +0 -114
  44. /aiagents4pharma/talk2scholars/configs/tools/{download_arxiv_paper → paper_download}/__init__.py +0 -0
  45. {aiagents4pharma-1.42.0.dist-info → aiagents4pharma-1.44.0.dist-info}/WHEEL +0 -0
  46. {aiagents4pharma-1.42.0.dist-info → aiagents4pharma-1.44.0.dist-info}/licenses/LICENSE +0 -0
  47. {aiagents4pharma-1.42.0.dist-info → aiagents4pharma-1.44.0.dist-info}/top_level.txt +0 -0
@@ -1,413 +1,618 @@
1
- """
2
- Test cases for tools/milvus_multimodal_subgraph_extraction.py
3
- """
4
-
5
- import importlib
6
- import unittest
7
- from unittest.mock import patch, MagicMock
8
- import numpy as np
9
- import pandas as pd
10
- from ..tools.milvus_multimodal_subgraph_extraction import MultimodalSubgraphExtractionTool
11
-
12
- class TestMultimodalSubgraphExtractionTool(unittest.TestCase):
13
- """
14
- Test cases for MultimodalSubgraphExtractionTool (Milvus)
15
- """
16
- def setUp(self):
17
- self.tool = MultimodalSubgraphExtractionTool()
18
- self.state = {
19
- "uploaded_files": [],
20
- "embedding_model": MagicMock(),
21
- "topk_nodes": 5,
22
- "topk_edges": 5,
23
- "dic_source_graph": [{"name": "TestGraph"}],
24
- }
25
- self.prompt = "Find subgraph for test"
26
- self.arg_data = {"extraction_name": "subkg_12345"}
27
- self.cfg_db = MagicMock()
28
- self.cfg_db.milvus_db.database_name = "testdb"
29
- self.cfg_db.milvus_db.alias = "default"
30
- self.cfg = MagicMock()
31
- self.cfg.cost_e = 1.0
32
- self.cfg.c_const = 1.0
33
- self.cfg.root = 0
34
- self.cfg.num_clusters = 1
35
- self.cfg.pruning = True
36
- self.cfg.verbosity_level = 0
37
- self.cfg.search_metric_type = "L2"
38
- self.cfg.node_colors_dict = {"gene/protein": "red"}
39
-
40
- @patch("aiagents4pharma.talk2knowledgegraphs.tools."
41
- "milvus_multimodal_subgraph_extraction.Collection")
42
- @patch("aiagents4pharma.talk2knowledgegraphs.tools."
43
- "milvus_multimodal_subgraph_extraction.MultimodalPCSTPruning")
44
- @patch("pymilvus.connections")
45
- def test_extract_multimodal_subgraph_wo_doc(self,
46
- mock_connections,
47
- mock_pcst,
48
- mock_collection):
49
- """
50
- Test the multimodal subgraph extraction tool for only text as modality.
51
- """
52
-
53
- # Mock Milvus connection utilities
54
- mock_connections.has_connection.return_value = True
55
-
56
- # No uploaded_files (no doc)
57
- self.state["uploaded_files"] = []
58
- self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
59
- self.state["selections"] = {}
60
-
61
- # Mock Collection for nodes and edges
62
- colls = {}
63
- colls["nodes"] = MagicMock()
64
- colls["nodes"] = MagicMock()
65
- colls["nodes"].query.return_value = [
66
- {"node_index": 0,
67
- "node_id": "id1",
68
- "node_name": "JAK1",
69
- "node_type": "gene/protein",
70
- "feat": "featA",
71
- "feat_emb": [0.1, 0.2, 0.3],
72
- "desc": "descA",
73
- "desc_emb": [0.1, 0.2, 0.3]},
74
- {"node_index": 1,
75
- "node_id": "id2",
76
- "node_name": "JAK2",
77
- "node_type": "gene/protein",
78
- "feat": "featB",
79
- "feat_emb": [0.4, 0.5, 0.6],
80
- "desc": "descB",
81
- "desc_emb": [0.4, 0.5, 0.6]}
82
- ]
83
- colls["nodes"].load.return_value = None
84
-
85
- colls["edges"] = MagicMock()
86
- colls["edges"].query.return_value = [
87
- {"triplet_index": 0,
88
- "head_id": "id1",
89
- "head_index": 0,
90
- "tail_id": "id2",
91
- "tail_index": 1,
92
- "edge_type": "gene/protein,ppi,gene/protein",
93
- "display_relation": "ppi",
94
- "feat": "featC",
95
- "feat_emb": [0.7, 0.8, 0.9]}
96
- ]
97
- colls["edges"].load.return_value = None
98
-
99
- def collection_side_effect(name):
100
- """
101
- Mock side effect for Collection to return nodes or edges based on name.
102
- """
103
- if "nodes" in name:
104
- return colls["nodes"]
105
- if "edges" in name:
106
- return colls["edges"]
107
- return None
108
- mock_collection.side_effect = collection_side_effect
109
-
110
- # Mock MultimodalPCSTPruning
111
- mock_pcst_instance = MagicMock()
112
- mock_pcst_instance.extract_subgraph.return_value = {
113
- "nodes": pd.Series([1, 2]),
114
- "edges": pd.Series([0])
115
- }
116
- mock_pcst.return_value = mock_pcst_instance
117
-
118
- # Patch hydra.compose to return config objects
119
- with patch("aiagents4pharma.talk2knowledgegraphs.tools."
120
- "milvus_multimodal_subgraph_extraction.hydra.initialize"), \
121
- patch("aiagents4pharma.talk2knowledgegraphs.tools."
122
- "milvus_multimodal_subgraph_extraction.hydra.compose") as mock_compose:
123
- mock_compose.return_value = MagicMock()
124
- mock_compose.return_value.app.frontend = self.cfg_db
125
- mock_compose.return_value.tools.multimodal_subgraph_extraction = self.cfg
126
-
127
- response = self.tool.invoke(
128
- input={"prompt": self.prompt,
129
- "tool_call_id": "subgraph_extraction_tool",
130
- "state": self.state,
131
- "arg_data": self.arg_data}
132
- )
133
-
134
- # Check tool message
135
- self.assertEqual(response.update["messages"][-1].tool_call_id, "subgraph_extraction_tool")
136
-
137
- # Check extracted subgraph dictionary
138
- dic_extracted_graph = response.update["dic_extracted_graph"][0]
139
- self.assertIsInstance(dic_extracted_graph, dict)
140
- self.assertEqual(dic_extracted_graph["name"], self.arg_data["extraction_name"])
141
- self.assertEqual(dic_extracted_graph["graph_source"], "TestGraph")
142
- self.assertEqual(dic_extracted_graph["topk_nodes"], 5)
143
- self.assertEqual(dic_extracted_graph["topk_edges"], 5)
144
- self.assertIsInstance(dic_extracted_graph["graph_dict"], dict)
145
- self.assertGreater(len(dic_extracted_graph["graph_dict"]["nodes"]), 0)
146
- self.assertGreater(len(dic_extracted_graph["graph_dict"]["edges"]), 0)
147
- self.assertIsInstance(dic_extracted_graph["graph_text"], str)
148
- # Check if the nodes are in the graph_text
149
- self.assertTrue(all(
150
- n[0] in dic_extracted_graph["graph_text"].replace('"', '')
151
- for subgraph_nodes in dic_extracted_graph["graph_dict"]["nodes"]
152
- for n in subgraph_nodes
153
- ))
154
- # Check if the edges are in the graph_text
155
- self.assertTrue(all(
156
- ",".join([str(e[0])] + str(e[2]['label'][0]).split(",") + [str(e[1])])
157
- in dic_extracted_graph["graph_text"].replace('"', '').\
158
- replace("[", "").replace("]", "").replace("'", "")
159
- for subgraph_edges in dic_extracted_graph["graph_dict"]["edges"]
160
- for e in subgraph_edges
161
- ))
162
-
163
- # Another test for unknown collection
164
- result = collection_side_effect("unknown")
165
- self.assertIsNone(result)
166
-
167
- @patch("aiagents4pharma.talk2knowledgegraphs.tools."
168
- "milvus_multimodal_subgraph_extraction.Collection")
169
- @patch("aiagents4pharma.talk2knowledgegraphs.tools."
170
- "milvus_multimodal_subgraph_extraction.pd.read_excel")
171
- @patch("aiagents4pharma.talk2knowledgegraphs.tools."
172
- "milvus_multimodal_subgraph_extraction.MultimodalPCSTPruning")
173
- @patch("pymilvus.connections")
174
- def test_extract_multimodal_subgraph_w_doc(self,
175
- mock_connections,
176
- mock_pcst,
177
- mock_read_excel,
178
- mock_collection):
179
- """
180
- Test the multimodal subgraph extraction tool for text as modality, plus genes.
181
- """
182
- # Mock Milvus connection utilities
183
- mock_connections.has_connection.return_value = True
184
-
185
- # With uploaded_files (with doc)
186
- self.state["uploaded_files"] = [
187
- {"file_type": "multimodal", "file_path": "dummy.xlsx"}
188
- ]
189
- self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
190
- self.state["selections"] = {"gene/protein": ["JAK1", "JAK2"]}
191
-
192
- # Mock pd.read_excel to return a dict of DataFrames
193
- df = pd.DataFrame({
194
- "name": ["JAK1", "JAK2"],
195
- "node_type": ["gene/protein", "gene/protein"]
196
- })
197
- mock_read_excel.return_value = {"gene/protein": df}
198
-
199
- # Mock Collection for nodes and edges
200
- colls = {}
201
- colls["nodes"] = MagicMock()
202
- colls["nodes"] = MagicMock()
203
- colls["nodes"].query.return_value = [
204
- {"node_index": 0,
205
- "node_id": "id1",
206
- "node_name": "JAK1",
207
- "node_type": "gene/protein",
208
- "feat": "featA",
209
- "feat_emb": [0.1, 0.2, 0.3],
210
- "desc": "descA",
211
- "desc_emb": [0.1, 0.2, 0.3]},
212
- {"node_index": 1,
213
- "node_id": "id2",
214
- "node_name": "JAK2",
215
- "node_type": "gene/protein",
216
- "feat": "featB",
217
- "feat_emb": [0.4, 0.5, 0.6],
218
- "desc": "descB",
219
- "desc_emb": [0.4, 0.5, 0.6]}
220
- ]
221
- colls["nodes"].load.return_value = None
222
-
223
- colls["edges"] = MagicMock()
224
- colls["edges"].query.return_value = [
225
- {"triplet_index": 0,
226
- "head_id": "id1",
227
- "head_index": 0,
228
- "tail_id": "id2",
229
- "tail_index": 1,
230
- "edge_type": "gene/protein,ppi,gene/protein",
231
- "display_relation": "ppi",
232
- "feat": "featC",
233
- "feat_emb": [0.7, 0.8, 0.9]}
234
- ]
235
- colls["edges"].load.return_value = None
236
-
237
- def collection_side_effect(name):
238
- """
239
- Mock side effect for Collection to return nodes or edges based on name.
240
- """
241
- if "nodes" in name:
242
- return colls["nodes"]
243
- if "edges" in name:
244
- return colls["edges"]
245
- return None
246
- mock_collection.side_effect = collection_side_effect
247
-
248
- # Mock MultimodalPCSTPruning
249
- mock_pcst_instance = MagicMock()
250
- mock_pcst_instance.extract_subgraph.return_value = {
251
- "nodes": pd.Series([1, 2]),
252
- "edges": pd.Series([0])
253
- }
254
- mock_pcst.return_value = mock_pcst_instance
255
-
256
- # Patch hydra.compose to return config objects
257
- with patch("aiagents4pharma.talk2knowledgegraphs.tools."
258
- "milvus_multimodal_subgraph_extraction.hydra.initialize"), \
259
- patch("aiagents4pharma.talk2knowledgegraphs.tools."
260
- "milvus_multimodal_subgraph_extraction.hydra.compose") as mock_compose:
261
- mock_compose.return_value = MagicMock()
262
- mock_compose.return_value.app.frontend = self.cfg_db
263
- mock_compose.return_value.tools.multimodal_subgraph_extraction = self.cfg
264
-
265
- response = self.tool.invoke(
266
- input={"prompt": self.prompt,
267
- "tool_call_id": "subgraph_extraction_tool",
268
- "state": self.state,
269
- "arg_data": self.arg_data}
270
- )
271
-
272
- # Check tool message
273
- self.assertEqual(response.update["messages"][-1].tool_call_id, "subgraph_extraction_tool")
274
-
275
- # Check extracted subgraph dictionary
276
- dic_extracted_graph = response.update["dic_extracted_graph"][0]
277
- self.assertIsInstance(dic_extracted_graph, dict)
278
- self.assertEqual(dic_extracted_graph["name"], self.arg_data["extraction_name"])
279
- self.assertEqual(dic_extracted_graph["graph_source"], "TestGraph")
280
- self.assertEqual(dic_extracted_graph["topk_nodes"], 5)
281
- self.assertEqual(dic_extracted_graph["topk_edges"], 5)
282
- self.assertIsInstance(dic_extracted_graph["graph_dict"], dict)
283
- self.assertGreater(len(dic_extracted_graph["graph_dict"]["nodes"]), 0)
284
- self.assertGreater(len(dic_extracted_graph["graph_dict"]["edges"]), 0)
285
- self.assertIsInstance(dic_extracted_graph["graph_text"], str)
286
- # Check if the nodes are in the graph_text
287
- self.assertTrue(all(
288
- n[0] in dic_extracted_graph["graph_text"].replace('"', '')
289
- for subgraph_nodes in dic_extracted_graph["graph_dict"]["nodes"]
290
- for n in subgraph_nodes
291
- ))
292
- # Check if the edges are in the graph_text
293
- self.assertTrue(all(
294
- ",".join([str(e[0])] + str(e[2]['label'][0]).split(",") + [str(e[1])])
295
- in dic_extracted_graph["graph_text"].replace('"', '').\
296
- replace("[", "").replace("]", "").replace("'", "")
297
- for subgraph_edges in dic_extracted_graph["graph_dict"]["edges"]
298
- for e in subgraph_edges
299
- ))
300
-
301
- # Another test for unknown collection
302
- result = collection_side_effect("unknown")
303
- self.assertIsNone(result)
304
-
305
- def test_extract_multimodal_subgraph_wo_doc_gpu(self):
306
- """
307
- Test the multimodal subgraph extraction tool for only text as modality,
308
- simulating GPU (cudf/cupy) environment.
309
- """
310
- module_name = "aiagents4pharma.talk2knowledgegraphs.tools."+\
311
- "milvus_multimodal_subgraph_extraction"
312
- with patch.dict("sys.modules", {"cupy": np, "cudf": pd}):
313
- mod = importlib.reload(importlib.import_module(module_name))
314
- # Patch Collection and MultimodalPCSTPruning after reload
315
- with patch(f"{module_name}.Collection") as mock_collection, \
316
- patch(f"{module_name}.MultimodalPCSTPruning") as mock_pcst, \
317
- patch("pymilvus.connections") as mock_connections:
318
- # Setup mocks as in the original test
319
- mock_connections.has_connection.return_value = True
320
- colls = {}
321
- colls["nodes"] = MagicMock()
322
- colls["nodes"].query.return_value = [
323
- {"node_index": 0,
324
- "node_id": "id1",
325
- "node_name": "JAK1",
326
- "node_type": "gene/protein",
327
- "feat": "featA",
328
- "feat_emb": [0.1, 0.2, 0.3],
329
- "desc": "descA",
330
- "desc_emb": [0.1, 0.2, 0.3]},
331
- {"node_index": 1,
332
- "node_id": "id2",
333
- "node_name": "JAK2",
334
- "node_type": "gene/protein",
335
- "feat": "featB",
336
- "feat_emb": [0.4, 0.5, 0.6],
337
- "desc": "descB",
338
- "desc_emb": [0.4, 0.5, 0.6]}
339
- ]
340
- colls["nodes"].load.return_value = None
341
- colls["edges"] = MagicMock()
342
- colls["edges"].query.return_value = [
343
- {"triplet_index": 0,
344
- "head_id": "id1",
345
- "head_index": 0,
346
- "tail_id": "id2",
347
- "tail_index": 1,
348
- "edge_type": "gene/protein,ppi,gene/protein",
349
- "display_relation": "ppi",
350
- "feat": "featC",
351
- "feat_emb": [0.7, 0.8, 0.9]}
352
- ]
353
- colls["edges"].load.return_value = None
354
- def collection_side_effect(name):
355
- if "nodes" in name:
356
- return colls["nodes"]
357
- if "edges" in name:
358
- return colls["edges"]
359
- return None
360
- mock_collection.side_effect = collection_side_effect
361
- mock_pcst_instance = MagicMock()
362
- mock_pcst_instance.extract_subgraph.return_value = {
363
- "nodes": pd.Series([1, 2]),
364
- "edges": pd.Series([0])
365
- }
366
- mock_pcst.return_value = mock_pcst_instance
367
- # Setup config mocks
368
- tool_cls = getattr(mod, "MultimodalSubgraphExtractionTool")
369
- tool = tool_cls()
370
-
371
- # Patch hydra.compose
372
- with patch(f"{module_name}.hydra.initialize"), \
373
- patch(f"{module_name}.hydra.compose") as mock_compose:
374
- mock_compose.return_value = MagicMock()
375
- mock_compose.return_value.app.frontend = self.cfg_db
376
- mock_compose.return_value.tools.multimodal_subgraph_extraction = self.cfg
377
- self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
378
- self.state["selections"] = {}
379
- response = tool.invoke(
380
- input={"prompt": self.prompt,
381
- "tool_call_id": "subgraph_extraction_tool",
382
- "state": self.state,
383
- "arg_data": self.arg_data}
384
- )
385
- # Check tool message
386
- self.assertEqual(response.update["messages"][-1].tool_call_id,
387
- "subgraph_extraction_tool")
388
- dic_extracted_graph = response.update["dic_extracted_graph"][0]
389
- self.assertIsInstance(dic_extracted_graph, dict)
390
- self.assertEqual(dic_extracted_graph["name"], self.arg_data["extraction_name"])
391
- self.assertEqual(dic_extracted_graph["graph_source"], "TestGraph")
392
- self.assertEqual(dic_extracted_graph["topk_nodes"], 5)
393
- self.assertEqual(dic_extracted_graph["topk_edges"], 5)
394
- self.assertIsInstance(dic_extracted_graph["graph_dict"], dict)
395
- self.assertGreater(len(dic_extracted_graph["graph_dict"]["nodes"]), 0)
396
- self.assertGreater(len(dic_extracted_graph["graph_dict"]["edges"]), 0)
397
- self.assertIsInstance(dic_extracted_graph["graph_text"], str)
398
- self.assertTrue(all(
399
- n[0] in dic_extracted_graph["graph_text"].replace('"', '')
400
- for subgraph_nodes in dic_extracted_graph["graph_dict"]["nodes"]
401
- for n in subgraph_nodes
402
- ))
403
- self.assertTrue(all(
404
- ",".join([str(e[0])] + str(e[2]['label'][0]).split(",") + [str(e[1])])
405
- in dic_extracted_graph["graph_text"].replace('"', '').\
406
- replace("[", "").replace("]", "").replace("'", "")
407
- for subgraph_edges in dic_extracted_graph["graph_dict"]["edges"]
408
- for e in subgraph_edges
409
- ))
410
-
411
- # Another test for unknown collection
412
- result = collection_side_effect("unknown")
413
- self.assertIsNone(result)
1
+ """
2
+ Test cases for tools/milvus_multimodal_subgraph_extraction.py
3
+ """
4
+
5
+ import importlib
6
+ import unittest
7
+ from unittest.mock import patch, MagicMock
8
+ import numpy as np
9
+ import pandas as pd
10
+ from ..tools.milvus_multimodal_subgraph_extraction import MultimodalSubgraphExtractionTool
11
+
12
+ class TestMultimodalSubgraphExtractionTool(unittest.TestCase):
13
+ """
14
+ Test cases for MultimodalSubgraphExtractionTool (Milvus)
15
+ """
16
+ def setUp(self):
17
+ self.tool = MultimodalSubgraphExtractionTool()
18
+ self.state = {
19
+ "uploaded_files": [],
20
+ "embedding_model": MagicMock(),
21
+ "topk_nodes": 5,
22
+ "topk_edges": 5,
23
+ "dic_source_graph": [{"name": "TestGraph"}],
24
+ }
25
+ self.prompt = "Find subgraph for test"
26
+ self.arg_data = {"extraction_name": "subkg_12345"}
27
+ self.cfg_db = MagicMock()
28
+ self.cfg_db.milvus_db.database_name = "testdb"
29
+ self.cfg_db.milvus_db.alias = "default"
30
+ self.cfg = MagicMock()
31
+ self.cfg.cost_e = 1.0
32
+ self.cfg.c_const = 1.0
33
+ self.cfg.root = 0
34
+ self.cfg.num_clusters = 1
35
+ self.cfg.pruning = True
36
+ self.cfg.verbosity_level = 0
37
+ self.cfg.search_metric_type = "L2"
38
+ self.cfg.node_colors_dict = {"gene/protein": "red"}
39
+
40
+ @patch("aiagents4pharma.talk2knowledgegraphs.tools."
41
+ "milvus_multimodal_subgraph_extraction.Collection")
42
+ @patch("aiagents4pharma.talk2knowledgegraphs.tools."
43
+ "milvus_multimodal_subgraph_extraction.MultimodalPCSTPruning")
44
+ @patch("pymilvus.connections")
45
+ def test_extract_multimodal_subgraph_wo_doc(self,
46
+ mock_connections,
47
+ mock_pcst,
48
+ mock_collection):
49
+ """
50
+ Test the multimodal subgraph extraction tool for only text as modality.
51
+ """
52
+
53
+ # Mock Milvus connection utilities
54
+ mock_connections.has_connection.return_value = True
55
+
56
+ # No uploaded_files (no doc)
57
+ self.state["uploaded_files"] = []
58
+ self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
59
+ self.state["selections"] = {}
60
+
61
+ # Mock Collection for nodes and edges
62
+ colls = {}
63
+ colls["nodes"] = MagicMock()
64
+ colls["nodes"] = MagicMock()
65
+ colls["nodes"].query.return_value = [
66
+ {"node_index": 0,
67
+ "node_id": "id1",
68
+ "node_name": "JAK1",
69
+ "node_type": "gene/protein",
70
+ "feat": "featA",
71
+ "feat_emb": [0.1, 0.2, 0.3],
72
+ "desc": "descA",
73
+ "desc_emb": [0.1, 0.2, 0.3]},
74
+ {"node_index": 1,
75
+ "node_id": "id2",
76
+ "node_name": "JAK2",
77
+ "node_type": "gene/protein",
78
+ "feat": "featB",
79
+ "feat_emb": [0.4, 0.5, 0.6],
80
+ "desc": "descB",
81
+ "desc_emb": [0.4, 0.5, 0.6]}
82
+ ]
83
+ colls["nodes"].load.return_value = None
84
+
85
+ colls["edges"] = MagicMock()
86
+ colls["edges"].query.return_value = [
87
+ {"triplet_index": 0,
88
+ "head_id": "id1",
89
+ "head_index": 0,
90
+ "tail_id": "id2",
91
+ "tail_index": 1,
92
+ "edge_type": "gene/protein,ppi,gene/protein",
93
+ "display_relation": "ppi",
94
+ "feat": "featC",
95
+ "feat_emb": [0.7, 0.8, 0.9]}
96
+ ]
97
+ colls["edges"].load.return_value = None
98
+
99
+ def collection_side_effect(name):
100
+ """
101
+ Mock side effect for Collection to return nodes or edges based on name.
102
+ """
103
+ if "nodes" in name:
104
+ return colls["nodes"]
105
+ if "edges" in name:
106
+ return colls["edges"]
107
+ return None
108
+ mock_collection.side_effect = collection_side_effect
109
+
110
+ # Mock MultimodalPCSTPruning
111
+ mock_pcst_instance = MagicMock()
112
+ mock_pcst_instance.extract_subgraph.return_value = {
113
+ "nodes": pd.Series([1, 2]),
114
+ "edges": pd.Series([0])
115
+ }
116
+ mock_pcst.return_value = mock_pcst_instance
117
+
118
+ # Patch hydra.compose to return config objects
119
+ with patch("aiagents4pharma.talk2knowledgegraphs.tools."
120
+ "milvus_multimodal_subgraph_extraction.hydra.initialize"), \
121
+ patch("aiagents4pharma.talk2knowledgegraphs.tools."
122
+ "milvus_multimodal_subgraph_extraction.hydra.compose") as mock_compose:
123
+ mock_compose.return_value = MagicMock()
124
+ mock_compose.return_value.app.frontend = self.cfg_db
125
+ mock_compose.return_value.tools.multimodal_subgraph_extraction = self.cfg
126
+
127
+ response = self.tool.invoke(
128
+ input={"prompt": self.prompt,
129
+ "tool_call_id": "subgraph_extraction_tool",
130
+ "state": self.state,
131
+ "arg_data": self.arg_data}
132
+ )
133
+
134
+ # Check tool message
135
+ self.assertEqual(response.update["messages"][-1].tool_call_id, "subgraph_extraction_tool")
136
+
137
+ # Check extracted subgraph dictionary
138
+ dic_extracted_graph = response.update["dic_extracted_graph"][0]
139
+ self.assertIsInstance(dic_extracted_graph, dict)
140
+ self.assertEqual(dic_extracted_graph["name"], self.arg_data["extraction_name"])
141
+ self.assertEqual(dic_extracted_graph["graph_source"], "TestGraph")
142
+ self.assertEqual(dic_extracted_graph["topk_nodes"], 5)
143
+ self.assertEqual(dic_extracted_graph["topk_edges"], 5)
144
+ self.assertIsInstance(dic_extracted_graph["graph_dict"], dict)
145
+ self.assertGreater(len(dic_extracted_graph["graph_dict"]["nodes"]), 0)
146
+ self.assertGreater(len(dic_extracted_graph["graph_dict"]["edges"]), 0)
147
+ self.assertIsInstance(dic_extracted_graph["graph_text"], str)
148
+ # Check if the nodes are in the graph_text
149
+ self.assertTrue(all(
150
+ n[0] in dic_extracted_graph["graph_text"].replace('"', '')
151
+ for subgraph_nodes in dic_extracted_graph["graph_dict"]["nodes"]
152
+ for n in subgraph_nodes
153
+ ))
154
+ # Check if the edges are in the graph_text
155
+ self.assertTrue(all(
156
+ ",".join([str(e[0])] + str(e[2]['label'][0]).split(",") + [str(e[1])])
157
+ in dic_extracted_graph["graph_text"].replace('"', '').\
158
+ replace("[", "").replace("]", "").replace("'", "")
159
+ for subgraph_edges in dic_extracted_graph["graph_dict"]["edges"]
160
+ for e in subgraph_edges
161
+ ))
162
+
163
+ # Another test for unknown collection
164
+ result = collection_side_effect("unknown")
165
+ self.assertIsNone(result)
166
+
167
+ @patch("aiagents4pharma.talk2knowledgegraphs.tools."
168
+ "milvus_multimodal_subgraph_extraction.Collection")
169
+ @patch("aiagents4pharma.talk2knowledgegraphs.tools."
170
+ "milvus_multimodal_subgraph_extraction.pd.read_excel")
171
+ @patch("aiagents4pharma.talk2knowledgegraphs.tools."
172
+ "milvus_multimodal_subgraph_extraction.MultimodalPCSTPruning")
173
+ @patch("pymilvus.connections")
174
+ def test_extract_multimodal_subgraph_w_doc(self,
175
+ mock_connections,
176
+ mock_pcst,
177
+ mock_read_excel,
178
+ mock_collection):
179
+ """
180
+ Test the multimodal subgraph extraction tool for text as modality, plus genes.
181
+ """
182
+ # Mock Milvus connection utilities
183
+ mock_connections.has_connection.return_value = True
184
+
185
+ # With uploaded_files (with doc)
186
+ self.state["uploaded_files"] = [
187
+ {"file_type": "multimodal", "file_path": "dummy.xlsx"}
188
+ ]
189
+ self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
190
+ self.state["selections"] = {"gene/protein": ["JAK1", "JAK2"]}
191
+
192
+ # Mock pd.read_excel to return a dict of DataFrames
193
+ df = pd.DataFrame({
194
+ "name": ["JAK1", "JAK2"],
195
+ "node_type": ["gene/protein", "gene/protein"]
196
+ })
197
+ mock_read_excel.return_value = {"gene/protein": df}
198
+
199
+ # Mock Collection for nodes and edges
200
+ colls = {}
201
+ colls["nodes"] = MagicMock()
202
+ colls["nodes"] = MagicMock()
203
+ colls["nodes"].query.return_value = [
204
+ {"node_index": 0,
205
+ "node_id": "id1",
206
+ "node_name": "JAK1",
207
+ "node_type": "gene/protein",
208
+ "feat": "featA",
209
+ "feat_emb": [0.1, 0.2, 0.3],
210
+ "desc": "descA",
211
+ "desc_emb": [0.1, 0.2, 0.3]},
212
+ {"node_index": 1,
213
+ "node_id": "id2",
214
+ "node_name": "JAK2",
215
+ "node_type": "gene/protein",
216
+ "feat": "featB",
217
+ "feat_emb": [0.4, 0.5, 0.6],
218
+ "desc": "descB",
219
+ "desc_emb": [0.4, 0.5, 0.6]}
220
+ ]
221
+ colls["nodes"].load.return_value = None
222
+
223
+ colls["edges"] = MagicMock()
224
+ colls["edges"].query.return_value = [
225
+ {"triplet_index": 0,
226
+ "head_id": "id1",
227
+ "head_index": 0,
228
+ "tail_id": "id2",
229
+ "tail_index": 1,
230
+ "edge_type": "gene/protein,ppi,gene/protein",
231
+ "display_relation": "ppi",
232
+ "feat": "featC",
233
+ "feat_emb": [0.7, 0.8, 0.9]}
234
+ ]
235
+ colls["edges"].load.return_value = None
236
+
237
+ def collection_side_effect(name):
238
+ """
239
+ Mock side effect for Collection to return nodes or edges based on name.
240
+ """
241
+ if "nodes" in name:
242
+ return colls["nodes"]
243
+ if "edges" in name:
244
+ return colls["edges"]
245
+ return None
246
+ mock_collection.side_effect = collection_side_effect
247
+
248
+ # Mock MultimodalPCSTPruning
249
+ mock_pcst_instance = MagicMock()
250
+ mock_pcst_instance.extract_subgraph.return_value = {
251
+ "nodes": pd.Series([1, 2]),
252
+ "edges": pd.Series([0])
253
+ }
254
+ mock_pcst.return_value = mock_pcst_instance
255
+
256
+ # Patch hydra.compose to return config objects
257
+ with patch("aiagents4pharma.talk2knowledgegraphs.tools."
258
+ "milvus_multimodal_subgraph_extraction.hydra.initialize"), \
259
+ patch("aiagents4pharma.talk2knowledgegraphs.tools."
260
+ "milvus_multimodal_subgraph_extraction.hydra.compose") as mock_compose:
261
+ mock_compose.return_value = MagicMock()
262
+ mock_compose.return_value.app.frontend = self.cfg_db
263
+ mock_compose.return_value.tools.multimodal_subgraph_extraction = self.cfg
264
+
265
+ response = self.tool.invoke(
266
+ input={"prompt": self.prompt,
267
+ "tool_call_id": "subgraph_extraction_tool",
268
+ "state": self.state,
269
+ "arg_data": self.arg_data}
270
+ )
271
+
272
+ # Check tool message
273
+ self.assertEqual(response.update["messages"][-1].tool_call_id, "subgraph_extraction_tool")
274
+
275
+ # Check extracted subgraph dictionary
276
+ dic_extracted_graph = response.update["dic_extracted_graph"][0]
277
+ self.assertIsInstance(dic_extracted_graph, dict)
278
+ self.assertEqual(dic_extracted_graph["name"], self.arg_data["extraction_name"])
279
+ self.assertEqual(dic_extracted_graph["graph_source"], "TestGraph")
280
+ self.assertEqual(dic_extracted_graph["topk_nodes"], 5)
281
+ self.assertEqual(dic_extracted_graph["topk_edges"], 5)
282
+ self.assertIsInstance(dic_extracted_graph["graph_dict"], dict)
283
+ self.assertGreater(len(dic_extracted_graph["graph_dict"]["nodes"]), 0)
284
+ self.assertGreater(len(dic_extracted_graph["graph_dict"]["edges"]), 0)
285
+ self.assertIsInstance(dic_extracted_graph["graph_text"], str)
286
+ # Check if the nodes are in the graph_text
287
+ self.assertTrue(all(
288
+ n[0] in dic_extracted_graph["graph_text"].replace('"', '')
289
+ for subgraph_nodes in dic_extracted_graph["graph_dict"]["nodes"]
290
+ for n in subgraph_nodes
291
+ ))
292
+ # Check if the edges are in the graph_text
293
+ self.assertTrue(all(
294
+ ",".join([str(e[0])] + str(e[2]['label'][0]).split(",") + [str(e[1])])
295
+ in dic_extracted_graph["graph_text"].replace('"', '').\
296
+ replace("[", "").replace("]", "").replace("'", "")
297
+ for subgraph_edges in dic_extracted_graph["graph_dict"]["edges"]
298
+ for e in subgraph_edges
299
+ ))
300
+
301
+ # Another test for unknown collection
302
+ result = collection_side_effect("unknown")
303
+ self.assertIsNone(result)
304
+
305
+ def test_extract_multimodal_subgraph_wo_doc_gpu(self):
306
+ """
307
+ Test the multimodal subgraph extraction tool for only text as modality,
308
+ simulating GPU (cudf/cupy) environment.
309
+ """
310
+ module_name = "aiagents4pharma.talk2knowledgegraphs.tools."+\
311
+ "milvus_multimodal_subgraph_extraction"
312
+ with patch.dict("sys.modules", {"cupy": np, "cudf": pd}):
313
+ mod = importlib.reload(importlib.import_module(module_name))
314
+ # Patch Collection and MultimodalPCSTPruning after reload
315
+ with patch(f"{module_name}.Collection") as mock_collection, \
316
+ patch(f"{module_name}.MultimodalPCSTPruning") as mock_pcst, \
317
+ patch("pymilvus.connections") as mock_connections:
318
+ # Setup mocks as in the original test
319
+ mock_connections.has_connection.return_value = True
320
+ colls = {}
321
+ colls["nodes"] = MagicMock()
322
+ colls["nodes"].query.return_value = [
323
+ {"node_index": 0,
324
+ "node_id": "id1",
325
+ "node_name": "JAK1",
326
+ "node_type": "gene/protein",
327
+ "feat": "featA",
328
+ "feat_emb": [0.1, 0.2, 0.3],
329
+ "desc": "descA",
330
+ "desc_emb": [0.1, 0.2, 0.3]},
331
+ {"node_index": 1,
332
+ "node_id": "id2",
333
+ "node_name": "JAK2",
334
+ "node_type": "gene/protein",
335
+ "feat": "featB",
336
+ "feat_emb": [0.4, 0.5, 0.6],
337
+ "desc": "descB",
338
+ "desc_emb": [0.4, 0.5, 0.6]}
339
+ ]
340
+ colls["nodes"].load.return_value = None
341
+ colls["edges"] = MagicMock()
342
+ colls["edges"].query.return_value = [
343
+ {"triplet_index": 0,
344
+ "head_id": "id1",
345
+ "head_index": 0,
346
+ "tail_id": "id2",
347
+ "tail_index": 1,
348
+ "edge_type": "gene/protein,ppi,gene/protein",
349
+ "display_relation": "ppi",
350
+ "feat": "featC",
351
+ "feat_emb": [0.7, 0.8, 0.9]}
352
+ ]
353
+ colls["edges"].load.return_value = None
354
+ def collection_side_effect(name):
355
+ if "nodes" in name:
356
+ return colls["nodes"]
357
+ if "edges" in name:
358
+ return colls["edges"]
359
+ return None
360
+ mock_collection.side_effect = collection_side_effect
361
+ mock_pcst_instance = MagicMock()
362
+ mock_pcst_instance.extract_subgraph.return_value = {
363
+ "nodes": pd.Series([1, 2]),
364
+ "edges": pd.Series([0])
365
+ }
366
+ mock_pcst.return_value = mock_pcst_instance
367
+ # Setup config mocks
368
+ tool_cls = getattr(mod, "MultimodalSubgraphExtractionTool")
369
+ tool = tool_cls()
370
+
371
+ # Patch hydra.compose
372
+ with patch(f"{module_name}.hydra.initialize"), \
373
+ patch(f"{module_name}.hydra.compose") as mock_compose:
374
+ mock_compose.return_value = MagicMock()
375
+ mock_compose.return_value.app.frontend = self.cfg_db
376
+ mock_compose.return_value.tools.multimodal_subgraph_extraction = self.cfg
377
+ self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
378
+ self.state["selections"] = {}
379
+ response = tool.invoke(
380
+ input={"prompt": self.prompt,
381
+ "tool_call_id": "subgraph_extraction_tool",
382
+ "state": self.state,
383
+ "arg_data": self.arg_data}
384
+ )
385
+ # Check tool message
386
+ self.assertEqual(response.update["messages"][-1].tool_call_id,
387
+ "subgraph_extraction_tool")
388
+ dic_extracted_graph = response.update["dic_extracted_graph"][0]
389
+ self.assertIsInstance(dic_extracted_graph, dict)
390
+ self.assertEqual(dic_extracted_graph["name"], self.arg_data["extraction_name"])
391
+ self.assertEqual(dic_extracted_graph["graph_source"], "TestGraph")
392
+ self.assertEqual(dic_extracted_graph["topk_nodes"], 5)
393
+ self.assertEqual(dic_extracted_graph["topk_edges"], 5)
394
+ self.assertIsInstance(dic_extracted_graph["graph_dict"], dict)
395
+ self.assertGreater(len(dic_extracted_graph["graph_dict"]["nodes"]), 0)
396
+ self.assertGreater(len(dic_extracted_graph["graph_dict"]["edges"]), 0)
397
+ self.assertIsInstance(dic_extracted_graph["graph_text"], str)
398
+ self.assertTrue(all(
399
+ n[0] in dic_extracted_graph["graph_text"].replace('"', '')
400
+ for subgraph_nodes in dic_extracted_graph["graph_dict"]["nodes"]
401
+ for n in subgraph_nodes
402
+ ))
403
+ self.assertTrue(all(
404
+ ",".join([str(e[0])] + str(e[2]['label'][0]).split(",") + [str(e[1])])
405
+ in dic_extracted_graph["graph_text"].replace('"', '').\
406
+ replace("[", "").replace("]", "").replace("'", "")
407
+ for subgraph_edges in dic_extracted_graph["graph_dict"]["edges"]
408
+ for e in subgraph_edges
409
+ ))
410
+
411
+ # Another test for unknown collection
412
+ result = collection_side_effect("unknown")
413
+ self.assertIsNone(result)
414
+
415
+ def test_normalize_vector_gpu_mode(self):
416
+ """Test normalize_vector method in GPU mode."""
417
+ # Mock the loader to simulate GPU mode
418
+ self.tool.loader.normalize_vectors = True
419
+ self.tool.loader.py = MagicMock()
420
+ # Mock the GPU array operations
421
+ mock_array = MagicMock()
422
+ mock_norm = MagicMock()
423
+ mock_norm.return_value = 2.0
424
+ mock_array.__truediv__ = MagicMock(return_value=mock_array)
425
+ mock_array.tolist.return_value = [0.5, 1.0, 1.5]
426
+ self.tool.loader.py.asarray.return_value = mock_array
427
+ self.tool.loader.py.linalg.norm.return_value = mock_norm
428
+ result = self.tool.normalize_vector([1.0, 2.0, 3.0])
429
+ # Verify the result
430
+ self.assertEqual(result, [0.5, 1.0, 1.5])
431
+ self.tool.loader.py.asarray.assert_called_once_with([1.0, 2.0, 3.0])
432
+ self.tool.loader.py.linalg.norm.assert_called_once_with(mock_array)
433
+
434
+ def test_normalize_vector_cpu_mode(self):
435
+ """Test normalize_vector method in CPU mode."""
436
+ # Mock the loader to simulate CPU mode
437
+ self.tool.loader.normalize_vectors = False
438
+ result = self.tool.normalize_vector([1.0, 2.0, 3.0])
439
+ # In CPU mode, should return the input as-is
440
+ self.assertEqual(result, [1.0, 2.0, 3.0])
441
+
442
+ @patch("aiagents4pharma.talk2knowledgegraphs.tools."
443
+ "milvus_multimodal_subgraph_extraction.Collection")
444
+ @patch("aiagents4pharma.talk2knowledgegraphs.tools."
445
+ "milvus_multimodal_subgraph_extraction.MultimodalPCSTPruning")
446
+ @patch("pymilvus.connections")
447
+ def test_extract_multimodal_subgraph_no_vector_processing(self,
448
+ mock_connections,
449
+ mock_pcst,
450
+ mock_collection):
451
+ """Test when vector_processing config is not present."""
452
+ # Mock Milvus connection utilities
453
+ mock_connections.has_connection.return_value = True
454
+
455
+ self.state["uploaded_files"] = []
456
+ self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
457
+ self.state["selections"] = {}
458
+
459
+ # Mock Collection for nodes and edges
460
+ colls = {}
461
+ colls["nodes"] = MagicMock()
462
+ colls["nodes"].query.return_value = [
463
+ {"node_index": 0, "node_id": "id1", "node_name": "JAK1",
464
+ "node_type": "gene/protein", "feat": "featA", "feat_emb": [0.1, 0.2, 0.3],
465
+ "desc": "descA", "desc_emb": [0.1, 0.2, 0.3]}
466
+ ]
467
+ colls["nodes"].load.return_value = None
468
+
469
+ colls["edges"] = MagicMock()
470
+ colls["edges"].query.return_value = [
471
+ {"triplet_index": 0, "head_id": "id1", "tail_id": "id2",
472
+ "edge_type": "gene/protein,ppi,gene/protein"}
473
+ ]
474
+ colls["edges"].load.return_value = None
475
+
476
+ def collection_side_effect(name):
477
+ if "nodes" in name:
478
+ return colls["nodes"]
479
+ if "edges" in name:
480
+ return colls["edges"]
481
+ return None
482
+ mock_collection.side_effect = collection_side_effect
483
+
484
+ # Mock MultimodalPCSTPruning
485
+ mock_pcst_instance = MagicMock()
486
+ mock_pcst_instance.extract_subgraph.return_value = {
487
+ "nodes": pd.Series([1]),
488
+ "edges": pd.Series([0])
489
+ }
490
+ mock_pcst.return_value = mock_pcst_instance
491
+
492
+ # Create config without vector_processing attribute
493
+ cfg_no_vector_processing = MagicMock()
494
+ cfg_no_vector_processing.cost_e = 1.0
495
+ cfg_no_vector_processing.c_const = 1.0
496
+ cfg_no_vector_processing.root = 0
497
+ cfg_no_vector_processing.num_clusters = 1
498
+ cfg_no_vector_processing.pruning = True
499
+ cfg_no_vector_processing.verbosity_level = 0
500
+ cfg_no_vector_processing.search_metric_type = "L2"
501
+ cfg_no_vector_processing.node_colors_dict = {"gene/protein": "red"}
502
+ # Remove vector_processing attribute to test the missing branch
503
+ del cfg_no_vector_processing.vector_processing
504
+
505
+ # Patch hydra.compose to return config without vector_processing
506
+ with patch("aiagents4pharma.talk2knowledgegraphs.tools."
507
+ "milvus_multimodal_subgraph_extraction.hydra.initialize"), \
508
+ patch("aiagents4pharma.talk2knowledgegraphs.tools."
509
+ "milvus_multimodal_subgraph_extraction.hydra.compose") as mock_compose:
510
+ mock_compose.return_value = MagicMock()
511
+ mock_compose.return_value.app.frontend = self.cfg_db
512
+ mock_compose.return_value.tools.multimodal_subgraph_extraction = \
513
+ cfg_no_vector_processing
514
+
515
+ response = self.tool.invoke(
516
+ input={"prompt": self.prompt,
517
+ "tool_call_id": "subgraph_extraction_tool",
518
+ "state": self.state,
519
+ "arg_data": self.arg_data}
520
+ )
521
+
522
+ # Verify the test completed successfully
523
+ self.assertEqual(response.update["messages"][-1].tool_call_id, "subgraph_extraction_tool")
524
+
525
+ # Test the collection_side_effect with unknown name for final test
526
+ result = collection_side_effect("final_unknown_collection")
527
+ self.assertIsNone(result)
528
+
529
+ # Test the collection_side_effect with unknown name
530
+ result = collection_side_effect("unknown_collection")
531
+ self.assertIsNone(result)
532
+
533
+ @patch("aiagents4pharma.talk2knowledgegraphs.tools."
534
+ "milvus_multimodal_subgraph_extraction.Collection")
535
+ @patch("aiagents4pharma.talk2knowledgegraphs.tools."
536
+ "milvus_multimodal_subgraph_extraction.MultimodalPCSTPruning")
537
+ @patch("pymilvus.connections")
538
+ def test_extract_multimodal_subgraph_dynamic_metrics_disabled(self,
539
+ mock_connections,
540
+ mock_pcst,
541
+ mock_collection):
542
+ """Test when dynamic_metrics is disabled."""
543
+ # Mock Milvus connection utilities
544
+ mock_connections.has_connection.return_value = True
545
+
546
+ self.state["uploaded_files"] = []
547
+ self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
548
+ self.state["selections"] = {}
549
+
550
+ # Mock Collection for nodes and edges
551
+ colls = {}
552
+ colls["nodes"] = MagicMock()
553
+ colls["nodes"].query.return_value = [
554
+ {"node_index": 0, "node_id": "id1", "node_name": "JAK1",
555
+ "node_type": "gene/protein", "feat": "featA", "feat_emb": [0.1, 0.2, 0.3],
556
+ "desc": "descA", "desc_emb": [0.1, 0.2, 0.3]}
557
+ ]
558
+ colls["nodes"].load.return_value = None
559
+
560
+ colls["edges"] = MagicMock()
561
+ colls["edges"].query.return_value = [
562
+ {"triplet_index": 0, "head_id": "id1", "tail_id": "id2",
563
+ "edge_type": "gene/protein,ppi,gene/protein"}
564
+ ]
565
+ colls["edges"].load.return_value = None
566
+
567
+ def collection_side_effect(name):
568
+ if "nodes" in name:
569
+ return colls["nodes"]
570
+ if "edges" in name:
571
+ return colls["edges"]
572
+ return None
573
+ mock_collection.side_effect = collection_side_effect
574
+
575
+ # Mock MultimodalPCSTPruning
576
+ mock_pcst_instance = MagicMock()
577
+ mock_pcst_instance.extract_subgraph.return_value = {
578
+ "nodes": pd.Series([1]),
579
+ "edges": pd.Series([0])
580
+ }
581
+ mock_pcst.return_value = mock_pcst_instance
582
+
583
+ # Create config with dynamic_metrics disabled
584
+ cfg_dynamic_disabled = MagicMock()
585
+ cfg_dynamic_disabled.cost_e = 1.0
586
+ cfg_dynamic_disabled.c_const = 1.0
587
+ cfg_dynamic_disabled.root = 0
588
+ cfg_dynamic_disabled.num_clusters = 1
589
+ cfg_dynamic_disabled.pruning = True
590
+ cfg_dynamic_disabled.verbosity_level = 0
591
+ cfg_dynamic_disabled.search_metric_type = "L2"
592
+ cfg_dynamic_disabled.node_colors_dict = {"gene/protein": "red"}
593
+ # Set dynamic_metrics to False
594
+ cfg_dynamic_disabled.vector_processing = MagicMock()
595
+ cfg_dynamic_disabled.vector_processing.dynamic_metrics = False
596
+
597
+ # Patch hydra.compose to return config with dynamic_metrics disabled
598
+ with patch("aiagents4pharma.talk2knowledgegraphs.tools."
599
+ "milvus_multimodal_subgraph_extraction.hydra.initialize"), \
600
+ patch("aiagents4pharma.talk2knowledgegraphs.tools."
601
+ "milvus_multimodal_subgraph_extraction.hydra.compose") as mock_compose:
602
+ mock_compose.return_value = MagicMock()
603
+ mock_compose.return_value.app.frontend = self.cfg_db
604
+ mock_compose.return_value.tools.multimodal_subgraph_extraction = cfg_dynamic_disabled
605
+
606
+ response = self.tool.invoke(
607
+ input={"prompt": self.prompt,
608
+ "tool_call_id": "subgraph_extraction_tool",
609
+ "state": self.state,
610
+ "arg_data": self.arg_data}
611
+ )
612
+
613
+ # Verify the test completed successfully
614
+ self.assertEqual(response.update["messages"][-1].tool_call_id, "subgraph_extraction_tool")
615
+
616
+ # Test the collection_side_effect with unknown name for final test
617
+ result = collection_side_effect("final_unknown_collection")
618
+ self.assertIsNone(result)