aiagents4pharma 1.42.0__py3-none-any.whl → 1.44.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/default.yaml +17 -2
  2. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py +618 -413
  3. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_extractions_milvus_multimodal_pcst.py +362 -25
  4. aiagents4pharma/talk2knowledgegraphs/tools/milvus_multimodal_subgraph_extraction.py +146 -109
  5. aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py +240 -83
  6. aiagents4pharma/talk2scholars/agents/paper_download_agent.py +7 -4
  7. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +49 -95
  8. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/default.yaml +15 -1
  9. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/pdf_agent/default.yaml +16 -2
  10. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +40 -5
  11. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +15 -5
  12. aiagents4pharma/talk2scholars/configs/config.yaml +1 -3
  13. aiagents4pharma/talk2scholars/configs/tools/paper_download/default.yaml +124 -0
  14. aiagents4pharma/talk2scholars/tests/test_arxiv_downloader.py +478 -0
  15. aiagents4pharma/talk2scholars/tests/test_base_paper_downloader.py +620 -0
  16. aiagents4pharma/talk2scholars/tests/test_biorxiv_downloader.py +697 -0
  17. aiagents4pharma/talk2scholars/tests/test_medrxiv_downloader.py +534 -0
  18. aiagents4pharma/talk2scholars/tests/test_paper_download_agent.py +22 -12
  19. aiagents4pharma/talk2scholars/tests/test_paper_downloader.py +545 -0
  20. aiagents4pharma/talk2scholars/tests/test_pubmed_downloader.py +1067 -0
  21. aiagents4pharma/talk2scholars/tools/paper_download/__init__.py +2 -4
  22. aiagents4pharma/talk2scholars/tools/paper_download/paper_downloader.py +457 -0
  23. aiagents4pharma/talk2scholars/tools/paper_download/utils/__init__.py +20 -0
  24. aiagents4pharma/talk2scholars/tools/paper_download/utils/arxiv_downloader.py +209 -0
  25. aiagents4pharma/talk2scholars/tools/paper_download/utils/base_paper_downloader.py +343 -0
  26. aiagents4pharma/talk2scholars/tools/paper_download/utils/biorxiv_downloader.py +321 -0
  27. aiagents4pharma/talk2scholars/tools/paper_download/utils/medrxiv_downloader.py +198 -0
  28. aiagents4pharma/talk2scholars/tools/paper_download/utils/pubmed_downloader.py +337 -0
  29. aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +97 -45
  30. aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +47 -29
  31. {aiagents4pharma-1.42.0.dist-info → aiagents4pharma-1.44.0.dist-info}/METADATA +3 -1
  32. {aiagents4pharma-1.42.0.dist-info → aiagents4pharma-1.44.0.dist-info}/RECORD +36 -33
  33. aiagents4pharma/talk2scholars/configs/tools/download_arxiv_paper/default.yaml +0 -4
  34. aiagents4pharma/talk2scholars/configs/tools/download_biorxiv_paper/__init__.py +0 -3
  35. aiagents4pharma/talk2scholars/configs/tools/download_biorxiv_paper/default.yaml +0 -2
  36. aiagents4pharma/talk2scholars/configs/tools/download_medrxiv_paper/__init__.py +0 -3
  37. aiagents4pharma/talk2scholars/configs/tools/download_medrxiv_paper/default.yaml +0 -2
  38. aiagents4pharma/talk2scholars/tests/test_paper_download_biorxiv.py +0 -151
  39. aiagents4pharma/talk2scholars/tests/test_paper_download_medrxiv.py +0 -151
  40. aiagents4pharma/talk2scholars/tests/test_paper_download_tools.py +0 -249
  41. aiagents4pharma/talk2scholars/tools/paper_download/download_arxiv_input.py +0 -177
  42. aiagents4pharma/talk2scholars/tools/paper_download/download_biorxiv_input.py +0 -114
  43. aiagents4pharma/talk2scholars/tools/paper_download/download_medrxiv_input.py +0 -114
  44. /aiagents4pharma/talk2scholars/configs/tools/{download_arxiv_paper → paper_download}/__init__.py +0 -0
  45. {aiagents4pharma-1.42.0.dist-info → aiagents4pharma-1.44.0.dist-info}/WHEEL +0 -0
  46. {aiagents4pharma-1.42.0.dist-info → aiagents4pharma-1.44.0.dist-info}/licenses/LICENSE +0 -0
  47. {aiagents4pharma-1.42.0.dist-info → aiagents4pharma-1.44.0.dist-info}/top_level.txt +0 -0
@@ -2,24 +2,156 @@
2
2
  Exctraction of multimodal subgraph using Prize-Collecting Steiner Tree (PCST) algorithm.
3
3
  """
4
4
 
5
- from typing import Tuple, NamedTuple
6
5
  import logging
7
6
  import pickle
7
+ import platform
8
+ import subprocess
9
+ from typing import NamedTuple
10
+
11
+ import numpy as np
8
12
  import pandas as pd
9
13
  import pcst_fast
10
14
  from pymilvus import Collection
15
+
11
16
  try:
12
- import cupy as py
13
17
  import cudf
14
- df = cudf
18
+ import cupy as cp
19
+ CUDF_AVAILABLE = True
15
20
  except ImportError:
16
- import numpy as py
17
- df = pd
21
+ CUDF_AVAILABLE = False
22
+ cudf = None
23
+ cp = None
18
24
 
19
25
  # Initialize logger
20
26
  logging.basicConfig(level=logging.INFO)
21
27
  logger = logging.getLogger(__name__)
22
28
 
29
+
30
+ class SystemDetector:
31
+ """Detect system capabilities and choose appropriate libraries."""
32
+
33
+ def __init__(self):
34
+ self.os_type = platform.system().lower() # 'windows', 'linux', 'darwin'
35
+ self.architecture = platform.machine().lower() # 'x86_64', 'arm64', etc.
36
+ self.has_nvidia_gpu = self._detect_nvidia_gpu()
37
+ self.use_gpu = (
38
+ self.has_nvidia_gpu and self.os_type != "darwin"
39
+ ) # No CUDA on macOS
40
+
41
+ logger.info("System Detection Results:")
42
+ logger.info(" OS: %s", self.os_type)
43
+ logger.info(" Architecture: %s", self.architecture)
44
+ logger.info(" NVIDIA GPU detected: %s", self.has_nvidia_gpu)
45
+ logger.info(" Will use GPU acceleration: %s", self.use_gpu)
46
+
47
+ def _detect_nvidia_gpu(self) -> bool:
48
+ """Detect if NVIDIA GPU is available."""
49
+ try:
50
+ # Try nvidia-smi command
51
+ result = subprocess.run(
52
+ ["nvidia-smi"], capture_output=True, text=True, timeout=10, check=False
53
+ )
54
+ return result.returncode == 0
55
+ except (
56
+ subprocess.TimeoutExpired,
57
+ FileNotFoundError,
58
+ subprocess.SubprocessError,
59
+ ):
60
+ return False
61
+
62
+ def get_system_info(self) -> dict:
63
+ """Get comprehensive system information."""
64
+ return {
65
+ "os_type": self.os_type,
66
+ "architecture": self.architecture,
67
+ "has_nvidia_gpu": self.has_nvidia_gpu,
68
+ "use_gpu": self.use_gpu,
69
+ }
70
+
71
+ def is_gpu_compatible(self) -> bool:
72
+ """Check if the system is compatible with GPU acceleration."""
73
+ return self.has_nvidia_gpu and self.os_type != "darwin"
74
+
75
+
76
+ class DynamicLibraryLoader:
77
+ """Dynamically load libraries based on system capabilities."""
78
+
79
+ def __init__(self, detector: SystemDetector):
80
+ self.detector = detector
81
+ self.use_gpu = detector.use_gpu
82
+
83
+ # Initialize attributes that will be set later
84
+ self.py = None
85
+ self.df = None
86
+ self.pd = None
87
+ self.np = None
88
+ self.cudf = None
89
+ self.cp = None
90
+
91
+ # Import libraries based on system capabilities
92
+ self._import_libraries()
93
+
94
+ # Dynamic settings based on hardware
95
+ self.normalize_vectors = self.use_gpu # Only normalize for GPU
96
+ self.metric_type = "IP" if self.use_gpu else "COSINE"
97
+
98
+ logger.info("Library Configuration:")
99
+ logger.info(" Using GPU acceleration: %s", self.use_gpu)
100
+ logger.info(" Vector normalization: %s", self.normalize_vectors)
101
+ logger.info(" Metric type: %s", self.metric_type)
102
+
103
+ def _import_libraries(self):
104
+ """Dynamically import libraries based on system capabilities."""
105
+ # Set base libraries
106
+ self.pd = pd
107
+ self.np = np
108
+
109
+ # Conditionally import GPU libraries
110
+ if self.detector.use_gpu:
111
+ if CUDF_AVAILABLE:
112
+ self.cudf = cudf
113
+ self.cp = cp
114
+ self.py = cp # Use cupy for array operations
115
+ self.df = cudf # Use cudf for dataframes
116
+ logger.info("Successfully imported GPU libraries (cudf, cupy)")
117
+ else:
118
+ logger.error("cudf or cupy not found. Falling back to CPU mode.")
119
+ self.detector.use_gpu = False
120
+ self.use_gpu = False
121
+ self._setup_cpu_mode()
122
+ else:
123
+ self._setup_cpu_mode()
124
+
125
+ def _setup_cpu_mode(self):
126
+ """Setup CPU mode with numpy and pandas."""
127
+ self.py = self.np # Use numpy for array operations
128
+ self.df = self.pd # Use pandas for dataframes
129
+ self.normalize_vectors = False
130
+ self.metric_type = "COSINE"
131
+ logger.info("Using CPU mode with numpy and pandas")
132
+
133
+ def normalize_matrix(self, matrix, axis: int = 1):
134
+ """Normalize matrix using appropriate library."""
135
+ if not self.normalize_vectors:
136
+ return matrix
137
+
138
+ if self.use_gpu:
139
+ # Use cupy for GPU
140
+ matrix_cp = self.cp.asarray(matrix).astype(self.cp.float32)
141
+ norms = self.cp.linalg.norm(matrix_cp, axis=axis, keepdims=True)
142
+ return matrix_cp / norms
143
+ # CPU mode doesn't normalize for COSINE similarity
144
+ return matrix
145
+
146
+ def to_list(self, data):
147
+ """Convert data to list format."""
148
+ if hasattr(data, "tolist"):
149
+ return data.tolist()
150
+ if hasattr(data, "to_arrow"):
151
+ return data.to_arrow().to_pylist()
152
+ return list(data)
153
+
154
+
23
155
  class MultimodalPCSTPruning(NamedTuple):
24
156
  """
25
157
  Prize-Collecting Steiner Tree (PCST) pruning algorithm implementation inspired by G-Retriever
@@ -37,7 +169,11 @@ class MultimodalPCSTPruning(NamedTuple):
37
169
  num_clusters: The number of clusters.
38
170
  pruning: The pruning strategy to use.
39
171
  verbosity_level: The verbosity level.
172
+ use_description: Whether to use description embeddings.
173
+ metric_type: The similarity metric type (dynamic based on hardware).
174
+ loader: The dynamic library loader instance.
40
175
  """
176
+
41
177
  topk: int = 3
42
178
  topk_e: int = 3
43
179
  cost_e: float = 0.5
@@ -47,7 +183,8 @@ class MultimodalPCSTPruning(NamedTuple):
47
183
  pruning: str = "gw"
48
184
  verbosity_level: int = 0
49
185
  use_description: bool = False
50
- metric_type: str = "IP" # Inner Product
186
+ metric_type: str = None # Will be set dynamically
187
+ loader: DynamicLibraryLoader = None
51
188
 
52
189
  def prepare_collections(self, cfg: dict, modality: str) -> dict:
53
190
  """
@@ -81,11 +218,9 @@ class MultimodalPCSTPruning(NamedTuple):
81
218
 
82
219
  return colls
83
220
 
84
- def _compute_node_prizes(self,
85
- query_emb: list,
86
- colls: dict) -> dict:
221
+ def _compute_node_prizes(self, query_emb: list, colls: dict) -> dict:
87
222
  """
88
- Compute the node prizes based on the cosine similarity between the query and nodes.
223
+ Compute the node prizes based on the similarity between the query and nodes.
89
224
 
90
225
  Args:
91
226
  query_emb: The query embedding. This can be an embedding of
@@ -95,79 +230,91 @@ class MultimodalPCSTPruning(NamedTuple):
95
230
  Returns:
96
231
  The prizes of the nodes.
97
232
  """
98
- # Intialize several variables
233
+ # Initialize several variables
99
234
  topk = min(self.topk, colls["nodes"].num_entities)
100
- n_prizes = py.zeros(colls["nodes"].num_entities, dtype=py.float32)
235
+ n_prizes = self.loader.py.zeros(
236
+ colls["nodes"].num_entities, dtype=self.loader.py.float32
237
+ )
238
+
239
+ # Get the actual metric type to use
240
+ actual_metric_type = self.metric_type or self.loader.metric_type
101
241
 
102
- # Calculate cosine similarity for text features and update the score
242
+ # Calculate similarity for text features and update the score
103
243
  if self.use_description:
104
244
  # Search the collection with the text embedding
105
245
  res = colls["nodes"].search(
106
246
  data=[query_emb],
107
247
  anns_field="desc_emb",
108
- param={"metric_type": self.metric_type},
248
+ param={"metric_type": actual_metric_type},
109
249
  limit=topk,
110
- output_fields=["node_id"])
250
+ output_fields=["node_id"],
251
+ )
111
252
  else:
112
253
  # Search the collection with the query embedding
113
254
  res = colls["nodes_type"].search(
114
255
  data=[query_emb],
115
256
  anns_field="feat_emb",
116
- param={"metric_type": self.metric_type},
257
+ param={"metric_type": actual_metric_type},
117
258
  limit=topk,
118
- output_fields=["node_id"])
259
+ output_fields=["node_id"],
260
+ )
119
261
 
120
262
  # Update the prizes based on the search results
121
- n_prizes[[r.id for r in res[0]]] = py.arange(topk, 0, -1).astype(py.float32)
263
+ n_prizes[[r.id for r in res[0]]] = self.loader.py.arange(topk, 0, -1).astype(
264
+ self.loader.py.float32
265
+ )
122
266
 
123
267
  return n_prizes
124
268
 
125
- def _compute_edge_prizes(self,
126
- text_emb: list,
127
- colls: dict) -> py.ndarray:
269
+ def _compute_edge_prizes(self, text_emb: list, colls: dict):
128
270
  """
129
- Compute the node prizes based on the cosine similarity between the query and nodes.
271
+ Compute the edge prizes based on the similarity between the query and edges.
130
272
 
131
273
  Args:
132
274
  text_emb: The textual description embedding.
133
275
  colls: The collections of nodes, node-type specific nodes, and edges in Milvus.
134
276
 
135
277
  Returns:
136
- The prizes of the nodes.
278
+ The prizes of the edges.
137
279
  """
138
- # Intialize several variables
280
+ # Initialize several variables
139
281
  topk_e = min(self.topk_e, colls["edges"].num_entities)
140
- e_prizes = py.zeros(colls["edges"].num_entities, dtype=py.float32)
282
+ e_prizes = self.loader.py.zeros(
283
+ colls["edges"].num_entities, dtype=self.loader.py.float32
284
+ )
285
+
286
+ # Get the actual metric type to use
287
+ actual_metric_type = self.metric_type or self.loader.metric_type
141
288
 
142
289
  # Search the collection with the query embedding
143
290
  res = colls["edges"].search(
144
291
  data=[text_emb],
145
292
  anns_field="feat_emb",
146
- param={"metric_type": self.metric_type},
147
- limit=topk_e, # Only retrieve the top-k edges
148
- # limit=colls["edges"].num_entities,
149
- output_fields=["head_id", "tail_id"])
293
+ param={"metric_type": actual_metric_type},
294
+ limit=topk_e, # Only retrieve the top-k edges
295
+ output_fields=["head_id", "tail_id"],
296
+ )
150
297
 
151
298
  # Update the prizes based on the search results
152
299
  e_prizes[[r.id for r in res[0]]] = [r.score for r in res[0]]
153
300
 
154
301
  # Further process the edge_prizes
155
- unique_prizes, inverse_indices = py.unique(e_prizes, return_inverse=True)
156
- topk_e_values = unique_prizes[py.argsort(-unique_prizes)[:topk_e]]
157
- # e_prizes[e_prizes < topk_e_values[-1]] = 0.0
302
+ unique_prizes, inverse_indices = self.loader.py.unique(
303
+ e_prizes, return_inverse=True
304
+ )
305
+ topk_e_values = unique_prizes[self.loader.py.argsort(-unique_prizes)[:topk_e]]
158
306
  last_topk_e_value = topk_e
159
307
  for k in range(topk_e):
160
- indices = inverse_indices == (unique_prizes == topk_e_values[k]).nonzero()[0]
308
+ indices = (
309
+ inverse_indices == (unique_prizes == topk_e_values[k]).nonzero()[0]
310
+ )
161
311
  value = min((topk_e - k) / indices.sum().item(), last_topk_e_value)
162
312
  e_prizes[indices] = value
163
313
  last_topk_e_value = value * (1 - self.c_const)
164
314
 
165
315
  return e_prizes
166
316
 
167
- def compute_prizes(self,
168
- text_emb: list,
169
- query_emb: list,
170
- colls: dict) -> dict:
317
+ def compute_prizes(self, text_emb: list, query_emb: list, colls: dict) -> dict:
171
318
  """
172
319
  Compute the node prizes based on the cosine similarity between the query and nodes,
173
320
  as well as the edge prizes based on the cosine similarity between the query and edges.
@@ -193,10 +340,7 @@ class MultimodalPCSTPruning(NamedTuple):
193
340
 
194
341
  return {"nodes": n_prizes, "edges": e_prizes}
195
342
 
196
- def compute_subgraph_costs(self,
197
- edge_index: py.ndarray,
198
- num_nodes: int,
199
- prizes: dict) -> Tuple[py.ndarray, py.ndarray, py.ndarray]:
343
+ def compute_subgraph_costs(self, edge_index, num_nodes: int, prizes: dict):
200
344
  """
201
345
  Compute the costs in constructing the subgraph proposed by G-Retriever paper.
202
346
 
@@ -218,7 +362,7 @@ class MultimodalPCSTPruning(NamedTuple):
218
362
  # Update edge cost threshold
219
363
  updated_cost_e = min(
220
364
  self.cost_e,
221
- py.max(prizes["edges"]).item() * (1 - self.c_const / 2),
365
+ self.loader.py.max(prizes["edges"]).item() * (1 - self.c_const / 2),
222
366
  )
223
367
 
224
368
  # Masks for real and virtual edges
@@ -228,19 +372,21 @@ class MultimodalPCSTPruning(NamedTuple):
228
372
 
229
373
  # Real edge indices
230
374
  logger.log(logging.INFO, "Computing real edges")
231
- real_["indices"] = py.nonzero(real_["mask"])[0]
375
+ real_["indices"] = self.loader.py.nonzero(real_["mask"])[0]
232
376
  real_["src"] = edge_index[0][real_["indices"]]
233
377
  real_["dst"] = edge_index[1][real_["indices"]]
234
- real_["edges"] = py.stack([real_["src"], real_["dst"]], axis=1)
378
+ real_["edges"] = self.loader.py.stack([real_["src"], real_["dst"]], axis=1)
235
379
  real_["costs"] = updated_cost_e - prizes["edges"][real_["indices"]]
236
380
 
237
381
  # Edge index mapping: local real edge idx -> original global index
238
382
  logger.log(logging.INFO, "Creating mapping for real edges")
239
- mapping_edges = dict(zip(range(len(real_["indices"])), real_["indices"].tolist()))
383
+ mapping_edges = dict(
384
+ zip(range(len(real_["indices"])), self.loader.to_list(real_["indices"]))
385
+ )
240
386
 
241
387
  # Virtual edge handling
242
388
  logger.log(logging.INFO, "Computing virtual edges")
243
- virt_["indices"] = py.nonzero(virt_["mask"])[0]
389
+ virt_["indices"] = self.loader.py.nonzero(virt_["mask"])[0]
244
390
  virt_["src"] = edge_index[0][virt_["indices"]]
245
391
  virt_["dst"] = edge_index[1][virt_["indices"]]
246
392
  virt_["prizes"] = prizes["edges"][virt_["indices"]] - updated_cost_e
@@ -248,28 +394,42 @@ class MultimodalPCSTPruning(NamedTuple):
248
394
  # Generate virtual node IDs
249
395
  logger.log(logging.INFO, "Generating virtual node IDs")
250
396
  virt_["num"] = virt_["indices"].shape[0]
251
- virt_["node_ids"] = py.arange(num_nodes, num_nodes + virt_["num"])
397
+ virt_["node_ids"] = self.loader.py.arange(num_nodes, num_nodes + virt_["num"])
252
398
 
253
399
  # Virtual edges: (src → virtual), (virtual → dst)
254
400
  logger.log(logging.INFO, "Creating virtual edges")
255
- virt_["edges_1"] = py.stack([virt_["src"], virt_["node_ids"]], axis=1)
256
- virt_["edges_2"] = py.stack([virt_["node_ids"], virt_["dst"]], axis=1)
257
- virt_["edges"] = py.concatenate([virt_["edges_1"],
258
- virt_["edges_2"]], axis=0)
259
- virt_["costs"] = py.zeros((virt_["edges"].shape[0],), dtype=real_["costs"].dtype)
401
+ virt_["edges_1"] = self.loader.py.stack(
402
+ [virt_["src"], virt_["node_ids"]], axis=1
403
+ )
404
+ virt_["edges_2"] = self.loader.py.stack(
405
+ [virt_["node_ids"], virt_["dst"]], axis=1
406
+ )
407
+ virt_["edges"] = self.loader.py.concatenate(
408
+ [virt_["edges_1"], virt_["edges_2"]], axis=0
409
+ )
410
+ virt_["costs"] = self.loader.py.zeros(
411
+ (virt_["edges"].shape[0],), dtype=real_["costs"].dtype
412
+ )
260
413
 
261
414
  # Combine real and virtual edges/costs
262
415
  logger.log(logging.INFO, "Combining real and virtual edges/costs")
263
- all_edges = py.concatenate([real_["edges"], virt_["edges"]], axis=0)
264
- all_costs = py.concatenate([real_["costs"], virt_["costs"]], axis=0)
416
+ all_edges = self.loader.py.concatenate([real_["edges"], virt_["edges"]], axis=0)
417
+ all_costs = self.loader.py.concatenate([real_["costs"], virt_["costs"]], axis=0)
265
418
 
266
419
  # Final prizes
267
420
  logger.log(logging.INFO, "Getting final prizes")
268
- final_prizes = py.concatenate([prizes["nodes"], virt_["prizes"]], axis=0)
421
+ final_prizes = self.loader.py.concatenate(
422
+ [prizes["nodes"], virt_["prizes"]], axis=0
423
+ )
269
424
 
270
425
  # Mapping virtual node ID -> edge index in original graph
271
426
  logger.log(logging.INFO, "Creating mapping for virtual nodes")
272
- mapping_nodes = dict(zip(virt_["node_ids"].tolist(), virt_["indices"].tolist()))
427
+ mapping_nodes = dict(
428
+ zip(
429
+ self.loader.to_list(virt_["node_ids"]),
430
+ self.loader.to_list(virt_["indices"]),
431
+ )
432
+ )
273
433
 
274
434
  # Build return values
275
435
  logger.log(logging.INFO, "Building return values")
@@ -284,11 +444,9 @@ class MultimodalPCSTPruning(NamedTuple):
284
444
 
285
445
  return edges_dict, final_prizes, all_costs, mapping
286
446
 
287
- def get_subgraph_nodes_edges(self,
288
- num_nodes: int,
289
- vertices: py.ndarray,
290
- edges_dict: dict,
291
- mapping: dict) -> dict:
447
+ def get_subgraph_nodes_edges(
448
+ self, num_nodes: int, vertices, edges_dict: dict, mapping: dict
449
+ ) -> dict:
292
450
  """
293
451
  Get the selected nodes and edges of the subgraph based on the vertices and edges computed
294
452
  by the PCST algorithm.
@@ -305,31 +463,26 @@ class MultimodalPCSTPruning(NamedTuple):
305
463
  # Get edges information
306
464
  edges = edges_dict["edges"]
307
465
  num_prior_edges = edges_dict["num_prior_edges"]
308
- # Get edges information
309
- edges = edges_dict["edges"]
310
- num_prior_edges = edges_dict["num_prior_edges"]
466
+
311
467
  # Retrieve the selected nodes and edges based on the given vertices and edges
312
468
  subgraph_nodes = vertices[vertices < num_nodes]
313
- subgraph_edges = [mapping["edges"][e.item()] for e in edges if e < num_prior_edges]
469
+ subgraph_edges = [
470
+ mapping["edges"][e.item()] for e in edges if e < num_prior_edges
471
+ ]
314
472
  virtual_vertices = vertices[vertices >= num_nodes]
315
473
  if len(virtual_vertices) > 0:
316
- virtual_vertices = vertices[vertices >= num_nodes]
317
474
  virtual_edges = [mapping["nodes"][i.item()] for i in virtual_vertices]
318
- subgraph_edges = py.array(subgraph_edges + virtual_edges)
475
+ subgraph_edges = self.loader.py.array(subgraph_edges + virtual_edges)
319
476
  edge_index = edges_dict["edge_index"][:, subgraph_edges]
320
- subgraph_nodes = py.unique(
321
- py.concatenate(
322
- [subgraph_nodes, edge_index[0], edge_index[1]]
323
- )
477
+ subgraph_nodes = self.loader.py.unique(
478
+ self.loader.py.concatenate([subgraph_nodes, edge_index[0], edge_index[1]])
324
479
  )
325
480
 
326
481
  return {"nodes": subgraph_nodes, "edges": subgraph_edges}
327
482
 
328
- def extract_subgraph(self,
329
- text_emb: list,
330
- query_emb: list,
331
- modality: str,
332
- cfg: dict) -> dict:
483
+ def extract_subgraph(
484
+ self, text_emb: list, query_emb: list, modality: str, cfg: dict
485
+ ) -> dict:
333
486
  """
334
487
  Perform the Prize-Collecting Steiner Tree (PCST) algorithm to extract the subgraph.
335
488
 
@@ -352,7 +505,7 @@ class MultimodalPCSTPruning(NamedTuple):
352
505
  logger.log(logging.INFO, "Loading cache edge index")
353
506
  with open(cfg.milvus_db.cache_edge_index_path, "rb") as f:
354
507
  edge_index = pickle.load(f)
355
- edge_index = py.array(edge_index)
508
+ edge_index = self.loader.py.array(edge_index)
356
509
 
357
510
  # Assert the topk and topk_e values for subgraph retrieval
358
511
  assert self.topk > 0, "topk must be greater than or equal to 0"
@@ -365,7 +518,8 @@ class MultimodalPCSTPruning(NamedTuple):
365
518
  # Compute costs in constructing the subgraph
366
519
  logger.log(logging.INFO, "compute_subgraph_costs")
367
520
  edges_dict, prizes, costs, mapping = self.compute_subgraph_costs(
368
- edge_index, colls["nodes"].num_entities, prizes)
521
+ edge_index, colls["nodes"].num_entities, prizes
522
+ )
369
523
 
370
524
  # Retrieve the subgraph using the PCST algorithm
371
525
  logger.log(logging.INFO, "Running PCST algorithm")
@@ -383,11 +537,14 @@ class MultimodalPCSTPruning(NamedTuple):
383
537
  logger.log(logging.INFO, "Getting subgraph nodes and edges")
384
538
  subgraph = self.get_subgraph_nodes_edges(
385
539
  colls["nodes"].num_entities,
386
- py.asarray(result_vertices),
387
- {"edges": py.asarray(result_edges),
388
- "num_prior_edges": edges_dict["num_prior_edges"],
389
- "edge_index": edge_index},
390
- mapping)
540
+ self.loader.py.asarray(result_vertices),
541
+ {
542
+ "edges": self.loader.py.asarray(result_edges),
543
+ "num_prior_edges": edges_dict["num_prior_edges"],
544
+ "edge_index": edge_index,
545
+ },
546
+ mapping,
547
+ )
391
548
  print(subgraph)
392
549
 
393
550
  return subgraph
@@ -13,9 +13,8 @@ from langgraph.prebuilt.chat_agent_executor import create_react_agent
13
13
  from langgraph.prebuilt.tool_node import ToolNode
14
14
  from langgraph.checkpoint.memory import MemorySaver
15
15
  from ..state.state_talk2scholars import Talk2Scholars
16
- from ..tools.paper_download.download_arxiv_input import download_arxiv_paper
17
- from ..tools.paper_download.download_medrxiv_input import download_medrxiv_paper
18
- from ..tools.paper_download.download_biorxiv_input import download_biorxiv_paper
16
+ from ..tools.paper_download.paper_downloader import download_papers
17
+
19
18
 
20
19
  # Initialize logger
21
20
  logging.basicConfig(level=logging.INFO)
@@ -52,7 +51,11 @@ def get_app(uniq_id, llm_model: BaseChatModel):
52
51
  cfg = cfg.agents.talk2scholars.paper_download_agent
53
52
 
54
53
  # Define tools properly
55
- tools = ToolNode([download_arxiv_paper, download_medrxiv_paper, download_biorxiv_paper])
54
+ tools = ToolNode(
55
+ [
56
+ download_papers,
57
+ ]
58
+ )
56
59
 
57
60
  # Define the model
58
61
  logger.info("Using OpenAI model %s", llm_model)