mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
mplang/v1/core/expr/printer.py
DELETED
|
@@ -1,285 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
Expression printer for debugging and visualization.
|
|
17
|
-
"""
|
|
18
|
-
|
|
19
|
-
from __future__ import annotations
|
|
20
|
-
|
|
21
|
-
from typing import Any
|
|
22
|
-
|
|
23
|
-
from mplang.v1.core.dtypes import DType
|
|
24
|
-
from mplang.v1.core.expr.ast import (
|
|
25
|
-
AccessExpr,
|
|
26
|
-
CallExpr,
|
|
27
|
-
CondExpr,
|
|
28
|
-
ConvExpr,
|
|
29
|
-
EvalExpr,
|
|
30
|
-
Expr,
|
|
31
|
-
FuncDefExpr,
|
|
32
|
-
ShflExpr,
|
|
33
|
-
ShflSExpr,
|
|
34
|
-
TupleExpr,
|
|
35
|
-
VariableExpr,
|
|
36
|
-
WhileExpr,
|
|
37
|
-
)
|
|
38
|
-
from mplang.v1.core.expr.visitor import ExprVisitor
|
|
39
|
-
from mplang.v1.core.mptype import MPType
|
|
40
|
-
from mplang.v1.core.pfunc import PFunction
|
|
41
|
-
from mplang.v1.core.tensor import Shape, TensorType
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
class Printer(ExprVisitor):
|
|
45
|
-
"""Printer that prints Expression DAG in IR style."""
|
|
46
|
-
|
|
47
|
-
def __init__(
|
|
48
|
-
self,
|
|
49
|
-
indent_size: int = 2,
|
|
50
|
-
compact_format: bool = True,
|
|
51
|
-
*,
|
|
52
|
-
verbose_peval: bool = False,
|
|
53
|
-
inline_pcall: bool = True,
|
|
54
|
-
):
|
|
55
|
-
super().__init__() # Initialize MemorizedVisitor
|
|
56
|
-
self.indent_size = indent_size
|
|
57
|
-
self.compact_format = compact_format
|
|
58
|
-
self.verbose_peval = verbose_peval
|
|
59
|
-
self.inline_pcall = inline_pcall
|
|
60
|
-
self._cur_indent = 0
|
|
61
|
-
self._output: list[str] = []
|
|
62
|
-
self._visited: dict[Expr, str] = {}
|
|
63
|
-
self._counter = 0
|
|
64
|
-
|
|
65
|
-
def _write(self, text: str) -> None:
|
|
66
|
-
"""Write a line with current indentation."""
|
|
67
|
-
indent = " " * (self._cur_indent * self.indent_size)
|
|
68
|
-
for line in text.split("\n"):
|
|
69
|
-
self._output.append(f"{indent}{line}")
|
|
70
|
-
|
|
71
|
-
def _do_print(
|
|
72
|
-
self,
|
|
73
|
-
op_name: str,
|
|
74
|
-
op_args: list[str],
|
|
75
|
-
attrs: dict | None = None,
|
|
76
|
-
regions: dict[str, FuncDefExpr] | None = None,
|
|
77
|
-
mptypes: list | None = None,
|
|
78
|
-
) -> str:
|
|
79
|
-
"""A generic node printer that prints in the MLIR style."""
|
|
80
|
-
ret_name = f"%{self._counter}"
|
|
81
|
-
self._counter += 1
|
|
82
|
-
|
|
83
|
-
args_str = f"({', '.join(op_args)})"
|
|
84
|
-
attrs_str = ""
|
|
85
|
-
if attrs:
|
|
86
|
-
attr_parts = [f"{k}={v}" for k, v in attrs.items()]
|
|
87
|
-
attrs_str = f" {{{', '.join(attr_parts)}}}"
|
|
88
|
-
|
|
89
|
-
regions_str = ""
|
|
90
|
-
if regions:
|
|
91
|
-
regions_str = " {\n"
|
|
92
|
-
indent = " " * self.indent_size
|
|
93
|
-
for r_name, func_def_expr in regions.items():
|
|
94
|
-
body_printer = Printer(
|
|
95
|
-
indent_size=self.indent_size,
|
|
96
|
-
compact_format=self.compact_format,
|
|
97
|
-
inline_pcall=self.inline_pcall,
|
|
98
|
-
)
|
|
99
|
-
func_def_expr.accept(body_printer)
|
|
100
|
-
regions_str += f"{indent}{r_name}: "
|
|
101
|
-
body_content = ("\n" + indent).join(body_printer._output)
|
|
102
|
-
regions_str += f"{body_content}\n"
|
|
103
|
-
regions_str += "}"
|
|
104
|
-
|
|
105
|
-
type_str = ""
|
|
106
|
-
if mptypes:
|
|
107
|
-
type_parts = [str(mptype) for mptype in mptypes]
|
|
108
|
-
if len(type_parts) == 1:
|
|
109
|
-
type_str = f" : {type_parts[0]}"
|
|
110
|
-
else:
|
|
111
|
-
type_str = f" : ({', '.join(type_parts)})"
|
|
112
|
-
|
|
113
|
-
self._write(
|
|
114
|
-
f"{ret_name} = {op_name}{args_str}{attrs_str}{regions_str}{type_str}"
|
|
115
|
-
)
|
|
116
|
-
return ret_name
|
|
117
|
-
|
|
118
|
-
def _var_name(self, expr: Expr) -> str:
|
|
119
|
-
key = expr
|
|
120
|
-
if key not in self._visited:
|
|
121
|
-
self._visited[key] = expr.accept(self)
|
|
122
|
-
return self._visited[key]
|
|
123
|
-
|
|
124
|
-
def print_expr(self, expr: Expr) -> str:
|
|
125
|
-
"""Print an expression and return the formatted string."""
|
|
126
|
-
self._output = []
|
|
127
|
-
self._visited = {}
|
|
128
|
-
self._cache: dict[str, Any] = {} # Reset memorized visitor cache
|
|
129
|
-
self._counter = 0
|
|
130
|
-
expr.accept(self)
|
|
131
|
-
return "\n".join(self._output)
|
|
132
|
-
|
|
133
|
-
def _get_const_data(self, dtype: DType, shape: Shape, data_bytes: bytes) -> str:
|
|
134
|
-
# Get dtype and shape from output info (following party.py implementation)
|
|
135
|
-
import numpy as np
|
|
136
|
-
|
|
137
|
-
np_array = np.frombuffer(data_bytes, dtype=dtype.to_numpy()).reshape(shape)
|
|
138
|
-
|
|
139
|
-
# Format the display based on array size
|
|
140
|
-
if np_array.size <= 10:
|
|
141
|
-
# Small arrays - show full content
|
|
142
|
-
if np_array.size == 1:
|
|
143
|
-
# Scalar value
|
|
144
|
-
value_str = str(np_array.item())
|
|
145
|
-
else:
|
|
146
|
-
value_str = str(np_array.tolist())
|
|
147
|
-
else:
|
|
148
|
-
# Large arrays - use numpy's default string representation which handles truncation
|
|
149
|
-
value_str = str(np_array)
|
|
150
|
-
return value_str
|
|
151
|
-
|
|
152
|
-
def _print_const(self, pfunc: PFunction, mptypes: list[MPType]) -> str:
|
|
153
|
-
assert len(pfunc.outs_info) == 1
|
|
154
|
-
out_type = pfunc.outs_info[0]
|
|
155
|
-
assert isinstance(out_type, TensorType)
|
|
156
|
-
attrs = {
|
|
157
|
-
"data": self._get_const_data(
|
|
158
|
-
out_type.dtype, out_type.shape, pfunc.attrs["data_bytes"]
|
|
159
|
-
)
|
|
160
|
-
}
|
|
161
|
-
return self._do_print("pconst", [], attrs=attrs, mptypes=mptypes)
|
|
162
|
-
|
|
163
|
-
def visit_eval(self, expr: EvalExpr) -> str:
|
|
164
|
-
arg_names = [self._var_name(arg) for arg in expr.args]
|
|
165
|
-
fn_type = expr.pfunc.fn_type
|
|
166
|
-
|
|
167
|
-
# for well known basic functions
|
|
168
|
-
if fn_type == "basic.constant":
|
|
169
|
-
return self._print_const(expr.pfunc, expr.mptypes)
|
|
170
|
-
|
|
171
|
-
attrs = {"fn_type": fn_type}
|
|
172
|
-
if expr.pfunc.fn_name:
|
|
173
|
-
attrs["fn_name"] = str(expr.pfunc.fn_name)
|
|
174
|
-
if self.verbose_peval:
|
|
175
|
-
attrs["fn_text"] = str(expr.pfunc.fn_text)
|
|
176
|
-
|
|
177
|
-
if expr.rmask is not None:
|
|
178
|
-
attrs["rmask"] = f"0x{expr.rmask.value:x}"
|
|
179
|
-
return self._do_print("peval", arg_names, attrs=attrs, mptypes=expr.mptypes)
|
|
180
|
-
|
|
181
|
-
def visit_variable(self, expr: VariableExpr) -> str:
|
|
182
|
-
if self.compact_format:
|
|
183
|
-
# Use $param format and don't print the variable definition
|
|
184
|
-
return f"{expr.name}"
|
|
185
|
-
else:
|
|
186
|
-
return self._do_print(
|
|
187
|
-
"pname", [f'"{expr.name}"'], attrs={}, mptypes=expr.mptypes
|
|
188
|
-
)
|
|
189
|
-
|
|
190
|
-
def visit_tuple(self, expr: TupleExpr) -> str:
|
|
191
|
-
arg_names = [self._var_name(arg) for arg in expr.args]
|
|
192
|
-
return self._do_print("tuple", arg_names, mptypes=expr.mptypes)
|
|
193
|
-
|
|
194
|
-
def visit_cond(self, expr: CondExpr) -> str:
|
|
195
|
-
pred_name = self._var_name(expr.pred)
|
|
196
|
-
arg_names = [self._var_name(arg) for arg in expr.args]
|
|
197
|
-
|
|
198
|
-
# Directly pass FuncDefExpr objects
|
|
199
|
-
return self._do_print(
|
|
200
|
-
"pcond",
|
|
201
|
-
[pred_name, *arg_names],
|
|
202
|
-
regions={
|
|
203
|
-
"then_fn": expr.then_fn,
|
|
204
|
-
"else_fn": expr.else_fn,
|
|
205
|
-
},
|
|
206
|
-
mptypes=expr.mptypes,
|
|
207
|
-
)
|
|
208
|
-
|
|
209
|
-
def visit_call(self, expr: CallExpr) -> str:
|
|
210
|
-
arg_names = [self._var_name(arg) for arg in expr.args]
|
|
211
|
-
if self.inline_pcall:
|
|
212
|
-
return self._do_print(
|
|
213
|
-
expr.name,
|
|
214
|
-
arg_names,
|
|
215
|
-
mptypes=expr.mptypes,
|
|
216
|
-
)
|
|
217
|
-
else:
|
|
218
|
-
return self._do_print(
|
|
219
|
-
"pcall",
|
|
220
|
-
arg_names,
|
|
221
|
-
regions={"fn": expr.fn},
|
|
222
|
-
mptypes=expr.mptypes,
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
def visit_while(self, expr: WhileExpr) -> str:
|
|
226
|
-
arg_names = [self._var_name(arg) for arg in expr.args]
|
|
227
|
-
|
|
228
|
-
return self._do_print(
|
|
229
|
-
"pwhile",
|
|
230
|
-
arg_names,
|
|
231
|
-
regions={
|
|
232
|
-
"cond_fn": expr.cond_fn,
|
|
233
|
-
"body_fn": expr.body_fn,
|
|
234
|
-
},
|
|
235
|
-
mptypes=expr.mptypes,
|
|
236
|
-
)
|
|
237
|
-
|
|
238
|
-
def visit_conv(self, expr: ConvExpr) -> str:
|
|
239
|
-
var_names = [self._var_name(var) for var in expr.vars]
|
|
240
|
-
return self._do_print("pconv", var_names, mptypes=expr.mptypes)
|
|
241
|
-
|
|
242
|
-
def visit_shfl_s(self, expr: ShflSExpr) -> str:
|
|
243
|
-
src_val_name = self._var_name(expr.src_val)
|
|
244
|
-
attrs = {"pmask": expr.pmask, "src_ranks": expr.src_ranks}
|
|
245
|
-
return self._do_print(
|
|
246
|
-
"pshfl_s", [src_val_name], attrs=attrs, mptypes=expr.mptypes
|
|
247
|
-
)
|
|
248
|
-
|
|
249
|
-
def visit_shfl(self, expr: ShflExpr) -> str:
|
|
250
|
-
src_name = self._var_name(expr.src)
|
|
251
|
-
index_name = self._var_name(expr.index)
|
|
252
|
-
return self._do_print("pshfl", [src_name, index_name], mptypes=expr.mptypes)
|
|
253
|
-
|
|
254
|
-
def visit_access(self, expr: AccessExpr) -> str:
|
|
255
|
-
expr_name = self._var_name(expr.src)
|
|
256
|
-
if self.compact_format:
|
|
257
|
-
# Original:
|
|
258
|
-
# %x = ...
|
|
259
|
-
# %y = %x[0]
|
|
260
|
-
# %z = some_fn(%y)
|
|
261
|
-
# Single output(optimized):
|
|
262
|
-
# %x = ...
|
|
263
|
-
# %z = some_fn(%x)
|
|
264
|
-
# Multiple outputs (optimized):
|
|
265
|
-
# %x = ...
|
|
266
|
-
# %z = some_fn(%x:0, %x:1)
|
|
267
|
-
if len(expr.src.mptypes) > 1:
|
|
268
|
-
return f"{expr_name}:{expr.index}"
|
|
269
|
-
else:
|
|
270
|
-
return expr_name
|
|
271
|
-
else:
|
|
272
|
-
attrs = {"index": str(expr.index)}
|
|
273
|
-
return self._do_print(
|
|
274
|
-
"access", [expr_name], attrs=attrs, mptypes=expr.mptypes
|
|
275
|
-
)
|
|
276
|
-
|
|
277
|
-
def visit_func_def(self, expr: FuncDefExpr) -> str:
|
|
278
|
-
param_names = expr.params
|
|
279
|
-
self._write(f"({', '.join(param_names)}) {{")
|
|
280
|
-
self._cur_indent += 1
|
|
281
|
-
body_name = expr.body.accept(self)
|
|
282
|
-
self._write(f"return {body_name}")
|
|
283
|
-
self._cur_indent -= 1
|
|
284
|
-
self._write("}")
|
|
285
|
-
return ""
|
|
@@ -1,141 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
Expression transformer based on visitor pattern.
|
|
17
|
-
"""
|
|
18
|
-
|
|
19
|
-
from collections.abc import Callable
|
|
20
|
-
|
|
21
|
-
from mplang.v1.core.expr.ast import (
|
|
22
|
-
AccessExpr,
|
|
23
|
-
CallExpr,
|
|
24
|
-
CondExpr,
|
|
25
|
-
ConvExpr,
|
|
26
|
-
EvalExpr,
|
|
27
|
-
Expr,
|
|
28
|
-
FuncDefExpr,
|
|
29
|
-
ShflExpr,
|
|
30
|
-
ShflSExpr,
|
|
31
|
-
TupleExpr,
|
|
32
|
-
VariableExpr,
|
|
33
|
-
WhileExpr,
|
|
34
|
-
)
|
|
35
|
-
from mplang.v1.core.expr.visitor import ExprVisitor
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
class ExprTransformer(ExprVisitor):
|
|
39
|
-
"""Transformer that transforms expressions based on rules."""
|
|
40
|
-
|
|
41
|
-
def __init__(self, trans_rules: dict[str, Callable[[Expr], Expr]] | None = None):
|
|
42
|
-
self.trans_rules = trans_rules or {}
|
|
43
|
-
|
|
44
|
-
def visit_eval(self, expr: EvalExpr) -> Expr:
|
|
45
|
-
# Transform child expressions first
|
|
46
|
-
transformed_args = [arg.accept(self) for arg in expr.args]
|
|
47
|
-
new_expr = EvalExpr(expr.pfunc, transformed_args, expr.rmask)
|
|
48
|
-
|
|
49
|
-
if "eval" in self.trans_rules:
|
|
50
|
-
return self.trans_rules["eval"](new_expr)
|
|
51
|
-
return new_expr
|
|
52
|
-
|
|
53
|
-
def visit_variable(self, expr: VariableExpr) -> Expr:
|
|
54
|
-
if "name" in self.trans_rules:
|
|
55
|
-
return self.trans_rules["name"](expr)
|
|
56
|
-
return expr
|
|
57
|
-
|
|
58
|
-
def visit_tuple(self, expr: TupleExpr) -> Expr:
|
|
59
|
-
# Transform child expressions first
|
|
60
|
-
transformed_args = [arg.accept(self) for arg in expr.args]
|
|
61
|
-
new_expr = TupleExpr(transformed_args)
|
|
62
|
-
|
|
63
|
-
if "tuple" in self.trans_rules:
|
|
64
|
-
return self.trans_rules["tuple"](new_expr)
|
|
65
|
-
return new_expr
|
|
66
|
-
|
|
67
|
-
def visit_cond(self, expr: CondExpr) -> Expr:
|
|
68
|
-
# Transform child expressions first
|
|
69
|
-
transformed_pred = expr.pred.accept(self)
|
|
70
|
-
transformed_args = [arg.accept(self) for arg in expr.args]
|
|
71
|
-
new_expr = CondExpr(
|
|
72
|
-
transformed_pred, expr.then_fn, expr.else_fn, transformed_args
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
if "cond" in self.trans_rules:
|
|
76
|
-
return self.trans_rules["cond"](new_expr)
|
|
77
|
-
return new_expr
|
|
78
|
-
|
|
79
|
-
def visit_call(self, expr: CallExpr) -> Expr:
|
|
80
|
-
# Transform child expressions first
|
|
81
|
-
transformed_args = [arg.accept(self) for arg in expr.args]
|
|
82
|
-
new_expr = CallExpr(expr.name, expr.fn, transformed_args)
|
|
83
|
-
|
|
84
|
-
if "call" in self.trans_rules:
|
|
85
|
-
return self.trans_rules["call"](new_expr)
|
|
86
|
-
return new_expr
|
|
87
|
-
|
|
88
|
-
def visit_while(self, expr: WhileExpr) -> Expr:
|
|
89
|
-
# Transform all arguments
|
|
90
|
-
transformed_args = [arg.accept(self) for arg in expr.args]
|
|
91
|
-
new_expr = WhileExpr(expr.cond_fn, expr.body_fn, transformed_args)
|
|
92
|
-
|
|
93
|
-
if "while" in self.trans_rules:
|
|
94
|
-
return self.trans_rules["while"](new_expr)
|
|
95
|
-
return new_expr
|
|
96
|
-
|
|
97
|
-
def visit_conv(self, expr: ConvExpr) -> Expr:
|
|
98
|
-
# Transform child expressions first
|
|
99
|
-
transformed_vars = [var.accept(self) for var in expr.vars]
|
|
100
|
-
new_expr = ConvExpr(transformed_vars)
|
|
101
|
-
|
|
102
|
-
if "conv" in self.trans_rules:
|
|
103
|
-
return self.trans_rules["conv"](new_expr)
|
|
104
|
-
return new_expr
|
|
105
|
-
|
|
106
|
-
def visit_shfl_s(self, expr: ShflSExpr) -> Expr:
|
|
107
|
-
# Transform child expression first
|
|
108
|
-
transformed_src_val = expr.src_val.accept(self)
|
|
109
|
-
new_expr = ShflSExpr(transformed_src_val, expr.pmask, expr.src_ranks)
|
|
110
|
-
|
|
111
|
-
if "shfl_s" in self.trans_rules:
|
|
112
|
-
return self.trans_rules["shfl_s"](new_expr)
|
|
113
|
-
return new_expr
|
|
114
|
-
|
|
115
|
-
def visit_shfl(self, expr: ShflExpr) -> Expr:
|
|
116
|
-
# Transform child expressions first
|
|
117
|
-
transformed_src = expr.src.accept(self)
|
|
118
|
-
transformed_index = expr.index.accept(self)
|
|
119
|
-
new_expr = ShflExpr(transformed_src, transformed_index)
|
|
120
|
-
|
|
121
|
-
if "shfl" in self.trans_rules:
|
|
122
|
-
return self.trans_rules["shfl"](new_expr)
|
|
123
|
-
return new_expr
|
|
124
|
-
|
|
125
|
-
def visit_access(self, expr: AccessExpr) -> Expr:
|
|
126
|
-
# Transform child expression first
|
|
127
|
-
transformed_expr = expr.src.accept(self)
|
|
128
|
-
new_expr = AccessExpr(transformed_expr, expr.index)
|
|
129
|
-
|
|
130
|
-
if "access" in self.trans_rules:
|
|
131
|
-
return self.trans_rules["access"](new_expr)
|
|
132
|
-
return new_expr
|
|
133
|
-
|
|
134
|
-
def visit_func_def(self, expr: FuncDefExpr) -> Expr:
|
|
135
|
-
# Transform body only, params are just strings now
|
|
136
|
-
transformed_body = expr.body.accept(self)
|
|
137
|
-
new_expr = FuncDefExpr(expr.params, transformed_body)
|
|
138
|
-
|
|
139
|
-
if "func_def" in self.trans_rules:
|
|
140
|
-
return self.trans_rules["func_def"](new_expr)
|
|
141
|
-
return new_expr
|
mplang/v1/core/expr/utils.py
DELETED
|
@@ -1,78 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
Utility functions for expression system.
|
|
17
|
-
"""
|
|
18
|
-
|
|
19
|
-
from collections.abc import Sequence
|
|
20
|
-
|
|
21
|
-
from mplang.v1.core.mask import Mask
|
|
22
|
-
from mplang.v1.core.mptype import TensorLike
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
def type_equal(*args: TensorLike) -> bool:
|
|
26
|
-
"""Check if tensors have identical type properties (dtype, shape).
|
|
27
|
-
|
|
28
|
-
Args:
|
|
29
|
-
*args: Variable number of TensorLike objects to compare
|
|
30
|
-
|
|
31
|
-
Returns:
|
|
32
|
-
bool: True if all tensors have identical types, False otherwise
|
|
33
|
-
"""
|
|
34
|
-
if len(args) <= 1:
|
|
35
|
-
return True
|
|
36
|
-
for i in range(1, len(args)):
|
|
37
|
-
if args[0].dtype != args[i].dtype or args[0].shape != args[i].shape:
|
|
38
|
-
return False
|
|
39
|
-
return True
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def ensure_scalar(obj: TensorLike) -> None:
|
|
43
|
-
"""Ensure that a tensor is a scalar."""
|
|
44
|
-
if len(obj.shape) != 0:
|
|
45
|
-
raise TypeError(f"Expected a scalar, got {obj}.")
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
def ensure_tensorlist_equal(*args: Sequence[TensorLike]) -> None:
|
|
49
|
-
"""Ensure that multiple tensor lists have the same structure and types."""
|
|
50
|
-
if len(args) < 2:
|
|
51
|
-
raise ValueError(f"expect at least 2 args, got {len(args)}")
|
|
52
|
-
for i in range(1, len(args)):
|
|
53
|
-
if len(args[i]) != len(args[0]):
|
|
54
|
-
raise ValueError(f"Length mismatch: {len(args[i])} vs {len(args[0])}")
|
|
55
|
-
for j in range(len(args[0])):
|
|
56
|
-
if not type_equal(args[0][j], args[i][j]):
|
|
57
|
-
raise TypeError(f"Type mismatch: {args[0][j]} vs {args[i][j]}")
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def deduce_mask(*pmasks: Mask | None) -> Mask | None:
|
|
61
|
-
"""Deduce the joint mask from multiple participant masks."""
|
|
62
|
-
if len(pmasks) == 0:
|
|
63
|
-
return None
|
|
64
|
-
|
|
65
|
-
if any(pmask is None for pmask in pmasks):
|
|
66
|
-
# If any pmask is None, we cannot deduce a specific mask.
|
|
67
|
-
return None
|
|
68
|
-
|
|
69
|
-
# return the joint mask of all provided pmasks.
|
|
70
|
-
# We already checked above, but add it here to make mypy happy
|
|
71
|
-
if pmasks[0] is None:
|
|
72
|
-
return None
|
|
73
|
-
result = Mask(pmasks[0])
|
|
74
|
-
for pmask in pmasks[1:]:
|
|
75
|
-
assert pmask is not None # We already checked above
|
|
76
|
-
result = result.intersection(Mask(pmask))
|
|
77
|
-
|
|
78
|
-
return result
|
mplang/v1/core/expr/visitor.py
DELETED
|
@@ -1,85 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
Visitor pattern interface for expression system.
|
|
17
|
-
"""
|
|
18
|
-
|
|
19
|
-
from __future__ import annotations
|
|
20
|
-
|
|
21
|
-
from abc import ABC, abstractmethod
|
|
22
|
-
from typing import TYPE_CHECKING, Any
|
|
23
|
-
|
|
24
|
-
if TYPE_CHECKING:
|
|
25
|
-
from mplang.v1.core.expr.ast import (
|
|
26
|
-
AccessExpr,
|
|
27
|
-
CallExpr,
|
|
28
|
-
CondExpr,
|
|
29
|
-
ConvExpr,
|
|
30
|
-
EvalExpr,
|
|
31
|
-
FuncDefExpr,
|
|
32
|
-
ShflExpr,
|
|
33
|
-
ShflSExpr,
|
|
34
|
-
TupleExpr,
|
|
35
|
-
VariableExpr,
|
|
36
|
-
WhileExpr,
|
|
37
|
-
)
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
class ExprVisitor(ABC):
|
|
41
|
-
"""Base visitor interface for expression types."""
|
|
42
|
-
|
|
43
|
-
@abstractmethod
|
|
44
|
-
def visit_eval(self, expr: EvalExpr) -> Any:
|
|
45
|
-
pass
|
|
46
|
-
|
|
47
|
-
@abstractmethod
|
|
48
|
-
def visit_variable(self, expr: VariableExpr) -> Any:
|
|
49
|
-
pass
|
|
50
|
-
|
|
51
|
-
@abstractmethod
|
|
52
|
-
def visit_tuple(self, expr: TupleExpr) -> Any:
|
|
53
|
-
pass
|
|
54
|
-
|
|
55
|
-
@abstractmethod
|
|
56
|
-
def visit_cond(self, expr: CondExpr) -> Any:
|
|
57
|
-
pass
|
|
58
|
-
|
|
59
|
-
@abstractmethod
|
|
60
|
-
def visit_call(self, expr: CallExpr) -> Any:
|
|
61
|
-
pass
|
|
62
|
-
|
|
63
|
-
@abstractmethod
|
|
64
|
-
def visit_while(self, expr: WhileExpr) -> Any:
|
|
65
|
-
pass
|
|
66
|
-
|
|
67
|
-
@abstractmethod
|
|
68
|
-
def visit_conv(self, expr: ConvExpr) -> Any:
|
|
69
|
-
pass
|
|
70
|
-
|
|
71
|
-
@abstractmethod
|
|
72
|
-
def visit_shfl_s(self, expr: ShflSExpr) -> Any:
|
|
73
|
-
pass
|
|
74
|
-
|
|
75
|
-
@abstractmethod
|
|
76
|
-
def visit_shfl(self, expr: ShflExpr) -> Any:
|
|
77
|
-
pass
|
|
78
|
-
|
|
79
|
-
@abstractmethod
|
|
80
|
-
def visit_access(self, expr: AccessExpr) -> Any:
|
|
81
|
-
pass
|
|
82
|
-
|
|
83
|
-
@abstractmethod
|
|
84
|
-
def visit_func_def(self, expr: FuncDefExpr) -> Any:
|
|
85
|
-
pass
|