qadence 1.10.3__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qadence/blocks/block_to_tensor.py +21 -24
 - qadence/constructors/__init__.py +7 -1
 - qadence/constructors/hamiltonians.py +96 -9
 - qadence/mitigations/analog_zne.py +6 -2
 - qadence/ml_tools/__init__.py +2 -2
 - qadence/ml_tools/callbacks/callback.py +80 -50
 - qadence/ml_tools/callbacks/callbackmanager.py +3 -2
 - qadence/ml_tools/callbacks/writer_registry.py +3 -2
 - qadence/ml_tools/config.py +66 -5
 - qadence/ml_tools/constructors.py +9 -62
 - qadence/ml_tools/data.py +4 -0
 - qadence/ml_tools/models.py +69 -4
 - qadence/ml_tools/optimize_step.py +1 -2
 - qadence/ml_tools/train_utils/__init__.py +3 -1
 - qadence/ml_tools/train_utils/accelerator.py +480 -0
 - qadence/ml_tools/train_utils/config_manager.py +7 -7
 - qadence/ml_tools/train_utils/distribution.py +209 -0
 - qadence/ml_tools/train_utils/execution.py +421 -0
 - qadence/ml_tools/trainer.py +188 -100
 - qadence/types.py +7 -11
 - qadence/utils.py +45 -0
 - {qadence-1.10.3.dist-info → qadence-1.11.0.dist-info}/METADATA +14 -11
 - {qadence-1.10.3.dist-info → qadence-1.11.0.dist-info}/RECORD +25 -22
 - {qadence-1.10.3.dist-info → qadence-1.11.0.dist-info}/WHEEL +0 -0
 - {qadence-1.10.3.dist-info → qadence-1.11.0.dist-info}/licenses/LICENSE +0 -0
 
| 
         @@ -82,18 +82,20 @@ def _fill_identities( 
     | 
|
| 
       82 
82 
     | 
    
         
             
                full_qubit_support = tuple(sorted(full_qubit_support))
         
     | 
| 
       83 
83 
     | 
    
         
             
                qubit_support = tuple(sorted(qubit_support))
         
     | 
| 
       84 
84 
     | 
    
         
             
                block_mat = block_mat.to(device)
         
     | 
| 
       85 
     | 
    
         
            -
                 
     | 
| 
      
 85 
     | 
    
         
            +
                identity_mat = IMAT.to(device)
         
     | 
| 
       86 
86 
     | 
    
         
             
                if diag_only:
         
     | 
| 
       87 
     | 
    
         
            -
                     
     | 
| 
      
 87 
     | 
    
         
            +
                    block_mat = torch.diag(block_mat.squeeze(0))
         
     | 
| 
      
 88 
     | 
    
         
            +
                    identity_mat = torch.diag(identity_mat.squeeze(0))
         
     | 
| 
      
 89 
     | 
    
         
            +
                mat = identity_mat if qubit_support[0] != full_qubit_support[0] else block_mat
         
     | 
| 
       88 
90 
     | 
    
         
             
                for i in full_qubit_support[1:]:
         
     | 
| 
       89 
91 
     | 
    
         
             
                    if i == qubit_support[0]:
         
     | 
| 
       90 
     | 
    
         
            -
                        other =  
     | 
| 
      
 92 
     | 
    
         
            +
                        other = block_mat
         
     | 
| 
       91 
93 
     | 
    
         
             
                        if endianness == Endianness.LITTLE:
         
     | 
| 
       92 
94 
     | 
    
         
             
                            mat = torch.kron(other, mat)
         
     | 
| 
       93 
95 
     | 
    
         
             
                        else:
         
     | 
| 
       94 
96 
     | 
    
         
             
                            mat = torch.kron(mat.contiguous(), other.contiguous())
         
     | 
| 
       95 
97 
     | 
    
         
             
                    elif i not in qubit_support:
         
     | 
| 
       96 
     | 
    
         
            -
                        other =  
     | 
| 
      
 98 
     | 
    
         
            +
                        other = identity_mat
         
     | 
| 
       97 
99 
     | 
    
         
             
                        if endianness == Endianness.LITTLE:
         
     | 
| 
       98 
100 
     | 
    
         
             
                            mat = torch.kron(other.contiguous(), mat.contiguous())
         
     | 
| 
       99 
101 
     | 
    
         
             
                        else:
         
     | 
| 
         @@ -264,13 +266,12 @@ def _gate_parameters(b: AbstractBlock, values: dict[str, torch.Tensor]) -> tuple 
     | 
|
| 
       264 
266 
     | 
    
         | 
| 
       265 
267 
     | 
    
         
             
            def block_to_diagonal(
         
     | 
| 
       266 
268 
     | 
    
         
             
                block: AbstractBlock,
         
     | 
| 
      
 269 
     | 
    
         
            +
                values: dict[str, TNumber | torch.Tensor] = dict(),
         
     | 
| 
       267 
270 
     | 
    
         
             
                qubit_support: tuple | list | None = None,
         
     | 
| 
       268 
     | 
    
         
            -
                use_full_support: bool =  
     | 
| 
      
 271 
     | 
    
         
            +
                use_full_support: bool = False,
         
     | 
| 
       269 
272 
     | 
    
         
             
                endianness: Endianness = Endianness.BIG,
         
     | 
| 
       270 
273 
     | 
    
         
             
                device: torch.device = None,
         
     | 
| 
       271 
274 
     | 
    
         
             
            ) -> torch.Tensor:
         
     | 
| 
       272 
     | 
    
         
            -
                if block.is_parametric:
         
     | 
| 
       273 
     | 
    
         
            -
                    raise TypeError("Sparse observables cant be parametric.")
         
     | 
| 
       274 
275 
     | 
    
         
             
                if not block._is_diag_pauli:
         
     | 
| 
       275 
276 
     | 
    
         
             
                    raise TypeError("Sparse observables can only be used on paulis which are diagonal.")
         
     | 
| 
       276 
277 
     | 
    
         
             
                if qubit_support is None:
         
     | 
| 
         @@ -282,17 +283,16 @@ def block_to_diagonal( 
     | 
|
| 
       282 
283 
     | 
    
         
             
                if isinstance(block, (ChainBlock, KronBlock)):
         
     | 
| 
       283 
284 
     | 
    
         
             
                    v = torch.ones(2**nqubits, dtype=torch.cdouble)
         
     | 
| 
       284 
285 
     | 
    
         
             
                    for b in block.blocks:
         
     | 
| 
       285 
     | 
    
         
            -
                        v *= block_to_diagonal(b, qubit_support)
         
     | 
| 
      
 286 
     | 
    
         
            +
                        v *= block_to_diagonal(b, values, qubit_support, device=device)
         
     | 
| 
       286 
287 
     | 
    
         
             
                if isinstance(block, AddBlock):
         
     | 
| 
       287 
288 
     | 
    
         
             
                    t = torch.zeros(2**nqubits, dtype=torch.cdouble)
         
     | 
| 
       288 
289 
     | 
    
         
             
                    for b in block.blocks:
         
     | 
| 
       289 
     | 
    
         
            -
                        t += block_to_diagonal(b, qubit_support)
         
     | 
| 
      
 290 
     | 
    
         
            +
                        t += block_to_diagonal(b, values, qubit_support, device=device)
         
     | 
| 
       290 
291 
     | 
    
         
             
                    v = t
         
     | 
| 
       291 
292 
     | 
    
         
             
                elif isinstance(block, ScaleBlock):
         
     | 
| 
       292 
     | 
    
         
            -
                    _s = evaluate(block.scale,  
     | 
| 
       293 
     | 
    
         
            -
                    _s = _s.detach()  # type: ignore[union-attr]
         
     | 
| 
       294 
     | 
    
         
            -
                    v = _s * block_to_diagonal(block.block, qubit_support)
         
     | 
| 
       295 
     | 
    
         
            -
             
     | 
| 
      
 293 
     | 
    
         
            +
                    _s = evaluate(block.scale, values, as_torch=True)  # type: ignore[attr-defined]
         
     | 
| 
      
 294 
     | 
    
         
            +
                    _s = _s.detach().squeeze(0)  # type: ignore[union-attr]
         
     | 
| 
      
 295 
     | 
    
         
            +
                    v = _s * block_to_diagonal(block.block, values, qubit_support, device=device)
         
     | 
| 
       296 
296 
     | 
    
         
             
                elif isinstance(block, PrimitiveBlock):
         
     | 
| 
       297 
297 
     | 
    
         
             
                    v = _fill_identities(
         
     | 
| 
       298 
298 
     | 
    
         
             
                        OPERATIONS_DICT[block.name],
         
     | 
| 
         @@ -300,6 +300,7 @@ def block_to_diagonal( 
     | 
|
| 
       300 
300 
     | 
    
         
             
                        qubit_support,  # type: ignore [arg-type]
         
     | 
| 
       301 
301 
     | 
    
         
             
                        diag_only=True,
         
     | 
| 
       302 
302 
     | 
    
         
             
                        endianness=endianness,
         
     | 
| 
      
 303 
     | 
    
         
            +
                        device=device,
         
     | 
| 
       303 
304 
     | 
    
         
             
                    )
         
     | 
| 
       304 
305 
     | 
    
         
             
                return v
         
     | 
| 
       305 
306 
     | 
    
         | 
| 
         @@ -309,7 +310,7 @@ def block_to_tensor( 
     | 
|
| 
       309 
310 
     | 
    
         
             
                block: AbstractBlock,
         
     | 
| 
       310 
311 
     | 
    
         
             
                values: dict[str, TNumber | torch.Tensor] = {},
         
     | 
| 
       311 
312 
     | 
    
         
             
                qubit_support: tuple | None = None,
         
     | 
| 
       312 
     | 
    
         
            -
                use_full_support: bool =  
     | 
| 
      
 313 
     | 
    
         
            +
                use_full_support: bool = False,
         
     | 
| 
       313 
314 
     | 
    
         
             
                tensor_type: TensorType = TensorType.DENSE,
         
     | 
| 
       314 
315 
     | 
    
         
             
                endianness: Endianness = Endianness.BIG,
         
     | 
| 
       315 
316 
     | 
    
         
             
                device: torch.device = None,
         
     | 
| 
         @@ -339,18 +340,14 @@ def block_to_tensor( 
     | 
|
| 
       339 
340 
     | 
    
         
             
                print(block_to_tensor(obs, tensor_type="SparseDiagonal"))
         
     | 
| 
       340 
341 
     | 
    
         
             
                ```
         
     | 
| 
       341 
342 
     | 
    
         
             
                """
         
     | 
| 
      
 343 
     | 
    
         
            +
                from qadence.blocks import embedding
         
     | 
| 
       342 
344 
     | 
    
         | 
| 
       343 
     | 
    
         
            -
                 
     | 
| 
       344 
     | 
    
         
            -
                 
     | 
| 
       345 
     | 
    
         
            -
                # as observables only do the matmul of the size of the qubit support.
         
     | 
| 
       346 
     | 
    
         
            -
             
     | 
| 
      
 345 
     | 
    
         
            +
                (ps, embed) = embedding(block)
         
     | 
| 
      
 346 
     | 
    
         
            +
                values = embed(ps, values)
         
     | 
| 
       347 
347 
     | 
    
         
             
                if tensor_type == TensorType.DENSE:
         
     | 
| 
       348 
     | 
    
         
            -
                    from qadence.blocks import embedding
         
     | 
| 
       349 
     | 
    
         
            -
             
     | 
| 
       350 
     | 
    
         
            -
                    (ps, embed) = embedding(block)
         
     | 
| 
       351 
348 
     | 
    
         
             
                    return _block_to_tensor_embedded(
         
     | 
| 
       352 
349 
     | 
    
         
             
                        block,
         
     | 
| 
       353 
     | 
    
         
            -
                         
     | 
| 
      
 350 
     | 
    
         
            +
                        values,
         
     | 
| 
       354 
351 
     | 
    
         
             
                        qubit_support,
         
     | 
| 
       355 
352 
     | 
    
         
             
                        use_full_support,
         
     | 
| 
       356 
353 
     | 
    
         
             
                        endianness=endianness,
         
     | 
| 
         @@ -358,7 +355,7 @@ def block_to_tensor( 
     | 
|
| 
       358 
355 
     | 
    
         
             
                    )
         
     | 
| 
       359 
356 
     | 
    
         | 
| 
       360 
357 
     | 
    
         
             
                elif tensor_type == TensorType.SPARSEDIAGONAL:
         
     | 
| 
       361 
     | 
    
         
            -
                    t = block_to_diagonal(block, endianness=endianness)
         
     | 
| 
      
 358 
     | 
    
         
            +
                    t = block_to_diagonal(block, values, endianness=endianness)
         
     | 
| 
       362 
359 
     | 
    
         
             
                    indices, values, size = torch.nonzero(t), t[t != 0], len(t)
         
     | 
| 
       363 
360 
     | 
    
         
             
                    indices = torch.stack((indices.flatten(), indices.flatten()))
         
     | 
| 
       364 
361 
     | 
    
         
             
                    return torch.sparse_coo_tensor(indices, values, (size, size))
         
     | 
| 
         @@ -369,7 +366,7 @@ def _block_to_tensor_embedded( 
     | 
|
| 
       369 
366 
     | 
    
         
             
                block: AbstractBlock,
         
     | 
| 
       370 
367 
     | 
    
         
             
                values: dict[str, TNumber | torch.Tensor] = {},
         
     | 
| 
       371 
368 
     | 
    
         
             
                qubit_support: tuple | None = None,
         
     | 
| 
       372 
     | 
    
         
            -
                use_full_support: bool =  
     | 
| 
      
 369 
     | 
    
         
            +
                use_full_support: bool = False,
         
     | 
| 
       373 
370 
     | 
    
         
             
                endianness: Endianness = Endianness.BIG,
         
     | 
| 
       374 
371 
     | 
    
         
             
                device: torch.device = None,
         
     | 
| 
       375 
372 
     | 
    
         
             
            ) -> torch.Tensor:
         
     | 
    
        qadence/constructors/__init__.py
    CHANGED
    
    | 
         @@ -17,6 +17,9 @@ from .hamiltonians import ( 
     | 
|
| 
       17 
17 
     | 
    
         
             
                ObservableConfig,
         
     | 
| 
       18 
18 
     | 
    
         
             
                total_magnetization,
         
     | 
| 
       19 
19 
     | 
    
         
             
                zz_hamiltonian,
         
     | 
| 
      
 20 
     | 
    
         
            +
                total_magnetization_config,
         
     | 
| 
      
 21 
     | 
    
         
            +
                zz_hamiltonian_config,
         
     | 
| 
      
 22 
     | 
    
         
            +
                ising_hamiltonian_config,
         
     | 
| 
       20 
23 
     | 
    
         
             
            )
         
     | 
| 
       21 
24 
     | 
    
         | 
| 
       22 
25 
     | 
    
         
             
            from .rydberg_hea import rydberg_hea, rydberg_hea_layer
         
     | 
| 
         @@ -34,9 +37,12 @@ __all__ = [ 
     | 
|
| 
       34 
37 
     | 
    
         
             
                "iia",
         
     | 
| 
       35 
38 
     | 
    
         
             
                "hamiltonian_factory",
         
     | 
| 
       36 
39 
     | 
    
         
             
                "ising_hamiltonian",
         
     | 
| 
       37 
     | 
    
         
            -
                "ObservableConfig",
         
     | 
| 
       38 
40 
     | 
    
         
             
                "total_magnetization",
         
     | 
| 
       39 
41 
     | 
    
         
             
                "zz_hamiltonian",
         
     | 
| 
      
 42 
     | 
    
         
            +
                "ObservableConfig",
         
     | 
| 
      
 43 
     | 
    
         
            +
                "total_magnetization_config",
         
     | 
| 
      
 44 
     | 
    
         
            +
                "zz_hamiltonian_config",
         
     | 
| 
      
 45 
     | 
    
         
            +
                "ising_hamiltonian_config",
         
     | 
| 
       40 
46 
     | 
    
         
             
                "qft",
         
     | 
| 
       41 
47 
     | 
    
         
             
                "daqc_transform",
         
     | 
| 
       42 
48 
     | 
    
         
             
                "rydberg_hea",
         
     | 
| 
         @@ -7,11 +7,12 @@ from typing import Callable, List, Type, Union 
     | 
|
| 
       7 
7 
     | 
    
         
             
            import numpy as np
         
     | 
| 
       8 
8 
     | 
    
         
             
            from torch import Tensor, double, ones, rand
         
     | 
| 
       9 
9 
     | 
    
         
             
            from typing_extensions import Any
         
     | 
| 
      
 10 
     | 
    
         
            +
            from qadence.parameters import Parameter
         
     | 
| 
       10 
11 
     | 
    
         | 
| 
       11 
12 
     | 
    
         
             
            from qadence.blocks import AbstractBlock, add, block_is_qubit_hamiltonian
         
     | 
| 
       12 
     | 
    
         
            -
            from qadence.operations import N, X, Y, Z
         
     | 
| 
      
 13 
     | 
    
         
            +
            from qadence.operations import N, X, Y, Z, H
         
     | 
| 
       13 
14 
     | 
    
         
             
            from qadence.register import Register
         
     | 
| 
       14 
     | 
    
         
            -
            from qadence.types import Interaction,  
     | 
| 
      
 15 
     | 
    
         
            +
            from qadence.types import Interaction, TArray, TParameter
         
     | 
| 
       15 
16 
     | 
    
         | 
| 
       16 
17 
     | 
    
         
             
            logger = getLogger(__name__)
         
     | 
| 
       17 
18 
     | 
    
         | 
| 
         @@ -239,7 +240,30 @@ def is_numeric(x: Any) -> bool: 
     | 
|
| 
       239 
240 
     | 
    
         | 
| 
       240 
241 
     | 
    
         
             
            @dataclass
         
     | 
| 
       241 
242 
     | 
    
         
             
            class ObservableConfig:
         
     | 
| 
       242 
     | 
    
         
            -
                 
     | 
| 
      
 243 
     | 
    
         
            +
                """ObservableConfig is a configuration class for defining the parameters of an observable Hamiltonian."""
         
     | 
| 
      
 244 
     | 
    
         
            +
             
     | 
| 
      
 245 
     | 
    
         
            +
                interaction: Interaction | Callable | None = None
         
     | 
| 
      
 246 
     | 
    
         
            +
                """
         
     | 
| 
      
 247 
     | 
    
         
            +
                The type of interaction.
         
     | 
| 
      
 248 
     | 
    
         
            +
             
     | 
| 
      
 249 
     | 
    
         
            +
                Available options from the Interaction enum are:
         
     | 
| 
      
 250 
     | 
    
         
            +
                        - Interaction.ZZ
         
     | 
| 
      
 251 
     | 
    
         
            +
                        - Interaction.NN
         
     | 
| 
      
 252 
     | 
    
         
            +
                        - Interaction.XY
         
     | 
| 
      
 253 
     | 
    
         
            +
                        - Interaction.XYZ
         
     | 
| 
      
 254 
     | 
    
         
            +
             
     | 
| 
      
 255 
     | 
    
         
            +
                Alternatively, a custom interaction function can be defined.
         
     | 
| 
      
 256 
     | 
    
         
            +
                        Example:
         
     | 
| 
      
 257 
     | 
    
         
            +
             
     | 
| 
      
 258 
     | 
    
         
            +
                            def custom_int(i: int, j: int):
         
     | 
| 
      
 259 
     | 
    
         
            +
                                return X(i) @ X(j) + Y(i) @ Y(j)
         
     | 
| 
      
 260 
     | 
    
         
            +
             
     | 
| 
      
 261 
     | 
    
         
            +
                            n_qubits = 2
         
     | 
| 
      
 262 
     | 
    
         
            +
             
     | 
| 
      
 263 
     | 
    
         
            +
                            observable_config = ObservableConfig(interaction=custom_int, scale = 1.0, shift = 0.0)
         
     | 
| 
      
 264 
     | 
    
         
            +
                            observable = create_observable(register=4, config=observable_config)
         
     | 
| 
      
 265 
     | 
    
         
            +
                """
         
     | 
| 
      
 266 
     | 
    
         
            +
                detuning: TDetuning | None = None
         
     | 
| 
       243 
267 
     | 
    
         
             
                """
         
     | 
| 
       244 
268 
     | 
    
         
             
                Single qubit detuning of the observable Hamiltonian.
         
     | 
| 
       245 
269 
     | 
    
         | 
| 
         @@ -249,8 +273,6 @@ class ObservableConfig: 
     | 
|
| 
       249 
273 
     | 
    
         
             
                """The scale by which to multiply the output of the observable."""
         
     | 
| 
       250 
274 
     | 
    
         
             
                shift: TParameter = 0.0
         
     | 
| 
       251 
275 
     | 
    
         
             
                """The shift to add to the output of the observable."""
         
     | 
| 
       252 
     | 
    
         
            -
                transformation_type: ObservableTransform = ObservableTransform.NONE  # type: ignore[assignment]
         
     | 
| 
       253 
     | 
    
         
            -
                """The type of transformation."""
         
     | 
| 
       254 
276 
     | 
    
         
             
                trainable_transform: bool | None = None
         
     | 
| 
       255 
277 
     | 
    
         
             
                """
         
     | 
| 
       256 
278 
     | 
    
         
             
                Whether to have a trainable transformation on the output of the observable.
         
     | 
| 
         @@ -261,8 +283,73 @@ class ObservableConfig: 
     | 
|
| 
       261 
283 
     | 
    
         
             
                """
         
     | 
| 
       262 
284 
     | 
    
         | 
| 
       263 
285 
     | 
    
         
             
                def __post_init__(self) -> None:
         
     | 
| 
      
 286 
     | 
    
         
            +
                    if self.interaction is None and self.detuning is None:
         
     | 
| 
      
 287 
     | 
    
         
            +
                        raise ValueError(
         
     | 
| 
      
 288 
     | 
    
         
            +
                            "Please provide an interaction and/or detuning for the Observable Hamiltonian."
         
     | 
| 
      
 289 
     | 
    
         
            +
                        )
         
     | 
| 
      
 290 
     | 
    
         
            +
             
     | 
| 
       264 
291 
     | 
    
         
             
                    if is_numeric(self.scale) and is_numeric(self.shift):
         
     | 
| 
       265 
     | 
    
         
            -
                        assert (
         
     | 
| 
       266 
     | 
    
         
            -
                             
     | 
| 
       267 
     | 
    
         
            -
             
     | 
| 
       268 
     | 
    
         
            -
                         
     | 
| 
      
 292 
     | 
    
         
            +
                        assert self.trainable_transform is None, (
         
     | 
| 
      
 293 
     | 
    
         
            +
                            "If scale and shift are numbers, trainable_transform must be None."
         
     | 
| 
      
 294 
     | 
    
         
            +
                            f"But got: {self.trainable_transform}"
         
     | 
| 
      
 295 
     | 
    
         
            +
                        )
         
     | 
| 
      
 296 
     | 
    
         
            +
             
     | 
| 
      
 297 
     | 
    
         
            +
                    # trasform the scale and shift into parameters
         
     | 
| 
      
 298 
     | 
    
         
            +
                    if self.trainable_transform is not None:
         
     | 
| 
      
 299 
     | 
    
         
            +
                        self.shift = Parameter(name=self.shift, trainable=self.trainable_transform)
         
     | 
| 
      
 300 
     | 
    
         
            +
                        self.scale = Parameter(name=self.scale, trainable=self.trainable_transform)
         
     | 
| 
      
 301 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 302 
     | 
    
         
            +
                        self.shift = Parameter(self.shift)
         
     | 
| 
      
 303 
     | 
    
         
            +
                        self.scale = Parameter(self.scale)
         
     | 
| 
      
 304 
     | 
    
         
            +
             
     | 
| 
      
 305 
     | 
    
         
            +
             
     | 
| 
      
 306 
     | 
    
         
            +
            def total_magnetization_config(
         
     | 
| 
      
 307 
     | 
    
         
            +
                scale: TParameter = 1.0,
         
     | 
| 
      
 308 
     | 
    
         
            +
                shift: TParameter = 0.0,
         
     | 
| 
      
 309 
     | 
    
         
            +
                trainable_transform: bool | None = None,
         
     | 
| 
      
 310 
     | 
    
         
            +
            ) -> ObservableConfig:
         
     | 
| 
      
 311 
     | 
    
         
            +
                return ObservableConfig(
         
     | 
| 
      
 312 
     | 
    
         
            +
                    detuning=Z,
         
     | 
| 
      
 313 
     | 
    
         
            +
                    scale=scale,
         
     | 
| 
      
 314 
     | 
    
         
            +
                    shift=shift,
         
     | 
| 
      
 315 
     | 
    
         
            +
                    trainable_transform=trainable_transform,
         
     | 
| 
      
 316 
     | 
    
         
            +
                )
         
     | 
| 
      
 317 
     | 
    
         
            +
             
     | 
| 
      
 318 
     | 
    
         
            +
             
     | 
| 
      
 319 
     | 
    
         
            +
            def zz_hamiltonian_config(
         
     | 
| 
      
 320 
     | 
    
         
            +
                scale: TParameter = 1.0,
         
     | 
| 
      
 321 
     | 
    
         
            +
                shift: TParameter = 0.0,
         
     | 
| 
      
 322 
     | 
    
         
            +
                trainable_transform: bool | None = None,
         
     | 
| 
      
 323 
     | 
    
         
            +
            ) -> ObservableConfig:
         
     | 
| 
      
 324 
     | 
    
         
            +
                return ObservableConfig(
         
     | 
| 
      
 325 
     | 
    
         
            +
                    interaction=Interaction.ZZ,
         
     | 
| 
      
 326 
     | 
    
         
            +
                    detuning=Z,
         
     | 
| 
      
 327 
     | 
    
         
            +
                    scale=scale,
         
     | 
| 
      
 328 
     | 
    
         
            +
                    shift=shift,
         
     | 
| 
      
 329 
     | 
    
         
            +
                    trainable_transform=trainable_transform,
         
     | 
| 
      
 330 
     | 
    
         
            +
                )
         
     | 
| 
      
 331 
     | 
    
         
            +
             
     | 
| 
      
 332 
     | 
    
         
            +
             
     | 
| 
      
 333 
     | 
    
         
            +
            def ising_hamiltonian_config(
         
     | 
| 
      
 334 
     | 
    
         
            +
                scale: TParameter = 1.0,
         
     | 
| 
      
 335 
     | 
    
         
            +
                shift: TParameter = 0.0,
         
     | 
| 
      
 336 
     | 
    
         
            +
                trainable_transform: bool | None = None,
         
     | 
| 
      
 337 
     | 
    
         
            +
            ) -> ObservableConfig:
         
     | 
| 
      
 338 
     | 
    
         
            +
             
     | 
| 
      
 339 
     | 
    
         
            +
                def ZZ_Z_hamiltonian(i: int, j: int) -> AbstractBlock:
         
     | 
| 
      
 340 
     | 
    
         
            +
                    result = Z(i) @ Z(j)
         
     | 
| 
      
 341 
     | 
    
         
            +
             
     | 
| 
      
 342 
     | 
    
         
            +
                    if i == 0:
         
     | 
| 
      
 343 
     | 
    
         
            +
                        result += Z(j)
         
     | 
| 
      
 344 
     | 
    
         
            +
                    elif i == 1 and j == 2:
         
     | 
| 
      
 345 
     | 
    
         
            +
                        result += Z(0)
         
     | 
| 
      
 346 
     | 
    
         
            +
             
     | 
| 
      
 347 
     | 
    
         
            +
                    return result
         
     | 
| 
      
 348 
     | 
    
         
            +
             
     | 
| 
      
 349 
     | 
    
         
            +
                return ObservableConfig(
         
     | 
| 
      
 350 
     | 
    
         
            +
                    interaction=ZZ_Z_hamiltonian,
         
     | 
| 
      
 351 
     | 
    
         
            +
                    detuning=Z,
         
     | 
| 
      
 352 
     | 
    
         
            +
                    scale=scale,
         
     | 
| 
      
 353 
     | 
    
         
            +
                    shift=shift,
         
     | 
| 
      
 354 
     | 
    
         
            +
                    trainable_transform=trainable_transform,
         
     | 
| 
      
 355 
     | 
    
         
            +
                )
         
     | 
| 
         @@ -92,7 +92,9 @@ def pulse_experiment( 
     | 
|
| 
       92 
92 
     | 
    
         
             
                    )
         
     | 
| 
       93 
93 
     | 
    
         
             
                # Convert observable to Numpy types compatible with QuTip simulations.
         
     | 
| 
       94 
94 
     | 
    
         
             
                # Matrices are flipped to match QuTip conventions.
         
     | 
| 
       95 
     | 
    
         
            -
                converted_observable = [ 
     | 
| 
      
 95 
     | 
    
         
            +
                converted_observable = [
         
     | 
| 
      
 96 
     | 
    
         
            +
                    np.flip(block_to_tensor(obs, use_full_support=True).numpy()) for obs in observable
         
     | 
| 
      
 97 
     | 
    
         
            +
                ]
         
     | 
| 
       96 
98 
     | 
    
         
             
                # Create ZNE datasets by looping over batches.
         
     | 
| 
       97 
99 
     | 
    
         
             
                for observable in converted_observable:
         
     | 
| 
       98 
100 
     | 
    
         
             
                    # Get expectation values at the end of the time serie [0,t]
         
     | 
| 
         @@ -130,7 +132,9 @@ def noise_level_experiment( 
     | 
|
| 
       130 
132 
     | 
    
         
             
                )
         
     | 
| 
       131 
133 
     | 
    
         
             
                # Convert observable to Numpy types compatible with QuTip simulations.
         
     | 
| 
       132 
134 
     | 
    
         
             
                # Matrices are flipped to match QuTip conventions.
         
     | 
| 
       133 
     | 
    
         
            -
                converted_observable = [ 
     | 
| 
      
 135 
     | 
    
         
            +
                converted_observable = [
         
     | 
| 
      
 136 
     | 
    
         
            +
                    np.flip(block_to_tensor(obs, use_full_support=True).numpy()) for obs in observable
         
     | 
| 
      
 137 
     | 
    
         
            +
                ]
         
     | 
| 
       134 
138 
     | 
    
         
             
                # Create ZNE datasets by looping over batches.
         
     | 
| 
       135 
139 
     | 
    
         
             
                for observable in converted_observable:
         
     | 
| 
       136 
140 
     | 
    
         
             
                    # Get expectation values at the end of the time serie [0,t]
         
     | 
    
        qadence/ml_tools/__init__.py
    CHANGED
    
    | 
         @@ -2,7 +2,7 @@ from __future__ import annotations 
     | 
|
| 
       2 
2 
     | 
    
         | 
| 
       3 
3 
     | 
    
         
             
            from .callbacks.saveload import load_checkpoint, load_model, write_checkpoint
         
     | 
| 
       4 
4 
     | 
    
         
             
            from .config import AnsatzConfig, FeatureMapConfig, TrainConfig
         
     | 
| 
       5 
     | 
    
         
            -
            from .constructors import create_ansatz, create_fm_blocks,  
     | 
| 
      
 5 
     | 
    
         
            +
            from .constructors import create_ansatz, create_fm_blocks, create_observable
         
     | 
| 
       6 
6 
     | 
    
         
             
            from .data import DictDataLoader, InfiniteTensorDataset, OptimizeResult, to_dataloader
         
     | 
| 
       7 
7 
     | 
    
         
             
            from .information import InformationContent
         
     | 
| 
       8 
8 
     | 
    
         
             
            from .models import QNN
         
     | 
| 
         @@ -19,7 +19,7 @@ __all__ = [ 
     | 
|
| 
       19 
19 
     | 
    
         
             
                "DictDataLoader",
         
     | 
| 
       20 
20 
     | 
    
         
             
                "FeatureMapConfig",
         
     | 
| 
       21 
21 
     | 
    
         
             
                "load_checkpoint",
         
     | 
| 
       22 
     | 
    
         
            -
                " 
     | 
| 
      
 22 
     | 
    
         
            +
                "create_observable",
         
     | 
| 
       23 
23 
     | 
    
         
             
                "QNN",
         
     | 
| 
       24 
24 
     | 
    
         
             
                "TrainConfig",
         
     | 
| 
       25 
25 
     | 
    
         
             
                "OptimizeResult",
         
     | 
| 
         @@ -95,14 +95,36 @@ class Callback: 
     | 
|
| 
       95 
95 
     | 
    
         
             
                    self.callback: CallbackFunction | None = callback
         
     | 
| 
       96 
96 
     | 
    
         
             
                    self.on: str | TrainingStage = on
         
     | 
| 
       97 
97 
     | 
    
         
             
                    self.called_every: int = called_every
         
     | 
| 
       98 
     | 
    
         
            -
                    self.callback_condition =  
     | 
| 
      
 98 
     | 
    
         
            +
                    self.callback_condition = (
         
     | 
| 
      
 99 
     | 
    
         
            +
                        callback_condition if callback_condition else Callback.default_callback
         
     | 
| 
      
 100 
     | 
    
         
            +
                    )
         
     | 
| 
       99 
101 
     | 
    
         | 
| 
       100 
102 
     | 
    
         
             
                    if isinstance(modify_optimize_result, dict):
         
     | 
| 
       101 
     | 
    
         
            -
                        self.modify_optimize_result = (
         
     | 
| 
       102 
     | 
    
         
            -
                             
     | 
| 
      
 103 
     | 
    
         
            +
                        self.modify_optimize_result = lambda opt_res: Callback.modify_opt_res_dict(
         
     | 
| 
      
 104 
     | 
    
         
            +
                            opt_res, modify_optimize_result
         
     | 
| 
       103 
105 
     | 
    
         
             
                        )
         
     | 
| 
       104 
106 
     | 
    
         
             
                    else:
         
     | 
| 
       105 
     | 
    
         
            -
                        self.modify_optimize_result =  
     | 
| 
      
 107 
     | 
    
         
            +
                        self.modify_optimize_result = (
         
     | 
| 
      
 108 
     | 
    
         
            +
                            modify_optimize_result
         
     | 
| 
      
 109 
     | 
    
         
            +
                            if modify_optimize_result
         
     | 
| 
      
 110 
     | 
    
         
            +
                            else Callback.modify_opt_res_default
         
     | 
| 
      
 111 
     | 
    
         
            +
                        )
         
     | 
| 
      
 112 
     | 
    
         
            +
             
     | 
| 
      
 113 
     | 
    
         
            +
                @staticmethod
         
     | 
| 
      
 114 
     | 
    
         
            +
                def default_callback(_: Any) -> bool:
         
     | 
| 
      
 115 
     | 
    
         
            +
                    return True
         
     | 
| 
      
 116 
     | 
    
         
            +
             
     | 
| 
      
 117 
     | 
    
         
            +
                @staticmethod
         
     | 
| 
      
 118 
     | 
    
         
            +
                def modify_opt_res_dict(
         
     | 
| 
      
 119 
     | 
    
         
            +
                    opt_res: OptimizeResult,
         
     | 
| 
      
 120 
     | 
    
         
            +
                    modify_optimize_result: dict[str, Any] = {},
         
     | 
| 
      
 121 
     | 
    
         
            +
                ) -> OptimizeResult:
         
     | 
| 
      
 122 
     | 
    
         
            +
                    opt_res.extra.update(modify_optimize_result)
         
     | 
| 
      
 123 
     | 
    
         
            +
                    return opt_res
         
     | 
| 
      
 124 
     | 
    
         
            +
             
     | 
| 
      
 125 
     | 
    
         
            +
                @staticmethod
         
     | 
| 
      
 126 
     | 
    
         
            +
                def modify_opt_res_default(opt_res: OptimizeResult) -> OptimizeResult:
         
     | 
| 
      
 127 
     | 
    
         
            +
                    return opt_res
         
     | 
| 
       106 
128 
     | 
    
         | 
| 
       107 
129 
     | 
    
         
             
                @property
         
     | 
| 
       108 
130 
     | 
    
         
             
                def on(self) -> TrainingStage | str:
         
     | 
| 
         @@ -261,8 +283,9 @@ class WriteMetrics(Callback): 
     | 
|
| 
       261 
283 
     | 
    
         
             
                        config (TrainConfig): The configuration object.
         
     | 
| 
       262 
284 
     | 
    
         
             
                        writer (BaseWriter ): The writer object for logging.
         
     | 
| 
       263 
285 
     | 
    
         
             
                    """
         
     | 
| 
       264 
     | 
    
         
            -
                     
     | 
| 
       265 
     | 
    
         
            -
             
     | 
| 
      
 286 
     | 
    
         
            +
                    if trainer.accelerator.rank == 0:
         
     | 
| 
      
 287 
     | 
    
         
            +
                        opt_result = trainer.opt_result
         
     | 
| 
      
 288 
     | 
    
         
            +
                        writer.write(opt_result.iteration, opt_result.metrics)
         
     | 
| 
       266 
289 
     | 
    
         | 
| 
       267 
290 
     | 
    
         | 
| 
       268 
291 
     | 
    
         
             
            class PlotMetrics(Callback):
         
     | 
| 
         @@ -299,9 +322,10 @@ class PlotMetrics(Callback): 
     | 
|
| 
       299 
322 
     | 
    
         
             
                        config (TrainConfig): The configuration object.
         
     | 
| 
       300 
323 
     | 
    
         
             
                        writer (BaseWriter ): The writer object for logging.
         
     | 
| 
       301 
324 
     | 
    
         
             
                    """
         
     | 
| 
       302 
     | 
    
         
            -
                     
     | 
| 
       303 
     | 
    
         
            -
             
     | 
| 
       304 
     | 
    
         
            -
             
     | 
| 
      
 325 
     | 
    
         
            +
                    if trainer.accelerator.rank == 0:
         
     | 
| 
      
 326 
     | 
    
         
            +
                        opt_result = trainer.opt_result
         
     | 
| 
      
 327 
     | 
    
         
            +
                        plotting_functions = config.plotting_functions
         
     | 
| 
      
 328 
     | 
    
         
            +
                        writer.plot(trainer.model, opt_result.iteration, plotting_functions)
         
     | 
| 
       305 
329 
     | 
    
         | 
| 
       306 
330 
     | 
    
         | 
| 
       307 
331 
     | 
    
         
             
            class LogHyperparameters(Callback):
         
     | 
| 
         @@ -338,8 +362,9 @@ class LogHyperparameters(Callback): 
     | 
|
| 
       338 
362 
     | 
    
         
             
                        config (TrainConfig): The configuration object.
         
     | 
| 
       339 
363 
     | 
    
         
             
                        writer (BaseWriter ): The writer object for logging.
         
     | 
| 
       340 
364 
     | 
    
         
             
                    """
         
     | 
| 
       341 
     | 
    
         
            -
                     
     | 
| 
       342 
     | 
    
         
            -
             
     | 
| 
      
 365 
     | 
    
         
            +
                    if trainer.accelerator.rank == 0:
         
     | 
| 
      
 366 
     | 
    
         
            +
                        hyperparams = config.hyperparams
         
     | 
| 
      
 367 
     | 
    
         
            +
                        writer.log_hyperparams(hyperparams)
         
     | 
| 
       343 
368 
     | 
    
         | 
| 
       344 
369 
     | 
    
         | 
| 
       345 
370 
     | 
    
         
             
            class SaveCheckpoint(Callback):
         
     | 
| 
         @@ -376,11 +401,12 @@ class SaveCheckpoint(Callback): 
     | 
|
| 
       376 
401 
     | 
    
         
             
                        config (TrainConfig): The configuration object.
         
     | 
| 
       377 
402 
     | 
    
         
             
                        writer (BaseWriter ): The writer object for logging.
         
     | 
| 
       378 
403 
     | 
    
         
             
                    """
         
     | 
| 
       379 
     | 
    
         
            -
                     
     | 
| 
       380 
     | 
    
         
            -
             
     | 
| 
       381 
     | 
    
         
            -
             
     | 
| 
       382 
     | 
    
         
            -
             
     | 
| 
       383 
     | 
    
         
            -
             
     | 
| 
      
 404 
     | 
    
         
            +
                    if trainer.accelerator.rank == 0:
         
     | 
| 
      
 405 
     | 
    
         
            +
                        folder = config.log_folder
         
     | 
| 
      
 406 
     | 
    
         
            +
                        model = trainer.model
         
     | 
| 
      
 407 
     | 
    
         
            +
                        optimizer = trainer.optimizer
         
     | 
| 
      
 408 
     | 
    
         
            +
                        opt_result = trainer.opt_result
         
     | 
| 
      
 409 
     | 
    
         
            +
                        write_checkpoint(folder, model, optimizer, opt_result.iteration)
         
     | 
| 
       384 
410 
     | 
    
         | 
| 
       385 
411 
     | 
    
         | 
| 
       386 
412 
     | 
    
         
             
            class SaveBestCheckpoint(SaveCheckpoint):
         
     | 
| 
         @@ -404,17 +430,18 @@ class SaveBestCheckpoint(SaveCheckpoint): 
     | 
|
| 
       404 
430 
     | 
    
         
             
                        config (TrainConfig): The configuration object.
         
     | 
| 
       405 
431 
     | 
    
         
             
                        writer (BaseWriter ): The writer object for logging.
         
     | 
| 
       406 
432 
     | 
    
         
             
                    """
         
     | 
| 
       407 
     | 
    
         
            -
                     
     | 
| 
       408 
     | 
    
         
            -
                    if config.validation_criterion and config.validation_criterion(
         
     | 
| 
       409 
     | 
    
         
            -
                        opt_result.loss, self.best_loss, config.val_epsilon
         
     | 
| 
       410 
     | 
    
         
            -
                    ):
         
     | 
| 
       411 
     | 
    
         
            -
                        self.best_loss = opt_result.loss
         
     | 
| 
       412 
     | 
    
         
            -
             
     | 
| 
       413 
     | 
    
         
            -
                        folder = config.log_folder
         
     | 
| 
       414 
     | 
    
         
            -
                        model = trainer.model
         
     | 
| 
       415 
     | 
    
         
            -
                        optimizer = trainer.optimizer
         
     | 
| 
      
 433 
     | 
    
         
            +
                    if trainer.accelerator.rank == 0:
         
     | 
| 
       416 
434 
     | 
    
         
             
                        opt_result = trainer.opt_result
         
     | 
| 
       417 
     | 
    
         
            -
                         
     | 
| 
      
 435 
     | 
    
         
            +
                        if config.validation_criterion and config.validation_criterion(
         
     | 
| 
      
 436 
     | 
    
         
            +
                            opt_result.loss, self.best_loss, config.val_epsilon
         
     | 
| 
      
 437 
     | 
    
         
            +
                        ):
         
     | 
| 
      
 438 
     | 
    
         
            +
                            self.best_loss = opt_result.loss
         
     | 
| 
      
 439 
     | 
    
         
            +
             
     | 
| 
      
 440 
     | 
    
         
            +
                            folder = config.log_folder
         
     | 
| 
      
 441 
     | 
    
         
            +
                            model = trainer.model
         
     | 
| 
      
 442 
     | 
    
         
            +
                            optimizer = trainer.optimizer
         
     | 
| 
      
 443 
     | 
    
         
            +
                            opt_result = trainer.opt_result
         
     | 
| 
      
 444 
     | 
    
         
            +
                            write_checkpoint(folder, model, optimizer, "best")
         
     | 
| 
       418 
445 
     | 
    
         | 
| 
       419 
446 
     | 
    
         | 
| 
       420 
447 
     | 
    
         
             
            class LoadCheckpoint(Callback):
         
     | 
| 
         @@ -431,11 +458,12 @@ class LoadCheckpoint(Callback): 
     | 
|
| 
       431 
458 
     | 
    
         
             
                    Returns:
         
     | 
| 
       432 
459 
     | 
    
         
             
                        Any: The result of loading the checkpoint.
         
     | 
| 
       433 
460 
     | 
    
         
             
                    """
         
     | 
| 
       434 
     | 
    
         
            -
                     
     | 
| 
       435 
     | 
    
         
            -
             
     | 
| 
       436 
     | 
    
         
            -
             
     | 
| 
       437 
     | 
    
         
            -
             
     | 
| 
       438 
     | 
    
         
            -
             
     | 
| 
      
 461 
     | 
    
         
            +
                    if trainer.accelerator.rank == 0:
         
     | 
| 
      
 462 
     | 
    
         
            +
                        folder = config.log_folder
         
     | 
| 
      
 463 
     | 
    
         
            +
                        model = trainer.model
         
     | 
| 
      
 464 
     | 
    
         
            +
                        optimizer = trainer.optimizer
         
     | 
| 
      
 465 
     | 
    
         
            +
                        device = trainer.accelerator.execution.log_device
         
     | 
| 
      
 466 
     | 
    
         
            +
                        return load_checkpoint(folder, model, optimizer, device=device)
         
     | 
| 
       439 
467 
     | 
    
         | 
| 
       440 
468 
     | 
    
         | 
| 
       441 
469 
     | 
    
         
             
            class LogModelTracker(Callback):
         
     | 
| 
         @@ -449,10 +477,11 @@ class LogModelTracker(Callback): 
     | 
|
| 
       449 
477 
     | 
    
         
             
                        config (TrainConfig): The configuration object.
         
     | 
| 
       450 
478 
     | 
    
         
             
                        writer (BaseWriter ): The writer object for logging.
         
     | 
| 
       451 
479 
     | 
    
         
             
                    """
         
     | 
| 
       452 
     | 
    
         
            -
                     
     | 
| 
       453 
     | 
    
         
            -
             
     | 
| 
       454 
     | 
    
         
            -
                         
     | 
| 
       455 
     | 
    
         
            -
             
     | 
| 
      
 480 
     | 
    
         
            +
                    if trainer.accelerator.rank == 0:
         
     | 
| 
      
 481 
     | 
    
         
            +
                        model = trainer.model
         
     | 
| 
      
 482 
     | 
    
         
            +
                        writer.log_model(
         
     | 
| 
      
 483 
     | 
    
         
            +
                            model, trainer.train_dataloader, trainer.val_dataloader, trainer.test_dataloader
         
     | 
| 
      
 484 
     | 
    
         
            +
                        )
         
     | 
| 
       456 
485 
     | 
    
         | 
| 
       457 
486 
     | 
    
         | 
| 
       458 
487 
     | 
    
         
             
            class LRSchedulerStepDecay(Callback):
         
     | 
| 
         @@ -713,7 +742,7 @@ class EarlyStopping(Callback): 
     | 
|
| 
       713 
742 
     | 
    
         
             
                            f"EarlyStopping: No improvement in '{self.monitor}' for {self.patience} epochs. "
         
     | 
| 
       714 
743 
     | 
    
         
             
                            "Stopping training."
         
     | 
| 
       715 
744 
     | 
    
         
             
                        )
         
     | 
| 
       716 
     | 
    
         
            -
                        trainer. 
     | 
| 
      
 745 
     | 
    
         
            +
                        trainer._stop_training.fill_(1)
         
     | 
| 
       717 
746 
     | 
    
         | 
| 
       718 
747 
     | 
    
         | 
| 
       719 
748 
     | 
    
         
             
            class GradientMonitoring(Callback):
         
     | 
| 
         @@ -759,17 +788,18 @@ class GradientMonitoring(Callback): 
     | 
|
| 
       759 
788 
     | 
    
         
             
                        config (TrainConfig): The configuration object.
         
     | 
| 
       760 
789 
     | 
    
         
             
                        writer (BaseWriter): The writer object for logging.
         
     | 
| 
       761 
790 
     | 
    
         
             
                    """
         
     | 
| 
       762 
     | 
    
         
            -
                     
     | 
| 
       763 
     | 
    
         
            -
             
     | 
| 
       764 
     | 
    
         
            -
                         
     | 
| 
       765 
     | 
    
         
            -
                             
     | 
| 
       766 
     | 
    
         
            -
             
     | 
| 
       767 
     | 
    
         
            -
                                 
     | 
| 
       768 
     | 
    
         
            -
                                     
     | 
| 
       769 
     | 
    
         
            -
             
     | 
| 
       770 
     | 
    
         
            -
             
     | 
| 
       771 
     | 
    
         
            -
             
     | 
| 
       772 
     | 
    
         
            -
             
     | 
| 
       773 
     | 
    
         
            -
             
     | 
| 
       774 
     | 
    
         
            -
             
     | 
| 
       775 
     | 
    
         
            -
             
     | 
| 
      
 791 
     | 
    
         
            +
                    if trainer.accelerator.rank == 0:
         
     | 
| 
      
 792 
     | 
    
         
            +
                        gradient_stats = {}
         
     | 
| 
      
 793 
     | 
    
         
            +
                        for name, param in trainer.model.named_parameters():
         
     | 
| 
      
 794 
     | 
    
         
            +
                            if param.grad is not None:
         
     | 
| 
      
 795 
     | 
    
         
            +
                                grad = param.grad
         
     | 
| 
      
 796 
     | 
    
         
            +
                                gradient_stats.update(
         
     | 
| 
      
 797 
     | 
    
         
            +
                                    {
         
     | 
| 
      
 798 
     | 
    
         
            +
                                        name + "_mean": grad.mean().item(),
         
     | 
| 
      
 799 
     | 
    
         
            +
                                        name + "_std": grad.std().item(),
         
     | 
| 
      
 800 
     | 
    
         
            +
                                        name + "_max": grad.max().item(),
         
     | 
| 
      
 801 
     | 
    
         
            +
                                        name + "_min": grad.min().item(),
         
     | 
| 
      
 802 
     | 
    
         
            +
                                    }
         
     | 
| 
      
 803 
     | 
    
         
            +
                                )
         
     | 
| 
      
 804 
     | 
    
         
            +
             
     | 
| 
      
 805 
     | 
    
         
            +
                        writer.write(trainer.opt_result.iteration, gradient_stats)
         
     | 
| 
         @@ -201,7 +201,8 @@ class CallbacksManager: 
     | 
|
| 
       201 
201 
     | 
    
         
             
                                logger.debug(f"Loaded model and optimizer from {self.config.log_folder}")
         
     | 
| 
       202 
202 
     | 
    
         | 
| 
       203 
203 
     | 
    
         
             
                    # Setup writer
         
     | 
| 
       204 
     | 
    
         
            -
                     
     | 
| 
      
 204 
     | 
    
         
            +
                    if trainer.accelerator.rank == 0:
         
     | 
| 
      
 205 
     | 
    
         
            +
                        self.writer.open(self.config, iteration=trainer.global_step)
         
     | 
| 
       205 
206 
     | 
    
         | 
| 
       206 
207 
     | 
    
         
             
                def end_training(self, trainer: Any) -> None:
         
     | 
| 
       207 
208 
     | 
    
         
             
                    """
         
     | 
| 
         @@ -210,5 +211,5 @@ class CallbacksManager: 
     | 
|
| 
       210 
211 
     | 
    
         
             
                    Args:
         
     | 
| 
       211 
212 
     | 
    
         
             
                        trainer (Any): The training object managing the training process.
         
     | 
| 
       212 
213 
     | 
    
         
             
                    """
         
     | 
| 
       213 
     | 
    
         
            -
                    if self.writer:
         
     | 
| 
      
 214 
     | 
    
         
            +
                    if trainer.accelerator.rank == 0 and self.writer:
         
     | 
| 
       214 
215 
     | 
    
         
             
                        self.writer.close()
         
     | 
| 
         @@ -127,11 +127,12 @@ class BaseWriter(ABC): 
     | 
|
| 
       127 
127 
     | 
    
         | 
| 
       128 
128 
     | 
    
         
             
                    # Find the key in result.metrics that contains "loss" (case-insensitive)
         
     | 
| 
       129 
129 
     | 
    
         
             
                    loss_key = next((k for k in result.metrics if "loss" in k.lower()), None)
         
     | 
| 
      
 130 
     | 
    
         
            +
                    initial = f"P {result.rank: >2}|{result.device: <7}| Iteration {result.iteration: >7}| "
         
     | 
| 
       130 
131 
     | 
    
         
             
                    if loss_key:
         
     | 
| 
       131 
132 
     | 
    
         
             
                        loss_value = result.metrics[loss_key]
         
     | 
| 
       132 
     | 
    
         
            -
                        msg = f" 
     | 
| 
      
 133 
     | 
    
         
            +
                        msg = initial + f"{loss_key.title()}: {loss_value:.7f} -"
         
     | 
| 
       133 
134 
     | 
    
         
             
                    else:
         
     | 
| 
       134 
     | 
    
         
            -
                        msg = f" 
     | 
| 
      
 135 
     | 
    
         
            +
                        msg = initial + f"Loss: None -"
         
     | 
| 
       135 
136 
     | 
    
         
             
                    msg += " ".join([f"{k}: {v:.7f}" for k, v in result.metrics.items() if k != loss_key])
         
     | 
| 
       136 
137 
     | 
    
         
             
                    print(msg)
         
     | 
| 
       137 
138 
     | 
    
         | 
    
        qadence/ml_tools/config.py
    CHANGED
    
    | 
         @@ -20,6 +20,7 @@ from qadence.types import ( 
     | 
|
| 
       20 
20 
     | 
    
         
             
                ReuploadScaling,
         
     | 
| 
       21 
21 
     | 
    
         
             
                Strategy,
         
     | 
| 
       22 
22 
     | 
    
         
             
            )
         
     | 
| 
      
 23 
     | 
    
         
            +
            from torch import dtype
         
     | 
| 
       23 
24 
     | 
    
         | 
| 
       24 
25 
     | 
    
         
             
            logger = getLogger(__file__)
         
     | 
| 
       25 
26 
     | 
    
         | 
| 
         @@ -116,10 +117,9 @@ class TrainConfig: 
     | 
|
| 
       116 
117 
     | 
    
         
             
                """The log folder for saving checkpoints and tensorboard logs.
         
     | 
| 
       117 
118 
     | 
    
         | 
| 
       118 
119 
     | 
    
         
             
                This stores the path where all logs and checkpoints are being saved
         
     | 
| 
       119 
     | 
    
         
            -
                for this training session. `log_folder` takes precedence over `root_folder 
     | 
| 
       120 
     | 
    
         
            -
                 
     | 
| 
       121 
     | 
    
         
            -
                 
     | 
| 
       122 
     | 
    
         
            -
                will not be used.
         
     | 
| 
      
 120 
     | 
    
         
            +
                for this training session. `log_folder` takes precedence over `root_folder`,
         
     | 
| 
      
 121 
     | 
    
         
            +
                but it is ignored if `create_subfolders_per_run=True` (in which case, subfolders
         
     | 
| 
      
 122 
     | 
    
         
            +
                will be spawned in the root folder).
         
     | 
| 
       123 
123 
     | 
    
         
             
                """
         
     | 
| 
       124 
124 
     | 
    
         | 
| 
       125 
125 
     | 
    
         
             
                checkpoint_best_only: bool = False
         
     | 
| 
         @@ -195,7 +195,7 @@ class TrainConfig: 
     | 
|
| 
       195 
195 
     | 
    
         
             
                plots that are logged or saved at specified intervals.
         
     | 
| 
       196 
196 
     | 
    
         
             
                """
         
     | 
| 
       197 
197 
     | 
    
         | 
| 
       198 
     | 
    
         
            -
                _subfolders: list = field(default_factory=list)
         
     | 
| 
      
 198 
     | 
    
         
            +
                _subfolders: list[str] = field(default_factory=list)
         
     | 
| 
       199 
199 
     | 
    
         
             
                """List of subfolders used for logging different runs using the same config inside the.
         
     | 
| 
       200 
200 
     | 
    
         | 
| 
       201 
201 
     | 
    
         
             
                root folder.
         
     | 
| 
         @@ -203,6 +203,67 @@ class TrainConfig: 
     | 
|
| 
       203 
203 
     | 
    
         
             
                Each subfolder is of structure `<id>_<timestamp>_<PID>`.
         
     | 
| 
       204 
204 
     | 
    
         
             
                """
         
     | 
| 
       205 
205 
     | 
    
         | 
| 
      
 206 
     | 
    
         
            +
                nprocs: int = 1
         
     | 
| 
      
 207 
     | 
    
         
            +
                """
         
     | 
| 
      
 208 
     | 
    
         
            +
                The number of processes to use for training when spawning subprocesses.
         
     | 
| 
      
 209 
     | 
    
         
            +
             
     | 
| 
      
 210 
     | 
    
         
            +
                For effective parallel processing, set this to a value greater than 1.
         
     | 
| 
      
 211 
     | 
    
         
            +
                - In case of Multi-GPU or Multi-Node-Multi-GPU setups, nprocs should be equal to
         
     | 
| 
      
 212 
     | 
    
         
            +
                the total number of GPUs across all nodes (world size), or total number of GPU to be used.
         
     | 
| 
      
 213 
     | 
    
         
            +
             
     | 
| 
      
 214 
     | 
    
         
            +
                If nprocs > 1, multiple processes will be spawned for training. The training framework will launch
         
     | 
| 
      
 215 
     | 
    
         
            +
                additional processes (e.g., for distributed or parallel training).
         
     | 
| 
      
 216 
     | 
    
         
            +
                - For CPU setup, this will launch a true parallel processes
         
     | 
| 
      
 217 
     | 
    
         
            +
                - For GPU setup, this will launch a distributed training routine.
         
     | 
| 
      
 218 
     | 
    
         
            +
                This uses the DistributedDataParallel framework from PyTorch.
         
     | 
| 
      
 219 
     | 
    
         
            +
                """
         
     | 
| 
      
 220 
     | 
    
         
            +
             
     | 
| 
      
 221 
     | 
    
         
            +
                compute_setup: str = "cpu"
         
     | 
| 
      
 222 
     | 
    
         
            +
                """
         
     | 
| 
      
 223 
     | 
    
         
            +
                Compute device setup; options are "auto", "gpu", or "cpu".
         
     | 
| 
      
 224 
     | 
    
         
            +
             
     | 
| 
      
 225 
     | 
    
         
            +
                - "auto": Automatically uses GPU if available; otherwise, falls back to CPU.
         
     | 
| 
      
 226 
     | 
    
         
            +
                - "gpu": Forces GPU usage, raising an error if no CUDA device is available.
         
     | 
| 
      
 227 
     | 
    
         
            +
                - "cpu": Forces the use of CPU regardless of GPU availability.
         
     | 
| 
      
 228 
     | 
    
         
            +
                """
         
     | 
| 
      
 229 
     | 
    
         
            +
             
     | 
| 
      
 230 
     | 
    
         
            +
                backend: str = "gloo"
         
     | 
| 
      
 231 
     | 
    
         
            +
                """
         
     | 
| 
      
 232 
     | 
    
         
            +
                Backend used for distributed training communication.
         
     | 
| 
      
 233 
     | 
    
         
            +
             
     | 
| 
      
 234 
     | 
    
         
            +
                The default is "gloo". Other options may include "nccl" - which is optimized for GPU-based training or "mpi",
         
     | 
| 
      
 235 
     | 
    
         
            +
                depending on your system and requirements.
         
     | 
| 
      
 236 
     | 
    
         
            +
                It should be one of the backends supported by `torch.distributed`. For further details, please look at
         
     | 
| 
      
 237 
     | 
    
         
            +
                [torch backends](https://pytorch.org/docs/stable/distributed.html#torch.distributed.Backend)
         
     | 
| 
      
 238 
     | 
    
         
            +
                """
         
     | 
| 
      
 239 
     | 
    
         
            +
             
     | 
| 
      
 240 
     | 
    
         
            +
                log_setup: str = "cpu"
         
     | 
| 
      
 241 
     | 
    
         
            +
                """
         
     | 
| 
      
 242 
     | 
    
         
            +
                Logging device setup; options are "auto" or "cpu".
         
     | 
| 
      
 243 
     | 
    
         
            +
             
     | 
| 
      
 244 
     | 
    
         
            +
                - "auto": Uses the same device for logging as for computation.
         
     | 
| 
      
 245 
     | 
    
         
            +
                - "cpu": Forces logging to occur on the CPU. This can be useful to avoid potential conflicts with GPU processes.
         
     | 
| 
      
 246 
     | 
    
         
            +
                """
         
     | 
| 
      
 247 
     | 
    
         
            +
             
     | 
| 
      
 248 
     | 
    
         
            +
                dtype: dtype | None = None
         
     | 
| 
      
 249 
     | 
    
         
            +
                """
         
     | 
| 
      
 250 
     | 
    
         
            +
                Data type (precision) for computations.
         
     | 
| 
      
 251 
     | 
    
         
            +
             
     | 
| 
      
 252 
     | 
    
         
            +
                Both model parameters, and dataset will be of the provided precision.
         
     | 
| 
      
 253 
     | 
    
         
            +
             
     | 
| 
      
 254 
     | 
    
         
            +
                If not specified or None, the default torch precision (usually torch.float32) is used.
         
     | 
| 
      
 255 
     | 
    
         
            +
                If provided dtype is torch.complex128, model parameters will be torch.complex128, and data parameters will be torch.float64
         
     | 
| 
      
 256 
     | 
    
         
            +
                """
         
     | 
| 
      
 257 
     | 
    
         
            +
             
     | 
| 
      
 258 
     | 
    
         
            +
                all_reduce_metrics: bool = False
         
     | 
| 
      
 259 
     | 
    
         
            +
                """
         
     | 
| 
      
 260 
     | 
    
         
            +
                Whether to aggregate metrics (e.g., loss, accuracy) across processes.
         
     | 
| 
      
 261 
     | 
    
         
            +
             
     | 
| 
      
 262 
     | 
    
         
            +
                When True, metrics from different training processes are averaged to provide a consolidated metrics.
         
     | 
| 
      
 263 
     | 
    
         
            +
                Note: Since aggregation requires synchronization/all_reduce operation, this can increase the
         
     | 
| 
      
 264 
     | 
    
         
            +
                 computation time significantly.
         
     | 
| 
      
 265 
     | 
    
         
            +
                """
         
     | 
| 
      
 266 
     | 
    
         
            +
             
     | 
| 
       206 
267 
     | 
    
         | 
| 
       207 
268 
     | 
    
         
             
            @dataclass
         
     | 
| 
       208 
269 
     | 
    
         
             
            class FeatureMapConfig:
         
     |