qadence 1.8.0__py3-none-any.whl → 1.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qadence/__init__.py +1 -1
- qadence/analog/parse_analog.py +1 -2
- qadence/backends/gpsr.py +8 -2
- qadence/backends/pulser/backend.py +7 -23
- qadence/backends/pyqtorch/backend.py +80 -5
- qadence/backends/pyqtorch/config.py +10 -3
- qadence/backends/pyqtorch/convert_ops.py +63 -2
- qadence/blocks/primitive.py +1 -0
- qadence/execution.py +0 -2
- qadence/log_config.yaml +10 -0
- qadence/measurements/shadow.py +97 -128
- qadence/measurements/utils.py +2 -2
- qadence/mitigations/readout.py +12 -6
- qadence/ml_tools/__init__.py +4 -8
- qadence/ml_tools/callbacks/__init__.py +30 -0
- qadence/ml_tools/callbacks/callback.py +451 -0
- qadence/ml_tools/callbacks/callbackmanager.py +214 -0
- qadence/ml_tools/{saveload.py → callbacks/saveload.py} +11 -11
- qadence/ml_tools/callbacks/writer_registry.py +430 -0
- qadence/ml_tools/config.py +132 -258
- qadence/ml_tools/data.py +7 -3
- qadence/ml_tools/loss/__init__.py +10 -0
- qadence/ml_tools/loss/loss.py +87 -0
- qadence/ml_tools/optimize_step.py +45 -10
- qadence/ml_tools/stages.py +46 -0
- qadence/ml_tools/train_utils/__init__.py +7 -0
- qadence/ml_tools/train_utils/base_trainer.py +548 -0
- qadence/ml_tools/train_utils/config_manager.py +184 -0
- qadence/ml_tools/trainer.py +692 -0
- qadence/model.py +1 -1
- qadence/noise/__init__.py +2 -2
- qadence/noise/protocols.py +18 -53
- qadence/operations/ham_evo.py +87 -26
- qadence/transpile/noise.py +12 -5
- qadence/types.py +15 -3
- {qadence-1.8.0.dist-info → qadence-1.9.0.dist-info}/METADATA +3 -4
- {qadence-1.8.0.dist-info → qadence-1.9.0.dist-info}/RECORD +39 -32
- {qadence-1.8.0.dist-info → qadence-1.9.0.dist-info}/WHEEL +1 -1
- qadence/ml_tools/printing.py +0 -154
- qadence/ml_tools/train_grad.py +0 -395
- qadence/ml_tools/train_no_grad.py +0 -199
- qadence/noise/readout.py +0 -218
- {qadence-1.8.0.dist-info → qadence-1.9.0.dist-info}/licenses/LICENSE +0 -0
qadence/ml_tools/printing.py
DELETED
@@ -1,154 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from logging import getLogger
|
4
|
-
from typing import Any, Callable, Union
|
5
|
-
|
6
|
-
from matplotlib.figure import Figure
|
7
|
-
from torch import Tensor
|
8
|
-
from torch.nn import Module
|
9
|
-
from torch.utils.data import DataLoader
|
10
|
-
from torch.utils.tensorboard import SummaryWriter
|
11
|
-
|
12
|
-
from qadence.ml_tools.data import DictDataLoader
|
13
|
-
from qadence.types import ExperimentTrackingTool
|
14
|
-
|
15
|
-
logger = getLogger(__name__)
|
16
|
-
|
17
|
-
PlottingFunction = Callable[[Module, int], tuple[str, Figure]]
|
18
|
-
InputData = Union[Tensor, dict[str, Tensor]]
|
19
|
-
|
20
|
-
|
21
|
-
def print_metrics(loss: float | None, metrics: dict, iteration: int) -> None:
|
22
|
-
msg = " ".join(
|
23
|
-
[f"Iteration {iteration: >7} | Loss: {loss:.7f} -"]
|
24
|
-
+ [f"{k}: {v.item():.7f}" for k, v in metrics.items()]
|
25
|
-
)
|
26
|
-
print(msg)
|
27
|
-
|
28
|
-
|
29
|
-
def write_tensorboard(
|
30
|
-
writer: SummaryWriter, loss: float = None, metrics: dict | None = None, iteration: int = 0
|
31
|
-
) -> None:
|
32
|
-
metrics = metrics or dict()
|
33
|
-
if loss is not None:
|
34
|
-
writer.add_scalar("loss", loss, iteration)
|
35
|
-
for key, arg in metrics.items():
|
36
|
-
writer.add_scalar(key, arg, iteration)
|
37
|
-
|
38
|
-
|
39
|
-
def log_hyperparams_tensorboard(writer: SummaryWriter, hyperparams: dict, metrics: dict) -> None:
|
40
|
-
writer.add_hparams(hyperparams, metrics)
|
41
|
-
|
42
|
-
|
43
|
-
def plot_tensorboard(
|
44
|
-
writer: SummaryWriter,
|
45
|
-
model: Module,
|
46
|
-
iteration: int,
|
47
|
-
plotting_functions: tuple[PlottingFunction],
|
48
|
-
) -> None:
|
49
|
-
for pf in plotting_functions:
|
50
|
-
descr, fig = pf(model, iteration)
|
51
|
-
writer.add_figure(descr, fig, global_step=iteration)
|
52
|
-
|
53
|
-
|
54
|
-
def log_model_tensorboard(
|
55
|
-
writer: SummaryWriter,
|
56
|
-
model: Module,
|
57
|
-
dataloader: Union[None, DataLoader, DictDataLoader],
|
58
|
-
) -> None:
|
59
|
-
logger.warning("Model logging is not supported by tensorboard. No model will be logged.")
|
60
|
-
|
61
|
-
|
62
|
-
def write_mlflow(writer: Any, loss: float | None, metrics: dict, iteration: int) -> None:
|
63
|
-
writer.log_metrics({"loss": float(loss)}, step=iteration) # type: ignore
|
64
|
-
writer.log_metrics(metrics, step=iteration) # logs the single metrics
|
65
|
-
|
66
|
-
|
67
|
-
def log_hyperparams_mlflow(writer: Any, hyperparams: dict, metrics: dict) -> None:
|
68
|
-
writer.log_params(hyperparams) # type: ignore
|
69
|
-
|
70
|
-
|
71
|
-
def plot_mlflow(
|
72
|
-
writer: Any,
|
73
|
-
model: Module,
|
74
|
-
iteration: int,
|
75
|
-
plotting_functions: tuple[PlottingFunction],
|
76
|
-
) -> None:
|
77
|
-
for pf in plotting_functions:
|
78
|
-
descr, fig = pf(model, iteration)
|
79
|
-
writer.log_figure(fig, descr)
|
80
|
-
|
81
|
-
|
82
|
-
def log_model_mlflow(
|
83
|
-
writer: Any, model: Module, dataloader: DataLoader | DictDataLoader | None
|
84
|
-
) -> None:
|
85
|
-
signature = None
|
86
|
-
if dataloader is not None:
|
87
|
-
xs: InputData
|
88
|
-
xs, *_ = next(iter(dataloader))
|
89
|
-
preds = model(xs)
|
90
|
-
if isinstance(xs, Tensor):
|
91
|
-
xs = xs.numpy()
|
92
|
-
preds = preds.detach().numpy()
|
93
|
-
elif isinstance(xs, dict):
|
94
|
-
for key, val in xs.items():
|
95
|
-
xs[key] = val.numpy()
|
96
|
-
for key, val in preds.items():
|
97
|
-
preds[key] = val.detach.numpy()
|
98
|
-
|
99
|
-
try:
|
100
|
-
from mlflow.models import infer_signature
|
101
|
-
|
102
|
-
signature = infer_signature(xs, preds)
|
103
|
-
except ImportError:
|
104
|
-
logger.warning(
|
105
|
-
"An MLFlow specific function has been called but MLFlow failed to import."
|
106
|
-
"Please install MLFlow or adjust your code."
|
107
|
-
)
|
108
|
-
|
109
|
-
writer.pytorch.log_model(model, artifact_path="model", signature=signature)
|
110
|
-
|
111
|
-
|
112
|
-
TRACKER_MAPPING: dict[ExperimentTrackingTool, Callable[..., None]] = {
|
113
|
-
ExperimentTrackingTool.TENSORBOARD: write_tensorboard,
|
114
|
-
ExperimentTrackingTool.MLFLOW: write_mlflow,
|
115
|
-
}
|
116
|
-
|
117
|
-
LOGGER_MAPPING: dict[ExperimentTrackingTool, Callable[..., None]] = {
|
118
|
-
ExperimentTrackingTool.TENSORBOARD: log_hyperparams_tensorboard,
|
119
|
-
ExperimentTrackingTool.MLFLOW: log_hyperparams_mlflow,
|
120
|
-
}
|
121
|
-
|
122
|
-
PLOTTER_MAPPING: dict[ExperimentTrackingTool, Callable[..., None]] = {
|
123
|
-
ExperimentTrackingTool.TENSORBOARD: plot_tensorboard,
|
124
|
-
ExperimentTrackingTool.MLFLOW: plot_mlflow,
|
125
|
-
}
|
126
|
-
|
127
|
-
MODEL_LOGGER_MAPPING: dict[ExperimentTrackingTool, Callable[..., None]] = {
|
128
|
-
ExperimentTrackingTool.TENSORBOARD: log_model_tensorboard,
|
129
|
-
ExperimentTrackingTool.MLFLOW: log_model_mlflow,
|
130
|
-
}
|
131
|
-
|
132
|
-
|
133
|
-
def write_tracker(
|
134
|
-
*args: Any, tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD
|
135
|
-
) -> None:
|
136
|
-
return TRACKER_MAPPING[tracking_tool](*args)
|
137
|
-
|
138
|
-
|
139
|
-
def log_tracker(
|
140
|
-
*args: Any, tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD
|
141
|
-
) -> None:
|
142
|
-
return LOGGER_MAPPING[tracking_tool](*args)
|
143
|
-
|
144
|
-
|
145
|
-
def plot_tracker(
|
146
|
-
*args: Any, tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD
|
147
|
-
) -> None:
|
148
|
-
return PLOTTER_MAPPING[tracking_tool](*args)
|
149
|
-
|
150
|
-
|
151
|
-
def log_model_tracker(
|
152
|
-
*args: Any, tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD
|
153
|
-
) -> None:
|
154
|
-
return MODEL_LOGGER_MAPPING[tracking_tool](*args)
|
qadence/ml_tools/train_grad.py
DELETED
@@ -1,395 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import importlib
|
4
|
-
import math
|
5
|
-
from logging import getLogger
|
6
|
-
from typing import Any, Callable, Union
|
7
|
-
|
8
|
-
from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn
|
9
|
-
from torch import Tensor, complex128, float32, float64
|
10
|
-
from torch import device as torch_device
|
11
|
-
from torch import dtype as torch_dtype
|
12
|
-
from torch.nn import DataParallel, Module
|
13
|
-
from torch.optim import Optimizer
|
14
|
-
from torch.utils.data import DataLoader
|
15
|
-
from torch.utils.tensorboard import SummaryWriter
|
16
|
-
|
17
|
-
from qadence.ml_tools.config import Callback, TrainConfig, run_callbacks
|
18
|
-
from qadence.ml_tools.data import DictDataLoader, OptimizeResult, data_to_device
|
19
|
-
from qadence.ml_tools.optimize_step import optimize_step
|
20
|
-
from qadence.ml_tools.printing import (
|
21
|
-
log_model_tracker,
|
22
|
-
log_tracker,
|
23
|
-
plot_tracker,
|
24
|
-
print_metrics,
|
25
|
-
write_tracker,
|
26
|
-
)
|
27
|
-
from qadence.ml_tools.saveload import load_checkpoint, write_checkpoint
|
28
|
-
from qadence.types import ExperimentTrackingTool
|
29
|
-
|
30
|
-
logger = getLogger(__name__)
|
31
|
-
|
32
|
-
|
33
|
-
def train(
|
34
|
-
model: Module,
|
35
|
-
dataloader: Union[None, DataLoader, DictDataLoader],
|
36
|
-
optimizer: Optimizer,
|
37
|
-
config: TrainConfig,
|
38
|
-
loss_fn: Callable,
|
39
|
-
device: torch_device = None,
|
40
|
-
optimize_step: Callable = optimize_step,
|
41
|
-
dtype: torch_dtype = None,
|
42
|
-
) -> tuple[Module, Optimizer]:
|
43
|
-
"""Runs the training loop with gradient-based optimizer.
|
44
|
-
|
45
|
-
Assumes that `loss_fn` returns a tuple of (loss,
|
46
|
-
metrics: dict), where `metrics` is a dict of scalars. Loss and metrics are
|
47
|
-
written to tensorboard. Checkpoints are written every
|
48
|
-
`config.checkpoint_every` steps (and after the last training step). If a
|
49
|
-
checkpoint is found at `config.folder` we resume training from there. The
|
50
|
-
tensorboard logs can be viewed via `tensorboard --logdir /path/to/folder`.
|
51
|
-
|
52
|
-
Args:
|
53
|
-
model: The model to train.
|
54
|
-
dataloader: dataloader of different types. If None, no data is required by
|
55
|
-
the model
|
56
|
-
optimizer: The optimizer to use.
|
57
|
-
config: `TrainConfig` with additional training options.
|
58
|
-
loss_fn: Loss function returning (loss: float, metrics: dict[str, float], ...)
|
59
|
-
device: String defining device to train on, pass 'cuda' for GPU.
|
60
|
-
optimize_step: Customizable optimization callback which is called at every iteration.=
|
61
|
-
The function must have the signature `optimize_step(model,
|
62
|
-
optimizer, loss_fn, xs, device="cpu")`.
|
63
|
-
dtype: The dtype to use for the data.
|
64
|
-
|
65
|
-
Example:
|
66
|
-
```python exec="on" source="material-block"
|
67
|
-
from pathlib import Path
|
68
|
-
import torch
|
69
|
-
from itertools import count
|
70
|
-
from qadence import Parameter, QuantumCircuit, Z
|
71
|
-
from qadence import hamiltonian_factory, hea, feature_map, chain
|
72
|
-
from qadence import QNN
|
73
|
-
from qadence.ml_tools import TrainConfig, train_with_grad, to_dataloader
|
74
|
-
|
75
|
-
n_qubits = 2
|
76
|
-
fm = feature_map(n_qubits)
|
77
|
-
ansatz = hea(n_qubits=n_qubits, depth=3)
|
78
|
-
observable = hamiltonian_factory(n_qubits, detuning = Z)
|
79
|
-
circuit = QuantumCircuit(n_qubits, fm, ansatz)
|
80
|
-
|
81
|
-
model = QNN(circuit, observable, backend="pyqtorch", diff_mode="ad")
|
82
|
-
batch_size = 1
|
83
|
-
input_values = {"phi": torch.rand(batch_size, requires_grad=True)}
|
84
|
-
pred = model(input_values)
|
85
|
-
|
86
|
-
## lets prepare the train routine
|
87
|
-
|
88
|
-
cnt = count()
|
89
|
-
criterion = torch.nn.MSELoss()
|
90
|
-
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
|
91
|
-
|
92
|
-
def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, dict]:
|
93
|
-
next(cnt)
|
94
|
-
x, y = data[0], data[1]
|
95
|
-
out = model(x)
|
96
|
-
loss = criterion(out, y)
|
97
|
-
return loss, {}
|
98
|
-
|
99
|
-
tmp_path = Path("/tmp")
|
100
|
-
n_epochs = 5
|
101
|
-
batch_size = 25
|
102
|
-
config = TrainConfig(
|
103
|
-
folder=tmp_path,
|
104
|
-
max_iter=n_epochs,
|
105
|
-
checkpoint_every=100,
|
106
|
-
write_every=100,
|
107
|
-
)
|
108
|
-
x = torch.linspace(0, 1, batch_size).reshape(-1, 1)
|
109
|
-
y = torch.sin(x)
|
110
|
-
data = to_dataloader(x, y, batch_size=batch_size, infinite=True)
|
111
|
-
train_with_grad(model, data, optimizer, config, loss_fn=loss_fn)
|
112
|
-
```
|
113
|
-
"""
|
114
|
-
# load available checkpoint
|
115
|
-
init_iter = 0
|
116
|
-
log_device = "cpu" if device is None else device
|
117
|
-
if config.folder:
|
118
|
-
model, optimizer, init_iter = load_checkpoint(
|
119
|
-
config.folder, model, optimizer, device=log_device
|
120
|
-
)
|
121
|
-
logger.debug(f"Loaded model and optimizer from {config.folder}")
|
122
|
-
|
123
|
-
# Move model to device before optimizer is loaded
|
124
|
-
if isinstance(model, DataParallel):
|
125
|
-
model = model.module.to(device=device, dtype=dtype)
|
126
|
-
else:
|
127
|
-
model = model.to(device=device, dtype=dtype)
|
128
|
-
# initialize tracking tool
|
129
|
-
if config.tracking_tool == ExperimentTrackingTool.TENSORBOARD:
|
130
|
-
writer = SummaryWriter(config.folder, purge_step=init_iter)
|
131
|
-
else:
|
132
|
-
writer = importlib.import_module("mlflow")
|
133
|
-
|
134
|
-
perform_val = isinstance(config.val_every, int)
|
135
|
-
if perform_val:
|
136
|
-
if not isinstance(dataloader, DictDataLoader):
|
137
|
-
raise ValueError(
|
138
|
-
"If `config.val_every` is provided as an integer, dataloader must"
|
139
|
-
"be an instance of `DictDataLoader`."
|
140
|
-
)
|
141
|
-
iter_keys = dataloader.dataloaders.keys()
|
142
|
-
if "train" not in iter_keys or "val" not in iter_keys:
|
143
|
-
raise ValueError(
|
144
|
-
"If `config.val_every` is provided as an integer, the dictdataloader"
|
145
|
-
"must have `train` and `val` keys to access the respective dataloaders."
|
146
|
-
)
|
147
|
-
val_dataloader = dataloader.dataloaders["val"]
|
148
|
-
dataloader = dataloader.dataloaders["train"]
|
149
|
-
|
150
|
-
## Training
|
151
|
-
progress = Progress(
|
152
|
-
TextColumn("[progress.description]{task.description}"),
|
153
|
-
BarColumn(),
|
154
|
-
TaskProgressColumn(),
|
155
|
-
TimeRemainingColumn(elapsed_when_finished=True),
|
156
|
-
)
|
157
|
-
data_dtype = None
|
158
|
-
if dtype:
|
159
|
-
data_dtype = float64 if dtype == complex128 else float32
|
160
|
-
|
161
|
-
best_val_loss = math.inf
|
162
|
-
|
163
|
-
if not ((dataloader is None) or isinstance(dataloader, (DictDataLoader, DataLoader))):
|
164
|
-
raise NotImplementedError(
|
165
|
-
f"Unsupported dataloader type: {type(dataloader)}. "
|
166
|
-
"You can use e.g. `qadence.ml_tools.to_dataloader` to build a dataloader."
|
167
|
-
)
|
168
|
-
|
169
|
-
def next_loss_iter(dl_iter: Union[None, DataLoader, DictDataLoader]) -> Any:
|
170
|
-
"""Get loss on the next batch of a dataloader.
|
171
|
-
|
172
|
-
loaded on device if not None.
|
173
|
-
|
174
|
-
Args:
|
175
|
-
dl_iter (Union[None, DataLoader, DictDataLoader]): Dataloader.
|
176
|
-
|
177
|
-
Returns:
|
178
|
-
Any: Loss value
|
179
|
-
"""
|
180
|
-
xs = next(dl_iter) if dl_iter is not None else None
|
181
|
-
xs_to_device = data_to_device(xs, device=device, dtype=data_dtype)
|
182
|
-
return loss_fn(model, xs_to_device)
|
183
|
-
|
184
|
-
# populate callbacks with already available internal functions
|
185
|
-
# printing, writing and plotting
|
186
|
-
callbacks = config.callbacks
|
187
|
-
|
188
|
-
# printing
|
189
|
-
if config.verbose and config.print_every > 0:
|
190
|
-
# Note that the loss returned by optimize_step
|
191
|
-
# is the value before doing the training step
|
192
|
-
# which is printed accordingly by the previous iteration number
|
193
|
-
callbacks += [
|
194
|
-
Callback(
|
195
|
-
lambda opt_res: print_metrics(opt_res.loss, opt_res.metrics, opt_res.iteration - 1),
|
196
|
-
called_every=config.print_every,
|
197
|
-
)
|
198
|
-
]
|
199
|
-
|
200
|
-
# plotting
|
201
|
-
callbacks += [
|
202
|
-
Callback(
|
203
|
-
lambda opt_res: plot_tracker(
|
204
|
-
writer,
|
205
|
-
opt_res.model,
|
206
|
-
opt_res.iteration,
|
207
|
-
config.plotting_functions,
|
208
|
-
tracking_tool=config.tracking_tool,
|
209
|
-
),
|
210
|
-
called_every=config.plot_every,
|
211
|
-
call_before_opt=True,
|
212
|
-
)
|
213
|
-
]
|
214
|
-
|
215
|
-
# writing metrics
|
216
|
-
# we specify two writers,
|
217
|
-
# to write at evaluation time and before evaluation
|
218
|
-
callbacks += [
|
219
|
-
Callback(
|
220
|
-
lambda opt_res: write_tracker(
|
221
|
-
writer,
|
222
|
-
opt_res.loss,
|
223
|
-
opt_res.metrics,
|
224
|
-
opt_res.iteration - 1, # loss returned be optimized_step is at -1
|
225
|
-
tracking_tool=config.tracking_tool,
|
226
|
-
),
|
227
|
-
called_every=config.write_every,
|
228
|
-
call_end_epoch=True,
|
229
|
-
),
|
230
|
-
Callback(
|
231
|
-
lambda opt_res: write_tracker(
|
232
|
-
writer,
|
233
|
-
opt_res.loss,
|
234
|
-
opt_res.metrics,
|
235
|
-
opt_res.iteration, # after_opt we match the right loss function
|
236
|
-
tracking_tool=config.tracking_tool,
|
237
|
-
),
|
238
|
-
called_every=config.write_every,
|
239
|
-
call_end_epoch=False,
|
240
|
-
call_after_opt=True,
|
241
|
-
),
|
242
|
-
]
|
243
|
-
if perform_val:
|
244
|
-
callbacks += [
|
245
|
-
Callback(
|
246
|
-
lambda opt_res: write_tracker(
|
247
|
-
writer,
|
248
|
-
None,
|
249
|
-
opt_res.metrics,
|
250
|
-
opt_res.iteration,
|
251
|
-
tracking_tool=config.tracking_tool,
|
252
|
-
),
|
253
|
-
called_every=config.write_every,
|
254
|
-
call_before_opt=True,
|
255
|
-
call_during_eval=True,
|
256
|
-
)
|
257
|
-
]
|
258
|
-
|
259
|
-
# checkpointing
|
260
|
-
if config.folder and config.checkpoint_every > 0 and not config.checkpoint_best_only:
|
261
|
-
callbacks += [
|
262
|
-
Callback(
|
263
|
-
lambda opt_res: write_checkpoint(
|
264
|
-
config.folder, # type: ignore[arg-type]
|
265
|
-
opt_res.model,
|
266
|
-
opt_res.optimizer,
|
267
|
-
opt_res.iteration,
|
268
|
-
),
|
269
|
-
called_every=config.checkpoint_every,
|
270
|
-
call_before_opt=False,
|
271
|
-
call_after_opt=True,
|
272
|
-
)
|
273
|
-
]
|
274
|
-
|
275
|
-
if config.folder and config.checkpoint_best_only:
|
276
|
-
callbacks += [
|
277
|
-
Callback(
|
278
|
-
lambda opt_res: write_checkpoint(
|
279
|
-
config.folder, # type: ignore[arg-type]
|
280
|
-
opt_res.model,
|
281
|
-
opt_res.optimizer,
|
282
|
-
"best",
|
283
|
-
),
|
284
|
-
called_every=config.checkpoint_every,
|
285
|
-
call_before_opt=True,
|
286
|
-
call_after_opt=True,
|
287
|
-
call_during_eval=True,
|
288
|
-
)
|
289
|
-
]
|
290
|
-
|
291
|
-
callbacks_before_opt = [
|
292
|
-
callback
|
293
|
-
for callback in callbacks
|
294
|
-
if callback.call_before_opt and not callback.call_during_eval
|
295
|
-
]
|
296
|
-
callbacks_before_opt_eval = [
|
297
|
-
callback for callback in callbacks if callback.call_before_opt and callback.call_during_eval
|
298
|
-
]
|
299
|
-
|
300
|
-
with progress:
|
301
|
-
dl_iter = iter(dataloader) if dataloader is not None else None
|
302
|
-
|
303
|
-
# Initial validation evaluation
|
304
|
-
try:
|
305
|
-
opt_result = OptimizeResult(init_iter, model, optimizer)
|
306
|
-
if perform_val:
|
307
|
-
dl_iter_val = iter(val_dataloader) if val_dataloader is not None else None
|
308
|
-
best_val_loss, metrics, *_ = next_loss_iter(dl_iter_val)
|
309
|
-
metrics["val_loss"] = best_val_loss
|
310
|
-
opt_result.metrics = metrics
|
311
|
-
run_callbacks(callbacks_before_opt_eval, opt_result)
|
312
|
-
|
313
|
-
run_callbacks(callbacks_before_opt, opt_result)
|
314
|
-
|
315
|
-
except KeyboardInterrupt:
|
316
|
-
logger.info("Terminating training gracefully after the current iteration.")
|
317
|
-
|
318
|
-
# outer epoch loop
|
319
|
-
init_iter += 1
|
320
|
-
callbacks_end_epoch = [
|
321
|
-
callback
|
322
|
-
for callback in callbacks
|
323
|
-
if callback.call_end_epoch and not callback.call_during_eval
|
324
|
-
]
|
325
|
-
callbacks_end_epoch_eval = [
|
326
|
-
callback
|
327
|
-
for callback in callbacks
|
328
|
-
if callback.call_end_epoch and callback.call_during_eval
|
329
|
-
]
|
330
|
-
for iteration in progress.track(range(init_iter, init_iter + config.max_iter)):
|
331
|
-
try:
|
332
|
-
# in case there is not data needed by the model
|
333
|
-
# this is the case, for example, of quantum models
|
334
|
-
# which do not have classical input data (e.g. chemistry)
|
335
|
-
loss, metrics = optimize_step(
|
336
|
-
model=model,
|
337
|
-
optimizer=optimizer,
|
338
|
-
loss_fn=loss_fn,
|
339
|
-
xs=None if dataloader is None else next(dl_iter), # type: ignore[arg-type]
|
340
|
-
device=device,
|
341
|
-
dtype=data_dtype,
|
342
|
-
)
|
343
|
-
if isinstance(loss, Tensor):
|
344
|
-
loss = loss.item()
|
345
|
-
opt_result = OptimizeResult(iteration, model, optimizer, loss, metrics)
|
346
|
-
run_callbacks(callbacks_end_epoch, opt_result)
|
347
|
-
|
348
|
-
if perform_val:
|
349
|
-
if iteration % config.val_every == 0:
|
350
|
-
val_loss, *_ = next_loss_iter(dl_iter_val)
|
351
|
-
if config.validation_criterion(val_loss, best_val_loss, config.val_epsilon): # type: ignore[misc]
|
352
|
-
best_val_loss = val_loss
|
353
|
-
metrics["val_loss"] = val_loss
|
354
|
-
opt_result.metrics = metrics
|
355
|
-
|
356
|
-
run_callbacks(callbacks_end_epoch_eval, opt_result)
|
357
|
-
|
358
|
-
except KeyboardInterrupt:
|
359
|
-
logger.info("Terminating training gracefully after the current iteration.")
|
360
|
-
break
|
361
|
-
|
362
|
-
# For handling printing/writing the last training loss
|
363
|
-
# as optimize_step does not give the loss value at the last iteration
|
364
|
-
try:
|
365
|
-
loss, metrics, *_ = next_loss_iter(dl_iter)
|
366
|
-
if isinstance(loss, Tensor):
|
367
|
-
loss = loss.item()
|
368
|
-
if perform_val:
|
369
|
-
# reputting val_loss as already evaluated before
|
370
|
-
metrics["val_loss"] = val_loss
|
371
|
-
print_metrics(loss, metrics, iteration)
|
372
|
-
|
373
|
-
except KeyboardInterrupt:
|
374
|
-
logger.info("Terminating training gracefully after the current iteration.")
|
375
|
-
|
376
|
-
# Final callbacks, by default checkpointing and writing
|
377
|
-
opt_result = OptimizeResult(iteration, model, optimizer, loss, metrics)
|
378
|
-
callbacks_after_opt = [callback for callback in callbacks if callback.call_after_opt]
|
379
|
-
run_callbacks(callbacks_after_opt, opt_result, is_last_iteration=True)
|
380
|
-
|
381
|
-
# writing hyperparameters
|
382
|
-
if config.hyperparams:
|
383
|
-
log_tracker(writer, config.hyperparams, metrics, tracking_tool=config.tracking_tool)
|
384
|
-
|
385
|
-
# logging the model
|
386
|
-
if config.log_model:
|
387
|
-
log_model_tracker(writer, model, dataloader, tracking_tool=config.tracking_tool)
|
388
|
-
|
389
|
-
# close tracker
|
390
|
-
if config.tracking_tool == ExperimentTrackingTool.TENSORBOARD:
|
391
|
-
writer.close()
|
392
|
-
elif config.tracking_tool == ExperimentTrackingTool.MLFLOW:
|
393
|
-
writer.end_run()
|
394
|
-
|
395
|
-
return model, optimizer
|