mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -15,16 +15,38 @@
|
|
|
15
15
|
from __future__ import annotations
|
|
16
16
|
|
|
17
17
|
from dataclasses import dataclass
|
|
18
|
-
from typing import Any
|
|
18
|
+
from typing import Any, ClassVar
|
|
19
19
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
import spu.api as spu_api
|
|
22
22
|
import spu.libspu as libspu
|
|
23
23
|
|
|
24
|
-
from mplang.core
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
24
|
+
from mplang.v1.core import (
|
|
25
|
+
BOOL,
|
|
26
|
+
FLOAT32,
|
|
27
|
+
FLOAT64,
|
|
28
|
+
INT8,
|
|
29
|
+
INT16,
|
|
30
|
+
INT32,
|
|
31
|
+
INT64,
|
|
32
|
+
UINT8,
|
|
33
|
+
UINT16,
|
|
34
|
+
UINT32,
|
|
35
|
+
UINT64,
|
|
36
|
+
DType,
|
|
37
|
+
PFunction,
|
|
38
|
+
)
|
|
39
|
+
from mplang.v1.kernels.base import cur_kctx, kernel_def
|
|
40
|
+
from mplang.v1.kernels.value import (
|
|
41
|
+
TensorValue,
|
|
42
|
+
Value,
|
|
43
|
+
ValueDecodeError,
|
|
44
|
+
ValueProtoBuilder,
|
|
45
|
+
ValueProtoReader,
|
|
46
|
+
register_value,
|
|
47
|
+
)
|
|
48
|
+
from mplang.v1.protos.v1alpha1 import value_pb2 as _value_pb2
|
|
49
|
+
from mplang.v1.runtime.link_comm import LinkCommunicator
|
|
28
50
|
|
|
29
51
|
|
|
30
52
|
def shape_spu_to_np(spu_shape: Any) -> tuple[int, ...]:
|
|
@@ -32,36 +54,106 @@ def shape_spu_to_np(spu_shape: Any) -> tuple[int, ...]:
|
|
|
32
54
|
return tuple(spu_shape.dims)
|
|
33
55
|
|
|
34
56
|
|
|
35
|
-
def
|
|
36
|
-
"""Convert
|
|
57
|
+
def dtype_spu_to_mpl(spu_dtype: libspu.DataType) -> DType:
|
|
58
|
+
"""Convert libspu.DataType to MPLang DType."""
|
|
37
59
|
MAP = {
|
|
38
|
-
libspu.DataType.DT_F32:
|
|
39
|
-
libspu.DataType.DT_F64:
|
|
40
|
-
libspu.DataType.DT_I1:
|
|
41
|
-
libspu.DataType.DT_I8:
|
|
42
|
-
libspu.DataType.DT_U8:
|
|
43
|
-
libspu.DataType.DT_I16:
|
|
44
|
-
libspu.DataType.DT_U16:
|
|
45
|
-
libspu.DataType.DT_I32:
|
|
46
|
-
libspu.DataType.DT_U32:
|
|
47
|
-
libspu.DataType.DT_I64:
|
|
48
|
-
libspu.DataType.DT_U64:
|
|
60
|
+
libspu.DataType.DT_F32: FLOAT32,
|
|
61
|
+
libspu.DataType.DT_F64: FLOAT64,
|
|
62
|
+
libspu.DataType.DT_I1: BOOL,
|
|
63
|
+
libspu.DataType.DT_I8: INT8,
|
|
64
|
+
libspu.DataType.DT_U8: UINT8,
|
|
65
|
+
libspu.DataType.DT_I16: INT16,
|
|
66
|
+
libspu.DataType.DT_U16: UINT16,
|
|
67
|
+
libspu.DataType.DT_I32: INT32,
|
|
68
|
+
libspu.DataType.DT_U32: UINT32,
|
|
69
|
+
libspu.DataType.DT_I64: INT64,
|
|
70
|
+
libspu.DataType.DT_U64: UINT64,
|
|
49
71
|
}
|
|
50
|
-
return MAP[spu_dtype]
|
|
72
|
+
return MAP[spu_dtype]
|
|
51
73
|
|
|
52
74
|
|
|
75
|
+
@register_value
|
|
53
76
|
@dataclass
|
|
54
|
-
class SpuValue:
|
|
55
|
-
"""SPU value container for secure computation."""
|
|
77
|
+
class SpuValue(Value):
|
|
78
|
+
"""SPU value container for secure computation (Value type)."""
|
|
79
|
+
|
|
80
|
+
KIND: ClassVar[str] = "mplang.spu.SpuValue"
|
|
81
|
+
WIRE_VERSION: ClassVar[int] = 1
|
|
56
82
|
|
|
57
83
|
shape: tuple[int, ...]
|
|
58
|
-
dtype:
|
|
84
|
+
dtype: DType # Now uses MPLang's unified DType
|
|
59
85
|
vtype: libspu.Visibility
|
|
60
86
|
share: libspu.Share
|
|
61
87
|
|
|
62
88
|
def __repr__(self) -> str:
|
|
63
89
|
return f"SpuValue({self.shape},{self.dtype},{self.vtype})"
|
|
64
90
|
|
|
91
|
+
def to_proto(self) -> _value_pb2.ValueProto:
|
|
92
|
+
"""Serialize SpuValue to wire format.
|
|
93
|
+
|
|
94
|
+
libspu.Share has two attributes:
|
|
95
|
+
- meta: bytes (protobuf serialized metadata)
|
|
96
|
+
- share_chunks: list[bytes] (the actual secret share data)
|
|
97
|
+
|
|
98
|
+
Strategy: Store shape/dtype/vtype in runtime_attrs, concatenate share.meta + all chunks in payload.
|
|
99
|
+
"""
|
|
100
|
+
# Store metadata in runtime_attrs; keep chunk lengths for payload splitting
|
|
101
|
+
chunk_lengths = [len(chunk) for chunk in self.share.share_chunks]
|
|
102
|
+
|
|
103
|
+
# Payload contains only share chunks (meta stored in attrs)
|
|
104
|
+
payload = b""
|
|
105
|
+
for chunk in self.share.share_chunks:
|
|
106
|
+
payload += chunk
|
|
107
|
+
|
|
108
|
+
return (
|
|
109
|
+
ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
|
|
110
|
+
.set_attr("shape", list(self.shape))
|
|
111
|
+
.set_attr("dtype", self.dtype.name) # Serialize DType name
|
|
112
|
+
.set_attr("vtype", int(self.vtype))
|
|
113
|
+
.set_attr("share_meta", self.share.meta)
|
|
114
|
+
.set_attr("chunk_lengths", chunk_lengths)
|
|
115
|
+
.set_payload(payload)
|
|
116
|
+
.build()
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
@classmethod
|
|
120
|
+
def from_proto(cls, proto: _value_pb2.ValueProto) -> SpuValue:
|
|
121
|
+
"""Deserialize SpuValue from wire format."""
|
|
122
|
+
reader = ValueProtoReader(proto)
|
|
123
|
+
if reader.version != cls.WIRE_VERSION:
|
|
124
|
+
raise ValueDecodeError(f"Unsupported SpuValue version {reader.version}")
|
|
125
|
+
|
|
126
|
+
# Read metadata from runtime_attrs
|
|
127
|
+
shape = tuple(reader.get_attr("shape"))
|
|
128
|
+
dtype_name = reader.get_attr("dtype")
|
|
129
|
+
# Reconstruct DType from serialized name (numpy dtype string)
|
|
130
|
+
dtype = DType.from_numpy(dtype_name)
|
|
131
|
+
vtype = libspu.Visibility(reader.get_attr("vtype"))
|
|
132
|
+
share_meta = reader.get_attr("share_meta")
|
|
133
|
+
chunk_lengths = reader.get_attr("chunk_lengths")
|
|
134
|
+
|
|
135
|
+
# Parse payload: [chunk_0][chunk_1]...
|
|
136
|
+
payload = reader.payload
|
|
137
|
+
offset = 0
|
|
138
|
+
|
|
139
|
+
share_chunks: list[bytes] = []
|
|
140
|
+
for chunk_len in chunk_lengths:
|
|
141
|
+
chunk = payload[offset : offset + chunk_len]
|
|
142
|
+
offset += chunk_len
|
|
143
|
+
share_chunks.append(chunk)
|
|
144
|
+
|
|
145
|
+
# Reconstruct libspu.Share
|
|
146
|
+
share = libspu.Share()
|
|
147
|
+
share.meta = share_meta
|
|
148
|
+
share.share_chunks = share_chunks
|
|
149
|
+
|
|
150
|
+
return cls(
|
|
151
|
+
shape=shape,
|
|
152
|
+
dtype=dtype,
|
|
153
|
+
vtype=vtype,
|
|
154
|
+
share=share,
|
|
155
|
+
)
|
|
156
|
+
|
|
65
157
|
|
|
66
158
|
def _get_spu_config_and_world() -> tuple[libspu.RuntimeConfig, int]:
|
|
67
159
|
kctx = cur_kctx()
|
|
@@ -128,33 +220,25 @@ def _spu_seed_env(pfunc: PFunction, *args: Any) -> Any:
|
|
|
128
220
|
|
|
129
221
|
|
|
130
222
|
@kernel_def("spu.makeshares")
|
|
131
|
-
def _spu_makeshares(pfunc: PFunction,
|
|
132
|
-
"""Create SPU shares from input data.
|
|
133
|
-
|
|
134
|
-
Args:
|
|
135
|
-
pfunc: PFunction containing makeshares metadata
|
|
136
|
-
args: Input data to be shared (single tensor)
|
|
137
|
-
|
|
138
|
-
Returns:
|
|
139
|
-
Tuple of SPU shares (SpuValue), one for each party.
|
|
140
|
-
"""
|
|
141
|
-
assert len(args) == 1
|
|
142
|
-
|
|
223
|
+
def _spu_makeshares(pfunc: PFunction, tensor: TensorValue) -> tuple[SpuValue, ...]:
|
|
224
|
+
"""Create SPU shares from input TensorValue data."""
|
|
143
225
|
visibility_value = pfunc.attrs.get("visibility", libspu.Visibility.VIS_SECRET.value)
|
|
144
226
|
if isinstance(visibility_value, int):
|
|
145
227
|
visibility = libspu.Visibility(visibility_value)
|
|
146
228
|
else:
|
|
147
229
|
visibility = visibility_value
|
|
148
230
|
|
|
149
|
-
arg =
|
|
231
|
+
arg = tensor.to_numpy()
|
|
150
232
|
cfg, world = _get_spu_config_and_world()
|
|
151
233
|
spu_io = spu_api.Io(world, cfg)
|
|
152
234
|
shares = spu_io.make_shares(arg, visibility)
|
|
153
235
|
assert len(shares) == world, f"Expected {world} shares, got {len(shares)}"
|
|
236
|
+
# Store MPLang DType instead of libspu.DataType
|
|
237
|
+
dtype = DType.from_numpy(arg.dtype)
|
|
154
238
|
return tuple(
|
|
155
239
|
SpuValue(
|
|
156
240
|
shape=arg.shape,
|
|
157
|
-
dtype=
|
|
241
|
+
dtype=dtype,
|
|
158
242
|
vtype=visibility,
|
|
159
243
|
share=share,
|
|
160
244
|
)
|
|
@@ -163,24 +247,29 @@ def _spu_makeshares(pfunc: PFunction, *args: Any) -> Any:
|
|
|
163
247
|
|
|
164
248
|
|
|
165
249
|
@kernel_def("spu.reconstruct")
|
|
166
|
-
def _spu_reconstruct(pfunc: PFunction, *
|
|
250
|
+
def _spu_reconstruct(pfunc: PFunction, *shares: SpuValue) -> TensorValue:
|
|
167
251
|
"""Reconstruct plaintext data from SPU shares."""
|
|
168
252
|
cfg, world = _get_spu_config_and_world()
|
|
169
|
-
assert len(
|
|
170
|
-
for i,
|
|
171
|
-
if not isinstance(
|
|
253
|
+
assert len(shares) == world, f"Expected {world} shares, got {len(shares)}"
|
|
254
|
+
for i, share in enumerate(shares):
|
|
255
|
+
if not isinstance(share, SpuValue):
|
|
172
256
|
raise ValueError(
|
|
173
|
-
f"Input {i} must be SpuValue, got {type(
|
|
257
|
+
f"Input {i} must be SpuValue, got {type(share)}. Reconstruction requires SPU shares as input."
|
|
174
258
|
)
|
|
175
|
-
spu_args: list[SpuValue] = list(
|
|
176
|
-
|
|
259
|
+
spu_args: list[SpuValue] = list(shares) # type: ignore
|
|
260
|
+
share_payloads = [spu_arg.share for spu_arg in spu_args]
|
|
177
261
|
spu_io = spu_api.Io(world, cfg)
|
|
178
|
-
reconstructed = spu_io.reconstruct(
|
|
179
|
-
|
|
262
|
+
reconstructed = spu_io.reconstruct(share_payloads)
|
|
263
|
+
base = np.array(reconstructed, copy=False)
|
|
264
|
+
# Respect semantic dtype/shape recorded on shares (all shares share same meta).
|
|
265
|
+
semantic_dtype = shares[0].dtype.to_numpy() # DType now has to_numpy() method
|
|
266
|
+
semantic_shape = shares[0].shape
|
|
267
|
+
restored = np.asarray(base, dtype=semantic_dtype).reshape(semantic_shape)
|
|
268
|
+
return TensorValue(np.array(restored, copy=False))
|
|
180
269
|
|
|
181
270
|
|
|
182
271
|
@kernel_def("spu.run_pphlo")
|
|
183
|
-
def _spu_run_mlir(pfunc: PFunction, *args:
|
|
272
|
+
def _spu_run_mlir(pfunc: PFunction, *args: SpuValue) -> tuple[SpuValue, ...]:
|
|
184
273
|
"""Execute compiled SPU function (spu.run_pphlo) and return SpuValue outputs.
|
|
185
274
|
|
|
186
275
|
Participation rule: a rank participates iff its entry in the stored
|
|
@@ -240,10 +329,10 @@ def _spu_run_mlir(pfunc: PFunction, *args: Any) -> Any:
|
|
|
240
329
|
spu_rt.run(executable)
|
|
241
330
|
shares = [spu_rt.get_var(out_name) for out_name in output_names]
|
|
242
331
|
metas = [spu_rt.get_var_meta(out_name) for out_name in output_names]
|
|
243
|
-
results: list[
|
|
332
|
+
results: list[SpuValue] = [
|
|
244
333
|
SpuValue(
|
|
245
334
|
shape=shape_spu_to_np(meta.shape),
|
|
246
|
-
dtype=
|
|
335
|
+
dtype=dtype_spu_to_mpl(meta.data_type),
|
|
247
336
|
vtype=meta.visibility,
|
|
248
337
|
share=shares[idx],
|
|
249
338
|
)
|
|
@@ -14,16 +14,14 @@
|
|
|
14
14
|
|
|
15
15
|
from __future__ import annotations
|
|
16
16
|
|
|
17
|
-
from
|
|
18
|
-
|
|
19
|
-
from mplang.
|
|
20
|
-
from mplang.kernels.base import kernel_def
|
|
17
|
+
from mplang.v1.core import PFunction
|
|
18
|
+
from mplang.v1.kernels.base import kernel_def
|
|
19
|
+
from mplang.v1.kernels.value import TableValue
|
|
21
20
|
|
|
22
21
|
|
|
23
22
|
@kernel_def("duckdb.run_sql")
|
|
24
|
-
def _duckdb_sql(pfunc: PFunction, *args:
|
|
23
|
+
def _duckdb_sql(pfunc: PFunction, *args: TableValue) -> TableValue:
|
|
25
24
|
import duckdb
|
|
26
|
-
import pandas as pd
|
|
27
25
|
|
|
28
26
|
# TODO: maybe we could translate the sql to duckdb dialect
|
|
29
27
|
# instead of raising an exception
|
|
@@ -36,12 +34,11 @@ def _duckdb_sql(pfunc: PFunction, *args: Any) -> Any:
|
|
|
36
34
|
if in_names is None:
|
|
37
35
|
raise ValueError("duckdb sql missing in_names attr")
|
|
38
36
|
for arg, name in zip(args, in_names, strict=True):
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
return res_df
|
|
37
|
+
# Use Arrow directly for zero-copy data transfer
|
|
38
|
+
arrow_table = arg.to_arrow()
|
|
39
|
+
conn.register(name, arrow_table)
|
|
40
|
+
# Fetch result as Arrow table for consistency
|
|
41
|
+
if pfunc.fn_text is None:
|
|
42
|
+
raise ValueError("SQL function text is None")
|
|
43
|
+
res_arrow = conn.execute(pfunc.fn_text).fetch_arrow_table()
|
|
44
|
+
return TableValue(res_arrow)
|
|
@@ -17,12 +17,14 @@ from __future__ import annotations
|
|
|
17
17
|
from typing import Any
|
|
18
18
|
|
|
19
19
|
import jax
|
|
20
|
+
import jax.extend as jxt
|
|
20
21
|
import jax.numpy as jnp
|
|
21
|
-
|
|
22
|
-
from jax.
|
|
22
|
+
import numpy as np
|
|
23
|
+
from jax._src import compiler
|
|
23
24
|
|
|
24
|
-
from mplang.core
|
|
25
|
-
from mplang.kernels.base import cur_kctx, kernel_def
|
|
25
|
+
from mplang.v1.core import PFunction
|
|
26
|
+
from mplang.v1.kernels.base import cur_kctx, kernel_def
|
|
27
|
+
from mplang.v1.kernels.value import TensorValue
|
|
26
28
|
|
|
27
29
|
|
|
28
30
|
@kernel_def("mlir.stablehlo")
|
|
@@ -45,11 +47,13 @@ def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
|
|
|
45
47
|
key = f"stablehlo.compile_cache.{h}"
|
|
46
48
|
compiled = rt.get_state(key)
|
|
47
49
|
if compiled is None:
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
50
|
+
client = jxt.backend.get_backend()
|
|
51
|
+
compile_options = compiler.get_compile_options(num_replicas=1, num_partitions=1)
|
|
52
|
+
|
|
51
53
|
try:
|
|
52
|
-
compiled = client.
|
|
54
|
+
compiled = client.compile_and_load(
|
|
55
|
+
mlir_text, client.devices(), compile_options
|
|
56
|
+
)
|
|
53
57
|
except Exception as e: # pragma: no cover
|
|
54
58
|
raise RuntimeError(f"StableHLO compile failed: {e}") from e
|
|
55
59
|
rt.set_state(key, compiled)
|
|
@@ -61,23 +65,26 @@ def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
|
|
|
61
65
|
# Filter out arguments that were eliminated by JAX during compilation
|
|
62
66
|
runtime_args = tuple(args[i] for i in keep_indices)
|
|
63
67
|
|
|
64
|
-
|
|
65
|
-
for arg in runtime_args:
|
|
66
|
-
if
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
68
|
+
tensor_args: list[TensorValue] = []
|
|
69
|
+
for idx, arg in enumerate(runtime_args):
|
|
70
|
+
if not isinstance(arg, TensorValue):
|
|
71
|
+
raise TypeError(
|
|
72
|
+
f"StableHLO kernel expects TensorValue inputs, got {type(arg).__name__} at position {idx}"
|
|
73
|
+
)
|
|
74
|
+
tensor_args.append(arg)
|
|
75
|
+
|
|
76
|
+
jax_args = [
|
|
77
|
+
jax.device_put(jnp.asarray(tensor.to_numpy())) for tensor in tensor_args
|
|
78
|
+
]
|
|
71
79
|
|
|
72
80
|
try:
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
flat.extend([jnp.array(a) for a in lst])
|
|
81
|
+
# Execute with the new LoadedExecutable interface
|
|
82
|
+
result = compiled.execute(jax_args)
|
|
83
|
+
|
|
84
|
+
# Use jax.tree_util.tree_flatten to robustly handle any PyTree structure
|
|
85
|
+
flat_results, _ = jax.tree_util.tree_flatten(result)
|
|
86
|
+
flat = [TensorValue(np.asarray(item)) for item in flat_results]
|
|
87
|
+
|
|
81
88
|
return tuple(flat)
|
|
82
89
|
except Exception as e: # pragma: no cover
|
|
83
90
|
raise RuntimeError(f"StableHLO execute failed: {e}") from e
|