qadence 1.10.3__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -82,18 +82,20 @@ def _fill_identities(
82
82
  full_qubit_support = tuple(sorted(full_qubit_support))
83
83
  qubit_support = tuple(sorted(qubit_support))
84
84
  block_mat = block_mat.to(device)
85
- mat = IMAT.to(device) if qubit_support[0] != full_qubit_support[0] else block_mat
85
+ identity_mat = IMAT.to(device)
86
86
  if diag_only:
87
- mat = torch.diag(mat.squeeze(0))
87
+ block_mat = torch.diag(block_mat.squeeze(0))
88
+ identity_mat = torch.diag(identity_mat.squeeze(0))
89
+ mat = identity_mat if qubit_support[0] != full_qubit_support[0] else block_mat
88
90
  for i in full_qubit_support[1:]:
89
91
  if i == qubit_support[0]:
90
- other = torch.diag(block_mat.squeeze(0)) if diag_only else block_mat
92
+ other = block_mat
91
93
  if endianness == Endianness.LITTLE:
92
94
  mat = torch.kron(other, mat)
93
95
  else:
94
96
  mat = torch.kron(mat.contiguous(), other.contiguous())
95
97
  elif i not in qubit_support:
96
- other = torch.diag(IMAT.squeeze(0).to(device)) if diag_only else IMAT.to(device)
98
+ other = identity_mat
97
99
  if endianness == Endianness.LITTLE:
98
100
  mat = torch.kron(other.contiguous(), mat.contiguous())
99
101
  else:
@@ -264,13 +266,12 @@ def _gate_parameters(b: AbstractBlock, values: dict[str, torch.Tensor]) -> tuple
264
266
 
265
267
  def block_to_diagonal(
266
268
  block: AbstractBlock,
269
+ values: dict[str, TNumber | torch.Tensor] = dict(),
267
270
  qubit_support: tuple | list | None = None,
268
- use_full_support: bool = True,
271
+ use_full_support: bool = False,
269
272
  endianness: Endianness = Endianness.BIG,
270
273
  device: torch.device = None,
271
274
  ) -> torch.Tensor:
272
- if block.is_parametric:
273
- raise TypeError("Sparse observables cant be parametric.")
274
275
  if not block._is_diag_pauli:
275
276
  raise TypeError("Sparse observables can only be used on paulis which are diagonal.")
276
277
  if qubit_support is None:
@@ -282,17 +283,16 @@ def block_to_diagonal(
282
283
  if isinstance(block, (ChainBlock, KronBlock)):
283
284
  v = torch.ones(2**nqubits, dtype=torch.cdouble)
284
285
  for b in block.blocks:
285
- v *= block_to_diagonal(b, qubit_support)
286
+ v *= block_to_diagonal(b, values, qubit_support, device=device)
286
287
  if isinstance(block, AddBlock):
287
288
  t = torch.zeros(2**nqubits, dtype=torch.cdouble)
288
289
  for b in block.blocks:
289
- t += block_to_diagonal(b, qubit_support)
290
+ t += block_to_diagonal(b, values, qubit_support, device=device)
290
291
  v = t
291
292
  elif isinstance(block, ScaleBlock):
292
- _s = evaluate(block.scale, {}, as_torch=True) # type: ignore[attr-defined]
293
- _s = _s.detach() # type: ignore[union-attr]
294
- v = _s * block_to_diagonal(block.block, qubit_support)
295
-
293
+ _s = evaluate(block.scale, values, as_torch=True) # type: ignore[attr-defined]
294
+ _s = _s.detach().squeeze(0) # type: ignore[union-attr]
295
+ v = _s * block_to_diagonal(block.block, values, qubit_support, device=device)
296
296
  elif isinstance(block, PrimitiveBlock):
297
297
  v = _fill_identities(
298
298
  OPERATIONS_DICT[block.name],
@@ -300,6 +300,7 @@ def block_to_diagonal(
300
300
  qubit_support, # type: ignore [arg-type]
301
301
  diag_only=True,
302
302
  endianness=endianness,
303
+ device=device,
303
304
  )
304
305
  return v
305
306
 
@@ -309,7 +310,7 @@ def block_to_tensor(
309
310
  block: AbstractBlock,
310
311
  values: dict[str, TNumber | torch.Tensor] = {},
311
312
  qubit_support: tuple | None = None,
312
- use_full_support: bool = True,
313
+ use_full_support: bool = False,
313
314
  tensor_type: TensorType = TensorType.DENSE,
314
315
  endianness: Endianness = Endianness.BIG,
315
316
  device: torch.device = None,
@@ -339,18 +340,14 @@ def block_to_tensor(
339
340
  print(block_to_tensor(obs, tensor_type="SparseDiagonal"))
340
341
  ```
341
342
  """
343
+ from qadence.blocks import embedding
342
344
 
343
- # FIXME: default use_full_support to False. In general, it would
344
- # be more efficient to do that, and make sure that computations such
345
- # as observables only do the matmul of the size of the qubit support.
346
-
345
+ (ps, embed) = embedding(block)
346
+ values = embed(ps, values)
347
347
  if tensor_type == TensorType.DENSE:
348
- from qadence.blocks import embedding
349
-
350
- (ps, embed) = embedding(block)
351
348
  return _block_to_tensor_embedded(
352
349
  block,
353
- embed(ps, values),
350
+ values,
354
351
  qubit_support,
355
352
  use_full_support,
356
353
  endianness=endianness,
@@ -358,7 +355,7 @@ def block_to_tensor(
358
355
  )
359
356
 
360
357
  elif tensor_type == TensorType.SPARSEDIAGONAL:
361
- t = block_to_diagonal(block, endianness=endianness)
358
+ t = block_to_diagonal(block, values, endianness=endianness)
362
359
  indices, values, size = torch.nonzero(t), t[t != 0], len(t)
363
360
  indices = torch.stack((indices.flatten(), indices.flatten()))
364
361
  return torch.sparse_coo_tensor(indices, values, (size, size))
@@ -369,7 +366,7 @@ def _block_to_tensor_embedded(
369
366
  block: AbstractBlock,
370
367
  values: dict[str, TNumber | torch.Tensor] = {},
371
368
  qubit_support: tuple | None = None,
372
- use_full_support: bool = True,
369
+ use_full_support: bool = False,
373
370
  endianness: Endianness = Endianness.BIG,
374
371
  device: torch.device = None,
375
372
  ) -> torch.Tensor:
@@ -17,6 +17,9 @@ from .hamiltonians import (
17
17
  ObservableConfig,
18
18
  total_magnetization,
19
19
  zz_hamiltonian,
20
+ total_magnetization_config,
21
+ zz_hamiltonian_config,
22
+ ising_hamiltonian_config,
20
23
  )
21
24
 
22
25
  from .rydberg_hea import rydberg_hea, rydberg_hea_layer
@@ -34,9 +37,12 @@ __all__ = [
34
37
  "iia",
35
38
  "hamiltonian_factory",
36
39
  "ising_hamiltonian",
37
- "ObservableConfig",
38
40
  "total_magnetization",
39
41
  "zz_hamiltonian",
42
+ "ObservableConfig",
43
+ "total_magnetization_config",
44
+ "zz_hamiltonian_config",
45
+ "ising_hamiltonian_config",
40
46
  "qft",
41
47
  "daqc_transform",
42
48
  "rydberg_hea",
@@ -7,11 +7,12 @@ from typing import Callable, List, Type, Union
7
7
  import numpy as np
8
8
  from torch import Tensor, double, ones, rand
9
9
  from typing_extensions import Any
10
+ from qadence.parameters import Parameter
10
11
 
11
12
  from qadence.blocks import AbstractBlock, add, block_is_qubit_hamiltonian
12
- from qadence.operations import N, X, Y, Z
13
+ from qadence.operations import N, X, Y, Z, H
13
14
  from qadence.register import Register
14
- from qadence.types import Interaction, ObservableTransform, TArray, TParameter
15
+ from qadence.types import Interaction, TArray, TParameter
15
16
 
16
17
  logger = getLogger(__name__)
17
18
 
@@ -239,7 +240,30 @@ def is_numeric(x: Any) -> bool:
239
240
 
240
241
  @dataclass
241
242
  class ObservableConfig:
242
- detuning: TDetuning
243
+ """ObservableConfig is a configuration class for defining the parameters of an observable Hamiltonian."""
244
+
245
+ interaction: Interaction | Callable | None = None
246
+ """
247
+ The type of interaction.
248
+
249
+ Available options from the Interaction enum are:
250
+ - Interaction.ZZ
251
+ - Interaction.NN
252
+ - Interaction.XY
253
+ - Interaction.XYZ
254
+
255
+ Alternatively, a custom interaction function can be defined.
256
+ Example:
257
+
258
+ def custom_int(i: int, j: int):
259
+ return X(i) @ X(j) + Y(i) @ Y(j)
260
+
261
+ n_qubits = 2
262
+
263
+ observable_config = ObservableConfig(interaction=custom_int, scale = 1.0, shift = 0.0)
264
+ observable = create_observable(register=4, config=observable_config)
265
+ """
266
+ detuning: TDetuning | None = None
243
267
  """
244
268
  Single qubit detuning of the observable Hamiltonian.
245
269
 
@@ -249,8 +273,6 @@ class ObservableConfig:
249
273
  """The scale by which to multiply the output of the observable."""
250
274
  shift: TParameter = 0.0
251
275
  """The shift to add to the output of the observable."""
252
- transformation_type: ObservableTransform = ObservableTransform.NONE # type: ignore[assignment]
253
- """The type of transformation."""
254
276
  trainable_transform: bool | None = None
255
277
  """
256
278
  Whether to have a trainable transformation on the output of the observable.
@@ -261,8 +283,73 @@ class ObservableConfig:
261
283
  """
262
284
 
263
285
  def __post_init__(self) -> None:
286
+ if self.interaction is None and self.detuning is None:
287
+ raise ValueError(
288
+ "Please provide an interaction and/or detuning for the Observable Hamiltonian."
289
+ )
290
+
264
291
  if is_numeric(self.scale) and is_numeric(self.shift):
265
- assert (
266
- self.trainable_transform is None
267
- ), f"If scale and shift are numbers, trainable_transform must be None. \
268
- But got: {self.trainable_transform}"
292
+ assert self.trainable_transform is None, (
293
+ "If scale and shift are numbers, trainable_transform must be None."
294
+ f"But got: {self.trainable_transform}"
295
+ )
296
+
297
+ # trasform the scale and shift into parameters
298
+ if self.trainable_transform is not None:
299
+ self.shift = Parameter(name=self.shift, trainable=self.trainable_transform)
300
+ self.scale = Parameter(name=self.scale, trainable=self.trainable_transform)
301
+ else:
302
+ self.shift = Parameter(self.shift)
303
+ self.scale = Parameter(self.scale)
304
+
305
+
306
+ def total_magnetization_config(
307
+ scale: TParameter = 1.0,
308
+ shift: TParameter = 0.0,
309
+ trainable_transform: bool | None = None,
310
+ ) -> ObservableConfig:
311
+ return ObservableConfig(
312
+ detuning=Z,
313
+ scale=scale,
314
+ shift=shift,
315
+ trainable_transform=trainable_transform,
316
+ )
317
+
318
+
319
+ def zz_hamiltonian_config(
320
+ scale: TParameter = 1.0,
321
+ shift: TParameter = 0.0,
322
+ trainable_transform: bool | None = None,
323
+ ) -> ObservableConfig:
324
+ return ObservableConfig(
325
+ interaction=Interaction.ZZ,
326
+ detuning=Z,
327
+ scale=scale,
328
+ shift=shift,
329
+ trainable_transform=trainable_transform,
330
+ )
331
+
332
+
333
+ def ising_hamiltonian_config(
334
+ scale: TParameter = 1.0,
335
+ shift: TParameter = 0.0,
336
+ trainable_transform: bool | None = None,
337
+ ) -> ObservableConfig:
338
+
339
+ def ZZ_Z_hamiltonian(i: int, j: int) -> AbstractBlock:
340
+ result = Z(i) @ Z(j)
341
+
342
+ if i == 0:
343
+ result += Z(j)
344
+ elif i == 1 and j == 2:
345
+ result += Z(0)
346
+
347
+ return result
348
+
349
+ return ObservableConfig(
350
+ interaction=ZZ_Z_hamiltonian,
351
+ detuning=Z,
352
+ scale=scale,
353
+ shift=shift,
354
+ trainable_transform=trainable_transform,
355
+ )
@@ -92,7 +92,9 @@ def pulse_experiment(
92
92
  )
93
93
  # Convert observable to Numpy types compatible with QuTip simulations.
94
94
  # Matrices are flipped to match QuTip conventions.
95
- converted_observable = [np.flip(block_to_tensor(obs).numpy()) for obs in observable]
95
+ converted_observable = [
96
+ np.flip(block_to_tensor(obs, use_full_support=True).numpy()) for obs in observable
97
+ ]
96
98
  # Create ZNE datasets by looping over batches.
97
99
  for observable in converted_observable:
98
100
  # Get expectation values at the end of the time serie [0,t]
@@ -130,7 +132,9 @@ def noise_level_experiment(
130
132
  )
131
133
  # Convert observable to Numpy types compatible with QuTip simulations.
132
134
  # Matrices are flipped to match QuTip conventions.
133
- converted_observable = [np.flip(block_to_tensor(obs).numpy()) for obs in observable]
135
+ converted_observable = [
136
+ np.flip(block_to_tensor(obs, use_full_support=True).numpy()) for obs in observable
137
+ ]
134
138
  # Create ZNE datasets by looping over batches.
135
139
  for observable in converted_observable:
136
140
  # Get expectation values at the end of the time serie [0,t]
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from .callbacks.saveload import load_checkpoint, load_model, write_checkpoint
4
4
  from .config import AnsatzConfig, FeatureMapConfig, TrainConfig
5
- from .constructors import create_ansatz, create_fm_blocks, observable_from_config
5
+ from .constructors import create_ansatz, create_fm_blocks, create_observable
6
6
  from .data import DictDataLoader, InfiniteTensorDataset, OptimizeResult, to_dataloader
7
7
  from .information import InformationContent
8
8
  from .models import QNN
@@ -19,7 +19,7 @@ __all__ = [
19
19
  "DictDataLoader",
20
20
  "FeatureMapConfig",
21
21
  "load_checkpoint",
22
- "observable_from_config",
22
+ "create_observable",
23
23
  "QNN",
24
24
  "TrainConfig",
25
25
  "OptimizeResult",
@@ -95,14 +95,36 @@ class Callback:
95
95
  self.callback: CallbackFunction | None = callback
96
96
  self.on: str | TrainingStage = on
97
97
  self.called_every: int = called_every
98
- self.callback_condition = callback_condition or (lambda _: True)
98
+ self.callback_condition = (
99
+ callback_condition if callback_condition else Callback.default_callback
100
+ )
99
101
 
100
102
  if isinstance(modify_optimize_result, dict):
101
- self.modify_optimize_result = (
102
- lambda opt_res: opt_res.extra.update(modify_optimize_result) or opt_res
103
+ self.modify_optimize_result = lambda opt_res: Callback.modify_opt_res_dict(
104
+ opt_res, modify_optimize_result
103
105
  )
104
106
  else:
105
- self.modify_optimize_result = modify_optimize_result or (lambda opt_res: opt_res)
107
+ self.modify_optimize_result = (
108
+ modify_optimize_result
109
+ if modify_optimize_result
110
+ else Callback.modify_opt_res_default
111
+ )
112
+
113
+ @staticmethod
114
+ def default_callback(_: Any) -> bool:
115
+ return True
116
+
117
+ @staticmethod
118
+ def modify_opt_res_dict(
119
+ opt_res: OptimizeResult,
120
+ modify_optimize_result: dict[str, Any] = {},
121
+ ) -> OptimizeResult:
122
+ opt_res.extra.update(modify_optimize_result)
123
+ return opt_res
124
+
125
+ @staticmethod
126
+ def modify_opt_res_default(opt_res: OptimizeResult) -> OptimizeResult:
127
+ return opt_res
106
128
 
107
129
  @property
108
130
  def on(self) -> TrainingStage | str:
@@ -261,8 +283,9 @@ class WriteMetrics(Callback):
261
283
  config (TrainConfig): The configuration object.
262
284
  writer (BaseWriter ): The writer object for logging.
263
285
  """
264
- opt_result = trainer.opt_result
265
- writer.write(opt_result.iteration, opt_result.metrics)
286
+ if trainer.accelerator.rank == 0:
287
+ opt_result = trainer.opt_result
288
+ writer.write(opt_result.iteration, opt_result.metrics)
266
289
 
267
290
 
268
291
  class PlotMetrics(Callback):
@@ -299,9 +322,10 @@ class PlotMetrics(Callback):
299
322
  config (TrainConfig): The configuration object.
300
323
  writer (BaseWriter ): The writer object for logging.
301
324
  """
302
- opt_result = trainer.opt_result
303
- plotting_functions = config.plotting_functions
304
- writer.plot(trainer.model, opt_result.iteration, plotting_functions)
325
+ if trainer.accelerator.rank == 0:
326
+ opt_result = trainer.opt_result
327
+ plotting_functions = config.plotting_functions
328
+ writer.plot(trainer.model, opt_result.iteration, plotting_functions)
305
329
 
306
330
 
307
331
  class LogHyperparameters(Callback):
@@ -338,8 +362,9 @@ class LogHyperparameters(Callback):
338
362
  config (TrainConfig): The configuration object.
339
363
  writer (BaseWriter ): The writer object for logging.
340
364
  """
341
- hyperparams = config.hyperparams
342
- writer.log_hyperparams(hyperparams)
365
+ if trainer.accelerator.rank == 0:
366
+ hyperparams = config.hyperparams
367
+ writer.log_hyperparams(hyperparams)
343
368
 
344
369
 
345
370
  class SaveCheckpoint(Callback):
@@ -376,11 +401,12 @@ class SaveCheckpoint(Callback):
376
401
  config (TrainConfig): The configuration object.
377
402
  writer (BaseWriter ): The writer object for logging.
378
403
  """
379
- folder = config.log_folder
380
- model = trainer.model
381
- optimizer = trainer.optimizer
382
- opt_result = trainer.opt_result
383
- write_checkpoint(folder, model, optimizer, opt_result.iteration)
404
+ if trainer.accelerator.rank == 0:
405
+ folder = config.log_folder
406
+ model = trainer.model
407
+ optimizer = trainer.optimizer
408
+ opt_result = trainer.opt_result
409
+ write_checkpoint(folder, model, optimizer, opt_result.iteration)
384
410
 
385
411
 
386
412
  class SaveBestCheckpoint(SaveCheckpoint):
@@ -404,17 +430,18 @@ class SaveBestCheckpoint(SaveCheckpoint):
404
430
  config (TrainConfig): The configuration object.
405
431
  writer (BaseWriter ): The writer object for logging.
406
432
  """
407
- opt_result = trainer.opt_result
408
- if config.validation_criterion and config.validation_criterion(
409
- opt_result.loss, self.best_loss, config.val_epsilon
410
- ):
411
- self.best_loss = opt_result.loss
412
-
413
- folder = config.log_folder
414
- model = trainer.model
415
- optimizer = trainer.optimizer
433
+ if trainer.accelerator.rank == 0:
416
434
  opt_result = trainer.opt_result
417
- write_checkpoint(folder, model, optimizer, "best")
435
+ if config.validation_criterion and config.validation_criterion(
436
+ opt_result.loss, self.best_loss, config.val_epsilon
437
+ ):
438
+ self.best_loss = opt_result.loss
439
+
440
+ folder = config.log_folder
441
+ model = trainer.model
442
+ optimizer = trainer.optimizer
443
+ opt_result = trainer.opt_result
444
+ write_checkpoint(folder, model, optimizer, "best")
418
445
 
419
446
 
420
447
  class LoadCheckpoint(Callback):
@@ -431,11 +458,12 @@ class LoadCheckpoint(Callback):
431
458
  Returns:
432
459
  Any: The result of loading the checkpoint.
433
460
  """
434
- folder = config.log_folder
435
- model = trainer.model
436
- optimizer = trainer.optimizer
437
- device = trainer.log_device
438
- return load_checkpoint(folder, model, optimizer, device=device)
461
+ if trainer.accelerator.rank == 0:
462
+ folder = config.log_folder
463
+ model = trainer.model
464
+ optimizer = trainer.optimizer
465
+ device = trainer.accelerator.execution.log_device
466
+ return load_checkpoint(folder, model, optimizer, device=device)
439
467
 
440
468
 
441
469
  class LogModelTracker(Callback):
@@ -449,10 +477,11 @@ class LogModelTracker(Callback):
449
477
  config (TrainConfig): The configuration object.
450
478
  writer (BaseWriter ): The writer object for logging.
451
479
  """
452
- model = trainer.model
453
- writer.log_model(
454
- model, trainer.train_dataloader, trainer.val_dataloader, trainer.test_dataloader
455
- )
480
+ if trainer.accelerator.rank == 0:
481
+ model = trainer.model
482
+ writer.log_model(
483
+ model, trainer.train_dataloader, trainer.val_dataloader, trainer.test_dataloader
484
+ )
456
485
 
457
486
 
458
487
  class LRSchedulerStepDecay(Callback):
@@ -713,7 +742,7 @@ class EarlyStopping(Callback):
713
742
  f"EarlyStopping: No improvement in '{self.monitor}' for {self.patience} epochs. "
714
743
  "Stopping training."
715
744
  )
716
- trainer.stop_training = True
745
+ trainer._stop_training.fill_(1)
717
746
 
718
747
 
719
748
  class GradientMonitoring(Callback):
@@ -759,17 +788,18 @@ class GradientMonitoring(Callback):
759
788
  config (TrainConfig): The configuration object.
760
789
  writer (BaseWriter): The writer object for logging.
761
790
  """
762
- gradient_stats = {}
763
- for name, param in trainer.model.named_parameters():
764
- if param.grad is not None:
765
- grad = param.grad
766
- gradient_stats.update(
767
- {
768
- name + "_mean": grad.mean().item(),
769
- name + "_std": grad.std().item(),
770
- name + "_max": grad.max().item(),
771
- name + "_min": grad.min().item(),
772
- }
773
- )
774
-
775
- writer.write(trainer.opt_result.iteration, gradient_stats)
791
+ if trainer.accelerator.rank == 0:
792
+ gradient_stats = {}
793
+ for name, param in trainer.model.named_parameters():
794
+ if param.grad is not None:
795
+ grad = param.grad
796
+ gradient_stats.update(
797
+ {
798
+ name + "_mean": grad.mean().item(),
799
+ name + "_std": grad.std().item(),
800
+ name + "_max": grad.max().item(),
801
+ name + "_min": grad.min().item(),
802
+ }
803
+ )
804
+
805
+ writer.write(trainer.opt_result.iteration, gradient_stats)
@@ -201,7 +201,8 @@ class CallbacksManager:
201
201
  logger.debug(f"Loaded model and optimizer from {self.config.log_folder}")
202
202
 
203
203
  # Setup writer
204
- self.writer.open(self.config, iteration=trainer.global_step)
204
+ if trainer.accelerator.rank == 0:
205
+ self.writer.open(self.config, iteration=trainer.global_step)
205
206
 
206
207
  def end_training(self, trainer: Any) -> None:
207
208
  """
@@ -210,5 +211,5 @@ class CallbacksManager:
210
211
  Args:
211
212
  trainer (Any): The training object managing the training process.
212
213
  """
213
- if self.writer:
214
+ if trainer.accelerator.rank == 0 and self.writer:
214
215
  self.writer.close()
@@ -127,11 +127,12 @@ class BaseWriter(ABC):
127
127
 
128
128
  # Find the key in result.metrics that contains "loss" (case-insensitive)
129
129
  loss_key = next((k for k in result.metrics if "loss" in k.lower()), None)
130
+ initial = f"P {result.rank: >2}|{result.device: <7}| Iteration {result.iteration: >7}| "
130
131
  if loss_key:
131
132
  loss_value = result.metrics[loss_key]
132
- msg = f"Iteration {result.iteration: >7} | {loss_key.title()}: {loss_value:.7f} -"
133
+ msg = initial + f"{loss_key.title()}: {loss_value:.7f} -"
133
134
  else:
134
- msg = f"Iteration {result.iteration: >7} | Loss: None -"
135
+ msg = initial + f"Loss: None -"
135
136
  msg += " ".join([f"{k}: {v:.7f}" for k, v in result.metrics.items() if k != loss_key])
136
137
  print(msg)
137
138
 
@@ -20,6 +20,7 @@ from qadence.types import (
20
20
  ReuploadScaling,
21
21
  Strategy,
22
22
  )
23
+ from torch import dtype
23
24
 
24
25
  logger = getLogger(__file__)
25
26
 
@@ -116,10 +117,9 @@ class TrainConfig:
116
117
  """The log folder for saving checkpoints and tensorboard logs.
117
118
 
118
119
  This stores the path where all logs and checkpoints are being saved
119
- for this training session. `log_folder` takes precedence over `root_folder` and
120
- `create_subfolder_per_run` arguments. If the user specifies a log_folder,
121
- all checkpoints will be saved in this folder and `root_folder` argument
122
- will not be used.
120
+ for this training session. `log_folder` takes precedence over `root_folder`,
121
+ but it is ignored if `create_subfolders_per_run=True` (in which case, subfolders
122
+ will be spawned in the root folder).
123
123
  """
124
124
 
125
125
  checkpoint_best_only: bool = False
@@ -195,7 +195,7 @@ class TrainConfig:
195
195
  plots that are logged or saved at specified intervals.
196
196
  """
197
197
 
198
- _subfolders: list = field(default_factory=list)
198
+ _subfolders: list[str] = field(default_factory=list)
199
199
  """List of subfolders used for logging different runs using the same config inside the.
200
200
 
201
201
  root folder.
@@ -203,6 +203,67 @@ class TrainConfig:
203
203
  Each subfolder is of structure `<id>_<timestamp>_<PID>`.
204
204
  """
205
205
 
206
+ nprocs: int = 1
207
+ """
208
+ The number of processes to use for training when spawning subprocesses.
209
+
210
+ For effective parallel processing, set this to a value greater than 1.
211
+ - In case of Multi-GPU or Multi-Node-Multi-GPU setups, nprocs should be equal to
212
+ the total number of GPUs across all nodes (world size), or total number of GPU to be used.
213
+
214
+ If nprocs > 1, multiple processes will be spawned for training. The training framework will launch
215
+ additional processes (e.g., for distributed or parallel training).
216
+ - For CPU setup, this will launch a true parallel processes
217
+ - For GPU setup, this will launch a distributed training routine.
218
+ This uses the DistributedDataParallel framework from PyTorch.
219
+ """
220
+
221
+ compute_setup: str = "cpu"
222
+ """
223
+ Compute device setup; options are "auto", "gpu", or "cpu".
224
+
225
+ - "auto": Automatically uses GPU if available; otherwise, falls back to CPU.
226
+ - "gpu": Forces GPU usage, raising an error if no CUDA device is available.
227
+ - "cpu": Forces the use of CPU regardless of GPU availability.
228
+ """
229
+
230
+ backend: str = "gloo"
231
+ """
232
+ Backend used for distributed training communication.
233
+
234
+ The default is "gloo". Other options may include "nccl" - which is optimized for GPU-based training or "mpi",
235
+ depending on your system and requirements.
236
+ It should be one of the backends supported by `torch.distributed`. For further details, please look at
237
+ [torch backends](https://pytorch.org/docs/stable/distributed.html#torch.distributed.Backend)
238
+ """
239
+
240
+ log_setup: str = "cpu"
241
+ """
242
+ Logging device setup; options are "auto" or "cpu".
243
+
244
+ - "auto": Uses the same device for logging as for computation.
245
+ - "cpu": Forces logging to occur on the CPU. This can be useful to avoid potential conflicts with GPU processes.
246
+ """
247
+
248
+ dtype: dtype | None = None
249
+ """
250
+ Data type (precision) for computations.
251
+
252
+ Both model parameters, and dataset will be of the provided precision.
253
+
254
+ If not specified or None, the default torch precision (usually torch.float32) is used.
255
+ If provided dtype is torch.complex128, model parameters will be torch.complex128, and data parameters will be torch.float64
256
+ """
257
+
258
+ all_reduce_metrics: bool = False
259
+ """
260
+ Whether to aggregate metrics (e.g., loss, accuracy) across processes.
261
+
262
+ When True, metrics from different training processes are averaged to provide a consolidated metrics.
263
+ Note: Since aggregation requires synchronization/all_reduce operation, this can increase the
264
+ computation time significantly.
265
+ """
266
+
206
267
 
207
268
  @dataclass
208
269
  class FeatureMapConfig: