aiagents4pharma 1.43.0__py3-none-any.whl → 1.44.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/default.yaml +17 -2
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py +618 -413
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_extractions_milvus_multimodal_pcst.py +362 -25
- aiagents4pharma/talk2knowledgegraphs/tools/milvus_multimodal_subgraph_extraction.py +146 -109
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py +240 -83
- {aiagents4pharma-1.43.0.dist-info → aiagents4pharma-1.44.0.dist-info}/METADATA +1 -1
- {aiagents4pharma-1.43.0.dist-info → aiagents4pharma-1.44.0.dist-info}/RECORD +10 -10
- {aiagents4pharma-1.43.0.dist-info → aiagents4pharma-1.44.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.43.0.dist-info → aiagents4pharma-1.44.0.dist-info}/licenses/LICENSE +0 -0
- {aiagents4pharma-1.43.0.dist-info → aiagents4pharma-1.44.0.dist-info}/top_level.txt +0 -0
@@ -2,24 +2,156 @@
|
|
2
2
|
Exctraction of multimodal subgraph using Prize-Collecting Steiner Tree (PCST) algorithm.
|
3
3
|
"""
|
4
4
|
|
5
|
-
from typing import Tuple, NamedTuple
|
6
5
|
import logging
|
7
6
|
import pickle
|
7
|
+
import platform
|
8
|
+
import subprocess
|
9
|
+
from typing import NamedTuple
|
10
|
+
|
11
|
+
import numpy as np
|
8
12
|
import pandas as pd
|
9
13
|
import pcst_fast
|
10
14
|
from pymilvus import Collection
|
15
|
+
|
11
16
|
try:
|
12
|
-
import cupy as py
|
13
17
|
import cudf
|
14
|
-
|
18
|
+
import cupy as cp
|
19
|
+
CUDF_AVAILABLE = True
|
15
20
|
except ImportError:
|
16
|
-
|
17
|
-
|
21
|
+
CUDF_AVAILABLE = False
|
22
|
+
cudf = None
|
23
|
+
cp = None
|
18
24
|
|
19
25
|
# Initialize logger
|
20
26
|
logging.basicConfig(level=logging.INFO)
|
21
27
|
logger = logging.getLogger(__name__)
|
22
28
|
|
29
|
+
|
30
|
+
class SystemDetector:
|
31
|
+
"""Detect system capabilities and choose appropriate libraries."""
|
32
|
+
|
33
|
+
def __init__(self):
|
34
|
+
self.os_type = platform.system().lower() # 'windows', 'linux', 'darwin'
|
35
|
+
self.architecture = platform.machine().lower() # 'x86_64', 'arm64', etc.
|
36
|
+
self.has_nvidia_gpu = self._detect_nvidia_gpu()
|
37
|
+
self.use_gpu = (
|
38
|
+
self.has_nvidia_gpu and self.os_type != "darwin"
|
39
|
+
) # No CUDA on macOS
|
40
|
+
|
41
|
+
logger.info("System Detection Results:")
|
42
|
+
logger.info(" OS: %s", self.os_type)
|
43
|
+
logger.info(" Architecture: %s", self.architecture)
|
44
|
+
logger.info(" NVIDIA GPU detected: %s", self.has_nvidia_gpu)
|
45
|
+
logger.info(" Will use GPU acceleration: %s", self.use_gpu)
|
46
|
+
|
47
|
+
def _detect_nvidia_gpu(self) -> bool:
|
48
|
+
"""Detect if NVIDIA GPU is available."""
|
49
|
+
try:
|
50
|
+
# Try nvidia-smi command
|
51
|
+
result = subprocess.run(
|
52
|
+
["nvidia-smi"], capture_output=True, text=True, timeout=10, check=False
|
53
|
+
)
|
54
|
+
return result.returncode == 0
|
55
|
+
except (
|
56
|
+
subprocess.TimeoutExpired,
|
57
|
+
FileNotFoundError,
|
58
|
+
subprocess.SubprocessError,
|
59
|
+
):
|
60
|
+
return False
|
61
|
+
|
62
|
+
def get_system_info(self) -> dict:
|
63
|
+
"""Get comprehensive system information."""
|
64
|
+
return {
|
65
|
+
"os_type": self.os_type,
|
66
|
+
"architecture": self.architecture,
|
67
|
+
"has_nvidia_gpu": self.has_nvidia_gpu,
|
68
|
+
"use_gpu": self.use_gpu,
|
69
|
+
}
|
70
|
+
|
71
|
+
def is_gpu_compatible(self) -> bool:
|
72
|
+
"""Check if the system is compatible with GPU acceleration."""
|
73
|
+
return self.has_nvidia_gpu and self.os_type != "darwin"
|
74
|
+
|
75
|
+
|
76
|
+
class DynamicLibraryLoader:
|
77
|
+
"""Dynamically load libraries based on system capabilities."""
|
78
|
+
|
79
|
+
def __init__(self, detector: SystemDetector):
|
80
|
+
self.detector = detector
|
81
|
+
self.use_gpu = detector.use_gpu
|
82
|
+
|
83
|
+
# Initialize attributes that will be set later
|
84
|
+
self.py = None
|
85
|
+
self.df = None
|
86
|
+
self.pd = None
|
87
|
+
self.np = None
|
88
|
+
self.cudf = None
|
89
|
+
self.cp = None
|
90
|
+
|
91
|
+
# Import libraries based on system capabilities
|
92
|
+
self._import_libraries()
|
93
|
+
|
94
|
+
# Dynamic settings based on hardware
|
95
|
+
self.normalize_vectors = self.use_gpu # Only normalize for GPU
|
96
|
+
self.metric_type = "IP" if self.use_gpu else "COSINE"
|
97
|
+
|
98
|
+
logger.info("Library Configuration:")
|
99
|
+
logger.info(" Using GPU acceleration: %s", self.use_gpu)
|
100
|
+
logger.info(" Vector normalization: %s", self.normalize_vectors)
|
101
|
+
logger.info(" Metric type: %s", self.metric_type)
|
102
|
+
|
103
|
+
def _import_libraries(self):
|
104
|
+
"""Dynamically import libraries based on system capabilities."""
|
105
|
+
# Set base libraries
|
106
|
+
self.pd = pd
|
107
|
+
self.np = np
|
108
|
+
|
109
|
+
# Conditionally import GPU libraries
|
110
|
+
if self.detector.use_gpu:
|
111
|
+
if CUDF_AVAILABLE:
|
112
|
+
self.cudf = cudf
|
113
|
+
self.cp = cp
|
114
|
+
self.py = cp # Use cupy for array operations
|
115
|
+
self.df = cudf # Use cudf for dataframes
|
116
|
+
logger.info("Successfully imported GPU libraries (cudf, cupy)")
|
117
|
+
else:
|
118
|
+
logger.error("cudf or cupy not found. Falling back to CPU mode.")
|
119
|
+
self.detector.use_gpu = False
|
120
|
+
self.use_gpu = False
|
121
|
+
self._setup_cpu_mode()
|
122
|
+
else:
|
123
|
+
self._setup_cpu_mode()
|
124
|
+
|
125
|
+
def _setup_cpu_mode(self):
|
126
|
+
"""Setup CPU mode with numpy and pandas."""
|
127
|
+
self.py = self.np # Use numpy for array operations
|
128
|
+
self.df = self.pd # Use pandas for dataframes
|
129
|
+
self.normalize_vectors = False
|
130
|
+
self.metric_type = "COSINE"
|
131
|
+
logger.info("Using CPU mode with numpy and pandas")
|
132
|
+
|
133
|
+
def normalize_matrix(self, matrix, axis: int = 1):
|
134
|
+
"""Normalize matrix using appropriate library."""
|
135
|
+
if not self.normalize_vectors:
|
136
|
+
return matrix
|
137
|
+
|
138
|
+
if self.use_gpu:
|
139
|
+
# Use cupy for GPU
|
140
|
+
matrix_cp = self.cp.asarray(matrix).astype(self.cp.float32)
|
141
|
+
norms = self.cp.linalg.norm(matrix_cp, axis=axis, keepdims=True)
|
142
|
+
return matrix_cp / norms
|
143
|
+
# CPU mode doesn't normalize for COSINE similarity
|
144
|
+
return matrix
|
145
|
+
|
146
|
+
def to_list(self, data):
|
147
|
+
"""Convert data to list format."""
|
148
|
+
if hasattr(data, "tolist"):
|
149
|
+
return data.tolist()
|
150
|
+
if hasattr(data, "to_arrow"):
|
151
|
+
return data.to_arrow().to_pylist()
|
152
|
+
return list(data)
|
153
|
+
|
154
|
+
|
23
155
|
class MultimodalPCSTPruning(NamedTuple):
|
24
156
|
"""
|
25
157
|
Prize-Collecting Steiner Tree (PCST) pruning algorithm implementation inspired by G-Retriever
|
@@ -37,7 +169,11 @@ class MultimodalPCSTPruning(NamedTuple):
|
|
37
169
|
num_clusters: The number of clusters.
|
38
170
|
pruning: The pruning strategy to use.
|
39
171
|
verbosity_level: The verbosity level.
|
172
|
+
use_description: Whether to use description embeddings.
|
173
|
+
metric_type: The similarity metric type (dynamic based on hardware).
|
174
|
+
loader: The dynamic library loader instance.
|
40
175
|
"""
|
176
|
+
|
41
177
|
topk: int = 3
|
42
178
|
topk_e: int = 3
|
43
179
|
cost_e: float = 0.5
|
@@ -47,7 +183,8 @@ class MultimodalPCSTPruning(NamedTuple):
|
|
47
183
|
pruning: str = "gw"
|
48
184
|
verbosity_level: int = 0
|
49
185
|
use_description: bool = False
|
50
|
-
metric_type: str =
|
186
|
+
metric_type: str = None # Will be set dynamically
|
187
|
+
loader: DynamicLibraryLoader = None
|
51
188
|
|
52
189
|
def prepare_collections(self, cfg: dict, modality: str) -> dict:
|
53
190
|
"""
|
@@ -81,11 +218,9 @@ class MultimodalPCSTPruning(NamedTuple):
|
|
81
218
|
|
82
219
|
return colls
|
83
220
|
|
84
|
-
def _compute_node_prizes(self,
|
85
|
-
query_emb: list,
|
86
|
-
colls: dict) -> dict:
|
221
|
+
def _compute_node_prizes(self, query_emb: list, colls: dict) -> dict:
|
87
222
|
"""
|
88
|
-
Compute the node prizes based on the
|
223
|
+
Compute the node prizes based on the similarity between the query and nodes.
|
89
224
|
|
90
225
|
Args:
|
91
226
|
query_emb: The query embedding. This can be an embedding of
|
@@ -95,79 +230,91 @@ class MultimodalPCSTPruning(NamedTuple):
|
|
95
230
|
Returns:
|
96
231
|
The prizes of the nodes.
|
97
232
|
"""
|
98
|
-
#
|
233
|
+
# Initialize several variables
|
99
234
|
topk = min(self.topk, colls["nodes"].num_entities)
|
100
|
-
n_prizes = py.zeros(
|
235
|
+
n_prizes = self.loader.py.zeros(
|
236
|
+
colls["nodes"].num_entities, dtype=self.loader.py.float32
|
237
|
+
)
|
238
|
+
|
239
|
+
# Get the actual metric type to use
|
240
|
+
actual_metric_type = self.metric_type or self.loader.metric_type
|
101
241
|
|
102
|
-
# Calculate
|
242
|
+
# Calculate similarity for text features and update the score
|
103
243
|
if self.use_description:
|
104
244
|
# Search the collection with the text embedding
|
105
245
|
res = colls["nodes"].search(
|
106
246
|
data=[query_emb],
|
107
247
|
anns_field="desc_emb",
|
108
|
-
param={"metric_type":
|
248
|
+
param={"metric_type": actual_metric_type},
|
109
249
|
limit=topk,
|
110
|
-
output_fields=["node_id"]
|
250
|
+
output_fields=["node_id"],
|
251
|
+
)
|
111
252
|
else:
|
112
253
|
# Search the collection with the query embedding
|
113
254
|
res = colls["nodes_type"].search(
|
114
255
|
data=[query_emb],
|
115
256
|
anns_field="feat_emb",
|
116
|
-
param={"metric_type":
|
257
|
+
param={"metric_type": actual_metric_type},
|
117
258
|
limit=topk,
|
118
|
-
output_fields=["node_id"]
|
259
|
+
output_fields=["node_id"],
|
260
|
+
)
|
119
261
|
|
120
262
|
# Update the prizes based on the search results
|
121
|
-
n_prizes[[r.id for r in res[0]]] = py.arange(topk, 0, -1).astype(
|
263
|
+
n_prizes[[r.id for r in res[0]]] = self.loader.py.arange(topk, 0, -1).astype(
|
264
|
+
self.loader.py.float32
|
265
|
+
)
|
122
266
|
|
123
267
|
return n_prizes
|
124
268
|
|
125
|
-
def _compute_edge_prizes(self,
|
126
|
-
text_emb: list,
|
127
|
-
colls: dict) -> py.ndarray:
|
269
|
+
def _compute_edge_prizes(self, text_emb: list, colls: dict):
|
128
270
|
"""
|
129
|
-
Compute the
|
271
|
+
Compute the edge prizes based on the similarity between the query and edges.
|
130
272
|
|
131
273
|
Args:
|
132
274
|
text_emb: The textual description embedding.
|
133
275
|
colls: The collections of nodes, node-type specific nodes, and edges in Milvus.
|
134
276
|
|
135
277
|
Returns:
|
136
|
-
The prizes of the
|
278
|
+
The prizes of the edges.
|
137
279
|
"""
|
138
|
-
#
|
280
|
+
# Initialize several variables
|
139
281
|
topk_e = min(self.topk_e, colls["edges"].num_entities)
|
140
|
-
e_prizes = py.zeros(
|
282
|
+
e_prizes = self.loader.py.zeros(
|
283
|
+
colls["edges"].num_entities, dtype=self.loader.py.float32
|
284
|
+
)
|
285
|
+
|
286
|
+
# Get the actual metric type to use
|
287
|
+
actual_metric_type = self.metric_type or self.loader.metric_type
|
141
288
|
|
142
289
|
# Search the collection with the query embedding
|
143
290
|
res = colls["edges"].search(
|
144
291
|
data=[text_emb],
|
145
292
|
anns_field="feat_emb",
|
146
|
-
param={"metric_type":
|
147
|
-
limit=topk_e,
|
148
|
-
|
149
|
-
|
293
|
+
param={"metric_type": actual_metric_type},
|
294
|
+
limit=topk_e, # Only retrieve the top-k edges
|
295
|
+
output_fields=["head_id", "tail_id"],
|
296
|
+
)
|
150
297
|
|
151
298
|
# Update the prizes based on the search results
|
152
299
|
e_prizes[[r.id for r in res[0]]] = [r.score for r in res[0]]
|
153
300
|
|
154
301
|
# Further process the edge_prizes
|
155
|
-
unique_prizes, inverse_indices = py.unique(
|
156
|
-
|
157
|
-
|
302
|
+
unique_prizes, inverse_indices = self.loader.py.unique(
|
303
|
+
e_prizes, return_inverse=True
|
304
|
+
)
|
305
|
+
topk_e_values = unique_prizes[self.loader.py.argsort(-unique_prizes)[:topk_e]]
|
158
306
|
last_topk_e_value = topk_e
|
159
307
|
for k in range(topk_e):
|
160
|
-
indices =
|
308
|
+
indices = (
|
309
|
+
inverse_indices == (unique_prizes == topk_e_values[k]).nonzero()[0]
|
310
|
+
)
|
161
311
|
value = min((topk_e - k) / indices.sum().item(), last_topk_e_value)
|
162
312
|
e_prizes[indices] = value
|
163
313
|
last_topk_e_value = value * (1 - self.c_const)
|
164
314
|
|
165
315
|
return e_prizes
|
166
316
|
|
167
|
-
def compute_prizes(self,
|
168
|
-
text_emb: list,
|
169
|
-
query_emb: list,
|
170
|
-
colls: dict) -> dict:
|
317
|
+
def compute_prizes(self, text_emb: list, query_emb: list, colls: dict) -> dict:
|
171
318
|
"""
|
172
319
|
Compute the node prizes based on the cosine similarity between the query and nodes,
|
173
320
|
as well as the edge prizes based on the cosine similarity between the query and edges.
|
@@ -193,10 +340,7 @@ class MultimodalPCSTPruning(NamedTuple):
|
|
193
340
|
|
194
341
|
return {"nodes": n_prizes, "edges": e_prizes}
|
195
342
|
|
196
|
-
def compute_subgraph_costs(self,
|
197
|
-
edge_index: py.ndarray,
|
198
|
-
num_nodes: int,
|
199
|
-
prizes: dict) -> Tuple[py.ndarray, py.ndarray, py.ndarray]:
|
343
|
+
def compute_subgraph_costs(self, edge_index, num_nodes: int, prizes: dict):
|
200
344
|
"""
|
201
345
|
Compute the costs in constructing the subgraph proposed by G-Retriever paper.
|
202
346
|
|
@@ -218,7 +362,7 @@ class MultimodalPCSTPruning(NamedTuple):
|
|
218
362
|
# Update edge cost threshold
|
219
363
|
updated_cost_e = min(
|
220
364
|
self.cost_e,
|
221
|
-
py.max(prizes["edges"]).item() * (1 - self.c_const / 2),
|
365
|
+
self.loader.py.max(prizes["edges"]).item() * (1 - self.c_const / 2),
|
222
366
|
)
|
223
367
|
|
224
368
|
# Masks for real and virtual edges
|
@@ -228,19 +372,21 @@ class MultimodalPCSTPruning(NamedTuple):
|
|
228
372
|
|
229
373
|
# Real edge indices
|
230
374
|
logger.log(logging.INFO, "Computing real edges")
|
231
|
-
real_["indices"] = py.nonzero(real_["mask"])[0]
|
375
|
+
real_["indices"] = self.loader.py.nonzero(real_["mask"])[0]
|
232
376
|
real_["src"] = edge_index[0][real_["indices"]]
|
233
377
|
real_["dst"] = edge_index[1][real_["indices"]]
|
234
|
-
real_["edges"] = py.stack([real_["src"], real_["dst"]], axis=1)
|
378
|
+
real_["edges"] = self.loader.py.stack([real_["src"], real_["dst"]], axis=1)
|
235
379
|
real_["costs"] = updated_cost_e - prizes["edges"][real_["indices"]]
|
236
380
|
|
237
381
|
# Edge index mapping: local real edge idx -> original global index
|
238
382
|
logger.log(logging.INFO, "Creating mapping for real edges")
|
239
|
-
mapping_edges = dict(
|
383
|
+
mapping_edges = dict(
|
384
|
+
zip(range(len(real_["indices"])), self.loader.to_list(real_["indices"]))
|
385
|
+
)
|
240
386
|
|
241
387
|
# Virtual edge handling
|
242
388
|
logger.log(logging.INFO, "Computing virtual edges")
|
243
|
-
virt_["indices"] = py.nonzero(virt_["mask"])[0]
|
389
|
+
virt_["indices"] = self.loader.py.nonzero(virt_["mask"])[0]
|
244
390
|
virt_["src"] = edge_index[0][virt_["indices"]]
|
245
391
|
virt_["dst"] = edge_index[1][virt_["indices"]]
|
246
392
|
virt_["prizes"] = prizes["edges"][virt_["indices"]] - updated_cost_e
|
@@ -248,28 +394,42 @@ class MultimodalPCSTPruning(NamedTuple):
|
|
248
394
|
# Generate virtual node IDs
|
249
395
|
logger.log(logging.INFO, "Generating virtual node IDs")
|
250
396
|
virt_["num"] = virt_["indices"].shape[0]
|
251
|
-
virt_["node_ids"] = py.arange(num_nodes, num_nodes + virt_["num"])
|
397
|
+
virt_["node_ids"] = self.loader.py.arange(num_nodes, num_nodes + virt_["num"])
|
252
398
|
|
253
399
|
# Virtual edges: (src → virtual), (virtual → dst)
|
254
400
|
logger.log(logging.INFO, "Creating virtual edges")
|
255
|
-
virt_["edges_1"] = py.stack(
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
401
|
+
virt_["edges_1"] = self.loader.py.stack(
|
402
|
+
[virt_["src"], virt_["node_ids"]], axis=1
|
403
|
+
)
|
404
|
+
virt_["edges_2"] = self.loader.py.stack(
|
405
|
+
[virt_["node_ids"], virt_["dst"]], axis=1
|
406
|
+
)
|
407
|
+
virt_["edges"] = self.loader.py.concatenate(
|
408
|
+
[virt_["edges_1"], virt_["edges_2"]], axis=0
|
409
|
+
)
|
410
|
+
virt_["costs"] = self.loader.py.zeros(
|
411
|
+
(virt_["edges"].shape[0],), dtype=real_["costs"].dtype
|
412
|
+
)
|
260
413
|
|
261
414
|
# Combine real and virtual edges/costs
|
262
415
|
logger.log(logging.INFO, "Combining real and virtual edges/costs")
|
263
|
-
all_edges = py.concatenate([real_["edges"], virt_["edges"]], axis=0)
|
264
|
-
all_costs = py.concatenate([real_["costs"], virt_["costs"]], axis=0)
|
416
|
+
all_edges = self.loader.py.concatenate([real_["edges"], virt_["edges"]], axis=0)
|
417
|
+
all_costs = self.loader.py.concatenate([real_["costs"], virt_["costs"]], axis=0)
|
265
418
|
|
266
419
|
# Final prizes
|
267
420
|
logger.log(logging.INFO, "Getting final prizes")
|
268
|
-
final_prizes = py.concatenate(
|
421
|
+
final_prizes = self.loader.py.concatenate(
|
422
|
+
[prizes["nodes"], virt_["prizes"]], axis=0
|
423
|
+
)
|
269
424
|
|
270
425
|
# Mapping virtual node ID -> edge index in original graph
|
271
426
|
logger.log(logging.INFO, "Creating mapping for virtual nodes")
|
272
|
-
mapping_nodes = dict(
|
427
|
+
mapping_nodes = dict(
|
428
|
+
zip(
|
429
|
+
self.loader.to_list(virt_["node_ids"]),
|
430
|
+
self.loader.to_list(virt_["indices"]),
|
431
|
+
)
|
432
|
+
)
|
273
433
|
|
274
434
|
# Build return values
|
275
435
|
logger.log(logging.INFO, "Building return values")
|
@@ -284,11 +444,9 @@ class MultimodalPCSTPruning(NamedTuple):
|
|
284
444
|
|
285
445
|
return edges_dict, final_prizes, all_costs, mapping
|
286
446
|
|
287
|
-
def get_subgraph_nodes_edges(
|
288
|
-
|
289
|
-
|
290
|
-
edges_dict: dict,
|
291
|
-
mapping: dict) -> dict:
|
447
|
+
def get_subgraph_nodes_edges(
|
448
|
+
self, num_nodes: int, vertices, edges_dict: dict, mapping: dict
|
449
|
+
) -> dict:
|
292
450
|
"""
|
293
451
|
Get the selected nodes and edges of the subgraph based on the vertices and edges computed
|
294
452
|
by the PCST algorithm.
|
@@ -305,31 +463,26 @@ class MultimodalPCSTPruning(NamedTuple):
|
|
305
463
|
# Get edges information
|
306
464
|
edges = edges_dict["edges"]
|
307
465
|
num_prior_edges = edges_dict["num_prior_edges"]
|
308
|
-
|
309
|
-
edges = edges_dict["edges"]
|
310
|
-
num_prior_edges = edges_dict["num_prior_edges"]
|
466
|
+
|
311
467
|
# Retrieve the selected nodes and edges based on the given vertices and edges
|
312
468
|
subgraph_nodes = vertices[vertices < num_nodes]
|
313
|
-
subgraph_edges = [
|
469
|
+
subgraph_edges = [
|
470
|
+
mapping["edges"][e.item()] for e in edges if e < num_prior_edges
|
471
|
+
]
|
314
472
|
virtual_vertices = vertices[vertices >= num_nodes]
|
315
473
|
if len(virtual_vertices) > 0:
|
316
|
-
virtual_vertices = vertices[vertices >= num_nodes]
|
317
474
|
virtual_edges = [mapping["nodes"][i.item()] for i in virtual_vertices]
|
318
|
-
subgraph_edges = py.array(subgraph_edges + virtual_edges)
|
475
|
+
subgraph_edges = self.loader.py.array(subgraph_edges + virtual_edges)
|
319
476
|
edge_index = edges_dict["edge_index"][:, subgraph_edges]
|
320
|
-
subgraph_nodes = py.unique(
|
321
|
-
py.concatenate(
|
322
|
-
[subgraph_nodes, edge_index[0], edge_index[1]]
|
323
|
-
)
|
477
|
+
subgraph_nodes = self.loader.py.unique(
|
478
|
+
self.loader.py.concatenate([subgraph_nodes, edge_index[0], edge_index[1]])
|
324
479
|
)
|
325
480
|
|
326
481
|
return {"nodes": subgraph_nodes, "edges": subgraph_edges}
|
327
482
|
|
328
|
-
def extract_subgraph(
|
329
|
-
|
330
|
-
|
331
|
-
modality: str,
|
332
|
-
cfg: dict) -> dict:
|
483
|
+
def extract_subgraph(
|
484
|
+
self, text_emb: list, query_emb: list, modality: str, cfg: dict
|
485
|
+
) -> dict:
|
333
486
|
"""
|
334
487
|
Perform the Prize-Collecting Steiner Tree (PCST) algorithm to extract the subgraph.
|
335
488
|
|
@@ -352,7 +505,7 @@ class MultimodalPCSTPruning(NamedTuple):
|
|
352
505
|
logger.log(logging.INFO, "Loading cache edge index")
|
353
506
|
with open(cfg.milvus_db.cache_edge_index_path, "rb") as f:
|
354
507
|
edge_index = pickle.load(f)
|
355
|
-
edge_index = py.array(edge_index)
|
508
|
+
edge_index = self.loader.py.array(edge_index)
|
356
509
|
|
357
510
|
# Assert the topk and topk_e values for subgraph retrieval
|
358
511
|
assert self.topk > 0, "topk must be greater than or equal to 0"
|
@@ -365,7 +518,8 @@ class MultimodalPCSTPruning(NamedTuple):
|
|
365
518
|
# Compute costs in constructing the subgraph
|
366
519
|
logger.log(logging.INFO, "compute_subgraph_costs")
|
367
520
|
edges_dict, prizes, costs, mapping = self.compute_subgraph_costs(
|
368
|
-
edge_index, colls["nodes"].num_entities, prizes
|
521
|
+
edge_index, colls["nodes"].num_entities, prizes
|
522
|
+
)
|
369
523
|
|
370
524
|
# Retrieve the subgraph using the PCST algorithm
|
371
525
|
logger.log(logging.INFO, "Running PCST algorithm")
|
@@ -383,11 +537,14 @@ class MultimodalPCSTPruning(NamedTuple):
|
|
383
537
|
logger.log(logging.INFO, "Getting subgraph nodes and edges")
|
384
538
|
subgraph = self.get_subgraph_nodes_edges(
|
385
539
|
colls["nodes"].num_entities,
|
386
|
-
py.asarray(result_vertices),
|
387
|
-
{
|
388
|
-
|
389
|
-
|
390
|
-
|
540
|
+
self.loader.py.asarray(result_vertices),
|
541
|
+
{
|
542
|
+
"edges": self.loader.py.asarray(result_edges),
|
543
|
+
"num_prior_edges": edges_dict["num_prior_edges"],
|
544
|
+
"edge_index": edge_index,
|
545
|
+
},
|
546
|
+
mapping,
|
547
|
+
)
|
391
548
|
print(subgraph)
|
392
549
|
|
393
550
|
return subgraph
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: aiagents4pharma
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.44.0
|
4
4
|
Summary: AI Agents for drug discovery, drug development, and other pharmaceutical R&D.
|
5
5
|
Classifier: Programming Language :: Python :: 3
|
6
6
|
Classifier: License :: OSI Approved :: MIT License
|
@@ -86,7 +86,7 @@ aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py,sha256=C1yyRZW8hq
|
|
86
86
|
aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
|
87
87
|
aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml,sha256=Ua99yECXiwp4ZCUDgsDskYbKzcJrv7roQuLj31Zky4c,1037
|
88
88
|
aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
89
|
-
aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/default.yaml,sha256=
|
89
|
+
aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/default.yaml,sha256=QEtxF7Fj1DYFEw1qS-JXAptbIgNHc-dnBV6aic0alkk,1330
|
90
90
|
aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
|
91
91
|
aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml,sha256=U8HvMsYbaOwDwQPATj7EFvLtTy7XZEplE5WMoNjgYYc,1469
|
92
92
|
aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
|
@@ -109,7 +109,7 @@ aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_dataset.py,sha256=NFUls
|
|
109
109
|
aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_primekg.py,sha256=Pvu0r93CpnhjkfMxc-EiVLpAJ04FdW9iTamCnetu654,2272
|
110
110
|
aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_starkqa_primekg.py,sha256=TuIsqcN1Mww3DTqGk6ebgJBWzUWdMWEq2yRQuYSFqvA,4416
|
111
111
|
aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py,sha256=aOKHTber2Cg3mjNjfIa6RZU7XdFj5C2ps1YEUXw76CI,10650
|
112
|
-
aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py,sha256=
|
112
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py,sha256=rS2QbSjzHB8TA7JX0K66boILoug9qCJSPp3lti_CL_s,27930
|
113
113
|
aiagents4pharma/talk2knowledgegraphs/tests/test_tools_multimodal_subgraph_extraction.py,sha256=Da-hXcu41_5Ge4DPlOoY6OqBwYnXPc58Q89wuywqVJM,5806
|
114
114
|
aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py,sha256=C2HzmAG1XCeV1hwZzz3-9_2dm_84-i1BvTNWA1pqUwM,5393
|
115
115
|
aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py,sha256=oBqfspXXOxH04OQuPb8BCW0liIQTGKXtaPNSrPpQtFc,7597
|
@@ -124,13 +124,13 @@ aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ols.py,sha256=
|
|
124
124
|
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_pubchem.py,sha256=0SgYvqdvxseUYTHx2KuSNI2hnmQ3VVVz0F-79_-P41o,1769
|
125
125
|
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_reactome.py,sha256=r1D74mavsnSCm4xnWl0n0nM9PZqgm3doD2dulNrKNVQ,1754
|
126
126
|
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_uniprot.py,sha256=G13Diw7cA5TGINUNO1CDnN4rM6KbepxRXNjuzY578DI,1611
|
127
|
-
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_extractions_milvus_multimodal_pcst.py,sha256=
|
127
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_extractions_milvus_multimodal_pcst.py,sha256=uGo60iLGuXkkiG67JBbRjUPxw1uICxPYlenbb9wrD48,20893
|
128
128
|
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py,sha256=pal76wi7WgQWUNk56BrzfFV8jKpbDaHHdbwtgx_gXLI,2410
|
129
129
|
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_pubchem_utils.py,sha256=31WPX8MrhnztoHUROAlH5KvHeXMbB_Jndp3ypAKJO9E,1543
|
130
130
|
aiagents4pharma/talk2knowledgegraphs/tools/__init__.py,sha256=u50fnnIhm7NHt4JhQeXdF_XtNYR2i35p4VRNQzP1CVQ,268
|
131
131
|
aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py,sha256=OEuOFncDRdb7TQEGq4rkT5On-jI-R7Nt8K5EBzaND8w,5338
|
132
132
|
aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py,sha256=zhmsRp-8vjB5rRekqTA07d3yb-42HWqng9dDMkvK6hM,623
|
133
|
-
aiagents4pharma/talk2knowledgegraphs/tools/milvus_multimodal_subgraph_extraction.py,sha256=
|
133
|
+
aiagents4pharma/talk2knowledgegraphs/tools/milvus_multimodal_subgraph_extraction.py,sha256=M-Rrg-mPpe0dpP-AF85tgLSJjPRqFv3UUs9cWTHcCes,24065
|
134
134
|
aiagents4pharma/talk2knowledgegraphs/tools/multimodal_subgraph_extraction.py,sha256=Qjl8hXG8Gv5jQ4pBX8me0pGGakqRZmcDfTGgdEHD9pc,15394
|
135
135
|
aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py,sha256=te06QMFQfgJWrjaGrqpcOYeaV38jwm0KY_rXVSMHkeI,11468
|
136
136
|
aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py,sha256=mDSBOxopDfNhEJeU8fVI8b5lXTYrRzcc97aLbFgYSy4,4413
|
@@ -151,7 +151,7 @@ aiagents4pharma/talk2knowledgegraphs/utils/enrichments/pubchem_strings.py,sha256
|
|
151
151
|
aiagents4pharma/talk2knowledgegraphs/utils/enrichments/reactome_pathways.py,sha256=I0cD0Fk2Uk27_4jEaIhpoGhoMh_RphY1VtkMnk4dkPg,2011
|
152
152
|
aiagents4pharma/talk2knowledgegraphs/utils/enrichments/uniprot_proteins.py,sha256=z0Jb3tt8VzRjzqI9oVcUvRlPPg6BUdmslfKDIEFE_h8,3013
|
153
153
|
aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py,sha256=6xVclMlSuAIHQXBvn5D9zRLDzSv2LWLcAwDQw-nwZgM,153
|
154
|
-
aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py,sha256
|
154
|
+
aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py,sha256=-Qi35E4Yev0jv6UhAhhVkgaF53p2zvmS6FMytOOx-Xs,21533
|
155
155
|
aiagents4pharma/talk2knowledgegraphs/utils/extractions/multimodal_pcst.py,sha256=Irh5JXEhaLZ6Rxv3h5Anif_rGNItyLOGDWg1RACmoDA,12628
|
156
156
|
aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py,sha256=m5p0yoJb7I19ua5yeQfXPf7c4r6S1XPwttsrM7Qoy94,9336
|
157
157
|
aiagents4pharma/talk2scholars/__init__.py,sha256=NOZxTklAH1j1ggu97Ib8Xn9LCKudEWt-8dx8w7yxVD8,180
|
@@ -286,8 +286,8 @@ aiagents4pharma/talk2scholars/tools/zotero/utils/review_helper.py,sha256=IPD1V9y
|
|
286
286
|
aiagents4pharma/talk2scholars/tools/zotero/utils/write_helper.py,sha256=ALwLecy1QVebbsmXJiDj1GhGmyhq2R2tZlAyEl1vfhw,7410
|
287
287
|
aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_path.py,sha256=oIrfbOySgts50ksHKyjcWjRkPRIS88g3Lc0v9mBkU8w,6375
|
288
288
|
aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_pdf_downloader.py,sha256=ERBha8afU6Q1EaRBe9qB8tchOzZ4_KfFgDW6EElOJoU,4816
|
289
|
-
aiagents4pharma-1.
|
290
|
-
aiagents4pharma-1.
|
291
|
-
aiagents4pharma-1.
|
292
|
-
aiagents4pharma-1.
|
293
|
-
aiagents4pharma-1.
|
289
|
+
aiagents4pharma-1.44.0.dist-info/licenses/LICENSE,sha256=IcIbyB1Hyk5ZDah03VNQvJkbNk2hkBCDqQ8qtnCvB4Q,1077
|
290
|
+
aiagents4pharma-1.44.0.dist-info/METADATA,sha256=Ibu0QQcsmO9c65SO-Q5jkhU2Vd-QXCQ5fS9h-2Hoa6I,13281
|
291
|
+
aiagents4pharma-1.44.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
292
|
+
aiagents4pharma-1.44.0.dist-info/top_level.txt,sha256=-AH8rMmrSnJtq7HaAObS78UU-cTCwvX660dSxeM7a0A,16
|
293
|
+
aiagents4pharma-1.44.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|