datamint 2.3.4__py3-none-any.whl → 2.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datamint might be problematic. Click here for more details.

@@ -0,0 +1,338 @@
1
+ from lightning.pytorch.callbacks import ModelCheckpoint
2
+ from pathlib import Path
3
+ from weakref import proxy
4
+ from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
5
+ from typing import Literal, Any
6
+ import inspect
7
+ import torch
8
+ from torch import nn
9
+ import lightning.pytorch as L
10
+ from datamint.mlflow.models import log_model_metadata, _get_MLFlowLogger
11
+ from datamint.mlflow.env_utils import ensure_mlflow_configured
12
+ import mlflow
13
+ import logging
14
+ from lightning.pytorch.loggers import MLFlowLogger
15
+
16
+ _LOGGER = logging.getLogger(__name__)
17
+
18
+
19
+ def help_infer_signature(x):
20
+ if isinstance(x, torch.Tensor):
21
+ return x.detach().cpu().numpy()
22
+ elif isinstance(x, dict):
23
+ return {k: v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else v for k, v in x.items()}
24
+ elif isinstance(x, list):
25
+ return [v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else v for v in x]
26
+ elif isinstance(x, tuple):
27
+ return tuple(v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else v for v in x)
28
+
29
+ return x
30
+
31
+
32
+ class MLFlowModelCheckpoint(ModelCheckpoint):
33
+ def __init__(self, *args,
34
+ register_model_name: str | None = None,
35
+ register_model_on: Literal["train", "val", "test", "predict"] | None = None,
36
+ code_paths: list[str] | None = None,
37
+ log_model_at_end_only: bool = True,
38
+ additional_metadata: dict[str, Any] | None = None,
39
+ extra_pip_requirements: list[str] | None = None,
40
+ **kwargs):
41
+ """
42
+ MLFlowModelCheckpoint is a custom callback for PyTorch Lightning that integrates with MLFlow to log and register models.
43
+
44
+ Args:
45
+ register_model_name (str | None): The name to register the model under in MLFlow. If None, the model will not be registered.
46
+ register_model_on (Literal["train", "val", "test", "predict"] | None): The stage at which to register the model. If None, the model will not be registered.
47
+ code_paths (list[str] | None): List of paths to Python files that should be included in the MLFlow model.
48
+ log_model_at_end_only (bool): If True, only log the model to MLFlow at the end of the training instead of after every checkpoint save.
49
+ additional_metadata (dict[str, Any] | None): Additional metadata to log with the model as a JSON file.
50
+ extra_pip_requirements (list[str] | None): Additional pip requirements to include with the MLFlow model.
51
+ **kwargs: Keyword arguments for ModelCheckpoint.
52
+ """
53
+ # Ensure MLflow is configured when callback is initialized
54
+ ensure_mlflow_configured()
55
+
56
+ super().__init__(*args, **kwargs)
57
+ if self.save_top_k > 1:
58
+ raise NotImplementedError("save_top_k > 1 is not supported. "
59
+ "Please use save_top_k=1 to save only the best model.")
60
+ if self.save_last is not None and self.save_top_k != 0 and self.monitor is not None:
61
+ raise NotImplementedError("save_last is not supported with monitor and save_top_k!=0. "
62
+ "Please use two separate callbacks: one for saving the last model "
63
+ "and another for saving the best model based on the monitor metric.")
64
+
65
+ if register_model_name is not None and register_model_on is None:
66
+ raise ValueError("If you provide a register_model_name, you must also provide a register_model_on.")
67
+ if register_model_on is not None and register_model_name is None:
68
+ raise ValueError("If you provide a register_model_on, you must also provide a register_model_name.")
69
+ if register_model_on not in ["train", "val", "test", "predict", None]:
70
+ raise ValueError("register_model_on must be one of train, val, test or predict.")
71
+
72
+ self.register_model_name = register_model_name
73
+ self.register_model_on = register_model_on
74
+ self.log_model_at_end_only = log_model_at_end_only
75
+ self._last_model_uri = None
76
+ self.last_saved_model_info = None
77
+ self._inferred_signature = None
78
+ self._input_example = None
79
+ self.code_paths = code_paths
80
+ self.additional_metadata = additional_metadata or {}
81
+ self.extra_pip_requirements = extra_pip_requirements or []
82
+
83
+ def _infer_params(self, model: nn.Module) -> tuple[dict, ...]:
84
+ """Extract metadata from the model's forward method signature.
85
+
86
+ Returns:
87
+ A tuple of dicts, each containing parameter metadata ordered by position.
88
+ """
89
+ forward_method = getattr(model.__class__, 'forward', None)
90
+
91
+ if forward_method is None:
92
+ return ()
93
+
94
+ try:
95
+ sig = inspect.signature(forward_method)
96
+ params_list = []
97
+
98
+ for param_name, param in sig.parameters.items():
99
+ if param_name == 'self':
100
+ continue
101
+
102
+ param_info = {
103
+ 'name': param_name,
104
+ 'kind': param.kind.name,
105
+ 'annotation': param.annotation if param.annotation != inspect.Parameter.empty else None,
106
+ 'default': param.default if param.default != inspect.Parameter.empty else None,
107
+ }
108
+ params_list.append(param_info)
109
+
110
+ # Add return annotation if available as the last element
111
+ return_annotation = sig.return_annotation
112
+ if return_annotation != inspect.Signature.empty:
113
+ return_info = {'_return_annotation': str(return_annotation)}
114
+ params_list.append(return_info)
115
+
116
+ return tuple(params_list)
117
+
118
+ except Exception as e:
119
+ _LOGGER.warning(f"Failed to infer forward method parameters: {e}")
120
+ return ()
121
+
122
+ def _save_checkpoint(self, trainer: L.Trainer, filepath: str) -> None:
123
+ _LOGGER.debug(f"Saving checkpoint to {filepath}...")
124
+ trainer.save_checkpoint(filepath, self.save_weights_only)
125
+
126
+ self._last_global_step_saved = trainer.global_step
127
+ self._last_checkpoint_saved = filepath
128
+
129
+ # notify loggers
130
+ if trainer.is_global_zero:
131
+ for logger in trainer.loggers:
132
+ logger.after_save_checkpoint(proxy(self))
133
+ if isinstance(logger, MLFlowLogger) and not self.log_model_at_end_only:
134
+ _LOGGER.debug(f"_save_checkpoint: Logging model to MLFlow at {filepath}...")
135
+ self.log_model_to_mlflow(trainer.model, run_id=logger.run_id)
136
+
137
+ def log_additional_metadata(self, logger: MLFlowLogger | L.Trainer,
138
+ additional_metadata: dict) -> None:
139
+ """Log additional metadata as a JSON file to the model artifact.
140
+
141
+ Args:
142
+ logger: The MLFlowLogger or Lightning Trainer instance to use for logging.
143
+ additional_metadata: A dictionary containing additional metadata to log.
144
+ """
145
+ self.additional_metadata = additional_metadata
146
+ if not self.additional_metadata:
147
+ return
148
+
149
+ if self.last_saved_model_info is None:
150
+ _LOGGER.warning("No model has been saved yet. Cannot log additional metadata.")
151
+ return
152
+
153
+ try:
154
+ log_model_metadata(metadata=self.additional_metadata,
155
+ logger=logger,
156
+ model_path=self.last_saved_model_info.artifact_path)
157
+ except Exception as e:
158
+ _LOGGER.warning(f"Failed to log additional metadata: {e}")
159
+
160
+ def log_model_to_mlflow(self,
161
+ model: nn.Module,
162
+ run_id: str | MLFlowLogger
163
+ ) -> None:
164
+ """Log the model to MLflow."""
165
+ if isinstance(run_id, MLFlowLogger):
166
+ logger = run_id
167
+ if logger.run_id is None:
168
+ raise ValueError("MLFlowLogger has no run_id. Cannot log model to MLFlow.")
169
+ run_id = logger.run_id
170
+
171
+ if self._last_checkpoint_saved is None or self._last_checkpoint_saved == '':
172
+ _LOGGER.warning("No checkpoint saved yet. Cannot log model to MLFlow.")
173
+ return
174
+
175
+ orig_device = next(model.parameters()).device
176
+ model = model.cpu() # Ensure the model is on CPU for logging
177
+
178
+ requirements = list(self.extra_pip_requirements)
179
+ # check if lightning is in the requirements
180
+ if not any('lightning' in req.lower() for req in requirements):
181
+ requirements.append(f'lightning=={L.__version__}')
182
+
183
+ _LOGGER.debug(f"log_model_to_mlflow: Logging model to MLFlow at {self._last_checkpoint_saved}...")
184
+ modelinfo = mlflow.pytorch.log_model(
185
+ pytorch_model=model,
186
+ artifact_path=f'model/{Path(self._last_checkpoint_saved).stem}',
187
+ signature=self._inferred_signature,
188
+ run_id=run_id,
189
+ extra_pip_requirements=requirements,
190
+ code_paths=self.code_paths
191
+ )
192
+
193
+ model.to(device=orig_device) # Move the model back to its original device
194
+ self._last_model_uri = modelinfo.model_uri
195
+ self.last_saved_model_info = modelinfo
196
+
197
+ # Log additional metadata after the model is saved
198
+ log_model_metadata(self.additional_metadata,
199
+ model_path=modelinfo.artifact_path,
200
+ run_id=run_id)
201
+
202
+ def _remove_checkpoint(self, trainer: L.Trainer, filepath: str) -> None:
203
+ super()._remove_checkpoint(trainer, filepath)
204
+ # remove the checkpoint from mlflow
205
+ if trainer.is_global_zero:
206
+ for logger in trainer.loggers:
207
+ if isinstance(logger, MLFlowLogger):
208
+ artifact_uri = logger.experiment.get_run(logger.run_id).info.artifact_uri
209
+ rep = get_artifact_repository(artifact_uri)
210
+ rep.delete_artifacts(f'model/{Path(filepath).stem}')
211
+
212
+ def register_model(self, trainer=None):
213
+ """Register the model in MLFlow Model Registry."""
214
+ # mlflow_client = _get_MLFlowLogger(trainer)._mlflow_client
215
+ return mlflow.register_model(
216
+ model_uri=self._last_model_uri,
217
+ name=self.register_model_name,
218
+ )
219
+
220
+ def _update_signature(self, trainer):
221
+ if self._inferred_signature is None:
222
+ _LOGGER.warning("No signature found. Cannot update signature.")
223
+ return
224
+ if self._last_model_uri is None:
225
+ _LOGGER.warning("No model URI found. Cannot update signature.")
226
+ return
227
+
228
+ mllogger = _get_MLFlowLogger(trainer)
229
+ mlclient = mllogger._mlflow_client
230
+
231
+ # check if the model exists
232
+ for artifact_info in mlclient.list_artifacts(run_id=mllogger.run_id):
233
+ if artifact_info.path.startswith('model'):
234
+ break
235
+ else:
236
+ _LOGGER.warning(f"Model URI {self._last_model_uri} does not exist. Cannot update signature.")
237
+ return
238
+ _LOGGER.debug(f"Updating signature for model URI: {self._last_model_uri}...")
239
+ # update the signature
240
+ mlflow.models.set_signature(
241
+ model_uri=self._last_model_uri,
242
+ signature=self._inferred_signature,
243
+ )
244
+
245
+ def __wrap_forward(self, pl_module: nn.Module):
246
+ original_forward = pl_module.forward
247
+
248
+ def wrapped_forward(x, *args, **kwargs):
249
+ x0 = help_infer_signature(x)
250
+ infered_params = self._infer_params(pl_module)
251
+ if len(infered_params) > 1:
252
+ infered_params = {param['name']: param['default']
253
+ for param in infered_params[1:] if 'name' in param}
254
+ else:
255
+ infered_params = None
256
+
257
+ self._inferred_signature = mlflow.models.infer_signature(model_input=x0,
258
+ params=infered_params)
259
+
260
+
261
+ # run once and get back to the original forward
262
+ pl_module.forward = original_forward
263
+ method = getattr(pl_module, 'forward')
264
+ out = method(x, *args, **kwargs)
265
+
266
+ output_sig = mlflow.models.infer_signature(model_output=help_infer_signature(out))
267
+ self._inferred_signature.outputs = output_sig.outputs
268
+
269
+ return out
270
+
271
+ pl_module.forward = wrapped_forward
272
+
273
+ def on_train_start(self, trainer, pl_module):
274
+ self.__wrap_forward(pl_module)
275
+
276
+ def on_train_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
277
+ super().on_train_end(trainer, pl_module)
278
+
279
+ if self.log_model_at_end_only and trainer.is_global_zero:
280
+ logger = _get_MLFlowLogger(trainer)
281
+ if logger is None:
282
+ _LOGGER.warning("No MLFlowLogger found. Cannot log model to MLFlow.")
283
+ else:
284
+ self.log_model_to_mlflow(trainer.model, run_id=logger.run_id)
285
+
286
+ self._update_signature(trainer)
287
+
288
+ if self.register_model_on == 'train':
289
+ self.register_model(trainer)
290
+
291
+ def _restore_model_uri(self, trainer: L.Trainer) -> None:
292
+ logger = _get_MLFlowLogger(trainer)
293
+ if logger is None:
294
+ _LOGGER.warning("No MLFlowLogger found. Cannot restore model URI.")
295
+ return
296
+ if trainer.ckpt_path is None:
297
+ return
298
+ extracted_run_id = Path(trainer.ckpt_path).parts[1]
299
+ if extracted_run_id != logger.run_id:
300
+ _LOGGER.warning(f"Run ID mismatch: {extracted_run_id} != {logger.run_id}." +
301
+ " Check `run_id` parameter in MLFlowLogger.")
302
+ self._last_model_uri = f'runs:/{logger.run_id}/model/{Path(trainer.ckpt_path).stem}'
303
+ try:
304
+ self.last_saved_model_info = mlflow.models.get_model_info(self._last_model_uri)
305
+ except mlflow.exceptions.MlflowException as e:
306
+ _LOGGER.warning(f"Failed to get model info for URI {self._last_model_uri}: {e}")
307
+ self.last_saved_model_info = None
308
+
309
+ def on_test_start(self, trainer, pl_module):
310
+ self.__wrap_forward(pl_module)
311
+ self._restore_model_uri(trainer)
312
+ return super().on_test_start(trainer, pl_module)
313
+
314
+ def on_predict_start(self, trainer, pl_module):
315
+ self.__wrap_forward(pl_module)
316
+ self._restore_model_uri(trainer)
317
+ return super().on_predict_start(trainer, pl_module)
318
+
319
+ def on_test_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
320
+ super().on_test_end(trainer, pl_module)
321
+
322
+ if self.register_model_on == 'test':
323
+ self._update_signature(trainer)
324
+ self.register_model(trainer)
325
+
326
+ def on_predict_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
327
+ super().on_predict_end(trainer, pl_module)
328
+
329
+ if self.register_model_on == 'predict':
330
+ self._update_signature(trainer)
331
+ self.register_model(trainer)
332
+
333
+ def on_validation_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
334
+ super().on_validation_end(trainer, pl_module)
335
+
336
+ if self.register_model_on == 'val':
337
+ self._update_signature(trainer)
338
+ self.register_model(trainer)
@@ -0,0 +1,94 @@
1
+ import logging
2
+ import json
3
+ import lightning as L
4
+ from lightning.pytorch.loggers import MLFlowLogger
5
+ import mlflow
6
+ import os
7
+ from tempfile import TemporaryDirectory
8
+ from torch import nn
9
+
10
+ _LOGGER = logging.getLogger(__name__)
11
+
12
+
13
+ def download_model_metadata(model_uri: str) -> dict:
14
+ from mlflow.tracking.artifact_utils import get_artifact_repository
15
+
16
+ art_repo = get_artifact_repository(artifact_uri=model_uri)
17
+ try:
18
+ out_artifact_path = art_repo.download_artifacts(artifact_path='metadata.json')
19
+ except OSError as e:
20
+ _LOGGER.warning(f"Error downloading model metadata: {e}")
21
+ return {}
22
+
23
+ with open(out_artifact_path, 'r') as f:
24
+ metadata = json.load(f)
25
+ return metadata
26
+
27
+
28
+ def _get_MLFlowLogger(trainer: L.Trainer) -> MLFlowLogger:
29
+ for logger in trainer.loggers:
30
+ if isinstance(logger, MLFlowLogger):
31
+ return logger
32
+ raise ValueError("No MLFlowLogger found in the trainer loggers.")
33
+
34
+
35
+ def log_model_metadata(metadata: dict,
36
+ mlflow_model: mlflow.models.Model | None = None,
37
+ logger: MLFlowLogger | L.Trainer | None = None,
38
+ model_path: str | None = None,
39
+ run_id: str | None = None,
40
+ ) -> None:
41
+ """
42
+ Log additional metadata to the MLflow model.
43
+ It should be provided the one of the following combination of parameters:
44
+ 1. `mlflow_model`
45
+ 2. `logger` and `model_path`
46
+ 3. `run_id` and `model_path`
47
+
48
+ Args:
49
+ self: The instance of the class where this method is called.
50
+ metadata (dict): The metadata to log.
51
+ mlflow_model (mlflow.models.Model, optional): The MLflow model object. Defaults to None.
52
+ logger (MLFlowLogger or L.Trainer, optional): The MLFlow logger or Lightning Trainer instance. Defaults to None.
53
+ model_path (str, optional): The path where the model is stored in MLflow. Defaults to None.
54
+ run_id (str, optional): The run ID of the MLflow run. Defaults to None.
55
+ """
56
+
57
+ # Validate inputs
58
+ if mlflow_model is None and (logger is None or model_path is None) and (run_id is None or model_path is None):
59
+ raise ValueError(
60
+ "You must provide either `mlflow_model`, or both `logger` and `model_path`, "
61
+ "or both `run_id` and `model_path`."
62
+ )
63
+ # not both
64
+ if mlflow_model is not None and logger is not None:
65
+ raise ValueError("Only one of mlflow_model or logger can be provided.")
66
+
67
+ if logger is not None and isinstance(logger, L.Trainer):
68
+ logger = _get_MLFlowLogger(logger)
69
+ if logger is None:
70
+ raise ValueError("MLFlowLogger not found in the Trainer's loggers.")
71
+ run_id = logger.run_id
72
+ artifact_path = model_path
73
+ mlfclient = logger.experiment
74
+ elif mlflow_model is not None:
75
+ run_id = mlflow_model.run_id
76
+ artifact_path = mlflow_model.artifact_path
77
+ mlfclient = mlflow.client.MlflowClient()
78
+ elif run_id is not None and model_path is not None:
79
+ mlfclient = mlflow.client.MlflowClient()
80
+ artifact_path = model_path
81
+ else:
82
+ raise ValueError("Invalid logger or mlflow_model provided.")
83
+
84
+ with TemporaryDirectory() as tmpdir:
85
+ metadata_path = os.path.join(tmpdir, "metadata.json")
86
+ with open(metadata_path, "w") as f:
87
+ json.dump(metadata, f, indent=2)
88
+
89
+ mlfclient.log_artifact(
90
+ run_id=run_id,
91
+ local_path=metadata_path,
92
+ artifact_path=artifact_path,
93
+ )
94
+ _LOGGER.debug(f"Additional metadata logged to {artifact_path}/metadata.json")
@@ -0,0 +1,46 @@
1
+ from mlflow.store.tracking.rest_store import RestStore
2
+ from functools import partial
3
+ from .fluent import get_active_project_id
4
+ import json
5
+
6
+
7
+ class DatamintStore(RestStore):
8
+ """
9
+ DatamintStore is a subclass of RestStore that provides a tracking store
10
+ implementation for Datamint.
11
+ """
12
+
13
+ def __init__(self, store_uri: str, artifact_uri=None, force_valid=True):
14
+ # Ensure MLflow environment is configured when store is initialized
15
+ from datamint.mlflow.env_utils import setup_mlflow_environment
16
+ from mlflow.utils.credentials import get_default_host_creds
17
+ setup_mlflow_environment()
18
+
19
+ if store_uri.startswith('datamint://') or 'datamint.io' in store_uri or force_valid:
20
+ self.invalid = False
21
+ else:
22
+ self.invalid = True
23
+
24
+ store_uri = store_uri.split('datamint://', maxsplit=1)[-1]
25
+ get_host_creds = partial(get_default_host_creds, store_uri)
26
+ super().__init__(get_host_creds=get_host_creds)
27
+
28
+ def create_experiment(self, name, artifact_location=None, tags=None, project_id: str = None) -> str:
29
+ from mlflow.protos.service_pb2 import CreateExperiment
30
+ from mlflow.utils.proto_json_utils import message_to_json
31
+
32
+ if self.invalid:
33
+ return super().create_experiment(name, artifact_location, tags)
34
+ if project_id is None:
35
+ project_id = get_active_project_id()
36
+ tag_protos = [tag.to_proto() for tag in tags] if tags else []
37
+ req_body = message_to_json(
38
+ CreateExperiment(name=name, artifact_location=artifact_location, tags=tag_protos)
39
+ )
40
+
41
+ req_body = json.loads(req_body)
42
+ req_body["project_id"] = project_id # FIXME: this should be in the proto
43
+ req_body = json.dumps(req_body)
44
+
45
+ response_proto = self._call_endpoint(CreateExperiment, req_body)
46
+ return response_proto.experiment_id
@@ -0,0 +1,27 @@
1
+ import sys
2
+ import os
3
+ from mlflow.tracking.default_experiment.abstract_context import DefaultExperimentProvider
4
+
5
+
6
+ class DatamintExperimentProvider(DefaultExperimentProvider):
7
+ _experiment_id = None
8
+
9
+ def in_context(self):
10
+ return True
11
+
12
+ def get_experiment_id(self):
13
+ from mlflow.tracking.client import MlflowClient
14
+
15
+ if DatamintExperimentProvider._experiment_id is not None:
16
+ return self._experiment_id
17
+ # Get the filename of the main source file
18
+ source_code_filename = os.path.basename(sys.argv[0])
19
+ mlflowclient = MlflowClient()
20
+ exp = mlflowclient.get_experiment_by_name(source_code_filename)
21
+ if exp is None:
22
+ experiment_id = mlflowclient.create_experiment(source_code_filename)
23
+ else:
24
+ experiment_id = exp.experiment_id
25
+ DatamintExperimentProvider._experiment_id = experiment_id
26
+
27
+ return experiment_id
@@ -0,0 +1,78 @@
1
+ from typing import Optional
2
+ import threading
3
+ import logging
4
+ from datamint import Api
5
+ from datamint.exceptions import DatamintException
6
+ import os
7
+ from datamint.mlflow.env_vars import EnvVars
8
+ from datamint.mlflow.env_utils import ensure_mlflow_configured
9
+
10
+ _PROJECT_LOCK = threading.Lock()
11
+ _LOGGER = logging.getLogger(__name__)
12
+
13
+ _ACTIVE_PROJECT_ID: Optional[str] = None
14
+
15
+
16
+ def get_active_project_id() -> str | None:
17
+ """
18
+ Get the active project ID from the environment variable or the global variable.
19
+ """
20
+ global _ACTIVE_PROJECT_ID
21
+
22
+ if _ACTIVE_PROJECT_ID is not None:
23
+ return _ACTIVE_PROJECT_ID
24
+ # Check if the environment variable is set
25
+ project_id = os.getenv(EnvVars.DATAMINT_PROJECT_ID.value)
26
+ if project_id is not None:
27
+ _ACTIVE_PROJECT_ID = project_id
28
+ return project_id
29
+ project_name = os.getenv(EnvVars.DATAMINT_PROJECT_NAME.value)
30
+ if project_name is not None:
31
+ project = _find_project_by_name(project_name)
32
+ if project is not None:
33
+ _ACTIVE_PROJECT_ID = project['id']
34
+ return _ACTIVE_PROJECT_ID
35
+
36
+ return None
37
+
38
+
39
+ def _find_project_by_name(project_name: str):
40
+ dt_client = Api(check_connection=False)
41
+ project = dt_client.projects.get_by_name(project_name)
42
+ if project is None:
43
+ raise DatamintException(f"Project with name '{project_name}' does not exist.")
44
+ return project
45
+
46
+
47
+ def set_project(project_name: Optional[str] = None, project_id: Optional[str] = None):
48
+ from mlflow.exceptions import MlflowException
49
+ global _ACTIVE_PROJECT_ID
50
+
51
+ # Ensure MLflow is properly configured before proceeding
52
+ ensure_mlflow_configured()
53
+
54
+ if project_name is None and project_id is None:
55
+ raise MlflowException("You must specify either a project name or a project id")
56
+
57
+ if project_name is not None and project_id is not None:
58
+ raise MlflowException("You cannot specify both a project name and a project id")
59
+
60
+ with _PROJECT_LOCK:
61
+ dt_client = Api(check_connection=False)
62
+ if project_id is None:
63
+ project = dt_client.projects.get_by_name(project_name)
64
+ if project is None:
65
+ raise DatamintException(f"Project with name '{project_name}' does not exist.")
66
+ project_id = project.id
67
+ else:
68
+ project = dt_client.projects.get_by_id(project_id)
69
+ if project is None:
70
+ raise DatamintException(f"Project with id '{project_id}' does not exist.")
71
+
72
+ _ACTIVE_PROJECT_ID = project_id
73
+
74
+ # Set 'DATAMINT_PROJECT_ID' environment variable
75
+ # so that subprocess can inherit it.
76
+ os.environ[EnvVars.DATAMINT_PROJECT_ID.value] = project_id
77
+
78
+ return project