mplang-nightly 0.1.dev147__py3-none-any.whl → 0.1.dev149__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/backend/base.py +21 -47
- mplang/backend/context.py +67 -26
- mplang/backend/stablehlo.py +8 -1
- mplang/frontend/jax_cc.py +39 -7
- {mplang_nightly-0.1.dev147.dist-info → mplang_nightly-0.1.dev149.dist-info}/METADATA +1 -1
- {mplang_nightly-0.1.dev147.dist-info → mplang_nightly-0.1.dev149.dist-info}/RECORD +9 -9
- {mplang_nightly-0.1.dev147.dist-info → mplang_nightly-0.1.dev149.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev147.dist-info → mplang_nightly-0.1.dev149.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev147.dist-info → mplang_nightly-0.1.dev149.dist-info}/licenses/LICENSE +0 -0
mplang/backend/base.py
CHANGED
@@ -12,20 +12,21 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
"""Backend kernel registry
|
15
|
+
"""Backend kernel registry: mapping kernel_id -> implementation.
|
16
16
|
|
17
|
-
This
|
17
|
+
This module provides a lightweight registry for backend kernel implementations.
|
18
|
+
It does not track or decide which kernel handles a given semantic operation;
|
19
|
+
that policy (op -> kernel_id) is managed externally by each ``RuntimeContext``.
|
18
20
|
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
21
|
+
Exposed primitives:
|
22
|
+
* ``@kernel_def(kernel_id)``: decorator to register a kernel implementation.
|
23
|
+
* ``get_kernel_spec(kernel_id)``: look up a previously registered kernel.
|
24
|
+
* ``cur_kctx()`` / ``KernelContext``: execution context available only
|
25
|
+
inside a kernel body (rank, world_size, per-backend state pockets, and a
|
26
|
+
runtime-wide cache shared by kernels of the same runtime instance).
|
25
27
|
|
26
|
-
|
27
|
-
|
28
|
-
centrally (lazy) the first time a runtime executes a kernel.
|
28
|
+
No global op binding table exists here; callers resolve an op to a kernel_id
|
29
|
+
before invoking the kernel function.
|
29
30
|
"""
|
30
31
|
|
31
32
|
from __future__ import annotations
|
@@ -38,12 +39,10 @@ from typing import Any
|
|
38
39
|
__all__ = [
|
39
40
|
"KernelContext",
|
40
41
|
"KernelSpec",
|
41
|
-
"bind_op",
|
42
42
|
"cur_kctx",
|
43
|
-
"
|
43
|
+
"get_kernel_spec",
|
44
|
+
"kernel_exists",
|
44
45
|
"list_kernels",
|
45
|
-
"list_ops",
|
46
|
-
"unbind_op",
|
47
46
|
]
|
48
47
|
|
49
48
|
|
@@ -116,9 +115,6 @@ class KernelSpec:
|
|
116
115
|
# All registered kernel implementations: kernel_id -> spec
|
117
116
|
_KERNELS: dict[str, KernelSpec] = {}
|
118
117
|
|
119
|
-
# Active op bindings: op_type -> kernel_id
|
120
|
-
_BINDINGS: dict[str, str] = {}
|
121
|
-
|
122
118
|
|
123
119
|
def kernel_def(kernel_id: str, /, **meta: Any) -> Callable[[KernelFn], KernelFn]:
|
124
120
|
"""Decorator to register a concrete kernel implementation.
|
@@ -136,34 +132,11 @@ def kernel_def(kernel_id: str, /, **meta: Any) -> Callable[[KernelFn], KernelFn]
|
|
136
132
|
return _decorator
|
137
133
|
|
138
134
|
|
139
|
-
def
|
140
|
-
"""
|
141
|
-
|
142
|
-
|
143
|
-
op_type: Semantic operation name.
|
144
|
-
kernel_id: Previously registered kernel identifier.
|
145
|
-
force: If False and op_type already bound, keep existing binding.
|
146
|
-
If True (default), overwrite.
|
147
|
-
"""
|
148
|
-
if kernel_id not in _KERNELS:
|
135
|
+
def get_kernel_spec(kernel_id: str) -> KernelSpec:
|
136
|
+
"""Return KernelSpec for a registered kernel_id (no op binding lookup)."""
|
137
|
+
spec = _KERNELS.get(kernel_id)
|
138
|
+
if spec is None:
|
149
139
|
raise KeyError(f"kernel_id {kernel_id} not registered")
|
150
|
-
if not force and op_type in _BINDINGS:
|
151
|
-
return
|
152
|
-
_BINDINGS[op_type] = kernel_id
|
153
|
-
|
154
|
-
|
155
|
-
def unbind_op(op_type: str) -> None:
|
156
|
-
_BINDINGS.pop(op_type, None)
|
157
|
-
|
158
|
-
|
159
|
-
def get_kernel_for_op(op_type: str) -> KernelSpec:
|
160
|
-
kid = _BINDINGS.get(op_type)
|
161
|
-
if kid is None:
|
162
|
-
# Tests expect NotImplementedError for unsupported operations
|
163
|
-
raise NotImplementedError(f"no backend kernel registered for op {op_type}")
|
164
|
-
spec = _KERNELS.get(kid)
|
165
|
-
if spec is None: # inconsistent state
|
166
|
-
raise RuntimeError(f"active kernel_id {kid} missing spec")
|
167
140
|
return spec
|
168
141
|
|
169
142
|
|
@@ -171,5 +144,6 @@ def list_kernels() -> list[str]:
|
|
171
144
|
return sorted(_KERNELS.keys())
|
172
145
|
|
173
146
|
|
174
|
-
def
|
175
|
-
|
147
|
+
def kernel_exists(kernel_id: str) -> bool:
|
148
|
+
"""Return True if a kernel_id has been registered."""
|
149
|
+
return kernel_id in _KERNELS
|
mplang/backend/context.py
CHANGED
@@ -15,11 +15,10 @@
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
17
|
from collections.abc import Mapping
|
18
|
-
from dataclasses import dataclass, field
|
19
18
|
from typing import Any
|
20
19
|
|
21
20
|
from mplang.backend import base
|
22
|
-
from mplang.backend.base import KernelContext,
|
21
|
+
from mplang.backend.base import KernelContext, get_kernel_spec, kernel_exists
|
23
22
|
from mplang.core.dtype import UINT8, DType
|
24
23
|
from mplang.core.pfunc import PFunction
|
25
24
|
from mplang.core.table import TableLike, TableType
|
@@ -100,30 +99,57 @@ _DEFAULT_BINDINGS: dict[str, str] = {
|
|
100
99
|
# --- RuntimeContext ---
|
101
100
|
|
102
101
|
|
103
|
-
@dataclass
|
104
102
|
class RuntimeContext:
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
103
|
+
"""Per-runtime execution context with isolated op->kernel bindings.
|
104
|
+
|
105
|
+
Parameters
|
106
|
+
----------
|
107
|
+
rank : int
|
108
|
+
Local rank of this participant.
|
109
|
+
world_size : int
|
110
|
+
Total number of participants.
|
111
|
+
initial_bindings : Mapping[str, str] | None, optional
|
112
|
+
Optional partial overrides applied on top of the default binding table
|
113
|
+
during construction (override semantics, not replace). After
|
114
|
+
initialization, all (re)binding must go through ``bind_op`` /
|
115
|
+
``rebind_op``.
|
116
|
+
state / cache / stats : dict, optional
|
117
|
+
Mutable pockets reused across kernel invocations. If omitted, new
|
118
|
+
dictionaries are created.
|
119
|
+
"""
|
120
|
+
|
121
|
+
__slots__ = ("_ibindings", "cache", "rank", "state", "stats", "world_size")
|
122
|
+
|
123
|
+
def __init__(
|
124
|
+
self,
|
125
|
+
rank: int,
|
126
|
+
world_size: int,
|
127
|
+
initial_bindings: Mapping[str, str] | None = None,
|
128
|
+
*,
|
129
|
+
state: dict[str, dict[str, Any]] | None = None,
|
130
|
+
cache: dict[str, Any] | None = None,
|
131
|
+
stats: dict[str, Any] | None = None,
|
132
|
+
) -> None:
|
113
133
|
_ensure_impl_imported()
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
134
|
+
self.rank = rank
|
135
|
+
self.world_size = world_size
|
136
|
+
# Merge defaults with user overrides (override semantics)
|
137
|
+
self._ibindings: dict[str, str] = {
|
138
|
+
**_DEFAULT_BINDINGS,
|
139
|
+
**(initial_bindings or {}),
|
140
|
+
}
|
141
|
+
self.state = state if state is not None else {}
|
142
|
+
self.cache = cache if cache is not None else {}
|
143
|
+
self.stats = stats if stats is not None else {}
|
121
144
|
self.stats.setdefault("op_calls", {})
|
122
145
|
|
123
146
|
def run_kernel(self, pfunc: PFunction, arg_list: list[Any]) -> list[Any]:
|
124
147
|
fn_type = pfunc.fn_type
|
125
|
-
|
126
|
-
|
148
|
+
kid = self._ibindings.get(fn_type)
|
149
|
+
if kid is None:
|
150
|
+
raise NotImplementedError(f"no backend kernel registered for op {fn_type}")
|
151
|
+
spec = get_kernel_spec(kid)
|
152
|
+
fn = spec.fn # kernel implementation
|
127
153
|
if len(arg_list) != len(pfunc.ins_info):
|
128
154
|
raise ValueError(
|
129
155
|
f"kernel {fn_type} arg count mismatch: got {len(arg_list)}, expect {len(pfunc.ins_info)}"
|
@@ -186,17 +212,32 @@ class RuntimeContext:
|
|
186
212
|
|
187
213
|
# ---- explicit (re)binding API ----
|
188
214
|
def bind_op(self, op_type: str, kernel_id: str, *, force: bool = False) -> None:
|
189
|
-
"""Bind an operation to a kernel
|
215
|
+
"""Bind an operation to a kernel for THIS context only.
|
190
216
|
|
191
|
-
force=False (default)
|
192
|
-
silent overrides. Use ``rebind_op`` or ``force=True`` to intentionally
|
193
|
-
change a binding.
|
217
|
+
force=False (default) keeps existing binding (no silent override).
|
194
218
|
"""
|
195
|
-
|
219
|
+
if not kernel_exists(kernel_id):
|
220
|
+
raise KeyError(f"kernel_id {kernel_id} not registered")
|
221
|
+
if not force and op_type in self._ibindings:
|
222
|
+
return
|
223
|
+
self._ibindings[op_type] = kernel_id
|
196
224
|
|
197
225
|
def rebind_op(self, op_type: str, kernel_id: str) -> None:
|
198
226
|
"""Force rebind an operation to a different kernel (shorthand)."""
|
199
|
-
|
227
|
+
self.bind_op(op_type, kernel_id, force=True)
|
228
|
+
|
229
|
+
# Introspection helpers
|
230
|
+
def list_bound_ops(self) -> list[str]: # pragma: no cover - convenience
|
231
|
+
return sorted(self._ibindings.keys())
|
232
|
+
|
233
|
+
def get_binding(self, op_type: str) -> str | None: # pragma: no cover
|
234
|
+
return self._ibindings.get(op_type)
|
235
|
+
|
236
|
+
def __repr__(self) -> str: # pragma: no cover - debug aid
|
237
|
+
return (
|
238
|
+
f"RuntimeContext(rank={self.rank}, world_size={self.world_size}, "
|
239
|
+
f"bound_ops={len(self._ibindings)})"
|
240
|
+
)
|
200
241
|
|
201
242
|
|
202
243
|
def _validate_table_arg(
|
mplang/backend/stablehlo.py
CHANGED
@@ -51,8 +51,15 @@ def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
|
|
51
51
|
raise RuntimeError(f"StableHLO compile failed: {e}") from e
|
52
52
|
cache[mlir_text] = compiled
|
53
53
|
|
54
|
+
# Handle JAX's unused parameter elimination via arg_keep_map
|
55
|
+
runtime_args = args
|
56
|
+
if "arg_keep_map" in pfunc.attrs:
|
57
|
+
keep_indices = pfunc.attrs["arg_keep_map"]
|
58
|
+
# Filter out arguments that were eliminated by JAX during compilation
|
59
|
+
runtime_args = tuple(args[i] for i in keep_indices)
|
60
|
+
|
54
61
|
jax_args = []
|
55
|
-
for arg in
|
62
|
+
for arg in runtime_args:
|
56
63
|
if hasattr(arg, "numpy"):
|
57
64
|
jax_arg = jnp.array(arg.numpy()) # type: ignore
|
58
65
|
else:
|
mplang/frontend/jax_cc.py
CHANGED
@@ -106,14 +106,46 @@ def jax2stablehlo(
|
|
106
106
|
out_info_flat, out_tree = tree_flatten(lowered.out_info)
|
107
107
|
out_info_flat = [TensorType.from_obj(info) for info in out_info_flat]
|
108
108
|
|
109
|
+
# Extract argument keep mapping to handle JAX's unused parameter elimination
|
110
|
+
# JAX can eliminate unused parameters during compilation, but the runtime still
|
111
|
+
# receives all original arguments. We need the mapping to filter them correctly.
|
112
|
+
arg_keep_map = None
|
113
|
+
original_arg_count = len(in_vars)
|
114
|
+
|
115
|
+
try:
|
116
|
+
# Access JAX internal kept_var_idx - the authoritative source
|
117
|
+
# This tells us exactly which original parameters survived compilation
|
118
|
+
compile_args = lowered._lowering.compile_args
|
119
|
+
kept_var_idx = compile_args["kept_var_idx"]
|
120
|
+
|
121
|
+
kept_indices = sorted(kept_var_idx)
|
122
|
+
if len(kept_indices) < original_arg_count:
|
123
|
+
arg_keep_map = kept_indices
|
124
|
+
|
125
|
+
except (AttributeError, KeyError, TypeError) as e:
|
126
|
+
# JAX internal API is not available or changed
|
127
|
+
# This is a hard error - we cannot reliably handle unused parameters
|
128
|
+
# without knowing exactly which ones were kept
|
129
|
+
raise RuntimeError(
|
130
|
+
f"Cannot access JAX's kept_var_idx to handle unused parameter elimination. "
|
131
|
+
f"This function may have unused parameters that JAX optimized away, "
|
132
|
+
f"but we cannot determine which ones without the internal API. "
|
133
|
+
f"Original error: {e}"
|
134
|
+
) from e
|
135
|
+
|
109
136
|
# This format tells JaxRT how to handle the compiled result
|
110
|
-
|
111
|
-
fn_type
|
112
|
-
ins_info
|
113
|
-
outs_info
|
114
|
-
fn_name
|
115
|
-
fn_text
|
116
|
-
|
137
|
+
pfn_kwargs: dict[str, Any] = {
|
138
|
+
"fn_type": "mlir.stablehlo", # Key: specify StableHLO MLIR format
|
139
|
+
"ins_info": tuple(TensorType.from_obj(x) for x in in_vars),
|
140
|
+
"outs_info": tuple(out_info_flat),
|
141
|
+
"fn_name": get_fn_name(flat_fn),
|
142
|
+
"fn_text": mlir_text, # MLIR text, serializable for transmission
|
143
|
+
}
|
144
|
+
|
145
|
+
if arg_keep_map is not None:
|
146
|
+
pfn_kwargs["arg_keep_map"] = arg_keep_map
|
147
|
+
|
148
|
+
pfn = PFunction(**pfn_kwargs)
|
117
149
|
return pfn, in_vars, out_tree
|
118
150
|
|
119
151
|
|
@@ -4,14 +4,14 @@ mplang/device.py,sha256=Iz_YFKkrbTFKtTQdGqkQZfc0NQH9dIxXP7-fUkIQOa4,12568
|
|
4
4
|
mplang/analysis/__init__.py,sha256=CTHFvRsi-nFngojqjn08UaR3RY9i7CJ7T2UdR95kCrk,1056
|
5
5
|
mplang/analysis/diagram.py,sha256=ffwgD12gL1_KH1uJ_EYkjmIlDrfxYJJkWj-wHl09_Xk,19520
|
6
6
|
mplang/backend/__init__.py,sha256=2WE4cmW96Xkzyq2yRRYNww4kZ5o6u6NbPV0BxqZG698,581
|
7
|
-
mplang/backend/base.py,sha256=
|
7
|
+
mplang/backend/base.py,sha256=eizxj16sWkUvBvXWS0Zl-S9uuqalJmNUzB1xLhBg8S8,4920
|
8
8
|
mplang/backend/builtin.py,sha256=Mk1uUO2Vpw3meqZ0B7B0hG-wndea6cmFv2Uk1vM_uTg,7052
|
9
|
-
mplang/backend/context.py,sha256=
|
9
|
+
mplang/backend/context.py,sha256=fVJ0w0cw15JEqJO048dncWg7DGNWqbHSUjq42Jsyvos,10952
|
10
10
|
mplang/backend/crypto.py,sha256=H_s5HI7lUP7g0xz-a9qMbSn6dhJStUilKbn3-7SIh0I,3812
|
11
11
|
mplang/backend/phe.py,sha256=uNqmrbDAbd97TWS_O6D5sopastHy6J20R7knFE4M4uc,65247
|
12
12
|
mplang/backend/spu.py,sha256=QT1q5uv-5P_nBGtTvtA_yI2h3h3zIqNSnvzGT7Shua4,9307
|
13
13
|
mplang/backend/sql_duckdb.py,sha256=U_KzEUinxrBRDoUz2Vh597-N4I3hPOBT0RT3tX-ZqKE,1502
|
14
|
-
mplang/backend/stablehlo.py,sha256=
|
14
|
+
mplang/backend/stablehlo.py,sha256=RhKf6TUvjLrRvgtdVY2HxcRDGtjpKBobuBFMfsvZQOI,2937
|
15
15
|
mplang/backend/tee.py,sha256=6kc7qTe8nWc3pr6iYtozEGLO8Umg-UBQLDiz6p3pdVg,1918
|
16
16
|
mplang/core/__init__.py,sha256=lWxlEKfRwX7FNDzgyKZ1fiDMaCiqkyg0j5mKlZD_v7g,2244
|
17
17
|
mplang/core/cluster.py,sha256=gqMJenvXUfHhE181Dd5JiUkD4nT07RLoicBnvsGmRkE,8598
|
@@ -41,7 +41,7 @@ mplang/frontend/base.py,sha256=rGtfBejcDh9mTRxOdJK5VUlG5vYiVJSir8X72X0Huvc,18264
|
|
41
41
|
mplang/frontend/builtin.py,sha256=8qrlbe_SSy6QTXTnMG6_ADB8jSklVZGFBrkoR-p02FE,9368
|
42
42
|
mplang/frontend/crypto.py,sha256=Nf8zT4Eko7MIs4R2tgZecKVd7d6Hvd_CGGmANhs3Ghs,3651
|
43
43
|
mplang/frontend/ibis_cc.py,sha256=CTTbPPZ9hFnHuFDDIfgJHie1EdNnHmi5Ha1KsX0iYh8,4235
|
44
|
-
mplang/frontend/jax_cc.py,sha256=
|
44
|
+
mplang/frontend/jax_cc.py,sha256=lMqaYD1tyM5DsStTNYifAXzhzsNM5nDiG3a61ygbWyc,7807
|
45
45
|
mplang/frontend/phe.py,sha256=tDsCvStjVJ1Fs07yF3idkFnugUCA1zdFApPx7Uuulik,6795
|
46
46
|
mplang/frontend/spu.py,sha256=7G6DaEfC5APSDhfeWSISTG_8tEcVbWth3XmjL8QUrVA,4994
|
47
47
|
mplang/frontend/sql.py,sha256=DFdvjEPQX28VCRgUMeHYR0rwwOaoCH15bpvvlclLtHA,1999
|
@@ -70,8 +70,8 @@ mplang/utils/crypto.py,sha256=rvPomBFtznRHc3RPi6Aip9lsU8zW2oxBqGv1K3vn7Rs,1052
|
|
70
70
|
mplang/utils/func_utils.py,sha256=vCJcZmu0bEbqhOQKdpttV2_MBllIcPSN0b8U4WjNGGo,5164
|
71
71
|
mplang/utils/spu_utils.py,sha256=S3L9RBkBe2AvSuMSQQ12cBY5Y1NPthubvErSX_7nj1A,4158
|
72
72
|
mplang/utils/table_utils.py,sha256=aC-IZOKkSmFkpr3NZchLM0Wt0GOn-rg_xHBHREWBwAU,2202
|
73
|
-
mplang_nightly-0.1.
|
74
|
-
mplang_nightly-0.1.
|
75
|
-
mplang_nightly-0.1.
|
76
|
-
mplang_nightly-0.1.
|
77
|
-
mplang_nightly-0.1.
|
73
|
+
mplang_nightly-0.1.dev149.dist-info/METADATA,sha256=eZ_qGx1500gbGY5Ms8smHjQsyZ2Wxfa11z2mSeF0CSk,16547
|
74
|
+
mplang_nightly-0.1.dev149.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
75
|
+
mplang_nightly-0.1.dev149.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
|
76
|
+
mplang_nightly-0.1.dev149.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
77
|
+
mplang_nightly-0.1.dev149.dist-info/RECORD,,
|
File without changes
|
{mplang_nightly-0.1.dev147.dist-info → mplang_nightly-0.1.dev149.dist-info}/entry_points.txt
RENAMED
File without changes
|
{mplang_nightly-0.1.dev147.dist-info → mplang_nightly-0.1.dev149.dist-info}/licenses/LICENSE
RENAMED
File without changes
|