mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,386 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Secure Permutation Network library.
16
+
17
+ This module implements secure permutation (shuffling) using Oblivious Transfer (OT).
18
+ It allows a Sender (holding data) and a Receiver (holding a permutation) to
19
+ cooperatively shuffle the data such that:
20
+ 1. The Receiver obtains the shuffled data.
21
+ 2. The Sender learns nothing about the permutation.
22
+ 3. The Receiver learns nothing about the original data order (beyond the result).
23
+
24
+ The implementation uses a Bitonic sorting network to achieve oblivious permutation.
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ import math
30
+ from typing import Any
31
+
32
+ import jax
33
+ import jax.numpy as jnp
34
+ import numpy as np
35
+
36
+ import mplang.v2.edsl.typing as elt
37
+ from mplang.v2.dialects import simp, tensor
38
+ from mplang.v2.libs.mpc.ot import base as ot
39
+
40
+
41
+ def secure_switch(
42
+ x0: Any, x1: Any, control_bit: Any, sender: int, receiver: int
43
+ ) -> tuple[Any, Any]:
44
+ """A 2x2 Secure Switch using OT.
45
+
46
+ Args:
47
+ x0: Input 0 (on Sender).
48
+ x1: Input 1 (on Sender).
49
+ control_bit: Control bit (on Receiver). 0 = straight, 1 = swap.
50
+ sender: Rank of the sender.
51
+ receiver: Rank of the receiver.
52
+
53
+ Returns:
54
+ (y0, y1) on Receiver, where:
55
+ y0 = x0 if c=0 else x1
56
+ y1 = x1 if c=0 else x0
57
+ """
58
+ # y0 = select(x0, x1, c)
59
+ y0 = ot.transfer(x0, x1, control_bit, sender, receiver)
60
+
61
+ # y1 = select(x0, x1, 1-c)
62
+ # Compute 1-c locally on receiver
63
+ def invert_bit(c: Any) -> Any:
64
+ return tensor.run_jax(lambda x: 1 - x, c)
65
+
66
+ inv_control_bit = simp.pcall_static((receiver,), invert_bit, control_bit)
67
+ y1 = ot.transfer(x0, x1, inv_control_bit, sender, receiver)
68
+
69
+ return y0, y1
70
+
71
+
72
+ def _compute_bitonic_sort_controls(permutation: Any, n: int) -> Any:
73
+ """Compute control bits for bitonic sorting network using JAX.
74
+
75
+ This uses a simple approach: track permutation values through
76
+ bitonic sort and record comparison results.
77
+
78
+ Args:
79
+ permutation: JAX array of target indices.
80
+ n: Size (must be power of 2).
81
+
82
+ Returns:
83
+ JAX array of control bits.
84
+ """
85
+
86
+ def impl(perm: Any) -> Any:
87
+ # Convert from gather-style permutation (output[i] = input[perm[i]])
88
+ # to destination positions for each input index: inv[src] = dest
89
+ perm_i64 = perm.astype(jnp.int64)
90
+ indices = jnp.arange(n, dtype=jnp.int64)
91
+ current = jnp.zeros_like(perm_i64)
92
+ current = current.at[perm_i64].set(indices)
93
+
94
+ controls = []
95
+ num_stages = int(math.log2(n))
96
+
97
+ for stage in range(num_stages):
98
+ for step in range(stage + 1):
99
+ step_dist = 2 ** (stage - step)
100
+
101
+ # Vectorized calculation using numpy for static indices
102
+ indices_np = np.arange(n, dtype=np.int64)
103
+ partners_np = indices_np ^ step_dist
104
+ mask_np = partners_np > indices_np
105
+
106
+ idx_i_np = indices_np[mask_np]
107
+ idx_j_np = partners_np[mask_np]
108
+
109
+ # Determine sort direction
110
+ block_size = 2 ** (stage + 1)
111
+ block_ids = indices_np // block_size
112
+ ascending_np = (block_ids % 2) == 0
113
+ asc_mask_np = ascending_np[mask_np]
114
+
115
+ # Extract values
116
+ v_i = current[idx_i_np]
117
+ v_j = current[idx_j_np]
118
+
119
+ swap_asc = v_i > v_j
120
+ swap_desc = v_i < v_j
121
+
122
+ should_swap = jnp.where(asc_mask_np, swap_asc, swap_desc)
123
+
124
+ controls.append(should_swap)
125
+
126
+ # Update current
127
+ new_i = jnp.where(should_swap, v_j, v_i)
128
+ new_j = jnp.where(should_swap, v_i, v_j)
129
+
130
+ current = current.at[idx_i_np].set(new_i)
131
+ current = current.at[idx_j_np].set(new_j)
132
+
133
+ return jnp.concatenate(controls)
134
+
135
+ return tensor.run_jax(impl, permutation)
136
+
137
+
138
+ def apply_permutation(data: Any, permutation: Any, sender: int, receiver: int) -> Any:
139
+ """Apply a secure permutation using a Bitonic sorting network.
140
+
141
+ Args:
142
+ data: Data items (on Sender). Can be a Tensor or list of Objects.
143
+ permutation: Tensor of indices (on Receiver). e.g. [2, 0, 1, 3]
144
+ permutation[i] = src means output[i] comes from input[src].
145
+ sender: Rank of sender.
146
+ receiver: Rank of receiver.
147
+
148
+ Returns:
149
+ Shuffled data (on Receiver). Returns a list if input was a list.
150
+ """
151
+ # Remember if input was a list
152
+ is_list_input = isinstance(data, list)
153
+
154
+ # Handle list input - convert to tensor
155
+ if is_list_input:
156
+ if len(data) == 0:
157
+ return []
158
+
159
+ # Stack list elements into a tensor
160
+ def stack_elements(*args: Any) -> Any:
161
+ return tensor.run_jax(lambda *xs: jnp.stack(xs), *args)
162
+
163
+ data = simp.pcall_static((sender,), stack_elements, *data)
164
+
165
+ target_type = data.type
166
+ if isinstance(target_type, elt.MPType):
167
+ target_type = target_type.value_type
168
+ if not isinstance(target_type, elt.TensorType):
169
+ raise TypeError("apply_permutation expects tensor inputs")
170
+ n = target_type.shape[0]
171
+ original_n = n
172
+
173
+ # Bitonic sort requires power-of-2 size - pad if necessary
174
+ n_padded = 2 ** math.ceil(math.log2(max(n, 2)))
175
+
176
+ if n_padded != n:
177
+ # Pad data with zeros
178
+ def pad_data(d: Any, pad_n: int) -> Any:
179
+ return tensor.run_jax(
180
+ lambda x: jnp.pad(x, (0, pad_n - x.shape[0]), mode="constant"), d
181
+ )
182
+
183
+ data = simp.pcall_static((sender,), pad_data, data, n_padded)
184
+
185
+ # Pad permutation with identity mapping for extra elements
186
+ def pad_perm(p: Any, orig: int, pad_n: int) -> Any:
187
+ extra = jnp.arange(orig, pad_n, dtype=jnp.int64)
188
+ return tensor.run_jax(lambda x: jnp.concatenate([x, extra]), p)
189
+
190
+ permutation = simp.pcall_static((receiver,), pad_perm, permutation, n, n_padded)
191
+ n = n_padded
192
+
193
+ # Compute control bits for bitonic sort (on Receiver)
194
+ controls = simp.pcall_static(
195
+ (receiver,), lambda p: _compute_bitonic_sort_controls(p, n), permutation
196
+ )
197
+
198
+ # Apply bitonic sorting network
199
+ # Strategy:
200
+ # - Iterate through stages/steps.
201
+ # - For each step, identify pairs (i, j).
202
+ # - If data is on Sender (first step), use Vectorized OT (secure_switch).
203
+ # - If data is on Receiver (subsequent steps), use local select.
204
+
205
+ current = data
206
+ ctrl_offset = 0
207
+ num_stages = int(math.log2(n))
208
+
209
+ # Helper to extract a slice of controls
210
+ def get_step_ctrls(all_ctrls: Any, off: int, count: int) -> Any:
211
+ def impl(c: Any, o: int, n: int) -> Any:
212
+ # Convert offset to tensor to avoid recompilation (dynamic slice start)
213
+ o_tensor = tensor.constant(np.array(o, dtype=np.int64))
214
+ # n (slice size) must be static for dynamic_slice
215
+ return tensor.run_jax(
216
+ lambda x, start, size: jax.lax.dynamic_slice(x, (start,), (size,)),
217
+ c,
218
+ o_tensor,
219
+ n,
220
+ )
221
+
222
+ return simp.pcall_static((receiver,), impl, all_ctrls, off, count)
223
+
224
+ for stage in range(num_stages):
225
+ for step in range(stage + 1):
226
+ step_dist = 2 ** (stage - step)
227
+
228
+ # Vectorized step application
229
+ # Construct indices for all pairs in this step
230
+ indices = np.arange(n, dtype=np.int64)
231
+ partners = indices ^ step_dist
232
+ mask = partners > indices
233
+ idx_i_np = indices[mask]
234
+ idx_j_np = partners[mask]
235
+
236
+ # Number of pairs
237
+ num_pairs = len(idx_i_np)
238
+
239
+ # Check where data is
240
+ typ = current.type
241
+ is_on_sender = isinstance(typ, elt.MPType) and typ.parties == (sender,)
242
+
243
+ # Helper to create index tensors on specific parties
244
+ def make_indices(party: tuple[int, ...], idx: Any) -> Any:
245
+ return simp.pcall_static(party, lambda: tensor.constant(idx))
246
+
247
+ if is_on_sender:
248
+ # Get controls for this step (on Receiver)
249
+ step_ctrls = get_step_ctrls(controls, ctrl_offset, num_pairs)
250
+
251
+ # Data on Sender: Use OT
252
+ idx_i_sender = make_indices((sender,), idx_i_np)
253
+ idx_j_sender = make_indices((sender,), idx_j_np)
254
+
255
+ # Extract pairs on Sender
256
+ def extract_pairs_sender(
257
+ d: Any, idx_i: Any, idx_j: Any
258
+ ) -> tuple[Any, Any]:
259
+ return (
260
+ tensor.run_jax(lambda x, i: x[i], d, idx_i),
261
+ tensor.run_jax(lambda x, j: x[j], d, idx_j),
262
+ )
263
+
264
+ val_i, val_j = simp.pcall_static(
265
+ (sender,),
266
+ extract_pairs_sender,
267
+ current,
268
+ idx_i_sender,
269
+ idx_j_sender,
270
+ )
271
+
272
+ # Secure Switch (OT) -> Result on Receiver
273
+ res_i, res_j = secure_switch(val_i, val_j, step_ctrls, sender, receiver)
274
+
275
+ # Reconstruct full array on Receiver
276
+ idx_i_recv = make_indices((receiver,), idx_i_np)
277
+ idx_j_recv = make_indices((receiver,), idx_j_np)
278
+
279
+ # We need to scatter res_i and res_j back to their positions
280
+ def scatter_results(
281
+ vi: Any, vj: Any, ii: Any, ij: Any, size: int
282
+ ) -> Any:
283
+ # Initialize with zeros (or dummy)
284
+ # Note: We assume we cover all indices.
285
+ # Bitonic sort step covers all indices exactly once.
286
+ # So we can just scatter.
287
+
288
+ # We need a template for the result.
289
+ # vi is the type of elements.
290
+ # We can use jnp.zeros_like(vi) but expanded?
291
+ # Or just allocate.
292
+
293
+ def impl(
294
+ v_i: jnp.ndarray,
295
+ v_j: jnp.ndarray,
296
+ idx_i: jnp.ndarray,
297
+ idx_j: jnp.ndarray,
298
+ ) -> jnp.ndarray:
299
+ # v_i shape (N/2, ...), idx_i shape (N/2,)
300
+ # We want output shape (N, ...)
301
+
302
+ # Infer shape from v_i
303
+ out_shape = (size, *v_i.shape[1:])
304
+ out = jnp.zeros(out_shape, dtype=v_i.dtype)
305
+
306
+ out = out.at[idx_i].set(v_i)
307
+ out = out.at[idx_j].set(v_j)
308
+ return out
309
+
310
+ return tensor.run_jax(impl, vi, vj, ii, ij)
311
+
312
+ current = simp.pcall_static(
313
+ (receiver,),
314
+ scatter_results,
315
+ res_i,
316
+ res_j,
317
+ idx_i_recv,
318
+ idx_j_recv,
319
+ n,
320
+ )
321
+ ctrl_offset += num_pairs
322
+
323
+ else:
324
+ # Data on Receiver: Execute locally
325
+ step_ctrls = get_step_ctrls(controls, ctrl_offset, num_pairs)
326
+
327
+ # Construct indices on Receiver
328
+ idx_i_recv = make_indices((receiver,), idx_i_np)
329
+ idx_j_recv = make_indices((receiver,), idx_j_np)
330
+
331
+ def apply_local_step(d: Any, c: Any, ii: Any, ij: Any) -> Any:
332
+ def impl(
333
+ curr: jnp.ndarray,
334
+ ctrls: jnp.ndarray,
335
+ idx_i: jnp.ndarray,
336
+ idx_j: jnp.ndarray,
337
+ ) -> jnp.ndarray:
338
+ val_i = curr[idx_i]
339
+ val_j = curr[idx_j]
340
+
341
+ new_i = jnp.where(ctrls, val_j, val_i)
342
+ new_j = jnp.where(ctrls, val_i, val_j)
343
+
344
+ curr = curr.at[idx_i].set(new_i)
345
+ curr = curr.at[idx_j].set(new_j)
346
+ return curr
347
+
348
+ return tensor.run_jax(impl, d, c, ii, ij)
349
+
350
+ current = simp.pcall_static(
351
+ (receiver,),
352
+ apply_local_step,
353
+ current,
354
+ step_ctrls,
355
+ idx_i_recv,
356
+ idx_j_recv,
357
+ )
358
+ ctrl_offset += num_pairs
359
+
360
+ # Unpad if necessary
361
+ if n_padded != original_n:
362
+
363
+ def unpad(d: Any, orig: int) -> Any:
364
+ return tensor.run_jax(lambda x: x[:orig], d)
365
+
366
+ final_parties = (
367
+ current.type.parties if isinstance(current.type, elt.MPType) else None
368
+ )
369
+ if final_parties is not None:
370
+ current = simp.pcall_static(final_parties, unpad, current, original_n)
371
+ else:
372
+ current = unpad(current, original_n)
373
+
374
+ # Convert back to list if input was a list
375
+ if is_list_input:
376
+
377
+ def unstack_to_list(d: Any, n: int) -> list:
378
+ results = []
379
+ for i in range(n):
380
+ elem = tensor.run_jax(lambda x, idx=i: x[idx], d)
381
+ results.append(elem)
382
+ return results
383
+
384
+ return simp.pcall_static((receiver,), unstack_to_list, current, original_n)
385
+
386
+ return current
@@ -0,0 +1,39 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Common constants for MPC protocols."""
16
+
17
+ # Golden Ratio for 64-bit Multiplicative Hashing
18
+ # Closest integer to 2^64 / phi
19
+ # Golden Ratio for 64-bit Multiplicative Hashing
20
+ # Closest integer to 2^64 / phi
21
+ GOLDEN_RATIO_64 = 0x9E3779B97F4A7C15
22
+
23
+ # LCG Constants
24
+ LCG_ADDEND = 0x14650FB0739D0383
25
+ LCG_MULTIPLIER = 0x27D4EB2F165667C5
26
+
27
+ # SplitMix64 Constants (Gamma values)
28
+ SPLITMIX64_GAMMA_1 = 0xBF58476D1CE4E5B9
29
+ SPLITMIX64_GAMMA_2 = 0x94D049BB133111EB
30
+ SPLITMIX64_GAMMA_3 = 0xFF51AFD7ED558CCD
31
+ SPLITMIX64_GAMMA_4 = 0xC4CEB9FE1A85EC53
32
+
33
+ # Arbitrary Constants (Nothing-up-my-sleeve numbers)
34
+ # Fractional part of PI (Hex)
35
+ PI_FRAC_1 = 0x243F6A8885A308D3
36
+ PI_FRAC_2 = 0x13198A2E03707344
37
+
38
+ # Fractional part of E (Hex)
39
+ E_FRAC_1 = 0xA4093822299F31D0
@@ -0,0 +1,32 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Oblivious Transfer protocols.
16
+
17
+ Submodules:
18
+ - base: Naor-Pinkas 1-out-of-2 OT
19
+ - extension: IKNP OT Extension
20
+ - silent: Silent OT via LPN
21
+ """
22
+
23
+ from .base import transfer
24
+ from .extension import iknp_core, transfer_extension
25
+ from .silent import silent_vole_random_u
26
+
27
+ __all__ = [
28
+ "iknp_core",
29
+ "silent_vole_random_u",
30
+ "transfer",
31
+ "transfer_extension",
32
+ ]
@@ -0,0 +1,222 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Oblivious Transfer (OT) library.
16
+
17
+ This module implements OT logic using the `crypto` dialect (ECC + Hash + SymEnc).
18
+ It implements the Naor-Pinkas 1-out-of-2 OT protocol.
19
+
20
+ Protocol: Naor-Pinkas 1-out-of-2 OT
21
+ Security: Computational security based on ECDH.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ from typing import Any, cast
27
+
28
+ import numpy as np
29
+
30
+ import mplang.v2.edsl as el
31
+ import mplang.v2.edsl.typing as elt
32
+ from mplang.v2.dialects import crypto, simp, tensor
33
+
34
+
35
+ def _receiver_keygen_scalar(
36
+ C_point: el.Object, b: el.Object
37
+ ) -> tuple[el.Object, el.Object]:
38
+ # b is selection bit (0 or 1)
39
+ # k is private key (random scalar)
40
+ k = crypto.ec_random_scalar()
41
+ G = crypto.ec_generator()
42
+ PK_sigma = crypto.ec_mul(G, k)
43
+
44
+ # PK0 = PK_sigma if b=0 else C - PK_sigma
45
+ # We use arithmetic selection for Points:
46
+ # PK0 = PK_sigma + b * (C - 2*PK_sigma)
47
+
48
+ b_scalar = crypto.ec_scalar_from_int(b)
49
+
50
+ # 2 * PK_sigma
51
+ # We use scalar 2
52
+ two_tensor = tensor.constant(np.array(2, dtype=np.int64))
53
+ two_scalar = crypto.ec_scalar_from_int(two_tensor)
54
+
55
+ two_PK_sigma = crypto.ec_mul(PK_sigma, two_scalar)
56
+ diff = crypto.ec_sub(C_point, two_PK_sigma)
57
+ term = crypto.ec_mul(diff, b_scalar)
58
+ PK0 = crypto.ec_add(PK_sigma, term)
59
+
60
+ return PK0, k
61
+
62
+
63
+ def _sender_derive_keys(
64
+ C_point: el.Object, PK0_point: el.Object
65
+ ) -> tuple[el.Object, el.Object, el.Object, el.Object]:
66
+ # PK1 = C - PK0
67
+ PK1_point = crypto.ec_sub(C_point, PK0_point)
68
+
69
+ def derive_key(PK: el.Object) -> tuple[el.Object, el.Object]:
70
+ # Ephemeral key r
71
+ r = crypto.ec_random_scalar()
72
+ G = crypto.ec_generator()
73
+ U = crypto.ec_mul(G, r) # U = g^r
74
+
75
+ # Shared secret K = PK^r
76
+ K_point = crypto.ec_mul(PK, r)
77
+ return U, K_point
78
+
79
+ U0, K0 = derive_key(PK0_point)
80
+ U1, K1 = derive_key(PK1_point)
81
+
82
+ return U0, K0, U1, K1
83
+
84
+
85
+ def _receiver_derive_key(
86
+ U0: el.Object, U1: el.Object, PK0: el.Object, k: el.Object, b: el.Object
87
+ ) -> el.Object:
88
+ b_scalar = crypto.ec_scalar_from_int(b)
89
+
90
+ # Select U (Point arithmetic)
91
+ # U = U0 + b*(U1-U0)
92
+ diff_U = crypto.ec_sub(U1, U0)
93
+ term_U = crypto.ec_mul(diff_U, b_scalar)
94
+ U = crypto.ec_add(U0, term_U)
95
+
96
+ # Recover Shared Secret K = U^k
97
+ K_point = crypto.ec_mul(U, k)
98
+ return K_point
99
+
100
+
101
+ def transfer(
102
+ m0: el.MPObject,
103
+ m1: el.MPObject,
104
+ choice: el.MPObject,
105
+ sender: int,
106
+ receiver: int,
107
+ ) -> el.MPObject:
108
+ """Perform 1-out-of-2 Oblivious Transfer (Naor-Pinkas).
109
+
110
+ Args:
111
+ m0: Message 0 (on Sender).
112
+ m1: Message 1 (on Sender).
113
+ choice: Selection bit 0 or 1 (on Receiver).
114
+ sender: Rank of the sender.
115
+ receiver: Rank of the receiver.
116
+
117
+ Returns:
118
+ The selected message (on Receiver).
119
+ """
120
+ assert isinstance(m0, el.Object) and isinstance(m0.type, elt.MPType)
121
+ assert isinstance(m1, el.Object) and isinstance(m1.type, elt.MPType)
122
+ assert isinstance(choice, el.Object) and isinstance(choice.type, elt.MPType)
123
+
124
+ # --- Step 1: Sender Initialization ---
125
+ def sender_init_fn() -> el.Object:
126
+ # C is a random point: C = r * G
127
+ r = crypto.ec_random_scalar()
128
+ G = crypto.ec_generator()
129
+ C = crypto.ec_mul(G, r)
130
+ return C
131
+
132
+ C = simp.pcall_static((sender,), sender_init_fn)
133
+
134
+ # Move C to Receiver
135
+ C_recv = simp.shuffle_static(C, {receiver: sender})
136
+
137
+ # Infer target type from m0
138
+ m0_type = m0.type
139
+ if isinstance(m0_type, elt.MPType):
140
+ val_type = m0_type.value_type
141
+ else:
142
+ val_type = m0_type
143
+
144
+ # Since we use tensor.elementwise, we need the element type for decryption
145
+ if isinstance(val_type, elt.TensorType):
146
+ target_type = val_type.element_type
147
+ else:
148
+ target_type = val_type
149
+
150
+ # --- Step 1: Receiver Key Generation ---
151
+ def receiver_keygen_fn(C_point: el.Object, b: el.Object) -> el.Object:
152
+ res: el.Object = cast(
153
+ el.Object, tensor.elementwise(_receiver_keygen_scalar, C_point, b)
154
+ )
155
+ return res # type: ignore[no-any-return]
156
+
157
+ # Returns (PK0, k) on receiver
158
+ keys_recv = simp.pcall_static((receiver,), receiver_keygen_fn, C_recv, choice)
159
+
160
+ # Extract PK0 to send back
161
+ def get_pk0(pair: Any) -> el.Object:
162
+ return cast(el.Object, pair[0])
163
+
164
+ PK0_to_send = simp.pcall_static((receiver,), get_pk0, keys_recv)
165
+ PK0_sender = simp.shuffle_static(PK0_to_send, {sender: receiver})
166
+
167
+ # --- Step 3: Sender Encryption ---
168
+ def sender_encrypt_fn(
169
+ C_point: el.Object, PK0_point: el.Object, msg0: el.Object, msg1: el.Object
170
+ ) -> Any:
171
+ def encrypt_elementwise(
172
+ c: Any, pk0: Any, m0: Any, m1: Any
173
+ ) -> tuple[Any, Any, Any, Any]:
174
+ u0, k0, u1, k1 = _sender_derive_keys(c, pk0)
175
+
176
+ kb0 = crypto.ec_point_to_bytes(k0)
177
+ kb1 = crypto.ec_point_to_bytes(k1)
178
+
179
+ sk0 = crypto.hash_bytes(kb0)
180
+ sk1 = crypto.hash_bytes(kb1)
181
+
182
+ c0 = crypto.sym_encrypt(sk0, m0)
183
+ c1 = crypto.sym_encrypt(sk1, m1)
184
+ return u0, c0, u1, c1
185
+
186
+ return tensor.elementwise(encrypt_elementwise, C_point, PK0_point, msg0, msg1)
187
+
188
+ ciphertexts = simp.pcall_static((sender,), sender_encrypt_fn, C, PK0_sender, m0, m1)
189
+
190
+ # Move ciphertexts to Receiver
191
+ # ciphertexts is a tuple, so we map shuffle over it
192
+ from jax.tree_util import tree_map
193
+
194
+ ciphertexts_recv = tree_map(
195
+ lambda x: simp.shuffle_static(x, {receiver: sender}), ciphertexts
196
+ )
197
+
198
+ # --- Step 4: Receiver Decryption ---
199
+ def receiver_decrypt_fn(c_texts: Any, keys: Any, b: el.Object) -> el.Object:
200
+ # b is selection bit
201
+ # keys is (PK0, k)
202
+ PK0, k = keys
203
+ U0, V0, U1, V1 = c_texts
204
+
205
+ def decrypt_elementwise(
206
+ u0: Any, v0: Any, u1: Any, v1: Any, pk0: Any, k_priv: Any, sel: Any
207
+ ) -> Any:
208
+ k_pt = _receiver_derive_key(u0, u1, pk0, k_priv, sel)
209
+ kb = crypto.ec_point_to_bytes(k_pt)
210
+ sk = crypto.hash_bytes(kb)
211
+ v = crypto.select(sel, v1, v0)
212
+ return crypto.sym_decrypt(sk, v, target_type)
213
+
214
+ res = tensor.elementwise(decrypt_elementwise, U0, V0, U1, V1, PK0, k, b)
215
+ return cast(el.Object, res)
216
+
217
+ result = simp.pcall_static(
218
+ (receiver,), receiver_decrypt_fn, ciphertexts_recv, keys_recv, choice
219
+ )
220
+
221
+ res_obj: el.Object = cast(el.Object, result)
222
+ return res_obj