mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
@@ -1,80 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- Expression system for multi-party computation graph construction.
17
-
18
- This package provides a modern, extensible expression-based architecture for building
19
- multi-party computation graphs using the visitor pattern.
20
- """
21
-
22
- # Core expression types
23
- from mplang.v1.core.expr.ast import (
24
- AccessExpr,
25
- CallExpr,
26
- CondExpr,
27
- ConvExpr,
28
- EvalExpr,
29
- Expr,
30
- FuncDefExpr,
31
- ShflExpr,
32
- ShflSExpr,
33
- TupleExpr,
34
- VariableExpr,
35
- WhileExpr,
36
- )
37
-
38
- # Built-in evaluator engines
39
- from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator
40
- from mplang.v1.core.expr.printer import Printer
41
- from mplang.v1.core.expr.transformer import ExprTransformer
42
-
43
- # Utility functions
44
- from mplang.v1.core.expr.utils import (
45
- deduce_mask,
46
- ensure_scalar,
47
- ensure_tensorlist_equal,
48
- type_equal,
49
- )
50
-
51
- # Visitor pattern interface
52
- from mplang.v1.core.expr.visitor import ExprVisitor
53
- from mplang.v1.core.expr.walk import walk, walk_dataflow, walk_structural
54
-
55
- __all__ = [
56
- "AccessExpr",
57
- "CallExpr",
58
- "CondExpr",
59
- "ConvExpr",
60
- "EvalExpr",
61
- "Expr",
62
- "ExprTransformer",
63
- "ExprVisitor",
64
- "FuncDefExpr",
65
- "IEvaluator",
66
- "Printer",
67
- "ShflExpr",
68
- "ShflSExpr",
69
- "TupleExpr",
70
- "VariableExpr",
71
- "WhileExpr",
72
- "create_evaluator",
73
- "deduce_mask",
74
- "ensure_scalar",
75
- "ensure_tensorlist_equal",
76
- "type_equal",
77
- "walk",
78
- "walk_dataflow",
79
- "walk_structural",
80
- ]
@@ -1,542 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- Abstract Syntax Tree (AST) nodes for multi-party computation expressions.
17
-
18
- This module defines the AST nodes for representing multi-party computation expressions.
19
- Each node type represents a different kind of operation or construct in the multi-party
20
- computation language, following the visitor pattern for extensible processing.
21
- """
22
-
23
- from __future__ import annotations
24
-
25
- import logging
26
- from abc import ABC, abstractmethod
27
- from typing import TYPE_CHECKING, Any
28
-
29
- from mplang.v1.core.expr.utils import deduce_mask
30
- from mplang.v1.core.mask import Mask
31
- from mplang.v1.core.mptype import MPType, Rank
32
- from mplang.v1.core.pfunc import PFunction
33
- from mplang.v1.core.table import TableType
34
- from mplang.v1.core.tensor import TensorType
35
-
36
- if TYPE_CHECKING:
37
- from mplang.v1.core.expr.visitor import ExprVisitor
38
-
39
-
40
- class Expr(ABC):
41
- """Base class for all expression types in the multi-party computation graph.
42
-
43
- This expression system is designed to be Multi-Input Multi-Output (MIMO),
44
- meaning each expression node can conceptually have multiple outputs. This is
45
- fundamental to supporting multi-output PFunctions and constructing complex
46
- dataflow graphs efficiently.
47
-
48
- Attributes:
49
- mptypes (list[MPType]): The list of output types for this expression. This
50
- is the core property that enables MIMO capabilities. It's computed
51
- lazily and cached.
52
- mptype (MPType): A convenience property for the common case of a single-output
53
- expression. It raises a ValueError if the expression does not have
54
- exactly one output, providing a useful runtime check.
55
- """
56
-
57
- def __init__(self) -> None:
58
- self._mptypes: list[MPType] | None = None
59
-
60
- @property
61
- def num_outputs(self) -> int:
62
- """Return the number of outputs this expression produces."""
63
- return len(self.mptypes)
64
-
65
- @property
66
- def mptypes(self) -> list[MPType]:
67
- if self._mptypes is None:
68
- self._mptypes = self._compute_mptypes()
69
- return self._mptypes
70
-
71
- @property
72
- def mptype(self) -> MPType:
73
- """Convenience property for single-output expressions."""
74
- types = self.mptypes
75
- if len(types) != 1:
76
- raise ValueError(f"Expression has {len(types)} outputs, expected 1")
77
- return types[0]
78
-
79
- @abstractmethod
80
- def _compute_mptypes(self) -> list[MPType]:
81
- """Computes the types of the expression's outputs."""
82
-
83
- @abstractmethod
84
- def accept(self, visitor: ExprVisitor) -> Any:
85
- """Accept a visitor for the visitor pattern."""
86
-
87
-
88
- # ============================================================================
89
- # Concrete Expression Classes
90
- # ============================================================================
91
-
92
-
93
- class EvalExpr(Expr):
94
- """Expression for multi-party function evaluation."""
95
-
96
- def __init__(
97
- self, pfunc: PFunction, args: list[Expr], rmask: Mask | int | None = None
98
- ):
99
- super().__init__()
100
- # Type checking - basic validation that we have the right number of inputs
101
- if len(args) != len(pfunc.ins_info):
102
- raise ValueError(
103
- f"Expected {len(pfunc.ins_info)} arguments, got {len(args)}"
104
- )
105
- rmask = Mask(rmask) if rmask is not None else None
106
-
107
- self.pfunc = pfunc
108
- self.args = args
109
- self.rmask = rmask
110
-
111
- def _compute_mptypes(self) -> list[MPType]:
112
- """Compute output MPTypes based on PFunction and mask deduction logic.
113
-
114
- The logic follows these steps:
115
- 1. Determine output TensorType (dtype + shape) from PFunction
116
- 2. If rmask is explicitly provided (caller has strong mask knowledge):
117
- 2.1 Deduce pmask from args (intersection of all arg pmasks)
118
- 2.1.1 If deduced pmask is not None (trace time known):
119
- - If rmask is subset of deduced pmask: use rmask
120
- - If rmask is not subset of deduced pmask: raise error
121
- 2.1.2 If deduced pmask is None (trace time unknown): force use rmask
122
- 3. If rmask is not provided (caller lets expr deduce it): use deduced pmask from args
123
- """
124
- # Deduce pmask from arguments (including None values - if any arg has None, result is None)
125
- arg_pmasks = [arg.mptype.pmask for arg in self.args]
126
- deduced_pmask = deduce_mask(*arg_pmasks)
127
-
128
- # Determine effective output pmask
129
- effective_pmask: Mask | None
130
- if self.rmask is not None:
131
- # rmask is explicitly provided - caller has strong mask knowledge
132
- if deduced_pmask is not None:
133
- # pmask is known at trace time - validate subset relationship
134
- if not Mask(self.rmask).is_subset(deduced_pmask):
135
- raise ValueError(
136
- f"Specified rmask {self.rmask} is not a subset of deduced pmask {deduced_pmask}."
137
- )
138
- effective_pmask = self.rmask
139
- else:
140
- # pmask is unknown at trace time - force use rmask
141
- effective_pmask = self.rmask
142
- else:
143
- # rmask not provided - use deduced pmask from args
144
- effective_pmask = deduced_pmask
145
-
146
- # Create result MPTypes based on PFunction output info
147
- result_types = []
148
- for out_info in self.pfunc.outs_info:
149
- if isinstance(out_info, TensorType):
150
- # Tensor type
151
- result_types.append(
152
- MPType.tensor(out_info.dtype, out_info.shape, effective_pmask)
153
- )
154
- elif isinstance(out_info, TableType):
155
- # Table type
156
- result_types.append(MPType.table(out_info, effective_pmask))
157
- else:
158
- raise TypeError(f"Unsupported output type: {type(out_info)}")
159
- return result_types
160
-
161
- def accept(self, visitor: ExprVisitor) -> Any:
162
- return visitor.visit_eval(self)
163
-
164
-
165
- class TupleExpr(Expr):
166
- """Expression for creating a tuple from multiple single-output expressions.
167
-
168
- In a Multi-Input Multi-Output (MIMO) expression system, this primitive
169
- creates a logical tuple from multiple single-output expressions. Unlike
170
- the previous FlattenExpr, TupleExpr requires all input expressions to
171
- have exactly one output each.
172
-
173
- This expression acts as a "tuple construction" primitive. It takes a list
174
- of single-output expressions and produces a new logical expression whose
175
- outputs are the list of all input expression outputs.
176
-
177
- For example, if expr1 has output [A] and expr2 has output [B],
178
- TupleExpr([expr1, expr2]) will have outputs [A, B].
179
-
180
- This is the opposite of AccessExpr, which extracts a single element
181
- from a multi-output expression.
182
- """
183
-
184
- def __init__(self, args: list[Expr]):
185
- super().__init__()
186
- # Validate that all arguments are single-output expressions
187
- for i, arg in enumerate(args):
188
- if arg.num_outputs != 1:
189
- raise ValueError(
190
- f"TupleExpr requires all arguments to be single-output expressions, "
191
- f"but argument {i} has {arg.num_outputs} outputs"
192
- )
193
- self.args = args
194
-
195
- def _compute_mptypes(self) -> list[MPType]:
196
- # TupleExpr creates a tuple from single-output expressions
197
- result_types = []
198
- for arg in self.args:
199
- result_types.append(
200
- arg.mptype
201
- ) # Use mptype since we validated single output
202
- return result_types
203
-
204
- def accept(self, visitor: ExprVisitor) -> Any:
205
- return visitor.visit_tuple(self)
206
-
207
-
208
- class CondExpr(Expr):
209
- """Expression for conditional execution.
210
-
211
- Added fields:
212
- verify_uniform: whether runtime should assert the predicate is uniform across parties.
213
- """
214
-
215
- def __init__(
216
- self,
217
- pred: Expr,
218
- then_fn: FuncDefExpr,
219
- else_fn: FuncDefExpr,
220
- args: list[Expr],
221
- verify_uniform: bool = False,
222
- ):
223
- super().__init__()
224
- self.pred = pred
225
- self.then_fn = then_fn
226
- self.else_fn = else_fn
227
- self.args = args
228
- self.verify_uniform = verify_uniform
229
-
230
- def _compute_mptypes(self) -> list[MPType]:
231
- for t_type, e_type in zip(
232
- self.then_fn.mptypes, self.else_fn.mptypes, strict=False
233
- ):
234
- if t_type != e_type:
235
- raise TypeError(
236
- f"Then branch type {t_type} does not match else branch type {e_type}"
237
- )
238
- return self.then_fn.mptypes
239
-
240
- def accept(self, visitor: ExprVisitor) -> Any:
241
- return visitor.visit_cond(self)
242
-
243
-
244
- class WhileExpr(Expr):
245
- """Expression for while loop."""
246
-
247
- def __init__(
248
- self,
249
- cond_fn: FuncDefExpr,
250
- body_fn: FuncDefExpr,
251
- args: list[Expr],
252
- ):
253
- super().__init__()
254
- if not args:
255
- raise ValueError("WhileExpr requires at least one argument (init value)")
256
- self.cond_fn = cond_fn
257
- self.body_fn = body_fn
258
- self.args = args
259
-
260
- def _compute_mptypes(self) -> list[MPType]:
261
- # The result types of a while loop are the same as the body function's outputs.
262
- # This supports multi-value loop-carried state (PyTree leaves) and ensures
263
- # evaluator can determine how many values are produced by the loop.
264
- return self.body_fn.mptypes
265
-
266
- def accept(self, visitor: ExprVisitor) -> Any:
267
- return visitor.visit_while(self)
268
-
269
-
270
- class ConvExpr(Expr):
271
- """Expression for convergence of multiple variables."""
272
-
273
- def __init__(self, vars: list[Expr]):
274
- super().__init__()
275
-
276
- # Validate all vars have identical out-length.
277
- for v in vars:
278
- if v.num_outputs != 1:
279
- raise ValueError("All variables in ConvExpr must have the same arity.")
280
-
281
- self.vars = vars
282
-
283
- def _compute_mptypes(self) -> list[MPType]:
284
- # Collect the idx-th mptype from every var.
285
- types = [v.mptype for v in self.vars]
286
- # Validate dtype / shape consistency.
287
- first = types[0]
288
- for c in types[1:]:
289
- if c.raw_type() != first.raw_type():
290
- raise TypeError(f"Inconsistent type in pconv: {c} vs {first}")
291
-
292
- # Deduce the pmask by intersecting all pmasks.
293
- pmasks = [t.pmask for t in types]
294
- dynamic_pmask = False
295
- if any(pmask is None for pmask in pmasks):
296
- logging.warning("pconv called with None pmask.")
297
- dynamic_pmask = True
298
-
299
- non_none_pmasks = [pmask for pmask in pmasks if pmask is not None]
300
- for i, mask1 in enumerate(non_none_pmasks):
301
- for mask2 in non_none_pmasks[i + 1 :]:
302
- if not Mask(mask1).is_disjoint(mask2):
303
- raise ValueError(
304
- f"pconv called with non-disjoint pmasks: {pmasks}."
305
- )
306
-
307
- # deduce output pmask.
308
- if dynamic_pmask:
309
- out_pmask = None
310
- else:
311
- valid_pmasks = [pmask for pmask in pmasks if pmask is not None]
312
- if valid_pmasks:
313
- out_pmask = Mask(valid_pmasks[0])
314
- for mask in valid_pmasks[1:]:
315
- out_pmask = out_pmask.union(mask)
316
- else:
317
- out_pmask = None
318
-
319
- return [MPType(first.raw_type(), out_pmask, first.attrs)]
320
-
321
- def accept(self, visitor: ExprVisitor) -> Any:
322
- return visitor.visit_conv(self)
323
-
324
-
325
- class ShflSExpr(Expr):
326
- """Expression for static shuffle operation.
327
-
328
- Redistributes data from source ranks to target ranks based on a specified
329
- mapping. Each party in the output mask (`pmask`) receives data from a
330
- corresponding source rank specified in `src_ranks`.
331
-
332
- Rationale for Design (Pull vs. Push Model):
333
- This operation uses a "pull" model, where each receiving party explicitly
334
- states its data source (`src_ranks`). This contrasts with a "push" model,
335
- where each sending party would specify a destination.
336
-
337
- The pull model is chosen because it guarantees that every party in the
338
- output `pmask` receives exactly one value, upholding the semantic
339
- integrity of the computation graph.
340
-
341
- A push model, on the other hand, would be semantically ambiguous. For
342
- example, two different source parties could attempt to send data to the
343
- same destination, or some parties might receive no data at all. This
344
- would break the Single Instruction, Multiple Programs (SIMP) paradigm by
345
- creating an unpredictable number of outputs at each party.
346
-
347
- While the pull model might have performance implications if multiple
348
- receivers pull from the same source (potentially creating a network
349
- bottleneck at that source), this is a performance consideration rather
350
- than a correctness issue. The chosen design prioritizes semantic
351
- predictability and correctness.
352
- """
353
-
354
- def __init__(self, src_val: Expr, pmask: Mask, src_ranks: list[Rank]):
355
- """Initialize static shuffle expression.
356
-
357
- Args:
358
- src_val (Expr): The input tensor to be shuffled.
359
- pmask (Mask): The mask indicating which parties will hold the output.
360
- Only parties with non-zero bits in pmask will receive output.
361
- src_ranks (list[Rank]): List of source ranks. The i-th output party
362
- (i-th non-zero bit in pmask) receives data from
363
- src_ranks[i].
364
-
365
- Raises:
366
- ValueError: If src_val has multiple outputs, if src_ranks length doesn't
367
- match pmask bit count, or if any rank in src_ranks is not
368
- present in src_val.pmask.
369
-
370
- Example:
371
- If pmask indicates parties [0, 2] should receive output and src_ranks = [1, 3], then:
372
- - Party 0 receives data from rank 1
373
- - Party 2 receives data from rank 3
374
- """
375
- super().__init__()
376
- if src_val.num_outputs != 1:
377
- raise ValueError(
378
- f"ShflSExpr requires a single output source, got {src_val.num_outputs}"
379
- )
380
-
381
- # Assign values first before validation
382
- self.src_val = src_val
383
- self.pmask = pmask
384
- self.src_ranks = src_ranks
385
-
386
- # Now do validation using the assigned values
387
- if len(self.src_ranks) != Mask(self.pmask).num_parties():
388
- raise ValueError(
389
- f"src_ranks length ({len(self.src_ranks)}) not match {self.pmask}"
390
- )
391
- for i, rank in enumerate(self.src_ranks):
392
- src_pmask = self.src_val.mptype.pmask
393
- if src_pmask is not None and rank not in Mask(src_pmask):
394
- raise ValueError(
395
- f"Source rank {rank} at index {i} is not present in src {Mask(src_pmask)}"
396
- )
397
-
398
- def _compute_mptypes(self) -> list[MPType]:
399
- # The types are the same as the source value, but with a new pmask.
400
- src_type = self.src_val.mptype
401
- return [MPType(src_type._type, self.pmask, src_type.attrs)]
402
-
403
- def accept(self, visitor: ExprVisitor) -> Any:
404
- return visitor.visit_shfl_s(self)
405
-
406
-
407
- class ShflExpr(Expr):
408
- """Expression for dynamic shuffle operation."""
409
-
410
- def __init__(self, src: Expr, index: Expr):
411
- super().__init__()
412
- self.src = src
413
- self.index = index
414
-
415
- def _compute_mptypes(self) -> list[MPType]:
416
- # Dynamic shuffle is complex. The resulting pmask is often unknown
417
- # at compile time. We'll assume the tensor types remain the same
418
- # but the pmask becomes None (runtime-determined).
419
- src_types = self.src.mptypes
420
- result_types = []
421
- for src_type in src_types:
422
- result_types.append(
423
- MPType.tensor(src_type.dtype, src_type.shape, None, **src_type.attrs)
424
- )
425
- return result_types
426
-
427
- def accept(self, visitor: ExprVisitor) -> Any:
428
- return visitor.visit_shfl(self)
429
-
430
-
431
- class AccessExpr(Expr):
432
- """Expression for accessing a specific output of a multi-output expression.
433
-
434
- As the counterpart to TupleExpr, AccessExpr is the "un-packing" or "selection"
435
- primitive in the MIMO system. It takes a (potentially multi-output) expression
436
- and an index, and produces a new single-output expression representing just
437
- the selected output.
438
-
439
- This is essential for routing specific outputs from a multi-output function
440
- or a flattened stream to subsequent operations that expect single inputs.
441
- """
442
-
443
- def __init__(self, src: Expr, index: int):
444
- super().__init__()
445
- self.src = src
446
- self.index = index
447
-
448
- def _compute_mptypes(self) -> list[MPType]:
449
- # Access a specific output from the expression's output list
450
- expr_types = self.src.mptypes
451
- if self.index < 0 or self.index >= len(expr_types):
452
- raise IndexError(
453
- f"Index {self.index} out of range for expression with {len(expr_types)} outputs"
454
- )
455
- return [expr_types[self.index]]
456
-
457
- def accept(self, visitor: ExprVisitor) -> Any:
458
- return visitor.visit_access(self)
459
-
460
-
461
- class VariableExpr(Expr):
462
- """Expression for variable reference/lookup."""
463
-
464
- def __init__(self, name: str, mptype: MPType):
465
- super().__init__()
466
- self.name = name
467
- self.mptype_value = mptype
468
-
469
- def _compute_mptypes(self) -> list[MPType]:
470
- # Return the explicitly provided type for this variable.
471
- return [self.mptype_value]
472
-
473
- def accept(self, visitor: ExprVisitor) -> Any:
474
- return visitor.visit_variable(self)
475
-
476
-
477
- class FuncDefExpr(Expr):
478
- """Expression representing a function definition with parameters and body.
479
-
480
- This class captures the essence of lambda abstraction in functional programming.
481
- The body expression tree may contain free variables (VariableExpr nodes) that
482
- reference parameter names. When the function is called, arguments are bound
483
- to parameters positionally, resolving these free variables.
484
-
485
- Example:
486
- Consider a function that adds two variables:
487
- ```
488
- # Body expression tree contains free variables "x" and "y"
489
- body = EvalExpr(
490
- add_pfunc, [VariableExpr("x", int_type), VariableExpr("y", int_type)]
491
- )
492
-
493
- # Parameters define the binding order - note "y" comes before "x"
494
- params = ["z", "y", "x"] # extra parameter "z", different order
495
-
496
- func_def = FuncDefExpr(params, body)
497
-
498
- # When called with [expr0, expr1, expr2]:
499
- # - "z" binds to expr0 (unused in body, but valid)
500
- # - "y" binds to expr1 (resolves VariableExpr("y") in body)
501
- # - "x" binds to expr2 (resolves VariableExpr("x") in body)
502
- call = CallExpr(func_def, [expr0, expr1, expr2])
503
- ```
504
-
505
- Key insights:
506
- - Free variables in the body are placeholders waiting for concrete expressions
507
- - Parameters act as a "binding contract" - they define which arguments map to which variables
508
- - Parameter order matters for positional binding, not alphabetical or usage order
509
- - Parameters can include names not used in the body (dead parameters)
510
- - All free variables in the body should have corresponding parameters for well-formed functions
511
- """
512
-
513
- def __init__(self, params: list[str], body: Expr):
514
- super().__init__()
515
- self.params = params
516
- self.body = body
517
-
518
- def _compute_mptypes(self) -> list[MPType]:
519
- # The types of a function are the types of its body.
520
- return self.body.mptypes
521
-
522
- def accept(self, visitor: ExprVisitor) -> Any:
523
- return visitor.visit_func_def(self)
524
-
525
-
526
- class CallExpr(Expr):
527
- """Expression for function call."""
528
-
529
- def __init__(self, name: str, fn: FuncDefExpr, args: list[Expr]):
530
- super().__init__()
531
- self.name = name
532
- self.fn = fn
533
- self.args = args
534
-
535
- def _compute_mptypes(self) -> list[MPType]:
536
- # The result types are the types of the function's body, with parameter
537
- # types substituted. For simplicity, we return the function's declared
538
- # return types. A full implementation would require substitution logic.
539
- return self.fn.mptypes
540
-
541
- def accept(self, visitor: ExprVisitor) -> Any:
542
- return visitor.visit_call(self)