mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,200 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Unbalanced PSI Protocol.
16
+
17
+ This module implements unbalanced PSI for scenarios where client set size n << server set size N.
18
+ Uses Seeded OKVS (via derived keys) to prevent pre-computation attacks.
19
+
20
+ Security Model:
21
+ - Session-specific random seed generated at RUNTIME on the Server.
22
+ - Both Key and Value derivations use the seed for consistent security.
23
+ - WARNING: Online dictionary attacks by active clients remain possible without OPRF.
24
+
25
+ Protocol:
26
+ 1. Server generates random Seed at runtime.
27
+ 2. Server computes K' = H(ServerItems, Seed) and V = H(ServerItems, Seed).
28
+ 3. Server solves OKVS: Table = Solve(K', V).
29
+ 4. Server sends Seed + Table to Client.
30
+ 5. Client computes k' = H(ClientItems, Seed) and v = H(ClientItems, Seed).
31
+ 6. Client decodes V' = Decode(k', Table).
32
+ 7. Client checks V' == v.
33
+ """
34
+
35
+ from typing import Any, cast
36
+
37
+ import jax.numpy as jnp
38
+
39
+ import mplang.v2.edsl as el
40
+ import mplang.v2.edsl.typing as elt
41
+ from mplang.v2.dialects import crypto, field, simp, tensor
42
+ from mplang.v2.libs.mpc.psi.okvs_gct import get_okvs_expansion
43
+
44
+
45
+ def psi_unbalanced(
46
+ server: int,
47
+ client: int,
48
+ server_n: int,
49
+ client_n: int,
50
+ server_items: el.Object,
51
+ client_items: el.Object,
52
+ ) -> el.Object:
53
+ """Unbalanced PSI with O(client_n) communication.
54
+
55
+ This protocol is optimized for scenarios where client_n << server_n.
56
+
57
+ Security:
58
+ - Uses a cryptographically random Session Seed (128-bit) generated at RUNTIME.
59
+ - Both Key and Value derivations include the Seed.
60
+ - Prevents offline pre-computation (Rainbow Table) attacks.
61
+ - WARNING: Online dictionary attacks by active clients remain possible.
62
+
63
+ > [!WARNING]
64
+ > **Security Notice**: This protocol sends the Session Seed to the Client to allow
65
+ > them to compute the OKVS lookups. A malicious Client can perform an online
66
+ > dictionary attack (brute-force hashing) to enumerate Server items.
67
+ > For strict set privacy against malicious clients, use OPRF-PSI (`oprf.py` based)
68
+ > instead of this unbalanced protocol.
69
+
70
+ Args:
71
+ server: Rank of server (holds large set N)
72
+ client: Rank of client (holds small set n)
73
+ server_n: Size of server's set
74
+ client_n: Size of client's set
75
+ server_items: (server_n,) uint64 on server
76
+ client_items: (client_n,) uint64 on client
77
+
78
+ Returns:
79
+ Intersection indicators on client: (client_n,) uint8
80
+ """
81
+ if server == client:
82
+ raise ValueError("Server and Client must be different parties.")
83
+
84
+ if client_n <= 0 or server_n <= 0:
85
+ raise ValueError("Set sizes must be positive.")
86
+
87
+ # =========================================================================
88
+ # 1. Server Setup: Generate Runtime Random Seed
89
+ # =========================================================================
90
+
91
+ # Generate 16 bytes (128-bit) of cryptographically secure random data
92
+ # AT RUNTIME on the Server party (not during trace!)
93
+ def _gen_runtime_seed() -> Any:
94
+ # Use new API: directly generate (2,) u64 tensor
95
+ return crypto.random_tensor((2,), elt.u64)
96
+
97
+ server_seed = simp.pcall_static((server,), _gen_runtime_seed)
98
+
99
+ # =========================================================================
100
+ # Hashing Helpers (Both Key and Value use Seed)
101
+ # =========================================================================
102
+
103
+ def _compute_hashes(items: Any, seed: Any) -> tuple[Any, Any]:
104
+ """Compute Derived Key K' and Validation Value V for items.
105
+
106
+ Both Key and Value are derived using the session Seed to prevent
107
+ pre-computation attacks.
108
+
109
+ Key: K' = AES_Expand(H_key(Item, Seed))[:64bit]
110
+ Value: V = AES_Expand(H_val(Item, Seed))[:128bit]
111
+ """
112
+
113
+ # Domain separator for Key derivation
114
+ KEY_DOMAIN = jnp.uint64(0xA5A5A5A5A5A5A5A5)
115
+ # Domain separator for Value derivation
116
+ VAL_DOMAIN = jnp.uint64(0x5A5A5A5A5A5A5A5A)
117
+
118
+ def _prepare_key_seed(x: Any, s: Any) -> Any:
119
+ # x: (N,) u64, s: (2,) u64
120
+ # Mix with KEY domain separator
121
+ k_lo = (x + s[0]) ^ KEY_DOMAIN
122
+ k_hi = (x ^ s[1]) + KEY_DOMAIN
123
+ return jnp.stack([k_lo, k_hi], axis=1)
124
+
125
+ def _prepare_val_seed(x: Any, s: Any) -> Any:
126
+ # x: (N,) u64, s: (2,) u64
127
+ # Mix with VAL domain separator (different from key)
128
+ v_lo = (x + s[0]) ^ VAL_DOMAIN
129
+ v_hi = (x ^ s[1]) + VAL_DOMAIN
130
+ return jnp.stack([v_lo, v_hi], axis=1)
131
+
132
+ # Derive Keys
133
+ key_seeds = tensor.run_jax(_prepare_key_seed, items, seed)
134
+ h_keys_raw = field.aes_expand(key_seeds, 1) # (N, 1, 2)
135
+
136
+ def _extract_key(h: Any) -> Any:
137
+ return h[:, 0, 0]
138
+
139
+ keys = tensor.run_jax(_extract_key, h_keys_raw)
140
+
141
+ # Derive Values (ALSO using seed - fixes Value Oracle Attack)
142
+ val_seeds = tensor.run_jax(_prepare_val_seed, items, seed)
143
+ h_vals_raw = field.aes_expand(val_seeds, 1) # (N, 1, 2)
144
+
145
+ def _flatten(h: Any) -> Any:
146
+ return h.reshape(h.shape[0], 2)
147
+
148
+ vals = tensor.run_jax(_flatten, h_vals_raw)
149
+
150
+ return keys, vals
151
+
152
+ # Server computes K' and V
153
+ server_derived_keys, server_values = simp.pcall_static(
154
+ (server,), _compute_hashes, server_items, server_seed
155
+ )
156
+
157
+ # Server Solves OKVS
158
+ expansion = get_okvs_expansion(server_n)
159
+ M = int(server_n * expansion)
160
+
161
+ def _solve(k: Any, v: Any, s: Any) -> Any:
162
+ return field.solve_okvs(k, v, M, s)
163
+
164
+ okvs_table = simp.pcall_static(
165
+ (server,), _solve, server_derived_keys, server_values, server_seed
166
+ )
167
+
168
+ # Send to Client
169
+ okvs_table_client = simp.shuffle_static(okvs_table, {client: server})
170
+ client_seed = simp.shuffle_static(server_seed, {client: server})
171
+
172
+ # =========================================================================
173
+ # 2. Client Operations
174
+ # =========================================================================
175
+
176
+ # Client computes k' and expected V using the SAME hash functions
177
+ client_derived_keys, client_expected_values = simp.pcall_static(
178
+ (client,), _compute_hashes, client_items, client_seed
179
+ )
180
+
181
+ # Client Decodes OKVS and Compares
182
+ def _decode_and_compare(keys: Any, table: Any, expected: Any, s: Any) -> Any:
183
+ decoded = field.decode_okvs(keys, table, s)
184
+
185
+ def _compare_jax(dec: Any, exp: Any) -> Any:
186
+ match = (dec[:, 0] == exp[:, 0]) & (dec[:, 1] == exp[:, 1])
187
+ return match.astype(jnp.uint8)
188
+
189
+ return tensor.run_jax(_compare_jax, decoded, expected)
190
+
191
+ intersection_mask = simp.pcall_static(
192
+ (client,),
193
+ _decode_and_compare,
194
+ client_derived_keys,
195
+ okvs_table_client,
196
+ client_expected_values,
197
+ client_seed,
198
+ )
199
+
200
+ return cast(el.Object, intersection_mask)
@@ -0,0 +1,31 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Vector Oblivious Linear Evaluation (VOLE) protocols.
16
+
17
+ Submodules:
18
+ - gilboa: Gilboa VOLE protocol
19
+ - silver: Silver VOLE (LDPC-based)
20
+ - ldpc: LDPC matrix operations
21
+ """
22
+
23
+ from .gilboa import vole
24
+ from .silver import estimate_silver_communication, silver_vole, silver_vole_ldpc
25
+
26
+ __all__ = [
27
+ "estimate_silver_communication",
28
+ "silver_vole",
29
+ "silver_vole_ldpc",
30
+ "vole",
31
+ ]
@@ -0,0 +1,327 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Vector Oblivious Linear Evaluation (VOLE) Protocol.
16
+
17
+ Implements the Gilboa protocol for VOLE over GF(2^k).
18
+ Global SIMP implementation.
19
+ """
20
+
21
+ from collections.abc import Callable
22
+ from typing import Any, cast
23
+
24
+ import jax.numpy as jnp
25
+ import numpy as np
26
+
27
+ import mplang.v2.edsl as el
28
+ import mplang.v2.libs.mpc.ot.extension as ot
29
+ from mplang.v2.dialects import field, simp, tensor
30
+
31
+
32
+ def vole(
33
+ sender: int,
34
+ receiver: int,
35
+ n: int,
36
+ u_provider: Callable[[], el.Object],
37
+ delta_provider: Callable[[], el.Object],
38
+ return_secrets: bool = False,
39
+ ) -> tuple[el.Object, el.Object] | tuple[el.Object, el.Object, el.Object, el.Object]:
40
+ """Execute VOLE Protocol (Gilboa).
41
+
42
+ Args:
43
+ sender: Rank of Sender.
44
+ receiver: Rank of Receiver.
45
+ n: Vector length.
46
+ u_provider: Callable running on Sender returning u (N, 2).
47
+ delta_provider: Callable running on Receiver returning delta (2,).
48
+ return_secrets: If True, returns (v, w, u, delta).
49
+
50
+ Returns:
51
+ If return_secrets=False:
52
+ v: Vector on Sender (N, 2).
53
+ w: Vector on Receiver (N, 2).
54
+ If return_secrets=True:
55
+ v, w, u, delta
56
+ """
57
+ K = 128
58
+
59
+ # 1. Receiver decomp Delta
60
+ def _recv_prep() -> tuple[el.Object, el.Object]:
61
+ delta = delta_provider()
62
+
63
+ # Decompose
64
+ # delta is (2,) u64.
65
+ # Run JAX to unpack
66
+ def _unpack(d: Any) -> Any:
67
+ return jnp.unpackbits(d.view(jnp.uint8), bitorder="little")
68
+
69
+ bits_u8 = tensor.run_jax(_unpack, delta) # (128,) u8
70
+ # Reshape to (128, 1) using run_jax for XLA optimization
71
+ bits_reshaped = tensor.run_jax(lambda x: x.reshape(128, 1), bits_u8)
72
+ return delta, bits_reshaped
73
+
74
+ delta_and_bits = simp.pcall_static((receiver,), _recv_prep)
75
+ # Extract
76
+ delta_recv = simp.pcall_static((receiver,), lambda x: x[0], delta_and_bits)
77
+ delta_bits = simp.pcall_static((receiver,), lambda x: x[1], delta_and_bits)
78
+
79
+ # 2. Run IKNP OT Core
80
+ # Returns t (Sender), q (Receiver), s (Sender)
81
+ # Note: In standard IKNP, Receiver chooses. Sender gets keys.
82
+ # Here VOLE Receiver chooses (delta bits).
83
+ # So VOLE Receiver is OT Receiver.
84
+ # We need 128 OTs for Gilboa. Result is (128, 128) matrices.
85
+ t_matrix_128, q_matrix_128, s_choices = ot.iknp_core(
86
+ delta_bits, sender, receiver, K
87
+ )
88
+
89
+ # t_matrix_128: (128, 128) - 128 OT seeds, each 128 bits wide
90
+ # These are the "Seeds" for the Gilboa extension.
91
+ # Sender has T (128 seeds).
92
+ # Receiver has Q (128 seeds).
93
+ # Wait, IKNP usage usually:
94
+ # Q = T ^ (choices * S).
95
+ # Row i of Q is Q_i = T_i ^ (c_i * S).
96
+ # c_i is delta_i.
97
+ # S is the base OT choice vector (global secret S).
98
+
99
+ # We need:
100
+ # Sender has S_{i,0}, S_{i,1}.
101
+ # Receiver has S_{i, d_i}.
102
+ #
103
+ # IKNP gives:
104
+ # Col j of Q = Col j of T ^ (c * S_j) ? No.
105
+
106
+ # Let's map IKNP output to Gilboa needs.
107
+ # IKNP gives:
108
+ # For generated OT i (0..127):
109
+ # Sender holds T[i] (block).
110
+ # Receiver holds Q[i] (block).
111
+ # Q[i] = T[i] ^ (c[i] * S).
112
+ # Where S is the Base OT Choice (held by Sender of IKNP = Sender of VOLE).
113
+ # Wait, Sender acts as Receiver in BaseOT usually.
114
+ # In `ot_extension.py`: `s` (base choices) is on Sender.
115
+ # So Q[i] = T[i] ^ (delta_i * s).
116
+
117
+ # This gives us CORRELATED SEEDS.
118
+ # Sender has T[i] and s.
119
+ # Receiver has Q[i].
120
+
121
+ # Gilboa needs:
122
+ # Sender sends u * x^i masked.
123
+ # We can use T[i] and (T[i]^s) as the seeds for random strings?
124
+ #
125
+ # Q[i] is ONE seed.
126
+ # If delta_i = 0, Q[i] = T[i].
127
+ # If delta_i = 1, Q[i] = T[i] ^ s.
128
+
129
+ # So Sender has two seeds for bit i:
130
+ # Seed0 = T[i]
131
+ # Seed1 = T[i] ^ s
132
+
133
+ # This is perfect! IKNP *is* ROT.
134
+
135
+ # 3. Expansion
136
+ # Sender expands:
137
+ # V0_i = PRG(T[i], N)
138
+ # V1_i = PRG(T[i] ^ s, N)
139
+
140
+ # Receiver expands:
141
+ # W_i = PRG(Q[i], N)
142
+ # Note: W_i = V0_i if delta_i=0
143
+ # W_i = V1_i if delta_i=1
144
+
145
+ # Sender computes correction:
146
+ # M_i = V0_i ^ V1_i ^ (u * x^i)
147
+ # M_i = PRG(T) ^ PRG(T^s) ^ (u * x^i)
148
+
149
+ # Receiver computes:
150
+ # result_i = W_i ^ (delta_i * M_i)
151
+ # = V_{delta_i} ^ (delta_i * (V0^V1^term))
152
+ # if d=0: W = V0. Res = V0. Correct.
153
+ # if d=1: W = V1. Res = V1 ^ V0 ^ V1 ^ term = V0 ^ term.
154
+ # Wait.
155
+ # We want result = V0 + ... ?
156
+ # Gilboa: v = Sum(V0).
157
+ # w = v + u*delta.
158
+ #
159
+ # If d=0: Res = V0.
160
+ # If d=1: Res = V0 ^ term.
161
+ # Sum(Res) = Sum(V0) ^ Sum(d_i * term) = v ^ (u * Sum(d_i x^i)) = v + u*delta.
162
+ # Correct.
163
+
164
+ # Implementation:
165
+
166
+ # Capture U on Sender
167
+ def _sender_wrapper() -> el.Object:
168
+ u = u_provider()
169
+ return u
170
+
171
+ u_loc_captured = simp.pcall_static((sender,), _sender_wrapper)
172
+
173
+ m_corrections, v_sender = simp.pcall_static(
174
+ (sender,), _sender_round, t_matrix_128, s_choices, u_loc_captured, n
175
+ )
176
+
177
+ # Shuffle M to Receiver
178
+ from jax.tree_util import tree_map
179
+
180
+ m_recv = tree_map(
181
+ lambda x: simp.shuffle_static(x, {receiver: sender}), m_corrections
182
+ )
183
+
184
+ w_receiver = simp.pcall_static(
185
+ (receiver,), _recv_round, q_matrix_128, m_recv, delta_bits, n
186
+ )
187
+
188
+ if return_secrets:
189
+ return v_sender, w_receiver, u_loc_captured, delta_recv
190
+ else:
191
+ return v_sender, w_receiver
192
+
193
+
194
+ # A. Expand (Sender)
195
+ def _sender_round(
196
+ t_loc: el.Object, s_loc: el.Object, u_loc: el.Object, n: int
197
+ ) -> tuple[el.Object, el.Object]:
198
+ # t_loc: (128, 128)
199
+ # s_loc: (128,)
200
+ # u_loc: (N, 2)
201
+
202
+ # 0. Prep Seeds
203
+ def _prep_sender_seeds(t: Any, s: Any) -> tuple[Any, Any]:
204
+ # t: (128, 128) bits
205
+ # s: (128,) bits
206
+ t_seeds = jnp.packbits(t, axis=-1) # (128, 16) uint8
207
+ s_bytes = jnp.packbits(s, axis=-1) # (16,)
208
+ s_broad = jnp.expand_dims(s_bytes, 0) # (1, 16)
209
+ t_xor_s_seeds = jnp.bitwise_xor(t_seeds, s_broad)
210
+ return t_seeds, t_xor_s_seeds
211
+
212
+ t_seeds, t_s_seeds = tensor.run_jax(_prep_sender_seeds, t_loc, s_loc)
213
+ t_seeds = cast(el.Object, t_seeds)
214
+ t_s_seeds = cast(el.Object, t_s_seeds)
215
+
216
+ # 1. Expand
217
+ v0_expanded = field.aes_expand(t_seeds, n)
218
+ v1_expanded = field.aes_expand(t_s_seeds, n)
219
+
220
+ # 2. Compute term = u * powers using Field Arithmetic
221
+ # Vectorized Version:
222
+ # u_loc: (N, 2)
223
+ # powers: (128, 2)
224
+ # term: (128, N, 2) = u_loc * p_broad
225
+
226
+ # Generate Powers of X (128, 2) CONSTANT
227
+ # 1, x, x^2 ...
228
+ powers_list = []
229
+ for i in range(128):
230
+ lo, hi = 0, 0
231
+ if i < 64:
232
+ lo = 1 << i
233
+ else:
234
+ hi = 1 << (i - 64)
235
+ powers_list.append([lo, hi])
236
+ powers_arr = np.array(powers_list, dtype=np.uint64)
237
+ powers_const = tensor.constant(powers_arr)
238
+
239
+ # Broadcast for Vectorized Mul
240
+ # u_loc: (N, 2) -> (1, N, 2) -> (128, N, 2)
241
+ # powers: (128, 2) -> (128, 1, 2) -> (128, N, 2)
242
+
243
+ def _broadcast_inputs(u_val: Any, p_val: Any) -> tuple[Any, Any]:
244
+ # u: (N, 2)
245
+ # p: (128, 2)
246
+ n_ = u_val.shape[0]
247
+
248
+ # Tile U: (128, N, 2)
249
+ u_broad = jnp.tile(u_val[None, :, :], (128, 1, 1))
250
+
251
+ # Tile P: (128, N, 2)
252
+ p_broad = jnp.tile(p_val[:, None, :], (1, n_, 1))
253
+
254
+ return u_broad, p_broad
255
+
256
+ u_vec, p_vec = tensor.run_jax(_broadcast_inputs, u_loc, powers_const)
257
+
258
+ # Single Batched Mul
259
+ term_val = field.mul(u_vec, p_vec) # (128, N, 2)
260
+
261
+ # 3. Compute Corrections
262
+ def _sender_calc(v0: Any, v1: Any, term: Any) -> tuple[Any, Any]:
263
+ # v0: (128, N, 2)
264
+ # v1: (128, N, 2)
265
+ # term: (128, N, 2)
266
+
267
+ m_out = v0 ^ v1 ^ term
268
+
269
+ # v_out sum
270
+ v_out = v0[0]
271
+ for i in range(1, 128):
272
+ v_out = v_out ^ v0[i]
273
+
274
+ return m_out, v_out
275
+
276
+ m_corr, v = tensor.run_jax(_sender_calc, v0_expanded, v1_expanded, term_val)
277
+ return cast(el.Object, m_corr), cast(el.Object, v)
278
+
279
+
280
+ # B. Expand & Reconstruct (Receiver)
281
+ def _recv_round(
282
+ q_loc: el.Object, m_loc: el.Object, d_bits: el.Object, n: int
283
+ ) -> el.Object:
284
+ # 0. Prep Seeds
285
+ def _prep_recv_seeds(q: Any) -> Any:
286
+ return jnp.packbits(q, axis=-1)
287
+
288
+ q_seeds = tensor.run_jax(_prep_recv_seeds, q_loc)
289
+
290
+ # 1. AES Expand
291
+ w_expanded = field.aes_expand(q_seeds, n) # (128, N, 2)
292
+
293
+ # 2. Reconstruct
294
+ def _recv_calc(w_exp: Any, m_val: Any, d_b: Any) -> Any:
295
+ # w_exp: (128, N, 2)
296
+ # m_val: (128, N, 2)
297
+ # d_b: (128, 1) bits from earlier
298
+
299
+ d_flat = d_b.reshape(128)
300
+ # Mask M
301
+ # m_val is u64. d_flat is u8(?).
302
+ mask = d_flat.reshape(128, 1, 1).astype(bool)
303
+ m_masked = jnp.where(mask, m_val, jnp.zeros_like(m_val))
304
+
305
+ res_i = w_exp ^ m_masked
306
+ # w_final = jnp.bitwise_xor.reduce(res_i, axis=0)
307
+ w_final = res_i[0]
308
+ for i in range(1, 128):
309
+ w_final = w_final ^ res_i[i]
310
+
311
+ return w_final
312
+
313
+ return cast(el.Object, tensor.run_jax(_recv_calc, w_expanded, m_loc, d_bits))
314
+
315
+
316
+ def _gen_powers_of_x_jax(dummy: Any, k: int = 128) -> Any:
317
+ # JAX version for use inside run_jax (returns jnp.array)
318
+ # dummy is required for run_jax tracing anchor
319
+ rows = []
320
+ for i in range(k):
321
+ lo, hi = 0, 0
322
+ if i < 64:
323
+ lo = 1 << i
324
+ else:
325
+ hi = 1 << (i - 64)
326
+ rows.append([lo, hi])
327
+ return jnp.array(rows, dtype=jnp.uint64)