mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev269__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/v2/kernels/okvs_opt.cpp +16 -1
- mplang/v2/libs/mpc/psi/okvs_gct.py +16 -10
- mplang/v2/libs/mpc/psi/rr22.py +136 -177
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev269.dist-info}/METADATA +1 -1
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev269.dist-info}/RECORD +8 -8
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev269.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev269.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev269.dist-info}/licenses/LICENSE +0 -0
mplang/v2/kernels/okvs_opt.cpp
CHANGED
|
@@ -36,6 +36,10 @@ extern "C" {
|
|
|
36
36
|
uint64_t h1, h2, h3;
|
|
37
37
|
};
|
|
38
38
|
|
|
39
|
+
// Declaration of the safe (robust) solver implemented in okvs.cpp
|
|
40
|
+
// Signature: solve_okvs(keys, values, output, n, m, seed_ptr)
|
|
41
|
+
void solve_okvs(uint64_t* keys, uint64_t* values, uint64_t* output, uint64_t n, uint64_t m, uint64_t* seed_ptr);
|
|
42
|
+
|
|
39
43
|
// Stateless Bin Selection
|
|
40
44
|
// Maps a key to a deterministic bin index [0, NUM_BINS).
|
|
41
45
|
inline uint64_t get_bin_index(uint64_t key, __m128i seed) {
|
|
@@ -245,10 +249,21 @@ extern "C" {
|
|
|
245
249
|
uint64_t valid_m = m_per_bin[b];
|
|
246
250
|
|
|
247
251
|
if(!solve_bin(bin_keys[b], bin_vals[b], &P_vec[offset], valid_m, seed)) {
|
|
252
|
+
// On failure, log and fall back to the robust solver for this bin.
|
|
253
|
+
// The fallback is executed inside a critical section to avoid nested OpenMP
|
|
254
|
+
// regions and to serialize rare fallbacks.
|
|
248
255
|
#pragma omp critical
|
|
249
256
|
{
|
|
250
|
-
fprintf(stderr, "[
|
|
257
|
+
fprintf(stderr, "[WARN] Bin %lu failed optimized peeling; falling back to safe solver. Items: %lu / M: %lu (Ratio: %.2f)\n",
|
|
251
258
|
b, bin_keys[b].size(), valid_m, (double)valid_m / bin_keys[b].size());
|
|
259
|
+
|
|
260
|
+
// Prepare pointers for the safe solver
|
|
261
|
+
uint64_t* keys_ptr = &bin_keys[b][0];
|
|
262
|
+
uint64_t* vals_ptr = &bin_vals[b][0];
|
|
263
|
+
uint64_t* out_ptr = output + (offset * 2ULL); // each 128-bit slot == 2 uint64_t
|
|
264
|
+
|
|
265
|
+
// Call the safe solver implemented in okvs.cpp
|
|
266
|
+
solve_okvs(keys_ptr, vals_ptr, out_ptr, bin_keys[b].size(), valid_m, seed_ptr);
|
|
252
267
|
}
|
|
253
268
|
}
|
|
254
269
|
}
|
|
@@ -40,12 +40,16 @@ def get_okvs_expansion(n: int) -> float:
|
|
|
40
40
|
- For N → ∞: Theoretical minimum is ε ≈ 0.23 (M = 1.23N)
|
|
41
41
|
- For finite N: Larger ε needed due to variance in random hash collisions
|
|
42
42
|
|
|
43
|
-
Empirical safe thresholds (failure probability < 0.
|
|
44
|
-
- N
|
|
45
|
-
|
|
46
|
-
- N < 10,000: ε = 0.
|
|
47
|
-
- N < 100,000: ε = 0.
|
|
48
|
-
- N ≥ 100,000: ε = 0.35 (M = 1.35N)
|
|
43
|
+
Empirical safe thresholds (failure probability < 0.001%):
|
|
44
|
+
- N ≤ 200: ε = 24.0 (M = 25.0N) - extremely small sets need very wide margin
|
|
45
|
+
- N < 1,000: ε = 11.0 (M = 12.0N) - small sets need extra wide safety margin
|
|
46
|
+
- N < 10,000: ε = 0.6 (M = 1.6N)
|
|
47
|
+
- N < 100,000: ε = 0.4 (M = 1.4N)
|
|
48
|
+
- N ≥ 100,000: ε = 0.35 (M = 1.35N) - large sets converge near theory
|
|
49
|
+
|
|
50
|
+
Note: These expansion factors account for the 128-byte alignment requirement
|
|
51
|
+
in the OKVS implementation. The factors are intentionally conservative to
|
|
52
|
+
ensure high success rates (>99.9%) for the probabilistic peeling algorithm.
|
|
49
53
|
|
|
50
54
|
Args:
|
|
51
55
|
n: Number of key-value pairs to encode
|
|
@@ -53,12 +57,14 @@ def get_okvs_expansion(n: int) -> float:
|
|
|
53
57
|
Returns:
|
|
54
58
|
Expansion factor ε such that M = (1+ε)*N is safe for peeling
|
|
55
59
|
"""
|
|
56
|
-
if n
|
|
57
|
-
return
|
|
60
|
+
if n <= 200:
|
|
61
|
+
return 25.0 # Extremely small scale: need very wide margin for stability
|
|
62
|
+
elif n < 1000:
|
|
63
|
+
return 12.0 # Small scale: need wide safety margin for stability
|
|
58
64
|
elif n <= 10000:
|
|
59
|
-
return 1.
|
|
65
|
+
return 1.6 # Medium scale
|
|
60
66
|
elif n <= 100000:
|
|
61
|
-
return 1.
|
|
67
|
+
return 1.4 # Large scale
|
|
62
68
|
else:
|
|
63
69
|
# Mega-Binning requires ~1.35 for stability with 1024 bins
|
|
64
70
|
return 1.35
|
mplang/v2/libs/mpc/psi/rr22.py
CHANGED
|
@@ -28,37 +28,36 @@ Phases:
|
|
|
28
28
|
1. **Correlated Randomness (VOLE)**:
|
|
29
29
|
Sender and Receiver establish a shared correlation:
|
|
30
30
|
W = V + U * Delta
|
|
31
|
-
-
|
|
32
|
-
|
|
33
|
-
-
|
|
31
|
+
- PSI Receiver holds `U` and `V` (these are generated by the OT "sender"
|
|
32
|
+
role in `silent_vole_random_u`).
|
|
33
|
+
- PSI Sender holds `W` and `Delta` (these are generated by the OT "receiver"
|
|
34
|
+
role).
|
|
35
|
+
- `U` is random; `Delta` is a fixed secret scalar held by the Sender.
|
|
34
36
|
|
|
35
37
|
2. **Encoding (OKVS)**:
|
|
36
|
-
Receiver encodes
|
|
37
|
-
Decode(P, y) = H(y) for all y in Y.
|
|
38
|
-
|
|
38
|
+
The Receiver encodes its input set Y into an OKVS storage `P` such that
|
|
39
|
+
Decode(P, y) = H(y) for all y in Y. The function `H(y)` is implemented via
|
|
40
|
+
AES/Davies–Meyer expansion acting as a random oracle.
|
|
39
41
|
|
|
40
42
|
3. **Masking & Exchange**:
|
|
41
|
-
Receiver masks the
|
|
42
|
-
Q = P
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
4. **Decoding &
|
|
46
|
-
Sender
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
Sender
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
Thus, verification becomes checking if T == U* * Delta, where U* = Decode(U, x).
|
|
61
|
-
This check is performed securely using hashes to prevent leakage.
|
|
43
|
+
The Receiver masks the OKVS storage `P` with its VOLE share `U`:
|
|
44
|
+
Q = P + U
|
|
45
|
+
The masked storage `Q` is sent to the Sender (so Sender sees a masked OKVS).
|
|
46
|
+
|
|
47
|
+
4. **Decoding & Tag Generation (Sender)**:
|
|
48
|
+
The Sender holds `W` and `Delta` and computes the linear combination:
|
|
49
|
+
K = Q * Delta + W
|
|
50
|
+
Using W = V + U * Delta and Q = P + U, this simplifies to
|
|
51
|
+
K = P * Delta + V.
|
|
52
|
+
The Sender decodes `K` for each of its items x to obtain P(x)*Delta + V(x),
|
|
53
|
+
then subtracts H(x)*Delta (computed locally) to recover `V(x)`. The value
|
|
54
|
+
`Tag = V(x)` serves as the sender-side tag for item x.
|
|
55
|
+
|
|
56
|
+
5. **Verification (Receiver)**:
|
|
57
|
+
The Sender hashes and truncates tags and sends the truncated hashes to the
|
|
58
|
+
Receiver. The Receiver locally decodes `V(y)` from its OKVS, hashes it with
|
|
59
|
+
the same domain separation and truncation, and compares to the received
|
|
60
|
+
truncated hashes to determine membership in the intersection.
|
|
62
61
|
"""
|
|
63
62
|
|
|
64
63
|
from typing import Any, cast
|
|
@@ -77,17 +76,24 @@ def psi_intersect(
|
|
|
77
76
|
sender_items: el.Object,
|
|
78
77
|
receiver_items: el.Object,
|
|
79
78
|
) -> el.Object:
|
|
80
|
-
"""Execute OKVS-based PSI Protocol.
|
|
79
|
+
"""Execute OKVS-based PSI Protocol (Original RR22 Logic).
|
|
80
|
+
|
|
81
|
+
This implementation follows the RR22 paper's role assignment where:
|
|
82
|
+
- PSI Sender holds Delta (and W).
|
|
83
|
+
- PSI Receiver holds U and V.
|
|
84
|
+
|
|
85
|
+
This enables the "One Decode" optimization on the Sender side and prevents
|
|
86
|
+
offline brute-force attacks by the Receiver (though Sender could brute-force).
|
|
81
87
|
|
|
82
88
|
Args:
|
|
83
89
|
sender: Rank of Sender.
|
|
84
90
|
receiver: Rank of Receiver.
|
|
85
|
-
n: Number of items
|
|
86
|
-
sender_items: Object located at Sender
|
|
87
|
-
receiver_items: Object located at Receiver
|
|
91
|
+
n: Number of items.
|
|
92
|
+
sender_items: Object located at Sender.
|
|
93
|
+
receiver_items: Object located at Receiver.
|
|
88
94
|
|
|
89
95
|
Returns:
|
|
90
|
-
Intersection
|
|
96
|
+
Intersection mask (0/1) located at Receiver.
|
|
91
97
|
"""
|
|
92
98
|
|
|
93
99
|
# Validation
|
|
@@ -95,52 +101,42 @@ def psi_intersect(
|
|
|
95
101
|
raise ValueError(
|
|
96
102
|
f"Sender ({sender}) and Receiver ({receiver}) must be different."
|
|
97
103
|
)
|
|
98
|
-
|
|
99
104
|
if n <= 0:
|
|
100
105
|
raise ValueError(f"Input size n must be positive, got {n}.")
|
|
101
106
|
|
|
102
107
|
# =========================================================================
|
|
103
|
-
# Phase 1. Parameter Setup
|
|
108
|
+
# Phase 1. Parameter Setup
|
|
104
109
|
# =========================================================================
|
|
105
|
-
# OKVS Size M = expansion * N.
|
|
106
|
-
# The expansion factor is critical for the success probability of the "Peeling"
|
|
107
|
-
# algorithm used in OKVS encoding (Garbled Cuckoo Table).
|
|
108
|
-
# Larger N allows smaller expansion (closer to theoretical 1.23) while maintaining safety.
|
|
109
110
|
import mplang.v2.libs.mpc.psi.okvs_gct as okvs_gct
|
|
110
111
|
|
|
111
112
|
expansion = okvs_gct.get_okvs_expansion(n)
|
|
112
113
|
M = int(n * expansion)
|
|
113
|
-
|
|
114
|
-
# Align M to 128 boundary for efficient batch processing in Silent VOLE (LPN)
|
|
115
114
|
if M % 128 != 0:
|
|
116
115
|
M = ((M // 128) + 1) * 128
|
|
117
116
|
|
|
118
117
|
# =========================================================================
|
|
119
|
-
# Phase 2. Correlated Randomness
|
|
118
|
+
# Phase 2. Correlated Randomness (VOLE)
|
|
120
119
|
# =========================================================================
|
|
121
|
-
#
|
|
122
|
-
#
|
|
123
|
-
#
|
|
124
|
-
#
|
|
120
|
+
# In the original paper logic (Fig 4), the PSI Sender holds Delta.
|
|
121
|
+
# Therefore, we swap the roles in the OT call.
|
|
122
|
+
#
|
|
123
|
+
# silent_vole_random_u(A, B) gives:
|
|
124
|
+
# A (OT Sender): U, V
|
|
125
|
+
# B (OT Receiver): W, Delta
|
|
125
126
|
#
|
|
126
|
-
#
|
|
127
|
+
# We want PSI Sender to be OT Receiver.
|
|
128
|
+
res_tuple = silent_ot.silent_vole_random_u(receiver, sender, M, base_k=1024)
|
|
127
129
|
|
|
128
|
-
#
|
|
129
|
-
|
|
130
|
-
v_sender, w_receiver, u_sender, delta_receiver = res_tuple[:4]
|
|
130
|
+
# PSI Receiver gets U, V
|
|
131
|
+
v_recv, w_sender, u_recv, delta_sender = res_tuple[:4]
|
|
131
132
|
|
|
132
133
|
# =========================================================================
|
|
133
|
-
# Phase 3. Receiver Encoding & Masking
|
|
134
|
+
# Phase 3. Receiver Encoding & Masking
|
|
134
135
|
# =========================================================================
|
|
135
|
-
#
|
|
136
|
-
#
|
|
137
|
-
#
|
|
138
|
-
# Then, Receiver masks P with the VOLE output W to get Q:
|
|
139
|
-
# Q = P ^ W
|
|
140
|
-
# This Q is sent to the Sender.
|
|
136
|
+
# Receiver computes P such that P(y) = H(y).
|
|
137
|
+
# Receiver masks P with U (Paper's A vector).
|
|
138
|
+
# Q = P ^ U
|
|
141
139
|
|
|
142
|
-
# 3.1 Generate OKVS Seed (Public/Session Randomness)
|
|
143
|
-
# Used for OKVS hashing distribution. Can be public, but generated at runtime for safety.
|
|
144
140
|
from mplang.v2.dialects import crypto
|
|
145
141
|
from mplang.v2.edsl import typing as elt
|
|
146
142
|
|
|
@@ -150,19 +146,11 @@ def psi_intersect(
|
|
|
150
146
|
okvs_seed = simp.pcall_static((receiver,), _gen_seed)
|
|
151
147
|
okvs_seed_sender = simp.shuffle_static(okvs_seed, {sender: receiver})
|
|
152
148
|
|
|
153
|
-
# Instantiate OKVS Data Structure
|
|
154
149
|
okvs = okvs_gct.SparseOKVS(M)
|
|
155
150
|
|
|
156
|
-
def _recv_ops(y: Any,
|
|
157
|
-
#
|
|
158
|
-
# w: (M, 2) VOLE share
|
|
159
|
-
|
|
160
|
-
# 3.2 Compute H(y) - The Random Oracle Target
|
|
161
|
-
# We use Davies-Meyer construction: H(x) = E_x(0) ^ x
|
|
162
|
-
# This is a standard, efficient, and robust way to instantiate a RO from AES.
|
|
163
|
-
|
|
151
|
+
def _recv_ops(y: Any, u: Any, seed: Any) -> Any:
|
|
152
|
+
# 3.1 Compute H(y)
|
|
164
153
|
def _reshape_seeds(items: Any) -> Any:
|
|
165
|
-
# Prepare items as AES keys (128-bit)
|
|
166
154
|
lo = items
|
|
167
155
|
hi = jnp.zeros_like(items)
|
|
168
156
|
return jnp.stack([lo, hi], axis=1) # (N, 2)
|
|
@@ -176,55 +164,59 @@ def psi_intersect(
|
|
|
176
164
|
|
|
177
165
|
h_y = tensor.run_jax(_davies_meyer, res_exp, seeds)
|
|
178
166
|
|
|
179
|
-
# 3.
|
|
180
|
-
# We find P such that: P * M_okvs(y) = h_y
|
|
167
|
+
# 3.2 Encode P
|
|
181
168
|
p_storage = okvs.encode(y, h_y, seed)
|
|
182
169
|
|
|
183
|
-
# 3.
|
|
184
|
-
# Q = P ^
|
|
185
|
-
q_storage = field.add(p_storage,
|
|
170
|
+
# 3. Mask with U (instead of W)
|
|
171
|
+
# Q = P ^ U
|
|
172
|
+
q_storage = field.add(p_storage, u)
|
|
186
173
|
|
|
187
174
|
return q_storage
|
|
188
175
|
|
|
189
|
-
#
|
|
176
|
+
# Receiver uses U to mask
|
|
190
177
|
q_shared = simp.pcall_static(
|
|
191
|
-
(receiver,), _recv_ops, receiver_items,
|
|
178
|
+
(receiver,), _recv_ops, receiver_items, u_recv, okvs_seed
|
|
192
179
|
)
|
|
193
180
|
|
|
194
|
-
# 3.5 Send Q to Sender
|
|
195
181
|
q_sender_view = simp.shuffle_static(q_shared, {sender: receiver})
|
|
196
182
|
|
|
197
183
|
# =========================================================================
|
|
198
|
-
# Phase 4. Sender
|
|
184
|
+
# Phase 4. Sender "One Decode" & Tag Generation
|
|
199
185
|
# =========================================================================
|
|
200
|
-
# Sender
|
|
186
|
+
# Sender holds W, Delta. Receives Q.
|
|
187
|
+
# W = V + U * Delta
|
|
201
188
|
#
|
|
202
189
|
# Derivation:
|
|
203
|
-
#
|
|
204
|
-
#
|
|
205
|
-
#
|
|
206
|
-
#
|
|
207
|
-
# 4. Sender computes T = S_decoded ^ V(x) ^ H(x)
|
|
208
|
-
# T = P(x) ^ V(x) ^ U(x)*Delta ^ V(x) ^ H(x)
|
|
209
|
-
# T = P(x) ^ H(x) ^ U(x)*Delta
|
|
190
|
+
# K = Q * Delta + W
|
|
191
|
+
# = (P + U) * Delta + (V + U * Delta)
|
|
192
|
+
# = P * Delta + U * Delta + V + U * Delta
|
|
193
|
+
# = P * Delta + V
|
|
210
194
|
#
|
|
211
|
-
#
|
|
212
|
-
#
|
|
213
|
-
#
|
|
214
|
-
#
|
|
215
|
-
|
|
216
|
-
|
|
195
|
+
# Sender computes Tag = Decode(K, x) - H(x) * Delta
|
|
196
|
+
# If x in Intersection: P(x) = H(x)
|
|
197
|
+
# Tag = (P(x) * Delta + V(x)) - P(x) * Delta
|
|
198
|
+
# Tag = V(x)
|
|
199
|
+
|
|
200
|
+
def _sender_ops(x: Any, q: Any, w: Any, delta: Any, seed: Any) -> Any:
|
|
201
|
+
# q, w: (M, 2)
|
|
202
|
+
# delta: (2,)
|
|
203
|
+
|
|
204
|
+
# Safe tiling assuming M is aligned
|
|
205
|
+
def _tile_m_simple(d: Any) -> Any:
|
|
206
|
+
return jnp.tile(d, (M, 1))
|
|
207
|
+
|
|
208
|
+
delta_expanded_m = tensor.run_jax(_tile_m_simple, delta)
|
|
217
209
|
|
|
218
|
-
|
|
219
|
-
#
|
|
220
|
-
|
|
210
|
+
# 4.2. Compute Global K = Q * Delta + W
|
|
211
|
+
# This is the O(M) multiplication mentioned in the paper
|
|
212
|
+
q_times_delta = field.mul(q, delta_expanded_m)
|
|
213
|
+
k_storage = field.add(q_times_delta, w)
|
|
221
214
|
|
|
222
|
-
# 4.
|
|
223
|
-
#
|
|
224
|
-
|
|
225
|
-
v_decoded = okvs.decode(x, v, seed)
|
|
215
|
+
# 4.3 One Decode
|
|
216
|
+
# decoded_val = P(x)*Delta + V(x)
|
|
217
|
+
decoded_k = okvs.decode(x, k_storage, seed)
|
|
226
218
|
|
|
227
|
-
# 4.
|
|
219
|
+
# 4.4 Remove H(x)*Delta
|
|
228
220
|
def _reshape_seeds(items: Any) -> Any:
|
|
229
221
|
lo = items
|
|
230
222
|
hi = jnp.zeros_like(items)
|
|
@@ -239,106 +231,73 @@ def psi_intersect(
|
|
|
239
231
|
|
|
240
232
|
h_x = tensor.run_jax(_davies_meyer, res_exp_x, seeds_x)
|
|
241
233
|
|
|
242
|
-
#
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
234
|
+
# Expand delta for batch N
|
|
235
|
+
def _tile_n(d: Any) -> Any:
|
|
236
|
+
return jnp.tile(d, (n, 1))
|
|
237
|
+
|
|
238
|
+
delta_expanded_n = tensor.run_jax(_tile_n, delta)
|
|
247
239
|
|
|
248
|
-
|
|
249
|
-
# This is the sender's share of the randomness for item x.
|
|
250
|
-
s_u = field.decode_okvs(x, u, seed)
|
|
240
|
+
h_x_times_delta = field.mul(h_x, delta_expanded_n)
|
|
251
241
|
|
|
252
|
-
|
|
242
|
+
# Final Tag = (P*Delta + V) + H*Delta = V(x)
|
|
243
|
+
tag = field.add(decoded_k, h_x_times_delta)
|
|
253
244
|
|
|
254
|
-
|
|
245
|
+
return tag
|
|
246
|
+
|
|
247
|
+
# Execute on Sender
|
|
248
|
+
sender_tags = simp.pcall_static(
|
|
255
249
|
(sender,),
|
|
256
250
|
_sender_ops,
|
|
257
251
|
sender_items,
|
|
258
252
|
q_sender_view,
|
|
259
|
-
|
|
260
|
-
|
|
253
|
+
w_sender,
|
|
254
|
+
delta_sender,
|
|
261
255
|
okvs_seed_sender,
|
|
262
256
|
)
|
|
263
|
-
|
|
264
257
|
# =========================================================================
|
|
265
|
-
# Phase 5.
|
|
258
|
+
# Phase 5. Verification (Receiver Side)
|
|
266
259
|
# =========================================================================
|
|
267
|
-
#
|
|
268
|
-
#
|
|
269
|
-
#
|
|
270
|
-
# We must NOT reveal T or Delta to the other party.
|
|
271
|
-
# - If Receiver learns T, they can compute Diff = T - U*Delta = H(x) + ... and attack x.
|
|
272
|
-
# - If Sender learns Delta, VOLE security collapses.
|
|
273
|
-
#
|
|
274
|
-
# Secure Verification Method:
|
|
275
|
-
# 1. Sender sends U* (Random Mask share) to Receiver.
|
|
276
|
-
# - U* is derived from U (random VOLE inputs) so it reveals nothing about X.
|
|
277
|
-
#
|
|
278
|
-
# 2. Receiver computes Target = U* * Delta.
|
|
279
|
-
# - This allows Receiver to construct the expected value of T without knowing T's components.
|
|
280
|
-
#
|
|
281
|
-
# 3. Receiver Hashes the Target and sends H(Target) to Sender.
|
|
282
|
-
# - Hashing prevents Sender from learning Delta algebraically.
|
|
283
|
-
# - Hash function acts as a commitment.
|
|
284
|
-
#
|
|
285
|
-
# 4. Sender compares H(T) =? H(Target).
|
|
286
|
-
# - Equality implies x is in Intersection.
|
|
287
|
-
|
|
288
|
-
# 5.1 Sender -> Receiver: U*
|
|
289
|
-
u_star_recv = simp.shuffle_static(u_star_sender, {receiver: sender})
|
|
290
|
-
|
|
291
|
-
# 5.2 Receiver: Compute Expected Target (U* * Delta)
|
|
292
|
-
def _recv_verify_ops(u_s: Any, delta: Any) -> Any:
|
|
293
|
-
# u_s: (N, 2), delta: (2,)
|
|
294
|
-
|
|
295
|
-
# Use tensor.run_jax to isolate JAX operations (tile is not an EDSL primitive)
|
|
296
|
-
def _tile(d: Any) -> Any:
|
|
297
|
-
return jnp.tile(d, (n, 1))
|
|
298
|
-
|
|
299
|
-
delta_expanded = tensor.run_jax(_tile, delta)
|
|
300
|
-
|
|
301
|
-
# Compute U* * Delta in GF(2^128)
|
|
302
|
-
target = field.mul(u_s, delta_expanded)
|
|
303
|
-
return target
|
|
304
|
-
|
|
305
|
-
target_val = simp.pcall_static(
|
|
306
|
-
(receiver,), _recv_verify_ops, u_star_recv, delta_receiver
|
|
307
|
-
)
|
|
260
|
+
# Sender sends Tags (which should be V(x)) to Receiver. To reduce
|
|
261
|
+
# communication we hash and truncate on the sender side and only send
|
|
262
|
+
# the truncated hash (first 16 bytes).
|
|
308
263
|
|
|
309
|
-
# 5.
|
|
310
|
-
# Use robust hashing to prevent algebraic attacks or leakage
|
|
264
|
+
# 5.1 Compute hashed & truncated tags on Sender
|
|
311
265
|
from mplang.v2.libs.mpc.ot import extension as ot_extension
|
|
312
266
|
|
|
313
|
-
def
|
|
314
|
-
|
|
315
|
-
|
|
267
|
+
def _hash_and_trunc(tags: Any) -> Any:
|
|
268
|
+
# Compute batched hash on sender and truncate to 16 bytes
|
|
269
|
+
full_h = ot_extension.vec_hash(tags, domain_sep=0x1111, num_rows=n)
|
|
270
|
+
# Use tensor.slice_tensor to slice TraceObjects (start=(0,0), end=(n,16))
|
|
271
|
+
return tensor.slice_tensor(full_h, (0, 0), (n, 16))
|
|
316
272
|
|
|
317
|
-
|
|
318
|
-
h_target_recv = simp.pcall_static(
|
|
319
|
-
(receiver,), lambda x: _hash_shares(x, receiver), target_val
|
|
320
|
-
)
|
|
273
|
+
h_sender_trunc = simp.pcall_static((sender,), _hash_and_trunc, sender_tags)
|
|
321
274
|
|
|
322
|
-
#
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
)
|
|
275
|
+
# 5.2 Send truncated hashes to Receiver (much smaller payload)
|
|
276
|
+
tags_at_recv = simp.shuffle_static(h_sender_trunc, {receiver: sender})
|
|
277
|
+
|
|
278
|
+
# 5.3 Receiver computes local V(y) and compares
|
|
279
|
+
def _recv_verify(y: Any, v: Any, seed: Any, remote_tags: Any) -> Any:
|
|
280
|
+
# 1. Decode V locally: target = V(y)
|
|
281
|
+
local_v_y = okvs.decode(y, v, seed)
|
|
326
282
|
|
|
327
|
-
|
|
328
|
-
|
|
283
|
+
# 2. Hash local V(y) and compare with received truncated sender hashes
|
|
284
|
+
# Note: `remote_tags` here is already the truncated hash (16 bytes)
|
|
285
|
+
# sent from the Sender.
|
|
286
|
+
h_local = ot_extension.vec_hash(local_v_y, domain_sep=0x1111, num_rows=n)
|
|
329
287
|
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
288
|
+
def _core(h_r16: Any, h_l_full: Any) -> Any:
|
|
289
|
+
# h_r16: (n, 16) truncated bytes from sender
|
|
290
|
+
# h_l_full: (n, k) full hash bytes; truncate to 16
|
|
291
|
+
h_l16 = h_l_full[:, :16]
|
|
333
292
|
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
return
|
|
293
|
+
eq_matrix = jnp.all(h_r16[:, None, :] == h_l16[None, :, :], axis=2)
|
|
294
|
+
membership = jnp.any(eq_matrix, axis=0)
|
|
295
|
+
return membership.astype(jnp.uint8)
|
|
337
296
|
|
|
338
|
-
return tensor.run_jax(_core,
|
|
297
|
+
return tensor.run_jax(_core, remote_tags, h_local)
|
|
339
298
|
|
|
340
299
|
intersection_mask = simp.pcall_static(
|
|
341
|
-
(
|
|
300
|
+
(receiver,), _recv_verify, receiver_items, v_recv, okvs_seed, tags_at_recv
|
|
342
301
|
)
|
|
343
302
|
|
|
344
303
|
return cast(el.Object, intersection_mask)
|
|
@@ -137,7 +137,7 @@ mplang/v2/kernels/__init__.py,sha256=J_rDl9lAXd7QL3Nt_P3YX6j9yge7ssguSaHuafPZNKE
|
|
|
137
137
|
mplang/v2/kernels/gf128.cpp,sha256=WIvCr3MijzwJxMi1Wnfhm8aWT8oL0fia6FeyTmFJtPQ,5975
|
|
138
138
|
mplang/v2/kernels/ldpc.cpp,sha256=_zE90ZHQvrweRkBB3CEu80cXKG0a-QNJ59ZQ452gml8,2654
|
|
139
139
|
mplang/v2/kernels/okvs.cpp,sha256=Z_7oGHFAdLc5d5llUNujBO8HwDBh5yd3MpfmT8ZNf1o,10347
|
|
140
|
-
mplang/v2/kernels/okvs_opt.cpp,sha256=
|
|
140
|
+
mplang/v2/kernels/okvs_opt.cpp,sha256=5MkI_rTfFxohIKqU5Uog0TpxokgXTjUVIJMUS0q5e2I,11870
|
|
141
141
|
mplang/v2/kernels/py_kernels.py,sha256=FDsD86IHV-UBzxZLolhSOkrp24PuboHXeb1gBHLOfMo,12073
|
|
142
142
|
mplang/v2/libs/collective.py,sha256=pfXq9tmFUNKjeHhWMTjtzOi-m2Fn1lLru1G6txZVyic,10683
|
|
143
143
|
mplang/v2/libs/device/__init__.py,sha256=mXsSvXrWmlHu6Ch87Vcd85m4L_qdDkbSvJyHyuai2fc,1251
|
|
@@ -160,9 +160,9 @@ mplang/v2/libs/mpc/ot/silent.py,sha256=9J3sMsz3XzxPbIk91IpfAvvdGeZw-Tt0kElyPsNln
|
|
|
160
160
|
mplang/v2/libs/mpc/psi/__init__.py,sha256=mpevlx3Z5_u9Q1McDZBBIGHApeO9julgiM09GToxxEA,1231
|
|
161
161
|
mplang/v2/libs/mpc/psi/cuckoo.py,sha256=GQvLi7BtaPZyk96xwVCwpQPGlcGhOUX6kdsEn8P80l0,7752
|
|
162
162
|
mplang/v2/libs/mpc/psi/okvs.py,sha256=a1Q4ILrsLII9K-BJRSX8iKkpkxJsMxFEj7cTId-XGCE,1576
|
|
163
|
-
mplang/v2/libs/mpc/psi/okvs_gct.py,sha256=
|
|
163
|
+
mplang/v2/libs/mpc/psi/okvs_gct.py,sha256=YcN5ms8StV4ogHR5gK4-SlbIjbcA0QjITagza2EyY70,3375
|
|
164
164
|
mplang/v2/libs/mpc/psi/oprf.py,sha256=YXD-I9P3t1YuqHVxOD9JUpLTZUu-HjgvOaEOQ3hhxMM,13772
|
|
165
|
-
mplang/v2/libs/mpc/psi/rr22.py,sha256=
|
|
165
|
+
mplang/v2/libs/mpc/psi/rr22.py,sha256=fDXigUduBnHfG_8qoL4uS7EHaTmjUjJsZCgnA0-u8cQ,11209
|
|
166
166
|
mplang/v2/libs/mpc/psi/unbalanced.py,sha256=hC84TVsgnlJDg6hpUrx8kbUbmFb27T9wrHG0zv3FXLc,7433
|
|
167
167
|
mplang/v2/libs/mpc/vole/__init__.py,sha256=2dU4X6n73HoK-YCiCl4b36SkLRKR6rofe2xxLxBz6Rc,968
|
|
168
168
|
mplang/v2/libs/mpc/vole/gilboa.py,sha256=apnKOYR4_dJ2wkzGq7PBlwauA-W5i5MPESdetCWTegU,9951
|
|
@@ -173,8 +173,8 @@ mplang/v2/runtime/dialect_state.py,sha256=HxO1i4kSOujS2tQzAF9-WmI3nChSaGgupf2_07
|
|
|
173
173
|
mplang/v2/runtime/interpreter.py,sha256=UzrM5oepka6H0YKRZncNXhsuwKVm4pliG5J92fFRZMI,32300
|
|
174
174
|
mplang/v2/runtime/object_store.py,sha256=yT6jtKG2GUEJVmpq3gnQ8mCMvUFYzgBciC5A-J5KRdk,5998
|
|
175
175
|
mplang/v2/runtime/value.py,sha256=CMOxElJP78v7pjasPhEpbxWbSgB2KsLbpPmzz0mQX0E,4317
|
|
176
|
-
mplang_nightly-0.1.
|
|
177
|
-
mplang_nightly-0.1.
|
|
178
|
-
mplang_nightly-0.1.
|
|
179
|
-
mplang_nightly-0.1.
|
|
180
|
-
mplang_nightly-0.1.
|
|
176
|
+
mplang_nightly-0.1.dev269.dist-info/METADATA,sha256=54UZyxyPNQqxRpAXBoj_1ZQxHjf5a5Gi7sBKBsEYx78,16775
|
|
177
|
+
mplang_nightly-0.1.dev269.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
178
|
+
mplang_nightly-0.1.dev269.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
|
|
179
|
+
mplang_nightly-0.1.dev269.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
180
|
+
mplang_nightly-0.1.dev269.dist-info/RECORD,,
|
|
File without changes
|
{mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev269.dist-info}/entry_points.txt
RENAMED
|
File without changes
|
{mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev269.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|