mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/core/table.py
DELETED
|
@@ -1,218 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
from collections.abc import Iterator
|
|
18
|
-
from dataclasses import dataclass, field
|
|
19
|
-
from typing import Any, Protocol, runtime_checkable
|
|
20
|
-
|
|
21
|
-
from mplang.v1.core.dtypes import DType
|
|
22
|
-
|
|
23
|
-
__all__ = ["TableLike", "TableType"]
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
@runtime_checkable
|
|
27
|
-
class PandasTableLike(Protocol):
|
|
28
|
-
"""
|
|
29
|
-
Protocol for objects structurally resembling tables from common libraries
|
|
30
|
-
(pandas DataFrame, polars DataFrame, etc.), focusing on dtypes and columns attributes.
|
|
31
|
-
"""
|
|
32
|
-
|
|
33
|
-
@property
|
|
34
|
-
def dtypes(self) -> Any: ...
|
|
35
|
-
|
|
36
|
-
@property
|
|
37
|
-
def columns(self) -> Any: ...
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
@runtime_checkable
|
|
41
|
-
class ArrowSchema(Protocol):
|
|
42
|
-
@property
|
|
43
|
-
def names(self) -> list[str]: ...
|
|
44
|
-
@property
|
|
45
|
-
def types(self) -> list[Any]: ...
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
@runtime_checkable
|
|
49
|
-
class ArrowTableLike(Protocol):
|
|
50
|
-
@property
|
|
51
|
-
def column_names(self) -> list[str]: ...
|
|
52
|
-
|
|
53
|
-
@property
|
|
54
|
-
def schema(self) -> ArrowSchema: ...
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
TableLike = PandasTableLike | ArrowTableLike
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
@dataclass(frozen=True)
|
|
61
|
-
class TableType:
|
|
62
|
-
"""Table schema: ordered list of column name-type pairs.
|
|
63
|
-
|
|
64
|
-
Represents table structure in relational algebra, containing column names
|
|
65
|
-
and their corresponding data types.
|
|
66
|
-
|
|
67
|
-
Examples:
|
|
68
|
-
>>> schema = TableType.from_dict({
|
|
69
|
-
... "id": DType.i64(),
|
|
70
|
-
... "name": DType.string(),
|
|
71
|
-
... })
|
|
72
|
-
>>> schema = TableType((("id", DType.i64()), ("name", DType.string())))
|
|
73
|
-
"""
|
|
74
|
-
|
|
75
|
-
columns: tuple[tuple[str, DType], ...]
|
|
76
|
-
_column_map: dict[str, DType] = field(init=False, repr=False, compare=False)
|
|
77
|
-
|
|
78
|
-
def __post_init__(self) -> None:
|
|
79
|
-
"""Validate the table schema."""
|
|
80
|
-
if not self.columns:
|
|
81
|
-
raise ValueError("TableType cannot be empty")
|
|
82
|
-
|
|
83
|
-
# Validate column name uniqueness
|
|
84
|
-
names = [name for name, _ in self.columns]
|
|
85
|
-
if len(names) != len(set(names)):
|
|
86
|
-
raise ValueError("Column names must be unique")
|
|
87
|
-
|
|
88
|
-
# Validate column names are non-empty
|
|
89
|
-
for name, dtype in self.columns:
|
|
90
|
-
if not name or not isinstance(name, str):
|
|
91
|
-
raise ValueError("Column names must be non-empty strings")
|
|
92
|
-
if not isinstance(dtype, DType):
|
|
93
|
-
raise ValueError(f"Column type must be DType, got {type(dtype)}")
|
|
94
|
-
|
|
95
|
-
# Create column name to type mapping for O(1) lookups
|
|
96
|
-
object.__setattr__(self, "_column_map", dict(self.columns))
|
|
97
|
-
|
|
98
|
-
@classmethod
|
|
99
|
-
def from_dict(cls, schema_dict: dict[str, DType]) -> TableType:
|
|
100
|
-
"""Create table schema from dictionary.
|
|
101
|
-
|
|
102
|
-
Args:
|
|
103
|
-
schema_dict: Mapping from column names to data types
|
|
104
|
-
|
|
105
|
-
Returns:
|
|
106
|
-
TableType instance
|
|
107
|
-
"""
|
|
108
|
-
return cls(tuple(schema_dict.items()))
|
|
109
|
-
|
|
110
|
-
@classmethod
|
|
111
|
-
def from_pairs(cls, pairs: list[tuple[str, DType]]) -> TableType:
|
|
112
|
-
"""Create table schema from list of name-type pairs.
|
|
113
|
-
|
|
114
|
-
Args:
|
|
115
|
-
pairs: List of tuples containing column name and data type
|
|
116
|
-
|
|
117
|
-
Returns:
|
|
118
|
-
TableType instance
|
|
119
|
-
"""
|
|
120
|
-
return cls(tuple(pairs))
|
|
121
|
-
|
|
122
|
-
@classmethod
|
|
123
|
-
def from_tablelike(cls, table: TableLike) -> TableType:
|
|
124
|
-
"""Create table schema from a table-like object.
|
|
125
|
-
|
|
126
|
-
Args:
|
|
127
|
-
table: A table-like object (e.g., pandas DataFrame)
|
|
128
|
-
|
|
129
|
-
Returns:
|
|
130
|
-
TableType instance
|
|
131
|
-
"""
|
|
132
|
-
if isinstance(table, PandasTableLike):
|
|
133
|
-
columns = [
|
|
134
|
-
(name, DType.from_any(dtype))
|
|
135
|
-
for name, dtype in zip(table.columns, table.dtypes, strict=True)
|
|
136
|
-
]
|
|
137
|
-
return cls(tuple(columns))
|
|
138
|
-
elif isinstance(table, ArrowTableLike):
|
|
139
|
-
schema = table.schema
|
|
140
|
-
columns = [
|
|
141
|
-
(name, DType.from_any(dtype))
|
|
142
|
-
for name, dtype in zip(schema.names, schema.types, strict=True)
|
|
143
|
-
]
|
|
144
|
-
return cls(tuple(columns))
|
|
145
|
-
|
|
146
|
-
def column_names(self) -> tuple[str, ...]:
|
|
147
|
-
"""Get all column names."""
|
|
148
|
-
return tuple(name for name, _ in self.columns)
|
|
149
|
-
|
|
150
|
-
def column_types(self) -> tuple[DType, ...]:
|
|
151
|
-
"""Get all column data types."""
|
|
152
|
-
return tuple(dtype for _, dtype in self.columns)
|
|
153
|
-
|
|
154
|
-
def get_column_type(self, name: str) -> DType:
|
|
155
|
-
"""Get data type by column name.
|
|
156
|
-
|
|
157
|
-
Args:
|
|
158
|
-
name: Column name
|
|
159
|
-
|
|
160
|
-
Returns:
|
|
161
|
-
Corresponding data type
|
|
162
|
-
|
|
163
|
-
Raises:
|
|
164
|
-
KeyError: If column name does not exist
|
|
165
|
-
"""
|
|
166
|
-
try:
|
|
167
|
-
return self._column_map[name]
|
|
168
|
-
except KeyError:
|
|
169
|
-
raise KeyError(f"Column '{name}' not found in schema") from None
|
|
170
|
-
|
|
171
|
-
def has_column(self, name: str) -> bool:
|
|
172
|
-
"""Check if contains specified column name.
|
|
173
|
-
|
|
174
|
-
Args:
|
|
175
|
-
name: Column name
|
|
176
|
-
|
|
177
|
-
Returns:
|
|
178
|
-
True if contains the column, False otherwise
|
|
179
|
-
"""
|
|
180
|
-
return name in self.column_names()
|
|
181
|
-
|
|
182
|
-
def num_columns(self) -> int:
|
|
183
|
-
"""Get number of columns."""
|
|
184
|
-
return len(self.columns)
|
|
185
|
-
|
|
186
|
-
def to_dict(self) -> dict[str, DType]:
|
|
187
|
-
"""Convert to dictionary form."""
|
|
188
|
-
return dict(self.columns)
|
|
189
|
-
|
|
190
|
-
def __repr__(self) -> str:
|
|
191
|
-
"""String representation."""
|
|
192
|
-
cols = ", ".join(f"{name}:{dtype.short_name()}" for name, dtype in self.columns)
|
|
193
|
-
return f"TableType<{cols}>"
|
|
194
|
-
|
|
195
|
-
def __len__(self) -> int:
|
|
196
|
-
"""Get number of columns."""
|
|
197
|
-
return len(self.columns)
|
|
198
|
-
|
|
199
|
-
def __iter__(self) -> Iterator[tuple[str, DType]]:
|
|
200
|
-
"""Support iteration."""
|
|
201
|
-
return iter(self.columns)
|
|
202
|
-
|
|
203
|
-
def __getitem__(self, index: int | str) -> tuple[str, DType] | DType:
|
|
204
|
-
"""Support index access.
|
|
205
|
-
|
|
206
|
-
Args:
|
|
207
|
-
index: Integer index or column name
|
|
208
|
-
|
|
209
|
-
Returns:
|
|
210
|
-
If integer index, returns (column name, data type) tuple
|
|
211
|
-
If column name, returns corresponding data type
|
|
212
|
-
"""
|
|
213
|
-
if isinstance(index, int):
|
|
214
|
-
return self.columns[index]
|
|
215
|
-
elif isinstance(index, str):
|
|
216
|
-
return self.get_column_type(index)
|
|
217
|
-
else:
|
|
218
|
-
raise TypeError(f"Index must be int or str, got {type(index)}")
|
mplang/v1/core/tensor.py
DELETED
|
@@ -1,75 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
from dataclasses import dataclass
|
|
18
|
-
from typing import Any, Protocol, runtime_checkable
|
|
19
|
-
|
|
20
|
-
import numpy as np
|
|
21
|
-
|
|
22
|
-
from mplang.v1.core.dtypes import DType
|
|
23
|
-
|
|
24
|
-
# basic type aliases
|
|
25
|
-
Shape = tuple[int, ...]
|
|
26
|
-
ScalarType = int | float | bool | complex
|
|
27
|
-
|
|
28
|
-
__all__ = ["ScalarType", "Shape", "TensorLike", "TensorType"]
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
@runtime_checkable
|
|
32
|
-
class TensorLike(Protocol):
|
|
33
|
-
"""
|
|
34
|
-
Protocol for objects structurally resembling tensors from common libraries
|
|
35
|
-
(NumPy, PyTorch, JAX), focusing on dtype and shape attributes.
|
|
36
|
-
"""
|
|
37
|
-
|
|
38
|
-
@property
|
|
39
|
-
def dtype(self) -> Any: ...
|
|
40
|
-
|
|
41
|
-
@property
|
|
42
|
-
def shape(self) -> Shape: ...
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
@dataclass(frozen=True)
|
|
46
|
-
class TensorType:
|
|
47
|
-
"""A data class that describes the type information of a tensor."""
|
|
48
|
-
|
|
49
|
-
dtype: DType
|
|
50
|
-
shape: Shape
|
|
51
|
-
|
|
52
|
-
def __init__(self, dtype: DType | Any, shape: Shape):
|
|
53
|
-
# Convert dtype to DType if needed
|
|
54
|
-
if not isinstance(dtype, DType):
|
|
55
|
-
dtype = DType.from_any(dtype)
|
|
56
|
-
object.__setattr__(self, "dtype", dtype)
|
|
57
|
-
object.__setattr__(self, "shape", shape)
|
|
58
|
-
|
|
59
|
-
@classmethod
|
|
60
|
-
def from_obj(cls, obj: TensorLike | ScalarType) -> TensorType:
|
|
61
|
-
if isinstance(obj, ScalarType):
|
|
62
|
-
return cls(DType.from_python_type(type(obj)), ())
|
|
63
|
-
elif isinstance(obj, TensorLike):
|
|
64
|
-
return cls(DType.from_any(obj.dtype), obj.shape)
|
|
65
|
-
else:
|
|
66
|
-
raise TypeError(f"Unsupported type: {type(obj)}.")
|
|
67
|
-
|
|
68
|
-
def to_numpy(self) -> np.dtype:
|
|
69
|
-
"""Convert to NumPy dtype for compatibility."""
|
|
70
|
-
return self.dtype.to_numpy()
|
|
71
|
-
|
|
72
|
-
def __repr__(self) -> str:
|
|
73
|
-
shape_str = "x".join(str(d) for d in self.shape)
|
|
74
|
-
dtype_name = str(self.dtype)
|
|
75
|
-
return f"Tensor<{shape_str}x{dtype_name}>" if shape_str else f"{dtype_name}"
|
mplang/v1/core/tracer.py
DELETED
|
@@ -1,383 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
Trace context and TraceVar implementation.
|
|
17
|
-
|
|
18
|
-
This module provides the trace context for lazy evaluation and TraceVar
|
|
19
|
-
which stores expressions for deferred computation.
|
|
20
|
-
|
|
21
|
-
Design Philosophy (inspired by JAX):
|
|
22
|
-
====================================
|
|
23
|
-
The tracing mechanism converts Python functions operating on data into a static,
|
|
24
|
-
dataflow graph representation (Expr) for analysis and multi-party execution.
|
|
25
|
-
This follows a "closed-world" design, similar to JAX's JIT, with a core
|
|
26
|
-
principle: functions are for data transformation ("Tensor in, Tensor out").
|
|
27
|
-
|
|
28
|
-
This imposes several intentional limitations:
|
|
29
|
-
- **Data-Centric Boundaries**: Only MPObjects (tensors or their pytrees) and
|
|
30
|
-
immediate values can be passed as arguments to or be returned from a traced
|
|
31
|
-
function.
|
|
32
|
-
- **No Function Outputs**: A traced function cannot return a Python function that
|
|
33
|
-
has captured tracers, as this would violate the static nature of the graph.
|
|
34
|
-
- **Limited Function Inputs**: Arbitrary Python functions are not supported as
|
|
35
|
-
arguments. However, for structured control flow (e.g., `cond`, `while_loop`),
|
|
36
|
-
`mplang` allows passing Python functions. These are not true first-class
|
|
37
|
-
functions; they are immediately traced into sub-graphs (`FuncDefExpr`) and
|
|
38
|
-
embedded into the IR, never existing as runtime values within the graph.
|
|
39
|
-
|
|
40
|
-
Rationale for TracedFunction vs. First-Class Functions:
|
|
41
|
-
-------------------------------------------------------
|
|
42
|
-
Instead of representing functions as `TraceVar(expr=FuncDefExpr)`, a dedicated
|
|
43
|
-
`TracedFunction` class is used. This is crucial for:
|
|
44
|
-
|
|
45
|
-
1. **Type Safety & Clear Boundaries**: `TracedFunction` represents a callable
|
|
46
|
-
computation, while `TraceVar` represents data. This separation prevents
|
|
47
|
-
treating computation as data within the graph.
|
|
48
|
-
2. **Preserving Metadata**: It holds essential metadata for marshalling arguments
|
|
49
|
-
and results, such as pytree structures (`in_struct`/`out_struct`) and
|
|
50
|
-
captured variables, which a simple `Expr` would not retain.
|
|
51
|
-
|
|
52
|
-
This design avoids the complexities of dynamic dispatch and higher-order functions
|
|
53
|
-
in the IR, making the resulting graph simpler, more analyzable, and easier to
|
|
54
|
-
compile for a multi-party setting.
|
|
55
|
-
"""
|
|
56
|
-
|
|
57
|
-
from __future__ import annotations
|
|
58
|
-
|
|
59
|
-
from collections.abc import Callable
|
|
60
|
-
from dataclasses import dataclass
|
|
61
|
-
from typing import Any, cast
|
|
62
|
-
|
|
63
|
-
from mplang.v1.core.cluster import ClusterSpec
|
|
64
|
-
from mplang.v1.core.context_mgr import with_ctx
|
|
65
|
-
from mplang.v1.core.expr.ast import Expr, FuncDefExpr, TupleExpr, VariableExpr
|
|
66
|
-
from mplang.v1.core.expr.printer import Printer
|
|
67
|
-
from mplang.v1.core.mask import Mask
|
|
68
|
-
from mplang.v1.core.mpobject import MPContext, MPObject
|
|
69
|
-
from mplang.v1.core.mptype import MPType
|
|
70
|
-
from mplang.v1.core.pfunc import get_fn_name
|
|
71
|
-
from mplang.v1.utils.func_utils import MorphStruct, var_demorph, var_morph
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
class VarNamer:
|
|
75
|
-
"""Utility class to generate unique variable names in a trace context."""
|
|
76
|
-
|
|
77
|
-
def __init__(self, prefix: str = "$"):
|
|
78
|
-
self._counter = 0
|
|
79
|
-
self._prefix = prefix
|
|
80
|
-
|
|
81
|
-
def next_name(self) -> str:
|
|
82
|
-
"""Generate a new unique variable name."""
|
|
83
|
-
name = f"{self._prefix}{self._counter}"
|
|
84
|
-
self._counter += 1
|
|
85
|
-
return name
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
class TraceContext(MPContext):
|
|
89
|
-
"""Context for lazy evaluation using expressions.
|
|
90
|
-
|
|
91
|
-
TraceContext builds computation graphs by creating TraceVar objects
|
|
92
|
-
that store expressions instead of executing them immediately.
|
|
93
|
-
"""
|
|
94
|
-
|
|
95
|
-
def __init__(
|
|
96
|
-
self,
|
|
97
|
-
cluster_spec: ClusterSpec,
|
|
98
|
-
*,
|
|
99
|
-
mask: Mask | None = None,
|
|
100
|
-
capture_namer: VarNamer | None = None,
|
|
101
|
-
parent: MPContext | None = None,
|
|
102
|
-
):
|
|
103
|
-
"""Initialize TraceContext with a cluster specification.
|
|
104
|
-
|
|
105
|
-
Args:
|
|
106
|
-
cluster_spec: The cluster specification defining the physical nodes
|
|
107
|
-
and logical devices available for computation.
|
|
108
|
-
mask: The default mask for this context. If None, defaults to all parties.
|
|
109
|
-
capture_namer: Optional VarNamer for naming captured variables.
|
|
110
|
-
"""
|
|
111
|
-
super().__init__(cluster_spec, parent=parent)
|
|
112
|
-
|
|
113
|
-
self._mask = mask or Mask.all(self.world_size())
|
|
114
|
-
self._capture_namer = capture_namer or VarNamer()
|
|
115
|
-
|
|
116
|
-
self._var_namer = VarNamer(prefix="%")
|
|
117
|
-
self._captures: dict[MPObject, TraceVar] = {}
|
|
118
|
-
|
|
119
|
-
@property
|
|
120
|
-
def mask(self) -> Mask:
|
|
121
|
-
"""The default mask for this context."""
|
|
122
|
-
return self._mask
|
|
123
|
-
|
|
124
|
-
def _gen_name(self) -> str:
|
|
125
|
-
"""Generate a unique variable name."""
|
|
126
|
-
return self._capture_namer.next_name()
|
|
127
|
-
|
|
128
|
-
def fork(self, mask: Mask | None = None) -> TraceContext:
|
|
129
|
-
"""Create a new TraceContext with the same attributes."""
|
|
130
|
-
if mask is None:
|
|
131
|
-
mask = self._mask
|
|
132
|
-
else:
|
|
133
|
-
# ensure mask is subset of the current mask
|
|
134
|
-
if not Mask(mask).is_subset(self._mask):
|
|
135
|
-
raise ValueError(
|
|
136
|
-
f"New mask {mask} must be a subset of the current mask {self._mask}"
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
return TraceContext(
|
|
140
|
-
cluster_spec=self.cluster_spec,
|
|
141
|
-
mask=mask,
|
|
142
|
-
parent=self._parent,
|
|
143
|
-
# capture_namer=self._capture_namer,
|
|
144
|
-
)
|
|
145
|
-
|
|
146
|
-
def capture(self, obj: MPObject) -> TraceVar:
|
|
147
|
-
"""Create or reuse a variable that represents a captured MPObject.
|
|
148
|
-
|
|
149
|
-
This method ensures that the same captured object always maps to
|
|
150
|
-
the same variable in the traced function.
|
|
151
|
-
|
|
152
|
-
Args:
|
|
153
|
-
obj: The MPObject being captured from another context
|
|
154
|
-
|
|
155
|
-
Returns:
|
|
156
|
-
TraceVar representing the captured variable in this context
|
|
157
|
-
"""
|
|
158
|
-
# If we've seen this object before, return the existing variable
|
|
159
|
-
if obj in self._captures:
|
|
160
|
-
return self._captures[obj]
|
|
161
|
-
|
|
162
|
-
# Use the object's name directly if available, otherwise generate a name
|
|
163
|
-
capture_name = self._gen_name()
|
|
164
|
-
var = TraceVar(self, VariableExpr(capture_name, obj.mptype))
|
|
165
|
-
self._captures[obj] = var
|
|
166
|
-
|
|
167
|
-
return var
|
|
168
|
-
|
|
169
|
-
def get_captures(self) -> dict[MPObject, TraceVar]:
|
|
170
|
-
return self._captures
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
class TraceVar(MPObject):
|
|
174
|
-
"""A variable that stores an expression for lazy evaluation.
|
|
175
|
-
|
|
176
|
-
TraceVar represents a computation that has not yet been executed.
|
|
177
|
-
It stores the expression tree that would produce the value when evaluated.
|
|
178
|
-
The expression must be single-output (checked at construction time).
|
|
179
|
-
"""
|
|
180
|
-
|
|
181
|
-
def __init__(self, ctx: TraceContext, expr: Expr):
|
|
182
|
-
# Ensure the expression is single-output
|
|
183
|
-
if len(expr.mptypes) != 1:
|
|
184
|
-
raise ValueError(
|
|
185
|
-
f"TraceVar requires single-output expression, "
|
|
186
|
-
f"but expression has {len(expr.mptypes)} outputs"
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
self._ctx = ctx
|
|
190
|
-
self._expr = expr
|
|
191
|
-
|
|
192
|
-
@property
|
|
193
|
-
def ctx(self) -> MPContext:
|
|
194
|
-
"""The context this variable belongs to."""
|
|
195
|
-
return self._ctx
|
|
196
|
-
|
|
197
|
-
@property
|
|
198
|
-
def expr(self) -> Expr:
|
|
199
|
-
"""The expression that this variable represents."""
|
|
200
|
-
return self._expr
|
|
201
|
-
|
|
202
|
-
@property
|
|
203
|
-
def mptype(self) -> MPType:
|
|
204
|
-
"""The type of this variable, derived from the expression."""
|
|
205
|
-
return self._expr.mptype
|
|
206
|
-
|
|
207
|
-
def __repr__(self) -> str:
|
|
208
|
-
return f"TraceVar(expr={self.expr.__class__.__name__})"
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
@dataclass
|
|
212
|
-
class TracedFunction:
|
|
213
|
-
func_name: str
|
|
214
|
-
"""The name of the traced function."""
|
|
215
|
-
|
|
216
|
-
in_vars: list[TraceVar]
|
|
217
|
-
"""List of free (input) variables in the traced function."""
|
|
218
|
-
in_struct: MorphStruct
|
|
219
|
-
in_imms: list[Any]
|
|
220
|
-
|
|
221
|
-
capture_map: dict[MPObject, TraceVar]
|
|
222
|
-
"""Map of captured MPObjects to their traced values."""
|
|
223
|
-
|
|
224
|
-
out_vars: list[TraceVar]
|
|
225
|
-
"""List of output TraceVars."""
|
|
226
|
-
out_struct: MorphStruct
|
|
227
|
-
out_imms: list[Any]
|
|
228
|
-
|
|
229
|
-
def in_names(self) -> list[str]:
|
|
230
|
-
"""Get the parameter names of the traced function."""
|
|
231
|
-
return [cast(VariableExpr, var.expr).name for var in self.in_vars]
|
|
232
|
-
|
|
233
|
-
def capture_names(self, captures: list[MPObject] | None = None) -> list[str]:
|
|
234
|
-
if captures is None:
|
|
235
|
-
captures = list(self.capture_map.keys())
|
|
236
|
-
|
|
237
|
-
def var_name(var: TraceVar | None) -> str:
|
|
238
|
-
return cast(VariableExpr, var.expr).name if var is not None else ""
|
|
239
|
-
|
|
240
|
-
return [var_name(self.capture_map.get(var, None)) for var in captures]
|
|
241
|
-
|
|
242
|
-
def make_expr(self, freevar_names: list[str] | None = None) -> FuncDefExpr:
|
|
243
|
-
"""Create a FuncDefExpr from the traced function data."""
|
|
244
|
-
arg_names = [cast(VariableExpr, var.expr).name for var in self.in_vars]
|
|
245
|
-
capture_names = [
|
|
246
|
-
cast(VariableExpr, var.expr).name for var in self.capture_map.values()
|
|
247
|
-
]
|
|
248
|
-
if freevar_names is None:
|
|
249
|
-
# If no freevar_names provided, use default names
|
|
250
|
-
freevar_names = arg_names + capture_names
|
|
251
|
-
else:
|
|
252
|
-
# Ensure freevar_names is superset of arg_names and capture_names
|
|
253
|
-
if not set(arg_names).issubset(freevar_names):
|
|
254
|
-
raise ValueError(
|
|
255
|
-
f"Provided freevar_names {freevar_names} must include all input variable names {arg_names}"
|
|
256
|
-
)
|
|
257
|
-
if not set(capture_names).issubset(freevar_names):
|
|
258
|
-
raise ValueError(
|
|
259
|
-
f"Provided freevar_names {freevar_names} must include all capture variable names {capture_names}"
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
if len(self.out_vars) == 0:
|
|
263
|
-
# No outputs - use empty tuple
|
|
264
|
-
body_expr: Expr = TupleExpr([])
|
|
265
|
-
return FuncDefExpr(freevar_names, body_expr)
|
|
266
|
-
elif len(self.out_vars) == 1:
|
|
267
|
-
body_expr = self.out_vars[0].expr
|
|
268
|
-
return FuncDefExpr(freevar_names, body_expr)
|
|
269
|
-
else:
|
|
270
|
-
# Multiple outputs - use tuple (ensures all vars are single-output)
|
|
271
|
-
body_expr = TupleExpr([var.expr for var in self.out_vars])
|
|
272
|
-
return FuncDefExpr(freevar_names, body_expr)
|
|
273
|
-
|
|
274
|
-
def is_signature_match(
|
|
275
|
-
self,
|
|
276
|
-
other: TracedFunction,
|
|
277
|
-
check_captures: bool = True,
|
|
278
|
-
) -> bool:
|
|
279
|
-
"""Check if this function's signature matches another."""
|
|
280
|
-
if not isinstance(other, TracedFunction):
|
|
281
|
-
return False
|
|
282
|
-
# Check input structures and immutables
|
|
283
|
-
if (
|
|
284
|
-
self.in_struct != other.in_struct
|
|
285
|
-
or self.in_imms != other.in_imms
|
|
286
|
-
or self.out_struct != other.out_struct
|
|
287
|
-
or self.out_imms != other.out_imms
|
|
288
|
-
):
|
|
289
|
-
return False
|
|
290
|
-
|
|
291
|
-
# Check input type match
|
|
292
|
-
if len(self.in_vars) != len(other.in_vars):
|
|
293
|
-
return False
|
|
294
|
-
for var, other_var in zip(self.in_vars, other.in_vars, strict=False):
|
|
295
|
-
if var.mptype != other_var.mptype:
|
|
296
|
-
return False
|
|
297
|
-
|
|
298
|
-
# Check captures if required
|
|
299
|
-
if check_captures:
|
|
300
|
-
if len(self.capture_map) != len(other.capture_map):
|
|
301
|
-
return False
|
|
302
|
-
for key, var in self.capture_map.items():
|
|
303
|
-
if (
|
|
304
|
-
key not in other.capture_map
|
|
305
|
-
or var.mptype != other.capture_map[key].mptype
|
|
306
|
-
):
|
|
307
|
-
return False
|
|
308
|
-
|
|
309
|
-
# check output type match
|
|
310
|
-
if len(self.out_vars) != len(other.out_vars):
|
|
311
|
-
return False
|
|
312
|
-
for var, other_var in zip(self.out_vars, other.out_vars, strict=False):
|
|
313
|
-
if var.mptype != other_var.mptype:
|
|
314
|
-
return False
|
|
315
|
-
|
|
316
|
-
return True
|
|
317
|
-
|
|
318
|
-
def compiler_ir(self, verbose_peval: bool = False) -> str:
|
|
319
|
-
"""Get the compiler IR representation of this traced function."""
|
|
320
|
-
printer = Printer(verbose_peval=verbose_peval)
|
|
321
|
-
func_expr = self.make_expr()
|
|
322
|
-
return printer.print_expr(func_expr)
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
def trace(
|
|
326
|
-
tracer: TraceContext,
|
|
327
|
-
mpfn: Callable,
|
|
328
|
-
*args: Any,
|
|
329
|
-
**kwargs: Any,
|
|
330
|
-
) -> TracedFunction:
|
|
331
|
-
"""Trace a Python function into an expression representation.
|
|
332
|
-
|
|
333
|
-
This converts a Python function into a FuncDefExpr that can be executed
|
|
334
|
-
in multi-party computation contexts. It handles:
|
|
335
|
-
- Function arguments (including pytree structures)
|
|
336
|
-
- Captured variables from outer scopes
|
|
337
|
-
- Output structures
|
|
338
|
-
|
|
339
|
-
Args:
|
|
340
|
-
tracer: The tracing context
|
|
341
|
-
fn: The Python function to trace
|
|
342
|
-
*args, **kwargs: Arguments to the function
|
|
343
|
-
|
|
344
|
-
Returns:
|
|
345
|
-
A TracedFunction containing a FuncDefExpr representing the function
|
|
346
|
-
"""
|
|
347
|
-
assert isinstance(tracer, TraceContext), f"Expect TraceContext, got {tracer}"
|
|
348
|
-
|
|
349
|
-
# Separate MPObjects from immediate values in inputs
|
|
350
|
-
is_mpobj = lambda x: isinstance(x, MPObject)
|
|
351
|
-
in_params, in_imms, in_struct = var_morph((args, kwargs), is_mpobj)
|
|
352
|
-
|
|
353
|
-
param_names = [tracer._gen_name() for _ in range(len(in_params))]
|
|
354
|
-
in_vars = [
|
|
355
|
-
TraceVar(tracer, VariableExpr(name, var.mptype))
|
|
356
|
-
for name, var in zip(param_names, in_params, strict=False)
|
|
357
|
-
]
|
|
358
|
-
|
|
359
|
-
with with_ctx(tracer):
|
|
360
|
-
# Prepare formal parameters for the function
|
|
361
|
-
vargs, vkwargs = var_demorph(in_vars, in_imms, in_struct)
|
|
362
|
-
|
|
363
|
-
# Execute the function - this will capture any external variables through switch_ctx
|
|
364
|
-
outs = mpfn(*vargs, **vkwargs)
|
|
365
|
-
|
|
366
|
-
# Extract output MPObjects and structure
|
|
367
|
-
out_vars, out_imms, out_struct = var_morph(outs, is_mpobj)
|
|
368
|
-
# Each MPObject represents a single tensor, so this assertion is redundant
|
|
369
|
-
# assert all(len(out.mptypes) == 1 for out in out_vars), out_vars
|
|
370
|
-
|
|
371
|
-
captures = tracer.get_captures()
|
|
372
|
-
|
|
373
|
-
# Return TracedFunction with all the necessary information
|
|
374
|
-
return TracedFunction(
|
|
375
|
-
func_name=get_fn_name(mpfn),
|
|
376
|
-
in_vars=in_vars,
|
|
377
|
-
in_struct=in_struct,
|
|
378
|
-
in_imms=in_imms,
|
|
379
|
-
capture_map=captures,
|
|
380
|
-
out_vars=out_vars,
|
|
381
|
-
out_struct=out_struct,
|
|
382
|
-
out_imms=out_imms,
|
|
383
|
-
)
|