mcli-framework 7.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mcli-framework might be problematic. Click here for more details.

Files changed (186) hide show
  1. mcli/app/chat_cmd.py +42 -0
  2. mcli/app/commands_cmd.py +226 -0
  3. mcli/app/completion_cmd.py +216 -0
  4. mcli/app/completion_helpers.py +288 -0
  5. mcli/app/cron_test_cmd.py +697 -0
  6. mcli/app/logs_cmd.py +419 -0
  7. mcli/app/main.py +492 -0
  8. mcli/app/model/model.py +1060 -0
  9. mcli/app/model_cmd.py +227 -0
  10. mcli/app/redis_cmd.py +269 -0
  11. mcli/app/video/video.py +1114 -0
  12. mcli/app/visual_cmd.py +303 -0
  13. mcli/chat/chat.py +2409 -0
  14. mcli/chat/command_rag.py +514 -0
  15. mcli/chat/enhanced_chat.py +652 -0
  16. mcli/chat/system_controller.py +1010 -0
  17. mcli/chat/system_integration.py +1016 -0
  18. mcli/cli.py +25 -0
  19. mcli/config.toml +20 -0
  20. mcli/lib/api/api.py +586 -0
  21. mcli/lib/api/daemon_client.py +203 -0
  22. mcli/lib/api/daemon_client_local.py +44 -0
  23. mcli/lib/api/daemon_decorator.py +217 -0
  24. mcli/lib/api/mcli_decorators.py +1032 -0
  25. mcli/lib/auth/auth.py +85 -0
  26. mcli/lib/auth/aws_manager.py +85 -0
  27. mcli/lib/auth/azure_manager.py +91 -0
  28. mcli/lib/auth/credential_manager.py +192 -0
  29. mcli/lib/auth/gcp_manager.py +93 -0
  30. mcli/lib/auth/key_manager.py +117 -0
  31. mcli/lib/auth/mcli_manager.py +93 -0
  32. mcli/lib/auth/token_manager.py +75 -0
  33. mcli/lib/auth/token_util.py +1011 -0
  34. mcli/lib/config/config.py +47 -0
  35. mcli/lib/discovery/__init__.py +1 -0
  36. mcli/lib/discovery/command_discovery.py +274 -0
  37. mcli/lib/erd/erd.py +1345 -0
  38. mcli/lib/erd/generate_graph.py +453 -0
  39. mcli/lib/files/files.py +76 -0
  40. mcli/lib/fs/fs.py +109 -0
  41. mcli/lib/lib.py +29 -0
  42. mcli/lib/logger/logger.py +611 -0
  43. mcli/lib/performance/optimizer.py +409 -0
  44. mcli/lib/performance/rust_bridge.py +502 -0
  45. mcli/lib/performance/uvloop_config.py +154 -0
  46. mcli/lib/pickles/pickles.py +50 -0
  47. mcli/lib/search/cached_vectorizer.py +479 -0
  48. mcli/lib/services/data_pipeline.py +460 -0
  49. mcli/lib/services/lsh_client.py +441 -0
  50. mcli/lib/services/redis_service.py +387 -0
  51. mcli/lib/shell/shell.py +137 -0
  52. mcli/lib/toml/toml.py +33 -0
  53. mcli/lib/ui/styling.py +47 -0
  54. mcli/lib/ui/visual_effects.py +634 -0
  55. mcli/lib/watcher/watcher.py +185 -0
  56. mcli/ml/api/app.py +215 -0
  57. mcli/ml/api/middleware.py +224 -0
  58. mcli/ml/api/routers/admin_router.py +12 -0
  59. mcli/ml/api/routers/auth_router.py +244 -0
  60. mcli/ml/api/routers/backtest_router.py +12 -0
  61. mcli/ml/api/routers/data_router.py +12 -0
  62. mcli/ml/api/routers/model_router.py +302 -0
  63. mcli/ml/api/routers/monitoring_router.py +12 -0
  64. mcli/ml/api/routers/portfolio_router.py +12 -0
  65. mcli/ml/api/routers/prediction_router.py +267 -0
  66. mcli/ml/api/routers/trade_router.py +12 -0
  67. mcli/ml/api/routers/websocket_router.py +76 -0
  68. mcli/ml/api/schemas.py +64 -0
  69. mcli/ml/auth/auth_manager.py +425 -0
  70. mcli/ml/auth/models.py +154 -0
  71. mcli/ml/auth/permissions.py +302 -0
  72. mcli/ml/backtesting/backtest_engine.py +502 -0
  73. mcli/ml/backtesting/performance_metrics.py +393 -0
  74. mcli/ml/cache.py +400 -0
  75. mcli/ml/cli/main.py +398 -0
  76. mcli/ml/config/settings.py +394 -0
  77. mcli/ml/configs/dvc_config.py +230 -0
  78. mcli/ml/configs/mlflow_config.py +131 -0
  79. mcli/ml/configs/mlops_manager.py +293 -0
  80. mcli/ml/dashboard/app.py +532 -0
  81. mcli/ml/dashboard/app_integrated.py +738 -0
  82. mcli/ml/dashboard/app_supabase.py +560 -0
  83. mcli/ml/dashboard/app_training.py +615 -0
  84. mcli/ml/dashboard/cli.py +51 -0
  85. mcli/ml/data_ingestion/api_connectors.py +501 -0
  86. mcli/ml/data_ingestion/data_pipeline.py +567 -0
  87. mcli/ml/data_ingestion/stream_processor.py +512 -0
  88. mcli/ml/database/migrations/env.py +94 -0
  89. mcli/ml/database/models.py +667 -0
  90. mcli/ml/database/session.py +200 -0
  91. mcli/ml/experimentation/ab_testing.py +845 -0
  92. mcli/ml/features/ensemble_features.py +607 -0
  93. mcli/ml/features/political_features.py +676 -0
  94. mcli/ml/features/recommendation_engine.py +809 -0
  95. mcli/ml/features/stock_features.py +573 -0
  96. mcli/ml/features/test_feature_engineering.py +346 -0
  97. mcli/ml/logging.py +85 -0
  98. mcli/ml/mlops/data_versioning.py +518 -0
  99. mcli/ml/mlops/experiment_tracker.py +377 -0
  100. mcli/ml/mlops/model_serving.py +481 -0
  101. mcli/ml/mlops/pipeline_orchestrator.py +614 -0
  102. mcli/ml/models/base_models.py +324 -0
  103. mcli/ml/models/ensemble_models.py +675 -0
  104. mcli/ml/models/recommendation_models.py +474 -0
  105. mcli/ml/models/test_models.py +487 -0
  106. mcli/ml/monitoring/drift_detection.py +676 -0
  107. mcli/ml/monitoring/metrics.py +45 -0
  108. mcli/ml/optimization/portfolio_optimizer.py +834 -0
  109. mcli/ml/preprocessing/data_cleaners.py +451 -0
  110. mcli/ml/preprocessing/feature_extractors.py +491 -0
  111. mcli/ml/preprocessing/ml_pipeline.py +382 -0
  112. mcli/ml/preprocessing/politician_trading_preprocessor.py +569 -0
  113. mcli/ml/preprocessing/test_preprocessing.py +294 -0
  114. mcli/ml/scripts/populate_sample_data.py +200 -0
  115. mcli/ml/tasks.py +400 -0
  116. mcli/ml/tests/test_integration.py +429 -0
  117. mcli/ml/tests/test_training_dashboard.py +387 -0
  118. mcli/public/oi/oi.py +15 -0
  119. mcli/public/public.py +4 -0
  120. mcli/self/self_cmd.py +1246 -0
  121. mcli/workflow/daemon/api_daemon.py +800 -0
  122. mcli/workflow/daemon/async_command_database.py +681 -0
  123. mcli/workflow/daemon/async_process_manager.py +591 -0
  124. mcli/workflow/daemon/client.py +530 -0
  125. mcli/workflow/daemon/commands.py +1196 -0
  126. mcli/workflow/daemon/daemon.py +905 -0
  127. mcli/workflow/daemon/daemon_api.py +59 -0
  128. mcli/workflow/daemon/enhanced_daemon.py +571 -0
  129. mcli/workflow/daemon/process_cli.py +244 -0
  130. mcli/workflow/daemon/process_manager.py +439 -0
  131. mcli/workflow/daemon/test_daemon.py +275 -0
  132. mcli/workflow/dashboard/dashboard_cmd.py +113 -0
  133. mcli/workflow/docker/docker.py +0 -0
  134. mcli/workflow/file/file.py +100 -0
  135. mcli/workflow/gcloud/config.toml +21 -0
  136. mcli/workflow/gcloud/gcloud.py +58 -0
  137. mcli/workflow/git_commit/ai_service.py +328 -0
  138. mcli/workflow/git_commit/commands.py +430 -0
  139. mcli/workflow/lsh_integration.py +355 -0
  140. mcli/workflow/model_service/client.py +594 -0
  141. mcli/workflow/model_service/download_and_run_efficient_models.py +288 -0
  142. mcli/workflow/model_service/lightweight_embedder.py +397 -0
  143. mcli/workflow/model_service/lightweight_model_server.py +714 -0
  144. mcli/workflow/model_service/lightweight_test.py +241 -0
  145. mcli/workflow/model_service/model_service.py +1955 -0
  146. mcli/workflow/model_service/ollama_efficient_runner.py +425 -0
  147. mcli/workflow/model_service/pdf_processor.py +386 -0
  148. mcli/workflow/model_service/test_efficient_runner.py +234 -0
  149. mcli/workflow/model_service/test_example.py +315 -0
  150. mcli/workflow/model_service/test_integration.py +131 -0
  151. mcli/workflow/model_service/test_new_features.py +149 -0
  152. mcli/workflow/openai/openai.py +99 -0
  153. mcli/workflow/politician_trading/commands.py +1790 -0
  154. mcli/workflow/politician_trading/config.py +134 -0
  155. mcli/workflow/politician_trading/connectivity.py +490 -0
  156. mcli/workflow/politician_trading/data_sources.py +395 -0
  157. mcli/workflow/politician_trading/database.py +410 -0
  158. mcli/workflow/politician_trading/demo.py +248 -0
  159. mcli/workflow/politician_trading/models.py +165 -0
  160. mcli/workflow/politician_trading/monitoring.py +413 -0
  161. mcli/workflow/politician_trading/scrapers.py +966 -0
  162. mcli/workflow/politician_trading/scrapers_california.py +412 -0
  163. mcli/workflow/politician_trading/scrapers_eu.py +377 -0
  164. mcli/workflow/politician_trading/scrapers_uk.py +350 -0
  165. mcli/workflow/politician_trading/scrapers_us_states.py +438 -0
  166. mcli/workflow/politician_trading/supabase_functions.py +354 -0
  167. mcli/workflow/politician_trading/workflow.py +852 -0
  168. mcli/workflow/registry/registry.py +180 -0
  169. mcli/workflow/repo/repo.py +223 -0
  170. mcli/workflow/scheduler/commands.py +493 -0
  171. mcli/workflow/scheduler/cron_parser.py +238 -0
  172. mcli/workflow/scheduler/job.py +182 -0
  173. mcli/workflow/scheduler/monitor.py +139 -0
  174. mcli/workflow/scheduler/persistence.py +324 -0
  175. mcli/workflow/scheduler/scheduler.py +679 -0
  176. mcli/workflow/sync/sync_cmd.py +437 -0
  177. mcli/workflow/sync/test_cmd.py +314 -0
  178. mcli/workflow/videos/videos.py +242 -0
  179. mcli/workflow/wakatime/wakatime.py +11 -0
  180. mcli/workflow/workflow.py +37 -0
  181. mcli_framework-7.0.0.dist-info/METADATA +479 -0
  182. mcli_framework-7.0.0.dist-info/RECORD +186 -0
  183. mcli_framework-7.0.0.dist-info/WHEEL +5 -0
  184. mcli_framework-7.0.0.dist-info/entry_points.txt +7 -0
  185. mcli_framework-7.0.0.dist-info/licenses/LICENSE +21 -0
  186. mcli_framework-7.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,377 @@
1
+ """MLflow experiment tracking and model registry"""
2
+
3
+ import mlflow
4
+ import mlflow.pytorch
5
+ import mlflow.sklearn
6
+ from mlflow.tracking import MlflowClient
7
+ from mlflow.models.signature import ModelSignature, infer_signature
8
+ import torch
9
+ import numpy as np
10
+ import pandas as pd
11
+ from typing import Dict, Any, Optional, List, Union
12
+ from dataclasses import dataclass
13
+ from pathlib import Path
14
+ import json
15
+ import logging
16
+ from datetime import datetime
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class MLflowConfig:
23
+ """Configuration for MLflow tracking"""
24
+ tracking_uri: str = "sqlite:///mlruns.db"
25
+ experiment_name: str = "politician-trading-predictions"
26
+ artifact_location: Optional[str] = None
27
+ registry_uri: Optional[str] = None
28
+ tags: Dict[str, str] = None
29
+
30
+ def __post_init__(self):
31
+ if self.tags is None:
32
+ self.tags = {
33
+ "project": "politician-trading",
34
+ "framework": "pytorch",
35
+ "type": "stock-recommendation"
36
+ }
37
+
38
+
39
+ @dataclass
40
+ class ExperimentRun:
41
+ """Container for experiment run information"""
42
+ run_id: str
43
+ experiment_id: str
44
+ run_name: str
45
+ metrics: Dict[str, float]
46
+ params: Dict[str, Any]
47
+ artifacts: List[str]
48
+ model_uri: Optional[str] = None
49
+ status: str = "RUNNING"
50
+ start_time: Optional[datetime] = None
51
+ end_time: Optional[datetime] = None
52
+
53
+
54
+ class ExperimentTracker:
55
+ """MLflow experiment tracker for ML pipeline"""
56
+
57
+ def __init__(self, config: MLflowConfig):
58
+ self.config = config
59
+ self.client = None
60
+ self.current_run = None
61
+ self.setup_mlflow()
62
+
63
+ def setup_mlflow(self):
64
+ """Initialize MLflow tracking"""
65
+ mlflow.set_tracking_uri(self.config.tracking_uri)
66
+
67
+ if self.config.registry_uri:
68
+ mlflow.set_registry_uri(self.config.registry_uri)
69
+
70
+ # Create or get experiment
71
+ experiment = mlflow.get_experiment_by_name(self.config.experiment_name)
72
+ if experiment is None:
73
+ experiment_id = mlflow.create_experiment(
74
+ self.config.experiment_name,
75
+ artifact_location=self.config.artifact_location,
76
+ tags=self.config.tags
77
+ )
78
+ else:
79
+ experiment_id = experiment.experiment_id
80
+
81
+ mlflow.set_experiment(self.config.experiment_name)
82
+ self.client = MlflowClient()
83
+ self.experiment_id = experiment_id
84
+
85
+ logger.info(f"MLflow tracking initialized at {self.config.tracking_uri}")
86
+ logger.info(f"Experiment: {self.config.experiment_name} (ID: {experiment_id})")
87
+
88
+ def start_run(self, run_name: str, tags: Optional[Dict[str, str]] = None) -> ExperimentRun:
89
+ """Start a new MLflow run"""
90
+ if self.current_run:
91
+ self.end_run()
92
+
93
+ # Merge tags
94
+ all_tags = {**self.config.tags}
95
+ if tags:
96
+ all_tags.update(tags)
97
+
98
+ # Start run
99
+ run = mlflow.start_run(run_name=run_name, tags=all_tags)
100
+
101
+ self.current_run = ExperimentRun(
102
+ run_id=run.info.run_id,
103
+ experiment_id=run.info.experiment_id,
104
+ run_name=run_name,
105
+ metrics={},
106
+ params={},
107
+ artifacts=[],
108
+ start_time=datetime.now()
109
+ )
110
+
111
+ logger.info(f"Started MLflow run: {run_name} (ID: {run.info.run_id})")
112
+ return self.current_run
113
+
114
+ def log_params(self, params: Dict[str, Any]):
115
+ """Log parameters to current run"""
116
+ if not self.current_run:
117
+ raise ValueError("No active MLflow run. Call start_run() first.")
118
+
119
+ for key, value in params.items():
120
+ # Convert complex types to strings
121
+ if isinstance(value, (list, dict, tuple)):
122
+ value = json.dumps(value)
123
+ elif not isinstance(value, (str, int, float, bool)):
124
+ value = str(value)
125
+
126
+ mlflow.log_param(key, value)
127
+ self.current_run.params[key] = value
128
+
129
+ logger.debug(f"Logged {len(params)} parameters")
130
+
131
+ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
132
+ """Log metrics to current run"""
133
+ if not self.current_run:
134
+ raise ValueError("No active MLflow run. Call start_run() first.")
135
+
136
+ for key, value in metrics.items():
137
+ mlflow.log_metric(key, value, step=step)
138
+ self.current_run.metrics[key] = value
139
+
140
+ logger.debug(f"Logged {len(metrics)} metrics at step {step}")
141
+
142
+ def log_artifact(self, artifact_path: Union[str, Path], artifact_type: Optional[str] = None):
143
+ """Log artifact to current run"""
144
+ if not self.current_run:
145
+ raise ValueError("No active MLflow run. Call start_run() first.")
146
+
147
+ artifact_path = Path(artifact_path)
148
+
149
+ if artifact_path.is_file():
150
+ mlflow.log_artifact(str(artifact_path))
151
+ elif artifact_path.is_dir():
152
+ mlflow.log_artifacts(str(artifact_path))
153
+ else:
154
+ raise ValueError(f"Artifact path does not exist: {artifact_path}")
155
+
156
+ self.current_run.artifacts.append(str(artifact_path))
157
+ logger.debug(f"Logged artifact: {artifact_path}")
158
+
159
+ def log_model(self, model: Any, model_name: str,
160
+ input_example: Optional[Union[np.ndarray, pd.DataFrame]] = None,
161
+ signature: Optional[ModelSignature] = None,
162
+ conda_env: Optional[Dict] = None,
163
+ pip_requirements: Optional[List[str]] = None):
164
+ """Log model to current run"""
165
+ if not self.current_run:
166
+ raise ValueError("No active MLflow run. Call start_run() first.")
167
+
168
+ # Infer signature if not provided
169
+ if signature is None and input_example is not None:
170
+ if isinstance(model, torch.nn.Module):
171
+ model.eval()
172
+ with torch.no_grad():
173
+ if isinstance(input_example, pd.DataFrame):
174
+ input_tensor = torch.FloatTensor(input_example.values)
175
+ else:
176
+ input_tensor = torch.FloatTensor(input_example)
177
+
178
+ output = model(input_tensor)
179
+ if isinstance(output, dict):
180
+ # Handle dictionary outputs
181
+ output_example = {k: v.numpy() for k, v in output.items()}
182
+ else:
183
+ output_example = output.numpy()
184
+
185
+ signature = infer_signature(input_example, output_example)
186
+ else:
187
+ # For sklearn models
188
+ output_example = model.predict(input_example)
189
+ signature = infer_signature(input_example, output_example)
190
+
191
+ # Log model based on type
192
+ if isinstance(model, torch.nn.Module):
193
+ mlflow.pytorch.log_model(
194
+ model,
195
+ model_name,
196
+ signature=signature,
197
+ input_example=input_example,
198
+ conda_env=conda_env,
199
+ pip_requirements=pip_requirements
200
+ )
201
+ framework = "pytorch"
202
+ else:
203
+ # Assume sklearn-compatible
204
+ mlflow.sklearn.log_model(
205
+ model,
206
+ model_name,
207
+ signature=signature,
208
+ input_example=input_example,
209
+ conda_env=conda_env,
210
+ pip_requirements=pip_requirements
211
+ )
212
+ framework = "sklearn"
213
+
214
+ self.current_run.model_uri = f"runs:/{self.current_run.run_id}/{model_name}"
215
+
216
+ logger.info(f"Logged {framework} model: {model_name}")
217
+ return self.current_run.model_uri
218
+
219
+ def log_figure(self, figure, artifact_name: str):
220
+ """Log matplotlib figure"""
221
+ if not self.current_run:
222
+ raise ValueError("No active MLflow run. Call start_run() first.")
223
+
224
+ mlflow.log_figure(figure, artifact_name)
225
+ self.current_run.artifacts.append(artifact_name)
226
+ logger.debug(f"Logged figure: {artifact_name}")
227
+
228
+ def log_dict(self, dictionary: Dict, artifact_name: str):
229
+ """Log dictionary as JSON artifact"""
230
+ if not self.current_run:
231
+ raise ValueError("No active MLflow run. Call start_run() first.")
232
+
233
+ mlflow.log_dict(dictionary, artifact_name)
234
+ self.current_run.artifacts.append(artifact_name)
235
+ logger.debug(f"Logged dictionary: {artifact_name}")
236
+
237
+ def end_run(self, status: str = "FINISHED"):
238
+ """End current MLflow run"""
239
+ if not self.current_run:
240
+ return
241
+
242
+ self.current_run.status = status
243
+ self.current_run.end_time = datetime.now()
244
+
245
+ mlflow.end_run(status=status)
246
+
247
+ duration = (self.current_run.end_time - self.current_run.start_time).total_seconds()
248
+ logger.info(f"Ended MLflow run {self.current_run.run_name} "
249
+ f"(Duration: {duration:.2f}s, Status: {status})")
250
+
251
+ current_run = self.current_run
252
+ self.current_run = None
253
+ return current_run
254
+
255
+ def get_run(self, run_id: str) -> mlflow.entities.Run:
256
+ """Get run by ID"""
257
+ return self.client.get_run(run_id)
258
+
259
+ def search_runs(self, filter_string: str = "",
260
+ max_results: int = 100) -> List[mlflow.entities.Run]:
261
+ """Search for runs in experiment"""
262
+ return self.client.search_runs(
263
+ experiment_ids=[self.experiment_id],
264
+ filter_string=filter_string,
265
+ max_results=max_results
266
+ )
267
+
268
+ def compare_runs(self, run_ids: List[str],
269
+ metrics: Optional[List[str]] = None) -> pd.DataFrame:
270
+ """Compare multiple runs"""
271
+ runs_data = []
272
+
273
+ for run_id in run_ids:
274
+ run = self.get_run(run_id)
275
+ run_data = {
276
+ "run_id": run_id,
277
+ "run_name": run.data.tags.get("mlflow.runName", ""),
278
+ "status": run.info.status,
279
+ }
280
+
281
+ # Add params
282
+ for key, value in run.data.params.items():
283
+ run_data[f"param_{key}"] = value
284
+
285
+ # Add metrics
286
+ if metrics:
287
+ for metric in metrics:
288
+ if metric in run.data.metrics:
289
+ run_data[f"metric_{metric}"] = run.data.metrics[metric]
290
+ else:
291
+ for key, value in run.data.metrics.items():
292
+ run_data[f"metric_{key}"] = value
293
+
294
+ runs_data.append(run_data)
295
+
296
+ return pd.DataFrame(runs_data)
297
+
298
+
299
+ class ModelRegistry:
300
+ """MLflow model registry for model versioning and deployment"""
301
+
302
+ def __init__(self, config: MLflowConfig):
303
+ self.config = config
304
+ self.client = MlflowClient()
305
+ mlflow.set_tracking_uri(config.tracking_uri)
306
+
307
+ if config.registry_uri:
308
+ mlflow.set_registry_uri(config.registry_uri)
309
+
310
+ def register_model(self, model_uri: str, model_name: str,
311
+ tags: Optional[Dict[str, str]] = None) -> str:
312
+ """Register model in MLflow registry"""
313
+ try:
314
+ # Create registered model if it doesn't exist
315
+ self.client.create_registered_model(
316
+ model_name,
317
+ tags=tags or {},
318
+ description=f"Model for {model_name}"
319
+ )
320
+ except Exception as e:
321
+ logger.debug(f"Model {model_name} already exists: {e}")
322
+
323
+ # Register model version
324
+ model_version = self.client.create_model_version(
325
+ name=model_name,
326
+ source=model_uri,
327
+ run_id=model_uri.split("/")[1] if "runs:/" in model_uri else None,
328
+ tags=tags or {}
329
+ )
330
+
331
+ logger.info(f"Registered model {model_name} version {model_version.version}")
332
+ return f"models:/{model_name}/{model_version.version}"
333
+
334
+ def transition_model_stage(self, model_name: str, version: int,
335
+ stage: str, archive_existing: bool = True):
336
+ """Transition model version to new stage"""
337
+ self.client.transition_model_version_stage(
338
+ name=model_name,
339
+ version=version,
340
+ stage=stage,
341
+ archive_existing_versions=archive_existing
342
+ )
343
+
344
+ logger.info(f"Transitioned {model_name} v{version} to {stage}")
345
+
346
+ def load_model(self, model_name: str,
347
+ version: Optional[int] = None,
348
+ stage: Optional[str] = None) -> Any:
349
+ """Load model from registry"""
350
+ if version:
351
+ model_uri = f"models:/{model_name}/{version}"
352
+ elif stage:
353
+ model_uri = f"models:/{model_name}/{stage}"
354
+ else:
355
+ model_uri = f"models:/{model_name}/latest"
356
+
357
+ model = mlflow.pytorch.load_model(model_uri)
358
+ logger.info(f"Loaded model from {model_uri}")
359
+ return model
360
+
361
+ def get_model_version(self, model_name: str, version: int):
362
+ """Get specific model version details"""
363
+ return self.client.get_model_version(model_name, version)
364
+
365
+ def get_latest_versions(self, model_name: str,
366
+ stages: Optional[List[str]] = None):
367
+ """Get latest model versions for given stages"""
368
+ return self.client.get_latest_versions(model_name, stages=stages)
369
+
370
+ def delete_model_version(self, model_name: str, version: int):
371
+ """Delete model version"""
372
+ self.client.delete_model_version(model_name, version)
373
+ logger.info(f"Deleted {model_name} version {version}")
374
+
375
+ def search_models(self, filter_string: str = "") -> List:
376
+ """Search registered models"""
377
+ return self.client.search_registered_models(filter_string=filter_string)