mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +46 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +19 -13
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/libs/mpc/psi/rr22.py +303 -0
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang/v2/libs/mpc/psi/rr22.py +0 -344
- mplang_nightly-0.1.dev268.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/core/mpobject.py
DELETED
|
@@ -1,117 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
from abc import ABC, abstractmethod
|
|
18
|
-
from typing import TYPE_CHECKING, Any
|
|
19
|
-
|
|
20
|
-
from mplang.v1.core.dtypes import DType
|
|
21
|
-
from mplang.v1.core.mask import Mask
|
|
22
|
-
from mplang.v1.core.mptype import MPType
|
|
23
|
-
from mplang.v1.core.table import TableType
|
|
24
|
-
from mplang.v1.core.tensor import Shape
|
|
25
|
-
|
|
26
|
-
if TYPE_CHECKING:
|
|
27
|
-
from mplang.v1.core.cluster import ClusterSpec
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class MPContext:
|
|
31
|
-
"""The context of an MPObject.
|
|
32
|
-
|
|
33
|
-
MPContext is the abstract base class for all execution contexts.
|
|
34
|
-
It only holds the immutable cluster_spec plus lightweight parent/root
|
|
35
|
-
helpers used to support stack-scoped extension state (attached lazily by
|
|
36
|
-
external features on the root context).
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
def __init__(self, cluster_spec: ClusterSpec, *, parent: MPContext | None = None):
|
|
40
|
-
if cluster_spec is None:
|
|
41
|
-
raise ValueError("cluster_spec cannot be None")
|
|
42
|
-
self.cluster_spec = cluster_spec
|
|
43
|
-
# Parent link enables stack-scoped state sharing: ephemeral child contexts
|
|
44
|
-
# (e.g. short-lived tracing) can delegate to a stable root without relying
|
|
45
|
-
# on process-wide globals.
|
|
46
|
-
self._parent: MPContext | None = parent
|
|
47
|
-
|
|
48
|
-
# Basic topology helpers
|
|
49
|
-
def world_size(self) -> int:
|
|
50
|
-
return len(self.cluster_spec.nodes)
|
|
51
|
-
|
|
52
|
-
@property
|
|
53
|
-
def parent(self) -> MPContext | None:
|
|
54
|
-
"""Direct parent context or None if this is root."""
|
|
55
|
-
return self._parent
|
|
56
|
-
|
|
57
|
-
def root(self) -> MPContext:
|
|
58
|
-
"""Return the root context (follow parent chain)."""
|
|
59
|
-
ctx: MPContext = self
|
|
60
|
-
visited: set[int] = set()
|
|
61
|
-
while ctx._parent is not None:
|
|
62
|
-
if id(ctx) in visited:
|
|
63
|
-
raise RuntimeError("Cycle detected in MPContext parent chain")
|
|
64
|
-
visited.add(id(ctx))
|
|
65
|
-
ctx = ctx._parent
|
|
66
|
-
return ctx
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
class MPObject(ABC):
|
|
70
|
-
"""The base class for all objects in mp-system."""
|
|
71
|
-
|
|
72
|
-
@property
|
|
73
|
-
@abstractmethod
|
|
74
|
-
def mptype(self) -> MPType:
|
|
75
|
-
"""The type information of the object.
|
|
76
|
-
|
|
77
|
-
This property is readonly (mandatory) and will be used for JAX compilation
|
|
78
|
-
to determine the appropriate data type during trace and compilation phases.
|
|
79
|
-
MPType can be passed between different MPObjects as a value.
|
|
80
|
-
"""
|
|
81
|
-
|
|
82
|
-
@property
|
|
83
|
-
def dtype(self) -> DType:
|
|
84
|
-
return self.mptype.dtype
|
|
85
|
-
|
|
86
|
-
@property
|
|
87
|
-
def shape(self) -> Shape:
|
|
88
|
-
return self.mptype.shape
|
|
89
|
-
|
|
90
|
-
@property
|
|
91
|
-
def schema(self) -> TableType:
|
|
92
|
-
"""The table schema of the object.
|
|
93
|
-
|
|
94
|
-
Only available for table types.
|
|
95
|
-
"""
|
|
96
|
-
return self.mptype.schema
|
|
97
|
-
|
|
98
|
-
@property
|
|
99
|
-
def pmask(self) -> Mask | None:
|
|
100
|
-
return self.mptype.pmask
|
|
101
|
-
|
|
102
|
-
@property
|
|
103
|
-
def attrs(self) -> dict[str, Any]:
|
|
104
|
-
return self.mptype.attrs
|
|
105
|
-
|
|
106
|
-
@property
|
|
107
|
-
@abstractmethod
|
|
108
|
-
def ctx(self) -> MPContext:
|
|
109
|
-
"""Return the context of the object."""
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
# Forward docstrings from MPType to MPObject
|
|
113
|
-
MPObject.dtype.__doc__ = MPType.dtype.__doc__
|
|
114
|
-
MPObject.shape.__doc__ = MPType.shape.__doc__
|
|
115
|
-
MPObject.schema.__doc__ = MPType.schema.__doc__
|
|
116
|
-
MPObject.pmask.__doc__ = MPType.pmask.__doc__
|
|
117
|
-
MPObject.attrs.__doc__ = MPType.attrs.__doc__
|
mplang/v1/core/mptype.py
DELETED
|
@@ -1,407 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
import copy
|
|
18
|
-
from typing import TYPE_CHECKING, Any
|
|
19
|
-
|
|
20
|
-
import numpy as np
|
|
21
|
-
|
|
22
|
-
if TYPE_CHECKING:
|
|
23
|
-
from mplang.v1.core.mpobject import MPObject
|
|
24
|
-
|
|
25
|
-
from mplang.v1.core.dtypes import STRING, DType
|
|
26
|
-
from mplang.v1.core.mask import Mask
|
|
27
|
-
from mplang.v1.core.table import TableLike, TableType
|
|
28
|
-
from mplang.v1.core.tensor import ScalarType, Shape, TensorLike, TensorType
|
|
29
|
-
|
|
30
|
-
# basic type aliases
|
|
31
|
-
Rank = int
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
class MPType:
|
|
35
|
-
"""A type that describes the type information of an MPObject."""
|
|
36
|
-
|
|
37
|
-
_type: TensorType | TableType
|
|
38
|
-
_pmask: Mask | None
|
|
39
|
-
_attrs: dict[str, Any]
|
|
40
|
-
|
|
41
|
-
def __init__(
|
|
42
|
-
self,
|
|
43
|
-
type_info: TensorType | TableType,
|
|
44
|
-
pmask: Mask | None = None,
|
|
45
|
-
attrs: dict[str, Any] | None = None,
|
|
46
|
-
):
|
|
47
|
-
"""Initialize MPType.
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
type_info: The type information (TensorType for tensors, TableType for tables).
|
|
51
|
-
pmask: The party mask, used for compile/trace time determine which party holds the object.
|
|
52
|
-
attrs: Attributes are key-value pairs that can be used to store additional information about the object.
|
|
53
|
-
"""
|
|
54
|
-
self._type = type_info
|
|
55
|
-
self._pmask = pmask
|
|
56
|
-
# Ensure attrs is a copy
|
|
57
|
-
self._attrs = copy.copy(attrs) if attrs is not None else {}
|
|
58
|
-
|
|
59
|
-
@classmethod
|
|
60
|
-
def tensor(
|
|
61
|
-
cls,
|
|
62
|
-
dtype: DType | Any,
|
|
63
|
-
shape: Shape,
|
|
64
|
-
pmask: int | Mask | None = None,
|
|
65
|
-
**attrs: Any,
|
|
66
|
-
) -> MPType:
|
|
67
|
-
"""Create a tensor type.
|
|
68
|
-
|
|
69
|
-
Args:
|
|
70
|
-
dtype: The data type of the tensor.
|
|
71
|
-
shape: The shape of the tensor.
|
|
72
|
-
pmask: The party mask.
|
|
73
|
-
**attrs: Additional attributes.
|
|
74
|
-
|
|
75
|
-
Returns:
|
|
76
|
-
MPType instance for tensor.
|
|
77
|
-
|
|
78
|
-
Raises:
|
|
79
|
-
ValueError: If dtype is table-only.
|
|
80
|
-
"""
|
|
81
|
-
# Convert dtype to DType if needed and validate
|
|
82
|
-
if not isinstance(dtype, DType):
|
|
83
|
-
dtype = DType.from_any(dtype)
|
|
84
|
-
|
|
85
|
-
# Ensure tensor types don't use table-only dtypes
|
|
86
|
-
if dtype.is_table_only:
|
|
87
|
-
raise ValueError(
|
|
88
|
-
f"Data type '{dtype.name}' is only supported in tables, "
|
|
89
|
-
f"not in tensors. Use table types for string, date, and other "
|
|
90
|
-
f"non-numeric data types."
|
|
91
|
-
)
|
|
92
|
-
|
|
93
|
-
if isinstance(pmask, int):
|
|
94
|
-
pmask = Mask.from_int(pmask)
|
|
95
|
-
|
|
96
|
-
tensor_info = TensorType(dtype, shape)
|
|
97
|
-
return cls(tensor_info, pmask, attrs)
|
|
98
|
-
|
|
99
|
-
@classmethod
|
|
100
|
-
def table(
|
|
101
|
-
cls,
|
|
102
|
-
schema: TableType | dict[str, DType],
|
|
103
|
-
pmask: int | Mask | None = None,
|
|
104
|
-
**attrs: Any,
|
|
105
|
-
) -> MPType:
|
|
106
|
-
"""Create a table type.
|
|
107
|
-
|
|
108
|
-
Args:
|
|
109
|
-
schema: The table schema or dict mapping column names to types.
|
|
110
|
-
pmask: The party mask.
|
|
111
|
-
**attrs: Additional attributes.
|
|
112
|
-
|
|
113
|
-
Returns:
|
|
114
|
-
MPType instance for table.
|
|
115
|
-
"""
|
|
116
|
-
if isinstance(schema, dict):
|
|
117
|
-
schema = TableType.from_dict(schema)
|
|
118
|
-
|
|
119
|
-
if isinstance(pmask, int):
|
|
120
|
-
pmask = Mask.from_int(pmask)
|
|
121
|
-
|
|
122
|
-
return cls(schema, pmask, attrs)
|
|
123
|
-
|
|
124
|
-
@property
|
|
125
|
-
def is_tensor(self) -> bool:
|
|
126
|
-
"""Check if this is a tensor type."""
|
|
127
|
-
return isinstance(self._type, TensorType)
|
|
128
|
-
|
|
129
|
-
@property
|
|
130
|
-
def is_table(self) -> bool:
|
|
131
|
-
"""Check if this is a table type."""
|
|
132
|
-
return isinstance(self._type, TableType)
|
|
133
|
-
|
|
134
|
-
@property
|
|
135
|
-
def dtype(self) -> DType:
|
|
136
|
-
"""The data type of the object.
|
|
137
|
-
|
|
138
|
-
This property is readonly (mandatory) and will be used for JAX compilation
|
|
139
|
-
to determine the appropriate data type during trace and compilation phases.
|
|
140
|
-
|
|
141
|
-
Only available for tensor types.
|
|
142
|
-
"""
|
|
143
|
-
if not isinstance(self._type, TensorType):
|
|
144
|
-
raise AttributeError("dtype is only available for tensor types")
|
|
145
|
-
return self._type.dtype
|
|
146
|
-
|
|
147
|
-
@property
|
|
148
|
-
def shape(self) -> Shape:
|
|
149
|
-
"""The shape of the object, represented as a tuple of integers.
|
|
150
|
-
|
|
151
|
-
For example, a 2D tensor with shape (3, 4) would be represented as (3, 4).
|
|
152
|
-
The shape can be empty, which indicates a scalar.
|
|
153
|
-
|
|
154
|
-
This property is readonly (mandatory) and will be used for JAX compilation
|
|
155
|
-
to determine tensor shapes during trace and compilation phases.
|
|
156
|
-
|
|
157
|
-
Only available for tensor types.
|
|
158
|
-
"""
|
|
159
|
-
if not isinstance(self._type, TensorType):
|
|
160
|
-
raise AttributeError("shape is only available for tensor types")
|
|
161
|
-
return self._type.shape
|
|
162
|
-
|
|
163
|
-
@property
|
|
164
|
-
def schema(self) -> TableType:
|
|
165
|
-
"""The table schema.
|
|
166
|
-
|
|
167
|
-
Only available for table types.
|
|
168
|
-
"""
|
|
169
|
-
if not isinstance(self._type, TableType):
|
|
170
|
-
raise AttributeError("schema is only available for table types")
|
|
171
|
-
return self._type
|
|
172
|
-
|
|
173
|
-
@property
|
|
174
|
-
def pmask(self) -> Mask | None:
|
|
175
|
-
"""The party mask indicating which parties hold the data.
|
|
176
|
-
|
|
177
|
-
Value interpretation:
|
|
178
|
-
- When not None: A bitmask where the i'th bit is 1 if the i'th party holds
|
|
179
|
-
the data, and 0 otherwise. For example, 0b1101 means parties 0, 2, and 3
|
|
180
|
-
hold the data, while party 1 does not.
|
|
181
|
-
- When None: Party ownership is unknown at compile/trace time and will be
|
|
182
|
-
completely determined at runtime.
|
|
183
|
-
|
|
184
|
-
Semantic meaning:
|
|
185
|
-
This mask can be either manually set or deduced by primitive functions during
|
|
186
|
-
compilation/tracing. When None, it does NOT imply either a full mask (all
|
|
187
|
-
parties) or zero mask (no parties) - the actual ownership pattern is entirely
|
|
188
|
-
runtime-dependent.
|
|
189
|
-
"""
|
|
190
|
-
return self._pmask
|
|
191
|
-
|
|
192
|
-
@property
|
|
193
|
-
def attrs(self) -> dict[str, Any]:
|
|
194
|
-
"""Attributes are key-value pairs that can be used to store additional
|
|
195
|
-
information about the object."""
|
|
196
|
-
return self._attrs
|
|
197
|
-
|
|
198
|
-
def raw_type(self) -> TensorType | TableType:
|
|
199
|
-
"""Get the raw type information (TensorType or TableType)."""
|
|
200
|
-
return self._type
|
|
201
|
-
|
|
202
|
-
def set_attr(self, key: str, value: Any) -> None:
|
|
203
|
-
"""Set an attribute for this type."""
|
|
204
|
-
self._attrs[key] = value
|
|
205
|
-
|
|
206
|
-
def get_attr(self, key: str, default: Any = None) -> Any:
|
|
207
|
-
"""Get an attribute for this type."""
|
|
208
|
-
return self._attrs.get(key, default)
|
|
209
|
-
|
|
210
|
-
def __repr__(self) -> str:
|
|
211
|
-
"""String representation of MPType.
|
|
212
|
-
|
|
213
|
-
Schema:
|
|
214
|
-
- For tensor: dtype[shape]<pmask>{other_attrs}
|
|
215
|
-
- For table: Tbl(col1:type1, col2:type2)<pmask>{other_attrs}
|
|
216
|
-
|
|
217
|
-
Examples:
|
|
218
|
-
- u64 # scalar uint64
|
|
219
|
-
- f32[3, 2] # 3x2 float32 tensor
|
|
220
|
-
- f16[3]<3> # float16 vector with pmask=3
|
|
221
|
-
- u32[5, 5]<F>{device="P0"} # uint32 matrix with pmask=15 and device attr
|
|
222
|
-
- Tbl(id:i64, name:str) # table with id and name columns
|
|
223
|
-
"""
|
|
224
|
-
if isinstance(self._type, TensorType):
|
|
225
|
-
# Start with short dtype name
|
|
226
|
-
ret = self._type.dtype.short_name()
|
|
227
|
-
|
|
228
|
-
# Add shape if not scalar
|
|
229
|
-
if self._type.shape:
|
|
230
|
-
shape_str = ", ".join(str(d) for d in self._type.shape)
|
|
231
|
-
ret += f"[{shape_str}]"
|
|
232
|
-
else: # TableType
|
|
233
|
-
cols = ", ".join(
|
|
234
|
-
f"{name}:{dtype.short_name()}" for name, dtype in self._type.columns
|
|
235
|
-
)
|
|
236
|
-
ret = f"Tbl({cols})"
|
|
237
|
-
|
|
238
|
-
# Add pmask in angle brackets if present
|
|
239
|
-
if self._pmask is not None:
|
|
240
|
-
ret += f"<{self._pmask:X}>"
|
|
241
|
-
|
|
242
|
-
# Add other attributes in curly braces if any
|
|
243
|
-
if self._attrs:
|
|
244
|
-
attrs_list = []
|
|
245
|
-
for key, value in self._attrs.items():
|
|
246
|
-
if isinstance(value, str):
|
|
247
|
-
attrs_list.append(f'{key}="{value}"')
|
|
248
|
-
else:
|
|
249
|
-
attrs_list.append(f"{key}={value}")
|
|
250
|
-
ret += "{" + ", ".join(attrs_list) + "}"
|
|
251
|
-
|
|
252
|
-
return ret
|
|
253
|
-
|
|
254
|
-
def __eq__(self, other: object) -> bool:
|
|
255
|
-
"""Check if two MPType objects are equal."""
|
|
256
|
-
if not isinstance(other, MPType):
|
|
257
|
-
return False
|
|
258
|
-
return (
|
|
259
|
-
self._type == other._type and self._pmask == other._pmask
|
|
260
|
-
# and self._attrs == other._attrs # TODO(jint): attrs should be optional
|
|
261
|
-
)
|
|
262
|
-
|
|
263
|
-
def __hash__(self) -> int:
|
|
264
|
-
"""Compute hash for MPType objects."""
|
|
265
|
-
# Make attrs hashable by converting to frozenset of items
|
|
266
|
-
attrs_hash = hash(frozenset(self._attrs.items())) if self._attrs else 0
|
|
267
|
-
return hash((
|
|
268
|
-
self._type,
|
|
269
|
-
self._pmask,
|
|
270
|
-
attrs_hash,
|
|
271
|
-
))
|
|
272
|
-
|
|
273
|
-
def isInstance(self, obj: MPObject) -> bool:
|
|
274
|
-
"""Check if the given object is an instance of this MPType."""
|
|
275
|
-
# Import here to avoid circular import
|
|
276
|
-
from mplang.v1.core.mpobject import MPObject
|
|
277
|
-
|
|
278
|
-
if not isinstance(obj, MPObject):
|
|
279
|
-
return False
|
|
280
|
-
|
|
281
|
-
# Check if the object's type matches this type
|
|
282
|
-
obj_type = obj.mptype
|
|
283
|
-
if type(self._type) is not type(obj_type._type):
|
|
284
|
-
return False
|
|
285
|
-
|
|
286
|
-
if self._type != obj_type._type:
|
|
287
|
-
return False
|
|
288
|
-
|
|
289
|
-
# Check attributes
|
|
290
|
-
if self._attrs:
|
|
291
|
-
if not isinstance(obj.attrs, dict):
|
|
292
|
-
return False
|
|
293
|
-
for k, v in self._attrs.items():
|
|
294
|
-
if k not in obj.attrs or obj.attrs[k] != v:
|
|
295
|
-
return False
|
|
296
|
-
return True
|
|
297
|
-
|
|
298
|
-
def to_numpy(self) -> np.dtype:
|
|
299
|
-
"""Convert to NumPy dtype for compatibility.
|
|
300
|
-
|
|
301
|
-
Only available for tensor types.
|
|
302
|
-
"""
|
|
303
|
-
if not isinstance(self._type, TensorType):
|
|
304
|
-
raise AttributeError("to_numpy is only available for tensor types")
|
|
305
|
-
return self._type.to_numpy()
|
|
306
|
-
|
|
307
|
-
@staticmethod
|
|
308
|
-
def _create_tensor_info(obj: TensorLike | ScalarType) -> TensorType:
|
|
309
|
-
"""Helper method to create TensorType from tensor-like objects."""
|
|
310
|
-
if isinstance(obj, ScalarType):
|
|
311
|
-
return TensorType(DType.from_python_type(type(obj)), ())
|
|
312
|
-
elif isinstance(obj, TensorLike):
|
|
313
|
-
return TensorType(DType.from_any(obj.dtype), obj.shape)
|
|
314
|
-
elif isinstance(obj, list | tuple):
|
|
315
|
-
# Convert lists/tuples to numpy arrays for compatibility
|
|
316
|
-
arr = np.array(obj)
|
|
317
|
-
return TensorType(DType.from_any(arr.dtype), arr.shape)
|
|
318
|
-
else:
|
|
319
|
-
raise TypeError(f"Unsupported type: {type(obj)}.")
|
|
320
|
-
|
|
321
|
-
@classmethod
|
|
322
|
-
def from_tensor(
|
|
323
|
-
cls,
|
|
324
|
-
obj: TensorLike | ScalarType,
|
|
325
|
-
pmask: Mask | None = None,
|
|
326
|
-
**kwargs: Any,
|
|
327
|
-
) -> MPType:
|
|
328
|
-
"""Create MPType from tensor-like object.
|
|
329
|
-
|
|
330
|
-
Args:
|
|
331
|
-
obj: Tensor-like object or scalar.
|
|
332
|
-
pmask: The party mask.
|
|
333
|
-
**kwargs: Additional attributes.
|
|
334
|
-
|
|
335
|
-
Returns:
|
|
336
|
-
MPType instance for tensor.
|
|
337
|
-
"""
|
|
338
|
-
attrs = copy.copy(kwargs)
|
|
339
|
-
tensor_info = cls._create_tensor_info(obj)
|
|
340
|
-
return cls(tensor_info, pmask, attrs)
|
|
341
|
-
|
|
342
|
-
@classmethod
|
|
343
|
-
def from_mpobj(cls, obj: MPObject) -> MPType:
|
|
344
|
-
"""Create MPType from MPObject.
|
|
345
|
-
|
|
346
|
-
Args:
|
|
347
|
-
obj: MPObject instance.
|
|
348
|
-
|
|
349
|
-
Returns:
|
|
350
|
-
MPType instance with same type as the object.
|
|
351
|
-
"""
|
|
352
|
-
# assume obj is MPObject-like
|
|
353
|
-
obj_type = obj.mptype
|
|
354
|
-
return cls(obj_type._type, obj.pmask, copy.copy(obj.attrs))
|
|
355
|
-
|
|
356
|
-
@classmethod
|
|
357
|
-
def from_obj(cls, obj: Any, pmask: Mask | None = None, **attrs: Any) -> MPType:
|
|
358
|
-
"""Create MPType from any object, automatically inferring the type.
|
|
359
|
-
|
|
360
|
-
Args:
|
|
361
|
-
obj: Object to create type from.
|
|
362
|
-
pmask: The party mask.
|
|
363
|
-
**attrs: Additional attributes.
|
|
364
|
-
|
|
365
|
-
Returns:
|
|
366
|
-
MPType instance.
|
|
367
|
-
|
|
368
|
-
Raises:
|
|
369
|
-
TypeError: If object type cannot be inferred.
|
|
370
|
-
NotImplementedError: For table objects (not yet implemented).
|
|
371
|
-
"""
|
|
372
|
-
# Check if it's a table-like object using the TableLike protocol
|
|
373
|
-
if isinstance(obj, TableLike):
|
|
374
|
-
# For TableLike objects, try to extract schema information
|
|
375
|
-
try:
|
|
376
|
-
import pandas as pd
|
|
377
|
-
|
|
378
|
-
if isinstance(obj, pd.DataFrame):
|
|
379
|
-
from mplang.v1.core.dtypes import DType
|
|
380
|
-
|
|
381
|
-
schema_dict = {}
|
|
382
|
-
for col_name in obj.columns:
|
|
383
|
-
pandas_dtype = obj[col_name].dtype
|
|
384
|
-
# Convert pandas dtype to DType
|
|
385
|
-
if pandas_dtype.kind in (
|
|
386
|
-
"O",
|
|
387
|
-
"U",
|
|
388
|
-
"S",
|
|
389
|
-
): # object, unicode, string
|
|
390
|
-
schema_dict[col_name] = (
|
|
391
|
-
DType.from_numpy(pandas_dtype)
|
|
392
|
-
if pandas_dtype.kind != "O"
|
|
393
|
-
else STRING
|
|
394
|
-
)
|
|
395
|
-
else:
|
|
396
|
-
schema_dict[col_name] = DType.from_numpy(pandas_dtype)
|
|
397
|
-
schema = TableType.from_dict(schema_dict)
|
|
398
|
-
return cls(schema, pmask, attrs)
|
|
399
|
-
except ImportError:
|
|
400
|
-
pass
|
|
401
|
-
# For other table-like objects without pandas
|
|
402
|
-
raise NotImplementedError(
|
|
403
|
-
"Table object detection for non-pandas objects not fully implemented yet"
|
|
404
|
-
)
|
|
405
|
-
|
|
406
|
-
# Otherwise treat as tensor-like
|
|
407
|
-
return cls.from_tensor(obj, pmask, **attrs)
|
mplang/v1/core/pfunc.py
DELETED
|
@@ -1,130 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
import copy
|
|
18
|
-
from collections.abc import Sequence
|
|
19
|
-
from types import MappingProxyType
|
|
20
|
-
from typing import Any
|
|
21
|
-
|
|
22
|
-
from mplang.v1.core.table import TableType
|
|
23
|
-
from mplang.v1.core.tensor import TensorType
|
|
24
|
-
|
|
25
|
-
__all__ = [
|
|
26
|
-
"PFunction",
|
|
27
|
-
"get_fn_name",
|
|
28
|
-
]
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
class PFunction:
|
|
32
|
-
"""A Party Function represents a computation unit that can be executed by a single party.
|
|
33
|
-
|
|
34
|
-
PFunction serves as a unified interface for describing single-party computations
|
|
35
|
-
in multi-party computing scenarios. It can represent both:
|
|
36
|
-
1. Built-in operations (e.g., "spu.makeshares", "basic.read")
|
|
37
|
-
2. User-defined programmable functions with custom code
|
|
38
|
-
|
|
39
|
-
The PFunction accepts a list of typed inputs (TensorType/TableType). For
|
|
40
|
-
backend-only handles (e.g., crypto keys), use a sentinel TensorType
|
|
41
|
-
of UINT8 with shape (-1, 0) to indicate the argument should bypass
|
|
42
|
-
structural validation at runtime. Outputs should likewise use concrete
|
|
43
|
-
TensorType/TableType specs. PFunction can be:
|
|
44
|
-
- Expressed and defined in the mplang frontend
|
|
45
|
-
- Serialized for transmission between components
|
|
46
|
-
- Interpreted and executed by backend runtime engines
|
|
47
|
-
|
|
48
|
-
Args:
|
|
49
|
-
fn_type: The type/category identifier of this PFunction, indicating which
|
|
50
|
-
backend or handler should process it (e.g., "spu.makeshares", "basic.read",
|
|
51
|
-
"mlir.stablehlo"). This serves as a routing mechanism for execution.
|
|
52
|
-
ins_info: Type information for input parameters (TensorType or TableType)
|
|
53
|
-
outs_info: Type information for output values (TensorType or TableType)
|
|
54
|
-
fn_name: Optional name of the function. For programmable functions, this is
|
|
55
|
-
the user-defined function name. For built-in operations, this may be
|
|
56
|
-
None or a descriptive identifier.
|
|
57
|
-
fn_text: Optional serialized function body. For programmable functions, this
|
|
58
|
-
contains the actual code (e.g., MLIR, bytecode, source code). For built-in
|
|
59
|
-
operations, this is typically None.
|
|
60
|
-
**kwargs: Additional attributes and metadata specific to the function type.
|
|
61
|
-
These are used to pass execution parameters, configuration, and context
|
|
62
|
-
information to the backend handlers.
|
|
63
|
-
"""
|
|
64
|
-
|
|
65
|
-
# Required fields - these define the core execution context
|
|
66
|
-
fn_type: str # Unique identifier for backend routing
|
|
67
|
-
ins_info: tuple[TensorType | TableType, ...]
|
|
68
|
-
outs_info: tuple[TensorType | TableType, ...]
|
|
69
|
-
|
|
70
|
-
# Optional fields for programmable functions
|
|
71
|
-
fn_name: str | None # Function name (for programmable functions)
|
|
72
|
-
fn_text: str | None # Function body/code (for programmable functions)
|
|
73
|
-
|
|
74
|
-
# Custom attributes and metadata
|
|
75
|
-
attrs: MappingProxyType[str, Any] # Execution parameters and metadata
|
|
76
|
-
|
|
77
|
-
def __init__(
|
|
78
|
-
self,
|
|
79
|
-
fn_type: str,
|
|
80
|
-
ins_info: Sequence[TensorType | TableType],
|
|
81
|
-
outs_info: Sequence[TensorType | TableType],
|
|
82
|
-
*,
|
|
83
|
-
fn_name: str | None = None,
|
|
84
|
-
fn_text: str | None = None,
|
|
85
|
-
**kwargs: Any,
|
|
86
|
-
):
|
|
87
|
-
self.fn_type = fn_type
|
|
88
|
-
self.fn_name = fn_name
|
|
89
|
-
self.fn_text = fn_text
|
|
90
|
-
self.ins_info = tuple(ins_info)
|
|
91
|
-
self.outs_info = tuple(outs_info)
|
|
92
|
-
# Make attrs immutable to ensure PFunction immutability
|
|
93
|
-
# Create a copy first, then wrap it in MappingProxyType
|
|
94
|
-
self.attrs = MappingProxyType(copy.copy(kwargs))
|
|
95
|
-
|
|
96
|
-
def __repr__(self) -> str:
|
|
97
|
-
return f"{self.__class__.__name__}({self.fn_type}, {self.fn_name})"
|
|
98
|
-
|
|
99
|
-
def __hash__(self) -> int:
|
|
100
|
-
return hash((
|
|
101
|
-
self.fn_type,
|
|
102
|
-
self.fn_name,
|
|
103
|
-
self.fn_text,
|
|
104
|
-
self.ins_info,
|
|
105
|
-
self.outs_info,
|
|
106
|
-
frozenset(self.attrs.items()),
|
|
107
|
-
))
|
|
108
|
-
|
|
109
|
-
def __eq__(self, other: object) -> bool:
|
|
110
|
-
"""Check equality between PFunction instances."""
|
|
111
|
-
if not isinstance(other, PFunction):
|
|
112
|
-
return False
|
|
113
|
-
|
|
114
|
-
return (
|
|
115
|
-
self.fn_type == other.fn_type
|
|
116
|
-
and self.fn_name == other.fn_name
|
|
117
|
-
and self.fn_text == other.fn_text
|
|
118
|
-
and self.ins_info == other.ins_info
|
|
119
|
-
and self.outs_info == other.outs_info
|
|
120
|
-
and self.attrs == other.attrs
|
|
121
|
-
)
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
def get_fn_name(fn_like: Any) -> str:
|
|
125
|
-
if hasattr(fn_like, "__name__"):
|
|
126
|
-
return fn_like.__name__ # type: ignore[no-any-return]
|
|
127
|
-
if hasattr(fn_like, "func"):
|
|
128
|
-
# handle partial functions
|
|
129
|
-
return get_fn_name(fn_like.func)
|
|
130
|
-
return "unnamed function"
|