mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,813 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
Device-oriented programming interface for MPLang2.
|
|
17
|
+
|
|
18
|
+
This module provides high-level abstractions for device placement and data movement.
|
|
19
|
+
It allows users to write programs in a device-centric way, handling data transfers
|
|
20
|
+
and execution dispatch automatically.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
from collections.abc import Callable
|
|
26
|
+
from functools import partial, wraps
|
|
27
|
+
from typing import Any, cast
|
|
28
|
+
|
|
29
|
+
from jax.tree_util import tree_flatten, tree_map
|
|
30
|
+
|
|
31
|
+
from mplang.v2.backends import load_builtins
|
|
32
|
+
from mplang.v2.dialects import crypto, simp, spu, tee
|
|
33
|
+
from mplang.v2.edsl.object import Object
|
|
34
|
+
from mplang.v2.libs.device.cluster import Device
|
|
35
|
+
|
|
36
|
+
load_builtins()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _resolve_cluster() -> Any:
|
|
40
|
+
"""Resolve the active ClusterSpec by traversing the context stack.
|
|
41
|
+
|
|
42
|
+
Traverses from the top of the stack (most recent) to find the nearest
|
|
43
|
+
Interpreter with a _cluster_spec attribute. This allows nested contexts
|
|
44
|
+
to override the cluster if needed.
|
|
45
|
+
"""
|
|
46
|
+
from mplang.v2.edsl.context import find_context
|
|
47
|
+
|
|
48
|
+
ctx = find_context(lambda c: getattr(c, "_cluster_spec", None) is not None)
|
|
49
|
+
if ctx is not None:
|
|
50
|
+
return ctx._cluster_spec # type: ignore[attr-defined]
|
|
51
|
+
|
|
52
|
+
raise RuntimeError(
|
|
53
|
+
"No active device context found. Please use 'with simulator:' "
|
|
54
|
+
"or 'push_context(sim)' to set the execution environment."
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# Magic attribute name to mark an Object as a device object
|
|
59
|
+
DEVICE_ATTR_NAME = "__device__"
|
|
60
|
+
|
|
61
|
+
# Default KEM suite for TEE session establishment
|
|
62
|
+
_TEE_KEM_SUITE: str = "x25519"
|
|
63
|
+
|
|
64
|
+
# HKDF info string for TEE session key derivation (domain separation)
|
|
65
|
+
_TEE_HKDF_INFO: str = "mplang/device/tee/v2"
|
|
66
|
+
|
|
67
|
+
# Global cache for TEE sessions (keyed by (frm_dev_id, to_dev_id))
|
|
68
|
+
# Each entry is (context_id, sess_frm, sess_tee) where context_id ensures
|
|
69
|
+
# sessions are not reused across different trace/interp contexts.
|
|
70
|
+
_tee_session_cache: dict[tuple[str, str], tuple[int, Object, Object]] = {}
|
|
71
|
+
|
|
72
|
+
# Automatic transfer between devices when parameter is not on the target device.
|
|
73
|
+
g_auto_trans: bool = True
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class DeviceError(Exception):
|
|
77
|
+
"""Base exception for device-related errors."""
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class DeviceNotFoundError(DeviceError):
|
|
81
|
+
"""Raised when a device ID is not found in the cluster."""
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class DeviceInferenceError(DeviceError):
|
|
85
|
+
"""Raised when device cannot be inferred from arguments."""
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def is_device_obj(obj: Any) -> bool:
|
|
89
|
+
"""Check if an object is a device object (has device attribute)."""
|
|
90
|
+
if not isinstance(obj, Object):
|
|
91
|
+
return False
|
|
92
|
+
return hasattr(obj, DEVICE_ATTR_NAME)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def set_dev_attr(obj: Object, dev_id: str) -> Object:
|
|
96
|
+
"""Mark an object as residing on a specific device."""
|
|
97
|
+
if not isinstance(obj, Object):
|
|
98
|
+
raise TypeError(f"Input must be an instance of Object, got {type(obj)}")
|
|
99
|
+
setattr(obj, DEVICE_ATTR_NAME, dev_id)
|
|
100
|
+
return obj
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def get_dev_attr(obj: Object) -> str:
|
|
104
|
+
"""Get the device ID of an object."""
|
|
105
|
+
if not isinstance(obj, Object):
|
|
106
|
+
raise TypeError("Input must be an instance of Object")
|
|
107
|
+
if not hasattr(obj, DEVICE_ATTR_NAME):
|
|
108
|
+
raise ValueError("Object does not have a device attribute")
|
|
109
|
+
return str(getattr(obj, DEVICE_ATTR_NAME))
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _maybe_set_dev_attr(dev_id: str, obj: Any) -> Any:
|
|
113
|
+
"""Set device attribute if obj is an Object, otherwise return as-is."""
|
|
114
|
+
if isinstance(obj, Object):
|
|
115
|
+
return set_dev_attr(obj, dev_id)
|
|
116
|
+
return obj
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _infer_device_from_args(*args: Any, **kwargs: Any) -> str:
|
|
120
|
+
"""Infer target device from function arguments."""
|
|
121
|
+
all_args = tree_flatten((args, kwargs))[0]
|
|
122
|
+
device_objs = []
|
|
123
|
+
|
|
124
|
+
for obj in all_args:
|
|
125
|
+
if isinstance(obj, Object):
|
|
126
|
+
if not is_device_obj(obj):
|
|
127
|
+
# Skip non-device objects (they might be purely local/host values)
|
|
128
|
+
continue
|
|
129
|
+
device_objs.append(obj)
|
|
130
|
+
|
|
131
|
+
if not device_objs:
|
|
132
|
+
raise DeviceInferenceError(
|
|
133
|
+
"Cannot infer device: no device-bound Object arguments found. "
|
|
134
|
+
"Please specify device explicitly using device('device_id')."
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
devices = {get_dev_attr(obj) for obj in device_objs}
|
|
138
|
+
|
|
139
|
+
if len(devices) == 1:
|
|
140
|
+
return devices.pop() # All arguments on same device
|
|
141
|
+
|
|
142
|
+
if not g_auto_trans:
|
|
143
|
+
raise DeviceInferenceError(
|
|
144
|
+
f"Cannot infer device: arguments from multiple devices {devices} "
|
|
145
|
+
f"but auto-transfer is disabled (g_auto_trans=False). "
|
|
146
|
+
f"Please enable auto-transfer or put all data on same device first."
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
cluster = _resolve_cluster()
|
|
150
|
+
device_kinds = {dev_id: cluster.devices[dev_id].kind.upper() for dev_id in devices}
|
|
151
|
+
|
|
152
|
+
# Count devices by type
|
|
153
|
+
spu_devs = [d for d, k in device_kinds.items() if k == "SPU"]
|
|
154
|
+
tee_devs = [d for d, k in device_kinds.items() if k == "TEE"]
|
|
155
|
+
ppu_devs = [d for d, k in device_kinds.items() if k == "PPU"]
|
|
156
|
+
|
|
157
|
+
# Decision logic
|
|
158
|
+
# Case 1: Only PPUs -> ambiguous (unless we want to pick one arbitrarily, but safer to error)
|
|
159
|
+
if not spu_devs and not tee_devs:
|
|
160
|
+
raise DeviceInferenceError(
|
|
161
|
+
f"Cannot infer device: arguments from multiple PPU devices {ppu_devs}. "
|
|
162
|
+
f"Please specify device explicitly or use put() to consolidate data."
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# Case 2: Single SPU (possibly with PPUs) -> use SPU
|
|
166
|
+
if len(spu_devs) == 1 and len(tee_devs) == 0:
|
|
167
|
+
return spu_devs[0]
|
|
168
|
+
|
|
169
|
+
# Case 3: Single TEE (possibly with PPUs) -> use TEE
|
|
170
|
+
if len(tee_devs) == 1 and len(spu_devs) == 0:
|
|
171
|
+
return tee_devs[0]
|
|
172
|
+
|
|
173
|
+
# Case 4: Multiple SPUs -> ambiguous
|
|
174
|
+
if len(spu_devs) > 1:
|
|
175
|
+
raise DeviceInferenceError(
|
|
176
|
+
f"Ambiguous device inference: arguments from multiple SPU devices {spu_devs}. "
|
|
177
|
+
f"Please specify which SPU to use explicitly."
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Case 5: Multiple TEEs -> ambiguous
|
|
181
|
+
if len(tee_devs) > 1:
|
|
182
|
+
raise DeviceInferenceError(
|
|
183
|
+
f"Ambiguous device inference: arguments from multiple TEE devices {tee_devs}. "
|
|
184
|
+
f"Please specify which TEE to use explicitly."
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Case 6: Both SPU and TEE -> conflicting
|
|
188
|
+
if spu_devs and tee_devs:
|
|
189
|
+
raise DeviceInferenceError(
|
|
190
|
+
f"Ambiguous device inference: arguments from both SPU {spu_devs} and TEE {tee_devs}. "
|
|
191
|
+
f"Please specify which secure device to use explicitly."
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# Should never reach here
|
|
195
|
+
raise DeviceInferenceError(f"Unexpected device configuration: {devices}")
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _device_run_spu(dev_info: Device, fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
199
|
+
"""Run function on SPU device."""
|
|
200
|
+
spu_parties = tuple(m.rank for m in dev_info.members)
|
|
201
|
+
|
|
202
|
+
# SPU execution uses spu.run_jax to compile and execute the function on the SPU.
|
|
203
|
+
# Inputs are expected to be already on the SPU (handled by _d2d).
|
|
204
|
+
# We wrap spu.run_jax in simp.pcall_static to execute it on all SPU parties.
|
|
205
|
+
spu_config = spu.SPUConfig.from_dict(dev_info.config)
|
|
206
|
+
result = simp.pcall_static(
|
|
207
|
+
spu_parties,
|
|
208
|
+
spu.run_jax,
|
|
209
|
+
spu_config,
|
|
210
|
+
fn,
|
|
211
|
+
*args,
|
|
212
|
+
**kwargs,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
return tree_map(partial(set_dev_attr, dev_id=dev_info.name), result)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _device_run_ppu(
|
|
219
|
+
dev_info: Device,
|
|
220
|
+
fn: Callable,
|
|
221
|
+
*args: Any,
|
|
222
|
+
**kwargs: Any,
|
|
223
|
+
) -> Any:
|
|
224
|
+
"""Run function on PPU device."""
|
|
225
|
+
assert len(dev_info.members) == 1
|
|
226
|
+
rank = dev_info.members[0].rank
|
|
227
|
+
|
|
228
|
+
result = simp.pcall_static((rank,), fn, *args, **kwargs)
|
|
229
|
+
return tree_map(partial(_maybe_set_dev_attr, dev_info.name), result)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def _device_run_tee(
|
|
233
|
+
dev_info: Device,
|
|
234
|
+
fn: Callable,
|
|
235
|
+
*args: Any,
|
|
236
|
+
**kwargs: Any,
|
|
237
|
+
) -> Any:
|
|
238
|
+
"""Run function on TEE device.
|
|
239
|
+
|
|
240
|
+
TEE devices execute functions in a trusted execution environment.
|
|
241
|
+
The execution is similar to PPU but runs in an isolated enclave.
|
|
242
|
+
"""
|
|
243
|
+
assert len(dev_info.members) == 1
|
|
244
|
+
rank = dev_info.members[0].rank
|
|
245
|
+
|
|
246
|
+
result = simp.pcall_static((rank,), fn, *args, **kwargs)
|
|
247
|
+
return tree_map(partial(_maybe_set_dev_attr, dev_info.name), result)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def _device_run(
|
|
251
|
+
dev_id: str,
|
|
252
|
+
fn: Callable,
|
|
253
|
+
*args: Any,
|
|
254
|
+
**kwargs: Any,
|
|
255
|
+
) -> Any:
|
|
256
|
+
"""Execute function on the specified device."""
|
|
257
|
+
cluster = _resolve_cluster()
|
|
258
|
+
if dev_id not in cluster.devices:
|
|
259
|
+
available = list(cluster.devices.keys())
|
|
260
|
+
raise DeviceNotFoundError(
|
|
261
|
+
f"Device '{dev_id}' not found in cluster. Available devices: {available}"
|
|
262
|
+
)
|
|
263
|
+
dev_info = cluster.devices[dev_id]
|
|
264
|
+
|
|
265
|
+
if g_auto_trans:
|
|
266
|
+
|
|
267
|
+
def trans(obj: Any) -> Any:
|
|
268
|
+
if isinstance(obj, Object) and is_device_obj(obj):
|
|
269
|
+
return _d2d(dev_id, obj)
|
|
270
|
+
else:
|
|
271
|
+
return obj
|
|
272
|
+
|
|
273
|
+
args, kwargs = tree_map(trans, (args, kwargs))
|
|
274
|
+
|
|
275
|
+
if dev_info.kind.upper() == "SPU":
|
|
276
|
+
return _device_run_spu(dev_info, fn, *args, **kwargs)
|
|
277
|
+
elif dev_info.kind.upper() == "TEE":
|
|
278
|
+
return _device_run_tee(dev_info, fn, *args, **kwargs)
|
|
279
|
+
elif dev_info.kind.upper() == "PPU":
|
|
280
|
+
return _device_run_ppu(dev_info, fn, *args, **kwargs)
|
|
281
|
+
else:
|
|
282
|
+
raise DeviceError(f"Unknown device type: {dev_info.kind}")
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
class DeviceContext:
|
|
286
|
+
"""Context for device-specific operations.
|
|
287
|
+
|
|
288
|
+
Supports explicit device specification or auto-inference from arguments.
|
|
289
|
+
|
|
290
|
+
Examples:
|
|
291
|
+
# Explicit device
|
|
292
|
+
@device("P0")
|
|
293
|
+
def add(a, b): ...
|
|
294
|
+
|
|
295
|
+
# Auto-infer device from arguments
|
|
296
|
+
@device()
|
|
297
|
+
def add(a, b): ...
|
|
298
|
+
|
|
299
|
+
# JAX frontend via .jax property (recommended for PPU)
|
|
300
|
+
@device("P0").jax
|
|
301
|
+
def add(a, b): return a + b
|
|
302
|
+
|
|
303
|
+
# Or use separate decorators (equivalent)
|
|
304
|
+
@device("P0")
|
|
305
|
+
@jax_fn
|
|
306
|
+
def add(a, b): return a + b
|
|
307
|
+
|
|
308
|
+
# Inline call style
|
|
309
|
+
result = device("P0").jax(fn)(x, y)
|
|
310
|
+
"""
|
|
311
|
+
|
|
312
|
+
def __init__(self, dev_id: str | None = None):
|
|
313
|
+
"""Create a DeviceContext.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
dev_id: Device ID (e.g., "P0", "SP0") or None for auto-inference.
|
|
317
|
+
"""
|
|
318
|
+
self.dev_id = dev_id
|
|
319
|
+
|
|
320
|
+
def _resolve_device(self, *args: Any, **kwargs: Any) -> str:
|
|
321
|
+
"""Resolve device ID, inferring from args if needed."""
|
|
322
|
+
if self.dev_id is not None:
|
|
323
|
+
return self.dev_id
|
|
324
|
+
return _infer_device_from_args(*args, **kwargs)
|
|
325
|
+
|
|
326
|
+
def _is_spu_device(self) -> bool:
|
|
327
|
+
"""Check if this device context targets an SPU device."""
|
|
328
|
+
if self.dev_id is None:
|
|
329
|
+
return False
|
|
330
|
+
cluster = _resolve_cluster()
|
|
331
|
+
if self.dev_id not in cluster.devices:
|
|
332
|
+
return False
|
|
333
|
+
return bool(cluster.devices[self.dev_id].kind.upper() == "SPU")
|
|
334
|
+
|
|
335
|
+
@property
|
|
336
|
+
def jax(self) -> Callable[[Callable], Callable]:
|
|
337
|
+
"""Return a decorator that wraps JAX functions for this device.
|
|
338
|
+
|
|
339
|
+
For PPU/TEE: applies tensor.jax_fn to compile JAX code via StableHLO.
|
|
340
|
+
For SPU: no-op wrapper, as SPU natively uses JAX via spu.run_jax.
|
|
341
|
+
|
|
342
|
+
This is syntax sugar for using jax_fn adaptor:
|
|
343
|
+
device("P0").jax(fn) == device("P0")(jax_fn(fn))
|
|
344
|
+
|
|
345
|
+
Examples:
|
|
346
|
+
# As decorator
|
|
347
|
+
@device("P0").jax
|
|
348
|
+
def add(a, b): return a + b
|
|
349
|
+
|
|
350
|
+
# As inline call
|
|
351
|
+
result = device("P0").jax(fn)(x, y)
|
|
352
|
+
"""
|
|
353
|
+
|
|
354
|
+
def wrapper(fn: Callable) -> Callable:
|
|
355
|
+
# SPU natively uses JAX via spu.run_jax, no extra wrapping needed
|
|
356
|
+
if self._is_spu_device():
|
|
357
|
+
return self(fn)
|
|
358
|
+
# PPU/TEE need tensor.jax_fn to compile JAX code
|
|
359
|
+
from mplang.v2.dialects.tensor import jax_fn
|
|
360
|
+
|
|
361
|
+
return self(jax_fn(fn))
|
|
362
|
+
|
|
363
|
+
return wrapper
|
|
364
|
+
|
|
365
|
+
def __call__(self, fn: Callable) -> Callable:
|
|
366
|
+
"""Wrap function for execution on this device."""
|
|
367
|
+
|
|
368
|
+
@wraps(fn)
|
|
369
|
+
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
|
370
|
+
dev_id = self._resolve_device(*args, **kwargs)
|
|
371
|
+
return _device_run(dev_id, fn, *args, **kwargs)
|
|
372
|
+
|
|
373
|
+
return wrapped
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def device(dev_id: str | None = None) -> DeviceContext:
|
|
377
|
+
"""Create a device context for device-specific execution.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
dev_id: Device ID (e.g., "P0", "SP0") or None for auto-inference.
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
DeviceContext that wraps functions for device execution.
|
|
384
|
+
|
|
385
|
+
Usage patterns:
|
|
386
|
+
# Explicit device + generic function
|
|
387
|
+
@device("P0")
|
|
388
|
+
def fn(a, b): ...
|
|
389
|
+
|
|
390
|
+
# Auto-infer device from arguments
|
|
391
|
+
@device()
|
|
392
|
+
def fn(a, b): ...
|
|
393
|
+
|
|
394
|
+
# JAX frontend via .jax property (recommended for PPU)
|
|
395
|
+
@device("P0").jax
|
|
396
|
+
def add(a, b): return a + b
|
|
397
|
+
|
|
398
|
+
# Inline call
|
|
399
|
+
result = device("P0").jax(fn)(x, y)
|
|
400
|
+
|
|
401
|
+
# Separate decorators (equivalent to above)
|
|
402
|
+
@device("P0")
|
|
403
|
+
@jax_fn
|
|
404
|
+
def add(a, b): return a + b
|
|
405
|
+
"""
|
|
406
|
+
return DeviceContext(dev_id)
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def _ensure_tee_session(
|
|
410
|
+
frm_dev_id: str,
|
|
411
|
+
to_dev_id: str,
|
|
412
|
+
frm_rank: int,
|
|
413
|
+
tee_rank: int,
|
|
414
|
+
) -> tuple[Object, Object]:
|
|
415
|
+
"""Ensure a TEE session (sess_frm at sender, sess_tee at TEE) exists.
|
|
416
|
+
|
|
417
|
+
Performs remote attestation and establishes an encrypted channel between
|
|
418
|
+
a PPU and a TEE device using NIST SP 800-56C compliant key derivation.
|
|
419
|
+
Session keys are cached within the same execution context to avoid
|
|
420
|
+
repeated handshakes.
|
|
421
|
+
|
|
422
|
+
Protocol (ECDH + Remote Attestation + HKDF):
|
|
423
|
+
1. TEE generates keypair (sk, pk) and creates attestation quote binding pk
|
|
424
|
+
2. Quote is sent to sender (PPU) for verification
|
|
425
|
+
3. Sender verifies quote and extracts TEE's attested public key
|
|
426
|
+
4. Sender generates ephemeral keypair and sends pk to TEE
|
|
427
|
+
5. Both sides derive ECDH shared secret using X25519
|
|
428
|
+
6. Both sides derive session keys from shared secret using HKDF-SHA256
|
|
429
|
+
with protocol-specific info string for domain separation
|
|
430
|
+
|
|
431
|
+
Security properties:
|
|
432
|
+
- Remote attestation: TEE identity is cryptographically verified
|
|
433
|
+
- Ephemeral keys: Perfect forward secrecy (keys not reused across sessions)
|
|
434
|
+
- HKDF derivation: NIST SP 800-56C compliant (shared secret not used directly)
|
|
435
|
+
- Domain separation: Info parameter binds keys to TEE protocol v2
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
frm_dev_id: Source device ID (PPU)
|
|
439
|
+
to_dev_id: Target device ID (TEE)
|
|
440
|
+
frm_rank: Rank of the source party
|
|
441
|
+
tee_rank: Rank of the TEE party
|
|
442
|
+
|
|
443
|
+
Returns:
|
|
444
|
+
Tuple of (sess_frm, sess_tee) where each is a symmetric key Object
|
|
445
|
+
"""
|
|
446
|
+
import mplang.v2.edsl as el
|
|
447
|
+
|
|
448
|
+
# Get current context ID for cache isolation
|
|
449
|
+
current_ctx = el.get_current_context()
|
|
450
|
+
current_context_id = id(current_ctx)
|
|
451
|
+
|
|
452
|
+
# Check cache
|
|
453
|
+
key = (frm_dev_id, to_dev_id)
|
|
454
|
+
if key in _tee_session_cache:
|
|
455
|
+
cached_context_id, sess_frm, sess_tee = _tee_session_cache[key]
|
|
456
|
+
if cached_context_id == current_context_id:
|
|
457
|
+
return sess_frm, sess_tee
|
|
458
|
+
else:
|
|
459
|
+
# Different context, cannot reuse
|
|
460
|
+
del _tee_session_cache[key]
|
|
461
|
+
|
|
462
|
+
# 1. TEE generates keypair and attestation quote
|
|
463
|
+
tee_sk, tee_pk = simp.pcall_static((tee_rank,), crypto.kem_keygen, _TEE_KEM_SUITE)
|
|
464
|
+
quote = simp.pcall_static((tee_rank,), tee.quote_gen, tee_pk)
|
|
465
|
+
|
|
466
|
+
# 2. Send quote to sender for attestation verification
|
|
467
|
+
quote_at_sender = simp.shuffle_static(quote, {frm_rank: tee_rank})
|
|
468
|
+
|
|
469
|
+
# 3. Sender verifies quote and extracts TEE's public key
|
|
470
|
+
tee_pk_at_sender = simp.pcall_static((frm_rank,), tee.attest, quote_at_sender)
|
|
471
|
+
|
|
472
|
+
# 4. Sender generates ephemeral keypair and sends pk to TEE
|
|
473
|
+
v_sk, v_pk = simp.pcall_static((frm_rank,), crypto.kem_keygen, _TEE_KEM_SUITE)
|
|
474
|
+
v_pk_at_tee = simp.shuffle_static(v_pk, {tee_rank: frm_rank})
|
|
475
|
+
|
|
476
|
+
# 5. Both sides derive ECDH shared secret using X25519
|
|
477
|
+
# Note: kem_derive signature is (private_key, public_key) - suite is in key type
|
|
478
|
+
shared_frm = simp.pcall_static(
|
|
479
|
+
(frm_rank,), crypto.kem_derive, v_sk, tee_pk_at_sender
|
|
480
|
+
)
|
|
481
|
+
shared_tee = simp.pcall_static((tee_rank,), crypto.kem_derive, tee_sk, v_pk_at_tee)
|
|
482
|
+
|
|
483
|
+
# 6. Derive session keys using HKDF-SHA256 for domain separation
|
|
484
|
+
# Per NIST SP 800-56C: "shared secret SHALL NOT be used directly as a key"
|
|
485
|
+
# HKDF provides: uniform distribution + protocol-specific context binding
|
|
486
|
+
sess_frm = simp.pcall_static((frm_rank,), crypto.hkdf, shared_frm, _TEE_HKDF_INFO)
|
|
487
|
+
sess_tee = simp.pcall_static((tee_rank,), crypto.hkdf, shared_tee, _TEE_HKDF_INFO)
|
|
488
|
+
|
|
489
|
+
# Cache the session
|
|
490
|
+
_tee_session_cache[key] = (current_context_id, sess_frm, sess_tee)
|
|
491
|
+
|
|
492
|
+
return sess_frm, sess_tee
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def _d2d(to_dev_id: str, obj: Object) -> Object:
|
|
496
|
+
"""Transfer object to target device."""
|
|
497
|
+
if not isinstance(obj, Object):
|
|
498
|
+
raise TypeError(f"Expected Object, got {type(obj)}")
|
|
499
|
+
|
|
500
|
+
frm_dev_id = get_dev_attr(obj)
|
|
501
|
+
if frm_dev_id == to_dev_id:
|
|
502
|
+
return obj
|
|
503
|
+
|
|
504
|
+
cluster = _resolve_cluster()
|
|
505
|
+
frm_dev = cluster.devices[frm_dev_id]
|
|
506
|
+
to_dev = cluster.devices[to_dev_id]
|
|
507
|
+
frm_to_pair = (frm_dev.kind.upper(), to_dev.kind.upper())
|
|
508
|
+
|
|
509
|
+
# PPU -> PPU
|
|
510
|
+
if frm_to_pair == ("PPU", "PPU"):
|
|
511
|
+
assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
|
|
512
|
+
to_rank = to_dev.members[0].rank
|
|
513
|
+
frm_rank = frm_dev.members[0].rank
|
|
514
|
+
|
|
515
|
+
var = simp.shuffle_static(obj, {to_rank: frm_rank})
|
|
516
|
+
return set_dev_attr(var, to_dev_id)
|
|
517
|
+
|
|
518
|
+
# PPU -> SPU (Seal)
|
|
519
|
+
elif frm_to_pair == ("PPU", "SPU"):
|
|
520
|
+
assert len(frm_dev.members) == 1
|
|
521
|
+
frm_rank = frm_dev.members[0].rank
|
|
522
|
+
spu_parties = tuple(m.rank for m in to_dev.members)
|
|
523
|
+
spu_config = spu.SPUConfig.from_dict(to_dev.config)
|
|
524
|
+
|
|
525
|
+
# 1. Generate shares on source
|
|
526
|
+
# We call spu.make_shares inside pcall on the source party
|
|
527
|
+
shares_on_source = simp.pcall_static(
|
|
528
|
+
(frm_rank,),
|
|
529
|
+
spu.make_shares,
|
|
530
|
+
spu_config,
|
|
531
|
+
obj,
|
|
532
|
+
count=len(spu_parties),
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
# 2. Distribute shares
|
|
536
|
+
distributed_shares = []
|
|
537
|
+
for i, target_rank in enumerate(spu_parties):
|
|
538
|
+
# Extract i-th share (still on source)
|
|
539
|
+
# shares_on_source is MP[tuple[SS, ...], (frm_rank)]
|
|
540
|
+
# We need to extract the i-th element.
|
|
541
|
+
# Since pcall returns MPType, we can't index it directly if it's a tuple of shares.
|
|
542
|
+
# Wait, pcall returns a PyTree of MPObjects if the function returns a PyTree.
|
|
543
|
+
# So shares_on_source IS a tuple of MPObjects.
|
|
544
|
+
share_i = shares_on_source[i]
|
|
545
|
+
|
|
546
|
+
share_at_target = simp.shuffle_static(share_i, {target_rank: frm_rank})
|
|
547
|
+
distributed_shares.append(share_at_target)
|
|
548
|
+
|
|
549
|
+
# 3. Converge
|
|
550
|
+
var = simp.converge(*distributed_shares)
|
|
551
|
+
return set_dev_attr(var, to_dev_id)
|
|
552
|
+
|
|
553
|
+
# SPU -> PPU (Reveal)
|
|
554
|
+
elif frm_to_pair == ("SPU", "PPU"):
|
|
555
|
+
assert len(to_dev.members) == 1
|
|
556
|
+
to_rank = to_dev.members[0].rank
|
|
557
|
+
spu_parties = tuple(m.rank for m in frm_dev.members)
|
|
558
|
+
spu_config = spu.SPUConfig.from_dict(frm_dev.config)
|
|
559
|
+
|
|
560
|
+
# 1. Gather shares to target
|
|
561
|
+
gathered_shares = []
|
|
562
|
+
for source_rank in spu_parties:
|
|
563
|
+
# Extract share from logical variable
|
|
564
|
+
share_on_source = simp.pcall_static((source_rank,), lambda x: x, obj)
|
|
565
|
+
|
|
566
|
+
# Move to target
|
|
567
|
+
share_at_target = simp.shuffle_static(
|
|
568
|
+
share_on_source, {to_rank: source_rank}
|
|
569
|
+
)
|
|
570
|
+
gathered_shares.append(share_at_target)
|
|
571
|
+
|
|
572
|
+
# 2. Reconstruct on target
|
|
573
|
+
# We call spu.reconstruct inside pcall on the target party
|
|
574
|
+
var = simp.pcall_static(
|
|
575
|
+
(to_rank,), lambda *s: spu.reconstruct(spu_config, s), *gathered_shares
|
|
576
|
+
)
|
|
577
|
+
return set_dev_attr(var, to_dev_id)
|
|
578
|
+
|
|
579
|
+
# SPU -> SPU
|
|
580
|
+
elif frm_to_pair == ("SPU", "SPU"):
|
|
581
|
+
raise NotImplementedError("SPU to SPU transfer not implemented yet.")
|
|
582
|
+
|
|
583
|
+
# PPU -> TEE (Encrypted transfer)
|
|
584
|
+
elif frm_to_pair == ("PPU", "TEE"):
|
|
585
|
+
assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
|
|
586
|
+
frm_rank = frm_dev.members[0].rank
|
|
587
|
+
tee_rank = to_dev.members[0].rank
|
|
588
|
+
|
|
589
|
+
# Establish encrypted session (includes remote attestation)
|
|
590
|
+
sess_frm, sess_tee = _ensure_tee_session(
|
|
591
|
+
frm_dev_id, to_dev_id, frm_rank, tee_rank
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
# Encrypt on sender and send to TEE
|
|
595
|
+
ct = simp.pcall_static((frm_rank,), crypto.sym_encrypt, sess_frm, obj)
|
|
596
|
+
ct_at_tee = simp.shuffle_static(ct, {tee_rank: frm_rank})
|
|
597
|
+
|
|
598
|
+
# Decrypt on TEE
|
|
599
|
+
var = simp.pcall_static(
|
|
600
|
+
(tee_rank,),
|
|
601
|
+
crypto.sym_decrypt,
|
|
602
|
+
sess_tee,
|
|
603
|
+
ct_at_tee,
|
|
604
|
+
obj.type.value_type if hasattr(obj.type, "value_type") else obj.type,
|
|
605
|
+
)
|
|
606
|
+
return set_dev_attr(var, to_dev_id)
|
|
607
|
+
|
|
608
|
+
# TEE -> PPU (Encrypted transfer)
|
|
609
|
+
elif frm_to_pair == ("TEE", "PPU"):
|
|
610
|
+
assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
|
|
611
|
+
tee_rank = frm_dev.members[0].rank
|
|
612
|
+
ppu_rank = to_dev.members[0].rank
|
|
613
|
+
|
|
614
|
+
# Establish encrypted session (reuse existing or create new)
|
|
615
|
+
# Note: We pass (ppu, tee) order to match the session key derivation
|
|
616
|
+
sess_ppu, sess_tee = _ensure_tee_session(
|
|
617
|
+
to_dev_id, frm_dev_id, ppu_rank, tee_rank
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
# Encrypt on TEE and send to PPU
|
|
621
|
+
ct = simp.pcall_static((tee_rank,), crypto.sym_encrypt, sess_tee, obj)
|
|
622
|
+
ct_at_ppu = simp.shuffle_static(ct, {ppu_rank: tee_rank})
|
|
623
|
+
|
|
624
|
+
# Decrypt on PPU
|
|
625
|
+
var = simp.pcall_static(
|
|
626
|
+
(ppu_rank,),
|
|
627
|
+
crypto.sym_decrypt,
|
|
628
|
+
sess_ppu,
|
|
629
|
+
ct_at_ppu,
|
|
630
|
+
obj.type.value_type if hasattr(obj.type, "value_type") else obj.type,
|
|
631
|
+
)
|
|
632
|
+
return set_dev_attr(var, to_dev_id)
|
|
633
|
+
|
|
634
|
+
# TEE -> SPU (TEE acts like a PPU for SPU sealing)
|
|
635
|
+
elif frm_to_pair == ("TEE", "SPU"):
|
|
636
|
+
assert len(frm_dev.members) == 1
|
|
637
|
+
frm_rank = frm_dev.members[0].rank
|
|
638
|
+
spu_parties = tuple(m.rank for m in to_dev.members)
|
|
639
|
+
spu_config = spu.SPUConfig.from_dict(to_dev.config)
|
|
640
|
+
|
|
641
|
+
# Generate shares on TEE (same logic as PPU -> SPU)
|
|
642
|
+
shares_on_source = simp.pcall_static(
|
|
643
|
+
(frm_rank,),
|
|
644
|
+
spu.make_shares,
|
|
645
|
+
spu_config,
|
|
646
|
+
obj,
|
|
647
|
+
count=len(spu_parties),
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
# Distribute shares to SPU parties
|
|
651
|
+
distributed_shares = []
|
|
652
|
+
for i, target_rank in enumerate(spu_parties):
|
|
653
|
+
share_i = shares_on_source[i]
|
|
654
|
+
share_at_target = simp.shuffle_static(share_i, {target_rank: frm_rank})
|
|
655
|
+
distributed_shares.append(share_at_target)
|
|
656
|
+
|
|
657
|
+
# Converge shares
|
|
658
|
+
var = simp.converge(*distributed_shares)
|
|
659
|
+
return set_dev_attr(var, to_dev_id)
|
|
660
|
+
|
|
661
|
+
# SPU -> TEE (Reveal to TEE)
|
|
662
|
+
elif frm_to_pair == ("SPU", "TEE"):
|
|
663
|
+
assert len(to_dev.members) == 1
|
|
664
|
+
to_rank = to_dev.members[0].rank
|
|
665
|
+
spu_parties = tuple(m.rank for m in frm_dev.members)
|
|
666
|
+
spu_config = spu.SPUConfig.from_dict(frm_dev.config)
|
|
667
|
+
|
|
668
|
+
# Gather shares to TEE (same logic as SPU -> PPU)
|
|
669
|
+
gathered_shares = []
|
|
670
|
+
for source_rank in spu_parties:
|
|
671
|
+
share_on_source = simp.pcall_static((source_rank,), lambda x: x, obj)
|
|
672
|
+
share_at_target = simp.shuffle_static(
|
|
673
|
+
share_on_source, {to_rank: source_rank}
|
|
674
|
+
)
|
|
675
|
+
gathered_shares.append(share_at_target)
|
|
676
|
+
|
|
677
|
+
# Reconstruct on TEE
|
|
678
|
+
var = simp.pcall_static(
|
|
679
|
+
(to_rank,), lambda *s: spu.reconstruct(spu_config, s), *gathered_shares
|
|
680
|
+
)
|
|
681
|
+
return set_dev_attr(var, to_dev_id)
|
|
682
|
+
|
|
683
|
+
# TEE -> TEE
|
|
684
|
+
elif frm_to_pair == ("TEE", "TEE"):
|
|
685
|
+
raise NotImplementedError("TEE to TEE transfer not implemented yet.")
|
|
686
|
+
|
|
687
|
+
else:
|
|
688
|
+
raise DeviceError(f"Unsupported device transfer: {frm_to_pair}")
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
def put(to_dev_id: str, obj: Any) -> Object:
|
|
692
|
+
"""Put data onto a device.
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
to_dev_id: Target device ID (e.g., "P0", "SP0").
|
|
696
|
+
obj: The object to put onto the device.
|
|
697
|
+
|
|
698
|
+
If obj is already a device object, it moves it to the target device.
|
|
699
|
+
If obj is a host object (e.g. numpy array), it uploads it to the target device.
|
|
700
|
+
"""
|
|
701
|
+
cluster = _resolve_cluster()
|
|
702
|
+
if to_dev_id not in cluster.devices:
|
|
703
|
+
available = list(cluster.devices.keys())
|
|
704
|
+
raise DeviceNotFoundError(
|
|
705
|
+
f"Device '{to_dev_id}' not found in cluster. Available devices: {available}"
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
if isinstance(obj, Object) and is_device_obj(obj):
|
|
709
|
+
return _d2d(to_dev_id, obj)
|
|
710
|
+
|
|
711
|
+
# Host -> Device
|
|
712
|
+
dev_info = cluster.devices[to_dev_id]
|
|
713
|
+
|
|
714
|
+
if dev_info.kind.upper() == "PPU":
|
|
715
|
+
assert len(dev_info.members) == 1
|
|
716
|
+
rank = dev_info.members[0].rank
|
|
717
|
+
|
|
718
|
+
var = simp.constant((rank,), obj)
|
|
719
|
+
return set_dev_attr(var, to_dev_id)
|
|
720
|
+
|
|
721
|
+
elif dev_info.kind.upper() == "SPU":
|
|
722
|
+
# Host -> SPU: Run identity function on SPU.
|
|
723
|
+
# Note: This results in a Public (replicated) value on the SPU.
|
|
724
|
+
# SPU operations will automatically promote it to Secret if needed.
|
|
725
|
+
return cast(Object, device(to_dev_id)(lambda x: x)(obj))
|
|
726
|
+
|
|
727
|
+
elif dev_info.kind.upper() == "TEE":
|
|
728
|
+
# Host -> TEE: Similar to PPU, create constant on TEE device
|
|
729
|
+
assert len(dev_info.members) == 1
|
|
730
|
+
rank = dev_info.members[0].rank
|
|
731
|
+
|
|
732
|
+
var = simp.constant((rank,), obj)
|
|
733
|
+
return set_dev_attr(var, to_dev_id)
|
|
734
|
+
|
|
735
|
+
else:
|
|
736
|
+
raise DeviceError(f"Cannot put to device kind '{dev_info.kind}'")
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
def fetch(obj: Object) -> Any:
|
|
740
|
+
"""Fetch data from device to host based on device attribute.
|
|
741
|
+
|
|
742
|
+
This function fetches data from the device the object resides on.
|
|
743
|
+
For PPU/TEE: fetches from the single member rank.
|
|
744
|
+
For SPU: fetches from all parties (returns reconstructed value).
|
|
745
|
+
|
|
746
|
+
Args:
|
|
747
|
+
obj: Object with device attribute to fetch.
|
|
748
|
+
|
|
749
|
+
Returns:
|
|
750
|
+
Python value (numpy array, scalar, etc.)
|
|
751
|
+
"""
|
|
752
|
+
from mplang.v2.backends.simp_driver.state import SimpDriver
|
|
753
|
+
from mplang.v2.backends.simp_driver.values import DriverVar
|
|
754
|
+
from mplang.v2.edsl.context import get_current_context
|
|
755
|
+
from mplang.v2.runtime.interpreter import InterpObject, Interpreter
|
|
756
|
+
from mplang.v2.runtime.value import WrapValue
|
|
757
|
+
|
|
758
|
+
def _unwrap_value(val: Any) -> Any:
|
|
759
|
+
"""Unwrap WrapValue to get the underlying data."""
|
|
760
|
+
if isinstance(val, WrapValue):
|
|
761
|
+
return val.data
|
|
762
|
+
return val
|
|
763
|
+
|
|
764
|
+
# 1. Ensure is object and is device obj
|
|
765
|
+
if not is_device_obj(obj):
|
|
766
|
+
raise DeviceError(
|
|
767
|
+
"Object does not have device attribute. Use mp.fetch() directly."
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
# 2. Get device information according to device id
|
|
771
|
+
dev_id = get_dev_attr(obj)
|
|
772
|
+
cluster = _resolve_cluster()
|
|
773
|
+
dev_info = cluster.devices[dev_id]
|
|
774
|
+
|
|
775
|
+
# Get interpreter context
|
|
776
|
+
ctx = get_current_context()
|
|
777
|
+
if not isinstance(ctx, Interpreter):
|
|
778
|
+
raise RuntimeError("No interpreter context available for fetch")
|
|
779
|
+
|
|
780
|
+
simp_state = ctx.get_dialect_state("simp")
|
|
781
|
+
assert isinstance(simp_state, SimpDriver), "DriverVar requires simp state"
|
|
782
|
+
|
|
783
|
+
# Unwrap InterpObject to get runtime value
|
|
784
|
+
assert isinstance(obj, InterpObject), f"Expected InterpObject, got {type(obj)}"
|
|
785
|
+
runtime_obj = obj.runtime_obj
|
|
786
|
+
|
|
787
|
+
def _fetch_from_rank(rank: int) -> Any:
|
|
788
|
+
"""Fetch value from a rank (DriverVar values are always URIs)."""
|
|
789
|
+
uri = runtime_obj.values[rank]
|
|
790
|
+
assert isinstance(uri, str) and "://" in uri, f"Expected URI, got {uri}"
|
|
791
|
+
return simp_state.fetch(rank, uri).result()
|
|
792
|
+
|
|
793
|
+
# 3. Match device type and do corresponding fetch action
|
|
794
|
+
if isinstance(runtime_obj, DriverVar):
|
|
795
|
+
# 3.1 PPU/TEE: single member, fetch directly
|
|
796
|
+
if dev_info.kind.upper() in ("PPU", "TEE"):
|
|
797
|
+
assert len(dev_info.members) == 1
|
|
798
|
+
result = _fetch_from_rank(dev_info.members[0].rank)
|
|
799
|
+
# 4. Unwrap if WrapValue
|
|
800
|
+
return _unwrap_value(result)
|
|
801
|
+
|
|
802
|
+
# 3.2 SPU: fetch from all ranks and reconstruct
|
|
803
|
+
elif dev_info.kind.upper() == "SPU":
|
|
804
|
+
# Fetch shares from all SPU members
|
|
805
|
+
shares = [_fetch_from_rank(m.rank) for m in dev_info.members]
|
|
806
|
+
# For now, just return the first share (TODO: implement spu.reconstruct)
|
|
807
|
+
# In practice, SPU values should be revealed to a PPU first
|
|
808
|
+
result = shares[0] if shares else None
|
|
809
|
+
# 4. Unwrap if WrapValue
|
|
810
|
+
return _unwrap_value(result)
|
|
811
|
+
|
|
812
|
+
# Direct value (not DriverVar)
|
|
813
|
+
return _unwrap_value(runtime_obj)
|