classiq 0.84.0__py3-none-any.whl → 0.86.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. classiq/applications/combinatorial_optimization/combinatorial_problem.py +24 -45
  2. classiq/evaluators/classical_expression.py +32 -15
  3. classiq/evaluators/qmod_annotated_expression.py +207 -0
  4. classiq/evaluators/qmod_expression_visitors/__init__.py +0 -0
  5. classiq/evaluators/qmod_expression_visitors/qmod_expression_bwc.py +134 -0
  6. classiq/evaluators/qmod_expression_visitors/qmod_expression_evaluator.py +232 -0
  7. classiq/evaluators/qmod_expression_visitors/qmod_expression_renamer.py +44 -0
  8. classiq/evaluators/qmod_expression_visitors/qmod_expression_simplifier.py +308 -0
  9. classiq/evaluators/qmod_node_evaluators/__init__.py +0 -0
  10. classiq/evaluators/qmod_node_evaluators/attribute_evaluation.py +112 -0
  11. classiq/evaluators/qmod_node_evaluators/binary_op_evaluation.py +132 -0
  12. classiq/evaluators/qmod_node_evaluators/bool_op_evaluation.py +70 -0
  13. classiq/evaluators/qmod_node_evaluators/classical_function_evaluation.py +311 -0
  14. classiq/evaluators/qmod_node_evaluators/compare_evaluation.py +107 -0
  15. classiq/evaluators/qmod_node_evaluators/constant_evaluation.py +67 -0
  16. classiq/evaluators/qmod_node_evaluators/list_evaluation.py +107 -0
  17. classiq/evaluators/qmod_node_evaluators/measurement_evaluation.py +25 -0
  18. classiq/evaluators/qmod_node_evaluators/name_evaluation.py +50 -0
  19. classiq/evaluators/qmod_node_evaluators/struct_instantiation_evaluation.py +66 -0
  20. classiq/evaluators/qmod_node_evaluators/subscript_evaluation.py +225 -0
  21. classiq/evaluators/qmod_node_evaluators/unary_op_evaluation.py +58 -0
  22. classiq/evaluators/qmod_node_evaluators/utils.py +80 -0
  23. classiq/execution/execution_session.py +53 -6
  24. classiq/interface/_version.py +1 -1
  25. classiq/interface/analyzer/analysis_params.py +1 -1
  26. classiq/interface/analyzer/result.py +1 -1
  27. classiq/interface/debug_info/debug_info.py +0 -4
  28. classiq/interface/executor/quantum_code.py +2 -2
  29. classiq/interface/generator/arith/arithmetic_expression_validator.py +5 -1
  30. classiq/interface/generator/arith/binary_ops.py +43 -51
  31. classiq/interface/generator/arith/number_utils.py +3 -2
  32. classiq/interface/generator/arith/register_user_input.py +15 -0
  33. classiq/interface/generator/arith/unary_ops.py +32 -28
  34. classiq/interface/generator/expressions/atomic_expression_functions.py +5 -0
  35. classiq/interface/generator/expressions/expression_types.py +2 -2
  36. classiq/interface/generator/expressions/proxies/classical/qmod_struct_instance.py +7 -0
  37. classiq/interface/generator/functions/builtins/internal_operators.py +2 -0
  38. classiq/interface/generator/functions/classical_function_declaration.py +0 -4
  39. classiq/interface/generator/functions/classical_type.py +0 -32
  40. classiq/interface/generator/functions/concrete_types.py +20 -0
  41. classiq/interface/generator/generated_circuit_data.py +7 -10
  42. classiq/interface/generator/quantum_program.py +6 -1
  43. classiq/interface/generator/synthesis_metadata/synthesis_execution_data.py +29 -0
  44. classiq/interface/ide/operation_registry.py +45 -0
  45. classiq/interface/ide/visual_model.py +84 -2
  46. classiq/interface/model/bounds.py +12 -2
  47. classiq/interface/model/quantum_expressions/arithmetic_operation.py +7 -4
  48. classiq/interface/model/quantum_type.py +67 -33
  49. classiq/interface/model/variable_declaration_statement.py +33 -6
  50. classiq/model_expansions/arithmetic.py +115 -0
  51. classiq/model_expansions/arithmetic_compute_result_attrs.py +71 -0
  52. classiq/model_expansions/atomic_expression_functions_defs.py +10 -6
  53. classiq/model_expansions/function_builder.py +4 -1
  54. classiq/model_expansions/generative_functions.py +15 -2
  55. classiq/model_expansions/interpreters/base_interpreter.py +7 -0
  56. classiq/model_expansions/interpreters/generative_interpreter.py +18 -1
  57. classiq/model_expansions/quantum_operations/assignment_result_processor.py +63 -21
  58. classiq/model_expansions/quantum_operations/bounds.py +7 -1
  59. classiq/model_expansions/quantum_operations/call_emitter.py +5 -2
  60. classiq/model_expansions/quantum_operations/classical_var_emitter.py +16 -0
  61. classiq/model_expansions/quantum_operations/variable_decleration.py +30 -10
  62. classiq/model_expansions/scope.py +7 -0
  63. classiq/model_expansions/scope_initialization.py +2 -0
  64. classiq/model_expansions/sympy_conversion/sympy_to_python.py +1 -1
  65. classiq/model_expansions/transformers/type_modifier_inference.py +5 -0
  66. classiq/model_expansions/transformers/var_splitter.py +1 -1
  67. classiq/model_expansions/visitors/boolean_expression_transformers.py +1 -1
  68. classiq/open_library/functions/__init__.py +0 -2
  69. classiq/open_library/functions/qaoa_penalty.py +8 -1
  70. classiq/open_library/functions/state_preparation.py +1 -32
  71. classiq/qmod/__init__.py +2 -0
  72. classiq/qmod/builtins/operations.py +66 -2
  73. classiq/qmod/classical_variable.py +74 -0
  74. classiq/qmod/declaration_inferrer.py +5 -3
  75. classiq/qmod/native/pretty_printer.py +18 -14
  76. classiq/qmod/pretty_print/pretty_printer.py +34 -15
  77. classiq/qmod/qfunc.py +2 -19
  78. classiq/qmod/qmod_variable.py +5 -8
  79. classiq/qmod/quantum_expandable.py +1 -1
  80. classiq/qmod/quantum_function.py +42 -2
  81. classiq/qmod/symbolic_type.py +2 -1
  82. classiq/qmod/write_qmod.py +3 -1
  83. {classiq-0.84.0.dist-info → classiq-0.86.0.dist-info}/METADATA +1 -1
  84. {classiq-0.84.0.dist-info → classiq-0.86.0.dist-info}/RECORD +86 -62
  85. classiq/interface/model/quantum_variable_declaration.py +0 -7
  86. /classiq/{model_expansions/sympy_conversion/arithmetics.py → evaluators/qmod_expression_visitors/sympy_wrappers.py} +0 -0
  87. {classiq-0.84.0.dist-info → classiq-0.86.0.dist-info}/WHEEL +0 -0
@@ -1,12 +1,11 @@
1
1
  import math
2
2
  import re
3
- from typing import Callable, Optional
3
+ from typing import Callable, Optional, cast
4
4
 
5
5
  import numpy as np
6
6
  import pandas as pd
7
7
  import pyomo.core as pyo
8
8
  import scipy
9
- from tqdm import tqdm
10
9
 
11
10
  from classiq.interface.executor.execution_preferences import ExecutionPreferences
12
11
  from classiq.interface.executor.result import ExecutionDetails
@@ -44,22 +43,22 @@ class CombinatorialProblem:
44
43
  self.num_layers_ = num_layers
45
44
  self.model_ = None
46
45
  self.qprog_ = None
47
- self.es_ = None
48
- self.optimized_params_ = None
46
+ self._es: ExecutionSession | None = None
47
+ self._optimized_params: list[float] | None = None
49
48
  self.params_trace_: list[np.ndarray] = []
50
- self.cost_trace_: list = []
49
+ self._cost_trace: list[float] = []
51
50
 
52
51
  @property
53
- def cost_trace(self) -> list:
54
- return self.cost_trace_
52
+ def cost_trace(self) -> list[float]:
53
+ return self._cost_trace
55
54
 
56
55
  @property
57
56
  def params_trace(self) -> list[np.ndarray]:
58
57
  return self.params_trace_
59
58
 
60
59
  @property
61
- def optimized_params(self) -> list:
62
- return self.optimized_params_ # type:ignore[return-value]
60
+ def optimized_params(self) -> list[float]:
61
+ return self._optimized_params # type:ignore[return-value]
63
62
 
64
63
  def get_model(
65
64
  self,
@@ -100,22 +99,9 @@ class CombinatorialProblem:
100
99
  ) -> list[float]:
101
100
  if self.qprog_ is None:
102
101
  self.get_qprog()
103
- self.es_ = ExecutionSession(
104
- self.qprog_, execution_preferences # type:ignore[assignment,arg-type]
102
+ _es = ExecutionSession(
103
+ self.qprog_, execution_preferences # type:ignore[arg-type]
105
104
  )
106
- self.params_trace_ = []
107
- self.cost_trace_ = []
108
-
109
- def estimate_cost_wrapper(params: np.ndarray) -> float:
110
- cost = self.es_.estimate_cost( # type:ignore[attr-defined]
111
- lambda state: self.cost_func(state["v"]),
112
- {"params": params.tolist()},
113
- quantile=quantile,
114
- )
115
- self.cost_trace_.append(cost)
116
- self.params_trace_.append(params)
117
- return cost
118
-
119
105
  initial_params = (
120
106
  np.concatenate(
121
107
  (
@@ -125,31 +111,24 @@ class CombinatorialProblem:
125
111
  )
126
112
  * math.pi
127
113
  )
128
-
129
- with tqdm(total=maxiter, desc="Optimization Progress", leave=True) as pbar:
130
-
131
- def _minimze_callback(xk: np.ndarray) -> None:
132
- pbar.update(1) # increment progress bar
133
- self.optimized_params_ = xk.tolist() # save recent optimized value
134
-
135
- self.optimized_params_ = scipy.optimize.minimize(
136
- estimate_cost_wrapper,
137
- callback=_minimze_callback,
138
- x0=initial_params,
139
- method="COBYLA",
140
- options={"maxiter": maxiter},
141
- ).x.tolist()
142
-
143
- return self.optimized_params_ # type:ignore[return-value]
114
+ result = _es.minimize(
115
+ lambda v: self.cost_func(v), # type:ignore[arg-type]
116
+ {"params": initial_params.tolist()},
117
+ maxiter,
118
+ quantile,
119
+ )
120
+ _optimized_params = cast(list[float], result[-1][1]["params"])
121
+ self._optimized_params = _optimized_params
122
+ self._cost_trace = [cost for cost, _ in result]
123
+ self._es = _es
124
+ return _optimized_params
144
125
 
145
126
  def sample_uniform(self) -> pd.DataFrame:
146
127
  return self.sample([0] * self.num_layers_ * 2)
147
128
 
148
129
  def sample(self, params: list) -> pd.DataFrame:
149
- assert self.es_ is not None
150
- res = self.es_.sample( # type:ignore[unreachable]
151
- {"params": params}
152
- )
130
+ assert self._es is not None
131
+ res = self._es.sample({"params": params})
153
132
  parsed_result = [
154
133
  {
155
134
  "solution": {
@@ -157,7 +136,7 @@ class CombinatorialProblem:
157
136
  for key, value in sampled.state["v"].items()
158
137
  if not re.match(".*_slack_var_.*", key)
159
138
  },
160
- "probability": sampled.shots / res.num_shots,
139
+ "probability": sampled.shots / res.num_shots, # type:ignore[operator]
161
140
  "cost": self.cost_func(sampled.state["v"]),
162
141
  }
163
142
  for sampled in res.parsed_counts
@@ -11,26 +11,43 @@ from classiq.interface.generator.expressions.proxies.classical.any_classical_val
11
11
  from classiq.interface.model.handle_binding import HandleBinding
12
12
 
13
13
  from classiq.evaluators.expression_evaluator import evaluate
14
- from classiq.model_expansions.scope import Evaluated, QuantumSymbol, Scope
14
+ from classiq.model_expansions.scope import (
15
+ ClassicalSymbol,
16
+ Evaluated,
17
+ QuantumSymbol,
18
+ Scope,
19
+ )
15
20
 
16
21
 
17
22
  def evaluate_classical_expression(expr: Expression, scope: Scope) -> Evaluated:
18
23
  all_symbols = scope.items()
19
- locals_dict = {
20
- name: EvaluatedExpression(value=evaluated.value)
21
- for name, evaluated in all_symbols
22
- if isinstance(evaluated.value, get_args(ExpressionValue))
23
- } | {
24
- name: EvaluatedExpression(
25
- value=(
26
- evaluated.value.quantum_type.get_proxy(HandleBinding(name=name))
27
- if evaluated.value.quantum_type.is_evaluated
28
- else AnyClassicalValue(name)
24
+ locals_dict = (
25
+ {
26
+ name: EvaluatedExpression(value=evaluated.value)
27
+ for name, evaluated in all_symbols
28
+ if isinstance(evaluated.value, get_args(ExpressionValue))
29
+ }
30
+ | {
31
+ name: EvaluatedExpression(
32
+ value=(
33
+ evaluated.value.quantum_type.get_proxy(HandleBinding(name=name))
34
+ if evaluated.value.quantum_type.is_evaluated
35
+ else AnyClassicalValue(name)
36
+ )
37
+ )
38
+ for name, evaluated in all_symbols
39
+ if isinstance(evaluated.value, QuantumSymbol)
40
+ }
41
+ | {
42
+ name: EvaluatedExpression(
43
+ value=evaluated.value.classical_type.get_classical_proxy(
44
+ HandleBinding(name=name)
45
+ )
29
46
  )
30
- )
31
- for name, evaluated in all_symbols
32
- if isinstance(evaluated.value, QuantumSymbol)
33
- }
47
+ for name, evaluated in all_symbols
48
+ if isinstance(evaluated.value, ClassicalSymbol)
49
+ }
50
+ )
34
51
 
35
52
  ret = evaluate(expr, locals_dict)
36
53
  return Evaluated(value=ret.value)
@@ -0,0 +1,207 @@
1
+ import ast
2
+ from collections.abc import Sequence
3
+ from dataclasses import dataclass
4
+ from typing import Any, Union, cast
5
+
6
+ from classiq.interface.model.handle_binding import HandleBinding
7
+
8
+ from classiq.evaluators.qmod_node_evaluators.utils import QmodType, is_classical_type
9
+
10
+ QmodExprNodeId = int
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class QuantumSubscriptAnnotation:
15
+ value: QmodExprNodeId
16
+ index: QmodExprNodeId
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class QuantumTypeAttributeAnnotation:
21
+ value: QmodExprNodeId
22
+ attr: str
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class ConcatenationAnnotation:
27
+ elements: list[QmodExprNodeId]
28
+
29
+
30
+ class _ExprInliner(ast.NodeTransformer):
31
+ def __init__(self, expr_val: "QmodAnnotatedExpression") -> None:
32
+ self._expr_val = expr_val
33
+
34
+ def visit(self, node: ast.AST) -> Any:
35
+ if self._expr_val.has_value(node):
36
+ return ast.Name(id=str(self._expr_val.get_value(node)))
37
+ if self._expr_val.has_var(node):
38
+ return ast.Name(id=str(self._expr_val.get_var(node)))
39
+ return super().visit(node)
40
+
41
+
42
+ class QmodAnnotatedExpression:
43
+ def __init__(self, expr_ast: ast.AST) -> None:
44
+ self.root = expr_ast
45
+ self._node_mapping: dict[QmodExprNodeId, ast.AST] = {}
46
+ self._values: dict[QmodExprNodeId, Any] = {}
47
+ self._types: dict[QmodExprNodeId, QmodType] = {}
48
+ self._classical_vars: dict[QmodExprNodeId, HandleBinding] = {}
49
+ self._quantum_vars: dict[QmodExprNodeId, HandleBinding] = {}
50
+ self._quantum_subscripts: dict[QmodExprNodeId, QuantumSubscriptAnnotation] = {}
51
+ self._quantum_type_attrs: dict[
52
+ QmodExprNodeId, QuantumTypeAttributeAnnotation
53
+ ] = {}
54
+ self._concatenations: dict[QmodExprNodeId, ConcatenationAnnotation] = {}
55
+
56
+ def to_qmod_expr(self) -> str:
57
+ return ast.unparse(_ExprInliner(self).visit(self.root))
58
+
59
+ def has_node(self, node_id: QmodExprNodeId) -> bool:
60
+ return node_id in self._node_mapping
61
+
62
+ def get_node(self, node_id: QmodExprNodeId) -> ast.AST:
63
+ return self._node_mapping[node_id]
64
+
65
+ def set_value(self, node: Union[ast.AST, QmodExprNodeId], value: Any) -> None:
66
+ if isinstance(node, ast.AST):
67
+ node = id(node)
68
+ self._values[node] = value
69
+
70
+ def get_value(self, node: Union[ast.AST, QmodExprNodeId]) -> Any:
71
+ if isinstance(node, ast.AST):
72
+ node = id(node)
73
+ return self._values[node]
74
+
75
+ def has_value(self, node: Union[ast.AST, QmodExprNodeId]) -> bool:
76
+ if isinstance(node, ast.AST):
77
+ node = id(node)
78
+ return node in self._values
79
+
80
+ def set_type(
81
+ self, node: Union[ast.AST, QmodExprNodeId], qmod_type: QmodType
82
+ ) -> None:
83
+ if isinstance(node, ast.AST):
84
+ node_id = id(node)
85
+ self._node_mapping[node_id] = node
86
+ node = id(node)
87
+ self._types[node] = qmod_type
88
+
89
+ def get_type(self, node: Union[ast.AST, QmodExprNodeId]) -> QmodType:
90
+ if isinstance(node, ast.AST):
91
+ node = id(node)
92
+ return self._types[node]
93
+
94
+ def set_var(self, node: Union[ast.AST, QmodExprNodeId], var: HandleBinding) -> None:
95
+ var = var.collapse()
96
+ if isinstance(node, ast.AST):
97
+ node = id(node)
98
+ if is_classical_type(self.get_type(node)):
99
+ self._classical_vars[node] = var
100
+ else:
101
+ self._quantum_vars[node] = var
102
+
103
+ def get_var(self, node: Union[ast.AST, QmodExprNodeId]) -> HandleBinding:
104
+ if isinstance(node, ast.AST):
105
+ node = id(node)
106
+ return (self._classical_vars | self._quantum_vars)[node]
107
+
108
+ def has_var(self, node: Union[ast.AST, QmodExprNodeId]) -> bool:
109
+ return self.has_classical_var(node) or self.has_quantum_var(node)
110
+
111
+ def has_classical_var(self, node: Union[ast.AST, QmodExprNodeId]) -> bool:
112
+ if isinstance(node, ast.AST):
113
+ node = id(node)
114
+ return node in self._classical_vars
115
+
116
+ def has_quantum_var(self, node: Union[ast.AST, QmodExprNodeId]) -> bool:
117
+ if isinstance(node, ast.AST):
118
+ node = id(node)
119
+ return node in self._quantum_vars
120
+
121
+ def remove_var(self, node: Union[ast.AST, QmodExprNodeId]) -> None:
122
+ if isinstance(node, ast.AST):
123
+ node = id(node)
124
+ if node in self._classical_vars:
125
+ self._classical_vars.pop(node)
126
+ else:
127
+ self._quantum_vars.pop(node)
128
+
129
+ def set_quantum_subscript(
130
+ self,
131
+ node: Union[ast.AST, QmodExprNodeId],
132
+ value: Union[ast.AST, QmodExprNodeId],
133
+ index: Union[ast.AST, QmodExprNodeId],
134
+ ) -> None:
135
+ if isinstance(node, ast.AST):
136
+ node = id(node)
137
+ if isinstance(value, ast.AST):
138
+ value = id(value)
139
+ if isinstance(index, ast.AST):
140
+ index = id(index)
141
+ self._quantum_subscripts[node] = QuantumSubscriptAnnotation(
142
+ value=value, index=index
143
+ )
144
+
145
+ def has_quantum_subscript(self, node: Union[ast.AST, QmodExprNodeId]) -> bool:
146
+ if isinstance(node, ast.AST):
147
+ node = id(node)
148
+ return node in self._quantum_subscripts
149
+
150
+ def get_quantum_subscripts(
151
+ self,
152
+ ) -> dict[QmodExprNodeId, QuantumSubscriptAnnotation]:
153
+ return self._quantum_subscripts
154
+
155
+ def set_quantum_type_attr(
156
+ self,
157
+ node: Union[ast.AST, QmodExprNodeId],
158
+ value: Union[ast.AST, QmodExprNodeId],
159
+ attr: str,
160
+ ) -> None:
161
+ if isinstance(node, ast.AST):
162
+ node = id(node)
163
+ if isinstance(value, ast.AST):
164
+ value = id(value)
165
+ self._quantum_type_attrs[node] = QuantumTypeAttributeAnnotation(
166
+ value=value, attr=attr
167
+ )
168
+
169
+ def has_quantum_type_attribute(self, node: Union[ast.AST, QmodExprNodeId]) -> bool:
170
+ if isinstance(node, ast.AST):
171
+ node = id(node)
172
+ return node in self._quantum_type_attrs
173
+
174
+ def get_quantum_type_attributes(
175
+ self,
176
+ ) -> dict[QmodExprNodeId, QuantumTypeAttributeAnnotation]:
177
+ return self._quantum_type_attrs
178
+
179
+ def set_concatenation(
180
+ self,
181
+ node: Union[ast.AST, QmodExprNodeId],
182
+ elements: Sequence[Union[ast.AST, QmodExprNodeId]],
183
+ ) -> None:
184
+ if isinstance(node, ast.AST):
185
+ node = id(node)
186
+ elements = cast(
187
+ list[QmodExprNodeId],
188
+ [
189
+ id(element) if isinstance(element, ast.AST) else element
190
+ for element in elements
191
+ ],
192
+ )
193
+ self._concatenations[node] = ConcatenationAnnotation(elements=elements)
194
+
195
+ def has_concatenation(self, node: Union[ast.AST, QmodExprNodeId]) -> bool:
196
+ if isinstance(node, ast.AST):
197
+ node = id(node)
198
+ return node in self._concatenations
199
+
200
+ def get_concatenations(self) -> dict[QmodExprNodeId, ConcatenationAnnotation]:
201
+ return self._concatenations
202
+
203
+ def get_classical_vars(self) -> dict[QmodExprNodeId, HandleBinding]:
204
+ return self._classical_vars
205
+
206
+ def get_quantum_vars(self) -> dict[QmodExprNodeId, HandleBinding]:
207
+ return self._quantum_vars
@@ -0,0 +1,134 @@
1
+ import ast
2
+ from typing import Any, cast
3
+
4
+ import sympy
5
+
6
+ MAX_PIECEWISE_LOOPS = 1000
7
+
8
+
9
+ # FIXME: Remove with deprecation (CLS-3214)
10
+ class QmodExpressionBwc(ast.NodeTransformer):
11
+ def visit_Call(self, node: ast.Call) -> Any:
12
+ node = cast(ast.Call, self.generic_visit(node))
13
+ if not isinstance(node.func, ast.Name):
14
+ return node
15
+ func = node.func.id
16
+ args = node.args
17
+ kwargs = node.keywords
18
+ num_args = len(args)
19
+ num_kwargs = len(kwargs)
20
+
21
+ if func == "BitwiseNot":
22
+ if num_args != 1 or num_kwargs != 0:
23
+ return node
24
+ return ast.UnaryOp(op=ast.Invert(), operand=args[0])
25
+ if func == "LShift":
26
+ if num_args != 2 or num_kwargs != 0:
27
+ return node
28
+ return ast.BinOp(left=args[0], op=ast.LShift(), right=args[1])
29
+ if func == "RShift":
30
+ if num_args != 2 or num_kwargs != 0:
31
+ return node
32
+ return ast.BinOp(left=args[0], op=ast.RShift(), right=args[1])
33
+ if func == "BitwiseOr":
34
+ if num_args != 2 or num_kwargs != 0:
35
+ return node
36
+ return ast.BinOp(left=args[0], op=ast.BitOr(), right=args[1])
37
+ if func == "BitwiseXor":
38
+ if num_args != 2 or num_kwargs != 0:
39
+ return node
40
+ return ast.BinOp(left=args[0], op=ast.BitXor(), right=args[1])
41
+ if func == "BitwiseAnd":
42
+ if num_args != 2 or num_kwargs != 0:
43
+ return node
44
+ return ast.BinOp(left=args[0], op=ast.BitAnd(), right=args[1])
45
+
46
+ if func == "LogicalXor":
47
+ if num_args != 2 or num_kwargs != 0:
48
+ return node
49
+ return ast.BinOp(left=args[0], op=ast.BitXor(), right=args[1])
50
+
51
+ if func == "Eq":
52
+ if num_args != 2 or num_kwargs != 0:
53
+ return node
54
+ return ast.Compare(left=args[0], ops=[ast.Eq()], comparators=[args[1]])
55
+ if func == "Ne":
56
+ if num_args != 2 or num_kwargs != 0:
57
+ return node
58
+ return ast.Compare(left=args[0], ops=[ast.NotEq()], comparators=[args[1]])
59
+ if func == "Lt":
60
+ if num_args != 2 or num_kwargs != 0:
61
+ return node
62
+ return ast.Compare(left=args[0], ops=[ast.Lt()], comparators=[args[1]])
63
+ if func == "Le":
64
+ if num_args != 2 or num_kwargs != 0:
65
+ return node
66
+ return ast.Compare(left=args[0], ops=[ast.LtE()], comparators=[args[1]])
67
+ if func == "Gt":
68
+ if num_args != 2 or num_kwargs != 0:
69
+ return node
70
+ return ast.Compare(left=args[0], ops=[ast.Gt()], comparators=[args[1]])
71
+ if func == "Ge":
72
+ if num_args != 2 or num_kwargs != 0:
73
+ return node
74
+ return ast.Compare(left=args[0], ops=[ast.GtE()], comparators=[args[1]])
75
+
76
+ if func == "struct_literal":
77
+ if num_args != 1:
78
+ return node
79
+ return ast.Call(func=node.args[0], args=[], keywords=node.keywords)
80
+
81
+ if func == "do_subscript":
82
+ if num_args != 2 or num_kwargs != 0:
83
+ return node
84
+ return ast.Subscript(value=args[0], slice=args[1])
85
+
86
+ if func == "get_field":
87
+ if num_args != 2 or num_kwargs != 0:
88
+ return node
89
+ attr = args[1]
90
+ if not isinstance(attr, ast.Constant):
91
+ return node
92
+ return ast.Attribute(value=args[0], attr=attr.value)
93
+
94
+ if func == "Piecewise":
95
+ if num_args == 0:
96
+ return node
97
+ first_piece = args[0]
98
+ if not isinstance(first_piece, ast.Tuple) or len(first_piece.elts) != 2:
99
+ return node
100
+ first_cond = first_piece.elts[1]
101
+ if isinstance(first_cond, ast.BinOp):
102
+ first_cond = first_cond.right
103
+ if not isinstance(first_cond, ast.Compare) or len(first_cond.ops) != 1:
104
+ return node
105
+ index_var_node = first_cond.left
106
+ if not isinstance(index_var_node, ast.Name):
107
+ return node
108
+ index_var = index_var_node.id
109
+ last_cond = args[-1]
110
+ if not isinstance(last_cond, ast.Tuple) or len(last_cond.elts) != 2:
111
+ return node
112
+ last_value = last_cond.elts[0]
113
+ if not isinstance(last_value, ast.Constant) and (
114
+ not isinstance(last_value, ast.UnaryOp)
115
+ or not isinstance(last_value.operand, ast.Constant)
116
+ ):
117
+ return node
118
+ dummy_var_name = f"{index_var}_not_it"
119
+ last_cond.elts[0] = ast.Name(id=dummy_var_name)
120
+ items: list = []
121
+ idx = 0
122
+ for idx in range(MAX_PIECEWISE_LOOPS):
123
+ item = sympy.sympify(ast.unparse(node), locals={index_var: idx})
124
+ if str(item) == dummy_var_name:
125
+ items.append(last_value)
126
+ break
127
+ items.append(ast.parse(str(item), mode="eval").body)
128
+ if idx == MAX_PIECEWISE_LOOPS:
129
+ return node
130
+ return ast.Subscript(
131
+ value=ast.List(elts=items), slice=ast.Name(id=index_var)
132
+ )
133
+
134
+ return node