datamint 2.3.3__py3-none-any.whl → 2.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. datamint/__init__.py +1 -3
  2. datamint/api/__init__.py +0 -3
  3. datamint/api/base_api.py +286 -54
  4. datamint/api/client.py +76 -13
  5. datamint/api/endpoints/__init__.py +2 -2
  6. datamint/api/endpoints/annotations_api.py +186 -28
  7. datamint/api/endpoints/deploy_model_api.py +78 -0
  8. datamint/api/endpoints/models_api.py +1 -0
  9. datamint/api/endpoints/projects_api.py +38 -7
  10. datamint/api/endpoints/resources_api.py +227 -100
  11. datamint/api/entity_base_api.py +66 -7
  12. datamint/apihandler/base_api_handler.py +0 -1
  13. datamint/apihandler/dto/annotation_dto.py +2 -0
  14. datamint/client_cmd_tools/datamint_config.py +0 -1
  15. datamint/client_cmd_tools/datamint_upload.py +3 -1
  16. datamint/configs.py +11 -7
  17. datamint/dataset/base_dataset.py +24 -4
  18. datamint/dataset/dataset.py +1 -1
  19. datamint/entities/__init__.py +1 -1
  20. datamint/entities/annotations/__init__.py +13 -0
  21. datamint/entities/{annotation.py → annotations/annotation.py} +81 -47
  22. datamint/entities/annotations/image_classification.py +12 -0
  23. datamint/entities/annotations/image_segmentation.py +252 -0
  24. datamint/entities/annotations/volume_segmentation.py +273 -0
  25. datamint/entities/base_entity.py +100 -6
  26. datamint/entities/cache_manager.py +129 -15
  27. datamint/entities/datasetinfo.py +60 -65
  28. datamint/entities/deployjob.py +18 -0
  29. datamint/entities/project.py +39 -0
  30. datamint/entities/resource.py +310 -46
  31. datamint/lightning/__init__.py +1 -0
  32. datamint/lightning/datamintdatamodule.py +103 -0
  33. datamint/mlflow/__init__.py +65 -0
  34. datamint/mlflow/artifact/__init__.py +1 -0
  35. datamint/mlflow/artifact/datamint_artifacts_repo.py +8 -0
  36. datamint/mlflow/env_utils.py +131 -0
  37. datamint/mlflow/env_vars.py +5 -0
  38. datamint/mlflow/flavors/__init__.py +17 -0
  39. datamint/mlflow/flavors/datamint_flavor.py +150 -0
  40. datamint/mlflow/flavors/model.py +877 -0
  41. datamint/mlflow/lightning/callbacks/__init__.py +1 -0
  42. datamint/mlflow/lightning/callbacks/modelcheckpoint.py +410 -0
  43. datamint/mlflow/models/__init__.py +93 -0
  44. datamint/mlflow/tracking/datamint_store.py +76 -0
  45. datamint/mlflow/tracking/default_experiment.py +27 -0
  46. datamint/mlflow/tracking/fluent.py +91 -0
  47. datamint/utils/env.py +27 -0
  48. datamint/utils/visualization.py +21 -13
  49. datamint-2.9.0.dist-info/METADATA +220 -0
  50. datamint-2.9.0.dist-info/RECORD +73 -0
  51. {datamint-2.3.3.dist-info → datamint-2.9.0.dist-info}/WHEEL +1 -1
  52. datamint-2.9.0.dist-info/entry_points.txt +18 -0
  53. datamint/apihandler/exp_api_handler.py +0 -204
  54. datamint/experiment/__init__.py +0 -1
  55. datamint/experiment/_patcher.py +0 -570
  56. datamint/experiment/experiment.py +0 -1049
  57. datamint-2.3.3.dist-info/METADATA +0 -125
  58. datamint-2.3.3.dist-info/RECORD +0 -54
  59. datamint-2.3.3.dist-info/entry_points.txt +0 -4
@@ -0,0 +1 @@
1
+ from .modelcheckpoint import MLFlowModelCheckpoint
@@ -0,0 +1,410 @@
1
+ from lightning.pytorch.callbacks import ModelCheckpoint
2
+ from pathlib import Path
3
+ from weakref import proxy
4
+ from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
5
+ from typing import Literal, Any
6
+ import inspect
7
+ from torch import nn
8
+ import lightning.pytorch as L
9
+ from datamint.mlflow.models import log_model_metadata, _get_MLFlowLogger
10
+ from datamint.mlflow.env_utils import ensure_mlflow_configured
11
+ import mlflow.models
12
+ import mlflow.exceptions
13
+ import mlflow.pytorch
14
+ import logging
15
+ import json
16
+ import hashlib
17
+ from lightning.pytorch.loggers import MLFlowLogger
18
+
19
+ _LOGGER = logging.getLogger(__name__)
20
+
21
+
22
+ def help_infer_signature(x):
23
+ import torch
24
+ if isinstance(x, torch.Tensor):
25
+ return x.detach().cpu().numpy()
26
+ elif isinstance(x, dict):
27
+ return {k: v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else v for k, v in x.items()}
28
+ elif isinstance(x, list):
29
+ return [v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else v for v in x]
30
+ elif isinstance(x, tuple):
31
+ return tuple(v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else v for v in x)
32
+
33
+ return x
34
+
35
+
36
+ class MLFlowModelCheckpoint(ModelCheckpoint):
37
+ def __init__(self, *args,
38
+ register_model_name: str | None = None,
39
+ register_model_on: Literal["train", "val", "test", "predict"] = 'test',
40
+ code_paths: list[str] | None = None,
41
+ log_model_at_end_only: bool = True,
42
+ additional_metadata: dict[str, Any] | None = None,
43
+ extra_pip_requirements: list[str] | None = None,
44
+ **kwargs):
45
+ """
46
+ MLFlowModelCheckpoint is a custom callback for PyTorch Lightning that integrates with MLFlow to log and register models.
47
+
48
+ Args:
49
+ register_model_name (str | None): The name to register the model under in MLFlow. If None, the model will not be registered.
50
+ register_model_on (Literal["train", "val", "test", "predict"]): The stage at which to register the model. It registers at the end of the specified stage.
51
+ code_paths (list[str] | None): List of paths to Python files that should be included in the MLFlow model.
52
+ log_model_at_end_only (bool): If True, only log the model to MLFlow at the end of the training instead of after every checkpoint save.
53
+ additional_metadata (dict[str, Any] | None): Additional metadata to log with the model as a JSON file.
54
+ extra_pip_requirements (list[str] | None): Additional pip requirements to include with the MLFlow model.
55
+ **kwargs: Keyword arguments for ModelCheckpoint.
56
+ """
57
+ # Ensure MLflow is configured when callback is initialized
58
+ ensure_mlflow_configured()
59
+
60
+ super().__init__(*args, **kwargs)
61
+ if self.save_top_k > 1:
62
+ raise NotImplementedError("save_top_k > 1 is not supported. "
63
+ "Please use save_top_k=1 to save only the best model.")
64
+ if self.save_last is not None and self.save_top_k != 0 and self.monitor is not None:
65
+ raise NotImplementedError("save_last is not supported with monitor and save_top_k!=0. "
66
+ "Please use two separate callbacks: one for saving the last model "
67
+ "and another for saving the best model based on the monitor metric.")
68
+
69
+ if register_model_name is not None and register_model_on is None:
70
+ raise ValueError("If you provide a register_model_name, you must also provide a register_model_on.")
71
+ if register_model_on not in ["train", "val", "test", "predict"]:
72
+ raise ValueError("register_model_on must be one of train, val, test or predict.")
73
+
74
+ self.register_model_name = register_model_name
75
+ self.register_model_on = register_model_on
76
+ self.registered_model_info = None
77
+ self.log_model_at_end_only = log_model_at_end_only
78
+ self._last_model_uri = None
79
+ self.last_saved_model_info = None
80
+ self._inferred_signature = None
81
+ self._input_example = None
82
+ self.code_paths = code_paths
83
+ self.additional_metadata = additional_metadata or {}
84
+ self.extra_pip_requirements = extra_pip_requirements or []
85
+ self._last_registered_state_hash: str = "None"
86
+ self._has_been_trained: bool = False
87
+
88
+ def _compute_registration_state_hash(self) -> str:
89
+ """Compute a hash representing the current model state for registration comparison.
90
+
91
+ Returns:
92
+ A hash string of the current state, or None if state cannot be computed.
93
+ """
94
+ state_dict = {
95
+ 'checkpoint_path': str(self._last_checkpoint_saved),
96
+ 'global_step': self._last_global_step_saved,
97
+ 'signature': str(self._inferred_signature) if self._inferred_signature else None,
98
+ 'model_uri': self._last_model_uri,
99
+ }
100
+
101
+ state_str = json.dumps(state_dict, sort_keys=True)
102
+ return hashlib.md5(state_str.encode('utf-8')).hexdigest()
103
+
104
+ def _should_register_model(self) -> bool:
105
+ """Determine if the model should be registered.
106
+
107
+ Returns:
108
+ True if the model should be registered, False otherwise.
109
+ """
110
+
111
+ if self._last_model_uri is None:
112
+ _LOGGER.warning("No model URI available. Cannot register model.")
113
+ return False
114
+
115
+ # If never registered before, register
116
+ if self._last_registered_state_hash is None:
117
+ return True
118
+
119
+ # If model was retrained, register
120
+ if self._has_been_trained:
121
+ return True
122
+
123
+ # If state changed (signature, checkpoint, etc.), register
124
+ current_state_hash = self._compute_registration_state_hash()
125
+ if current_state_hash != self._last_registered_state_hash:
126
+ return True
127
+
128
+ _LOGGER.info("Model already registered with same configuration. Skipping registration.")
129
+ return False
130
+
131
+ def _infer_params(self, model: nn.Module) -> tuple[dict, ...]:
132
+ """Extract metadata from the model's forward method signature.
133
+
134
+ Returns:
135
+ A tuple of dicts, each containing parameter metadata ordered by position.
136
+ """
137
+ forward_method = getattr(model.__class__, 'forward', None)
138
+
139
+ if forward_method is None:
140
+ return ()
141
+
142
+ try:
143
+ sig = inspect.signature(forward_method)
144
+ params_list = []
145
+
146
+ for param_name, param in sig.parameters.items():
147
+ if param_name == 'self':
148
+ continue
149
+
150
+ param_info = {
151
+ 'name': param_name,
152
+ 'kind': param.kind.name,
153
+ 'annotation': param.annotation if param.annotation != inspect.Parameter.empty else None,
154
+ 'default': param.default if param.default != inspect.Parameter.empty else None,
155
+ }
156
+ params_list.append(param_info)
157
+
158
+ # Add return annotation if available as the last element
159
+ return_annotation = sig.return_annotation
160
+ if return_annotation != inspect.Signature.empty:
161
+ return_info = {'_return_annotation': str(return_annotation)}
162
+ params_list.append(return_info)
163
+
164
+ return tuple(params_list)
165
+
166
+ except Exception as e:
167
+ _LOGGER.warning(f"Failed to infer forward method parameters: {e}")
168
+ return ()
169
+
170
+ def _save_checkpoint(self, trainer: L.Trainer, filepath: str) -> None:
171
+ trainer.save_checkpoint(filepath, self.save_weights_only)
172
+
173
+ self._last_global_step_saved = trainer.global_step
174
+ self._last_checkpoint_saved = filepath
175
+
176
+ # notify loggers
177
+ if trainer.is_global_zero:
178
+ for logger in trainer.loggers:
179
+ logger.after_save_checkpoint(proxy(self))
180
+ if isinstance(logger, MLFlowLogger) and not self.log_model_at_end_only:
181
+ self.log_model_to_mlflow(trainer.model, run_id=logger.run_id)
182
+
183
+ def log_additional_metadata(self, logger: MLFlowLogger | L.Trainer,
184
+ additional_metadata: dict) -> None:
185
+ """Log additional metadata as a JSON file to the model artifact.
186
+
187
+ Args:
188
+ logger: The MLFlowLogger or Lightning Trainer instance to use for logging.
189
+ additional_metadata: A dictionary containing additional metadata to log.
190
+ """
191
+ self.additional_metadata = additional_metadata
192
+ if not self.additional_metadata:
193
+ return
194
+
195
+ if self.last_saved_model_info is None:
196
+ _LOGGER.warning("No model has been saved yet. Cannot log additional metadata.")
197
+ return
198
+
199
+ try:
200
+ log_model_metadata(metadata=self.additional_metadata,
201
+ logger=logger,
202
+ model_path=self.last_saved_model_info.artifact_path)
203
+ except Exception as e:
204
+ _LOGGER.warning(f"Failed to log additional metadata: {e}")
205
+
206
+ def log_model_to_mlflow(self,
207
+ model: nn.Module,
208
+ run_id: str | MLFlowLogger
209
+ ) -> None:
210
+ """Log the model to MLflow."""
211
+ if isinstance(run_id, MLFlowLogger):
212
+ logger = run_id
213
+ if logger.run_id is None:
214
+ raise ValueError("MLFlowLogger has no run_id. Cannot log model to MLFlow.")
215
+ run_id = logger.run_id
216
+
217
+ if self._last_checkpoint_saved is None or self._last_checkpoint_saved == '':
218
+ _LOGGER.warning("No checkpoint saved yet. Cannot log model to MLFlow.")
219
+ return
220
+
221
+ orig_device = next(model.parameters()).device
222
+ model = model.cpu() # Ensure the model is on CPU for logging
223
+
224
+ requirements = list(self.extra_pip_requirements)
225
+ # check if lightning is in the requirements
226
+ if not any('lightning' in req.lower() for req in requirements):
227
+ requirements.append(f'lightning=={L.__version__}')
228
+
229
+ modelinfo = mlflow.pytorch.log_model(
230
+ pytorch_model=model,
231
+ name=Path(self._last_checkpoint_saved).stem,
232
+ signature=self._inferred_signature,
233
+ run_id=run_id,
234
+ extra_pip_requirements=requirements,
235
+ code_paths=self.code_paths
236
+ )
237
+
238
+ model.to(device=orig_device) # Move the model back to its original device
239
+ self._last_model_uri = modelinfo.model_uri
240
+ self.last_saved_model_info = modelinfo
241
+
242
+ # Log additional metadata after the model is saved
243
+ log_model_metadata(self.additional_metadata,
244
+ model_path=modelinfo.artifact_path,
245
+ run_id=run_id)
246
+
247
+ def _remove_checkpoint(self, trainer: L.Trainer, filepath: str) -> None:
248
+ super()._remove_checkpoint(trainer, filepath)
249
+ # remove the checkpoint from mlflow
250
+ if trainer.is_global_zero:
251
+ for logger in trainer.loggers:
252
+ if isinstance(logger, MLFlowLogger):
253
+ artifact_uri = logger.experiment.get_run(logger.run_id).info.artifact_uri
254
+ rep = get_artifact_repository(artifact_uri)
255
+ rep.delete_artifacts(f'model/{Path(filepath).stem}')
256
+
257
+ def register_model(self, trainer=None):
258
+ """Register the model in MLFlow Model Registry."""
259
+ if not self._should_register_model():
260
+ return self.registered_model_info
261
+
262
+ # mlflow_client = _get_MLFlowLogger(trainer)._mlflow_client
263
+ self.registered_model_info = mlflow.register_model(
264
+ model_uri=self._last_model_uri,
265
+ name=self.register_model_name,
266
+ )
267
+
268
+ # Update the registered state hash after successful registration
269
+ self._last_registered_state_hash = self._compute_registration_state_hash()
270
+ self._has_been_trained = False # Reset training flag after registration
271
+
272
+ _LOGGER.info(f"Model registered as '{self.register_model_name}' "
273
+ f"version {self.registered_model_info.version}")
274
+
275
+ return self.registered_model_info
276
+
277
+ def _update_signature(self, trainer):
278
+ if self._inferred_signature is None:
279
+ _LOGGER.warning("No signature found. Cannot update signature.")
280
+ return
281
+ if self._last_model_uri is None:
282
+ _LOGGER.warning("No model URI found. Cannot update signature.")
283
+ return
284
+
285
+ # update the signature
286
+ try:
287
+ mlflow.models.set_signature(
288
+ model_uri=self._last_model_uri,
289
+ signature=self._inferred_signature,
290
+ )
291
+ except mlflow.exceptions.MlflowException as e:
292
+ _LOGGER.warning(f"Failed to update model signature. Check if model actually exists. {e}")
293
+
294
+ def __wrap_forward(self, pl_module: nn.Module):
295
+ original_forward = pl_module.forward
296
+
297
+ def wrapped_forward(x, *args, **kwargs):
298
+ x0 = help_infer_signature(x)
299
+ infered_params = self._infer_params(pl_module)
300
+ if len(infered_params) > 1:
301
+ infered_params = {param['name']: param['default']
302
+ for param in infered_params[1:] if 'name' in param}
303
+ else:
304
+ infered_params = None
305
+
306
+ self._inferred_signature = mlflow.models.infer_signature(model_input=x0,
307
+ params=infered_params)
308
+
309
+ # run once and get back to the original forward
310
+ pl_module.forward = original_forward
311
+ method = getattr(pl_module, 'forward')
312
+ out = method(x, *args, **kwargs)
313
+
314
+ output_sig = mlflow.models.infer_signature(model_output=help_infer_signature(out))
315
+ self._inferred_signature.outputs = output_sig.outputs
316
+
317
+ return out
318
+
319
+ pl_module.forward = wrapped_forward
320
+
321
+ def on_train_start(self, trainer, pl_module):
322
+ self._has_been_trained = True
323
+ self.__wrap_forward(pl_module)
324
+ logger = _get_MLFlowLogger(trainer)
325
+ if logger._tracking_uri.startswith('file:'):
326
+ _LOGGER.error("MLFlowLogger tracking URI is a local file path. "
327
+ "Model registration will likely fail if using MLflow Model Registry.")
328
+ if logger.experiment_id is not None:
329
+ mlflow.set_experiment(experiment_id=logger.experiment_id)
330
+ super().on_train_start(trainer, pl_module)
331
+
332
+ def on_train_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
333
+ super().on_train_end(trainer, pl_module)
334
+
335
+ if self.log_model_at_end_only and trainer.is_global_zero:
336
+ logger = _get_MLFlowLogger(trainer)
337
+ if logger is None:
338
+ _LOGGER.warning("No MLFlowLogger found. Cannot log model to MLFlow.")
339
+ else:
340
+ self.log_model_to_mlflow(trainer.model, run_id=logger.run_id)
341
+
342
+ self._update_signature(trainer)
343
+
344
+ if self.register_model_on == 'train' and self.register_model_name:
345
+ self.register_model(trainer)
346
+
347
+ def _restore_model_uri(self, trainer: L.Trainer) -> None:
348
+ """Restore the last model URI from the trainer's checkpoint path.
349
+ """
350
+ logger = _get_MLFlowLogger(trainer)
351
+ self._last_model_uri = None
352
+ self.last_saved_model_info = None
353
+ if logger is None:
354
+ _LOGGER.warning("No MLFlowLogger found. Cannot restore model URI.")
355
+ return
356
+ if trainer.ckpt_path is None:
357
+ return
358
+ if logger.run_id is None:
359
+ _LOGGER.warning("MLFlowLogger has no run_id. Cannot restore model URI.")
360
+ return
361
+ if logger.run_id not in str(trainer.ckpt_path):
362
+ _LOGGER.warning(f"Run ID mismatch between checkpoint path and MLFlowLogger." +
363
+ " Check `run_id` parameter in MLFlowLogger.")
364
+ return
365
+ retrieved_logged_models = mlflow.search_logged_models(
366
+ filter_string=f"name = '{Path(trainer.ckpt_path).stem[:256]}' AND source_run_id='{logger.run_id[:64]}'",
367
+ order_by=[{"field_name": "last_updated_timestamp", "ascending": False}],
368
+ output_format="list"
369
+ )
370
+ if not retrieved_logged_models:
371
+ _LOGGER.warning(f"No logged model found for checkpoint {trainer.ckpt_path}.")
372
+ return
373
+ # get the most recent one
374
+ self._last_model_uri = retrieved_logged_models[0].model_uri
375
+ try:
376
+ self.last_saved_model_info = mlflow.models.get_model_info(self._last_model_uri)
377
+ except mlflow.exceptions.MlflowException as e:
378
+ _LOGGER.warning(f"Failed to get model info for URI {self._last_model_uri}: {e}")
379
+ self.last_saved_model_info = None
380
+
381
+ def on_test_start(self, trainer, pl_module):
382
+ self.__wrap_forward(pl_module)
383
+ self._restore_model_uri(trainer)
384
+ return super().on_test_start(trainer, pl_module)
385
+
386
+ def on_predict_start(self, trainer, pl_module):
387
+ self.__wrap_forward(pl_module)
388
+ self._restore_model_uri(trainer)
389
+ return super().on_predict_start(trainer, pl_module)
390
+
391
+ def on_test_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
392
+ super().on_test_end(trainer, pl_module)
393
+
394
+ if self.register_model_on == 'test' and self.register_model_name:
395
+ self._update_signature(trainer)
396
+ self.register_model(trainer)
397
+
398
+ def on_predict_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
399
+ super().on_predict_end(trainer, pl_module)
400
+
401
+ if self.register_model_on == 'predict' and self.register_model_name:
402
+ self._update_signature(trainer)
403
+ self.register_model(trainer)
404
+
405
+ def on_validation_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
406
+ super().on_validation_end(trainer, pl_module)
407
+
408
+ if self.register_model_on == 'val' and self.register_model_name:
409
+ self._update_signature(trainer)
410
+ self.register_model(trainer)
@@ -0,0 +1,93 @@
1
+ import logging
2
+ import json
3
+ import lightning as L
4
+ from lightning.pytorch.loggers import MLFlowLogger
5
+ import mlflow
6
+ import os
7
+ from tempfile import TemporaryDirectory
8
+
9
+ _LOGGER = logging.getLogger(__name__)
10
+
11
+
12
+ def download_model_metadata(model_uri: str) -> dict:
13
+ from mlflow.tracking.artifact_utils import get_artifact_repository
14
+
15
+ art_repo = get_artifact_repository(artifact_uri=model_uri)
16
+ try:
17
+ out_artifact_path = art_repo.download_artifacts(artifact_path='metadata.json')
18
+ except OSError as e:
19
+ _LOGGER.warning(f"Error downloading model metadata: {e}")
20
+ return {}
21
+
22
+ with open(out_artifact_path, 'r') as f:
23
+ metadata = json.load(f)
24
+ return metadata
25
+
26
+
27
+ def _get_MLFlowLogger(trainer: L.Trainer) -> MLFlowLogger:
28
+ for logger in trainer.loggers:
29
+ if isinstance(logger, MLFlowLogger):
30
+ return logger
31
+ raise ValueError("No MLFlowLogger found in the trainer loggers.")
32
+
33
+
34
+ def log_model_metadata(metadata: dict,
35
+ mlflow_model: mlflow.models.Model | None = None,
36
+ logger: MLFlowLogger | L.Trainer | None = None,
37
+ model_path: str | None = None,
38
+ run_id: str | None = None,
39
+ ) -> None:
40
+ """
41
+ Log additional metadata to the MLflow model.
42
+ It should be provided the one of the following combination of parameters:
43
+ 1. `mlflow_model`
44
+ 2. `logger` and `model_path`
45
+ 3. `run_id` and `model_path`
46
+
47
+ Args:
48
+ self: The instance of the class where this method is called.
49
+ metadata (dict): The metadata to log.
50
+ mlflow_model (mlflow.models.Model, optional): The MLflow model object. Defaults to None.
51
+ logger (MLFlowLogger or L.Trainer, optional): The MLFlow logger or Lightning Trainer instance. Defaults to None.
52
+ model_path (str, optional): The path where the model is stored in MLflow. Defaults to None.
53
+ run_id (str, optional): The run ID of the MLflow run. Defaults to None.
54
+ """
55
+
56
+ # Validate inputs
57
+ if mlflow_model is None and (logger is None or model_path is None) and (run_id is None or model_path is None):
58
+ raise ValueError(
59
+ "You must provide either `mlflow_model`, or both `logger` and `model_path`, "
60
+ "or both `run_id` and `model_path`."
61
+ )
62
+ # not both
63
+ if mlflow_model is not None and logger is not None:
64
+ raise ValueError("Only one of mlflow_model or logger can be provided.")
65
+
66
+ if logger is not None and isinstance(logger, L.Trainer):
67
+ logger = _get_MLFlowLogger(logger)
68
+ if logger is None:
69
+ raise ValueError("MLFlowLogger not found in the Trainer's loggers.")
70
+ run_id = logger.run_id
71
+ artifact_path = model_path
72
+ mlfclient = logger.experiment
73
+ elif mlflow_model is not None:
74
+ run_id = mlflow_model.run_id
75
+ artifact_path = mlflow_model.artifact_path
76
+ mlfclient = mlflow.client.MlflowClient()
77
+ elif run_id is not None and model_path is not None:
78
+ mlfclient = mlflow.client.MlflowClient()
79
+ artifact_path = model_path
80
+ else:
81
+ raise ValueError("Invalid logger or mlflow_model provided.")
82
+
83
+ with TemporaryDirectory() as tmpdir:
84
+ metadata_path = os.path.join(tmpdir, "metadata.json")
85
+ with open(metadata_path, "w") as f:
86
+ json.dump(metadata, f, indent=2)
87
+
88
+ mlfclient.log_artifact(
89
+ run_id=run_id,
90
+ local_path=metadata_path,
91
+ artifact_path=artifact_path,
92
+ )
93
+ _LOGGER.debug(f"Additional metadata logged to {artifact_path}/metadata.json")
@@ -0,0 +1,76 @@
1
+ from mlflow.store.tracking.rest_store import RestStore
2
+ from mlflow.exceptions import MlflowException
3
+ from mlflow.utils.proto_json_utils import message_to_json
4
+ from functools import partial
5
+ import json
6
+ from typing_extensions import override
7
+
8
+
9
+ class DatamintStore(RestStore):
10
+ """
11
+ DatamintStore is a subclass of RestStore that provides a tracking store
12
+ implementation for Datamint.
13
+ """
14
+
15
+ def __init__(self, store_uri: str, artifact_uri=None, force_valid=True):
16
+ # Ensure MLflow environment is configured when store is initialized
17
+ from datamint.mlflow.env_utils import setup_mlflow_environment
18
+ from mlflow.utils.credentials import get_default_host_creds
19
+ setup_mlflow_environment()
20
+
21
+ if store_uri.startswith('datamint://') or 'datamint.io' in store_uri or force_valid:
22
+ self.invalid = False
23
+ else:
24
+ self.invalid = True
25
+
26
+ store_uri = store_uri.split('datamint://', maxsplit=1)[-1]
27
+ get_host_creds = partial(get_default_host_creds, store_uri)
28
+ super().__init__(get_host_creds=get_host_creds)
29
+
30
+ def create_experiment(self, name, artifact_location=None, tags=None, project_id: str | None = None) -> str:
31
+ from mlflow.protos.service_pb2 import CreateExperiment
32
+ from datamint.mlflow.tracking.fluent import get_active_project_id
33
+
34
+ if self.invalid:
35
+ return super().create_experiment(name, artifact_location, tags)
36
+ if project_id is None:
37
+ project_id = get_active_project_id()
38
+ tag_protos = [tag.to_proto() for tag in tags] if tags else []
39
+ req_body = message_to_json(
40
+ CreateExperiment(name=name, artifact_location=artifact_location, tags=tag_protos)
41
+ )
42
+
43
+ req_body = json.loads(req_body)
44
+ req_body["project_id"] = project_id # FIXME: this should be in the proto
45
+ req_body = json.dumps(req_body)
46
+
47
+ response_proto = self._call_endpoint(CreateExperiment, req_body)
48
+ return response_proto.experiment_id
49
+
50
+ @override
51
+ def get_experiment_by_name(self, experiment_name, project_id: str | None = None):
52
+ from datamint.mlflow.tracking.fluent import get_active_project_id
53
+ from mlflow.protos.service_pb2 import GetExperimentByName
54
+ from mlflow.entities import Experiment
55
+ from mlflow.protos import databricks_pb2
56
+
57
+ if self.invalid:
58
+ return super().get_experiment_by_name(experiment_name)
59
+ if project_id is None:
60
+ project_id = get_active_project_id()
61
+ try:
62
+ req_body = message_to_json(GetExperimentByName(experiment_name=experiment_name))
63
+ if project_id:
64
+ body = json.loads(req_body)
65
+ body["project_id"] = project_id
66
+ req_body = json.dumps(body)
67
+
68
+ response_proto = self._call_endpoint(GetExperimentByName, req_body)
69
+ return Experiment.from_proto(response_proto.experiment)
70
+ except MlflowException as e:
71
+ if e.error_code == databricks_pb2.ErrorCode.Name(
72
+ databricks_pb2.RESOURCE_DOES_NOT_EXIST
73
+ ):
74
+ return None
75
+ else:
76
+ raise
@@ -0,0 +1,27 @@
1
+ import sys
2
+ import os
3
+ from mlflow.tracking.default_experiment.abstract_context import DefaultExperimentProvider
4
+
5
+
6
+ class DatamintExperimentProvider(DefaultExperimentProvider):
7
+ _experiment_id = None
8
+
9
+ def in_context(self):
10
+ return True
11
+
12
+ def get_experiment_id(self):
13
+ from mlflow.tracking.client import MlflowClient
14
+
15
+ if DatamintExperimentProvider._experiment_id is not None:
16
+ return self._experiment_id
17
+ # Get the filename of the main source file
18
+ source_code_filename = os.path.basename(sys.argv[0])
19
+ mlflowclient = MlflowClient()
20
+ exp = mlflowclient.get_experiment_by_name(source_code_filename)
21
+ if exp is None:
22
+ experiment_id = mlflowclient.create_experiment(source_code_filename)
23
+ else:
24
+ experiment_id = exp.experiment_id
25
+ DatamintExperimentProvider._experiment_id = experiment_id
26
+
27
+ return experiment_id