mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +46 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +19 -13
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/libs/mpc/psi/rr22.py +303 -0
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang/v2/libs/mpc/psi/rr22.py +0 -344
- mplang_nightly-0.1.dev268.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/kernels/context.py
DELETED
|
@@ -1,369 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
from collections.abc import Mapping
|
|
18
|
-
from typing import Any
|
|
19
|
-
|
|
20
|
-
from mplang.v1.core.dtypes import UINT8, DType
|
|
21
|
-
from mplang.v1.core.pfunc import PFunction
|
|
22
|
-
from mplang.v1.core.table import PandasTableLike, TableLike, TableType
|
|
23
|
-
from mplang.v1.core.tensor import TensorLike, TensorType
|
|
24
|
-
from mplang.v1.kernels import base
|
|
25
|
-
from mplang.v1.kernels.base import KernelContext, get_kernel_spec, kernel_exists
|
|
26
|
-
|
|
27
|
-
# Default bindings
|
|
28
|
-
# Import kernel implementation modules explicitly so their @kernel_def entries
|
|
29
|
-
# register at import time. Keep imports grouped; alias with leading underscore
|
|
30
|
-
# to silence unused variable warnings without F401 pragmas.
|
|
31
|
-
_IMPL_IMPORTED = False
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def _ensure_impl_imported() -> None:
|
|
35
|
-
global _IMPL_IMPORTED
|
|
36
|
-
if _IMPL_IMPORTED:
|
|
37
|
-
return
|
|
38
|
-
from mplang.v1.kernels import basic as _impl_basic # noqa: F401
|
|
39
|
-
from mplang.v1.kernels import crypto as _impl_crypto # noqa: F401
|
|
40
|
-
from mplang.v1.kernels import fhe as _impl_fhe # noqa: F401
|
|
41
|
-
from mplang.v1.kernels import mock_tee as _impl_tee # noqa: F401
|
|
42
|
-
from mplang.v1.kernels import phe as _impl_phe # noqa: F401
|
|
43
|
-
from mplang.v1.kernels import spu as _impl_spu # noqa: F401
|
|
44
|
-
from mplang.v1.kernels import sql_duckdb as _impl_sql_duckdb # noqa: F401
|
|
45
|
-
from mplang.v1.kernels import stablehlo as _impl_stablehlo # noqa: F401
|
|
46
|
-
|
|
47
|
-
_IMPL_IMPORTED = True
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
# imports consolidated above
|
|
51
|
-
|
|
52
|
-
_DEFAULT_BINDINGS: dict[str, str] = {
|
|
53
|
-
# basic
|
|
54
|
-
"basic.identity": "basic.identity",
|
|
55
|
-
"basic.read": "basic.read",
|
|
56
|
-
"basic.write": "basic.write",
|
|
57
|
-
"basic.constant": "basic.constant",
|
|
58
|
-
"basic.rank": "basic.rank",
|
|
59
|
-
"basic.prand": "basic.prand",
|
|
60
|
-
"basic.table_to_tensor": "basic.table_to_tensor",
|
|
61
|
-
"basic.tensor_to_table": "basic.tensor_to_table",
|
|
62
|
-
"basic.debug_print": "basic.debug_print",
|
|
63
|
-
"basic.pack": "basic.pack",
|
|
64
|
-
"basic.unpack": "basic.unpack",
|
|
65
|
-
# crypto
|
|
66
|
-
"crypto.keygen": "crypto.keygen",
|
|
67
|
-
"crypto.enc": "crypto.enc",
|
|
68
|
-
"crypto.dec": "crypto.dec",
|
|
69
|
-
"crypto.kem_keygen": "crypto.kem_keygen",
|
|
70
|
-
"crypto.kem_derive": "crypto.kem_derive",
|
|
71
|
-
"crypto.hkdf": "crypto.hkdf",
|
|
72
|
-
# phe
|
|
73
|
-
"phe.keygen": "phe.keygen",
|
|
74
|
-
"phe.encrypt": "phe.encrypt",
|
|
75
|
-
"phe.mul": "phe.mul",
|
|
76
|
-
"phe.add": "phe.add",
|
|
77
|
-
"phe.decrypt": "phe.decrypt",
|
|
78
|
-
"phe.dot": "phe.dot",
|
|
79
|
-
"phe.gather": "phe.gather",
|
|
80
|
-
"phe.scatter": "phe.scatter",
|
|
81
|
-
"phe.concat": "phe.concat",
|
|
82
|
-
"phe.reshape": "phe.reshape",
|
|
83
|
-
"phe.transpose": "phe.transpose",
|
|
84
|
-
# fhe
|
|
85
|
-
"fhe.keygen": "fhe.keygen",
|
|
86
|
-
"fhe.encrypt": "fhe.encrypt",
|
|
87
|
-
"fhe.decrypt": "fhe.decrypt",
|
|
88
|
-
"fhe.add": "fhe.add",
|
|
89
|
-
"fhe.mul": "fhe.mul",
|
|
90
|
-
"fhe.dot": "fhe.dot",
|
|
91
|
-
"fhe.polyval": "fhe.polyval",
|
|
92
|
-
"fhe.sub": "fhe.sub",
|
|
93
|
-
"fhe.negate": "fhe.negate",
|
|
94
|
-
"fhe.square": "fhe.square",
|
|
95
|
-
# spu
|
|
96
|
-
"spu.seed_env": "spu.seed_env",
|
|
97
|
-
"spu.makeshares": "spu.makeshares",
|
|
98
|
-
"spu.reconstruct": "spu.reconstruct",
|
|
99
|
-
"spu.run_pphlo": "spu.run_pphlo",
|
|
100
|
-
# stablehlo
|
|
101
|
-
"mlir.stablehlo": "mlir.stablehlo",
|
|
102
|
-
# sql
|
|
103
|
-
# generic SQL op; backend-specific kernel id for duckdb
|
|
104
|
-
"sql.run": "duckdb.run_sql",
|
|
105
|
-
# tee
|
|
106
|
-
# "tee.quote_gen": "mock_tee.quote_gen",
|
|
107
|
-
# "tee.attest": "mock_tee.attest",
|
|
108
|
-
}
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
# --- RuntimeContext ---
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
class RuntimeContext:
|
|
115
|
-
"""Per-runtime execution context with isolated op->kernel bindings.
|
|
116
|
-
|
|
117
|
-
This object owns ONLY static dispatch metadata ("op bindings") and mutable
|
|
118
|
-
per-rank kernel side state/cache/stats. It does NOT store per-evaluation
|
|
119
|
-
variable bindings (those are provided to the evaluator at evaluation time).
|
|
120
|
-
|
|
121
|
-
Parameters
|
|
122
|
-
----------
|
|
123
|
-
rank : int
|
|
124
|
-
Local rank of this participant.
|
|
125
|
-
world_size : int
|
|
126
|
-
Total number of participants.
|
|
127
|
-
initial_bindings : Mapping[str, str] | None, optional
|
|
128
|
-
Optional partial overrides applied on top of the default binding table
|
|
129
|
-
during construction (override semantics, not replace). These map
|
|
130
|
-
op_type -> kernel_id and form a *template* for dispatch. After
|
|
131
|
-
initialization, all (re)binding must go through ``bind_op`` /
|
|
132
|
-
``rebind_op`` on this context (scoped to THIS runtime only).
|
|
133
|
-
state : dict, optional
|
|
134
|
-
Mutable per-runtime key/value storage for kernels. Flat key space;
|
|
135
|
-
callers SHOULD use dotted prefixes (e.g. "stablehlo.compile_cache").
|
|
136
|
-
Kernels own their *state* (functional correctness data, caches,
|
|
137
|
-
handles, compiled objects, RNGs, etc.). Runtime does not interpret
|
|
138
|
-
structure—values may themselves be dicts if a kernel wants its own
|
|
139
|
-
pocket. Created empty when omitted.
|
|
140
|
-
stats : dict, optional
|
|
141
|
-
Mutable statistics/telemetry owned by the runtime (usage counters,
|
|
142
|
-
timings, profiling aids). Kernels may increment counters but should
|
|
143
|
-
avoid storing functional state here. A default "op_calls" mapping is
|
|
144
|
-
ensured. Created empty when omitted.
|
|
145
|
-
"""
|
|
146
|
-
|
|
147
|
-
__slots__ = ("_ibindings", "rank", "state", "stats", "world_size")
|
|
148
|
-
|
|
149
|
-
def __init__(
|
|
150
|
-
self,
|
|
151
|
-
rank: int,
|
|
152
|
-
world_size: int,
|
|
153
|
-
initial_bindings: Mapping[str, str] | None = None,
|
|
154
|
-
*,
|
|
155
|
-
state: dict[str, Any] | None = None,
|
|
156
|
-
stats: dict[str, Any] | None = None,
|
|
157
|
-
) -> None:
|
|
158
|
-
_ensure_impl_imported()
|
|
159
|
-
self.rank = rank
|
|
160
|
-
self.world_size = world_size
|
|
161
|
-
# Merge defaults with user overrides (override semantics)
|
|
162
|
-
self._ibindings: dict[str, str] = {
|
|
163
|
-
**_DEFAULT_BINDINGS,
|
|
164
|
-
**(initial_bindings or {}),
|
|
165
|
-
}
|
|
166
|
-
self.state = state if state is not None else {}
|
|
167
|
-
self.stats = stats if stats is not None else {}
|
|
168
|
-
self.stats.setdefault("op_calls", {})
|
|
169
|
-
|
|
170
|
-
def run_kernel(self, pfunc: PFunction, arg_list: list[Any]) -> list[Any]:
|
|
171
|
-
fn_type = pfunc.fn_type
|
|
172
|
-
kid = self._ibindings.get(fn_type)
|
|
173
|
-
if kid is None:
|
|
174
|
-
raise NotImplementedError(f"no backend kernel registered for op {fn_type}")
|
|
175
|
-
spec = get_kernel_spec(kid)
|
|
176
|
-
fn = spec.fn # kernel implementation
|
|
177
|
-
if len(arg_list) != len(pfunc.ins_info):
|
|
178
|
-
raise ValueError(
|
|
179
|
-
f"kernel {fn_type} arg count mismatch: got {len(arg_list)}, expect {len(pfunc.ins_info)}"
|
|
180
|
-
)
|
|
181
|
-
for idx, (ins_spec, val) in enumerate(
|
|
182
|
-
zip(pfunc.ins_info, arg_list, strict=True)
|
|
183
|
-
):
|
|
184
|
-
if isinstance(ins_spec, TableType):
|
|
185
|
-
_validate_table_arg(fn_type, idx, ins_spec, val)
|
|
186
|
-
continue
|
|
187
|
-
if isinstance(ins_spec, TensorType):
|
|
188
|
-
_validate_tensor_arg(fn_type, idx, ins_spec, val)
|
|
189
|
-
continue
|
|
190
|
-
|
|
191
|
-
# install kernel context
|
|
192
|
-
kctx = KernelContext(rank=self.rank, world_size=self.world_size, runtime=self)
|
|
193
|
-
token = base._CTX_VAR.set(kctx)
|
|
194
|
-
try:
|
|
195
|
-
raw = fn(pfunc, *arg_list)
|
|
196
|
-
finally:
|
|
197
|
-
base._CTX_VAR.reset(token)
|
|
198
|
-
|
|
199
|
-
try:
|
|
200
|
-
op_calls = self.stats.setdefault("op_calls", {})
|
|
201
|
-
op_calls[fn_type] = op_calls.get(fn_type, 0) + 1
|
|
202
|
-
except Exception: # pragma: no cover - never raise due to stats
|
|
203
|
-
pass
|
|
204
|
-
expected = len(pfunc.outs_info)
|
|
205
|
-
if expected == 0:
|
|
206
|
-
if raw in (None, (), []):
|
|
207
|
-
return []
|
|
208
|
-
raise ValueError(
|
|
209
|
-
f"kernel {fn_type} should return no values; got {type(raw).__name__}"
|
|
210
|
-
)
|
|
211
|
-
if expected == 1:
|
|
212
|
-
if isinstance(raw, (tuple, list)):
|
|
213
|
-
if len(raw) != 1:
|
|
214
|
-
raise ValueError(
|
|
215
|
-
f"kernel {fn_type} produced {len(raw)} outputs, expected 1"
|
|
216
|
-
)
|
|
217
|
-
return [raw[0]]
|
|
218
|
-
return [raw]
|
|
219
|
-
if not isinstance(raw, (tuple, list)):
|
|
220
|
-
raise TypeError(
|
|
221
|
-
f"kernel {fn_type} must return sequence (len={expected}), got {type(raw).__name__}"
|
|
222
|
-
)
|
|
223
|
-
if len(raw) != expected:
|
|
224
|
-
raise ValueError(
|
|
225
|
-
f"kernel {fn_type} produced {len(raw)} outputs, expected {expected}"
|
|
226
|
-
)
|
|
227
|
-
return list(raw)
|
|
228
|
-
|
|
229
|
-
def reset(self) -> None:
|
|
230
|
-
self.state.clear()
|
|
231
|
-
|
|
232
|
-
# ---- runtime state API (flat key space) ----
|
|
233
|
-
# Keys are treated atomically; convention encourages dotted prefixes
|
|
234
|
-
# (e.g. 'stablehlo.compile_cache.hash', 'crypto.rng'). Implementation
|
|
235
|
-
# does NOT parse or create hierarchical dicts—any grouping is purely
|
|
236
|
-
# by string prefix. Values themselves MAY be dicts if callers want a
|
|
237
|
-
# manual pocket. This keeps semantics simple and predictable.
|
|
238
|
-
|
|
239
|
-
def ensure_state(self, key: str, factory: type | Any = dict) -> Any:
|
|
240
|
-
"""Return value for key; if absent create via factory and store.
|
|
241
|
-
|
|
242
|
-
Key is not parsed; dotted forms are allowed but treated as a single
|
|
243
|
-
map key. Use consistent prefixes for grouping (e.g. 'spu.config').
|
|
244
|
-
"""
|
|
245
|
-
if not key:
|
|
246
|
-
raise ValueError("empty state key")
|
|
247
|
-
val = self.state.get(key)
|
|
248
|
-
if val is None:
|
|
249
|
-
val = factory()
|
|
250
|
-
self.state[key] = val
|
|
251
|
-
return val
|
|
252
|
-
|
|
253
|
-
def get_state(self, key: str, default: Any | None = None) -> Any:
|
|
254
|
-
if not key:
|
|
255
|
-
raise ValueError("empty state key")
|
|
256
|
-
return self.state.get(key, default)
|
|
257
|
-
|
|
258
|
-
def set_state(self, key: str, value: Any) -> None:
|
|
259
|
-
if not key:
|
|
260
|
-
raise ValueError("empty state key")
|
|
261
|
-
self.state[key] = value
|
|
262
|
-
|
|
263
|
-
def del_state(self, key: str) -> None:
|
|
264
|
-
if not key:
|
|
265
|
-
raise ValueError("empty state key")
|
|
266
|
-
self.state.pop(key, None)
|
|
267
|
-
|
|
268
|
-
def list_state(self, prefix: str = "") -> dict[str, Any]:
|
|
269
|
-
"""Return mapping of key -> value; optional prefix filter.
|
|
270
|
-
|
|
271
|
-
Prefix match is string-based; if prefix is non-empty include keys
|
|
272
|
-
where key == prefix or key starts with prefix + '.'.
|
|
273
|
-
"""
|
|
274
|
-
if not prefix:
|
|
275
|
-
return dict(self.state)
|
|
276
|
-
pref = prefix if prefix.endswith(".") else prefix + "."
|
|
277
|
-
out: dict[str, Any] = {}
|
|
278
|
-
for k, v in self.state.items():
|
|
279
|
-
if k == prefix or k.startswith(pref):
|
|
280
|
-
out[k] = v
|
|
281
|
-
return out
|
|
282
|
-
|
|
283
|
-
# ---- explicit (re)binding API ----
|
|
284
|
-
def bind_op(self, op_type: str, kernel_id: str, *, force: bool = False) -> None:
|
|
285
|
-
"""Bind an operation to a kernel for THIS context only.
|
|
286
|
-
|
|
287
|
-
force=False (default) keeps existing binding (no silent override).
|
|
288
|
-
"""
|
|
289
|
-
if not kernel_exists(kernel_id):
|
|
290
|
-
raise KeyError(f"kernel_id {kernel_id} not registered")
|
|
291
|
-
if not force and op_type in self._ibindings:
|
|
292
|
-
return
|
|
293
|
-
self._ibindings[op_type] = kernel_id
|
|
294
|
-
|
|
295
|
-
def rebind_op(self, op_type: str, kernel_id: str) -> None:
|
|
296
|
-
"""Force rebind an operation to a different kernel (shorthand)."""
|
|
297
|
-
self.bind_op(op_type, kernel_id, force=True)
|
|
298
|
-
|
|
299
|
-
# Introspection helpers
|
|
300
|
-
def list_bound_ops(self) -> list[str]: # pragma: no cover - convenience
|
|
301
|
-
return sorted(self._ibindings.keys())
|
|
302
|
-
|
|
303
|
-
def get_binding(self, op_type: str) -> str | None: # pragma: no cover
|
|
304
|
-
return self._ibindings.get(op_type)
|
|
305
|
-
|
|
306
|
-
def __repr__(self) -> str: # pragma: no cover - debug aid
|
|
307
|
-
return (
|
|
308
|
-
f"RuntimeContext(rank={self.rank}, world_size={self.world_size}, "
|
|
309
|
-
f"bound_ops={len(self._ibindings)})"
|
|
310
|
-
)
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
def _validate_table_arg(
|
|
314
|
-
fn_type: str, arg_index: int, spec: TableType, value: Any
|
|
315
|
-
) -> None:
|
|
316
|
-
if not isinstance(value, TableLike):
|
|
317
|
-
raise TypeError(
|
|
318
|
-
f"kernel {fn_type} input[{arg_index}] expects TableLike, got {type(value).__name__}"
|
|
319
|
-
)
|
|
320
|
-
columns = (
|
|
321
|
-
value.columns if isinstance(value, PandasTableLike) else value.column_names
|
|
322
|
-
)
|
|
323
|
-
if len(columns) != len(spec.columns):
|
|
324
|
-
raise ValueError(
|
|
325
|
-
f"kernel {fn_type} input[{arg_index}] column count mismatch: got {len(columns)}, expected {len(spec.columns)}"
|
|
326
|
-
)
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
def _validate_tensor_arg(
|
|
330
|
-
fn_type: str, arg_index: int, spec: TensorType, value: Any
|
|
331
|
-
) -> None:
|
|
332
|
-
# Backend-only handle sentinel (e.g., PHE keys) bypasses all structural checks
|
|
333
|
-
if tuple(spec.shape) == (-1, 0) and spec.dtype == UINT8:
|
|
334
|
-
return
|
|
335
|
-
|
|
336
|
-
if isinstance(value, (int, float, bool, complex)):
|
|
337
|
-
val_shape: tuple[Any, ...] = ()
|
|
338
|
-
duck_dtype: Any = type(value)
|
|
339
|
-
else:
|
|
340
|
-
if not isinstance(value, TensorLike):
|
|
341
|
-
raise TypeError(
|
|
342
|
-
f"kernel {fn_type} input[{arg_index}] expects TensorLike, got {type(value).__name__}"
|
|
343
|
-
)
|
|
344
|
-
val_shape = getattr(value, "shape", ())
|
|
345
|
-
duck_dtype = getattr(value, "dtype", None)
|
|
346
|
-
|
|
347
|
-
if len(spec.shape) != len(val_shape):
|
|
348
|
-
raise ValueError(
|
|
349
|
-
f"kernel {fn_type} input[{arg_index}] rank mismatch: got {val_shape}, expected {spec.shape}"
|
|
350
|
-
)
|
|
351
|
-
|
|
352
|
-
for dim_idx, (spec_dim, val_dim) in enumerate(
|
|
353
|
-
zip(spec.shape, val_shape, strict=True)
|
|
354
|
-
):
|
|
355
|
-
if spec_dim >= 0 and spec_dim != val_dim:
|
|
356
|
-
raise ValueError(
|
|
357
|
-
f"kernel {fn_type} input[{arg_index}] shape mismatch at dim {dim_idx}: got {val_dim}, expected {spec_dim}"
|
|
358
|
-
)
|
|
359
|
-
|
|
360
|
-
try:
|
|
361
|
-
val_dtype = DType.from_any(duck_dtype)
|
|
362
|
-
except (ValueError, TypeError): # pragma: no cover
|
|
363
|
-
raise TypeError(
|
|
364
|
-
f"kernel {fn_type} input[{arg_index}] has unsupported dtype object {duck_dtype!r}"
|
|
365
|
-
) from None
|
|
366
|
-
if val_dtype != spec.dtype:
|
|
367
|
-
raise ValueError(
|
|
368
|
-
f"kernel {fn_type} input[{arg_index}] dtype mismatch: got {val_dtype}, expected {spec.dtype}"
|
|
369
|
-
)
|
mplang/v1/kernels/crypto.py
DELETED
|
@@ -1,122 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
import os
|
|
18
|
-
|
|
19
|
-
import numpy as np
|
|
20
|
-
|
|
21
|
-
from mplang.v1.core import PFunction
|
|
22
|
-
from mplang.v1.kernels.base import cur_kctx, kernel_def
|
|
23
|
-
from mplang.v1.kernels.value import TensorValue
|
|
24
|
-
from mplang.v1.utils.crypto import blake2b
|
|
25
|
-
|
|
26
|
-
__all__: list[str] = [] # No public exports currently
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def _get_rng() -> np.random.Generator:
|
|
30
|
-
"""Get (and lazily create) per-rank RNG for crypto kernels.
|
|
31
|
-
|
|
32
|
-
Runtime state is untyped, so we narrow the type explicitly for mypy.
|
|
33
|
-
"""
|
|
34
|
-
kctx = cur_kctx()
|
|
35
|
-
rt = kctx.runtime
|
|
36
|
-
rng_obj = rt.get_state("crypto.rng")
|
|
37
|
-
if rng_obj is None:
|
|
38
|
-
seed = int(os.environ.get("MPLANG_CRYPTO_SEED", "0")) + kctx.rank * 7919
|
|
39
|
-
rng_obj = np.random.default_rng(seed)
|
|
40
|
-
rt.set_state("crypto.rng", rng_obj)
|
|
41
|
-
assert isinstance(rng_obj, np.random.Generator) # narrow
|
|
42
|
-
return rng_obj
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
def _keystream(key: bytes, nonce: bytes, length: int) -> bytes:
|
|
46
|
-
# WARNING (INSECURE): hash-based keystream (key||nonce||counter)
|
|
47
|
-
out = bytearray()
|
|
48
|
-
while len(out) < length:
|
|
49
|
-
chunk = blake2b(key + nonce)
|
|
50
|
-
out.extend(chunk)
|
|
51
|
-
return bytes(out[:length])
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
@kernel_def("crypto.keygen")
|
|
55
|
-
def _crypto_keygen(pfunc: PFunction) -> TensorValue:
|
|
56
|
-
length = int(pfunc.attrs.get("length", 32))
|
|
57
|
-
rng = _get_rng()
|
|
58
|
-
key = rng.integers(0, 256, size=(length,), dtype=np.uint8)
|
|
59
|
-
return TensorValue(key)
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
@kernel_def("crypto.enc")
|
|
63
|
-
def _crypto_encrypt(
|
|
64
|
-
pfunc: PFunction, pt_bytes: TensorValue, key: TensorValue
|
|
65
|
-
) -> TensorValue:
|
|
66
|
-
pt_bytes_np = pt_bytes.to_numpy().astype(np.uint8, copy=False)
|
|
67
|
-
key_np = key.to_numpy().astype(np.uint8, copy=False)
|
|
68
|
-
rng = _get_rng()
|
|
69
|
-
nonce = rng.integers(0, 256, size=(16,), dtype=np.uint8)
|
|
70
|
-
stream = np.frombuffer(
|
|
71
|
-
_keystream(key_np.tobytes(), nonce.tobytes(), pt_bytes_np.size), dtype=np.uint8
|
|
72
|
-
)
|
|
73
|
-
ct = (pt_bytes_np ^ stream).astype(np.uint8)
|
|
74
|
-
out = np.concatenate([nonce, ct]).astype(np.uint8)
|
|
75
|
-
return TensorValue(out)
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
@kernel_def("crypto.dec")
|
|
79
|
-
def _crypto_decrypt(
|
|
80
|
-
pfunc: PFunction, ct_with_nonce: TensorValue, key: TensorValue
|
|
81
|
-
) -> TensorValue:
|
|
82
|
-
ct_np = ct_with_nonce.to_numpy().astype(np.uint8, copy=False)
|
|
83
|
-
key_np = key.to_numpy().astype(np.uint8, copy=False)
|
|
84
|
-
nonce = ct_np[:16]
|
|
85
|
-
ct = ct_np[16:]
|
|
86
|
-
stream = np.frombuffer(
|
|
87
|
-
_keystream(key_np.tobytes(), nonce.tobytes(), len(ct)), dtype=np.uint8
|
|
88
|
-
)
|
|
89
|
-
pt_bytes = (ct ^ stream).astype(np.uint8)
|
|
90
|
-
return TensorValue(pt_bytes)
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
@kernel_def("crypto.kem_keygen")
|
|
94
|
-
def _crypto_kem_keygen(pfunc: PFunction) -> tuple[TensorValue, TensorValue]:
|
|
95
|
-
rng = _get_rng()
|
|
96
|
-
sk = rng.integers(0, 256, size=(32,), dtype=np.uint8)
|
|
97
|
-
pk_bytes = blake2b(sk.tobytes())[:32]
|
|
98
|
-
pk = np.frombuffer(pk_bytes, dtype=np.uint8)
|
|
99
|
-
return (TensorValue(sk), TensorValue(pk))
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
@kernel_def("crypto.kem_derive")
|
|
103
|
-
def _crypto_kem_derive(
|
|
104
|
-
pfunc: PFunction, sk: TensorValue, peer_pk: TensorValue
|
|
105
|
-
) -> TensorValue:
|
|
106
|
-
sk_np = sk.to_numpy().astype(np.uint8, copy=False)
|
|
107
|
-
peer_pk_np = peer_pk.to_numpy().astype(np.uint8, copy=False)
|
|
108
|
-
|
|
109
|
-
self_pk_bytes = blake2b(sk_np.tobytes())[:32]
|
|
110
|
-
self_pk_arr = np.frombuffer(self_pk_bytes, dtype=np.uint8)
|
|
111
|
-
xored = (self_pk_arr ^ peer_pk_np).astype(np.uint8)
|
|
112
|
-
secret = np.frombuffer(blake2b(xored.tobytes())[:32], dtype=np.uint8)
|
|
113
|
-
return TensorValue(secret)
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
@kernel_def("crypto.hkdf")
|
|
117
|
-
def _crypto_hkdf(pfunc: PFunction, secret: TensorValue) -> TensorValue:
|
|
118
|
-
secret_np = secret.to_numpy().astype(np.uint8, copy=False)
|
|
119
|
-
info_str = str(pfunc.attrs.get("info", ""))
|
|
120
|
-
info = info_str.encode("utf-8")
|
|
121
|
-
out = np.frombuffer(blake2b(secret_np.tobytes() + info)[:32], dtype=np.uint8)
|
|
122
|
-
return TensorValue(out)
|