isa-model 0.3.9__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/client.py +732 -565
- isa_model/core/cache/redis_cache.py +401 -0
- isa_model/core/config/config_manager.py +53 -10
- isa_model/core/config.py +1 -1
- isa_model/core/database/__init__.py +1 -0
- isa_model/core/database/migrations.py +277 -0
- isa_model/core/database/supabase_client.py +123 -0
- isa_model/core/models/__init__.py +37 -0
- isa_model/core/models/model_billing_tracker.py +60 -88
- isa_model/core/models/model_manager.py +36 -18
- isa_model/core/models/model_repo.py +44 -38
- isa_model/core/models/model_statistics_tracker.py +234 -0
- isa_model/core/models/model_storage.py +0 -1
- isa_model/core/models/model_version_manager.py +959 -0
- isa_model/core/pricing_manager.py +2 -249
- isa_model/core/resilience/circuit_breaker.py +366 -0
- isa_model/core/security/secrets.py +358 -0
- isa_model/core/services/__init__.py +2 -4
- isa_model/core/services/intelligent_model_selector.py +101 -370
- isa_model/core/storage/hf_storage.py +1 -1
- isa_model/core/types.py +7 -0
- isa_model/deployment/cloud/modal/isa_audio_chatTTS_service.py +520 -0
- isa_model/deployment/cloud/modal/isa_audio_fish_service.py +0 -0
- isa_model/deployment/cloud/modal/isa_audio_openvoice_service.py +758 -0
- isa_model/deployment/cloud/modal/isa_audio_service_v2.py +1044 -0
- isa_model/deployment/cloud/modal/isa_embed_rerank_service.py +296 -0
- isa_model/deployment/cloud/modal/isa_video_hunyuan_service.py +423 -0
- isa_model/deployment/cloud/modal/isa_vision_ocr_service.py +519 -0
- isa_model/deployment/cloud/modal/isa_vision_qwen25_service.py +709 -0
- isa_model/deployment/cloud/modal/isa_vision_table_service.py +467 -323
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +607 -180
- isa_model/deployment/cloud/modal/isa_vision_ui_service_optimized.py +660 -0
- isa_model/deployment/core/deployment_manager.py +6 -4
- isa_model/deployment/services/auto_hf_modal_deployer.py +894 -0
- isa_model/eval/benchmarks/__init__.py +27 -0
- isa_model/eval/benchmarks/multimodal_datasets.py +460 -0
- isa_model/eval/benchmarks.py +244 -12
- isa_model/eval/evaluators/__init__.py +8 -2
- isa_model/eval/evaluators/audio_evaluator.py +727 -0
- isa_model/eval/evaluators/embedding_evaluator.py +742 -0
- isa_model/eval/evaluators/vision_evaluator.py +564 -0
- isa_model/eval/example_evaluation.py +395 -0
- isa_model/eval/factory.py +272 -5
- isa_model/eval/isa_benchmarks.py +700 -0
- isa_model/eval/isa_integration.py +582 -0
- isa_model/eval/metrics.py +159 -6
- isa_model/eval/tests/unit/test_basic.py +396 -0
- isa_model/inference/ai_factory.py +44 -8
- isa_model/inference/services/audio/__init__.py +21 -0
- isa_model/inference/services/audio/base_realtime_service.py +225 -0
- isa_model/inference/services/audio/isa_tts_service.py +0 -0
- isa_model/inference/services/audio/openai_realtime_service.py +320 -124
- isa_model/inference/services/audio/openai_stt_service.py +32 -6
- isa_model/inference/services/base_service.py +17 -1
- isa_model/inference/services/embedding/__init__.py +13 -0
- isa_model/inference/services/embedding/base_embed_service.py +111 -8
- isa_model/inference/services/embedding/isa_embed_service.py +305 -0
- isa_model/inference/services/embedding/openai_embed_service.py +2 -4
- isa_model/inference/services/embedding/tests/test_embedding.py +222 -0
- isa_model/inference/services/img/__init__.py +2 -2
- isa_model/inference/services/img/base_image_gen_service.py +24 -7
- isa_model/inference/services/img/replicate_image_gen_service.py +84 -422
- isa_model/inference/services/img/services/replicate_face_swap.py +193 -0
- isa_model/inference/services/img/services/replicate_flux.py +226 -0
- isa_model/inference/services/img/services/replicate_flux_kontext.py +219 -0
- isa_model/inference/services/img/services/replicate_sticker_maker.py +249 -0
- isa_model/inference/services/img/tests/test_img_client.py +297 -0
- isa_model/inference/services/llm/base_llm_service.py +30 -6
- isa_model/inference/services/llm/helpers/llm_adapter.py +63 -9
- isa_model/inference/services/llm/ollama_llm_service.py +2 -1
- isa_model/inference/services/llm/openai_llm_service.py +652 -55
- isa_model/inference/services/llm/yyds_llm_service.py +2 -1
- isa_model/inference/services/vision/__init__.py +5 -5
- isa_model/inference/services/vision/base_vision_service.py +118 -185
- isa_model/inference/services/vision/helpers/image_utils.py +11 -5
- isa_model/inference/services/vision/isa_vision_service.py +573 -0
- isa_model/inference/services/vision/tests/test_ocr_client.py +284 -0
- isa_model/serving/api/fastapi_server.py +88 -16
- isa_model/serving/api/middleware/auth.py +311 -0
- isa_model/serving/api/middleware/security.py +278 -0
- isa_model/serving/api/routes/analytics.py +486 -0
- isa_model/serving/api/routes/deployments.py +339 -0
- isa_model/serving/api/routes/evaluations.py +579 -0
- isa_model/serving/api/routes/logs.py +430 -0
- isa_model/serving/api/routes/settings.py +582 -0
- isa_model/serving/api/routes/unified.py +324 -165
- isa_model/serving/api/startup.py +304 -0
- isa_model/serving/modal_proxy_server.py +249 -0
- isa_model/training/__init__.py +100 -6
- isa_model/training/core/__init__.py +4 -1
- isa_model/training/examples/intelligent_training_example.py +281 -0
- isa_model/training/intelligent/__init__.py +25 -0
- isa_model/training/intelligent/decision_engine.py +643 -0
- isa_model/training/intelligent/intelligent_factory.py +888 -0
- isa_model/training/intelligent/knowledge_base.py +751 -0
- isa_model/training/intelligent/resource_optimizer.py +839 -0
- isa_model/training/intelligent/task_classifier.py +576 -0
- isa_model/training/storage/__init__.py +24 -0
- isa_model/training/storage/core_integration.py +439 -0
- isa_model/training/storage/training_repository.py +552 -0
- isa_model/training/storage/training_storage.py +628 -0
- {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/METADATA +13 -1
- isa_model-0.4.0.dist-info/RECORD +182 -0
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +0 -766
- isa_model/deployment/cloud/modal/register_models.py +0 -321
- isa_model/inference/adapter/unified_api.py +0 -248
- isa_model/inference/services/helpers/stacked_config.py +0 -148
- isa_model/inference/services/img/flux_professional_service.py +0 -603
- isa_model/inference/services/img/helpers/base_stacked_service.py +0 -274
- isa_model/inference/services/others/table_transformer_service.py +0 -61
- isa_model/inference/services/vision/doc_analysis_service.py +0 -640
- isa_model/inference/services/vision/helpers/base_stacked_service.py +0 -274
- isa_model/inference/services/vision/ui_analysis_service.py +0 -823
- isa_model/scripts/inference_tracker.py +0 -283
- isa_model/scripts/mlflow_manager.py +0 -379
- isa_model/scripts/model_registry.py +0 -465
- isa_model/scripts/register_models.py +0 -370
- isa_model/scripts/register_models_with_embeddings.py +0 -510
- isa_model/scripts/start_mlflow.py +0 -95
- isa_model/scripts/training_tracker.py +0 -257
- isa_model-0.3.9.dist-info/RECORD +0 -138
- {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/WHEEL +0 -0
- {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,283 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
MLflow tracker for inference workflows.
|
3
|
-
"""
|
4
|
-
|
5
|
-
import os
|
6
|
-
import json
|
7
|
-
import time
|
8
|
-
import logging
|
9
|
-
from typing import Dict, List, Optional, Any, Union
|
10
|
-
from contextlib import contextmanager
|
11
|
-
|
12
|
-
from .mlflow_manager import MLflowManager, ExperimentType
|
13
|
-
from .model_registry import ModelRegistry, ModelStage, ModelVersion
|
14
|
-
|
15
|
-
|
16
|
-
logger = logging.getLogger(__name__)
|
17
|
-
|
18
|
-
|
19
|
-
class InferenceTracker:
|
20
|
-
"""
|
21
|
-
Tracker for model inference workflows.
|
22
|
-
|
23
|
-
This class provides utilities to track model inference using MLflow,
|
24
|
-
including performance metrics and input/output logging.
|
25
|
-
|
26
|
-
Example:
|
27
|
-
```python
|
28
|
-
# Initialize tracker
|
29
|
-
tracker = InferenceTracker(
|
30
|
-
tracking_uri="http://localhost:5000"
|
31
|
-
)
|
32
|
-
|
33
|
-
# Get model from registry
|
34
|
-
model_version = tracker.get_production_model("llama-7b")
|
35
|
-
|
36
|
-
# Track inference
|
37
|
-
with tracker.track_inference(
|
38
|
-
model_name="llama-7b",
|
39
|
-
model_version=model_version.version
|
40
|
-
):
|
41
|
-
# Start timer
|
42
|
-
start_time = time.time()
|
43
|
-
|
44
|
-
# Generate text
|
45
|
-
output = model.generate(prompt)
|
46
|
-
|
47
|
-
# Log inference
|
48
|
-
tracker.log_inference(
|
49
|
-
input=prompt,
|
50
|
-
output=output,
|
51
|
-
latency_ms=(time.time() - start_time) * 1000
|
52
|
-
)
|
53
|
-
```
|
54
|
-
"""
|
55
|
-
|
56
|
-
def __init__(
|
57
|
-
self,
|
58
|
-
tracking_uri: Optional[str] = None,
|
59
|
-
artifact_uri: Optional[str] = None,
|
60
|
-
registry_uri: Optional[str] = None
|
61
|
-
):
|
62
|
-
"""
|
63
|
-
Initialize the inference tracker.
|
64
|
-
|
65
|
-
Args:
|
66
|
-
tracking_uri: URI for MLflow tracking server
|
67
|
-
artifact_uri: URI for MLflow artifacts
|
68
|
-
registry_uri: URI for MLflow model registry
|
69
|
-
"""
|
70
|
-
self.mlflow_manager = MLflowManager(
|
71
|
-
tracking_uri=tracking_uri,
|
72
|
-
artifact_uri=artifact_uri,
|
73
|
-
registry_uri=registry_uri
|
74
|
-
)
|
75
|
-
self.model_registry = ModelRegistry(
|
76
|
-
tracking_uri=tracking_uri,
|
77
|
-
registry_uri=registry_uri
|
78
|
-
)
|
79
|
-
self.current_run_info = {}
|
80
|
-
self.inference_samples = []
|
81
|
-
|
82
|
-
def get_production_model(self, model_name: str) -> Optional[ModelVersion]:
|
83
|
-
"""
|
84
|
-
Get the production version of a model.
|
85
|
-
|
86
|
-
Args:
|
87
|
-
model_name: Name of the model
|
88
|
-
|
89
|
-
Returns:
|
90
|
-
Production ModelVersion or None if not found
|
91
|
-
"""
|
92
|
-
return self.model_registry.get_latest_model_version(
|
93
|
-
name=model_name,
|
94
|
-
stage=ModelStage.PRODUCTION
|
95
|
-
)
|
96
|
-
|
97
|
-
def get_staging_model(self, model_name: str) -> Optional[ModelVersion]:
|
98
|
-
"""
|
99
|
-
Get the staging version of a model.
|
100
|
-
|
101
|
-
Args:
|
102
|
-
model_name: Name of the model
|
103
|
-
|
104
|
-
Returns:
|
105
|
-
Staging ModelVersion or None if not found
|
106
|
-
"""
|
107
|
-
return self.model_registry.get_latest_model_version(
|
108
|
-
name=model_name,
|
109
|
-
stage=ModelStage.STAGING
|
110
|
-
)
|
111
|
-
|
112
|
-
@contextmanager
|
113
|
-
def track_inference(
|
114
|
-
self,
|
115
|
-
model_name: str,
|
116
|
-
model_version: Optional[str] = None,
|
117
|
-
batch_size: Optional[int] = None,
|
118
|
-
tags: Optional[Dict[str, str]] = None
|
119
|
-
):
|
120
|
-
"""
|
121
|
-
Track model inference with MLflow.
|
122
|
-
|
123
|
-
Args:
|
124
|
-
model_name: Name of the model
|
125
|
-
model_version: Version of the model
|
126
|
-
batch_size: Batch size for inference
|
127
|
-
tags: Tags for the run
|
128
|
-
|
129
|
-
Yields:
|
130
|
-
Dictionary with run information
|
131
|
-
"""
|
132
|
-
run_info = {
|
133
|
-
"model_name": model_name,
|
134
|
-
"model_version": model_version,
|
135
|
-
"batch_size": batch_size,
|
136
|
-
"start_time": time.time(),
|
137
|
-
"metrics": {}
|
138
|
-
}
|
139
|
-
|
140
|
-
# Prepare tags
|
141
|
-
if tags is None:
|
142
|
-
tags = {}
|
143
|
-
|
144
|
-
tags["model_name"] = model_name
|
145
|
-
if model_version:
|
146
|
-
tags["model_version"] = model_version
|
147
|
-
|
148
|
-
if batch_size:
|
149
|
-
tags["batch_size"] = str(batch_size)
|
150
|
-
|
151
|
-
# Start the MLflow run
|
152
|
-
with self.mlflow_manager.start_run(
|
153
|
-
experiment_type=ExperimentType.INFERENCE,
|
154
|
-
model_name=model_name,
|
155
|
-
tags=tags
|
156
|
-
) as run:
|
157
|
-
run_info["run_id"] = run.info.run_id
|
158
|
-
run_info["experiment_id"] = run.info.experiment_id
|
159
|
-
|
160
|
-
# Reset inference samples
|
161
|
-
self.inference_samples = []
|
162
|
-
|
163
|
-
self.current_run_info = run_info
|
164
|
-
try:
|
165
|
-
yield run_info
|
166
|
-
|
167
|
-
# Calculate and log summary metrics
|
168
|
-
self._log_summary_metrics()
|
169
|
-
|
170
|
-
# Save inference samples
|
171
|
-
if self.inference_samples:
|
172
|
-
self._save_inference_samples()
|
173
|
-
|
174
|
-
finally:
|
175
|
-
run_info["end_time"] = time.time()
|
176
|
-
run_info["duration"] = run_info["end_time"] - run_info["start_time"]
|
177
|
-
|
178
|
-
# Log duration
|
179
|
-
self.mlflow_manager.log_metrics({
|
180
|
-
"duration_seconds": run_info["duration"]
|
181
|
-
})
|
182
|
-
|
183
|
-
self.current_run_info = {}
|
184
|
-
|
185
|
-
def log_inference(
|
186
|
-
self,
|
187
|
-
input: str,
|
188
|
-
output: str,
|
189
|
-
latency_ms: Optional[float] = None,
|
190
|
-
token_count: Optional[int] = None,
|
191
|
-
tokens_per_second: Optional[float] = None,
|
192
|
-
metadata: Optional[Dict[str, Any]] = None
|
193
|
-
) -> None:
|
194
|
-
"""
|
195
|
-
Log an inference sample.
|
196
|
-
|
197
|
-
Args:
|
198
|
-
input: Input prompt
|
199
|
-
output: Generated output
|
200
|
-
latency_ms: Latency in milliseconds
|
201
|
-
token_count: Number of tokens generated
|
202
|
-
tokens_per_second: Tokens per second
|
203
|
-
metadata: Additional metadata
|
204
|
-
"""
|
205
|
-
if not self.current_run_info:
|
206
|
-
logger.warning("No active run. Inference will not be logged.")
|
207
|
-
return
|
208
|
-
|
209
|
-
sample = {
|
210
|
-
"input": input,
|
211
|
-
"output": output,
|
212
|
-
"timestamp": time.time()
|
213
|
-
}
|
214
|
-
|
215
|
-
if latency_ms is not None:
|
216
|
-
sample["latency_ms"] = latency_ms
|
217
|
-
|
218
|
-
if token_count is not None:
|
219
|
-
sample["token_count"] = token_count
|
220
|
-
|
221
|
-
if tokens_per_second is not None:
|
222
|
-
sample["tokens_per_second"] = tokens_per_second
|
223
|
-
|
224
|
-
if metadata:
|
225
|
-
sample["metadata"] = metadata
|
226
|
-
|
227
|
-
self.inference_samples.append(sample)
|
228
|
-
|
229
|
-
# Log individual metrics
|
230
|
-
metrics = {}
|
231
|
-
if latency_ms is not None:
|
232
|
-
metrics["latency_ms"] = latency_ms
|
233
|
-
|
234
|
-
if token_count is not None:
|
235
|
-
metrics["token_count"] = token_count
|
236
|
-
|
237
|
-
if tokens_per_second is not None:
|
238
|
-
metrics["tokens_per_second"] = tokens_per_second
|
239
|
-
|
240
|
-
if metrics:
|
241
|
-
self.mlflow_manager.log_metrics(metrics)
|
242
|
-
|
243
|
-
def _log_summary_metrics(self) -> None:
|
244
|
-
"""Log summary metrics based on all inference samples."""
|
245
|
-
if not self.inference_samples:
|
246
|
-
return
|
247
|
-
|
248
|
-
latencies = [s.get("latency_ms") for s in self.inference_samples if "latency_ms" in s]
|
249
|
-
token_counts = [s.get("token_count") for s in self.inference_samples if "token_count" in s]
|
250
|
-
tokens_per_second = [s.get("tokens_per_second") for s in self.inference_samples if "tokens_per_second" in s]
|
251
|
-
|
252
|
-
metrics = {
|
253
|
-
"inference_count": len(self.inference_samples)
|
254
|
-
}
|
255
|
-
|
256
|
-
if latencies:
|
257
|
-
metrics["avg_latency_ms"] = sum(latencies) / len(latencies)
|
258
|
-
metrics["min_latency_ms"] = min(latencies)
|
259
|
-
metrics["max_latency_ms"] = max(latencies)
|
260
|
-
|
261
|
-
if token_counts:
|
262
|
-
metrics["avg_token_count"] = sum(token_counts) / len(token_counts)
|
263
|
-
metrics["total_tokens"] = sum(token_counts)
|
264
|
-
|
265
|
-
if tokens_per_second:
|
266
|
-
metrics["avg_tokens_per_second"] = sum(tokens_per_second) / len(tokens_per_second)
|
267
|
-
|
268
|
-
self.mlflow_manager.log_metrics(metrics)
|
269
|
-
|
270
|
-
def _save_inference_samples(self) -> None:
|
271
|
-
"""Save inference samples as an artifact."""
|
272
|
-
import tempfile
|
273
|
-
|
274
|
-
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as f:
|
275
|
-
json.dump(self.inference_samples, f, indent=2)
|
276
|
-
temp_path = f.name
|
277
|
-
|
278
|
-
self.mlflow_manager.log_artifact(temp_path, "inference_samples.json")
|
279
|
-
|
280
|
-
try:
|
281
|
-
os.remove(temp_path)
|
282
|
-
except:
|
283
|
-
pass
|
@@ -1,379 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
MLflow manager for experiment tracking and model management.
|
3
|
-
"""
|
4
|
-
|
5
|
-
import os
|
6
|
-
import logging
|
7
|
-
from enum import Enum
|
8
|
-
from typing import Dict, List, Optional, Any, Union
|
9
|
-
import mlflow
|
10
|
-
from mlflow.tracking import MlflowClient
|
11
|
-
|
12
|
-
logger = logging.getLogger(__name__)
|
13
|
-
|
14
|
-
|
15
|
-
class ExperimentType(str, Enum):
|
16
|
-
"""Types of experiments that can be tracked."""
|
17
|
-
|
18
|
-
TRAINING = "training"
|
19
|
-
FINETUNING = "finetuning"
|
20
|
-
REINFORCEMENT_LEARNING = "rl"
|
21
|
-
INFERENCE = "inference"
|
22
|
-
EVALUATION = "evaluation"
|
23
|
-
|
24
|
-
|
25
|
-
class MLflowManager:
|
26
|
-
"""
|
27
|
-
Manager class for MLflow operations.
|
28
|
-
|
29
|
-
This class provides methods to set up MLflow, track experiments,
|
30
|
-
log metrics, and manage models.
|
31
|
-
|
32
|
-
Example:
|
33
|
-
```python
|
34
|
-
# Initialize MLflow manager
|
35
|
-
mlflow_manager = MLflowManager(
|
36
|
-
tracking_uri="http://localhost:5000",
|
37
|
-
artifact_uri="s3://bucket/artifacts"
|
38
|
-
)
|
39
|
-
|
40
|
-
# Set up experiment and start run
|
41
|
-
with mlflow_manager.start_run(
|
42
|
-
experiment_type=ExperimentType.FINETUNING,
|
43
|
-
model_name="llama-7b"
|
44
|
-
) as run:
|
45
|
-
# Log parameters
|
46
|
-
mlflow_manager.log_params({
|
47
|
-
"learning_rate": 2e-5,
|
48
|
-
"batch_size": 8
|
49
|
-
})
|
50
|
-
|
51
|
-
# Train model...
|
52
|
-
|
53
|
-
# Log metrics
|
54
|
-
mlflow_manager.log_metrics({
|
55
|
-
"accuracy": 0.95,
|
56
|
-
"loss": 0.02
|
57
|
-
})
|
58
|
-
|
59
|
-
# Log model
|
60
|
-
mlflow_manager.log_model(
|
61
|
-
model_path="/path/to/model",
|
62
|
-
name="finetuned-llama-7b"
|
63
|
-
)
|
64
|
-
```
|
65
|
-
"""
|
66
|
-
|
67
|
-
def __init__(
|
68
|
-
self,
|
69
|
-
tracking_uri: Optional[str] = None,
|
70
|
-
artifact_uri: Optional[str] = None,
|
71
|
-
registry_uri: Optional[str] = None
|
72
|
-
):
|
73
|
-
"""
|
74
|
-
Initialize the MLflow manager.
|
75
|
-
|
76
|
-
Args:
|
77
|
-
tracking_uri: URI for MLflow tracking server
|
78
|
-
artifact_uri: URI for MLflow artifacts
|
79
|
-
registry_uri: URI for MLflow model registry
|
80
|
-
"""
|
81
|
-
self.tracking_uri = tracking_uri or os.environ.get("MLFLOW_TRACKING_URI", "")
|
82
|
-
self.artifact_uri = artifact_uri or os.environ.get("MLFLOW_ARTIFACT_URI", "")
|
83
|
-
self.registry_uri = registry_uri or os.environ.get("MLFLOW_REGISTRY_URI", "")
|
84
|
-
|
85
|
-
self._setup_mlflow()
|
86
|
-
self.client = MlflowClient(tracking_uri=self.tracking_uri, registry_uri=self.registry_uri)
|
87
|
-
self.active_run = None
|
88
|
-
|
89
|
-
def _setup_mlflow(self) -> None:
|
90
|
-
"""Set up MLflow configuration."""
|
91
|
-
if self.tracking_uri:
|
92
|
-
mlflow.set_tracking_uri(self.tracking_uri)
|
93
|
-
logger.info(f"Set MLflow tracking URI to {self.tracking_uri}")
|
94
|
-
|
95
|
-
if self.registry_uri:
|
96
|
-
mlflow.set_registry_uri(self.registry_uri)
|
97
|
-
logger.info(f"Set MLflow registry URI to {self.registry_uri}")
|
98
|
-
|
99
|
-
def create_experiment(
|
100
|
-
self,
|
101
|
-
experiment_type: ExperimentType,
|
102
|
-
model_name: str,
|
103
|
-
tags: Optional[Dict[str, str]] = None
|
104
|
-
) -> str:
|
105
|
-
"""
|
106
|
-
Create a new experiment if it doesn't exist.
|
107
|
-
|
108
|
-
Args:
|
109
|
-
experiment_type: Type of experiment
|
110
|
-
model_name: Name of the model
|
111
|
-
tags: Tags for the experiment
|
112
|
-
|
113
|
-
Returns:
|
114
|
-
ID of the experiment
|
115
|
-
"""
|
116
|
-
experiment_name = f"{model_name}_{experiment_type.value}"
|
117
|
-
|
118
|
-
# Get experiment if exists, create if not
|
119
|
-
experiment = mlflow.get_experiment_by_name(experiment_name)
|
120
|
-
if experiment is None:
|
121
|
-
experiment_id = mlflow.create_experiment(
|
122
|
-
name=experiment_name,
|
123
|
-
artifact_location=self.artifact_uri if self.artifact_uri else None,
|
124
|
-
tags=tags
|
125
|
-
)
|
126
|
-
logger.info(f"Created new experiment: {experiment_name} (ID: {experiment_id})")
|
127
|
-
else:
|
128
|
-
experiment_id = experiment.experiment_id
|
129
|
-
logger.info(f"Using existing experiment: {experiment_name} (ID: {experiment_id})")
|
130
|
-
|
131
|
-
return experiment_id
|
132
|
-
|
133
|
-
def start_run(
|
134
|
-
self,
|
135
|
-
experiment_type: ExperimentType,
|
136
|
-
model_name: str,
|
137
|
-
run_name: Optional[str] = None,
|
138
|
-
tags: Optional[Dict[str, str]] = None,
|
139
|
-
nested: bool = False
|
140
|
-
) -> mlflow.ActiveRun:
|
141
|
-
"""
|
142
|
-
Start a new MLflow run.
|
143
|
-
|
144
|
-
Args:
|
145
|
-
experiment_type: Type of experiment
|
146
|
-
model_name: Name of the model
|
147
|
-
run_name: Name for the run
|
148
|
-
tags: Tags for the run
|
149
|
-
nested: Whether this is a nested run
|
150
|
-
|
151
|
-
Returns:
|
152
|
-
MLflow active run context
|
153
|
-
"""
|
154
|
-
experiment_id = self.create_experiment(experiment_type, model_name)
|
155
|
-
|
156
|
-
if not run_name:
|
157
|
-
import datetime
|
158
|
-
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
159
|
-
run_name = f"{model_name}_{experiment_type.value}_{timestamp}"
|
160
|
-
|
161
|
-
self.active_run = mlflow.start_run(
|
162
|
-
experiment_id=experiment_id,
|
163
|
-
run_name=run_name,
|
164
|
-
tags=tags,
|
165
|
-
nested=nested
|
166
|
-
)
|
167
|
-
|
168
|
-
logger.info(f"Started MLflow run: {run_name} (ID: {self.active_run.info.run_id})")
|
169
|
-
return self.active_run
|
170
|
-
|
171
|
-
def end_run(self) -> None:
|
172
|
-
"""End the current MLflow run."""
|
173
|
-
if mlflow.active_run():
|
174
|
-
run_id = mlflow.active_run().info.run_id
|
175
|
-
mlflow.end_run()
|
176
|
-
logger.info(f"Ended MLflow run: {run_id}")
|
177
|
-
self.active_run = None
|
178
|
-
|
179
|
-
def log_params(self, params: Dict[str, Any]) -> None:
|
180
|
-
"""
|
181
|
-
Log parameters to the current run.
|
182
|
-
|
183
|
-
Args:
|
184
|
-
params: Dictionary of parameters to log
|
185
|
-
"""
|
186
|
-
if not mlflow.active_run():
|
187
|
-
logger.warning("No active run. Parameters will not be logged.")
|
188
|
-
return
|
189
|
-
|
190
|
-
mlflow.log_params(params)
|
191
|
-
logger.debug(f"Logged parameters: {params}")
|
192
|
-
|
193
|
-
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
|
194
|
-
"""
|
195
|
-
Log metrics to the current run.
|
196
|
-
|
197
|
-
Args:
|
198
|
-
metrics: Dictionary of metrics to log
|
199
|
-
step: Step value for the metrics
|
200
|
-
"""
|
201
|
-
if not mlflow.active_run():
|
202
|
-
logger.warning("No active run. Metrics will not be logged.")
|
203
|
-
return
|
204
|
-
|
205
|
-
mlflow.log_metrics(metrics, step=step)
|
206
|
-
logger.debug(f"Logged metrics: {metrics}")
|
207
|
-
|
208
|
-
def log_model(
|
209
|
-
self,
|
210
|
-
model_path: str,
|
211
|
-
name: str,
|
212
|
-
flavor: str = "pyfunc",
|
213
|
-
**kwargs
|
214
|
-
) -> str:
|
215
|
-
"""
|
216
|
-
Log a model to MLflow.
|
217
|
-
|
218
|
-
Args:
|
219
|
-
model_path: Path to the model
|
220
|
-
name: Name for the logged model
|
221
|
-
flavor: MLflow model flavor
|
222
|
-
**kwargs: Additional arguments for model logging
|
223
|
-
|
224
|
-
Returns:
|
225
|
-
Path where the model is logged
|
226
|
-
"""
|
227
|
-
if not mlflow.active_run():
|
228
|
-
logger.warning("No active run. Model will not be logged.")
|
229
|
-
return ""
|
230
|
-
|
231
|
-
log_func = getattr(mlflow, f"log_{flavor}")
|
232
|
-
if not log_func:
|
233
|
-
logger.warning(f"Unsupported model flavor: {flavor}. Using pyfunc instead.")
|
234
|
-
log_func = mlflow.pyfunc.log_model
|
235
|
-
|
236
|
-
artifact_path = f"models/{name}"
|
237
|
-
logged_model = log_func(
|
238
|
-
artifact_path=artifact_path,
|
239
|
-
path=model_path,
|
240
|
-
**kwargs
|
241
|
-
)
|
242
|
-
|
243
|
-
logger.info(f"Logged model: {name} at {artifact_path}")
|
244
|
-
return artifact_path
|
245
|
-
|
246
|
-
def log_artifact(self, local_path: str, artifact_path: Optional[str] = None) -> None:
|
247
|
-
"""
|
248
|
-
Log an artifact to MLflow.
|
249
|
-
|
250
|
-
Args:
|
251
|
-
local_path: Local path to the artifact
|
252
|
-
artifact_path: Path for the artifact in MLflow
|
253
|
-
"""
|
254
|
-
if not mlflow.active_run():
|
255
|
-
logger.warning("No active run. Artifact will not be logged.")
|
256
|
-
return
|
257
|
-
|
258
|
-
mlflow.log_artifact(local_path, artifact_path)
|
259
|
-
logger.debug(f"Logged artifact: {local_path} to {artifact_path or 'root'}")
|
260
|
-
|
261
|
-
def log_artifacts(self, local_dir: str, artifact_path: Optional[str] = None) -> None:
|
262
|
-
"""
|
263
|
-
Log multiple artifacts to MLflow.
|
264
|
-
|
265
|
-
Args:
|
266
|
-
local_dir: Local directory containing artifacts
|
267
|
-
artifact_path: Path for the artifacts in MLflow
|
268
|
-
"""
|
269
|
-
if not mlflow.active_run():
|
270
|
-
logger.warning("No active run. Artifacts will not be logged.")
|
271
|
-
return
|
272
|
-
|
273
|
-
mlflow.log_artifacts(local_dir, artifact_path)
|
274
|
-
logger.debug(f"Logged artifacts from directory: {local_dir} to {artifact_path or 'root'}")
|
275
|
-
|
276
|
-
def get_run(self, run_id: str) -> Optional[mlflow.entities.Run]:
|
277
|
-
"""
|
278
|
-
Get a run by ID.
|
279
|
-
|
280
|
-
Args:
|
281
|
-
run_id: ID of the run
|
282
|
-
|
283
|
-
Returns:
|
284
|
-
MLflow run entity or None if not found
|
285
|
-
"""
|
286
|
-
try:
|
287
|
-
return self.client.get_run(run_id)
|
288
|
-
except mlflow.exceptions.MlflowException as e:
|
289
|
-
logger.error(f"Failed to get run {run_id}: {e}")
|
290
|
-
return None
|
291
|
-
|
292
|
-
def search_runs(
|
293
|
-
self,
|
294
|
-
experiment_ids: List[str],
|
295
|
-
filter_string: Optional[str] = None,
|
296
|
-
max_results: int = 100
|
297
|
-
) -> List[mlflow.entities.Run]:
|
298
|
-
"""
|
299
|
-
Search for runs in the given experiments.
|
300
|
-
|
301
|
-
Args:
|
302
|
-
experiment_ids: List of experiment IDs
|
303
|
-
filter_string: Filter string for the search
|
304
|
-
max_results: Maximum number of results to return
|
305
|
-
|
306
|
-
Returns:
|
307
|
-
List of MLflow run entities
|
308
|
-
"""
|
309
|
-
try:
|
310
|
-
return self.client.search_runs(
|
311
|
-
experiment_ids=experiment_ids,
|
312
|
-
filter_string=filter_string,
|
313
|
-
max_results=max_results
|
314
|
-
)
|
315
|
-
except mlflow.exceptions.MlflowException as e:
|
316
|
-
logger.error(f"Failed to search runs: {e}")
|
317
|
-
return []
|
318
|
-
|
319
|
-
def get_experiment_id_by_name(self, experiment_name: str) -> Optional[str]:
|
320
|
-
"""
|
321
|
-
Get experiment ID by name.
|
322
|
-
|
323
|
-
Args:
|
324
|
-
experiment_name: Name of the experiment
|
325
|
-
|
326
|
-
Returns:
|
327
|
-
Experiment ID or None if not found
|
328
|
-
"""
|
329
|
-
experiment = mlflow.get_experiment_by_name(experiment_name)
|
330
|
-
if experiment:
|
331
|
-
return experiment.experiment_id
|
332
|
-
return None
|
333
|
-
|
334
|
-
def set_tracking_tag(self, key: str, value: str) -> None:
|
335
|
-
"""
|
336
|
-
Set a tag for the current run.
|
337
|
-
|
338
|
-
Args:
|
339
|
-
key: Tag key
|
340
|
-
value: Tag value
|
341
|
-
"""
|
342
|
-
if not mlflow.active_run():
|
343
|
-
logger.warning("No active run. Tag will not be set.")
|
344
|
-
return
|
345
|
-
|
346
|
-
mlflow.set_tag(key, value)
|
347
|
-
logger.debug(f"Set tag: {key}={value}")
|
348
|
-
|
349
|
-
def create_model_version(
|
350
|
-
self,
|
351
|
-
name: str,
|
352
|
-
source: str,
|
353
|
-
description: Optional[str] = None,
|
354
|
-
tags: Optional[Dict[str, str]] = None
|
355
|
-
) -> Optional[str]:
|
356
|
-
"""
|
357
|
-
Create a new model version in the registry.
|
358
|
-
|
359
|
-
Args:
|
360
|
-
name: Name of the registered model
|
361
|
-
source: Source path of the model
|
362
|
-
description: Description for the model version
|
363
|
-
tags: Tags for the model version
|
364
|
-
|
365
|
-
Returns:
|
366
|
-
Version of the created model or None if creation failed
|
367
|
-
"""
|
368
|
-
try:
|
369
|
-
version = self.client.create_model_version(
|
370
|
-
name=name,
|
371
|
-
source=source,
|
372
|
-
description=description,
|
373
|
-
tags=tags
|
374
|
-
)
|
375
|
-
logger.info(f"Created model version: {name} v{version.version}")
|
376
|
-
return version.version
|
377
|
-
except mlflow.exceptions.MlflowException as e:
|
378
|
-
logger.error(f"Failed to create model version: {e}")
|
379
|
-
return None
|