mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
mplang/v1/core/table.py DELETED
@@ -1,218 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- from collections.abc import Iterator
18
- from dataclasses import dataclass, field
19
- from typing import Any, Protocol, runtime_checkable
20
-
21
- from mplang.v1.core.dtypes import DType
22
-
23
- __all__ = ["TableLike", "TableType"]
24
-
25
-
26
- @runtime_checkable
27
- class PandasTableLike(Protocol):
28
- """
29
- Protocol for objects structurally resembling tables from common libraries
30
- (pandas DataFrame, polars DataFrame, etc.), focusing on dtypes and columns attributes.
31
- """
32
-
33
- @property
34
- def dtypes(self) -> Any: ...
35
-
36
- @property
37
- def columns(self) -> Any: ...
38
-
39
-
40
- @runtime_checkable
41
- class ArrowSchema(Protocol):
42
- @property
43
- def names(self) -> list[str]: ...
44
- @property
45
- def types(self) -> list[Any]: ...
46
-
47
-
48
- @runtime_checkable
49
- class ArrowTableLike(Protocol):
50
- @property
51
- def column_names(self) -> list[str]: ...
52
-
53
- @property
54
- def schema(self) -> ArrowSchema: ...
55
-
56
-
57
- TableLike = PandasTableLike | ArrowTableLike
58
-
59
-
60
- @dataclass(frozen=True)
61
- class TableType:
62
- """Table schema: ordered list of column name-type pairs.
63
-
64
- Represents table structure in relational algebra, containing column names
65
- and their corresponding data types.
66
-
67
- Examples:
68
- >>> schema = TableType.from_dict({
69
- ... "id": DType.i64(),
70
- ... "name": DType.string(),
71
- ... })
72
- >>> schema = TableType((("id", DType.i64()), ("name", DType.string())))
73
- """
74
-
75
- columns: tuple[tuple[str, DType], ...]
76
- _column_map: dict[str, DType] = field(init=False, repr=False, compare=False)
77
-
78
- def __post_init__(self) -> None:
79
- """Validate the table schema."""
80
- if not self.columns:
81
- raise ValueError("TableType cannot be empty")
82
-
83
- # Validate column name uniqueness
84
- names = [name for name, _ in self.columns]
85
- if len(names) != len(set(names)):
86
- raise ValueError("Column names must be unique")
87
-
88
- # Validate column names are non-empty
89
- for name, dtype in self.columns:
90
- if not name or not isinstance(name, str):
91
- raise ValueError("Column names must be non-empty strings")
92
- if not isinstance(dtype, DType):
93
- raise ValueError(f"Column type must be DType, got {type(dtype)}")
94
-
95
- # Create column name to type mapping for O(1) lookups
96
- object.__setattr__(self, "_column_map", dict(self.columns))
97
-
98
- @classmethod
99
- def from_dict(cls, schema_dict: dict[str, DType]) -> TableType:
100
- """Create table schema from dictionary.
101
-
102
- Args:
103
- schema_dict: Mapping from column names to data types
104
-
105
- Returns:
106
- TableType instance
107
- """
108
- return cls(tuple(schema_dict.items()))
109
-
110
- @classmethod
111
- def from_pairs(cls, pairs: list[tuple[str, DType]]) -> TableType:
112
- """Create table schema from list of name-type pairs.
113
-
114
- Args:
115
- pairs: List of tuples containing column name and data type
116
-
117
- Returns:
118
- TableType instance
119
- """
120
- return cls(tuple(pairs))
121
-
122
- @classmethod
123
- def from_tablelike(cls, table: TableLike) -> TableType:
124
- """Create table schema from a table-like object.
125
-
126
- Args:
127
- table: A table-like object (e.g., pandas DataFrame)
128
-
129
- Returns:
130
- TableType instance
131
- """
132
- if isinstance(table, PandasTableLike):
133
- columns = [
134
- (name, DType.from_any(dtype))
135
- for name, dtype in zip(table.columns, table.dtypes, strict=True)
136
- ]
137
- return cls(tuple(columns))
138
- elif isinstance(table, ArrowTableLike):
139
- schema = table.schema
140
- columns = [
141
- (name, DType.from_any(dtype))
142
- for name, dtype in zip(schema.names, schema.types, strict=True)
143
- ]
144
- return cls(tuple(columns))
145
-
146
- def column_names(self) -> tuple[str, ...]:
147
- """Get all column names."""
148
- return tuple(name for name, _ in self.columns)
149
-
150
- def column_types(self) -> tuple[DType, ...]:
151
- """Get all column data types."""
152
- return tuple(dtype for _, dtype in self.columns)
153
-
154
- def get_column_type(self, name: str) -> DType:
155
- """Get data type by column name.
156
-
157
- Args:
158
- name: Column name
159
-
160
- Returns:
161
- Corresponding data type
162
-
163
- Raises:
164
- KeyError: If column name does not exist
165
- """
166
- try:
167
- return self._column_map[name]
168
- except KeyError:
169
- raise KeyError(f"Column '{name}' not found in schema") from None
170
-
171
- def has_column(self, name: str) -> bool:
172
- """Check if contains specified column name.
173
-
174
- Args:
175
- name: Column name
176
-
177
- Returns:
178
- True if contains the column, False otherwise
179
- """
180
- return name in self.column_names()
181
-
182
- def num_columns(self) -> int:
183
- """Get number of columns."""
184
- return len(self.columns)
185
-
186
- def to_dict(self) -> dict[str, DType]:
187
- """Convert to dictionary form."""
188
- return dict(self.columns)
189
-
190
- def __repr__(self) -> str:
191
- """String representation."""
192
- cols = ", ".join(f"{name}:{dtype.short_name()}" for name, dtype in self.columns)
193
- return f"TableType<{cols}>"
194
-
195
- def __len__(self) -> int:
196
- """Get number of columns."""
197
- return len(self.columns)
198
-
199
- def __iter__(self) -> Iterator[tuple[str, DType]]:
200
- """Support iteration."""
201
- return iter(self.columns)
202
-
203
- def __getitem__(self, index: int | str) -> tuple[str, DType] | DType:
204
- """Support index access.
205
-
206
- Args:
207
- index: Integer index or column name
208
-
209
- Returns:
210
- If integer index, returns (column name, data type) tuple
211
- If column name, returns corresponding data type
212
- """
213
- if isinstance(index, int):
214
- return self.columns[index]
215
- elif isinstance(index, str):
216
- return self.get_column_type(index)
217
- else:
218
- raise TypeError(f"Index must be int or str, got {type(index)}")
mplang/v1/core/tensor.py DELETED
@@ -1,75 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- from dataclasses import dataclass
18
- from typing import Any, Protocol, runtime_checkable
19
-
20
- import numpy as np
21
-
22
- from mplang.v1.core.dtypes import DType
23
-
24
- # basic type aliases
25
- Shape = tuple[int, ...]
26
- ScalarType = int | float | bool | complex
27
-
28
- __all__ = ["ScalarType", "Shape", "TensorLike", "TensorType"]
29
-
30
-
31
- @runtime_checkable
32
- class TensorLike(Protocol):
33
- """
34
- Protocol for objects structurally resembling tensors from common libraries
35
- (NumPy, PyTorch, JAX), focusing on dtype and shape attributes.
36
- """
37
-
38
- @property
39
- def dtype(self) -> Any: ...
40
-
41
- @property
42
- def shape(self) -> Shape: ...
43
-
44
-
45
- @dataclass(frozen=True)
46
- class TensorType:
47
- """A data class that describes the type information of a tensor."""
48
-
49
- dtype: DType
50
- shape: Shape
51
-
52
- def __init__(self, dtype: DType | Any, shape: Shape):
53
- # Convert dtype to DType if needed
54
- if not isinstance(dtype, DType):
55
- dtype = DType.from_any(dtype)
56
- object.__setattr__(self, "dtype", dtype)
57
- object.__setattr__(self, "shape", shape)
58
-
59
- @classmethod
60
- def from_obj(cls, obj: TensorLike | ScalarType) -> TensorType:
61
- if isinstance(obj, ScalarType):
62
- return cls(DType.from_python_type(type(obj)), ())
63
- elif isinstance(obj, TensorLike):
64
- return cls(DType.from_any(obj.dtype), obj.shape)
65
- else:
66
- raise TypeError(f"Unsupported type: {type(obj)}.")
67
-
68
- def to_numpy(self) -> np.dtype:
69
- """Convert to NumPy dtype for compatibility."""
70
- return self.dtype.to_numpy()
71
-
72
- def __repr__(self) -> str:
73
- shape_str = "x".join(str(d) for d in self.shape)
74
- dtype_name = str(self.dtype)
75
- return f"Tensor<{shape_str}x{dtype_name}>" if shape_str else f"{dtype_name}"
mplang/v1/core/tracer.py DELETED
@@ -1,383 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- Trace context and TraceVar implementation.
17
-
18
- This module provides the trace context for lazy evaluation and TraceVar
19
- which stores expressions for deferred computation.
20
-
21
- Design Philosophy (inspired by JAX):
22
- ====================================
23
- The tracing mechanism converts Python functions operating on data into a static,
24
- dataflow graph representation (Expr) for analysis and multi-party execution.
25
- This follows a "closed-world" design, similar to JAX's JIT, with a core
26
- principle: functions are for data transformation ("Tensor in, Tensor out").
27
-
28
- This imposes several intentional limitations:
29
- - **Data-Centric Boundaries**: Only MPObjects (tensors or their pytrees) and
30
- immediate values can be passed as arguments to or be returned from a traced
31
- function.
32
- - **No Function Outputs**: A traced function cannot return a Python function that
33
- has captured tracers, as this would violate the static nature of the graph.
34
- - **Limited Function Inputs**: Arbitrary Python functions are not supported as
35
- arguments. However, for structured control flow (e.g., `cond`, `while_loop`),
36
- `mplang` allows passing Python functions. These are not true first-class
37
- functions; they are immediately traced into sub-graphs (`FuncDefExpr`) and
38
- embedded into the IR, never existing as runtime values within the graph.
39
-
40
- Rationale for TracedFunction vs. First-Class Functions:
41
- -------------------------------------------------------
42
- Instead of representing functions as `TraceVar(expr=FuncDefExpr)`, a dedicated
43
- `TracedFunction` class is used. This is crucial for:
44
-
45
- 1. **Type Safety & Clear Boundaries**: `TracedFunction` represents a callable
46
- computation, while `TraceVar` represents data. This separation prevents
47
- treating computation as data within the graph.
48
- 2. **Preserving Metadata**: It holds essential metadata for marshalling arguments
49
- and results, such as pytree structures (`in_struct`/`out_struct`) and
50
- captured variables, which a simple `Expr` would not retain.
51
-
52
- This design avoids the complexities of dynamic dispatch and higher-order functions
53
- in the IR, making the resulting graph simpler, more analyzable, and easier to
54
- compile for a multi-party setting.
55
- """
56
-
57
- from __future__ import annotations
58
-
59
- from collections.abc import Callable
60
- from dataclasses import dataclass
61
- from typing import Any, cast
62
-
63
- from mplang.v1.core.cluster import ClusterSpec
64
- from mplang.v1.core.context_mgr import with_ctx
65
- from mplang.v1.core.expr.ast import Expr, FuncDefExpr, TupleExpr, VariableExpr
66
- from mplang.v1.core.expr.printer import Printer
67
- from mplang.v1.core.mask import Mask
68
- from mplang.v1.core.mpobject import MPContext, MPObject
69
- from mplang.v1.core.mptype import MPType
70
- from mplang.v1.core.pfunc import get_fn_name
71
- from mplang.v1.utils.func_utils import MorphStruct, var_demorph, var_morph
72
-
73
-
74
- class VarNamer:
75
- """Utility class to generate unique variable names in a trace context."""
76
-
77
- def __init__(self, prefix: str = "$"):
78
- self._counter = 0
79
- self._prefix = prefix
80
-
81
- def next_name(self) -> str:
82
- """Generate a new unique variable name."""
83
- name = f"{self._prefix}{self._counter}"
84
- self._counter += 1
85
- return name
86
-
87
-
88
- class TraceContext(MPContext):
89
- """Context for lazy evaluation using expressions.
90
-
91
- TraceContext builds computation graphs by creating TraceVar objects
92
- that store expressions instead of executing them immediately.
93
- """
94
-
95
- def __init__(
96
- self,
97
- cluster_spec: ClusterSpec,
98
- *,
99
- mask: Mask | None = None,
100
- capture_namer: VarNamer | None = None,
101
- parent: MPContext | None = None,
102
- ):
103
- """Initialize TraceContext with a cluster specification.
104
-
105
- Args:
106
- cluster_spec: The cluster specification defining the physical nodes
107
- and logical devices available for computation.
108
- mask: The default mask for this context. If None, defaults to all parties.
109
- capture_namer: Optional VarNamer for naming captured variables.
110
- """
111
- super().__init__(cluster_spec, parent=parent)
112
-
113
- self._mask = mask or Mask.all(self.world_size())
114
- self._capture_namer = capture_namer or VarNamer()
115
-
116
- self._var_namer = VarNamer(prefix="%")
117
- self._captures: dict[MPObject, TraceVar] = {}
118
-
119
- @property
120
- def mask(self) -> Mask:
121
- """The default mask for this context."""
122
- return self._mask
123
-
124
- def _gen_name(self) -> str:
125
- """Generate a unique variable name."""
126
- return self._capture_namer.next_name()
127
-
128
- def fork(self, mask: Mask | None = None) -> TraceContext:
129
- """Create a new TraceContext with the same attributes."""
130
- if mask is None:
131
- mask = self._mask
132
- else:
133
- # ensure mask is subset of the current mask
134
- if not Mask(mask).is_subset(self._mask):
135
- raise ValueError(
136
- f"New mask {mask} must be a subset of the current mask {self._mask}"
137
- )
138
-
139
- return TraceContext(
140
- cluster_spec=self.cluster_spec,
141
- mask=mask,
142
- parent=self._parent,
143
- # capture_namer=self._capture_namer,
144
- )
145
-
146
- def capture(self, obj: MPObject) -> TraceVar:
147
- """Create or reuse a variable that represents a captured MPObject.
148
-
149
- This method ensures that the same captured object always maps to
150
- the same variable in the traced function.
151
-
152
- Args:
153
- obj: The MPObject being captured from another context
154
-
155
- Returns:
156
- TraceVar representing the captured variable in this context
157
- """
158
- # If we've seen this object before, return the existing variable
159
- if obj in self._captures:
160
- return self._captures[obj]
161
-
162
- # Use the object's name directly if available, otherwise generate a name
163
- capture_name = self._gen_name()
164
- var = TraceVar(self, VariableExpr(capture_name, obj.mptype))
165
- self._captures[obj] = var
166
-
167
- return var
168
-
169
- def get_captures(self) -> dict[MPObject, TraceVar]:
170
- return self._captures
171
-
172
-
173
- class TraceVar(MPObject):
174
- """A variable that stores an expression for lazy evaluation.
175
-
176
- TraceVar represents a computation that has not yet been executed.
177
- It stores the expression tree that would produce the value when evaluated.
178
- The expression must be single-output (checked at construction time).
179
- """
180
-
181
- def __init__(self, ctx: TraceContext, expr: Expr):
182
- # Ensure the expression is single-output
183
- if len(expr.mptypes) != 1:
184
- raise ValueError(
185
- f"TraceVar requires single-output expression, "
186
- f"but expression has {len(expr.mptypes)} outputs"
187
- )
188
-
189
- self._ctx = ctx
190
- self._expr = expr
191
-
192
- @property
193
- def ctx(self) -> MPContext:
194
- """The context this variable belongs to."""
195
- return self._ctx
196
-
197
- @property
198
- def expr(self) -> Expr:
199
- """The expression that this variable represents."""
200
- return self._expr
201
-
202
- @property
203
- def mptype(self) -> MPType:
204
- """The type of this variable, derived from the expression."""
205
- return self._expr.mptype
206
-
207
- def __repr__(self) -> str:
208
- return f"TraceVar(expr={self.expr.__class__.__name__})"
209
-
210
-
211
- @dataclass
212
- class TracedFunction:
213
- func_name: str
214
- """The name of the traced function."""
215
-
216
- in_vars: list[TraceVar]
217
- """List of free (input) variables in the traced function."""
218
- in_struct: MorphStruct
219
- in_imms: list[Any]
220
-
221
- capture_map: dict[MPObject, TraceVar]
222
- """Map of captured MPObjects to their traced values."""
223
-
224
- out_vars: list[TraceVar]
225
- """List of output TraceVars."""
226
- out_struct: MorphStruct
227
- out_imms: list[Any]
228
-
229
- def in_names(self) -> list[str]:
230
- """Get the parameter names of the traced function."""
231
- return [cast(VariableExpr, var.expr).name for var in self.in_vars]
232
-
233
- def capture_names(self, captures: list[MPObject] | None = None) -> list[str]:
234
- if captures is None:
235
- captures = list(self.capture_map.keys())
236
-
237
- def var_name(var: TraceVar | None) -> str:
238
- return cast(VariableExpr, var.expr).name if var is not None else ""
239
-
240
- return [var_name(self.capture_map.get(var, None)) for var in captures]
241
-
242
- def make_expr(self, freevar_names: list[str] | None = None) -> FuncDefExpr:
243
- """Create a FuncDefExpr from the traced function data."""
244
- arg_names = [cast(VariableExpr, var.expr).name for var in self.in_vars]
245
- capture_names = [
246
- cast(VariableExpr, var.expr).name for var in self.capture_map.values()
247
- ]
248
- if freevar_names is None:
249
- # If no freevar_names provided, use default names
250
- freevar_names = arg_names + capture_names
251
- else:
252
- # Ensure freevar_names is superset of arg_names and capture_names
253
- if not set(arg_names).issubset(freevar_names):
254
- raise ValueError(
255
- f"Provided freevar_names {freevar_names} must include all input variable names {arg_names}"
256
- )
257
- if not set(capture_names).issubset(freevar_names):
258
- raise ValueError(
259
- f"Provided freevar_names {freevar_names} must include all capture variable names {capture_names}"
260
- )
261
-
262
- if len(self.out_vars) == 0:
263
- # No outputs - use empty tuple
264
- body_expr: Expr = TupleExpr([])
265
- return FuncDefExpr(freevar_names, body_expr)
266
- elif len(self.out_vars) == 1:
267
- body_expr = self.out_vars[0].expr
268
- return FuncDefExpr(freevar_names, body_expr)
269
- else:
270
- # Multiple outputs - use tuple (ensures all vars are single-output)
271
- body_expr = TupleExpr([var.expr for var in self.out_vars])
272
- return FuncDefExpr(freevar_names, body_expr)
273
-
274
- def is_signature_match(
275
- self,
276
- other: TracedFunction,
277
- check_captures: bool = True,
278
- ) -> bool:
279
- """Check if this function's signature matches another."""
280
- if not isinstance(other, TracedFunction):
281
- return False
282
- # Check input structures and immutables
283
- if (
284
- self.in_struct != other.in_struct
285
- or self.in_imms != other.in_imms
286
- or self.out_struct != other.out_struct
287
- or self.out_imms != other.out_imms
288
- ):
289
- return False
290
-
291
- # Check input type match
292
- if len(self.in_vars) != len(other.in_vars):
293
- return False
294
- for var, other_var in zip(self.in_vars, other.in_vars, strict=False):
295
- if var.mptype != other_var.mptype:
296
- return False
297
-
298
- # Check captures if required
299
- if check_captures:
300
- if len(self.capture_map) != len(other.capture_map):
301
- return False
302
- for key, var in self.capture_map.items():
303
- if (
304
- key not in other.capture_map
305
- or var.mptype != other.capture_map[key].mptype
306
- ):
307
- return False
308
-
309
- # check output type match
310
- if len(self.out_vars) != len(other.out_vars):
311
- return False
312
- for var, other_var in zip(self.out_vars, other.out_vars, strict=False):
313
- if var.mptype != other_var.mptype:
314
- return False
315
-
316
- return True
317
-
318
- def compiler_ir(self, verbose_peval: bool = False) -> str:
319
- """Get the compiler IR representation of this traced function."""
320
- printer = Printer(verbose_peval=verbose_peval)
321
- func_expr = self.make_expr()
322
- return printer.print_expr(func_expr)
323
-
324
-
325
- def trace(
326
- tracer: TraceContext,
327
- mpfn: Callable,
328
- *args: Any,
329
- **kwargs: Any,
330
- ) -> TracedFunction:
331
- """Trace a Python function into an expression representation.
332
-
333
- This converts a Python function into a FuncDefExpr that can be executed
334
- in multi-party computation contexts. It handles:
335
- - Function arguments (including pytree structures)
336
- - Captured variables from outer scopes
337
- - Output structures
338
-
339
- Args:
340
- tracer: The tracing context
341
- fn: The Python function to trace
342
- *args, **kwargs: Arguments to the function
343
-
344
- Returns:
345
- A TracedFunction containing a FuncDefExpr representing the function
346
- """
347
- assert isinstance(tracer, TraceContext), f"Expect TraceContext, got {tracer}"
348
-
349
- # Separate MPObjects from immediate values in inputs
350
- is_mpobj = lambda x: isinstance(x, MPObject)
351
- in_params, in_imms, in_struct = var_morph((args, kwargs), is_mpobj)
352
-
353
- param_names = [tracer._gen_name() for _ in range(len(in_params))]
354
- in_vars = [
355
- TraceVar(tracer, VariableExpr(name, var.mptype))
356
- for name, var in zip(param_names, in_params, strict=False)
357
- ]
358
-
359
- with with_ctx(tracer):
360
- # Prepare formal parameters for the function
361
- vargs, vkwargs = var_demorph(in_vars, in_imms, in_struct)
362
-
363
- # Execute the function - this will capture any external variables through switch_ctx
364
- outs = mpfn(*vargs, **vkwargs)
365
-
366
- # Extract output MPObjects and structure
367
- out_vars, out_imms, out_struct = var_morph(outs, is_mpobj)
368
- # Each MPObject represents a single tensor, so this assertion is redundant
369
- # assert all(len(out.mptypes) == 1 for out in out_vars), out_vars
370
-
371
- captures = tracer.get_captures()
372
-
373
- # Return TracedFunction with all the necessary information
374
- return TracedFunction(
375
- func_name=get_fn_name(mpfn),
376
- in_vars=in_vars,
377
- in_struct=in_struct,
378
- in_imms=in_imms,
379
- capture_map=captures,
380
- out_vars=out_vars,
381
- out_struct=out_struct,
382
- out_imms=out_imms,
383
- )