aiagents4pharma 1.39.4__py3-none-any.whl → 1.40.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. aiagents4pharma/talk2aiagents4pharma/configs/agents/main_agent/default.yaml +26 -13
  2. aiagents4pharma/talk2aiagents4pharma/tests/test_main_agent.py +83 -3
  3. aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +4 -1
  4. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +36 -5
  5. aiagents4pharma/talk2knowledgegraphs/milvus_data_dump.py +509 -0
  6. aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +85 -23
  7. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py +413 -0
  8. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ols.py +10 -10
  9. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_extractions_milvus_multimodal_pcst.py +175 -0
  10. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_pubchem_utils.py +11 -0
  11. aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +1 -0
  12. aiagents4pharma/talk2knowledgegraphs/tools/milvus_multimodal_subgraph_extraction.py +509 -0
  13. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ols_terms.py +15 -7
  14. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/pubchem_strings.py +31 -9
  15. aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +1 -0
  16. aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py +393 -0
  17. aiagents4pharma/talk2knowledgegraphs/utils/pubchem_utils.py +33 -2
  18. {aiagents4pharma-1.39.4.dist-info → aiagents4pharma-1.40.0.dist-info}/METADATA +13 -14
  19. {aiagents4pharma-1.39.4.dist-info → aiagents4pharma-1.40.0.dist-info}/RECORD +22 -17
  20. {aiagents4pharma-1.39.4.dist-info → aiagents4pharma-1.40.0.dist-info}/WHEEL +0 -0
  21. {aiagents4pharma-1.39.4.dist-info → aiagents4pharma-1.40.0.dist-info}/licenses/LICENSE +0 -0
  22. {aiagents4pharma-1.39.4.dist-info → aiagents4pharma-1.40.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,413 @@
1
+ """
2
+ Test cases for tools/milvus_multimodal_subgraph_extraction.py
3
+ """
4
+
5
+ import importlib
6
+ import unittest
7
+ from unittest.mock import patch, MagicMock
8
+ import numpy as np
9
+ import pandas as pd
10
+ from ..tools.milvus_multimodal_subgraph_extraction import MultimodalSubgraphExtractionTool
11
+
12
+ class TestMultimodalSubgraphExtractionTool(unittest.TestCase):
13
+ """
14
+ Test cases for MultimodalSubgraphExtractionTool (Milvus)
15
+ """
16
+ def setUp(self):
17
+ self.tool = MultimodalSubgraphExtractionTool()
18
+ self.state = {
19
+ "uploaded_files": [],
20
+ "embedding_model": MagicMock(),
21
+ "topk_nodes": 5,
22
+ "topk_edges": 5,
23
+ "dic_source_graph": [{"name": "TestGraph"}],
24
+ }
25
+ self.prompt = "Find subgraph for test"
26
+ self.arg_data = {"extraction_name": "subkg_12345"}
27
+ self.cfg_db = MagicMock()
28
+ self.cfg_db.milvus_db.database_name = "testdb"
29
+ self.cfg_db.milvus_db.alias = "default"
30
+ self.cfg = MagicMock()
31
+ self.cfg.cost_e = 1.0
32
+ self.cfg.c_const = 1.0
33
+ self.cfg.root = 0
34
+ self.cfg.num_clusters = 1
35
+ self.cfg.pruning = True
36
+ self.cfg.verbosity_level = 0
37
+ self.cfg.search_metric_type = "L2"
38
+ self.cfg.node_colors_dict = {"gene/protein": "red"}
39
+
40
+ @patch("aiagents4pharma.talk2knowledgegraphs.tools."
41
+ "milvus_multimodal_subgraph_extraction.Collection")
42
+ @patch("aiagents4pharma.talk2knowledgegraphs.tools."
43
+ "milvus_multimodal_subgraph_extraction.MultimodalPCSTPruning")
44
+ @patch("pymilvus.connections")
45
+ def test_extract_multimodal_subgraph_wo_doc(self,
46
+ mock_connections,
47
+ mock_pcst,
48
+ mock_collection):
49
+ """
50
+ Test the multimodal subgraph extraction tool for only text as modality.
51
+ """
52
+
53
+ # Mock Milvus connection utilities
54
+ mock_connections.has_connection.return_value = True
55
+
56
+ # No uploaded_files (no doc)
57
+ self.state["uploaded_files"] = []
58
+ self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
59
+ self.state["selections"] = {}
60
+
61
+ # Mock Collection for nodes and edges
62
+ colls = {}
63
+ colls["nodes"] = MagicMock()
64
+ colls["nodes"] = MagicMock()
65
+ colls["nodes"].query.return_value = [
66
+ {"node_index": 0,
67
+ "node_id": "id1",
68
+ "node_name": "JAK1",
69
+ "node_type": "gene/protein",
70
+ "feat": "featA",
71
+ "feat_emb": [0.1, 0.2, 0.3],
72
+ "desc": "descA",
73
+ "desc_emb": [0.1, 0.2, 0.3]},
74
+ {"node_index": 1,
75
+ "node_id": "id2",
76
+ "node_name": "JAK2",
77
+ "node_type": "gene/protein",
78
+ "feat": "featB",
79
+ "feat_emb": [0.4, 0.5, 0.6],
80
+ "desc": "descB",
81
+ "desc_emb": [0.4, 0.5, 0.6]}
82
+ ]
83
+ colls["nodes"].load.return_value = None
84
+
85
+ colls["edges"] = MagicMock()
86
+ colls["edges"].query.return_value = [
87
+ {"triplet_index": 0,
88
+ "head_id": "id1",
89
+ "head_index": 0,
90
+ "tail_id": "id2",
91
+ "tail_index": 1,
92
+ "edge_type": "gene/protein,ppi,gene/protein",
93
+ "display_relation": "ppi",
94
+ "feat": "featC",
95
+ "feat_emb": [0.7, 0.8, 0.9]}
96
+ ]
97
+ colls["edges"].load.return_value = None
98
+
99
+ def collection_side_effect(name):
100
+ """
101
+ Mock side effect for Collection to return nodes or edges based on name.
102
+ """
103
+ if "nodes" in name:
104
+ return colls["nodes"]
105
+ if "edges" in name:
106
+ return colls["edges"]
107
+ return None
108
+ mock_collection.side_effect = collection_side_effect
109
+
110
+ # Mock MultimodalPCSTPruning
111
+ mock_pcst_instance = MagicMock()
112
+ mock_pcst_instance.extract_subgraph.return_value = {
113
+ "nodes": pd.Series([1, 2]),
114
+ "edges": pd.Series([0])
115
+ }
116
+ mock_pcst.return_value = mock_pcst_instance
117
+
118
+ # Patch hydra.compose to return config objects
119
+ with patch("aiagents4pharma.talk2knowledgegraphs.tools."
120
+ "milvus_multimodal_subgraph_extraction.hydra.initialize"), \
121
+ patch("aiagents4pharma.talk2knowledgegraphs.tools."
122
+ "milvus_multimodal_subgraph_extraction.hydra.compose") as mock_compose:
123
+ mock_compose.return_value = MagicMock()
124
+ mock_compose.return_value.app.frontend = self.cfg_db
125
+ mock_compose.return_value.tools.multimodal_subgraph_extraction = self.cfg
126
+
127
+ response = self.tool.invoke(
128
+ input={"prompt": self.prompt,
129
+ "tool_call_id": "subgraph_extraction_tool",
130
+ "state": self.state,
131
+ "arg_data": self.arg_data}
132
+ )
133
+
134
+ # Check tool message
135
+ self.assertEqual(response.update["messages"][-1].tool_call_id, "subgraph_extraction_tool")
136
+
137
+ # Check extracted subgraph dictionary
138
+ dic_extracted_graph = response.update["dic_extracted_graph"][0]
139
+ self.assertIsInstance(dic_extracted_graph, dict)
140
+ self.assertEqual(dic_extracted_graph["name"], self.arg_data["extraction_name"])
141
+ self.assertEqual(dic_extracted_graph["graph_source"], "TestGraph")
142
+ self.assertEqual(dic_extracted_graph["topk_nodes"], 5)
143
+ self.assertEqual(dic_extracted_graph["topk_edges"], 5)
144
+ self.assertIsInstance(dic_extracted_graph["graph_dict"], dict)
145
+ self.assertGreater(len(dic_extracted_graph["graph_dict"]["nodes"]), 0)
146
+ self.assertGreater(len(dic_extracted_graph["graph_dict"]["edges"]), 0)
147
+ self.assertIsInstance(dic_extracted_graph["graph_text"], str)
148
+ # Check if the nodes are in the graph_text
149
+ self.assertTrue(all(
150
+ n[0] in dic_extracted_graph["graph_text"].replace('"', '')
151
+ for subgraph_nodes in dic_extracted_graph["graph_dict"]["nodes"]
152
+ for n in subgraph_nodes
153
+ ))
154
+ # Check if the edges are in the graph_text
155
+ self.assertTrue(all(
156
+ ",".join([str(e[0])] + str(e[2]['label'][0]).split(",") + [str(e[1])])
157
+ in dic_extracted_graph["graph_text"].replace('"', '').\
158
+ replace("[", "").replace("]", "").replace("'", "")
159
+ for subgraph_edges in dic_extracted_graph["graph_dict"]["edges"]
160
+ for e in subgraph_edges
161
+ ))
162
+
163
+ # Another test for unknown collection
164
+ result = collection_side_effect("unknown")
165
+ self.assertIsNone(result)
166
+
167
+ @patch("aiagents4pharma.talk2knowledgegraphs.tools."
168
+ "milvus_multimodal_subgraph_extraction.Collection")
169
+ @patch("aiagents4pharma.talk2knowledgegraphs.tools."
170
+ "milvus_multimodal_subgraph_extraction.pd.read_excel")
171
+ @patch("aiagents4pharma.talk2knowledgegraphs.tools."
172
+ "milvus_multimodal_subgraph_extraction.MultimodalPCSTPruning")
173
+ @patch("pymilvus.connections")
174
+ def test_extract_multimodal_subgraph_w_doc(self,
175
+ mock_connections,
176
+ mock_pcst,
177
+ mock_read_excel,
178
+ mock_collection):
179
+ """
180
+ Test the multimodal subgraph extraction tool for text as modality, plus genes.
181
+ """
182
+ # Mock Milvus connection utilities
183
+ mock_connections.has_connection.return_value = True
184
+
185
+ # With uploaded_files (with doc)
186
+ self.state["uploaded_files"] = [
187
+ {"file_type": "multimodal", "file_path": "dummy.xlsx"}
188
+ ]
189
+ self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
190
+ self.state["selections"] = {"gene/protein": ["JAK1", "JAK2"]}
191
+
192
+ # Mock pd.read_excel to return a dict of DataFrames
193
+ df = pd.DataFrame({
194
+ "name": ["JAK1", "JAK2"],
195
+ "node_type": ["gene/protein", "gene/protein"]
196
+ })
197
+ mock_read_excel.return_value = {"gene/protein": df}
198
+
199
+ # Mock Collection for nodes and edges
200
+ colls = {}
201
+ colls["nodes"] = MagicMock()
202
+ colls["nodes"] = MagicMock()
203
+ colls["nodes"].query.return_value = [
204
+ {"node_index": 0,
205
+ "node_id": "id1",
206
+ "node_name": "JAK1",
207
+ "node_type": "gene/protein",
208
+ "feat": "featA",
209
+ "feat_emb": [0.1, 0.2, 0.3],
210
+ "desc": "descA",
211
+ "desc_emb": [0.1, 0.2, 0.3]},
212
+ {"node_index": 1,
213
+ "node_id": "id2",
214
+ "node_name": "JAK2",
215
+ "node_type": "gene/protein",
216
+ "feat": "featB",
217
+ "feat_emb": [0.4, 0.5, 0.6],
218
+ "desc": "descB",
219
+ "desc_emb": [0.4, 0.5, 0.6]}
220
+ ]
221
+ colls["nodes"].load.return_value = None
222
+
223
+ colls["edges"] = MagicMock()
224
+ colls["edges"].query.return_value = [
225
+ {"triplet_index": 0,
226
+ "head_id": "id1",
227
+ "head_index": 0,
228
+ "tail_id": "id2",
229
+ "tail_index": 1,
230
+ "edge_type": "gene/protein,ppi,gene/protein",
231
+ "display_relation": "ppi",
232
+ "feat": "featC",
233
+ "feat_emb": [0.7, 0.8, 0.9]}
234
+ ]
235
+ colls["edges"].load.return_value = None
236
+
237
+ def collection_side_effect(name):
238
+ """
239
+ Mock side effect for Collection to return nodes or edges based on name.
240
+ """
241
+ if "nodes" in name:
242
+ return colls["nodes"]
243
+ if "edges" in name:
244
+ return colls["edges"]
245
+ return None
246
+ mock_collection.side_effect = collection_side_effect
247
+
248
+ # Mock MultimodalPCSTPruning
249
+ mock_pcst_instance = MagicMock()
250
+ mock_pcst_instance.extract_subgraph.return_value = {
251
+ "nodes": pd.Series([1, 2]),
252
+ "edges": pd.Series([0])
253
+ }
254
+ mock_pcst.return_value = mock_pcst_instance
255
+
256
+ # Patch hydra.compose to return config objects
257
+ with patch("aiagents4pharma.talk2knowledgegraphs.tools."
258
+ "milvus_multimodal_subgraph_extraction.hydra.initialize"), \
259
+ patch("aiagents4pharma.talk2knowledgegraphs.tools."
260
+ "milvus_multimodal_subgraph_extraction.hydra.compose") as mock_compose:
261
+ mock_compose.return_value = MagicMock()
262
+ mock_compose.return_value.app.frontend = self.cfg_db
263
+ mock_compose.return_value.tools.multimodal_subgraph_extraction = self.cfg
264
+
265
+ response = self.tool.invoke(
266
+ input={"prompt": self.prompt,
267
+ "tool_call_id": "subgraph_extraction_tool",
268
+ "state": self.state,
269
+ "arg_data": self.arg_data}
270
+ )
271
+
272
+ # Check tool message
273
+ self.assertEqual(response.update["messages"][-1].tool_call_id, "subgraph_extraction_tool")
274
+
275
+ # Check extracted subgraph dictionary
276
+ dic_extracted_graph = response.update["dic_extracted_graph"][0]
277
+ self.assertIsInstance(dic_extracted_graph, dict)
278
+ self.assertEqual(dic_extracted_graph["name"], self.arg_data["extraction_name"])
279
+ self.assertEqual(dic_extracted_graph["graph_source"], "TestGraph")
280
+ self.assertEqual(dic_extracted_graph["topk_nodes"], 5)
281
+ self.assertEqual(dic_extracted_graph["topk_edges"], 5)
282
+ self.assertIsInstance(dic_extracted_graph["graph_dict"], dict)
283
+ self.assertGreater(len(dic_extracted_graph["graph_dict"]["nodes"]), 0)
284
+ self.assertGreater(len(dic_extracted_graph["graph_dict"]["edges"]), 0)
285
+ self.assertIsInstance(dic_extracted_graph["graph_text"], str)
286
+ # Check if the nodes are in the graph_text
287
+ self.assertTrue(all(
288
+ n[0] in dic_extracted_graph["graph_text"].replace('"', '')
289
+ for subgraph_nodes in dic_extracted_graph["graph_dict"]["nodes"]
290
+ for n in subgraph_nodes
291
+ ))
292
+ # Check if the edges are in the graph_text
293
+ self.assertTrue(all(
294
+ ",".join([str(e[0])] + str(e[2]['label'][0]).split(",") + [str(e[1])])
295
+ in dic_extracted_graph["graph_text"].replace('"', '').\
296
+ replace("[", "").replace("]", "").replace("'", "")
297
+ for subgraph_edges in dic_extracted_graph["graph_dict"]["edges"]
298
+ for e in subgraph_edges
299
+ ))
300
+
301
+ # Another test for unknown collection
302
+ result = collection_side_effect("unknown")
303
+ self.assertIsNone(result)
304
+
305
+ def test_extract_multimodal_subgraph_wo_doc_gpu(self):
306
+ """
307
+ Test the multimodal subgraph extraction tool for only text as modality,
308
+ simulating GPU (cudf/cupy) environment.
309
+ """
310
+ module_name = "aiagents4pharma.talk2knowledgegraphs.tools."+\
311
+ "milvus_multimodal_subgraph_extraction"
312
+ with patch.dict("sys.modules", {"cupy": np, "cudf": pd}):
313
+ mod = importlib.reload(importlib.import_module(module_name))
314
+ # Patch Collection and MultimodalPCSTPruning after reload
315
+ with patch(f"{module_name}.Collection") as mock_collection, \
316
+ patch(f"{module_name}.MultimodalPCSTPruning") as mock_pcst, \
317
+ patch("pymilvus.connections") as mock_connections:
318
+ # Setup mocks as in the original test
319
+ mock_connections.has_connection.return_value = True
320
+ colls = {}
321
+ colls["nodes"] = MagicMock()
322
+ colls["nodes"].query.return_value = [
323
+ {"node_index": 0,
324
+ "node_id": "id1",
325
+ "node_name": "JAK1",
326
+ "node_type": "gene/protein",
327
+ "feat": "featA",
328
+ "feat_emb": [0.1, 0.2, 0.3],
329
+ "desc": "descA",
330
+ "desc_emb": [0.1, 0.2, 0.3]},
331
+ {"node_index": 1,
332
+ "node_id": "id2",
333
+ "node_name": "JAK2",
334
+ "node_type": "gene/protein",
335
+ "feat": "featB",
336
+ "feat_emb": [0.4, 0.5, 0.6],
337
+ "desc": "descB",
338
+ "desc_emb": [0.4, 0.5, 0.6]}
339
+ ]
340
+ colls["nodes"].load.return_value = None
341
+ colls["edges"] = MagicMock()
342
+ colls["edges"].query.return_value = [
343
+ {"triplet_index": 0,
344
+ "head_id": "id1",
345
+ "head_index": 0,
346
+ "tail_id": "id2",
347
+ "tail_index": 1,
348
+ "edge_type": "gene/protein,ppi,gene/protein",
349
+ "display_relation": "ppi",
350
+ "feat": "featC",
351
+ "feat_emb": [0.7, 0.8, 0.9]}
352
+ ]
353
+ colls["edges"].load.return_value = None
354
+ def collection_side_effect(name):
355
+ if "nodes" in name:
356
+ return colls["nodes"]
357
+ if "edges" in name:
358
+ return colls["edges"]
359
+ return None
360
+ mock_collection.side_effect = collection_side_effect
361
+ mock_pcst_instance = MagicMock()
362
+ mock_pcst_instance.extract_subgraph.return_value = {
363
+ "nodes": pd.Series([1, 2]),
364
+ "edges": pd.Series([0])
365
+ }
366
+ mock_pcst.return_value = mock_pcst_instance
367
+ # Setup config mocks
368
+ tool_cls = getattr(mod, "MultimodalSubgraphExtractionTool")
369
+ tool = tool_cls()
370
+
371
+ # Patch hydra.compose
372
+ with patch(f"{module_name}.hydra.initialize"), \
373
+ patch(f"{module_name}.hydra.compose") as mock_compose:
374
+ mock_compose.return_value = MagicMock()
375
+ mock_compose.return_value.app.frontend = self.cfg_db
376
+ mock_compose.return_value.tools.multimodal_subgraph_extraction = self.cfg
377
+ self.state["embedding_model"].embed_query.return_value = [0.1, 0.2, 0.3]
378
+ self.state["selections"] = {}
379
+ response = tool.invoke(
380
+ input={"prompt": self.prompt,
381
+ "tool_call_id": "subgraph_extraction_tool",
382
+ "state": self.state,
383
+ "arg_data": self.arg_data}
384
+ )
385
+ # Check tool message
386
+ self.assertEqual(response.update["messages"][-1].tool_call_id,
387
+ "subgraph_extraction_tool")
388
+ dic_extracted_graph = response.update["dic_extracted_graph"][0]
389
+ self.assertIsInstance(dic_extracted_graph, dict)
390
+ self.assertEqual(dic_extracted_graph["name"], self.arg_data["extraction_name"])
391
+ self.assertEqual(dic_extracted_graph["graph_source"], "TestGraph")
392
+ self.assertEqual(dic_extracted_graph["topk_nodes"], 5)
393
+ self.assertEqual(dic_extracted_graph["topk_edges"], 5)
394
+ self.assertIsInstance(dic_extracted_graph["graph_dict"], dict)
395
+ self.assertGreater(len(dic_extracted_graph["graph_dict"]["nodes"]), 0)
396
+ self.assertGreater(len(dic_extracted_graph["graph_dict"]["edges"]), 0)
397
+ self.assertIsInstance(dic_extracted_graph["graph_text"], str)
398
+ self.assertTrue(all(
399
+ n[0] in dic_extracted_graph["graph_text"].replace('"', '')
400
+ for subgraph_nodes in dic_extracted_graph["graph_dict"]["nodes"]
401
+ for n in subgraph_nodes
402
+ ))
403
+ self.assertTrue(all(
404
+ ",".join([str(e[0])] + str(e[2]['label'][0]).split(",") + [str(e[1])])
405
+ in dic_extracted_graph["graph_text"].replace('"', '').\
406
+ replace("[", "").replace("]", "").replace("'", "")
407
+ for subgraph_edges in dic_extracted_graph["graph_dict"]["edges"]
408
+ for e in subgraph_edges
409
+ ))
410
+
411
+ # Another test for unknown collection
412
+ result = collection_side_effect("unknown")
413
+ self.assertIsNone(result)
@@ -42,11 +42,11 @@ def test_enrich_documents(enrich_obj):
42
42
  "XYZ_0000000",
43
43
  ]
44
44
  descriptions = enrich_obj.enrich_documents(ols_terms)
45
- assert descriptions[0].startswith(CL_DESC)
46
- assert descriptions[1].startswith(GO_DESC)
47
- assert descriptions[2].startswith(UBERON_DESC)
48
- assert descriptions[3].startswith(HP_DESC)
49
- assert descriptions[4].startswith(MONDO_DESC)
45
+ assert CL_DESC in descriptions[0]
46
+ assert GO_DESC in descriptions[1]
47
+ assert UBERON_DESC in descriptions[2]
48
+ assert HP_DESC in descriptions[3]
49
+ assert MONDO_DESC in descriptions[4]
50
50
  assert descriptions[5] is None
51
51
 
52
52
 
@@ -61,9 +61,9 @@ def test_enrich_documents_with_rag(enrich_obj):
61
61
  "XYZ_0000000",
62
62
  ]
63
63
  descriptions = enrich_obj.enrich_documents_with_rag(ols_terms, None)
64
- assert descriptions[0].startswith(CL_DESC)
65
- assert descriptions[1].startswith(GO_DESC)
66
- assert descriptions[2].startswith(UBERON_DESC)
67
- assert descriptions[3].startswith(HP_DESC)
68
- assert descriptions[4].startswith(MONDO_DESC)
64
+ assert CL_DESC in descriptions[0]
65
+ assert GO_DESC in descriptions[1]
66
+ assert UBERON_DESC in descriptions[2]
67
+ assert HP_DESC in descriptions[3]
68
+ assert MONDO_DESC in descriptions[4]
69
69
  assert descriptions[5] is None
@@ -0,0 +1,175 @@
1
+ """
2
+ Test cases for tools/utils/extractions/milvus_multimodal_pcst.py
3
+ """
4
+
5
+ import importlib
6
+ import sys
7
+ import unittest
8
+ from unittest.mock import patch, MagicMock, mock_open
9
+ import numpy as np
10
+ import pandas as pd
11
+ from ..utils.extractions.milvus_multimodal_pcst import MultimodalPCSTPruning
12
+
13
+ class TestMultimodalPCSTPruning(unittest.TestCase):
14
+ """
15
+ Test cases for MultimodalPCSTPruning class (Milvus-based PCST pruning).
16
+ """
17
+ def setUp(self):
18
+ # Patch cupy and cudf to simulate GPU environment
19
+ patcher_cupy = patch.dict('sys.modules', {'cupy': MagicMock(), 'cudf': MagicMock()})
20
+ patcher_cupy.start()
21
+ self.addCleanup(patcher_cupy.stop)
22
+
23
+ # Patch pcst_fast
24
+ self.pcst_fast_patcher = patch("aiagents4pharma.talk2knowledgegraphs.utils."
25
+ "extractions.milvus_multimodal_pcst.pcst_fast")
26
+ self.mock_pcst_fast = self.pcst_fast_patcher.start()
27
+ self.addCleanup(self.pcst_fast_patcher.stop)
28
+ self.mock_pcst_fast.pcst_fast.return_value = ([0, 1], [0])
29
+
30
+ # Patch Collection
31
+ self.collection_patcher = patch("aiagents4pharma.talk2knowledgegraphs.utils."
32
+ "extractions.milvus_multimodal_pcst.Collection")
33
+ self.mock_collection = self.collection_patcher.start()
34
+ self.addCleanup(self.collection_patcher.stop)
35
+
36
+ # Patch open for cache_edge_index_path
37
+ self.open_patcher = patch('builtins.open', mock_open(read_data='[[0,1],[1,2]]'))
38
+ self.mock_open = self.open_patcher.start()
39
+ self.addCleanup(self.open_patcher.stop)
40
+
41
+ # Patch pickle.load to return a numpy array for edge_index
42
+ self.pickle_patcher = patch("aiagents4pharma.talk2knowledgegraphs.utils."
43
+ "extractions.milvus_multimodal_pcst.pickle")
44
+ self.mock_pickle = self.pickle_patcher.start()
45
+ self.addCleanup(self.pickle_patcher.stop)
46
+ self.mock_pickle.load.return_value = np.array([[0, 1], [1, 2]])
47
+
48
+ # Setup config mock
49
+ self.cfg = MagicMock()
50
+ self.cfg.milvus_db.database_name = "testdb"
51
+ self.cfg.milvus_db.cache_edge_index_path = "dummy_cache.pkl"
52
+
53
+ # Setup Collection mocks
54
+ node_coll = MagicMock()
55
+ node_coll.num_entities = 2
56
+ node_coll.search.return_value = [[MagicMock(id=0), MagicMock(id=1)]]
57
+ edge_coll = MagicMock()
58
+ edge_coll.num_entities = 2
59
+ edge_coll.search.return_value = [[MagicMock(id=0, score=1.0), MagicMock(id=1, score=0.5)]]
60
+ self.mock_collection.side_effect = lambda name: node_coll if "nodes" in name else edge_coll
61
+
62
+ def test_extract_subgraph_use_description_true(self):
63
+ """
64
+ Test the extract_subgraph method of MultimodalPCSTPruning with use_description=True.
65
+ """
66
+ # Create instance
67
+ pcst = MultimodalPCSTPruning(
68
+ topk=3, topk_e=3, cost_e=0.5, c_const=0.01, root=-1,
69
+ num_clusters=1, pruning="gw", verbosity_level=0, use_description=True, metric_type="IP"
70
+ )
71
+ # Dummy embeddings
72
+ text_emb = [0.1, 0.2, 0.3]
73
+ query_emb = [0.1, 0.2, 0.3]
74
+ modality = "gene/protein"
75
+
76
+ # Call extract_subgraph
77
+ result = pcst.extract_subgraph(text_emb, query_emb, modality, self.cfg)
78
+
79
+ # Assertions
80
+ self.assertIn("nodes", result)
81
+ self.assertIn("edges", result)
82
+ self.assertGreaterEqual(len(result["nodes"]), 0)
83
+ self.assertGreaterEqual(len(result["edges"]), 0)
84
+
85
+ def test_extract_subgraph_use_description_false(self):
86
+ """
87
+ Test the extract_subgraph method of MultimodalPCSTPruning with use_description=False.
88
+ """
89
+ # Create instance
90
+ pcst = MultimodalPCSTPruning(
91
+ topk=3, topk_e=3, cost_e=0.5, c_const=0.01, root=-1,
92
+ num_clusters=1, pruning="gw", verbosity_level=0, use_description=False, metric_type="IP"
93
+ )
94
+ # Dummy embeddings
95
+ text_emb = [0.1, 0.2, 0.3]
96
+ query_emb = [0.1, 0.2, 0.3]
97
+ modality = "gene/protein"
98
+
99
+ # Call extract_subgraph
100
+ result = pcst.extract_subgraph(text_emb, query_emb, modality, self.cfg)
101
+
102
+ # Assertions
103
+ self.assertIn("nodes", result)
104
+ self.assertIn("edges", result)
105
+ self.assertGreaterEqual(len(result["nodes"]), 0)
106
+ self.assertGreaterEqual(len(result["edges"]), 0)
107
+
108
+ def test_extract_subgraph_with_virtual_vertices(self):
109
+ """
110
+ Test get_subgraph_nodes_edges with virtual vertices present (len(virtual_vertices) > 0).
111
+ """
112
+ pcst = MultimodalPCSTPruning(
113
+ topk=3, topk_e=3, cost_e=0.5, c_const=0.01, root=-1,
114
+ num_clusters=1, pruning="gw", verbosity_level=0, use_description=True, metric_type="IP"
115
+ )
116
+ # Simulate num_nodes = 2, vertices contains [0, 1, 2, 3] (2 and 3 are virtual)
117
+ num_nodes = 2
118
+ # vertices: [0, 1, 2, 3] (2 and 3 are virtual)
119
+ vertices = np.array([0, 1, 2, 3])
120
+ # edges_dict simulates prior edges and edge_index
121
+ edges_dict = {
122
+ "edges": np.array([0, 1, 2]),
123
+ "num_prior_edges": 2,
124
+ "edge_index": np.array([[0, 1, 2, 3], [1, 2, 3, 4]])
125
+ }
126
+ # mapping simulates mapping for edges and nodes
127
+ mapping = {
128
+ "edges": {0: 0, 1: 1},
129
+ "nodes": {2: 2, 3: 3}
130
+ }
131
+
132
+ # Call extract_subgraph
133
+ result = pcst.get_subgraph_nodes_edges(num_nodes, vertices, edges_dict, mapping)
134
+
135
+ # Assertions
136
+ self.assertIn("nodes", result)
137
+ self.assertIn("edges", result)
138
+ self.assertGreaterEqual(len(result["nodes"]), 0)
139
+ self.assertGreaterEqual(len(result["edges"]), 0)
140
+ # Check that virtual edges are included
141
+ self.assertTrue(any(e in [2, 3] for e in result["edges"]))
142
+
143
+ def test_gpu_import_branch(self):
144
+ """
145
+ Test coverage for GPU import branch by patching sys.modules to mock cupy and
146
+ cudf as numpy and pandas.
147
+ """
148
+ module_name = "aiagents4pharma.talk2knowledgegraphs.utils" + \
149
+ ".extractions.milvus_multimodal_pcst"
150
+ with patch.dict("sys.modules", {"cupy": np, "cudf": pd}):
151
+ # Reload the module to trigger the GPU branch
152
+ mod = importlib.reload(sys.modules[module_name])
153
+ # Patch Collection, pcst_fast, and pickle after reload
154
+ with patch(f"{module_name}.Collection", self.mock_collection), \
155
+ patch(f"{module_name}.pcst_fast", self.mock_pcst_fast), \
156
+ patch(f"{module_name}.pickle", self.mock_pickle):
157
+ pcst_pruning_cls = getattr(mod, "MultimodalPCSTPruning")
158
+ pcst = pcst_pruning_cls(
159
+ topk=3, topk_e=3, cost_e=0.5, c_const=0.01, root=-1,
160
+ num_clusters=1, pruning="gw", verbosity_level=0, use_description=True,
161
+ metric_type="IP"
162
+ )
163
+ # Dummy embeddings
164
+ text_emb = [0.1, 0.2, 0.3]
165
+ query_emb = [0.1, 0.2, 0.3]
166
+ modality = "gene/protein"
167
+
168
+ # Call extract_subgraph
169
+ result = pcst.extract_subgraph(text_emb, query_emb, modality, self.cfg)
170
+
171
+ # Assertions
172
+ self.assertIn("nodes", result)
173
+ self.assertIn("edges", result)
174
+ self.assertGreaterEqual(len(result["nodes"]), 0)
175
+ self.assertGreaterEqual(len(result["edges"]), 0)
@@ -4,6 +4,17 @@ Test cases for utils/pubchem_utils.py
4
4
 
5
5
  from ..utils import pubchem_utils
6
6
 
7
+ def test_cas_rn2pubchem_cid():
8
+ """
9
+ Test the casRN2pubchem_cid function.
10
+
11
+ The CAS RN for ethyl cabonate is 105-58-8.
12
+ The PubChem CID for ethyl cabonate is 7766.
13
+ """
14
+ casrn = "105-58-8"
15
+ pubchem_cid = pubchem_utils.cas_rn2pubchem_cid(casrn)
16
+ assert pubchem_cid == 7766
17
+
7
18
  def test_external_id2pubchem_cid():
8
19
  """
9
20
  Test the external_id2pubchem_cid function.
@@ -3,5 +3,6 @@ This file is used to import all the models in the package.
3
3
  '''
4
4
  from . import subgraph_extraction
5
5
  from . import multimodal_subgraph_extraction
6
+ from . import milvus_multimodal_subgraph_extraction
6
7
  from . import subgraph_summarization
7
8
  from . import graphrag_reasoning