qadence 1.9.0__py3-none-any.whl → 1.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
qadence/analog/device.py CHANGED
@@ -5,6 +5,8 @@ from dataclasses import dataclass, fields
5
5
  from qadence.analog import AddressingPattern
6
6
  from qadence.types import PI, DeviceType, Interaction
7
7
 
8
+ from .constants import C6_DICT
9
+
8
10
 
9
11
  @dataclass(frozen=True, eq=True)
10
12
  class RydbergDevice:
@@ -41,6 +43,11 @@ class RydbergDevice:
41
43
  type: DeviceType = DeviceType.IDEALIZED
42
44
  """DeviceType.IDEALIZED or REALISTIC to convert to the Pulser backend."""
43
45
 
46
+ @property
47
+ def coeff_ising(self) -> float:
48
+ """Value of C_6."""
49
+ return C6_DICT[self.rydberg_level]
50
+
44
51
  def __post_init__(self) -> None:
45
52
  # FIXME: Currently not supporting custom interaction functions.
46
53
  if self.interaction not in [Interaction.NN, Interaction.XY]:
@@ -259,7 +259,7 @@ class Backend(BackendInterface):
259
259
  for i, param_values_el in enumerate(vals):
260
260
  sequence = self.assign_parameters(circuit, param_values_el)
261
261
  sim_result: CoherentResults = simulate_sequence(sequence, self.config, state)
262
- final_state = sim_result.get_final_state().data.toarray()
262
+ final_state = sim_result.get_final_state().data.to_array()
263
263
  batched_dm[i] = np.flip(final_state)
264
264
  return torch.from_numpy(batched_dm)
265
265
 
@@ -264,7 +264,7 @@ def convert_block(
264
264
  duration=duration,
265
265
  solver=config.ode_solver,
266
266
  steps=config.n_steps_hevo,
267
- noise_operators=noise_operators,
267
+ noise=noise_operators if len(noise_operators) > 0 else None,
268
268
  )
269
269
  ]
270
270
 
@@ -351,22 +351,22 @@ def convert_block(
351
351
  )
352
352
 
353
353
 
354
- def convert_digital_noise(noise: NoiseHandler) -> pyq.noise.NoiseProtocol | None:
354
+ def convert_digital_noise(noise: NoiseHandler) -> pyq.noise.DigitalNoiseProtocol | None:
355
355
  """Convert the digital noise into pyqtorch NoiseProtocol.
356
356
 
357
357
  Args:
358
358
  noise (NoiseHandler): Noise to convert.
359
359
 
360
360
  Returns:
361
- pyq.noise.NoiseProtocol | None: Pyqtorch native noise protocol
361
+ pyq.noise.DigitalNoiseProtocol | None: Pyqtorch native noise protocol
362
362
  if there are any digital noise protocols.
363
363
  """
364
364
  digital_part = noise.filter(NoiseProtocol.DIGITAL)
365
365
  if digital_part is None:
366
366
  return None
367
- return pyq.noise.NoiseProtocol(
367
+ return pyq.noise.DigitalNoiseProtocol(
368
368
  [
369
- pyq.noise.NoiseProtocol(proto, option.get("error_probability"))
369
+ pyq.noise.DigitalNoiseProtocol(proto, option.get("error_probability"))
370
370
  for proto, option in zip(digital_part.protocol, digital_part.options)
371
371
  ]
372
372
  )
qadence/backends/utils.py CHANGED
@@ -110,9 +110,23 @@ def to_list_of_dicts(param_values: ParamDictType) -> list[ParamDictType]:
110
110
 
111
111
 
112
112
  def pyqify(state: Tensor, n_qubits: int = None) -> ArrayLike:
113
- """Convert a state of shape (batch_size, 2**n_qubits) to [2] * n_qubits + [batch_size]."""
113
+ """Convert a state of shape (batch_size, 2**n_qubits) to [2] * n_qubits + [batch_size].
114
+
115
+ Or set the batch_size of a density matrix as the last dimension for PyQTorch.
116
+ """
114
117
  if n_qubits is None:
115
118
  n_qubits = int(log2(state.shape[1]))
119
+ if isinstance(state, DensityMatrix):
120
+ if (
121
+ len(state.shape) != 3
122
+ or (state.shape[1] != 2**n_qubits)
123
+ or (state.shape[1] != state.shape[2])
124
+ ):
125
+ raise ValueError(
126
+ "The initial state must be composed of tensors/arrays of size "
127
+ f"(batch_size, 2**n_qubits, 2**n_qubits). Found: {state.shape = }."
128
+ )
129
+ return torch.einsum("kij->ijk", state)
116
130
  if len(state.shape) != 2 or (state.shape[1] != 2**n_qubits):
117
131
  raise ValueError(
118
132
  "The initial state must be composed of tensors/arrays of size "
@@ -49,7 +49,9 @@ class PSRExpectation(Function):
49
49
  if isinstance(expectation_values[0], list):
50
50
  exp_vals: list = []
51
51
  for expectation_value in expectation_values:
52
- res = list(map(lambda x: x.get_final_state().data.toarray(), expectation_value))
52
+ res = list(
53
+ map(lambda x: x.get_final_state().data.to_array(), expectation_value)
54
+ )
53
55
  exp_vals.append(torch.tensor(res))
54
56
  expectation_values = exp_vals
55
57
  return torch.stack(expectation_values)
@@ -2,9 +2,14 @@ from __future__ import annotations
2
2
 
3
3
  from .callback import (
4
4
  Callback,
5
+ EarlyStopping,
6
+ GradientMonitoring,
5
7
  LoadCheckpoint,
6
8
  LogHyperparameters,
7
9
  LogModelTracker,
10
+ LRSchedulerCosineAnnealing,
11
+ LRSchedulerCyclic,
12
+ LRSchedulerStepDecay,
8
13
  PlotMetrics,
9
14
  PrintMetrics,
10
15
  SaveBestCheckpoint,
@@ -26,5 +31,10 @@ __all__ = [
26
31
  "SaveBestCheckpoint",
27
32
  "SaveCheckpoint",
28
33
  "WriteMetrics",
34
+ "GradientMonitoring",
35
+ "LRSchedulerStepDecay",
36
+ "LRSchedulerCyclic",
37
+ "LRSchedulerCosineAnnealing",
38
+ "EarlyStopping",
29
39
  "get_writer",
30
40
  ]
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import math
4
+ from logging import getLogger
3
5
  from typing import Any, Callable
4
6
 
5
7
  from qadence.ml_tools.callbacks.saveload import load_checkpoint, write_checkpoint
@@ -12,6 +14,8 @@ from qadence.ml_tools.stages import TrainingStage
12
14
  CallbackFunction = Callable[..., Any]
13
15
  CallbackConditionFunction = Callable[..., bool]
14
16
 
17
+ logger = getLogger("ml_tools")
18
+
15
19
 
16
20
  class Callback:
17
21
  """Base class for defining various training callbacks.
@@ -258,7 +262,7 @@ class WriteMetrics(Callback):
258
262
  writer (BaseWriter ): The writer object for logging.
259
263
  """
260
264
  opt_result = trainer.opt_result
261
- writer.write(opt_result)
265
+ writer.write(opt_result.iteration, opt_result.metrics)
262
266
 
263
267
 
264
268
  class PlotMetrics(Callback):
@@ -449,3 +453,323 @@ class LogModelTracker(Callback):
449
453
  writer.log_model(
450
454
  model, trainer.train_dataloader, trainer.val_dataloader, trainer.test_dataloader
451
455
  )
456
+
457
+
458
+ class LRSchedulerStepDecay(Callback):
459
+ """
460
+ Reduces the learning rate by a factor at regular intervals.
461
+
462
+ This callback adjusts the learning rate by multiplying it with a decay factor
463
+ after a specified number of iterations. The learning rate is updated as:
464
+ lr = lr * gamma
465
+
466
+ Example Usage in `TrainConfig`:
467
+ To use `LRSchedulerStepDecay`, include it in the `callbacks` list when setting
468
+ up your `TrainConfig`:
469
+ ```python exec="on" source="material-block" result="json"
470
+ from qadence.ml_tools import TrainConfig
471
+ from qadence.ml_tools.callbacks import LRSchedulerStepDecay
472
+
473
+ # Create an instance of the LRSchedulerStepDecay callback
474
+ lr_step_decay = LRSchedulerStepDecay(on="train_epoch_end",
475
+ called_every=100,
476
+ gamma=0.5)
477
+
478
+ config = TrainConfig(
479
+ max_iter=10000,
480
+ # Print metrics every 1000 training epochs
481
+ print_every=1000,
482
+ # Add the custom callback
483
+ callbacks=[lr_step_decay]
484
+ )
485
+ ```
486
+ """
487
+
488
+ def __init__(self, on: str, called_every: int, gamma: float = 0.5):
489
+ """Initializes the LRSchedulerStepDecay callback.
490
+
491
+ Args:
492
+ on (str): The event to trigger the callback.
493
+ called_every (int): Frequency of callback calls in terms of iterations.
494
+ gamma (float, optional): The decay factor applied to the learning rate.
495
+ A value < 1 reduces the learning rate over time. Default is 0.5.
496
+ """
497
+ super().__init__(on=on, called_every=called_every)
498
+ self.gamma = gamma
499
+
500
+ def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> None:
501
+ """
502
+ Runs the callback to apply step decay to the learning rate.
503
+
504
+ Args:
505
+ trainer (Any): The training object.
506
+ config (TrainConfig): The configuration object.
507
+ writer (BaseWriter): The writer object for logging.
508
+ """
509
+ for param_group in trainer.optimizer.param_groups:
510
+ param_group["lr"] *= self.gamma
511
+
512
+
513
+ class LRSchedulerCyclic(Callback):
514
+ """
515
+ Applies a cyclic learning rate schedule during training.
516
+
517
+ This callback oscillates the learning rate between a minimum (base_lr)
518
+ and a maximum (max_lr) over a defined cycle length (step_size). The learning
519
+ rate follows a triangular wave pattern.
520
+
521
+ Example Usage in `TrainConfig`:
522
+ To use `LRSchedulerCyclic`, include it in the `callbacks` list when setting
523
+ up your `TrainConfig`:
524
+ ```python exec="on" source="material-block" result="json"
525
+ from qadence.ml_tools import TrainConfig
526
+ from qadence.ml_tools.callbacks import LRSchedulerCyclic
527
+
528
+ # Create an instance of the LRSchedulerCyclic callback
529
+ lr_cyclic = LRSchedulerCyclic(on="train_batch_end",
530
+ called_every=1,
531
+ base_lr=0.001,
532
+ max_lr=0.01,
533
+ step_size=2000)
534
+
535
+ config = TrainConfig(
536
+ max_iter=10000,
537
+ # Print metrics every 1000 training epochs
538
+ print_every=1000,
539
+ # Add the custom callback
540
+ callbacks=[lr_cyclic]
541
+ )
542
+ ```
543
+ """
544
+
545
+ def __init__(self, on: str, called_every: int, base_lr: float, max_lr: float, step_size: int):
546
+ """Initializes the LRSchedulerCyclic callback.
547
+
548
+ Args:
549
+ on (str): The event to trigger the callback.
550
+ called_every (int): Frequency of callback calls in terms of iterations.
551
+ base_lr (float): The minimum learning rate.
552
+ max_lr (float): The maximum learning rate.
553
+ step_size (int): Number of iterations for half a cycle.
554
+ """
555
+ super().__init__(on=on, called_every=called_every)
556
+ self.base_lr = base_lr
557
+ self.max_lr = max_lr
558
+ self.step_size = step_size
559
+
560
+ def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> None:
561
+ """
562
+ Adjusts the learning rate cyclically.
563
+
564
+ Args:
565
+ trainer (Any): The training object.
566
+ config (TrainConfig): The configuration object.
567
+ writer (BaseWriter): The writer object for logging.
568
+ """
569
+ cycle = trainer.opt_result.iteration // (2 * self.step_size)
570
+ x = abs(trainer.opt_result.iteration / self.step_size - 2 * cycle - 1)
571
+ scale = max(0, (1 - x))
572
+ new_lr = self.base_lr + (self.max_lr - self.base_lr) * scale
573
+ for param_group in trainer.optimizer.param_groups:
574
+ param_group["lr"] = new_lr
575
+
576
+
577
+ class LRSchedulerCosineAnnealing(Callback):
578
+ """
579
+ Applies cosine annealing to the learning rate during training.
580
+
581
+ This callback decreases the learning rate following a cosine curve,
582
+ starting from the initial learning rate and annealing to a minimum (min_lr).
583
+
584
+ Example Usage in `TrainConfig`:
585
+ To use `LRSchedulerCosineAnnealing`, include it in the `callbacks` list
586
+ when setting up your `TrainConfig`:
587
+ ```python exec="on" source="material-block" result="json"
588
+ from qadence.ml_tools import TrainConfig
589
+ from qadence.ml_tools.callbacks import LRSchedulerCosineAnnealing
590
+
591
+ # Create an instance of the LRSchedulerCosineAnnealing callback
592
+ lr_cosine = LRSchedulerCosineAnnealing(on="train_batch_end",
593
+ called_every=1,
594
+ t_max=5000,
595
+ min_lr=1e-6)
596
+
597
+ config = TrainConfig(
598
+ max_iter=10000,
599
+ # Print metrics every 1000 training epochs
600
+ print_every=1000,
601
+ # Add the custom callback
602
+ callbacks=[lr_cosine]
603
+ )
604
+ ```
605
+ """
606
+
607
+ def __init__(self, on: str, called_every: int, t_max: int, min_lr: float = 0.0):
608
+ """Initializes the LRSchedulerCosineAnnealing callback.
609
+
610
+ Args:
611
+ on (str): The event to trigger the callback.
612
+ called_every (int): Frequency of callback calls in terms of iterations.
613
+ t_max (int): The total number of iterations for one annealing cycle.
614
+ min_lr (float, optional): The minimum learning rate. Default is 0.0.
615
+ """
616
+ super().__init__(on=on, called_every=called_every)
617
+ self.t_max = t_max
618
+ self.min_lr = min_lr
619
+
620
+ def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> None:
621
+ """
622
+ Adjusts the learning rate using cosine annealing.
623
+
624
+ Args:
625
+ trainer (Any): The training object.
626
+ config (TrainConfig): The configuration object.
627
+ writer (BaseWriter): The writer object for logging.
628
+ """
629
+ for param_group in trainer.optimizer.param_groups:
630
+ max_lr = param_group["lr"]
631
+ new_lr = (
632
+ self.min_lr
633
+ + (max_lr - self.min_lr)
634
+ * (1 + math.cos(math.pi * trainer.opt_result.iteration / self.t_max))
635
+ / 2
636
+ )
637
+ param_group["lr"] = new_lr
638
+
639
+
640
+ class EarlyStopping(Callback):
641
+ """
642
+ Stops training when a monitored metric has not improved for a specified number of epochs.
643
+
644
+ This callback monitors a specified metric (e.g., validation loss or accuracy). If the metric
645
+ does not improve for a given patience period, training is stopped.
646
+
647
+ Example Usage in `TrainConfig`:
648
+ To use `EarlyStopping`, include it in the `callbacks` list when setting up your `TrainConfig`:
649
+ ```python exec="on" source="material-block" result="json"
650
+ from qadence.ml_tools import TrainConfig
651
+ from qadence.ml_tools.callbacks import EarlyStopping
652
+
653
+ # Create an instance of the EarlyStopping callback
654
+ early_stopping = EarlyStopping(on="val_epoch_end",
655
+ called_every=1,
656
+ monitor="val_loss",
657
+ patience=5,
658
+ mode="min")
659
+
660
+ config = TrainConfig(
661
+ max_iter=10000,
662
+ print_every=1000,
663
+ callbacks=[early_stopping]
664
+ )
665
+ ```
666
+ """
667
+
668
+ def __init__(
669
+ self, on: str, called_every: int, monitor: str, patience: int = 5, mode: str = "min"
670
+ ):
671
+ """Initializes the EarlyStopping callback.
672
+
673
+ Args:
674
+ on (str): The event to trigger the callback (e.g., "val_epoch_end").
675
+ called_every (int): Frequency of callback calls in terms of iterations.
676
+ monitor (str): The metric to monitor (e.g., "val_loss" or "train_loss").
677
+ All metrics returned by optimize step are available to monitor.
678
+ Please add "val_" and "train_" strings at the start of the metric name.
679
+ patience (int, optional): Number of iterations to wait for improvement. Default is 5.
680
+ mode (str, optional): Whether to minimize ("min") or maximize ("max") the metric.
681
+ Default is "min".
682
+ """
683
+ super().__init__(on=on, called_every=called_every)
684
+ self.monitor = monitor
685
+ self.patience = patience
686
+ self.mode = mode
687
+ self.best_value = float("inf") if mode == "min" else -float("inf")
688
+ self.counter = 0
689
+
690
+ def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> None:
691
+ """
692
+ Monitors the metric and stops training if no improvement is observed.
693
+
694
+ Args:
695
+ trainer (Any): The training object.
696
+ config (TrainConfig): The configuration object.
697
+ writer (BaseWriter): The writer object for logging.
698
+ """
699
+ current_value = trainer.opt_result.metrics.get(self.monitor)
700
+ if current_value is None:
701
+ raise ValueError(f"Metric '{self.monitor}' is not available in the trainer's metrics.")
702
+
703
+ if (self.mode == "min" and current_value < self.best_value) or (
704
+ self.mode == "max" and current_value > self.best_value
705
+ ):
706
+ self.best_value = current_value
707
+ self.counter = 0
708
+ else:
709
+ self.counter += 1
710
+
711
+ if self.counter >= self.patience:
712
+ logger.info(
713
+ f"EarlyStopping: No improvement in '{self.monitor}' for {self.patience} epochs. "
714
+ "Stopping training."
715
+ )
716
+ trainer.stop_training = True
717
+
718
+
719
+ class GradientMonitoring(Callback):
720
+ """
721
+ Logs gradient statistics (e.g., mean, standard deviation, max) during training.
722
+
723
+ This callback monitors and logs statistics about the gradients of the model parameters
724
+ to help debug or optimize the training process.
725
+
726
+ Example Usage in `TrainConfig`:
727
+ To use `GradientMonitoring`, include it in the `callbacks` list when
728
+ setting up your `TrainConfig`:
729
+ ```python exec="on" source="material-block" result="json"
730
+ from qadence.ml_tools import TrainConfig
731
+ from qadence.ml_tools.callbacks import GradientMonitoring
732
+
733
+ # Create an instance of the GradientMonitoring callback
734
+ gradient_monitoring = GradientMonitoring(on="train_batch_end", called_every=10)
735
+
736
+ config = TrainConfig(
737
+ max_iter=10000,
738
+ print_every=1000,
739
+ callbacks=[gradient_monitoring]
740
+ )
741
+ ```
742
+ """
743
+
744
+ def __init__(self, on: str, called_every: int = 1):
745
+ """Initializes the GradientMonitoring callback.
746
+
747
+ Args:
748
+ on (str): The event to trigger the callback (e.g., "train_batch_end").
749
+ called_every (int): Frequency of callback calls in terms of iterations.
750
+ """
751
+ super().__init__(on=on, called_every=called_every)
752
+
753
+ def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> None:
754
+ """
755
+ Logs gradient statistics.
756
+
757
+ Args:
758
+ trainer (Any): The training object.
759
+ config (TrainConfig): The configuration object.
760
+ writer (BaseWriter): The writer object for logging.
761
+ """
762
+ gradient_stats = {}
763
+ for name, param in trainer.model.named_parameters():
764
+ if param.grad is not None:
765
+ grad = param.grad
766
+ gradient_stats.update(
767
+ {
768
+ name + "_mean": grad.mean().item(),
769
+ name + "_std": grad.std().item(),
770
+ name + "_max": grad.max().item(),
771
+ name + "_min": grad.min().item(),
772
+ }
773
+ )
774
+
775
+ writer.write(trainer.opt_result.iteration, gradient_stats)
@@ -7,17 +7,14 @@ from types import ModuleType
7
7
  from typing import Any, Callable, Union
8
8
  from uuid import uuid4
9
9
 
10
- import mlflow
11
10
  from matplotlib.figure import Figure
12
- from mlflow.entities import Run
13
- from mlflow.models import infer_signature
14
11
  from torch import Tensor
15
12
  from torch.nn import Module
16
13
  from torch.utils.data import DataLoader
17
14
  from torch.utils.tensorboard import SummaryWriter
18
15
 
19
16
  from qadence.ml_tools.config import TrainConfig
20
- from qadence.ml_tools.data import OptimizeResult
17
+ from qadence.ml_tools.data import DictDataLoader, OptimizeResult
21
18
  from qadence.types import ExperimentTrackingTool
22
19
 
23
20
  logger = getLogger("ml_tools")
@@ -43,7 +40,7 @@ class BaseWriter(ABC):
43
40
  log_model(model, dataloader): Logs the model and any relevant information.
44
41
  """
45
42
 
46
- run: Run # [attr-defined]
43
+ run: Any # [attr-defined]
47
44
 
48
45
  @abstractmethod
49
46
  def open(self, config: TrainConfig, iteration: int | None = None) -> Any:
@@ -63,12 +60,14 @@ class BaseWriter(ABC):
63
60
  raise NotImplementedError("Writers must implement a close method.")
64
61
 
65
62
  @abstractmethod
66
- def write(self, result: OptimizeResult) -> None:
63
+ def write(self, iteration: int, metrics: dict) -> None:
67
64
  """
68
65
  Logs the results of the current iteration.
69
66
 
70
67
  Args:
71
- result (OptimizeResult): The optimization results to log.
68
+ iteration (int): The current training iteration.
69
+ metrics (dict): A dictionary of metrics to log, where keys are metric names
70
+ and values are the corresponding metric values.
72
71
  """
73
72
  raise NotImplementedError("Writers must implement a write method.")
74
73
 
@@ -104,18 +103,18 @@ class BaseWriter(ABC):
104
103
  def log_model(
105
104
  self,
106
105
  model: Module,
107
- train_dataloader: DataLoader | None = None,
108
- val_dataloader: DataLoader | None = None,
109
- test_dataloader: DataLoader | None = None,
106
+ train_dataloader: DataLoader | DictDataLoader | None = None,
107
+ val_dataloader: DataLoader | DictDataLoader | None = None,
108
+ test_dataloader: DataLoader | DictDataLoader | None = None,
110
109
  ) -> None:
111
110
  """
112
111
  Logs the model and associated data.
113
112
 
114
113
  Args:
115
114
  model (Module): The model to log.
116
- train_dataloader (DataLoader | None): DataLoader for training data.
117
- val_dataloader (DataLoader | None): DataLoader for validation data.
118
- test_dataloader (DataLoader | None): DataLoader for testing data.
115
+ train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
116
+ val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
117
+ test_dataloader (DataLoader | DictDataLoader | None): DataLoader for testing data.
119
118
  """
120
119
  raise NotImplementedError("Writers must implement a log_model method.")
121
120
 
@@ -169,23 +168,22 @@ class TensorBoardWriter(BaseWriter):
169
168
  if self.writer:
170
169
  self.writer.close()
171
170
 
172
- def write(self, result: OptimizeResult) -> None:
171
+ def write(self, iteration: int, metrics: dict) -> None:
173
172
  """
174
173
  Logs the results of the current iteration to TensorBoard.
175
174
 
176
175
  Args:
177
- result (OptimizeResult): The optimization results to log.
176
+ iteration (int): The current training iteration.
177
+ metrics (dict): A dictionary of metrics to log, where keys are metric names
178
+ and values are the corresponding metric values.
178
179
  """
179
- # Not writing loss as loss is available in the metrics
180
- # if result.loss is not None:
181
- # self.writer.add_scalar("loss", float(result.loss), result.iteration)
182
180
  if self.writer:
183
- for key, value in result.metrics.items():
184
- self.writer.add_scalar(key, value, result.iteration)
181
+ for key, value in metrics.items():
182
+ self.writer.add_scalar(key, value, iteration)
185
183
  else:
186
184
  raise RuntimeError(
187
185
  "The writer is not initialized."
188
- "Please call the 'writer.open()' method before writing"
186
+ "Please call the 'writer.open()' method before writing."
189
187
  )
190
188
 
191
189
  def log_hyperparams(self, hyperparams: dict) -> None:
@@ -231,9 +229,9 @@ class TensorBoardWriter(BaseWriter):
231
229
  def log_model(
232
230
  self,
233
231
  model: Module,
234
- train_dataloader: DataLoader | None = None,
235
- val_dataloader: DataLoader | None = None,
236
- test_dataloader: DataLoader | None = None,
232
+ train_dataloader: DataLoader | DictDataLoader | None = None,
233
+ val_dataloader: DataLoader | DictDataLoader | None = None,
234
+ test_dataloader: DataLoader | DictDataLoader | None = None,
237
235
  ) -> None:
238
236
  """
239
237
  Logs the model.
@@ -242,9 +240,9 @@ class TensorBoardWriter(BaseWriter):
242
240
 
243
241
  Args:
244
242
  model (Module): The model to log.
245
- train_dataloader (DataLoader | None): DataLoader for training data.
246
- val_dataloader (DataLoader | None): DataLoader for validation data.
247
- test_dataloader (DataLoader | None): DataLoader for testing data.
243
+ train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
244
+ val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
245
+ test_dataloader (DataLoader | DictDataLoader | None): DataLoader for testing data.
248
246
  """
249
247
  logger.warning("Model logging is not supported by tensorboard. No model will be logged.")
250
248
 
@@ -259,6 +257,14 @@ class MLFlowWriter(BaseWriter):
259
257
  """
260
258
 
261
259
  def __init__(self) -> None:
260
+ try:
261
+ from mlflow.entities import Run
262
+ except ImportError:
263
+ raise ImportError(
264
+ "mlflow is not installed. Please install qadence with the mlflow feature: "
265
+ "`pip install qadence[mlflow]`."
266
+ )
267
+
262
268
  self.run: Run
263
269
  self.mlflow: ModuleType
264
270
 
@@ -274,6 +280,8 @@ class MLFlowWriter(BaseWriter):
274
280
  Returns:
275
281
  mlflow: The MLflow module instance.
276
282
  """
283
+ import mlflow
284
+
277
285
  self.mlflow = mlflow
278
286
  tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "")
279
287
  experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", str(uuid4()))
@@ -298,22 +306,21 @@ class MLFlowWriter(BaseWriter):
298
306
  if self.run:
299
307
  self.mlflow.end_run()
300
308
 
301
- def write(self, result: OptimizeResult) -> None:
309
+ def write(self, iteration: int, metrics: dict) -> None:
302
310
  """
303
311
  Logs the results of the current iteration to MLflow.
304
312
 
305
313
  Args:
306
- result (OptimizeResult): The optimization results to log.
314
+ iteration (int): The current training iteration.
315
+ metrics (dict): A dictionary of metrics to log, where keys are metric names
316
+ and values are the corresponding metric values.
307
317
  """
308
- # Not writing loss as loss is available in the metrics
309
- # if result.loss is not None:
310
- # self.mlflow.log_metric("loss", float(result.loss), step=result.iteration)
311
318
  if self.mlflow:
312
- self.mlflow.log_metrics(result.metrics, step=result.iteration)
319
+ self.mlflow.log_metrics(metrics, step=iteration)
313
320
  else:
314
321
  raise RuntimeError(
315
322
  "The writer is not initialized."
316
- "Please call the 'writer.open()' method before writing"
323
+ "Please call the 'writer.open()' method before writing."
317
324
  )
318
325
 
319
326
  def log_hyperparams(self, hyperparams: dict) -> None:
@@ -356,17 +363,21 @@ class MLFlowWriter(BaseWriter):
356
363
  "Please call the 'writer.open()' method before writing"
357
364
  )
358
365
 
359
- def get_signature_from_dataloader(self, model: Module, dataloader: DataLoader | None) -> Any:
366
+ def get_signature_from_dataloader(
367
+ self, model: Module, dataloader: DataLoader | DictDataLoader | None
368
+ ) -> Any:
360
369
  """
361
370
  Infers the signature of the model based on the input data from the dataloader.
362
371
 
363
372
  Args:
364
373
  model (Module): The model to use for inference.
365
- dataloader (DataLoader | None): DataLoader for model inputs.
374
+ dataloader (DataLoader | DictDataLoader | None): DataLoader for model inputs.
366
375
 
367
376
  Returns:
368
377
  Optional[Any]: The inferred signature, if available.
369
378
  """
379
+ from mlflow.models import infer_signature
380
+
370
381
  if dataloader is None:
371
382
  return None
372
383
 
@@ -384,18 +395,18 @@ class MLFlowWriter(BaseWriter):
384
395
  def log_model(
385
396
  self,
386
397
  model: Module,
387
- train_dataloader: DataLoader | None = None,
388
- val_dataloader: DataLoader | None = None,
389
- test_dataloader: DataLoader | None = None,
398
+ train_dataloader: DataLoader | DictDataLoader | None = None,
399
+ val_dataloader: DataLoader | DictDataLoader | None = None,
400
+ test_dataloader: DataLoader | DictDataLoader | None = None,
390
401
  ) -> None:
391
402
  """
392
403
  Logs the model and its signature to MLflow using the provided data loaders.
393
404
 
394
405
  Args:
395
406
  model (Module): The model to log.
396
- train_dataloader (DataLoader | None): DataLoader for training data.
397
- val_dataloader (DataLoader | None): DataLoader for validation data.
398
- test_dataloader (DataLoader | None): DataLoader for testing data.
407
+ train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
408
+ val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
409
+ test_dataloader (DataLoader | DictDataLoader | None): DataLoader for testing data.
399
410
  """
400
411
  if not self.mlflow:
401
412
  raise RuntimeError(
@@ -8,11 +8,11 @@ import nevergrad as ng
8
8
  import torch
9
9
  from nevergrad.optimization.base import Optimizer as NGOptimizer
10
10
  from torch import nn, optim
11
- from torch.utils.data import DataLoader
11
+ from torch.utils.data import DataLoader, TensorDataset
12
12
 
13
13
  from qadence.ml_tools.callbacks import CallbacksManager
14
14
  from qadence.ml_tools.config import TrainConfig
15
- from qadence.ml_tools.data import InfiniteTensorDataset
15
+ from qadence.ml_tools.data import DictDataLoader
16
16
  from qadence.ml_tools.loss import get_loss_fn
17
17
  from qadence.ml_tools.optimize_step import optimize_step
18
18
  from qadence.ml_tools.parameters import get_parameters
@@ -42,9 +42,9 @@ class BaseTrainer:
42
42
  model (nn.Module): The neural network model.
43
43
  optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training.
44
44
  config (TrainConfig): The configuration settings for training.
45
- train_dataloader (DataLoader | None): DataLoader for training data.
46
- val_dataloader (DataLoader | None): DataLoader for validation data.
47
- test_dataloader (DataLoader | None): DataLoader for testing data.
45
+ train_dataloader (Dataloader | DictDataLoader | None): DataLoader for training data.
46
+ val_dataloader (Dataloader | DictDataLoader | None): DataLoader for validation data.
47
+ test_dataloader (Dataloader | DictDataLoader | None): DataLoader for testing data.
48
48
 
49
49
  optimize_step (Callable): Function for performing an optimization step.
50
50
  loss_fn (Callable | str ]): loss function to use. Default loss function
@@ -69,9 +69,9 @@ class BaseTrainer:
69
69
  config: TrainConfig,
70
70
  loss_fn: str | Callable = "mse",
71
71
  optimize_step: Callable = optimize_step,
72
- train_dataloader: DataLoader | None = None,
73
- val_dataloader: DataLoader | None = None,
74
- test_dataloader: DataLoader | None = None,
72
+ train_dataloader: DataLoader | DictDataLoader | None = None,
73
+ val_dataloader: DataLoader | DictDataLoader | None = None,
74
+ test_dataloader: DataLoader | DictDataLoader | None = None,
75
75
  max_batches: int | None = None,
76
76
  ):
77
77
  """
@@ -86,11 +86,11 @@ class BaseTrainer:
86
86
  str input to be specified to use a default loss function.
87
87
  currently supported loss functions: 'mse', 'cross_entropy'.
88
88
  If not specified, default mse loss will be used.
89
- train_dataloader (DataLoader | None): DataLoader for training data.
89
+ train_dataloader (Dataloader | DictDataLoader | None): DataLoader for training data.
90
90
  If the model does not need data to evaluate loss, no dataset
91
91
  should be provided.
92
- val_dataloader (DataLoader | None): DataLoader for validation data.
93
- test_dataloader (DataLoader | None): DataLoader for testing data.
92
+ val_dataloader (Dataloader | DictDataLoader | None): DataLoader for validation data.
93
+ test_dataloader (Dataloader | DictDataLoader | None): DataLoader for testing data.
94
94
  max_batches (int | None): Maximum number of batches to process per epoch.
95
95
  This is only valid in case of finite TensorDataset dataloaders.
96
96
  if max_batches is not None, the maximum number of batches used will
@@ -100,9 +100,9 @@ class BaseTrainer:
100
100
  self._model: nn.Module
101
101
  self._optimizer: optim.Optimizer | NGOptimizer | None
102
102
  self._config: TrainConfig
103
- self._train_dataloader: DataLoader | None = None
104
- self._val_dataloader: DataLoader | None = None
105
- self._test_dataloader: DataLoader | None = None
103
+ self._train_dataloader: DataLoader | DictDataLoader | None = None
104
+ self._val_dataloader: DataLoader | DictDataLoader | None = None
105
+ self._test_dataloader: DataLoader | DictDataLoader | None = None
106
106
 
107
107
  self.config = config
108
108
  self.model = model
@@ -311,7 +311,7 @@ class BaseTrainer:
311
311
  self.callback_manager = CallbacksManager(value)
312
312
  self.config_manager = ConfigManager(value)
313
313
 
314
- def _compute_num_batches(self, dataloader: DataLoader) -> int:
314
+ def _compute_num_batches(self, dataloader: DataLoader | DictDataLoader) -> int:
315
315
  """
316
316
  Computes the number of batches for the given DataLoader.
317
317
 
@@ -321,34 +321,41 @@ class BaseTrainer:
321
321
  """
322
322
  if dataloader is None:
323
323
  return 1
324
- dataset = dataloader.dataset
325
- if isinstance(dataset, InfiniteTensorDataset):
326
- return 1
324
+ if isinstance(dataloader, DictDataLoader):
325
+ dataloader_name, dataloader_value = list(dataloader.dataloaders.items())[0]
326
+ dataset = dataloader_value.dataset
327
+ batch_size = dataloader_value.batch_size
327
328
  else:
328
- n_batches = int(
329
- (dataset.tensors[0].size(0) + dataloader.batch_size - 1) // dataloader.batch_size
330
- )
329
+ dataset = dataloader.dataset
330
+ batch_size = dataloader.batch_size
331
+
332
+ if isinstance(dataset, TensorDataset):
333
+ n_batches = int((dataset.tensors[0].size(0) + batch_size - 1) // batch_size)
331
334
  return min(self.max_batches, n_batches) if self.max_batches is not None else n_batches
335
+ else:
336
+ return 1
332
337
 
333
- def _validate_dataloader(self, dataloader: DataLoader, dataloader_type: str) -> None:
338
+ def _validate_dataloader(
339
+ self, dataloader: DataLoader | DictDataLoader, dataloader_type: str
340
+ ) -> None:
334
341
  """
335
342
  Validates the type of the DataLoader and raises errors for unsupported types.
336
343
 
337
344
  Args:
338
- dataloader (DataLoader): The DataLoader to validate.
345
+ dataloader (DataLoader | DictDataLoader): The DataLoader to validate.
339
346
  dataloader_type (str): The type of DataLoader ("train", "val", or "test").
340
347
  """
341
348
  if dataloader is not None:
342
- if not isinstance(dataloader, DataLoader):
349
+ if not isinstance(dataloader, (DataLoader, DictDataLoader)):
343
350
  raise NotImplementedError(
344
351
  f"Unsupported dataloader type: {type(dataloader)}."
345
352
  "The dataloader must be an instance of DataLoader."
346
353
  )
347
354
  if dataloader_type == "val" and self.config.val_every > 0:
348
- if not isinstance(dataloader, DataLoader):
355
+ if not isinstance(dataloader, (DataLoader, DictDataLoader)):
349
356
  raise ValueError(
350
357
  "If `config.val_every` is provided as an integer > 0, validation_dataloader"
351
- "must be an instance of `DataLoader`."
358
+ "must be an instance of `DataLoader` or `DictDataLoader`."
352
359
  )
353
360
 
354
361
  @staticmethod
@@ -14,7 +14,7 @@ from torch import dtype as torch_dtype
14
14
  from torch.utils.data import DataLoader
15
15
 
16
16
  from qadence.ml_tools.config import TrainConfig
17
- from qadence.ml_tools.data import OptimizeResult
17
+ from qadence.ml_tools.data import DictDataLoader, OptimizeResult
18
18
  from qadence.ml_tools.optimize_step import optimize_step, update_ng_parameters
19
19
  from qadence.ml_tools.stages import TrainingStage
20
20
 
@@ -49,9 +49,9 @@ class Trainer(BaseTrainer):
49
49
  model (nn.Module): The neural network model.
50
50
  optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training.
51
51
  config (TrainConfig): The configuration settings for training.
52
- train_dataloader (DataLoader | None): DataLoader for training data.
53
- val_dataloader (DataLoader | None): DataLoader for validation data.
54
- test_dataloader (DataLoader | None): DataLoader for testing data.
52
+ train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
53
+ val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
54
+ test_dataloader (DataLoader | DictDataLoader | None): DataLoader for testing data.
55
55
 
56
56
  optimize_step (Callable): Function for performing an optimization step.
57
57
  loss_fn (Callable): loss function to use.
@@ -235,9 +235,9 @@ class Trainer(BaseTrainer):
235
235
  optimizer: optim.Optimizer | NGOptimizer | None,
236
236
  config: TrainConfig,
237
237
  loss_fn: str | Callable = "mse",
238
- train_dataloader: DataLoader | None = None,
239
- val_dataloader: DataLoader | None = None,
240
- test_dataloader: DataLoader | None = None,
238
+ train_dataloader: DataLoader | DictDataLoader | None = None,
239
+ val_dataloader: DataLoader | DictDataLoader | None = None,
240
+ test_dataloader: DataLoader | DictDataLoader | None = None,
241
241
  optimize_step: Callable = optimize_step,
242
242
  device: torch_device | None = None,
243
243
  dtype: torch_dtype | None = None,
@@ -252,9 +252,9 @@ class Trainer(BaseTrainer):
252
252
  config (TrainConfig): Training configuration object.
253
253
  loss_fn (str | Callable ): Loss function used for training.
254
254
  If not specified, default mse loss will be used.
255
- train_dataloader (DataLoader | None): DataLoader for training data.
256
- val_dataloader (DataLoader | None): DataLoader for validation data.
257
- test_dataloader (DataLoader | None): DataLoader for test data.
255
+ train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
256
+ val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
257
+ test_dataloader (DataLoader | DictDataLoader | None): DataLoader for test data.
258
258
  optimize_step (Callable): Function to execute an optimization step.
259
259
  device (torch_device): Device to use for computation.
260
260
  dtype (torch_dtype): Data type for computation.
@@ -281,11 +281,14 @@ class Trainer(BaseTrainer):
281
281
  self.device: torch_device | None = device
282
282
  self.dtype: torch_dtype | None = dtype
283
283
  self.data_dtype: torch_dtype | None = None
284
+ self.stop_training: bool = False
284
285
  if self.dtype:
285
286
  self.data_dtype = float64 if (self.dtype == complex128) else float32
286
287
 
287
288
  def fit(
288
- self, train_dataloader: DataLoader | None = None, val_dataloader: DataLoader | None = None
289
+ self,
290
+ train_dataloader: DataLoader | DictDataLoader | None = None,
291
+ val_dataloader: DataLoader | DictDataLoader | None = None,
289
292
  ) -> tuple[nn.Module, optim.Optimizer]:
290
293
  """
291
294
  Fits the model using the specified training configuration.
@@ -294,8 +297,8 @@ class Trainer(BaseTrainer):
294
297
  provided in the trainer will be used.
295
298
 
296
299
  Args:
297
- train_dataloader (DataLoader | None): DataLoader for training data.
298
- val_dataloader (DataLoader | None): DataLoader for validation data.
300
+ train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
301
+ val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
299
302
 
300
303
  Returns:
301
304
  tuple[nn.Module, optim.Optimizer]: The trained model and optimizer.
@@ -319,6 +322,7 @@ class Trainer(BaseTrainer):
319
322
  The callback_manager.start_training takes care of loading checkpoint,
320
323
  and setting up the writer.
321
324
  """
325
+ self.stop_training = False
322
326
  self.config_manager.initialize_config()
323
327
  self.callback_manager.start_training(trainer=self)
324
328
 
@@ -336,10 +340,8 @@ class Trainer(BaseTrainer):
336
340
  TimeRemainingColumn(elapsed_when_finished=True),
337
341
  )
338
342
 
339
- # Quick Fix for build_optimize_step
340
- # Please review run_train_batch for more details
341
- self.model_old = copy.deepcopy(self.model)
342
- self.optimizer_old = copy.deepcopy(self.optimizer)
343
+ # Quick Fix for iteration 0
344
+ self._reset_model_and_opt()
343
345
 
344
346
  # Run validation at the start if specified in the configuration
345
347
  self.perform_val = self.config.val_every > 0
@@ -377,25 +379,26 @@ class Trainer(BaseTrainer):
377
379
  for epoch in range(
378
380
  self.global_step, self.global_step + self.config_manager.config.max_iter + 1
379
381
  ):
380
- try:
381
- self.current_epoch = epoch
382
- self.on_train_epoch_start()
383
- train_epoch_loss_metrics = self.run_training(self.train_dataloader)
384
- train_losses.append(train_epoch_loss_metrics)
385
- self.on_train_epoch_end(train_epoch_loss_metrics)
386
-
387
- # Run validation periodically if specified
388
- if self.perform_val and self.current_epoch % self.config.val_every == 0:
389
- self.on_val_epoch_start()
390
- val_epoch_loss_metrics = self.run_validation(self.val_dataloader)
391
- val_losses.append(val_epoch_loss_metrics)
392
- self.on_val_epoch_end(val_epoch_loss_metrics)
393
- self.progress.update(val_task, advance=1)
394
-
395
- self.progress.update(train_task, advance=1)
396
- except KeyboardInterrupt:
397
- logger.info("Terminating training gracefully after the current iteration.")
398
- break
382
+ if not self.stop_training:
383
+ try:
384
+ self.current_epoch = epoch
385
+ self.on_train_epoch_start()
386
+ train_epoch_loss_metrics = self.run_training(self.train_dataloader)
387
+ train_losses.append(train_epoch_loss_metrics)
388
+ self.on_train_epoch_end(train_epoch_loss_metrics)
389
+
390
+ # Run validation periodically if specified
391
+ if self.perform_val and self.current_epoch % self.config.val_every == 0:
392
+ self.on_val_epoch_start()
393
+ val_epoch_loss_metrics = self.run_validation(self.val_dataloader)
394
+ val_losses.append(val_epoch_loss_metrics)
395
+ self.on_val_epoch_end(val_epoch_loss_metrics)
396
+ self.progress.update(val_task, advance=1)
397
+
398
+ self.progress.update(train_task, advance=1)
399
+ except KeyboardInterrupt:
400
+ logger.info("Terminating training gracefully after the current iteration.")
401
+ break
399
402
 
400
403
  self.on_train_end(train_losses, val_losses)
401
404
  return train_losses
@@ -415,16 +418,10 @@ class Trainer(BaseTrainer):
415
418
  """
416
419
  self.model.train()
417
420
  train_epoch_loss_metrics = []
418
- # Deep copy model and optimizer to maintain checkpoints
419
- # We do this because optimize step provides loss, metrics
420
- # before step of optimization
421
- # To align them with model/optimizer correctly, we checkpoint
422
- # the older copy of the model.
423
- # TODO: review optimize_step to provide iteration aligned model and loss.
424
- self.model_old = copy.deepcopy(self.model)
425
- self.optimizer_old = copy.deepcopy(self.optimizer)
421
+ # Quick Fix for iteration 0
422
+ self._reset_model_and_opt()
426
423
 
427
- for batch in self.batch_iter(dataloader, self.num_training_batches):
424
+ for batch in self._batch_iter(dataloader, self.num_training_batches):
428
425
  self.on_train_batch_start(batch)
429
426
  train_batch_loss_metrics = self.run_train_batch(batch)
430
427
  train_epoch_loss_metrics.append(train_batch_loss_metrics)
@@ -475,7 +472,7 @@ class Trainer(BaseTrainer):
475
472
  self.ng_params = ng_params
476
473
  loss_metrics = loss, metrics
477
474
 
478
- return self.modify_batch_end_loss_metrics(loss_metrics)
475
+ return self._modify_batch_end_loss_metrics(loss_metrics)
479
476
 
480
477
  @BaseTrainer.callback("val_epoch")
481
478
  def run_validation(self, dataloader: DataLoader) -> list[tuple[torch.Tensor, dict[str, Any]]]:
@@ -493,7 +490,7 @@ class Trainer(BaseTrainer):
493
490
  self.model.eval()
494
491
  val_epoch_loss_metrics = []
495
492
 
496
- for batch in self.batch_iter(dataloader, self.num_validation_batches):
493
+ for batch in self._batch_iter(dataloader, self.num_validation_batches):
497
494
  self.on_val_batch_start(batch)
498
495
  val_batch_loss_metrics = self.run_val_batch(batch)
499
496
  val_epoch_loss_metrics.append(val_batch_loss_metrics)
@@ -514,7 +511,7 @@ class Trainer(BaseTrainer):
514
511
  """
515
512
  with torch.no_grad():
516
513
  loss_metrics = self.loss_fn(self.model, batch)
517
- return self.modify_batch_end_loss_metrics(loss_metrics)
514
+ return self._modify_batch_end_loss_metrics(loss_metrics)
518
515
 
519
516
  def test(self, test_dataloader: DataLoader = None) -> list[tuple[torch.Tensor, dict[str, Any]]]:
520
517
  """
@@ -537,7 +534,7 @@ class Trainer(BaseTrainer):
537
534
  self.model.eval()
538
535
  test_loss_metrics = []
539
536
 
540
- for batch in self.batch_iter(test_dataloader, self.num_training_batches):
537
+ for batch in self._batch_iter(test_dataloader, self.num_training_batches):
541
538
  self.on_test_batch_start(batch)
542
539
  loss_metrics = self.run_test_batch(batch)
543
540
  test_loss_metrics.append(loss_metrics)
@@ -560,11 +557,11 @@ class Trainer(BaseTrainer):
560
557
  """
561
558
  with torch.no_grad():
562
559
  loss_metrics = self.loss_fn(self.model, batch)
563
- return self.modify_batch_end_loss_metrics(loss_metrics)
560
+ return self._modify_batch_end_loss_metrics(loss_metrics)
564
561
 
565
- def batch_iter(
562
+ def _batch_iter(
566
563
  self,
567
- dataloader: DataLoader,
564
+ dataloader: DataLoader | DictDataLoader,
568
565
  num_batches: int,
569
566
  ) -> Iterable[tuple[torch.Tensor, ...] | None]:
570
567
  """
@@ -587,7 +584,7 @@ class Trainer(BaseTrainer):
587
584
  # batch = data_to_device(batch, device=self.device, dtype=self.data_dtype)
588
585
  yield batch
589
586
 
590
- def modify_batch_end_loss_metrics(
587
+ def _modify_batch_end_loss_metrics(
591
588
  self, loss_metrics: tuple[torch.Tensor, dict[str, Any]]
592
589
  ) -> tuple[torch.Tensor, dict[str, Any]]:
593
590
  """
@@ -611,6 +608,28 @@ class Trainer(BaseTrainer):
611
608
  return loss, updated_metrics
612
609
  return loss_metrics
613
610
 
611
+ def _reset_model_and_opt(self) -> None:
612
+ """
613
+ Save model_old and optimizer_old for epoch 0.
614
+
615
+ This allows us to create a copy of model
616
+ and optimizer before running the optimization.
617
+
618
+ We do this because optimize step provides loss, metrics
619
+ before step of optimization
620
+ To align them with model/optimizer correctly, we checkpoint
621
+ the older copy of the model.
622
+ """
623
+
624
+ # TODO: review optimize_step to provide iteration aligned model and loss.
625
+ try:
626
+ # Deep copy model and optimizer to maintain checkpoints
627
+ self.model_old = copy.deepcopy(self.model)
628
+ self.optimizer_old = copy.deepcopy(self.optimizer)
629
+ except Exception:
630
+ self.model_old = self.model
631
+ self.optimizer_old = self.optimizer
632
+
614
633
  def build_optimize_result(
615
634
  self,
616
635
  result: None
qadence/states.py CHANGED
@@ -6,6 +6,7 @@ from typing import List
6
6
 
7
7
  import torch
8
8
  from numpy.typing import ArrayLike
9
+ from pyqtorch.utils import DensityMatrix
9
10
  from torch import Tensor, concat
10
11
  from torch.distributions import Categorical, Distribution
11
12
 
@@ -37,6 +38,8 @@ __all__ = [
37
38
  "is_normalized",
38
39
  "rand_bitstring",
39
40
  "equivalent_state",
41
+ "DensityMatrix",
42
+ "density_mat",
40
43
  ]
41
44
 
42
45
  ATOL_64 = 1e-14 # 64 bit precision
@@ -319,6 +322,24 @@ def random_state(
319
322
  return state
320
323
 
321
324
 
325
+ # DENSITY MATRIX
326
+
327
+
328
+ def density_mat(state: Tensor) -> DensityMatrix:
329
+ """
330
+ Computes the density matrix from a pure state vector.
331
+
332
+ Arguments:
333
+ state: The pure state vector :math:`|\\psi\\rangle`.
334
+
335
+ Returns:
336
+ Tensor: The density matrix :math:`\\rho = |\psi \\rangle \\langle\\psi|`.
337
+ """
338
+ if isinstance(state, DensityMatrix):
339
+ return state
340
+ return DensityMatrix(torch.einsum("bi,bj->bij", (state, state.conj())))
341
+
342
+
322
343
  # BLOCKS
323
344
 
324
345
 
qadence/types.py CHANGED
@@ -9,7 +9,7 @@ import numpy as np
9
9
  import sympy
10
10
  from matplotlib.figure import Figure
11
11
  from numpy.typing import ArrayLike
12
- from pyqtorch.noise import NoiseType as DigitalNoise
12
+ from pyqtorch.noise import DigitalNoiseType as DigitalNoise
13
13
  from pyqtorch.noise.readout import WhiteNoise
14
14
  from pyqtorch.utils import DropoutMode, SolverType
15
15
  from torch import Tensor, pi
@@ -1,9 +1,10 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: qadence
3
- Version: 1.9.0
3
+ Version: 1.9.2
4
4
  Summary: Pasqal interface for circuit-based quantum computing SDKs
5
5
  Author-email: Aleksander Wennersteen <aleksander.wennersteen@pasqal.com>, Gert-Jan Both <gert-jan.both@pasqal.com>, Niklas Heim <niklas.heim@pasqal.com>, Mario Dagrada <mario.dagrada@pasqal.com>, Vincent Elfving <vincent.elfving@pasqal.com>, Dominik Seitz <dominik.seitz@pasqal.com>, Roland Guichard <roland.guichard@pasqal.com>, "Joao P. Moutinho" <joao.moutinho@pasqal.com>, Vytautas Abramavicius <vytautas.abramavicius@pasqal.com>, Gergana Velikova <gergana.velikova@pasqal.com>, Eduardo Maschio <eduardo.maschio@pasqal.com>, Smit Chaudhary <smit.chaudhary@pasqal.com>, Ignacio Fernández Graña <ignacio.fernandez-grana@pasqal.com>, Charles Moussa <charles.moussa@pasqal.com>, Giorgio Tosti Balducci <giorgio.tosti-balducci@pasqal.com>, Daniele Cucurachi <daniele.cucurachi@pasqal.com>
6
6
  License: Apache 2.0
7
+ License-File: LICENSE
7
8
  Classifier: License :: OSI Approved :: Apache Software License
8
9
  Classifier: Programming Language :: Python
9
10
  Classifier: Programming Language :: Python :: 3
@@ -21,7 +22,7 @@ Requires-Dist: matplotlib
21
22
  Requires-Dist: nevergrad
22
23
  Requires-Dist: numpy
23
24
  Requires-Dist: openfermion
24
- Requires-Dist: pyqtorch==1.6.0
25
+ Requires-Dist: pyqtorch==1.7.0
25
26
  Requires-Dist: pyyaml
26
27
  Requires-Dist: rich
27
28
  Requires-Dist: scipy
@@ -53,9 +54,9 @@ Requires-Dist: mlflow; extra == 'mlflow'
53
54
  Provides-Extra: protocols
54
55
  Requires-Dist: qadence-protocols; extra == 'protocols'
55
56
  Provides-Extra: pulser
56
- Requires-Dist: pasqal-cloud==0.12.5; extra == 'pulser'
57
- Requires-Dist: pulser-core==1.1.1; extra == 'pulser'
58
- Requires-Dist: pulser-simulation==1.1.1; extra == 'pulser'
57
+ Requires-Dist: pasqal-cloud==0.12.6; extra == 'pulser'
58
+ Requires-Dist: pulser-core==1.2.0; extra == 'pulser'
59
+ Requires-Dist: pulser-simulation==1.2.0; extra == 'pulser'
59
60
  Provides-Extra: visualization
60
61
  Requires-Dist: graphviz; extra == 'visualization'
61
62
  Description-Content-Type: text/markdown
@@ -80,6 +81,8 @@ programs** with tunable qubit interactions and arbitrary register topologies rea
80
81
  [![Documentation](https://github.com/pasqal-io/qadence/actions/workflows/build_docs.yml/badge.svg)](https://pasqal-io.github.io/qadence/latest)
81
82
  [![Pypi](https://badge.fury.io/py/qadence.svg)](https://pypi.org/project/qadence/)
82
83
  [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
84
+ ![Coverage](https://img.shields.io/codecov/c/github/pasqal-io/qadence?style=flat-square)
85
+
83
86
 
84
87
  ## Feature highlights
85
88
 
@@ -17,26 +17,26 @@ qadence/qubit_support.py,sha256=Nkn1Q01RVViTcggSIom7EFKdWpAuM4TMGwBZ5feCUxA,2120
17
17
  qadence/register.py,sha256=mwmvS6PcTY0F9cIhTUXG3NT73FIagfMCwVqYa4DrQrk,13001
18
18
  qadence/serial_expr_grammar.peg,sha256=z5ytL7do9kO8o4h-V5GrsDuLdso0KsRcMuIYURFfmAY,328
19
19
  qadence/serialization.py,sha256=qEET6Gu9u2aSibPve3bJrqDzK2_gO3RPDJjt4ZY8GbE,15596
20
- qadence/states.py,sha256=5QIOBBYs8e2uLFiMa8iMYZ-MvWIFEqkZAjNYx0SyYPI,14843
21
- qadence/types.py,sha256=NR1dnN4tKC3zL1KEzjs2p_IYVth7GctgXVbIJsuiUws,11992
20
+ qadence/states.py,sha256=OMFuPAmPTLfZYwefXMv82P96xp5aBDJpanmCNgkRO-o,15379
21
+ qadence/types.py,sha256=Jhd_qTI8X7R61LcueNfIsODLUFB7WfVHWiJpsQkrixs,11999
22
22
  qadence/utils.py,sha256=zb2j7wURfy8kazaS84r4t35vAeDpo4Tpm4HbmPH-kFA,9865
23
23
  qadence/analog/__init__.py,sha256=BCyS9R4KUjzUXN0Ax3b0eMo8ZAuSkGoJQVtZ4_pvAFs,279
24
24
  qadence/analog/addressing.py,sha256=GSt4heEmRkBmoQIgdgkTclEFxZY-jjuAd77_SsZtGdI,6513
25
25
  qadence/analog/constants.py,sha256=B2phQoN1ASL8CwM-Dsa1rbraYwGwwPSeiB3HbVe-MPA,1243
26
- qadence/analog/device.py,sha256=LK8rQYBiK_PWHezLfTL0Ig83yc5wQPmZ3rBZF-gWgYw,2416
26
+ qadence/analog/device.py,sha256=t7oGjiZhk28IG2C-SVkc0RNSlV1L4SXV-tkLNiSYFNM,2570
27
27
  qadence/analog/hamiltonian_terms.py,sha256=9LKidqqEMJTTdXeaxkxP_otTmcv9i4yeJ-JKCLOCK3Y,3421
28
28
  qadence/analog/parse_analog.py,sha256=9Y_LMdw4wCHH6YSkvHhs6PUNwzT14HS7cUGheNSmDQg,4168
29
29
  qadence/backends/__init__.py,sha256=ibm7wmZxuIoMYAQxgAx0MsfLYWOVHNWgLwyS1HjMuuI,215
30
30
  qadence/backends/api.py,sha256=NPrvtZQ4klUBabUWJ5hbTUCVoaoW9-sHVbiXxAnTt3A,2643
31
31
  qadence/backends/gpsr.py,sha256=HW5m6iHLq3hLHdJoU1q1i1laR0hBs7uCniXqrsFoNCI,5616
32
32
  qadence/backends/jax_utils.py,sha256=VfKhqCKknHDWZO21UFipWH_Lkiq175Z5GkP49gWjbyw,5038
33
- qadence/backends/utils.py,sha256=SdeAf6BDayc-W9hPI271PKtJ_ROOzRry9Jw72wUGmm4,8565
33
+ qadence/backends/utils.py,sha256=SSiMxZjaFS8e8sB6ZBLXPKuJNQGl93pRMy9hnI4oDrw,9104
34
34
  qadence/backends/horqrux/__init__.py,sha256=0OdVy6cq0oQggV48LO1WXdaZuSkDkz7OYNEPIkNAmfk,140
35
35
  qadence/backends/horqrux/backend.py,sha256=KNFFGN9dsgB9QKtNXiP3LyMY9DQ-7W7ScyE6k29fHJY,8842
36
36
  qadence/backends/horqrux/config.py,sha256=xz7JlUcwW_4JAbvProbSI9hA1SXZRRAN0Hr2bvmLzfg,892
37
37
  qadence/backends/horqrux/convert_ops.py,sha256=3uG3yLq5wjfrWzFHDs0HEnd8kER91ZHVX3HCpYjOdjk,8565
38
38
  qadence/backends/pulser/__init__.py,sha256=capQ-eHqwtOeLf4mWsI0BIseAHhiLGie5cFD4-iVhUo,116
39
- qadence/backends/pulser/backend.py,sha256=kNo_AAvOPNevdiMbZnQlEUbWCO-NXcwUtFWQHJ6VR2Q,14845
39
+ qadence/backends/pulser/backend.py,sha256=cI4IgijPpItNdDmLpKkJFas0X02wMiZd_XmVas41gEI,14846
40
40
  qadence/backends/pulser/channels.py,sha256=ZF0yEXUFHAmi3IdeXjzdTNGR5NzaRRFTiUpUGVg2sO4,329
41
41
  qadence/backends/pulser/cloud.py,sha256=0uUluvbFV9sOuCPraE-9uiVtC3Q8QaDY1IJMDi8grDM,2057
42
42
  qadence/backends/pulser/config.py,sha256=aoHDmtgq5i0Zryxenw_p3uARY0B1w-UaYvfqDmrWHM0,2175
@@ -47,7 +47,7 @@ qadence/backends/pulser/waveforms.py,sha256=0uz95b7rUaUUtN0tuHBZmJ0H6UBmfHST_59o
47
47
  qadence/backends/pyqtorch/__init__.py,sha256=0OdVy6cq0oQggV48LO1WXdaZuSkDkz7OYNEPIkNAmfk,140
48
48
  qadence/backends/pyqtorch/backend.py,sha256=Sjuof9b332w4gk9o8Rso2rgSHxskexfkIazRfxRD0Ng,11458
49
49
  qadence/backends/pyqtorch/config.py,sha256=sAxWVSkWvj6Lu0em1KJCDb6nfjqe8Dsxi7pyh6qYJpA,2387
50
- qadence/backends/pyqtorch/convert_ops.py,sha256=QVlXAqqBPtBt2NuFU9mvEm5i7FWXg3gWqoeizlH-1_s,13401
50
+ qadence/backends/pyqtorch/convert_ops.py,sha256=qG26-HmtUDaZO0KDnw2sbT3CRx_poS7eqJ3dn9wpWgc,13457
51
51
  qadence/blocks/__init__.py,sha256=H6jEA_CptkE-eoB4UfSbUiDszbxxhZwECV_TgoZWXoU,960
52
52
  qadence/blocks/abstract.py,sha256=DSQUE71rMyRBwAP--4Tx1WQC_LCXaNlftjd7goGyrpQ,12027
53
53
  qadence/blocks/analog.py,sha256=ymnnlSVoW1XL05ZvnnHCqRTHuOXIEY_7E9M0PNKJZy4,10812
@@ -86,7 +86,7 @@ qadence/engines/jax/differentiable_backend.py,sha256=FcSrzzjzb0zfXC0-4mUJ6UB-wGO
86
86
  qadence/engines/jax/differentiable_expectation.py,sha256=rn_l7IH-S4IvuAcyAIgyEuMZOIqswu5Nsfz0JffXjaE,3694
87
87
  qadence/engines/torch/__init__.py,sha256=iZFdD32ot0B0CVyC-f5dVViOBnqoalxa6M9Lj4WQuPE,160
88
88
  qadence/engines/torch/differentiable_backend.py,sha256=uQfyGg-25MAc0soK1FyvJ2FJakRuv5_5DOy7OPiZYg8,3567
89
- qadence/engines/torch/differentiable_expectation.py,sha256=AXb1nLG1WuJXUOH02OhRVKASXst12d1b8Y2eUMmPt5M,10304
89
+ qadence/engines/torch/differentiable_expectation.py,sha256=kc4WTos7d65DDmao6YSrpTM0rCBnpqhGK4xLHm_K4yk,10351
90
90
  qadence/exceptions/__init__.py,sha256=BU6vWrI9mshzr1aTPm1Ticr_o_42GjTrWI4OZXhThsI,203
91
91
  qadence/exceptions/exceptions.py,sha256=4j_VJpx2sZ2Mir5BJUWu4nwb131FY1ygO4q8-XlyfRc,190
92
92
  qadence/measurements/__init__.py,sha256=RIjG9tVJMqhNzyj7maZI250Um0KgHl2PizDcKJag-JU,161
@@ -108,17 +108,17 @@ qadence/ml_tools/optimize_step.py,sha256=wUnxfWy0c9rEKe41-26On1bPFBwmSYBF4WCGn76
108
108
  qadence/ml_tools/parameters.py,sha256=gew2Kq_5-RgRpaTvs8eauVhgo0sTqqDQEV6WHFEiLGM,1301
109
109
  qadence/ml_tools/stages.py,sha256=qW2phMIvQBLM3tn2UoGN-ePiBnZoNq5k844eHVnnn8Y,1407
110
110
  qadence/ml_tools/tensors.py,sha256=xZ9ZRzOqEaMgLUGWQf1najDmL6iLuN1ojCGVFs1Tm94,1337
111
- qadence/ml_tools/trainer.py,sha256=ic1UWu2lYJFSsg-FG-bcbge3pqDGocsZ2hH4Ln8uuLQ,26222
111
+ qadence/ml_tools/trainer.py,sha256=phKCr3-hmHTsKMoZ89z0U5KTZ_h7kaUX7w4WX7A0YH8,26990
112
112
  qadence/ml_tools/utils.py,sha256=PW8FyoV0mG_DtN1U8njTDV5qxZ0EK4mnFwMAsLBArfk,1410
113
- qadence/ml_tools/callbacks/__init__.py,sha256=XaUKmyQZaqxI0jvKnWCpIBgnX5y4Kczcbn2FRiomFu4,655
114
- qadence/ml_tools/callbacks/callback.py,sha256=F9tbXBBv3ZKTFbm0fGBZIZtTRO63jLazMk_oeL77dyE,16289
113
+ qadence/ml_tools/callbacks/__init__.py,sha256=pTdfjulDGNKca--9BgrdmMyvJSah_0spp929Th6RzC8,913
114
+ qadence/ml_tools/callbacks/callback.py,sha256=XoqTS1uLOkbh4FtKpDSXbUA5_LzjOAoVMaa2jYcYB3w,28800
115
115
  qadence/ml_tools/callbacks/callbackmanager.py,sha256=HwxgbqJi1GWYg2lgUqEyw9Y6a71YG_m5DmhpaeB6kLs,8007
116
116
  qadence/ml_tools/callbacks/saveload.py,sha256=2z8v1A3qIIPZuusEcSNqgYTnKGKkDj71KvY_atJvKnM,6015
117
- qadence/ml_tools/callbacks/writer_registry.py,sha256=Sl7OsBBzRCoOW5kQ1RMdAWS_y4TrIsGUI8XvA8JImJ4,14626
117
+ qadence/ml_tools/callbacks/writer_registry.py,sha256=_lPb4VvDHiiRNh2EaEKxOSslnJgBAImGw5SoVReg-Rs,15351
118
118
  qadence/ml_tools/loss/__init__.py,sha256=d_0FlisdmgLY0qL1PeaabbcWX1B42RBdm7220cfzSN4,247
119
119
  qadence/ml_tools/loss/loss.py,sha256=Bditg8nelMEpG4Yt0aopcAQz84xIc6O-AGUO2M0nqbM,2982
120
120
  qadence/ml_tools/train_utils/__init__.py,sha256=1A2FlFg7kn68R1fdRC73S8DzA9gkBW7whdNHjzH5UTA,235
121
- qadence/ml_tools/train_utils/base_trainer.py,sha256=7XrIV2qEV8qetYSH9Pg-RKqQUxGw8u7Xlz42yeVjh3Y,19864
121
+ qadence/ml_tools/train_utils/base_trainer.py,sha256=giOcBRMjgbq9sLjqck6MCWH8V1MCVBHarWuFrS-ahbw,20442
122
122
  qadence/ml_tools/train_utils/config_manager.py,sha256=dps94qfiwjhoY_aQp5RvQPd9zW_MIN2knw1UaDaYrKs,6896
123
123
  qadence/noise/__init__.py,sha256=tnChHv7FzOaV8C7O0P2l_gfjrpmHg8JaNhZprL33CP4,161
124
124
  qadence/noise/protocols.py,sha256=SPHJi5AiIOcz6U_iXY3ddVHk3cl9UHSDKk49eMTX2QM,8586
@@ -137,7 +137,7 @@ qadence/transpile/flatten.py,sha256=EdhSG5WyF56nbnxINNLqrHgY84MRM1YFjT3fR4aph5Q,
137
137
  qadence/transpile/invert.py,sha256=KAefHTG2AWr39aengVhXrzCtJPhrZC-ZnL6vYvmbnY0,4867
138
138
  qadence/transpile/noise.py,sha256=LDcDJtQGkgUPkL2t69gg6AScTb-p3J3SxCDZbYOu1L8,1668
139
139
  qadence/transpile/transpile.py,sha256=6MRRkk1OS279L1fwUQjazA6qlfpbd-T_EJMKT8hAhOU,2721
140
- qadence-1.9.0.dist-info/METADATA,sha256=HbwOPe7bOi1evqjHaPIN8R7xiR_SRF2r4RqQyatKGPo,9842
141
- qadence-1.9.0.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87
142
- qadence-1.9.0.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
143
- qadence-1.9.0.dist-info/RECORD,,
140
+ qadence-1.9.2.dist-info/METADATA,sha256=JzJ9P6KRKQuAp8XeTW65OX5I6l9qc2aPMGGYNZczBpU,9954
141
+ qadence-1.9.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
142
+ qadence-1.9.2.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
143
+ qadence-1.9.2.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.26.3
2
+ Generator: hatchling 1.27.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any