aiagents4pharma 1.41.0__py3-none-any.whl → 1.42.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,438 +1,796 @@
1
- # pylint: disable=wrong-import-position
2
1
  #!/usr/bin/env python3
2
+ # pylint: skip-file
3
3
  """
4
- Script to load PrimeKG multimodal data into Milvus database.
5
- This script runs after Milvus container is ready and loads the .pkl file data.
4
+ Dynamic Cross-Platform PrimeKG Multimodal Data Loader for Milvus Database.
5
+ Automatically detects system capabilities and chooses appropriate libraries and configurations.
6
+
7
+ Supports:
8
+ - Windows, Linux, macOS
9
+ - CPU-only mode (pandas/numpy)
10
+ - NVIDIA GPU mode (cudf/cupy)
11
+ - Dynamic index selection based on hardware
12
+ - Automatic dependency installation
6
13
  """
7
14
 
8
- import os
9
- import sys
10
- import subprocess
11
15
  import glob
12
16
  import logging
13
- from typing import Dict, Any, List
14
-
15
- def install_packages():
16
- """Install required packages."""
17
- packages = [
18
- "pip install --extra-index-url=https://pypi.nvidia.com cudf-cu12",
19
- "pip install --extra-index-url=https://pypi.nvidia.com dask-cudf-cu12",
20
- "pip install pymilvus==2.5.11",
21
- "pip install numpy==1.26.4",
22
- "pip install pandas==2.1.3",
23
- "pip install tqdm==4.67.1",
24
- ]
25
-
26
- print("[DATA LOADER] Installing required packages...")
27
- for package_cmd in packages:
28
- print(f"[DATA LOADER] Running: {package_cmd}")
29
- result = subprocess.run(package_cmd.split(), capture_output=True, text=True, check=True)
30
- if result.returncode != 0:
31
- print(f"[DATA LOADER] Error installing package: {result.stderr}")
32
- sys.exit(1)
33
- print("[DATA LOADER] All packages installed successfully!")
34
-
35
- # Install packages first
36
- install_packages()
37
-
38
- try:
39
- import cudf
40
- import cupy as cp
41
- except ImportError as e:
42
- print("[DATA LOADER] cudf or cupy not found. Please ensure they are installed correctly.")
43
- sys.exit(1)
44
-
45
- from pymilvus import (
46
- db,
47
- connections,
48
- FieldSchema,
49
- CollectionSchema,
50
- DataType,
51
- Collection,
52
- utility
53
- )
54
- from tqdm import tqdm
17
+ import os
18
+ import platform
19
+ import subprocess
20
+ import sys
21
+ from typing import Any, Dict, List, Optional, Union
55
22
 
56
23
  # Configure logging
57
- logging.basicConfig(level=logging.INFO, format='[DATA LOADER] %(message)s')
24
+ logging.basicConfig(level=logging.INFO, format="[DATA LOADER] %(message)s")
58
25
  logger = logging.getLogger(__name__)
59
26
 
60
- class MilvusDataLoader:
61
- """
62
- Class to handle loading of BioBridge-PrimeKG multimodal data into Milvus.
63
- """
27
+
28
+ class SystemDetector:
29
+ """Detect system capabilities and choose appropriate libraries."""
30
+
31
+ def __init__(self):
32
+ self.os_type = platform.system().lower() # 'windows', 'linux', 'darwin'
33
+ self.architecture = platform.machine().lower() # 'x86_64', 'arm64', etc.
34
+ self.has_nvidia_gpu = self._detect_nvidia_gpu()
35
+ self.use_gpu = (
36
+ self.has_nvidia_gpu and self.os_type != "darwin"
37
+ ) # No CUDA on macOS
38
+
39
+ logger.info("System Detection Results:")
40
+ logger.info(" OS: %s", self.os_type)
41
+ logger.info(" Architecture: %s", self.architecture)
42
+ logger.info(" NVIDIA GPU detected: %s", self.has_nvidia_gpu)
43
+ logger.info(" Will use GPU acceleration: %s", self.use_gpu)
44
+
45
+ def _detect_nvidia_gpu(self) -> bool:
46
+ """Detect if NVIDIA GPU is available."""
47
+ try:
48
+ # Try nvidia-smi command
49
+ result = subprocess.run(
50
+ ["nvidia-smi"], capture_output=True, text=True, timeout=10
51
+ )
52
+ return result.returncode == 0
53
+ except (
54
+ subprocess.TimeoutExpired,
55
+ FileNotFoundError,
56
+ subprocess.SubprocessError,
57
+ ):
58
+ return False
59
+
60
+ def get_required_packages(self) -> List[str]:
61
+ """Get list of packages to install based on system capabilities - matches original logic."""
62
+ if self.use_gpu and self.os_type == "linux":
63
+ # Exact package list from original script for GPU mode
64
+ packages = [
65
+ # "pip install --extra-index-url=https://pypi.nvidia.com cudf-cu12",
66
+ # "pip install --extra-index-url=https://pypi.nvidia.com dask-cudf-cu12",
67
+ "pip install pymilvus==2.5.11",
68
+ "pip install numpy==1.26.4",
69
+ "pip install pandas==2.1.3",
70
+ "pip install tqdm==4.67.1",
71
+ ]
72
+ return packages
73
+ else:
74
+ # CPU-only packages
75
+ packages = [
76
+ "pip install pymilvus==2.5.11",
77
+ "pip install numpy==1.26.4",
78
+ "pip install pandas==2.1.3",
79
+ "pip install tqdm==4.67.1",
80
+ ]
81
+ return packages
82
+
83
+ def install_packages(self):
84
+ """Install required packages using original script's exact logic."""
85
+ packages = self.get_required_packages()
86
+
87
+ logger.info(
88
+ "Installing packages for %s system%s",
89
+ self.os_type,
90
+ " with GPU support" if self.use_gpu else "",
91
+ )
92
+
93
+ for package_cmd in packages:
94
+ logger.info("Running: %s", package_cmd)
95
+ try:
96
+ result = subprocess.run(
97
+ package_cmd.split(),
98
+ capture_output=True,
99
+ text=True,
100
+ check=True,
101
+ timeout=300,
102
+ )
103
+ if result.returncode != 0:
104
+ logger.error("Error installing package: %s", result.stderr)
105
+ if "cudf" in package_cmd or "dask-cudf" in package_cmd:
106
+ logger.warning(
107
+ "GPU package installation failed, falling back to CPU mode"
108
+ )
109
+ self.use_gpu = False
110
+ return self.install_packages() # Retry with CPU packages
111
+ else:
112
+ sys.exit(1)
113
+ else:
114
+ logger.info("Successfully installed: %s", package_cmd.split()[-1])
115
+ except subprocess.CalledProcessError as e:
116
+ logger.error("Failed to install %s: %s", package_cmd, e.stderr)
117
+ if "cudf" in package_cmd:
118
+ logger.warning(
119
+ "GPU package installation failed, falling back to CPU mode"
120
+ )
121
+ self.use_gpu = False
122
+ return self.install_packages() # Retry with CPU packages
123
+ else:
124
+ raise
125
+ except subprocess.TimeoutExpired:
126
+ logger.error("Installation timeout for package: %s", package_cmd)
127
+ raise
128
+
129
+
130
+ class DynamicDataLoader:
131
+ """Dynamic data loader that adapts to system capabilities."""
132
+
64
133
  def __init__(self, config: Dict[str, Any]):
65
- """Initialize the MilvusDataLoader with configuration parameters."""
134
+ """Initialize with system detection and dynamic library loading."""
66
135
  self.config = config
67
- self.milvus_host = config.get('milvus_host', 'localhost')
68
- self.milvus_port = config.get('milvus_port', '19530')
69
- self.milvus_user = config.get('milvus_user', 'root')
70
- self.milvus_password = config.get('milvus_password', 'Milvus')
71
- self.milvus_database = config.get('milvus_database', 't2kg_primekg')
72
- self.data_dir = config.get('data_dir',
73
- 'tests/files/biobridge_multimodal/')
74
- self.batch_size = config.get('batch_size', 500)
75
- self.chunk_size = config.get('chunk_size', 5)
76
-
77
- def normalize_matrix(self, m, axis=1):
78
- """Normalize each row of a 2D matrix using CuPy."""
79
- norms = cp.linalg.norm(m, axis=axis, keepdims=True)
80
- return m / norms
81
-
82
- def normalize_vector(self, v):
83
- """Normalize a vector using CuPy."""
84
- v = cp.asarray(v)
85
- norm = cp.linalg.norm(v)
86
- return v / norm
136
+ self.detector = SystemDetector()
137
+
138
+ # Install packages if needed
139
+ if config.get("auto_install_packages", True):
140
+ self.detector.install_packages()
141
+
142
+ # Import libraries based on system capabilities
143
+ self._import_libraries()
144
+
145
+ # Configuration - exact original parameters
146
+ self.milvus_host = config.get("milvus_host", "localhost")
147
+ self.milvus_port = config.get("milvus_port", "19530")
148
+ self.milvus_user = config.get("milvus_user", "root")
149
+ self.milvus_password = config.get("milvus_password", "Milvus")
150
+ self.milvus_database = config.get("milvus_database", "t2kg_primekg")
151
+ self.data_dir = config.get("data_dir", "./data")
152
+ self.batch_size = config.get("batch_size", 500)
153
+ self.chunk_size = config.get("chunk_size", 5) # Original chunk_size parameter
154
+
155
+ # Dynamic settings based on hardware
156
+ self.use_gpu = self.detector.use_gpu
157
+ self.normalize_vectors = self.use_gpu # Only normalize for GPU (original logic)
158
+ self.vector_index_type = "GPU_CAGRA" if self.use_gpu else "HNSW"
159
+ self.metric_type = "IP" if self.use_gpu else "COSINE"
160
+
161
+ logger.info("Loader Configuration:")
162
+ logger.info(" Using GPU acceleration: %s", self.use_gpu)
163
+ logger.info(" Vector normalization: %s", self.normalize_vectors)
164
+ logger.info(" Vector index type: %s", self.vector_index_type)
165
+ logger.info(" Metric type: %s", self.metric_type)
166
+ logger.info(" Data directory: %s", self.data_dir)
167
+ logger.info(" Batch size: %s", self.batch_size)
168
+ logger.info(" Chunk size: %s", self.chunk_size)
169
+
170
+ def _import_libraries(self):
171
+ """Dynamically import libraries - matches original script's import logic."""
172
+ # Always import base libraries
173
+ import numpy as np
174
+ import pandas as pd
175
+ from pymilvus import (
176
+ Collection,
177
+ CollectionSchema,
178
+ DataType,
179
+ FieldSchema,
180
+ connections,
181
+ db,
182
+ utility,
183
+ )
184
+ from tqdm import tqdm
185
+
186
+ self.pd = pd
187
+ self.np = np
188
+ self.tqdm = tqdm
189
+ self.pymilvus_modules = {
190
+ "db": db,
191
+ "connections": connections,
192
+ "FieldSchema": FieldSchema,
193
+ "CollectionSchema": CollectionSchema,
194
+ "DataType": DataType,
195
+ "Collection": Collection,
196
+ "utility": utility,
197
+ }
198
+
199
+ # Conditionally import GPU libraries - matches original error handling
200
+ if self.detector.use_gpu:
201
+ try:
202
+ import cudf # pyright: ignore
203
+ import cupy as cp # pyright: ignore
204
+
205
+ self.cudf = cudf
206
+ self.cp = cp
207
+ logger.info("Successfully imported GPU libraries (cudf, cupy)")
208
+ except ImportError as e:
209
+ logger.error(
210
+ "[DATA LOADER] cudf or cupy not found. Please ensure they are installed correctly."
211
+ )
212
+ logger.error("Import error: %s", str(e))
213
+ # Match original script's exit behavior for critical GPU import failure
214
+ if not os.getenv("FORCE_CPU", "false").lower() == "true":
215
+ logger.error(
216
+ "GPU libraries required but not available. Set FORCE_CPU=true to use CPU mode."
217
+ )
218
+ sys.exit(1)
219
+ else:
220
+ logger.warning("Falling back to CPU mode due to FORCE_CPU=true")
221
+ self.detector.use_gpu = False
222
+ self.use_gpu = False
223
+
224
+ def _read_dataframe(
225
+ self, file_path: str, columns: Optional[List[str]] = None
226
+ ) -> Union["pd.DataFrame", "cudf.DataFrame"]: # type: ignore[reportUndefinedVariable] # noqa: F821
227
+ """Read dataframe using appropriate library."""
228
+ if self.use_gpu:
229
+ return self.cudf.read_parquet(file_path, columns=columns)
230
+ else:
231
+ return self.pd.read_parquet(file_path, columns=columns)
232
+
233
+ def _concat_dataframes(
234
+ self, df_list: List, ignore_index: bool = True
235
+ ) -> Union["pd.DataFrame", "cudf.DataFrame"]: # type: ignore[reportUndefinedVariable] # noqa: F821
236
+ """Concatenate dataframes using appropriate library."""
237
+ if self.use_gpu:
238
+ return self.cudf.concat(df_list, ignore_index=ignore_index)
239
+ else:
240
+ return self.pd.concat(df_list, ignore_index=ignore_index)
241
+
242
+ def _normalize_matrix(self, matrix, axis: int = 1):
243
+ """Normalize matrix using appropriate library."""
244
+ if not self.normalize_vectors:
245
+ return matrix
246
+
247
+ if self.use_gpu:
248
+ # Use cupy for GPU
249
+ matrix_cp = self.cp.asarray(matrix).astype(self.cp.float32)
250
+ norms = self.cp.linalg.norm(matrix_cp, axis=axis, keepdims=True)
251
+ return matrix_cp / norms
252
+ else:
253
+ # Use numpy for CPU (but we don't normalize for CPU/COSINE)
254
+ return matrix
255
+
256
+ def _extract_embeddings(self, df, column_name: str):
257
+ """Extract embeddings and convert to appropriate format."""
258
+ if self.use_gpu:
259
+ # cuDF list extraction
260
+ emb_data = self.cp.asarray(df[column_name].list.leaves).astype(
261
+ self.cp.float32
262
+ )
263
+ return emb_data.reshape(df.shape[0], -1)
264
+ else:
265
+ # pandas extraction
266
+ emb_list = []
267
+ for emb in df[column_name]:
268
+ if isinstance(emb, list):
269
+ emb_list.append(emb)
270
+ else:
271
+ emb_list.append(emb.tolist() if hasattr(emb, "tolist") else emb)
272
+ return self.np.array(emb_list, dtype=self.np.float32)
273
+
274
+ def _to_list(self, data):
275
+ """Convert data to list format for Milvus insertion."""
276
+ if self.use_gpu:
277
+ # For cuDF data, use to_arrow().to_pylist()
278
+ if hasattr(data, "to_arrow"):
279
+ return data.to_arrow().to_pylist()
280
+ elif hasattr(data, "tolist"):
281
+ # Fallback for cupy arrays
282
+ return data.tolist()
283
+ else:
284
+ return list(data)
285
+ else:
286
+ # For pandas/numpy data
287
+ if hasattr(data, "tolist"):
288
+ return data.tolist()
289
+ elif hasattr(data, "to_arrow"):
290
+ return data.to_arrow().to_pylist()
291
+ else:
292
+ return list(data)
87
293
 
88
294
  def connect_to_milvus(self):
89
295
  """Connect to Milvus and setup database."""
90
296
  logger.info("Connecting to Milvus at %s:%s", self.milvus_host, self.milvus_port)
91
297
 
92
- connections.connect(
298
+ self.pymilvus_modules["connections"].connect(
93
299
  alias="default",
94
300
  host=self.milvus_host,
95
301
  port=self.milvus_port,
96
302
  user=self.milvus_user,
97
- password=self.milvus_password
303
+ password=self.milvus_password,
98
304
  )
99
305
 
100
306
  # Check if database exists, create if it doesn't
101
- if self.milvus_database not in db.list_database():
307
+ if self.milvus_database not in self.pymilvus_modules["db"].list_database():
102
308
  logger.info("Creating database: %s", self.milvus_database)
103
- db.create_database(self.milvus_database)
309
+ self.pymilvus_modules["db"].create_database(self.milvus_database)
104
310
 
105
311
  # Switch to the desired database
106
- db.using_database(self.milvus_database)
312
+ self.pymilvus_modules["db"].using_database(self.milvus_database)
107
313
  logger.info("Using database: %s", self.milvus_database)
108
314
 
109
315
  def load_graph_data(self):
110
- """Load the pickle file containing graph data."""
316
+ """Load the parquet files containing graph data."""
111
317
  logger.info("Loading graph data from: %s", self.data_dir)
112
318
 
113
319
  if not os.path.exists(self.data_dir):
114
320
  raise FileNotFoundError(f"Data directory not found: {self.data_dir}")
115
321
 
116
- # Load dataframes containing nodes and edges
117
- # Loop over nodes and edges
118
322
  graph = {}
119
323
  for element in ["nodes", "edges"]:
120
- # Make an empty dictionary for each folder
121
324
  graph[element] = {}
122
325
  for stage in ["enrichment", "embedding"]:
123
- print(element, stage)
124
- # Create the file pattern for the current subfolder
125
- file_list = glob.glob(os.path.join(self.data_dir,
126
- element,
127
- stage,
128
- '*.parquet.gzip'))
129
- print(file_list)
130
- # Read and concatenate all dataframes in the folder
131
- # Except the edges embedding, which is too large to read in one go
132
- # We are using a chunk size to read the edges embedding in smaller parts instead
326
+ logger.info("Processing %s %s", element, stage)
327
+
328
+ file_list = glob.glob(
329
+ os.path.join(self.data_dir, element, stage, "*.parquet.gzip")
330
+ )
331
+ logger.info("Found %d files for %s %s", len(file_list), element, stage)
332
+
333
+ if not file_list:
334
+ logger.warning("No files found for %s %s", element, stage)
335
+ continue
336
+
337
+ # For edges embedding, process in chunks due to size
133
338
  if element == "edges" and stage == "embedding":
134
- # For edges embedding, only read two columns:
135
- # triplet_index and edge_emb
136
- # Loop by chunks
137
339
  chunk_size = self.chunk_size
138
340
  graph[element][stage] = []
139
341
  for i in range(0, len(file_list), chunk_size):
140
- chunk_files = file_list[i:i+chunk_size]
141
- chunk_df = cudf.concat([
142
- cudf.read_parquet(f, columns=["triplet_index", "edge_emb"])
143
- for f in chunk_files
144
- ], ignore_index=True)
342
+ chunk_files = file_list[i : i + chunk_size]
343
+ chunk_df_list = []
344
+ for f in chunk_files:
345
+ df = self._read_dataframe(
346
+ f, columns=["triplet_index", "edge_emb"]
347
+ )
348
+ chunk_df_list.append(df)
349
+ chunk_df = self._concat_dataframes(
350
+ chunk_df_list, ignore_index=True
351
+ )
145
352
  graph[element][stage].append(chunk_df)
146
353
  else:
147
- # For nodes and edges enrichment,
148
- # read and concatenate all dataframes in the folder
149
- # This includes the nodes embedding,
150
- # which is small enough to read in one go
151
- graph[element][stage] = cudf.concat([
152
- cudf.read_parquet(f) for f in file_list
153
- ], ignore_index=True)
354
+ # For other combinations, read all files
355
+ df_list = []
356
+ for f in file_list:
357
+ df = self._read_dataframe(f)
358
+ df_list.append(df)
359
+ graph[element][stage] = self._concat_dataframes(
360
+ df_list, ignore_index=True
361
+ )
154
362
 
155
363
  logger.info("Graph data loaded successfully")
156
364
  return graph
157
365
 
158
- def create_nodes_collection(self, nodes_df: cudf.DataFrame):
366
+ def _get_embedding_dimension(self, df, column_name: str) -> int:
367
+ """Get embedding dimension using original script's exact logic."""
368
+ first_emb = df.iloc[0][column_name]
369
+ if self.use_gpu:
370
+ # cuDF format - matches original: len(nodes_df.iloc[0]['desc_emb'].to_arrow().to_pylist()[0])
371
+ return len(first_emb.to_arrow().to_pylist()[0])
372
+ else:
373
+ # pandas format
374
+ if isinstance(first_emb, list):
375
+ return len(first_emb)
376
+ else:
377
+ return len(
378
+ first_emb.tolist() if hasattr(first_emb, "tolist") else first_emb
379
+ )
380
+
381
+ def create_nodes_collection(self, nodes_df):
159
382
  """Create and populate the main nodes collection."""
160
383
  logger.info("Creating main nodes collection...")
161
384
  node_coll_name = f"{self.milvus_database}_nodes"
162
385
 
386
+ # Get embedding dimension
387
+ emb_dim = self._get_embedding_dimension(nodes_df, "desc_emb")
388
+
163
389
  node_fields = [
164
- FieldSchema(name="node_index",
165
- dtype=DataType.INT64,
166
- is_primary=True),
167
- FieldSchema(name="node_id",
168
- dtype=DataType.VARCHAR,
169
- max_length=1024),
170
- FieldSchema(name="node_name",
171
- dtype=DataType.VARCHAR,
172
- max_length=1024,
173
- enable_analyzer=True,
174
- enable_match=True),
175
- FieldSchema(name="node_type",
176
- dtype=DataType.VARCHAR,
177
- max_length=1024,
178
- enable_analyzer=True,
179
- enable_match=True),
180
- FieldSchema(name="desc",
181
- dtype=DataType.VARCHAR,
182
- max_length=40960,
183
- enable_analyzer=True,
184
- enable_match=True),
185
- FieldSchema(name="desc_emb",
186
- dtype=DataType.FLOAT_VECTOR,
187
- dim=len(nodes_df.iloc[0]['desc_emb'].to_arrow().to_pylist()[0])),
390
+ self.pymilvus_modules["FieldSchema"](
391
+ name="node_index",
392
+ dtype=self.pymilvus_modules["DataType"].INT64,
393
+ is_primary=True,
394
+ ),
395
+ self.pymilvus_modules["FieldSchema"](
396
+ name="node_id",
397
+ dtype=self.pymilvus_modules["DataType"].VARCHAR,
398
+ max_length=1024,
399
+ ),
400
+ self.pymilvus_modules["FieldSchema"](
401
+ name="node_name",
402
+ dtype=self.pymilvus_modules["DataType"].VARCHAR,
403
+ max_length=1024,
404
+ enable_analyzer=True,
405
+ enable_match=True,
406
+ ),
407
+ self.pymilvus_modules["FieldSchema"](
408
+ name="node_type",
409
+ dtype=self.pymilvus_modules["DataType"].VARCHAR,
410
+ max_length=1024,
411
+ enable_analyzer=True,
412
+ enable_match=True,
413
+ ),
414
+ self.pymilvus_modules["FieldSchema"](
415
+ name="desc",
416
+ dtype=self.pymilvus_modules["DataType"].VARCHAR,
417
+ max_length=40960,
418
+ enable_analyzer=True,
419
+ enable_match=True,
420
+ ),
421
+ self.pymilvus_modules["FieldSchema"](
422
+ name="desc_emb",
423
+ dtype=self.pymilvus_modules["DataType"].FLOAT_VECTOR,
424
+ dim=emb_dim,
425
+ ),
188
426
  ]
189
- schema = CollectionSchema(fields=node_fields,
190
- description=f"Schema for collection {node_coll_name}")
427
+
428
+ schema = self.pymilvus_modules["CollectionSchema"](
429
+ fields=node_fields, description=f"Schema for collection {node_coll_name}"
430
+ )
191
431
 
192
432
  # Create collection if it doesn't exist
193
- if not utility.has_collection(node_coll_name):
194
- collection = Collection(name=node_coll_name, schema=schema)
433
+ if not self.pymilvus_modules["utility"].has_collection(node_coll_name):
434
+ collection = self.pymilvus_modules["Collection"](
435
+ name=node_coll_name, schema=schema
436
+ )
195
437
  else:
196
- collection = Collection(name=node_coll_name)
197
-
198
- # Create indexes
199
- collection.create_index(field_name="node_index",
200
- index_params={"index_type": "STL_SORT"},
201
- index_name="node_index_index")
202
- collection.create_index(field_name="node_name",
203
- index_params={"index_type": "INVERTED"},
204
- index_name="node_name_index")
205
- collection.create_index(field_name="node_type",
206
- index_params={"index_type": "INVERTED"},
207
- index_name="node_type_index")
208
- collection.create_index(field_name="desc",
209
- index_params={"index_type": "INVERTED"},
210
- index_name="desc_index")
211
- collection.create_index(field_name="desc_emb",
212
- index_params={"index_type": "GPU_CAGRA",
213
- "metric_type": "IP"},
214
- index_name="desc_emb_index")
438
+ collection = self.pymilvus_modules["Collection"](name=node_coll_name)
439
+
440
+ # Create indexes with dynamic parameters
441
+ collection.create_index(
442
+ field_name="node_index",
443
+ index_params={"index_type": "STL_SORT"},
444
+ index_name="node_index_index",
445
+ )
446
+ collection.create_index(
447
+ field_name="node_name",
448
+ index_params={"index_type": "INVERTED"},
449
+ index_name="node_name_index",
450
+ )
451
+ collection.create_index(
452
+ field_name="node_type",
453
+ index_params={"index_type": "INVERTED"},
454
+ index_name="node_type_index",
455
+ )
456
+ collection.create_index(
457
+ field_name="desc",
458
+ index_params={"index_type": "INVERTED"},
459
+ index_name="desc_index",
460
+ )
461
+ collection.create_index(
462
+ field_name="desc_emb",
463
+ index_params={
464
+ "index_type": self.vector_index_type,
465
+ "metric_type": self.metric_type,
466
+ },
467
+ index_name="desc_emb_index",
468
+ )
215
469
 
216
470
  # Prepare and insert data
217
- desc_emb_norm = cp.asarray(nodes_df["desc_emb"].list.leaves).astype(cp.float32).\
218
- reshape(nodes_df.shape[0], -1)
219
- desc_emb_norm = self.normalize_matrix(desc_emb_norm, axis=1)
471
+ desc_emb_data = self._extract_embeddings(nodes_df, "desc_emb")
472
+ desc_emb_normalized = self._normalize_matrix(desc_emb_data, axis=1)
473
+
220
474
  data = [
221
- nodes_df["node_index"].to_arrow().to_pylist(),
222
- nodes_df["node_id"].to_arrow().to_pylist(),
223
- nodes_df["node_name"].to_arrow().to_pylist(),
224
- nodes_df["node_type"].to_arrow().to_pylist(),
225
- nodes_df["desc"].to_arrow().to_pylist(),
226
- desc_emb_norm.tolist(), # Use normalized embeddings
475
+ self._to_list(nodes_df["node_index"]),
476
+ self._to_list(nodes_df["node_id"]),
477
+ self._to_list(nodes_df["node_name"]),
478
+ self._to_list(nodes_df["node_type"]),
479
+ self._to_list(nodes_df["desc"]),
480
+ self._to_list(desc_emb_normalized),
227
481
  ]
228
482
 
229
483
  # Insert data in batches
230
484
  total = len(data[0])
231
- for i in tqdm(range(0, total, self.batch_size), desc="Inserting nodes"):
232
- batch = [col[i:i+self.batch_size] for col in data]
485
+ for i in self.tqdm(range(0, total, self.batch_size), desc="Inserting nodes"):
486
+ batch = [col[i : i + self.batch_size] for col in data]
233
487
  collection.insert(batch)
234
488
 
235
489
  collection.flush()
236
- logger.info("Nodes collection created with %d entities", collection.num_entities)
490
+ logger.info(
491
+ "Nodes collection created with %d entities", collection.num_entities
492
+ )
237
493
 
238
- def create_node_type_collections(self, nodes_df: cudf.DataFrame):
494
+ def create_node_type_collections(self, nodes_df):
239
495
  """Create separate collections for each node type."""
240
496
  logger.info("Creating node type-specific collections...")
241
497
 
242
- for node_type, nodes_df_ in tqdm(nodes_df.groupby('node_type'),
243
- desc="Processing node types"):
244
- node_coll_name = f"{self.milvus_database}_nodes_{node_type.replace('/', '_')}"
498
+ for node_type, nodes_df_ in self.tqdm(
499
+ nodes_df.groupby("node_type"), desc="Processing node types"
500
+ ):
501
+ node_coll_name = (
502
+ f"{self.milvus_database}_nodes_{node_type.replace('/', '_')}"
503
+ )
504
+
505
+ # Get embedding dimensions
506
+ desc_dim = self._get_embedding_dimension(nodes_df_, "desc_emb")
507
+ feat_dim = self._get_embedding_dimension(nodes_df_, "feat_emb")
245
508
 
246
509
  node_fields = [
247
- FieldSchema(name="node_index",
248
- dtype=DataType.INT64,
249
- is_primary=True,
250
- auto_id=False),
251
- FieldSchema(name="node_id",
252
- dtype=DataType.VARCHAR,
253
- max_length=1024),
254
- FieldSchema(name="node_name",
255
- dtype=DataType.VARCHAR,
256
- max_length=1024,
257
- enable_analyzer=True,
258
- enable_match=True),
259
- FieldSchema(name="node_type",
260
- dtype=DataType.VARCHAR,
261
- max_length=1024,
262
- enable_analyzer=True,
263
- enable_match=True),
264
- FieldSchema(name="desc",
265
- dtype=DataType.VARCHAR,
266
- max_length=40960,
267
- enable_analyzer=True,
268
- enable_match=True),
269
- FieldSchema(name="desc_emb",
270
- dtype=DataType.FLOAT_VECTOR,
271
- dim=len(nodes_df_.iloc[0]['desc_emb'].to_arrow().to_pylist()[0])),
272
- FieldSchema(name="feat",
273
- dtype=DataType.VARCHAR,
274
- max_length=40960,
275
- enable_analyzer=True,
276
- enable_match=True),
277
- FieldSchema(name="feat_emb",
278
- dtype=DataType.FLOAT_VECTOR,
279
- dim=len(nodes_df_.iloc[0]['feat_emb'].to_arrow().to_pylist()[0])),
510
+ self.pymilvus_modules["FieldSchema"](
511
+ name="node_index",
512
+ dtype=self.pymilvus_modules["DataType"].INT64,
513
+ is_primary=True,
514
+ auto_id=False,
515
+ ),
516
+ self.pymilvus_modules["FieldSchema"](
517
+ name="node_id",
518
+ dtype=self.pymilvus_modules["DataType"].VARCHAR,
519
+ max_length=1024,
520
+ ),
521
+ self.pymilvus_modules["FieldSchema"](
522
+ name="node_name",
523
+ dtype=self.pymilvus_modules["DataType"].VARCHAR,
524
+ max_length=1024,
525
+ enable_analyzer=True,
526
+ enable_match=True,
527
+ ),
528
+ self.pymilvus_modules["FieldSchema"](
529
+ name="node_type",
530
+ dtype=self.pymilvus_modules["DataType"].VARCHAR,
531
+ max_length=1024,
532
+ enable_analyzer=True,
533
+ enable_match=True,
534
+ ),
535
+ self.pymilvus_modules["FieldSchema"](
536
+ name="desc",
537
+ dtype=self.pymilvus_modules["DataType"].VARCHAR,
538
+ max_length=40960,
539
+ enable_analyzer=True,
540
+ enable_match=True,
541
+ ),
542
+ self.pymilvus_modules["FieldSchema"](
543
+ name="desc_emb",
544
+ dtype=self.pymilvus_modules["DataType"].FLOAT_VECTOR,
545
+ dim=desc_dim,
546
+ ),
547
+ self.pymilvus_modules["FieldSchema"](
548
+ name="feat",
549
+ dtype=self.pymilvus_modules["DataType"].VARCHAR,
550
+ max_length=40960,
551
+ enable_analyzer=True,
552
+ enable_match=True,
553
+ ),
554
+ self.pymilvus_modules["FieldSchema"](
555
+ name="feat_emb",
556
+ dtype=self.pymilvus_modules["DataType"].FLOAT_VECTOR,
557
+ dim=feat_dim,
558
+ ),
280
559
  ]
281
- schema = CollectionSchema(fields=node_fields,
282
- description=f"schema for collection {node_coll_name}")
283
560
 
284
- if not utility.has_collection(node_coll_name):
285
- collection = Collection(name=node_coll_name, schema=schema)
561
+ schema = self.pymilvus_modules["CollectionSchema"](
562
+ fields=node_fields,
563
+ description=f"schema for collection {node_coll_name}",
564
+ )
565
+
566
+ if not self.pymilvus_modules["utility"].has_collection(node_coll_name):
567
+ collection = self.pymilvus_modules["Collection"](
568
+ name=node_coll_name, schema=schema
569
+ )
286
570
  else:
287
- collection = Collection(name=node_coll_name)
288
-
289
- # Create indexes
290
- collection.create_index(field_name="node_index",
291
- index_params={"index_type": "STL_SORT"},
292
- index_name="node_index_index")
293
- collection.create_index(field_name="node_name",
294
- index_params={"index_type": "INVERTED"},
295
- index_name="node_name_index")
296
- collection.create_index(field_name="node_type",
297
- index_params={"index_type": "INVERTED"},
298
- index_name="node_type_index")
299
- collection.create_index(field_name="desc",
300
- index_params={"index_type": "INVERTED"},
301
- index_name="desc_index")
302
- collection.create_index(field_name="desc_emb",
303
- index_params={"index_type": "GPU_CAGRA",
304
- "metric_type": "IP"},
305
- index_name="desc_emb_index")
306
- collection.create_index(field_name="feat_emb",
307
- index_params={"index_type": "GPU_CAGRA",
308
- "metric_type": "IP"},
309
- index_name="feat_emb_index")
571
+ collection = self.pymilvus_modules["Collection"](name=node_coll_name)
572
+
573
+ # Create indexes with dynamic parameters
574
+ collection.create_index(
575
+ field_name="node_index",
576
+ index_params={"index_type": "STL_SORT"},
577
+ index_name="node_index_index",
578
+ )
579
+ collection.create_index(
580
+ field_name="node_name",
581
+ index_params={"index_type": "INVERTED"},
582
+ index_name="node_name_index",
583
+ )
584
+ collection.create_index(
585
+ field_name="node_type",
586
+ index_params={"index_type": "INVERTED"},
587
+ index_name="node_type_index",
588
+ )
589
+ collection.create_index(
590
+ field_name="desc",
591
+ index_params={"index_type": "INVERTED"},
592
+ index_name="desc_index",
593
+ )
594
+ collection.create_index(
595
+ field_name="desc_emb",
596
+ index_params={
597
+ "index_type": self.vector_index_type,
598
+ "metric_type": self.metric_type,
599
+ },
600
+ index_name="desc_emb_index",
601
+ )
602
+ collection.create_index(
603
+ field_name="feat_emb",
604
+ index_params={
605
+ "index_type": self.vector_index_type,
606
+ "metric_type": self.metric_type,
607
+ },
608
+ index_name="feat_emb_index",
609
+ )
310
610
 
311
611
  # Prepare data
312
- desc_emb_norm = cp.asarray(nodes_df_["desc_emb"].list.leaves).astype(cp.float32).\
313
- reshape(nodes_df_.shape[0], -1)
314
- desc_emb_norm = self.normalize_matrix(desc_emb_norm, axis=1)
315
- feat_emb_norm = cp.asarray(nodes_df_["feat_emb"].list.leaves).astype(cp.float32).\
316
- reshape(nodes_df_.shape[0], -1)
317
- feat_emb_norm = self.normalize_matrix(feat_emb_norm, axis=1)
612
+ desc_emb_data = self._extract_embeddings(nodes_df_, "desc_emb")
613
+ feat_emb_data = self._extract_embeddings(nodes_df_, "feat_emb")
614
+
615
+ desc_emb_normalized = self._normalize_matrix(desc_emb_data, axis=1)
616
+ feat_emb_normalized = self._normalize_matrix(feat_emb_data, axis=1)
617
+
318
618
  data = [
319
- nodes_df_["node_index"].to_arrow().to_pylist(),
320
- nodes_df_["node_id"].to_arrow().to_pylist(),
321
- nodes_df_["node_name"].to_arrow().to_pylist(),
322
- nodes_df_["node_type"].to_arrow().to_pylist(),
323
- nodes_df_["desc"].to_arrow().to_pylist(),
324
- desc_emb_norm.tolist(), # Use normalized embeddings
325
- nodes_df_["feat"].to_arrow().to_pylist(),
326
- feat_emb_norm.tolist(), # Use normalized embeddings
619
+ self._to_list(nodes_df_["node_index"]),
620
+ self._to_list(nodes_df_["node_id"]),
621
+ self._to_list(nodes_df_["node_name"]),
622
+ self._to_list(nodes_df_["node_type"]),
623
+ self._to_list(nodes_df_["desc"]),
624
+ self._to_list(desc_emb_normalized),
625
+ self._to_list(nodes_df_["feat"]),
626
+ self._to_list(feat_emb_normalized),
327
627
  ]
328
628
 
329
629
  # Insert data in batches
330
630
  total_rows = len(data[0])
331
631
  for i in range(0, total_rows, self.batch_size):
332
- batch = [col[i:i + self.batch_size] for col in data]
632
+ batch = [col[i : i + self.batch_size] for col in data]
333
633
  collection.insert(batch)
334
634
 
335
635
  collection.flush()
336
- logger.info("Collection %s created with %d entities",
337
- node_coll_name, collection.num_entities)
636
+ logger.info(
637
+ "Collection %s created with %d entities",
638
+ node_coll_name,
639
+ collection.num_entities,
640
+ )
338
641
 
339
- def create_edges_collection(self,
340
- edges_enrichment_df: cudf.DataFrame,
341
- edges_embedding_df: List[cudf.DataFrame]):
342
- """Create and populate the edges collection."""
642
+ def create_edges_collection(self, edges_enrichment_df, edges_embedding_df: List):
643
+ """Create and populate the edges collection - exact original logic."""
343
644
  logger.info("Creating edges collection...")
344
645
 
345
646
  edge_coll_name = f"{self.milvus_database}_edges"
346
647
 
648
+ # Get embedding dimension from first chunk - exact original logic
649
+ if self.use_gpu:
650
+ emb_dim = len(
651
+ edges_embedding_df[0].loc[0, "edge_emb"]
652
+ ) # Original cudf access
653
+ else:
654
+ first_edge_emb = edges_embedding_df[0].iloc[0]["edge_emb"]
655
+ emb_dim = (
656
+ len(first_edge_emb)
657
+ if isinstance(first_edge_emb, list)
658
+ else len(first_edge_emb.tolist())
659
+ )
660
+
347
661
  edge_fields = [
348
- FieldSchema(name="triplet_index",
349
- dtype=DataType.INT64,
350
- is_primary=True,
351
- auto_id=False),
352
- FieldSchema(name="head_id",
353
- dtype=DataType.VARCHAR,
354
- max_length=1024),
355
- FieldSchema(name="head_index",
356
- dtype=DataType.INT64),
357
- FieldSchema(name="tail_id",
358
- dtype=DataType.VARCHAR,
359
- max_length=1024),
360
- FieldSchema(name="tail_index",
361
- dtype=DataType.INT64),
362
- FieldSchema(name="edge_type",
363
- dtype=DataType.VARCHAR,
364
- max_length=1024),
365
- FieldSchema(name="display_relation",
366
- dtype=DataType.VARCHAR,
367
- max_length=1024),
368
- FieldSchema(name="feat",
369
- dtype=DataType.VARCHAR,
370
- max_length=40960),
371
- FieldSchema(name="feat_emb",
372
- dtype=DataType.FLOAT_VECTOR,
373
- dim=len(edges_embedding_df[0].loc[0, 'edge_emb'])),
662
+ self.pymilvus_modules["FieldSchema"](
663
+ name="triplet_index",
664
+ dtype=self.pymilvus_modules["DataType"].INT64,
665
+ is_primary=True,
666
+ auto_id=False,
667
+ ),
668
+ self.pymilvus_modules["FieldSchema"](
669
+ name="head_id",
670
+ dtype=self.pymilvus_modules["DataType"].VARCHAR,
671
+ max_length=1024,
672
+ ),
673
+ self.pymilvus_modules["FieldSchema"](
674
+ name="head_index", dtype=self.pymilvus_modules["DataType"].INT64
675
+ ),
676
+ self.pymilvus_modules["FieldSchema"](
677
+ name="tail_id",
678
+ dtype=self.pymilvus_modules["DataType"].VARCHAR,
679
+ max_length=1024,
680
+ ),
681
+ self.pymilvus_modules["FieldSchema"](
682
+ name="tail_index", dtype=self.pymilvus_modules["DataType"].INT64
683
+ ),
684
+ self.pymilvus_modules["FieldSchema"](
685
+ name="edge_type",
686
+ dtype=self.pymilvus_modules["DataType"].VARCHAR,
687
+ max_length=1024,
688
+ ),
689
+ self.pymilvus_modules["FieldSchema"](
690
+ name="display_relation",
691
+ dtype=self.pymilvus_modules["DataType"].VARCHAR,
692
+ max_length=1024,
693
+ ),
694
+ self.pymilvus_modules["FieldSchema"](
695
+ name="feat",
696
+ dtype=self.pymilvus_modules["DataType"].VARCHAR,
697
+ max_length=40960,
698
+ ),
699
+ self.pymilvus_modules["FieldSchema"](
700
+ name="feat_emb",
701
+ dtype=self.pymilvus_modules["DataType"].FLOAT_VECTOR,
702
+ dim=emb_dim,
703
+ ),
374
704
  ]
375
- edge_schema = CollectionSchema(fields=edge_fields,
376
- description="Schema for edges collection")
377
705
 
378
- if not utility.has_collection(edge_coll_name):
379
- collection = Collection(name=edge_coll_name, schema=edge_schema)
706
+ edge_schema = self.pymilvus_modules["CollectionSchema"](
707
+ fields=edge_fields, description="Schema for edges collection"
708
+ )
709
+
710
+ if not self.pymilvus_modules["utility"].has_collection(edge_coll_name):
711
+ collection = self.pymilvus_modules["Collection"](
712
+ name=edge_coll_name, schema=edge_schema
713
+ )
380
714
  else:
381
- collection = Collection(name=edge_coll_name)
382
-
383
- # Create indexes
384
- collection.create_index(field_name="triplet_index",
385
- index_params={"index_type": "STL_SORT"},
386
- index_name="triplet_index_index")
387
- collection.create_index(field_name="head_index",
388
- index_params={"index_type": "STL_SORT"},
389
- index_name="head_index_index")
390
- collection.create_index(field_name="tail_index",
391
- index_params={"index_type": "STL_SORT"},
392
- index_name="tail_index_index")
393
- collection.create_index(field_name="feat_emb",
394
- index_params={"index_type": "GPU_CAGRA",
395
- "metric_type": "IP"},
396
- index_name="feat_emb_index")
397
-
398
- # Iterate over chunked edges embedding df
399
- for edges_df in tqdm(edges_embedding_df):
715
+ collection = self.pymilvus_modules["Collection"](name=edge_coll_name)
716
+
717
+ # Create indexes with dynamic parameters
718
+ collection.create_index(
719
+ field_name="triplet_index",
720
+ index_params={"index_type": "STL_SORT"},
721
+ index_name="triplet_index_index",
722
+ )
723
+ collection.create_index(
724
+ field_name="head_index",
725
+ index_params={"index_type": "STL_SORT"},
726
+ index_name="head_index_index",
727
+ )
728
+ collection.create_index(
729
+ field_name="tail_index",
730
+ index_params={"index_type": "STL_SORT"},
731
+ index_name="tail_index_index",
732
+ )
733
+ collection.create_index(
734
+ field_name="feat_emb",
735
+ index_params={
736
+ "index_type": self.vector_index_type,
737
+ "metric_type": self.metric_type,
738
+ },
739
+ index_name="feat_emb_index",
740
+ )
741
+
742
+ # Iterate over chunked edges embedding df - exact original logic
743
+ for edges_df in self.tqdm(edges_embedding_df, desc="Processing edge chunks"):
400
744
  # Merge enrichment with embedding
401
745
  merged_edges_df = edges_enrichment_df.merge(
402
- edges_df[["triplet_index", "edge_emb"]],
403
- on="triplet_index",
404
- how="inner"
746
+ edges_df[["triplet_index", "edge_emb"]], on="triplet_index", how="inner"
405
747
  )
406
748
 
407
- # Prepare data
408
- edge_emb_cp = cp.asarray(merged_edges_df["edge_emb"].list.leaves).astype(cp.float32).\
409
- reshape(merged_edges_df.shape[0], -1)
410
- edge_emb_norm = self.normalize_matrix(edge_emb_cp, axis=1)
749
+ # Prepare embeddings - exact original logic for GPU
750
+ if self.use_gpu:
751
+ edge_emb_cp = (
752
+ self.cp.asarray(merged_edges_df["edge_emb"].list.leaves)
753
+ .astype(self.cp.float32)
754
+ .reshape(merged_edges_df.shape[0], -1)
755
+ )
756
+ edge_emb_norm = self._normalize_matrix(edge_emb_cp, axis=1)
757
+ else:
758
+ edge_emb_data = self._extract_embeddings(merged_edges_df, "edge_emb")
759
+ edge_emb_norm = self._normalize_matrix(edge_emb_data, axis=1)
760
+
411
761
  data = [
412
- merged_edges_df["triplet_index"].to_arrow().to_pylist(),
413
- merged_edges_df["head_id"].to_arrow().to_pylist(),
414
- merged_edges_df["head_index"].to_arrow().to_pylist(),
415
- merged_edges_df["tail_id"].to_arrow().to_pylist(),
416
- merged_edges_df["tail_index"].to_arrow().to_pylist(),
417
- merged_edges_df["edge_type_str"].to_arrow().to_pylist(),
418
- merged_edges_df["display_relation"].to_arrow().to_pylist(),
419
- merged_edges_df["feat"].to_arrow().to_pylist(),
420
- edge_emb_norm.tolist(), # Use normalized embeddings
762
+ self._to_list(merged_edges_df["triplet_index"]),
763
+ self._to_list(merged_edges_df["head_id"]),
764
+ self._to_list(merged_edges_df["head_index"]),
765
+ self._to_list(merged_edges_df["tail_id"]),
766
+ self._to_list(merged_edges_df["tail_index"]),
767
+ self._to_list(merged_edges_df["edge_type_str"]), # Original field name
768
+ self._to_list(merged_edges_df["display_relation"]),
769
+ self._to_list(merged_edges_df["feat"]),
770
+ self._to_list(edge_emb_norm),
421
771
  ]
422
772
 
423
773
  # Insert data in batches
424
774
  total = len(data[0])
425
- for i in tqdm(range(0, total, self.batch_size), desc="Inserting edges"):
426
- batch_data = [d[i:i+self.batch_size] for d in data]
775
+ for i in self.tqdm(
776
+ range(0, total, self.batch_size), desc="Inserting edges"
777
+ ):
778
+ batch_data = [d[i : i + self.batch_size] for d in data]
427
779
  collection.insert(batch_data)
428
780
 
429
781
  collection.flush()
430
- logger.info("Edges collection created with %d entities", collection.num_entities)
782
+ logger.info(
783
+ "Edges collection created with %d entities", collection.num_entities
784
+ )
431
785
 
432
786
  def run(self):
433
787
  """Main execution method."""
434
788
  try:
435
- logger.info("Starting Milvus data loading process...")
789
+ logger.info("Starting Dynamic Milvus data loading process...")
790
+ logger.info(
791
+ "System: %s %s", self.detector.os_type, self.detector.architecture
792
+ )
793
+ logger.info("GPU acceleration: %s", self.use_gpu)
436
794
 
437
795
  # Connect to Milvus
438
796
  self.connect_to_milvus()
@@ -443,66 +801,110 @@ class MilvusDataLoader:
443
801
  # Prepare data
444
802
  logger.info("Data Preparation started...")
445
803
  # Get nodes enrichment and embedding dataframes
446
- nodes_enrichment_df = graph['nodes']['enrichment']
447
- nodes_embedding_df = graph['nodes']['embedding']
804
+ nodes_enrichment_df = graph["nodes"]["enrichment"]
805
+ nodes_embedding_df = graph["nodes"]["embedding"]
448
806
 
449
807
  # Get edges enrichment and embedding dataframes
450
- edges_enrichment_df = graph['edges']['enrichment']
451
- # !!consisted of a list of dataframes!!
452
- edges_embedding_df = graph['edges']['embedding']
808
+ edges_enrichment_df = graph["edges"]["enrichment"]
809
+ edges_embedding_df = graph["edges"]["embedding"] # List of dataframes
453
810
 
454
- # For nodes, we can directly merge enrichment and embedding
455
811
  # Merge nodes enrichment and embedding dataframes
456
812
  merged_nodes_df = nodes_enrichment_df.merge(
457
813
  nodes_embedding_df[["node_id", "desc_emb", "feat_emb"]],
458
814
  on="node_id",
459
- how="left"
815
+ how="left",
460
816
  )
461
817
 
462
818
  # Create collections and load data
463
819
  self.create_nodes_collection(merged_nodes_df)
464
820
  self.create_node_type_collections(merged_nodes_df)
465
- self.create_edges_collection(edges_enrichment_df,
466
- edges_embedding_df)
821
+ self.create_edges_collection(edges_enrichment_df, edges_embedding_df)
467
822
 
468
823
  # List all collections for verification
469
824
  logger.info("Data loading completed successfully!")
470
825
  logger.info("Created collections:")
471
- for coll in utility.list_collections():
472
- collection = Collection(name=coll)
826
+ for coll in self.pymilvus_modules["utility"].list_collections():
827
+ collection = self.pymilvus_modules["Collection"](name=coll)
473
828
  logger.info(" %s: %d entities", coll, collection.num_entities)
474
829
 
475
830
  except Exception as e:
476
831
  logger.error("Error during data loading: %s", str(e))
832
+ import traceback
833
+
834
+ logger.error("Full traceback: %s", traceback.format_exc())
477
835
  raise
478
836
 
479
837
 
480
838
  def main():
481
- """Main function to run the data loader."""
839
+ """Main function to run the dynamic data loader."""
482
840
  # Resolve the fallback data path relative to this script's location
483
841
  script_dir = os.path.dirname(os.path.abspath(__file__))
484
842
  default_data_dir = os.path.join(script_dir, "tests/files/biobridge_multimodal/")
485
843
 
486
- # Configuration
844
+ # Configuration with environment variable fallbacks - matches original exactly
487
845
  config = {
488
- 'milvus_host': os.getenv('MILVUS_HOST', 'localhost'),
489
- 'milvus_port': os.getenv('MILVUS_PORT', '19530'),
490
- 'milvus_user': os.getenv('MILVUS_USER', 'root'),
491
- 'milvus_password': os.getenv('MILVUS_PASSWORD', 'Milvus'),
492
- 'milvus_database': os.getenv('MILVUS_DATABASE', 't2kg_primekg'),
493
- 'data_dir': os.getenv('DATA_DIR', default_data_dir),
494
- 'batch_size': int(os.getenv('BATCH_SIZE', '500')),
495
- 'chunk_size': int(os.getenv('CHUNK_SIZE', '5')),
846
+ "milvus_host": os.getenv("MILVUS_HOST", "localhost"),
847
+ "milvus_port": os.getenv("MILVUS_PORT", "19530"),
848
+ "milvus_user": os.getenv("MILVUS_USER", "root"),
849
+ "milvus_password": os.getenv("MILVUS_PASSWORD", "Milvus"),
850
+ "milvus_database": os.getenv("MILVUS_DATABASE", "t2kg_primekg"),
851
+ "data_dir": os.getenv("DATA_DIR", default_data_dir),
852
+ "batch_size": int(os.getenv("BATCH_SIZE", "500")),
853
+ "chunk_size": int(os.getenv("CHUNK_SIZE", "5")),
854
+ "auto_install_packages": os.getenv("AUTO_INSTALL_PACKAGES", "true").lower()
855
+ == "true",
496
856
  }
497
857
 
498
- # Print configuration for debugging
499
- print("[DATA LOADER] Configuration:")
500
- for key, value in config.items():
501
- print(f"[DATA LOADER] {key}: {value}")
858
+ # Override detection for testing/forcing specific modes
859
+ force_cpu = os.getenv("FORCE_CPU", "false").lower() == "true"
860
+ if force_cpu:
861
+ logger.info("FORCE_CPU environment variable set - forcing CPU mode")
502
862
 
503
- # Create and run data loader
504
- loader = MilvusDataLoader(config)
505
- loader.run()
863
+ # Print configuration for debugging - matches original format
864
+ logger.info("=== Dynamic Milvus Data Loader ===")
865
+ logger.info("Configuration:")
866
+ for key, value in config.items():
867
+ # Don't log sensitive information
868
+ if "password" in key.lower():
869
+ logger.info(" %s: %s", key, "*" * len(str(value)))
870
+ else:
871
+ logger.info(" %s: %s", key, value)
872
+
873
+ # Additional environment info
874
+ logger.info("Environment:")
875
+ logger.info(" Python version: %s", sys.version)
876
+ logger.info(" Platform: %s", platform.platform())
877
+ logger.info(" Force CPU mode: %s", force_cpu)
878
+ logger.info(" Script directory: %s", script_dir)
879
+ logger.info(" Default data directory: %s", default_data_dir)
880
+
881
+ try:
882
+ # Create and run dynamic data loader
883
+ loader = DynamicDataLoader(config)
884
+
885
+ # Override GPU detection if forced
886
+ if force_cpu:
887
+ loader.detector.use_gpu = False
888
+ loader.use_gpu = False
889
+ loader.normalize_vectors = False
890
+ loader.vector_index_type = "HNSW"
891
+ loader.metric_type = "COSINE"
892
+ logger.info("Forced CPU mode - updated loader settings")
893
+
894
+ # Run the data loading process
895
+ loader.run()
896
+
897
+ logger.info("=== Data Loading Completed Successfully ===")
898
+
899
+ except KeyboardInterrupt:
900
+ logger.info("Data loading interrupted by user")
901
+ sys.exit(1)
902
+ except Exception as e:
903
+ logger.error("Fatal error during data loading: %s", str(e))
904
+ import traceback
905
+
906
+ logger.error("Full traceback: %s", traceback.format_exc())
907
+ sys.exit(1)
506
908
 
507
909
 
508
910
  if __name__ == "__main__":