mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Oblivious Group-by Sum library.
|
|
16
|
+
|
|
17
|
+
This module implements algorithms to compute the sum of values grouped by bins,
|
|
18
|
+
where the data holder (Sender) and the bin holder (Receiver) keep their inputs private.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
# mypy: disable-error-code="no-untyped-def"
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
from typing import Any
|
|
26
|
+
|
|
27
|
+
import jax
|
|
28
|
+
import jax.numpy as jnp
|
|
29
|
+
|
|
30
|
+
from mplang.v2.dialects import bfv, crypto, simp, tensor
|
|
31
|
+
from mplang.v2.libs.mpc.analytics import aggregation, permutation
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def oblivious_groupby_sum_bfv(
|
|
35
|
+
data: Any,
|
|
36
|
+
bins: Any,
|
|
37
|
+
K: int,
|
|
38
|
+
sender: int = 0,
|
|
39
|
+
receiver: int = 1,
|
|
40
|
+
poly_modulus_degree: int = 4096,
|
|
41
|
+
plain_modulus: int | None = None,
|
|
42
|
+
) -> Any:
|
|
43
|
+
"""Computes group-by sum using BFV homomorphic encryption.
|
|
44
|
+
|
|
45
|
+
Best for small K (number of bins) and low bandwidth.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
data: Input data tensor (on Sender). Shape (N,).
|
|
49
|
+
bins: Bin assignments (on Receiver). Shape (N,). Values in [0, K).
|
|
50
|
+
K: Number of bins.
|
|
51
|
+
sender: Rank of the data holder.
|
|
52
|
+
receiver: Rank of the bin holder.
|
|
53
|
+
poly_modulus_degree: BFV polynomial modulus degree (slot count).
|
|
54
|
+
plain_modulus: BFV plaintext modulus. If None, uses backend default.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
A tensor of shape (K,) on the Receiver containing the sums.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
# ----------------------------------------------------------------------
|
|
61
|
+
# 1. KeyGen (Sender)
|
|
62
|
+
# ----------------------------------------------------------------------
|
|
63
|
+
def keygen_fn(degree, p_mod):
|
|
64
|
+
kwargs = {"poly_modulus_degree": degree}
|
|
65
|
+
if p_mod is not None:
|
|
66
|
+
kwargs["plain_modulus"] = p_mod
|
|
67
|
+
|
|
68
|
+
pk, sk = bfv.keygen(**kwargs)
|
|
69
|
+
rk = bfv.make_relin_keys(sk)
|
|
70
|
+
gk = bfv.make_galois_keys(sk)
|
|
71
|
+
encoder = bfv.create_encoder(poly_modulus_degree=degree)
|
|
72
|
+
return pk, sk, rk, gk, encoder
|
|
73
|
+
|
|
74
|
+
# We use a closure to capture parameters
|
|
75
|
+
def keygen_fn_closure():
|
|
76
|
+
return keygen_fn(poly_modulus_degree, plain_modulus)
|
|
77
|
+
|
|
78
|
+
pk, sk, rk, gk, encoder = simp.pcall_static((sender,), keygen_fn_closure)
|
|
79
|
+
|
|
80
|
+
# ----------------------------------------------------------------------
|
|
81
|
+
# 2. Encrypt Data (Sender)
|
|
82
|
+
# ----------------------------------------------------------------------
|
|
83
|
+
def encrypt_chunks_fn(d, enc, p_key):
|
|
84
|
+
# d is a Value (Tensor)
|
|
85
|
+
shape = d.type.shape
|
|
86
|
+
N = shape[0]
|
|
87
|
+
# Use half the degree to avoid column rotation issues (only row rotation supported)
|
|
88
|
+
B = poly_modulus_degree // 2
|
|
89
|
+
num_chunks = (N + B - 1) // B
|
|
90
|
+
|
|
91
|
+
ciphertexts = []
|
|
92
|
+
for i in range(num_chunks):
|
|
93
|
+
start = i * B
|
|
94
|
+
end = min((i + 1) * B, N)
|
|
95
|
+
|
|
96
|
+
# Bind loop variables
|
|
97
|
+
def get_chunk(x, s=start, e=end, b_val=B):
|
|
98
|
+
c = x[s:e]
|
|
99
|
+
if e - s < b_val:
|
|
100
|
+
c = jnp.pad(c, (0, b_val - (e - s)))
|
|
101
|
+
return c
|
|
102
|
+
|
|
103
|
+
chunk = tensor.run_jax(get_chunk, d)
|
|
104
|
+
|
|
105
|
+
pt = bfv.encode(chunk, enc)
|
|
106
|
+
ct = bfv.encrypt(pt, p_key)
|
|
107
|
+
ciphertexts.append(ct)
|
|
108
|
+
|
|
109
|
+
return tuple(ciphertexts)
|
|
110
|
+
|
|
111
|
+
encrypted_chunks = simp.pcall_static(
|
|
112
|
+
(sender,), encrypt_chunks_fn, data, encoder, pk
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Transfer data and keys to Receiver
|
|
116
|
+
def transfer_to_receiver(obj):
|
|
117
|
+
return simp.shuffle_static(obj, {receiver: sender})
|
|
118
|
+
|
|
119
|
+
# Always a tuple now
|
|
120
|
+
encrypted_chunks_recv = tuple(transfer_to_receiver(c) for c in encrypted_chunks)
|
|
121
|
+
|
|
122
|
+
pk_recv = transfer_to_receiver(pk)
|
|
123
|
+
rk_recv = transfer_to_receiver(rk)
|
|
124
|
+
gk_recv = transfer_to_receiver(gk)
|
|
125
|
+
encoder_recv = transfer_to_receiver(encoder)
|
|
126
|
+
|
|
127
|
+
# ----------------------------------------------------------------------
|
|
128
|
+
# 3. Aggregate (Receiver)
|
|
129
|
+
# ----------------------------------------------------------------------
|
|
130
|
+
def aggregate_fn(b_data, cts, p_key, r_key, g_key, enc):
|
|
131
|
+
# b_data is Value (Tensor)
|
|
132
|
+
# cts is list/tuple of Values (Ciphertexts)
|
|
133
|
+
|
|
134
|
+
N = b_data.type.shape[0]
|
|
135
|
+
# Use half the degree to avoid column rotation issues
|
|
136
|
+
B = poly_modulus_degree // 2
|
|
137
|
+
num_chunks = len(cts)
|
|
138
|
+
|
|
139
|
+
bin_sums = [None] * K
|
|
140
|
+
|
|
141
|
+
# Zero ciphertext
|
|
142
|
+
# Pass b_data as dummy to satisfy run_jax requirement
|
|
143
|
+
def make_zero(dummy, b_val=B):
|
|
144
|
+
return jnp.zeros((b_val,), dtype=jnp.int64)
|
|
145
|
+
|
|
146
|
+
zero_vec = tensor.run_jax(make_zero, b_data)
|
|
147
|
+
pt_zero = bfv.encode(zero_vec, enc)
|
|
148
|
+
ct_zero = bfv.encrypt(pt_zero, p_key)
|
|
149
|
+
|
|
150
|
+
for k in range(K):
|
|
151
|
+
current_sum = ct_zero
|
|
152
|
+
|
|
153
|
+
for i in range(num_chunks):
|
|
154
|
+
start = i * B
|
|
155
|
+
end = min((i + 1) * B, N)
|
|
156
|
+
|
|
157
|
+
def get_mask(b_chunk_full, s=start, e=end, b_val=B, k_target=k):
|
|
158
|
+
# b_chunk_full is the full bins tensor
|
|
159
|
+
c = b_chunk_full[s:e]
|
|
160
|
+
if e - s < b_val:
|
|
161
|
+
c = jnp.pad(c, (0, b_val - (e - s)), constant_values=-1)
|
|
162
|
+
return (c == k_target).astype(jnp.int64)
|
|
163
|
+
|
|
164
|
+
mask = tensor.run_jax(get_mask, b_data)
|
|
165
|
+
pt_mask = bfv.encode(mask, enc)
|
|
166
|
+
|
|
167
|
+
ct_masked = bfv.mul(cts[i], pt_mask)
|
|
168
|
+
ct_masked = bfv.relinearize(ct_masked, r_key)
|
|
169
|
+
current_sum = bfv.add(current_sum, ct_masked)
|
|
170
|
+
|
|
171
|
+
total_sum_ct = aggregation.rotate_and_sum(
|
|
172
|
+
current_sum, B, g_key, slot_count=poly_modulus_degree
|
|
173
|
+
)
|
|
174
|
+
bin_sums[k] = total_sum_ct
|
|
175
|
+
|
|
176
|
+
return bin_sums
|
|
177
|
+
|
|
178
|
+
encrypted_sums = simp.pcall_static(
|
|
179
|
+
(receiver,),
|
|
180
|
+
aggregate_fn,
|
|
181
|
+
bins,
|
|
182
|
+
encrypted_chunks_recv,
|
|
183
|
+
pk_recv,
|
|
184
|
+
rk_recv,
|
|
185
|
+
gk_recv,
|
|
186
|
+
encoder_recv,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# Transfer encrypted sums back to Sender
|
|
190
|
+
def transfer_to_sender(obj):
|
|
191
|
+
return simp.shuffle_static(obj, {sender: receiver})
|
|
192
|
+
|
|
193
|
+
# Always a tuple/list
|
|
194
|
+
encrypted_sums_sender = tuple(transfer_to_sender(s) for s in encrypted_sums)
|
|
195
|
+
|
|
196
|
+
# ----------------------------------------------------------------------
|
|
197
|
+
# 4. Decrypt (Sender)
|
|
198
|
+
# ----------------------------------------------------------------------
|
|
199
|
+
def decrypt_fn(cts, s_key, enc):
|
|
200
|
+
results = []
|
|
201
|
+
for ct in cts:
|
|
202
|
+
pt = bfv.decrypt(ct, s_key)
|
|
203
|
+
vec = bfv.decode(pt, enc)
|
|
204
|
+
# vec is a Tensor Value
|
|
205
|
+
# We need to extract the first element.
|
|
206
|
+
val = tensor.run_jax(lambda v: v[0], vec)
|
|
207
|
+
results.append(val)
|
|
208
|
+
|
|
209
|
+
# Stack results into a single tensor
|
|
210
|
+
def stack(*args):
|
|
211
|
+
return jnp.stack(args)
|
|
212
|
+
|
|
213
|
+
return tensor.run_jax(stack, *results)
|
|
214
|
+
|
|
215
|
+
final_sums_sender = simp.pcall_static(
|
|
216
|
+
(sender,), decrypt_fn, encrypted_sums_sender, sk, encoder
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# ----------------------------------------------------------------------
|
|
220
|
+
# 5. Return to Receiver
|
|
221
|
+
# ----------------------------------------------------------------------
|
|
222
|
+
final_sums_receiver = simp.shuffle_static(final_sums_sender, {receiver: sender})
|
|
223
|
+
|
|
224
|
+
return final_sums_receiver
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def oblivious_groupby_sum_shuffle(
|
|
228
|
+
data: Any,
|
|
229
|
+
bins: Any,
|
|
230
|
+
K: int,
|
|
231
|
+
sender: int = 0,
|
|
232
|
+
receiver: int = 1,
|
|
233
|
+
helper: int = 2,
|
|
234
|
+
) -> Any:
|
|
235
|
+
"""Computes group-by sum using Oblivious Shuffle.
|
|
236
|
+
|
|
237
|
+
Note: This implementation uses secret sharing to hide the data values from the Receiver.
|
|
238
|
+
It requires a Helper party (3-party protocol).
|
|
239
|
+
|
|
240
|
+
Security:
|
|
241
|
+
- Sender learns nothing.
|
|
242
|
+
- Receiver learns the final sums and the bin sizes (from bins).
|
|
243
|
+
- Helper learns the bin sizes (from bins) and a random share of data.
|
|
244
|
+
- No party learns the individual data values or the permutation of data values.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
data: Input data tensor (on Sender). Shape (N,).
|
|
248
|
+
bins: Bin assignments (on Receiver). Shape (N,). Values in [0, K).
|
|
249
|
+
K: Number of bins.
|
|
250
|
+
sender: Rank of the data holder.
|
|
251
|
+
receiver: Rank of the bin holder.
|
|
252
|
+
helper: Rank of the helper party.
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
A tensor of shape (K,) on the Receiver containing the sums.
|
|
256
|
+
"""
|
|
257
|
+
|
|
258
|
+
# 1. Compute Permutation (Receiver)
|
|
259
|
+
def compute_perm_fn(b):
|
|
260
|
+
# b is the bins tensor
|
|
261
|
+
# We want indices that sort b
|
|
262
|
+
return tensor.run_jax(lambda x: jnp.argsort(x, stable=True), b)
|
|
263
|
+
|
|
264
|
+
perm = simp.pcall_static((receiver,), compute_perm_fn, bins)
|
|
265
|
+
|
|
266
|
+
# 2. Secret Share Data (Sender)
|
|
267
|
+
# Security Fix: Generate mask using crypto.random_bytes at RUNTIME on Sender
|
|
268
|
+
# This generates cryptographically secure random bytes that are unique per session.
|
|
269
|
+
|
|
270
|
+
def split_shares_fn(d):
|
|
271
|
+
# Generate random bytes at runtime (EDSL primitive, NOT during trace)
|
|
272
|
+
# This is secure because crypto.random_bytes executes at runtime on the party.
|
|
273
|
+
n_elements = d.type.shape[0]
|
|
274
|
+
bytes_per_element = 8 # int64 = 8 bytes
|
|
275
|
+
total_bytes = n_elements * bytes_per_element
|
|
276
|
+
|
|
277
|
+
mask_bytes = crypto.random_bytes(total_bytes)
|
|
278
|
+
|
|
279
|
+
def _apply_mask(arr, m_bytes):
|
|
280
|
+
# View random bytes as int64 (same as typical input dtype)
|
|
281
|
+
# For generality, we use arr.dtype, but assume int64 for now.
|
|
282
|
+
mask = m_bytes.view(jnp.int64).reshape(arr.shape)
|
|
283
|
+
d0 = arr - mask
|
|
284
|
+
d1 = mask
|
|
285
|
+
return d0, d1
|
|
286
|
+
|
|
287
|
+
return tensor.run_jax(_apply_mask, d, mask_bytes)
|
|
288
|
+
|
|
289
|
+
d0, d1 = simp.pcall_static((sender,), split_shares_fn, data)
|
|
290
|
+
|
|
291
|
+
# 3. Shuffle Share 0 (Sender -> Receiver)
|
|
292
|
+
# Receiver gets s0 = perm(d0)
|
|
293
|
+
s0 = permutation.apply_permutation(d0, perm, sender=sender, receiver=receiver)
|
|
294
|
+
|
|
295
|
+
# 4. Compute Agg0 (Receiver)
|
|
296
|
+
def agg_s0_fn(s_val, b, p, k_val):
|
|
297
|
+
def _impl(s_v, b_v, p_v):
|
|
298
|
+
# Sort bins to match data
|
|
299
|
+
s_bins = b_v[p_v]
|
|
300
|
+
# Compute sums for share 0
|
|
301
|
+
return jax.ops.segment_sum(s_v, s_bins, num_segments=k_val)
|
|
302
|
+
|
|
303
|
+
return tensor.run_jax(_impl, s_val, b, p)
|
|
304
|
+
|
|
305
|
+
agg0 = simp.pcall_static((receiver,), agg_s0_fn, s0, bins, perm, K)
|
|
306
|
+
|
|
307
|
+
# 5. Send Share 1 to Helper
|
|
308
|
+
d1_helper = simp.shuffle_static(d1, {helper: sender})
|
|
309
|
+
|
|
310
|
+
# 6. Send Bins to Helper
|
|
311
|
+
bins_helper = simp.shuffle_static(bins, {helper: receiver})
|
|
312
|
+
|
|
313
|
+
# 7. Compute Agg1 (Helper)
|
|
314
|
+
def agg_d1_fn(d_val, b_val, k_val):
|
|
315
|
+
def _impl(d_v, b_v):
|
|
316
|
+
return jax.ops.segment_sum(d_v, b_v, num_segments=k_val)
|
|
317
|
+
|
|
318
|
+
return tensor.run_jax(_impl, d_val, b_val)
|
|
319
|
+
|
|
320
|
+
agg1 = simp.pcall_static((helper,), agg_d1_fn, d1_helper, bins_helper, K)
|
|
321
|
+
|
|
322
|
+
# 8. Send Agg1 to Receiver
|
|
323
|
+
agg1_recv = simp.shuffle_static(agg1, {receiver: helper})
|
|
324
|
+
|
|
325
|
+
# 9. Combine (Receiver)
|
|
326
|
+
def combine_fn(a0, a1):
|
|
327
|
+
return tensor.run_jax(lambda x, y: x + y, a0, a1)
|
|
328
|
+
|
|
329
|
+
final_sums = simp.pcall_static((receiver,), combine_fn, agg0, agg1_recv)
|
|
330
|
+
|
|
331
|
+
return final_sums
|