qadence 1.10.2__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,21 +4,20 @@ import copy
4
4
  from itertools import islice
5
5
  from logging import getLogger
6
6
  from typing import Any, Callable, Iterable, cast
7
-
8
- import torch
9
7
  from nevergrad.optimization.base import Optimizer as NGOptimizer
10
- from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn
11
- from torch import complex128, float32, float64, nn, optim
12
- from torch import device as torch_device
13
- from torch import dtype as torch_dtype
8
+ import torch
9
+ from torch import nn, optim
14
10
  from torch.utils.data import DataLoader
11
+ from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn
15
12
 
16
13
  from qadence.ml_tools.config import TrainConfig
17
- from qadence.ml_tools.data import DictDataLoader, OptimizeResult
14
+ from qadence.ml_tools.data import DictDataLoader, OptimizeResult, data_to_device
15
+ from qadence.ml_tools.information import InformationContent
18
16
  from qadence.ml_tools.optimize_step import optimize_step, update_ng_parameters
19
17
  from qadence.ml_tools.stages import TrainingStage
20
18
 
21
19
  from .train_utils.base_trainer import BaseTrainer
20
+ from .train_utils.accelerator import Accelerator
22
21
 
23
22
  logger = getLogger("ml_tools")
24
23
 
@@ -37,11 +36,6 @@ class Trainer(BaseTrainer):
37
36
  Attributes:
38
37
  current_epoch (int): The current epoch number.
39
38
  global_step (int): The global step across all epochs.
40
- log_device (str): Device for logging, default is "cpu".
41
- device (torch_device): Device used for computation.
42
- dtype (torch_dtype | None): Data type used for computation.
43
- data_dtype (torch_dtype | None): Data type for data.
44
- Depends on the model's data type.
45
39
 
46
40
  Inherited Attributes:
47
41
  use_grad (bool): Indicates if gradients are used for optimization. Default is True.
@@ -239,8 +233,6 @@ class Trainer(BaseTrainer):
239
233
  val_dataloader: DataLoader | DictDataLoader | None = None,
240
234
  test_dataloader: DataLoader | DictDataLoader | None = None,
241
235
  optimize_step: Callable = optimize_step,
242
- device: torch_device | None = None,
243
- dtype: torch_dtype | None = None,
244
236
  max_batches: int | None = None,
245
237
  ):
246
238
  """
@@ -256,8 +248,6 @@ class Trainer(BaseTrainer):
256
248
  val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
257
249
  test_dataloader (DataLoader | DictDataLoader | None): DataLoader for test data.
258
250
  optimize_step (Callable): Function to execute an optimization step.
259
- device (torch_device): Device to use for computation.
260
- dtype (torch_dtype): Data type for computation.
261
251
  max_batches (int | None): Maximum number of batches to process per epoch.
262
252
  This is only valid in case of finite TensorDataset dataloaders.
263
253
  if max_batches is not None, the maximum number of batches used will
@@ -277,13 +267,21 @@ class Trainer(BaseTrainer):
277
267
  )
278
268
  self.current_epoch: int = 0
279
269
  self.global_step: int = 0
280
- self.log_device: str = "cpu" if device is None else device
281
- self.device: torch_device | None = device
282
- self.dtype: torch_dtype | None = dtype
283
- self.data_dtype: torch_dtype | None = None
284
- self.stop_training: bool = False
285
- if self.dtype:
286
- self.data_dtype = float64 if (self.dtype == complex128) else float32
270
+ self._stop_training: torch.Tensor = torch.tensor(0, dtype=torch.int)
271
+ self.progress: Progress | None = None
272
+
273
+ # Integration with Accelerator:
274
+ self.accelerator = Accelerator(
275
+ backend=config.backend,
276
+ nprocs=config.nprocs,
277
+ compute_setup=config.compute_setup,
278
+ dtype=config.dtype,
279
+ log_setup=config.log_setup,
280
+ )
281
+ # Decorate the unbound Trainer.fit method with accelerator.distribute.
282
+ # We use __get__ to bind the decorated method to the current instance,
283
+ # ensuring that 'self' is passed only once when self.fit is called.
284
+ self.fit = self.accelerator.distribute(Trainer.fit).__get__(self, Trainer) # type: ignore[method-assign]
287
285
 
288
286
  def fit(
289
287
  self,
@@ -322,26 +320,30 @@ class Trainer(BaseTrainer):
322
320
  The callback_manager.start_training takes care of loading checkpoint,
323
321
  and setting up the writer.
324
322
  """
325
- self.stop_training = False
326
- self.config_manager.initialize_config()
323
+ self._stop_training = torch.tensor(
324
+ 0, dtype=torch.int, device=self.accelerator.execution.device
325
+ )
326
+ # initalize config in the first process, and broadcast it to all processes
327
+ if self.accelerator.rank == 0:
328
+ self.config_manager.initialize_config()
329
+ self.config_manager = self.accelerator.broadcast(self.config_manager, src=0)
327
330
  self.callback_manager.start_training(trainer=self)
328
331
 
329
- # Move model to device
330
- if isinstance(self.model, nn.DataParallel):
331
- self.model = self.model.module.to(device=self.device, dtype=self.dtype)
332
- else:
333
- self.model = self.model.to(device=self.device, dtype=self.dtype)
334
-
335
- # Progress bar for training visualization
336
- self.progress: Progress = Progress(
337
- TextColumn("[progress.description]{task.description}"),
338
- BarColumn(),
339
- TaskProgressColumn(),
340
- TimeRemainingColumn(elapsed_when_finished=True),
332
+ # Integration with Accelerator: prepare the model, optimizer, and dataloaders.
333
+ (self.model, self.optimizer, self.train_dataloader, self.val_dataloader) = (
334
+ self.accelerator.prepare(
335
+ self.model, self.optimizer, self.train_dataloader, self.val_dataloader
336
+ )
341
337
  )
342
338
 
343
- # Quick Fix for iteration 0
344
- self._reset_model_and_opt()
339
+ # Progress bar for training visualization
340
+ if self.accelerator.world_size == 1:
341
+ self.progress = Progress(
342
+ TextColumn("[progress.description]{task.description}"),
343
+ BarColumn(),
344
+ TaskProgressColumn(),
345
+ TimeRemainingColumn(elapsed_when_finished=True),
346
+ )
345
347
 
346
348
  # Run validation at the start if specified in the configuration
347
349
  self.perform_val = self.config.val_every > 0
@@ -355,7 +357,11 @@ class Trainer(BaseTrainer):
355
357
  @BaseTrainer.callback("train")
356
358
  def _train(self) -> list[list[tuple[torch.Tensor, dict[str, Any]]]]:
357
359
  """
358
- Runs the main training loop, iterating over epochs.
360
+ Runs the main training loop over multiple epochs.
361
+
362
+ This method sets up the training process by performing any necessary pre-training
363
+ actions (via `on_train_start`), configuring progress tracking (if available), and then
364
+ iteratively calling `_train_epoch` to run through the epochs.
359
365
 
360
366
  Returns:
361
367
  list[list[tuple[torch.Tensor, dict[str, Any]]]]: Training loss
@@ -364,45 +370,97 @@ class Trainer(BaseTrainer):
364
370
  Epochs -> Training Batches -> (loss, metrics)
365
371
  """
366
372
  self.on_train_start()
367
- train_losses = []
368
- val_losses = []
373
+ epoch_start, epoch_end = (
374
+ self.global_step,
375
+ self.global_step + self.config_manager.config.max_iter + 1,
376
+ )
369
377
 
370
- with self.progress:
371
- train_task = self.progress.add_task(
372
- "Training", total=self.config_manager.config.max_iter
373
- )
374
- if self.perform_val:
375
- val_task = self.progress.add_task(
376
- "Validation",
377
- total=(self.config_manager.config.max_iter + 1) / self.config.val_every,
378
+ if self.accelerator.world_size == 1 and self.progress:
379
+ # Progress setup is only available for non-spawned training.
380
+ with self.progress:
381
+ train_task = self.progress.add_task(
382
+ "Training", total=self.config_manager.config.max_iter
378
383
  )
379
- for epoch in range(
380
- self.global_step, self.global_step + self.config_manager.config.max_iter + 1
381
- ):
382
- if not self.stop_training:
383
- try:
384
- self.current_epoch = epoch
385
- self.on_train_epoch_start()
386
- train_epoch_loss_metrics = self.run_training(self.train_dataloader)
387
- train_losses.append(train_epoch_loss_metrics)
388
- self.on_train_epoch_end(train_epoch_loss_metrics)
389
-
390
- # Run validation periodically if specified
391
- if self.perform_val and self.current_epoch % self.config.val_every == 0:
392
- self.on_val_epoch_start()
393
- val_epoch_loss_metrics = self.run_validation(self.val_dataloader)
394
- val_losses.append(val_epoch_loss_metrics)
395
- self.on_val_epoch_end(val_epoch_loss_metrics)
396
- self.progress.update(val_task, advance=1)
397
-
398
- self.progress.update(train_task, advance=1)
399
- except KeyboardInterrupt:
400
- logger.info("Terminating training gracefully after the current iteration.")
401
- break
384
+ if self.perform_val:
385
+ val_task = self.progress.add_task(
386
+ "Validation",
387
+ total=(self.config_manager.config.max_iter + 1) / self.config.val_every,
388
+ )
389
+ else:
390
+ val_task = None
391
+ train_losses, val_losses = self._train_epochs(
392
+ epoch_start, epoch_end, train_task, val_task
393
+ )
394
+ else:
395
+ train_losses, val_losses = self._train_epochs(epoch_start, epoch_end)
402
396
 
403
397
  self.on_train_end(train_losses, val_losses)
404
398
  return train_losses
405
399
 
400
+ def _train_epochs(
401
+ self,
402
+ epoch_start: int,
403
+ epoch_end: int,
404
+ train_task: int | None = None,
405
+ val_task: int | None = None,
406
+ ) -> tuple[
407
+ list[list[tuple[torch.Tensor, dict[str, Any]]]],
408
+ list[list[tuple[torch.Tensor, dict[str, Any]]]],
409
+ ]:
410
+ """
411
+ Executes the training loop for a series of epochs.
412
+
413
+ Args:
414
+ epoch_start (int): The starting epoch index.
415
+ epoch_end (int): The ending epoch index (non-inclusive).
416
+ train_task (int | None, optional): The progress bar task ID for training updates.
417
+ If provided, the progress bar will be updated after each epoch. Defaults to None.
418
+ val_task (int | None, optional): The progress bar task ID for validation updates.
419
+ If provided and validation is enabled, the progress bar will be updated after each validation run.
420
+ Defaults to None.
421
+
422
+ Returns:
423
+ list[list[tuple[torch.Tensor, dict[str, Any]]]]: A tuple of
424
+ Training loss metrics for all epochs.
425
+ list -> list -> tuples
426
+ Epochs -> Training Batches -> (loss, metrics)
427
+ And Validation loss metrics for all epochs
428
+ list -> list -> tuples
429
+ Epochs -> Training Batches -> (loss, metrics)
430
+ """
431
+ train_losses = []
432
+ val_losses = []
433
+
434
+ # Iterate over the epochs
435
+ for epoch in range(epoch_start, epoch_end):
436
+ if not self.stop_training():
437
+ try:
438
+ self.current_epoch = epoch
439
+ self.on_train_epoch_start()
440
+ train_epoch_loss_metrics = self.run_training(self.train_dataloader)
441
+ train_losses.append(train_epoch_loss_metrics)
442
+ self.on_train_epoch_end(train_epoch_loss_metrics)
443
+
444
+ # Run validation periodically if specified
445
+ if self.perform_val and (epoch % self.config.val_every == 0):
446
+ self.on_val_epoch_start()
447
+ val_epoch_loss_metrics = self.run_validation(self.val_dataloader)
448
+ val_losses.append(val_epoch_loss_metrics)
449
+ self.on_val_epoch_end(val_epoch_loss_metrics)
450
+ if val_task is not None:
451
+ self.progress.update(val_task, advance=1) # type: ignore[union-attr]
452
+
453
+ if train_task is not None:
454
+ self.progress.update(train_task, advance=1) # type: ignore[union-attr]
455
+ except KeyboardInterrupt:
456
+ self._stop_training.fill_(1)
457
+ else:
458
+ if self.accelerator.rank == 0:
459
+ logger.info("Terminating training gracefully after the current iteration.")
460
+ self.accelerator.finalize()
461
+ break
462
+ return train_losses, val_losses
463
+
406
464
  @BaseTrainer.callback("train_epoch")
407
465
  def run_training(self, dataloader: DataLoader) -> list[tuple[torch.Tensor, dict[str, Any]]]:
408
466
  """
@@ -418,12 +476,12 @@ class Trainer(BaseTrainer):
418
476
  """
419
477
  self.model.train()
420
478
  train_epoch_loss_metrics = []
421
- # Quick Fix for iteration 0
422
- self._reset_model_and_opt()
423
479
 
424
480
  for batch in self._batch_iter(dataloader, self.num_training_batches):
425
481
  self.on_train_batch_start(batch)
426
482
  train_batch_loss_metrics = self.run_train_batch(batch)
483
+ if self.config.all_reduce_metrics:
484
+ train_batch_loss_metrics = self._aggregate_result(train_batch_loss_metrics)
427
485
  train_epoch_loss_metrics.append(train_batch_loss_metrics)
428
486
  self.on_train_batch_end(train_batch_loss_metrics)
429
487
 
@@ -457,8 +515,8 @@ class Trainer(BaseTrainer):
457
515
  optimizer=self.optimizer,
458
516
  loss_fn=self.loss_fn,
459
517
  xs=batch,
460
- device=self.device,
461
- dtype=self.data_dtype,
518
+ device=self.accelerator.execution.device,
519
+ dtype=self.accelerator.execution.data_dtype,
462
520
  )
463
521
  else:
464
522
  # Perform optimization using Nevergrad
@@ -472,7 +530,15 @@ class Trainer(BaseTrainer):
472
530
  self.ng_params = ng_params
473
531
  loss_metrics = loss, metrics
474
532
 
475
- return self._modify_batch_end_loss_metrics(loss_metrics)
533
+ # --------------------- FIX: Post-Optimization Loss --------------------- #
534
+ # Because the loss/metrics are returned before the optimization. To sync
535
+ # model state and current loss/metrics we calculate them again after optimization.
536
+ # This is not strictly necessary.
537
+ # TODO: Should be removed if loss can be logged at an unoptimized model state
538
+ with torch.no_grad():
539
+ post_update_loss_metrics = self.loss_fn(self.model, batch)
540
+
541
+ return self._modify_batch_end_loss_metrics(post_update_loss_metrics)
476
542
 
477
543
  @BaseTrainer.callback("val_epoch")
478
544
  def run_validation(self, dataloader: DataLoader) -> list[tuple[torch.Tensor, dict[str, Any]]]:
@@ -493,6 +559,8 @@ class Trainer(BaseTrainer):
493
559
  for batch in self._batch_iter(dataloader, self.num_validation_batches):
494
560
  self.on_val_batch_start(batch)
495
561
  val_batch_loss_metrics = self.run_val_batch(batch)
562
+ if self.config.all_reduce_metrics:
563
+ val_batch_loss_metrics = self._aggregate_result(val_batch_loss_metrics)
496
564
  val_epoch_loss_metrics.append(val_batch_loss_metrics)
497
565
  self.on_val_batch_end(val_batch_loss_metrics)
498
566
 
@@ -567,6 +635,9 @@ class Trainer(BaseTrainer):
567
635
  """
568
636
  Yields batches from the provided dataloader.
569
637
 
638
+ The batch of data is also moved
639
+ to the correct device and dtype using accelerator.prepare.
640
+
570
641
  Args:
571
642
  dataloader ([DataLoader]): The dataloader to iterate over.
572
643
  num_batches (int): The maximum number of batches to yield.
@@ -580,9 +651,7 @@ class Trainer(BaseTrainer):
580
651
  yield None
581
652
  else:
582
653
  for batch in islice(dataloader, num_batches):
583
- # batch is moved to device inside optimize step
584
- # batch = data_to_device(batch, device=self.device, dtype=self.data_dtype)
585
- yield batch
654
+ yield self.accelerator.prepare_batch(batch)
586
655
 
587
656
  def _modify_batch_end_loss_metrics(
588
657
  self, loss_metrics: tuple[torch.Tensor, dict[str, Any]]
@@ -608,27 +677,43 @@ class Trainer(BaseTrainer):
608
677
  return loss, updated_metrics
609
678
  return loss_metrics
610
679
 
611
- def _reset_model_and_opt(self) -> None:
680
+ def _aggregate_result(
681
+ self, result: tuple[torch.Tensor, dict[str, Any]]
682
+ ) -> tuple[torch.Tensor, dict[str, Any]]:
612
683
  """
613
- Save model_old and optimizer_old for epoch 0.
684
+ Aggregates the loss and metrics using the Accelerator's all_reduce_dict method if aggregation is enabled.
685
+
686
+ Args:
687
+ result: (tuple[torch.Tensor, dict[str, Any]])
688
+ The result consisting of loss and metrics.For more details,
689
+ look at the signature of build_optimize_result.
614
690
 
615
- This allows us to create a copy of model
616
- and optimizer before running the optimization.
691
+ Returns:
692
+ tuple[torch.Tensor, dict[str, Any]]: The aggregated loss and metrics.
693
+ """
694
+ loss, metrics = result
695
+ if self.config.all_reduce_metrics:
696
+ reduced = self.accelerator.all_reduce_dict({"loss": loss, **metrics})
697
+ loss = reduced.pop("loss")
698
+ metrics = reduced
699
+ return loss, metrics
700
+ else:
701
+ return loss, metrics
617
702
 
618
- We do this because optimize step provides loss, metrics
619
- before step of optimization
620
- To align them with model/optimizer correctly, we checkpoint
621
- the older copy of the model.
703
+ def stop_training(self) -> bool:
622
704
  """
705
+ Helper function to indicate if the training should be stopped.
623
706
 
624
- # TODO: review optimize_step to provide iteration aligned model and loss.
625
- try:
626
- # Deep copy model and optimizer to maintain checkpoints
627
- self.model_old = copy.deepcopy(self.model)
628
- self.optimizer_old = copy.deepcopy(self.optimizer)
629
- except Exception:
630
- self.model_old = self.model
631
- self.optimizer_old = self.optimizer
707
+ We all_reduce the indicator across all processes to ensure all processes are stopped.
708
+
709
+ Notes:
710
+ self._stop_training indicator indicates if the training should be stopped.
711
+ 0 is continue. 1 is stop.
712
+ """
713
+ _stop_training = self.accelerator.all_reduce_dict(
714
+ {"indicator": self._stop_training}, op="max"
715
+ )
716
+ return bool(_stop_training["indicator"] > 0)
632
717
 
633
718
  def build_optimize_result(
634
719
  self,
@@ -709,5 +794,113 @@ class Trainer(BaseTrainer):
709
794
 
710
795
  # Store the optimization result
711
796
  self.opt_result = OptimizeResult(
712
- self.current_epoch, self.model_old, self.optimizer_old, loss, metrics
797
+ self.current_epoch,
798
+ self.model,
799
+ self.optimizer,
800
+ loss,
801
+ metrics,
802
+ rank=self.accelerator.rank,
803
+ device=self.accelerator.execution.device,
713
804
  )
805
+
806
+ def get_ic_grad_bounds(
807
+ self,
808
+ eta: float,
809
+ epsilons: torch.Tensor,
810
+ variation_multiple: int = 20,
811
+ dataloader: DataLoader | DictDataLoader | None = None,
812
+ ) -> tuple[float, float, float]:
813
+ """
814
+ Calculate the bounds on the gradient norm of the loss using Information Content.
815
+
816
+ Args:
817
+ eta (float): The sensitivity IC.
818
+ epsilons (torch.Tensor): The epsilons to use for thresholds to for discretization of the
819
+ finite derivatives.
820
+ variation_multiple (int): The number of sets of variational parameters to generate per
821
+ each variational parameter. The number of variational parameters required for the
822
+ statisctiacal analysis scales linearly with the amount of them present in the
823
+ model. This is that linear factor.
824
+ dataloader (DataLoader | DictDataLoader | None): The dataloader for training data. A
825
+ new dataloader can be provided, or the dataloader provided in the trinaer will be
826
+ used. In case no dataloaders are provided at either places, it assumes that the
827
+ model does not require any input data.
828
+
829
+ Returns:
830
+ tuple[float, float, float]: The max IC lower bound, max IC upper bound, and sensitivity
831
+ IC upper bound.
832
+
833
+ Examples:
834
+ ```python
835
+ import torch
836
+ from torch.optim.adam import Adam
837
+
838
+ from qadence.constructors import ObservableConfig
839
+ from qadence.ml_tools.config import AnsatzConfig, FeatureMapConfig, TrainConfig
840
+ from qadence.ml_tools.data import to_dataloader
841
+ from qadence.ml_tools.models import QNN
842
+ from qadence.ml_tools.optimize_step import optimize_step
843
+ from qadence.ml_tools.trainer import Trainer
844
+ from qadence.operations.primitive import Z
845
+
846
+ fm_config = FeatureMapConfig(num_features=1)
847
+ ansatz_config = AnsatzConfig(depth=4)
848
+ obs_config = ObservableConfig(detuning=Z)
849
+
850
+ qnn = QNN.from_configs(
851
+ register=4,
852
+ obs_config=obs_config,
853
+ fm_config=fm_config,
854
+ ansatz_config=ansatz_config,
855
+ )
856
+
857
+ optimizer = Adam(qnn.parameters(), lr=0.001)
858
+
859
+ batch_size = 25
860
+ x = torch.linspace(0, 1, 32).reshape(-1, 1)
861
+ y = torch.sin(x)
862
+ train_loader = to_dataloader(x, y, batch_size=batch_size, infinite=True)
863
+
864
+ train_config = TrainConfig(max_iter=100)
865
+
866
+ trainer = Trainer(
867
+ model=qnn,
868
+ optimizer=optimizer,
869
+ config=train_config,
870
+ loss_fn="mse",
871
+ train_dataloader=train_loader,
872
+ optimize_step=optimize_step,
873
+ )
874
+
875
+ # Perform exploratory landscape analysis with Information Content
876
+ ic_sensitivity_threshold = 1e-4
877
+ epsilons = torch.logspace(-2, 2, 10)
878
+
879
+ max_ic_lower_bound, max_ic_upper_bound, sensitivity_ic_upper_bound = (
880
+ trainer.get_ic_grad_bounds(
881
+ eta=ic_sensitivity_threshold,
882
+ epsilons=epsilons,
883
+ )
884
+ )
885
+
886
+ # Resume training as usual...
887
+
888
+ trainer.fit(train_loader)
889
+ ```
890
+ """
891
+ if not self._use_grad:
892
+ logger.warning(
893
+ "Gradient norm bounds are only relevant when using a gradient based optimizer. \
894
+ Currently the trainer is set to use a gradient-free optimizer."
895
+ )
896
+
897
+ dataloader = dataloader if dataloader is not None else self.train_dataloader
898
+
899
+ batch = next(iter(self._batch_iter(dataloader, num_batches=1)))
900
+
901
+ ic = InformationContent(self.model, self.loss_fn, batch, epsilons)
902
+
903
+ max_ic_lower_bound, max_ic_upper_bound = ic.get_grad_norm_bounds_max_IC()
904
+ sensitivity_ic_upper_bound = ic.get_grad_norm_bounds_sensitivity_IC(eta)
905
+
906
+ return max_ic_lower_bound, max_ic_upper_bound, sensitivity_ic_upper_bound
@@ -376,10 +376,6 @@ class SWAP(PrimitiveBlock):
376
376
  def eigenvalues(self) -> Tensor:
377
377
  return torch.tensor([-1, 1, 1, 1], dtype=cdouble)
378
378
 
379
- @property
380
- def n_qubits(self) -> int:
381
- return 2
382
-
383
379
  @property
384
380
  def _block_title(self) -> str:
385
381
  c, t = self.qubit_support
qadence/types.py CHANGED
@@ -445,17 +445,6 @@ class InputDiffMode(StrEnum):
445
445
  """Central finite differencing."""
446
446
 
447
447
 
448
- class ObservableTransform:
449
- """Observable transformation type."""
450
-
451
- SCALE = "scale"
452
- """Use the given values as scale and shift."""
453
- RANGE = "range"
454
- """Use the given values as min and max."""
455
- NONE = "none"
456
- """No transformation."""
457
-
458
-
459
448
  class ExperimentTrackingTool(StrEnum):
460
449
  TENSORBOARD = "tensorboard"
461
450
  """Use the tensorboard experiment tracker."""
@@ -463,6 +452,13 @@ class ExperimentTrackingTool(StrEnum):
463
452
  """Use the ml-flow experiment tracker."""
464
453
 
465
454
 
455
+ class ExecutionType(StrEnum):
456
+ TORCHRUN = "torchrun"
457
+ """Torchrun based distribution execution."""
458
+ DEFAULT = "default"
459
+ """Default distribution execution."""
460
+
461
+
466
462
  LoggablePlotFunction = Callable[[Module, int], tuple[str, Figure]]
467
463
 
468
464
 
qadence/utils.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import math
4
+ import re
4
5
  from collections import Counter
5
6
  from functools import partial
6
7
  from logging import getLogger
@@ -13,6 +14,9 @@ from torch import Tensor, stack, vmap
13
14
  from torch import complex as make_complex
14
15
  from torch.linalg import eigvals
15
16
 
17
+ from rich.tree import Tree
18
+
19
+ from qadence.blocks import AbstractBlock
16
20
  from qadence.types import Endianness, ResultType, TNumber
17
21
 
18
22
  if TYPE_CHECKING:
@@ -290,3 +294,44 @@ def one_qubit_projector_matrix(state: str) -> Tensor:
290
294
 
291
295
  P0 = partial(one_qubit_projector, "0")
292
296
  P1 = partial(one_qubit_projector, "1")
297
+
298
+
299
+ def block_to_mathematical_expression(block: Tree | AbstractBlock) -> str:
300
+ """Convert a block to a readable mathematical expression.
301
+
302
+ Useful for printing Observables as a mathematical expression.
303
+
304
+ Args:
305
+ block (AbstractBlock): Tree instance.
306
+
307
+ Returns:
308
+ str: A mathematical expression.
309
+ """
310
+ block_tree: Tree = block.__rich_tree__() if isinstance(block, AbstractBlock) else block
311
+ block_title = block_tree.label if isinstance(block_tree.label, str) else ""
312
+ if "AddBlock" in block_title:
313
+ block_title = " + ".join(
314
+ [block_to_mathematical_expression(block_child) for block_child in block_tree.children]
315
+ )
316
+ if "KronBlock" in block_title:
317
+ block_title = " ⊗ ".join(
318
+ [block_to_mathematical_expression(block_child) for block_child in block_tree.children]
319
+ )
320
+ if "mul" in block_title:
321
+ block_title = re.findall("\d+\.\d+", block_title)[0]
322
+ coeff = float(block_title)
323
+ if coeff == 0:
324
+ block_title = ""
325
+ elif coeff == 1:
326
+ block_title = block_to_mathematical_expression(block_tree.children[0])
327
+ else:
328
+ block_title += " * " + block_to_mathematical_expression(block_tree.children[0])
329
+ first_part = block_title[:3]
330
+ if first_part in [" + ", " ⊗ ", " * "]:
331
+ block_title = block_title[3:]
332
+
333
+ # if too many trees, add parentheses.
334
+ nb_children = len(block_tree.children)
335
+ if nb_children > 1:
336
+ block_title = "(" + block_title + ")"
337
+ return block_title