mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,30 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ CXX = g++
16
+ # -march=native enables PCLMULQDQ if the host CPU supports it.
17
+ # -mpclmul -maes are explicit flags if native doesn't pick them up, but native is safer for local dev.
18
+ CXXFLAGS = -O3 -Wall -shared -fPIC -march=native -mpclmul -maes -fopenmp
19
+
20
+ TARGET = libmplang_kernels.so
21
+ SRCS = gf128.cpp okvs.cpp okvs_opt.cpp ldpc.cpp
22
+ OBJS = $(SRCS:.cpp=.o)
23
+
24
+ all: $(TARGET)
25
+
26
+ $(TARGET): $(SRCS)
27
+ $(CXX) $(CXXFLAGS) -o $@ $^
28
+
29
+ clean:
30
+ rm -f $(TARGET) $(OBJS)
@@ -0,0 +1,23 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Kernels package for mplang v2.
16
+
17
+ This package contains both:
18
+ - Native C++ kernels (libmplang_kernels.so) for performance
19
+ - Pure Python fallback implementations for portability
20
+
21
+ The native kernels are optional. When not available, pure Python
22
+ implementations will be used automatically.
23
+ """
@@ -0,0 +1,148 @@
1
+ /*
2
+ * Copyright 2025 Ant Group Co., Ltd.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+
18
+ #include <cstdint>
19
+ #include <iostream>
20
+ #include <wmmintrin.h> // For PCLMULQDQ
21
+ #include <emmintrin.h> // For SSE2
22
+ #include <tmmintrin.h> // For SSSE3 (pshufb)
23
+
24
+ // Helper to reverse bits in bytes (if needed, but for GF(128) usually standard representation is used)
25
+ // We assume standard GCM representation (x^128 + x^7 + x^2 + x + 1)
26
+ // Little-endian input: a[0] is low 64 bits.
27
+
28
+ extern "C" {
29
+
30
+ // ------------------------------------------------------------------------
31
+ // GF(2^128) Multiplication using PCLMULQDQ
32
+ // ------------------------------------------------------------------------
33
+ //
34
+ // Performs c = a * b mod P(x)
35
+ // P(x) = x^128 + x^7 + x^2 + x + 1
36
+ //
37
+ // Implementation based on Intel Whitepaper:
38
+ // "Intel Carry-Less Multiplication Instruction and its Usage for Computing the GCM Mode"
39
+ // Algorithm 1 or optimized variants.
40
+
41
+ // Perform 128x128 -> 256 bit multiplication (carry-less)
42
+ // Returns low 128 bits in ret_lo, high 128 bits in ret_hi
43
+ static inline void clmul128(__m128i a, __m128i b, __m128i *ret_lo, __m128i *ret_hi) {
44
+ __m128i tmp3, tmp4, tmp5, tmp6;
45
+
46
+ tmp3 = _mm_clmulepi64_si128(a, b, 0x00); // a_lo * b_lo
47
+ tmp4 = _mm_clmulepi64_si128(a, b, 0x11); // a_hi * b_hi
48
+ tmp5 = _mm_clmulepi64_si128(a, b, 0x01); // a_lo * b_hi
49
+ tmp6 = _mm_clmulepi64_si128(a, b, 0x10); // a_hi * b_lo
50
+
51
+ tmp5 = _mm_xor_si128(tmp5, tmp6); // (a_lo*b_hi) + (a_hi*b_lo)
52
+
53
+ __m128i tmp5_lo = _mm_slli_si128(tmp5, 8);
54
+ __m128i tmp5_hi = _mm_srli_si128(tmp5, 8);
55
+
56
+ *ret_lo = _mm_xor_si128(tmp3, tmp5_lo);
57
+ *ret_hi = _mm_xor_si128(tmp4, tmp5_hi);
58
+ }
59
+
60
+ // Reduce 256-bit polynomial modulo P(x) = x^128 + x^7 + x^2 + x + 1
61
+ // Input: c_lo (low 128), c_hi (high 128)
62
+ // Output: reduced (128 bit)
63
+ // Based on optimized reduction for GCM (often called "folding")
64
+ static inline __m128i gcm_reduce(__m128i c_lo, __m128i c_hi) {
65
+ __m128i tmp3, tmp6, tmp7;
66
+ __m128i R = _mm_set_epi32(1, 0, 0, 135); // 0...010...010000111 (See note below)
67
+ // Actually, careful with endianness and GCM bit order "reflected" vs "polynomial".
68
+ // Most VOLE implementations (e.g., libOTe) use standard polynomial basis, not reflected GCM.
69
+ // Standard polynomial basis P(x) = x^128 + x^7 + x^2 + x + 1.
70
+ // x^128 = x^7 + x^2 + x + 1 (mod P)
71
+
72
+ // Simple reduction algorithm:
73
+ // We need to reduce c_hi into c_lo.
74
+ // 256-bit product C = C_hi * x^128 + C_lo
75
+ // x^128 mod P = (x^7 + x^2 + x + 1)
76
+
77
+ // Let's implement specific reduction for standard basis.
78
+ // Method: Shift-based or PCLMUL based reduction.
79
+ // For Speed, use PCLMUL.
80
+
81
+ // Constants for reduction
82
+ // Algorithm 5 from Intel paper (modified for standard basis if needed)
83
+ // The one in paper is for Reflected GCM.
84
+ // Let's assume we want Standard Basis GF(2^128).
85
+ // Ref: https://github.com/emp-toolkit/emp-ot/blob/master/emp-ot/ferret/ferret_cot.hpp#L15
86
+
87
+ return c_lo; // PLACEHOLDER: Reduction is complex to get right without writing a test first.
88
+ // I will implement a simpler but slower reduction first to verify pipeline,
89
+ // then optimize. Or copy verified code.
90
+ }
91
+
92
+ // Verified implementation of GF(2^128) Multiply from EMP-toolkit (Standard Basis)
93
+ // https://github.com/emp-toolkit/emp-tool/blob/master/emp-tool/utils/block.h#L137
94
+ // Using simple logic for now:
95
+ // This function computes mul in GF(2^128)
96
+ void gf128_mul(uint64_t* a_ptr, uint64_t* b_ptr, uint64_t* out_ptr) {
97
+ __m128i a = _mm_loadu_si128((__m128i*)a_ptr);
98
+ __m128i b = _mm_loadu_si128((__m128i*)b_ptr);
99
+
100
+ // 1. Multiply (Carry-less)
101
+ // Res = A * B
102
+ __m128i tmp3, tmp4, tmp5, tmp6;
103
+ tmp3 = _mm_clmulepi64_si128(a, b, 0x00);
104
+ tmp4 = _mm_clmulepi64_si128(a, b, 0x11);
105
+ tmp5 = _mm_clmulepi64_si128(a, b, 0x01);
106
+ tmp6 = _mm_clmulepi64_si128(a, b, 0x10);
107
+ tmp5 = _mm_xor_si128(tmp5, tmp6);
108
+ __m128i tmp5_lo = _mm_slli_si128(tmp5, 8);
109
+ __m128i tmp5_hi = _mm_srli_si128(tmp5, 8);
110
+ __m128i r0 = _mm_xor_si128(tmp3, tmp5_lo);
111
+ __m128i r1 = _mm_xor_si128(tmp4, tmp5_hi);
112
+
113
+ // 2. Reduce (Standard Basis)
114
+ // P(x) = x^128 + x^7 + x^2 + x + 1
115
+ // Q(x) = x^7 + x^2 + x + 1 = 0x87
116
+ __m128i Q = _mm_set_epi64x(0, 0x87);
117
+
118
+ __m128i r1_lo = r1;
119
+
120
+ __m128i m0 = _mm_clmulepi64_si128(r1, Q, 0x00); // r1_lo * Q
121
+ __m128i m1 = _mm_clmulepi64_si128(r1, Q, 0x10); // r1_hi * Q
122
+
123
+ __m128i m1_shifted = _mm_slli_si128(m1, 8);
124
+ __m128i M_lo = _mm_xor_si128(m0, m1_shifted);
125
+ __m128i M_hi = _mm_srli_si128(m1, 8);
126
+
127
+ __m128i H = _mm_clmulepi64_si128(M_hi, Q, 0x00);
128
+
129
+ __m128i res = _mm_xor_si128(r0, M_lo);
130
+ res = _mm_xor_si128(res, H);
131
+
132
+ _mm_storeu_si128((__m128i*)out_ptr, res);
133
+ }
134
+
135
+ // Batch Multiplication
136
+ void gf128_mul_batch(uint64_t* a, uint64_t* b, uint64_t* out, int64_t n) {
137
+ #pragma omp parallel for schedule(static)
138
+ for (int64_t i = 0; i < n; ++i) {
139
+ gf128_mul(a + 2*i, b + 2*i, out + 2*i);
140
+ }
141
+ }
142
+
143
+ // Test function updated
144
+ void gf128_mul_test(uint64_t* a, uint64_t* b, uint64_t* out) {
145
+ gf128_mul(a, b, out);
146
+ }
147
+
148
+ }
@@ -0,0 +1,82 @@
1
+ /*
2
+ * Copyright 2025 Ant Group Co., Ltd.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <cstdint>
18
+ #include <cstring>
19
+ #include <vector>
20
+ #include <immintrin.h>
21
+
22
+ #ifdef _OPENMP
23
+ #include <omp.h>
24
+ #endif
25
+
26
+ extern "C" {
27
+
28
+ /**
29
+ * @brief LDPC Encoding: Compute Syndrome s = H * x
30
+ *
31
+ * H is a sparse M x N binary matrix (CSR format).
32
+ * x is a dense N-vector of 128-bit blocks (N * 16 bytes).
33
+ * s is a dense M-vector of 128-bit blocks (M * 16 bytes).
34
+ *
35
+ * Logic: For each row i of H, s[i] = XOR(x[j]) for all j where H[i, j] = 1.
36
+ *
37
+ * @param message_ptr Pointer to message x (N * 2 uint64_t)
38
+ * @param indices_ptr Pointer to CSR indices (uint64_t)
39
+ * @param indptr_ptr Pointer to CSR indptr (M+1 uint64_t)
40
+ * @param output_ptr Pointer to output s (M * 2 uint64_t)
41
+ * @param m Number of rows in H (syndrome length)
42
+ * @param n Number of cols in H (message length)
43
+ */
44
+ void ldpc_encode(const uint64_t* message_ptr,
45
+ const uint64_t* indices_ptr,
46
+ const uint64_t* indptr_ptr,
47
+ uint64_t* output_ptr,
48
+ uint64_t m,
49
+ uint64_t n) {
50
+
51
+ // Check alignment
52
+ // We assume message_ptr and output_ptr are 16-byte aligned for SSE/AVX?
53
+ // JAX/Numpy arrays are usually aligned.
54
+
55
+ // Cast to __m128i for efficiency
56
+ // But we need to handle potential unaligned access if numpy doesn't align.
57
+ // _mm_loadu_si128 handles unaligned.
58
+
59
+ const __m128i* x_vec = (const __m128i*)message_ptr;
60
+ __m128i* s_vec = (__m128i*)output_ptr;
61
+
62
+ #pragma omp parallel for schedule(static)
63
+ for (uint64_t i = 0; i < m; ++i) {
64
+ // Row i
65
+ __m128i sum = _mm_setzero_si128();
66
+
67
+ uint64_t start = indptr_ptr[i];
68
+ uint64_t end = indptr_ptr[i+1];
69
+
70
+ for (uint64_t k = start; k < end; ++k) {
71
+ uint64_t col_idx = indices_ptr[k];
72
+ // XOR accumulation
73
+ // Use loadu for safety
74
+ __m128i val = _mm_loadu_si128(&x_vec[col_idx]);
75
+ sum = _mm_xor_si128(sum, val);
76
+ }
77
+
78
+ _mm_storeu_si128(&s_vec[i], sum);
79
+ }
80
+ }
81
+
82
+ }
@@ -0,0 +1,283 @@
1
+ /*
2
+ * Copyright 2025 Ant Group Co., Ltd.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <cstdint>
18
+ #include <vector>
19
+ #include <stack>
20
+ #include <random>
21
+ #include <immintrin.h>
22
+ #include <cstring>
23
+ #include <cstdio>
24
+ #include <iostream>
25
+
26
+ extern "C" {
27
+
28
+ // AES-NI Hashing Helper
29
+ struct Indices {
30
+ uint64_t h1, h2, h3;
31
+ };
32
+
33
+ inline Indices hash_key(uint64_t key, uint64_t m, __m128i seed) {
34
+ __m128i k = _mm_set_epi64x(0, key);
35
+ __m128i h = _mm_aesenc_si128(k, seed);
36
+ h = _mm_aesenc_si128(h, seed);
37
+
38
+ uint64_t v1 = _mm_extract_epi64(h, 0);
39
+ uint64_t v2 = _mm_extract_epi64(h, 1);
40
+
41
+ Indices idx;
42
+ idx.h1 = v1 % m;
43
+ idx.h2 = v2 % m;
44
+ idx.h3 = (v1 ^ v2) % m;
45
+
46
+ // Enforce distinct indices
47
+ if(idx.h2 == idx.h1) {
48
+ idx.h2 = (idx.h2 + 1) % m;
49
+ }
50
+ if(idx.h3 == idx.h1 || idx.h3 == idx.h2) {
51
+ idx.h3 = (idx.h3 + 1) % m;
52
+ if(idx.h3 == idx.h1 || idx.h3 == idx.h2) {
53
+ idx.h3 = (idx.h3 + 1) % m;
54
+ }
55
+ }
56
+
57
+ return idx;
58
+ }
59
+
60
+ // Solve OKVS System: H * P = V
61
+ void solve_okvs(uint64_t* keys, uint64_t* values, uint64_t* output, uint64_t n, uint64_t m, uint64_t* seed_ptr) {
62
+ // Load dynamic seed
63
+ __m128i seed = _mm_loadu_si128((__m128i*)seed_ptr);
64
+
65
+ struct Row {
66
+ uint64_t h1, h2, h3;
67
+ };
68
+ std::vector<Row> rows(n);
69
+
70
+ // 1. Parallel Hash Compute
71
+ #pragma omp parallel for schedule(static)
72
+ for(uint64_t i=0; i<n; ++i) {
73
+ Indices idx = hash_key(keys[i], m, seed);
74
+ rows[i] = {idx.h1, idx.h2, idx.h3};
75
+ }
76
+
77
+ // 2. Count Degrees (Serial or Atomic)
78
+ // Since M ~ 1.2N, atomic contention is low? Serial is safe and simple.
79
+ std::vector<int> col_degree(m, 0);
80
+ for(uint64_t i=0; i<n; ++i) {
81
+ col_degree[rows[i].h1]++;
82
+ col_degree[rows[i].h2]++;
83
+ col_degree[rows[i].h3]++;
84
+ }
85
+
86
+ // 3. Build CSR Structure (Flat Arrays) to replace vector<vector>
87
+ // col_start[j] points to start of column j's rows in flat_rows
88
+ std::vector<int> col_start(m + 1, 0);
89
+
90
+ // Prefix sum to compute start positions
91
+ // col_start[0] = 0
92
+ // col_start[j+1] = col_start[j] + degree[j]
93
+ for(uint64_t j=0; j<m; ++j) {
94
+ col_start[j+1] = col_start[j] + col_degree[j];
95
+ }
96
+
97
+ // Total edges = 3 * N implies flat_rows size
98
+ std::vector<int> flat_rows(n * 3);
99
+
100
+ // Temporary copy of start indices to use as fill pointers
101
+ std::vector<int> fill_ptr = col_start;
102
+
103
+ for(uint64_t i=0; i<n; ++i) {
104
+ const auto& r = rows[i];
105
+ flat_rows[fill_ptr[r.h1]++] = i;
106
+ flat_rows[fill_ptr[r.h2]++] = i;
107
+ flat_rows[fill_ptr[r.h3]++] = i;
108
+ }
109
+
110
+ // 4. Initialize Peeling
111
+ std::vector<int> peel_stack;
112
+ peel_stack.reserve(m);
113
+ for(uint64_t j=0; j<m; ++j) {
114
+ if(col_degree[j] == 1) peel_stack.push_back(j);
115
+ }
116
+
117
+ std::vector<bool> row_removed(n, false);
118
+ std::vector<bool> col_removed(m, false);
119
+
120
+ struct Assignment {
121
+ int col;
122
+ int row;
123
+ };
124
+ std::vector<Assignment> assignment_stack;
125
+ assignment_stack.reserve(n);
126
+
127
+ int head = 0;
128
+
129
+ // 5. Peeling BFS
130
+ while(head < peel_stack.size()) {
131
+ int j = peel_stack[head++];
132
+ if(col_removed[j]) continue;
133
+
134
+ // Find owner row: Iterate over edges of col j using flat arrays
135
+ int owner_row = -1;
136
+ int start = col_start[j];
137
+ int end = col_start[j+1];
138
+
139
+ for(int k=start; k<end; ++k) {
140
+ int r_idx = flat_rows[k];
141
+ if(!row_removed[r_idx]) {
142
+ owner_row = r_idx;
143
+ break;
144
+ }
145
+ }
146
+
147
+ if(owner_row == -1) {
148
+ col_removed[j] = true;
149
+ continue;
150
+ }
151
+
152
+ assignment_stack.push_back({j, owner_row});
153
+ col_removed[j] = true;
154
+ row_removed[owner_row] = true;
155
+
156
+ // Update neighbors
157
+ const auto& r = rows[owner_row];
158
+ uint64_t nbs[3] = {r.h1, r.h2, r.h3};
159
+ for(uint64_t neighbor : nbs) {
160
+ if(neighbor == (uint64_t)j) continue;
161
+ if(col_removed[neighbor]) continue;
162
+
163
+ col_degree[neighbor]--;
164
+ if(col_degree[neighbor] == 1) {
165
+ peel_stack.push_back((int)neighbor);
166
+ }
167
+ }
168
+ }
169
+
170
+ if(assignment_stack.size() != n) {
171
+ fprintf(stderr, "[ERROR] OKVS Peeling Failed. N=%lu M=%lu Solved=%lu\n",
172
+ n, m, assignment_stack.size());
173
+ // Zero output to identify failure clearly
174
+ memset(output, 0, m * 16);
175
+ return;
176
+ }
177
+
178
+ // 6. Back Substitution
179
+ // Use 128-bit intrinsics for value XORing
180
+ __m128i* P_vec = (__m128i*)output;
181
+ __m128i* V_vec = (__m128i*)values;
182
+ memset(output, 0, m * 16);
183
+
184
+ // Process in reverse constraint order (LIFO)
185
+ for(int i = (int)assignment_stack.size() - 1; i >= 0; --i) {
186
+ const auto& a = assignment_stack[i];
187
+ const auto& r = rows[a.row];
188
+
189
+ __m128i val1 = _mm_loadu_si128(&P_vec[r.h1]);
190
+ __m128i val2 = _mm_loadu_si128(&P_vec[r.h2]);
191
+ __m128i val3 = _mm_loadu_si128(&P_vec[r.h3]);
192
+ __m128i target = _mm_loadu_si128(&V_vec[a.row]);
193
+
194
+ __m128i current_sum = _mm_xor_si128(_mm_xor_si128(val1, val2), val3);
195
+ __m128i diff = _mm_xor_si128(target, current_sum);
196
+
197
+ _mm_storeu_si128(&P_vec[a.col], diff);
198
+ }
199
+ }
200
+
201
+ void decode_okvs(uint64_t* keys, uint64_t* storage, uint64_t* output, uint64_t n, uint64_t m, uint64_t* seed_ptr) {
202
+ __m128i seed = _mm_loadu_si128((__m128i*)seed_ptr);
203
+ __m128i* P_vec = (__m128i*)storage;
204
+ __m128i* out_vec = (__m128i*)output;
205
+
206
+ #pragma omp parallel for schedule(static)
207
+ for(uint64_t i=0; i<n; ++i) {
208
+ Indices idx = hash_key(keys[i], m, seed);
209
+ __m128i val = _mm_xor_si128(
210
+ _mm_xor_si128(_mm_loadu_si128(&P_vec[idx.h1]), _mm_loadu_si128(&P_vec[idx.h2])),
211
+ _mm_loadu_si128(&P_vec[idx.h3])
212
+ );
213
+ _mm_storeu_si128(&out_vec[i], val);
214
+ }
215
+ }
216
+
217
+ // Helper for key expansion
218
+ inline __m128i aes_keygen_assist(__m128i temp1, __m128i temp2) {
219
+ __m128i temp3;
220
+ temp2 = _mm_shuffle_epi32(temp2, 0xff);
221
+ temp3 = _mm_slli_si128(temp1, 0x4);
222
+ temp1 = _mm_xor_si128(temp1, temp3);
223
+ temp3 = _mm_slli_si128(temp3, 0x4);
224
+ temp1 = _mm_xor_si128(temp1, temp3);
225
+ temp3 = _mm_slli_si128(temp3, 0x4);
226
+ temp1 = _mm_xor_si128(temp1, temp3);
227
+ temp1 = _mm_xor_si128(temp1, temp2);
228
+ return temp1;
229
+ }
230
+
231
+ void aes_key_expand(__m128i user_key, __m128i* key_schedule) {
232
+ key_schedule[0] = user_key;
233
+ key_schedule[1] = aes_keygen_assist(key_schedule[0], _mm_aeskeygenassist_si128(key_schedule[0], 0x01));
234
+ key_schedule[2] = aes_keygen_assist(key_schedule[1], _mm_aeskeygenassist_si128(key_schedule[1], 0x02));
235
+ key_schedule[3] = aes_keygen_assist(key_schedule[2], _mm_aeskeygenassist_si128(key_schedule[2], 0x04));
236
+ key_schedule[4] = aes_keygen_assist(key_schedule[3], _mm_aeskeygenassist_si128(key_schedule[3], 0x08));
237
+ key_schedule[5] = aes_keygen_assist(key_schedule[4], _mm_aeskeygenassist_si128(key_schedule[4], 0x10));
238
+ key_schedule[6] = aes_keygen_assist(key_schedule[5], _mm_aeskeygenassist_si128(key_schedule[5], 0x20));
239
+ key_schedule[7] = aes_keygen_assist(key_schedule[6], _mm_aeskeygenassist_si128(key_schedule[6], 0x40));
240
+ key_schedule[8] = aes_keygen_assist(key_schedule[7], _mm_aeskeygenassist_si128(key_schedule[7], 0x80));
241
+ key_schedule[9] = aes_keygen_assist(key_schedule[8], _mm_aeskeygenassist_si128(key_schedule[8], 0x1b));
242
+ key_schedule[10] = aes_keygen_assist(key_schedule[9], _mm_aeskeygenassist_si128(key_schedule[9], 0x36));
243
+ }
244
+
245
+ // AES-128 Expansion
246
+ void aes_128_expand(uint64_t* seeds, uint64_t* output, uint64_t num_seeds, uint64_t length) {
247
+ __m128i* seeds_vec = (__m128i*)seeds;
248
+ __m128i* out_vec = (__m128i*)output;
249
+
250
+ // Fixed Key (Arbitrary constant)
251
+ // Using PI fractional part (Nothing-up-my-sleeve numbers)
252
+ // 0x243F6A8885A308D3 (PI_FRAC_1)
253
+ // 0x13198A2E03707344 (PI_FRAC_2)
254
+ __m128i fixed_key = _mm_set_epi64x(0x243F6A8885A308D3, 0x13198A2E03707344);
255
+ __m128i round_keys[11];
256
+ aes_key_expand(fixed_key, round_keys);
257
+
258
+ // For each seed
259
+ #pragma omp parallel for schedule(static)
260
+ for(uint64_t i=0; i<num_seeds; ++i) {
261
+ __m128i s = _mm_loadu_si128(&seeds_vec[i]);
262
+
263
+ // Expand to 'length' blocks
264
+ for(uint64_t j=0; j<length; ++j) {
265
+ // Block = Seed ^ j
266
+ // Note: j is passed as counter mix
267
+ __m128i ctr = _mm_set_epi64x(0, j);
268
+ __m128i block = _mm_xor_si128(s, ctr);
269
+
270
+ // Encrypt Block
271
+ __m128i state = _mm_xor_si128(block, round_keys[0]);
272
+ for(int r=1; r<10; ++r) {
273
+ state = _mm_aesenc_si128(state, round_keys[r]);
274
+ }
275
+ state = _mm_aesenclast_si128(state, round_keys[10]);
276
+
277
+ // Store
278
+ // Output is flat: [seed0_0, seed0_1 ... seed1_0 ...]
279
+ _mm_storeu_si128(&out_vec[i * length + j], state);
280
+ }
281
+ }
282
+ }
283
+ }