mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,291 @@
1
+ /*
2
+ * Copyright 2025 Ant Group Co., Ltd.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <cstdint>
18
+ #include <vector>
19
+ #include <stack>
20
+ #include <random>
21
+ #include <immintrin.h>
22
+ #include <cstring>
23
+ #include <cstdio>
24
+ #include <iostream>
25
+ #include <omp.h>
26
+ #include <atomic>
27
+
28
+ extern "C" {
29
+
30
+ // Number of Bins for Mega-Binning strategy.
31
+ // 1024 bins implies ~1000 items per bin for N=1M, fitting the working set
32
+ // entirely in L1 cache (32KB/48KB) for maximum performance.
33
+ static const uint64_t NUM_BINS = 1024;
34
+
35
+ struct Indices {
36
+ uint64_t h1, h2, h3;
37
+ };
38
+
39
+ // Stateless Bin Selection
40
+ // Maps a key to a deterministic bin index [0, NUM_BINS).
41
+ inline uint64_t get_bin_index(uint64_t key, __m128i seed) {
42
+ __m128i k = _mm_set_epi64x(0, key);
43
+ __m128i h = _mm_aesenc_si128(k, seed);
44
+ h = _mm_aesenc_si128(h, seed);
45
+ uint64_t v1 = _mm_extract_epi64(h, 0);
46
+ return v1 % NUM_BINS;
47
+ }
48
+
49
+ // Generate 3 positions within a local bin of size m_local.
50
+ inline Indices get_bin_local_indices(uint64_t key, uint64_t m_local, __m128i seed) {
51
+ // Use a distinct seed mix to decorrelate from bin selection
52
+ __m128i k = _mm_set_epi64x(0, key);
53
+ __m128i s2 = _mm_add_epi64(seed, _mm_set_epi64x(1, 1));
54
+ __m128i h = _mm_aesenc_si128(k, s2);
55
+ h = _mm_aesenc_si128(h, s2);
56
+ h = _mm_aesenc_si128(h, s2);
57
+
58
+ uint64_t r = _mm_extract_epi64(h, 0);
59
+ Indices idx;
60
+
61
+ // Fast modulo for local indices
62
+ idx.h1 = r % m_local;
63
+ r = r * 6364136223846793005ULL + 1442695040888963407ULL; // LCG step
64
+ idx.h2 = r % m_local;
65
+ r = r * 6364136223846793005ULL + 1442695040888963407ULL;
66
+ idx.h3 = r % m_local;
67
+
68
+ // Ensure distinct indices
69
+ if(idx.h2 == idx.h1) idx.h2 = (idx.h2 + 1) % m_local;
70
+ if(idx.h3 == idx.h1 || idx.h3 == idx.h2) {
71
+ idx.h3 = (idx.h3 + 1) % m_local;
72
+ if(idx.h3 == idx.h1 || idx.h3 == idx.h2) idx.h3 = (idx.h3 + 1) % m_local;
73
+ }
74
+ return idx;
75
+ }
76
+
77
+ // Core Peeling Solver for a single Bin
78
+ bool solve_bin(
79
+ const std::vector<uint64_t>& keys,
80
+ const std::vector<__m128i>& vals,
81
+ __m128i* P_local,
82
+ uint64_t m,
83
+ __m128i seed
84
+ ) {
85
+ uint64_t n = keys.size();
86
+ if (n == 0) return true;
87
+
88
+ struct Edge {
89
+ uint64_t h1, h2, h3;
90
+ uint64_t key_idx;
91
+ };
92
+ std::vector<Edge> edges(n);
93
+ std::vector<int> col_degree(m, 0);
94
+
95
+ // 1. Build Local Graph
96
+ for(uint64_t i=0; i<n; ++i) {
97
+ Indices idx = get_bin_local_indices(keys[i], m, seed);
98
+ edges[i] = {idx.h1, idx.h2, idx.h3, i};
99
+ col_degree[idx.h1]++;
100
+ col_degree[idx.h2]++;
101
+ col_degree[idx.h3]++;
102
+ }
103
+
104
+ // 2. CSR Construction
105
+ std::vector<int> col_start(m + 1, 0);
106
+ for(uint64_t j=0; j<m; ++j) {
107
+ col_start[j+1] = col_start[j] + col_degree[j];
108
+ }
109
+ std::vector<int> flat_rows(n * 3);
110
+ std::vector<int> fill_ptr = col_start;
111
+ for(uint64_t i=0; i<n; ++i) {
112
+ flat_rows[fill_ptr[edges[i].h1]++] = i;
113
+ flat_rows[fill_ptr[edges[i].h2]++] = i;
114
+ flat_rows[fill_ptr[edges[i].h3]++] = i;
115
+ }
116
+
117
+ // 3. Peeling Process
118
+ std::vector<int> peel_stack;
119
+ peel_stack.reserve(m);
120
+ for(uint64_t j=0; j<m; ++j) {
121
+ if(col_degree[j] == 1) peel_stack.push_back(j);
122
+ }
123
+
124
+ std::vector<bool> row_removed(n, false);
125
+ std::vector<bool> col_removed(m, false);
126
+
127
+ struct Assignment {
128
+ int col;
129
+ int row_idx;
130
+ };
131
+ std::vector<Assignment> assignment_stack;
132
+ assignment_stack.reserve(n);
133
+
134
+ int head = 0;
135
+ while(head < peel_stack.size()) {
136
+ int j = peel_stack[head++];
137
+ if(col_removed[j]) continue;
138
+
139
+ int owner_row = -1;
140
+ for(int k=col_start[j]; k<col_start[j+1]; ++k) {
141
+ int r = flat_rows[k];
142
+ if(!row_removed[r]) {
143
+ owner_row = r;
144
+ break;
145
+ }
146
+ }
147
+ if(owner_row == -1) {
148
+ col_removed[j] = true;
149
+ continue;
150
+ }
151
+
152
+ assignment_stack.push_back({j, owner_row});
153
+ col_removed[j] = true;
154
+ row_removed[owner_row] = true;
155
+
156
+ const auto& e = edges[owner_row];
157
+ uint64_t nbs[3] = {e.h1, e.h2, e.h3};
158
+ for(uint64_t nb : nbs) {
159
+ if(nb == (uint64_t)j) continue;
160
+ if(col_removed[nb]) continue;
161
+ col_degree[nb]--;
162
+ if(col_degree[nb] == 1) peel_stack.push_back((int)nb);
163
+ }
164
+ }
165
+
166
+ if(assignment_stack.size() != n) return false;
167
+
168
+ // 4. Back-Substitution
169
+ for(int i=(int)assignment_stack.size()-1; i>=0; --i) {
170
+ auto a = assignment_stack[i];
171
+ const auto& e = edges[a.row_idx];
172
+
173
+ __m128i val1 = _mm_loadu_si128(&P_local[e.h1]);
174
+ __m128i val2 = _mm_loadu_si128(&P_local[e.h2]);
175
+ __m128i val3 = _mm_loadu_si128(&P_local[e.h3]);
176
+ __m128i target = vals[e.key_idx];
177
+
178
+ __m128i current = _mm_xor_si128(_mm_xor_si128(val1, val2), val3);
179
+ __m128i diff = _mm_xor_si128(target, current);
180
+
181
+ _mm_storeu_si128(&P_local[a.col], diff);
182
+ }
183
+ return true;
184
+ }
185
+
186
+ void solve_okvs_opt(uint64_t* keys, uint64_t* values, uint64_t* output, uint64_t n, uint64_t m, uint64_t* seed_ptr) {
187
+ __m128i seed = _mm_loadu_si128((__m128i*)seed_ptr);
188
+
189
+ // 1. Calculate Bin Boundaries
190
+ // We divide M evenly among bins. The remainder is distributed to the first few bins.
191
+ std::vector<uint64_t> bin_offsets(NUM_BINS + 1);
192
+ std::vector<uint64_t> m_per_bin(NUM_BINS);
193
+
194
+ uint64_t base_m = m / NUM_BINS;
195
+ uint64_t remainder = m % NUM_BINS;
196
+
197
+ uint64_t current_offset = 0;
198
+ for(uint64_t b=0; b<NUM_BINS; ++b) {
199
+ bin_offsets[b] = current_offset;
200
+ m_per_bin[b] = base_m + (b < remainder ? 1 : 0);
201
+ current_offset += m_per_bin[b];
202
+ }
203
+ bin_offsets[NUM_BINS] = m;
204
+
205
+ // 2. Partition Data (Stateless)
206
+ // Note on "Two-Choice Hashing":
207
+ // While Two-Choice Hashing (selecting the lighter of 2 potential bins) would significantly
208
+ // reduce max bin load variance, it introduces "Statefulness".
209
+ // The bin assignment for Key K would depend on the load of bins, which depends on other keys.
210
+ // In standard PSI protocols (like RR22), the Decode step must be capable of processing keys
211
+ // independently or without knowledge of the full set distribution (Sender/Receiver separation).
212
+ // Therefore, we use **Simple Binning** (Stateless Hash) where Bin(K) = H(K) % Bins.
213
+ // We mitigate the resulting variance ("Balls-in-Bins" problem) by using a slightly larger
214
+ // expansion factor (epsilon ~ 1.35) which is bandwidth-acceptable and ensures stability.
215
+
216
+ std::vector<std::vector<uint64_t>> bin_keys(NUM_BINS);
217
+ std::vector<std::vector<__m128i>> bin_vals(NUM_BINS);
218
+
219
+ // Pre-allocate to reduce reallocation overhead (assume ~uniform distribution)
220
+ // 1.5x margin for pre-allocation safety
221
+ size_t est_size = (n / NUM_BINS) * 3 / 2;
222
+ for(int b=0; b<NUM_BINS; ++b) {
223
+ bin_keys[b].reserve(est_size);
224
+ bin_vals[b].reserve(est_size);
225
+ }
226
+
227
+ const __m128i* V_ptr = (const __m128i*)values;
228
+ for(uint64_t i=0; i<n; ++i) {
229
+ uint64_t b = get_bin_index(keys[i], seed);
230
+ bin_keys[b].push_back(keys[i]);
231
+ bin_vals[b].push_back(_mm_loadu_si128(&V_ptr[i]));
232
+ }
233
+
234
+ // 3. Parallel Solve
235
+ // Each bin is solved independently. This logic is perfectly parallelizable (embarrassingly parallel).
236
+ // The working set for each bin (~1000 items) stays hot in L1 Cache.
237
+ memset(output, 0, m * 16);
238
+ __m128i* P_vec = (__m128i*)output;
239
+
240
+ #pragma omp parallel for schedule(dynamic)
241
+ for(uint64_t b=0; b<NUM_BINS; ++b) {
242
+ if(bin_keys[b].empty()) continue;
243
+
244
+ uint64_t offset = bin_offsets[b];
245
+ uint64_t valid_m = m_per_bin[b];
246
+
247
+ if(!solve_bin(bin_keys[b], bin_vals[b], &P_vec[offset], valid_m, seed)) {
248
+ #pragma omp critical
249
+ {
250
+ fprintf(stderr, "[ERROR] Bin %lu failed OKVS peeling. Items: %lu / M: %lu (Ratio: %.2f). Try increasing expansion factor.\n",
251
+ b, bin_keys[b].size(), valid_m, (double)valid_m / bin_keys[b].size());
252
+ }
253
+ }
254
+ }
255
+ }
256
+
257
+ void decode_okvs_opt(uint64_t* keys, uint64_t* storage, uint64_t* output, uint64_t n, uint64_t m, uint64_t* seed_ptr) {
258
+ __m128i seed = _mm_loadu_si128((__m128i*)seed_ptr);
259
+ __m128i* P_vec = (__m128i*)storage;
260
+ __m128i* out_vec = (__m128i*)output;
261
+
262
+ // Replicate Boundary Logic
263
+ std::vector<uint64_t> bin_offsets(NUM_BINS + 1);
264
+ std::vector<uint64_t> m_per_bin(NUM_BINS);
265
+ uint64_t base_m = m / NUM_BINS;
266
+ uint64_t remainder = m % NUM_BINS;
267
+ uint64_t current_offset = 0;
268
+ for(uint64_t b=0; b<NUM_BINS; ++b) {
269
+ bin_offsets[b] = current_offset;
270
+ m_per_bin[b] = base_m + (b < remainder ? 1 : 0);
271
+ current_offset += m_per_bin[b];
272
+ }
273
+
274
+ // Parallel Stateless Decode
275
+ #pragma omp parallel for schedule(static)
276
+ for(uint64_t i=0; i<n; ++i) {
277
+ uint64_t b = get_bin_index(keys[i], seed);
278
+
279
+ uint64_t m_local = m_per_bin[b];
280
+ uint64_t offset = bin_offsets[b];
281
+
282
+ Indices idx = get_bin_local_indices(keys[i], m_local, seed);
283
+
284
+ __m128i val = _mm_xor_si128(
285
+ _mm_xor_si128(_mm_loadu_si128(&P_vec[offset + idx.h1]), _mm_loadu_si128(&P_vec[offset + idx.h2])),
286
+ _mm_loadu_si128(&P_vec[offset + idx.h3])
287
+ );
288
+ _mm_storeu_si128(&out_vec[i], val);
289
+ }
290
+ }
291
+ }
@@ -0,0 +1,398 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Pure Python implementations of performance-critical kernels.
16
+
17
+ These implementations provide fallback functionality when native C++ kernels
18
+ (libmplang_kernels.so) are not available. They are functionally correct but
19
+ significantly slower than the optimized C++ versions.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import numpy as np
25
+
26
+ from mplang.v2.libs.mpc.common.constants import (
27
+ GOLDEN_RATIO_64,
28
+ SPLITMIX64_GAMMA_2,
29
+ SPLITMIX64_GAMMA_3,
30
+ SPLITMIX64_GAMMA_4,
31
+ )
32
+
33
+ # =============================================================================
34
+ # GF(2^128) Arithmetic
35
+ # =============================================================================
36
+
37
+ # Irreducible polynomial: P(x) = x^128 + x^7 + x^2 + x + 1
38
+ # In polynomial basis, this means x^128 = x^7 + x^2 + x + 1 (mod P)
39
+ _GF128_POLYNOMIAL = 0x87 # x^7 + x^2 + x + 1 = 0b10000111 = 135
40
+
41
+
42
+ def _gf128_clmul64(a: int, b: int) -> tuple[int, int]:
43
+ """Carryless multiplication of two 64-bit integers.
44
+
45
+ Returns (lo, hi) where result = hi * 2^64 + lo.
46
+ """
47
+ result_lo = 0
48
+ result_hi = 0
49
+
50
+ for i in range(64):
51
+ if (b >> i) & 1:
52
+ # Add a shifted by i positions
53
+ shifted_lo = (a << i) & ((1 << 64) - 1)
54
+ shifted_hi = a >> (64 - i) if i > 0 else 0
55
+ result_lo ^= shifted_lo
56
+ result_hi ^= shifted_hi
57
+
58
+ return result_lo, result_hi
59
+
60
+
61
+ def _gf128_clmul128(
62
+ a_lo: int, a_hi: int, b_lo: int, b_hi: int
63
+ ) -> tuple[int, int, int, int]:
64
+ """Carryless multiplication of two 128-bit values.
65
+
66
+ Returns (r0, r1, r2, r3) where result = r3 * 2^192 + r2 * 2^128 + r1 * 2^64 + r0.
67
+ """
68
+ # a_lo * b_lo -> [0:128]
69
+ t0_lo, t0_hi = _gf128_clmul64(a_lo, b_lo)
70
+
71
+ # a_hi * b_hi -> [128:256]
72
+ t1_lo, t1_hi = _gf128_clmul64(a_hi, b_hi)
73
+
74
+ # a_lo * b_hi -> [64:192]
75
+ t2_lo, t2_hi = _gf128_clmul64(a_lo, b_hi)
76
+
77
+ # a_hi * b_lo -> [64:192]
78
+ t3_lo, t3_hi = _gf128_clmul64(a_hi, b_lo)
79
+
80
+ # Combine cross terms
81
+ mid_lo = t2_lo ^ t3_lo
82
+ mid_hi = t2_hi ^ t3_hi
83
+
84
+ # Result accumulation
85
+ r0 = t0_lo
86
+ r1 = t0_hi ^ mid_lo
87
+ r2 = t1_lo ^ mid_hi
88
+ r3 = t1_hi
89
+
90
+ # Handle carry from r1 to r2 (carryless, just XOR overflow)
91
+ # In carryless arithmetic, there's no carry propagation
92
+
93
+ return r0, r1, r2, r3
94
+
95
+
96
+ def _gf128_reduce(r0: int, r1: int, r2: int, r3: int) -> tuple[int, int]:
97
+ """Reduce 256-bit polynomial modulo P(x) = x^128 + x^7 + x^2 + x + 1.
98
+
99
+ Returns (lo, hi) representing the 128-bit result.
100
+ """
101
+ # Reduction: x^128 = x^7 + x^2 + x + 1 (mod P)
102
+ # So we need to reduce r2 and r3 into r0 and r1
103
+
104
+ # r3 contributes at positions [192:256], which after reduction affects [64:128] and [0:64]
105
+ # r2 contributes at positions [128:192], which after reduction affects [0:64]
106
+
107
+ # First, reduce r3 (bits 192-255)
108
+ # x^192 = x^64 * x^128 = x^64 * (x^7 + x^2 + x + 1)
109
+ # = x^71 + x^66 + x^65 + x^64
110
+ # x^256 is beyond our range, but r3 represents bits [192:256]
111
+
112
+ # For each bit position p in [192:255] that is set:
113
+ # x^p = x^(p-128) * x^128 = x^(p-128) * 0x87
114
+ # This means bit at position p reduces to XOR with 0x87 shifted by (p-128)
115
+
116
+ # Simpler approach: reduce in two stages
117
+
118
+ # Stage 1: Reduce r3 (affects r1 and r0 after multiple reductions)
119
+ # r3 * x^192 mod P = r3 * x^64 * (x^7 + x^2 + x + 1)
120
+ q3_lo, q3_hi = _gf128_clmul64(r3, _GF128_POLYNOMIAL)
121
+ # This gives us bits at [64+0:64+128] = [64:192]
122
+ # So it affects r1 and r2
123
+
124
+ r1 ^= q3_lo
125
+ r2 ^= q3_hi
126
+
127
+ # Stage 2: Reduce r2 (affects r0 and r1)
128
+ # r2 * x^128 mod P = r2 * 0x87
129
+ q2_lo, q2_hi = _gf128_clmul64(r2, _GF128_POLYNOMIAL)
130
+ # This gives bits at [0:128]
131
+
132
+ r0 ^= q2_lo
133
+ r1 ^= q2_hi
134
+
135
+ return r0, r1
136
+
137
+
138
+ def gf128_mul_single(a: np.ndarray, b: np.ndarray) -> np.ndarray:
139
+ """Multiply two GF(2^128) elements.
140
+
141
+ Args:
142
+ a: Shape (2,) uint64 array representing a 128-bit element [lo, hi]
143
+ b: Shape (2,) uint64 array representing a 128-bit element [lo, hi]
144
+
145
+ Returns:
146
+ Shape (2,) uint64 array representing the product
147
+ """
148
+ a_lo, a_hi = int(a[0]), int(a[1])
149
+ b_lo, b_hi = int(b[0]), int(b[1])
150
+
151
+ r0, r1, r2, r3 = _gf128_clmul128(a_lo, a_hi, b_lo, b_hi)
152
+ res_lo, res_hi = _gf128_reduce(r0, r1, r2, r3)
153
+
154
+ return np.array(
155
+ [res_lo & ((1 << 64) - 1), res_hi & ((1 << 64) - 1)], dtype=np.uint64
156
+ )
157
+
158
+
159
+ def gf128_mul_batch(a: np.ndarray, b: np.ndarray) -> np.ndarray:
160
+ """Batch multiply GF(2^128) elements.
161
+
162
+ Args:
163
+ a: Shape (..., 2) uint64 array
164
+ b: Shape (..., 2) uint64 array
165
+
166
+ Returns:
167
+ Shape (..., 2) uint64 array of products
168
+ """
169
+ original_shape = a.shape
170
+ a_flat = a.reshape(-1, 2)
171
+ b_flat = b.reshape(-1, 2)
172
+ n = a_flat.shape[0]
173
+
174
+ result = np.zeros_like(a_flat)
175
+ for i in range(n):
176
+ result[i] = gf128_mul_single(a_flat[i], b_flat[i])
177
+
178
+ return result.reshape(original_shape)
179
+
180
+
181
+ # =============================================================================
182
+ # OKVS (Oblivious Key-Value Store) - 3-Hash Garbled Cuckoo Table
183
+ # =============================================================================
184
+
185
+
186
+ def _hash_key_py(key: int, m: int, seed: tuple[int, int]) -> tuple[int, int, int]:
187
+ """Hash a key to 3 distinct indices using simple polynomial hashing.
188
+
189
+ This is a pure Python approximation of the AES-based hash in C++.
190
+ For compatibility, we use a deterministic hash based on the key.
191
+ """
192
+ # Simple polynomial hash (not as secure as AES, but deterministic)
193
+ s0, s1 = seed
194
+
195
+ # Mix key with seed
196
+ h1 = ((key * GOLDEN_RATIO_64) ^ s0) & ((1 << 64) - 1)
197
+ h2 = ((key * SPLITMIX64_GAMMA_2) ^ s1) & ((1 << 64) - 1)
198
+
199
+ # Additional mixing
200
+ h1 = ((h1 ^ (h1 >> 33)) * SPLITMIX64_GAMMA_3) & ((1 << 64) - 1)
201
+ h2 = ((h2 ^ (h2 >> 33)) * SPLITMIX64_GAMMA_4) & ((1 << 64) - 1)
202
+
203
+ idx1 = h1 % m
204
+ idx2 = h2 % m
205
+ idx3 = (h1 ^ h2) % m
206
+
207
+ # Enforce distinct indices
208
+ if idx2 == idx1:
209
+ idx2 = (idx2 + 1) % m
210
+ if idx3 == idx1 or idx3 == idx2:
211
+ idx3 = (idx3 + 1) % m
212
+ if idx3 == idx1 or idx3 == idx2:
213
+ idx3 = (idx3 + 1) % m
214
+
215
+ return int(idx1), int(idx2), int(idx3)
216
+
217
+
218
+ def okvs_solve(
219
+ keys: np.ndarray,
220
+ values: np.ndarray,
221
+ m: int,
222
+ seed: tuple[int, int] = (0xDEADBEEF, 0xCAFEBABE),
223
+ ) -> np.ndarray:
224
+ """Solve the OKVS system using peeling algorithm.
225
+
226
+ Args:
227
+ keys: Shape (n,) uint64 array of keys
228
+ values: Shape (n, 2) uint64 array of values (128-bit each)
229
+ m: Size of output storage
230
+
231
+ Returns:
232
+ Shape (m, 2) uint64 array representing the OKVS storage
233
+ """
234
+ n = len(keys)
235
+
236
+ # Build graph: for each row, compute its 3 column indices
237
+ rows = []
238
+ col_to_rows: dict[int, list[int]] = {j: [] for j in range(m)}
239
+
240
+ for i in range(n):
241
+ h1, h2, h3 = _hash_key_py(int(keys[i]), m, seed)
242
+ rows.append((h1, h2, h3))
243
+ col_to_rows[h1].append(i)
244
+ col_to_rows[h2].append(i)
245
+ col_to_rows[h3].append(i)
246
+
247
+ # Compute column degrees
248
+ col_degree = [len(col_to_rows[j]) for j in range(m)]
249
+
250
+ # Initialize peel queue with degree-1 columns
251
+ peel_queue = [j for j in range(m) if col_degree[j] == 1]
252
+
253
+ row_removed = [False] * n
254
+ col_removed = [False] * m
255
+ assignment_stack: list[tuple[int, int]] = [] # (col, row)
256
+
257
+ head = 0
258
+ while head < len(peel_queue):
259
+ j = peel_queue[head]
260
+ head += 1
261
+
262
+ if col_removed[j]:
263
+ continue
264
+
265
+ # Find the single active row for this column
266
+ owner_row = -1
267
+ for r_idx in col_to_rows[j]:
268
+ if not row_removed[r_idx]:
269
+ owner_row = r_idx
270
+ break
271
+
272
+ if owner_row == -1:
273
+ col_removed[j] = True
274
+ continue
275
+
276
+ # Peel this (column, row) pair
277
+ assignment_stack.append((j, owner_row))
278
+ col_removed[j] = True
279
+ row_removed[owner_row] = True
280
+
281
+ # Update neighbor column degrees
282
+ h1, h2, h3 = rows[owner_row]
283
+ for neighbor in (h1, h2, h3):
284
+ if neighbor == j or col_removed[neighbor]:
285
+ continue
286
+ col_degree[neighbor] -= 1
287
+ if col_degree[neighbor] == 1:
288
+ peel_queue.append(neighbor)
289
+
290
+ if len(assignment_stack) != n:
291
+ raise RuntimeError(
292
+ f"OKVS core detected. Failed to peel all rows. "
293
+ f"n={n}, m={m}, solved={len(assignment_stack)}"
294
+ )
295
+
296
+ # Back substitution (solve in reverse order)
297
+ output = np.zeros((m, 2), dtype=np.uint64)
298
+
299
+ for col, row in reversed(assignment_stack):
300
+ h1, h2, h3 = rows[row]
301
+ # Current sum of columns in this row
302
+ current_sum = output[h1] ^ output[h2] ^ output[h3]
303
+ # Compute value needed for col to make sum equal target
304
+ target = values[row]
305
+ diff = target ^ current_sum
306
+ output[col] = diff
307
+
308
+ return output
309
+
310
+
311
+ def okvs_decode(
312
+ keys: np.ndarray,
313
+ storage: np.ndarray,
314
+ m: int,
315
+ seed: tuple[int, int] = (0xDEADBEEF, 0xCAFEBABE),
316
+ ) -> np.ndarray:
317
+ """Decode values from OKVS storage.
318
+
319
+ Args:
320
+ keys: Shape (n,) uint64 array of keys to query
321
+ storage: Shape (m, 2) uint64 array (the solved OKVS)
322
+ m: Size of storage
323
+
324
+ Returns:
325
+ Shape (n, 2) uint64 array of decoded values
326
+ """
327
+ n = len(keys)
328
+ output = np.zeros((n, 2), dtype=np.uint64)
329
+
330
+ for i in range(n):
331
+ h1, h2, h3 = _hash_key_py(int(keys[i]), m, seed)
332
+ output[i] = storage[h1] ^ storage[h2] ^ storage[h3]
333
+
334
+ return output
335
+
336
+
337
+ # =============================================================================
338
+ # AES-128 Expansion (PRG Fallback)
339
+ # =============================================================================
340
+
341
+
342
+ def aes_expand(seeds: np.ndarray, length: int) -> np.ndarray:
343
+ """Expand seeds to pseudorandom sequence.
344
+
345
+ This is a fallback using NumPy's PRNG instead of AES-NI.
346
+
347
+ Args:
348
+ seeds: Shape (num_seeds, 2) uint64 array of 128-bit seeds
349
+ length: Number of 128-bit blocks to generate per seed
350
+
351
+ Returns:
352
+ Shape (num_seeds, length, 2) uint64 array
353
+ """
354
+ num_seeds = seeds.shape[0]
355
+ output = np.zeros((num_seeds, length, 2), dtype=np.uint64)
356
+
357
+ for i in range(num_seeds):
358
+ seed_val = [int(seeds[i, 0]), int(seeds[i, 1])]
359
+ rng = np.random.default_rng(seed_val)
360
+ output[i] = rng.integers(
361
+ 0, 0xFFFFFFFFFFFFFFFF, size=(length, 2), dtype=np.uint64
362
+ )
363
+ return output
364
+
365
+
366
+ # =============================================================================
367
+ # LDPC Encoding (Sparse)
368
+ # =============================================================================
369
+
370
+
371
+ def ldpc_encode(
372
+ message: np.ndarray, h_indices: np.ndarray, h_indptr: np.ndarray, m: int
373
+ ) -> np.ndarray:
374
+ """Compute syndrome S = H @ message using sparse CSR representation.
375
+
376
+ This is the fallback when C++ kernel is not available.
377
+
378
+ Args:
379
+ message: (N, 2) uint64 message vector
380
+ h_indices: CSR indices array for H
381
+ h_indptr: CSR indptr array for H (length m+1)
382
+ m: Number of rows in H (syndrome length)
383
+
384
+ Returns:
385
+ (m, 2) uint64 syndrome vector
386
+ """
387
+ syndrome = np.zeros((m, 2), dtype=np.uint64)
388
+
389
+ for i in range(m):
390
+ # Get column indices for row i
391
+ start, end = int(h_indptr[i]), int(h_indptr[i + 1])
392
+ cols = h_indices[start:end]
393
+
394
+ # XOR all selected message elements
395
+ for j in cols:
396
+ syndrome[i] ^= message[int(j)]
397
+
398
+ return syndrome