mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,519 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Tensor Runtime Implementation.
|
|
16
|
+
|
|
17
|
+
Implements execution logic for Tensor primitives.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import base64
|
|
23
|
+
import hashlib
|
|
24
|
+
import os
|
|
25
|
+
import time
|
|
26
|
+
from typing import Any, ClassVar, cast
|
|
27
|
+
|
|
28
|
+
import jax
|
|
29
|
+
import jax.extend as jxt
|
|
30
|
+
import jax.numpy as jnp
|
|
31
|
+
import numpy as np
|
|
32
|
+
from jax._src import compiler
|
|
33
|
+
from numpy.typing import ArrayLike
|
|
34
|
+
|
|
35
|
+
import mplang.v2.edsl.typing as elt
|
|
36
|
+
from mplang.v2.dialects import dtypes, tensor
|
|
37
|
+
from mplang.v2.edsl import serde
|
|
38
|
+
from mplang.v2.edsl.graph import Operation
|
|
39
|
+
from mplang.v2.runtime.interpreter import Interpreter
|
|
40
|
+
from mplang.v2.runtime.value import Value, WrapValue
|
|
41
|
+
|
|
42
|
+
# =============================================================================
|
|
43
|
+
# TensorValue Wrapper
|
|
44
|
+
# =============================================================================
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@serde.register_class
|
|
48
|
+
class TensorValue(WrapValue[Any]):
|
|
49
|
+
"""Runtime value wrapping a numpy array or JAX array.
|
|
50
|
+
|
|
51
|
+
Handles numpy arrays, JAX arrays, and other numpy-like objects via duck typing.
|
|
52
|
+
Serialization uses base64-encoded raw bytes for efficiency.
|
|
53
|
+
|
|
54
|
+
Note: This is for numeric tensors only. Object dtype arrays (containing
|
|
55
|
+
encrypted values, etc.) should NOT be wrapped - they are handled separately
|
|
56
|
+
by elementwise_impl which returns raw np.ndarray(dtype=object).
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
_serde_kind: ClassVar[str] = "tensor_impl.TensorValue"
|
|
60
|
+
|
|
61
|
+
# Expose common array properties for convenience
|
|
62
|
+
@property
|
|
63
|
+
def shape(self) -> tuple[int, ...]:
|
|
64
|
+
return cast(tuple[int, ...], self._data.shape)
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def dtype(self) -> np.dtype[Any]:
|
|
68
|
+
return np.dtype(self._data.dtype) # type: ignore[no-any-return]
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def ndim(self) -> int:
|
|
72
|
+
return cast(int, self._data.ndim)
|
|
73
|
+
|
|
74
|
+
def __getitem__(self, key: Any) -> Any:
|
|
75
|
+
"""Allow indexing into the underlying array."""
|
|
76
|
+
return self._data[key]
|
|
77
|
+
|
|
78
|
+
# =========== Wrap/Unwrap ===========
|
|
79
|
+
|
|
80
|
+
def _convert(self, data: Any) -> Any:
|
|
81
|
+
"""Convert input data to numpy array or JAX array."""
|
|
82
|
+
if isinstance(data, TensorValue):
|
|
83
|
+
return data._data
|
|
84
|
+
|
|
85
|
+
# Allow JAX arrays to pass through
|
|
86
|
+
if hasattr(data, "__jax_array__"):
|
|
87
|
+
return data
|
|
88
|
+
|
|
89
|
+
# Handle other numpy-like objects via np.asarray
|
|
90
|
+
if (
|
|
91
|
+
hasattr(data, "__module__")
|
|
92
|
+
and data.__module__ is not None
|
|
93
|
+
and "jax" in data.__module__
|
|
94
|
+
):
|
|
95
|
+
return data
|
|
96
|
+
|
|
97
|
+
if isinstance(data, np.ndarray):
|
|
98
|
+
return data
|
|
99
|
+
# Try converting other array-like objects
|
|
100
|
+
return np.asarray(data)
|
|
101
|
+
|
|
102
|
+
def unwrap(self) -> np.ndarray:
|
|
103
|
+
"""Get the underlying data as a numpy array.
|
|
104
|
+
|
|
105
|
+
If the data is a JAX array, it will be transferred to host.
|
|
106
|
+
"""
|
|
107
|
+
return np.asarray(self._data)
|
|
108
|
+
|
|
109
|
+
def as_jax(self) -> Any:
|
|
110
|
+
"""Get the underlying data as a JAX array.
|
|
111
|
+
|
|
112
|
+
If the data is a numpy array, it will be transferred to device.
|
|
113
|
+
"""
|
|
114
|
+
if hasattr(self._data, "__jax_array__"):
|
|
115
|
+
return self._data
|
|
116
|
+
|
|
117
|
+
# Handle object arrays that might contain numbers (e.g. from elementwise)
|
|
118
|
+
if isinstance(self._data, np.ndarray) and self._data.dtype == object:
|
|
119
|
+
try:
|
|
120
|
+
# Attempt to convert to numeric numpy array first
|
|
121
|
+
# This handles cases where elementwise returned object array of numbers
|
|
122
|
+
val_numeric = np.array(self._data.tolist())
|
|
123
|
+
if val_numeric.dtype != object:
|
|
124
|
+
return jax.device_put(jnp.asarray(val_numeric))
|
|
125
|
+
except Exception:
|
|
126
|
+
# If conversion fails, proceed with original (which will likely fail in jax)
|
|
127
|
+
pass
|
|
128
|
+
|
|
129
|
+
return jax.device_put(jnp.asarray(self._data))
|
|
130
|
+
|
|
131
|
+
# =========== Serialization ===========
|
|
132
|
+
|
|
133
|
+
def to_json(self) -> dict[str, Any]:
|
|
134
|
+
# Ensure we have numpy data for serialization
|
|
135
|
+
# This forces synchronization if data is on device
|
|
136
|
+
data_np = np.asarray(self._data)
|
|
137
|
+
|
|
138
|
+
# Handle object dtype arrays - serialize element by element
|
|
139
|
+
if data_np.dtype == np.object_:
|
|
140
|
+
return {
|
|
141
|
+
"kind": "object",
|
|
142
|
+
"shape": list(data_np.shape),
|
|
143
|
+
"items": [serde.to_json(item) for item in data_np.flat],
|
|
144
|
+
}
|
|
145
|
+
# Standard numeric arrays - use raw bytes
|
|
146
|
+
return {
|
|
147
|
+
"kind": "numeric",
|
|
148
|
+
"dtype": str(data_np.dtype),
|
|
149
|
+
"shape": list(data_np.shape),
|
|
150
|
+
"data": base64.b64encode(data_np.tobytes()).decode("ascii"),
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
@classmethod
|
|
154
|
+
def from_json(cls, data: dict[str, Any]) -> TensorValue:
|
|
155
|
+
kind = data.get("kind", "numeric")
|
|
156
|
+
shape = tuple(data["shape"])
|
|
157
|
+
|
|
158
|
+
if kind == "object":
|
|
159
|
+
items = [serde.from_json(item) for item in data["items"]]
|
|
160
|
+
arr = np.empty(len(items), dtype=object)
|
|
161
|
+
for i, item in enumerate(items):
|
|
162
|
+
arr[i] = item
|
|
163
|
+
return cls(arr.reshape(shape))
|
|
164
|
+
else:
|
|
165
|
+
arr = np.frombuffer(
|
|
166
|
+
base64.b64decode(data["data"]),
|
|
167
|
+
dtype=np.dtype(data["dtype"]),
|
|
168
|
+
)
|
|
169
|
+
return cls(arr.reshape(shape).copy())
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
# Module-level helpers for convenience (delegate to class methods)
|
|
173
|
+
def _wrap(val: ArrayLike | TensorValue) -> TensorValue:
|
|
174
|
+
"""Wrap an array-like value into TensorValue."""
|
|
175
|
+
return TensorValue.wrap(val)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _unwrap(val: TensorValue | np.ndarray | ArrayLike) -> np.ndarray:
|
|
179
|
+
"""Unwrap TensorValue to np.ndarray, also accepts raw arrays."""
|
|
180
|
+
if isinstance(val, TensorValue):
|
|
181
|
+
return val.unwrap()
|
|
182
|
+
if isinstance(val, np.ndarray):
|
|
183
|
+
return val
|
|
184
|
+
# Handle JAX arrays
|
|
185
|
+
if hasattr(val, "__jax_array__"):
|
|
186
|
+
return np.asarray(val)
|
|
187
|
+
return np.asarray(val)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
# _ensure_tensor_value removed - callers should unwrap InterpObject before calling impls
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
# =============================================================================
|
|
194
|
+
# Tensor Primitive Implementations
|
|
195
|
+
# =============================================================================
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
@tensor.constant_p.def_impl
|
|
199
|
+
def constant_impl(interpreter: Interpreter, op: Operation) -> TensorValue:
|
|
200
|
+
# Recover dtype and shape from IR type
|
|
201
|
+
output_type = op.outputs[0].type
|
|
202
|
+
if not isinstance(output_type, elt.TensorType):
|
|
203
|
+
raise TypeError(f"Expected TensorType, got {output_type}")
|
|
204
|
+
|
|
205
|
+
dtype = dtypes.to_jax(cast(elt.ScalarType, output_type.element_type))
|
|
206
|
+
if dtype is None:
|
|
207
|
+
raise ValueError(f"Unsupported scalar type {output_type.element_type}")
|
|
208
|
+
|
|
209
|
+
shape = output_type.shape
|
|
210
|
+
|
|
211
|
+
# Decode data
|
|
212
|
+
data_b64 = op.attrs["value_b64"]
|
|
213
|
+
data_bytes = base64.b64decode(data_b64)
|
|
214
|
+
|
|
215
|
+
# Create array
|
|
216
|
+
arr = np.frombuffer(data_bytes, dtype=cast(Any, dtype)).reshape(shape).copy()
|
|
217
|
+
return _wrap(arr)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@tensor.concat_p.def_impl
|
|
221
|
+
def concat_impl(
|
|
222
|
+
interpreter: Interpreter, op: Operation, *args: TensorValue
|
|
223
|
+
) -> TensorValue:
|
|
224
|
+
axis = op.attrs.get("axis", 0)
|
|
225
|
+
unwrapped = [_unwrap(a) for a in args]
|
|
226
|
+
return _wrap(np.concatenate(unwrapped, axis=axis))
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
@tensor.elementwise_p.def_impl
|
|
230
|
+
def elementwise_impl(interpreter: Interpreter, op: Operation, *args: Value) -> Any:
|
|
231
|
+
"""Execute elementwise operation by iterating over tensor elements.
|
|
232
|
+
|
|
233
|
+
Note: args typed as Value (base class) because elementwise handles polymorphic
|
|
234
|
+
inputs - TensorValue for numeric tensors, or np.ndarray with dtype=object
|
|
235
|
+
containing encrypted values (BFVValue, etc.) that are processed element-wise.
|
|
236
|
+
"""
|
|
237
|
+
# args are the input tensors (or scalars)
|
|
238
|
+
# op.regions[0] is the scalar computation graph
|
|
239
|
+
|
|
240
|
+
# 1. Determine shape from IR types and runtime args
|
|
241
|
+
shape = ()
|
|
242
|
+
for i, inp_val in enumerate(op.inputs):
|
|
243
|
+
if isinstance(inp_val.type, elt.TensorType):
|
|
244
|
+
if inp_val.type.shape != ():
|
|
245
|
+
# Found a non-scalar tensor input. Use its runtime shape.
|
|
246
|
+
# We assume the tracer ensured all non-scalar tensors have compatible shapes.
|
|
247
|
+
arg = args[i]
|
|
248
|
+
if hasattr(arg, "shape"):
|
|
249
|
+
shape = arg.shape
|
|
250
|
+
break
|
|
251
|
+
|
|
252
|
+
# 2. Construct output container
|
|
253
|
+
# We need to know the output type/dtype.
|
|
254
|
+
# op.outputs[0].type should give us a hint, but here we are in runtime.
|
|
255
|
+
# Let's just use a list or numpy array of objects for flexibility.
|
|
256
|
+
# Since we might be mixing types (e.g. Encrypted objects), object array is safest.
|
|
257
|
+
num_outputs = len(op.outputs)
|
|
258
|
+
results: Any
|
|
259
|
+
if num_outputs > 1:
|
|
260
|
+
results = [np.empty(shape, dtype=object) for _ in range(num_outputs)]
|
|
261
|
+
else:
|
|
262
|
+
results = np.empty(shape, dtype=object)
|
|
263
|
+
|
|
264
|
+
# 3. Iterate and execute
|
|
265
|
+
# Use np.ndindex for multi-dimensional iteration
|
|
266
|
+
subgraph = op.regions[0]
|
|
267
|
+
|
|
268
|
+
if shape == ():
|
|
269
|
+
# Scalar case - return first element from result list
|
|
270
|
+
result = interpreter.evaluate_graph(subgraph, list(args))
|
|
271
|
+
return result[0] if len(result) == 1 else result
|
|
272
|
+
|
|
273
|
+
for index in np.ndindex(shape):
|
|
274
|
+
# Prepare inputs for this element (list ordered by subgraph.inputs)
|
|
275
|
+
scalar_inputs = []
|
|
276
|
+
for i, arg in enumerate(args):
|
|
277
|
+
outer_val = op.inputs[i]
|
|
278
|
+
# Check if this argument should be iterated based on OUTER IR type
|
|
279
|
+
if (
|
|
280
|
+
isinstance(outer_val.type, elt.TensorType)
|
|
281
|
+
and outer_val.type.shape != ()
|
|
282
|
+
):
|
|
283
|
+
# Tensor argument: pick element (arg is array-like at runtime)
|
|
284
|
+
# Wrap scalar in TensorValue to maintain Value-only contract
|
|
285
|
+
elem = cast(Any, arg)[index]
|
|
286
|
+
if isinstance(elem, Value):
|
|
287
|
+
scalar_inputs.append(elem)
|
|
288
|
+
else:
|
|
289
|
+
scalar_inputs.append(_wrap(np.array(elem))) # type: ignore[index]
|
|
290
|
+
else:
|
|
291
|
+
# Scalar/Broadcast argument: use as is
|
|
292
|
+
# Ensure it is wrapped (it should be, but double check)
|
|
293
|
+
if not isinstance(arg, Value):
|
|
294
|
+
scalar_inputs.append(_wrap(np.array(arg)))
|
|
295
|
+
else:
|
|
296
|
+
scalar_inputs.append(arg)
|
|
297
|
+
|
|
298
|
+
# Recursive execution
|
|
299
|
+
scalar_out_list = interpreter.evaluate_graph(subgraph, scalar_inputs)
|
|
300
|
+
scalar_out = (
|
|
301
|
+
scalar_out_list[0] if len(scalar_out_list) == 1 else scalar_out_list
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# Unwrap result if it's a TensorValue (to store in numpy array)
|
|
305
|
+
# We store raw values in the object array for now, but will wrap the final array
|
|
306
|
+
if isinstance(scalar_out, TensorValue):
|
|
307
|
+
scalar_out = scalar_out.unwrap()
|
|
308
|
+
if scalar_out.shape == ():
|
|
309
|
+
scalar_out = scalar_out.item()
|
|
310
|
+
|
|
311
|
+
if num_outputs > 1:
|
|
312
|
+
for i, val in enumerate(scalar_out):
|
|
313
|
+
results[i][index] = val
|
|
314
|
+
else:
|
|
315
|
+
results[index] = scalar_out
|
|
316
|
+
|
|
317
|
+
# Wrap results in TensorValue if possible
|
|
318
|
+
if num_outputs > 1:
|
|
319
|
+
return [_wrap(res) for res in results]
|
|
320
|
+
else:
|
|
321
|
+
return _wrap(results)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
# Global cache for compiled StableHLO executables
|
|
325
|
+
_STABLEHLO_CACHE: dict[str, Any] = {}
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
@tensor.run_jax_p.def_impl
|
|
329
|
+
def run_jax_impl(
|
|
330
|
+
interpreter: Interpreter, op: Operation, *args: TensorValue
|
|
331
|
+
) -> TensorValue | list[TensorValue]:
|
|
332
|
+
"""Execute JAX function."""
|
|
333
|
+
t0 = time.time()
|
|
334
|
+
|
|
335
|
+
# Execute via StableHLO
|
|
336
|
+
stablehlo_code = op.attrs.get("stablehlo_code")
|
|
337
|
+
if stablehlo_code is None:
|
|
338
|
+
raise NotImplementedError(
|
|
339
|
+
"run_jax execution requires 'stablehlo_code' attribute"
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# Compile StableHLO
|
|
343
|
+
client = jxt.backend.get_backend()
|
|
344
|
+
|
|
345
|
+
# Use SHA256 of code as cache key for stability across runs
|
|
346
|
+
# Note: We assume compile_options are constant (num_replicas=1, num_partitions=1)
|
|
347
|
+
code_hash = hashlib.sha256(stablehlo_code.encode("utf-8")).hexdigest()
|
|
348
|
+
|
|
349
|
+
if code_hash in _STABLEHLO_CACHE:
|
|
350
|
+
compiled = _STABLEHLO_CACHE[code_hash]
|
|
351
|
+
else:
|
|
352
|
+
compile_options = compiler.get_compile_options(num_replicas=1, num_partitions=1)
|
|
353
|
+
|
|
354
|
+
# Try disk cache
|
|
355
|
+
cache_dir = interpreter.root_dir / "cache" / "jax"
|
|
356
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
357
|
+
cache_path = str(cache_dir / f"{code_hash}.pjrt")
|
|
358
|
+
loaded_from_disk = False
|
|
359
|
+
|
|
360
|
+
if os.path.exists(cache_path):
|
|
361
|
+
try:
|
|
362
|
+
with open(cache_path, "rb") as f:
|
|
363
|
+
serialized = f.read()
|
|
364
|
+
compiled = client.deserialize_executable(
|
|
365
|
+
serialized, client.devices(), compile_options
|
|
366
|
+
)
|
|
367
|
+
loaded_from_disk = True
|
|
368
|
+
# print(f"[JAX] Loaded compiled executable from {cache_path}")
|
|
369
|
+
except Exception as e:
|
|
370
|
+
print(f"[JAX] Failed to load from disk cache: {e}")
|
|
371
|
+
|
|
372
|
+
if not loaded_from_disk:
|
|
373
|
+
try:
|
|
374
|
+
compiled = client.compile_and_load(
|
|
375
|
+
stablehlo_code, client.devices(), compile_options
|
|
376
|
+
)
|
|
377
|
+
# Save to disk
|
|
378
|
+
try:
|
|
379
|
+
# Directory creation handled above
|
|
380
|
+
with open(cache_path, "wb") as f:
|
|
381
|
+
f.write(client.serialize_executable(compiled))
|
|
382
|
+
# print(f"[JAX] Saved compiled executable to {cache_path}")
|
|
383
|
+
except Exception as e:
|
|
384
|
+
print(f"[JAX] Failed to save to disk cache: {e}")
|
|
385
|
+
except Exception as e:
|
|
386
|
+
raise RuntimeError(f"StableHLO compile failed: {e}") from e
|
|
387
|
+
|
|
388
|
+
_STABLEHLO_CACHE[code_hash] = compiled
|
|
389
|
+
|
|
390
|
+
# Cast inputs to expected types (Boundary Type Guard)
|
|
391
|
+
# This allows users to pass Python ints/floats to functions expecting f32/i32
|
|
392
|
+
t1 = time.time()
|
|
393
|
+
|
|
394
|
+
jax_input_args = []
|
|
395
|
+
for i, arg in enumerate(args):
|
|
396
|
+
# arg is TensorValue
|
|
397
|
+
if i < len(op.inputs):
|
|
398
|
+
input_type = op.inputs[i].type
|
|
399
|
+
# Check if we need casting
|
|
400
|
+
if isinstance(input_type, elt.TensorType):
|
|
401
|
+
dtype = dtypes.to_jax(cast(elt.ScalarType, input_type.element_type))
|
|
402
|
+
# Get as JAX array
|
|
403
|
+
if isinstance(arg, TensorValue):
|
|
404
|
+
val = arg.as_jax()
|
|
405
|
+
else:
|
|
406
|
+
val = jnp.asarray(arg)
|
|
407
|
+
|
|
408
|
+
if (
|
|
409
|
+
dtype is not None
|
|
410
|
+
and isinstance(val, (jnp.ndarray, np.ndarray))
|
|
411
|
+
and val.dtype != dtype
|
|
412
|
+
):
|
|
413
|
+
val = val.astype(dtype)
|
|
414
|
+
jax_input_args.append(val)
|
|
415
|
+
else:
|
|
416
|
+
if isinstance(arg, TensorValue):
|
|
417
|
+
jax_input_args.append(arg.as_jax())
|
|
418
|
+
else:
|
|
419
|
+
jax_input_args.append(jnp.asarray(arg))
|
|
420
|
+
else:
|
|
421
|
+
if isinstance(arg, TensorValue):
|
|
422
|
+
jax_input_args.append(arg.as_jax())
|
|
423
|
+
else:
|
|
424
|
+
jax_input_args.append(jnp.asarray(arg))
|
|
425
|
+
|
|
426
|
+
# Handle JAX's unused parameter elimination via arg_keep_map
|
|
427
|
+
arg_keep_map = op.attrs.get("arg_keep_map")
|
|
428
|
+
if arg_keep_map is not None:
|
|
429
|
+
# Filter out arguments that were eliminated by JAX during compilation
|
|
430
|
+
jax_input_args = [jax_input_args[i] for i in arg_keep_map]
|
|
431
|
+
|
|
432
|
+
# Convert args to JAX arrays
|
|
433
|
+
t2 = time.time()
|
|
434
|
+
# jax_input_args are already JAX arrays (or will be handled by execute_sharded if not)
|
|
435
|
+
jax_args = jax_input_args
|
|
436
|
+
|
|
437
|
+
try:
|
|
438
|
+
t3 = time.time()
|
|
439
|
+
result = compiled.execute_sharded(jax_args)
|
|
440
|
+
t4 = time.time()
|
|
441
|
+
arrays = result.disassemble_into_single_device_arrays()
|
|
442
|
+
flat: list[TensorValue] = []
|
|
443
|
+
for lst in arrays:
|
|
444
|
+
if isinstance(lst, list) and len(lst) == 1:
|
|
445
|
+
# Wrap JAX array directly, avoiding np.asarray
|
|
446
|
+
flat.append(_wrap(lst[0]))
|
|
447
|
+
else:
|
|
448
|
+
flat.extend(_wrap(a) for a in lst)
|
|
449
|
+
t5 = time.time()
|
|
450
|
+
|
|
451
|
+
if interpreter.tracer:
|
|
452
|
+
p = interpreter.tracer
|
|
453
|
+
p.log_custom_event("JAX Compile/Cache", t0, t1, cat="jax")
|
|
454
|
+
p.log_custom_event("JAX Prep", t1, t2, cat="jax")
|
|
455
|
+
p.log_custom_event("JAX Transfer In", t2, t3, cat="jax")
|
|
456
|
+
p.log_custom_event("JAX Exec", t3, t4, cat="jax")
|
|
457
|
+
p.log_custom_event("JAX Transfer Out", t4, t5, cat="jax")
|
|
458
|
+
|
|
459
|
+
# If single output, return it directly (but run_jax usually returns list of vars)
|
|
460
|
+
# The primitive expects a list of results matching outputs.
|
|
461
|
+
# If op has 1 output, flat should have 1 element.
|
|
462
|
+
if len(op.outputs) == 1 and len(flat) == 1:
|
|
463
|
+
return flat[0]
|
|
464
|
+
return flat
|
|
465
|
+
except Exception as e:
|
|
466
|
+
raise RuntimeError(f"StableHLO execute failed: {e}") from e
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
@tensor.gather_p.def_impl
|
|
470
|
+
def gather_impl(
|
|
471
|
+
interpreter: Interpreter, op: Operation, operand: TensorValue, indices: TensorValue
|
|
472
|
+
) -> TensorValue:
|
|
473
|
+
axis = op.attrs.get("axis", 0)
|
|
474
|
+
operand_arr = _unwrap(operand)
|
|
475
|
+
indices_arr = _unwrap(indices)
|
|
476
|
+
# Ensure indices are integers (they might be JAX arrays or numpy arrays)
|
|
477
|
+
if hasattr(indices_arr, "astype"):
|
|
478
|
+
indices_arr = indices_arr.astype(int)
|
|
479
|
+
return _wrap(np.take(operand_arr, indices_arr, axis=axis))
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
@tensor.slice_p.def_impl
|
|
483
|
+
def slice_impl(
|
|
484
|
+
interpreter: Interpreter, op: Operation, operand: TensorValue
|
|
485
|
+
) -> TensorValue:
|
|
486
|
+
starts = op.attrs["starts"]
|
|
487
|
+
ends = op.attrs["ends"]
|
|
488
|
+
strides = op.attrs.get("strides")
|
|
489
|
+
|
|
490
|
+
slices: list[Any] = []
|
|
491
|
+
for i in range(len(starts)):
|
|
492
|
+
start = starts[i]
|
|
493
|
+
end = ends[i]
|
|
494
|
+
stride = strides[i] if strides else 1
|
|
495
|
+
slices.append(slice(start, end, stride))
|
|
496
|
+
|
|
497
|
+
operand_arr = _unwrap(operand)
|
|
498
|
+
# If operand is numpy array, we can slice directly
|
|
499
|
+
# If operand has more dimensions than slices provided, we assume full slice for remaining
|
|
500
|
+
if len(slices) < operand_arr.ndim:
|
|
501
|
+
slices.append(Ellipsis)
|
|
502
|
+
|
|
503
|
+
return _wrap(operand_arr[tuple(slices)])
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
@tensor.reshape_p.def_impl
|
|
507
|
+
def reshape_impl(
|
|
508
|
+
interpreter: Interpreter, op: Operation, tensor_data: TensorValue
|
|
509
|
+
) -> TensorValue:
|
|
510
|
+
new_shape = op.attrs["new_shape"]
|
|
511
|
+
return _wrap(_unwrap(tensor_data).reshape(new_shape))
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
@tensor.transpose_p.def_impl
|
|
515
|
+
def transpose_impl(
|
|
516
|
+
interpreter: Interpreter, op: Operation, tensor_data: TensorValue
|
|
517
|
+
) -> TensorValue:
|
|
518
|
+
perm = op.attrs.get("perm")
|
|
519
|
+
return _wrap(np.transpose(_unwrap(tensor_data), axes=perm))
|