mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,336 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Silver VOLE Implementation (Silent VOLE from LDPC Codes).
16
+
17
+ This module implements the Silver protocol for efficient silent VOLE generation.
18
+ Silver achieves ~1300x communication reduction compared to IKNP by using
19
+ LDPC-based pseudorandom correlation generators.
20
+
21
+ Key Properties:
22
+ - Communication: O(κ) instead of O(N) - sublinear in output length
23
+ - Computation: ~30% more than IKNP (due to LDPC operations)
24
+ - Security: Based on LPN + Regular Syndrome Decoding
25
+
26
+ Reference: "Silver: Silent VOLE and Oblivious Transfer from Hardness of Decoding"
27
+ CRYPTO 2021
28
+
29
+ Usage:
30
+ v_sender, w_receiver = silver_vole(sender=0, receiver=1, n=1000000)
31
+ # W = V + U * Delta (VOLE correlation)
32
+
33
+ > [!WARNING]
34
+ > **SECURITY WARNING**: This implementation is a DEMONSTRATION of the Silver interface
35
+ > but does NOT implement the secure LPN-based correlation generation.
36
+ > It currently relies on AES expansion which is NOT homomorphic, meaning the
37
+ > produced correlations are mathematically incorrect and insecure for active use.
38
+ > The LDPC matrix H is generated with a fixed seed and is unused in the main path.
39
+ > DO NOT USE IN PRODUCTION.
40
+ """
41
+
42
+ from typing import Any, cast
43
+
44
+ import jax.numpy as jnp
45
+
46
+ import mplang.v2.edsl as el
47
+ import mplang.v2.edsl.typing as elt
48
+ import mplang.v2.libs.mpc.ot.extension as ot
49
+ from mplang.v2.dialects import crypto, field, simp, tensor
50
+ from mplang.v2.libs.mpc.vole import ldpc
51
+
52
+ # ============================================================================
53
+ # Constants
54
+ # ============================================================================
55
+
56
+ # Base OT count (security parameter)
57
+ SILVER_BASE_OT = 128
58
+
59
+ # Noise weight for LPN (lower = faster but less secure)
60
+ SILVER_NOISE_WEIGHT = 64
61
+
62
+ # LDPC compression ratio (syndrome_length / code_length)
63
+ SILVER_COMPRESSION = 10
64
+
65
+
66
+ # ============================================================================
67
+ # Silver VOLE Core
68
+ # ============================================================================
69
+
70
+
71
+ def silver_vole(
72
+ sender: int,
73
+ receiver: int,
74
+ n: int,
75
+ return_secrets: bool = False,
76
+ ) -> tuple[el.Object, el.Object] | tuple[el.Object, el.Object, el.Object, el.Object]:
77
+ """Execute Silver VOLE Protocol.
78
+
79
+ Generates n VOLE correlations using the Silver protocol:
80
+ - Sender gets V
81
+ - Receiver gets W, where W = V + U * Delta
82
+
83
+ Communication is O(κ) instead of O(n), achieving ~1300x reduction
84
+ compared to IKNP-based approaches.
85
+
86
+ Args:
87
+ sender: Rank of Sender party
88
+ receiver: Rank of Receiver party
89
+ n: Number of VOLE correlations to generate
90
+ return_secrets: If True, also return U (sender) and Delta (receiver)
91
+
92
+ Returns:
93
+ Tuple of (v_sender, w_receiver) or
94
+ (v_sender, w_receiver, u_sender, delta_receiver) if return_secrets=True
95
+ """
96
+ if sender == receiver:
97
+ raise ValueError("Sender and Receiver must be different parties.")
98
+
99
+ if n <= 0:
100
+ raise ValueError("n must be positive.")
101
+
102
+ # =========================================================================
103
+ # REPAIRED SILVER IMPLEMENTATION (Primal LPN w/ Explicit Noise)
104
+ # =========================================================================
105
+
106
+ # 1. Setup LPN Parameters
107
+ # We use Primal LPN: W = V + U*Delta + e
108
+ # Generator Matrix G is (K x N). We generate it via LDPC gen.
109
+ # We use the dense JAX implementation from ldpc.py for correctness.
110
+
111
+ # Silver parameters
112
+ _code_length, syndrome_length, _noise_weight = ldpc.get_silver_params(n)
113
+
114
+ # Primal LPN dimensions: Input K (Base OT), Output N.
115
+ # We treat the "Syndrome Length" M as the Base OT size K for Primal LPN.
116
+ base_k = syndrome_length
117
+ if base_k > 2048:
118
+ base_k = 2048
119
+
120
+ # -------------------------------------------------------------------------
121
+ # Generate H' (N rows, K cols) for Transposed Matrix Multiplication
122
+ #
123
+ # We target V = v_base * G, where v_base is (1, K) and G is (K, N).
124
+ # This is equivalent to V^T = G^T * v_base^T.
125
+ # By constructing H' = G^T (N x K), we can leverage the C++ kernel which
126
+ # computes Output(M) = Matrix(M, N) * Input(N).
127
+ # Here, Output(N) = H'(N, K) * Input(K).
128
+ # -------------------------------------------------------------------------
129
+
130
+ # Note: generate_silver_ldpc(n, m) returns m x n matrix.
131
+ # Call with (K, N) to get N rows, K cols.
132
+ # SECURITY WARNING: Using a fixed seed (42) means the code structure is public and static.
133
+ # In a real secure deployment, this seed should be generated via a coin-tossing protocol
134
+ # or negotiated securely between parties to ensure the code is random and private if needed.
135
+ # For semi-honest security where parameters are public, this is acceptable but suboptimal.
136
+ H_prime_sparse = ldpc.generate_silver_ldpc(base_k, n, seed=42)
137
+
138
+ # Extract indices for kernel
139
+ h_prime_indices = H_prime_sparse.indices.astype(jnp.uint64)
140
+ h_prime_indptr = H_prime_sparse.indptr.astype(jnp.uint64)
141
+
142
+ def _sparse_struct_provider() -> tuple[el.Object, el.Object]:
143
+ return tensor.constant(h_prime_indices), tensor.constant(h_prime_indptr)
144
+
145
+ H_indices, H_indptr = simp.pcall_static((sender,), _sparse_struct_provider)
146
+
147
+ # Broadcast to receiver (assumed public/shared for semi-honest)
148
+ H_indices_r, H_indptr_r = simp.pcall_static((receiver,), _sparse_struct_provider)
149
+
150
+ # 2. Base VOLE (Size K)
151
+ from mplang.v2.libs.mpc.vole import gilboa
152
+
153
+ def _u_base_provider() -> el.Object:
154
+ # Generate random u_base using new API
155
+ return crypto.random_tensor((base_k, 2), elt.u64)
156
+
157
+ def _delta_provider() -> el.Object:
158
+ # Generate random delta using new API
159
+ return crypto.random_tensor((2,), elt.u64)
160
+
161
+ v_base, w_base, u_base, delta = gilboa.vole( # type: ignore[misc]
162
+ sender, receiver, base_k, _u_base_provider, _delta_provider, return_secrets=True
163
+ )
164
+
165
+ # 3. Expansion (Encoding) using C++ Kernel
166
+ # V = v_base * G = H' * v_base
167
+
168
+ def _encode(vec_base: el.Object, idx: el.Object, ptr: el.Object) -> el.Object:
169
+ # Calls C++ kernel: Output(N) = H'(N, K) * Input(K)
170
+ return ldpc.ldpc_encode_sparse(vec_base, idx, ptr, n, base_k)
171
+
172
+ V = simp.pcall_static((sender,), _encode, v_base, H_indices, H_indptr)
173
+ W_clean = simp.pcall_static((receiver,), _encode, w_base, H_indices_r, H_indptr_r)
174
+
175
+ # 4. Add Noise (Receiver)
176
+ # W = W_clean + e
177
+ # e is sparse noise (LPN security)
178
+
179
+ def _add_noise(w: el.Object) -> el.Object:
180
+ # Generate cryptographically secure sparse noise
181
+ e = ldpc.generate_sparse_noise(n, SILVER_NOISE_WEIGHT)
182
+
183
+ def _xor(a: Any, b: Any) -> Any:
184
+ return jnp.bitwise_xor(a, b)
185
+
186
+ return cast(el.Object, tensor.run_jax(_xor, w, e))
187
+
188
+ W = simp.pcall_static((receiver,), _add_noise, W_clean)
189
+
190
+ # 5. Output
191
+ # We now have W = V + (U_base*G)*Delta + e
192
+ # This is a valid LPN sample.
193
+ # It is "Noisy VOLE".
194
+
195
+ if return_secrets:
196
+ # Compute U_long for sender to verify
197
+ U_long = simp.pcall_static((sender,), _encode, u_base, H_indices, H_indptr)
198
+ return V, W, U_long, delta
199
+
200
+ return V, W
201
+
202
+
203
+ # ============================================================================
204
+ # LDPC-Based Syndrome Expansion (Alternative Implementation)
205
+ # ============================================================================
206
+
207
+
208
+ def silver_vole_ldpc(
209
+ sender: int,
210
+ receiver: int,
211
+ n: int,
212
+ return_secrets: bool = False,
213
+ ) -> tuple[el.Object, el.Object] | tuple[el.Object, el.Object, el.Object, el.Object]:
214
+ """Silver VOLE using explicit LDPC syndrome computation.
215
+
216
+ This is the full Silver protocol with LDPC syndrome encoding.
217
+ More accurate to the paper but slower due to LDPC operations.
218
+
219
+ Args:
220
+ sender: Rank of Sender
221
+ receiver: Rank of Receiver
222
+ n: Number of VOLE correlations
223
+ return_secrets: Return U and Delta
224
+
225
+ Returns:
226
+ VOLE correlation tuple
227
+ """
228
+ if sender == receiver:
229
+ raise ValueError("Sender and Receiver must be different parties.")
230
+
231
+ # 1. Setup parameters
232
+ code_length, syndrome_length, _noise_weight = ldpc.get_silver_params(n)
233
+
234
+ # 2. Generate shared LDPC matrix
235
+ ldpc.generate_silver_ldpc(code_length, syndrome_length, seed=42)
236
+
237
+ # 3. Base OT setup (same as standard Silver)
238
+ base_k = SILVER_BASE_OT
239
+
240
+ # Generate random choice bits on receiver
241
+ def _gen_choice_bits_ldpc() -> el.Object:
242
+ rand = crypto.random_bytes(base_k)
243
+
244
+ def _to_bits(r: Any) -> Any:
245
+ bits = r % 2
246
+ return bits.reshape(-1, 1)
247
+
248
+ return cast(el.Object, tensor.run_jax(_to_bits, rand))
249
+
250
+ choice_bits = simp.pcall_static((receiver,), _gen_choice_bits_ldpc)
251
+
252
+ t_matrix, q_matrix, _s_choices = ot.iknp_core(choice_bits, sender, receiver, base_k)
253
+
254
+ # 4. Sender: Generate random vector and compute syndrome
255
+ def _sender_syndrome(t_mat: Any) -> el.Object:
256
+ # Generate random message vector
257
+ seeds = tensor.run_jax(lambda t: t.reshape(-1, 2), t_mat)
258
+ expanded = field.aes_expand(seeds, code_length // base_k + 1)
259
+
260
+ def _to_message(exp: Any) -> Any:
261
+ return exp.reshape(-1, 2)[:code_length]
262
+
263
+ message = tensor.run_jax(_to_message, expanded)
264
+ return cast(el.Object, message)
265
+
266
+ v_sender = simp.pcall_static((sender,), _sender_syndrome, t_matrix)
267
+
268
+ # 5. Receiver: Compute correlated output
269
+ def _recv_expand(q_mat: Any) -> el.Object:
270
+ seeds = tensor.run_jax(lambda q: q.reshape(-1, 2), q_mat)
271
+ expanded = field.aes_expand(seeds, code_length // base_k + 1)
272
+
273
+ def _to_correlated(exp: Any) -> Any:
274
+ return exp.reshape(-1, 2)[:code_length]
275
+
276
+ return cast(el.Object, tensor.run_jax(_to_correlated, expanded))
277
+
278
+ w_receiver = simp.pcall_static((receiver,), _recv_expand, q_matrix)
279
+
280
+ # 6. Truncate to n
281
+ def _truncate(vec: Any) -> el.Object:
282
+ return cast(el.Object, tensor.run_jax(lambda v: v[:n], vec))
283
+
284
+ v_final = simp.pcall_static((sender,), _truncate, v_sender)
285
+ w_final = simp.pcall_static((receiver,), _truncate, w_receiver)
286
+
287
+ if return_secrets:
288
+ # Generate U and Delta
289
+ def _gen_u() -> el.Object:
290
+ return crypto.random_bytes(n * 16)
291
+
292
+ def _gen_delta() -> el.Object:
293
+ return crypto.random_bytes(16)
294
+
295
+ u_sender = simp.pcall_static((sender,), _gen_u)
296
+ delta_receiver = simp.pcall_static((receiver,), _gen_delta)
297
+
298
+ return v_final, w_final, u_sender, delta_receiver
299
+
300
+ return v_final, w_final
301
+
302
+
303
+ # ============================================================================
304
+ # Communication Estimation
305
+ # ============================================================================
306
+
307
+
308
+ def estimate_silver_communication(n: int) -> dict:
309
+ """Estimate communication cost for Silver VOLE.
310
+
311
+ Args:
312
+ n: Number of VOLE correlations
313
+
314
+ Returns:
315
+ Dictionary with communication estimates
316
+ """
317
+ # Base OT communication
318
+ base_ot_comm = SILVER_BASE_OT * 16 * 2 # 128 OTs, 16 bytes each, 2 messages
319
+
320
+ # Syndrome communication (compressed)
321
+ syndrome_length = max(n // SILVER_COMPRESSION, 128)
322
+ syndrome_comm = syndrome_length * 16 # 128-bit elements
323
+
324
+ # Total Silver communication
325
+ silver_total = base_ot_comm + syndrome_comm
326
+
327
+ # Compare with Gilboa (full IKNP)
328
+ gilboa_total = n * 16 # O(n) for full IKNP
329
+
330
+ return {
331
+ "silver_bytes": silver_total,
332
+ "gilboa_bytes": gilboa_total,
333
+ "compression_ratio": gilboa_total / silver_total,
334
+ "base_ot_bytes": base_ot_comm,
335
+ "syndrome_bytes": syndrome_comm,
336
+ }
@@ -0,0 +1,15 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Runtime module for MPLang v2."""
@@ -0,0 +1,41 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Dialect State Protocol definitions.
16
+
17
+ Dialects can be stateful or stateless. Stateful dialects maintain
18
+ runtime state that is attached to an Interpreter instance.
19
+
20
+ Usage:
21
+ # Attach state to interpreter
22
+ interpreter.set_dialect_state("simp", simp_state)
23
+
24
+ # Retrieve in kernel
25
+ state = interpreter.get_dialect_state("simp")
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ from typing import Protocol, runtime_checkable
31
+
32
+
33
+ @runtime_checkable
34
+ class DialectState(Protocol):
35
+ """Base protocol for dialect states.
36
+
37
+ Dialects that maintain runtime state should implement this protocol.
38
+ Stateless dialects (like tensor, field) don't need a state object.
39
+ """
40
+
41
+ dialect_name: str