mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Oblivious Pseudorandom Function (OPRF).
|
|
16
|
+
|
|
17
|
+
Implements KKRT-style OPRF based on OT Extension.
|
|
18
|
+
Ref: https://eprint.iacr.org/2016/799.pdf
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
from typing import Any, cast
|
|
24
|
+
|
|
25
|
+
import jax.numpy as jnp
|
|
26
|
+
|
|
27
|
+
import mplang.v2.edsl as el
|
|
28
|
+
from mplang.v2.dialects import simp, tensor
|
|
29
|
+
from mplang.v2.libs.mpc.ot import extension as ot_extension
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def eval_oprf(
|
|
33
|
+
receiver_inputs: el.Object, # (N, 16) bytes
|
|
34
|
+
sender: int,
|
|
35
|
+
receiver: int,
|
|
36
|
+
num_items: int,
|
|
37
|
+
) -> tuple[el.Object, el.Object]:
|
|
38
|
+
"""Evaluate OPRF on receiver's inputs using KKRT-style protocol.
|
|
39
|
+
|
|
40
|
+
Protocol Overview:
|
|
41
|
+
──────────────────────────────────────────────────────────────────────────
|
|
42
|
+
This implements a simplified KKRT OPRF using IKNP OT Extension as the base.
|
|
43
|
+
|
|
44
|
+
Parties:
|
|
45
|
+
- Sender: Has secret key (T matrix, s vector) from IKNP
|
|
46
|
+
- Receiver: Has inputs x₁, ..., xₙ and gets PRF outputs
|
|
47
|
+
|
|
48
|
+
Key Relations (IKNP):
|
|
49
|
+
──────────────────────────────────────────────────────────────────────────
|
|
50
|
+
Let:
|
|
51
|
+
Q[i]: (K,) bit vector - receiver's OT output for row i
|
|
52
|
+
T[i]: (K,) bit vector - sender's OT output for row i
|
|
53
|
+
s: (K,) bit vector - sender's secret (random)
|
|
54
|
+
c[i]: 1 bit - receiver's choice bit for row i
|
|
55
|
+
|
|
56
|
+
IKNP Correlation:
|
|
57
|
+
T[i][j] = Q[i][j] ⊕ (c[i] · s[j]) for all j ∈ [0, K)
|
|
58
|
+
|
|
59
|
+
Where ⊕ is XOR and · is AND.
|
|
60
|
+
This means: if c[i] = 1: T[i] = Q[i] ⊕ s
|
|
61
|
+
if c[i] = 0: T[i] = Q[i]
|
|
62
|
+
|
|
63
|
+
Simplified OPRF Construction:
|
|
64
|
+
──────────────────────────────────────────────────────────────────────────
|
|
65
|
+
Choice bits: c[i] = encode(x_i)[0] (first bit of item encoding)
|
|
66
|
+
|
|
67
|
+
Receiver output: PRF(x_i) = pack(Q[i]) (just pack the Q matrix row)
|
|
68
|
+
Sender can eval: PRF(y) = pack(T[row(y)]) (pack corresponding T row)
|
|
69
|
+
|
|
70
|
+
When x_i == y and they map to same row: outputs match due to IKNP relation.
|
|
71
|
+
|
|
72
|
+
Note: Full KKRT uses Cuckoo hashing to map items to rows. This simplified
|
|
73
|
+
version assumes sequential mapping (item i uses row i).
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
receiver_inputs: (N, 16) byte tensor of receiver's inputs
|
|
77
|
+
sender: Rank of sender party
|
|
78
|
+
receiver: Rank of receiver party
|
|
79
|
+
num_items: Number of items N
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Tuple of:
|
|
83
|
+
- sender_key: (T, s) tuple on sender - T is (N, K) bit matrix, s is (K,)
|
|
84
|
+
- receiver_outputs: (N, 32) byte tensor of OPRF outputs on receiver (SHA256)
|
|
85
|
+
"""
|
|
86
|
+
K = 128 # Security parameter (OT extension width)
|
|
87
|
+
|
|
88
|
+
# ═════════════════════════════════════════════════════════════════════════
|
|
89
|
+
# Step 1: Encode receiver's inputs to choice bits for IKNP
|
|
90
|
+
# ═════════════════════════════════════════════════════════════════════════
|
|
91
|
+
# For each input x_i, we need K choice bits for IKNP OT Extension.
|
|
92
|
+
# We use a deterministic encoding: unpack bytes to bits.
|
|
93
|
+
|
|
94
|
+
def encode_inputs(inputs: el.Object) -> el.Object:
|
|
95
|
+
"""Encode (N, 16) byte inputs to (N, K) bit codes.
|
|
96
|
+
|
|
97
|
+
Each 16-byte input is unpacked to 128 bits.
|
|
98
|
+
These bits serve as the receiver's OT choices.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def _encode(x: Any) -> Any:
|
|
102
|
+
# x: (N, 16) uint8 array
|
|
103
|
+
# Unpack each byte to 8 bits: (N, 16) -> (N, 128)
|
|
104
|
+
unpacked = jnp.unpackbits(x, axis=1) # (N, 128)
|
|
105
|
+
# Ensure exactly K bits
|
|
106
|
+
return unpacked[:, :K].astype(jnp.uint8) # (N, K)
|
|
107
|
+
|
|
108
|
+
return cast(el.Object, tensor.run_jax(_encode, inputs))
|
|
109
|
+
|
|
110
|
+
choice_codes = simp.pcall_static((receiver,), encode_inputs, receiver_inputs)
|
|
111
|
+
# choice_codes: (N, K) bit matrix on receiver
|
|
112
|
+
|
|
113
|
+
# ═════════════════════════════════════════════════════════════════════════
|
|
114
|
+
# Step 2: Extract first bit of each code as IKNP choice bits
|
|
115
|
+
# ═════════════════════════════════════════════════════════════════════════
|
|
116
|
+
# Simplified: use only first bit of encoding as OT choice
|
|
117
|
+
# Full KKRT would use all K bits differently
|
|
118
|
+
|
|
119
|
+
# ═════════════════════════════════════════════════════════════════════════
|
|
120
|
+
# Step 3: Run IKNP OT Extension to generate correlated matrices Q and T
|
|
121
|
+
# ═════════════════════════════════════════════════════════════════════════
|
|
122
|
+
# IKNP generates:
|
|
123
|
+
# Q: (N, K) on receiver - one K-bit row per item
|
|
124
|
+
# T: (N, K) on sender - correlated via T[i] = Q[i] ⊕ (choice[i] · s)
|
|
125
|
+
# s: (K,) on sender - random secret vector
|
|
126
|
+
|
|
127
|
+
# Pass full K-bit codes as choice bits (N, K)
|
|
128
|
+
t_matrix, q_matrix, s = ot_extension.iknp_core(
|
|
129
|
+
choice_codes, sender, receiver, num_items
|
|
130
|
+
)
|
|
131
|
+
# t_matrix: (N, K) on sender
|
|
132
|
+
# q_matrix: (N, K) on receiver
|
|
133
|
+
# s: (K,) on sender
|
|
134
|
+
|
|
135
|
+
# ═════════════════════════════════════════════════════════════════════════
|
|
136
|
+
# Step 4: Compute OPRF outputs
|
|
137
|
+
# ═════════════════════════════════════════════════════════════════════════
|
|
138
|
+
# Simplified KKRT:
|
|
139
|
+
# Receiver: output_i = pack(Q[i]) (pack 128 bits to 16 bytes)
|
|
140
|
+
# Sender: can later compute pack(T[i]) for matching items
|
|
141
|
+
|
|
142
|
+
def compute_receiver_outputs(q: el.Object, codes: el.Object) -> el.Object:
|
|
143
|
+
"""Compute receiver's OPRF outputs by packing Q matrix rows.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
q: (N, K) bit matrix Q from IKNP
|
|
147
|
+
codes: (N, K) bit codes (not used in simplified version)
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
(N, 16) packed bytes - OPRF output for each input
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
def _process(q_mat: Any, code_mat: Any) -> Any:
|
|
154
|
+
# q_mat: (N, K=128) bits
|
|
155
|
+
# Pack each row from 128 bits to 16 bytes
|
|
156
|
+
packed = jnp.packbits(q_mat, axis=1) # (N, 16) uint8
|
|
157
|
+
return packed
|
|
158
|
+
|
|
159
|
+
packed_q = cast(el.Object, tensor.run_jax(_process, q, codes))
|
|
160
|
+
|
|
161
|
+
# Security Fix: Hash the OT output to implement a Random Oracle
|
|
162
|
+
# OPRF = H(OT_output, input_tweaks...)
|
|
163
|
+
# Here we use the shared vec_hash utility which handles domain separation.
|
|
164
|
+
return ot_extension.vec_hash(packed_q, domain_sep=0x0CDF, num_rows=num_items)
|
|
165
|
+
|
|
166
|
+
receiver_outputs = simp.pcall_static(
|
|
167
|
+
(receiver,), compute_receiver_outputs, q_matrix, choice_codes
|
|
168
|
+
)
|
|
169
|
+
# receiver_outputs: (N, 32) on receiver
|
|
170
|
+
|
|
171
|
+
# ═════════════════════════════════════════════════════════════════════════
|
|
172
|
+
# Step 5: Package sender's key for later PRF evaluation
|
|
173
|
+
# ═════════════════════════════════════════════════════════════════════════
|
|
174
|
+
# Sender keeps (T, s) to evaluate PRF on any input later
|
|
175
|
+
sender_key = simp.pcall_static((sender,), lambda t, s_: (t, s_), t_matrix, s)
|
|
176
|
+
# sender_key: tuple (T, s) on sender where T is (N,K), s is (K,)
|
|
177
|
+
|
|
178
|
+
return sender_key, receiver_outputs
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
# =============================================================================
|
|
182
|
+
# KKRT OPRF Sender Evaluation (Vectorized)
|
|
183
|
+
# =============================================================================
|
|
184
|
+
#
|
|
185
|
+
# KKRT Formula:
|
|
186
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
187
|
+
# For sender with key (T, s) and input y:
|
|
188
|
+
# code_y = encode(y) # K bits
|
|
189
|
+
# output = pack(T[row] XOR (code_y * s))
|
|
190
|
+
#
|
|
191
|
+
# For receiver with Q matrix and input x:
|
|
192
|
+
# code_x = encode(x) # K bits
|
|
193
|
+
# output = pack(Q[row] XOR code_x)
|
|
194
|
+
#
|
|
195
|
+
# When x == y:
|
|
196
|
+
# T[row] XOR (code_x * s) == Q[row] XOR code_x ✅ (due to IKNP correlation)
|
|
197
|
+
# =============================================================================
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def sender_eval_prf_batch(
|
|
201
|
+
sender_key: el.Object, # Tuple (t_matrix, s) on sender
|
|
202
|
+
sender_items: el.Object, # (M, 16) bytes - items to evaluate
|
|
203
|
+
sender: int,
|
|
204
|
+
num_items: int,
|
|
205
|
+
) -> el.Object:
|
|
206
|
+
"""Evaluate PRF on sender's side for a batch of items.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
sender_key: The key tuple (t_matrix, s) from eval_oprf.
|
|
210
|
+
sender_items: (M, 16) byte tensor of sender's items.
|
|
211
|
+
sender: Rank of sender party.
|
|
212
|
+
num_items: Number of items M (must be provided).
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
(M, 32) byte tensor of PRF outputs on sender.
|
|
216
|
+
"""
|
|
217
|
+
K = 128
|
|
218
|
+
|
|
219
|
+
def compute_sender_outputs(key: el.Object, items: el.Object) -> el.Object:
|
|
220
|
+
"""Compute sender's PRF outputs using KKRT formula."""
|
|
221
|
+
|
|
222
|
+
def _eval(key_tuple: Any, x: Any) -> Any:
|
|
223
|
+
t_matrix, s = key_tuple
|
|
224
|
+
M = x.shape[0]
|
|
225
|
+
N = t_matrix.shape[0]
|
|
226
|
+
|
|
227
|
+
# Encode items to get choice bits
|
|
228
|
+
# Unpack: (M, 16) -> (M, 128) bits
|
|
229
|
+
codes = jnp.unpackbits(x, axis=1)[:, :K] # (M, K)
|
|
230
|
+
|
|
231
|
+
# Compute (codes · s) for each item
|
|
232
|
+
# Masking s with item codes ensures result depends on EVERY bit
|
|
233
|
+
# codes: (M, K), s: (K,) -> broadcast to (M, K)
|
|
234
|
+
code_masked = jnp.where(codes, s, 0).astype(t_matrix.dtype)
|
|
235
|
+
|
|
236
|
+
# Use row i for item i
|
|
237
|
+
M_clipped = min(M, N)
|
|
238
|
+
t_rows = t_matrix[:M_clipped] # (M_clipped, K)
|
|
239
|
+
|
|
240
|
+
# KKRT: output = T[i] XOR (first_bit[i] · s)
|
|
241
|
+
xored = jnp.bitwise_xor(t_rows, code_masked[:M_clipped]) # (M_clipped, K)
|
|
242
|
+
|
|
243
|
+
# Pack to bytes
|
|
244
|
+
packed = jnp.packbits(xored, axis=1) # (M_clipped, 16)
|
|
245
|
+
|
|
246
|
+
# Pad if needed
|
|
247
|
+
if M > N:
|
|
248
|
+
padding = jnp.zeros((M - N, 16), dtype=packed.dtype)
|
|
249
|
+
packed = jnp.concatenate([packed, padding], axis=0)
|
|
250
|
+
|
|
251
|
+
return packed
|
|
252
|
+
|
|
253
|
+
raw_outputs = cast(el.Object, tensor.run_jax(_eval, key, items))
|
|
254
|
+
|
|
255
|
+
return ot_extension.vec_hash(raw_outputs, domain_sep=0x0CDF, num_rows=num_items)
|
|
256
|
+
|
|
257
|
+
return cast(
|
|
258
|
+
el.Object,
|
|
259
|
+
simp.pcall_static((sender,), compute_sender_outputs, sender_key, sender_items),
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def sender_eval_prf(
|
|
264
|
+
sender_key: el.Object, # Tuple (t_matrix, s) on sender
|
|
265
|
+
candidate: el.Object, # (16,) bytes to evaluate
|
|
266
|
+
sender: int,
|
|
267
|
+
) -> el.Object:
|
|
268
|
+
"""Evaluate PRF on sender's side for a single candidate.
|
|
269
|
+
|
|
270
|
+
This allows sender to compute PRF(k, y) for any y.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
sender_key: The key tuple from eval_oprf.
|
|
274
|
+
candidate: A single 16-byte input to evaluate.
|
|
275
|
+
sender: Rank of sender party.
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
(32,) byte tensor of PRF output on sender.
|
|
279
|
+
"""
|
|
280
|
+
K = 128
|
|
281
|
+
|
|
282
|
+
def _eval(key: el.Object, cand: el.Object) -> el.Object:
|
|
283
|
+
def _compute(key_tuple: Any, c: Any) -> Any:
|
|
284
|
+
t_matrix, s = key_tuple
|
|
285
|
+
|
|
286
|
+
# Encode candidate to K bits
|
|
287
|
+
code = jnp.unpackbits(c)[:K] # (K,)
|
|
288
|
+
|
|
289
|
+
# KKRT formula: output = pack(t_row XOR (code * s))
|
|
290
|
+
t_row = t_matrix[0] # (K,) - use first row
|
|
291
|
+
code_masked = jnp.bitwise_and(code, s) # (K,)
|
|
292
|
+
xored = jnp.bitwise_xor(t_row, code_masked) # (K,)
|
|
293
|
+
|
|
294
|
+
# Pack to bytes
|
|
295
|
+
packed = jnp.packbits(xored) # (16,)
|
|
296
|
+
|
|
297
|
+
# Reshape to (1, 16) for vec_hash
|
|
298
|
+
return packed.reshape(1, 16)
|
|
299
|
+
|
|
300
|
+
raw_out_batch = cast(el.Object, tensor.run_jax(_compute, key, cand))
|
|
301
|
+
|
|
302
|
+
# Use batched hash with num_rows=1
|
|
303
|
+
hashed_batch = ot_extension.vec_hash(
|
|
304
|
+
raw_out_batch, domain_sep=0x0CDF, num_rows=1
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# Flatten back to (32,) using slice to avoid extra run_jax node
|
|
308
|
+
return tensor.slice_tensor(hashed_batch, (0, 0), (32,))
|
|
309
|
+
|
|
310
|
+
return cast(el.Object, simp.pcall_static((sender,), _eval, sender_key, candidate))
|
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Private Set Intersection using VOLE and OKVS (RR22-Style).
|
|
16
|
+
|
|
17
|
+
This module implements a high-performance PSI protocol based on the "Blazing Fast PSI"
|
|
18
|
+
(RR22) paper. The protocol relies on Vector Oblivious Linear Evaluation (VOLE) and
|
|
19
|
+
Oblivious Key-Value Stores (OKVS) to achieve efficient set intersection with linear
|
|
20
|
+
communication O(N) and computation complexity.
|
|
21
|
+
|
|
22
|
+
Protocol Overview:
|
|
23
|
+
The core idea is to mask a "Polynomial" (encoded via OKVS) with VOLE-correlated randomness,
|
|
24
|
+
such that the mask can only be removed (and the polynomial verified) if the parties share
|
|
25
|
+
the same element.
|
|
26
|
+
|
|
27
|
+
Phases:
|
|
28
|
+
1. **Correlated Randomness (VOLE)**:
|
|
29
|
+
Sender and Receiver establish a shared correlation:
|
|
30
|
+
W = V + U * Delta
|
|
31
|
+
- Sender holds U, V.
|
|
32
|
+
- Receiver holds W, Delta.
|
|
33
|
+
- U is random. Delta is a fixed secret scalar (Receiver's key).
|
|
34
|
+
|
|
35
|
+
2. **Encoding (OKVS)**:
|
|
36
|
+
Receiver encodes their input set Y into a structure P using OKVS, such that:
|
|
37
|
+
Decode(P, y) = H(y) for all y in Y.
|
|
38
|
+
Here H(y) is a Random Oracle (implemented via Davies-Meyer/AES).
|
|
39
|
+
|
|
40
|
+
3. **Masking & Exchange**:
|
|
41
|
+
Receiver masks the structure P with their VOLE share W:
|
|
42
|
+
Q = P ^ W
|
|
43
|
+
Receiver sends Q to Sender.
|
|
44
|
+
|
|
45
|
+
4. **Decoding & Verification**:
|
|
46
|
+
Sender attempts to decode Q for each of their items x in X.
|
|
47
|
+
Since OKVS is linear:
|
|
48
|
+
Decode(Q, x) = Decode(P, x) ^ Decode(W, x)
|
|
49
|
+
|
|
50
|
+
Sender reconstructs the potential "Target" value T:
|
|
51
|
+
T = Decode(Q, x) ^ Decode(V, x) ^ H(x)
|
|
52
|
+
|
|
53
|
+
If x in Y (Intersection):
|
|
54
|
+
Decode(P, x) = H(x)
|
|
55
|
+
Decode(W, x) = Decode(V, x) ^ Decode(U, x) * Delta
|
|
56
|
+
Substitute into T:
|
|
57
|
+
T = H(x) ^ (Decode(V, x) ^ Decode(U, x) * Delta) ^ Decode(V, x) ^ H(x)
|
|
58
|
+
T = Decode(U, x) * Delta
|
|
59
|
+
|
|
60
|
+
Thus, verification becomes checking if T == U* * Delta, where U* = Decode(U, x).
|
|
61
|
+
This check is performed securely using hashes to prevent leakage.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
from typing import Any, cast
|
|
65
|
+
|
|
66
|
+
import jax.numpy as jnp
|
|
67
|
+
|
|
68
|
+
import mplang.v2.edsl as el
|
|
69
|
+
import mplang.v2.libs.mpc.ot.silent as silent_ot
|
|
70
|
+
from mplang.v2.dialects import field, simp, tensor
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def psi_intersect(
|
|
74
|
+
sender: int,
|
|
75
|
+
receiver: int,
|
|
76
|
+
n: int,
|
|
77
|
+
sender_items: el.Object,
|
|
78
|
+
receiver_items: el.Object,
|
|
79
|
+
) -> el.Object:
|
|
80
|
+
"""Execute OKVS-based PSI Protocol.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
sender: Rank of Sender.
|
|
84
|
+
receiver: Rank of Receiver.
|
|
85
|
+
n: Number of items (must be same for now).
|
|
86
|
+
sender_items: Object located at Sender containing (N,) u64 items.
|
|
87
|
+
receiver_items: Object located at Receiver containing (N,) u64 items.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
Intersection verification tuple (T, U*, Delta).
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
# Validation
|
|
94
|
+
if sender == receiver:
|
|
95
|
+
raise ValueError(
|
|
96
|
+
f"Sender ({sender}) and Receiver ({receiver}) must be different."
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
if n <= 0:
|
|
100
|
+
raise ValueError(f"Input size n must be positive, got {n}.")
|
|
101
|
+
|
|
102
|
+
# =========================================================================
|
|
103
|
+
# Phase 1. Parameter Setup & Topology
|
|
104
|
+
# =========================================================================
|
|
105
|
+
# OKVS Size M = expansion * N.
|
|
106
|
+
# The expansion factor is critical for the success probability of the "Peeling"
|
|
107
|
+
# algorithm used in OKVS encoding (Garbled Cuckoo Table).
|
|
108
|
+
# Larger N allows smaller expansion (closer to theoretical 1.23) while maintaining safety.
|
|
109
|
+
import mplang.v2.libs.mpc.psi.okvs_gct as okvs_gct
|
|
110
|
+
|
|
111
|
+
expansion = okvs_gct.get_okvs_expansion(n)
|
|
112
|
+
M = int(n * expansion)
|
|
113
|
+
|
|
114
|
+
# Align M to 128 boundary for efficient batch processing in Silent VOLE (LPN)
|
|
115
|
+
if M % 128 != 0:
|
|
116
|
+
M = ((M // 128) + 1) * 128
|
|
117
|
+
|
|
118
|
+
# =========================================================================
|
|
119
|
+
# Phase 2. Correlated Randomness Generation (VOLE)
|
|
120
|
+
# =========================================================================
|
|
121
|
+
# Parties run Silent VOLE (based on LPN assumption) to generate:
|
|
122
|
+
# Sender: U, V (Vectors of size M)
|
|
123
|
+
# Receiver: W, Delta
|
|
124
|
+
# Correlation: W = V + U * Delta
|
|
125
|
+
#
|
|
126
|
+
# Note: U is uniformly random. It acts as a "One-Time Pad" key for the protocol.
|
|
127
|
+
|
|
128
|
+
# silent_vole_random_u returns (v, w, u, delta)
|
|
129
|
+
res_tuple = silent_ot.silent_vole_random_u(sender, receiver, M, base_k=1024)
|
|
130
|
+
v_sender, w_receiver, u_sender, delta_receiver = res_tuple[:4]
|
|
131
|
+
|
|
132
|
+
# =========================================================================
|
|
133
|
+
# Phase 3. Receiver Encoding & Masking (OKVS)
|
|
134
|
+
# =========================================================================
|
|
135
|
+
# The Receiver encodes their input set Y into the OKVS structure P.
|
|
136
|
+
# Goal: Decode(P, y) = H(y) forall y in Y.
|
|
137
|
+
#
|
|
138
|
+
# Then, Receiver masks P with the VOLE output W to get Q:
|
|
139
|
+
# Q = P ^ W
|
|
140
|
+
# This Q is sent to the Sender.
|
|
141
|
+
|
|
142
|
+
# 3.1 Generate OKVS Seed (Public/Session Randomness)
|
|
143
|
+
# Used for OKVS hashing distribution. Can be public, but generated at runtime for safety.
|
|
144
|
+
from mplang.v2.dialects import crypto
|
|
145
|
+
from mplang.v2.edsl import typing as elt
|
|
146
|
+
|
|
147
|
+
def _gen_seed() -> Any:
|
|
148
|
+
return crypto.random_tensor((2,), elt.u64)
|
|
149
|
+
|
|
150
|
+
okvs_seed = simp.pcall_static((receiver,), _gen_seed)
|
|
151
|
+
okvs_seed_sender = simp.shuffle_static(okvs_seed, {sender: receiver})
|
|
152
|
+
|
|
153
|
+
# Instantiate OKVS Data Structure
|
|
154
|
+
okvs = okvs_gct.SparseOKVS(M)
|
|
155
|
+
|
|
156
|
+
def _recv_ops(y: Any, w: Any, delta: Any, seed: Any) -> Any:
|
|
157
|
+
# y: (N,) Inputs
|
|
158
|
+
# w: (M, 2) VOLE share
|
|
159
|
+
|
|
160
|
+
# 3.2 Compute H(y) - The Random Oracle Target
|
|
161
|
+
# We use Davies-Meyer construction: H(x) = E_x(0) ^ x
|
|
162
|
+
# This is a standard, efficient, and robust way to instantiate a RO from AES.
|
|
163
|
+
|
|
164
|
+
def _reshape_seeds(items: Any) -> Any:
|
|
165
|
+
# Prepare items as AES keys (128-bit)
|
|
166
|
+
lo = items
|
|
167
|
+
hi = jnp.zeros_like(items)
|
|
168
|
+
return jnp.stack([lo, hi], axis=1) # (N, 2)
|
|
169
|
+
|
|
170
|
+
seeds = tensor.run_jax(_reshape_seeds, y)
|
|
171
|
+
res_exp = field.aes_expand(seeds, 1) # (N, 1, 2)
|
|
172
|
+
|
|
173
|
+
def _davies_meyer(enc: Any, s: Any) -> Any:
|
|
174
|
+
enc_flat = enc.reshape(enc.shape[0], 2)
|
|
175
|
+
return jnp.bitwise_xor(enc_flat, s)
|
|
176
|
+
|
|
177
|
+
h_y = tensor.run_jax(_davies_meyer, res_exp, seeds)
|
|
178
|
+
|
|
179
|
+
# 3.3 Solve System of Linear Equations (OKVS Encode)
|
|
180
|
+
# We find P such that: P * M_okvs(y) = h_y
|
|
181
|
+
p_storage = okvs.encode(y, h_y, seed)
|
|
182
|
+
|
|
183
|
+
# 3.4 Mask with Vole Share
|
|
184
|
+
# Q = P ^ W
|
|
185
|
+
q_storage = field.add(p_storage, w)
|
|
186
|
+
|
|
187
|
+
return q_storage
|
|
188
|
+
|
|
189
|
+
# Execute on Receiver
|
|
190
|
+
q_shared = simp.pcall_static(
|
|
191
|
+
(receiver,), _recv_ops, receiver_items, w_receiver, delta_receiver, okvs_seed
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# 3.5 Send Q to Sender
|
|
195
|
+
q_sender_view = simp.shuffle_static(q_shared, {sender: receiver})
|
|
196
|
+
|
|
197
|
+
# =========================================================================
|
|
198
|
+
# Phase 4. Sender Decoding & Reconstruction
|
|
199
|
+
# =========================================================================
|
|
200
|
+
# Sender uses Q and their local shares (U, V) to reconstruct T.
|
|
201
|
+
#
|
|
202
|
+
# Derivation:
|
|
203
|
+
# 1. S_decoded = Decode(Q, x) = Decode(P ^ W, x) = P(x) ^ W(x)
|
|
204
|
+
# 2. Recall W(x) = V(x) ^ U(x)*Delta (VOLE property)
|
|
205
|
+
# 3. So S_decoded = P(x) ^ V(x) ^ U(x)*Delta
|
|
206
|
+
#
|
|
207
|
+
# 4. Sender computes T = S_decoded ^ V(x) ^ H(x)
|
|
208
|
+
# T = P(x) ^ V(x) ^ U(x)*Delta ^ V(x) ^ H(x)
|
|
209
|
+
# T = P(x) ^ H(x) ^ U(x)*Delta
|
|
210
|
+
#
|
|
211
|
+
# 5. If x is in Intersection (Meanings x == y for some y):
|
|
212
|
+
# Then P(x) == H(x) (by OKVS property)
|
|
213
|
+
# So T = H(x) ^ H(x) ^ U(x)*Delta
|
|
214
|
+
# T = U(x)*Delta
|
|
215
|
+
#
|
|
216
|
+
# This relation T == U* * Delta is what we verify in Phase 5.
|
|
217
|
+
|
|
218
|
+
def _sender_ops(x: Any, q: Any, u: Any, v: Any, seed: Any) -> tuple[Any, Any]:
|
|
219
|
+
# x: (N,) Sender Items
|
|
220
|
+
# q: (M, 2) Received OKVS
|
|
221
|
+
|
|
222
|
+
# 4.1 Decode Q and V at x
|
|
223
|
+
# OKVS Decode is a linear combination of storage positions.
|
|
224
|
+
s_decoded = okvs.decode(x, q, seed)
|
|
225
|
+
v_decoded = okvs.decode(x, v, seed)
|
|
226
|
+
|
|
227
|
+
# 4.2 Compute H(x)
|
|
228
|
+
def _reshape_seeds(items: Any) -> Any:
|
|
229
|
+
lo = items
|
|
230
|
+
hi = jnp.zeros_like(items)
|
|
231
|
+
return jnp.stack([lo, hi], axis=1)
|
|
232
|
+
|
|
233
|
+
seeds_x = tensor.run_jax(_reshape_seeds, x)
|
|
234
|
+
res_exp_x = field.aes_expand(seeds_x, 1)
|
|
235
|
+
|
|
236
|
+
def _davies_meyer(enc: Any, s: Any) -> Any:
|
|
237
|
+
enc_flat = enc.reshape(enc.shape[0], 2)
|
|
238
|
+
return jnp.bitwise_xor(enc_flat, s)
|
|
239
|
+
|
|
240
|
+
h_x = tensor.run_jax(_davies_meyer, res_exp_x, seeds_x)
|
|
241
|
+
|
|
242
|
+
# 4.3 Compute T candidate
|
|
243
|
+
# T = S ^ V ^ H(x)
|
|
244
|
+
# Note: s_decoded is (S^V^U*Delta) effectively
|
|
245
|
+
t_val = field.add(s_decoded, v_decoded)
|
|
246
|
+
t_val = field.add(t_val, h_x)
|
|
247
|
+
|
|
248
|
+
# 4.4 Compute U* = Decode(U, x)
|
|
249
|
+
# This is the sender's share of the randomness for item x.
|
|
250
|
+
s_u = field.decode_okvs(x, u, seed)
|
|
251
|
+
|
|
252
|
+
return t_val, s_u
|
|
253
|
+
|
|
254
|
+
t_val_sender, u_star_sender = simp.pcall_static(
|
|
255
|
+
(sender,),
|
|
256
|
+
_sender_ops,
|
|
257
|
+
sender_items,
|
|
258
|
+
q_sender_view,
|
|
259
|
+
u_sender,
|
|
260
|
+
v_sender,
|
|
261
|
+
okvs_seed_sender,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# =========================================================================
|
|
265
|
+
# Phase 5. Secure Verification
|
|
266
|
+
# =========================================================================
|
|
267
|
+
# The Protocol invariant is T == U* * Delta for intersection items.
|
|
268
|
+
#
|
|
269
|
+
# Security Risk:
|
|
270
|
+
# We must NOT reveal T or Delta to the other party.
|
|
271
|
+
# - If Receiver learns T, they can compute Diff = T - U*Delta = H(x) + ... and attack x.
|
|
272
|
+
# - If Sender learns Delta, VOLE security collapses.
|
|
273
|
+
#
|
|
274
|
+
# Secure Verification Method:
|
|
275
|
+
# 1. Sender sends U* (Random Mask share) to Receiver.
|
|
276
|
+
# - U* is derived from U (random VOLE inputs) so it reveals nothing about X.
|
|
277
|
+
#
|
|
278
|
+
# 2. Receiver computes Target = U* * Delta.
|
|
279
|
+
# - This allows Receiver to construct the expected value of T without knowing T's components.
|
|
280
|
+
#
|
|
281
|
+
# 3. Receiver Hashes the Target and sends H(Target) to Sender.
|
|
282
|
+
# - Hashing prevents Sender from learning Delta algebraically.
|
|
283
|
+
# - Hash function acts as a commitment.
|
|
284
|
+
#
|
|
285
|
+
# 4. Sender compares H(T) =? H(Target).
|
|
286
|
+
# - Equality implies x is in Intersection.
|
|
287
|
+
|
|
288
|
+
# 5.1 Sender -> Receiver: U*
|
|
289
|
+
u_star_recv = simp.shuffle_static(u_star_sender, {receiver: sender})
|
|
290
|
+
|
|
291
|
+
# 5.2 Receiver: Compute Expected Target (U* * Delta)
|
|
292
|
+
def _recv_verify_ops(u_s: Any, delta: Any) -> Any:
|
|
293
|
+
# u_s: (N, 2), delta: (2,)
|
|
294
|
+
|
|
295
|
+
# Use tensor.run_jax to isolate JAX operations (tile is not an EDSL primitive)
|
|
296
|
+
def _tile(d: Any) -> Any:
|
|
297
|
+
return jnp.tile(d, (n, 1))
|
|
298
|
+
|
|
299
|
+
delta_expanded = tensor.run_jax(_tile, delta)
|
|
300
|
+
|
|
301
|
+
# Compute U* * Delta in GF(2^128)
|
|
302
|
+
target = field.mul(u_s, delta_expanded)
|
|
303
|
+
return target
|
|
304
|
+
|
|
305
|
+
target_val = simp.pcall_static(
|
|
306
|
+
(receiver,), _recv_verify_ops, u_star_recv, delta_receiver
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# 5.3 Hash Exchange
|
|
310
|
+
# Use robust hashing to prevent algebraic attacks or leakage
|
|
311
|
+
from mplang.v2.libs.mpc.ot import extension as ot_extension
|
|
312
|
+
|
|
313
|
+
def _hash_shares(share: el.Object, party: int) -> el.Object:
|
|
314
|
+
"""Hash the shares using domain separator for security."""
|
|
315
|
+
return ot_extension.vec_hash(share, domain_sep=0xFEED, num_rows=n)
|
|
316
|
+
|
|
317
|
+
# Hash(Target) on Receiver
|
|
318
|
+
h_target_recv = simp.pcall_static(
|
|
319
|
+
(receiver,), lambda x: _hash_shares(x, receiver), target_val
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# Hash(T) on Sender
|
|
323
|
+
h_t_sender = simp.pcall_static(
|
|
324
|
+
(sender,), lambda x: _hash_shares(x, sender), t_val_sender
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# Send Hash to Sender for comparison
|
|
328
|
+
h_target_at_sender = simp.shuffle_static(h_target_recv, {sender: receiver})
|
|
329
|
+
|
|
330
|
+
# 5.4 Final Comparison on Sender
|
|
331
|
+
def _compare(h_t: Any, h_target: Any) -> Any:
|
|
332
|
+
# Compare 32-byte hashes (N, 32) row-by-row
|
|
333
|
+
|
|
334
|
+
def _core(a: Any, b: Any) -> Any:
|
|
335
|
+
eq = jnp.all(a == b, axis=1)
|
|
336
|
+
return eq.astype(jnp.uint8) # (N,) 0 or 1
|
|
337
|
+
|
|
338
|
+
return tensor.run_jax(_core, h_t, h_target)
|
|
339
|
+
|
|
340
|
+
intersection_mask = simp.pcall_static(
|
|
341
|
+
(sender,), _compare, h_t_sender, h_target_at_sender
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
return cast(el.Object, intersection_mask)
|