mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Secure Permutation Network library.
|
|
16
|
+
|
|
17
|
+
This module implements secure permutation (shuffling) using Oblivious Transfer (OT).
|
|
18
|
+
It allows a Sender (holding data) and a Receiver (holding a permutation) to
|
|
19
|
+
cooperatively shuffle the data such that:
|
|
20
|
+
1. The Receiver obtains the shuffled data.
|
|
21
|
+
2. The Sender learns nothing about the permutation.
|
|
22
|
+
3. The Receiver learns nothing about the original data order (beyond the result).
|
|
23
|
+
|
|
24
|
+
The implementation uses a Bitonic sorting network to achieve oblivious permutation.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from __future__ import annotations
|
|
28
|
+
|
|
29
|
+
import math
|
|
30
|
+
from typing import Any
|
|
31
|
+
|
|
32
|
+
import jax
|
|
33
|
+
import jax.numpy as jnp
|
|
34
|
+
import numpy as np
|
|
35
|
+
|
|
36
|
+
import mplang.v2.edsl.typing as elt
|
|
37
|
+
from mplang.v2.dialects import simp, tensor
|
|
38
|
+
from mplang.v2.libs.mpc.ot import base as ot
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def secure_switch(
|
|
42
|
+
x0: Any, x1: Any, control_bit: Any, sender: int, receiver: int
|
|
43
|
+
) -> tuple[Any, Any]:
|
|
44
|
+
"""A 2x2 Secure Switch using OT.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
x0: Input 0 (on Sender).
|
|
48
|
+
x1: Input 1 (on Sender).
|
|
49
|
+
control_bit: Control bit (on Receiver). 0 = straight, 1 = swap.
|
|
50
|
+
sender: Rank of the sender.
|
|
51
|
+
receiver: Rank of the receiver.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
(y0, y1) on Receiver, where:
|
|
55
|
+
y0 = x0 if c=0 else x1
|
|
56
|
+
y1 = x1 if c=0 else x0
|
|
57
|
+
"""
|
|
58
|
+
# y0 = select(x0, x1, c)
|
|
59
|
+
y0 = ot.transfer(x0, x1, control_bit, sender, receiver)
|
|
60
|
+
|
|
61
|
+
# y1 = select(x0, x1, 1-c)
|
|
62
|
+
# Compute 1-c locally on receiver
|
|
63
|
+
def invert_bit(c: Any) -> Any:
|
|
64
|
+
return tensor.run_jax(lambda x: 1 - x, c)
|
|
65
|
+
|
|
66
|
+
inv_control_bit = simp.pcall_static((receiver,), invert_bit, control_bit)
|
|
67
|
+
y1 = ot.transfer(x0, x1, inv_control_bit, sender, receiver)
|
|
68
|
+
|
|
69
|
+
return y0, y1
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _compute_bitonic_sort_controls(permutation: Any, n: int) -> Any:
|
|
73
|
+
"""Compute control bits for bitonic sorting network using JAX.
|
|
74
|
+
|
|
75
|
+
This uses a simple approach: track permutation values through
|
|
76
|
+
bitonic sort and record comparison results.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
permutation: JAX array of target indices.
|
|
80
|
+
n: Size (must be power of 2).
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
JAX array of control bits.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def impl(perm: Any) -> Any:
|
|
87
|
+
# Convert from gather-style permutation (output[i] = input[perm[i]])
|
|
88
|
+
# to destination positions for each input index: inv[src] = dest
|
|
89
|
+
perm_i64 = perm.astype(jnp.int64)
|
|
90
|
+
indices = jnp.arange(n, dtype=jnp.int64)
|
|
91
|
+
current = jnp.zeros_like(perm_i64)
|
|
92
|
+
current = current.at[perm_i64].set(indices)
|
|
93
|
+
|
|
94
|
+
controls = []
|
|
95
|
+
num_stages = int(math.log2(n))
|
|
96
|
+
|
|
97
|
+
for stage in range(num_stages):
|
|
98
|
+
for step in range(stage + 1):
|
|
99
|
+
step_dist = 2 ** (stage - step)
|
|
100
|
+
|
|
101
|
+
# Vectorized calculation using numpy for static indices
|
|
102
|
+
indices_np = np.arange(n, dtype=np.int64)
|
|
103
|
+
partners_np = indices_np ^ step_dist
|
|
104
|
+
mask_np = partners_np > indices_np
|
|
105
|
+
|
|
106
|
+
idx_i_np = indices_np[mask_np]
|
|
107
|
+
idx_j_np = partners_np[mask_np]
|
|
108
|
+
|
|
109
|
+
# Determine sort direction
|
|
110
|
+
block_size = 2 ** (stage + 1)
|
|
111
|
+
block_ids = indices_np // block_size
|
|
112
|
+
ascending_np = (block_ids % 2) == 0
|
|
113
|
+
asc_mask_np = ascending_np[mask_np]
|
|
114
|
+
|
|
115
|
+
# Extract values
|
|
116
|
+
v_i = current[idx_i_np]
|
|
117
|
+
v_j = current[idx_j_np]
|
|
118
|
+
|
|
119
|
+
swap_asc = v_i > v_j
|
|
120
|
+
swap_desc = v_i < v_j
|
|
121
|
+
|
|
122
|
+
should_swap = jnp.where(asc_mask_np, swap_asc, swap_desc)
|
|
123
|
+
|
|
124
|
+
controls.append(should_swap)
|
|
125
|
+
|
|
126
|
+
# Update current
|
|
127
|
+
new_i = jnp.where(should_swap, v_j, v_i)
|
|
128
|
+
new_j = jnp.where(should_swap, v_i, v_j)
|
|
129
|
+
|
|
130
|
+
current = current.at[idx_i_np].set(new_i)
|
|
131
|
+
current = current.at[idx_j_np].set(new_j)
|
|
132
|
+
|
|
133
|
+
return jnp.concatenate(controls)
|
|
134
|
+
|
|
135
|
+
return tensor.run_jax(impl, permutation)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def apply_permutation(data: Any, permutation: Any, sender: int, receiver: int) -> Any:
|
|
139
|
+
"""Apply a secure permutation using a Bitonic sorting network.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
data: Data items (on Sender). Can be a Tensor or list of Objects.
|
|
143
|
+
permutation: Tensor of indices (on Receiver). e.g. [2, 0, 1, 3]
|
|
144
|
+
permutation[i] = src means output[i] comes from input[src].
|
|
145
|
+
sender: Rank of sender.
|
|
146
|
+
receiver: Rank of receiver.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Shuffled data (on Receiver). Returns a list if input was a list.
|
|
150
|
+
"""
|
|
151
|
+
# Remember if input was a list
|
|
152
|
+
is_list_input = isinstance(data, list)
|
|
153
|
+
|
|
154
|
+
# Handle list input - convert to tensor
|
|
155
|
+
if is_list_input:
|
|
156
|
+
if len(data) == 0:
|
|
157
|
+
return []
|
|
158
|
+
|
|
159
|
+
# Stack list elements into a tensor
|
|
160
|
+
def stack_elements(*args: Any) -> Any:
|
|
161
|
+
return tensor.run_jax(lambda *xs: jnp.stack(xs), *args)
|
|
162
|
+
|
|
163
|
+
data = simp.pcall_static((sender,), stack_elements, *data)
|
|
164
|
+
|
|
165
|
+
target_type = data.type
|
|
166
|
+
if isinstance(target_type, elt.MPType):
|
|
167
|
+
target_type = target_type.value_type
|
|
168
|
+
if not isinstance(target_type, elt.TensorType):
|
|
169
|
+
raise TypeError("apply_permutation expects tensor inputs")
|
|
170
|
+
n = target_type.shape[0]
|
|
171
|
+
original_n = n
|
|
172
|
+
|
|
173
|
+
# Bitonic sort requires power-of-2 size - pad if necessary
|
|
174
|
+
n_padded = 2 ** math.ceil(math.log2(max(n, 2)))
|
|
175
|
+
|
|
176
|
+
if n_padded != n:
|
|
177
|
+
# Pad data with zeros
|
|
178
|
+
def pad_data(d: Any, pad_n: int) -> Any:
|
|
179
|
+
return tensor.run_jax(
|
|
180
|
+
lambda x: jnp.pad(x, (0, pad_n - x.shape[0]), mode="constant"), d
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
data = simp.pcall_static((sender,), pad_data, data, n_padded)
|
|
184
|
+
|
|
185
|
+
# Pad permutation with identity mapping for extra elements
|
|
186
|
+
def pad_perm(p: Any, orig: int, pad_n: int) -> Any:
|
|
187
|
+
extra = jnp.arange(orig, pad_n, dtype=jnp.int64)
|
|
188
|
+
return tensor.run_jax(lambda x: jnp.concatenate([x, extra]), p)
|
|
189
|
+
|
|
190
|
+
permutation = simp.pcall_static((receiver,), pad_perm, permutation, n, n_padded)
|
|
191
|
+
n = n_padded
|
|
192
|
+
|
|
193
|
+
# Compute control bits for bitonic sort (on Receiver)
|
|
194
|
+
controls = simp.pcall_static(
|
|
195
|
+
(receiver,), lambda p: _compute_bitonic_sort_controls(p, n), permutation
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Apply bitonic sorting network
|
|
199
|
+
# Strategy:
|
|
200
|
+
# - Iterate through stages/steps.
|
|
201
|
+
# - For each step, identify pairs (i, j).
|
|
202
|
+
# - If data is on Sender (first step), use Vectorized OT (secure_switch).
|
|
203
|
+
# - If data is on Receiver (subsequent steps), use local select.
|
|
204
|
+
|
|
205
|
+
current = data
|
|
206
|
+
ctrl_offset = 0
|
|
207
|
+
num_stages = int(math.log2(n))
|
|
208
|
+
|
|
209
|
+
# Helper to extract a slice of controls
|
|
210
|
+
def get_step_ctrls(all_ctrls: Any, off: int, count: int) -> Any:
|
|
211
|
+
def impl(c: Any, o: int, n: int) -> Any:
|
|
212
|
+
# Convert offset to tensor to avoid recompilation (dynamic slice start)
|
|
213
|
+
o_tensor = tensor.constant(np.array(o, dtype=np.int64))
|
|
214
|
+
# n (slice size) must be static for dynamic_slice
|
|
215
|
+
return tensor.run_jax(
|
|
216
|
+
lambda x, start, size: jax.lax.dynamic_slice(x, (start,), (size,)),
|
|
217
|
+
c,
|
|
218
|
+
o_tensor,
|
|
219
|
+
n,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
return simp.pcall_static((receiver,), impl, all_ctrls, off, count)
|
|
223
|
+
|
|
224
|
+
for stage in range(num_stages):
|
|
225
|
+
for step in range(stage + 1):
|
|
226
|
+
step_dist = 2 ** (stage - step)
|
|
227
|
+
|
|
228
|
+
# Vectorized step application
|
|
229
|
+
# Construct indices for all pairs in this step
|
|
230
|
+
indices = np.arange(n, dtype=np.int64)
|
|
231
|
+
partners = indices ^ step_dist
|
|
232
|
+
mask = partners > indices
|
|
233
|
+
idx_i_np = indices[mask]
|
|
234
|
+
idx_j_np = partners[mask]
|
|
235
|
+
|
|
236
|
+
# Number of pairs
|
|
237
|
+
num_pairs = len(idx_i_np)
|
|
238
|
+
|
|
239
|
+
# Check where data is
|
|
240
|
+
typ = current.type
|
|
241
|
+
is_on_sender = isinstance(typ, elt.MPType) and typ.parties == (sender,)
|
|
242
|
+
|
|
243
|
+
# Helper to create index tensors on specific parties
|
|
244
|
+
def make_indices(party: tuple[int, ...], idx: Any) -> Any:
|
|
245
|
+
return simp.pcall_static(party, lambda: tensor.constant(idx))
|
|
246
|
+
|
|
247
|
+
if is_on_sender:
|
|
248
|
+
# Get controls for this step (on Receiver)
|
|
249
|
+
step_ctrls = get_step_ctrls(controls, ctrl_offset, num_pairs)
|
|
250
|
+
|
|
251
|
+
# Data on Sender: Use OT
|
|
252
|
+
idx_i_sender = make_indices((sender,), idx_i_np)
|
|
253
|
+
idx_j_sender = make_indices((sender,), idx_j_np)
|
|
254
|
+
|
|
255
|
+
# Extract pairs on Sender
|
|
256
|
+
def extract_pairs_sender(
|
|
257
|
+
d: Any, idx_i: Any, idx_j: Any
|
|
258
|
+
) -> tuple[Any, Any]:
|
|
259
|
+
return (
|
|
260
|
+
tensor.run_jax(lambda x, i: x[i], d, idx_i),
|
|
261
|
+
tensor.run_jax(lambda x, j: x[j], d, idx_j),
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
val_i, val_j = simp.pcall_static(
|
|
265
|
+
(sender,),
|
|
266
|
+
extract_pairs_sender,
|
|
267
|
+
current,
|
|
268
|
+
idx_i_sender,
|
|
269
|
+
idx_j_sender,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Secure Switch (OT) -> Result on Receiver
|
|
273
|
+
res_i, res_j = secure_switch(val_i, val_j, step_ctrls, sender, receiver)
|
|
274
|
+
|
|
275
|
+
# Reconstruct full array on Receiver
|
|
276
|
+
idx_i_recv = make_indices((receiver,), idx_i_np)
|
|
277
|
+
idx_j_recv = make_indices((receiver,), idx_j_np)
|
|
278
|
+
|
|
279
|
+
# We need to scatter res_i and res_j back to their positions
|
|
280
|
+
def scatter_results(
|
|
281
|
+
vi: Any, vj: Any, ii: Any, ij: Any, size: int
|
|
282
|
+
) -> Any:
|
|
283
|
+
# Initialize with zeros (or dummy)
|
|
284
|
+
# Note: We assume we cover all indices.
|
|
285
|
+
# Bitonic sort step covers all indices exactly once.
|
|
286
|
+
# So we can just scatter.
|
|
287
|
+
|
|
288
|
+
# We need a template for the result.
|
|
289
|
+
# vi is the type of elements.
|
|
290
|
+
# We can use jnp.zeros_like(vi) but expanded?
|
|
291
|
+
# Or just allocate.
|
|
292
|
+
|
|
293
|
+
def impl(
|
|
294
|
+
v_i: jnp.ndarray,
|
|
295
|
+
v_j: jnp.ndarray,
|
|
296
|
+
idx_i: jnp.ndarray,
|
|
297
|
+
idx_j: jnp.ndarray,
|
|
298
|
+
) -> jnp.ndarray:
|
|
299
|
+
# v_i shape (N/2, ...), idx_i shape (N/2,)
|
|
300
|
+
# We want output shape (N, ...)
|
|
301
|
+
|
|
302
|
+
# Infer shape from v_i
|
|
303
|
+
out_shape = (size, *v_i.shape[1:])
|
|
304
|
+
out = jnp.zeros(out_shape, dtype=v_i.dtype)
|
|
305
|
+
|
|
306
|
+
out = out.at[idx_i].set(v_i)
|
|
307
|
+
out = out.at[idx_j].set(v_j)
|
|
308
|
+
return out
|
|
309
|
+
|
|
310
|
+
return tensor.run_jax(impl, vi, vj, ii, ij)
|
|
311
|
+
|
|
312
|
+
current = simp.pcall_static(
|
|
313
|
+
(receiver,),
|
|
314
|
+
scatter_results,
|
|
315
|
+
res_i,
|
|
316
|
+
res_j,
|
|
317
|
+
idx_i_recv,
|
|
318
|
+
idx_j_recv,
|
|
319
|
+
n,
|
|
320
|
+
)
|
|
321
|
+
ctrl_offset += num_pairs
|
|
322
|
+
|
|
323
|
+
else:
|
|
324
|
+
# Data on Receiver: Execute locally
|
|
325
|
+
step_ctrls = get_step_ctrls(controls, ctrl_offset, num_pairs)
|
|
326
|
+
|
|
327
|
+
# Construct indices on Receiver
|
|
328
|
+
idx_i_recv = make_indices((receiver,), idx_i_np)
|
|
329
|
+
idx_j_recv = make_indices((receiver,), idx_j_np)
|
|
330
|
+
|
|
331
|
+
def apply_local_step(d: Any, c: Any, ii: Any, ij: Any) -> Any:
|
|
332
|
+
def impl(
|
|
333
|
+
curr: jnp.ndarray,
|
|
334
|
+
ctrls: jnp.ndarray,
|
|
335
|
+
idx_i: jnp.ndarray,
|
|
336
|
+
idx_j: jnp.ndarray,
|
|
337
|
+
) -> jnp.ndarray:
|
|
338
|
+
val_i = curr[idx_i]
|
|
339
|
+
val_j = curr[idx_j]
|
|
340
|
+
|
|
341
|
+
new_i = jnp.where(ctrls, val_j, val_i)
|
|
342
|
+
new_j = jnp.where(ctrls, val_i, val_j)
|
|
343
|
+
|
|
344
|
+
curr = curr.at[idx_i].set(new_i)
|
|
345
|
+
curr = curr.at[idx_j].set(new_j)
|
|
346
|
+
return curr
|
|
347
|
+
|
|
348
|
+
return tensor.run_jax(impl, d, c, ii, ij)
|
|
349
|
+
|
|
350
|
+
current = simp.pcall_static(
|
|
351
|
+
(receiver,),
|
|
352
|
+
apply_local_step,
|
|
353
|
+
current,
|
|
354
|
+
step_ctrls,
|
|
355
|
+
idx_i_recv,
|
|
356
|
+
idx_j_recv,
|
|
357
|
+
)
|
|
358
|
+
ctrl_offset += num_pairs
|
|
359
|
+
|
|
360
|
+
# Unpad if necessary
|
|
361
|
+
if n_padded != original_n:
|
|
362
|
+
|
|
363
|
+
def unpad(d: Any, orig: int) -> Any:
|
|
364
|
+
return tensor.run_jax(lambda x: x[:orig], d)
|
|
365
|
+
|
|
366
|
+
final_parties = (
|
|
367
|
+
current.type.parties if isinstance(current.type, elt.MPType) else None
|
|
368
|
+
)
|
|
369
|
+
if final_parties is not None:
|
|
370
|
+
current = simp.pcall_static(final_parties, unpad, current, original_n)
|
|
371
|
+
else:
|
|
372
|
+
current = unpad(current, original_n)
|
|
373
|
+
|
|
374
|
+
# Convert back to list if input was a list
|
|
375
|
+
if is_list_input:
|
|
376
|
+
|
|
377
|
+
def unstack_to_list(d: Any, n: int) -> list:
|
|
378
|
+
results = []
|
|
379
|
+
for i in range(n):
|
|
380
|
+
elem = tensor.run_jax(lambda x, idx=i: x[idx], d)
|
|
381
|
+
results.append(elem)
|
|
382
|
+
return results
|
|
383
|
+
|
|
384
|
+
return simp.pcall_static((receiver,), unstack_to_list, current, original_n)
|
|
385
|
+
|
|
386
|
+
return current
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Common constants for MPC protocols."""
|
|
16
|
+
|
|
17
|
+
# Golden Ratio for 64-bit Multiplicative Hashing
|
|
18
|
+
# Closest integer to 2^64 / phi
|
|
19
|
+
# Golden Ratio for 64-bit Multiplicative Hashing
|
|
20
|
+
# Closest integer to 2^64 / phi
|
|
21
|
+
GOLDEN_RATIO_64 = 0x9E3779B97F4A7C15
|
|
22
|
+
|
|
23
|
+
# LCG Constants
|
|
24
|
+
LCG_ADDEND = 0x14650FB0739D0383
|
|
25
|
+
LCG_MULTIPLIER = 0x27D4EB2F165667C5
|
|
26
|
+
|
|
27
|
+
# SplitMix64 Constants (Gamma values)
|
|
28
|
+
SPLITMIX64_GAMMA_1 = 0xBF58476D1CE4E5B9
|
|
29
|
+
SPLITMIX64_GAMMA_2 = 0x94D049BB133111EB
|
|
30
|
+
SPLITMIX64_GAMMA_3 = 0xFF51AFD7ED558CCD
|
|
31
|
+
SPLITMIX64_GAMMA_4 = 0xC4CEB9FE1A85EC53
|
|
32
|
+
|
|
33
|
+
# Arbitrary Constants (Nothing-up-my-sleeve numbers)
|
|
34
|
+
# Fractional part of PI (Hex)
|
|
35
|
+
PI_FRAC_1 = 0x243F6A8885A308D3
|
|
36
|
+
PI_FRAC_2 = 0x13198A2E03707344
|
|
37
|
+
|
|
38
|
+
# Fractional part of E (Hex)
|
|
39
|
+
E_FRAC_1 = 0xA4093822299F31D0
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Oblivious Transfer protocols.
|
|
16
|
+
|
|
17
|
+
Submodules:
|
|
18
|
+
- base: Naor-Pinkas 1-out-of-2 OT
|
|
19
|
+
- extension: IKNP OT Extension
|
|
20
|
+
- silent: Silent OT via LPN
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from .base import transfer
|
|
24
|
+
from .extension import iknp_core, transfer_extension
|
|
25
|
+
from .silent import silent_vole_random_u
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"iknp_core",
|
|
29
|
+
"silent_vole_random_u",
|
|
30
|
+
"transfer",
|
|
31
|
+
"transfer_extension",
|
|
32
|
+
]
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Oblivious Transfer (OT) library.
|
|
16
|
+
|
|
17
|
+
This module implements OT logic using the `crypto` dialect (ECC + Hash + SymEnc).
|
|
18
|
+
It implements the Naor-Pinkas 1-out-of-2 OT protocol.
|
|
19
|
+
|
|
20
|
+
Protocol: Naor-Pinkas 1-out-of-2 OT
|
|
21
|
+
Security: Computational security based on ECDH.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
from typing import Any, cast
|
|
27
|
+
|
|
28
|
+
import numpy as np
|
|
29
|
+
|
|
30
|
+
import mplang.v2.edsl as el
|
|
31
|
+
import mplang.v2.edsl.typing as elt
|
|
32
|
+
from mplang.v2.dialects import crypto, simp, tensor
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _receiver_keygen_scalar(
|
|
36
|
+
C_point: el.Object, b: el.Object
|
|
37
|
+
) -> tuple[el.Object, el.Object]:
|
|
38
|
+
# b is selection bit (0 or 1)
|
|
39
|
+
# k is private key (random scalar)
|
|
40
|
+
k = crypto.ec_random_scalar()
|
|
41
|
+
G = crypto.ec_generator()
|
|
42
|
+
PK_sigma = crypto.ec_mul(G, k)
|
|
43
|
+
|
|
44
|
+
# PK0 = PK_sigma if b=0 else C - PK_sigma
|
|
45
|
+
# We use arithmetic selection for Points:
|
|
46
|
+
# PK0 = PK_sigma + b * (C - 2*PK_sigma)
|
|
47
|
+
|
|
48
|
+
b_scalar = crypto.ec_scalar_from_int(b)
|
|
49
|
+
|
|
50
|
+
# 2 * PK_sigma
|
|
51
|
+
# We use scalar 2
|
|
52
|
+
two_tensor = tensor.constant(np.array(2, dtype=np.int64))
|
|
53
|
+
two_scalar = crypto.ec_scalar_from_int(two_tensor)
|
|
54
|
+
|
|
55
|
+
two_PK_sigma = crypto.ec_mul(PK_sigma, two_scalar)
|
|
56
|
+
diff = crypto.ec_sub(C_point, two_PK_sigma)
|
|
57
|
+
term = crypto.ec_mul(diff, b_scalar)
|
|
58
|
+
PK0 = crypto.ec_add(PK_sigma, term)
|
|
59
|
+
|
|
60
|
+
return PK0, k
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _sender_derive_keys(
|
|
64
|
+
C_point: el.Object, PK0_point: el.Object
|
|
65
|
+
) -> tuple[el.Object, el.Object, el.Object, el.Object]:
|
|
66
|
+
# PK1 = C - PK0
|
|
67
|
+
PK1_point = crypto.ec_sub(C_point, PK0_point)
|
|
68
|
+
|
|
69
|
+
def derive_key(PK: el.Object) -> tuple[el.Object, el.Object]:
|
|
70
|
+
# Ephemeral key r
|
|
71
|
+
r = crypto.ec_random_scalar()
|
|
72
|
+
G = crypto.ec_generator()
|
|
73
|
+
U = crypto.ec_mul(G, r) # U = g^r
|
|
74
|
+
|
|
75
|
+
# Shared secret K = PK^r
|
|
76
|
+
K_point = crypto.ec_mul(PK, r)
|
|
77
|
+
return U, K_point
|
|
78
|
+
|
|
79
|
+
U0, K0 = derive_key(PK0_point)
|
|
80
|
+
U1, K1 = derive_key(PK1_point)
|
|
81
|
+
|
|
82
|
+
return U0, K0, U1, K1
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _receiver_derive_key(
|
|
86
|
+
U0: el.Object, U1: el.Object, PK0: el.Object, k: el.Object, b: el.Object
|
|
87
|
+
) -> el.Object:
|
|
88
|
+
b_scalar = crypto.ec_scalar_from_int(b)
|
|
89
|
+
|
|
90
|
+
# Select U (Point arithmetic)
|
|
91
|
+
# U = U0 + b*(U1-U0)
|
|
92
|
+
diff_U = crypto.ec_sub(U1, U0)
|
|
93
|
+
term_U = crypto.ec_mul(diff_U, b_scalar)
|
|
94
|
+
U = crypto.ec_add(U0, term_U)
|
|
95
|
+
|
|
96
|
+
# Recover Shared Secret K = U^k
|
|
97
|
+
K_point = crypto.ec_mul(U, k)
|
|
98
|
+
return K_point
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def transfer(
|
|
102
|
+
m0: el.MPObject,
|
|
103
|
+
m1: el.MPObject,
|
|
104
|
+
choice: el.MPObject,
|
|
105
|
+
sender: int,
|
|
106
|
+
receiver: int,
|
|
107
|
+
) -> el.MPObject:
|
|
108
|
+
"""Perform 1-out-of-2 Oblivious Transfer (Naor-Pinkas).
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
m0: Message 0 (on Sender).
|
|
112
|
+
m1: Message 1 (on Sender).
|
|
113
|
+
choice: Selection bit 0 or 1 (on Receiver).
|
|
114
|
+
sender: Rank of the sender.
|
|
115
|
+
receiver: Rank of the receiver.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
The selected message (on Receiver).
|
|
119
|
+
"""
|
|
120
|
+
assert isinstance(m0, el.Object) and isinstance(m0.type, elt.MPType)
|
|
121
|
+
assert isinstance(m1, el.Object) and isinstance(m1.type, elt.MPType)
|
|
122
|
+
assert isinstance(choice, el.Object) and isinstance(choice.type, elt.MPType)
|
|
123
|
+
|
|
124
|
+
# --- Step 1: Sender Initialization ---
|
|
125
|
+
def sender_init_fn() -> el.Object:
|
|
126
|
+
# C is a random point: C = r * G
|
|
127
|
+
r = crypto.ec_random_scalar()
|
|
128
|
+
G = crypto.ec_generator()
|
|
129
|
+
C = crypto.ec_mul(G, r)
|
|
130
|
+
return C
|
|
131
|
+
|
|
132
|
+
C = simp.pcall_static((sender,), sender_init_fn)
|
|
133
|
+
|
|
134
|
+
# Move C to Receiver
|
|
135
|
+
C_recv = simp.shuffle_static(C, {receiver: sender})
|
|
136
|
+
|
|
137
|
+
# Infer target type from m0
|
|
138
|
+
m0_type = m0.type
|
|
139
|
+
if isinstance(m0_type, elt.MPType):
|
|
140
|
+
val_type = m0_type.value_type
|
|
141
|
+
else:
|
|
142
|
+
val_type = m0_type
|
|
143
|
+
|
|
144
|
+
# Since we use tensor.elementwise, we need the element type for decryption
|
|
145
|
+
if isinstance(val_type, elt.TensorType):
|
|
146
|
+
target_type = val_type.element_type
|
|
147
|
+
else:
|
|
148
|
+
target_type = val_type
|
|
149
|
+
|
|
150
|
+
# --- Step 1: Receiver Key Generation ---
|
|
151
|
+
def receiver_keygen_fn(C_point: el.Object, b: el.Object) -> el.Object:
|
|
152
|
+
res: el.Object = cast(
|
|
153
|
+
el.Object, tensor.elementwise(_receiver_keygen_scalar, C_point, b)
|
|
154
|
+
)
|
|
155
|
+
return res # type: ignore[no-any-return]
|
|
156
|
+
|
|
157
|
+
# Returns (PK0, k) on receiver
|
|
158
|
+
keys_recv = simp.pcall_static((receiver,), receiver_keygen_fn, C_recv, choice)
|
|
159
|
+
|
|
160
|
+
# Extract PK0 to send back
|
|
161
|
+
def get_pk0(pair: Any) -> el.Object:
|
|
162
|
+
return cast(el.Object, pair[0])
|
|
163
|
+
|
|
164
|
+
PK0_to_send = simp.pcall_static((receiver,), get_pk0, keys_recv)
|
|
165
|
+
PK0_sender = simp.shuffle_static(PK0_to_send, {sender: receiver})
|
|
166
|
+
|
|
167
|
+
# --- Step 3: Sender Encryption ---
|
|
168
|
+
def sender_encrypt_fn(
|
|
169
|
+
C_point: el.Object, PK0_point: el.Object, msg0: el.Object, msg1: el.Object
|
|
170
|
+
) -> Any:
|
|
171
|
+
def encrypt_elementwise(
|
|
172
|
+
c: Any, pk0: Any, m0: Any, m1: Any
|
|
173
|
+
) -> tuple[Any, Any, Any, Any]:
|
|
174
|
+
u0, k0, u1, k1 = _sender_derive_keys(c, pk0)
|
|
175
|
+
|
|
176
|
+
kb0 = crypto.ec_point_to_bytes(k0)
|
|
177
|
+
kb1 = crypto.ec_point_to_bytes(k1)
|
|
178
|
+
|
|
179
|
+
sk0 = crypto.hash_bytes(kb0)
|
|
180
|
+
sk1 = crypto.hash_bytes(kb1)
|
|
181
|
+
|
|
182
|
+
c0 = crypto.sym_encrypt(sk0, m0)
|
|
183
|
+
c1 = crypto.sym_encrypt(sk1, m1)
|
|
184
|
+
return u0, c0, u1, c1
|
|
185
|
+
|
|
186
|
+
return tensor.elementwise(encrypt_elementwise, C_point, PK0_point, msg0, msg1)
|
|
187
|
+
|
|
188
|
+
ciphertexts = simp.pcall_static((sender,), sender_encrypt_fn, C, PK0_sender, m0, m1)
|
|
189
|
+
|
|
190
|
+
# Move ciphertexts to Receiver
|
|
191
|
+
# ciphertexts is a tuple, so we map shuffle over it
|
|
192
|
+
from jax.tree_util import tree_map
|
|
193
|
+
|
|
194
|
+
ciphertexts_recv = tree_map(
|
|
195
|
+
lambda x: simp.shuffle_static(x, {receiver: sender}), ciphertexts
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# --- Step 4: Receiver Decryption ---
|
|
199
|
+
def receiver_decrypt_fn(c_texts: Any, keys: Any, b: el.Object) -> el.Object:
|
|
200
|
+
# b is selection bit
|
|
201
|
+
# keys is (PK0, k)
|
|
202
|
+
PK0, k = keys
|
|
203
|
+
U0, V0, U1, V1 = c_texts
|
|
204
|
+
|
|
205
|
+
def decrypt_elementwise(
|
|
206
|
+
u0: Any, v0: Any, u1: Any, v1: Any, pk0: Any, k_priv: Any, sel: Any
|
|
207
|
+
) -> Any:
|
|
208
|
+
k_pt = _receiver_derive_key(u0, u1, pk0, k_priv, sel)
|
|
209
|
+
kb = crypto.ec_point_to_bytes(k_pt)
|
|
210
|
+
sk = crypto.hash_bytes(kb)
|
|
211
|
+
v = crypto.select(sel, v1, v0)
|
|
212
|
+
return crypto.sym_decrypt(sk, v, target_type)
|
|
213
|
+
|
|
214
|
+
res = tensor.elementwise(decrypt_elementwise, U0, V0, U1, V1, PK0, k, b)
|
|
215
|
+
return cast(el.Object, res)
|
|
216
|
+
|
|
217
|
+
result = simp.pcall_static(
|
|
218
|
+
(receiver,), receiver_decrypt_fn, ciphertexts_recv, keys_recv, choice
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
res_obj: el.Object = cast(el.Object, result)
|
|
222
|
+
return res_obj
|