mplang-nightly 0.1.dev142__py3-none-any.whl → 0.1.dev144__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/backend/__init__.py +0 -7
- mplang/backend/base.py +71 -183
- mplang/backend/context.py +255 -0
- mplang/backend/phe.py +1448 -91
- mplang/backend/spu.py +6 -4
- mplang/backend/sql_duckdb.py +1 -1
- mplang/core/expr/evaluator.py +6 -6
- mplang/frontend/base.py +1 -1
- mplang/frontend/ibis_cc.py +2 -1
- mplang/frontend/phe.py +140 -3
- mplang/frontend/spu.py +4 -3
- mplang/runtime/resource.py +39 -62
- mplang/runtime/simulation.py +6 -13
- {mplang_nightly-0.1.dev142.dist-info → mplang_nightly-0.1.dev144.dist-info}/METADATA +1 -1
- {mplang_nightly-0.1.dev142.dist-info → mplang_nightly-0.1.dev144.dist-info}/RECORD +18 -17
- {mplang_nightly-0.1.dev142.dist-info → mplang_nightly-0.1.dev144.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev142.dist-info → mplang_nightly-0.1.dev144.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev142.dist-info → mplang_nightly-0.1.dev144.dist-info}/licenses/LICENSE +0 -0
mplang/backend/__init__.py
CHANGED
@@ -11,10 +11,3 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
|
15
|
-
"""
|
16
|
-
Backend module for mplang.
|
17
|
-
|
18
|
-
This module contains handlers that execute serialized functions on individual
|
19
|
-
parties in a multi-party computation system.
|
20
|
-
"""
|
mplang/backend/base.py
CHANGED
@@ -12,12 +12,20 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
"""
|
15
|
+
"""Backend kernel registry & per-participant runtime (explicit op->kernel binding).
|
16
16
|
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
17
|
+
This version decouples *kernel implementation registration* from *operation binding*.
|
18
|
+
|
19
|
+
Concepts:
|
20
|
+
* kernel_id: unique identifier of a concrete backend implementation.
|
21
|
+
* op_type: semantic operation name carried by ``PFunction.fn_type``.
|
22
|
+
* bind_op(op_type, kernel_id): performed by higher layer (see ``backend.context``)
|
23
|
+
to select which implementation handles an op. Runtime dispatch is now a 2-step:
|
24
|
+
pfunc.fn_type -> active kernel_id -> KernelSpec.fn
|
25
|
+
|
26
|
+
The previous implicit "import == register+bind" coupling is removed. Kernel
|
27
|
+
modules only call ``@kernel_def(kernel_id)``. Default bindings are established
|
28
|
+
centrally (lazy) the first time a runtime executes a kernel.
|
21
29
|
"""
|
22
30
|
|
23
31
|
from __future__ import annotations
|
@@ -27,22 +35,17 @@ from collections.abc import Callable
|
|
27
35
|
from dataclasses import dataclass
|
28
36
|
from typing import Any
|
29
37
|
|
30
|
-
from mplang.core.dtype import UINT8, DType
|
31
|
-
from mplang.core.pfunc import PFunction
|
32
|
-
from mplang.core.table import TableLike, TableType
|
33
|
-
from mplang.core.tensor import TensorLike, TensorType
|
34
|
-
|
35
38
|
__all__ = [
|
36
|
-
"BackendRuntime",
|
37
39
|
"KernelContext",
|
38
|
-
"
|
40
|
+
"KernelSpec",
|
41
|
+
"bind_op",
|
39
42
|
"cur_kctx",
|
40
|
-
"
|
41
|
-
"
|
43
|
+
"get_kernel_for_op",
|
44
|
+
"list_kernels",
|
45
|
+
"list_ops",
|
46
|
+
"unbind_op",
|
42
47
|
]
|
43
48
|
|
44
|
-
# ---------------- Context ----------------
|
45
|
-
|
46
49
|
|
47
50
|
@dataclass
|
48
51
|
class KernelContext:
|
@@ -99,189 +102,74 @@ def cur_kctx() -> KernelContext:
|
|
99
102
|
|
100
103
|
# ---------------- Registry ----------------
|
101
104
|
|
102
|
-
#
|
103
|
-
# - No **kwargs (explicitly disallowed)
|
104
|
-
# - Return normalization handled by BackendRuntime.run_kernel
|
105
|
+
# Kernel callable signature: (pfunc, *args) -> Any | sequence (no **kwargs)
|
105
106
|
KernelFn = Callable[..., Any]
|
106
107
|
|
107
|
-
_KERNELS: dict[str, KernelFn] = {}
|
108
108
|
|
109
|
+
@dataclass
|
110
|
+
class KernelSpec:
|
111
|
+
kernel_id: str
|
112
|
+
fn: KernelFn
|
113
|
+
meta: dict[str, Any]
|
109
114
|
|
110
|
-
def _validate_table_arg(
|
111
|
-
fn_type: str, arg_index: int, spec: TableType, value: Any
|
112
|
-
) -> None:
|
113
|
-
if not isinstance(value, TableLike):
|
114
|
-
raise TypeError(
|
115
|
-
f"kernel {fn_type} input[{arg_index}] expects TableLike, got {type(value).__name__}"
|
116
|
-
)
|
117
|
-
if len(value.columns) != len(spec.columns):
|
118
|
-
raise ValueError(
|
119
|
-
f"kernel {fn_type} input[{arg_index}] column count mismatch: got {len(value.columns)}, expected {len(spec.columns)}"
|
120
|
-
)
|
121
115
|
|
116
|
+
# All registered kernel implementations: kernel_id -> spec
|
117
|
+
_KERNELS: dict[str, KernelSpec] = {}
|
118
|
+
|
119
|
+
# Active op bindings: op_type -> kernel_id
|
120
|
+
_BINDINGS: dict[str, str] = {}
|
122
121
|
|
123
|
-
def _validate_tensor_arg(
|
124
|
-
fn_type: str, arg_index: int, spec: TensorType, value: Any
|
125
|
-
) -> None:
|
126
|
-
# Backend-only handle sentinel (e.g., PHE keys) bypasses all structural checks
|
127
|
-
if tuple(spec.shape) == (-1, 0) and spec.dtype == UINT8:
|
128
|
-
return
|
129
122
|
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
raise TypeError(
|
136
|
-
f"kernel {fn_type} input[{arg_index}] expects TensorLike, got {type(value).__name__}"
|
137
|
-
)
|
138
|
-
val_shape = getattr(value, "shape", ())
|
139
|
-
duck_dtype = getattr(value, "dtype", None)
|
140
|
-
|
141
|
-
if len(spec.shape) != len(val_shape):
|
142
|
-
raise ValueError(
|
143
|
-
f"kernel {fn_type} input[{arg_index}] rank mismatch: got {val_shape}, expected {spec.shape}"
|
144
|
-
)
|
145
|
-
|
146
|
-
for dim_idx, (spec_dim, val_dim) in enumerate(
|
147
|
-
zip(spec.shape, val_shape, strict=True)
|
148
|
-
):
|
149
|
-
if spec_dim >= 0 and spec_dim != val_dim:
|
150
|
-
raise ValueError(
|
151
|
-
f"kernel {fn_type} input[{arg_index}] shape mismatch at dim {dim_idx}: got {val_dim}, expected {spec_dim}"
|
152
|
-
)
|
153
|
-
|
154
|
-
try:
|
155
|
-
val_dtype = DType.from_any(duck_dtype)
|
156
|
-
except (ValueError, TypeError): # pragma: no cover
|
157
|
-
raise TypeError(
|
158
|
-
f"kernel {fn_type} input[{arg_index}] has unsupported dtype object {duck_dtype!r}"
|
159
|
-
) from None
|
160
|
-
if val_dtype != spec.dtype:
|
161
|
-
raise ValueError(
|
162
|
-
f"kernel {fn_type} input[{arg_index}] dtype mismatch: got {val_dtype}, expected {spec.dtype}"
|
163
|
-
)
|
164
|
-
|
165
|
-
|
166
|
-
def kernel_def(fn_type: str) -> Callable[[KernelFn], KernelFn]:
|
167
|
-
"""Decorator to register a backend kernel (new signature).
|
168
|
-
|
169
|
-
Expected Python signature form:
|
170
|
-
|
171
|
-
@kernel_def("namespace.op")
|
172
|
-
def _op(pfunc: PFunction, *args): ...
|
173
|
-
|
174
|
-
Rules:
|
175
|
-
* First parameter MUST be the PFunction object.
|
176
|
-
* Positional arguments correspond 1:1 to pfunc.ins_info order.
|
177
|
-
* **kwargs are NOT supported (will raise at call site if used).
|
178
|
-
* Return value forms accepted (n = len(pfunc.outs_info)):
|
179
|
-
- n == 0: return None / () / []
|
180
|
-
- n == 1: return scalar/object OR (value,) / [value]
|
181
|
-
- n > 1 : return tuple/list of length n
|
182
|
-
Anything else raises a ValueError.
|
123
|
+
def kernel_def(kernel_id: str, /, **meta: Any) -> Callable[[KernelFn], KernelFn]:
|
124
|
+
"""Decorator to register a concrete kernel implementation.
|
125
|
+
|
126
|
+
This ONLY registers the implementation (kernel_id -> fn). It does NOT bind
|
127
|
+
any op. Higher layer must call ``bind_op(op_type, kernel_id)`` explicitly.
|
183
128
|
"""
|
184
129
|
|
185
130
|
def _decorator(fn: KernelFn) -> KernelFn:
|
186
|
-
if
|
187
|
-
raise ValueError(f"duplicate
|
188
|
-
_KERNELS[
|
131
|
+
if kernel_id in _KERNELS:
|
132
|
+
raise ValueError(f"duplicate kernel_id={kernel_id}")
|
133
|
+
_KERNELS[kernel_id] = KernelSpec(kernel_id=kernel_id, fn=fn, meta=dict(meta))
|
189
134
|
return fn
|
190
135
|
|
191
136
|
return _decorator
|
192
137
|
|
193
138
|
|
194
|
-
def
|
195
|
-
|
139
|
+
def bind_op(op_type: str, kernel_id: str, *, force: bool = True) -> None:
|
140
|
+
"""Bind an op_type to a registered kernel implementation.
|
196
141
|
|
142
|
+
Args:
|
143
|
+
op_type: Semantic operation name.
|
144
|
+
kernel_id: Previously registered kernel identifier.
|
145
|
+
force: If False and op_type already bound, keep existing binding.
|
146
|
+
If True (default), overwrite.
|
147
|
+
"""
|
148
|
+
if kernel_id not in _KERNELS:
|
149
|
+
raise KeyError(f"kernel_id {kernel_id} not registered")
|
150
|
+
if not force and op_type in _BINDINGS:
|
151
|
+
return
|
152
|
+
_BINDINGS[op_type] = kernel_id
|
197
153
|
|
198
|
-
class BackendRuntime:
|
199
|
-
"""Per-rank backend execution environment.
|
200
154
|
|
201
|
-
|
202
|
-
|
203
|
-
|
155
|
+
def unbind_op(op_type: str) -> None:
|
156
|
+
_BINDINGS.pop(op_type, None)
|
157
|
+
|
158
|
+
|
159
|
+
def get_kernel_for_op(op_type: str) -> KernelSpec:
|
160
|
+
kid = _BINDINGS.get(op_type)
|
161
|
+
if kid is None:
|
162
|
+
# Tests expect NotImplementedError for unsupported operations
|
163
|
+
raise NotImplementedError(f"no backend kernel registered for op {op_type}")
|
164
|
+
spec = _KERNELS.get(kid)
|
165
|
+
if spec is None: # inconsistent state
|
166
|
+
raise RuntimeError(f"active kernel_id {kid} missing spec")
|
167
|
+
return spec
|
168
|
+
|
169
|
+
|
170
|
+
def list_kernels() -> list[str]:
|
171
|
+
return sorted(_KERNELS.keys())
|
172
|
+
|
204
173
|
|
205
|
-
|
206
|
-
|
207
|
-
self.world_size = world_size
|
208
|
-
self.state: dict[str, dict[str, Any]] = {}
|
209
|
-
self.cache: dict[str, Any] = {}
|
210
|
-
|
211
|
-
# Main entry
|
212
|
-
def run_kernel(self, pfunc: PFunction, arg_list: list[Any]) -> list[Any]:
|
213
|
-
fn_type = pfunc.fn_type
|
214
|
-
fn = _KERNELS.get(fn_type)
|
215
|
-
if fn is None:
|
216
|
-
raise NotImplementedError(f"no backend kernel registered for {fn_type}")
|
217
|
-
|
218
|
-
# Strict positional arg count validation (no kernel-managed arity bypass)
|
219
|
-
if len(arg_list) != len(pfunc.ins_info):
|
220
|
-
raise ValueError(
|
221
|
-
f"kernel {fn_type} arg count mismatch: got {len(arg_list)}, expect {len(pfunc.ins_info)}"
|
222
|
-
)
|
223
|
-
|
224
|
-
for idx, (spec, val) in enumerate(zip(pfunc.ins_info, arg_list, strict=True)):
|
225
|
-
if isinstance(spec, TableType):
|
226
|
-
_validate_table_arg(fn_type, idx, spec, val)
|
227
|
-
continue
|
228
|
-
|
229
|
-
if isinstance(spec, TensorType):
|
230
|
-
_validate_tensor_arg(fn_type, idx, spec, val)
|
231
|
-
continue
|
232
|
-
|
233
|
-
# Unknown spec type: silently skip validation (legacy behavior)
|
234
|
-
continue
|
235
|
-
|
236
|
-
kctx = KernelContext(
|
237
|
-
rank=self.rank,
|
238
|
-
world_size=self.world_size,
|
239
|
-
state=self.state,
|
240
|
-
cache=self.cache,
|
241
|
-
)
|
242
|
-
token = _CTX_VAR.set(kctx)
|
243
|
-
try:
|
244
|
-
raw = fn(pfunc, *arg_list)
|
245
|
-
finally:
|
246
|
-
_CTX_VAR.reset(token)
|
247
|
-
|
248
|
-
# Normalize return values
|
249
|
-
expected = len(pfunc.outs_info)
|
250
|
-
if expected == 0:
|
251
|
-
if raw in (None, (), []):
|
252
|
-
return []
|
253
|
-
raise ValueError(
|
254
|
-
f"kernel {fn_type} should return no values; got {type(raw).__name__}"
|
255
|
-
)
|
256
|
-
|
257
|
-
# If multi-output expected, raw must be sequence of right length
|
258
|
-
if expected == 1:
|
259
|
-
if isinstance(raw, (tuple, list)):
|
260
|
-
if len(raw) != 1:
|
261
|
-
raise ValueError(
|
262
|
-
f"kernel {fn_type} produced {len(raw)} outputs, expected 1"
|
263
|
-
)
|
264
|
-
return [raw[0]]
|
265
|
-
# Single object
|
266
|
-
return [raw]
|
267
|
-
|
268
|
-
# expected > 1
|
269
|
-
if not isinstance(raw, (tuple, list)):
|
270
|
-
raise TypeError(
|
271
|
-
f"kernel {fn_type} must return sequence (len={expected}), got {type(raw).__name__}"
|
272
|
-
)
|
273
|
-
if len(raw) != expected:
|
274
|
-
raise ValueError(
|
275
|
-
f"kernel {fn_type} produced {len(raw)} outputs, expected {expected}"
|
276
|
-
)
|
277
|
-
return list(raw)
|
278
|
-
|
279
|
-
# Optional helper
|
280
|
-
def reset(self) -> None: # pragma: no cover - simple
|
281
|
-
self.state.clear()
|
282
|
-
self.cache.clear()
|
283
|
-
|
284
|
-
|
285
|
-
def create_runtime(rank: int, world_size: int) -> BackendRuntime:
|
286
|
-
"""Factory for BackendRuntime (allows future policy injection)."""
|
287
|
-
return BackendRuntime(rank, world_size)
|
174
|
+
def list_ops() -> list[str]:
|
175
|
+
return sorted(_BINDINGS.keys())
|
@@ -0,0 +1,255 @@
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from __future__ import annotations
|
16
|
+
|
17
|
+
from collections.abc import Mapping
|
18
|
+
from dataclasses import dataclass, field
|
19
|
+
from typing import Any
|
20
|
+
|
21
|
+
from mplang.backend import base
|
22
|
+
from mplang.backend.base import KernelContext, bind_op, get_kernel_for_op
|
23
|
+
from mplang.core.dtype import UINT8, DType
|
24
|
+
from mplang.core.pfunc import PFunction
|
25
|
+
from mplang.core.table import TableLike, TableType
|
26
|
+
from mplang.core.tensor import TensorLike, TensorType
|
27
|
+
|
28
|
+
# Default bindings
|
29
|
+
# Import kernel implementation modules explicitly so their @kernel_def entries
|
30
|
+
# register at import time. Keep imports grouped; alias with leading underscore
|
31
|
+
# to silence unused variable warnings without F401 pragmas.
|
32
|
+
_IMPL_IMPORTED = False
|
33
|
+
|
34
|
+
|
35
|
+
def _ensure_impl_imported() -> None:
|
36
|
+
global _IMPL_IMPORTED
|
37
|
+
if _IMPL_IMPORTED:
|
38
|
+
return
|
39
|
+
from mplang.backend import builtin as _impl_builtin # noqa: F401
|
40
|
+
from mplang.backend import crypto as _impl_crypto # noqa: F401
|
41
|
+
from mplang.backend import phe as _impl_phe # noqa: F401
|
42
|
+
from mplang.backend import spu as _impl_spu # noqa: F401
|
43
|
+
from mplang.backend import sql_duckdb as _impl_sql_duckdb # noqa: F401
|
44
|
+
from mplang.backend import stablehlo as _impl_stablehlo # noqa: F401
|
45
|
+
from mplang.backend import tee as _impl_tee # noqa: F401
|
46
|
+
|
47
|
+
_IMPL_IMPORTED = True
|
48
|
+
|
49
|
+
|
50
|
+
# imports consolidated above
|
51
|
+
|
52
|
+
_DEFAULT_BINDINGS: dict[str, str] = {
|
53
|
+
# builtin
|
54
|
+
"builtin.identity": "builtin.identity",
|
55
|
+
"builtin.read": "builtin.read",
|
56
|
+
"builtin.write": "builtin.write",
|
57
|
+
"builtin.constant": "builtin.constant",
|
58
|
+
"builtin.rank": "builtin.rank",
|
59
|
+
"builtin.prand": "builtin.prand",
|
60
|
+
"builtin.table_to_tensor": "builtin.table_to_tensor",
|
61
|
+
"builtin.tensor_to_table": "builtin.tensor_to_table",
|
62
|
+
"builtin.debug_print": "builtin.debug_print",
|
63
|
+
"builtin.pack": "builtin.pack",
|
64
|
+
"builtin.unpack": "builtin.unpack",
|
65
|
+
# crypto
|
66
|
+
"crypto.keygen": "crypto.keygen",
|
67
|
+
"crypto.enc": "crypto.enc",
|
68
|
+
"crypto.dec": "crypto.dec",
|
69
|
+
"crypto.kem_keygen": "crypto.kem_keygen",
|
70
|
+
"crypto.kem_derive": "crypto.kem_derive",
|
71
|
+
"crypto.hkdf": "crypto.hkdf",
|
72
|
+
# phe
|
73
|
+
"phe.keygen": "phe.keygen",
|
74
|
+
"phe.encrypt": "phe.encrypt",
|
75
|
+
"phe.mul": "phe.mul",
|
76
|
+
"phe.add": "phe.add",
|
77
|
+
"phe.decrypt": "phe.decrypt",
|
78
|
+
"phe.dot": "phe.dot",
|
79
|
+
"phe.gather": "phe.gather",
|
80
|
+
"phe.scatter": "phe.scatter",
|
81
|
+
"phe.concat": "phe.concat",
|
82
|
+
"phe.reshape": "phe.reshape",
|
83
|
+
"phe.transpose": "phe.transpose",
|
84
|
+
# spu
|
85
|
+
"spu.seed_env": "spu.seed_env",
|
86
|
+
"spu.makeshares": "spu.makeshares",
|
87
|
+
"spu.reconstruct": "spu.reconstruct",
|
88
|
+
"spu.run_pphlo": "spu.run_pphlo",
|
89
|
+
# stablehlo
|
90
|
+
"mlir.stablehlo": "mlir.stablehlo",
|
91
|
+
# sql
|
92
|
+
# generic SQL op; backend-specific kernel id for duckdb
|
93
|
+
"sql.run": "duckdb.run_sql",
|
94
|
+
# tee
|
95
|
+
"tee.quote": "tee.quote",
|
96
|
+
"tee.attest": "tee.attest",
|
97
|
+
}
|
98
|
+
|
99
|
+
|
100
|
+
# --- RuntimeContext ---
|
101
|
+
|
102
|
+
|
103
|
+
@dataclass
|
104
|
+
class RuntimeContext:
|
105
|
+
rank: int
|
106
|
+
world_size: int
|
107
|
+
bindings: Mapping[str, str] | None = None # optional overrides
|
108
|
+
state: dict[str, dict[str, Any]] = field(default_factory=dict)
|
109
|
+
cache: dict[str, Any] = field(default_factory=dict)
|
110
|
+
stats: dict[str, Any] = field(default_factory=dict)
|
111
|
+
|
112
|
+
def __post_init__(self) -> None:
|
113
|
+
_ensure_impl_imported()
|
114
|
+
if self.bindings is not None:
|
115
|
+
for op, kid in self.bindings.items():
|
116
|
+
bind_op(op, kid)
|
117
|
+
else:
|
118
|
+
for op, kid in _DEFAULT_BINDINGS.items():
|
119
|
+
bind_op(op, kid)
|
120
|
+
# Initialize stats pocket
|
121
|
+
self.stats.setdefault("op_calls", {})
|
122
|
+
|
123
|
+
def run_kernel(self, pfunc: PFunction, arg_list: list[Any]) -> list[Any]:
|
124
|
+
fn_type = pfunc.fn_type
|
125
|
+
spec = get_kernel_for_op(fn_type)
|
126
|
+
fn = spec.fn
|
127
|
+
if len(arg_list) != len(pfunc.ins_info):
|
128
|
+
raise ValueError(
|
129
|
+
f"kernel {fn_type} arg count mismatch: got {len(arg_list)}, expect {len(pfunc.ins_info)}"
|
130
|
+
)
|
131
|
+
for idx, (ins_spec, val) in enumerate(
|
132
|
+
zip(pfunc.ins_info, arg_list, strict=True)
|
133
|
+
):
|
134
|
+
if isinstance(ins_spec, TableType):
|
135
|
+
_validate_table_arg(fn_type, idx, ins_spec, val)
|
136
|
+
continue
|
137
|
+
if isinstance(ins_spec, TensorType):
|
138
|
+
_validate_tensor_arg(fn_type, idx, ins_spec, val)
|
139
|
+
continue
|
140
|
+
# install kernel context
|
141
|
+
kctx = KernelContext(
|
142
|
+
rank=self.rank,
|
143
|
+
world_size=self.world_size,
|
144
|
+
state=self.state,
|
145
|
+
cache=self.cache,
|
146
|
+
)
|
147
|
+
token = base._CTX_VAR.set(kctx) # type: ignore[attr-defined]
|
148
|
+
try:
|
149
|
+
raw = fn(pfunc, *arg_list)
|
150
|
+
finally:
|
151
|
+
base._CTX_VAR.reset(token) # type: ignore[attr-defined]
|
152
|
+
# Stats (best effort)
|
153
|
+
try:
|
154
|
+
op_calls = self.stats.setdefault("op_calls", {})
|
155
|
+
op_calls[fn_type] = op_calls.get(fn_type, 0) + 1
|
156
|
+
except Exception: # pragma: no cover - never raise due to stats
|
157
|
+
pass
|
158
|
+
expected = len(pfunc.outs_info)
|
159
|
+
if expected == 0:
|
160
|
+
if raw in (None, (), []):
|
161
|
+
return []
|
162
|
+
raise ValueError(
|
163
|
+
f"kernel {fn_type} should return no values; got {type(raw).__name__}"
|
164
|
+
)
|
165
|
+
if expected == 1:
|
166
|
+
if isinstance(raw, (tuple, list)):
|
167
|
+
if len(raw) != 1:
|
168
|
+
raise ValueError(
|
169
|
+
f"kernel {fn_type} produced {len(raw)} outputs, expected 1"
|
170
|
+
)
|
171
|
+
return [raw[0]]
|
172
|
+
return [raw]
|
173
|
+
if not isinstance(raw, (tuple, list)):
|
174
|
+
raise TypeError(
|
175
|
+
f"kernel {fn_type} must return sequence (len={expected}), got {type(raw).__name__}"
|
176
|
+
)
|
177
|
+
if len(raw) != expected:
|
178
|
+
raise ValueError(
|
179
|
+
f"kernel {fn_type} produced {len(raw)} outputs, expected {expected}"
|
180
|
+
)
|
181
|
+
return list(raw)
|
182
|
+
|
183
|
+
def reset(self) -> None:
|
184
|
+
self.state.clear()
|
185
|
+
self.cache.clear()
|
186
|
+
|
187
|
+
# ---- explicit (re)binding API ----
|
188
|
+
def bind_op(self, op_type: str, kernel_id: str, *, force: bool = False) -> None:
|
189
|
+
"""Bind an operation to a kernel at runtime.
|
190
|
+
|
191
|
+
force=False (default) preserves any existing binding to avoid accidental
|
192
|
+
silent overrides. Use ``rebind_op`` or ``force=True`` to intentionally
|
193
|
+
change a binding.
|
194
|
+
"""
|
195
|
+
base.bind_op(op_type, kernel_id, force=force)
|
196
|
+
|
197
|
+
def rebind_op(self, op_type: str, kernel_id: str) -> None:
|
198
|
+
"""Force rebind an operation to a different kernel (shorthand)."""
|
199
|
+
base.bind_op(op_type, kernel_id, force=True)
|
200
|
+
|
201
|
+
|
202
|
+
def _validate_table_arg(
|
203
|
+
fn_type: str, arg_index: int, spec: TableType, value: Any
|
204
|
+
) -> None:
|
205
|
+
if not isinstance(value, TableLike):
|
206
|
+
raise TypeError(
|
207
|
+
f"kernel {fn_type} input[{arg_index}] expects TableLike, got {type(value).__name__}"
|
208
|
+
)
|
209
|
+
if len(value.columns) != len(spec.columns):
|
210
|
+
raise ValueError(
|
211
|
+
f"kernel {fn_type} input[{arg_index}] column count mismatch: got {len(value.columns)}, expected {len(spec.columns)}"
|
212
|
+
)
|
213
|
+
|
214
|
+
|
215
|
+
def _validate_tensor_arg(
|
216
|
+
fn_type: str, arg_index: int, spec: TensorType, value: Any
|
217
|
+
) -> None:
|
218
|
+
# Backend-only handle sentinel (e.g., PHE keys) bypasses all structural checks
|
219
|
+
if tuple(spec.shape) == (-1, 0) and spec.dtype == UINT8:
|
220
|
+
return
|
221
|
+
|
222
|
+
if isinstance(value, (int, float, bool, complex)):
|
223
|
+
val_shape: tuple[Any, ...] = ()
|
224
|
+
duck_dtype: Any = type(value)
|
225
|
+
else:
|
226
|
+
if not isinstance(value, TensorLike):
|
227
|
+
raise TypeError(
|
228
|
+
f"kernel {fn_type} input[{arg_index}] expects TensorLike, got {type(value).__name__}"
|
229
|
+
)
|
230
|
+
val_shape = getattr(value, "shape", ())
|
231
|
+
duck_dtype = getattr(value, "dtype", None)
|
232
|
+
|
233
|
+
if len(spec.shape) != len(val_shape):
|
234
|
+
raise ValueError(
|
235
|
+
f"kernel {fn_type} input[{arg_index}] rank mismatch: got {val_shape}, expected {spec.shape}"
|
236
|
+
)
|
237
|
+
|
238
|
+
for dim_idx, (spec_dim, val_dim) in enumerate(
|
239
|
+
zip(spec.shape, val_shape, strict=True)
|
240
|
+
):
|
241
|
+
if spec_dim >= 0 and spec_dim != val_dim:
|
242
|
+
raise ValueError(
|
243
|
+
f"kernel {fn_type} input[{arg_index}] shape mismatch at dim {dim_idx}: got {val_dim}, expected {spec_dim}"
|
244
|
+
)
|
245
|
+
|
246
|
+
try:
|
247
|
+
val_dtype = DType.from_any(duck_dtype)
|
248
|
+
except (ValueError, TypeError): # pragma: no cover
|
249
|
+
raise TypeError(
|
250
|
+
f"kernel {fn_type} input[{arg_index}] has unsupported dtype object {duck_dtype!r}"
|
251
|
+
) from None
|
252
|
+
if val_dtype != spec.dtype:
|
253
|
+
raise ValueError(
|
254
|
+
f"kernel {fn_type} input[{arg_index}] dtype mismatch: got {val_dtype}, expected {spec.dtype}"
|
255
|
+
)
|