mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
@@ -1,285 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- Expression printer for debugging and visualization.
17
- """
18
-
19
- from __future__ import annotations
20
-
21
- from typing import Any
22
-
23
- from mplang.v1.core.dtypes import DType
24
- from mplang.v1.core.expr.ast import (
25
- AccessExpr,
26
- CallExpr,
27
- CondExpr,
28
- ConvExpr,
29
- EvalExpr,
30
- Expr,
31
- FuncDefExpr,
32
- ShflExpr,
33
- ShflSExpr,
34
- TupleExpr,
35
- VariableExpr,
36
- WhileExpr,
37
- )
38
- from mplang.v1.core.expr.visitor import ExprVisitor
39
- from mplang.v1.core.mptype import MPType
40
- from mplang.v1.core.pfunc import PFunction
41
- from mplang.v1.core.tensor import Shape, TensorType
42
-
43
-
44
- class Printer(ExprVisitor):
45
- """Printer that prints Expression DAG in IR style."""
46
-
47
- def __init__(
48
- self,
49
- indent_size: int = 2,
50
- compact_format: bool = True,
51
- *,
52
- verbose_peval: bool = False,
53
- inline_pcall: bool = True,
54
- ):
55
- super().__init__() # Initialize MemorizedVisitor
56
- self.indent_size = indent_size
57
- self.compact_format = compact_format
58
- self.verbose_peval = verbose_peval
59
- self.inline_pcall = inline_pcall
60
- self._cur_indent = 0
61
- self._output: list[str] = []
62
- self._visited: dict[Expr, str] = {}
63
- self._counter = 0
64
-
65
- def _write(self, text: str) -> None:
66
- """Write a line with current indentation."""
67
- indent = " " * (self._cur_indent * self.indent_size)
68
- for line in text.split("\n"):
69
- self._output.append(f"{indent}{line}")
70
-
71
- def _do_print(
72
- self,
73
- op_name: str,
74
- op_args: list[str],
75
- attrs: dict | None = None,
76
- regions: dict[str, FuncDefExpr] | None = None,
77
- mptypes: list | None = None,
78
- ) -> str:
79
- """A generic node printer that prints in the MLIR style."""
80
- ret_name = f"%{self._counter}"
81
- self._counter += 1
82
-
83
- args_str = f"({', '.join(op_args)})"
84
- attrs_str = ""
85
- if attrs:
86
- attr_parts = [f"{k}={v}" for k, v in attrs.items()]
87
- attrs_str = f" {{{', '.join(attr_parts)}}}"
88
-
89
- regions_str = ""
90
- if regions:
91
- regions_str = " {\n"
92
- indent = " " * self.indent_size
93
- for r_name, func_def_expr in regions.items():
94
- body_printer = Printer(
95
- indent_size=self.indent_size,
96
- compact_format=self.compact_format,
97
- inline_pcall=self.inline_pcall,
98
- )
99
- func_def_expr.accept(body_printer)
100
- regions_str += f"{indent}{r_name}: "
101
- body_content = ("\n" + indent).join(body_printer._output)
102
- regions_str += f"{body_content}\n"
103
- regions_str += "}"
104
-
105
- type_str = ""
106
- if mptypes:
107
- type_parts = [str(mptype) for mptype in mptypes]
108
- if len(type_parts) == 1:
109
- type_str = f" : {type_parts[0]}"
110
- else:
111
- type_str = f" : ({', '.join(type_parts)})"
112
-
113
- self._write(
114
- f"{ret_name} = {op_name}{args_str}{attrs_str}{regions_str}{type_str}"
115
- )
116
- return ret_name
117
-
118
- def _var_name(self, expr: Expr) -> str:
119
- key = expr
120
- if key not in self._visited:
121
- self._visited[key] = expr.accept(self)
122
- return self._visited[key]
123
-
124
- def print_expr(self, expr: Expr) -> str:
125
- """Print an expression and return the formatted string."""
126
- self._output = []
127
- self._visited = {}
128
- self._cache: dict[str, Any] = {} # Reset memorized visitor cache
129
- self._counter = 0
130
- expr.accept(self)
131
- return "\n".join(self._output)
132
-
133
- def _get_const_data(self, dtype: DType, shape: Shape, data_bytes: bytes) -> str:
134
- # Get dtype and shape from output info (following party.py implementation)
135
- import numpy as np
136
-
137
- np_array = np.frombuffer(data_bytes, dtype=dtype.to_numpy()).reshape(shape)
138
-
139
- # Format the display based on array size
140
- if np_array.size <= 10:
141
- # Small arrays - show full content
142
- if np_array.size == 1:
143
- # Scalar value
144
- value_str = str(np_array.item())
145
- else:
146
- value_str = str(np_array.tolist())
147
- else:
148
- # Large arrays - use numpy's default string representation which handles truncation
149
- value_str = str(np_array)
150
- return value_str
151
-
152
- def _print_const(self, pfunc: PFunction, mptypes: list[MPType]) -> str:
153
- assert len(pfunc.outs_info) == 1
154
- out_type = pfunc.outs_info[0]
155
- assert isinstance(out_type, TensorType)
156
- attrs = {
157
- "data": self._get_const_data(
158
- out_type.dtype, out_type.shape, pfunc.attrs["data_bytes"]
159
- )
160
- }
161
- return self._do_print("pconst", [], attrs=attrs, mptypes=mptypes)
162
-
163
- def visit_eval(self, expr: EvalExpr) -> str:
164
- arg_names = [self._var_name(arg) for arg in expr.args]
165
- fn_type = expr.pfunc.fn_type
166
-
167
- # for well known basic functions
168
- if fn_type == "basic.constant":
169
- return self._print_const(expr.pfunc, expr.mptypes)
170
-
171
- attrs = {"fn_type": fn_type}
172
- if expr.pfunc.fn_name:
173
- attrs["fn_name"] = str(expr.pfunc.fn_name)
174
- if self.verbose_peval:
175
- attrs["fn_text"] = str(expr.pfunc.fn_text)
176
-
177
- if expr.rmask is not None:
178
- attrs["rmask"] = f"0x{expr.rmask.value:x}"
179
- return self._do_print("peval", arg_names, attrs=attrs, mptypes=expr.mptypes)
180
-
181
- def visit_variable(self, expr: VariableExpr) -> str:
182
- if self.compact_format:
183
- # Use $param format and don't print the variable definition
184
- return f"{expr.name}"
185
- else:
186
- return self._do_print(
187
- "pname", [f'"{expr.name}"'], attrs={}, mptypes=expr.mptypes
188
- )
189
-
190
- def visit_tuple(self, expr: TupleExpr) -> str:
191
- arg_names = [self._var_name(arg) for arg in expr.args]
192
- return self._do_print("tuple", arg_names, mptypes=expr.mptypes)
193
-
194
- def visit_cond(self, expr: CondExpr) -> str:
195
- pred_name = self._var_name(expr.pred)
196
- arg_names = [self._var_name(arg) for arg in expr.args]
197
-
198
- # Directly pass FuncDefExpr objects
199
- return self._do_print(
200
- "pcond",
201
- [pred_name, *arg_names],
202
- regions={
203
- "then_fn": expr.then_fn,
204
- "else_fn": expr.else_fn,
205
- },
206
- mptypes=expr.mptypes,
207
- )
208
-
209
- def visit_call(self, expr: CallExpr) -> str:
210
- arg_names = [self._var_name(arg) for arg in expr.args]
211
- if self.inline_pcall:
212
- return self._do_print(
213
- expr.name,
214
- arg_names,
215
- mptypes=expr.mptypes,
216
- )
217
- else:
218
- return self._do_print(
219
- "pcall",
220
- arg_names,
221
- regions={"fn": expr.fn},
222
- mptypes=expr.mptypes,
223
- )
224
-
225
- def visit_while(self, expr: WhileExpr) -> str:
226
- arg_names = [self._var_name(arg) for arg in expr.args]
227
-
228
- return self._do_print(
229
- "pwhile",
230
- arg_names,
231
- regions={
232
- "cond_fn": expr.cond_fn,
233
- "body_fn": expr.body_fn,
234
- },
235
- mptypes=expr.mptypes,
236
- )
237
-
238
- def visit_conv(self, expr: ConvExpr) -> str:
239
- var_names = [self._var_name(var) for var in expr.vars]
240
- return self._do_print("pconv", var_names, mptypes=expr.mptypes)
241
-
242
- def visit_shfl_s(self, expr: ShflSExpr) -> str:
243
- src_val_name = self._var_name(expr.src_val)
244
- attrs = {"pmask": expr.pmask, "src_ranks": expr.src_ranks}
245
- return self._do_print(
246
- "pshfl_s", [src_val_name], attrs=attrs, mptypes=expr.mptypes
247
- )
248
-
249
- def visit_shfl(self, expr: ShflExpr) -> str:
250
- src_name = self._var_name(expr.src)
251
- index_name = self._var_name(expr.index)
252
- return self._do_print("pshfl", [src_name, index_name], mptypes=expr.mptypes)
253
-
254
- def visit_access(self, expr: AccessExpr) -> str:
255
- expr_name = self._var_name(expr.src)
256
- if self.compact_format:
257
- # Original:
258
- # %x = ...
259
- # %y = %x[0]
260
- # %z = some_fn(%y)
261
- # Single output(optimized):
262
- # %x = ...
263
- # %z = some_fn(%x)
264
- # Multiple outputs (optimized):
265
- # %x = ...
266
- # %z = some_fn(%x:0, %x:1)
267
- if len(expr.src.mptypes) > 1:
268
- return f"{expr_name}:{expr.index}"
269
- else:
270
- return expr_name
271
- else:
272
- attrs = {"index": str(expr.index)}
273
- return self._do_print(
274
- "access", [expr_name], attrs=attrs, mptypes=expr.mptypes
275
- )
276
-
277
- def visit_func_def(self, expr: FuncDefExpr) -> str:
278
- param_names = expr.params
279
- self._write(f"({', '.join(param_names)}) {{")
280
- self._cur_indent += 1
281
- body_name = expr.body.accept(self)
282
- self._write(f"return {body_name}")
283
- self._cur_indent -= 1
284
- self._write("}")
285
- return ""
@@ -1,141 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- Expression transformer based on visitor pattern.
17
- """
18
-
19
- from collections.abc import Callable
20
-
21
- from mplang.v1.core.expr.ast import (
22
- AccessExpr,
23
- CallExpr,
24
- CondExpr,
25
- ConvExpr,
26
- EvalExpr,
27
- Expr,
28
- FuncDefExpr,
29
- ShflExpr,
30
- ShflSExpr,
31
- TupleExpr,
32
- VariableExpr,
33
- WhileExpr,
34
- )
35
- from mplang.v1.core.expr.visitor import ExprVisitor
36
-
37
-
38
- class ExprTransformer(ExprVisitor):
39
- """Transformer that transforms expressions based on rules."""
40
-
41
- def __init__(self, trans_rules: dict[str, Callable[[Expr], Expr]] | None = None):
42
- self.trans_rules = trans_rules or {}
43
-
44
- def visit_eval(self, expr: EvalExpr) -> Expr:
45
- # Transform child expressions first
46
- transformed_args = [arg.accept(self) for arg in expr.args]
47
- new_expr = EvalExpr(expr.pfunc, transformed_args, expr.rmask)
48
-
49
- if "eval" in self.trans_rules:
50
- return self.trans_rules["eval"](new_expr)
51
- return new_expr
52
-
53
- def visit_variable(self, expr: VariableExpr) -> Expr:
54
- if "name" in self.trans_rules:
55
- return self.trans_rules["name"](expr)
56
- return expr
57
-
58
- def visit_tuple(self, expr: TupleExpr) -> Expr:
59
- # Transform child expressions first
60
- transformed_args = [arg.accept(self) for arg in expr.args]
61
- new_expr = TupleExpr(transformed_args)
62
-
63
- if "tuple" in self.trans_rules:
64
- return self.trans_rules["tuple"](new_expr)
65
- return new_expr
66
-
67
- def visit_cond(self, expr: CondExpr) -> Expr:
68
- # Transform child expressions first
69
- transformed_pred = expr.pred.accept(self)
70
- transformed_args = [arg.accept(self) for arg in expr.args]
71
- new_expr = CondExpr(
72
- transformed_pred, expr.then_fn, expr.else_fn, transformed_args
73
- )
74
-
75
- if "cond" in self.trans_rules:
76
- return self.trans_rules["cond"](new_expr)
77
- return new_expr
78
-
79
- def visit_call(self, expr: CallExpr) -> Expr:
80
- # Transform child expressions first
81
- transformed_args = [arg.accept(self) for arg in expr.args]
82
- new_expr = CallExpr(expr.name, expr.fn, transformed_args)
83
-
84
- if "call" in self.trans_rules:
85
- return self.trans_rules["call"](new_expr)
86
- return new_expr
87
-
88
- def visit_while(self, expr: WhileExpr) -> Expr:
89
- # Transform all arguments
90
- transformed_args = [arg.accept(self) for arg in expr.args]
91
- new_expr = WhileExpr(expr.cond_fn, expr.body_fn, transformed_args)
92
-
93
- if "while" in self.trans_rules:
94
- return self.trans_rules["while"](new_expr)
95
- return new_expr
96
-
97
- def visit_conv(self, expr: ConvExpr) -> Expr:
98
- # Transform child expressions first
99
- transformed_vars = [var.accept(self) for var in expr.vars]
100
- new_expr = ConvExpr(transformed_vars)
101
-
102
- if "conv" in self.trans_rules:
103
- return self.trans_rules["conv"](new_expr)
104
- return new_expr
105
-
106
- def visit_shfl_s(self, expr: ShflSExpr) -> Expr:
107
- # Transform child expression first
108
- transformed_src_val = expr.src_val.accept(self)
109
- new_expr = ShflSExpr(transformed_src_val, expr.pmask, expr.src_ranks)
110
-
111
- if "shfl_s" in self.trans_rules:
112
- return self.trans_rules["shfl_s"](new_expr)
113
- return new_expr
114
-
115
- def visit_shfl(self, expr: ShflExpr) -> Expr:
116
- # Transform child expressions first
117
- transformed_src = expr.src.accept(self)
118
- transformed_index = expr.index.accept(self)
119
- new_expr = ShflExpr(transformed_src, transformed_index)
120
-
121
- if "shfl" in self.trans_rules:
122
- return self.trans_rules["shfl"](new_expr)
123
- return new_expr
124
-
125
- def visit_access(self, expr: AccessExpr) -> Expr:
126
- # Transform child expression first
127
- transformed_expr = expr.src.accept(self)
128
- new_expr = AccessExpr(transformed_expr, expr.index)
129
-
130
- if "access" in self.trans_rules:
131
- return self.trans_rules["access"](new_expr)
132
- return new_expr
133
-
134
- def visit_func_def(self, expr: FuncDefExpr) -> Expr:
135
- # Transform body only, params are just strings now
136
- transformed_body = expr.body.accept(self)
137
- new_expr = FuncDefExpr(expr.params, transformed_body)
138
-
139
- if "func_def" in self.trans_rules:
140
- return self.trans_rules["func_def"](new_expr)
141
- return new_expr
@@ -1,78 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- Utility functions for expression system.
17
- """
18
-
19
- from collections.abc import Sequence
20
-
21
- from mplang.v1.core.mask import Mask
22
- from mplang.v1.core.mptype import TensorLike
23
-
24
-
25
- def type_equal(*args: TensorLike) -> bool:
26
- """Check if tensors have identical type properties (dtype, shape).
27
-
28
- Args:
29
- *args: Variable number of TensorLike objects to compare
30
-
31
- Returns:
32
- bool: True if all tensors have identical types, False otherwise
33
- """
34
- if len(args) <= 1:
35
- return True
36
- for i in range(1, len(args)):
37
- if args[0].dtype != args[i].dtype or args[0].shape != args[i].shape:
38
- return False
39
- return True
40
-
41
-
42
- def ensure_scalar(obj: TensorLike) -> None:
43
- """Ensure that a tensor is a scalar."""
44
- if len(obj.shape) != 0:
45
- raise TypeError(f"Expected a scalar, got {obj}.")
46
-
47
-
48
- def ensure_tensorlist_equal(*args: Sequence[TensorLike]) -> None:
49
- """Ensure that multiple tensor lists have the same structure and types."""
50
- if len(args) < 2:
51
- raise ValueError(f"expect at least 2 args, got {len(args)}")
52
- for i in range(1, len(args)):
53
- if len(args[i]) != len(args[0]):
54
- raise ValueError(f"Length mismatch: {len(args[i])} vs {len(args[0])}")
55
- for j in range(len(args[0])):
56
- if not type_equal(args[0][j], args[i][j]):
57
- raise TypeError(f"Type mismatch: {args[0][j]} vs {args[i][j]}")
58
-
59
-
60
- def deduce_mask(*pmasks: Mask | None) -> Mask | None:
61
- """Deduce the joint mask from multiple participant masks."""
62
- if len(pmasks) == 0:
63
- return None
64
-
65
- if any(pmask is None for pmask in pmasks):
66
- # If any pmask is None, we cannot deduce a specific mask.
67
- return None
68
-
69
- # return the joint mask of all provided pmasks.
70
- # We already checked above, but add it here to make mypy happy
71
- if pmasks[0] is None:
72
- return None
73
- result = Mask(pmasks[0])
74
- for pmask in pmasks[1:]:
75
- assert pmask is not None # We already checked above
76
- result = result.intersection(Mask(pmask))
77
-
78
- return result
@@ -1,85 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- Visitor pattern interface for expression system.
17
- """
18
-
19
- from __future__ import annotations
20
-
21
- from abc import ABC, abstractmethod
22
- from typing import TYPE_CHECKING, Any
23
-
24
- if TYPE_CHECKING:
25
- from mplang.v1.core.expr.ast import (
26
- AccessExpr,
27
- CallExpr,
28
- CondExpr,
29
- ConvExpr,
30
- EvalExpr,
31
- FuncDefExpr,
32
- ShflExpr,
33
- ShflSExpr,
34
- TupleExpr,
35
- VariableExpr,
36
- WhileExpr,
37
- )
38
-
39
-
40
- class ExprVisitor(ABC):
41
- """Base visitor interface for expression types."""
42
-
43
- @abstractmethod
44
- def visit_eval(self, expr: EvalExpr) -> Any:
45
- pass
46
-
47
- @abstractmethod
48
- def visit_variable(self, expr: VariableExpr) -> Any:
49
- pass
50
-
51
- @abstractmethod
52
- def visit_tuple(self, expr: TupleExpr) -> Any:
53
- pass
54
-
55
- @abstractmethod
56
- def visit_cond(self, expr: CondExpr) -> Any:
57
- pass
58
-
59
- @abstractmethod
60
- def visit_call(self, expr: CallExpr) -> Any:
61
- pass
62
-
63
- @abstractmethod
64
- def visit_while(self, expr: WhileExpr) -> Any:
65
- pass
66
-
67
- @abstractmethod
68
- def visit_conv(self, expr: ConvExpr) -> Any:
69
- pass
70
-
71
- @abstractmethod
72
- def visit_shfl_s(self, expr: ShflSExpr) -> Any:
73
- pass
74
-
75
- @abstractmethod
76
- def visit_shfl(self, expr: ShflExpr) -> Any:
77
- pass
78
-
79
- @abstractmethod
80
- def visit_access(self, expr: AccessExpr) -> Any:
81
- pass
82
-
83
- @abstractmethod
84
- def visit_func_def(self, expr: FuncDefExpr) -> Any:
85
- pass