mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1175 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tensor dialect: tensor ops backed by plaintext/private JAX execution.
16
+
17
+ Design Philosophy
18
+ -----------------
19
+ This dialect is intentionally *lightweight* — it focuses on **structural/shape
20
+ operations** (slice, reshape, transpose, gather, scatter, concat) rather than
21
+ full-fledged element-wise arithmetic.
22
+
23
+ Why not add bitwise_and / bitwise_or / arithmetic primitives here?
24
+
25
+ 1. **Shape Dialect**: The primitives defined here perform *index arithmetic* on
26
+ tensor metadata (offsets, strides, dim sizes). They don't interpret element
27
+ values — that's left to the backend (JAX/XLA).
28
+
29
+ 2. **Delegate to run_jax**: For element-wise logic (bitwise ops, arithmetic),
30
+ use `tensor.run_jax(jnp.bitwise_xor, a, b)`. This leverages JAX's mature XLA
31
+ backend without duplicating op definitions or abstract_eval rules for every
32
+ possible JAX op.
33
+
34
+ 3. **Type Preservation**: `run_jax` infers output types from JAX's shape/dtype
35
+ inference, avoiding the need to re-implement type rules for hundreds of ops.
36
+
37
+ For domain-specific ops (GF(2^128) mul, AES expand), use dedicated dialects
38
+ like `field` which have optimized C++ kernel backends.
39
+
40
+ Helper Functions
41
+ ----------------
42
+ - `bitcast(x, dtype)`: Type reinterpretation (SSA-safe, same bytes).
43
+ - For random tensor generation, see `crypto.random_tensor`.
44
+ """
45
+
46
+ from __future__ import annotations
47
+
48
+ import base64
49
+ import math
50
+ from collections.abc import Callable
51
+ from dataclasses import dataclass
52
+ from itertools import count
53
+ from typing import Any, cast
54
+ from weakref import WeakKeyDictionary
55
+
56
+ import jax
57
+ import numpy as np
58
+ from jax import ShapeDtypeStruct
59
+ from jax.tree_util import PyTreeDef, tree_flatten
60
+
61
+ import mplang.v2.edsl as el
62
+ import mplang.v2.edsl.typing as elt
63
+ from mplang.v1.utils.func_utils import normalize_fn
64
+ from mplang.v2.dialects import dtypes
65
+
66
+ run_jax_p = el.Primitive[Any]("tensor.run_jax")
67
+ constant_p = el.Primitive[el.Object]("tensor.constant")
68
+
69
+
70
+ @dataclass
71
+ class RunJaxCompilation:
72
+ """Compilation record for tensor.run_jax functions.
73
+
74
+ Stores both the compilation artifacts (StableHLO MLIR, types, tree structure)
75
+ and metadata needed for execution (arg_keep_map for JAX's unused param elimination).
76
+ """
77
+
78
+ fn: Callable[..., Any]
79
+ stablehlo: str
80
+ out_tree: PyTreeDef
81
+ output_types: list[elt.BaseType]
82
+ arg_keep_map: list[int] | None = None
83
+
84
+
85
+ _RUN_JAX_REGISTRY: dict[str, RunJaxCompilation] = {}
86
+ _RUN_JAX_ID_GENERATOR = count()
87
+
88
+
89
+ def _current_tracer() -> el.Tracer:
90
+ ctx = el.get_current_context()
91
+ if not isinstance(ctx, el.Tracer):
92
+ raise TypeError(f"Expected Tracer context, got {type(ctx)}")
93
+ return ctx
94
+
95
+
96
+ def _scalar_to_numpy_dtype(scalar: elt.ScalarType) -> np.dtype[np.generic]:
97
+ return np.dtype(dtypes.to_jax(scalar)) # type: ignore[no-any-return]
98
+
99
+
100
+ def _numpy_dtype_to_scalar(dtype: Any) -> elt.ScalarType:
101
+ return dtypes.from_dtype(dtype)
102
+
103
+
104
+ def _tensor_type_to_placeholder(
105
+ tensor_type: elt.TensorType | elt.ScalarType,
106
+ ) -> ShapeDtypeStruct:
107
+ if isinstance(tensor_type, elt.ScalarType):
108
+ # Treat scalar as rank-0 tensor
109
+ dtype = _scalar_to_numpy_dtype(tensor_type)
110
+ return ShapeDtypeStruct((), dtype)
111
+
112
+ normalized_shape: list[int] = []
113
+ for idx, dim in enumerate(tensor_type.shape):
114
+ if dim is None:
115
+ raise TypeError(
116
+ f"tensor.run_jax argument dimension {idx} is None; "
117
+ "please provide a static dimension."
118
+ )
119
+ if dim == -1:
120
+ raise TypeError(
121
+ "tensor.run_jax does not yet support dynamic (-1) dimensions"
122
+ )
123
+ if dim <= 0 and dim != 0:
124
+ raise ValueError(f"Invalid tensor dimension {dim}")
125
+ normalized_shape.append(dim)
126
+ # element_type must be ScalarType for conversion to numpy dtype
127
+ if not isinstance(tensor_type.element_type, elt.ScalarType):
128
+ raise TypeError(
129
+ f"Expected ScalarType element, got {type(tensor_type.element_type)}"
130
+ )
131
+ dtype = _scalar_to_numpy_dtype(tensor_type.element_type)
132
+ return ShapeDtypeStruct(tuple(normalized_shape), dtype)
133
+
134
+
135
+ def _out_info_to_edsl(out_info: Any) -> elt.TensorType:
136
+ scalar = _numpy_dtype_to_scalar(out_info.dtype)
137
+ shape = tuple(out_info.shape)
138
+ return elt.TensorType(scalar, shape)
139
+
140
+
141
+ def _register_compilation(compilation: RunJaxCompilation) -> str:
142
+ compilation_id = f"tensor.run_jax::{next(_RUN_JAX_ID_GENERATOR)}"
143
+ _RUN_JAX_REGISTRY[compilation_id] = compilation
144
+ return compilation_id
145
+
146
+
147
+ def get_run_jax_compilation(compilation_id: str) -> RunJaxCompilation:
148
+ """Get compilation record by ID.
149
+
150
+ Returns:
151
+ The compilation record containing StableHLO MLIR, types, and metadata.
152
+ """
153
+ try:
154
+ return _RUN_JAX_REGISTRY[compilation_id]
155
+ except KeyError as exc:
156
+ raise KeyError(
157
+ f"Unknown tensor.run_jax compilation id '{compilation_id}'"
158
+ ) from exc
159
+
160
+
161
+ def _compile_run_jax(
162
+ fn: Callable[..., Any],
163
+ normalized_fn: Callable[..., Any],
164
+ placeholders: list[ShapeDtypeStruct],
165
+ ) -> tuple[RunJaxCompilation, str]:
166
+ """Compile JAX function to StableHLO MLIR.
167
+
168
+ Pipeline: jit → lower → StableHLO MLIR
169
+
170
+ Args:
171
+ fn: Original JAX function
172
+ normalized_fn: Function accepting list of variables (for JAX lower API)
173
+ placeholders: JAX ShapeDtypeStruct list for lowering
174
+
175
+ Returns:
176
+ Tuple of (compilation record, compilation_id)
177
+ """
178
+ jitted = jax.jit(normalized_fn)
179
+ lowered = jitted.lower(placeholders)
180
+ stablehlo_text = str(lowered.compiler_ir("stablehlo"))
181
+
182
+ # Handle JAX's unused parameter elimination
183
+ arg_keep_map: list[int] | None = None
184
+ try:
185
+ compile_args = lowered._lowering.compile_args
186
+ kept_var_idx = compile_args["kept_var_idx"]
187
+ kept_indices = sorted(kept_var_idx)
188
+ if len(kept_indices) < len(placeholders):
189
+ arg_keep_map = kept_indices
190
+ except (AttributeError, KeyError, TypeError) as e:
191
+ raise RuntimeError(
192
+ f"Cannot access JAX's kept_var_idx for unused parameter handling. "
193
+ f"JAX may have optimized away unused parameters. Error: {e}"
194
+ ) from e
195
+
196
+ # Convert output info to EDSL types
197
+ output_types: list[elt.BaseType]
198
+ if isinstance(lowered.out_info, tuple):
199
+ output_types = [_out_info_to_edsl(info) for info in lowered.out_info]
200
+ else:
201
+ output_types = [_out_info_to_edsl(lowered.out_info)]
202
+
203
+ compilation = RunJaxCompilation(
204
+ fn=fn,
205
+ stablehlo=stablehlo_text,
206
+ out_tree=lowered.out_tree,
207
+ output_types=output_types,
208
+ arg_keep_map=arg_keep_map,
209
+ )
210
+ compilation_id = _register_compilation(compilation)
211
+ return compilation, compilation_id
212
+
213
+
214
+ @run_jax_p.def_trace
215
+ def _run_jax_trace(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
216
+ """Trace tensor.run_jax primitive.
217
+
218
+ Compiles JAX function to StableHLO and emits graph operation.
219
+
220
+ Args:
221
+ fn: JAX-compatible callable
222
+ *args: Positional arguments (TraceObjects become dynamic, others static)
223
+ **kwargs: Keyword arguments (TraceObjects become dynamic, others static)
224
+
225
+ Returns:
226
+ PyTree of TraceObjects matching fn's output structure
227
+ """
228
+ if not callable(fn):
229
+ raise TypeError(f"run_jax expects callable, got {type(fn)}")
230
+
231
+ tracer = _current_tracer()
232
+
233
+ # Extract TraceObjects (dynamic args) from args/kwargs
234
+ def _is_trace_object(value: Any) -> bool:
235
+ return isinstance(value, el.TraceObject)
236
+
237
+ normalized_fn, variables = normalize_fn(fn, args, kwargs, _is_trace_object)
238
+
239
+ # Convert TraceObjects to JAX placeholders for compilation
240
+ placeholders: list[ShapeDtypeStruct] = []
241
+ for var in variables:
242
+ if not isinstance(var, el.TraceObject):
243
+ raise TypeError(f"Expected TraceObject, got {type(var)}")
244
+ if not isinstance(var.type, (elt.TensorType, elt.ScalarType)):
245
+ raise TypeError(f"run_jax only supports Tensors/Scalars, got {var.type}")
246
+ placeholders.append(_tensor_type_to_placeholder(var.type))
247
+
248
+ # Compile to StableHLO
249
+ compilation, text_ref = _compile_run_jax(fn, normalized_fn, placeholders)
250
+
251
+ # Emit graph operation
252
+ input_values = [var._graph_value for var in variables]
253
+ result_values = tracer.graph.add_op(
254
+ opcode="tensor.run_jax",
255
+ inputs=input_values,
256
+ output_types=compilation.output_types,
257
+ attrs={
258
+ "ir_type": "stablehlo",
259
+ "text_ref": text_ref,
260
+ "stablehlo_code": compilation.stablehlo,
261
+ "arg_keep_map": compilation.arg_keep_map,
262
+ },
263
+ )
264
+
265
+ # Reconstruct output PyTree (JAX outputs are all variables)
266
+ out_var_pos = list(range(len(result_values)))
267
+ return tracer.reconstruct_outputs(
268
+ out_var_pos, [], compilation.out_tree, result_values
269
+ )
270
+
271
+
272
+ def run_jax(
273
+ fn: Callable[..., Any],
274
+ *args: Any,
275
+ **kwargs: Any,
276
+ ) -> Any:
277
+ """Trace a tensor JAX function as a graph op.
278
+
279
+ Args:
280
+ fn: Callable that accepts JAX-compatible tensors.
281
+ *args: Positional arguments to the callable. TraceObjects are treated
282
+ as dynamic tensors, while non-Object values become static parameters.
283
+ **kwargs: Keyword arguments for the callable. TraceObjects are treated
284
+ as dynamic tensors, while non-Object values become static parameters.
285
+
286
+ Returns:
287
+ PyTree of TraceObjects with the same structure as fn's output.
288
+ """
289
+ return run_jax_p.bind(fn, *args, **kwargs)
290
+
291
+
292
+ def jax_fn(fn: Callable[..., Any]) -> Callable[..., Any]:
293
+ """Wrap a JAX function for use with pcall.
294
+
295
+ This creates a callable that can be passed to pcall primitives,
296
+ providing a cleaner user interface:
297
+
298
+ Instead of:
299
+ pcall_static((0,), lambda x, y: run_jax(native_fn, x, y), x_p0, y_p0)
300
+
301
+ You can write:
302
+ pcall_static((0,), jax_fn(native_fn), x_p0, y_p0)
303
+
304
+ Args:
305
+ fn: JAX function to wrap
306
+
307
+ Returns:
308
+ Wrapped function that calls run_jax when invoked
309
+
310
+ Example:
311
+ >>> def square(x):
312
+ ... return jnp.square(x)
313
+ >>> wrapped = jax_fn(square)
314
+ >>> result = pcall_static((0,), wrapped, x_p0)
315
+ """
316
+
317
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
318
+ return run_jax(fn, *args, **kwargs)
319
+
320
+ # Preserve function name for better IR readability
321
+ wrapped.__name__ = fn.__name__
322
+ wrapped.__doc__ = fn.__doc__
323
+ return wrapped
324
+
325
+
326
+ @constant_p.def_trace
327
+ def _constant_trace(data: Any) -> el.TraceObject:
328
+ """Create constant tensor from data.
329
+
330
+ Args:
331
+ data: Scalar, numpy array, or array-like object
332
+
333
+ Returns:
334
+ TraceObject with inferred tensor type
335
+
336
+ Raises:
337
+ TypeError: If data cannot be converted to a tensor
338
+ """
339
+ tracer = _current_tracer()
340
+
341
+ # Unified numpy conversion for all data types
342
+ np_array = np.array(data)
343
+ dtype = _numpy_dtype_to_scalar(np_array.dtype)
344
+ shape = tuple(np_array.shape)
345
+ output_type: elt.TensorType = elt.TensorType(dtype, shape)
346
+
347
+ # Emit graph operation with data as attribute
348
+ # Use base64 encoded bytes for efficiency and precision
349
+ data_b64 = base64.b64encode(np_array.tobytes()).decode("ascii")
350
+
351
+ [value] = tracer.graph.add_op(
352
+ opcode="tensor.constant",
353
+ inputs=[],
354
+ output_types=[output_type],
355
+ attrs={
356
+ "value_b64": data_b64,
357
+ },
358
+ regions=[],
359
+ )
360
+
361
+ return el.TraceObject(value, tracer)
362
+
363
+
364
+ # Constant cache: Tracer -> { (dtype, shape, bytes) -> Object }
365
+ _CONSTANT_CACHE: WeakKeyDictionary[
366
+ el.Tracer, dict[tuple[str, tuple[int, ...], bytes], el.Object]
367
+ ] = WeakKeyDictionary()
368
+
369
+
370
+ def constant(data: Any) -> el.Object:
371
+ """Create a tensor constant value.
372
+
373
+ This creates a constant tensor that can be used in tensor computations.
374
+ The constant value is embedded directly into the computation graph.
375
+ Duplicate constants (same data and shape) are cached per-Tracer to
376
+ minimize graph size.
377
+
378
+ Args:
379
+ data: Constant data. Can be:
380
+ - A scalar value (int, float, bool, complex)
381
+ - A numpy array
382
+ - Any array-like object that can be converted to numpy
383
+
384
+ Returns:
385
+ Object representing the constant tensor
386
+
387
+ Raises:
388
+ TypeError: If data cannot be converted to a tensor
389
+
390
+ Example:
391
+ >>> x = constant(3.14) # Scalar constant
392
+ >>> y = constant(np.array([1, 2, 3])) # Array constant
393
+ >>> z = constant([[1, 2], [3, 4]]) # Nested list constant
394
+ """
395
+ # Normalize data to numpy
396
+ np_array = np.array(data)
397
+
398
+ # Ensure canonical form for cache key
399
+ key_shape = tuple(np_array.shape)
400
+ key_dtype = np_array.dtype
401
+ # Use simple bytes for cache key. For very large constants this might
402
+ # be expensive, but typically constants in MPC are small (params, masks).
403
+ key_bytes = np_array.tobytes()
404
+
405
+ try:
406
+ tracer = _current_tracer()
407
+ except TypeError:
408
+ # If no tracer is active (e.g. eager execution), skip caching logic
409
+ # and fall back to standard bind which will handle eager/trace check.
410
+ return cast(el.Object, constant_p.bind(np_array))
411
+
412
+ inner_key = (str(key_dtype), key_shape, key_bytes)
413
+
414
+ tracer_cache: dict[tuple[str, tuple[int, ...], bytes], el.Object] = (
415
+ _CONSTANT_CACHE.setdefault(tracer, {})
416
+ )
417
+ if inner_key in tracer_cache:
418
+ return tracer_cache[inner_key]
419
+
420
+ # Create new constant
421
+ obj = cast(el.Object, constant_p.bind(np_array))
422
+
423
+ # Store in cache
424
+ tracer_cache[inner_key] = obj
425
+ return obj
426
+
427
+
428
+ # ==============================================================================
429
+ # --- Tensor Structural Operations (Element-type agnostic)
430
+ # ==============================================================================
431
+
432
+ transpose_p = el.Primitive[el.Object]("tensor.transpose")
433
+ reshape_p = el.Primitive[el.Object]("tensor.reshape")
434
+ concat_p = el.Primitive[el.Object]("tensor.concat")
435
+ gather_p = el.Primitive[el.Object]("tensor.gather")
436
+ scatter_p = el.Primitive[el.Object]("tensor.scatter")
437
+ slice_p = el.Primitive[el.Object]("tensor.slice")
438
+ elementwise_p = el.Primitive[el.Object]("tensor.elementwise")
439
+
440
+
441
+ class _ElementwiseTracer(el.Tracer):
442
+ """Tracer for element-wise function body.
443
+
444
+ Unwraps TensorType→element type during lift, enabling the traced function
445
+ to operate on scalar element types instead of full tensors. Non-tensor
446
+ arguments (scalars, custom types) are passed through unchanged.
447
+
448
+ Validates that all tensor inputs have the same shape, tracking the first
449
+ tensor's shape in _tensor_shape for result type construction.
450
+ """
451
+
452
+ def __init__(self) -> None:
453
+ """Initialize elementwise tracer."""
454
+ super().__init__()
455
+ self._tensor_shape: tuple[int, ...] | None = None
456
+
457
+ def _lift_type(self, obj: el.Object) -> elt.BaseType:
458
+ """Override to unwrap Tensor→element type, keep scalar as-is.
459
+
460
+ Args:
461
+ obj: Object to lift (can be Tensor or Scalar typed)
462
+
463
+ Returns:
464
+ element type (for Tensor) or original type (for Scalar)
465
+
466
+ Raises:
467
+ ValueError: If tensor shapes don't match
468
+ """
469
+ obj_type = obj.type
470
+
471
+ if isinstance(obj_type, elt.TensorType):
472
+ # Validate and track shape
473
+ new_shape = obj_type.shape
474
+ if self._tensor_shape is None:
475
+ self._tensor_shape = new_shape
476
+ elif self._tensor_shape == new_shape:
477
+ pass # Shapes match
478
+ elif self._tensor_shape == ():
479
+ # Upgrade tracked shape from scalar to tensor
480
+ self._tensor_shape = new_shape
481
+ elif new_shape == ():
482
+ # Input is scalar, broadcasts to tracked shape
483
+ pass
484
+ else:
485
+ raise ValueError(
486
+ f"All tensor arguments must have the same shape. "
487
+ f"Expected {self._tensor_shape}, got {obj_type.shape}"
488
+ )
489
+
490
+ # Unwrap to element type
491
+ return cast(elt.BaseType, obj_type.element_type)
492
+ else:
493
+ # Non-tensor (scalar, custom type) - keep as-is
494
+ return cast(elt.BaseType, obj_type)
495
+
496
+
497
+ @elementwise_p.def_trace
498
+ def _elementwise_trace(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
499
+ """Apply element-wise operation to tensor elements.
500
+
501
+ This primitive maps an element-level callable to tensor elements while
502
+ preserving shape. All tensor arguments must have the same shape.
503
+ Supports mixing tensor and scalar arguments (scalars passed unchanged to each element).
504
+
505
+ Args:
506
+ fn: Callable/traceable function operating on scalar elements.
507
+ Must NOT capture any variables (closure-free).
508
+ *args: Arguments to pass to fn (can be Tensor or Scalar types)
509
+ **kwargs: Keyword arguments to pass to fn
510
+
511
+ Returns:
512
+ PyTree whose leaves are TraceObjects with tensor types.
513
+ Each output tensor has the same shape as input tensors,
514
+ with element types determined by tracing fn.
515
+
516
+ Raises:
517
+ ValueError: If fn captures variables or tensor shapes don't match
518
+ TypeError: If outputs contain non-scalar types
519
+ """
520
+ tracer = _current_tracer()
521
+
522
+ # Trace fn with element inputs using custom tracer
523
+ # The tracer will automatically:
524
+ # 1. Unwrap Tensor→element, keep Scalar as-is
525
+ # 2. Validate all tensors have the same shape
526
+ # 3. Track the tensor shape in _tensor_shape
527
+ element_tracer = _ElementwiseTracer()
528
+ traced_fn = element_tracer.run(fn, *args, **kwargs)
529
+
530
+ # Get result shape from the tracer (set by first tensor in _lift)
531
+ if element_tracer._tensor_shape is None:
532
+ # If no tensor arguments were found, it means we only had
533
+ # non-tensor arguments (scalars/custom types).
534
+ # Degrade to scalar operation (shape ()).
535
+ result_shape: tuple[int, ...] = ()
536
+ else:
537
+ result_shape = element_tracer._tensor_shape
538
+
539
+ # Check that fn doesn't capture variables (closure-free requirement)
540
+ if traced_fn.captured:
541
+ captured_names = [f"{type(obj).__name__}" for obj in traced_fn.captured]
542
+ raise ValueError(
543
+ f"elementwise function must not capture variables. "
544
+ f"Found {len(traced_fn.captured)} captured object(s): {captured_names}. "
545
+ f"Pass all dependencies as explicit arguments."
546
+ )
547
+
548
+ # Get output type from traced graph
549
+ if not traced_fn.graph.outputs:
550
+ raise TypeError("elementwise function must return a value, got empty outputs")
551
+
552
+ if traced_fn.out_imms:
553
+ raise TypeError(
554
+ "elementwise function outputs must be TraceObjects (no pure Python constants)"
555
+ )
556
+
557
+ output_types: list[elt.BaseType] = []
558
+ for idx, output_value in enumerate(traced_fn.graph.outputs):
559
+ output_element_type = output_value.type
560
+ # Allow rank-0 tensors as scalars (produced by run_jax)
561
+ if (
562
+ isinstance(output_element_type, elt.TensorType)
563
+ and output_element_type.shape == ()
564
+ ):
565
+ output_element_type = output_element_type.element_type
566
+
567
+ if not isinstance(output_element_type, elt.BaseType):
568
+ raise TypeError(
569
+ "elementwise function must return BaseType leaves, "
570
+ f"got {type(output_element_type).__name__} at output index {idx}. "
571
+ "Elementwise only supports operations producing valid MPLang types."
572
+ )
573
+ output_types.append(elt.TensorType(output_element_type, result_shape))
574
+ flat_inputs, _ = tree_flatten((args, kwargs))
575
+ input_values = [
576
+ value._graph_value for value in flat_inputs if isinstance(value, el.TraceObject)
577
+ ]
578
+
579
+ # Emit graph operation with traced subgraph as region
580
+ result_values = tracer.graph.add_op(
581
+ opcode="tensor.elementwise",
582
+ inputs=input_values,
583
+ output_types=output_types,
584
+ attrs={},
585
+ regions=[traced_fn.graph],
586
+ )
587
+
588
+ return tracer.reconstruct_outputs(
589
+ traced_fn.out_var_pos,
590
+ traced_fn.out_imms,
591
+ traced_fn.out_tree,
592
+ result_values,
593
+ )
594
+
595
+
596
+ def elementwise(fn: Callable[..., Any], *inputs: el.Object, **kwargs: Any) -> el.Object:
597
+ """Apply element-wise operation to tensor elements.
598
+
599
+ Maps an element-level callable to tensor elements while preserving shape.
600
+ All tensor arguments must have the same shape. Allows mixing tensor and
601
+ scalar arguments (scalars are passed unchanged to fn for each element).
602
+
603
+ The function `fn` must be closure-free (no captured variables) - all
604
+ dependencies must be passed as explicit arguments. This ensures the
605
+ computation graph captures all data dependencies.
606
+
607
+ Type Promotion Rule:
608
+ If all arguments are scalars, the result will be lifted to a rank-0 tensor (shape=()).
609
+
610
+ Args:
611
+ fn: Callable/traceable function operating on scalar elements.
612
+ Can be a lambda, regular function, or Primitive.bind.
613
+ Must not capture variables (closure-free).
614
+ Must return ScalarType values - no tensor nesting.
615
+ *inputs: Tensor or Scalar arguments to pass to fn.
616
+ All tensor inputs must have the same shape.
617
+ **kwargs: Keyword arguments to pass to fn
618
+
619
+ Returns:
620
+ PyTree whose leaves are Tensors with the same shape as the input tensors.
621
+ The PyTree structure matches the return value of `fn`.
622
+ Each leaf has element type determined by fn's corresponding output.
623
+
624
+ Raises:
625
+ ValueError: If fn captures variables or tensor shapes don't match
626
+ TypeError: If fn returns non-scalar types
627
+
628
+ Example:
629
+ >>> # Element-wise addition with lambda
630
+ >>> t1 = ... # Tensor[f32, (10,)]
631
+ >>> t2 = ... # Tensor[f32, (10,)]
632
+ >>> result = elementwise(lambda x, y: x + y, t1, t2)
633
+ >>> # result: Tensor[f32, (10,)]
634
+ >>>
635
+ >>> # PHE encryption: mixing tensor and scalar (key)
636
+ >>> plaintext = ... # Tensor[f32, (10,)]
637
+ >>> public_key = ... # PHEPublicKey (scalar)
638
+ >>> ciphertext = elementwise(phe.encrypt, plaintext, public_key)
639
+ >>> # ciphertext: Tensor[HE[f32], (10,)]
640
+ >>>
641
+ >>> # Multiple tensors with same shape
642
+ >>> t1 = ... # Tensor[f32, (3, 4)]
643
+ >>> t2 = ... # Tensor[f32, (3, 4)]
644
+ >>> result = elementwise(lambda x, y: x * y, t1, t2)
645
+ >>> # result: Tensor[f32, (3, 4)]
646
+ >>>
647
+ >>> # Tensor-scalar operation
648
+ >>> tensor = ... # Tensor[f32, (10,)]
649
+ >>> scalar = ... # f32
650
+ >>> result = elementwise(lambda x, s: x * s, tensor, scalar)
651
+ >>> # result: Tensor[f32, (10,)]
652
+ """
653
+ return elementwise_p.bind(fn, *inputs, **kwargs) # type: ignore[no-any-return]
654
+
655
+
656
+ @transpose_p.def_abstract_eval
657
+ def _transpose_ae(input: elt.TensorType, *, perm: tuple[int, ...]) -> elt.TensorType:
658
+ """Transpose tensor dimensions.
659
+
660
+ Args:
661
+ input: Input tensor type
662
+ perm: Permutation of dimensions (e.g., (1, 0) for 2D transpose)
663
+
664
+ Returns:
665
+ Tensor type with permuted shape
666
+
667
+ Raises:
668
+ TypeError: If input is not a TensorType
669
+ ValueError: If permutation is invalid
670
+ """
671
+ if not isinstance(input, elt.TensorType):
672
+ raise TypeError(f"transpose expects TensorType, got {type(input)}")
673
+
674
+ # Shape is always a tuple (TensorType enforces ranked tensors)
675
+ rank = len(input.shape)
676
+ if len(perm) != rank:
677
+ raise ValueError(
678
+ f"Permutation length {len(perm)} doesn't match tensor rank {rank}"
679
+ )
680
+
681
+ if set(perm) != set(range(rank)):
682
+ raise ValueError(
683
+ f"Invalid permutation {perm}, expected permutation of 0..{rank - 1}"
684
+ )
685
+
686
+ # Apply permutation to shape
687
+ new_shape = tuple(input.shape[i] for i in perm)
688
+ return elt.TensorType(input.element_type, new_shape)
689
+
690
+
691
+ @reshape_p.def_abstract_eval
692
+ def _reshape_ae(input: elt.TensorType, new_shape: tuple[int, ...]) -> elt.TensorType:
693
+ """Reshape tensor to new shape.
694
+
695
+ Args:
696
+ tensor_type: Input tensor type
697
+ new_shape: Target shape (can contain -1 for inferred dimension)
698
+
699
+ Returns:
700
+ Tensor type with new shape
701
+
702
+ Raises:
703
+ TypeError: If input is not a TensorType
704
+ ValueError: If reshape is invalid
705
+ """
706
+ if not isinstance(input, elt.TensorType):
707
+ raise TypeError(f"reshape expects TensorType, got {type(input)}")
708
+
709
+ # Validate new_shape
710
+ if not isinstance(new_shape, tuple):
711
+ raise TypeError(f"new_shape must be tuple, got {type(new_shape)}")
712
+
713
+ neg_one_count = sum(1 for d in new_shape if d == -1)
714
+ if neg_one_count > 1:
715
+ raise ValueError("new_shape can contain at most one -1 dimension")
716
+
717
+ # Compute output shape
718
+ if input.is_fully_static:
719
+ # Input size is known - we can infer or validate
720
+ input_size = math.prod(input.shape)
721
+
722
+ if neg_one_count == 0:
723
+ # No -1: validate total size matches
724
+ new_size = math.prod(new_shape)
725
+ if input_size != new_size:
726
+ raise ValueError(
727
+ f"Cannot reshape tensor of size {input_size} to shape {new_shape} (size {new_size})"
728
+ )
729
+ output_shape = new_shape
730
+ else:
731
+ # One -1: infer that dimension
732
+ known_size = math.prod(d for d in new_shape if d != -1)
733
+ if known_size == 0:
734
+ raise ValueError("Cannot reshape: new_shape has zero-size dimensions")
735
+ if input_size % known_size != 0:
736
+ raise ValueError(
737
+ f"Cannot infer dimension: {input_size} is not divisible by {known_size}"
738
+ )
739
+ inferred_dim = input_size // known_size
740
+ output_shape = tuple(inferred_dim if d == -1 else d for d in new_shape)
741
+ else:
742
+ # Input has dynamic dims - output inherits uncertainty
743
+ # Keep -1 in output (we cannot infer at trace time)
744
+ output_shape = new_shape
745
+
746
+ return elt.TensorType(input.element_type, output_shape)
747
+
748
+
749
+ @concat_p.def_abstract_eval
750
+ def _concat_ae(in_types: list[elt.BaseType], *, axis: int = 0) -> elt.TensorType:
751
+ """Concatenate tensors along axis.
752
+
753
+ Args:
754
+ in_types: List of input tensor types
755
+ axis: Axis along which to concatenate (default: 0)
756
+
757
+ Returns:
758
+ Concatenated tensor type
759
+
760
+ Raises:
761
+ TypeError: If inputs are not TensorTypes
762
+ ValueError: If shapes are incompatible
763
+ """
764
+ if not in_types:
765
+ raise ValueError("concat requires at least one input tensor")
766
+
767
+ # Verify all inputs are TensorType
768
+ for i, t in enumerate(in_types):
769
+ if not isinstance(t, elt.TensorType):
770
+ raise TypeError(f"Input {i} is not TensorType: {type(t)}")
771
+
772
+ tensor_types = cast(list[elt.TensorType], in_types)
773
+
774
+ # Check element types match
775
+ element_type = tensor_types[0].element_type
776
+ for i, t in enumerate(tensor_types[1:], 1):
777
+ if t.element_type != element_type:
778
+ raise TypeError(
779
+ f"Element type mismatch: tensor 0 has {element_type}, "
780
+ f"tensor {i} has {t.element_type}"
781
+ )
782
+
783
+ # All tensors are ranked (shape is always a tuple)
784
+ first_shape = tensor_types[0].shape
785
+ rank = len(first_shape)
786
+
787
+ # Normalize negative axis
788
+ normalized_axis = axis if axis >= 0 else rank + axis
789
+ if normalized_axis < 0 or normalized_axis >= rank:
790
+ raise ValueError(f"axis {axis} out of bounds for rank {rank}")
791
+
792
+ # Check shape compatibility
793
+ result_shape = list(first_shape)
794
+ concat_dim_size = first_shape[normalized_axis]
795
+
796
+ for i, t in enumerate(tensor_types[1:], 1):
797
+ if len(t.shape) != rank:
798
+ raise ValueError(
799
+ f"Rank mismatch: tensor 0 has rank {rank}, tensor {i} has rank {len(t.shape)}"
800
+ )
801
+
802
+ for dim_idx in range(rank):
803
+ if dim_idx == normalized_axis:
804
+ # Concatenation dimension
805
+ if concat_dim_size == -1 or t.shape[dim_idx] == -1:
806
+ concat_dim_size = -1 # Result is dynamic
807
+ else:
808
+ concat_dim_size += t.shape[dim_idx]
809
+ else:
810
+ # Other dimensions must match (or be dynamic)
811
+ if (
812
+ result_shape[dim_idx] != -1
813
+ and t.shape[dim_idx] != -1
814
+ and result_shape[dim_idx] != t.shape[dim_idx]
815
+ ):
816
+ raise ValueError(
817
+ f"Dimension {dim_idx} mismatch: tensor 0 has {result_shape[dim_idx]}, "
818
+ f"tensor {i} has {t.shape[dim_idx]}"
819
+ )
820
+ if t.shape[dim_idx] == -1:
821
+ result_shape[dim_idx] = -1
822
+
823
+ result_shape[normalized_axis] = concat_dim_size
824
+ return elt.TensorType(element_type, tuple(result_shape))
825
+
826
+
827
+ @gather_p.def_abstract_eval
828
+ def _gather_ae(
829
+ input: elt.TensorType, index: elt.TensorType, *, axis: int = 0
830
+ ) -> elt.TensorType:
831
+ """Gather elements along axis using indices.
832
+
833
+ Args:
834
+ input: Input tensor type
835
+ index: Integer indices tensor type
836
+ axis: Axis along which to gather
837
+
838
+ Returns:
839
+ Tensor type with gathered elements
840
+
841
+ Raises:
842
+ TypeError: If inputs are not TensorTypes or indices are not integer
843
+ ValueError: If axis is invalid
844
+ """
845
+ if not isinstance(input, elt.TensorType):
846
+ raise TypeError(f"gather expects TensorType, got {type(input)}")
847
+ if not isinstance(index, elt.TensorType):
848
+ raise TypeError(f"indices must be TensorType, got {type(index)}")
849
+
850
+ # Verify indices are integer type (ScalarType includes IntegerType)
851
+ if not isinstance(index.element_type, elt.IntegerType):
852
+ raise TypeError(
853
+ f"indices must have IntegerType element, got {type(index.element_type).__name__}"
854
+ )
855
+ # Check for 32-bit or 64-bit integers
856
+ if index.element_type.bitwidth not in (32, 64):
857
+ raise TypeError(
858
+ f"indices must be 32-bit or 64-bit integers (i32/i64/u32/u64), got {index.element_type}"
859
+ )
860
+
861
+ # Both inputs must be ranked (shape is always a tuple now)
862
+ rank = len(input.shape)
863
+ normalized_axis = axis if axis >= 0 else rank + axis
864
+ if normalized_axis < 0 or normalized_axis >= rank:
865
+ raise ValueError(f"axis {axis} out of bounds for rank {rank}")
866
+
867
+ # Result shape: replace axis dimension with indices shape
868
+ result_shape = (
869
+ input.shape[:normalized_axis] + index.shape + input.shape[normalized_axis + 1 :]
870
+ )
871
+ return elt.TensorType(input.element_type, result_shape)
872
+
873
+
874
+ @scatter_p.def_abstract_eval
875
+ def _scatter_ae(
876
+ tensor_type: elt.TensorType,
877
+ indices_type: elt.TensorType,
878
+ updates_type: elt.TensorType,
879
+ axis: int = 0,
880
+ ) -> elt.TensorType:
881
+ """Scatter updates into tensor at indices.
882
+
883
+ Args:
884
+ tensor_type: Input tensor type
885
+ indices_type: Integer indices tensor type
886
+ updates_type: Updates tensor type
887
+ axis: Axis along which to scatter
888
+
889
+ Returns:
890
+ Tensor type (same as input)
891
+
892
+ Raises:
893
+ TypeError: If inputs are not compatible
894
+ ValueError: If shapes are incompatible
895
+ """
896
+ if not isinstance(tensor_type, elt.TensorType):
897
+ raise TypeError(f"scatter expects TensorType, got {type(tensor_type)}")
898
+ if not isinstance(indices_type, elt.TensorType):
899
+ raise TypeError(f"indices must be TensorType, got {type(indices_type)}")
900
+ if not isinstance(updates_type, elt.TensorType):
901
+ raise TypeError(f"updates must be TensorType, got {type(updates_type)}")
902
+
903
+ # Verify element types match
904
+ if updates_type.element_type != tensor_type.element_type:
905
+ raise TypeError(
906
+ f"Element type mismatch: tensor has {tensor_type.element_type}, "
907
+ f"updates has {updates_type.element_type}"
908
+ )
909
+
910
+ # Scatter returns same type as input
911
+ return tensor_type
912
+
913
+
914
+ @slice_p.def_abstract_eval
915
+ def _slice_ae(
916
+ tensor_type: elt.TensorType,
917
+ starts: tuple[int, ...],
918
+ ends: tuple[int, ...],
919
+ strides: tuple[int, ...] | None = None,
920
+ ) -> elt.TensorType:
921
+ """Slice tensor along dimensions.
922
+
923
+ Args:
924
+ tensor_type: Input tensor type
925
+ starts: Start indices for each dimension
926
+ ends: End indices for each dimension
927
+ strides: Stride for each dimension (defaults to 1)
928
+
929
+ Returns:
930
+ Sliced tensor type
931
+
932
+ Raises:
933
+ TypeError: If input is not TensorType
934
+ ValueError: If slice parameters are invalid
935
+ """
936
+ if not isinstance(tensor_type, elt.TensorType):
937
+ raise TypeError(f"slice expects TensorType, got {type(tensor_type)}")
938
+
939
+ # Tensor is always ranked (shape is always a tuple)
940
+ rank = len(tensor_type.shape)
941
+ if len(starts) != rank or len(ends) != rank:
942
+ raise ValueError(
943
+ f"starts and ends must have length {rank}, got {len(starts)} and {len(ends)}"
944
+ )
945
+
946
+ if strides is None:
947
+ strides = tuple([1] * rank)
948
+ elif len(strides) != rank:
949
+ raise ValueError(f"strides must have length {rank}, got {len(strides)}")
950
+
951
+ # Compute result shape
952
+ result_shape = []
953
+ for dim_idx in range(rank):
954
+ dim_size = tensor_type.shape[dim_idx]
955
+ if dim_size == -1:
956
+ # Dynamic dimension - result is also dynamic
957
+ result_shape.append(-1)
958
+ else:
959
+ # Static dimension - compute slice size
960
+ start = starts[dim_idx]
961
+ end = ends[dim_idx]
962
+ stride = strides[dim_idx]
963
+
964
+ if stride <= 0:
965
+ raise ValueError(
966
+ f"stride must be positive, got {stride} at dim {dim_idx}"
967
+ )
968
+
969
+ # Handle negative indices
970
+ if start < 0:
971
+ start = max(0, dim_size + start)
972
+ if end < 0:
973
+ end = max(0, dim_size + end)
974
+
975
+ # Clamp to valid range
976
+ start = max(0, min(start, dim_size))
977
+ end = max(0, min(end, dim_size))
978
+
979
+ # Compute slice length
980
+ if end <= start:
981
+ slice_len = 0
982
+ else:
983
+ slice_len = (end - start + stride - 1) // stride
984
+
985
+ result_shape.append(slice_len)
986
+
987
+ return elt.TensorType(tensor_type.element_type, tuple(result_shape))
988
+
989
+
990
+ # User-facing API
991
+ def transpose(tensor: el.Object, perm: tuple[int, ...]) -> el.Object:
992
+ """Transpose tensor dimensions.
993
+
994
+ Args:
995
+ tensor: Input tensor
996
+ perm: Permutation of dimensions
997
+
998
+ Returns:
999
+ Transposed tensor
1000
+
1001
+ Example:
1002
+ >>> x = constant([[1, 2], [3, 4]]) # shape (2, 2)
1003
+ >>> y = transpose(x, (1, 0)) # shape (2, 2), transposed
1004
+ """
1005
+ return transpose_p.bind(tensor, perm=perm) # type: ignore[no-any-return]
1006
+
1007
+
1008
+ def reshape(tensor: el.Object, new_shape: tuple[int, ...]) -> el.Object:
1009
+ """Reshape tensor to new shape.
1010
+
1011
+ Args:
1012
+ tensor: Input tensor
1013
+ new_shape: Target shape (can contain -1 for inferred dimension)
1014
+
1015
+ Returns:
1016
+ Reshaped tensor
1017
+
1018
+ Example:
1019
+ >>> x = constant([1, 2, 3, 4, 5, 6]) # shape (6,)
1020
+ >>> y = reshape(x, (2, 3)) # shape (2, 3)
1021
+ >>> z = reshape(x, (2, -1)) # shape (2, 3), -1 inferred
1022
+ """
1023
+ return reshape_p.bind(tensor, new_shape=new_shape) # type: ignore[no-any-return]
1024
+
1025
+
1026
+ def concat(tensors: list[el.Object], axis: int = 0) -> el.Object:
1027
+ """Concatenate tensors along axis.
1028
+
1029
+ Args:
1030
+ tensors: List of tensors to concatenate
1031
+ axis: Axis along which to concatenate
1032
+
1033
+ Returns:
1034
+ Concatenated tensor
1035
+
1036
+ Example:
1037
+ >>> x = constant([1, 2, 3])
1038
+ >>> y = constant([4, 5, 6])
1039
+ >>> z = concat([x, y], axis=0) # [1, 2, 3, 4, 5, 6]
1040
+ """
1041
+ return concat_p.bind(*tensors, axis=axis) # type: ignore[no-any-return]
1042
+
1043
+
1044
+ def gather(tensor: el.Object, indices: el.Object, axis: int = 0) -> el.Object:
1045
+ """Gather elements along axis using indices.
1046
+
1047
+ Args:
1048
+ tensor: Input tensor
1049
+ indices: Integer indices tensor
1050
+ axis: Axis along which to gather
1051
+
1052
+ Returns:
1053
+ Gathered tensor
1054
+
1055
+ Example:
1056
+ >>> x = constant([10, 20, 30, 40])
1057
+ >>> idx = constant([0, 2, 1])
1058
+ >>> y = gather(x, idx) # [10, 30, 20]
1059
+ """
1060
+ return gather_p.bind(tensor, indices, axis=axis) # type: ignore[no-any-return]
1061
+
1062
+
1063
+ def scatter(
1064
+ tensor: el.Object,
1065
+ indices: el.Object,
1066
+ updates: el.Object,
1067
+ axis: int = 0,
1068
+ ) -> el.Object:
1069
+ """Scatter updates into tensor at indices.
1070
+
1071
+ Args:
1072
+ tensor: Input tensor
1073
+ indices: Integer indices tensor
1074
+ updates: Updates tensor
1075
+ axis: Axis along which to scatter
1076
+
1077
+ Returns:
1078
+ Updated tensor
1079
+
1080
+ Example:
1081
+ >>> x = constant([1, 2, 3, 4])
1082
+ >>> idx = constant([0, 2])
1083
+ >>> updates = constant([10, 30])
1084
+ >>> y = scatter(x, idx, updates) # [10, 2, 30, 4]
1085
+ """
1086
+ return scatter_p.bind(tensor, indices, updates, axis=axis) # type: ignore[no-any-return]
1087
+
1088
+
1089
+ def slice_tensor(
1090
+ tensor: el.Object,
1091
+ starts: tuple[int, ...],
1092
+ ends: tuple[int, ...],
1093
+ strides: tuple[int, ...] | None = None,
1094
+ ) -> el.Object:
1095
+ """Slice tensor along dimensions.
1096
+
1097
+ Args:
1098
+ tensor: Input tensor
1099
+ starts: Start indices for each dimension
1100
+ ends: End indices for each dimension
1101
+ strides: Stride for each dimension (defaults to 1)
1102
+
1103
+ Returns:
1104
+ Sliced tensor
1105
+
1106
+ Example:
1107
+ >>> x = constant([[1, 2, 3], [4, 5, 6]])
1108
+ >>> y = slice_tensor(x, (0, 1), (2, 3)) # [[2, 3], [5, 6]]
1109
+ """
1110
+ return slice_p.bind(tensor, starts=starts, ends=ends, strides=strides) # type: ignore[no-any-return]
1111
+
1112
+
1113
+ # ==============================================================================
1114
+ # --- Type Reinterpretation (via run_jax)
1115
+ # ==============================================================================
1116
+
1117
+
1118
+ def bitcast(x: el.Object, dtype: elt.ScalarType) -> el.Object:
1119
+ """Reinterpret tensor bytes as a different dtype.
1120
+
1121
+ This is a zero-copy (at execution time) type reinterpretation that views
1122
+ the same underlying bytes as a different element type. The total byte
1123
+ count must remain the same.
1124
+
1125
+ This follows LLVM/MLIR `bitcast` semantics: the operation produces a new
1126
+ SSA value with different type but same bit representation.
1127
+
1128
+ Args:
1129
+ x: Input tensor.
1130
+ dtype: Target element type (e.g., elt.u64, elt.u8, elt.i32).
1131
+
1132
+ Returns:
1133
+ Tensor with same bytes reinterpreted as dtype.
1134
+ Shape changes to preserve total bytes.
1135
+
1136
+ Example:
1137
+ >>> # Tensor[u8, (8,)] -> Tensor[u64, (1,)]
1138
+ >>> packed = tensor.bitcast(bytes_tensor, elt.u64)
1139
+ >>> # Tensor[u64, (10, 2)] -> Tensor[u8, (10, 16)]
1140
+ >>> unpacked = tensor.bitcast(u64_tensor, elt.u8)
1141
+ """
1142
+ from typing import cast
1143
+
1144
+ jax_dtype = dtypes.to_jax(dtype)
1145
+
1146
+ def _bitcast(arr: Any) -> Any:
1147
+ return arr.view(jax_dtype)
1148
+
1149
+ return cast(el.Object, run_jax(_bitcast, x))
1150
+
1151
+
1152
+ __all__ = [
1153
+ "RunJaxCompilation",
1154
+ "bitcast",
1155
+ "concat",
1156
+ "concat_p",
1157
+ "constant",
1158
+ "constant_p",
1159
+ "elementwise",
1160
+ "elementwise_p",
1161
+ "gather",
1162
+ "gather_p",
1163
+ "get_run_jax_compilation",
1164
+ "jax_fn",
1165
+ "reshape",
1166
+ "reshape_p",
1167
+ "run_jax",
1168
+ "run_jax_p",
1169
+ "scatter",
1170
+ "scatter_p",
1171
+ "slice_p",
1172
+ "slice_tensor",
1173
+ "transpose",
1174
+ "transpose_p",
1175
+ ]