mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,310 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Oblivious Pseudorandom Function (OPRF).
16
+
17
+ Implements KKRT-style OPRF based on OT Extension.
18
+ Ref: https://eprint.iacr.org/2016/799.pdf
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ from typing import Any, cast
24
+
25
+ import jax.numpy as jnp
26
+
27
+ import mplang.v2.edsl as el
28
+ from mplang.v2.dialects import simp, tensor
29
+ from mplang.v2.libs.mpc.ot import extension as ot_extension
30
+
31
+
32
+ def eval_oprf(
33
+ receiver_inputs: el.Object, # (N, 16) bytes
34
+ sender: int,
35
+ receiver: int,
36
+ num_items: int,
37
+ ) -> tuple[el.Object, el.Object]:
38
+ """Evaluate OPRF on receiver's inputs using KKRT-style protocol.
39
+
40
+ Protocol Overview:
41
+ ──────────────────────────────────────────────────────────────────────────
42
+ This implements a simplified KKRT OPRF using IKNP OT Extension as the base.
43
+
44
+ Parties:
45
+ - Sender: Has secret key (T matrix, s vector) from IKNP
46
+ - Receiver: Has inputs x₁, ..., xₙ and gets PRF outputs
47
+
48
+ Key Relations (IKNP):
49
+ ──────────────────────────────────────────────────────────────────────────
50
+ Let:
51
+ Q[i]: (K,) bit vector - receiver's OT output for row i
52
+ T[i]: (K,) bit vector - sender's OT output for row i
53
+ s: (K,) bit vector - sender's secret (random)
54
+ c[i]: 1 bit - receiver's choice bit for row i
55
+
56
+ IKNP Correlation:
57
+ T[i][j] = Q[i][j] ⊕ (c[i] · s[j]) for all j ∈ [0, K)
58
+
59
+ Where ⊕ is XOR and · is AND.
60
+ This means: if c[i] = 1: T[i] = Q[i] ⊕ s
61
+ if c[i] = 0: T[i] = Q[i]
62
+
63
+ Simplified OPRF Construction:
64
+ ──────────────────────────────────────────────────────────────────────────
65
+ Choice bits: c[i] = encode(x_i)[0] (first bit of item encoding)
66
+
67
+ Receiver output: PRF(x_i) = pack(Q[i]) (just pack the Q matrix row)
68
+ Sender can eval: PRF(y) = pack(T[row(y)]) (pack corresponding T row)
69
+
70
+ When x_i == y and they map to same row: outputs match due to IKNP relation.
71
+
72
+ Note: Full KKRT uses Cuckoo hashing to map items to rows. This simplified
73
+ version assumes sequential mapping (item i uses row i).
74
+
75
+ Args:
76
+ receiver_inputs: (N, 16) byte tensor of receiver's inputs
77
+ sender: Rank of sender party
78
+ receiver: Rank of receiver party
79
+ num_items: Number of items N
80
+
81
+ Returns:
82
+ Tuple of:
83
+ - sender_key: (T, s) tuple on sender - T is (N, K) bit matrix, s is (K,)
84
+ - receiver_outputs: (N, 32) byte tensor of OPRF outputs on receiver (SHA256)
85
+ """
86
+ K = 128 # Security parameter (OT extension width)
87
+
88
+ # ═════════════════════════════════════════════════════════════════════════
89
+ # Step 1: Encode receiver's inputs to choice bits for IKNP
90
+ # ═════════════════════════════════════════════════════════════════════════
91
+ # For each input x_i, we need K choice bits for IKNP OT Extension.
92
+ # We use a deterministic encoding: unpack bytes to bits.
93
+
94
+ def encode_inputs(inputs: el.Object) -> el.Object:
95
+ """Encode (N, 16) byte inputs to (N, K) bit codes.
96
+
97
+ Each 16-byte input is unpacked to 128 bits.
98
+ These bits serve as the receiver's OT choices.
99
+ """
100
+
101
+ def _encode(x: Any) -> Any:
102
+ # x: (N, 16) uint8 array
103
+ # Unpack each byte to 8 bits: (N, 16) -> (N, 128)
104
+ unpacked = jnp.unpackbits(x, axis=1) # (N, 128)
105
+ # Ensure exactly K bits
106
+ return unpacked[:, :K].astype(jnp.uint8) # (N, K)
107
+
108
+ return cast(el.Object, tensor.run_jax(_encode, inputs))
109
+
110
+ choice_codes = simp.pcall_static((receiver,), encode_inputs, receiver_inputs)
111
+ # choice_codes: (N, K) bit matrix on receiver
112
+
113
+ # ═════════════════════════════════════════════════════════════════════════
114
+ # Step 2: Extract first bit of each code as IKNP choice bits
115
+ # ═════════════════════════════════════════════════════════════════════════
116
+ # Simplified: use only first bit of encoding as OT choice
117
+ # Full KKRT would use all K bits differently
118
+
119
+ # ═════════════════════════════════════════════════════════════════════════
120
+ # Step 3: Run IKNP OT Extension to generate correlated matrices Q and T
121
+ # ═════════════════════════════════════════════════════════════════════════
122
+ # IKNP generates:
123
+ # Q: (N, K) on receiver - one K-bit row per item
124
+ # T: (N, K) on sender - correlated via T[i] = Q[i] ⊕ (choice[i] · s)
125
+ # s: (K,) on sender - random secret vector
126
+
127
+ # Pass full K-bit codes as choice bits (N, K)
128
+ t_matrix, q_matrix, s = ot_extension.iknp_core(
129
+ choice_codes, sender, receiver, num_items
130
+ )
131
+ # t_matrix: (N, K) on sender
132
+ # q_matrix: (N, K) on receiver
133
+ # s: (K,) on sender
134
+
135
+ # ═════════════════════════════════════════════════════════════════════════
136
+ # Step 4: Compute OPRF outputs
137
+ # ═════════════════════════════════════════════════════════════════════════
138
+ # Simplified KKRT:
139
+ # Receiver: output_i = pack(Q[i]) (pack 128 bits to 16 bytes)
140
+ # Sender: can later compute pack(T[i]) for matching items
141
+
142
+ def compute_receiver_outputs(q: el.Object, codes: el.Object) -> el.Object:
143
+ """Compute receiver's OPRF outputs by packing Q matrix rows.
144
+
145
+ Args:
146
+ q: (N, K) bit matrix Q from IKNP
147
+ codes: (N, K) bit codes (not used in simplified version)
148
+
149
+ Returns:
150
+ (N, 16) packed bytes - OPRF output for each input
151
+ """
152
+
153
+ def _process(q_mat: Any, code_mat: Any) -> Any:
154
+ # q_mat: (N, K=128) bits
155
+ # Pack each row from 128 bits to 16 bytes
156
+ packed = jnp.packbits(q_mat, axis=1) # (N, 16) uint8
157
+ return packed
158
+
159
+ packed_q = cast(el.Object, tensor.run_jax(_process, q, codes))
160
+
161
+ # Security Fix: Hash the OT output to implement a Random Oracle
162
+ # OPRF = H(OT_output, input_tweaks...)
163
+ # Here we use the shared vec_hash utility which handles domain separation.
164
+ return ot_extension.vec_hash(packed_q, domain_sep=0x0CDF, num_rows=num_items)
165
+
166
+ receiver_outputs = simp.pcall_static(
167
+ (receiver,), compute_receiver_outputs, q_matrix, choice_codes
168
+ )
169
+ # receiver_outputs: (N, 32) on receiver
170
+
171
+ # ═════════════════════════════════════════════════════════════════════════
172
+ # Step 5: Package sender's key for later PRF evaluation
173
+ # ═════════════════════════════════════════════════════════════════════════
174
+ # Sender keeps (T, s) to evaluate PRF on any input later
175
+ sender_key = simp.pcall_static((sender,), lambda t, s_: (t, s_), t_matrix, s)
176
+ # sender_key: tuple (T, s) on sender where T is (N,K), s is (K,)
177
+
178
+ return sender_key, receiver_outputs
179
+
180
+
181
+ # =============================================================================
182
+ # KKRT OPRF Sender Evaluation (Vectorized)
183
+ # =============================================================================
184
+ #
185
+ # KKRT Formula:
186
+ # ─────────────────────────────────────────────────────────────────────────────
187
+ # For sender with key (T, s) and input y:
188
+ # code_y = encode(y) # K bits
189
+ # output = pack(T[row] XOR (code_y * s))
190
+ #
191
+ # For receiver with Q matrix and input x:
192
+ # code_x = encode(x) # K bits
193
+ # output = pack(Q[row] XOR code_x)
194
+ #
195
+ # When x == y:
196
+ # T[row] XOR (code_x * s) == Q[row] XOR code_x ✅ (due to IKNP correlation)
197
+ # =============================================================================
198
+
199
+
200
+ def sender_eval_prf_batch(
201
+ sender_key: el.Object, # Tuple (t_matrix, s) on sender
202
+ sender_items: el.Object, # (M, 16) bytes - items to evaluate
203
+ sender: int,
204
+ num_items: int,
205
+ ) -> el.Object:
206
+ """Evaluate PRF on sender's side for a batch of items.
207
+
208
+ Args:
209
+ sender_key: The key tuple (t_matrix, s) from eval_oprf.
210
+ sender_items: (M, 16) byte tensor of sender's items.
211
+ sender: Rank of sender party.
212
+ num_items: Number of items M (must be provided).
213
+
214
+ Returns:
215
+ (M, 32) byte tensor of PRF outputs on sender.
216
+ """
217
+ K = 128
218
+
219
+ def compute_sender_outputs(key: el.Object, items: el.Object) -> el.Object:
220
+ """Compute sender's PRF outputs using KKRT formula."""
221
+
222
+ def _eval(key_tuple: Any, x: Any) -> Any:
223
+ t_matrix, s = key_tuple
224
+ M = x.shape[0]
225
+ N = t_matrix.shape[0]
226
+
227
+ # Encode items to get choice bits
228
+ # Unpack: (M, 16) -> (M, 128) bits
229
+ codes = jnp.unpackbits(x, axis=1)[:, :K] # (M, K)
230
+
231
+ # Compute (codes · s) for each item
232
+ # Masking s with item codes ensures result depends on EVERY bit
233
+ # codes: (M, K), s: (K,) -> broadcast to (M, K)
234
+ code_masked = jnp.where(codes, s, 0).astype(t_matrix.dtype)
235
+
236
+ # Use row i for item i
237
+ M_clipped = min(M, N)
238
+ t_rows = t_matrix[:M_clipped] # (M_clipped, K)
239
+
240
+ # KKRT: output = T[i] XOR (first_bit[i] · s)
241
+ xored = jnp.bitwise_xor(t_rows, code_masked[:M_clipped]) # (M_clipped, K)
242
+
243
+ # Pack to bytes
244
+ packed = jnp.packbits(xored, axis=1) # (M_clipped, 16)
245
+
246
+ # Pad if needed
247
+ if M > N:
248
+ padding = jnp.zeros((M - N, 16), dtype=packed.dtype)
249
+ packed = jnp.concatenate([packed, padding], axis=0)
250
+
251
+ return packed
252
+
253
+ raw_outputs = cast(el.Object, tensor.run_jax(_eval, key, items))
254
+
255
+ return ot_extension.vec_hash(raw_outputs, domain_sep=0x0CDF, num_rows=num_items)
256
+
257
+ return cast(
258
+ el.Object,
259
+ simp.pcall_static((sender,), compute_sender_outputs, sender_key, sender_items),
260
+ )
261
+
262
+
263
+ def sender_eval_prf(
264
+ sender_key: el.Object, # Tuple (t_matrix, s) on sender
265
+ candidate: el.Object, # (16,) bytes to evaluate
266
+ sender: int,
267
+ ) -> el.Object:
268
+ """Evaluate PRF on sender's side for a single candidate.
269
+
270
+ This allows sender to compute PRF(k, y) for any y.
271
+
272
+ Args:
273
+ sender_key: The key tuple from eval_oprf.
274
+ candidate: A single 16-byte input to evaluate.
275
+ sender: Rank of sender party.
276
+
277
+ Returns:
278
+ (32,) byte tensor of PRF output on sender.
279
+ """
280
+ K = 128
281
+
282
+ def _eval(key: el.Object, cand: el.Object) -> el.Object:
283
+ def _compute(key_tuple: Any, c: Any) -> Any:
284
+ t_matrix, s = key_tuple
285
+
286
+ # Encode candidate to K bits
287
+ code = jnp.unpackbits(c)[:K] # (K,)
288
+
289
+ # KKRT formula: output = pack(t_row XOR (code * s))
290
+ t_row = t_matrix[0] # (K,) - use first row
291
+ code_masked = jnp.bitwise_and(code, s) # (K,)
292
+ xored = jnp.bitwise_xor(t_row, code_masked) # (K,)
293
+
294
+ # Pack to bytes
295
+ packed = jnp.packbits(xored) # (16,)
296
+
297
+ # Reshape to (1, 16) for vec_hash
298
+ return packed.reshape(1, 16)
299
+
300
+ raw_out_batch = cast(el.Object, tensor.run_jax(_compute, key, cand))
301
+
302
+ # Use batched hash with num_rows=1
303
+ hashed_batch = ot_extension.vec_hash(
304
+ raw_out_batch, domain_sep=0x0CDF, num_rows=1
305
+ )
306
+
307
+ # Flatten back to (32,) using slice to avoid extra run_jax node
308
+ return tensor.slice_tensor(hashed_batch, (0, 0), (32,))
309
+
310
+ return cast(el.Object, simp.pcall_static((sender,), _eval, sender_key, candidate))
@@ -0,0 +1,344 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Private Set Intersection using VOLE and OKVS (RR22-Style).
16
+
17
+ This module implements a high-performance PSI protocol based on the "Blazing Fast PSI"
18
+ (RR22) paper. The protocol relies on Vector Oblivious Linear Evaluation (VOLE) and
19
+ Oblivious Key-Value Stores (OKVS) to achieve efficient set intersection with linear
20
+ communication O(N) and computation complexity.
21
+
22
+ Protocol Overview:
23
+ The core idea is to mask a "Polynomial" (encoded via OKVS) with VOLE-correlated randomness,
24
+ such that the mask can only be removed (and the polynomial verified) if the parties share
25
+ the same element.
26
+
27
+ Phases:
28
+ 1. **Correlated Randomness (VOLE)**:
29
+ Sender and Receiver establish a shared correlation:
30
+ W = V + U * Delta
31
+ - Sender holds U, V.
32
+ - Receiver holds W, Delta.
33
+ - U is random. Delta is a fixed secret scalar (Receiver's key).
34
+
35
+ 2. **Encoding (OKVS)**:
36
+ Receiver encodes their input set Y into a structure P using OKVS, such that:
37
+ Decode(P, y) = H(y) for all y in Y.
38
+ Here H(y) is a Random Oracle (implemented via Davies-Meyer/AES).
39
+
40
+ 3. **Masking & Exchange**:
41
+ Receiver masks the structure P with their VOLE share W:
42
+ Q = P ^ W
43
+ Receiver sends Q to Sender.
44
+
45
+ 4. **Decoding & Verification**:
46
+ Sender attempts to decode Q for each of their items x in X.
47
+ Since OKVS is linear:
48
+ Decode(Q, x) = Decode(P, x) ^ Decode(W, x)
49
+
50
+ Sender reconstructs the potential "Target" value T:
51
+ T = Decode(Q, x) ^ Decode(V, x) ^ H(x)
52
+
53
+ If x in Y (Intersection):
54
+ Decode(P, x) = H(x)
55
+ Decode(W, x) = Decode(V, x) ^ Decode(U, x) * Delta
56
+ Substitute into T:
57
+ T = H(x) ^ (Decode(V, x) ^ Decode(U, x) * Delta) ^ Decode(V, x) ^ H(x)
58
+ T = Decode(U, x) * Delta
59
+
60
+ Thus, verification becomes checking if T == U* * Delta, where U* = Decode(U, x).
61
+ This check is performed securely using hashes to prevent leakage.
62
+ """
63
+
64
+ from typing import Any, cast
65
+
66
+ import jax.numpy as jnp
67
+
68
+ import mplang.v2.edsl as el
69
+ import mplang.v2.libs.mpc.ot.silent as silent_ot
70
+ from mplang.v2.dialects import field, simp, tensor
71
+
72
+
73
+ def psi_intersect(
74
+ sender: int,
75
+ receiver: int,
76
+ n: int,
77
+ sender_items: el.Object,
78
+ receiver_items: el.Object,
79
+ ) -> el.Object:
80
+ """Execute OKVS-based PSI Protocol.
81
+
82
+ Args:
83
+ sender: Rank of Sender.
84
+ receiver: Rank of Receiver.
85
+ n: Number of items (must be same for now).
86
+ sender_items: Object located at Sender containing (N,) u64 items.
87
+ receiver_items: Object located at Receiver containing (N,) u64 items.
88
+
89
+ Returns:
90
+ Intersection verification tuple (T, U*, Delta).
91
+ """
92
+
93
+ # Validation
94
+ if sender == receiver:
95
+ raise ValueError(
96
+ f"Sender ({sender}) and Receiver ({receiver}) must be different."
97
+ )
98
+
99
+ if n <= 0:
100
+ raise ValueError(f"Input size n must be positive, got {n}.")
101
+
102
+ # =========================================================================
103
+ # Phase 1. Parameter Setup & Topology
104
+ # =========================================================================
105
+ # OKVS Size M = expansion * N.
106
+ # The expansion factor is critical for the success probability of the "Peeling"
107
+ # algorithm used in OKVS encoding (Garbled Cuckoo Table).
108
+ # Larger N allows smaller expansion (closer to theoretical 1.23) while maintaining safety.
109
+ import mplang.v2.libs.mpc.psi.okvs_gct as okvs_gct
110
+
111
+ expansion = okvs_gct.get_okvs_expansion(n)
112
+ M = int(n * expansion)
113
+
114
+ # Align M to 128 boundary for efficient batch processing in Silent VOLE (LPN)
115
+ if M % 128 != 0:
116
+ M = ((M // 128) + 1) * 128
117
+
118
+ # =========================================================================
119
+ # Phase 2. Correlated Randomness Generation (VOLE)
120
+ # =========================================================================
121
+ # Parties run Silent VOLE (based on LPN assumption) to generate:
122
+ # Sender: U, V (Vectors of size M)
123
+ # Receiver: W, Delta
124
+ # Correlation: W = V + U * Delta
125
+ #
126
+ # Note: U is uniformly random. It acts as a "One-Time Pad" key for the protocol.
127
+
128
+ # silent_vole_random_u returns (v, w, u, delta)
129
+ res_tuple = silent_ot.silent_vole_random_u(sender, receiver, M, base_k=1024)
130
+ v_sender, w_receiver, u_sender, delta_receiver = res_tuple[:4]
131
+
132
+ # =========================================================================
133
+ # Phase 3. Receiver Encoding & Masking (OKVS)
134
+ # =========================================================================
135
+ # The Receiver encodes their input set Y into the OKVS structure P.
136
+ # Goal: Decode(P, y) = H(y) forall y in Y.
137
+ #
138
+ # Then, Receiver masks P with the VOLE output W to get Q:
139
+ # Q = P ^ W
140
+ # This Q is sent to the Sender.
141
+
142
+ # 3.1 Generate OKVS Seed (Public/Session Randomness)
143
+ # Used for OKVS hashing distribution. Can be public, but generated at runtime for safety.
144
+ from mplang.v2.dialects import crypto
145
+ from mplang.v2.edsl import typing as elt
146
+
147
+ def _gen_seed() -> Any:
148
+ return crypto.random_tensor((2,), elt.u64)
149
+
150
+ okvs_seed = simp.pcall_static((receiver,), _gen_seed)
151
+ okvs_seed_sender = simp.shuffle_static(okvs_seed, {sender: receiver})
152
+
153
+ # Instantiate OKVS Data Structure
154
+ okvs = okvs_gct.SparseOKVS(M)
155
+
156
+ def _recv_ops(y: Any, w: Any, delta: Any, seed: Any) -> Any:
157
+ # y: (N,) Inputs
158
+ # w: (M, 2) VOLE share
159
+
160
+ # 3.2 Compute H(y) - The Random Oracle Target
161
+ # We use Davies-Meyer construction: H(x) = E_x(0) ^ x
162
+ # This is a standard, efficient, and robust way to instantiate a RO from AES.
163
+
164
+ def _reshape_seeds(items: Any) -> Any:
165
+ # Prepare items as AES keys (128-bit)
166
+ lo = items
167
+ hi = jnp.zeros_like(items)
168
+ return jnp.stack([lo, hi], axis=1) # (N, 2)
169
+
170
+ seeds = tensor.run_jax(_reshape_seeds, y)
171
+ res_exp = field.aes_expand(seeds, 1) # (N, 1, 2)
172
+
173
+ def _davies_meyer(enc: Any, s: Any) -> Any:
174
+ enc_flat = enc.reshape(enc.shape[0], 2)
175
+ return jnp.bitwise_xor(enc_flat, s)
176
+
177
+ h_y = tensor.run_jax(_davies_meyer, res_exp, seeds)
178
+
179
+ # 3.3 Solve System of Linear Equations (OKVS Encode)
180
+ # We find P such that: P * M_okvs(y) = h_y
181
+ p_storage = okvs.encode(y, h_y, seed)
182
+
183
+ # 3.4 Mask with Vole Share
184
+ # Q = P ^ W
185
+ q_storage = field.add(p_storage, w)
186
+
187
+ return q_storage
188
+
189
+ # Execute on Receiver
190
+ q_shared = simp.pcall_static(
191
+ (receiver,), _recv_ops, receiver_items, w_receiver, delta_receiver, okvs_seed
192
+ )
193
+
194
+ # 3.5 Send Q to Sender
195
+ q_sender_view = simp.shuffle_static(q_shared, {sender: receiver})
196
+
197
+ # =========================================================================
198
+ # Phase 4. Sender Decoding & Reconstruction
199
+ # =========================================================================
200
+ # Sender uses Q and their local shares (U, V) to reconstruct T.
201
+ #
202
+ # Derivation:
203
+ # 1. S_decoded = Decode(Q, x) = Decode(P ^ W, x) = P(x) ^ W(x)
204
+ # 2. Recall W(x) = V(x) ^ U(x)*Delta (VOLE property)
205
+ # 3. So S_decoded = P(x) ^ V(x) ^ U(x)*Delta
206
+ #
207
+ # 4. Sender computes T = S_decoded ^ V(x) ^ H(x)
208
+ # T = P(x) ^ V(x) ^ U(x)*Delta ^ V(x) ^ H(x)
209
+ # T = P(x) ^ H(x) ^ U(x)*Delta
210
+ #
211
+ # 5. If x is in Intersection (Meanings x == y for some y):
212
+ # Then P(x) == H(x) (by OKVS property)
213
+ # So T = H(x) ^ H(x) ^ U(x)*Delta
214
+ # T = U(x)*Delta
215
+ #
216
+ # This relation T == U* * Delta is what we verify in Phase 5.
217
+
218
+ def _sender_ops(x: Any, q: Any, u: Any, v: Any, seed: Any) -> tuple[Any, Any]:
219
+ # x: (N,) Sender Items
220
+ # q: (M, 2) Received OKVS
221
+
222
+ # 4.1 Decode Q and V at x
223
+ # OKVS Decode is a linear combination of storage positions.
224
+ s_decoded = okvs.decode(x, q, seed)
225
+ v_decoded = okvs.decode(x, v, seed)
226
+
227
+ # 4.2 Compute H(x)
228
+ def _reshape_seeds(items: Any) -> Any:
229
+ lo = items
230
+ hi = jnp.zeros_like(items)
231
+ return jnp.stack([lo, hi], axis=1)
232
+
233
+ seeds_x = tensor.run_jax(_reshape_seeds, x)
234
+ res_exp_x = field.aes_expand(seeds_x, 1)
235
+
236
+ def _davies_meyer(enc: Any, s: Any) -> Any:
237
+ enc_flat = enc.reshape(enc.shape[0], 2)
238
+ return jnp.bitwise_xor(enc_flat, s)
239
+
240
+ h_x = tensor.run_jax(_davies_meyer, res_exp_x, seeds_x)
241
+
242
+ # 4.3 Compute T candidate
243
+ # T = S ^ V ^ H(x)
244
+ # Note: s_decoded is (S^V^U*Delta) effectively
245
+ t_val = field.add(s_decoded, v_decoded)
246
+ t_val = field.add(t_val, h_x)
247
+
248
+ # 4.4 Compute U* = Decode(U, x)
249
+ # This is the sender's share of the randomness for item x.
250
+ s_u = field.decode_okvs(x, u, seed)
251
+
252
+ return t_val, s_u
253
+
254
+ t_val_sender, u_star_sender = simp.pcall_static(
255
+ (sender,),
256
+ _sender_ops,
257
+ sender_items,
258
+ q_sender_view,
259
+ u_sender,
260
+ v_sender,
261
+ okvs_seed_sender,
262
+ )
263
+
264
+ # =========================================================================
265
+ # Phase 5. Secure Verification
266
+ # =========================================================================
267
+ # The Protocol invariant is T == U* * Delta for intersection items.
268
+ #
269
+ # Security Risk:
270
+ # We must NOT reveal T or Delta to the other party.
271
+ # - If Receiver learns T, they can compute Diff = T - U*Delta = H(x) + ... and attack x.
272
+ # - If Sender learns Delta, VOLE security collapses.
273
+ #
274
+ # Secure Verification Method:
275
+ # 1. Sender sends U* (Random Mask share) to Receiver.
276
+ # - U* is derived from U (random VOLE inputs) so it reveals nothing about X.
277
+ #
278
+ # 2. Receiver computes Target = U* * Delta.
279
+ # - This allows Receiver to construct the expected value of T without knowing T's components.
280
+ #
281
+ # 3. Receiver Hashes the Target and sends H(Target) to Sender.
282
+ # - Hashing prevents Sender from learning Delta algebraically.
283
+ # - Hash function acts as a commitment.
284
+ #
285
+ # 4. Sender compares H(T) =? H(Target).
286
+ # - Equality implies x is in Intersection.
287
+
288
+ # 5.1 Sender -> Receiver: U*
289
+ u_star_recv = simp.shuffle_static(u_star_sender, {receiver: sender})
290
+
291
+ # 5.2 Receiver: Compute Expected Target (U* * Delta)
292
+ def _recv_verify_ops(u_s: Any, delta: Any) -> Any:
293
+ # u_s: (N, 2), delta: (2,)
294
+
295
+ # Use tensor.run_jax to isolate JAX operations (tile is not an EDSL primitive)
296
+ def _tile(d: Any) -> Any:
297
+ return jnp.tile(d, (n, 1))
298
+
299
+ delta_expanded = tensor.run_jax(_tile, delta)
300
+
301
+ # Compute U* * Delta in GF(2^128)
302
+ target = field.mul(u_s, delta_expanded)
303
+ return target
304
+
305
+ target_val = simp.pcall_static(
306
+ (receiver,), _recv_verify_ops, u_star_recv, delta_receiver
307
+ )
308
+
309
+ # 5.3 Hash Exchange
310
+ # Use robust hashing to prevent algebraic attacks or leakage
311
+ from mplang.v2.libs.mpc.ot import extension as ot_extension
312
+
313
+ def _hash_shares(share: el.Object, party: int) -> el.Object:
314
+ """Hash the shares using domain separator for security."""
315
+ return ot_extension.vec_hash(share, domain_sep=0xFEED, num_rows=n)
316
+
317
+ # Hash(Target) on Receiver
318
+ h_target_recv = simp.pcall_static(
319
+ (receiver,), lambda x: _hash_shares(x, receiver), target_val
320
+ )
321
+
322
+ # Hash(T) on Sender
323
+ h_t_sender = simp.pcall_static(
324
+ (sender,), lambda x: _hash_shares(x, sender), t_val_sender
325
+ )
326
+
327
+ # Send Hash to Sender for comparison
328
+ h_target_at_sender = simp.shuffle_static(h_target_recv, {sender: receiver})
329
+
330
+ # 5.4 Final Comparison on Sender
331
+ def _compare(h_t: Any, h_target: Any) -> Any:
332
+ # Compare 32-byte hashes (N, 32) row-by-row
333
+
334
+ def _core(a: Any, b: Any) -> Any:
335
+ eq = jnp.all(a == b, axis=1)
336
+ return eq.astype(jnp.uint8) # (N,) 0 or 1
337
+
338
+ return tensor.run_jax(_core, h_t, h_target)
339
+
340
+ intersection_mask = simp.pcall_static(
341
+ (sender,), _compare, h_t_sender, h_target_at_sender
342
+ )
343
+
344
+ return cast(el.Object, intersection_mask)