mcli-framework 7.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mcli-framework might be problematic. Click here for more details.
- mcli/app/chat_cmd.py +42 -0
- mcli/app/commands_cmd.py +226 -0
- mcli/app/completion_cmd.py +216 -0
- mcli/app/completion_helpers.py +288 -0
- mcli/app/cron_test_cmd.py +697 -0
- mcli/app/logs_cmd.py +419 -0
- mcli/app/main.py +492 -0
- mcli/app/model/model.py +1060 -0
- mcli/app/model_cmd.py +227 -0
- mcli/app/redis_cmd.py +269 -0
- mcli/app/video/video.py +1114 -0
- mcli/app/visual_cmd.py +303 -0
- mcli/chat/chat.py +2409 -0
- mcli/chat/command_rag.py +514 -0
- mcli/chat/enhanced_chat.py +652 -0
- mcli/chat/system_controller.py +1010 -0
- mcli/chat/system_integration.py +1016 -0
- mcli/cli.py +25 -0
- mcli/config.toml +20 -0
- mcli/lib/api/api.py +586 -0
- mcli/lib/api/daemon_client.py +203 -0
- mcli/lib/api/daemon_client_local.py +44 -0
- mcli/lib/api/daemon_decorator.py +217 -0
- mcli/lib/api/mcli_decorators.py +1032 -0
- mcli/lib/auth/auth.py +85 -0
- mcli/lib/auth/aws_manager.py +85 -0
- mcli/lib/auth/azure_manager.py +91 -0
- mcli/lib/auth/credential_manager.py +192 -0
- mcli/lib/auth/gcp_manager.py +93 -0
- mcli/lib/auth/key_manager.py +117 -0
- mcli/lib/auth/mcli_manager.py +93 -0
- mcli/lib/auth/token_manager.py +75 -0
- mcli/lib/auth/token_util.py +1011 -0
- mcli/lib/config/config.py +47 -0
- mcli/lib/discovery/__init__.py +1 -0
- mcli/lib/discovery/command_discovery.py +274 -0
- mcli/lib/erd/erd.py +1345 -0
- mcli/lib/erd/generate_graph.py +453 -0
- mcli/lib/files/files.py +76 -0
- mcli/lib/fs/fs.py +109 -0
- mcli/lib/lib.py +29 -0
- mcli/lib/logger/logger.py +611 -0
- mcli/lib/performance/optimizer.py +409 -0
- mcli/lib/performance/rust_bridge.py +502 -0
- mcli/lib/performance/uvloop_config.py +154 -0
- mcli/lib/pickles/pickles.py +50 -0
- mcli/lib/search/cached_vectorizer.py +479 -0
- mcli/lib/services/data_pipeline.py +460 -0
- mcli/lib/services/lsh_client.py +441 -0
- mcli/lib/services/redis_service.py +387 -0
- mcli/lib/shell/shell.py +137 -0
- mcli/lib/toml/toml.py +33 -0
- mcli/lib/ui/styling.py +47 -0
- mcli/lib/ui/visual_effects.py +634 -0
- mcli/lib/watcher/watcher.py +185 -0
- mcli/ml/api/app.py +215 -0
- mcli/ml/api/middleware.py +224 -0
- mcli/ml/api/routers/admin_router.py +12 -0
- mcli/ml/api/routers/auth_router.py +244 -0
- mcli/ml/api/routers/backtest_router.py +12 -0
- mcli/ml/api/routers/data_router.py +12 -0
- mcli/ml/api/routers/model_router.py +302 -0
- mcli/ml/api/routers/monitoring_router.py +12 -0
- mcli/ml/api/routers/portfolio_router.py +12 -0
- mcli/ml/api/routers/prediction_router.py +267 -0
- mcli/ml/api/routers/trade_router.py +12 -0
- mcli/ml/api/routers/websocket_router.py +76 -0
- mcli/ml/api/schemas.py +64 -0
- mcli/ml/auth/auth_manager.py +425 -0
- mcli/ml/auth/models.py +154 -0
- mcli/ml/auth/permissions.py +302 -0
- mcli/ml/backtesting/backtest_engine.py +502 -0
- mcli/ml/backtesting/performance_metrics.py +393 -0
- mcli/ml/cache.py +400 -0
- mcli/ml/cli/main.py +398 -0
- mcli/ml/config/settings.py +394 -0
- mcli/ml/configs/dvc_config.py +230 -0
- mcli/ml/configs/mlflow_config.py +131 -0
- mcli/ml/configs/mlops_manager.py +293 -0
- mcli/ml/dashboard/app.py +532 -0
- mcli/ml/dashboard/app_integrated.py +738 -0
- mcli/ml/dashboard/app_supabase.py +560 -0
- mcli/ml/dashboard/app_training.py +615 -0
- mcli/ml/dashboard/cli.py +51 -0
- mcli/ml/data_ingestion/api_connectors.py +501 -0
- mcli/ml/data_ingestion/data_pipeline.py +567 -0
- mcli/ml/data_ingestion/stream_processor.py +512 -0
- mcli/ml/database/migrations/env.py +94 -0
- mcli/ml/database/models.py +667 -0
- mcli/ml/database/session.py +200 -0
- mcli/ml/experimentation/ab_testing.py +845 -0
- mcli/ml/features/ensemble_features.py +607 -0
- mcli/ml/features/political_features.py +676 -0
- mcli/ml/features/recommendation_engine.py +809 -0
- mcli/ml/features/stock_features.py +573 -0
- mcli/ml/features/test_feature_engineering.py +346 -0
- mcli/ml/logging.py +85 -0
- mcli/ml/mlops/data_versioning.py +518 -0
- mcli/ml/mlops/experiment_tracker.py +377 -0
- mcli/ml/mlops/model_serving.py +481 -0
- mcli/ml/mlops/pipeline_orchestrator.py +614 -0
- mcli/ml/models/base_models.py +324 -0
- mcli/ml/models/ensemble_models.py +675 -0
- mcli/ml/models/recommendation_models.py +474 -0
- mcli/ml/models/test_models.py +487 -0
- mcli/ml/monitoring/drift_detection.py +676 -0
- mcli/ml/monitoring/metrics.py +45 -0
- mcli/ml/optimization/portfolio_optimizer.py +834 -0
- mcli/ml/preprocessing/data_cleaners.py +451 -0
- mcli/ml/preprocessing/feature_extractors.py +491 -0
- mcli/ml/preprocessing/ml_pipeline.py +382 -0
- mcli/ml/preprocessing/politician_trading_preprocessor.py +569 -0
- mcli/ml/preprocessing/test_preprocessing.py +294 -0
- mcli/ml/scripts/populate_sample_data.py +200 -0
- mcli/ml/tasks.py +400 -0
- mcli/ml/tests/test_integration.py +429 -0
- mcli/ml/tests/test_training_dashboard.py +387 -0
- mcli/public/oi/oi.py +15 -0
- mcli/public/public.py +4 -0
- mcli/self/self_cmd.py +1246 -0
- mcli/workflow/daemon/api_daemon.py +800 -0
- mcli/workflow/daemon/async_command_database.py +681 -0
- mcli/workflow/daemon/async_process_manager.py +591 -0
- mcli/workflow/daemon/client.py +530 -0
- mcli/workflow/daemon/commands.py +1196 -0
- mcli/workflow/daemon/daemon.py +905 -0
- mcli/workflow/daemon/daemon_api.py +59 -0
- mcli/workflow/daemon/enhanced_daemon.py +571 -0
- mcli/workflow/daemon/process_cli.py +244 -0
- mcli/workflow/daemon/process_manager.py +439 -0
- mcli/workflow/daemon/test_daemon.py +275 -0
- mcli/workflow/dashboard/dashboard_cmd.py +113 -0
- mcli/workflow/docker/docker.py +0 -0
- mcli/workflow/file/file.py +100 -0
- mcli/workflow/gcloud/config.toml +21 -0
- mcli/workflow/gcloud/gcloud.py +58 -0
- mcli/workflow/git_commit/ai_service.py +328 -0
- mcli/workflow/git_commit/commands.py +430 -0
- mcli/workflow/lsh_integration.py +355 -0
- mcli/workflow/model_service/client.py +594 -0
- mcli/workflow/model_service/download_and_run_efficient_models.py +288 -0
- mcli/workflow/model_service/lightweight_embedder.py +397 -0
- mcli/workflow/model_service/lightweight_model_server.py +714 -0
- mcli/workflow/model_service/lightweight_test.py +241 -0
- mcli/workflow/model_service/model_service.py +1955 -0
- mcli/workflow/model_service/ollama_efficient_runner.py +425 -0
- mcli/workflow/model_service/pdf_processor.py +386 -0
- mcli/workflow/model_service/test_efficient_runner.py +234 -0
- mcli/workflow/model_service/test_example.py +315 -0
- mcli/workflow/model_service/test_integration.py +131 -0
- mcli/workflow/model_service/test_new_features.py +149 -0
- mcli/workflow/openai/openai.py +99 -0
- mcli/workflow/politician_trading/commands.py +1790 -0
- mcli/workflow/politician_trading/config.py +134 -0
- mcli/workflow/politician_trading/connectivity.py +490 -0
- mcli/workflow/politician_trading/data_sources.py +395 -0
- mcli/workflow/politician_trading/database.py +410 -0
- mcli/workflow/politician_trading/demo.py +248 -0
- mcli/workflow/politician_trading/models.py +165 -0
- mcli/workflow/politician_trading/monitoring.py +413 -0
- mcli/workflow/politician_trading/scrapers.py +966 -0
- mcli/workflow/politician_trading/scrapers_california.py +412 -0
- mcli/workflow/politician_trading/scrapers_eu.py +377 -0
- mcli/workflow/politician_trading/scrapers_uk.py +350 -0
- mcli/workflow/politician_trading/scrapers_us_states.py +438 -0
- mcli/workflow/politician_trading/supabase_functions.py +354 -0
- mcli/workflow/politician_trading/workflow.py +852 -0
- mcli/workflow/registry/registry.py +180 -0
- mcli/workflow/repo/repo.py +223 -0
- mcli/workflow/scheduler/commands.py +493 -0
- mcli/workflow/scheduler/cron_parser.py +238 -0
- mcli/workflow/scheduler/job.py +182 -0
- mcli/workflow/scheduler/monitor.py +139 -0
- mcli/workflow/scheduler/persistence.py +324 -0
- mcli/workflow/scheduler/scheduler.py +679 -0
- mcli/workflow/sync/sync_cmd.py +437 -0
- mcli/workflow/sync/test_cmd.py +314 -0
- mcli/workflow/videos/videos.py +242 -0
- mcli/workflow/wakatime/wakatime.py +11 -0
- mcli/workflow/workflow.py +37 -0
- mcli_framework-7.0.0.dist-info/METADATA +479 -0
- mcli_framework-7.0.0.dist-info/RECORD +186 -0
- mcli_framework-7.0.0.dist-info/WHEEL +5 -0
- mcli_framework-7.0.0.dist-info/entry_points.txt +7 -0
- mcli_framework-7.0.0.dist-info/licenses/LICENSE +21 -0
- mcli_framework-7.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
"""MLflow experiment tracking and model registry"""
|
|
2
|
+
|
|
3
|
+
import mlflow
|
|
4
|
+
import mlflow.pytorch
|
|
5
|
+
import mlflow.sklearn
|
|
6
|
+
from mlflow.tracking import MlflowClient
|
|
7
|
+
from mlflow.models.signature import ModelSignature, infer_signature
|
|
8
|
+
import torch
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
from typing import Dict, Any, Optional, List, Union
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
import json
|
|
15
|
+
import logging
|
|
16
|
+
from datetime import datetime
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class MLflowConfig:
|
|
23
|
+
"""Configuration for MLflow tracking"""
|
|
24
|
+
tracking_uri: str = "sqlite:///mlruns.db"
|
|
25
|
+
experiment_name: str = "politician-trading-predictions"
|
|
26
|
+
artifact_location: Optional[str] = None
|
|
27
|
+
registry_uri: Optional[str] = None
|
|
28
|
+
tags: Dict[str, str] = None
|
|
29
|
+
|
|
30
|
+
def __post_init__(self):
|
|
31
|
+
if self.tags is None:
|
|
32
|
+
self.tags = {
|
|
33
|
+
"project": "politician-trading",
|
|
34
|
+
"framework": "pytorch",
|
|
35
|
+
"type": "stock-recommendation"
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class ExperimentRun:
|
|
41
|
+
"""Container for experiment run information"""
|
|
42
|
+
run_id: str
|
|
43
|
+
experiment_id: str
|
|
44
|
+
run_name: str
|
|
45
|
+
metrics: Dict[str, float]
|
|
46
|
+
params: Dict[str, Any]
|
|
47
|
+
artifacts: List[str]
|
|
48
|
+
model_uri: Optional[str] = None
|
|
49
|
+
status: str = "RUNNING"
|
|
50
|
+
start_time: Optional[datetime] = None
|
|
51
|
+
end_time: Optional[datetime] = None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class ExperimentTracker:
|
|
55
|
+
"""MLflow experiment tracker for ML pipeline"""
|
|
56
|
+
|
|
57
|
+
def __init__(self, config: MLflowConfig):
|
|
58
|
+
self.config = config
|
|
59
|
+
self.client = None
|
|
60
|
+
self.current_run = None
|
|
61
|
+
self.setup_mlflow()
|
|
62
|
+
|
|
63
|
+
def setup_mlflow(self):
|
|
64
|
+
"""Initialize MLflow tracking"""
|
|
65
|
+
mlflow.set_tracking_uri(self.config.tracking_uri)
|
|
66
|
+
|
|
67
|
+
if self.config.registry_uri:
|
|
68
|
+
mlflow.set_registry_uri(self.config.registry_uri)
|
|
69
|
+
|
|
70
|
+
# Create or get experiment
|
|
71
|
+
experiment = mlflow.get_experiment_by_name(self.config.experiment_name)
|
|
72
|
+
if experiment is None:
|
|
73
|
+
experiment_id = mlflow.create_experiment(
|
|
74
|
+
self.config.experiment_name,
|
|
75
|
+
artifact_location=self.config.artifact_location,
|
|
76
|
+
tags=self.config.tags
|
|
77
|
+
)
|
|
78
|
+
else:
|
|
79
|
+
experiment_id = experiment.experiment_id
|
|
80
|
+
|
|
81
|
+
mlflow.set_experiment(self.config.experiment_name)
|
|
82
|
+
self.client = MlflowClient()
|
|
83
|
+
self.experiment_id = experiment_id
|
|
84
|
+
|
|
85
|
+
logger.info(f"MLflow tracking initialized at {self.config.tracking_uri}")
|
|
86
|
+
logger.info(f"Experiment: {self.config.experiment_name} (ID: {experiment_id})")
|
|
87
|
+
|
|
88
|
+
def start_run(self, run_name: str, tags: Optional[Dict[str, str]] = None) -> ExperimentRun:
|
|
89
|
+
"""Start a new MLflow run"""
|
|
90
|
+
if self.current_run:
|
|
91
|
+
self.end_run()
|
|
92
|
+
|
|
93
|
+
# Merge tags
|
|
94
|
+
all_tags = {**self.config.tags}
|
|
95
|
+
if tags:
|
|
96
|
+
all_tags.update(tags)
|
|
97
|
+
|
|
98
|
+
# Start run
|
|
99
|
+
run = mlflow.start_run(run_name=run_name, tags=all_tags)
|
|
100
|
+
|
|
101
|
+
self.current_run = ExperimentRun(
|
|
102
|
+
run_id=run.info.run_id,
|
|
103
|
+
experiment_id=run.info.experiment_id,
|
|
104
|
+
run_name=run_name,
|
|
105
|
+
metrics={},
|
|
106
|
+
params={},
|
|
107
|
+
artifacts=[],
|
|
108
|
+
start_time=datetime.now()
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
logger.info(f"Started MLflow run: {run_name} (ID: {run.info.run_id})")
|
|
112
|
+
return self.current_run
|
|
113
|
+
|
|
114
|
+
def log_params(self, params: Dict[str, Any]):
|
|
115
|
+
"""Log parameters to current run"""
|
|
116
|
+
if not self.current_run:
|
|
117
|
+
raise ValueError("No active MLflow run. Call start_run() first.")
|
|
118
|
+
|
|
119
|
+
for key, value in params.items():
|
|
120
|
+
# Convert complex types to strings
|
|
121
|
+
if isinstance(value, (list, dict, tuple)):
|
|
122
|
+
value = json.dumps(value)
|
|
123
|
+
elif not isinstance(value, (str, int, float, bool)):
|
|
124
|
+
value = str(value)
|
|
125
|
+
|
|
126
|
+
mlflow.log_param(key, value)
|
|
127
|
+
self.current_run.params[key] = value
|
|
128
|
+
|
|
129
|
+
logger.debug(f"Logged {len(params)} parameters")
|
|
130
|
+
|
|
131
|
+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
|
|
132
|
+
"""Log metrics to current run"""
|
|
133
|
+
if not self.current_run:
|
|
134
|
+
raise ValueError("No active MLflow run. Call start_run() first.")
|
|
135
|
+
|
|
136
|
+
for key, value in metrics.items():
|
|
137
|
+
mlflow.log_metric(key, value, step=step)
|
|
138
|
+
self.current_run.metrics[key] = value
|
|
139
|
+
|
|
140
|
+
logger.debug(f"Logged {len(metrics)} metrics at step {step}")
|
|
141
|
+
|
|
142
|
+
def log_artifact(self, artifact_path: Union[str, Path], artifact_type: Optional[str] = None):
|
|
143
|
+
"""Log artifact to current run"""
|
|
144
|
+
if not self.current_run:
|
|
145
|
+
raise ValueError("No active MLflow run. Call start_run() first.")
|
|
146
|
+
|
|
147
|
+
artifact_path = Path(artifact_path)
|
|
148
|
+
|
|
149
|
+
if artifact_path.is_file():
|
|
150
|
+
mlflow.log_artifact(str(artifact_path))
|
|
151
|
+
elif artifact_path.is_dir():
|
|
152
|
+
mlflow.log_artifacts(str(artifact_path))
|
|
153
|
+
else:
|
|
154
|
+
raise ValueError(f"Artifact path does not exist: {artifact_path}")
|
|
155
|
+
|
|
156
|
+
self.current_run.artifacts.append(str(artifact_path))
|
|
157
|
+
logger.debug(f"Logged artifact: {artifact_path}")
|
|
158
|
+
|
|
159
|
+
def log_model(self, model: Any, model_name: str,
|
|
160
|
+
input_example: Optional[Union[np.ndarray, pd.DataFrame]] = None,
|
|
161
|
+
signature: Optional[ModelSignature] = None,
|
|
162
|
+
conda_env: Optional[Dict] = None,
|
|
163
|
+
pip_requirements: Optional[List[str]] = None):
|
|
164
|
+
"""Log model to current run"""
|
|
165
|
+
if not self.current_run:
|
|
166
|
+
raise ValueError("No active MLflow run. Call start_run() first.")
|
|
167
|
+
|
|
168
|
+
# Infer signature if not provided
|
|
169
|
+
if signature is None and input_example is not None:
|
|
170
|
+
if isinstance(model, torch.nn.Module):
|
|
171
|
+
model.eval()
|
|
172
|
+
with torch.no_grad():
|
|
173
|
+
if isinstance(input_example, pd.DataFrame):
|
|
174
|
+
input_tensor = torch.FloatTensor(input_example.values)
|
|
175
|
+
else:
|
|
176
|
+
input_tensor = torch.FloatTensor(input_example)
|
|
177
|
+
|
|
178
|
+
output = model(input_tensor)
|
|
179
|
+
if isinstance(output, dict):
|
|
180
|
+
# Handle dictionary outputs
|
|
181
|
+
output_example = {k: v.numpy() for k, v in output.items()}
|
|
182
|
+
else:
|
|
183
|
+
output_example = output.numpy()
|
|
184
|
+
|
|
185
|
+
signature = infer_signature(input_example, output_example)
|
|
186
|
+
else:
|
|
187
|
+
# For sklearn models
|
|
188
|
+
output_example = model.predict(input_example)
|
|
189
|
+
signature = infer_signature(input_example, output_example)
|
|
190
|
+
|
|
191
|
+
# Log model based on type
|
|
192
|
+
if isinstance(model, torch.nn.Module):
|
|
193
|
+
mlflow.pytorch.log_model(
|
|
194
|
+
model,
|
|
195
|
+
model_name,
|
|
196
|
+
signature=signature,
|
|
197
|
+
input_example=input_example,
|
|
198
|
+
conda_env=conda_env,
|
|
199
|
+
pip_requirements=pip_requirements
|
|
200
|
+
)
|
|
201
|
+
framework = "pytorch"
|
|
202
|
+
else:
|
|
203
|
+
# Assume sklearn-compatible
|
|
204
|
+
mlflow.sklearn.log_model(
|
|
205
|
+
model,
|
|
206
|
+
model_name,
|
|
207
|
+
signature=signature,
|
|
208
|
+
input_example=input_example,
|
|
209
|
+
conda_env=conda_env,
|
|
210
|
+
pip_requirements=pip_requirements
|
|
211
|
+
)
|
|
212
|
+
framework = "sklearn"
|
|
213
|
+
|
|
214
|
+
self.current_run.model_uri = f"runs:/{self.current_run.run_id}/{model_name}"
|
|
215
|
+
|
|
216
|
+
logger.info(f"Logged {framework} model: {model_name}")
|
|
217
|
+
return self.current_run.model_uri
|
|
218
|
+
|
|
219
|
+
def log_figure(self, figure, artifact_name: str):
|
|
220
|
+
"""Log matplotlib figure"""
|
|
221
|
+
if not self.current_run:
|
|
222
|
+
raise ValueError("No active MLflow run. Call start_run() first.")
|
|
223
|
+
|
|
224
|
+
mlflow.log_figure(figure, artifact_name)
|
|
225
|
+
self.current_run.artifacts.append(artifact_name)
|
|
226
|
+
logger.debug(f"Logged figure: {artifact_name}")
|
|
227
|
+
|
|
228
|
+
def log_dict(self, dictionary: Dict, artifact_name: str):
|
|
229
|
+
"""Log dictionary as JSON artifact"""
|
|
230
|
+
if not self.current_run:
|
|
231
|
+
raise ValueError("No active MLflow run. Call start_run() first.")
|
|
232
|
+
|
|
233
|
+
mlflow.log_dict(dictionary, artifact_name)
|
|
234
|
+
self.current_run.artifacts.append(artifact_name)
|
|
235
|
+
logger.debug(f"Logged dictionary: {artifact_name}")
|
|
236
|
+
|
|
237
|
+
def end_run(self, status: str = "FINISHED"):
|
|
238
|
+
"""End current MLflow run"""
|
|
239
|
+
if not self.current_run:
|
|
240
|
+
return
|
|
241
|
+
|
|
242
|
+
self.current_run.status = status
|
|
243
|
+
self.current_run.end_time = datetime.now()
|
|
244
|
+
|
|
245
|
+
mlflow.end_run(status=status)
|
|
246
|
+
|
|
247
|
+
duration = (self.current_run.end_time - self.current_run.start_time).total_seconds()
|
|
248
|
+
logger.info(f"Ended MLflow run {self.current_run.run_name} "
|
|
249
|
+
f"(Duration: {duration:.2f}s, Status: {status})")
|
|
250
|
+
|
|
251
|
+
current_run = self.current_run
|
|
252
|
+
self.current_run = None
|
|
253
|
+
return current_run
|
|
254
|
+
|
|
255
|
+
def get_run(self, run_id: str) -> mlflow.entities.Run:
|
|
256
|
+
"""Get run by ID"""
|
|
257
|
+
return self.client.get_run(run_id)
|
|
258
|
+
|
|
259
|
+
def search_runs(self, filter_string: str = "",
|
|
260
|
+
max_results: int = 100) -> List[mlflow.entities.Run]:
|
|
261
|
+
"""Search for runs in experiment"""
|
|
262
|
+
return self.client.search_runs(
|
|
263
|
+
experiment_ids=[self.experiment_id],
|
|
264
|
+
filter_string=filter_string,
|
|
265
|
+
max_results=max_results
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
def compare_runs(self, run_ids: List[str],
|
|
269
|
+
metrics: Optional[List[str]] = None) -> pd.DataFrame:
|
|
270
|
+
"""Compare multiple runs"""
|
|
271
|
+
runs_data = []
|
|
272
|
+
|
|
273
|
+
for run_id in run_ids:
|
|
274
|
+
run = self.get_run(run_id)
|
|
275
|
+
run_data = {
|
|
276
|
+
"run_id": run_id,
|
|
277
|
+
"run_name": run.data.tags.get("mlflow.runName", ""),
|
|
278
|
+
"status": run.info.status,
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
# Add params
|
|
282
|
+
for key, value in run.data.params.items():
|
|
283
|
+
run_data[f"param_{key}"] = value
|
|
284
|
+
|
|
285
|
+
# Add metrics
|
|
286
|
+
if metrics:
|
|
287
|
+
for metric in metrics:
|
|
288
|
+
if metric in run.data.metrics:
|
|
289
|
+
run_data[f"metric_{metric}"] = run.data.metrics[metric]
|
|
290
|
+
else:
|
|
291
|
+
for key, value in run.data.metrics.items():
|
|
292
|
+
run_data[f"metric_{key}"] = value
|
|
293
|
+
|
|
294
|
+
runs_data.append(run_data)
|
|
295
|
+
|
|
296
|
+
return pd.DataFrame(runs_data)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
class ModelRegistry:
|
|
300
|
+
"""MLflow model registry for model versioning and deployment"""
|
|
301
|
+
|
|
302
|
+
def __init__(self, config: MLflowConfig):
|
|
303
|
+
self.config = config
|
|
304
|
+
self.client = MlflowClient()
|
|
305
|
+
mlflow.set_tracking_uri(config.tracking_uri)
|
|
306
|
+
|
|
307
|
+
if config.registry_uri:
|
|
308
|
+
mlflow.set_registry_uri(config.registry_uri)
|
|
309
|
+
|
|
310
|
+
def register_model(self, model_uri: str, model_name: str,
|
|
311
|
+
tags: Optional[Dict[str, str]] = None) -> str:
|
|
312
|
+
"""Register model in MLflow registry"""
|
|
313
|
+
try:
|
|
314
|
+
# Create registered model if it doesn't exist
|
|
315
|
+
self.client.create_registered_model(
|
|
316
|
+
model_name,
|
|
317
|
+
tags=tags or {},
|
|
318
|
+
description=f"Model for {model_name}"
|
|
319
|
+
)
|
|
320
|
+
except Exception as e:
|
|
321
|
+
logger.debug(f"Model {model_name} already exists: {e}")
|
|
322
|
+
|
|
323
|
+
# Register model version
|
|
324
|
+
model_version = self.client.create_model_version(
|
|
325
|
+
name=model_name,
|
|
326
|
+
source=model_uri,
|
|
327
|
+
run_id=model_uri.split("/")[1] if "runs:/" in model_uri else None,
|
|
328
|
+
tags=tags or {}
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
logger.info(f"Registered model {model_name} version {model_version.version}")
|
|
332
|
+
return f"models:/{model_name}/{model_version.version}"
|
|
333
|
+
|
|
334
|
+
def transition_model_stage(self, model_name: str, version: int,
|
|
335
|
+
stage: str, archive_existing: bool = True):
|
|
336
|
+
"""Transition model version to new stage"""
|
|
337
|
+
self.client.transition_model_version_stage(
|
|
338
|
+
name=model_name,
|
|
339
|
+
version=version,
|
|
340
|
+
stage=stage,
|
|
341
|
+
archive_existing_versions=archive_existing
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
logger.info(f"Transitioned {model_name} v{version} to {stage}")
|
|
345
|
+
|
|
346
|
+
def load_model(self, model_name: str,
|
|
347
|
+
version: Optional[int] = None,
|
|
348
|
+
stage: Optional[str] = None) -> Any:
|
|
349
|
+
"""Load model from registry"""
|
|
350
|
+
if version:
|
|
351
|
+
model_uri = f"models:/{model_name}/{version}"
|
|
352
|
+
elif stage:
|
|
353
|
+
model_uri = f"models:/{model_name}/{stage}"
|
|
354
|
+
else:
|
|
355
|
+
model_uri = f"models:/{model_name}/latest"
|
|
356
|
+
|
|
357
|
+
model = mlflow.pytorch.load_model(model_uri)
|
|
358
|
+
logger.info(f"Loaded model from {model_uri}")
|
|
359
|
+
return model
|
|
360
|
+
|
|
361
|
+
def get_model_version(self, model_name: str, version: int):
|
|
362
|
+
"""Get specific model version details"""
|
|
363
|
+
return self.client.get_model_version(model_name, version)
|
|
364
|
+
|
|
365
|
+
def get_latest_versions(self, model_name: str,
|
|
366
|
+
stages: Optional[List[str]] = None):
|
|
367
|
+
"""Get latest model versions for given stages"""
|
|
368
|
+
return self.client.get_latest_versions(model_name, stages=stages)
|
|
369
|
+
|
|
370
|
+
def delete_model_version(self, model_name: str, version: int):
|
|
371
|
+
"""Delete model version"""
|
|
372
|
+
self.client.delete_model_version(model_name, version)
|
|
373
|
+
logger.info(f"Deleted {model_name} version {version}")
|
|
374
|
+
|
|
375
|
+
def search_models(self, filter_string: str = "") -> List:
|
|
376
|
+
"""Search registered models"""
|
|
377
|
+
return self.client.search_registered_models(filter_string=filter_string)
|