mplang-nightly 0.1.dev267__py3-none-any.whl → 0.1.dev269__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mplang/py.typed ADDED
@@ -0,0 +1,13 @@
1
+ # Copyright 2026 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -32,7 +32,7 @@ if TYPE_CHECKING:
32
32
  from concurrent.futures import Future
33
33
 
34
34
  from mplang.v2.edsl.graph import Graph
35
- from mplang.v2.edsl.spec import ClusterSpec
35
+ from mplang.v2.libs.device import ClusterSpec
36
36
 
37
37
 
38
38
  class SimpHttpDriver(SimpDriver):
@@ -32,7 +32,7 @@ if TYPE_CHECKING:
32
32
  from concurrent.futures import Future
33
33
 
34
34
  from mplang.v2.edsl.graph import Graph
35
- from mplang.v2.edsl.spec import ClusterSpec
35
+ from mplang.v2.libs.device import ClusterSpec
36
36
 
37
37
 
38
38
  class MemCluster:
@@ -36,6 +36,10 @@ extern "C" {
36
36
  uint64_t h1, h2, h3;
37
37
  };
38
38
 
39
+ // Declaration of the safe (robust) solver implemented in okvs.cpp
40
+ // Signature: solve_okvs(keys, values, output, n, m, seed_ptr)
41
+ void solve_okvs(uint64_t* keys, uint64_t* values, uint64_t* output, uint64_t n, uint64_t m, uint64_t* seed_ptr);
42
+
39
43
  // Stateless Bin Selection
40
44
  // Maps a key to a deterministic bin index [0, NUM_BINS).
41
45
  inline uint64_t get_bin_index(uint64_t key, __m128i seed) {
@@ -245,10 +249,21 @@ extern "C" {
245
249
  uint64_t valid_m = m_per_bin[b];
246
250
 
247
251
  if(!solve_bin(bin_keys[b], bin_vals[b], &P_vec[offset], valid_m, seed)) {
252
+ // On failure, log and fall back to the robust solver for this bin.
253
+ // The fallback is executed inside a critical section to avoid nested OpenMP
254
+ // regions and to serialize rare fallbacks.
248
255
  #pragma omp critical
249
256
  {
250
- fprintf(stderr, "[ERROR] Bin %lu failed OKVS peeling. Items: %lu / M: %lu (Ratio: %.2f). Try increasing expansion factor.\n",
257
+ fprintf(stderr, "[WARN] Bin %lu failed optimized peeling; falling back to safe solver. Items: %lu / M: %lu (Ratio: %.2f)\n",
251
258
  b, bin_keys[b].size(), valid_m, (double)valid_m / bin_keys[b].size());
259
+
260
+ // Prepare pointers for the safe solver
261
+ uint64_t* keys_ptr = &bin_keys[b][0];
262
+ uint64_t* vals_ptr = &bin_vals[b][0];
263
+ uint64_t* out_ptr = output + (offset * 2ULL); // each 128-bit slot == 2 uint64_t
264
+
265
+ // Call the safe solver implemented in okvs.cpp
266
+ solve_okvs(keys_ptr, vals_ptr, out_ptr, bin_keys[b].size(), valid_m, seed_ptr);
252
267
  }
253
268
  }
254
269
  }
@@ -40,12 +40,16 @@ def get_okvs_expansion(n: int) -> float:
40
40
  - For N → ∞: Theoretical minimum is ε ≈ 0.23 (M = 1.23N)
41
41
  - For finite N: Larger ε needed due to variance in random hash collisions
42
42
 
43
- Empirical safe thresholds (failure probability < 0.01%):
44
- - N < 1,000: ε = 4.5 (M = 5.5N) - very small sets need extra wide margin
45
- to handle worst-case hash collisions
46
- - N < 10,000: ε = 0.4 (M = 1.4N)
47
- - N < 100,000: ε = 0.3 (M = 1.3N)
48
- - N ≥ 100,000: ε = 0.35 (M = 1.35N) - large sets converge near theory
43
+ Empirical safe thresholds (failure probability < 0.001%):
44
+ - N 200: ε = 24.0 (M = 25.0N) - extremely small sets need very wide margin
45
+ - N < 1,000: ε = 11.0 (M = 12.0N) - small sets need extra wide safety margin
46
+ - N < 10,000: ε = 0.6 (M = 1.6N)
47
+ - N < 100,000: ε = 0.4 (M = 1.4N)
48
+ - N ≥ 100,000: ε = 0.35 (M = 1.35N) - large sets converge near theory
49
+
50
+ Note: These expansion factors account for the 128-byte alignment requirement
51
+ in the OKVS implementation. The factors are intentionally conservative to
52
+ ensure high success rates (>99.9%) for the probabilistic peeling algorithm.
49
53
 
50
54
  Args:
51
55
  n: Number of key-value pairs to encode
@@ -53,12 +57,14 @@ def get_okvs_expansion(n: int) -> float:
53
57
  Returns:
54
58
  Expansion factor ε such that M = (1+ε)*N is safe for peeling
55
59
  """
56
- if n < 1000:
57
- return 5.5 # Small scale: need very wide safety margin for stability
60
+ if n <= 200:
61
+ return 25.0 # Extremely small scale: need very wide margin for stability
62
+ elif n < 1000:
63
+ return 12.0 # Small scale: need wide safety margin for stability
58
64
  elif n <= 10000:
59
- return 1.4 # Medium scale
65
+ return 1.6 # Medium scale
60
66
  elif n <= 100000:
61
- return 1.3 # Large scale
67
+ return 1.4 # Large scale
62
68
  else:
63
69
  # Mega-Binning requires ~1.35 for stability with 1024 bins
64
70
  return 1.35
@@ -28,37 +28,36 @@ Phases:
28
28
  1. **Correlated Randomness (VOLE)**:
29
29
  Sender and Receiver establish a shared correlation:
30
30
  W = V + U * Delta
31
- - Sender holds U, V.
32
- - Receiver holds W, Delta.
33
- - U is random. Delta is a fixed secret scalar (Receiver's key).
31
+ - PSI Receiver holds `U` and `V` (these are generated by the OT "sender"
32
+ role in `silent_vole_random_u`).
33
+ - PSI Sender holds `W` and `Delta` (these are generated by the OT "receiver"
34
+ role).
35
+ - `U` is random; `Delta` is a fixed secret scalar held by the Sender.
34
36
 
35
37
  2. **Encoding (OKVS)**:
36
- Receiver encodes their input set Y into a structure P using OKVS, such that:
37
- Decode(P, y) = H(y) for all y in Y.
38
- Here H(y) is a Random Oracle (implemented via Davies-Meyer/AES).
38
+ The Receiver encodes its input set Y into an OKVS storage `P` such that
39
+ Decode(P, y) = H(y) for all y in Y. The function `H(y)` is implemented via
40
+ AES/Davies–Meyer expansion acting as a random oracle.
39
41
 
40
42
  3. **Masking & Exchange**:
41
- Receiver masks the structure P with their VOLE share W:
42
- Q = P ^ W
43
- Receiver sends Q to Sender.
44
-
45
- 4. **Decoding & Verification**:
46
- Sender attempts to decode Q for each of their items x in X.
47
- Since OKVS is linear:
48
- Decode(Q, x) = Decode(P, x) ^ Decode(W, x)
49
-
50
- Sender reconstructs the potential "Target" value T:
51
- T = Decode(Q, x) ^ Decode(V, x) ^ H(x)
52
-
53
- If x in Y (Intersection):
54
- Decode(P, x) = H(x)
55
- Decode(W, x) = Decode(V, x) ^ Decode(U, x) * Delta
56
- Substitute into T:
57
- T = H(x) ^ (Decode(V, x) ^ Decode(U, x) * Delta) ^ Decode(V, x) ^ H(x)
58
- T = Decode(U, x) * Delta
59
-
60
- Thus, verification becomes checking if T == U* * Delta, where U* = Decode(U, x).
61
- This check is performed securely using hashes to prevent leakage.
43
+ The Receiver masks the OKVS storage `P` with its VOLE share `U`:
44
+ Q = P + U
45
+ The masked storage `Q` is sent to the Sender (so Sender sees a masked OKVS).
46
+
47
+ 4. **Decoding & Tag Generation (Sender)**:
48
+ The Sender holds `W` and `Delta` and computes the linear combination:
49
+ K = Q * Delta + W
50
+ Using W = V + U * Delta and Q = P + U, this simplifies to
51
+ K = P * Delta + V.
52
+ The Sender decodes `K` for each of its items x to obtain P(x)*Delta + V(x),
53
+ then subtracts H(x)*Delta (computed locally) to recover `V(x)`. The value
54
+ `Tag = V(x)` serves as the sender-side tag for item x.
55
+
56
+ 5. **Verification (Receiver)**:
57
+ The Sender hashes and truncates tags and sends the truncated hashes to the
58
+ Receiver. The Receiver locally decodes `V(y)` from its OKVS, hashes it with
59
+ the same domain separation and truncation, and compares to the received
60
+ truncated hashes to determine membership in the intersection.
62
61
  """
63
62
 
64
63
  from typing import Any, cast
@@ -77,17 +76,24 @@ def psi_intersect(
77
76
  sender_items: el.Object,
78
77
  receiver_items: el.Object,
79
78
  ) -> el.Object:
80
- """Execute OKVS-based PSI Protocol.
79
+ """Execute OKVS-based PSI Protocol (Original RR22 Logic).
80
+
81
+ This implementation follows the RR22 paper's role assignment where:
82
+ - PSI Sender holds Delta (and W).
83
+ - PSI Receiver holds U and V.
84
+
85
+ This enables the "One Decode" optimization on the Sender side and prevents
86
+ offline brute-force attacks by the Receiver (though Sender could brute-force).
81
87
 
82
88
  Args:
83
89
  sender: Rank of Sender.
84
90
  receiver: Rank of Receiver.
85
- n: Number of items (must be same for now).
86
- sender_items: Object located at Sender containing (N,) u64 items.
87
- receiver_items: Object located at Receiver containing (N,) u64 items.
91
+ n: Number of items.
92
+ sender_items: Object located at Sender.
93
+ receiver_items: Object located at Receiver.
88
94
 
89
95
  Returns:
90
- Intersection verification tuple (T, U*, Delta).
96
+ Intersection mask (0/1) located at Receiver.
91
97
  """
92
98
 
93
99
  # Validation
@@ -95,52 +101,42 @@ def psi_intersect(
95
101
  raise ValueError(
96
102
  f"Sender ({sender}) and Receiver ({receiver}) must be different."
97
103
  )
98
-
99
104
  if n <= 0:
100
105
  raise ValueError(f"Input size n must be positive, got {n}.")
101
106
 
102
107
  # =========================================================================
103
- # Phase 1. Parameter Setup & Topology
108
+ # Phase 1. Parameter Setup
104
109
  # =========================================================================
105
- # OKVS Size M = expansion * N.
106
- # The expansion factor is critical for the success probability of the "Peeling"
107
- # algorithm used in OKVS encoding (Garbled Cuckoo Table).
108
- # Larger N allows smaller expansion (closer to theoretical 1.23) while maintaining safety.
109
110
  import mplang.v2.libs.mpc.psi.okvs_gct as okvs_gct
110
111
 
111
112
  expansion = okvs_gct.get_okvs_expansion(n)
112
113
  M = int(n * expansion)
113
-
114
- # Align M to 128 boundary for efficient batch processing in Silent VOLE (LPN)
115
114
  if M % 128 != 0:
116
115
  M = ((M // 128) + 1) * 128
117
116
 
118
117
  # =========================================================================
119
- # Phase 2. Correlated Randomness Generation (VOLE)
118
+ # Phase 2. Correlated Randomness (VOLE)
120
119
  # =========================================================================
121
- # Parties run Silent VOLE (based on LPN assumption) to generate:
122
- # Sender: U, V (Vectors of size M)
123
- # Receiver: W, Delta
124
- # Correlation: W = V + U * Delta
120
+ # In the original paper logic (Fig 4), the PSI Sender holds Delta.
121
+ # Therefore, we swap the roles in the OT call.
122
+ #
123
+ # silent_vole_random_u(A, B) gives:
124
+ # A (OT Sender): U, V
125
+ # B (OT Receiver): W, Delta
125
126
  #
126
- # Note: U is uniformly random. It acts as a "One-Time Pad" key for the protocol.
127
+ # We want PSI Sender to be OT Receiver.
128
+ res_tuple = silent_ot.silent_vole_random_u(receiver, sender, M, base_k=1024)
127
129
 
128
- # silent_vole_random_u returns (v, w, u, delta)
129
- res_tuple = silent_ot.silent_vole_random_u(sender, receiver, M, base_k=1024)
130
- v_sender, w_receiver, u_sender, delta_receiver = res_tuple[:4]
130
+ # PSI Receiver gets U, V
131
+ v_recv, w_sender, u_recv, delta_sender = res_tuple[:4]
131
132
 
132
133
  # =========================================================================
133
- # Phase 3. Receiver Encoding & Masking (OKVS)
134
+ # Phase 3. Receiver Encoding & Masking
134
135
  # =========================================================================
135
- # The Receiver encodes their input set Y into the OKVS structure P.
136
- # Goal: Decode(P, y) = H(y) forall y in Y.
137
- #
138
- # Then, Receiver masks P with the VOLE output W to get Q:
139
- # Q = P ^ W
140
- # This Q is sent to the Sender.
136
+ # Receiver computes P such that P(y) = H(y).
137
+ # Receiver masks P with U (Paper's A vector).
138
+ # Q = P ^ U
141
139
 
142
- # 3.1 Generate OKVS Seed (Public/Session Randomness)
143
- # Used for OKVS hashing distribution. Can be public, but generated at runtime for safety.
144
140
  from mplang.v2.dialects import crypto
145
141
  from mplang.v2.edsl import typing as elt
146
142
 
@@ -150,19 +146,11 @@ def psi_intersect(
150
146
  okvs_seed = simp.pcall_static((receiver,), _gen_seed)
151
147
  okvs_seed_sender = simp.shuffle_static(okvs_seed, {sender: receiver})
152
148
 
153
- # Instantiate OKVS Data Structure
154
149
  okvs = okvs_gct.SparseOKVS(M)
155
150
 
156
- def _recv_ops(y: Any, w: Any, delta: Any, seed: Any) -> Any:
157
- # y: (N,) Inputs
158
- # w: (M, 2) VOLE share
159
-
160
- # 3.2 Compute H(y) - The Random Oracle Target
161
- # We use Davies-Meyer construction: H(x) = E_x(0) ^ x
162
- # This is a standard, efficient, and robust way to instantiate a RO from AES.
163
-
151
+ def _recv_ops(y: Any, u: Any, seed: Any) -> Any:
152
+ # 3.1 Compute H(y)
164
153
  def _reshape_seeds(items: Any) -> Any:
165
- # Prepare items as AES keys (128-bit)
166
154
  lo = items
167
155
  hi = jnp.zeros_like(items)
168
156
  return jnp.stack([lo, hi], axis=1) # (N, 2)
@@ -176,55 +164,59 @@ def psi_intersect(
176
164
 
177
165
  h_y = tensor.run_jax(_davies_meyer, res_exp, seeds)
178
166
 
179
- # 3.3 Solve System of Linear Equations (OKVS Encode)
180
- # We find P such that: P * M_okvs(y) = h_y
167
+ # 3.2 Encode P
181
168
  p_storage = okvs.encode(y, h_y, seed)
182
169
 
183
- # 3.4 Mask with Vole Share
184
- # Q = P ^ W
185
- q_storage = field.add(p_storage, w)
170
+ # 3. Mask with U (instead of W)
171
+ # Q = P ^ U
172
+ q_storage = field.add(p_storage, u)
186
173
 
187
174
  return q_storage
188
175
 
189
- # Execute on Receiver
176
+ # Receiver uses U to mask
190
177
  q_shared = simp.pcall_static(
191
- (receiver,), _recv_ops, receiver_items, w_receiver, delta_receiver, okvs_seed
178
+ (receiver,), _recv_ops, receiver_items, u_recv, okvs_seed
192
179
  )
193
180
 
194
- # 3.5 Send Q to Sender
195
181
  q_sender_view = simp.shuffle_static(q_shared, {sender: receiver})
196
182
 
197
183
  # =========================================================================
198
- # Phase 4. Sender Decoding & Reconstruction
184
+ # Phase 4. Sender "One Decode" & Tag Generation
199
185
  # =========================================================================
200
- # Sender uses Q and their local shares (U, V) to reconstruct T.
186
+ # Sender holds W, Delta. Receives Q.
187
+ # W = V + U * Delta
201
188
  #
202
189
  # Derivation:
203
- # 1. S_decoded = Decode(Q, x) = Decode(P ^ W, x) = P(x) ^ W(x)
204
- # 2. Recall W(x) = V(x) ^ U(x)*Delta (VOLE property)
205
- # 3. So S_decoded = P(x) ^ V(x) ^ U(x)*Delta
206
- #
207
- # 4. Sender computes T = S_decoded ^ V(x) ^ H(x)
208
- # T = P(x) ^ V(x) ^ U(x)*Delta ^ V(x) ^ H(x)
209
- # T = P(x) ^ H(x) ^ U(x)*Delta
190
+ # K = Q * Delta + W
191
+ # = (P + U) * Delta + (V + U * Delta)
192
+ # = P * Delta + U * Delta + V + U * Delta
193
+ # = P * Delta + V
210
194
  #
211
- # 5. If x is in Intersection (Meanings x == y for some y):
212
- # Then P(x) == H(x) (by OKVS property)
213
- # So T = H(x) ^ H(x) ^ U(x)*Delta
214
- # T = U(x)*Delta
215
- #
216
- # This relation T == U* * Delta is what we verify in Phase 5.
195
+ # Sender computes Tag = Decode(K, x) - H(x) * Delta
196
+ # If x in Intersection: P(x) = H(x)
197
+ # Tag = (P(x) * Delta + V(x)) - P(x) * Delta
198
+ # Tag = V(x)
199
+
200
+ def _sender_ops(x: Any, q: Any, w: Any, delta: Any, seed: Any) -> Any:
201
+ # q, w: (M, 2)
202
+ # delta: (2,)
203
+
204
+ # Safe tiling assuming M is aligned
205
+ def _tile_m_simple(d: Any) -> Any:
206
+ return jnp.tile(d, (M, 1))
207
+
208
+ delta_expanded_m = tensor.run_jax(_tile_m_simple, delta)
217
209
 
218
- def _sender_ops(x: Any, q: Any, u: Any, v: Any, seed: Any) -> tuple[Any, Any]:
219
- # x: (N,) Sender Items
220
- # q: (M, 2) Received OKVS
210
+ # 4.2. Compute Global K = Q * Delta + W
211
+ # This is the O(M) multiplication mentioned in the paper
212
+ q_times_delta = field.mul(q, delta_expanded_m)
213
+ k_storage = field.add(q_times_delta, w)
221
214
 
222
- # 4.1 Decode Q and V at x
223
- # OKVS Decode is a linear combination of storage positions.
224
- s_decoded = okvs.decode(x, q, seed)
225
- v_decoded = okvs.decode(x, v, seed)
215
+ # 4.3 One Decode
216
+ # decoded_val = P(x)*Delta + V(x)
217
+ decoded_k = okvs.decode(x, k_storage, seed)
226
218
 
227
- # 4.2 Compute H(x)
219
+ # 4.4 Remove H(x)*Delta
228
220
  def _reshape_seeds(items: Any) -> Any:
229
221
  lo = items
230
222
  hi = jnp.zeros_like(items)
@@ -239,106 +231,73 @@ def psi_intersect(
239
231
 
240
232
  h_x = tensor.run_jax(_davies_meyer, res_exp_x, seeds_x)
241
233
 
242
- # 4.3 Compute T candidate
243
- # T = S ^ V ^ H(x)
244
- # Note: s_decoded is (S^V^U*Delta) effectively
245
- t_val = field.add(s_decoded, v_decoded)
246
- t_val = field.add(t_val, h_x)
234
+ # Expand delta for batch N
235
+ def _tile_n(d: Any) -> Any:
236
+ return jnp.tile(d, (n, 1))
237
+
238
+ delta_expanded_n = tensor.run_jax(_tile_n, delta)
247
239
 
248
- # 4.4 Compute U* = Decode(U, x)
249
- # This is the sender's share of the randomness for item x.
250
- s_u = field.decode_okvs(x, u, seed)
240
+ h_x_times_delta = field.mul(h_x, delta_expanded_n)
251
241
 
252
- return t_val, s_u
242
+ # Final Tag = (P*Delta + V) + H*Delta = V(x)
243
+ tag = field.add(decoded_k, h_x_times_delta)
253
244
 
254
- t_val_sender, u_star_sender = simp.pcall_static(
245
+ return tag
246
+
247
+ # Execute on Sender
248
+ sender_tags = simp.pcall_static(
255
249
  (sender,),
256
250
  _sender_ops,
257
251
  sender_items,
258
252
  q_sender_view,
259
- u_sender,
260
- v_sender,
253
+ w_sender,
254
+ delta_sender,
261
255
  okvs_seed_sender,
262
256
  )
263
-
264
257
  # =========================================================================
265
- # Phase 5. Secure Verification
258
+ # Phase 5. Verification (Receiver Side)
266
259
  # =========================================================================
267
- # The Protocol invariant is T == U* * Delta for intersection items.
268
- #
269
- # Security Risk:
270
- # We must NOT reveal T or Delta to the other party.
271
- # - If Receiver learns T, they can compute Diff = T - U*Delta = H(x) + ... and attack x.
272
- # - If Sender learns Delta, VOLE security collapses.
273
- #
274
- # Secure Verification Method:
275
- # 1. Sender sends U* (Random Mask share) to Receiver.
276
- # - U* is derived from U (random VOLE inputs) so it reveals nothing about X.
277
- #
278
- # 2. Receiver computes Target = U* * Delta.
279
- # - This allows Receiver to construct the expected value of T without knowing T's components.
280
- #
281
- # 3. Receiver Hashes the Target and sends H(Target) to Sender.
282
- # - Hashing prevents Sender from learning Delta algebraically.
283
- # - Hash function acts as a commitment.
284
- #
285
- # 4. Sender compares H(T) =? H(Target).
286
- # - Equality implies x is in Intersection.
287
-
288
- # 5.1 Sender -> Receiver: U*
289
- u_star_recv = simp.shuffle_static(u_star_sender, {receiver: sender})
290
-
291
- # 5.2 Receiver: Compute Expected Target (U* * Delta)
292
- def _recv_verify_ops(u_s: Any, delta: Any) -> Any:
293
- # u_s: (N, 2), delta: (2,)
294
-
295
- # Use tensor.run_jax to isolate JAX operations (tile is not an EDSL primitive)
296
- def _tile(d: Any) -> Any:
297
- return jnp.tile(d, (n, 1))
298
-
299
- delta_expanded = tensor.run_jax(_tile, delta)
300
-
301
- # Compute U* * Delta in GF(2^128)
302
- target = field.mul(u_s, delta_expanded)
303
- return target
304
-
305
- target_val = simp.pcall_static(
306
- (receiver,), _recv_verify_ops, u_star_recv, delta_receiver
307
- )
260
+ # Sender sends Tags (which should be V(x)) to Receiver. To reduce
261
+ # communication we hash and truncate on the sender side and only send
262
+ # the truncated hash (first 16 bytes).
308
263
 
309
- # 5.3 Hash Exchange
310
- # Use robust hashing to prevent algebraic attacks or leakage
264
+ # 5.1 Compute hashed & truncated tags on Sender
311
265
  from mplang.v2.libs.mpc.ot import extension as ot_extension
312
266
 
313
- def _hash_shares(share: el.Object, party: int) -> el.Object:
314
- """Hash the shares using domain separator for security."""
315
- return ot_extension.vec_hash(share, domain_sep=0xFEED, num_rows=n)
267
+ def _hash_and_trunc(tags: Any) -> Any:
268
+ # Compute batched hash on sender and truncate to 16 bytes
269
+ full_h = ot_extension.vec_hash(tags, domain_sep=0x1111, num_rows=n)
270
+ # Use tensor.slice_tensor to slice TraceObjects (start=(0,0), end=(n,16))
271
+ return tensor.slice_tensor(full_h, (0, 0), (n, 16))
316
272
 
317
- # Hash(Target) on Receiver
318
- h_target_recv = simp.pcall_static(
319
- (receiver,), lambda x: _hash_shares(x, receiver), target_val
320
- )
273
+ h_sender_trunc = simp.pcall_static((sender,), _hash_and_trunc, sender_tags)
321
274
 
322
- # Hash(T) on Sender
323
- h_t_sender = simp.pcall_static(
324
- (sender,), lambda x: _hash_shares(x, sender), t_val_sender
325
- )
275
+ # 5.2 Send truncated hashes to Receiver (much smaller payload)
276
+ tags_at_recv = simp.shuffle_static(h_sender_trunc, {receiver: sender})
277
+
278
+ # 5.3 Receiver computes local V(y) and compares
279
+ def _recv_verify(y: Any, v: Any, seed: Any, remote_tags: Any) -> Any:
280
+ # 1. Decode V locally: target = V(y)
281
+ local_v_y = okvs.decode(y, v, seed)
326
282
 
327
- # Send Hash to Sender for comparison
328
- h_target_at_sender = simp.shuffle_static(h_target_recv, {sender: receiver})
283
+ # 2. Hash local V(y) and compare with received truncated sender hashes
284
+ # Note: `remote_tags` here is already the truncated hash (16 bytes)
285
+ # sent from the Sender.
286
+ h_local = ot_extension.vec_hash(local_v_y, domain_sep=0x1111, num_rows=n)
329
287
 
330
- # 5.4 Final Comparison on Sender
331
- def _compare(h_t: Any, h_target: Any) -> Any:
332
- # Compare 32-byte hashes (N, 32) row-by-row
288
+ def _core(h_r16: Any, h_l_full: Any) -> Any:
289
+ # h_r16: (n, 16) truncated bytes from sender
290
+ # h_l_full: (n, k) full hash bytes; truncate to 16
291
+ h_l16 = h_l_full[:, :16]
333
292
 
334
- def _core(a: Any, b: Any) -> Any:
335
- eq = jnp.all(a == b, axis=1)
336
- return eq.astype(jnp.uint8) # (N,) 0 or 1
293
+ eq_matrix = jnp.all(h_r16[:, None, :] == h_l16[None, :, :], axis=2)
294
+ membership = jnp.any(eq_matrix, axis=0)
295
+ return membership.astype(jnp.uint8)
337
296
 
338
- return tensor.run_jax(_core, h_t, h_target)
297
+ return tensor.run_jax(_core, remote_tags, h_local)
339
298
 
340
299
  intersection_mask = simp.pcall_static(
341
- (sender,), _compare, h_t_sender, h_target_at_sender
300
+ (receiver,), _recv_verify, receiver_items, v_recv, okvs_seed, tags_at_recv
342
301
  )
343
302
 
344
303
  return cast(el.Object, intersection_mask)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mplang-nightly
3
- Version: 0.1.dev267
3
+ Version: 0.1.dev269
4
4
  Summary: Multi-Party Programming Language
5
5
  Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
6
6
  License: Apache License
@@ -1,4 +1,5 @@
1
1
  mplang/__init__.py,sha256=CdfWOdeg-I1q6ULjdBxeLioVlplA4bgPTSn_2xSk5VY,1677
2
+ mplang/py.typed,sha256=RyhZV7Yxo8BvEoBiFz5fH3IWVqQKTcBBE-bzjx_N5GQ,583
2
3
  mplang/v1/__init__.py,sha256=m7UQeAqYwQOzFt-lYqv9eKs9kdvutW025pxEG0h7eVs,3346
3
4
  mplang/v1/_device.py,sha256=MY4OO7TJr2oxDXbvv_pnBweALxP2wvV0cvq3DMLHFuE,22971
4
5
  mplang/v1/host.py,sha256=-daviW1W4HYFhBqMzkDkJoQap8HxNDFDop2FyI_mrak,4306
@@ -96,8 +97,8 @@ mplang/v2/backends/table_impl.py,sha256=Qmd-Z_PLjSbDngWkHz0wc6VykoGHfS2-rCOk1aWu
96
97
  mplang/v2/backends/tee_impl.py,sha256=Gp-vqqJPtEMNqP7y68tLhL3a-EW3BQwpo_qCJOSHqKs,7044
97
98
  mplang/v2/backends/tensor_impl.py,sha256=8f9f4-_e-m4JWGZSbXLmSSHcgPykRBc1sAYrA3OIxEg,18906
98
99
  mplang/v2/backends/simp_driver/__init__.py,sha256=ahOPYYvtFVwqxiFxkpSNP8BCTao_MfCXmtt5zsMaJxg,1258
99
- mplang/v2/backends/simp_driver/http.py,sha256=nl7ny7f8bzhy1ubNIDXhMgA5P_WA8dhhgFNHvcmfSKk,5548
100
- mplang/v2/backends/simp_driver/mem.py,sha256=nFA-KkYx5fDh6NseI8QOd5FkNErPDq4h_QjrWD7nMrE,9126
100
+ mplang/v2/backends/simp_driver/http.py,sha256=Fm0M7BKf6Ddqec79btd-tJiuVaD92yghr1GJc84RXmg,5550
101
+ mplang/v2/backends/simp_driver/mem.py,sha256=kx3jDAYx3QkJa1UZDhhY_JjJAdT8u-r6Hsw8fYwFPKY,9128
101
102
  mplang/v2/backends/simp_driver/ops.py,sha256=UeVC3eaCUwxrkN6OsJyMYj8qMDufMFQI0YogeSbhkjM,4515
102
103
  mplang/v2/backends/simp_driver/state.py,sha256=6tQyQg_PNzHOJkjCoNm51Wvknl3XiJZzpQXuRB4qRtM,1719
103
104
  mplang/v2/backends/simp_driver/values.py,sha256=OQ_7Kt6l7Pcfx5eB6GVbpunS6CG60Lj0AS6H9Wx9sKQ,1515
@@ -136,7 +137,7 @@ mplang/v2/kernels/__init__.py,sha256=J_rDl9lAXd7QL3Nt_P3YX6j9yge7ssguSaHuafPZNKE
136
137
  mplang/v2/kernels/gf128.cpp,sha256=WIvCr3MijzwJxMi1Wnfhm8aWT8oL0fia6FeyTmFJtPQ,5975
137
138
  mplang/v2/kernels/ldpc.cpp,sha256=_zE90ZHQvrweRkBB3CEu80cXKG0a-QNJ59ZQ452gml8,2654
138
139
  mplang/v2/kernels/okvs.cpp,sha256=Z_7oGHFAdLc5d5llUNujBO8HwDBh5yd3MpfmT8ZNf1o,10347
139
- mplang/v2/kernels/okvs_opt.cpp,sha256=d_HhvMdcebYsG2x7kYzjuFgmEsh9WKLH6SHee3375Bg,10932
140
+ mplang/v2/kernels/okvs_opt.cpp,sha256=5MkI_rTfFxohIKqU5Uog0TpxokgXTjUVIJMUS0q5e2I,11870
140
141
  mplang/v2/kernels/py_kernels.py,sha256=FDsD86IHV-UBzxZLolhSOkrp24PuboHXeb1gBHLOfMo,12073
141
142
  mplang/v2/libs/collective.py,sha256=pfXq9tmFUNKjeHhWMTjtzOi-m2Fn1lLru1G6txZVyic,10683
142
143
  mplang/v2/libs/device/__init__.py,sha256=mXsSvXrWmlHu6Ch87Vcd85m4L_qdDkbSvJyHyuai2fc,1251
@@ -159,9 +160,9 @@ mplang/v2/libs/mpc/ot/silent.py,sha256=9J3sMsz3XzxPbIk91IpfAvvdGeZw-Tt0kElyPsNln
159
160
  mplang/v2/libs/mpc/psi/__init__.py,sha256=mpevlx3Z5_u9Q1McDZBBIGHApeO9julgiM09GToxxEA,1231
160
161
  mplang/v2/libs/mpc/psi/cuckoo.py,sha256=GQvLi7BtaPZyk96xwVCwpQPGlcGhOUX6kdsEn8P80l0,7752
161
162
  mplang/v2/libs/mpc/psi/okvs.py,sha256=a1Q4ILrsLII9K-BJRSX8iKkpkxJsMxFEj7cTId-XGCE,1576
162
- mplang/v2/libs/mpc/psi/okvs_gct.py,sha256=wRxBEZw-dnYXHWng-1eRsnnP6k6wKySSUxigN9eq08k,3023
163
+ mplang/v2/libs/mpc/psi/okvs_gct.py,sha256=YcN5ms8StV4ogHR5gK4-SlbIjbcA0QjITagza2EyY70,3375
163
164
  mplang/v2/libs/mpc/psi/oprf.py,sha256=YXD-I9P3t1YuqHVxOD9JUpLTZUu-HjgvOaEOQ3hhxMM,13772
164
- mplang/v2/libs/mpc/psi/rr22.py,sha256=2mN1zbjrBUgaWCsF3Lj8ohtK6gcG95PtBb3EseS-Nsg,12614
165
+ mplang/v2/libs/mpc/psi/rr22.py,sha256=fDXigUduBnHfG_8qoL4uS7EHaTmjUjJsZCgnA0-u8cQ,11209
165
166
  mplang/v2/libs/mpc/psi/unbalanced.py,sha256=hC84TVsgnlJDg6hpUrx8kbUbmFb27T9wrHG0zv3FXLc,7433
166
167
  mplang/v2/libs/mpc/vole/__init__.py,sha256=2dU4X6n73HoK-YCiCl4b36SkLRKR6rofe2xxLxBz6Rc,968
167
168
  mplang/v2/libs/mpc/vole/gilboa.py,sha256=apnKOYR4_dJ2wkzGq7PBlwauA-W5i5MPESdetCWTegU,9951
@@ -172,8 +173,8 @@ mplang/v2/runtime/dialect_state.py,sha256=HxO1i4kSOujS2tQzAF9-WmI3nChSaGgupf2_07
172
173
  mplang/v2/runtime/interpreter.py,sha256=UzrM5oepka6H0YKRZncNXhsuwKVm4pliG5J92fFRZMI,32300
173
174
  mplang/v2/runtime/object_store.py,sha256=yT6jtKG2GUEJVmpq3gnQ8mCMvUFYzgBciC5A-J5KRdk,5998
174
175
  mplang/v2/runtime/value.py,sha256=CMOxElJP78v7pjasPhEpbxWbSgB2KsLbpPmzz0mQX0E,4317
175
- mplang_nightly-0.1.dev267.dist-info/METADATA,sha256=fWTmQUUYcAYzXHVW7m_4CxvnQb7w4qbWn5vZ7yffuoU,16775
176
- mplang_nightly-0.1.dev267.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
177
- mplang_nightly-0.1.dev267.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
178
- mplang_nightly-0.1.dev267.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
179
- mplang_nightly-0.1.dev267.dist-info/RECORD,,
176
+ mplang_nightly-0.1.dev269.dist-info/METADATA,sha256=54UZyxyPNQqxRpAXBoj_1ZQxHjf5a5Gi7sBKBsEYx78,16775
177
+ mplang_nightly-0.1.dev269.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
178
+ mplang_nightly-0.1.dev269.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
179
+ mplang_nightly-0.1.dev269.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
180
+ mplang_nightly-0.1.dev269.dist-info/RECORD,,