mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,813 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Device-oriented programming interface for MPLang2.
17
+
18
+ This module provides high-level abstractions for device placement and data movement.
19
+ It allows users to write programs in a device-centric way, handling data transfers
20
+ and execution dispatch automatically.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ from collections.abc import Callable
26
+ from functools import partial, wraps
27
+ from typing import Any, cast
28
+
29
+ from jax.tree_util import tree_flatten, tree_map
30
+
31
+ from mplang.v2.backends import load_builtins
32
+ from mplang.v2.dialects import crypto, simp, spu, tee
33
+ from mplang.v2.edsl.object import Object
34
+ from mplang.v2.libs.device.cluster import Device
35
+
36
+ load_builtins()
37
+
38
+
39
+ def _resolve_cluster() -> Any:
40
+ """Resolve the active ClusterSpec by traversing the context stack.
41
+
42
+ Traverses from the top of the stack (most recent) to find the nearest
43
+ Interpreter with a _cluster_spec attribute. This allows nested contexts
44
+ to override the cluster if needed.
45
+ """
46
+ from mplang.v2.edsl.context import find_context
47
+
48
+ ctx = find_context(lambda c: getattr(c, "_cluster_spec", None) is not None)
49
+ if ctx is not None:
50
+ return ctx._cluster_spec # type: ignore[attr-defined]
51
+
52
+ raise RuntimeError(
53
+ "No active device context found. Please use 'with simulator:' "
54
+ "or 'push_context(sim)' to set the execution environment."
55
+ )
56
+
57
+
58
+ # Magic attribute name to mark an Object as a device object
59
+ DEVICE_ATTR_NAME = "__device__"
60
+
61
+ # Default KEM suite for TEE session establishment
62
+ _TEE_KEM_SUITE: str = "x25519"
63
+
64
+ # HKDF info string for TEE session key derivation (domain separation)
65
+ _TEE_HKDF_INFO: str = "mplang/device/tee/v2"
66
+
67
+ # Global cache for TEE sessions (keyed by (frm_dev_id, to_dev_id))
68
+ # Each entry is (context_id, sess_frm, sess_tee) where context_id ensures
69
+ # sessions are not reused across different trace/interp contexts.
70
+ _tee_session_cache: dict[tuple[str, str], tuple[int, Object, Object]] = {}
71
+
72
+ # Automatic transfer between devices when parameter is not on the target device.
73
+ g_auto_trans: bool = True
74
+
75
+
76
+ class DeviceError(Exception):
77
+ """Base exception for device-related errors."""
78
+
79
+
80
+ class DeviceNotFoundError(DeviceError):
81
+ """Raised when a device ID is not found in the cluster."""
82
+
83
+
84
+ class DeviceInferenceError(DeviceError):
85
+ """Raised when device cannot be inferred from arguments."""
86
+
87
+
88
+ def is_device_obj(obj: Any) -> bool:
89
+ """Check if an object is a device object (has device attribute)."""
90
+ if not isinstance(obj, Object):
91
+ return False
92
+ return hasattr(obj, DEVICE_ATTR_NAME)
93
+
94
+
95
+ def set_dev_attr(obj: Object, dev_id: str) -> Object:
96
+ """Mark an object as residing on a specific device."""
97
+ if not isinstance(obj, Object):
98
+ raise TypeError(f"Input must be an instance of Object, got {type(obj)}")
99
+ setattr(obj, DEVICE_ATTR_NAME, dev_id)
100
+ return obj
101
+
102
+
103
+ def get_dev_attr(obj: Object) -> str:
104
+ """Get the device ID of an object."""
105
+ if not isinstance(obj, Object):
106
+ raise TypeError("Input must be an instance of Object")
107
+ if not hasattr(obj, DEVICE_ATTR_NAME):
108
+ raise ValueError("Object does not have a device attribute")
109
+ return str(getattr(obj, DEVICE_ATTR_NAME))
110
+
111
+
112
+ def _maybe_set_dev_attr(dev_id: str, obj: Any) -> Any:
113
+ """Set device attribute if obj is an Object, otherwise return as-is."""
114
+ if isinstance(obj, Object):
115
+ return set_dev_attr(obj, dev_id)
116
+ return obj
117
+
118
+
119
+ def _infer_device_from_args(*args: Any, **kwargs: Any) -> str:
120
+ """Infer target device from function arguments."""
121
+ all_args = tree_flatten((args, kwargs))[0]
122
+ device_objs = []
123
+
124
+ for obj in all_args:
125
+ if isinstance(obj, Object):
126
+ if not is_device_obj(obj):
127
+ # Skip non-device objects (they might be purely local/host values)
128
+ continue
129
+ device_objs.append(obj)
130
+
131
+ if not device_objs:
132
+ raise DeviceInferenceError(
133
+ "Cannot infer device: no device-bound Object arguments found. "
134
+ "Please specify device explicitly using device('device_id')."
135
+ )
136
+
137
+ devices = {get_dev_attr(obj) for obj in device_objs}
138
+
139
+ if len(devices) == 1:
140
+ return devices.pop() # All arguments on same device
141
+
142
+ if not g_auto_trans:
143
+ raise DeviceInferenceError(
144
+ f"Cannot infer device: arguments from multiple devices {devices} "
145
+ f"but auto-transfer is disabled (g_auto_trans=False). "
146
+ f"Please enable auto-transfer or put all data on same device first."
147
+ )
148
+
149
+ cluster = _resolve_cluster()
150
+ device_kinds = {dev_id: cluster.devices[dev_id].kind.upper() for dev_id in devices}
151
+
152
+ # Count devices by type
153
+ spu_devs = [d for d, k in device_kinds.items() if k == "SPU"]
154
+ tee_devs = [d for d, k in device_kinds.items() if k == "TEE"]
155
+ ppu_devs = [d for d, k in device_kinds.items() if k == "PPU"]
156
+
157
+ # Decision logic
158
+ # Case 1: Only PPUs -> ambiguous (unless we want to pick one arbitrarily, but safer to error)
159
+ if not spu_devs and not tee_devs:
160
+ raise DeviceInferenceError(
161
+ f"Cannot infer device: arguments from multiple PPU devices {ppu_devs}. "
162
+ f"Please specify device explicitly or use put() to consolidate data."
163
+ )
164
+
165
+ # Case 2: Single SPU (possibly with PPUs) -> use SPU
166
+ if len(spu_devs) == 1 and len(tee_devs) == 0:
167
+ return spu_devs[0]
168
+
169
+ # Case 3: Single TEE (possibly with PPUs) -> use TEE
170
+ if len(tee_devs) == 1 and len(spu_devs) == 0:
171
+ return tee_devs[0]
172
+
173
+ # Case 4: Multiple SPUs -> ambiguous
174
+ if len(spu_devs) > 1:
175
+ raise DeviceInferenceError(
176
+ f"Ambiguous device inference: arguments from multiple SPU devices {spu_devs}. "
177
+ f"Please specify which SPU to use explicitly."
178
+ )
179
+
180
+ # Case 5: Multiple TEEs -> ambiguous
181
+ if len(tee_devs) > 1:
182
+ raise DeviceInferenceError(
183
+ f"Ambiguous device inference: arguments from multiple TEE devices {tee_devs}. "
184
+ f"Please specify which TEE to use explicitly."
185
+ )
186
+
187
+ # Case 6: Both SPU and TEE -> conflicting
188
+ if spu_devs and tee_devs:
189
+ raise DeviceInferenceError(
190
+ f"Ambiguous device inference: arguments from both SPU {spu_devs} and TEE {tee_devs}. "
191
+ f"Please specify which secure device to use explicitly."
192
+ )
193
+
194
+ # Should never reach here
195
+ raise DeviceInferenceError(f"Unexpected device configuration: {devices}")
196
+
197
+
198
+ def _device_run_spu(dev_info: Device, fn: Callable, *args: Any, **kwargs: Any) -> Any:
199
+ """Run function on SPU device."""
200
+ spu_parties = tuple(m.rank for m in dev_info.members)
201
+
202
+ # SPU execution uses spu.run_jax to compile and execute the function on the SPU.
203
+ # Inputs are expected to be already on the SPU (handled by _d2d).
204
+ # We wrap spu.run_jax in simp.pcall_static to execute it on all SPU parties.
205
+ spu_config = spu.SPUConfig.from_dict(dev_info.config)
206
+ result = simp.pcall_static(
207
+ spu_parties,
208
+ spu.run_jax,
209
+ spu_config,
210
+ fn,
211
+ *args,
212
+ **kwargs,
213
+ )
214
+
215
+ return tree_map(partial(set_dev_attr, dev_id=dev_info.name), result)
216
+
217
+
218
+ def _device_run_ppu(
219
+ dev_info: Device,
220
+ fn: Callable,
221
+ *args: Any,
222
+ **kwargs: Any,
223
+ ) -> Any:
224
+ """Run function on PPU device."""
225
+ assert len(dev_info.members) == 1
226
+ rank = dev_info.members[0].rank
227
+
228
+ result = simp.pcall_static((rank,), fn, *args, **kwargs)
229
+ return tree_map(partial(_maybe_set_dev_attr, dev_info.name), result)
230
+
231
+
232
+ def _device_run_tee(
233
+ dev_info: Device,
234
+ fn: Callable,
235
+ *args: Any,
236
+ **kwargs: Any,
237
+ ) -> Any:
238
+ """Run function on TEE device.
239
+
240
+ TEE devices execute functions in a trusted execution environment.
241
+ The execution is similar to PPU but runs in an isolated enclave.
242
+ """
243
+ assert len(dev_info.members) == 1
244
+ rank = dev_info.members[0].rank
245
+
246
+ result = simp.pcall_static((rank,), fn, *args, **kwargs)
247
+ return tree_map(partial(_maybe_set_dev_attr, dev_info.name), result)
248
+
249
+
250
+ def _device_run(
251
+ dev_id: str,
252
+ fn: Callable,
253
+ *args: Any,
254
+ **kwargs: Any,
255
+ ) -> Any:
256
+ """Execute function on the specified device."""
257
+ cluster = _resolve_cluster()
258
+ if dev_id not in cluster.devices:
259
+ available = list(cluster.devices.keys())
260
+ raise DeviceNotFoundError(
261
+ f"Device '{dev_id}' not found in cluster. Available devices: {available}"
262
+ )
263
+ dev_info = cluster.devices[dev_id]
264
+
265
+ if g_auto_trans:
266
+
267
+ def trans(obj: Any) -> Any:
268
+ if isinstance(obj, Object) and is_device_obj(obj):
269
+ return _d2d(dev_id, obj)
270
+ else:
271
+ return obj
272
+
273
+ args, kwargs = tree_map(trans, (args, kwargs))
274
+
275
+ if dev_info.kind.upper() == "SPU":
276
+ return _device_run_spu(dev_info, fn, *args, **kwargs)
277
+ elif dev_info.kind.upper() == "TEE":
278
+ return _device_run_tee(dev_info, fn, *args, **kwargs)
279
+ elif dev_info.kind.upper() == "PPU":
280
+ return _device_run_ppu(dev_info, fn, *args, **kwargs)
281
+ else:
282
+ raise DeviceError(f"Unknown device type: {dev_info.kind}")
283
+
284
+
285
+ class DeviceContext:
286
+ """Context for device-specific operations.
287
+
288
+ Supports explicit device specification or auto-inference from arguments.
289
+
290
+ Examples:
291
+ # Explicit device
292
+ @device("P0")
293
+ def add(a, b): ...
294
+
295
+ # Auto-infer device from arguments
296
+ @device()
297
+ def add(a, b): ...
298
+
299
+ # JAX frontend via .jax property (recommended for PPU)
300
+ @device("P0").jax
301
+ def add(a, b): return a + b
302
+
303
+ # Or use separate decorators (equivalent)
304
+ @device("P0")
305
+ @jax_fn
306
+ def add(a, b): return a + b
307
+
308
+ # Inline call style
309
+ result = device("P0").jax(fn)(x, y)
310
+ """
311
+
312
+ def __init__(self, dev_id: str | None = None):
313
+ """Create a DeviceContext.
314
+
315
+ Args:
316
+ dev_id: Device ID (e.g., "P0", "SP0") or None for auto-inference.
317
+ """
318
+ self.dev_id = dev_id
319
+
320
+ def _resolve_device(self, *args: Any, **kwargs: Any) -> str:
321
+ """Resolve device ID, inferring from args if needed."""
322
+ if self.dev_id is not None:
323
+ return self.dev_id
324
+ return _infer_device_from_args(*args, **kwargs)
325
+
326
+ def _is_spu_device(self) -> bool:
327
+ """Check if this device context targets an SPU device."""
328
+ if self.dev_id is None:
329
+ return False
330
+ cluster = _resolve_cluster()
331
+ if self.dev_id not in cluster.devices:
332
+ return False
333
+ return bool(cluster.devices[self.dev_id].kind.upper() == "SPU")
334
+
335
+ @property
336
+ def jax(self) -> Callable[[Callable], Callable]:
337
+ """Return a decorator that wraps JAX functions for this device.
338
+
339
+ For PPU/TEE: applies tensor.jax_fn to compile JAX code via StableHLO.
340
+ For SPU: no-op wrapper, as SPU natively uses JAX via spu.run_jax.
341
+
342
+ This is syntax sugar for using jax_fn adaptor:
343
+ device("P0").jax(fn) == device("P0")(jax_fn(fn))
344
+
345
+ Examples:
346
+ # As decorator
347
+ @device("P0").jax
348
+ def add(a, b): return a + b
349
+
350
+ # As inline call
351
+ result = device("P0").jax(fn)(x, y)
352
+ """
353
+
354
+ def wrapper(fn: Callable) -> Callable:
355
+ # SPU natively uses JAX via spu.run_jax, no extra wrapping needed
356
+ if self._is_spu_device():
357
+ return self(fn)
358
+ # PPU/TEE need tensor.jax_fn to compile JAX code
359
+ from mplang.v2.dialects.tensor import jax_fn
360
+
361
+ return self(jax_fn(fn))
362
+
363
+ return wrapper
364
+
365
+ def __call__(self, fn: Callable) -> Callable:
366
+ """Wrap function for execution on this device."""
367
+
368
+ @wraps(fn)
369
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
370
+ dev_id = self._resolve_device(*args, **kwargs)
371
+ return _device_run(dev_id, fn, *args, **kwargs)
372
+
373
+ return wrapped
374
+
375
+
376
+ def device(dev_id: str | None = None) -> DeviceContext:
377
+ """Create a device context for device-specific execution.
378
+
379
+ Args:
380
+ dev_id: Device ID (e.g., "P0", "SP0") or None for auto-inference.
381
+
382
+ Returns:
383
+ DeviceContext that wraps functions for device execution.
384
+
385
+ Usage patterns:
386
+ # Explicit device + generic function
387
+ @device("P0")
388
+ def fn(a, b): ...
389
+
390
+ # Auto-infer device from arguments
391
+ @device()
392
+ def fn(a, b): ...
393
+
394
+ # JAX frontend via .jax property (recommended for PPU)
395
+ @device("P0").jax
396
+ def add(a, b): return a + b
397
+
398
+ # Inline call
399
+ result = device("P0").jax(fn)(x, y)
400
+
401
+ # Separate decorators (equivalent to above)
402
+ @device("P0")
403
+ @jax_fn
404
+ def add(a, b): return a + b
405
+ """
406
+ return DeviceContext(dev_id)
407
+
408
+
409
+ def _ensure_tee_session(
410
+ frm_dev_id: str,
411
+ to_dev_id: str,
412
+ frm_rank: int,
413
+ tee_rank: int,
414
+ ) -> tuple[Object, Object]:
415
+ """Ensure a TEE session (sess_frm at sender, sess_tee at TEE) exists.
416
+
417
+ Performs remote attestation and establishes an encrypted channel between
418
+ a PPU and a TEE device using NIST SP 800-56C compliant key derivation.
419
+ Session keys are cached within the same execution context to avoid
420
+ repeated handshakes.
421
+
422
+ Protocol (ECDH + Remote Attestation + HKDF):
423
+ 1. TEE generates keypair (sk, pk) and creates attestation quote binding pk
424
+ 2. Quote is sent to sender (PPU) for verification
425
+ 3. Sender verifies quote and extracts TEE's attested public key
426
+ 4. Sender generates ephemeral keypair and sends pk to TEE
427
+ 5. Both sides derive ECDH shared secret using X25519
428
+ 6. Both sides derive session keys from shared secret using HKDF-SHA256
429
+ with protocol-specific info string for domain separation
430
+
431
+ Security properties:
432
+ - Remote attestation: TEE identity is cryptographically verified
433
+ - Ephemeral keys: Perfect forward secrecy (keys not reused across sessions)
434
+ - HKDF derivation: NIST SP 800-56C compliant (shared secret not used directly)
435
+ - Domain separation: Info parameter binds keys to TEE protocol v2
436
+
437
+ Args:
438
+ frm_dev_id: Source device ID (PPU)
439
+ to_dev_id: Target device ID (TEE)
440
+ frm_rank: Rank of the source party
441
+ tee_rank: Rank of the TEE party
442
+
443
+ Returns:
444
+ Tuple of (sess_frm, sess_tee) where each is a symmetric key Object
445
+ """
446
+ import mplang.v2.edsl as el
447
+
448
+ # Get current context ID for cache isolation
449
+ current_ctx = el.get_current_context()
450
+ current_context_id = id(current_ctx)
451
+
452
+ # Check cache
453
+ key = (frm_dev_id, to_dev_id)
454
+ if key in _tee_session_cache:
455
+ cached_context_id, sess_frm, sess_tee = _tee_session_cache[key]
456
+ if cached_context_id == current_context_id:
457
+ return sess_frm, sess_tee
458
+ else:
459
+ # Different context, cannot reuse
460
+ del _tee_session_cache[key]
461
+
462
+ # 1. TEE generates keypair and attestation quote
463
+ tee_sk, tee_pk = simp.pcall_static((tee_rank,), crypto.kem_keygen, _TEE_KEM_SUITE)
464
+ quote = simp.pcall_static((tee_rank,), tee.quote_gen, tee_pk)
465
+
466
+ # 2. Send quote to sender for attestation verification
467
+ quote_at_sender = simp.shuffle_static(quote, {frm_rank: tee_rank})
468
+
469
+ # 3. Sender verifies quote and extracts TEE's public key
470
+ tee_pk_at_sender = simp.pcall_static((frm_rank,), tee.attest, quote_at_sender)
471
+
472
+ # 4. Sender generates ephemeral keypair and sends pk to TEE
473
+ v_sk, v_pk = simp.pcall_static((frm_rank,), crypto.kem_keygen, _TEE_KEM_SUITE)
474
+ v_pk_at_tee = simp.shuffle_static(v_pk, {tee_rank: frm_rank})
475
+
476
+ # 5. Both sides derive ECDH shared secret using X25519
477
+ # Note: kem_derive signature is (private_key, public_key) - suite is in key type
478
+ shared_frm = simp.pcall_static(
479
+ (frm_rank,), crypto.kem_derive, v_sk, tee_pk_at_sender
480
+ )
481
+ shared_tee = simp.pcall_static((tee_rank,), crypto.kem_derive, tee_sk, v_pk_at_tee)
482
+
483
+ # 6. Derive session keys using HKDF-SHA256 for domain separation
484
+ # Per NIST SP 800-56C: "shared secret SHALL NOT be used directly as a key"
485
+ # HKDF provides: uniform distribution + protocol-specific context binding
486
+ sess_frm = simp.pcall_static((frm_rank,), crypto.hkdf, shared_frm, _TEE_HKDF_INFO)
487
+ sess_tee = simp.pcall_static((tee_rank,), crypto.hkdf, shared_tee, _TEE_HKDF_INFO)
488
+
489
+ # Cache the session
490
+ _tee_session_cache[key] = (current_context_id, sess_frm, sess_tee)
491
+
492
+ return sess_frm, sess_tee
493
+
494
+
495
+ def _d2d(to_dev_id: str, obj: Object) -> Object:
496
+ """Transfer object to target device."""
497
+ if not isinstance(obj, Object):
498
+ raise TypeError(f"Expected Object, got {type(obj)}")
499
+
500
+ frm_dev_id = get_dev_attr(obj)
501
+ if frm_dev_id == to_dev_id:
502
+ return obj
503
+
504
+ cluster = _resolve_cluster()
505
+ frm_dev = cluster.devices[frm_dev_id]
506
+ to_dev = cluster.devices[to_dev_id]
507
+ frm_to_pair = (frm_dev.kind.upper(), to_dev.kind.upper())
508
+
509
+ # PPU -> PPU
510
+ if frm_to_pair == ("PPU", "PPU"):
511
+ assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
512
+ to_rank = to_dev.members[0].rank
513
+ frm_rank = frm_dev.members[0].rank
514
+
515
+ var = simp.shuffle_static(obj, {to_rank: frm_rank})
516
+ return set_dev_attr(var, to_dev_id)
517
+
518
+ # PPU -> SPU (Seal)
519
+ elif frm_to_pair == ("PPU", "SPU"):
520
+ assert len(frm_dev.members) == 1
521
+ frm_rank = frm_dev.members[0].rank
522
+ spu_parties = tuple(m.rank for m in to_dev.members)
523
+ spu_config = spu.SPUConfig.from_dict(to_dev.config)
524
+
525
+ # 1. Generate shares on source
526
+ # We call spu.make_shares inside pcall on the source party
527
+ shares_on_source = simp.pcall_static(
528
+ (frm_rank,),
529
+ spu.make_shares,
530
+ spu_config,
531
+ obj,
532
+ count=len(spu_parties),
533
+ )
534
+
535
+ # 2. Distribute shares
536
+ distributed_shares = []
537
+ for i, target_rank in enumerate(spu_parties):
538
+ # Extract i-th share (still on source)
539
+ # shares_on_source is MP[tuple[SS, ...], (frm_rank)]
540
+ # We need to extract the i-th element.
541
+ # Since pcall returns MPType, we can't index it directly if it's a tuple of shares.
542
+ # Wait, pcall returns a PyTree of MPObjects if the function returns a PyTree.
543
+ # So shares_on_source IS a tuple of MPObjects.
544
+ share_i = shares_on_source[i]
545
+
546
+ share_at_target = simp.shuffle_static(share_i, {target_rank: frm_rank})
547
+ distributed_shares.append(share_at_target)
548
+
549
+ # 3. Converge
550
+ var = simp.converge(*distributed_shares)
551
+ return set_dev_attr(var, to_dev_id)
552
+
553
+ # SPU -> PPU (Reveal)
554
+ elif frm_to_pair == ("SPU", "PPU"):
555
+ assert len(to_dev.members) == 1
556
+ to_rank = to_dev.members[0].rank
557
+ spu_parties = tuple(m.rank for m in frm_dev.members)
558
+ spu_config = spu.SPUConfig.from_dict(frm_dev.config)
559
+
560
+ # 1. Gather shares to target
561
+ gathered_shares = []
562
+ for source_rank in spu_parties:
563
+ # Extract share from logical variable
564
+ share_on_source = simp.pcall_static((source_rank,), lambda x: x, obj)
565
+
566
+ # Move to target
567
+ share_at_target = simp.shuffle_static(
568
+ share_on_source, {to_rank: source_rank}
569
+ )
570
+ gathered_shares.append(share_at_target)
571
+
572
+ # 2. Reconstruct on target
573
+ # We call spu.reconstruct inside pcall on the target party
574
+ var = simp.pcall_static(
575
+ (to_rank,), lambda *s: spu.reconstruct(spu_config, s), *gathered_shares
576
+ )
577
+ return set_dev_attr(var, to_dev_id)
578
+
579
+ # SPU -> SPU
580
+ elif frm_to_pair == ("SPU", "SPU"):
581
+ raise NotImplementedError("SPU to SPU transfer not implemented yet.")
582
+
583
+ # PPU -> TEE (Encrypted transfer)
584
+ elif frm_to_pair == ("PPU", "TEE"):
585
+ assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
586
+ frm_rank = frm_dev.members[0].rank
587
+ tee_rank = to_dev.members[0].rank
588
+
589
+ # Establish encrypted session (includes remote attestation)
590
+ sess_frm, sess_tee = _ensure_tee_session(
591
+ frm_dev_id, to_dev_id, frm_rank, tee_rank
592
+ )
593
+
594
+ # Encrypt on sender and send to TEE
595
+ ct = simp.pcall_static((frm_rank,), crypto.sym_encrypt, sess_frm, obj)
596
+ ct_at_tee = simp.shuffle_static(ct, {tee_rank: frm_rank})
597
+
598
+ # Decrypt on TEE
599
+ var = simp.pcall_static(
600
+ (tee_rank,),
601
+ crypto.sym_decrypt,
602
+ sess_tee,
603
+ ct_at_tee,
604
+ obj.type.value_type if hasattr(obj.type, "value_type") else obj.type,
605
+ )
606
+ return set_dev_attr(var, to_dev_id)
607
+
608
+ # TEE -> PPU (Encrypted transfer)
609
+ elif frm_to_pair == ("TEE", "PPU"):
610
+ assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
611
+ tee_rank = frm_dev.members[0].rank
612
+ ppu_rank = to_dev.members[0].rank
613
+
614
+ # Establish encrypted session (reuse existing or create new)
615
+ # Note: We pass (ppu, tee) order to match the session key derivation
616
+ sess_ppu, sess_tee = _ensure_tee_session(
617
+ to_dev_id, frm_dev_id, ppu_rank, tee_rank
618
+ )
619
+
620
+ # Encrypt on TEE and send to PPU
621
+ ct = simp.pcall_static((tee_rank,), crypto.sym_encrypt, sess_tee, obj)
622
+ ct_at_ppu = simp.shuffle_static(ct, {ppu_rank: tee_rank})
623
+
624
+ # Decrypt on PPU
625
+ var = simp.pcall_static(
626
+ (ppu_rank,),
627
+ crypto.sym_decrypt,
628
+ sess_ppu,
629
+ ct_at_ppu,
630
+ obj.type.value_type if hasattr(obj.type, "value_type") else obj.type,
631
+ )
632
+ return set_dev_attr(var, to_dev_id)
633
+
634
+ # TEE -> SPU (TEE acts like a PPU for SPU sealing)
635
+ elif frm_to_pair == ("TEE", "SPU"):
636
+ assert len(frm_dev.members) == 1
637
+ frm_rank = frm_dev.members[0].rank
638
+ spu_parties = tuple(m.rank for m in to_dev.members)
639
+ spu_config = spu.SPUConfig.from_dict(to_dev.config)
640
+
641
+ # Generate shares on TEE (same logic as PPU -> SPU)
642
+ shares_on_source = simp.pcall_static(
643
+ (frm_rank,),
644
+ spu.make_shares,
645
+ spu_config,
646
+ obj,
647
+ count=len(spu_parties),
648
+ )
649
+
650
+ # Distribute shares to SPU parties
651
+ distributed_shares = []
652
+ for i, target_rank in enumerate(spu_parties):
653
+ share_i = shares_on_source[i]
654
+ share_at_target = simp.shuffle_static(share_i, {target_rank: frm_rank})
655
+ distributed_shares.append(share_at_target)
656
+
657
+ # Converge shares
658
+ var = simp.converge(*distributed_shares)
659
+ return set_dev_attr(var, to_dev_id)
660
+
661
+ # SPU -> TEE (Reveal to TEE)
662
+ elif frm_to_pair == ("SPU", "TEE"):
663
+ assert len(to_dev.members) == 1
664
+ to_rank = to_dev.members[0].rank
665
+ spu_parties = tuple(m.rank for m in frm_dev.members)
666
+ spu_config = spu.SPUConfig.from_dict(frm_dev.config)
667
+
668
+ # Gather shares to TEE (same logic as SPU -> PPU)
669
+ gathered_shares = []
670
+ for source_rank in spu_parties:
671
+ share_on_source = simp.pcall_static((source_rank,), lambda x: x, obj)
672
+ share_at_target = simp.shuffle_static(
673
+ share_on_source, {to_rank: source_rank}
674
+ )
675
+ gathered_shares.append(share_at_target)
676
+
677
+ # Reconstruct on TEE
678
+ var = simp.pcall_static(
679
+ (to_rank,), lambda *s: spu.reconstruct(spu_config, s), *gathered_shares
680
+ )
681
+ return set_dev_attr(var, to_dev_id)
682
+
683
+ # TEE -> TEE
684
+ elif frm_to_pair == ("TEE", "TEE"):
685
+ raise NotImplementedError("TEE to TEE transfer not implemented yet.")
686
+
687
+ else:
688
+ raise DeviceError(f"Unsupported device transfer: {frm_to_pair}")
689
+
690
+
691
+ def put(to_dev_id: str, obj: Any) -> Object:
692
+ """Put data onto a device.
693
+
694
+ Args:
695
+ to_dev_id: Target device ID (e.g., "P0", "SP0").
696
+ obj: The object to put onto the device.
697
+
698
+ If obj is already a device object, it moves it to the target device.
699
+ If obj is a host object (e.g. numpy array), it uploads it to the target device.
700
+ """
701
+ cluster = _resolve_cluster()
702
+ if to_dev_id not in cluster.devices:
703
+ available = list(cluster.devices.keys())
704
+ raise DeviceNotFoundError(
705
+ f"Device '{to_dev_id}' not found in cluster. Available devices: {available}"
706
+ )
707
+
708
+ if isinstance(obj, Object) and is_device_obj(obj):
709
+ return _d2d(to_dev_id, obj)
710
+
711
+ # Host -> Device
712
+ dev_info = cluster.devices[to_dev_id]
713
+
714
+ if dev_info.kind.upper() == "PPU":
715
+ assert len(dev_info.members) == 1
716
+ rank = dev_info.members[0].rank
717
+
718
+ var = simp.constant((rank,), obj)
719
+ return set_dev_attr(var, to_dev_id)
720
+
721
+ elif dev_info.kind.upper() == "SPU":
722
+ # Host -> SPU: Run identity function on SPU.
723
+ # Note: This results in a Public (replicated) value on the SPU.
724
+ # SPU operations will automatically promote it to Secret if needed.
725
+ return cast(Object, device(to_dev_id)(lambda x: x)(obj))
726
+
727
+ elif dev_info.kind.upper() == "TEE":
728
+ # Host -> TEE: Similar to PPU, create constant on TEE device
729
+ assert len(dev_info.members) == 1
730
+ rank = dev_info.members[0].rank
731
+
732
+ var = simp.constant((rank,), obj)
733
+ return set_dev_attr(var, to_dev_id)
734
+
735
+ else:
736
+ raise DeviceError(f"Cannot put to device kind '{dev_info.kind}'")
737
+
738
+
739
+ def fetch(obj: Object) -> Any:
740
+ """Fetch data from device to host based on device attribute.
741
+
742
+ This function fetches data from the device the object resides on.
743
+ For PPU/TEE: fetches from the single member rank.
744
+ For SPU: fetches from all parties (returns reconstructed value).
745
+
746
+ Args:
747
+ obj: Object with device attribute to fetch.
748
+
749
+ Returns:
750
+ Python value (numpy array, scalar, etc.)
751
+ """
752
+ from mplang.v2.backends.simp_driver.state import SimpDriver
753
+ from mplang.v2.backends.simp_driver.values import DriverVar
754
+ from mplang.v2.edsl.context import get_current_context
755
+ from mplang.v2.runtime.interpreter import InterpObject, Interpreter
756
+ from mplang.v2.runtime.value import WrapValue
757
+
758
+ def _unwrap_value(val: Any) -> Any:
759
+ """Unwrap WrapValue to get the underlying data."""
760
+ if isinstance(val, WrapValue):
761
+ return val.data
762
+ return val
763
+
764
+ # 1. Ensure is object and is device obj
765
+ if not is_device_obj(obj):
766
+ raise DeviceError(
767
+ "Object does not have device attribute. Use mp.fetch() directly."
768
+ )
769
+
770
+ # 2. Get device information according to device id
771
+ dev_id = get_dev_attr(obj)
772
+ cluster = _resolve_cluster()
773
+ dev_info = cluster.devices[dev_id]
774
+
775
+ # Get interpreter context
776
+ ctx = get_current_context()
777
+ if not isinstance(ctx, Interpreter):
778
+ raise RuntimeError("No interpreter context available for fetch")
779
+
780
+ simp_state = ctx.get_dialect_state("simp")
781
+ assert isinstance(simp_state, SimpDriver), "DriverVar requires simp state"
782
+
783
+ # Unwrap InterpObject to get runtime value
784
+ assert isinstance(obj, InterpObject), f"Expected InterpObject, got {type(obj)}"
785
+ runtime_obj = obj.runtime_obj
786
+
787
+ def _fetch_from_rank(rank: int) -> Any:
788
+ """Fetch value from a rank (DriverVar values are always URIs)."""
789
+ uri = runtime_obj.values[rank]
790
+ assert isinstance(uri, str) and "://" in uri, f"Expected URI, got {uri}"
791
+ return simp_state.fetch(rank, uri).result()
792
+
793
+ # 3. Match device type and do corresponding fetch action
794
+ if isinstance(runtime_obj, DriverVar):
795
+ # 3.1 PPU/TEE: single member, fetch directly
796
+ if dev_info.kind.upper() in ("PPU", "TEE"):
797
+ assert len(dev_info.members) == 1
798
+ result = _fetch_from_rank(dev_info.members[0].rank)
799
+ # 4. Unwrap if WrapValue
800
+ return _unwrap_value(result)
801
+
802
+ # 3.2 SPU: fetch from all ranks and reconstruct
803
+ elif dev_info.kind.upper() == "SPU":
804
+ # Fetch shares from all SPU members
805
+ shares = [_fetch_from_rank(m.rank) for m in dev_info.members]
806
+ # For now, just return the first share (TODO: implement spu.reconstruct)
807
+ # In practice, SPU values should be revealed to a PPU first
808
+ result = shares[0] if shares else None
809
+ # 4. Unwrap if WrapValue
810
+ return _unwrap_value(result)
811
+
812
+ # Direct value (not DriverVar)
813
+ return _unwrap_value(runtime_obj)