mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
mplang/v1/kernels/spu.py
DELETED
|
@@ -1,341 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
from dataclasses import dataclass
|
|
18
|
-
from typing import Any, ClassVar
|
|
19
|
-
|
|
20
|
-
import numpy as np
|
|
21
|
-
import spu.api as spu_api
|
|
22
|
-
import spu.libspu as libspu
|
|
23
|
-
|
|
24
|
-
from mplang.v1.core import (
|
|
25
|
-
BOOL,
|
|
26
|
-
FLOAT32,
|
|
27
|
-
FLOAT64,
|
|
28
|
-
INT8,
|
|
29
|
-
INT16,
|
|
30
|
-
INT32,
|
|
31
|
-
INT64,
|
|
32
|
-
UINT8,
|
|
33
|
-
UINT16,
|
|
34
|
-
UINT32,
|
|
35
|
-
UINT64,
|
|
36
|
-
DType,
|
|
37
|
-
PFunction,
|
|
38
|
-
)
|
|
39
|
-
from mplang.v1.kernels.base import cur_kctx, kernel_def
|
|
40
|
-
from mplang.v1.kernels.value import (
|
|
41
|
-
TensorValue,
|
|
42
|
-
Value,
|
|
43
|
-
ValueDecodeError,
|
|
44
|
-
ValueProtoBuilder,
|
|
45
|
-
ValueProtoReader,
|
|
46
|
-
register_value,
|
|
47
|
-
)
|
|
48
|
-
from mplang.v1.protos.v1alpha1 import value_pb2 as _value_pb2
|
|
49
|
-
from mplang.v1.runtime.link_comm import LinkCommunicator
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def shape_spu_to_np(spu_shape: Any) -> tuple[int, ...]:
|
|
53
|
-
"""Convert SPU shape to numpy tuple."""
|
|
54
|
-
return tuple(spu_shape.dims)
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
def dtype_spu_to_mpl(spu_dtype: libspu.DataType) -> DType:
|
|
58
|
-
"""Convert libspu.DataType to MPLang DType."""
|
|
59
|
-
MAP = {
|
|
60
|
-
libspu.DataType.DT_F32: FLOAT32,
|
|
61
|
-
libspu.DataType.DT_F64: FLOAT64,
|
|
62
|
-
libspu.DataType.DT_I1: BOOL,
|
|
63
|
-
libspu.DataType.DT_I8: INT8,
|
|
64
|
-
libspu.DataType.DT_U8: UINT8,
|
|
65
|
-
libspu.DataType.DT_I16: INT16,
|
|
66
|
-
libspu.DataType.DT_U16: UINT16,
|
|
67
|
-
libspu.DataType.DT_I32: INT32,
|
|
68
|
-
libspu.DataType.DT_U32: UINT32,
|
|
69
|
-
libspu.DataType.DT_I64: INT64,
|
|
70
|
-
libspu.DataType.DT_U64: UINT64,
|
|
71
|
-
}
|
|
72
|
-
return MAP[spu_dtype]
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
@register_value
|
|
76
|
-
@dataclass
|
|
77
|
-
class SpuValue(Value):
|
|
78
|
-
"""SPU value container for secure computation (Value type)."""
|
|
79
|
-
|
|
80
|
-
KIND: ClassVar[str] = "mplang.spu.SpuValue"
|
|
81
|
-
WIRE_VERSION: ClassVar[int] = 1
|
|
82
|
-
|
|
83
|
-
shape: tuple[int, ...]
|
|
84
|
-
dtype: DType # Now uses MPLang's unified DType
|
|
85
|
-
vtype: libspu.Visibility
|
|
86
|
-
share: libspu.Share
|
|
87
|
-
|
|
88
|
-
def __repr__(self) -> str:
|
|
89
|
-
return f"SpuValue({self.shape},{self.dtype},{self.vtype})"
|
|
90
|
-
|
|
91
|
-
def to_proto(self) -> _value_pb2.ValueProto:
|
|
92
|
-
"""Serialize SpuValue to wire format.
|
|
93
|
-
|
|
94
|
-
libspu.Share has two attributes:
|
|
95
|
-
- meta: bytes (protobuf serialized metadata)
|
|
96
|
-
- share_chunks: list[bytes] (the actual secret share data)
|
|
97
|
-
|
|
98
|
-
Strategy: Store shape/dtype/vtype in runtime_attrs, concatenate share.meta + all chunks in payload.
|
|
99
|
-
"""
|
|
100
|
-
# Store metadata in runtime_attrs; keep chunk lengths for payload splitting
|
|
101
|
-
chunk_lengths = [len(chunk) for chunk in self.share.share_chunks]
|
|
102
|
-
|
|
103
|
-
# Payload contains only share chunks (meta stored in attrs)
|
|
104
|
-
payload = b""
|
|
105
|
-
for chunk in self.share.share_chunks:
|
|
106
|
-
payload += chunk
|
|
107
|
-
|
|
108
|
-
return (
|
|
109
|
-
ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
|
|
110
|
-
.set_attr("shape", list(self.shape))
|
|
111
|
-
.set_attr("dtype", self.dtype.name) # Serialize DType name
|
|
112
|
-
.set_attr("vtype", int(self.vtype))
|
|
113
|
-
.set_attr("share_meta", self.share.meta)
|
|
114
|
-
.set_attr("chunk_lengths", chunk_lengths)
|
|
115
|
-
.set_payload(payload)
|
|
116
|
-
.build()
|
|
117
|
-
)
|
|
118
|
-
|
|
119
|
-
@classmethod
|
|
120
|
-
def from_proto(cls, proto: _value_pb2.ValueProto) -> SpuValue:
|
|
121
|
-
"""Deserialize SpuValue from wire format."""
|
|
122
|
-
reader = ValueProtoReader(proto)
|
|
123
|
-
if reader.version != cls.WIRE_VERSION:
|
|
124
|
-
raise ValueDecodeError(f"Unsupported SpuValue version {reader.version}")
|
|
125
|
-
|
|
126
|
-
# Read metadata from runtime_attrs
|
|
127
|
-
shape = tuple(reader.get_attr("shape"))
|
|
128
|
-
dtype_name = reader.get_attr("dtype")
|
|
129
|
-
# Reconstruct DType from serialized name (numpy dtype string)
|
|
130
|
-
dtype = DType.from_numpy(dtype_name)
|
|
131
|
-
vtype = libspu.Visibility(reader.get_attr("vtype"))
|
|
132
|
-
share_meta = reader.get_attr("share_meta")
|
|
133
|
-
chunk_lengths = reader.get_attr("chunk_lengths")
|
|
134
|
-
|
|
135
|
-
# Parse payload: [chunk_0][chunk_1]...
|
|
136
|
-
payload = reader.payload
|
|
137
|
-
offset = 0
|
|
138
|
-
|
|
139
|
-
share_chunks: list[bytes] = []
|
|
140
|
-
for chunk_len in chunk_lengths:
|
|
141
|
-
chunk = payload[offset : offset + chunk_len]
|
|
142
|
-
offset += chunk_len
|
|
143
|
-
share_chunks.append(chunk)
|
|
144
|
-
|
|
145
|
-
# Reconstruct libspu.Share
|
|
146
|
-
share = libspu.Share()
|
|
147
|
-
share.meta = share_meta
|
|
148
|
-
share.share_chunks = share_chunks
|
|
149
|
-
|
|
150
|
-
return cls(
|
|
151
|
-
shape=shape,
|
|
152
|
-
dtype=dtype,
|
|
153
|
-
vtype=vtype,
|
|
154
|
-
share=share,
|
|
155
|
-
)
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
def _get_spu_config_and_world() -> tuple[libspu.RuntimeConfig, int]:
|
|
159
|
-
kctx = cur_kctx()
|
|
160
|
-
cfg = kctx.runtime.get_state("spu.config")
|
|
161
|
-
world = kctx.runtime.get_state("spu.world")
|
|
162
|
-
if cfg is None or world is None:
|
|
163
|
-
raise RuntimeError("SPU kernel state not initialized (config/world)")
|
|
164
|
-
return cfg, int(world)
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
def _register_spu_env(
|
|
168
|
-
config: libspu.RuntimeConfig, world_size: int, link_ctx: LinkCommunicator | None
|
|
169
|
-
) -> None:
|
|
170
|
-
"""Register SPU config/world/link inside current kernel context.
|
|
171
|
-
|
|
172
|
-
Idempotent: if config/world already set, they must match; link is recorded per rank.
|
|
173
|
-
This replaces previous global fallback seeding logic.
|
|
174
|
-
"""
|
|
175
|
-
kctx = cur_kctx()
|
|
176
|
-
prev_cfg = kctx.runtime.get_state("spu.config")
|
|
177
|
-
prev_world = kctx.runtime.get_state("spu.world")
|
|
178
|
-
if prev_cfg is None:
|
|
179
|
-
kctx.runtime.set_state("spu.config", config)
|
|
180
|
-
kctx.runtime.set_state("spu.world", world_size)
|
|
181
|
-
else:
|
|
182
|
-
# libspu RuntimeConfig may not implement __eq__; compare serialized repr
|
|
183
|
-
same_cfg = (
|
|
184
|
-
prev_cfg.SerializeToString() == config.SerializeToString() # type: ignore[attr-defined]
|
|
185
|
-
if hasattr(prev_cfg, "SerializeToString")
|
|
186
|
-
and hasattr(config, "SerializeToString")
|
|
187
|
-
else prev_cfg == config
|
|
188
|
-
)
|
|
189
|
-
if not (same_cfg and prev_world == world_size):
|
|
190
|
-
raise RuntimeError("Conflicting SPU env registration")
|
|
191
|
-
# Store single link per runtime (one runtime per rank)
|
|
192
|
-
if link_ctx is not None:
|
|
193
|
-
kctx.runtime.set_state("spu.link", link_ctx)
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
@kernel_def("spu.seed_env")
|
|
197
|
-
def _spu_seed_env(pfunc: PFunction, *args: Any) -> Any:
|
|
198
|
-
"""Backend kernel to seed SPU environment.
|
|
199
|
-
|
|
200
|
-
NOTE: This is a control-plane style operation (side-effect: installs SPU
|
|
201
|
-
config/link into the per-runtime state pocket) rather than a pure data
|
|
202
|
-
transformation. It remains a kernel temporarily for minimal surface
|
|
203
|
-
changes during the backend deglobalization refactor. Callers MUST invoke
|
|
204
|
-
it explicitly via `runtime.run_kernel(seed_pfunc, [])`, never through
|
|
205
|
-
`Evaluator.evaluate` (fast-path removed) to keep IR evaluation semantics
|
|
206
|
-
clean. A future cleanup may promote this to a dedicated runtime helper
|
|
207
|
-
(e.g. `seed_spu_env(runtime, config, world, link)`), at which point this
|
|
208
|
-
kernel can be deprecated.
|
|
209
|
-
|
|
210
|
-
Required attrs: config (RuntimeConfig), world (int)
|
|
211
|
-
Optional attr: link (LinkCommunicator or None)
|
|
212
|
-
"""
|
|
213
|
-
cfg = pfunc.attrs.get("config")
|
|
214
|
-
world = pfunc.attrs.get("world")
|
|
215
|
-
link_ctx = pfunc.attrs.get("link", None)
|
|
216
|
-
if cfg is None or world is None:
|
|
217
|
-
raise ValueError("spu.seed_env requires 'config' and 'world' attrs")
|
|
218
|
-
_register_spu_env(cfg, int(world), link_ctx)
|
|
219
|
-
return None
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
@kernel_def("spu.makeshares")
|
|
223
|
-
def _spu_makeshares(pfunc: PFunction, tensor: TensorValue) -> tuple[SpuValue, ...]:
|
|
224
|
-
"""Create SPU shares from input TensorValue data."""
|
|
225
|
-
visibility_value = pfunc.attrs.get("visibility", libspu.Visibility.VIS_SECRET.value)
|
|
226
|
-
if isinstance(visibility_value, int):
|
|
227
|
-
visibility = libspu.Visibility(visibility_value)
|
|
228
|
-
else:
|
|
229
|
-
visibility = visibility_value
|
|
230
|
-
|
|
231
|
-
arg = tensor.to_numpy()
|
|
232
|
-
cfg, world = _get_spu_config_and_world()
|
|
233
|
-
spu_io = spu_api.Io(world, cfg)
|
|
234
|
-
shares = spu_io.make_shares(arg, visibility)
|
|
235
|
-
assert len(shares) == world, f"Expected {world} shares, got {len(shares)}"
|
|
236
|
-
# Store MPLang DType instead of libspu.DataType
|
|
237
|
-
dtype = DType.from_numpy(arg.dtype)
|
|
238
|
-
return tuple(
|
|
239
|
-
SpuValue(
|
|
240
|
-
shape=arg.shape,
|
|
241
|
-
dtype=dtype,
|
|
242
|
-
vtype=visibility,
|
|
243
|
-
share=share,
|
|
244
|
-
)
|
|
245
|
-
for share in shares
|
|
246
|
-
)
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
@kernel_def("spu.reconstruct")
|
|
250
|
-
def _spu_reconstruct(pfunc: PFunction, *shares: SpuValue) -> TensorValue:
|
|
251
|
-
"""Reconstruct plaintext data from SPU shares."""
|
|
252
|
-
cfg, world = _get_spu_config_and_world()
|
|
253
|
-
assert len(shares) == world, f"Expected {world} shares, got {len(shares)}"
|
|
254
|
-
for i, share in enumerate(shares):
|
|
255
|
-
if not isinstance(share, SpuValue):
|
|
256
|
-
raise ValueError(
|
|
257
|
-
f"Input {i} must be SpuValue, got {type(share)}. Reconstruction requires SPU shares as input."
|
|
258
|
-
)
|
|
259
|
-
spu_args: list[SpuValue] = list(shares) # type: ignore
|
|
260
|
-
share_payloads = [spu_arg.share for spu_arg in spu_args]
|
|
261
|
-
spu_io = spu_api.Io(world, cfg)
|
|
262
|
-
reconstructed = spu_io.reconstruct(share_payloads)
|
|
263
|
-
base = np.array(reconstructed, copy=False)
|
|
264
|
-
# Respect semantic dtype/shape recorded on shares (all shares share same meta).
|
|
265
|
-
semantic_dtype = shares[0].dtype.to_numpy() # DType now has to_numpy() method
|
|
266
|
-
semantic_shape = shares[0].shape
|
|
267
|
-
restored = np.asarray(base, dtype=semantic_dtype).reshape(semantic_shape)
|
|
268
|
-
return TensorValue(np.array(restored, copy=False))
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
@kernel_def("spu.run_pphlo")
|
|
272
|
-
def _spu_run_mlir(pfunc: PFunction, *args: SpuValue) -> tuple[SpuValue, ...]:
|
|
273
|
-
"""Execute compiled SPU function (spu.run_pphlo) and return SpuValue outputs.
|
|
274
|
-
|
|
275
|
-
Participation rule: a rank participates iff its entry in the stored
|
|
276
|
-
link_ctx list is non-None. This allows us to allocate a world-sized list
|
|
277
|
-
(indexed by global rank) and simply assign None for non-SPU parties.
|
|
278
|
-
"""
|
|
279
|
-
if pfunc.fn_type != "spu.run_pphlo":
|
|
280
|
-
raise ValueError(
|
|
281
|
-
f"Unsupported format: {pfunc.fn_type}. Expected 'spu.run_pphlo'"
|
|
282
|
-
)
|
|
283
|
-
|
|
284
|
-
cfg, _ = _get_spu_config_and_world()
|
|
285
|
-
kctx = cur_kctx()
|
|
286
|
-
link_ctx = kctx.runtime.get_state("spu.link")
|
|
287
|
-
if link_ctx is None:
|
|
288
|
-
raise RuntimeError("Rank not participating in SPU; no link set via seed_env")
|
|
289
|
-
|
|
290
|
-
# Lazy runtime cache under key spu.runtime
|
|
291
|
-
spu_rt = kctx.runtime.get_state("spu.runtime")
|
|
292
|
-
if spu_rt is None:
|
|
293
|
-
spu_rt = spu_api.Runtime(link_ctx.get_lctx(), cfg)
|
|
294
|
-
kctx.runtime.set_state("spu.runtime", spu_rt)
|
|
295
|
-
|
|
296
|
-
# Validate that all inputs are SpuValue objects
|
|
297
|
-
for i, arg in enumerate(args):
|
|
298
|
-
if not isinstance(arg, SpuValue):
|
|
299
|
-
raise ValueError(
|
|
300
|
-
f"Input {i} must be SpuValue, got {type(arg)}. In real SPU environments, all inputs must be SpuValue objects."
|
|
301
|
-
)
|
|
302
|
-
|
|
303
|
-
# Cast for type checking (we've validated above)
|
|
304
|
-
spu_args: list[SpuValue] = list(args) # type: ignore
|
|
305
|
-
|
|
306
|
-
# Reconstruct SPU executable from MLIR code and metadata
|
|
307
|
-
if pfunc.fn_text is None:
|
|
308
|
-
raise ValueError("PFunction does not contain executable data")
|
|
309
|
-
if not isinstance(pfunc.fn_text, str):
|
|
310
|
-
raise ValueError(f"Expected str, got {type(pfunc.fn_text)}")
|
|
311
|
-
|
|
312
|
-
# Extract metadata for executable reconstruction
|
|
313
|
-
attrs: dict[str, Any] = dict(pfunc.attrs or {})
|
|
314
|
-
input_names = attrs.get("input_names", [])
|
|
315
|
-
output_names = attrs.get("output_names", [])
|
|
316
|
-
executable_name = attrs.get("executable_name", pfunc.fn_name)
|
|
317
|
-
|
|
318
|
-
# Create executable from MLIR code and metadata
|
|
319
|
-
executable = libspu.Executable(
|
|
320
|
-
name=executable_name,
|
|
321
|
-
input_names=input_names,
|
|
322
|
-
output_names=output_names,
|
|
323
|
-
code=pfunc.fn_text,
|
|
324
|
-
)
|
|
325
|
-
|
|
326
|
-
# Set input variables in SPU runtime
|
|
327
|
-
for idx, spu_arg in enumerate(spu_args):
|
|
328
|
-
spu_rt.set_var(input_names[idx], spu_arg.share)
|
|
329
|
-
spu_rt.run(executable)
|
|
330
|
-
shares = [spu_rt.get_var(out_name) for out_name in output_names]
|
|
331
|
-
metas = [spu_rt.get_var_meta(out_name) for out_name in output_names]
|
|
332
|
-
results: list[SpuValue] = [
|
|
333
|
-
SpuValue(
|
|
334
|
-
shape=shape_spu_to_np(meta.shape),
|
|
335
|
-
dtype=dtype_spu_to_mpl(meta.data_type),
|
|
336
|
-
vtype=meta.visibility,
|
|
337
|
-
share=shares[idx],
|
|
338
|
-
)
|
|
339
|
-
for idx, meta in enumerate(metas)
|
|
340
|
-
]
|
|
341
|
-
return tuple(results)
|
mplang/v1/kernels/sql_duckdb.py
DELETED
|
@@ -1,44 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
from mplang.v1.core import PFunction
|
|
18
|
-
from mplang.v1.kernels.base import kernel_def
|
|
19
|
-
from mplang.v1.kernels.value import TableValue
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
@kernel_def("duckdb.run_sql")
|
|
23
|
-
def _duckdb_sql(pfunc: PFunction, *args: TableValue) -> TableValue:
|
|
24
|
-
import duckdb
|
|
25
|
-
|
|
26
|
-
# TODO: maybe we could translate the sql to duckdb dialect
|
|
27
|
-
# instead of raising an exception
|
|
28
|
-
if pfunc.attrs.get("dialect") != "duckdb":
|
|
29
|
-
raise ValueError("duckdb.run_sql must have dialect=duckdb attr")
|
|
30
|
-
|
|
31
|
-
conn = duckdb.connect(":memory:")
|
|
32
|
-
if args:
|
|
33
|
-
in_names = pfunc.attrs.get("in_names")
|
|
34
|
-
if in_names is None:
|
|
35
|
-
raise ValueError("duckdb sql missing in_names attr")
|
|
36
|
-
for arg, name in zip(args, in_names, strict=True):
|
|
37
|
-
# Use Arrow directly for zero-copy data transfer
|
|
38
|
-
arrow_table = arg.to_arrow()
|
|
39
|
-
conn.register(name, arrow_table)
|
|
40
|
-
# Fetch result as Arrow table for consistency
|
|
41
|
-
if pfunc.fn_text is None:
|
|
42
|
-
raise ValueError("SQL function text is None")
|
|
43
|
-
res_arrow = conn.execute(pfunc.fn_text).fetch_arrow_table()
|
|
44
|
-
return TableValue(res_arrow)
|
mplang/v1/kernels/stablehlo.py
DELETED
|
@@ -1,90 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
from typing import Any
|
|
18
|
-
|
|
19
|
-
import jax
|
|
20
|
-
import jax.extend as jxt
|
|
21
|
-
import jax.numpy as jnp
|
|
22
|
-
import numpy as np
|
|
23
|
-
from jax._src import compiler
|
|
24
|
-
|
|
25
|
-
from mplang.v1.core import PFunction
|
|
26
|
-
from mplang.v1.kernels.base import cur_kctx, kernel_def
|
|
27
|
-
from mplang.v1.kernels.value import TensorValue
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
@kernel_def("mlir.stablehlo")
|
|
31
|
-
def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
|
|
32
|
-
if pfunc.fn_type != "mlir.stablehlo":
|
|
33
|
-
raise ValueError("stablehlo kernel received wrong fn_type")
|
|
34
|
-
|
|
35
|
-
mlir_text = pfunc.fn_text
|
|
36
|
-
if mlir_text is None:
|
|
37
|
-
raise ValueError("StableHLO kernel missing fn_text")
|
|
38
|
-
if isinstance(mlir_text, bytes):
|
|
39
|
-
mlir_text = mlir_text.decode("utf-8")
|
|
40
|
-
|
|
41
|
-
# Flat-key compile cache: stablehlo.compile_cache.<hash>
|
|
42
|
-
ctx = cur_kctx()
|
|
43
|
-
rt = ctx.runtime
|
|
44
|
-
import hashlib
|
|
45
|
-
|
|
46
|
-
h = hashlib.sha256(mlir_text.encode("utf-8")).hexdigest()[:16]
|
|
47
|
-
key = f"stablehlo.compile_cache.{h}"
|
|
48
|
-
compiled = rt.get_state(key)
|
|
49
|
-
if compiled is None:
|
|
50
|
-
client = jxt.backend.get_backend()
|
|
51
|
-
compile_options = compiler.get_compile_options(num_replicas=1, num_partitions=1)
|
|
52
|
-
|
|
53
|
-
try:
|
|
54
|
-
compiled = client.compile_and_load(
|
|
55
|
-
mlir_text, client.devices(), compile_options
|
|
56
|
-
)
|
|
57
|
-
except Exception as e: # pragma: no cover
|
|
58
|
-
raise RuntimeError(f"StableHLO compile failed: {e}") from e
|
|
59
|
-
rt.set_state(key, compiled)
|
|
60
|
-
|
|
61
|
-
# Handle JAX's unused parameter elimination via arg_keep_map
|
|
62
|
-
runtime_args = args
|
|
63
|
-
if "arg_keep_map" in pfunc.attrs:
|
|
64
|
-
keep_indices = pfunc.attrs["arg_keep_map"]
|
|
65
|
-
# Filter out arguments that were eliminated by JAX during compilation
|
|
66
|
-
runtime_args = tuple(args[i] for i in keep_indices)
|
|
67
|
-
|
|
68
|
-
tensor_args: list[TensorValue] = []
|
|
69
|
-
for idx, arg in enumerate(runtime_args):
|
|
70
|
-
if not isinstance(arg, TensorValue):
|
|
71
|
-
raise TypeError(
|
|
72
|
-
f"StableHLO kernel expects TensorValue inputs, got {type(arg).__name__} at position {idx}"
|
|
73
|
-
)
|
|
74
|
-
tensor_args.append(arg)
|
|
75
|
-
|
|
76
|
-
jax_args = [
|
|
77
|
-
jax.device_put(jnp.asarray(tensor.to_numpy())) for tensor in tensor_args
|
|
78
|
-
]
|
|
79
|
-
|
|
80
|
-
try:
|
|
81
|
-
# Execute with the new LoadedExecutable interface
|
|
82
|
-
result = compiled.execute(jax_args)
|
|
83
|
-
|
|
84
|
-
# Use jax.tree_util.tree_flatten to robustly handle any PyTree structure
|
|
85
|
-
flat_results, _ = jax.tree_util.tree_flatten(result)
|
|
86
|
-
flat = [TensorValue(np.asarray(item)) for item in flat_results]
|
|
87
|
-
|
|
88
|
-
return tuple(flat)
|
|
89
|
-
except Exception as e: # pragma: no cover
|
|
90
|
-
raise RuntimeError(f"StableHLO execute failed: {e}") from e
|