lumen-resources 0.3.1__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lumen_resources-0.3.1/src/lumen_resources.egg-info → lumen_resources-0.4.0}/PKG-INFO +1 -1
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/pyproject.toml +34 -4
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources/__init__.py +2 -1
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources/downloader.py +200 -103
- lumen_resources-0.4.0/src/lumen_resources/lumen_config.py +251 -0
- lumen_resources-0.4.0/src/lumen_resources/result_schemas/README.md +29 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources/result_schemas/__init__.py +3 -3
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources/result_schemas/embedding_v1.py +1 -1
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources/result_schemas/face_v1.py +1 -1
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources/result_schemas/labels_v1.py +1 -1
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources/result_schemas/ocr_v1.py +1 -1
- lumen_resources-0.4.0/src/lumen_resources/result_schemas/text_generation_v1.py +89 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources/schemas/config-schema.yaml +14 -5
- lumen_resources-0.4.0/src/lumen_resources/schemas/result_schemas/text_generation_v1.json +94 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0/src/lumen_resources.egg-info}/PKG-INFO +1 -1
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources.egg-info/SOURCES.txt +3 -2
- lumen_resources-0.3.1/src/lumen_resources/lumen_config.py +0 -459
- lumen_resources-0.3.1/src/lumen_resources/result_schemas/README.md +0 -14
- lumen_resources-0.3.1/uv.lock +0 -1320
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/.gitignore +0 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/README.md +0 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/docs/examples/clip_torch_cn.yaml +0 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/docs/examples/hub-service.yaml +0 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/docs/examples/model_info_template.json +0 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/docs/examples/single-service.yaml +0 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/setup.cfg +0 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources/cli.py +0 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources/exceptions.py +0 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources/lumen_config_validator.py +1 -1
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources/model_info.py +0 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources/model_info_validator.py +0 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources/platform.py +0 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources/schemas/model_info-schema.json +0 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources/schemas/result_schemas/embedding_v1.json +0 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources/schemas/result_schemas/face_v1.json +0 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources/schemas/result_schemas/labels_v1.json +0 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources/schemas/result_schemas/ocr_v1.json +0 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources.egg-info/dependency_links.txt +0 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources.egg-info/entry_points.txt +0 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources.egg-info/requires.txt +0 -0
- {lumen_resources-0.3.1 → lumen_resources-0.4.0}/src/lumen_resources.egg-info/top_level.txt +0 -0
|
@@ -33,10 +33,6 @@ config = ["pyyaml>=6.0.0"]
|
|
|
33
33
|
[project.scripts]
|
|
34
34
|
lumen-resources = "lumen_resources.cli:main"
|
|
35
35
|
|
|
36
|
-
[tool.ruff.lint]
|
|
37
|
-
select = ["E", "F", "W", "I", "N"]
|
|
38
|
-
ignore = []
|
|
39
|
-
|
|
40
36
|
[tool.pytest.ini_options]
|
|
41
37
|
testpaths = ["tests"]
|
|
42
38
|
python_files = ["test_*.py"]
|
|
@@ -55,3 +51,37 @@ where = ["src"]
|
|
|
55
51
|
|
|
56
52
|
[tool.setuptools.package-data]
|
|
57
53
|
"lumen_resources" = ["schemas/*.yaml", "schemas/*.json"]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
[tool.ruff]
|
|
57
|
+
line-length = 88
|
|
58
|
+
target-version = "py310"
|
|
59
|
+
|
|
60
|
+
[tool.ruff.lint]
|
|
61
|
+
select = [
|
|
62
|
+
"E", # pycodestyle errors
|
|
63
|
+
"W", # pycodestyle warnings
|
|
64
|
+
"F", # pyflakes
|
|
65
|
+
"I", # isort
|
|
66
|
+
"B", # flake8-bugbear
|
|
67
|
+
"C4", # flake8-comprehensions
|
|
68
|
+
"UP", # pyupgrade
|
|
69
|
+
]
|
|
70
|
+
ignore = [
|
|
71
|
+
"E501", # line too long, handled by black
|
|
72
|
+
"B008", # do not perform function calls in argument defaults
|
|
73
|
+
"C901", # too complex
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
[tool.ruff.lint.per-file-ignores]
|
|
77
|
+
"proto/*" = [
|
|
78
|
+
"E402", # module level import not at top of file
|
|
79
|
+
"F401", # unused imports (generated files)
|
|
80
|
+
"F403", # star imports (generated files)
|
|
81
|
+
"F405", # undefined names from star imports (generated files)
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
[dependency-groups]
|
|
85
|
+
dev = [
|
|
86
|
+
"ruff>=0.14.2",
|
|
87
|
+
]
|
|
@@ -51,7 +51,7 @@ from .lumen_config import LumenConfig, Region, Runtime
|
|
|
51
51
|
from .lumen_config_validator import load_and_validate_config
|
|
52
52
|
from .model_info import Metadata, ModelInfo, Runtimes, Source
|
|
53
53
|
from .model_info_validator import load_and_validate_model_info
|
|
54
|
-
from .result_schemas import OCRV1, EmbeddingV1, FaceV1, LabelsV1
|
|
54
|
+
from .result_schemas import OCRV1, EmbeddingV1, FaceV1, LabelsV1, TextGenerationV1
|
|
55
55
|
|
|
56
56
|
__version__ = "0.1.0"
|
|
57
57
|
|
|
@@ -72,6 +72,7 @@ __all__ = [
|
|
|
72
72
|
"EmbeddingV1",
|
|
73
73
|
"LabelsV1",
|
|
74
74
|
"OCRV1",
|
|
75
|
+
"TextGenerationV1",
|
|
75
76
|
# Downloader
|
|
76
77
|
"Downloader",
|
|
77
78
|
"DownloadResult",
|
|
@@ -139,13 +139,18 @@ class Downloader:
|
|
|
139
139
|
|
|
140
140
|
for alias, model_config in service_config.models.items():
|
|
141
141
|
model_type = f"{service_name}:{alias}"
|
|
142
|
+
prefer_fp16 = False
|
|
143
|
+
if service_config.backend_settings:
|
|
144
|
+
prefer_fp16 = service_config.backend_settings.prefer_fp16 or False
|
|
142
145
|
|
|
143
146
|
if self.verbose:
|
|
144
147
|
print(f"\n📦 Processing {model_type.upper()}")
|
|
145
148
|
print(f" Model: {model_config.model}")
|
|
146
149
|
print(f" Runtime: {model_config.runtime.value}")
|
|
147
150
|
|
|
148
|
-
result = self._download_model(
|
|
151
|
+
result = self._download_model(
|
|
152
|
+
model_type, model_config, force, prefer_fp16
|
|
153
|
+
)
|
|
149
154
|
results[model_type] = result
|
|
150
155
|
|
|
151
156
|
# Print result
|
|
@@ -163,7 +168,7 @@ class Downloader:
|
|
|
163
168
|
|
|
164
169
|
return results
|
|
165
170
|
|
|
166
|
-
def _get_runtime_patterns(self, runtime: Runtime) -> list[str]:
|
|
171
|
+
def _get_runtime_patterns(self, runtime: Runtime, pref_fp16: bool) -> list[str]:
|
|
167
172
|
"""Get file patterns to download based on runtime.
|
|
168
173
|
|
|
169
174
|
Determines which file patterns to include in downloads based on the
|
|
@@ -171,12 +176,13 @@ class Downloader:
|
|
|
171
176
|
|
|
172
177
|
Args:
|
|
173
178
|
runtime: The model runtime (torch, onnx, rknn).
|
|
179
|
+
pref_fp16: Whether to prefer FP16 models over FP32.
|
|
174
180
|
|
|
175
181
|
Returns:
|
|
176
182
|
List of file glob patterns for the download.
|
|
177
183
|
|
|
178
184
|
Example:
|
|
179
|
-
>>> patterns = downloader._get_runtime_patterns(Runtime.torch)
|
|
185
|
+
>>> patterns = downloader._get_runtime_patterns(Runtime.torch, False)
|
|
180
186
|
>>> print("model_info.json" in patterns)
|
|
181
187
|
True
|
|
182
188
|
"""
|
|
@@ -203,14 +209,17 @@ class Downloader:
|
|
|
203
209
|
elif runtime == Runtime.onnx:
|
|
204
210
|
patterns.extend(
|
|
205
211
|
[
|
|
206
|
-
"*.onnx",
|
|
207
|
-
"*.ort",
|
|
208
212
|
"*vocab*",
|
|
209
213
|
"*tokenizer*",
|
|
210
214
|
"special_tokens_map.json",
|
|
211
215
|
"preprocessor_config.json",
|
|
212
216
|
]
|
|
213
217
|
)
|
|
218
|
+
# Only add one precision based on preference to save space
|
|
219
|
+
if pref_fp16:
|
|
220
|
+
patterns.extend(["*.fp16.onnx"])
|
|
221
|
+
else:
|
|
222
|
+
patterns.extend(["*.fp32.onnx"])
|
|
214
223
|
elif runtime == Runtime.rknn:
|
|
215
224
|
patterns.extend(
|
|
216
225
|
[
|
|
@@ -225,18 +234,94 @@ class Downloader:
|
|
|
225
234
|
return patterns
|
|
226
235
|
|
|
227
236
|
def _download_model(
|
|
228
|
-
self, model_type: str, model_config: ModelConfig, force: bool
|
|
237
|
+
self, model_type: str, model_config: ModelConfig, force: bool, pref_fp16: bool
|
|
238
|
+
) -> DownloadResult:
|
|
239
|
+
"""Download a single model with its runtime files using fallback strategy.
|
|
240
|
+
|
|
241
|
+
First attempts to download with the preferred precision (FP16/FP32),
|
|
242
|
+
and if that fails due to file mismatch, falls back to the other precision.
|
|
243
|
+
This ensures model availability while minimizing storage usage.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
model_type: Identifier for the model (e.g., "clip:default").
|
|
247
|
+
model_config: Model configuration from LumenConfig.
|
|
248
|
+
force: Whether to force re-download even if already cached.
|
|
249
|
+
pref_fp16: Whether to prefer FP16 models over FP32.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
DownloadResult with success status, file paths, and error details.
|
|
253
|
+
|
|
254
|
+
Raises:
|
|
255
|
+
DownloadError: If platform download fails for both precisions.
|
|
256
|
+
ModelInfoError: If model_info.json is missing or invalid.
|
|
257
|
+
ValidationError: If model configuration is not supported.
|
|
258
|
+
"""
|
|
259
|
+
# First attempt with preferred precision
|
|
260
|
+
preferred_patterns = self._get_runtime_patterns(model_config.runtime, pref_fp16)
|
|
261
|
+
fallback_patterns = self._get_runtime_patterns(
|
|
262
|
+
model_config.runtime, not pref_fp16
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
# Try preferred precision first
|
|
266
|
+
try:
|
|
267
|
+
return self._download_model_with_patterns(
|
|
268
|
+
model_type, model_config, force, preferred_patterns, pref_fp16
|
|
269
|
+
)
|
|
270
|
+
except DownloadError as e:
|
|
271
|
+
# Check if this is a "no matching files" error that warrants fallback
|
|
272
|
+
if (
|
|
273
|
+
self._should_fallback_download(str(e))
|
|
274
|
+
and model_config.runtime == Runtime.onnx
|
|
275
|
+
):
|
|
276
|
+
precision = "FP16" if pref_fp16 else "FP32"
|
|
277
|
+
fallback_precision = "FP32" if pref_fp16 else "FP16"
|
|
278
|
+
if self.verbose:
|
|
279
|
+
print(
|
|
280
|
+
f" ⚠️ {precision} model not found, trying {fallback_precision}"
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
try:
|
|
284
|
+
return self._download_model_with_patterns(
|
|
285
|
+
model_type,
|
|
286
|
+
model_config,
|
|
287
|
+
force,
|
|
288
|
+
fallback_patterns,
|
|
289
|
+
not pref_fp16,
|
|
290
|
+
)
|
|
291
|
+
except DownloadError as fallback_error:
|
|
292
|
+
# If fallback also fails, report both errors
|
|
293
|
+
return DownloadResult(
|
|
294
|
+
model_type=model_type,
|
|
295
|
+
model_name=model_config.model,
|
|
296
|
+
runtime=model_config.runtime.value
|
|
297
|
+
if hasattr(model_config.runtime, "value")
|
|
298
|
+
else str(model_config.runtime),
|
|
299
|
+
success=False,
|
|
300
|
+
error=f"Failed to download with {precision}: {e}. "
|
|
301
|
+
f"Fallback with {fallback_precision} also failed: {fallback_error}",
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# Non-fallbackable error or non-ONNX runtime, just report original error
|
|
305
|
+
raise
|
|
306
|
+
|
|
307
|
+
def _download_model_with_patterns(
|
|
308
|
+
self,
|
|
309
|
+
model_type: str,
|
|
310
|
+
model_config: ModelConfig,
|
|
311
|
+
force: bool,
|
|
312
|
+
patterns: list[str],
|
|
313
|
+
is_fp16: bool | None = None,
|
|
229
314
|
) -> DownloadResult:
|
|
230
|
-
"""Download a
|
|
315
|
+
"""Download a model with specific file patterns.
|
|
231
316
|
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
rollback on failure by cleaning up downloaded files.
|
|
317
|
+
This is the core download method that handles the actual downloading
|
|
318
|
+
with a given set of file patterns.
|
|
235
319
|
|
|
236
320
|
Args:
|
|
237
321
|
model_type: Identifier for the model (e.g., "clip:default").
|
|
238
322
|
model_config: Model configuration from LumenConfig.
|
|
239
323
|
force: Whether to force re-download even if already cached.
|
|
324
|
+
patterns: File patterns to include in the download.
|
|
240
325
|
|
|
241
326
|
Returns:
|
|
242
327
|
DownloadResult with success status, file paths, and error details.
|
|
@@ -255,8 +340,6 @@ class Downloader:
|
|
|
255
340
|
)
|
|
256
341
|
|
|
257
342
|
try:
|
|
258
|
-
# Phase 1: Prepare and download runtime + JSON only
|
|
259
|
-
patterns = self._get_runtime_patterns(model_config.runtime)
|
|
260
343
|
cache_dir = Path(self.config.metadata.cache_dir).expanduser()
|
|
261
344
|
|
|
262
345
|
model_path = self.platform.download_model(
|
|
@@ -278,7 +361,10 @@ class Downloader:
|
|
|
278
361
|
if model_config.dataset and model_info.datasets:
|
|
279
362
|
dataset_files = model_info.datasets.get(model_config.dataset)
|
|
280
363
|
if dataset_files:
|
|
281
|
-
for file_rel in [
|
|
364
|
+
for file_rel in [
|
|
365
|
+
dataset_files.labels,
|
|
366
|
+
dataset_files.embeddings,
|
|
367
|
+
]:
|
|
282
368
|
dataset_path = model_path / file_rel
|
|
283
369
|
if not dataset_path.exists():
|
|
284
370
|
# Download only the dataset file by its relative path
|
|
@@ -295,7 +381,9 @@ class Downloader:
|
|
|
295
381
|
)
|
|
296
382
|
|
|
297
383
|
# Final: File integrity validation
|
|
298
|
-
missing = self._validate_files(
|
|
384
|
+
missing = self._validate_files(
|
|
385
|
+
model_path, model_info, model_config, is_fp16
|
|
386
|
+
)
|
|
299
387
|
result.missing_files = missing
|
|
300
388
|
|
|
301
389
|
if missing:
|
|
@@ -318,142 +406,151 @@ class Downloader:
|
|
|
318
406
|
|
|
319
407
|
return result
|
|
320
408
|
|
|
409
|
+
def _should_fallback_download(self, error_message: str) -> bool:
|
|
410
|
+
"""Determine if a download error should trigger fallback to another precision.
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
error_message: The error message from the download attempt.
|
|
414
|
+
|
|
415
|
+
Returns:
|
|
416
|
+
True if the error suggests we should try the other precision, False otherwise.
|
|
417
|
+
"""
|
|
418
|
+
# Common patterns that indicate file matching issues
|
|
419
|
+
fallback_indicators = [
|
|
420
|
+
"No matching files found",
|
|
421
|
+
"No files matched the pattern",
|
|
422
|
+
"Cannot find any files matching",
|
|
423
|
+
"File pattern matched no files",
|
|
424
|
+
"No such file or directory", # Sometimes used for remote files
|
|
425
|
+
]
|
|
426
|
+
|
|
427
|
+
error_lower = error_message.lower()
|
|
428
|
+
return any(
|
|
429
|
+
indicator.lower() in error_lower for indicator in fallback_indicators
|
|
430
|
+
)
|
|
431
|
+
|
|
321
432
|
def _load_model_info(self, model_path: Path) -> ModelInfo:
|
|
322
433
|
"""Load and parse model_info.json using validator.
|
|
323
434
|
|
|
324
|
-
Loads the model_info.json file from the model directory and validates
|
|
325
|
-
it against the ModelInfo schema to ensure metadata integrity.
|
|
326
|
-
|
|
327
435
|
Args:
|
|
328
|
-
model_path:
|
|
436
|
+
model_path: Local path where model files are located.
|
|
329
437
|
|
|
330
438
|
Returns:
|
|
331
|
-
|
|
439
|
+
Parsed ModelInfo object.
|
|
332
440
|
|
|
333
441
|
Raises:
|
|
334
|
-
ModelInfoError: If model_info.json is missing or
|
|
335
|
-
|
|
336
|
-
Example:
|
|
337
|
-
>>> model_info = downloader._load_model_info(Path("/models/clip_vit_b32"))
|
|
338
|
-
>>> print(model_info.name)
|
|
339
|
-
'ViT-B-32'
|
|
442
|
+
ModelInfoError: If model_info.json is missing or invalid.
|
|
340
443
|
"""
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
msg = (
|
|
345
|
-
f"model_info.json not found in {model_path}. "
|
|
346
|
-
"Repository must contain model metadata."
|
|
347
|
-
)
|
|
348
|
-
raise ModelInfoError(msg)
|
|
444
|
+
info_path = model_path / "model_info.json"
|
|
445
|
+
if not info_path.exists():
|
|
446
|
+
raise ModelInfoError(f"Missing model_info.json in {model_path}")
|
|
349
447
|
|
|
350
448
|
try:
|
|
351
|
-
return load_and_validate_model_info(
|
|
449
|
+
return load_and_validate_model_info(info_path)
|
|
352
450
|
except Exception as e:
|
|
353
|
-
raise ModelInfoError(f"Failed to load
|
|
451
|
+
raise ModelInfoError(f"Failed to load model_info.json: {e}")
|
|
354
452
|
|
|
355
453
|
def _validate_model_config(
|
|
356
454
|
self, model_info: ModelInfo, model_config: ModelConfig
|
|
357
455
|
) -> None:
|
|
358
|
-
"""Validate
|
|
456
|
+
"""Validate model configuration against model_info.json.
|
|
359
457
|
|
|
360
|
-
Checks
|
|
361
|
-
|
|
458
|
+
Checks that the requested runtime and dataset are supported
|
|
459
|
+
by the model metadata.
|
|
362
460
|
|
|
363
461
|
Args:
|
|
364
|
-
model_info:
|
|
365
|
-
model_config:
|
|
462
|
+
model_info: Parsed model information.
|
|
463
|
+
model_config: Model configuration to validate.
|
|
366
464
|
|
|
367
465
|
Raises:
|
|
368
|
-
ValidationError: If
|
|
369
|
-
by the model according to its metadata.
|
|
370
|
-
|
|
371
|
-
Example:
|
|
372
|
-
>>> downloader._validate_model_config(model_info, model_config) # No exception
|
|
373
|
-
>>> # If runtime is not supported, raises ValidationError
|
|
466
|
+
ValidationError: If configuration is not supported by the model.
|
|
374
467
|
"""
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
468
|
+
# Validate runtime support
|
|
469
|
+
runtime_str = (
|
|
470
|
+
model_config.runtime.value
|
|
471
|
+
if hasattr(model_config.runtime, "value")
|
|
472
|
+
else str(model_config.runtime)
|
|
473
|
+
)
|
|
474
|
+
if runtime_str not in model_info.runtimes:
|
|
381
475
|
raise ValidationError(
|
|
382
|
-
f"
|
|
476
|
+
f"Runtime {runtime_str} not supported by model {model_config.model}. "
|
|
477
|
+
f"Supported runtimes: {', '.join(model_info.runtimes)}"
|
|
383
478
|
)
|
|
384
479
|
|
|
385
|
-
# Special check for RKNN device support
|
|
386
|
-
if runtime == Runtime.rknn and rknn_device:
|
|
387
|
-
if not runtime_config.devices or rknn_device not in runtime_config.devices:
|
|
388
|
-
msg = (
|
|
389
|
-
f"Model {model_config.model} does not support "
|
|
390
|
-
f"runtime '{runtime.value}' with device '{rknn_device}'"
|
|
391
|
-
)
|
|
392
|
-
raise ValidationError(msg)
|
|
393
|
-
|
|
394
480
|
# Validate dataset if specified
|
|
395
481
|
if model_config.dataset:
|
|
396
482
|
if (
|
|
397
483
|
not model_info.datasets
|
|
398
484
|
or model_config.dataset not in model_info.datasets
|
|
399
485
|
):
|
|
400
|
-
|
|
401
|
-
f"Dataset
|
|
402
|
-
f"
|
|
486
|
+
raise ValidationError(
|
|
487
|
+
f"Dataset {model_config.dataset} not supported by model {model_config.model}. "
|
|
488
|
+
f"Available datasets: {', '.join(model_info.datasets.keys() if model_info.datasets else [])}"
|
|
403
489
|
)
|
|
404
|
-
|
|
490
|
+
|
|
491
|
+
# Validate RKNN device if RKNN runtime
|
|
492
|
+
if model_config.runtime == Runtime.rknn and not model_config.rknn_device:
|
|
493
|
+
raise ValidationError(
|
|
494
|
+
f"RKNN runtime requires rknn_device specification for model {model_config.model}"
|
|
495
|
+
)
|
|
405
496
|
|
|
406
497
|
def _validate_files(
|
|
407
|
-
self,
|
|
498
|
+
self,
|
|
499
|
+
model_path: Path,
|
|
500
|
+
model_info: ModelInfo,
|
|
501
|
+
model_config: ModelConfig,
|
|
502
|
+
is_fp16: bool | None = None,
|
|
408
503
|
) -> list[str]:
|
|
409
|
-
"""Validate that required
|
|
504
|
+
"""Validate that all required files are present after download.
|
|
410
505
|
|
|
411
|
-
Checks
|
|
412
|
-
|
|
413
|
-
|
|
506
|
+
Checks model files, tokenizer files, and dataset files against
|
|
507
|
+
the model_info.json metadata based on the actual precision
|
|
508
|
+
downloaded.
|
|
414
509
|
|
|
415
510
|
Args:
|
|
416
|
-
model_path:
|
|
417
|
-
model_info:
|
|
418
|
-
model_config: Model configuration
|
|
511
|
+
model_path: Local path where model files are located.
|
|
512
|
+
model_info: Parsed model information.
|
|
513
|
+
model_config: Model configuration to validate.
|
|
514
|
+
is_fp16: Whether FP16 files were downloaded (None for non-ONNX
|
|
515
|
+
runtimes).
|
|
419
516
|
|
|
420
517
|
Returns:
|
|
421
|
-
List of missing file paths
|
|
422
|
-
Empty list if all required files are present.
|
|
518
|
+
List of missing file paths. Empty list if all files present.
|
|
423
519
|
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
>>> if not missing:
|
|
427
|
-
... print("All required files are present")
|
|
428
|
-
... else:
|
|
429
|
-
... print(f"Missing files: {missing}")
|
|
520
|
+
Raises:
|
|
521
|
+
ValidationError: If critical files are missing.
|
|
430
522
|
"""
|
|
431
523
|
missing: list[str] = []
|
|
524
|
+
runtime_str = (
|
|
525
|
+
model_config.runtime.value
|
|
526
|
+
if hasattr(model_config.runtime, "value")
|
|
527
|
+
else str(model_config.runtime)
|
|
528
|
+
)
|
|
432
529
|
|
|
433
|
-
#
|
|
434
|
-
runtime_config = model_info.runtimes.get(
|
|
435
|
-
if
|
|
436
|
-
return missing
|
|
437
|
-
|
|
438
|
-
# Get required files based on runtime type
|
|
439
|
-
if model_config.runtime == Runtime.rknn and model_config.rknn_device:
|
|
440
|
-
# RKNN files are organized by device
|
|
441
|
-
if isinstance(runtime_config.files, dict):
|
|
442
|
-
required_files = runtime_config.files.get(model_config.rknn_device, [])
|
|
443
|
-
else:
|
|
444
|
-
required_files = []
|
|
445
|
-
else:
|
|
446
|
-
# Other runtimes have simple list
|
|
530
|
+
# Check runtime files
|
|
531
|
+
runtime_config = model_info.runtimes.get(runtime_str)
|
|
532
|
+
if runtime_config and runtime_config.files:
|
|
447
533
|
if isinstance(runtime_config.files, list):
|
|
448
|
-
|
|
534
|
+
runtime_files = runtime_config.files
|
|
535
|
+
|
|
536
|
+
# For ONNX runtime, filter by precision if specified
|
|
537
|
+
if runtime_str == "onnx" and is_fp16 is not None:
|
|
538
|
+
precision_str = "fp16" if is_fp16 else "fp32"
|
|
539
|
+
runtime_files = [
|
|
540
|
+
f
|
|
541
|
+
for f in runtime_files
|
|
542
|
+
if not f.endswith((".fp16.onnx", ".fp32.onnx"))
|
|
543
|
+
or f.endswith(f".{precision_str}.onnx")
|
|
544
|
+
]
|
|
545
|
+
elif isinstance(runtime_config.files, dict) and model_config.rknn_device:
|
|
546
|
+
# RKNN files are organized by device
|
|
547
|
+
runtime_files = runtime_config.files.get(model_config.rknn_device, [])
|
|
449
548
|
else:
|
|
450
|
-
|
|
549
|
+
runtime_files = []
|
|
451
550
|
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
if not full_path.exists():
|
|
456
|
-
missing.append(file_path)
|
|
551
|
+
for file_name in runtime_files:
|
|
552
|
+
if not (model_path / file_name).exists():
|
|
553
|
+
missing.append(file_name)
|
|
457
554
|
|
|
458
555
|
# Check dataset files if specified
|
|
459
556
|
if model_config.dataset and model_info.datasets:
|