datamint 2.3.5__py3-none-any.whl → 2.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datamint might be problematic. Click here for more details.
- datamint/api/base_api.py +42 -8
- datamint/api/client.py +2 -0
- datamint/api/endpoints/resources_api.py +37 -13
- datamint/apihandler/base_api_handler.py +0 -1
- datamint/apihandler/dto/annotation_dto.py +2 -0
- datamint/dataset/base_dataset.py +4 -0
- datamint/lightning/__init__.py +1 -0
- datamint/lightning/datamintdatamodule.py +103 -0
- datamint/mlflow/__init__.py +46 -0
- datamint/mlflow/artifact/__init__.py +1 -0
- datamint/mlflow/artifact/datamint_artifacts_repo.py +8 -0
- datamint/mlflow/env_utils.py +109 -0
- datamint/mlflow/env_vars.py +5 -0
- datamint/mlflow/lightning/callbacks/__init__.py +1 -0
- datamint/mlflow/lightning/callbacks/modelcheckpoint.py +338 -0
- datamint/mlflow/models/__init__.py +94 -0
- datamint/mlflow/tracking/datamint_store.py +46 -0
- datamint/mlflow/tracking/default_experiment.py +27 -0
- datamint/mlflow/tracking/fluent.py +78 -0
- datamint-2.4.1.dist-info/METADATA +320 -0
- {datamint-2.3.5.dist-info → datamint-2.4.1.dist-info}/RECORD +23 -10
- datamint-2.4.1.dist-info/entry_points.txt +18 -0
- datamint-2.3.5.dist-info/METADATA +0 -125
- datamint-2.3.5.dist-info/entry_points.txt +0 -4
- {datamint-2.3.5.dist-info → datamint-2.4.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from weakref import proxy
|
|
4
|
+
from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
|
|
5
|
+
from typing import Literal, Any
|
|
6
|
+
import inspect
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn
|
|
9
|
+
import lightning.pytorch as L
|
|
10
|
+
from datamint.mlflow.models import log_model_metadata, _get_MLFlowLogger
|
|
11
|
+
from datamint.mlflow.env_utils import ensure_mlflow_configured
|
|
12
|
+
import mlflow
|
|
13
|
+
import logging
|
|
14
|
+
from lightning.pytorch.loggers import MLFlowLogger
|
|
15
|
+
|
|
16
|
+
_LOGGER = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def help_infer_signature(x):
|
|
20
|
+
if isinstance(x, torch.Tensor):
|
|
21
|
+
return x.detach().cpu().numpy()
|
|
22
|
+
elif isinstance(x, dict):
|
|
23
|
+
return {k: v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else v for k, v in x.items()}
|
|
24
|
+
elif isinstance(x, list):
|
|
25
|
+
return [v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else v for v in x]
|
|
26
|
+
elif isinstance(x, tuple):
|
|
27
|
+
return tuple(v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else v for v in x)
|
|
28
|
+
|
|
29
|
+
return x
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class MLFlowModelCheckpoint(ModelCheckpoint):
|
|
33
|
+
def __init__(self, *args,
|
|
34
|
+
register_model_name: str | None = None,
|
|
35
|
+
register_model_on: Literal["train", "val", "test", "predict"] | None = None,
|
|
36
|
+
code_paths: list[str] | None = None,
|
|
37
|
+
log_model_at_end_only: bool = True,
|
|
38
|
+
additional_metadata: dict[str, Any] | None = None,
|
|
39
|
+
extra_pip_requirements: list[str] | None = None,
|
|
40
|
+
**kwargs):
|
|
41
|
+
"""
|
|
42
|
+
MLFlowModelCheckpoint is a custom callback for PyTorch Lightning that integrates with MLFlow to log and register models.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
register_model_name (str | None): The name to register the model under in MLFlow. If None, the model will not be registered.
|
|
46
|
+
register_model_on (Literal["train", "val", "test", "predict"] | None): The stage at which to register the model. If None, the model will not be registered.
|
|
47
|
+
code_paths (list[str] | None): List of paths to Python files that should be included in the MLFlow model.
|
|
48
|
+
log_model_at_end_only (bool): If True, only log the model to MLFlow at the end of the training instead of after every checkpoint save.
|
|
49
|
+
additional_metadata (dict[str, Any] | None): Additional metadata to log with the model as a JSON file.
|
|
50
|
+
extra_pip_requirements (list[str] | None): Additional pip requirements to include with the MLFlow model.
|
|
51
|
+
**kwargs: Keyword arguments for ModelCheckpoint.
|
|
52
|
+
"""
|
|
53
|
+
# Ensure MLflow is configured when callback is initialized
|
|
54
|
+
ensure_mlflow_configured()
|
|
55
|
+
|
|
56
|
+
super().__init__(*args, **kwargs)
|
|
57
|
+
if self.save_top_k > 1:
|
|
58
|
+
raise NotImplementedError("save_top_k > 1 is not supported. "
|
|
59
|
+
"Please use save_top_k=1 to save only the best model.")
|
|
60
|
+
if self.save_last is not None and self.save_top_k != 0 and self.monitor is not None:
|
|
61
|
+
raise NotImplementedError("save_last is not supported with monitor and save_top_k!=0. "
|
|
62
|
+
"Please use two separate callbacks: one for saving the last model "
|
|
63
|
+
"and another for saving the best model based on the monitor metric.")
|
|
64
|
+
|
|
65
|
+
if register_model_name is not None and register_model_on is None:
|
|
66
|
+
raise ValueError("If you provide a register_model_name, you must also provide a register_model_on.")
|
|
67
|
+
if register_model_on is not None and register_model_name is None:
|
|
68
|
+
raise ValueError("If you provide a register_model_on, you must also provide a register_model_name.")
|
|
69
|
+
if register_model_on not in ["train", "val", "test", "predict", None]:
|
|
70
|
+
raise ValueError("register_model_on must be one of train, val, test or predict.")
|
|
71
|
+
|
|
72
|
+
self.register_model_name = register_model_name
|
|
73
|
+
self.register_model_on = register_model_on
|
|
74
|
+
self.log_model_at_end_only = log_model_at_end_only
|
|
75
|
+
self._last_model_uri = None
|
|
76
|
+
self.last_saved_model_info = None
|
|
77
|
+
self._inferred_signature = None
|
|
78
|
+
self._input_example = None
|
|
79
|
+
self.code_paths = code_paths
|
|
80
|
+
self.additional_metadata = additional_metadata or {}
|
|
81
|
+
self.extra_pip_requirements = extra_pip_requirements or []
|
|
82
|
+
|
|
83
|
+
def _infer_params(self, model: nn.Module) -> tuple[dict, ...]:
|
|
84
|
+
"""Extract metadata from the model's forward method signature.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
A tuple of dicts, each containing parameter metadata ordered by position.
|
|
88
|
+
"""
|
|
89
|
+
forward_method = getattr(model.__class__, 'forward', None)
|
|
90
|
+
|
|
91
|
+
if forward_method is None:
|
|
92
|
+
return ()
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
sig = inspect.signature(forward_method)
|
|
96
|
+
params_list = []
|
|
97
|
+
|
|
98
|
+
for param_name, param in sig.parameters.items():
|
|
99
|
+
if param_name == 'self':
|
|
100
|
+
continue
|
|
101
|
+
|
|
102
|
+
param_info = {
|
|
103
|
+
'name': param_name,
|
|
104
|
+
'kind': param.kind.name,
|
|
105
|
+
'annotation': param.annotation if param.annotation != inspect.Parameter.empty else None,
|
|
106
|
+
'default': param.default if param.default != inspect.Parameter.empty else None,
|
|
107
|
+
}
|
|
108
|
+
params_list.append(param_info)
|
|
109
|
+
|
|
110
|
+
# Add return annotation if available as the last element
|
|
111
|
+
return_annotation = sig.return_annotation
|
|
112
|
+
if return_annotation != inspect.Signature.empty:
|
|
113
|
+
return_info = {'_return_annotation': str(return_annotation)}
|
|
114
|
+
params_list.append(return_info)
|
|
115
|
+
|
|
116
|
+
return tuple(params_list)
|
|
117
|
+
|
|
118
|
+
except Exception as e:
|
|
119
|
+
_LOGGER.warning(f"Failed to infer forward method parameters: {e}")
|
|
120
|
+
return ()
|
|
121
|
+
|
|
122
|
+
def _save_checkpoint(self, trainer: L.Trainer, filepath: str) -> None:
|
|
123
|
+
_LOGGER.debug(f"Saving checkpoint to {filepath}...")
|
|
124
|
+
trainer.save_checkpoint(filepath, self.save_weights_only)
|
|
125
|
+
|
|
126
|
+
self._last_global_step_saved = trainer.global_step
|
|
127
|
+
self._last_checkpoint_saved = filepath
|
|
128
|
+
|
|
129
|
+
# notify loggers
|
|
130
|
+
if trainer.is_global_zero:
|
|
131
|
+
for logger in trainer.loggers:
|
|
132
|
+
logger.after_save_checkpoint(proxy(self))
|
|
133
|
+
if isinstance(logger, MLFlowLogger) and not self.log_model_at_end_only:
|
|
134
|
+
_LOGGER.debug(f"_save_checkpoint: Logging model to MLFlow at {filepath}...")
|
|
135
|
+
self.log_model_to_mlflow(trainer.model, run_id=logger.run_id)
|
|
136
|
+
|
|
137
|
+
def log_additional_metadata(self, logger: MLFlowLogger | L.Trainer,
|
|
138
|
+
additional_metadata: dict) -> None:
|
|
139
|
+
"""Log additional metadata as a JSON file to the model artifact.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
logger: The MLFlowLogger or Lightning Trainer instance to use for logging.
|
|
143
|
+
additional_metadata: A dictionary containing additional metadata to log.
|
|
144
|
+
"""
|
|
145
|
+
self.additional_metadata = additional_metadata
|
|
146
|
+
if not self.additional_metadata:
|
|
147
|
+
return
|
|
148
|
+
|
|
149
|
+
if self.last_saved_model_info is None:
|
|
150
|
+
_LOGGER.warning("No model has been saved yet. Cannot log additional metadata.")
|
|
151
|
+
return
|
|
152
|
+
|
|
153
|
+
try:
|
|
154
|
+
log_model_metadata(metadata=self.additional_metadata,
|
|
155
|
+
logger=logger,
|
|
156
|
+
model_path=self.last_saved_model_info.artifact_path)
|
|
157
|
+
except Exception as e:
|
|
158
|
+
_LOGGER.warning(f"Failed to log additional metadata: {e}")
|
|
159
|
+
|
|
160
|
+
def log_model_to_mlflow(self,
|
|
161
|
+
model: nn.Module,
|
|
162
|
+
run_id: str | MLFlowLogger
|
|
163
|
+
) -> None:
|
|
164
|
+
"""Log the model to MLflow."""
|
|
165
|
+
if isinstance(run_id, MLFlowLogger):
|
|
166
|
+
logger = run_id
|
|
167
|
+
if logger.run_id is None:
|
|
168
|
+
raise ValueError("MLFlowLogger has no run_id. Cannot log model to MLFlow.")
|
|
169
|
+
run_id = logger.run_id
|
|
170
|
+
|
|
171
|
+
if self._last_checkpoint_saved is None or self._last_checkpoint_saved == '':
|
|
172
|
+
_LOGGER.warning("No checkpoint saved yet. Cannot log model to MLFlow.")
|
|
173
|
+
return
|
|
174
|
+
|
|
175
|
+
orig_device = next(model.parameters()).device
|
|
176
|
+
model = model.cpu() # Ensure the model is on CPU for logging
|
|
177
|
+
|
|
178
|
+
requirements = list(self.extra_pip_requirements)
|
|
179
|
+
# check if lightning is in the requirements
|
|
180
|
+
if not any('lightning' in req.lower() for req in requirements):
|
|
181
|
+
requirements.append(f'lightning=={L.__version__}')
|
|
182
|
+
|
|
183
|
+
_LOGGER.debug(f"log_model_to_mlflow: Logging model to MLFlow at {self._last_checkpoint_saved}...")
|
|
184
|
+
modelinfo = mlflow.pytorch.log_model(
|
|
185
|
+
pytorch_model=model,
|
|
186
|
+
artifact_path=f'model/{Path(self._last_checkpoint_saved).stem}',
|
|
187
|
+
signature=self._inferred_signature,
|
|
188
|
+
run_id=run_id,
|
|
189
|
+
extra_pip_requirements=requirements,
|
|
190
|
+
code_paths=self.code_paths
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
model.to(device=orig_device) # Move the model back to its original device
|
|
194
|
+
self._last_model_uri = modelinfo.model_uri
|
|
195
|
+
self.last_saved_model_info = modelinfo
|
|
196
|
+
|
|
197
|
+
# Log additional metadata after the model is saved
|
|
198
|
+
log_model_metadata(self.additional_metadata,
|
|
199
|
+
model_path=modelinfo.artifact_path,
|
|
200
|
+
run_id=run_id)
|
|
201
|
+
|
|
202
|
+
def _remove_checkpoint(self, trainer: L.Trainer, filepath: str) -> None:
|
|
203
|
+
super()._remove_checkpoint(trainer, filepath)
|
|
204
|
+
# remove the checkpoint from mlflow
|
|
205
|
+
if trainer.is_global_zero:
|
|
206
|
+
for logger in trainer.loggers:
|
|
207
|
+
if isinstance(logger, MLFlowLogger):
|
|
208
|
+
artifact_uri = logger.experiment.get_run(logger.run_id).info.artifact_uri
|
|
209
|
+
rep = get_artifact_repository(artifact_uri)
|
|
210
|
+
rep.delete_artifacts(f'model/{Path(filepath).stem}')
|
|
211
|
+
|
|
212
|
+
def register_model(self, trainer=None):
|
|
213
|
+
"""Register the model in MLFlow Model Registry."""
|
|
214
|
+
# mlflow_client = _get_MLFlowLogger(trainer)._mlflow_client
|
|
215
|
+
return mlflow.register_model(
|
|
216
|
+
model_uri=self._last_model_uri,
|
|
217
|
+
name=self.register_model_name,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
def _update_signature(self, trainer):
|
|
221
|
+
if self._inferred_signature is None:
|
|
222
|
+
_LOGGER.warning("No signature found. Cannot update signature.")
|
|
223
|
+
return
|
|
224
|
+
if self._last_model_uri is None:
|
|
225
|
+
_LOGGER.warning("No model URI found. Cannot update signature.")
|
|
226
|
+
return
|
|
227
|
+
|
|
228
|
+
mllogger = _get_MLFlowLogger(trainer)
|
|
229
|
+
mlclient = mllogger._mlflow_client
|
|
230
|
+
|
|
231
|
+
# check if the model exists
|
|
232
|
+
for artifact_info in mlclient.list_artifacts(run_id=mllogger.run_id):
|
|
233
|
+
if artifact_info.path.startswith('model'):
|
|
234
|
+
break
|
|
235
|
+
else:
|
|
236
|
+
_LOGGER.warning(f"Model URI {self._last_model_uri} does not exist. Cannot update signature.")
|
|
237
|
+
return
|
|
238
|
+
_LOGGER.debug(f"Updating signature for model URI: {self._last_model_uri}...")
|
|
239
|
+
# update the signature
|
|
240
|
+
mlflow.models.set_signature(
|
|
241
|
+
model_uri=self._last_model_uri,
|
|
242
|
+
signature=self._inferred_signature,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
def __wrap_forward(self, pl_module: nn.Module):
|
|
246
|
+
original_forward = pl_module.forward
|
|
247
|
+
|
|
248
|
+
def wrapped_forward(x, *args, **kwargs):
|
|
249
|
+
x0 = help_infer_signature(x)
|
|
250
|
+
infered_params = self._infer_params(pl_module)
|
|
251
|
+
if len(infered_params) > 1:
|
|
252
|
+
infered_params = {param['name']: param['default']
|
|
253
|
+
for param in infered_params[1:] if 'name' in param}
|
|
254
|
+
else:
|
|
255
|
+
infered_params = None
|
|
256
|
+
|
|
257
|
+
self._inferred_signature = mlflow.models.infer_signature(model_input=x0,
|
|
258
|
+
params=infered_params)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
# run once and get back to the original forward
|
|
262
|
+
pl_module.forward = original_forward
|
|
263
|
+
method = getattr(pl_module, 'forward')
|
|
264
|
+
out = method(x, *args, **kwargs)
|
|
265
|
+
|
|
266
|
+
output_sig = mlflow.models.infer_signature(model_output=help_infer_signature(out))
|
|
267
|
+
self._inferred_signature.outputs = output_sig.outputs
|
|
268
|
+
|
|
269
|
+
return out
|
|
270
|
+
|
|
271
|
+
pl_module.forward = wrapped_forward
|
|
272
|
+
|
|
273
|
+
def on_train_start(self, trainer, pl_module):
|
|
274
|
+
self.__wrap_forward(pl_module)
|
|
275
|
+
|
|
276
|
+
def on_train_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
|
|
277
|
+
super().on_train_end(trainer, pl_module)
|
|
278
|
+
|
|
279
|
+
if self.log_model_at_end_only and trainer.is_global_zero:
|
|
280
|
+
logger = _get_MLFlowLogger(trainer)
|
|
281
|
+
if logger is None:
|
|
282
|
+
_LOGGER.warning("No MLFlowLogger found. Cannot log model to MLFlow.")
|
|
283
|
+
else:
|
|
284
|
+
self.log_model_to_mlflow(trainer.model, run_id=logger.run_id)
|
|
285
|
+
|
|
286
|
+
self._update_signature(trainer)
|
|
287
|
+
|
|
288
|
+
if self.register_model_on == 'train':
|
|
289
|
+
self.register_model(trainer)
|
|
290
|
+
|
|
291
|
+
def _restore_model_uri(self, trainer: L.Trainer) -> None:
|
|
292
|
+
logger = _get_MLFlowLogger(trainer)
|
|
293
|
+
if logger is None:
|
|
294
|
+
_LOGGER.warning("No MLFlowLogger found. Cannot restore model URI.")
|
|
295
|
+
return
|
|
296
|
+
if trainer.ckpt_path is None:
|
|
297
|
+
return
|
|
298
|
+
extracted_run_id = Path(trainer.ckpt_path).parts[1]
|
|
299
|
+
if extracted_run_id != logger.run_id:
|
|
300
|
+
_LOGGER.warning(f"Run ID mismatch: {extracted_run_id} != {logger.run_id}." +
|
|
301
|
+
" Check `run_id` parameter in MLFlowLogger.")
|
|
302
|
+
self._last_model_uri = f'runs:/{logger.run_id}/model/{Path(trainer.ckpt_path).stem}'
|
|
303
|
+
try:
|
|
304
|
+
self.last_saved_model_info = mlflow.models.get_model_info(self._last_model_uri)
|
|
305
|
+
except mlflow.exceptions.MlflowException as e:
|
|
306
|
+
_LOGGER.warning(f"Failed to get model info for URI {self._last_model_uri}: {e}")
|
|
307
|
+
self.last_saved_model_info = None
|
|
308
|
+
|
|
309
|
+
def on_test_start(self, trainer, pl_module):
|
|
310
|
+
self.__wrap_forward(pl_module)
|
|
311
|
+
self._restore_model_uri(trainer)
|
|
312
|
+
return super().on_test_start(trainer, pl_module)
|
|
313
|
+
|
|
314
|
+
def on_predict_start(self, trainer, pl_module):
|
|
315
|
+
self.__wrap_forward(pl_module)
|
|
316
|
+
self._restore_model_uri(trainer)
|
|
317
|
+
return super().on_predict_start(trainer, pl_module)
|
|
318
|
+
|
|
319
|
+
def on_test_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
|
|
320
|
+
super().on_test_end(trainer, pl_module)
|
|
321
|
+
|
|
322
|
+
if self.register_model_on == 'test':
|
|
323
|
+
self._update_signature(trainer)
|
|
324
|
+
self.register_model(trainer)
|
|
325
|
+
|
|
326
|
+
def on_predict_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
|
|
327
|
+
super().on_predict_end(trainer, pl_module)
|
|
328
|
+
|
|
329
|
+
if self.register_model_on == 'predict':
|
|
330
|
+
self._update_signature(trainer)
|
|
331
|
+
self.register_model(trainer)
|
|
332
|
+
|
|
333
|
+
def on_validation_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
|
|
334
|
+
super().on_validation_end(trainer, pl_module)
|
|
335
|
+
|
|
336
|
+
if self.register_model_on == 'val':
|
|
337
|
+
self._update_signature(trainer)
|
|
338
|
+
self.register_model(trainer)
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import json
|
|
3
|
+
import lightning as L
|
|
4
|
+
from lightning.pytorch.loggers import MLFlowLogger
|
|
5
|
+
import mlflow
|
|
6
|
+
import os
|
|
7
|
+
from tempfile import TemporaryDirectory
|
|
8
|
+
from torch import nn
|
|
9
|
+
|
|
10
|
+
_LOGGER = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def download_model_metadata(model_uri: str) -> dict:
|
|
14
|
+
from mlflow.tracking.artifact_utils import get_artifact_repository
|
|
15
|
+
|
|
16
|
+
art_repo = get_artifact_repository(artifact_uri=model_uri)
|
|
17
|
+
try:
|
|
18
|
+
out_artifact_path = art_repo.download_artifacts(artifact_path='metadata.json')
|
|
19
|
+
except OSError as e:
|
|
20
|
+
_LOGGER.warning(f"Error downloading model metadata: {e}")
|
|
21
|
+
return {}
|
|
22
|
+
|
|
23
|
+
with open(out_artifact_path, 'r') as f:
|
|
24
|
+
metadata = json.load(f)
|
|
25
|
+
return metadata
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _get_MLFlowLogger(trainer: L.Trainer) -> MLFlowLogger:
|
|
29
|
+
for logger in trainer.loggers:
|
|
30
|
+
if isinstance(logger, MLFlowLogger):
|
|
31
|
+
return logger
|
|
32
|
+
raise ValueError("No MLFlowLogger found in the trainer loggers.")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def log_model_metadata(metadata: dict,
|
|
36
|
+
mlflow_model: mlflow.models.Model | None = None,
|
|
37
|
+
logger: MLFlowLogger | L.Trainer | None = None,
|
|
38
|
+
model_path: str | None = None,
|
|
39
|
+
run_id: str | None = None,
|
|
40
|
+
) -> None:
|
|
41
|
+
"""
|
|
42
|
+
Log additional metadata to the MLflow model.
|
|
43
|
+
It should be provided the one of the following combination of parameters:
|
|
44
|
+
1. `mlflow_model`
|
|
45
|
+
2. `logger` and `model_path`
|
|
46
|
+
3. `run_id` and `model_path`
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
self: The instance of the class where this method is called.
|
|
50
|
+
metadata (dict): The metadata to log.
|
|
51
|
+
mlflow_model (mlflow.models.Model, optional): The MLflow model object. Defaults to None.
|
|
52
|
+
logger (MLFlowLogger or L.Trainer, optional): The MLFlow logger or Lightning Trainer instance. Defaults to None.
|
|
53
|
+
model_path (str, optional): The path where the model is stored in MLflow. Defaults to None.
|
|
54
|
+
run_id (str, optional): The run ID of the MLflow run. Defaults to None.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
# Validate inputs
|
|
58
|
+
if mlflow_model is None and (logger is None or model_path is None) and (run_id is None or model_path is None):
|
|
59
|
+
raise ValueError(
|
|
60
|
+
"You must provide either `mlflow_model`, or both `logger` and `model_path`, "
|
|
61
|
+
"or both `run_id` and `model_path`."
|
|
62
|
+
)
|
|
63
|
+
# not both
|
|
64
|
+
if mlflow_model is not None and logger is not None:
|
|
65
|
+
raise ValueError("Only one of mlflow_model or logger can be provided.")
|
|
66
|
+
|
|
67
|
+
if logger is not None and isinstance(logger, L.Trainer):
|
|
68
|
+
logger = _get_MLFlowLogger(logger)
|
|
69
|
+
if logger is None:
|
|
70
|
+
raise ValueError("MLFlowLogger not found in the Trainer's loggers.")
|
|
71
|
+
run_id = logger.run_id
|
|
72
|
+
artifact_path = model_path
|
|
73
|
+
mlfclient = logger.experiment
|
|
74
|
+
elif mlflow_model is not None:
|
|
75
|
+
run_id = mlflow_model.run_id
|
|
76
|
+
artifact_path = mlflow_model.artifact_path
|
|
77
|
+
mlfclient = mlflow.client.MlflowClient()
|
|
78
|
+
elif run_id is not None and model_path is not None:
|
|
79
|
+
mlfclient = mlflow.client.MlflowClient()
|
|
80
|
+
artifact_path = model_path
|
|
81
|
+
else:
|
|
82
|
+
raise ValueError("Invalid logger or mlflow_model provided.")
|
|
83
|
+
|
|
84
|
+
with TemporaryDirectory() as tmpdir:
|
|
85
|
+
metadata_path = os.path.join(tmpdir, "metadata.json")
|
|
86
|
+
with open(metadata_path, "w") as f:
|
|
87
|
+
json.dump(metadata, f, indent=2)
|
|
88
|
+
|
|
89
|
+
mlfclient.log_artifact(
|
|
90
|
+
run_id=run_id,
|
|
91
|
+
local_path=metadata_path,
|
|
92
|
+
artifact_path=artifact_path,
|
|
93
|
+
)
|
|
94
|
+
_LOGGER.debug(f"Additional metadata logged to {artifact_path}/metadata.json")
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from mlflow.store.tracking.rest_store import RestStore
|
|
2
|
+
from functools import partial
|
|
3
|
+
from .fluent import get_active_project_id
|
|
4
|
+
import json
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DatamintStore(RestStore):
|
|
8
|
+
"""
|
|
9
|
+
DatamintStore is a subclass of RestStore that provides a tracking store
|
|
10
|
+
implementation for Datamint.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self, store_uri: str, artifact_uri=None, force_valid=True):
|
|
14
|
+
# Ensure MLflow environment is configured when store is initialized
|
|
15
|
+
from datamint.mlflow.env_utils import setup_mlflow_environment
|
|
16
|
+
from mlflow.utils.credentials import get_default_host_creds
|
|
17
|
+
setup_mlflow_environment()
|
|
18
|
+
|
|
19
|
+
if store_uri.startswith('datamint://') or 'datamint.io' in store_uri or force_valid:
|
|
20
|
+
self.invalid = False
|
|
21
|
+
else:
|
|
22
|
+
self.invalid = True
|
|
23
|
+
|
|
24
|
+
store_uri = store_uri.split('datamint://', maxsplit=1)[-1]
|
|
25
|
+
get_host_creds = partial(get_default_host_creds, store_uri)
|
|
26
|
+
super().__init__(get_host_creds=get_host_creds)
|
|
27
|
+
|
|
28
|
+
def create_experiment(self, name, artifact_location=None, tags=None, project_id: str = None) -> str:
|
|
29
|
+
from mlflow.protos.service_pb2 import CreateExperiment
|
|
30
|
+
from mlflow.utils.proto_json_utils import message_to_json
|
|
31
|
+
|
|
32
|
+
if self.invalid:
|
|
33
|
+
return super().create_experiment(name, artifact_location, tags)
|
|
34
|
+
if project_id is None:
|
|
35
|
+
project_id = get_active_project_id()
|
|
36
|
+
tag_protos = [tag.to_proto() for tag in tags] if tags else []
|
|
37
|
+
req_body = message_to_json(
|
|
38
|
+
CreateExperiment(name=name, artifact_location=artifact_location, tags=tag_protos)
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
req_body = json.loads(req_body)
|
|
42
|
+
req_body["project_id"] = project_id # FIXME: this should be in the proto
|
|
43
|
+
req_body = json.dumps(req_body)
|
|
44
|
+
|
|
45
|
+
response_proto = self._call_endpoint(CreateExperiment, req_body)
|
|
46
|
+
return response_proto.experiment_id
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import os
|
|
3
|
+
from mlflow.tracking.default_experiment.abstract_context import DefaultExperimentProvider
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DatamintExperimentProvider(DefaultExperimentProvider):
|
|
7
|
+
_experiment_id = None
|
|
8
|
+
|
|
9
|
+
def in_context(self):
|
|
10
|
+
return True
|
|
11
|
+
|
|
12
|
+
def get_experiment_id(self):
|
|
13
|
+
from mlflow.tracking.client import MlflowClient
|
|
14
|
+
|
|
15
|
+
if DatamintExperimentProvider._experiment_id is not None:
|
|
16
|
+
return self._experiment_id
|
|
17
|
+
# Get the filename of the main source file
|
|
18
|
+
source_code_filename = os.path.basename(sys.argv[0])
|
|
19
|
+
mlflowclient = MlflowClient()
|
|
20
|
+
exp = mlflowclient.get_experiment_by_name(source_code_filename)
|
|
21
|
+
if exp is None:
|
|
22
|
+
experiment_id = mlflowclient.create_experiment(source_code_filename)
|
|
23
|
+
else:
|
|
24
|
+
experiment_id = exp.experiment_id
|
|
25
|
+
DatamintExperimentProvider._experiment_id = experiment_id
|
|
26
|
+
|
|
27
|
+
return experiment_id
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
import threading
|
|
3
|
+
import logging
|
|
4
|
+
from datamint import Api
|
|
5
|
+
from datamint.exceptions import DatamintException
|
|
6
|
+
import os
|
|
7
|
+
from datamint.mlflow.env_vars import EnvVars
|
|
8
|
+
from datamint.mlflow.env_utils import ensure_mlflow_configured
|
|
9
|
+
|
|
10
|
+
_PROJECT_LOCK = threading.Lock()
|
|
11
|
+
_LOGGER = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
_ACTIVE_PROJECT_ID: Optional[str] = None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_active_project_id() -> str | None:
|
|
17
|
+
"""
|
|
18
|
+
Get the active project ID from the environment variable or the global variable.
|
|
19
|
+
"""
|
|
20
|
+
global _ACTIVE_PROJECT_ID
|
|
21
|
+
|
|
22
|
+
if _ACTIVE_PROJECT_ID is not None:
|
|
23
|
+
return _ACTIVE_PROJECT_ID
|
|
24
|
+
# Check if the environment variable is set
|
|
25
|
+
project_id = os.getenv(EnvVars.DATAMINT_PROJECT_ID.value)
|
|
26
|
+
if project_id is not None:
|
|
27
|
+
_ACTIVE_PROJECT_ID = project_id
|
|
28
|
+
return project_id
|
|
29
|
+
project_name = os.getenv(EnvVars.DATAMINT_PROJECT_NAME.value)
|
|
30
|
+
if project_name is not None:
|
|
31
|
+
project = _find_project_by_name(project_name)
|
|
32
|
+
if project is not None:
|
|
33
|
+
_ACTIVE_PROJECT_ID = project['id']
|
|
34
|
+
return _ACTIVE_PROJECT_ID
|
|
35
|
+
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _find_project_by_name(project_name: str):
|
|
40
|
+
dt_client = Api(check_connection=False)
|
|
41
|
+
project = dt_client.projects.get_by_name(project_name)
|
|
42
|
+
if project is None:
|
|
43
|
+
raise DatamintException(f"Project with name '{project_name}' does not exist.")
|
|
44
|
+
return project
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def set_project(project_name: Optional[str] = None, project_id: Optional[str] = None):
|
|
48
|
+
from mlflow.exceptions import MlflowException
|
|
49
|
+
global _ACTIVE_PROJECT_ID
|
|
50
|
+
|
|
51
|
+
# Ensure MLflow is properly configured before proceeding
|
|
52
|
+
ensure_mlflow_configured()
|
|
53
|
+
|
|
54
|
+
if project_name is None and project_id is None:
|
|
55
|
+
raise MlflowException("You must specify either a project name or a project id")
|
|
56
|
+
|
|
57
|
+
if project_name is not None and project_id is not None:
|
|
58
|
+
raise MlflowException("You cannot specify both a project name and a project id")
|
|
59
|
+
|
|
60
|
+
with _PROJECT_LOCK:
|
|
61
|
+
dt_client = Api(check_connection=False)
|
|
62
|
+
if project_id is None:
|
|
63
|
+
project = dt_client.projects.get_by_name(project_name)
|
|
64
|
+
if project is None:
|
|
65
|
+
raise DatamintException(f"Project with name '{project_name}' does not exist.")
|
|
66
|
+
project_id = project.id
|
|
67
|
+
else:
|
|
68
|
+
project = dt_client.projects.get_by_id(project_id)
|
|
69
|
+
if project is None:
|
|
70
|
+
raise DatamintException(f"Project with id '{project_id}' does not exist.")
|
|
71
|
+
|
|
72
|
+
_ACTIVE_PROJECT_ID = project_id
|
|
73
|
+
|
|
74
|
+
# Set 'DATAMINT_PROJECT_ID' environment variable
|
|
75
|
+
# so that subprocess can inherit it.
|
|
76
|
+
os.environ[EnvVars.DATAMINT_PROJECT_ID.value] = project_id
|
|
77
|
+
|
|
78
|
+
return project
|