mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/v1/ops/crypto.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
Crypto frontend operations: operation signatures, types, and high-level semantics.
|
|
17
|
+
|
|
18
|
+
Scope and contracts:
|
|
19
|
+
- This module defines portable API shapes; it does not implement cryptography.
|
|
20
|
+
- Backends execute the operations and must meet the security semantics required
|
|
21
|
+
by the deployment (confidentiality, authenticity, correctness, etc.).
|
|
22
|
+
- The enc/dec API in this frontend uses a conventional 12-byte nonce prefix
|
|
23
|
+
(ciphertext = nonce || payload), and dec expects that format. Other security
|
|
24
|
+
properties (e.g., AEAD) are backend responsibilities.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from __future__ import annotations
|
|
28
|
+
|
|
29
|
+
from jax.tree_util import PyTreeDef, tree_flatten
|
|
30
|
+
|
|
31
|
+
from mplang.v1.core import UINT8, TensorType
|
|
32
|
+
from mplang.v1.core.mpobject import MPObject
|
|
33
|
+
from mplang.v1.core.pfunc import PFunction
|
|
34
|
+
from mplang.v1.ops.base import stateless_mod
|
|
35
|
+
|
|
36
|
+
_CRYPTO_MOD = stateless_mod("crypto")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _get_algo_overhead(algo: str) -> int:
|
|
40
|
+
"""Get ciphertext overhead for a given encryption algorithm.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
algo: Encryption algorithm identifier
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
int: Number of overhead bytes added to plaintext length
|
|
47
|
+
"""
|
|
48
|
+
overhead_map = {
|
|
49
|
+
"aes-ctr": 16, # nonce only (legacy compatibility)
|
|
50
|
+
"aes-gcm": 28, # nonce(12) + tag(16) for AES-GCM
|
|
51
|
+
"sm4-gcm": 28, # nonce(12) + tag(16) for SM4-GCM
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
if algo not in overhead_map:
|
|
55
|
+
# return unknown overhead as -1
|
|
56
|
+
return -1
|
|
57
|
+
return overhead_map[algo]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@_CRYPTO_MOD.simple_op()
|
|
61
|
+
def keygen(*, length: int = 32) -> TensorType:
|
|
62
|
+
"""Generate random bytes for symmetric keys or generic randomness.
|
|
63
|
+
|
|
64
|
+
API: keygen(length: int = 32) -> key: u8[length]
|
|
65
|
+
|
|
66
|
+
Notes:
|
|
67
|
+
- Frontend defines the type/shape; backend provides randomness.
|
|
68
|
+
- Raises ValueError when length <= 0.
|
|
69
|
+
"""
|
|
70
|
+
if length <= 0:
|
|
71
|
+
raise ValueError("length must be > 0")
|
|
72
|
+
return TensorType(UINT8, (length,))
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@_CRYPTO_MOD.op_def()
|
|
76
|
+
def enc(
|
|
77
|
+
plaintext: MPObject, key: MPObject, algo: str = "aes-ctr"
|
|
78
|
+
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
79
|
+
"""Symmetric encryption with algorithm-aware output sizing.
|
|
80
|
+
|
|
81
|
+
API: enc(plaintext: u8[N], key: u8[M], *, algo: str = "aes-ctr") -> ciphertext: u8[N + overhead]
|
|
82
|
+
|
|
83
|
+
Supported algorithms and overhead:
|
|
84
|
+
- "aes-ctr": 16 bytes (nonce only, legacy compatibility)
|
|
85
|
+
- "aes-gcm": 28 bytes (nonce + 16-byte authentication tag)
|
|
86
|
+
- "sm4-gcm": 28 bytes (nonce + 16-byte authentication tag)
|
|
87
|
+
|
|
88
|
+
The algo parameter is stored in the PFunction attributes for backend use.
|
|
89
|
+
"""
|
|
90
|
+
pt_ty = plaintext
|
|
91
|
+
if pt_ty.dtype != UINT8:
|
|
92
|
+
raise TypeError("enc expects UINT8 plaintext")
|
|
93
|
+
if len(pt_ty.shape) != 1:
|
|
94
|
+
raise TypeError("enc expects 1-D plaintext")
|
|
95
|
+
|
|
96
|
+
# Validate and get overhead for the specified algorithm
|
|
97
|
+
overhead = _get_algo_overhead(algo)
|
|
98
|
+
length = pt_ty.shape[0]
|
|
99
|
+
if length >= 0 and overhead >= 0:
|
|
100
|
+
outs_info = (TensorType(UINT8, (length + overhead,)),)
|
|
101
|
+
else:
|
|
102
|
+
# Unknown length or overhead, return dynamic length
|
|
103
|
+
outs_info = (TensorType(UINT8, (-1,)),)
|
|
104
|
+
|
|
105
|
+
ins_info = (TensorType.from_obj(pt_ty), TensorType.from_obj(key))
|
|
106
|
+
pfunc = PFunction(
|
|
107
|
+
fn_type="crypto.enc",
|
|
108
|
+
ins_info=ins_info,
|
|
109
|
+
outs_info=outs_info,
|
|
110
|
+
algo=algo,
|
|
111
|
+
)
|
|
112
|
+
_, treedef = tree_flatten(outs_info[0])
|
|
113
|
+
return pfunc, [plaintext, key], treedef
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@_CRYPTO_MOD.op_def()
|
|
117
|
+
def dec(
|
|
118
|
+
ciphertext: MPObject, key: MPObject, algo: str = "aes-ctr"
|
|
119
|
+
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
120
|
+
"""Symmetric decryption with algorithm-aware input sizing.
|
|
121
|
+
|
|
122
|
+
API: dec(ciphertext: u8[N + overhead], key: u8[M], *, algo: str = "aes-ctr") -> plaintext: u8[N]
|
|
123
|
+
|
|
124
|
+
Supported algorithms and overhead:
|
|
125
|
+
- "aes-ctr": 16 bytes (nonce only, legacy compatibility)
|
|
126
|
+
- "aes-gcm": 28 bytes (nonce + 16-byte authentication tag)
|
|
127
|
+
- "sm4-gcm": 28 bytes (nonce + 16-byte authentication tag)
|
|
128
|
+
|
|
129
|
+
The algo parameter is stored in the PFunction attributes for backend use.
|
|
130
|
+
Backend is responsible for parsing the ciphertext format according to algo.
|
|
131
|
+
"""
|
|
132
|
+
ct_ty = ciphertext
|
|
133
|
+
if ct_ty.dtype != UINT8:
|
|
134
|
+
raise TypeError("dec expects UINT8 ciphertext")
|
|
135
|
+
if len(ct_ty.shape) != 1:
|
|
136
|
+
raise TypeError("dec expects 1-D ciphertext")
|
|
137
|
+
|
|
138
|
+
# Validate and get overhead for the specified algorithm
|
|
139
|
+
overhead = _get_algo_overhead(algo)
|
|
140
|
+
length = ct_ty.shape[0]
|
|
141
|
+
|
|
142
|
+
# Validate minimum ciphertext length
|
|
143
|
+
if length >= 0 and overhead >= 0 and length < overhead:
|
|
144
|
+
raise TypeError(
|
|
145
|
+
f"dec expects ciphertext with at least {overhead} bytes for algo='{algo}', but got {length} bytes"
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Compute output plaintext length
|
|
149
|
+
if length >= 0 and overhead >= 0:
|
|
150
|
+
outs_info = (TensorType(UINT8, (length - overhead,)),)
|
|
151
|
+
else:
|
|
152
|
+
# Unknown length or overhead, return dynamic length
|
|
153
|
+
outs_info = (TensorType(UINT8, (-1,)),)
|
|
154
|
+
|
|
155
|
+
ins_info = (TensorType.from_obj(ct_ty), TensorType.from_obj(key))
|
|
156
|
+
pfunc = PFunction(
|
|
157
|
+
fn_type="crypto.dec",
|
|
158
|
+
ins_info=ins_info,
|
|
159
|
+
outs_info=outs_info,
|
|
160
|
+
algo=algo,
|
|
161
|
+
)
|
|
162
|
+
_, treedef = tree_flatten(outs_info[0])
|
|
163
|
+
return pfunc, [ciphertext, key], treedef
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@_CRYPTO_MOD.op_def()
|
|
167
|
+
def kem_keygen(suite: str = "x25519") -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
168
|
+
"""KEM-style keypair generation: returns (sk, pk) bytes.
|
|
169
|
+
|
|
170
|
+
API: kem_keygen(suite: str = "x25519") -> (sk: u8[32], pk: u8[32])
|
|
171
|
+
|
|
172
|
+
The suite parameter is stored in the PFunction attributes for backend use.
|
|
173
|
+
"""
|
|
174
|
+
if suite == "x25519":
|
|
175
|
+
sk_ty = TensorType(UINT8, (32,))
|
|
176
|
+
pk_ty = TensorType(UINT8, (32,))
|
|
177
|
+
else:
|
|
178
|
+
# Unknown suite, return dynamic lengths
|
|
179
|
+
sk_ty = TensorType(UINT8, (-1,))
|
|
180
|
+
pk_ty = TensorType(UINT8, (-1,))
|
|
181
|
+
outs_info = (sk_ty, pk_ty)
|
|
182
|
+
|
|
183
|
+
pfunc = PFunction(
|
|
184
|
+
fn_type="crypto.kem_keygen",
|
|
185
|
+
ins_info=(),
|
|
186
|
+
outs_info=outs_info,
|
|
187
|
+
suite=suite,
|
|
188
|
+
)
|
|
189
|
+
_, treedef = tree_flatten(outs_info)
|
|
190
|
+
return pfunc, [], treedef
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
@_CRYPTO_MOD.op_def()
|
|
194
|
+
def kem_derive(
|
|
195
|
+
sk: MPObject, peer_pk: MPObject, suite: str = "x25519"
|
|
196
|
+
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
197
|
+
"""KEM-style shared secret derivation: returns secret bytes.
|
|
198
|
+
|
|
199
|
+
API: kem_derive(sk: u8[32], peer_pk: u8[32], suite: str = "x25519") -> secret: u8[32]
|
|
200
|
+
|
|
201
|
+
The suite parameter is stored in the PFunction attributes for backend use.
|
|
202
|
+
"""
|
|
203
|
+
# Validate input types
|
|
204
|
+
if sk.dtype != UINT8:
|
|
205
|
+
raise TypeError("kem_derive expects UINT8 secret key")
|
|
206
|
+
if peer_pk.dtype != UINT8:
|
|
207
|
+
raise TypeError("kem_derive expects UINT8 peer public key")
|
|
208
|
+
if len(sk.shape) != 1 or len(peer_pk.shape) != 1:
|
|
209
|
+
raise TypeError("kem_derive expects 1-D inputs")
|
|
210
|
+
|
|
211
|
+
if suite == "x25519":
|
|
212
|
+
if sk.shape[0] != 32 or peer_pk.shape[0] != 32:
|
|
213
|
+
raise TypeError("kem_derive expects 32-byte keys for suite 'x25519'")
|
|
214
|
+
secret_ty = TensorType(UINT8, (32,))
|
|
215
|
+
else:
|
|
216
|
+
# Unknown suite, return dynamic length
|
|
217
|
+
secret_ty = TensorType(UINT8, (-1,))
|
|
218
|
+
outs_info = (secret_ty,)
|
|
219
|
+
|
|
220
|
+
ins_info = (TensorType.from_obj(sk), TensorType.from_obj(peer_pk))
|
|
221
|
+
pfunc = PFunction(
|
|
222
|
+
fn_type="crypto.kem_derive",
|
|
223
|
+
ins_info=ins_info,
|
|
224
|
+
outs_info=outs_info,
|
|
225
|
+
suite=suite,
|
|
226
|
+
)
|
|
227
|
+
_, treedef = tree_flatten(outs_info[0])
|
|
228
|
+
return pfunc, [sk, peer_pk], treedef
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
@_CRYPTO_MOD.op_def()
|
|
232
|
+
def hkdf(
|
|
233
|
+
secret: MPObject, info: str, hash: str = "SHA-256"
|
|
234
|
+
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
235
|
+
"""HKDF-style key derivation: returns a 32-byte key.
|
|
236
|
+
|
|
237
|
+
API: hkdf(secret: u8[N], info: str, hash: str = "SHA-256") -> key: u8[32]
|
|
238
|
+
|
|
239
|
+
The hash parameter is stored in the PFunction attributes for backend use.
|
|
240
|
+
"""
|
|
241
|
+
# Validate input types
|
|
242
|
+
if secret.dtype != UINT8:
|
|
243
|
+
raise TypeError("hkdf expects UINT8 secret")
|
|
244
|
+
if len(secret.shape) != 1:
|
|
245
|
+
raise TypeError("hkdf expects 1-D secret")
|
|
246
|
+
|
|
247
|
+
if hash == "SHA-256" or hash == "SM3":
|
|
248
|
+
outs_info = (TensorType(UINT8, (32,)),)
|
|
249
|
+
else:
|
|
250
|
+
# Unknown hash, return dynamic length
|
|
251
|
+
outs_info = (TensorType(UINT8, (-1,)),)
|
|
252
|
+
|
|
253
|
+
ins_info = (TensorType.from_obj(secret),)
|
|
254
|
+
pfunc = PFunction(
|
|
255
|
+
fn_type="crypto.hkdf",
|
|
256
|
+
ins_info=ins_info,
|
|
257
|
+
outs_info=outs_info,
|
|
258
|
+
hash=hash,
|
|
259
|
+
info=info,
|
|
260
|
+
)
|
|
261
|
+
_, treedef = tree_flatten(outs_info[0])
|
|
262
|
+
return pfunc, [secret], treedef
|
mplang/{ops → v1/ops}/fhe.py
RENAMED
|
@@ -12,8 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from mplang.core import UINT8, TensorType
|
|
16
|
-
from mplang.ops.base import stateless_mod
|
|
15
|
+
from mplang.v1.core import UINT8, TensorType
|
|
16
|
+
from mplang.v1.ops.base import stateless_mod
|
|
17
17
|
|
|
18
18
|
_fhe_MOD = stateless_mod("fhe")
|
|
19
19
|
|
mplang/{ops → v1/ops}/jax_cc.py
RENAMED
|
@@ -14,16 +14,18 @@
|
|
|
14
14
|
|
|
15
15
|
from __future__ import annotations
|
|
16
16
|
|
|
17
|
+
import logging
|
|
17
18
|
from collections.abc import Callable
|
|
18
19
|
from typing import Any
|
|
19
20
|
|
|
20
21
|
import jax
|
|
21
22
|
import jax.numpy as jnp
|
|
23
|
+
from jax import export
|
|
22
24
|
from jax.tree_util import PyTreeDef, tree_flatten
|
|
23
25
|
|
|
24
|
-
from mplang.core import MPObject, PFunction, TensorType, get_fn_name
|
|
25
|
-
from mplang.ops.base import FeOperation, stateless_mod
|
|
26
|
-
from mplang.utils.func_utils import normalize_fn
|
|
26
|
+
from mplang.v1.core import MPObject, PFunction, TensorType, get_fn_name
|
|
27
|
+
from mplang.v1.ops.base import FeOperation, stateless_mod
|
|
28
|
+
from mplang.v1.utils.func_utils import normalize_fn
|
|
27
29
|
|
|
28
30
|
# Enable 64-bit precision for JAX to match tensor types
|
|
29
31
|
jax.config.update("jax_enable_x64", True)
|
|
@@ -36,7 +38,8 @@ def jax2stablehlo(
|
|
|
36
38
|
|
|
37
39
|
Translates high-level JAX functions into StableHLO MLIR representations,
|
|
38
40
|
enabling execution on JAX backends across different processes and platforms.
|
|
39
|
-
Uses
|
|
41
|
+
Uses a hybrid approach: traditional JAX trace/lower for compilation compatibility,
|
|
42
|
+
with stable jax.export API for parameter tracking.
|
|
40
43
|
|
|
41
44
|
Args:
|
|
42
45
|
is_variable: Predicate function to classify parameters as variables vs. constants.
|
|
@@ -52,34 +55,6 @@ def jax2stablehlo(
|
|
|
52
55
|
Non-variable parameters are captured as compile-time constants within
|
|
53
56
|
the PFunction body, while variables become runtime input parameters.
|
|
54
57
|
- PyTreeDef: Tree structure template for reconstructing nested output values
|
|
55
|
-
|
|
56
|
-
Rationale:
|
|
57
|
-
JAX Serialization Options Analysis:
|
|
58
|
-
1. jax.export (JAX ≥0.4.35) - Official export API with StableHLO backend
|
|
59
|
-
2. HLO protobuf - Raw XLA HloModule serialization
|
|
60
|
-
3. HLO text - Human-readable HLO representation
|
|
61
|
-
4. StableHLO MLIR - Portable intermediate representation
|
|
62
|
-
5. JAX compiled object pickling - Limited to same-process execution
|
|
63
|
-
|
|
64
|
-
Current Choice: StableHLO MLIR
|
|
65
|
-
Advantages:
|
|
66
|
-
- ✅ Available in current JAX version (0.4.34)
|
|
67
|
-
- ✅ Cross-version compatibility guaranteed by StableHLO design
|
|
68
|
-
- ✅ Direct compilation support via XLA client.compile(mlir_string)
|
|
69
|
-
- ✅ Handles complex functions (multi-input/output, control flow)
|
|
70
|
-
- ✅ Preserves numerical precision
|
|
71
|
-
- ✅ Platform-independent representation
|
|
72
|
-
|
|
73
|
-
Alternative Options Issues:
|
|
74
|
-
- jax.export: Not available in JAX 0.4.34
|
|
75
|
-
- HLO protobuf: Version compatibility issues with StableHLO parser
|
|
76
|
-
- HLO text: Parser compatibility issues with XLA client
|
|
77
|
-
- Pickle: Cannot serialize XLA LoadedExecutable objects
|
|
78
|
-
|
|
79
|
-
Future Migration Path:
|
|
80
|
-
- JAX ≥0.4.35: Migrate to jax.export.export() + jax.export.deserialize()
|
|
81
|
-
- JAX ≥0.5.x: Consider new portable formats if available
|
|
82
|
-
- Long-term: Adopt official JAX serialization standards as they mature
|
|
83
58
|
"""
|
|
84
59
|
# Flatten (args, kwargs) and capture immediates using the moved logic from primitive.py
|
|
85
60
|
normalized_fn, in_vars = normalize_fn(flat_fn, args, kwargs, is_variable)
|
|
@@ -89,47 +64,39 @@ def jax2stablehlo(
|
|
|
89
64
|
jax.ShapeDtypeStruct(arg.shape, jnp.dtype(arg.dtype.name)) for arg in in_vars
|
|
90
65
|
]
|
|
91
66
|
|
|
92
|
-
#
|
|
67
|
+
# Hybrid approach: Use standard JAX trace/lower for compatibility, but jax.export for parameter tracking
|
|
93
68
|
jitted_fn = jax.jit(normalized_fn)
|
|
94
69
|
traced = jitted_fn.trace(jax_params)
|
|
95
70
|
lowered = traced.lower()
|
|
96
71
|
|
|
97
|
-
# Get StableHLO MLIR representation
|
|
98
|
-
# compiler_ir("stablehlo") returns jaxlib.mlir.ir.Module object
|
|
99
|
-
# str() converts to serializable text format
|
|
72
|
+
# Get StableHLO MLIR representation using traditional approach
|
|
100
73
|
stablehlo_mlir = lowered.compiler_ir("stablehlo")
|
|
101
74
|
mlir_text = str(stablehlo_mlir)
|
|
102
75
|
|
|
103
|
-
# Get output info
|
|
76
|
+
# Get output info using traditional approach
|
|
104
77
|
out_info_flat, out_tree = tree_flatten(lowered.out_info)
|
|
105
78
|
out_info_flat = [TensorType.from_obj(info) for info in out_info_flat]
|
|
106
79
|
|
|
107
|
-
# Extract argument keep mapping
|
|
108
|
-
#
|
|
109
|
-
# receives all original arguments. We need the mapping to filter them correctly.
|
|
80
|
+
# Extract argument keep mapping using stable jax.export API for parameter tracking
|
|
81
|
+
# We use jax.export only for getting the kept_var_idx information, not for the main compilation
|
|
110
82
|
arg_keep_map = None
|
|
111
83
|
original_arg_count = len(in_vars)
|
|
112
84
|
|
|
113
85
|
try:
|
|
114
|
-
#
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
kept_var_idx =
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
arg_keep_map =
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
#
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
f"Cannot access JAX's kept_var_idx to handle unused parameter elimination. "
|
|
129
|
-
f"This function may have unused parameters that JAX optimized away, "
|
|
130
|
-
f"but we cannot determine which ones without the internal API. "
|
|
131
|
-
f"Original error: {e}"
|
|
132
|
-
) from e
|
|
86
|
+
# Use jax.export just to get the stable parameter tracking information
|
|
87
|
+
export_fn = export.export(jitted_fn)
|
|
88
|
+
exported = export_fn(jax_params)
|
|
89
|
+
kept_var_idx = exported.module_kept_var_idx
|
|
90
|
+
if kept_var_idx is not None and len(kept_var_idx) < original_arg_count:
|
|
91
|
+
# JAX eliminated some unused parameters during compilation
|
|
92
|
+
# Keep the indices in sorted order for consistent mapping
|
|
93
|
+
arg_keep_map = sorted(kept_var_idx)
|
|
94
|
+
except Exception as e:
|
|
95
|
+
# Fallback: if jax.export fails, we can still use the compiled result without parameter tracking
|
|
96
|
+
# This ensures backward compatibility even if export has issues
|
|
97
|
+
logging.warning(
|
|
98
|
+
f"jax.export failed to get kept_var_idx, proceeding without it. Error: {e}"
|
|
99
|
+
)
|
|
133
100
|
|
|
134
101
|
# This format tells JaxRT how to handle the compiled result
|
|
135
102
|
pfn_kwargs: dict[str, Any] = {
|
mplang/v1/ops/nnx_cc.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from collections.abc import Callable
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
import jax
|
|
22
|
+
import jax.numpy as jnp
|
|
23
|
+
from flax import nnx
|
|
24
|
+
from jax import export
|
|
25
|
+
from jax.tree_util import PyTreeDef, tree_flatten
|
|
26
|
+
|
|
27
|
+
from mplang.v1.core import MPObject, PFunction, TensorType, get_fn_name
|
|
28
|
+
from mplang.v1.ops.base import FeOperation, stateless_mod
|
|
29
|
+
from mplang.v1.utils.func_utils import normalize_fn
|
|
30
|
+
|
|
31
|
+
# Enable 64-bit precision for JAX to match tensor types
|
|
32
|
+
jax.config.update("jax_enable_x64", True)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def nnx2stablehlo(
|
|
36
|
+
is_variable: Callable[[Any], bool], flat_fn: Any, *args: Any, **kwargs: Any
|
|
37
|
+
) -> tuple[PFunction, list[Any], PyTreeDef]:
|
|
38
|
+
"""Compile NNX function to StableHLO MLIR format for remote execution.
|
|
39
|
+
|
|
40
|
+
Translates high-level NNX functions into StableHLO MLIR representations,
|
|
41
|
+
enabling execution on JAX backends across different processes and platforms.
|
|
42
|
+
Uses a hybrid approach: traditional NNX trace/lower for compilation compatibility,
|
|
43
|
+
with stable jax.export API for parameter tracking.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
is_variable: Predicate function to classify parameters as variables vs. constants.
|
|
47
|
+
Returns True for parameters that should be treated as PFunction inputs.
|
|
48
|
+
flat_fn: NNX function to be compiled into StableHLO format
|
|
49
|
+
*args: Positional arguments passed to the function during compilation
|
|
50
|
+
**kwargs: Keyword arguments passed to the function during compilation
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
tuple[PFunction, list, PyTreeDef]: Compilation artifacts containing:
|
|
54
|
+
- PFunction: Serialized function with embedded MLIR text and type metadata
|
|
55
|
+
- list: Extracted variable parameters (those satisfying is_variable predicate).
|
|
56
|
+
Non-variable parameters are captured as compile-time constants within
|
|
57
|
+
the PFunction body, while variables become runtime input parameters.
|
|
58
|
+
- PyTreeDef: Tree structure template for reconstructing nested output values
|
|
59
|
+
"""
|
|
60
|
+
# Flatten (args, kwargs) and capture immediates using the moved logic from primitive.py
|
|
61
|
+
normalized_fn, in_vars = normalize_fn(flat_fn, args, kwargs, is_variable)
|
|
62
|
+
|
|
63
|
+
# Convert TensorType in_vars to ShapeDtypeStruct for JAX tracing
|
|
64
|
+
jax_params = [
|
|
65
|
+
jax.ShapeDtypeStruct(arg.shape, jnp.dtype(arg.dtype.name)) for arg in in_vars
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
# NNX compilation pipeline using JAX export API: nnx.jit → jax.export → StableHLO MLIR
|
|
69
|
+
# Use nnx.jit for NNX-specific functionality, then jax.export for stable parameter handling
|
|
70
|
+
nnx_jitted = nnx.jit(normalized_fn)
|
|
71
|
+
|
|
72
|
+
# Extract the underlying JAX function for jax.export compatibility
|
|
73
|
+
# nnx.jit wraps a JAX function, and we can access it via .fun attribute
|
|
74
|
+
underlying_jax_fn = nnx_jitted.fun
|
|
75
|
+
|
|
76
|
+
# Hybrid approach: Use NNX trace/lower for compilation, but jax.export for parameter tracking
|
|
77
|
+
# Use traditional nnx.jit → trace → lower for compatibility with argument structure
|
|
78
|
+
nnx_traced = nnx_jitted.trace(jax_params)
|
|
79
|
+
nnx_lowered = nnx_traced.lower()
|
|
80
|
+
|
|
81
|
+
# Get StableHLO MLIR representation using traditional NNX approach
|
|
82
|
+
# NNX lowered object wraps JAX lowered, so we access the inner JAX lowered object
|
|
83
|
+
jax_lowered = nnx_lowered.lowered
|
|
84
|
+
stablehlo_mlir = jax_lowered.compiler_ir("stablehlo")
|
|
85
|
+
mlir_text = str(stablehlo_mlir)
|
|
86
|
+
|
|
87
|
+
# Get output info using traditional NNX approach
|
|
88
|
+
# NNX captures output in (args, kwargs, result) format, so we need to extract just the result part
|
|
89
|
+
raw_out_info = jax_lowered.out_info
|
|
90
|
+
if isinstance(raw_out_info, tuple) and len(raw_out_info) == 3:
|
|
91
|
+
# NNX format: (args, kwargs, result) - extract just the result
|
|
92
|
+
_, _, actual_out_info = raw_out_info
|
|
93
|
+
out_info_flat, out_tree = tree_flatten(actual_out_info)
|
|
94
|
+
else:
|
|
95
|
+
# Fallback to direct format (shouldn't happen with NNX, but just in case)
|
|
96
|
+
out_info_flat, out_tree = tree_flatten(raw_out_info)
|
|
97
|
+
|
|
98
|
+
out_info_flat = [TensorType.from_obj(info) for info in out_info_flat]
|
|
99
|
+
|
|
100
|
+
# Extract argument keep mapping using stable jax.export API for parameter tracking
|
|
101
|
+
# We use the underlying JAX function with jax.export only for parameter tracking
|
|
102
|
+
arg_keep_map = None
|
|
103
|
+
original_arg_count = len(in_vars)
|
|
104
|
+
|
|
105
|
+
try:
|
|
106
|
+
# Use jax.export with the underlying JAX function just to get stable parameter tracking
|
|
107
|
+
export_fn = export.export(jax.jit(underlying_jax_fn))
|
|
108
|
+
exported = export_fn(jax_params)
|
|
109
|
+
kept_var_idx = exported.module_kept_var_idx
|
|
110
|
+
if kept_var_idx is not None and len(kept_var_idx) < original_arg_count:
|
|
111
|
+
# JAX eliminated some unused parameters during compilation
|
|
112
|
+
# Keep the indices in sorted order for consistent mapping
|
|
113
|
+
arg_keep_map = sorted(kept_var_idx)
|
|
114
|
+
except Exception as e:
|
|
115
|
+
# Fallback: if jax.export fails, we can still use the compiled result without parameter tracking
|
|
116
|
+
# This ensures backward compatibility even if export has issues
|
|
117
|
+
logging.warning(
|
|
118
|
+
f"jax.export failed to get kept_var_idx, proceeding without it. Error: {e}"
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# This format tells JaxRT how to handle the compiled result
|
|
122
|
+
# Use the same format as JAX since NNX compiles to the same backend
|
|
123
|
+
pfn_kwargs: dict[str, Any] = {
|
|
124
|
+
"fn_type": "mlir.stablehlo", # Key: specify StableHLO MLIR format
|
|
125
|
+
"ins_info": tuple(TensorType.from_obj(x) for x in in_vars),
|
|
126
|
+
"outs_info": tuple(out_info_flat),
|
|
127
|
+
"fn_name": get_fn_name(flat_fn),
|
|
128
|
+
"fn_text": mlir_text, # MLIR text, serializable for transmission
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
if arg_keep_map is not None:
|
|
132
|
+
pfn_kwargs["arg_keep_map"] = arg_keep_map
|
|
133
|
+
|
|
134
|
+
pfn = PFunction(**pfn_kwargs)
|
|
135
|
+
return pfn, in_vars, out_tree
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class NnxRunner(FeOperation):
|
|
139
|
+
"""NNX function runner frontend operation."""
|
|
140
|
+
|
|
141
|
+
def trace(
|
|
142
|
+
self, nnx_fn: Callable, *args: Any, **kwargs: Any
|
|
143
|
+
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
144
|
+
"""
|
|
145
|
+
NNX compilation helper function.
|
|
146
|
+
|
|
147
|
+
Compiles an NNX function to StableHLO format and returns the PFunction
|
|
148
|
+
along with variable arguments for evaluation.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
nnx_fn: The NNX function to compile
|
|
152
|
+
*args: Positional arguments to the function
|
|
153
|
+
**kwargs: Keyword arguments to the function
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
tuple[PFunction, list[MPObject], PyTreeDef]: The compiled PFunction, input variables, and output tree
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
def is_variable(arg: Any) -> bool:
|
|
160
|
+
return isinstance(arg, MPObject)
|
|
161
|
+
|
|
162
|
+
pfunc, in_vars, out_tree = nnx2stablehlo(is_variable, nnx_fn, *args, **kwargs)
|
|
163
|
+
return pfunc, in_vars, out_tree
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
_NNX_MOD = stateless_mod("nnx")
|
|
167
|
+
|
|
168
|
+
run_nnx = NnxRunner(_NNX_MOD, "run")
|
mplang/{ops → v1/ops}/phe.py
RENAMED
|
@@ -14,21 +14,34 @@
|
|
|
14
14
|
|
|
15
15
|
"""PHE (Partially Homomorphic Encryption) frontend operations."""
|
|
16
16
|
|
|
17
|
-
from mplang.core import UINT8, TensorType
|
|
18
|
-
from mplang.ops.base import stateless_mod
|
|
17
|
+
from mplang.v1.core import UINT8, TensorType
|
|
18
|
+
from mplang.v1.ops.base import stateless_mod
|
|
19
19
|
|
|
20
20
|
_PHE_MOD = stateless_mod("phe")
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
@_PHE_MOD.simple_op()
|
|
24
24
|
def keygen(
|
|
25
|
-
*,
|
|
25
|
+
*,
|
|
26
|
+
scheme: str = "paillier",
|
|
27
|
+
key_size: int = 2048,
|
|
28
|
+
max_value: int | None = None,
|
|
29
|
+
fxp_bits: int | None = None,
|
|
26
30
|
) -> tuple[TensorType, TensorType]:
|
|
27
31
|
"""Generate a PHE key pair: returns (public_key, private_key).
|
|
28
32
|
|
|
29
33
|
Keys are represented with a sentinel TensorType UINT8[(-1, 0)] to indicate
|
|
30
34
|
non-structural, backend-only handles. Runtime validation will treat this
|
|
31
35
|
shape as an opaque placeholder and skip dtype/shape checks.
|
|
36
|
+
|
|
37
|
+
Attributes (forwarded to backend):
|
|
38
|
+
scheme: PHE scheme (default: 'paillier')
|
|
39
|
+
key_size: Modulus size in bits (default: 2048)
|
|
40
|
+
max_value: Optional range-encoding bound B. If provided, the backend will
|
|
41
|
+
encode/decode integers/floats within [-B, B] and treat (B, N-B) as overflow.
|
|
42
|
+
Pick B to exceed the largest intermediate magnitude you expect in homomorphic
|
|
43
|
+
combinations. If omitted, backend default is used (currently 2**32).
|
|
44
|
+
fxp_bits: Optional fixed-point fractional bits for float encoding (default backend value).
|
|
32
45
|
"""
|
|
33
46
|
key_spec = TensorType(UINT8, (-1, 0))
|
|
34
47
|
return key_spec, key_spec
|
mplang/{ops → v1/ops}/spu.py
RENAMED
|
@@ -23,9 +23,9 @@ import spu.utils.frontend as spu_fe
|
|
|
23
23
|
from jax import ShapeDtypeStruct
|
|
24
24
|
from jax.tree_util import PyTreeDef, tree_flatten
|
|
25
25
|
|
|
26
|
-
from mplang.core import MPObject, PFunction, TensorType, get_fn_name
|
|
27
|
-
from mplang.ops.base import stateless_mod
|
|
28
|
-
from mplang.utils.func_utils import normalize_fn
|
|
26
|
+
from mplang.v1.core import MPObject, PFunction, TensorType, get_fn_name
|
|
27
|
+
from mplang.v1.ops.base import stateless_mod
|
|
28
|
+
from mplang.v1.utils.func_utils import normalize_fn
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
class Visibility:
|