mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright 2025 Ant Group Co., Ltd.
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
#include <cstdint>
|
|
18
|
+
#include <vector>
|
|
19
|
+
#include <stack>
|
|
20
|
+
#include <random>
|
|
21
|
+
#include <immintrin.h>
|
|
22
|
+
#include <cstring>
|
|
23
|
+
#include <cstdio>
|
|
24
|
+
#include <iostream>
|
|
25
|
+
#include <omp.h>
|
|
26
|
+
#include <atomic>
|
|
27
|
+
|
|
28
|
+
extern "C" {
|
|
29
|
+
|
|
30
|
+
// Number of Bins for Mega-Binning strategy.
|
|
31
|
+
// 1024 bins implies ~1000 items per bin for N=1M, fitting the working set
|
|
32
|
+
// entirely in L1 cache (32KB/48KB) for maximum performance.
|
|
33
|
+
static const uint64_t NUM_BINS = 1024;
|
|
34
|
+
|
|
35
|
+
struct Indices {
|
|
36
|
+
uint64_t h1, h2, h3;
|
|
37
|
+
};
|
|
38
|
+
|
|
39
|
+
// Stateless Bin Selection
|
|
40
|
+
// Maps a key to a deterministic bin index [0, NUM_BINS).
|
|
41
|
+
inline uint64_t get_bin_index(uint64_t key, __m128i seed) {
|
|
42
|
+
__m128i k = _mm_set_epi64x(0, key);
|
|
43
|
+
__m128i h = _mm_aesenc_si128(k, seed);
|
|
44
|
+
h = _mm_aesenc_si128(h, seed);
|
|
45
|
+
uint64_t v1 = _mm_extract_epi64(h, 0);
|
|
46
|
+
return v1 % NUM_BINS;
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
// Generate 3 positions within a local bin of size m_local.
|
|
50
|
+
inline Indices get_bin_local_indices(uint64_t key, uint64_t m_local, __m128i seed) {
|
|
51
|
+
// Use a distinct seed mix to decorrelate from bin selection
|
|
52
|
+
__m128i k = _mm_set_epi64x(0, key);
|
|
53
|
+
__m128i s2 = _mm_add_epi64(seed, _mm_set_epi64x(1, 1));
|
|
54
|
+
__m128i h = _mm_aesenc_si128(k, s2);
|
|
55
|
+
h = _mm_aesenc_si128(h, s2);
|
|
56
|
+
h = _mm_aesenc_si128(h, s2);
|
|
57
|
+
|
|
58
|
+
uint64_t r = _mm_extract_epi64(h, 0);
|
|
59
|
+
Indices idx;
|
|
60
|
+
|
|
61
|
+
// Fast modulo for local indices
|
|
62
|
+
idx.h1 = r % m_local;
|
|
63
|
+
r = r * 6364136223846793005ULL + 1442695040888963407ULL; // LCG step
|
|
64
|
+
idx.h2 = r % m_local;
|
|
65
|
+
r = r * 6364136223846793005ULL + 1442695040888963407ULL;
|
|
66
|
+
idx.h3 = r % m_local;
|
|
67
|
+
|
|
68
|
+
// Ensure distinct indices
|
|
69
|
+
if(idx.h2 == idx.h1) idx.h2 = (idx.h2 + 1) % m_local;
|
|
70
|
+
if(idx.h3 == idx.h1 || idx.h3 == idx.h2) {
|
|
71
|
+
idx.h3 = (idx.h3 + 1) % m_local;
|
|
72
|
+
if(idx.h3 == idx.h1 || idx.h3 == idx.h2) idx.h3 = (idx.h3 + 1) % m_local;
|
|
73
|
+
}
|
|
74
|
+
return idx;
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
// Core Peeling Solver for a single Bin
|
|
78
|
+
bool solve_bin(
|
|
79
|
+
const std::vector<uint64_t>& keys,
|
|
80
|
+
const std::vector<__m128i>& vals,
|
|
81
|
+
__m128i* P_local,
|
|
82
|
+
uint64_t m,
|
|
83
|
+
__m128i seed
|
|
84
|
+
) {
|
|
85
|
+
uint64_t n = keys.size();
|
|
86
|
+
if (n == 0) return true;
|
|
87
|
+
|
|
88
|
+
struct Edge {
|
|
89
|
+
uint64_t h1, h2, h3;
|
|
90
|
+
uint64_t key_idx;
|
|
91
|
+
};
|
|
92
|
+
std::vector<Edge> edges(n);
|
|
93
|
+
std::vector<int> col_degree(m, 0);
|
|
94
|
+
|
|
95
|
+
// 1. Build Local Graph
|
|
96
|
+
for(uint64_t i=0; i<n; ++i) {
|
|
97
|
+
Indices idx = get_bin_local_indices(keys[i], m, seed);
|
|
98
|
+
edges[i] = {idx.h1, idx.h2, idx.h3, i};
|
|
99
|
+
col_degree[idx.h1]++;
|
|
100
|
+
col_degree[idx.h2]++;
|
|
101
|
+
col_degree[idx.h3]++;
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
// 2. CSR Construction
|
|
105
|
+
std::vector<int> col_start(m + 1, 0);
|
|
106
|
+
for(uint64_t j=0; j<m; ++j) {
|
|
107
|
+
col_start[j+1] = col_start[j] + col_degree[j];
|
|
108
|
+
}
|
|
109
|
+
std::vector<int> flat_rows(n * 3);
|
|
110
|
+
std::vector<int> fill_ptr = col_start;
|
|
111
|
+
for(uint64_t i=0; i<n; ++i) {
|
|
112
|
+
flat_rows[fill_ptr[edges[i].h1]++] = i;
|
|
113
|
+
flat_rows[fill_ptr[edges[i].h2]++] = i;
|
|
114
|
+
flat_rows[fill_ptr[edges[i].h3]++] = i;
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
// 3. Peeling Process
|
|
118
|
+
std::vector<int> peel_stack;
|
|
119
|
+
peel_stack.reserve(m);
|
|
120
|
+
for(uint64_t j=0; j<m; ++j) {
|
|
121
|
+
if(col_degree[j] == 1) peel_stack.push_back(j);
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
std::vector<bool> row_removed(n, false);
|
|
125
|
+
std::vector<bool> col_removed(m, false);
|
|
126
|
+
|
|
127
|
+
struct Assignment {
|
|
128
|
+
int col;
|
|
129
|
+
int row_idx;
|
|
130
|
+
};
|
|
131
|
+
std::vector<Assignment> assignment_stack;
|
|
132
|
+
assignment_stack.reserve(n);
|
|
133
|
+
|
|
134
|
+
int head = 0;
|
|
135
|
+
while(head < peel_stack.size()) {
|
|
136
|
+
int j = peel_stack[head++];
|
|
137
|
+
if(col_removed[j]) continue;
|
|
138
|
+
|
|
139
|
+
int owner_row = -1;
|
|
140
|
+
for(int k=col_start[j]; k<col_start[j+1]; ++k) {
|
|
141
|
+
int r = flat_rows[k];
|
|
142
|
+
if(!row_removed[r]) {
|
|
143
|
+
owner_row = r;
|
|
144
|
+
break;
|
|
145
|
+
}
|
|
146
|
+
}
|
|
147
|
+
if(owner_row == -1) {
|
|
148
|
+
col_removed[j] = true;
|
|
149
|
+
continue;
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
assignment_stack.push_back({j, owner_row});
|
|
153
|
+
col_removed[j] = true;
|
|
154
|
+
row_removed[owner_row] = true;
|
|
155
|
+
|
|
156
|
+
const auto& e = edges[owner_row];
|
|
157
|
+
uint64_t nbs[3] = {e.h1, e.h2, e.h3};
|
|
158
|
+
for(uint64_t nb : nbs) {
|
|
159
|
+
if(nb == (uint64_t)j) continue;
|
|
160
|
+
if(col_removed[nb]) continue;
|
|
161
|
+
col_degree[nb]--;
|
|
162
|
+
if(col_degree[nb] == 1) peel_stack.push_back((int)nb);
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
if(assignment_stack.size() != n) return false;
|
|
167
|
+
|
|
168
|
+
// 4. Back-Substitution
|
|
169
|
+
for(int i=(int)assignment_stack.size()-1; i>=0; --i) {
|
|
170
|
+
auto a = assignment_stack[i];
|
|
171
|
+
const auto& e = edges[a.row_idx];
|
|
172
|
+
|
|
173
|
+
__m128i val1 = _mm_loadu_si128(&P_local[e.h1]);
|
|
174
|
+
__m128i val2 = _mm_loadu_si128(&P_local[e.h2]);
|
|
175
|
+
__m128i val3 = _mm_loadu_si128(&P_local[e.h3]);
|
|
176
|
+
__m128i target = vals[e.key_idx];
|
|
177
|
+
|
|
178
|
+
__m128i current = _mm_xor_si128(_mm_xor_si128(val1, val2), val3);
|
|
179
|
+
__m128i diff = _mm_xor_si128(target, current);
|
|
180
|
+
|
|
181
|
+
_mm_storeu_si128(&P_local[a.col], diff);
|
|
182
|
+
}
|
|
183
|
+
return true;
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
void solve_okvs_opt(uint64_t* keys, uint64_t* values, uint64_t* output, uint64_t n, uint64_t m, uint64_t* seed_ptr) {
|
|
187
|
+
__m128i seed = _mm_loadu_si128((__m128i*)seed_ptr);
|
|
188
|
+
|
|
189
|
+
// 1. Calculate Bin Boundaries
|
|
190
|
+
// We divide M evenly among bins. The remainder is distributed to the first few bins.
|
|
191
|
+
std::vector<uint64_t> bin_offsets(NUM_BINS + 1);
|
|
192
|
+
std::vector<uint64_t> m_per_bin(NUM_BINS);
|
|
193
|
+
|
|
194
|
+
uint64_t base_m = m / NUM_BINS;
|
|
195
|
+
uint64_t remainder = m % NUM_BINS;
|
|
196
|
+
|
|
197
|
+
uint64_t current_offset = 0;
|
|
198
|
+
for(uint64_t b=0; b<NUM_BINS; ++b) {
|
|
199
|
+
bin_offsets[b] = current_offset;
|
|
200
|
+
m_per_bin[b] = base_m + (b < remainder ? 1 : 0);
|
|
201
|
+
current_offset += m_per_bin[b];
|
|
202
|
+
}
|
|
203
|
+
bin_offsets[NUM_BINS] = m;
|
|
204
|
+
|
|
205
|
+
// 2. Partition Data (Stateless)
|
|
206
|
+
// Note on "Two-Choice Hashing":
|
|
207
|
+
// While Two-Choice Hashing (selecting the lighter of 2 potential bins) would significantly
|
|
208
|
+
// reduce max bin load variance, it introduces "Statefulness".
|
|
209
|
+
// The bin assignment for Key K would depend on the load of bins, which depends on other keys.
|
|
210
|
+
// In standard PSI protocols (like RR22), the Decode step must be capable of processing keys
|
|
211
|
+
// independently or without knowledge of the full set distribution (Sender/Receiver separation).
|
|
212
|
+
// Therefore, we use **Simple Binning** (Stateless Hash) where Bin(K) = H(K) % Bins.
|
|
213
|
+
// We mitigate the resulting variance ("Balls-in-Bins" problem) by using a slightly larger
|
|
214
|
+
// expansion factor (epsilon ~ 1.35) which is bandwidth-acceptable and ensures stability.
|
|
215
|
+
|
|
216
|
+
std::vector<std::vector<uint64_t>> bin_keys(NUM_BINS);
|
|
217
|
+
std::vector<std::vector<__m128i>> bin_vals(NUM_BINS);
|
|
218
|
+
|
|
219
|
+
// Pre-allocate to reduce reallocation overhead (assume ~uniform distribution)
|
|
220
|
+
// 1.5x margin for pre-allocation safety
|
|
221
|
+
size_t est_size = (n / NUM_BINS) * 3 / 2;
|
|
222
|
+
for(int b=0; b<NUM_BINS; ++b) {
|
|
223
|
+
bin_keys[b].reserve(est_size);
|
|
224
|
+
bin_vals[b].reserve(est_size);
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
const __m128i* V_ptr = (const __m128i*)values;
|
|
228
|
+
for(uint64_t i=0; i<n; ++i) {
|
|
229
|
+
uint64_t b = get_bin_index(keys[i], seed);
|
|
230
|
+
bin_keys[b].push_back(keys[i]);
|
|
231
|
+
bin_vals[b].push_back(_mm_loadu_si128(&V_ptr[i]));
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
// 3. Parallel Solve
|
|
235
|
+
// Each bin is solved independently. This logic is perfectly parallelizable (embarrassingly parallel).
|
|
236
|
+
// The working set for each bin (~1000 items) stays hot in L1 Cache.
|
|
237
|
+
memset(output, 0, m * 16);
|
|
238
|
+
__m128i* P_vec = (__m128i*)output;
|
|
239
|
+
|
|
240
|
+
#pragma omp parallel for schedule(dynamic)
|
|
241
|
+
for(uint64_t b=0; b<NUM_BINS; ++b) {
|
|
242
|
+
if(bin_keys[b].empty()) continue;
|
|
243
|
+
|
|
244
|
+
uint64_t offset = bin_offsets[b];
|
|
245
|
+
uint64_t valid_m = m_per_bin[b];
|
|
246
|
+
|
|
247
|
+
if(!solve_bin(bin_keys[b], bin_vals[b], &P_vec[offset], valid_m, seed)) {
|
|
248
|
+
#pragma omp critical
|
|
249
|
+
{
|
|
250
|
+
fprintf(stderr, "[ERROR] Bin %lu failed OKVS peeling. Items: %lu / M: %lu (Ratio: %.2f). Try increasing expansion factor.\n",
|
|
251
|
+
b, bin_keys[b].size(), valid_m, (double)valid_m / bin_keys[b].size());
|
|
252
|
+
}
|
|
253
|
+
}
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
void decode_okvs_opt(uint64_t* keys, uint64_t* storage, uint64_t* output, uint64_t n, uint64_t m, uint64_t* seed_ptr) {
|
|
258
|
+
__m128i seed = _mm_loadu_si128((__m128i*)seed_ptr);
|
|
259
|
+
__m128i* P_vec = (__m128i*)storage;
|
|
260
|
+
__m128i* out_vec = (__m128i*)output;
|
|
261
|
+
|
|
262
|
+
// Replicate Boundary Logic
|
|
263
|
+
std::vector<uint64_t> bin_offsets(NUM_BINS + 1);
|
|
264
|
+
std::vector<uint64_t> m_per_bin(NUM_BINS);
|
|
265
|
+
uint64_t base_m = m / NUM_BINS;
|
|
266
|
+
uint64_t remainder = m % NUM_BINS;
|
|
267
|
+
uint64_t current_offset = 0;
|
|
268
|
+
for(uint64_t b=0; b<NUM_BINS; ++b) {
|
|
269
|
+
bin_offsets[b] = current_offset;
|
|
270
|
+
m_per_bin[b] = base_m + (b < remainder ? 1 : 0);
|
|
271
|
+
current_offset += m_per_bin[b];
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
// Parallel Stateless Decode
|
|
275
|
+
#pragma omp parallel for schedule(static)
|
|
276
|
+
for(uint64_t i=0; i<n; ++i) {
|
|
277
|
+
uint64_t b = get_bin_index(keys[i], seed);
|
|
278
|
+
|
|
279
|
+
uint64_t m_local = m_per_bin[b];
|
|
280
|
+
uint64_t offset = bin_offsets[b];
|
|
281
|
+
|
|
282
|
+
Indices idx = get_bin_local_indices(keys[i], m_local, seed);
|
|
283
|
+
|
|
284
|
+
__m128i val = _mm_xor_si128(
|
|
285
|
+
_mm_xor_si128(_mm_loadu_si128(&P_vec[offset + idx.h1]), _mm_loadu_si128(&P_vec[offset + idx.h2])),
|
|
286
|
+
_mm_loadu_si128(&P_vec[offset + idx.h3])
|
|
287
|
+
);
|
|
288
|
+
_mm_storeu_si128(&out_vec[i], val);
|
|
289
|
+
}
|
|
290
|
+
}
|
|
291
|
+
}
|
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Pure Python implementations of performance-critical kernels.
|
|
16
|
+
|
|
17
|
+
These implementations provide fallback functionality when native C++ kernels
|
|
18
|
+
(libmplang_kernels.so) are not available. They are functionally correct but
|
|
19
|
+
significantly slower than the optimized C++ versions.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
|
|
26
|
+
from mplang.v2.libs.mpc.common.constants import (
|
|
27
|
+
GOLDEN_RATIO_64,
|
|
28
|
+
SPLITMIX64_GAMMA_2,
|
|
29
|
+
SPLITMIX64_GAMMA_3,
|
|
30
|
+
SPLITMIX64_GAMMA_4,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# =============================================================================
|
|
34
|
+
# GF(2^128) Arithmetic
|
|
35
|
+
# =============================================================================
|
|
36
|
+
|
|
37
|
+
# Irreducible polynomial: P(x) = x^128 + x^7 + x^2 + x + 1
|
|
38
|
+
# In polynomial basis, this means x^128 = x^7 + x^2 + x + 1 (mod P)
|
|
39
|
+
_GF128_POLYNOMIAL = 0x87 # x^7 + x^2 + x + 1 = 0b10000111 = 135
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _gf128_clmul64(a: int, b: int) -> tuple[int, int]:
|
|
43
|
+
"""Carryless multiplication of two 64-bit integers.
|
|
44
|
+
|
|
45
|
+
Returns (lo, hi) where result = hi * 2^64 + lo.
|
|
46
|
+
"""
|
|
47
|
+
result_lo = 0
|
|
48
|
+
result_hi = 0
|
|
49
|
+
|
|
50
|
+
for i in range(64):
|
|
51
|
+
if (b >> i) & 1:
|
|
52
|
+
# Add a shifted by i positions
|
|
53
|
+
shifted_lo = (a << i) & ((1 << 64) - 1)
|
|
54
|
+
shifted_hi = a >> (64 - i) if i > 0 else 0
|
|
55
|
+
result_lo ^= shifted_lo
|
|
56
|
+
result_hi ^= shifted_hi
|
|
57
|
+
|
|
58
|
+
return result_lo, result_hi
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _gf128_clmul128(
|
|
62
|
+
a_lo: int, a_hi: int, b_lo: int, b_hi: int
|
|
63
|
+
) -> tuple[int, int, int, int]:
|
|
64
|
+
"""Carryless multiplication of two 128-bit values.
|
|
65
|
+
|
|
66
|
+
Returns (r0, r1, r2, r3) where result = r3 * 2^192 + r2 * 2^128 + r1 * 2^64 + r0.
|
|
67
|
+
"""
|
|
68
|
+
# a_lo * b_lo -> [0:128]
|
|
69
|
+
t0_lo, t0_hi = _gf128_clmul64(a_lo, b_lo)
|
|
70
|
+
|
|
71
|
+
# a_hi * b_hi -> [128:256]
|
|
72
|
+
t1_lo, t1_hi = _gf128_clmul64(a_hi, b_hi)
|
|
73
|
+
|
|
74
|
+
# a_lo * b_hi -> [64:192]
|
|
75
|
+
t2_lo, t2_hi = _gf128_clmul64(a_lo, b_hi)
|
|
76
|
+
|
|
77
|
+
# a_hi * b_lo -> [64:192]
|
|
78
|
+
t3_lo, t3_hi = _gf128_clmul64(a_hi, b_lo)
|
|
79
|
+
|
|
80
|
+
# Combine cross terms
|
|
81
|
+
mid_lo = t2_lo ^ t3_lo
|
|
82
|
+
mid_hi = t2_hi ^ t3_hi
|
|
83
|
+
|
|
84
|
+
# Result accumulation
|
|
85
|
+
r0 = t0_lo
|
|
86
|
+
r1 = t0_hi ^ mid_lo
|
|
87
|
+
r2 = t1_lo ^ mid_hi
|
|
88
|
+
r3 = t1_hi
|
|
89
|
+
|
|
90
|
+
# Handle carry from r1 to r2 (carryless, just XOR overflow)
|
|
91
|
+
# In carryless arithmetic, there's no carry propagation
|
|
92
|
+
|
|
93
|
+
return r0, r1, r2, r3
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _gf128_reduce(r0: int, r1: int, r2: int, r3: int) -> tuple[int, int]:
|
|
97
|
+
"""Reduce 256-bit polynomial modulo P(x) = x^128 + x^7 + x^2 + x + 1.
|
|
98
|
+
|
|
99
|
+
Returns (lo, hi) representing the 128-bit result.
|
|
100
|
+
"""
|
|
101
|
+
# Reduction: x^128 = x^7 + x^2 + x + 1 (mod P)
|
|
102
|
+
# So we need to reduce r2 and r3 into r0 and r1
|
|
103
|
+
|
|
104
|
+
# r3 contributes at positions [192:256], which after reduction affects [64:128] and [0:64]
|
|
105
|
+
# r2 contributes at positions [128:192], which after reduction affects [0:64]
|
|
106
|
+
|
|
107
|
+
# First, reduce r3 (bits 192-255)
|
|
108
|
+
# x^192 = x^64 * x^128 = x^64 * (x^7 + x^2 + x + 1)
|
|
109
|
+
# = x^71 + x^66 + x^65 + x^64
|
|
110
|
+
# x^256 is beyond our range, but r3 represents bits [192:256]
|
|
111
|
+
|
|
112
|
+
# For each bit position p in [192:255] that is set:
|
|
113
|
+
# x^p = x^(p-128) * x^128 = x^(p-128) * 0x87
|
|
114
|
+
# This means bit at position p reduces to XOR with 0x87 shifted by (p-128)
|
|
115
|
+
|
|
116
|
+
# Simpler approach: reduce in two stages
|
|
117
|
+
|
|
118
|
+
# Stage 1: Reduce r3 (affects r1 and r0 after multiple reductions)
|
|
119
|
+
# r3 * x^192 mod P = r3 * x^64 * (x^7 + x^2 + x + 1)
|
|
120
|
+
q3_lo, q3_hi = _gf128_clmul64(r3, _GF128_POLYNOMIAL)
|
|
121
|
+
# This gives us bits at [64+0:64+128] = [64:192]
|
|
122
|
+
# So it affects r1 and r2
|
|
123
|
+
|
|
124
|
+
r1 ^= q3_lo
|
|
125
|
+
r2 ^= q3_hi
|
|
126
|
+
|
|
127
|
+
# Stage 2: Reduce r2 (affects r0 and r1)
|
|
128
|
+
# r2 * x^128 mod P = r2 * 0x87
|
|
129
|
+
q2_lo, q2_hi = _gf128_clmul64(r2, _GF128_POLYNOMIAL)
|
|
130
|
+
# This gives bits at [0:128]
|
|
131
|
+
|
|
132
|
+
r0 ^= q2_lo
|
|
133
|
+
r1 ^= q2_hi
|
|
134
|
+
|
|
135
|
+
return r0, r1
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def gf128_mul_single(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
|
139
|
+
"""Multiply two GF(2^128) elements.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
a: Shape (2,) uint64 array representing a 128-bit element [lo, hi]
|
|
143
|
+
b: Shape (2,) uint64 array representing a 128-bit element [lo, hi]
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Shape (2,) uint64 array representing the product
|
|
147
|
+
"""
|
|
148
|
+
a_lo, a_hi = int(a[0]), int(a[1])
|
|
149
|
+
b_lo, b_hi = int(b[0]), int(b[1])
|
|
150
|
+
|
|
151
|
+
r0, r1, r2, r3 = _gf128_clmul128(a_lo, a_hi, b_lo, b_hi)
|
|
152
|
+
res_lo, res_hi = _gf128_reduce(r0, r1, r2, r3)
|
|
153
|
+
|
|
154
|
+
return np.array(
|
|
155
|
+
[res_lo & ((1 << 64) - 1), res_hi & ((1 << 64) - 1)], dtype=np.uint64
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def gf128_mul_batch(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
|
160
|
+
"""Batch multiply GF(2^128) elements.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
a: Shape (..., 2) uint64 array
|
|
164
|
+
b: Shape (..., 2) uint64 array
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
Shape (..., 2) uint64 array of products
|
|
168
|
+
"""
|
|
169
|
+
original_shape = a.shape
|
|
170
|
+
a_flat = a.reshape(-1, 2)
|
|
171
|
+
b_flat = b.reshape(-1, 2)
|
|
172
|
+
n = a_flat.shape[0]
|
|
173
|
+
|
|
174
|
+
result = np.zeros_like(a_flat)
|
|
175
|
+
for i in range(n):
|
|
176
|
+
result[i] = gf128_mul_single(a_flat[i], b_flat[i])
|
|
177
|
+
|
|
178
|
+
return result.reshape(original_shape)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
# =============================================================================
|
|
182
|
+
# OKVS (Oblivious Key-Value Store) - 3-Hash Garbled Cuckoo Table
|
|
183
|
+
# =============================================================================
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _hash_key_py(key: int, m: int, seed: tuple[int, int]) -> tuple[int, int, int]:
|
|
187
|
+
"""Hash a key to 3 distinct indices using simple polynomial hashing.
|
|
188
|
+
|
|
189
|
+
This is a pure Python approximation of the AES-based hash in C++.
|
|
190
|
+
For compatibility, we use a deterministic hash based on the key.
|
|
191
|
+
"""
|
|
192
|
+
# Simple polynomial hash (not as secure as AES, but deterministic)
|
|
193
|
+
s0, s1 = seed
|
|
194
|
+
|
|
195
|
+
# Mix key with seed
|
|
196
|
+
h1 = ((key * GOLDEN_RATIO_64) ^ s0) & ((1 << 64) - 1)
|
|
197
|
+
h2 = ((key * SPLITMIX64_GAMMA_2) ^ s1) & ((1 << 64) - 1)
|
|
198
|
+
|
|
199
|
+
# Additional mixing
|
|
200
|
+
h1 = ((h1 ^ (h1 >> 33)) * SPLITMIX64_GAMMA_3) & ((1 << 64) - 1)
|
|
201
|
+
h2 = ((h2 ^ (h2 >> 33)) * SPLITMIX64_GAMMA_4) & ((1 << 64) - 1)
|
|
202
|
+
|
|
203
|
+
idx1 = h1 % m
|
|
204
|
+
idx2 = h2 % m
|
|
205
|
+
idx3 = (h1 ^ h2) % m
|
|
206
|
+
|
|
207
|
+
# Enforce distinct indices
|
|
208
|
+
if idx2 == idx1:
|
|
209
|
+
idx2 = (idx2 + 1) % m
|
|
210
|
+
if idx3 == idx1 or idx3 == idx2:
|
|
211
|
+
idx3 = (idx3 + 1) % m
|
|
212
|
+
if idx3 == idx1 or idx3 == idx2:
|
|
213
|
+
idx3 = (idx3 + 1) % m
|
|
214
|
+
|
|
215
|
+
return int(idx1), int(idx2), int(idx3)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def okvs_solve(
|
|
219
|
+
keys: np.ndarray,
|
|
220
|
+
values: np.ndarray,
|
|
221
|
+
m: int,
|
|
222
|
+
seed: tuple[int, int] = (0xDEADBEEF, 0xCAFEBABE),
|
|
223
|
+
) -> np.ndarray:
|
|
224
|
+
"""Solve the OKVS system using peeling algorithm.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
keys: Shape (n,) uint64 array of keys
|
|
228
|
+
values: Shape (n, 2) uint64 array of values (128-bit each)
|
|
229
|
+
m: Size of output storage
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
Shape (m, 2) uint64 array representing the OKVS storage
|
|
233
|
+
"""
|
|
234
|
+
n = len(keys)
|
|
235
|
+
|
|
236
|
+
# Build graph: for each row, compute its 3 column indices
|
|
237
|
+
rows = []
|
|
238
|
+
col_to_rows: dict[int, list[int]] = {j: [] for j in range(m)}
|
|
239
|
+
|
|
240
|
+
for i in range(n):
|
|
241
|
+
h1, h2, h3 = _hash_key_py(int(keys[i]), m, seed)
|
|
242
|
+
rows.append((h1, h2, h3))
|
|
243
|
+
col_to_rows[h1].append(i)
|
|
244
|
+
col_to_rows[h2].append(i)
|
|
245
|
+
col_to_rows[h3].append(i)
|
|
246
|
+
|
|
247
|
+
# Compute column degrees
|
|
248
|
+
col_degree = [len(col_to_rows[j]) for j in range(m)]
|
|
249
|
+
|
|
250
|
+
# Initialize peel queue with degree-1 columns
|
|
251
|
+
peel_queue = [j for j in range(m) if col_degree[j] == 1]
|
|
252
|
+
|
|
253
|
+
row_removed = [False] * n
|
|
254
|
+
col_removed = [False] * m
|
|
255
|
+
assignment_stack: list[tuple[int, int]] = [] # (col, row)
|
|
256
|
+
|
|
257
|
+
head = 0
|
|
258
|
+
while head < len(peel_queue):
|
|
259
|
+
j = peel_queue[head]
|
|
260
|
+
head += 1
|
|
261
|
+
|
|
262
|
+
if col_removed[j]:
|
|
263
|
+
continue
|
|
264
|
+
|
|
265
|
+
# Find the single active row for this column
|
|
266
|
+
owner_row = -1
|
|
267
|
+
for r_idx in col_to_rows[j]:
|
|
268
|
+
if not row_removed[r_idx]:
|
|
269
|
+
owner_row = r_idx
|
|
270
|
+
break
|
|
271
|
+
|
|
272
|
+
if owner_row == -1:
|
|
273
|
+
col_removed[j] = True
|
|
274
|
+
continue
|
|
275
|
+
|
|
276
|
+
# Peel this (column, row) pair
|
|
277
|
+
assignment_stack.append((j, owner_row))
|
|
278
|
+
col_removed[j] = True
|
|
279
|
+
row_removed[owner_row] = True
|
|
280
|
+
|
|
281
|
+
# Update neighbor column degrees
|
|
282
|
+
h1, h2, h3 = rows[owner_row]
|
|
283
|
+
for neighbor in (h1, h2, h3):
|
|
284
|
+
if neighbor == j or col_removed[neighbor]:
|
|
285
|
+
continue
|
|
286
|
+
col_degree[neighbor] -= 1
|
|
287
|
+
if col_degree[neighbor] == 1:
|
|
288
|
+
peel_queue.append(neighbor)
|
|
289
|
+
|
|
290
|
+
if len(assignment_stack) != n:
|
|
291
|
+
raise RuntimeError(
|
|
292
|
+
f"OKVS core detected. Failed to peel all rows. "
|
|
293
|
+
f"n={n}, m={m}, solved={len(assignment_stack)}"
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
# Back substitution (solve in reverse order)
|
|
297
|
+
output = np.zeros((m, 2), dtype=np.uint64)
|
|
298
|
+
|
|
299
|
+
for col, row in reversed(assignment_stack):
|
|
300
|
+
h1, h2, h3 = rows[row]
|
|
301
|
+
# Current sum of columns in this row
|
|
302
|
+
current_sum = output[h1] ^ output[h2] ^ output[h3]
|
|
303
|
+
# Compute value needed for col to make sum equal target
|
|
304
|
+
target = values[row]
|
|
305
|
+
diff = target ^ current_sum
|
|
306
|
+
output[col] = diff
|
|
307
|
+
|
|
308
|
+
return output
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def okvs_decode(
|
|
312
|
+
keys: np.ndarray,
|
|
313
|
+
storage: np.ndarray,
|
|
314
|
+
m: int,
|
|
315
|
+
seed: tuple[int, int] = (0xDEADBEEF, 0xCAFEBABE),
|
|
316
|
+
) -> np.ndarray:
|
|
317
|
+
"""Decode values from OKVS storage.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
keys: Shape (n,) uint64 array of keys to query
|
|
321
|
+
storage: Shape (m, 2) uint64 array (the solved OKVS)
|
|
322
|
+
m: Size of storage
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
Shape (n, 2) uint64 array of decoded values
|
|
326
|
+
"""
|
|
327
|
+
n = len(keys)
|
|
328
|
+
output = np.zeros((n, 2), dtype=np.uint64)
|
|
329
|
+
|
|
330
|
+
for i in range(n):
|
|
331
|
+
h1, h2, h3 = _hash_key_py(int(keys[i]), m, seed)
|
|
332
|
+
output[i] = storage[h1] ^ storage[h2] ^ storage[h3]
|
|
333
|
+
|
|
334
|
+
return output
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
# =============================================================================
|
|
338
|
+
# AES-128 Expansion (PRG Fallback)
|
|
339
|
+
# =============================================================================
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def aes_expand(seeds: np.ndarray, length: int) -> np.ndarray:
|
|
343
|
+
"""Expand seeds to pseudorandom sequence.
|
|
344
|
+
|
|
345
|
+
This is a fallback using NumPy's PRNG instead of AES-NI.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
seeds: Shape (num_seeds, 2) uint64 array of 128-bit seeds
|
|
349
|
+
length: Number of 128-bit blocks to generate per seed
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
Shape (num_seeds, length, 2) uint64 array
|
|
353
|
+
"""
|
|
354
|
+
num_seeds = seeds.shape[0]
|
|
355
|
+
output = np.zeros((num_seeds, length, 2), dtype=np.uint64)
|
|
356
|
+
|
|
357
|
+
for i in range(num_seeds):
|
|
358
|
+
seed_val = [int(seeds[i, 0]), int(seeds[i, 1])]
|
|
359
|
+
rng = np.random.default_rng(seed_val)
|
|
360
|
+
output[i] = rng.integers(
|
|
361
|
+
0, 0xFFFFFFFFFFFFFFFF, size=(length, 2), dtype=np.uint64
|
|
362
|
+
)
|
|
363
|
+
return output
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
# =============================================================================
|
|
367
|
+
# LDPC Encoding (Sparse)
|
|
368
|
+
# =============================================================================
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def ldpc_encode(
|
|
372
|
+
message: np.ndarray, h_indices: np.ndarray, h_indptr: np.ndarray, m: int
|
|
373
|
+
) -> np.ndarray:
|
|
374
|
+
"""Compute syndrome S = H @ message using sparse CSR representation.
|
|
375
|
+
|
|
376
|
+
This is the fallback when C++ kernel is not available.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
message: (N, 2) uint64 message vector
|
|
380
|
+
h_indices: CSR indices array for H
|
|
381
|
+
h_indptr: CSR indptr array for H (length m+1)
|
|
382
|
+
m: Number of rows in H (syndrome length)
|
|
383
|
+
|
|
384
|
+
Returns:
|
|
385
|
+
(m, 2) uint64 syndrome vector
|
|
386
|
+
"""
|
|
387
|
+
syndrome = np.zeros((m, 2), dtype=np.uint64)
|
|
388
|
+
|
|
389
|
+
for i in range(m):
|
|
390
|
+
# Get column indices for row i
|
|
391
|
+
start, end = int(h_indptr[i]), int(h_indptr[i + 1])
|
|
392
|
+
cols = h_indices[start:end]
|
|
393
|
+
|
|
394
|
+
# XOR all selected message elements
|
|
395
|
+
for j in cols:
|
|
396
|
+
syndrome[i] ^= message[int(j)]
|
|
397
|
+
|
|
398
|
+
return syndrome
|