qadence 1.10.2__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qadence/backends/horqrux/convert_ops.py +1 -1
- qadence/blocks/block_to_tensor.py +29 -32
- qadence/blocks/matrix.py +4 -0
- qadence/constructors/__init__.py +7 -1
- qadence/constructors/hamiltonians.py +96 -9
- qadence/mitigations/analog_zne.py +6 -2
- qadence/ml_tools/__init__.py +3 -2
- qadence/ml_tools/callbacks/callback.py +80 -50
- qadence/ml_tools/callbacks/callbackmanager.py +3 -2
- qadence/ml_tools/callbacks/writer_registry.py +3 -2
- qadence/ml_tools/config.py +66 -5
- qadence/ml_tools/constructors.py +9 -62
- qadence/ml_tools/data.py +4 -0
- qadence/ml_tools/information/__init__.py +3 -0
- qadence/ml_tools/information/information_content.py +339 -0
- qadence/ml_tools/models.py +69 -4
- qadence/ml_tools/optimize_step.py +1 -2
- qadence/ml_tools/train_utils/__init__.py +3 -1
- qadence/ml_tools/train_utils/accelerator.py +480 -0
- qadence/ml_tools/train_utils/config_manager.py +7 -7
- qadence/ml_tools/train_utils/distribution.py +209 -0
- qadence/ml_tools/train_utils/execution.py +421 -0
- qadence/ml_tools/trainer.py +291 -98
- qadence/operations/primitive.py +0 -4
- qadence/types.py +7 -11
- qadence/utils.py +45 -0
- {qadence-1.10.2.dist-info → qadence-1.11.0.dist-info}/METADATA +16 -13
- {qadence-1.10.2.dist-info → qadence-1.11.0.dist-info}/RECORD +30 -25
- {qadence-1.10.2.dist-info → qadence-1.11.0.dist-info}/WHEEL +0 -0
- {qadence-1.10.2.dist-info → qadence-1.11.0.dist-info}/licenses/LICENSE +0 -0
qadence/ml_tools/trainer.py
CHANGED
@@ -4,21 +4,20 @@ import copy
|
|
4
4
|
from itertools import islice
|
5
5
|
from logging import getLogger
|
6
6
|
from typing import Any, Callable, Iterable, cast
|
7
|
-
|
8
|
-
import torch
|
9
7
|
from nevergrad.optimization.base import Optimizer as NGOptimizer
|
10
|
-
|
11
|
-
from torch import
|
12
|
-
from torch import device as torch_device
|
13
|
-
from torch import dtype as torch_dtype
|
8
|
+
import torch
|
9
|
+
from torch import nn, optim
|
14
10
|
from torch.utils.data import DataLoader
|
11
|
+
from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn
|
15
12
|
|
16
13
|
from qadence.ml_tools.config import TrainConfig
|
17
|
-
from qadence.ml_tools.data import DictDataLoader, OptimizeResult
|
14
|
+
from qadence.ml_tools.data import DictDataLoader, OptimizeResult, data_to_device
|
15
|
+
from qadence.ml_tools.information import InformationContent
|
18
16
|
from qadence.ml_tools.optimize_step import optimize_step, update_ng_parameters
|
19
17
|
from qadence.ml_tools.stages import TrainingStage
|
20
18
|
|
21
19
|
from .train_utils.base_trainer import BaseTrainer
|
20
|
+
from .train_utils.accelerator import Accelerator
|
22
21
|
|
23
22
|
logger = getLogger("ml_tools")
|
24
23
|
|
@@ -37,11 +36,6 @@ class Trainer(BaseTrainer):
|
|
37
36
|
Attributes:
|
38
37
|
current_epoch (int): The current epoch number.
|
39
38
|
global_step (int): The global step across all epochs.
|
40
|
-
log_device (str): Device for logging, default is "cpu".
|
41
|
-
device (torch_device): Device used for computation.
|
42
|
-
dtype (torch_dtype | None): Data type used for computation.
|
43
|
-
data_dtype (torch_dtype | None): Data type for data.
|
44
|
-
Depends on the model's data type.
|
45
39
|
|
46
40
|
Inherited Attributes:
|
47
41
|
use_grad (bool): Indicates if gradients are used for optimization. Default is True.
|
@@ -239,8 +233,6 @@ class Trainer(BaseTrainer):
|
|
239
233
|
val_dataloader: DataLoader | DictDataLoader | None = None,
|
240
234
|
test_dataloader: DataLoader | DictDataLoader | None = None,
|
241
235
|
optimize_step: Callable = optimize_step,
|
242
|
-
device: torch_device | None = None,
|
243
|
-
dtype: torch_dtype | None = None,
|
244
236
|
max_batches: int | None = None,
|
245
237
|
):
|
246
238
|
"""
|
@@ -256,8 +248,6 @@ class Trainer(BaseTrainer):
|
|
256
248
|
val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
|
257
249
|
test_dataloader (DataLoader | DictDataLoader | None): DataLoader for test data.
|
258
250
|
optimize_step (Callable): Function to execute an optimization step.
|
259
|
-
device (torch_device): Device to use for computation.
|
260
|
-
dtype (torch_dtype): Data type for computation.
|
261
251
|
max_batches (int | None): Maximum number of batches to process per epoch.
|
262
252
|
This is only valid in case of finite TensorDataset dataloaders.
|
263
253
|
if max_batches is not None, the maximum number of batches used will
|
@@ -277,13 +267,21 @@ class Trainer(BaseTrainer):
|
|
277
267
|
)
|
278
268
|
self.current_epoch: int = 0
|
279
269
|
self.global_step: int = 0
|
280
|
-
self.
|
281
|
-
self.
|
282
|
-
|
283
|
-
|
284
|
-
self.
|
285
|
-
|
286
|
-
|
270
|
+
self._stop_training: torch.Tensor = torch.tensor(0, dtype=torch.int)
|
271
|
+
self.progress: Progress | None = None
|
272
|
+
|
273
|
+
# Integration with Accelerator:
|
274
|
+
self.accelerator = Accelerator(
|
275
|
+
backend=config.backend,
|
276
|
+
nprocs=config.nprocs,
|
277
|
+
compute_setup=config.compute_setup,
|
278
|
+
dtype=config.dtype,
|
279
|
+
log_setup=config.log_setup,
|
280
|
+
)
|
281
|
+
# Decorate the unbound Trainer.fit method with accelerator.distribute.
|
282
|
+
# We use __get__ to bind the decorated method to the current instance,
|
283
|
+
# ensuring that 'self' is passed only once when self.fit is called.
|
284
|
+
self.fit = self.accelerator.distribute(Trainer.fit).__get__(self, Trainer) # type: ignore[method-assign]
|
287
285
|
|
288
286
|
def fit(
|
289
287
|
self,
|
@@ -322,26 +320,30 @@ class Trainer(BaseTrainer):
|
|
322
320
|
The callback_manager.start_training takes care of loading checkpoint,
|
323
321
|
and setting up the writer.
|
324
322
|
"""
|
325
|
-
self.
|
326
|
-
|
323
|
+
self._stop_training = torch.tensor(
|
324
|
+
0, dtype=torch.int, device=self.accelerator.execution.device
|
325
|
+
)
|
326
|
+
# initalize config in the first process, and broadcast it to all processes
|
327
|
+
if self.accelerator.rank == 0:
|
328
|
+
self.config_manager.initialize_config()
|
329
|
+
self.config_manager = self.accelerator.broadcast(self.config_manager, src=0)
|
327
330
|
self.callback_manager.start_training(trainer=self)
|
328
331
|
|
329
|
-
#
|
330
|
-
|
331
|
-
self.
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
# Progress bar for training visualization
|
336
|
-
self.progress: Progress = Progress(
|
337
|
-
TextColumn("[progress.description]{task.description}"),
|
338
|
-
BarColumn(),
|
339
|
-
TaskProgressColumn(),
|
340
|
-
TimeRemainingColumn(elapsed_when_finished=True),
|
332
|
+
# Integration with Accelerator: prepare the model, optimizer, and dataloaders.
|
333
|
+
(self.model, self.optimizer, self.train_dataloader, self.val_dataloader) = (
|
334
|
+
self.accelerator.prepare(
|
335
|
+
self.model, self.optimizer, self.train_dataloader, self.val_dataloader
|
336
|
+
)
|
341
337
|
)
|
342
338
|
|
343
|
-
#
|
344
|
-
self.
|
339
|
+
# Progress bar for training visualization
|
340
|
+
if self.accelerator.world_size == 1:
|
341
|
+
self.progress = Progress(
|
342
|
+
TextColumn("[progress.description]{task.description}"),
|
343
|
+
BarColumn(),
|
344
|
+
TaskProgressColumn(),
|
345
|
+
TimeRemainingColumn(elapsed_when_finished=True),
|
346
|
+
)
|
345
347
|
|
346
348
|
# Run validation at the start if specified in the configuration
|
347
349
|
self.perform_val = self.config.val_every > 0
|
@@ -355,7 +357,11 @@ class Trainer(BaseTrainer):
|
|
355
357
|
@BaseTrainer.callback("train")
|
356
358
|
def _train(self) -> list[list[tuple[torch.Tensor, dict[str, Any]]]]:
|
357
359
|
"""
|
358
|
-
Runs the main training loop
|
360
|
+
Runs the main training loop over multiple epochs.
|
361
|
+
|
362
|
+
This method sets up the training process by performing any necessary pre-training
|
363
|
+
actions (via `on_train_start`), configuring progress tracking (if available), and then
|
364
|
+
iteratively calling `_train_epoch` to run through the epochs.
|
359
365
|
|
360
366
|
Returns:
|
361
367
|
list[list[tuple[torch.Tensor, dict[str, Any]]]]: Training loss
|
@@ -364,45 +370,97 @@ class Trainer(BaseTrainer):
|
|
364
370
|
Epochs -> Training Batches -> (loss, metrics)
|
365
371
|
"""
|
366
372
|
self.on_train_start()
|
367
|
-
|
368
|
-
|
373
|
+
epoch_start, epoch_end = (
|
374
|
+
self.global_step,
|
375
|
+
self.global_step + self.config_manager.config.max_iter + 1,
|
376
|
+
)
|
369
377
|
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
val_task = self.progress.add_task(
|
376
|
-
"Validation",
|
377
|
-
total=(self.config_manager.config.max_iter + 1) / self.config.val_every,
|
378
|
+
if self.accelerator.world_size == 1 and self.progress:
|
379
|
+
# Progress setup is only available for non-spawned training.
|
380
|
+
with self.progress:
|
381
|
+
train_task = self.progress.add_task(
|
382
|
+
"Training", total=self.config_manager.config.max_iter
|
378
383
|
)
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
if self.perform_val and self.current_epoch % self.config.val_every == 0:
|
392
|
-
self.on_val_epoch_start()
|
393
|
-
val_epoch_loss_metrics = self.run_validation(self.val_dataloader)
|
394
|
-
val_losses.append(val_epoch_loss_metrics)
|
395
|
-
self.on_val_epoch_end(val_epoch_loss_metrics)
|
396
|
-
self.progress.update(val_task, advance=1)
|
397
|
-
|
398
|
-
self.progress.update(train_task, advance=1)
|
399
|
-
except KeyboardInterrupt:
|
400
|
-
logger.info("Terminating training gracefully after the current iteration.")
|
401
|
-
break
|
384
|
+
if self.perform_val:
|
385
|
+
val_task = self.progress.add_task(
|
386
|
+
"Validation",
|
387
|
+
total=(self.config_manager.config.max_iter + 1) / self.config.val_every,
|
388
|
+
)
|
389
|
+
else:
|
390
|
+
val_task = None
|
391
|
+
train_losses, val_losses = self._train_epochs(
|
392
|
+
epoch_start, epoch_end, train_task, val_task
|
393
|
+
)
|
394
|
+
else:
|
395
|
+
train_losses, val_losses = self._train_epochs(epoch_start, epoch_end)
|
402
396
|
|
403
397
|
self.on_train_end(train_losses, val_losses)
|
404
398
|
return train_losses
|
405
399
|
|
400
|
+
def _train_epochs(
|
401
|
+
self,
|
402
|
+
epoch_start: int,
|
403
|
+
epoch_end: int,
|
404
|
+
train_task: int | None = None,
|
405
|
+
val_task: int | None = None,
|
406
|
+
) -> tuple[
|
407
|
+
list[list[tuple[torch.Tensor, dict[str, Any]]]],
|
408
|
+
list[list[tuple[torch.Tensor, dict[str, Any]]]],
|
409
|
+
]:
|
410
|
+
"""
|
411
|
+
Executes the training loop for a series of epochs.
|
412
|
+
|
413
|
+
Args:
|
414
|
+
epoch_start (int): The starting epoch index.
|
415
|
+
epoch_end (int): The ending epoch index (non-inclusive).
|
416
|
+
train_task (int | None, optional): The progress bar task ID for training updates.
|
417
|
+
If provided, the progress bar will be updated after each epoch. Defaults to None.
|
418
|
+
val_task (int | None, optional): The progress bar task ID for validation updates.
|
419
|
+
If provided and validation is enabled, the progress bar will be updated after each validation run.
|
420
|
+
Defaults to None.
|
421
|
+
|
422
|
+
Returns:
|
423
|
+
list[list[tuple[torch.Tensor, dict[str, Any]]]]: A tuple of
|
424
|
+
Training loss metrics for all epochs.
|
425
|
+
list -> list -> tuples
|
426
|
+
Epochs -> Training Batches -> (loss, metrics)
|
427
|
+
And Validation loss metrics for all epochs
|
428
|
+
list -> list -> tuples
|
429
|
+
Epochs -> Training Batches -> (loss, metrics)
|
430
|
+
"""
|
431
|
+
train_losses = []
|
432
|
+
val_losses = []
|
433
|
+
|
434
|
+
# Iterate over the epochs
|
435
|
+
for epoch in range(epoch_start, epoch_end):
|
436
|
+
if not self.stop_training():
|
437
|
+
try:
|
438
|
+
self.current_epoch = epoch
|
439
|
+
self.on_train_epoch_start()
|
440
|
+
train_epoch_loss_metrics = self.run_training(self.train_dataloader)
|
441
|
+
train_losses.append(train_epoch_loss_metrics)
|
442
|
+
self.on_train_epoch_end(train_epoch_loss_metrics)
|
443
|
+
|
444
|
+
# Run validation periodically if specified
|
445
|
+
if self.perform_val and (epoch % self.config.val_every == 0):
|
446
|
+
self.on_val_epoch_start()
|
447
|
+
val_epoch_loss_metrics = self.run_validation(self.val_dataloader)
|
448
|
+
val_losses.append(val_epoch_loss_metrics)
|
449
|
+
self.on_val_epoch_end(val_epoch_loss_metrics)
|
450
|
+
if val_task is not None:
|
451
|
+
self.progress.update(val_task, advance=1) # type: ignore[union-attr]
|
452
|
+
|
453
|
+
if train_task is not None:
|
454
|
+
self.progress.update(train_task, advance=1) # type: ignore[union-attr]
|
455
|
+
except KeyboardInterrupt:
|
456
|
+
self._stop_training.fill_(1)
|
457
|
+
else:
|
458
|
+
if self.accelerator.rank == 0:
|
459
|
+
logger.info("Terminating training gracefully after the current iteration.")
|
460
|
+
self.accelerator.finalize()
|
461
|
+
break
|
462
|
+
return train_losses, val_losses
|
463
|
+
|
406
464
|
@BaseTrainer.callback("train_epoch")
|
407
465
|
def run_training(self, dataloader: DataLoader) -> list[tuple[torch.Tensor, dict[str, Any]]]:
|
408
466
|
"""
|
@@ -418,12 +476,12 @@ class Trainer(BaseTrainer):
|
|
418
476
|
"""
|
419
477
|
self.model.train()
|
420
478
|
train_epoch_loss_metrics = []
|
421
|
-
# Quick Fix for iteration 0
|
422
|
-
self._reset_model_and_opt()
|
423
479
|
|
424
480
|
for batch in self._batch_iter(dataloader, self.num_training_batches):
|
425
481
|
self.on_train_batch_start(batch)
|
426
482
|
train_batch_loss_metrics = self.run_train_batch(batch)
|
483
|
+
if self.config.all_reduce_metrics:
|
484
|
+
train_batch_loss_metrics = self._aggregate_result(train_batch_loss_metrics)
|
427
485
|
train_epoch_loss_metrics.append(train_batch_loss_metrics)
|
428
486
|
self.on_train_batch_end(train_batch_loss_metrics)
|
429
487
|
|
@@ -457,8 +515,8 @@ class Trainer(BaseTrainer):
|
|
457
515
|
optimizer=self.optimizer,
|
458
516
|
loss_fn=self.loss_fn,
|
459
517
|
xs=batch,
|
460
|
-
device=self.device,
|
461
|
-
dtype=self.data_dtype,
|
518
|
+
device=self.accelerator.execution.device,
|
519
|
+
dtype=self.accelerator.execution.data_dtype,
|
462
520
|
)
|
463
521
|
else:
|
464
522
|
# Perform optimization using Nevergrad
|
@@ -472,7 +530,15 @@ class Trainer(BaseTrainer):
|
|
472
530
|
self.ng_params = ng_params
|
473
531
|
loss_metrics = loss, metrics
|
474
532
|
|
475
|
-
|
533
|
+
# --------------------- FIX: Post-Optimization Loss --------------------- #
|
534
|
+
# Because the loss/metrics are returned before the optimization. To sync
|
535
|
+
# model state and current loss/metrics we calculate them again after optimization.
|
536
|
+
# This is not strictly necessary.
|
537
|
+
# TODO: Should be removed if loss can be logged at an unoptimized model state
|
538
|
+
with torch.no_grad():
|
539
|
+
post_update_loss_metrics = self.loss_fn(self.model, batch)
|
540
|
+
|
541
|
+
return self._modify_batch_end_loss_metrics(post_update_loss_metrics)
|
476
542
|
|
477
543
|
@BaseTrainer.callback("val_epoch")
|
478
544
|
def run_validation(self, dataloader: DataLoader) -> list[tuple[torch.Tensor, dict[str, Any]]]:
|
@@ -493,6 +559,8 @@ class Trainer(BaseTrainer):
|
|
493
559
|
for batch in self._batch_iter(dataloader, self.num_validation_batches):
|
494
560
|
self.on_val_batch_start(batch)
|
495
561
|
val_batch_loss_metrics = self.run_val_batch(batch)
|
562
|
+
if self.config.all_reduce_metrics:
|
563
|
+
val_batch_loss_metrics = self._aggregate_result(val_batch_loss_metrics)
|
496
564
|
val_epoch_loss_metrics.append(val_batch_loss_metrics)
|
497
565
|
self.on_val_batch_end(val_batch_loss_metrics)
|
498
566
|
|
@@ -567,6 +635,9 @@ class Trainer(BaseTrainer):
|
|
567
635
|
"""
|
568
636
|
Yields batches from the provided dataloader.
|
569
637
|
|
638
|
+
The batch of data is also moved
|
639
|
+
to the correct device and dtype using accelerator.prepare.
|
640
|
+
|
570
641
|
Args:
|
571
642
|
dataloader ([DataLoader]): The dataloader to iterate over.
|
572
643
|
num_batches (int): The maximum number of batches to yield.
|
@@ -580,9 +651,7 @@ class Trainer(BaseTrainer):
|
|
580
651
|
yield None
|
581
652
|
else:
|
582
653
|
for batch in islice(dataloader, num_batches):
|
583
|
-
|
584
|
-
# batch = data_to_device(batch, device=self.device, dtype=self.data_dtype)
|
585
|
-
yield batch
|
654
|
+
yield self.accelerator.prepare_batch(batch)
|
586
655
|
|
587
656
|
def _modify_batch_end_loss_metrics(
|
588
657
|
self, loss_metrics: tuple[torch.Tensor, dict[str, Any]]
|
@@ -608,27 +677,43 @@ class Trainer(BaseTrainer):
|
|
608
677
|
return loss, updated_metrics
|
609
678
|
return loss_metrics
|
610
679
|
|
611
|
-
def
|
680
|
+
def _aggregate_result(
|
681
|
+
self, result: tuple[torch.Tensor, dict[str, Any]]
|
682
|
+
) -> tuple[torch.Tensor, dict[str, Any]]:
|
612
683
|
"""
|
613
|
-
|
684
|
+
Aggregates the loss and metrics using the Accelerator's all_reduce_dict method if aggregation is enabled.
|
685
|
+
|
686
|
+
Args:
|
687
|
+
result: (tuple[torch.Tensor, dict[str, Any]])
|
688
|
+
The result consisting of loss and metrics.For more details,
|
689
|
+
look at the signature of build_optimize_result.
|
614
690
|
|
615
|
-
|
616
|
-
|
691
|
+
Returns:
|
692
|
+
tuple[torch.Tensor, dict[str, Any]]: The aggregated loss and metrics.
|
693
|
+
"""
|
694
|
+
loss, metrics = result
|
695
|
+
if self.config.all_reduce_metrics:
|
696
|
+
reduced = self.accelerator.all_reduce_dict({"loss": loss, **metrics})
|
697
|
+
loss = reduced.pop("loss")
|
698
|
+
metrics = reduced
|
699
|
+
return loss, metrics
|
700
|
+
else:
|
701
|
+
return loss, metrics
|
617
702
|
|
618
|
-
|
619
|
-
before step of optimization
|
620
|
-
To align them with model/optimizer correctly, we checkpoint
|
621
|
-
the older copy of the model.
|
703
|
+
def stop_training(self) -> bool:
|
622
704
|
"""
|
705
|
+
Helper function to indicate if the training should be stopped.
|
623
706
|
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
self.
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
self.
|
707
|
+
We all_reduce the indicator across all processes to ensure all processes are stopped.
|
708
|
+
|
709
|
+
Notes:
|
710
|
+
self._stop_training indicator indicates if the training should be stopped.
|
711
|
+
0 is continue. 1 is stop.
|
712
|
+
"""
|
713
|
+
_stop_training = self.accelerator.all_reduce_dict(
|
714
|
+
{"indicator": self._stop_training}, op="max"
|
715
|
+
)
|
716
|
+
return bool(_stop_training["indicator"] > 0)
|
632
717
|
|
633
718
|
def build_optimize_result(
|
634
719
|
self,
|
@@ -709,5 +794,113 @@ class Trainer(BaseTrainer):
|
|
709
794
|
|
710
795
|
# Store the optimization result
|
711
796
|
self.opt_result = OptimizeResult(
|
712
|
-
self.current_epoch,
|
797
|
+
self.current_epoch,
|
798
|
+
self.model,
|
799
|
+
self.optimizer,
|
800
|
+
loss,
|
801
|
+
metrics,
|
802
|
+
rank=self.accelerator.rank,
|
803
|
+
device=self.accelerator.execution.device,
|
713
804
|
)
|
805
|
+
|
806
|
+
def get_ic_grad_bounds(
|
807
|
+
self,
|
808
|
+
eta: float,
|
809
|
+
epsilons: torch.Tensor,
|
810
|
+
variation_multiple: int = 20,
|
811
|
+
dataloader: DataLoader | DictDataLoader | None = None,
|
812
|
+
) -> tuple[float, float, float]:
|
813
|
+
"""
|
814
|
+
Calculate the bounds on the gradient norm of the loss using Information Content.
|
815
|
+
|
816
|
+
Args:
|
817
|
+
eta (float): The sensitivity IC.
|
818
|
+
epsilons (torch.Tensor): The epsilons to use for thresholds to for discretization of the
|
819
|
+
finite derivatives.
|
820
|
+
variation_multiple (int): The number of sets of variational parameters to generate per
|
821
|
+
each variational parameter. The number of variational parameters required for the
|
822
|
+
statisctiacal analysis scales linearly with the amount of them present in the
|
823
|
+
model. This is that linear factor.
|
824
|
+
dataloader (DataLoader | DictDataLoader | None): The dataloader for training data. A
|
825
|
+
new dataloader can be provided, or the dataloader provided in the trinaer will be
|
826
|
+
used. In case no dataloaders are provided at either places, it assumes that the
|
827
|
+
model does not require any input data.
|
828
|
+
|
829
|
+
Returns:
|
830
|
+
tuple[float, float, float]: The max IC lower bound, max IC upper bound, and sensitivity
|
831
|
+
IC upper bound.
|
832
|
+
|
833
|
+
Examples:
|
834
|
+
```python
|
835
|
+
import torch
|
836
|
+
from torch.optim.adam import Adam
|
837
|
+
|
838
|
+
from qadence.constructors import ObservableConfig
|
839
|
+
from qadence.ml_tools.config import AnsatzConfig, FeatureMapConfig, TrainConfig
|
840
|
+
from qadence.ml_tools.data import to_dataloader
|
841
|
+
from qadence.ml_tools.models import QNN
|
842
|
+
from qadence.ml_tools.optimize_step import optimize_step
|
843
|
+
from qadence.ml_tools.trainer import Trainer
|
844
|
+
from qadence.operations.primitive import Z
|
845
|
+
|
846
|
+
fm_config = FeatureMapConfig(num_features=1)
|
847
|
+
ansatz_config = AnsatzConfig(depth=4)
|
848
|
+
obs_config = ObservableConfig(detuning=Z)
|
849
|
+
|
850
|
+
qnn = QNN.from_configs(
|
851
|
+
register=4,
|
852
|
+
obs_config=obs_config,
|
853
|
+
fm_config=fm_config,
|
854
|
+
ansatz_config=ansatz_config,
|
855
|
+
)
|
856
|
+
|
857
|
+
optimizer = Adam(qnn.parameters(), lr=0.001)
|
858
|
+
|
859
|
+
batch_size = 25
|
860
|
+
x = torch.linspace(0, 1, 32).reshape(-1, 1)
|
861
|
+
y = torch.sin(x)
|
862
|
+
train_loader = to_dataloader(x, y, batch_size=batch_size, infinite=True)
|
863
|
+
|
864
|
+
train_config = TrainConfig(max_iter=100)
|
865
|
+
|
866
|
+
trainer = Trainer(
|
867
|
+
model=qnn,
|
868
|
+
optimizer=optimizer,
|
869
|
+
config=train_config,
|
870
|
+
loss_fn="mse",
|
871
|
+
train_dataloader=train_loader,
|
872
|
+
optimize_step=optimize_step,
|
873
|
+
)
|
874
|
+
|
875
|
+
# Perform exploratory landscape analysis with Information Content
|
876
|
+
ic_sensitivity_threshold = 1e-4
|
877
|
+
epsilons = torch.logspace(-2, 2, 10)
|
878
|
+
|
879
|
+
max_ic_lower_bound, max_ic_upper_bound, sensitivity_ic_upper_bound = (
|
880
|
+
trainer.get_ic_grad_bounds(
|
881
|
+
eta=ic_sensitivity_threshold,
|
882
|
+
epsilons=epsilons,
|
883
|
+
)
|
884
|
+
)
|
885
|
+
|
886
|
+
# Resume training as usual...
|
887
|
+
|
888
|
+
trainer.fit(train_loader)
|
889
|
+
```
|
890
|
+
"""
|
891
|
+
if not self._use_grad:
|
892
|
+
logger.warning(
|
893
|
+
"Gradient norm bounds are only relevant when using a gradient based optimizer. \
|
894
|
+
Currently the trainer is set to use a gradient-free optimizer."
|
895
|
+
)
|
896
|
+
|
897
|
+
dataloader = dataloader if dataloader is not None else self.train_dataloader
|
898
|
+
|
899
|
+
batch = next(iter(self._batch_iter(dataloader, num_batches=1)))
|
900
|
+
|
901
|
+
ic = InformationContent(self.model, self.loss_fn, batch, epsilons)
|
902
|
+
|
903
|
+
max_ic_lower_bound, max_ic_upper_bound = ic.get_grad_norm_bounds_max_IC()
|
904
|
+
sensitivity_ic_upper_bound = ic.get_grad_norm_bounds_sensitivity_IC(eta)
|
905
|
+
|
906
|
+
return max_ic_lower_bound, max_ic_upper_bound, sensitivity_ic_upper_bound
|
qadence/operations/primitive.py
CHANGED
@@ -376,10 +376,6 @@ class SWAP(PrimitiveBlock):
|
|
376
376
|
def eigenvalues(self) -> Tensor:
|
377
377
|
return torch.tensor([-1, 1, 1, 1], dtype=cdouble)
|
378
378
|
|
379
|
-
@property
|
380
|
-
def n_qubits(self) -> int:
|
381
|
-
return 2
|
382
|
-
|
383
379
|
@property
|
384
380
|
def _block_title(self) -> str:
|
385
381
|
c, t = self.qubit_support
|
qadence/types.py
CHANGED
@@ -445,17 +445,6 @@ class InputDiffMode(StrEnum):
|
|
445
445
|
"""Central finite differencing."""
|
446
446
|
|
447
447
|
|
448
|
-
class ObservableTransform:
|
449
|
-
"""Observable transformation type."""
|
450
|
-
|
451
|
-
SCALE = "scale"
|
452
|
-
"""Use the given values as scale and shift."""
|
453
|
-
RANGE = "range"
|
454
|
-
"""Use the given values as min and max."""
|
455
|
-
NONE = "none"
|
456
|
-
"""No transformation."""
|
457
|
-
|
458
|
-
|
459
448
|
class ExperimentTrackingTool(StrEnum):
|
460
449
|
TENSORBOARD = "tensorboard"
|
461
450
|
"""Use the tensorboard experiment tracker."""
|
@@ -463,6 +452,13 @@ class ExperimentTrackingTool(StrEnum):
|
|
463
452
|
"""Use the ml-flow experiment tracker."""
|
464
453
|
|
465
454
|
|
455
|
+
class ExecutionType(StrEnum):
|
456
|
+
TORCHRUN = "torchrun"
|
457
|
+
"""Torchrun based distribution execution."""
|
458
|
+
DEFAULT = "default"
|
459
|
+
"""Default distribution execution."""
|
460
|
+
|
461
|
+
|
466
462
|
LoggablePlotFunction = Callable[[Module, int], tuple[str, Figure]]
|
467
463
|
|
468
464
|
|
qadence/utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import math
|
4
|
+
import re
|
4
5
|
from collections import Counter
|
5
6
|
from functools import partial
|
6
7
|
from logging import getLogger
|
@@ -13,6 +14,9 @@ from torch import Tensor, stack, vmap
|
|
13
14
|
from torch import complex as make_complex
|
14
15
|
from torch.linalg import eigvals
|
15
16
|
|
17
|
+
from rich.tree import Tree
|
18
|
+
|
19
|
+
from qadence.blocks import AbstractBlock
|
16
20
|
from qadence.types import Endianness, ResultType, TNumber
|
17
21
|
|
18
22
|
if TYPE_CHECKING:
|
@@ -290,3 +294,44 @@ def one_qubit_projector_matrix(state: str) -> Tensor:
|
|
290
294
|
|
291
295
|
P0 = partial(one_qubit_projector, "0")
|
292
296
|
P1 = partial(one_qubit_projector, "1")
|
297
|
+
|
298
|
+
|
299
|
+
def block_to_mathematical_expression(block: Tree | AbstractBlock) -> str:
|
300
|
+
"""Convert a block to a readable mathematical expression.
|
301
|
+
|
302
|
+
Useful for printing Observables as a mathematical expression.
|
303
|
+
|
304
|
+
Args:
|
305
|
+
block (AbstractBlock): Tree instance.
|
306
|
+
|
307
|
+
Returns:
|
308
|
+
str: A mathematical expression.
|
309
|
+
"""
|
310
|
+
block_tree: Tree = block.__rich_tree__() if isinstance(block, AbstractBlock) else block
|
311
|
+
block_title = block_tree.label if isinstance(block_tree.label, str) else ""
|
312
|
+
if "AddBlock" in block_title:
|
313
|
+
block_title = " + ".join(
|
314
|
+
[block_to_mathematical_expression(block_child) for block_child in block_tree.children]
|
315
|
+
)
|
316
|
+
if "KronBlock" in block_title:
|
317
|
+
block_title = " ⊗ ".join(
|
318
|
+
[block_to_mathematical_expression(block_child) for block_child in block_tree.children]
|
319
|
+
)
|
320
|
+
if "mul" in block_title:
|
321
|
+
block_title = re.findall("\d+\.\d+", block_title)[0]
|
322
|
+
coeff = float(block_title)
|
323
|
+
if coeff == 0:
|
324
|
+
block_title = ""
|
325
|
+
elif coeff == 1:
|
326
|
+
block_title = block_to_mathematical_expression(block_tree.children[0])
|
327
|
+
else:
|
328
|
+
block_title += " * " + block_to_mathematical_expression(block_tree.children[0])
|
329
|
+
first_part = block_title[:3]
|
330
|
+
if first_part in [" + ", " ⊗ ", " * "]:
|
331
|
+
block_title = block_title[3:]
|
332
|
+
|
333
|
+
# if too many trees, add parentheses.
|
334
|
+
nb_children = len(block_tree.children)
|
335
|
+
if nb_children > 1:
|
336
|
+
block_title = "(" + block_title + ")"
|
337
|
+
return block_title
|