mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev270__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +46 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +19 -13
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/libs/mpc/psi/rr22.py +303 -0
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang/v2/libs/mpc/psi/rr22.py +0 -344
  162. mplang_nightly-0.1.dev268.dist-info/RECORD +0 -180
  163. /mplang/{v2/backends → backends}/channel.py +0 -0
  164. /mplang/{v2/edsl → edsl}/README.md +0 -0
  165. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  166. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  167. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  168. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  169. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  171. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  172. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  175. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  177. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  178. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  179. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
  180. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
  181. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/core/mask.py DELETED
@@ -1,325 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- Mask class for representing party masks in multi-party computation.
17
-
18
- This class encapsulates mask data and operations, replacing the previous
19
- int-based mask representation with a proper type-safe abstraction.
20
- """
21
-
22
- from __future__ import annotations
23
-
24
- from collections.abc import Iterable, Iterator
25
- from typing import Literal
26
-
27
-
28
- class Mask:
29
- """
30
- A mask representing a set of parties in multi-party computation.
31
-
32
- The mask uses bit positions to represent party ranks:
33
- - Bit 0 represents party 0
34
- - Bit 1 represents party 1
35
- - And so on...
36
-
37
- Examples:
38
- >>> mask = Mask.from_ranks([0, 1]) # Parties 0 and 1
39
- >>> mask = Mask.from_int(0b101) # Parties 0 and 2
40
- >>> mask = Mask.all(3) # All parties 0, 1, 2
41
- """
42
-
43
- _value: int
44
-
45
- def __init__(self, value: Mask | int) -> None:
46
- """
47
- Create a mask from an integer value.
48
-
49
- Args:
50
- value: Integer where each bit represents a party
51
-
52
- Raises:
53
- ValueError: If value is negative
54
- """
55
- if isinstance(value, Mask):
56
- self._value = value._value
57
- else:
58
- if value < 0:
59
- raise ValueError("Mask value must be non-negative")
60
- self._value = int(value)
61
-
62
- @classmethod
63
- def from_int(cls, value: int) -> Mask:
64
- """Create a mask from an integer."""
65
- return cls(value)
66
-
67
- @classmethod
68
- def from_ranks(cls, ranks: int | Iterable[int]) -> Mask:
69
- """
70
- Create a mask from one or more ranks.
71
-
72
- Args:
73
- ranks: Either a single integer rank or an iterable of integer ranks
74
-
75
- Returns:
76
- Mask with the specified ranks set
77
-
78
- Examples:
79
- >>> Mask.from_ranks(0) # Single party 0
80
- >>> Mask.from_ranks([0, 1, 2]) # Multiple parties
81
- >>> Mask.from_ranks((1, 3)) # Tuple of parties
82
- """
83
- if isinstance(ranks, int):
84
- if ranks < 0:
85
- raise ValueError("Rank must be non-negative")
86
- return cls(1 << ranks)
87
-
88
- mask_value = 0
89
- for rank in ranks:
90
- if rank < 0:
91
- raise ValueError("All ranks must be non-negative")
92
- mask_value |= 1 << rank
93
- return cls(mask_value)
94
-
95
- @classmethod
96
- def all(cls, num_parties: int) -> Mask:
97
- """Create a mask with all parties up to num_parties-1."""
98
- if num_parties < 0:
99
- raise ValueError("Number of parties must be non-negative")
100
- if num_parties == 0:
101
- return cls(0)
102
- return cls((1 << num_parties) - 1)
103
-
104
- @classmethod
105
- def none(cls) -> Mask:
106
- """Create an empty mask."""
107
- return cls(0)
108
-
109
- @staticmethod
110
- def _ensure_mask_value(value: Mask | int) -> int:
111
- """
112
- Ensure a value is converted to its underlying integer mask.
113
-
114
- Args:
115
- value: Either a Mask instance or an integer
116
-
117
- Returns:
118
- The underlying integer value of the mask
119
- """
120
- if isinstance(value, Mask):
121
- return value._value
122
- else:
123
- return int(value)
124
-
125
- @property
126
- def value(self) -> int:
127
- """Get the underlying integer value."""
128
- return self._value
129
-
130
- def __int__(self) -> int:
131
- """Allow implicit conversion to int."""
132
- return self._value
133
-
134
- def __eq__(self, other: object) -> bool:
135
- """Check equality with another mask or int."""
136
- if isinstance(other, Mask):
137
- return self._value == other._value
138
- elif isinstance(other, int):
139
- return self._value == other
140
- else:
141
- raise TypeError("Invalid type for equal comparison")
142
-
143
- def __hash__(self) -> int:
144
- """Make Mask hashable."""
145
- return hash(self._value)
146
-
147
- def __repr__(self) -> str:
148
- """String representation of the mask."""
149
- return f"Mask({bin(self._value)})"
150
-
151
- def __str__(self) -> str:
152
- """Human-readable string representation."""
153
- ranks = list(self.ranks())
154
- if not ranks:
155
- return "Mask()"
156
- return f"Mask({ranks})"
157
-
158
- def __format__(self, format_spec: str) -> str:
159
- """Support formatting for hexadecimal display."""
160
- return format(self._value, format_spec)
161
-
162
- def num_parties(self) -> int:
163
- """Count the number of parties in this mask."""
164
- return self._value.bit_count()
165
-
166
- def ranks(self) -> Iterator[int]:
167
- """Iterate over the ranks in this mask."""
168
- value = self._value
169
- rank = 0
170
- while value > 0:
171
- if value & 1:
172
- yield rank
173
- value >>= 1
174
- rank += 1
175
-
176
- def __iter__(self) -> Iterator[int]:
177
- """Allow iteration over ranks."""
178
- return self.ranks()
179
-
180
- def __contains__(self, rank: int) -> bool:
181
- """Check if a rank is in this mask."""
182
- if rank < 0:
183
- return False
184
- return (self._value & (1 << rank)) != 0
185
-
186
- def is_disjoint(self, other: Mask | int) -> bool:
187
- """Check if this mask is disjoint with another."""
188
- other_mask_value = self._ensure_mask_value(other)
189
- return (self._value & other_mask_value) == 0
190
-
191
- def is_subset(self, other: Mask | int) -> bool:
192
- """Check if this mask is a subset of another."""
193
- other_mask_value = self._ensure_mask_value(other)
194
- return (self._value & other_mask_value) == self._value
195
-
196
- def is_superset(self, other: Mask | int) -> bool:
197
- """Check if this mask is a superset of another."""
198
- other_mask_value = self._ensure_mask_value(other)
199
- return (other_mask_value & self._value) == other_mask_value
200
-
201
- def union(self, other: Mask | int) -> Mask:
202
- """Return the union of this mask with another."""
203
- other_mask_value = self._ensure_mask_value(other)
204
- return Mask(self._value | other_mask_value)
205
-
206
- def intersection(self, other: Mask | int) -> Mask:
207
- """Return the intersection of this mask with another."""
208
- other_mask_value = self._ensure_mask_value(other)
209
- return Mask(self._value & other_mask_value)
210
-
211
- def difference(self, other: Mask | int) -> Mask:
212
- """Return the difference of this mask with another."""
213
- other_mask_value = self._ensure_mask_value(other)
214
- return Mask(self._value & Mask._invert_mask_value(other_mask_value))
215
-
216
- def __or__(self, other: Mask | int) -> Mask:
217
- """Union operator (|)."""
218
- return self.union(other)
219
-
220
- def __and__(self, other: Mask | int) -> Mask:
221
- """Intersection operator (&)."""
222
- return self.intersection(other)
223
-
224
- def __xor__(self, other: Mask | int) -> Mask:
225
- """Symmetric difference operator (^)."""
226
- other_mask_value = self._ensure_mask_value(other)
227
- return Mask(self._value ^ other_mask_value)
228
-
229
- def __sub__(self, other: Mask | int) -> Mask:
230
- """Difference operator (-)."""
231
- return self.difference(other)
232
-
233
- @staticmethod
234
- def _invert_mask_value(value: int) -> int:
235
- # Invert the bits of the mask value
236
- # Use with caution - typically you want to limit to a specific number of parties
237
- # For now, we limit to 64 bits to avoid negative values
238
- return ~value & ((1 << 64) - 1)
239
-
240
- def __invert__(self) -> Mask:
241
- """Bitwise NOT operator (~)."""
242
- # Note: This creates a mask with potentially infinite bits set
243
- return Mask(Mask._invert_mask_value(self._value))
244
-
245
- def global_to_relative_rank(self, global_rank: int) -> int:
246
- """Convert a global rank to relative rank within this mask."""
247
- if global_rank not in self:
248
- raise ValueError(f"Global rank {global_rank} not in mask")
249
-
250
- # Count set bits up to global_rank
251
- mask_up_to_rank = self._value & ((1 << (global_rank + 1)) - 1)
252
- return bin(mask_up_to_rank).count("1") - 1
253
-
254
- def relative_to_global_rank(self, relative_rank: int) -> int:
255
- """Convert a relative rank to global rank within this mask."""
256
- if relative_rank < 0 or relative_rank >= self.num_parties():
257
- raise ValueError(f"Relative rank {relative_rank} out of range")
258
-
259
- count = 0
260
- global_rank = 0
261
- value = self._value
262
-
263
- while value > 0 and count <= relative_rank:
264
- if value & 1:
265
- if count == relative_rank:
266
- return global_rank
267
- count += 1
268
- value >>= 1
269
- global_rank += 1
270
-
271
- raise ValueError(f"Relative rank {relative_rank} not found in mask")
272
-
273
- def copy(self) -> Mask:
274
- """Return a copy of this mask."""
275
- return Mask(self._value)
276
-
277
- def to_bytes(
278
- self, length: int = 8, byteorder: Literal["little", "big"] = "big"
279
- ) -> bytes:
280
- """Convert mask to bytes for serialization."""
281
- return self._value.to_bytes(length, byteorder=byteorder)
282
-
283
- @property
284
- def is_empty(self) -> bool:
285
- """Check if this mask is empty."""
286
- return self._value == 0
287
-
288
- @property
289
- def is_single(self) -> bool:
290
- """Check if this mask contains exactly one party."""
291
- return (self._value & (self._value - 1)) == 0 and self._value != 0
292
-
293
- def to_json(self) -> int:
294
- """Serialize to JSON-compatible format."""
295
- return self._value
296
-
297
- @classmethod
298
- def from_json(cls, value: int) -> Mask:
299
- """Deserialize from JSON-compatible format."""
300
- return cls(value)
301
-
302
- @classmethod
303
- def from_bytes(
304
- cls, data: bytes, byteorder: Literal["little", "big"] = "big"
305
- ) -> Mask:
306
- """
307
- Create a mask from bytes for deserialization.
308
-
309
- Args:
310
- data: Bytes to convert to mask
311
- byteorder: Byte order ('little' or 'big')
312
-
313
- Returns:
314
- Mask created from the bytes
315
-
316
- Examples:
317
- >>> mask = Mask.from_bytes(b"\x05", byteorder="big")
318
- >>> mask.value == 5
319
- True
320
- >>> mask = Mask.from_bytes(b"\x05\x00", byteorder="little")
321
- >>> mask.value == 5
322
- True
323
- """
324
- value = int.from_bytes(data, byteorder=byteorder)
325
- return cls(value)