qadence 1.7.0__py3-none-any.whl → 1.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qadence/backends/api.py +47 -60
- qadence/backends/gpsr.py +1 -0
- qadence/backends/pyqtorch/backend.py +1 -2
- qadence/backends/pyqtorch/config.py +5 -0
- qadence/backends/pyqtorch/convert_ops.py +79 -8
- qadence/backends/utils.py +17 -3
- qadence/blocks/abstract.py +7 -0
- qadence/blocks/embedding.py +17 -12
- qadence/execution.py +11 -3
- qadence/extensions.py +65 -34
- qadence/ml_tools/config.py +9 -3
- qadence/ml_tools/constructors.py +53 -31
- qadence/ml_tools/models.py +51 -17
- qadence/ml_tools/printing.py +5 -2
- qadence/ml_tools/saveload.py +36 -12
- qadence/ml_tools/train_grad.py +45 -7
- qadence/model.py +164 -2
- qadence/operations/ham_evo.py +10 -0
- qadence/parameters.py +10 -1
- qadence/register.py +98 -22
- qadence/types.py +2 -0
- qadence/utils.py +2 -8
- {qadence-1.7.0.dist-info → qadence-1.7.2.dist-info}/METADATA +7 -6
- {qadence-1.7.0.dist-info → qadence-1.7.2.dist-info}/RECORD +26 -26
- {qadence-1.7.0.dist-info → qadence-1.7.2.dist-info}/WHEEL +1 -1
- {qadence-1.7.0.dist-info → qadence-1.7.2.dist-info}/licenses/LICENSE +0 -0
qadence/ml_tools/constructors.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import numpy as np
|
4
|
+
from sympy import Basic
|
4
5
|
|
6
|
+
from qadence.backend import BackendConfiguration
|
5
7
|
from qadence.blocks import chain, kron
|
6
8
|
from qadence.blocks.abstract import AbstractBlock
|
7
9
|
from qadence.blocks.composite import ChainBlock, KronBlock
|
@@ -18,11 +20,16 @@ from qadence.constructors import (
|
|
18
20
|
)
|
19
21
|
from qadence.constructors.ansatze import hea_digital, hea_sDAQC
|
20
22
|
from qadence.constructors.hamiltonians import ObservableConfig, TDetuning
|
21
|
-
from qadence.
|
23
|
+
from qadence.measurements import Measurements
|
24
|
+
from qadence.noise import Noise
|
25
|
+
from qadence.operations import CNOT, RX, RY, I, N, Z
|
22
26
|
from qadence.parameters import Parameter
|
23
27
|
from qadence.register import Register
|
24
28
|
from qadence.types import (
|
25
29
|
AnsatzType,
|
30
|
+
BackendName,
|
31
|
+
DiffMode,
|
32
|
+
InputDiffMode,
|
26
33
|
Interaction,
|
27
34
|
MultivariateStrategy,
|
28
35
|
ObservableTransform,
|
@@ -721,54 +728,69 @@ def create_observable(
|
|
721
728
|
|
722
729
|
def build_qnn_from_configs(
|
723
730
|
register: int | Register,
|
724
|
-
fm_config: FeatureMapConfig,
|
725
|
-
ansatz_config: AnsatzConfig,
|
726
731
|
observable_config: ObservableConfig | list[ObservableConfig],
|
732
|
+
fm_config: FeatureMapConfig = FeatureMapConfig(),
|
733
|
+
ansatz_config: AnsatzConfig = AnsatzConfig(),
|
734
|
+
backend: BackendName = BackendName.PYQTORCH,
|
735
|
+
diff_mode: DiffMode = DiffMode.AD,
|
736
|
+
measurement: Measurements | None = None,
|
737
|
+
noise: Noise | None = None,
|
738
|
+
configuration: BackendConfiguration | dict | None = None,
|
739
|
+
input_diff_mode: InputDiffMode | str = InputDiffMode.AD,
|
727
740
|
) -> QNN:
|
728
741
|
"""
|
729
742
|
Build a QNN model.
|
730
743
|
|
731
744
|
Args:
|
732
745
|
register (int | Register): Number of qubits or a register object.
|
746
|
+
observable_config (ObservableConfig | list[ObservableConfig]): Observable configuration(s).
|
733
747
|
fm_config (FeatureMapConfig): Feature map configuration.
|
734
748
|
ansatz_config (AnsatzConfig): Ansatz configuration.
|
735
|
-
|
749
|
+
backend (BackendName): The chosen quantum backend.
|
750
|
+
diff_mode (DiffMode): The differentiation engine to use. Choices are
|
751
|
+
'gpsr' or 'ad'.
|
752
|
+
measurement (Measurements): Optional measurement protocol. If None,
|
753
|
+
use exact expectation value with a statevector simulator.
|
754
|
+
noise (Noise): A noise model to use.
|
755
|
+
configuration (BackendConfiguration | dict): Optional backend configuration.
|
756
|
+
input_diff_mode (InputDiffMode): The differentiation mode for the input tensor.
|
736
757
|
|
737
758
|
Returns:
|
738
759
|
QNN: A QNN model.
|
739
760
|
"""
|
740
|
-
|
741
|
-
|
742
|
-
register=register,
|
743
|
-
fm_blocks=fm_blocks,
|
744
|
-
ansatz_config=ansatz_config,
|
745
|
-
)
|
761
|
+
blocks: list[AbstractBlock] = []
|
762
|
+
inputs: list[Basic | str] | None = None
|
746
763
|
|
747
|
-
|
764
|
+
if fm_config.num_features > 0:
|
765
|
+
fm_blocks = create_fm_blocks(register=register, config=fm_config)
|
766
|
+
full_fm = _interleave_ansatz_in_fm(
|
767
|
+
register=register,
|
768
|
+
fm_blocks=fm_blocks,
|
769
|
+
ansatz_config=ansatz_config,
|
770
|
+
)
|
771
|
+
inputs = fm_config.inputs
|
772
|
+
blocks.append(full_fm)
|
748
773
|
|
749
|
-
|
750
|
-
# equal superposition of all states. This needs to be here only for rydberg
|
751
|
-
# feature map and only as long as the feature map is not updated to include
|
752
|
-
# a driving term in the Hamiltonian.
|
774
|
+
blocks.append(create_ansatz(register=register, config=ansatz_config))
|
753
775
|
|
754
|
-
|
755
|
-
num_qubits = register if isinstance(register, int) else register.n_qubits
|
756
|
-
mixing_block = kron(*[H(i) for i in range(num_qubits)])
|
757
|
-
full_fm = chain(mixing_block, full_fm)
|
776
|
+
circ = QuantumCircuit(register, *blocks)
|
758
777
|
|
759
|
-
|
760
|
-
register,
|
761
|
-
|
762
|
-
|
778
|
+
observable: AbstractBlock | list[AbstractBlock] = (
|
779
|
+
[observable_from_config(register=register, config=cfg) for cfg in observable_config]
|
780
|
+
if isinstance(observable_config, list)
|
781
|
+
else observable_from_config(register=register, config=observable_config)
|
763
782
|
)
|
764
783
|
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
784
|
+
ufa = QNN(
|
785
|
+
circ,
|
786
|
+
observable,
|
787
|
+
inputs=inputs,
|
788
|
+
backend=backend,
|
789
|
+
diff_mode=diff_mode,
|
790
|
+
measurement=measurement,
|
791
|
+
noise=noise,
|
792
|
+
configuration=configuration,
|
793
|
+
input_diff_mode=input_diff_mode,
|
794
|
+
)
|
773
795
|
|
774
796
|
return ufa
|
qadence/ml_tools/models.py
CHANGED
@@ -14,6 +14,7 @@ from qadence.blocks.abstract import AbstractBlock
|
|
14
14
|
from qadence.circuit import QuantumCircuit
|
15
15
|
from qadence.measurements import Measurements
|
16
16
|
from qadence.mitigations import Mitigations
|
17
|
+
from qadence.ml_tools.config import AnsatzConfig, FeatureMapConfig
|
17
18
|
from qadence.model import QuantumModel
|
18
19
|
from qadence.noise import Noise
|
19
20
|
from qadence.register import Register
|
@@ -61,7 +62,7 @@ def derivative(ufa: torch.nn.Module, x: Tensor, derivative_indices: tuple[int, .
|
|
61
62
|
obs_config = ObservableConfig(detuning=Z)
|
62
63
|
|
63
64
|
f = QNN.from_configs(
|
64
|
-
register=3, fm_config=fm_config, ansatz_config=ansatz_config,
|
65
|
+
register=3, obs_config=obs_config, fm_config=fm_config, ansatz_config=ansatz_config,
|
65
66
|
)
|
66
67
|
inputs = torch.rand(5,3,requires_grad=True)
|
67
68
|
|
@@ -211,21 +212,41 @@ class QNN(QuantumModel):
|
|
211
212
|
def from_configs(
|
212
213
|
cls,
|
213
214
|
register: int | Register,
|
214
|
-
fm_config: Any,
|
215
|
-
ansatz_config: Any,
|
216
215
|
obs_config: Any,
|
216
|
+
fm_config: Any = FeatureMapConfig(),
|
217
|
+
ansatz_config: Any = AnsatzConfig(),
|
218
|
+
backend: BackendName = BackendName.PYQTORCH,
|
219
|
+
diff_mode: DiffMode = DiffMode.AD,
|
220
|
+
measurement: Measurements | None = None,
|
221
|
+
noise: Noise | None = None,
|
222
|
+
configuration: BackendConfiguration | dict | None = None,
|
223
|
+
input_diff_mode: InputDiffMode | str = InputDiffMode.AD,
|
217
224
|
) -> QNN:
|
218
225
|
"""Create a QNN from a set of configurations.
|
219
226
|
|
220
227
|
Args:
|
221
|
-
register: The number of qubits or a register object.
|
222
|
-
|
223
|
-
|
224
|
-
|
228
|
+
register (int | Register): The number of qubits or a register object.
|
229
|
+
obs_config (list[ObservableConfig] | ObservableConfig): The configuration(s)
|
230
|
+
for the observable(s).
|
231
|
+
fm_config (FeatureMapConfig): The configuration for the feature map.
|
232
|
+
Defaults to no feature encoding block.
|
233
|
+
ansatz_config (AnsatzConfig): The configuration for the ansatz.
|
234
|
+
Defaults to a single layer of hardware efficient ansatz.
|
235
|
+
backend (BackendName): The chosen quantum backend.
|
236
|
+
diff_mode (DiffMode): The differentiation engine to use. Choices are
|
237
|
+
'gpsr' or 'ad'.
|
238
|
+
measurement (Measurements): Optional measurement protocol. If None,
|
239
|
+
use exact expectation value with a statevector simulator.
|
240
|
+
noise (Noise): A noise model to use.
|
241
|
+
configuration (BackendConfiguration | dict): Optional backend configuration.
|
242
|
+
input_diff_mode (InputDiffMode): The differentiation mode for the input tensor.
|
225
243
|
|
226
244
|
Returns:
|
227
245
|
A QNN object.
|
228
246
|
|
247
|
+
Raises:
|
248
|
+
ValueError: If the observable configuration is not provided.
|
249
|
+
|
229
250
|
Example:
|
230
251
|
```python exec="on" source="material-block" result="json"
|
231
252
|
import torch
|
@@ -234,10 +255,17 @@ class QNN(QuantumModel):
|
|
234
255
|
from qadence.constructors import ObservableConfig
|
235
256
|
from qadence.operations import Z
|
236
257
|
from qadence.types import (
|
237
|
-
AnsatzType,
|
258
|
+
AnsatzType, BackendName, BasisSet, ObservableTransform, ReuploadScaling, Strategy
|
238
259
|
)
|
239
260
|
|
240
261
|
register = 4
|
262
|
+
obs_config = ObservableConfig(
|
263
|
+
detuning=Z,
|
264
|
+
scale=5.0,
|
265
|
+
shift=0.0,
|
266
|
+
transformation_type=ObservableTransform.SCALE,
|
267
|
+
trainable_transform=None,
|
268
|
+
)
|
241
269
|
fm_config = FeatureMapConfig(
|
242
270
|
num_features=2,
|
243
271
|
inputs=["x", "y"],
|
@@ -253,15 +281,10 @@ class QNN(QuantumModel):
|
|
253
281
|
ansatz_type=AnsatzType.HEA,
|
254
282
|
ansatz_strategy=Strategy.DIGITAL,
|
255
283
|
)
|
256
|
-
obs_config = ObservableConfig(
|
257
|
-
detuning=Z,
|
258
|
-
scale=5.0,
|
259
|
-
shift=0.0,
|
260
|
-
transformation_type=ObservableTransform.SCALE,
|
261
|
-
trainable_transform=None,
|
262
|
-
)
|
263
284
|
|
264
|
-
qnn = QNN.from_configs(
|
285
|
+
qnn = QNN.from_configs(
|
286
|
+
register, obs_config, fm_config, ansatz_config, backend=BackendName.PYQTORCH
|
287
|
+
)
|
265
288
|
|
266
289
|
x = torch.rand(2, 2)
|
267
290
|
y = qnn(x)
|
@@ -270,7 +293,18 @@ class QNN(QuantumModel):
|
|
270
293
|
"""
|
271
294
|
from .constructors import build_qnn_from_configs
|
272
295
|
|
273
|
-
return build_qnn_from_configs(
|
296
|
+
return build_qnn_from_configs(
|
297
|
+
register=register,
|
298
|
+
observable_config=obs_config,
|
299
|
+
fm_config=fm_config,
|
300
|
+
ansatz_config=ansatz_config,
|
301
|
+
backend=backend,
|
302
|
+
diff_mode=diff_mode,
|
303
|
+
measurement=measurement,
|
304
|
+
noise=noise,
|
305
|
+
configuration=configuration,
|
306
|
+
input_diff_mode=input_diff_mode,
|
307
|
+
)
|
274
308
|
|
275
309
|
def forward(
|
276
310
|
self,
|
qadence/ml_tools/printing.py
CHANGED
@@ -11,8 +11,11 @@ def print_metrics(loss: float | None, metrics: dict, iteration: int) -> None:
|
|
11
11
|
print(msg)
|
12
12
|
|
13
13
|
|
14
|
-
def write_tensorboard(
|
15
|
-
writer
|
14
|
+
def write_tensorboard(
|
15
|
+
writer: SummaryWriter, loss: float = None, metrics: dict = {}, iteration: int = 0
|
16
|
+
) -> None:
|
17
|
+
if loss is not None:
|
18
|
+
writer.add_scalar("loss", loss, iteration)
|
16
19
|
for key, arg in metrics.items():
|
17
20
|
writer.add_scalar(key, arg, iteration)
|
18
21
|
|
qadence/ml_tools/saveload.py
CHANGED
@@ -14,7 +14,7 @@ from torch.optim import Optimizer
|
|
14
14
|
logger = getLogger(__name__)
|
15
15
|
|
16
16
|
|
17
|
-
def get_latest_checkpoint_name(folder: Path, type: str) -> Path:
|
17
|
+
def get_latest_checkpoint_name(folder: Path, type: str, device: str | torch.device = "cpu") -> Path:
|
18
18
|
file = Path("")
|
19
19
|
files = [f for f in os.listdir(folder) if f.endswith(".pt") and type in f]
|
20
20
|
if len(files) == 0:
|
@@ -22,12 +22,18 @@ def get_latest_checkpoint_name(folder: Path, type: str) -> Path:
|
|
22
22
|
if len(files) == 1:
|
23
23
|
file = Path(files[0])
|
24
24
|
else:
|
25
|
-
|
25
|
+
device = str(device).split(":")[0]
|
26
|
+
pattern = re.compile(f".*_(\d+)_device_{device}.pt$")
|
27
|
+
legacy_pattern = re.compile(".*_(\d+).pt$")
|
26
28
|
max_index = -1
|
27
29
|
for f in files:
|
30
|
+
legacy_match = legacy_pattern.search(f)
|
28
31
|
match = pattern.search(f)
|
29
|
-
if match:
|
30
|
-
|
32
|
+
if match or legacy_match:
|
33
|
+
if legacy_match:
|
34
|
+
logger.warn(f"Found checkpoint(s) in legacy format: {f}.")
|
35
|
+
match = legacy_match
|
36
|
+
index_str = match.group(1).replace("_", "") # type: ignore [union-attr]
|
31
37
|
index = int(index_str)
|
32
38
|
if index > max_index:
|
33
39
|
max_index = index
|
@@ -41,19 +47,23 @@ def load_checkpoint(
|
|
41
47
|
optimizer: Optimizer | NGOptimizer,
|
42
48
|
model_ckpt_name: str | Path = "",
|
43
49
|
opt_ckpt_name: str | Path = "",
|
50
|
+
device: str | torch.device = "cpu",
|
44
51
|
) -> tuple[Module, Optimizer | NGOptimizer, int]:
|
45
52
|
if isinstance(folder, str):
|
46
53
|
folder = Path(folder)
|
47
54
|
if not folder.exists():
|
48
55
|
folder.mkdir(parents=True)
|
49
56
|
return model, optimizer, 0
|
50
|
-
model, iter = load_model(folder, model, model_ckpt_name)
|
51
|
-
optimizer = load_optimizer(folder, optimizer, opt_ckpt_name)
|
57
|
+
model, iter = load_model(folder, model, model_ckpt_name, device)
|
58
|
+
optimizer = load_optimizer(folder, optimizer, opt_ckpt_name, device)
|
52
59
|
return model, optimizer, iter
|
53
60
|
|
54
61
|
|
55
62
|
def write_checkpoint(
|
56
|
-
folder: Path,
|
63
|
+
folder: Path,
|
64
|
+
model: Module,
|
65
|
+
optimizer: Optimizer | NGOptimizer,
|
66
|
+
iteration: int | str,
|
57
67
|
) -> None:
|
58
68
|
from qadence import QuantumModel
|
59
69
|
|
@@ -63,8 +73,12 @@ def write_checkpoint(
|
|
63
73
|
try:
|
64
74
|
# We extract the device from the pyqtorch native circuit
|
65
75
|
device = str(model.device).split(":")[0] # in case of using several CUDA devices
|
66
|
-
except Exception:
|
67
|
-
|
76
|
+
except Exception as e:
|
77
|
+
msg = (
|
78
|
+
f"Unable to identify in which device the QuantumModel is stored due to {e}."
|
79
|
+
"Setting device to None"
|
80
|
+
)
|
81
|
+
logger.warning(msg)
|
68
82
|
|
69
83
|
iteration_substring = f"{iteration:03n}" if isinstance(iteration, int) else iteration
|
70
84
|
model_checkpoint_name: str = (
|
@@ -102,13 +116,18 @@ def write_checkpoint(
|
|
102
116
|
|
103
117
|
|
104
118
|
def load_model(
|
105
|
-
folder: Path,
|
119
|
+
folder: Path,
|
120
|
+
model: Module,
|
121
|
+
model_ckpt_name: str | Path = "",
|
122
|
+
device: str | torch.device = "cpu",
|
123
|
+
*args: Any,
|
124
|
+
**kwargs: Any,
|
106
125
|
) -> tuple[Module, int]:
|
107
126
|
from qadence import QNN, QuantumModel
|
108
127
|
|
109
128
|
iteration = 0
|
110
129
|
if model_ckpt_name == "":
|
111
|
-
model_ckpt_name = get_latest_checkpoint_name(folder, "model")
|
130
|
+
model_ckpt_name = get_latest_checkpoint_name(folder, "model", device)
|
112
131
|
|
113
132
|
try:
|
114
133
|
iteration, model_dict = torch.load(folder / model_ckpt_name, *args, **kwargs)
|
@@ -116,6 +135,10 @@ def load_model(
|
|
116
135
|
model._from_dict(model_dict, as_torch=True)
|
117
136
|
elif isinstance(model, Module):
|
118
137
|
model.load_state_dict(model_dict, strict=True)
|
138
|
+
# Load model to a specific gpu device if specified
|
139
|
+
pattern = re.compile("cuda:\d+$")
|
140
|
+
if pattern.search(str(device)):
|
141
|
+
model.to(device)
|
119
142
|
|
120
143
|
except Exception as e:
|
121
144
|
msg = f"Unable to load state dict due to {e}.\
|
@@ -128,9 +151,10 @@ def load_optimizer(
|
|
128
151
|
folder: Path,
|
129
152
|
optimizer: Optimizer | NGOptimizer,
|
130
153
|
opt_ckpt_name: str | Path = "",
|
154
|
+
device: str | torch.device = "cpu",
|
131
155
|
) -> Optimizer | NGOptimizer:
|
132
156
|
if opt_ckpt_name == "":
|
133
|
-
opt_ckpt_name = get_latest_checkpoint_name(folder, "opt")
|
157
|
+
opt_ckpt_name = get_latest_checkpoint_name(folder, "opt", device)
|
134
158
|
if os.path.isfile(folder / opt_ckpt_name):
|
135
159
|
if isinstance(optimizer, Optimizer):
|
136
160
|
(_, OptType, optimizer_state) = torch.load(folder / opt_ckpt_name)
|
qadence/ml_tools/train_grad.py
CHANGED
@@ -110,8 +110,11 @@ def train(
|
|
110
110
|
"""
|
111
111
|
# load available checkpoint
|
112
112
|
init_iter = 0
|
113
|
+
log_device = "cpu" if device is None else device
|
113
114
|
if config.folder:
|
114
|
-
model, optimizer, init_iter = load_checkpoint(
|
115
|
+
model, optimizer, init_iter = load_checkpoint(
|
116
|
+
config.folder, model, optimizer, device=log_device
|
117
|
+
)
|
115
118
|
logger.debug(f"Loaded model and optimizer from {config.folder}")
|
116
119
|
|
117
120
|
# Move model to device before optimizer is loaded
|
@@ -150,12 +153,32 @@ def train(
|
|
150
153
|
data_dtype = float64 if dtype == complex128 else float32
|
151
154
|
|
152
155
|
best_val_loss = math.inf
|
156
|
+
|
153
157
|
with progress:
|
154
158
|
dl_iter = iter(dataloader) if dataloader is not None else None
|
155
|
-
|
156
|
-
|
159
|
+
|
160
|
+
# Initial validation evaluation
|
161
|
+
try:
|
162
|
+
if perform_val:
|
163
|
+
dl_iter_val = iter(val_dataloader) if val_dataloader is not None else None
|
164
|
+
xs = next(dl_iter_val)
|
165
|
+
xs_to_device = data_to_device(xs, device=device, dtype=data_dtype)
|
166
|
+
best_val_loss, metrics = loss_fn(model, xs_to_device)
|
167
|
+
|
168
|
+
metrics["val_loss"] = best_val_loss
|
169
|
+
write_tensorboard(writer, None, metrics, init_iter)
|
170
|
+
|
171
|
+
if config.folder:
|
172
|
+
if config.checkpoint_best_only:
|
173
|
+
write_checkpoint(config.folder, model, optimizer, iteration="best")
|
174
|
+
else:
|
175
|
+
write_checkpoint(config.folder, model, optimizer, init_iter)
|
176
|
+
|
177
|
+
except KeyboardInterrupt:
|
178
|
+
logger.info("Terminating training gracefully after the current iteration.")
|
157
179
|
|
158
180
|
# outer epoch loop
|
181
|
+
init_iter += 1
|
159
182
|
for iteration in progress.track(range(init_iter, init_iter + config.max_iter)):
|
160
183
|
try:
|
161
184
|
# in case there is not data needed by the model
|
@@ -189,10 +212,13 @@ def train(
|
|
189
212
|
)
|
190
213
|
|
191
214
|
if iteration % config.print_every == 0 and config.verbose:
|
192
|
-
|
215
|
+
# Note that the loss returned by optimize_step
|
216
|
+
# is the value before doing the training step
|
217
|
+
# which is printed accordingly by the previous iteration number
|
218
|
+
print_metrics(loss, metrics, iteration - 1)
|
193
219
|
|
194
220
|
if iteration % config.write_every == 0:
|
195
|
-
write_tensorboard(writer, loss, metrics, iteration)
|
221
|
+
write_tensorboard(writer, loss, metrics, iteration - 1)
|
196
222
|
|
197
223
|
if perform_val:
|
198
224
|
if iteration % config.val_every == 0:
|
@@ -204,7 +230,7 @@ def train(
|
|
204
230
|
if config.folder and config.checkpoint_best_only:
|
205
231
|
write_checkpoint(config.folder, model, optimizer, iteration="best")
|
206
232
|
metrics["val_loss"] = val_loss
|
207
|
-
write_tensorboard(writer,
|
233
|
+
write_tensorboard(writer, None, metrics, iteration)
|
208
234
|
|
209
235
|
if config.folder:
|
210
236
|
if iteration % config.checkpoint_every == 0 and not config.checkpoint_best_only:
|
@@ -214,7 +240,19 @@ def train(
|
|
214
240
|
logger.info("Terminating training gracefully after the current iteration.")
|
215
241
|
break
|
216
242
|
|
217
|
-
|
243
|
+
# Handling printing the last training loss
|
244
|
+
# as optimize_step does not give the loss value at the last iteration
|
245
|
+
try:
|
246
|
+
xs = next(dl_iter) if dataloader is not None else None # type: ignore[arg-type]
|
247
|
+
xs_to_device = data_to_device(xs, device=device, dtype=data_dtype)
|
248
|
+
loss, metrics = loss_fn(model, xs_to_device)
|
249
|
+
if iteration % config.print_every == 0 and config.verbose:
|
250
|
+
print_metrics(loss, metrics, iteration)
|
251
|
+
|
252
|
+
except KeyboardInterrupt:
|
253
|
+
logger.info("Terminating training gracefully after the current iteration.")
|
254
|
+
|
255
|
+
# Final printing, writing and checkpointing
|
218
256
|
if config.folder and not config.checkpoint_best_only:
|
219
257
|
write_checkpoint(config.folder, model, optimizer, iteration)
|
220
258
|
write_tensorboard(writer, loss, metrics, iteration)
|