qilisdk 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. qilisdk/__init__.py +11 -2
  2. qilisdk/__init__.pyi +2 -3
  3. qilisdk/_logging.py +135 -0
  4. qilisdk/_optionals.py +5 -7
  5. qilisdk/analog/__init__.py +3 -18
  6. qilisdk/analog/exceptions.py +2 -4
  7. qilisdk/analog/hamiltonian.py +455 -110
  8. qilisdk/analog/linear_schedule.py +121 -0
  9. qilisdk/analog/schedule.py +275 -79
  10. qilisdk/{extras → backends}/__init__.py +9 -4
  11. qilisdk/{common/model.py → backends/__init__.pyi} +3 -1
  12. qilisdk/backends/backend.py +117 -0
  13. qilisdk/{extras/cuda → backends}/cuda_backend.py +152 -159
  14. qilisdk/backends/qutip_backend.py +473 -0
  15. qilisdk/core/__init__.py +63 -0
  16. qilisdk/{common → core}/algorithm.py +2 -1
  17. qilisdk/{extras/qaas/qaas_settings.py → core/exceptions.py} +12 -6
  18. qilisdk/core/model.py +1034 -0
  19. qilisdk/core/parameterizable.py +75 -0
  20. qilisdk/core/qtensor.py +666 -0
  21. qilisdk/{common → core}/result.py +2 -1
  22. qilisdk/core/variables.py +1969 -0
  23. qilisdk/cost_functions/__init__.py +18 -0
  24. qilisdk/cost_functions/cost_function.py +77 -0
  25. qilisdk/cost_functions/model_cost_function.py +145 -0
  26. qilisdk/cost_functions/observable_cost_function.py +109 -0
  27. qilisdk/digital/__init__.py +3 -22
  28. qilisdk/digital/ansatz.py +200 -160
  29. qilisdk/digital/circuit.py +81 -9
  30. qilisdk/digital/exceptions.py +12 -6
  31. qilisdk/digital/gates.py +229 -86
  32. qilisdk/{extras/qaas/qaas_analog_result.py → functionals/__init__.py} +14 -5
  33. qilisdk/functionals/functional.py +39 -0
  34. qilisdk/{common/backend.py → functionals/functional_result.py} +3 -1
  35. qilisdk/functionals/sampling.py +81 -0
  36. qilisdk/functionals/sampling_result.py +92 -0
  37. qilisdk/functionals/time_evolution.py +98 -0
  38. qilisdk/functionals/time_evolution_result.py +84 -0
  39. qilisdk/functionals/variational_program.py +80 -0
  40. qilisdk/functionals/variational_program_result.py +69 -0
  41. qilisdk/logging_config.yaml +16 -0
  42. qilisdk/{common → optimizers}/__init__.py +1 -1
  43. qilisdk/optimizers/optimizer.py +39 -0
  44. qilisdk/{common → optimizers}/optimizer_result.py +3 -12
  45. qilisdk/{common/optimizer.py → optimizers/scipy_optimizer.py} +10 -28
  46. qilisdk/settings.py +78 -0
  47. qilisdk/speqtrum/__init__.py +41 -0
  48. qilisdk/{extras → speqtrum}/__init__.pyi +3 -3
  49. qilisdk/speqtrum/experiments/__init__.py +25 -0
  50. qilisdk/speqtrum/experiments/experiment_functional.py +124 -0
  51. qilisdk/speqtrum/experiments/experiment_result.py +231 -0
  52. qilisdk/{extras/qaas → speqtrum}/keyring.py +8 -4
  53. qilisdk/speqtrum/speqtrum.py +587 -0
  54. qilisdk/speqtrum/speqtrum_models.py +467 -0
  55. qilisdk/utils/__init__.py +0 -14
  56. qilisdk/utils/openqasm2.py +1 -1
  57. qilisdk/utils/serialization.py +1 -1
  58. qilisdk/utils/visualization/PlusJakartaSans-SemiBold.ttf +0 -0
  59. qilisdk/utils/visualization/__init__.py +24 -0
  60. qilisdk/utils/visualization/circuit_renderers.py +781 -0
  61. qilisdk/utils/visualization/schedule_renderers.py +166 -0
  62. qilisdk/utils/visualization/style.py +154 -0
  63. qilisdk/utils/visualization/themes.py +76 -0
  64. qilisdk/yaml.py +126 -0
  65. {qilisdk-0.1.4.dist-info → qilisdk-0.1.6.dist-info}/METADATA +186 -140
  66. qilisdk-0.1.6.dist-info/RECORD +69 -0
  67. qilisdk/analog/algorithms.py +0 -111
  68. qilisdk/analog/analog_backend.py +0 -43
  69. qilisdk/analog/analog_result.py +0 -114
  70. qilisdk/analog/quantum_objects.py +0 -596
  71. qilisdk/digital/digital_algorithm.py +0 -20
  72. qilisdk/digital/digital_backend.py +0 -90
  73. qilisdk/digital/digital_result.py +0 -145
  74. qilisdk/digital/vqe.py +0 -166
  75. qilisdk/extras/cuda/__init__.py +0 -13
  76. qilisdk/extras/cuda/cuda_analog_result.py +0 -19
  77. qilisdk/extras/cuda/cuda_digital_result.py +0 -19
  78. qilisdk/extras/qaas/__init__.py +0 -13
  79. qilisdk/extras/qaas/models.py +0 -132
  80. qilisdk/extras/qaas/qaas_backend.py +0 -255
  81. qilisdk/extras/qaas/qaas_digital_result.py +0 -20
  82. qilisdk/extras/qaas/qaas_time_evolution_result.py +0 -20
  83. qilisdk/extras/qaas/qaas_vqe_result.py +0 -20
  84. qilisdk-0.1.4.dist-info/RECORD +0 -51
  85. {qilisdk-0.1.4.dist-info → qilisdk-0.1.6.dist-info}/WHEEL +0 -0
  86. {qilisdk-0.1.4.dist-info → qilisdk-0.1.6.dist-info}/licenses/LICENCE +0 -0
@@ -0,0 +1,1969 @@
1
+ # Copyright 2025 Qilimanjaro Quantum Tech
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import copy
18
+ import re
19
+ from abc import ABC, abstractmethod
20
+ from enum import Enum
21
+ from typing import TYPE_CHECKING, Iterator, Mapping, Sequence, TypeVar
22
+
23
+ import numpy as np
24
+ from loguru import logger
25
+
26
+ from qilisdk.core.exceptions import EvaluationError, InvalidBoundsError, NotSupportedOperation, OutOfBoundsException
27
+ from qilisdk.yaml import yaml
28
+
29
+ if TYPE_CHECKING:
30
+ from ruamel.yaml.nodes import ScalarNode
31
+ from ruamel.yaml.representer import RoundTripRepresenter
32
+
33
+ Number = int | float | complex
34
+ RealNumber = int | float
35
+ GenericVar = TypeVar("GenericVar", bound="Variable")
36
+ CONST_KEY = "_const_"
37
+ MAX_INT = np.iinfo(np.int64).max
38
+ MIN_INT = np.iinfo(np.int64).min
39
+ LARGE_BOUND = 100
40
+
41
+
42
+ def LT(lhs: RealNumber | BaseVariable | Term, rhs: RealNumber | BaseVariable | Term) -> ComparisonTerm:
43
+ """'Less Than' mathematical operation
44
+
45
+ Args:
46
+ lhs (RealNumber | BaseVariable | Term): the left hand side of the comparison term.
47
+ rhs (RealNumber | BaseVariable | Term): the right hand side of the comparison term.
48
+
49
+ Returns:
50
+ ComparisonTerm: a comparison term with the structure lhs < rhs.
51
+ """
52
+ return ComparisonTerm(lhs=lhs, rhs=rhs, operation=ComparisonOperation.LT)
53
+
54
+
55
+ LessThan = LT
56
+
57
+
58
+ def LEQ(lhs: RealNumber | BaseVariable | Term, rhs: RealNumber | BaseVariable | Term) -> ComparisonTerm:
59
+ """'Less Than or equal to' mathematical operation
60
+
61
+ Args:
62
+ lhs (RealNumber | BaseVariable | Term): the left hand side of the comparison term.
63
+ rhs (RealNumber | BaseVariable | Term): the right hand side of the comparison term.
64
+
65
+ Returns:
66
+ ComparisonTerm: a comparison term with the structure lhs <= rhs.
67
+ """
68
+ return ComparisonTerm(lhs=lhs, rhs=rhs, operation=ComparisonOperation.LEQ)
69
+
70
+
71
+ LessThanOrEqual = LEQ
72
+
73
+
74
+ def EQ(lhs: RealNumber | BaseVariable | Term, rhs: RealNumber | BaseVariable | Term) -> ComparisonTerm:
75
+ """'Equal to' mathematical operation
76
+
77
+ Args:
78
+ lhs (RealNumber | BaseVariable | Term): the left hand side of the comparison term.
79
+ rhs (RealNumber | BaseVariable | Term): the right hand side of the comparison term.
80
+
81
+ Returns:
82
+ ComparisonTerm: a comparison term with the structure lhs == rhs.
83
+ """
84
+ return ComparisonTerm(lhs=lhs, rhs=rhs, operation=ComparisonOperation.EQ)
85
+
86
+
87
+ Equal = EQ
88
+
89
+
90
+ def NEQ(lhs: RealNumber | BaseVariable | Term, rhs: RealNumber | BaseVariable | Term) -> ComparisonTerm:
91
+ """'Not Equal to' mathematical operation
92
+
93
+ Args:
94
+ lhs (RealNumber | BaseVariable | Term): the left hand side of the comparison term.
95
+ rhs (RealNumber | BaseVariable | Term): the right hand side of the comparison term.
96
+
97
+ Returns:
98
+ ComparisonTerm: a comparison term with the structure lhs != rhs.
99
+ """
100
+ return ComparisonTerm(lhs=lhs, rhs=rhs, operation=ComparisonOperation.NEQ)
101
+
102
+
103
+ NotEqual = NEQ
104
+
105
+
106
+ def GT(lhs: RealNumber | BaseVariable | Term, rhs: RealNumber | BaseVariable | Term) -> ComparisonTerm:
107
+ """'Greater Than' mathematical operation
108
+
109
+ Args:
110
+ lhs (RealNumber | BaseVariable | Term): the left hand side of the comparison term.
111
+ rhs (RealNumber | BaseVariable | Term): the right hand side of the comparison term.
112
+
113
+ Returns:
114
+ ComparisonTerm: a comparison term with the structure lhs > rhs.
115
+ """
116
+ return ComparisonTerm(lhs=lhs, rhs=rhs, operation=ComparisonOperation.GT)
117
+
118
+
119
+ GreaterThan = GT
120
+
121
+
122
+ def GEQ(lhs: RealNumber | BaseVariable | Term, rhs: RealNumber | BaseVariable | Term) -> ComparisonTerm:
123
+ """'Greater Than or equal to' mathematical operation
124
+
125
+ Args:
126
+ lhs (RealNumber | BaseVariable | Term): the left hand side of the comparison term.
127
+ rhs (RealNumber | BaseVariable | Term): the right hand side of the comparison term.
128
+
129
+ Returns:
130
+ ComparisonTerm: a comparison term with the structure lhs >= rhs.
131
+ """
132
+ return ComparisonTerm(lhs=lhs, rhs=rhs, operation=ComparisonOperation.GEQ)
133
+
134
+
135
+ GreaterThanOrEqual = GEQ
136
+
137
+
138
+ def _extract_number(label: str) -> int:
139
+ """Extracts the number from the variable's label.
140
+
141
+ Args:
142
+ label (str): variable label that follows the format <label>(<number>).
143
+
144
+ Returns:
145
+ int: the number in the label.
146
+ """
147
+ pattern = re.compile(r"\((\d+)\)$")
148
+ matches = pattern.search(label)
149
+ if matches is not None:
150
+ return int(matches.group(1))
151
+ return 0
152
+
153
+
154
+ def _float_if_real(value: Number) -> Number:
155
+ if isinstance(value, RealNumber):
156
+ return value
157
+ if isinstance(value, complex) and value.imag == 0:
158
+ return value.real
159
+ return value
160
+
161
+
162
+ def _assert_real(value: Number) -> RealNumber:
163
+ _value = _float_if_real(value)
164
+ if isinstance(_value, RealNumber):
165
+ return _value
166
+ raise ValueError(f"Only Real values are allowed but {_value} was provided.")
167
+
168
+
169
+ @yaml.register_class
170
+ class Domain(str, Enum):
171
+ INTEGER = "Integer Domain"
172
+ POSITIVE_INTEGER = "Positive Integer Domain"
173
+ REAL = "Real Domain"
174
+ BINARY = "Binary Domain"
175
+ SPIN = "Spin Domain"
176
+
177
+ def check_value(self, value: Number) -> bool:
178
+ """checks if the provided value is valid for a given domain
179
+
180
+ Args:
181
+ value (int | float): the value to be evaluated.
182
+
183
+ Returns:
184
+ bool: True if the value provided is valid, False otherwise.
185
+ """
186
+ if self == Domain.BINARY:
187
+ return isinstance(value, Number) and value in {0, 1}
188
+ if self == Domain.SPIN:
189
+ return isinstance(value, Number) and value in {-1, 1}
190
+ if self == Domain.REAL:
191
+ return isinstance(value, (int, float))
192
+ if self == Domain.INTEGER:
193
+ return isinstance(value, int)
194
+ if self == Domain.POSITIVE_INTEGER:
195
+ return isinstance(value, int) and value >= 0
196
+ return False
197
+
198
+ def min(self) -> float:
199
+ """
200
+ Returns:
201
+ float: the minimum value allowed of a given domain.
202
+ """
203
+ if self in {Domain.BINARY, Domain.POSITIVE_INTEGER}:
204
+ return 0
205
+ if self == Domain.SPIN:
206
+ return -1
207
+ if self == Domain.INTEGER:
208
+ return MIN_INT
209
+ return -1e30
210
+
211
+ def max(self) -> float:
212
+ """
213
+ Returns:
214
+ float: the maximum value allowed for a given domain.
215
+ """
216
+ if self in {Domain.BINARY, Domain.SPIN}:
217
+ return 1
218
+ if self in {Domain.POSITIVE_INTEGER, Domain.INTEGER}:
219
+ return MAX_INT
220
+ return 1e30
221
+
222
+ @classmethod
223
+ def to_yaml(cls, representer: RoundTripRepresenter, node: Domain) -> ScalarNode:
224
+ """
225
+ Method to be called automatically during YAML serialization.
226
+
227
+ Returns:
228
+ ScalarNode: The YAML scalar node representing the Domain.
229
+ """
230
+ return representer.represent_scalar("!Domain", f"{node.value}")
231
+
232
+ @classmethod
233
+ def from_yaml(cls, _, node: ScalarNode) -> Domain:
234
+ """
235
+ Method to be called automatically during YAML deserialization.
236
+
237
+ Returns:
238
+ Domain: The Domain instance created from the YAML node value.
239
+ """
240
+ return cls(node.value)
241
+
242
+
243
+ @yaml.register_class
244
+ class Operation(str, Enum):
245
+ MUL = "*"
246
+ ADD = "+"
247
+ DIV = "/"
248
+ SUB = "-"
249
+
250
+ @classmethod
251
+ def to_yaml(cls, representer: RoundTripRepresenter, node: Operation) -> ScalarNode:
252
+ """
253
+ Method to be called automatically during YAML serialization.
254
+
255
+ Returns:
256
+ ScalarNode: The YAML scalar node representing the Operation.
257
+ """
258
+ return representer.represent_scalar("!Operation", f"{node.value}")
259
+
260
+ @classmethod
261
+ def from_yaml(cls, _, node: ScalarNode) -> Operation:
262
+ """
263
+ Method to be called automatically during YAML deserialization.
264
+
265
+ Returns:
266
+ Operation: The Operation instance created from the YAML node value.
267
+ """
268
+ return cls(node.value)
269
+
270
+
271
+ @yaml.register_class
272
+ class ComparisonOperation(str, Enum):
273
+ LT = "<"
274
+ LEQ = "<="
275
+ EQ = "=="
276
+ NEQ = "!="
277
+ GT = ">"
278
+ GEQ = ">="
279
+
280
+ @classmethod
281
+ def to_yaml(cls, representer: RoundTripRepresenter, node: ComparisonOperation) -> ScalarNode:
282
+ """
283
+ Method to be called automatically during YAML serialization.
284
+
285
+ Returns:
286
+ ScalarNode: The YAML scalar node representing the ComparisonOperation.
287
+ """
288
+ return representer.represent_scalar("!ComparisonOperation", f"{node.value}")
289
+
290
+ @classmethod
291
+ def from_yaml(cls, _, node: ScalarNode) -> ComparisonOperation:
292
+ """
293
+ Method to be called automatically during YAML deserialization.
294
+
295
+ Returns:
296
+ ComparisonOperation: The ComparisonOperation instance created from the YAML node value.
297
+ """
298
+ return cls(node.value)
299
+
300
+
301
+ @yaml.register_class
302
+ class Encoding(ABC):
303
+ """Represents an abstract variable encoding class.
304
+
305
+ The Encoding is defined on the variable bases, and it defines how the continuous variables are encoded into binary
306
+ variables.
307
+ """
308
+
309
+ @property
310
+ @abstractmethod
311
+ def name(self) -> str:
312
+ """Encoding's name
313
+
314
+ Returns:
315
+ str: The name of the encoding.
316
+ """
317
+
318
+ @staticmethod
319
+ @abstractmethod
320
+ def encode(var: Variable, precision: float = 1e-2) -> Term:
321
+ """Given a continuous variable return a Term that only consists of
322
+ binary variables that represent the continuous variable in the given encoding.
323
+
324
+ Args:
325
+ var (ContinuousVar): The continuous variable to be encoded
326
+ precision (int): the precision to be considered for real variables (Only applies if
327
+ the variable domain is Domain.Real)
328
+
329
+ Returns:
330
+ Term: a term that only contains binary variables
331
+ """
332
+
333
+ @staticmethod
334
+ @abstractmethod
335
+ def encoding_constraint(var: Variable, precision: float = 1e-2) -> ComparisonTerm:
336
+ """Given a continuous variable return a Constraint Term that ensures that the encoding is respected.
337
+
338
+ Args:
339
+ var (ContinuousVar): The continuous variable to be encoded
340
+ precision (float): the precision to be considered for real variables (Only applies if
341
+ the variable domain is Domain.Real)
342
+
343
+ Returns:
344
+ Constraint Term: a constraint term that ensures the encoding is respected.
345
+ """
346
+
347
+ @staticmethod
348
+ @abstractmethod
349
+ def evaluate(var: Variable, value: list[int] | int, precision: float = 1e-2) -> float:
350
+ """Given a binary string, evaluate the value of the continuous variable in the given encoding.
351
+
352
+ Args:
353
+ var (ContinuousVar): the variable to be evaluated
354
+ value (list[int] | int): a list of binary values or an integer value.
355
+ precision (float): the precision to be considered for real variables (Only applies if
356
+ the variable domain is Domain.Real)
357
+
358
+ Returns:
359
+ float: the value of the continuous variable given the specified binary values.
360
+ """
361
+
362
+ @staticmethod
363
+ @abstractmethod
364
+ def num_binary_equivalent(var: "Variable", precision: float = 1e-2) -> int:
365
+ """Give a continuous variable return the number of binary variables needed to encode it.
366
+
367
+ Args:
368
+ var (ContinuousVar): the continuous variable.
369
+ precision (float): the precision to be considered for real variables (Only applies if
370
+ the variable domain is Domain.Real)
371
+
372
+ Returns:
373
+ int: the number of binary variables needed to encode it.
374
+ """
375
+
376
+ @staticmethod
377
+ @abstractmethod
378
+ def check_valid(value: list[int] | int) -> tuple[bool, int]:
379
+ """checks if the binary list sample is a valid sample in this encoding.
380
+
381
+ Args:
382
+ value (list[int] | int): a list of binary values or an integer value.
383
+
384
+ Returns:
385
+ tuple[bool, int]: the boolean is True if the sample is a valid encoding,
386
+ while the int is the error in the encoding.
387
+ """
388
+
389
+
390
+ @yaml.register_class
391
+ class Bitwise(Encoding):
392
+ """Represents a Bitwise variable encoding class."""
393
+
394
+ name = "Bitwise"
395
+
396
+ @staticmethod
397
+ def _bitwise_encode(x: int, N: int) -> list[int]:
398
+ """encode the integer x in Bitwise encoding.
399
+
400
+ Args:
401
+ x (int): the integer to be encoded.
402
+ N (int): the number of bits to encode x.
403
+
404
+ Returns:
405
+ list[int]: a binary list representing the Bitwise encoding of the integer x.
406
+ """
407
+ return list(reversed([int(b) for b in format(x, f"0{N}b")]))
408
+
409
+ @staticmethod
410
+ def encode(var: Variable, precision: float = 1e-2) -> Term:
411
+ bounds = var.bounds
412
+ if var.domain is Domain.REAL:
413
+ bounds = (bounds[0] / precision, bounds[1] / precision)
414
+
415
+ abs_bound = np.abs(bounds[1] - bounds[0])
416
+ n_binary = int(np.floor(np.log2(abs_bound if abs_bound != 0 else 1)))
417
+ binary_vars = [BinaryVariable(var.label + f"({i})") for i in range(n_binary + 1)]
418
+
419
+ term = sum(2**i * binary_vars[i] for i in range(n_binary))
420
+ term += (np.abs(bounds[1] - bounds[0]) + 1 - 2**n_binary) * binary_vars[-1]
421
+ term += bounds[0]
422
+ return term * var.precision if var.domain is Domain.REAL else term
423
+
424
+ @staticmethod
425
+ def evaluate(var: Variable, value: list[int] | int, precision: float = 1e-2) -> float:
426
+ term = Bitwise.encode(var, precision)
427
+ binary_var = sorted(
428
+ term.variables(),
429
+ key=lambda x: _extract_number(x.label),
430
+ )
431
+
432
+ binary_list = Bitwise._bitwise_encode(value, len(binary_var)) if isinstance(value, Number) else value
433
+
434
+ if not Bitwise.check_valid(binary_list)[0]:
435
+ raise ValueError(
436
+ f"invalid binary string {binary_list} with the Bitwise encoding."
437
+ ) # can not be reached in the case of Bitwise encoding.
438
+
439
+ if len(binary_list) < len(binary_var):
440
+ for _ in range(len(binary_var) - len(binary_list)):
441
+ binary_list.append(0)
442
+ elif len(binary_list) != len(binary_var):
443
+ raise ValueError(f"expected {len(binary_var)} variables but received {len(binary_list)}")
444
+
445
+ binary_dict: dict[BaseVariable, list[int]] = {binary_var[i]: [binary_list[i]] for i in range(len(binary_list))}
446
+
447
+ _out = term.evaluate(binary_dict)
448
+
449
+ if isinstance(_out, RealNumber):
450
+ out = float(_out)
451
+ elif isinstance(_out, complex) and _out.imag == 0:
452
+ out = float(_out.real)
453
+ else:
454
+ raise ValueError(f"Evaluation answer ({_out}) is outside the variable domain ({var.domain}).")
455
+
456
+ out = int(out) if var.domain in {Domain.INTEGER, Domain.POSITIVE_INTEGER} else out
457
+
458
+ if not var.domain.check_value(out):
459
+ raise ValueError(
460
+ f"The value {out} violates the domain {var.domain.__class__.__name__} of the variable {var}"
461
+ ) # not sure this line can be reached.
462
+ return out
463
+
464
+ @staticmethod
465
+ def encoding_constraint(var: Variable, precision: float = 1e-2) -> ComparisonTerm:
466
+ raise NotImplementedError("Bitwise encoding constraints are not supported at the moment")
467
+
468
+ @staticmethod
469
+ def num_binary_equivalent(var: "Variable", precision: float = 1e-2) -> int:
470
+ bounds = var.bounds
471
+ if var.domain is Domain.REAL:
472
+ bounds = (bounds[0] / precision, bounds[1] / precision)
473
+
474
+ n_binary = int(np.floor(np.log2(np.abs(bounds[1] - bounds[0]))))
475
+
476
+ return n_binary + 1
477
+
478
+ @staticmethod
479
+ def check_valid(value: list[int] | int) -> tuple[bool, int]:
480
+ return True, 0
481
+
482
+
483
+ @yaml.register_class
484
+ class OneHot(Encoding):
485
+ """Represents a One-Hot variable encoding class."""
486
+
487
+ name = "One-Hot"
488
+
489
+ @staticmethod
490
+ def _one_hot_encode(x: int, N: int) -> list[int]:
491
+ """One-hot encode integer x in range [0, N-1].
492
+
493
+ Args:
494
+ x (int): the value to be encoded
495
+ N (int): the number of bits to encode x.
496
+
497
+ Raises:
498
+ ValueError: if x is larger than N or less than 0
499
+
500
+ Returns:
501
+ list[int]: a binary list representing the one hot encoding of the integer x.
502
+ """
503
+ if not (0 <= x < N):
504
+ raise ValueError(f"the input value ({x}) must be in range [0, {N - 1}]")
505
+ return [1 if i == x else 0 for i in range(N)]
506
+
507
+ @staticmethod
508
+ def _find_zero(var: Variable) -> int:
509
+ binary_var = var.bin_vars
510
+ term = var.term
511
+ for i in range(var.num_binary_equivalent()):
512
+ if binary_var[i] not in term:
513
+ return i
514
+ return 0
515
+
516
+ @staticmethod
517
+ def encode(var: Variable, precision: float = 1e-2) -> Term:
518
+ bounds = var.bounds
519
+ if var.domain is Domain.REAL:
520
+ bounds = (bounds[0] / precision, bounds[1] / precision)
521
+
522
+ n_binary = int(np.abs(bounds[1] - bounds[0])) + 1
523
+
524
+ binary_vars = [BinaryVariable(var.label + f"({i})") for i in range(n_binary)]
525
+
526
+ term = Term([(bounds[0] + i) * binary_vars[i] for i in range(n_binary)], Operation.ADD)
527
+
528
+ return term * var.precision if var.domain is Domain.REAL else term
529
+
530
+ @staticmethod
531
+ def evaluate(var: Variable, value: list[int] | int, precision: float = 1e-2) -> float:
532
+ term = OneHot.encode(var, precision)
533
+ binary_var = sorted(
534
+ term.variables(),
535
+ key=lambda x: _extract_number(x.label),
536
+ )
537
+
538
+ binary_list = OneHot._one_hot_encode(value, len(binary_var) + 1) if isinstance(value, int) else value
539
+
540
+ if not OneHot.check_valid(binary_list)[0]:
541
+ raise ValueError(f"invalid binary string {binary_list} with the one hot encoding.")
542
+
543
+ # after encoding we will have one less variable than the binary list, because the first variable is multiplied
544
+ # by 0 so it is removed from the term.
545
+ if len(binary_list) < len(binary_var) + 1:
546
+ for _ in range(len(binary_var) - len(binary_list) + 1):
547
+ binary_list.append(0)
548
+ elif len(binary_list) != len(binary_var) + 1:
549
+ raise ValueError(f"expected {len(binary_var) + 1} variables but received {len(binary_list)}")
550
+
551
+ zero_index = OneHot._find_zero(var)
552
+ binary_dict: dict[BaseVariable, list[int]] = {}
553
+ for i in range(var.num_binary_equivalent()):
554
+ if i < zero_index:
555
+ binary_dict[binary_var[i]] = [binary_list[i]]
556
+ if i > zero_index:
557
+ binary_dict[binary_var[i - 1]] = [binary_list[i]]
558
+
559
+ _out = term.evaluate(binary_dict)
560
+
561
+ if isinstance(_out, RealNumber):
562
+ out = float(_out)
563
+ elif isinstance(_out, complex) and _out.imag == 0:
564
+ out = float(_out.real)
565
+ else:
566
+ raise ValueError(f"Evaluation answer ({_out}) is outside the variable domain ({var.domain}).")
567
+
568
+ out = int(out) if var.domain in {Domain.INTEGER, Domain.POSITIVE_INTEGER} else out
569
+
570
+ if not var.domain.check_value(out):
571
+ raise ValueError(
572
+ f"The value {out} violates the domain {var.domain.__class__.__name__} of the variable {var}"
573
+ ) # not sure this line can be reached.
574
+
575
+ return out
576
+
577
+ @staticmethod
578
+ def encoding_constraint(var: Variable, precision: float = 1e-2) -> ComparisonTerm:
579
+ bounds = var.bounds
580
+ if var.domain is Domain.REAL:
581
+ bounds = (bounds[0] / precision, bounds[1] / precision)
582
+
583
+ n_binary = int(np.abs(bounds[1] - bounds[0])) + 1
584
+
585
+ binary_vars = [BinaryVariable(var.label + f"({i})") for i in range(n_binary)]
586
+ return ComparisonTerm(lhs=sum(binary_vars), rhs=1, operation=ComparisonOperation.EQ)
587
+
588
+ @staticmethod
589
+ def num_binary_equivalent(var: Variable, precision: float = 1e-2) -> int:
590
+ bounds = var.bounds
591
+ if var.domain is Domain.REAL:
592
+ bounds = (bounds[0] / precision, bounds[1] / precision)
593
+
594
+ n_binary = int(np.abs(bounds[1] - bounds[0])) + 1
595
+
596
+ return n_binary
597
+
598
+ @staticmethod
599
+ def check_valid(value: list[int] | int) -> tuple[bool, int]:
600
+ binary_list = OneHot._one_hot_encode(value, value) if isinstance(value, int) else value
601
+ num_ones = binary_list.count(1)
602
+ return num_ones == 1, (num_ones - 1) ** 2
603
+
604
+
605
+ @yaml.register_class
606
+ class DomainWall(Encoding):
607
+ """Represents a Domain-wall variable encoding class."""
608
+
609
+ name = "Domain Wall"
610
+
611
+ @staticmethod
612
+ def _domain_wall_encode(x: int, N: int) -> list[int]:
613
+ """domain wall encode integer x in range [0, N-1].
614
+
615
+ Args:
616
+ x (int): the value to be encoded
617
+ N (int): the number of bits to encode x.
618
+
619
+ Raises:
620
+ ValueError: if x is larger than N or less than 0
621
+
622
+ Returns:
623
+ list[int]: a binary list representing the domain wall encoding of the integer x.
624
+ """
625
+ if not (0 <= x <= N):
626
+ raise ValueError(f"the input value ({x}) must be in range [0, {N}]")
627
+ return [1] * x + [0] * (N - x)
628
+
629
+ @staticmethod
630
+ def encode(var: Variable, precision: float = 1e-2) -> Term:
631
+ bounds = var.bounds
632
+ if var.domain is Domain.REAL:
633
+ bounds = (bounds[0] / precision, bounds[1] / precision)
634
+
635
+ n_binary = int(np.abs(bounds[1] - bounds[0]))
636
+
637
+ binary_vars = [BinaryVariable(var.label + f"({i})") for i in range(n_binary)]
638
+
639
+ term = Term([0], Operation.ADD)
640
+ for i in range(n_binary):
641
+ term += binary_vars[i]
642
+
643
+ term += bounds[0]
644
+
645
+ return term * var.precision if var.domain is Domain.REAL else term
646
+
647
+ @staticmethod
648
+ def evaluate(var: Variable, value: list[int] | int, precision: float = 1e-2) -> float:
649
+ term = DomainWall.encode(var, precision)
650
+ binary_var = term.variables()
651
+ binary_var = sorted(
652
+ term.variables(),
653
+ key=lambda x: _extract_number(x.label),
654
+ )
655
+
656
+ binary_list: list[int] = (
657
+ DomainWall._domain_wall_encode(value, len(binary_var)) if isinstance(value, RealNumber) else value
658
+ )
659
+
660
+ if not DomainWall.check_valid(binary_list)[0]:
661
+ raise ValueError(f"invalid binary string {binary_list} with the domain wall encoding.")
662
+
663
+ if len(binary_list) < len(binary_var):
664
+ for _ in range(len(binary_var) - len(binary_list)):
665
+ binary_list.append(0)
666
+ elif len(binary_list) != len(binary_var):
667
+ raise ValueError(f"expected {len(binary_var)} variables but received {len(binary_list)}")
668
+
669
+ binary_dict: dict[BaseVariable, list[int]] = {binary_var[i]: [binary_list[i]] for i in range(len(binary_list))}
670
+
671
+ _out = term.evaluate(binary_dict)
672
+
673
+ if isinstance(_out, RealNumber):
674
+ out = float(_out)
675
+ elif isinstance(_out, complex) and _out.imag == 0:
676
+ out = float(_out.real)
677
+ else:
678
+ raise ValueError(f"Evaluation answer ({_out}) is outside the variable domain ({var.domain}).")
679
+
680
+ out = int(out) if var.domain in {Domain.INTEGER, Domain.POSITIVE_INTEGER} else out
681
+
682
+ if not var.domain.check_value(out):
683
+ raise ValueError(
684
+ f"The value {out} violates the domain {var.domain.__class__.__name__} of the variable {var}"
685
+ ) # not sure if this line is reachable.
686
+ return out
687
+
688
+ @staticmethod
689
+ def encoding_constraint(var: Variable, precision: float = 1e-2) -> ComparisonTerm:
690
+ bounds = var.bounds
691
+ if var.domain is Domain.REAL:
692
+ bounds = (bounds[0] / precision, bounds[1] / precision)
693
+
694
+ n_binary = int(np.abs(bounds[1] - bounds[0]))
695
+
696
+ binary_vars = [BinaryVariable(var.label + f"({i})") for i in range(n_binary)]
697
+ return ComparisonTerm(
698
+ lhs=sum(binary_vars[i + 1] * (1 - binary_vars[i]) for i in range(len(binary_vars) - 1)),
699
+ rhs=0,
700
+ operation=ComparisonOperation.EQ,
701
+ )
702
+
703
+ @staticmethod
704
+ def num_binary_equivalent(var: Variable, precision: float = 1e-2) -> int:
705
+ bounds = var.bounds
706
+ if var.domain is Domain.REAL:
707
+ bounds = (bounds[0] / precision, bounds[1] / precision)
708
+
709
+ n_binary = int(np.abs(bounds[1] - bounds[0]))
710
+
711
+ return n_binary
712
+
713
+ @staticmethod
714
+ def check_valid(value: list[int] | int) -> tuple[bool, int]:
715
+ binary_list = DomainWall._domain_wall_encode(value, value) if isinstance(value, int) else value
716
+ value = sum(binary_list[i + 1] * (1 - binary_list[i]) for i in range(len(binary_list) - 1))
717
+ return value == 0, value
718
+
719
+
720
+ # Variables ###
721
+
722
+
723
+ class BaseVariable(ABC):
724
+ """
725
+ Abstract base class for symbolic decision variables.
726
+ """
727
+
728
+ def __init__(self, label: str, domain: Domain, bounds: tuple[float | None, float | None] = (None, None)) -> None:
729
+ """initialize a new Variable object
730
+
731
+ Args:
732
+ label (str): The name of the variable.
733
+ domain (Domain): The domain of the values this variable can take.
734
+ bounds (tuple[float | None, float | None], optional): the bounds on the variable's values.
735
+ The bounds follow the structure (lower_bound, Upper_bound) both
736
+ included. Defaults to (None, None).
737
+ Note: if None is selected then the lowest/highest possible value of the
738
+ variable's domain is chosen.
739
+
740
+ Raises:
741
+ OutOfBoundsException: the lower bound or the upper bound don't correspond to the variable domain.
742
+ InvalidBoundsError: the lower bound is higher than the upper bound.
743
+ """
744
+ self._label = label
745
+ self._domain = domain
746
+
747
+ lower_bound, upper_bound = bounds
748
+ if lower_bound is None:
749
+ lower_bound = domain.min()
750
+ if upper_bound is None:
751
+ upper_bound = domain.max()
752
+
753
+ if not self.domain.check_value(upper_bound):
754
+ raise OutOfBoundsException(
755
+ f"the lower bound ({upper_bound}) does not respect the domain of the variable ({self.domain})"
756
+ )
757
+ if not self.domain.check_value(lower_bound):
758
+ raise OutOfBoundsException(
759
+ f"the upper bound ({lower_bound}) does not respect the domain of the variable ({self.domain})"
760
+ )
761
+ if lower_bound > upper_bound:
762
+ raise InvalidBoundsError("lower bound can't be larger than the upper bound.")
763
+ self._bounds = (lower_bound, upper_bound)
764
+
765
+ @property
766
+ def bounds(self) -> tuple[float, float]:
767
+ """Property that stores a tuple representing the bounds of the values a variable is allowed to take.º
768
+
769
+ Returns:
770
+ tuple(float, float): The lower and upper bound of the variable.
771
+ """
772
+ return self._bounds
773
+
774
+ @property
775
+ def lower_bound(self) -> float:
776
+ """The lower bound of the variable.
777
+
778
+ Returns:
779
+ float: the value of the lower bound.
780
+ """
781
+ return self._bounds[0]
782
+
783
+ @property
784
+ def upper_bound(self) -> float:
785
+ """The upper bound of the variable.
786
+
787
+ Returns:
788
+ float: the value of the upper bound.
789
+ """
790
+ return self._bounds[1]
791
+
792
+ @property
793
+ def label(self) -> str:
794
+ """the label (name) of the variable.
795
+
796
+ Returns:
797
+ string: the name of the variable.
798
+ """
799
+ return self._label
800
+
801
+ @property
802
+ def domain(self) -> Domain:
803
+ """The domain of values that the variable is allowed to take.
804
+
805
+ Returns:
806
+ Domain: The domain of the values the variable can take.
807
+ """
808
+ return self._domain
809
+
810
+ def set_bounds(self, lower_bound: float | None, upper_bound: float | None) -> None:
811
+ """set the bounds of the variable.
812
+
813
+ Args:
814
+ lower_bound (float | None): The lower bound (if None the lowest allowed bound in the variable domain is
815
+ selected). Defaults to None.
816
+ upper_bound (float | None): The upper bound (if None the highest allowed bound in the variable domain is
817
+ selected). Defaults to None.
818
+ Raises:
819
+ OutOfBoundsException: the lower bound or the upper bound don't correspond to the variable domain.
820
+ InvalidBoundsError: the lower bound is higher than the upper bound.
821
+ """
822
+ if lower_bound is None:
823
+ lower_bound = self._domain.min()
824
+ if upper_bound is None:
825
+ upper_bound = self._domain.max()
826
+ if not self.domain.check_value(lower_bound):
827
+ raise OutOfBoundsException(
828
+ f"the lower bound ({lower_bound}) does not respect the domain of the variable ({self.domain})"
829
+ )
830
+ if not self.domain.check_value(upper_bound):
831
+ raise OutOfBoundsException(
832
+ f"the upper bound ({upper_bound}) does not respect the domain of the variable ({self.domain})"
833
+ )
834
+ if lower_bound > upper_bound:
835
+ raise InvalidBoundsError(
836
+ f"the lower bound ({lower_bound}) should not be greater than the upper bound ({upper_bound})"
837
+ )
838
+ self._bounds = (lower_bound, upper_bound)
839
+
840
+ @abstractmethod
841
+ def num_binary_equivalent(self) -> int:
842
+ """
843
+ Returns:
844
+ int: the number of binary variables that are needed to represent this variable in the given encoding.
845
+ """
846
+
847
+ @abstractmethod
848
+ def evaluate(self, value: list[int] | RealNumber) -> RealNumber:
849
+ """Evaluates the value of the variable given a binary string or a number.
850
+
851
+ Args:
852
+ value (list[int] | int | float): the value used to evaluate the variable.
853
+ If the value provided is binary list (list[int]) then the value of the variable is evaluated based on
854
+ its binary representation. This representation is constructed using the encoding, bounds and domain
855
+ of the variable. To check the binary representation of a variable you can check the method `to_binary()`
856
+
857
+ Returns:
858
+ int | float | complex: the evaluated vale of the variable.
859
+ """
860
+
861
+ def update_variable(self, domain: Domain, bounds: tuple[float | None, float | None] = (None, None)) -> None:
862
+ """Replaces the information of the variable with those coming from the dictionary
863
+ if the variable label is in the dictionary
864
+
865
+ Args:
866
+ domain (Domain): The updated domain of the variable.
867
+ bounds (tuple[float | None, float | None]): The updated bounds of the variable. Defaults to (None, None)
868
+ """
869
+
870
+ self._domain = domain
871
+ self.set_bounds(bounds[0], bounds[1])
872
+
873
+ @abstractmethod
874
+ def to_binary(self) -> Term:
875
+ """Returns the binary representation of a variable.
876
+
877
+ Returns:
878
+ Term: the binary representation of a variable.
879
+ """
880
+
881
+ def __repr__(self) -> str:
882
+ return f"{self._label}"
883
+
884
+ def __str__(self) -> str:
885
+ return f"{self._label}"
886
+
887
+ def __add__(self, other: Number | BaseVariable | Term) -> Term:
888
+ if not isinstance(other, (Number, BaseVariable, Term)):
889
+ return NotImplemented
890
+ if isinstance(other, Term):
891
+ return other + self
892
+
893
+ if isinstance(other, np.generic):
894
+ other = other.item()
895
+
896
+ return Term(elements=[self, other], operation=Operation.ADD)
897
+
898
+ __radd__ = __add__
899
+ __iadd__ = __add__
900
+
901
+ def __mul__(self, other: Number | BaseVariable | Term) -> Term:
902
+ if not isinstance(other, (Number, BaseVariable, Term)):
903
+ return NotImplemented
904
+ if isinstance(other, Term):
905
+ return other * self
906
+
907
+ if isinstance(other, np.generic):
908
+ other = other.item()
909
+
910
+ return Term(elements=[self, other], operation=Operation.MUL)
911
+
912
+ def __rmul__(self, other: Number | BaseVariable | Term) -> Term:
913
+ if not isinstance(other, (Number, BaseVariable, Term)):
914
+ return NotImplemented
915
+ if isinstance(other, Term):
916
+ return other * self
917
+
918
+ if isinstance(other, np.generic):
919
+ other = other.item()
920
+
921
+ return Term(elements=[other, self], operation=Operation.MUL)
922
+
923
+ __imul__ = __mul__
924
+
925
+ def __sub__(self, other: Number | BaseVariable | Term) -> Term:
926
+ if not isinstance(other, (Number, BaseVariable, Term)):
927
+ return NotImplemented
928
+
929
+ if isinstance(other, np.generic):
930
+ other = other.item()
931
+
932
+ return self + -1 * other
933
+
934
+ def __rsub__(self, other: Number | BaseVariable | Term) -> Term:
935
+ if not isinstance(other, (Number, BaseVariable, Term)):
936
+ return NotImplemented
937
+
938
+ if isinstance(other, np.generic):
939
+ other = other.item()
940
+
941
+ return -1 * self + other
942
+
943
+ __isub__ = __sub__
944
+
945
+ def __neg__(self) -> Term:
946
+ return -1 * self
947
+
948
+ def __truediv__(self, other: RealNumber) -> Term:
949
+ if not isinstance(other, RealNumber):
950
+ raise NotImplementedError("Only division by real numbers is currently supported")
951
+
952
+ if other == 0:
953
+ raise ValueError("Division by zero is not allowed")
954
+
955
+ if isinstance(other, np.generic):
956
+ other = other.item()
957
+ other = 1 / other
958
+ return self * other
959
+
960
+ __itruediv__ = __truediv__
961
+
962
+ def __rtruediv__(self, other: Number | BaseVariable | Term) -> Term:
963
+ raise NotSupportedOperation("Only division by numbers is currently supported")
964
+
965
+ def __rfloordiv__(self, other: Number | BaseVariable | Term) -> Term:
966
+ raise NotSupportedOperation("Only division by numbers is currently supported")
967
+
968
+ def __pow__(self, a: int) -> Term:
969
+ out: BaseVariable | Term = copy.copy(self)
970
+
971
+ if a < 0:
972
+ raise NotImplementedError("Negative Power is not Supported.")
973
+
974
+ if a == 0:
975
+ return Term(elements=[1], operation=Operation.ADD)
976
+
977
+ for _ in range(a - 1):
978
+ out *= copy.copy(self)
979
+
980
+ if isinstance(out, BaseVariable):
981
+ out = Term(elements=[out], operation=Operation.ADD)
982
+ return out
983
+
984
+ def __hash__(self) -> int:
985
+ return hash((self._label, self._domain.value, self._bounds))
986
+
987
+ def __eq__(self, other: object) -> bool:
988
+ if not isinstance(other, BaseVariable):
989
+ return False
990
+ return hash(self) == hash(other)
991
+
992
+
993
+ @yaml.register_class
994
+ class BinaryVariable(BaseVariable):
995
+ """
996
+ Binary decision variable restricted to the set ``{0, 1}``.
997
+
998
+ Example:
999
+ .. code-block:: python
1000
+
1001
+ from qilisdk.core.variables import BinaryVariable
1002
+
1003
+ x = BinaryVariable("x")
1004
+ """
1005
+
1006
+ def __init__(self, label: str) -> None:
1007
+ super().__init__(label=label, domain=Domain.BINARY)
1008
+
1009
+ def num_binary_equivalent(self) -> int: # noqa: PLR6301
1010
+ return 1
1011
+
1012
+ def evaluate(self, value: list[int] | RealNumber) -> RealNumber:
1013
+ if isinstance(value, int | float):
1014
+ if value in {1.0, 0.0}:
1015
+ return int(value)
1016
+ if not self.domain.check_value(value):
1017
+ raise EvaluationError(f"Evaluating a Binary variable with a value {value} that is outside the domain.")
1018
+ return value # I don't think this line is reachable
1019
+ if len(value) != 1:
1020
+ raise EvaluationError("Evaluating a Binary variable with a binary list of more than one item.")
1021
+ return value[0]
1022
+
1023
+ def update_variable(self, domain: Domain, bounds: tuple[float | None, float | None] = (None, None)) -> None:
1024
+ raise NotImplementedError
1025
+
1026
+ def to_binary(self) -> Term:
1027
+ return Term([self], Operation.ADD)
1028
+
1029
+ def __copy__(self) -> BinaryVariable:
1030
+ return BinaryVariable(label=self.label)
1031
+
1032
+
1033
+ @yaml.register_class
1034
+ class SpinVariable(BaseVariable):
1035
+ """Represents Spin Variable structure."""
1036
+
1037
+ def __init__(self, label: str) -> None:
1038
+ super().__init__(label=label, domain=Domain.SPIN, bounds=(-1, 1))
1039
+
1040
+ def num_binary_equivalent(self) -> int: # noqa: PLR6301
1041
+ return 1
1042
+
1043
+ def update_variable(self, domain: Domain, bounds: tuple[float | None, float | None] = (None, None)) -> None:
1044
+ raise NotImplementedError
1045
+
1046
+ def evaluate(self, value: list[int] | RealNumber) -> RealNumber:
1047
+ if isinstance(value, Number):
1048
+ if not self.domain.check_value(value) and value != 0:
1049
+ raise EvaluationError(f"Evaluating a Spin variable with a value {value} that is outside the domain.")
1050
+ return -1 if value in {0, -1} else 1
1051
+ if len(value) != 1:
1052
+ raise EvaluationError("Evaluating a Spin variable with a list of more than one item.")
1053
+ return -1 if value[0] in {0, -1} else 1
1054
+
1055
+ def to_binary(self) -> Term:
1056
+ return Term([self], Operation.ADD)
1057
+
1058
+ def __copy__(self) -> SpinVariable:
1059
+ return SpinVariable(label=self.label)
1060
+
1061
+
1062
+ @yaml.register_class
1063
+ class Variable(BaseVariable):
1064
+ """
1065
+ Generic (possibly continuous) optimization variable with configurable encoding.
1066
+
1067
+ Example:
1068
+ .. code-block:: python
1069
+
1070
+ from qilisdk.core.variables import Domain, Variable
1071
+
1072
+ price = Variable("price", domain=Domain.REAL, bounds=(0, 10))
1073
+ binary_term = price.to_binary()
1074
+ """
1075
+
1076
+ def __init__(
1077
+ self,
1078
+ label: str,
1079
+ domain: Domain,
1080
+ bounds: tuple[float | None, float | None] = (None, None),
1081
+ encoding: type[Encoding] = Bitwise,
1082
+ precision: float = 1e-2,
1083
+ ) -> None:
1084
+ """
1085
+
1086
+ Args:
1087
+ label (str): The name of the variable.
1088
+ domain (Domain): The domain of the values this variable can take.
1089
+ bounds (tuple[float | None, float | None], optional): the bounds on the values of the variable The bounds
1090
+ have the structure (lower_bound, Upper_bound) both values included. Defaults to (None, None).
1091
+ Note: if None is selected then the lowest/highest possible value of the variable's domain is chosen.
1092
+ encoding (type[Encoding], optional): _description_. Defaults to Bitwise.
1093
+ precision (float, optional): The floating point precision for REAL variables. Defaults to 1e-2.
1094
+ """
1095
+ super().__init__(label=label, domain=domain, bounds=bounds)
1096
+ self._encoding = encoding
1097
+ self._precision = 1e-2
1098
+ self._term: Term | None = None
1099
+ self._bin_vars: list[BaseVariable] = []
1100
+ self._precision = precision
1101
+
1102
+ @property
1103
+ def encoding(self) -> type[Encoding]:
1104
+ return self._encoding
1105
+
1106
+ @property
1107
+ def precision(self) -> float:
1108
+ return self._precision
1109
+
1110
+ @property
1111
+ def term(self) -> Term:
1112
+ if self._term is None:
1113
+ if self.bounds[1] > LARGE_BOUND or self.bounds[0] < -LARGE_BOUND:
1114
+ logger.warning(
1115
+ f"Encoding variable {self.label} which has the bounds {self.bounds}"
1116
+ + "is very expensive and may take a very long time."
1117
+ )
1118
+ self._term = self.to_binary()
1119
+ return self._term
1120
+
1121
+ @property
1122
+ def bin_vars(self) -> list[BaseVariable]:
1123
+ if self._term is None:
1124
+ self.to_binary()
1125
+ return self._bin_vars
1126
+
1127
+ def set_precision(self, precision: float) -> None:
1128
+ self._precision = precision
1129
+ self._term = None
1130
+
1131
+ def __copy__(self) -> Variable:
1132
+ return Variable(label=self.label, domain=self.domain, bounds=self.bounds, encoding=self._encoding)
1133
+
1134
+ def __getitem__(self, item: int) -> BaseVariable:
1135
+ if self._term is None:
1136
+ self.to_binary()
1137
+ return self._bin_vars[item]
1138
+
1139
+ def update_variable(
1140
+ self,
1141
+ domain: Domain,
1142
+ bounds: tuple[float | None, float | None] = (None, None),
1143
+ encoding: type[Encoding] | None = None,
1144
+ ) -> None:
1145
+ self._encoding = encoding if encoding is not None else self._encoding
1146
+ self._term = None
1147
+ return super().update_variable(domain, bounds)
1148
+
1149
+ def evaluate(self, value: list[int] | RealNumber) -> RealNumber:
1150
+ if isinstance(value, int | float):
1151
+ if not self.domain.check_value(value):
1152
+ raise ValueError(f"The value {value} is invalid for the domain {self.domain.value}")
1153
+ if value < self.lower_bound or value > self.upper_bound:
1154
+ raise ValueError(f"The value {value} is outside the defined bounds {self.bounds}")
1155
+ return value
1156
+ return self.encoding.evaluate(self, value, self._precision)
1157
+
1158
+ def to_binary(self) -> Term:
1159
+ if self._term is None:
1160
+ term = self.encoding.encode(self, precision=self._precision)
1161
+ self._term = copy.copy(term)
1162
+ self._bin_vars = [BinaryVariable(f"{self.label}({i})") for i in range(self.num_binary_equivalent())]
1163
+ self._bin_vars = sorted(
1164
+ self._bin_vars,
1165
+ key=lambda x: _extract_number(x.label),
1166
+ )
1167
+ return self._term
1168
+
1169
+ def num_binary_equivalent(self) -> int:
1170
+ """
1171
+ Returns:
1172
+ int: the number of binary variables needed to encode the continuous variable.
1173
+ """
1174
+ return self.encoding.num_binary_equivalent(self, precision=self._precision)
1175
+
1176
+ def check_valid(self, binary_list: list[int]) -> tuple[bool, int]:
1177
+ """checks if the binary list sample is a valid sample in the variable's encoding.
1178
+
1179
+ Args:
1180
+ binary_list (list[int] | int): a list of binary values or an integer value.
1181
+
1182
+ Returns:
1183
+ tuple[bool, int]: the boolean is True if the sample is a valid encoding,
1184
+ and the integer is the error in the encoding.
1185
+ """
1186
+ return self.encoding.check_valid(binary_list)
1187
+
1188
+ def encoding_constraint(self) -> ComparisonTerm:
1189
+ """Given a continuous variable return a Comparison Term that ensures that the encoding is respected.
1190
+
1191
+ Returns:
1192
+ ComparisonTerm: a Comparison Term that ensures the encoding is respected.
1193
+ """
1194
+ return self.encoding.encoding_constraint(self, precision=self._precision)
1195
+
1196
+
1197
+ @yaml.register_class
1198
+ class Parameter(BaseVariable):
1199
+ """
1200
+ Symbolic scalar used to parametrize expressions while remaining differentiable.
1201
+
1202
+ Example:
1203
+ .. code-block:: python
1204
+
1205
+ from qilisdk.core.variables import Parameter
1206
+
1207
+ theta = Parameter("theta", value=0.5)
1208
+ theta.set_value(0.75)
1209
+ """
1210
+
1211
+ def __init__(
1212
+ self,
1213
+ label: str,
1214
+ value: RealNumber,
1215
+ domain: Domain = Domain.REAL,
1216
+ bounds: tuple[float | None, float | None] = (None, None),
1217
+ ) -> None:
1218
+ super().__init__(label=label, domain=domain, bounds=bounds)
1219
+
1220
+ if not self.domain.check_value(value):
1221
+ raise ValueError(
1222
+ f"Parameter value provided ({value}) doesn't correspond to the parameter's domain ({self.domain.name})"
1223
+ )
1224
+ self._value = value
1225
+ self.set_bounds(bounds[0], bounds[1])
1226
+
1227
+ @property
1228
+ def value(self) -> RealNumber:
1229
+ return self._value
1230
+
1231
+ def set_value(self, value: RealNumber) -> None:
1232
+ if not self.domain.check_value(value):
1233
+ raise ValueError(
1234
+ f"Parameter value provided ({value}) doesn't correspond to the parameter's domain ({self.domain.name})"
1235
+ )
1236
+ if value > self.bounds[1] or value < self.bounds[0]:
1237
+ raise ValueError(f"The value provided ({value}) is outside the bound of the parameter {self.bounds}")
1238
+ self._value = value
1239
+
1240
+ def num_binary_equivalent(self) -> int: # noqa: PLR6301
1241
+ """
1242
+ Returns:
1243
+ int: the number of binary variables that are needed to represent this variable in the given encoding.
1244
+ """
1245
+ return 0
1246
+
1247
+ def evaluate(self, value: list[int] | RealNumber = 0) -> RealNumber:
1248
+ """Evaluates the value of the variable given a binary string or a number.
1249
+
1250
+ Args:
1251
+ value (list[int] | int | float): the value used to evaluate the variable.
1252
+ If the value provided is binary list (list[int]) then the value of the variable is evaluated based on
1253
+ its binary representation. This representation is constructed using the encoding, bounds and domain
1254
+ of the variable. To check the binary representation of a variable you can check the method `to_binary()`
1255
+
1256
+ Returns:
1257
+ float: the evaluated vale of the variable.
1258
+ """
1259
+ return self.value
1260
+
1261
+ def to_binary(self) -> Term:
1262
+ """Returns the binary representation of a variable.
1263
+
1264
+ Returns:
1265
+ Term: the binary representation of a variable.
1266
+ """
1267
+ return Term([self.value], operation=Operation.ADD)
1268
+
1269
+ def set_bounds(self, lower_bound: float | None, upper_bound: float | None) -> None:
1270
+ upper_bound = upper_bound if upper_bound is not None else self.domain.max()
1271
+ lower_bound = lower_bound if lower_bound is not None else self.domain.min()
1272
+ if self.value > upper_bound or self.value < lower_bound:
1273
+ raise ValueError(
1274
+ f"The current value of the parameter ({self.value}) is outside the bounds ({lower_bound}, {upper_bound})"
1275
+ )
1276
+ super().set_bounds(lower_bound, upper_bound)
1277
+
1278
+ def update_variable(self, domain: Domain, bounds: tuple[float | None, float | None] = (None, None)) -> None:
1279
+ if len(bounds) != 2: # noqa: PLR2004
1280
+ raise ValueError(
1281
+ "Invalid bounds provided: the bounds need to be a tuple with the format (lowe_bound, upper_bound)"
1282
+ )
1283
+
1284
+ if domain.check_value(self.value):
1285
+ self._domain = domain
1286
+ else:
1287
+ raise ValueError(
1288
+ f"The provided domain ({domain.name}) is incompatible with the current parameter value ({self.value})"
1289
+ )
1290
+
1291
+ self.set_bounds(lower_bound=bounds[0], upper_bound=bounds[1])
1292
+
1293
+
1294
+ # Terms ###
1295
+
1296
+
1297
+ @yaml.register_class
1298
+ class Term:
1299
+ """Represents a mathematical Term (e.g. 3x*y, 2x, ...).
1300
+
1301
+ And they are built from:
1302
+ - ``Variable``'s: The decision variables of the model (x, y, ...).
1303
+ - Other ``Term``'s: Allowing for complex expressions to be constructed.
1304
+ """
1305
+
1306
+ CONST = Variable(CONST_KEY, Domain.REAL)
1307
+
1308
+ def __init__(self, elements: Sequence[BaseVariable | Term | Number], operation: Operation) -> None:
1309
+ """initialize a new term object.
1310
+
1311
+ Args:
1312
+ elements (Sequence[BaseVariable | Term | Number]): a list of elements in the term.
1313
+ operation (Operation): the mathematical operation between these elements.
1314
+
1315
+ Raises:
1316
+ ValueError: if the items inside elements are not from the listed types (BaseVariable | Term | Number).
1317
+ """
1318
+ self._operation = operation
1319
+ self._elements: dict[BaseVariable | Term, Number] = {} # The list of elements in the term.
1320
+ # key: the term or variable | value: the coefficient corresponding to that value.
1321
+ for e in elements:
1322
+ if isinstance(e, BaseVariable):
1323
+ if e in self:
1324
+ if self._is_constant(e):
1325
+ self[e] = self._apply_operation_on_constants([self[e], 1])
1326
+ elif isinstance(e, BinaryVariable) and self.operation == Operation.MUL:
1327
+ self[e] = 1
1328
+ else:
1329
+ self[e] += 1
1330
+ else:
1331
+ self[e] = 1
1332
+ elif isinstance(e, Number):
1333
+ if self.CONST in self:
1334
+ self[self.CONST] = self._apply_operation_on_constants([self[self.CONST], e])
1335
+ else:
1336
+ self[self.CONST] = e
1337
+ elif isinstance(e, Term):
1338
+ if len(e) == 0:
1339
+ if self.CONST in self:
1340
+ self[self.CONST] = self._apply_operation_on_constants([self[self.CONST], 0])
1341
+ else:
1342
+ self[self.CONST] = 0
1343
+ elif e.operation == self._operation:
1344
+ for key in e:
1345
+ if key in self:
1346
+ if isinstance(key, BaseVariable) and self._is_constant(key):
1347
+ self[key] = self._apply_operation_on_constants([self[key], e[key]])
1348
+ elif isinstance(key, BinaryVariable) and self.operation == Operation.MUL:
1349
+ self[key] = 1
1350
+ else:
1351
+ self[key] += e[key]
1352
+ else:
1353
+ self[key] = e[key]
1354
+ else:
1355
+ e_copy = copy.copy(e)
1356
+ coeff = complex(1.0)
1357
+ if e_copy.operation == Operation.MUL and self.CONST in e_copy:
1358
+ coeff = e_copy.pop(self.CONST)
1359
+ simple_e = e_copy._simplify() # noqa: SLF001
1360
+ simple_e = self.CONST if isinstance(simple_e, Term) and len(simple_e) == 0 else simple_e
1361
+ if simple_e in self:
1362
+ if isinstance(simple_e, BaseVariable) and self._is_constant(simple_e):
1363
+ self[simple_e] = self._apply_operation_on_constants([self[simple_e], coeff])
1364
+ else:
1365
+ self[simple_e] += coeff
1366
+ else:
1367
+ self[simple_e] = coeff
1368
+ else:
1369
+ raise ValueError(
1370
+ f"Term accepts object of types Term or Variable but an object of type {e.__class__} was given"
1371
+ )
1372
+ self._remove_zeros()
1373
+
1374
+ @property
1375
+ def operation(self) -> Operation:
1376
+ """
1377
+ Returns:
1378
+ Operation: the operation between the term's elements.
1379
+ """
1380
+ return self._operation
1381
+
1382
+ @property
1383
+ def degree(self) -> int:
1384
+ """
1385
+ Returns:
1386
+ int: the highest degree in the term.
1387
+ """
1388
+ degree = 0
1389
+ if self.operation == Operation.MUL:
1390
+ for element in self:
1391
+ if isinstance(element, Term):
1392
+ degree += element.degree
1393
+ elif isinstance(element, BaseVariable) and not self._is_constant(element):
1394
+ degree += int(_assert_real(self[element]))
1395
+ return degree
1396
+
1397
+ for element in self:
1398
+ if isinstance(element, Term):
1399
+ degree = max(degree, element.degree)
1400
+ elif isinstance(element, BaseVariable) and not self._is_constant(element):
1401
+ degree = max(degree, 1)
1402
+ return degree
1403
+
1404
+ def to_binary(self) -> Term:
1405
+ """Returns the term in binary format. That is encoding all continuous variables into
1406
+ binary according to the encoding defined in the variable.
1407
+
1408
+ Raises:
1409
+ ValueError: The term contains operations that are not addition or multiplication.
1410
+ ValueError: the term contains an element that is not a Term or a BaseVariable.
1411
+
1412
+ Returns:
1413
+ Term: the term after transforming all the variables into binary.
1414
+ """
1415
+ if self.operation not in {Operation.ADD, Operation.MUL}:
1416
+ raise ValueError("Can not evaluate any operation that is not Addition of Multiplication")
1417
+ out_list: list[BaseVariable | Term | Number] = []
1418
+ for e in self:
1419
+ if isinstance(e, Term):
1420
+ out_list.append(self[e] * e.to_binary())
1421
+ elif isinstance(e, BaseVariable):
1422
+ if self._is_constant(e):
1423
+ out_list.append(self[e])
1424
+ elif isinstance(e, Variable):
1425
+ x = e.to_binary()
1426
+ if self.operation == Operation.MUL:
1427
+ out_list.append(x ** int(_assert_real(self[e])))
1428
+ else:
1429
+ out_list.append(self[e] * x)
1430
+ else:
1431
+ out_list.append(self[e] * e)
1432
+ else:
1433
+ raise ValueError(f"Evaluating term with elements of type {e.__class__} is not supported.")
1434
+
1435
+ return Term(out_list, self.operation)
1436
+
1437
+ def _apply_operation_on_constants(self, const_list: list[Number]) -> Number:
1438
+ out = complex(const_list[0])
1439
+ for i in range(1, len(const_list)):
1440
+ if self.operation is Operation.ADD:
1441
+ out += const_list[i]
1442
+ elif self.operation is Operation.SUB:
1443
+ out -= const_list[i]
1444
+ elif self.operation is Operation.MUL:
1445
+ out *= const_list[i]
1446
+ elif self.operation is Operation.DIV:
1447
+ out /= const_list[i]
1448
+
1449
+ return out
1450
+
1451
+ def variables(self) -> list[BaseVariable]:
1452
+ """Returns the unique list of variables in the Term
1453
+
1454
+ Returns:
1455
+ list[Variable]: The unique list of variables in the Term.
1456
+ """
1457
+ var = set()
1458
+ for e in self:
1459
+ if isinstance(e, BaseVariable) and not self._is_constant(e):
1460
+ var.add(e)
1461
+ elif isinstance(e, Term):
1462
+ var.update(e.variables())
1463
+ return sorted(var, key=lambda x: x.label)
1464
+
1465
+ def _simplify(self) -> Term | BaseVariable:
1466
+ """Simplify the term object.
1467
+
1468
+ Returns:
1469
+ (Term | BaseVariable): the simplified term.
1470
+ """
1471
+ if len(self) == 1:
1472
+ item = next(iter(self._elements.keys()))
1473
+ if self._elements[item] == 1:
1474
+ return item
1475
+ return self
1476
+
1477
+ def pop(self, item: BaseVariable | Term) -> Number:
1478
+ """Remove an item from the term.
1479
+
1480
+ Args:
1481
+ item (BaseVariable | Term): the item to be removed.
1482
+
1483
+ Raises:
1484
+ KeyError: if item is not in the term.
1485
+
1486
+ Returns:
1487
+ Number: the coefficient of the removed item.
1488
+ """
1489
+ try:
1490
+ return self._elements.pop(item)
1491
+ except KeyError as e:
1492
+ raise KeyError(f'item "{item}" not found in the term.') from e
1493
+
1494
+ def _is_constant(self, variable: BaseVariable) -> bool:
1495
+ """Checks if the variable is a constant variable as defined by the Term class.
1496
+
1497
+ Args:
1498
+ variable (BaseVariable): the variable to be checked.
1499
+
1500
+ Returns:
1501
+ bool: True if the variable is a constant, False otherwise.
1502
+ """
1503
+ return variable == self.CONST
1504
+
1505
+ def to_list(self) -> list[BaseVariable | Term | Number]:
1506
+ """Exports the current term into a list of its elements.
1507
+
1508
+ Returns:
1509
+ list[BaseVariable | Term | Number]: A list of the elements inside the term.
1510
+ """
1511
+ out_list: list[BaseVariable | Term | Number] = []
1512
+ for e in self:
1513
+ if isinstance(e, BaseVariable) and self._is_constant(e):
1514
+ out_list.append(self[e])
1515
+ elif self.operation == Operation.MUL:
1516
+ for _ in range(int(_assert_real(self[e]))):
1517
+ out_list.append(e)
1518
+ else:
1519
+ out_list.append(self[e] * e if self[e] != 1 else e)
1520
+ return out_list
1521
+
1522
+ def _unfold_parentheses(self) -> Term:
1523
+ """Simplifies any parentheses in the term expression.
1524
+
1525
+ Returns:
1526
+ Term: A new term with a more simplified form.
1527
+ """
1528
+ out = copy.copy(self)
1529
+ if out.operation != Operation.MUL:
1530
+ return out
1531
+
1532
+ parentheses: list[tuple[Term, Number]] = []
1533
+
1534
+ for e in out:
1535
+ if isinstance(e, Term) and e.operation == Operation.ADD:
1536
+ parentheses.append((copy.copy(e), out[e]))
1537
+
1538
+ for term, _ in parentheses:
1539
+ out.pop(term)
1540
+
1541
+ if len(out) == 0 and len(parentheses) != 0:
1542
+ out = Term([1], Operation.ADD)
1543
+
1544
+ for _term, coeff in parentheses:
1545
+ term = copy.copy(_term)
1546
+ _coeff = _assert_real(coeff)
1547
+ if _coeff > 1:
1548
+ term **= int(_coeff)
1549
+ final_out = []
1550
+ for t in term:
1551
+ final_out.append(t * out * term[t])
1552
+ out = Term(final_out, Operation.ADD)
1553
+
1554
+ return out
1555
+
1556
+ def _remove_zeros(self) -> None:
1557
+ """Simplifies any un-necessary zeros from terms."""
1558
+ to_be_popped = []
1559
+ if self.operation == Operation.MUL and self.CONST in self and self[self.CONST] == 0:
1560
+ l = len(self)
1561
+ for _ in range(l):
1562
+ self._elements.popitem()
1563
+ for e in self:
1564
+ if self[e] == 0:
1565
+ to_be_popped.append(e)
1566
+ for p in to_be_popped:
1567
+ self._elements.pop(p)
1568
+
1569
+ def evaluate(self, var_values: Mapping[BaseVariable, list[int] | RealNumber]) -> Number:
1570
+ """Evaluates the term given a set of values for the variables in the term.
1571
+
1572
+ Args:
1573
+ var_values (Mapping[BaseVariable, list[int] | Number]): the values of the variables in the term.
1574
+ If the value provided is binary list (list[int]) then the value of the variable is evaluated based on
1575
+ its binary representation. This representation is constructed using the encoding, bounds and domain
1576
+ of the variable. To check the binary representation of a variable you can check the method `to_binary()`
1577
+
1578
+ Raises:
1579
+ ValueError: if not all variables in the term are provided a value.
1580
+
1581
+ Returns:
1582
+ float: the result from evaluating the term.
1583
+ """
1584
+ if len(self._elements) == 0:
1585
+ return 0
1586
+ _var_values = dict(var_values)
1587
+ for var in self.variables():
1588
+ if isinstance(var, Parameter):
1589
+ _var_values[var] = var.value
1590
+ if var not in _var_values:
1591
+ raise ValueError(f"Can not evaluate term because the value of the variable {var} is not provided.")
1592
+ output = complex(0.0) if self.operation in {Operation.ADD, Operation.SUB} else complex(1.0)
1593
+ for e in self:
1594
+ if isinstance(e, Term):
1595
+ output = self._apply_operation_on_constants([output, e.evaluate(_var_values) * self[e]])
1596
+ elif isinstance(e, BaseVariable):
1597
+ if e == self.CONST:
1598
+ output = self._apply_operation_on_constants([output, self[e]])
1599
+ elif self.operation == Operation.MUL:
1600
+ output = self._apply_operation_on_constants([output, e.evaluate(_var_values[e]) ** self[e]])
1601
+ else:
1602
+ output = self._apply_operation_on_constants([output, e.evaluate(_var_values[e]) * self[e]])
1603
+ if isinstance(output, RealNumber):
1604
+ return float(output)
1605
+ if isinstance(output, complex) and output.imag == 0:
1606
+ return float(output.real)
1607
+ return output
1608
+
1609
+ def get_constant(self) -> Number:
1610
+ """
1611
+ Returns:
1612
+ Number: The constant value of the term.
1613
+ """
1614
+ if self.CONST in self:
1615
+ return self[self.CONST]
1616
+ return 0 if self.operation in {Operation.ADD, Operation.SUB} else 1
1617
+
1618
+ def is_parameterized_term(self) -> bool:
1619
+ return all(isinstance(var, Parameter) for var in self.variables())
1620
+
1621
+ def __copy__(self) -> Term:
1622
+ return Term(copy.copy(self.to_list()), copy.copy(self.operation))
1623
+
1624
+ def __repr__(self) -> str:
1625
+ if len(self) == 0:
1626
+ return "0"
1627
+ output_string = ""
1628
+ const = self.get_constant()
1629
+ keys = list(self._elements.keys())
1630
+
1631
+ if (
1632
+ (self.operation in {Operation.ADD, Operation.SUB} and const == 0)
1633
+ or (self.operation in {Operation.MUL, Operation.DIV} and const == 1)
1634
+ ) and Term.CONST in keys:
1635
+ keys.remove(Term.CONST)
1636
+
1637
+ for i, e in enumerate(keys):
1638
+ if isinstance(e, Term):
1639
+ term_str = str(e).strip()
1640
+ if len(term_str) > 0:
1641
+ if term_str[0] == "(" and term_str[-1] == ")":
1642
+ term_str = term_str.removeprefix("(").removesuffix(")")
1643
+ output_string += (
1644
+ f"({term_str}) " if self[e] == 1 else f"({_float_if_real(self[e])}) * ({term_str}) "
1645
+ )
1646
+ elif isinstance(e, BaseVariable):
1647
+ if self._is_constant(e):
1648
+ if self.operation in {Operation.ADD, Operation.SUB} and self[e] == 0:
1649
+ continue
1650
+ if self.operation in {Operation.MUL, Operation.DIV} and self[e] == 1:
1651
+ continue
1652
+ output_string += f"({_float_if_real(self[e])}) "
1653
+ elif (self.operation is Operation.MUL or self.operation is Operation.DIV) and _assert_real(self[e]) > 1:
1654
+ output_string += f"({e}^{_float_if_real(self[e])}) "
1655
+ else:
1656
+ output_string += f"{e} " if self[e] == 1 else f"({_float_if_real(self[e])}) * {e} "
1657
+ else:
1658
+ output_string += f"{e} "
1659
+ if i < len(keys) - 1:
1660
+ output_string += f"{self.operation.value} "
1661
+
1662
+ return output_string.strip()
1663
+
1664
+ __str__ = __repr__
1665
+
1666
+ def __getitem__(self, item: BaseVariable | Term) -> Number:
1667
+ return self._elements[item]
1668
+
1669
+ def __setitem__(self, key: BaseVariable | Term, item: Number) -> None:
1670
+ self._elements[key] = item
1671
+
1672
+ def __iter__(self) -> Iterator[BaseVariable | Term]:
1673
+ yield from self._elements
1674
+
1675
+ def __contains__(self, item: BaseVariable | Term) -> bool:
1676
+ return item in self._elements
1677
+
1678
+ __next__ = __iter__
1679
+
1680
+ def __len__(self) -> int:
1681
+ return len(self._elements)
1682
+
1683
+ def __add__(self, other: Number | BaseVariable | Term) -> Term:
1684
+ if not isinstance(other, (Number, BaseVariable, Term)):
1685
+ return NotImplemented
1686
+ out = self.to_list() if self.operation == Operation.ADD else [copy.copy(self)]
1687
+
1688
+ if isinstance(other, np.generic):
1689
+ other = other.item()
1690
+
1691
+ out.append(other)
1692
+ return Term(out, Operation.ADD)
1693
+
1694
+ __iadd__ = __add__
1695
+
1696
+ def __radd__(self, other: Number | BaseVariable | Term) -> Term:
1697
+ if not isinstance(other, (Number, BaseVariable, Term)):
1698
+ return NotImplemented
1699
+ out = self.to_list() if self.operation == Operation.ADD else [copy.copy(self)]
1700
+
1701
+ if isinstance(other, np.generic):
1702
+ other = other.item()
1703
+ out.insert(0, other)
1704
+ return Term(out, Operation.ADD)
1705
+
1706
+ def __mul__(self, other: Number | BaseVariable | Term) -> Term:
1707
+ if not isinstance(other, (Number, BaseVariable, Term)):
1708
+ return NotImplemented
1709
+ out = self.to_list() if self.operation == Operation.MUL else [copy.copy(self)]
1710
+ if len(out) == 0:
1711
+ out = [0]
1712
+
1713
+ if isinstance(other, np.generic):
1714
+ other = other.item()
1715
+
1716
+ out.append(other)
1717
+ return Term(out, Operation.MUL)._unfold_parentheses()
1718
+
1719
+ __imul__ = __mul__
1720
+
1721
+ def __rmul__(self, other: Number | BaseVariable | Term) -> Term:
1722
+ if not isinstance(other, (Number, BaseVariable, Term)):
1723
+ return NotImplemented
1724
+ out = self.to_list() if self.operation == Operation.MUL else [copy.copy(self)]
1725
+ if len(out) == 0:
1726
+ out = [0]
1727
+
1728
+ if isinstance(other, np.generic):
1729
+ other = other.item()
1730
+
1731
+ out.insert(0, other)
1732
+ return Term(out, Operation.MUL)._unfold_parentheses()
1733
+
1734
+ def __neg__(self) -> Term:
1735
+ return -1 * self
1736
+
1737
+ def __sub__(self, other: Number | BaseVariable | Term) -> Term:
1738
+ if not isinstance(other, (Number, BaseVariable, Term)):
1739
+ return NotImplemented
1740
+
1741
+ if isinstance(other, np.generic):
1742
+ other = other.item()
1743
+
1744
+ return self + -1 * other
1745
+
1746
+ def __rsub__(self, other: Number | BaseVariable | Term) -> Term:
1747
+ if not isinstance(other, (Number, BaseVariable, Term)):
1748
+ return NotImplemented
1749
+ return -1 * self + other
1750
+
1751
+ __isub__ = __sub__
1752
+
1753
+ def __truediv__(self, other: Number) -> Term:
1754
+ if not isinstance(other, Number):
1755
+ raise NotImplementedError("Only division by numbers is currently supported")
1756
+
1757
+ if other == 0:
1758
+ raise ValueError("Division by zero is not allowed")
1759
+
1760
+ other = 1 / other
1761
+ return self * other
1762
+
1763
+ __itruediv__ = __truediv__
1764
+
1765
+ def __rtruediv__(self, other: Number | BaseVariable | Term) -> Term:
1766
+ raise NotSupportedOperation("Only division by numbers is currently supported")
1767
+
1768
+ def __rfloordiv__(self, other: Number | BaseVariable | Term) -> Term:
1769
+ raise NotSupportedOperation("Only division by numbers is currently supported")
1770
+
1771
+ def __pow__(self, a: int) -> Term:
1772
+ if not isinstance(a, int):
1773
+ raise ValueError(f"Only integer exponents are allowed, but provided {type(a)}")
1774
+ if self.operation == Operation.ADD:
1775
+ out = copy.copy(self)
1776
+ for _ in range(a - 1):
1777
+ out_list = []
1778
+ for element in self:
1779
+ out_list.append(out * copy.copy(element) * self[element])
1780
+ out = Term(out_list, Operation.ADD)
1781
+ return out
1782
+ if self.operation == Operation.MUL:
1783
+ out = copy.copy(self)
1784
+ for element in out:
1785
+ if element is Term.CONST:
1786
+ out[element] **= a
1787
+ else:
1788
+ out[element] *= a
1789
+ return out
1790
+ raise NotImplementedError(
1791
+ "The power operation for terms that are not addition or multiplication is not supported."
1792
+ )
1793
+
1794
+ def __hash__(self) -> int:
1795
+ return hash((frozenset(self._elements.items()), self.operation))
1796
+
1797
+ def __eq__(self, other: object) -> bool:
1798
+ if not isinstance(other, Term):
1799
+ return False
1800
+ return hash(self) == hash(other)
1801
+
1802
+
1803
+ @yaml.register_class
1804
+ class ComparisonTerm:
1805
+ """Represents a mathematical comparison Term, that can be an equality or an inequality between two ``Term`` objects
1806
+ (e.g. x+y>0, x>2, ...).
1807
+
1808
+ They are built from a left and a right hand part, each of which can contain:
1809
+ - ``Variable``'s: The decision variables of the model (x, y, ...).
1810
+ - Other ``Term``'s: Allowing for complex expressions to be constructed (x+y, ...)
1811
+ """
1812
+
1813
+ def __init__(
1814
+ self,
1815
+ lhs: RealNumber | BaseVariable | Term,
1816
+ rhs: RealNumber | BaseVariable | Term,
1817
+ operation: ComparisonOperation,
1818
+ ) -> None:
1819
+ """Initializes a new comparison term.
1820
+
1821
+ Args:
1822
+ lhs (RealNumber | BaseVariable | Term): the left hand side of the comparison term.
1823
+ rhs (RealNumber | BaseVariable | Term): the right hand side of the comparison term.
1824
+ operation (ComparisonOperation): the comparison operations between the left and right hand sides.
1825
+ """
1826
+ term = lhs - rhs
1827
+ if not isinstance(term, Term):
1828
+ term = Term([term], Operation.ADD) # I don't think this line is reachable
1829
+ const = -1 * term.pop(Term.CONST) if Term.CONST in term else 0
1830
+ self._lhs = term
1831
+ self._rhs = Term([const], Operation.ADD)
1832
+ self._operation = operation
1833
+
1834
+ @property
1835
+ def operation(self) -> ComparisonOperation:
1836
+ """
1837
+ Returns:
1838
+ ComparisonOperation: the comparison operation between the left and right hand sides.
1839
+ """
1840
+ return self._operation
1841
+
1842
+ @property
1843
+ def lhs(self) -> Term:
1844
+ """
1845
+ Returns:
1846
+ Term: the left hand side of the comparison term.
1847
+ """
1848
+ return self._lhs
1849
+
1850
+ @property
1851
+ def rhs(self) -> Term:
1852
+ """
1853
+ Returns:
1854
+ Term: the right hand side of the comparison term.
1855
+ """
1856
+ return self._rhs
1857
+
1858
+ def variables(self) -> list[BaseVariable]:
1859
+ """Returns the unique list of variables in the Term
1860
+
1861
+ Returns:
1862
+ list[Variable]: The unique list of variables in the Term.
1863
+ """
1864
+ lhs_var = self._lhs.variables()
1865
+ rhs_var = self._rhs.variables()
1866
+
1867
+ var = set()
1868
+ var.update(lhs_var)
1869
+ var.update(rhs_var)
1870
+
1871
+ return sorted(var, key=lambda x: x.label)
1872
+
1873
+ @property
1874
+ def degree(self) -> int:
1875
+ """
1876
+ Returns:
1877
+ int: the maximum degree in the left and right hand sides of the comparison term.
1878
+ """
1879
+ return max(self.rhs.degree, self.lhs.degree)
1880
+
1881
+ def to_list(self) -> list:
1882
+ """Exports the comparison term into a list. The elements of the right hand side are first moved to the left hand
1883
+ side before the generation of the list. Therefore, you can assume that the right hand side will be zero.
1884
+
1885
+ Returns:
1886
+ list: a list constructed from all the elements in the left and right hand sides of the comparison term.
1887
+ """
1888
+ logger.info(
1889
+ "to_list(): The elements of output list assume the comparison term has been transformed "
1890
+ + f"from (lhs {self.operation.value} rhs) to (lhs - rhs {self.operation.value} 0).",
1891
+ )
1892
+ out = self.lhs.to_list()
1893
+ out.extend((-1 * self.rhs).to_list())
1894
+ return out
1895
+
1896
+ def to_binary(self) -> ComparisonTerm:
1897
+ """Returns the comparison term in binary format. That is encoding all continuous variables into
1898
+ binary according to the encoding defined in the variable.
1899
+
1900
+ Returns:
1901
+ ComparisonTerm: the comparison term after transforming all the variables into binary.
1902
+ """
1903
+ return ComparisonTerm(rhs=self.rhs.to_binary(), lhs=self.lhs.to_binary(), operation=self.operation)
1904
+
1905
+ def _apply_comparison_operation(self, v1: RealNumber, v2: RealNumber) -> bool:
1906
+ """Compare two arguments.
1907
+
1908
+ Args:
1909
+ v1 (Number): the left hand side value.
1910
+ v2 (Number): the right hand side value.
1911
+
1912
+ Raises:
1913
+ ValueError: if the comparison term's operation is invalid.
1914
+
1915
+ Returns:
1916
+ bool: the result of the comparison between v1 and v2 assuming the
1917
+ comparison operation of the comparison term object.
1918
+ """
1919
+ if self.operation is ComparisonOperation.EQ:
1920
+ return v1 == v2
1921
+ if self.operation is ComparisonOperation.GEQ:
1922
+ return v1 >= v2
1923
+ if self.operation is ComparisonOperation.GT:
1924
+ return v1 > v2
1925
+ if self.operation is ComparisonOperation.LEQ:
1926
+ return v1 <= v2
1927
+ if self.operation is ComparisonOperation.LT:
1928
+ return v1 < v2
1929
+ if self.operation is ComparisonOperation.NEQ:
1930
+ return v1 != v2
1931
+ raise ValueError(f"Unsupported Operation of type {self.operation.value}")
1932
+
1933
+ def evaluate(self, var_values: Mapping[BaseVariable, RealNumber | list[int]]) -> bool:
1934
+ """Evaluates the comparison term given a set of values for the variables in the term.
1935
+
1936
+ Args:
1937
+ var_values (Mapping[BaseVariable, list[int] | RealNumber]): the values of the variables in the comparison term.
1938
+
1939
+ Returns:
1940
+ bool: the result from evaluating the comparison term.
1941
+
1942
+ Raises:
1943
+ ValueError: if the constraint contains imaginary numbers.
1944
+ """
1945
+ lhs = self._lhs.evaluate(var_values)
1946
+ rhs = self._rhs.evaluate(var_values)
1947
+ if isinstance(lhs, complex):
1948
+ if lhs.imag != 0:
1949
+ raise ValueError("evaluating inequality constraints with complex values is not allowed")
1950
+ lhs = lhs.real
1951
+ if isinstance(rhs, complex):
1952
+ if rhs.imag != 0:
1953
+ raise ValueError("evaluating inequality constraints with complex values is not allowed")
1954
+ rhs = rhs.real
1955
+ return self._apply_comparison_operation(lhs, rhs)
1956
+
1957
+ def __copy__(self) -> ComparisonTerm:
1958
+ return ComparisonTerm(rhs=copy.copy(self.rhs), lhs=copy.copy(self.lhs), operation=self.operation)
1959
+
1960
+ def __repr__(self) -> str:
1961
+ return f"{str(self.lhs).strip()} {self.operation.value} {str(self.rhs).strip()}"
1962
+
1963
+ __str__ = __repr__
1964
+
1965
+ def __bool__(self) -> bool:
1966
+ raise TypeError(
1967
+ "Symbolic Constraint Term objects do not have an inherent truth value. "
1968
+ "Use a method like .evaluate() to obtain a Boolean value."
1969
+ )