mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,40 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Private Set Intersection (PSI) protocols.
16
+
17
+ Submodules:
18
+ - rr22: VOLE-masked PSI protocol (formerly okvs.py)
19
+ - unbalanced: Unbalanced PSI (O(n) communication)
20
+ - oprf: KKRT OPRF protocol
21
+ - cuckoo: Cuckoo hashing
22
+ - okvs_gct: Sparse OKVS data structure (Garbled Cuckoo Table)
23
+ - okvs: OKVS Abstract Base Class
24
+ """
25
+
26
+ from .oprf import eval_oprf, sender_eval_prf, sender_eval_prf_batch
27
+ from .rr22 import psi_intersect
28
+ from .unbalanced import psi_unbalanced
29
+
30
+ # Alias for backward compatibility
31
+ eval = psi_intersect
32
+
33
+ __all__ = [
34
+ "eval",
35
+ "eval_oprf",
36
+ "psi_intersect",
37
+ "psi_unbalanced",
38
+ "sender_eval_prf",
39
+ "sender_eval_prf_batch",
40
+ ]
@@ -0,0 +1,228 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Cuckoo Hashing for OPRF-PSI.
16
+
17
+ Implements JAX-compatible Cuckoo hashing for mapping items to table positions.
18
+ Each item hashes to K candidate positions; during lookup, check all K positions.
19
+
20
+ Reference: KKRT OPRF-PSI uses Cuckoo hashing for row mapping.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ from typing import Any, cast
26
+
27
+ import jax.numpy as jnp
28
+
29
+ import mplang.v2.edsl as el
30
+ from mplang.v2.dialects import tensor
31
+ from mplang.v2.libs.mpc.common.constants import (
32
+ E_FRAC_1,
33
+ GOLDEN_RATIO_64,
34
+ PI_FRAC_1,
35
+ PI_FRAC_2,
36
+ SPLITMIX64_GAMMA_1,
37
+ SPLITMIX64_GAMMA_2,
38
+ )
39
+
40
+ # =============================================================================
41
+ # Cuckoo Hash Parameters
42
+ # =============================================================================
43
+
44
+ NUM_HASH_FUNCTIONS = 3 # Standard: 3 hash functions
45
+ STASH_SIZE = 0 # Simple version: no stash (higher failure rate)
46
+ MASK64 = 0xFFFFFFFFFFFFFFFF
47
+
48
+
49
+ def hash_to_positions(items: Any, table_size: int, seed: tuple[int, int]) -> Any:
50
+ """Compute K candidate positions for each item.
51
+
52
+ Uses polynomial hash family with seeded coefficients:
53
+ h_i(x) = (a_i * x + b_i) mod table_size
54
+
55
+ Security: Both coefficients a and b are seeded to prevent
56
+ structural analysis attacks on the hash family.
57
+
58
+ Args:
59
+ items: (N, 16) uint8 array - items to hash
60
+ table_size: Size of Cuckoo hash table
61
+ seed: (2,) tuple of uint64 - random seed
62
+
63
+ Returns:
64
+ (N, K) int32 array - K candidate positions for each item
65
+ """
66
+ N = items.shape[0]
67
+ K = NUM_HASH_FUNCTIONS
68
+
69
+ # Convert items to 64-bit keys (first 8 bytes)
70
+ keys = items[:, :8].view(jnp.uint64).reshape(N)
71
+
72
+ # Mix seed into keys
73
+ seed0 = jnp.uint64(seed[0])
74
+ seed1 = jnp.uint64(seed[1])
75
+ keys = keys ^ seed0
76
+
77
+ # Base hash coefficients (deterministic starting point)
78
+ a_base = jnp.array(
79
+ [GOLDEN_RATIO_64, SPLITMIX64_GAMMA_1, SPLITMIX64_GAMMA_2], dtype=jnp.uint64
80
+ )
81
+ b_base = jnp.array([PI_FRAC_1, PI_FRAC_2, E_FRAC_1], dtype=jnp.uint64)
82
+
83
+ # Security Fix: Seed BOTH coefficients a and b
84
+ # This prevents structural analysis attacks on the hash family
85
+ a = a_base ^ seed0 # Mix seed0 into multiplicative coefficient
86
+ b = b_base ^ seed1 # Mix seed1 into additive coefficient
87
+
88
+ # Compute hash positions: (N, K)
89
+ positions = jnp.zeros((N, K), dtype=jnp.int32)
90
+ for i in range(K):
91
+ h = (keys * a[i] + b[i]) % table_size
92
+ positions = positions.at[:, i].set(h.astype(jnp.int32))
93
+
94
+ return positions
95
+
96
+
97
+ def cuckoo_insert_batch(
98
+ items: Any,
99
+ table_size: int,
100
+ seed: tuple[int, int],
101
+ max_iters: int = 100,
102
+ ) -> tuple[Any, Any, Any]:
103
+ """Batch Cuckoo insertion using vectorized logic (JAX-compatible).
104
+
105
+ Uses multi-choice parallel insertion:
106
+ 1. All items try 1st choice. Collisions resolved by last-write-wins.
107
+ 2. Failed items try 2nd choice.
108
+ 3. Failed items try 3rd choice.
109
+
110
+ Args:
111
+ items: (N, 16) uint8 array - items to insert
112
+ table_size: Size of Cuckoo hash table (should be ~1.3-1.5N)
113
+ max_iters: Ignored in this vectorized version (uses K=3 fixed passes)
114
+ seed: (2,) uint64 seed
115
+
116
+ Returns:
117
+ Tuple of:
118
+ - table: (table_size, 16) uint8 - Cuckoo hash table
119
+ - item_to_pos: (N,) int32 - position of each item in table
120
+ - success: (N,) bool - whether each item was successfully inserted
121
+ """
122
+ N = items.shape[0]
123
+ K = NUM_HASH_FUNCTIONS
124
+
125
+ positions = hash_to_positions(items, table_size, seed)
126
+ item_to_pos = jnp.full(N, -1, dtype=jnp.int32)
127
+ active_mask = jnp.ones(N, dtype=jnp.bool_)
128
+
129
+ # We track which item "owns" each table slot
130
+ table_slots = jnp.full(table_size, -1, dtype=jnp.int32)
131
+
132
+ # Track occupied status to forbid overwriting previous successes
133
+ table_occupied = jnp.zeros(table_size, dtype=jnp.bool_)
134
+
135
+ item_indices = jnp.arange(N, dtype=jnp.int32)
136
+
137
+ for k in range(K):
138
+ # 1. Propose positions for active items
139
+ # Inactive items get -1 proposal
140
+ cand_pos = jnp.where(active_mask, positions[:, k], -1)
141
+
142
+ # 2. Filter out already occupied slots
143
+ # Map -1 to safe index 0 for lookup (result discarded via mask)
144
+ safe_lookup = jnp.maximum(cand_pos, 0)
145
+ is_occupied = table_occupied[safe_lookup]
146
+ # Valid proposal: not -1 AND not occupied
147
+ cand_pos_valid = jnp.where((cand_pos >= 0) & (~is_occupied), cand_pos, -1)
148
+
149
+ # 3. Attempt write to table_slots using Scatter
150
+ # Extend table to handle -1 dump index (at index table_size)
151
+ ext_slots = jnp.pad(table_slots, (0, 1), constant_values=-1)
152
+
153
+ # Map -1 to dump index
154
+ write_pos = jnp.where(cand_pos_valid >= 0, cand_pos_valid, table_size)
155
+
156
+ # Write active item indices
157
+ # We write ALL items, but inactive ones write to dump.
158
+ # This is safe because active ones write to valid slots (or dump if collision/occupied).
159
+ ext_slots_updated = ext_slots.at[write_pos].set(item_indices)
160
+
161
+ # 4. Verify winners
162
+ winner_indices = ext_slots_updated[write_pos]
163
+
164
+ # Success if:
165
+ # a) We had a valid proposal (cand_pos_valid != -1)
166
+ # b) Our index matches the winner
167
+ success_round = (cand_pos_valid >= 0) & (winner_indices == item_indices)
168
+
169
+ # 5. Commit state
170
+ # Update global state based on success
171
+ item_to_pos = jnp.where(success_round, cand_pos_valid, item_to_pos)
172
+ active_mask = active_mask & (~success_round)
173
+
174
+ # Update table slots (truncate dump)
175
+ table_slots = ext_slots_updated[:table_size]
176
+ table_occupied = table_slots >= 0
177
+
178
+ # Construct final table
179
+ safe_indices = jnp.maximum(table_slots, 0)
180
+ final_table = items[safe_indices]
181
+ final_table = jnp.where(table_slots[:, None] >= 0, final_table, 0)
182
+
183
+ success_total = item_to_pos >= 0
184
+ return final_table, item_to_pos, success_total
185
+
186
+
187
+ def cuckoo_lookup_positions(items: Any, table_size: int, seed: tuple[int, int]) -> Any:
188
+ """Get Cuckoo lookup positions for each item.
189
+
190
+ Returns the K candidate positions where each item could be located
191
+ in a Cuckoo hash table.
192
+
193
+ Args:
194
+ items: (M, 16) uint8 array - items to lookup
195
+ table_size: Size of Cuckoo hash table
196
+ seed: (2,) uint64 seed
197
+
198
+ Returns:
199
+ (M, K) int32 array - K positions to check for each item
200
+ """
201
+ return hash_to_positions(items, table_size, seed)
202
+
203
+
204
+ # =============================================================================
205
+ # EDSL Wrappers
206
+ # =============================================================================
207
+
208
+
209
+ def compute_positions(
210
+ items: el.Object,
211
+ table_size: int,
212
+ seed: el.Object, # (2,) uint64
213
+ ) -> el.Object:
214
+ """Compute Cuckoo hash positions for items (EDSL wrapper).
215
+
216
+ Args:
217
+ items: (N, 16) byte tensor of items
218
+ table_size: Size of Cuckoo hash table
219
+ seed: (2,) uint64 seed
220
+
221
+ Returns:
222
+ (N, K) int32 tensor of candidate positions
223
+ """
224
+
225
+ def _hash(x: Any, s: Any) -> Any:
226
+ return hash_to_positions(x, table_size, tuple(s))
227
+
228
+ return cast(el.Object, tensor.run_jax(_hash, items, seed))
@@ -0,0 +1,49 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Abstract Base Class for OKVS (Oblivious Key-Value Store)."""
16
+
17
+ from abc import ABC, abstractmethod
18
+
19
+ import mplang.v2.edsl as el
20
+
21
+
22
+ class OKVS(ABC):
23
+ """Abstract interface for Oblivious Key-Value Store."""
24
+
25
+ @abstractmethod
26
+ def encode(self, keys: el.Object, values: el.Object, seed: el.Object) -> el.Object:
27
+ """Encode items into OKVS storage.
28
+
29
+ Args:
30
+ keys: (N,) uint64 tensor of keys
31
+ values: (N, D) uint64 tensor of values
32
+ seed: (2,) uint64 tensor seed
33
+
34
+ Returns:
35
+ (M, D) uint64 tensor OKVS storage
36
+ """
37
+
38
+ @abstractmethod
39
+ def decode(self, keys: el.Object, storage: el.Object, seed: el.Object) -> el.Object:
40
+ """Decode items from OKVS storage.
41
+
42
+ Args:
43
+ keys: (N,) uint64 tensor of keys to query
44
+ storage: (M, D) uint64 tensor OKVS storage
45
+ seed: (2,) uint64 tensor seed
46
+
47
+ Returns:
48
+ (N, D) uint64 tensor of recovered values
49
+ """
@@ -0,0 +1,79 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Sparse OKVS (Oblivious Key-Value Store) Implementation.
16
+
17
+ This module provides the core data structures and algorithms for Sparse OKVS,
18
+ which is a critical component in unbalanced Private Set Intersection (PSI).
19
+ """
20
+
21
+ import mplang.v2.edsl as el
22
+ from mplang.v2.dialects import field
23
+ from mplang.v2.libs.mpc.psi.okvs import OKVS
24
+
25
+ # ============================================================================
26
+ # Constants
27
+ # ============================================================================
28
+
29
+ # Number of hash functions for Cuckoo hashing
30
+ NUM_HASHES = 3
31
+
32
+
33
+ def get_okvs_expansion(n: int) -> float:
34
+ """Get optimal OKVS expansion factor based on dataset size.
35
+
36
+ The 3-hash Garbled Cuckoo Table algorithm requires table size M > N for
37
+ the peeling algorithm to successfully solve the system. The minimum safe
38
+ expansion factor ε (where M = (1+ε)*N) depends on N:
39
+
40
+ - For N → ∞: Theoretical minimum is ε ≈ 0.23 (M = 1.23N)
41
+ - For finite N: Larger ε needed due to variance in random hash collisions
42
+
43
+ Empirical safe thresholds (failure probability < 0.01%):
44
+ - N < 1,000: ε = 4.5 (M = 5.5N) - very small sets need extra wide margin
45
+ to handle worst-case hash collisions
46
+ - N < 10,000: ε = 0.4 (M = 1.4N)
47
+ - N < 100,000: ε = 0.3 (M = 1.3N)
48
+ - N ≥ 100,000: ε = 0.35 (M = 1.35N) - large sets converge near theory
49
+
50
+ Args:
51
+ n: Number of key-value pairs to encode
52
+
53
+ Returns:
54
+ Expansion factor ε such that M = (1+ε)*N is safe for peeling
55
+ """
56
+ if n < 1000:
57
+ return 5.5 # Small scale: need very wide safety margin for stability
58
+ elif n <= 10000:
59
+ return 1.4 # Medium scale
60
+ elif n <= 100000:
61
+ return 1.3 # Large scale
62
+ else:
63
+ # Mega-Binning requires ~1.35 for stability with 1024 bins
64
+ return 1.35
65
+
66
+
67
+ class SparseOKVS(OKVS):
68
+ """Sparse OKVS Implementation using 3-Hash Garbled Cuckoo Table."""
69
+
70
+ def __init__(self, m: int):
71
+ self.m = m
72
+
73
+ def encode(self, keys: el.Object, values: el.Object, seed: el.Object) -> el.Object:
74
+ """Encode items into OKVS storage using C++ Kernel."""
75
+ return field.solve_okvs(keys, values, self.m, seed)
76
+
77
+ def decode(self, keys: el.Object, storage: el.Object, seed: el.Object) -> el.Object:
78
+ """Decode items from OKVS storage using C++ Kernel."""
79
+ return field.decode_okvs(keys, storage, seed)