aiagents4pharma 1.41.0__py3-none-any.whl → 1.43.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +1 -1
- aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/default.yaml +37 -0
- aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/ols_terms/default.yaml +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/reactome_pathways/default.yaml +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/uniprot_proteins/default.yaml +6 -0
- aiagents4pharma/talk2knowledgegraphs/configs/utils/pubchem_utils/default.yaml +5 -0
- aiagents4pharma/talk2knowledgegraphs/milvus_data_dump.py +752 -350
- aiagents4pharma/talk2scholars/agents/paper_download_agent.py +7 -4
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +49 -95
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/default.yaml +15 -1
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/pdf_agent/default.yaml +16 -2
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +40 -5
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +15 -5
- aiagents4pharma/talk2scholars/configs/config.yaml +1 -3
- aiagents4pharma/talk2scholars/configs/tools/paper_download/default.yaml +124 -0
- aiagents4pharma/talk2scholars/tests/test_arxiv_downloader.py +478 -0
- aiagents4pharma/talk2scholars/tests/test_base_paper_downloader.py +620 -0
- aiagents4pharma/talk2scholars/tests/test_biorxiv_downloader.py +697 -0
- aiagents4pharma/talk2scholars/tests/test_medrxiv_downloader.py +534 -0
- aiagents4pharma/talk2scholars/tests/test_paper_download_agent.py +22 -12
- aiagents4pharma/talk2scholars/tests/test_paper_downloader.py +545 -0
- aiagents4pharma/talk2scholars/tests/test_pubmed_downloader.py +1067 -0
- aiagents4pharma/talk2scholars/tools/paper_download/__init__.py +2 -4
- aiagents4pharma/talk2scholars/tools/paper_download/paper_downloader.py +457 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/__init__.py +20 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/arxiv_downloader.py +209 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/base_paper_downloader.py +343 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/biorxiv_downloader.py +321 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/medrxiv_downloader.py +198 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/pubmed_downloader.py +337 -0
- aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +97 -45
- aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +47 -29
- {aiagents4pharma-1.41.0.dist-info → aiagents4pharma-1.43.0.dist-info}/METADATA +30 -14
- {aiagents4pharma-1.41.0.dist-info → aiagents4pharma-1.43.0.dist-info}/RECORD +38 -30
- aiagents4pharma/talk2scholars/configs/tools/download_arxiv_paper/default.yaml +0 -4
- aiagents4pharma/talk2scholars/configs/tools/download_biorxiv_paper/__init__.py +0 -3
- aiagents4pharma/talk2scholars/configs/tools/download_biorxiv_paper/default.yaml +0 -2
- aiagents4pharma/talk2scholars/configs/tools/download_medrxiv_paper/__init__.py +0 -3
- aiagents4pharma/talk2scholars/configs/tools/download_medrxiv_paper/default.yaml +0 -2
- aiagents4pharma/talk2scholars/tests/test_paper_download_biorxiv.py +0 -151
- aiagents4pharma/talk2scholars/tests/test_paper_download_medrxiv.py +0 -151
- aiagents4pharma/talk2scholars/tests/test_paper_download_tools.py +0 -249
- aiagents4pharma/talk2scholars/tools/paper_download/download_arxiv_input.py +0 -177
- aiagents4pharma/talk2scholars/tools/paper_download/download_biorxiv_input.py +0 -114
- aiagents4pharma/talk2scholars/tools/paper_download/download_medrxiv_input.py +0 -114
- /aiagents4pharma/talk2scholars/configs/tools/{download_arxiv_paper → paper_download}/__init__.py +0 -0
- {aiagents4pharma-1.41.0.dist-info → aiagents4pharma-1.43.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.41.0.dist-info → aiagents4pharma-1.43.0.dist-info}/licenses/LICENSE +0 -0
- {aiagents4pharma-1.41.0.dist-info → aiagents4pharma-1.43.0.dist-info}/top_level.txt +0 -0
@@ -1,438 +1,796 @@
|
|
1
|
-
# pylint: disable=wrong-import-position
|
2
1
|
#!/usr/bin/env python3
|
2
|
+
# pylint: skip-file
|
3
3
|
"""
|
4
|
-
|
5
|
-
|
4
|
+
Dynamic Cross-Platform PrimeKG Multimodal Data Loader for Milvus Database.
|
5
|
+
Automatically detects system capabilities and chooses appropriate libraries and configurations.
|
6
|
+
|
7
|
+
Supports:
|
8
|
+
- Windows, Linux, macOS
|
9
|
+
- CPU-only mode (pandas/numpy)
|
10
|
+
- NVIDIA GPU mode (cudf/cupy)
|
11
|
+
- Dynamic index selection based on hardware
|
12
|
+
- Automatic dependency installation
|
6
13
|
"""
|
7
14
|
|
8
|
-
import os
|
9
|
-
import sys
|
10
|
-
import subprocess
|
11
15
|
import glob
|
12
16
|
import logging
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
"pip install --extra-index-url=https://pypi.nvidia.com cudf-cu12",
|
19
|
-
"pip install --extra-index-url=https://pypi.nvidia.com dask-cudf-cu12",
|
20
|
-
"pip install pymilvus==2.5.11",
|
21
|
-
"pip install numpy==1.26.4",
|
22
|
-
"pip install pandas==2.1.3",
|
23
|
-
"pip install tqdm==4.67.1",
|
24
|
-
]
|
25
|
-
|
26
|
-
print("[DATA LOADER] Installing required packages...")
|
27
|
-
for package_cmd in packages:
|
28
|
-
print(f"[DATA LOADER] Running: {package_cmd}")
|
29
|
-
result = subprocess.run(package_cmd.split(), capture_output=True, text=True, check=True)
|
30
|
-
if result.returncode != 0:
|
31
|
-
print(f"[DATA LOADER] Error installing package: {result.stderr}")
|
32
|
-
sys.exit(1)
|
33
|
-
print("[DATA LOADER] All packages installed successfully!")
|
34
|
-
|
35
|
-
# Install packages first
|
36
|
-
install_packages()
|
37
|
-
|
38
|
-
try:
|
39
|
-
import cudf
|
40
|
-
import cupy as cp
|
41
|
-
except ImportError as e:
|
42
|
-
print("[DATA LOADER] cudf or cupy not found. Please ensure they are installed correctly.")
|
43
|
-
sys.exit(1)
|
44
|
-
|
45
|
-
from pymilvus import (
|
46
|
-
db,
|
47
|
-
connections,
|
48
|
-
FieldSchema,
|
49
|
-
CollectionSchema,
|
50
|
-
DataType,
|
51
|
-
Collection,
|
52
|
-
utility
|
53
|
-
)
|
54
|
-
from tqdm import tqdm
|
17
|
+
import os
|
18
|
+
import platform
|
19
|
+
import subprocess
|
20
|
+
import sys
|
21
|
+
from typing import Any, Dict, List, Optional, Union
|
55
22
|
|
56
23
|
# Configure logging
|
57
|
-
logging.basicConfig(level=logging.INFO, format=
|
24
|
+
logging.basicConfig(level=logging.INFO, format="[DATA LOADER] %(message)s")
|
58
25
|
logger = logging.getLogger(__name__)
|
59
26
|
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
27
|
+
|
28
|
+
class SystemDetector:
|
29
|
+
"""Detect system capabilities and choose appropriate libraries."""
|
30
|
+
|
31
|
+
def __init__(self):
|
32
|
+
self.os_type = platform.system().lower() # 'windows', 'linux', 'darwin'
|
33
|
+
self.architecture = platform.machine().lower() # 'x86_64', 'arm64', etc.
|
34
|
+
self.has_nvidia_gpu = self._detect_nvidia_gpu()
|
35
|
+
self.use_gpu = (
|
36
|
+
self.has_nvidia_gpu and self.os_type != "darwin"
|
37
|
+
) # No CUDA on macOS
|
38
|
+
|
39
|
+
logger.info("System Detection Results:")
|
40
|
+
logger.info(" OS: %s", self.os_type)
|
41
|
+
logger.info(" Architecture: %s", self.architecture)
|
42
|
+
logger.info(" NVIDIA GPU detected: %s", self.has_nvidia_gpu)
|
43
|
+
logger.info(" Will use GPU acceleration: %s", self.use_gpu)
|
44
|
+
|
45
|
+
def _detect_nvidia_gpu(self) -> bool:
|
46
|
+
"""Detect if NVIDIA GPU is available."""
|
47
|
+
try:
|
48
|
+
# Try nvidia-smi command
|
49
|
+
result = subprocess.run(
|
50
|
+
["nvidia-smi"], capture_output=True, text=True, timeout=10
|
51
|
+
)
|
52
|
+
return result.returncode == 0
|
53
|
+
except (
|
54
|
+
subprocess.TimeoutExpired,
|
55
|
+
FileNotFoundError,
|
56
|
+
subprocess.SubprocessError,
|
57
|
+
):
|
58
|
+
return False
|
59
|
+
|
60
|
+
def get_required_packages(self) -> List[str]:
|
61
|
+
"""Get list of packages to install based on system capabilities - matches original logic."""
|
62
|
+
if self.use_gpu and self.os_type == "linux":
|
63
|
+
# Exact package list from original script for GPU mode
|
64
|
+
packages = [
|
65
|
+
# "pip install --extra-index-url=https://pypi.nvidia.com cudf-cu12",
|
66
|
+
# "pip install --extra-index-url=https://pypi.nvidia.com dask-cudf-cu12",
|
67
|
+
"pip install pymilvus==2.5.11",
|
68
|
+
"pip install numpy==1.26.4",
|
69
|
+
"pip install pandas==2.1.3",
|
70
|
+
"pip install tqdm==4.67.1",
|
71
|
+
]
|
72
|
+
return packages
|
73
|
+
else:
|
74
|
+
# CPU-only packages
|
75
|
+
packages = [
|
76
|
+
"pip install pymilvus==2.5.11",
|
77
|
+
"pip install numpy==1.26.4",
|
78
|
+
"pip install pandas==2.1.3",
|
79
|
+
"pip install tqdm==4.67.1",
|
80
|
+
]
|
81
|
+
return packages
|
82
|
+
|
83
|
+
def install_packages(self):
|
84
|
+
"""Install required packages using original script's exact logic."""
|
85
|
+
packages = self.get_required_packages()
|
86
|
+
|
87
|
+
logger.info(
|
88
|
+
"Installing packages for %s system%s",
|
89
|
+
self.os_type,
|
90
|
+
" with GPU support" if self.use_gpu else "",
|
91
|
+
)
|
92
|
+
|
93
|
+
for package_cmd in packages:
|
94
|
+
logger.info("Running: %s", package_cmd)
|
95
|
+
try:
|
96
|
+
result = subprocess.run(
|
97
|
+
package_cmd.split(),
|
98
|
+
capture_output=True,
|
99
|
+
text=True,
|
100
|
+
check=True,
|
101
|
+
timeout=300,
|
102
|
+
)
|
103
|
+
if result.returncode != 0:
|
104
|
+
logger.error("Error installing package: %s", result.stderr)
|
105
|
+
if "cudf" in package_cmd or "dask-cudf" in package_cmd:
|
106
|
+
logger.warning(
|
107
|
+
"GPU package installation failed, falling back to CPU mode"
|
108
|
+
)
|
109
|
+
self.use_gpu = False
|
110
|
+
return self.install_packages() # Retry with CPU packages
|
111
|
+
else:
|
112
|
+
sys.exit(1)
|
113
|
+
else:
|
114
|
+
logger.info("Successfully installed: %s", package_cmd.split()[-1])
|
115
|
+
except subprocess.CalledProcessError as e:
|
116
|
+
logger.error("Failed to install %s: %s", package_cmd, e.stderr)
|
117
|
+
if "cudf" in package_cmd:
|
118
|
+
logger.warning(
|
119
|
+
"GPU package installation failed, falling back to CPU mode"
|
120
|
+
)
|
121
|
+
self.use_gpu = False
|
122
|
+
return self.install_packages() # Retry with CPU packages
|
123
|
+
else:
|
124
|
+
raise
|
125
|
+
except subprocess.TimeoutExpired:
|
126
|
+
logger.error("Installation timeout for package: %s", package_cmd)
|
127
|
+
raise
|
128
|
+
|
129
|
+
|
130
|
+
class DynamicDataLoader:
|
131
|
+
"""Dynamic data loader that adapts to system capabilities."""
|
132
|
+
|
64
133
|
def __init__(self, config: Dict[str, Any]):
|
65
|
-
"""Initialize
|
134
|
+
"""Initialize with system detection and dynamic library loading."""
|
66
135
|
self.config = config
|
67
|
-
self.
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
self.
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
136
|
+
self.detector = SystemDetector()
|
137
|
+
|
138
|
+
# Install packages if needed
|
139
|
+
if config.get("auto_install_packages", True):
|
140
|
+
self.detector.install_packages()
|
141
|
+
|
142
|
+
# Import libraries based on system capabilities
|
143
|
+
self._import_libraries()
|
144
|
+
|
145
|
+
# Configuration - exact original parameters
|
146
|
+
self.milvus_host = config.get("milvus_host", "localhost")
|
147
|
+
self.milvus_port = config.get("milvus_port", "19530")
|
148
|
+
self.milvus_user = config.get("milvus_user", "root")
|
149
|
+
self.milvus_password = config.get("milvus_password", "Milvus")
|
150
|
+
self.milvus_database = config.get("milvus_database", "t2kg_primekg")
|
151
|
+
self.data_dir = config.get("data_dir", "./data")
|
152
|
+
self.batch_size = config.get("batch_size", 500)
|
153
|
+
self.chunk_size = config.get("chunk_size", 5) # Original chunk_size parameter
|
154
|
+
|
155
|
+
# Dynamic settings based on hardware
|
156
|
+
self.use_gpu = self.detector.use_gpu
|
157
|
+
self.normalize_vectors = self.use_gpu # Only normalize for GPU (original logic)
|
158
|
+
self.vector_index_type = "GPU_CAGRA" if self.use_gpu else "HNSW"
|
159
|
+
self.metric_type = "IP" if self.use_gpu else "COSINE"
|
160
|
+
|
161
|
+
logger.info("Loader Configuration:")
|
162
|
+
logger.info(" Using GPU acceleration: %s", self.use_gpu)
|
163
|
+
logger.info(" Vector normalization: %s", self.normalize_vectors)
|
164
|
+
logger.info(" Vector index type: %s", self.vector_index_type)
|
165
|
+
logger.info(" Metric type: %s", self.metric_type)
|
166
|
+
logger.info(" Data directory: %s", self.data_dir)
|
167
|
+
logger.info(" Batch size: %s", self.batch_size)
|
168
|
+
logger.info(" Chunk size: %s", self.chunk_size)
|
169
|
+
|
170
|
+
def _import_libraries(self):
|
171
|
+
"""Dynamically import libraries - matches original script's import logic."""
|
172
|
+
# Always import base libraries
|
173
|
+
import numpy as np
|
174
|
+
import pandas as pd
|
175
|
+
from pymilvus import (
|
176
|
+
Collection,
|
177
|
+
CollectionSchema,
|
178
|
+
DataType,
|
179
|
+
FieldSchema,
|
180
|
+
connections,
|
181
|
+
db,
|
182
|
+
utility,
|
183
|
+
)
|
184
|
+
from tqdm import tqdm
|
185
|
+
|
186
|
+
self.pd = pd
|
187
|
+
self.np = np
|
188
|
+
self.tqdm = tqdm
|
189
|
+
self.pymilvus_modules = {
|
190
|
+
"db": db,
|
191
|
+
"connections": connections,
|
192
|
+
"FieldSchema": FieldSchema,
|
193
|
+
"CollectionSchema": CollectionSchema,
|
194
|
+
"DataType": DataType,
|
195
|
+
"Collection": Collection,
|
196
|
+
"utility": utility,
|
197
|
+
}
|
198
|
+
|
199
|
+
# Conditionally import GPU libraries - matches original error handling
|
200
|
+
if self.detector.use_gpu:
|
201
|
+
try:
|
202
|
+
import cudf # pyright: ignore
|
203
|
+
import cupy as cp # pyright: ignore
|
204
|
+
|
205
|
+
self.cudf = cudf
|
206
|
+
self.cp = cp
|
207
|
+
logger.info("Successfully imported GPU libraries (cudf, cupy)")
|
208
|
+
except ImportError as e:
|
209
|
+
logger.error(
|
210
|
+
"[DATA LOADER] cudf or cupy not found. Please ensure they are installed correctly."
|
211
|
+
)
|
212
|
+
logger.error("Import error: %s", str(e))
|
213
|
+
# Match original script's exit behavior for critical GPU import failure
|
214
|
+
if not os.getenv("FORCE_CPU", "false").lower() == "true":
|
215
|
+
logger.error(
|
216
|
+
"GPU libraries required but not available. Set FORCE_CPU=true to use CPU mode."
|
217
|
+
)
|
218
|
+
sys.exit(1)
|
219
|
+
else:
|
220
|
+
logger.warning("Falling back to CPU mode due to FORCE_CPU=true")
|
221
|
+
self.detector.use_gpu = False
|
222
|
+
self.use_gpu = False
|
223
|
+
|
224
|
+
def _read_dataframe(
|
225
|
+
self, file_path: str, columns: Optional[List[str]] = None
|
226
|
+
) -> Union["pd.DataFrame", "cudf.DataFrame"]: # type: ignore[reportUndefinedVariable] # noqa: F821
|
227
|
+
"""Read dataframe using appropriate library."""
|
228
|
+
if self.use_gpu:
|
229
|
+
return self.cudf.read_parquet(file_path, columns=columns)
|
230
|
+
else:
|
231
|
+
return self.pd.read_parquet(file_path, columns=columns)
|
232
|
+
|
233
|
+
def _concat_dataframes(
|
234
|
+
self, df_list: List, ignore_index: bool = True
|
235
|
+
) -> Union["pd.DataFrame", "cudf.DataFrame"]: # type: ignore[reportUndefinedVariable] # noqa: F821
|
236
|
+
"""Concatenate dataframes using appropriate library."""
|
237
|
+
if self.use_gpu:
|
238
|
+
return self.cudf.concat(df_list, ignore_index=ignore_index)
|
239
|
+
else:
|
240
|
+
return self.pd.concat(df_list, ignore_index=ignore_index)
|
241
|
+
|
242
|
+
def _normalize_matrix(self, matrix, axis: int = 1):
|
243
|
+
"""Normalize matrix using appropriate library."""
|
244
|
+
if not self.normalize_vectors:
|
245
|
+
return matrix
|
246
|
+
|
247
|
+
if self.use_gpu:
|
248
|
+
# Use cupy for GPU
|
249
|
+
matrix_cp = self.cp.asarray(matrix).astype(self.cp.float32)
|
250
|
+
norms = self.cp.linalg.norm(matrix_cp, axis=axis, keepdims=True)
|
251
|
+
return matrix_cp / norms
|
252
|
+
else:
|
253
|
+
# Use numpy for CPU (but we don't normalize for CPU/COSINE)
|
254
|
+
return matrix
|
255
|
+
|
256
|
+
def _extract_embeddings(self, df, column_name: str):
|
257
|
+
"""Extract embeddings and convert to appropriate format."""
|
258
|
+
if self.use_gpu:
|
259
|
+
# cuDF list extraction
|
260
|
+
emb_data = self.cp.asarray(df[column_name].list.leaves).astype(
|
261
|
+
self.cp.float32
|
262
|
+
)
|
263
|
+
return emb_data.reshape(df.shape[0], -1)
|
264
|
+
else:
|
265
|
+
# pandas extraction
|
266
|
+
emb_list = []
|
267
|
+
for emb in df[column_name]:
|
268
|
+
if isinstance(emb, list):
|
269
|
+
emb_list.append(emb)
|
270
|
+
else:
|
271
|
+
emb_list.append(emb.tolist() if hasattr(emb, "tolist") else emb)
|
272
|
+
return self.np.array(emb_list, dtype=self.np.float32)
|
273
|
+
|
274
|
+
def _to_list(self, data):
|
275
|
+
"""Convert data to list format for Milvus insertion."""
|
276
|
+
if self.use_gpu:
|
277
|
+
# For cuDF data, use to_arrow().to_pylist()
|
278
|
+
if hasattr(data, "to_arrow"):
|
279
|
+
return data.to_arrow().to_pylist()
|
280
|
+
elif hasattr(data, "tolist"):
|
281
|
+
# Fallback for cupy arrays
|
282
|
+
return data.tolist()
|
283
|
+
else:
|
284
|
+
return list(data)
|
285
|
+
else:
|
286
|
+
# For pandas/numpy data
|
287
|
+
if hasattr(data, "tolist"):
|
288
|
+
return data.tolist()
|
289
|
+
elif hasattr(data, "to_arrow"):
|
290
|
+
return data.to_arrow().to_pylist()
|
291
|
+
else:
|
292
|
+
return list(data)
|
87
293
|
|
88
294
|
def connect_to_milvus(self):
|
89
295
|
"""Connect to Milvus and setup database."""
|
90
296
|
logger.info("Connecting to Milvus at %s:%s", self.milvus_host, self.milvus_port)
|
91
297
|
|
92
|
-
connections.connect(
|
298
|
+
self.pymilvus_modules["connections"].connect(
|
93
299
|
alias="default",
|
94
300
|
host=self.milvus_host,
|
95
301
|
port=self.milvus_port,
|
96
302
|
user=self.milvus_user,
|
97
|
-
password=self.milvus_password
|
303
|
+
password=self.milvus_password,
|
98
304
|
)
|
99
305
|
|
100
306
|
# Check if database exists, create if it doesn't
|
101
|
-
if self.milvus_database not in db.list_database():
|
307
|
+
if self.milvus_database not in self.pymilvus_modules["db"].list_database():
|
102
308
|
logger.info("Creating database: %s", self.milvus_database)
|
103
|
-
db.create_database(self.milvus_database)
|
309
|
+
self.pymilvus_modules["db"].create_database(self.milvus_database)
|
104
310
|
|
105
311
|
# Switch to the desired database
|
106
|
-
db.using_database(self.milvus_database)
|
312
|
+
self.pymilvus_modules["db"].using_database(self.milvus_database)
|
107
313
|
logger.info("Using database: %s", self.milvus_database)
|
108
314
|
|
109
315
|
def load_graph_data(self):
|
110
|
-
"""Load the
|
316
|
+
"""Load the parquet files containing graph data."""
|
111
317
|
logger.info("Loading graph data from: %s", self.data_dir)
|
112
318
|
|
113
319
|
if not os.path.exists(self.data_dir):
|
114
320
|
raise FileNotFoundError(f"Data directory not found: {self.data_dir}")
|
115
321
|
|
116
|
-
# Load dataframes containing nodes and edges
|
117
|
-
# Loop over nodes and edges
|
118
322
|
graph = {}
|
119
323
|
for element in ["nodes", "edges"]:
|
120
|
-
# Make an empty dictionary for each folder
|
121
324
|
graph[element] = {}
|
122
325
|
for stage in ["enrichment", "embedding"]:
|
123
|
-
|
124
|
-
|
125
|
-
file_list = glob.glob(
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
326
|
+
logger.info("Processing %s %s", element, stage)
|
327
|
+
|
328
|
+
file_list = glob.glob(
|
329
|
+
os.path.join(self.data_dir, element, stage, "*.parquet.gzip")
|
330
|
+
)
|
331
|
+
logger.info("Found %d files for %s %s", len(file_list), element, stage)
|
332
|
+
|
333
|
+
if not file_list:
|
334
|
+
logger.warning("No files found for %s %s", element, stage)
|
335
|
+
continue
|
336
|
+
|
337
|
+
# For edges embedding, process in chunks due to size
|
133
338
|
if element == "edges" and stage == "embedding":
|
134
|
-
# For edges embedding, only read two columns:
|
135
|
-
# triplet_index and edge_emb
|
136
|
-
# Loop by chunks
|
137
339
|
chunk_size = self.chunk_size
|
138
340
|
graph[element][stage] = []
|
139
341
|
for i in range(0, len(file_list), chunk_size):
|
140
|
-
chunk_files = file_list[i:i+chunk_size]
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
342
|
+
chunk_files = file_list[i : i + chunk_size]
|
343
|
+
chunk_df_list = []
|
344
|
+
for f in chunk_files:
|
345
|
+
df = self._read_dataframe(
|
346
|
+
f, columns=["triplet_index", "edge_emb"]
|
347
|
+
)
|
348
|
+
chunk_df_list.append(df)
|
349
|
+
chunk_df = self._concat_dataframes(
|
350
|
+
chunk_df_list, ignore_index=True
|
351
|
+
)
|
145
352
|
graph[element][stage].append(chunk_df)
|
146
353
|
else:
|
147
|
-
# For
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
354
|
+
# For other combinations, read all files
|
355
|
+
df_list = []
|
356
|
+
for f in file_list:
|
357
|
+
df = self._read_dataframe(f)
|
358
|
+
df_list.append(df)
|
359
|
+
graph[element][stage] = self._concat_dataframes(
|
360
|
+
df_list, ignore_index=True
|
361
|
+
)
|
154
362
|
|
155
363
|
logger.info("Graph data loaded successfully")
|
156
364
|
return graph
|
157
365
|
|
158
|
-
def
|
366
|
+
def _get_embedding_dimension(self, df, column_name: str) -> int:
|
367
|
+
"""Get embedding dimension using original script's exact logic."""
|
368
|
+
first_emb = df.iloc[0][column_name]
|
369
|
+
if self.use_gpu:
|
370
|
+
# cuDF format - matches original: len(nodes_df.iloc[0]['desc_emb'].to_arrow().to_pylist()[0])
|
371
|
+
return len(first_emb.to_arrow().to_pylist()[0])
|
372
|
+
else:
|
373
|
+
# pandas format
|
374
|
+
if isinstance(first_emb, list):
|
375
|
+
return len(first_emb)
|
376
|
+
else:
|
377
|
+
return len(
|
378
|
+
first_emb.tolist() if hasattr(first_emb, "tolist") else first_emb
|
379
|
+
)
|
380
|
+
|
381
|
+
def create_nodes_collection(self, nodes_df):
|
159
382
|
"""Create and populate the main nodes collection."""
|
160
383
|
logger.info("Creating main nodes collection...")
|
161
384
|
node_coll_name = f"{self.milvus_database}_nodes"
|
162
385
|
|
386
|
+
# Get embedding dimension
|
387
|
+
emb_dim = self._get_embedding_dimension(nodes_df, "desc_emb")
|
388
|
+
|
163
389
|
node_fields = [
|
164
|
-
FieldSchema(
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
390
|
+
self.pymilvus_modules["FieldSchema"](
|
391
|
+
name="node_index",
|
392
|
+
dtype=self.pymilvus_modules["DataType"].INT64,
|
393
|
+
is_primary=True,
|
394
|
+
),
|
395
|
+
self.pymilvus_modules["FieldSchema"](
|
396
|
+
name="node_id",
|
397
|
+
dtype=self.pymilvus_modules["DataType"].VARCHAR,
|
398
|
+
max_length=1024,
|
399
|
+
),
|
400
|
+
self.pymilvus_modules["FieldSchema"](
|
401
|
+
name="node_name",
|
402
|
+
dtype=self.pymilvus_modules["DataType"].VARCHAR,
|
403
|
+
max_length=1024,
|
404
|
+
enable_analyzer=True,
|
405
|
+
enable_match=True,
|
406
|
+
),
|
407
|
+
self.pymilvus_modules["FieldSchema"](
|
408
|
+
name="node_type",
|
409
|
+
dtype=self.pymilvus_modules["DataType"].VARCHAR,
|
410
|
+
max_length=1024,
|
411
|
+
enable_analyzer=True,
|
412
|
+
enable_match=True,
|
413
|
+
),
|
414
|
+
self.pymilvus_modules["FieldSchema"](
|
415
|
+
name="desc",
|
416
|
+
dtype=self.pymilvus_modules["DataType"].VARCHAR,
|
417
|
+
max_length=40960,
|
418
|
+
enable_analyzer=True,
|
419
|
+
enable_match=True,
|
420
|
+
),
|
421
|
+
self.pymilvus_modules["FieldSchema"](
|
422
|
+
name="desc_emb",
|
423
|
+
dtype=self.pymilvus_modules["DataType"].FLOAT_VECTOR,
|
424
|
+
dim=emb_dim,
|
425
|
+
),
|
188
426
|
]
|
189
|
-
|
190
|
-
|
427
|
+
|
428
|
+
schema = self.pymilvus_modules["CollectionSchema"](
|
429
|
+
fields=node_fields, description=f"Schema for collection {node_coll_name}"
|
430
|
+
)
|
191
431
|
|
192
432
|
# Create collection if it doesn't exist
|
193
|
-
if not utility.has_collection(node_coll_name):
|
194
|
-
collection = Collection(
|
433
|
+
if not self.pymilvus_modules["utility"].has_collection(node_coll_name):
|
434
|
+
collection = self.pymilvus_modules["Collection"](
|
435
|
+
name=node_coll_name, schema=schema
|
436
|
+
)
|
195
437
|
else:
|
196
|
-
collection = Collection(name=node_coll_name)
|
197
|
-
|
198
|
-
# Create indexes
|
199
|
-
collection.create_index(
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
438
|
+
collection = self.pymilvus_modules["Collection"](name=node_coll_name)
|
439
|
+
|
440
|
+
# Create indexes with dynamic parameters
|
441
|
+
collection.create_index(
|
442
|
+
field_name="node_index",
|
443
|
+
index_params={"index_type": "STL_SORT"},
|
444
|
+
index_name="node_index_index",
|
445
|
+
)
|
446
|
+
collection.create_index(
|
447
|
+
field_name="node_name",
|
448
|
+
index_params={"index_type": "INVERTED"},
|
449
|
+
index_name="node_name_index",
|
450
|
+
)
|
451
|
+
collection.create_index(
|
452
|
+
field_name="node_type",
|
453
|
+
index_params={"index_type": "INVERTED"},
|
454
|
+
index_name="node_type_index",
|
455
|
+
)
|
456
|
+
collection.create_index(
|
457
|
+
field_name="desc",
|
458
|
+
index_params={"index_type": "INVERTED"},
|
459
|
+
index_name="desc_index",
|
460
|
+
)
|
461
|
+
collection.create_index(
|
462
|
+
field_name="desc_emb",
|
463
|
+
index_params={
|
464
|
+
"index_type": self.vector_index_type,
|
465
|
+
"metric_type": self.metric_type,
|
466
|
+
},
|
467
|
+
index_name="desc_emb_index",
|
468
|
+
)
|
215
469
|
|
216
470
|
# Prepare and insert data
|
217
|
-
|
218
|
-
|
219
|
-
|
471
|
+
desc_emb_data = self._extract_embeddings(nodes_df, "desc_emb")
|
472
|
+
desc_emb_normalized = self._normalize_matrix(desc_emb_data, axis=1)
|
473
|
+
|
220
474
|
data = [
|
221
|
-
nodes_df["node_index"]
|
222
|
-
nodes_df["node_id"]
|
223
|
-
nodes_df["node_name"]
|
224
|
-
nodes_df["node_type"]
|
225
|
-
nodes_df["desc"]
|
226
|
-
|
475
|
+
self._to_list(nodes_df["node_index"]),
|
476
|
+
self._to_list(nodes_df["node_id"]),
|
477
|
+
self._to_list(nodes_df["node_name"]),
|
478
|
+
self._to_list(nodes_df["node_type"]),
|
479
|
+
self._to_list(nodes_df["desc"]),
|
480
|
+
self._to_list(desc_emb_normalized),
|
227
481
|
]
|
228
482
|
|
229
483
|
# Insert data in batches
|
230
484
|
total = len(data[0])
|
231
|
-
for i in tqdm(range(0, total, self.batch_size), desc="Inserting nodes"):
|
232
|
-
batch = [col[i:i+self.batch_size] for col in data]
|
485
|
+
for i in self.tqdm(range(0, total, self.batch_size), desc="Inserting nodes"):
|
486
|
+
batch = [col[i : i + self.batch_size] for col in data]
|
233
487
|
collection.insert(batch)
|
234
488
|
|
235
489
|
collection.flush()
|
236
|
-
logger.info(
|
490
|
+
logger.info(
|
491
|
+
"Nodes collection created with %d entities", collection.num_entities
|
492
|
+
)
|
237
493
|
|
238
|
-
def create_node_type_collections(self, nodes_df
|
494
|
+
def create_node_type_collections(self, nodes_df):
|
239
495
|
"""Create separate collections for each node type."""
|
240
496
|
logger.info("Creating node type-specific collections...")
|
241
497
|
|
242
|
-
for node_type, nodes_df_ in tqdm(
|
243
|
-
|
244
|
-
|
498
|
+
for node_type, nodes_df_ in self.tqdm(
|
499
|
+
nodes_df.groupby("node_type"), desc="Processing node types"
|
500
|
+
):
|
501
|
+
node_coll_name = (
|
502
|
+
f"{self.milvus_database}_nodes_{node_type.replace('/', '_')}"
|
503
|
+
)
|
504
|
+
|
505
|
+
# Get embedding dimensions
|
506
|
+
desc_dim = self._get_embedding_dimension(nodes_df_, "desc_emb")
|
507
|
+
feat_dim = self._get_embedding_dimension(nodes_df_, "feat_emb")
|
245
508
|
|
246
509
|
node_fields = [
|
247
|
-
FieldSchema(
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
FieldSchema(
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
510
|
+
self.pymilvus_modules["FieldSchema"](
|
511
|
+
name="node_index",
|
512
|
+
dtype=self.pymilvus_modules["DataType"].INT64,
|
513
|
+
is_primary=True,
|
514
|
+
auto_id=False,
|
515
|
+
),
|
516
|
+
self.pymilvus_modules["FieldSchema"](
|
517
|
+
name="node_id",
|
518
|
+
dtype=self.pymilvus_modules["DataType"].VARCHAR,
|
519
|
+
max_length=1024,
|
520
|
+
),
|
521
|
+
self.pymilvus_modules["FieldSchema"](
|
522
|
+
name="node_name",
|
523
|
+
dtype=self.pymilvus_modules["DataType"].VARCHAR,
|
524
|
+
max_length=1024,
|
525
|
+
enable_analyzer=True,
|
526
|
+
enable_match=True,
|
527
|
+
),
|
528
|
+
self.pymilvus_modules["FieldSchema"](
|
529
|
+
name="node_type",
|
530
|
+
dtype=self.pymilvus_modules["DataType"].VARCHAR,
|
531
|
+
max_length=1024,
|
532
|
+
enable_analyzer=True,
|
533
|
+
enable_match=True,
|
534
|
+
),
|
535
|
+
self.pymilvus_modules["FieldSchema"](
|
536
|
+
name="desc",
|
537
|
+
dtype=self.pymilvus_modules["DataType"].VARCHAR,
|
538
|
+
max_length=40960,
|
539
|
+
enable_analyzer=True,
|
540
|
+
enable_match=True,
|
541
|
+
),
|
542
|
+
self.pymilvus_modules["FieldSchema"](
|
543
|
+
name="desc_emb",
|
544
|
+
dtype=self.pymilvus_modules["DataType"].FLOAT_VECTOR,
|
545
|
+
dim=desc_dim,
|
546
|
+
),
|
547
|
+
self.pymilvus_modules["FieldSchema"](
|
548
|
+
name="feat",
|
549
|
+
dtype=self.pymilvus_modules["DataType"].VARCHAR,
|
550
|
+
max_length=40960,
|
551
|
+
enable_analyzer=True,
|
552
|
+
enable_match=True,
|
553
|
+
),
|
554
|
+
self.pymilvus_modules["FieldSchema"](
|
555
|
+
name="feat_emb",
|
556
|
+
dtype=self.pymilvus_modules["DataType"].FLOAT_VECTOR,
|
557
|
+
dim=feat_dim,
|
558
|
+
),
|
280
559
|
]
|
281
|
-
schema = CollectionSchema(fields=node_fields,
|
282
|
-
description=f"schema for collection {node_coll_name}")
|
283
560
|
|
284
|
-
|
285
|
-
|
561
|
+
schema = self.pymilvus_modules["CollectionSchema"](
|
562
|
+
fields=node_fields,
|
563
|
+
description=f"schema for collection {node_coll_name}",
|
564
|
+
)
|
565
|
+
|
566
|
+
if not self.pymilvus_modules["utility"].has_collection(node_coll_name):
|
567
|
+
collection = self.pymilvus_modules["Collection"](
|
568
|
+
name=node_coll_name, schema=schema
|
569
|
+
)
|
286
570
|
else:
|
287
|
-
collection = Collection(name=node_coll_name)
|
288
|
-
|
289
|
-
# Create indexes
|
290
|
-
collection.create_index(
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
571
|
+
collection = self.pymilvus_modules["Collection"](name=node_coll_name)
|
572
|
+
|
573
|
+
# Create indexes with dynamic parameters
|
574
|
+
collection.create_index(
|
575
|
+
field_name="node_index",
|
576
|
+
index_params={"index_type": "STL_SORT"},
|
577
|
+
index_name="node_index_index",
|
578
|
+
)
|
579
|
+
collection.create_index(
|
580
|
+
field_name="node_name",
|
581
|
+
index_params={"index_type": "INVERTED"},
|
582
|
+
index_name="node_name_index",
|
583
|
+
)
|
584
|
+
collection.create_index(
|
585
|
+
field_name="node_type",
|
586
|
+
index_params={"index_type": "INVERTED"},
|
587
|
+
index_name="node_type_index",
|
588
|
+
)
|
589
|
+
collection.create_index(
|
590
|
+
field_name="desc",
|
591
|
+
index_params={"index_type": "INVERTED"},
|
592
|
+
index_name="desc_index",
|
593
|
+
)
|
594
|
+
collection.create_index(
|
595
|
+
field_name="desc_emb",
|
596
|
+
index_params={
|
597
|
+
"index_type": self.vector_index_type,
|
598
|
+
"metric_type": self.metric_type,
|
599
|
+
},
|
600
|
+
index_name="desc_emb_index",
|
601
|
+
)
|
602
|
+
collection.create_index(
|
603
|
+
field_name="feat_emb",
|
604
|
+
index_params={
|
605
|
+
"index_type": self.vector_index_type,
|
606
|
+
"metric_type": self.metric_type,
|
607
|
+
},
|
608
|
+
index_name="feat_emb_index",
|
609
|
+
)
|
310
610
|
|
311
611
|
# Prepare data
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
612
|
+
desc_emb_data = self._extract_embeddings(nodes_df_, "desc_emb")
|
613
|
+
feat_emb_data = self._extract_embeddings(nodes_df_, "feat_emb")
|
614
|
+
|
615
|
+
desc_emb_normalized = self._normalize_matrix(desc_emb_data, axis=1)
|
616
|
+
feat_emb_normalized = self._normalize_matrix(feat_emb_data, axis=1)
|
617
|
+
|
318
618
|
data = [
|
319
|
-
nodes_df_["node_index"]
|
320
|
-
nodes_df_["node_id"]
|
321
|
-
nodes_df_["node_name"]
|
322
|
-
nodes_df_["node_type"]
|
323
|
-
nodes_df_["desc"]
|
324
|
-
|
325
|
-
nodes_df_["feat"]
|
326
|
-
|
619
|
+
self._to_list(nodes_df_["node_index"]),
|
620
|
+
self._to_list(nodes_df_["node_id"]),
|
621
|
+
self._to_list(nodes_df_["node_name"]),
|
622
|
+
self._to_list(nodes_df_["node_type"]),
|
623
|
+
self._to_list(nodes_df_["desc"]),
|
624
|
+
self._to_list(desc_emb_normalized),
|
625
|
+
self._to_list(nodes_df_["feat"]),
|
626
|
+
self._to_list(feat_emb_normalized),
|
327
627
|
]
|
328
628
|
|
329
629
|
# Insert data in batches
|
330
630
|
total_rows = len(data[0])
|
331
631
|
for i in range(0, total_rows, self.batch_size):
|
332
|
-
batch = [col[i:i + self.batch_size] for col in data]
|
632
|
+
batch = [col[i : i + self.batch_size] for col in data]
|
333
633
|
collection.insert(batch)
|
334
634
|
|
335
635
|
collection.flush()
|
336
|
-
logger.info(
|
337
|
-
|
636
|
+
logger.info(
|
637
|
+
"Collection %s created with %d entities",
|
638
|
+
node_coll_name,
|
639
|
+
collection.num_entities,
|
640
|
+
)
|
338
641
|
|
339
|
-
def create_edges_collection(self,
|
340
|
-
|
341
|
-
edges_embedding_df: List[cudf.DataFrame]):
|
342
|
-
"""Create and populate the edges collection."""
|
642
|
+
def create_edges_collection(self, edges_enrichment_df, edges_embedding_df: List):
|
643
|
+
"""Create and populate the edges collection - exact original logic."""
|
343
644
|
logger.info("Creating edges collection...")
|
344
645
|
|
345
646
|
edge_coll_name = f"{self.milvus_database}_edges"
|
346
647
|
|
648
|
+
# Get embedding dimension from first chunk - exact original logic
|
649
|
+
if self.use_gpu:
|
650
|
+
emb_dim = len(
|
651
|
+
edges_embedding_df[0].loc[0, "edge_emb"]
|
652
|
+
) # Original cudf access
|
653
|
+
else:
|
654
|
+
first_edge_emb = edges_embedding_df[0].iloc[0]["edge_emb"]
|
655
|
+
emb_dim = (
|
656
|
+
len(first_edge_emb)
|
657
|
+
if isinstance(first_edge_emb, list)
|
658
|
+
else len(first_edge_emb.tolist())
|
659
|
+
)
|
660
|
+
|
347
661
|
edge_fields = [
|
348
|
-
FieldSchema(
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
FieldSchema(
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
662
|
+
self.pymilvus_modules["FieldSchema"](
|
663
|
+
name="triplet_index",
|
664
|
+
dtype=self.pymilvus_modules["DataType"].INT64,
|
665
|
+
is_primary=True,
|
666
|
+
auto_id=False,
|
667
|
+
),
|
668
|
+
self.pymilvus_modules["FieldSchema"](
|
669
|
+
name="head_id",
|
670
|
+
dtype=self.pymilvus_modules["DataType"].VARCHAR,
|
671
|
+
max_length=1024,
|
672
|
+
),
|
673
|
+
self.pymilvus_modules["FieldSchema"](
|
674
|
+
name="head_index", dtype=self.pymilvus_modules["DataType"].INT64
|
675
|
+
),
|
676
|
+
self.pymilvus_modules["FieldSchema"](
|
677
|
+
name="tail_id",
|
678
|
+
dtype=self.pymilvus_modules["DataType"].VARCHAR,
|
679
|
+
max_length=1024,
|
680
|
+
),
|
681
|
+
self.pymilvus_modules["FieldSchema"](
|
682
|
+
name="tail_index", dtype=self.pymilvus_modules["DataType"].INT64
|
683
|
+
),
|
684
|
+
self.pymilvus_modules["FieldSchema"](
|
685
|
+
name="edge_type",
|
686
|
+
dtype=self.pymilvus_modules["DataType"].VARCHAR,
|
687
|
+
max_length=1024,
|
688
|
+
),
|
689
|
+
self.pymilvus_modules["FieldSchema"](
|
690
|
+
name="display_relation",
|
691
|
+
dtype=self.pymilvus_modules["DataType"].VARCHAR,
|
692
|
+
max_length=1024,
|
693
|
+
),
|
694
|
+
self.pymilvus_modules["FieldSchema"](
|
695
|
+
name="feat",
|
696
|
+
dtype=self.pymilvus_modules["DataType"].VARCHAR,
|
697
|
+
max_length=40960,
|
698
|
+
),
|
699
|
+
self.pymilvus_modules["FieldSchema"](
|
700
|
+
name="feat_emb",
|
701
|
+
dtype=self.pymilvus_modules["DataType"].FLOAT_VECTOR,
|
702
|
+
dim=emb_dim,
|
703
|
+
),
|
374
704
|
]
|
375
|
-
edge_schema = CollectionSchema(fields=edge_fields,
|
376
|
-
description="Schema for edges collection")
|
377
705
|
|
378
|
-
|
379
|
-
|
706
|
+
edge_schema = self.pymilvus_modules["CollectionSchema"](
|
707
|
+
fields=edge_fields, description="Schema for edges collection"
|
708
|
+
)
|
709
|
+
|
710
|
+
if not self.pymilvus_modules["utility"].has_collection(edge_coll_name):
|
711
|
+
collection = self.pymilvus_modules["Collection"](
|
712
|
+
name=edge_coll_name, schema=edge_schema
|
713
|
+
)
|
380
714
|
else:
|
381
|
-
collection = Collection(name=edge_coll_name)
|
382
|
-
|
383
|
-
# Create indexes
|
384
|
-
collection.create_index(
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
715
|
+
collection = self.pymilvus_modules["Collection"](name=edge_coll_name)
|
716
|
+
|
717
|
+
# Create indexes with dynamic parameters
|
718
|
+
collection.create_index(
|
719
|
+
field_name="triplet_index",
|
720
|
+
index_params={"index_type": "STL_SORT"},
|
721
|
+
index_name="triplet_index_index",
|
722
|
+
)
|
723
|
+
collection.create_index(
|
724
|
+
field_name="head_index",
|
725
|
+
index_params={"index_type": "STL_SORT"},
|
726
|
+
index_name="head_index_index",
|
727
|
+
)
|
728
|
+
collection.create_index(
|
729
|
+
field_name="tail_index",
|
730
|
+
index_params={"index_type": "STL_SORT"},
|
731
|
+
index_name="tail_index_index",
|
732
|
+
)
|
733
|
+
collection.create_index(
|
734
|
+
field_name="feat_emb",
|
735
|
+
index_params={
|
736
|
+
"index_type": self.vector_index_type,
|
737
|
+
"metric_type": self.metric_type,
|
738
|
+
},
|
739
|
+
index_name="feat_emb_index",
|
740
|
+
)
|
741
|
+
|
742
|
+
# Iterate over chunked edges embedding df - exact original logic
|
743
|
+
for edges_df in self.tqdm(edges_embedding_df, desc="Processing edge chunks"):
|
400
744
|
# Merge enrichment with embedding
|
401
745
|
merged_edges_df = edges_enrichment_df.merge(
|
402
|
-
edges_df[["triplet_index", "edge_emb"]],
|
403
|
-
on="triplet_index",
|
404
|
-
how="inner"
|
746
|
+
edges_df[["triplet_index", "edge_emb"]], on="triplet_index", how="inner"
|
405
747
|
)
|
406
748
|
|
407
|
-
# Prepare
|
408
|
-
|
409
|
-
|
410
|
-
|
749
|
+
# Prepare embeddings - exact original logic for GPU
|
750
|
+
if self.use_gpu:
|
751
|
+
edge_emb_cp = (
|
752
|
+
self.cp.asarray(merged_edges_df["edge_emb"].list.leaves)
|
753
|
+
.astype(self.cp.float32)
|
754
|
+
.reshape(merged_edges_df.shape[0], -1)
|
755
|
+
)
|
756
|
+
edge_emb_norm = self._normalize_matrix(edge_emb_cp, axis=1)
|
757
|
+
else:
|
758
|
+
edge_emb_data = self._extract_embeddings(merged_edges_df, "edge_emb")
|
759
|
+
edge_emb_norm = self._normalize_matrix(edge_emb_data, axis=1)
|
760
|
+
|
411
761
|
data = [
|
412
|
-
merged_edges_df["triplet_index"]
|
413
|
-
merged_edges_df["head_id"]
|
414
|
-
merged_edges_df["head_index"]
|
415
|
-
merged_edges_df["tail_id"]
|
416
|
-
merged_edges_df["tail_index"]
|
417
|
-
merged_edges_df["edge_type_str"]
|
418
|
-
merged_edges_df["display_relation"]
|
419
|
-
merged_edges_df["feat"]
|
420
|
-
|
762
|
+
self._to_list(merged_edges_df["triplet_index"]),
|
763
|
+
self._to_list(merged_edges_df["head_id"]),
|
764
|
+
self._to_list(merged_edges_df["head_index"]),
|
765
|
+
self._to_list(merged_edges_df["tail_id"]),
|
766
|
+
self._to_list(merged_edges_df["tail_index"]),
|
767
|
+
self._to_list(merged_edges_df["edge_type_str"]), # Original field name
|
768
|
+
self._to_list(merged_edges_df["display_relation"]),
|
769
|
+
self._to_list(merged_edges_df["feat"]),
|
770
|
+
self._to_list(edge_emb_norm),
|
421
771
|
]
|
422
772
|
|
423
773
|
# Insert data in batches
|
424
774
|
total = len(data[0])
|
425
|
-
for i in tqdm(
|
426
|
-
|
775
|
+
for i in self.tqdm(
|
776
|
+
range(0, total, self.batch_size), desc="Inserting edges"
|
777
|
+
):
|
778
|
+
batch_data = [d[i : i + self.batch_size] for d in data]
|
427
779
|
collection.insert(batch_data)
|
428
780
|
|
429
781
|
collection.flush()
|
430
|
-
logger.info(
|
782
|
+
logger.info(
|
783
|
+
"Edges collection created with %d entities", collection.num_entities
|
784
|
+
)
|
431
785
|
|
432
786
|
def run(self):
|
433
787
|
"""Main execution method."""
|
434
788
|
try:
|
435
|
-
logger.info("Starting Milvus data loading process...")
|
789
|
+
logger.info("Starting Dynamic Milvus data loading process...")
|
790
|
+
logger.info(
|
791
|
+
"System: %s %s", self.detector.os_type, self.detector.architecture
|
792
|
+
)
|
793
|
+
logger.info("GPU acceleration: %s", self.use_gpu)
|
436
794
|
|
437
795
|
# Connect to Milvus
|
438
796
|
self.connect_to_milvus()
|
@@ -443,66 +801,110 @@ class MilvusDataLoader:
|
|
443
801
|
# Prepare data
|
444
802
|
logger.info("Data Preparation started...")
|
445
803
|
# Get nodes enrichment and embedding dataframes
|
446
|
-
nodes_enrichment_df = graph[
|
447
|
-
nodes_embedding_df = graph[
|
804
|
+
nodes_enrichment_df = graph["nodes"]["enrichment"]
|
805
|
+
nodes_embedding_df = graph["nodes"]["embedding"]
|
448
806
|
|
449
807
|
# Get edges enrichment and embedding dataframes
|
450
|
-
edges_enrichment_df = graph[
|
451
|
-
|
452
|
-
edges_embedding_df = graph['edges']['embedding']
|
808
|
+
edges_enrichment_df = graph["edges"]["enrichment"]
|
809
|
+
edges_embedding_df = graph["edges"]["embedding"] # List of dataframes
|
453
810
|
|
454
|
-
# For nodes, we can directly merge enrichment and embedding
|
455
811
|
# Merge nodes enrichment and embedding dataframes
|
456
812
|
merged_nodes_df = nodes_enrichment_df.merge(
|
457
813
|
nodes_embedding_df[["node_id", "desc_emb", "feat_emb"]],
|
458
814
|
on="node_id",
|
459
|
-
how="left"
|
815
|
+
how="left",
|
460
816
|
)
|
461
817
|
|
462
818
|
# Create collections and load data
|
463
819
|
self.create_nodes_collection(merged_nodes_df)
|
464
820
|
self.create_node_type_collections(merged_nodes_df)
|
465
|
-
self.create_edges_collection(edges_enrichment_df,
|
466
|
-
edges_embedding_df)
|
821
|
+
self.create_edges_collection(edges_enrichment_df, edges_embedding_df)
|
467
822
|
|
468
823
|
# List all collections for verification
|
469
824
|
logger.info("Data loading completed successfully!")
|
470
825
|
logger.info("Created collections:")
|
471
|
-
for coll in utility.list_collections():
|
472
|
-
collection = Collection(name=coll)
|
826
|
+
for coll in self.pymilvus_modules["utility"].list_collections():
|
827
|
+
collection = self.pymilvus_modules["Collection"](name=coll)
|
473
828
|
logger.info(" %s: %d entities", coll, collection.num_entities)
|
474
829
|
|
475
830
|
except Exception as e:
|
476
831
|
logger.error("Error during data loading: %s", str(e))
|
832
|
+
import traceback
|
833
|
+
|
834
|
+
logger.error("Full traceback: %s", traceback.format_exc())
|
477
835
|
raise
|
478
836
|
|
479
837
|
|
480
838
|
def main():
|
481
|
-
"""Main function to run the data loader."""
|
839
|
+
"""Main function to run the dynamic data loader."""
|
482
840
|
# Resolve the fallback data path relative to this script's location
|
483
841
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
484
842
|
default_data_dir = os.path.join(script_dir, "tests/files/biobridge_multimodal/")
|
485
843
|
|
486
|
-
# Configuration
|
844
|
+
# Configuration with environment variable fallbacks - matches original exactly
|
487
845
|
config = {
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
846
|
+
"milvus_host": os.getenv("MILVUS_HOST", "localhost"),
|
847
|
+
"milvus_port": os.getenv("MILVUS_PORT", "19530"),
|
848
|
+
"milvus_user": os.getenv("MILVUS_USER", "root"),
|
849
|
+
"milvus_password": os.getenv("MILVUS_PASSWORD", "Milvus"),
|
850
|
+
"milvus_database": os.getenv("MILVUS_DATABASE", "t2kg_primekg"),
|
851
|
+
"data_dir": os.getenv("DATA_DIR", default_data_dir),
|
852
|
+
"batch_size": int(os.getenv("BATCH_SIZE", "500")),
|
853
|
+
"chunk_size": int(os.getenv("CHUNK_SIZE", "5")),
|
854
|
+
"auto_install_packages": os.getenv("AUTO_INSTALL_PACKAGES", "true").lower()
|
855
|
+
== "true",
|
496
856
|
}
|
497
857
|
|
498
|
-
#
|
499
|
-
|
500
|
-
|
501
|
-
|
858
|
+
# Override detection for testing/forcing specific modes
|
859
|
+
force_cpu = os.getenv("FORCE_CPU", "false").lower() == "true"
|
860
|
+
if force_cpu:
|
861
|
+
logger.info("FORCE_CPU environment variable set - forcing CPU mode")
|
502
862
|
|
503
|
-
#
|
504
|
-
|
505
|
-
|
863
|
+
# Print configuration for debugging - matches original format
|
864
|
+
logger.info("=== Dynamic Milvus Data Loader ===")
|
865
|
+
logger.info("Configuration:")
|
866
|
+
for key, value in config.items():
|
867
|
+
# Don't log sensitive information
|
868
|
+
if "password" in key.lower():
|
869
|
+
logger.info(" %s: %s", key, "*" * len(str(value)))
|
870
|
+
else:
|
871
|
+
logger.info(" %s: %s", key, value)
|
872
|
+
|
873
|
+
# Additional environment info
|
874
|
+
logger.info("Environment:")
|
875
|
+
logger.info(" Python version: %s", sys.version)
|
876
|
+
logger.info(" Platform: %s", platform.platform())
|
877
|
+
logger.info(" Force CPU mode: %s", force_cpu)
|
878
|
+
logger.info(" Script directory: %s", script_dir)
|
879
|
+
logger.info(" Default data directory: %s", default_data_dir)
|
880
|
+
|
881
|
+
try:
|
882
|
+
# Create and run dynamic data loader
|
883
|
+
loader = DynamicDataLoader(config)
|
884
|
+
|
885
|
+
# Override GPU detection if forced
|
886
|
+
if force_cpu:
|
887
|
+
loader.detector.use_gpu = False
|
888
|
+
loader.use_gpu = False
|
889
|
+
loader.normalize_vectors = False
|
890
|
+
loader.vector_index_type = "HNSW"
|
891
|
+
loader.metric_type = "COSINE"
|
892
|
+
logger.info("Forced CPU mode - updated loader settings")
|
893
|
+
|
894
|
+
# Run the data loading process
|
895
|
+
loader.run()
|
896
|
+
|
897
|
+
logger.info("=== Data Loading Completed Successfully ===")
|
898
|
+
|
899
|
+
except KeyboardInterrupt:
|
900
|
+
logger.info("Data loading interrupted by user")
|
901
|
+
sys.exit(1)
|
902
|
+
except Exception as e:
|
903
|
+
logger.error("Fatal error during data loading: %s", str(e))
|
904
|
+
import traceback
|
905
|
+
|
906
|
+
logger.error("Full traceback: %s", traceback.format_exc())
|
907
|
+
sys.exit(1)
|
506
908
|
|
507
909
|
|
508
910
|
if __name__ == "__main__":
|