mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/v1/_device.py ADDED
@@ -0,0 +1,602 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ This module provides the device oriented programming interface for MPC.
17
+
18
+ The device oriented programming interface is designed to provide a high-level
19
+ abstraction for the MPC programming. It allows the user to write the program
20
+ in a device-oriented manner, and the runtime will take care of the data
21
+ transformation between devices.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ from collections.abc import Callable
27
+ from functools import partial, wraps
28
+ from typing import Any, cast
29
+
30
+ from jax.tree_util import tree_map, tree_unflatten
31
+
32
+ from mplang.v1.core import (
33
+ ClusterSpec,
34
+ Device,
35
+ Mask,
36
+ MPContext,
37
+ MPObject,
38
+ TableLike,
39
+ TensorLike,
40
+ cur_ctx,
41
+ peval,
42
+ )
43
+ from mplang.v1.ops import basic, crypto, jax_cc, nnx_cc, spu, tee
44
+ from mplang.v1.ops.base import FeOperation
45
+ from mplang.v1.ops.jax_cc import JaxRunner
46
+ from mplang.v1.simp import mpi
47
+ from mplang.v1.simp.api import run_at
48
+
49
+ # Automatic transfer between devices when parameter is not on the target device.
50
+ g_auto_trans: bool = True
51
+
52
+ _HKDF_INFO_LITERAL: str = "mplang/device/tee/v1"
53
+ # Default KEM suite for TEE session establishment; make configurable via ClusterSpec in future.
54
+ _TEE_KEM_SUITE: str = "x25519"
55
+
56
+
57
+ # Context-aware session management
58
+ def _get_context_id(ctx: MPContext) -> int:
59
+ """
60
+ Get unique identifier for a context.
61
+
62
+ Args:
63
+ ctx: The context object (TraceContext or InterpContext)
64
+
65
+ Returns:
66
+ Unique integer ID for this context instance
67
+ """
68
+ return id(ctx)
69
+
70
+
71
+ # magic attribute name to mark a MPObject as a device object
72
+ DEVICE_ATTR_NAME = "__device__"
73
+
74
+
75
+ def is_device_obj(obj: Any) -> bool:
76
+ if not isinstance(obj, MPObject):
77
+ return False
78
+ return DEVICE_ATTR_NAME in obj.attrs
79
+
80
+
81
+ def set_dev_attr(obj: MPObject, dev_id: str) -> MPObject:
82
+ if not isinstance(obj, MPObject):
83
+ raise TypeError(f"Input must be an instance of MPObject, {obj}")
84
+ obj.attrs[DEVICE_ATTR_NAME] = dev_id
85
+ return obj
86
+
87
+
88
+ def get_dev_attr(obj: MPObject) -> str:
89
+ if not isinstance(obj, MPObject):
90
+ raise TypeError("Input must be an instance of MPObject")
91
+
92
+ return str(obj.attrs[DEVICE_ATTR_NAME])
93
+
94
+
95
+ def _infer_device_from_args(*args: Any, **kwargs: Any) -> str:
96
+ """Infer target device from function arguments.
97
+
98
+ Inference strategy:
99
+ 1. Collect all MPObject arguments and check device attributes
100
+ - If MPObject has no device attr -> error (user must set_devid)
101
+ - If no MPObject arguments -> error (explicit device required)
102
+
103
+ 2. Analyze device distribution
104
+ 2.1 All objects on same device -> use that device
105
+ 2.2 Multiple devices with g_auto_trans enabled:
106
+ - Single SPU (+ PPUs) -> use SPU (auto-transfer from PPUs)
107
+ - Single TEE (+ PPUs) -> use TEE (auto-transfer from PPUs)
108
+ - Multiple PPUs only -> error (ambiguous, need explicit device)
109
+ - Multiple SPUs -> error (ambiguous)
110
+ - Multiple TEEs -> error (ambiguous)
111
+ - SPU + TEE -> error (conflicting secure devices)
112
+ 2.3 Multiple devices with g_auto_trans disabled -> error
113
+
114
+ Args:
115
+ *args: Positional arguments
116
+ **kwargs: Keyword arguments
117
+
118
+ Returns:
119
+ Device id string
120
+
121
+ Raises:
122
+ ValueError: When inference fails or is ambiguous
123
+ """
124
+ from jax.tree_util import tree_flatten
125
+
126
+ # Step 1: Collect all MPObject arguments and validate device attributes
127
+ all_args = tree_flatten((args, kwargs))[0]
128
+ device_objs = []
129
+
130
+ for obj in all_args:
131
+ if isinstance(obj, MPObject):
132
+ if not is_device_obj(obj):
133
+ raise ValueError(
134
+ "MPObject is missing device attribute. "
135
+ "If you're mixing device-level and simp-level code, "
136
+ "use set_dev_attr(obj, 'device_id') to mark the device explicitly."
137
+ )
138
+ device_objs.append(obj)
139
+
140
+ if not device_objs:
141
+ raise ValueError(
142
+ "Cannot infer device: no MPObject arguments found. "
143
+ "Please specify device explicitly using device('device_id')(fn)."
144
+ )
145
+
146
+ # Step 2: Extract all unique devices
147
+ devices = {get_dev_attr(obj) for obj in device_objs}
148
+
149
+ if len(devices) == 1:
150
+ return devices.pop() # All arguments on same device
151
+
152
+ # Step 3: Multiple devices - check if auto-transfer is enabled
153
+ if not g_auto_trans:
154
+ raise ValueError(
155
+ f"Cannot infer device: arguments from multiple devices {devices} "
156
+ f"but auto-transfer is disabled (g_auto_trans=False). "
157
+ f"Please enable auto-transfer or put all data on same device first."
158
+ )
159
+
160
+ # Step 4: Analyze device kinds for auto-transfer scenario
161
+ cluster_spec = cur_ctx().cluster_spec
162
+ device_kinds = {
163
+ dev_id: cluster_spec.devices[dev_id].kind.upper() for dev_id in devices
164
+ }
165
+
166
+ # Count devices by type
167
+ spu_devs = [d for d, k in device_kinds.items() if k == "SPU"]
168
+ tee_devs = [d for d, k in device_kinds.items() if k == "TEE"]
169
+ ppu_devs = [d for d, k in device_kinds.items() if k == "PPU"]
170
+
171
+ # Decision logic
172
+ # Case 1: Only PPUs -> ambiguous
173
+ if not spu_devs and not tee_devs:
174
+ raise ValueError(
175
+ f"Cannot infer device: arguments from multiple PPU devices {ppu_devs}. "
176
+ f"Please specify device explicitly or use put() to consolidate data."
177
+ )
178
+
179
+ # Case 2: Single SPU (possibly with PPUs) -> use SPU
180
+ if len(spu_devs) == 1 and len(tee_devs) == 0:
181
+ return spu_devs[0]
182
+
183
+ # Case 3: Single TEE (possibly with PPUs) -> use TEE
184
+ if len(tee_devs) == 1 and len(spu_devs) == 0:
185
+ return tee_devs[0]
186
+
187
+ # Case 4: Multiple SPUs -> ambiguous
188
+ if len(spu_devs) > 1:
189
+ raise ValueError(
190
+ f"Ambiguous device inference: arguments from multiple SPU devices {spu_devs}. "
191
+ f"Please specify which SPU to use explicitly."
192
+ )
193
+
194
+ # Case 5: Multiple TEEs -> ambiguous
195
+ if len(tee_devs) > 1:
196
+ raise ValueError(
197
+ f"Ambiguous device inference: arguments from multiple TEE devices {tee_devs}. "
198
+ f"Please specify which TEE to use explicitly."
199
+ )
200
+
201
+ # Case 6: Both SPU and TEE -> conflicting
202
+ if spu_devs and tee_devs:
203
+ raise ValueError(
204
+ f"Ambiguous device inference: arguments from both SPU {spu_devs} and TEE {tee_devs}. "
205
+ f"Please specify which secure device to use explicitly."
206
+ )
207
+
208
+ # Should never reach here
209
+ raise ValueError(f"Unexpected device configuration: {devices}")
210
+
211
+
212
+ def _device_run_spu(
213
+ dev_info: Device, op: FeOperation, fn: Callable, *args: Any, **kwargs: Any
214
+ ) -> Any:
215
+ if not isinstance(op, JaxRunner):
216
+ raise ValueError("SPU device only supports JAX frontend.")
217
+ spu_mask = Mask.from_ranks([member.rank for member in dev_info.members])
218
+ pfunc, in_vars, out_tree = spu.jax_compile(fn, *args, **kwargs)
219
+ assert all(var.pmask == spu_mask for var in in_vars), in_vars
220
+ out_flat = peval(pfunc, in_vars, spu_mask)
221
+ result = tree_unflatten(out_tree, out_flat)
222
+ return tree_map(partial(set_dev_attr, dev_id=dev_info.name), result)
223
+
224
+
225
+ def _device_run_tee(
226
+ dev_info: Device, op: FeOperation, *args: Any, **kwargs: Any
227
+ ) -> Any:
228
+ # TODO(jint): should we filter out all IO operations?
229
+ assert len(dev_info.members) == 1
230
+ rank = dev_info.members[0].rank
231
+ var = run_at(rank, op, *args, **kwargs)
232
+ return tree_map(partial(set_dev_attr, dev_id=dev_info.name), var)
233
+
234
+
235
+ def _device_run_ppu(
236
+ dev_info: Device, op: FeOperation, *args: Any, **kwargs: Any
237
+ ) -> Any:
238
+ assert len(dev_info.members) == 1
239
+ rank = dev_info.members[0].rank
240
+ var = run_at(rank, op, *args, **kwargs)
241
+ return tree_map(partial(set_dev_attr, dev_id=dev_info.name), var)
242
+
243
+
244
+ def _device_run(dev_id: str, op: FeOperation, *args: Any, **kwargs: Any) -> Any:
245
+ assert isinstance(op, FeOperation)
246
+ cluster_spec = cur_ctx().cluster_spec
247
+ if dev_id not in cluster_spec.devices:
248
+ raise ValueError(f"Device {dev_id} not found in cluster spec.")
249
+ dev_info = cluster_spec.devices[dev_id]
250
+
251
+ if g_auto_trans:
252
+
253
+ def trans(obj: Any) -> Any:
254
+ if isinstance(obj, MPObject):
255
+ assert is_device_obj(obj)
256
+ return _d2d(dev_id, obj)
257
+ else:
258
+ return obj
259
+
260
+ args, kwargs = tree_map(trans, (args, kwargs))
261
+
262
+ if dev_info.kind.upper() == "SPU":
263
+ return _device_run_spu(dev_info, op, *args, **kwargs)
264
+ elif dev_info.kind.upper() == "TEE":
265
+ return _device_run_tee(dev_info, op, *args, **kwargs)
266
+ elif dev_info.kind.upper() == "PPU":
267
+ return _device_run_ppu(dev_info, op, *args, **kwargs)
268
+ else:
269
+ raise ValueError(f"Unknown device type: {dev_info.kind}")
270
+
271
+
272
+ def device(
273
+ dev_or_fn: str | Callable | None = None, *, fe_type: str = "jax"
274
+ ) -> Callable:
275
+ """Decorator to mark a function to be executed on a specific device.
276
+
277
+ Supports both explicit device specification and automatic device inference:
278
+
279
+ 1. Explicit device placement:
280
+ @device("P0")
281
+ def foo(x, y): return x + y
282
+
283
+ 2. Auto device inference:
284
+ @device
285
+ def foo(x, y): return x + y
286
+ # Device is inferred from x, y at runtime
287
+
288
+ 3. Inline usage:
289
+ result = device(lambda x, y: x + y)(x_on_p0, y_on_p0)
290
+ # Automatically infers device from arguments
291
+
292
+ Args:
293
+ dev_or_fn: Either a device id string ("P0", "SPU", etc.) for explicit placement,
294
+ a callable function for auto inference, or None (same as not providing arg).
295
+ fe_type: The frontend type of the device, could be "jax" or "nnx".
296
+ Not needed if the decorated function is already a FeOperation.
297
+
298
+ Returns:
299
+ A decorator (when dev_or_fn is a string or None) or decorated function (when callable).
300
+
301
+ Raises:
302
+ TypeError: When dev_or_fn is not a string, callable, or None.
303
+ ValueError: When device cannot be inferred or inference is ambiguous.
304
+
305
+ Device Inference Strategy:
306
+ - Same device: All arguments on device D -> execute on D
307
+ - PPU + SPU: Arguments from PPU and SPU -> execute on SPU (secure compute)
308
+ - PPU + TEE: Arguments from PPU and TEE -> execute on TEE (trusted execution)
309
+ - Multiple PPUs: Ambiguous -> error (explicit device required)
310
+ - No device objects: Cannot infer -> error (explicit device required)
311
+
312
+ Example:
313
+ >>> # Explicit device
314
+ >>> @device("P0")
315
+ ... def add_explicit(x, y):
316
+ ... return x + y
317
+ >>>
318
+ >>> # Auto inference
319
+ >>> @device
320
+ ... def add_auto(x, y):
321
+ ... return x + y
322
+ >>>
323
+ >>> x_on_p0 = ... # data on P0
324
+ >>> y_on_p0 = ... # data on P0
325
+ >>> result = add_auto(x_on_p0, y_on_p0) # Inferred to P0
326
+ >>>
327
+ >>> x_on_spu = ... # data on SPU
328
+ >>> y_on_p1 = ... # data on P1
329
+ >>> result = add_auto(x_on_spu, y_on_p1) # Inferred to SPU
330
+ """
331
+
332
+ def _execute_on_device(dev_id: str, fn: Callable, *args: Any, **kwargs: Any) -> Any:
333
+ """Helper to execute function on specified device with appropriate frontend."""
334
+ if isinstance(fn, FeOperation):
335
+ return _device_run(dev_id, fn, *args, **kwargs)
336
+ else:
337
+ if fe_type == "jax":
338
+ return _device_run(dev_id, jax_cc.run_jax, fn, *args, **kwargs)
339
+ elif fe_type == "nnx":
340
+ return _device_run(dev_id, nnx_cc.run_nnx, fn, *args, **kwargs)
341
+ else:
342
+ raise ValueError(f"Unsupported frontend type: {fe_type}")
343
+
344
+ # Case 1: device("P0") - Explicit device specification
345
+ if isinstance(dev_or_fn, str):
346
+ dev_id = dev_or_fn
347
+
348
+ def deco(fn: Callable) -> Callable:
349
+ @wraps(fn)
350
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
351
+ return _execute_on_device(dev_id, fn, *args, **kwargs)
352
+
353
+ return wrapped
354
+
355
+ return deco
356
+
357
+ # Case 2: device(fn) or @device - Auto device inference
358
+ elif callable(dev_or_fn):
359
+ fn = dev_or_fn
360
+
361
+ @wraps(fn)
362
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
363
+ try:
364
+ dev_id = _infer_device_from_args(*args, **kwargs)
365
+ except ValueError as e:
366
+ # Enhance error message with function context
367
+ raise ValueError(
368
+ f"Cannot infer device for function '{fn.__name__}': {e!s}"
369
+ ) from e
370
+
371
+ return _execute_on_device(dev_id, fn, *args, **kwargs)
372
+
373
+ return wrapped
374
+
375
+ # Case 3: device() or @device() - Return auto-inference decorator
376
+ elif dev_or_fn is None:
377
+
378
+ def deco(fn: Callable) -> Callable:
379
+ return device(fn, fe_type=fe_type)
380
+
381
+ return deco
382
+
383
+ else:
384
+ # More helpful error message for common mistakes
385
+ raise TypeError(
386
+ f"device() expects a device id (string), a function (callable), or nothing. "
387
+ f"Got: {type(dev_or_fn).__name__}.\n"
388
+ f"Usage:\n"
389
+ f" - Explicit device: @device('P0') or device('P0')(fn)\n"
390
+ f" - Auto inference: @device or device(fn)"
391
+ )
392
+
393
+
394
+ def _spu_reveal(spu_dev: Device, obj: MPObject, to_mask: Mask) -> MPObject:
395
+ spu_mask = Mask.from_ranks([m.rank for m in spu_dev.members])
396
+ assert obj.pmask == spu_mask, (obj.pmask, spu_mask)
397
+
398
+ # (n_parties, n_shares)
399
+ shares = [mpi.bcast_m(to_mask, rank, obj) for rank in Mask(spu_mask)]
400
+ assert len(shares) == Mask(spu_mask).num_parties(), (shares, spu_mask)
401
+ assert all(share.pmask == to_mask for share in shares)
402
+
403
+ # Reconstruct the original object from shares
404
+ pfunc, ins, _ = spu.reconstruct(*shares)
405
+ return peval(pfunc, ins, to_mask)[0] # type: ignore[no-any-return]
406
+
407
+
408
+ def _spu_seal(spu_dev: Device, obj: MPObject) -> list[MPObject]:
409
+ """Seal plaintext into SPU shares on a specific SPU device.
410
+
411
+ Low-level API: device id is mandatory to avoid ambiguity.
412
+ """
413
+ if obj.pmask is None:
414
+ raise ValueError("Seal can not apply to dynamic mask objects.")
415
+
416
+ spu_mask = Mask.from_ranks([member.rank for member in spu_dev.members])
417
+ spu_wsize = Mask(spu_mask).num_parties()
418
+ pfunc, ins, _ = spu.makeshares(
419
+ obj, world_size=spu_wsize, visibility=spu.Visibility.SECRET
420
+ )
421
+ assert len(ins) == 1
422
+ shares = peval(pfunc, ins)
423
+
424
+ # scatter the shares to each party.
425
+ outs = [mpi.scatter_m(spu_mask, rank, shares) for rank in obj.pmask]
426
+ return outs
427
+
428
+
429
+ def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
430
+ assert isinstance(obj, MPObject)
431
+ frm_dev_id = get_dev_attr(obj)
432
+
433
+ if frm_dev_id == to_dev_id:
434
+ return obj
435
+
436
+ cluster_spec: ClusterSpec = cur_ctx().cluster_spec
437
+ frm_dev = cluster_spec.devices[frm_dev_id]
438
+ to_dev = cluster_spec.devices[to_dev_id]
439
+ frm_to_pair = (frm_dev.kind.upper(), to_dev.kind.upper())
440
+
441
+ if frm_to_pair == ("SPU", "SPU"):
442
+ raise NotImplementedError("Only one SPU is supported for now.")
443
+ elif frm_to_pair == ("SPU", "PPU"):
444
+ assert len(to_dev.members) == 1
445
+ to_rank = to_dev.members[0].rank
446
+ var = _spu_reveal(frm_dev, obj, Mask.from_ranks([to_rank]))
447
+ return tree_map(partial(set_dev_attr, dev_id=to_dev_id), var) # type: ignore[no-any-return]
448
+ elif frm_to_pair == ("PPU", "SPU"):
449
+ assert len(frm_dev.members) == 1
450
+ frm_rank = frm_dev.members[0].rank
451
+ vars = _spu_seal(to_dev, obj)
452
+ assert len(vars) == 1, "Expected single share from PPU to SPU seal."
453
+ return tree_map(partial(set_dev_attr, dev_id=to_dev_id), vars[0]) # type: ignore[no-any-return]
454
+ elif frm_to_pair == ("PPU", "PPU"):
455
+ assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
456
+ frm_rank = frm_dev.members[0].rank
457
+ to_rank = to_dev.members[0].rank
458
+ var = mpi.p2p(frm_rank, to_rank, obj)
459
+ return tree_map(partial(set_dev_attr, dev_id=to_dev_id), var) # type: ignore[no-any-return]
460
+ elif frm_to_pair == ("PPU", "TEE"):
461
+ # Transparent handshake + encryption for the first transfer; reuse thereafter
462
+ assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
463
+ frm_rank = frm_dev.members[0].rank
464
+ tee_rank = to_dev.members[0].rank
465
+ # Ensure sessions (both directions) exist for this PPU<->TEE pair
466
+ sess_p, sess_t = _ensure_tee_session(frm_dev_id, to_dev_id, frm_rank, tee_rank)
467
+ # Bytes-only path: pack -> enc -> p2p -> dec -> unpack (with static out type)
468
+ obj_ty = obj.mptype.raw_type()
469
+ b = run_at(frm_rank, basic.pack, obj)
470
+ ct = run_at(frm_rank, crypto.enc, b, sess_p)
471
+ ct_at_tee = mpi.p2p(frm_rank, tee_rank, ct)
472
+ b_at_tee = run_at(tee_rank, crypto.dec, ct_at_tee, sess_t)
473
+ pt_at_tee = run_at(tee_rank, basic.unpack, b_at_tee, out_ty=obj_ty)
474
+ return tree_map(partial(set_dev_attr, dev_id=to_dev_id), pt_at_tee) # type: ignore[no-any-return]
475
+ elif frm_to_pair == ("TEE", "PPU"):
476
+ # Transparent encryption from TEE to a specific PPU using the reverse-direction session key
477
+ assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
478
+ tee_rank = frm_dev.members[0].rank
479
+ ppu_rank = to_dev.members[0].rank
480
+ # Ensure bidirectional session established for this pair
481
+ sess_p, sess_t = _ensure_tee_session(to_dev_id, frm_dev_id, ppu_rank, tee_rank)
482
+ obj_ty = obj.mptype.raw_type()
483
+ b = run_at(tee_rank, basic.pack, obj)
484
+ ct = run_at(tee_rank, crypto.enc, b, sess_t)
485
+ ct_at_ppu = mpi.p2p(tee_rank, ppu_rank, ct)
486
+ b_at_ppu = run_at(ppu_rank, crypto.dec, ct_at_ppu, sess_p)
487
+ pt_at_ppu = run_at(ppu_rank, basic.unpack, b_at_ppu, out_ty=obj_ty)
488
+ return tree_map(partial(set_dev_attr, dev_id=to_dev_id), pt_at_ppu) # type: ignore[no-any-return]
489
+ elif frm_to_pair == ("TEE", "SPU"):
490
+ assert len(frm_dev.members) == 1
491
+ frm_rank = frm_dev.members[0].rank
492
+ vars = _spu_seal(to_dev, obj)
493
+ assert len(vars) == 1, "Expected single share from TEE to SPU seal."
494
+ return tree_map(partial(set_dev_attr, dev_id=to_dev_id), vars[0]) # type: ignore[no-any-return]
495
+ elif frm_to_pair == ("SPU", "TEE"):
496
+ assert len(to_dev.members) == 1
497
+ to_rank = to_dev.members[0].rank
498
+ var = _spu_reveal(frm_dev, obj, Mask.from_ranks([to_rank]))
499
+ return tree_map(partial(set_dev_attr, dev_id=to_dev_id), var) # type: ignore[no-any-return]
500
+ else:
501
+ supported = [
502
+ ("SPU", "PPU"),
503
+ ("PPU", "SPU"),
504
+ ("PPU", "PPU"),
505
+ ("PPU", "TEE"),
506
+ ("TEE", "PPU"),
507
+ ("TEE", "SPU"),
508
+ ("SPU", "TEE"),
509
+ ]
510
+ raise ValueError(
511
+ f"Unsupported device transfer: {frm_to_pair}. Supported pairs: {supported}."
512
+ )
513
+
514
+
515
+ def _ensure_tee_session(
516
+ frm_dev_id: str, to_dev_id: str, frm_rank: int, tee_rank: int
517
+ ) -> tuple[MPObject, MPObject]:
518
+ """Ensure a TEE session (sess_p at sender, sess_t at TEE) exists.
519
+
520
+ Context-aware version: caches include context ID to ensure isolation
521
+ between different TraceContext instances, preventing TraceVar pollution.
522
+
523
+ Returns (sess_p, sess_t).
524
+ """
525
+ # Get current context and its unique ID
526
+ current_ctx = cur_ctx()
527
+ current_context_id = _get_context_id(current_ctx)
528
+
529
+ # Get root context for cache storage
530
+ root_ctx = current_ctx.root()
531
+ if not hasattr(root_ctx, "_tee_sessions"):
532
+ root_ctx._tee_sessions = {} # type: ignore[attr-defined]
533
+ cache: dict[tuple[str, str], tuple[int, MPObject, MPObject]] = (
534
+ root_ctx._tee_sessions # type: ignore[attr-defined]
535
+ )
536
+
537
+ key = (frm_dev_id, to_dev_id)
538
+
539
+ # Check cache with context awareness
540
+ if key in cache:
541
+ cached_context_id, sess_p, sess_t = cache[key]
542
+
543
+ # Only reuse cache from the same context
544
+ if cached_context_id == current_context_id:
545
+ return sess_p, sess_t
546
+ else:
547
+ # Different context, cannot reuse cache, clean up old entry
548
+ del cache[key]
549
+
550
+ # 1) TEE generates (sk, pk) and quote(pk)
551
+ # KEM suite currently constant; future: read from tee device config (e.g. cluster_spec.devices[dev_id].config)
552
+ tee_sk, tee_pk = run_at(tee_rank, crypto.kem_keygen, _TEE_KEM_SUITE)
553
+ quote = run_at(tee_rank, tee.quote_gen, tee_pk)
554
+
555
+ # 2) Send quote to sender and attest to obtain TEE pk
556
+ quote_at_sender = mpi.p2p(tee_rank, frm_rank, quote)
557
+ tee_pk_at_sender = run_at(frm_rank, tee.attest, quote_at_sender)
558
+
559
+ # 3) Sender generates its ephemeral keypair and sends its pk to TEE
560
+ v_sk, v_pk = run_at(frm_rank, crypto.kem_keygen, _TEE_KEM_SUITE)
561
+ v_pk_at_tee = mpi.p2p(frm_rank, tee_rank, v_pk)
562
+
563
+ # 4) Both sides derive the shared secret and session key
564
+ shared_p = run_at(
565
+ frm_rank, crypto.kem_derive, v_sk, tee_pk_at_sender, _TEE_KEM_SUITE
566
+ )
567
+ shared_t = run_at(tee_rank, crypto.kem_derive, tee_sk, v_pk_at_tee, _TEE_KEM_SUITE)
568
+ # Use a fixed ASCII string literal for HKDF info on both sides
569
+ sess_p = run_at(frm_rank, crypto.hkdf, shared_p, _HKDF_INFO_LITERAL)
570
+ sess_t = run_at(tee_rank, crypto.hkdf, shared_t, _HKDF_INFO_LITERAL)
571
+
572
+ # Cache with context ID for isolation
573
+ cache[key] = (current_context_id, sess_p, sess_t)
574
+ return sess_p, sess_t
575
+
576
+
577
+ def _host_to_device(to_dev_id: str, obj: Any) -> MPObject:
578
+ if isinstance(obj, TensorLike):
579
+ # run jax identity on the target device to put the tensor there
580
+ return device(to_dev_id)(lambda x: x)(obj) # type: ignore[no-any-return]
581
+ elif isinstance(obj, TableLike):
582
+ dev_info = cur_ctx().cluster_spec.devices[to_dev_id]
583
+ if dev_info.kind.upper() not in ["PPU", "TEE"]:
584
+ raise ValueError(
585
+ f"TableLike put() only supports PPU or TEE devices, got {dev_info.kind}"
586
+ )
587
+ assert len(dev_info.members) == 1
588
+ rank = dev_info.members[0].rank
589
+ obj_mp = cast(MPObject, run_at(rank, basic.constant, obj))
590
+ set_dev_attr(obj_mp, to_dev_id)
591
+ return obj_mp
592
+ else:
593
+ raise TypeError(
594
+ f"put() only supports TensorLike or TableLike objects, got {type(obj)}"
595
+ )
596
+
597
+
598
+ def put(to_dev_id: str, obj: Any) -> MPObject:
599
+ if not isinstance(obj, MPObject):
600
+ return _host_to_device(to_dev_id, obj)
601
+ assert isinstance(obj, MPObject)
602
+ return _d2d(to_dev_id, obj)
@@ -18,7 +18,7 @@ This subpackage hosts non-core developer aids: diagram rendering, IR dumps,
18
18
  profiling helpers (future), etc.
19
19
  """
20
20
 
21
- from mplang.analysis.diagram import (
21
+ from mplang.v1.analysis.diagram import (
22
22
  DumpResult,
23
23
  FlowchartOptions,
24
24
  SequenceDiagramOptions,
@@ -14,7 +14,7 @@
14
14
 
15
15
  """Diagram rendering (Mermaid) and markdown dump helpers.
16
16
 
17
- Moved from mplang.utils.mermaid to dedicated analysis namespace.
17
+ Moved from mplang.v1.utils.mermaid to dedicated analysis namespace.
18
18
  """
19
19
 
20
20
  from __future__ import annotations
@@ -23,9 +23,9 @@ from dataclasses import dataclass
23
23
  from pathlib import Path
24
24
  from typing import TypedDict
25
25
 
26
- from mplang.core import ClusterSpec, IrWriter, Mask, TracedFunction
27
- from mplang.core.mpir import get_graph_statistics
28
- from mplang.protos.v1alpha1 import mpir_pb2
26
+ from mplang.v1.core import ClusterSpec, IrWriter, Mask, TracedFunction
27
+ from mplang.v1.core.mpir import get_graph_statistics
28
+ from mplang.v1.protos.v1alpha1 import mpir_pb2
29
29
 
30
30
  # ----------------------------- Core helpers (copied) -----------------------------
31
31
 
@@ -21,15 +21,15 @@ including type systems, tracing mechanisms, and interpreter contexts.
21
21
 
22
22
  # Core type system
23
23
  # Communication interfaces & core symbols
24
- from mplang.core.cluster import ClusterSpec, Device, Node, RuntimeInfo
25
- from mplang.core.comm import (
24
+ from mplang.v1.core.cluster import ClusterSpec, Device, Node, RuntimeInfo
25
+ from mplang.v1.core.comm import (
26
26
  CollectiveMixin,
27
27
  CommunicatorBase,
28
28
  ICollective,
29
29
  ICommunicator,
30
30
  )
31
- from mplang.core.context_mgr import cur_ctx, set_ctx, with_ctx
32
- from mplang.core.dtypes import (
31
+ from mplang.v1.core.context_mgr import cur_ctx, set_ctx, with_ctx
32
+ from mplang.v1.core.dtypes import (
33
33
  BINARY,
34
34
  BOOL,
35
35
  COMPLEX64,
@@ -55,16 +55,16 @@ from mplang.core.dtypes import (
55
55
  UUID,
56
56
  DType,
57
57
  )
58
- from mplang.core.interp import InterpContext, InterpVar
59
- from mplang.core.mask import Mask
60
- from mplang.core.mpir import IrReader, IrWriter
61
- from mplang.core.mpobject import MPContext, MPObject
62
- from mplang.core.mptype import MPType, Rank, Shape
63
- from mplang.core.pfunc import PFunction, get_fn_name
58
+ from mplang.v1.core.interp import InterpContext, InterpVar
59
+ from mplang.v1.core.mask import Mask
60
+ from mplang.v1.core.mpir import IrReader, IrWriter
61
+ from mplang.v1.core.mpobject import MPContext, MPObject
62
+ from mplang.v1.core.mptype import MPType, Rank, Shape
63
+ from mplang.v1.core.pfunc import PFunction, get_fn_name
64
64
 
65
65
  # Import primitive-dependent symbols at the end to avoid cycles when frontend ops
66
66
  # import from the core facade during package initialization.
67
- from mplang.core.primitive import (
67
+ from mplang.v1.core.primitive import (
68
68
  builtin_function,
69
69
  function,
70
70
  pconv,
@@ -76,9 +76,15 @@ from mplang.core.primitive import (
76
76
  uniform_cond,
77
77
  while_loop,
78
78
  )
79
- from mplang.core.table import TableLike, TableType
80
- from mplang.core.tensor import ScalarType, TensorLike, TensorType
81
- from mplang.core.tracer import TraceContext, TracedFunction, TraceVar, VarNamer, trace
79
+ from mplang.v1.core.table import TableLike, TableType
80
+ from mplang.v1.core.tensor import ScalarType, TensorLike, TensorType
81
+ from mplang.v1.core.tracer import (
82
+ TraceContext,
83
+ TracedFunction,
84
+ TraceVar,
85
+ VarNamer,
86
+ trace,
87
+ )
82
88
 
83
89
  __all__ = [
84
90
  "BINARY",