qadence 1.9.0__py3-none-any.whl → 1.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qadence/ml_tools/callbacks/writer_registry.py +36 -25
- qadence/ml_tools/train_utils/base_trainer.py +33 -26
- qadence/ml_tools/trainer.py +50 -34
- {qadence-1.9.0.dist-info → qadence-1.9.1.dist-info}/METADATA +1 -1
- {qadence-1.9.0.dist-info → qadence-1.9.1.dist-info}/RECORD +7 -7
- {qadence-1.9.0.dist-info → qadence-1.9.1.dist-info}/WHEEL +0 -0
- {qadence-1.9.0.dist-info → qadence-1.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -7,17 +7,14 @@ from types import ModuleType
|
|
7
7
|
from typing import Any, Callable, Union
|
8
8
|
from uuid import uuid4
|
9
9
|
|
10
|
-
import mlflow
|
11
10
|
from matplotlib.figure import Figure
|
12
|
-
from mlflow.entities import Run
|
13
|
-
from mlflow.models import infer_signature
|
14
11
|
from torch import Tensor
|
15
12
|
from torch.nn import Module
|
16
13
|
from torch.utils.data import DataLoader
|
17
14
|
from torch.utils.tensorboard import SummaryWriter
|
18
15
|
|
19
16
|
from qadence.ml_tools.config import TrainConfig
|
20
|
-
from qadence.ml_tools.data import OptimizeResult
|
17
|
+
from qadence.ml_tools.data import DictDataLoader, OptimizeResult
|
21
18
|
from qadence.types import ExperimentTrackingTool
|
22
19
|
|
23
20
|
logger = getLogger("ml_tools")
|
@@ -43,7 +40,7 @@ class BaseWriter(ABC):
|
|
43
40
|
log_model(model, dataloader): Logs the model and any relevant information.
|
44
41
|
"""
|
45
42
|
|
46
|
-
run:
|
43
|
+
run: Any # [attr-defined]
|
47
44
|
|
48
45
|
@abstractmethod
|
49
46
|
def open(self, config: TrainConfig, iteration: int | None = None) -> Any:
|
@@ -104,18 +101,18 @@ class BaseWriter(ABC):
|
|
104
101
|
def log_model(
|
105
102
|
self,
|
106
103
|
model: Module,
|
107
|
-
train_dataloader: DataLoader | None = None,
|
108
|
-
val_dataloader: DataLoader | None = None,
|
109
|
-
test_dataloader: DataLoader | None = None,
|
104
|
+
train_dataloader: DataLoader | DictDataLoader | None = None,
|
105
|
+
val_dataloader: DataLoader | DictDataLoader | None = None,
|
106
|
+
test_dataloader: DataLoader | DictDataLoader | None = None,
|
110
107
|
) -> None:
|
111
108
|
"""
|
112
109
|
Logs the model and associated data.
|
113
110
|
|
114
111
|
Args:
|
115
112
|
model (Module): The model to log.
|
116
|
-
train_dataloader (DataLoader | None): DataLoader for training data.
|
117
|
-
val_dataloader (DataLoader | None): DataLoader for validation data.
|
118
|
-
test_dataloader (DataLoader | None): DataLoader for testing data.
|
113
|
+
train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
|
114
|
+
val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
|
115
|
+
test_dataloader (DataLoader | DictDataLoader | None): DataLoader for testing data.
|
119
116
|
"""
|
120
117
|
raise NotImplementedError("Writers must implement a log_model method.")
|
121
118
|
|
@@ -231,9 +228,9 @@ class TensorBoardWriter(BaseWriter):
|
|
231
228
|
def log_model(
|
232
229
|
self,
|
233
230
|
model: Module,
|
234
|
-
train_dataloader: DataLoader | None = None,
|
235
|
-
val_dataloader: DataLoader | None = None,
|
236
|
-
test_dataloader: DataLoader | None = None,
|
231
|
+
train_dataloader: DataLoader | DictDataLoader | None = None,
|
232
|
+
val_dataloader: DataLoader | DictDataLoader | None = None,
|
233
|
+
test_dataloader: DataLoader | DictDataLoader | None = None,
|
237
234
|
) -> None:
|
238
235
|
"""
|
239
236
|
Logs the model.
|
@@ -242,9 +239,9 @@ class TensorBoardWriter(BaseWriter):
|
|
242
239
|
|
243
240
|
Args:
|
244
241
|
model (Module): The model to log.
|
245
|
-
train_dataloader (DataLoader | None): DataLoader for training data.
|
246
|
-
val_dataloader (DataLoader | None): DataLoader for validation data.
|
247
|
-
test_dataloader (DataLoader | None): DataLoader for testing data.
|
242
|
+
train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
|
243
|
+
val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
|
244
|
+
test_dataloader (DataLoader | DictDataLoader | None): DataLoader for testing data.
|
248
245
|
"""
|
249
246
|
logger.warning("Model logging is not supported by tensorboard. No model will be logged.")
|
250
247
|
|
@@ -259,6 +256,14 @@ class MLFlowWriter(BaseWriter):
|
|
259
256
|
"""
|
260
257
|
|
261
258
|
def __init__(self) -> None:
|
259
|
+
try:
|
260
|
+
from mlflow.entities import Run
|
261
|
+
except ImportError:
|
262
|
+
raise ImportError(
|
263
|
+
"mlflow is not installed. Please install qadence with the mlflow feature: "
|
264
|
+
"`pip install qadence[mlflow]`."
|
265
|
+
)
|
266
|
+
|
262
267
|
self.run: Run
|
263
268
|
self.mlflow: ModuleType
|
264
269
|
|
@@ -274,6 +279,8 @@ class MLFlowWriter(BaseWriter):
|
|
274
279
|
Returns:
|
275
280
|
mlflow: The MLflow module instance.
|
276
281
|
"""
|
282
|
+
import mlflow
|
283
|
+
|
277
284
|
self.mlflow = mlflow
|
278
285
|
tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "")
|
279
286
|
experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", str(uuid4()))
|
@@ -356,17 +363,21 @@ class MLFlowWriter(BaseWriter):
|
|
356
363
|
"Please call the 'writer.open()' method before writing"
|
357
364
|
)
|
358
365
|
|
359
|
-
def get_signature_from_dataloader(
|
366
|
+
def get_signature_from_dataloader(
|
367
|
+
self, model: Module, dataloader: DataLoader | DictDataLoader | None
|
368
|
+
) -> Any:
|
360
369
|
"""
|
361
370
|
Infers the signature of the model based on the input data from the dataloader.
|
362
371
|
|
363
372
|
Args:
|
364
373
|
model (Module): The model to use for inference.
|
365
|
-
dataloader (DataLoader | None): DataLoader for model inputs.
|
374
|
+
dataloader (DataLoader | DictDataLoader | None): DataLoader for model inputs.
|
366
375
|
|
367
376
|
Returns:
|
368
377
|
Optional[Any]: The inferred signature, if available.
|
369
378
|
"""
|
379
|
+
from mlflow.models import infer_signature
|
380
|
+
|
370
381
|
if dataloader is None:
|
371
382
|
return None
|
372
383
|
|
@@ -384,18 +395,18 @@ class MLFlowWriter(BaseWriter):
|
|
384
395
|
def log_model(
|
385
396
|
self,
|
386
397
|
model: Module,
|
387
|
-
train_dataloader: DataLoader | None = None,
|
388
|
-
val_dataloader: DataLoader | None = None,
|
389
|
-
test_dataloader: DataLoader | None = None,
|
398
|
+
train_dataloader: DataLoader | DictDataLoader | None = None,
|
399
|
+
val_dataloader: DataLoader | DictDataLoader | None = None,
|
400
|
+
test_dataloader: DataLoader | DictDataLoader | None = None,
|
390
401
|
) -> None:
|
391
402
|
"""
|
392
403
|
Logs the model and its signature to MLflow using the provided data loaders.
|
393
404
|
|
394
405
|
Args:
|
395
406
|
model (Module): The model to log.
|
396
|
-
train_dataloader (DataLoader | None): DataLoader for training data.
|
397
|
-
val_dataloader (DataLoader | None): DataLoader for validation data.
|
398
|
-
test_dataloader (DataLoader | None): DataLoader for testing data.
|
407
|
+
train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
|
408
|
+
val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
|
409
|
+
test_dataloader (DataLoader | DictDataLoader | None): DataLoader for testing data.
|
399
410
|
"""
|
400
411
|
if not self.mlflow:
|
401
412
|
raise RuntimeError(
|
@@ -8,11 +8,11 @@ import nevergrad as ng
|
|
8
8
|
import torch
|
9
9
|
from nevergrad.optimization.base import Optimizer as NGOptimizer
|
10
10
|
from torch import nn, optim
|
11
|
-
from torch.utils.data import DataLoader
|
11
|
+
from torch.utils.data import DataLoader, TensorDataset
|
12
12
|
|
13
13
|
from qadence.ml_tools.callbacks import CallbacksManager
|
14
14
|
from qadence.ml_tools.config import TrainConfig
|
15
|
-
from qadence.ml_tools.data import
|
15
|
+
from qadence.ml_tools.data import DictDataLoader
|
16
16
|
from qadence.ml_tools.loss import get_loss_fn
|
17
17
|
from qadence.ml_tools.optimize_step import optimize_step
|
18
18
|
from qadence.ml_tools.parameters import get_parameters
|
@@ -42,9 +42,9 @@ class BaseTrainer:
|
|
42
42
|
model (nn.Module): The neural network model.
|
43
43
|
optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training.
|
44
44
|
config (TrainConfig): The configuration settings for training.
|
45
|
-
train_dataloader (
|
46
|
-
val_dataloader (
|
47
|
-
test_dataloader (
|
45
|
+
train_dataloader (Dataloader | DictDataLoader | None): DataLoader for training data.
|
46
|
+
val_dataloader (Dataloader | DictDataLoader | None): DataLoader for validation data.
|
47
|
+
test_dataloader (Dataloader | DictDataLoader | None): DataLoader for testing data.
|
48
48
|
|
49
49
|
optimize_step (Callable): Function for performing an optimization step.
|
50
50
|
loss_fn (Callable | str ]): loss function to use. Default loss function
|
@@ -69,9 +69,9 @@ class BaseTrainer:
|
|
69
69
|
config: TrainConfig,
|
70
70
|
loss_fn: str | Callable = "mse",
|
71
71
|
optimize_step: Callable = optimize_step,
|
72
|
-
train_dataloader: DataLoader | None = None,
|
73
|
-
val_dataloader: DataLoader | None = None,
|
74
|
-
test_dataloader: DataLoader | None = None,
|
72
|
+
train_dataloader: DataLoader | DictDataLoader | None = None,
|
73
|
+
val_dataloader: DataLoader | DictDataLoader | None = None,
|
74
|
+
test_dataloader: DataLoader | DictDataLoader | None = None,
|
75
75
|
max_batches: int | None = None,
|
76
76
|
):
|
77
77
|
"""
|
@@ -86,11 +86,11 @@ class BaseTrainer:
|
|
86
86
|
str input to be specified to use a default loss function.
|
87
87
|
currently supported loss functions: 'mse', 'cross_entropy'.
|
88
88
|
If not specified, default mse loss will be used.
|
89
|
-
train_dataloader (
|
89
|
+
train_dataloader (Dataloader | DictDataLoader | None): DataLoader for training data.
|
90
90
|
If the model does not need data to evaluate loss, no dataset
|
91
91
|
should be provided.
|
92
|
-
val_dataloader (
|
93
|
-
test_dataloader (
|
92
|
+
val_dataloader (Dataloader | DictDataLoader | None): DataLoader for validation data.
|
93
|
+
test_dataloader (Dataloader | DictDataLoader | None): DataLoader for testing data.
|
94
94
|
max_batches (int | None): Maximum number of batches to process per epoch.
|
95
95
|
This is only valid in case of finite TensorDataset dataloaders.
|
96
96
|
if max_batches is not None, the maximum number of batches used will
|
@@ -100,9 +100,9 @@ class BaseTrainer:
|
|
100
100
|
self._model: nn.Module
|
101
101
|
self._optimizer: optim.Optimizer | NGOptimizer | None
|
102
102
|
self._config: TrainConfig
|
103
|
-
self._train_dataloader: DataLoader | None = None
|
104
|
-
self._val_dataloader: DataLoader | None = None
|
105
|
-
self._test_dataloader: DataLoader | None = None
|
103
|
+
self._train_dataloader: DataLoader | DictDataLoader | None = None
|
104
|
+
self._val_dataloader: DataLoader | DictDataLoader | None = None
|
105
|
+
self._test_dataloader: DataLoader | DictDataLoader | None = None
|
106
106
|
|
107
107
|
self.config = config
|
108
108
|
self.model = model
|
@@ -311,7 +311,7 @@ class BaseTrainer:
|
|
311
311
|
self.callback_manager = CallbacksManager(value)
|
312
312
|
self.config_manager = ConfigManager(value)
|
313
313
|
|
314
|
-
def _compute_num_batches(self, dataloader: DataLoader) -> int:
|
314
|
+
def _compute_num_batches(self, dataloader: DataLoader | DictDataLoader) -> int:
|
315
315
|
"""
|
316
316
|
Computes the number of batches for the given DataLoader.
|
317
317
|
|
@@ -321,34 +321,41 @@ class BaseTrainer:
|
|
321
321
|
"""
|
322
322
|
if dataloader is None:
|
323
323
|
return 1
|
324
|
-
|
325
|
-
|
326
|
-
|
324
|
+
if isinstance(dataloader, DictDataLoader):
|
325
|
+
dataloader_name, dataloader_value = list(dataloader.dataloaders.items())[0]
|
326
|
+
dataset = dataloader_value.dataset
|
327
|
+
batch_size = dataloader_value.batch_size
|
327
328
|
else:
|
328
|
-
|
329
|
-
|
330
|
-
|
329
|
+
dataset = dataloader.dataset
|
330
|
+
batch_size = dataloader.batch_size
|
331
|
+
|
332
|
+
if isinstance(dataset, TensorDataset):
|
333
|
+
n_batches = int((dataset.tensors[0].size(0) + batch_size - 1) // batch_size)
|
331
334
|
return min(self.max_batches, n_batches) if self.max_batches is not None else n_batches
|
335
|
+
else:
|
336
|
+
return 1
|
332
337
|
|
333
|
-
def _validate_dataloader(
|
338
|
+
def _validate_dataloader(
|
339
|
+
self, dataloader: DataLoader | DictDataLoader, dataloader_type: str
|
340
|
+
) -> None:
|
334
341
|
"""
|
335
342
|
Validates the type of the DataLoader and raises errors for unsupported types.
|
336
343
|
|
337
344
|
Args:
|
338
|
-
dataloader (DataLoader): The DataLoader to validate.
|
345
|
+
dataloader (DataLoader | DictDataLoader): The DataLoader to validate.
|
339
346
|
dataloader_type (str): The type of DataLoader ("train", "val", or "test").
|
340
347
|
"""
|
341
348
|
if dataloader is not None:
|
342
|
-
if not isinstance(dataloader, DataLoader):
|
349
|
+
if not isinstance(dataloader, (DataLoader, DictDataLoader)):
|
343
350
|
raise NotImplementedError(
|
344
351
|
f"Unsupported dataloader type: {type(dataloader)}."
|
345
352
|
"The dataloader must be an instance of DataLoader."
|
346
353
|
)
|
347
354
|
if dataloader_type == "val" and self.config.val_every > 0:
|
348
|
-
if not isinstance(dataloader, DataLoader):
|
355
|
+
if not isinstance(dataloader, (DataLoader, DictDataLoader)):
|
349
356
|
raise ValueError(
|
350
357
|
"If `config.val_every` is provided as an integer > 0, validation_dataloader"
|
351
|
-
"must be an instance of `DataLoader`."
|
358
|
+
"must be an instance of `DataLoader` or `DictDataLoader`."
|
352
359
|
)
|
353
360
|
|
354
361
|
@staticmethod
|
qadence/ml_tools/trainer.py
CHANGED
@@ -14,7 +14,7 @@ from torch import dtype as torch_dtype
|
|
14
14
|
from torch.utils.data import DataLoader
|
15
15
|
|
16
16
|
from qadence.ml_tools.config import TrainConfig
|
17
|
-
from qadence.ml_tools.data import OptimizeResult
|
17
|
+
from qadence.ml_tools.data import DictDataLoader, OptimizeResult
|
18
18
|
from qadence.ml_tools.optimize_step import optimize_step, update_ng_parameters
|
19
19
|
from qadence.ml_tools.stages import TrainingStage
|
20
20
|
|
@@ -49,9 +49,9 @@ class Trainer(BaseTrainer):
|
|
49
49
|
model (nn.Module): The neural network model.
|
50
50
|
optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training.
|
51
51
|
config (TrainConfig): The configuration settings for training.
|
52
|
-
train_dataloader (DataLoader | None): DataLoader for training data.
|
53
|
-
val_dataloader (DataLoader | None): DataLoader for validation data.
|
54
|
-
test_dataloader (DataLoader | None): DataLoader for testing data.
|
52
|
+
train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
|
53
|
+
val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
|
54
|
+
test_dataloader (DataLoader | DictDataLoader | None): DataLoader for testing data.
|
55
55
|
|
56
56
|
optimize_step (Callable): Function for performing an optimization step.
|
57
57
|
loss_fn (Callable): loss function to use.
|
@@ -235,9 +235,9 @@ class Trainer(BaseTrainer):
|
|
235
235
|
optimizer: optim.Optimizer | NGOptimizer | None,
|
236
236
|
config: TrainConfig,
|
237
237
|
loss_fn: str | Callable = "mse",
|
238
|
-
train_dataloader: DataLoader | None = None,
|
239
|
-
val_dataloader: DataLoader | None = None,
|
240
|
-
test_dataloader: DataLoader | None = None,
|
238
|
+
train_dataloader: DataLoader | DictDataLoader | None = None,
|
239
|
+
val_dataloader: DataLoader | DictDataLoader | None = None,
|
240
|
+
test_dataloader: DataLoader | DictDataLoader | None = None,
|
241
241
|
optimize_step: Callable = optimize_step,
|
242
242
|
device: torch_device | None = None,
|
243
243
|
dtype: torch_dtype | None = None,
|
@@ -252,9 +252,9 @@ class Trainer(BaseTrainer):
|
|
252
252
|
config (TrainConfig): Training configuration object.
|
253
253
|
loss_fn (str | Callable ): Loss function used for training.
|
254
254
|
If not specified, default mse loss will be used.
|
255
|
-
train_dataloader (DataLoader | None): DataLoader for training data.
|
256
|
-
val_dataloader (DataLoader | None): DataLoader for validation data.
|
257
|
-
test_dataloader (DataLoader | None): DataLoader for test data.
|
255
|
+
train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
|
256
|
+
val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
|
257
|
+
test_dataloader (DataLoader | DictDataLoader | None): DataLoader for test data.
|
258
258
|
optimize_step (Callable): Function to execute an optimization step.
|
259
259
|
device (torch_device): Device to use for computation.
|
260
260
|
dtype (torch_dtype): Data type for computation.
|
@@ -285,7 +285,9 @@ class Trainer(BaseTrainer):
|
|
285
285
|
self.data_dtype = float64 if (self.dtype == complex128) else float32
|
286
286
|
|
287
287
|
def fit(
|
288
|
-
self,
|
288
|
+
self,
|
289
|
+
train_dataloader: DataLoader | DictDataLoader | None = None,
|
290
|
+
val_dataloader: DataLoader | DictDataLoader | None = None,
|
289
291
|
) -> tuple[nn.Module, optim.Optimizer]:
|
290
292
|
"""
|
291
293
|
Fits the model using the specified training configuration.
|
@@ -294,8 +296,8 @@ class Trainer(BaseTrainer):
|
|
294
296
|
provided in the trainer will be used.
|
295
297
|
|
296
298
|
Args:
|
297
|
-
train_dataloader (DataLoader | None): DataLoader for training data.
|
298
|
-
val_dataloader (DataLoader | None): DataLoader for validation data.
|
299
|
+
train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
|
300
|
+
val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
|
299
301
|
|
300
302
|
Returns:
|
301
303
|
tuple[nn.Module, optim.Optimizer]: The trained model and optimizer.
|
@@ -336,10 +338,8 @@ class Trainer(BaseTrainer):
|
|
336
338
|
TimeRemainingColumn(elapsed_when_finished=True),
|
337
339
|
)
|
338
340
|
|
339
|
-
# Quick Fix for
|
340
|
-
|
341
|
-
self.model_old = copy.deepcopy(self.model)
|
342
|
-
self.optimizer_old = copy.deepcopy(self.optimizer)
|
341
|
+
# Quick Fix for iteration 0
|
342
|
+
self._reset_model_and_opt()
|
343
343
|
|
344
344
|
# Run validation at the start if specified in the configuration
|
345
345
|
self.perform_val = self.config.val_every > 0
|
@@ -415,16 +415,10 @@ class Trainer(BaseTrainer):
|
|
415
415
|
"""
|
416
416
|
self.model.train()
|
417
417
|
train_epoch_loss_metrics = []
|
418
|
-
#
|
419
|
-
|
420
|
-
# before step of optimization
|
421
|
-
# To align them with model/optimizer correctly, we checkpoint
|
422
|
-
# the older copy of the model.
|
423
|
-
# TODO: review optimize_step to provide iteration aligned model and loss.
|
424
|
-
self.model_old = copy.deepcopy(self.model)
|
425
|
-
self.optimizer_old = copy.deepcopy(self.optimizer)
|
418
|
+
# Quick Fix for iteration 0
|
419
|
+
self._reset_model_and_opt()
|
426
420
|
|
427
|
-
for batch in self.
|
421
|
+
for batch in self._batch_iter(dataloader, self.num_training_batches):
|
428
422
|
self.on_train_batch_start(batch)
|
429
423
|
train_batch_loss_metrics = self.run_train_batch(batch)
|
430
424
|
train_epoch_loss_metrics.append(train_batch_loss_metrics)
|
@@ -475,7 +469,7 @@ class Trainer(BaseTrainer):
|
|
475
469
|
self.ng_params = ng_params
|
476
470
|
loss_metrics = loss, metrics
|
477
471
|
|
478
|
-
return self.
|
472
|
+
return self._modify_batch_end_loss_metrics(loss_metrics)
|
479
473
|
|
480
474
|
@BaseTrainer.callback("val_epoch")
|
481
475
|
def run_validation(self, dataloader: DataLoader) -> list[tuple[torch.Tensor, dict[str, Any]]]:
|
@@ -493,7 +487,7 @@ class Trainer(BaseTrainer):
|
|
493
487
|
self.model.eval()
|
494
488
|
val_epoch_loss_metrics = []
|
495
489
|
|
496
|
-
for batch in self.
|
490
|
+
for batch in self._batch_iter(dataloader, self.num_validation_batches):
|
497
491
|
self.on_val_batch_start(batch)
|
498
492
|
val_batch_loss_metrics = self.run_val_batch(batch)
|
499
493
|
val_epoch_loss_metrics.append(val_batch_loss_metrics)
|
@@ -514,7 +508,7 @@ class Trainer(BaseTrainer):
|
|
514
508
|
"""
|
515
509
|
with torch.no_grad():
|
516
510
|
loss_metrics = self.loss_fn(self.model, batch)
|
517
|
-
return self.
|
511
|
+
return self._modify_batch_end_loss_metrics(loss_metrics)
|
518
512
|
|
519
513
|
def test(self, test_dataloader: DataLoader = None) -> list[tuple[torch.Tensor, dict[str, Any]]]:
|
520
514
|
"""
|
@@ -537,7 +531,7 @@ class Trainer(BaseTrainer):
|
|
537
531
|
self.model.eval()
|
538
532
|
test_loss_metrics = []
|
539
533
|
|
540
|
-
for batch in self.
|
534
|
+
for batch in self._batch_iter(test_dataloader, self.num_training_batches):
|
541
535
|
self.on_test_batch_start(batch)
|
542
536
|
loss_metrics = self.run_test_batch(batch)
|
543
537
|
test_loss_metrics.append(loss_metrics)
|
@@ -560,11 +554,11 @@ class Trainer(BaseTrainer):
|
|
560
554
|
"""
|
561
555
|
with torch.no_grad():
|
562
556
|
loss_metrics = self.loss_fn(self.model, batch)
|
563
|
-
return self.
|
557
|
+
return self._modify_batch_end_loss_metrics(loss_metrics)
|
564
558
|
|
565
|
-
def
|
559
|
+
def _batch_iter(
|
566
560
|
self,
|
567
|
-
dataloader: DataLoader,
|
561
|
+
dataloader: DataLoader | DictDataLoader,
|
568
562
|
num_batches: int,
|
569
563
|
) -> Iterable[tuple[torch.Tensor, ...] | None]:
|
570
564
|
"""
|
@@ -587,7 +581,7 @@ class Trainer(BaseTrainer):
|
|
587
581
|
# batch = data_to_device(batch, device=self.device, dtype=self.data_dtype)
|
588
582
|
yield batch
|
589
583
|
|
590
|
-
def
|
584
|
+
def _modify_batch_end_loss_metrics(
|
591
585
|
self, loss_metrics: tuple[torch.Tensor, dict[str, Any]]
|
592
586
|
) -> tuple[torch.Tensor, dict[str, Any]]:
|
593
587
|
"""
|
@@ -611,6 +605,28 @@ class Trainer(BaseTrainer):
|
|
611
605
|
return loss, updated_metrics
|
612
606
|
return loss_metrics
|
613
607
|
|
608
|
+
def _reset_model_and_opt(self) -> None:
|
609
|
+
"""
|
610
|
+
Save model_old and optimizer_old for epoch 0.
|
611
|
+
|
612
|
+
This allows us to create a copy of model
|
613
|
+
and optimizer before running the optimization.
|
614
|
+
|
615
|
+
We do this because optimize step provides loss, metrics
|
616
|
+
before step of optimization
|
617
|
+
To align them with model/optimizer correctly, we checkpoint
|
618
|
+
the older copy of the model.
|
619
|
+
"""
|
620
|
+
|
621
|
+
# TODO: review optimize_step to provide iteration aligned model and loss.
|
622
|
+
try:
|
623
|
+
# Deep copy model and optimizer to maintain checkpoints
|
624
|
+
self.model_old = copy.deepcopy(self.model)
|
625
|
+
self.optimizer_old = copy.deepcopy(self.optimizer)
|
626
|
+
except Exception:
|
627
|
+
self.model_old = self.model
|
628
|
+
self.optimizer_old = self.optimizer
|
629
|
+
|
614
630
|
def build_optimize_result(
|
615
631
|
self,
|
616
632
|
result: None
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: qadence
|
3
|
-
Version: 1.9.
|
3
|
+
Version: 1.9.1
|
4
4
|
Summary: Pasqal interface for circuit-based quantum computing SDKs
|
5
5
|
Author-email: Aleksander Wennersteen <aleksander.wennersteen@pasqal.com>, Gert-Jan Both <gert-jan.both@pasqal.com>, Niklas Heim <niklas.heim@pasqal.com>, Mario Dagrada <mario.dagrada@pasqal.com>, Vincent Elfving <vincent.elfving@pasqal.com>, Dominik Seitz <dominik.seitz@pasqal.com>, Roland Guichard <roland.guichard@pasqal.com>, "Joao P. Moutinho" <joao.moutinho@pasqal.com>, Vytautas Abramavicius <vytautas.abramavicius@pasqal.com>, Gergana Velikova <gergana.velikova@pasqal.com>, Eduardo Maschio <eduardo.maschio@pasqal.com>, Smit Chaudhary <smit.chaudhary@pasqal.com>, Ignacio Fernández Graña <ignacio.fernandez-grana@pasqal.com>, Charles Moussa <charles.moussa@pasqal.com>, Giorgio Tosti Balducci <giorgio.tosti-balducci@pasqal.com>, Daniele Cucurachi <daniele.cucurachi@pasqal.com>
|
6
6
|
License: Apache 2.0
|
@@ -108,17 +108,17 @@ qadence/ml_tools/optimize_step.py,sha256=wUnxfWy0c9rEKe41-26On1bPFBwmSYBF4WCGn76
|
|
108
108
|
qadence/ml_tools/parameters.py,sha256=gew2Kq_5-RgRpaTvs8eauVhgo0sTqqDQEV6WHFEiLGM,1301
|
109
109
|
qadence/ml_tools/stages.py,sha256=qW2phMIvQBLM3tn2UoGN-ePiBnZoNq5k844eHVnnn8Y,1407
|
110
110
|
qadence/ml_tools/tensors.py,sha256=xZ9ZRzOqEaMgLUGWQf1najDmL6iLuN1ojCGVFs1Tm94,1337
|
111
|
-
qadence/ml_tools/trainer.py,sha256=
|
111
|
+
qadence/ml_tools/trainer.py,sha256=u9Mxv9WwRlYScLozT1Qltf1tNYLAUgn3oiz2E8bLpx0,26803
|
112
112
|
qadence/ml_tools/utils.py,sha256=PW8FyoV0mG_DtN1U8njTDV5qxZ0EK4mnFwMAsLBArfk,1410
|
113
113
|
qadence/ml_tools/callbacks/__init__.py,sha256=XaUKmyQZaqxI0jvKnWCpIBgnX5y4Kczcbn2FRiomFu4,655
|
114
114
|
qadence/ml_tools/callbacks/callback.py,sha256=F9tbXBBv3ZKTFbm0fGBZIZtTRO63jLazMk_oeL77dyE,16289
|
115
115
|
qadence/ml_tools/callbacks/callbackmanager.py,sha256=HwxgbqJi1GWYg2lgUqEyw9Y6a71YG_m5DmhpaeB6kLs,8007
|
116
116
|
qadence/ml_tools/callbacks/saveload.py,sha256=2z8v1A3qIIPZuusEcSNqgYTnKGKkDj71KvY_atJvKnM,6015
|
117
|
-
qadence/ml_tools/callbacks/writer_registry.py,sha256=
|
117
|
+
qadence/ml_tools/callbacks/writer_registry.py,sha256=FVM13j1-mv1Qt-v2QgkRFSB_uQ1bmezr5v6UKfeh3as,15264
|
118
118
|
qadence/ml_tools/loss/__init__.py,sha256=d_0FlisdmgLY0qL1PeaabbcWX1B42RBdm7220cfzSN4,247
|
119
119
|
qadence/ml_tools/loss/loss.py,sha256=Bditg8nelMEpG4Yt0aopcAQz84xIc6O-AGUO2M0nqbM,2982
|
120
120
|
qadence/ml_tools/train_utils/__init__.py,sha256=1A2FlFg7kn68R1fdRC73S8DzA9gkBW7whdNHjzH5UTA,235
|
121
|
-
qadence/ml_tools/train_utils/base_trainer.py,sha256=
|
121
|
+
qadence/ml_tools/train_utils/base_trainer.py,sha256=giOcBRMjgbq9sLjqck6MCWH8V1MCVBHarWuFrS-ahbw,20442
|
122
122
|
qadence/ml_tools/train_utils/config_manager.py,sha256=dps94qfiwjhoY_aQp5RvQPd9zW_MIN2knw1UaDaYrKs,6896
|
123
123
|
qadence/noise/__init__.py,sha256=tnChHv7FzOaV8C7O0P2l_gfjrpmHg8JaNhZprL33CP4,161
|
124
124
|
qadence/noise/protocols.py,sha256=SPHJi5AiIOcz6U_iXY3ddVHk3cl9UHSDKk49eMTX2QM,8586
|
@@ -137,7 +137,7 @@ qadence/transpile/flatten.py,sha256=EdhSG5WyF56nbnxINNLqrHgY84MRM1YFjT3fR4aph5Q,
|
|
137
137
|
qadence/transpile/invert.py,sha256=KAefHTG2AWr39aengVhXrzCtJPhrZC-ZnL6vYvmbnY0,4867
|
138
138
|
qadence/transpile/noise.py,sha256=LDcDJtQGkgUPkL2t69gg6AScTb-p3J3SxCDZbYOu1L8,1668
|
139
139
|
qadence/transpile/transpile.py,sha256=6MRRkk1OS279L1fwUQjazA6qlfpbd-T_EJMKT8hAhOU,2721
|
140
|
-
qadence-1.9.
|
141
|
-
qadence-1.9.
|
142
|
-
qadence-1.9.
|
143
|
-
qadence-1.9.
|
140
|
+
qadence-1.9.1.dist-info/METADATA,sha256=LR5dgYAA874bZyIRlUpSxHmjPY3Ww-Yos7p_LLZtNT0,9842
|
141
|
+
qadence-1.9.1.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87
|
142
|
+
qadence-1.9.1.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
143
|
+
qadence-1.9.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|