mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev270__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +46 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +19 -13
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/libs/mpc/psi/rr22.py +303 -0
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang/v2/libs/mpc/psi/rr22.py +0 -344
  162. mplang_nightly-0.1.dev268.dist-info/RECORD +0 -180
  163. /mplang/{v2/backends → backends}/channel.py +0 -0
  164. /mplang/{v2/edsl → edsl}/README.md +0 -0
  165. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  166. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  167. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  168. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  169. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  171. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  172. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  175. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  177. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  178. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  179. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
  180. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
  181. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
@@ -1,117 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- from abc import ABC, abstractmethod
18
- from typing import TYPE_CHECKING, Any
19
-
20
- from mplang.v1.core.dtypes import DType
21
- from mplang.v1.core.mask import Mask
22
- from mplang.v1.core.mptype import MPType
23
- from mplang.v1.core.table import TableType
24
- from mplang.v1.core.tensor import Shape
25
-
26
- if TYPE_CHECKING:
27
- from mplang.v1.core.cluster import ClusterSpec
28
-
29
-
30
- class MPContext:
31
- """The context of an MPObject.
32
-
33
- MPContext is the abstract base class for all execution contexts.
34
- It only holds the immutable cluster_spec plus lightweight parent/root
35
- helpers used to support stack-scoped extension state (attached lazily by
36
- external features on the root context).
37
- """
38
-
39
- def __init__(self, cluster_spec: ClusterSpec, *, parent: MPContext | None = None):
40
- if cluster_spec is None:
41
- raise ValueError("cluster_spec cannot be None")
42
- self.cluster_spec = cluster_spec
43
- # Parent link enables stack-scoped state sharing: ephemeral child contexts
44
- # (e.g. short-lived tracing) can delegate to a stable root without relying
45
- # on process-wide globals.
46
- self._parent: MPContext | None = parent
47
-
48
- # Basic topology helpers
49
- def world_size(self) -> int:
50
- return len(self.cluster_spec.nodes)
51
-
52
- @property
53
- def parent(self) -> MPContext | None:
54
- """Direct parent context or None if this is root."""
55
- return self._parent
56
-
57
- def root(self) -> MPContext:
58
- """Return the root context (follow parent chain)."""
59
- ctx: MPContext = self
60
- visited: set[int] = set()
61
- while ctx._parent is not None:
62
- if id(ctx) in visited:
63
- raise RuntimeError("Cycle detected in MPContext parent chain")
64
- visited.add(id(ctx))
65
- ctx = ctx._parent
66
- return ctx
67
-
68
-
69
- class MPObject(ABC):
70
- """The base class for all objects in mp-system."""
71
-
72
- @property
73
- @abstractmethod
74
- def mptype(self) -> MPType:
75
- """The type information of the object.
76
-
77
- This property is readonly (mandatory) and will be used for JAX compilation
78
- to determine the appropriate data type during trace and compilation phases.
79
- MPType can be passed between different MPObjects as a value.
80
- """
81
-
82
- @property
83
- def dtype(self) -> DType:
84
- return self.mptype.dtype
85
-
86
- @property
87
- def shape(self) -> Shape:
88
- return self.mptype.shape
89
-
90
- @property
91
- def schema(self) -> TableType:
92
- """The table schema of the object.
93
-
94
- Only available for table types.
95
- """
96
- return self.mptype.schema
97
-
98
- @property
99
- def pmask(self) -> Mask | None:
100
- return self.mptype.pmask
101
-
102
- @property
103
- def attrs(self) -> dict[str, Any]:
104
- return self.mptype.attrs
105
-
106
- @property
107
- @abstractmethod
108
- def ctx(self) -> MPContext:
109
- """Return the context of the object."""
110
-
111
-
112
- # Forward docstrings from MPType to MPObject
113
- MPObject.dtype.__doc__ = MPType.dtype.__doc__
114
- MPObject.shape.__doc__ = MPType.shape.__doc__
115
- MPObject.schema.__doc__ = MPType.schema.__doc__
116
- MPObject.pmask.__doc__ = MPType.pmask.__doc__
117
- MPObject.attrs.__doc__ = MPType.attrs.__doc__
mplang/v1/core/mptype.py DELETED
@@ -1,407 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- import copy
18
- from typing import TYPE_CHECKING, Any
19
-
20
- import numpy as np
21
-
22
- if TYPE_CHECKING:
23
- from mplang.v1.core.mpobject import MPObject
24
-
25
- from mplang.v1.core.dtypes import STRING, DType
26
- from mplang.v1.core.mask import Mask
27
- from mplang.v1.core.table import TableLike, TableType
28
- from mplang.v1.core.tensor import ScalarType, Shape, TensorLike, TensorType
29
-
30
- # basic type aliases
31
- Rank = int
32
-
33
-
34
- class MPType:
35
- """A type that describes the type information of an MPObject."""
36
-
37
- _type: TensorType | TableType
38
- _pmask: Mask | None
39
- _attrs: dict[str, Any]
40
-
41
- def __init__(
42
- self,
43
- type_info: TensorType | TableType,
44
- pmask: Mask | None = None,
45
- attrs: dict[str, Any] | None = None,
46
- ):
47
- """Initialize MPType.
48
-
49
- Args:
50
- type_info: The type information (TensorType for tensors, TableType for tables).
51
- pmask: The party mask, used for compile/trace time determine which party holds the object.
52
- attrs: Attributes are key-value pairs that can be used to store additional information about the object.
53
- """
54
- self._type = type_info
55
- self._pmask = pmask
56
- # Ensure attrs is a copy
57
- self._attrs = copy.copy(attrs) if attrs is not None else {}
58
-
59
- @classmethod
60
- def tensor(
61
- cls,
62
- dtype: DType | Any,
63
- shape: Shape,
64
- pmask: int | Mask | None = None,
65
- **attrs: Any,
66
- ) -> MPType:
67
- """Create a tensor type.
68
-
69
- Args:
70
- dtype: The data type of the tensor.
71
- shape: The shape of the tensor.
72
- pmask: The party mask.
73
- **attrs: Additional attributes.
74
-
75
- Returns:
76
- MPType instance for tensor.
77
-
78
- Raises:
79
- ValueError: If dtype is table-only.
80
- """
81
- # Convert dtype to DType if needed and validate
82
- if not isinstance(dtype, DType):
83
- dtype = DType.from_any(dtype)
84
-
85
- # Ensure tensor types don't use table-only dtypes
86
- if dtype.is_table_only:
87
- raise ValueError(
88
- f"Data type '{dtype.name}' is only supported in tables, "
89
- f"not in tensors. Use table types for string, date, and other "
90
- f"non-numeric data types."
91
- )
92
-
93
- if isinstance(pmask, int):
94
- pmask = Mask.from_int(pmask)
95
-
96
- tensor_info = TensorType(dtype, shape)
97
- return cls(tensor_info, pmask, attrs)
98
-
99
- @classmethod
100
- def table(
101
- cls,
102
- schema: TableType | dict[str, DType],
103
- pmask: int | Mask | None = None,
104
- **attrs: Any,
105
- ) -> MPType:
106
- """Create a table type.
107
-
108
- Args:
109
- schema: The table schema or dict mapping column names to types.
110
- pmask: The party mask.
111
- **attrs: Additional attributes.
112
-
113
- Returns:
114
- MPType instance for table.
115
- """
116
- if isinstance(schema, dict):
117
- schema = TableType.from_dict(schema)
118
-
119
- if isinstance(pmask, int):
120
- pmask = Mask.from_int(pmask)
121
-
122
- return cls(schema, pmask, attrs)
123
-
124
- @property
125
- def is_tensor(self) -> bool:
126
- """Check if this is a tensor type."""
127
- return isinstance(self._type, TensorType)
128
-
129
- @property
130
- def is_table(self) -> bool:
131
- """Check if this is a table type."""
132
- return isinstance(self._type, TableType)
133
-
134
- @property
135
- def dtype(self) -> DType:
136
- """The data type of the object.
137
-
138
- This property is readonly (mandatory) and will be used for JAX compilation
139
- to determine the appropriate data type during trace and compilation phases.
140
-
141
- Only available for tensor types.
142
- """
143
- if not isinstance(self._type, TensorType):
144
- raise AttributeError("dtype is only available for tensor types")
145
- return self._type.dtype
146
-
147
- @property
148
- def shape(self) -> Shape:
149
- """The shape of the object, represented as a tuple of integers.
150
-
151
- For example, a 2D tensor with shape (3, 4) would be represented as (3, 4).
152
- The shape can be empty, which indicates a scalar.
153
-
154
- This property is readonly (mandatory) and will be used for JAX compilation
155
- to determine tensor shapes during trace and compilation phases.
156
-
157
- Only available for tensor types.
158
- """
159
- if not isinstance(self._type, TensorType):
160
- raise AttributeError("shape is only available for tensor types")
161
- return self._type.shape
162
-
163
- @property
164
- def schema(self) -> TableType:
165
- """The table schema.
166
-
167
- Only available for table types.
168
- """
169
- if not isinstance(self._type, TableType):
170
- raise AttributeError("schema is only available for table types")
171
- return self._type
172
-
173
- @property
174
- def pmask(self) -> Mask | None:
175
- """The party mask indicating which parties hold the data.
176
-
177
- Value interpretation:
178
- - When not None: A bitmask where the i'th bit is 1 if the i'th party holds
179
- the data, and 0 otherwise. For example, 0b1101 means parties 0, 2, and 3
180
- hold the data, while party 1 does not.
181
- - When None: Party ownership is unknown at compile/trace time and will be
182
- completely determined at runtime.
183
-
184
- Semantic meaning:
185
- This mask can be either manually set or deduced by primitive functions during
186
- compilation/tracing. When None, it does NOT imply either a full mask (all
187
- parties) or zero mask (no parties) - the actual ownership pattern is entirely
188
- runtime-dependent.
189
- """
190
- return self._pmask
191
-
192
- @property
193
- def attrs(self) -> dict[str, Any]:
194
- """Attributes are key-value pairs that can be used to store additional
195
- information about the object."""
196
- return self._attrs
197
-
198
- def raw_type(self) -> TensorType | TableType:
199
- """Get the raw type information (TensorType or TableType)."""
200
- return self._type
201
-
202
- def set_attr(self, key: str, value: Any) -> None:
203
- """Set an attribute for this type."""
204
- self._attrs[key] = value
205
-
206
- def get_attr(self, key: str, default: Any = None) -> Any:
207
- """Get an attribute for this type."""
208
- return self._attrs.get(key, default)
209
-
210
- def __repr__(self) -> str:
211
- """String representation of MPType.
212
-
213
- Schema:
214
- - For tensor: dtype[shape]<pmask>{other_attrs}
215
- - For table: Tbl(col1:type1, col2:type2)<pmask>{other_attrs}
216
-
217
- Examples:
218
- - u64 # scalar uint64
219
- - f32[3, 2] # 3x2 float32 tensor
220
- - f16[3]<3> # float16 vector with pmask=3
221
- - u32[5, 5]<F>{device="P0"} # uint32 matrix with pmask=15 and device attr
222
- - Tbl(id:i64, name:str) # table with id and name columns
223
- """
224
- if isinstance(self._type, TensorType):
225
- # Start with short dtype name
226
- ret = self._type.dtype.short_name()
227
-
228
- # Add shape if not scalar
229
- if self._type.shape:
230
- shape_str = ", ".join(str(d) for d in self._type.shape)
231
- ret += f"[{shape_str}]"
232
- else: # TableType
233
- cols = ", ".join(
234
- f"{name}:{dtype.short_name()}" for name, dtype in self._type.columns
235
- )
236
- ret = f"Tbl({cols})"
237
-
238
- # Add pmask in angle brackets if present
239
- if self._pmask is not None:
240
- ret += f"<{self._pmask:X}>"
241
-
242
- # Add other attributes in curly braces if any
243
- if self._attrs:
244
- attrs_list = []
245
- for key, value in self._attrs.items():
246
- if isinstance(value, str):
247
- attrs_list.append(f'{key}="{value}"')
248
- else:
249
- attrs_list.append(f"{key}={value}")
250
- ret += "{" + ", ".join(attrs_list) + "}"
251
-
252
- return ret
253
-
254
- def __eq__(self, other: object) -> bool:
255
- """Check if two MPType objects are equal."""
256
- if not isinstance(other, MPType):
257
- return False
258
- return (
259
- self._type == other._type and self._pmask == other._pmask
260
- # and self._attrs == other._attrs # TODO(jint): attrs should be optional
261
- )
262
-
263
- def __hash__(self) -> int:
264
- """Compute hash for MPType objects."""
265
- # Make attrs hashable by converting to frozenset of items
266
- attrs_hash = hash(frozenset(self._attrs.items())) if self._attrs else 0
267
- return hash((
268
- self._type,
269
- self._pmask,
270
- attrs_hash,
271
- ))
272
-
273
- def isInstance(self, obj: MPObject) -> bool:
274
- """Check if the given object is an instance of this MPType."""
275
- # Import here to avoid circular import
276
- from mplang.v1.core.mpobject import MPObject
277
-
278
- if not isinstance(obj, MPObject):
279
- return False
280
-
281
- # Check if the object's type matches this type
282
- obj_type = obj.mptype
283
- if type(self._type) is not type(obj_type._type):
284
- return False
285
-
286
- if self._type != obj_type._type:
287
- return False
288
-
289
- # Check attributes
290
- if self._attrs:
291
- if not isinstance(obj.attrs, dict):
292
- return False
293
- for k, v in self._attrs.items():
294
- if k not in obj.attrs or obj.attrs[k] != v:
295
- return False
296
- return True
297
-
298
- def to_numpy(self) -> np.dtype:
299
- """Convert to NumPy dtype for compatibility.
300
-
301
- Only available for tensor types.
302
- """
303
- if not isinstance(self._type, TensorType):
304
- raise AttributeError("to_numpy is only available for tensor types")
305
- return self._type.to_numpy()
306
-
307
- @staticmethod
308
- def _create_tensor_info(obj: TensorLike | ScalarType) -> TensorType:
309
- """Helper method to create TensorType from tensor-like objects."""
310
- if isinstance(obj, ScalarType):
311
- return TensorType(DType.from_python_type(type(obj)), ())
312
- elif isinstance(obj, TensorLike):
313
- return TensorType(DType.from_any(obj.dtype), obj.shape)
314
- elif isinstance(obj, list | tuple):
315
- # Convert lists/tuples to numpy arrays for compatibility
316
- arr = np.array(obj)
317
- return TensorType(DType.from_any(arr.dtype), arr.shape)
318
- else:
319
- raise TypeError(f"Unsupported type: {type(obj)}.")
320
-
321
- @classmethod
322
- def from_tensor(
323
- cls,
324
- obj: TensorLike | ScalarType,
325
- pmask: Mask | None = None,
326
- **kwargs: Any,
327
- ) -> MPType:
328
- """Create MPType from tensor-like object.
329
-
330
- Args:
331
- obj: Tensor-like object or scalar.
332
- pmask: The party mask.
333
- **kwargs: Additional attributes.
334
-
335
- Returns:
336
- MPType instance for tensor.
337
- """
338
- attrs = copy.copy(kwargs)
339
- tensor_info = cls._create_tensor_info(obj)
340
- return cls(tensor_info, pmask, attrs)
341
-
342
- @classmethod
343
- def from_mpobj(cls, obj: MPObject) -> MPType:
344
- """Create MPType from MPObject.
345
-
346
- Args:
347
- obj: MPObject instance.
348
-
349
- Returns:
350
- MPType instance with same type as the object.
351
- """
352
- # assume obj is MPObject-like
353
- obj_type = obj.mptype
354
- return cls(obj_type._type, obj.pmask, copy.copy(obj.attrs))
355
-
356
- @classmethod
357
- def from_obj(cls, obj: Any, pmask: Mask | None = None, **attrs: Any) -> MPType:
358
- """Create MPType from any object, automatically inferring the type.
359
-
360
- Args:
361
- obj: Object to create type from.
362
- pmask: The party mask.
363
- **attrs: Additional attributes.
364
-
365
- Returns:
366
- MPType instance.
367
-
368
- Raises:
369
- TypeError: If object type cannot be inferred.
370
- NotImplementedError: For table objects (not yet implemented).
371
- """
372
- # Check if it's a table-like object using the TableLike protocol
373
- if isinstance(obj, TableLike):
374
- # For TableLike objects, try to extract schema information
375
- try:
376
- import pandas as pd
377
-
378
- if isinstance(obj, pd.DataFrame):
379
- from mplang.v1.core.dtypes import DType
380
-
381
- schema_dict = {}
382
- for col_name in obj.columns:
383
- pandas_dtype = obj[col_name].dtype
384
- # Convert pandas dtype to DType
385
- if pandas_dtype.kind in (
386
- "O",
387
- "U",
388
- "S",
389
- ): # object, unicode, string
390
- schema_dict[col_name] = (
391
- DType.from_numpy(pandas_dtype)
392
- if pandas_dtype.kind != "O"
393
- else STRING
394
- )
395
- else:
396
- schema_dict[col_name] = DType.from_numpy(pandas_dtype)
397
- schema = TableType.from_dict(schema_dict)
398
- return cls(schema, pmask, attrs)
399
- except ImportError:
400
- pass
401
- # For other table-like objects without pandas
402
- raise NotImplementedError(
403
- "Table object detection for non-pandas objects not fully implemented yet"
404
- )
405
-
406
- # Otherwise treat as tensor-like
407
- return cls.from_tensor(obj, pmask, **attrs)
mplang/v1/core/pfunc.py DELETED
@@ -1,130 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- import copy
18
- from collections.abc import Sequence
19
- from types import MappingProxyType
20
- from typing import Any
21
-
22
- from mplang.v1.core.table import TableType
23
- from mplang.v1.core.tensor import TensorType
24
-
25
- __all__ = [
26
- "PFunction",
27
- "get_fn_name",
28
- ]
29
-
30
-
31
- class PFunction:
32
- """A Party Function represents a computation unit that can be executed by a single party.
33
-
34
- PFunction serves as a unified interface for describing single-party computations
35
- in multi-party computing scenarios. It can represent both:
36
- 1. Built-in operations (e.g., "spu.makeshares", "basic.read")
37
- 2. User-defined programmable functions with custom code
38
-
39
- The PFunction accepts a list of typed inputs (TensorType/TableType). For
40
- backend-only handles (e.g., crypto keys), use a sentinel TensorType
41
- of UINT8 with shape (-1, 0) to indicate the argument should bypass
42
- structural validation at runtime. Outputs should likewise use concrete
43
- TensorType/TableType specs. PFunction can be:
44
- - Expressed and defined in the mplang frontend
45
- - Serialized for transmission between components
46
- - Interpreted and executed by backend runtime engines
47
-
48
- Args:
49
- fn_type: The type/category identifier of this PFunction, indicating which
50
- backend or handler should process it (e.g., "spu.makeshares", "basic.read",
51
- "mlir.stablehlo"). This serves as a routing mechanism for execution.
52
- ins_info: Type information for input parameters (TensorType or TableType)
53
- outs_info: Type information for output values (TensorType or TableType)
54
- fn_name: Optional name of the function. For programmable functions, this is
55
- the user-defined function name. For built-in operations, this may be
56
- None or a descriptive identifier.
57
- fn_text: Optional serialized function body. For programmable functions, this
58
- contains the actual code (e.g., MLIR, bytecode, source code). For built-in
59
- operations, this is typically None.
60
- **kwargs: Additional attributes and metadata specific to the function type.
61
- These are used to pass execution parameters, configuration, and context
62
- information to the backend handlers.
63
- """
64
-
65
- # Required fields - these define the core execution context
66
- fn_type: str # Unique identifier for backend routing
67
- ins_info: tuple[TensorType | TableType, ...]
68
- outs_info: tuple[TensorType | TableType, ...]
69
-
70
- # Optional fields for programmable functions
71
- fn_name: str | None # Function name (for programmable functions)
72
- fn_text: str | None # Function body/code (for programmable functions)
73
-
74
- # Custom attributes and metadata
75
- attrs: MappingProxyType[str, Any] # Execution parameters and metadata
76
-
77
- def __init__(
78
- self,
79
- fn_type: str,
80
- ins_info: Sequence[TensorType | TableType],
81
- outs_info: Sequence[TensorType | TableType],
82
- *,
83
- fn_name: str | None = None,
84
- fn_text: str | None = None,
85
- **kwargs: Any,
86
- ):
87
- self.fn_type = fn_type
88
- self.fn_name = fn_name
89
- self.fn_text = fn_text
90
- self.ins_info = tuple(ins_info)
91
- self.outs_info = tuple(outs_info)
92
- # Make attrs immutable to ensure PFunction immutability
93
- # Create a copy first, then wrap it in MappingProxyType
94
- self.attrs = MappingProxyType(copy.copy(kwargs))
95
-
96
- def __repr__(self) -> str:
97
- return f"{self.__class__.__name__}({self.fn_type}, {self.fn_name})"
98
-
99
- def __hash__(self) -> int:
100
- return hash((
101
- self.fn_type,
102
- self.fn_name,
103
- self.fn_text,
104
- self.ins_info,
105
- self.outs_info,
106
- frozenset(self.attrs.items()),
107
- ))
108
-
109
- def __eq__(self, other: object) -> bool:
110
- """Check equality between PFunction instances."""
111
- if not isinstance(other, PFunction):
112
- return False
113
-
114
- return (
115
- self.fn_type == other.fn_type
116
- and self.fn_name == other.fn_name
117
- and self.fn_text == other.fn_text
118
- and self.ins_info == other.ins_info
119
- and self.outs_info == other.outs_info
120
- and self.attrs == other.attrs
121
- )
122
-
123
-
124
- def get_fn_name(fn_like: Any) -> str:
125
- if hasattr(fn_like, "__name__"):
126
- return fn_like.__name__ # type: ignore[no-any-return]
127
- if hasattr(fn_like, "func"):
128
- # handle partial functions
129
- return get_fn_name(fn_like.func)
130
- return "unnamed function"