mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev270__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +46 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +19 -13
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/libs/mpc/psi/rr22.py +303 -0
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang/v2/libs/mpc/psi/rr22.py +0 -344
  162. mplang_nightly-0.1.dev268.dist-info/RECORD +0 -180
  163. /mplang/{v2/backends → backends}/channel.py +0 -0
  164. /mplang/{v2/edsl → edsl}/README.md +0 -0
  165. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  166. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  167. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  168. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  169. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  171. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  172. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  175. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  177. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  178. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  179. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
  180. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
  181. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
@@ -28,10 +28,10 @@ from typing import Any, cast
28
28
 
29
29
  from jax.tree_util import tree_flatten, tree_map
30
30
 
31
- from mplang.v2.backends import load_builtins
32
- from mplang.v2.dialects import crypto, simp, spu, tee
33
- from mplang.v2.edsl.object import Object
34
- from mplang.v2.libs.device.cluster import Device
31
+ from mplang.backends import load_builtins
32
+ from mplang.dialects import crypto, simp, spu, tee
33
+ from mplang.edsl.object import Object
34
+ from mplang.libs.device.cluster import Device
35
35
 
36
36
  load_builtins()
37
37
 
@@ -43,7 +43,7 @@ def _resolve_cluster() -> Any:
43
43
  Interpreter with a _cluster_spec attribute. This allows nested contexts
44
44
  to override the cluster if needed.
45
45
  """
46
- from mplang.v2.edsl.context import find_context
46
+ from mplang.edsl.context import find_context
47
47
 
48
48
  ctx = find_context(lambda c: getattr(c, "_cluster_spec", None) is not None)
49
49
  if ctx is not None:
@@ -356,7 +356,7 @@ class DeviceContext:
356
356
  if self._is_spu_device():
357
357
  return self(fn)
358
358
  # PPU/TEE need tensor.jax_fn to compile JAX code
359
- from mplang.v2.dialects.tensor import jax_fn
359
+ from mplang.dialects.tensor import jax_fn
360
360
 
361
361
  return self(jax_fn(fn))
362
362
 
@@ -443,7 +443,7 @@ def _ensure_tee_session(
443
443
  Returns:
444
444
  Tuple of (sess_frm, sess_tee) where each is a symmetric key Object
445
445
  """
446
- import mplang.v2.edsl as el
446
+ import mplang.edsl as el
447
447
 
448
448
  # Get current context ID for cache isolation
449
449
  current_ctx = el.get_current_context()
@@ -749,11 +749,11 @@ def fetch(obj: Object) -> Any:
749
749
  Returns:
750
750
  Python value (numpy array, scalar, etc.)
751
751
  """
752
- from mplang.v2.backends.simp_driver.state import SimpDriver
753
- from mplang.v2.backends.simp_driver.values import DriverVar
754
- from mplang.v2.edsl.context import get_current_context
755
- from mplang.v2.runtime.interpreter import InterpObject, Interpreter
756
- from mplang.v2.runtime.value import WrapValue
752
+ from mplang.backends.simp_driver.state import SimpDriver
753
+ from mplang.backends.simp_driver.values import DriverVar
754
+ from mplang.edsl.context import get_current_context
755
+ from mplang.runtime.interpreter import InterpObject, Interpreter
756
+ from mplang.runtime.value import WrapValue
757
757
 
758
758
  def _unwrap_value(val: Any) -> Any:
759
759
  """Unwrap WrapValue to get the underlying data."""
@@ -14,7 +14,7 @@
14
14
 
15
15
  """Machine Learning algorithms for secure multi-party computation."""
16
16
 
17
- from mplang.v2.libs.ml.sgb import SecureBoost, Tree, TreeEnsemble
17
+ from mplang.libs.ml.sgb import SecureBoost, Tree, TreeEnsemble
18
18
 
19
19
  __all__ = [
20
20
  "SecureBoost",
@@ -14,7 +14,7 @@
14
14
 
15
15
  # mypy: disable-error-code="no-untyped-def,no-any-return,var-annotated"
16
16
 
17
- """SecureBoost v2: Optimized implementation using mplang.v2 low-level BFV APIs.
17
+ """SecureBoost v2: Optimized implementation using mplang low-level BFV APIs.
18
18
 
19
19
  This implementation improves upon v1 by leveraging BFV SIMD slots and the
20
20
  groupby primitives for efficient histogram computation.
@@ -46,8 +46,8 @@ import jax.numpy as jnp
46
46
  import numpy as np
47
47
  from jax.ops import segment_sum
48
48
 
49
- from mplang.v2.dialects import bfv, simp, tensor
50
- from mplang.v2.libs.mpc.analytics import aggregation
49
+ from mplang.dialects import bfv, simp, tensor
50
+ from mplang.libs.mpc.analytics import aggregation
51
51
 
52
52
  # ==============================================================================
53
53
  # Configuration
@@ -1687,7 +1687,7 @@ def fit_tree_ensemble(
1687
1687
 
1688
1688
 
1689
1689
  class SecureBoost:
1690
- """SecureBoost classifier using mplang.v2 low-level BFV APIs.
1690
+ """SecureBoost classifier using mplang low-level BFV APIs.
1691
1691
 
1692
1692
  This is an optimized implementation that uses BFV SIMD slots for
1693
1693
  efficient histogram computation.
@@ -21,9 +21,9 @@ Subpackages:
21
21
  - analytics: Privacy-preserving analytics
22
22
 
23
23
  Example usage:
24
- from mplang.v2.libs.mpc import ot_transfer, apply_permutation
25
- from mplang.v2.libs.mpc.vole import silver_vole
26
- from mplang.v2.libs.mpc.psi import psi_intersect
24
+ from mplang.libs.mpc import ot_transfer, apply_permutation
25
+ from mplang.libs.mpc.vole import silver_vole
26
+ from mplang.libs.mpc.psi import psi_intersect
27
27
  """
28
28
 
29
29
  from .analytics.aggregation import rotate_and_sum
@@ -20,8 +20,8 @@ from typing import Any, cast
20
20
 
21
21
  import jax.numpy as jnp
22
22
 
23
- import mplang.v2.edsl as el
24
- from mplang.v2.dialects import tensor
23
+ import mplang.edsl as el
24
+ from mplang.dialects import tensor
25
25
 
26
26
 
27
27
  def bytes_to_bits(data: el.Object) -> el.Object:
@@ -24,7 +24,7 @@ from typing import Any
24
24
 
25
25
  import numpy as np
26
26
 
27
- from mplang.v2.dialects import bfv, tensor
27
+ from mplang.dialects import bfv, tensor
28
28
 
29
29
 
30
30
  def _safe_rotate(
@@ -27,8 +27,8 @@ from typing import Any
27
27
  import jax
28
28
  import jax.numpy as jnp
29
29
 
30
- from mplang.v2.dialects import bfv, crypto, simp, tensor
31
- from mplang.v2.libs.mpc.analytics import aggregation, permutation
30
+ from mplang.dialects import bfv, crypto, simp, tensor
31
+ from mplang.libs.mpc.analytics import aggregation, permutation
32
32
 
33
33
 
34
34
  def oblivious_groupby_sum_bfv(
@@ -33,9 +33,9 @@ import jax
33
33
  import jax.numpy as jnp
34
34
  import numpy as np
35
35
 
36
- import mplang.v2.edsl.typing as elt
37
- from mplang.v2.dialects import simp, tensor
38
- from mplang.v2.libs.mpc.ot import base as ot
36
+ import mplang.edsl.typing as elt
37
+ from mplang.dialects import simp, tensor
38
+ from mplang.libs.mpc.ot import base as ot
39
39
 
40
40
 
41
41
  def secure_switch(
@@ -27,9 +27,9 @@ from typing import Any, cast
27
27
 
28
28
  import numpy as np
29
29
 
30
- import mplang.v2.edsl as el
31
- import mplang.v2.edsl.typing as elt
32
- from mplang.v2.dialects import crypto, simp, tensor
30
+ import mplang.edsl as el
31
+ import mplang.edsl.typing as elt
32
+ from mplang.dialects import crypto, simp, tensor
33
33
 
34
34
 
35
35
  def _receiver_keygen_scalar(
@@ -24,8 +24,8 @@ from typing import Any, cast
24
24
 
25
25
  import jax.numpy as jnp
26
26
 
27
- import mplang.v2.edsl as el
28
- from mplang.v2.dialects import crypto, field, simp, tensor
27
+ import mplang.edsl as el
28
+ from mplang.dialects import crypto, field, simp, tensor
29
29
 
30
30
 
31
31
  def prg_expand(seed_tensor: el.Object, length: int) -> el.Object:
@@ -22,10 +22,10 @@ from typing import Any, cast
22
22
 
23
23
  import jax.numpy as jnp
24
24
 
25
- import mplang.v2.edsl as el
26
- import mplang.v2.edsl.typing as elt
27
- import mplang.v2.libs.mpc.vole.gilboa as vole
28
- from mplang.v2.dialects import crypto, field, simp, tensor
25
+ import mplang.edsl as el
26
+ import mplang.edsl.typing as elt
27
+ import mplang.libs.mpc.vole.gilboa as vole
28
+ from mplang.dialects import crypto, field, simp, tensor
29
29
 
30
30
 
31
31
  def silent_vole_random_u(
@@ -26,9 +26,9 @@ from typing import Any, cast
26
26
 
27
27
  import jax.numpy as jnp
28
28
 
29
- import mplang.v2.edsl as el
30
- from mplang.v2.dialects import tensor
31
- from mplang.v2.libs.mpc.common.constants import (
29
+ import mplang.edsl as el
30
+ from mplang.dialects import tensor
31
+ from mplang.libs.mpc.common.constants import (
32
32
  E_FRAC_1,
33
33
  GOLDEN_RATIO_64,
34
34
  PI_FRAC_1,
@@ -16,7 +16,7 @@
16
16
 
17
17
  from abc import ABC, abstractmethod
18
18
 
19
- import mplang.v2.edsl as el
19
+ import mplang.edsl as el
20
20
 
21
21
 
22
22
  class OKVS(ABC):
@@ -18,9 +18,9 @@ This module provides the core data structures and algorithms for Sparse OKVS,
18
18
  which is a critical component in unbalanced Private Set Intersection (PSI).
19
19
  """
20
20
 
21
- import mplang.v2.edsl as el
22
- from mplang.v2.dialects import field
23
- from mplang.v2.libs.mpc.psi.okvs import OKVS
21
+ import mplang.edsl as el
22
+ from mplang.dialects import field
23
+ from mplang.libs.mpc.psi.okvs import OKVS
24
24
 
25
25
  # ============================================================================
26
26
  # Constants
@@ -40,12 +40,16 @@ def get_okvs_expansion(n: int) -> float:
40
40
  - For N → ∞: Theoretical minimum is ε ≈ 0.23 (M = 1.23N)
41
41
  - For finite N: Larger ε needed due to variance in random hash collisions
42
42
 
43
- Empirical safe thresholds (failure probability < 0.01%):
44
- - N < 1,000: ε = 4.5 (M = 5.5N) - very small sets need extra wide margin
45
- to handle worst-case hash collisions
46
- - N < 10,000: ε = 0.4 (M = 1.4N)
47
- - N < 100,000: ε = 0.3 (M = 1.3N)
48
- - N ≥ 100,000: ε = 0.35 (M = 1.35N) - large sets converge near theory
43
+ Empirical safe thresholds (failure probability < 0.001%):
44
+ - N 200: ε = 24.0 (M = 25.0N) - extremely small sets need very wide margin
45
+ - N < 1,000: ε = 11.0 (M = 12.0N) - small sets need extra wide safety margin
46
+ - N < 10,000: ε = 0.6 (M = 1.6N)
47
+ - N < 100,000: ε = 0.4 (M = 1.4N)
48
+ - N ≥ 100,000: ε = 0.35 (M = 1.35N) - large sets converge near theory
49
+
50
+ Note: These expansion factors account for the 128-byte alignment requirement
51
+ in the OKVS implementation. The factors are intentionally conservative to
52
+ ensure high success rates (>99.9%) for the probabilistic peeling algorithm.
49
53
 
50
54
  Args:
51
55
  n: Number of key-value pairs to encode
@@ -53,12 +57,14 @@ def get_okvs_expansion(n: int) -> float:
53
57
  Returns:
54
58
  Expansion factor ε such that M = (1+ε)*N is safe for peeling
55
59
  """
56
- if n < 1000:
57
- return 5.5 # Small scale: need very wide safety margin for stability
60
+ if n <= 200:
61
+ return 25.0 # Extremely small scale: need very wide margin for stability
62
+ elif n < 1000:
63
+ return 12.0 # Small scale: need wide safety margin for stability
58
64
  elif n <= 10000:
59
- return 1.4 # Medium scale
65
+ return 1.6 # Medium scale
60
66
  elif n <= 100000:
61
- return 1.3 # Large scale
67
+ return 1.4 # Large scale
62
68
  else:
63
69
  # Mega-Binning requires ~1.35 for stability with 1024 bins
64
70
  return 1.35
@@ -24,9 +24,9 @@ from typing import Any, cast
24
24
 
25
25
  import jax.numpy as jnp
26
26
 
27
- import mplang.v2.edsl as el
28
- from mplang.v2.dialects import simp, tensor
29
- from mplang.v2.libs.mpc.ot import extension as ot_extension
27
+ import mplang.edsl as el
28
+ from mplang.dialects import simp, tensor
29
+ from mplang.libs.mpc.ot import extension as ot_extension
30
30
 
31
31
 
32
32
  def eval_oprf(
@@ -0,0 +1,303 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Private Set Intersection using VOLE and OKVS (RR22-Style).
16
+
17
+ This module implements a high-performance PSI protocol based on the "Blazing Fast PSI"
18
+ (RR22) paper. The protocol relies on Vector Oblivious Linear Evaluation (VOLE) and
19
+ Oblivious Key-Value Stores (OKVS) to achieve efficient set intersection with linear
20
+ communication O(N) and computation complexity.
21
+
22
+ Protocol Overview:
23
+ The core idea is to mask a "Polynomial" (encoded via OKVS) with VOLE-correlated randomness,
24
+ such that the mask can only be removed (and the polynomial verified) if the parties share
25
+ the same element.
26
+
27
+ Phases:
28
+ 1. **Correlated Randomness (VOLE)**:
29
+ Sender and Receiver establish a shared correlation:
30
+ W = V + U * Delta
31
+ - PSI Receiver holds `U` and `V` (these are generated by the OT "sender"
32
+ role in `silent_vole_random_u`).
33
+ - PSI Sender holds `W` and `Delta` (these are generated by the OT "receiver"
34
+ role).
35
+ - `U` is random; `Delta` is a fixed secret scalar held by the Sender.
36
+
37
+ 2. **Encoding (OKVS)**:
38
+ The Receiver encodes its input set Y into an OKVS storage `P` such that
39
+ Decode(P, y) = H(y) for all y in Y. The function `H(y)` is implemented via
40
+ AES/Davies–Meyer expansion acting as a random oracle.
41
+
42
+ 3. **Masking & Exchange**:
43
+ The Receiver masks the OKVS storage `P` with its VOLE share `U`:
44
+ Q = P + U
45
+ The masked storage `Q` is sent to the Sender (so Sender sees a masked OKVS).
46
+
47
+ 4. **Decoding & Tag Generation (Sender)**:
48
+ The Sender holds `W` and `Delta` and computes the linear combination:
49
+ K = Q * Delta + W
50
+ Using W = V + U * Delta and Q = P + U, this simplifies to
51
+ K = P * Delta + V.
52
+ The Sender decodes `K` for each of its items x to obtain P(x)*Delta + V(x),
53
+ then subtracts H(x)*Delta (computed locally) to recover `V(x)`. The value
54
+ `Tag = V(x)` serves as the sender-side tag for item x.
55
+
56
+ 5. **Verification (Receiver)**:
57
+ The Sender hashes and truncates tags and sends the truncated hashes to the
58
+ Receiver. The Receiver locally decodes `V(y)` from its OKVS, hashes it with
59
+ the same domain separation and truncation, and compares to the received
60
+ truncated hashes to determine membership in the intersection.
61
+ """
62
+
63
+ from typing import Any, cast
64
+
65
+ import jax.numpy as jnp
66
+
67
+ import mplang.edsl as el
68
+ import mplang.libs.mpc.ot.silent as silent_ot
69
+ from mplang.dialects import field, simp, tensor
70
+
71
+
72
+ def psi_intersect(
73
+ sender: int,
74
+ receiver: int,
75
+ n: int,
76
+ sender_items: el.Object,
77
+ receiver_items: el.Object,
78
+ ) -> el.Object:
79
+ """Execute OKVS-based PSI Protocol (Original RR22 Logic).
80
+
81
+ This implementation follows the RR22 paper's role assignment where:
82
+ - PSI Sender holds Delta (and W).
83
+ - PSI Receiver holds U and V.
84
+
85
+ This enables the "One Decode" optimization on the Sender side and prevents
86
+ offline brute-force attacks by the Receiver (though Sender could brute-force).
87
+
88
+ Args:
89
+ sender: Rank of Sender.
90
+ receiver: Rank of Receiver.
91
+ n: Number of items.
92
+ sender_items: Object located at Sender.
93
+ receiver_items: Object located at Receiver.
94
+
95
+ Returns:
96
+ Intersection mask (0/1) located at Receiver.
97
+ """
98
+
99
+ # Validation
100
+ if sender == receiver:
101
+ raise ValueError(
102
+ f"Sender ({sender}) and Receiver ({receiver}) must be different."
103
+ )
104
+ if n <= 0:
105
+ raise ValueError(f"Input size n must be positive, got {n}.")
106
+
107
+ # =========================================================================
108
+ # Phase 1. Parameter Setup
109
+ # =========================================================================
110
+ import mplang.libs.mpc.psi.okvs_gct as okvs_gct
111
+
112
+ expansion = okvs_gct.get_okvs_expansion(n)
113
+ M = int(n * expansion)
114
+ if M % 128 != 0:
115
+ M = ((M // 128) + 1) * 128
116
+
117
+ # =========================================================================
118
+ # Phase 2. Correlated Randomness (VOLE)
119
+ # =========================================================================
120
+ # In the original paper logic (Fig 4), the PSI Sender holds Delta.
121
+ # Therefore, we swap the roles in the OT call.
122
+ #
123
+ # silent_vole_random_u(A, B) gives:
124
+ # A (OT Sender): U, V
125
+ # B (OT Receiver): W, Delta
126
+ #
127
+ # We want PSI Sender to be OT Receiver.
128
+ res_tuple = silent_ot.silent_vole_random_u(receiver, sender, M, base_k=1024)
129
+
130
+ # PSI Receiver gets U, V
131
+ v_recv, w_sender, u_recv, delta_sender = res_tuple[:4]
132
+
133
+ # =========================================================================
134
+ # Phase 3. Receiver Encoding & Masking
135
+ # =========================================================================
136
+ # Receiver computes P such that P(y) = H(y).
137
+ # Receiver masks P with U (Paper's A vector).
138
+ # Q = P ^ U
139
+
140
+ from mplang.dialects import crypto
141
+ from mplang.edsl import typing as elt
142
+
143
+ def _gen_seed() -> Any:
144
+ return crypto.random_tensor((2,), elt.u64)
145
+
146
+ okvs_seed = simp.pcall_static((receiver,), _gen_seed)
147
+ okvs_seed_sender = simp.shuffle_static(okvs_seed, {sender: receiver})
148
+
149
+ okvs = okvs_gct.SparseOKVS(M)
150
+
151
+ def _recv_ops(y: Any, u: Any, seed: Any) -> Any:
152
+ # 3.1 Compute H(y)
153
+ def _reshape_seeds(items: Any) -> Any:
154
+ lo = items
155
+ hi = jnp.zeros_like(items)
156
+ return jnp.stack([lo, hi], axis=1) # (N, 2)
157
+
158
+ seeds = tensor.run_jax(_reshape_seeds, y)
159
+ res_exp = field.aes_expand(seeds, 1) # (N, 1, 2)
160
+
161
+ def _davies_meyer(enc: Any, s: Any) -> Any:
162
+ enc_flat = enc.reshape(enc.shape[0], 2)
163
+ return jnp.bitwise_xor(enc_flat, s)
164
+
165
+ h_y = tensor.run_jax(_davies_meyer, res_exp, seeds)
166
+
167
+ # 3.2 Encode P
168
+ p_storage = okvs.encode(y, h_y, seed)
169
+
170
+ # 3. Mask with U (instead of W)
171
+ # Q = P ^ U
172
+ q_storage = field.add(p_storage, u)
173
+
174
+ return q_storage
175
+
176
+ # Receiver uses U to mask
177
+ q_shared = simp.pcall_static(
178
+ (receiver,), _recv_ops, receiver_items, u_recv, okvs_seed
179
+ )
180
+
181
+ q_sender_view = simp.shuffle_static(q_shared, {sender: receiver})
182
+
183
+ # =========================================================================
184
+ # Phase 4. Sender "One Decode" & Tag Generation
185
+ # =========================================================================
186
+ # Sender holds W, Delta. Receives Q.
187
+ # W = V + U * Delta
188
+ #
189
+ # Derivation:
190
+ # K = Q * Delta + W
191
+ # = (P + U) * Delta + (V + U * Delta)
192
+ # = P * Delta + U * Delta + V + U * Delta
193
+ # = P * Delta + V
194
+ #
195
+ # Sender computes Tag = Decode(K, x) - H(x) * Delta
196
+ # If x in Intersection: P(x) = H(x)
197
+ # Tag = (P(x) * Delta + V(x)) - P(x) * Delta
198
+ # Tag = V(x)
199
+
200
+ def _sender_ops(x: Any, q: Any, w: Any, delta: Any, seed: Any) -> Any:
201
+ # q, w: (M, 2)
202
+ # delta: (2,)
203
+
204
+ # Safe tiling assuming M is aligned
205
+ def _tile_m_simple(d: Any) -> Any:
206
+ return jnp.tile(d, (M, 1))
207
+
208
+ delta_expanded_m = tensor.run_jax(_tile_m_simple, delta)
209
+
210
+ # 4.2. Compute Global K = Q * Delta + W
211
+ # This is the O(M) multiplication mentioned in the paper
212
+ q_times_delta = field.mul(q, delta_expanded_m)
213
+ k_storage = field.add(q_times_delta, w)
214
+
215
+ # 4.3 One Decode
216
+ # decoded_val = P(x)*Delta + V(x)
217
+ decoded_k = okvs.decode(x, k_storage, seed)
218
+
219
+ # 4.4 Remove H(x)*Delta
220
+ def _reshape_seeds(items: Any) -> Any:
221
+ lo = items
222
+ hi = jnp.zeros_like(items)
223
+ return jnp.stack([lo, hi], axis=1)
224
+
225
+ seeds_x = tensor.run_jax(_reshape_seeds, x)
226
+ res_exp_x = field.aes_expand(seeds_x, 1)
227
+
228
+ def _davies_meyer(enc: Any, s: Any) -> Any:
229
+ enc_flat = enc.reshape(enc.shape[0], 2)
230
+ return jnp.bitwise_xor(enc_flat, s)
231
+
232
+ h_x = tensor.run_jax(_davies_meyer, res_exp_x, seeds_x)
233
+
234
+ # Expand delta for batch N
235
+ def _tile_n(d: Any) -> Any:
236
+ return jnp.tile(d, (n, 1))
237
+
238
+ delta_expanded_n = tensor.run_jax(_tile_n, delta)
239
+
240
+ h_x_times_delta = field.mul(h_x, delta_expanded_n)
241
+
242
+ # Final Tag = (P*Delta + V) + H*Delta = V(x)
243
+ tag = field.add(decoded_k, h_x_times_delta)
244
+
245
+ return tag
246
+
247
+ # Execute on Sender
248
+ sender_tags = simp.pcall_static(
249
+ (sender,),
250
+ _sender_ops,
251
+ sender_items,
252
+ q_sender_view,
253
+ w_sender,
254
+ delta_sender,
255
+ okvs_seed_sender,
256
+ )
257
+ # =========================================================================
258
+ # Phase 5. Verification (Receiver Side)
259
+ # =========================================================================
260
+ # Sender sends Tags (which should be V(x)) to Receiver. To reduce
261
+ # communication we hash and truncate on the sender side and only send
262
+ # the truncated hash (first 16 bytes).
263
+
264
+ # 5.1 Compute hashed & truncated tags on Sender
265
+ from mplang.libs.mpc.ot import extension as ot_extension
266
+
267
+ def _hash_and_trunc(tags: Any) -> Any:
268
+ # Compute batched hash on sender and truncate to 16 bytes
269
+ full_h = ot_extension.vec_hash(tags, domain_sep=0x1111, num_rows=n)
270
+ # Use tensor.slice_tensor to slice TraceObjects (start=(0,0), end=(n,16))
271
+ return tensor.slice_tensor(full_h, (0, 0), (n, 16))
272
+
273
+ h_sender_trunc = simp.pcall_static((sender,), _hash_and_trunc, sender_tags)
274
+
275
+ # 5.2 Send truncated hashes to Receiver (much smaller payload)
276
+ tags_at_recv = simp.shuffle_static(h_sender_trunc, {receiver: sender})
277
+
278
+ # 5.3 Receiver computes local V(y) and compares
279
+ def _recv_verify(y: Any, v: Any, seed: Any, remote_tags: Any) -> Any:
280
+ # 1. Decode V locally: target = V(y)
281
+ local_v_y = okvs.decode(y, v, seed)
282
+
283
+ # 2. Hash local V(y) and compare with received truncated sender hashes
284
+ # Note: `remote_tags` here is already the truncated hash (16 bytes)
285
+ # sent from the Sender.
286
+ h_local = ot_extension.vec_hash(local_v_y, domain_sep=0x1111, num_rows=n)
287
+
288
+ def _core(h_r16: Any, h_l_full: Any) -> Any:
289
+ # h_r16: (n, 16) truncated bytes from sender
290
+ # h_l_full: (n, k) full hash bytes; truncate to 16
291
+ h_l16 = h_l_full[:, :16]
292
+
293
+ eq_matrix = jnp.all(h_r16[:, None, :] == h_l16[None, :, :], axis=2)
294
+ membership = jnp.any(eq_matrix, axis=0)
295
+ return membership.astype(jnp.uint8)
296
+
297
+ return tensor.run_jax(_core, remote_tags, h_local)
298
+
299
+ intersection_mask = simp.pcall_static(
300
+ (receiver,), _recv_verify, receiver_items, v_recv, okvs_seed, tags_at_recv
301
+ )
302
+
303
+ return cast(el.Object, intersection_mask)
@@ -36,10 +36,10 @@ from typing import Any, cast
36
36
 
37
37
  import jax.numpy as jnp
38
38
 
39
- import mplang.v2.edsl as el
40
- import mplang.v2.edsl.typing as elt
41
- from mplang.v2.dialects import crypto, field, simp, tensor
42
- from mplang.v2.libs.mpc.psi.okvs_gct import get_okvs_expansion
39
+ import mplang.edsl as el
40
+ import mplang.edsl.typing as elt
41
+ from mplang.dialects import crypto, field, simp, tensor
42
+ from mplang.libs.mpc.psi.okvs_gct import get_okvs_expansion
43
43
 
44
44
 
45
45
  def psi_unbalanced(
@@ -24,9 +24,9 @@ from typing import Any, cast
24
24
  import jax.numpy as jnp
25
25
  import numpy as np
26
26
 
27
- import mplang.v2.edsl as el
28
- import mplang.v2.libs.mpc.ot.extension as ot
29
- from mplang.v2.dialects import field, simp, tensor
27
+ import mplang.edsl as el
28
+ import mplang.libs.mpc.ot.extension as ot
29
+ from mplang.dialects import field, simp, tensor
30
30
 
31
31
 
32
32
  def vole(
@@ -32,8 +32,8 @@ import jax.numpy as jnp
32
32
  import numpy as np
33
33
  import scipy.sparse as sp
34
34
 
35
- import mplang.v2.edsl as el
36
- from mplang.v2.dialects import crypto, field, tensor
35
+ import mplang.edsl as el
36
+ from mplang.dialects import crypto, field, tensor
37
37
 
38
38
  # ============================================================================
39
39
  # Constants
@@ -43,11 +43,11 @@ from typing import Any, cast
43
43
 
44
44
  import jax.numpy as jnp
45
45
 
46
- import mplang.v2.edsl as el
47
- import mplang.v2.edsl.typing as elt
48
- import mplang.v2.libs.mpc.ot.extension as ot
49
- from mplang.v2.dialects import crypto, field, simp, tensor
50
- from mplang.v2.libs.mpc.vole import ldpc
46
+ import mplang.edsl as el
47
+ import mplang.edsl.typing as elt
48
+ import mplang.libs.mpc.ot.extension as ot
49
+ from mplang.dialects import crypto, field, simp, tensor
50
+ from mplang.libs.mpc.vole import ldpc
51
51
 
52
52
  # ============================================================================
53
53
  # Constants
@@ -148,7 +148,7 @@ def silver_vole(
148
148
  H_indices_r, H_indptr_r = simp.pcall_static((receiver,), _sparse_struct_provider)
149
149
 
150
150
  # 2. Base VOLE (Size K)
151
- from mplang.v2.libs.mpc.vole import gilboa
151
+ from mplang.libs.mpc.vole import gilboa
152
152
 
153
153
  def _u_base_provider() -> el.Object:
154
154
  # Generate random u_base using new API