isa-model 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. isa_model/__init__.py +5 -0
  2. isa_model/core/model_manager.py +143 -0
  3. isa_model/core/model_registry.py +115 -0
  4. isa_model/core/model_router.py +226 -0
  5. isa_model/core/model_storage.py +133 -0
  6. isa_model/core/model_version.py +0 -0
  7. isa_model/core/resource_manager.py +202 -0
  8. isa_model/core/storage/hf_storage.py +0 -0
  9. isa_model/core/storage/local_storage.py +0 -0
  10. isa_model/core/storage/minio_storage.py +0 -0
  11. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
  12. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
  13. isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
  14. isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
  15. isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
  16. isa_model/inference/__init__.py +11 -0
  17. isa_model/inference/adapter/unified_api.py +248 -0
  18. isa_model/inference/ai_factory.py +359 -0
  19. isa_model/inference/base.py +46 -0
  20. isa_model/inference/providers/__init__.py +19 -0
  21. isa_model/inference/providers/base_provider.py +30 -0
  22. isa_model/inference/providers/model_cache_manager.py +341 -0
  23. isa_model/inference/providers/ollama_provider.py +73 -0
  24. isa_model/inference/providers/openai_provider.py +101 -0
  25. isa_model/inference/providers/replicate_provider.py +107 -0
  26. isa_model/inference/providers/triton_provider.py +439 -0
  27. isa_model/inference/services/__init__.py +14 -0
  28. isa_model/inference/services/audio/base_stt_service.py +91 -0
  29. isa_model/inference/services/audio/base_tts_service.py +136 -0
  30. isa_model/inference/services/audio/openai_tts_service.py +71 -0
  31. isa_model/inference/services/base_service.py +106 -0
  32. isa_model/inference/services/embedding/ollama_embed_service.py +97 -0
  33. isa_model/inference/services/embedding/openai_embed_service.py +0 -0
  34. isa_model/inference/services/llm/__init__.py +12 -0
  35. isa_model/inference/services/llm/base_llm_service.py +134 -0
  36. isa_model/inference/services/llm/ollama_llm_service.py +99 -0
  37. isa_model/inference/services/llm/openai_llm_service.py +138 -0
  38. isa_model/inference/services/others/table_transformer_service.py +61 -0
  39. isa_model/inference/services/vision/__init__.py +12 -0
  40. isa_model/inference/services/vision/helpers/image_utils.py +58 -0
  41. isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
  42. isa_model/inference/services/vision/ollama_vision_service.py +60 -0
  43. isa_model/inference/services/vision/openai_vision_service.py +80 -0
  44. isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
  45. isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
  46. isa_model/inference/utils/conversion/onnx_converter.py +0 -0
  47. isa_model/inference/utils/conversion/torch_converter.py +0 -0
  48. isa_model/scripts/inference_tracker.py +283 -0
  49. isa_model/scripts/mlflow_manager.py +379 -0
  50. isa_model/scripts/model_registry.py +465 -0
  51. isa_model/scripts/start_mlflow.py +95 -0
  52. isa_model/scripts/training_tracker.py +257 -0
  53. isa_model/training/engine/llama_factory/__init__.py +39 -0
  54. isa_model/training/engine/llama_factory/config.py +115 -0
  55. isa_model/training/engine/llama_factory/data_adapter.py +284 -0
  56. isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
  57. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
  58. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
  59. isa_model/training/engine/llama_factory/factory.py +331 -0
  60. isa_model/training/engine/llama_factory/rl.py +254 -0
  61. isa_model/training/engine/llama_factory/trainer.py +171 -0
  62. isa_model/training/image_model/configs/create_config.py +37 -0
  63. isa_model/training/image_model/configs/create_flux_config.py +26 -0
  64. isa_model/training/image_model/configs/create_lora_config.py +21 -0
  65. isa_model/training/image_model/prepare_massed_compute.py +97 -0
  66. isa_model/training/image_model/prepare_upload.py +17 -0
  67. isa_model/training/image_model/raw_data/create_captions.py +16 -0
  68. isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
  69. isa_model/training/image_model/raw_data/pre_processing.py +200 -0
  70. isa_model/training/image_model/train/train.py +42 -0
  71. isa_model/training/image_model/train/train_flux.py +41 -0
  72. isa_model/training/image_model/train/train_lora.py +57 -0
  73. isa_model/training/image_model/train_main.py +25 -0
  74. isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
  75. isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
  76. isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
  77. isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
  78. isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
  79. isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
  80. isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
  81. isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
  82. isa_model-0.0.1.dist-info/METADATA +327 -0
  83. isa_model-0.0.1.dist-info/RECORD +86 -0
  84. isa_model-0.0.1.dist-info/WHEEL +5 -0
  85. isa_model-0.0.1.dist-info/licenses/LICENSE +21 -0
  86. isa_model-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,283 @@
1
+ """
2
+ MLflow tracker for inference workflows.
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import time
8
+ import logging
9
+ from typing import Dict, List, Optional, Any, Union
10
+ from contextlib import contextmanager
11
+
12
+ from .mlflow_manager import MLflowManager, ExperimentType
13
+ from .model_registry import ModelRegistry, ModelStage, ModelVersion
14
+
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class InferenceTracker:
20
+ """
21
+ Tracker for model inference workflows.
22
+
23
+ This class provides utilities to track model inference using MLflow,
24
+ including performance metrics and input/output logging.
25
+
26
+ Example:
27
+ ```python
28
+ # Initialize tracker
29
+ tracker = InferenceTracker(
30
+ tracking_uri="http://localhost:5000"
31
+ )
32
+
33
+ # Get model from registry
34
+ model_version = tracker.get_production_model("llama-7b")
35
+
36
+ # Track inference
37
+ with tracker.track_inference(
38
+ model_name="llama-7b",
39
+ model_version=model_version.version
40
+ ):
41
+ # Start timer
42
+ start_time = time.time()
43
+
44
+ # Generate text
45
+ output = model.generate(prompt)
46
+
47
+ # Log inference
48
+ tracker.log_inference(
49
+ input=prompt,
50
+ output=output,
51
+ latency_ms=(time.time() - start_time) * 1000
52
+ )
53
+ ```
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ tracking_uri: Optional[str] = None,
59
+ artifact_uri: Optional[str] = None,
60
+ registry_uri: Optional[str] = None
61
+ ):
62
+ """
63
+ Initialize the inference tracker.
64
+
65
+ Args:
66
+ tracking_uri: URI for MLflow tracking server
67
+ artifact_uri: URI for MLflow artifacts
68
+ registry_uri: URI for MLflow model registry
69
+ """
70
+ self.mlflow_manager = MLflowManager(
71
+ tracking_uri=tracking_uri,
72
+ artifact_uri=artifact_uri,
73
+ registry_uri=registry_uri
74
+ )
75
+ self.model_registry = ModelRegistry(
76
+ tracking_uri=tracking_uri,
77
+ registry_uri=registry_uri
78
+ )
79
+ self.current_run_info = {}
80
+ self.inference_samples = []
81
+
82
+ def get_production_model(self, model_name: str) -> Optional[ModelVersion]:
83
+ """
84
+ Get the production version of a model.
85
+
86
+ Args:
87
+ model_name: Name of the model
88
+
89
+ Returns:
90
+ Production ModelVersion or None if not found
91
+ """
92
+ return self.model_registry.get_latest_model_version(
93
+ name=model_name,
94
+ stage=ModelStage.PRODUCTION
95
+ )
96
+
97
+ def get_staging_model(self, model_name: str) -> Optional[ModelVersion]:
98
+ """
99
+ Get the staging version of a model.
100
+
101
+ Args:
102
+ model_name: Name of the model
103
+
104
+ Returns:
105
+ Staging ModelVersion or None if not found
106
+ """
107
+ return self.model_registry.get_latest_model_version(
108
+ name=model_name,
109
+ stage=ModelStage.STAGING
110
+ )
111
+
112
+ @contextmanager
113
+ def track_inference(
114
+ self,
115
+ model_name: str,
116
+ model_version: Optional[str] = None,
117
+ batch_size: Optional[int] = None,
118
+ tags: Optional[Dict[str, str]] = None
119
+ ):
120
+ """
121
+ Track model inference with MLflow.
122
+
123
+ Args:
124
+ model_name: Name of the model
125
+ model_version: Version of the model
126
+ batch_size: Batch size for inference
127
+ tags: Tags for the run
128
+
129
+ Yields:
130
+ Dictionary with run information
131
+ """
132
+ run_info = {
133
+ "model_name": model_name,
134
+ "model_version": model_version,
135
+ "batch_size": batch_size,
136
+ "start_time": time.time(),
137
+ "metrics": {}
138
+ }
139
+
140
+ # Prepare tags
141
+ if tags is None:
142
+ tags = {}
143
+
144
+ tags["model_name"] = model_name
145
+ if model_version:
146
+ tags["model_version"] = model_version
147
+
148
+ if batch_size:
149
+ tags["batch_size"] = str(batch_size)
150
+
151
+ # Start the MLflow run
152
+ with self.mlflow_manager.start_run(
153
+ experiment_type=ExperimentType.INFERENCE,
154
+ model_name=model_name,
155
+ tags=tags
156
+ ) as run:
157
+ run_info["run_id"] = run.info.run_id
158
+ run_info["experiment_id"] = run.info.experiment_id
159
+
160
+ # Reset inference samples
161
+ self.inference_samples = []
162
+
163
+ self.current_run_info = run_info
164
+ try:
165
+ yield run_info
166
+
167
+ # Calculate and log summary metrics
168
+ self._log_summary_metrics()
169
+
170
+ # Save inference samples
171
+ if self.inference_samples:
172
+ self._save_inference_samples()
173
+
174
+ finally:
175
+ run_info["end_time"] = time.time()
176
+ run_info["duration"] = run_info["end_time"] - run_info["start_time"]
177
+
178
+ # Log duration
179
+ self.mlflow_manager.log_metrics({
180
+ "duration_seconds": run_info["duration"]
181
+ })
182
+
183
+ self.current_run_info = {}
184
+
185
+ def log_inference(
186
+ self,
187
+ input: str,
188
+ output: str,
189
+ latency_ms: Optional[float] = None,
190
+ token_count: Optional[int] = None,
191
+ tokens_per_second: Optional[float] = None,
192
+ metadata: Optional[Dict[str, Any]] = None
193
+ ) -> None:
194
+ """
195
+ Log an inference sample.
196
+
197
+ Args:
198
+ input: Input prompt
199
+ output: Generated output
200
+ latency_ms: Latency in milliseconds
201
+ token_count: Number of tokens generated
202
+ tokens_per_second: Tokens per second
203
+ metadata: Additional metadata
204
+ """
205
+ if not self.current_run_info:
206
+ logger.warning("No active run. Inference will not be logged.")
207
+ return
208
+
209
+ sample = {
210
+ "input": input,
211
+ "output": output,
212
+ "timestamp": time.time()
213
+ }
214
+
215
+ if latency_ms is not None:
216
+ sample["latency_ms"] = latency_ms
217
+
218
+ if token_count is not None:
219
+ sample["token_count"] = token_count
220
+
221
+ if tokens_per_second is not None:
222
+ sample["tokens_per_second"] = tokens_per_second
223
+
224
+ if metadata:
225
+ sample["metadata"] = metadata
226
+
227
+ self.inference_samples.append(sample)
228
+
229
+ # Log individual metrics
230
+ metrics = {}
231
+ if latency_ms is not None:
232
+ metrics["latency_ms"] = latency_ms
233
+
234
+ if token_count is not None:
235
+ metrics["token_count"] = token_count
236
+
237
+ if tokens_per_second is not None:
238
+ metrics["tokens_per_second"] = tokens_per_second
239
+
240
+ if metrics:
241
+ self.mlflow_manager.log_metrics(metrics)
242
+
243
+ def _log_summary_metrics(self) -> None:
244
+ """Log summary metrics based on all inference samples."""
245
+ if not self.inference_samples:
246
+ return
247
+
248
+ latencies = [s.get("latency_ms") for s in self.inference_samples if "latency_ms" in s]
249
+ token_counts = [s.get("token_count") for s in self.inference_samples if "token_count" in s]
250
+ tokens_per_second = [s.get("tokens_per_second") for s in self.inference_samples if "tokens_per_second" in s]
251
+
252
+ metrics = {
253
+ "inference_count": len(self.inference_samples)
254
+ }
255
+
256
+ if latencies:
257
+ metrics["avg_latency_ms"] = sum(latencies) / len(latencies)
258
+ metrics["min_latency_ms"] = min(latencies)
259
+ metrics["max_latency_ms"] = max(latencies)
260
+
261
+ if token_counts:
262
+ metrics["avg_token_count"] = sum(token_counts) / len(token_counts)
263
+ metrics["total_tokens"] = sum(token_counts)
264
+
265
+ if tokens_per_second:
266
+ metrics["avg_tokens_per_second"] = sum(tokens_per_second) / len(tokens_per_second)
267
+
268
+ self.mlflow_manager.log_metrics(metrics)
269
+
270
+ def _save_inference_samples(self) -> None:
271
+ """Save inference samples as an artifact."""
272
+ import tempfile
273
+
274
+ with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as f:
275
+ json.dump(self.inference_samples, f, indent=2)
276
+ temp_path = f.name
277
+
278
+ self.mlflow_manager.log_artifact(temp_path, "inference_samples.json")
279
+
280
+ try:
281
+ os.remove(temp_path)
282
+ except:
283
+ pass
@@ -0,0 +1,379 @@
1
+ """
2
+ MLflow manager for experiment tracking and model management.
3
+ """
4
+
5
+ import os
6
+ import logging
7
+ from enum import Enum
8
+ from typing import Dict, List, Optional, Any, Union
9
+ import mlflow
10
+ from mlflow.tracking import MlflowClient
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class ExperimentType(str, Enum):
16
+ """Types of experiments that can be tracked."""
17
+
18
+ TRAINING = "training"
19
+ FINETUNING = "finetuning"
20
+ REINFORCEMENT_LEARNING = "rl"
21
+ INFERENCE = "inference"
22
+ EVALUATION = "evaluation"
23
+
24
+
25
+ class MLflowManager:
26
+ """
27
+ Manager class for MLflow operations.
28
+
29
+ This class provides methods to set up MLflow, track experiments,
30
+ log metrics, and manage models.
31
+
32
+ Example:
33
+ ```python
34
+ # Initialize MLflow manager
35
+ mlflow_manager = MLflowManager(
36
+ tracking_uri="http://localhost:5000",
37
+ artifact_uri="s3://bucket/artifacts"
38
+ )
39
+
40
+ # Set up experiment and start run
41
+ with mlflow_manager.start_run(
42
+ experiment_type=ExperimentType.FINETUNING,
43
+ model_name="llama-7b"
44
+ ) as run:
45
+ # Log parameters
46
+ mlflow_manager.log_params({
47
+ "learning_rate": 2e-5,
48
+ "batch_size": 8
49
+ })
50
+
51
+ # Train model...
52
+
53
+ # Log metrics
54
+ mlflow_manager.log_metrics({
55
+ "accuracy": 0.95,
56
+ "loss": 0.02
57
+ })
58
+
59
+ # Log model
60
+ mlflow_manager.log_model(
61
+ model_path="/path/to/model",
62
+ name="finetuned-llama-7b"
63
+ )
64
+ ```
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ tracking_uri: Optional[str] = None,
70
+ artifact_uri: Optional[str] = None,
71
+ registry_uri: Optional[str] = None
72
+ ):
73
+ """
74
+ Initialize the MLflow manager.
75
+
76
+ Args:
77
+ tracking_uri: URI for MLflow tracking server
78
+ artifact_uri: URI for MLflow artifacts
79
+ registry_uri: URI for MLflow model registry
80
+ """
81
+ self.tracking_uri = tracking_uri or os.environ.get("MLFLOW_TRACKING_URI", "")
82
+ self.artifact_uri = artifact_uri or os.environ.get("MLFLOW_ARTIFACT_URI", "")
83
+ self.registry_uri = registry_uri or os.environ.get("MLFLOW_REGISTRY_URI", "")
84
+
85
+ self._setup_mlflow()
86
+ self.client = MlflowClient(tracking_uri=self.tracking_uri, registry_uri=self.registry_uri)
87
+ self.active_run = None
88
+
89
+ def _setup_mlflow(self) -> None:
90
+ """Set up MLflow configuration."""
91
+ if self.tracking_uri:
92
+ mlflow.set_tracking_uri(self.tracking_uri)
93
+ logger.info(f"Set MLflow tracking URI to {self.tracking_uri}")
94
+
95
+ if self.registry_uri:
96
+ mlflow.set_registry_uri(self.registry_uri)
97
+ logger.info(f"Set MLflow registry URI to {self.registry_uri}")
98
+
99
+ def create_experiment(
100
+ self,
101
+ experiment_type: ExperimentType,
102
+ model_name: str,
103
+ tags: Optional[Dict[str, str]] = None
104
+ ) -> str:
105
+ """
106
+ Create a new experiment if it doesn't exist.
107
+
108
+ Args:
109
+ experiment_type: Type of experiment
110
+ model_name: Name of the model
111
+ tags: Tags for the experiment
112
+
113
+ Returns:
114
+ ID of the experiment
115
+ """
116
+ experiment_name = f"{model_name}_{experiment_type.value}"
117
+
118
+ # Get experiment if exists, create if not
119
+ experiment = mlflow.get_experiment_by_name(experiment_name)
120
+ if experiment is None:
121
+ experiment_id = mlflow.create_experiment(
122
+ name=experiment_name,
123
+ artifact_location=self.artifact_uri if self.artifact_uri else None,
124
+ tags=tags
125
+ )
126
+ logger.info(f"Created new experiment: {experiment_name} (ID: {experiment_id})")
127
+ else:
128
+ experiment_id = experiment.experiment_id
129
+ logger.info(f"Using existing experiment: {experiment_name} (ID: {experiment_id})")
130
+
131
+ return experiment_id
132
+
133
+ def start_run(
134
+ self,
135
+ experiment_type: ExperimentType,
136
+ model_name: str,
137
+ run_name: Optional[str] = None,
138
+ tags: Optional[Dict[str, str]] = None,
139
+ nested: bool = False
140
+ ) -> mlflow.ActiveRun:
141
+ """
142
+ Start a new MLflow run.
143
+
144
+ Args:
145
+ experiment_type: Type of experiment
146
+ model_name: Name of the model
147
+ run_name: Name for the run
148
+ tags: Tags for the run
149
+ nested: Whether this is a nested run
150
+
151
+ Returns:
152
+ MLflow active run context
153
+ """
154
+ experiment_id = self.create_experiment(experiment_type, model_name)
155
+
156
+ if not run_name:
157
+ import datetime
158
+ timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
159
+ run_name = f"{model_name}_{experiment_type.value}_{timestamp}"
160
+
161
+ self.active_run = mlflow.start_run(
162
+ experiment_id=experiment_id,
163
+ run_name=run_name,
164
+ tags=tags,
165
+ nested=nested
166
+ )
167
+
168
+ logger.info(f"Started MLflow run: {run_name} (ID: {self.active_run.info.run_id})")
169
+ return self.active_run
170
+
171
+ def end_run(self) -> None:
172
+ """End the current MLflow run."""
173
+ if mlflow.active_run():
174
+ run_id = mlflow.active_run().info.run_id
175
+ mlflow.end_run()
176
+ logger.info(f"Ended MLflow run: {run_id}")
177
+ self.active_run = None
178
+
179
+ def log_params(self, params: Dict[str, Any]) -> None:
180
+ """
181
+ Log parameters to the current run.
182
+
183
+ Args:
184
+ params: Dictionary of parameters to log
185
+ """
186
+ if not mlflow.active_run():
187
+ logger.warning("No active run. Parameters will not be logged.")
188
+ return
189
+
190
+ mlflow.log_params(params)
191
+ logger.debug(f"Logged parameters: {params}")
192
+
193
+ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
194
+ """
195
+ Log metrics to the current run.
196
+
197
+ Args:
198
+ metrics: Dictionary of metrics to log
199
+ step: Step value for the metrics
200
+ """
201
+ if not mlflow.active_run():
202
+ logger.warning("No active run. Metrics will not be logged.")
203
+ return
204
+
205
+ mlflow.log_metrics(metrics, step=step)
206
+ logger.debug(f"Logged metrics: {metrics}")
207
+
208
+ def log_model(
209
+ self,
210
+ model_path: str,
211
+ name: str,
212
+ flavor: str = "pyfunc",
213
+ **kwargs
214
+ ) -> str:
215
+ """
216
+ Log a model to MLflow.
217
+
218
+ Args:
219
+ model_path: Path to the model
220
+ name: Name for the logged model
221
+ flavor: MLflow model flavor
222
+ **kwargs: Additional arguments for model logging
223
+
224
+ Returns:
225
+ Path where the model is logged
226
+ """
227
+ if not mlflow.active_run():
228
+ logger.warning("No active run. Model will not be logged.")
229
+ return ""
230
+
231
+ log_func = getattr(mlflow, f"log_{flavor}")
232
+ if not log_func:
233
+ logger.warning(f"Unsupported model flavor: {flavor}. Using pyfunc instead.")
234
+ log_func = mlflow.pyfunc.log_model
235
+
236
+ artifact_path = f"models/{name}"
237
+ logged_model = log_func(
238
+ artifact_path=artifact_path,
239
+ path=model_path,
240
+ **kwargs
241
+ )
242
+
243
+ logger.info(f"Logged model: {name} at {artifact_path}")
244
+ return artifact_path
245
+
246
+ def log_artifact(self, local_path: str, artifact_path: Optional[str] = None) -> None:
247
+ """
248
+ Log an artifact to MLflow.
249
+
250
+ Args:
251
+ local_path: Local path to the artifact
252
+ artifact_path: Path for the artifact in MLflow
253
+ """
254
+ if not mlflow.active_run():
255
+ logger.warning("No active run. Artifact will not be logged.")
256
+ return
257
+
258
+ mlflow.log_artifact(local_path, artifact_path)
259
+ logger.debug(f"Logged artifact: {local_path} to {artifact_path or 'root'}")
260
+
261
+ def log_artifacts(self, local_dir: str, artifact_path: Optional[str] = None) -> None:
262
+ """
263
+ Log multiple artifacts to MLflow.
264
+
265
+ Args:
266
+ local_dir: Local directory containing artifacts
267
+ artifact_path: Path for the artifacts in MLflow
268
+ """
269
+ if not mlflow.active_run():
270
+ logger.warning("No active run. Artifacts will not be logged.")
271
+ return
272
+
273
+ mlflow.log_artifacts(local_dir, artifact_path)
274
+ logger.debug(f"Logged artifacts from directory: {local_dir} to {artifact_path or 'root'}")
275
+
276
+ def get_run(self, run_id: str) -> Optional[mlflow.entities.Run]:
277
+ """
278
+ Get a run by ID.
279
+
280
+ Args:
281
+ run_id: ID of the run
282
+
283
+ Returns:
284
+ MLflow run entity or None if not found
285
+ """
286
+ try:
287
+ return self.client.get_run(run_id)
288
+ except mlflow.exceptions.MlflowException as e:
289
+ logger.error(f"Failed to get run {run_id}: {e}")
290
+ return None
291
+
292
+ def search_runs(
293
+ self,
294
+ experiment_ids: List[str],
295
+ filter_string: Optional[str] = None,
296
+ max_results: int = 100
297
+ ) -> List[mlflow.entities.Run]:
298
+ """
299
+ Search for runs in the given experiments.
300
+
301
+ Args:
302
+ experiment_ids: List of experiment IDs
303
+ filter_string: Filter string for the search
304
+ max_results: Maximum number of results to return
305
+
306
+ Returns:
307
+ List of MLflow run entities
308
+ """
309
+ try:
310
+ return self.client.search_runs(
311
+ experiment_ids=experiment_ids,
312
+ filter_string=filter_string,
313
+ max_results=max_results
314
+ )
315
+ except mlflow.exceptions.MlflowException as e:
316
+ logger.error(f"Failed to search runs: {e}")
317
+ return []
318
+
319
+ def get_experiment_id_by_name(self, experiment_name: str) -> Optional[str]:
320
+ """
321
+ Get experiment ID by name.
322
+
323
+ Args:
324
+ experiment_name: Name of the experiment
325
+
326
+ Returns:
327
+ Experiment ID or None if not found
328
+ """
329
+ experiment = mlflow.get_experiment_by_name(experiment_name)
330
+ if experiment:
331
+ return experiment.experiment_id
332
+ return None
333
+
334
+ def set_tracking_tag(self, key: str, value: str) -> None:
335
+ """
336
+ Set a tag for the current run.
337
+
338
+ Args:
339
+ key: Tag key
340
+ value: Tag value
341
+ """
342
+ if not mlflow.active_run():
343
+ logger.warning("No active run. Tag will not be set.")
344
+ return
345
+
346
+ mlflow.set_tag(key, value)
347
+ logger.debug(f"Set tag: {key}={value}")
348
+
349
+ def create_model_version(
350
+ self,
351
+ name: str,
352
+ source: str,
353
+ description: Optional[str] = None,
354
+ tags: Optional[Dict[str, str]] = None
355
+ ) -> Optional[str]:
356
+ """
357
+ Create a new model version in the registry.
358
+
359
+ Args:
360
+ name: Name of the registered model
361
+ source: Source path of the model
362
+ description: Description for the model version
363
+ tags: Tags for the model version
364
+
365
+ Returns:
366
+ Version of the created model or None if creation failed
367
+ """
368
+ try:
369
+ version = self.client.create_model_version(
370
+ name=name,
371
+ source=source,
372
+ description=description,
373
+ tags=tags
374
+ )
375
+ logger.info(f"Created model version: {name} v{version.version}")
376
+ return version.version
377
+ except mlflow.exceptions.MlflowException as e:
378
+ logger.error(f"Failed to create model version: {e}")
379
+ return None