lumen-resources 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -51,7 +51,7 @@ from .lumen_config import LumenConfig, Region, Runtime
51
51
  from .lumen_config_validator import load_and_validate_config
52
52
  from .model_info import Metadata, ModelInfo, Runtimes, Source
53
53
  from .model_info_validator import load_and_validate_model_info
54
- from .result_schemas import OCRV1, EmbeddingV1, FaceV1, LabelsV1
54
+ from .result_schemas import OCRV1, EmbeddingV1, FaceV1, LabelsV1, TextGenerationV1
55
55
 
56
56
  __version__ = "0.1.0"
57
57
 
@@ -72,6 +72,7 @@ __all__ = [
72
72
  "EmbeddingV1",
73
73
  "LabelsV1",
74
74
  "OCRV1",
75
+ "TextGenerationV1",
75
76
  # Downloader
76
77
  "Downloader",
77
78
  "DownloadResult",
@@ -139,13 +139,18 @@ class Downloader:
139
139
 
140
140
  for alias, model_config in service_config.models.items():
141
141
  model_type = f"{service_name}:{alias}"
142
+ prefer_fp16 = False
143
+ if service_config.backend_settings:
144
+ prefer_fp16 = service_config.backend_settings.prefer_fp16 or False
142
145
 
143
146
  if self.verbose:
144
147
  print(f"\n📦 Processing {model_type.upper()}")
145
148
  print(f" Model: {model_config.model}")
146
149
  print(f" Runtime: {model_config.runtime.value}")
147
150
 
148
- result = self._download_model(model_type, model_config, force)
151
+ result = self._download_model(
152
+ model_type, model_config, force, prefer_fp16
153
+ )
149
154
  results[model_type] = result
150
155
 
151
156
  # Print result
@@ -163,7 +168,7 @@ class Downloader:
163
168
 
164
169
  return results
165
170
 
166
- def _get_runtime_patterns(self, runtime: Runtime) -> list[str]:
171
+ def _get_runtime_patterns(self, runtime: Runtime, pref_fp16: bool) -> list[str]:
167
172
  """Get file patterns to download based on runtime.
168
173
 
169
174
  Determines which file patterns to include in downloads based on the
@@ -171,12 +176,13 @@ class Downloader:
171
176
 
172
177
  Args:
173
178
  runtime: The model runtime (torch, onnx, rknn).
179
+ pref_fp16: Whether to prefer FP16 models over FP32.
174
180
 
175
181
  Returns:
176
182
  List of file glob patterns for the download.
177
183
 
178
184
  Example:
179
- >>> patterns = downloader._get_runtime_patterns(Runtime.torch)
185
+ >>> patterns = downloader._get_runtime_patterns(Runtime.torch, False)
180
186
  >>> print("model_info.json" in patterns)
181
187
  True
182
188
  """
@@ -203,14 +209,17 @@ class Downloader:
203
209
  elif runtime == Runtime.onnx:
204
210
  patterns.extend(
205
211
  [
206
- "*.onnx",
207
- "*.ort",
208
212
  "*vocab*",
209
213
  "*tokenizer*",
210
214
  "special_tokens_map.json",
211
215
  "preprocessor_config.json",
212
216
  ]
213
217
  )
218
+ # Only add one precision based on preference to save space
219
+ if pref_fp16:
220
+ patterns.extend(["*.fp16.onnx"])
221
+ else:
222
+ patterns.extend(["*.fp32.onnx"])
214
223
  elif runtime == Runtime.rknn:
215
224
  patterns.extend(
216
225
  [
@@ -225,18 +234,94 @@ class Downloader:
225
234
  return patterns
226
235
 
227
236
  def _download_model(
228
- self, model_type: str, model_config: ModelConfig, force: bool
237
+ self, model_type: str, model_config: ModelConfig, force: bool, pref_fp16: bool
238
+ ) -> DownloadResult:
239
+ """Download a single model with its runtime files using fallback strategy.
240
+
241
+ First attempts to download with the preferred precision (FP16/FP32),
242
+ and if that fails due to file mismatch, falls back to the other precision.
243
+ This ensures model availability while minimizing storage usage.
244
+
245
+ Args:
246
+ model_type: Identifier for the model (e.g., "clip:default").
247
+ model_config: Model configuration from LumenConfig.
248
+ force: Whether to force re-download even if already cached.
249
+ pref_fp16: Whether to prefer FP16 models over FP32.
250
+
251
+ Returns:
252
+ DownloadResult with success status, file paths, and error details.
253
+
254
+ Raises:
255
+ DownloadError: If platform download fails for both precisions.
256
+ ModelInfoError: If model_info.json is missing or invalid.
257
+ ValidationError: If model configuration is not supported.
258
+ """
259
+ # First attempt with preferred precision
260
+ preferred_patterns = self._get_runtime_patterns(model_config.runtime, pref_fp16)
261
+ fallback_patterns = self._get_runtime_patterns(
262
+ model_config.runtime, not pref_fp16
263
+ )
264
+
265
+ # Try preferred precision first
266
+ try:
267
+ return self._download_model_with_patterns(
268
+ model_type, model_config, force, preferred_patterns, pref_fp16
269
+ )
270
+ except DownloadError as e:
271
+ # Check if this is a "no matching files" error that warrants fallback
272
+ if (
273
+ self._should_fallback_download(str(e))
274
+ and model_config.runtime == Runtime.onnx
275
+ ):
276
+ precision = "FP16" if pref_fp16 else "FP32"
277
+ fallback_precision = "FP32" if pref_fp16 else "FP16"
278
+ if self.verbose:
279
+ print(
280
+ f" ⚠️ {precision} model not found, trying {fallback_precision}"
281
+ )
282
+
283
+ try:
284
+ return self._download_model_with_patterns(
285
+ model_type,
286
+ model_config,
287
+ force,
288
+ fallback_patterns,
289
+ not pref_fp16,
290
+ )
291
+ except DownloadError as fallback_error:
292
+ # If fallback also fails, report both errors
293
+ return DownloadResult(
294
+ model_type=model_type,
295
+ model_name=model_config.model,
296
+ runtime=model_config.runtime.value
297
+ if hasattr(model_config.runtime, "value")
298
+ else str(model_config.runtime),
299
+ success=False,
300
+ error=f"Failed to download with {precision}: {e}. "
301
+ f"Fallback with {fallback_precision} also failed: {fallback_error}",
302
+ )
303
+
304
+ # Non-fallbackable error or non-ONNX runtime, just report original error
305
+ raise
306
+
307
+ def _download_model_with_patterns(
308
+ self,
309
+ model_type: str,
310
+ model_config: ModelConfig,
311
+ force: bool,
312
+ patterns: list[str],
313
+ is_fp16: bool | None = None,
229
314
  ) -> DownloadResult:
230
- """Download a single model with its runtime files.
315
+ """Download a model with specific file patterns.
231
316
 
232
- Handles the complete download process for a single model including
233
- runtime files, metadata validation, and integrity checks. Performs
234
- rollback on failure by cleaning up downloaded files.
317
+ This is the core download method that handles the actual downloading
318
+ with a given set of file patterns.
235
319
 
236
320
  Args:
237
321
  model_type: Identifier for the model (e.g., "clip:default").
238
322
  model_config: Model configuration from LumenConfig.
239
323
  force: Whether to force re-download even if already cached.
324
+ patterns: File patterns to include in the download.
240
325
 
241
326
  Returns:
242
327
  DownloadResult with success status, file paths, and error details.
@@ -255,8 +340,6 @@ class Downloader:
255
340
  )
256
341
 
257
342
  try:
258
- # Phase 1: Prepare and download runtime + JSON only
259
- patterns = self._get_runtime_patterns(model_config.runtime)
260
343
  cache_dir = Path(self.config.metadata.cache_dir).expanduser()
261
344
 
262
345
  model_path = self.platform.download_model(
@@ -278,7 +361,10 @@ class Downloader:
278
361
  if model_config.dataset and model_info.datasets:
279
362
  dataset_files = model_info.datasets.get(model_config.dataset)
280
363
  if dataset_files:
281
- for file_rel in [dataset_files.labels, dataset_files.embeddings]:
364
+ for file_rel in [
365
+ dataset_files.labels,
366
+ dataset_files.embeddings,
367
+ ]:
282
368
  dataset_path = model_path / file_rel
283
369
  if not dataset_path.exists():
284
370
  # Download only the dataset file by its relative path
@@ -295,7 +381,9 @@ class Downloader:
295
381
  )
296
382
 
297
383
  # Final: File integrity validation
298
- missing = self._validate_files(model_path, model_info, model_config)
384
+ missing = self._validate_files(
385
+ model_path, model_info, model_config, is_fp16
386
+ )
299
387
  result.missing_files = missing
300
388
 
301
389
  if missing:
@@ -318,142 +406,151 @@ class Downloader:
318
406
 
319
407
  return result
320
408
 
409
+ def _should_fallback_download(self, error_message: str) -> bool:
410
+ """Determine if a download error should trigger fallback to another precision.
411
+
412
+ Args:
413
+ error_message: The error message from the download attempt.
414
+
415
+ Returns:
416
+ True if the error suggests we should try the other precision, False otherwise.
417
+ """
418
+ # Common patterns that indicate file matching issues
419
+ fallback_indicators = [
420
+ "No matching files found",
421
+ "No files matched the pattern",
422
+ "Cannot find any files matching",
423
+ "File pattern matched no files",
424
+ "No such file or directory", # Sometimes used for remote files
425
+ ]
426
+
427
+ error_lower = error_message.lower()
428
+ return any(
429
+ indicator.lower() in error_lower for indicator in fallback_indicators
430
+ )
431
+
321
432
  def _load_model_info(self, model_path: Path) -> ModelInfo:
322
433
  """Load and parse model_info.json using validator.
323
434
 
324
- Loads the model_info.json file from the model directory and validates
325
- it against the ModelInfo schema to ensure metadata integrity.
326
-
327
435
  Args:
328
- model_path: Path to the downloaded model directory.
436
+ model_path: Local path where model files are located.
329
437
 
330
438
  Returns:
331
- Validated ModelInfo object containing model metadata.
439
+ Parsed ModelInfo object.
332
440
 
333
441
  Raises:
334
- ModelInfoError: If model_info.json is missing or fails validation.
335
-
336
- Example:
337
- >>> model_info = downloader._load_model_info(Path("/models/clip_vit_b32"))
338
- >>> print(model_info.name)
339
- 'ViT-B-32'
442
+ ModelInfoError: If model_info.json is missing or invalid.
340
443
  """
341
- info_file = model_path / "model_info.json"
342
-
343
- if not info_file.exists():
344
- msg = (
345
- f"model_info.json not found in {model_path}. "
346
- "Repository must contain model metadata."
347
- )
348
- raise ModelInfoError(msg)
444
+ info_path = model_path / "model_info.json"
445
+ if not info_path.exists():
446
+ raise ModelInfoError(f"Missing model_info.json in {model_path}")
349
447
 
350
448
  try:
351
- return load_and_validate_model_info(info_file)
449
+ return load_and_validate_model_info(info_path)
352
450
  except Exception as e:
353
- raise ModelInfoError(f"Failed to load/validate model_info.json: {e}")
451
+ raise ModelInfoError(f"Failed to load model_info.json: {e}")
354
452
 
355
453
  def _validate_model_config(
356
454
  self, model_info: ModelInfo, model_config: ModelConfig
357
455
  ) -> None:
358
- """Validate that model supports the requested configuration.
456
+ """Validate model configuration against model_info.json.
359
457
 
360
- Checks if the model metadata indicates support for the requested
361
- runtime, device (for RKNN), and dataset configurations.
458
+ Checks that the requested runtime and dataset are supported
459
+ by the model metadata.
362
460
 
363
461
  Args:
364
- model_info: Validated model metadata from model_info.json.
365
- model_config: Requested model configuration from LumenConfig.
462
+ model_info: Parsed model information.
463
+ model_config: Model configuration to validate.
366
464
 
367
465
  Raises:
368
- ValidationError: If the requested configuration is not supported
369
- by the model according to its metadata.
370
-
371
- Example:
372
- >>> downloader._validate_model_config(model_info, model_config) # No exception
373
- >>> # If runtime is not supported, raises ValidationError
466
+ ValidationError: If configuration is not supported by the model.
374
467
  """
375
- runtime = model_config.runtime
376
- rknn_device = model_config.rknn_device
377
-
378
- # Check if runtime is available
379
- runtime_config = model_info.runtimes.get(runtime.value)
380
- if not runtime_config or not runtime_config.available:
468
+ # Validate runtime support
469
+ runtime_str = (
470
+ model_config.runtime.value
471
+ if hasattr(model_config.runtime, "value")
472
+ else str(model_config.runtime)
473
+ )
474
+ if runtime_str not in model_info.runtimes:
381
475
  raise ValidationError(
382
- f"Model {model_config.model} does not support runtime '{runtime.value}'"
476
+ f"Runtime {runtime_str} not supported by model {model_config.model}. "
477
+ f"Supported runtimes: {', '.join(model_info.runtimes)}"
383
478
  )
384
479
 
385
- # Special check for RKNN device support
386
- if runtime == Runtime.rknn and rknn_device:
387
- if not runtime_config.devices or rknn_device not in runtime_config.devices:
388
- msg = (
389
- f"Model {model_config.model} does not support "
390
- f"runtime '{runtime.value}' with device '{rknn_device}'"
391
- )
392
- raise ValidationError(msg)
393
-
394
480
  # Validate dataset if specified
395
481
  if model_config.dataset:
396
482
  if (
397
483
  not model_info.datasets
398
484
  or model_config.dataset not in model_info.datasets
399
485
  ):
400
- msg = (
401
- f"Dataset '{model_config.dataset}' not available for "
402
- f"model {model_config.model}"
486
+ raise ValidationError(
487
+ f"Dataset {model_config.dataset} not supported by model {model_config.model}. "
488
+ f"Available datasets: {', '.join(model_info.datasets.keys() if model_info.datasets else [])}"
403
489
  )
404
- raise ValidationError(msg)
490
+
491
+ # Validate RKNN device if RKNN runtime
492
+ if model_config.runtime == Runtime.rknn and not model_config.rknn_device:
493
+ raise ValidationError(
494
+ f"RKNN runtime requires rknn_device specification for model {model_config.model}"
495
+ )
405
496
 
406
497
  def _validate_files(
407
- self, model_path: Path, model_info: ModelInfo, model_config: ModelConfig
498
+ self,
499
+ model_path: Path,
500
+ model_info: ModelInfo,
501
+ model_config: ModelConfig,
502
+ is_fp16: bool | None = None,
408
503
  ) -> list[str]:
409
- """Validate that required model files exist.
504
+ """Validate that all required files are present after download.
410
505
 
411
- Checks that all required files for the specified runtime and device
412
- are present in the downloaded model directory. Also validates dataset
413
- files if specified in the configuration.
506
+ Checks model files, tokenizer files, and dataset files against
507
+ the model_info.json metadata based on the actual precision
508
+ downloaded.
414
509
 
415
510
  Args:
416
- model_path: Path to the downloaded model directory.
417
- model_info: Validated model metadata from model_info.json.
418
- model_config: Model configuration specifying runtime and dataset.
511
+ model_path: Local path where model files are located.
512
+ model_info: Parsed model information.
513
+ model_config: Model configuration to validate.
514
+ is_fp16: Whether FP16 files were downloaded (None for non-ONNX
515
+ runtimes).
419
516
 
420
517
  Returns:
421
- List of missing file paths relative to model directory.
422
- Empty list if all required files are present.
518
+ List of missing file paths. Empty list if all files present.
423
519
 
424
- Example:
425
- >>> missing = downloader._validate_files(model_path, model_info, model_config)
426
- >>> if not missing:
427
- ... print("All required files are present")
428
- ... else:
429
- ... print(f"Missing files: {missing}")
520
+ Raises:
521
+ ValidationError: If critical files are missing.
430
522
  """
431
523
  missing: list[str] = []
524
+ runtime_str = (
525
+ model_config.runtime.value
526
+ if hasattr(model_config.runtime, "value")
527
+ else str(model_config.runtime)
528
+ )
432
529
 
433
- # Get runtime configuration
434
- runtime_config = model_info.runtimes.get(model_config.runtime.value)
435
- if not runtime_config or not runtime_config.files:
436
- return missing
437
-
438
- # Get required files based on runtime type
439
- if model_config.runtime == Runtime.rknn and model_config.rknn_device:
440
- # RKNN files are organized by device
441
- if isinstance(runtime_config.files, dict):
442
- required_files = runtime_config.files.get(model_config.rknn_device, [])
443
- else:
444
- required_files = []
445
- else:
446
- # Other runtimes have simple list
530
+ # Check runtime files
531
+ runtime_config = model_info.runtimes.get(runtime_str)
532
+ if runtime_config and runtime_config.files:
447
533
  if isinstance(runtime_config.files, list):
448
- required_files = runtime_config.files
534
+ runtime_files = runtime_config.files
535
+
536
+ # For ONNX runtime, filter by precision if specified
537
+ if runtime_str == "onnx" and is_fp16 is not None:
538
+ precision_str = "fp16" if is_fp16 else "fp32"
539
+ runtime_files = [
540
+ f
541
+ for f in runtime_files
542
+ if not f.endswith((".fp16.onnx", ".fp32.onnx"))
543
+ or f.endswith(f".{precision_str}.onnx")
544
+ ]
545
+ elif isinstance(runtime_config.files, dict) and model_config.rknn_device:
546
+ # RKNN files are organized by device
547
+ runtime_files = runtime_config.files.get(model_config.rknn_device, [])
449
548
  else:
450
- required_files = []
549
+ runtime_files = []
451
550
 
452
- # Check each required file
453
- for file_path in required_files:
454
- full_path = model_path / file_path
455
- if not full_path.exists():
456
- missing.append(file_path)
551
+ for file_name in runtime_files:
552
+ if not (model_path / file_name).exists():
553
+ missing.append(file_name)
457
554
 
458
555
  # Check dataset files if specified
459
556
  if model_config.dataset and model_info.datasets: