mplang-nightly 0.1.dev142__py3-none-any.whl → 0.1.dev144__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,10 +11,3 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
- """
16
- Backend module for mplang.
17
-
18
- This module contains handlers that execute serialized functions on individual
19
- parties in a multi-party computation system.
20
- """
mplang/backend/base.py CHANGED
@@ -12,12 +12,20 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- """Flat backend kernel registry & per-participant runtime.
15
+ """Backend kernel registry & per-participant runtime (explicit op->kernel binding).
16
16
 
17
- Design revision:
18
- - Global, stateless kernel function catalog (fn_type -> callable).
19
- - BackendRuntime: per-rank state & cache; executes kernels.
20
- - Legacy global helpers removed after full migration to explicit runtimes.
17
+ This version decouples *kernel implementation registration* from *operation binding*.
18
+
19
+ Concepts:
20
+ * kernel_id: unique identifier of a concrete backend implementation.
21
+ * op_type: semantic operation name carried by ``PFunction.fn_type``.
22
+ * bind_op(op_type, kernel_id): performed by higher layer (see ``backend.context``)
23
+ to select which implementation handles an op. Runtime dispatch is now a 2-step:
24
+ pfunc.fn_type -> active kernel_id -> KernelSpec.fn
25
+
26
+ The previous implicit "import == register+bind" coupling is removed. Kernel
27
+ modules only call ``@kernel_def(kernel_id)``. Default bindings are established
28
+ centrally (lazy) the first time a runtime executes a kernel.
21
29
  """
22
30
 
23
31
  from __future__ import annotations
@@ -27,22 +35,17 @@ from collections.abc import Callable
27
35
  from dataclasses import dataclass
28
36
  from typing import Any
29
37
 
30
- from mplang.core.dtype import UINT8, DType
31
- from mplang.core.pfunc import PFunction
32
- from mplang.core.table import TableLike, TableType
33
- from mplang.core.tensor import TensorLike, TensorType
34
-
35
38
  __all__ = [
36
- "BackendRuntime",
37
39
  "KernelContext",
38
- "create_runtime",
40
+ "KernelSpec",
41
+ "bind_op",
39
42
  "cur_kctx",
40
- "kernel_def",
41
- "list_registered_kernels",
43
+ "get_kernel_for_op",
44
+ "list_kernels",
45
+ "list_ops",
46
+ "unbind_op",
42
47
  ]
43
48
 
44
- # ---------------- Context ----------------
45
-
46
49
 
47
50
  @dataclass
48
51
  class KernelContext:
@@ -99,189 +102,74 @@ def cur_kctx() -> KernelContext:
99
102
 
100
103
  # ---------------- Registry ----------------
101
104
 
102
- # Canonical kernel callable signature (new style): (pfunc, *args) -> Any | sequence
103
- # - No **kwargs (explicitly disallowed)
104
- # - Return normalization handled by BackendRuntime.run_kernel
105
+ # Kernel callable signature: (pfunc, *args) -> Any | sequence (no **kwargs)
105
106
  KernelFn = Callable[..., Any]
106
107
 
107
- _KERNELS: dict[str, KernelFn] = {}
108
108
 
109
+ @dataclass
110
+ class KernelSpec:
111
+ kernel_id: str
112
+ fn: KernelFn
113
+ meta: dict[str, Any]
109
114
 
110
- def _validate_table_arg(
111
- fn_type: str, arg_index: int, spec: TableType, value: Any
112
- ) -> None:
113
- if not isinstance(value, TableLike):
114
- raise TypeError(
115
- f"kernel {fn_type} input[{arg_index}] expects TableLike, got {type(value).__name__}"
116
- )
117
- if len(value.columns) != len(spec.columns):
118
- raise ValueError(
119
- f"kernel {fn_type} input[{arg_index}] column count mismatch: got {len(value.columns)}, expected {len(spec.columns)}"
120
- )
121
115
 
116
+ # All registered kernel implementations: kernel_id -> spec
117
+ _KERNELS: dict[str, KernelSpec] = {}
118
+
119
+ # Active op bindings: op_type -> kernel_id
120
+ _BINDINGS: dict[str, str] = {}
122
121
 
123
- def _validate_tensor_arg(
124
- fn_type: str, arg_index: int, spec: TensorType, value: Any
125
- ) -> None:
126
- # Backend-only handle sentinel (e.g., PHE keys) bypasses all structural checks
127
- if tuple(spec.shape) == (-1, 0) and spec.dtype == UINT8:
128
- return
129
122
 
130
- if isinstance(value, (int, float, bool, complex)):
131
- val_shape: tuple[Any, ...] = ()
132
- duck_dtype: Any = type(value)
133
- else:
134
- if not isinstance(value, TensorLike):
135
- raise TypeError(
136
- f"kernel {fn_type} input[{arg_index}] expects TensorLike, got {type(value).__name__}"
137
- )
138
- val_shape = getattr(value, "shape", ())
139
- duck_dtype = getattr(value, "dtype", None)
140
-
141
- if len(spec.shape) != len(val_shape):
142
- raise ValueError(
143
- f"kernel {fn_type} input[{arg_index}] rank mismatch: got {val_shape}, expected {spec.shape}"
144
- )
145
-
146
- for dim_idx, (spec_dim, val_dim) in enumerate(
147
- zip(spec.shape, val_shape, strict=True)
148
- ):
149
- if spec_dim >= 0 and spec_dim != val_dim:
150
- raise ValueError(
151
- f"kernel {fn_type} input[{arg_index}] shape mismatch at dim {dim_idx}: got {val_dim}, expected {spec_dim}"
152
- )
153
-
154
- try:
155
- val_dtype = DType.from_any(duck_dtype)
156
- except (ValueError, TypeError): # pragma: no cover
157
- raise TypeError(
158
- f"kernel {fn_type} input[{arg_index}] has unsupported dtype object {duck_dtype!r}"
159
- ) from None
160
- if val_dtype != spec.dtype:
161
- raise ValueError(
162
- f"kernel {fn_type} input[{arg_index}] dtype mismatch: got {val_dtype}, expected {spec.dtype}"
163
- )
164
-
165
-
166
- def kernel_def(fn_type: str) -> Callable[[KernelFn], KernelFn]:
167
- """Decorator to register a backend kernel (new signature).
168
-
169
- Expected Python signature form:
170
-
171
- @kernel_def("namespace.op")
172
- def _op(pfunc: PFunction, *args): ...
173
-
174
- Rules:
175
- * First parameter MUST be the PFunction object.
176
- * Positional arguments correspond 1:1 to pfunc.ins_info order.
177
- * **kwargs are NOT supported (will raise at call site if used).
178
- * Return value forms accepted (n = len(pfunc.outs_info)):
179
- - n == 0: return None / () / []
180
- - n == 1: return scalar/object OR (value,) / [value]
181
- - n > 1 : return tuple/list of length n
182
- Anything else raises a ValueError.
123
+ def kernel_def(kernel_id: str, /, **meta: Any) -> Callable[[KernelFn], KernelFn]:
124
+ """Decorator to register a concrete kernel implementation.
125
+
126
+ This ONLY registers the implementation (kernel_id -> fn). It does NOT bind
127
+ any op. Higher layer must call ``bind_op(op_type, kernel_id)`` explicitly.
183
128
  """
184
129
 
185
130
  def _decorator(fn: KernelFn) -> KernelFn:
186
- if fn_type in _KERNELS:
187
- raise ValueError(f"duplicate backend kernel fn_type={fn_type}")
188
- _KERNELS[fn_type] = fn
131
+ if kernel_id in _KERNELS:
132
+ raise ValueError(f"duplicate kernel_id={kernel_id}")
133
+ _KERNELS[kernel_id] = KernelSpec(kernel_id=kernel_id, fn=fn, meta=dict(meta))
189
134
  return fn
190
135
 
191
136
  return _decorator
192
137
 
193
138
 
194
- def list_registered_kernels() -> list[str]: # public API unchanged
195
- return sorted(_KERNELS.keys())
139
+ def bind_op(op_type: str, kernel_id: str, *, force: bool = True) -> None:
140
+ """Bind an op_type to a registered kernel implementation.
196
141
 
142
+ Args:
143
+ op_type: Semantic operation name.
144
+ kernel_id: Previously registered kernel identifier.
145
+ force: If False and op_type already bound, keep existing binding.
146
+ If True (default), overwrite.
147
+ """
148
+ if kernel_id not in _KERNELS:
149
+ raise KeyError(f"kernel_id {kernel_id} not registered")
150
+ if not force and op_type in _BINDINGS:
151
+ return
152
+ _BINDINGS[op_type] = kernel_id
197
153
 
198
- class BackendRuntime:
199
- """Per-rank backend execution environment.
200
154
 
201
- Holds mutable backend state (namespaced pockets) and a cache. Stateless
202
- kernel implementations look up their state through cur_kctx().
203
- """
155
+ def unbind_op(op_type: str) -> None:
156
+ _BINDINGS.pop(op_type, None)
157
+
158
+
159
+ def get_kernel_for_op(op_type: str) -> KernelSpec:
160
+ kid = _BINDINGS.get(op_type)
161
+ if kid is None:
162
+ # Tests expect NotImplementedError for unsupported operations
163
+ raise NotImplementedError(f"no backend kernel registered for op {op_type}")
164
+ spec = _KERNELS.get(kid)
165
+ if spec is None: # inconsistent state
166
+ raise RuntimeError(f"active kernel_id {kid} missing spec")
167
+ return spec
168
+
169
+
170
+ def list_kernels() -> list[str]:
171
+ return sorted(_KERNELS.keys())
172
+
204
173
 
205
- def __init__(self, rank: int, world_size: int):
206
- self.rank = rank
207
- self.world_size = world_size
208
- self.state: dict[str, dict[str, Any]] = {}
209
- self.cache: dict[str, Any] = {}
210
-
211
- # Main entry
212
- def run_kernel(self, pfunc: PFunction, arg_list: list[Any]) -> list[Any]:
213
- fn_type = pfunc.fn_type
214
- fn = _KERNELS.get(fn_type)
215
- if fn is None:
216
- raise NotImplementedError(f"no backend kernel registered for {fn_type}")
217
-
218
- # Strict positional arg count validation (no kernel-managed arity bypass)
219
- if len(arg_list) != len(pfunc.ins_info):
220
- raise ValueError(
221
- f"kernel {fn_type} arg count mismatch: got {len(arg_list)}, expect {len(pfunc.ins_info)}"
222
- )
223
-
224
- for idx, (spec, val) in enumerate(zip(pfunc.ins_info, arg_list, strict=True)):
225
- if isinstance(spec, TableType):
226
- _validate_table_arg(fn_type, idx, spec, val)
227
- continue
228
-
229
- if isinstance(spec, TensorType):
230
- _validate_tensor_arg(fn_type, idx, spec, val)
231
- continue
232
-
233
- # Unknown spec type: silently skip validation (legacy behavior)
234
- continue
235
-
236
- kctx = KernelContext(
237
- rank=self.rank,
238
- world_size=self.world_size,
239
- state=self.state,
240
- cache=self.cache,
241
- )
242
- token = _CTX_VAR.set(kctx)
243
- try:
244
- raw = fn(pfunc, *arg_list)
245
- finally:
246
- _CTX_VAR.reset(token)
247
-
248
- # Normalize return values
249
- expected = len(pfunc.outs_info)
250
- if expected == 0:
251
- if raw in (None, (), []):
252
- return []
253
- raise ValueError(
254
- f"kernel {fn_type} should return no values; got {type(raw).__name__}"
255
- )
256
-
257
- # If multi-output expected, raw must be sequence of right length
258
- if expected == 1:
259
- if isinstance(raw, (tuple, list)):
260
- if len(raw) != 1:
261
- raise ValueError(
262
- f"kernel {fn_type} produced {len(raw)} outputs, expected 1"
263
- )
264
- return [raw[0]]
265
- # Single object
266
- return [raw]
267
-
268
- # expected > 1
269
- if not isinstance(raw, (tuple, list)):
270
- raise TypeError(
271
- f"kernel {fn_type} must return sequence (len={expected}), got {type(raw).__name__}"
272
- )
273
- if len(raw) != expected:
274
- raise ValueError(
275
- f"kernel {fn_type} produced {len(raw)} outputs, expected {expected}"
276
- )
277
- return list(raw)
278
-
279
- # Optional helper
280
- def reset(self) -> None: # pragma: no cover - simple
281
- self.state.clear()
282
- self.cache.clear()
283
-
284
-
285
- def create_runtime(rank: int, world_size: int) -> BackendRuntime:
286
- """Factory for BackendRuntime (allows future policy injection)."""
287
- return BackendRuntime(rank, world_size)
174
+ def list_ops() -> list[str]:
175
+ return sorted(_BINDINGS.keys())
@@ -0,0 +1,255 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from collections.abc import Mapping
18
+ from dataclasses import dataclass, field
19
+ from typing import Any
20
+
21
+ from mplang.backend import base
22
+ from mplang.backend.base import KernelContext, bind_op, get_kernel_for_op
23
+ from mplang.core.dtype import UINT8, DType
24
+ from mplang.core.pfunc import PFunction
25
+ from mplang.core.table import TableLike, TableType
26
+ from mplang.core.tensor import TensorLike, TensorType
27
+
28
+ # Default bindings
29
+ # Import kernel implementation modules explicitly so their @kernel_def entries
30
+ # register at import time. Keep imports grouped; alias with leading underscore
31
+ # to silence unused variable warnings without F401 pragmas.
32
+ _IMPL_IMPORTED = False
33
+
34
+
35
+ def _ensure_impl_imported() -> None:
36
+ global _IMPL_IMPORTED
37
+ if _IMPL_IMPORTED:
38
+ return
39
+ from mplang.backend import builtin as _impl_builtin # noqa: F401
40
+ from mplang.backend import crypto as _impl_crypto # noqa: F401
41
+ from mplang.backend import phe as _impl_phe # noqa: F401
42
+ from mplang.backend import spu as _impl_spu # noqa: F401
43
+ from mplang.backend import sql_duckdb as _impl_sql_duckdb # noqa: F401
44
+ from mplang.backend import stablehlo as _impl_stablehlo # noqa: F401
45
+ from mplang.backend import tee as _impl_tee # noqa: F401
46
+
47
+ _IMPL_IMPORTED = True
48
+
49
+
50
+ # imports consolidated above
51
+
52
+ _DEFAULT_BINDINGS: dict[str, str] = {
53
+ # builtin
54
+ "builtin.identity": "builtin.identity",
55
+ "builtin.read": "builtin.read",
56
+ "builtin.write": "builtin.write",
57
+ "builtin.constant": "builtin.constant",
58
+ "builtin.rank": "builtin.rank",
59
+ "builtin.prand": "builtin.prand",
60
+ "builtin.table_to_tensor": "builtin.table_to_tensor",
61
+ "builtin.tensor_to_table": "builtin.tensor_to_table",
62
+ "builtin.debug_print": "builtin.debug_print",
63
+ "builtin.pack": "builtin.pack",
64
+ "builtin.unpack": "builtin.unpack",
65
+ # crypto
66
+ "crypto.keygen": "crypto.keygen",
67
+ "crypto.enc": "crypto.enc",
68
+ "crypto.dec": "crypto.dec",
69
+ "crypto.kem_keygen": "crypto.kem_keygen",
70
+ "crypto.kem_derive": "crypto.kem_derive",
71
+ "crypto.hkdf": "crypto.hkdf",
72
+ # phe
73
+ "phe.keygen": "phe.keygen",
74
+ "phe.encrypt": "phe.encrypt",
75
+ "phe.mul": "phe.mul",
76
+ "phe.add": "phe.add",
77
+ "phe.decrypt": "phe.decrypt",
78
+ "phe.dot": "phe.dot",
79
+ "phe.gather": "phe.gather",
80
+ "phe.scatter": "phe.scatter",
81
+ "phe.concat": "phe.concat",
82
+ "phe.reshape": "phe.reshape",
83
+ "phe.transpose": "phe.transpose",
84
+ # spu
85
+ "spu.seed_env": "spu.seed_env",
86
+ "spu.makeshares": "spu.makeshares",
87
+ "spu.reconstruct": "spu.reconstruct",
88
+ "spu.run_pphlo": "spu.run_pphlo",
89
+ # stablehlo
90
+ "mlir.stablehlo": "mlir.stablehlo",
91
+ # sql
92
+ # generic SQL op; backend-specific kernel id for duckdb
93
+ "sql.run": "duckdb.run_sql",
94
+ # tee
95
+ "tee.quote": "tee.quote",
96
+ "tee.attest": "tee.attest",
97
+ }
98
+
99
+
100
+ # --- RuntimeContext ---
101
+
102
+
103
+ @dataclass
104
+ class RuntimeContext:
105
+ rank: int
106
+ world_size: int
107
+ bindings: Mapping[str, str] | None = None # optional overrides
108
+ state: dict[str, dict[str, Any]] = field(default_factory=dict)
109
+ cache: dict[str, Any] = field(default_factory=dict)
110
+ stats: dict[str, Any] = field(default_factory=dict)
111
+
112
+ def __post_init__(self) -> None:
113
+ _ensure_impl_imported()
114
+ if self.bindings is not None:
115
+ for op, kid in self.bindings.items():
116
+ bind_op(op, kid)
117
+ else:
118
+ for op, kid in _DEFAULT_BINDINGS.items():
119
+ bind_op(op, kid)
120
+ # Initialize stats pocket
121
+ self.stats.setdefault("op_calls", {})
122
+
123
+ def run_kernel(self, pfunc: PFunction, arg_list: list[Any]) -> list[Any]:
124
+ fn_type = pfunc.fn_type
125
+ spec = get_kernel_for_op(fn_type)
126
+ fn = spec.fn
127
+ if len(arg_list) != len(pfunc.ins_info):
128
+ raise ValueError(
129
+ f"kernel {fn_type} arg count mismatch: got {len(arg_list)}, expect {len(pfunc.ins_info)}"
130
+ )
131
+ for idx, (ins_spec, val) in enumerate(
132
+ zip(pfunc.ins_info, arg_list, strict=True)
133
+ ):
134
+ if isinstance(ins_spec, TableType):
135
+ _validate_table_arg(fn_type, idx, ins_spec, val)
136
+ continue
137
+ if isinstance(ins_spec, TensorType):
138
+ _validate_tensor_arg(fn_type, idx, ins_spec, val)
139
+ continue
140
+ # install kernel context
141
+ kctx = KernelContext(
142
+ rank=self.rank,
143
+ world_size=self.world_size,
144
+ state=self.state,
145
+ cache=self.cache,
146
+ )
147
+ token = base._CTX_VAR.set(kctx) # type: ignore[attr-defined]
148
+ try:
149
+ raw = fn(pfunc, *arg_list)
150
+ finally:
151
+ base._CTX_VAR.reset(token) # type: ignore[attr-defined]
152
+ # Stats (best effort)
153
+ try:
154
+ op_calls = self.stats.setdefault("op_calls", {})
155
+ op_calls[fn_type] = op_calls.get(fn_type, 0) + 1
156
+ except Exception: # pragma: no cover - never raise due to stats
157
+ pass
158
+ expected = len(pfunc.outs_info)
159
+ if expected == 0:
160
+ if raw in (None, (), []):
161
+ return []
162
+ raise ValueError(
163
+ f"kernel {fn_type} should return no values; got {type(raw).__name__}"
164
+ )
165
+ if expected == 1:
166
+ if isinstance(raw, (tuple, list)):
167
+ if len(raw) != 1:
168
+ raise ValueError(
169
+ f"kernel {fn_type} produced {len(raw)} outputs, expected 1"
170
+ )
171
+ return [raw[0]]
172
+ return [raw]
173
+ if not isinstance(raw, (tuple, list)):
174
+ raise TypeError(
175
+ f"kernel {fn_type} must return sequence (len={expected}), got {type(raw).__name__}"
176
+ )
177
+ if len(raw) != expected:
178
+ raise ValueError(
179
+ f"kernel {fn_type} produced {len(raw)} outputs, expected {expected}"
180
+ )
181
+ return list(raw)
182
+
183
+ def reset(self) -> None:
184
+ self.state.clear()
185
+ self.cache.clear()
186
+
187
+ # ---- explicit (re)binding API ----
188
+ def bind_op(self, op_type: str, kernel_id: str, *, force: bool = False) -> None:
189
+ """Bind an operation to a kernel at runtime.
190
+
191
+ force=False (default) preserves any existing binding to avoid accidental
192
+ silent overrides. Use ``rebind_op`` or ``force=True`` to intentionally
193
+ change a binding.
194
+ """
195
+ base.bind_op(op_type, kernel_id, force=force)
196
+
197
+ def rebind_op(self, op_type: str, kernel_id: str) -> None:
198
+ """Force rebind an operation to a different kernel (shorthand)."""
199
+ base.bind_op(op_type, kernel_id, force=True)
200
+
201
+
202
+ def _validate_table_arg(
203
+ fn_type: str, arg_index: int, spec: TableType, value: Any
204
+ ) -> None:
205
+ if not isinstance(value, TableLike):
206
+ raise TypeError(
207
+ f"kernel {fn_type} input[{arg_index}] expects TableLike, got {type(value).__name__}"
208
+ )
209
+ if len(value.columns) != len(spec.columns):
210
+ raise ValueError(
211
+ f"kernel {fn_type} input[{arg_index}] column count mismatch: got {len(value.columns)}, expected {len(spec.columns)}"
212
+ )
213
+
214
+
215
+ def _validate_tensor_arg(
216
+ fn_type: str, arg_index: int, spec: TensorType, value: Any
217
+ ) -> None:
218
+ # Backend-only handle sentinel (e.g., PHE keys) bypasses all structural checks
219
+ if tuple(spec.shape) == (-1, 0) and spec.dtype == UINT8:
220
+ return
221
+
222
+ if isinstance(value, (int, float, bool, complex)):
223
+ val_shape: tuple[Any, ...] = ()
224
+ duck_dtype: Any = type(value)
225
+ else:
226
+ if not isinstance(value, TensorLike):
227
+ raise TypeError(
228
+ f"kernel {fn_type} input[{arg_index}] expects TensorLike, got {type(value).__name__}"
229
+ )
230
+ val_shape = getattr(value, "shape", ())
231
+ duck_dtype = getattr(value, "dtype", None)
232
+
233
+ if len(spec.shape) != len(val_shape):
234
+ raise ValueError(
235
+ f"kernel {fn_type} input[{arg_index}] rank mismatch: got {val_shape}, expected {spec.shape}"
236
+ )
237
+
238
+ for dim_idx, (spec_dim, val_dim) in enumerate(
239
+ zip(spec.shape, val_shape, strict=True)
240
+ ):
241
+ if spec_dim >= 0 and spec_dim != val_dim:
242
+ raise ValueError(
243
+ f"kernel {fn_type} input[{arg_index}] shape mismatch at dim {dim_idx}: got {val_dim}, expected {spec_dim}"
244
+ )
245
+
246
+ try:
247
+ val_dtype = DType.from_any(duck_dtype)
248
+ except (ValueError, TypeError): # pragma: no cover
249
+ raise TypeError(
250
+ f"kernel {fn_type} input[{arg_index}] has unsupported dtype object {duck_dtype!r}"
251
+ ) from None
252
+ if val_dtype != spec.dtype:
253
+ raise ValueError(
254
+ f"kernel {fn_type} input[{arg_index}] dtype mismatch: got {val_dtype}, expected {spec.dtype}"
255
+ )