qadence 1.5.1__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. qadence/__init__.py +33 -5
  2. qadence/backend.py +2 -2
  3. qadence/backends/adjoint.py +8 -4
  4. qadence/backends/braket/backend.py +3 -2
  5. qadence/backends/braket/config.py +2 -2
  6. qadence/backends/gpsr.py +1 -1
  7. qadence/backends/horqrux/backend.py +4 -0
  8. qadence/backends/horqrux/config.py +2 -2
  9. qadence/backends/pulser/backend.py +71 -45
  10. qadence/backends/pulser/config.py +0 -28
  11. qadence/backends/pulser/pulses.py +2 -2
  12. qadence/backends/pyqtorch/backend.py +3 -2
  13. qadence/backends/pyqtorch/config.py +2 -2
  14. qadence/blocks/block_to_tensor.py +4 -4
  15. qadence/blocks/matrix.py +2 -2
  16. qadence/blocks/utils.py +2 -2
  17. qadence/circuit.py +5 -2
  18. qadence/constructors/__init__.py +1 -10
  19. qadence/constructors/ansatze.py +1 -65
  20. qadence/constructors/daqc/daqc.py +3 -2
  21. qadence/constructors/daqc/gen_parser.py +3 -2
  22. qadence/constructors/daqc/utils.py +3 -3
  23. qadence/constructors/feature_maps.py +2 -90
  24. qadence/constructors/hamiltonians.py +2 -6
  25. qadence/constructors/rydberg_feature_maps.py +2 -2
  26. qadence/decompose.py +2 -2
  27. qadence/engines/torch/differentiable_expectation.py +7 -0
  28. qadence/extensions.py +4 -15
  29. qadence/log_config.yaml +24 -0
  30. qadence/logger.py +9 -27
  31. qadence/ml_tools/models.py +10 -2
  32. qadence/ml_tools/saveload.py +14 -5
  33. qadence/ml_tools/train_grad.py +3 -3
  34. qadence/ml_tools/train_no_grad.py +2 -2
  35. qadence/models/quantum_model.py +13 -6
  36. qadence/noise/readout.py +2 -3
  37. qadence/operations/__init__.py +0 -2
  38. qadence/operations/analog.py +2 -12
  39. qadence/operations/control_ops.py +3 -2
  40. qadence/operations/ham_evo.py +2 -2
  41. qadence/operations/parametric.py +3 -2
  42. qadence/operations/primitive.py +2 -2
  43. qadence/parameters.py +2 -2
  44. qadence/serial_expr_grammar.peg +11 -0
  45. qadence/serialization.py +192 -67
  46. qadence/transpile/block.py +2 -2
  47. qadence/types.py +2 -2
  48. qadence/utils.py +10 -4
  49. {qadence-1.5.1.dist-info → qadence-1.6.0.dist-info}/METADATA +45 -36
  50. {qadence-1.5.1.dist-info → qadence-1.6.0.dist-info}/RECORD +52 -50
  51. {qadence-1.5.1.dist-info → qadence-1.6.0.dist-info}/WHEEL +1 -1
  52. {qadence-1.5.1.dist-info → qadence-1.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import os
4
4
  from collections import Counter, OrderedDict
5
5
  from dataclasses import asdict
6
+ from logging import getLogger
6
7
  from pathlib import Path
7
8
  from typing import Any, Callable, Optional, Sequence
8
9
 
@@ -21,14 +22,13 @@ from qadence.blocks.abstract import AbstractBlock
21
22
  from qadence.blocks.utils import chain, unique_parameters
22
23
  from qadence.circuit import QuantumCircuit
23
24
  from qadence.engines.differentiable_backend import DifferentiableBackend
24
- from qadence.logger import get_logger
25
25
  from qadence.measurements import Measurements
26
26
  from qadence.mitigations import Mitigations
27
27
  from qadence.noise import Noise
28
28
  from qadence.parameters import Parameter
29
29
  from qadence.types import DiffMode, Endianness
30
30
 
31
- logger = get_logger(__name__)
31
+ logger = getLogger(__name__)
32
32
 
33
33
 
34
34
  class QuantumModel(nn.Module):
@@ -44,6 +44,7 @@ class QuantumModel(nn.Module):
44
44
  _params: nn.ParameterDict
45
45
  _circuit: ConvertedCircuit
46
46
  _observable: list[ConvertedObservable] | None
47
+ logger.debug("Initialised")
47
48
 
48
49
  def __init__(
49
50
  self,
@@ -185,8 +186,6 @@ class QuantumModel(nn.Module):
185
186
  params = self.embedding_fn(self._params, values)
186
187
  if noise is None:
187
188
  noise = self._noise
188
- else:
189
- self._noise = noise
190
189
  if mitigation is None:
191
190
  mitigation = self._mitigation
192
191
  return self.backend.sample(
@@ -316,7 +315,7 @@ class QuantumModel(nn.Module):
316
315
  try:
317
316
  torch.save(self._to_dict(save_params), folder / Path(file_name))
318
317
  except Exception as e:
319
- print(f"Unable to write QuantumModel to disk due to {e}")
318
+ logger.error(f"Unable to write QuantumModel to disk due to {e}")
320
319
 
321
320
  @classmethod
322
321
  def load(
@@ -333,7 +332,7 @@ class QuantumModel(nn.Module):
333
332
  try:
334
333
  qm_pt = torch.load(file_path, map_location=map_location)
335
334
  except Exception as e:
336
- print(f"Unable to load QuantumModel due to {e}")
335
+ logger.error(f"Unable to load QuantumModel due to {e}")
337
336
  return cls._from_dict(qm_pt, as_torch)
338
337
 
339
338
  def assign_parameters(self, values: dict[str, Tensor]) -> Any:
@@ -365,3 +364,11 @@ class QuantumModel(nn.Module):
365
364
  except Exception as e:
366
365
  logger.warning(f"Unable to move {self} to {args}, {kwargs} due to {e}.")
367
366
  return self
367
+
368
+ @property
369
+ def device(self) -> torch.device:
370
+ return (
371
+ self._circuit.native.device
372
+ if self.backend.backend.name == "pyqtorch" # type: ignore[union-attr]
373
+ else torch.device("cpu")
374
+ )
qadence/noise/readout.py CHANGED
@@ -2,14 +2,13 @@ from __future__ import annotations
2
2
 
3
3
  from collections import Counter
4
4
  from enum import Enum
5
+ from logging import getLogger
5
6
 
6
7
  import torch
7
8
  from torch import Tensor
8
9
  from torch.distributions import normal, poisson, uniform
9
10
 
10
- from qadence.logger import get_logger
11
-
12
- logger = get_logger(__name__)
11
+ logger = getLogger(__name__)
13
12
 
14
13
 
15
14
  class WhiteNoise(Enum):
@@ -10,7 +10,6 @@ from .analog import (
10
10
  AnalogSWAP,
11
11
  ConstantAnalogRotation,
12
12
  entangle,
13
- wait,
14
13
  )
15
14
  from .control_ops import (
16
15
  CNOT,
@@ -89,7 +88,6 @@ __all__ = [
89
88
  "CSWAP",
90
89
  "MCPHASE",
91
90
  "Toffoli",
92
- "wait",
93
91
  "entangle",
94
92
  "AnalogEntanglement",
95
93
  "AnalogInteraction",
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
+ from logging import getLogger
4
5
  from typing import Any, Tuple
5
6
 
6
7
  import numpy as np
@@ -19,7 +20,6 @@ from qadence.blocks.utils import (
19
20
  add, # noqa
20
21
  kron,
21
22
  )
22
- from qadence.logger import get_logger
23
23
  from qadence.parameters import (
24
24
  Parameter,
25
25
  ParamMap,
@@ -29,7 +29,7 @@ from qadence.types import PI, OpName, TNumber, TParameter
29
29
  from .ham_evo import HamEvo
30
30
  from .primitive import I, X, Z
31
31
 
32
- logger = get_logger(__name__)
32
+ logger = getLogger(__name__)
33
33
 
34
34
 
35
35
  class AnalogSWAP(HamEvo):
@@ -84,16 +84,6 @@ def AnalogInteraction(
84
84
  return InteractionBlock(parameters=ps, qubit_support=q, add_pattern=add_pattern)
85
85
 
86
86
 
87
- # FIXME: Remove in v1.5.0
88
- def wait(
89
- duration: TNumber | sympy.Basic,
90
- qubit_support: str | QubitSupport | tuple = "global",
91
- add_pattern: bool = True,
92
- ) -> InteractionBlock:
93
- logger.warning("The alias `wait` is deprecated, please use `AnalogInteraction`")
94
- return AnalogInteraction(duration, qubit_support, add_pattern)
95
-
96
-
97
87
  # FIXME: clarify the usage of this gate, rename more formally, and implement in PyQ
98
88
  @dataclass(eq=False, repr=False)
99
89
  class AnalogEntanglement(AnalogBlock):
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from logging import getLogger
4
+
3
5
  import sympy
4
6
  import torch
5
7
  from rich.console import Console, RenderableType
@@ -16,7 +18,6 @@ from qadence.blocks.utils import (
16
18
  chain,
17
19
  kron,
18
20
  )
19
- from qadence.logger import get_logger
20
21
  from qadence.parameters import (
21
22
  Parameter,
22
23
  evaluate,
@@ -26,7 +27,7 @@ from qadence.types import OpName, TNumber, TParameter
26
27
  from .parametric import PHASE, RX, RY, RZ
27
28
  from .primitive import SWAP, I, N, X, Y, Z
28
29
 
29
- logger = get_logger(__name__)
30
+ logger = getLogger(__name__)
30
31
 
31
32
 
32
33
  class CNOT(ControlBlock):
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from copy import deepcopy
4
4
  from functools import cached_property
5
+ from logging import getLogger
5
6
  from typing import Any, Union
6
7
 
7
8
  import numpy as np
@@ -22,7 +23,6 @@ from qadence.blocks.utils import (
22
23
  expressions,
23
24
  )
24
25
  from qadence.decompose import lie_trotter_suzuki
25
- from qadence.logger import get_logger
26
26
  from qadence.parameters import (
27
27
  Parameter,
28
28
  ParamMap,
@@ -32,7 +32,7 @@ from qadence.parameters import (
32
32
  from qadence.types import LTSOrder, OpName, TGenerator, TParameter
33
33
  from qadence.utils import eigenvalues
34
34
 
35
- logger = get_logger(__name__)
35
+ logger = getLogger(__name__)
36
36
 
37
37
 
38
38
  class HamEvo(TimeEvolutionBlock):
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from logging import getLogger
4
+
3
5
  import numpy as np
4
6
  import sympy
5
7
  import torch
@@ -13,7 +15,6 @@ from qadence.blocks.utils import (
13
15
  add, # noqa
14
16
  chain,
15
17
  )
16
- from qadence.logger import get_logger
17
18
  from qadence.parameters import (
18
19
  Parameter,
19
20
  ParamMap,
@@ -23,7 +24,7 @@ from qadence.types import OpName, TNumber, TParameter
23
24
 
24
25
  from .primitive import I, X, Y, Z
25
26
 
26
- logger = get_logger(__name__)
27
+ logger = getLogger(__name__)
27
28
 
28
29
 
29
30
  class PHASE(ParametricBlock):
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from logging import getLogger
3
4
  from typing import Union
4
5
 
5
6
  import numpy as np
@@ -17,13 +18,12 @@ from qadence.blocks.utils import (
17
18
  chain,
18
19
  kron,
19
20
  )
20
- from qadence.logger import get_logger
21
21
  from qadence.parameters import (
22
22
  Parameter,
23
23
  )
24
24
  from qadence.types import OpName, TNumber
25
25
 
26
- logger = get_logger(__name__)
26
+ logger = getLogger(__name__)
27
27
 
28
28
 
29
29
  class X(PrimitiveBlock):
qadence/parameters.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from logging import getLogger
3
4
  from typing import Any, ItemsView, KeysView, ValuesView, get_args
4
5
  from uuid import uuid4
5
6
 
@@ -12,13 +13,12 @@ from sympy.physics.quantum.dagger import Dagger
12
13
  from sympytorch import SymPyModule as torchSympyModule
13
14
  from torch import Tensor, heaviside, no_grad, rand, tensor
14
15
 
15
- from qadence.logger import get_logger
16
16
  from qadence.types import DifferentiableExpression, Engine, TNumber
17
17
 
18
18
  # Modules to be automatically added to the qadence namespace
19
19
  __all__ = ["FeatureParameter", "Parameter", "VariationalParameter"]
20
20
 
21
- logger = get_logger(__file__)
21
+ logger = getLogger(__name__)
22
22
 
23
23
  dagger_expression = Dagger
24
24
 
@@ -0,0 +1,11 @@
1
+ Program = expr EOF
2
+ expr = id '(' args ')'
3
+ args = (value / expr) (',' (value / expr))*
4
+ value = int / float_str / str / pair
5
+ pair = param '=' (int / bool / str)
6
+ param = r"[a-z][a-zA-Z]*"
7
+ id = r"[a-zA-Z][a-zA-Z_0-9]*"
8
+ bool = 'True' / 'False'
9
+ int = r"-?(0|[1-9][0-9]*)"
10
+ float_str = r"'-?(0|\d*(\.\d*))'"
11
+ str = r"'[a-zA-Z0-9\_\.]+'"
qadence/serialization.py CHANGED
@@ -2,21 +2,26 @@ from __future__ import annotations
2
2
 
3
3
  import json
4
4
  import os
5
+ import sys
6
+ from dataclasses import dataclass
7
+ from dataclasses import field as dataclass_field
8
+ from functools import lru_cache
9
+ from logging import getLogger
5
10
  from pathlib import Path
6
- from typing import Any, get_args
11
+ from typing import Any, Callable, get_args
7
12
  from typing import Union as TypingUnion
8
13
 
9
14
  import torch
15
+ from arpeggio import NoMatch
16
+ from arpeggio.cleanpeg import ParserPEG
10
17
  from sympy import *
11
- from sympy import Basic, Expr, srepr
18
+ from sympy import core, srepr
12
19
 
13
- from qadence import QuantumCircuit, operations
20
+ from qadence import QuantumCircuit, operations, parameters
14
21
  from qadence import blocks as qadenceblocks
15
22
  from qadence.blocks import AbstractBlock
16
23
  from qadence.blocks.utils import tag
17
- from qadence.logger import get_logger
18
- from qadence.ml_tools.models import TransformedModule
19
- from qadence.models import QNN, QuantumModel
24
+ from qadence.models import QuantumModel
20
25
  from qadence.parameters import Parameter
21
26
  from qadence.register import Register
22
27
  from qadence.types import SerializationFormat
@@ -25,7 +30,7 @@ from qadence.types import SerializationFormat
25
30
  __all__ = ["deserialize", "load", "save", "serialize"]
26
31
 
27
32
 
28
- logger = get_logger(__name__)
33
+ logger = getLogger(__name__)
29
34
 
30
35
 
31
36
  def file_extension(file: Path | str) -> str:
@@ -43,27 +48,172 @@ SUPPORTED_OBJECTS = [
43
48
  AbstractBlock,
44
49
  QuantumCircuit,
45
50
  QuantumModel,
46
- QNN,
47
- TransformedModule,
48
51
  Register,
49
- Basic,
52
+ core.Basic,
50
53
  torch.nn.Module,
51
54
  ]
52
55
  SUPPORTED_TYPES = TypingUnion[
53
56
  AbstractBlock,
54
57
  QuantumCircuit,
55
58
  QuantumModel,
56
- QNN,
57
- TransformedModule,
58
59
  Register,
59
- Basic,
60
+ core.Basic,
60
61
  torch.nn.Module,
61
62
  ]
62
63
 
63
-
64
64
  ALL_BLOCK_NAMES = [
65
65
  n for n in dir(qadenceblocks) if not (n.startswith("__") and n.endswith("__"))
66
66
  ] + [n for n in dir(operations) if not (n.startswith("__") and n.endswith("__"))]
67
+ SYMPY_EXPRS = [n for n in dir(core) if not (n.startswith("__") and n.endswith("__"))]
68
+ QADENCE_PARAMETERS = [n for n in dir(parameters) if not (n.startswith("__") and n.endswith("__"))]
69
+
70
+
71
+ THIS_PATH = Path(__file__).parent
72
+ GRAMMAR_FILE = THIS_PATH / "serial_expr_grammar.peg"
73
+
74
+
75
+ @lru_cache
76
+ def _parser_fn() -> ParserPEG:
77
+ with open(GRAMMAR_FILE, "r") as f:
78
+ grammar = f.read()
79
+ return ParserPEG(grammar, "Program")
80
+
81
+
82
+ _parsing_serialize_expr = _parser_fn()
83
+
84
+
85
+ def parse_expr_fn(code: str) -> bool:
86
+ """
87
+ A parsing expressions function that checks whether a given code is valid on.
88
+
89
+ the parsing grammar. The grammar is defined to be compatible with `sympy`
90
+ expressions, such as `Float('-0.33261030434342942', precision=53)`, while
91
+ avoiding code injection such as `2*3` or `__import__('os').system('ls -la')`.
92
+
93
+ Args:
94
+ code (str): code to be parsed and checked.
95
+
96
+ Returns:
97
+ Boolean indicating whether the code matches the defined grammar or not.
98
+ """
99
+
100
+ parser = _parsing_serialize_expr
101
+ try:
102
+ parser.parse(code)
103
+ except NoMatch:
104
+ return False
105
+ else:
106
+ return True
107
+
108
+
109
+ @dataclass
110
+ class SerializationModel:
111
+ """
112
+ A serialization model class to serialize data from `QuantumModel`s,.
113
+
114
+ `torch.nn.Module` and similar structures. The data included in the
115
+ serialization logic includes: the `AbstractBlock` and its children
116
+ classes, `QuantumCircuit`, `Register`, and `sympy` expressions
117
+ (including `Parameter` class from `qadence.parameters`).
118
+
119
+ A children class must define the `value` attribute type and how to
120
+ handle it, since it is the main property for the class to be used
121
+ by the serialization process. For instance:
122
+
123
+ ```python
124
+ @dataclass
125
+ class QuantumCircuitSerialization(SerializationModel):
126
+ value: QuantumCircuit = dataclass_field(init=False)
127
+
128
+ def __post_init__(self) -> None:
129
+ self.value = (
130
+ QuantumCircuit._from_dict(self.d)
131
+ if isinstance(self.d, dict)
132
+ else self.d
133
+ )
134
+ ```
135
+ """
136
+
137
+ d: dict = dataclass_field(default_factory=dict)
138
+ value: Any = dataclass_field(init=False)
139
+
140
+
141
+ @dataclass
142
+ class BlockTypeSerialization(SerializationModel):
143
+ value: AbstractBlock = dataclass_field(init=False)
144
+
145
+ def __post_init__(self) -> None:
146
+ block = (
147
+ getattr(operations, self.d["type"])
148
+ if hasattr(operations, self.d["type"])
149
+ else getattr(qadenceblocks, self.d["type"])
150
+ )._from_dict(self.d)
151
+ if self.d["tag"] is not None:
152
+ block = tag(block, self.d["tag"])
153
+ self.value = block
154
+
155
+
156
+ @dataclass
157
+ class QuantumCircuitSerialization(SerializationModel):
158
+ value: QuantumCircuit = dataclass_field(init=False)
159
+
160
+ def __post_init__(self) -> None:
161
+ self.value = QuantumCircuit._from_dict(self.d) if isinstance(self.d, dict) else self.d
162
+
163
+
164
+ @dataclass
165
+ class RegisterSerialization(SerializationModel):
166
+ value: Register = dataclass_field(init=False)
167
+
168
+ def __post_init__(self) -> None:
169
+ self.value = Register._from_dict(self.d)
170
+
171
+
172
+ @dataclass
173
+ class ModelSerialization(SerializationModel):
174
+ as_torch: bool = False
175
+ value: torch.nn.Module = dataclass_field(init=False)
176
+
177
+ def __post_init__(self) -> None:
178
+ module_name = list(self.d.keys())[0]
179
+ obj = globals().get(module_name, None)
180
+ if obj is None:
181
+ obj = self._resolve_module(module_name)
182
+ if hasattr(obj, "_from_dict"):
183
+ self.value = obj._from_dict(self.d, self.as_torch)
184
+ elif hasattr(obj, "load_state_dict"):
185
+ self.value = obj.load_state_dict(self.d[module_name])
186
+ else:
187
+ msg = (
188
+ f"Unable to deserialize object '{module_name}'. "
189
+ f"Supported types are {SUPPORTED_OBJECTS}."
190
+ )
191
+ logger.error(TypeError(msg))
192
+ raise TypeError(msg)
193
+
194
+ @staticmethod
195
+ def _resolve_module(module: str) -> Any:
196
+ for loaded_module in sys.modules.keys():
197
+ if "qadence" in loaded_module:
198
+ obj = getattr(sys.modules[loaded_module], module, None)
199
+ if obj:
200
+ return obj
201
+ raise ValueError(f"Couldn't resolve module '{module}'.")
202
+
203
+
204
+ @dataclass
205
+ class ExpressionSerialization(SerializationModel):
206
+ value: str | core.Expr | float = dataclass_field(init=False)
207
+
208
+ def __post_init__(self) -> None:
209
+ if parse_expr_fn(self.d["expression"]):
210
+ expr = eval(self.d["expression"])
211
+ if hasattr(expr, "free_symbols"):
212
+ for s in expr.free_symbols:
213
+ s.value = float(self.d["symbols"][s.name]["value"])
214
+ self.value = expr
215
+ else:
216
+ raise ValueError(f"Invalid expression: {self.d['expression']}")
67
217
 
68
218
 
69
219
  def save_pt(d: dict, file_path: str | Path) -> None:
@@ -94,11 +244,11 @@ def serialize(obj: SUPPORTED_TYPES, save_params: bool = False) -> dict:
94
244
  """
95
245
  Supported Types:
96
246
 
97
- AbstractBlock | QuantumCircuit | QuantumModel | TransformedModule | Register | Module
247
+ AbstractBlock | QuantumCircuit | QuantumModel | torch.nn.Module | Register | Module
98
248
  Serializes a qadence object to a dictionary.
99
249
 
100
250
  Arguments:
101
- obj (AbstractBlock | QuantumCircuit | QuantumModel | Register | Module):
251
+ obj (AbstractBlock | QuantumCircuit | QuantumModel | Register | torch.nn.Module):
102
252
  Returns:
103
253
  A dict.
104
254
 
@@ -132,21 +282,28 @@ def serialize(obj: SUPPORTED_TYPES, save_params: bool = False) -> dict:
132
282
  """
133
283
  if not isinstance(obj, get_args(SUPPORTED_TYPES)):
134
284
  logger.error(TypeError(f"Serialization of object type {type(obj)} not supported."))
135
- d: dict = {}
285
+
286
+ d: dict = dict()
136
287
  try:
137
- if isinstance(obj, Expr):
138
- symb_dict = {}
288
+ if isinstance(obj, core.Expr):
289
+ symb_dict = dict()
139
290
  expr_dict = {"name": str(obj), "expression": srepr(obj)}
140
- symbs: set[Parameter | Basic] = obj.free_symbols
291
+ symbs: set[Parameter | core.Basic] = obj.free_symbols
141
292
  if symbs:
142
293
  symb_dict = {"symbols": {str(s): s._to_dict() for s in symbs}}
143
294
  d = {**expr_dict, **symb_dict}
144
- elif isinstance(obj, (QuantumModel, QNN, TransformedModule)):
145
- d = obj._to_dict(save_params)
146
- elif isinstance(obj, torch.nn.Module):
147
- d = {type(obj).__name__: obj.state_dict()}
148
295
  else:
149
- d = obj._to_dict()
296
+ if hasattr(obj, "_to_dict"):
297
+ model_to_dict: Callable = obj._to_dict
298
+ d = (
299
+ model_to_dict(save_params)
300
+ if isinstance(obj, torch.nn.Module)
301
+ else model_to_dict()
302
+ )
303
+ elif hasattr(obj, "state_dict"):
304
+ d = {type(obj).__name__: obj.state_dict()}
305
+ else:
306
+ raise ValueError(f"Cannot serialize object {obj}.")
150
307
  except Exception as e:
151
308
  logger.error(f"Serialization of object {obj} failed due to {e}")
152
309
  return d
@@ -156,13 +313,14 @@ def deserialize(d: dict, as_torch: bool = False) -> SUPPORTED_TYPES:
156
313
  """
157
314
  Supported Types:
158
315
 
159
- AbstractBlock | QuantumCircuit | QuantumModel | TransformedModule | Register | Module
316
+ AbstractBlock | QuantumCircuit | QuantumModel | Register | torch.nn.Module
160
317
  Deserializes a dict to one of the supported types.
161
318
 
162
319
  Arguments:
163
320
  d (dict): A dict containing a serialized object.
321
+ as_torch (bool): Whether to transform to torch for the deserialized object.
164
322
  Returns:
165
- AbstractBlock, QuantumCircuit, QuantumModel, TransformedModule, Register, Module.
323
+ AbstractBlock, QuantumCircuit, QuantumModel, Register, torch.nn.Module.
166
324
 
167
325
  Examples:
168
326
  ```python exec="on" source="material-block" result="json"
@@ -192,51 +350,18 @@ def deserialize(d: dict, as_torch: bool = False) -> SUPPORTED_TYPES:
192
350
  assert torch.isclose(qm.expectation({}), qm_deserialized.expectation({}))
193
351
  ```
194
352
  """
195
- obj: Any
353
+ obj: SerializationModel
196
354
  if d.get("expression"):
197
- expr = eval(d["expression"])
198
- if hasattr(expr, "free_symbols"):
199
- for symb in expr.free_symbols:
200
- symb.value = float(d["symbols"][symb.name]["value"])
201
- obj = expr
202
- elif d.get("QuantumModel"):
203
- obj = QuantumModel._from_dict(d, as_torch)
204
- elif d.get("QNN"):
205
- obj = QNN._from_dict(d, as_torch)
206
- elif d.get("TransformedModule"):
207
- obj = TransformedModule._from_dict(d, as_torch)
355
+ obj = ExpressionSerialization(d)
208
356
  elif d.get("block") and d.get("register"):
209
- obj = QuantumCircuit._from_dict(d)
357
+ obj = QuantumCircuitSerialization(d)
210
358
  elif d.get("graph"):
211
- obj = Register._from_dict(d)
359
+ obj = RegisterSerialization(d)
212
360
  elif d.get("type"):
213
- if d["type"] in ALL_BLOCK_NAMES:
214
- block: AbstractBlock = (
215
- getattr(operations, d["type"])._from_dict(d)
216
- if hasattr(operations, d["type"])
217
- else getattr(qadenceblocks, d["type"])._from_dict(d)
218
- )
219
- if d["tag"] is not None:
220
- block = tag(block, d["tag"])
221
- obj = block
361
+ obj = BlockTypeSerialization(d)
222
362
  else:
223
- import warnings
224
-
225
- msg = warnings.warn(
226
- "In order to load a custom torch.nn.Module, make sure its imported in the namespace."
227
- )
228
- try:
229
- module_name = list(d.keys())[0]
230
- obj = getattr(globals(), module_name)
231
- obj.load_state_dict(d[module_name])
232
- except Exception as e:
233
- logger.error(
234
- TypeError(
235
- f"{msg}. Unable to deserialize object due to {e}.\
236
- Supported objects are: {SUPPORTED_OBJECTS}"
237
- )
238
- )
239
- return obj
363
+ obj = ModelSerialization(d, as_torch=as_torch)
364
+ return obj.value
240
365
 
241
366
 
242
367
  def save(
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from copy import deepcopy
4
4
  from functools import singledispatch
5
+ from logging import getLogger
5
6
  from typing import Callable, Iterable, Type
6
7
 
7
8
  import sympy
@@ -26,11 +27,10 @@ from qadence.blocks.utils import (
26
27
  _construct,
27
28
  parameters,
28
29
  )
29
- from qadence.logger import get_logger
30
30
  from qadence.operations import SWAP, I
31
31
  from qadence.parameters import Parameter
32
32
 
33
- logger = get_logger(__name__)
33
+ logger = getLogger(__name__)
34
34
 
35
35
 
36
36
  def repeat(
qadence/types.py CHANGED
@@ -9,8 +9,8 @@ import sympy
9
9
  from numpy.typing import ArrayLike
10
10
  from torch import Tensor, pi
11
11
 
12
- TNumber = Union[int, float, complex]
13
- """Union of python number types."""
12
+ TNumber = Union[int, float, complex, np.int64, np.float64]
13
+ """Union of python and numpy numeric types."""
14
14
 
15
15
  TDrawColor = Tuple[float, float, float, float]
16
16