classiq 0.55.0__py3-none-any.whl → 0.56.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. classiq/interface/_version.py +1 -1
  2. classiq/interface/debug_info/debug_info.py +11 -0
  3. classiq/interface/executor/result.py +0 -3
  4. classiq/interface/generator/visitor.py +13 -1
  5. classiq/interface/ide/visual_model.py +2 -0
  6. classiq/interface/interface_version.py +1 -1
  7. classiq/interface/model/handle_binding.py +28 -0
  8. classiq/interface/model/quantum_statement.py +3 -0
  9. classiq/model_expansions/capturing/mangling_utils.py +22 -0
  10. classiq/model_expansions/capturing/propagated_var_stack.py +36 -25
  11. classiq/model_expansions/closure.py +3 -1
  12. classiq/model_expansions/function_builder.py +9 -4
  13. classiq/model_expansions/interpreter.py +7 -1
  14. classiq/model_expansions/quantum_operations/control.py +35 -9
  15. classiq/model_expansions/quantum_operations/emitter.py +13 -3
  16. classiq/model_expansions/quantum_operations/expression_operation.py +55 -17
  17. classiq/model_expansions/quantum_operations/power.py +5 -0
  18. classiq/model_expansions/quantum_operations/quantum_assignment_operation.py +4 -11
  19. classiq/model_expansions/quantum_operations/repeat.py +5 -0
  20. classiq/model_expansions/scope_initialization.py +2 -2
  21. classiq/qmod/__init__.py +0 -2
  22. classiq/qmod/builtins/functions/arithmetic.py +0 -2
  23. classiq/qmod/builtins/functions/discrete_sine_cosine_transform.py +0 -12
  24. classiq/qmod/builtins/functions/exponentiation.py +0 -6
  25. classiq/qmod/builtins/functions/grover.py +0 -17
  26. classiq/qmod/builtins/functions/linear_pauli_rotation.py +0 -5
  27. classiq/qmod/builtins/functions/modular_exponentiation.py +0 -3
  28. classiq/qmod/builtins/functions/qaoa_penalty.py +0 -8
  29. classiq/qmod/builtins/functions/qft_functions.py +0 -3
  30. classiq/qmod/builtins/functions/qpe.py +0 -6
  31. classiq/qmod/builtins/functions/qsvt.py +0 -12
  32. classiq/qmod/builtins/functions/standard_gates.py +0 -88
  33. classiq/qmod/builtins/functions/state_preparation.py +7 -15
  34. classiq/qmod/builtins/functions/swap_test.py +0 -3
  35. classiq/qmod/builtins/operations.py +39 -0
  36. classiq/qmod/qfunc.py +33 -1
  37. classiq/qmod/qmod_constant.py +31 -3
  38. {classiq-0.55.0.dist-info → classiq-0.56.1.dist-info}/METADATA +2 -3
  39. {classiq-0.55.0.dist-info → classiq-0.56.1.dist-info}/RECORD +40 -41
  40. classiq/qmod/synthesize_separately.py +0 -15
  41. {classiq-0.55.0.dist-info → classiq-0.56.1.dist-info}/WHEEL +0 -0
@@ -3,5 +3,5 @@ from packaging.version import Version
3
3
  # This file was generated automatically
4
4
  # Please don't track in version control (DONTTRACK)
5
5
 
6
- SEMVER_VERSION = '0.55.0'
6
+ SEMVER_VERSION = '0.56.1'
7
7
  VERSION = str(Version(SEMVER_VERSION))
@@ -5,6 +5,7 @@ from uuid import UUID
5
5
 
6
6
  from pydantic import BaseModel, Field
7
7
 
8
+ from classiq.interface.enum_utils import StrEnum
8
9
  from classiq.interface.generator.generated_circuit_data import (
9
10
  FunctionDebugInfoInterface,
10
11
  OperationLevel,
@@ -13,11 +14,21 @@ from classiq.interface.generator.generated_circuit_data import (
13
14
  ParameterValue = Union[float, int, str, None]
14
15
 
15
16
 
17
+ class StatementType(StrEnum):
18
+ CONTROL = "control"
19
+ POWER = "power"
20
+ INVERT = "invert"
21
+ WITHIN_APPLY = "within_apply"
22
+ ASSIGNMENT = "assignment"
23
+ REPEAT = "repeat"
24
+
25
+
16
26
  class FunctionDebugInfo(BaseModel):
17
27
  name: str
18
28
  # Parameters describe classical parameters passed to function
19
29
  parameters: dict[str, str]
20
30
  level: OperationLevel
31
+ statement_type: Union[StatementType, None] = None
21
32
  is_allocate_or_free: bool = Field(default=False)
22
33
  is_inverse: bool = Field(default=False)
23
34
  port_to_passed_variable_map: dict[str, str] = Field(default_factory=dict)
@@ -315,9 +315,6 @@ class EstimationMetadata(BaseModel, extra="allow"):
315
315
 
316
316
  class EstimationResult(BaseModel, QmodPyObject):
317
317
  value: Complex = pydantic.Field(..., description="Estimation for the operator")
318
- variance: Optional[Complex] = pydantic.Field(
319
- description="Variance of the estimation", default=None
320
- )
321
318
  metadata: EstimationMetadata = pydantic.Field(
322
319
  ..., description="Metadata for the estimation"
323
320
  )
@@ -1,8 +1,9 @@
1
- from collections import abc
1
+ from collections import abc, defaultdict
2
2
  from collections.abc import Collection, Mapping, Sequence
3
3
  from typing import (
4
4
  TYPE_CHECKING,
5
5
  Any,
6
+ Callable,
6
7
  Optional,
7
8
  TypeVar,
8
9
  Union,
@@ -84,6 +85,17 @@ class Transformer(Visitor):
84
85
  def visit_dict(self, node: dict[Key, NodeType]) -> dict[Key, RetType]:
85
86
  return {key: self.visit(value) for key, value in node.items()}
86
87
 
88
+ def visit_defaultdict(
89
+ self, node: defaultdict[Key, NodeType]
90
+ ) -> defaultdict[Key, RetType]:
91
+ new_default_factory: Callable[[], RetType] | None = None
92
+ if (default_factory := node.default_factory) is not None:
93
+
94
+ def new_default_factory() -> RetType:
95
+ return self.visit(default_factory()) # type: ignore[misc]
96
+
97
+ return defaultdict(new_default_factory, self.visit_dict(node))
98
+
87
99
  def visit_tuple(self, node: tuple[NodeType, ...]) -> tuple[RetType, ...]:
88
100
  return tuple(self.visit(value) for value in node)
89
101
 
@@ -105,6 +105,7 @@ class AtomicGate(StrEnum):
105
105
 
106
106
  class Operation(pydantic.BaseModel):
107
107
  name: str
108
+ qasm_name: str = pydantic.Field(default="")
108
109
  details: str = pydantic.Field(default="")
109
110
  children: list["Operation"]
110
111
  operation_data: Optional[OperationData] = None
@@ -120,6 +121,7 @@ class Operation(pydantic.BaseModel):
120
121
  gate: AtomicGate = pydantic.Field(
121
122
  default=AtomicGate.UNKNOWN, description="Gate type"
122
123
  )
124
+ is_daggered: bool = pydantic.Field(default=False)
123
125
 
124
126
 
125
127
  class ProgramVisualModel(VersionedModel):
@@ -1 +1 @@
1
- INTERFACE_VERSION = "4"
1
+ INTERFACE_VERSION = "5"
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Union
4
4
 
5
5
  import pydantic
6
6
  from pydantic import ConfigDict, Field
7
+ from typing_extensions import Self
7
8
 
8
9
  from classiq.interface.ast_node import ASTNode
9
10
  from classiq.interface.generator.expressions.expression import Expression
@@ -46,6 +47,16 @@ class HandleBinding(ASTNode):
46
47
  for self_prefix, other_prefix in zip(self_prefixes, other_prefixes)
47
48
  )
48
49
 
50
+ def rename(self, name: str) -> Self:
51
+ return self.model_copy(update=dict(name=name))
52
+
53
+ def replace_prefix(
54
+ self, prefix: "HandleBinding", replacement: "HandleBinding"
55
+ ) -> "HandleBinding":
56
+ if self == prefix:
57
+ return replacement
58
+ return self
59
+
49
60
 
50
61
  class NestedHandleBinding(HandleBinding):
51
62
  base_handle: "ConcreteHandleBinding"
@@ -70,6 +81,23 @@ class NestedHandleBinding(HandleBinding):
70
81
  def prefixes(self) -> Sequence["HandleBinding"]:
71
82
  return list(chain.from_iterable([self.base_handle.prefixes(), [self]]))
72
83
 
84
+ def rename(self, name: str) -> Self:
85
+ return self.model_copy(
86
+ update=dict(name=name, base_handle=self.base_handle.rename(name))
87
+ )
88
+
89
+ def replace_prefix(
90
+ self, prefix: HandleBinding, replacement: HandleBinding
91
+ ) -> HandleBinding:
92
+ if self == prefix:
93
+ return replacement
94
+ new_base_handle = self.base_handle.replace_prefix(prefix, replacement)
95
+ if new_base_handle is not self.base_handle:
96
+ return self.model_copy(
97
+ update=dict(name=new_base_handle.name, base_handle=new_base_handle)
98
+ )
99
+ return self
100
+
73
101
 
74
102
  class SubscriptHandleBinding(NestedHandleBinding):
75
103
  index: Expression
@@ -86,3 +86,6 @@ class QuantumOperation(QuantumStatement):
86
86
 
87
87
  def is_generative(self) -> bool:
88
88
  return len(self._generative_blocks) > 0
89
+
90
+ def clear_generative_blocks(self) -> None:
91
+ self._generative_blocks.clear()
@@ -1,11 +1,13 @@
1
1
  import re
2
2
 
3
3
  from classiq.interface.generator.compiler_keywords import CAPTURE_SUFFIX
4
+ from classiq.interface.model.handle_binding import HANDLE_ID_SEPARATOR, HandleBinding
4
5
 
5
6
  IDENTIFIER_PATTERN = r"[a-zA-Z_][a-zA-Z0-9_]*"
6
7
  CAPTURE_PATTERN = re.compile(
7
8
  rf"({IDENTIFIER_PATTERN}){CAPTURE_SUFFIX}{IDENTIFIER_PATTERN}__"
8
9
  )
10
+ ARRAY_CAST_SUFFIX = HANDLE_ID_SEPARATOR + "array_cast"
9
11
 
10
12
 
11
13
  def mangle_captured_var_name(var_name: str, defining_function: str) -> str:
@@ -15,3 +17,23 @@ def mangle_captured_var_name(var_name: str, defining_function: str) -> str:
15
17
  def demangle_name(name: str) -> str:
16
18
  match = re.match(CAPTURE_PATTERN, name)
17
19
  return match.group(1) if match else name
20
+
21
+
22
+ def demangle_handle(handle: HandleBinding) -> HandleBinding:
23
+ name = handle.name
24
+ if HANDLE_ID_SEPARATOR not in name:
25
+ return handle
26
+ if ARRAY_CAST_SUFFIX in name:
27
+ return HandleBinding(name=name.split(ARRAY_CAST_SUFFIX)[0])
28
+ name = re.sub(r"_\d+$", "", name)
29
+ name_parts = name.split(HANDLE_ID_SEPARATOR)
30
+ new_name = name_parts[0]
31
+ for part in name_parts[1:]:
32
+ if re.fullmatch(r"\d+", part):
33
+ new_name += f"[{part}]"
34
+ elif re.fullmatch(r"\d+_\d+", part):
35
+ part_left, part_right = part.split("_")
36
+ new_name += f"[{part_left}:{part_right}]"
37
+ else:
38
+ new_name += f".{part}"
39
+ return handle.rename(new_name)
@@ -9,12 +9,15 @@ from classiq.interface.exceptions import (
9
9
  from classiq.interface.generator.functions.port_declaration import (
10
10
  PortDeclarationDirection,
11
11
  )
12
- from classiq.interface.model.handle_binding import HANDLE_ID_SEPARATOR, HandleBinding
12
+ from classiq.interface.model.handle_binding import HandleBinding
13
13
  from classiq.interface.model.port_declaration import PortDeclaration
14
14
  from classiq.interface.model.quantum_function_call import ArgValue
15
15
  from classiq.interface.model.quantum_statement import QuantumOperation
16
16
 
17
- from classiq.model_expansions.capturing.mangling_utils import mangle_captured_var_name
17
+ from classiq.model_expansions.capturing.mangling_utils import (
18
+ demangle_handle,
19
+ mangle_captured_var_name,
20
+ )
18
21
  from classiq.model_expansions.closure import FunctionClosure, GenerativeFunctionClosure
19
22
  from classiq.model_expansions.function_builder import OperationBuilder
20
23
  from classiq.model_expansions.scope import QuantumSymbol, Scope
@@ -25,9 +28,12 @@ class PropagatedVariable:
25
28
  symbol: QuantumSymbol
26
29
  direction: PortDeclarationDirection
27
30
  defining_function: str
31
+ handle: HandleBinding
28
32
 
29
33
  @property
30
34
  def name(self) -> str:
35
+ name = self.symbol.handle.name
36
+ assert name == self.handle.name
31
37
  return self.symbol.handle.name
32
38
 
33
39
 
@@ -99,21 +105,22 @@ class PropagatedVarStack:
99
105
  direction: PortDeclarationDirection,
100
106
  ) -> dict[PropagatedVariable, None]:
101
107
  return {
102
- self._get_captured_var_with_direction(var.name, direction): None
108
+ self._get_captured_var_with_direction(var, direction): None
103
109
  for var in variables
104
110
  if self._is_captured(var.name)
105
111
  }
106
112
 
107
113
  def _get_captured_var_with_direction(
108
- self, var_name: str, direction: PortDeclarationDirection
114
+ self, var_handle: HandleBinding, direction: PortDeclarationDirection
109
115
  ) -> PropagatedVariable:
110
- defining_function = self._current_scope[var_name].defining_function
116
+ defining_function = self._current_scope[var_handle.name].defining_function
111
117
  if defining_function is None:
112
118
  raise ClassiqInternalExpansionError
113
119
  return PropagatedVariable(
114
- symbol=self._current_scope[var_name].as_type(QuantumSymbol),
120
+ symbol=self._current_scope[var_handle.name].as_type(QuantumSymbol),
115
121
  direction=direction,
116
122
  defining_function=defining_function.name,
123
+ handle=var_handle,
117
124
  )
118
125
 
119
126
  def _is_captured(self, var_name: str) -> bool:
@@ -133,15 +140,16 @@ class PropagatedVarStack:
133
140
  for var in self._stack[-1]
134
141
  )
135
142
 
136
- def get_propagated_variables(self) -> list[HandleBinding]:
137
- propagated_var_names: list[str] = [
138
- self._get_propagated_var_name(var) for var in self._stack[-1]
139
- ]
140
- return [
141
- HandleBinding(name=name) for name in dict.fromkeys(propagated_var_names)
142
- ]
143
+ def get_propagated_variables(self, flatten: bool) -> list[HandleBinding]:
144
+ return list(
145
+ dict.fromkeys(
146
+ [self._get_propagated_handle(var, flatten) for var in self._stack[-1]]
147
+ )
148
+ )
143
149
 
144
- def _get_propagated_var_name(self, var: PropagatedVariable) -> str:
150
+ def _get_propagated_handle(
151
+ self, var: PropagatedVariable, flatten: bool
152
+ ) -> HandleBinding:
145
153
  if (
146
154
  var.defining_function == self._builder.current_function.name
147
155
  or not isinstance(
@@ -155,7 +163,9 @@ class PropagatedVarStack:
155
163
  else:
156
164
  handle_name = mangle_captured_var_name(var.name, var.defining_function)
157
165
  self._to_mangle[var] = handle_name
158
- return handle_name
166
+ if flatten:
167
+ return HandleBinding(name=handle_name)
168
+ return var.handle.rename(handle_name)
159
169
 
160
170
  def _no_name_conflict(self, var: PropagatedVariable) -> bool:
161
171
  return var.name not in self._builder.current_function.colliding_variables
@@ -166,18 +176,19 @@ def validate_args_are_not_propagated(
166
176
  ) -> None:
167
177
  if not captured_vars:
168
178
  return
169
- captured_var_names = {var.name for var in captured_vars}
170
- arg_names = {
171
- demangle_suffixes(arg.name) for arg in args if isinstance(arg, HandleBinding)
179
+ captured_handles = {demangle_handle(handle) for handle in captured_vars}
180
+ arg_handles = {
181
+ demangle_handle(arg) for arg in args if isinstance(arg, HandleBinding)
172
182
  }
173
- if not captured_var_names.isdisjoint(arg_names):
174
- vars_msg = f"Explicitly passed variables: {arg_names}, captured variables: {captured_var_names}"
183
+ if any(
184
+ arg_handle.overlaps(captured_handle)
185
+ for arg_handle in arg_handles
186
+ for captured_handle in captured_handles
187
+ ):
188
+ captured_handles_str = {str(handle) for handle in captured_handles}
189
+ arg_handles_str = {str(handle) for handle in arg_handles}
190
+ vars_msg = f"Explicitly passed variables: {arg_handles_str}, captured variables: {captured_handles_str}"
175
191
  raise ClassiqExpansionError(
176
192
  f"Cannot capture variables that are explicitly passed as arguments. "
177
193
  f"{vars_msg}"
178
194
  )
179
-
180
-
181
- # TODO this is not a good long-term solution
182
- def demangle_suffixes(name: str) -> str:
183
- return name.split(HANDLE_ID_SEPARATOR)[0]
@@ -130,8 +130,10 @@ class FunctionClosure(Closure):
130
130
  self, declaration: NamedParamsQuantumFunctionDeclaration
131
131
  ) -> Self:
132
132
  fields: dict = self.__dict__ | {
133
- "positional_arg_declarations": declaration.positional_arg_declarations
133
+ "name": declaration.name,
134
+ "positional_arg_declarations": declaration.positional_arg_declarations,
134
135
  }
136
+ fields.pop("colliding_variables", 0)
135
137
  return type(self)(**fields)
136
138
 
137
139
 
@@ -30,6 +30,7 @@ from classiq.interface.model.quantum_statement import QuantumStatement
30
30
  from classiq.model_expansions.capturing.captured_var_manager import update_captured_vars
31
31
  from classiq.model_expansions.capturing.mangling_utils import demangle_name
32
32
  from classiq.model_expansions.closure import Closure, FunctionClosure
33
+ from classiq.model_expansions.scope import Scope
33
34
 
34
35
  ClosureType = TypeVar("ClosureType", bound=Closure)
35
36
 
@@ -77,10 +78,10 @@ class FunctionContext(OperationContext[FunctionClosure]):
77
78
 
78
79
 
79
80
  class OperationBuilder:
80
- def __init__(self) -> None:
81
+ def __init__(self, functions_scope: Scope) -> None:
81
82
  self._operations: list[OperationContext] = []
82
83
  self._blocks: list[str] = []
83
- self._counter = 0
84
+ self._functions_scope = functions_scope
84
85
 
85
86
  @property
86
87
  def current_operation(self) -> Closure:
@@ -158,8 +159,12 @@ class OperationBuilder:
158
159
  ) -> NativeFunctionDefinition:
159
160
  name = function_context.name
160
161
  if name != MAIN_FUNCTION_NAME:
161
- name = f"{name}_{LAMBDA_KEYWORD + '_0_0_' if function_context.is_lambda else ''}{EXPANDED_KEYWORD}_{self._counter}"
162
- self._counter += 1
162
+ idx = 0
163
+ new_name = name
164
+ while idx == 0 or new_name in self._functions_scope:
165
+ new_name = f"{name}_{LAMBDA_KEYWORD + '_0_0_' if function_context.is_lambda else ''}{EXPANDED_KEYWORD}_{idx}"
166
+ idx += 1
167
+ name = new_name
163
168
 
164
169
  new_parameters: list[PortDeclaration] = [
165
170
  param
@@ -11,6 +11,7 @@ from classiq.interface.exceptions import (
11
11
  ClassiqExpansionError,
12
12
  ClassiqInternalExpansionError,
13
13
  )
14
+ from classiq.interface.generator.constant import Constant
14
15
  from classiq.interface.generator.expressions.expression import Expression
15
16
  from classiq.interface.generator.types.compilation_metadata import CompilationMetadata
16
17
  from classiq.interface.model.bind_operation import BindOperation
@@ -81,6 +82,7 @@ from classiq.model_expansions.quantum_operations import (
81
82
  from classiq.model_expansions.quantum_operations.phase import PhaseEmitter
82
83
  from classiq.model_expansions.scope import Evaluated, QuantumSymbol, Scope
83
84
  from classiq.model_expansions.scope_initialization import (
85
+ add_constants_to_scope,
84
86
  add_entry_point_params_to_scope,
85
87
  get_main_renamer,
86
88
  init_top_level_scope,
@@ -107,7 +109,8 @@ class Interpreter:
107
109
  self._is_frontend = is_frontend
108
110
  self._model = model
109
111
  self._current_scope = Scope()
110
- self._builder = OperationBuilder()
112
+ self._top_level_scope = self._current_scope
113
+ self._builder = OperationBuilder(self._top_level_scope)
111
114
  self._expanded_functions: dict[str, NativeFunctionDefinition] = {}
112
115
  self._propagated_var_stack = PropagatedVarStack(
113
116
  self._current_scope, self._builder
@@ -384,3 +387,6 @@ class Interpreter:
384
387
  + [gen_func.func_decl for gen_func in self._generative_functions]
385
388
  + list(self._expanded_functions.values())
386
389
  )
390
+
391
+ def add_constant(self, constant: Constant) -> None:
392
+ add_constants_to_scope([constant], self._top_level_scope)
@@ -17,19 +17,24 @@ from classiq.interface.generator.functions.builtins.internal_operators import (
17
17
  )
18
18
  from classiq.interface.model.bind_operation import BindOperation
19
19
  from classiq.interface.model.control import Control
20
- from classiq.interface.model.handle_binding import HANDLE_ID_SEPARATOR, HandleBinding
20
+ from classiq.interface.model.handle_binding import HandleBinding
21
21
  from classiq.interface.model.quantum_expressions.arithmetic_operation import (
22
22
  ArithmeticOperation,
23
23
  ArithmeticOperationKind,
24
24
  )
25
25
  from classiq.interface.model.quantum_function_call import QuantumFunctionCall
26
- from classiq.interface.model.quantum_type import QuantumBit, QuantumBitvector
26
+ from classiq.interface.model.quantum_type import (
27
+ QuantumBit,
28
+ QuantumBitvector,
29
+ QuantumType,
30
+ )
27
31
  from classiq.interface.model.statement_block import ConcreteQuantumStatement
28
32
  from classiq.interface.model.variable_declaration_statement import (
29
33
  VariableDeclarationStatement,
30
34
  )
31
35
  from classiq.interface.model.within_apply_operation import WithinApply
32
36
 
37
+ from classiq.model_expansions.capturing.mangling_utils import ARRAY_CAST_SUFFIX
33
38
  from classiq.model_expansions.capturing.propagated_var_stack import (
34
39
  validate_args_are_not_propagated,
35
40
  )
@@ -44,8 +49,6 @@ from classiq.model_expansions.quantum_operations.expression_operation import (
44
49
  from classiq.model_expansions.scope import Scope
45
50
  from classiq.qmod.builtins.functions.standard_gates import X
46
51
 
47
- ARRAY_CAST_SUFFIX = HANDLE_ID_SEPARATOR + "array_cast"
48
-
49
52
 
50
53
  class ControlEmitter(ExpressionOperationEmitter[Control]):
51
54
  def emit(self, control: Control, /) -> None:
@@ -53,6 +56,9 @@ class ControlEmitter(ExpressionOperationEmitter[Control]):
53
56
 
54
57
  arrays_with_subscript = self._get_symbols_to_split(condition)
55
58
  if len(arrays_with_subscript) > 0:
59
+ if control.is_generative():
60
+ with self._propagated_var_stack.capture_variables(control):
61
+ control = self._expand_generative_control(control)
56
62
  self._emit_with_split(control, condition, arrays_with_subscript)
57
63
  return
58
64
 
@@ -86,16 +92,26 @@ class ControlEmitter(ExpressionOperationEmitter[Control]):
86
92
  )
87
93
  return Expression(expr=f"{lhs} == 1")
88
94
 
95
+ def _expand_generative_control(self, control: Control) -> Control:
96
+ block_names = ["body"]
97
+ if control.has_generative_block("else_block"):
98
+ block_names += ["else_block"]
99
+ context = self._register_generative_context(
100
+ control, CONTROL_OPERATOR_NAME, block_names
101
+ )
102
+ new_blocks = {"body": context.statements("body")}
103
+ if "else_block" in block_names:
104
+ new_blocks["else_block"] = context.statements("else_block")
105
+ return control.model_copy(update=new_blocks)
106
+
89
107
  def _emit_canonical_control(self, control: Control) -> None:
90
108
  # canonical means control(q, body) where q is a single quantum variable
91
- control = self._evaluate_types_in_expression(control, control.expression)
92
109
  with self._propagated_var_stack.capture_variables(control):
93
110
  self._emit_propagated(control)
94
111
 
95
112
  def _emit_propagated(self, control: Control) -> None:
96
113
  if control.is_generative():
97
- context = self._register_generative_context(control, CONTROL_OPERATOR_NAME)
98
- control = control.model_copy(update={"body": context.statements("body")})
114
+ control = self._expand_generative_control(control)
99
115
 
100
116
  if self._should_wrap_control(control):
101
117
  self._emit_wrapped(control)
@@ -121,7 +137,7 @@ class ControlEmitter(ExpressionOperationEmitter[Control]):
121
137
  context = self._expand_operation(control_operation)
122
138
  validate_args_are_not_propagated(
123
139
  control.var_handles,
124
- self._propagated_var_stack.get_propagated_variables(),
140
+ self._propagated_var_stack.get_propagated_variables(flatten=False),
125
141
  )
126
142
  self._update_control_state(control)
127
143
  self._builder.emit_statement(
@@ -134,7 +150,7 @@ class ControlEmitter(ExpressionOperationEmitter[Control]):
134
150
  )
135
151
  validate_args_are_not_propagated(
136
152
  control.var_handles,
137
- self._propagated_var_stack.get_propagated_variables(),
153
+ self._propagated_var_stack.get_propagated_variables(flatten=False),
138
154
  )
139
155
  self._update_control_state(control)
140
156
  self._builder.emit_statement(
@@ -197,6 +213,7 @@ class ControlEmitter(ExpressionOperationEmitter[Control]):
197
213
  action=[control_else_inner],
198
214
  )
199
215
  if control_else_inner.is_generative():
216
+ control_then.remove_generative_block("else_block")
200
217
  control_else_inner.set_generative_block(
201
218
  "body", control_else_inner.get_generative_block("else_block")
202
219
  )
@@ -290,6 +307,15 @@ class ControlEmitter(ExpressionOperationEmitter[Control]):
290
307
  if isinstance(condition_val, QmodQNumProxy):
291
308
  raise ClassiqExpansionError(_condition_err_msg(condition_val))
292
309
 
310
+ def _get_updated_op_split_symbols(
311
+ self, op: Control, symbol_mapping: dict[HandleBinding, tuple[str, QuantumType]]
312
+ ) -> Control:
313
+ new_body = self._rewrite(op.body, symbol_mapping)
314
+ new_else = None
315
+ if op.else_block is not None:
316
+ new_else = self._rewrite(op.else_block, symbol_mapping)
317
+ return op.model_copy(update=dict(body=new_body, else_block=new_else))
318
+
293
319
 
294
320
  def _condition_err_msg(condition_val: ExpressionValue) -> str:
295
321
  return (
@@ -104,6 +104,10 @@ class Emitter(Generic[QuantumStatementT]):
104
104
  def _current_scope(self) -> Scope:
105
105
  return self._interpreter._current_scope
106
106
 
107
+ @property
108
+ def _top_level_scope(self) -> Scope:
109
+ return self._interpreter._top_level_scope
110
+
107
111
  @property
108
112
  def _expanded_functions(self) -> dict[str, NativeFunctionDefinition]:
109
113
  return self._interpreter._expanded_functions
@@ -156,6 +160,10 @@ class Emitter(Generic[QuantumStatementT]):
156
160
  if function_def is None:
157
161
  function_def = self._builder.create_definition(function_context)
158
162
  self._expanded_functions[closure_id] = function_def
163
+ self._top_level_scope[function_def.name] = Evaluated(
164
+ value=function_context.closure.with_new_declaration(function_def)
165
+ )
166
+ new_declaration = function_def
159
167
  new_function_name = function_def.name
160
168
  compilation_metadata = self._functions_compilation_metadata.get(
161
169
  function.name
@@ -194,8 +202,7 @@ class Emitter(Generic[QuantumStatementT]):
194
202
  is_allocate_or_free=is_allocate_or_free,
195
203
  port_to_passed_variable_map=port_to_passed_variable_map,
196
204
  )
197
- if is_atomic:
198
- new_call.set_func_decl(new_declaration)
205
+ new_call.set_func_decl(new_declaration)
199
206
  return new_call
200
207
 
201
208
  @staticmethod
@@ -236,7 +243,9 @@ class Emitter(Generic[QuantumStatementT]):
236
243
  arg.emit() for arg in evaluated_args if isinstance(arg.value, QuantumSymbol)
237
244
  ]
238
245
 
239
- propagated_variables = self._propagated_var_stack.get_propagated_variables()
246
+ propagated_variables = self._propagated_var_stack.get_propagated_variables(
247
+ flatten=True
248
+ )
240
249
  validate_args_are_not_propagated(positional_args, propagated_variables)
241
250
  positional_args.extend(propagated_variables)
242
251
 
@@ -295,6 +304,7 @@ class Emitter(Generic[QuantumStatementT]):
295
304
  )
296
305
  context = self._interpreter._expand_operation(gen_closure)
297
306
  self._generative_contexts[context_name] = context
307
+ op.clear_generative_blocks()
298
308
  return context
299
309
 
300
310
  def _evaluate_expression(self, expression: Expression) -> Expression: