mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,262 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Crypto frontend operations: operation signatures, types, and high-level semantics.
17
+
18
+ Scope and contracts:
19
+ - This module defines portable API shapes; it does not implement cryptography.
20
+ - Backends execute the operations and must meet the security semantics required
21
+ by the deployment (confidentiality, authenticity, correctness, etc.).
22
+ - The enc/dec API in this frontend uses a conventional 12-byte nonce prefix
23
+ (ciphertext = nonce || payload), and dec expects that format. Other security
24
+ properties (e.g., AEAD) are backend responsibilities.
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ from jax.tree_util import PyTreeDef, tree_flatten
30
+
31
+ from mplang.v1.core import UINT8, TensorType
32
+ from mplang.v1.core.mpobject import MPObject
33
+ from mplang.v1.core.pfunc import PFunction
34
+ from mplang.v1.ops.base import stateless_mod
35
+
36
+ _CRYPTO_MOD = stateless_mod("crypto")
37
+
38
+
39
+ def _get_algo_overhead(algo: str) -> int:
40
+ """Get ciphertext overhead for a given encryption algorithm.
41
+
42
+ Args:
43
+ algo: Encryption algorithm identifier
44
+
45
+ Returns:
46
+ int: Number of overhead bytes added to plaintext length
47
+ """
48
+ overhead_map = {
49
+ "aes-ctr": 16, # nonce only (legacy compatibility)
50
+ "aes-gcm": 28, # nonce(12) + tag(16) for AES-GCM
51
+ "sm4-gcm": 28, # nonce(12) + tag(16) for SM4-GCM
52
+ }
53
+
54
+ if algo not in overhead_map:
55
+ # return unknown overhead as -1
56
+ return -1
57
+ return overhead_map[algo]
58
+
59
+
60
+ @_CRYPTO_MOD.simple_op()
61
+ def keygen(*, length: int = 32) -> TensorType:
62
+ """Generate random bytes for symmetric keys or generic randomness.
63
+
64
+ API: keygen(length: int = 32) -> key: u8[length]
65
+
66
+ Notes:
67
+ - Frontend defines the type/shape; backend provides randomness.
68
+ - Raises ValueError when length <= 0.
69
+ """
70
+ if length <= 0:
71
+ raise ValueError("length must be > 0")
72
+ return TensorType(UINT8, (length,))
73
+
74
+
75
+ @_CRYPTO_MOD.op_def()
76
+ def enc(
77
+ plaintext: MPObject, key: MPObject, algo: str = "aes-ctr"
78
+ ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
79
+ """Symmetric encryption with algorithm-aware output sizing.
80
+
81
+ API: enc(plaintext: u8[N], key: u8[M], *, algo: str = "aes-ctr") -> ciphertext: u8[N + overhead]
82
+
83
+ Supported algorithms and overhead:
84
+ - "aes-ctr": 16 bytes (nonce only, legacy compatibility)
85
+ - "aes-gcm": 28 bytes (nonce + 16-byte authentication tag)
86
+ - "sm4-gcm": 28 bytes (nonce + 16-byte authentication tag)
87
+
88
+ The algo parameter is stored in the PFunction attributes for backend use.
89
+ """
90
+ pt_ty = plaintext
91
+ if pt_ty.dtype != UINT8:
92
+ raise TypeError("enc expects UINT8 plaintext")
93
+ if len(pt_ty.shape) != 1:
94
+ raise TypeError("enc expects 1-D plaintext")
95
+
96
+ # Validate and get overhead for the specified algorithm
97
+ overhead = _get_algo_overhead(algo)
98
+ length = pt_ty.shape[0]
99
+ if length >= 0 and overhead >= 0:
100
+ outs_info = (TensorType(UINT8, (length + overhead,)),)
101
+ else:
102
+ # Unknown length or overhead, return dynamic length
103
+ outs_info = (TensorType(UINT8, (-1,)),)
104
+
105
+ ins_info = (TensorType.from_obj(pt_ty), TensorType.from_obj(key))
106
+ pfunc = PFunction(
107
+ fn_type="crypto.enc",
108
+ ins_info=ins_info,
109
+ outs_info=outs_info,
110
+ algo=algo,
111
+ )
112
+ _, treedef = tree_flatten(outs_info[0])
113
+ return pfunc, [plaintext, key], treedef
114
+
115
+
116
+ @_CRYPTO_MOD.op_def()
117
+ def dec(
118
+ ciphertext: MPObject, key: MPObject, algo: str = "aes-ctr"
119
+ ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
120
+ """Symmetric decryption with algorithm-aware input sizing.
121
+
122
+ API: dec(ciphertext: u8[N + overhead], key: u8[M], *, algo: str = "aes-ctr") -> plaintext: u8[N]
123
+
124
+ Supported algorithms and overhead:
125
+ - "aes-ctr": 16 bytes (nonce only, legacy compatibility)
126
+ - "aes-gcm": 28 bytes (nonce + 16-byte authentication tag)
127
+ - "sm4-gcm": 28 bytes (nonce + 16-byte authentication tag)
128
+
129
+ The algo parameter is stored in the PFunction attributes for backend use.
130
+ Backend is responsible for parsing the ciphertext format according to algo.
131
+ """
132
+ ct_ty = ciphertext
133
+ if ct_ty.dtype != UINT8:
134
+ raise TypeError("dec expects UINT8 ciphertext")
135
+ if len(ct_ty.shape) != 1:
136
+ raise TypeError("dec expects 1-D ciphertext")
137
+
138
+ # Validate and get overhead for the specified algorithm
139
+ overhead = _get_algo_overhead(algo)
140
+ length = ct_ty.shape[0]
141
+
142
+ # Validate minimum ciphertext length
143
+ if length >= 0 and overhead >= 0 and length < overhead:
144
+ raise TypeError(
145
+ f"dec expects ciphertext with at least {overhead} bytes for algo='{algo}', but got {length} bytes"
146
+ )
147
+
148
+ # Compute output plaintext length
149
+ if length >= 0 and overhead >= 0:
150
+ outs_info = (TensorType(UINT8, (length - overhead,)),)
151
+ else:
152
+ # Unknown length or overhead, return dynamic length
153
+ outs_info = (TensorType(UINT8, (-1,)),)
154
+
155
+ ins_info = (TensorType.from_obj(ct_ty), TensorType.from_obj(key))
156
+ pfunc = PFunction(
157
+ fn_type="crypto.dec",
158
+ ins_info=ins_info,
159
+ outs_info=outs_info,
160
+ algo=algo,
161
+ )
162
+ _, treedef = tree_flatten(outs_info[0])
163
+ return pfunc, [ciphertext, key], treedef
164
+
165
+
166
+ @_CRYPTO_MOD.op_def()
167
+ def kem_keygen(suite: str = "x25519") -> tuple[PFunction, list[MPObject], PyTreeDef]:
168
+ """KEM-style keypair generation: returns (sk, pk) bytes.
169
+
170
+ API: kem_keygen(suite: str = "x25519") -> (sk: u8[32], pk: u8[32])
171
+
172
+ The suite parameter is stored in the PFunction attributes for backend use.
173
+ """
174
+ if suite == "x25519":
175
+ sk_ty = TensorType(UINT8, (32,))
176
+ pk_ty = TensorType(UINT8, (32,))
177
+ else:
178
+ # Unknown suite, return dynamic lengths
179
+ sk_ty = TensorType(UINT8, (-1,))
180
+ pk_ty = TensorType(UINT8, (-1,))
181
+ outs_info = (sk_ty, pk_ty)
182
+
183
+ pfunc = PFunction(
184
+ fn_type="crypto.kem_keygen",
185
+ ins_info=(),
186
+ outs_info=outs_info,
187
+ suite=suite,
188
+ )
189
+ _, treedef = tree_flatten(outs_info)
190
+ return pfunc, [], treedef
191
+
192
+
193
+ @_CRYPTO_MOD.op_def()
194
+ def kem_derive(
195
+ sk: MPObject, peer_pk: MPObject, suite: str = "x25519"
196
+ ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
197
+ """KEM-style shared secret derivation: returns secret bytes.
198
+
199
+ API: kem_derive(sk: u8[32], peer_pk: u8[32], suite: str = "x25519") -> secret: u8[32]
200
+
201
+ The suite parameter is stored in the PFunction attributes for backend use.
202
+ """
203
+ # Validate input types
204
+ if sk.dtype != UINT8:
205
+ raise TypeError("kem_derive expects UINT8 secret key")
206
+ if peer_pk.dtype != UINT8:
207
+ raise TypeError("kem_derive expects UINT8 peer public key")
208
+ if len(sk.shape) != 1 or len(peer_pk.shape) != 1:
209
+ raise TypeError("kem_derive expects 1-D inputs")
210
+
211
+ if suite == "x25519":
212
+ if sk.shape[0] != 32 or peer_pk.shape[0] != 32:
213
+ raise TypeError("kem_derive expects 32-byte keys for suite 'x25519'")
214
+ secret_ty = TensorType(UINT8, (32,))
215
+ else:
216
+ # Unknown suite, return dynamic length
217
+ secret_ty = TensorType(UINT8, (-1,))
218
+ outs_info = (secret_ty,)
219
+
220
+ ins_info = (TensorType.from_obj(sk), TensorType.from_obj(peer_pk))
221
+ pfunc = PFunction(
222
+ fn_type="crypto.kem_derive",
223
+ ins_info=ins_info,
224
+ outs_info=outs_info,
225
+ suite=suite,
226
+ )
227
+ _, treedef = tree_flatten(outs_info[0])
228
+ return pfunc, [sk, peer_pk], treedef
229
+
230
+
231
+ @_CRYPTO_MOD.op_def()
232
+ def hkdf(
233
+ secret: MPObject, info: str, hash: str = "SHA-256"
234
+ ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
235
+ """HKDF-style key derivation: returns a 32-byte key.
236
+
237
+ API: hkdf(secret: u8[N], info: str, hash: str = "SHA-256") -> key: u8[32]
238
+
239
+ The hash parameter is stored in the PFunction attributes for backend use.
240
+ """
241
+ # Validate input types
242
+ if secret.dtype != UINT8:
243
+ raise TypeError("hkdf expects UINT8 secret")
244
+ if len(secret.shape) != 1:
245
+ raise TypeError("hkdf expects 1-D secret")
246
+
247
+ if hash == "SHA-256" or hash == "SM3":
248
+ outs_info = (TensorType(UINT8, (32,)),)
249
+ else:
250
+ # Unknown hash, return dynamic length
251
+ outs_info = (TensorType(UINT8, (-1,)),)
252
+
253
+ ins_info = (TensorType.from_obj(secret),)
254
+ pfunc = PFunction(
255
+ fn_type="crypto.hkdf",
256
+ ins_info=ins_info,
257
+ outs_info=outs_info,
258
+ hash=hash,
259
+ info=info,
260
+ )
261
+ _, treedef = tree_flatten(outs_info[0])
262
+ return pfunc, [secret], treedef
@@ -12,8 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from mplang.core import UINT8, TensorType
16
- from mplang.ops.base import stateless_mod
15
+ from mplang.v1.core import UINT8, TensorType
16
+ from mplang.v1.ops.base import stateless_mod
17
17
 
18
18
  _fhe_MOD = stateless_mod("fhe")
19
19
 
@@ -14,16 +14,18 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
+ import logging
17
18
  from collections.abc import Callable
18
19
  from typing import Any
19
20
 
20
21
  import jax
21
22
  import jax.numpy as jnp
23
+ from jax import export
22
24
  from jax.tree_util import PyTreeDef, tree_flatten
23
25
 
24
- from mplang.core import MPObject, PFunction, TensorType, get_fn_name
25
- from mplang.ops.base import FeOperation, stateless_mod
26
- from mplang.utils.func_utils import normalize_fn
26
+ from mplang.v1.core import MPObject, PFunction, TensorType, get_fn_name
27
+ from mplang.v1.ops.base import FeOperation, stateless_mod
28
+ from mplang.v1.utils.func_utils import normalize_fn
27
29
 
28
30
  # Enable 64-bit precision for JAX to match tensor types
29
31
  jax.config.update("jax_enable_x64", True)
@@ -36,7 +38,8 @@ def jax2stablehlo(
36
38
 
37
39
  Translates high-level JAX functions into StableHLO MLIR representations,
38
40
  enabling execution on JAX backends across different processes and platforms.
39
- Uses the standard JAX compilation pipeline: jit tracelower StableHLO MLIR.
41
+ Uses a hybrid approach: traditional JAX trace/lower for compilation compatibility,
42
+ with stable jax.export API for parameter tracking.
40
43
 
41
44
  Args:
42
45
  is_variable: Predicate function to classify parameters as variables vs. constants.
@@ -52,34 +55,6 @@ def jax2stablehlo(
52
55
  Non-variable parameters are captured as compile-time constants within
53
56
  the PFunction body, while variables become runtime input parameters.
54
57
  - PyTreeDef: Tree structure template for reconstructing nested output values
55
-
56
- Rationale:
57
- JAX Serialization Options Analysis:
58
- 1. jax.export (JAX ≥0.4.35) - Official export API with StableHLO backend
59
- 2. HLO protobuf - Raw XLA HloModule serialization
60
- 3. HLO text - Human-readable HLO representation
61
- 4. StableHLO MLIR - Portable intermediate representation
62
- 5. JAX compiled object pickling - Limited to same-process execution
63
-
64
- Current Choice: StableHLO MLIR
65
- Advantages:
66
- - ✅ Available in current JAX version (0.4.34)
67
- - ✅ Cross-version compatibility guaranteed by StableHLO design
68
- - ✅ Direct compilation support via XLA client.compile(mlir_string)
69
- - ✅ Handles complex functions (multi-input/output, control flow)
70
- - ✅ Preserves numerical precision
71
- - ✅ Platform-independent representation
72
-
73
- Alternative Options Issues:
74
- - jax.export: Not available in JAX 0.4.34
75
- - HLO protobuf: Version compatibility issues with StableHLO parser
76
- - HLO text: Parser compatibility issues with XLA client
77
- - Pickle: Cannot serialize XLA LoadedExecutable objects
78
-
79
- Future Migration Path:
80
- - JAX ≥0.4.35: Migrate to jax.export.export() + jax.export.deserialize()
81
- - JAX ≥0.5.x: Consider new portable formats if available
82
- - Long-term: Adopt official JAX serialization standards as they mature
83
58
  """
84
59
  # Flatten (args, kwargs) and capture immediates using the moved logic from primitive.py
85
60
  normalized_fn, in_vars = normalize_fn(flat_fn, args, kwargs, is_variable)
@@ -89,47 +64,39 @@ def jax2stablehlo(
89
64
  jax.ShapeDtypeStruct(arg.shape, jnp.dtype(arg.dtype.name)) for arg in in_vars
90
65
  ]
91
66
 
92
- # Standard JAX serialization pipeline: jit trace lower StableHLO MLIR
67
+ # Hybrid approach: Use standard JAX trace/lower for compatibility, but jax.export for parameter tracking
93
68
  jitted_fn = jax.jit(normalized_fn)
94
69
  traced = jitted_fn.trace(jax_params)
95
70
  lowered = traced.lower()
96
71
 
97
- # Get StableHLO MLIR representation - the portable format
98
- # compiler_ir("stablehlo") returns jaxlib.mlir.ir.Module object
99
- # str() converts to serializable text format
72
+ # Get StableHLO MLIR representation using traditional approach
100
73
  stablehlo_mlir = lowered.compiler_ir("stablehlo")
101
74
  mlir_text = str(stablehlo_mlir)
102
75
 
103
- # Get output info and tree structure for result reconstruction after remote execution
76
+ # Get output info using traditional approach
104
77
  out_info_flat, out_tree = tree_flatten(lowered.out_info)
105
78
  out_info_flat = [TensorType.from_obj(info) for info in out_info_flat]
106
79
 
107
- # Extract argument keep mapping to handle JAX's unused parameter elimination
108
- # JAX can eliminate unused parameters during compilation, but the runtime still
109
- # receives all original arguments. We need the mapping to filter them correctly.
80
+ # Extract argument keep mapping using stable jax.export API for parameter tracking
81
+ # We use jax.export only for getting the kept_var_idx information, not for the main compilation
110
82
  arg_keep_map = None
111
83
  original_arg_count = len(in_vars)
112
84
 
113
85
  try:
114
- # Access JAX internal kept_var_idx - the authoritative source
115
- # This tells us exactly which original parameters survived compilation
116
- compile_args = lowered._lowering.compile_args
117
- kept_var_idx = compile_args["kept_var_idx"]
118
-
119
- kept_indices = sorted(kept_var_idx)
120
- if len(kept_indices) < original_arg_count:
121
- arg_keep_map = kept_indices
122
-
123
- except (AttributeError, KeyError, TypeError) as e:
124
- # JAX internal API is not available or changed
125
- # This is a hard error - we cannot reliably handle unused parameters
126
- # without knowing exactly which ones were kept
127
- raise RuntimeError(
128
- f"Cannot access JAX's kept_var_idx to handle unused parameter elimination. "
129
- f"This function may have unused parameters that JAX optimized away, "
130
- f"but we cannot determine which ones without the internal API. "
131
- f"Original error: {e}"
132
- ) from e
86
+ # Use jax.export just to get the stable parameter tracking information
87
+ export_fn = export.export(jitted_fn)
88
+ exported = export_fn(jax_params)
89
+ kept_var_idx = exported.module_kept_var_idx
90
+ if kept_var_idx is not None and len(kept_var_idx) < original_arg_count:
91
+ # JAX eliminated some unused parameters during compilation
92
+ # Keep the indices in sorted order for consistent mapping
93
+ arg_keep_map = sorted(kept_var_idx)
94
+ except Exception as e:
95
+ # Fallback: if jax.export fails, we can still use the compiled result without parameter tracking
96
+ # This ensures backward compatibility even if export has issues
97
+ logging.warning(
98
+ f"jax.export failed to get kept_var_idx, proceeding without it. Error: {e}"
99
+ )
133
100
 
134
101
  # This format tells JaxRT how to handle the compiled result
135
102
  pfn_kwargs: dict[str, Any] = {
@@ -0,0 +1,168 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import logging
18
+ from collections.abc import Callable
19
+ from typing import Any
20
+
21
+ import jax
22
+ import jax.numpy as jnp
23
+ from flax import nnx
24
+ from jax import export
25
+ from jax.tree_util import PyTreeDef, tree_flatten
26
+
27
+ from mplang.v1.core import MPObject, PFunction, TensorType, get_fn_name
28
+ from mplang.v1.ops.base import FeOperation, stateless_mod
29
+ from mplang.v1.utils.func_utils import normalize_fn
30
+
31
+ # Enable 64-bit precision for JAX to match tensor types
32
+ jax.config.update("jax_enable_x64", True)
33
+
34
+
35
+ def nnx2stablehlo(
36
+ is_variable: Callable[[Any], bool], flat_fn: Any, *args: Any, **kwargs: Any
37
+ ) -> tuple[PFunction, list[Any], PyTreeDef]:
38
+ """Compile NNX function to StableHLO MLIR format for remote execution.
39
+
40
+ Translates high-level NNX functions into StableHLO MLIR representations,
41
+ enabling execution on JAX backends across different processes and platforms.
42
+ Uses a hybrid approach: traditional NNX trace/lower for compilation compatibility,
43
+ with stable jax.export API for parameter tracking.
44
+
45
+ Args:
46
+ is_variable: Predicate function to classify parameters as variables vs. constants.
47
+ Returns True for parameters that should be treated as PFunction inputs.
48
+ flat_fn: NNX function to be compiled into StableHLO format
49
+ *args: Positional arguments passed to the function during compilation
50
+ **kwargs: Keyword arguments passed to the function during compilation
51
+
52
+ Returns:
53
+ tuple[PFunction, list, PyTreeDef]: Compilation artifacts containing:
54
+ - PFunction: Serialized function with embedded MLIR text and type metadata
55
+ - list: Extracted variable parameters (those satisfying is_variable predicate).
56
+ Non-variable parameters are captured as compile-time constants within
57
+ the PFunction body, while variables become runtime input parameters.
58
+ - PyTreeDef: Tree structure template for reconstructing nested output values
59
+ """
60
+ # Flatten (args, kwargs) and capture immediates using the moved logic from primitive.py
61
+ normalized_fn, in_vars = normalize_fn(flat_fn, args, kwargs, is_variable)
62
+
63
+ # Convert TensorType in_vars to ShapeDtypeStruct for JAX tracing
64
+ jax_params = [
65
+ jax.ShapeDtypeStruct(arg.shape, jnp.dtype(arg.dtype.name)) for arg in in_vars
66
+ ]
67
+
68
+ # NNX compilation pipeline using JAX export API: nnx.jit → jax.export → StableHLO MLIR
69
+ # Use nnx.jit for NNX-specific functionality, then jax.export for stable parameter handling
70
+ nnx_jitted = nnx.jit(normalized_fn)
71
+
72
+ # Extract the underlying JAX function for jax.export compatibility
73
+ # nnx.jit wraps a JAX function, and we can access it via .fun attribute
74
+ underlying_jax_fn = nnx_jitted.fun
75
+
76
+ # Hybrid approach: Use NNX trace/lower for compilation, but jax.export for parameter tracking
77
+ # Use traditional nnx.jit → trace → lower for compatibility with argument structure
78
+ nnx_traced = nnx_jitted.trace(jax_params)
79
+ nnx_lowered = nnx_traced.lower()
80
+
81
+ # Get StableHLO MLIR representation using traditional NNX approach
82
+ # NNX lowered object wraps JAX lowered, so we access the inner JAX lowered object
83
+ jax_lowered = nnx_lowered.lowered
84
+ stablehlo_mlir = jax_lowered.compiler_ir("stablehlo")
85
+ mlir_text = str(stablehlo_mlir)
86
+
87
+ # Get output info using traditional NNX approach
88
+ # NNX captures output in (args, kwargs, result) format, so we need to extract just the result part
89
+ raw_out_info = jax_lowered.out_info
90
+ if isinstance(raw_out_info, tuple) and len(raw_out_info) == 3:
91
+ # NNX format: (args, kwargs, result) - extract just the result
92
+ _, _, actual_out_info = raw_out_info
93
+ out_info_flat, out_tree = tree_flatten(actual_out_info)
94
+ else:
95
+ # Fallback to direct format (shouldn't happen with NNX, but just in case)
96
+ out_info_flat, out_tree = tree_flatten(raw_out_info)
97
+
98
+ out_info_flat = [TensorType.from_obj(info) for info in out_info_flat]
99
+
100
+ # Extract argument keep mapping using stable jax.export API for parameter tracking
101
+ # We use the underlying JAX function with jax.export only for parameter tracking
102
+ arg_keep_map = None
103
+ original_arg_count = len(in_vars)
104
+
105
+ try:
106
+ # Use jax.export with the underlying JAX function just to get stable parameter tracking
107
+ export_fn = export.export(jax.jit(underlying_jax_fn))
108
+ exported = export_fn(jax_params)
109
+ kept_var_idx = exported.module_kept_var_idx
110
+ if kept_var_idx is not None and len(kept_var_idx) < original_arg_count:
111
+ # JAX eliminated some unused parameters during compilation
112
+ # Keep the indices in sorted order for consistent mapping
113
+ arg_keep_map = sorted(kept_var_idx)
114
+ except Exception as e:
115
+ # Fallback: if jax.export fails, we can still use the compiled result without parameter tracking
116
+ # This ensures backward compatibility even if export has issues
117
+ logging.warning(
118
+ f"jax.export failed to get kept_var_idx, proceeding without it. Error: {e}"
119
+ )
120
+
121
+ # This format tells JaxRT how to handle the compiled result
122
+ # Use the same format as JAX since NNX compiles to the same backend
123
+ pfn_kwargs: dict[str, Any] = {
124
+ "fn_type": "mlir.stablehlo", # Key: specify StableHLO MLIR format
125
+ "ins_info": tuple(TensorType.from_obj(x) for x in in_vars),
126
+ "outs_info": tuple(out_info_flat),
127
+ "fn_name": get_fn_name(flat_fn),
128
+ "fn_text": mlir_text, # MLIR text, serializable for transmission
129
+ }
130
+
131
+ if arg_keep_map is not None:
132
+ pfn_kwargs["arg_keep_map"] = arg_keep_map
133
+
134
+ pfn = PFunction(**pfn_kwargs)
135
+ return pfn, in_vars, out_tree
136
+
137
+
138
+ class NnxRunner(FeOperation):
139
+ """NNX function runner frontend operation."""
140
+
141
+ def trace(
142
+ self, nnx_fn: Callable, *args: Any, **kwargs: Any
143
+ ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
144
+ """
145
+ NNX compilation helper function.
146
+
147
+ Compiles an NNX function to StableHLO format and returns the PFunction
148
+ along with variable arguments for evaluation.
149
+
150
+ Args:
151
+ nnx_fn: The NNX function to compile
152
+ *args: Positional arguments to the function
153
+ **kwargs: Keyword arguments to the function
154
+
155
+ Returns:
156
+ tuple[PFunction, list[MPObject], PyTreeDef]: The compiled PFunction, input variables, and output tree
157
+ """
158
+
159
+ def is_variable(arg: Any) -> bool:
160
+ return isinstance(arg, MPObject)
161
+
162
+ pfunc, in_vars, out_tree = nnx2stablehlo(is_variable, nnx_fn, *args, **kwargs)
163
+ return pfunc, in_vars, out_tree
164
+
165
+
166
+ _NNX_MOD = stateless_mod("nnx")
167
+
168
+ run_nnx = NnxRunner(_NNX_MOD, "run")
@@ -14,21 +14,34 @@
14
14
 
15
15
  """PHE (Partially Homomorphic Encryption) frontend operations."""
16
16
 
17
- from mplang.core import UINT8, TensorType
18
- from mplang.ops.base import stateless_mod
17
+ from mplang.v1.core import UINT8, TensorType
18
+ from mplang.v1.ops.base import stateless_mod
19
19
 
20
20
  _PHE_MOD = stateless_mod("phe")
21
21
 
22
22
 
23
23
  @_PHE_MOD.simple_op()
24
24
  def keygen(
25
- *, scheme: str = "paillier", key_size: int = 2048
25
+ *,
26
+ scheme: str = "paillier",
27
+ key_size: int = 2048,
28
+ max_value: int | None = None,
29
+ fxp_bits: int | None = None,
26
30
  ) -> tuple[TensorType, TensorType]:
27
31
  """Generate a PHE key pair: returns (public_key, private_key).
28
32
 
29
33
  Keys are represented with a sentinel TensorType UINT8[(-1, 0)] to indicate
30
34
  non-structural, backend-only handles. Runtime validation will treat this
31
35
  shape as an opaque placeholder and skip dtype/shape checks.
36
+
37
+ Attributes (forwarded to backend):
38
+ scheme: PHE scheme (default: 'paillier')
39
+ key_size: Modulus size in bits (default: 2048)
40
+ max_value: Optional range-encoding bound B. If provided, the backend will
41
+ encode/decode integers/floats within [-B, B] and treat (B, N-B) as overflow.
42
+ Pick B to exceed the largest intermediate magnitude you expect in homomorphic
43
+ combinations. If omitted, backend default is used (currently 2**32).
44
+ fxp_bits: Optional fixed-point fractional bits for float encoding (default backend value).
32
45
  """
33
46
  key_spec = TensorType(UINT8, (-1, 0))
34
47
  return key_spec, key_spec
@@ -23,9 +23,9 @@ import spu.utils.frontend as spu_fe
23
23
  from jax import ShapeDtypeStruct
24
24
  from jax.tree_util import PyTreeDef, tree_flatten
25
25
 
26
- from mplang.core import MPObject, PFunction, TensorType, get_fn_name
27
- from mplang.ops.base import stateless_mod
28
- from mplang.utils.func_utils import normalize_fn
26
+ from mplang.v1.core import MPObject, PFunction, TensorType, get_fn_name
27
+ from mplang.v1.ops.base import stateless_mod
28
+ from mplang.v1.utils.func_utils import normalize_fn
29
29
 
30
30
 
31
31
  class Visibility: