mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/ops/base.py
DELETED
|
@@ -1,424 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
from abc import ABC, abstractmethod
|
|
18
|
-
from collections.abc import Callable
|
|
19
|
-
from typing import Any
|
|
20
|
-
|
|
21
|
-
from jax.tree_util import PyTreeDef, tree_flatten
|
|
22
|
-
|
|
23
|
-
from mplang.v1.core import MPContext, MPObject, PFunction, TableType, TensorType
|
|
24
|
-
|
|
25
|
-
# -----------------------------------------------------------------------------
|
|
26
|
-
# Triad ABI
|
|
27
|
-
# The standard return contract for frontend operations (FeOperation.trace).
|
|
28
|
-
#
|
|
29
|
-
# Triad := (PFunction, list[MPObject], PyTreeDef)
|
|
30
|
-
# - PFunction: Captures fn_type (routing key, e.g., "mlir.stablehlo", "sql.run"),
|
|
31
|
-
# input/output MPTypes and optional attributes.
|
|
32
|
-
# - list[MPObject]: The flat positional MPObjects captured under the current
|
|
33
|
-
# context (Trace/Interp). Order matches pfunc.ins_info.
|
|
34
|
-
# - PyTreeDef: The output pytree structure to unflatten results after execution.
|
|
35
|
-
#
|
|
36
|
-
# Error modes:
|
|
37
|
-
# - Type errors if non-MPObject positional args provided to simple ops.
|
|
38
|
-
# - Kernel/type builder must produce TensorType/TableType leaves for outs.
|
|
39
|
-
# - Context errors propagate from cur_ctx() usage if called outside capture.
|
|
40
|
-
# -----------------------------------------------------------------------------
|
|
41
|
-
Triad = tuple[PFunction, list[MPObject], PyTreeDef]
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
# -----------------------------------------------------------------------------
|
|
45
|
-
# Lightweight fe module/feop system (new FeOperation based)
|
|
46
|
-
# -----------------------------------------------------------------------------
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
# Global registry for frontend modules and operations
|
|
50
|
-
class FeRegistry:
|
|
51
|
-
"""Registry for FeModules and FeOperations.
|
|
52
|
-
|
|
53
|
-
Maintains:
|
|
54
|
-
- modules: name -> FeModule
|
|
55
|
-
- ops: (module, name) -> FeOperation (callable) returning Triad
|
|
56
|
-
"""
|
|
57
|
-
|
|
58
|
-
__slots__ = ("_modules", "_ops")
|
|
59
|
-
|
|
60
|
-
def __init__(self) -> None:
|
|
61
|
-
# Typed registries
|
|
62
|
-
self._modules: dict[str, FeModule] = {}
|
|
63
|
-
self._ops: dict[tuple[str, str], FeOperation] = {}
|
|
64
|
-
|
|
65
|
-
# ----------------------------- Modules -----------------------------
|
|
66
|
-
def register_module(self, mod: FeModule, *, replace: bool = False) -> None:
|
|
67
|
-
if not replace and mod.name in self._modules:
|
|
68
|
-
raise ValueError(f"Module already registered: {mod.name}")
|
|
69
|
-
self._modules[mod.name] = mod
|
|
70
|
-
|
|
71
|
-
def get_module(self, name: str) -> FeModule:
|
|
72
|
-
if name not in self._modules:
|
|
73
|
-
raise KeyError(f"Unknown module: {name}")
|
|
74
|
-
return self._modules[name]
|
|
75
|
-
|
|
76
|
-
def has_module(self, name: str) -> bool:
|
|
77
|
-
return name in self._modules
|
|
78
|
-
|
|
79
|
-
def list_modules(self) -> dict[str, FeModule]:
|
|
80
|
-
return dict(self._modules)
|
|
81
|
-
|
|
82
|
-
# ------------------------------ Ops -------------------------------
|
|
83
|
-
def register_op(
|
|
84
|
-
self, module: str, name: str, op: FeOperation, *, replace: bool = False
|
|
85
|
-
) -> None:
|
|
86
|
-
key = (module, name)
|
|
87
|
-
if not replace and key in self._ops:
|
|
88
|
-
raise ValueError(f"Op already registered: {module}.{name}")
|
|
89
|
-
self._ops[key] = op
|
|
90
|
-
|
|
91
|
-
def get_op(self, module: str, name: str) -> FeOperation:
|
|
92
|
-
key = (module, name)
|
|
93
|
-
if key not in self._ops:
|
|
94
|
-
raise KeyError(f"Unknown op: {module}.{name}")
|
|
95
|
-
return self._ops[key]
|
|
96
|
-
|
|
97
|
-
def list_ops(self, module: str | None = None) -> dict[tuple[str, str], FeOperation]:
|
|
98
|
-
if module is None:
|
|
99
|
-
return dict(self._ops)
|
|
100
|
-
return {k: v for k, v in self._ops.items() if k[0] == module}
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
_REGISTRY = FeRegistry()
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
def get_registry() -> FeRegistry:
|
|
107
|
-
return _REGISTRY
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
def is_feop(x: Any) -> bool:
|
|
111
|
-
"""Return True if x is a frontend operation instance."""
|
|
112
|
-
return isinstance(x, FeOperation)
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
class FeModule(ABC):
|
|
116
|
-
"""Frontend module with feop/typed_op decorators.
|
|
117
|
-
|
|
118
|
-
When to use which:
|
|
119
|
-
- Use typed_op (SimpleFeOperation) when:
|
|
120
|
-
- You know the backend routing key up front via pfunc_name, and the kernel is pure type logic.
|
|
121
|
-
- Inputs are MPObjects (positional/kwargs). Attributes are simple Python values (int/float/str/bytes/tuples/lists of primitives) passed as keywords.
|
|
122
|
-
- The kernel returns TensorType/TableType (or a PyTree thereof); no IR construction inside.
|
|
123
|
-
- Use feop (InlineFeOperation) when:
|
|
124
|
-
- You already build and return the Triad explicitly, or need custom packing/attrs/multi-output composition.
|
|
125
|
-
- Subclass FeOperation when:
|
|
126
|
-
- You need compilation/stateful behavior/dynamic routing, multiple PFunctions, or complex capture flows.
|
|
127
|
-
|
|
128
|
-
Tips:
|
|
129
|
-
- Keep routing information in PFunction.fn_type (e.g., "basic.read", "sql.run", "mlir.stablehlo").
|
|
130
|
-
- Avoid backend-specific logic in kernels; only validate and shape types.
|
|
131
|
-
- Prefer keyword-only attributes in typed_op kernels for clarity (def op(x: MPObject, *, attr: int)).
|
|
132
|
-
"""
|
|
133
|
-
|
|
134
|
-
def __init__(self, name: str):
|
|
135
|
-
self.name = name
|
|
136
|
-
get_registry().register_module(self)
|
|
137
|
-
|
|
138
|
-
@abstractmethod
|
|
139
|
-
def initialize(self, ctx: MPContext) -> None: ...
|
|
140
|
-
|
|
141
|
-
def op_def(self) -> Callable[[Callable[..., Triad]], FeOperation]:
|
|
142
|
-
"""Decorator for inline/complex ops which already return a Triad.
|
|
143
|
-
|
|
144
|
-
Usage:
|
|
145
|
-
@mymod.feop()
|
|
146
|
-
def scale(x: MPObject, factor: int) -> Triad:
|
|
147
|
-
# build PFunction and return triad directly
|
|
148
|
-
...
|
|
149
|
-
return pfunc, [x], out_tree
|
|
150
|
-
"""
|
|
151
|
-
|
|
152
|
-
def _decorator(trace_fn: Callable[..., Triad]) -> FeOperation:
|
|
153
|
-
name = trace_fn.__name__
|
|
154
|
-
op = InlineFeOperation(self, name, trace_fn)
|
|
155
|
-
get_registry().register_op(self.name, name, op)
|
|
156
|
-
return op
|
|
157
|
-
|
|
158
|
-
return _decorator
|
|
159
|
-
|
|
160
|
-
def simple_op(
|
|
161
|
-
self, pfunc_name: str | None = None
|
|
162
|
-
) -> Callable[[Callable[..., Any]], FeOperation]:
|
|
163
|
-
"""Decorator for type-driven ops that return only types/schemas.
|
|
164
|
-
|
|
165
|
-
The decorated kernel should compute and return a TensorType/TableType (or PyTree thereof).
|
|
166
|
-
Positional inputs may be MPObjects (captured as inputs) or data-like values (TableLike/TensorLike)
|
|
167
|
-
used for type inference/validation. Keyword arguments are PFunction attributes and must be plain
|
|
168
|
-
Python values (int/float/str/bytes/tuples/lists of primitives). Passing MPObjects via kwargs is not allowed.
|
|
169
|
-
|
|
170
|
-
SSOT naming: The operation name is derived from the kernel function name (kernel.__name__),
|
|
171
|
-
ensuring there's a single source of truth and improving readability. Use clear, concise
|
|
172
|
-
function names to define the public op names.
|
|
173
|
-
|
|
174
|
-
Example:
|
|
175
|
-
@mymod.typed_op(pfunc_name="builtin.add")
|
|
176
|
-
def add_kernel(x: MPObject, y: MPObject) -> TensorType:
|
|
177
|
-
return x.mptype._type # same shape/type as x
|
|
178
|
-
|
|
179
|
-
Bad vs Good (signatures and calls):
|
|
180
|
-
- Bad: def op(x: MPObject, **kwargs): ... # disallowed: **kwargs
|
|
181
|
-
Good: def op(x: MPObject, *, attr: int): ...
|
|
182
|
-
|
|
183
|
-
- Bad: def op(*args, **kwargs): ... # disallowed: *args/**kwargs
|
|
184
|
-
Good: def op(x: MPObject, y: MPObject, *, k: str): ...
|
|
185
|
-
|
|
186
|
-
- Bad: enc(plaintext=pt, key=mp_key) # MPObject via kwargs (disallowed)
|
|
187
|
-
Good: enc(pt, mp_key) # pass MPObjects positionally
|
|
188
|
-
|
|
189
|
-
- Good: hkdf(secret, "info") # data-like positional mapped to kw-only attr
|
|
190
|
-
Also good: hkdf(secret, info="info")
|
|
191
|
-
|
|
192
|
-
- Good: phe.mul(jnp.array(...), jnp.array(...)) # data-like positionals allowed for type inference
|
|
193
|
-
"""
|
|
194
|
-
|
|
195
|
-
def _decorator(kernel: Callable[..., Any]) -> FeOperation:
|
|
196
|
-
# Default PFunction routing when not provided: "<module>.<kernel_name>"
|
|
197
|
-
final_pfunc_name = pfunc_name or f"{self.name}.{kernel.__name__}"
|
|
198
|
-
op = SimpleFeOperation(self, final_pfunc_name, kernel)
|
|
199
|
-
# Use kernel function name as SSOT for op name
|
|
200
|
-
get_registry().register_op(self.name, op.name, op)
|
|
201
|
-
return op
|
|
202
|
-
|
|
203
|
-
return _decorator
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
class StatelessFeModule(FeModule):
|
|
207
|
-
"""Stateless frontend module with no ctx-level state."""
|
|
208
|
-
|
|
209
|
-
def initialize(self, ctx: MPContext) -> None:
|
|
210
|
-
pass
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
# -----------------------------------------------------------------------------
|
|
214
|
-
# Class-based contracts and adapters
|
|
215
|
-
# -----------------------------------------------------------------------------
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
class FeOperation(ABC):
|
|
219
|
-
"""Class-based frontend operation contract.
|
|
220
|
-
|
|
221
|
-
Subclasses implement trace() to produce a standard triad. __call__ delegates to trace().
|
|
222
|
-
"""
|
|
223
|
-
|
|
224
|
-
module: FeModule
|
|
225
|
-
name: str
|
|
226
|
-
|
|
227
|
-
def __init__(self, module: FeModule, name: str):
|
|
228
|
-
self.module = module
|
|
229
|
-
self.name = name
|
|
230
|
-
|
|
231
|
-
@abstractmethod
|
|
232
|
-
def trace(self, *args: Any, **kwargs: Any) -> Triad:
|
|
233
|
-
"""Produce a standard triad for this operation."""
|
|
234
|
-
|
|
235
|
-
# Convenience: allow calling an FeOperation like a function.
|
|
236
|
-
def __call__(self, *args: Any, **kwargs: Any) -> Triad:
|
|
237
|
-
return self.trace(*args, **kwargs)
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
class InlineFeOperation(FeOperation):
|
|
241
|
-
"""FeOperation that delegates tracing to a provided triad-returning function."""
|
|
242
|
-
|
|
243
|
-
def __init__(self, module: FeModule, name: str, trace_fn: Callable[..., Triad]):
|
|
244
|
-
super().__init__(module, name)
|
|
245
|
-
self._trace_fn = trace_fn
|
|
246
|
-
|
|
247
|
-
# override
|
|
248
|
-
def trace(self, *args: Any, **kwargs: Any) -> Triad:
|
|
249
|
-
return self._trace_fn(*args, **kwargs)
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
class SimpleFeOperation(FeOperation):
|
|
253
|
-
"""FeOperation that builds Triad from a type-only kernel.
|
|
254
|
-
|
|
255
|
-
Contract (keep it simple):
|
|
256
|
-
- Kernel computes and returns TensorType/TableType or a PyTree thereof.
|
|
257
|
-
- Positional inputs may be MPObjects (captured as inputs) or data-like values (TableLike/TensorLike)
|
|
258
|
-
used for type inference/validation. Keyword arguments are attributes and must be plain Python
|
|
259
|
-
values (TensorType/TableType are also excluded from attrs). MPObject kwargs are disallowed.
|
|
260
|
-
- Prefer keyword-only attributes in the kernel signature for explicitness. For convenience, non-MPObject
|
|
261
|
-
positional values that are not data-like will be mapped to keyword-only parameters by order when possible.
|
|
262
|
-
- No IR building inside the kernel; PFunction is assembled here with fn_type=pfunc_name.
|
|
263
|
-
"""
|
|
264
|
-
|
|
265
|
-
def __init__(
|
|
266
|
-
self,
|
|
267
|
-
module: FeModule,
|
|
268
|
-
pfunc_name: str,
|
|
269
|
-
kernel: Callable[..., Any],
|
|
270
|
-
):
|
|
271
|
-
# Derive operation name from kernel function name for SSOT
|
|
272
|
-
super().__init__(module, kernel.__name__)
|
|
273
|
-
self.pfunc_name = pfunc_name
|
|
274
|
-
self._kernel = kernel
|
|
275
|
-
|
|
276
|
-
# Validate kernel signature: typed_op kernels must not use *args/**kwargs.
|
|
277
|
-
import inspect
|
|
278
|
-
|
|
279
|
-
sig = inspect.signature(kernel)
|
|
280
|
-
for p in sig.parameters.values():
|
|
281
|
-
if p.kind == inspect.Parameter.VAR_KEYWORD:
|
|
282
|
-
raise TypeError(
|
|
283
|
-
f"typed_op kernel '{module.name}.{kernel.__name__}' must not use **kwargs; define explicit keywords instead"
|
|
284
|
-
)
|
|
285
|
-
if p.kind == inspect.Parameter.VAR_POSITIONAL:
|
|
286
|
-
raise TypeError(
|
|
287
|
-
f"typed_op kernel '{module.name}.{kernel.__name__}' must not use *args; define explicit parameters instead"
|
|
288
|
-
)
|
|
289
|
-
|
|
290
|
-
# Cache signature and kw-only parameter names for fast path in trace
|
|
291
|
-
self._kernel_sig = sig
|
|
292
|
-
self._kwonly_names = [
|
|
293
|
-
p.name
|
|
294
|
-
for p in sig.parameters.values()
|
|
295
|
-
if p.kind == inspect.Parameter.KEYWORD_ONLY
|
|
296
|
-
]
|
|
297
|
-
|
|
298
|
-
# override
|
|
299
|
-
def trace(self, *args: MPObject, **kwargs: Any) -> Triad:
|
|
300
|
-
# Actual params may not match kernel signature exactly, so we do flexible binding.
|
|
301
|
-
sig = self._kernel_sig
|
|
302
|
-
|
|
303
|
-
# Inputs at PFunction layer are MPObjects captured from positional args only.
|
|
304
|
-
pos_mp_inputs: list[MPObject] = [a for a in args if isinstance(a, MPObject)]
|
|
305
|
-
|
|
306
|
-
# Enforce: no MPObject kwargs per simplified contract
|
|
307
|
-
for k, v in kwargs.items():
|
|
308
|
-
if isinstance(v, MPObject):
|
|
309
|
-
raise TypeError(
|
|
310
|
-
f"typed_op does not accept MPObject kwargs: {k}; pass MPObjects positionally"
|
|
311
|
-
)
|
|
312
|
-
|
|
313
|
-
# Try original call; if it binds, keep it as-is to support data-like positionals
|
|
314
|
-
try:
|
|
315
|
-
sig.bind_partial(*args, **kwargs)
|
|
316
|
-
call_pos = args
|
|
317
|
-
call_kwargs = kwargs
|
|
318
|
-
except TypeError as _bind_err:
|
|
319
|
-
# Fallback: For convenience, map non-MPObject positional arguments to
|
|
320
|
-
# keyword-only parameters by order. This allows ergonomic calls like
|
|
321
|
-
# `crypto.keygen(32)` where the kernel is `def keygen(*, length: int)`.
|
|
322
|
-
# The direct binding `sig.bind_partial(32)` would fail, so we manually
|
|
323
|
-
# map the positional `32` to the `length` keyword.
|
|
324
|
-
non_mp_positional = [a for a in args if not isinstance(a, MPObject)]
|
|
325
|
-
call_kwargs = dict(kwargs)
|
|
326
|
-
filled = 0
|
|
327
|
-
for _i, name in enumerate(self._kwonly_names):
|
|
328
|
-
if filled < len(non_mp_positional) and name not in call_kwargs:
|
|
329
|
-
call_kwargs[name] = non_mp_positional[filled]
|
|
330
|
-
filled += 1
|
|
331
|
-
if filled < len(non_mp_positional):
|
|
332
|
-
leftover = non_mp_positional[filled:]
|
|
333
|
-
raise TypeError(
|
|
334
|
-
f"too many non-MPObject positional values for typed_op '{self.module.name}.{self.name}': {leftover}. "
|
|
335
|
-
"Pass attributes explicitly by keyword (e.g., foo(x, *, attr=...))."
|
|
336
|
-
) from None
|
|
337
|
-
call_pos = tuple(pos_mp_inputs)
|
|
338
|
-
|
|
339
|
-
# Compute PFunction attrs from the call kwargs (exclude MPObject and type objects)
|
|
340
|
-
attr_kwargs: dict[str, Any] = {
|
|
341
|
-
k: v
|
|
342
|
-
for k, v in call_kwargs.items()
|
|
343
|
-
if not isinstance(v, MPObject)
|
|
344
|
-
and not isinstance(v, (TensorType, TableType))
|
|
345
|
-
}
|
|
346
|
-
|
|
347
|
-
# Prepare kernel positional arguments: replace MPObject with its underlying type so
|
|
348
|
-
# the kernel always sees TensorType/TableType (never TraceVar/InterpVar).
|
|
349
|
-
call_pos_types = tuple(a.mptype._type for a in call_pos)
|
|
350
|
-
|
|
351
|
-
# Sanity: no MPObject should appear in kwargs (enforced earlier), but be safe.
|
|
352
|
-
if any(isinstance(v, MPObject) for v in call_kwargs.values()):
|
|
353
|
-
raise TypeError("kernel kwargs should not be MPObject")
|
|
354
|
-
|
|
355
|
-
# Execute kernel to compute return types
|
|
356
|
-
result = self._kernel(*call_pos_types, **call_kwargs)
|
|
357
|
-
|
|
358
|
-
outs_info, out_tree = tree_flatten(result)
|
|
359
|
-
|
|
360
|
-
# ensure all out_vars are TensorType or TableType.
|
|
361
|
-
# TODO(jint), theoretically we can also python constants here.
|
|
362
|
-
for o in outs_info:
|
|
363
|
-
if not isinstance(o, (TensorType, TableType)):
|
|
364
|
-
raise TypeError(
|
|
365
|
-
f"simple op kernel must return TensorType or TableType, got {type(o).__name__}"
|
|
366
|
-
)
|
|
367
|
-
|
|
368
|
-
# Build input types from positional MPObjects only
|
|
369
|
-
ins_info = [a.mptype._type for a in pos_mp_inputs]
|
|
370
|
-
|
|
371
|
-
# Compose PFunction and return triad
|
|
372
|
-
pfunc = PFunction(
|
|
373
|
-
fn_type=self.pfunc_name,
|
|
374
|
-
ins_info=tuple(ins_info),
|
|
375
|
-
outs_info=tuple(outs_info),
|
|
376
|
-
**attr_kwargs,
|
|
377
|
-
)
|
|
378
|
-
return pfunc, pos_mp_inputs, out_tree
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
def stateless_mod(mod_name: str) -> FeModule:
|
|
382
|
-
return StatelessFeModule(mod_name)
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
def list_ops(module: str | None = None) -> dict[tuple[str, str], FeOperation]:
|
|
386
|
-
"""Return a view of registered feops, optionally filtered by module name."""
|
|
387
|
-
return get_registry().list_ops(module)
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
# -----------------------------------------------------------------------------
|
|
391
|
-
# Guidance: complex ops via subclassing
|
|
392
|
-
# -----------------------------------------------------------------------------
|
|
393
|
-
|
|
394
|
-
# Example pattern (non-executable) showing how a complex op (e.g., jax_cc) could
|
|
395
|
-
# capture a Python callable and compile it to a Triad by subclassing FeOperation.
|
|
396
|
-
#
|
|
397
|
-
# class JaxCompileOp(FeOperation):
|
|
398
|
-
# def __init__(self, module: FeModule, name: str, func: Callable[..., Any], *,
|
|
399
|
-
# fn_type: str = "mlir.stablehlo", **options: Any) -> None:
|
|
400
|
-
# super().__init__(module, name)
|
|
401
|
-
# self.func = func
|
|
402
|
-
# self.fn_type = fn_type
|
|
403
|
-
# self.options = dict(options)
|
|
404
|
-
#
|
|
405
|
-
# def trace(self, *args: MPObject, **kwargs: Any) -> Triad:
|
|
406
|
-
# # 1) Infer output types from func and args, respecting current ctx/masks.
|
|
407
|
-
# # 2) Build PFunction with fn_type=self.fn_type and any attributes.
|
|
408
|
-
# # 3) Return (pfunc, list(args), out_tree)
|
|
409
|
-
# raise NotImplementedError
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
# -----------------------------------------------------------------------------
|
|
413
|
-
# Migration notes (checklist)
|
|
414
|
-
# -----------------------------------------------------------------------------
|
|
415
|
-
|
|
416
|
-
# - Replace any isinstance(FEOp)/metadata checks with isinstance(x, FeOperation).
|
|
417
|
-
# - Define a FeModule via femod("module_name") and register it in FeRegistry automatically.
|
|
418
|
-
# - For inline ops that already produce a triad, use @module.feop()(trace_fn). The op name is derived from the function name.
|
|
419
|
-
# - For type-only kernels, use @module.typed_op(pfunc_name)(kernel). The op name is derived from the kernel function name.
|
|
420
|
-
# - For complex ops (with Python callables/closures), subclass FeOperation and register
|
|
421
|
-
# using get_registry().register_op(module, name, op_instance) or use @module.feop with InlineFeOperation.
|
|
422
|
-
# - Ensure PFunction.fn_type is set as the routing key (e.g., "mlir.stablehlo", "sql.run").
|
|
423
|
-
# - Keep device selection/routing out of frontend code; only set fn_type and attributes.
|
|
424
|
-
# - Avoid moving MPObjects across contexts directly; capture within current ctx in trace().
|