mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Primitive: User-facing API for building atomic operations.
|
|
16
|
+
|
|
17
|
+
Provides the Primitive class for defining operations that automatically work in
|
|
18
|
+
both trace mode (record to Graph IR) and interp mode (execute immediately).
|
|
19
|
+
|
|
20
|
+
See Primitive class documentation for detailed usage examples.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
from collections.abc import Callable, Sequence
|
|
26
|
+
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
|
|
27
|
+
|
|
28
|
+
from jax.tree_util import tree_map
|
|
29
|
+
|
|
30
|
+
from mplang.v2.edsl.context import get_current_context, get_default_context
|
|
31
|
+
from mplang.v2.edsl.object import Object
|
|
32
|
+
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from mplang.v2.edsl.typing import BaseType
|
|
35
|
+
|
|
36
|
+
T_Ret = TypeVar("T_Ret")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class Primitive(Generic[T_Ret]):
|
|
40
|
+
"""Atomic operation definition (similar to JAX Primitive).
|
|
41
|
+
|
|
42
|
+
A Primitive represents an atomic operation that can be:
|
|
43
|
+
1. **Traced**: Records operation to Graph IR (via abstract_eval or trace)
|
|
44
|
+
2. **Executed**: Runs via backend execution of Graph IR
|
|
45
|
+
|
|
46
|
+
Attributes:
|
|
47
|
+
name: Unique name of the primitive (e.g., "add", "mul", "encrypt")
|
|
48
|
+
_abstract_eval: Type inference function (type → type)
|
|
49
|
+
_trace: Custom trace logic for complex operations
|
|
50
|
+
|
|
51
|
+
Example:
|
|
52
|
+
>>> # Define custom FHE encryption primitive
|
|
53
|
+
>>> encrypt_p = Primitive("fhe_encrypt")
|
|
54
|
+
>>>
|
|
55
|
+
>>> @encrypt_p.def_abstract_eval
|
|
56
|
+
>>> def encrypt_abstract(x_type):
|
|
57
|
+
>>> from mplang.v2.edsl.typing import Vector
|
|
58
|
+
>>> return Vector[x_type.dtype, x_type.shape]
|
|
59
|
+
>>>
|
|
60
|
+
>>> # Execution happens via Graph IR → Backend
|
|
61
|
+
>>> # Backend handles FHE library calls based on operation type
|
|
62
|
+
>>>
|
|
63
|
+
>>> # Usage
|
|
64
|
+
>>> plaintext = TraceObject(...)
|
|
65
|
+
>>> ciphertext = encrypt_p.bind(plaintext) # Records to Graph IR
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(self, name: str):
|
|
69
|
+
"""Initialize a primitive with a unique name.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
name: Unique identifier for this primitive (e.g., "add", "encrypt")
|
|
73
|
+
"""
|
|
74
|
+
self.name = name
|
|
75
|
+
self._abstract_eval: Callable[..., BaseType | Sequence[BaseType]] | None = None
|
|
76
|
+
self._trace: Callable[..., Any] | None = None
|
|
77
|
+
self._impl: Callable[..., Any] | None = None
|
|
78
|
+
|
|
79
|
+
def def_impl(self, fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
80
|
+
"""Define execution logic for this primitive in the interpreter.
|
|
81
|
+
|
|
82
|
+
This function is called by the Interpreter during eager execution or
|
|
83
|
+
when evaluating a graph.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
fn: Function that implements the operation.
|
|
87
|
+
Signature: (interpreter, op, *args) -> result
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
The same function (for decorator pattern)
|
|
91
|
+
"""
|
|
92
|
+
self._impl = fn
|
|
93
|
+
# Register with the global interpreter registry
|
|
94
|
+
from mplang.v2.edsl.registry import register_impl
|
|
95
|
+
|
|
96
|
+
register_impl(self.name, fn)
|
|
97
|
+
return fn
|
|
98
|
+
|
|
99
|
+
def def_abstract_eval(
|
|
100
|
+
self, fn: Callable[..., BaseType | Sequence[BaseType]]
|
|
101
|
+
) -> Callable[..., BaseType | Sequence[BaseType]]:
|
|
102
|
+
"""Define type inference rule for this primitive.
|
|
103
|
+
|
|
104
|
+
This function is called during tracing to infer output types from input types.
|
|
105
|
+
Supports both single-output and multi-output primitives.
|
|
106
|
+
|
|
107
|
+
Supported signatures:
|
|
108
|
+
1. Positional form (variable number of input types):
|
|
109
|
+
(*in_types: BaseType, **attrs) -> BaseType | Sequence[BaseType]
|
|
110
|
+
|
|
111
|
+
2. Flat form (input types as list):
|
|
112
|
+
(in_types: list[BaseType], **attrs) -> BaseType | Sequence[BaseType]
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
fn: Function that takes input types and returns output type(s)
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
The same function (for decorator pattern)
|
|
119
|
+
|
|
120
|
+
Example (positional form):
|
|
121
|
+
>>> add_p = Primitive("add")
|
|
122
|
+
>>>
|
|
123
|
+
>>> @add_p.def_abstract_eval
|
|
124
|
+
>>> def add_abstract(x_type: BaseType, y_type: BaseType) -> BaseType:
|
|
125
|
+
>>> assert x_type == y_type, "Inputs must have same type"
|
|
126
|
+
>>> return x_type
|
|
127
|
+
|
|
128
|
+
Example (positional form, multi-output):
|
|
129
|
+
>>> split_p = Primitive("split")
|
|
130
|
+
>>>
|
|
131
|
+
>>> @split_p.def_abstract_eval
|
|
132
|
+
>>> def split_abstract(x_type: BaseType, *, num_splits: int) -> list[BaseType]:
|
|
133
|
+
>>> return [x_type] * num_splits
|
|
134
|
+
|
|
135
|
+
Example (flat form):
|
|
136
|
+
>>> concat_p = Primitive("concat")
|
|
137
|
+
>>>
|
|
138
|
+
>>> @concat_p.def_abstract_eval
|
|
139
|
+
>>> def concat_abstract(in_types: list[BaseType], *, axis: int = 0) -> BaseType:
|
|
140
|
+
>>> # Variable number of inputs
|
|
141
|
+
>>> return in_types[0] # Concatenated type
|
|
142
|
+
"""
|
|
143
|
+
self._abstract_eval = fn
|
|
144
|
+
return fn
|
|
145
|
+
|
|
146
|
+
def def_trace(self, fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
147
|
+
"""Define custom trace logic for this primitive.
|
|
148
|
+
|
|
149
|
+
This method enables full control over the tracing process, suitable for
|
|
150
|
+
complex scenarios like:
|
|
151
|
+
- Integrating external functions (JAX, FHE, etc.)
|
|
152
|
+
- Accepting arbitrary PyTree inputs mixing Objects and constants
|
|
153
|
+
- Producing arbitrary PyTree outputs
|
|
154
|
+
|
|
155
|
+
The decorated function receives raw args/kwargs and returns the result PyTree.
|
|
156
|
+
The tracer automatically handles:
|
|
157
|
+
- Extracting Objects from input PyTree (via var_morph)
|
|
158
|
+
- Recording morph structure to Operation attrs
|
|
159
|
+
- Flattening output PyTree
|
|
160
|
+
- Reconstructing output structure during interpretation
|
|
161
|
+
|
|
162
|
+
Signature: (*args, **kwargs) -> Object | PyTree[Object]
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
fn: Custom trace function that takes arbitrary args/kwargs and
|
|
166
|
+
returns result PyTree (can contain Objects and constants)
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
The same function (for decorator pattern)
|
|
170
|
+
|
|
171
|
+
Example (JAX integration):
|
|
172
|
+
>>> run_jax_p = Primitive("run_jax")
|
|
173
|
+
>>>
|
|
174
|
+
>>> @run_jax_p.def_trace
|
|
175
|
+
>>> def run_jax_trace(jax_fn: Callable, *args, **kwargs):
|
|
176
|
+
>>> # args/kwargs can mix Objects and constants
|
|
177
|
+
>>> # Compile JAX function and execute
|
|
178
|
+
>>> result = compile_and_run(jax_fn, args, kwargs)
|
|
179
|
+
>>> return result # Can be any PyTree structure
|
|
180
|
+
>>>
|
|
181
|
+
>>> # Example (multi-output):
|
|
182
|
+
>>> split_p = Primitive("split")
|
|
183
|
+
>>>
|
|
184
|
+
>>> @split_p.def_trace
|
|
185
|
+
>>> def split_trace(x: Object, *, num_splits: int):
|
|
186
|
+
>>> # Call underlying operations
|
|
187
|
+
>>> parts = [slice_p.bind(x, i) for i in range(num_splits)]
|
|
188
|
+
>>> return parts # Returns list of Objects
|
|
189
|
+
"""
|
|
190
|
+
self._trace = fn
|
|
191
|
+
return fn
|
|
192
|
+
|
|
193
|
+
def bind(self, *args: Any, **kwargs: Any) -> T_Ret:
|
|
194
|
+
"""Bind arguments and execute/trace the primitive.
|
|
195
|
+
|
|
196
|
+
This is the main user-facing API. It automatically chooses between:
|
|
197
|
+
- **Trace mode**: Record operation to Graph IR (if in Tracer context)
|
|
198
|
+
- **Interp mode**: Execute Graph IR via backend (if in Interpreter context)
|
|
199
|
+
|
|
200
|
+
Behavior depends on which method was used to define the primitive:
|
|
201
|
+
- **def_abstract_eval**: Positional args must be Objects (inputs),
|
|
202
|
+
kwargs must be plain values (attrs). Returns single Object or list[Object].
|
|
203
|
+
- **def_trace**: Both args and kwargs can mix Objects and plain values.
|
|
204
|
+
Returns arbitrary PyTree structure.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
*args: Positional arguments
|
|
208
|
+
**kwargs: Keyword arguments
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Object | PyTree[Object] - Result structure depends on primitive definition
|
|
212
|
+
|
|
213
|
+
Raises:
|
|
214
|
+
RuntimeError: If neither abstract_eval nor trace is defined
|
|
215
|
+
TypeError: If using def_abstract_eval and kwargs contain Object instances
|
|
216
|
+
|
|
217
|
+
Example:
|
|
218
|
+
>>> # With def_abstract_eval (simple form)
|
|
219
|
+
>>> z = add_p.bind(x, y) # x, y are Objects
|
|
220
|
+
>>>
|
|
221
|
+
>>> # With def_trace (full form)
|
|
222
|
+
>>> result = run_jax_p.bind(fn, obj1, 42, obj2, k=3.14)
|
|
223
|
+
>>> # Mixing Objects (obj1, obj2) and constants (42, 3.14)
|
|
224
|
+
"""
|
|
225
|
+
# Get current context
|
|
226
|
+
ctx = get_current_context()
|
|
227
|
+
if ctx is None:
|
|
228
|
+
ctx = get_default_context()
|
|
229
|
+
|
|
230
|
+
def lift_if_object(x: Any) -> Any: # Add type annotation
|
|
231
|
+
return ctx.lift(x) if isinstance(x, Object) else x
|
|
232
|
+
|
|
233
|
+
lifted_args, lifted_kwargs = tree_map(lift_if_object, (args, kwargs))
|
|
234
|
+
|
|
235
|
+
# Execute in context
|
|
236
|
+
return cast(T_Ret, ctx.bind_primitive(self, lifted_args, lifted_kwargs))
|
|
237
|
+
|
|
238
|
+
def __call__(self, *args: Any, **kwargs: Any) -> T_Ret:
|
|
239
|
+
"""Syntactic sugar for bind(): primitive(*args, **kwargs) == primitive.bind(*args, **kwargs)."""
|
|
240
|
+
return self.bind(*args, **kwargs)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
# ============================================================================
|
|
244
|
+
# Decorator: @primitive for defining primitives in a concise way
|
|
245
|
+
# ============================================================================
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def primitive(name: str) -> Callable[[Callable], Primitive]:
|
|
249
|
+
"""Decorator for defining primitives in a concise way.
|
|
250
|
+
|
|
251
|
+
This is a convenience decorator that creates a Primitive and registers
|
|
252
|
+
the decorated function as its abstract_eval rule.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
name: Unique name for the primitive
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
Decorator function
|
|
259
|
+
|
|
260
|
+
Example:
|
|
261
|
+
>>> @primitive("my_custom_op")
|
|
262
|
+
>>> def my_op_abstract(x_type: BaseType, y_type: BaseType) -> BaseType:
|
|
263
|
+
>>> # Type inference logic
|
|
264
|
+
>>> return x_type
|
|
265
|
+
>>>
|
|
266
|
+
>>> # The decorator returns a Primitive instance
|
|
267
|
+
>>> my_op_p = my_op_abstract
|
|
268
|
+
>>>
|
|
269
|
+
>>> # Use it (execution via Graph IR → Backend)
|
|
270
|
+
>>> z = my_op_p.bind(x, y)
|
|
271
|
+
"""
|
|
272
|
+
|
|
273
|
+
def decorator(fn: Callable) -> Primitive[Any]:
|
|
274
|
+
p: Primitive[Any] = Primitive(name)
|
|
275
|
+
p.def_abstract_eval(fn)
|
|
276
|
+
return p
|
|
277
|
+
|
|
278
|
+
return decorator
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
__all__ = [
|
|
282
|
+
"Primitive",
|
|
283
|
+
"primitive",
|
|
284
|
+
]
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Pretty printer for the EDSL Graph IR."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from mplang.v2.edsl.graph import Graph, Operation, Value
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class GraphPrinter:
|
|
25
|
+
"""Format Graph IR in a readable, MLIR-like style."""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
*,
|
|
30
|
+
indent_size: int = 2,
|
|
31
|
+
show_types: bool = True,
|
|
32
|
+
show_attrs: bool = True,
|
|
33
|
+
):
|
|
34
|
+
self.indent_size = indent_size
|
|
35
|
+
self.show_types = show_types
|
|
36
|
+
self.show_attrs = show_attrs
|
|
37
|
+
|
|
38
|
+
def format(self, graph: Graph) -> str:
|
|
39
|
+
"""Return a formatted string representation of `graph`."""
|
|
40
|
+
lines: list[str] = []
|
|
41
|
+
self._format_graph(graph, lines, indent_level=0, heading=None)
|
|
42
|
+
return "\n".join(lines)
|
|
43
|
+
|
|
44
|
+
# ------------------------------------------------------------------
|
|
45
|
+
# Internal helpers
|
|
46
|
+
# ------------------------------------------------------------------
|
|
47
|
+
def _write(self, lines: list[str], indent_level: int, text: str) -> None:
|
|
48
|
+
indent = " " * (indent_level * self.indent_size)
|
|
49
|
+
lines.append(f"{indent}{text}")
|
|
50
|
+
|
|
51
|
+
def _format_graph(
|
|
52
|
+
self, graph: Graph, lines: list[str], indent_level: int, heading: str | None
|
|
53
|
+
) -> None:
|
|
54
|
+
header_prefix = f"{heading}" if heading else ""
|
|
55
|
+
params_str = self._format_params(graph.inputs)
|
|
56
|
+
self._write(lines, indent_level, f"{header_prefix}{params_str} {{")
|
|
57
|
+
|
|
58
|
+
for op in graph.operations:
|
|
59
|
+
self._format_operation(op, lines, indent_level + 1)
|
|
60
|
+
|
|
61
|
+
if graph.outputs:
|
|
62
|
+
out_names = ", ".join(val.name for val in graph.outputs)
|
|
63
|
+
self._write(lines, indent_level + 1, f"return {out_names}")
|
|
64
|
+
|
|
65
|
+
self._write(lines, indent_level, "}")
|
|
66
|
+
|
|
67
|
+
def _format_params(self, inputs: list[Value]) -> str:
|
|
68
|
+
if not inputs:
|
|
69
|
+
return "()"
|
|
70
|
+
parts: list[str] = []
|
|
71
|
+
for value in inputs:
|
|
72
|
+
if self.show_types:
|
|
73
|
+
parts.append(f"{value.name}: {value.type}")
|
|
74
|
+
else:
|
|
75
|
+
parts.append(f"{value.name}")
|
|
76
|
+
joined = ", ".join(parts)
|
|
77
|
+
return f"({joined})"
|
|
78
|
+
|
|
79
|
+
def _format_operation(
|
|
80
|
+
self, op: Operation, lines: list[str], indent_level: int
|
|
81
|
+
) -> None:
|
|
82
|
+
lhs = self._format_outputs(op.outputs)
|
|
83
|
+
inputs_str = ", ".join(val.name for val in op.inputs)
|
|
84
|
+
attrs_str = self._format_attrs(op.attrs)
|
|
85
|
+
type_str = self._format_output_types(op.outputs)
|
|
86
|
+
op_line = f"{lhs} = {op.opcode}({inputs_str}){attrs_str}{type_str}"
|
|
87
|
+
if op.regions:
|
|
88
|
+
self._write(lines, indent_level, f"{op_line} {{")
|
|
89
|
+
for region in op.regions:
|
|
90
|
+
self._format_graph(region, lines, indent_level + 1, heading=None)
|
|
91
|
+
self._write(lines, indent_level, "}")
|
|
92
|
+
else:
|
|
93
|
+
self._write(lines, indent_level, op_line)
|
|
94
|
+
|
|
95
|
+
def _format_outputs(self, outputs: list[Value]) -> str:
|
|
96
|
+
if not outputs:
|
|
97
|
+
return "[]"
|
|
98
|
+
if len(outputs) == 1:
|
|
99
|
+
return outputs[0].name
|
|
100
|
+
return "[" + ", ".join(val.name for val in outputs) + "]"
|
|
101
|
+
|
|
102
|
+
def _format_attrs(self, attrs: dict[str, Any]) -> str:
|
|
103
|
+
if not self.show_attrs or not attrs:
|
|
104
|
+
return ""
|
|
105
|
+
parts = [f"{key}={attrs[key]!r}" for key in sorted(attrs)]
|
|
106
|
+
return " {" + ", ".join(parts) + "}"
|
|
107
|
+
|
|
108
|
+
def _format_output_types(self, outputs: list[Value]) -> str:
|
|
109
|
+
if not self.show_types or not outputs:
|
|
110
|
+
return ""
|
|
111
|
+
type_strings = [str(val.type) for val in outputs]
|
|
112
|
+
if len(type_strings) == 1:
|
|
113
|
+
return f" : {type_strings[0]}"
|
|
114
|
+
return " : (" + ", ".join(type_strings) + ")"
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def format_graph(graph: Graph, **kwargs: Any) -> str:
|
|
118
|
+
"""Convenience helper that returns `GraphPrinter(**kwargs).format(graph)`."""
|
|
119
|
+
return GraphPrinter(**kwargs).format(graph)
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Registry for primitive implementations.
|
|
16
|
+
|
|
17
|
+
This module decouples the Primitive definition from the Interpreter execution.
|
|
18
|
+
Primitives register their implementations here, and the Interpreter looks them up here.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import time
|
|
24
|
+
from collections import defaultdict
|
|
25
|
+
from collections.abc import Callable
|
|
26
|
+
from dataclasses import dataclass, field
|
|
27
|
+
from typing import Any
|
|
28
|
+
|
|
29
|
+
# Global registry for primitive implementations
|
|
30
|
+
# Key: opcode (str), Value: implementation function
|
|
31
|
+
_IMPL_REGISTRY: dict[str, Callable[..., Any]] = {}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# ==============================================================================
|
|
35
|
+
# Profiler for All Primitive Operations
|
|
36
|
+
# ==============================================================================
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class OpProfiler:
|
|
41
|
+
"""Global profiler for tracking all primitive operation timing."""
|
|
42
|
+
|
|
43
|
+
enabled: bool = False
|
|
44
|
+
timings: dict[str, list[float]] = field(default_factory=lambda: defaultdict(list))
|
|
45
|
+
|
|
46
|
+
def reset(self) -> None:
|
|
47
|
+
"""Clear all timing data."""
|
|
48
|
+
self.timings = defaultdict(list)
|
|
49
|
+
|
|
50
|
+
def record(self, opcode: str, duration: float) -> None:
|
|
51
|
+
"""Record a timing measurement."""
|
|
52
|
+
if self.enabled:
|
|
53
|
+
self.timings[opcode].append(duration)
|
|
54
|
+
|
|
55
|
+
def summary(self) -> dict[str, dict[str, float]]:
|
|
56
|
+
"""Get summary statistics for all operations."""
|
|
57
|
+
result = {}
|
|
58
|
+
for opcode, times in sorted(self.timings.items()):
|
|
59
|
+
if times:
|
|
60
|
+
result[opcode] = {
|
|
61
|
+
"count": len(times),
|
|
62
|
+
"total": sum(times),
|
|
63
|
+
"mean": sum(times) / len(times),
|
|
64
|
+
"min": min(times),
|
|
65
|
+
"max": max(times),
|
|
66
|
+
}
|
|
67
|
+
return result
|
|
68
|
+
|
|
69
|
+
def print_summary(self, top_n: int = 20) -> None:
|
|
70
|
+
"""Print a formatted summary of timing statistics."""
|
|
71
|
+
stats = self.summary()
|
|
72
|
+
if not stats:
|
|
73
|
+
print("No timing data collected.")
|
|
74
|
+
return
|
|
75
|
+
|
|
76
|
+
print("\n" + "=" * 80)
|
|
77
|
+
print("PRIMITIVE OPERATION TIMING SUMMARY")
|
|
78
|
+
print("=" * 80)
|
|
79
|
+
print(
|
|
80
|
+
f"{'Operation':<35} {'Count':>8} {'Total(s)':>10} "
|
|
81
|
+
f"{'Mean(ms)':>10} {'Max(ms)':>10}"
|
|
82
|
+
)
|
|
83
|
+
print("-" * 80)
|
|
84
|
+
|
|
85
|
+
total_time = sum(s["total"] for s in stats.values())
|
|
86
|
+
|
|
87
|
+
# Sort by total time descending
|
|
88
|
+
sorted_stats = sorted(stats.items(), key=lambda x: -x[1]["total"])
|
|
89
|
+
|
|
90
|
+
for opcode, s in sorted_stats[:top_n]:
|
|
91
|
+
pct = s["total"] / total_time * 100 if total_time > 0 else 0
|
|
92
|
+
print(
|
|
93
|
+
f"{opcode:<35} {s['count']:>8} {s['total']:>10.3f} "
|
|
94
|
+
f"{s['mean'] * 1000:>10.3f} {s['max'] * 1000:>10.3f} ({pct:>5.1f}%)"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
if len(sorted_stats) > top_n:
|
|
98
|
+
print(f" ... and {len(sorted_stats) - top_n} more operations")
|
|
99
|
+
|
|
100
|
+
print("-" * 80)
|
|
101
|
+
print(f"{'TOTAL':<35} {'':<8} {total_time:>10.3f}s")
|
|
102
|
+
|
|
103
|
+
def print_leaf_summary(self, top_n: int = 20) -> None:
|
|
104
|
+
"""Print summary excluding container ops (pcall, shuffle, etc.).
|
|
105
|
+
|
|
106
|
+
This shows only 'leaf' operations that don't contain nested calls,
|
|
107
|
+
giving accurate self-time without double-counting.
|
|
108
|
+
"""
|
|
109
|
+
# Container ops that include nested operation time
|
|
110
|
+
container_ops = {
|
|
111
|
+
"simp.pcall_static",
|
|
112
|
+
"simp.pcall_dynamic",
|
|
113
|
+
"simp.shuffle_static",
|
|
114
|
+
"simp.shuffle",
|
|
115
|
+
"simp.uniform_cond",
|
|
116
|
+
"simp.while_loop",
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
stats = self.summary()
|
|
120
|
+
leaf_stats = {k: v for k, v in stats.items() if k not in container_ops}
|
|
121
|
+
|
|
122
|
+
if not leaf_stats:
|
|
123
|
+
print("No leaf timing data collected.")
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
print("\n" + "=" * 80)
|
|
127
|
+
print("LEAF OPERATION TIMING SUMMARY (excludes container ops)")
|
|
128
|
+
print("=" * 80)
|
|
129
|
+
print(
|
|
130
|
+
f"{'Operation':<35} {'Count':>8} {'Total(s)':>10} "
|
|
131
|
+
f"{'Mean(ms)':>10} {'Max(ms)':>10}"
|
|
132
|
+
)
|
|
133
|
+
print("-" * 80)
|
|
134
|
+
|
|
135
|
+
total_time = sum(s["total"] for s in leaf_stats.values())
|
|
136
|
+
sorted_stats = sorted(leaf_stats.items(), key=lambda x: -x[1]["total"])
|
|
137
|
+
|
|
138
|
+
for opcode, s in sorted_stats[:top_n]:
|
|
139
|
+
pct = s["total"] / total_time * 100 if total_time > 0 else 0
|
|
140
|
+
print(
|
|
141
|
+
f"{opcode:<35} {s['count']:>8} {s['total']:>10.3f} "
|
|
142
|
+
f"{s['mean'] * 1000:>10.3f} {s['max'] * 1000:>10.3f} ({pct:>5.1f}%)"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
if len(sorted_stats) > top_n:
|
|
146
|
+
print(f" ... and {len(sorted_stats) - top_n} more operations")
|
|
147
|
+
|
|
148
|
+
print("-" * 80)
|
|
149
|
+
print(f"{'TOTAL (leaf ops)':<35} {'':<8} {total_time:>10.3f}s")
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# Global profiler instance
|
|
153
|
+
_profiler = OpProfiler()
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def get_profiler() -> OpProfiler:
|
|
157
|
+
"""Get the global operation profiler instance."""
|
|
158
|
+
return _profiler
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def enable_profiling() -> None:
|
|
162
|
+
"""Enable primitive operation profiling."""
|
|
163
|
+
_profiler.enabled = True
|
|
164
|
+
_profiler.reset()
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def disable_profiling() -> None:
|
|
168
|
+
"""Disable primitive operation profiling."""
|
|
169
|
+
_profiler.enabled = False
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
# ==============================================================================
|
|
173
|
+
# Registry Functions
|
|
174
|
+
# ==============================================================================
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def register_impl(opcode: str, fn: Callable[..., Any]) -> None:
|
|
178
|
+
"""Register an implementation for an opcode.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
opcode: The unique name of the primitive (e.g. "add", "mul").
|
|
182
|
+
fn: The function implementing the logic.
|
|
183
|
+
Signature: (interpreter, op, *args) -> result
|
|
184
|
+
"""
|
|
185
|
+
_IMPL_REGISTRY[opcode] = fn
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def get_impl(opcode: str) -> Callable[..., Any] | None:
|
|
189
|
+
"""Get the registered implementation for an opcode.
|
|
190
|
+
|
|
191
|
+
If profiling is enabled, returns a wrapped function that records timing.
|
|
192
|
+
"""
|
|
193
|
+
fn = _IMPL_REGISTRY.get(opcode)
|
|
194
|
+
if fn is None:
|
|
195
|
+
return None
|
|
196
|
+
|
|
197
|
+
if not _profiler.enabled:
|
|
198
|
+
return fn
|
|
199
|
+
|
|
200
|
+
# Return a profiling wrapper
|
|
201
|
+
def profiled_fn(interpreter: Any, op: Any, *args: Any) -> Any:
|
|
202
|
+
t0 = time.perf_counter()
|
|
203
|
+
result = fn(interpreter, op, *args)
|
|
204
|
+
_profiler.record(opcode, time.perf_counter() - t0)
|
|
205
|
+
return result
|
|
206
|
+
|
|
207
|
+
return profiled_fn
|