mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/core/__init__.py
DELETED
|
@@ -1,92 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
Core components for multi-party computation.
|
|
17
|
-
|
|
18
|
-
This package provides the fundamental building blocks for multi-party computation,
|
|
19
|
-
including type systems, tracing mechanisms, and interpreter contexts.
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
# Core type system
|
|
23
|
-
# Communication interfaces & core symbols
|
|
24
|
-
from mplang.core.comm import (
|
|
25
|
-
CollectiveMixin,
|
|
26
|
-
CommunicatorBase,
|
|
27
|
-
ICollective,
|
|
28
|
-
ICommunicator,
|
|
29
|
-
)
|
|
30
|
-
from mplang.core.dtype import DType
|
|
31
|
-
from mplang.core.interp import InterpContext, InterpVar
|
|
32
|
-
from mplang.core.mask import Mask
|
|
33
|
-
from mplang.core.mpobject import MPContext, MPObject
|
|
34
|
-
from mplang.core.mptype import MPType, Rank, Shape
|
|
35
|
-
from mplang.core.pfunc import PFunction
|
|
36
|
-
from mplang.core.primitive import (
|
|
37
|
-
constant,
|
|
38
|
-
debug_print,
|
|
39
|
-
function,
|
|
40
|
-
pconv,
|
|
41
|
-
peval,
|
|
42
|
-
prand,
|
|
43
|
-
prank,
|
|
44
|
-
pshfl,
|
|
45
|
-
pshfl_s,
|
|
46
|
-
psize,
|
|
47
|
-
set_mask,
|
|
48
|
-
uniform_cond,
|
|
49
|
-
while_loop,
|
|
50
|
-
)
|
|
51
|
-
from mplang.core.table import TableLike, TableType
|
|
52
|
-
from mplang.core.tensor import TensorLike, TensorType
|
|
53
|
-
from mplang.core.tracer import TraceContext, TracedFunction, TraceVar, VarNamer, trace
|
|
54
|
-
|
|
55
|
-
__all__ = [
|
|
56
|
-
"CollectiveMixin",
|
|
57
|
-
"CommunicatorBase",
|
|
58
|
-
"DType",
|
|
59
|
-
"ICollective",
|
|
60
|
-
"ICommunicator",
|
|
61
|
-
"InterpContext",
|
|
62
|
-
"InterpVar",
|
|
63
|
-
"MPContext",
|
|
64
|
-
"MPObject",
|
|
65
|
-
"MPType",
|
|
66
|
-
"Mask",
|
|
67
|
-
"PFunction",
|
|
68
|
-
"Rank",
|
|
69
|
-
"Shape",
|
|
70
|
-
"TableLike",
|
|
71
|
-
"TableType",
|
|
72
|
-
"TensorLike",
|
|
73
|
-
"TensorType",
|
|
74
|
-
"TraceContext",
|
|
75
|
-
"TraceVar",
|
|
76
|
-
"TracedFunction",
|
|
77
|
-
"VarNamer",
|
|
78
|
-
"constant",
|
|
79
|
-
"debug_print",
|
|
80
|
-
"function",
|
|
81
|
-
"pconv",
|
|
82
|
-
"peval",
|
|
83
|
-
"prand",
|
|
84
|
-
"prank",
|
|
85
|
-
"pshfl",
|
|
86
|
-
"pshfl_s",
|
|
87
|
-
"psize",
|
|
88
|
-
"set_mask",
|
|
89
|
-
"trace",
|
|
90
|
-
"uniform_cond",
|
|
91
|
-
"while_loop",
|
|
92
|
-
]
|
mplang/device.py
DELETED
|
@@ -1,340 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
This module provides the device oriented programming interface for MPC.
|
|
17
|
-
|
|
18
|
-
The device oriented programming interface is designed to provide a high-level
|
|
19
|
-
abstraction for the MPC programming. It allows the user to write the program
|
|
20
|
-
in a device-oriented manner, and the runtime will take care of the data
|
|
21
|
-
transformation between devices.
|
|
22
|
-
"""
|
|
23
|
-
|
|
24
|
-
from __future__ import annotations
|
|
25
|
-
|
|
26
|
-
from collections.abc import Callable
|
|
27
|
-
from functools import partial, wraps
|
|
28
|
-
from typing import Any
|
|
29
|
-
|
|
30
|
-
from jax.tree_util import tree_map
|
|
31
|
-
|
|
32
|
-
import mplang.api as mapi
|
|
33
|
-
from mplang import simp
|
|
34
|
-
from mplang.core import InterpContext, MPObject, primitive
|
|
35
|
-
from mplang.core.cluster import ClusterSpec, Device
|
|
36
|
-
from mplang.core.context_mgr import cur_ctx
|
|
37
|
-
from mplang.core.tensor import TensorType
|
|
38
|
-
from mplang.ops import builtin, crypto, ibis_cc, jax_cc, tee
|
|
39
|
-
from mplang.ops.base import FeOperation
|
|
40
|
-
from mplang.ops.ibis_cc import IbisCompiler
|
|
41
|
-
from mplang.ops.jax_cc import JaxCompiler
|
|
42
|
-
from mplang.simp import mpi, smpc
|
|
43
|
-
|
|
44
|
-
# Automatic transfer between devices when parameter is not on the target device.
|
|
45
|
-
g_auto_trans: bool = True
|
|
46
|
-
|
|
47
|
-
_HKDF_INFO_LITERAL: str = "mplang/device/tee/v1"
|
|
48
|
-
# Default KEM suite for TEE session establishment; make configurable via ClusterSpec in future.
|
|
49
|
-
_TEE_KEM_SUITE: str = "x25519"
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
# `function` decorator could also compile device-level apis.
|
|
53
|
-
function = primitive.function
|
|
54
|
-
|
|
55
|
-
# magic attribute name to mark a MPObject as a device object
|
|
56
|
-
DEVICE_ATTR_NAME = "_devid_"
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def _is_device_obj(obj: Any) -> bool:
|
|
60
|
-
if not isinstance(obj, MPObject):
|
|
61
|
-
return False
|
|
62
|
-
return DEVICE_ATTR_NAME in obj.attrs
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
def _set_devid(obj: MPObject, dev_id: str) -> MPObject:
|
|
66
|
-
if not isinstance(obj, MPObject):
|
|
67
|
-
raise TypeError(f"Input must be an instance of Object, {obj}")
|
|
68
|
-
obj.attrs[DEVICE_ATTR_NAME] = dev_id
|
|
69
|
-
return obj
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
def _get_devid(obj: MPObject) -> str:
|
|
73
|
-
if not isinstance(obj, MPObject):
|
|
74
|
-
raise TypeError("Input must be an instance of Object")
|
|
75
|
-
|
|
76
|
-
return obj.attrs[DEVICE_ATTR_NAME] # type: ignore[no-any-return]
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
_is_mpobj = lambda x: isinstance(x, MPObject)
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
def _device_run_spu(
|
|
83
|
-
dev_info: Device, op: FeOperation, *args: Any, **kwargs: Any
|
|
84
|
-
) -> Any:
|
|
85
|
-
if not isinstance(op, JaxCompiler):
|
|
86
|
-
raise ValueError("SPU device only supports JAX frontend.")
|
|
87
|
-
fn, *aargs = args
|
|
88
|
-
var = smpc.srun(fn)(*aargs, **kwargs)
|
|
89
|
-
return tree_map(partial(_set_devid, dev_id=dev_info.name), var)
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
def _device_run_tee(
|
|
93
|
-
dev_info: Device, op: FeOperation, *args: Any, **kwargs: Any
|
|
94
|
-
) -> Any:
|
|
95
|
-
if not isinstance(op, JaxCompiler) and not isinstance(op, IbisCompiler):
|
|
96
|
-
raise ValueError("TEE device only supports JAX and Ibis frontend.")
|
|
97
|
-
assert len(dev_info.members) == 1
|
|
98
|
-
rank = dev_info.members[0].rank
|
|
99
|
-
var = simp.runAt(rank, op)(*args, **kwargs)
|
|
100
|
-
return tree_map(partial(_set_devid, dev_id=dev_info.name), var)
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
def _device_run_ppu(
|
|
104
|
-
dev_info: Device, op: FeOperation, *args: Any, **kwargs: Any
|
|
105
|
-
) -> Any:
|
|
106
|
-
assert len(dev_info.members) == 1
|
|
107
|
-
rank = dev_info.members[0].rank
|
|
108
|
-
var = simp.runAt(rank, op)(*args, **kwargs)
|
|
109
|
-
return tree_map(partial(_set_devid, dev_id=dev_info.name), var)
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
def _device_run(dev_id: str, op: FeOperation, *args: Any, **kwargs: Any) -> Any:
|
|
113
|
-
assert isinstance(op, FeOperation)
|
|
114
|
-
cluster_spec = mapi.cur_ctx().cluster_spec
|
|
115
|
-
if dev_id not in cluster_spec.devices:
|
|
116
|
-
raise ValueError(f"Device {dev_id} not found in cluster spec.")
|
|
117
|
-
|
|
118
|
-
if g_auto_trans:
|
|
119
|
-
|
|
120
|
-
def trans(obj: Any) -> Any:
|
|
121
|
-
if _is_mpobj(obj):
|
|
122
|
-
assert _is_device_obj(obj)
|
|
123
|
-
return _d2d(dev_id, obj)
|
|
124
|
-
else:
|
|
125
|
-
return obj
|
|
126
|
-
|
|
127
|
-
args, kwargs = tree_map(trans, (args, kwargs))
|
|
128
|
-
|
|
129
|
-
dev_info = cluster_spec.devices[dev_id]
|
|
130
|
-
if dev_info.kind.upper() == "SPU":
|
|
131
|
-
return _device_run_spu(dev_info, op, *args, **kwargs)
|
|
132
|
-
elif dev_info.kind.upper() == "TEE":
|
|
133
|
-
return _device_run_tee(dev_info, op, *args, **kwargs)
|
|
134
|
-
elif dev_info.kind.upper() == "PPU":
|
|
135
|
-
return _device_run_ppu(dev_info, op, *args, **kwargs)
|
|
136
|
-
else:
|
|
137
|
-
raise ValueError(f"Unknown device type: {dev_info.kind}")
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
def device(dev_id: str, *, fe_type: str = "jax") -> Callable:
|
|
141
|
-
"""Decorator to mark a function to be executed on a specific device.
|
|
142
|
-
|
|
143
|
-
Args:
|
|
144
|
-
dev_id: The device id.
|
|
145
|
-
fe_type: The frontend type of the device, could be "jax" or "ibis".
|
|
146
|
-
|
|
147
|
-
Note: 'fe_type' is not needed if the decorated function is already a FeOperation.
|
|
148
|
-
|
|
149
|
-
Example:
|
|
150
|
-
>>> @device("P0")
|
|
151
|
-
... def foo(x, y):
|
|
152
|
-
... return x + y
|
|
153
|
-
"""
|
|
154
|
-
|
|
155
|
-
def deco(fn: Callable) -> Callable:
|
|
156
|
-
@wraps(fn)
|
|
157
|
-
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
|
158
|
-
if isinstance(fn, FeOperation):
|
|
159
|
-
return _device_run(dev_id, fn, *args, **kwargs)
|
|
160
|
-
else:
|
|
161
|
-
if fe_type == "jax":
|
|
162
|
-
return _device_run(dev_id, jax_cc.jax_compile, fn, *args, **kwargs)
|
|
163
|
-
elif fe_type == "ibis":
|
|
164
|
-
return _device_run(
|
|
165
|
-
dev_id, ibis_cc.ibis_compile, fn, *args, **kwargs
|
|
166
|
-
)
|
|
167
|
-
else:
|
|
168
|
-
raise ValueError(f"Unsupported frontend type: {fe_type}")
|
|
169
|
-
|
|
170
|
-
return wrapped
|
|
171
|
-
|
|
172
|
-
return deco
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
|
|
176
|
-
assert isinstance(obj, MPObject)
|
|
177
|
-
frm_dev_id = _get_devid(obj)
|
|
178
|
-
|
|
179
|
-
if frm_dev_id == to_dev_id:
|
|
180
|
-
return obj
|
|
181
|
-
|
|
182
|
-
cluster_spec: ClusterSpec = mapi.cur_ctx().cluster_spec
|
|
183
|
-
frm_dev = cluster_spec.devices[frm_dev_id]
|
|
184
|
-
to_dev = cluster_spec.devices[to_dev_id]
|
|
185
|
-
frm_to_pair = (frm_dev.kind.upper(), to_dev.kind.upper())
|
|
186
|
-
|
|
187
|
-
if frm_to_pair == ("SPU", "SPU"):
|
|
188
|
-
raise NotImplementedError("Only one SPU is supported for now.")
|
|
189
|
-
elif frm_to_pair == ("SPU", "PPU"):
|
|
190
|
-
assert len(to_dev.members) == 1
|
|
191
|
-
to_rank = to_dev.members[0].rank
|
|
192
|
-
var = smpc.revealTo(obj, to_rank)
|
|
193
|
-
return tree_map(partial(_set_devid, dev_id=to_dev_id), var) # type: ignore[no-any-return]
|
|
194
|
-
elif frm_to_pair == ("PPU", "SPU"):
|
|
195
|
-
assert len(frm_dev.members) == 1
|
|
196
|
-
frm_rank = frm_dev.members[0].rank
|
|
197
|
-
var = smpc.sealFrom(obj, frm_rank)
|
|
198
|
-
return tree_map(partial(_set_devid, dev_id=to_dev_id), var) # type: ignore[no-any-return]
|
|
199
|
-
elif frm_to_pair == ("PPU", "PPU"):
|
|
200
|
-
assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
|
|
201
|
-
frm_rank = frm_dev.members[0].rank
|
|
202
|
-
to_rank = to_dev.members[0].rank
|
|
203
|
-
var = mpi.p2p(frm_rank, to_rank, obj)
|
|
204
|
-
return tree_map(partial(_set_devid, dev_id=to_dev_id), var) # type: ignore[no-any-return]
|
|
205
|
-
elif frm_to_pair == ("PPU", "TEE"):
|
|
206
|
-
# Transparent handshake + encryption for the first transfer; reuse thereafter
|
|
207
|
-
assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
|
|
208
|
-
frm_rank = frm_dev.members[0].rank
|
|
209
|
-
tee_rank = to_dev.members[0].rank
|
|
210
|
-
platform = to_dev.config.get("platform")
|
|
211
|
-
if not platform:
|
|
212
|
-
raise ValueError(
|
|
213
|
-
f"TEE device '{to_dev_id}' is missing 'platform' in its config."
|
|
214
|
-
)
|
|
215
|
-
# Ensure sessions (both directions) exist for this PPU<->TEE pair
|
|
216
|
-
sess_p, sess_t = _ensure_tee_session(
|
|
217
|
-
frm_dev_id, to_dev_id, frm_rank, tee_rank, platform
|
|
218
|
-
)
|
|
219
|
-
# Bytes-only path: pack -> enc -> p2p -> dec -> unpack (with static out type)
|
|
220
|
-
obj_ty = TensorType.from_obj(obj)
|
|
221
|
-
b = simp.runAt(frm_rank, builtin.pack)(obj)
|
|
222
|
-
ct = simp.runAt(frm_rank, crypto.enc)(b, sess_p)
|
|
223
|
-
ct_at_tee = mpi.p2p(frm_rank, tee_rank, ct)
|
|
224
|
-
b_at_tee = simp.runAt(tee_rank, crypto.dec)(ct_at_tee, sess_t)
|
|
225
|
-
pt_at_tee = simp.runAt(tee_rank, builtin.unpack)(b_at_tee, out_ty=obj_ty)
|
|
226
|
-
return tree_map(partial(_set_devid, dev_id=to_dev_id), pt_at_tee) # type: ignore[no-any-return]
|
|
227
|
-
elif frm_to_pair == ("TEE", "PPU"):
|
|
228
|
-
# Transparent encryption from TEE to a specific PPU using the reverse-direction session key
|
|
229
|
-
assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
|
|
230
|
-
tee_rank = frm_dev.members[0].rank
|
|
231
|
-
ppu_rank = to_dev.members[0].rank
|
|
232
|
-
platform = frm_dev.config.get("platform")
|
|
233
|
-
if not platform:
|
|
234
|
-
raise ValueError(
|
|
235
|
-
f"TEE device '{frm_dev_id}' is missing 'platform' in its config."
|
|
236
|
-
)
|
|
237
|
-
# Ensure bidirectional session established for this pair
|
|
238
|
-
sess_p, sess_t = _ensure_tee_session(
|
|
239
|
-
to_dev_id, frm_dev_id, ppu_rank, tee_rank, platform
|
|
240
|
-
)
|
|
241
|
-
obj_ty = TensorType.from_obj(obj)
|
|
242
|
-
b = simp.runAt(tee_rank, builtin.pack)(obj)
|
|
243
|
-
ct = simp.runAt(tee_rank, crypto.enc)(b, sess_t)
|
|
244
|
-
ct_at_ppu = mpi.p2p(tee_rank, ppu_rank, ct)
|
|
245
|
-
b_at_ppu = simp.runAt(ppu_rank, crypto.dec)(ct_at_ppu, sess_p)
|
|
246
|
-
pt_at_ppu = simp.runAt(ppu_rank, builtin.unpack)(b_at_ppu, out_ty=obj_ty)
|
|
247
|
-
return tree_map(partial(_set_devid, dev_id=to_dev_id), pt_at_ppu) # type: ignore[no-any-return]
|
|
248
|
-
else:
|
|
249
|
-
supported = [
|
|
250
|
-
("SPU", "PPU"),
|
|
251
|
-
("PPU", "SPU"),
|
|
252
|
-
("PPU", "PPU"),
|
|
253
|
-
("PPU", "TEE"),
|
|
254
|
-
("TEE", "PPU"),
|
|
255
|
-
]
|
|
256
|
-
raise ValueError(
|
|
257
|
-
f"Unsupported device transfer: {frm_to_pair}. Supported pairs: {supported}."
|
|
258
|
-
)
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
def _ensure_tee_session(
|
|
262
|
-
frm_dev_id: str, to_dev_id: str, frm_rank: int, tee_rank: int, platform: str
|
|
263
|
-
) -> tuple[MPObject, MPObject]:
|
|
264
|
-
"""Ensure a TEE session (sess_p at sender, sess_t at TEE) exists.
|
|
265
|
-
|
|
266
|
-
Returns (sess_p, sess_t).
|
|
267
|
-
"""
|
|
268
|
-
ctx = cur_ctx().root()
|
|
269
|
-
if not hasattr(ctx, "_tee_sessions"):
|
|
270
|
-
ctx._tee_sessions = {} # type: ignore[attr-defined]
|
|
271
|
-
cache: dict[tuple[str, str], tuple[MPObject, MPObject]] = ctx._tee_sessions # type: ignore
|
|
272
|
-
|
|
273
|
-
key = (frm_dev_id, to_dev_id)
|
|
274
|
-
if key in cache:
|
|
275
|
-
return cache[key]
|
|
276
|
-
|
|
277
|
-
# 1) TEE generates (sk, pk) and quote(pk)
|
|
278
|
-
# KEM suite currently constant; future: read from tee device config (e.g. cluster_spec.devices[dev_id].config)
|
|
279
|
-
tee_sk, tee_pk = simp.runAt(tee_rank, crypto.kem_keygen)(_TEE_KEM_SUITE)
|
|
280
|
-
quote = simp.runAt(tee_rank, tee.quote_gen)(tee_pk)
|
|
281
|
-
|
|
282
|
-
# 2) Send quote to sender and attest to obtain TEE pk
|
|
283
|
-
quote_at_sender = mpi.p2p(tee_rank, frm_rank, quote)
|
|
284
|
-
tee_pk_at_sender = simp.runAt(frm_rank, tee.attest)(quote_at_sender, platform)
|
|
285
|
-
|
|
286
|
-
# 3) Sender generates its ephemeral keypair and sends its pk to TEE
|
|
287
|
-
v_sk, v_pk = simp.runAt(frm_rank, crypto.kem_keygen)(_TEE_KEM_SUITE)
|
|
288
|
-
v_pk_at_tee = mpi.p2p(frm_rank, tee_rank, v_pk)
|
|
289
|
-
|
|
290
|
-
# 4) Both sides derive the shared secret and session key
|
|
291
|
-
shared_p = simp.runAt(frm_rank, crypto.kem_derive)(
|
|
292
|
-
v_sk, tee_pk_at_sender, _TEE_KEM_SUITE
|
|
293
|
-
)
|
|
294
|
-
shared_t = simp.runAt(tee_rank, crypto.kem_derive)(
|
|
295
|
-
tee_sk, v_pk_at_tee, _TEE_KEM_SUITE
|
|
296
|
-
)
|
|
297
|
-
# Use a fixed ASCII string literal for HKDF info on both sides
|
|
298
|
-
sess_p = simp.runAt(frm_rank, crypto.hkdf)(shared_p, _HKDF_INFO_LITERAL)
|
|
299
|
-
sess_t = simp.runAt(tee_rank, crypto.hkdf)(shared_t, _HKDF_INFO_LITERAL)
|
|
300
|
-
|
|
301
|
-
cache[key] = (sess_p, sess_t)
|
|
302
|
-
return sess_p, sess_t
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
def put(to_dev_id: str, obj: Any) -> MPObject:
|
|
306
|
-
if not isinstance(obj, MPObject):
|
|
307
|
-
return device(to_dev_id)(lambda x: x)(obj) # type: ignore[no-any-return]
|
|
308
|
-
assert isinstance(obj, MPObject)
|
|
309
|
-
return _d2d(to_dev_id, obj)
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
def _fetch(interp: InterpContext, obj: MPObject) -> Any:
|
|
313
|
-
dev_id = _get_devid(obj)
|
|
314
|
-
cluster_spec = interp.cluster_spec
|
|
315
|
-
dev_kind = cluster_spec.devices[dev_id].kind.upper()
|
|
316
|
-
|
|
317
|
-
dev_info = cluster_spec.devices[dev_id]
|
|
318
|
-
if dev_kind == "SPU":
|
|
319
|
-
revealed = mapi.evaluate(interp, smpc.reveal, obj)
|
|
320
|
-
result = mapi.fetch(interp, revealed)
|
|
321
|
-
# now all members have the same value, return the one at rank 0
|
|
322
|
-
return result[dev_info.members[0].rank]
|
|
323
|
-
elif dev_kind == "PPU":
|
|
324
|
-
assert len(dev_info.members) == 1
|
|
325
|
-
rank = dev_info.members[0].rank
|
|
326
|
-
result = mapi.fetch(interp, obj)
|
|
327
|
-
return result[rank]
|
|
328
|
-
elif dev_kind == "TEE":
|
|
329
|
-
assert len(dev_info.members) == 1
|
|
330
|
-
rank = dev_info.members[0].rank
|
|
331
|
-
result = mapi.fetch(interp, obj)
|
|
332
|
-
return result[rank]
|
|
333
|
-
else:
|
|
334
|
-
raise ValueError(f"Unknown device id: {dev_id}")
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
def fetch(interp: InterpContext, objs: Any) -> Any:
|
|
338
|
-
ctx = interp or mapi.cur_ctx()
|
|
339
|
-
assert isinstance(ctx, InterpContext), f"Expect InterpContext, got {ctx}"
|
|
340
|
-
return tree_map(partial(_fetch, ctx), objs)
|
mplang/kernels/builtin.py
DELETED
|
@@ -1,207 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
from typing import Any
|
|
18
|
-
|
|
19
|
-
import numpy as np
|
|
20
|
-
import pandas as pd
|
|
21
|
-
|
|
22
|
-
from mplang.core.pfunc import PFunction
|
|
23
|
-
from mplang.core.table import TableType
|
|
24
|
-
from mplang.core.tensor import TensorType
|
|
25
|
-
from mplang.kernels.base import cur_kctx, kernel_def
|
|
26
|
-
from mplang.runtime.data_providers import get_provider, resolve_uri
|
|
27
|
-
from mplang.utils import table_utils
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def _to_numpy(obj: Any) -> np.ndarray: # minimal helper to avoid duplicating logic
|
|
31
|
-
if isinstance(obj, np.ndarray):
|
|
32
|
-
return obj
|
|
33
|
-
if hasattr(obj, "numpy"):
|
|
34
|
-
try:
|
|
35
|
-
return np.asarray(obj.numpy()) # type: ignore
|
|
36
|
-
except Exception:
|
|
37
|
-
pass
|
|
38
|
-
return np.asarray(obj)
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
@kernel_def("builtin.identity")
|
|
42
|
-
def _identity(pfunc: PFunction, value: Any) -> Any:
|
|
43
|
-
# Runtime guarantees exactly one argument; no extra arity checks here.
|
|
44
|
-
return value
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
@kernel_def("builtin.read")
|
|
48
|
-
def _read(pfunc: PFunction) -> Any:
|
|
49
|
-
path = pfunc.attrs.get("path")
|
|
50
|
-
if path is None:
|
|
51
|
-
raise ValueError("missing path attr for builtin.read")
|
|
52
|
-
out_t = pfunc.outs_info[0]
|
|
53
|
-
uri = resolve_uri(str(path))
|
|
54
|
-
prov = get_provider(uri.scheme)
|
|
55
|
-
if prov is None:
|
|
56
|
-
raise NotImplementedError(f"no resource provider for scheme: {uri.scheme}")
|
|
57
|
-
ctx = cur_kctx()
|
|
58
|
-
try:
|
|
59
|
-
return prov.read(uri, out_t, ctx=ctx)
|
|
60
|
-
except Exception as e: # pragma: no cover - provider errors
|
|
61
|
-
raise RuntimeError(f"builtin.read failed: {e}") from e
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
@kernel_def("builtin.write")
|
|
65
|
-
def _write(pfunc: PFunction, obj: Any) -> Any:
|
|
66
|
-
path = pfunc.attrs.get("path")
|
|
67
|
-
if path is None:
|
|
68
|
-
raise ValueError("missing path attr for builtin.write")
|
|
69
|
-
uri = resolve_uri(str(path))
|
|
70
|
-
prov = get_provider(uri.scheme)
|
|
71
|
-
if prov is None:
|
|
72
|
-
raise NotImplementedError(f"no resource provider for scheme: {uri.scheme}")
|
|
73
|
-
ctx = cur_kctx()
|
|
74
|
-
try:
|
|
75
|
-
prov.write(uri, obj, ctx=ctx)
|
|
76
|
-
return obj
|
|
77
|
-
except Exception as e: # pragma: no cover
|
|
78
|
-
raise RuntimeError(f"builtin.write failed: {e}") from e
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
@kernel_def("builtin.constant")
|
|
82
|
-
def _constant(pfunc: PFunction) -> Any:
|
|
83
|
-
data_bytes = pfunc.attrs.get("data_bytes")
|
|
84
|
-
if data_bytes is None:
|
|
85
|
-
raise ValueError("missing data_bytes attr for builtin.constant")
|
|
86
|
-
out_t = pfunc.outs_info[0]
|
|
87
|
-
fmt = pfunc.attrs.get("data_format")
|
|
88
|
-
if isinstance(out_t, TableType):
|
|
89
|
-
if fmt != "bytes[csv]":
|
|
90
|
-
raise ValueError(f"unsupported table constant format {fmt}")
|
|
91
|
-
df = table_utils.csv_to_dataframe(data_bytes)
|
|
92
|
-
return df
|
|
93
|
-
# tensor path
|
|
94
|
-
shape = out_t.shape # type: ignore[attr-defined,union-attr]
|
|
95
|
-
dtype = out_t.dtype.numpy_dtype() # type: ignore[attr-defined,union-attr]
|
|
96
|
-
arr = np.frombuffer(data_bytes, dtype=dtype).reshape(shape)
|
|
97
|
-
return arr
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
@kernel_def("builtin.rank")
|
|
101
|
-
def _rank(pfunc: PFunction) -> Any:
|
|
102
|
-
ctx = cur_kctx()
|
|
103
|
-
return np.array(ctx.rank, dtype=np.uint64)
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
@kernel_def("builtin.prand")
|
|
107
|
-
def _prand(pfunc: PFunction) -> Any:
|
|
108
|
-
shape = pfunc.attrs.get("shape", ())
|
|
109
|
-
rng = np.random.default_rng()
|
|
110
|
-
info = np.iinfo(np.uint64)
|
|
111
|
-
data = rng.integers(
|
|
112
|
-
low=info.min, high=info.max, size=shape, dtype=np.uint64, endpoint=True
|
|
113
|
-
)
|
|
114
|
-
return data
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
@kernel_def("builtin.table_to_tensor")
|
|
118
|
-
def _table_to_tensor(pfunc: PFunction, table: Any) -> Any:
|
|
119
|
-
if not isinstance(table, pd.DataFrame):
|
|
120
|
-
raise TypeError("expected pandas DataFrame")
|
|
121
|
-
if table.shape[1] == 0:
|
|
122
|
-
raise ValueError("cannot pack empty table")
|
|
123
|
-
mat = np.column_stack([table[col].to_numpy() for col in table.columns])
|
|
124
|
-
return mat
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
@kernel_def("builtin.tensor_to_table")
|
|
128
|
-
def _tensor_to_table(pfunc: PFunction, tensor: Any) -> Any:
|
|
129
|
-
arr = _to_numpy(tensor)
|
|
130
|
-
if arr.ndim != 2:
|
|
131
|
-
raise ValueError("tensor_to_table expects rank-2 array")
|
|
132
|
-
col_names = pfunc.attrs.get("column_names")
|
|
133
|
-
if col_names is None:
|
|
134
|
-
raise ValueError("missing column_names attr")
|
|
135
|
-
df = pd.DataFrame(arr, columns=list(col_names))
|
|
136
|
-
return df
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
def _summ(v: Any) -> str:
|
|
140
|
-
try:
|
|
141
|
-
if isinstance(v, pd.DataFrame):
|
|
142
|
-
return str(v.head(8).to_string(index=False))
|
|
143
|
-
arr = _to_numpy(v)
|
|
144
|
-
return str(
|
|
145
|
-
np.array2string(
|
|
146
|
-
arr, threshold=64, edgeitems=3, precision=6, suppress_small=True
|
|
147
|
-
)
|
|
148
|
-
)
|
|
149
|
-
except Exception as e: # pragma: no cover
|
|
150
|
-
return f"<unprintable {type(v).__name__}: {e}>"
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
@kernel_def("builtin.debug_print")
|
|
154
|
-
def _debug_print(pfunc: PFunction, val: Any) -> Any:
|
|
155
|
-
prefix = pfunc.attrs.get("prefix", "")
|
|
156
|
-
ctx = cur_kctx()
|
|
157
|
-
print(f"[debug_print][rank={ctx.rank}] {prefix}{_summ(val)}")
|
|
158
|
-
return val
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
@kernel_def("builtin.pack")
|
|
162
|
-
def _pack(pfunc: PFunction, value: Any) -> Any:
|
|
163
|
-
outs_info = pfunc.outs_info
|
|
164
|
-
if len(outs_info) != 1:
|
|
165
|
-
raise ValueError("builtin.pack expects single output type")
|
|
166
|
-
out_ty = outs_info[0]
|
|
167
|
-
if not isinstance(out_ty, TensorType):
|
|
168
|
-
raise TypeError("builtin.pack must return TensorType")
|
|
169
|
-
if out_ty.dtype.numpy_dtype() != np.uint8:
|
|
170
|
-
raise TypeError("builtin.pack output dtype must be uint8")
|
|
171
|
-
|
|
172
|
-
if isinstance(value, pd.DataFrame):
|
|
173
|
-
csv_bytes = table_utils.dataframe_to_csv(value)
|
|
174
|
-
return np.frombuffer(csv_bytes, dtype=np.uint8)
|
|
175
|
-
|
|
176
|
-
arr = _to_numpy(value)
|
|
177
|
-
return np.frombuffer(arr.tobytes(order="C"), dtype=np.uint8)
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
@kernel_def("builtin.unpack")
|
|
181
|
-
def _unpack(pfunc: PFunction, packed: Any) -> Any:
|
|
182
|
-
outs_info = pfunc.outs_info
|
|
183
|
-
if len(outs_info) != 1:
|
|
184
|
-
raise ValueError("builtin.unpack expects single output type")
|
|
185
|
-
out_ty = outs_info[0]
|
|
186
|
-
|
|
187
|
-
b = np.asarray(packed, dtype=np.uint8).reshape(-1)
|
|
188
|
-
|
|
189
|
-
if isinstance(out_ty, TensorType):
|
|
190
|
-
np_dtype = out_ty.dtype.numpy_dtype()
|
|
191
|
-
shape = tuple(out_ty.shape)
|
|
192
|
-
if any(dim < 0 for dim in shape):
|
|
193
|
-
raise ValueError("builtin.unpack does not support dynamic tensor shapes")
|
|
194
|
-
elem_count = int(np.prod(shape))
|
|
195
|
-
expected = elem_count * np.dtype(np_dtype).itemsize
|
|
196
|
-
if b.size != expected:
|
|
197
|
-
raise ValueError(
|
|
198
|
-
f"unpack size mismatch: got {b.size} bytes, expect {expected} for {np_dtype} {shape}"
|
|
199
|
-
)
|
|
200
|
-
arr = np.frombuffer(b.tobytes(), dtype=np_dtype)
|
|
201
|
-
return arr.reshape(shape)
|
|
202
|
-
|
|
203
|
-
if isinstance(out_ty, TableType):
|
|
204
|
-
csv_bytes = b.tobytes()
|
|
205
|
-
return table_utils.csv_to_dataframe(csv_bytes)
|
|
206
|
-
|
|
207
|
-
raise TypeError("builtin.unpack output type must be TensorType or TableType")
|