mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,349 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """SPU (Secure Processing Unit) dialect for the EDSL.
16
+
17
+ This dialect implements an "Encrypted Virtual Machine" model where the SPU
18
+ is treated as a logical device composed of multiple parties. It leverages
19
+ the `simp` dialect for data movement (encryption/decryption) and execution.
20
+
21
+ Concepts:
22
+ - SPUDevice: Represents a set of parties forming the SPU.
23
+ - make_shares: Generates secret shares on the source party.
24
+ - reconstruct: Reconstructs secret from shares on the target party.
25
+ - run_jax: Executes JAX computations on the SPU.
26
+
27
+ Example:
28
+ ```python
29
+ import jax.numpy as jnp
30
+ from mplang.v2.dialects import spu, tensor, simp
31
+ import mplang.v2.edsl.typing as elt
32
+
33
+ # 0. Setup
34
+ spu_device = spu.SPUDevice(parties=(0, 1, 2))
35
+
36
+
37
+ # 1. Define computation
38
+ def secure_add(x, y):
39
+ return x + y
40
+
41
+
42
+ # 2. Encrypt (Public -> SPU)
43
+ # Assume x, y are on party 0
44
+ # Generate shares locally
45
+ x_shares = spu.make_shares(x, count=3)
46
+ y_shares = spu.make_shares(y, count=3)
47
+
48
+ # Distribute shares to SPU parties
49
+ x_dist = []
50
+ y_dist = []
51
+ for i, target in enumerate(spu_device.parties):
52
+ x_dist.append(simp.shuffle_static(x_shares[i], {target: 0}))
53
+ y_dist.append(simp.shuffle_static(y_shares[i], {target: 0}))
54
+
55
+ # Converge to logical SPU variables
56
+ x_enc = simp.converge(*x_dist)
57
+ y_enc = simp.converge(*y_dist)
58
+
59
+ # 3. Execute (SPU -> SPU)
60
+ z_enc = spu.run_jax(secure_add, spu_device.parties, x_enc, y_enc)
61
+
62
+ # 4. Decrypt (SPU -> Public)
63
+ # Gather shares to party 0
64
+ z_shares = []
65
+ for source in spu_device.parties:
66
+ # Extract share from logical variable
67
+ share = simp.pcall_static((source,), lambda x: x, z_enc)
68
+ # Move to target
69
+ z_shares.append(simp.shuffle_static(share, {0: source}))
70
+
71
+ # Reconstruct
72
+ z = spu.reconstruct(tuple(z_shares))
73
+ ```
74
+ """
75
+
76
+ from __future__ import annotations
77
+
78
+ from collections.abc import Callable
79
+ from dataclasses import dataclass
80
+ from typing import Any, ClassVar, Literal, cast
81
+
82
+ import spu.utils.frontend as spu_fe
83
+ from jax import ShapeDtypeStruct
84
+ from jax.tree_util import tree_flatten, tree_unflatten
85
+
86
+ import mplang.v2.edsl as el
87
+ import mplang.v2.edsl.typing as elt
88
+ from mplang.v1.utils.func_utils import normalize_fn
89
+ from mplang.v2.dialects import dtypes
90
+ from mplang.v2.edsl import serde
91
+
92
+ # ==============================================================================
93
+ # --- Configuration
94
+ # ==============================================================================
95
+
96
+
97
+ @serde.register_class
98
+ @dataclass(frozen=True)
99
+ class SPUConfig:
100
+ """SPU configuration (subset of libspu.RuntimeConfig).
101
+
102
+ Attributes:
103
+ protocol: SPU protocol (e.g., "SEMI2K", "ABY3").
104
+ field: SPU field type (e.g., "FM64", "FM128").
105
+ fxp_fraction_bits: Fixed-point fraction bits.
106
+ """
107
+
108
+ protocol: str = "SEMI2K"
109
+ field: str = "FM128"
110
+ fxp_fraction_bits: int = 18
111
+
112
+ @classmethod
113
+ def from_dict(cls, d: dict[str, Any]) -> SPUConfig:
114
+ return cls(
115
+ protocol=d.get("protocol", "SEMI2K"),
116
+ field=d.get("field", "FM128"),
117
+ fxp_fraction_bits=d.get("fxp_fraction_bits", 18),
118
+ )
119
+
120
+ # --- Serde methods ---
121
+ _serde_kind: ClassVar[str] = "spu.SPUConfig"
122
+
123
+ def to_json(self) -> dict[str, Any]:
124
+ return {
125
+ "protocol": self.protocol,
126
+ "field": self.field,
127
+ "fxp_fraction_bits": self.fxp_fraction_bits,
128
+ }
129
+
130
+ @classmethod
131
+ def from_json(cls, data: dict[str, Any]) -> SPUConfig:
132
+ return cls(
133
+ protocol=data["protocol"],
134
+ field=data["field"],
135
+ fxp_fraction_bits=data["fxp_fraction_bits"],
136
+ )
137
+
138
+
139
+ # ==============================================================================
140
+ # --- Primitives (Local Operations)
141
+ # ==============================================================================
142
+
143
+ # These primitives operate locally on a single party.
144
+ # They are used inside simp.pcall to construct the distributed protocols.
145
+
146
+ makeshares_p = el.Primitive[tuple[el.Object, ...]]("spu.makeshares")
147
+ reconstruct_p = el.Primitive[el.Object]("spu.reconstruct")
148
+ exec_p = el.Primitive[Any]("spu.exec")
149
+
150
+
151
+ @makeshares_p.def_abstract_eval
152
+ def _makeshares_ae(
153
+ data: elt.TensorType, *, count: int, config: SPUConfig
154
+ ) -> tuple[elt.SSType, ...]:
155
+ """Split a tensor into `count` secret shares."""
156
+ if not isinstance(data, elt.TensorType):
157
+ raise TypeError(f"makeshares expects TensorType, got {data}")
158
+ # Shares have same shape/dtype as data (simplified additive sharing)
159
+ # Return SS-typed shares directly
160
+ return tuple(elt.SS(data) for _ in range(count))
161
+
162
+
163
+ @reconstruct_p.def_abstract_eval
164
+ def _reconstruct_ae(*shares: elt.SSType, config: SPUConfig) -> elt.TensorType:
165
+ """Reconstruct a tensor from shares."""
166
+ if not shares:
167
+ raise ValueError("reconstruct requires at least one share")
168
+ first = shares[0]
169
+ if not isinstance(first, elt.SSType):
170
+ raise TypeError(f"reconstruct expects SSType shares, got {first}")
171
+ if not isinstance(first.pt_type, elt.TensorType):
172
+ raise TypeError(f"reconstruct expects SS[Tensor], got {first}")
173
+ # Return the underlying plaintext type
174
+ return first.pt_type
175
+
176
+
177
+ # Visibility type for IR attrs (string-based, mapped to libspu.Visibility at runtime)
178
+ Visibility = Literal["secret", "public", "private"]
179
+
180
+
181
+ @exec_p.def_abstract_eval
182
+ def _exec_ae(
183
+ *args: elt.SSType | elt.TensorType,
184
+ executable: bytes,
185
+ input_vis: list[Visibility],
186
+ output_vis: list[Visibility],
187
+ output_shapes: list[tuple[int, ...]],
188
+ output_dtypes: list[elt.ScalarType],
189
+ input_names: list[str],
190
+ output_names: list[str],
191
+ config: SPUConfig,
192
+ ) -> tuple[elt.SSType, ...] | elt.SSType:
193
+ """Execute SPU kernel on shares."""
194
+ # Validate inputs are SS types or Tensor types
195
+ for arg in args:
196
+ if not (isinstance(arg, elt.SSType) or isinstance(arg, elt.TensorType)):
197
+ raise TypeError(f"spu.exec expects SSType or TensorType inputs, got {arg}")
198
+
199
+ # Outputs are SS[Tensor]
200
+ outputs: list[elt.SSType[Any]] = []
201
+ for shape, dtype in zip(output_shapes, output_dtypes, strict=True):
202
+ outputs.append(elt.SS(elt.Tensor(dtype, shape)))
203
+
204
+ if len(outputs) == 1:
205
+ return outputs[0]
206
+ return tuple(outputs)
207
+
208
+
209
+ # ==============================================================================
210
+ # --- High-Level API (Distributed Protocols)
211
+ # ==============================================================================
212
+
213
+
214
+ def make_shares(
215
+ config: SPUConfig, data: el.Object, count: int
216
+ ) -> tuple[el.Object, ...]:
217
+ """Generate shares locally (no transfer).
218
+
219
+ This function should be called inside a `simp.pcall` region.
220
+
221
+ Args:
222
+ config: SPU configuration.
223
+ data: Local TensorType object.
224
+ count: Number of shares to generate.
225
+
226
+ Returns:
227
+ Tuple of SSType objects (shares).
228
+ """
229
+ return makeshares_p.bind(data, count=count, config=config)
230
+
231
+
232
+ def reconstruct(config: SPUConfig, shares: tuple[el.Object, ...]) -> el.Object:
233
+ """Reconstruct data from shares locally (no transfer).
234
+
235
+ This function should be called inside a `simp.pcall` region.
236
+
237
+ Args:
238
+ config: SPU configuration.
239
+ shares: Tuple of SSType objects (shares).
240
+
241
+ Returns:
242
+ TensorType object (reconstructed).
243
+ """
244
+ return reconstruct_p.bind(*shares, config=config)
245
+
246
+
247
+ def run_jax(config: SPUConfig, fn: Callable, *args: Any, **kwargs: Any) -> Any:
248
+ """Execute a function on SPU locally.
249
+
250
+ This function should be called inside a `simp.pcall` region.
251
+ It compiles the function and executes it using the SPU runtime.
252
+
253
+ Args:
254
+ config: SPU configuration.
255
+ fn: The function to execute.
256
+ *args: Positional arguments (SSType or TensorType).
257
+ **kwargs: Keyword arguments.
258
+ """
259
+
260
+ # 1. Inspect inputs
261
+ # Use normalize_fn to separate EDSL objects (variables) from raw values (immediates)
262
+ def is_variable(arg: Any) -> bool:
263
+ return isinstance(arg, el.Object)
264
+
265
+ normalized_fn, in_vars = normalize_fn(fn, args, kwargs, is_variable)
266
+
267
+ # Validate inputs
268
+ for arg in in_vars:
269
+ if not (
270
+ isinstance(arg.type, elt.SSType) or isinstance(arg.type, elt.TensorType)
271
+ ):
272
+ raise TypeError(
273
+ f"spu.run_jax inputs must be SSType or TensorType, got {arg.type}"
274
+ )
275
+
276
+ # 2. Prepare for compilation
277
+ jax_args_flat = []
278
+ input_vis: list[Visibility] = [] # String-based visibility for IR
279
+
280
+ for arg in in_vars:
281
+ if isinstance(arg.type, elt.SSType):
282
+ pt_type = arg.type.pt_type
283
+ vis: Visibility = "secret"
284
+ elif isinstance(arg.type, elt.TensorType):
285
+ pt_type = arg.type
286
+ vis = "public"
287
+ else:
288
+ raise TypeError(f"Unsupported input type: {arg.type}")
289
+
290
+ if not isinstance(pt_type, elt.TensorType):
291
+ raise TypeError(f"spu.run_jax inputs must be Tensor-based, got {pt_type}")
292
+
293
+ # Map to JAX
294
+ jax_dtype = dtypes.to_jax(cast(elt.ScalarType, pt_type.element_type))
295
+ shape = tuple(d if d != -1 else 1 for d in pt_type.shape)
296
+
297
+ jax_args_flat.append(ShapeDtypeStruct(shape, jax_dtype))
298
+ input_vis.append(vis)
299
+
300
+ # 3. Compile
301
+ # Map string visibility to libspu.Visibility for spu_fe.compile
302
+ # Import libspu only at compile time, not stored in IR
303
+ import spu.libspu as libspu
304
+
305
+ def vis_to_libspu(v: Visibility) -> libspu.Visibility:
306
+ return (
307
+ libspu.Visibility.VIS_SECRET
308
+ if v == "secret"
309
+ else libspu.Visibility.VIS_PUBLIC
310
+ )
311
+
312
+ # Note: normalized_fn takes a list of variables as input
313
+ executable, output_info = spu_fe.compile(
314
+ spu_fe.Kind.JAX,
315
+ normalized_fn,
316
+ [jax_args_flat],
317
+ {},
318
+ input_names=[f"in{i}" for i in range(len(in_vars))],
319
+ input_vis=[vis_to_libspu(v) for v in input_vis],
320
+ outputNameGen=lambda outs: [f"out{i}" for i in range(len(outs))],
321
+ )
322
+
323
+ # 4. Execute SPU Kernel
324
+ flat_outputs_info, out_tree = tree_flatten(output_info)
325
+ output_shapes = [out.shape for out in flat_outputs_info]
326
+
327
+ output_dtypes = [dtypes.from_dtype(out.dtype) for out in flat_outputs_info]
328
+ output_vis_list: list[Visibility] = ["secret"] * len(flat_outputs_info)
329
+
330
+ res_shares = exec_p.bind(
331
+ *in_vars,
332
+ executable=executable.code,
333
+ input_vis=input_vis,
334
+ output_vis=output_vis_list,
335
+ output_shapes=output_shapes,
336
+ output_dtypes=output_dtypes,
337
+ input_names=executable.input_names,
338
+ output_names=executable.output_names,
339
+ config=config,
340
+ )
341
+
342
+ # 5. Unflatten results
343
+ if isinstance(res_shares, (tuple, list)):
344
+ leaves = list(res_shares)
345
+ else:
346
+ leaves = [res_shares]
347
+ final_result = tree_unflatten(out_tree, leaves)
348
+
349
+ return final_result
@@ -0,0 +1,63 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Store dialect: save/load primitives for internal state."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import mplang.v2.edsl as el
20
+ import mplang.v2.edsl.typing as elt
21
+
22
+ save_p: el.Primitive[el.Object] = el.Primitive("store.save")
23
+ load_p: el.Primitive[el.Object] = el.Primitive("store.load")
24
+
25
+
26
+ @save_p.def_abstract_eval
27
+ def _save_abstract(obj: elt.BaseType, *, uri_base: str) -> elt.BaseType:
28
+ # Save is an identity operation: returns the input object type
29
+ return obj
30
+
31
+
32
+ @load_p.def_abstract_eval
33
+ def _load_abstract(*, uri_base: str, expected_type: elt.BaseType) -> elt.BaseType:
34
+ # Load returns an object of the expected type
35
+ return expected_type
36
+
37
+
38
+ def save(obj: el.Object, uri_base: str) -> el.Object:
39
+ """Save an object to persistent storage.
40
+
41
+ This is an SPMD operation. Each party holding the object will save its
42
+ local portion to the location specified by `uri_base`.
43
+
44
+ Returns:
45
+ The input object (identity), allowing for dependency chaining.
46
+ """
47
+ return save_p.bind(obj, uri_base=uri_base)
48
+
49
+
50
+ def load(uri_base: str, expected_type: elt.BaseType) -> el.Object:
51
+ """Load an object from persistent storage.
52
+
53
+ This is an SPMD operation. Each party will load its local portion from
54
+ a path derived from `uri_base`.
55
+
56
+ Args:
57
+ uri_base: Base URI for the checkpoint package.
58
+ expected_type: The type of the object to load (reconstructed from manifest).
59
+
60
+ Returns:
61
+ The loaded object.
62
+ """
63
+ return load_p.bind(uri_base=uri_base, expected_type=expected_type)