mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,614 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tracer: Python Function → Graph IR.
16
+
17
+ Responsible for converting Python functions to Graph IR, handling:
18
+ - Function parameters
19
+ - Free variables (external references including captures)
20
+ - Polymorphic handling of TraceObject/InterpObject
21
+
22
+ Tracer is a Context (inherits from Context abstract base class).
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import inspect
28
+ from collections.abc import Callable
29
+ from dataclasses import dataclass
30
+ from typing import TYPE_CHECKING, Any, cast
31
+
32
+ from jax.tree_util import PyTreeDef, tree_flatten, tree_map
33
+
34
+ from mplang.v2.edsl.context import Context
35
+ from mplang.v2.edsl.graph import Graph
36
+ from mplang.v2.edsl.graph import Value as GraphValue
37
+ from mplang.v2.edsl.object import Object
38
+ from mplang.v2.edsl.typing import BaseType
39
+
40
+ if TYPE_CHECKING:
41
+ from mplang.v2.edsl.primitive import Primitive
42
+
43
+
44
+ class TraceObject(Object):
45
+ """Trace-time object (during JIT tracing).
46
+
47
+ Holds a Value in the Graph IR and a reference to the Tracer (Context).
48
+ All operations delegate to primitives which record into Graph.
49
+
50
+ Example:
51
+ >>> from mplang.v2.edsl import trace
52
+ >>> def compute(x, y):
53
+ ... z = x + y # TraceObject.__add__ → add_p.bind(x, y)
54
+ ... return z
55
+ >>> graph = trace(compute, x_interp, y_interp)
56
+ """
57
+
58
+ def __init__(self, graph_value: GraphValue, tracer: Tracer):
59
+ self._graph_value = graph_value
60
+ self._context = tracer
61
+
62
+ @property
63
+ def type(self) -> BaseType:
64
+ return self._graph_value.type
65
+
66
+ @property
67
+ def _tracer(self) -> Tracer:
68
+ """Backward compatibility: access Tracer via _context."""
69
+ return self._context
70
+
71
+ def __repr__(self) -> str:
72
+ return f"TraceObject({self._graph_value.name}: {self.type})"
73
+
74
+
75
+ class Tracer(Context):
76
+ """Converter from Python Function to Graph IR.
77
+
78
+ Inherits from Context and implements bind_primitive() by recording to Graph.
79
+
80
+ Responsibilities:
81
+ 1. Convert Python functions to Graph IR
82
+ 2. Manage free variables (function params and captured external references)
83
+ 3. Handle Object Hierarchy (TraceObject/InterpObject)
84
+ 4. Promote InterpObject → TraceObject
85
+ 5. Implement Context.bind_primitive() by recording to Graph
86
+
87
+ Example:
88
+ >>> tracer = Tracer()
89
+ >>> graph = tracer.trace(lambda x, y: x + y, x_interp, y_interp)
90
+ >>> print(graph)
91
+ """
92
+
93
+ def __init__(self) -> None:
94
+ self.reset()
95
+
96
+ def reset(self) -> None:
97
+ """Reset graph state so a tracer instance can be reused."""
98
+ self.graph = Graph()
99
+ # Cache for captured variables (closures), keyed by id(obj)
100
+ # Does NOT include function parameters - those are created per-position
101
+ self._captured_vars: dict[int, tuple[Object, GraphValue]] = {}
102
+ self._arg_counter = 0
103
+
104
+ def bind_primitive(
105
+ self, primitive: Primitive, args: tuple[Any, ...], kwargs: dict[str, Any]
106
+ ) -> TraceObject | list[TraceObject] | Any:
107
+ """Execute primitive by recording to Graph IR (trace mode).
108
+
109
+ Handles two modes:
110
+ 1. def_trace: Primitive has full control - builds graph via other primitives
111
+ 2. def_abstract_eval: Tracer controls - infers types and builds operation
112
+
113
+ Args:
114
+ primitive: The primitive to trace
115
+ args: Positional arguments (can be Objects, opaques like callables, or constants)
116
+ kwargs: Keyword arguments (can be Objects, opaques, or constants)
117
+
118
+ Returns:
119
+ TraceObject, list[TraceObject], or PyTree containing TraceObjects
120
+
121
+ Raises:
122
+ RuntimeError: If primitive has neither trace nor abstract_eval defined
123
+ """
124
+ if primitive._trace is not None:
125
+ return primitive._trace(*args, **kwargs)
126
+
127
+ if primitive._abstract_eval is not None:
128
+ trace_args = list(args)
129
+ input_objects = [arg for arg in trace_args if isinstance(arg, TraceObject)]
130
+ input_types = [obj.type for obj in input_objects]
131
+
132
+ sig = inspect.signature(primitive._abstract_eval)
133
+ params = list(sig.parameters.values())
134
+ # Detect flat style: first param is list-annotated "in_types"
135
+ is_flat_style = len(params) >= 1 and params[0].name == "in_types"
136
+
137
+ if is_flat_style:
138
+ output_types = primitive._abstract_eval(input_types, **kwargs)
139
+ else:
140
+ output_types = primitive._abstract_eval(*input_types, **kwargs)
141
+
142
+ # Normalize to list: single type or sequence → list
143
+ if isinstance(output_types, BaseType):
144
+ output_types = [output_types]
145
+ else:
146
+ output_types = list(output_types)
147
+
148
+ input_values = [obj._graph_value for obj in input_objects]
149
+ result_values = self.graph.add_op(
150
+ opcode=primitive.name,
151
+ inputs=input_values,
152
+ output_types=output_types,
153
+ attrs=kwargs,
154
+ )
155
+ outs = [TraceObject(v, self) for v in result_values]
156
+ return outs[0] if len(outs) == 1 else outs
157
+
158
+ raise RuntimeError(
159
+ f"Primitive '{primitive.name}' has neither trace nor abstract_eval defined. "
160
+ f"Define one using @{primitive.name}_p.def_trace or @{primitive.name}_p.def_abstract_eval"
161
+ )
162
+
163
+ def lift(self, obj: Any, *, is_param: bool = False) -> Any:
164
+ """Lift an object to TraceObject.
165
+
166
+ Converts objects to TraceObject for use in tracing:
167
+ - Non-Object types: return as-is (int, float, np.ndarray, callables, etc.)
168
+ - TraceObject (same context): return as-is (idempotent)
169
+ - TraceObject (different context): create graph input
170
+ - InterpObject: promote to TraceObject
171
+
172
+ Args:
173
+ obj: Value to lift (Object or non-Object constant)
174
+ is_param: If True, create independent graph input (no caching).
175
+ If False, cache by id() for captures (same object → same input).
176
+
177
+ Returns:
178
+ TraceObject for Objects, or original value for non-Objects
179
+
180
+ Note:
181
+ - Parameters (is_param=True): Each position gets independent input,
182
+ so `trace(fn, x, x)` creates two separate graph inputs.
183
+ - Captures (is_param=False): Cached by id(), so the same captured
184
+ object always maps to the same graph input.
185
+
186
+ Subclass extension:
187
+ Override _lift_type() to customize type transformation
188
+ (e.g., unwrap MPType → value_type, TensorType → element_type).
189
+ """
190
+ # Early return for non-Object types (constants, callables, etc.)
191
+ if not isinstance(obj, Object):
192
+ return obj
193
+
194
+ # Same-context TraceObject → return as-is (idempotent)
195
+ if isinstance(obj, TraceObject) and obj._context is self:
196
+ return obj
197
+
198
+ # Parameters: always create fresh input (no caching)
199
+ if is_param:
200
+ return self._new_arg(self._lift_type(obj))
201
+
202
+ # Captures: cache by id()
203
+ obj_id = id(obj)
204
+ if obj_id in self._captured_vars:
205
+ _, graph_value = self._captured_vars[obj_id]
206
+ return TraceObject(graph_value, self)
207
+
208
+ lifted = self._new_arg(self._lift_type(obj))
209
+ self._captured_vars[obj_id] = (obj, lifted._graph_value)
210
+ return lifted
211
+
212
+ def _lift_type(self, obj: Object) -> BaseType:
213
+ """Get the graph input type for an object.
214
+
215
+ Subclasses override this to customize type transformation:
216
+ - _LocalMPTracer: unwrap MPType → value_type
217
+ - _ElementwiseTracer: unwrap TensorType → element_type
218
+
219
+ The base class preserves the object's type unchanged.
220
+
221
+ Args:
222
+ obj: Object being lifted to a graph input
223
+
224
+ Returns:
225
+ The type to use for the graph input
226
+ """
227
+ return cast(BaseType, obj.type)
228
+
229
+ def _new_arg(self, arg_type: BaseType) -> TraceObject:
230
+ """Create a new graph input for the given type.
231
+
232
+ Internal method - prefer using lift() which handles caching logic.
233
+ Use this for function parameters where each position should be independent.
234
+
235
+ Args:
236
+ arg_type: The type of the argument
237
+
238
+ Returns:
239
+ TraceObject wrapping a new graph input Value
240
+ """
241
+ name = f"%arg{self._arg_counter}"
242
+ self._arg_counter += 1
243
+ graph_value = self.graph.add_input(
244
+ name=name,
245
+ type=arg_type,
246
+ )
247
+ return TraceObject(graph_value, self)
248
+
249
+ def finalize(self, result: Any) -> Graph:
250
+ """Finalize the graph by setting outputs.
251
+
252
+ This marks the traced result as the outputs of the graph,
253
+ completing the graph construction. After this, the graph
254
+ is ready for interpretation or transformation.
255
+
256
+ Args:
257
+ result: Traced result, PyTree containing TraceObjects
258
+
259
+ Returns:
260
+ The finalized graph (self.graph with outputs set)
261
+
262
+ Example:
263
+ >>> tracer = Tracer()
264
+ >>> push_context(tracer)
265
+ >>> result = do_something(x, y)
266
+ >>> pop_context()
267
+ >>> graph = tracer.finalize(result)
268
+ """
269
+ out_flat, _out_tree = tree_flatten(result)
270
+ for out in out_flat:
271
+ if not isinstance(out, TraceObject) or out._context is not self:
272
+ raise TypeError(
273
+ f"Graph output must be TraceObject from this Tracer context, got: {type(out)}"
274
+ )
275
+ self.graph.add_output(out._graph_value)
276
+
277
+ return self.graph # type: ignore[return-value]
278
+
279
+ def run(
280
+ self,
281
+ fn: Callable[..., Any],
282
+ *args: Any,
283
+ **kwargs: Any,
284
+ ) -> TracedFunction:
285
+ """Trace `fn` using this tracer instance.
286
+
287
+ Parameter handling:
288
+ Each parameter position gets an independent graph input via new_arg(),
289
+ even if the same Python object is passed multiple times. This ensures
290
+ correct semantics: `trace(fn, x, x)` creates two separate inputs.
291
+
292
+ Capture handling:
293
+ Variables captured from closures are cached by id() via lift(),
294
+ so the same captured object always maps to the same graph input.
295
+ """
296
+ self.reset()
297
+ if not callable(fn):
298
+ raise TypeError(f"fn must be callable, got {type(fn)}")
299
+
300
+ fn_name = getattr(fn, "__name__", "anonymous")
301
+ in_flat, in_treedef = tree_flatten((args, kwargs))
302
+ in_imms, in_var_pos, in_vars = _separate_vars_and_imms(in_flat)
303
+
304
+ with self:
305
+ # Helper to lift params, allowing BaseType as placeholders
306
+ def lift_param(obj: Any) -> Any:
307
+ if isinstance(obj, Object):
308
+ return self.lift(obj, is_param=True)
309
+ return obj
310
+
311
+ # Lift parameters with is_param=True (each position gets independent input)
312
+ args_traced, kwargs_traced = tree_map(lift_param, (args, kwargs))
313
+
314
+ result = fn(*args_traced, **kwargs_traced)
315
+ # Lift any Objects in result (captures use default is_param=False)
316
+ result = tree_map(self.lift, result)
317
+
318
+ output_flat, output_treedef = tree_flatten(result)
319
+ out_imms, out_var_pos, out_vars = _separate_vars_and_imms(output_flat)
320
+
321
+ if out_vars:
322
+ graph = self.finalize(out_vars)
323
+ else:
324
+ graph = self.graph
325
+ graph.outputs = []
326
+
327
+ # Captured objects are those in _captured_vars (excludes parameters)
328
+ captured_objects: list[Object] = [
329
+ obj for obj, _ in self._captured_vars.values()
330
+ ]
331
+
332
+ return TracedFunction(
333
+ name=fn_name,
334
+ graph=graph,
335
+ in_imms=in_imms,
336
+ in_var_pos=in_var_pos,
337
+ in_tree=in_treedef,
338
+ out_imms=out_imms,
339
+ out_var_pos=out_var_pos,
340
+ out_tree=output_treedef,
341
+ params=in_vars, # Original parameter objects
342
+ captured=captured_objects,
343
+ )
344
+
345
+ def reconstruct_outputs(
346
+ self,
347
+ out_var_pos: list[int],
348
+ out_imms: list[Any],
349
+ out_tree: PyTreeDef,
350
+ result_values: list[GraphValue],
351
+ ) -> Any:
352
+ """Rebuild PyTree outputs from recorded metadata."""
353
+
354
+ var_iter = iter([TraceObject(val, self) for val in result_values])
355
+ var_pos_iter = iter(out_var_pos)
356
+ next_var_pos = next(var_pos_iter, None)
357
+ imm_idx = 0
358
+ total_len = len(out_imms) + len(out_var_pos)
359
+ flat_out: list[Any] = []
360
+ for idx in range(total_len):
361
+ if next_var_pos is not None and idx == next_var_pos:
362
+ flat_out.append(next(var_iter))
363
+ next_var_pos = next(var_pos_iter, None)
364
+ else:
365
+ flat_out.append(out_imms[imm_idx])
366
+ imm_idx += 1
367
+ return out_tree.unflatten(flat_out)
368
+
369
+
370
+ def _separate_vars_and_imms(
371
+ flat_values: list[Any],
372
+ ) -> tuple[list[Any], list[int], list[Any]]:
373
+ """Separate a flattened list into variables (Objects) and immediates (constants).
374
+
375
+ Args:
376
+ flat_values: Flattened list of values (mix of Objects and constants)
377
+
378
+ Returns:
379
+ Tuple of (imms, var_pos, vars) where:
380
+ - imms: List of immediate values (constants) in order
381
+ - var_pos: List of positions where variables appear in flat_values
382
+ - vars: List of variable values (Objects) in order
383
+ """
384
+ imms = []
385
+ var_pos = []
386
+ vars_list = []
387
+
388
+ for i, val in enumerate(flat_values):
389
+ if isinstance(val, Object):
390
+ var_pos.append(i)
391
+ vars_list.append(val)
392
+ else:
393
+ imms.append(val)
394
+
395
+ return imms, var_pos, vars_list
396
+
397
+
398
+ @dataclass
399
+ class TracedFunction:
400
+ """Result of tracing a Python function into Graph IR.
401
+
402
+ Represents a fully Pythonic function captured as a graph, distinguishing
403
+ between constants (immediates) and traced values (graph inputs/outputs).
404
+
405
+ Graph Inputs Order Convention:
406
+ graph.inputs = [*params_inputs, *captured_inputs]
407
+ - First len(params) inputs correspond to function parameters
408
+ - Remaining inputs correspond to captured variables (closures)
409
+
410
+ Attributes:
411
+ name: Function name (from fn.__name__)
412
+ graph: The finalized Graph IR containing traced computations
413
+ in_imms: Input immediates (constants) in flattened order
414
+ in_var_pos: Positions of graph.inputs in the flattened input list
415
+ in_tree: PyTreeDef to reconstruct (args, kwargs) from flattened inputs
416
+ out_imms: Output immediates (constants) in flattened order
417
+ out_var_pos: Positions of graph.outputs in the flattened output list
418
+ out_tree: PyTreeDef to reconstruct result from flattened outputs
419
+ params: Original parameter Objects (in order matching graph.inputs[:len(params)])
420
+ captured: Captured Objects from closures (in order matching graph.inputs[len(params):])
421
+
422
+ Reconstruction:
423
+ To reconstruct *args, **kwargs from graph.inputs:
424
+ 1. Create flattened list: [in_imms[i] if i not in in_var_pos else graph.inputs[...]]
425
+ 2. Use in_tree.unflatten() to get (args, kwargs)
426
+
427
+ To reconstruct result from graph.outputs:
428
+ 1. Create flattened list: [out_imms[i] if i not in out_var_pos else graph.outputs[...]]
429
+ 2. Use out_tree.unflatten() to get result
430
+
431
+ Example:
432
+ >>> def fn(x, y, *, scale=2.0):
433
+ ... return x + y, scale
434
+ >>> traced = make_graph(fn, x_obj, y_obj, scale=2.0)
435
+ >>> # in_imms = [2.0], in_var_pos = [0, 1] (x, y are vars)
436
+ >>> # out_imms = [2.0], out_var_pos = [0] (x+y is var, scale is constant)
437
+ >>> # params = [x_obj, y_obj], captured = []
438
+ """
439
+
440
+ name: str
441
+ graph: Graph
442
+ in_imms: list[Any]
443
+ in_var_pos: list[int]
444
+ in_tree: PyTreeDef
445
+ out_imms: list[Any]
446
+ out_var_pos: list[int]
447
+ out_tree: PyTreeDef
448
+ params: list[Object] # Original parameter objects
449
+ captured: list[Object] # Captured objects from closures
450
+
451
+ def is_input_signature_match(self, other: TracedFunction) -> bool:
452
+ """Check if this TracedFunction has the same input signature as another.
453
+
454
+ Args:
455
+ other: Another TracedFunction to compare against
456
+
457
+ Returns:
458
+ True if input counts and types match, False otherwise
459
+ """
460
+ if len(self.graph.inputs) != len(other.graph.inputs):
461
+ return False
462
+ return all(
463
+ self_in.type == other_in.type
464
+ for self_in, other_in in zip(
465
+ self.graph.inputs, other.graph.inputs, strict=True
466
+ )
467
+ )
468
+
469
+ def is_output_signature_match(self, other: TracedFunction) -> bool:
470
+ """Check if this TracedFunction has the same output signature as another.
471
+
472
+ Args:
473
+ other: Another TracedFunction to compare against
474
+
475
+ Returns:
476
+ True if output counts and types match, False otherwise
477
+ """
478
+ if len(self.graph.outputs) != len(other.graph.outputs):
479
+ return False
480
+ return all(
481
+ self_out.type == other_out.type
482
+ for self_out, other_out in zip(
483
+ self.graph.outputs, other.graph.outputs, strict=True
484
+ )
485
+ )
486
+
487
+ def compiler_ir(self, verbose: bool = False) -> str:
488
+ """Get human-readable IR representation of the traced function.
489
+
490
+ This is useful for debugging, auditing, and understanding what
491
+ operations were captured during tracing.
492
+
493
+ Args:
494
+ verbose: If True, include type annotations in the output
495
+
496
+ Returns:
497
+ String representation of the Graph IR
498
+
499
+ Example:
500
+ >>> traced = compile(lambda x, y: x + y, x_obj, y_obj)
501
+ >>> print(traced.compiler_ir())
502
+ %arg0 = input
503
+ %arg1 = input
504
+ %0 = add(%arg0, %arg1)
505
+ return %0
506
+ """
507
+ return self.graph.to_string(verbose=verbose)
508
+
509
+ def align_region_inputs(
510
+ self, leading_count: int, capture_order: list[Object]
511
+ ) -> None:
512
+ """Align region graph inputs as [leading_values..., captures...] sequence.
513
+
514
+ Reorders the graph inputs to have a standardized structure:
515
+ - First `leading_count` inputs: explicit function parameters
516
+ - Remaining inputs: captured variables in the specified order
517
+
518
+ This is essential for multi-region primitives (e.g., uniform_cond, while_loop)
519
+ where different regions need to share the same capture ordering.
520
+
521
+ Args:
522
+ leading_count: Number of explicit function parameters (non-captured)
523
+ capture_order: Desired order of captured variables
524
+
525
+ Example:
526
+ >>> # Align two branches to have same capture order
527
+ >>> all_captures = merge_captures(then_fn.captured, else_fn.captured)
528
+ >>> then_fn.align_region_inputs(num_args, all_captures)
529
+ >>> else_fn.align_region_inputs(num_args, all_captures)
530
+ """
531
+ assert len(self.graph.inputs) >= leading_count
532
+
533
+ leading_inputs = self.graph.inputs[:leading_count]
534
+ capture_inputs = self.graph.inputs[leading_count:]
535
+ capture_map = (
536
+ dict(zip(self.captured, capture_inputs, strict=True))
537
+ if self.captured
538
+ else {}
539
+ )
540
+
541
+ new_capture_inputs = []
542
+ for capture_obj in capture_order:
543
+ value = capture_map.get(capture_obj)
544
+ if value is None:
545
+ value = self.graph.add_input(
546
+ name=f"%capture{len(self.graph.inputs)}",
547
+ type=capture_obj.type,
548
+ )
549
+ new_capture_inputs.append(value)
550
+
551
+ self.graph.inputs = leading_inputs + new_capture_inputs
552
+ self.captured = list(capture_order)
553
+
554
+ def prepare_inputs(self, *args: Any, **kwargs: Any) -> list[Any]:
555
+ """Flatten arguments and map them to graph inputs.
556
+
557
+ Used by the runtime to prepare inputs for graph execution.
558
+
559
+ Args:
560
+ *args: Positional arguments for the function.
561
+ **kwargs: Keyword arguments for the function.
562
+
563
+ Returns:
564
+ List of values corresponding to graph.inputs (may include InterpObject).
565
+ The caller is responsible for unwrapping InterpObject at execution boundary.
566
+ """
567
+ flat_args, _ = tree_flatten((args, kwargs))
568
+
569
+ # Map to graph inputs
570
+ # fn.in_var_pos contains indices in flat_args that correspond to graph inputs
571
+ # Note: graph.inputs = [explicit_inputs...] + [captured_inputs...]
572
+ explicit_inputs = [flat_args[i] for i in self.in_var_pos]
573
+ all_inputs = explicit_inputs + list(self.captured)
574
+ return all_inputs
575
+
576
+ def reconstruct_outputs(self, execution_result: list[Any]) -> Any:
577
+ """Reconstruct structured output from execution result.
578
+
579
+ Used by the runtime to format the result of graph execution.
580
+
581
+ Args:
582
+ execution_result: List of results from interpreter.evaluate_graph().
583
+
584
+ Returns:
585
+ Structured output matching the original function's return signature.
586
+ """
587
+ # execution_result is always a list (now that evaluate_graph returns list)
588
+ results = execution_result
589
+
590
+ # Reconstruct
591
+ total_len = len(self.out_imms) + len(self.out_var_pos)
592
+ flat_out = [None] * total_len
593
+
594
+ var_indices = set(self.out_var_pos)
595
+ imm_iter = iter(self.out_imms)
596
+ res_iter = iter(results)
597
+
598
+ for i in range(total_len):
599
+ if i in var_indices:
600
+ flat_out[i] = next(res_iter)
601
+ else:
602
+ flat_out[i] = next(imm_iter)
603
+
604
+ return self.out_tree.unflatten(flat_out)
605
+
606
+
607
+ def trace(
608
+ fn: Callable[..., Any],
609
+ *args: Any,
610
+ **kwargs: Any,
611
+ ) -> TracedFunction:
612
+ """Trace a Python function with the default Tracer."""
613
+
614
+ return Tracer().run(fn, *args, **kwargs)