mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/core/mask.py
DELETED
|
@@ -1,325 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
Mask class for representing party masks in multi-party computation.
|
|
17
|
-
|
|
18
|
-
This class encapsulates mask data and operations, replacing the previous
|
|
19
|
-
int-based mask representation with a proper type-safe abstraction.
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
from __future__ import annotations
|
|
23
|
-
|
|
24
|
-
from collections.abc import Iterable, Iterator
|
|
25
|
-
from typing import Literal
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
class Mask:
|
|
29
|
-
"""
|
|
30
|
-
A mask representing a set of parties in multi-party computation.
|
|
31
|
-
|
|
32
|
-
The mask uses bit positions to represent party ranks:
|
|
33
|
-
- Bit 0 represents party 0
|
|
34
|
-
- Bit 1 represents party 1
|
|
35
|
-
- And so on...
|
|
36
|
-
|
|
37
|
-
Examples:
|
|
38
|
-
>>> mask = Mask.from_ranks([0, 1]) # Parties 0 and 1
|
|
39
|
-
>>> mask = Mask.from_int(0b101) # Parties 0 and 2
|
|
40
|
-
>>> mask = Mask.all(3) # All parties 0, 1, 2
|
|
41
|
-
"""
|
|
42
|
-
|
|
43
|
-
_value: int
|
|
44
|
-
|
|
45
|
-
def __init__(self, value: Mask | int) -> None:
|
|
46
|
-
"""
|
|
47
|
-
Create a mask from an integer value.
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
value: Integer where each bit represents a party
|
|
51
|
-
|
|
52
|
-
Raises:
|
|
53
|
-
ValueError: If value is negative
|
|
54
|
-
"""
|
|
55
|
-
if isinstance(value, Mask):
|
|
56
|
-
self._value = value._value
|
|
57
|
-
else:
|
|
58
|
-
if value < 0:
|
|
59
|
-
raise ValueError("Mask value must be non-negative")
|
|
60
|
-
self._value = int(value)
|
|
61
|
-
|
|
62
|
-
@classmethod
|
|
63
|
-
def from_int(cls, value: int) -> Mask:
|
|
64
|
-
"""Create a mask from an integer."""
|
|
65
|
-
return cls(value)
|
|
66
|
-
|
|
67
|
-
@classmethod
|
|
68
|
-
def from_ranks(cls, ranks: int | Iterable[int]) -> Mask:
|
|
69
|
-
"""
|
|
70
|
-
Create a mask from one or more ranks.
|
|
71
|
-
|
|
72
|
-
Args:
|
|
73
|
-
ranks: Either a single integer rank or an iterable of integer ranks
|
|
74
|
-
|
|
75
|
-
Returns:
|
|
76
|
-
Mask with the specified ranks set
|
|
77
|
-
|
|
78
|
-
Examples:
|
|
79
|
-
>>> Mask.from_ranks(0) # Single party 0
|
|
80
|
-
>>> Mask.from_ranks([0, 1, 2]) # Multiple parties
|
|
81
|
-
>>> Mask.from_ranks((1, 3)) # Tuple of parties
|
|
82
|
-
"""
|
|
83
|
-
if isinstance(ranks, int):
|
|
84
|
-
if ranks < 0:
|
|
85
|
-
raise ValueError("Rank must be non-negative")
|
|
86
|
-
return cls(1 << ranks)
|
|
87
|
-
|
|
88
|
-
mask_value = 0
|
|
89
|
-
for rank in ranks:
|
|
90
|
-
if rank < 0:
|
|
91
|
-
raise ValueError("All ranks must be non-negative")
|
|
92
|
-
mask_value |= 1 << rank
|
|
93
|
-
return cls(mask_value)
|
|
94
|
-
|
|
95
|
-
@classmethod
|
|
96
|
-
def all(cls, num_parties: int) -> Mask:
|
|
97
|
-
"""Create a mask with all parties up to num_parties-1."""
|
|
98
|
-
if num_parties < 0:
|
|
99
|
-
raise ValueError("Number of parties must be non-negative")
|
|
100
|
-
if num_parties == 0:
|
|
101
|
-
return cls(0)
|
|
102
|
-
return cls((1 << num_parties) - 1)
|
|
103
|
-
|
|
104
|
-
@classmethod
|
|
105
|
-
def none(cls) -> Mask:
|
|
106
|
-
"""Create an empty mask."""
|
|
107
|
-
return cls(0)
|
|
108
|
-
|
|
109
|
-
@staticmethod
|
|
110
|
-
def _ensure_mask_value(value: Mask | int) -> int:
|
|
111
|
-
"""
|
|
112
|
-
Ensure a value is converted to its underlying integer mask.
|
|
113
|
-
|
|
114
|
-
Args:
|
|
115
|
-
value: Either a Mask instance or an integer
|
|
116
|
-
|
|
117
|
-
Returns:
|
|
118
|
-
The underlying integer value of the mask
|
|
119
|
-
"""
|
|
120
|
-
if isinstance(value, Mask):
|
|
121
|
-
return value._value
|
|
122
|
-
else:
|
|
123
|
-
return int(value)
|
|
124
|
-
|
|
125
|
-
@property
|
|
126
|
-
def value(self) -> int:
|
|
127
|
-
"""Get the underlying integer value."""
|
|
128
|
-
return self._value
|
|
129
|
-
|
|
130
|
-
def __int__(self) -> int:
|
|
131
|
-
"""Allow implicit conversion to int."""
|
|
132
|
-
return self._value
|
|
133
|
-
|
|
134
|
-
def __eq__(self, other: object) -> bool:
|
|
135
|
-
"""Check equality with another mask or int."""
|
|
136
|
-
if isinstance(other, Mask):
|
|
137
|
-
return self._value == other._value
|
|
138
|
-
elif isinstance(other, int):
|
|
139
|
-
return self._value == other
|
|
140
|
-
else:
|
|
141
|
-
raise TypeError("Invalid type for equal comparison")
|
|
142
|
-
|
|
143
|
-
def __hash__(self) -> int:
|
|
144
|
-
"""Make Mask hashable."""
|
|
145
|
-
return hash(self._value)
|
|
146
|
-
|
|
147
|
-
def __repr__(self) -> str:
|
|
148
|
-
"""String representation of the mask."""
|
|
149
|
-
return f"Mask({bin(self._value)})"
|
|
150
|
-
|
|
151
|
-
def __str__(self) -> str:
|
|
152
|
-
"""Human-readable string representation."""
|
|
153
|
-
ranks = list(self.ranks())
|
|
154
|
-
if not ranks:
|
|
155
|
-
return "Mask()"
|
|
156
|
-
return f"Mask({ranks})"
|
|
157
|
-
|
|
158
|
-
def __format__(self, format_spec: str) -> str:
|
|
159
|
-
"""Support formatting for hexadecimal display."""
|
|
160
|
-
return format(self._value, format_spec)
|
|
161
|
-
|
|
162
|
-
def num_parties(self) -> int:
|
|
163
|
-
"""Count the number of parties in this mask."""
|
|
164
|
-
return self._value.bit_count()
|
|
165
|
-
|
|
166
|
-
def ranks(self) -> Iterator[int]:
|
|
167
|
-
"""Iterate over the ranks in this mask."""
|
|
168
|
-
value = self._value
|
|
169
|
-
rank = 0
|
|
170
|
-
while value > 0:
|
|
171
|
-
if value & 1:
|
|
172
|
-
yield rank
|
|
173
|
-
value >>= 1
|
|
174
|
-
rank += 1
|
|
175
|
-
|
|
176
|
-
def __iter__(self) -> Iterator[int]:
|
|
177
|
-
"""Allow iteration over ranks."""
|
|
178
|
-
return self.ranks()
|
|
179
|
-
|
|
180
|
-
def __contains__(self, rank: int) -> bool:
|
|
181
|
-
"""Check if a rank is in this mask."""
|
|
182
|
-
if rank < 0:
|
|
183
|
-
return False
|
|
184
|
-
return (self._value & (1 << rank)) != 0
|
|
185
|
-
|
|
186
|
-
def is_disjoint(self, other: Mask | int) -> bool:
|
|
187
|
-
"""Check if this mask is disjoint with another."""
|
|
188
|
-
other_mask_value = self._ensure_mask_value(other)
|
|
189
|
-
return (self._value & other_mask_value) == 0
|
|
190
|
-
|
|
191
|
-
def is_subset(self, other: Mask | int) -> bool:
|
|
192
|
-
"""Check if this mask is a subset of another."""
|
|
193
|
-
other_mask_value = self._ensure_mask_value(other)
|
|
194
|
-
return (self._value & other_mask_value) == self._value
|
|
195
|
-
|
|
196
|
-
def is_superset(self, other: Mask | int) -> bool:
|
|
197
|
-
"""Check if this mask is a superset of another."""
|
|
198
|
-
other_mask_value = self._ensure_mask_value(other)
|
|
199
|
-
return (other_mask_value & self._value) == other_mask_value
|
|
200
|
-
|
|
201
|
-
def union(self, other: Mask | int) -> Mask:
|
|
202
|
-
"""Return the union of this mask with another."""
|
|
203
|
-
other_mask_value = self._ensure_mask_value(other)
|
|
204
|
-
return Mask(self._value | other_mask_value)
|
|
205
|
-
|
|
206
|
-
def intersection(self, other: Mask | int) -> Mask:
|
|
207
|
-
"""Return the intersection of this mask with another."""
|
|
208
|
-
other_mask_value = self._ensure_mask_value(other)
|
|
209
|
-
return Mask(self._value & other_mask_value)
|
|
210
|
-
|
|
211
|
-
def difference(self, other: Mask | int) -> Mask:
|
|
212
|
-
"""Return the difference of this mask with another."""
|
|
213
|
-
other_mask_value = self._ensure_mask_value(other)
|
|
214
|
-
return Mask(self._value & Mask._invert_mask_value(other_mask_value))
|
|
215
|
-
|
|
216
|
-
def __or__(self, other: Mask | int) -> Mask:
|
|
217
|
-
"""Union operator (|)."""
|
|
218
|
-
return self.union(other)
|
|
219
|
-
|
|
220
|
-
def __and__(self, other: Mask | int) -> Mask:
|
|
221
|
-
"""Intersection operator (&)."""
|
|
222
|
-
return self.intersection(other)
|
|
223
|
-
|
|
224
|
-
def __xor__(self, other: Mask | int) -> Mask:
|
|
225
|
-
"""Symmetric difference operator (^)."""
|
|
226
|
-
other_mask_value = self._ensure_mask_value(other)
|
|
227
|
-
return Mask(self._value ^ other_mask_value)
|
|
228
|
-
|
|
229
|
-
def __sub__(self, other: Mask | int) -> Mask:
|
|
230
|
-
"""Difference operator (-)."""
|
|
231
|
-
return self.difference(other)
|
|
232
|
-
|
|
233
|
-
@staticmethod
|
|
234
|
-
def _invert_mask_value(value: int) -> int:
|
|
235
|
-
# Invert the bits of the mask value
|
|
236
|
-
# Use with caution - typically you want to limit to a specific number of parties
|
|
237
|
-
# For now, we limit to 64 bits to avoid negative values
|
|
238
|
-
return ~value & ((1 << 64) - 1)
|
|
239
|
-
|
|
240
|
-
def __invert__(self) -> Mask:
|
|
241
|
-
"""Bitwise NOT operator (~)."""
|
|
242
|
-
# Note: This creates a mask with potentially infinite bits set
|
|
243
|
-
return Mask(Mask._invert_mask_value(self._value))
|
|
244
|
-
|
|
245
|
-
def global_to_relative_rank(self, global_rank: int) -> int:
|
|
246
|
-
"""Convert a global rank to relative rank within this mask."""
|
|
247
|
-
if global_rank not in self:
|
|
248
|
-
raise ValueError(f"Global rank {global_rank} not in mask")
|
|
249
|
-
|
|
250
|
-
# Count set bits up to global_rank
|
|
251
|
-
mask_up_to_rank = self._value & ((1 << (global_rank + 1)) - 1)
|
|
252
|
-
return bin(mask_up_to_rank).count("1") - 1
|
|
253
|
-
|
|
254
|
-
def relative_to_global_rank(self, relative_rank: int) -> int:
|
|
255
|
-
"""Convert a relative rank to global rank within this mask."""
|
|
256
|
-
if relative_rank < 0 or relative_rank >= self.num_parties():
|
|
257
|
-
raise ValueError(f"Relative rank {relative_rank} out of range")
|
|
258
|
-
|
|
259
|
-
count = 0
|
|
260
|
-
global_rank = 0
|
|
261
|
-
value = self._value
|
|
262
|
-
|
|
263
|
-
while value > 0 and count <= relative_rank:
|
|
264
|
-
if value & 1:
|
|
265
|
-
if count == relative_rank:
|
|
266
|
-
return global_rank
|
|
267
|
-
count += 1
|
|
268
|
-
value >>= 1
|
|
269
|
-
global_rank += 1
|
|
270
|
-
|
|
271
|
-
raise ValueError(f"Relative rank {relative_rank} not found in mask")
|
|
272
|
-
|
|
273
|
-
def copy(self) -> Mask:
|
|
274
|
-
"""Return a copy of this mask."""
|
|
275
|
-
return Mask(self._value)
|
|
276
|
-
|
|
277
|
-
def to_bytes(
|
|
278
|
-
self, length: int = 8, byteorder: Literal["little", "big"] = "big"
|
|
279
|
-
) -> bytes:
|
|
280
|
-
"""Convert mask to bytes for serialization."""
|
|
281
|
-
return self._value.to_bytes(length, byteorder=byteorder)
|
|
282
|
-
|
|
283
|
-
@property
|
|
284
|
-
def is_empty(self) -> bool:
|
|
285
|
-
"""Check if this mask is empty."""
|
|
286
|
-
return self._value == 0
|
|
287
|
-
|
|
288
|
-
@property
|
|
289
|
-
def is_single(self) -> bool:
|
|
290
|
-
"""Check if this mask contains exactly one party."""
|
|
291
|
-
return (self._value & (self._value - 1)) == 0 and self._value != 0
|
|
292
|
-
|
|
293
|
-
def to_json(self) -> int:
|
|
294
|
-
"""Serialize to JSON-compatible format."""
|
|
295
|
-
return self._value
|
|
296
|
-
|
|
297
|
-
@classmethod
|
|
298
|
-
def from_json(cls, value: int) -> Mask:
|
|
299
|
-
"""Deserialize from JSON-compatible format."""
|
|
300
|
-
return cls(value)
|
|
301
|
-
|
|
302
|
-
@classmethod
|
|
303
|
-
def from_bytes(
|
|
304
|
-
cls, data: bytes, byteorder: Literal["little", "big"] = "big"
|
|
305
|
-
) -> Mask:
|
|
306
|
-
"""
|
|
307
|
-
Create a mask from bytes for deserialization.
|
|
308
|
-
|
|
309
|
-
Args:
|
|
310
|
-
data: Bytes to convert to mask
|
|
311
|
-
byteorder: Byte order ('little' or 'big')
|
|
312
|
-
|
|
313
|
-
Returns:
|
|
314
|
-
Mask created from the bytes
|
|
315
|
-
|
|
316
|
-
Examples:
|
|
317
|
-
>>> mask = Mask.from_bytes(b"\x05", byteorder="big")
|
|
318
|
-
>>> mask.value == 5
|
|
319
|
-
True
|
|
320
|
-
>>> mask = Mask.from_bytes(b"\x05\x00", byteorder="little")
|
|
321
|
-
>>> mask.value == 5
|
|
322
|
-
True
|
|
323
|
-
"""
|
|
324
|
-
value = int.from_bytes(data, byteorder=byteorder)
|
|
325
|
-
return cls(value)
|