mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Unbalanced PSI Protocol.
|
|
16
|
+
|
|
17
|
+
This module implements unbalanced PSI for scenarios where client set size n << server set size N.
|
|
18
|
+
Uses Seeded OKVS (via derived keys) to prevent pre-computation attacks.
|
|
19
|
+
|
|
20
|
+
Security Model:
|
|
21
|
+
- Session-specific random seed generated at RUNTIME on the Server.
|
|
22
|
+
- Both Key and Value derivations use the seed for consistent security.
|
|
23
|
+
- WARNING: Online dictionary attacks by active clients remain possible without OPRF.
|
|
24
|
+
|
|
25
|
+
Protocol:
|
|
26
|
+
1. Server generates random Seed at runtime.
|
|
27
|
+
2. Server computes K' = H(ServerItems, Seed) and V = H(ServerItems, Seed).
|
|
28
|
+
3. Server solves OKVS: Table = Solve(K', V).
|
|
29
|
+
4. Server sends Seed + Table to Client.
|
|
30
|
+
5. Client computes k' = H(ClientItems, Seed) and v = H(ClientItems, Seed).
|
|
31
|
+
6. Client decodes V' = Decode(k', Table).
|
|
32
|
+
7. Client checks V' == v.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
from typing import Any, cast
|
|
36
|
+
|
|
37
|
+
import jax.numpy as jnp
|
|
38
|
+
|
|
39
|
+
import mplang.v2.edsl as el
|
|
40
|
+
import mplang.v2.edsl.typing as elt
|
|
41
|
+
from mplang.v2.dialects import crypto, field, simp, tensor
|
|
42
|
+
from mplang.v2.libs.mpc.psi.okvs_gct import get_okvs_expansion
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def psi_unbalanced(
|
|
46
|
+
server: int,
|
|
47
|
+
client: int,
|
|
48
|
+
server_n: int,
|
|
49
|
+
client_n: int,
|
|
50
|
+
server_items: el.Object,
|
|
51
|
+
client_items: el.Object,
|
|
52
|
+
) -> el.Object:
|
|
53
|
+
"""Unbalanced PSI with O(client_n) communication.
|
|
54
|
+
|
|
55
|
+
This protocol is optimized for scenarios where client_n << server_n.
|
|
56
|
+
|
|
57
|
+
Security:
|
|
58
|
+
- Uses a cryptographically random Session Seed (128-bit) generated at RUNTIME.
|
|
59
|
+
- Both Key and Value derivations include the Seed.
|
|
60
|
+
- Prevents offline pre-computation (Rainbow Table) attacks.
|
|
61
|
+
- WARNING: Online dictionary attacks by active clients remain possible.
|
|
62
|
+
|
|
63
|
+
> [!WARNING]
|
|
64
|
+
> **Security Notice**: This protocol sends the Session Seed to the Client to allow
|
|
65
|
+
> them to compute the OKVS lookups. A malicious Client can perform an online
|
|
66
|
+
> dictionary attack (brute-force hashing) to enumerate Server items.
|
|
67
|
+
> For strict set privacy against malicious clients, use OPRF-PSI (`oprf.py` based)
|
|
68
|
+
> instead of this unbalanced protocol.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
server: Rank of server (holds large set N)
|
|
72
|
+
client: Rank of client (holds small set n)
|
|
73
|
+
server_n: Size of server's set
|
|
74
|
+
client_n: Size of client's set
|
|
75
|
+
server_items: (server_n,) uint64 on server
|
|
76
|
+
client_items: (client_n,) uint64 on client
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Intersection indicators on client: (client_n,) uint8
|
|
80
|
+
"""
|
|
81
|
+
if server == client:
|
|
82
|
+
raise ValueError("Server and Client must be different parties.")
|
|
83
|
+
|
|
84
|
+
if client_n <= 0 or server_n <= 0:
|
|
85
|
+
raise ValueError("Set sizes must be positive.")
|
|
86
|
+
|
|
87
|
+
# =========================================================================
|
|
88
|
+
# 1. Server Setup: Generate Runtime Random Seed
|
|
89
|
+
# =========================================================================
|
|
90
|
+
|
|
91
|
+
# Generate 16 bytes (128-bit) of cryptographically secure random data
|
|
92
|
+
# AT RUNTIME on the Server party (not during trace!)
|
|
93
|
+
def _gen_runtime_seed() -> Any:
|
|
94
|
+
# Use new API: directly generate (2,) u64 tensor
|
|
95
|
+
return crypto.random_tensor((2,), elt.u64)
|
|
96
|
+
|
|
97
|
+
server_seed = simp.pcall_static((server,), _gen_runtime_seed)
|
|
98
|
+
|
|
99
|
+
# =========================================================================
|
|
100
|
+
# Hashing Helpers (Both Key and Value use Seed)
|
|
101
|
+
# =========================================================================
|
|
102
|
+
|
|
103
|
+
def _compute_hashes(items: Any, seed: Any) -> tuple[Any, Any]:
|
|
104
|
+
"""Compute Derived Key K' and Validation Value V for items.
|
|
105
|
+
|
|
106
|
+
Both Key and Value are derived using the session Seed to prevent
|
|
107
|
+
pre-computation attacks.
|
|
108
|
+
|
|
109
|
+
Key: K' = AES_Expand(H_key(Item, Seed))[:64bit]
|
|
110
|
+
Value: V = AES_Expand(H_val(Item, Seed))[:128bit]
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
# Domain separator for Key derivation
|
|
114
|
+
KEY_DOMAIN = jnp.uint64(0xA5A5A5A5A5A5A5A5)
|
|
115
|
+
# Domain separator for Value derivation
|
|
116
|
+
VAL_DOMAIN = jnp.uint64(0x5A5A5A5A5A5A5A5A)
|
|
117
|
+
|
|
118
|
+
def _prepare_key_seed(x: Any, s: Any) -> Any:
|
|
119
|
+
# x: (N,) u64, s: (2,) u64
|
|
120
|
+
# Mix with KEY domain separator
|
|
121
|
+
k_lo = (x + s[0]) ^ KEY_DOMAIN
|
|
122
|
+
k_hi = (x ^ s[1]) + KEY_DOMAIN
|
|
123
|
+
return jnp.stack([k_lo, k_hi], axis=1)
|
|
124
|
+
|
|
125
|
+
def _prepare_val_seed(x: Any, s: Any) -> Any:
|
|
126
|
+
# x: (N,) u64, s: (2,) u64
|
|
127
|
+
# Mix with VAL domain separator (different from key)
|
|
128
|
+
v_lo = (x + s[0]) ^ VAL_DOMAIN
|
|
129
|
+
v_hi = (x ^ s[1]) + VAL_DOMAIN
|
|
130
|
+
return jnp.stack([v_lo, v_hi], axis=1)
|
|
131
|
+
|
|
132
|
+
# Derive Keys
|
|
133
|
+
key_seeds = tensor.run_jax(_prepare_key_seed, items, seed)
|
|
134
|
+
h_keys_raw = field.aes_expand(key_seeds, 1) # (N, 1, 2)
|
|
135
|
+
|
|
136
|
+
def _extract_key(h: Any) -> Any:
|
|
137
|
+
return h[:, 0, 0]
|
|
138
|
+
|
|
139
|
+
keys = tensor.run_jax(_extract_key, h_keys_raw)
|
|
140
|
+
|
|
141
|
+
# Derive Values (ALSO using seed - fixes Value Oracle Attack)
|
|
142
|
+
val_seeds = tensor.run_jax(_prepare_val_seed, items, seed)
|
|
143
|
+
h_vals_raw = field.aes_expand(val_seeds, 1) # (N, 1, 2)
|
|
144
|
+
|
|
145
|
+
def _flatten(h: Any) -> Any:
|
|
146
|
+
return h.reshape(h.shape[0], 2)
|
|
147
|
+
|
|
148
|
+
vals = tensor.run_jax(_flatten, h_vals_raw)
|
|
149
|
+
|
|
150
|
+
return keys, vals
|
|
151
|
+
|
|
152
|
+
# Server computes K' and V
|
|
153
|
+
server_derived_keys, server_values = simp.pcall_static(
|
|
154
|
+
(server,), _compute_hashes, server_items, server_seed
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Server Solves OKVS
|
|
158
|
+
expansion = get_okvs_expansion(server_n)
|
|
159
|
+
M = int(server_n * expansion)
|
|
160
|
+
|
|
161
|
+
def _solve(k: Any, v: Any, s: Any) -> Any:
|
|
162
|
+
return field.solve_okvs(k, v, M, s)
|
|
163
|
+
|
|
164
|
+
okvs_table = simp.pcall_static(
|
|
165
|
+
(server,), _solve, server_derived_keys, server_values, server_seed
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# Send to Client
|
|
169
|
+
okvs_table_client = simp.shuffle_static(okvs_table, {client: server})
|
|
170
|
+
client_seed = simp.shuffle_static(server_seed, {client: server})
|
|
171
|
+
|
|
172
|
+
# =========================================================================
|
|
173
|
+
# 2. Client Operations
|
|
174
|
+
# =========================================================================
|
|
175
|
+
|
|
176
|
+
# Client computes k' and expected V using the SAME hash functions
|
|
177
|
+
client_derived_keys, client_expected_values = simp.pcall_static(
|
|
178
|
+
(client,), _compute_hashes, client_items, client_seed
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Client Decodes OKVS and Compares
|
|
182
|
+
def _decode_and_compare(keys: Any, table: Any, expected: Any, s: Any) -> Any:
|
|
183
|
+
decoded = field.decode_okvs(keys, table, s)
|
|
184
|
+
|
|
185
|
+
def _compare_jax(dec: Any, exp: Any) -> Any:
|
|
186
|
+
match = (dec[:, 0] == exp[:, 0]) & (dec[:, 1] == exp[:, 1])
|
|
187
|
+
return match.astype(jnp.uint8)
|
|
188
|
+
|
|
189
|
+
return tensor.run_jax(_compare_jax, decoded, expected)
|
|
190
|
+
|
|
191
|
+
intersection_mask = simp.pcall_static(
|
|
192
|
+
(client,),
|
|
193
|
+
_decode_and_compare,
|
|
194
|
+
client_derived_keys,
|
|
195
|
+
okvs_table_client,
|
|
196
|
+
client_expected_values,
|
|
197
|
+
client_seed,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
return cast(el.Object, intersection_mask)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Vector Oblivious Linear Evaluation (VOLE) protocols.
|
|
16
|
+
|
|
17
|
+
Submodules:
|
|
18
|
+
- gilboa: Gilboa VOLE protocol
|
|
19
|
+
- silver: Silver VOLE (LDPC-based)
|
|
20
|
+
- ldpc: LDPC matrix operations
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from .gilboa import vole
|
|
24
|
+
from .silver import estimate_silver_communication, silver_vole, silver_vole_ldpc
|
|
25
|
+
|
|
26
|
+
__all__ = [
|
|
27
|
+
"estimate_silver_communication",
|
|
28
|
+
"silver_vole",
|
|
29
|
+
"silver_vole_ldpc",
|
|
30
|
+
"vole",
|
|
31
|
+
]
|
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Vector Oblivious Linear Evaluation (VOLE) Protocol.
|
|
16
|
+
|
|
17
|
+
Implements the Gilboa protocol for VOLE over GF(2^k).
|
|
18
|
+
Global SIMP implementation.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from collections.abc import Callable
|
|
22
|
+
from typing import Any, cast
|
|
23
|
+
|
|
24
|
+
import jax.numpy as jnp
|
|
25
|
+
import numpy as np
|
|
26
|
+
|
|
27
|
+
import mplang.v2.edsl as el
|
|
28
|
+
import mplang.v2.libs.mpc.ot.extension as ot
|
|
29
|
+
from mplang.v2.dialects import field, simp, tensor
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def vole(
|
|
33
|
+
sender: int,
|
|
34
|
+
receiver: int,
|
|
35
|
+
n: int,
|
|
36
|
+
u_provider: Callable[[], el.Object],
|
|
37
|
+
delta_provider: Callable[[], el.Object],
|
|
38
|
+
return_secrets: bool = False,
|
|
39
|
+
) -> tuple[el.Object, el.Object] | tuple[el.Object, el.Object, el.Object, el.Object]:
|
|
40
|
+
"""Execute VOLE Protocol (Gilboa).
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
sender: Rank of Sender.
|
|
44
|
+
receiver: Rank of Receiver.
|
|
45
|
+
n: Vector length.
|
|
46
|
+
u_provider: Callable running on Sender returning u (N, 2).
|
|
47
|
+
delta_provider: Callable running on Receiver returning delta (2,).
|
|
48
|
+
return_secrets: If True, returns (v, w, u, delta).
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
If return_secrets=False:
|
|
52
|
+
v: Vector on Sender (N, 2).
|
|
53
|
+
w: Vector on Receiver (N, 2).
|
|
54
|
+
If return_secrets=True:
|
|
55
|
+
v, w, u, delta
|
|
56
|
+
"""
|
|
57
|
+
K = 128
|
|
58
|
+
|
|
59
|
+
# 1. Receiver decomp Delta
|
|
60
|
+
def _recv_prep() -> tuple[el.Object, el.Object]:
|
|
61
|
+
delta = delta_provider()
|
|
62
|
+
|
|
63
|
+
# Decompose
|
|
64
|
+
# delta is (2,) u64.
|
|
65
|
+
# Run JAX to unpack
|
|
66
|
+
def _unpack(d: Any) -> Any:
|
|
67
|
+
return jnp.unpackbits(d.view(jnp.uint8), bitorder="little")
|
|
68
|
+
|
|
69
|
+
bits_u8 = tensor.run_jax(_unpack, delta) # (128,) u8
|
|
70
|
+
# Reshape to (128, 1) using run_jax for XLA optimization
|
|
71
|
+
bits_reshaped = tensor.run_jax(lambda x: x.reshape(128, 1), bits_u8)
|
|
72
|
+
return delta, bits_reshaped
|
|
73
|
+
|
|
74
|
+
delta_and_bits = simp.pcall_static((receiver,), _recv_prep)
|
|
75
|
+
# Extract
|
|
76
|
+
delta_recv = simp.pcall_static((receiver,), lambda x: x[0], delta_and_bits)
|
|
77
|
+
delta_bits = simp.pcall_static((receiver,), lambda x: x[1], delta_and_bits)
|
|
78
|
+
|
|
79
|
+
# 2. Run IKNP OT Core
|
|
80
|
+
# Returns t (Sender), q (Receiver), s (Sender)
|
|
81
|
+
# Note: In standard IKNP, Receiver chooses. Sender gets keys.
|
|
82
|
+
# Here VOLE Receiver chooses (delta bits).
|
|
83
|
+
# So VOLE Receiver is OT Receiver.
|
|
84
|
+
# We need 128 OTs for Gilboa. Result is (128, 128) matrices.
|
|
85
|
+
t_matrix_128, q_matrix_128, s_choices = ot.iknp_core(
|
|
86
|
+
delta_bits, sender, receiver, K
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# t_matrix_128: (128, 128) - 128 OT seeds, each 128 bits wide
|
|
90
|
+
# These are the "Seeds" for the Gilboa extension.
|
|
91
|
+
# Sender has T (128 seeds).
|
|
92
|
+
# Receiver has Q (128 seeds).
|
|
93
|
+
# Wait, IKNP usage usually:
|
|
94
|
+
# Q = T ^ (choices * S).
|
|
95
|
+
# Row i of Q is Q_i = T_i ^ (c_i * S).
|
|
96
|
+
# c_i is delta_i.
|
|
97
|
+
# S is the base OT choice vector (global secret S).
|
|
98
|
+
|
|
99
|
+
# We need:
|
|
100
|
+
# Sender has S_{i,0}, S_{i,1}.
|
|
101
|
+
# Receiver has S_{i, d_i}.
|
|
102
|
+
#
|
|
103
|
+
# IKNP gives:
|
|
104
|
+
# Col j of Q = Col j of T ^ (c * S_j) ? No.
|
|
105
|
+
|
|
106
|
+
# Let's map IKNP output to Gilboa needs.
|
|
107
|
+
# IKNP gives:
|
|
108
|
+
# For generated OT i (0..127):
|
|
109
|
+
# Sender holds T[i] (block).
|
|
110
|
+
# Receiver holds Q[i] (block).
|
|
111
|
+
# Q[i] = T[i] ^ (c[i] * S).
|
|
112
|
+
# Where S is the Base OT Choice (held by Sender of IKNP = Sender of VOLE).
|
|
113
|
+
# Wait, Sender acts as Receiver in BaseOT usually.
|
|
114
|
+
# In `ot_extension.py`: `s` (base choices) is on Sender.
|
|
115
|
+
# So Q[i] = T[i] ^ (delta_i * s).
|
|
116
|
+
|
|
117
|
+
# This gives us CORRELATED SEEDS.
|
|
118
|
+
# Sender has T[i] and s.
|
|
119
|
+
# Receiver has Q[i].
|
|
120
|
+
|
|
121
|
+
# Gilboa needs:
|
|
122
|
+
# Sender sends u * x^i masked.
|
|
123
|
+
# We can use T[i] and (T[i]^s) as the seeds for random strings?
|
|
124
|
+
#
|
|
125
|
+
# Q[i] is ONE seed.
|
|
126
|
+
# If delta_i = 0, Q[i] = T[i].
|
|
127
|
+
# If delta_i = 1, Q[i] = T[i] ^ s.
|
|
128
|
+
|
|
129
|
+
# So Sender has two seeds for bit i:
|
|
130
|
+
# Seed0 = T[i]
|
|
131
|
+
# Seed1 = T[i] ^ s
|
|
132
|
+
|
|
133
|
+
# This is perfect! IKNP *is* ROT.
|
|
134
|
+
|
|
135
|
+
# 3. Expansion
|
|
136
|
+
# Sender expands:
|
|
137
|
+
# V0_i = PRG(T[i], N)
|
|
138
|
+
# V1_i = PRG(T[i] ^ s, N)
|
|
139
|
+
|
|
140
|
+
# Receiver expands:
|
|
141
|
+
# W_i = PRG(Q[i], N)
|
|
142
|
+
# Note: W_i = V0_i if delta_i=0
|
|
143
|
+
# W_i = V1_i if delta_i=1
|
|
144
|
+
|
|
145
|
+
# Sender computes correction:
|
|
146
|
+
# M_i = V0_i ^ V1_i ^ (u * x^i)
|
|
147
|
+
# M_i = PRG(T) ^ PRG(T^s) ^ (u * x^i)
|
|
148
|
+
|
|
149
|
+
# Receiver computes:
|
|
150
|
+
# result_i = W_i ^ (delta_i * M_i)
|
|
151
|
+
# = V_{delta_i} ^ (delta_i * (V0^V1^term))
|
|
152
|
+
# if d=0: W = V0. Res = V0. Correct.
|
|
153
|
+
# if d=1: W = V1. Res = V1 ^ V0 ^ V1 ^ term = V0 ^ term.
|
|
154
|
+
# Wait.
|
|
155
|
+
# We want result = V0 + ... ?
|
|
156
|
+
# Gilboa: v = Sum(V0).
|
|
157
|
+
# w = v + u*delta.
|
|
158
|
+
#
|
|
159
|
+
# If d=0: Res = V0.
|
|
160
|
+
# If d=1: Res = V0 ^ term.
|
|
161
|
+
# Sum(Res) = Sum(V0) ^ Sum(d_i * term) = v ^ (u * Sum(d_i x^i)) = v + u*delta.
|
|
162
|
+
# Correct.
|
|
163
|
+
|
|
164
|
+
# Implementation:
|
|
165
|
+
|
|
166
|
+
# Capture U on Sender
|
|
167
|
+
def _sender_wrapper() -> el.Object:
|
|
168
|
+
u = u_provider()
|
|
169
|
+
return u
|
|
170
|
+
|
|
171
|
+
u_loc_captured = simp.pcall_static((sender,), _sender_wrapper)
|
|
172
|
+
|
|
173
|
+
m_corrections, v_sender = simp.pcall_static(
|
|
174
|
+
(sender,), _sender_round, t_matrix_128, s_choices, u_loc_captured, n
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Shuffle M to Receiver
|
|
178
|
+
from jax.tree_util import tree_map
|
|
179
|
+
|
|
180
|
+
m_recv = tree_map(
|
|
181
|
+
lambda x: simp.shuffle_static(x, {receiver: sender}), m_corrections
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
w_receiver = simp.pcall_static(
|
|
185
|
+
(receiver,), _recv_round, q_matrix_128, m_recv, delta_bits, n
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
if return_secrets:
|
|
189
|
+
return v_sender, w_receiver, u_loc_captured, delta_recv
|
|
190
|
+
else:
|
|
191
|
+
return v_sender, w_receiver
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
# A. Expand (Sender)
|
|
195
|
+
def _sender_round(
|
|
196
|
+
t_loc: el.Object, s_loc: el.Object, u_loc: el.Object, n: int
|
|
197
|
+
) -> tuple[el.Object, el.Object]:
|
|
198
|
+
# t_loc: (128, 128)
|
|
199
|
+
# s_loc: (128,)
|
|
200
|
+
# u_loc: (N, 2)
|
|
201
|
+
|
|
202
|
+
# 0. Prep Seeds
|
|
203
|
+
def _prep_sender_seeds(t: Any, s: Any) -> tuple[Any, Any]:
|
|
204
|
+
# t: (128, 128) bits
|
|
205
|
+
# s: (128,) bits
|
|
206
|
+
t_seeds = jnp.packbits(t, axis=-1) # (128, 16) uint8
|
|
207
|
+
s_bytes = jnp.packbits(s, axis=-1) # (16,)
|
|
208
|
+
s_broad = jnp.expand_dims(s_bytes, 0) # (1, 16)
|
|
209
|
+
t_xor_s_seeds = jnp.bitwise_xor(t_seeds, s_broad)
|
|
210
|
+
return t_seeds, t_xor_s_seeds
|
|
211
|
+
|
|
212
|
+
t_seeds, t_s_seeds = tensor.run_jax(_prep_sender_seeds, t_loc, s_loc)
|
|
213
|
+
t_seeds = cast(el.Object, t_seeds)
|
|
214
|
+
t_s_seeds = cast(el.Object, t_s_seeds)
|
|
215
|
+
|
|
216
|
+
# 1. Expand
|
|
217
|
+
v0_expanded = field.aes_expand(t_seeds, n)
|
|
218
|
+
v1_expanded = field.aes_expand(t_s_seeds, n)
|
|
219
|
+
|
|
220
|
+
# 2. Compute term = u * powers using Field Arithmetic
|
|
221
|
+
# Vectorized Version:
|
|
222
|
+
# u_loc: (N, 2)
|
|
223
|
+
# powers: (128, 2)
|
|
224
|
+
# term: (128, N, 2) = u_loc * p_broad
|
|
225
|
+
|
|
226
|
+
# Generate Powers of X (128, 2) CONSTANT
|
|
227
|
+
# 1, x, x^2 ...
|
|
228
|
+
powers_list = []
|
|
229
|
+
for i in range(128):
|
|
230
|
+
lo, hi = 0, 0
|
|
231
|
+
if i < 64:
|
|
232
|
+
lo = 1 << i
|
|
233
|
+
else:
|
|
234
|
+
hi = 1 << (i - 64)
|
|
235
|
+
powers_list.append([lo, hi])
|
|
236
|
+
powers_arr = np.array(powers_list, dtype=np.uint64)
|
|
237
|
+
powers_const = tensor.constant(powers_arr)
|
|
238
|
+
|
|
239
|
+
# Broadcast for Vectorized Mul
|
|
240
|
+
# u_loc: (N, 2) -> (1, N, 2) -> (128, N, 2)
|
|
241
|
+
# powers: (128, 2) -> (128, 1, 2) -> (128, N, 2)
|
|
242
|
+
|
|
243
|
+
def _broadcast_inputs(u_val: Any, p_val: Any) -> tuple[Any, Any]:
|
|
244
|
+
# u: (N, 2)
|
|
245
|
+
# p: (128, 2)
|
|
246
|
+
n_ = u_val.shape[0]
|
|
247
|
+
|
|
248
|
+
# Tile U: (128, N, 2)
|
|
249
|
+
u_broad = jnp.tile(u_val[None, :, :], (128, 1, 1))
|
|
250
|
+
|
|
251
|
+
# Tile P: (128, N, 2)
|
|
252
|
+
p_broad = jnp.tile(p_val[:, None, :], (1, n_, 1))
|
|
253
|
+
|
|
254
|
+
return u_broad, p_broad
|
|
255
|
+
|
|
256
|
+
u_vec, p_vec = tensor.run_jax(_broadcast_inputs, u_loc, powers_const)
|
|
257
|
+
|
|
258
|
+
# Single Batched Mul
|
|
259
|
+
term_val = field.mul(u_vec, p_vec) # (128, N, 2)
|
|
260
|
+
|
|
261
|
+
# 3. Compute Corrections
|
|
262
|
+
def _sender_calc(v0: Any, v1: Any, term: Any) -> tuple[Any, Any]:
|
|
263
|
+
# v0: (128, N, 2)
|
|
264
|
+
# v1: (128, N, 2)
|
|
265
|
+
# term: (128, N, 2)
|
|
266
|
+
|
|
267
|
+
m_out = v0 ^ v1 ^ term
|
|
268
|
+
|
|
269
|
+
# v_out sum
|
|
270
|
+
v_out = v0[0]
|
|
271
|
+
for i in range(1, 128):
|
|
272
|
+
v_out = v_out ^ v0[i]
|
|
273
|
+
|
|
274
|
+
return m_out, v_out
|
|
275
|
+
|
|
276
|
+
m_corr, v = tensor.run_jax(_sender_calc, v0_expanded, v1_expanded, term_val)
|
|
277
|
+
return cast(el.Object, m_corr), cast(el.Object, v)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
# B. Expand & Reconstruct (Receiver)
|
|
281
|
+
def _recv_round(
|
|
282
|
+
q_loc: el.Object, m_loc: el.Object, d_bits: el.Object, n: int
|
|
283
|
+
) -> el.Object:
|
|
284
|
+
# 0. Prep Seeds
|
|
285
|
+
def _prep_recv_seeds(q: Any) -> Any:
|
|
286
|
+
return jnp.packbits(q, axis=-1)
|
|
287
|
+
|
|
288
|
+
q_seeds = tensor.run_jax(_prep_recv_seeds, q_loc)
|
|
289
|
+
|
|
290
|
+
# 1. AES Expand
|
|
291
|
+
w_expanded = field.aes_expand(q_seeds, n) # (128, N, 2)
|
|
292
|
+
|
|
293
|
+
# 2. Reconstruct
|
|
294
|
+
def _recv_calc(w_exp: Any, m_val: Any, d_b: Any) -> Any:
|
|
295
|
+
# w_exp: (128, N, 2)
|
|
296
|
+
# m_val: (128, N, 2)
|
|
297
|
+
# d_b: (128, 1) bits from earlier
|
|
298
|
+
|
|
299
|
+
d_flat = d_b.reshape(128)
|
|
300
|
+
# Mask M
|
|
301
|
+
# m_val is u64. d_flat is u8(?).
|
|
302
|
+
mask = d_flat.reshape(128, 1, 1).astype(bool)
|
|
303
|
+
m_masked = jnp.where(mask, m_val, jnp.zeros_like(m_val))
|
|
304
|
+
|
|
305
|
+
res_i = w_exp ^ m_masked
|
|
306
|
+
# w_final = jnp.bitwise_xor.reduce(res_i, axis=0)
|
|
307
|
+
w_final = res_i[0]
|
|
308
|
+
for i in range(1, 128):
|
|
309
|
+
w_final = w_final ^ res_i[i]
|
|
310
|
+
|
|
311
|
+
return w_final
|
|
312
|
+
|
|
313
|
+
return cast(el.Object, tensor.run_jax(_recv_calc, w_expanded, m_loc, d_bits))
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def _gen_powers_of_x_jax(dummy: Any, k: int = 128) -> Any:
|
|
317
|
+
# JAX version for use inside run_jax (returns jnp.array)
|
|
318
|
+
# dummy is required for run_jax tracing anchor
|
|
319
|
+
rows = []
|
|
320
|
+
for i in range(k):
|
|
321
|
+
lo, hi = 0, 0
|
|
322
|
+
if i < 64:
|
|
323
|
+
lo = 1 << i
|
|
324
|
+
else:
|
|
325
|
+
hi = 1 << (i - 64)
|
|
326
|
+
rows.append([lo, hi])
|
|
327
|
+
return jnp.array(rows, dtype=jnp.uint64)
|