mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/v2/edsl/tracer.py
ADDED
|
@@ -0,0 +1,614 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Tracer: Python Function → Graph IR.
|
|
16
|
+
|
|
17
|
+
Responsible for converting Python functions to Graph IR, handling:
|
|
18
|
+
- Function parameters
|
|
19
|
+
- Free variables (external references including captures)
|
|
20
|
+
- Polymorphic handling of TraceObject/InterpObject
|
|
21
|
+
|
|
22
|
+
Tracer is a Context (inherits from Context abstract base class).
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
import inspect
|
|
28
|
+
from collections.abc import Callable
|
|
29
|
+
from dataclasses import dataclass
|
|
30
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
31
|
+
|
|
32
|
+
from jax.tree_util import PyTreeDef, tree_flatten, tree_map
|
|
33
|
+
|
|
34
|
+
from mplang.v2.edsl.context import Context
|
|
35
|
+
from mplang.v2.edsl.graph import Graph
|
|
36
|
+
from mplang.v2.edsl.graph import Value as GraphValue
|
|
37
|
+
from mplang.v2.edsl.object import Object
|
|
38
|
+
from mplang.v2.edsl.typing import BaseType
|
|
39
|
+
|
|
40
|
+
if TYPE_CHECKING:
|
|
41
|
+
from mplang.v2.edsl.primitive import Primitive
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TraceObject(Object):
|
|
45
|
+
"""Trace-time object (during JIT tracing).
|
|
46
|
+
|
|
47
|
+
Holds a Value in the Graph IR and a reference to the Tracer (Context).
|
|
48
|
+
All operations delegate to primitives which record into Graph.
|
|
49
|
+
|
|
50
|
+
Example:
|
|
51
|
+
>>> from mplang.v2.edsl import trace
|
|
52
|
+
>>> def compute(x, y):
|
|
53
|
+
... z = x + y # TraceObject.__add__ → add_p.bind(x, y)
|
|
54
|
+
... return z
|
|
55
|
+
>>> graph = trace(compute, x_interp, y_interp)
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(self, graph_value: GraphValue, tracer: Tracer):
|
|
59
|
+
self._graph_value = graph_value
|
|
60
|
+
self._context = tracer
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def type(self) -> BaseType:
|
|
64
|
+
return self._graph_value.type
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def _tracer(self) -> Tracer:
|
|
68
|
+
"""Backward compatibility: access Tracer via _context."""
|
|
69
|
+
return self._context
|
|
70
|
+
|
|
71
|
+
def __repr__(self) -> str:
|
|
72
|
+
return f"TraceObject({self._graph_value.name}: {self.type})"
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class Tracer(Context):
|
|
76
|
+
"""Converter from Python Function to Graph IR.
|
|
77
|
+
|
|
78
|
+
Inherits from Context and implements bind_primitive() by recording to Graph.
|
|
79
|
+
|
|
80
|
+
Responsibilities:
|
|
81
|
+
1. Convert Python functions to Graph IR
|
|
82
|
+
2. Manage free variables (function params and captured external references)
|
|
83
|
+
3. Handle Object Hierarchy (TraceObject/InterpObject)
|
|
84
|
+
4. Promote InterpObject → TraceObject
|
|
85
|
+
5. Implement Context.bind_primitive() by recording to Graph
|
|
86
|
+
|
|
87
|
+
Example:
|
|
88
|
+
>>> tracer = Tracer()
|
|
89
|
+
>>> graph = tracer.trace(lambda x, y: x + y, x_interp, y_interp)
|
|
90
|
+
>>> print(graph)
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
def __init__(self) -> None:
|
|
94
|
+
self.reset()
|
|
95
|
+
|
|
96
|
+
def reset(self) -> None:
|
|
97
|
+
"""Reset graph state so a tracer instance can be reused."""
|
|
98
|
+
self.graph = Graph()
|
|
99
|
+
# Cache for captured variables (closures), keyed by id(obj)
|
|
100
|
+
# Does NOT include function parameters - those are created per-position
|
|
101
|
+
self._captured_vars: dict[int, tuple[Object, GraphValue]] = {}
|
|
102
|
+
self._arg_counter = 0
|
|
103
|
+
|
|
104
|
+
def bind_primitive(
|
|
105
|
+
self, primitive: Primitive, args: tuple[Any, ...], kwargs: dict[str, Any]
|
|
106
|
+
) -> TraceObject | list[TraceObject] | Any:
|
|
107
|
+
"""Execute primitive by recording to Graph IR (trace mode).
|
|
108
|
+
|
|
109
|
+
Handles two modes:
|
|
110
|
+
1. def_trace: Primitive has full control - builds graph via other primitives
|
|
111
|
+
2. def_abstract_eval: Tracer controls - infers types and builds operation
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
primitive: The primitive to trace
|
|
115
|
+
args: Positional arguments (can be Objects, opaques like callables, or constants)
|
|
116
|
+
kwargs: Keyword arguments (can be Objects, opaques, or constants)
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
TraceObject, list[TraceObject], or PyTree containing TraceObjects
|
|
120
|
+
|
|
121
|
+
Raises:
|
|
122
|
+
RuntimeError: If primitive has neither trace nor abstract_eval defined
|
|
123
|
+
"""
|
|
124
|
+
if primitive._trace is not None:
|
|
125
|
+
return primitive._trace(*args, **kwargs)
|
|
126
|
+
|
|
127
|
+
if primitive._abstract_eval is not None:
|
|
128
|
+
trace_args = list(args)
|
|
129
|
+
input_objects = [arg for arg in trace_args if isinstance(arg, TraceObject)]
|
|
130
|
+
input_types = [obj.type for obj in input_objects]
|
|
131
|
+
|
|
132
|
+
sig = inspect.signature(primitive._abstract_eval)
|
|
133
|
+
params = list(sig.parameters.values())
|
|
134
|
+
# Detect flat style: first param is list-annotated "in_types"
|
|
135
|
+
is_flat_style = len(params) >= 1 and params[0].name == "in_types"
|
|
136
|
+
|
|
137
|
+
if is_flat_style:
|
|
138
|
+
output_types = primitive._abstract_eval(input_types, **kwargs)
|
|
139
|
+
else:
|
|
140
|
+
output_types = primitive._abstract_eval(*input_types, **kwargs)
|
|
141
|
+
|
|
142
|
+
# Normalize to list: single type or sequence → list
|
|
143
|
+
if isinstance(output_types, BaseType):
|
|
144
|
+
output_types = [output_types]
|
|
145
|
+
else:
|
|
146
|
+
output_types = list(output_types)
|
|
147
|
+
|
|
148
|
+
input_values = [obj._graph_value for obj in input_objects]
|
|
149
|
+
result_values = self.graph.add_op(
|
|
150
|
+
opcode=primitive.name,
|
|
151
|
+
inputs=input_values,
|
|
152
|
+
output_types=output_types,
|
|
153
|
+
attrs=kwargs,
|
|
154
|
+
)
|
|
155
|
+
outs = [TraceObject(v, self) for v in result_values]
|
|
156
|
+
return outs[0] if len(outs) == 1 else outs
|
|
157
|
+
|
|
158
|
+
raise RuntimeError(
|
|
159
|
+
f"Primitive '{primitive.name}' has neither trace nor abstract_eval defined. "
|
|
160
|
+
f"Define one using @{primitive.name}_p.def_trace or @{primitive.name}_p.def_abstract_eval"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
def lift(self, obj: Any, *, is_param: bool = False) -> Any:
|
|
164
|
+
"""Lift an object to TraceObject.
|
|
165
|
+
|
|
166
|
+
Converts objects to TraceObject for use in tracing:
|
|
167
|
+
- Non-Object types: return as-is (int, float, np.ndarray, callables, etc.)
|
|
168
|
+
- TraceObject (same context): return as-is (idempotent)
|
|
169
|
+
- TraceObject (different context): create graph input
|
|
170
|
+
- InterpObject: promote to TraceObject
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
obj: Value to lift (Object or non-Object constant)
|
|
174
|
+
is_param: If True, create independent graph input (no caching).
|
|
175
|
+
If False, cache by id() for captures (same object → same input).
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
TraceObject for Objects, or original value for non-Objects
|
|
179
|
+
|
|
180
|
+
Note:
|
|
181
|
+
- Parameters (is_param=True): Each position gets independent input,
|
|
182
|
+
so `trace(fn, x, x)` creates two separate graph inputs.
|
|
183
|
+
- Captures (is_param=False): Cached by id(), so the same captured
|
|
184
|
+
object always maps to the same graph input.
|
|
185
|
+
|
|
186
|
+
Subclass extension:
|
|
187
|
+
Override _lift_type() to customize type transformation
|
|
188
|
+
(e.g., unwrap MPType → value_type, TensorType → element_type).
|
|
189
|
+
"""
|
|
190
|
+
# Early return for non-Object types (constants, callables, etc.)
|
|
191
|
+
if not isinstance(obj, Object):
|
|
192
|
+
return obj
|
|
193
|
+
|
|
194
|
+
# Same-context TraceObject → return as-is (idempotent)
|
|
195
|
+
if isinstance(obj, TraceObject) and obj._context is self:
|
|
196
|
+
return obj
|
|
197
|
+
|
|
198
|
+
# Parameters: always create fresh input (no caching)
|
|
199
|
+
if is_param:
|
|
200
|
+
return self._new_arg(self._lift_type(obj))
|
|
201
|
+
|
|
202
|
+
# Captures: cache by id()
|
|
203
|
+
obj_id = id(obj)
|
|
204
|
+
if obj_id in self._captured_vars:
|
|
205
|
+
_, graph_value = self._captured_vars[obj_id]
|
|
206
|
+
return TraceObject(graph_value, self)
|
|
207
|
+
|
|
208
|
+
lifted = self._new_arg(self._lift_type(obj))
|
|
209
|
+
self._captured_vars[obj_id] = (obj, lifted._graph_value)
|
|
210
|
+
return lifted
|
|
211
|
+
|
|
212
|
+
def _lift_type(self, obj: Object) -> BaseType:
|
|
213
|
+
"""Get the graph input type for an object.
|
|
214
|
+
|
|
215
|
+
Subclasses override this to customize type transformation:
|
|
216
|
+
- _LocalMPTracer: unwrap MPType → value_type
|
|
217
|
+
- _ElementwiseTracer: unwrap TensorType → element_type
|
|
218
|
+
|
|
219
|
+
The base class preserves the object's type unchanged.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
obj: Object being lifted to a graph input
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
The type to use for the graph input
|
|
226
|
+
"""
|
|
227
|
+
return cast(BaseType, obj.type)
|
|
228
|
+
|
|
229
|
+
def _new_arg(self, arg_type: BaseType) -> TraceObject:
|
|
230
|
+
"""Create a new graph input for the given type.
|
|
231
|
+
|
|
232
|
+
Internal method - prefer using lift() which handles caching logic.
|
|
233
|
+
Use this for function parameters where each position should be independent.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
arg_type: The type of the argument
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
TraceObject wrapping a new graph input Value
|
|
240
|
+
"""
|
|
241
|
+
name = f"%arg{self._arg_counter}"
|
|
242
|
+
self._arg_counter += 1
|
|
243
|
+
graph_value = self.graph.add_input(
|
|
244
|
+
name=name,
|
|
245
|
+
type=arg_type,
|
|
246
|
+
)
|
|
247
|
+
return TraceObject(graph_value, self)
|
|
248
|
+
|
|
249
|
+
def finalize(self, result: Any) -> Graph:
|
|
250
|
+
"""Finalize the graph by setting outputs.
|
|
251
|
+
|
|
252
|
+
This marks the traced result as the outputs of the graph,
|
|
253
|
+
completing the graph construction. After this, the graph
|
|
254
|
+
is ready for interpretation or transformation.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
result: Traced result, PyTree containing TraceObjects
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
The finalized graph (self.graph with outputs set)
|
|
261
|
+
|
|
262
|
+
Example:
|
|
263
|
+
>>> tracer = Tracer()
|
|
264
|
+
>>> push_context(tracer)
|
|
265
|
+
>>> result = do_something(x, y)
|
|
266
|
+
>>> pop_context()
|
|
267
|
+
>>> graph = tracer.finalize(result)
|
|
268
|
+
"""
|
|
269
|
+
out_flat, _out_tree = tree_flatten(result)
|
|
270
|
+
for out in out_flat:
|
|
271
|
+
if not isinstance(out, TraceObject) or out._context is not self:
|
|
272
|
+
raise TypeError(
|
|
273
|
+
f"Graph output must be TraceObject from this Tracer context, got: {type(out)}"
|
|
274
|
+
)
|
|
275
|
+
self.graph.add_output(out._graph_value)
|
|
276
|
+
|
|
277
|
+
return self.graph # type: ignore[return-value]
|
|
278
|
+
|
|
279
|
+
def run(
|
|
280
|
+
self,
|
|
281
|
+
fn: Callable[..., Any],
|
|
282
|
+
*args: Any,
|
|
283
|
+
**kwargs: Any,
|
|
284
|
+
) -> TracedFunction:
|
|
285
|
+
"""Trace `fn` using this tracer instance.
|
|
286
|
+
|
|
287
|
+
Parameter handling:
|
|
288
|
+
Each parameter position gets an independent graph input via new_arg(),
|
|
289
|
+
even if the same Python object is passed multiple times. This ensures
|
|
290
|
+
correct semantics: `trace(fn, x, x)` creates two separate inputs.
|
|
291
|
+
|
|
292
|
+
Capture handling:
|
|
293
|
+
Variables captured from closures are cached by id() via lift(),
|
|
294
|
+
so the same captured object always maps to the same graph input.
|
|
295
|
+
"""
|
|
296
|
+
self.reset()
|
|
297
|
+
if not callable(fn):
|
|
298
|
+
raise TypeError(f"fn must be callable, got {type(fn)}")
|
|
299
|
+
|
|
300
|
+
fn_name = getattr(fn, "__name__", "anonymous")
|
|
301
|
+
in_flat, in_treedef = tree_flatten((args, kwargs))
|
|
302
|
+
in_imms, in_var_pos, in_vars = _separate_vars_and_imms(in_flat)
|
|
303
|
+
|
|
304
|
+
with self:
|
|
305
|
+
# Helper to lift params, allowing BaseType as placeholders
|
|
306
|
+
def lift_param(obj: Any) -> Any:
|
|
307
|
+
if isinstance(obj, Object):
|
|
308
|
+
return self.lift(obj, is_param=True)
|
|
309
|
+
return obj
|
|
310
|
+
|
|
311
|
+
# Lift parameters with is_param=True (each position gets independent input)
|
|
312
|
+
args_traced, kwargs_traced = tree_map(lift_param, (args, kwargs))
|
|
313
|
+
|
|
314
|
+
result = fn(*args_traced, **kwargs_traced)
|
|
315
|
+
# Lift any Objects in result (captures use default is_param=False)
|
|
316
|
+
result = tree_map(self.lift, result)
|
|
317
|
+
|
|
318
|
+
output_flat, output_treedef = tree_flatten(result)
|
|
319
|
+
out_imms, out_var_pos, out_vars = _separate_vars_and_imms(output_flat)
|
|
320
|
+
|
|
321
|
+
if out_vars:
|
|
322
|
+
graph = self.finalize(out_vars)
|
|
323
|
+
else:
|
|
324
|
+
graph = self.graph
|
|
325
|
+
graph.outputs = []
|
|
326
|
+
|
|
327
|
+
# Captured objects are those in _captured_vars (excludes parameters)
|
|
328
|
+
captured_objects: list[Object] = [
|
|
329
|
+
obj for obj, _ in self._captured_vars.values()
|
|
330
|
+
]
|
|
331
|
+
|
|
332
|
+
return TracedFunction(
|
|
333
|
+
name=fn_name,
|
|
334
|
+
graph=graph,
|
|
335
|
+
in_imms=in_imms,
|
|
336
|
+
in_var_pos=in_var_pos,
|
|
337
|
+
in_tree=in_treedef,
|
|
338
|
+
out_imms=out_imms,
|
|
339
|
+
out_var_pos=out_var_pos,
|
|
340
|
+
out_tree=output_treedef,
|
|
341
|
+
params=in_vars, # Original parameter objects
|
|
342
|
+
captured=captured_objects,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
def reconstruct_outputs(
|
|
346
|
+
self,
|
|
347
|
+
out_var_pos: list[int],
|
|
348
|
+
out_imms: list[Any],
|
|
349
|
+
out_tree: PyTreeDef,
|
|
350
|
+
result_values: list[GraphValue],
|
|
351
|
+
) -> Any:
|
|
352
|
+
"""Rebuild PyTree outputs from recorded metadata."""
|
|
353
|
+
|
|
354
|
+
var_iter = iter([TraceObject(val, self) for val in result_values])
|
|
355
|
+
var_pos_iter = iter(out_var_pos)
|
|
356
|
+
next_var_pos = next(var_pos_iter, None)
|
|
357
|
+
imm_idx = 0
|
|
358
|
+
total_len = len(out_imms) + len(out_var_pos)
|
|
359
|
+
flat_out: list[Any] = []
|
|
360
|
+
for idx in range(total_len):
|
|
361
|
+
if next_var_pos is not None and idx == next_var_pos:
|
|
362
|
+
flat_out.append(next(var_iter))
|
|
363
|
+
next_var_pos = next(var_pos_iter, None)
|
|
364
|
+
else:
|
|
365
|
+
flat_out.append(out_imms[imm_idx])
|
|
366
|
+
imm_idx += 1
|
|
367
|
+
return out_tree.unflatten(flat_out)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def _separate_vars_and_imms(
|
|
371
|
+
flat_values: list[Any],
|
|
372
|
+
) -> tuple[list[Any], list[int], list[Any]]:
|
|
373
|
+
"""Separate a flattened list into variables (Objects) and immediates (constants).
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
flat_values: Flattened list of values (mix of Objects and constants)
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
Tuple of (imms, var_pos, vars) where:
|
|
380
|
+
- imms: List of immediate values (constants) in order
|
|
381
|
+
- var_pos: List of positions where variables appear in flat_values
|
|
382
|
+
- vars: List of variable values (Objects) in order
|
|
383
|
+
"""
|
|
384
|
+
imms = []
|
|
385
|
+
var_pos = []
|
|
386
|
+
vars_list = []
|
|
387
|
+
|
|
388
|
+
for i, val in enumerate(flat_values):
|
|
389
|
+
if isinstance(val, Object):
|
|
390
|
+
var_pos.append(i)
|
|
391
|
+
vars_list.append(val)
|
|
392
|
+
else:
|
|
393
|
+
imms.append(val)
|
|
394
|
+
|
|
395
|
+
return imms, var_pos, vars_list
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
@dataclass
|
|
399
|
+
class TracedFunction:
|
|
400
|
+
"""Result of tracing a Python function into Graph IR.
|
|
401
|
+
|
|
402
|
+
Represents a fully Pythonic function captured as a graph, distinguishing
|
|
403
|
+
between constants (immediates) and traced values (graph inputs/outputs).
|
|
404
|
+
|
|
405
|
+
Graph Inputs Order Convention:
|
|
406
|
+
graph.inputs = [*params_inputs, *captured_inputs]
|
|
407
|
+
- First len(params) inputs correspond to function parameters
|
|
408
|
+
- Remaining inputs correspond to captured variables (closures)
|
|
409
|
+
|
|
410
|
+
Attributes:
|
|
411
|
+
name: Function name (from fn.__name__)
|
|
412
|
+
graph: The finalized Graph IR containing traced computations
|
|
413
|
+
in_imms: Input immediates (constants) in flattened order
|
|
414
|
+
in_var_pos: Positions of graph.inputs in the flattened input list
|
|
415
|
+
in_tree: PyTreeDef to reconstruct (args, kwargs) from flattened inputs
|
|
416
|
+
out_imms: Output immediates (constants) in flattened order
|
|
417
|
+
out_var_pos: Positions of graph.outputs in the flattened output list
|
|
418
|
+
out_tree: PyTreeDef to reconstruct result from flattened outputs
|
|
419
|
+
params: Original parameter Objects (in order matching graph.inputs[:len(params)])
|
|
420
|
+
captured: Captured Objects from closures (in order matching graph.inputs[len(params):])
|
|
421
|
+
|
|
422
|
+
Reconstruction:
|
|
423
|
+
To reconstruct *args, **kwargs from graph.inputs:
|
|
424
|
+
1. Create flattened list: [in_imms[i] if i not in in_var_pos else graph.inputs[...]]
|
|
425
|
+
2. Use in_tree.unflatten() to get (args, kwargs)
|
|
426
|
+
|
|
427
|
+
To reconstruct result from graph.outputs:
|
|
428
|
+
1. Create flattened list: [out_imms[i] if i not in out_var_pos else graph.outputs[...]]
|
|
429
|
+
2. Use out_tree.unflatten() to get result
|
|
430
|
+
|
|
431
|
+
Example:
|
|
432
|
+
>>> def fn(x, y, *, scale=2.0):
|
|
433
|
+
... return x + y, scale
|
|
434
|
+
>>> traced = make_graph(fn, x_obj, y_obj, scale=2.0)
|
|
435
|
+
>>> # in_imms = [2.0], in_var_pos = [0, 1] (x, y are vars)
|
|
436
|
+
>>> # out_imms = [2.0], out_var_pos = [0] (x+y is var, scale is constant)
|
|
437
|
+
>>> # params = [x_obj, y_obj], captured = []
|
|
438
|
+
"""
|
|
439
|
+
|
|
440
|
+
name: str
|
|
441
|
+
graph: Graph
|
|
442
|
+
in_imms: list[Any]
|
|
443
|
+
in_var_pos: list[int]
|
|
444
|
+
in_tree: PyTreeDef
|
|
445
|
+
out_imms: list[Any]
|
|
446
|
+
out_var_pos: list[int]
|
|
447
|
+
out_tree: PyTreeDef
|
|
448
|
+
params: list[Object] # Original parameter objects
|
|
449
|
+
captured: list[Object] # Captured objects from closures
|
|
450
|
+
|
|
451
|
+
def is_input_signature_match(self, other: TracedFunction) -> bool:
|
|
452
|
+
"""Check if this TracedFunction has the same input signature as another.
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
other: Another TracedFunction to compare against
|
|
456
|
+
|
|
457
|
+
Returns:
|
|
458
|
+
True if input counts and types match, False otherwise
|
|
459
|
+
"""
|
|
460
|
+
if len(self.graph.inputs) != len(other.graph.inputs):
|
|
461
|
+
return False
|
|
462
|
+
return all(
|
|
463
|
+
self_in.type == other_in.type
|
|
464
|
+
for self_in, other_in in zip(
|
|
465
|
+
self.graph.inputs, other.graph.inputs, strict=True
|
|
466
|
+
)
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
def is_output_signature_match(self, other: TracedFunction) -> bool:
|
|
470
|
+
"""Check if this TracedFunction has the same output signature as another.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
other: Another TracedFunction to compare against
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
True if output counts and types match, False otherwise
|
|
477
|
+
"""
|
|
478
|
+
if len(self.graph.outputs) != len(other.graph.outputs):
|
|
479
|
+
return False
|
|
480
|
+
return all(
|
|
481
|
+
self_out.type == other_out.type
|
|
482
|
+
for self_out, other_out in zip(
|
|
483
|
+
self.graph.outputs, other.graph.outputs, strict=True
|
|
484
|
+
)
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
def compiler_ir(self, verbose: bool = False) -> str:
|
|
488
|
+
"""Get human-readable IR representation of the traced function.
|
|
489
|
+
|
|
490
|
+
This is useful for debugging, auditing, and understanding what
|
|
491
|
+
operations were captured during tracing.
|
|
492
|
+
|
|
493
|
+
Args:
|
|
494
|
+
verbose: If True, include type annotations in the output
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
String representation of the Graph IR
|
|
498
|
+
|
|
499
|
+
Example:
|
|
500
|
+
>>> traced = compile(lambda x, y: x + y, x_obj, y_obj)
|
|
501
|
+
>>> print(traced.compiler_ir())
|
|
502
|
+
%arg0 = input
|
|
503
|
+
%arg1 = input
|
|
504
|
+
%0 = add(%arg0, %arg1)
|
|
505
|
+
return %0
|
|
506
|
+
"""
|
|
507
|
+
return self.graph.to_string(verbose=verbose)
|
|
508
|
+
|
|
509
|
+
def align_region_inputs(
|
|
510
|
+
self, leading_count: int, capture_order: list[Object]
|
|
511
|
+
) -> None:
|
|
512
|
+
"""Align region graph inputs as [leading_values..., captures...] sequence.
|
|
513
|
+
|
|
514
|
+
Reorders the graph inputs to have a standardized structure:
|
|
515
|
+
- First `leading_count` inputs: explicit function parameters
|
|
516
|
+
- Remaining inputs: captured variables in the specified order
|
|
517
|
+
|
|
518
|
+
This is essential for multi-region primitives (e.g., uniform_cond, while_loop)
|
|
519
|
+
where different regions need to share the same capture ordering.
|
|
520
|
+
|
|
521
|
+
Args:
|
|
522
|
+
leading_count: Number of explicit function parameters (non-captured)
|
|
523
|
+
capture_order: Desired order of captured variables
|
|
524
|
+
|
|
525
|
+
Example:
|
|
526
|
+
>>> # Align two branches to have same capture order
|
|
527
|
+
>>> all_captures = merge_captures(then_fn.captured, else_fn.captured)
|
|
528
|
+
>>> then_fn.align_region_inputs(num_args, all_captures)
|
|
529
|
+
>>> else_fn.align_region_inputs(num_args, all_captures)
|
|
530
|
+
"""
|
|
531
|
+
assert len(self.graph.inputs) >= leading_count
|
|
532
|
+
|
|
533
|
+
leading_inputs = self.graph.inputs[:leading_count]
|
|
534
|
+
capture_inputs = self.graph.inputs[leading_count:]
|
|
535
|
+
capture_map = (
|
|
536
|
+
dict(zip(self.captured, capture_inputs, strict=True))
|
|
537
|
+
if self.captured
|
|
538
|
+
else {}
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
new_capture_inputs = []
|
|
542
|
+
for capture_obj in capture_order:
|
|
543
|
+
value = capture_map.get(capture_obj)
|
|
544
|
+
if value is None:
|
|
545
|
+
value = self.graph.add_input(
|
|
546
|
+
name=f"%capture{len(self.graph.inputs)}",
|
|
547
|
+
type=capture_obj.type,
|
|
548
|
+
)
|
|
549
|
+
new_capture_inputs.append(value)
|
|
550
|
+
|
|
551
|
+
self.graph.inputs = leading_inputs + new_capture_inputs
|
|
552
|
+
self.captured = list(capture_order)
|
|
553
|
+
|
|
554
|
+
def prepare_inputs(self, *args: Any, **kwargs: Any) -> list[Any]:
|
|
555
|
+
"""Flatten arguments and map them to graph inputs.
|
|
556
|
+
|
|
557
|
+
Used by the runtime to prepare inputs for graph execution.
|
|
558
|
+
|
|
559
|
+
Args:
|
|
560
|
+
*args: Positional arguments for the function.
|
|
561
|
+
**kwargs: Keyword arguments for the function.
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
List of values corresponding to graph.inputs (may include InterpObject).
|
|
565
|
+
The caller is responsible for unwrapping InterpObject at execution boundary.
|
|
566
|
+
"""
|
|
567
|
+
flat_args, _ = tree_flatten((args, kwargs))
|
|
568
|
+
|
|
569
|
+
# Map to graph inputs
|
|
570
|
+
# fn.in_var_pos contains indices in flat_args that correspond to graph inputs
|
|
571
|
+
# Note: graph.inputs = [explicit_inputs...] + [captured_inputs...]
|
|
572
|
+
explicit_inputs = [flat_args[i] for i in self.in_var_pos]
|
|
573
|
+
all_inputs = explicit_inputs + list(self.captured)
|
|
574
|
+
return all_inputs
|
|
575
|
+
|
|
576
|
+
def reconstruct_outputs(self, execution_result: list[Any]) -> Any:
|
|
577
|
+
"""Reconstruct structured output from execution result.
|
|
578
|
+
|
|
579
|
+
Used by the runtime to format the result of graph execution.
|
|
580
|
+
|
|
581
|
+
Args:
|
|
582
|
+
execution_result: List of results from interpreter.evaluate_graph().
|
|
583
|
+
|
|
584
|
+
Returns:
|
|
585
|
+
Structured output matching the original function's return signature.
|
|
586
|
+
"""
|
|
587
|
+
# execution_result is always a list (now that evaluate_graph returns list)
|
|
588
|
+
results = execution_result
|
|
589
|
+
|
|
590
|
+
# Reconstruct
|
|
591
|
+
total_len = len(self.out_imms) + len(self.out_var_pos)
|
|
592
|
+
flat_out = [None] * total_len
|
|
593
|
+
|
|
594
|
+
var_indices = set(self.out_var_pos)
|
|
595
|
+
imm_iter = iter(self.out_imms)
|
|
596
|
+
res_iter = iter(results)
|
|
597
|
+
|
|
598
|
+
for i in range(total_len):
|
|
599
|
+
if i in var_indices:
|
|
600
|
+
flat_out[i] = next(res_iter)
|
|
601
|
+
else:
|
|
602
|
+
flat_out[i] = next(imm_iter)
|
|
603
|
+
|
|
604
|
+
return self.out_tree.unflatten(flat_out)
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
def trace(
|
|
608
|
+
fn: Callable[..., Any],
|
|
609
|
+
*args: Any,
|
|
610
|
+
**kwargs: Any,
|
|
611
|
+
) -> TracedFunction:
|
|
612
|
+
"""Trace a Python function with the default Tracer."""
|
|
613
|
+
|
|
614
|
+
return Tracer().run(fn, *args, **kwargs)
|