mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/core/expr/__init__.py
DELETED
|
@@ -1,80 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
Expression system for multi-party computation graph construction.
|
|
17
|
-
|
|
18
|
-
This package provides a modern, extensible expression-based architecture for building
|
|
19
|
-
multi-party computation graphs using the visitor pattern.
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
# Core expression types
|
|
23
|
-
from mplang.v1.core.expr.ast import (
|
|
24
|
-
AccessExpr,
|
|
25
|
-
CallExpr,
|
|
26
|
-
CondExpr,
|
|
27
|
-
ConvExpr,
|
|
28
|
-
EvalExpr,
|
|
29
|
-
Expr,
|
|
30
|
-
FuncDefExpr,
|
|
31
|
-
ShflExpr,
|
|
32
|
-
ShflSExpr,
|
|
33
|
-
TupleExpr,
|
|
34
|
-
VariableExpr,
|
|
35
|
-
WhileExpr,
|
|
36
|
-
)
|
|
37
|
-
|
|
38
|
-
# Built-in evaluator engines
|
|
39
|
-
from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator
|
|
40
|
-
from mplang.v1.core.expr.printer import Printer
|
|
41
|
-
from mplang.v1.core.expr.transformer import ExprTransformer
|
|
42
|
-
|
|
43
|
-
# Utility functions
|
|
44
|
-
from mplang.v1.core.expr.utils import (
|
|
45
|
-
deduce_mask,
|
|
46
|
-
ensure_scalar,
|
|
47
|
-
ensure_tensorlist_equal,
|
|
48
|
-
type_equal,
|
|
49
|
-
)
|
|
50
|
-
|
|
51
|
-
# Visitor pattern interface
|
|
52
|
-
from mplang.v1.core.expr.visitor import ExprVisitor
|
|
53
|
-
from mplang.v1.core.expr.walk import walk, walk_dataflow, walk_structural
|
|
54
|
-
|
|
55
|
-
__all__ = [
|
|
56
|
-
"AccessExpr",
|
|
57
|
-
"CallExpr",
|
|
58
|
-
"CondExpr",
|
|
59
|
-
"ConvExpr",
|
|
60
|
-
"EvalExpr",
|
|
61
|
-
"Expr",
|
|
62
|
-
"ExprTransformer",
|
|
63
|
-
"ExprVisitor",
|
|
64
|
-
"FuncDefExpr",
|
|
65
|
-
"IEvaluator",
|
|
66
|
-
"Printer",
|
|
67
|
-
"ShflExpr",
|
|
68
|
-
"ShflSExpr",
|
|
69
|
-
"TupleExpr",
|
|
70
|
-
"VariableExpr",
|
|
71
|
-
"WhileExpr",
|
|
72
|
-
"create_evaluator",
|
|
73
|
-
"deduce_mask",
|
|
74
|
-
"ensure_scalar",
|
|
75
|
-
"ensure_tensorlist_equal",
|
|
76
|
-
"type_equal",
|
|
77
|
-
"walk",
|
|
78
|
-
"walk_dataflow",
|
|
79
|
-
"walk_structural",
|
|
80
|
-
]
|
mplang/v1/core/expr/ast.py
DELETED
|
@@ -1,542 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
Abstract Syntax Tree (AST) nodes for multi-party computation expressions.
|
|
17
|
-
|
|
18
|
-
This module defines the AST nodes for representing multi-party computation expressions.
|
|
19
|
-
Each node type represents a different kind of operation or construct in the multi-party
|
|
20
|
-
computation language, following the visitor pattern for extensible processing.
|
|
21
|
-
"""
|
|
22
|
-
|
|
23
|
-
from __future__ import annotations
|
|
24
|
-
|
|
25
|
-
import logging
|
|
26
|
-
from abc import ABC, abstractmethod
|
|
27
|
-
from typing import TYPE_CHECKING, Any
|
|
28
|
-
|
|
29
|
-
from mplang.v1.core.expr.utils import deduce_mask
|
|
30
|
-
from mplang.v1.core.mask import Mask
|
|
31
|
-
from mplang.v1.core.mptype import MPType, Rank
|
|
32
|
-
from mplang.v1.core.pfunc import PFunction
|
|
33
|
-
from mplang.v1.core.table import TableType
|
|
34
|
-
from mplang.v1.core.tensor import TensorType
|
|
35
|
-
|
|
36
|
-
if TYPE_CHECKING:
|
|
37
|
-
from mplang.v1.core.expr.visitor import ExprVisitor
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
class Expr(ABC):
|
|
41
|
-
"""Base class for all expression types in the multi-party computation graph.
|
|
42
|
-
|
|
43
|
-
This expression system is designed to be Multi-Input Multi-Output (MIMO),
|
|
44
|
-
meaning each expression node can conceptually have multiple outputs. This is
|
|
45
|
-
fundamental to supporting multi-output PFunctions and constructing complex
|
|
46
|
-
dataflow graphs efficiently.
|
|
47
|
-
|
|
48
|
-
Attributes:
|
|
49
|
-
mptypes (list[MPType]): The list of output types for this expression. This
|
|
50
|
-
is the core property that enables MIMO capabilities. It's computed
|
|
51
|
-
lazily and cached.
|
|
52
|
-
mptype (MPType): A convenience property for the common case of a single-output
|
|
53
|
-
expression. It raises a ValueError if the expression does not have
|
|
54
|
-
exactly one output, providing a useful runtime check.
|
|
55
|
-
"""
|
|
56
|
-
|
|
57
|
-
def __init__(self) -> None:
|
|
58
|
-
self._mptypes: list[MPType] | None = None
|
|
59
|
-
|
|
60
|
-
@property
|
|
61
|
-
def num_outputs(self) -> int:
|
|
62
|
-
"""Return the number of outputs this expression produces."""
|
|
63
|
-
return len(self.mptypes)
|
|
64
|
-
|
|
65
|
-
@property
|
|
66
|
-
def mptypes(self) -> list[MPType]:
|
|
67
|
-
if self._mptypes is None:
|
|
68
|
-
self._mptypes = self._compute_mptypes()
|
|
69
|
-
return self._mptypes
|
|
70
|
-
|
|
71
|
-
@property
|
|
72
|
-
def mptype(self) -> MPType:
|
|
73
|
-
"""Convenience property for single-output expressions."""
|
|
74
|
-
types = self.mptypes
|
|
75
|
-
if len(types) != 1:
|
|
76
|
-
raise ValueError(f"Expression has {len(types)} outputs, expected 1")
|
|
77
|
-
return types[0]
|
|
78
|
-
|
|
79
|
-
@abstractmethod
|
|
80
|
-
def _compute_mptypes(self) -> list[MPType]:
|
|
81
|
-
"""Computes the types of the expression's outputs."""
|
|
82
|
-
|
|
83
|
-
@abstractmethod
|
|
84
|
-
def accept(self, visitor: ExprVisitor) -> Any:
|
|
85
|
-
"""Accept a visitor for the visitor pattern."""
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
# ============================================================================
|
|
89
|
-
# Concrete Expression Classes
|
|
90
|
-
# ============================================================================
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
class EvalExpr(Expr):
|
|
94
|
-
"""Expression for multi-party function evaluation."""
|
|
95
|
-
|
|
96
|
-
def __init__(
|
|
97
|
-
self, pfunc: PFunction, args: list[Expr], rmask: Mask | int | None = None
|
|
98
|
-
):
|
|
99
|
-
super().__init__()
|
|
100
|
-
# Type checking - basic validation that we have the right number of inputs
|
|
101
|
-
if len(args) != len(pfunc.ins_info):
|
|
102
|
-
raise ValueError(
|
|
103
|
-
f"Expected {len(pfunc.ins_info)} arguments, got {len(args)}"
|
|
104
|
-
)
|
|
105
|
-
rmask = Mask(rmask) if rmask is not None else None
|
|
106
|
-
|
|
107
|
-
self.pfunc = pfunc
|
|
108
|
-
self.args = args
|
|
109
|
-
self.rmask = rmask
|
|
110
|
-
|
|
111
|
-
def _compute_mptypes(self) -> list[MPType]:
|
|
112
|
-
"""Compute output MPTypes based on PFunction and mask deduction logic.
|
|
113
|
-
|
|
114
|
-
The logic follows these steps:
|
|
115
|
-
1. Determine output TensorType (dtype + shape) from PFunction
|
|
116
|
-
2. If rmask is explicitly provided (caller has strong mask knowledge):
|
|
117
|
-
2.1 Deduce pmask from args (intersection of all arg pmasks)
|
|
118
|
-
2.1.1 If deduced pmask is not None (trace time known):
|
|
119
|
-
- If rmask is subset of deduced pmask: use rmask
|
|
120
|
-
- If rmask is not subset of deduced pmask: raise error
|
|
121
|
-
2.1.2 If deduced pmask is None (trace time unknown): force use rmask
|
|
122
|
-
3. If rmask is not provided (caller lets expr deduce it): use deduced pmask from args
|
|
123
|
-
"""
|
|
124
|
-
# Deduce pmask from arguments (including None values - if any arg has None, result is None)
|
|
125
|
-
arg_pmasks = [arg.mptype.pmask for arg in self.args]
|
|
126
|
-
deduced_pmask = deduce_mask(*arg_pmasks)
|
|
127
|
-
|
|
128
|
-
# Determine effective output pmask
|
|
129
|
-
effective_pmask: Mask | None
|
|
130
|
-
if self.rmask is not None:
|
|
131
|
-
# rmask is explicitly provided - caller has strong mask knowledge
|
|
132
|
-
if deduced_pmask is not None:
|
|
133
|
-
# pmask is known at trace time - validate subset relationship
|
|
134
|
-
if not Mask(self.rmask).is_subset(deduced_pmask):
|
|
135
|
-
raise ValueError(
|
|
136
|
-
f"Specified rmask {self.rmask} is not a subset of deduced pmask {deduced_pmask}."
|
|
137
|
-
)
|
|
138
|
-
effective_pmask = self.rmask
|
|
139
|
-
else:
|
|
140
|
-
# pmask is unknown at trace time - force use rmask
|
|
141
|
-
effective_pmask = self.rmask
|
|
142
|
-
else:
|
|
143
|
-
# rmask not provided - use deduced pmask from args
|
|
144
|
-
effective_pmask = deduced_pmask
|
|
145
|
-
|
|
146
|
-
# Create result MPTypes based on PFunction output info
|
|
147
|
-
result_types = []
|
|
148
|
-
for out_info in self.pfunc.outs_info:
|
|
149
|
-
if isinstance(out_info, TensorType):
|
|
150
|
-
# Tensor type
|
|
151
|
-
result_types.append(
|
|
152
|
-
MPType.tensor(out_info.dtype, out_info.shape, effective_pmask)
|
|
153
|
-
)
|
|
154
|
-
elif isinstance(out_info, TableType):
|
|
155
|
-
# Table type
|
|
156
|
-
result_types.append(MPType.table(out_info, effective_pmask))
|
|
157
|
-
else:
|
|
158
|
-
raise TypeError(f"Unsupported output type: {type(out_info)}")
|
|
159
|
-
return result_types
|
|
160
|
-
|
|
161
|
-
def accept(self, visitor: ExprVisitor) -> Any:
|
|
162
|
-
return visitor.visit_eval(self)
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
class TupleExpr(Expr):
|
|
166
|
-
"""Expression for creating a tuple from multiple single-output expressions.
|
|
167
|
-
|
|
168
|
-
In a Multi-Input Multi-Output (MIMO) expression system, this primitive
|
|
169
|
-
creates a logical tuple from multiple single-output expressions. Unlike
|
|
170
|
-
the previous FlattenExpr, TupleExpr requires all input expressions to
|
|
171
|
-
have exactly one output each.
|
|
172
|
-
|
|
173
|
-
This expression acts as a "tuple construction" primitive. It takes a list
|
|
174
|
-
of single-output expressions and produces a new logical expression whose
|
|
175
|
-
outputs are the list of all input expression outputs.
|
|
176
|
-
|
|
177
|
-
For example, if expr1 has output [A] and expr2 has output [B],
|
|
178
|
-
TupleExpr([expr1, expr2]) will have outputs [A, B].
|
|
179
|
-
|
|
180
|
-
This is the opposite of AccessExpr, which extracts a single element
|
|
181
|
-
from a multi-output expression.
|
|
182
|
-
"""
|
|
183
|
-
|
|
184
|
-
def __init__(self, args: list[Expr]):
|
|
185
|
-
super().__init__()
|
|
186
|
-
# Validate that all arguments are single-output expressions
|
|
187
|
-
for i, arg in enumerate(args):
|
|
188
|
-
if arg.num_outputs != 1:
|
|
189
|
-
raise ValueError(
|
|
190
|
-
f"TupleExpr requires all arguments to be single-output expressions, "
|
|
191
|
-
f"but argument {i} has {arg.num_outputs} outputs"
|
|
192
|
-
)
|
|
193
|
-
self.args = args
|
|
194
|
-
|
|
195
|
-
def _compute_mptypes(self) -> list[MPType]:
|
|
196
|
-
# TupleExpr creates a tuple from single-output expressions
|
|
197
|
-
result_types = []
|
|
198
|
-
for arg in self.args:
|
|
199
|
-
result_types.append(
|
|
200
|
-
arg.mptype
|
|
201
|
-
) # Use mptype since we validated single output
|
|
202
|
-
return result_types
|
|
203
|
-
|
|
204
|
-
def accept(self, visitor: ExprVisitor) -> Any:
|
|
205
|
-
return visitor.visit_tuple(self)
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
class CondExpr(Expr):
|
|
209
|
-
"""Expression for conditional execution.
|
|
210
|
-
|
|
211
|
-
Added fields:
|
|
212
|
-
verify_uniform: whether runtime should assert the predicate is uniform across parties.
|
|
213
|
-
"""
|
|
214
|
-
|
|
215
|
-
def __init__(
|
|
216
|
-
self,
|
|
217
|
-
pred: Expr,
|
|
218
|
-
then_fn: FuncDefExpr,
|
|
219
|
-
else_fn: FuncDefExpr,
|
|
220
|
-
args: list[Expr],
|
|
221
|
-
verify_uniform: bool = False,
|
|
222
|
-
):
|
|
223
|
-
super().__init__()
|
|
224
|
-
self.pred = pred
|
|
225
|
-
self.then_fn = then_fn
|
|
226
|
-
self.else_fn = else_fn
|
|
227
|
-
self.args = args
|
|
228
|
-
self.verify_uniform = verify_uniform
|
|
229
|
-
|
|
230
|
-
def _compute_mptypes(self) -> list[MPType]:
|
|
231
|
-
for t_type, e_type in zip(
|
|
232
|
-
self.then_fn.mptypes, self.else_fn.mptypes, strict=False
|
|
233
|
-
):
|
|
234
|
-
if t_type != e_type:
|
|
235
|
-
raise TypeError(
|
|
236
|
-
f"Then branch type {t_type} does not match else branch type {e_type}"
|
|
237
|
-
)
|
|
238
|
-
return self.then_fn.mptypes
|
|
239
|
-
|
|
240
|
-
def accept(self, visitor: ExprVisitor) -> Any:
|
|
241
|
-
return visitor.visit_cond(self)
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
class WhileExpr(Expr):
|
|
245
|
-
"""Expression for while loop."""
|
|
246
|
-
|
|
247
|
-
def __init__(
|
|
248
|
-
self,
|
|
249
|
-
cond_fn: FuncDefExpr,
|
|
250
|
-
body_fn: FuncDefExpr,
|
|
251
|
-
args: list[Expr],
|
|
252
|
-
):
|
|
253
|
-
super().__init__()
|
|
254
|
-
if not args:
|
|
255
|
-
raise ValueError("WhileExpr requires at least one argument (init value)")
|
|
256
|
-
self.cond_fn = cond_fn
|
|
257
|
-
self.body_fn = body_fn
|
|
258
|
-
self.args = args
|
|
259
|
-
|
|
260
|
-
def _compute_mptypes(self) -> list[MPType]:
|
|
261
|
-
# The result types of a while loop are the same as the body function's outputs.
|
|
262
|
-
# This supports multi-value loop-carried state (PyTree leaves) and ensures
|
|
263
|
-
# evaluator can determine how many values are produced by the loop.
|
|
264
|
-
return self.body_fn.mptypes
|
|
265
|
-
|
|
266
|
-
def accept(self, visitor: ExprVisitor) -> Any:
|
|
267
|
-
return visitor.visit_while(self)
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
class ConvExpr(Expr):
|
|
271
|
-
"""Expression for convergence of multiple variables."""
|
|
272
|
-
|
|
273
|
-
def __init__(self, vars: list[Expr]):
|
|
274
|
-
super().__init__()
|
|
275
|
-
|
|
276
|
-
# Validate all vars have identical out-length.
|
|
277
|
-
for v in vars:
|
|
278
|
-
if v.num_outputs != 1:
|
|
279
|
-
raise ValueError("All variables in ConvExpr must have the same arity.")
|
|
280
|
-
|
|
281
|
-
self.vars = vars
|
|
282
|
-
|
|
283
|
-
def _compute_mptypes(self) -> list[MPType]:
|
|
284
|
-
# Collect the idx-th mptype from every var.
|
|
285
|
-
types = [v.mptype for v in self.vars]
|
|
286
|
-
# Validate dtype / shape consistency.
|
|
287
|
-
first = types[0]
|
|
288
|
-
for c in types[1:]:
|
|
289
|
-
if c.raw_type() != first.raw_type():
|
|
290
|
-
raise TypeError(f"Inconsistent type in pconv: {c} vs {first}")
|
|
291
|
-
|
|
292
|
-
# Deduce the pmask by intersecting all pmasks.
|
|
293
|
-
pmasks = [t.pmask for t in types]
|
|
294
|
-
dynamic_pmask = False
|
|
295
|
-
if any(pmask is None for pmask in pmasks):
|
|
296
|
-
logging.warning("pconv called with None pmask.")
|
|
297
|
-
dynamic_pmask = True
|
|
298
|
-
|
|
299
|
-
non_none_pmasks = [pmask for pmask in pmasks if pmask is not None]
|
|
300
|
-
for i, mask1 in enumerate(non_none_pmasks):
|
|
301
|
-
for mask2 in non_none_pmasks[i + 1 :]:
|
|
302
|
-
if not Mask(mask1).is_disjoint(mask2):
|
|
303
|
-
raise ValueError(
|
|
304
|
-
f"pconv called with non-disjoint pmasks: {pmasks}."
|
|
305
|
-
)
|
|
306
|
-
|
|
307
|
-
# deduce output pmask.
|
|
308
|
-
if dynamic_pmask:
|
|
309
|
-
out_pmask = None
|
|
310
|
-
else:
|
|
311
|
-
valid_pmasks = [pmask for pmask in pmasks if pmask is not None]
|
|
312
|
-
if valid_pmasks:
|
|
313
|
-
out_pmask = Mask(valid_pmasks[0])
|
|
314
|
-
for mask in valid_pmasks[1:]:
|
|
315
|
-
out_pmask = out_pmask.union(mask)
|
|
316
|
-
else:
|
|
317
|
-
out_pmask = None
|
|
318
|
-
|
|
319
|
-
return [MPType(first.raw_type(), out_pmask, first.attrs)]
|
|
320
|
-
|
|
321
|
-
def accept(self, visitor: ExprVisitor) -> Any:
|
|
322
|
-
return visitor.visit_conv(self)
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
class ShflSExpr(Expr):
|
|
326
|
-
"""Expression for static shuffle operation.
|
|
327
|
-
|
|
328
|
-
Redistributes data from source ranks to target ranks based on a specified
|
|
329
|
-
mapping. Each party in the output mask (`pmask`) receives data from a
|
|
330
|
-
corresponding source rank specified in `src_ranks`.
|
|
331
|
-
|
|
332
|
-
Rationale for Design (Pull vs. Push Model):
|
|
333
|
-
This operation uses a "pull" model, where each receiving party explicitly
|
|
334
|
-
states its data source (`src_ranks`). This contrasts with a "push" model,
|
|
335
|
-
where each sending party would specify a destination.
|
|
336
|
-
|
|
337
|
-
The pull model is chosen because it guarantees that every party in the
|
|
338
|
-
output `pmask` receives exactly one value, upholding the semantic
|
|
339
|
-
integrity of the computation graph.
|
|
340
|
-
|
|
341
|
-
A push model, on the other hand, would be semantically ambiguous. For
|
|
342
|
-
example, two different source parties could attempt to send data to the
|
|
343
|
-
same destination, or some parties might receive no data at all. This
|
|
344
|
-
would break the Single Instruction, Multiple Programs (SIMP) paradigm by
|
|
345
|
-
creating an unpredictable number of outputs at each party.
|
|
346
|
-
|
|
347
|
-
While the pull model might have performance implications if multiple
|
|
348
|
-
receivers pull from the same source (potentially creating a network
|
|
349
|
-
bottleneck at that source), this is a performance consideration rather
|
|
350
|
-
than a correctness issue. The chosen design prioritizes semantic
|
|
351
|
-
predictability and correctness.
|
|
352
|
-
"""
|
|
353
|
-
|
|
354
|
-
def __init__(self, src_val: Expr, pmask: Mask, src_ranks: list[Rank]):
|
|
355
|
-
"""Initialize static shuffle expression.
|
|
356
|
-
|
|
357
|
-
Args:
|
|
358
|
-
src_val (Expr): The input tensor to be shuffled.
|
|
359
|
-
pmask (Mask): The mask indicating which parties will hold the output.
|
|
360
|
-
Only parties with non-zero bits in pmask will receive output.
|
|
361
|
-
src_ranks (list[Rank]): List of source ranks. The i-th output party
|
|
362
|
-
(i-th non-zero bit in pmask) receives data from
|
|
363
|
-
src_ranks[i].
|
|
364
|
-
|
|
365
|
-
Raises:
|
|
366
|
-
ValueError: If src_val has multiple outputs, if src_ranks length doesn't
|
|
367
|
-
match pmask bit count, or if any rank in src_ranks is not
|
|
368
|
-
present in src_val.pmask.
|
|
369
|
-
|
|
370
|
-
Example:
|
|
371
|
-
If pmask indicates parties [0, 2] should receive output and src_ranks = [1, 3], then:
|
|
372
|
-
- Party 0 receives data from rank 1
|
|
373
|
-
- Party 2 receives data from rank 3
|
|
374
|
-
"""
|
|
375
|
-
super().__init__()
|
|
376
|
-
if src_val.num_outputs != 1:
|
|
377
|
-
raise ValueError(
|
|
378
|
-
f"ShflSExpr requires a single output source, got {src_val.num_outputs}"
|
|
379
|
-
)
|
|
380
|
-
|
|
381
|
-
# Assign values first before validation
|
|
382
|
-
self.src_val = src_val
|
|
383
|
-
self.pmask = pmask
|
|
384
|
-
self.src_ranks = src_ranks
|
|
385
|
-
|
|
386
|
-
# Now do validation using the assigned values
|
|
387
|
-
if len(self.src_ranks) != Mask(self.pmask).num_parties():
|
|
388
|
-
raise ValueError(
|
|
389
|
-
f"src_ranks length ({len(self.src_ranks)}) not match {self.pmask}"
|
|
390
|
-
)
|
|
391
|
-
for i, rank in enumerate(self.src_ranks):
|
|
392
|
-
src_pmask = self.src_val.mptype.pmask
|
|
393
|
-
if src_pmask is not None and rank not in Mask(src_pmask):
|
|
394
|
-
raise ValueError(
|
|
395
|
-
f"Source rank {rank} at index {i} is not present in src {Mask(src_pmask)}"
|
|
396
|
-
)
|
|
397
|
-
|
|
398
|
-
def _compute_mptypes(self) -> list[MPType]:
|
|
399
|
-
# The types are the same as the source value, but with a new pmask.
|
|
400
|
-
src_type = self.src_val.mptype
|
|
401
|
-
return [MPType(src_type._type, self.pmask, src_type.attrs)]
|
|
402
|
-
|
|
403
|
-
def accept(self, visitor: ExprVisitor) -> Any:
|
|
404
|
-
return visitor.visit_shfl_s(self)
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
class ShflExpr(Expr):
|
|
408
|
-
"""Expression for dynamic shuffle operation."""
|
|
409
|
-
|
|
410
|
-
def __init__(self, src: Expr, index: Expr):
|
|
411
|
-
super().__init__()
|
|
412
|
-
self.src = src
|
|
413
|
-
self.index = index
|
|
414
|
-
|
|
415
|
-
def _compute_mptypes(self) -> list[MPType]:
|
|
416
|
-
# Dynamic shuffle is complex. The resulting pmask is often unknown
|
|
417
|
-
# at compile time. We'll assume the tensor types remain the same
|
|
418
|
-
# but the pmask becomes None (runtime-determined).
|
|
419
|
-
src_types = self.src.mptypes
|
|
420
|
-
result_types = []
|
|
421
|
-
for src_type in src_types:
|
|
422
|
-
result_types.append(
|
|
423
|
-
MPType.tensor(src_type.dtype, src_type.shape, None, **src_type.attrs)
|
|
424
|
-
)
|
|
425
|
-
return result_types
|
|
426
|
-
|
|
427
|
-
def accept(self, visitor: ExprVisitor) -> Any:
|
|
428
|
-
return visitor.visit_shfl(self)
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
class AccessExpr(Expr):
|
|
432
|
-
"""Expression for accessing a specific output of a multi-output expression.
|
|
433
|
-
|
|
434
|
-
As the counterpart to TupleExpr, AccessExpr is the "un-packing" or "selection"
|
|
435
|
-
primitive in the MIMO system. It takes a (potentially multi-output) expression
|
|
436
|
-
and an index, and produces a new single-output expression representing just
|
|
437
|
-
the selected output.
|
|
438
|
-
|
|
439
|
-
This is essential for routing specific outputs from a multi-output function
|
|
440
|
-
or a flattened stream to subsequent operations that expect single inputs.
|
|
441
|
-
"""
|
|
442
|
-
|
|
443
|
-
def __init__(self, src: Expr, index: int):
|
|
444
|
-
super().__init__()
|
|
445
|
-
self.src = src
|
|
446
|
-
self.index = index
|
|
447
|
-
|
|
448
|
-
def _compute_mptypes(self) -> list[MPType]:
|
|
449
|
-
# Access a specific output from the expression's output list
|
|
450
|
-
expr_types = self.src.mptypes
|
|
451
|
-
if self.index < 0 or self.index >= len(expr_types):
|
|
452
|
-
raise IndexError(
|
|
453
|
-
f"Index {self.index} out of range for expression with {len(expr_types)} outputs"
|
|
454
|
-
)
|
|
455
|
-
return [expr_types[self.index]]
|
|
456
|
-
|
|
457
|
-
def accept(self, visitor: ExprVisitor) -> Any:
|
|
458
|
-
return visitor.visit_access(self)
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
class VariableExpr(Expr):
|
|
462
|
-
"""Expression for variable reference/lookup."""
|
|
463
|
-
|
|
464
|
-
def __init__(self, name: str, mptype: MPType):
|
|
465
|
-
super().__init__()
|
|
466
|
-
self.name = name
|
|
467
|
-
self.mptype_value = mptype
|
|
468
|
-
|
|
469
|
-
def _compute_mptypes(self) -> list[MPType]:
|
|
470
|
-
# Return the explicitly provided type for this variable.
|
|
471
|
-
return [self.mptype_value]
|
|
472
|
-
|
|
473
|
-
def accept(self, visitor: ExprVisitor) -> Any:
|
|
474
|
-
return visitor.visit_variable(self)
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
class FuncDefExpr(Expr):
|
|
478
|
-
"""Expression representing a function definition with parameters and body.
|
|
479
|
-
|
|
480
|
-
This class captures the essence of lambda abstraction in functional programming.
|
|
481
|
-
The body expression tree may contain free variables (VariableExpr nodes) that
|
|
482
|
-
reference parameter names. When the function is called, arguments are bound
|
|
483
|
-
to parameters positionally, resolving these free variables.
|
|
484
|
-
|
|
485
|
-
Example:
|
|
486
|
-
Consider a function that adds two variables:
|
|
487
|
-
```
|
|
488
|
-
# Body expression tree contains free variables "x" and "y"
|
|
489
|
-
body = EvalExpr(
|
|
490
|
-
add_pfunc, [VariableExpr("x", int_type), VariableExpr("y", int_type)]
|
|
491
|
-
)
|
|
492
|
-
|
|
493
|
-
# Parameters define the binding order - note "y" comes before "x"
|
|
494
|
-
params = ["z", "y", "x"] # extra parameter "z", different order
|
|
495
|
-
|
|
496
|
-
func_def = FuncDefExpr(params, body)
|
|
497
|
-
|
|
498
|
-
# When called with [expr0, expr1, expr2]:
|
|
499
|
-
# - "z" binds to expr0 (unused in body, but valid)
|
|
500
|
-
# - "y" binds to expr1 (resolves VariableExpr("y") in body)
|
|
501
|
-
# - "x" binds to expr2 (resolves VariableExpr("x") in body)
|
|
502
|
-
call = CallExpr(func_def, [expr0, expr1, expr2])
|
|
503
|
-
```
|
|
504
|
-
|
|
505
|
-
Key insights:
|
|
506
|
-
- Free variables in the body are placeholders waiting for concrete expressions
|
|
507
|
-
- Parameters act as a "binding contract" - they define which arguments map to which variables
|
|
508
|
-
- Parameter order matters for positional binding, not alphabetical or usage order
|
|
509
|
-
- Parameters can include names not used in the body (dead parameters)
|
|
510
|
-
- All free variables in the body should have corresponding parameters for well-formed functions
|
|
511
|
-
"""
|
|
512
|
-
|
|
513
|
-
def __init__(self, params: list[str], body: Expr):
|
|
514
|
-
super().__init__()
|
|
515
|
-
self.params = params
|
|
516
|
-
self.body = body
|
|
517
|
-
|
|
518
|
-
def _compute_mptypes(self) -> list[MPType]:
|
|
519
|
-
# The types of a function are the types of its body.
|
|
520
|
-
return self.body.mptypes
|
|
521
|
-
|
|
522
|
-
def accept(self, visitor: ExprVisitor) -> Any:
|
|
523
|
-
return visitor.visit_func_def(self)
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
class CallExpr(Expr):
|
|
527
|
-
"""Expression for function call."""
|
|
528
|
-
|
|
529
|
-
def __init__(self, name: str, fn: FuncDefExpr, args: list[Expr]):
|
|
530
|
-
super().__init__()
|
|
531
|
-
self.name = name
|
|
532
|
-
self.fn = fn
|
|
533
|
-
self.args = args
|
|
534
|
-
|
|
535
|
-
def _compute_mptypes(self) -> list[MPType]:
|
|
536
|
-
# The result types are the types of the function's body, with parameter
|
|
537
|
-
# types substituted. For simplicity, we return the function's declared
|
|
538
|
-
# return types. A full implementation would require substitution logic.
|
|
539
|
-
return self.fn.mptypes
|
|
540
|
-
|
|
541
|
-
def accept(self, visitor: ExprVisitor) -> Any:
|
|
542
|
-
return visitor.visit_call(self)
|