mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,331 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Oblivious Group-by Sum library.
16
+
17
+ This module implements algorithms to compute the sum of values grouped by bins,
18
+ where the data holder (Sender) and the bin holder (Receiver) keep their inputs private.
19
+ """
20
+
21
+ # mypy: disable-error-code="no-untyped-def"
22
+
23
+ from __future__ import annotations
24
+
25
+ from typing import Any
26
+
27
+ import jax
28
+ import jax.numpy as jnp
29
+
30
+ from mplang.v2.dialects import bfv, crypto, simp, tensor
31
+ from mplang.v2.libs.mpc.analytics import aggregation, permutation
32
+
33
+
34
+ def oblivious_groupby_sum_bfv(
35
+ data: Any,
36
+ bins: Any,
37
+ K: int,
38
+ sender: int = 0,
39
+ receiver: int = 1,
40
+ poly_modulus_degree: int = 4096,
41
+ plain_modulus: int | None = None,
42
+ ) -> Any:
43
+ """Computes group-by sum using BFV homomorphic encryption.
44
+
45
+ Best for small K (number of bins) and low bandwidth.
46
+
47
+ Args:
48
+ data: Input data tensor (on Sender). Shape (N,).
49
+ bins: Bin assignments (on Receiver). Shape (N,). Values in [0, K).
50
+ K: Number of bins.
51
+ sender: Rank of the data holder.
52
+ receiver: Rank of the bin holder.
53
+ poly_modulus_degree: BFV polynomial modulus degree (slot count).
54
+ plain_modulus: BFV plaintext modulus. If None, uses backend default.
55
+
56
+ Returns:
57
+ A tensor of shape (K,) on the Receiver containing the sums.
58
+ """
59
+
60
+ # ----------------------------------------------------------------------
61
+ # 1. KeyGen (Sender)
62
+ # ----------------------------------------------------------------------
63
+ def keygen_fn(degree, p_mod):
64
+ kwargs = {"poly_modulus_degree": degree}
65
+ if p_mod is not None:
66
+ kwargs["plain_modulus"] = p_mod
67
+
68
+ pk, sk = bfv.keygen(**kwargs)
69
+ rk = bfv.make_relin_keys(sk)
70
+ gk = bfv.make_galois_keys(sk)
71
+ encoder = bfv.create_encoder(poly_modulus_degree=degree)
72
+ return pk, sk, rk, gk, encoder
73
+
74
+ # We use a closure to capture parameters
75
+ def keygen_fn_closure():
76
+ return keygen_fn(poly_modulus_degree, plain_modulus)
77
+
78
+ pk, sk, rk, gk, encoder = simp.pcall_static((sender,), keygen_fn_closure)
79
+
80
+ # ----------------------------------------------------------------------
81
+ # 2. Encrypt Data (Sender)
82
+ # ----------------------------------------------------------------------
83
+ def encrypt_chunks_fn(d, enc, p_key):
84
+ # d is a Value (Tensor)
85
+ shape = d.type.shape
86
+ N = shape[0]
87
+ # Use half the degree to avoid column rotation issues (only row rotation supported)
88
+ B = poly_modulus_degree // 2
89
+ num_chunks = (N + B - 1) // B
90
+
91
+ ciphertexts = []
92
+ for i in range(num_chunks):
93
+ start = i * B
94
+ end = min((i + 1) * B, N)
95
+
96
+ # Bind loop variables
97
+ def get_chunk(x, s=start, e=end, b_val=B):
98
+ c = x[s:e]
99
+ if e - s < b_val:
100
+ c = jnp.pad(c, (0, b_val - (e - s)))
101
+ return c
102
+
103
+ chunk = tensor.run_jax(get_chunk, d)
104
+
105
+ pt = bfv.encode(chunk, enc)
106
+ ct = bfv.encrypt(pt, p_key)
107
+ ciphertexts.append(ct)
108
+
109
+ return tuple(ciphertexts)
110
+
111
+ encrypted_chunks = simp.pcall_static(
112
+ (sender,), encrypt_chunks_fn, data, encoder, pk
113
+ )
114
+
115
+ # Transfer data and keys to Receiver
116
+ def transfer_to_receiver(obj):
117
+ return simp.shuffle_static(obj, {receiver: sender})
118
+
119
+ # Always a tuple now
120
+ encrypted_chunks_recv = tuple(transfer_to_receiver(c) for c in encrypted_chunks)
121
+
122
+ pk_recv = transfer_to_receiver(pk)
123
+ rk_recv = transfer_to_receiver(rk)
124
+ gk_recv = transfer_to_receiver(gk)
125
+ encoder_recv = transfer_to_receiver(encoder)
126
+
127
+ # ----------------------------------------------------------------------
128
+ # 3. Aggregate (Receiver)
129
+ # ----------------------------------------------------------------------
130
+ def aggregate_fn(b_data, cts, p_key, r_key, g_key, enc):
131
+ # b_data is Value (Tensor)
132
+ # cts is list/tuple of Values (Ciphertexts)
133
+
134
+ N = b_data.type.shape[0]
135
+ # Use half the degree to avoid column rotation issues
136
+ B = poly_modulus_degree // 2
137
+ num_chunks = len(cts)
138
+
139
+ bin_sums = [None] * K
140
+
141
+ # Zero ciphertext
142
+ # Pass b_data as dummy to satisfy run_jax requirement
143
+ def make_zero(dummy, b_val=B):
144
+ return jnp.zeros((b_val,), dtype=jnp.int64)
145
+
146
+ zero_vec = tensor.run_jax(make_zero, b_data)
147
+ pt_zero = bfv.encode(zero_vec, enc)
148
+ ct_zero = bfv.encrypt(pt_zero, p_key)
149
+
150
+ for k in range(K):
151
+ current_sum = ct_zero
152
+
153
+ for i in range(num_chunks):
154
+ start = i * B
155
+ end = min((i + 1) * B, N)
156
+
157
+ def get_mask(b_chunk_full, s=start, e=end, b_val=B, k_target=k):
158
+ # b_chunk_full is the full bins tensor
159
+ c = b_chunk_full[s:e]
160
+ if e - s < b_val:
161
+ c = jnp.pad(c, (0, b_val - (e - s)), constant_values=-1)
162
+ return (c == k_target).astype(jnp.int64)
163
+
164
+ mask = tensor.run_jax(get_mask, b_data)
165
+ pt_mask = bfv.encode(mask, enc)
166
+
167
+ ct_masked = bfv.mul(cts[i], pt_mask)
168
+ ct_masked = bfv.relinearize(ct_masked, r_key)
169
+ current_sum = bfv.add(current_sum, ct_masked)
170
+
171
+ total_sum_ct = aggregation.rotate_and_sum(
172
+ current_sum, B, g_key, slot_count=poly_modulus_degree
173
+ )
174
+ bin_sums[k] = total_sum_ct
175
+
176
+ return bin_sums
177
+
178
+ encrypted_sums = simp.pcall_static(
179
+ (receiver,),
180
+ aggregate_fn,
181
+ bins,
182
+ encrypted_chunks_recv,
183
+ pk_recv,
184
+ rk_recv,
185
+ gk_recv,
186
+ encoder_recv,
187
+ )
188
+
189
+ # Transfer encrypted sums back to Sender
190
+ def transfer_to_sender(obj):
191
+ return simp.shuffle_static(obj, {sender: receiver})
192
+
193
+ # Always a tuple/list
194
+ encrypted_sums_sender = tuple(transfer_to_sender(s) for s in encrypted_sums)
195
+
196
+ # ----------------------------------------------------------------------
197
+ # 4. Decrypt (Sender)
198
+ # ----------------------------------------------------------------------
199
+ def decrypt_fn(cts, s_key, enc):
200
+ results = []
201
+ for ct in cts:
202
+ pt = bfv.decrypt(ct, s_key)
203
+ vec = bfv.decode(pt, enc)
204
+ # vec is a Tensor Value
205
+ # We need to extract the first element.
206
+ val = tensor.run_jax(lambda v: v[0], vec)
207
+ results.append(val)
208
+
209
+ # Stack results into a single tensor
210
+ def stack(*args):
211
+ return jnp.stack(args)
212
+
213
+ return tensor.run_jax(stack, *results)
214
+
215
+ final_sums_sender = simp.pcall_static(
216
+ (sender,), decrypt_fn, encrypted_sums_sender, sk, encoder
217
+ )
218
+
219
+ # ----------------------------------------------------------------------
220
+ # 5. Return to Receiver
221
+ # ----------------------------------------------------------------------
222
+ final_sums_receiver = simp.shuffle_static(final_sums_sender, {receiver: sender})
223
+
224
+ return final_sums_receiver
225
+
226
+
227
+ def oblivious_groupby_sum_shuffle(
228
+ data: Any,
229
+ bins: Any,
230
+ K: int,
231
+ sender: int = 0,
232
+ receiver: int = 1,
233
+ helper: int = 2,
234
+ ) -> Any:
235
+ """Computes group-by sum using Oblivious Shuffle.
236
+
237
+ Note: This implementation uses secret sharing to hide the data values from the Receiver.
238
+ It requires a Helper party (3-party protocol).
239
+
240
+ Security:
241
+ - Sender learns nothing.
242
+ - Receiver learns the final sums and the bin sizes (from bins).
243
+ - Helper learns the bin sizes (from bins) and a random share of data.
244
+ - No party learns the individual data values or the permutation of data values.
245
+
246
+ Args:
247
+ data: Input data tensor (on Sender). Shape (N,).
248
+ bins: Bin assignments (on Receiver). Shape (N,). Values in [0, K).
249
+ K: Number of bins.
250
+ sender: Rank of the data holder.
251
+ receiver: Rank of the bin holder.
252
+ helper: Rank of the helper party.
253
+
254
+ Returns:
255
+ A tensor of shape (K,) on the Receiver containing the sums.
256
+ """
257
+
258
+ # 1. Compute Permutation (Receiver)
259
+ def compute_perm_fn(b):
260
+ # b is the bins tensor
261
+ # We want indices that sort b
262
+ return tensor.run_jax(lambda x: jnp.argsort(x, stable=True), b)
263
+
264
+ perm = simp.pcall_static((receiver,), compute_perm_fn, bins)
265
+
266
+ # 2. Secret Share Data (Sender)
267
+ # Security Fix: Generate mask using crypto.random_bytes at RUNTIME on Sender
268
+ # This generates cryptographically secure random bytes that are unique per session.
269
+
270
+ def split_shares_fn(d):
271
+ # Generate random bytes at runtime (EDSL primitive, NOT during trace)
272
+ # This is secure because crypto.random_bytes executes at runtime on the party.
273
+ n_elements = d.type.shape[0]
274
+ bytes_per_element = 8 # int64 = 8 bytes
275
+ total_bytes = n_elements * bytes_per_element
276
+
277
+ mask_bytes = crypto.random_bytes(total_bytes)
278
+
279
+ def _apply_mask(arr, m_bytes):
280
+ # View random bytes as int64 (same as typical input dtype)
281
+ # For generality, we use arr.dtype, but assume int64 for now.
282
+ mask = m_bytes.view(jnp.int64).reshape(arr.shape)
283
+ d0 = arr - mask
284
+ d1 = mask
285
+ return d0, d1
286
+
287
+ return tensor.run_jax(_apply_mask, d, mask_bytes)
288
+
289
+ d0, d1 = simp.pcall_static((sender,), split_shares_fn, data)
290
+
291
+ # 3. Shuffle Share 0 (Sender -> Receiver)
292
+ # Receiver gets s0 = perm(d0)
293
+ s0 = permutation.apply_permutation(d0, perm, sender=sender, receiver=receiver)
294
+
295
+ # 4. Compute Agg0 (Receiver)
296
+ def agg_s0_fn(s_val, b, p, k_val):
297
+ def _impl(s_v, b_v, p_v):
298
+ # Sort bins to match data
299
+ s_bins = b_v[p_v]
300
+ # Compute sums for share 0
301
+ return jax.ops.segment_sum(s_v, s_bins, num_segments=k_val)
302
+
303
+ return tensor.run_jax(_impl, s_val, b, p)
304
+
305
+ agg0 = simp.pcall_static((receiver,), agg_s0_fn, s0, bins, perm, K)
306
+
307
+ # 5. Send Share 1 to Helper
308
+ d1_helper = simp.shuffle_static(d1, {helper: sender})
309
+
310
+ # 6. Send Bins to Helper
311
+ bins_helper = simp.shuffle_static(bins, {helper: receiver})
312
+
313
+ # 7. Compute Agg1 (Helper)
314
+ def agg_d1_fn(d_val, b_val, k_val):
315
+ def _impl(d_v, b_v):
316
+ return jax.ops.segment_sum(d_v, b_v, num_segments=k_val)
317
+
318
+ return tensor.run_jax(_impl, d_val, b_val)
319
+
320
+ agg1 = simp.pcall_static((helper,), agg_d1_fn, d1_helper, bins_helper, K)
321
+
322
+ # 8. Send Agg1 to Receiver
323
+ agg1_recv = simp.shuffle_static(agg1, {receiver: helper})
324
+
325
+ # 9. Combine (Receiver)
326
+ def combine_fn(a0, a1):
327
+ return tensor.run_jax(lambda x, y: x + y, a0, a1)
328
+
329
+ final_sums = simp.pcall_static((receiver,), combine_fn, agg0, agg1_recv)
330
+
331
+ return final_sums