mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev270__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/ops/base.py DELETED
@@ -1,424 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- from abc import ABC, abstractmethod
18
- from collections.abc import Callable
19
- from typing import Any
20
-
21
- from jax.tree_util import PyTreeDef, tree_flatten
22
-
23
- from mplang.v1.core import MPContext, MPObject, PFunction, TableType, TensorType
24
-
25
- # -----------------------------------------------------------------------------
26
- # Triad ABI
27
- # The standard return contract for frontend operations (FeOperation.trace).
28
- #
29
- # Triad := (PFunction, list[MPObject], PyTreeDef)
30
- # - PFunction: Captures fn_type (routing key, e.g., "mlir.stablehlo", "sql.run"),
31
- # input/output MPTypes and optional attributes.
32
- # - list[MPObject]: The flat positional MPObjects captured under the current
33
- # context (Trace/Interp). Order matches pfunc.ins_info.
34
- # - PyTreeDef: The output pytree structure to unflatten results after execution.
35
- #
36
- # Error modes:
37
- # - Type errors if non-MPObject positional args provided to simple ops.
38
- # - Kernel/type builder must produce TensorType/TableType leaves for outs.
39
- # - Context errors propagate from cur_ctx() usage if called outside capture.
40
- # -----------------------------------------------------------------------------
41
- Triad = tuple[PFunction, list[MPObject], PyTreeDef]
42
-
43
-
44
- # -----------------------------------------------------------------------------
45
- # Lightweight fe module/feop system (new FeOperation based)
46
- # -----------------------------------------------------------------------------
47
-
48
-
49
- # Global registry for frontend modules and operations
50
- class FeRegistry:
51
- """Registry for FeModules and FeOperations.
52
-
53
- Maintains:
54
- - modules: name -> FeModule
55
- - ops: (module, name) -> FeOperation (callable) returning Triad
56
- """
57
-
58
- __slots__ = ("_modules", "_ops")
59
-
60
- def __init__(self) -> None:
61
- # Typed registries
62
- self._modules: dict[str, FeModule] = {}
63
- self._ops: dict[tuple[str, str], FeOperation] = {}
64
-
65
- # ----------------------------- Modules -----------------------------
66
- def register_module(self, mod: FeModule, *, replace: bool = False) -> None:
67
- if not replace and mod.name in self._modules:
68
- raise ValueError(f"Module already registered: {mod.name}")
69
- self._modules[mod.name] = mod
70
-
71
- def get_module(self, name: str) -> FeModule:
72
- if name not in self._modules:
73
- raise KeyError(f"Unknown module: {name}")
74
- return self._modules[name]
75
-
76
- def has_module(self, name: str) -> bool:
77
- return name in self._modules
78
-
79
- def list_modules(self) -> dict[str, FeModule]:
80
- return dict(self._modules)
81
-
82
- # ------------------------------ Ops -------------------------------
83
- def register_op(
84
- self, module: str, name: str, op: FeOperation, *, replace: bool = False
85
- ) -> None:
86
- key = (module, name)
87
- if not replace and key in self._ops:
88
- raise ValueError(f"Op already registered: {module}.{name}")
89
- self._ops[key] = op
90
-
91
- def get_op(self, module: str, name: str) -> FeOperation:
92
- key = (module, name)
93
- if key not in self._ops:
94
- raise KeyError(f"Unknown op: {module}.{name}")
95
- return self._ops[key]
96
-
97
- def list_ops(self, module: str | None = None) -> dict[tuple[str, str], FeOperation]:
98
- if module is None:
99
- return dict(self._ops)
100
- return {k: v for k, v in self._ops.items() if k[0] == module}
101
-
102
-
103
- _REGISTRY = FeRegistry()
104
-
105
-
106
- def get_registry() -> FeRegistry:
107
- return _REGISTRY
108
-
109
-
110
- def is_feop(x: Any) -> bool:
111
- """Return True if x is a frontend operation instance."""
112
- return isinstance(x, FeOperation)
113
-
114
-
115
- class FeModule(ABC):
116
- """Frontend module with feop/typed_op decorators.
117
-
118
- When to use which:
119
- - Use typed_op (SimpleFeOperation) when:
120
- - You know the backend routing key up front via pfunc_name, and the kernel is pure type logic.
121
- - Inputs are MPObjects (positional/kwargs). Attributes are simple Python values (int/float/str/bytes/tuples/lists of primitives) passed as keywords.
122
- - The kernel returns TensorType/TableType (or a PyTree thereof); no IR construction inside.
123
- - Use feop (InlineFeOperation) when:
124
- - You already build and return the Triad explicitly, or need custom packing/attrs/multi-output composition.
125
- - Subclass FeOperation when:
126
- - You need compilation/stateful behavior/dynamic routing, multiple PFunctions, or complex capture flows.
127
-
128
- Tips:
129
- - Keep routing information in PFunction.fn_type (e.g., "basic.read", "sql.run", "mlir.stablehlo").
130
- - Avoid backend-specific logic in kernels; only validate and shape types.
131
- - Prefer keyword-only attributes in typed_op kernels for clarity (def op(x: MPObject, *, attr: int)).
132
- """
133
-
134
- def __init__(self, name: str):
135
- self.name = name
136
- get_registry().register_module(self)
137
-
138
- @abstractmethod
139
- def initialize(self, ctx: MPContext) -> None: ...
140
-
141
- def op_def(self) -> Callable[[Callable[..., Triad]], FeOperation]:
142
- """Decorator for inline/complex ops which already return a Triad.
143
-
144
- Usage:
145
- @mymod.feop()
146
- def scale(x: MPObject, factor: int) -> Triad:
147
- # build PFunction and return triad directly
148
- ...
149
- return pfunc, [x], out_tree
150
- """
151
-
152
- def _decorator(trace_fn: Callable[..., Triad]) -> FeOperation:
153
- name = trace_fn.__name__
154
- op = InlineFeOperation(self, name, trace_fn)
155
- get_registry().register_op(self.name, name, op)
156
- return op
157
-
158
- return _decorator
159
-
160
- def simple_op(
161
- self, pfunc_name: str | None = None
162
- ) -> Callable[[Callable[..., Any]], FeOperation]:
163
- """Decorator for type-driven ops that return only types/schemas.
164
-
165
- The decorated kernel should compute and return a TensorType/TableType (or PyTree thereof).
166
- Positional inputs may be MPObjects (captured as inputs) or data-like values (TableLike/TensorLike)
167
- used for type inference/validation. Keyword arguments are PFunction attributes and must be plain
168
- Python values (int/float/str/bytes/tuples/lists of primitives). Passing MPObjects via kwargs is not allowed.
169
-
170
- SSOT naming: The operation name is derived from the kernel function name (kernel.__name__),
171
- ensuring there's a single source of truth and improving readability. Use clear, concise
172
- function names to define the public op names.
173
-
174
- Example:
175
- @mymod.typed_op(pfunc_name="builtin.add")
176
- def add_kernel(x: MPObject, y: MPObject) -> TensorType:
177
- return x.mptype._type # same shape/type as x
178
-
179
- Bad vs Good (signatures and calls):
180
- - Bad: def op(x: MPObject, **kwargs): ... # disallowed: **kwargs
181
- Good: def op(x: MPObject, *, attr: int): ...
182
-
183
- - Bad: def op(*args, **kwargs): ... # disallowed: *args/**kwargs
184
- Good: def op(x: MPObject, y: MPObject, *, k: str): ...
185
-
186
- - Bad: enc(plaintext=pt, key=mp_key) # MPObject via kwargs (disallowed)
187
- Good: enc(pt, mp_key) # pass MPObjects positionally
188
-
189
- - Good: hkdf(secret, "info") # data-like positional mapped to kw-only attr
190
- Also good: hkdf(secret, info="info")
191
-
192
- - Good: phe.mul(jnp.array(...), jnp.array(...)) # data-like positionals allowed for type inference
193
- """
194
-
195
- def _decorator(kernel: Callable[..., Any]) -> FeOperation:
196
- # Default PFunction routing when not provided: "<module>.<kernel_name>"
197
- final_pfunc_name = pfunc_name or f"{self.name}.{kernel.__name__}"
198
- op = SimpleFeOperation(self, final_pfunc_name, kernel)
199
- # Use kernel function name as SSOT for op name
200
- get_registry().register_op(self.name, op.name, op)
201
- return op
202
-
203
- return _decorator
204
-
205
-
206
- class StatelessFeModule(FeModule):
207
- """Stateless frontend module with no ctx-level state."""
208
-
209
- def initialize(self, ctx: MPContext) -> None:
210
- pass
211
-
212
-
213
- # -----------------------------------------------------------------------------
214
- # Class-based contracts and adapters
215
- # -----------------------------------------------------------------------------
216
-
217
-
218
- class FeOperation(ABC):
219
- """Class-based frontend operation contract.
220
-
221
- Subclasses implement trace() to produce a standard triad. __call__ delegates to trace().
222
- """
223
-
224
- module: FeModule
225
- name: str
226
-
227
- def __init__(self, module: FeModule, name: str):
228
- self.module = module
229
- self.name = name
230
-
231
- @abstractmethod
232
- def trace(self, *args: Any, **kwargs: Any) -> Triad:
233
- """Produce a standard triad for this operation."""
234
-
235
- # Convenience: allow calling an FeOperation like a function.
236
- def __call__(self, *args: Any, **kwargs: Any) -> Triad:
237
- return self.trace(*args, **kwargs)
238
-
239
-
240
- class InlineFeOperation(FeOperation):
241
- """FeOperation that delegates tracing to a provided triad-returning function."""
242
-
243
- def __init__(self, module: FeModule, name: str, trace_fn: Callable[..., Triad]):
244
- super().__init__(module, name)
245
- self._trace_fn = trace_fn
246
-
247
- # override
248
- def trace(self, *args: Any, **kwargs: Any) -> Triad:
249
- return self._trace_fn(*args, **kwargs)
250
-
251
-
252
- class SimpleFeOperation(FeOperation):
253
- """FeOperation that builds Triad from a type-only kernel.
254
-
255
- Contract (keep it simple):
256
- - Kernel computes and returns TensorType/TableType or a PyTree thereof.
257
- - Positional inputs may be MPObjects (captured as inputs) or data-like values (TableLike/TensorLike)
258
- used for type inference/validation. Keyword arguments are attributes and must be plain Python
259
- values (TensorType/TableType are also excluded from attrs). MPObject kwargs are disallowed.
260
- - Prefer keyword-only attributes in the kernel signature for explicitness. For convenience, non-MPObject
261
- positional values that are not data-like will be mapped to keyword-only parameters by order when possible.
262
- - No IR building inside the kernel; PFunction is assembled here with fn_type=pfunc_name.
263
- """
264
-
265
- def __init__(
266
- self,
267
- module: FeModule,
268
- pfunc_name: str,
269
- kernel: Callable[..., Any],
270
- ):
271
- # Derive operation name from kernel function name for SSOT
272
- super().__init__(module, kernel.__name__)
273
- self.pfunc_name = pfunc_name
274
- self._kernel = kernel
275
-
276
- # Validate kernel signature: typed_op kernels must not use *args/**kwargs.
277
- import inspect
278
-
279
- sig = inspect.signature(kernel)
280
- for p in sig.parameters.values():
281
- if p.kind == inspect.Parameter.VAR_KEYWORD:
282
- raise TypeError(
283
- f"typed_op kernel '{module.name}.{kernel.__name__}' must not use **kwargs; define explicit keywords instead"
284
- )
285
- if p.kind == inspect.Parameter.VAR_POSITIONAL:
286
- raise TypeError(
287
- f"typed_op kernel '{module.name}.{kernel.__name__}' must not use *args; define explicit parameters instead"
288
- )
289
-
290
- # Cache signature and kw-only parameter names for fast path in trace
291
- self._kernel_sig = sig
292
- self._kwonly_names = [
293
- p.name
294
- for p in sig.parameters.values()
295
- if p.kind == inspect.Parameter.KEYWORD_ONLY
296
- ]
297
-
298
- # override
299
- def trace(self, *args: MPObject, **kwargs: Any) -> Triad:
300
- # Actual params may not match kernel signature exactly, so we do flexible binding.
301
- sig = self._kernel_sig
302
-
303
- # Inputs at PFunction layer are MPObjects captured from positional args only.
304
- pos_mp_inputs: list[MPObject] = [a for a in args if isinstance(a, MPObject)]
305
-
306
- # Enforce: no MPObject kwargs per simplified contract
307
- for k, v in kwargs.items():
308
- if isinstance(v, MPObject):
309
- raise TypeError(
310
- f"typed_op does not accept MPObject kwargs: {k}; pass MPObjects positionally"
311
- )
312
-
313
- # Try original call; if it binds, keep it as-is to support data-like positionals
314
- try:
315
- sig.bind_partial(*args, **kwargs)
316
- call_pos = args
317
- call_kwargs = kwargs
318
- except TypeError as _bind_err:
319
- # Fallback: For convenience, map non-MPObject positional arguments to
320
- # keyword-only parameters by order. This allows ergonomic calls like
321
- # `crypto.keygen(32)` where the kernel is `def keygen(*, length: int)`.
322
- # The direct binding `sig.bind_partial(32)` would fail, so we manually
323
- # map the positional `32` to the `length` keyword.
324
- non_mp_positional = [a for a in args if not isinstance(a, MPObject)]
325
- call_kwargs = dict(kwargs)
326
- filled = 0
327
- for _i, name in enumerate(self._kwonly_names):
328
- if filled < len(non_mp_positional) and name not in call_kwargs:
329
- call_kwargs[name] = non_mp_positional[filled]
330
- filled += 1
331
- if filled < len(non_mp_positional):
332
- leftover = non_mp_positional[filled:]
333
- raise TypeError(
334
- f"too many non-MPObject positional values for typed_op '{self.module.name}.{self.name}': {leftover}. "
335
- "Pass attributes explicitly by keyword (e.g., foo(x, *, attr=...))."
336
- ) from None
337
- call_pos = tuple(pos_mp_inputs)
338
-
339
- # Compute PFunction attrs from the call kwargs (exclude MPObject and type objects)
340
- attr_kwargs: dict[str, Any] = {
341
- k: v
342
- for k, v in call_kwargs.items()
343
- if not isinstance(v, MPObject)
344
- and not isinstance(v, (TensorType, TableType))
345
- }
346
-
347
- # Prepare kernel positional arguments: replace MPObject with its underlying type so
348
- # the kernel always sees TensorType/TableType (never TraceVar/InterpVar).
349
- call_pos_types = tuple(a.mptype._type for a in call_pos)
350
-
351
- # Sanity: no MPObject should appear in kwargs (enforced earlier), but be safe.
352
- if any(isinstance(v, MPObject) for v in call_kwargs.values()):
353
- raise TypeError("kernel kwargs should not be MPObject")
354
-
355
- # Execute kernel to compute return types
356
- result = self._kernel(*call_pos_types, **call_kwargs)
357
-
358
- outs_info, out_tree = tree_flatten(result)
359
-
360
- # ensure all out_vars are TensorType or TableType.
361
- # TODO(jint), theoretically we can also python constants here.
362
- for o in outs_info:
363
- if not isinstance(o, (TensorType, TableType)):
364
- raise TypeError(
365
- f"simple op kernel must return TensorType or TableType, got {type(o).__name__}"
366
- )
367
-
368
- # Build input types from positional MPObjects only
369
- ins_info = [a.mptype._type for a in pos_mp_inputs]
370
-
371
- # Compose PFunction and return triad
372
- pfunc = PFunction(
373
- fn_type=self.pfunc_name,
374
- ins_info=tuple(ins_info),
375
- outs_info=tuple(outs_info),
376
- **attr_kwargs,
377
- )
378
- return pfunc, pos_mp_inputs, out_tree
379
-
380
-
381
- def stateless_mod(mod_name: str) -> FeModule:
382
- return StatelessFeModule(mod_name)
383
-
384
-
385
- def list_ops(module: str | None = None) -> dict[tuple[str, str], FeOperation]:
386
- """Return a view of registered feops, optionally filtered by module name."""
387
- return get_registry().list_ops(module)
388
-
389
-
390
- # -----------------------------------------------------------------------------
391
- # Guidance: complex ops via subclassing
392
- # -----------------------------------------------------------------------------
393
-
394
- # Example pattern (non-executable) showing how a complex op (e.g., jax_cc) could
395
- # capture a Python callable and compile it to a Triad by subclassing FeOperation.
396
- #
397
- # class JaxCompileOp(FeOperation):
398
- # def __init__(self, module: FeModule, name: str, func: Callable[..., Any], *,
399
- # fn_type: str = "mlir.stablehlo", **options: Any) -> None:
400
- # super().__init__(module, name)
401
- # self.func = func
402
- # self.fn_type = fn_type
403
- # self.options = dict(options)
404
- #
405
- # def trace(self, *args: MPObject, **kwargs: Any) -> Triad:
406
- # # 1) Infer output types from func and args, respecting current ctx/masks.
407
- # # 2) Build PFunction with fn_type=self.fn_type and any attributes.
408
- # # 3) Return (pfunc, list(args), out_tree)
409
- # raise NotImplementedError
410
-
411
-
412
- # -----------------------------------------------------------------------------
413
- # Migration notes (checklist)
414
- # -----------------------------------------------------------------------------
415
-
416
- # - Replace any isinstance(FEOp)/metadata checks with isinstance(x, FeOperation).
417
- # - Define a FeModule via femod("module_name") and register it in FeRegistry automatically.
418
- # - For inline ops that already produce a triad, use @module.feop()(trace_fn). The op name is derived from the function name.
419
- # - For type-only kernels, use @module.typed_op(pfunc_name)(kernel). The op name is derived from the kernel function name.
420
- # - For complex ops (with Python callables/closures), subclass FeOperation and register
421
- # using get_registry().register_op(module, name, op_instance) or use @module.feop with InlineFeOperation.
422
- # - Ensure PFunction.fn_type is set as the routing key (e.g., "mlir.stablehlo", "sql.run").
423
- # - Keep device selection/routing out of frontend code; only set fn_type and attributes.
424
- # - Avoid moving MPObjects across contexts directly; capture within current ctx in trace().