isa-model 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +5 -0
- isa_model/core/model_manager.py +143 -0
- isa_model/core/model_registry.py +115 -0
- isa_model/core/model_router.py +226 -0
- isa_model/core/model_storage.py +133 -0
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +202 -0
- isa_model/core/storage/hf_storage.py +0 -0
- isa_model/core/storage/local_storage.py +0 -0
- isa_model/core/storage/minio_storage.py +0 -0
- isa_model/deployment/mlflow_gateway/__init__.py +8 -0
- isa_model/deployment/mlflow_gateway/start_gateway.py +65 -0
- isa_model/deployment/unified_multimodal_client.py +341 -0
- isa_model/inference/__init__.py +11 -0
- isa_model/inference/adapter/triton_adapter.py +453 -0
- isa_model/inference/adapter/unified_api.py +248 -0
- isa_model/inference/ai_factory.py +354 -0
- isa_model/inference/backends/Pytorch/bge_embed_backend.py +188 -0
- isa_model/inference/backends/Pytorch/gemma_backend.py +167 -0
- isa_model/inference/backends/Pytorch/llama_backend.py +166 -0
- isa_model/inference/backends/Pytorch/whisper_backend.py +194 -0
- isa_model/inference/backends/__init__.py +53 -0
- isa_model/inference/backends/base_backend_client.py +26 -0
- isa_model/inference/backends/container_services.py +104 -0
- isa_model/inference/backends/local_services.py +72 -0
- isa_model/inference/backends/openai_client.py +130 -0
- isa_model/inference/backends/replicate_client.py +197 -0
- isa_model/inference/backends/third_party_services.py +239 -0
- isa_model/inference/backends/triton_client.py +97 -0
- isa_model/inference/base.py +46 -0
- isa_model/inference/client_sdk/__init__.py +0 -0
- isa_model/inference/client_sdk/client.py +134 -0
- isa_model/inference/client_sdk/client_data_std.py +34 -0
- isa_model/inference/client_sdk/client_sdk_schema.py +16 -0
- isa_model/inference/client_sdk/exceptions.py +0 -0
- isa_model/inference/engine/triton/model_repository/bge/1/model.py +174 -0
- isa_model/inference/engine/triton/model_repository/gemma/1/model.py +250 -0
- isa_model/inference/engine/triton/model_repository/llama/1/model.py +76 -0
- isa_model/inference/engine/triton/model_repository/whisper/1/model.py +195 -0
- isa_model/inference/providers/__init__.py +19 -0
- isa_model/inference/providers/base_provider.py +30 -0
- isa_model/inference/providers/model_cache_manager.py +341 -0
- isa_model/inference/providers/ollama_provider.py +73 -0
- isa_model/inference/providers/openai_provider.py +87 -0
- isa_model/inference/providers/replicate_provider.py +94 -0
- isa_model/inference/providers/triton_provider.py +439 -0
- isa_model/inference/providers/vllm_provider.py +0 -0
- isa_model/inference/providers/yyds_provider.py +83 -0
- isa_model/inference/services/__init__.py +14 -0
- isa_model/inference/services/audio/fish_speech/handler.py +215 -0
- isa_model/inference/services/audio/runpod_tts_fish_service.py +212 -0
- isa_model/inference/services/audio/triton_speech_service.py +138 -0
- isa_model/inference/services/audio/whisper_service.py +186 -0
- isa_model/inference/services/audio/yyds_audio_service.py +71 -0
- isa_model/inference/services/base_service.py +106 -0
- isa_model/inference/services/base_tts_service.py +66 -0
- isa_model/inference/services/embedding/bge_service.py +183 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +85 -0
- isa_model/inference/services/embedding/ollama_rerank_service.py +118 -0
- isa_model/inference/services/embedding/onnx_rerank_service.py +73 -0
- isa_model/inference/services/llm/__init__.py +16 -0
- isa_model/inference/services/llm/gemma_service.py +143 -0
- isa_model/inference/services/llm/llama_service.py +143 -0
- isa_model/inference/services/llm/ollama_llm_service.py +108 -0
- isa_model/inference/services/llm/openai_llm_service.py +129 -0
- isa_model/inference/services/llm/replicate_llm_service.py +179 -0
- isa_model/inference/services/llm/triton_llm_service.py +230 -0
- isa_model/inference/services/others/table_transformer_service.py +61 -0
- isa_model/inference/services/vision/__init__.py +12 -0
- isa_model/inference/services/vision/helpers/image_utils.py +58 -0
- isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
- isa_model/inference/services/vision/ollama_vision_service.py +60 -0
- isa_model/inference/services/vision/replicate_vision_service.py +241 -0
- isa_model/inference/services/vision/triton_vision_service.py +199 -0
- isa_model/inference/services/vision/yyds_vision_service.py +80 -0
- isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
- isa_model/inference/utils/conversion/onnx_converter.py +0 -0
- isa_model/inference/utils/conversion/torch_converter.py +0 -0
- isa_model/scripts/inference_tracker.py +283 -0
- isa_model/scripts/mlflow_manager.py +379 -0
- isa_model/scripts/model_registry.py +465 -0
- isa_model/scripts/start_mlflow.py +95 -0
- isa_model/scripts/training_tracker.py +257 -0
- isa_model/training/engine/llama_factory/__init__.py +39 -0
- isa_model/training/engine/llama_factory/config.py +115 -0
- isa_model/training/engine/llama_factory/data_adapter.py +284 -0
- isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
- isa_model/training/engine/llama_factory/factory.py +331 -0
- isa_model/training/engine/llama_factory/rl.py +254 -0
- isa_model/training/engine/llama_factory/trainer.py +171 -0
- isa_model/training/image_model/configs/create_config.py +37 -0
- isa_model/training/image_model/configs/create_flux_config.py +26 -0
- isa_model/training/image_model/configs/create_lora_config.py +21 -0
- isa_model/training/image_model/prepare_massed_compute.py +97 -0
- isa_model/training/image_model/prepare_upload.py +17 -0
- isa_model/training/image_model/raw_data/create_captions.py +16 -0
- isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
- isa_model/training/image_model/raw_data/pre_processing.py +200 -0
- isa_model/training/image_model/train/train.py +42 -0
- isa_model/training/image_model/train/train_flux.py +41 -0
- isa_model/training/image_model/train/train_lora.py +57 -0
- isa_model/training/image_model/train_main.py +25 -0
- isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
- isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
- isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
- isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
- isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
- isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
- isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
- isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
- isa_model-0.1.0.dist-info/METADATA +116 -0
- isa_model-0.1.0.dist-info/RECORD +117 -0
- isa_model-0.1.0.dist-info/WHEEL +5 -0
- isa_model-0.1.0.dist-info/licenses/LICENSE +21 -0
- isa_model-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,465 @@
|
|
1
|
+
"""
|
2
|
+
Model registry for managing model versions and stages.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import logging
|
7
|
+
from enum import Enum
|
8
|
+
from typing import Dict, List, Optional, Any, Union, Tuple
|
9
|
+
import mlflow
|
10
|
+
from mlflow.tracking import MlflowClient
|
11
|
+
from mlflow.entities.model_registry import ModelVersion as MlflowModelVersion
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class ModelStage(str, Enum):
|
17
|
+
"""Stages of a model in the registry."""
|
18
|
+
|
19
|
+
STAGING = "Staging"
|
20
|
+
PRODUCTION = "Production"
|
21
|
+
ARCHIVED = "Archived"
|
22
|
+
NONE = "None"
|
23
|
+
|
24
|
+
|
25
|
+
class ModelVersion:
|
26
|
+
"""
|
27
|
+
Model version representation.
|
28
|
+
|
29
|
+
This class provides a wrapper around MLflow's ModelVersion entity
|
30
|
+
with additional functionality.
|
31
|
+
"""
|
32
|
+
|
33
|
+
def __init__(self, mlflow_model_version: MlflowModelVersion):
|
34
|
+
"""
|
35
|
+
Initialize a model version.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
mlflow_model_version: MLflow model version entity
|
39
|
+
"""
|
40
|
+
self.mlflow_version = mlflow_model_version
|
41
|
+
|
42
|
+
@property
|
43
|
+
def name(self) -> str:
|
44
|
+
"""Get the model name."""
|
45
|
+
return self.mlflow_version.name
|
46
|
+
|
47
|
+
@property
|
48
|
+
def version(self) -> str:
|
49
|
+
"""Get the version number."""
|
50
|
+
return self.mlflow_version.version
|
51
|
+
|
52
|
+
@property
|
53
|
+
def stage(self) -> ModelStage:
|
54
|
+
"""Get the model stage."""
|
55
|
+
return ModelStage(self.mlflow_version.current_stage)
|
56
|
+
|
57
|
+
@property
|
58
|
+
def description(self) -> str:
|
59
|
+
"""Get the model description."""
|
60
|
+
return self.mlflow_version.description or ""
|
61
|
+
|
62
|
+
@property
|
63
|
+
def source(self) -> str:
|
64
|
+
"""Get the model source path."""
|
65
|
+
return self.mlflow_version.source
|
66
|
+
|
67
|
+
@property
|
68
|
+
def run_id(self) -> str:
|
69
|
+
"""Get the run ID that created this model."""
|
70
|
+
return self.mlflow_version.run_id
|
71
|
+
|
72
|
+
@property
|
73
|
+
def tags(self) -> Dict[str, str]:
|
74
|
+
"""Get the model tags."""
|
75
|
+
return self.mlflow_version.tags or {}
|
76
|
+
|
77
|
+
def to_dict(self) -> Dict[str, Any]:
|
78
|
+
"""
|
79
|
+
Convert the model version to a dictionary.
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
Dictionary representation of the model version
|
83
|
+
"""
|
84
|
+
return {
|
85
|
+
"name": self.name,
|
86
|
+
"version": self.version,
|
87
|
+
"stage": self.stage.value,
|
88
|
+
"description": self.description,
|
89
|
+
"source": self.source,
|
90
|
+
"run_id": self.run_id,
|
91
|
+
"tags": self.tags
|
92
|
+
}
|
93
|
+
|
94
|
+
|
95
|
+
class ModelRegistry:
|
96
|
+
"""
|
97
|
+
Registry for managing models and their versions.
|
98
|
+
|
99
|
+
This class provides methods to register models, transition between
|
100
|
+
stages, and retrieve models from the registry.
|
101
|
+
|
102
|
+
Example:
|
103
|
+
```python
|
104
|
+
# Create model registry
|
105
|
+
registry = ModelRegistry(
|
106
|
+
tracking_uri="http://localhost:5000",
|
107
|
+
registry_uri="http://localhost:5000"
|
108
|
+
)
|
109
|
+
|
110
|
+
# Register a model
|
111
|
+
version = registry.register_model(
|
112
|
+
name="llama-7b-finetuned",
|
113
|
+
source="path/to/model",
|
114
|
+
description="Llama 7B finetuned on custom data"
|
115
|
+
)
|
116
|
+
|
117
|
+
# Transition to staging
|
118
|
+
registry.transition_model_version_stage(
|
119
|
+
name="llama-7b-finetuned",
|
120
|
+
version=version,
|
121
|
+
stage=ModelStage.STAGING
|
122
|
+
)
|
123
|
+
|
124
|
+
# Get the latest staging model
|
125
|
+
staging_model = registry.get_latest_model_version(
|
126
|
+
name="llama-7b-finetuned",
|
127
|
+
stage=ModelStage.STAGING
|
128
|
+
)
|
129
|
+
```
|
130
|
+
"""
|
131
|
+
|
132
|
+
def __init__(
|
133
|
+
self,
|
134
|
+
tracking_uri: Optional[str] = None,
|
135
|
+
registry_uri: Optional[str] = None
|
136
|
+
):
|
137
|
+
"""
|
138
|
+
Initialize the model registry.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
tracking_uri: URI for MLflow tracking server
|
142
|
+
registry_uri: URI for MLflow model registry
|
143
|
+
"""
|
144
|
+
self.tracking_uri = tracking_uri or os.environ.get("MLFLOW_TRACKING_URI", "")
|
145
|
+
self.registry_uri = registry_uri or os.environ.get("MLFLOW_REGISTRY_URI", "")
|
146
|
+
|
147
|
+
self._setup_mlflow()
|
148
|
+
self.client = MlflowClient(tracking_uri=self.tracking_uri, registry_uri=self.registry_uri)
|
149
|
+
|
150
|
+
def _setup_mlflow(self) -> None:
|
151
|
+
"""Set up MLflow configuration."""
|
152
|
+
if self.tracking_uri:
|
153
|
+
mlflow.set_tracking_uri(self.tracking_uri)
|
154
|
+
|
155
|
+
if self.registry_uri:
|
156
|
+
mlflow.set_registry_uri(self.registry_uri)
|
157
|
+
|
158
|
+
def create_registered_model(
|
159
|
+
self,
|
160
|
+
name: str,
|
161
|
+
tags: Optional[Dict[str, str]] = None,
|
162
|
+
description: Optional[str] = None
|
163
|
+
) -> bool:
|
164
|
+
"""
|
165
|
+
Create a new registered model if it doesn't exist.
|
166
|
+
|
167
|
+
Args:
|
168
|
+
name: Name for the registered model
|
169
|
+
tags: Tags for the registered model
|
170
|
+
description: Description for the registered model
|
171
|
+
|
172
|
+
Returns:
|
173
|
+
True if successful, False otherwise
|
174
|
+
"""
|
175
|
+
try:
|
176
|
+
self.client.create_registered_model(
|
177
|
+
name=name,
|
178
|
+
tags=tags,
|
179
|
+
description=description
|
180
|
+
)
|
181
|
+
logger.info(f"Created registered model: {name}")
|
182
|
+
return True
|
183
|
+
except mlflow.exceptions.MlflowException as e:
|
184
|
+
if "already exists" in str(e):
|
185
|
+
logger.info(f"Model {name} already exists, skipping creation")
|
186
|
+
return True
|
187
|
+
logger.error(f"Failed to create registered model: {e}")
|
188
|
+
return False
|
189
|
+
|
190
|
+
def register_model(
|
191
|
+
self,
|
192
|
+
name: str,
|
193
|
+
source: str,
|
194
|
+
description: Optional[str] = None,
|
195
|
+
tags: Optional[Dict[str, str]] = None
|
196
|
+
) -> Optional[str]:
|
197
|
+
"""
|
198
|
+
Register a model with the registry.
|
199
|
+
|
200
|
+
Args:
|
201
|
+
name: Name for the registered model
|
202
|
+
source: Source path of the model
|
203
|
+
description: Description for the model version
|
204
|
+
tags: Tags for the model version
|
205
|
+
|
206
|
+
Returns:
|
207
|
+
Version of the registered model or None if registration failed
|
208
|
+
"""
|
209
|
+
# Ensure registered model exists
|
210
|
+
self.create_registered_model(name)
|
211
|
+
|
212
|
+
try:
|
213
|
+
model_version = self.client.create_model_version(
|
214
|
+
name=name,
|
215
|
+
source=source,
|
216
|
+
description=description,
|
217
|
+
tags=tags
|
218
|
+
)
|
219
|
+
version = model_version.version
|
220
|
+
logger.info(f"Registered model version: {name} v{version}")
|
221
|
+
return version
|
222
|
+
except mlflow.exceptions.MlflowException as e:
|
223
|
+
logger.error(f"Failed to register model: {e}")
|
224
|
+
return None
|
225
|
+
|
226
|
+
def get_model_version(
|
227
|
+
self,
|
228
|
+
name: str,
|
229
|
+
version: str
|
230
|
+
) -> Optional[ModelVersion]:
|
231
|
+
"""
|
232
|
+
Get a specific model version.
|
233
|
+
|
234
|
+
Args:
|
235
|
+
name: Name of the registered model
|
236
|
+
version: Version of the model
|
237
|
+
|
238
|
+
Returns:
|
239
|
+
ModelVersion entity or None if not found
|
240
|
+
"""
|
241
|
+
try:
|
242
|
+
mlflow_version = self.client.get_model_version(name=name, version=version)
|
243
|
+
return ModelVersion(mlflow_version)
|
244
|
+
except mlflow.exceptions.MlflowException as e:
|
245
|
+
logger.error(f"Failed to get model version: {e}")
|
246
|
+
return None
|
247
|
+
|
248
|
+
def get_latest_model_version(
|
249
|
+
self,
|
250
|
+
name: str,
|
251
|
+
stage: Optional[ModelStage] = None
|
252
|
+
) -> Optional[ModelVersion]:
|
253
|
+
"""
|
254
|
+
Get the latest version of a model, optionally filtering by stage.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
name: Name of the registered model
|
258
|
+
stage: Stage to filter by
|
259
|
+
|
260
|
+
Returns:
|
261
|
+
Latest ModelVersion or None if not found
|
262
|
+
"""
|
263
|
+
try:
|
264
|
+
filter_string = f"name='{name}'"
|
265
|
+
if stage:
|
266
|
+
filter_string += f" AND stage='{stage.value}'"
|
267
|
+
|
268
|
+
versions = self.client.search_model_versions(filter_string=filter_string)
|
269
|
+
|
270
|
+
if not versions:
|
271
|
+
logger.warning(f"No model versions found for {name}")
|
272
|
+
return None
|
273
|
+
|
274
|
+
# Sort by version number (newest first)
|
275
|
+
sorted_versions = sorted(
|
276
|
+
versions,
|
277
|
+
key=lambda v: int(v.version),
|
278
|
+
reverse=True
|
279
|
+
)
|
280
|
+
|
281
|
+
return ModelVersion(sorted_versions[0])
|
282
|
+
except mlflow.exceptions.MlflowException as e:
|
283
|
+
logger.error(f"Failed to get latest model version: {e}")
|
284
|
+
return None
|
285
|
+
|
286
|
+
def list_model_versions(
|
287
|
+
self,
|
288
|
+
name: str,
|
289
|
+
stage: Optional[ModelStage] = None
|
290
|
+
) -> List[ModelVersion]:
|
291
|
+
"""
|
292
|
+
List all versions of a model, optionally filtering by stage.
|
293
|
+
|
294
|
+
Args:
|
295
|
+
name: Name of the registered model
|
296
|
+
stage: Stage to filter by
|
297
|
+
|
298
|
+
Returns:
|
299
|
+
List of ModelVersion entities
|
300
|
+
"""
|
301
|
+
try:
|
302
|
+
filter_string = f"name='{name}'"
|
303
|
+
if stage:
|
304
|
+
filter_string += f" AND stage='{stage.value}'"
|
305
|
+
|
306
|
+
mlflow_versions = self.client.search_model_versions(filter_string=filter_string)
|
307
|
+
|
308
|
+
return [ModelVersion(v) for v in mlflow_versions]
|
309
|
+
except mlflow.exceptions.MlflowException as e:
|
310
|
+
logger.error(f"Failed to list model versions: {e}")
|
311
|
+
return []
|
312
|
+
|
313
|
+
def list_registered_models(self) -> List[str]:
|
314
|
+
"""
|
315
|
+
List all registered models in the registry.
|
316
|
+
|
317
|
+
Returns:
|
318
|
+
List of registered model names
|
319
|
+
"""
|
320
|
+
try:
|
321
|
+
models = self.client.list_registered_models()
|
322
|
+
return [model.name for model in models]
|
323
|
+
except mlflow.exceptions.MlflowException as e:
|
324
|
+
logger.error(f"Failed to list registered models: {e}")
|
325
|
+
return []
|
326
|
+
|
327
|
+
def transition_model_version_stage(
|
328
|
+
self,
|
329
|
+
name: str,
|
330
|
+
version: str,
|
331
|
+
stage: ModelStage,
|
332
|
+
archive_existing_versions: bool = False
|
333
|
+
) -> Optional[ModelVersion]:
|
334
|
+
"""
|
335
|
+
Transition a model version to a different stage.
|
336
|
+
|
337
|
+
Args:
|
338
|
+
name: Name of the registered model
|
339
|
+
version: Version of the model
|
340
|
+
stage: Stage to transition to
|
341
|
+
archive_existing_versions: Whether to archive existing versions in the target stage
|
342
|
+
|
343
|
+
Returns:
|
344
|
+
Updated ModelVersion or None if transition failed
|
345
|
+
"""
|
346
|
+
try:
|
347
|
+
mlflow_version = self.client.transition_model_version_stage(
|
348
|
+
name=name,
|
349
|
+
version=version,
|
350
|
+
stage=stage.value,
|
351
|
+
archive_existing_versions=archive_existing_versions
|
352
|
+
)
|
353
|
+
logger.info(f"Transitioned model {name} v{version} to {stage.value}")
|
354
|
+
return ModelVersion(mlflow_version)
|
355
|
+
except mlflow.exceptions.MlflowException as e:
|
356
|
+
logger.error(f"Failed to transition model version: {e}")
|
357
|
+
return None
|
358
|
+
|
359
|
+
def update_model_version(
|
360
|
+
self,
|
361
|
+
name: str,
|
362
|
+
version: str,
|
363
|
+
description: Optional[str] = None
|
364
|
+
) -> Optional[ModelVersion]:
|
365
|
+
"""
|
366
|
+
Update a model version's metadata.
|
367
|
+
|
368
|
+
Args:
|
369
|
+
name: Name of the registered model
|
370
|
+
version: Version of the model
|
371
|
+
description: New description for the model version
|
372
|
+
|
373
|
+
Returns:
|
374
|
+
Updated ModelVersion or None if update failed
|
375
|
+
"""
|
376
|
+
try:
|
377
|
+
mlflow_version = self.client.update_model_version(
|
378
|
+
name=name,
|
379
|
+
version=version,
|
380
|
+
description=description
|
381
|
+
)
|
382
|
+
logger.info(f"Updated model version: {name} v{version}")
|
383
|
+
return ModelVersion(mlflow_version)
|
384
|
+
except mlflow.exceptions.MlflowException as e:
|
385
|
+
logger.error(f"Failed to update model version: {e}")
|
386
|
+
return None
|
387
|
+
|
388
|
+
def set_model_version_tag(
|
389
|
+
self,
|
390
|
+
name: str,
|
391
|
+
version: str,
|
392
|
+
key: str,
|
393
|
+
value: str
|
394
|
+
) -> bool:
|
395
|
+
"""
|
396
|
+
Set a tag on a model version.
|
397
|
+
|
398
|
+
Args:
|
399
|
+
name: Name of the registered model
|
400
|
+
version: Version of the model
|
401
|
+
key: Tag key
|
402
|
+
value: Tag value
|
403
|
+
|
404
|
+
Returns:
|
405
|
+
True if successful, False otherwise
|
406
|
+
"""
|
407
|
+
try:
|
408
|
+
self.client.set_model_version_tag(
|
409
|
+
name=name,
|
410
|
+
version=version,
|
411
|
+
key=key,
|
412
|
+
value=value
|
413
|
+
)
|
414
|
+
logger.debug(f"Set tag {key}={value} on model {name} v{version}")
|
415
|
+
return True
|
416
|
+
except mlflow.exceptions.MlflowException as e:
|
417
|
+
logger.error(f"Failed to set model version tag: {e}")
|
418
|
+
return False
|
419
|
+
|
420
|
+
def delete_model_version(
|
421
|
+
self,
|
422
|
+
name: str,
|
423
|
+
version: str
|
424
|
+
) -> bool:
|
425
|
+
"""
|
426
|
+
Delete a model version.
|
427
|
+
|
428
|
+
Args:
|
429
|
+
name: Name of the registered model
|
430
|
+
version: Version of the model
|
431
|
+
|
432
|
+
Returns:
|
433
|
+
True if successful, False otherwise
|
434
|
+
"""
|
435
|
+
try:
|
436
|
+
self.client.delete_model_version(
|
437
|
+
name=name,
|
438
|
+
version=version
|
439
|
+
)
|
440
|
+
logger.info(f"Deleted model version: {name} v{version}")
|
441
|
+
return True
|
442
|
+
except mlflow.exceptions.MlflowException as e:
|
443
|
+
logger.error(f"Failed to delete model version: {e}")
|
444
|
+
return False
|
445
|
+
|
446
|
+
def delete_registered_model(
|
447
|
+
self,
|
448
|
+
name: str
|
449
|
+
) -> bool:
|
450
|
+
"""
|
451
|
+
Delete a registered model and all its versions.
|
452
|
+
|
453
|
+
Args:
|
454
|
+
name: Name of the registered model
|
455
|
+
|
456
|
+
Returns:
|
457
|
+
True if successful, False otherwise
|
458
|
+
"""
|
459
|
+
try:
|
460
|
+
self.client.delete_registered_model(name=name)
|
461
|
+
logger.info(f"Deleted registered model: {name}")
|
462
|
+
return True
|
463
|
+
except mlflow.exceptions.MlflowException as e:
|
464
|
+
logger.error(f"Failed to delete registered model: {e}")
|
465
|
+
return False
|
@@ -0,0 +1,95 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
"""
|
3
|
+
Start an MLflow tracking server.
|
4
|
+
|
5
|
+
This script provides a simple way to start an MLflow tracking server
|
6
|
+
with configurable storage locations.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import os
|
10
|
+
import argparse
|
11
|
+
import subprocess
|
12
|
+
import logging
|
13
|
+
|
14
|
+
# Configure logging
|
15
|
+
logging.basicConfig(
|
16
|
+
level=logging.INFO,
|
17
|
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
18
|
+
)
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
|
22
|
+
def parse_args():
|
23
|
+
"""Parse command line arguments."""
|
24
|
+
parser = argparse.ArgumentParser(description="Start an MLflow tracking server")
|
25
|
+
|
26
|
+
parser.add_argument(
|
27
|
+
"--backend_store_uri",
|
28
|
+
type=str,
|
29
|
+
default="./mlruns",
|
30
|
+
help="URI for the backend store (e.g., SQLite, MySQL, PostgreSQL)"
|
31
|
+
)
|
32
|
+
parser.add_argument(
|
33
|
+
"--default_artifact_root",
|
34
|
+
type=str,
|
35
|
+
default="./mlartifacts",
|
36
|
+
help="Directory or URI for storing artifacts"
|
37
|
+
)
|
38
|
+
parser.add_argument(
|
39
|
+
"--host",
|
40
|
+
type=str,
|
41
|
+
default="127.0.0.1",
|
42
|
+
help="Host to bind to"
|
43
|
+
)
|
44
|
+
parser.add_argument(
|
45
|
+
"--port",
|
46
|
+
type=int,
|
47
|
+
default=5000,
|
48
|
+
help="Port to bind to"
|
49
|
+
)
|
50
|
+
parser.add_argument(
|
51
|
+
"--workers",
|
52
|
+
type=int,
|
53
|
+
default=4,
|
54
|
+
help="Number of gunicorn workers"
|
55
|
+
)
|
56
|
+
|
57
|
+
return parser.parse_args()
|
58
|
+
|
59
|
+
|
60
|
+
def main():
|
61
|
+
"""Start the MLflow tracking server."""
|
62
|
+
args = parse_args()
|
63
|
+
|
64
|
+
# Create directories if they don't exist
|
65
|
+
if args.backend_store_uri.startswith("./") or args.backend_store_uri.startswith("/"):
|
66
|
+
os.makedirs(args.backend_store_uri, exist_ok=True)
|
67
|
+
logger.info(f"Using backend store: {args.backend_store_uri}")
|
68
|
+
|
69
|
+
if args.default_artifact_root.startswith("./") or args.default_artifact_root.startswith("/"):
|
70
|
+
os.makedirs(args.default_artifact_root, exist_ok=True)
|
71
|
+
logger.info(f"Using artifact root: {args.default_artifact_root}")
|
72
|
+
|
73
|
+
# Build the MLflow command
|
74
|
+
cmd = [
|
75
|
+
"mlflow", "server",
|
76
|
+
"--backend-store-uri", args.backend_store_uri,
|
77
|
+
"--default-artifact-root", args.default_artifact_root,
|
78
|
+
"--host", args.host,
|
79
|
+
"--port", str(args.port),
|
80
|
+
"--workers", str(args.workers)
|
81
|
+
]
|
82
|
+
|
83
|
+
# Start the server
|
84
|
+
logger.info(f"Starting MLflow server: {' '.join(cmd)}")
|
85
|
+
try:
|
86
|
+
subprocess.run(cmd, check=True)
|
87
|
+
except KeyboardInterrupt:
|
88
|
+
logger.info("MLflow server stopped by user")
|
89
|
+
except Exception as e:
|
90
|
+
logger.error(f"Error starting MLflow server: {e}")
|
91
|
+
raise
|
92
|
+
|
93
|
+
|
94
|
+
if __name__ == "__main__":
|
95
|
+
main()
|