mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1175 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Tensor dialect: tensor ops backed by plaintext/private JAX execution.
|
|
16
|
+
|
|
17
|
+
Design Philosophy
|
|
18
|
+
-----------------
|
|
19
|
+
This dialect is intentionally *lightweight* — it focuses on **structural/shape
|
|
20
|
+
operations** (slice, reshape, transpose, gather, scatter, concat) rather than
|
|
21
|
+
full-fledged element-wise arithmetic.
|
|
22
|
+
|
|
23
|
+
Why not add bitwise_and / bitwise_or / arithmetic primitives here?
|
|
24
|
+
|
|
25
|
+
1. **Shape Dialect**: The primitives defined here perform *index arithmetic* on
|
|
26
|
+
tensor metadata (offsets, strides, dim sizes). They don't interpret element
|
|
27
|
+
values — that's left to the backend (JAX/XLA).
|
|
28
|
+
|
|
29
|
+
2. **Delegate to run_jax**: For element-wise logic (bitwise ops, arithmetic),
|
|
30
|
+
use `tensor.run_jax(jnp.bitwise_xor, a, b)`. This leverages JAX's mature XLA
|
|
31
|
+
backend without duplicating op definitions or abstract_eval rules for every
|
|
32
|
+
possible JAX op.
|
|
33
|
+
|
|
34
|
+
3. **Type Preservation**: `run_jax` infers output types from JAX's shape/dtype
|
|
35
|
+
inference, avoiding the need to re-implement type rules for hundreds of ops.
|
|
36
|
+
|
|
37
|
+
For domain-specific ops (GF(2^128) mul, AES expand), use dedicated dialects
|
|
38
|
+
like `field` which have optimized C++ kernel backends.
|
|
39
|
+
|
|
40
|
+
Helper Functions
|
|
41
|
+
----------------
|
|
42
|
+
- `bitcast(x, dtype)`: Type reinterpretation (SSA-safe, same bytes).
|
|
43
|
+
- For random tensor generation, see `crypto.random_tensor`.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
from __future__ import annotations
|
|
47
|
+
|
|
48
|
+
import base64
|
|
49
|
+
import math
|
|
50
|
+
from collections.abc import Callable
|
|
51
|
+
from dataclasses import dataclass
|
|
52
|
+
from itertools import count
|
|
53
|
+
from typing import Any, cast
|
|
54
|
+
from weakref import WeakKeyDictionary
|
|
55
|
+
|
|
56
|
+
import jax
|
|
57
|
+
import numpy as np
|
|
58
|
+
from jax import ShapeDtypeStruct
|
|
59
|
+
from jax.tree_util import PyTreeDef, tree_flatten
|
|
60
|
+
|
|
61
|
+
import mplang.v2.edsl as el
|
|
62
|
+
import mplang.v2.edsl.typing as elt
|
|
63
|
+
from mplang.v1.utils.func_utils import normalize_fn
|
|
64
|
+
from mplang.v2.dialects import dtypes
|
|
65
|
+
|
|
66
|
+
run_jax_p = el.Primitive[Any]("tensor.run_jax")
|
|
67
|
+
constant_p = el.Primitive[el.Object]("tensor.constant")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@dataclass
|
|
71
|
+
class RunJaxCompilation:
|
|
72
|
+
"""Compilation record for tensor.run_jax functions.
|
|
73
|
+
|
|
74
|
+
Stores both the compilation artifacts (StableHLO MLIR, types, tree structure)
|
|
75
|
+
and metadata needed for execution (arg_keep_map for JAX's unused param elimination).
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
fn: Callable[..., Any]
|
|
79
|
+
stablehlo: str
|
|
80
|
+
out_tree: PyTreeDef
|
|
81
|
+
output_types: list[elt.BaseType]
|
|
82
|
+
arg_keep_map: list[int] | None = None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
_RUN_JAX_REGISTRY: dict[str, RunJaxCompilation] = {}
|
|
86
|
+
_RUN_JAX_ID_GENERATOR = count()
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _current_tracer() -> el.Tracer:
|
|
90
|
+
ctx = el.get_current_context()
|
|
91
|
+
if not isinstance(ctx, el.Tracer):
|
|
92
|
+
raise TypeError(f"Expected Tracer context, got {type(ctx)}")
|
|
93
|
+
return ctx
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _scalar_to_numpy_dtype(scalar: elt.ScalarType) -> np.dtype[np.generic]:
|
|
97
|
+
return np.dtype(dtypes.to_jax(scalar)) # type: ignore[no-any-return]
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _numpy_dtype_to_scalar(dtype: Any) -> elt.ScalarType:
|
|
101
|
+
return dtypes.from_dtype(dtype)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _tensor_type_to_placeholder(
|
|
105
|
+
tensor_type: elt.TensorType | elt.ScalarType,
|
|
106
|
+
) -> ShapeDtypeStruct:
|
|
107
|
+
if isinstance(tensor_type, elt.ScalarType):
|
|
108
|
+
# Treat scalar as rank-0 tensor
|
|
109
|
+
dtype = _scalar_to_numpy_dtype(tensor_type)
|
|
110
|
+
return ShapeDtypeStruct((), dtype)
|
|
111
|
+
|
|
112
|
+
normalized_shape: list[int] = []
|
|
113
|
+
for idx, dim in enumerate(tensor_type.shape):
|
|
114
|
+
if dim is None:
|
|
115
|
+
raise TypeError(
|
|
116
|
+
f"tensor.run_jax argument dimension {idx} is None; "
|
|
117
|
+
"please provide a static dimension."
|
|
118
|
+
)
|
|
119
|
+
if dim == -1:
|
|
120
|
+
raise TypeError(
|
|
121
|
+
"tensor.run_jax does not yet support dynamic (-1) dimensions"
|
|
122
|
+
)
|
|
123
|
+
if dim <= 0 and dim != 0:
|
|
124
|
+
raise ValueError(f"Invalid tensor dimension {dim}")
|
|
125
|
+
normalized_shape.append(dim)
|
|
126
|
+
# element_type must be ScalarType for conversion to numpy dtype
|
|
127
|
+
if not isinstance(tensor_type.element_type, elt.ScalarType):
|
|
128
|
+
raise TypeError(
|
|
129
|
+
f"Expected ScalarType element, got {type(tensor_type.element_type)}"
|
|
130
|
+
)
|
|
131
|
+
dtype = _scalar_to_numpy_dtype(tensor_type.element_type)
|
|
132
|
+
return ShapeDtypeStruct(tuple(normalized_shape), dtype)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _out_info_to_edsl(out_info: Any) -> elt.TensorType:
|
|
136
|
+
scalar = _numpy_dtype_to_scalar(out_info.dtype)
|
|
137
|
+
shape = tuple(out_info.shape)
|
|
138
|
+
return elt.TensorType(scalar, shape)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _register_compilation(compilation: RunJaxCompilation) -> str:
|
|
142
|
+
compilation_id = f"tensor.run_jax::{next(_RUN_JAX_ID_GENERATOR)}"
|
|
143
|
+
_RUN_JAX_REGISTRY[compilation_id] = compilation
|
|
144
|
+
return compilation_id
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def get_run_jax_compilation(compilation_id: str) -> RunJaxCompilation:
|
|
148
|
+
"""Get compilation record by ID.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
The compilation record containing StableHLO MLIR, types, and metadata.
|
|
152
|
+
"""
|
|
153
|
+
try:
|
|
154
|
+
return _RUN_JAX_REGISTRY[compilation_id]
|
|
155
|
+
except KeyError as exc:
|
|
156
|
+
raise KeyError(
|
|
157
|
+
f"Unknown tensor.run_jax compilation id '{compilation_id}'"
|
|
158
|
+
) from exc
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _compile_run_jax(
|
|
162
|
+
fn: Callable[..., Any],
|
|
163
|
+
normalized_fn: Callable[..., Any],
|
|
164
|
+
placeholders: list[ShapeDtypeStruct],
|
|
165
|
+
) -> tuple[RunJaxCompilation, str]:
|
|
166
|
+
"""Compile JAX function to StableHLO MLIR.
|
|
167
|
+
|
|
168
|
+
Pipeline: jit → lower → StableHLO MLIR
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
fn: Original JAX function
|
|
172
|
+
normalized_fn: Function accepting list of variables (for JAX lower API)
|
|
173
|
+
placeholders: JAX ShapeDtypeStruct list for lowering
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
Tuple of (compilation record, compilation_id)
|
|
177
|
+
"""
|
|
178
|
+
jitted = jax.jit(normalized_fn)
|
|
179
|
+
lowered = jitted.lower(placeholders)
|
|
180
|
+
stablehlo_text = str(lowered.compiler_ir("stablehlo"))
|
|
181
|
+
|
|
182
|
+
# Handle JAX's unused parameter elimination
|
|
183
|
+
arg_keep_map: list[int] | None = None
|
|
184
|
+
try:
|
|
185
|
+
compile_args = lowered._lowering.compile_args
|
|
186
|
+
kept_var_idx = compile_args["kept_var_idx"]
|
|
187
|
+
kept_indices = sorted(kept_var_idx)
|
|
188
|
+
if len(kept_indices) < len(placeholders):
|
|
189
|
+
arg_keep_map = kept_indices
|
|
190
|
+
except (AttributeError, KeyError, TypeError) as e:
|
|
191
|
+
raise RuntimeError(
|
|
192
|
+
f"Cannot access JAX's kept_var_idx for unused parameter handling. "
|
|
193
|
+
f"JAX may have optimized away unused parameters. Error: {e}"
|
|
194
|
+
) from e
|
|
195
|
+
|
|
196
|
+
# Convert output info to EDSL types
|
|
197
|
+
output_types: list[elt.BaseType]
|
|
198
|
+
if isinstance(lowered.out_info, tuple):
|
|
199
|
+
output_types = [_out_info_to_edsl(info) for info in lowered.out_info]
|
|
200
|
+
else:
|
|
201
|
+
output_types = [_out_info_to_edsl(lowered.out_info)]
|
|
202
|
+
|
|
203
|
+
compilation = RunJaxCompilation(
|
|
204
|
+
fn=fn,
|
|
205
|
+
stablehlo=stablehlo_text,
|
|
206
|
+
out_tree=lowered.out_tree,
|
|
207
|
+
output_types=output_types,
|
|
208
|
+
arg_keep_map=arg_keep_map,
|
|
209
|
+
)
|
|
210
|
+
compilation_id = _register_compilation(compilation)
|
|
211
|
+
return compilation, compilation_id
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@run_jax_p.def_trace
|
|
215
|
+
def _run_jax_trace(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
|
216
|
+
"""Trace tensor.run_jax primitive.
|
|
217
|
+
|
|
218
|
+
Compiles JAX function to StableHLO and emits graph operation.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
fn: JAX-compatible callable
|
|
222
|
+
*args: Positional arguments (TraceObjects become dynamic, others static)
|
|
223
|
+
**kwargs: Keyword arguments (TraceObjects become dynamic, others static)
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
PyTree of TraceObjects matching fn's output structure
|
|
227
|
+
"""
|
|
228
|
+
if not callable(fn):
|
|
229
|
+
raise TypeError(f"run_jax expects callable, got {type(fn)}")
|
|
230
|
+
|
|
231
|
+
tracer = _current_tracer()
|
|
232
|
+
|
|
233
|
+
# Extract TraceObjects (dynamic args) from args/kwargs
|
|
234
|
+
def _is_trace_object(value: Any) -> bool:
|
|
235
|
+
return isinstance(value, el.TraceObject)
|
|
236
|
+
|
|
237
|
+
normalized_fn, variables = normalize_fn(fn, args, kwargs, _is_trace_object)
|
|
238
|
+
|
|
239
|
+
# Convert TraceObjects to JAX placeholders for compilation
|
|
240
|
+
placeholders: list[ShapeDtypeStruct] = []
|
|
241
|
+
for var in variables:
|
|
242
|
+
if not isinstance(var, el.TraceObject):
|
|
243
|
+
raise TypeError(f"Expected TraceObject, got {type(var)}")
|
|
244
|
+
if not isinstance(var.type, (elt.TensorType, elt.ScalarType)):
|
|
245
|
+
raise TypeError(f"run_jax only supports Tensors/Scalars, got {var.type}")
|
|
246
|
+
placeholders.append(_tensor_type_to_placeholder(var.type))
|
|
247
|
+
|
|
248
|
+
# Compile to StableHLO
|
|
249
|
+
compilation, text_ref = _compile_run_jax(fn, normalized_fn, placeholders)
|
|
250
|
+
|
|
251
|
+
# Emit graph operation
|
|
252
|
+
input_values = [var._graph_value for var in variables]
|
|
253
|
+
result_values = tracer.graph.add_op(
|
|
254
|
+
opcode="tensor.run_jax",
|
|
255
|
+
inputs=input_values,
|
|
256
|
+
output_types=compilation.output_types,
|
|
257
|
+
attrs={
|
|
258
|
+
"ir_type": "stablehlo",
|
|
259
|
+
"text_ref": text_ref,
|
|
260
|
+
"stablehlo_code": compilation.stablehlo,
|
|
261
|
+
"arg_keep_map": compilation.arg_keep_map,
|
|
262
|
+
},
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
# Reconstruct output PyTree (JAX outputs are all variables)
|
|
266
|
+
out_var_pos = list(range(len(result_values)))
|
|
267
|
+
return tracer.reconstruct_outputs(
|
|
268
|
+
out_var_pos, [], compilation.out_tree, result_values
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def run_jax(
|
|
273
|
+
fn: Callable[..., Any],
|
|
274
|
+
*args: Any,
|
|
275
|
+
**kwargs: Any,
|
|
276
|
+
) -> Any:
|
|
277
|
+
"""Trace a tensor JAX function as a graph op.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
fn: Callable that accepts JAX-compatible tensors.
|
|
281
|
+
*args: Positional arguments to the callable. TraceObjects are treated
|
|
282
|
+
as dynamic tensors, while non-Object values become static parameters.
|
|
283
|
+
**kwargs: Keyword arguments for the callable. TraceObjects are treated
|
|
284
|
+
as dynamic tensors, while non-Object values become static parameters.
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
PyTree of TraceObjects with the same structure as fn's output.
|
|
288
|
+
"""
|
|
289
|
+
return run_jax_p.bind(fn, *args, **kwargs)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def jax_fn(fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
293
|
+
"""Wrap a JAX function for use with pcall.
|
|
294
|
+
|
|
295
|
+
This creates a callable that can be passed to pcall primitives,
|
|
296
|
+
providing a cleaner user interface:
|
|
297
|
+
|
|
298
|
+
Instead of:
|
|
299
|
+
pcall_static((0,), lambda x, y: run_jax(native_fn, x, y), x_p0, y_p0)
|
|
300
|
+
|
|
301
|
+
You can write:
|
|
302
|
+
pcall_static((0,), jax_fn(native_fn), x_p0, y_p0)
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
fn: JAX function to wrap
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
Wrapped function that calls run_jax when invoked
|
|
309
|
+
|
|
310
|
+
Example:
|
|
311
|
+
>>> def square(x):
|
|
312
|
+
... return jnp.square(x)
|
|
313
|
+
>>> wrapped = jax_fn(square)
|
|
314
|
+
>>> result = pcall_static((0,), wrapped, x_p0)
|
|
315
|
+
"""
|
|
316
|
+
|
|
317
|
+
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
|
318
|
+
return run_jax(fn, *args, **kwargs)
|
|
319
|
+
|
|
320
|
+
# Preserve function name for better IR readability
|
|
321
|
+
wrapped.__name__ = fn.__name__
|
|
322
|
+
wrapped.__doc__ = fn.__doc__
|
|
323
|
+
return wrapped
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
@constant_p.def_trace
|
|
327
|
+
def _constant_trace(data: Any) -> el.TraceObject:
|
|
328
|
+
"""Create constant tensor from data.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
data: Scalar, numpy array, or array-like object
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
TraceObject with inferred tensor type
|
|
335
|
+
|
|
336
|
+
Raises:
|
|
337
|
+
TypeError: If data cannot be converted to a tensor
|
|
338
|
+
"""
|
|
339
|
+
tracer = _current_tracer()
|
|
340
|
+
|
|
341
|
+
# Unified numpy conversion for all data types
|
|
342
|
+
np_array = np.array(data)
|
|
343
|
+
dtype = _numpy_dtype_to_scalar(np_array.dtype)
|
|
344
|
+
shape = tuple(np_array.shape)
|
|
345
|
+
output_type: elt.TensorType = elt.TensorType(dtype, shape)
|
|
346
|
+
|
|
347
|
+
# Emit graph operation with data as attribute
|
|
348
|
+
# Use base64 encoded bytes for efficiency and precision
|
|
349
|
+
data_b64 = base64.b64encode(np_array.tobytes()).decode("ascii")
|
|
350
|
+
|
|
351
|
+
[value] = tracer.graph.add_op(
|
|
352
|
+
opcode="tensor.constant",
|
|
353
|
+
inputs=[],
|
|
354
|
+
output_types=[output_type],
|
|
355
|
+
attrs={
|
|
356
|
+
"value_b64": data_b64,
|
|
357
|
+
},
|
|
358
|
+
regions=[],
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
return el.TraceObject(value, tracer)
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
# Constant cache: Tracer -> { (dtype, shape, bytes) -> Object }
|
|
365
|
+
_CONSTANT_CACHE: WeakKeyDictionary[
|
|
366
|
+
el.Tracer, dict[tuple[str, tuple[int, ...], bytes], el.Object]
|
|
367
|
+
] = WeakKeyDictionary()
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def constant(data: Any) -> el.Object:
|
|
371
|
+
"""Create a tensor constant value.
|
|
372
|
+
|
|
373
|
+
This creates a constant tensor that can be used in tensor computations.
|
|
374
|
+
The constant value is embedded directly into the computation graph.
|
|
375
|
+
Duplicate constants (same data and shape) are cached per-Tracer to
|
|
376
|
+
minimize graph size.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
data: Constant data. Can be:
|
|
380
|
+
- A scalar value (int, float, bool, complex)
|
|
381
|
+
- A numpy array
|
|
382
|
+
- Any array-like object that can be converted to numpy
|
|
383
|
+
|
|
384
|
+
Returns:
|
|
385
|
+
Object representing the constant tensor
|
|
386
|
+
|
|
387
|
+
Raises:
|
|
388
|
+
TypeError: If data cannot be converted to a tensor
|
|
389
|
+
|
|
390
|
+
Example:
|
|
391
|
+
>>> x = constant(3.14) # Scalar constant
|
|
392
|
+
>>> y = constant(np.array([1, 2, 3])) # Array constant
|
|
393
|
+
>>> z = constant([[1, 2], [3, 4]]) # Nested list constant
|
|
394
|
+
"""
|
|
395
|
+
# Normalize data to numpy
|
|
396
|
+
np_array = np.array(data)
|
|
397
|
+
|
|
398
|
+
# Ensure canonical form for cache key
|
|
399
|
+
key_shape = tuple(np_array.shape)
|
|
400
|
+
key_dtype = np_array.dtype
|
|
401
|
+
# Use simple bytes for cache key. For very large constants this might
|
|
402
|
+
# be expensive, but typically constants in MPC are small (params, masks).
|
|
403
|
+
key_bytes = np_array.tobytes()
|
|
404
|
+
|
|
405
|
+
try:
|
|
406
|
+
tracer = _current_tracer()
|
|
407
|
+
except TypeError:
|
|
408
|
+
# If no tracer is active (e.g. eager execution), skip caching logic
|
|
409
|
+
# and fall back to standard bind which will handle eager/trace check.
|
|
410
|
+
return cast(el.Object, constant_p.bind(np_array))
|
|
411
|
+
|
|
412
|
+
inner_key = (str(key_dtype), key_shape, key_bytes)
|
|
413
|
+
|
|
414
|
+
tracer_cache: dict[tuple[str, tuple[int, ...], bytes], el.Object] = (
|
|
415
|
+
_CONSTANT_CACHE.setdefault(tracer, {})
|
|
416
|
+
)
|
|
417
|
+
if inner_key in tracer_cache:
|
|
418
|
+
return tracer_cache[inner_key]
|
|
419
|
+
|
|
420
|
+
# Create new constant
|
|
421
|
+
obj = cast(el.Object, constant_p.bind(np_array))
|
|
422
|
+
|
|
423
|
+
# Store in cache
|
|
424
|
+
tracer_cache[inner_key] = obj
|
|
425
|
+
return obj
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
# ==============================================================================
|
|
429
|
+
# --- Tensor Structural Operations (Element-type agnostic)
|
|
430
|
+
# ==============================================================================
|
|
431
|
+
|
|
432
|
+
transpose_p = el.Primitive[el.Object]("tensor.transpose")
|
|
433
|
+
reshape_p = el.Primitive[el.Object]("tensor.reshape")
|
|
434
|
+
concat_p = el.Primitive[el.Object]("tensor.concat")
|
|
435
|
+
gather_p = el.Primitive[el.Object]("tensor.gather")
|
|
436
|
+
scatter_p = el.Primitive[el.Object]("tensor.scatter")
|
|
437
|
+
slice_p = el.Primitive[el.Object]("tensor.slice")
|
|
438
|
+
elementwise_p = el.Primitive[el.Object]("tensor.elementwise")
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
class _ElementwiseTracer(el.Tracer):
|
|
442
|
+
"""Tracer for element-wise function body.
|
|
443
|
+
|
|
444
|
+
Unwraps TensorType→element type during lift, enabling the traced function
|
|
445
|
+
to operate on scalar element types instead of full tensors. Non-tensor
|
|
446
|
+
arguments (scalars, custom types) are passed through unchanged.
|
|
447
|
+
|
|
448
|
+
Validates that all tensor inputs have the same shape, tracking the first
|
|
449
|
+
tensor's shape in _tensor_shape for result type construction.
|
|
450
|
+
"""
|
|
451
|
+
|
|
452
|
+
def __init__(self) -> None:
|
|
453
|
+
"""Initialize elementwise tracer."""
|
|
454
|
+
super().__init__()
|
|
455
|
+
self._tensor_shape: tuple[int, ...] | None = None
|
|
456
|
+
|
|
457
|
+
def _lift_type(self, obj: el.Object) -> elt.BaseType:
|
|
458
|
+
"""Override to unwrap Tensor→element type, keep scalar as-is.
|
|
459
|
+
|
|
460
|
+
Args:
|
|
461
|
+
obj: Object to lift (can be Tensor or Scalar typed)
|
|
462
|
+
|
|
463
|
+
Returns:
|
|
464
|
+
element type (for Tensor) or original type (for Scalar)
|
|
465
|
+
|
|
466
|
+
Raises:
|
|
467
|
+
ValueError: If tensor shapes don't match
|
|
468
|
+
"""
|
|
469
|
+
obj_type = obj.type
|
|
470
|
+
|
|
471
|
+
if isinstance(obj_type, elt.TensorType):
|
|
472
|
+
# Validate and track shape
|
|
473
|
+
new_shape = obj_type.shape
|
|
474
|
+
if self._tensor_shape is None:
|
|
475
|
+
self._tensor_shape = new_shape
|
|
476
|
+
elif self._tensor_shape == new_shape:
|
|
477
|
+
pass # Shapes match
|
|
478
|
+
elif self._tensor_shape == ():
|
|
479
|
+
# Upgrade tracked shape from scalar to tensor
|
|
480
|
+
self._tensor_shape = new_shape
|
|
481
|
+
elif new_shape == ():
|
|
482
|
+
# Input is scalar, broadcasts to tracked shape
|
|
483
|
+
pass
|
|
484
|
+
else:
|
|
485
|
+
raise ValueError(
|
|
486
|
+
f"All tensor arguments must have the same shape. "
|
|
487
|
+
f"Expected {self._tensor_shape}, got {obj_type.shape}"
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
# Unwrap to element type
|
|
491
|
+
return cast(elt.BaseType, obj_type.element_type)
|
|
492
|
+
else:
|
|
493
|
+
# Non-tensor (scalar, custom type) - keep as-is
|
|
494
|
+
return cast(elt.BaseType, obj_type)
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
@elementwise_p.def_trace
|
|
498
|
+
def _elementwise_trace(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
|
499
|
+
"""Apply element-wise operation to tensor elements.
|
|
500
|
+
|
|
501
|
+
This primitive maps an element-level callable to tensor elements while
|
|
502
|
+
preserving shape. All tensor arguments must have the same shape.
|
|
503
|
+
Supports mixing tensor and scalar arguments (scalars passed unchanged to each element).
|
|
504
|
+
|
|
505
|
+
Args:
|
|
506
|
+
fn: Callable/traceable function operating on scalar elements.
|
|
507
|
+
Must NOT capture any variables (closure-free).
|
|
508
|
+
*args: Arguments to pass to fn (can be Tensor or Scalar types)
|
|
509
|
+
**kwargs: Keyword arguments to pass to fn
|
|
510
|
+
|
|
511
|
+
Returns:
|
|
512
|
+
PyTree whose leaves are TraceObjects with tensor types.
|
|
513
|
+
Each output tensor has the same shape as input tensors,
|
|
514
|
+
with element types determined by tracing fn.
|
|
515
|
+
|
|
516
|
+
Raises:
|
|
517
|
+
ValueError: If fn captures variables or tensor shapes don't match
|
|
518
|
+
TypeError: If outputs contain non-scalar types
|
|
519
|
+
"""
|
|
520
|
+
tracer = _current_tracer()
|
|
521
|
+
|
|
522
|
+
# Trace fn with element inputs using custom tracer
|
|
523
|
+
# The tracer will automatically:
|
|
524
|
+
# 1. Unwrap Tensor→element, keep Scalar as-is
|
|
525
|
+
# 2. Validate all tensors have the same shape
|
|
526
|
+
# 3. Track the tensor shape in _tensor_shape
|
|
527
|
+
element_tracer = _ElementwiseTracer()
|
|
528
|
+
traced_fn = element_tracer.run(fn, *args, **kwargs)
|
|
529
|
+
|
|
530
|
+
# Get result shape from the tracer (set by first tensor in _lift)
|
|
531
|
+
if element_tracer._tensor_shape is None:
|
|
532
|
+
# If no tensor arguments were found, it means we only had
|
|
533
|
+
# non-tensor arguments (scalars/custom types).
|
|
534
|
+
# Degrade to scalar operation (shape ()).
|
|
535
|
+
result_shape: tuple[int, ...] = ()
|
|
536
|
+
else:
|
|
537
|
+
result_shape = element_tracer._tensor_shape
|
|
538
|
+
|
|
539
|
+
# Check that fn doesn't capture variables (closure-free requirement)
|
|
540
|
+
if traced_fn.captured:
|
|
541
|
+
captured_names = [f"{type(obj).__name__}" for obj in traced_fn.captured]
|
|
542
|
+
raise ValueError(
|
|
543
|
+
f"elementwise function must not capture variables. "
|
|
544
|
+
f"Found {len(traced_fn.captured)} captured object(s): {captured_names}. "
|
|
545
|
+
f"Pass all dependencies as explicit arguments."
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
# Get output type from traced graph
|
|
549
|
+
if not traced_fn.graph.outputs:
|
|
550
|
+
raise TypeError("elementwise function must return a value, got empty outputs")
|
|
551
|
+
|
|
552
|
+
if traced_fn.out_imms:
|
|
553
|
+
raise TypeError(
|
|
554
|
+
"elementwise function outputs must be TraceObjects (no pure Python constants)"
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
output_types: list[elt.BaseType] = []
|
|
558
|
+
for idx, output_value in enumerate(traced_fn.graph.outputs):
|
|
559
|
+
output_element_type = output_value.type
|
|
560
|
+
# Allow rank-0 tensors as scalars (produced by run_jax)
|
|
561
|
+
if (
|
|
562
|
+
isinstance(output_element_type, elt.TensorType)
|
|
563
|
+
and output_element_type.shape == ()
|
|
564
|
+
):
|
|
565
|
+
output_element_type = output_element_type.element_type
|
|
566
|
+
|
|
567
|
+
if not isinstance(output_element_type, elt.BaseType):
|
|
568
|
+
raise TypeError(
|
|
569
|
+
"elementwise function must return BaseType leaves, "
|
|
570
|
+
f"got {type(output_element_type).__name__} at output index {idx}. "
|
|
571
|
+
"Elementwise only supports operations producing valid MPLang types."
|
|
572
|
+
)
|
|
573
|
+
output_types.append(elt.TensorType(output_element_type, result_shape))
|
|
574
|
+
flat_inputs, _ = tree_flatten((args, kwargs))
|
|
575
|
+
input_values = [
|
|
576
|
+
value._graph_value for value in flat_inputs if isinstance(value, el.TraceObject)
|
|
577
|
+
]
|
|
578
|
+
|
|
579
|
+
# Emit graph operation with traced subgraph as region
|
|
580
|
+
result_values = tracer.graph.add_op(
|
|
581
|
+
opcode="tensor.elementwise",
|
|
582
|
+
inputs=input_values,
|
|
583
|
+
output_types=output_types,
|
|
584
|
+
attrs={},
|
|
585
|
+
regions=[traced_fn.graph],
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
return tracer.reconstruct_outputs(
|
|
589
|
+
traced_fn.out_var_pos,
|
|
590
|
+
traced_fn.out_imms,
|
|
591
|
+
traced_fn.out_tree,
|
|
592
|
+
result_values,
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
def elementwise(fn: Callable[..., Any], *inputs: el.Object, **kwargs: Any) -> el.Object:
|
|
597
|
+
"""Apply element-wise operation to tensor elements.
|
|
598
|
+
|
|
599
|
+
Maps an element-level callable to tensor elements while preserving shape.
|
|
600
|
+
All tensor arguments must have the same shape. Allows mixing tensor and
|
|
601
|
+
scalar arguments (scalars are passed unchanged to fn for each element).
|
|
602
|
+
|
|
603
|
+
The function `fn` must be closure-free (no captured variables) - all
|
|
604
|
+
dependencies must be passed as explicit arguments. This ensures the
|
|
605
|
+
computation graph captures all data dependencies.
|
|
606
|
+
|
|
607
|
+
Type Promotion Rule:
|
|
608
|
+
If all arguments are scalars, the result will be lifted to a rank-0 tensor (shape=()).
|
|
609
|
+
|
|
610
|
+
Args:
|
|
611
|
+
fn: Callable/traceable function operating on scalar elements.
|
|
612
|
+
Can be a lambda, regular function, or Primitive.bind.
|
|
613
|
+
Must not capture variables (closure-free).
|
|
614
|
+
Must return ScalarType values - no tensor nesting.
|
|
615
|
+
*inputs: Tensor or Scalar arguments to pass to fn.
|
|
616
|
+
All tensor inputs must have the same shape.
|
|
617
|
+
**kwargs: Keyword arguments to pass to fn
|
|
618
|
+
|
|
619
|
+
Returns:
|
|
620
|
+
PyTree whose leaves are Tensors with the same shape as the input tensors.
|
|
621
|
+
The PyTree structure matches the return value of `fn`.
|
|
622
|
+
Each leaf has element type determined by fn's corresponding output.
|
|
623
|
+
|
|
624
|
+
Raises:
|
|
625
|
+
ValueError: If fn captures variables or tensor shapes don't match
|
|
626
|
+
TypeError: If fn returns non-scalar types
|
|
627
|
+
|
|
628
|
+
Example:
|
|
629
|
+
>>> # Element-wise addition with lambda
|
|
630
|
+
>>> t1 = ... # Tensor[f32, (10,)]
|
|
631
|
+
>>> t2 = ... # Tensor[f32, (10,)]
|
|
632
|
+
>>> result = elementwise(lambda x, y: x + y, t1, t2)
|
|
633
|
+
>>> # result: Tensor[f32, (10,)]
|
|
634
|
+
>>>
|
|
635
|
+
>>> # PHE encryption: mixing tensor and scalar (key)
|
|
636
|
+
>>> plaintext = ... # Tensor[f32, (10,)]
|
|
637
|
+
>>> public_key = ... # PHEPublicKey (scalar)
|
|
638
|
+
>>> ciphertext = elementwise(phe.encrypt, plaintext, public_key)
|
|
639
|
+
>>> # ciphertext: Tensor[HE[f32], (10,)]
|
|
640
|
+
>>>
|
|
641
|
+
>>> # Multiple tensors with same shape
|
|
642
|
+
>>> t1 = ... # Tensor[f32, (3, 4)]
|
|
643
|
+
>>> t2 = ... # Tensor[f32, (3, 4)]
|
|
644
|
+
>>> result = elementwise(lambda x, y: x * y, t1, t2)
|
|
645
|
+
>>> # result: Tensor[f32, (3, 4)]
|
|
646
|
+
>>>
|
|
647
|
+
>>> # Tensor-scalar operation
|
|
648
|
+
>>> tensor = ... # Tensor[f32, (10,)]
|
|
649
|
+
>>> scalar = ... # f32
|
|
650
|
+
>>> result = elementwise(lambda x, s: x * s, tensor, scalar)
|
|
651
|
+
>>> # result: Tensor[f32, (10,)]
|
|
652
|
+
"""
|
|
653
|
+
return elementwise_p.bind(fn, *inputs, **kwargs) # type: ignore[no-any-return]
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
@transpose_p.def_abstract_eval
|
|
657
|
+
def _transpose_ae(input: elt.TensorType, *, perm: tuple[int, ...]) -> elt.TensorType:
|
|
658
|
+
"""Transpose tensor dimensions.
|
|
659
|
+
|
|
660
|
+
Args:
|
|
661
|
+
input: Input tensor type
|
|
662
|
+
perm: Permutation of dimensions (e.g., (1, 0) for 2D transpose)
|
|
663
|
+
|
|
664
|
+
Returns:
|
|
665
|
+
Tensor type with permuted shape
|
|
666
|
+
|
|
667
|
+
Raises:
|
|
668
|
+
TypeError: If input is not a TensorType
|
|
669
|
+
ValueError: If permutation is invalid
|
|
670
|
+
"""
|
|
671
|
+
if not isinstance(input, elt.TensorType):
|
|
672
|
+
raise TypeError(f"transpose expects TensorType, got {type(input)}")
|
|
673
|
+
|
|
674
|
+
# Shape is always a tuple (TensorType enforces ranked tensors)
|
|
675
|
+
rank = len(input.shape)
|
|
676
|
+
if len(perm) != rank:
|
|
677
|
+
raise ValueError(
|
|
678
|
+
f"Permutation length {len(perm)} doesn't match tensor rank {rank}"
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
if set(perm) != set(range(rank)):
|
|
682
|
+
raise ValueError(
|
|
683
|
+
f"Invalid permutation {perm}, expected permutation of 0..{rank - 1}"
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
# Apply permutation to shape
|
|
687
|
+
new_shape = tuple(input.shape[i] for i in perm)
|
|
688
|
+
return elt.TensorType(input.element_type, new_shape)
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
@reshape_p.def_abstract_eval
|
|
692
|
+
def _reshape_ae(input: elt.TensorType, new_shape: tuple[int, ...]) -> elt.TensorType:
|
|
693
|
+
"""Reshape tensor to new shape.
|
|
694
|
+
|
|
695
|
+
Args:
|
|
696
|
+
tensor_type: Input tensor type
|
|
697
|
+
new_shape: Target shape (can contain -1 for inferred dimension)
|
|
698
|
+
|
|
699
|
+
Returns:
|
|
700
|
+
Tensor type with new shape
|
|
701
|
+
|
|
702
|
+
Raises:
|
|
703
|
+
TypeError: If input is not a TensorType
|
|
704
|
+
ValueError: If reshape is invalid
|
|
705
|
+
"""
|
|
706
|
+
if not isinstance(input, elt.TensorType):
|
|
707
|
+
raise TypeError(f"reshape expects TensorType, got {type(input)}")
|
|
708
|
+
|
|
709
|
+
# Validate new_shape
|
|
710
|
+
if not isinstance(new_shape, tuple):
|
|
711
|
+
raise TypeError(f"new_shape must be tuple, got {type(new_shape)}")
|
|
712
|
+
|
|
713
|
+
neg_one_count = sum(1 for d in new_shape if d == -1)
|
|
714
|
+
if neg_one_count > 1:
|
|
715
|
+
raise ValueError("new_shape can contain at most one -1 dimension")
|
|
716
|
+
|
|
717
|
+
# Compute output shape
|
|
718
|
+
if input.is_fully_static:
|
|
719
|
+
# Input size is known - we can infer or validate
|
|
720
|
+
input_size = math.prod(input.shape)
|
|
721
|
+
|
|
722
|
+
if neg_one_count == 0:
|
|
723
|
+
# No -1: validate total size matches
|
|
724
|
+
new_size = math.prod(new_shape)
|
|
725
|
+
if input_size != new_size:
|
|
726
|
+
raise ValueError(
|
|
727
|
+
f"Cannot reshape tensor of size {input_size} to shape {new_shape} (size {new_size})"
|
|
728
|
+
)
|
|
729
|
+
output_shape = new_shape
|
|
730
|
+
else:
|
|
731
|
+
# One -1: infer that dimension
|
|
732
|
+
known_size = math.prod(d for d in new_shape if d != -1)
|
|
733
|
+
if known_size == 0:
|
|
734
|
+
raise ValueError("Cannot reshape: new_shape has zero-size dimensions")
|
|
735
|
+
if input_size % known_size != 0:
|
|
736
|
+
raise ValueError(
|
|
737
|
+
f"Cannot infer dimension: {input_size} is not divisible by {known_size}"
|
|
738
|
+
)
|
|
739
|
+
inferred_dim = input_size // known_size
|
|
740
|
+
output_shape = tuple(inferred_dim if d == -1 else d for d in new_shape)
|
|
741
|
+
else:
|
|
742
|
+
# Input has dynamic dims - output inherits uncertainty
|
|
743
|
+
# Keep -1 in output (we cannot infer at trace time)
|
|
744
|
+
output_shape = new_shape
|
|
745
|
+
|
|
746
|
+
return elt.TensorType(input.element_type, output_shape)
|
|
747
|
+
|
|
748
|
+
|
|
749
|
+
@concat_p.def_abstract_eval
|
|
750
|
+
def _concat_ae(in_types: list[elt.BaseType], *, axis: int = 0) -> elt.TensorType:
|
|
751
|
+
"""Concatenate tensors along axis.
|
|
752
|
+
|
|
753
|
+
Args:
|
|
754
|
+
in_types: List of input tensor types
|
|
755
|
+
axis: Axis along which to concatenate (default: 0)
|
|
756
|
+
|
|
757
|
+
Returns:
|
|
758
|
+
Concatenated tensor type
|
|
759
|
+
|
|
760
|
+
Raises:
|
|
761
|
+
TypeError: If inputs are not TensorTypes
|
|
762
|
+
ValueError: If shapes are incompatible
|
|
763
|
+
"""
|
|
764
|
+
if not in_types:
|
|
765
|
+
raise ValueError("concat requires at least one input tensor")
|
|
766
|
+
|
|
767
|
+
# Verify all inputs are TensorType
|
|
768
|
+
for i, t in enumerate(in_types):
|
|
769
|
+
if not isinstance(t, elt.TensorType):
|
|
770
|
+
raise TypeError(f"Input {i} is not TensorType: {type(t)}")
|
|
771
|
+
|
|
772
|
+
tensor_types = cast(list[elt.TensorType], in_types)
|
|
773
|
+
|
|
774
|
+
# Check element types match
|
|
775
|
+
element_type = tensor_types[0].element_type
|
|
776
|
+
for i, t in enumerate(tensor_types[1:], 1):
|
|
777
|
+
if t.element_type != element_type:
|
|
778
|
+
raise TypeError(
|
|
779
|
+
f"Element type mismatch: tensor 0 has {element_type}, "
|
|
780
|
+
f"tensor {i} has {t.element_type}"
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
# All tensors are ranked (shape is always a tuple)
|
|
784
|
+
first_shape = tensor_types[0].shape
|
|
785
|
+
rank = len(first_shape)
|
|
786
|
+
|
|
787
|
+
# Normalize negative axis
|
|
788
|
+
normalized_axis = axis if axis >= 0 else rank + axis
|
|
789
|
+
if normalized_axis < 0 or normalized_axis >= rank:
|
|
790
|
+
raise ValueError(f"axis {axis} out of bounds for rank {rank}")
|
|
791
|
+
|
|
792
|
+
# Check shape compatibility
|
|
793
|
+
result_shape = list(first_shape)
|
|
794
|
+
concat_dim_size = first_shape[normalized_axis]
|
|
795
|
+
|
|
796
|
+
for i, t in enumerate(tensor_types[1:], 1):
|
|
797
|
+
if len(t.shape) != rank:
|
|
798
|
+
raise ValueError(
|
|
799
|
+
f"Rank mismatch: tensor 0 has rank {rank}, tensor {i} has rank {len(t.shape)}"
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
for dim_idx in range(rank):
|
|
803
|
+
if dim_idx == normalized_axis:
|
|
804
|
+
# Concatenation dimension
|
|
805
|
+
if concat_dim_size == -1 or t.shape[dim_idx] == -1:
|
|
806
|
+
concat_dim_size = -1 # Result is dynamic
|
|
807
|
+
else:
|
|
808
|
+
concat_dim_size += t.shape[dim_idx]
|
|
809
|
+
else:
|
|
810
|
+
# Other dimensions must match (or be dynamic)
|
|
811
|
+
if (
|
|
812
|
+
result_shape[dim_idx] != -1
|
|
813
|
+
and t.shape[dim_idx] != -1
|
|
814
|
+
and result_shape[dim_idx] != t.shape[dim_idx]
|
|
815
|
+
):
|
|
816
|
+
raise ValueError(
|
|
817
|
+
f"Dimension {dim_idx} mismatch: tensor 0 has {result_shape[dim_idx]}, "
|
|
818
|
+
f"tensor {i} has {t.shape[dim_idx]}"
|
|
819
|
+
)
|
|
820
|
+
if t.shape[dim_idx] == -1:
|
|
821
|
+
result_shape[dim_idx] = -1
|
|
822
|
+
|
|
823
|
+
result_shape[normalized_axis] = concat_dim_size
|
|
824
|
+
return elt.TensorType(element_type, tuple(result_shape))
|
|
825
|
+
|
|
826
|
+
|
|
827
|
+
@gather_p.def_abstract_eval
|
|
828
|
+
def _gather_ae(
|
|
829
|
+
input: elt.TensorType, index: elt.TensorType, *, axis: int = 0
|
|
830
|
+
) -> elt.TensorType:
|
|
831
|
+
"""Gather elements along axis using indices.
|
|
832
|
+
|
|
833
|
+
Args:
|
|
834
|
+
input: Input tensor type
|
|
835
|
+
index: Integer indices tensor type
|
|
836
|
+
axis: Axis along which to gather
|
|
837
|
+
|
|
838
|
+
Returns:
|
|
839
|
+
Tensor type with gathered elements
|
|
840
|
+
|
|
841
|
+
Raises:
|
|
842
|
+
TypeError: If inputs are not TensorTypes or indices are not integer
|
|
843
|
+
ValueError: If axis is invalid
|
|
844
|
+
"""
|
|
845
|
+
if not isinstance(input, elt.TensorType):
|
|
846
|
+
raise TypeError(f"gather expects TensorType, got {type(input)}")
|
|
847
|
+
if not isinstance(index, elt.TensorType):
|
|
848
|
+
raise TypeError(f"indices must be TensorType, got {type(index)}")
|
|
849
|
+
|
|
850
|
+
# Verify indices are integer type (ScalarType includes IntegerType)
|
|
851
|
+
if not isinstance(index.element_type, elt.IntegerType):
|
|
852
|
+
raise TypeError(
|
|
853
|
+
f"indices must have IntegerType element, got {type(index.element_type).__name__}"
|
|
854
|
+
)
|
|
855
|
+
# Check for 32-bit or 64-bit integers
|
|
856
|
+
if index.element_type.bitwidth not in (32, 64):
|
|
857
|
+
raise TypeError(
|
|
858
|
+
f"indices must be 32-bit or 64-bit integers (i32/i64/u32/u64), got {index.element_type}"
|
|
859
|
+
)
|
|
860
|
+
|
|
861
|
+
# Both inputs must be ranked (shape is always a tuple now)
|
|
862
|
+
rank = len(input.shape)
|
|
863
|
+
normalized_axis = axis if axis >= 0 else rank + axis
|
|
864
|
+
if normalized_axis < 0 or normalized_axis >= rank:
|
|
865
|
+
raise ValueError(f"axis {axis} out of bounds for rank {rank}")
|
|
866
|
+
|
|
867
|
+
# Result shape: replace axis dimension with indices shape
|
|
868
|
+
result_shape = (
|
|
869
|
+
input.shape[:normalized_axis] + index.shape + input.shape[normalized_axis + 1 :]
|
|
870
|
+
)
|
|
871
|
+
return elt.TensorType(input.element_type, result_shape)
|
|
872
|
+
|
|
873
|
+
|
|
874
|
+
@scatter_p.def_abstract_eval
|
|
875
|
+
def _scatter_ae(
|
|
876
|
+
tensor_type: elt.TensorType,
|
|
877
|
+
indices_type: elt.TensorType,
|
|
878
|
+
updates_type: elt.TensorType,
|
|
879
|
+
axis: int = 0,
|
|
880
|
+
) -> elt.TensorType:
|
|
881
|
+
"""Scatter updates into tensor at indices.
|
|
882
|
+
|
|
883
|
+
Args:
|
|
884
|
+
tensor_type: Input tensor type
|
|
885
|
+
indices_type: Integer indices tensor type
|
|
886
|
+
updates_type: Updates tensor type
|
|
887
|
+
axis: Axis along which to scatter
|
|
888
|
+
|
|
889
|
+
Returns:
|
|
890
|
+
Tensor type (same as input)
|
|
891
|
+
|
|
892
|
+
Raises:
|
|
893
|
+
TypeError: If inputs are not compatible
|
|
894
|
+
ValueError: If shapes are incompatible
|
|
895
|
+
"""
|
|
896
|
+
if not isinstance(tensor_type, elt.TensorType):
|
|
897
|
+
raise TypeError(f"scatter expects TensorType, got {type(tensor_type)}")
|
|
898
|
+
if not isinstance(indices_type, elt.TensorType):
|
|
899
|
+
raise TypeError(f"indices must be TensorType, got {type(indices_type)}")
|
|
900
|
+
if not isinstance(updates_type, elt.TensorType):
|
|
901
|
+
raise TypeError(f"updates must be TensorType, got {type(updates_type)}")
|
|
902
|
+
|
|
903
|
+
# Verify element types match
|
|
904
|
+
if updates_type.element_type != tensor_type.element_type:
|
|
905
|
+
raise TypeError(
|
|
906
|
+
f"Element type mismatch: tensor has {tensor_type.element_type}, "
|
|
907
|
+
f"updates has {updates_type.element_type}"
|
|
908
|
+
)
|
|
909
|
+
|
|
910
|
+
# Scatter returns same type as input
|
|
911
|
+
return tensor_type
|
|
912
|
+
|
|
913
|
+
|
|
914
|
+
@slice_p.def_abstract_eval
|
|
915
|
+
def _slice_ae(
|
|
916
|
+
tensor_type: elt.TensorType,
|
|
917
|
+
starts: tuple[int, ...],
|
|
918
|
+
ends: tuple[int, ...],
|
|
919
|
+
strides: tuple[int, ...] | None = None,
|
|
920
|
+
) -> elt.TensorType:
|
|
921
|
+
"""Slice tensor along dimensions.
|
|
922
|
+
|
|
923
|
+
Args:
|
|
924
|
+
tensor_type: Input tensor type
|
|
925
|
+
starts: Start indices for each dimension
|
|
926
|
+
ends: End indices for each dimension
|
|
927
|
+
strides: Stride for each dimension (defaults to 1)
|
|
928
|
+
|
|
929
|
+
Returns:
|
|
930
|
+
Sliced tensor type
|
|
931
|
+
|
|
932
|
+
Raises:
|
|
933
|
+
TypeError: If input is not TensorType
|
|
934
|
+
ValueError: If slice parameters are invalid
|
|
935
|
+
"""
|
|
936
|
+
if not isinstance(tensor_type, elt.TensorType):
|
|
937
|
+
raise TypeError(f"slice expects TensorType, got {type(tensor_type)}")
|
|
938
|
+
|
|
939
|
+
# Tensor is always ranked (shape is always a tuple)
|
|
940
|
+
rank = len(tensor_type.shape)
|
|
941
|
+
if len(starts) != rank or len(ends) != rank:
|
|
942
|
+
raise ValueError(
|
|
943
|
+
f"starts and ends must have length {rank}, got {len(starts)} and {len(ends)}"
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
if strides is None:
|
|
947
|
+
strides = tuple([1] * rank)
|
|
948
|
+
elif len(strides) != rank:
|
|
949
|
+
raise ValueError(f"strides must have length {rank}, got {len(strides)}")
|
|
950
|
+
|
|
951
|
+
# Compute result shape
|
|
952
|
+
result_shape = []
|
|
953
|
+
for dim_idx in range(rank):
|
|
954
|
+
dim_size = tensor_type.shape[dim_idx]
|
|
955
|
+
if dim_size == -1:
|
|
956
|
+
# Dynamic dimension - result is also dynamic
|
|
957
|
+
result_shape.append(-1)
|
|
958
|
+
else:
|
|
959
|
+
# Static dimension - compute slice size
|
|
960
|
+
start = starts[dim_idx]
|
|
961
|
+
end = ends[dim_idx]
|
|
962
|
+
stride = strides[dim_idx]
|
|
963
|
+
|
|
964
|
+
if stride <= 0:
|
|
965
|
+
raise ValueError(
|
|
966
|
+
f"stride must be positive, got {stride} at dim {dim_idx}"
|
|
967
|
+
)
|
|
968
|
+
|
|
969
|
+
# Handle negative indices
|
|
970
|
+
if start < 0:
|
|
971
|
+
start = max(0, dim_size + start)
|
|
972
|
+
if end < 0:
|
|
973
|
+
end = max(0, dim_size + end)
|
|
974
|
+
|
|
975
|
+
# Clamp to valid range
|
|
976
|
+
start = max(0, min(start, dim_size))
|
|
977
|
+
end = max(0, min(end, dim_size))
|
|
978
|
+
|
|
979
|
+
# Compute slice length
|
|
980
|
+
if end <= start:
|
|
981
|
+
slice_len = 0
|
|
982
|
+
else:
|
|
983
|
+
slice_len = (end - start + stride - 1) // stride
|
|
984
|
+
|
|
985
|
+
result_shape.append(slice_len)
|
|
986
|
+
|
|
987
|
+
return elt.TensorType(tensor_type.element_type, tuple(result_shape))
|
|
988
|
+
|
|
989
|
+
|
|
990
|
+
# User-facing API
|
|
991
|
+
def transpose(tensor: el.Object, perm: tuple[int, ...]) -> el.Object:
|
|
992
|
+
"""Transpose tensor dimensions.
|
|
993
|
+
|
|
994
|
+
Args:
|
|
995
|
+
tensor: Input tensor
|
|
996
|
+
perm: Permutation of dimensions
|
|
997
|
+
|
|
998
|
+
Returns:
|
|
999
|
+
Transposed tensor
|
|
1000
|
+
|
|
1001
|
+
Example:
|
|
1002
|
+
>>> x = constant([[1, 2], [3, 4]]) # shape (2, 2)
|
|
1003
|
+
>>> y = transpose(x, (1, 0)) # shape (2, 2), transposed
|
|
1004
|
+
"""
|
|
1005
|
+
return transpose_p.bind(tensor, perm=perm) # type: ignore[no-any-return]
|
|
1006
|
+
|
|
1007
|
+
|
|
1008
|
+
def reshape(tensor: el.Object, new_shape: tuple[int, ...]) -> el.Object:
|
|
1009
|
+
"""Reshape tensor to new shape.
|
|
1010
|
+
|
|
1011
|
+
Args:
|
|
1012
|
+
tensor: Input tensor
|
|
1013
|
+
new_shape: Target shape (can contain -1 for inferred dimension)
|
|
1014
|
+
|
|
1015
|
+
Returns:
|
|
1016
|
+
Reshaped tensor
|
|
1017
|
+
|
|
1018
|
+
Example:
|
|
1019
|
+
>>> x = constant([1, 2, 3, 4, 5, 6]) # shape (6,)
|
|
1020
|
+
>>> y = reshape(x, (2, 3)) # shape (2, 3)
|
|
1021
|
+
>>> z = reshape(x, (2, -1)) # shape (2, 3), -1 inferred
|
|
1022
|
+
"""
|
|
1023
|
+
return reshape_p.bind(tensor, new_shape=new_shape) # type: ignore[no-any-return]
|
|
1024
|
+
|
|
1025
|
+
|
|
1026
|
+
def concat(tensors: list[el.Object], axis: int = 0) -> el.Object:
|
|
1027
|
+
"""Concatenate tensors along axis.
|
|
1028
|
+
|
|
1029
|
+
Args:
|
|
1030
|
+
tensors: List of tensors to concatenate
|
|
1031
|
+
axis: Axis along which to concatenate
|
|
1032
|
+
|
|
1033
|
+
Returns:
|
|
1034
|
+
Concatenated tensor
|
|
1035
|
+
|
|
1036
|
+
Example:
|
|
1037
|
+
>>> x = constant([1, 2, 3])
|
|
1038
|
+
>>> y = constant([4, 5, 6])
|
|
1039
|
+
>>> z = concat([x, y], axis=0) # [1, 2, 3, 4, 5, 6]
|
|
1040
|
+
"""
|
|
1041
|
+
return concat_p.bind(*tensors, axis=axis) # type: ignore[no-any-return]
|
|
1042
|
+
|
|
1043
|
+
|
|
1044
|
+
def gather(tensor: el.Object, indices: el.Object, axis: int = 0) -> el.Object:
|
|
1045
|
+
"""Gather elements along axis using indices.
|
|
1046
|
+
|
|
1047
|
+
Args:
|
|
1048
|
+
tensor: Input tensor
|
|
1049
|
+
indices: Integer indices tensor
|
|
1050
|
+
axis: Axis along which to gather
|
|
1051
|
+
|
|
1052
|
+
Returns:
|
|
1053
|
+
Gathered tensor
|
|
1054
|
+
|
|
1055
|
+
Example:
|
|
1056
|
+
>>> x = constant([10, 20, 30, 40])
|
|
1057
|
+
>>> idx = constant([0, 2, 1])
|
|
1058
|
+
>>> y = gather(x, idx) # [10, 30, 20]
|
|
1059
|
+
"""
|
|
1060
|
+
return gather_p.bind(tensor, indices, axis=axis) # type: ignore[no-any-return]
|
|
1061
|
+
|
|
1062
|
+
|
|
1063
|
+
def scatter(
|
|
1064
|
+
tensor: el.Object,
|
|
1065
|
+
indices: el.Object,
|
|
1066
|
+
updates: el.Object,
|
|
1067
|
+
axis: int = 0,
|
|
1068
|
+
) -> el.Object:
|
|
1069
|
+
"""Scatter updates into tensor at indices.
|
|
1070
|
+
|
|
1071
|
+
Args:
|
|
1072
|
+
tensor: Input tensor
|
|
1073
|
+
indices: Integer indices tensor
|
|
1074
|
+
updates: Updates tensor
|
|
1075
|
+
axis: Axis along which to scatter
|
|
1076
|
+
|
|
1077
|
+
Returns:
|
|
1078
|
+
Updated tensor
|
|
1079
|
+
|
|
1080
|
+
Example:
|
|
1081
|
+
>>> x = constant([1, 2, 3, 4])
|
|
1082
|
+
>>> idx = constant([0, 2])
|
|
1083
|
+
>>> updates = constant([10, 30])
|
|
1084
|
+
>>> y = scatter(x, idx, updates) # [10, 2, 30, 4]
|
|
1085
|
+
"""
|
|
1086
|
+
return scatter_p.bind(tensor, indices, updates, axis=axis) # type: ignore[no-any-return]
|
|
1087
|
+
|
|
1088
|
+
|
|
1089
|
+
def slice_tensor(
|
|
1090
|
+
tensor: el.Object,
|
|
1091
|
+
starts: tuple[int, ...],
|
|
1092
|
+
ends: tuple[int, ...],
|
|
1093
|
+
strides: tuple[int, ...] | None = None,
|
|
1094
|
+
) -> el.Object:
|
|
1095
|
+
"""Slice tensor along dimensions.
|
|
1096
|
+
|
|
1097
|
+
Args:
|
|
1098
|
+
tensor: Input tensor
|
|
1099
|
+
starts: Start indices for each dimension
|
|
1100
|
+
ends: End indices for each dimension
|
|
1101
|
+
strides: Stride for each dimension (defaults to 1)
|
|
1102
|
+
|
|
1103
|
+
Returns:
|
|
1104
|
+
Sliced tensor
|
|
1105
|
+
|
|
1106
|
+
Example:
|
|
1107
|
+
>>> x = constant([[1, 2, 3], [4, 5, 6]])
|
|
1108
|
+
>>> y = slice_tensor(x, (0, 1), (2, 3)) # [[2, 3], [5, 6]]
|
|
1109
|
+
"""
|
|
1110
|
+
return slice_p.bind(tensor, starts=starts, ends=ends, strides=strides) # type: ignore[no-any-return]
|
|
1111
|
+
|
|
1112
|
+
|
|
1113
|
+
# ==============================================================================
|
|
1114
|
+
# --- Type Reinterpretation (via run_jax)
|
|
1115
|
+
# ==============================================================================
|
|
1116
|
+
|
|
1117
|
+
|
|
1118
|
+
def bitcast(x: el.Object, dtype: elt.ScalarType) -> el.Object:
|
|
1119
|
+
"""Reinterpret tensor bytes as a different dtype.
|
|
1120
|
+
|
|
1121
|
+
This is a zero-copy (at execution time) type reinterpretation that views
|
|
1122
|
+
the same underlying bytes as a different element type. The total byte
|
|
1123
|
+
count must remain the same.
|
|
1124
|
+
|
|
1125
|
+
This follows LLVM/MLIR `bitcast` semantics: the operation produces a new
|
|
1126
|
+
SSA value with different type but same bit representation.
|
|
1127
|
+
|
|
1128
|
+
Args:
|
|
1129
|
+
x: Input tensor.
|
|
1130
|
+
dtype: Target element type (e.g., elt.u64, elt.u8, elt.i32).
|
|
1131
|
+
|
|
1132
|
+
Returns:
|
|
1133
|
+
Tensor with same bytes reinterpreted as dtype.
|
|
1134
|
+
Shape changes to preserve total bytes.
|
|
1135
|
+
|
|
1136
|
+
Example:
|
|
1137
|
+
>>> # Tensor[u8, (8,)] -> Tensor[u64, (1,)]
|
|
1138
|
+
>>> packed = tensor.bitcast(bytes_tensor, elt.u64)
|
|
1139
|
+
>>> # Tensor[u64, (10, 2)] -> Tensor[u8, (10, 16)]
|
|
1140
|
+
>>> unpacked = tensor.bitcast(u64_tensor, elt.u8)
|
|
1141
|
+
"""
|
|
1142
|
+
from typing import cast
|
|
1143
|
+
|
|
1144
|
+
jax_dtype = dtypes.to_jax(dtype)
|
|
1145
|
+
|
|
1146
|
+
def _bitcast(arr: Any) -> Any:
|
|
1147
|
+
return arr.view(jax_dtype)
|
|
1148
|
+
|
|
1149
|
+
return cast(el.Object, run_jax(_bitcast, x))
|
|
1150
|
+
|
|
1151
|
+
|
|
1152
|
+
__all__ = [
|
|
1153
|
+
"RunJaxCompilation",
|
|
1154
|
+
"bitcast",
|
|
1155
|
+
"concat",
|
|
1156
|
+
"concat_p",
|
|
1157
|
+
"constant",
|
|
1158
|
+
"constant_p",
|
|
1159
|
+
"elementwise",
|
|
1160
|
+
"elementwise_p",
|
|
1161
|
+
"gather",
|
|
1162
|
+
"gather_p",
|
|
1163
|
+
"get_run_jax_compilation",
|
|
1164
|
+
"jax_fn",
|
|
1165
|
+
"reshape",
|
|
1166
|
+
"reshape_p",
|
|
1167
|
+
"run_jax",
|
|
1168
|
+
"run_jax_p",
|
|
1169
|
+
"scatter",
|
|
1170
|
+
"scatter_p",
|
|
1171
|
+
"slice_p",
|
|
1172
|
+
"slice_tensor",
|
|
1173
|
+
"transpose",
|
|
1174
|
+
"transpose_p",
|
|
1175
|
+
]
|