mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +46 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +19 -13
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/libs/mpc/psi/rr22.py +303 -0
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang/v2/libs/mpc/psi/rr22.py +0 -344
- mplang_nightly-0.1.dev268.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
|
@@ -28,10 +28,10 @@ from typing import Any, cast
|
|
|
28
28
|
|
|
29
29
|
from jax.tree_util import tree_flatten, tree_map
|
|
30
30
|
|
|
31
|
-
from mplang.
|
|
32
|
-
from mplang.
|
|
33
|
-
from mplang.
|
|
34
|
-
from mplang.
|
|
31
|
+
from mplang.backends import load_builtins
|
|
32
|
+
from mplang.dialects import crypto, simp, spu, tee
|
|
33
|
+
from mplang.edsl.object import Object
|
|
34
|
+
from mplang.libs.device.cluster import Device
|
|
35
35
|
|
|
36
36
|
load_builtins()
|
|
37
37
|
|
|
@@ -43,7 +43,7 @@ def _resolve_cluster() -> Any:
|
|
|
43
43
|
Interpreter with a _cluster_spec attribute. This allows nested contexts
|
|
44
44
|
to override the cluster if needed.
|
|
45
45
|
"""
|
|
46
|
-
from mplang.
|
|
46
|
+
from mplang.edsl.context import find_context
|
|
47
47
|
|
|
48
48
|
ctx = find_context(lambda c: getattr(c, "_cluster_spec", None) is not None)
|
|
49
49
|
if ctx is not None:
|
|
@@ -356,7 +356,7 @@ class DeviceContext:
|
|
|
356
356
|
if self._is_spu_device():
|
|
357
357
|
return self(fn)
|
|
358
358
|
# PPU/TEE need tensor.jax_fn to compile JAX code
|
|
359
|
-
from mplang.
|
|
359
|
+
from mplang.dialects.tensor import jax_fn
|
|
360
360
|
|
|
361
361
|
return self(jax_fn(fn))
|
|
362
362
|
|
|
@@ -443,7 +443,7 @@ def _ensure_tee_session(
|
|
|
443
443
|
Returns:
|
|
444
444
|
Tuple of (sess_frm, sess_tee) where each is a symmetric key Object
|
|
445
445
|
"""
|
|
446
|
-
import mplang.
|
|
446
|
+
import mplang.edsl as el
|
|
447
447
|
|
|
448
448
|
# Get current context ID for cache isolation
|
|
449
449
|
current_ctx = el.get_current_context()
|
|
@@ -749,11 +749,11 @@ def fetch(obj: Object) -> Any:
|
|
|
749
749
|
Returns:
|
|
750
750
|
Python value (numpy array, scalar, etc.)
|
|
751
751
|
"""
|
|
752
|
-
from mplang.
|
|
753
|
-
from mplang.
|
|
754
|
-
from mplang.
|
|
755
|
-
from mplang.
|
|
756
|
-
from mplang.
|
|
752
|
+
from mplang.backends.simp_driver.state import SimpDriver
|
|
753
|
+
from mplang.backends.simp_driver.values import DriverVar
|
|
754
|
+
from mplang.edsl.context import get_current_context
|
|
755
|
+
from mplang.runtime.interpreter import InterpObject, Interpreter
|
|
756
|
+
from mplang.runtime.value import WrapValue
|
|
757
757
|
|
|
758
758
|
def _unwrap_value(val: Any) -> Any:
|
|
759
759
|
"""Unwrap WrapValue to get the underlying data."""
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
# mypy: disable-error-code="no-untyped-def,no-any-return,var-annotated"
|
|
16
16
|
|
|
17
|
-
"""SecureBoost v2: Optimized implementation using mplang
|
|
17
|
+
"""SecureBoost v2: Optimized implementation using mplang low-level BFV APIs.
|
|
18
18
|
|
|
19
19
|
This implementation improves upon v1 by leveraging BFV SIMD slots and the
|
|
20
20
|
groupby primitives for efficient histogram computation.
|
|
@@ -46,8 +46,8 @@ import jax.numpy as jnp
|
|
|
46
46
|
import numpy as np
|
|
47
47
|
from jax.ops import segment_sum
|
|
48
48
|
|
|
49
|
-
from mplang.
|
|
50
|
-
from mplang.
|
|
49
|
+
from mplang.dialects import bfv, simp, tensor
|
|
50
|
+
from mplang.libs.mpc.analytics import aggregation
|
|
51
51
|
|
|
52
52
|
# ==============================================================================
|
|
53
53
|
# Configuration
|
|
@@ -1687,7 +1687,7 @@ def fit_tree_ensemble(
|
|
|
1687
1687
|
|
|
1688
1688
|
|
|
1689
1689
|
class SecureBoost:
|
|
1690
|
-
"""SecureBoost classifier using mplang
|
|
1690
|
+
"""SecureBoost classifier using mplang low-level BFV APIs.
|
|
1691
1691
|
|
|
1692
1692
|
This is an optimized implementation that uses BFV SIMD slots for
|
|
1693
1693
|
efficient histogram computation.
|
|
@@ -21,9 +21,9 @@ Subpackages:
|
|
|
21
21
|
- analytics: Privacy-preserving analytics
|
|
22
22
|
|
|
23
23
|
Example usage:
|
|
24
|
-
from mplang.
|
|
25
|
-
from mplang.
|
|
26
|
-
from mplang.
|
|
24
|
+
from mplang.libs.mpc import ot_transfer, apply_permutation
|
|
25
|
+
from mplang.libs.mpc.vole import silver_vole
|
|
26
|
+
from mplang.libs.mpc.psi import psi_intersect
|
|
27
27
|
"""
|
|
28
28
|
|
|
29
29
|
from .analytics.aggregation import rotate_and_sum
|
|
@@ -20,8 +20,8 @@ from typing import Any, cast
|
|
|
20
20
|
|
|
21
21
|
import jax.numpy as jnp
|
|
22
22
|
|
|
23
|
-
import mplang.
|
|
24
|
-
from mplang.
|
|
23
|
+
import mplang.edsl as el
|
|
24
|
+
from mplang.dialects import tensor
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
def bytes_to_bits(data: el.Object) -> el.Object:
|
|
@@ -27,8 +27,8 @@ from typing import Any
|
|
|
27
27
|
import jax
|
|
28
28
|
import jax.numpy as jnp
|
|
29
29
|
|
|
30
|
-
from mplang.
|
|
31
|
-
from mplang.
|
|
30
|
+
from mplang.dialects import bfv, crypto, simp, tensor
|
|
31
|
+
from mplang.libs.mpc.analytics import aggregation, permutation
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
def oblivious_groupby_sum_bfv(
|
|
@@ -33,9 +33,9 @@ import jax
|
|
|
33
33
|
import jax.numpy as jnp
|
|
34
34
|
import numpy as np
|
|
35
35
|
|
|
36
|
-
import mplang.
|
|
37
|
-
from mplang.
|
|
38
|
-
from mplang.
|
|
36
|
+
import mplang.edsl.typing as elt
|
|
37
|
+
from mplang.dialects import simp, tensor
|
|
38
|
+
from mplang.libs.mpc.ot import base as ot
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
def secure_switch(
|
|
@@ -27,9 +27,9 @@ from typing import Any, cast
|
|
|
27
27
|
|
|
28
28
|
import numpy as np
|
|
29
29
|
|
|
30
|
-
import mplang.
|
|
31
|
-
import mplang.
|
|
32
|
-
from mplang.
|
|
30
|
+
import mplang.edsl as el
|
|
31
|
+
import mplang.edsl.typing as elt
|
|
32
|
+
from mplang.dialects import crypto, simp, tensor
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
def _receiver_keygen_scalar(
|
|
@@ -24,8 +24,8 @@ from typing import Any, cast
|
|
|
24
24
|
|
|
25
25
|
import jax.numpy as jnp
|
|
26
26
|
|
|
27
|
-
import mplang.
|
|
28
|
-
from mplang.
|
|
27
|
+
import mplang.edsl as el
|
|
28
|
+
from mplang.dialects import crypto, field, simp, tensor
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
def prg_expand(seed_tensor: el.Object, length: int) -> el.Object:
|
|
@@ -22,10 +22,10 @@ from typing import Any, cast
|
|
|
22
22
|
|
|
23
23
|
import jax.numpy as jnp
|
|
24
24
|
|
|
25
|
-
import mplang.
|
|
26
|
-
import mplang.
|
|
27
|
-
import mplang.
|
|
28
|
-
from mplang.
|
|
25
|
+
import mplang.edsl as el
|
|
26
|
+
import mplang.edsl.typing as elt
|
|
27
|
+
import mplang.libs.mpc.vole.gilboa as vole
|
|
28
|
+
from mplang.dialects import crypto, field, simp, tensor
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
def silent_vole_random_u(
|
|
@@ -26,9 +26,9 @@ from typing import Any, cast
|
|
|
26
26
|
|
|
27
27
|
import jax.numpy as jnp
|
|
28
28
|
|
|
29
|
-
import mplang.
|
|
30
|
-
from mplang.
|
|
31
|
-
from mplang.
|
|
29
|
+
import mplang.edsl as el
|
|
30
|
+
from mplang.dialects import tensor
|
|
31
|
+
from mplang.libs.mpc.common.constants import (
|
|
32
32
|
E_FRAC_1,
|
|
33
33
|
GOLDEN_RATIO_64,
|
|
34
34
|
PI_FRAC_1,
|
|
@@ -18,9 +18,9 @@ This module provides the core data structures and algorithms for Sparse OKVS,
|
|
|
18
18
|
which is a critical component in unbalanced Private Set Intersection (PSI).
|
|
19
19
|
"""
|
|
20
20
|
|
|
21
|
-
import mplang.
|
|
22
|
-
from mplang.
|
|
23
|
-
from mplang.
|
|
21
|
+
import mplang.edsl as el
|
|
22
|
+
from mplang.dialects import field
|
|
23
|
+
from mplang.libs.mpc.psi.okvs import OKVS
|
|
24
24
|
|
|
25
25
|
# ============================================================================
|
|
26
26
|
# Constants
|
|
@@ -40,12 +40,16 @@ def get_okvs_expansion(n: int) -> float:
|
|
|
40
40
|
- For N → ∞: Theoretical minimum is ε ≈ 0.23 (M = 1.23N)
|
|
41
41
|
- For finite N: Larger ε needed due to variance in random hash collisions
|
|
42
42
|
|
|
43
|
-
Empirical safe thresholds (failure probability < 0.
|
|
44
|
-
- N
|
|
45
|
-
|
|
46
|
-
- N < 10,000: ε = 0.
|
|
47
|
-
- N < 100,000: ε = 0.
|
|
48
|
-
- N ≥ 100,000: ε = 0.35 (M = 1.35N)
|
|
43
|
+
Empirical safe thresholds (failure probability < 0.001%):
|
|
44
|
+
- N ≤ 200: ε = 24.0 (M = 25.0N) - extremely small sets need very wide margin
|
|
45
|
+
- N < 1,000: ε = 11.0 (M = 12.0N) - small sets need extra wide safety margin
|
|
46
|
+
- N < 10,000: ε = 0.6 (M = 1.6N)
|
|
47
|
+
- N < 100,000: ε = 0.4 (M = 1.4N)
|
|
48
|
+
- N ≥ 100,000: ε = 0.35 (M = 1.35N) - large sets converge near theory
|
|
49
|
+
|
|
50
|
+
Note: These expansion factors account for the 128-byte alignment requirement
|
|
51
|
+
in the OKVS implementation. The factors are intentionally conservative to
|
|
52
|
+
ensure high success rates (>99.9%) for the probabilistic peeling algorithm.
|
|
49
53
|
|
|
50
54
|
Args:
|
|
51
55
|
n: Number of key-value pairs to encode
|
|
@@ -53,12 +57,14 @@ def get_okvs_expansion(n: int) -> float:
|
|
|
53
57
|
Returns:
|
|
54
58
|
Expansion factor ε such that M = (1+ε)*N is safe for peeling
|
|
55
59
|
"""
|
|
56
|
-
if n
|
|
57
|
-
return
|
|
60
|
+
if n <= 200:
|
|
61
|
+
return 25.0 # Extremely small scale: need very wide margin for stability
|
|
62
|
+
elif n < 1000:
|
|
63
|
+
return 12.0 # Small scale: need wide safety margin for stability
|
|
58
64
|
elif n <= 10000:
|
|
59
|
-
return 1.
|
|
65
|
+
return 1.6 # Medium scale
|
|
60
66
|
elif n <= 100000:
|
|
61
|
-
return 1.
|
|
67
|
+
return 1.4 # Large scale
|
|
62
68
|
else:
|
|
63
69
|
# Mega-Binning requires ~1.35 for stability with 1024 bins
|
|
64
70
|
return 1.35
|
|
@@ -24,9 +24,9 @@ from typing import Any, cast
|
|
|
24
24
|
|
|
25
25
|
import jax.numpy as jnp
|
|
26
26
|
|
|
27
|
-
import mplang.
|
|
28
|
-
from mplang.
|
|
29
|
-
from mplang.
|
|
27
|
+
import mplang.edsl as el
|
|
28
|
+
from mplang.dialects import simp, tensor
|
|
29
|
+
from mplang.libs.mpc.ot import extension as ot_extension
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
def eval_oprf(
|
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Private Set Intersection using VOLE and OKVS (RR22-Style).
|
|
16
|
+
|
|
17
|
+
This module implements a high-performance PSI protocol based on the "Blazing Fast PSI"
|
|
18
|
+
(RR22) paper. The protocol relies on Vector Oblivious Linear Evaluation (VOLE) and
|
|
19
|
+
Oblivious Key-Value Stores (OKVS) to achieve efficient set intersection with linear
|
|
20
|
+
communication O(N) and computation complexity.
|
|
21
|
+
|
|
22
|
+
Protocol Overview:
|
|
23
|
+
The core idea is to mask a "Polynomial" (encoded via OKVS) with VOLE-correlated randomness,
|
|
24
|
+
such that the mask can only be removed (and the polynomial verified) if the parties share
|
|
25
|
+
the same element.
|
|
26
|
+
|
|
27
|
+
Phases:
|
|
28
|
+
1. **Correlated Randomness (VOLE)**:
|
|
29
|
+
Sender and Receiver establish a shared correlation:
|
|
30
|
+
W = V + U * Delta
|
|
31
|
+
- PSI Receiver holds `U` and `V` (these are generated by the OT "sender"
|
|
32
|
+
role in `silent_vole_random_u`).
|
|
33
|
+
- PSI Sender holds `W` and `Delta` (these are generated by the OT "receiver"
|
|
34
|
+
role).
|
|
35
|
+
- `U` is random; `Delta` is a fixed secret scalar held by the Sender.
|
|
36
|
+
|
|
37
|
+
2. **Encoding (OKVS)**:
|
|
38
|
+
The Receiver encodes its input set Y into an OKVS storage `P` such that
|
|
39
|
+
Decode(P, y) = H(y) for all y in Y. The function `H(y)` is implemented via
|
|
40
|
+
AES/Davies–Meyer expansion acting as a random oracle.
|
|
41
|
+
|
|
42
|
+
3. **Masking & Exchange**:
|
|
43
|
+
The Receiver masks the OKVS storage `P` with its VOLE share `U`:
|
|
44
|
+
Q = P + U
|
|
45
|
+
The masked storage `Q` is sent to the Sender (so Sender sees a masked OKVS).
|
|
46
|
+
|
|
47
|
+
4. **Decoding & Tag Generation (Sender)**:
|
|
48
|
+
The Sender holds `W` and `Delta` and computes the linear combination:
|
|
49
|
+
K = Q * Delta + W
|
|
50
|
+
Using W = V + U * Delta and Q = P + U, this simplifies to
|
|
51
|
+
K = P * Delta + V.
|
|
52
|
+
The Sender decodes `K` for each of its items x to obtain P(x)*Delta + V(x),
|
|
53
|
+
then subtracts H(x)*Delta (computed locally) to recover `V(x)`. The value
|
|
54
|
+
`Tag = V(x)` serves as the sender-side tag for item x.
|
|
55
|
+
|
|
56
|
+
5. **Verification (Receiver)**:
|
|
57
|
+
The Sender hashes and truncates tags and sends the truncated hashes to the
|
|
58
|
+
Receiver. The Receiver locally decodes `V(y)` from its OKVS, hashes it with
|
|
59
|
+
the same domain separation and truncation, and compares to the received
|
|
60
|
+
truncated hashes to determine membership in the intersection.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
from typing import Any, cast
|
|
64
|
+
|
|
65
|
+
import jax.numpy as jnp
|
|
66
|
+
|
|
67
|
+
import mplang.edsl as el
|
|
68
|
+
import mplang.libs.mpc.ot.silent as silent_ot
|
|
69
|
+
from mplang.dialects import field, simp, tensor
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def psi_intersect(
|
|
73
|
+
sender: int,
|
|
74
|
+
receiver: int,
|
|
75
|
+
n: int,
|
|
76
|
+
sender_items: el.Object,
|
|
77
|
+
receiver_items: el.Object,
|
|
78
|
+
) -> el.Object:
|
|
79
|
+
"""Execute OKVS-based PSI Protocol (Original RR22 Logic).
|
|
80
|
+
|
|
81
|
+
This implementation follows the RR22 paper's role assignment where:
|
|
82
|
+
- PSI Sender holds Delta (and W).
|
|
83
|
+
- PSI Receiver holds U and V.
|
|
84
|
+
|
|
85
|
+
This enables the "One Decode" optimization on the Sender side and prevents
|
|
86
|
+
offline brute-force attacks by the Receiver (though Sender could brute-force).
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
sender: Rank of Sender.
|
|
90
|
+
receiver: Rank of Receiver.
|
|
91
|
+
n: Number of items.
|
|
92
|
+
sender_items: Object located at Sender.
|
|
93
|
+
receiver_items: Object located at Receiver.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Intersection mask (0/1) located at Receiver.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
# Validation
|
|
100
|
+
if sender == receiver:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"Sender ({sender}) and Receiver ({receiver}) must be different."
|
|
103
|
+
)
|
|
104
|
+
if n <= 0:
|
|
105
|
+
raise ValueError(f"Input size n must be positive, got {n}.")
|
|
106
|
+
|
|
107
|
+
# =========================================================================
|
|
108
|
+
# Phase 1. Parameter Setup
|
|
109
|
+
# =========================================================================
|
|
110
|
+
import mplang.libs.mpc.psi.okvs_gct as okvs_gct
|
|
111
|
+
|
|
112
|
+
expansion = okvs_gct.get_okvs_expansion(n)
|
|
113
|
+
M = int(n * expansion)
|
|
114
|
+
if M % 128 != 0:
|
|
115
|
+
M = ((M // 128) + 1) * 128
|
|
116
|
+
|
|
117
|
+
# =========================================================================
|
|
118
|
+
# Phase 2. Correlated Randomness (VOLE)
|
|
119
|
+
# =========================================================================
|
|
120
|
+
# In the original paper logic (Fig 4), the PSI Sender holds Delta.
|
|
121
|
+
# Therefore, we swap the roles in the OT call.
|
|
122
|
+
#
|
|
123
|
+
# silent_vole_random_u(A, B) gives:
|
|
124
|
+
# A (OT Sender): U, V
|
|
125
|
+
# B (OT Receiver): W, Delta
|
|
126
|
+
#
|
|
127
|
+
# We want PSI Sender to be OT Receiver.
|
|
128
|
+
res_tuple = silent_ot.silent_vole_random_u(receiver, sender, M, base_k=1024)
|
|
129
|
+
|
|
130
|
+
# PSI Receiver gets U, V
|
|
131
|
+
v_recv, w_sender, u_recv, delta_sender = res_tuple[:4]
|
|
132
|
+
|
|
133
|
+
# =========================================================================
|
|
134
|
+
# Phase 3. Receiver Encoding & Masking
|
|
135
|
+
# =========================================================================
|
|
136
|
+
# Receiver computes P such that P(y) = H(y).
|
|
137
|
+
# Receiver masks P with U (Paper's A vector).
|
|
138
|
+
# Q = P ^ U
|
|
139
|
+
|
|
140
|
+
from mplang.dialects import crypto
|
|
141
|
+
from mplang.edsl import typing as elt
|
|
142
|
+
|
|
143
|
+
def _gen_seed() -> Any:
|
|
144
|
+
return crypto.random_tensor((2,), elt.u64)
|
|
145
|
+
|
|
146
|
+
okvs_seed = simp.pcall_static((receiver,), _gen_seed)
|
|
147
|
+
okvs_seed_sender = simp.shuffle_static(okvs_seed, {sender: receiver})
|
|
148
|
+
|
|
149
|
+
okvs = okvs_gct.SparseOKVS(M)
|
|
150
|
+
|
|
151
|
+
def _recv_ops(y: Any, u: Any, seed: Any) -> Any:
|
|
152
|
+
# 3.1 Compute H(y)
|
|
153
|
+
def _reshape_seeds(items: Any) -> Any:
|
|
154
|
+
lo = items
|
|
155
|
+
hi = jnp.zeros_like(items)
|
|
156
|
+
return jnp.stack([lo, hi], axis=1) # (N, 2)
|
|
157
|
+
|
|
158
|
+
seeds = tensor.run_jax(_reshape_seeds, y)
|
|
159
|
+
res_exp = field.aes_expand(seeds, 1) # (N, 1, 2)
|
|
160
|
+
|
|
161
|
+
def _davies_meyer(enc: Any, s: Any) -> Any:
|
|
162
|
+
enc_flat = enc.reshape(enc.shape[0], 2)
|
|
163
|
+
return jnp.bitwise_xor(enc_flat, s)
|
|
164
|
+
|
|
165
|
+
h_y = tensor.run_jax(_davies_meyer, res_exp, seeds)
|
|
166
|
+
|
|
167
|
+
# 3.2 Encode P
|
|
168
|
+
p_storage = okvs.encode(y, h_y, seed)
|
|
169
|
+
|
|
170
|
+
# 3. Mask with U (instead of W)
|
|
171
|
+
# Q = P ^ U
|
|
172
|
+
q_storage = field.add(p_storage, u)
|
|
173
|
+
|
|
174
|
+
return q_storage
|
|
175
|
+
|
|
176
|
+
# Receiver uses U to mask
|
|
177
|
+
q_shared = simp.pcall_static(
|
|
178
|
+
(receiver,), _recv_ops, receiver_items, u_recv, okvs_seed
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
q_sender_view = simp.shuffle_static(q_shared, {sender: receiver})
|
|
182
|
+
|
|
183
|
+
# =========================================================================
|
|
184
|
+
# Phase 4. Sender "One Decode" & Tag Generation
|
|
185
|
+
# =========================================================================
|
|
186
|
+
# Sender holds W, Delta. Receives Q.
|
|
187
|
+
# W = V + U * Delta
|
|
188
|
+
#
|
|
189
|
+
# Derivation:
|
|
190
|
+
# K = Q * Delta + W
|
|
191
|
+
# = (P + U) * Delta + (V + U * Delta)
|
|
192
|
+
# = P * Delta + U * Delta + V + U * Delta
|
|
193
|
+
# = P * Delta + V
|
|
194
|
+
#
|
|
195
|
+
# Sender computes Tag = Decode(K, x) - H(x) * Delta
|
|
196
|
+
# If x in Intersection: P(x) = H(x)
|
|
197
|
+
# Tag = (P(x) * Delta + V(x)) - P(x) * Delta
|
|
198
|
+
# Tag = V(x)
|
|
199
|
+
|
|
200
|
+
def _sender_ops(x: Any, q: Any, w: Any, delta: Any, seed: Any) -> Any:
|
|
201
|
+
# q, w: (M, 2)
|
|
202
|
+
# delta: (2,)
|
|
203
|
+
|
|
204
|
+
# Safe tiling assuming M is aligned
|
|
205
|
+
def _tile_m_simple(d: Any) -> Any:
|
|
206
|
+
return jnp.tile(d, (M, 1))
|
|
207
|
+
|
|
208
|
+
delta_expanded_m = tensor.run_jax(_tile_m_simple, delta)
|
|
209
|
+
|
|
210
|
+
# 4.2. Compute Global K = Q * Delta + W
|
|
211
|
+
# This is the O(M) multiplication mentioned in the paper
|
|
212
|
+
q_times_delta = field.mul(q, delta_expanded_m)
|
|
213
|
+
k_storage = field.add(q_times_delta, w)
|
|
214
|
+
|
|
215
|
+
# 4.3 One Decode
|
|
216
|
+
# decoded_val = P(x)*Delta + V(x)
|
|
217
|
+
decoded_k = okvs.decode(x, k_storage, seed)
|
|
218
|
+
|
|
219
|
+
# 4.4 Remove H(x)*Delta
|
|
220
|
+
def _reshape_seeds(items: Any) -> Any:
|
|
221
|
+
lo = items
|
|
222
|
+
hi = jnp.zeros_like(items)
|
|
223
|
+
return jnp.stack([lo, hi], axis=1)
|
|
224
|
+
|
|
225
|
+
seeds_x = tensor.run_jax(_reshape_seeds, x)
|
|
226
|
+
res_exp_x = field.aes_expand(seeds_x, 1)
|
|
227
|
+
|
|
228
|
+
def _davies_meyer(enc: Any, s: Any) -> Any:
|
|
229
|
+
enc_flat = enc.reshape(enc.shape[0], 2)
|
|
230
|
+
return jnp.bitwise_xor(enc_flat, s)
|
|
231
|
+
|
|
232
|
+
h_x = tensor.run_jax(_davies_meyer, res_exp_x, seeds_x)
|
|
233
|
+
|
|
234
|
+
# Expand delta for batch N
|
|
235
|
+
def _tile_n(d: Any) -> Any:
|
|
236
|
+
return jnp.tile(d, (n, 1))
|
|
237
|
+
|
|
238
|
+
delta_expanded_n = tensor.run_jax(_tile_n, delta)
|
|
239
|
+
|
|
240
|
+
h_x_times_delta = field.mul(h_x, delta_expanded_n)
|
|
241
|
+
|
|
242
|
+
# Final Tag = (P*Delta + V) + H*Delta = V(x)
|
|
243
|
+
tag = field.add(decoded_k, h_x_times_delta)
|
|
244
|
+
|
|
245
|
+
return tag
|
|
246
|
+
|
|
247
|
+
# Execute on Sender
|
|
248
|
+
sender_tags = simp.pcall_static(
|
|
249
|
+
(sender,),
|
|
250
|
+
_sender_ops,
|
|
251
|
+
sender_items,
|
|
252
|
+
q_sender_view,
|
|
253
|
+
w_sender,
|
|
254
|
+
delta_sender,
|
|
255
|
+
okvs_seed_sender,
|
|
256
|
+
)
|
|
257
|
+
# =========================================================================
|
|
258
|
+
# Phase 5. Verification (Receiver Side)
|
|
259
|
+
# =========================================================================
|
|
260
|
+
# Sender sends Tags (which should be V(x)) to Receiver. To reduce
|
|
261
|
+
# communication we hash and truncate on the sender side and only send
|
|
262
|
+
# the truncated hash (first 16 bytes).
|
|
263
|
+
|
|
264
|
+
# 5.1 Compute hashed & truncated tags on Sender
|
|
265
|
+
from mplang.libs.mpc.ot import extension as ot_extension
|
|
266
|
+
|
|
267
|
+
def _hash_and_trunc(tags: Any) -> Any:
|
|
268
|
+
# Compute batched hash on sender and truncate to 16 bytes
|
|
269
|
+
full_h = ot_extension.vec_hash(tags, domain_sep=0x1111, num_rows=n)
|
|
270
|
+
# Use tensor.slice_tensor to slice TraceObjects (start=(0,0), end=(n,16))
|
|
271
|
+
return tensor.slice_tensor(full_h, (0, 0), (n, 16))
|
|
272
|
+
|
|
273
|
+
h_sender_trunc = simp.pcall_static((sender,), _hash_and_trunc, sender_tags)
|
|
274
|
+
|
|
275
|
+
# 5.2 Send truncated hashes to Receiver (much smaller payload)
|
|
276
|
+
tags_at_recv = simp.shuffle_static(h_sender_trunc, {receiver: sender})
|
|
277
|
+
|
|
278
|
+
# 5.3 Receiver computes local V(y) and compares
|
|
279
|
+
def _recv_verify(y: Any, v: Any, seed: Any, remote_tags: Any) -> Any:
|
|
280
|
+
# 1. Decode V locally: target = V(y)
|
|
281
|
+
local_v_y = okvs.decode(y, v, seed)
|
|
282
|
+
|
|
283
|
+
# 2. Hash local V(y) and compare with received truncated sender hashes
|
|
284
|
+
# Note: `remote_tags` here is already the truncated hash (16 bytes)
|
|
285
|
+
# sent from the Sender.
|
|
286
|
+
h_local = ot_extension.vec_hash(local_v_y, domain_sep=0x1111, num_rows=n)
|
|
287
|
+
|
|
288
|
+
def _core(h_r16: Any, h_l_full: Any) -> Any:
|
|
289
|
+
# h_r16: (n, 16) truncated bytes from sender
|
|
290
|
+
# h_l_full: (n, k) full hash bytes; truncate to 16
|
|
291
|
+
h_l16 = h_l_full[:, :16]
|
|
292
|
+
|
|
293
|
+
eq_matrix = jnp.all(h_r16[:, None, :] == h_l16[None, :, :], axis=2)
|
|
294
|
+
membership = jnp.any(eq_matrix, axis=0)
|
|
295
|
+
return membership.astype(jnp.uint8)
|
|
296
|
+
|
|
297
|
+
return tensor.run_jax(_core, remote_tags, h_local)
|
|
298
|
+
|
|
299
|
+
intersection_mask = simp.pcall_static(
|
|
300
|
+
(receiver,), _recv_verify, receiver_items, v_recv, okvs_seed, tags_at_recv
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
return cast(el.Object, intersection_mask)
|
|
@@ -36,10 +36,10 @@ from typing import Any, cast
|
|
|
36
36
|
|
|
37
37
|
import jax.numpy as jnp
|
|
38
38
|
|
|
39
|
-
import mplang.
|
|
40
|
-
import mplang.
|
|
41
|
-
from mplang.
|
|
42
|
-
from mplang.
|
|
39
|
+
import mplang.edsl as el
|
|
40
|
+
import mplang.edsl.typing as elt
|
|
41
|
+
from mplang.dialects import crypto, field, simp, tensor
|
|
42
|
+
from mplang.libs.mpc.psi.okvs_gct import get_okvs_expansion
|
|
43
43
|
|
|
44
44
|
|
|
45
45
|
def psi_unbalanced(
|
|
@@ -24,9 +24,9 @@ from typing import Any, cast
|
|
|
24
24
|
import jax.numpy as jnp
|
|
25
25
|
import numpy as np
|
|
26
26
|
|
|
27
|
-
import mplang.
|
|
28
|
-
import mplang.
|
|
29
|
-
from mplang.
|
|
27
|
+
import mplang.edsl as el
|
|
28
|
+
import mplang.libs.mpc.ot.extension as ot
|
|
29
|
+
from mplang.dialects import field, simp, tensor
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
def vole(
|
|
@@ -32,8 +32,8 @@ import jax.numpy as jnp
|
|
|
32
32
|
import numpy as np
|
|
33
33
|
import scipy.sparse as sp
|
|
34
34
|
|
|
35
|
-
import mplang.
|
|
36
|
-
from mplang.
|
|
35
|
+
import mplang.edsl as el
|
|
36
|
+
from mplang.dialects import crypto, field, tensor
|
|
37
37
|
|
|
38
38
|
# ============================================================================
|
|
39
39
|
# Constants
|
|
@@ -43,11 +43,11 @@ from typing import Any, cast
|
|
|
43
43
|
|
|
44
44
|
import jax.numpy as jnp
|
|
45
45
|
|
|
46
|
-
import mplang.
|
|
47
|
-
import mplang.
|
|
48
|
-
import mplang.
|
|
49
|
-
from mplang.
|
|
50
|
-
from mplang.
|
|
46
|
+
import mplang.edsl as el
|
|
47
|
+
import mplang.edsl.typing as elt
|
|
48
|
+
import mplang.libs.mpc.ot.extension as ot
|
|
49
|
+
from mplang.dialects import crypto, field, simp, tensor
|
|
50
|
+
from mplang.libs.mpc.vole import ldpc
|
|
51
51
|
|
|
52
52
|
# ============================================================================
|
|
53
53
|
# Constants
|
|
@@ -148,7 +148,7 @@ def silver_vole(
|
|
|
148
148
|
H_indices_r, H_indptr_r = simp.pcall_static((receiver,), _sparse_struct_provider)
|
|
149
149
|
|
|
150
150
|
# 2. Base VOLE (Size K)
|
|
151
|
-
from mplang.
|
|
151
|
+
from mplang.libs.mpc.vole import gilboa
|
|
152
152
|
|
|
153
153
|
def _u_base_provider() -> el.Object:
|
|
154
154
|
# Generate random u_base using new API
|