aiagents4pharma 1.42.0__py3-none-any.whl → 1.44.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/default.yaml +17 -2
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py +618 -413
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_extractions_milvus_multimodal_pcst.py +362 -25
- aiagents4pharma/talk2knowledgegraphs/tools/milvus_multimodal_subgraph_extraction.py +146 -109
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py +240 -83
- aiagents4pharma/talk2scholars/agents/paper_download_agent.py +7 -4
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +49 -95
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/default.yaml +15 -1
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/pdf_agent/default.yaml +16 -2
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +40 -5
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +15 -5
- aiagents4pharma/talk2scholars/configs/config.yaml +1 -3
- aiagents4pharma/talk2scholars/configs/tools/paper_download/default.yaml +124 -0
- aiagents4pharma/talk2scholars/tests/test_arxiv_downloader.py +478 -0
- aiagents4pharma/talk2scholars/tests/test_base_paper_downloader.py +620 -0
- aiagents4pharma/talk2scholars/tests/test_biorxiv_downloader.py +697 -0
- aiagents4pharma/talk2scholars/tests/test_medrxiv_downloader.py +534 -0
- aiagents4pharma/talk2scholars/tests/test_paper_download_agent.py +22 -12
- aiagents4pharma/talk2scholars/tests/test_paper_downloader.py +545 -0
- aiagents4pharma/talk2scholars/tests/test_pubmed_downloader.py +1067 -0
- aiagents4pharma/talk2scholars/tools/paper_download/__init__.py +2 -4
- aiagents4pharma/talk2scholars/tools/paper_download/paper_downloader.py +457 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/__init__.py +20 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/arxiv_downloader.py +209 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/base_paper_downloader.py +343 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/biorxiv_downloader.py +321 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/medrxiv_downloader.py +198 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/pubmed_downloader.py +337 -0
- aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +97 -45
- aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +47 -29
- {aiagents4pharma-1.42.0.dist-info → aiagents4pharma-1.44.0.dist-info}/METADATA +3 -1
- {aiagents4pharma-1.42.0.dist-info → aiagents4pharma-1.44.0.dist-info}/RECORD +36 -33
- aiagents4pharma/talk2scholars/configs/tools/download_arxiv_paper/default.yaml +0 -4
- aiagents4pharma/talk2scholars/configs/tools/download_biorxiv_paper/__init__.py +0 -3
- aiagents4pharma/talk2scholars/configs/tools/download_biorxiv_paper/default.yaml +0 -2
- aiagents4pharma/talk2scholars/configs/tools/download_medrxiv_paper/__init__.py +0 -3
- aiagents4pharma/talk2scholars/configs/tools/download_medrxiv_paper/default.yaml +0 -2
- aiagents4pharma/talk2scholars/tests/test_paper_download_biorxiv.py +0 -151
- aiagents4pharma/talk2scholars/tests/test_paper_download_medrxiv.py +0 -151
- aiagents4pharma/talk2scholars/tests/test_paper_download_tools.py +0 -249
- aiagents4pharma/talk2scholars/tools/paper_download/download_arxiv_input.py +0 -177
- aiagents4pharma/talk2scholars/tools/paper_download/download_biorxiv_input.py +0 -114
- aiagents4pharma/talk2scholars/tools/paper_download/download_medrxiv_input.py +0 -114
- /aiagents4pharma/talk2scholars/configs/tools/{download_arxiv_paper → paper_download}/__init__.py +0 -0
- {aiagents4pharma-1.42.0.dist-info → aiagents4pharma-1.44.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.42.0.dist-info → aiagents4pharma-1.44.0.dist-info}/licenses/LICENSE +0 -0
- {aiagents4pharma-1.42.0.dist-info → aiagents4pharma-1.44.0.dist-info}/top_level.txt +0 -0
@@ -2,27 +2,25 @@
|
|
2
2
|
Tool for performing multimodal subgraph extraction.
|
3
3
|
"""
|
4
4
|
|
5
|
-
# import datetime
|
6
|
-
from typing import Type, Annotated
|
7
5
|
import logging
|
6
|
+
from typing import Annotated, Type
|
7
|
+
|
8
8
|
import hydra
|
9
9
|
import pandas as pd
|
10
|
-
from pydantic import BaseModel, Field
|
11
|
-
from langchain_core.tools import BaseTool
|
12
10
|
from langchain_core.messages import ToolMessage
|
11
|
+
from langchain_core.tools import BaseTool
|
13
12
|
from langchain_core.tools.base import InjectedToolCallId
|
14
|
-
from langgraph.types import Command
|
15
13
|
from langgraph.prebuilt import InjectedState
|
14
|
+
from langgraph.types import Command
|
15
|
+
from pydantic import BaseModel, Field
|
16
16
|
from pymilvus import Collection
|
17
|
-
|
17
|
+
|
18
|
+
from ..utils.extractions.milvus_multimodal_pcst import (
|
19
|
+
DynamicLibraryLoader,
|
20
|
+
MultimodalPCSTPruning,
|
21
|
+
SystemDetector,
|
22
|
+
)
|
18
23
|
from .load_arguments import ArgumentData
|
19
|
-
try:
|
20
|
-
import cupy as py
|
21
|
-
import cudf
|
22
|
-
df = cudf
|
23
|
-
except ImportError:
|
24
|
-
import numpy as py
|
25
|
-
df = pd
|
26
24
|
|
27
25
|
# Initialize logger
|
28
26
|
logging.basicConfig(level=logging.INFO)
|
@@ -61,8 +59,16 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
61
59
|
description: str = "A tool for subgraph extraction based on user's prompt."
|
62
60
|
args_schema: Type[BaseModel] = MultimodalSubgraphExtractionInput
|
63
61
|
|
62
|
+
def __init__(self, **kwargs):
|
63
|
+
super().__init__(**kwargs)
|
64
|
+
# Initialize hardware detection and dynamic library loading
|
65
|
+
object.__setattr__(self, 'detector', SystemDetector())
|
66
|
+
object.__setattr__(self, 'loader', DynamicLibraryLoader(self.detector))
|
67
|
+
logger.info("MultimodalSubgraphExtractionTool initialized with %s mode",
|
68
|
+
"GPU" if self.loader.use_gpu else "CPU")
|
69
|
+
|
64
70
|
def _read_multimodal_files(self,
|
65
|
-
state: Annotated[dict, InjectedState])
|
71
|
+
state: Annotated[dict, InjectedState]):
|
66
72
|
"""
|
67
73
|
Read the uploaded multimodal files and return a DataFrame.
|
68
74
|
|
@@ -72,7 +78,7 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
72
78
|
Returns:
|
73
79
|
A DataFrame containing the multimodal files.
|
74
80
|
"""
|
75
|
-
multimodal_df = df.DataFrame({"name": [], "node_type": []})
|
81
|
+
multimodal_df = self.loader.df.DataFrame({"name": [], "node_type": []})
|
76
82
|
|
77
83
|
# Loop over the uploaded files and find multimodal files
|
78
84
|
logger.log(logging.INFO, "Looping over uploaded files")
|
@@ -90,7 +96,7 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
90
96
|
logger.log(logging.INFO, "Preparing multimodal_df")
|
91
97
|
# Merge all obtained dataframes into a single dataframe
|
92
98
|
multimodal_df = pd.concat(multimodal_df).reset_index()
|
93
|
-
multimodal_df = df.DataFrame(multimodal_df)
|
99
|
+
multimodal_df = self.loader.df.DataFrame(multimodal_df)
|
94
100
|
multimodal_df.drop(columns=["level_1"], inplace=True)
|
95
101
|
multimodal_df.rename(columns={"level_0": "q_node_type",
|
96
102
|
"name": "q_node_name"}, inplace=True)
|
@@ -100,10 +106,39 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
100
106
|
|
101
107
|
return multimodal_df
|
102
108
|
|
109
|
+
def _query_milvus_collection(self, node_type, node_type_df, cfg_db):
|
110
|
+
"""Helper method to query Milvus collection for a specific node type."""
|
111
|
+
# Load the collection
|
112
|
+
collection = Collection(
|
113
|
+
name=f"{cfg_db.milvus_db.database_name}_nodes_{node_type.replace('/', '_')}"
|
114
|
+
)
|
115
|
+
collection.load()
|
116
|
+
|
117
|
+
# Query the collection with node names from multimodal_df
|
118
|
+
node_names_series = node_type_df['q_node_name']
|
119
|
+
q_node_names = getattr(node_names_series,
|
120
|
+
"to_pandas",
|
121
|
+
lambda series=node_names_series: series)().tolist()
|
122
|
+
q_columns = ["node_id", "node_name", "node_type",
|
123
|
+
"feat", "feat_emb", "desc", "desc_emb"]
|
124
|
+
res = collection.query(
|
125
|
+
expr=f'node_name IN [{','.join(f'"{name}"' for name in q_node_names)}]',
|
126
|
+
output_fields=q_columns,
|
127
|
+
)
|
128
|
+
# Convert the embeedings into floats
|
129
|
+
for r_ in res:
|
130
|
+
r_['feat_emb'] = [float(x) for x in r_['feat_emb']]
|
131
|
+
r_['desc_emb'] = [float(x) for x in r_['desc_emb']]
|
132
|
+
|
133
|
+
# Convert the result to a DataFrame
|
134
|
+
res_df = self.loader.df.DataFrame(res)[q_columns]
|
135
|
+
res_df["use_description"] = False
|
136
|
+
return res_df
|
137
|
+
|
103
138
|
def _prepare_query_modalities(self,
|
104
139
|
prompt: dict,
|
105
140
|
state: Annotated[dict, InjectedState],
|
106
|
-
cfg_db: dict)
|
141
|
+
cfg_db: dict):
|
107
142
|
"""
|
108
143
|
Prepare the modality-specific query for subgraph extraction.
|
109
144
|
|
@@ -118,7 +153,7 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
118
153
|
# Initialize dataframes
|
119
154
|
logger.log(logging.INFO, "Initializing dataframes")
|
120
155
|
query_df = []
|
121
|
-
prompt_df = df.DataFrame({
|
156
|
+
prompt_df = self.loader.df.DataFrame({
|
122
157
|
'node_id': 'user_prompt',
|
123
158
|
'node_name': 'User Prompt',
|
124
159
|
'node_type': 'prompt',
|
@@ -139,38 +174,12 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
139
174
|
logger.log(logging.INFO, "Querying Milvus database for each node type in multimodal_df")
|
140
175
|
for node_type, node_type_df in multimodal_df.groupby("q_node_type"):
|
141
176
|
print(f"Processing node type: {node_type}")
|
142
|
-
|
143
|
-
# Load the collection
|
144
|
-
collection = Collection(
|
145
|
-
name=f"{cfg_db.milvus_db.database_name}_nodes_{node_type.replace('/', '_')}"
|
146
|
-
)
|
147
|
-
collection.load()
|
148
|
-
|
149
|
-
# Query the collection with node names from multimodal_df
|
150
|
-
q_node_names = getattr(node_type_df['q_node_name'],
|
151
|
-
"to_pandas",
|
152
|
-
lambda: node_type_df['q_node_name'])().tolist()
|
153
|
-
q_columns = ["node_id", "node_name", "node_type",
|
154
|
-
"feat", "feat_emb", "desc", "desc_emb"]
|
155
|
-
res = collection.query(
|
156
|
-
expr=f'node_name IN [{','.join(f'"{name}"' for name in q_node_names)}]',
|
157
|
-
output_fields=q_columns,
|
158
|
-
)
|
159
|
-
# Convert the embeedings into floats
|
160
|
-
for r_ in res:
|
161
|
-
r_['feat_emb'] = [float(x) for x in r_['feat_emb']]
|
162
|
-
r_['desc_emb'] = [float(x) for x in r_['desc_emb']]
|
163
|
-
|
164
|
-
# Convert the result to a DataFrame
|
165
|
-
res_df = df.DataFrame(res)[q_columns]
|
166
|
-
res_df["use_description"] = False
|
167
|
-
|
168
|
-
# Append the results to query_df
|
177
|
+
res_df = self._query_milvus_collection(node_type, node_type_df, cfg_db)
|
169
178
|
query_df.append(res_df)
|
170
179
|
|
171
180
|
# Concatenate all results into a single DataFrame
|
172
181
|
logger.log(logging.INFO, "Concatenating all results into a single DataFrame")
|
173
|
-
query_df = df.concat(query_df, ignore_index=True)
|
182
|
+
query_df = self.loader.df.concat(query_df, ignore_index=True)
|
174
183
|
|
175
184
|
# Update the state by adding the the selected node IDs
|
176
185
|
logger.log(logging.INFO, "Updating state with selected node IDs")
|
@@ -182,7 +191,7 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
182
191
|
|
183
192
|
# Append a user prompt to the query dataframe
|
184
193
|
logger.log(logging.INFO, "Adding user prompt to query dataframe")
|
185
|
-
query_df = df.concat([query_df, prompt_df]).reset_index(drop=True)
|
194
|
+
query_df = self.loader.df.concat([query_df, prompt_df]).reset_index(drop=True)
|
186
195
|
else:
|
187
196
|
# If no multimodal files are uploaded, use the prompt embeddings
|
188
197
|
query_df = prompt_df
|
@@ -223,6 +232,19 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
223
232
|
# Prepare the PCSTPruning object and extract the subgraph
|
224
233
|
# Parameters were set in the configuration file obtained from Hydra
|
225
234
|
# start = datetime.datetime.now()
|
235
|
+
# Get dynamic metric type (overrides any config setting)
|
236
|
+
# Get dynamic metric type (overrides any config setting)
|
237
|
+
has_vector_processing = hasattr(cfg, 'vector_processing')
|
238
|
+
if has_vector_processing:
|
239
|
+
dynamic_metrics_enabled = getattr(cfg.vector_processing, 'dynamic_metrics', True)
|
240
|
+
else:
|
241
|
+
dynamic_metrics_enabled = False
|
242
|
+
if has_vector_processing and dynamic_metrics_enabled:
|
243
|
+
dynamic_metric_type = self.loader.metric_type
|
244
|
+
else:
|
245
|
+
dynamic_metric_type = getattr(cfg, 'search_metric_type',
|
246
|
+
self.loader.metric_type)
|
247
|
+
|
226
248
|
subgraph = MultimodalPCSTPruning(
|
227
249
|
topk=state["topk_nodes"],
|
228
250
|
topk_e=state["topk_edges"],
|
@@ -233,7 +255,8 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
233
255
|
pruning=cfg.pruning,
|
234
256
|
verbosity_level=cfg.verbosity_level,
|
235
257
|
use_description=q[1]['use_description'],
|
236
|
-
metric_type=
|
258
|
+
metric_type=dynamic_metric_type, # Use dynamic or config metric type
|
259
|
+
loader=self.loader # Pass the loader instance
|
237
260
|
).extract_subgraph(q[1]['desc_emb'],
|
238
261
|
q[1]['feat_emb'],
|
239
262
|
q[1]['node_type'],
|
@@ -251,22 +274,24 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
251
274
|
# (end - start).total_seconds())
|
252
275
|
|
253
276
|
# Concatenate and get unique node and edge indices
|
254
|
-
unified_subgraph["nodes"]
|
255
|
-
|
277
|
+
nodes_arrays = [self.loader.py.array(list_) for list_ in unified_subgraph["nodes"]]
|
278
|
+
unified_subgraph["nodes"] = self.loader.py.unique(
|
279
|
+
self.loader.py.concatenate(nodes_arrays)
|
256
280
|
).tolist()
|
257
|
-
unified_subgraph["edges"]
|
258
|
-
|
281
|
+
edges_arrays = [self.loader.py.array(list_) for list_ in unified_subgraph["edges"]]
|
282
|
+
unified_subgraph["edges"] = self.loader.py.unique(
|
283
|
+
self.loader.py.concatenate(edges_arrays)
|
259
284
|
).tolist()
|
260
285
|
|
261
|
-
# Convert the unified subgraph and subgraphs to
|
262
|
-
unified_subgraph = df.DataFrame([("Unified Subgraph",
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
subgraphs = df.DataFrame(subgraphs, columns=["name", "nodes", "edges"])
|
286
|
+
# Convert the unified subgraph and subgraphs to DataFrames
|
287
|
+
unified_subgraph = self.loader.df.DataFrame([("Unified Subgraph",
|
288
|
+
unified_subgraph["nodes"],
|
289
|
+
unified_subgraph["edges"])],
|
290
|
+
columns=["name", "nodes", "edges"])
|
291
|
+
subgraphs = self.loader.df.DataFrame(subgraphs, columns=["name", "nodes", "edges"])
|
267
292
|
|
268
|
-
#
|
269
|
-
subgraphs = df.concat([unified_subgraph, subgraphs], ignore_index=True)
|
293
|
+
# Concatenate both DataFrames
|
294
|
+
subgraphs = self.loader.df.concat([unified_subgraph, subgraphs], ignore_index=True)
|
270
295
|
|
271
296
|
return subgraphs
|
272
297
|
|
@@ -288,10 +313,10 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
288
313
|
Returns:
|
289
314
|
A dictionary containing the PyG graph, NetworkX graph, and textualized graph.
|
290
315
|
"""
|
291
|
-
# Convert the dict to a
|
316
|
+
# Convert the dict to a DataFrame
|
292
317
|
node_colors = {n: cfg.node_colors_dict[k]
|
293
318
|
for k, v in state["selections"].items() for n in v}
|
294
|
-
color_df = df.DataFrame(list(node_colors.items()), columns=["node_id", "color"])
|
319
|
+
color_df = self.loader.df.DataFrame(list(node_colors.items()), columns=["node_id", "color"])
|
295
320
|
# print(color_df)
|
296
321
|
|
297
322
|
# Prepare the subgraph dictionary
|
@@ -302,42 +327,7 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
302
327
|
"text": ""
|
303
328
|
}
|
304
329
|
for sub in getattr(subgraph, "to_pandas", lambda: subgraph)().itertuples(index=False):
|
305
|
-
|
306
|
-
print(f"Processing subgraph: {sub.name}")
|
307
|
-
print('---')
|
308
|
-
print(sub.nodes)
|
309
|
-
print('---')
|
310
|
-
print(sub.edges)
|
311
|
-
print('---')
|
312
|
-
|
313
|
-
# Prepare graph dataframes
|
314
|
-
# Nodes
|
315
|
-
coll_name = f"{cfg_db.milvus_db.database_name}_nodes"
|
316
|
-
node_coll = Collection(name=coll_name)
|
317
|
-
node_coll.load()
|
318
|
-
graph_nodes = node_coll.query(
|
319
|
-
expr=f'node_index IN [{",".join(f"{n}" for n in sub.nodes)}]',
|
320
|
-
output_fields=['node_id', 'node_name', 'node_type', 'desc']
|
321
|
-
)
|
322
|
-
graph_nodes = df.DataFrame(graph_nodes)
|
323
|
-
graph_nodes.drop(columns=['node_index'], inplace=True)
|
324
|
-
if not color_df.empty:
|
325
|
-
# Merge the color dataframe with the graph nodes
|
326
|
-
graph_nodes = graph_nodes.merge(color_df, on="node_id", how="left")
|
327
|
-
else:
|
328
|
-
graph_nodes["color"] = 'black' # Default color
|
329
|
-
graph_nodes['color'].fillna('black', inplace=True) # Fill NaN colors with black
|
330
|
-
# Edges
|
331
|
-
coll_name = f"{cfg_db.milvus_db.database_name}_edges"
|
332
|
-
edge_coll = Collection(name=coll_name)
|
333
|
-
edge_coll.load()
|
334
|
-
graph_edges = edge_coll.query(
|
335
|
-
expr=f'triplet_index IN [{",".join(f"{e}" for e in sub.edges)}]',
|
336
|
-
output_fields=['head_id', 'tail_id', 'edge_type']
|
337
|
-
)
|
338
|
-
graph_edges = df.DataFrame(graph_edges)
|
339
|
-
graph_edges.drop(columns=['triplet_index'], inplace=True)
|
340
|
-
graph_edges['edge_type'] = graph_edges['edge_type'].str.split('|')
|
330
|
+
graph_nodes, graph_edges = self._process_subgraph_data(sub, cfg_db, color_df)
|
341
331
|
|
342
332
|
# Prepare lists for visualization
|
343
333
|
graph_dict["name"].append(sub.name)
|
@@ -350,32 +340,75 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
350
340
|
'color': row.color})
|
351
341
|
for row in getattr(graph_nodes,
|
352
342
|
"to_pandas",
|
353
|
-
lambda: graph_nodes)()
|
343
|
+
lambda graph_nodes=graph_nodes: graph_nodes)()
|
344
|
+
.itertuples(index=False)])
|
354
345
|
graph_dict["edges"].append([(
|
355
346
|
row.head_id,
|
356
347
|
row.tail_id,
|
357
348
|
{'label': tuple(row.edge_type)})
|
358
349
|
for row in getattr(graph_edges,
|
359
350
|
"to_pandas",
|
360
|
-
lambda: graph_edges)()
|
351
|
+
lambda graph_edges=graph_edges: graph_edges)()
|
352
|
+
.itertuples(index=False)])
|
361
353
|
|
362
354
|
# Prepare the textualized subgraph
|
363
355
|
if sub.name == "Unified Subgraph":
|
364
356
|
graph_nodes = graph_nodes[['node_id', 'desc']]
|
365
357
|
graph_nodes.rename(columns={'desc': 'node_attr'}, inplace=True)
|
366
358
|
graph_edges = graph_edges[['head_id', 'edge_type', 'tail_id']]
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
359
|
+
nodes_pandas = getattr(graph_nodes, "to_pandas",
|
360
|
+
lambda graph_nodes=graph_nodes: graph_nodes)()
|
361
|
+
nodes_csv = nodes_pandas.to_csv(index=False)
|
362
|
+
edges_pandas = getattr(graph_edges, "to_pandas",
|
363
|
+
lambda graph_edges=graph_edges: graph_edges)()
|
364
|
+
edges_csv = edges_pandas.to_csv(index=False)
|
365
|
+
graph_dict["text"] = nodes_csv + "\n" + edges_csv
|
372
366
|
|
373
367
|
return graph_dict
|
374
368
|
|
369
|
+
def _process_subgraph_data(self, sub, cfg_db, color_df):
|
370
|
+
"""Helper method to process individual subgraph data."""
|
371
|
+
print(f"Processing subgraph: {sub.name}")
|
372
|
+
print('---')
|
373
|
+
print(sub.nodes)
|
374
|
+
print('---')
|
375
|
+
print(sub.edges)
|
376
|
+
print('---')
|
377
|
+
|
378
|
+
# Prepare graph dataframes - Nodes
|
379
|
+
coll_name = f"{cfg_db.milvus_db.database_name}_nodes"
|
380
|
+
node_coll = Collection(name=coll_name)
|
381
|
+
node_coll.load()
|
382
|
+
graph_nodes = node_coll.query(
|
383
|
+
expr=f'node_index IN [{",".join(f"{n}" for n in sub.nodes)}]',
|
384
|
+
output_fields=['node_id', 'node_name', 'node_type', 'desc']
|
385
|
+
)
|
386
|
+
graph_nodes = self.loader.df.DataFrame(graph_nodes)
|
387
|
+
graph_nodes.drop(columns=['node_index'], inplace=True)
|
388
|
+
if not color_df.empty:
|
389
|
+
graph_nodes = graph_nodes.merge(color_df, on="node_id", how="left")
|
390
|
+
else:
|
391
|
+
graph_nodes["color"] = 'black'
|
392
|
+
graph_nodes['color'] = graph_nodes['color'].fillna('black')
|
393
|
+
|
394
|
+
# Edges
|
395
|
+
coll_name = f"{cfg_db.milvus_db.database_name}_edges"
|
396
|
+
edge_coll = Collection(name=coll_name)
|
397
|
+
edge_coll.load()
|
398
|
+
graph_edges = edge_coll.query(
|
399
|
+
expr=f'triplet_index IN [{",".join(f"{e}" for e in sub.edges)}]',
|
400
|
+
output_fields=['head_id', 'tail_id', 'edge_type']
|
401
|
+
)
|
402
|
+
graph_edges = self.loader.df.DataFrame(graph_edges)
|
403
|
+
graph_edges.drop(columns=['triplet_index'], inplace=True)
|
404
|
+
graph_edges['edge_type'] = graph_edges['edge_type'].str.split('|')
|
405
|
+
|
406
|
+
return graph_nodes, graph_edges
|
407
|
+
|
375
408
|
def normalize_vector(self,
|
376
409
|
v : list) -> list:
|
377
410
|
"""
|
378
|
-
Normalize a vector using CuPy.
|
411
|
+
Normalize a vector using appropriate library (CuPy for GPU, NumPy for CPU).
|
379
412
|
|
380
413
|
Args:
|
381
414
|
v : Vector to normalize.
|
@@ -383,9 +416,13 @@ class MultimodalSubgraphExtractionTool(BaseTool):
|
|
383
416
|
Returns:
|
384
417
|
Normalized vector.
|
385
418
|
"""
|
386
|
-
|
387
|
-
|
388
|
-
|
419
|
+
if self.loader.normalize_vectors:
|
420
|
+
# GPU mode: normalize the vector
|
421
|
+
v_array = self.loader.py.asarray(v)
|
422
|
+
norm = self.loader.py.linalg.norm(v_array)
|
423
|
+
return (v_array / norm).tolist()
|
424
|
+
# CPU mode: return as-is for COSINE similarity
|
425
|
+
return v
|
389
426
|
|
390
427
|
def _run(
|
391
428
|
self,
|