ml-analytics-tools 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml_analytics/__init__.py +53 -0
- ml_analytics/aws_auth.py +169 -0
- ml_analytics/cli.py +58 -0
- ml_analytics/data_connector.py +2615 -0
- ml_analytics/gsheet_connector.py +1646 -0
- ml_analytics/model_manager.py +1208 -0
- ml_analytics/model_tools.py +990 -0
- ml_analytics/s3_connector.py +1381 -0
- ml_analytics/slack_connector.py +637 -0
- ml_analytics/tunnel_manager.py +277 -0
- ml_analytics/utils.py +673 -0
- ml_analytics_tools-0.2.0.dist-info/METADATA +231 -0
- ml_analytics_tools-0.2.0.dist-info/RECORD +17 -0
- ml_analytics_tools-0.2.0.dist-info/WHEEL +5 -0
- ml_analytics_tools-0.2.0.dist-info/entry_points.txt +4 -0
- ml_analytics_tools-0.2.0.dist-info/licenses/LICENSE +21 -0
- ml_analytics_tools-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1208 @@
|
|
|
1
|
+
"""
|
|
2
|
+
A module for managing MLflow model lifecycle including registration, logging, and deletion.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
import mlflow
|
|
9
|
+
from mlflow.exceptions import RestException
|
|
10
|
+
from mlflow.models import infer_signature
|
|
11
|
+
from mlflow.server.auth.client import AuthServiceClient
|
|
12
|
+
from mlflow.tracking import MlflowClient
|
|
13
|
+
|
|
14
|
+
from .utils import get_credential_value, get_logger, log_and_raise_error
|
|
15
|
+
|
|
16
|
+
PermissionLevel = Literal["READ", "EDIT", "MANAGE", "NO_PERMISSIONS"]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ModelManager:
|
|
20
|
+
"""
|
|
21
|
+
A class to manage MLflow model lifecycle including registration, logging, and deletion.
|
|
22
|
+
|
|
23
|
+
Attributes
|
|
24
|
+
----------
|
|
25
|
+
client : MlflowClient
|
|
26
|
+
The MLflow client instance for interacting with the MLflow API.
|
|
27
|
+
model_name : str
|
|
28
|
+
The name of the model to be registered.
|
|
29
|
+
task : str
|
|
30
|
+
The task the model is designed for (classification, regression, etc.).
|
|
31
|
+
project : Optional[str]
|
|
32
|
+
The project associated with the model. Default is None.
|
|
33
|
+
description : Optional[str]
|
|
34
|
+
A description of the model. Default is None.
|
|
35
|
+
team : Optional[str]
|
|
36
|
+
The team responsible for the model. Default is None.
|
|
37
|
+
user : Optional[str]
|
|
38
|
+
The user ID associated with the model. Default is None.
|
|
39
|
+
tracking_uri : Optional[str]
|
|
40
|
+
The MLflow tracking URI. Default is None.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
*,
|
|
46
|
+
model_name: str,
|
|
47
|
+
task: str | None = None,
|
|
48
|
+
project: str | None = None,
|
|
49
|
+
description: str | None = None,
|
|
50
|
+
team: str | None = None,
|
|
51
|
+
user: str | None = None,
|
|
52
|
+
tracking_uri: str | None = None,
|
|
53
|
+
workspace: str | None = None,
|
|
54
|
+
create_registered_model: bool = False,
|
|
55
|
+
start_initial_run: bool = False,
|
|
56
|
+
run_name: str | None = None,
|
|
57
|
+
):
|
|
58
|
+
"""
|
|
59
|
+
Initialize a ModelManager instance.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
model_name : str
|
|
64
|
+
The name of the model to be registered.
|
|
65
|
+
task : str
|
|
66
|
+
The task the model is designed for (classification, regression, etc.).
|
|
67
|
+
project : Optional[str]
|
|
68
|
+
The project associated with the model.
|
|
69
|
+
description : Optional[str]
|
|
70
|
+
A description of the model.
|
|
71
|
+
team : Optional[str]
|
|
72
|
+
The team responsible for the model.
|
|
73
|
+
user : Optional[str]
|
|
74
|
+
The user ID associated with the model.
|
|
75
|
+
tracking_uri : Optional[str]
|
|
76
|
+
The MLflow tracking URI. If None, uses MLFLOW_TRACKING_URI env var,
|
|
77
|
+
falling back to MLflow's configured default.
|
|
78
|
+
workspace : Optional[str]
|
|
79
|
+
The MLflow workspace name. If None, uses MLFLOW_WORKSPACE env var.
|
|
80
|
+
If neither is set, workspace is not configured (for servers without workspace support).
|
|
81
|
+
create_registered_model : bool
|
|
82
|
+
Whether to create the registered model (default: True), otherwise only it will create the experiment.
|
|
83
|
+
run_name : str, optional
|
|
84
|
+
The name for the MLflow run. If None, a default name will be used.
|
|
85
|
+
start_initial_run : bool, default=False
|
|
86
|
+
Whether to start an MLflow run upon initialization.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
self._logger = get_logger("Model Manager")
|
|
90
|
+
|
|
91
|
+
for env_name in ("MLFLOW_TRACKING_USERNAME", "MLFLOW_TRACKING_PASSWORD"):
|
|
92
|
+
try:
|
|
93
|
+
os.environ[env_name] = get_credential_value(env_name)
|
|
94
|
+
except Exception:
|
|
95
|
+
self._logger.debug("%s is not configured; continuing without it.", env_name)
|
|
96
|
+
os.environ["MLFLOW_ENABLE_PROXY_MULTIPART_UPLOAD"] = "true"
|
|
97
|
+
os.environ["MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE"] = "15728640"
|
|
98
|
+
os.environ["MLFLOW_MULTIPART_UPLOAD_MINIMUM_FILE_SIZE"] = "15728640"
|
|
99
|
+
os.environ["MLFLOW_MULTIPART_DOWNLOAD_MINIMUM_FILE_SIZE"] = "524288000"
|
|
100
|
+
os.environ["MLFLOW_HTTP_REQUEST_TIMEOUT"] = "120"
|
|
101
|
+
os.environ["MLFLOW_ARTIFACT_UPLOAD_DOWNLOAD_TIMEOUT"] = "300"
|
|
102
|
+
|
|
103
|
+
# Set MLflow tracking URI: explicit param > env var > MLflow default.
|
|
104
|
+
self.tracking_uri = tracking_uri or os.environ.get("MLFLOW_TRACKING_URI") or mlflow.get_tracking_uri()
|
|
105
|
+
mlflow.set_tracking_uri(self.tracking_uri)
|
|
106
|
+
self._logger.debug(f"MLflow tracking URI set to: {self.tracking_uri}")
|
|
107
|
+
|
|
108
|
+
# Set MLflow workspace: explicit param > env var; skip if not set (server may not support workspaces)
|
|
109
|
+
self.workspace = workspace or os.environ.get("MLFLOW_WORKSPACE")
|
|
110
|
+
if self.workspace:
|
|
111
|
+
mlflow.set_workspace(self.workspace)
|
|
112
|
+
self._logger.debug(f"MLflow workspace set to: {self.workspace}")
|
|
113
|
+
else:
|
|
114
|
+
self._logger.debug("No workspace configured, using server default.")
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
self.auth_client = AuthServiceClient(self.tracking_uri)
|
|
118
|
+
except Exception as e:
|
|
119
|
+
self._logger.error(f"Failed to initialize AuthServiceClient: {e}. Permission management will not work.")
|
|
120
|
+
self.auth_client = None # Set to None if initialization fails
|
|
121
|
+
|
|
122
|
+
# Initialize MLflow client
|
|
123
|
+
self.client = MlflowClient()
|
|
124
|
+
self.model_name = model_name
|
|
125
|
+
self.team = team
|
|
126
|
+
self.user = user
|
|
127
|
+
self.project = project
|
|
128
|
+
self.task = task
|
|
129
|
+
self.description = description
|
|
130
|
+
self.run_id = None
|
|
131
|
+
self.experiment_id = None
|
|
132
|
+
self.model_uri = None
|
|
133
|
+
self.start_initial_run = start_initial_run
|
|
134
|
+
|
|
135
|
+
tags = {
|
|
136
|
+
k: v
|
|
137
|
+
for k, v in {"team": self.team, "user": self.user, "project": self.project, "task": self.task}.items()
|
|
138
|
+
if v is not None
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
if create_registered_model:
|
|
142
|
+
self._create_registered_model(tags)
|
|
143
|
+
|
|
144
|
+
self._setup_experiment()
|
|
145
|
+
|
|
146
|
+
if self.start_initial_run:
|
|
147
|
+
self.start_run(run_name)
|
|
148
|
+
else:
|
|
149
|
+
self.run_id = None
|
|
150
|
+
|
|
151
|
+
def start_run(self, run_name: str | None = None) -> None:
|
|
152
|
+
"""
|
|
153
|
+
Starts a new MLflow run if no run is currently managed by this instance (self.run_id is None).
|
|
154
|
+
Associates the run with self.experiment_id.
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
# checking there is no active run managed by this instance
|
|
158
|
+
self.run_id = mlflow.active_run().info.run_id if mlflow.active_run() else None
|
|
159
|
+
|
|
160
|
+
if self.run_id is not None:
|
|
161
|
+
self._logger.info(
|
|
162
|
+
f"Run {self.run_id} is already considered active by this ModelManager instance. "
|
|
163
|
+
"To start a new run, you can use start_new_run()."
|
|
164
|
+
)
|
|
165
|
+
current_mlflow_run = mlflow.active_run()
|
|
166
|
+
if not current_mlflow_run or current_mlflow_run.info.run_id != self.run_id:
|
|
167
|
+
self._logger.warning(
|
|
168
|
+
f"Mismatch: ModelManager's run_id is {self.run_id}, "
|
|
169
|
+
f"but MLflow's active run is {current_mlflow_run.info.run_id if current_mlflow_run else 'None'}. "
|
|
170
|
+
"This might indicate an inconsistent state."
|
|
171
|
+
)
|
|
172
|
+
else:
|
|
173
|
+
try:
|
|
174
|
+
mlflow.set_experiment(experiment_id=self.experiment_id)
|
|
175
|
+
active_run_obj = mlflow.start_run(run_name=run_name, experiment_id=self.experiment_id)
|
|
176
|
+
self.run_id = active_run_obj.info.run_id
|
|
177
|
+
log_message = f"Started new MLflow run with ID: {self.run_id} for experiment ID: {self.experiment_id}"
|
|
178
|
+
if run_name:
|
|
179
|
+
log_message += f" with name: {run_name}"
|
|
180
|
+
self._logger.info(log_message + ".")
|
|
181
|
+
except Exception as e:
|
|
182
|
+
log_and_raise_error(self._logger, f"Cannot start a new MLflow run: {e}")
|
|
183
|
+
self.run_id = None
|
|
184
|
+
|
|
185
|
+
def end_run(self, status: str = "FINISHED") -> None:
|
|
186
|
+
"""Ends the active MLflow run managed by this instance and clears self.run_id."""
|
|
187
|
+
if self.run_id:
|
|
188
|
+
current_active_run = mlflow.active_run()
|
|
189
|
+
if current_active_run and current_active_run.info.run_id == self.run_id:
|
|
190
|
+
try:
|
|
191
|
+
mlflow.end_run(status=status)
|
|
192
|
+
self._logger.info(f"MLflow run {self.run_id} ended with status: {status}.")
|
|
193
|
+
except Exception as e:
|
|
194
|
+
self._logger.error(f"Error ending MLflow run {self.run_id}: {e}")
|
|
195
|
+
elif current_active_run:
|
|
196
|
+
self._logger.warning(
|
|
197
|
+
f"Attempted to end ModelManager's run {self.run_id}, but the current "
|
|
198
|
+
f"MLflow active run is {current_active_run.info.run_id}. "
|
|
199
|
+
f"ModelManager's run {self.run_id} was not ended by this call."
|
|
200
|
+
)
|
|
201
|
+
else:
|
|
202
|
+
self._logger.info(
|
|
203
|
+
f"Attempted to end ModelManager's run {self.run_id}, but there is no "
|
|
204
|
+
"active MLflow run. Assuming it was already ended."
|
|
205
|
+
)
|
|
206
|
+
else:
|
|
207
|
+
self._logger.info("No run ID associated with this ModelManager instance to end.")
|
|
208
|
+
|
|
209
|
+
self.run_id = None
|
|
210
|
+
|
|
211
|
+
def start_new_run(self, run_name: str | None = None) -> None:
|
|
212
|
+
"""
|
|
213
|
+
Ensures any previous run managed by this instance is ended, then starts a new MLflow run.
|
|
214
|
+
|
|
215
|
+
Parameters
|
|
216
|
+
----------
|
|
217
|
+
run_name : str, optional
|
|
218
|
+
The name for the new MLflow run. If None, a default name might be used by MLflow.
|
|
219
|
+
"""
|
|
220
|
+
if self.run_id is not None:
|
|
221
|
+
self.end_run()
|
|
222
|
+
|
|
223
|
+
# After end_run, self.run_id is None, so start_run will proceed to create a new run.
|
|
224
|
+
self.start_run(run_name)
|
|
225
|
+
|
|
226
|
+
def _create_registered_model(self, tags: dict[str, str]) -> None:
|
|
227
|
+
"""
|
|
228
|
+
Create a registered model or update its tags if it already exists.
|
|
229
|
+
|
|
230
|
+
Parameters
|
|
231
|
+
----------
|
|
232
|
+
tags : Dict[str, str]
|
|
233
|
+
Tags to associate with the model.
|
|
234
|
+
"""
|
|
235
|
+
|
|
236
|
+
# add user id to tags
|
|
237
|
+
if self.user is None:
|
|
238
|
+
log_and_raise_error(self._logger, "Define user to create instance of ModelManager.")
|
|
239
|
+
else:
|
|
240
|
+
tags["user"] = self.user
|
|
241
|
+
|
|
242
|
+
try:
|
|
243
|
+
model_instance = self.client.get_registered_model(self.model_name)
|
|
244
|
+
if model_instance:
|
|
245
|
+
self._logger.info(f"Model '{self.model_name}' already exists.")
|
|
246
|
+
if tags:
|
|
247
|
+
for key, value in tags.items():
|
|
248
|
+
self.client.set_registered_model_tag(name=self.model_name, key=key, value=value)
|
|
249
|
+
except Exception:
|
|
250
|
+
if not all([self.team, self.user, self.project, self.task]):
|
|
251
|
+
log_and_raise_error(self._logger, "All tags must be provided (team, user, project, task).")
|
|
252
|
+
try:
|
|
253
|
+
self.client.create_registered_model(name=self.model_name, tags=tags, description=self.description)
|
|
254
|
+
self._logger.info(f"Model '{self.model_name}' created successfully.")
|
|
255
|
+
except Exception as e:
|
|
256
|
+
log_and_raise_error(self._logger, f"Error creating model '{self.model_name}': {e}")
|
|
257
|
+
|
|
258
|
+
def _setup_experiment(self) -> None:
|
|
259
|
+
"""Set up an MLflow experiment."""
|
|
260
|
+
try:
|
|
261
|
+
# Try to create the experiment
|
|
262
|
+
experiment_id = self.client.create_experiment(self.model_name)
|
|
263
|
+
self._logger.info(f"Created new experiment '{self.model_name}' with ID: {experiment_id}")
|
|
264
|
+
self.experiment_id = experiment_id
|
|
265
|
+
|
|
266
|
+
except Exception as e:
|
|
267
|
+
# If experiment already exists, get its ID
|
|
268
|
+
experiment = mlflow.get_experiment_by_name(self.model_name)
|
|
269
|
+
self.experiment_id = experiment.experiment_id
|
|
270
|
+
if experiment:
|
|
271
|
+
self._logger.info(f"Using existing experiment '{self.model_name}' with ID: {experiment.experiment_id}")
|
|
272
|
+
else:
|
|
273
|
+
log_and_raise_error(self._logger, f"Error setting up experiment '{self.model_name}': {e}")
|
|
274
|
+
|
|
275
|
+
try:
|
|
276
|
+
if self.project or self.team or self.user or self.task:
|
|
277
|
+
experiment_tags = {"project": self.project, "team": self.team, "user": self.user, "task": self.task}
|
|
278
|
+
|
|
279
|
+
experiment_tags = {k: v for k, v in experiment_tags.items() if v is not None}
|
|
280
|
+
for key, value in experiment_tags.items():
|
|
281
|
+
self.client.set_experiment_tag(experiment_id=self.experiment_id, key=key, value=value)
|
|
282
|
+
except Exception as e:
|
|
283
|
+
self._logger.warning(f"Error setting experiment tags: {e}. Check if you are the owner of this experiment.")
|
|
284
|
+
|
|
285
|
+
# Set the active experiment
|
|
286
|
+
mlflow.set_experiment(self.model_name)
|
|
287
|
+
|
|
288
|
+
def log_model(
|
|
289
|
+
self,
|
|
290
|
+
*,
|
|
291
|
+
model=None,
|
|
292
|
+
input_data=None,
|
|
293
|
+
predictions=None,
|
|
294
|
+
flavor="sklearn",
|
|
295
|
+
register_model=True,
|
|
296
|
+
description=None,
|
|
297
|
+
tags=None,
|
|
298
|
+
python_model=None,
|
|
299
|
+
name="model",
|
|
300
|
+
**kwargs,
|
|
301
|
+
):
|
|
302
|
+
"""
|
|
303
|
+
Log the model to MLflow using a dynamic flavor, including support for the 'pyfunc' flavor.
|
|
304
|
+
|
|
305
|
+
Parameters
|
|
306
|
+
----------
|
|
307
|
+
model : object, optional
|
|
308
|
+
The trained model to log (used for most flavors).
|
|
309
|
+
input_data : pd.DataFrame, optional
|
|
310
|
+
The input data used for inference (used for signature and input_example).
|
|
311
|
+
predictions : np.ndarray, optional
|
|
312
|
+
The predictions made by the model (used for signature).
|
|
313
|
+
flavor : str, default="sklearn"
|
|
314
|
+
The MLflow flavor to use for model logging. Supports 'sklearn', 'pyfunc', etc.
|
|
315
|
+
register_model : bool, default=True
|
|
316
|
+
Whether to register the model.
|
|
317
|
+
description : str, optional
|
|
318
|
+
Additional model description.
|
|
319
|
+
tags : dict, optional
|
|
320
|
+
Dictionary of tags to add to the model.
|
|
321
|
+
python_model : mlflow.pyfunc.PythonModel, optional
|
|
322
|
+
The PythonModel instance to log (required for 'pyfunc' flavor).
|
|
323
|
+
name : str, default='model'
|
|
324
|
+
The name under which the model will be logged.
|
|
325
|
+
|
|
326
|
+
**kwargs :
|
|
327
|
+
Additional keyword arguments to pass to the flavor-specific log_model function.
|
|
328
|
+
"""
|
|
329
|
+
active_mlflow_run = mlflow.active_run()
|
|
330
|
+
if not self.run_id:
|
|
331
|
+
log_and_raise_error(
|
|
332
|
+
self._logger,
|
|
333
|
+
"No MLflow run is currently managed by this ModelManager instance. "
|
|
334
|
+
"Call start_new_run() or ensure __init__ completed a run start.",
|
|
335
|
+
)
|
|
336
|
+
if not active_mlflow_run or active_mlflow_run.info.run_id != self.run_id:
|
|
337
|
+
log_and_raise_error(
|
|
338
|
+
self._logger,
|
|
339
|
+
f"The MLflow run managed by this instance ({self.run_id}) is not the "
|
|
340
|
+
f"currently active MLflow run ({active_mlflow_run.info.run_id if active_mlflow_run else 'None'}). "
|
|
341
|
+
"Please ensure the correct run is active or start a new run.",
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
try:
|
|
345
|
+
signature = None
|
|
346
|
+
input_example = None
|
|
347
|
+
if input_data is not None and predictions is not None:
|
|
348
|
+
signature_input = input_data.copy()
|
|
349
|
+
int_cols = signature_input.select_dtypes(include=["int32", "int64"]).columns
|
|
350
|
+
if len(int_cols):
|
|
351
|
+
signature_input[int_cols] = signature_input[int_cols].astype("float64")
|
|
352
|
+
signature = infer_signature(signature_input, predictions)
|
|
353
|
+
input_example = signature_input
|
|
354
|
+
registered_model_name = self.model_name if register_model else None
|
|
355
|
+
|
|
356
|
+
if tags:
|
|
357
|
+
mlflow.set_tags(tags)
|
|
358
|
+
if description:
|
|
359
|
+
mlflow.set_tag("mlflow.note.content", self.description)
|
|
360
|
+
|
|
361
|
+
if flavor == "pyfunc":
|
|
362
|
+
from mlflow import pyfunc
|
|
363
|
+
|
|
364
|
+
pyfunc.log_model(
|
|
365
|
+
python_model=python_model,
|
|
366
|
+
name=name,
|
|
367
|
+
registered_model_name=registered_model_name,
|
|
368
|
+
signature=signature,
|
|
369
|
+
input_example=input_example,
|
|
370
|
+
**kwargs,
|
|
371
|
+
)
|
|
372
|
+
self._logger.info(f"Model '{self.model_name}' logged successfully using mlflow.pyfunc.")
|
|
373
|
+
else:
|
|
374
|
+
flavor_module = getattr(mlflow, flavor, None)
|
|
375
|
+
if flavor_module is None:
|
|
376
|
+
log_and_raise_error(self._logger, f"MLflow flavor '{flavor}' is not available.")
|
|
377
|
+
|
|
378
|
+
log_model_func = getattr(flavor_module, "log_model", None)
|
|
379
|
+
if log_model_func is None:
|
|
380
|
+
log_and_raise_error(self._logger, f"'log_model' not found in mlflow.{flavor} module.")
|
|
381
|
+
|
|
382
|
+
log_model_func(
|
|
383
|
+
model,
|
|
384
|
+
name=name,
|
|
385
|
+
registered_model_name=registered_model_name,
|
|
386
|
+
signature=signature,
|
|
387
|
+
input_example=input_example,
|
|
388
|
+
**kwargs,
|
|
389
|
+
)
|
|
390
|
+
self._logger.info(f"Model '{self.model_name}' logged successfully using mlflow.{flavor}.")
|
|
391
|
+
|
|
392
|
+
if register_model:
|
|
393
|
+
self._add_model_version_metadata(description=description, tags=tags, user=self.user)
|
|
394
|
+
except Exception as e:
|
|
395
|
+
log_and_raise_error(self._logger, f"Error logging model '{self.model_name}': {str(e)}")
|
|
396
|
+
|
|
397
|
+
def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
|
|
398
|
+
"""
|
|
399
|
+
Logs a local file or directory as an artifact to the active MLflow run.
|
|
400
|
+
|
|
401
|
+
Parameters
|
|
402
|
+
----------
|
|
403
|
+
local_path : str
|
|
404
|
+
Path to the local file or directory to log.
|
|
405
|
+
artifact_path : Optional[str]
|
|
406
|
+
If provided, the artifact will be logged to this path within the run's artifact URI.
|
|
407
|
+
If None, the artifact is logged to the root of the run's artifact URI.
|
|
408
|
+
"""
|
|
409
|
+
active_mlflow_run = mlflow.active_run()
|
|
410
|
+
if not self.run_id:
|
|
411
|
+
log_and_raise_error(
|
|
412
|
+
self._logger,
|
|
413
|
+
"No MLflow run is currently managed by this ModelManager instance. "
|
|
414
|
+
"Call start_new_run() or ensure __init__ completed a run start.",
|
|
415
|
+
)
|
|
416
|
+
if not active_mlflow_run or active_mlflow_run.info.run_id != self.run_id:
|
|
417
|
+
log_and_raise_error(
|
|
418
|
+
self._logger,
|
|
419
|
+
f"The MLflow run managed by this instance ({self.run_id}) is not the "
|
|
420
|
+
f"currently active MLflow run ({active_mlflow_run.info.run_id if active_mlflow_run else 'None'}). "
|
|
421
|
+
"Please ensure the correct run is active or start a new run.",
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
try:
|
|
425
|
+
mlflow.log_artifact(local_path, artifact_path=artifact_path)
|
|
426
|
+
self._logger.debug(f"Logged artifact '{local_path}'.")
|
|
427
|
+
except Exception as e:
|
|
428
|
+
log_and_raise_error(self._logger, f"Error logging artifact '{local_path}' to run '{self.run_id}': {e}")
|
|
429
|
+
|
|
430
|
+
def log_artifacts(self, local_dir: str, artifact_path: str | None = None) -> None:
|
|
431
|
+
"""
|
|
432
|
+
Logs all files in a local directory as artifacts to the active MLflow run.
|
|
433
|
+
|
|
434
|
+
Parameters
|
|
435
|
+
----------
|
|
436
|
+
local_dir : str
|
|
437
|
+
Path to the local directory containing files to log.
|
|
438
|
+
artifact_path : Optional[str]
|
|
439
|
+
If provided, artifacts will be logged to this path within the run's artifact URI.
|
|
440
|
+
If None, artifacts are logged to the root of the run's artifact URI.
|
|
441
|
+
"""
|
|
442
|
+
active_mlflow_run = mlflow.active_run()
|
|
443
|
+
if not self.run_id:
|
|
444
|
+
log_and_raise_error(
|
|
445
|
+
self._logger,
|
|
446
|
+
"No MLflow run is currently managed by this ModelManager instance. "
|
|
447
|
+
"Call start_new_run() or ensure __init__ completed a run start.",
|
|
448
|
+
)
|
|
449
|
+
if not active_mlflow_run or active_mlflow_run.info.run_id != self.run_id:
|
|
450
|
+
log_and_raise_error(
|
|
451
|
+
self._logger,
|
|
452
|
+
f"The MLflow run managed by this instance ({self.run_id}) is not the "
|
|
453
|
+
f"currently active MLflow run ({active_mlflow_run.info.run_id if active_mlflow_run else 'None'}). "
|
|
454
|
+
"Please ensure the correct run is active or start a new run.",
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
try:
|
|
458
|
+
mlflow.log_artifacts(local_dir, artifact_path=artifact_path)
|
|
459
|
+
self._logger.debug(f"Logged artifacts from '{local_dir}'.")
|
|
460
|
+
except Exception as e:
|
|
461
|
+
log_and_raise_error(self._logger, f"Error logging artifacts from '{local_dir}' to run '{self.run_id}': {e}")
|
|
462
|
+
|
|
463
|
+
def log_metric(self, key: str, value: float, step: int | None = None) -> None:
|
|
464
|
+
"""
|
|
465
|
+
Logs a single metric to the active MLflow run.
|
|
466
|
+
|
|
467
|
+
Parameters
|
|
468
|
+
----------
|
|
469
|
+
key : str
|
|
470
|
+
Metric key.
|
|
471
|
+
value : float
|
|
472
|
+
Metric value.
|
|
473
|
+
step : Optional[int]
|
|
474
|
+
Metric step.
|
|
475
|
+
"""
|
|
476
|
+
active_mlflow_run = mlflow.active_run()
|
|
477
|
+
if not self.run_id:
|
|
478
|
+
log_and_raise_error(
|
|
479
|
+
self._logger,
|
|
480
|
+
"No MLflow run is currently managed by this ModelManager instance. "
|
|
481
|
+
"Call start_new_run() or ensure __init__ completed a run start.",
|
|
482
|
+
)
|
|
483
|
+
if not active_mlflow_run or active_mlflow_run.info.run_id != self.run_id:
|
|
484
|
+
log_and_raise_error(
|
|
485
|
+
self._logger,
|
|
486
|
+
f"The MLflow run managed by this instance ({self.run_id}) is not the "
|
|
487
|
+
f"currently active MLflow run ({active_mlflow_run.info.run_id if active_mlflow_run else 'None'}). "
|
|
488
|
+
"Please ensure the correct run is active or start a new run.",
|
|
489
|
+
)
|
|
490
|
+
try:
|
|
491
|
+
mlflow.log_metric(key, value, step=step)
|
|
492
|
+
self._logger.debug(f"Logged metric '{key}': {value}.")
|
|
493
|
+
except Exception as e:
|
|
494
|
+
log_and_raise_error(self._logger, f"Error logging metric '{key}' to run '{self.run_id}': {e}")
|
|
495
|
+
|
|
496
|
+
def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> None:
|
|
497
|
+
"""
|
|
498
|
+
Logs multiple metrics to the active MLflow run.
|
|
499
|
+
|
|
500
|
+
Parameters
|
|
501
|
+
----------
|
|
502
|
+
metrics : Dict[str, float]
|
|
503
|
+
Dictionary of metric keys and values.
|
|
504
|
+
step : Optional[int]
|
|
505
|
+
Metric step.
|
|
506
|
+
"""
|
|
507
|
+
active_mlflow_run = mlflow.active_run()
|
|
508
|
+
if not self.run_id:
|
|
509
|
+
log_and_raise_error(
|
|
510
|
+
self._logger,
|
|
511
|
+
"No MLflow run is currently managed by this ModelManager instance. "
|
|
512
|
+
"Call start_new_run() or ensure __init__ completed a run start.",
|
|
513
|
+
)
|
|
514
|
+
if not active_mlflow_run or active_mlflow_run.info.run_id != self.run_id:
|
|
515
|
+
log_and_raise_error(
|
|
516
|
+
self._logger,
|
|
517
|
+
f"The MLflow run managed by this instance ({self.run_id}) is not the "
|
|
518
|
+
f"currently active MLflow run ({active_mlflow_run.info.run_id if active_mlflow_run else 'None'}). "
|
|
519
|
+
"Please ensure the correct run is active or start a new run.",
|
|
520
|
+
)
|
|
521
|
+
try:
|
|
522
|
+
mlflow.log_metrics(metrics, step=step)
|
|
523
|
+
self._logger.debug(f"Logged {len(metrics)} metrics.")
|
|
524
|
+
except Exception as e:
|
|
525
|
+
log_and_raise_error(self._logger, f"Error logging metrics to run '{self.run_id}': {e}")
|
|
526
|
+
|
|
527
|
+
def log_param(self, key: str, value: any) -> None:
|
|
528
|
+
"""
|
|
529
|
+
Logs a single parameter to the active MLflow run.
|
|
530
|
+
|
|
531
|
+
Parameters
|
|
532
|
+
----------
|
|
533
|
+
key : str
|
|
534
|
+
Parameter key.
|
|
535
|
+
value : Any
|
|
536
|
+
Parameter value.
|
|
537
|
+
"""
|
|
538
|
+
active_mlflow_run = mlflow.active_run()
|
|
539
|
+
if not self.run_id:
|
|
540
|
+
log_and_raise_error(
|
|
541
|
+
self._logger,
|
|
542
|
+
"No MLflow run is currently managed by this ModelManager instance. "
|
|
543
|
+
"Call start_new_run() or ensure __init__ completed a run start.",
|
|
544
|
+
)
|
|
545
|
+
if not active_mlflow_run or active_mlflow_run.info.run_id != self.run_id:
|
|
546
|
+
log_and_raise_error(
|
|
547
|
+
self._logger,
|
|
548
|
+
f"The MLflow run managed by this instance ({self.run_id}) is not the "
|
|
549
|
+
f"currently active MLflow run ({active_mlflow_run.info.run_id if active_mlflow_run else 'None'}). "
|
|
550
|
+
"Please ensure the correct run is active or start a new run.",
|
|
551
|
+
)
|
|
552
|
+
try:
|
|
553
|
+
mlflow.log_param(key, value)
|
|
554
|
+
self._logger.debug(f"Logged param '{key}'.")
|
|
555
|
+
except Exception as e:
|
|
556
|
+
log_and_raise_error(self._logger, f"Error logging parameter '{key}' to run '{self.run_id}': {e}")
|
|
557
|
+
|
|
558
|
+
def log_params(self, params: dict[str, any]) -> None:
|
|
559
|
+
"""
|
|
560
|
+
Logs multiple parameters to the active MLflow run.
|
|
561
|
+
|
|
562
|
+
Parameters
|
|
563
|
+
----------
|
|
564
|
+
params : Dict[str, Any]
|
|
565
|
+
Dictionary of parameter keys and values.
|
|
566
|
+
"""
|
|
567
|
+
active_mlflow_run = mlflow.active_run()
|
|
568
|
+
if not self.run_id:
|
|
569
|
+
log_and_raise_error(
|
|
570
|
+
self._logger,
|
|
571
|
+
"No MLflow run is currently managed by this ModelManager instance. "
|
|
572
|
+
"Call start_new_run() or ensure __init__ completed a run start.",
|
|
573
|
+
)
|
|
574
|
+
if not active_mlflow_run or active_mlflow_run.info.run_id != self.run_id:
|
|
575
|
+
log_and_raise_error(
|
|
576
|
+
self._logger,
|
|
577
|
+
f"The MLflow run managed by this instance ({self.run_id}) is not the "
|
|
578
|
+
f"currently active MLflow run ({active_mlflow_run.info.run_id if active_mlflow_run else 'None'}). "
|
|
579
|
+
"Please ensure the correct run is active or start a new run.",
|
|
580
|
+
)
|
|
581
|
+
try:
|
|
582
|
+
mlflow.log_params(params)
|
|
583
|
+
self._logger.debug(f"Logged {len(params)} params.")
|
|
584
|
+
except Exception as e:
|
|
585
|
+
log_and_raise_error(self._logger, f"Error logging parameters to run '{self.run_id}': {e}")
|
|
586
|
+
|
|
587
|
+
def set_tag(self, key: str, value: any) -> None:
|
|
588
|
+
"""
|
|
589
|
+
Sets a single tag on the active MLflow run.
|
|
590
|
+
|
|
591
|
+
Parameters
|
|
592
|
+
----------
|
|
593
|
+
key : str
|
|
594
|
+
Tag key.
|
|
595
|
+
value : Any
|
|
596
|
+
Tag value.
|
|
597
|
+
"""
|
|
598
|
+
active_mlflow_run = mlflow.active_run()
|
|
599
|
+
if not self.run_id:
|
|
600
|
+
log_and_raise_error(
|
|
601
|
+
self._logger,
|
|
602
|
+
"No MLflow run is currently managed by this ModelManager instance. "
|
|
603
|
+
"Call start_new_run() or ensure __init__ completed a run start.",
|
|
604
|
+
)
|
|
605
|
+
if not active_mlflow_run or active_mlflow_run.info.run_id != self.run_id:
|
|
606
|
+
log_and_raise_error(
|
|
607
|
+
self._logger,
|
|
608
|
+
f"The MLflow run managed by this instance ({self.run_id}) is not the "
|
|
609
|
+
f"currently active MLflow run ({active_mlflow_run.info.run_id if active_mlflow_run else 'None'}). "
|
|
610
|
+
"Please ensure the correct run is active or start a new run.",
|
|
611
|
+
)
|
|
612
|
+
try:
|
|
613
|
+
mlflow.set_tag(key, value)
|
|
614
|
+
self._logger.debug(f"Set tag '{key}'.")
|
|
615
|
+
except Exception as e:
|
|
616
|
+
log_and_raise_error(self._logger, f"Error setting tag '{key}' on run '{self.run_id}': {e}")
|
|
617
|
+
|
|
618
|
+
def set_tags(self, tags: dict[str, any]) -> None:
|
|
619
|
+
"""
|
|
620
|
+
Sets multiple tags on the active MLflow run.
|
|
621
|
+
|
|
622
|
+
Parameters
|
|
623
|
+
----------
|
|
624
|
+
tags : Dict[str, Any]
|
|
625
|
+
Dictionary of tag keys and values.
|
|
626
|
+
"""
|
|
627
|
+
active_mlflow_run = mlflow.active_run()
|
|
628
|
+
if not self.run_id:
|
|
629
|
+
log_and_raise_error(
|
|
630
|
+
self._logger,
|
|
631
|
+
"No MLflow run is currently managed by this ModelManager instance. "
|
|
632
|
+
"Call start_new_run() or ensure __init__ completed a run start.",
|
|
633
|
+
)
|
|
634
|
+
if not active_mlflow_run or active_mlflow_run.info.run_id != self.run_id:
|
|
635
|
+
log_and_raise_error(
|
|
636
|
+
self._logger,
|
|
637
|
+
f"The MLflow run managed by this instance ({self.run_id}) is not the "
|
|
638
|
+
f"currently active MLflow run ({active_mlflow_run.info.run_id if active_mlflow_run else 'None'}). "
|
|
639
|
+
"Please ensure the correct run is active or start a new run.",
|
|
640
|
+
)
|
|
641
|
+
try:
|
|
642
|
+
mlflow.set_tags(tags)
|
|
643
|
+
self._logger.debug(f"Set {len(tags)} tags.")
|
|
644
|
+
except Exception as e:
|
|
645
|
+
log_and_raise_error(self._logger, f"Error setting tags on run '{self.run_id}': {e}")
|
|
646
|
+
|
|
647
|
+
def _add_model_version_metadata(
|
|
648
|
+
self, description: str | None = None, tags: dict[str, str] | None = None, user: str | None = None
|
|
649
|
+
) -> int | None:
|
|
650
|
+
"""
|
|
651
|
+
Add a description, tags, and a user ID (as a tag) to the model version
|
|
652
|
+
associated with the given run URI (or the last logged run if not provided).
|
|
653
|
+
|
|
654
|
+
Parameters
|
|
655
|
+
----------
|
|
656
|
+
description : str, optional
|
|
657
|
+
The description to add to this model version.
|
|
658
|
+
tags : Dict[str, str], optional
|
|
659
|
+
Key-value pairs to add as tags to the model version.
|
|
660
|
+
user : str, optional
|
|
661
|
+
The user ID to store in the model version's tags.
|
|
662
|
+
Returns
|
|
663
|
+
-------
|
|
664
|
+
Optional[int]
|
|
665
|
+
The integer version number that was modified, or None if not found.
|
|
666
|
+
"""
|
|
667
|
+
try:
|
|
668
|
+
# Use run_uri if provided, else use self.run_id
|
|
669
|
+
if self.run_id is None:
|
|
670
|
+
log_and_raise_error(self._logger, "No model run_id!")
|
|
671
|
+
|
|
672
|
+
versions = self.client.search_model_versions(f"name='{self.model_name}'")
|
|
673
|
+
target_version = None
|
|
674
|
+
for mv in versions:
|
|
675
|
+
if mv.run_id == self.run_id:
|
|
676
|
+
target_version = int(mv.version)
|
|
677
|
+
break
|
|
678
|
+
if target_version is None:
|
|
679
|
+
self._logger.warning(f"No matching version found for run ID: {self.run_id}")
|
|
680
|
+
return None
|
|
681
|
+
|
|
682
|
+
# Update description if provided
|
|
683
|
+
if description:
|
|
684
|
+
self.client.update_model_version(
|
|
685
|
+
name=self.model_name, version=str(target_version), description=description
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
# Combine any provided tags with user (if present)
|
|
689
|
+
final_tags = tags.copy() if tags else {}
|
|
690
|
+
if user is not None:
|
|
691
|
+
final_tags["user"] = user
|
|
692
|
+
|
|
693
|
+
# Set all tags
|
|
694
|
+
for key, value in final_tags.items():
|
|
695
|
+
self.client.set_model_version_tag(
|
|
696
|
+
name=self.model_name, version=str(target_version), key=key, value=value
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
self._logger.debug(f"Updated model '{self.model_name}' version {target_version}.")
|
|
700
|
+
except Exception as e:
|
|
701
|
+
log_and_raise_error(self._logger, f"Error updating model '{self.model_name}': {e}")
|
|
702
|
+
|
|
703
|
+
def get_latest_model_version(self) -> int | None:
|
|
704
|
+
"""
|
|
705
|
+
Get the latest version of the registered model.
|
|
706
|
+
|
|
707
|
+
Returns
|
|
708
|
+
-------
|
|
709
|
+
Optional[int]
|
|
710
|
+
The latest version number.
|
|
711
|
+
"""
|
|
712
|
+
try:
|
|
713
|
+
all_versions = self.client.search_model_versions(f"name='{self.model_name}'")
|
|
714
|
+
if all_versions:
|
|
715
|
+
# Find the version with the highest version number
|
|
716
|
+
latest_version = max(all_versions, key=lambda v: int(v.version))
|
|
717
|
+
# Return just the version number as an integer
|
|
718
|
+
return int(latest_version.version)
|
|
719
|
+
else:
|
|
720
|
+
self._logger.info(f"No versions found for model '{self.model_name}'.")
|
|
721
|
+
except Exception as e:
|
|
722
|
+
log_and_raise_error(self._logger, f"Error getting latest version for model '{self.model_name}': {e}")
|
|
723
|
+
|
|
724
|
+
def get_model_uri(self, version: int | None = None, alias: str | None = None) -> str:
|
|
725
|
+
"""
|
|
726
|
+
Get the URI of a registered model by version or alias.
|
|
727
|
+
|
|
728
|
+
Parameters
|
|
729
|
+
----------
|
|
730
|
+
version : int, optional
|
|
731
|
+
The version of the model to retrieve the URI for.
|
|
732
|
+
alias : str, optional
|
|
733
|
+
The alias of the model to retrieve the URI for.
|
|
734
|
+
|
|
735
|
+
Returns
|
|
736
|
+
-------
|
|
737
|
+
str
|
|
738
|
+
The URI of the registered model.
|
|
739
|
+
|
|
740
|
+
Raises
|
|
741
|
+
------
|
|
742
|
+
ValueError
|
|
743
|
+
If neither version nor alias is provided.
|
|
744
|
+
"""
|
|
745
|
+
if version is None and alias is None:
|
|
746
|
+
log_and_raise_error(self._logger, "Either 'version' or 'alias' must be provided to get the model URI.")
|
|
747
|
+
|
|
748
|
+
try:
|
|
749
|
+
if alias is not None:
|
|
750
|
+
version = self.client.get_model_version_by_alias(name=self.model_name, alias=alias).version
|
|
751
|
+
|
|
752
|
+
model_uri = f"models:/{self.model_name}/{version}"
|
|
753
|
+
return model_uri
|
|
754
|
+
except Exception as e:
|
|
755
|
+
log_and_raise_error(
|
|
756
|
+
self._logger,
|
|
757
|
+
f"Error retrieving model URI for model '{self.model_name}', version '{version}', alias '{alias}': {e}",
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
def set_model_alias(self, *, version: int | None = None, alias: str) -> None:
|
|
761
|
+
"""
|
|
762
|
+
Set an alias for a specific model version.
|
|
763
|
+
Aliases enable you to refer to a model version by name rather than numeric version.
|
|
764
|
+
Requires MLflow 2.0 or newer.
|
|
765
|
+
|
|
766
|
+
Parameters
|
|
767
|
+
----------
|
|
768
|
+
version : int
|
|
769
|
+
The version of the model to which the alias should be assigned.
|
|
770
|
+
alias : str
|
|
771
|
+
The alias name (e.g., "latest", "candidate", "production-ready").
|
|
772
|
+
"""
|
|
773
|
+
try:
|
|
774
|
+
if version is None:
|
|
775
|
+
version = self.get_latest_model_version()
|
|
776
|
+
|
|
777
|
+
self.client.set_registered_model_alias(name=self.model_name, version=str(version), alias=alias)
|
|
778
|
+
|
|
779
|
+
self._logger.info(f"Set alias '{alias}' for version {version} of model '{self.model_name}'.")
|
|
780
|
+
except Exception as e:
|
|
781
|
+
log_and_raise_error(
|
|
782
|
+
self._logger, f"Error setting alias '{alias}' for model '{self.model_name}', version {version}: {e}"
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
def register_model(
|
|
786
|
+
self, run_name: str, experiment_id: str = None, description: str = None, tags: dict = None
|
|
787
|
+
) -> None:
|
|
788
|
+
"""
|
|
789
|
+
Move a run model to the model registry.
|
|
790
|
+
|
|
791
|
+
Parameters
|
|
792
|
+
----------
|
|
793
|
+
run_name : str
|
|
794
|
+
The name of the model run to be moved.
|
|
795
|
+
description : str, optional
|
|
796
|
+
A description of the model.
|
|
797
|
+
tags : dict, optional
|
|
798
|
+
A dictionary of tags to associate with the model.
|
|
799
|
+
experiment_id : str, optional
|
|
800
|
+
Experiment ID to search for the run. If None, uses the current experiment ID.
|
|
801
|
+
"""
|
|
802
|
+
|
|
803
|
+
if self.user is None:
|
|
804
|
+
log_and_raise_error(self._logger, "Define user to register a model.")
|
|
805
|
+
if tags is None:
|
|
806
|
+
tags = {}
|
|
807
|
+
tags["user"] = self.user
|
|
808
|
+
|
|
809
|
+
if experiment_id is None:
|
|
810
|
+
experiment_id = self.experiment_id
|
|
811
|
+
|
|
812
|
+
runs = self.client.search_runs(
|
|
813
|
+
experiment_ids=[experiment_id], filter_string=f"tags.mlflow.runName = '{run_name}'"
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
# Check if runs were found and return the run ID
|
|
817
|
+
if not runs:
|
|
818
|
+
log_and_raise_error(
|
|
819
|
+
self._logger, f"No runs found with name '{run_name}' in experiment ID '{experiment_id}'."
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
run_id = runs[0].info.run_id
|
|
823
|
+
|
|
824
|
+
try:
|
|
825
|
+
mlflow.register_model(
|
|
826
|
+
model_uri=f"runs:/{run_id}/model",
|
|
827
|
+
name=self.model_name,
|
|
828
|
+
tags=tags,
|
|
829
|
+
)
|
|
830
|
+
self.client.update_model_version(
|
|
831
|
+
name=self.model_name, version=self.get_latest_model_version(), description=description
|
|
832
|
+
)
|
|
833
|
+
self._logger.info(f"Model '{self.model_name}' registered successfully from run '{run_name}'.")
|
|
834
|
+
except Exception as e:
|
|
835
|
+
log_and_raise_error(self._logger, f"Error registering model '{self.model_name}' from run '{run_name}': {e}")
|
|
836
|
+
|
|
837
|
+
def load_model(
|
|
838
|
+
self,
|
|
839
|
+
model_uri: str | None = None,
|
|
840
|
+
version: int | None = None,
|
|
841
|
+
alias: str | None = None,
|
|
842
|
+
flavor: str = "sklearn",
|
|
843
|
+
**kwargs,
|
|
844
|
+
):
|
|
845
|
+
"""
|
|
846
|
+
Load a model using model URI, version, or alias, supporting all MLflow flavors including 'pyfunc'.
|
|
847
|
+
|
|
848
|
+
Parameters
|
|
849
|
+
----------
|
|
850
|
+
model_uri : str, optional
|
|
851
|
+
The URI of the model to load. If provided, it takes precedence over version and alias.
|
|
852
|
+
version : int, optional
|
|
853
|
+
The version of the model to load.
|
|
854
|
+
alias : str, optional
|
|
855
|
+
The alias of the model to load.
|
|
856
|
+
flavor : str, optional
|
|
857
|
+
The flavor of the model to retrieve. Default is 'sklearn'. Supports 'pyfunc', etc.
|
|
858
|
+
**kwargs :
|
|
859
|
+
Additional keyword arguments to pass to the flavor-specific load_model function.
|
|
860
|
+
|
|
861
|
+
Returns
|
|
862
|
+
-------
|
|
863
|
+
model
|
|
864
|
+
The loaded model.
|
|
865
|
+
|
|
866
|
+
Raises
|
|
867
|
+
------
|
|
868
|
+
ValueError
|
|
869
|
+
If none of model_uri, version, or alias is provided.
|
|
870
|
+
"""
|
|
871
|
+
if model_uri is None:
|
|
872
|
+
if version is None and alias is None:
|
|
873
|
+
log_and_raise_error(
|
|
874
|
+
self._logger, "Either 'model_uri', 'version', or 'alias' must be provided to load the model."
|
|
875
|
+
)
|
|
876
|
+
|
|
877
|
+
# Get model URI from version or alias
|
|
878
|
+
model_uri = self.get_model_uri(version=version, alias=alias)
|
|
879
|
+
|
|
880
|
+
try:
|
|
881
|
+
if flavor == "pyfunc":
|
|
882
|
+
from mlflow import pyfunc
|
|
883
|
+
|
|
884
|
+
model = pyfunc.load_model(model_uri, **kwargs)
|
|
885
|
+
self._logger.info(f"Successfully loaded model from URI '{model_uri}' using flavor 'pyfunc'.")
|
|
886
|
+
return model
|
|
887
|
+
else:
|
|
888
|
+
flavor_module = getattr(mlflow, flavor, None)
|
|
889
|
+
if flavor_module is None:
|
|
890
|
+
log_and_raise_error(self._logger, f"MLflow flavor '{flavor}' is not available.")
|
|
891
|
+
|
|
892
|
+
load_model_func = getattr(flavor_module, "load_model", None)
|
|
893
|
+
if load_model_func is None:
|
|
894
|
+
log_and_raise_error(self._logger, f"'load_model' not found in mlflow.{flavor} module.")
|
|
895
|
+
|
|
896
|
+
model = load_model_func(model_uri, **kwargs)
|
|
897
|
+
self._logger.info(f"Successfully loaded model from URI '{model_uri}' using flavor '{flavor}'.")
|
|
898
|
+
return model
|
|
899
|
+
except Exception as e:
|
|
900
|
+
log_and_raise_error(self._logger, f"Error loading model from URI '{model_uri}': {e}")
|
|
901
|
+
|
|
902
|
+
def _resolve_logged_model_uri(self, run_id: str, experiment_id: str) -> str:
|
|
903
|
+
"""
|
|
904
|
+
Resolve the model artifact URI for a run. For MLflow 3.x runs, returns the LoggedModel's
|
|
905
|
+
artifact_location directly to avoid a proxy hang in RunsArtifactRepository. Falls back to
|
|
906
|
+
runs:/{run_id}/model for MLflow 2.x runs.
|
|
907
|
+
"""
|
|
908
|
+
try:
|
|
909
|
+
page_token = None
|
|
910
|
+
while True:
|
|
911
|
+
page = self.client.search_logged_models(
|
|
912
|
+
experiment_ids=[experiment_id],
|
|
913
|
+
filter_string="name = 'model'",
|
|
914
|
+
page_token=page_token,
|
|
915
|
+
)
|
|
916
|
+
for lm in page:
|
|
917
|
+
if lm.source_run_id == run_id:
|
|
918
|
+
self._logger.info(f"Loading model from LoggedModel URI (run: {run_id[:8]}...).")
|
|
919
|
+
return lm.artifact_location
|
|
920
|
+
if not page.token:
|
|
921
|
+
break
|
|
922
|
+
page_token = page.token
|
|
923
|
+
except Exception as e:
|
|
924
|
+
self._logger.warning(f"Could not resolve LoggedModel URI for run '{run_id}': {e}. Falling back.")
|
|
925
|
+
|
|
926
|
+
return f"runs:/{run_id}/model"
|
|
927
|
+
|
|
928
|
+
def load_model_from_experiment(self, run_name: str, experiment_name: str = None, flavor: str = "sklearn", **kwargs):
|
|
929
|
+
"""
|
|
930
|
+
Load a model from a specific experiment and run, supporting all MLflow flavors including 'pyfunc'.
|
|
931
|
+
|
|
932
|
+
Parameters
|
|
933
|
+
----------
|
|
934
|
+
experiment_name : str, optional
|
|
935
|
+
The name of the experiment to search for the run. If None, uses the current experiment name.
|
|
936
|
+
run_name : str
|
|
937
|
+
The name of the run to load the model from.
|
|
938
|
+
flavor : str, optional
|
|
939
|
+
The flavor of the model to retrieve. Default is 'sklearn'.
|
|
940
|
+
**kwargs :
|
|
941
|
+
Additional keyword arguments to pass to the flavor-specific load_model function.
|
|
942
|
+
|
|
943
|
+
Returns
|
|
944
|
+
-------
|
|
945
|
+
model
|
|
946
|
+
The model loaded from the specified experiment and run.
|
|
947
|
+
"""
|
|
948
|
+
if experiment_name is None:
|
|
949
|
+
experiment_name = self.model_name
|
|
950
|
+
|
|
951
|
+
try:
|
|
952
|
+
experiment = self.client.get_experiment_by_name(experiment_name)
|
|
953
|
+
if not experiment:
|
|
954
|
+
log_and_raise_error(self._logger, f"Experiment '{experiment_name}' not found.")
|
|
955
|
+
|
|
956
|
+
runs = self.client.search_runs(
|
|
957
|
+
experiment_ids=[experiment.experiment_id], filter_string=f"tags.mlflow.runName = '{run_name}'"
|
|
958
|
+
)
|
|
959
|
+
if not runs:
|
|
960
|
+
log_and_raise_error(self._logger, f"Run '{run_name}' not found in experiment '{experiment_name}'.")
|
|
961
|
+
|
|
962
|
+
run_id = runs[0].info.run_id
|
|
963
|
+
model_uri = self._resolve_logged_model_uri(run_id=run_id, experiment_id=experiment.experiment_id)
|
|
964
|
+
return self.load_model(model_uri=model_uri, flavor=flavor, **kwargs)
|
|
965
|
+
except Exception as e:
|
|
966
|
+
log_and_raise_error(
|
|
967
|
+
self._logger, f"Error loading model from experiment '{experiment_name}', run '{run_name}': {e}"
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
def load_latest_model(self, flavor="sklearn", **kwargs):
|
|
971
|
+
"""
|
|
972
|
+
Get registered model with the latest version, supporting all MLflow flavors including 'pyfunc'.
|
|
973
|
+
|
|
974
|
+
Parameters
|
|
975
|
+
----------
|
|
976
|
+
flavor : str, optional
|
|
977
|
+
The flavor of the model to retrieve. Default is 'sklearn'.
|
|
978
|
+
**kwargs :
|
|
979
|
+
Additional keyword arguments to pass to the flavor-specific load_model function.
|
|
980
|
+
|
|
981
|
+
Returns
|
|
982
|
+
-------
|
|
983
|
+
model
|
|
984
|
+
The model with the lastest version, or None if not found.
|
|
985
|
+
"""
|
|
986
|
+
version = self.get_latest_model_version()
|
|
987
|
+
if version is not None:
|
|
988
|
+
try:
|
|
989
|
+
model_uri = f"models:/{self.model_name}/{version}"
|
|
990
|
+
return self.load_model(model_uri=model_uri, flavor=flavor, **kwargs)
|
|
991
|
+
except Exception as e:
|
|
992
|
+
log_and_raise_error(self._logger, f"Error getting latest version for model '{self.model_name}': {e}")
|
|
993
|
+
|
|
994
|
+
def delete_model(self, model_uri: str | None = None, version: int | None = None, alias: str | None = None) -> None:
|
|
995
|
+
"""
|
|
996
|
+
Delete a specific version of this model from the MLflow model registry.
|
|
997
|
+
|
|
998
|
+
Parameters
|
|
999
|
+
----------
|
|
1000
|
+
model_uri : str, optional
|
|
1001
|
+
The URI of the model to delete.
|
|
1002
|
+
version : int, optional
|
|
1003
|
+
The version number of the model to delete.
|
|
1004
|
+
alias : str, optional
|
|
1005
|
+
The alias of the model to delete.
|
|
1006
|
+
|
|
1007
|
+
Raises
|
|
1008
|
+
------
|
|
1009
|
+
Exception
|
|
1010
|
+
If an error occurs while deleting the model version.
|
|
1011
|
+
"""
|
|
1012
|
+
if model_uri is None:
|
|
1013
|
+
if version is None and alias is None:
|
|
1014
|
+
log_and_raise_error(
|
|
1015
|
+
self._logger, "Either 'model_uri', 'version', or 'alias' must be provided to delete the model."
|
|
1016
|
+
)
|
|
1017
|
+
|
|
1018
|
+
# Get model URI from version or alias
|
|
1019
|
+
model_uri = self.get_model_uri(version=version, alias=alias)
|
|
1020
|
+
if version is not None and alias is not None:
|
|
1021
|
+
log_and_raise_error(self._logger, "Cannot delete model version and alias at the same time.")
|
|
1022
|
+
try:
|
|
1023
|
+
if alias is not None:
|
|
1024
|
+
version = self.client.get_model_version_by_alias(name=self.model_name, alias=alias).version
|
|
1025
|
+
if version is None:
|
|
1026
|
+
log_and_raise_error(self._logger, f"Version {version} not found for model '{self.model_name}'.")
|
|
1027
|
+
|
|
1028
|
+
self.client.delete_model_version(name=self.model_name, version=str(version))
|
|
1029
|
+
self._logger.info(f"Deleted version {version} of model '{self.model_name}'.")
|
|
1030
|
+
except Exception as e:
|
|
1031
|
+
log_and_raise_error(self._logger, f"Error deleting model '{self.model_name}' version {version}: {e}")
|
|
1032
|
+
|
|
1033
|
+
def grant_experiment_permission(self, username: str, permission: PermissionLevel) -> None:
|
|
1034
|
+
"""
|
|
1035
|
+
Grants or updates permission for a user on the current experiment.
|
|
1036
|
+
"""
|
|
1037
|
+
if not self.auth_client:
|
|
1038
|
+
log_and_raise_error(self._logger, "AuthServiceClient not initialized. Cannot manage permissions.")
|
|
1039
|
+
if not self.experiment_id:
|
|
1040
|
+
log_and_raise_error(self._logger, "Experiment ID not set. Cannot manage permissions.")
|
|
1041
|
+
|
|
1042
|
+
valid_permissions = ["READ", "EDIT", "MANAGE", "NO_PERMISSIONS"]
|
|
1043
|
+
|
|
1044
|
+
if permission not in valid_permissions:
|
|
1045
|
+
log_and_raise_error(
|
|
1046
|
+
self._logger, f"Invalid permission level '{permission}'. Must be one of {valid_permissions}"
|
|
1047
|
+
)
|
|
1048
|
+
|
|
1049
|
+
try:
|
|
1050
|
+
self.auth_client.update_experiment_permission(self.experiment_id, username, permission)
|
|
1051
|
+
self._logger.info(
|
|
1052
|
+
f"Successfully created '{permission}' permission for user '{username}' "
|
|
1053
|
+
f"on experiment '{self.experiment_id}'."
|
|
1054
|
+
)
|
|
1055
|
+
except RestException:
|
|
1056
|
+
self.auth_client.create_experiment_permission(self.experiment_id, username, permission)
|
|
1057
|
+
self._logger.info(
|
|
1058
|
+
f"Successfully created '{permission}' permission for user '{username}' "
|
|
1059
|
+
f"on experiment '{self.experiment_id}'."
|
|
1060
|
+
)
|
|
1061
|
+
except Exception as e_unexpected:
|
|
1062
|
+
log_and_raise_error(
|
|
1063
|
+
self._logger, f"An unexpected error occurred while setting experiment permission: {e_unexpected}"
|
|
1064
|
+
)
|
|
1065
|
+
|
|
1066
|
+
def grant_registered_model_permission(self, username: str, permission: PermissionLevel) -> None:
|
|
1067
|
+
"""
|
|
1068
|
+
Grants or updates permission for a user on the current registered model.
|
|
1069
|
+
"""
|
|
1070
|
+
if not self.auth_client:
|
|
1071
|
+
log_and_raise_error(self._logger, "AuthServiceClient not initialized. Cannot manage permissions.")
|
|
1072
|
+
|
|
1073
|
+
valid_permissions = ["READ", "EDIT", "MANAGE", "NO_PERMISSIONS"]
|
|
1074
|
+
if permission not in valid_permissions:
|
|
1075
|
+
log_and_raise_error(
|
|
1076
|
+
self._logger, f"Invalid permission level '{permission}'. Must be one of {valid_permissions}"
|
|
1077
|
+
)
|
|
1078
|
+
try:
|
|
1079
|
+
self.auth_client.update_registered_model_permission(self.model_name, username, permission)
|
|
1080
|
+
self._logger.info(
|
|
1081
|
+
f"Successfully updated permission for user '{username}' on registered model "
|
|
1082
|
+
f"'{self.model_name}' to '{permission}'."
|
|
1083
|
+
)
|
|
1084
|
+
except RestException:
|
|
1085
|
+
self.auth_client.create_registered_model_permission(self.model_name, username, permission)
|
|
1086
|
+
self._logger.info(
|
|
1087
|
+
f"Successfully created '{permission}' permission for user '{username}' "
|
|
1088
|
+
f"on registered model '{self.model_name}'."
|
|
1089
|
+
)
|
|
1090
|
+
except Exception as e_unexpected:
|
|
1091
|
+
log_and_raise_error(
|
|
1092
|
+
self._logger, f"An unexpected error occurred while setting registered model permission: {e_unexpected}"
|
|
1093
|
+
)
|
|
1094
|
+
|
|
1095
|
+
def load_artifact_from_model_version(self, version: int, artifact_path: str) -> str:
|
|
1096
|
+
"""
|
|
1097
|
+
Download an artifact from a specific version of the registered model.
|
|
1098
|
+
|
|
1099
|
+
Parameters
|
|
1100
|
+
----------
|
|
1101
|
+
version : int
|
|
1102
|
+
The version of the registered model.
|
|
1103
|
+
artifact_path : str
|
|
1104
|
+
The path to the artifact within the model version.
|
|
1105
|
+
|
|
1106
|
+
Returns
|
|
1107
|
+
-------
|
|
1108
|
+
str
|
|
1109
|
+
The local path to the downloaded artifact.
|
|
1110
|
+
"""
|
|
1111
|
+
try:
|
|
1112
|
+
model_uri = f"models:/{self.model_name}/{version}"
|
|
1113
|
+
local_path = mlflow.artifacts.download_artifacts(artifact_uri=f"{model_uri}/{artifact_path}")
|
|
1114
|
+
self._logger.info(f"Downloaded artifact '{artifact_path}' from model version {version} to '{local_path}'.")
|
|
1115
|
+
return local_path
|
|
1116
|
+
except Exception as e:
|
|
1117
|
+
log_and_raise_error(
|
|
1118
|
+
self._logger, f"Error downloading artifact '{artifact_path}' from model version {version}: {e}"
|
|
1119
|
+
)
|
|
1120
|
+
|
|
1121
|
+
def load_artifact_from_run(self, run_name: str, artifact_path: str, experiment_name: str = None) -> str:
|
|
1122
|
+
"""
|
|
1123
|
+
Download an artifact from a run by run name.
|
|
1124
|
+
|
|
1125
|
+
Parameters
|
|
1126
|
+
----------
|
|
1127
|
+
run_name : str
|
|
1128
|
+
The name of the run.
|
|
1129
|
+
artifact_path : str
|
|
1130
|
+
The path to the artifact within the run.
|
|
1131
|
+
experiment_name : str, optional
|
|
1132
|
+
The name of the experiment. If None, uses the model name.
|
|
1133
|
+
|
|
1134
|
+
Returns
|
|
1135
|
+
-------
|
|
1136
|
+
str
|
|
1137
|
+
The local path to the downloaded artifact.
|
|
1138
|
+
"""
|
|
1139
|
+
if experiment_name is None:
|
|
1140
|
+
experiment_name = self.model_name
|
|
1141
|
+
try:
|
|
1142
|
+
experiment = self.client.get_experiment_by_name(experiment_name)
|
|
1143
|
+
if not experiment:
|
|
1144
|
+
log_and_raise_error(self._logger, f"Experiment '{experiment_name}' not found.")
|
|
1145
|
+
runs = self.client.search_runs(
|
|
1146
|
+
experiment_ids=[experiment.experiment_id], filter_string=f"tags.mlflow.runName = '{run_name}'"
|
|
1147
|
+
)
|
|
1148
|
+
if not runs:
|
|
1149
|
+
log_and_raise_error(self._logger, f"Run '{run_name}' not found in experiment '{experiment_name}'.")
|
|
1150
|
+
run_id = runs[0].info.run_id
|
|
1151
|
+
run_uri = f"runs:/{run_id}/{artifact_path}"
|
|
1152
|
+
local_path = mlflow.artifacts.download_artifacts(artifact_uri=run_uri)
|
|
1153
|
+
self._logger.info(f"Downloaded artifact '{artifact_path}' from run '{run_name}' to '{local_path}'.")
|
|
1154
|
+
return local_path
|
|
1155
|
+
except Exception as e:
|
|
1156
|
+
log_and_raise_error(
|
|
1157
|
+
self._logger, f"Error downloading artifact '{artifact_path}' from run '{run_name}': {e}"
|
|
1158
|
+
)
|
|
1159
|
+
|
|
1160
|
+
def get_run_data(self, version: int | str | None = None, alias: str | None = None):
|
|
1161
|
+
"""
|
|
1162
|
+
Get run data from a model version using either version number or alias.
|
|
1163
|
+
|
|
1164
|
+
Parameters
|
|
1165
|
+
----------
|
|
1166
|
+
version : int, str, or None
|
|
1167
|
+
The version number of the model. Can be an integer, string, or "latest".
|
|
1168
|
+
alias : str, optional
|
|
1169
|
+
The alias of the model version (e.g., "production", "staging").
|
|
1170
|
+
|
|
1171
|
+
Returns
|
|
1172
|
+
-------
|
|
1173
|
+
mlflow.entities.Run
|
|
1174
|
+
The MLflow run object containing all run data (params, metrics, tags, etc.).
|
|
1175
|
+
|
|
1176
|
+
Raises
|
|
1177
|
+
------
|
|
1178
|
+
ValueError
|
|
1179
|
+
If neither version nor alias is provided, or if the version/alias is not found.
|
|
1180
|
+
"""
|
|
1181
|
+
if version is None and alias is None:
|
|
1182
|
+
log_and_raise_error(self._logger, "Either 'version' or 'alias' must be provided to get run data.")
|
|
1183
|
+
|
|
1184
|
+
try:
|
|
1185
|
+
# Handle version resolution
|
|
1186
|
+
if alias is not None:
|
|
1187
|
+
model_version = self.client.get_model_version_by_alias(name=self.model_name, alias=alias)
|
|
1188
|
+
elif version == "latest":
|
|
1189
|
+
latest_version = self.get_latest_model_version()
|
|
1190
|
+
if latest_version is None:
|
|
1191
|
+
log_and_raise_error(self._logger, f"No versions found for model '{self.model_name}'.")
|
|
1192
|
+
model_version = self.client.get_model_version(name=self.model_name, version=str(latest_version))
|
|
1193
|
+
else:
|
|
1194
|
+
model_version = self.client.get_model_version(name=self.model_name, version=str(version))
|
|
1195
|
+
|
|
1196
|
+
run = self.client.get_run(model_version.run_id)
|
|
1197
|
+
|
|
1198
|
+
self._logger.info(
|
|
1199
|
+
f"Successfully retrieved run data for model '{self.model_name}' version {model_version.version} "
|
|
1200
|
+
f"(run_id: {model_version.run_id})."
|
|
1201
|
+
)
|
|
1202
|
+
return run
|
|
1203
|
+
|
|
1204
|
+
except Exception as e:
|
|
1205
|
+
log_and_raise_error(
|
|
1206
|
+
self._logger,
|
|
1207
|
+
f"Error retrieving run data for model '{self.model_name}' (version: {version}, alias: {alias}): {e}",
|
|
1208
|
+
)
|