mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,284 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Primitive: User-facing API for building atomic operations.
16
+
17
+ Provides the Primitive class for defining operations that automatically work in
18
+ both trace mode (record to Graph IR) and interp mode (execute immediately).
19
+
20
+ See Primitive class documentation for detailed usage examples.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ from collections.abc import Callable, Sequence
26
+ from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
27
+
28
+ from jax.tree_util import tree_map
29
+
30
+ from mplang.v2.edsl.context import get_current_context, get_default_context
31
+ from mplang.v2.edsl.object import Object
32
+
33
+ if TYPE_CHECKING:
34
+ from mplang.v2.edsl.typing import BaseType
35
+
36
+ T_Ret = TypeVar("T_Ret")
37
+
38
+
39
+ class Primitive(Generic[T_Ret]):
40
+ """Atomic operation definition (similar to JAX Primitive).
41
+
42
+ A Primitive represents an atomic operation that can be:
43
+ 1. **Traced**: Records operation to Graph IR (via abstract_eval or trace)
44
+ 2. **Executed**: Runs via backend execution of Graph IR
45
+
46
+ Attributes:
47
+ name: Unique name of the primitive (e.g., "add", "mul", "encrypt")
48
+ _abstract_eval: Type inference function (type → type)
49
+ _trace: Custom trace logic for complex operations
50
+
51
+ Example:
52
+ >>> # Define custom FHE encryption primitive
53
+ >>> encrypt_p = Primitive("fhe_encrypt")
54
+ >>>
55
+ >>> @encrypt_p.def_abstract_eval
56
+ >>> def encrypt_abstract(x_type):
57
+ >>> from mplang.v2.edsl.typing import Vector
58
+ >>> return Vector[x_type.dtype, x_type.shape]
59
+ >>>
60
+ >>> # Execution happens via Graph IR → Backend
61
+ >>> # Backend handles FHE library calls based on operation type
62
+ >>>
63
+ >>> # Usage
64
+ >>> plaintext = TraceObject(...)
65
+ >>> ciphertext = encrypt_p.bind(plaintext) # Records to Graph IR
66
+ """
67
+
68
+ def __init__(self, name: str):
69
+ """Initialize a primitive with a unique name.
70
+
71
+ Args:
72
+ name: Unique identifier for this primitive (e.g., "add", "encrypt")
73
+ """
74
+ self.name = name
75
+ self._abstract_eval: Callable[..., BaseType | Sequence[BaseType]] | None = None
76
+ self._trace: Callable[..., Any] | None = None
77
+ self._impl: Callable[..., Any] | None = None
78
+
79
+ def def_impl(self, fn: Callable[..., Any]) -> Callable[..., Any]:
80
+ """Define execution logic for this primitive in the interpreter.
81
+
82
+ This function is called by the Interpreter during eager execution or
83
+ when evaluating a graph.
84
+
85
+ Args:
86
+ fn: Function that implements the operation.
87
+ Signature: (interpreter, op, *args) -> result
88
+
89
+ Returns:
90
+ The same function (for decorator pattern)
91
+ """
92
+ self._impl = fn
93
+ # Register with the global interpreter registry
94
+ from mplang.v2.edsl.registry import register_impl
95
+
96
+ register_impl(self.name, fn)
97
+ return fn
98
+
99
+ def def_abstract_eval(
100
+ self, fn: Callable[..., BaseType | Sequence[BaseType]]
101
+ ) -> Callable[..., BaseType | Sequence[BaseType]]:
102
+ """Define type inference rule for this primitive.
103
+
104
+ This function is called during tracing to infer output types from input types.
105
+ Supports both single-output and multi-output primitives.
106
+
107
+ Supported signatures:
108
+ 1. Positional form (variable number of input types):
109
+ (*in_types: BaseType, **attrs) -> BaseType | Sequence[BaseType]
110
+
111
+ 2. Flat form (input types as list):
112
+ (in_types: list[BaseType], **attrs) -> BaseType | Sequence[BaseType]
113
+
114
+ Args:
115
+ fn: Function that takes input types and returns output type(s)
116
+
117
+ Returns:
118
+ The same function (for decorator pattern)
119
+
120
+ Example (positional form):
121
+ >>> add_p = Primitive("add")
122
+ >>>
123
+ >>> @add_p.def_abstract_eval
124
+ >>> def add_abstract(x_type: BaseType, y_type: BaseType) -> BaseType:
125
+ >>> assert x_type == y_type, "Inputs must have same type"
126
+ >>> return x_type
127
+
128
+ Example (positional form, multi-output):
129
+ >>> split_p = Primitive("split")
130
+ >>>
131
+ >>> @split_p.def_abstract_eval
132
+ >>> def split_abstract(x_type: BaseType, *, num_splits: int) -> list[BaseType]:
133
+ >>> return [x_type] * num_splits
134
+
135
+ Example (flat form):
136
+ >>> concat_p = Primitive("concat")
137
+ >>>
138
+ >>> @concat_p.def_abstract_eval
139
+ >>> def concat_abstract(in_types: list[BaseType], *, axis: int = 0) -> BaseType:
140
+ >>> # Variable number of inputs
141
+ >>> return in_types[0] # Concatenated type
142
+ """
143
+ self._abstract_eval = fn
144
+ return fn
145
+
146
+ def def_trace(self, fn: Callable[..., Any]) -> Callable[..., Any]:
147
+ """Define custom trace logic for this primitive.
148
+
149
+ This method enables full control over the tracing process, suitable for
150
+ complex scenarios like:
151
+ - Integrating external functions (JAX, FHE, etc.)
152
+ - Accepting arbitrary PyTree inputs mixing Objects and constants
153
+ - Producing arbitrary PyTree outputs
154
+
155
+ The decorated function receives raw args/kwargs and returns the result PyTree.
156
+ The tracer automatically handles:
157
+ - Extracting Objects from input PyTree (via var_morph)
158
+ - Recording morph structure to Operation attrs
159
+ - Flattening output PyTree
160
+ - Reconstructing output structure during interpretation
161
+
162
+ Signature: (*args, **kwargs) -> Object | PyTree[Object]
163
+
164
+ Args:
165
+ fn: Custom trace function that takes arbitrary args/kwargs and
166
+ returns result PyTree (can contain Objects and constants)
167
+
168
+ Returns:
169
+ The same function (for decorator pattern)
170
+
171
+ Example (JAX integration):
172
+ >>> run_jax_p = Primitive("run_jax")
173
+ >>>
174
+ >>> @run_jax_p.def_trace
175
+ >>> def run_jax_trace(jax_fn: Callable, *args, **kwargs):
176
+ >>> # args/kwargs can mix Objects and constants
177
+ >>> # Compile JAX function and execute
178
+ >>> result = compile_and_run(jax_fn, args, kwargs)
179
+ >>> return result # Can be any PyTree structure
180
+ >>>
181
+ >>> # Example (multi-output):
182
+ >>> split_p = Primitive("split")
183
+ >>>
184
+ >>> @split_p.def_trace
185
+ >>> def split_trace(x: Object, *, num_splits: int):
186
+ >>> # Call underlying operations
187
+ >>> parts = [slice_p.bind(x, i) for i in range(num_splits)]
188
+ >>> return parts # Returns list of Objects
189
+ """
190
+ self._trace = fn
191
+ return fn
192
+
193
+ def bind(self, *args: Any, **kwargs: Any) -> T_Ret:
194
+ """Bind arguments and execute/trace the primitive.
195
+
196
+ This is the main user-facing API. It automatically chooses between:
197
+ - **Trace mode**: Record operation to Graph IR (if in Tracer context)
198
+ - **Interp mode**: Execute Graph IR via backend (if in Interpreter context)
199
+
200
+ Behavior depends on which method was used to define the primitive:
201
+ - **def_abstract_eval**: Positional args must be Objects (inputs),
202
+ kwargs must be plain values (attrs). Returns single Object or list[Object].
203
+ - **def_trace**: Both args and kwargs can mix Objects and plain values.
204
+ Returns arbitrary PyTree structure.
205
+
206
+ Args:
207
+ *args: Positional arguments
208
+ **kwargs: Keyword arguments
209
+
210
+ Returns:
211
+ Object | PyTree[Object] - Result structure depends on primitive definition
212
+
213
+ Raises:
214
+ RuntimeError: If neither abstract_eval nor trace is defined
215
+ TypeError: If using def_abstract_eval and kwargs contain Object instances
216
+
217
+ Example:
218
+ >>> # With def_abstract_eval (simple form)
219
+ >>> z = add_p.bind(x, y) # x, y are Objects
220
+ >>>
221
+ >>> # With def_trace (full form)
222
+ >>> result = run_jax_p.bind(fn, obj1, 42, obj2, k=3.14)
223
+ >>> # Mixing Objects (obj1, obj2) and constants (42, 3.14)
224
+ """
225
+ # Get current context
226
+ ctx = get_current_context()
227
+ if ctx is None:
228
+ ctx = get_default_context()
229
+
230
+ def lift_if_object(x: Any) -> Any: # Add type annotation
231
+ return ctx.lift(x) if isinstance(x, Object) else x
232
+
233
+ lifted_args, lifted_kwargs = tree_map(lift_if_object, (args, kwargs))
234
+
235
+ # Execute in context
236
+ return cast(T_Ret, ctx.bind_primitive(self, lifted_args, lifted_kwargs))
237
+
238
+ def __call__(self, *args: Any, **kwargs: Any) -> T_Ret:
239
+ """Syntactic sugar for bind(): primitive(*args, **kwargs) == primitive.bind(*args, **kwargs)."""
240
+ return self.bind(*args, **kwargs)
241
+
242
+
243
+ # ============================================================================
244
+ # Decorator: @primitive for defining primitives in a concise way
245
+ # ============================================================================
246
+
247
+
248
+ def primitive(name: str) -> Callable[[Callable], Primitive]:
249
+ """Decorator for defining primitives in a concise way.
250
+
251
+ This is a convenience decorator that creates a Primitive and registers
252
+ the decorated function as its abstract_eval rule.
253
+
254
+ Args:
255
+ name: Unique name for the primitive
256
+
257
+ Returns:
258
+ Decorator function
259
+
260
+ Example:
261
+ >>> @primitive("my_custom_op")
262
+ >>> def my_op_abstract(x_type: BaseType, y_type: BaseType) -> BaseType:
263
+ >>> # Type inference logic
264
+ >>> return x_type
265
+ >>>
266
+ >>> # The decorator returns a Primitive instance
267
+ >>> my_op_p = my_op_abstract
268
+ >>>
269
+ >>> # Use it (execution via Graph IR → Backend)
270
+ >>> z = my_op_p.bind(x, y)
271
+ """
272
+
273
+ def decorator(fn: Callable) -> Primitive[Any]:
274
+ p: Primitive[Any] = Primitive(name)
275
+ p.def_abstract_eval(fn)
276
+ return p
277
+
278
+ return decorator
279
+
280
+
281
+ __all__ = [
282
+ "Primitive",
283
+ "primitive",
284
+ ]
@@ -0,0 +1,119 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Pretty printer for the EDSL Graph IR."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from typing import Any
20
+
21
+ from mplang.v2.edsl.graph import Graph, Operation, Value
22
+
23
+
24
+ class GraphPrinter:
25
+ """Format Graph IR in a readable, MLIR-like style."""
26
+
27
+ def __init__(
28
+ self,
29
+ *,
30
+ indent_size: int = 2,
31
+ show_types: bool = True,
32
+ show_attrs: bool = True,
33
+ ):
34
+ self.indent_size = indent_size
35
+ self.show_types = show_types
36
+ self.show_attrs = show_attrs
37
+
38
+ def format(self, graph: Graph) -> str:
39
+ """Return a formatted string representation of `graph`."""
40
+ lines: list[str] = []
41
+ self._format_graph(graph, lines, indent_level=0, heading=None)
42
+ return "\n".join(lines)
43
+
44
+ # ------------------------------------------------------------------
45
+ # Internal helpers
46
+ # ------------------------------------------------------------------
47
+ def _write(self, lines: list[str], indent_level: int, text: str) -> None:
48
+ indent = " " * (indent_level * self.indent_size)
49
+ lines.append(f"{indent}{text}")
50
+
51
+ def _format_graph(
52
+ self, graph: Graph, lines: list[str], indent_level: int, heading: str | None
53
+ ) -> None:
54
+ header_prefix = f"{heading}" if heading else ""
55
+ params_str = self._format_params(graph.inputs)
56
+ self._write(lines, indent_level, f"{header_prefix}{params_str} {{")
57
+
58
+ for op in graph.operations:
59
+ self._format_operation(op, lines, indent_level + 1)
60
+
61
+ if graph.outputs:
62
+ out_names = ", ".join(val.name for val in graph.outputs)
63
+ self._write(lines, indent_level + 1, f"return {out_names}")
64
+
65
+ self._write(lines, indent_level, "}")
66
+
67
+ def _format_params(self, inputs: list[Value]) -> str:
68
+ if not inputs:
69
+ return "()"
70
+ parts: list[str] = []
71
+ for value in inputs:
72
+ if self.show_types:
73
+ parts.append(f"{value.name}: {value.type}")
74
+ else:
75
+ parts.append(f"{value.name}")
76
+ joined = ", ".join(parts)
77
+ return f"({joined})"
78
+
79
+ def _format_operation(
80
+ self, op: Operation, lines: list[str], indent_level: int
81
+ ) -> None:
82
+ lhs = self._format_outputs(op.outputs)
83
+ inputs_str = ", ".join(val.name for val in op.inputs)
84
+ attrs_str = self._format_attrs(op.attrs)
85
+ type_str = self._format_output_types(op.outputs)
86
+ op_line = f"{lhs} = {op.opcode}({inputs_str}){attrs_str}{type_str}"
87
+ if op.regions:
88
+ self._write(lines, indent_level, f"{op_line} {{")
89
+ for region in op.regions:
90
+ self._format_graph(region, lines, indent_level + 1, heading=None)
91
+ self._write(lines, indent_level, "}")
92
+ else:
93
+ self._write(lines, indent_level, op_line)
94
+
95
+ def _format_outputs(self, outputs: list[Value]) -> str:
96
+ if not outputs:
97
+ return "[]"
98
+ if len(outputs) == 1:
99
+ return outputs[0].name
100
+ return "[" + ", ".join(val.name for val in outputs) + "]"
101
+
102
+ def _format_attrs(self, attrs: dict[str, Any]) -> str:
103
+ if not self.show_attrs or not attrs:
104
+ return ""
105
+ parts = [f"{key}={attrs[key]!r}" for key in sorted(attrs)]
106
+ return " {" + ", ".join(parts) + "}"
107
+
108
+ def _format_output_types(self, outputs: list[Value]) -> str:
109
+ if not self.show_types or not outputs:
110
+ return ""
111
+ type_strings = [str(val.type) for val in outputs]
112
+ if len(type_strings) == 1:
113
+ return f" : {type_strings[0]}"
114
+ return " : (" + ", ".join(type_strings) + ")"
115
+
116
+
117
+ def format_graph(graph: Graph, **kwargs: Any) -> str:
118
+ """Convenience helper that returns `GraphPrinter(**kwargs).format(graph)`."""
119
+ return GraphPrinter(**kwargs).format(graph)
@@ -0,0 +1,207 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Registry for primitive implementations.
16
+
17
+ This module decouples the Primitive definition from the Interpreter execution.
18
+ Primitives register their implementations here, and the Interpreter looks them up here.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import time
24
+ from collections import defaultdict
25
+ from collections.abc import Callable
26
+ from dataclasses import dataclass, field
27
+ from typing import Any
28
+
29
+ # Global registry for primitive implementations
30
+ # Key: opcode (str), Value: implementation function
31
+ _IMPL_REGISTRY: dict[str, Callable[..., Any]] = {}
32
+
33
+
34
+ # ==============================================================================
35
+ # Profiler for All Primitive Operations
36
+ # ==============================================================================
37
+
38
+
39
+ @dataclass
40
+ class OpProfiler:
41
+ """Global profiler for tracking all primitive operation timing."""
42
+
43
+ enabled: bool = False
44
+ timings: dict[str, list[float]] = field(default_factory=lambda: defaultdict(list))
45
+
46
+ def reset(self) -> None:
47
+ """Clear all timing data."""
48
+ self.timings = defaultdict(list)
49
+
50
+ def record(self, opcode: str, duration: float) -> None:
51
+ """Record a timing measurement."""
52
+ if self.enabled:
53
+ self.timings[opcode].append(duration)
54
+
55
+ def summary(self) -> dict[str, dict[str, float]]:
56
+ """Get summary statistics for all operations."""
57
+ result = {}
58
+ for opcode, times in sorted(self.timings.items()):
59
+ if times:
60
+ result[opcode] = {
61
+ "count": len(times),
62
+ "total": sum(times),
63
+ "mean": sum(times) / len(times),
64
+ "min": min(times),
65
+ "max": max(times),
66
+ }
67
+ return result
68
+
69
+ def print_summary(self, top_n: int = 20) -> None:
70
+ """Print a formatted summary of timing statistics."""
71
+ stats = self.summary()
72
+ if not stats:
73
+ print("No timing data collected.")
74
+ return
75
+
76
+ print("\n" + "=" * 80)
77
+ print("PRIMITIVE OPERATION TIMING SUMMARY")
78
+ print("=" * 80)
79
+ print(
80
+ f"{'Operation':<35} {'Count':>8} {'Total(s)':>10} "
81
+ f"{'Mean(ms)':>10} {'Max(ms)':>10}"
82
+ )
83
+ print("-" * 80)
84
+
85
+ total_time = sum(s["total"] for s in stats.values())
86
+
87
+ # Sort by total time descending
88
+ sorted_stats = sorted(stats.items(), key=lambda x: -x[1]["total"])
89
+
90
+ for opcode, s in sorted_stats[:top_n]:
91
+ pct = s["total"] / total_time * 100 if total_time > 0 else 0
92
+ print(
93
+ f"{opcode:<35} {s['count']:>8} {s['total']:>10.3f} "
94
+ f"{s['mean'] * 1000:>10.3f} {s['max'] * 1000:>10.3f} ({pct:>5.1f}%)"
95
+ )
96
+
97
+ if len(sorted_stats) > top_n:
98
+ print(f" ... and {len(sorted_stats) - top_n} more operations")
99
+
100
+ print("-" * 80)
101
+ print(f"{'TOTAL':<35} {'':<8} {total_time:>10.3f}s")
102
+
103
+ def print_leaf_summary(self, top_n: int = 20) -> None:
104
+ """Print summary excluding container ops (pcall, shuffle, etc.).
105
+
106
+ This shows only 'leaf' operations that don't contain nested calls,
107
+ giving accurate self-time without double-counting.
108
+ """
109
+ # Container ops that include nested operation time
110
+ container_ops = {
111
+ "simp.pcall_static",
112
+ "simp.pcall_dynamic",
113
+ "simp.shuffle_static",
114
+ "simp.shuffle",
115
+ "simp.uniform_cond",
116
+ "simp.while_loop",
117
+ }
118
+
119
+ stats = self.summary()
120
+ leaf_stats = {k: v for k, v in stats.items() if k not in container_ops}
121
+
122
+ if not leaf_stats:
123
+ print("No leaf timing data collected.")
124
+ return
125
+
126
+ print("\n" + "=" * 80)
127
+ print("LEAF OPERATION TIMING SUMMARY (excludes container ops)")
128
+ print("=" * 80)
129
+ print(
130
+ f"{'Operation':<35} {'Count':>8} {'Total(s)':>10} "
131
+ f"{'Mean(ms)':>10} {'Max(ms)':>10}"
132
+ )
133
+ print("-" * 80)
134
+
135
+ total_time = sum(s["total"] for s in leaf_stats.values())
136
+ sorted_stats = sorted(leaf_stats.items(), key=lambda x: -x[1]["total"])
137
+
138
+ for opcode, s in sorted_stats[:top_n]:
139
+ pct = s["total"] / total_time * 100 if total_time > 0 else 0
140
+ print(
141
+ f"{opcode:<35} {s['count']:>8} {s['total']:>10.3f} "
142
+ f"{s['mean'] * 1000:>10.3f} {s['max'] * 1000:>10.3f} ({pct:>5.1f}%)"
143
+ )
144
+
145
+ if len(sorted_stats) > top_n:
146
+ print(f" ... and {len(sorted_stats) - top_n} more operations")
147
+
148
+ print("-" * 80)
149
+ print(f"{'TOTAL (leaf ops)':<35} {'':<8} {total_time:>10.3f}s")
150
+
151
+
152
+ # Global profiler instance
153
+ _profiler = OpProfiler()
154
+
155
+
156
+ def get_profiler() -> OpProfiler:
157
+ """Get the global operation profiler instance."""
158
+ return _profiler
159
+
160
+
161
+ def enable_profiling() -> None:
162
+ """Enable primitive operation profiling."""
163
+ _profiler.enabled = True
164
+ _profiler.reset()
165
+
166
+
167
+ def disable_profiling() -> None:
168
+ """Disable primitive operation profiling."""
169
+ _profiler.enabled = False
170
+
171
+
172
+ # ==============================================================================
173
+ # Registry Functions
174
+ # ==============================================================================
175
+
176
+
177
+ def register_impl(opcode: str, fn: Callable[..., Any]) -> None:
178
+ """Register an implementation for an opcode.
179
+
180
+ Args:
181
+ opcode: The unique name of the primitive (e.g. "add", "mul").
182
+ fn: The function implementing the logic.
183
+ Signature: (interpreter, op, *args) -> result
184
+ """
185
+ _IMPL_REGISTRY[opcode] = fn
186
+
187
+
188
+ def get_impl(opcode: str) -> Callable[..., Any] | None:
189
+ """Get the registered implementation for an opcode.
190
+
191
+ If profiling is enabled, returns a wrapped function that records timing.
192
+ """
193
+ fn = _IMPL_REGISTRY.get(opcode)
194
+ if fn is None:
195
+ return None
196
+
197
+ if not _profiler.enabled:
198
+ return fn
199
+
200
+ # Return a profiling wrapper
201
+ def profiled_fn(interpreter: Any, op: Any, *args: Any) -> Any:
202
+ t0 = time.perf_counter()
203
+ result = fn(interpreter, op, *args)
204
+ _profiler.record(opcode, time.perf_counter() - t0)
205
+ return result
206
+
207
+ return profiled_fn