mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
from mplang.v1.core import PFunction, TableType, TensorType
|
|
20
|
+
from mplang.v1.kernels.base import cur_kctx, kernel_def
|
|
21
|
+
from mplang.v1.kernels.value import TableValue, TensorValue, Value
|
|
22
|
+
from mplang.v1.runtime.data_providers import get_provider, resolve_uri
|
|
23
|
+
from mplang.v1.utils import table_utils
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@kernel_def("basic.identity")
|
|
27
|
+
def _identity(pfunc: PFunction, value: Value) -> Value:
|
|
28
|
+
# Runtime guarantees exactly one argument; no extra arity checks here.
|
|
29
|
+
return value
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@kernel_def("basic.read")
|
|
33
|
+
def _read(pfunc: PFunction) -> Value:
|
|
34
|
+
path = pfunc.attrs.get("path")
|
|
35
|
+
if path is None:
|
|
36
|
+
raise ValueError("missing path attr for basic.read")
|
|
37
|
+
out_t = pfunc.outs_info[0]
|
|
38
|
+
uri = resolve_uri(str(path))
|
|
39
|
+
prov = get_provider(uri.scheme)
|
|
40
|
+
if prov is None:
|
|
41
|
+
raise NotImplementedError(f"no resource provider for scheme: {uri.scheme}")
|
|
42
|
+
ctx = cur_kctx()
|
|
43
|
+
try:
|
|
44
|
+
data = prov.read(uri, out_t, ctx=ctx)
|
|
45
|
+
except Exception as e: # pragma: no cover - provider errors
|
|
46
|
+
raise RuntimeError(f"basic.read failed: {e}") from e
|
|
47
|
+
|
|
48
|
+
if isinstance(data, Value):
|
|
49
|
+
return data
|
|
50
|
+
|
|
51
|
+
if isinstance(out_t, TableType):
|
|
52
|
+
return TableValue(data)
|
|
53
|
+
elif isinstance(out_t, TensorType):
|
|
54
|
+
return TensorValue(np.asarray(data))
|
|
55
|
+
else:
|
|
56
|
+
raise TypeError(
|
|
57
|
+
f"basic.read only supports TableType/TensorType outputs, got {type(out_t).__name__}"
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@kernel_def("basic.write")
|
|
62
|
+
def _write(pfunc: PFunction, obj: Value) -> Value:
|
|
63
|
+
path = pfunc.attrs.get("path")
|
|
64
|
+
if path is None:
|
|
65
|
+
raise ValueError("missing path attr for basic.write")
|
|
66
|
+
uri = resolve_uri(str(path))
|
|
67
|
+
prov = get_provider(uri.scheme)
|
|
68
|
+
if prov is None:
|
|
69
|
+
raise NotImplementedError(f"no resource provider for scheme: {uri.scheme}")
|
|
70
|
+
# Pass Value object directly to provider - let provider decide how to handle it
|
|
71
|
+
ctx = cur_kctx()
|
|
72
|
+
try:
|
|
73
|
+
prov.write(uri, obj, ctx=ctx)
|
|
74
|
+
except Exception as e: # pragma: no cover
|
|
75
|
+
raise RuntimeError(f"basic.write failed: {e}") from e
|
|
76
|
+
return obj
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@kernel_def("basic.constant")
|
|
80
|
+
def _constant(pfunc: PFunction) -> Value:
|
|
81
|
+
"""Return constants as Value types (TensorValue or TableValue)."""
|
|
82
|
+
data_bytes = pfunc.attrs.get("data_bytes")
|
|
83
|
+
if data_bytes is None:
|
|
84
|
+
raise ValueError("missing data_bytes attr for basic.constant")
|
|
85
|
+
out_t = pfunc.outs_info[0]
|
|
86
|
+
fmt = pfunc.attrs.get("data_format")
|
|
87
|
+
if isinstance(out_t, TableType):
|
|
88
|
+
if fmt != "bytes[parquet]":
|
|
89
|
+
raise ValueError(f"unsupported table constant format {fmt}")
|
|
90
|
+
df = table_utils.decode_table(data_bytes, format="parquet")
|
|
91
|
+
return TableValue(df)
|
|
92
|
+
# tensor path
|
|
93
|
+
shape = out_t.shape # type: ignore[attr-defined,union-attr]
|
|
94
|
+
dtype = out_t.dtype.numpy_dtype() # type: ignore[attr-defined,union-attr]
|
|
95
|
+
arr = np.frombuffer(data_bytes, dtype=dtype).reshape(shape)
|
|
96
|
+
return TensorValue(arr)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@kernel_def("basic.rank")
|
|
100
|
+
def _rank(pfunc: PFunction) -> TensorValue:
|
|
101
|
+
"""Return rank as TensorValue."""
|
|
102
|
+
ctx = cur_kctx()
|
|
103
|
+
arr = np.array(ctx.rank, dtype=np.uint64)
|
|
104
|
+
return TensorValue(arr)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@kernel_def("basic.prand")
|
|
108
|
+
def _prand(pfunc: PFunction) -> TensorValue:
|
|
109
|
+
"""Return random data as TensorValue."""
|
|
110
|
+
shape = pfunc.attrs.get("shape", ())
|
|
111
|
+
rng = np.random.default_rng()
|
|
112
|
+
info = np.iinfo(np.uint64)
|
|
113
|
+
data = rng.integers(
|
|
114
|
+
low=info.min, high=info.max, size=shape, dtype=np.uint64, endpoint=True
|
|
115
|
+
)
|
|
116
|
+
return TensorValue(data)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@kernel_def("basic.table_to_tensor")
|
|
120
|
+
def _table_to_tensor(pfunc: PFunction, table: TableValue) -> TensorValue:
|
|
121
|
+
"""Convert table to tensor, return as TensorValue."""
|
|
122
|
+
arrow_table = table.to_arrow()
|
|
123
|
+
if arrow_table.num_columns == 0:
|
|
124
|
+
raise ValueError("cannot pack empty table")
|
|
125
|
+
# Convert Arrow columns to numpy arrays and stack
|
|
126
|
+
mat = np.column_stack([
|
|
127
|
+
arrow_table.column(i).to_numpy() for i in range(arrow_table.num_columns)
|
|
128
|
+
])
|
|
129
|
+
return TensorValue(mat)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@kernel_def("basic.tensor_to_table")
|
|
133
|
+
def _tensor_to_table(pfunc: PFunction, tensor: TensorValue) -> TableValue:
|
|
134
|
+
"""Convert tensor to table, return as TableValue."""
|
|
135
|
+
import pyarrow as pa # type: ignore
|
|
136
|
+
|
|
137
|
+
arr = tensor.to_numpy()
|
|
138
|
+
if arr.ndim != 2:
|
|
139
|
+
raise ValueError("tensor_to_table expects rank-2 array")
|
|
140
|
+
col_names = pfunc.attrs.get("column_names")
|
|
141
|
+
if col_names is None:
|
|
142
|
+
raise ValueError("missing column_names attr")
|
|
143
|
+
# Create Arrow table directly from numpy array columns
|
|
144
|
+
arrays = [pa.array(arr[:, i]) for i in range(arr.shape[1])]
|
|
145
|
+
arrow_table = pa.table(dict(zip(col_names, arrays, strict=True)))
|
|
146
|
+
return TableValue(arrow_table)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _summ(v: Value) -> str:
|
|
150
|
+
try:
|
|
151
|
+
if isinstance(v, TableValue):
|
|
152
|
+
# Use Arrow's native string representation (more efficient)
|
|
153
|
+
arrow_table = v.to_arrow()
|
|
154
|
+
# Show first 8 rows
|
|
155
|
+
preview = arrow_table.slice(0, min(8, arrow_table.num_rows))
|
|
156
|
+
return str(preview)
|
|
157
|
+
if isinstance(v, TensorValue):
|
|
158
|
+
arr = v.to_numpy()
|
|
159
|
+
return str(
|
|
160
|
+
np.array2string(
|
|
161
|
+
arr, threshold=64, edgeitems=3, precision=6, suppress_small=True
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
return repr(v)
|
|
165
|
+
except Exception as e: # pragma: no cover
|
|
166
|
+
return f"<unprintable {type(v).__name__}: {e}>"
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@kernel_def("basic.debug_print")
|
|
170
|
+
def _debug_print(pfunc: PFunction, val: Value) -> Value:
|
|
171
|
+
prefix = pfunc.attrs.get("prefix", "")
|
|
172
|
+
ctx = cur_kctx()
|
|
173
|
+
print(f"[debug_print][rank={ctx.rank}] {prefix}{_summ(val)}")
|
|
174
|
+
return val
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@kernel_def("basic.pack")
|
|
178
|
+
def _pack(pfunc: PFunction, value: Value) -> TensorValue:
|
|
179
|
+
outs_info = pfunc.outs_info
|
|
180
|
+
if len(outs_info) != 1:
|
|
181
|
+
raise ValueError("basic.pack expects single output type")
|
|
182
|
+
out_ty = outs_info[0]
|
|
183
|
+
if not isinstance(out_ty, TensorType):
|
|
184
|
+
raise TypeError("basic.pack must return TensorType")
|
|
185
|
+
if out_ty.dtype.numpy_dtype() != np.uint8:
|
|
186
|
+
raise TypeError("basic.pack output dtype must be uint8")
|
|
187
|
+
|
|
188
|
+
if isinstance(value, TableValue):
|
|
189
|
+
# Serialize Arrow table using IPC stream for consistency with Value serde
|
|
190
|
+
import pyarrow as pa # type: ignore
|
|
191
|
+
import pyarrow.ipc as pa_ipc # type: ignore
|
|
192
|
+
|
|
193
|
+
arrow_table = value.to_arrow()
|
|
194
|
+
sink = pa.BufferOutputStream()
|
|
195
|
+
with pa_ipc.new_stream(sink, arrow_table.schema) as writer: # type: ignore[arg-type]
|
|
196
|
+
writer.write_table(arrow_table) # type: ignore[arg-type]
|
|
197
|
+
ipc_bytes = sink.getvalue().to_pybytes()
|
|
198
|
+
return TensorValue(np.frombuffer(ipc_bytes, dtype=np.uint8))
|
|
199
|
+
|
|
200
|
+
if isinstance(value, TensorValue):
|
|
201
|
+
arr = value.to_numpy()
|
|
202
|
+
return TensorValue(np.frombuffer(arr.tobytes(order="C"), dtype=np.uint8))
|
|
203
|
+
|
|
204
|
+
raise TypeError(f"basic.pack does not support Value type {type(value).__name__}")
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@kernel_def("basic.unpack")
|
|
208
|
+
def _unpack(pfunc: PFunction, packed: TensorValue) -> Value:
|
|
209
|
+
outs_info = pfunc.outs_info
|
|
210
|
+
if len(outs_info) != 1:
|
|
211
|
+
raise ValueError("basic.unpack expects single output type")
|
|
212
|
+
out_ty = outs_info[0]
|
|
213
|
+
|
|
214
|
+
b = packed.to_numpy().astype(np.uint8, copy=False).reshape(-1)
|
|
215
|
+
|
|
216
|
+
if isinstance(out_ty, TensorType):
|
|
217
|
+
np_dtype = out_ty.dtype.numpy_dtype()
|
|
218
|
+
shape = tuple(out_ty.shape)
|
|
219
|
+
if any(dim < 0 for dim in shape):
|
|
220
|
+
raise ValueError("basic.unpack does not support dynamic tensor shapes")
|
|
221
|
+
elem_count = int(np.prod(shape))
|
|
222
|
+
expected = elem_count * np.dtype(np_dtype).itemsize
|
|
223
|
+
if b.size != expected:
|
|
224
|
+
raise ValueError(
|
|
225
|
+
f"unpack size mismatch: got {b.size} bytes, expect {expected} for {np_dtype} {shape}"
|
|
226
|
+
)
|
|
227
|
+
arr = np.frombuffer(b.tobytes(), dtype=np_dtype)
|
|
228
|
+
return TensorValue(arr.reshape(shape))
|
|
229
|
+
|
|
230
|
+
if isinstance(out_ty, TableType):
|
|
231
|
+
# Deserialize Arrow IPC stream back to TableValue
|
|
232
|
+
import pyarrow as pa # type: ignore
|
|
233
|
+
import pyarrow.ipc as pa_ipc # type: ignore
|
|
234
|
+
|
|
235
|
+
buf = pa.py_buffer(b.tobytes())
|
|
236
|
+
reader = pa_ipc.open_stream(buf)
|
|
237
|
+
table = reader.read_all()
|
|
238
|
+
return TableValue(table)
|
|
239
|
+
|
|
240
|
+
raise TypeError("basic.unpack output type must be TensorType or TableType")
|
|
@@ -17,12 +17,12 @@ from __future__ import annotations
|
|
|
17
17
|
from collections.abc import Mapping
|
|
18
18
|
from typing import Any
|
|
19
19
|
|
|
20
|
-
from mplang.core.
|
|
21
|
-
from mplang.core.pfunc import PFunction
|
|
22
|
-
from mplang.core.table import TableLike, TableType
|
|
23
|
-
from mplang.core.tensor import TensorLike, TensorType
|
|
24
|
-
from mplang.kernels import base
|
|
25
|
-
from mplang.kernels.base import KernelContext, get_kernel_spec, kernel_exists
|
|
20
|
+
from mplang.v1.core.dtypes import UINT8, DType
|
|
21
|
+
from mplang.v1.core.pfunc import PFunction
|
|
22
|
+
from mplang.v1.core.table import PandasTableLike, TableLike, TableType
|
|
23
|
+
from mplang.v1.core.tensor import TensorLike, TensorType
|
|
24
|
+
from mplang.v1.kernels import base
|
|
25
|
+
from mplang.v1.kernels.base import KernelContext, get_kernel_spec, kernel_exists
|
|
26
26
|
|
|
27
27
|
# Default bindings
|
|
28
28
|
# Import kernel implementation modules explicitly so their @kernel_def entries
|
|
@@ -35,13 +35,14 @@ def _ensure_impl_imported() -> None:
|
|
|
35
35
|
global _IMPL_IMPORTED
|
|
36
36
|
if _IMPL_IMPORTED:
|
|
37
37
|
return
|
|
38
|
-
from mplang.kernels import
|
|
39
|
-
from mplang.kernels import crypto as _impl_crypto # noqa: F401
|
|
40
|
-
from mplang.kernels import
|
|
41
|
-
from mplang.kernels import
|
|
42
|
-
from mplang.kernels import
|
|
43
|
-
from mplang.kernels import
|
|
44
|
-
from mplang.kernels import
|
|
38
|
+
from mplang.v1.kernels import basic as _impl_basic # noqa: F401
|
|
39
|
+
from mplang.v1.kernels import crypto as _impl_crypto # noqa: F401
|
|
40
|
+
from mplang.v1.kernels import fhe as _impl_fhe # noqa: F401
|
|
41
|
+
from mplang.v1.kernels import mock_tee as _impl_tee # noqa: F401
|
|
42
|
+
from mplang.v1.kernels import phe as _impl_phe # noqa: F401
|
|
43
|
+
from mplang.v1.kernels import spu as _impl_spu # noqa: F401
|
|
44
|
+
from mplang.v1.kernels import sql_duckdb as _impl_sql_duckdb # noqa: F401
|
|
45
|
+
from mplang.v1.kernels import stablehlo as _impl_stablehlo # noqa: F401
|
|
45
46
|
|
|
46
47
|
_IMPL_IMPORTED = True
|
|
47
48
|
|
|
@@ -49,18 +50,18 @@ def _ensure_impl_imported() -> None:
|
|
|
49
50
|
# imports consolidated above
|
|
50
51
|
|
|
51
52
|
_DEFAULT_BINDINGS: dict[str, str] = {
|
|
52
|
-
#
|
|
53
|
-
"
|
|
54
|
-
"
|
|
55
|
-
"
|
|
56
|
-
"
|
|
57
|
-
"
|
|
58
|
-
"
|
|
59
|
-
"
|
|
60
|
-
"
|
|
61
|
-
"
|
|
62
|
-
"
|
|
63
|
-
"
|
|
53
|
+
# basic
|
|
54
|
+
"basic.identity": "basic.identity",
|
|
55
|
+
"basic.read": "basic.read",
|
|
56
|
+
"basic.write": "basic.write",
|
|
57
|
+
"basic.constant": "basic.constant",
|
|
58
|
+
"basic.rank": "basic.rank",
|
|
59
|
+
"basic.prand": "basic.prand",
|
|
60
|
+
"basic.table_to_tensor": "basic.table_to_tensor",
|
|
61
|
+
"basic.tensor_to_table": "basic.tensor_to_table",
|
|
62
|
+
"basic.debug_print": "basic.debug_print",
|
|
63
|
+
"basic.pack": "basic.pack",
|
|
64
|
+
"basic.unpack": "basic.unpack",
|
|
64
65
|
# crypto
|
|
65
66
|
"crypto.keygen": "crypto.keygen",
|
|
66
67
|
"crypto.enc": "crypto.enc",
|
|
@@ -80,6 +81,17 @@ _DEFAULT_BINDINGS: dict[str, str] = {
|
|
|
80
81
|
"phe.concat": "phe.concat",
|
|
81
82
|
"phe.reshape": "phe.reshape",
|
|
82
83
|
"phe.transpose": "phe.transpose",
|
|
84
|
+
# fhe
|
|
85
|
+
"fhe.keygen": "fhe.keygen",
|
|
86
|
+
"fhe.encrypt": "fhe.encrypt",
|
|
87
|
+
"fhe.decrypt": "fhe.decrypt",
|
|
88
|
+
"fhe.add": "fhe.add",
|
|
89
|
+
"fhe.mul": "fhe.mul",
|
|
90
|
+
"fhe.dot": "fhe.dot",
|
|
91
|
+
"fhe.polyval": "fhe.polyval",
|
|
92
|
+
"fhe.sub": "fhe.sub",
|
|
93
|
+
"fhe.negate": "fhe.negate",
|
|
94
|
+
"fhe.square": "fhe.square",
|
|
83
95
|
# spu
|
|
84
96
|
"spu.seed_env": "spu.seed_env",
|
|
85
97
|
"spu.makeshares": "spu.makeshares",
|
|
@@ -305,9 +317,12 @@ def _validate_table_arg(
|
|
|
305
317
|
raise TypeError(
|
|
306
318
|
f"kernel {fn_type} input[{arg_index}] expects TableLike, got {type(value).__name__}"
|
|
307
319
|
)
|
|
308
|
-
|
|
320
|
+
columns = (
|
|
321
|
+
value.columns if isinstance(value, PandasTableLike) else value.column_names
|
|
322
|
+
)
|
|
323
|
+
if len(columns) != len(spec.columns):
|
|
309
324
|
raise ValueError(
|
|
310
|
-
f"kernel {fn_type} input[{arg_index}] column count mismatch: got {len(
|
|
325
|
+
f"kernel {fn_type} input[{arg_index}] column count mismatch: got {len(columns)}, expected {len(spec.columns)}"
|
|
311
326
|
)
|
|
312
327
|
|
|
313
328
|
|
|
@@ -15,15 +15,15 @@
|
|
|
15
15
|
from __future__ import annotations
|
|
16
16
|
|
|
17
17
|
import os
|
|
18
|
-
from typing import Any
|
|
19
18
|
|
|
20
19
|
import numpy as np
|
|
21
20
|
|
|
22
|
-
from mplang.core
|
|
23
|
-
from mplang.kernels.base import cur_kctx, kernel_def
|
|
24
|
-
from mplang.
|
|
21
|
+
from mplang.v1.core import PFunction
|
|
22
|
+
from mplang.v1.kernels.base import cur_kctx, kernel_def
|
|
23
|
+
from mplang.v1.kernels.value import TensorValue
|
|
24
|
+
from mplang.v1.utils.crypto import blake2b
|
|
25
25
|
|
|
26
|
-
__all__: list[str] = [] #
|
|
26
|
+
__all__: list[str] = [] # No public exports currently
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
def _get_rng() -> np.random.Generator:
|
|
@@ -45,71 +45,78 @@ def _get_rng() -> np.random.Generator:
|
|
|
45
45
|
def _keystream(key: bytes, nonce: bytes, length: int) -> bytes:
|
|
46
46
|
# WARNING (INSECURE): hash-based keystream (key||nonce||counter)
|
|
47
47
|
out = bytearray()
|
|
48
|
-
counter = 0
|
|
49
48
|
while len(out) < length:
|
|
50
|
-
chunk = blake2b(key + nonce
|
|
49
|
+
chunk = blake2b(key + nonce)
|
|
51
50
|
out.extend(chunk)
|
|
52
|
-
counter += 1
|
|
53
51
|
return bytes(out[:length])
|
|
54
52
|
|
|
55
53
|
|
|
56
54
|
@kernel_def("crypto.keygen")
|
|
57
|
-
def _crypto_keygen(pfunc: PFunction) ->
|
|
55
|
+
def _crypto_keygen(pfunc: PFunction) -> TensorValue:
|
|
58
56
|
length = int(pfunc.attrs.get("length", 32))
|
|
59
57
|
rng = _get_rng()
|
|
60
58
|
key = rng.integers(0, 256, size=(length,), dtype=np.uint8)
|
|
61
|
-
return key
|
|
59
|
+
return TensorValue(key)
|
|
62
60
|
|
|
63
61
|
|
|
64
62
|
@kernel_def("crypto.enc")
|
|
65
|
-
def _crypto_encrypt(
|
|
66
|
-
|
|
67
|
-
|
|
63
|
+
def _crypto_encrypt(
|
|
64
|
+
pfunc: PFunction, pt_bytes: TensorValue, key: TensorValue
|
|
65
|
+
) -> TensorValue:
|
|
66
|
+
pt_bytes_np = pt_bytes.to_numpy().astype(np.uint8, copy=False)
|
|
67
|
+
key_np = key.to_numpy().astype(np.uint8, copy=False)
|
|
68
68
|
rng = _get_rng()
|
|
69
|
-
nonce = rng.integers(0, 256, size=(
|
|
69
|
+
nonce = rng.integers(0, 256, size=(16,), dtype=np.uint8)
|
|
70
70
|
stream = np.frombuffer(
|
|
71
|
-
_keystream(
|
|
71
|
+
_keystream(key_np.tobytes(), nonce.tobytes(), pt_bytes_np.size), dtype=np.uint8
|
|
72
72
|
)
|
|
73
|
-
ct = (
|
|
73
|
+
ct = (pt_bytes_np ^ stream).astype(np.uint8)
|
|
74
74
|
out = np.concatenate([nonce, ct]).astype(np.uint8)
|
|
75
|
-
return out
|
|
75
|
+
return TensorValue(out)
|
|
76
76
|
|
|
77
77
|
|
|
78
78
|
@kernel_def("crypto.dec")
|
|
79
|
-
def _crypto_decrypt(
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
79
|
+
def _crypto_decrypt(
|
|
80
|
+
pfunc: PFunction, ct_with_nonce: TensorValue, key: TensorValue
|
|
81
|
+
) -> TensorValue:
|
|
82
|
+
ct_np = ct_with_nonce.to_numpy().astype(np.uint8, copy=False)
|
|
83
|
+
key_np = key.to_numpy().astype(np.uint8, copy=False)
|
|
84
|
+
nonce = ct_np[:16]
|
|
85
|
+
ct = ct_np[16:]
|
|
84
86
|
stream = np.frombuffer(
|
|
85
|
-
_keystream(
|
|
87
|
+
_keystream(key_np.tobytes(), nonce.tobytes(), len(ct)), dtype=np.uint8
|
|
86
88
|
)
|
|
87
89
|
pt_bytes = (ct ^ stream).astype(np.uint8)
|
|
88
|
-
return pt_bytes
|
|
90
|
+
return TensorValue(pt_bytes)
|
|
89
91
|
|
|
90
92
|
|
|
91
93
|
@kernel_def("crypto.kem_keygen")
|
|
92
|
-
def _crypto_kem_keygen(pfunc: PFunction) ->
|
|
94
|
+
def _crypto_kem_keygen(pfunc: PFunction) -> tuple[TensorValue, TensorValue]:
|
|
93
95
|
rng = _get_rng()
|
|
94
96
|
sk = rng.integers(0, 256, size=(32,), dtype=np.uint8)
|
|
95
|
-
|
|
96
|
-
|
|
97
|
+
pk_bytes = blake2b(sk.tobytes())[:32]
|
|
98
|
+
pk = np.frombuffer(pk_bytes, dtype=np.uint8)
|
|
99
|
+
return (TensorValue(sk), TensorValue(pk))
|
|
97
100
|
|
|
98
101
|
|
|
99
102
|
@kernel_def("crypto.kem_derive")
|
|
100
|
-
def _crypto_kem_derive(
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
103
|
+
def _crypto_kem_derive(
|
|
104
|
+
pfunc: PFunction, sk: TensorValue, peer_pk: TensorValue
|
|
105
|
+
) -> TensorValue:
|
|
106
|
+
sk_np = sk.to_numpy().astype(np.uint8, copy=False)
|
|
107
|
+
peer_pk_np = peer_pk.to_numpy().astype(np.uint8, copy=False)
|
|
108
|
+
|
|
109
|
+
self_pk_bytes = blake2b(sk_np.tobytes())[:32]
|
|
110
|
+
self_pk_arr = np.frombuffer(self_pk_bytes, dtype=np.uint8)
|
|
111
|
+
xored = (self_pk_arr ^ peer_pk_np).astype(np.uint8)
|
|
105
112
|
secret = np.frombuffer(blake2b(xored.tobytes())[:32], dtype=np.uint8)
|
|
106
|
-
return secret
|
|
113
|
+
return TensorValue(secret)
|
|
107
114
|
|
|
108
115
|
|
|
109
116
|
@kernel_def("crypto.hkdf")
|
|
110
|
-
def _crypto_hkdf(pfunc: PFunction, secret:
|
|
111
|
-
|
|
117
|
+
def _crypto_hkdf(pfunc: PFunction, secret: TensorValue) -> TensorValue:
|
|
118
|
+
secret_np = secret.to_numpy().astype(np.uint8, copy=False)
|
|
112
119
|
info_str = str(pfunc.attrs.get("info", ""))
|
|
113
120
|
info = info_str.encode("utf-8")
|
|
114
|
-
out = np.frombuffer(blake2b(
|
|
115
|
-
return out
|
|
121
|
+
out = np.frombuffer(blake2b(secret_np.tobytes() + info)[:32], dtype=np.uint8)
|
|
122
|
+
return TensorValue(out)
|