ddi-fw 0.0.217__py3-none-any.whl → 0.0.219__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddi_fw/datasets/core.py +1 -0
- ddi_fw/datasets/ddi_mdl/base.py +24 -8
- ddi_fw/datasets/mdf_sa_ddi/base.py +266 -55
- ddi_fw/ml/__init__.py +2 -1
- ddi_fw/ml/ml_helper.py +26 -30
- ddi_fw/ml/model_wrapper.py +0 -1
- ddi_fw/ml/tensorflow_wrapper.py +165 -89
- ddi_fw/ml/tracking_service.py +194 -0
- ddi_fw/pipeline/multi_pipeline.py +52 -32
- ddi_fw/pipeline/{multi_pipeline_v2.py → multi_pipeline_org.py} +25 -48
- ddi_fw/pipeline/pipeline.py +38 -96
- ddi_fw/utils/utils.py +51 -51
- {ddi_fw-0.0.217.dist-info → ddi_fw-0.0.219.dist-info}/METADATA +1 -1
- {ddi_fw-0.0.217.dist-info → ddi_fw-0.0.219.dist-info}/RECORD +16 -15
- {ddi_fw-0.0.217.dist-info → ddi_fw-0.0.219.dist-info}/WHEEL +0 -0
- {ddi_fw-0.0.217.dist-info → ddi_fw-0.0.219.dist-info}/top_level.txt +0 -0
ddi_fw/ml/tensorflow_wrapper.py
CHANGED
@@ -7,18 +7,16 @@ from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, Callback
|
|
7
7
|
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
|
8
8
|
import numpy as np
|
9
9
|
from tensorflow.keras import Model
|
10
|
-
import mlflow
|
11
|
-
from mlflow.utils.autologging_utils import batch_metrics_logger
|
12
|
-
|
13
|
-
from mlflow.models import infer_signature
|
14
10
|
from ddi_fw.ml.evaluation_helper import Metrics, evaluate
|
15
11
|
|
16
12
|
# import tf2onnx
|
17
13
|
# import onnx
|
18
14
|
|
15
|
+
from ddi_fw.ml.tracking_service import TrackingService
|
19
16
|
import ddi_fw.utils as utils
|
20
17
|
import os
|
21
18
|
|
19
|
+
|
22
20
|
def convert_to_categorical(arr, num_classes):
|
23
21
|
"""
|
24
22
|
This function takes an array of labels and converts them to one-hot encoding
|
@@ -33,7 +31,7 @@ def convert_to_categorical(arr, num_classes):
|
|
33
31
|
- The one-hot encoded array if the original array was binary or label encoded
|
34
32
|
- The original array if it doesn't require any conversion
|
35
33
|
"""
|
36
|
-
|
34
|
+
|
37
35
|
try:
|
38
36
|
# First, check if the array is binary-encoded
|
39
37
|
if not utils.is_binary_encoded(arr):
|
@@ -45,7 +43,7 @@ def convert_to_categorical(arr, num_classes):
|
|
45
43
|
except Exception as e:
|
46
44
|
# If binary encoding check raises an error, print it and continue to label encoding check
|
47
45
|
print(f"Error while checking binary encoding: {e}")
|
48
|
-
|
46
|
+
|
49
47
|
try:
|
50
48
|
# Check if the array is label-encoded
|
51
49
|
if utils.is_label_encoded(arr):
|
@@ -56,21 +54,21 @@ def convert_to_categorical(arr, num_classes):
|
|
56
54
|
print(f"Error while checking label encoding: {e}")
|
57
55
|
# If the arr labels don't match any of the known encodings, raise an error
|
58
56
|
raise ValueError("Unknown label encoding format.")
|
59
|
-
|
57
|
+
|
60
58
|
# If no conversion was needed, return the original array
|
61
|
-
|
59
|
+
|
62
60
|
return arr
|
63
61
|
|
64
62
|
|
65
63
|
class TFModelWrapper(ModelWrapper):
|
66
64
|
|
67
|
-
def __init__(self, date, descriptor, model_func,
|
65
|
+
def __init__(self, date, descriptor, model_func, tracking_service: TrackingService | None = None, **kwargs):
|
68
66
|
super().__init__(date, descriptor, model_func, **kwargs)
|
69
67
|
self.batch_size = kwargs.get('batch_size', 128)
|
70
68
|
self.epochs = kwargs.get('epochs', 100)
|
71
|
-
self.
|
69
|
+
self.tracking_service = tracking_service
|
72
70
|
|
73
|
-
# TODO think different settings for num_classes
|
71
|
+
# TODO think different settings for num_classes
|
74
72
|
def fit_model(self, X_train, y_train, X_valid, y_valid):
|
75
73
|
self.kwargs['input_shape'] = self.train_data.shape
|
76
74
|
self.num_classes = len(np.unique(y_train, axis=0))
|
@@ -85,7 +83,11 @@ class TFModelWrapper(ModelWrapper):
|
|
85
83
|
)
|
86
84
|
early_stopping = EarlyStopping(
|
87
85
|
monitor='val_loss', patience=10, mode='auto')
|
88
|
-
|
86
|
+
if self.tracking_service:
|
87
|
+
custom_callback = CustomCallback(self.tracking_service)
|
88
|
+
callbacks=[early_stopping, checkpoint, custom_callback]
|
89
|
+
else:
|
90
|
+
callbacks=[early_stopping, checkpoint]
|
89
91
|
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
|
90
92
|
train_dataset = train_dataset.batch(batch_size=self.batch_size)
|
91
93
|
|
@@ -100,7 +102,7 @@ class TFModelWrapper(ModelWrapper):
|
|
100
102
|
train_dataset,
|
101
103
|
epochs=self.epochs,
|
102
104
|
# validation_data=val_dataset,
|
103
|
-
callbacks=
|
105
|
+
callbacks=callbacks
|
104
106
|
)
|
105
107
|
# ex
|
106
108
|
# history = model.fit(
|
@@ -123,38 +125,51 @@ class TFModelWrapper(ModelWrapper):
|
|
123
125
|
if self.train_idx_arr and self.val_idx_arr:
|
124
126
|
for i, (train_idx, val_idx) in enumerate(zip(self.train_idx_arr, self.val_idx_arr)):
|
125
127
|
print(f"Validation {i}")
|
128
|
+
X_train_cv = self.train_data[train_idx]
|
129
|
+
y_train_cv = self.train_label[train_idx]
|
130
|
+
X_valid_cv = self.train_data[val_idx]
|
131
|
+
y_valid_cv = self.train_label[val_idx]
|
126
132
|
|
127
|
-
|
128
|
-
with mlflow.start_run(run_name=f'Validation {i}', description='CV models', nested=True) as cv_fit:
|
129
|
-
X_train_cv = self.train_data[train_idx]
|
130
|
-
y_train_cv = self.train_label[train_idx]
|
131
|
-
X_valid_cv = self.train_data[val_idx]
|
132
|
-
y_valid_cv = self.train_label[val_idx]
|
133
|
-
model, checkpoint = self.fit_model(
|
134
|
-
X_train_cv, y_train_cv, X_valid_cv, y_valid_cv)
|
135
|
-
models[f'{self.descriptor}_validation_{i}'] = model
|
136
|
-
models_val_acc[f'{self.descriptor}_validation_{i}'] = checkpoint.best
|
137
|
-
else:
|
138
|
-
X_train_cv = self.train_data[train_idx]
|
139
|
-
y_train_cv = self.train_label[train_idx]
|
140
|
-
X_valid_cv = self.train_data[val_idx]
|
141
|
-
y_valid_cv = self.train_label[val_idx]
|
133
|
+
def fit_model_cv_func():
|
142
134
|
model, checkpoint = self.fit_model(
|
143
135
|
X_train_cv, y_train_cv, X_valid_cv, y_valid_cv)
|
144
|
-
|
145
|
-
|
136
|
+
return model, checkpoint
|
137
|
+
|
138
|
+
if self.tracking_service:
|
139
|
+
model, checkpoint = self.tracking_service.run(
|
140
|
+
run_name=f'Validation {i}', description='CV models', nested_run=True, func=fit_model_cv_func)
|
141
|
+
# with mlflow.start_run(run_name=f'Validation {i}', description='CV models', nested=True) as cv_fit:
|
142
|
+
|
143
|
+
# model, checkpoint = self.fit_model(
|
144
|
+
# X_train_cv, y_train_cv, X_valid_cv, y_valid_cv)
|
145
|
+
# models[f'{self.descriptor}_validation_{i}'] = model
|
146
|
+
# models_val_acc[f'{self.descriptor}_validation_{i}'] = checkpoint.best
|
147
|
+
else:
|
148
|
+
model, checkpoint = fit_model_cv_func()
|
149
|
+
# model, checkpoint = self.fit_model(
|
150
|
+
# X_train_cv, y_train_cv, X_valid_cv, y_valid_cv)
|
151
|
+
models[f'{self.descriptor}_validation_{i}'] = model
|
152
|
+
models_val_acc[f'{self.descriptor}_validation_{i}'] = checkpoint.best
|
146
153
|
else:
|
147
|
-
|
148
|
-
with mlflow.start_run(run_name=f'Training', description='Training', nested=True) as cv_fit:
|
149
|
-
model, checkpoint = self.fit_model(
|
150
|
-
self.train_data, self.train_label, None, None)
|
151
|
-
models[self.descriptor] = model
|
152
|
-
models_val_acc[self.descriptor] = checkpoint.best
|
153
|
-
else:
|
154
|
+
def fit_model_func():
|
154
155
|
model, checkpoint = self.fit_model(
|
155
156
|
self.train_data, self.train_label, None, None)
|
156
|
-
|
157
|
-
|
157
|
+
return model, checkpoint
|
158
|
+
|
159
|
+
if self.tracking_service:
|
160
|
+
model, checkpoint = self.tracking_service.run(
|
161
|
+
run_name=f'Training', description='Training', nested_run=True, func=fit_model_func)
|
162
|
+
# with mlflow.start_run(run_name=f'Training', description='Training', nested=True) as cv_fit:
|
163
|
+
# model, checkpoint = self.fit_model(
|
164
|
+
# self.train_data, self.train_label, None, None)
|
165
|
+
# models[self.descriptor] = model
|
166
|
+
# models_val_acc[self.descriptor] = checkpoint.best
|
167
|
+
else:
|
168
|
+
model, checkpoint = fit_model_func()
|
169
|
+
# models[self.descriptor] = model
|
170
|
+
# models_val_acc[self.descriptor] = checkpoint.best
|
171
|
+
models[self.descriptor] = model
|
172
|
+
models_val_acc[self.descriptor] = checkpoint.best
|
158
173
|
if models_val_acc == {}:
|
159
174
|
return model, None
|
160
175
|
best_model_key = max(models_val_acc, key=lambda k: models_val_acc[k])
|
@@ -172,62 +187,113 @@ class TFModelWrapper(ModelWrapper):
|
|
172
187
|
pred = self.best_model.predict(test_dataset)
|
173
188
|
return pred
|
174
189
|
|
190
|
+
# def fit_and_evaluate(self, print_detail=False) -> tuple[dict[str, Any], Metrics, Any]:
|
191
|
+
# if self.use_mlflow:
|
192
|
+
# with mlflow.start_run(run_name=self.descriptor, description="***", nested=True) as run:
|
193
|
+
# best_model, best_model_key = self.fit()
|
194
|
+
# self.best_model: Model = best_model
|
195
|
+
# pred = self.predict()
|
196
|
+
# actual = self.test_label
|
197
|
+
# # if not utils.is_binary_encoded(pred):
|
198
|
+
# # pred = tf.keras.utils.to_categorical(np.argmax(pred,axis=1), num_classes=self.num_classes)
|
199
|
+
# pred_as_cat = convert_to_categorical(pred, self.num_classes)
|
200
|
+
# actual_as_cat = convert_to_categorical(
|
201
|
+
# actual, self.num_classes)
|
202
|
+
|
203
|
+
# logs, metrics = evaluate(
|
204
|
+
# actual=actual_as_cat, pred=pred_as_cat, info=self.descriptor, print_detail=print_detail)
|
205
|
+
# metrics.format_float()
|
206
|
+
# mlflow.log_metrics(logs)
|
207
|
+
# mlflow.log_param('best_cv', best_model_key)
|
208
|
+
# utils.compress_and_save_data(
|
209
|
+
# metrics.__dict__, run.info.artifact_uri, f'{self.date}_metrics.gzip')
|
210
|
+
# mlflow.log_artifact(
|
211
|
+
# f'{run.info.artifact_uri}/{self.date}_metrics.gzip')
|
212
|
+
|
213
|
+
# return logs, metrics, pred
|
214
|
+
# else:
|
215
|
+
# best_model, best_model_key = self.fit()
|
216
|
+
# self.best_model = best_model
|
217
|
+
# pred = self.predict()
|
218
|
+
# actual = self.test_label
|
219
|
+
|
220
|
+
# pred_as_cat = convert_to_categorical(pred, self.num_classes)
|
221
|
+
# actual_as_cat = convert_to_categorical(actual, self.num_classes)
|
222
|
+
# logs, metrics = evaluate(
|
223
|
+
# actual=actual_as_cat, pred=pred_as_cat, info=self.descriptor)
|
224
|
+
# metrics.format_float()
|
225
|
+
# return logs, metrics, pred
|
226
|
+
|
175
227
|
def fit_and_evaluate(self, print_detail=False) -> tuple[dict[str, Any], Metrics, Any]:
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
logs, metrics = evaluate(
|
190
|
-
actual=actual_as_cat, pred=pred_as_cat, info=self.descriptor, print_detail=print_detail)
|
191
|
-
metrics.format_float()
|
192
|
-
mlflow.log_metrics(logs)
|
193
|
-
mlflow.log_param('best_cv', best_model_key)
|
194
|
-
utils.compress_and_save_data(
|
195
|
-
metrics.__dict__, run.info.artifact_uri, f'{self.date}_metrics.gzip')
|
196
|
-
mlflow.log_artifact(
|
197
|
-
f'{run.info.artifact_uri}/{self.date}_metrics.gzip')
|
198
|
-
|
199
|
-
return logs, metrics, pred
|
200
|
-
else:
|
228
|
+
"""
|
229
|
+
Fit the model, evaluate it, and log results using the tracking service.
|
230
|
+
|
231
|
+
Args:
|
232
|
+
print_detail (bool): Whether to print detailed evaluation logs.
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
tuple: A tuple containing logs, metrics, and predictions.
|
236
|
+
"""
|
237
|
+
self.best_model: Model = None
|
238
|
+
def evaluate_and_log(artifact_uri=None):
|
239
|
+
# Fit the model
|
201
240
|
best_model, best_model_key = self.fit()
|
202
|
-
print(best_model_key)
|
203
241
|
self.best_model = best_model
|
242
|
+
|
243
|
+
# Make predictions
|
204
244
|
pred = self.predict()
|
205
245
|
actual = self.test_label
|
206
|
-
|
207
|
-
#
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
246
|
+
|
247
|
+
# Convert predictions and actual labels to categorical format
|
248
|
+
pred_as_cat = convert_to_categorical(pred, self.num_classes)
|
249
|
+
actual_as_cat = convert_to_categorical(actual, self.num_classes)
|
250
|
+
|
251
|
+
# Evaluate the model
|
212
252
|
logs, metrics = evaluate(
|
213
|
-
actual=
|
253
|
+
actual=actual_as_cat, pred=pred_as_cat, info=self.descriptor, print_detail=print_detail
|
254
|
+
)
|
214
255
|
metrics.format_float()
|
215
|
-
return logs, metrics, pred
|
216
256
|
|
257
|
+
if self.tracking_service:
|
258
|
+
# Log metrics and parameters
|
259
|
+
self.tracking_service.log_metrics(logs)
|
260
|
+
self.tracking_service.log_param('best_cv', best_model_key)
|
217
261
|
|
218
|
-
|
219
|
-
|
220
|
-
|
262
|
+
# Save metrics to the artifact URI if provided
|
263
|
+
if artifact_uri:
|
264
|
+
utils.compress_and_save_data(
|
265
|
+
metrics.__dict__, artifact_uri, f'{self.date}_metrics.gzip'
|
266
|
+
)
|
267
|
+
self.tracking_service.log_artifact(
|
268
|
+
f'{artifact_uri}/{self.date}_metrics.gzip'
|
269
|
+
)
|
270
|
+
|
271
|
+
return logs, metrics, pred
|
272
|
+
|
273
|
+
# Use the tracking service to run the evaluation
|
274
|
+
if self.tracking_service:
|
275
|
+
return self.tracking_service.run(
|
276
|
+
run_name=self.descriptor,
|
277
|
+
description="Fit and evaluate the model",
|
278
|
+
nested_run=True,
|
279
|
+
func=evaluate_and_log
|
280
|
+
)
|
281
|
+
else:
|
282
|
+
# If no tracking service is provided, run the evaluation directly
|
283
|
+
return evaluate_and_log()
|
221
284
|
|
222
285
|
|
223
286
|
class CustomCallback(Callback):
|
224
|
-
|
287
|
+
"""
|
288
|
+
Custom Keras callback for logging training metrics and model summary to MLflow.
|
289
|
+
"""
|
290
|
+
def __init__(self, tracking_service: TrackingService):
|
225
291
|
super().__init__()
|
226
|
-
self.
|
292
|
+
self.tracking_service = tracking_service
|
227
293
|
|
228
|
-
def _mlflow_log(self, func: Callable):
|
229
|
-
|
230
|
-
|
294
|
+
# def _mlflow_log(self, func: Callable):
|
295
|
+
# if self.use_mlflow:
|
296
|
+
# func()
|
231
297
|
|
232
298
|
def on_train_begin(self, logs=None):
|
233
299
|
if logs is None:
|
@@ -236,24 +302,32 @@ class CustomCallback(Callback):
|
|
236
302
|
raise TypeError("self.model must be an instance of Model")
|
237
303
|
|
238
304
|
keys = list(logs.keys())
|
239
|
-
|
305
|
+
|
306
|
+
self.tracking_service.log_param("train_begin_keys", keys)
|
307
|
+
# self._mlflow_log(lambda: mlflow.log_param("train_begin_keys", keys))
|
308
|
+
|
240
309
|
# config = self.model.optimizer.get_config()
|
241
310
|
config = self.model.get_config()
|
242
311
|
for attribute in config:
|
243
|
-
self.
|
244
|
-
"opt_" + attribute, config[attribute])
|
312
|
+
self.tracking_service.log_param(
|
313
|
+
"opt_" + attribute, config[attribute])
|
314
|
+
# self._mlflow_log(lambda: mlflow.log_param(
|
315
|
+
# "opt_" + attribute, config[attribute]))
|
245
316
|
|
246
317
|
sum_list = []
|
247
318
|
self.model.summary(print_fn=sum_list.append)
|
248
319
|
summary = "\n".join(sum_list)
|
249
|
-
self.
|
250
|
-
summary,
|
320
|
+
self.tracking_service.log_text(
|
321
|
+
summary, file_name="model_summary.txt")
|
322
|
+
# self._mlflow_log(lambda: mlflow.log_text(
|
323
|
+
# summary, artifact_file="model_summary.txt"))
|
251
324
|
|
252
325
|
def on_train_end(self, logs=None):
|
253
326
|
if logs is None:
|
254
327
|
logs = {}
|
255
328
|
print(logs)
|
256
|
-
self.
|
329
|
+
self.tracking_service.log_metrics(logs)
|
330
|
+
# self._mlflow_log(lambda: mlflow.log_metrics(logs))
|
257
331
|
|
258
332
|
def on_epoch_begin(self, epoch, logs=None):
|
259
333
|
if logs is None:
|
@@ -273,7 +347,8 @@ class CustomCallback(Callback):
|
|
273
347
|
def on_test_end(self, logs=None):
|
274
348
|
if logs is None:
|
275
349
|
logs = {}
|
276
|
-
self.
|
350
|
+
self.tracking_service.log_metrics(logs)
|
351
|
+
# self._mlflow_log(lambda: mlflow.log_metrics(logs))
|
277
352
|
print(logs)
|
278
353
|
|
279
354
|
def on_predict_begin(self, logs=None):
|
@@ -285,7 +360,8 @@ class CustomCallback(Callback):
|
|
285
360
|
if logs is None:
|
286
361
|
logs = {}
|
287
362
|
keys = list(logs.keys())
|
288
|
-
self.
|
363
|
+
self.tracking_service.log_metrics(logs)
|
364
|
+
# self._mlflow_log(lambda: mlflow.log_metrics(logs))
|
289
365
|
|
290
366
|
def on_train_batch_begin(self, batch, logs=None):
|
291
367
|
if logs is None:
|
@@ -0,0 +1,194 @@
|
|
1
|
+
import inspect
|
2
|
+
import os
|
3
|
+
from typing import Optional, Dict, Any
|
4
|
+
import logging
|
5
|
+
from urllib.parse import urlparse
|
6
|
+
import mlflow
|
7
|
+
from abc import ABC, abstractmethod
|
8
|
+
from typing import Callable, Optional, Dict, Any
|
9
|
+
|
10
|
+
|
11
|
+
def normalize_artifact_uri(artifact_uri: str) -> str:
|
12
|
+
"""
|
13
|
+
Normalize the artifact URI to a standard file path.
|
14
|
+
|
15
|
+
Args:
|
16
|
+
artifact_uri (str): The artifact URI to normalize.
|
17
|
+
|
18
|
+
Returns:
|
19
|
+
str: The normalized file path.
|
20
|
+
"""
|
21
|
+
if artifact_uri.startswith("file:///"):
|
22
|
+
parsed_uri = urlparse(artifact_uri)
|
23
|
+
return os.path.abspath(os.path.join(parsed_uri.path.lstrip('/')))
|
24
|
+
return artifact_uri
|
25
|
+
|
26
|
+
class Tracking(ABC):
|
27
|
+
def __init__(self, experiment_name: str, tracking_params: Optional[Dict[str, Any]] = None):
|
28
|
+
"""
|
29
|
+
Initialize the tracking backend.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
experiment_name (str): The name of the experiment.
|
33
|
+
experiment_tags (dict, optional): Tags for the experiment.
|
34
|
+
"""
|
35
|
+
self.experiment_name = experiment_name
|
36
|
+
self.tracking_params = tracking_params or {}
|
37
|
+
|
38
|
+
@abstractmethod
|
39
|
+
def setup_experiment(self):
|
40
|
+
"""Set up the experiment in the tracking backend."""
|
41
|
+
pass
|
42
|
+
|
43
|
+
@abstractmethod
|
44
|
+
def run(self, run_name: str, description:str, func: Callable, nested_run: bool = False):
|
45
|
+
"""Run the experiment with the given function."""
|
46
|
+
pass
|
47
|
+
|
48
|
+
@abstractmethod
|
49
|
+
def log_text(self, content:str, file_name: str):
|
50
|
+
"""Log parameters to the tracking backend."""
|
51
|
+
pass
|
52
|
+
@abstractmethod
|
53
|
+
def log_param(self, key:str, value: Any):
|
54
|
+
"""Log parameters to the tracking backend."""
|
55
|
+
pass
|
56
|
+
@abstractmethod
|
57
|
+
def log_params(self, params: Dict[str, Any]):
|
58
|
+
"""Log parameters to the tracking backend."""
|
59
|
+
pass
|
60
|
+
|
61
|
+
@abstractmethod
|
62
|
+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
|
63
|
+
"""Log metrics to the tracking backend."""
|
64
|
+
pass
|
65
|
+
|
66
|
+
@abstractmethod
|
67
|
+
def log_artifact(self, artifact_path: str):
|
68
|
+
"""Log an artifact to the tracking backend."""
|
69
|
+
pass
|
70
|
+
|
71
|
+
|
72
|
+
logger = logging.getLogger(__name__)
|
73
|
+
|
74
|
+
|
75
|
+
class MLFlowTracking(Tracking):
|
76
|
+
def __init__(self, experiment_name: str, tracking_params: Optional[Dict[str, Any]] = None):
|
77
|
+
"""
|
78
|
+
Initialize the MLFlowTracking backend.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
experiment_name (str): The name of the experiment.
|
82
|
+
tracking_params (dict, optional): Parameters for MLflow tracking.
|
83
|
+
"""
|
84
|
+
super().__init__(experiment_name, tracking_params)
|
85
|
+
if tracking_params:
|
86
|
+
self.experiment_tags = tracking_params.get("experiment_tags", {})
|
87
|
+
|
88
|
+
def setup_experiment(self):
|
89
|
+
"""Set up an experiment in MLflow."""
|
90
|
+
tracking_uri = self.tracking_params.get("tracking_uri")
|
91
|
+
if not tracking_uri:
|
92
|
+
raise ValueError("Tracking URI must be specified for MLflow.")
|
93
|
+
|
94
|
+
mlflow.set_tracking_uri(tracking_uri)
|
95
|
+
|
96
|
+
if mlflow.get_experiment_by_name(self.experiment_name) is None:
|
97
|
+
artifact_location = self.tracking_params.get("artifact_location")
|
98
|
+
mlflow.create_experiment(self.experiment_name, artifact_location)
|
99
|
+
logger.info(
|
100
|
+
f"Created new MLflow experiment: {self.experiment_name}")
|
101
|
+
|
102
|
+
mlflow.set_experiment(self.experiment_name)
|
103
|
+
|
104
|
+
if self.experiment_tags:
|
105
|
+
mlflow.set_experiment_tags(self.experiment_tags)
|
106
|
+
logger.info(
|
107
|
+
f"Set tags for MLflow experiment '{self.experiment_name}': {self.experiment_tags}")
|
108
|
+
|
109
|
+
def run(self, run_name: str, description:str, func: Callable, nested_run: bool = False):
|
110
|
+
"""Run the experiment with the given function."""
|
111
|
+
func_signature = inspect.signature(func)
|
112
|
+
|
113
|
+
if nested_run:
|
114
|
+
with mlflow.start_run(run_name=run_name, description= description, nested=True) as run:
|
115
|
+
if "artifact_uri" in func_signature.parameters:
|
116
|
+
artifact_uri = normalize_artifact_uri(run.info.artifact_uri) if run.info.artifact_uri else ""
|
117
|
+
return func(artifact_uri=artifact_uri)
|
118
|
+
else:
|
119
|
+
return func()
|
120
|
+
else:
|
121
|
+
with mlflow.start_run(run_name=run_name, description= description) as run:
|
122
|
+
if "artifact_uri" in func_signature.parameters:
|
123
|
+
artifact_uri = normalize_artifact_uri(run.info.artifact_uri) if run.info.artifact_uri else ""
|
124
|
+
return func(artifact_uri=artifact_uri)
|
125
|
+
else:
|
126
|
+
return func()
|
127
|
+
|
128
|
+
def log_text(self, content: str, file_name: str):
|
129
|
+
mlflow.log_text(
|
130
|
+
content, artifact_file=file_name)
|
131
|
+
|
132
|
+
def log_param(self, key: str, value: Any):
|
133
|
+
mlflow.log_param(key, value)
|
134
|
+
|
135
|
+
def log_params(self, params: Dict[str, Any]):
|
136
|
+
"""Log parameters to MLflow."""
|
137
|
+
mlflow.log_params(params)
|
138
|
+
|
139
|
+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
|
140
|
+
"""Log metrics to MLflow."""
|
141
|
+
mlflow.log_metrics(metrics, step=step)
|
142
|
+
|
143
|
+
def log_artifact(self, artifact_path: str):
|
144
|
+
"""Log an artifact to MLflow."""
|
145
|
+
mlflow.log_artifact(artifact_path)
|
146
|
+
|
147
|
+
|
148
|
+
class TrackingService:
|
149
|
+
def __init__(self, experiment_name: str, backend: str = "mlflow", tracking_params: Optional[Dict[str, Any]] = None):
|
150
|
+
"""
|
151
|
+
Initialize the TrackingService.
|
152
|
+
|
153
|
+
Args:
|
154
|
+
backend (str): The tracking backend to use (e.g., "mlflow").
|
155
|
+
tracking_params (dict, optional): Parameters for the tracking backend.
|
156
|
+
"""
|
157
|
+
self.experiment_name = experiment_name
|
158
|
+
self.backend = backend.lower()
|
159
|
+
self.tracking_params = tracking_params or {}
|
160
|
+
self.tracking_instance = self._initialize_backend()
|
161
|
+
|
162
|
+
def _initialize_backend(self) -> Tracking:
|
163
|
+
"""Initialize the appropriate tracking backend."""
|
164
|
+
if self.backend == "mlflow":
|
165
|
+
return MLFlowTracking(self.experiment_name, self.tracking_params)
|
166
|
+
else:
|
167
|
+
raise ValueError(f"Unsupported tracking backend: {self.backend}")
|
168
|
+
|
169
|
+
def setup(self):
|
170
|
+
"""Set up the experiment in the tracking backend."""
|
171
|
+
self.tracking_instance.setup_experiment()
|
172
|
+
|
173
|
+
def run(self, run_name: str, description:str ,func: Callable, nested_run: bool = False) -> Any:
|
174
|
+
"""Run the experiment with the given function."""
|
175
|
+
return self.tracking_instance.run(run_name, description , func, nested_run=nested_run)
|
176
|
+
|
177
|
+
def log_text(self, content: str, file_name: str):
|
178
|
+
self.tracking_instance.log_text(content, file_name)
|
179
|
+
|
180
|
+
def log_param(self, key: str, value: Any):
|
181
|
+
"""Log a parameter to the tracking backend."""
|
182
|
+
self.tracking_instance.log_param(key, value)
|
183
|
+
|
184
|
+
def log_params(self, params: Dict[str, Any]):
|
185
|
+
"""Log parameters to the tracking backend."""
|
186
|
+
self.tracking_instance.log_params(params)
|
187
|
+
|
188
|
+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
|
189
|
+
"""Log metrics to the tracking backend."""
|
190
|
+
self.tracking_instance.log_metrics(metrics, step=step)
|
191
|
+
|
192
|
+
def log_artifact(self, artifact_path: str):
|
193
|
+
"""Log an artifact to the tracking backend."""
|
194
|
+
self.tracking_instance.log_artifact(artifact_path)
|