classiq 0.58.0__py3-none-any.whl → 0.59.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. classiq/_internals/api_wrapper.py +8 -3
  2. classiq/_internals/jobs.py +3 -5
  3. classiq/execution/execution_session.py +36 -20
  4. classiq/executor.py +2 -1
  5. classiq/interface/_version.py +1 -1
  6. classiq/interface/generator/arith/arithmetic_operations.py +1 -0
  7. classiq/interface/generator/register_role.py +8 -0
  8. classiq/interface/model/handle_binding.py +22 -3
  9. classiq/model_expansions/capturing/captured_vars.py +316 -0
  10. classiq/model_expansions/capturing/mangling_utils.py +18 -9
  11. classiq/model_expansions/closure.py +29 -74
  12. classiq/model_expansions/function_builder.py +51 -66
  13. classiq/model_expansions/interpreter.py +4 -7
  14. classiq/model_expansions/quantum_operations/bind.py +1 -3
  15. classiq/model_expansions/quantum_operations/call_emitter.py +46 -11
  16. classiq/model_expansions/quantum_operations/classicalif.py +2 -5
  17. classiq/model_expansions/quantum_operations/control.py +13 -16
  18. classiq/model_expansions/quantum_operations/emitter.py +36 -8
  19. classiq/model_expansions/quantum_operations/expression_operation.py +9 -19
  20. classiq/model_expansions/quantum_operations/inplace_binary_operation.py +4 -6
  21. classiq/model_expansions/quantum_operations/invert.py +5 -8
  22. classiq/model_expansions/quantum_operations/power.py +5 -10
  23. classiq/model_expansions/quantum_operations/quantum_assignment_operation.py +1 -3
  24. classiq/model_expansions/quantum_operations/quantum_function_call.py +1 -3
  25. classiq/model_expansions/quantum_operations/repeat.py +3 -3
  26. classiq/model_expansions/quantum_operations/variable_decleration.py +1 -1
  27. classiq/model_expansions/quantum_operations/within_apply.py +1 -5
  28. classiq/model_expansions/scope.py +2 -2
  29. classiq/model_expansions/transformers/var_splitter.py +32 -19
  30. classiq/model_expansions/utils/handles_collector.py +33 -0
  31. classiq/model_expansions/visitors/variable_references.py +18 -2
  32. classiq/qmod/qfunc.py +9 -13
  33. classiq/qmod/quantum_expandable.py +1 -21
  34. classiq/qmod/quantum_function.py +16 -0
  35. {classiq-0.58.0.dist-info → classiq-0.59.0.dist-info}/METADATA +2 -2
  36. {classiq-0.58.0.dist-info → classiq-0.59.0.dist-info}/RECORD +37 -38
  37. classiq/interface/executor/aws_execution_cost.py +0 -90
  38. classiq/model_expansions/capturing/captured_var_manager.py +0 -48
  39. classiq/model_expansions/capturing/propagated_var_stack.py +0 -194
  40. {classiq-0.58.0.dist-info → classiq-0.59.0.dist-info}/WHEEL +0 -0
@@ -1,34 +1,29 @@
1
+ import dataclasses
1
2
  import json
2
3
  import uuid
3
- from collections import defaultdict
4
- from collections.abc import Collection, Sequence
4
+ from collections.abc import Collection, Iterator, Sequence
5
+ from contextlib import contextmanager
5
6
  from dataclasses import dataclass, field
6
- from functools import cached_property, singledispatch
7
+ from functools import singledispatch
7
8
  from symtable import Symbol
8
- from typing import Any, Optional, Union
9
+ from typing import Any, Optional
9
10
 
10
11
  from typing_extensions import Self
11
12
 
12
- from classiq.interface.exceptions import (
13
- ClassiqInternalExpansionError,
14
- )
13
+ from classiq.interface.exceptions import ClassiqInternalExpansionError
15
14
  from classiq.interface.generator.functions.builtins.internal_operators import (
16
15
  All_BUILTINS_OPERATORS,
17
16
  )
18
- from classiq.interface.generator.visitor import Visitor
19
17
  from classiq.interface.model.port_declaration import PortDeclaration
20
- from classiq.interface.model.quantum_function_call import QuantumFunctionCall
21
18
  from classiq.interface.model.quantum_function_declaration import (
22
19
  NamedParamsQuantumFunctionDeclaration,
23
20
  PositionalArg,
24
21
  QuantumOperandDeclaration,
25
22
  )
26
23
  from classiq.interface.model.quantum_statement import QuantumStatement
27
- from classiq.interface.model.variable_declaration_statement import (
28
- VariableDeclarationStatement,
29
- )
30
24
 
31
25
  from classiq import ClassicalParameterDeclaration
26
+ from classiq.model_expansions.capturing.captured_vars import CapturedVars
32
27
  from classiq.model_expansions.expression_renamer import ExpressionRenamer
33
28
  from classiq.model_expansions.scope import (
34
29
  Evaluated,
@@ -46,6 +41,7 @@ class Closure:
46
41
  blocks: dict[str, Sequence[QuantumStatement]]
47
42
  scope: Scope
48
43
  positional_arg_declarations: Sequence[PositionalArg] = tuple()
44
+ captured_vars: CapturedVars = field(default_factory=CapturedVars)
49
45
 
50
46
  @property
51
47
  def port_declarations(self) -> dict[str, PortDeclaration]:
@@ -55,6 +51,11 @@ class Closure:
55
51
  if isinstance(param, PortDeclaration)
56
52
  }
57
53
 
54
+ @contextmanager
55
+ def freeze(self) -> Iterator[None]:
56
+ with self.scope.freeze(), self.captured_vars.freeze():
57
+ yield
58
+
58
59
 
59
60
  @dataclass(frozen=True)
60
61
  class GenerativeClosure(Closure):
@@ -66,6 +67,13 @@ class FunctionClosure(Closure):
66
67
  is_lambda: bool = False
67
68
  is_atomic: bool = False
68
69
  signature_scope: Scope = field(default_factory=Scope)
70
+ _depth: Optional[int] = None
71
+
72
+ @property
73
+ def depth(self) -> int:
74
+ if self._depth is None:
75
+ raise ClassiqInternalExpansionError
76
+ return self._depth
69
77
 
70
78
  # creates a unique id for the function closure based on the arguments values.
71
79
  # The closure is changing across the interpreter flow so it's closure_id may change
@@ -89,12 +97,6 @@ class FunctionClosure(Closure):
89
97
  return []
90
98
  return self.blocks["body"]
91
99
 
92
- @cached_property
93
- def colliding_variables(self) -> set[str]:
94
- # Note that this has to be accessed after adding the parameters from the signature and not during
95
- # initialization
96
- return VariableCollector(self.scope).get_colliding_variables(self.body)
97
-
98
100
  @classmethod
99
101
  def create(
100
102
  cls,
@@ -122,6 +124,7 @@ class FunctionClosure(Closure):
122
124
  blocks,
123
125
  scope,
124
126
  positional_arg_declarations,
127
+ CapturedVars(),
125
128
  is_lambda,
126
129
  is_atomic,
127
130
  **kwargs,
@@ -134,68 +137,20 @@ class FunctionClosure(Closure):
134
137
  "name": declaration.name,
135
138
  "positional_arg_declarations": declaration.positional_arg_declarations,
136
139
  }
137
- fields.pop("colliding_variables", 0)
138
140
  return type(self)(**fields)
139
141
 
142
+ def set_depth(self, depth: int) -> Self:
143
+ return dataclasses.replace(self, _depth=depth)
140
144
 
141
- @dataclass(frozen=True)
142
- class GenerativeFunctionClosure(GenerativeClosure, FunctionClosure):
143
- pass
144
-
145
-
146
- NestedFunctionClosureT = Union[FunctionClosure, list["NestedFunctionClosureT"]]
147
-
148
-
149
- class VariableCollector(Visitor):
150
- def __init__(self, function_scope: Scope) -> None:
151
- self._function_scope = function_scope
152
- self._variables: defaultdict[str, set[Optional[str]]] = defaultdict(set)
153
- for var in self._function_scope.data:
154
- defining_function = self._function_scope[var].defining_function
155
- if defining_function is not None:
156
- self._variables[var].add(defining_function.name)
157
-
158
- def get_colliding_variables(self, body: Sequence[QuantumStatement]) -> set[str]:
159
- self.visit(body)
160
- return {
161
- var
162
- for var, defining_functions in self._variables.items()
163
- if len(defining_functions) > 1
164
- }
165
-
166
- def visit_VariableDeclarationStatement(
167
- self, node: VariableDeclarationStatement
168
- ) -> None:
169
- self._variables[node.name].add(None)
170
-
171
- def visit_QuantumFunctionCall(self, node: QuantumFunctionCall) -> None:
172
- # The else case corresponds to operand identifiers. In case of operand identifiers, we scan
173
- # the whole list of operands because we can't evaluate the index yet.
174
- identifier = (
175
- node.function if isinstance(node.function, str) else node.function.name
145
+ def copy_scope(self) -> Self: # Remove when scoping is normal (CAD-24980)
146
+ return dataclasses.replace(
147
+ self, scope=Scope(self.scope.data, parent=self.scope.parent)
176
148
  )
177
- self._add_variables(self._function_scope[identifier].value)
178
-
179
- def _add_variables(self, evaluated: NestedFunctionClosureT) -> None:
180
- if isinstance(evaluated, list):
181
- for elem in evaluated:
182
- self._add_variables(elem)
183
- return
184
- if not isinstance(evaluated, FunctionClosure):
185
- raise ClassiqInternalExpansionError
186
- self._add_variables_from_closure(evaluated)
187
149
 
188
- def _add_variables_from_closure(self, closure: FunctionClosure) -> None:
189
- if not closure.is_lambda:
190
- return
191
- lambda_environment = closure.scope.parent
192
- if lambda_environment is None:
193
- raise ClassiqInternalExpansionError
194
150
 
195
- for var in lambda_environment.iter_without_top_level():
196
- defining_function = lambda_environment[var].defining_function
197
- if defining_function is not None:
198
- self._variables[var].add(defining_function.name)
151
+ @dataclass(frozen=True)
152
+ class GenerativeFunctionClosure(GenerativeClosure, FunctionClosure):
153
+ pass
199
154
 
200
155
 
201
156
  def _generate_closure_id(
@@ -1,12 +1,9 @@
1
- from collections.abc import Iterable, Iterator, Sequence
1
+ from collections.abc import Iterator, Sequence
2
2
  from contextlib import contextmanager
3
3
  from dataclasses import dataclass, field
4
4
  from typing import Generic, Optional, TypeVar
5
5
 
6
- from classiq.interface.exceptions import (
7
- ClassiqExpansionError,
8
- ClassiqInternalExpansionError,
9
- )
6
+ from classiq.interface.exceptions import ClassiqInternalExpansionError
10
7
  from classiq.interface.generator.compiler_keywords import (
11
8
  EXPANDED_KEYWORD,
12
9
  LAMBDA_KEYWORD,
@@ -14,9 +11,6 @@ from classiq.interface.generator.compiler_keywords import (
14
11
  from classiq.interface.generator.functions.builtins.internal_operators import (
15
12
  WITHIN_APPLY_NAME,
16
13
  )
17
- from classiq.interface.generator.functions.port_declaration import (
18
- PortDeclarationDirection,
19
- )
20
14
  from classiq.interface.model.model import MAIN_FUNCTION_NAME
21
15
  from classiq.interface.model.native_function_definition import (
22
16
  NativeFunctionDefinition,
@@ -28,8 +22,10 @@ from classiq.interface.model.quantum_function_declaration import (
28
22
  from classiq.interface.model.quantum_statement import QuantumStatement
29
23
  from classiq.interface.source_reference import SourceReference
30
24
 
31
- from classiq.model_expansions.capturing.captured_var_manager import update_captured_vars
32
- from classiq.model_expansions.capturing.mangling_utils import demangle_name
25
+ from classiq.model_expansions.capturing.captured_vars import (
26
+ CapturedVars,
27
+ validate_captured_directions,
28
+ )
33
29
  from classiq.model_expansions.closure import Closure, FunctionClosure
34
30
  from classiq.model_expansions.scope import Scope
35
31
 
@@ -39,7 +35,7 @@ ClosureType = TypeVar("ClosureType", bound=Closure)
39
35
  @dataclass
40
36
  class Block:
41
37
  statements: list[QuantumStatement] = field(default_factory=list)
42
- captured_vars: list[PortDeclaration] = field(default_factory=list)
38
+ captured_vars: CapturedVars = field(default_factory=CapturedVars)
43
39
 
44
40
 
45
41
  @dataclass
@@ -69,10 +65,6 @@ class FunctionContext(OperationContext[FunctionClosure]):
69
65
  def body(self) -> list[QuantumStatement]:
70
66
  return self.statements("body")
71
67
 
72
- @property
73
- def captured_vars(self) -> list[PortDeclaration]:
74
- return self.blocks["body"].captured_vars
75
-
76
68
  @property
77
69
  def is_lambda(self) -> bool:
78
70
  return self.closure.is_lambda
@@ -91,14 +83,26 @@ class OperationBuilder:
91
83
 
92
84
  @property
93
85
  def current_function(self) -> FunctionClosure:
94
- for operation in reversed(self._operations):
86
+ return self._get_last_function(self._operations)
87
+
88
+ @property
89
+ def parent_function(self) -> FunctionClosure:
90
+ return self._get_last_function(self._operations[:-1])
91
+
92
+ @staticmethod
93
+ def _get_last_function(operations: list[OperationContext]) -> FunctionClosure:
94
+ for operation in reversed(operations):
95
95
  if isinstance(operation.closure, FunctionClosure):
96
96
  return operation.closure
97
97
  raise ClassiqInternalExpansionError("No function found")
98
98
 
99
+ @property
100
+ def current_block(self) -> Block:
101
+ return self._operations[-1].blocks[self._blocks[-1]]
102
+
99
103
  @property
100
104
  def _current_statements(self) -> list[QuantumStatement]:
101
- return self._operations[-1].blocks[self._blocks[-1]].statements
105
+ return self.current_block.statements
102
106
 
103
107
  def emit_statement(self, statement: QuantumStatement) -> None:
104
108
  if self._current_source_ref is not None:
@@ -109,16 +113,20 @@ class OperationBuilder:
109
113
  def current_statement(self) -> QuantumStatement:
110
114
  return self._current_statements[-1]
111
115
 
112
- def add_captured_vars(self, captured_vars: Iterable[PortDeclaration]) -> None:
113
- self._operations[-1].blocks[self._blocks[-1]].captured_vars.extend(
114
- captured_vars
115
- )
116
-
117
116
  @contextmanager
118
117
  def block_context(self, block_name: str) -> Iterator[None]:
119
118
  self._blocks.append(block_name)
120
119
  self._operations[-1].blocks[block_name] = Block()
121
120
  yield
121
+ captured_vars = self.current_block.captured_vars
122
+ if (
123
+ not isinstance(self.current_operation, FunctionClosure)
124
+ and self.current_operation.name != WITHIN_APPLY_NAME
125
+ ):
126
+ validate_captured_directions(captured_vars)
127
+ self.current_operation.captured_vars.update(
128
+ captured_vars.filter(self.current_function)
129
+ )
122
130
  self._blocks.pop()
123
131
 
124
132
  @contextmanager
@@ -132,9 +140,28 @@ class OperationBuilder:
132
140
  context = OperationContext(closure=original_operation)
133
141
  self._operations.append(context)
134
142
  yield context
135
- self._update_captured_vars()
143
+ self._finalize_within_apply()
144
+ self._propagate_captured_vars()
136
145
  self._operations.pop()
137
146
 
147
+ def _finalize_within_apply(self) -> None:
148
+ if self.current_operation.name != WITHIN_APPLY_NAME:
149
+ return
150
+ within_captured_vars = self._operations[-1].blocks["within"].captured_vars
151
+ self.current_operation.captured_vars.update(
152
+ within_captured_vars.filter(self.current_function).negate()
153
+ )
154
+
155
+ def _propagate_captured_vars(self) -> None:
156
+ captured_vars = self.current_operation.captured_vars
157
+ if isinstance(self.current_operation, FunctionClosure):
158
+ captured_vars = captured_vars.set_propagated()
159
+ validate_captured_directions(captured_vars)
160
+ if len(self._operations) < 2:
161
+ return
162
+ parent_block = self._operations[-2].blocks[self._blocks[-1]]
163
+ parent_block.captured_vars.update(captured_vars.filter(self.parent_function))
164
+
138
165
  @contextmanager
139
166
  def source_ref_context(
140
167
  self, source_ref: Optional[SourceReference]
@@ -144,29 +171,6 @@ class OperationBuilder:
144
171
  yield
145
172
  self._current_source_ref = previous_source_ref
146
173
 
147
- def _update_captured_vars(self) -> None:
148
- for block in self._operations[-1].blocks.values():
149
- block.captured_vars = update_captured_vars(block.captured_vars)
150
- if not self._is_within_apply_context():
151
- validate_captured_vars(block.captured_vars)
152
-
153
- def is_compute_context(self) -> bool:
154
- return self._is_within_apply_context("within")
155
-
156
- def _is_within_apply_context(self, block_name: Optional[str] = None) -> bool:
157
- return self._is_op_within_apply_context(block_name, -1) or (
158
- len(self._operations) > 1
159
- and isinstance(self._operations[-1], FunctionContext)
160
- and self._is_op_within_apply_context(block_name, -2)
161
- )
162
-
163
- def _is_op_within_apply_context(
164
- self, block_name: Optional[str], index: int
165
- ) -> bool:
166
- return self._operations[index].name == WITHIN_APPLY_NAME and (
167
- block_name is None or self._blocks[index] == block_name
168
- )
169
-
170
174
  def create_definition(
171
175
  self, function_context: FunctionContext
172
176
  ) -> NativeFunctionDefinition:
@@ -183,29 +187,10 @@ class OperationBuilder:
183
187
  param
184
188
  for param in function_context.positional_arg_declarations
185
189
  if isinstance(param, PortDeclaration)
186
- ] + function_context.captured_vars
190
+ ]
187
191
 
188
192
  return NativeFunctionDefinition(
189
193
  name=name,
190
194
  body=function_context.body,
191
195
  positional_arg_declarations=new_parameters,
192
196
  )
193
-
194
-
195
- def validate_captured_vars(captured_vars: list[PortDeclaration]) -> None:
196
- if input_captured := [
197
- demangle_name(var.name)
198
- for var in captured_vars
199
- if var.direction is PortDeclarationDirection.Input
200
- ]:
201
- raise ClassiqExpansionError(
202
- f"Captured quantum variables {input_captured!r} cannot be used as inputs"
203
- )
204
- if output_captured := [
205
- demangle_name(var.name)
206
- for var in captured_vars
207
- if var.direction is PortDeclarationDirection.Output
208
- ]:
209
- raise ClassiqExpansionError(
210
- f"Captured quantum variables {output_captured!r} cannot be used as outputs"
211
- )
@@ -50,7 +50,6 @@ from classiq.interface.model.variable_declaration_statement import (
50
50
  )
51
51
  from classiq.interface.model.within_apply_operation import WithinApply
52
52
 
53
- from classiq.model_expansions.capturing.propagated_var_stack import PropagatedVarStack
54
53
  from classiq.model_expansions.closure import (
55
54
  Closure,
56
55
  FunctionClosure,
@@ -108,9 +107,6 @@ class Interpreter:
108
107
  self._top_level_scope = self._current_scope
109
108
  self._builder = OperationBuilder(self._top_level_scope)
110
109
  self._expanded_functions: dict[str, NativeFunctionDefinition] = {}
111
- self._propagated_var_stack = PropagatedVarStack(
112
- self._current_scope, self._builder
113
- )
114
110
 
115
111
  self._main_renamer: ExpressionRenamer = get_main_renamer(self._model.functions)
116
112
 
@@ -132,11 +128,9 @@ class Interpreter:
132
128
 
133
129
  prev_context_scope = self._current_scope
134
130
  self._current_scope = scope
135
- self._propagated_var_stack.set_scope(scope)
136
131
  yield
137
132
  prev_context_scope.data = prev_context_scope_data
138
133
  self._current_scope = prev_context_scope
139
- self._propagated_var_stack.set_scope(prev_context_scope)
140
134
 
141
135
  scope.data = scope_data
142
136
 
@@ -147,6 +141,7 @@ class Interpreter:
147
141
  body=self._model.main_func.body,
148
142
  scope=Scope(parent=self._current_scope),
149
143
  expr_renamer=self._main_renamer,
144
+ _depth=0,
150
145
  )
151
146
 
152
147
  add_entry_point_params_to_scope(
@@ -367,6 +362,7 @@ class Interpreter:
367
362
  def _expand_permute(self) -> None:
368
363
  functions = self.evaluate("functions").as_type(list)
369
364
  functions_permutation = permutation(np.array(range(len(functions))))
365
+ calls: list[QuantumFunctionCall] = []
370
366
  for function_index in functions_permutation:
371
367
  permute_call = QuantumFunctionCall(
372
368
  function=OperandIdentifier(
@@ -374,7 +370,8 @@ class Interpreter:
374
370
  )
375
371
  )
376
372
  permute_call.set_func_decl(permute.func_decl)
377
- self.emit(permute_call)
373
+ calls.append(permute_call)
374
+ self._expand_block(calls, "body")
378
375
 
379
376
  def _get_function_declarations(self) -> Sequence[QuantumFunctionDeclaration]:
380
377
  return (
@@ -14,8 +14,6 @@ from classiq.model_expansions.scope import QuantumSymbol
14
14
 
15
15
  class BindEmitter(Emitter[BindOperation]):
16
16
  def emit(self, bind: BindOperation, /) -> None:
17
- with self._propagated_var_stack.capture_variables(bind):
18
- pass
19
17
  inputs: list[QuantumSymbol] = [
20
18
  self._interpreter.evaluate(arg).as_type(QuantumSymbol)
21
19
  for arg in bind.in_handles
@@ -57,6 +55,6 @@ class BindEmitter(Emitter[BindOperation]):
57
55
  f"The total size for the input and output of the bind operation must be the same. The in size is {input_size} and the out size is {output_size}"
58
56
  )
59
57
 
60
- self._builder.emit_statement(
58
+ self.emit_statement(
61
59
  BindOperation(in_handles=bind.in_handles, out_handles=bind.out_handles)
62
60
  )
@@ -1,5 +1,7 @@
1
1
  from collections.abc import Sequence
2
+ from itertools import chain
2
3
  from typing import (
4
+ TYPE_CHECKING,
3
5
  Generic,
4
6
  cast,
5
7
  )
@@ -10,6 +12,7 @@ from classiq.interface.model.classical_parameter_declaration import (
10
12
  ClassicalParameterDeclaration,
11
13
  )
12
14
  from classiq.interface.model.handle_binding import HandleBinding
15
+ from classiq.interface.model.native_function_definition import NativeFunctionDefinition
13
16
  from classiq.interface.model.port_declaration import PortDeclaration
14
17
  from classiq.interface.model.quantum_function_call import ArgValue, QuantumFunctionCall
15
18
  from classiq.interface.model.quantum_function_declaration import (
@@ -21,7 +24,7 @@ from classiq.interface.model.variable_declaration_statement import (
21
24
  VariableDeclarationStatement,
22
25
  )
23
26
 
24
- from classiq.model_expansions.capturing.propagated_var_stack import (
27
+ from classiq.model_expansions.capturing.captured_vars import (
25
28
  validate_args_are_not_propagated,
26
29
  )
27
30
  from classiq.model_expansions.closure import FunctionClosure
@@ -40,10 +43,18 @@ from classiq.model_expansions.quantum_operations.emitter import (
40
43
  QuantumStatementT,
41
44
  )
42
45
  from classiq.model_expansions.scope import Evaluated, QuantumSymbol, Scope
46
+ from classiq.model_expansions.transformers.var_splitter import VarSplitter
43
47
  from classiq.qmod.builtins.functions import allocate, free
44
48
 
49
+ if TYPE_CHECKING:
50
+ from classiq.model_expansions.interpreter import Interpreter
51
+
52
+
53
+ class CallEmitter(Generic[QuantumStatementT], Emitter[QuantumStatementT], VarSplitter):
54
+ def __init__(self, interpreter: "Interpreter") -> None:
55
+ Emitter.__init__(self, interpreter)
56
+ VarSplitter.__init__(self, interpreter._current_scope)
45
57
 
46
- class CallEmitter(Generic[QuantumStatementT], Emitter[QuantumStatementT]):
47
58
  @staticmethod
48
59
  def _should_wrap(body: Sequence[QuantumStatement]) -> bool:
49
60
  # This protects shadowing of captured variables (i.e, bad user code) by wrapping the body in a function
@@ -54,7 +65,10 @@ class CallEmitter(Generic[QuantumStatementT], Emitter[QuantumStatementT]):
54
65
  self, name: str, body: Sequence[QuantumStatement]
55
66
  ) -> QuantumFunctionCall:
56
67
  wrapping_function = FunctionClosure.create(
57
- name=name, body=body, scope=Scope(parent=self._current_scope)
68
+ name=name,
69
+ body=body,
70
+ scope=Scope(parent=self._current_scope),
71
+ is_lambda=True,
58
72
  )
59
73
  return self._create_quantum_function_call(wrapping_function, list())
60
74
 
@@ -62,12 +76,14 @@ class CallEmitter(Generic[QuantumStatementT], Emitter[QuantumStatementT]):
62
76
  self, function: FunctionClosure, args: list[ArgValue]
63
77
  ) -> QuantumFunctionCall:
64
78
  call = self._create_quantum_function_call(function, args)
65
- self._builder.emit_statement(call)
79
+ self.emit_statement(call)
66
80
  return call
67
81
 
68
82
  def _create_quantum_function_call(
69
83
  self, function: FunctionClosure, args: list[ArgValue]
70
84
  ) -> QuantumFunctionCall:
85
+ function = function.set_depth(self._builder.current_function.depth + 1)
86
+ function = function.copy_scope()
71
87
  evaluated_args = [self._interpreter.evaluate(arg) for arg in args]
72
88
  new_declaration = self._prepare_fully_typed_declaration(
73
89
  function, evaluated_args
@@ -86,7 +102,7 @@ class CallEmitter(Generic[QuantumStatementT], Emitter[QuantumStatementT]):
86
102
  closure_id = function_context.closure.closure_id
87
103
  function_def = self._expanded_functions.get(closure_id)
88
104
  if function_def is None:
89
- function_def = self._builder.create_definition(function_context)
105
+ function_def = self._create_function_definition(function_context)
90
106
  self._expanded_functions[closure_id] = function_def
91
107
  self._top_level_scope[function_def.name] = Evaluated(
92
108
  value=function_context.closure.with_new_declaration(function_def)
@@ -104,6 +120,11 @@ class CallEmitter(Generic[QuantumStatementT], Emitter[QuantumStatementT]):
104
120
  new_positional_args = self._get_new_positional_args(
105
121
  evaluated_args, is_atomic, new_positional_arg_decls
106
122
  )
123
+ captured_args = function.captured_vars.get_captured_args(
124
+ self._builder.current_function
125
+ )
126
+ validate_args_are_not_propagated(new_positional_args, captured_args)
127
+ new_positional_args.extend(captured_args)
107
128
  new_call = QuantumFunctionCall(
108
129
  function=new_function_name,
109
130
  positional_args=new_positional_args,
@@ -133,6 +154,26 @@ class CallEmitter(Generic[QuantumStatementT], Emitter[QuantumStatementT]):
133
154
  new_call.set_func_decl(new_declaration)
134
155
  return new_call
135
156
 
157
+ def _create_function_definition(
158
+ self, function_context: FunctionContext
159
+ ) -> NativeFunctionDefinition:
160
+ func_def = self._builder.create_definition(function_context)
161
+
162
+ captured_ports = function_context.closure.captured_vars.get_captured_ports()
163
+ if len(captured_ports) == 0:
164
+ return func_def
165
+ func_def.positional_arg_declarations = list(
166
+ chain.from_iterable((func_def.positional_arg_declarations, captured_ports))
167
+ )
168
+
169
+ if not function_context.is_lambda:
170
+ return func_def
171
+ func_def.body = self.rewrite(
172
+ func_def.body, function_context.closure.captured_vars.get_captured_mapping()
173
+ )
174
+
175
+ return func_def
176
+
136
177
  @staticmethod
137
178
  def _add_params_to_scope(
138
179
  parameters: Sequence[PositionalArg],
@@ -171,12 +212,6 @@ class CallEmitter(Generic[QuantumStatementT], Emitter[QuantumStatementT]):
171
212
  arg.emit() for arg in evaluated_args if isinstance(arg.value, QuantumSymbol)
172
213
  ]
173
214
 
174
- propagated_variables = self._propagated_var_stack.get_propagated_variables(
175
- flatten=True
176
- )
177
- validate_args_are_not_propagated(positional_args, propagated_variables)
178
- positional_args.extend(propagated_variables)
179
-
180
215
  return positional_args
181
216
 
182
217
  def _prepare_fully_typed_declaration(
@@ -18,10 +18,6 @@ def _is_all_identity_calls(body: Sequence[QuantumStatement]) -> bool:
18
18
 
19
19
  class ClassicalIfEmitter(CallEmitter[ClassicalIf]):
20
20
  def emit(self, classical_if: ClassicalIf, /) -> None:
21
- with self._propagated_var_stack.capture_variables(classical_if):
22
- self._emit_propagated(classical_if)
23
-
24
- def _emit_propagated(self, classical_if: ClassicalIf) -> None:
25
21
  condition = self._interpreter.evaluate(classical_if.condition).as_type(bool)
26
22
  op_name = "then" if condition else "else"
27
23
  is_generative = classical_if.is_generative()
@@ -40,7 +36,7 @@ class ClassicalIfEmitter(CallEmitter[ClassicalIf]):
40
36
  if _is_all_identity_calls(body):
41
37
  return
42
38
 
43
- if not self._should_wrap(body):
39
+ if is_generative or not self._should_wrap(body):
44
40
  for stmt in body:
45
41
  if is_generative:
46
42
  self._interpreter._builder.emit_statement(stmt)
@@ -52,5 +48,6 @@ class ClassicalIfEmitter(CallEmitter[ClassicalIf]):
52
48
  name=op_name,
53
49
  body=body,
54
50
  scope=Scope(parent=self._current_scope),
51
+ is_lambda=True,
55
52
  )
56
53
  self._emit_quantum_function_call(then_else_func, list())
@@ -1,3 +1,5 @@
1
+ from typing import cast
2
+
1
3
  from sympy import Equality
2
4
  from sympy.logic.boolalg import Boolean
3
5
  from typing_extensions import TypeGuard
@@ -33,10 +35,10 @@ from classiq.interface.model.variable_declaration_statement import (
33
35
  )
34
36
  from classiq.interface.model.within_apply_operation import WithinApply
35
37
 
36
- from classiq.model_expansions.capturing.mangling_utils import ARRAY_CAST_SUFFIX
37
- from classiq.model_expansions.capturing.propagated_var_stack import (
38
+ from classiq.model_expansions.capturing.captured_vars import (
38
39
  validate_args_are_not_propagated,
39
40
  )
41
+ from classiq.model_expansions.capturing.mangling_utils import ARRAY_CAST_SUFFIX
40
42
  from classiq.model_expansions.closure import Closure
41
43
  from classiq.model_expansions.evaluators.control import (
42
44
  resolve_num_condition,
@@ -47,6 +49,7 @@ from classiq.model_expansions.quantum_operations.expression_operation import (
47
49
  )
48
50
  from classiq.model_expansions.scope import Scope
49
51
  from classiq.model_expansions.transformers.var_splitter import SymbolParts
52
+ from classiq.model_expansions.utils.handles_collector import extract_handles
50
53
  from classiq.qmod.builtins.functions.standard_gates import X
51
54
 
52
55
 
@@ -59,8 +62,7 @@ class ControlEmitter(ExpressionOperationEmitter[Control]):
59
62
  )
60
63
  if len(arrays_with_subscript) > 0:
61
64
  if control.is_generative():
62
- with self._propagated_var_stack.capture_variables(control):
63
- control = self._expand_generative_control(control)
65
+ control = self._expand_generative_control(control)
64
66
  self._emit_with_split(control, condition, arrays_with_subscript)
65
67
  return
66
68
 
@@ -108,14 +110,11 @@ class ControlEmitter(ExpressionOperationEmitter[Control]):
108
110
 
109
111
  def _emit_canonical_control(self, control: Control) -> None:
110
112
  # canonical means control(q, body) where q is a single quantum variable
111
- with self._propagated_var_stack.capture_variables(control):
112
- self._emit_propagated(control)
113
-
114
- def _emit_propagated(self, control: Control) -> None:
115
- if control.is_generative():
113
+ is_generative = control.is_generative()
114
+ if is_generative:
116
115
  control = self._expand_generative_control(control)
117
116
 
118
- if self._should_wrap_control(control):
117
+ if not is_generative and self._should_wrap_control(control):
119
118
  self._emit_wrapped(control)
120
119
  return
121
120
  self._emit_as_operation(control)
@@ -139,10 +138,10 @@ class ControlEmitter(ExpressionOperationEmitter[Control]):
139
138
  context = self._expand_operation(control_operation)
140
139
  validate_args_are_not_propagated(
141
140
  control.var_handles,
142
- self._propagated_var_stack.get_propagated_variables(flatten=False),
141
+ extract_handles([block.statements for block in context.blocks.values()]),
143
142
  )
144
143
  self._update_control_state(control)
145
- self._builder.emit_statement(
144
+ self.emit_statement(
146
145
  control.model_copy(update=dict(body=context.statements("body")))
147
146
  )
148
147
 
@@ -152,12 +151,10 @@ class ControlEmitter(ExpressionOperationEmitter[Control]):
152
151
  )
153
152
  validate_args_are_not_propagated(
154
153
  control.var_handles,
155
- self._propagated_var_stack.get_propagated_variables(flatten=False),
154
+ cast(list[HandleBinding], wrapping_function.positional_args),
156
155
  )
157
156
  self._update_control_state(control)
158
- self._builder.emit_statement(
159
- control.model_copy(update=dict(body=[wrapping_function]))
160
- )
157
+ self.emit_statement(control.model_copy(update=dict(body=[wrapping_function])))
161
158
 
162
159
  @staticmethod
163
160
  def _update_control_state(control: Control) -> None: