qadence 1.5.2__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. qadence/__init__.py +33 -5
  2. qadence/backend.py +2 -2
  3. qadence/backends/adjoint.py +8 -4
  4. qadence/backends/braket/backend.py +3 -2
  5. qadence/backends/braket/config.py +2 -2
  6. qadence/backends/gpsr.py +1 -1
  7. qadence/backends/horqrux/backend.py +4 -0
  8. qadence/backends/horqrux/config.py +2 -2
  9. qadence/backends/pulser/backend.py +71 -45
  10. qadence/backends/pulser/config.py +0 -28
  11. qadence/backends/pulser/pulses.py +2 -2
  12. qadence/backends/pyqtorch/backend.py +3 -2
  13. qadence/backends/pyqtorch/config.py +2 -2
  14. qadence/blocks/block_to_tensor.py +4 -4
  15. qadence/blocks/matrix.py +2 -2
  16. qadence/blocks/utils.py +2 -2
  17. qadence/circuit.py +5 -2
  18. qadence/constructors/__init__.py +1 -10
  19. qadence/constructors/ansatze.py +1 -65
  20. qadence/constructors/daqc/daqc.py +3 -2
  21. qadence/constructors/daqc/gen_parser.py +3 -2
  22. qadence/constructors/daqc/utils.py +3 -3
  23. qadence/constructors/feature_maps.py +2 -90
  24. qadence/constructors/hamiltonians.py +2 -6
  25. qadence/constructors/rydberg_feature_maps.py +2 -2
  26. qadence/decompose.py +2 -2
  27. qadence/engines/torch/differentiable_expectation.py +7 -0
  28. qadence/extensions.py +4 -15
  29. qadence/log_config.yaml +24 -0
  30. qadence/logger.py +9 -27
  31. qadence/ml_tools/models.py +10 -2
  32. qadence/ml_tools/saveload.py +14 -5
  33. qadence/ml_tools/train_grad.py +3 -3
  34. qadence/ml_tools/train_no_grad.py +2 -2
  35. qadence/models/quantum_model.py +13 -6
  36. qadence/noise/readout.py +2 -3
  37. qadence/operations/__init__.py +0 -2
  38. qadence/operations/analog.py +2 -12
  39. qadence/operations/control_ops.py +3 -2
  40. qadence/operations/ham_evo.py +2 -2
  41. qadence/operations/parametric.py +3 -2
  42. qadence/operations/primitive.py +2 -2
  43. qadence/parameters.py +2 -2
  44. qadence/serialization.py +2 -2
  45. qadence/transpile/block.py +2 -2
  46. qadence/types.py +2 -2
  47. qadence/utils.py +2 -2
  48. {qadence-1.5.2.dist-info → qadence-1.6.0.dist-info}/METADATA +16 -8
  49. {qadence-1.5.2.dist-info → qadence-1.6.0.dist-info}/RECORD +51 -50
  50. {qadence-1.5.2.dist-info → qadence-1.6.0.dist-info}/WHEEL +0 -0
  51. {qadence-1.5.2.dist-info → qadence-1.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -2,13 +2,10 @@
2
2
 
3
3
  from .feature_maps import (
4
4
  feature_map,
5
- chebyshev_feature_map,
6
- fourier_feature_map,
7
- tower_feature_map,
8
5
  exp_fourier_feature_map,
9
6
  )
10
7
 
11
- from .ansatze import hea, build_qnn
8
+ from .ansatze import hea
12
9
 
13
10
  from .iia import identity_initialized_ansatz
14
11
 
@@ -17,7 +14,6 @@ from .daqc import daqc_transform
17
14
  from .hamiltonians import (
18
15
  hamiltonian_factory,
19
16
  ising_hamiltonian,
20
- single_z,
21
17
  total_magnetization,
22
18
  zz_hamiltonian,
23
19
  )
@@ -30,16 +26,11 @@ from .qft import qft
30
26
  # Modules to be automatically added to the qadence namespace
31
27
  __all__ = [
32
28
  "feature_map",
33
- "chebyshev_feature_map",
34
- "fourier_feature_map",
35
- "tower_feature_map",
36
29
  "exp_fourier_feature_map",
37
30
  "hea",
38
31
  "identity_initialized_ansatz",
39
- "build_qnn",
40
32
  "hamiltonian_factory",
41
33
  "ising_hamiltonian",
42
- "single_z",
43
34
  "total_magnetization",
44
35
  "zz_hamiltonian",
45
36
  "qft",
@@ -1,15 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import itertools
4
- import warnings
5
- from typing import Any, Optional, Type, Union
4
+ from typing import Any, Type, Union
6
5
 
7
6
  from qadence.blocks import AbstractBlock, block_is_qubit_hamiltonian, chain, kron, tag
8
7
  from qadence.operations import CNOT, CPHASE, CRX, CRY, CRZ, CZ, RX, RY, HamEvo
9
8
  from qadence.types import Interaction, Strategy
10
9
 
11
10
  from .hamiltonians import hamiltonian_factory
12
- from .utils import build_idx_fms
13
11
 
14
12
  DigitalEntanglers = Union[CNOT, CZ, CRZ, CRY, CRX]
15
13
 
@@ -322,65 +320,3 @@ def hea_bDAQC(*args: Any, **kwargs: Any) -> Any:
322
320
 
323
321
  def hea_analog(*args: Any, **kwargs: Any) -> Any:
324
322
  raise NotImplementedError
325
-
326
-
327
- #########
328
- ## QNN ##
329
- #########
330
-
331
-
332
- # FIXME: Remove in v1.5.0
333
- def build_qnn(
334
- n_qubits: int,
335
- n_features: int,
336
- depth: int = None,
337
- ansatz: Optional[AbstractBlock] = None,
338
- fm_pauli: Type[RY] = RY,
339
- spectrum: str = "simple",
340
- basis: str = "fourier",
341
- fm_strategy: str = "parallel",
342
- ) -> list[AbstractBlock]:
343
- """Helper function to build a qadence QNN quantum circuit.
344
-
345
- Args:
346
- n_qubits (int): The number of qubits.
347
- n_features (int): The number of input dimensions.
348
- depth (int): The depth of the ansatz.
349
- ansatz (Optional[AbstractBlock]): An optional argument to pass a custom qadence ansatz.
350
- fm_pauli (str): The type of Pauli gate for the feature map. Must be one of 'RX',
351
- 'RY', or 'RZ'.
352
- spectrum (str): The desired spectrum of the feature map generator. The options simple,
353
- tower and exponential produce a spectrum with linear, quadratic and exponential
354
- eigenvalues with respect to the number of qubits.
355
- basis (str): The encoding function. The options fourier and chebyshev correspond to Φ(x)=x
356
- and arcos(x) respectively.
357
- fm_strategy (str): The feature map encoding strategy. If "parallel", the features
358
- are encoded in one block of rotation gates, with each feature given
359
- an equal number of qubits. If "serial", the features are encoded
360
- sequentially, with a HEA block between.
361
-
362
- Returns:
363
- A list of Abstract blocks to be used for constructing a quantum circuit
364
- """
365
-
366
- warnings.warn("Function build_qnn is deprecated and will be removed in v1.5.0.", FutureWarning)
367
-
368
- depth = n_qubits if depth is None else depth
369
-
370
- idx_fms = build_idx_fms(basis, fm_pauli, fm_strategy, n_features, n_qubits, spectrum)
371
-
372
- if fm_strategy == "parallel":
373
- _fm = kron(*idx_fms)
374
- fm = tag(_fm, tag="FM")
375
-
376
- elif fm_strategy == "serial":
377
- fm_components: list[AbstractBlock] = []
378
- for j, fm_idx in enumerate(idx_fms[:-1]):
379
- fm_idx = tag(fm_idx, tag=f"FM{j}") # type: ignore[assignment]
380
- fm_component = (fm_idx, hea(n_qubits, 1, f"theta_{j}"))
381
- fm_components.extend(fm_component)
382
- fm_components.append(tag(idx_fms[-1], tag=f"FM{len(idx_fms) - 1}"))
383
- fm = chain(*fm_components) # type: ignore[assignment]
384
-
385
- ansatz = hea(n_qubits, depth=depth) if ansatz is None else ansatz
386
- return [fm, ansatz]
@@ -1,18 +1,19 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from logging import getLogger
4
+
3
5
  import torch
4
6
 
5
7
  from qadence.blocks import AbstractBlock, add, chain, kron
6
8
  from qadence.blocks.utils import block_is_qubit_hamiltonian
7
9
  from qadence.constructors.hamiltonians import hamiltonian_factory
8
- from qadence.logger import get_logger
9
10
  from qadence.operations import HamEvo, I, N, X
10
11
  from qadence.types import GenDAQC, Interaction, Strategy
11
12
 
12
13
  from .gen_parser import _check_compatibility, _parse_generator
13
14
  from .utils import _build_matrix_M, _ix_map
14
15
 
15
- logger = get_logger(__name__)
16
+ logger = getLogger(__name__)
16
17
 
17
18
 
18
19
  def daqc_transform(
@@ -1,17 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from logging import getLogger
4
+
3
5
  import torch
4
6
 
5
7
  from qadence.blocks import AbstractBlock, KronBlock
6
8
  from qadence.blocks.utils import unroll_block_with_scaling
7
- from qadence.logger import get_logger
8
9
  from qadence.operations import N, Z
9
10
  from qadence.parameters import Parameter, evaluate
10
11
  from qadence.types import GenDAQC
11
12
 
12
13
  from .utils import _ix_map
13
14
 
14
- logger = get_logger(__name__)
15
+ logger = getLogger(__name__)
15
16
 
16
17
 
17
18
  def _parse_generator(
@@ -1,10 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- import torch
3
+ from logging import getLogger
4
4
 
5
- from qadence.logger import get_logger
5
+ import torch
6
6
 
7
- logger = get_logger(__name__)
7
+ logger = getLogger(__name__)
8
8
 
9
9
 
10
10
  def _k_d(a: int, b: int) -> int:
@@ -1,19 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
- import warnings
4
3
  from collections.abc import Callable
4
+ from logging import getLogger
5
5
  from math import isclose
6
6
  from typing import Union
7
7
 
8
8
  from sympy import Basic, acos
9
9
 
10
10
  from qadence.blocks import AbstractBlock, KronBlock, chain, kron, tag
11
- from qadence.logger import get_logger
12
11
  from qadence.operations import PHASE, RX, RY, RZ, H
13
12
  from qadence.parameters import FeatureParameter, Parameter, VariationalParameter
14
13
  from qadence.types import PI, BasisSet, ReuploadScaling, TParameter
15
14
 
16
- logger = get_logger(__name__)
15
+ logger = getLogger(__name__)
17
16
 
18
17
  ROTATIONS = [RX, RY, RZ, PHASE]
19
18
  RotationTypes = type[Union[RX, RY, RZ, PHASE]]
@@ -35,28 +34,6 @@ RS_FUNC_DICT = {
35
34
  }
36
35
 
37
36
 
38
- # FIXME: Remove in v1.5.0
39
- def backwards_compatibility(
40
- fm_type: BasisSet | Callable | str,
41
- reupload_scaling: ReuploadScaling | Callable | str,
42
- ) -> tuple:
43
- if fm_type in ("fourier", "chebyshev", "tower"):
44
- logger.warning(
45
- "Selecting `fm_type` as 'fourier', 'chebyshev' or 'tower' is deprecated. "
46
- "Please use the respective enumerations: 'fm_type = BasisSet.FOURIER', "
47
- "'fm_type = BasisSet.CHEBYSHEV' or 'reupload_scaling = ReuploadScaling.TOWER'."
48
- )
49
- if fm_type == "fourier":
50
- fm_type = BasisSet.FOURIER
51
- elif fm_type == "chebyshev":
52
- fm_type = BasisSet.CHEBYSHEV
53
- elif fm_type == "tower":
54
- fm_type = BasisSet.CHEBYSHEV
55
- reupload_scaling = ReuploadScaling.TOWER
56
-
57
- return fm_type, reupload_scaling
58
-
59
-
60
37
  def fm_parameter_scaling(
61
38
  fm_type: BasisSet | Callable | str,
62
39
  param: Parameter | str = "phi",
@@ -195,9 +172,6 @@ def feature_map(
195
172
  f"Please provide one from {[rot.__name__ for rot in ROTATIONS]}."
196
173
  )
197
174
 
198
- # Backwards compatibility
199
- fm_type, reupload_scaling = backwards_compatibility(fm_type, reupload_scaling)
200
-
201
175
  scaled_fparam = fm_parameter_scaling(
202
176
  fm_type, param, feature_range=feature_range, target_range=target_range
203
177
  )
@@ -225,68 +199,6 @@ def feature_map(
225
199
  return fm
226
200
 
227
201
 
228
- # FIXME: Remove in v1.5.0
229
- def fourier_feature_map(
230
- n_qubits: int, support: tuple[int, ...] = None, param: str = "phi", op: RotationTypes = RX
231
- ) -> AbstractBlock:
232
- """Construct a Fourier feature map.
233
-
234
- Args:
235
- n_qubits: number of qubits across which the FM is created
236
- param: The base name for the feature `Parameter`
237
- """
238
- warnings.warn(
239
- "Function 'fourier_feature_map' is deprecated. Please use 'feature_map' directly.",
240
- FutureWarning,
241
- )
242
- fm = feature_map(n_qubits, support=support, param=param, op=op, fm_type=BasisSet.FOURIER)
243
- return fm
244
-
245
-
246
- # FIXME: Remove in v1.5.0
247
- def chebyshev_feature_map(
248
- n_qubits: int, support: tuple[int, ...] = None, param: str = "phi", op: RotationTypes = RX
249
- ) -> AbstractBlock:
250
- """Construct a Chebyshev feature map.
251
-
252
- Args:
253
- n_qubits: number of qubits across which the FM is created
254
- support (Iterable[int]): The qubit support
255
- param: The base name for the feature `Parameter`
256
- """
257
- warnings.warn(
258
- "Function 'chebyshev_feature_map' is deprecated. Please use 'feature_map' directly.",
259
- FutureWarning,
260
- )
261
- fm = feature_map(n_qubits, support=support, param=param, op=op, fm_type=BasisSet.CHEBYSHEV)
262
- return fm
263
-
264
-
265
- # FIXME: Remove in v1.5.0
266
- def tower_feature_map(
267
- n_qubits: int, support: tuple[int, ...] = None, param: str = "phi", op: RotationTypes = RX
268
- ) -> AbstractBlock:
269
- """Construct a Chebyshev tower feature map.
270
-
271
- Args:
272
- n_qubits: number of qubits across which the FM is created
273
- param: The base name for the feature `Parameter`
274
- """
275
- warnings.warn(
276
- "Function 'tower_feature_map' is deprecated. Please use feature_map directly.",
277
- FutureWarning,
278
- )
279
- fm = feature_map(
280
- n_qubits,
281
- support=support,
282
- param=param,
283
- op=op,
284
- fm_type=BasisSet.CHEBYSHEV,
285
- reupload_scaling=ReuploadScaling.TOWER,
286
- )
287
- return fm
288
-
289
-
290
202
  def exp_fourier_feature_map(
291
203
  n_qubits: int,
292
204
  support: tuple[int, ...] = None,
@@ -1,17 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from logging import getLogger
3
4
  from typing import Callable, List, Type, Union
4
5
 
5
6
  import numpy as np
6
7
  from torch import Tensor, double, ones, rand
7
8
 
8
9
  from qadence.blocks import AbstractBlock, add, block_is_qubit_hamiltonian
9
- from qadence.logger import get_logger
10
10
  from qadence.operations import N, X, Y, Z
11
11
  from qadence.register import Register
12
12
  from qadence.types import Interaction, TArray
13
13
 
14
- logger = get_logger(__name__)
14
+ logger = getLogger(__name__)
15
15
 
16
16
 
17
17
  def interaction_zz(i: int, j: int) -> AbstractBlock:
@@ -207,10 +207,6 @@ def total_magnetization(n_qubits: int, z_terms: np.ndarray | list | None = None)
207
207
  return hamiltonian_factory(n_qubits, detuning=Z, detuning_strength=z_terms)
208
208
 
209
209
 
210
- def single_z(qubit: int = 0, z_coefficient: float = 1.0) -> AbstractBlock:
211
- return Z(qubit) * z_coefficient
212
-
213
-
214
210
  def zz_hamiltonian(
215
211
  n_qubits: int,
216
212
  z_terms: np.ndarray | None = None,
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from logging import getLogger
3
4
  from typing import Callable
4
5
 
5
6
  import numpy as np
@@ -7,12 +8,11 @@ from sympy import Basic
7
8
 
8
9
  from qadence.blocks import AnalogBlock, KronBlock, kron
9
10
  from qadence.constructors.feature_maps import fm_parameter_func, fm_parameter_scaling
10
- from qadence.logger import get_logger
11
11
  from qadence.operations import AnalogRot, AnalogRX, AnalogRY, AnalogRZ
12
12
  from qadence.parameters import FeatureParameter, Parameter, VariationalParameter
13
13
  from qadence.types import PI, BasisSet, ReuploadScaling, TParameter
14
14
 
15
- logger = get_logger(__file__)
15
+ logger = getLogger(__name__)
16
16
 
17
17
  AnalogRotationTypes = [AnalogRX, AnalogRY, AnalogRZ]
18
18
 
qadence/decompose.py CHANGED
@@ -2,20 +2,20 @@ from __future__ import annotations
2
2
 
3
3
  import itertools
4
4
  from enum import Enum
5
+ from logging import getLogger
5
6
  from typing import Any, List, Tuple, Union
6
7
 
7
8
  import sympy
8
9
 
9
10
  from qadence.blocks import AbstractBlock
10
11
  from qadence.blocks.utils import get_pauli_blocks, unroll_block_with_scaling
11
- from qadence.logger import get_logger
12
12
  from qadence.parameters import Parameter, evaluate
13
13
 
14
14
  # from qadence.types import TNumber, TParameter
15
15
  from qadence.types import PI
16
16
  from qadence.types import LTSOrder as Order
17
17
 
18
- logger = get_logger(__name__)
18
+ logger = getLogger(__name__)
19
19
 
20
20
  # flatten a doubly-nested list
21
21
  flatten = lambda a: list(itertools.chain(*a)) # noqa: E731
@@ -45,6 +45,13 @@ class PSRExpectation(Function):
45
45
  expectation_values = expectation_fn(param_values=param_dict(param_keys, param_values)) # type: ignore[call-arg] # noqa: E501
46
46
  # Stack batches of expectations if so.
47
47
  if isinstance(expectation_values, list):
48
+ # Check for first element being a list in case of noisy simulations in Pulser.
49
+ if isinstance(expectation_values[0], list):
50
+ exp_vals: list = []
51
+ for expectation_value in expectation_values:
52
+ res = list(map(lambda x: x.get_final_state().data.toarray(), expectation_value))
53
+ exp_vals.append(torch.tensor(res))
54
+ expectation_values = exp_vals
48
55
  return torch.stack(expectation_values)
49
56
  else:
50
57
  return expectation_values
qadence/extensions.py CHANGED
@@ -1,16 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import importlib
4
+ from logging import getLogger
4
5
  from string import Template
5
6
 
6
7
  from qadence.backend import Backend
7
8
  from qadence.blocks.abstract import TAbstractBlock
8
- from qadence.logger import get_logger
9
9
  from qadence.types import BackendName, DiffMode, Engine
10
10
 
11
11
  backends_namespace = Template("qadence.backends.$name")
12
12
 
13
- logger = get_logger(__name__)
13
+ logger = getLogger(__name__)
14
14
 
15
15
 
16
16
  def _available_engines() -> dict:
@@ -24,7 +24,7 @@ def _available_engines() -> dict:
24
24
  res[engine] = DifferentiableBackendCls
25
25
  except (ImportError, ModuleNotFoundError):
26
26
  pass
27
- logger.info(f"Found engines: {res.keys()}")
27
+ logger.debug(f"Found engines: {res.keys()}")
28
28
  return res
29
29
 
30
30
 
@@ -39,7 +39,7 @@ def _available_backends() -> dict:
39
39
  res[backend] = BackendCls
40
40
  except (ImportError, ModuleNotFoundError):
41
41
  pass
42
- logger.info(f"Found backends: {res.keys()}")
42
+ logger.debug(f"Found backends: {res.keys()}")
43
43
  return res
44
44
 
45
45
 
@@ -77,16 +77,6 @@ def _validate_diff_mode(backend: Backend, diff_mode: DiffMode) -> None:
77
77
  raise TypeError(f"Backend {backend.name} does not support diff_mode {DiffMode.ADJOINT}.")
78
78
 
79
79
 
80
- def _validate_backend_config(backend: Backend) -> None:
81
- if backend.config.use_gradient_checkpointing:
82
- # FIXME: Remove in v1.5.0
83
- msg = "use_gradient_checkpointing is deprecated."
84
- import warnings
85
-
86
- warnings.warn(msg, UserWarning)
87
- logger.warn(msg)
88
-
89
-
90
80
  def _set_backend_config(backend: Backend, diff_mode: DiffMode) -> None:
91
81
  """Fallback function for native Qadence backends if extensions is not present.
92
82
 
@@ -96,7 +86,6 @@ def _set_backend_config(backend: Backend, diff_mode: DiffMode) -> None:
96
86
  """
97
87
 
98
88
  _validate_diff_mode(backend, diff_mode)
99
- _validate_backend_config(backend)
100
89
 
101
90
  # (1) When using PSR with any backend or (2) we use the backends Pulser or Braket,
102
91
  # we have to use gate-level parameters
@@ -0,0 +1,24 @@
1
+ version: 1
2
+ disable_existing_loggers: false
3
+ formatters:
4
+ base:
5
+ format: "%(levelname) -5s %(asctime)s - %(name)s: %(message)s"
6
+ datefmt: "%Y-%m-%d %H:%M:%S"
7
+ handlers:
8
+ console:
9
+ class: logging.StreamHandler
10
+ formatter: base
11
+ stream: ext://sys.stderr
12
+ loggers:
13
+ qadence:
14
+ level: INFO
15
+ handlers: [console]
16
+ propagate: yes
17
+ pyqtorch:
18
+ level: INFO
19
+ handlers: [console]
20
+ propagate: yes
21
+ script:
22
+ level: INFO
23
+ handlers: [console]
24
+ propagate: yes
qadence/logger.py CHANGED
@@ -1,35 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- import os
5
- import sys
6
-
7
- logging_levels = {
8
- "DEBUG": logging.DEBUG,
9
- "INFO": logging.INFO,
10
- "WARNING": logging.WARNING,
11
- "ERROR": logging.ERROR,
12
- "CRITICAL": logging.CRITICAL,
13
- }
14
-
15
- LOG_STREAM_HANDLER = sys.stdout
16
-
17
- DEFAULT_LOGGING_LEVEL = logging.INFO
18
-
19
- # FIXME: introduce a better handling of the configuration
20
- LOGGING_LEVEL = os.environ.get("LOGGING_LEVEL", "warning").upper()
4
+ from warnings import warn
21
5
 
22
6
 
23
7
  def get_logger(name: str) -> logging.Logger:
24
- logger: logging.Logger = logging.getLogger(name)
25
-
26
- level = logging_levels.get(LOGGING_LEVEL, DEFAULT_LOGGING_LEVEL)
27
- logger.setLevel(level)
8
+ warn(
9
+ '"get_logger" will be deprected soon.\
10
+ Please use "get_script_logger" instead.',
11
+ DeprecationWarning,
12
+ )
13
+ return logging.getLogger(name)
28
14
 
29
- formatter = logging.Formatter("%(levelname) -5s %(asctime)s: %(message)s", "%Y-%m-%d %H:%M:%S")
30
- # formatter = logging.Formatter(LOG_FORMAT)
31
- sh = logging.StreamHandler(LOG_STREAM_HANDLER)
32
- sh.setFormatter(formatter)
33
- logger.addHandler(sh)
34
15
 
35
- return logger
16
+ def get_script_logger(name: str = "") -> logging.Logger:
17
+ return logging.getLogger(f"script.{name}")
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from logging import getLogger
3
4
  from typing import Any, Counter, List
4
5
 
5
6
  import numpy as np
@@ -8,14 +9,13 @@ from torch import Tensor
8
9
  from torch.nn import Parameter as TorchParam
9
10
 
10
11
  from qadence.backend import ConvertedObservable
11
- from qadence.logger import get_logger
12
12
  from qadence.measurements import Measurements
13
13
  from qadence.ml_tools import promote_to_tensor
14
14
  from qadence.models import QNN, QuantumModel
15
15
  from qadence.noise import Noise
16
16
  from qadence.utils import Endianness
17
17
 
18
- logger = get_logger(__name__)
18
+ logger = getLogger(__name__)
19
19
 
20
20
 
21
21
  def _set_fixed_operation(
@@ -310,3 +310,11 @@ class TransformedModule(torch.nn.Module):
310
310
  except Exception as e:
311
311
  logger.warning(f"Unable to move {self} to {args}, {kwargs} due to {e}.")
312
312
  return self
313
+
314
+ @property
315
+ def device(self) -> torch.device:
316
+ return (
317
+ self.model.device
318
+ if isinstance(self.model, QuantumModel)
319
+ else self._input_scaling.device
320
+ )
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import os
4
4
  import re
5
+ from logging import getLogger
5
6
  from pathlib import Path
6
7
  from typing import Any
7
8
 
@@ -10,9 +11,7 @@ from nevergrad.optimization.base import Optimizer as NGOptimizer
10
11
  from torch.nn import Module
11
12
  from torch.optim import Optimizer
12
13
 
13
- from qadence.logger import get_logger
14
-
15
- logger = get_logger(__name__)
14
+ logger = getLogger(__name__)
16
15
 
17
16
 
18
17
  def get_latest_checkpoint_name(folder: Path, type: str) -> Path:
@@ -59,8 +58,18 @@ def write_checkpoint(
59
58
  from qadence.ml_tools.models import TransformedModule
60
59
  from qadence.models import QNN, QuantumModel
61
60
 
62
- model_checkpoint_name: str = f"model_{type(model).__name__}_ckpt_" + f"{iteration:03n}" + ".pt"
63
- opt_checkpoint_name: str = f"opt_{type(optimizer).__name__}_ckpt_" + f"{iteration:03n}" + ".pt"
61
+ device = None
62
+ try:
63
+ # We extract the device from the pyqtorch native circuit
64
+ device = str(model.device).split(":")[0] # in case of using several CUDA devices
65
+ except Exception:
66
+ pass
67
+ model_checkpoint_name: str = (
68
+ f"model_{type(model).__name__}_ckpt_" + f"{iteration:03n}" + f"_device_{device}" + ".pt"
69
+ )
70
+ opt_checkpoint_name: str = (
71
+ f"opt_{type(optimizer).__name__}_ckpt_" + f"{iteration:03n}" + f"_device_{device}" + ".pt"
72
+ )
64
73
  try:
65
74
  d = (
66
75
  model._to_dict(save_params=True)
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from logging import getLogger
3
4
  from typing import Callable, Union
4
5
 
5
6
  from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn
@@ -11,14 +12,13 @@ from torch.optim import Optimizer
11
12
  from torch.utils.data import DataLoader
12
13
  from torch.utils.tensorboard import SummaryWriter
13
14
 
14
- from qadence.logger import get_logger
15
15
  from qadence.ml_tools.config import TrainConfig
16
16
  from qadence.ml_tools.data import DictDataLoader
17
17
  from qadence.ml_tools.optimize_step import optimize_step
18
18
  from qadence.ml_tools.printing import print_metrics, write_tensorboard
19
19
  from qadence.ml_tools.saveload import load_checkpoint, write_checkpoint
20
20
 
21
- logger = get_logger(__name__)
21
+ logger = getLogger(__name__)
22
22
 
23
23
 
24
24
  def train(
@@ -182,7 +182,7 @@ def train(
182
182
  write_checkpoint(config.folder, model, optimizer, iteration)
183
183
 
184
184
  except KeyboardInterrupt:
185
- print("Terminating training gracefully after the current iteration.")
185
+ logger.info("Terminating training gracefully after the current iteration.")
186
186
  break
187
187
 
188
188
  # Final writing and checkpointing
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from logging import getLogger
3
4
  from typing import Callable
4
5
 
5
6
  import nevergrad as ng
@@ -10,7 +11,6 @@ from torch.nn import Module
10
11
  from torch.utils.data import DataLoader
11
12
  from torch.utils.tensorboard import SummaryWriter
12
13
 
13
- from qadence.logger import get_logger
14
14
  from qadence.ml_tools.config import TrainConfig
15
15
  from qadence.ml_tools.data import DictDataLoader
16
16
  from qadence.ml_tools.parameters import get_parameters, set_parameters
@@ -18,7 +18,7 @@ from qadence.ml_tools.printing import print_metrics, write_tensorboard
18
18
  from qadence.ml_tools.saveload import load_checkpoint, write_checkpoint
19
19
  from qadence.ml_tools.tensors import promote_to_tensor
20
20
 
21
- logger = get_logger(__name__)
21
+ logger = getLogger(__name__)
22
22
 
23
23
 
24
24
  def train(