mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,477 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """OT Extension (IKNP).
16
+
17
+ Implements IKNP OT extension protocol to perform N OTs using k Base OTs.
18
+ Ref: https://crypto.stanford.edu/~valeria/research/2003/IKNP03.pdf
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ from typing import Any, cast
24
+
25
+ import jax.numpy as jnp
26
+
27
+ import mplang.v2.edsl as el
28
+ from mplang.v2.dialects import crypto, field, simp, tensor
29
+
30
+
31
+ def prg_expand(seed_tensor: el.Object, length: int) -> el.Object:
32
+ """Pseudo-Random Generator: Expand seed to `length` bits (as uint8 0/1).
33
+
34
+ Uses AES-NI via field.aes_expand for cryptographic security.
35
+ """
36
+ # Calculate number of 128-bit blocks needed to cover `length` bits.
37
+ # field.aes_expand returns (K, M, 2) uint64 blocks.
38
+ # Total bits = M * 128.
39
+
40
+ m_blocks = (length + 127) // 128
41
+
42
+ # Input seed_tensor is (K, 32) bytes (uint8).
43
+ # field.aes_expand expects (K, 2) uint64 seeds.
44
+
45
+ def _reshape_seeds(s_bytes: Any) -> Any:
46
+ # s_bytes: (K, 32) u8
47
+ # Take first 16 bytes for 128-bit key/seed
48
+ s_16 = s_bytes[:, :16]
49
+ return s_16.view(jnp.uint64).reshape(-1, 2)
50
+
51
+ seeds_u64 = tensor.run_jax(_reshape_seeds, seed_tensor)
52
+
53
+ expanded_blocks = field.aes_expand(seeds_u64, m_blocks) # (K, M, 2) u64
54
+
55
+ # Convert blocks to bits
56
+ def _blocks_to_bits(blocks: Any) -> Any:
57
+ # blocks: (K, M, 2) u64
58
+ # unpackbits
59
+ # view as u8
60
+ bytes_view = blocks.view(jnp.uint8) # (K, M, 16)
61
+ bits = jnp.unpackbits(bytes_view, axis=-1, bitorder="little") # (K, M, 128)
62
+
63
+ # Flatten last two dims
64
+ bits_flat = bits.reshape(bits.shape[0], -1)
65
+
66
+ # Slice to exact length
67
+ return bits_flat[:, :length]
68
+
69
+ return cast(el.Object, tensor.run_jax(_blocks_to_bits, expanded_blocks))
70
+
71
+
72
+ def vec_hash(data_bytes: el.Object, domain_sep: int, num_rows: int) -> el.Object:
73
+ """Hash rows of a (N, D) tensor independently.
74
+
75
+ Args:
76
+ data_bytes: (N, D) tensor to hash.
77
+ domain_sep: Integer domain separator to mix into the hash.
78
+ num_rows: Number of rows N. Must be provided explicitly.
79
+ """
80
+ # Optimized batch hashing:
81
+ # 1. Prepend domain_sep to all rows (vectorized concatenation)
82
+ # 2. Call crypto.hash_bytes once on the whole tensor
83
+
84
+ if domain_sep != 0:
85
+
86
+ def _prepend_ds(arr: Any, ds: int) -> Any:
87
+ # arr: (N, D)
88
+ N = arr.shape[0]
89
+ # Create (N, 8) domain sep block using repeat & reshape
90
+ # ds_arr: (8,)
91
+ ds_arr = jnp.array([ds], dtype=jnp.uint64).view(jnp.uint8)
92
+ # Broadcast to (N, 8)
93
+ ds_block = jnp.broadcast_to(ds_arr, (N, 8))
94
+
95
+ return jnp.concatenate([ds_block, arr], axis=1)
96
+
97
+ # Result: (N, D+8)
98
+ data_to_hash = tensor.run_jax(lambda a: _prepend_ds(a, domain_sep), data_bytes)
99
+ else:
100
+ data_to_hash = data_bytes
101
+
102
+ # Call batched hash_bytes
103
+ # Input: (N, D_total) -> Output: (N, 32)
104
+ # This generates a single graph node, solving the compiler explosion issue.
105
+ # explicit hash_batch primitive (rank >= 2)
106
+ hashes = crypto.hash_batch(data_to_hash)
107
+
108
+ return hashes
109
+
110
+
111
+ def iknp_core(
112
+ choice_bits: el.Object, sender: int, receiver: int, num_ots: int
113
+ ) -> tuple[el.Object, el.Object, el.Object]:
114
+ """Core IKNP Matrix Generation.
115
+
116
+ Returns:
117
+ t_matrix: (N, K) bit matrix on Sender.
118
+ q_matrix: (N, K) bit matrix on Receiver.
119
+ s_choices: (K,) choice bits on Sender (s).
120
+ """
121
+ K = 128
122
+
123
+ # 1. Base OTs
124
+ def gen_s() -> el.Object:
125
+ # Generate random bits at runtime using new API
126
+ return crypto.random_bits(K)
127
+
128
+ s = simp.pcall_static((sender,), gen_s)
129
+
130
+ def gen_seeds() -> tuple[el.Object, el.Object]:
131
+ # Generate random bytes at runtime
132
+ k0_bytes = crypto.random_bytes(K * 32)
133
+ k1_bytes = crypto.random_bytes(K * 32)
134
+
135
+ # Reshape to (K, 32) using run_jax for XLA optimization
136
+ def _reshape_k32(b: Any) -> Any:
137
+ return b.reshape(K, 32)
138
+
139
+ k0 = tensor.run_jax(_reshape_k32, k0_bytes)
140
+ k1 = tensor.run_jax(_reshape_k32, k1_bytes)
141
+ return k0, k1
142
+
143
+ k0_base, k1_base = simp.pcall_static((receiver,), gen_seeds)
144
+
145
+ # Base OT Logic (Inlined)
146
+ # C (Common Point) initialization
147
+ # SECURITY FIX: C must be generated by the Base Sender (receiver in IKNP context)
148
+ # to prevent Base Receiver (sender in IKNP context) from knowing the discrete log,
149
+ # which would allow them to decrypt both messages and recover choice bits.
150
+ def base_param_gen() -> el.Object:
151
+ C = crypto.ec_mul(crypto.ec_generator(), crypto.ec_random_scalar())
152
+ return C
153
+
154
+ C_point = simp.pcall_static((receiver,), base_param_gen)
155
+ C_for_sender = simp.shuffle_static(C_point, {sender: receiver})
156
+
157
+ # Duplicate initialization removed
158
+
159
+ # R (Sender of BaseOT) keygen
160
+ def base_sender_keygen(
161
+ C: el.Object, s_base_choices: el.Object
162
+ ) -> tuple[el.Object, list[el.Object]]:
163
+ # s_base_choices is (K,) Tensor
164
+ PK0_bytes_list = []
165
+ k_priv_list = []
166
+
167
+ for i in range(K):
168
+ k_priv = crypto.ec_random_scalar()
169
+ PK_sigma = crypto.ec_mul(crypto.ec_generator(), k_priv)
170
+
171
+ # Slice s[i]
172
+ # Using slice_tensor here is efficient since s_base_choices is small; the overhead of run_jax is unnecessary.
173
+ s_i = tensor.slice_tensor(s_base_choices, (i,), (i + 1,))
174
+ s_scalar = crypto.ec_scalar_from_int(s_i)
175
+
176
+ diff = crypto.ec_sub(C, PK_sigma)
177
+ # select checks s_scalar. If 1 (true), pick diff.
178
+ PK0 = crypto.select(s_scalar, diff, PK_sigma)
179
+
180
+ # Convert to bytes (65 bytes uncompressed)
181
+ # K is 128, so overhead is small.
182
+ # Stacking points for shuffle.
183
+ pk0_b = crypto.ec_point_to_bytes(PK0)
184
+ # Reshape for stack: (65,) -> (1, 65)
185
+ pk0_b_r = tensor.reshape(pk0_b, (1, 65))
186
+
187
+ PK0_bytes_list.append(pk0_b_r)
188
+ k_priv_list.append(k_priv)
189
+
190
+ # Stack into (K, 65)
191
+ PK0_stacked = tensor.concat(PK0_bytes_list, axis=0)
192
+
193
+ return PK0_stacked, k_priv_list
194
+
195
+ # base_keys -> (PK0_stacked, k_priv_list (TraceObject list))
196
+ # Pass C_for_sender (received from receiver) to sender
197
+ base_keys_tuple = simp.pcall_static((sender,), base_sender_keygen, C_for_sender, s)
198
+
199
+ # Extract
200
+ PK0_loc = simp.pcall_static((sender,), lambda x: x[0], base_keys_tuple)
201
+ # Note: k_priv (x[1]) stays on sender, used later in base_decrypt_rev via base_keys_tuple
202
+ PK0_recv = simp.shuffle_static(PK0_loc, {receiver: sender})
203
+
204
+ # R (Base Sender) Encrypts k0, k1
205
+ def base_encrypt_rev(
206
+ C: el.Object,
207
+ PK0_bytes_tensor: el.Object,
208
+ m0_tensor: el.Object,
209
+ m1_tensor: el.Object,
210
+ ) -> tuple[el.Object, el.Object, el.Object]:
211
+ # m0, m1 are (K, 32) tensors.
212
+ # PK0_bytes_tensor is (K, 65)
213
+
214
+ U_bytes_list = []
215
+ c0_list = []
216
+ c1_list = []
217
+
218
+ for i in range(K):
219
+ # Unstack PK0
220
+ # PK0_bytes_tensor is (K, 65)
221
+ # We want row i, all 65 bytes: slice(i:i+1, 0:65)
222
+ pk0_b = tensor.slice_tensor(PK0_bytes_tensor, (i, 0), (i + 1, 65))
223
+ # Reshape to (65,) for conversion
224
+ pk0_b_flat = tensor.reshape(pk0_b, (65,))
225
+ PK0 = crypto.ec_bytes_to_point(pk0_b_flat)
226
+
227
+ r = crypto.ec_random_scalar()
228
+ U = crypto.ec_mul(crypto.ec_generator(), r)
229
+
230
+ # Stack U as bytes
231
+ u_b = crypto.ec_point_to_bytes(U)
232
+ u_b_r = tensor.reshape(u_b, (1, 65))
233
+ U_bytes_list.append(u_b_r)
234
+
235
+ K0_point = crypto.ec_mul(PK0, r)
236
+ PK1 = crypto.ec_sub(C, PK0)
237
+ K1_point = crypto.ec_mul(PK1, r)
238
+
239
+ sk0 = crypto.hash_bytes(crypto.ec_point_to_bytes(K0_point)) # (32,)
240
+ sk1 = crypto.hash_bytes(crypto.ec_point_to_bytes(K1_point))
241
+
242
+ # Extract row i and encrypt in single run_jax block
243
+ def _slice_and_enc(
244
+ m0_full: Any, m1_full: Any, k0: Any, k1: Any, idx: int = i
245
+ ) -> tuple[Any, Any]:
246
+ # Slice row i and reshape to (32,)
247
+ m0_row = m0_full[idx].flatten()
248
+ m1_row = m1_full[idx].flatten()
249
+ # XOR with keys
250
+ c0 = jnp.bitwise_xor(m0_row, k0)
251
+ c1 = jnp.bitwise_xor(m1_row, k1)
252
+ return c0.reshape(1, -1), c1.reshape(1, -1)
253
+
254
+ c0, c1 = tensor.run_jax(_slice_and_enc, m0_tensor, m1_tensor, sk0, sk1)
255
+
256
+ c0_list.append(c0)
257
+ c1_list.append(c1)
258
+
259
+ # Stack outputs
260
+ U_stacked = tensor.concat(U_bytes_list, axis=0) # (K, 65)
261
+ c0_stacked = tensor.concat(c0_list, axis=0) # (K, 32) (assuming 32 byte msgs)
262
+ c1_stacked = tensor.concat(c1_list, axis=0) # (K, 32)
263
+
264
+ return U_stacked, c0_stacked, c1_stacked
265
+
266
+ base_cts_rev = simp.pcall_static(
267
+ (receiver,), base_encrypt_rev, C_point, PK0_recv, k0_base, k1_base
268
+ )
269
+
270
+ # Shuffle tuple(Tensor, Tensor, Tensor) - Efficient!
271
+ # tree_map handles tuple
272
+ from jax.tree_util import tree_map
273
+
274
+ base_cts_s = tree_map(
275
+ lambda x: simp.shuffle_static(x, {sender: receiver}), base_cts_rev
276
+ )
277
+
278
+ def base_decrypt_rev(
279
+ keys: tuple[el.Object, list[el.Object]],
280
+ cts: tuple[el.Object, el.Object, el.Object],
281
+ s_choices: el.Object,
282
+ ) -> el.Object:
283
+ # keys[0] is PK0_stacked (unused here), keys[1] is k_priv_list
284
+ _, k_priv_list = keys
285
+ # cts are stacked (K, 65), (K, 32), (K, 32)
286
+ U_packed, c0_packed, c1_packed = cts
287
+
288
+ k_s_rows = []
289
+
290
+ for i in range(K):
291
+ k_priv = k_priv_list[i]
292
+
293
+ # Unstack U
294
+ u_b = tensor.slice_tensor(U_packed, (i, 0), (i + 1, 65))
295
+ u_b_flat = tensor.reshape(u_b, (65,))
296
+ U = crypto.ec_bytes_to_point(u_b_flat)
297
+
298
+ # Unstack c0, c1
299
+ c0 = tensor.slice_tensor(c0_packed, (i, 0), (i + 1, 32))
300
+ c1 = tensor.slice_tensor(c1_packed, (i, 0), (i + 1, 32))
301
+ # Reshape to (32,) or (1, 32) depending on what _slice_dec_reshape expects
302
+ # _slice_dec_reshape expects (32,) usually if we want flat XOR?
303
+ # Let's check _slice_dec_reshape: it does `jnp.bitwise_xor(chosen_c, k)`.
304
+ # sk is (32,). So chosen_c should be (32,).
305
+ c0 = tensor.reshape(c0, (32,))
306
+ c1 = tensor.reshape(c1, (32,))
307
+
308
+ # Recov K = U^k_priv
309
+ SharedK = crypto.ec_mul(U, k_priv)
310
+ sk = crypto.hash_bytes(crypto.ec_point_to_bytes(SharedK))
311
+
312
+ # Combined slice + decrypt + reshape in single run_jax
313
+ def _slice_dec_reshape(
314
+ s_arr: Any, k: Any, c0_: Any, c1_: Any, idx: int = i
315
+ ) -> Any:
316
+ sel = s_arr[idx]
317
+ chosen_c = jnp.where(sel == 0, c0_, c1_)
318
+ result = jnp.bitwise_xor(chosen_c, k)
319
+ return result.reshape(1, 32)
320
+
321
+ res_row = tensor.run_jax(_slice_dec_reshape, s_choices, sk, c0, c1)
322
+ k_s_rows.append(res_row)
323
+
324
+ # Concat using tensor.concat (run_jax with many args can cause tracing issues)
325
+ return tensor.concat(k_s_rows, axis=0)
326
+
327
+ k_s = simp.pcall_static((sender,), base_decrypt_rev, base_keys_tuple, base_cts_s, s)
328
+
329
+ # 2. PRG Expansion & Correction
330
+ def calc_u(k0_loc: el.Object, k1_loc: el.Object, r_loc: el.Object) -> el.Object:
331
+ g_k0 = prg_expand(k0_loc, num_ots) # (K, num_ots)
332
+ g_k1 = prg_expand(k1_loc, num_ots) # (K, num_ots)
333
+
334
+ # choice_bits can be:
335
+ # - (N,) 1D vector for standard IKNP
336
+ # - (N, K) 2D matrix for KKRT OPRF
337
+ #
338
+ # For IKNP: u^j = G(k0^j) ^ G(k1^j) ^ r, where r is broadcast to all K rows
339
+ # For KKRT: u^j = G(k0^j) ^ G(k1^j) ^ r^j, where r is (N, K) transposed to (K, N)
340
+
341
+ # Handle both 1D and 2D inputs
342
+ def _compute_u(g0: Any, g1: Any, r: Any) -> Any:
343
+ # g0, g1: (K, N) bit matrices
344
+ # r: either (N,) or (N, K)
345
+ if r.ndim == 1:
346
+ # 1D case: broadcast (N,) -> (1, N) for XOR with (K, N)
347
+ r_t = jnp.expand_dims(r, axis=0) # (1, N)
348
+ else:
349
+ # 2D case: transpose (N, K) -> (K, N)
350
+ r_t = jnp.transpose(r, (1, 0)) # (K, N)
351
+ return jnp.bitwise_xor(jnp.bitwise_xor(g0, g1), r_t)
352
+
353
+ return cast(el.Object, tensor.run_jax(_compute_u, g_k0, g_k1, r_loc))
354
+
355
+ u = simp.pcall_static((receiver,), calc_u, k0_base, k1_base, choice_bits)
356
+ u_recv = simp.shuffle_static(u, {sender: receiver})
357
+
358
+ # 3. Matrix Recovery & Transpose
359
+ def calc_t(k_s_loc: el.Object, u_loc: el.Object, s_loc: el.Object) -> el.Object:
360
+ g_k_s = prg_expand(k_s_loc, num_ots)
361
+
362
+ def _recover_and_transpose(g: Any, mask: Any, sel: Any) -> Any:
363
+ # Combine recover and transpose into single XLA block
364
+ sel_exp = jnp.expand_dims(sel, axis=-1)
365
+ term = jnp.bitwise_and(mask, sel_exp)
366
+ t_rows = jnp.bitwise_xor(g, term)
367
+ return jnp.transpose(t_rows, (1, 0)) # (N, K)
368
+
369
+ return cast(
370
+ el.Object, tensor.run_jax(_recover_and_transpose, g_k_s, u_loc, s_loc)
371
+ )
372
+
373
+ t_matrix = simp.pcall_static((sender,), calc_t, k_s, u_recv, s)
374
+
375
+ def calc_q(k0_loc: el.Object) -> el.Object:
376
+ g_k0 = prg_expand(k0_loc, num_ots)
377
+ # Use run_jax for transpose to enable XLA fusion
378
+ return cast(el.Object, tensor.run_jax(lambda x: jnp.transpose(x, (1, 0)), g_k0))
379
+
380
+ q_matrix = simp.pcall_static((receiver,), calc_q, k0_base)
381
+
382
+ # s is on Sender. t_matrix on Sender. q_matrix on Receiver.
383
+ return t_matrix, q_matrix, s
384
+
385
+
386
+ def s_choices_sender(s: el.Object) -> el.Object:
387
+ return s # Already pcalled on sender
388
+
389
+
390
+ def transfer_extension(
391
+ m0: el.Object,
392
+ m1: el.Object,
393
+ choice_bits: el.Object,
394
+ sender: int,
395
+ receiver: int,
396
+ num_ots: int,
397
+ ) -> el.Object:
398
+ """Perform IKNP OT Extension."""
399
+
400
+ t_matrix, q_matrix, s = iknp_core(choice_bits, sender, receiver, num_ots)
401
+
402
+ # 4. Encryption
403
+ def encrypt_msgs(
404
+ t_loc: el.Object, s_loc: el.Object, m0_loc: el.Object, m1_loc: el.Object
405
+ ) -> el.Object:
406
+ # t: (N, K)
407
+ # s: (K,)
408
+
409
+ # Hash keys before using them as masks to break linear correlation
410
+ # H(t) and H(t^s)
411
+ # We use domain_sep=1 for IKNP payload masking
412
+ h_t = vec_hash(t_loc, domain_sep=1, num_rows=num_ots)
413
+
414
+ def _xor_s_and_hash(t: Any, s: Any) -> Any:
415
+ t_xor_s = jnp.bitwise_xor(t, s)
416
+ return t_xor_s
417
+
418
+ # We need to compute H(t^s). We can't easily do it in one block with vec_hash
419
+ # unless we compute t^s first.
420
+ t_xor_s_loc = cast(el.Object, tensor.run_jax(_xor_s_and_hash, t_loc, s_loc))
421
+ h_t_xor_s = vec_hash(t_xor_s_loc, domain_sep=1, num_rows=num_ots)
422
+
423
+ def _enc(ht: Any, hts: Any, msg0: Any, msg1: Any) -> Any:
424
+ # ht, hts are mapped to (N, 32) bytes usually, or whatever vec_hash returns
425
+ # msg0, msg1 are (N, D) bytes
426
+
427
+ # Ensure shapes match for XOR
428
+ # vec_hash returns (N, 32)
429
+ # If messages are not 32 bytes, we might need to adjust or truncation?
430
+ # Standard IKNP assumes messages are block size (128 bit = 16 bytes).
431
+ # But vec_hash produces 32 bytes (SHA256 usually).
432
+ # We slice hash to message length.
433
+
434
+ # msg0 shape: (N, 16) usually
435
+ d = msg0.shape[1]
436
+
437
+ ht_sliced = ht[:, :d]
438
+ hts_sliced = hts[:, :d]
439
+
440
+ c0 = jnp.bitwise_xor(msg0, ht_sliced)
441
+ c1 = jnp.bitwise_xor(msg1, hts_sliced)
442
+ return c0, c1
443
+
444
+ return cast(el.Object, tensor.run_jax(_enc, h_t, h_t_xor_s, m0_loc, m1_loc))
445
+
446
+ ciphertexts = simp.pcall_static((sender,), encrypt_msgs, t_matrix, s, m0, m1)
447
+
448
+ from jax.tree_util import tree_map
449
+
450
+ ciphertexts_recv = tree_map(
451
+ lambda x: simp.shuffle_static(x, {receiver: sender}), ciphertexts
452
+ )
453
+
454
+ def decrypt_msg(
455
+ q_loc: el.Object, r_loc: el.Object, c_texts: tuple[el.Object, el.Object]
456
+ ) -> el.Object:
457
+ c0, c1 = c_texts
458
+
459
+ # Hash q: H(q)
460
+ h_q = vec_hash(q_loc, domain_sep=1, num_rows=num_ots)
461
+
462
+ def _dec(hq: Any, r: Any, ct0: Any, ct1: Any) -> Any:
463
+ # hq: (N, 32)
464
+ d = ct0.shape[1]
465
+ hq_sliced = hq[:, :d]
466
+
467
+ m0_cand = jnp.bitwise_xor(ct0, hq_sliced)
468
+ m1_cand = jnp.bitwise_xor(ct1, hq_sliced)
469
+ r_exp = jnp.expand_dims(r, axis=-1)
470
+ return jnp.where(r_exp == 1, m1_cand, m0_cand)
471
+
472
+ return cast(el.Object, tensor.run_jax(_dec, h_q, r_loc, c0, c1))
473
+
474
+ res = simp.pcall_static(
475
+ (receiver,), decrypt_msg, q_matrix, choice_bits, ciphertexts_recv
476
+ )
477
+ return cast(el.Object, res)
@@ -0,0 +1,217 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Silent OT (Random VOLE) Implementation.
16
+
17
+ Implements "Silent Random VOLE" via Linear Expansion (LPN-like).
18
+ This provides O(N) local computation but O(k) communication.
19
+ """
20
+
21
+ from typing import Any, cast
22
+
23
+ import jax.numpy as jnp
24
+
25
+ import mplang.v2.edsl as el
26
+ import mplang.v2.edsl.typing as elt
27
+ import mplang.v2.libs.mpc.vole.gilboa as vole
28
+ from mplang.v2.dialects import crypto, field, simp, tensor
29
+
30
+
31
+ def silent_vole_random_u(
32
+ sender: int,
33
+ receiver: int,
34
+ n: int,
35
+ base_k: int = 1024,
36
+ ) -> tuple[el.Object, el.Object, el.Object, el.Object]:
37
+ """Execute Silent Random VOLE (Linear Expansion).
38
+
39
+ Args:
40
+ sender: Rank of Sender.
41
+ receiver: Rank of Receiver.
42
+ n: Target vector length (e.g. 10^9).
43
+ base_k: Size of Base VOLE (LPN parameter).
44
+
45
+ Returns:
46
+ v, w, u, delta
47
+ Where w = v + u * delta.
48
+ u is RANDOM.
49
+ """
50
+
51
+ # 1. Base VOLE (Standard Gilboa)
52
+ # We need providers for base_u and base_delta.
53
+
54
+ def _base_u_provider() -> el.Object:
55
+ # Random U_base (base_k, 2) using new API
56
+ return crypto.random_tensor((base_k, 2), elt.u64)
57
+
58
+ def _base_delta_provider() -> el.Object:
59
+ # Random Delta (2,) using new API
60
+ return crypto.random_tensor((2,), elt.u64)
61
+
62
+ # v_base: (k, 2), w_base: (k, 2)
63
+ # The return type is a Union, mypy complains about unpacking.
64
+ # We ignore the type error here as we know return_secrets=True returns 4 values.
65
+ v_base, w_base, u_base, delta = vole.vole( # type: ignore
66
+ sender,
67
+ receiver,
68
+ base_k,
69
+ _base_u_provider,
70
+ _base_delta_provider,
71
+ return_secrets=True,
72
+ )
73
+
74
+ # 2. Linear Expansion
75
+ # We rely on a public seed for the mixing matrix M.
76
+ seed = simp.pcall_static((sender,), lambda: crypto.random_bytes(32))
77
+ # Share seed (Receiver needs it too)
78
+ seed_recv = simp.shuffle_static(seed, {receiver: sender}) # S -> R
79
+
80
+ # Expansion Logic
81
+ # We process in chunks to avoid massive implementation limit or memory issues.
82
+ # But EDSL graph optimization might handle loops?
83
+ # For safe side, let's implement a loop over chunks in Python if N is large.
84
+ # However, N is dynamic usually? Here N is int param.
85
+
86
+ # Chunk size
87
+ CHUNK_SIZE = 100_000 # 100k items per chunk
88
+
89
+ # We need to broadcast delta and bases to expansion function?
90
+ # Actually, we can just expand v_base -> v_long, w_base -> w_long.
91
+ # u_long is implicit (u_base * M).
92
+ # Since we need to return u, we expand u_base too.
93
+
94
+ # Define expansion op
95
+ def _expand_chunk(base_vec: Any, mask_packed: Any, chunk_len: int) -> Any:
96
+ # base_vec: (K, 2) u64
97
+ # mask_packed: (K, blocks, 2) u64 (AES output)
98
+ # chunk_len: number of bits to extract
99
+
100
+ # 1. Unpack bits from mask_packed
101
+ # mask_packed is (K, blocks, 2) u64.
102
+ # View as u8: (K, blocks, 16)
103
+ mask_u8 = mask_packed.view(jnp.uint8)
104
+
105
+ # Unpack bits: (K, blocks, 16, 8)
106
+ bits = jnp.unpackbits(mask_u8, bitorder="little")
107
+
108
+ # Flatten to (K, total_bits)
109
+ bits_flat = bits.reshape(base_k, -1)
110
+
111
+ # Slice to chunk_len
112
+ mask_mat = bits_flat[:, :chunk_len].astype(jnp.uint64)
113
+
114
+ # Broadcast base_vec (K, 2)
115
+ # We want: out[c] = XOR_sum_j (base[j] * mask[j, c])
116
+
117
+ base_shuffled = base_vec.reshape(base_k, 1, 2)
118
+ mask_expanded = mask_mat.reshape(base_k, chunk_len, 1)
119
+
120
+ # term[j, c] = base[j] * mask[j, c]
121
+ # mask is 0 or 1 (uint64). Multiplication works as selection.
122
+ terms = base_shuffled * mask_expanded # (K, chunk, 2)
123
+
124
+ # XOR Reduce over K
125
+ # Use simple loop or scan.
126
+ # terms: (K, chunk, 2)
127
+
128
+ def _xor_scan(carry: Any, x: Any) -> tuple[Any, Any]:
129
+ new_carry = jnp.bitwise_xor(carry, x)
130
+ return new_carry, None
131
+
132
+ # init: (chunk, 2) zeros
133
+ init_val = jnp.zeros((chunk_len, 2), dtype=jnp.uint64)
134
+
135
+ from jax import lax
136
+
137
+ res, _ = lax.scan(_xor_scan, init_val, terms)
138
+
139
+ return res
140
+
141
+ # 3. Orchestration
142
+ # We iterate chunks on Host? Or use `scan`?
143
+ # Host loop is easier for Memory management (Streaming).
144
+ # Return a "Lazy Object" or List of Objects?
145
+ # The signature `silent_vole` usually returns full Tensor.
146
+ # User requirement: "Silent OT" to reduce communications.
147
+ # If we return a full (N,) tensor, we solved bandwidth but not RAM.
148
+ # But for Phase 2 task "Protocol Upgrade", bandwidth is key.
149
+ # Phase 2 task "Streaming" handles RAM.
150
+ # So returning full Tensor is "okay" for now, although it might OOM 1B.
151
+ # Let's implement blocked execution and stack? No, that OOMs.
152
+
153
+ # We will implement `silent_vole_random_u` to return a `BigTensor` handle?
154
+ # Or just `el.Object` (which might be huge).
155
+ # Since we are in EDSL, the `el.Object` represents the *computation*.
156
+ # If we return a graph that produces (10^9,) tensor, the Evaluator might crash trying to allocate it.
157
+
158
+ # Let's just implement loop and return concatenated for now, assume 10^7-10^8 test case.
159
+ # For 10^9, we rely on Streaming Refactor later.
160
+
161
+ num_chunks = (n + CHUNK_SIZE - 1) // CHUNK_SIZE
162
+
163
+ def _run_expansion(b: Any, seed_val: Any) -> el.Object:
164
+ # b: base (K, 2)
165
+ # seed_val: (32,) u8
166
+
167
+ # 1. Derive K seeds from master seed using combined run_jax block
168
+ def _view_slice_reshape(b: Any) -> Any:
169
+ # View as u64, slice first row, then reshape for AES expand
170
+ u64_view = b.view(jnp.uint64).reshape(-1, 2)
171
+ master_seed = u64_view[:1] # (1, 2)
172
+ return master_seed
173
+
174
+ master_seed = tensor.run_jax(_view_slice_reshape, seed_val)
175
+
176
+ # Expand to K seeds: (1, K, 2)
177
+ row_seeds_packed = field.aes_expand(master_seed, base_k)
178
+ # Reshape using run_jax for XLA optimization
179
+ row_seeds = tensor.run_jax(lambda x: x.reshape(base_k, 2), row_seeds_packed)
180
+
181
+ # Iterate chunks
182
+ local_res = []
183
+ for i in range(num_chunks):
184
+ this_len = min(CHUNK_SIZE, n - i * CHUNK_SIZE)
185
+
186
+ # Generate mask for this chunk using AES
187
+ # Need ceil(this_len / 128) blocks
188
+ num_blocks = (this_len + 127) // 128
189
+ mask_packed = field.aes_expand(row_seeds, num_blocks)
190
+
191
+ # We must use `tensor.run_jax` so logic runs on device
192
+ def _core(base: Any, mask: Any, this_len: int = this_len) -> Any:
193
+ return _expand_chunk(base, mask, this_len)
194
+
195
+ chunk_res = tensor.run_jax(_core, b, mask_packed)
196
+ local_res.append(chunk_res)
197
+
198
+ # Use run_jax for concat to enable XLA fusion
199
+ if len(local_res) == 1:
200
+ return cast(el.Object, local_res[0])
201
+
202
+ def _concat_chunks(*chunks: Any) -> Any:
203
+ return jnp.concatenate(chunks, axis=0)
204
+
205
+ return cast(el.Object, tensor.run_jax(_concat_chunks, *local_res))
206
+
207
+ # Execute on Sender
208
+ v_long = simp.pcall_static((sender,), _run_expansion, v_base, seed)
209
+ # Execute on Receiver
210
+ w_long = simp.pcall_static((receiver,), _run_expansion, w_base, seed_recv)
211
+
212
+ # U expansion
213
+ u_long = simp.pcall_static((sender,), _run_expansion, u_base, seed)
214
+
215
+ # Delta is scalar, reusable
216
+
217
+ return v_long, w_long, u_long, delta