mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/_device.py
DELETED
|
@@ -1,602 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
This module provides the device oriented programming interface for MPC.
|
|
17
|
-
|
|
18
|
-
The device oriented programming interface is designed to provide a high-level
|
|
19
|
-
abstraction for the MPC programming. It allows the user to write the program
|
|
20
|
-
in a device-oriented manner, and the runtime will take care of the data
|
|
21
|
-
transformation between devices.
|
|
22
|
-
"""
|
|
23
|
-
|
|
24
|
-
from __future__ import annotations
|
|
25
|
-
|
|
26
|
-
from collections.abc import Callable
|
|
27
|
-
from functools import partial, wraps
|
|
28
|
-
from typing import Any, cast
|
|
29
|
-
|
|
30
|
-
from jax.tree_util import tree_map, tree_unflatten
|
|
31
|
-
|
|
32
|
-
from mplang.v1.core import (
|
|
33
|
-
ClusterSpec,
|
|
34
|
-
Device,
|
|
35
|
-
Mask,
|
|
36
|
-
MPContext,
|
|
37
|
-
MPObject,
|
|
38
|
-
TableLike,
|
|
39
|
-
TensorLike,
|
|
40
|
-
cur_ctx,
|
|
41
|
-
peval,
|
|
42
|
-
)
|
|
43
|
-
from mplang.v1.ops import basic, crypto, jax_cc, nnx_cc, spu, tee
|
|
44
|
-
from mplang.v1.ops.base import FeOperation
|
|
45
|
-
from mplang.v1.ops.jax_cc import JaxRunner
|
|
46
|
-
from mplang.v1.simp import mpi
|
|
47
|
-
from mplang.v1.simp.api import run_at
|
|
48
|
-
|
|
49
|
-
# Automatic transfer between devices when parameter is not on the target device.
|
|
50
|
-
g_auto_trans: bool = True
|
|
51
|
-
|
|
52
|
-
_HKDF_INFO_LITERAL: str = "mplang/device/tee/v1"
|
|
53
|
-
# Default KEM suite for TEE session establishment; make configurable via ClusterSpec in future.
|
|
54
|
-
_TEE_KEM_SUITE: str = "x25519"
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
# Context-aware session management
|
|
58
|
-
def _get_context_id(ctx: MPContext) -> int:
|
|
59
|
-
"""
|
|
60
|
-
Get unique identifier for a context.
|
|
61
|
-
|
|
62
|
-
Args:
|
|
63
|
-
ctx: The context object (TraceContext or InterpContext)
|
|
64
|
-
|
|
65
|
-
Returns:
|
|
66
|
-
Unique integer ID for this context instance
|
|
67
|
-
"""
|
|
68
|
-
return id(ctx)
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
# magic attribute name to mark a MPObject as a device object
|
|
72
|
-
DEVICE_ATTR_NAME = "__device__"
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
def is_device_obj(obj: Any) -> bool:
|
|
76
|
-
if not isinstance(obj, MPObject):
|
|
77
|
-
return False
|
|
78
|
-
return DEVICE_ATTR_NAME in obj.attrs
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
def set_dev_attr(obj: MPObject, dev_id: str) -> MPObject:
|
|
82
|
-
if not isinstance(obj, MPObject):
|
|
83
|
-
raise TypeError(f"Input must be an instance of MPObject, {obj}")
|
|
84
|
-
obj.attrs[DEVICE_ATTR_NAME] = dev_id
|
|
85
|
-
return obj
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
def get_dev_attr(obj: MPObject) -> str:
|
|
89
|
-
if not isinstance(obj, MPObject):
|
|
90
|
-
raise TypeError("Input must be an instance of MPObject")
|
|
91
|
-
|
|
92
|
-
return str(obj.attrs[DEVICE_ATTR_NAME])
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
def _infer_device_from_args(*args: Any, **kwargs: Any) -> str:
|
|
96
|
-
"""Infer target device from function arguments.
|
|
97
|
-
|
|
98
|
-
Inference strategy:
|
|
99
|
-
1. Collect all MPObject arguments and check device attributes
|
|
100
|
-
- If MPObject has no device attr -> error (user must set_devid)
|
|
101
|
-
- If no MPObject arguments -> error (explicit device required)
|
|
102
|
-
|
|
103
|
-
2. Analyze device distribution
|
|
104
|
-
2.1 All objects on same device -> use that device
|
|
105
|
-
2.2 Multiple devices with g_auto_trans enabled:
|
|
106
|
-
- Single SPU (+ PPUs) -> use SPU (auto-transfer from PPUs)
|
|
107
|
-
- Single TEE (+ PPUs) -> use TEE (auto-transfer from PPUs)
|
|
108
|
-
- Multiple PPUs only -> error (ambiguous, need explicit device)
|
|
109
|
-
- Multiple SPUs -> error (ambiguous)
|
|
110
|
-
- Multiple TEEs -> error (ambiguous)
|
|
111
|
-
- SPU + TEE -> error (conflicting secure devices)
|
|
112
|
-
2.3 Multiple devices with g_auto_trans disabled -> error
|
|
113
|
-
|
|
114
|
-
Args:
|
|
115
|
-
*args: Positional arguments
|
|
116
|
-
**kwargs: Keyword arguments
|
|
117
|
-
|
|
118
|
-
Returns:
|
|
119
|
-
Device id string
|
|
120
|
-
|
|
121
|
-
Raises:
|
|
122
|
-
ValueError: When inference fails or is ambiguous
|
|
123
|
-
"""
|
|
124
|
-
from jax.tree_util import tree_flatten
|
|
125
|
-
|
|
126
|
-
# Step 1: Collect all MPObject arguments and validate device attributes
|
|
127
|
-
all_args = tree_flatten((args, kwargs))[0]
|
|
128
|
-
device_objs = []
|
|
129
|
-
|
|
130
|
-
for obj in all_args:
|
|
131
|
-
if isinstance(obj, MPObject):
|
|
132
|
-
if not is_device_obj(obj):
|
|
133
|
-
raise ValueError(
|
|
134
|
-
"MPObject is missing device attribute. "
|
|
135
|
-
"If you're mixing device-level and simp-level code, "
|
|
136
|
-
"use set_dev_attr(obj, 'device_id') to mark the device explicitly."
|
|
137
|
-
)
|
|
138
|
-
device_objs.append(obj)
|
|
139
|
-
|
|
140
|
-
if not device_objs:
|
|
141
|
-
raise ValueError(
|
|
142
|
-
"Cannot infer device: no MPObject arguments found. "
|
|
143
|
-
"Please specify device explicitly using device('device_id')(fn)."
|
|
144
|
-
)
|
|
145
|
-
|
|
146
|
-
# Step 2: Extract all unique devices
|
|
147
|
-
devices = {get_dev_attr(obj) for obj in device_objs}
|
|
148
|
-
|
|
149
|
-
if len(devices) == 1:
|
|
150
|
-
return devices.pop() # All arguments on same device
|
|
151
|
-
|
|
152
|
-
# Step 3: Multiple devices - check if auto-transfer is enabled
|
|
153
|
-
if not g_auto_trans:
|
|
154
|
-
raise ValueError(
|
|
155
|
-
f"Cannot infer device: arguments from multiple devices {devices} "
|
|
156
|
-
f"but auto-transfer is disabled (g_auto_trans=False). "
|
|
157
|
-
f"Please enable auto-transfer or put all data on same device first."
|
|
158
|
-
)
|
|
159
|
-
|
|
160
|
-
# Step 4: Analyze device kinds for auto-transfer scenario
|
|
161
|
-
cluster_spec = cur_ctx().cluster_spec
|
|
162
|
-
device_kinds = {
|
|
163
|
-
dev_id: cluster_spec.devices[dev_id].kind.upper() for dev_id in devices
|
|
164
|
-
}
|
|
165
|
-
|
|
166
|
-
# Count devices by type
|
|
167
|
-
spu_devs = [d for d, k in device_kinds.items() if k == "SPU"]
|
|
168
|
-
tee_devs = [d for d, k in device_kinds.items() if k == "TEE"]
|
|
169
|
-
ppu_devs = [d for d, k in device_kinds.items() if k == "PPU"]
|
|
170
|
-
|
|
171
|
-
# Decision logic
|
|
172
|
-
# Case 1: Only PPUs -> ambiguous
|
|
173
|
-
if not spu_devs and not tee_devs:
|
|
174
|
-
raise ValueError(
|
|
175
|
-
f"Cannot infer device: arguments from multiple PPU devices {ppu_devs}. "
|
|
176
|
-
f"Please specify device explicitly or use put() to consolidate data."
|
|
177
|
-
)
|
|
178
|
-
|
|
179
|
-
# Case 2: Single SPU (possibly with PPUs) -> use SPU
|
|
180
|
-
if len(spu_devs) == 1 and len(tee_devs) == 0:
|
|
181
|
-
return spu_devs[0]
|
|
182
|
-
|
|
183
|
-
# Case 3: Single TEE (possibly with PPUs) -> use TEE
|
|
184
|
-
if len(tee_devs) == 1 and len(spu_devs) == 0:
|
|
185
|
-
return tee_devs[0]
|
|
186
|
-
|
|
187
|
-
# Case 4: Multiple SPUs -> ambiguous
|
|
188
|
-
if len(spu_devs) > 1:
|
|
189
|
-
raise ValueError(
|
|
190
|
-
f"Ambiguous device inference: arguments from multiple SPU devices {spu_devs}. "
|
|
191
|
-
f"Please specify which SPU to use explicitly."
|
|
192
|
-
)
|
|
193
|
-
|
|
194
|
-
# Case 5: Multiple TEEs -> ambiguous
|
|
195
|
-
if len(tee_devs) > 1:
|
|
196
|
-
raise ValueError(
|
|
197
|
-
f"Ambiguous device inference: arguments from multiple TEE devices {tee_devs}. "
|
|
198
|
-
f"Please specify which TEE to use explicitly."
|
|
199
|
-
)
|
|
200
|
-
|
|
201
|
-
# Case 6: Both SPU and TEE -> conflicting
|
|
202
|
-
if spu_devs and tee_devs:
|
|
203
|
-
raise ValueError(
|
|
204
|
-
f"Ambiguous device inference: arguments from both SPU {spu_devs} and TEE {tee_devs}. "
|
|
205
|
-
f"Please specify which secure device to use explicitly."
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
# Should never reach here
|
|
209
|
-
raise ValueError(f"Unexpected device configuration: {devices}")
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
def _device_run_spu(
|
|
213
|
-
dev_info: Device, op: FeOperation, fn: Callable, *args: Any, **kwargs: Any
|
|
214
|
-
) -> Any:
|
|
215
|
-
if not isinstance(op, JaxRunner):
|
|
216
|
-
raise ValueError("SPU device only supports JAX frontend.")
|
|
217
|
-
spu_mask = Mask.from_ranks([member.rank for member in dev_info.members])
|
|
218
|
-
pfunc, in_vars, out_tree = spu.jax_compile(fn, *args, **kwargs)
|
|
219
|
-
assert all(var.pmask == spu_mask for var in in_vars), in_vars
|
|
220
|
-
out_flat = peval(pfunc, in_vars, spu_mask)
|
|
221
|
-
result = tree_unflatten(out_tree, out_flat)
|
|
222
|
-
return tree_map(partial(set_dev_attr, dev_id=dev_info.name), result)
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
def _device_run_tee(
|
|
226
|
-
dev_info: Device, op: FeOperation, *args: Any, **kwargs: Any
|
|
227
|
-
) -> Any:
|
|
228
|
-
# TODO(jint): should we filter out all IO operations?
|
|
229
|
-
assert len(dev_info.members) == 1
|
|
230
|
-
rank = dev_info.members[0].rank
|
|
231
|
-
var = run_at(rank, op, *args, **kwargs)
|
|
232
|
-
return tree_map(partial(set_dev_attr, dev_id=dev_info.name), var)
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
def _device_run_ppu(
|
|
236
|
-
dev_info: Device, op: FeOperation, *args: Any, **kwargs: Any
|
|
237
|
-
) -> Any:
|
|
238
|
-
assert len(dev_info.members) == 1
|
|
239
|
-
rank = dev_info.members[0].rank
|
|
240
|
-
var = run_at(rank, op, *args, **kwargs)
|
|
241
|
-
return tree_map(partial(set_dev_attr, dev_id=dev_info.name), var)
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
def _device_run(dev_id: str, op: FeOperation, *args: Any, **kwargs: Any) -> Any:
|
|
245
|
-
assert isinstance(op, FeOperation)
|
|
246
|
-
cluster_spec = cur_ctx().cluster_spec
|
|
247
|
-
if dev_id not in cluster_spec.devices:
|
|
248
|
-
raise ValueError(f"Device {dev_id} not found in cluster spec.")
|
|
249
|
-
dev_info = cluster_spec.devices[dev_id]
|
|
250
|
-
|
|
251
|
-
if g_auto_trans:
|
|
252
|
-
|
|
253
|
-
def trans(obj: Any) -> Any:
|
|
254
|
-
if isinstance(obj, MPObject):
|
|
255
|
-
assert is_device_obj(obj)
|
|
256
|
-
return _d2d(dev_id, obj)
|
|
257
|
-
else:
|
|
258
|
-
return obj
|
|
259
|
-
|
|
260
|
-
args, kwargs = tree_map(trans, (args, kwargs))
|
|
261
|
-
|
|
262
|
-
if dev_info.kind.upper() == "SPU":
|
|
263
|
-
return _device_run_spu(dev_info, op, *args, **kwargs)
|
|
264
|
-
elif dev_info.kind.upper() == "TEE":
|
|
265
|
-
return _device_run_tee(dev_info, op, *args, **kwargs)
|
|
266
|
-
elif dev_info.kind.upper() == "PPU":
|
|
267
|
-
return _device_run_ppu(dev_info, op, *args, **kwargs)
|
|
268
|
-
else:
|
|
269
|
-
raise ValueError(f"Unknown device type: {dev_info.kind}")
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
def device(
|
|
273
|
-
dev_or_fn: str | Callable | None = None, *, fe_type: str = "jax"
|
|
274
|
-
) -> Callable:
|
|
275
|
-
"""Decorator to mark a function to be executed on a specific device.
|
|
276
|
-
|
|
277
|
-
Supports both explicit device specification and automatic device inference:
|
|
278
|
-
|
|
279
|
-
1. Explicit device placement:
|
|
280
|
-
@device("P0")
|
|
281
|
-
def foo(x, y): return x + y
|
|
282
|
-
|
|
283
|
-
2. Auto device inference:
|
|
284
|
-
@device
|
|
285
|
-
def foo(x, y): return x + y
|
|
286
|
-
# Device is inferred from x, y at runtime
|
|
287
|
-
|
|
288
|
-
3. Inline usage:
|
|
289
|
-
result = device(lambda x, y: x + y)(x_on_p0, y_on_p0)
|
|
290
|
-
# Automatically infers device from arguments
|
|
291
|
-
|
|
292
|
-
Args:
|
|
293
|
-
dev_or_fn: Either a device id string ("P0", "SPU", etc.) for explicit placement,
|
|
294
|
-
a callable function for auto inference, or None (same as not providing arg).
|
|
295
|
-
fe_type: The frontend type of the device, could be "jax" or "nnx".
|
|
296
|
-
Not needed if the decorated function is already a FeOperation.
|
|
297
|
-
|
|
298
|
-
Returns:
|
|
299
|
-
A decorator (when dev_or_fn is a string or None) or decorated function (when callable).
|
|
300
|
-
|
|
301
|
-
Raises:
|
|
302
|
-
TypeError: When dev_or_fn is not a string, callable, or None.
|
|
303
|
-
ValueError: When device cannot be inferred or inference is ambiguous.
|
|
304
|
-
|
|
305
|
-
Device Inference Strategy:
|
|
306
|
-
- Same device: All arguments on device D -> execute on D
|
|
307
|
-
- PPU + SPU: Arguments from PPU and SPU -> execute on SPU (secure compute)
|
|
308
|
-
- PPU + TEE: Arguments from PPU and TEE -> execute on TEE (trusted execution)
|
|
309
|
-
- Multiple PPUs: Ambiguous -> error (explicit device required)
|
|
310
|
-
- No device objects: Cannot infer -> error (explicit device required)
|
|
311
|
-
|
|
312
|
-
Example:
|
|
313
|
-
>>> # Explicit device
|
|
314
|
-
>>> @device("P0")
|
|
315
|
-
... def add_explicit(x, y):
|
|
316
|
-
... return x + y
|
|
317
|
-
>>>
|
|
318
|
-
>>> # Auto inference
|
|
319
|
-
>>> @device
|
|
320
|
-
... def add_auto(x, y):
|
|
321
|
-
... return x + y
|
|
322
|
-
>>>
|
|
323
|
-
>>> x_on_p0 = ... # data on P0
|
|
324
|
-
>>> y_on_p0 = ... # data on P0
|
|
325
|
-
>>> result = add_auto(x_on_p0, y_on_p0) # Inferred to P0
|
|
326
|
-
>>>
|
|
327
|
-
>>> x_on_spu = ... # data on SPU
|
|
328
|
-
>>> y_on_p1 = ... # data on P1
|
|
329
|
-
>>> result = add_auto(x_on_spu, y_on_p1) # Inferred to SPU
|
|
330
|
-
"""
|
|
331
|
-
|
|
332
|
-
def _execute_on_device(dev_id: str, fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
333
|
-
"""Helper to execute function on specified device with appropriate frontend."""
|
|
334
|
-
if isinstance(fn, FeOperation):
|
|
335
|
-
return _device_run(dev_id, fn, *args, **kwargs)
|
|
336
|
-
else:
|
|
337
|
-
if fe_type == "jax":
|
|
338
|
-
return _device_run(dev_id, jax_cc.run_jax, fn, *args, **kwargs)
|
|
339
|
-
elif fe_type == "nnx":
|
|
340
|
-
return _device_run(dev_id, nnx_cc.run_nnx, fn, *args, **kwargs)
|
|
341
|
-
else:
|
|
342
|
-
raise ValueError(f"Unsupported frontend type: {fe_type}")
|
|
343
|
-
|
|
344
|
-
# Case 1: device("P0") - Explicit device specification
|
|
345
|
-
if isinstance(dev_or_fn, str):
|
|
346
|
-
dev_id = dev_or_fn
|
|
347
|
-
|
|
348
|
-
def deco(fn: Callable) -> Callable:
|
|
349
|
-
@wraps(fn)
|
|
350
|
-
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
|
351
|
-
return _execute_on_device(dev_id, fn, *args, **kwargs)
|
|
352
|
-
|
|
353
|
-
return wrapped
|
|
354
|
-
|
|
355
|
-
return deco
|
|
356
|
-
|
|
357
|
-
# Case 2: device(fn) or @device - Auto device inference
|
|
358
|
-
elif callable(dev_or_fn):
|
|
359
|
-
fn = dev_or_fn
|
|
360
|
-
|
|
361
|
-
@wraps(fn)
|
|
362
|
-
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
|
363
|
-
try:
|
|
364
|
-
dev_id = _infer_device_from_args(*args, **kwargs)
|
|
365
|
-
except ValueError as e:
|
|
366
|
-
# Enhance error message with function context
|
|
367
|
-
raise ValueError(
|
|
368
|
-
f"Cannot infer device for function '{fn.__name__}': {e!s}"
|
|
369
|
-
) from e
|
|
370
|
-
|
|
371
|
-
return _execute_on_device(dev_id, fn, *args, **kwargs)
|
|
372
|
-
|
|
373
|
-
return wrapped
|
|
374
|
-
|
|
375
|
-
# Case 3: device() or @device() - Return auto-inference decorator
|
|
376
|
-
elif dev_or_fn is None:
|
|
377
|
-
|
|
378
|
-
def deco(fn: Callable) -> Callable:
|
|
379
|
-
return device(fn, fe_type=fe_type)
|
|
380
|
-
|
|
381
|
-
return deco
|
|
382
|
-
|
|
383
|
-
else:
|
|
384
|
-
# More helpful error message for common mistakes
|
|
385
|
-
raise TypeError(
|
|
386
|
-
f"device() expects a device id (string), a function (callable), or nothing. "
|
|
387
|
-
f"Got: {type(dev_or_fn).__name__}.\n"
|
|
388
|
-
f"Usage:\n"
|
|
389
|
-
f" - Explicit device: @device('P0') or device('P0')(fn)\n"
|
|
390
|
-
f" - Auto inference: @device or device(fn)"
|
|
391
|
-
)
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
def _spu_reveal(spu_dev: Device, obj: MPObject, to_mask: Mask) -> MPObject:
|
|
395
|
-
spu_mask = Mask.from_ranks([m.rank for m in spu_dev.members])
|
|
396
|
-
assert obj.pmask == spu_mask, (obj.pmask, spu_mask)
|
|
397
|
-
|
|
398
|
-
# (n_parties, n_shares)
|
|
399
|
-
shares = [mpi.bcast_m(to_mask, rank, obj) for rank in Mask(spu_mask)]
|
|
400
|
-
assert len(shares) == Mask(spu_mask).num_parties(), (shares, spu_mask)
|
|
401
|
-
assert all(share.pmask == to_mask for share in shares)
|
|
402
|
-
|
|
403
|
-
# Reconstruct the original object from shares
|
|
404
|
-
pfunc, ins, _ = spu.reconstruct(*shares)
|
|
405
|
-
return peval(pfunc, ins, to_mask)[0] # type: ignore[no-any-return]
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
def _spu_seal(spu_dev: Device, obj: MPObject) -> list[MPObject]:
|
|
409
|
-
"""Seal plaintext into SPU shares on a specific SPU device.
|
|
410
|
-
|
|
411
|
-
Low-level API: device id is mandatory to avoid ambiguity.
|
|
412
|
-
"""
|
|
413
|
-
if obj.pmask is None:
|
|
414
|
-
raise ValueError("Seal can not apply to dynamic mask objects.")
|
|
415
|
-
|
|
416
|
-
spu_mask = Mask.from_ranks([member.rank for member in spu_dev.members])
|
|
417
|
-
spu_wsize = Mask(spu_mask).num_parties()
|
|
418
|
-
pfunc, ins, _ = spu.makeshares(
|
|
419
|
-
obj, world_size=spu_wsize, visibility=spu.Visibility.SECRET
|
|
420
|
-
)
|
|
421
|
-
assert len(ins) == 1
|
|
422
|
-
shares = peval(pfunc, ins)
|
|
423
|
-
|
|
424
|
-
# scatter the shares to each party.
|
|
425
|
-
outs = [mpi.scatter_m(spu_mask, rank, shares) for rank in obj.pmask]
|
|
426
|
-
return outs
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
|
|
430
|
-
assert isinstance(obj, MPObject)
|
|
431
|
-
frm_dev_id = get_dev_attr(obj)
|
|
432
|
-
|
|
433
|
-
if frm_dev_id == to_dev_id:
|
|
434
|
-
return obj
|
|
435
|
-
|
|
436
|
-
cluster_spec: ClusterSpec = cur_ctx().cluster_spec
|
|
437
|
-
frm_dev = cluster_spec.devices[frm_dev_id]
|
|
438
|
-
to_dev = cluster_spec.devices[to_dev_id]
|
|
439
|
-
frm_to_pair = (frm_dev.kind.upper(), to_dev.kind.upper())
|
|
440
|
-
|
|
441
|
-
if frm_to_pair == ("SPU", "SPU"):
|
|
442
|
-
raise NotImplementedError("Only one SPU is supported for now.")
|
|
443
|
-
elif frm_to_pair == ("SPU", "PPU"):
|
|
444
|
-
assert len(to_dev.members) == 1
|
|
445
|
-
to_rank = to_dev.members[0].rank
|
|
446
|
-
var = _spu_reveal(frm_dev, obj, Mask.from_ranks([to_rank]))
|
|
447
|
-
return tree_map(partial(set_dev_attr, dev_id=to_dev_id), var) # type: ignore[no-any-return]
|
|
448
|
-
elif frm_to_pair == ("PPU", "SPU"):
|
|
449
|
-
assert len(frm_dev.members) == 1
|
|
450
|
-
frm_rank = frm_dev.members[0].rank
|
|
451
|
-
vars = _spu_seal(to_dev, obj)
|
|
452
|
-
assert len(vars) == 1, "Expected single share from PPU to SPU seal."
|
|
453
|
-
return tree_map(partial(set_dev_attr, dev_id=to_dev_id), vars[0]) # type: ignore[no-any-return]
|
|
454
|
-
elif frm_to_pair == ("PPU", "PPU"):
|
|
455
|
-
assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
|
|
456
|
-
frm_rank = frm_dev.members[0].rank
|
|
457
|
-
to_rank = to_dev.members[0].rank
|
|
458
|
-
var = mpi.p2p(frm_rank, to_rank, obj)
|
|
459
|
-
return tree_map(partial(set_dev_attr, dev_id=to_dev_id), var) # type: ignore[no-any-return]
|
|
460
|
-
elif frm_to_pair == ("PPU", "TEE"):
|
|
461
|
-
# Transparent handshake + encryption for the first transfer; reuse thereafter
|
|
462
|
-
assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
|
|
463
|
-
frm_rank = frm_dev.members[0].rank
|
|
464
|
-
tee_rank = to_dev.members[0].rank
|
|
465
|
-
# Ensure sessions (both directions) exist for this PPU<->TEE pair
|
|
466
|
-
sess_p, sess_t = _ensure_tee_session(frm_dev_id, to_dev_id, frm_rank, tee_rank)
|
|
467
|
-
# Bytes-only path: pack -> enc -> p2p -> dec -> unpack (with static out type)
|
|
468
|
-
obj_ty = obj.mptype.raw_type()
|
|
469
|
-
b = run_at(frm_rank, basic.pack, obj)
|
|
470
|
-
ct = run_at(frm_rank, crypto.enc, b, sess_p)
|
|
471
|
-
ct_at_tee = mpi.p2p(frm_rank, tee_rank, ct)
|
|
472
|
-
b_at_tee = run_at(tee_rank, crypto.dec, ct_at_tee, sess_t)
|
|
473
|
-
pt_at_tee = run_at(tee_rank, basic.unpack, b_at_tee, out_ty=obj_ty)
|
|
474
|
-
return tree_map(partial(set_dev_attr, dev_id=to_dev_id), pt_at_tee) # type: ignore[no-any-return]
|
|
475
|
-
elif frm_to_pair == ("TEE", "PPU"):
|
|
476
|
-
# Transparent encryption from TEE to a specific PPU using the reverse-direction session key
|
|
477
|
-
assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
|
|
478
|
-
tee_rank = frm_dev.members[0].rank
|
|
479
|
-
ppu_rank = to_dev.members[0].rank
|
|
480
|
-
# Ensure bidirectional session established for this pair
|
|
481
|
-
sess_p, sess_t = _ensure_tee_session(to_dev_id, frm_dev_id, ppu_rank, tee_rank)
|
|
482
|
-
obj_ty = obj.mptype.raw_type()
|
|
483
|
-
b = run_at(tee_rank, basic.pack, obj)
|
|
484
|
-
ct = run_at(tee_rank, crypto.enc, b, sess_t)
|
|
485
|
-
ct_at_ppu = mpi.p2p(tee_rank, ppu_rank, ct)
|
|
486
|
-
b_at_ppu = run_at(ppu_rank, crypto.dec, ct_at_ppu, sess_p)
|
|
487
|
-
pt_at_ppu = run_at(ppu_rank, basic.unpack, b_at_ppu, out_ty=obj_ty)
|
|
488
|
-
return tree_map(partial(set_dev_attr, dev_id=to_dev_id), pt_at_ppu) # type: ignore[no-any-return]
|
|
489
|
-
elif frm_to_pair == ("TEE", "SPU"):
|
|
490
|
-
assert len(frm_dev.members) == 1
|
|
491
|
-
frm_rank = frm_dev.members[0].rank
|
|
492
|
-
vars = _spu_seal(to_dev, obj)
|
|
493
|
-
assert len(vars) == 1, "Expected single share from TEE to SPU seal."
|
|
494
|
-
return tree_map(partial(set_dev_attr, dev_id=to_dev_id), vars[0]) # type: ignore[no-any-return]
|
|
495
|
-
elif frm_to_pair == ("SPU", "TEE"):
|
|
496
|
-
assert len(to_dev.members) == 1
|
|
497
|
-
to_rank = to_dev.members[0].rank
|
|
498
|
-
var = _spu_reveal(frm_dev, obj, Mask.from_ranks([to_rank]))
|
|
499
|
-
return tree_map(partial(set_dev_attr, dev_id=to_dev_id), var) # type: ignore[no-any-return]
|
|
500
|
-
else:
|
|
501
|
-
supported = [
|
|
502
|
-
("SPU", "PPU"),
|
|
503
|
-
("PPU", "SPU"),
|
|
504
|
-
("PPU", "PPU"),
|
|
505
|
-
("PPU", "TEE"),
|
|
506
|
-
("TEE", "PPU"),
|
|
507
|
-
("TEE", "SPU"),
|
|
508
|
-
("SPU", "TEE"),
|
|
509
|
-
]
|
|
510
|
-
raise ValueError(
|
|
511
|
-
f"Unsupported device transfer: {frm_to_pair}. Supported pairs: {supported}."
|
|
512
|
-
)
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
def _ensure_tee_session(
|
|
516
|
-
frm_dev_id: str, to_dev_id: str, frm_rank: int, tee_rank: int
|
|
517
|
-
) -> tuple[MPObject, MPObject]:
|
|
518
|
-
"""Ensure a TEE session (sess_p at sender, sess_t at TEE) exists.
|
|
519
|
-
|
|
520
|
-
Context-aware version: caches include context ID to ensure isolation
|
|
521
|
-
between different TraceContext instances, preventing TraceVar pollution.
|
|
522
|
-
|
|
523
|
-
Returns (sess_p, sess_t).
|
|
524
|
-
"""
|
|
525
|
-
# Get current context and its unique ID
|
|
526
|
-
current_ctx = cur_ctx()
|
|
527
|
-
current_context_id = _get_context_id(current_ctx)
|
|
528
|
-
|
|
529
|
-
# Get root context for cache storage
|
|
530
|
-
root_ctx = current_ctx.root()
|
|
531
|
-
if not hasattr(root_ctx, "_tee_sessions"):
|
|
532
|
-
root_ctx._tee_sessions = {} # type: ignore[attr-defined]
|
|
533
|
-
cache: dict[tuple[str, str], tuple[int, MPObject, MPObject]] = (
|
|
534
|
-
root_ctx._tee_sessions # type: ignore[attr-defined]
|
|
535
|
-
)
|
|
536
|
-
|
|
537
|
-
key = (frm_dev_id, to_dev_id)
|
|
538
|
-
|
|
539
|
-
# Check cache with context awareness
|
|
540
|
-
if key in cache:
|
|
541
|
-
cached_context_id, sess_p, sess_t = cache[key]
|
|
542
|
-
|
|
543
|
-
# Only reuse cache from the same context
|
|
544
|
-
if cached_context_id == current_context_id:
|
|
545
|
-
return sess_p, sess_t
|
|
546
|
-
else:
|
|
547
|
-
# Different context, cannot reuse cache, clean up old entry
|
|
548
|
-
del cache[key]
|
|
549
|
-
|
|
550
|
-
# 1) TEE generates (sk, pk) and quote(pk)
|
|
551
|
-
# KEM suite currently constant; future: read from tee device config (e.g. cluster_spec.devices[dev_id].config)
|
|
552
|
-
tee_sk, tee_pk = run_at(tee_rank, crypto.kem_keygen, _TEE_KEM_SUITE)
|
|
553
|
-
quote = run_at(tee_rank, tee.quote_gen, tee_pk)
|
|
554
|
-
|
|
555
|
-
# 2) Send quote to sender and attest to obtain TEE pk
|
|
556
|
-
quote_at_sender = mpi.p2p(tee_rank, frm_rank, quote)
|
|
557
|
-
tee_pk_at_sender = run_at(frm_rank, tee.attest, quote_at_sender)
|
|
558
|
-
|
|
559
|
-
# 3) Sender generates its ephemeral keypair and sends its pk to TEE
|
|
560
|
-
v_sk, v_pk = run_at(frm_rank, crypto.kem_keygen, _TEE_KEM_SUITE)
|
|
561
|
-
v_pk_at_tee = mpi.p2p(frm_rank, tee_rank, v_pk)
|
|
562
|
-
|
|
563
|
-
# 4) Both sides derive the shared secret and session key
|
|
564
|
-
shared_p = run_at(
|
|
565
|
-
frm_rank, crypto.kem_derive, v_sk, tee_pk_at_sender, _TEE_KEM_SUITE
|
|
566
|
-
)
|
|
567
|
-
shared_t = run_at(tee_rank, crypto.kem_derive, tee_sk, v_pk_at_tee, _TEE_KEM_SUITE)
|
|
568
|
-
# Use a fixed ASCII string literal for HKDF info on both sides
|
|
569
|
-
sess_p = run_at(frm_rank, crypto.hkdf, shared_p, _HKDF_INFO_LITERAL)
|
|
570
|
-
sess_t = run_at(tee_rank, crypto.hkdf, shared_t, _HKDF_INFO_LITERAL)
|
|
571
|
-
|
|
572
|
-
# Cache with context ID for isolation
|
|
573
|
-
cache[key] = (current_context_id, sess_p, sess_t)
|
|
574
|
-
return sess_p, sess_t
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
def _host_to_device(to_dev_id: str, obj: Any) -> MPObject:
|
|
578
|
-
if isinstance(obj, TensorLike):
|
|
579
|
-
# run jax identity on the target device to put the tensor there
|
|
580
|
-
return device(to_dev_id)(lambda x: x)(obj) # type: ignore[no-any-return]
|
|
581
|
-
elif isinstance(obj, TableLike):
|
|
582
|
-
dev_info = cur_ctx().cluster_spec.devices[to_dev_id]
|
|
583
|
-
if dev_info.kind.upper() not in ["PPU", "TEE"]:
|
|
584
|
-
raise ValueError(
|
|
585
|
-
f"TableLike put() only supports PPU or TEE devices, got {dev_info.kind}"
|
|
586
|
-
)
|
|
587
|
-
assert len(dev_info.members) == 1
|
|
588
|
-
rank = dev_info.members[0].rank
|
|
589
|
-
obj_mp = cast(MPObject, run_at(rank, basic.constant, obj))
|
|
590
|
-
set_dev_attr(obj_mp, to_dev_id)
|
|
591
|
-
return obj_mp
|
|
592
|
-
else:
|
|
593
|
-
raise TypeError(
|
|
594
|
-
f"put() only supports TensorLike or TableLike objects, got {type(obj)}"
|
|
595
|
-
)
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
def put(to_dev_id: str, obj: Any) -> MPObject:
|
|
599
|
-
if not isinstance(obj, MPObject):
|
|
600
|
-
return _host_to_device(to_dev_id, obj)
|
|
601
|
-
assert isinstance(obj, MPObject)
|
|
602
|
-
return _d2d(to_dev_id, obj)
|
mplang/v1/analysis/__init__.py
DELETED
|
@@ -1,37 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""Analysis and visualization utilities for mplang.
|
|
16
|
-
|
|
17
|
-
This subpackage hosts non-core developer aids: diagram rendering, IR dumps,
|
|
18
|
-
profiling helpers (future), etc.
|
|
19
|
-
"""
|
|
20
|
-
|
|
21
|
-
from mplang.v1.analysis.diagram import (
|
|
22
|
-
DumpResult,
|
|
23
|
-
FlowchartOptions,
|
|
24
|
-
SequenceDiagramOptions,
|
|
25
|
-
dump,
|
|
26
|
-
to_flowchart,
|
|
27
|
-
to_sequence_diagram,
|
|
28
|
-
)
|
|
29
|
-
|
|
30
|
-
__all__ = [
|
|
31
|
-
"DumpResult",
|
|
32
|
-
"FlowchartOptions",
|
|
33
|
-
"SequenceDiagramOptions",
|
|
34
|
-
"dump",
|
|
35
|
-
"to_flowchart",
|
|
36
|
-
"to_sequence_diagram",
|
|
37
|
-
]
|