mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,477 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""OT Extension (IKNP).
|
|
16
|
+
|
|
17
|
+
Implements IKNP OT extension protocol to perform N OTs using k Base OTs.
|
|
18
|
+
Ref: https://crypto.stanford.edu/~valeria/research/2003/IKNP03.pdf
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
from typing import Any, cast
|
|
24
|
+
|
|
25
|
+
import jax.numpy as jnp
|
|
26
|
+
|
|
27
|
+
import mplang.v2.edsl as el
|
|
28
|
+
from mplang.v2.dialects import crypto, field, simp, tensor
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def prg_expand(seed_tensor: el.Object, length: int) -> el.Object:
|
|
32
|
+
"""Pseudo-Random Generator: Expand seed to `length` bits (as uint8 0/1).
|
|
33
|
+
|
|
34
|
+
Uses AES-NI via field.aes_expand for cryptographic security.
|
|
35
|
+
"""
|
|
36
|
+
# Calculate number of 128-bit blocks needed to cover `length` bits.
|
|
37
|
+
# field.aes_expand returns (K, M, 2) uint64 blocks.
|
|
38
|
+
# Total bits = M * 128.
|
|
39
|
+
|
|
40
|
+
m_blocks = (length + 127) // 128
|
|
41
|
+
|
|
42
|
+
# Input seed_tensor is (K, 32) bytes (uint8).
|
|
43
|
+
# field.aes_expand expects (K, 2) uint64 seeds.
|
|
44
|
+
|
|
45
|
+
def _reshape_seeds(s_bytes: Any) -> Any:
|
|
46
|
+
# s_bytes: (K, 32) u8
|
|
47
|
+
# Take first 16 bytes for 128-bit key/seed
|
|
48
|
+
s_16 = s_bytes[:, :16]
|
|
49
|
+
return s_16.view(jnp.uint64).reshape(-1, 2)
|
|
50
|
+
|
|
51
|
+
seeds_u64 = tensor.run_jax(_reshape_seeds, seed_tensor)
|
|
52
|
+
|
|
53
|
+
expanded_blocks = field.aes_expand(seeds_u64, m_blocks) # (K, M, 2) u64
|
|
54
|
+
|
|
55
|
+
# Convert blocks to bits
|
|
56
|
+
def _blocks_to_bits(blocks: Any) -> Any:
|
|
57
|
+
# blocks: (K, M, 2) u64
|
|
58
|
+
# unpackbits
|
|
59
|
+
# view as u8
|
|
60
|
+
bytes_view = blocks.view(jnp.uint8) # (K, M, 16)
|
|
61
|
+
bits = jnp.unpackbits(bytes_view, axis=-1, bitorder="little") # (K, M, 128)
|
|
62
|
+
|
|
63
|
+
# Flatten last two dims
|
|
64
|
+
bits_flat = bits.reshape(bits.shape[0], -1)
|
|
65
|
+
|
|
66
|
+
# Slice to exact length
|
|
67
|
+
return bits_flat[:, :length]
|
|
68
|
+
|
|
69
|
+
return cast(el.Object, tensor.run_jax(_blocks_to_bits, expanded_blocks))
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def vec_hash(data_bytes: el.Object, domain_sep: int, num_rows: int) -> el.Object:
|
|
73
|
+
"""Hash rows of a (N, D) tensor independently.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
data_bytes: (N, D) tensor to hash.
|
|
77
|
+
domain_sep: Integer domain separator to mix into the hash.
|
|
78
|
+
num_rows: Number of rows N. Must be provided explicitly.
|
|
79
|
+
"""
|
|
80
|
+
# Optimized batch hashing:
|
|
81
|
+
# 1. Prepend domain_sep to all rows (vectorized concatenation)
|
|
82
|
+
# 2. Call crypto.hash_bytes once on the whole tensor
|
|
83
|
+
|
|
84
|
+
if domain_sep != 0:
|
|
85
|
+
|
|
86
|
+
def _prepend_ds(arr: Any, ds: int) -> Any:
|
|
87
|
+
# arr: (N, D)
|
|
88
|
+
N = arr.shape[0]
|
|
89
|
+
# Create (N, 8) domain sep block using repeat & reshape
|
|
90
|
+
# ds_arr: (8,)
|
|
91
|
+
ds_arr = jnp.array([ds], dtype=jnp.uint64).view(jnp.uint8)
|
|
92
|
+
# Broadcast to (N, 8)
|
|
93
|
+
ds_block = jnp.broadcast_to(ds_arr, (N, 8))
|
|
94
|
+
|
|
95
|
+
return jnp.concatenate([ds_block, arr], axis=1)
|
|
96
|
+
|
|
97
|
+
# Result: (N, D+8)
|
|
98
|
+
data_to_hash = tensor.run_jax(lambda a: _prepend_ds(a, domain_sep), data_bytes)
|
|
99
|
+
else:
|
|
100
|
+
data_to_hash = data_bytes
|
|
101
|
+
|
|
102
|
+
# Call batched hash_bytes
|
|
103
|
+
# Input: (N, D_total) -> Output: (N, 32)
|
|
104
|
+
# This generates a single graph node, solving the compiler explosion issue.
|
|
105
|
+
# explicit hash_batch primitive (rank >= 2)
|
|
106
|
+
hashes = crypto.hash_batch(data_to_hash)
|
|
107
|
+
|
|
108
|
+
return hashes
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def iknp_core(
|
|
112
|
+
choice_bits: el.Object, sender: int, receiver: int, num_ots: int
|
|
113
|
+
) -> tuple[el.Object, el.Object, el.Object]:
|
|
114
|
+
"""Core IKNP Matrix Generation.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
t_matrix: (N, K) bit matrix on Sender.
|
|
118
|
+
q_matrix: (N, K) bit matrix on Receiver.
|
|
119
|
+
s_choices: (K,) choice bits on Sender (s).
|
|
120
|
+
"""
|
|
121
|
+
K = 128
|
|
122
|
+
|
|
123
|
+
# 1. Base OTs
|
|
124
|
+
def gen_s() -> el.Object:
|
|
125
|
+
# Generate random bits at runtime using new API
|
|
126
|
+
return crypto.random_bits(K)
|
|
127
|
+
|
|
128
|
+
s = simp.pcall_static((sender,), gen_s)
|
|
129
|
+
|
|
130
|
+
def gen_seeds() -> tuple[el.Object, el.Object]:
|
|
131
|
+
# Generate random bytes at runtime
|
|
132
|
+
k0_bytes = crypto.random_bytes(K * 32)
|
|
133
|
+
k1_bytes = crypto.random_bytes(K * 32)
|
|
134
|
+
|
|
135
|
+
# Reshape to (K, 32) using run_jax for XLA optimization
|
|
136
|
+
def _reshape_k32(b: Any) -> Any:
|
|
137
|
+
return b.reshape(K, 32)
|
|
138
|
+
|
|
139
|
+
k0 = tensor.run_jax(_reshape_k32, k0_bytes)
|
|
140
|
+
k1 = tensor.run_jax(_reshape_k32, k1_bytes)
|
|
141
|
+
return k0, k1
|
|
142
|
+
|
|
143
|
+
k0_base, k1_base = simp.pcall_static((receiver,), gen_seeds)
|
|
144
|
+
|
|
145
|
+
# Base OT Logic (Inlined)
|
|
146
|
+
# C (Common Point) initialization
|
|
147
|
+
# SECURITY FIX: C must be generated by the Base Sender (receiver in IKNP context)
|
|
148
|
+
# to prevent Base Receiver (sender in IKNP context) from knowing the discrete log,
|
|
149
|
+
# which would allow them to decrypt both messages and recover choice bits.
|
|
150
|
+
def base_param_gen() -> el.Object:
|
|
151
|
+
C = crypto.ec_mul(crypto.ec_generator(), crypto.ec_random_scalar())
|
|
152
|
+
return C
|
|
153
|
+
|
|
154
|
+
C_point = simp.pcall_static((receiver,), base_param_gen)
|
|
155
|
+
C_for_sender = simp.shuffle_static(C_point, {sender: receiver})
|
|
156
|
+
|
|
157
|
+
# Duplicate initialization removed
|
|
158
|
+
|
|
159
|
+
# R (Sender of BaseOT) keygen
|
|
160
|
+
def base_sender_keygen(
|
|
161
|
+
C: el.Object, s_base_choices: el.Object
|
|
162
|
+
) -> tuple[el.Object, list[el.Object]]:
|
|
163
|
+
# s_base_choices is (K,) Tensor
|
|
164
|
+
PK0_bytes_list = []
|
|
165
|
+
k_priv_list = []
|
|
166
|
+
|
|
167
|
+
for i in range(K):
|
|
168
|
+
k_priv = crypto.ec_random_scalar()
|
|
169
|
+
PK_sigma = crypto.ec_mul(crypto.ec_generator(), k_priv)
|
|
170
|
+
|
|
171
|
+
# Slice s[i]
|
|
172
|
+
# Using slice_tensor here is efficient since s_base_choices is small; the overhead of run_jax is unnecessary.
|
|
173
|
+
s_i = tensor.slice_tensor(s_base_choices, (i,), (i + 1,))
|
|
174
|
+
s_scalar = crypto.ec_scalar_from_int(s_i)
|
|
175
|
+
|
|
176
|
+
diff = crypto.ec_sub(C, PK_sigma)
|
|
177
|
+
# select checks s_scalar. If 1 (true), pick diff.
|
|
178
|
+
PK0 = crypto.select(s_scalar, diff, PK_sigma)
|
|
179
|
+
|
|
180
|
+
# Convert to bytes (65 bytes uncompressed)
|
|
181
|
+
# K is 128, so overhead is small.
|
|
182
|
+
# Stacking points for shuffle.
|
|
183
|
+
pk0_b = crypto.ec_point_to_bytes(PK0)
|
|
184
|
+
# Reshape for stack: (65,) -> (1, 65)
|
|
185
|
+
pk0_b_r = tensor.reshape(pk0_b, (1, 65))
|
|
186
|
+
|
|
187
|
+
PK0_bytes_list.append(pk0_b_r)
|
|
188
|
+
k_priv_list.append(k_priv)
|
|
189
|
+
|
|
190
|
+
# Stack into (K, 65)
|
|
191
|
+
PK0_stacked = tensor.concat(PK0_bytes_list, axis=0)
|
|
192
|
+
|
|
193
|
+
return PK0_stacked, k_priv_list
|
|
194
|
+
|
|
195
|
+
# base_keys -> (PK0_stacked, k_priv_list (TraceObject list))
|
|
196
|
+
# Pass C_for_sender (received from receiver) to sender
|
|
197
|
+
base_keys_tuple = simp.pcall_static((sender,), base_sender_keygen, C_for_sender, s)
|
|
198
|
+
|
|
199
|
+
# Extract
|
|
200
|
+
PK0_loc = simp.pcall_static((sender,), lambda x: x[0], base_keys_tuple)
|
|
201
|
+
# Note: k_priv (x[1]) stays on sender, used later in base_decrypt_rev via base_keys_tuple
|
|
202
|
+
PK0_recv = simp.shuffle_static(PK0_loc, {receiver: sender})
|
|
203
|
+
|
|
204
|
+
# R (Base Sender) Encrypts k0, k1
|
|
205
|
+
def base_encrypt_rev(
|
|
206
|
+
C: el.Object,
|
|
207
|
+
PK0_bytes_tensor: el.Object,
|
|
208
|
+
m0_tensor: el.Object,
|
|
209
|
+
m1_tensor: el.Object,
|
|
210
|
+
) -> tuple[el.Object, el.Object, el.Object]:
|
|
211
|
+
# m0, m1 are (K, 32) tensors.
|
|
212
|
+
# PK0_bytes_tensor is (K, 65)
|
|
213
|
+
|
|
214
|
+
U_bytes_list = []
|
|
215
|
+
c0_list = []
|
|
216
|
+
c1_list = []
|
|
217
|
+
|
|
218
|
+
for i in range(K):
|
|
219
|
+
# Unstack PK0
|
|
220
|
+
# PK0_bytes_tensor is (K, 65)
|
|
221
|
+
# We want row i, all 65 bytes: slice(i:i+1, 0:65)
|
|
222
|
+
pk0_b = tensor.slice_tensor(PK0_bytes_tensor, (i, 0), (i + 1, 65))
|
|
223
|
+
# Reshape to (65,) for conversion
|
|
224
|
+
pk0_b_flat = tensor.reshape(pk0_b, (65,))
|
|
225
|
+
PK0 = crypto.ec_bytes_to_point(pk0_b_flat)
|
|
226
|
+
|
|
227
|
+
r = crypto.ec_random_scalar()
|
|
228
|
+
U = crypto.ec_mul(crypto.ec_generator(), r)
|
|
229
|
+
|
|
230
|
+
# Stack U as bytes
|
|
231
|
+
u_b = crypto.ec_point_to_bytes(U)
|
|
232
|
+
u_b_r = tensor.reshape(u_b, (1, 65))
|
|
233
|
+
U_bytes_list.append(u_b_r)
|
|
234
|
+
|
|
235
|
+
K0_point = crypto.ec_mul(PK0, r)
|
|
236
|
+
PK1 = crypto.ec_sub(C, PK0)
|
|
237
|
+
K1_point = crypto.ec_mul(PK1, r)
|
|
238
|
+
|
|
239
|
+
sk0 = crypto.hash_bytes(crypto.ec_point_to_bytes(K0_point)) # (32,)
|
|
240
|
+
sk1 = crypto.hash_bytes(crypto.ec_point_to_bytes(K1_point))
|
|
241
|
+
|
|
242
|
+
# Extract row i and encrypt in single run_jax block
|
|
243
|
+
def _slice_and_enc(
|
|
244
|
+
m0_full: Any, m1_full: Any, k0: Any, k1: Any, idx: int = i
|
|
245
|
+
) -> tuple[Any, Any]:
|
|
246
|
+
# Slice row i and reshape to (32,)
|
|
247
|
+
m0_row = m0_full[idx].flatten()
|
|
248
|
+
m1_row = m1_full[idx].flatten()
|
|
249
|
+
# XOR with keys
|
|
250
|
+
c0 = jnp.bitwise_xor(m0_row, k0)
|
|
251
|
+
c1 = jnp.bitwise_xor(m1_row, k1)
|
|
252
|
+
return c0.reshape(1, -1), c1.reshape(1, -1)
|
|
253
|
+
|
|
254
|
+
c0, c1 = tensor.run_jax(_slice_and_enc, m0_tensor, m1_tensor, sk0, sk1)
|
|
255
|
+
|
|
256
|
+
c0_list.append(c0)
|
|
257
|
+
c1_list.append(c1)
|
|
258
|
+
|
|
259
|
+
# Stack outputs
|
|
260
|
+
U_stacked = tensor.concat(U_bytes_list, axis=0) # (K, 65)
|
|
261
|
+
c0_stacked = tensor.concat(c0_list, axis=0) # (K, 32) (assuming 32 byte msgs)
|
|
262
|
+
c1_stacked = tensor.concat(c1_list, axis=0) # (K, 32)
|
|
263
|
+
|
|
264
|
+
return U_stacked, c0_stacked, c1_stacked
|
|
265
|
+
|
|
266
|
+
base_cts_rev = simp.pcall_static(
|
|
267
|
+
(receiver,), base_encrypt_rev, C_point, PK0_recv, k0_base, k1_base
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Shuffle tuple(Tensor, Tensor, Tensor) - Efficient!
|
|
271
|
+
# tree_map handles tuple
|
|
272
|
+
from jax.tree_util import tree_map
|
|
273
|
+
|
|
274
|
+
base_cts_s = tree_map(
|
|
275
|
+
lambda x: simp.shuffle_static(x, {sender: receiver}), base_cts_rev
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
def base_decrypt_rev(
|
|
279
|
+
keys: tuple[el.Object, list[el.Object]],
|
|
280
|
+
cts: tuple[el.Object, el.Object, el.Object],
|
|
281
|
+
s_choices: el.Object,
|
|
282
|
+
) -> el.Object:
|
|
283
|
+
# keys[0] is PK0_stacked (unused here), keys[1] is k_priv_list
|
|
284
|
+
_, k_priv_list = keys
|
|
285
|
+
# cts are stacked (K, 65), (K, 32), (K, 32)
|
|
286
|
+
U_packed, c0_packed, c1_packed = cts
|
|
287
|
+
|
|
288
|
+
k_s_rows = []
|
|
289
|
+
|
|
290
|
+
for i in range(K):
|
|
291
|
+
k_priv = k_priv_list[i]
|
|
292
|
+
|
|
293
|
+
# Unstack U
|
|
294
|
+
u_b = tensor.slice_tensor(U_packed, (i, 0), (i + 1, 65))
|
|
295
|
+
u_b_flat = tensor.reshape(u_b, (65,))
|
|
296
|
+
U = crypto.ec_bytes_to_point(u_b_flat)
|
|
297
|
+
|
|
298
|
+
# Unstack c0, c1
|
|
299
|
+
c0 = tensor.slice_tensor(c0_packed, (i, 0), (i + 1, 32))
|
|
300
|
+
c1 = tensor.slice_tensor(c1_packed, (i, 0), (i + 1, 32))
|
|
301
|
+
# Reshape to (32,) or (1, 32) depending on what _slice_dec_reshape expects
|
|
302
|
+
# _slice_dec_reshape expects (32,) usually if we want flat XOR?
|
|
303
|
+
# Let's check _slice_dec_reshape: it does `jnp.bitwise_xor(chosen_c, k)`.
|
|
304
|
+
# sk is (32,). So chosen_c should be (32,).
|
|
305
|
+
c0 = tensor.reshape(c0, (32,))
|
|
306
|
+
c1 = tensor.reshape(c1, (32,))
|
|
307
|
+
|
|
308
|
+
# Recov K = U^k_priv
|
|
309
|
+
SharedK = crypto.ec_mul(U, k_priv)
|
|
310
|
+
sk = crypto.hash_bytes(crypto.ec_point_to_bytes(SharedK))
|
|
311
|
+
|
|
312
|
+
# Combined slice + decrypt + reshape in single run_jax
|
|
313
|
+
def _slice_dec_reshape(
|
|
314
|
+
s_arr: Any, k: Any, c0_: Any, c1_: Any, idx: int = i
|
|
315
|
+
) -> Any:
|
|
316
|
+
sel = s_arr[idx]
|
|
317
|
+
chosen_c = jnp.where(sel == 0, c0_, c1_)
|
|
318
|
+
result = jnp.bitwise_xor(chosen_c, k)
|
|
319
|
+
return result.reshape(1, 32)
|
|
320
|
+
|
|
321
|
+
res_row = tensor.run_jax(_slice_dec_reshape, s_choices, sk, c0, c1)
|
|
322
|
+
k_s_rows.append(res_row)
|
|
323
|
+
|
|
324
|
+
# Concat using tensor.concat (run_jax with many args can cause tracing issues)
|
|
325
|
+
return tensor.concat(k_s_rows, axis=0)
|
|
326
|
+
|
|
327
|
+
k_s = simp.pcall_static((sender,), base_decrypt_rev, base_keys_tuple, base_cts_s, s)
|
|
328
|
+
|
|
329
|
+
# 2. PRG Expansion & Correction
|
|
330
|
+
def calc_u(k0_loc: el.Object, k1_loc: el.Object, r_loc: el.Object) -> el.Object:
|
|
331
|
+
g_k0 = prg_expand(k0_loc, num_ots) # (K, num_ots)
|
|
332
|
+
g_k1 = prg_expand(k1_loc, num_ots) # (K, num_ots)
|
|
333
|
+
|
|
334
|
+
# choice_bits can be:
|
|
335
|
+
# - (N,) 1D vector for standard IKNP
|
|
336
|
+
# - (N, K) 2D matrix for KKRT OPRF
|
|
337
|
+
#
|
|
338
|
+
# For IKNP: u^j = G(k0^j) ^ G(k1^j) ^ r, where r is broadcast to all K rows
|
|
339
|
+
# For KKRT: u^j = G(k0^j) ^ G(k1^j) ^ r^j, where r is (N, K) transposed to (K, N)
|
|
340
|
+
|
|
341
|
+
# Handle both 1D and 2D inputs
|
|
342
|
+
def _compute_u(g0: Any, g1: Any, r: Any) -> Any:
|
|
343
|
+
# g0, g1: (K, N) bit matrices
|
|
344
|
+
# r: either (N,) or (N, K)
|
|
345
|
+
if r.ndim == 1:
|
|
346
|
+
# 1D case: broadcast (N,) -> (1, N) for XOR with (K, N)
|
|
347
|
+
r_t = jnp.expand_dims(r, axis=0) # (1, N)
|
|
348
|
+
else:
|
|
349
|
+
# 2D case: transpose (N, K) -> (K, N)
|
|
350
|
+
r_t = jnp.transpose(r, (1, 0)) # (K, N)
|
|
351
|
+
return jnp.bitwise_xor(jnp.bitwise_xor(g0, g1), r_t)
|
|
352
|
+
|
|
353
|
+
return cast(el.Object, tensor.run_jax(_compute_u, g_k0, g_k1, r_loc))
|
|
354
|
+
|
|
355
|
+
u = simp.pcall_static((receiver,), calc_u, k0_base, k1_base, choice_bits)
|
|
356
|
+
u_recv = simp.shuffle_static(u, {sender: receiver})
|
|
357
|
+
|
|
358
|
+
# 3. Matrix Recovery & Transpose
|
|
359
|
+
def calc_t(k_s_loc: el.Object, u_loc: el.Object, s_loc: el.Object) -> el.Object:
|
|
360
|
+
g_k_s = prg_expand(k_s_loc, num_ots)
|
|
361
|
+
|
|
362
|
+
def _recover_and_transpose(g: Any, mask: Any, sel: Any) -> Any:
|
|
363
|
+
# Combine recover and transpose into single XLA block
|
|
364
|
+
sel_exp = jnp.expand_dims(sel, axis=-1)
|
|
365
|
+
term = jnp.bitwise_and(mask, sel_exp)
|
|
366
|
+
t_rows = jnp.bitwise_xor(g, term)
|
|
367
|
+
return jnp.transpose(t_rows, (1, 0)) # (N, K)
|
|
368
|
+
|
|
369
|
+
return cast(
|
|
370
|
+
el.Object, tensor.run_jax(_recover_and_transpose, g_k_s, u_loc, s_loc)
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
t_matrix = simp.pcall_static((sender,), calc_t, k_s, u_recv, s)
|
|
374
|
+
|
|
375
|
+
def calc_q(k0_loc: el.Object) -> el.Object:
|
|
376
|
+
g_k0 = prg_expand(k0_loc, num_ots)
|
|
377
|
+
# Use run_jax for transpose to enable XLA fusion
|
|
378
|
+
return cast(el.Object, tensor.run_jax(lambda x: jnp.transpose(x, (1, 0)), g_k0))
|
|
379
|
+
|
|
380
|
+
q_matrix = simp.pcall_static((receiver,), calc_q, k0_base)
|
|
381
|
+
|
|
382
|
+
# s is on Sender. t_matrix on Sender. q_matrix on Receiver.
|
|
383
|
+
return t_matrix, q_matrix, s
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def s_choices_sender(s: el.Object) -> el.Object:
|
|
387
|
+
return s # Already pcalled on sender
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def transfer_extension(
|
|
391
|
+
m0: el.Object,
|
|
392
|
+
m1: el.Object,
|
|
393
|
+
choice_bits: el.Object,
|
|
394
|
+
sender: int,
|
|
395
|
+
receiver: int,
|
|
396
|
+
num_ots: int,
|
|
397
|
+
) -> el.Object:
|
|
398
|
+
"""Perform IKNP OT Extension."""
|
|
399
|
+
|
|
400
|
+
t_matrix, q_matrix, s = iknp_core(choice_bits, sender, receiver, num_ots)
|
|
401
|
+
|
|
402
|
+
# 4. Encryption
|
|
403
|
+
def encrypt_msgs(
|
|
404
|
+
t_loc: el.Object, s_loc: el.Object, m0_loc: el.Object, m1_loc: el.Object
|
|
405
|
+
) -> el.Object:
|
|
406
|
+
# t: (N, K)
|
|
407
|
+
# s: (K,)
|
|
408
|
+
|
|
409
|
+
# Hash keys before using them as masks to break linear correlation
|
|
410
|
+
# H(t) and H(t^s)
|
|
411
|
+
# We use domain_sep=1 for IKNP payload masking
|
|
412
|
+
h_t = vec_hash(t_loc, domain_sep=1, num_rows=num_ots)
|
|
413
|
+
|
|
414
|
+
def _xor_s_and_hash(t: Any, s: Any) -> Any:
|
|
415
|
+
t_xor_s = jnp.bitwise_xor(t, s)
|
|
416
|
+
return t_xor_s
|
|
417
|
+
|
|
418
|
+
# We need to compute H(t^s). We can't easily do it in one block with vec_hash
|
|
419
|
+
# unless we compute t^s first.
|
|
420
|
+
t_xor_s_loc = cast(el.Object, tensor.run_jax(_xor_s_and_hash, t_loc, s_loc))
|
|
421
|
+
h_t_xor_s = vec_hash(t_xor_s_loc, domain_sep=1, num_rows=num_ots)
|
|
422
|
+
|
|
423
|
+
def _enc(ht: Any, hts: Any, msg0: Any, msg1: Any) -> Any:
|
|
424
|
+
# ht, hts are mapped to (N, 32) bytes usually, or whatever vec_hash returns
|
|
425
|
+
# msg0, msg1 are (N, D) bytes
|
|
426
|
+
|
|
427
|
+
# Ensure shapes match for XOR
|
|
428
|
+
# vec_hash returns (N, 32)
|
|
429
|
+
# If messages are not 32 bytes, we might need to adjust or truncation?
|
|
430
|
+
# Standard IKNP assumes messages are block size (128 bit = 16 bytes).
|
|
431
|
+
# But vec_hash produces 32 bytes (SHA256 usually).
|
|
432
|
+
# We slice hash to message length.
|
|
433
|
+
|
|
434
|
+
# msg0 shape: (N, 16) usually
|
|
435
|
+
d = msg0.shape[1]
|
|
436
|
+
|
|
437
|
+
ht_sliced = ht[:, :d]
|
|
438
|
+
hts_sliced = hts[:, :d]
|
|
439
|
+
|
|
440
|
+
c0 = jnp.bitwise_xor(msg0, ht_sliced)
|
|
441
|
+
c1 = jnp.bitwise_xor(msg1, hts_sliced)
|
|
442
|
+
return c0, c1
|
|
443
|
+
|
|
444
|
+
return cast(el.Object, tensor.run_jax(_enc, h_t, h_t_xor_s, m0_loc, m1_loc))
|
|
445
|
+
|
|
446
|
+
ciphertexts = simp.pcall_static((sender,), encrypt_msgs, t_matrix, s, m0, m1)
|
|
447
|
+
|
|
448
|
+
from jax.tree_util import tree_map
|
|
449
|
+
|
|
450
|
+
ciphertexts_recv = tree_map(
|
|
451
|
+
lambda x: simp.shuffle_static(x, {receiver: sender}), ciphertexts
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
def decrypt_msg(
|
|
455
|
+
q_loc: el.Object, r_loc: el.Object, c_texts: tuple[el.Object, el.Object]
|
|
456
|
+
) -> el.Object:
|
|
457
|
+
c0, c1 = c_texts
|
|
458
|
+
|
|
459
|
+
# Hash q: H(q)
|
|
460
|
+
h_q = vec_hash(q_loc, domain_sep=1, num_rows=num_ots)
|
|
461
|
+
|
|
462
|
+
def _dec(hq: Any, r: Any, ct0: Any, ct1: Any) -> Any:
|
|
463
|
+
# hq: (N, 32)
|
|
464
|
+
d = ct0.shape[1]
|
|
465
|
+
hq_sliced = hq[:, :d]
|
|
466
|
+
|
|
467
|
+
m0_cand = jnp.bitwise_xor(ct0, hq_sliced)
|
|
468
|
+
m1_cand = jnp.bitwise_xor(ct1, hq_sliced)
|
|
469
|
+
r_exp = jnp.expand_dims(r, axis=-1)
|
|
470
|
+
return jnp.where(r_exp == 1, m1_cand, m0_cand)
|
|
471
|
+
|
|
472
|
+
return cast(el.Object, tensor.run_jax(_dec, h_q, r_loc, c0, c1))
|
|
473
|
+
|
|
474
|
+
res = simp.pcall_static(
|
|
475
|
+
(receiver,), decrypt_msg, q_matrix, choice_bits, ciphertexts_recv
|
|
476
|
+
)
|
|
477
|
+
return cast(el.Object, res)
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Silent OT (Random VOLE) Implementation.
|
|
16
|
+
|
|
17
|
+
Implements "Silent Random VOLE" via Linear Expansion (LPN-like).
|
|
18
|
+
This provides O(N) local computation but O(k) communication.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from typing import Any, cast
|
|
22
|
+
|
|
23
|
+
import jax.numpy as jnp
|
|
24
|
+
|
|
25
|
+
import mplang.v2.edsl as el
|
|
26
|
+
import mplang.v2.edsl.typing as elt
|
|
27
|
+
import mplang.v2.libs.mpc.vole.gilboa as vole
|
|
28
|
+
from mplang.v2.dialects import crypto, field, simp, tensor
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def silent_vole_random_u(
|
|
32
|
+
sender: int,
|
|
33
|
+
receiver: int,
|
|
34
|
+
n: int,
|
|
35
|
+
base_k: int = 1024,
|
|
36
|
+
) -> tuple[el.Object, el.Object, el.Object, el.Object]:
|
|
37
|
+
"""Execute Silent Random VOLE (Linear Expansion).
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
sender: Rank of Sender.
|
|
41
|
+
receiver: Rank of Receiver.
|
|
42
|
+
n: Target vector length (e.g. 10^9).
|
|
43
|
+
base_k: Size of Base VOLE (LPN parameter).
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
v, w, u, delta
|
|
47
|
+
Where w = v + u * delta.
|
|
48
|
+
u is RANDOM.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
# 1. Base VOLE (Standard Gilboa)
|
|
52
|
+
# We need providers for base_u and base_delta.
|
|
53
|
+
|
|
54
|
+
def _base_u_provider() -> el.Object:
|
|
55
|
+
# Random U_base (base_k, 2) using new API
|
|
56
|
+
return crypto.random_tensor((base_k, 2), elt.u64)
|
|
57
|
+
|
|
58
|
+
def _base_delta_provider() -> el.Object:
|
|
59
|
+
# Random Delta (2,) using new API
|
|
60
|
+
return crypto.random_tensor((2,), elt.u64)
|
|
61
|
+
|
|
62
|
+
# v_base: (k, 2), w_base: (k, 2)
|
|
63
|
+
# The return type is a Union, mypy complains about unpacking.
|
|
64
|
+
# We ignore the type error here as we know return_secrets=True returns 4 values.
|
|
65
|
+
v_base, w_base, u_base, delta = vole.vole( # type: ignore
|
|
66
|
+
sender,
|
|
67
|
+
receiver,
|
|
68
|
+
base_k,
|
|
69
|
+
_base_u_provider,
|
|
70
|
+
_base_delta_provider,
|
|
71
|
+
return_secrets=True,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# 2. Linear Expansion
|
|
75
|
+
# We rely on a public seed for the mixing matrix M.
|
|
76
|
+
seed = simp.pcall_static((sender,), lambda: crypto.random_bytes(32))
|
|
77
|
+
# Share seed (Receiver needs it too)
|
|
78
|
+
seed_recv = simp.shuffle_static(seed, {receiver: sender}) # S -> R
|
|
79
|
+
|
|
80
|
+
# Expansion Logic
|
|
81
|
+
# We process in chunks to avoid massive implementation limit or memory issues.
|
|
82
|
+
# But EDSL graph optimization might handle loops?
|
|
83
|
+
# For safe side, let's implement a loop over chunks in Python if N is large.
|
|
84
|
+
# However, N is dynamic usually? Here N is int param.
|
|
85
|
+
|
|
86
|
+
# Chunk size
|
|
87
|
+
CHUNK_SIZE = 100_000 # 100k items per chunk
|
|
88
|
+
|
|
89
|
+
# We need to broadcast delta and bases to expansion function?
|
|
90
|
+
# Actually, we can just expand v_base -> v_long, w_base -> w_long.
|
|
91
|
+
# u_long is implicit (u_base * M).
|
|
92
|
+
# Since we need to return u, we expand u_base too.
|
|
93
|
+
|
|
94
|
+
# Define expansion op
|
|
95
|
+
def _expand_chunk(base_vec: Any, mask_packed: Any, chunk_len: int) -> Any:
|
|
96
|
+
# base_vec: (K, 2) u64
|
|
97
|
+
# mask_packed: (K, blocks, 2) u64 (AES output)
|
|
98
|
+
# chunk_len: number of bits to extract
|
|
99
|
+
|
|
100
|
+
# 1. Unpack bits from mask_packed
|
|
101
|
+
# mask_packed is (K, blocks, 2) u64.
|
|
102
|
+
# View as u8: (K, blocks, 16)
|
|
103
|
+
mask_u8 = mask_packed.view(jnp.uint8)
|
|
104
|
+
|
|
105
|
+
# Unpack bits: (K, blocks, 16, 8)
|
|
106
|
+
bits = jnp.unpackbits(mask_u8, bitorder="little")
|
|
107
|
+
|
|
108
|
+
# Flatten to (K, total_bits)
|
|
109
|
+
bits_flat = bits.reshape(base_k, -1)
|
|
110
|
+
|
|
111
|
+
# Slice to chunk_len
|
|
112
|
+
mask_mat = bits_flat[:, :chunk_len].astype(jnp.uint64)
|
|
113
|
+
|
|
114
|
+
# Broadcast base_vec (K, 2)
|
|
115
|
+
# We want: out[c] = XOR_sum_j (base[j] * mask[j, c])
|
|
116
|
+
|
|
117
|
+
base_shuffled = base_vec.reshape(base_k, 1, 2)
|
|
118
|
+
mask_expanded = mask_mat.reshape(base_k, chunk_len, 1)
|
|
119
|
+
|
|
120
|
+
# term[j, c] = base[j] * mask[j, c]
|
|
121
|
+
# mask is 0 or 1 (uint64). Multiplication works as selection.
|
|
122
|
+
terms = base_shuffled * mask_expanded # (K, chunk, 2)
|
|
123
|
+
|
|
124
|
+
# XOR Reduce over K
|
|
125
|
+
# Use simple loop or scan.
|
|
126
|
+
# terms: (K, chunk, 2)
|
|
127
|
+
|
|
128
|
+
def _xor_scan(carry: Any, x: Any) -> tuple[Any, Any]:
|
|
129
|
+
new_carry = jnp.bitwise_xor(carry, x)
|
|
130
|
+
return new_carry, None
|
|
131
|
+
|
|
132
|
+
# init: (chunk, 2) zeros
|
|
133
|
+
init_val = jnp.zeros((chunk_len, 2), dtype=jnp.uint64)
|
|
134
|
+
|
|
135
|
+
from jax import lax
|
|
136
|
+
|
|
137
|
+
res, _ = lax.scan(_xor_scan, init_val, terms)
|
|
138
|
+
|
|
139
|
+
return res
|
|
140
|
+
|
|
141
|
+
# 3. Orchestration
|
|
142
|
+
# We iterate chunks on Host? Or use `scan`?
|
|
143
|
+
# Host loop is easier for Memory management (Streaming).
|
|
144
|
+
# Return a "Lazy Object" or List of Objects?
|
|
145
|
+
# The signature `silent_vole` usually returns full Tensor.
|
|
146
|
+
# User requirement: "Silent OT" to reduce communications.
|
|
147
|
+
# If we return a full (N,) tensor, we solved bandwidth but not RAM.
|
|
148
|
+
# But for Phase 2 task "Protocol Upgrade", bandwidth is key.
|
|
149
|
+
# Phase 2 task "Streaming" handles RAM.
|
|
150
|
+
# So returning full Tensor is "okay" for now, although it might OOM 1B.
|
|
151
|
+
# Let's implement blocked execution and stack? No, that OOMs.
|
|
152
|
+
|
|
153
|
+
# We will implement `silent_vole_random_u` to return a `BigTensor` handle?
|
|
154
|
+
# Or just `el.Object` (which might be huge).
|
|
155
|
+
# Since we are in EDSL, the `el.Object` represents the *computation*.
|
|
156
|
+
# If we return a graph that produces (10^9,) tensor, the Evaluator might crash trying to allocate it.
|
|
157
|
+
|
|
158
|
+
# Let's just implement loop and return concatenated for now, assume 10^7-10^8 test case.
|
|
159
|
+
# For 10^9, we rely on Streaming Refactor later.
|
|
160
|
+
|
|
161
|
+
num_chunks = (n + CHUNK_SIZE - 1) // CHUNK_SIZE
|
|
162
|
+
|
|
163
|
+
def _run_expansion(b: Any, seed_val: Any) -> el.Object:
|
|
164
|
+
# b: base (K, 2)
|
|
165
|
+
# seed_val: (32,) u8
|
|
166
|
+
|
|
167
|
+
# 1. Derive K seeds from master seed using combined run_jax block
|
|
168
|
+
def _view_slice_reshape(b: Any) -> Any:
|
|
169
|
+
# View as u64, slice first row, then reshape for AES expand
|
|
170
|
+
u64_view = b.view(jnp.uint64).reshape(-1, 2)
|
|
171
|
+
master_seed = u64_view[:1] # (1, 2)
|
|
172
|
+
return master_seed
|
|
173
|
+
|
|
174
|
+
master_seed = tensor.run_jax(_view_slice_reshape, seed_val)
|
|
175
|
+
|
|
176
|
+
# Expand to K seeds: (1, K, 2)
|
|
177
|
+
row_seeds_packed = field.aes_expand(master_seed, base_k)
|
|
178
|
+
# Reshape using run_jax for XLA optimization
|
|
179
|
+
row_seeds = tensor.run_jax(lambda x: x.reshape(base_k, 2), row_seeds_packed)
|
|
180
|
+
|
|
181
|
+
# Iterate chunks
|
|
182
|
+
local_res = []
|
|
183
|
+
for i in range(num_chunks):
|
|
184
|
+
this_len = min(CHUNK_SIZE, n - i * CHUNK_SIZE)
|
|
185
|
+
|
|
186
|
+
# Generate mask for this chunk using AES
|
|
187
|
+
# Need ceil(this_len / 128) blocks
|
|
188
|
+
num_blocks = (this_len + 127) // 128
|
|
189
|
+
mask_packed = field.aes_expand(row_seeds, num_blocks)
|
|
190
|
+
|
|
191
|
+
# We must use `tensor.run_jax` so logic runs on device
|
|
192
|
+
def _core(base: Any, mask: Any, this_len: int = this_len) -> Any:
|
|
193
|
+
return _expand_chunk(base, mask, this_len)
|
|
194
|
+
|
|
195
|
+
chunk_res = tensor.run_jax(_core, b, mask_packed)
|
|
196
|
+
local_res.append(chunk_res)
|
|
197
|
+
|
|
198
|
+
# Use run_jax for concat to enable XLA fusion
|
|
199
|
+
if len(local_res) == 1:
|
|
200
|
+
return cast(el.Object, local_res[0])
|
|
201
|
+
|
|
202
|
+
def _concat_chunks(*chunks: Any) -> Any:
|
|
203
|
+
return jnp.concatenate(chunks, axis=0)
|
|
204
|
+
|
|
205
|
+
return cast(el.Object, tensor.run_jax(_concat_chunks, *local_res))
|
|
206
|
+
|
|
207
|
+
# Execute on Sender
|
|
208
|
+
v_long = simp.pcall_static((sender,), _run_expansion, v_base, seed)
|
|
209
|
+
# Execute on Receiver
|
|
210
|
+
w_long = simp.pcall_static((receiver,), _run_expansion, w_base, seed_recv)
|
|
211
|
+
|
|
212
|
+
# U expansion
|
|
213
|
+
u_long = simp.pcall_static((sender,), _run_expansion, u_base, seed)
|
|
214
|
+
|
|
215
|
+
# Delta is scalar, reusable
|
|
216
|
+
|
|
217
|
+
return v_long, w_long, u_long, delta
|