mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
mplang/v1/_device.py DELETED
@@ -1,602 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- This module provides the device oriented programming interface for MPC.
17
-
18
- The device oriented programming interface is designed to provide a high-level
19
- abstraction for the MPC programming. It allows the user to write the program
20
- in a device-oriented manner, and the runtime will take care of the data
21
- transformation between devices.
22
- """
23
-
24
- from __future__ import annotations
25
-
26
- from collections.abc import Callable
27
- from functools import partial, wraps
28
- from typing import Any, cast
29
-
30
- from jax.tree_util import tree_map, tree_unflatten
31
-
32
- from mplang.v1.core import (
33
- ClusterSpec,
34
- Device,
35
- Mask,
36
- MPContext,
37
- MPObject,
38
- TableLike,
39
- TensorLike,
40
- cur_ctx,
41
- peval,
42
- )
43
- from mplang.v1.ops import basic, crypto, jax_cc, nnx_cc, spu, tee
44
- from mplang.v1.ops.base import FeOperation
45
- from mplang.v1.ops.jax_cc import JaxRunner
46
- from mplang.v1.simp import mpi
47
- from mplang.v1.simp.api import run_at
48
-
49
- # Automatic transfer between devices when parameter is not on the target device.
50
- g_auto_trans: bool = True
51
-
52
- _HKDF_INFO_LITERAL: str = "mplang/device/tee/v1"
53
- # Default KEM suite for TEE session establishment; make configurable via ClusterSpec in future.
54
- _TEE_KEM_SUITE: str = "x25519"
55
-
56
-
57
- # Context-aware session management
58
- def _get_context_id(ctx: MPContext) -> int:
59
- """
60
- Get unique identifier for a context.
61
-
62
- Args:
63
- ctx: The context object (TraceContext or InterpContext)
64
-
65
- Returns:
66
- Unique integer ID for this context instance
67
- """
68
- return id(ctx)
69
-
70
-
71
- # magic attribute name to mark a MPObject as a device object
72
- DEVICE_ATTR_NAME = "__device__"
73
-
74
-
75
- def is_device_obj(obj: Any) -> bool:
76
- if not isinstance(obj, MPObject):
77
- return False
78
- return DEVICE_ATTR_NAME in obj.attrs
79
-
80
-
81
- def set_dev_attr(obj: MPObject, dev_id: str) -> MPObject:
82
- if not isinstance(obj, MPObject):
83
- raise TypeError(f"Input must be an instance of MPObject, {obj}")
84
- obj.attrs[DEVICE_ATTR_NAME] = dev_id
85
- return obj
86
-
87
-
88
- def get_dev_attr(obj: MPObject) -> str:
89
- if not isinstance(obj, MPObject):
90
- raise TypeError("Input must be an instance of MPObject")
91
-
92
- return str(obj.attrs[DEVICE_ATTR_NAME])
93
-
94
-
95
- def _infer_device_from_args(*args: Any, **kwargs: Any) -> str:
96
- """Infer target device from function arguments.
97
-
98
- Inference strategy:
99
- 1. Collect all MPObject arguments and check device attributes
100
- - If MPObject has no device attr -> error (user must set_devid)
101
- - If no MPObject arguments -> error (explicit device required)
102
-
103
- 2. Analyze device distribution
104
- 2.1 All objects on same device -> use that device
105
- 2.2 Multiple devices with g_auto_trans enabled:
106
- - Single SPU (+ PPUs) -> use SPU (auto-transfer from PPUs)
107
- - Single TEE (+ PPUs) -> use TEE (auto-transfer from PPUs)
108
- - Multiple PPUs only -> error (ambiguous, need explicit device)
109
- - Multiple SPUs -> error (ambiguous)
110
- - Multiple TEEs -> error (ambiguous)
111
- - SPU + TEE -> error (conflicting secure devices)
112
- 2.3 Multiple devices with g_auto_trans disabled -> error
113
-
114
- Args:
115
- *args: Positional arguments
116
- **kwargs: Keyword arguments
117
-
118
- Returns:
119
- Device id string
120
-
121
- Raises:
122
- ValueError: When inference fails or is ambiguous
123
- """
124
- from jax.tree_util import tree_flatten
125
-
126
- # Step 1: Collect all MPObject arguments and validate device attributes
127
- all_args = tree_flatten((args, kwargs))[0]
128
- device_objs = []
129
-
130
- for obj in all_args:
131
- if isinstance(obj, MPObject):
132
- if not is_device_obj(obj):
133
- raise ValueError(
134
- "MPObject is missing device attribute. "
135
- "If you're mixing device-level and simp-level code, "
136
- "use set_dev_attr(obj, 'device_id') to mark the device explicitly."
137
- )
138
- device_objs.append(obj)
139
-
140
- if not device_objs:
141
- raise ValueError(
142
- "Cannot infer device: no MPObject arguments found. "
143
- "Please specify device explicitly using device('device_id')(fn)."
144
- )
145
-
146
- # Step 2: Extract all unique devices
147
- devices = {get_dev_attr(obj) for obj in device_objs}
148
-
149
- if len(devices) == 1:
150
- return devices.pop() # All arguments on same device
151
-
152
- # Step 3: Multiple devices - check if auto-transfer is enabled
153
- if not g_auto_trans:
154
- raise ValueError(
155
- f"Cannot infer device: arguments from multiple devices {devices} "
156
- f"but auto-transfer is disabled (g_auto_trans=False). "
157
- f"Please enable auto-transfer or put all data on same device first."
158
- )
159
-
160
- # Step 4: Analyze device kinds for auto-transfer scenario
161
- cluster_spec = cur_ctx().cluster_spec
162
- device_kinds = {
163
- dev_id: cluster_spec.devices[dev_id].kind.upper() for dev_id in devices
164
- }
165
-
166
- # Count devices by type
167
- spu_devs = [d for d, k in device_kinds.items() if k == "SPU"]
168
- tee_devs = [d for d, k in device_kinds.items() if k == "TEE"]
169
- ppu_devs = [d for d, k in device_kinds.items() if k == "PPU"]
170
-
171
- # Decision logic
172
- # Case 1: Only PPUs -> ambiguous
173
- if not spu_devs and not tee_devs:
174
- raise ValueError(
175
- f"Cannot infer device: arguments from multiple PPU devices {ppu_devs}. "
176
- f"Please specify device explicitly or use put() to consolidate data."
177
- )
178
-
179
- # Case 2: Single SPU (possibly with PPUs) -> use SPU
180
- if len(spu_devs) == 1 and len(tee_devs) == 0:
181
- return spu_devs[0]
182
-
183
- # Case 3: Single TEE (possibly with PPUs) -> use TEE
184
- if len(tee_devs) == 1 and len(spu_devs) == 0:
185
- return tee_devs[0]
186
-
187
- # Case 4: Multiple SPUs -> ambiguous
188
- if len(spu_devs) > 1:
189
- raise ValueError(
190
- f"Ambiguous device inference: arguments from multiple SPU devices {spu_devs}. "
191
- f"Please specify which SPU to use explicitly."
192
- )
193
-
194
- # Case 5: Multiple TEEs -> ambiguous
195
- if len(tee_devs) > 1:
196
- raise ValueError(
197
- f"Ambiguous device inference: arguments from multiple TEE devices {tee_devs}. "
198
- f"Please specify which TEE to use explicitly."
199
- )
200
-
201
- # Case 6: Both SPU and TEE -> conflicting
202
- if spu_devs and tee_devs:
203
- raise ValueError(
204
- f"Ambiguous device inference: arguments from both SPU {spu_devs} and TEE {tee_devs}. "
205
- f"Please specify which secure device to use explicitly."
206
- )
207
-
208
- # Should never reach here
209
- raise ValueError(f"Unexpected device configuration: {devices}")
210
-
211
-
212
- def _device_run_spu(
213
- dev_info: Device, op: FeOperation, fn: Callable, *args: Any, **kwargs: Any
214
- ) -> Any:
215
- if not isinstance(op, JaxRunner):
216
- raise ValueError("SPU device only supports JAX frontend.")
217
- spu_mask = Mask.from_ranks([member.rank for member in dev_info.members])
218
- pfunc, in_vars, out_tree = spu.jax_compile(fn, *args, **kwargs)
219
- assert all(var.pmask == spu_mask for var in in_vars), in_vars
220
- out_flat = peval(pfunc, in_vars, spu_mask)
221
- result = tree_unflatten(out_tree, out_flat)
222
- return tree_map(partial(set_dev_attr, dev_id=dev_info.name), result)
223
-
224
-
225
- def _device_run_tee(
226
- dev_info: Device, op: FeOperation, *args: Any, **kwargs: Any
227
- ) -> Any:
228
- # TODO(jint): should we filter out all IO operations?
229
- assert len(dev_info.members) == 1
230
- rank = dev_info.members[0].rank
231
- var = run_at(rank, op, *args, **kwargs)
232
- return tree_map(partial(set_dev_attr, dev_id=dev_info.name), var)
233
-
234
-
235
- def _device_run_ppu(
236
- dev_info: Device, op: FeOperation, *args: Any, **kwargs: Any
237
- ) -> Any:
238
- assert len(dev_info.members) == 1
239
- rank = dev_info.members[0].rank
240
- var = run_at(rank, op, *args, **kwargs)
241
- return tree_map(partial(set_dev_attr, dev_id=dev_info.name), var)
242
-
243
-
244
- def _device_run(dev_id: str, op: FeOperation, *args: Any, **kwargs: Any) -> Any:
245
- assert isinstance(op, FeOperation)
246
- cluster_spec = cur_ctx().cluster_spec
247
- if dev_id not in cluster_spec.devices:
248
- raise ValueError(f"Device {dev_id} not found in cluster spec.")
249
- dev_info = cluster_spec.devices[dev_id]
250
-
251
- if g_auto_trans:
252
-
253
- def trans(obj: Any) -> Any:
254
- if isinstance(obj, MPObject):
255
- assert is_device_obj(obj)
256
- return _d2d(dev_id, obj)
257
- else:
258
- return obj
259
-
260
- args, kwargs = tree_map(trans, (args, kwargs))
261
-
262
- if dev_info.kind.upper() == "SPU":
263
- return _device_run_spu(dev_info, op, *args, **kwargs)
264
- elif dev_info.kind.upper() == "TEE":
265
- return _device_run_tee(dev_info, op, *args, **kwargs)
266
- elif dev_info.kind.upper() == "PPU":
267
- return _device_run_ppu(dev_info, op, *args, **kwargs)
268
- else:
269
- raise ValueError(f"Unknown device type: {dev_info.kind}")
270
-
271
-
272
- def device(
273
- dev_or_fn: str | Callable | None = None, *, fe_type: str = "jax"
274
- ) -> Callable:
275
- """Decorator to mark a function to be executed on a specific device.
276
-
277
- Supports both explicit device specification and automatic device inference:
278
-
279
- 1. Explicit device placement:
280
- @device("P0")
281
- def foo(x, y): return x + y
282
-
283
- 2. Auto device inference:
284
- @device
285
- def foo(x, y): return x + y
286
- # Device is inferred from x, y at runtime
287
-
288
- 3. Inline usage:
289
- result = device(lambda x, y: x + y)(x_on_p0, y_on_p0)
290
- # Automatically infers device from arguments
291
-
292
- Args:
293
- dev_or_fn: Either a device id string ("P0", "SPU", etc.) for explicit placement,
294
- a callable function for auto inference, or None (same as not providing arg).
295
- fe_type: The frontend type of the device, could be "jax" or "nnx".
296
- Not needed if the decorated function is already a FeOperation.
297
-
298
- Returns:
299
- A decorator (when dev_or_fn is a string or None) or decorated function (when callable).
300
-
301
- Raises:
302
- TypeError: When dev_or_fn is not a string, callable, or None.
303
- ValueError: When device cannot be inferred or inference is ambiguous.
304
-
305
- Device Inference Strategy:
306
- - Same device: All arguments on device D -> execute on D
307
- - PPU + SPU: Arguments from PPU and SPU -> execute on SPU (secure compute)
308
- - PPU + TEE: Arguments from PPU and TEE -> execute on TEE (trusted execution)
309
- - Multiple PPUs: Ambiguous -> error (explicit device required)
310
- - No device objects: Cannot infer -> error (explicit device required)
311
-
312
- Example:
313
- >>> # Explicit device
314
- >>> @device("P0")
315
- ... def add_explicit(x, y):
316
- ... return x + y
317
- >>>
318
- >>> # Auto inference
319
- >>> @device
320
- ... def add_auto(x, y):
321
- ... return x + y
322
- >>>
323
- >>> x_on_p0 = ... # data on P0
324
- >>> y_on_p0 = ... # data on P0
325
- >>> result = add_auto(x_on_p0, y_on_p0) # Inferred to P0
326
- >>>
327
- >>> x_on_spu = ... # data on SPU
328
- >>> y_on_p1 = ... # data on P1
329
- >>> result = add_auto(x_on_spu, y_on_p1) # Inferred to SPU
330
- """
331
-
332
- def _execute_on_device(dev_id: str, fn: Callable, *args: Any, **kwargs: Any) -> Any:
333
- """Helper to execute function on specified device with appropriate frontend."""
334
- if isinstance(fn, FeOperation):
335
- return _device_run(dev_id, fn, *args, **kwargs)
336
- else:
337
- if fe_type == "jax":
338
- return _device_run(dev_id, jax_cc.run_jax, fn, *args, **kwargs)
339
- elif fe_type == "nnx":
340
- return _device_run(dev_id, nnx_cc.run_nnx, fn, *args, **kwargs)
341
- else:
342
- raise ValueError(f"Unsupported frontend type: {fe_type}")
343
-
344
- # Case 1: device("P0") - Explicit device specification
345
- if isinstance(dev_or_fn, str):
346
- dev_id = dev_or_fn
347
-
348
- def deco(fn: Callable) -> Callable:
349
- @wraps(fn)
350
- def wrapped(*args: Any, **kwargs: Any) -> Any:
351
- return _execute_on_device(dev_id, fn, *args, **kwargs)
352
-
353
- return wrapped
354
-
355
- return deco
356
-
357
- # Case 2: device(fn) or @device - Auto device inference
358
- elif callable(dev_or_fn):
359
- fn = dev_or_fn
360
-
361
- @wraps(fn)
362
- def wrapped(*args: Any, **kwargs: Any) -> Any:
363
- try:
364
- dev_id = _infer_device_from_args(*args, **kwargs)
365
- except ValueError as e:
366
- # Enhance error message with function context
367
- raise ValueError(
368
- f"Cannot infer device for function '{fn.__name__}': {e!s}"
369
- ) from e
370
-
371
- return _execute_on_device(dev_id, fn, *args, **kwargs)
372
-
373
- return wrapped
374
-
375
- # Case 3: device() or @device() - Return auto-inference decorator
376
- elif dev_or_fn is None:
377
-
378
- def deco(fn: Callable) -> Callable:
379
- return device(fn, fe_type=fe_type)
380
-
381
- return deco
382
-
383
- else:
384
- # More helpful error message for common mistakes
385
- raise TypeError(
386
- f"device() expects a device id (string), a function (callable), or nothing. "
387
- f"Got: {type(dev_or_fn).__name__}.\n"
388
- f"Usage:\n"
389
- f" - Explicit device: @device('P0') or device('P0')(fn)\n"
390
- f" - Auto inference: @device or device(fn)"
391
- )
392
-
393
-
394
- def _spu_reveal(spu_dev: Device, obj: MPObject, to_mask: Mask) -> MPObject:
395
- spu_mask = Mask.from_ranks([m.rank for m in spu_dev.members])
396
- assert obj.pmask == spu_mask, (obj.pmask, spu_mask)
397
-
398
- # (n_parties, n_shares)
399
- shares = [mpi.bcast_m(to_mask, rank, obj) for rank in Mask(spu_mask)]
400
- assert len(shares) == Mask(spu_mask).num_parties(), (shares, spu_mask)
401
- assert all(share.pmask == to_mask for share in shares)
402
-
403
- # Reconstruct the original object from shares
404
- pfunc, ins, _ = spu.reconstruct(*shares)
405
- return peval(pfunc, ins, to_mask)[0] # type: ignore[no-any-return]
406
-
407
-
408
- def _spu_seal(spu_dev: Device, obj: MPObject) -> list[MPObject]:
409
- """Seal plaintext into SPU shares on a specific SPU device.
410
-
411
- Low-level API: device id is mandatory to avoid ambiguity.
412
- """
413
- if obj.pmask is None:
414
- raise ValueError("Seal can not apply to dynamic mask objects.")
415
-
416
- spu_mask = Mask.from_ranks([member.rank for member in spu_dev.members])
417
- spu_wsize = Mask(spu_mask).num_parties()
418
- pfunc, ins, _ = spu.makeshares(
419
- obj, world_size=spu_wsize, visibility=spu.Visibility.SECRET
420
- )
421
- assert len(ins) == 1
422
- shares = peval(pfunc, ins)
423
-
424
- # scatter the shares to each party.
425
- outs = [mpi.scatter_m(spu_mask, rank, shares) for rank in obj.pmask]
426
- return outs
427
-
428
-
429
- def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
430
- assert isinstance(obj, MPObject)
431
- frm_dev_id = get_dev_attr(obj)
432
-
433
- if frm_dev_id == to_dev_id:
434
- return obj
435
-
436
- cluster_spec: ClusterSpec = cur_ctx().cluster_spec
437
- frm_dev = cluster_spec.devices[frm_dev_id]
438
- to_dev = cluster_spec.devices[to_dev_id]
439
- frm_to_pair = (frm_dev.kind.upper(), to_dev.kind.upper())
440
-
441
- if frm_to_pair == ("SPU", "SPU"):
442
- raise NotImplementedError("Only one SPU is supported for now.")
443
- elif frm_to_pair == ("SPU", "PPU"):
444
- assert len(to_dev.members) == 1
445
- to_rank = to_dev.members[0].rank
446
- var = _spu_reveal(frm_dev, obj, Mask.from_ranks([to_rank]))
447
- return tree_map(partial(set_dev_attr, dev_id=to_dev_id), var) # type: ignore[no-any-return]
448
- elif frm_to_pair == ("PPU", "SPU"):
449
- assert len(frm_dev.members) == 1
450
- frm_rank = frm_dev.members[0].rank
451
- vars = _spu_seal(to_dev, obj)
452
- assert len(vars) == 1, "Expected single share from PPU to SPU seal."
453
- return tree_map(partial(set_dev_attr, dev_id=to_dev_id), vars[0]) # type: ignore[no-any-return]
454
- elif frm_to_pair == ("PPU", "PPU"):
455
- assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
456
- frm_rank = frm_dev.members[0].rank
457
- to_rank = to_dev.members[0].rank
458
- var = mpi.p2p(frm_rank, to_rank, obj)
459
- return tree_map(partial(set_dev_attr, dev_id=to_dev_id), var) # type: ignore[no-any-return]
460
- elif frm_to_pair == ("PPU", "TEE"):
461
- # Transparent handshake + encryption for the first transfer; reuse thereafter
462
- assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
463
- frm_rank = frm_dev.members[0].rank
464
- tee_rank = to_dev.members[0].rank
465
- # Ensure sessions (both directions) exist for this PPU<->TEE pair
466
- sess_p, sess_t = _ensure_tee_session(frm_dev_id, to_dev_id, frm_rank, tee_rank)
467
- # Bytes-only path: pack -> enc -> p2p -> dec -> unpack (with static out type)
468
- obj_ty = obj.mptype.raw_type()
469
- b = run_at(frm_rank, basic.pack, obj)
470
- ct = run_at(frm_rank, crypto.enc, b, sess_p)
471
- ct_at_tee = mpi.p2p(frm_rank, tee_rank, ct)
472
- b_at_tee = run_at(tee_rank, crypto.dec, ct_at_tee, sess_t)
473
- pt_at_tee = run_at(tee_rank, basic.unpack, b_at_tee, out_ty=obj_ty)
474
- return tree_map(partial(set_dev_attr, dev_id=to_dev_id), pt_at_tee) # type: ignore[no-any-return]
475
- elif frm_to_pair == ("TEE", "PPU"):
476
- # Transparent encryption from TEE to a specific PPU using the reverse-direction session key
477
- assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
478
- tee_rank = frm_dev.members[0].rank
479
- ppu_rank = to_dev.members[0].rank
480
- # Ensure bidirectional session established for this pair
481
- sess_p, sess_t = _ensure_tee_session(to_dev_id, frm_dev_id, ppu_rank, tee_rank)
482
- obj_ty = obj.mptype.raw_type()
483
- b = run_at(tee_rank, basic.pack, obj)
484
- ct = run_at(tee_rank, crypto.enc, b, sess_t)
485
- ct_at_ppu = mpi.p2p(tee_rank, ppu_rank, ct)
486
- b_at_ppu = run_at(ppu_rank, crypto.dec, ct_at_ppu, sess_p)
487
- pt_at_ppu = run_at(ppu_rank, basic.unpack, b_at_ppu, out_ty=obj_ty)
488
- return tree_map(partial(set_dev_attr, dev_id=to_dev_id), pt_at_ppu) # type: ignore[no-any-return]
489
- elif frm_to_pair == ("TEE", "SPU"):
490
- assert len(frm_dev.members) == 1
491
- frm_rank = frm_dev.members[0].rank
492
- vars = _spu_seal(to_dev, obj)
493
- assert len(vars) == 1, "Expected single share from TEE to SPU seal."
494
- return tree_map(partial(set_dev_attr, dev_id=to_dev_id), vars[0]) # type: ignore[no-any-return]
495
- elif frm_to_pair == ("SPU", "TEE"):
496
- assert len(to_dev.members) == 1
497
- to_rank = to_dev.members[0].rank
498
- var = _spu_reveal(frm_dev, obj, Mask.from_ranks([to_rank]))
499
- return tree_map(partial(set_dev_attr, dev_id=to_dev_id), var) # type: ignore[no-any-return]
500
- else:
501
- supported = [
502
- ("SPU", "PPU"),
503
- ("PPU", "SPU"),
504
- ("PPU", "PPU"),
505
- ("PPU", "TEE"),
506
- ("TEE", "PPU"),
507
- ("TEE", "SPU"),
508
- ("SPU", "TEE"),
509
- ]
510
- raise ValueError(
511
- f"Unsupported device transfer: {frm_to_pair}. Supported pairs: {supported}."
512
- )
513
-
514
-
515
- def _ensure_tee_session(
516
- frm_dev_id: str, to_dev_id: str, frm_rank: int, tee_rank: int
517
- ) -> tuple[MPObject, MPObject]:
518
- """Ensure a TEE session (sess_p at sender, sess_t at TEE) exists.
519
-
520
- Context-aware version: caches include context ID to ensure isolation
521
- between different TraceContext instances, preventing TraceVar pollution.
522
-
523
- Returns (sess_p, sess_t).
524
- """
525
- # Get current context and its unique ID
526
- current_ctx = cur_ctx()
527
- current_context_id = _get_context_id(current_ctx)
528
-
529
- # Get root context for cache storage
530
- root_ctx = current_ctx.root()
531
- if not hasattr(root_ctx, "_tee_sessions"):
532
- root_ctx._tee_sessions = {} # type: ignore[attr-defined]
533
- cache: dict[tuple[str, str], tuple[int, MPObject, MPObject]] = (
534
- root_ctx._tee_sessions # type: ignore[attr-defined]
535
- )
536
-
537
- key = (frm_dev_id, to_dev_id)
538
-
539
- # Check cache with context awareness
540
- if key in cache:
541
- cached_context_id, sess_p, sess_t = cache[key]
542
-
543
- # Only reuse cache from the same context
544
- if cached_context_id == current_context_id:
545
- return sess_p, sess_t
546
- else:
547
- # Different context, cannot reuse cache, clean up old entry
548
- del cache[key]
549
-
550
- # 1) TEE generates (sk, pk) and quote(pk)
551
- # KEM suite currently constant; future: read from tee device config (e.g. cluster_spec.devices[dev_id].config)
552
- tee_sk, tee_pk = run_at(tee_rank, crypto.kem_keygen, _TEE_KEM_SUITE)
553
- quote = run_at(tee_rank, tee.quote_gen, tee_pk)
554
-
555
- # 2) Send quote to sender and attest to obtain TEE pk
556
- quote_at_sender = mpi.p2p(tee_rank, frm_rank, quote)
557
- tee_pk_at_sender = run_at(frm_rank, tee.attest, quote_at_sender)
558
-
559
- # 3) Sender generates its ephemeral keypair and sends its pk to TEE
560
- v_sk, v_pk = run_at(frm_rank, crypto.kem_keygen, _TEE_KEM_SUITE)
561
- v_pk_at_tee = mpi.p2p(frm_rank, tee_rank, v_pk)
562
-
563
- # 4) Both sides derive the shared secret and session key
564
- shared_p = run_at(
565
- frm_rank, crypto.kem_derive, v_sk, tee_pk_at_sender, _TEE_KEM_SUITE
566
- )
567
- shared_t = run_at(tee_rank, crypto.kem_derive, tee_sk, v_pk_at_tee, _TEE_KEM_SUITE)
568
- # Use a fixed ASCII string literal for HKDF info on both sides
569
- sess_p = run_at(frm_rank, crypto.hkdf, shared_p, _HKDF_INFO_LITERAL)
570
- sess_t = run_at(tee_rank, crypto.hkdf, shared_t, _HKDF_INFO_LITERAL)
571
-
572
- # Cache with context ID for isolation
573
- cache[key] = (current_context_id, sess_p, sess_t)
574
- return sess_p, sess_t
575
-
576
-
577
- def _host_to_device(to_dev_id: str, obj: Any) -> MPObject:
578
- if isinstance(obj, TensorLike):
579
- # run jax identity on the target device to put the tensor there
580
- return device(to_dev_id)(lambda x: x)(obj) # type: ignore[no-any-return]
581
- elif isinstance(obj, TableLike):
582
- dev_info = cur_ctx().cluster_spec.devices[to_dev_id]
583
- if dev_info.kind.upper() not in ["PPU", "TEE"]:
584
- raise ValueError(
585
- f"TableLike put() only supports PPU or TEE devices, got {dev_info.kind}"
586
- )
587
- assert len(dev_info.members) == 1
588
- rank = dev_info.members[0].rank
589
- obj_mp = cast(MPObject, run_at(rank, basic.constant, obj))
590
- set_dev_attr(obj_mp, to_dev_id)
591
- return obj_mp
592
- else:
593
- raise TypeError(
594
- f"put() only supports TensorLike or TableLike objects, got {type(obj)}"
595
- )
596
-
597
-
598
- def put(to_dev_id: str, obj: Any) -> MPObject:
599
- if not isinstance(obj, MPObject):
600
- return _host_to_device(to_dev_id, obj)
601
- assert isinstance(obj, MPObject)
602
- return _d2d(to_dev_id, obj)
@@ -1,37 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Analysis and visualization utilities for mplang.
16
-
17
- This subpackage hosts non-core developer aids: diagram rendering, IR dumps,
18
- profiling helpers (future), etc.
19
- """
20
-
21
- from mplang.v1.analysis.diagram import (
22
- DumpResult,
23
- FlowchartOptions,
24
- SequenceDiagramOptions,
25
- dump,
26
- to_flowchart,
27
- to_sequence_diagram,
28
- )
29
-
30
- __all__ = [
31
- "DumpResult",
32
- "FlowchartOptions",
33
- "SequenceDiagramOptions",
34
- "dump",
35
- "to_flowchart",
36
- "to_sequence_diagram",
37
- ]