mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Private Set Intersection (PSI) protocols.
|
|
16
|
+
|
|
17
|
+
Submodules:
|
|
18
|
+
- rr22: VOLE-masked PSI protocol (formerly okvs.py)
|
|
19
|
+
- unbalanced: Unbalanced PSI (O(n) communication)
|
|
20
|
+
- oprf: KKRT OPRF protocol
|
|
21
|
+
- cuckoo: Cuckoo hashing
|
|
22
|
+
- okvs_gct: Sparse OKVS data structure (Garbled Cuckoo Table)
|
|
23
|
+
- okvs: OKVS Abstract Base Class
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from .oprf import eval_oprf, sender_eval_prf, sender_eval_prf_batch
|
|
27
|
+
from .rr22 import psi_intersect
|
|
28
|
+
from .unbalanced import psi_unbalanced
|
|
29
|
+
|
|
30
|
+
# Alias for backward compatibility
|
|
31
|
+
eval = psi_intersect
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
"eval",
|
|
35
|
+
"eval_oprf",
|
|
36
|
+
"psi_intersect",
|
|
37
|
+
"psi_unbalanced",
|
|
38
|
+
"sender_eval_prf",
|
|
39
|
+
"sender_eval_prf_batch",
|
|
40
|
+
]
|
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Cuckoo Hashing for OPRF-PSI.
|
|
16
|
+
|
|
17
|
+
Implements JAX-compatible Cuckoo hashing for mapping items to table positions.
|
|
18
|
+
Each item hashes to K candidate positions; during lookup, check all K positions.
|
|
19
|
+
|
|
20
|
+
Reference: KKRT OPRF-PSI uses Cuckoo hashing for row mapping.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
from typing import Any, cast
|
|
26
|
+
|
|
27
|
+
import jax.numpy as jnp
|
|
28
|
+
|
|
29
|
+
import mplang.v2.edsl as el
|
|
30
|
+
from mplang.v2.dialects import tensor
|
|
31
|
+
from mplang.v2.libs.mpc.common.constants import (
|
|
32
|
+
E_FRAC_1,
|
|
33
|
+
GOLDEN_RATIO_64,
|
|
34
|
+
PI_FRAC_1,
|
|
35
|
+
PI_FRAC_2,
|
|
36
|
+
SPLITMIX64_GAMMA_1,
|
|
37
|
+
SPLITMIX64_GAMMA_2,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# =============================================================================
|
|
41
|
+
# Cuckoo Hash Parameters
|
|
42
|
+
# =============================================================================
|
|
43
|
+
|
|
44
|
+
NUM_HASH_FUNCTIONS = 3 # Standard: 3 hash functions
|
|
45
|
+
STASH_SIZE = 0 # Simple version: no stash (higher failure rate)
|
|
46
|
+
MASK64 = 0xFFFFFFFFFFFFFFFF
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def hash_to_positions(items: Any, table_size: int, seed: tuple[int, int]) -> Any:
|
|
50
|
+
"""Compute K candidate positions for each item.
|
|
51
|
+
|
|
52
|
+
Uses polynomial hash family with seeded coefficients:
|
|
53
|
+
h_i(x) = (a_i * x + b_i) mod table_size
|
|
54
|
+
|
|
55
|
+
Security: Both coefficients a and b are seeded to prevent
|
|
56
|
+
structural analysis attacks on the hash family.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
items: (N, 16) uint8 array - items to hash
|
|
60
|
+
table_size: Size of Cuckoo hash table
|
|
61
|
+
seed: (2,) tuple of uint64 - random seed
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
(N, K) int32 array - K candidate positions for each item
|
|
65
|
+
"""
|
|
66
|
+
N = items.shape[0]
|
|
67
|
+
K = NUM_HASH_FUNCTIONS
|
|
68
|
+
|
|
69
|
+
# Convert items to 64-bit keys (first 8 bytes)
|
|
70
|
+
keys = items[:, :8].view(jnp.uint64).reshape(N)
|
|
71
|
+
|
|
72
|
+
# Mix seed into keys
|
|
73
|
+
seed0 = jnp.uint64(seed[0])
|
|
74
|
+
seed1 = jnp.uint64(seed[1])
|
|
75
|
+
keys = keys ^ seed0
|
|
76
|
+
|
|
77
|
+
# Base hash coefficients (deterministic starting point)
|
|
78
|
+
a_base = jnp.array(
|
|
79
|
+
[GOLDEN_RATIO_64, SPLITMIX64_GAMMA_1, SPLITMIX64_GAMMA_2], dtype=jnp.uint64
|
|
80
|
+
)
|
|
81
|
+
b_base = jnp.array([PI_FRAC_1, PI_FRAC_2, E_FRAC_1], dtype=jnp.uint64)
|
|
82
|
+
|
|
83
|
+
# Security Fix: Seed BOTH coefficients a and b
|
|
84
|
+
# This prevents structural analysis attacks on the hash family
|
|
85
|
+
a = a_base ^ seed0 # Mix seed0 into multiplicative coefficient
|
|
86
|
+
b = b_base ^ seed1 # Mix seed1 into additive coefficient
|
|
87
|
+
|
|
88
|
+
# Compute hash positions: (N, K)
|
|
89
|
+
positions = jnp.zeros((N, K), dtype=jnp.int32)
|
|
90
|
+
for i in range(K):
|
|
91
|
+
h = (keys * a[i] + b[i]) % table_size
|
|
92
|
+
positions = positions.at[:, i].set(h.astype(jnp.int32))
|
|
93
|
+
|
|
94
|
+
return positions
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def cuckoo_insert_batch(
|
|
98
|
+
items: Any,
|
|
99
|
+
table_size: int,
|
|
100
|
+
seed: tuple[int, int],
|
|
101
|
+
max_iters: int = 100,
|
|
102
|
+
) -> tuple[Any, Any, Any]:
|
|
103
|
+
"""Batch Cuckoo insertion using vectorized logic (JAX-compatible).
|
|
104
|
+
|
|
105
|
+
Uses multi-choice parallel insertion:
|
|
106
|
+
1. All items try 1st choice. Collisions resolved by last-write-wins.
|
|
107
|
+
2. Failed items try 2nd choice.
|
|
108
|
+
3. Failed items try 3rd choice.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
items: (N, 16) uint8 array - items to insert
|
|
112
|
+
table_size: Size of Cuckoo hash table (should be ~1.3-1.5N)
|
|
113
|
+
max_iters: Ignored in this vectorized version (uses K=3 fixed passes)
|
|
114
|
+
seed: (2,) uint64 seed
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Tuple of:
|
|
118
|
+
- table: (table_size, 16) uint8 - Cuckoo hash table
|
|
119
|
+
- item_to_pos: (N,) int32 - position of each item in table
|
|
120
|
+
- success: (N,) bool - whether each item was successfully inserted
|
|
121
|
+
"""
|
|
122
|
+
N = items.shape[0]
|
|
123
|
+
K = NUM_HASH_FUNCTIONS
|
|
124
|
+
|
|
125
|
+
positions = hash_to_positions(items, table_size, seed)
|
|
126
|
+
item_to_pos = jnp.full(N, -1, dtype=jnp.int32)
|
|
127
|
+
active_mask = jnp.ones(N, dtype=jnp.bool_)
|
|
128
|
+
|
|
129
|
+
# We track which item "owns" each table slot
|
|
130
|
+
table_slots = jnp.full(table_size, -1, dtype=jnp.int32)
|
|
131
|
+
|
|
132
|
+
# Track occupied status to forbid overwriting previous successes
|
|
133
|
+
table_occupied = jnp.zeros(table_size, dtype=jnp.bool_)
|
|
134
|
+
|
|
135
|
+
item_indices = jnp.arange(N, dtype=jnp.int32)
|
|
136
|
+
|
|
137
|
+
for k in range(K):
|
|
138
|
+
# 1. Propose positions for active items
|
|
139
|
+
# Inactive items get -1 proposal
|
|
140
|
+
cand_pos = jnp.where(active_mask, positions[:, k], -1)
|
|
141
|
+
|
|
142
|
+
# 2. Filter out already occupied slots
|
|
143
|
+
# Map -1 to safe index 0 for lookup (result discarded via mask)
|
|
144
|
+
safe_lookup = jnp.maximum(cand_pos, 0)
|
|
145
|
+
is_occupied = table_occupied[safe_lookup]
|
|
146
|
+
# Valid proposal: not -1 AND not occupied
|
|
147
|
+
cand_pos_valid = jnp.where((cand_pos >= 0) & (~is_occupied), cand_pos, -1)
|
|
148
|
+
|
|
149
|
+
# 3. Attempt write to table_slots using Scatter
|
|
150
|
+
# Extend table to handle -1 dump index (at index table_size)
|
|
151
|
+
ext_slots = jnp.pad(table_slots, (0, 1), constant_values=-1)
|
|
152
|
+
|
|
153
|
+
# Map -1 to dump index
|
|
154
|
+
write_pos = jnp.where(cand_pos_valid >= 0, cand_pos_valid, table_size)
|
|
155
|
+
|
|
156
|
+
# Write active item indices
|
|
157
|
+
# We write ALL items, but inactive ones write to dump.
|
|
158
|
+
# This is safe because active ones write to valid slots (or dump if collision/occupied).
|
|
159
|
+
ext_slots_updated = ext_slots.at[write_pos].set(item_indices)
|
|
160
|
+
|
|
161
|
+
# 4. Verify winners
|
|
162
|
+
winner_indices = ext_slots_updated[write_pos]
|
|
163
|
+
|
|
164
|
+
# Success if:
|
|
165
|
+
# a) We had a valid proposal (cand_pos_valid != -1)
|
|
166
|
+
# b) Our index matches the winner
|
|
167
|
+
success_round = (cand_pos_valid >= 0) & (winner_indices == item_indices)
|
|
168
|
+
|
|
169
|
+
# 5. Commit state
|
|
170
|
+
# Update global state based on success
|
|
171
|
+
item_to_pos = jnp.where(success_round, cand_pos_valid, item_to_pos)
|
|
172
|
+
active_mask = active_mask & (~success_round)
|
|
173
|
+
|
|
174
|
+
# Update table slots (truncate dump)
|
|
175
|
+
table_slots = ext_slots_updated[:table_size]
|
|
176
|
+
table_occupied = table_slots >= 0
|
|
177
|
+
|
|
178
|
+
# Construct final table
|
|
179
|
+
safe_indices = jnp.maximum(table_slots, 0)
|
|
180
|
+
final_table = items[safe_indices]
|
|
181
|
+
final_table = jnp.where(table_slots[:, None] >= 0, final_table, 0)
|
|
182
|
+
|
|
183
|
+
success_total = item_to_pos >= 0
|
|
184
|
+
return final_table, item_to_pos, success_total
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def cuckoo_lookup_positions(items: Any, table_size: int, seed: tuple[int, int]) -> Any:
|
|
188
|
+
"""Get Cuckoo lookup positions for each item.
|
|
189
|
+
|
|
190
|
+
Returns the K candidate positions where each item could be located
|
|
191
|
+
in a Cuckoo hash table.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
items: (M, 16) uint8 array - items to lookup
|
|
195
|
+
table_size: Size of Cuckoo hash table
|
|
196
|
+
seed: (2,) uint64 seed
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
(M, K) int32 array - K positions to check for each item
|
|
200
|
+
"""
|
|
201
|
+
return hash_to_positions(items, table_size, seed)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
# =============================================================================
|
|
205
|
+
# EDSL Wrappers
|
|
206
|
+
# =============================================================================
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def compute_positions(
|
|
210
|
+
items: el.Object,
|
|
211
|
+
table_size: int,
|
|
212
|
+
seed: el.Object, # (2,) uint64
|
|
213
|
+
) -> el.Object:
|
|
214
|
+
"""Compute Cuckoo hash positions for items (EDSL wrapper).
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
items: (N, 16) byte tensor of items
|
|
218
|
+
table_size: Size of Cuckoo hash table
|
|
219
|
+
seed: (2,) uint64 seed
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
(N, K) int32 tensor of candidate positions
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
def _hash(x: Any, s: Any) -> Any:
|
|
226
|
+
return hash_to_positions(x, table_size, tuple(s))
|
|
227
|
+
|
|
228
|
+
return cast(el.Object, tensor.run_jax(_hash, items, seed))
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Abstract Base Class for OKVS (Oblivious Key-Value Store)."""
|
|
16
|
+
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
|
|
19
|
+
import mplang.v2.edsl as el
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class OKVS(ABC):
|
|
23
|
+
"""Abstract interface for Oblivious Key-Value Store."""
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def encode(self, keys: el.Object, values: el.Object, seed: el.Object) -> el.Object:
|
|
27
|
+
"""Encode items into OKVS storage.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
keys: (N,) uint64 tensor of keys
|
|
31
|
+
values: (N, D) uint64 tensor of values
|
|
32
|
+
seed: (2,) uint64 tensor seed
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
(M, D) uint64 tensor OKVS storage
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def decode(self, keys: el.Object, storage: el.Object, seed: el.Object) -> el.Object:
|
|
40
|
+
"""Decode items from OKVS storage.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
keys: (N,) uint64 tensor of keys to query
|
|
44
|
+
storage: (M, D) uint64 tensor OKVS storage
|
|
45
|
+
seed: (2,) uint64 tensor seed
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
(N, D) uint64 tensor of recovered values
|
|
49
|
+
"""
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Sparse OKVS (Oblivious Key-Value Store) Implementation.
|
|
16
|
+
|
|
17
|
+
This module provides the core data structures and algorithms for Sparse OKVS,
|
|
18
|
+
which is a critical component in unbalanced Private Set Intersection (PSI).
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
import mplang.v2.edsl as el
|
|
22
|
+
from mplang.v2.dialects import field
|
|
23
|
+
from mplang.v2.libs.mpc.psi.okvs import OKVS
|
|
24
|
+
|
|
25
|
+
# ============================================================================
|
|
26
|
+
# Constants
|
|
27
|
+
# ============================================================================
|
|
28
|
+
|
|
29
|
+
# Number of hash functions for Cuckoo hashing
|
|
30
|
+
NUM_HASHES = 3
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_okvs_expansion(n: int) -> float:
|
|
34
|
+
"""Get optimal OKVS expansion factor based on dataset size.
|
|
35
|
+
|
|
36
|
+
The 3-hash Garbled Cuckoo Table algorithm requires table size M > N for
|
|
37
|
+
the peeling algorithm to successfully solve the system. The minimum safe
|
|
38
|
+
expansion factor ε (where M = (1+ε)*N) depends on N:
|
|
39
|
+
|
|
40
|
+
- For N → ∞: Theoretical minimum is ε ≈ 0.23 (M = 1.23N)
|
|
41
|
+
- For finite N: Larger ε needed due to variance in random hash collisions
|
|
42
|
+
|
|
43
|
+
Empirical safe thresholds (failure probability < 0.01%):
|
|
44
|
+
- N < 1,000: ε = 4.5 (M = 5.5N) - very small sets need extra wide margin
|
|
45
|
+
to handle worst-case hash collisions
|
|
46
|
+
- N < 10,000: ε = 0.4 (M = 1.4N)
|
|
47
|
+
- N < 100,000: ε = 0.3 (M = 1.3N)
|
|
48
|
+
- N ≥ 100,000: ε = 0.35 (M = 1.35N) - large sets converge near theory
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
n: Number of key-value pairs to encode
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Expansion factor ε such that M = (1+ε)*N is safe for peeling
|
|
55
|
+
"""
|
|
56
|
+
if n < 1000:
|
|
57
|
+
return 5.5 # Small scale: need very wide safety margin for stability
|
|
58
|
+
elif n <= 10000:
|
|
59
|
+
return 1.4 # Medium scale
|
|
60
|
+
elif n <= 100000:
|
|
61
|
+
return 1.3 # Large scale
|
|
62
|
+
else:
|
|
63
|
+
# Mega-Binning requires ~1.35 for stability with 1024 bins
|
|
64
|
+
return 1.35
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class SparseOKVS(OKVS):
|
|
68
|
+
"""Sparse OKVS Implementation using 3-Hash Garbled Cuckoo Table."""
|
|
69
|
+
|
|
70
|
+
def __init__(self, m: int):
|
|
71
|
+
self.m = m
|
|
72
|
+
|
|
73
|
+
def encode(self, keys: el.Object, values: el.Object, seed: el.Object) -> el.Object:
|
|
74
|
+
"""Encode items into OKVS storage using C++ Kernel."""
|
|
75
|
+
return field.solve_okvs(keys, values, self.m, seed)
|
|
76
|
+
|
|
77
|
+
def decode(self, keys: el.Object, storage: el.Object, seed: el.Object) -> el.Object:
|
|
78
|
+
"""Decode items from OKVS storage using C++ Kernel."""
|
|
79
|
+
return field.decode_okvs(keys, storage, seed)
|