mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,41 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """MPC (Multi-Party Computation) library for MPLang2.
16
+
17
+ Subpackages:
18
+ - ot: Oblivious Transfer protocols
19
+ - vole: Vector OLE protocols
20
+ - psi: Private Set Intersection
21
+ - analytics: Privacy-preserving analytics
22
+
23
+ Example usage:
24
+ from mplang.v2.libs.mpc import ot_transfer, apply_permutation
25
+ from mplang.v2.libs.mpc.vole import silver_vole
26
+ from mplang.v2.libs.mpc.psi import psi_intersect
27
+ """
28
+
29
+ from .analytics.aggregation import rotate_and_sum
30
+ from .analytics.groupby import oblivious_groupby_sum_bfv, oblivious_groupby_sum_shuffle
31
+ from .analytics.permutation import apply_permutation, secure_switch
32
+ from .ot.base import transfer as ot_transfer
33
+
34
+ __all__ = [
35
+ "apply_permutation",
36
+ "oblivious_groupby_sum_bfv",
37
+ "oblivious_groupby_sum_shuffle",
38
+ "ot_transfer",
39
+ "rotate_and_sum",
40
+ "secure_switch",
41
+ ]
@@ -0,0 +1,99 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Utilities for MPC protocols."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from typing import Any, cast
20
+
21
+ import jax.numpy as jnp
22
+
23
+ import mplang.v2.edsl as el
24
+ from mplang.v2.dialects import tensor
25
+
26
+
27
+ def bytes_to_bits(data: el.Object) -> el.Object:
28
+ """Convert bytes (uint8 tensor) to bits (uint8 tensor of 0s and 1s).
29
+
30
+ Output shape logic: (..., N) -> (..., N * 8)
31
+ """
32
+
33
+ def _to_bits(arr: Any) -> Any:
34
+ # View as u8
35
+ y_u8 = arr.view(jnp.uint8)
36
+ # Unpack produces Big Endian bits [b7, b6, ..., b0] per byte
37
+ bits = jnp.unpackbits(y_u8)
38
+ # Reshape to (N, 8) and flip to get [b0, ..., b7]
39
+ bits = bits.reshape(-1, 8)
40
+ bits = jnp.fliplr(bits)
41
+ return bits.reshape(-1)
42
+
43
+ return cast(el.Object, tensor.run_jax(_to_bits, data))
44
+
45
+
46
+ def bits_to_bytes(bits: el.Object) -> el.Object:
47
+ """Convert bits to bytes.
48
+
49
+ Output shape logic: (..., N * 8) -> (..., N)
50
+ """
51
+
52
+ def _to_bytes(arr: Any) -> Any:
53
+ return jnp.packbits(arr, axis=-1)
54
+
55
+ return cast(el.Object, tensor.run_jax(_to_bytes, bits))
56
+
57
+
58
+ def transpose_128(matrix_bits: el.Object) -> el.Object:
59
+ """Transpose a bit matrix.
60
+
61
+ Just a wrapper for tensor.transpose currently.
62
+ """
63
+ return tensor.transpose(matrix_bits, perm=(1, 0))
64
+
65
+
66
+ class CuckooHash:
67
+ """Simple Cuckoo Hashing simulation."""
68
+
69
+ def __init__(self, num_bins: int, num_hash_functions: int = 3, stash_size: int = 0):
70
+ self.num_bins = num_bins
71
+ self.num_functions = num_hash_functions
72
+ self.stash_size = stash_size
73
+
74
+ def hash(self, items: el.Object, seed: int) -> el.Object:
75
+ """Hash items to bin indices."""
76
+
77
+ # We perform hashing.
78
+ # Note: We return hashes for each function?
79
+ # Usually simplest cuckoo uses 3 hash functions.
80
+ # We can return (num_funcs, N) or (N, num_funcs)
81
+
82
+ def _hash_fn(xs: Any, s: int) -> Any:
83
+ # xs: array of items
84
+
85
+ # Simple hash: (x * s + s) % bins
86
+ # We want multiple hashes?
87
+ # For now, let's just return one hash per seed provided (assuming call per seed)
88
+ # Or if seed is a single int, we might mix it.
89
+
90
+ # Let's assume this function handles one hash instance.
91
+ res = (xs * s + s) % self.num_bins
92
+ return res.astype(jnp.int32)
93
+
94
+ # Passing self.num_bins as constant implementation detail inside _hash_fn closure is fine
95
+ # if using run_jax (as it's compiled).
96
+ # Actually run_jax recompiles if closure changes?
97
+ # run_jax supports closures.
98
+
99
+ return cast(el.Object, tensor.run_jax(_hash_fn, items, seed))
@@ -0,0 +1,35 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Privacy-preserving analytics operations.
16
+
17
+ Submodules:
18
+ - aggregation: BFV homomorphic aggregation
19
+ - groupby: Oblivious Group-By operations
20
+ - permutation: Secure permutation (Bitonic Sort)
21
+ """
22
+
23
+ from .aggregation import aggregate_sparse, batch_bucket_aggregate, rotate_and_sum
24
+ from .groupby import oblivious_groupby_sum_bfv, oblivious_groupby_sum_shuffle
25
+ from .permutation import apply_permutation, secure_switch
26
+
27
+ __all__ = [
28
+ "aggregate_sparse",
29
+ "apply_permutation",
30
+ "batch_bucket_aggregate",
31
+ "oblivious_groupby_sum_bfv",
32
+ "oblivious_groupby_sum_shuffle",
33
+ "rotate_and_sum",
34
+ "secure_switch",
35
+ ]
@@ -0,0 +1,372 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Homomorphic Aggregation library.
16
+
17
+ This module implements efficient aggregation algorithms using BFV rotation.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import math
23
+ from typing import Any
24
+
25
+ import numpy as np
26
+
27
+ from mplang.v2.dialects import bfv, tensor
28
+
29
+
30
+ def _safe_rotate(
31
+ ciphertext: Any, step: int, galois_keys: Any, max_step: int = 1024
32
+ ) -> Any:
33
+ """Rotate ciphertext by step, decomposing large steps if needed.
34
+
35
+ SEAL's rotate_rows requires step to be in range (-slot_count/2, slot_count/2).
36
+ For poly_modulus_degree=4096, slot_count=4096, max valid step is 2047.
37
+ We use a conservative max_step=1024 by default for safety.
38
+
39
+ For large steps, we decompose into multiple rotations:
40
+ - rotate(x, 3000) = rotate(rotate(rotate(x, 1024), 1024), 952)
41
+ """
42
+ if step == 0:
43
+ return ciphertext
44
+ if abs(step) <= max_step:
45
+ return bfv.rotate(ciphertext, step, galois_keys)
46
+
47
+ # Decompose large step into multiple rotations
48
+ current = ciphertext
49
+ remaining = abs(step)
50
+ sign = 1 if step > 0 else -1
51
+
52
+ while remaining > 0:
53
+ rot = min(remaining, max_step)
54
+ current = bfv.rotate(current, sign * rot, galois_keys)
55
+ remaining -= rot
56
+
57
+ return current
58
+
59
+
60
+ def _rotate_and_sum_row(
61
+ ciphertext: Any, k: int, galois_keys: Any, max_step: int = 1024
62
+ ) -> Any:
63
+ """Sum first k elements within a single row (k <= row_size).
64
+
65
+ Uses the recursive doubling algorithm with O(log k) rotations.
66
+ """
67
+ if k <= 1:
68
+ return ciphertext
69
+
70
+ num_steps = math.ceil(math.log2(k))
71
+ current = ciphertext
72
+
73
+ for i in range(num_steps):
74
+ step = 1 << i
75
+ if step >= k:
76
+ break
77
+ rotated = _safe_rotate(current, step, galois_keys, max_step)
78
+ current = bfv.add(current, rotated)
79
+
80
+ return current
81
+
82
+
83
+ def rotate_and_sum(
84
+ ciphertext: Any, k: int, galois_keys: Any, slot_count: int = 4096
85
+ ) -> Any:
86
+ """Aggregate the first k elements of a ciphertext using O(log k) rotations.
87
+
88
+ The result is placed in the 0-th slot.
89
+ This assumes the input ciphertext has relevant data in slots 0..k-1
90
+ and zeros (or irrelevant data) elsewhere, OR that the caller will mask the result.
91
+
92
+ Args:
93
+ ciphertext: The BFV ciphertext.
94
+ k: The number of elements to sum.
95
+ galois_keys: Keys required for rotation.
96
+ slot_count: Total number of slots (default 4096 for poly_degree=4096).
97
+
98
+ Returns:
99
+ Ciphertext where slot 0 contains sum(ciphertext[0..k-1]).
100
+
101
+ Note:
102
+ SEAL batching arranges slots as 2 rows of slot_count/2 each.
103
+ - rotate_rows rotates within each row (circular)
104
+ - rotate_columns swaps the two rows
105
+
106
+ For k <= row_size (2048), only row rotations are needed.
107
+ For k > row_size, we use rotate_columns to aggregate across rows.
108
+ """
109
+ row_size = slot_count // 2
110
+
111
+ if k <= row_size:
112
+ # Simple case: all elements in row 0
113
+ return _rotate_and_sum_row(ciphertext, k, galois_keys)
114
+
115
+ # k > row_size: data spans both rows
116
+ # Strategy:
117
+ # 1. Sum row 0 completely (row_size elements)
118
+ # 2. rotate_columns to bring row 1 to row 0 position
119
+ # 3. Sum the first (k - row_size) elements of what was row 1
120
+ # 4. Add the two partial sums
121
+
122
+ # Sum row 0 completely
123
+ row0_sum = _rotate_and_sum_row(ciphertext, row_size, galois_keys)
124
+
125
+ # Rotate columns: swap row 0 <-> row 1
126
+ # Now row 1's data is in row 0 position
127
+ col_rotated = bfv.rotate_columns(ciphertext, galois_keys)
128
+
129
+ # Sum the first (k - row_size) elements (originally in row 1)
130
+ row1_count = k - row_size
131
+ row1_sum = _rotate_and_sum_row(col_rotated, row1_count, galois_keys)
132
+
133
+ # Both row0_sum and row1_sum have their results in slot 0
134
+ # Add them together
135
+ return bfv.add(row0_sum, row1_sum)
136
+
137
+
138
+ def aggregate_sparse(
139
+ ciphertext: Any,
140
+ aggregations: list[tuple[int, list[int]]],
141
+ galois_keys: Any,
142
+ encoder: Any,
143
+ vector_size: int,
144
+ ) -> Any:
145
+ """Perform sparse aggregation.
146
+
147
+ Args:
148
+ ciphertext: Input ciphertext.
149
+ aggregations: List of (target_slot, [source_slots]).
150
+ e.g. [(0, [0, 3, 8]), (1, [1, 5])]
151
+ galois_keys: Rotation keys.
152
+ encoder: BFV encoder for encoding masks.
153
+ vector_size: Total size of the vector (slots).
154
+
155
+ Returns:
156
+ Ciphertext with aggregated results in target slots.
157
+ """
158
+ # Naive approach: For each target, sum sources.
159
+ # Optimized approach:
160
+ # 1. Decompose into rotations.
161
+ # For target t, source s: need rotation by (t-s).
162
+ # Group by rotation amount.
163
+ # 2. Apply rotations and accumulate.
164
+
165
+ # Map: rotation_amount -> mask
166
+ # We want to compute: Result = Sum( Rotate(Input, r) * Mask_r )
167
+ # where Mask_r has 1 at slot t if (t - r) is a source for t.
168
+
169
+ # Example: t=0, s={0, 3, 8}.
170
+ # s=0: rot=0. Mask[0]=1.
171
+ # s=3: rot=-3. Mask[0]=1.
172
+ # s=8: rot=-8. Mask[0]=1.
173
+ # Example: t=1, s={1, 5}.
174
+ # s=1: rot=0. Mask[1]=1.
175
+ # s=5: rot=-4. Mask[1]=1.
176
+
177
+ # Combined:
178
+ # Rot 0: Mask[0]=1, Mask[1]=1. -> Mask = [1, 1, 0...]
179
+ # Rot -3: Mask[0]=1. -> Mask = [1, 0...]
180
+ # Rot -8: Mask[0]=1. -> Mask = [1, 0...]
181
+ # Rot -4: Mask[1]=1. -> Mask = [0, 1, 0...]
182
+
183
+ rotations = {} # shift -> mask_list
184
+
185
+ for target, sources in aggregations:
186
+ for src in sources:
187
+ shift = src - target
188
+ if shift not in rotations:
189
+ rotations[shift] = [0] * vector_size
190
+ rotations[shift][target] = 1
191
+
192
+ final_result = None
193
+
194
+ for shift, mask_list in rotations.items():
195
+ # Optimization: Skip if mask is all zeros (no contribution from this rotation)
196
+ if not any(mask_list):
197
+ continue
198
+
199
+ # Create mask plaintext
200
+ # In a real implementation, we encode this list to a Plaintext
201
+ mask_tensor = tensor.constant(np.array(mask_list, dtype=np.int64))
202
+ mask_pt = bfv.encode(mask_tensor, encoder)
203
+
204
+ # Rotate
205
+ if shift == 0:
206
+ rotated_c = ciphertext
207
+ else:
208
+ rotated_c = bfv.rotate(ciphertext, shift, galois_keys)
209
+
210
+ # Mask
211
+ masked_c = bfv.mul(rotated_c, mask_pt)
212
+
213
+ # Accumulate
214
+ if final_result is None:
215
+ final_result = masked_c
216
+ else:
217
+ final_result = bfv.add(final_result, masked_c)
218
+
219
+ return final_result
220
+
221
+
222
+ def masked_aggregate(ciphertexts: list[Any], masks: list[Any]) -> Any:
223
+ """Aggregate multiple partial results using masks.
224
+
225
+ Args:
226
+ ciphertexts: List of ciphertexts.
227
+ masks: List of plaintexts (masks).
228
+
229
+ Returns:
230
+ Sum(ct * mask)
231
+ """
232
+ if not ciphertexts or not masks:
233
+ raise ValueError("Empty input lists")
234
+ if len(ciphertexts) != len(masks):
235
+ raise ValueError("Mismatch in ciphertexts and masks length")
236
+
237
+ total = None
238
+
239
+ for ct, mask in zip(ciphertexts, masks, strict=True):
240
+ # ct * mask
241
+ masked = bfv.mul(ct, mask)
242
+
243
+ if total is None:
244
+ total = masked
245
+ else:
246
+ total = bfv.add(total, masked)
247
+
248
+ return total
249
+
250
+
251
+ # ==============================================================================
252
+ # SIMD Bucket Packing for Histogram Computation
253
+ # ==============================================================================
254
+
255
+
256
+ def strided_rotate_and_sum(
257
+ ciphertext: Any,
258
+ stride: int,
259
+ n_elements: int,
260
+ galois_keys: Any,
261
+ max_step: int = 1024,
262
+ ) -> Any:
263
+ """Aggregate elements at positions [0, stride, 2*stride, ...] into slot 0.
264
+
265
+ This is used for SIMD bucket packing where each bucket's values are
266
+ placed at strided positions.
267
+
268
+ Args:
269
+ ciphertext: The BFV ciphertext with values at strided positions.
270
+ stride: Distance between consecutive elements to sum.
271
+ n_elements: Number of elements to aggregate (at positions 0, stride, ..., (n-1)*stride).
272
+ galois_keys: Rotation keys.
273
+ max_step: Maximum rotation step for safe_rotate.
274
+
275
+ Returns:
276
+ Ciphertext where slot 0 contains sum of strided elements.
277
+
278
+ Example:
279
+ stride=64, n_elements=47 (bucket has 47 samples)
280
+ Values at slots: 0, 64, 128, 192, ...
281
+ Result: slot[0] = sum of all these values
282
+ """
283
+ if n_elements <= 1:
284
+ return ciphertext
285
+
286
+ # Use recursive doubling with strided rotations
287
+ # Step 1: rotate by stride, add -> pairs summed at even positions
288
+ # Step 2: rotate by 2*stride, add -> quads summed at positions 0, 4*stride, ...
289
+ # ...
290
+ num_steps = math.ceil(math.log2(n_elements))
291
+ current = ciphertext
292
+
293
+ for i in range(num_steps):
294
+ step = stride * (1 << i)
295
+ if step >= n_elements * stride:
296
+ break
297
+ rotated = _safe_rotate(current, step, galois_keys, max_step)
298
+ current = bfv.add(current, rotated)
299
+
300
+ return current
301
+
302
+
303
+ def batch_bucket_aggregate(
304
+ ciphertext: Any,
305
+ n_buckets: int,
306
+ samples_per_bucket: int,
307
+ galois_keys: Any,
308
+ slot_count: int = 4096,
309
+ ) -> Any:
310
+ """Aggregate samples within each bucket region in a packed ciphertext.
311
+
312
+ Assumes the ciphertext has the following layout:
313
+ - slot_count is divided into n_buckets regions of size `stride = slot_count // n_buckets`
314
+ - Each bucket b occupies slots [b*stride, b*stride + samples_per_bucket)
315
+ - Samples are placed at consecutive positions within their bucket region
316
+
317
+ After aggregation, slot[b * stride] contains sum of bucket b.
318
+
319
+ Args:
320
+ ciphertext: Packed ciphertext with samples in bucket regions.
321
+ n_buckets: Number of buckets.
322
+ samples_per_bucket: Max samples per bucket (for rotation count).
323
+ galois_keys: Rotation keys.
324
+ slot_count: Total BFV slots.
325
+
326
+ Returns:
327
+ Ciphertext where slot[b * stride] = sum of bucket b's values.
328
+ """
329
+ if samples_per_bucket <= 1:
330
+ return ciphertext
331
+
332
+ # Use recursive doubling within each bucket region
333
+ # Since all buckets use the same relative positions, one set of rotations
334
+ # aggregates ALL buckets simultaneously!
335
+ num_steps = math.ceil(math.log2(samples_per_bucket))
336
+ current = ciphertext
337
+
338
+ for i in range(num_steps):
339
+ step = 1 << i
340
+ if step >= samples_per_bucket:
341
+ break
342
+ # Rotating by `step` shifts values within each bucket region
343
+ # Add original + rotated to sum pairs/quads/etc.
344
+ rotated = _safe_rotate(current, step, galois_keys)
345
+ current = bfv.add(current, rotated)
346
+
347
+ return current
348
+
349
+
350
+ def extract_bucket_results(
351
+ vector: Any,
352
+ n_buckets: int,
353
+ slot_count: int = 4096,
354
+ ) -> Any:
355
+ """Extract bucket sums from a packed result vector.
356
+
357
+ After batch_bucket_aggregate, each bucket's sum is at slot[b * stride].
358
+ This function extracts those values.
359
+
360
+ Args:
361
+ vector: Decoded vector from packed ciphertext.
362
+ n_buckets: Number of buckets.
363
+ slot_count: Total slots.
364
+
365
+ Returns:
366
+ (n_buckets,) array of bucket sums.
367
+ """
368
+ import jax.numpy as jnp
369
+
370
+ stride = slot_count // n_buckets
371
+ indices = jnp.arange(n_buckets) * stride
372
+ return vector[indices]
@@ -0,0 +1,99 @@
1
+ # Oblivious Group-by Sum Design
2
+
3
+ This document outlines the design for Oblivious Group-by Sum algorithms in MPLang. The goal is to compute the sum of values in `data` (held by P0) grouped by `bins` (held by P1), such that:
4
+ - P0 learns nothing about the `bins` (permutation/grouping).
5
+ - P1 learns nothing about the `data` values (except the final aggregated sums).
6
+ - The result is revealed to P1 (or shared).
7
+
8
+ We propose two approaches based on the trade-off between communication and computation, and the cardinality of groups ($K$).
9
+
10
+ ## Interface
11
+
12
+ ```python
13
+ def oblivious_groupby_sum(
14
+ data: Plaintext[P0],
15
+ bins: Plaintext[P1],
16
+ K: int,
17
+ method: str = "auto"
18
+ ) -> Plaintext[P1]:
19
+ """
20
+ Args:
21
+ data: Input data vector held by P0.
22
+ bins: Bin assignments for each data element held by P1.
23
+ Values must be in [0, K).
24
+ K: The number of bins (groups).
25
+ method: "bfv" (HE-based) or "shuffle" (OT-based).
26
+
27
+ Returns:
28
+ A vector of length K held by P1 containing the sum of data for each bin.
29
+ """
30
+ ```
31
+
32
+ ## Approach 1: HE-based (BFV SIMD)
33
+
34
+ Best for: **Small K** (e.g., $K < 1000$), Low Bandwidth.
35
+
36
+ ### Algorithm
37
+
38
+ 1. **Encryption (P0)**:
39
+ - P0 encrypts `data` using a BFV scheme with SIMD packing.
40
+ - Sends ciphertext(s) `Enc(data)` to P1.
41
+
42
+ 2. **Aggregation (P1)**:
43
+ - P1 holds `bins`. For each bin $k \in [0, K)$:
44
+ - Construct a plaintext mask vector $M_k$ where $M_k[i] = 1$ if $bins[i] == k$, else $0$.
45
+ - Compute homomorphic multiplication: $Enc(Sum_k) = Enc(data) \otimes M_k$.
46
+ - Sum the slots in $Enc(Sum_k)$ to get the total sum for bin $k$.
47
+ - *Optimization*: Instead of full slot summation for every bin (which is expensive), P1 can just compute the element-wise product and accumulate. The final reduction can be done by sending back to P0 or using rotations if $K$ is small enough to pack into result ciphertexts.
48
+ - *Simplified Flow*: P1 computes $Enc(Partial_k) = Enc(data) \cdot M_k$. P1 sends these $K$ ciphertexts (or batched versions) back to P0.
49
+
50
+ 3. **Decryption & Finalize (P0 -> P1)**:
51
+ - P0 decrypts the partial sums.
52
+ - P0 computes the sum of the vector for each bin.
53
+ - P0 sends the final $K$ sums to P1.
54
+ - *Privacy Note*: To prevent P0 from learning the partial sums (which reveals data distribution), P1 should add a random mask to the result before sending to P0, or use a proper threshold decryption if available. For the "Simplified Flow" above, P0 sees the masked data values. This might leak info.
55
+ - *Refined Privacy Flow*:
56
+ - P1 computes $Enc(V_k) = Enc(data) \cdot M_k$.
57
+ - P1 computes $Enc(S_k) = \text{TotalSum}(Enc(V_k))$ using rotations and additions.
58
+ - P1 masks $Enc(S_k)$ with a random value $r_k$: $Enc(O_k) = Enc(S_k) + Enc(r_k)$.
59
+ - P1 sends $Enc(O_k)$ to P0.
60
+ - P0 decrypts to get $O_k = S_k + r_k$ and sends back to P1.
61
+ - P1 subtracts $r_k$ to get $S_k$.
62
+
63
+ ### Complexity
64
+ - **Comm**: $O(N/B)$ ciphertexts (P0->P1) + $O(K)$ ciphertexts (P1->P0). ($B$ is batch size).
65
+ - **Comp**: $O(K \cdot N/B)$ homomorphic multiplications and additions.
66
+
67
+ ## Approach 2: OT-based (Shuffle + Prefix Sum)
68
+
69
+ Best for: **Large K**, High Bandwidth.
70
+
71
+ ### Algorithm
72
+
73
+ 1. **Sort Permutation (P1)**:
74
+ - P1 calculates a permutation $\pi$ that sorts `data` according to `bins`.
75
+ - P1 calculates boundary indices for each bin.
76
+
77
+ 2. **Oblivious Shuffle (P0, P1)**:
78
+ - Use a Benes network or similar switching network.
79
+ - P0 inputs `data`. P1 inputs control bits derived from $\pi$.
80
+ - Output: Secret shares of permuted data $\langle D' \rangle_0, \langle D' \rangle_1$.
81
+
82
+ 3. **Secret Shared Prefix Sum (P0, P1)**:
83
+ - Locally compute prefix sums of shares: $\langle S \rangle_0 = \text{cumsum}(\langle D' \rangle_0)$, $\langle S \rangle_1 = \text{cumsum}(\langle D' \rangle_1)$.
84
+
85
+ 4. **Oblivious Gather (P0, P1)**:
86
+ - P1 knows the boundary indices $idx_k$.
87
+ - P1 needs $S[idx_k] = \langle S \rangle_0[idx_k] + \langle S \rangle_1[idx_k]$.
88
+ - P1 has $\langle S \rangle_1[idx_k]$ locally.
89
+ - To get $\langle S \rangle_0[idx_k]$ obliviously:
90
+ - Use another permutation network or ORAM to fetch these values without revealing $idx_k$ to P0.
91
+ - Or, since P1 is the result receiver, we can use a simpler selection protocol if we don't hide the access pattern from P0 (but we must hide it to protect bin sizes).
92
+ - A second shuffle network mapping $idx_k \to k$ is secure.
93
+
94
+ 5. **Difference (P1)**:
95
+ - P1 computes $Result[k] = S[idx_k] - S[idx_{k-1}]$.
96
+
97
+ ### Complexity
98
+ - **Comm**: $O(N \log N)$ bits for shuffle.
99
+ - **Comp**: Low (symmetric crypto).