mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""LDPC (Low-Density Parity-Check) Code Implementation for Silver VOLE.
|
|
16
|
+
|
|
17
|
+
This module provides LDPC matrix generation and encoding functions used in
|
|
18
|
+
the Silver protocol for efficient silent VOLE generation.
|
|
19
|
+
|
|
20
|
+
Silver uses a specific LDPC structure optimized for:
|
|
21
|
+
1. Fast encoding (quasi-cyclic structure)
|
|
22
|
+
2. Efficient syndrome computation
|
|
23
|
+
3. Low-density for minimal communication
|
|
24
|
+
|
|
25
|
+
Reference: "Silver: Silent VOLE and Oblivious Transfer from Hardness of Decoding"
|
|
26
|
+
CRYPTO 2021
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from typing import Any, cast
|
|
30
|
+
|
|
31
|
+
import jax.numpy as jnp
|
|
32
|
+
import numpy as np
|
|
33
|
+
import scipy.sparse as sp
|
|
34
|
+
|
|
35
|
+
import mplang.v2.edsl as el
|
|
36
|
+
from mplang.v2.dialects import crypto, field, tensor
|
|
37
|
+
|
|
38
|
+
# ============================================================================
|
|
39
|
+
# Constants
|
|
40
|
+
# ============================================================================
|
|
41
|
+
|
|
42
|
+
# Default Silver parameters (from paper)
|
|
43
|
+
SILVER_WEIGHT = 5 # Row weight (number of 1s per row)
|
|
44
|
+
SILVER_GAP = 16 # Gap parameter for quasi-cyclic structure
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# ============================================================================
|
|
48
|
+
# LDPC Matrix Generation
|
|
49
|
+
# ============================================================================
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def generate_silver_ldpc(n: int, m: int, seed: int = 42) -> sp.csr_matrix:
|
|
53
|
+
"""Generate Silver-style LDPC parity check matrix.
|
|
54
|
+
|
|
55
|
+
Creates a quasi-cyclic LDPC matrix suitable for Silver protocol.
|
|
56
|
+
The matrix has:
|
|
57
|
+
- Dimensions: m x n (m < n for compression)
|
|
58
|
+
- Row weight: SILVER_WEIGHT (sparse)
|
|
59
|
+
- Quasi-cyclic structure for fast encoding
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
n: Number of columns (message length)
|
|
63
|
+
m: Number of rows (syndrome length, typically n/10 to n/5)
|
|
64
|
+
seed: Random seed for reproducibility
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Sparse CSR matrix H of shape (m, n)
|
|
68
|
+
"""
|
|
69
|
+
rng = np.random.RandomState(seed)
|
|
70
|
+
|
|
71
|
+
# Use a regular LDPC structure with fixed row weight
|
|
72
|
+
row_weight = min(SILVER_WEIGHT, n)
|
|
73
|
+
|
|
74
|
+
# Build sparse matrix in COO format for efficiency
|
|
75
|
+
rows = []
|
|
76
|
+
cols = []
|
|
77
|
+
|
|
78
|
+
for i in range(m):
|
|
79
|
+
# Select random column indices for this row
|
|
80
|
+
# Use consistent spacing with some randomness for quasi-cyclic property
|
|
81
|
+
base_positions = np.linspace(0, n - 1, row_weight, dtype=int)
|
|
82
|
+
offsets = rng.randint(-SILVER_GAP, SILVER_GAP + 1, size=row_weight)
|
|
83
|
+
positions = (base_positions + offsets) % n
|
|
84
|
+
positions = np.unique(positions) # Remove duplicates
|
|
85
|
+
|
|
86
|
+
# Ensure we have at least some entries
|
|
87
|
+
while len(positions) < min(3, row_weight):
|
|
88
|
+
extra = rng.randint(0, n, size=row_weight - len(positions))
|
|
89
|
+
positions = np.unique(np.concatenate([positions, extra]))
|
|
90
|
+
|
|
91
|
+
for j in positions:
|
|
92
|
+
rows.append(i)
|
|
93
|
+
cols.append(j)
|
|
94
|
+
|
|
95
|
+
data = np.ones(len(rows), dtype=np.uint8)
|
|
96
|
+
H = sp.coo_matrix((data, (rows, cols)), shape=(m, n), dtype=np.uint8)
|
|
97
|
+
|
|
98
|
+
return H.tocsr()
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def generate_silver_ldpc_systematic(
|
|
102
|
+
n: int, k: int, seed: int = 42
|
|
103
|
+
) -> tuple[sp.csr_matrix, sp.csr_matrix]:
|
|
104
|
+
"""Generate systematic LDPC matrix for Silver.
|
|
105
|
+
|
|
106
|
+
Returns both the parity check matrix H and generator matrix G.
|
|
107
|
+
H is (n-k) x n, G is k x n.
|
|
108
|
+
|
|
109
|
+
For Silver, we primarily need H for syndrome computation.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
n: Codeword length
|
|
113
|
+
k: Message length (k < n)
|
|
114
|
+
seed: Random seed
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Tuple of (H, G) as sparse matrices
|
|
118
|
+
"""
|
|
119
|
+
m = n - k # Number of parity bits
|
|
120
|
+
H = generate_silver_ldpc(n, m, seed)
|
|
121
|
+
|
|
122
|
+
# For Silver, G is not strictly needed as we use syndrome encoding
|
|
123
|
+
# Return None for G to save computation
|
|
124
|
+
return H, None
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
# ============================================================================
|
|
128
|
+
# LDPC Decoding (For Testing / Verification)
|
|
129
|
+
# ============================================================================
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def ldpc_decode_syndrome(
|
|
133
|
+
syndrome: np.ndarray, H: sp.csr_matrix, noise_weight: int
|
|
134
|
+
) -> np.ndarray:
|
|
135
|
+
"""Decode syndrome to recover sparse error vector (Testing only).
|
|
136
|
+
|
|
137
|
+
Uses simple greedy bit-flipping / peeling for low-weight errors.
|
|
138
|
+
Useful for verifying that the H matrix and encoding process are correct
|
|
139
|
+
by performing a round-trip: encode(error) -> syndrome -> decode(syndrome) == error.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
syndrome: Syndrome vector of shape (m,) or (m, 2)
|
|
143
|
+
H: LDPC parity check matrix
|
|
144
|
+
noise_weight: Expected weight of error vector
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
Estimated error vector of shape (n,) or (n, 2)
|
|
148
|
+
"""
|
|
149
|
+
m, n = H.shape
|
|
150
|
+
|
|
151
|
+
# For Silver with low noise, simple syndrome inversion works
|
|
152
|
+
# This is a placeholder - full BP decoder can be added later
|
|
153
|
+
|
|
154
|
+
if syndrome.ndim == 1:
|
|
155
|
+
error = np.zeros(n, dtype=np.uint8)
|
|
156
|
+
else:
|
|
157
|
+
error = np.zeros((n, syndrome.shape[1]), dtype=syndrome.dtype)
|
|
158
|
+
|
|
159
|
+
# Simple greedy decoder for sparse errors
|
|
160
|
+
# Find columns that match syndrome bits
|
|
161
|
+
remaining_syndrome = syndrome.copy()
|
|
162
|
+
|
|
163
|
+
for _ in range(noise_weight):
|
|
164
|
+
# Find column that reduces syndrome the most
|
|
165
|
+
best_col = -1
|
|
166
|
+
best_reduction = 0
|
|
167
|
+
|
|
168
|
+
for j in range(n):
|
|
169
|
+
col = H.getcol(j).toarray().flatten()
|
|
170
|
+
if syndrome.ndim == 1:
|
|
171
|
+
reduction = np.sum(col & (remaining_syndrome != 0))
|
|
172
|
+
else:
|
|
173
|
+
reduction = np.sum(col.reshape(-1, 1) & (remaining_syndrome != 0))
|
|
174
|
+
|
|
175
|
+
if reduction > best_reduction:
|
|
176
|
+
best_reduction = reduction
|
|
177
|
+
best_col = j
|
|
178
|
+
|
|
179
|
+
if best_col == -1 or best_reduction == 0:
|
|
180
|
+
break
|
|
181
|
+
|
|
182
|
+
# Flip this bit
|
|
183
|
+
error[best_col] = (
|
|
184
|
+
1
|
|
185
|
+
if syndrome.ndim == 1
|
|
186
|
+
else np.ones(syndrome.shape[1], dtype=syndrome.dtype)
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# Update syndrome
|
|
190
|
+
col = H.getcol(best_col).toarray().flatten()
|
|
191
|
+
if syndrome.ndim == 1:
|
|
192
|
+
remaining_syndrome = (remaining_syndrome + col) % 2
|
|
193
|
+
else:
|
|
194
|
+
for i in range(m):
|
|
195
|
+
if col[i]:
|
|
196
|
+
remaining_syndrome[i] ^= error[best_col]
|
|
197
|
+
|
|
198
|
+
return error
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
# ============================================================================
|
|
202
|
+
# Silver-specific Parameters
|
|
203
|
+
# ============================================================================
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def get_silver_params(n: int) -> tuple[int, int, int]:
|
|
207
|
+
"""Get recommended Silver parameters for given output length.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
n: Desired number of VOLE correlations
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Tuple of (code_length, syndrome_length, noise_weight)
|
|
214
|
+
"""
|
|
215
|
+
# Silver uses approximately 10:1 compression
|
|
216
|
+
code_length = n
|
|
217
|
+
syndrome_length = max(n // 10, 128) # At least 128 for security
|
|
218
|
+
noise_weight = 64 # Low noise for efficient decoding
|
|
219
|
+
|
|
220
|
+
return code_length, syndrome_length, noise_weight
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
# ============================================================================
|
|
224
|
+
# Utility Functions
|
|
225
|
+
# ============================================================================
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def matrix_to_sparse_repr(H: sp.csr_matrix) -> tuple[np.ndarray, np.ndarray]:
|
|
229
|
+
"""Convert sparse matrix to compact representation for C++ kernel.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
Tuple of (indptr, indices) arrays
|
|
233
|
+
"""
|
|
234
|
+
return H.indptr.astype(np.uint64), H.indices.astype(np.uint64)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def verify_ldpc_structure(H: sp.csr_matrix) -> bool:
|
|
238
|
+
"""Verify LDPC matrix has correct structure.
|
|
239
|
+
|
|
240
|
+
Checks:
|
|
241
|
+
- Sparsity (low density)
|
|
242
|
+
- No all-zero rows
|
|
243
|
+
- Reasonable row weights
|
|
244
|
+
"""
|
|
245
|
+
m, n = H.shape
|
|
246
|
+
|
|
247
|
+
# Check sparsity
|
|
248
|
+
density = H.nnz / (m * n)
|
|
249
|
+
if density > 0.1:
|
|
250
|
+
print(f"Warning: LDPC density {density:.3f} is high")
|
|
251
|
+
return False
|
|
252
|
+
|
|
253
|
+
# Check row weights
|
|
254
|
+
row_weights = np.diff(H.indptr)
|
|
255
|
+
if np.any(row_weights == 0):
|
|
256
|
+
print("Warning: LDPC has zero-weight rows")
|
|
257
|
+
return False
|
|
258
|
+
|
|
259
|
+
avg_weight = np.mean(row_weights)
|
|
260
|
+
if avg_weight < 2 or avg_weight > 20:
|
|
261
|
+
print(f"Warning: LDPC average row weight {avg_weight:.1f} unusual")
|
|
262
|
+
return False
|
|
263
|
+
|
|
264
|
+
return True
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
# ============================================================================
|
|
268
|
+
# JAX/EDSL Implementations
|
|
269
|
+
# ============================================================================
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def generate_sparse_noise(n: int, weight: int) -> el.Object:
|
|
273
|
+
"""Generate cryptographically secure sparse noise vector.
|
|
274
|
+
|
|
275
|
+
Uses entropy from crypto.random_bytes at runtime to select `weight` unique
|
|
276
|
+
positions from [0, n), then generates random 128-bit values at those positions.
|
|
277
|
+
|
|
278
|
+
Security: This is suitable for LPN-based protocols like Silver VOLE.
|
|
279
|
+
The randomness is generated at runtime, not trace-time.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
n: Length of noise vector
|
|
283
|
+
weight: Hamming weight (number of non-zero positions)
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
(n, 2) uint64 tensor with exactly `weight` non-zero 128-bit elements
|
|
287
|
+
"""
|
|
288
|
+
# Phase 1: Generate runtime entropy
|
|
289
|
+
# 8 bytes per position (for index selection) + 16 bytes per value
|
|
290
|
+
entropy_needed = weight * 8 + weight * 16
|
|
291
|
+
entropy = crypto.random_bytes(entropy_needed)
|
|
292
|
+
|
|
293
|
+
# Phase 2: Deterministic construction from entropy
|
|
294
|
+
def _build_noise(ent: Any) -> Any:
|
|
295
|
+
# Split entropy into index selection and value parts
|
|
296
|
+
idx_entropy = ent[: weight * 8].view(jnp.uint64) # (weight,)
|
|
297
|
+
val_entropy = (
|
|
298
|
+
ent[weight * 8 :].view(jnp.uint64).reshape(weight, 2)
|
|
299
|
+
) # (weight, 2)
|
|
300
|
+
|
|
301
|
+
# Generate unique indices using rejection-free Fisher-Yates-like approach
|
|
302
|
+
# Map random u64 to positions while ensuring uniqueness
|
|
303
|
+
# Use int64 to avoid dtype mismatch warning in scatter operations
|
|
304
|
+
positions = jnp.zeros(weight, dtype=jnp.int64)
|
|
305
|
+
|
|
306
|
+
# Build positions array (unrolled for JAX compatibility)
|
|
307
|
+
for i in range(weight):
|
|
308
|
+
# Map random value to remaining range [0, n-i)
|
|
309
|
+
pos = jnp.int64(idx_entropy[i] % (n - i))
|
|
310
|
+
|
|
311
|
+
# Shift position to avoid already-used indices
|
|
312
|
+
# Count how many existing positions are <= current pos
|
|
313
|
+
offset = jnp.sum(positions[:i] <= pos)
|
|
314
|
+
pos = pos + offset
|
|
315
|
+
|
|
316
|
+
positions = positions.at[i].set(pos)
|
|
317
|
+
|
|
318
|
+
# Sort positions for efficient scatter
|
|
319
|
+
positions = jnp.sort(positions)
|
|
320
|
+
|
|
321
|
+
# Build sparse noise vector using scatter
|
|
322
|
+
noise = jnp.zeros((n, 2), dtype=jnp.uint64)
|
|
323
|
+
noise = noise.at[positions].set(val_entropy)
|
|
324
|
+
|
|
325
|
+
return noise
|
|
326
|
+
|
|
327
|
+
return cast(el.Object, tensor.run_jax(_build_noise, entropy))
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def ldpc_encode_dense_jax(message: el.Object, H_dense: el.Object) -> el.Object:
|
|
331
|
+
"""Compute H * message (LDPC encode) using dense JAX operations.
|
|
332
|
+
|
|
333
|
+
This acts as a reference implementation for correctness checking.
|
|
334
|
+
It is significantly slower than the sparse C++ kernel.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
message: (N, 2) uint64 message.
|
|
338
|
+
H_dense: (M, N) uint8 parity check matrix (0/1).
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
(M, 2) uint64 syndrome.
|
|
342
|
+
"""
|
|
343
|
+
|
|
344
|
+
def _encode(msg: Any, h: Any) -> Any:
|
|
345
|
+
# msg: (N, 2)
|
|
346
|
+
# h: (M, N)
|
|
347
|
+
|
|
348
|
+
# Broadcast for element-wise AND
|
|
349
|
+
# msg: (1, N, 2) -> Broadcasts across M rows
|
|
350
|
+
# h: (M, N, 1) -> Broadcasts across 2 columns
|
|
351
|
+
msg_broad = msg.reshape(1, msg.shape[0], 2)
|
|
352
|
+
h_broad = h.reshape(h.shape[0], h.shape[1], 1).astype(jnp.uint64)
|
|
353
|
+
|
|
354
|
+
# Select active message elements
|
|
355
|
+
terms = jnp.bitwise_and(msg_broad, h_broad)
|
|
356
|
+
|
|
357
|
+
# Reduce: Sum (XOR) across N (axis 1)
|
|
358
|
+
# Using scan for memory efficiency over direct reduce
|
|
359
|
+
def body(carry: Any, x: Any) -> tuple[Any, None]:
|
|
360
|
+
return jnp.bitwise_xor(carry, x), None
|
|
361
|
+
|
|
362
|
+
# Transpose to (N, M, 2) to iterate over N
|
|
363
|
+
terms_tp = jnp.transpose(terms, (1, 0, 2))
|
|
364
|
+
|
|
365
|
+
zeros = jnp.zeros((h.shape[0], 2), dtype=jnp.uint64)
|
|
366
|
+
res, _ = jax.lax.scan(body, zeros, terms_tp)
|
|
367
|
+
return res
|
|
368
|
+
|
|
369
|
+
import jax
|
|
370
|
+
import jax.numpy as jnp
|
|
371
|
+
|
|
372
|
+
return cast(el.Object, tensor.run_jax(_encode, message, H_dense))
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def ldpc_encode_sparse(
|
|
376
|
+
message: el.Object, h_indices: el.Object, h_indptr: el.Object, m: int, n: int
|
|
377
|
+
) -> el.Object:
|
|
378
|
+
"""Compute S = H * x using C++ Kernel via Field Dialect Primitive.
|
|
379
|
+
|
|
380
|
+
This invokes `field.ldpc_encode` which bypasses JAX callback overhead
|
|
381
|
+
and uses the direct Interpreter dispatch mechanism.
|
|
382
|
+
"""
|
|
383
|
+
return field.ldpc_encode(message, h_indices, h_indptr, m, n)
|