ddi-fw 0.0.217__py3-none-any.whl → 0.0.218__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,18 +7,16 @@ from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, Callback
7
7
  from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
8
8
  import numpy as np
9
9
  from tensorflow.keras import Model
10
- import mlflow
11
- from mlflow.utils.autologging_utils import batch_metrics_logger
12
-
13
- from mlflow.models import infer_signature
14
10
  from ddi_fw.ml.evaluation_helper import Metrics, evaluate
15
11
 
16
12
  # import tf2onnx
17
13
  # import onnx
18
14
 
15
+ from ddi_fw.ml.tracking_service import TrackingService
19
16
  import ddi_fw.utils as utils
20
17
  import os
21
18
 
19
+
22
20
  def convert_to_categorical(arr, num_classes):
23
21
  """
24
22
  This function takes an array of labels and converts them to one-hot encoding
@@ -33,7 +31,7 @@ def convert_to_categorical(arr, num_classes):
33
31
  - The one-hot encoded array if the original array was binary or label encoded
34
32
  - The original array if it doesn't require any conversion
35
33
  """
36
-
34
+
37
35
  try:
38
36
  # First, check if the array is binary-encoded
39
37
  if not utils.is_binary_encoded(arr):
@@ -45,7 +43,7 @@ def convert_to_categorical(arr, num_classes):
45
43
  except Exception as e:
46
44
  # If binary encoding check raises an error, print it and continue to label encoding check
47
45
  print(f"Error while checking binary encoding: {e}")
48
-
46
+
49
47
  try:
50
48
  # Check if the array is label-encoded
51
49
  if utils.is_label_encoded(arr):
@@ -56,21 +54,21 @@ def convert_to_categorical(arr, num_classes):
56
54
  print(f"Error while checking label encoding: {e}")
57
55
  # If the arr labels don't match any of the known encodings, raise an error
58
56
  raise ValueError("Unknown label encoding format.")
59
-
57
+
60
58
  # If no conversion was needed, return the original array
61
-
59
+
62
60
  return arr
63
61
 
64
62
 
65
63
  class TFModelWrapper(ModelWrapper):
66
64
 
67
- def __init__(self, date, descriptor, model_func, use_mlflow=False, **kwargs):
65
+ def __init__(self, date, descriptor, model_func, tracking_service: TrackingService | None = None, **kwargs):
68
66
  super().__init__(date, descriptor, model_func, **kwargs)
69
67
  self.batch_size = kwargs.get('batch_size', 128)
70
68
  self.epochs = kwargs.get('epochs', 100)
71
- self.use_mlflow = use_mlflow
69
+ self.tracking_service = tracking_service
72
70
 
73
- # TODO think different settings for num_classes
71
+ # TODO think different settings for num_classes
74
72
  def fit_model(self, X_train, y_train, X_valid, y_valid):
75
73
  self.kwargs['input_shape'] = self.train_data.shape
76
74
  self.num_classes = len(np.unique(y_train, axis=0))
@@ -85,7 +83,11 @@ class TFModelWrapper(ModelWrapper):
85
83
  )
86
84
  early_stopping = EarlyStopping(
87
85
  monitor='val_loss', patience=10, mode='auto')
88
- custom_callback = CustomCallback(self.use_mlflow)
86
+ if self.tracking_service:
87
+ custom_callback = CustomCallback(self.tracking_service)
88
+ callbacks=[early_stopping, checkpoint, custom_callback]
89
+ else:
90
+ callbacks=[early_stopping, checkpoint]
89
91
  train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
90
92
  train_dataset = train_dataset.batch(batch_size=self.batch_size)
91
93
 
@@ -100,7 +102,7 @@ class TFModelWrapper(ModelWrapper):
100
102
  train_dataset,
101
103
  epochs=self.epochs,
102
104
  # validation_data=val_dataset,
103
- callbacks=[early_stopping, checkpoint, custom_callback]
105
+ callbacks=callbacks
104
106
  )
105
107
  # ex
106
108
  # history = model.fit(
@@ -123,38 +125,51 @@ class TFModelWrapper(ModelWrapper):
123
125
  if self.train_idx_arr and self.val_idx_arr:
124
126
  for i, (train_idx, val_idx) in enumerate(zip(self.train_idx_arr, self.val_idx_arr)):
125
127
  print(f"Validation {i}")
128
+ X_train_cv = self.train_data[train_idx]
129
+ y_train_cv = self.train_label[train_idx]
130
+ X_valid_cv = self.train_data[val_idx]
131
+ y_valid_cv = self.train_label[val_idx]
126
132
 
127
- if self.use_mlflow:
128
- with mlflow.start_run(run_name=f'Validation {i}', description='CV models', nested=True) as cv_fit:
129
- X_train_cv = self.train_data[train_idx]
130
- y_train_cv = self.train_label[train_idx]
131
- X_valid_cv = self.train_data[val_idx]
132
- y_valid_cv = self.train_label[val_idx]
133
- model, checkpoint = self.fit_model(
134
- X_train_cv, y_train_cv, X_valid_cv, y_valid_cv)
135
- models[f'{self.descriptor}_validation_{i}'] = model
136
- models_val_acc[f'{self.descriptor}_validation_{i}'] = checkpoint.best
137
- else:
138
- X_train_cv = self.train_data[train_idx]
139
- y_train_cv = self.train_label[train_idx]
140
- X_valid_cv = self.train_data[val_idx]
141
- y_valid_cv = self.train_label[val_idx]
133
+ def fit_model_cv_func():
142
134
  model, checkpoint = self.fit_model(
143
135
  X_train_cv, y_train_cv, X_valid_cv, y_valid_cv)
144
- models[f'{self.descriptor}_validation_{i}'] = model
145
- models_val_acc[f'{self.descriptor}_validation_{i}'] = checkpoint.best
136
+ return model, checkpoint
137
+
138
+ if self.tracking_service:
139
+ model, checkpoint = self.tracking_service.run(
140
+ run_name=f'Validation {i}', description='CV models', nested_run=True, func=fit_model_cv_func)
141
+ # with mlflow.start_run(run_name=f'Validation {i}', description='CV models', nested=True) as cv_fit:
142
+
143
+ # model, checkpoint = self.fit_model(
144
+ # X_train_cv, y_train_cv, X_valid_cv, y_valid_cv)
145
+ # models[f'{self.descriptor}_validation_{i}'] = model
146
+ # models_val_acc[f'{self.descriptor}_validation_{i}'] = checkpoint.best
147
+ else:
148
+ model, checkpoint = fit_model_cv_func()
149
+ # model, checkpoint = self.fit_model(
150
+ # X_train_cv, y_train_cv, X_valid_cv, y_valid_cv)
151
+ models[f'{self.descriptor}_validation_{i}'] = model
152
+ models_val_acc[f'{self.descriptor}_validation_{i}'] = checkpoint.best
146
153
  else:
147
- if self.use_mlflow:
148
- with mlflow.start_run(run_name=f'Training', description='Training', nested=True) as cv_fit:
149
- model, checkpoint = self.fit_model(
150
- self.train_data, self.train_label, None, None)
151
- models[self.descriptor] = model
152
- models_val_acc[self.descriptor] = checkpoint.best
153
- else:
154
+ def fit_model_func():
154
155
  model, checkpoint = self.fit_model(
155
156
  self.train_data, self.train_label, None, None)
156
- models[self.descriptor] = model
157
- models_val_acc[self.descriptor] = checkpoint.best
157
+ return model, checkpoint
158
+
159
+ if self.tracking_service:
160
+ model, checkpoint = self.tracking_service.run(
161
+ run_name=f'Training', description='Training', nested_run=True, func=fit_model_func)
162
+ # with mlflow.start_run(run_name=f'Training', description='Training', nested=True) as cv_fit:
163
+ # model, checkpoint = self.fit_model(
164
+ # self.train_data, self.train_label, None, None)
165
+ # models[self.descriptor] = model
166
+ # models_val_acc[self.descriptor] = checkpoint.best
167
+ else:
168
+ model, checkpoint = fit_model_func()
169
+ # models[self.descriptor] = model
170
+ # models_val_acc[self.descriptor] = checkpoint.best
171
+ models[self.descriptor] = model
172
+ models_val_acc[self.descriptor] = checkpoint.best
158
173
  if models_val_acc == {}:
159
174
  return model, None
160
175
  best_model_key = max(models_val_acc, key=lambda k: models_val_acc[k])
@@ -172,62 +187,113 @@ class TFModelWrapper(ModelWrapper):
172
187
  pred = self.best_model.predict(test_dataset)
173
188
  return pred
174
189
 
190
+ # def fit_and_evaluate(self, print_detail=False) -> tuple[dict[str, Any], Metrics, Any]:
191
+ # if self.use_mlflow:
192
+ # with mlflow.start_run(run_name=self.descriptor, description="***", nested=True) as run:
193
+ # best_model, best_model_key = self.fit()
194
+ # self.best_model: Model = best_model
195
+ # pred = self.predict()
196
+ # actual = self.test_label
197
+ # # if not utils.is_binary_encoded(pred):
198
+ # # pred = tf.keras.utils.to_categorical(np.argmax(pred,axis=1), num_classes=self.num_classes)
199
+ # pred_as_cat = convert_to_categorical(pred, self.num_classes)
200
+ # actual_as_cat = convert_to_categorical(
201
+ # actual, self.num_classes)
202
+
203
+ # logs, metrics = evaluate(
204
+ # actual=actual_as_cat, pred=pred_as_cat, info=self.descriptor, print_detail=print_detail)
205
+ # metrics.format_float()
206
+ # mlflow.log_metrics(logs)
207
+ # mlflow.log_param('best_cv', best_model_key)
208
+ # utils.compress_and_save_data(
209
+ # metrics.__dict__, run.info.artifact_uri, f'{self.date}_metrics.gzip')
210
+ # mlflow.log_artifact(
211
+ # f'{run.info.artifact_uri}/{self.date}_metrics.gzip')
212
+
213
+ # return logs, metrics, pred
214
+ # else:
215
+ # best_model, best_model_key = self.fit()
216
+ # self.best_model = best_model
217
+ # pred = self.predict()
218
+ # actual = self.test_label
219
+
220
+ # pred_as_cat = convert_to_categorical(pred, self.num_classes)
221
+ # actual_as_cat = convert_to_categorical(actual, self.num_classes)
222
+ # logs, metrics = evaluate(
223
+ # actual=actual_as_cat, pred=pred_as_cat, info=self.descriptor)
224
+ # metrics.format_float()
225
+ # return logs, metrics, pred
226
+
175
227
  def fit_and_evaluate(self, print_detail=False) -> tuple[dict[str, Any], Metrics, Any]:
176
- if self.use_mlflow:
177
- with mlflow.start_run(run_name=self.descriptor, description="***", nested=True) as run:
178
- print(run.info.artifact_uri)
179
- best_model, best_model_key = self.fit()
180
- print(best_model_key)
181
- self.best_model: Model = best_model
182
- pred = self.predict()
183
- actual = self.test_label
184
- # if not utils.is_binary_encoded(pred):
185
- # pred = tf.keras.utils.to_categorical(np.argmax(pred,axis=1), num_classes=self.num_classes)
186
- pred_as_cat= convert_to_categorical(pred, self.num_classes)
187
- actual_as_cat= convert_to_categorical(actual, self.num_classes)
188
-
189
- logs, metrics = evaluate(
190
- actual=actual_as_cat, pred=pred_as_cat, info=self.descriptor, print_detail=print_detail)
191
- metrics.format_float()
192
- mlflow.log_metrics(logs)
193
- mlflow.log_param('best_cv', best_model_key)
194
- utils.compress_and_save_data(
195
- metrics.__dict__, run.info.artifact_uri, f'{self.date}_metrics.gzip')
196
- mlflow.log_artifact(
197
- f'{run.info.artifact_uri}/{self.date}_metrics.gzip')
198
-
199
- return logs, metrics, pred
200
- else:
228
+ """
229
+ Fit the model, evaluate it, and log results using the tracking service.
230
+
231
+ Args:
232
+ print_detail (bool): Whether to print detailed evaluation logs.
233
+
234
+ Returns:
235
+ tuple: A tuple containing logs, metrics, and predictions.
236
+ """
237
+ self.best_model: Model = None
238
+ def evaluate_and_log(artifact_uri=None):
239
+ # Fit the model
201
240
  best_model, best_model_key = self.fit()
202
- print(best_model_key)
203
241
  self.best_model = best_model
242
+
243
+ # Make predictions
204
244
  pred = self.predict()
205
245
  actual = self.test_label
206
- # if not utils.is_binary_encoded(pred):
207
- # pred = tf.keras.utils.to_categorical(np.argmax(pred,axis=1), num_classes=self.num_classes)
208
- # if not utils.is_binary_encoded(actual):
209
- # actual = tf.keras.utils.to_categorical(actual, num_classes=self.num_classes)
210
- pred= convert_to_categorical(pred, self.num_classes)
211
- actual= convert_to_categorical(actual, self.num_classes)
246
+
247
+ # Convert predictions and actual labels to categorical format
248
+ pred_as_cat = convert_to_categorical(pred, self.num_classes)
249
+ actual_as_cat = convert_to_categorical(actual, self.num_classes)
250
+
251
+ # Evaluate the model
212
252
  logs, metrics = evaluate(
213
- actual=actual, pred=pred, info=self.descriptor)
253
+ actual=actual_as_cat, pred=pred_as_cat, info=self.descriptor, print_detail=print_detail
254
+ )
214
255
  metrics.format_float()
215
- return logs, metrics, pred
216
256
 
257
+ if self.tracking_service:
258
+ # Log metrics and parameters
259
+ self.tracking_service.log_metrics(logs)
260
+ self.tracking_service.log_param('best_cv', best_model_key)
217
261
 
218
- """
219
- Custom Keras callback for logging training metrics and model summary to MLflow.
220
- """
262
+ # Save metrics to the artifact URI if provided
263
+ if artifact_uri:
264
+ utils.compress_and_save_data(
265
+ metrics.__dict__, artifact_uri, f'{self.date}_metrics.gzip'
266
+ )
267
+ self.tracking_service.log_artifact(
268
+ f'{artifact_uri}/{self.date}_metrics.gzip'
269
+ )
270
+
271
+ return logs, metrics, pred
272
+
273
+ # Use the tracking service to run the evaluation
274
+ if self.tracking_service:
275
+ return self.tracking_service.run(
276
+ run_name=self.descriptor,
277
+ description="Fit and evaluate the model",
278
+ nested_run=True,
279
+ func=evaluate_and_log
280
+ )
281
+ else:
282
+ # If no tracking service is provided, run the evaluation directly
283
+ return evaluate_and_log()
221
284
 
222
285
 
223
286
  class CustomCallback(Callback):
224
- def __init__(self, use_mlflow: bool = True):
287
+ """
288
+ Custom Keras callback for logging training metrics and model summary to MLflow.
289
+ """
290
+ def __init__(self, tracking_service: TrackingService):
225
291
  super().__init__()
226
- self.use_mlflow = use_mlflow
292
+ self.tracking_service = tracking_service
227
293
 
228
- def _mlflow_log(self, func: Callable):
229
- if self.use_mlflow:
230
- func()
294
+ # def _mlflow_log(self, func: Callable):
295
+ # if self.use_mlflow:
296
+ # func()
231
297
 
232
298
  def on_train_begin(self, logs=None):
233
299
  if logs is None:
@@ -236,24 +302,32 @@ class CustomCallback(Callback):
236
302
  raise TypeError("self.model must be an instance of Model")
237
303
 
238
304
  keys = list(logs.keys())
239
- self._mlflow_log(lambda: mlflow.log_param("train_begin_keys", keys))
305
+
306
+ self.tracking_service.log_param("train_begin_keys", keys)
307
+ # self._mlflow_log(lambda: mlflow.log_param("train_begin_keys", keys))
308
+
240
309
  # config = self.model.optimizer.get_config()
241
310
  config = self.model.get_config()
242
311
  for attribute in config:
243
- self._mlflow_log(lambda: mlflow.log_param(
244
- "opt_" + attribute, config[attribute]))
312
+ self.tracking_service.log_param(
313
+ "opt_" + attribute, config[attribute])
314
+ # self._mlflow_log(lambda: mlflow.log_param(
315
+ # "opt_" + attribute, config[attribute]))
245
316
 
246
317
  sum_list = []
247
318
  self.model.summary(print_fn=sum_list.append)
248
319
  summary = "\n".join(sum_list)
249
- self._mlflow_log(lambda: mlflow.log_text(
250
- summary, artifact_file="model_summary.txt"))
320
+ self.tracking_service.log_text(
321
+ summary, file_name="model_summary.txt")
322
+ # self._mlflow_log(lambda: mlflow.log_text(
323
+ # summary, artifact_file="model_summary.txt"))
251
324
 
252
325
  def on_train_end(self, logs=None):
253
326
  if logs is None:
254
327
  logs = {}
255
328
  print(logs)
256
- self._mlflow_log(lambda: mlflow.log_metrics(logs))
329
+ self.tracking_service.log_metrics(logs)
330
+ # self._mlflow_log(lambda: mlflow.log_metrics(logs))
257
331
 
258
332
  def on_epoch_begin(self, epoch, logs=None):
259
333
  if logs is None:
@@ -273,7 +347,8 @@ class CustomCallback(Callback):
273
347
  def on_test_end(self, logs=None):
274
348
  if logs is None:
275
349
  logs = {}
276
- self._mlflow_log(lambda: mlflow.log_metrics(logs))
350
+ self.tracking_service.log_metrics(logs)
351
+ # self._mlflow_log(lambda: mlflow.log_metrics(logs))
277
352
  print(logs)
278
353
 
279
354
  def on_predict_begin(self, logs=None):
@@ -285,7 +360,8 @@ class CustomCallback(Callback):
285
360
  if logs is None:
286
361
  logs = {}
287
362
  keys = list(logs.keys())
288
- self._mlflow_log(lambda: mlflow.log_metrics(logs))
363
+ self.tracking_service.log_metrics(logs)
364
+ # self._mlflow_log(lambda: mlflow.log_metrics(logs))
289
365
 
290
366
  def on_train_batch_begin(self, batch, logs=None):
291
367
  if logs is None:
@@ -0,0 +1,194 @@
1
+ import inspect
2
+ import os
3
+ from typing import Optional, Dict, Any
4
+ import logging
5
+ from urllib.parse import urlparse
6
+ import mlflow
7
+ from abc import ABC, abstractmethod
8
+ from typing import Callable, Optional, Dict, Any
9
+
10
+
11
+ def normalize_artifact_uri(artifact_uri: str) -> str:
12
+ """
13
+ Normalize the artifact URI to a standard file path.
14
+
15
+ Args:
16
+ artifact_uri (str): The artifact URI to normalize.
17
+
18
+ Returns:
19
+ str: The normalized file path.
20
+ """
21
+ if artifact_uri.startswith("file:///"):
22
+ parsed_uri = urlparse(artifact_uri)
23
+ return os.path.abspath(os.path.join(parsed_uri.path.lstrip('/')))
24
+ return artifact_uri
25
+
26
+ class Tracking(ABC):
27
+ def __init__(self, experiment_name: str, tracking_params: Optional[Dict[str, Any]] = None):
28
+ """
29
+ Initialize the tracking backend.
30
+
31
+ Args:
32
+ experiment_name (str): The name of the experiment.
33
+ experiment_tags (dict, optional): Tags for the experiment.
34
+ """
35
+ self.experiment_name = experiment_name
36
+ self.tracking_params = tracking_params or {}
37
+
38
+ @abstractmethod
39
+ def setup_experiment(self):
40
+ """Set up the experiment in the tracking backend."""
41
+ pass
42
+
43
+ @abstractmethod
44
+ def run(self, run_name: str, description:str, func: Callable, nested_run: bool = False):
45
+ """Run the experiment with the given function."""
46
+ pass
47
+
48
+ @abstractmethod
49
+ def log_text(self, content:str, file_name: str):
50
+ """Log parameters to the tracking backend."""
51
+ pass
52
+ @abstractmethod
53
+ def log_param(self, key:str, value: Any):
54
+ """Log parameters to the tracking backend."""
55
+ pass
56
+ @abstractmethod
57
+ def log_params(self, params: Dict[str, Any]):
58
+ """Log parameters to the tracking backend."""
59
+ pass
60
+
61
+ @abstractmethod
62
+ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
63
+ """Log metrics to the tracking backend."""
64
+ pass
65
+
66
+ @abstractmethod
67
+ def log_artifact(self, artifact_path: str):
68
+ """Log an artifact to the tracking backend."""
69
+ pass
70
+
71
+
72
+ logger = logging.getLogger(__name__)
73
+
74
+
75
+ class MLFlowTracking(Tracking):
76
+ def __init__(self, experiment_name: str, tracking_params: Optional[Dict[str, Any]] = None):
77
+ """
78
+ Initialize the MLFlowTracking backend.
79
+
80
+ Args:
81
+ experiment_name (str): The name of the experiment.
82
+ tracking_params (dict, optional): Parameters for MLflow tracking.
83
+ """
84
+ super().__init__(experiment_name, tracking_params)
85
+ if tracking_params:
86
+ self.experiment_tags = tracking_params.get("experiment_tags", {})
87
+
88
+ def setup_experiment(self):
89
+ """Set up an experiment in MLflow."""
90
+ tracking_uri = self.tracking_params.get("tracking_uri")
91
+ if not tracking_uri:
92
+ raise ValueError("Tracking URI must be specified for MLflow.")
93
+
94
+ mlflow.set_tracking_uri(tracking_uri)
95
+
96
+ if mlflow.get_experiment_by_name(self.experiment_name) is None:
97
+ artifact_location = self.tracking_params.get("artifact_location")
98
+ mlflow.create_experiment(self.experiment_name, artifact_location)
99
+ logger.info(
100
+ f"Created new MLflow experiment: {self.experiment_name}")
101
+
102
+ mlflow.set_experiment(self.experiment_name)
103
+
104
+ if self.experiment_tags:
105
+ mlflow.set_experiment_tags(self.experiment_tags)
106
+ logger.info(
107
+ f"Set tags for MLflow experiment '{self.experiment_name}': {self.experiment_tags}")
108
+
109
+ def run(self, run_name: str, description:str, func: Callable, nested_run: bool = False):
110
+ """Run the experiment with the given function."""
111
+ func_signature = inspect.signature(func)
112
+
113
+ if nested_run:
114
+ with mlflow.start_run(run_name=run_name, description= description, nested=True) as run:
115
+ if "artifact_uri" in func_signature.parameters:
116
+ artifact_uri = normalize_artifact_uri(run.info.artifact_uri) if run.info.artifact_uri else ""
117
+ return func(artifact_uri=artifact_uri)
118
+ else:
119
+ return func()
120
+ else:
121
+ with mlflow.start_run(run_name=run_name, description= description) as run:
122
+ if "artifact_uri" in func_signature.parameters:
123
+ artifact_uri = normalize_artifact_uri(run.info.artifact_uri) if run.info.artifact_uri else ""
124
+ return func(artifact_uri=artifact_uri)
125
+ else:
126
+ return func()
127
+
128
+ def log_text(self, content: str, file_name: str):
129
+ mlflow.log_text(
130
+ content, artifact_file=file_name)
131
+
132
+ def log_param(self, key: str, value: Any):
133
+ mlflow.log_param(key, value)
134
+
135
+ def log_params(self, params: Dict[str, Any]):
136
+ """Log parameters to MLflow."""
137
+ mlflow.log_params(params)
138
+
139
+ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
140
+ """Log metrics to MLflow."""
141
+ mlflow.log_metrics(metrics, step=step)
142
+
143
+ def log_artifact(self, artifact_path: str):
144
+ """Log an artifact to MLflow."""
145
+ mlflow.log_artifact(artifact_path)
146
+
147
+
148
+ class TrackingService:
149
+ def __init__(self, experiment_name: str, backend: str = "mlflow", tracking_params: Optional[Dict[str, Any]] = None):
150
+ """
151
+ Initialize the TrackingService.
152
+
153
+ Args:
154
+ backend (str): The tracking backend to use (e.g., "mlflow").
155
+ tracking_params (dict, optional): Parameters for the tracking backend.
156
+ """
157
+ self.experiment_name = experiment_name
158
+ self.backend = backend.lower()
159
+ self.tracking_params = tracking_params or {}
160
+ self.tracking_instance = self._initialize_backend()
161
+
162
+ def _initialize_backend(self) -> Tracking:
163
+ """Initialize the appropriate tracking backend."""
164
+ if self.backend == "mlflow":
165
+ return MLFlowTracking(self.experiment_name, self.tracking_params)
166
+ else:
167
+ raise ValueError(f"Unsupported tracking backend: {self.backend}")
168
+
169
+ def setup(self):
170
+ """Set up the experiment in the tracking backend."""
171
+ self.tracking_instance.setup_experiment()
172
+
173
+ def run(self, run_name: str, description:str ,func: Callable, nested_run: bool = False) -> Any:
174
+ """Run the experiment with the given function."""
175
+ return self.tracking_instance.run(run_name, description , func, nested_run=nested_run)
176
+
177
+ def log_text(self, content: str, file_name: str):
178
+ self.tracking_instance.log_text(content, file_name)
179
+
180
+ def log_param(self, key: str, value: Any):
181
+ """Log a parameter to the tracking backend."""
182
+ self.tracking_instance.log_param(key, value)
183
+
184
+ def log_params(self, params: Dict[str, Any]):
185
+ """Log parameters to the tracking backend."""
186
+ self.tracking_instance.log_params(params)
187
+
188
+ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
189
+ """Log metrics to the tracking backend."""
190
+ self.tracking_instance.log_metrics(metrics, step=step)
191
+
192
+ def log_artifact(self, artifact_path: str):
193
+ """Log an artifact to the tracking backend."""
194
+ self.tracking_instance.log_artifact(artifact_path)