qadence 1.7.8__py3-none-any.whl → 1.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qadence/__init__.py +1 -1
- qadence/analog/device.py +1 -1
- qadence/analog/parse_analog.py +1 -2
- qadence/backend.py +3 -3
- qadence/backends/gpsr.py +8 -2
- qadence/backends/horqrux/backend.py +3 -3
- qadence/backends/pulser/backend.py +21 -38
- qadence/backends/pulser/convert_ops.py +2 -2
- qadence/backends/pyqtorch/backend.py +85 -10
- qadence/backends/pyqtorch/config.py +10 -3
- qadence/backends/pyqtorch/convert_ops.py +245 -233
- qadence/backends/utils.py +9 -1
- qadence/blocks/abstract.py +1 -1
- qadence/blocks/embedding.py +21 -11
- qadence/blocks/matrix.py +3 -1
- qadence/blocks/primitive.py +37 -11
- qadence/circuit.py +1 -1
- qadence/constructors/__init__.py +2 -1
- qadence/constructors/ansatze.py +176 -0
- qadence/engines/differentiable_backend.py +3 -3
- qadence/engines/jax/differentiable_backend.py +2 -2
- qadence/engines/jax/differentiable_expectation.py +2 -2
- qadence/engines/torch/differentiable_backend.py +2 -2
- qadence/engines/torch/differentiable_expectation.py +2 -2
- qadence/execution.py +14 -16
- qadence/extensions.py +1 -1
- qadence/log_config.yaml +10 -0
- qadence/measurements/shadow.py +101 -133
- qadence/measurements/tomography.py +2 -2
- qadence/measurements/utils.py +4 -4
- qadence/mitigations/analog_zne.py +8 -7
- qadence/mitigations/protocols.py +2 -2
- qadence/mitigations/readout.py +14 -5
- qadence/ml_tools/__init__.py +4 -8
- qadence/ml_tools/callbacks/__init__.py +30 -0
- qadence/ml_tools/callbacks/callback.py +451 -0
- qadence/ml_tools/callbacks/callbackmanager.py +214 -0
- qadence/ml_tools/{saveload.py → callbacks/saveload.py} +11 -11
- qadence/ml_tools/callbacks/writer_registry.py +430 -0
- qadence/ml_tools/config.py +132 -258
- qadence/ml_tools/constructors.py +2 -2
- qadence/ml_tools/data.py +7 -3
- qadence/ml_tools/loss/__init__.py +10 -0
- qadence/ml_tools/loss/loss.py +87 -0
- qadence/ml_tools/models.py +7 -7
- qadence/ml_tools/optimize_step.py +45 -10
- qadence/ml_tools/stages.py +46 -0
- qadence/ml_tools/train_utils/__init__.py +7 -0
- qadence/ml_tools/train_utils/base_trainer.py +548 -0
- qadence/ml_tools/train_utils/config_manager.py +184 -0
- qadence/ml_tools/trainer.py +692 -0
- qadence/model.py +6 -6
- qadence/noise/__init__.py +2 -2
- qadence/noise/protocols.py +188 -36
- qadence/operations/control_ops.py +37 -22
- qadence/operations/ham_evo.py +88 -26
- qadence/operations/parametric.py +32 -10
- qadence/operations/primitive.py +61 -29
- qadence/overlap.py +0 -6
- qadence/parameters.py +3 -2
- qadence/transpile/__init__.py +2 -1
- qadence/transpile/noise.py +53 -0
- qadence/types.py +39 -3
- {qadence-1.7.8.dist-info → qadence-1.9.0.dist-info}/METADATA +5 -9
- {qadence-1.7.8.dist-info → qadence-1.9.0.dist-info}/RECORD +67 -63
- {qadence-1.7.8.dist-info → qadence-1.9.0.dist-info}/WHEEL +1 -1
- qadence/backends/braket/__init__.py +0 -4
- qadence/backends/braket/backend.py +0 -234
- qadence/backends/braket/config.py +0 -22
- qadence/backends/braket/convert_ops.py +0 -116
- qadence/ml_tools/printing.py +0 -153
- qadence/ml_tools/train_grad.py +0 -395
- qadence/ml_tools/train_no_grad.py +0 -199
- qadence/noise/readout.py +0 -218
- {qadence-1.7.8.dist-info → qadence-1.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -11,7 +11,7 @@ from nevergrad.optimization.base import Optimizer as NGOptimizer
|
|
11
11
|
from torch.nn import Module
|
12
12
|
from torch.optim import Optimizer
|
13
13
|
|
14
|
-
logger = getLogger(
|
14
|
+
logger = getLogger("ml_tools")
|
15
15
|
|
16
16
|
|
17
17
|
def get_latest_checkpoint_name(folder: Path, type: str, device: str | torch.device = "cpu") -> Path:
|
@@ -19,6 +19,7 @@ def get_latest_checkpoint_name(folder: Path, type: str, device: str | torch.devi
|
|
19
19
|
files = [f for f in os.listdir(folder) if f.endswith(".pt") and type in f]
|
20
20
|
if len(files) == 0:
|
21
21
|
logger.error(f"Directory {folder} does not contain any {type} checkpoints.")
|
22
|
+
pass
|
22
23
|
if len(files) == 1:
|
23
24
|
file = Path(files[0])
|
24
25
|
else:
|
@@ -66,8 +67,7 @@ def write_checkpoint(
|
|
66
67
|
iteration: int | str,
|
67
68
|
) -> None:
|
68
69
|
from qadence import QuantumModel
|
69
|
-
|
70
|
-
from .models import QNN
|
70
|
+
from qadence.ml_tools.models import QNN
|
71
71
|
|
72
72
|
device = None
|
73
73
|
try:
|
@@ -79,10 +79,8 @@ def write_checkpoint(
|
|
79
79
|
)
|
80
80
|
device = str(device).split(":")[0] # in case of using several CUDA devices
|
81
81
|
except Exception as e:
|
82
|
-
msg =
|
83
|
-
|
84
|
-
"Setting device to None"
|
85
|
-
)
|
82
|
+
msg = f"""Unable to identify in which device the QuantumModel is stored due to {e}.
|
83
|
+
Setting device to None"""
|
86
84
|
logger.warning(msg)
|
87
85
|
|
88
86
|
iteration_substring = f"{iteration:03n}" if isinstance(iteration, int) else iteration
|
@@ -135,7 +133,9 @@ def load_model(
|
|
135
133
|
model_ckpt_name = get_latest_checkpoint_name(folder, "model", device)
|
136
134
|
|
137
135
|
try:
|
138
|
-
iteration, model_dict = torch.load(
|
136
|
+
iteration, model_dict = torch.load(
|
137
|
+
folder / model_ckpt_name, weights_only=False, *args, **kwargs
|
138
|
+
)
|
139
139
|
if isinstance(model, (QuantumModel, QNN)):
|
140
140
|
model.load_params_from_dict(model_dict)
|
141
141
|
elif isinstance(model, Module):
|
@@ -146,8 +146,8 @@ def load_model(
|
|
146
146
|
model.to(device)
|
147
147
|
|
148
148
|
except Exception as e:
|
149
|
-
msg = f"Unable to load state dict due to {e}
|
150
|
-
|
149
|
+
msg = f"""Unable to load state dict due to {e}.
|
150
|
+
No corresponding pre-trained model found."""
|
151
151
|
logger.warning(msg)
|
152
152
|
return model, iteration
|
153
153
|
|
@@ -162,7 +162,7 @@ def load_optimizer(
|
|
162
162
|
opt_ckpt_name = get_latest_checkpoint_name(folder, "opt", device)
|
163
163
|
if os.path.isfile(folder / opt_ckpt_name):
|
164
164
|
if isinstance(optimizer, Optimizer):
|
165
|
-
(_, OptType, optimizer_state) = torch.load(folder / opt_ckpt_name)
|
165
|
+
(_, OptType, optimizer_state) = torch.load(folder / opt_ckpt_name, weights_only=False)
|
166
166
|
if isinstance(optimizer, OptType):
|
167
167
|
optimizer.load_state_dict(optimizer_state)
|
168
168
|
|
@@ -0,0 +1,430 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import os
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from logging import getLogger
|
6
|
+
from types import ModuleType
|
7
|
+
from typing import Any, Callable, Union
|
8
|
+
from uuid import uuid4
|
9
|
+
|
10
|
+
import mlflow
|
11
|
+
from matplotlib.figure import Figure
|
12
|
+
from mlflow.entities import Run
|
13
|
+
from mlflow.models import infer_signature
|
14
|
+
from torch import Tensor
|
15
|
+
from torch.nn import Module
|
16
|
+
from torch.utils.data import DataLoader
|
17
|
+
from torch.utils.tensorboard import SummaryWriter
|
18
|
+
|
19
|
+
from qadence.ml_tools.config import TrainConfig
|
20
|
+
from qadence.ml_tools.data import OptimizeResult
|
21
|
+
from qadence.types import ExperimentTrackingTool
|
22
|
+
|
23
|
+
logger = getLogger("ml_tools")
|
24
|
+
|
25
|
+
# Type aliases
|
26
|
+
PlottingFunction = Callable[[Module, int], tuple[str, Figure]]
|
27
|
+
InputData = Union[Tensor, dict[str, Tensor]]
|
28
|
+
|
29
|
+
|
30
|
+
class BaseWriter(ABC):
|
31
|
+
"""
|
32
|
+
Abstract base class for experiment tracking writers.
|
33
|
+
|
34
|
+
Methods:
|
35
|
+
open(config, iteration=None): Opens the writer and sets up the logging
|
36
|
+
environment.
|
37
|
+
close(): Closes the writer and finalizes any ongoing logging processes.
|
38
|
+
print_metrics(result): Prints metrics and loss in a formatted manner.
|
39
|
+
write(result): Writes the optimization results to the tracking tool.
|
40
|
+
log_hyperparams(hyperparams): Logs the hyperparameters to the tracking tool.
|
41
|
+
plot(model, iteration, plotting_functions): Logs model plots using provided
|
42
|
+
plotting functions.
|
43
|
+
log_model(model, dataloader): Logs the model and any relevant information.
|
44
|
+
"""
|
45
|
+
|
46
|
+
run: Run # [attr-defined]
|
47
|
+
|
48
|
+
@abstractmethod
|
49
|
+
def open(self, config: TrainConfig, iteration: int | None = None) -> Any:
|
50
|
+
"""
|
51
|
+
Opens the writer and prepares it for logging.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
config: Configuration object containing settings for logging.
|
55
|
+
iteration (int, optional): The iteration step to start logging from.
|
56
|
+
Defaults to None.
|
57
|
+
"""
|
58
|
+
raise NotImplementedError("Writers must implement an open method.")
|
59
|
+
|
60
|
+
@abstractmethod
|
61
|
+
def close(self) -> None:
|
62
|
+
"""Closes the writer and finalizes logging."""
|
63
|
+
raise NotImplementedError("Writers must implement a close method.")
|
64
|
+
|
65
|
+
@abstractmethod
|
66
|
+
def write(self, result: OptimizeResult) -> None:
|
67
|
+
"""
|
68
|
+
Logs the results of the current iteration.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
result (OptimizeResult): The optimization results to log.
|
72
|
+
"""
|
73
|
+
raise NotImplementedError("Writers must implement a write method.")
|
74
|
+
|
75
|
+
@abstractmethod
|
76
|
+
def log_hyperparams(self, hyperparams: dict) -> None:
|
77
|
+
"""
|
78
|
+
Logs hyperparameters.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
hyperparams (dict): A dictionary of hyperparameters to log.
|
82
|
+
"""
|
83
|
+
raise NotImplementedError("Writers must implement a log_hyperparams method.")
|
84
|
+
|
85
|
+
@abstractmethod
|
86
|
+
def plot(
|
87
|
+
self,
|
88
|
+
model: Module,
|
89
|
+
iteration: int,
|
90
|
+
plotting_functions: tuple[PlottingFunction, ...],
|
91
|
+
) -> None:
|
92
|
+
"""
|
93
|
+
Logs plots of the model using provided plotting functions.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
model (Module): The model to plot.
|
97
|
+
iteration (int): The current iteration number.
|
98
|
+
plotting_functions (tuple[PlottingFunction, ...]): Functions used to
|
99
|
+
generate plots.
|
100
|
+
"""
|
101
|
+
raise NotImplementedError("Writers must implement a plot method.")
|
102
|
+
|
103
|
+
@abstractmethod
|
104
|
+
def log_model(
|
105
|
+
self,
|
106
|
+
model: Module,
|
107
|
+
train_dataloader: DataLoader | None = None,
|
108
|
+
val_dataloader: DataLoader | None = None,
|
109
|
+
test_dataloader: DataLoader | None = None,
|
110
|
+
) -> None:
|
111
|
+
"""
|
112
|
+
Logs the model and associated data.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
model (Module): The model to log.
|
116
|
+
train_dataloader (DataLoader | None): DataLoader for training data.
|
117
|
+
val_dataloader (DataLoader | None): DataLoader for validation data.
|
118
|
+
test_dataloader (DataLoader | None): DataLoader for testing data.
|
119
|
+
"""
|
120
|
+
raise NotImplementedError("Writers must implement a log_model method.")
|
121
|
+
|
122
|
+
def print_metrics(self, result: OptimizeResult) -> None:
|
123
|
+
"""Prints the metrics and loss in a readable format.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
result (OptimizeResult): The optimization results to display.
|
127
|
+
"""
|
128
|
+
|
129
|
+
# Find the key in result.metrics that contains "loss" (case-insensitive)
|
130
|
+
loss_key = next((k for k in result.metrics if "loss" in k.lower()), None)
|
131
|
+
if loss_key:
|
132
|
+
loss_value = result.metrics[loss_key]
|
133
|
+
msg = f"Iteration {result.iteration: >7} | {loss_key.title()}: {loss_value:.7f} -"
|
134
|
+
else:
|
135
|
+
msg = f"Iteration {result.iteration: >7} | Loss: None -"
|
136
|
+
msg += " ".join([f"{k}: {v:.7f}" for k, v in result.metrics.items() if k != loss_key])
|
137
|
+
print(msg)
|
138
|
+
|
139
|
+
|
140
|
+
class TensorBoardWriter(BaseWriter):
|
141
|
+
"""Writer for logging to TensorBoard.
|
142
|
+
|
143
|
+
Attributes:
|
144
|
+
writer (SummaryWriter): The TensorBoard SummaryWriter instance.
|
145
|
+
"""
|
146
|
+
|
147
|
+
def __init__(self) -> None:
|
148
|
+
self.writer = None
|
149
|
+
|
150
|
+
def open(self, config: TrainConfig, iteration: int | None = None) -> SummaryWriter:
|
151
|
+
"""
|
152
|
+
Opens the TensorBoard writer.
|
153
|
+
|
154
|
+
Args:
|
155
|
+
config: Configuration object containing settings for logging.
|
156
|
+
iteration (int, optional): The iteration step to start logging from.
|
157
|
+
Defaults to None.
|
158
|
+
|
159
|
+
Returns:
|
160
|
+
SummaryWriter: The initialized TensorBoard writer.
|
161
|
+
"""
|
162
|
+
log_dir = str(config.log_folder)
|
163
|
+
purge_step = iteration if isinstance(iteration, int) else None
|
164
|
+
self.writer = SummaryWriter(log_dir=log_dir, purge_step=purge_step)
|
165
|
+
return self.writer
|
166
|
+
|
167
|
+
def close(self) -> None:
|
168
|
+
"""Closes the TensorBoard writer."""
|
169
|
+
if self.writer:
|
170
|
+
self.writer.close()
|
171
|
+
|
172
|
+
def write(self, result: OptimizeResult) -> None:
|
173
|
+
"""
|
174
|
+
Logs the results of the current iteration to TensorBoard.
|
175
|
+
|
176
|
+
Args:
|
177
|
+
result (OptimizeResult): The optimization results to log.
|
178
|
+
"""
|
179
|
+
# Not writing loss as loss is available in the metrics
|
180
|
+
# if result.loss is not None:
|
181
|
+
# self.writer.add_scalar("loss", float(result.loss), result.iteration)
|
182
|
+
if self.writer:
|
183
|
+
for key, value in result.metrics.items():
|
184
|
+
self.writer.add_scalar(key, value, result.iteration)
|
185
|
+
else:
|
186
|
+
raise RuntimeError(
|
187
|
+
"The writer is not initialized."
|
188
|
+
"Please call the 'writer.open()' method before writing"
|
189
|
+
)
|
190
|
+
|
191
|
+
def log_hyperparams(self, hyperparams: dict) -> None:
|
192
|
+
"""
|
193
|
+
Logs hyperparameters to TensorBoard.
|
194
|
+
|
195
|
+
Args:
|
196
|
+
hyperparams (dict): A dictionary of hyperparameters to log.
|
197
|
+
"""
|
198
|
+
if self.writer:
|
199
|
+
self.writer.add_hparams(hyperparams, {})
|
200
|
+
else:
|
201
|
+
raise RuntimeError(
|
202
|
+
"The writer is not initialized."
|
203
|
+
"Please call the 'writer.open()' method before writing"
|
204
|
+
)
|
205
|
+
|
206
|
+
def plot(
|
207
|
+
self,
|
208
|
+
model: Module,
|
209
|
+
iteration: int,
|
210
|
+
plotting_functions: tuple[PlottingFunction, ...],
|
211
|
+
) -> None:
|
212
|
+
"""
|
213
|
+
Logs plots of the model using provided plotting functions.
|
214
|
+
|
215
|
+
Args:
|
216
|
+
model (Module): The model to plot.
|
217
|
+
iteration (int): The current iteration number.
|
218
|
+
plotting_functions (tuple[PlottingFunction, ...]): Functions used
|
219
|
+
to generate plots.
|
220
|
+
"""
|
221
|
+
if self.writer:
|
222
|
+
for pf in plotting_functions:
|
223
|
+
descr, fig = pf(model, iteration)
|
224
|
+
self.writer.add_figure(descr, fig, global_step=iteration)
|
225
|
+
else:
|
226
|
+
raise RuntimeError(
|
227
|
+
"The writer is not initialized."
|
228
|
+
"Please call the 'writer.open()' method before writing"
|
229
|
+
)
|
230
|
+
|
231
|
+
def log_model(
|
232
|
+
self,
|
233
|
+
model: Module,
|
234
|
+
train_dataloader: DataLoader | None = None,
|
235
|
+
val_dataloader: DataLoader | None = None,
|
236
|
+
test_dataloader: DataLoader | None = None,
|
237
|
+
) -> None:
|
238
|
+
"""
|
239
|
+
Logs the model.
|
240
|
+
|
241
|
+
Currently not supported by TensorBoard.
|
242
|
+
|
243
|
+
Args:
|
244
|
+
model (Module): The model to log.
|
245
|
+
train_dataloader (DataLoader | None): DataLoader for training data.
|
246
|
+
val_dataloader (DataLoader | None): DataLoader for validation data.
|
247
|
+
test_dataloader (DataLoader | None): DataLoader for testing data.
|
248
|
+
"""
|
249
|
+
logger.warning("Model logging is not supported by tensorboard. No model will be logged.")
|
250
|
+
|
251
|
+
|
252
|
+
class MLFlowWriter(BaseWriter):
|
253
|
+
"""
|
254
|
+
Writer for logging to MLflow.
|
255
|
+
|
256
|
+
Attributes:
|
257
|
+
run: The active MLflow run.
|
258
|
+
mlflow: The MLflow module.
|
259
|
+
"""
|
260
|
+
|
261
|
+
def __init__(self) -> None:
|
262
|
+
self.run: Run
|
263
|
+
self.mlflow: ModuleType
|
264
|
+
|
265
|
+
def open(self, config: TrainConfig, iteration: int | None = None) -> ModuleType | None:
|
266
|
+
"""
|
267
|
+
Opens the MLflow writer and initializes an MLflow run.
|
268
|
+
|
269
|
+
Args:
|
270
|
+
config: Configuration object containing settings for logging.
|
271
|
+
iteration (int, optional): The iteration step to start logging from.
|
272
|
+
Defaults to None.
|
273
|
+
|
274
|
+
Returns:
|
275
|
+
mlflow: The MLflow module instance.
|
276
|
+
"""
|
277
|
+
self.mlflow = mlflow
|
278
|
+
tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "")
|
279
|
+
experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", str(uuid4()))
|
280
|
+
run_name = os.getenv("MLFLOW_RUN_NAME", str(uuid4()))
|
281
|
+
|
282
|
+
if self.mlflow:
|
283
|
+
self.mlflow.set_tracking_uri(tracking_uri)
|
284
|
+
|
285
|
+
# Create or get the experiment
|
286
|
+
exp_filter_string = f"name = '{experiment_name}'"
|
287
|
+
experiments = self.mlflow.search_experiments(filter_string=exp_filter_string)
|
288
|
+
if not experiments:
|
289
|
+
self.mlflow.create_experiment(name=experiment_name)
|
290
|
+
|
291
|
+
self.mlflow.set_experiment(experiment_name)
|
292
|
+
self.run = self.mlflow.start_run(run_name=run_name, nested=False)
|
293
|
+
|
294
|
+
return self.mlflow
|
295
|
+
|
296
|
+
def close(self) -> None:
|
297
|
+
"""Closes the MLflow run."""
|
298
|
+
if self.run:
|
299
|
+
self.mlflow.end_run()
|
300
|
+
|
301
|
+
def write(self, result: OptimizeResult) -> None:
|
302
|
+
"""
|
303
|
+
Logs the results of the current iteration to MLflow.
|
304
|
+
|
305
|
+
Args:
|
306
|
+
result (OptimizeResult): The optimization results to log.
|
307
|
+
"""
|
308
|
+
# Not writing loss as loss is available in the metrics
|
309
|
+
# if result.loss is not None:
|
310
|
+
# self.mlflow.log_metric("loss", float(result.loss), step=result.iteration)
|
311
|
+
if self.mlflow:
|
312
|
+
self.mlflow.log_metrics(result.metrics, step=result.iteration)
|
313
|
+
else:
|
314
|
+
raise RuntimeError(
|
315
|
+
"The writer is not initialized."
|
316
|
+
"Please call the 'writer.open()' method before writing"
|
317
|
+
)
|
318
|
+
|
319
|
+
def log_hyperparams(self, hyperparams: dict) -> None:
|
320
|
+
"""
|
321
|
+
Logs hyperparameters to MLflow.
|
322
|
+
|
323
|
+
Args:
|
324
|
+
hyperparams (dict): A dictionary of hyperparameters to log.
|
325
|
+
"""
|
326
|
+
if self.mlflow:
|
327
|
+
self.mlflow.log_params(hyperparams)
|
328
|
+
else:
|
329
|
+
raise RuntimeError(
|
330
|
+
"The writer is not initialized."
|
331
|
+
"Please call the 'writer.open()' method before writing"
|
332
|
+
)
|
333
|
+
|
334
|
+
def plot(
|
335
|
+
self,
|
336
|
+
model: Module,
|
337
|
+
iteration: int,
|
338
|
+
plotting_functions: tuple[PlottingFunction, ...],
|
339
|
+
) -> None:
|
340
|
+
"""
|
341
|
+
Logs plots of the model using provided plotting functions.
|
342
|
+
|
343
|
+
Args:
|
344
|
+
model (Module): The model to plot.
|
345
|
+
iteration (int): The current iteration number.
|
346
|
+
plotting_functions (tuple[PlottingFunction, ...]): Functions used
|
347
|
+
to generate plots.
|
348
|
+
"""
|
349
|
+
if self.mlflow:
|
350
|
+
for pf in plotting_functions:
|
351
|
+
descr, fig = pf(model, iteration)
|
352
|
+
self.mlflow.log_figure(fig, descr)
|
353
|
+
else:
|
354
|
+
raise RuntimeError(
|
355
|
+
"The writer is not initialized."
|
356
|
+
"Please call the 'writer.open()' method before writing"
|
357
|
+
)
|
358
|
+
|
359
|
+
def get_signature_from_dataloader(self, model: Module, dataloader: DataLoader | None) -> Any:
|
360
|
+
"""
|
361
|
+
Infers the signature of the model based on the input data from the dataloader.
|
362
|
+
|
363
|
+
Args:
|
364
|
+
model (Module): The model to use for inference.
|
365
|
+
dataloader (DataLoader | None): DataLoader for model inputs.
|
366
|
+
|
367
|
+
Returns:
|
368
|
+
Optional[Any]: The inferred signature, if available.
|
369
|
+
"""
|
370
|
+
if dataloader is None:
|
371
|
+
return None
|
372
|
+
|
373
|
+
xs: InputData
|
374
|
+
xs, *_ = next(iter(dataloader))
|
375
|
+
preds = model(xs)
|
376
|
+
|
377
|
+
if isinstance(xs, Tensor):
|
378
|
+
xs = xs.detach().cpu().numpy()
|
379
|
+
preds = preds.detach().cpu().numpy()
|
380
|
+
return infer_signature(xs, preds)
|
381
|
+
|
382
|
+
return None
|
383
|
+
|
384
|
+
def log_model(
|
385
|
+
self,
|
386
|
+
model: Module,
|
387
|
+
train_dataloader: DataLoader | None = None,
|
388
|
+
val_dataloader: DataLoader | None = None,
|
389
|
+
test_dataloader: DataLoader | None = None,
|
390
|
+
) -> None:
|
391
|
+
"""
|
392
|
+
Logs the model and its signature to MLflow using the provided data loaders.
|
393
|
+
|
394
|
+
Args:
|
395
|
+
model (Module): The model to log.
|
396
|
+
train_dataloader (DataLoader | None): DataLoader for training data.
|
397
|
+
val_dataloader (DataLoader | None): DataLoader for validation data.
|
398
|
+
test_dataloader (DataLoader | None): DataLoader for testing data.
|
399
|
+
"""
|
400
|
+
if not self.mlflow:
|
401
|
+
raise RuntimeError(
|
402
|
+
"The writer is not initialized."
|
403
|
+
"Please call the 'writer.open()' method before writing"
|
404
|
+
)
|
405
|
+
|
406
|
+
signatures = self.get_signature_from_dataloader(model, train_dataloader)
|
407
|
+
self.mlflow.pytorch.log_model(model, artifact_path="model", signature=signatures)
|
408
|
+
|
409
|
+
|
410
|
+
# Writer registry
|
411
|
+
WRITER_REGISTRY = {
|
412
|
+
ExperimentTrackingTool.TENSORBOARD: TensorBoardWriter,
|
413
|
+
ExperimentTrackingTool.MLFLOW: MLFlowWriter,
|
414
|
+
}
|
415
|
+
|
416
|
+
|
417
|
+
def get_writer(tracking_tool: ExperimentTrackingTool) -> BaseWriter:
|
418
|
+
"""Factory method to get the appropriate writer based on the tracking tool.
|
419
|
+
|
420
|
+
Args:
|
421
|
+
tracking_tool (ExperimentTrackingTool): The experiment tracking tool to use.
|
422
|
+
|
423
|
+
Returns:
|
424
|
+
BaseWriter: An instance of the appropriate writer.
|
425
|
+
"""
|
426
|
+
writer_class = WRITER_REGISTRY.get(tracking_tool)
|
427
|
+
if writer_class:
|
428
|
+
return writer_class()
|
429
|
+
else:
|
430
|
+
raise ValueError(f"Unsupported tracking tool: {tracking_tool}")
|