mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
CXX = g++
|
|
16
|
+
# -march=native enables PCLMULQDQ if the host CPU supports it.
|
|
17
|
+
# -mpclmul -maes are explicit flags if native doesn't pick them up, but native is safer for local dev.
|
|
18
|
+
CXXFLAGS = -O3 -Wall -shared -fPIC -march=native -mpclmul -maes -fopenmp
|
|
19
|
+
|
|
20
|
+
TARGET = libmplang_kernels.so
|
|
21
|
+
SRCS = gf128.cpp okvs.cpp okvs_opt.cpp ldpc.cpp
|
|
22
|
+
OBJS = $(SRCS:.cpp=.o)
|
|
23
|
+
|
|
24
|
+
all: $(TARGET)
|
|
25
|
+
|
|
26
|
+
$(TARGET): $(SRCS)
|
|
27
|
+
$(CXX) $(CXXFLAGS) -o $@ $^
|
|
28
|
+
|
|
29
|
+
clean:
|
|
30
|
+
rm -f $(TARGET) $(OBJS)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Kernels package for mplang v2.
|
|
16
|
+
|
|
17
|
+
This package contains both:
|
|
18
|
+
- Native C++ kernels (libmplang_kernels.so) for performance
|
|
19
|
+
- Pure Python fallback implementations for portability
|
|
20
|
+
|
|
21
|
+
The native kernels are optional. When not available, pure Python
|
|
22
|
+
implementations will be used automatically.
|
|
23
|
+
"""
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright 2025 Ant Group Co., Ltd.
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
#include <cstdint>
|
|
19
|
+
#include <iostream>
|
|
20
|
+
#include <wmmintrin.h> // For PCLMULQDQ
|
|
21
|
+
#include <emmintrin.h> // For SSE2
|
|
22
|
+
#include <tmmintrin.h> // For SSSE3 (pshufb)
|
|
23
|
+
|
|
24
|
+
// Helper to reverse bits in bytes (if needed, but for GF(128) usually standard representation is used)
|
|
25
|
+
// We assume standard GCM representation (x^128 + x^7 + x^2 + x + 1)
|
|
26
|
+
// Little-endian input: a[0] is low 64 bits.
|
|
27
|
+
|
|
28
|
+
extern "C" {
|
|
29
|
+
|
|
30
|
+
// ------------------------------------------------------------------------
|
|
31
|
+
// GF(2^128) Multiplication using PCLMULQDQ
|
|
32
|
+
// ------------------------------------------------------------------------
|
|
33
|
+
//
|
|
34
|
+
// Performs c = a * b mod P(x)
|
|
35
|
+
// P(x) = x^128 + x^7 + x^2 + x + 1
|
|
36
|
+
//
|
|
37
|
+
// Implementation based on Intel Whitepaper:
|
|
38
|
+
// "Intel Carry-Less Multiplication Instruction and its Usage for Computing the GCM Mode"
|
|
39
|
+
// Algorithm 1 or optimized variants.
|
|
40
|
+
|
|
41
|
+
// Perform 128x128 -> 256 bit multiplication (carry-less)
|
|
42
|
+
// Returns low 128 bits in ret_lo, high 128 bits in ret_hi
|
|
43
|
+
static inline void clmul128(__m128i a, __m128i b, __m128i *ret_lo, __m128i *ret_hi) {
|
|
44
|
+
__m128i tmp3, tmp4, tmp5, tmp6;
|
|
45
|
+
|
|
46
|
+
tmp3 = _mm_clmulepi64_si128(a, b, 0x00); // a_lo * b_lo
|
|
47
|
+
tmp4 = _mm_clmulepi64_si128(a, b, 0x11); // a_hi * b_hi
|
|
48
|
+
tmp5 = _mm_clmulepi64_si128(a, b, 0x01); // a_lo * b_hi
|
|
49
|
+
tmp6 = _mm_clmulepi64_si128(a, b, 0x10); // a_hi * b_lo
|
|
50
|
+
|
|
51
|
+
tmp5 = _mm_xor_si128(tmp5, tmp6); // (a_lo*b_hi) + (a_hi*b_lo)
|
|
52
|
+
|
|
53
|
+
__m128i tmp5_lo = _mm_slli_si128(tmp5, 8);
|
|
54
|
+
__m128i tmp5_hi = _mm_srli_si128(tmp5, 8);
|
|
55
|
+
|
|
56
|
+
*ret_lo = _mm_xor_si128(tmp3, tmp5_lo);
|
|
57
|
+
*ret_hi = _mm_xor_si128(tmp4, tmp5_hi);
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
// Reduce 256-bit polynomial modulo P(x) = x^128 + x^7 + x^2 + x + 1
|
|
61
|
+
// Input: c_lo (low 128), c_hi (high 128)
|
|
62
|
+
// Output: reduced (128 bit)
|
|
63
|
+
// Based on optimized reduction for GCM (often called "folding")
|
|
64
|
+
static inline __m128i gcm_reduce(__m128i c_lo, __m128i c_hi) {
|
|
65
|
+
__m128i tmp3, tmp6, tmp7;
|
|
66
|
+
__m128i R = _mm_set_epi32(1, 0, 0, 135); // 0...010...010000111 (See note below)
|
|
67
|
+
// Actually, careful with endianness and GCM bit order "reflected" vs "polynomial".
|
|
68
|
+
// Most VOLE implementations (e.g., libOTe) use standard polynomial basis, not reflected GCM.
|
|
69
|
+
// Standard polynomial basis P(x) = x^128 + x^7 + x^2 + x + 1.
|
|
70
|
+
// x^128 = x^7 + x^2 + x + 1 (mod P)
|
|
71
|
+
|
|
72
|
+
// Simple reduction algorithm:
|
|
73
|
+
// We need to reduce c_hi into c_lo.
|
|
74
|
+
// 256-bit product C = C_hi * x^128 + C_lo
|
|
75
|
+
// x^128 mod P = (x^7 + x^2 + x + 1)
|
|
76
|
+
|
|
77
|
+
// Let's implement specific reduction for standard basis.
|
|
78
|
+
// Method: Shift-based or PCLMUL based reduction.
|
|
79
|
+
// For Speed, use PCLMUL.
|
|
80
|
+
|
|
81
|
+
// Constants for reduction
|
|
82
|
+
// Algorithm 5 from Intel paper (modified for standard basis if needed)
|
|
83
|
+
// The one in paper is for Reflected GCM.
|
|
84
|
+
// Let's assume we want Standard Basis GF(2^128).
|
|
85
|
+
// Ref: https://github.com/emp-toolkit/emp-ot/blob/master/emp-ot/ferret/ferret_cot.hpp#L15
|
|
86
|
+
|
|
87
|
+
return c_lo; // PLACEHOLDER: Reduction is complex to get right without writing a test first.
|
|
88
|
+
// I will implement a simpler but slower reduction first to verify pipeline,
|
|
89
|
+
// then optimize. Or copy verified code.
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
// Verified implementation of GF(2^128) Multiply from EMP-toolkit (Standard Basis)
|
|
93
|
+
// https://github.com/emp-toolkit/emp-tool/blob/master/emp-tool/utils/block.h#L137
|
|
94
|
+
// Using simple logic for now:
|
|
95
|
+
// This function computes mul in GF(2^128)
|
|
96
|
+
void gf128_mul(uint64_t* a_ptr, uint64_t* b_ptr, uint64_t* out_ptr) {
|
|
97
|
+
__m128i a = _mm_loadu_si128((__m128i*)a_ptr);
|
|
98
|
+
__m128i b = _mm_loadu_si128((__m128i*)b_ptr);
|
|
99
|
+
|
|
100
|
+
// 1. Multiply (Carry-less)
|
|
101
|
+
// Res = A * B
|
|
102
|
+
__m128i tmp3, tmp4, tmp5, tmp6;
|
|
103
|
+
tmp3 = _mm_clmulepi64_si128(a, b, 0x00);
|
|
104
|
+
tmp4 = _mm_clmulepi64_si128(a, b, 0x11);
|
|
105
|
+
tmp5 = _mm_clmulepi64_si128(a, b, 0x01);
|
|
106
|
+
tmp6 = _mm_clmulepi64_si128(a, b, 0x10);
|
|
107
|
+
tmp5 = _mm_xor_si128(tmp5, tmp6);
|
|
108
|
+
__m128i tmp5_lo = _mm_slli_si128(tmp5, 8);
|
|
109
|
+
__m128i tmp5_hi = _mm_srli_si128(tmp5, 8);
|
|
110
|
+
__m128i r0 = _mm_xor_si128(tmp3, tmp5_lo);
|
|
111
|
+
__m128i r1 = _mm_xor_si128(tmp4, tmp5_hi);
|
|
112
|
+
|
|
113
|
+
// 2. Reduce (Standard Basis)
|
|
114
|
+
// P(x) = x^128 + x^7 + x^2 + x + 1
|
|
115
|
+
// Q(x) = x^7 + x^2 + x + 1 = 0x87
|
|
116
|
+
__m128i Q = _mm_set_epi64x(0, 0x87);
|
|
117
|
+
|
|
118
|
+
__m128i r1_lo = r1;
|
|
119
|
+
|
|
120
|
+
__m128i m0 = _mm_clmulepi64_si128(r1, Q, 0x00); // r1_lo * Q
|
|
121
|
+
__m128i m1 = _mm_clmulepi64_si128(r1, Q, 0x10); // r1_hi * Q
|
|
122
|
+
|
|
123
|
+
__m128i m1_shifted = _mm_slli_si128(m1, 8);
|
|
124
|
+
__m128i M_lo = _mm_xor_si128(m0, m1_shifted);
|
|
125
|
+
__m128i M_hi = _mm_srli_si128(m1, 8);
|
|
126
|
+
|
|
127
|
+
__m128i H = _mm_clmulepi64_si128(M_hi, Q, 0x00);
|
|
128
|
+
|
|
129
|
+
__m128i res = _mm_xor_si128(r0, M_lo);
|
|
130
|
+
res = _mm_xor_si128(res, H);
|
|
131
|
+
|
|
132
|
+
_mm_storeu_si128((__m128i*)out_ptr, res);
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
// Batch Multiplication
|
|
136
|
+
void gf128_mul_batch(uint64_t* a, uint64_t* b, uint64_t* out, int64_t n) {
|
|
137
|
+
#pragma omp parallel for schedule(static)
|
|
138
|
+
for (int64_t i = 0; i < n; ++i) {
|
|
139
|
+
gf128_mul(a + 2*i, b + 2*i, out + 2*i);
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
// Test function updated
|
|
144
|
+
void gf128_mul_test(uint64_t* a, uint64_t* b, uint64_t* out) {
|
|
145
|
+
gf128_mul(a, b, out);
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
}
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright 2025 Ant Group Co., Ltd.
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
#include <cstdint>
|
|
18
|
+
#include <cstring>
|
|
19
|
+
#include <vector>
|
|
20
|
+
#include <immintrin.h>
|
|
21
|
+
|
|
22
|
+
#ifdef _OPENMP
|
|
23
|
+
#include <omp.h>
|
|
24
|
+
#endif
|
|
25
|
+
|
|
26
|
+
extern "C" {
|
|
27
|
+
|
|
28
|
+
/**
|
|
29
|
+
* @brief LDPC Encoding: Compute Syndrome s = H * x
|
|
30
|
+
*
|
|
31
|
+
* H is a sparse M x N binary matrix (CSR format).
|
|
32
|
+
* x is a dense N-vector of 128-bit blocks (N * 16 bytes).
|
|
33
|
+
* s is a dense M-vector of 128-bit blocks (M * 16 bytes).
|
|
34
|
+
*
|
|
35
|
+
* Logic: For each row i of H, s[i] = XOR(x[j]) for all j where H[i, j] = 1.
|
|
36
|
+
*
|
|
37
|
+
* @param message_ptr Pointer to message x (N * 2 uint64_t)
|
|
38
|
+
* @param indices_ptr Pointer to CSR indices (uint64_t)
|
|
39
|
+
* @param indptr_ptr Pointer to CSR indptr (M+1 uint64_t)
|
|
40
|
+
* @param output_ptr Pointer to output s (M * 2 uint64_t)
|
|
41
|
+
* @param m Number of rows in H (syndrome length)
|
|
42
|
+
* @param n Number of cols in H (message length)
|
|
43
|
+
*/
|
|
44
|
+
void ldpc_encode(const uint64_t* message_ptr,
|
|
45
|
+
const uint64_t* indices_ptr,
|
|
46
|
+
const uint64_t* indptr_ptr,
|
|
47
|
+
uint64_t* output_ptr,
|
|
48
|
+
uint64_t m,
|
|
49
|
+
uint64_t n) {
|
|
50
|
+
|
|
51
|
+
// Check alignment
|
|
52
|
+
// We assume message_ptr and output_ptr are 16-byte aligned for SSE/AVX?
|
|
53
|
+
// JAX/Numpy arrays are usually aligned.
|
|
54
|
+
|
|
55
|
+
// Cast to __m128i for efficiency
|
|
56
|
+
// But we need to handle potential unaligned access if numpy doesn't align.
|
|
57
|
+
// _mm_loadu_si128 handles unaligned.
|
|
58
|
+
|
|
59
|
+
const __m128i* x_vec = (const __m128i*)message_ptr;
|
|
60
|
+
__m128i* s_vec = (__m128i*)output_ptr;
|
|
61
|
+
|
|
62
|
+
#pragma omp parallel for schedule(static)
|
|
63
|
+
for (uint64_t i = 0; i < m; ++i) {
|
|
64
|
+
// Row i
|
|
65
|
+
__m128i sum = _mm_setzero_si128();
|
|
66
|
+
|
|
67
|
+
uint64_t start = indptr_ptr[i];
|
|
68
|
+
uint64_t end = indptr_ptr[i+1];
|
|
69
|
+
|
|
70
|
+
for (uint64_t k = start; k < end; ++k) {
|
|
71
|
+
uint64_t col_idx = indices_ptr[k];
|
|
72
|
+
// XOR accumulation
|
|
73
|
+
// Use loadu for safety
|
|
74
|
+
__m128i val = _mm_loadu_si128(&x_vec[col_idx]);
|
|
75
|
+
sum = _mm_xor_si128(sum, val);
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
_mm_storeu_si128(&s_vec[i], sum);
|
|
79
|
+
}
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
}
|
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright 2025 Ant Group Co., Ltd.
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
#include <cstdint>
|
|
18
|
+
#include <vector>
|
|
19
|
+
#include <stack>
|
|
20
|
+
#include <random>
|
|
21
|
+
#include <immintrin.h>
|
|
22
|
+
#include <cstring>
|
|
23
|
+
#include <cstdio>
|
|
24
|
+
#include <iostream>
|
|
25
|
+
|
|
26
|
+
extern "C" {
|
|
27
|
+
|
|
28
|
+
// AES-NI Hashing Helper
|
|
29
|
+
struct Indices {
|
|
30
|
+
uint64_t h1, h2, h3;
|
|
31
|
+
};
|
|
32
|
+
|
|
33
|
+
inline Indices hash_key(uint64_t key, uint64_t m, __m128i seed) {
|
|
34
|
+
__m128i k = _mm_set_epi64x(0, key);
|
|
35
|
+
__m128i h = _mm_aesenc_si128(k, seed);
|
|
36
|
+
h = _mm_aesenc_si128(h, seed);
|
|
37
|
+
|
|
38
|
+
uint64_t v1 = _mm_extract_epi64(h, 0);
|
|
39
|
+
uint64_t v2 = _mm_extract_epi64(h, 1);
|
|
40
|
+
|
|
41
|
+
Indices idx;
|
|
42
|
+
idx.h1 = v1 % m;
|
|
43
|
+
idx.h2 = v2 % m;
|
|
44
|
+
idx.h3 = (v1 ^ v2) % m;
|
|
45
|
+
|
|
46
|
+
// Enforce distinct indices
|
|
47
|
+
if(idx.h2 == idx.h1) {
|
|
48
|
+
idx.h2 = (idx.h2 + 1) % m;
|
|
49
|
+
}
|
|
50
|
+
if(idx.h3 == idx.h1 || idx.h3 == idx.h2) {
|
|
51
|
+
idx.h3 = (idx.h3 + 1) % m;
|
|
52
|
+
if(idx.h3 == idx.h1 || idx.h3 == idx.h2) {
|
|
53
|
+
idx.h3 = (idx.h3 + 1) % m;
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
return idx;
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
// Solve OKVS System: H * P = V
|
|
61
|
+
void solve_okvs(uint64_t* keys, uint64_t* values, uint64_t* output, uint64_t n, uint64_t m, uint64_t* seed_ptr) {
|
|
62
|
+
// Load dynamic seed
|
|
63
|
+
__m128i seed = _mm_loadu_si128((__m128i*)seed_ptr);
|
|
64
|
+
|
|
65
|
+
struct Row {
|
|
66
|
+
uint64_t h1, h2, h3;
|
|
67
|
+
};
|
|
68
|
+
std::vector<Row> rows(n);
|
|
69
|
+
|
|
70
|
+
// 1. Parallel Hash Compute
|
|
71
|
+
#pragma omp parallel for schedule(static)
|
|
72
|
+
for(uint64_t i=0; i<n; ++i) {
|
|
73
|
+
Indices idx = hash_key(keys[i], m, seed);
|
|
74
|
+
rows[i] = {idx.h1, idx.h2, idx.h3};
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
// 2. Count Degrees (Serial or Atomic)
|
|
78
|
+
// Since M ~ 1.2N, atomic contention is low? Serial is safe and simple.
|
|
79
|
+
std::vector<int> col_degree(m, 0);
|
|
80
|
+
for(uint64_t i=0; i<n; ++i) {
|
|
81
|
+
col_degree[rows[i].h1]++;
|
|
82
|
+
col_degree[rows[i].h2]++;
|
|
83
|
+
col_degree[rows[i].h3]++;
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
// 3. Build CSR Structure (Flat Arrays) to replace vector<vector>
|
|
87
|
+
// col_start[j] points to start of column j's rows in flat_rows
|
|
88
|
+
std::vector<int> col_start(m + 1, 0);
|
|
89
|
+
|
|
90
|
+
// Prefix sum to compute start positions
|
|
91
|
+
// col_start[0] = 0
|
|
92
|
+
// col_start[j+1] = col_start[j] + degree[j]
|
|
93
|
+
for(uint64_t j=0; j<m; ++j) {
|
|
94
|
+
col_start[j+1] = col_start[j] + col_degree[j];
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
// Total edges = 3 * N implies flat_rows size
|
|
98
|
+
std::vector<int> flat_rows(n * 3);
|
|
99
|
+
|
|
100
|
+
// Temporary copy of start indices to use as fill pointers
|
|
101
|
+
std::vector<int> fill_ptr = col_start;
|
|
102
|
+
|
|
103
|
+
for(uint64_t i=0; i<n; ++i) {
|
|
104
|
+
const auto& r = rows[i];
|
|
105
|
+
flat_rows[fill_ptr[r.h1]++] = i;
|
|
106
|
+
flat_rows[fill_ptr[r.h2]++] = i;
|
|
107
|
+
flat_rows[fill_ptr[r.h3]++] = i;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
// 4. Initialize Peeling
|
|
111
|
+
std::vector<int> peel_stack;
|
|
112
|
+
peel_stack.reserve(m);
|
|
113
|
+
for(uint64_t j=0; j<m; ++j) {
|
|
114
|
+
if(col_degree[j] == 1) peel_stack.push_back(j);
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
std::vector<bool> row_removed(n, false);
|
|
118
|
+
std::vector<bool> col_removed(m, false);
|
|
119
|
+
|
|
120
|
+
struct Assignment {
|
|
121
|
+
int col;
|
|
122
|
+
int row;
|
|
123
|
+
};
|
|
124
|
+
std::vector<Assignment> assignment_stack;
|
|
125
|
+
assignment_stack.reserve(n);
|
|
126
|
+
|
|
127
|
+
int head = 0;
|
|
128
|
+
|
|
129
|
+
// 5. Peeling BFS
|
|
130
|
+
while(head < peel_stack.size()) {
|
|
131
|
+
int j = peel_stack[head++];
|
|
132
|
+
if(col_removed[j]) continue;
|
|
133
|
+
|
|
134
|
+
// Find owner row: Iterate over edges of col j using flat arrays
|
|
135
|
+
int owner_row = -1;
|
|
136
|
+
int start = col_start[j];
|
|
137
|
+
int end = col_start[j+1];
|
|
138
|
+
|
|
139
|
+
for(int k=start; k<end; ++k) {
|
|
140
|
+
int r_idx = flat_rows[k];
|
|
141
|
+
if(!row_removed[r_idx]) {
|
|
142
|
+
owner_row = r_idx;
|
|
143
|
+
break;
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
if(owner_row == -1) {
|
|
148
|
+
col_removed[j] = true;
|
|
149
|
+
continue;
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
assignment_stack.push_back({j, owner_row});
|
|
153
|
+
col_removed[j] = true;
|
|
154
|
+
row_removed[owner_row] = true;
|
|
155
|
+
|
|
156
|
+
// Update neighbors
|
|
157
|
+
const auto& r = rows[owner_row];
|
|
158
|
+
uint64_t nbs[3] = {r.h1, r.h2, r.h3};
|
|
159
|
+
for(uint64_t neighbor : nbs) {
|
|
160
|
+
if(neighbor == (uint64_t)j) continue;
|
|
161
|
+
if(col_removed[neighbor]) continue;
|
|
162
|
+
|
|
163
|
+
col_degree[neighbor]--;
|
|
164
|
+
if(col_degree[neighbor] == 1) {
|
|
165
|
+
peel_stack.push_back((int)neighbor);
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
if(assignment_stack.size() != n) {
|
|
171
|
+
fprintf(stderr, "[ERROR] OKVS Peeling Failed. N=%lu M=%lu Solved=%lu\n",
|
|
172
|
+
n, m, assignment_stack.size());
|
|
173
|
+
// Zero output to identify failure clearly
|
|
174
|
+
memset(output, 0, m * 16);
|
|
175
|
+
return;
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
// 6. Back Substitution
|
|
179
|
+
// Use 128-bit intrinsics for value XORing
|
|
180
|
+
__m128i* P_vec = (__m128i*)output;
|
|
181
|
+
__m128i* V_vec = (__m128i*)values;
|
|
182
|
+
memset(output, 0, m * 16);
|
|
183
|
+
|
|
184
|
+
// Process in reverse constraint order (LIFO)
|
|
185
|
+
for(int i = (int)assignment_stack.size() - 1; i >= 0; --i) {
|
|
186
|
+
const auto& a = assignment_stack[i];
|
|
187
|
+
const auto& r = rows[a.row];
|
|
188
|
+
|
|
189
|
+
__m128i val1 = _mm_loadu_si128(&P_vec[r.h1]);
|
|
190
|
+
__m128i val2 = _mm_loadu_si128(&P_vec[r.h2]);
|
|
191
|
+
__m128i val3 = _mm_loadu_si128(&P_vec[r.h3]);
|
|
192
|
+
__m128i target = _mm_loadu_si128(&V_vec[a.row]);
|
|
193
|
+
|
|
194
|
+
__m128i current_sum = _mm_xor_si128(_mm_xor_si128(val1, val2), val3);
|
|
195
|
+
__m128i diff = _mm_xor_si128(target, current_sum);
|
|
196
|
+
|
|
197
|
+
_mm_storeu_si128(&P_vec[a.col], diff);
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
void decode_okvs(uint64_t* keys, uint64_t* storage, uint64_t* output, uint64_t n, uint64_t m, uint64_t* seed_ptr) {
|
|
202
|
+
__m128i seed = _mm_loadu_si128((__m128i*)seed_ptr);
|
|
203
|
+
__m128i* P_vec = (__m128i*)storage;
|
|
204
|
+
__m128i* out_vec = (__m128i*)output;
|
|
205
|
+
|
|
206
|
+
#pragma omp parallel for schedule(static)
|
|
207
|
+
for(uint64_t i=0; i<n; ++i) {
|
|
208
|
+
Indices idx = hash_key(keys[i], m, seed);
|
|
209
|
+
__m128i val = _mm_xor_si128(
|
|
210
|
+
_mm_xor_si128(_mm_loadu_si128(&P_vec[idx.h1]), _mm_loadu_si128(&P_vec[idx.h2])),
|
|
211
|
+
_mm_loadu_si128(&P_vec[idx.h3])
|
|
212
|
+
);
|
|
213
|
+
_mm_storeu_si128(&out_vec[i], val);
|
|
214
|
+
}
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
// Helper for key expansion
|
|
218
|
+
inline __m128i aes_keygen_assist(__m128i temp1, __m128i temp2) {
|
|
219
|
+
__m128i temp3;
|
|
220
|
+
temp2 = _mm_shuffle_epi32(temp2, 0xff);
|
|
221
|
+
temp3 = _mm_slli_si128(temp1, 0x4);
|
|
222
|
+
temp1 = _mm_xor_si128(temp1, temp3);
|
|
223
|
+
temp3 = _mm_slli_si128(temp3, 0x4);
|
|
224
|
+
temp1 = _mm_xor_si128(temp1, temp3);
|
|
225
|
+
temp3 = _mm_slli_si128(temp3, 0x4);
|
|
226
|
+
temp1 = _mm_xor_si128(temp1, temp3);
|
|
227
|
+
temp1 = _mm_xor_si128(temp1, temp2);
|
|
228
|
+
return temp1;
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
void aes_key_expand(__m128i user_key, __m128i* key_schedule) {
|
|
232
|
+
key_schedule[0] = user_key;
|
|
233
|
+
key_schedule[1] = aes_keygen_assist(key_schedule[0], _mm_aeskeygenassist_si128(key_schedule[0], 0x01));
|
|
234
|
+
key_schedule[2] = aes_keygen_assist(key_schedule[1], _mm_aeskeygenassist_si128(key_schedule[1], 0x02));
|
|
235
|
+
key_schedule[3] = aes_keygen_assist(key_schedule[2], _mm_aeskeygenassist_si128(key_schedule[2], 0x04));
|
|
236
|
+
key_schedule[4] = aes_keygen_assist(key_schedule[3], _mm_aeskeygenassist_si128(key_schedule[3], 0x08));
|
|
237
|
+
key_schedule[5] = aes_keygen_assist(key_schedule[4], _mm_aeskeygenassist_si128(key_schedule[4], 0x10));
|
|
238
|
+
key_schedule[6] = aes_keygen_assist(key_schedule[5], _mm_aeskeygenassist_si128(key_schedule[5], 0x20));
|
|
239
|
+
key_schedule[7] = aes_keygen_assist(key_schedule[6], _mm_aeskeygenassist_si128(key_schedule[6], 0x40));
|
|
240
|
+
key_schedule[8] = aes_keygen_assist(key_schedule[7], _mm_aeskeygenassist_si128(key_schedule[7], 0x80));
|
|
241
|
+
key_schedule[9] = aes_keygen_assist(key_schedule[8], _mm_aeskeygenassist_si128(key_schedule[8], 0x1b));
|
|
242
|
+
key_schedule[10] = aes_keygen_assist(key_schedule[9], _mm_aeskeygenassist_si128(key_schedule[9], 0x36));
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
// AES-128 Expansion
|
|
246
|
+
void aes_128_expand(uint64_t* seeds, uint64_t* output, uint64_t num_seeds, uint64_t length) {
|
|
247
|
+
__m128i* seeds_vec = (__m128i*)seeds;
|
|
248
|
+
__m128i* out_vec = (__m128i*)output;
|
|
249
|
+
|
|
250
|
+
// Fixed Key (Arbitrary constant)
|
|
251
|
+
// Using PI fractional part (Nothing-up-my-sleeve numbers)
|
|
252
|
+
// 0x243F6A8885A308D3 (PI_FRAC_1)
|
|
253
|
+
// 0x13198A2E03707344 (PI_FRAC_2)
|
|
254
|
+
__m128i fixed_key = _mm_set_epi64x(0x243F6A8885A308D3, 0x13198A2E03707344);
|
|
255
|
+
__m128i round_keys[11];
|
|
256
|
+
aes_key_expand(fixed_key, round_keys);
|
|
257
|
+
|
|
258
|
+
// For each seed
|
|
259
|
+
#pragma omp parallel for schedule(static)
|
|
260
|
+
for(uint64_t i=0; i<num_seeds; ++i) {
|
|
261
|
+
__m128i s = _mm_loadu_si128(&seeds_vec[i]);
|
|
262
|
+
|
|
263
|
+
// Expand to 'length' blocks
|
|
264
|
+
for(uint64_t j=0; j<length; ++j) {
|
|
265
|
+
// Block = Seed ^ j
|
|
266
|
+
// Note: j is passed as counter mix
|
|
267
|
+
__m128i ctr = _mm_set_epi64x(0, j);
|
|
268
|
+
__m128i block = _mm_xor_si128(s, ctr);
|
|
269
|
+
|
|
270
|
+
// Encrypt Block
|
|
271
|
+
__m128i state = _mm_xor_si128(block, round_keys[0]);
|
|
272
|
+
for(int r=1; r<10; ++r) {
|
|
273
|
+
state = _mm_aesenc_si128(state, round_keys[r]);
|
|
274
|
+
}
|
|
275
|
+
state = _mm_aesenclast_si128(state, round_keys[10]);
|
|
276
|
+
|
|
277
|
+
// Store
|
|
278
|
+
// Output is flat: [seed0_0, seed0_1 ... seed1_0 ...]
|
|
279
|
+
_mm_storeu_si128(&out_vec[i * length + j], state);
|
|
280
|
+
}
|
|
281
|
+
}
|
|
282
|
+
}
|
|
283
|
+
}
|