mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,383 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """LDPC (Low-Density Parity-Check) Code Implementation for Silver VOLE.
16
+
17
+ This module provides LDPC matrix generation and encoding functions used in
18
+ the Silver protocol for efficient silent VOLE generation.
19
+
20
+ Silver uses a specific LDPC structure optimized for:
21
+ 1. Fast encoding (quasi-cyclic structure)
22
+ 2. Efficient syndrome computation
23
+ 3. Low-density for minimal communication
24
+
25
+ Reference: "Silver: Silent VOLE and Oblivious Transfer from Hardness of Decoding"
26
+ CRYPTO 2021
27
+ """
28
+
29
+ from typing import Any, cast
30
+
31
+ import jax.numpy as jnp
32
+ import numpy as np
33
+ import scipy.sparse as sp
34
+
35
+ import mplang.v2.edsl as el
36
+ from mplang.v2.dialects import crypto, field, tensor
37
+
38
+ # ============================================================================
39
+ # Constants
40
+ # ============================================================================
41
+
42
+ # Default Silver parameters (from paper)
43
+ SILVER_WEIGHT = 5 # Row weight (number of 1s per row)
44
+ SILVER_GAP = 16 # Gap parameter for quasi-cyclic structure
45
+
46
+
47
+ # ============================================================================
48
+ # LDPC Matrix Generation
49
+ # ============================================================================
50
+
51
+
52
+ def generate_silver_ldpc(n: int, m: int, seed: int = 42) -> sp.csr_matrix:
53
+ """Generate Silver-style LDPC parity check matrix.
54
+
55
+ Creates a quasi-cyclic LDPC matrix suitable for Silver protocol.
56
+ The matrix has:
57
+ - Dimensions: m x n (m < n for compression)
58
+ - Row weight: SILVER_WEIGHT (sparse)
59
+ - Quasi-cyclic structure for fast encoding
60
+
61
+ Args:
62
+ n: Number of columns (message length)
63
+ m: Number of rows (syndrome length, typically n/10 to n/5)
64
+ seed: Random seed for reproducibility
65
+
66
+ Returns:
67
+ Sparse CSR matrix H of shape (m, n)
68
+ """
69
+ rng = np.random.RandomState(seed)
70
+
71
+ # Use a regular LDPC structure with fixed row weight
72
+ row_weight = min(SILVER_WEIGHT, n)
73
+
74
+ # Build sparse matrix in COO format for efficiency
75
+ rows = []
76
+ cols = []
77
+
78
+ for i in range(m):
79
+ # Select random column indices for this row
80
+ # Use consistent spacing with some randomness for quasi-cyclic property
81
+ base_positions = np.linspace(0, n - 1, row_weight, dtype=int)
82
+ offsets = rng.randint(-SILVER_GAP, SILVER_GAP + 1, size=row_weight)
83
+ positions = (base_positions + offsets) % n
84
+ positions = np.unique(positions) # Remove duplicates
85
+
86
+ # Ensure we have at least some entries
87
+ while len(positions) < min(3, row_weight):
88
+ extra = rng.randint(0, n, size=row_weight - len(positions))
89
+ positions = np.unique(np.concatenate([positions, extra]))
90
+
91
+ for j in positions:
92
+ rows.append(i)
93
+ cols.append(j)
94
+
95
+ data = np.ones(len(rows), dtype=np.uint8)
96
+ H = sp.coo_matrix((data, (rows, cols)), shape=(m, n), dtype=np.uint8)
97
+
98
+ return H.tocsr()
99
+
100
+
101
+ def generate_silver_ldpc_systematic(
102
+ n: int, k: int, seed: int = 42
103
+ ) -> tuple[sp.csr_matrix, sp.csr_matrix]:
104
+ """Generate systematic LDPC matrix for Silver.
105
+
106
+ Returns both the parity check matrix H and generator matrix G.
107
+ H is (n-k) x n, G is k x n.
108
+
109
+ For Silver, we primarily need H for syndrome computation.
110
+
111
+ Args:
112
+ n: Codeword length
113
+ k: Message length (k < n)
114
+ seed: Random seed
115
+
116
+ Returns:
117
+ Tuple of (H, G) as sparse matrices
118
+ """
119
+ m = n - k # Number of parity bits
120
+ H = generate_silver_ldpc(n, m, seed)
121
+
122
+ # For Silver, G is not strictly needed as we use syndrome encoding
123
+ # Return None for G to save computation
124
+ return H, None
125
+
126
+
127
+ # ============================================================================
128
+ # LDPC Decoding (For Testing / Verification)
129
+ # ============================================================================
130
+
131
+
132
+ def ldpc_decode_syndrome(
133
+ syndrome: np.ndarray, H: sp.csr_matrix, noise_weight: int
134
+ ) -> np.ndarray:
135
+ """Decode syndrome to recover sparse error vector (Testing only).
136
+
137
+ Uses simple greedy bit-flipping / peeling for low-weight errors.
138
+ Useful for verifying that the H matrix and encoding process are correct
139
+ by performing a round-trip: encode(error) -> syndrome -> decode(syndrome) == error.
140
+
141
+ Args:
142
+ syndrome: Syndrome vector of shape (m,) or (m, 2)
143
+ H: LDPC parity check matrix
144
+ noise_weight: Expected weight of error vector
145
+
146
+ Returns:
147
+ Estimated error vector of shape (n,) or (n, 2)
148
+ """
149
+ m, n = H.shape
150
+
151
+ # For Silver with low noise, simple syndrome inversion works
152
+ # This is a placeholder - full BP decoder can be added later
153
+
154
+ if syndrome.ndim == 1:
155
+ error = np.zeros(n, dtype=np.uint8)
156
+ else:
157
+ error = np.zeros((n, syndrome.shape[1]), dtype=syndrome.dtype)
158
+
159
+ # Simple greedy decoder for sparse errors
160
+ # Find columns that match syndrome bits
161
+ remaining_syndrome = syndrome.copy()
162
+
163
+ for _ in range(noise_weight):
164
+ # Find column that reduces syndrome the most
165
+ best_col = -1
166
+ best_reduction = 0
167
+
168
+ for j in range(n):
169
+ col = H.getcol(j).toarray().flatten()
170
+ if syndrome.ndim == 1:
171
+ reduction = np.sum(col & (remaining_syndrome != 0))
172
+ else:
173
+ reduction = np.sum(col.reshape(-1, 1) & (remaining_syndrome != 0))
174
+
175
+ if reduction > best_reduction:
176
+ best_reduction = reduction
177
+ best_col = j
178
+
179
+ if best_col == -1 or best_reduction == 0:
180
+ break
181
+
182
+ # Flip this bit
183
+ error[best_col] = (
184
+ 1
185
+ if syndrome.ndim == 1
186
+ else np.ones(syndrome.shape[1], dtype=syndrome.dtype)
187
+ )
188
+
189
+ # Update syndrome
190
+ col = H.getcol(best_col).toarray().flatten()
191
+ if syndrome.ndim == 1:
192
+ remaining_syndrome = (remaining_syndrome + col) % 2
193
+ else:
194
+ for i in range(m):
195
+ if col[i]:
196
+ remaining_syndrome[i] ^= error[best_col]
197
+
198
+ return error
199
+
200
+
201
+ # ============================================================================
202
+ # Silver-specific Parameters
203
+ # ============================================================================
204
+
205
+
206
+ def get_silver_params(n: int) -> tuple[int, int, int]:
207
+ """Get recommended Silver parameters for given output length.
208
+
209
+ Args:
210
+ n: Desired number of VOLE correlations
211
+
212
+ Returns:
213
+ Tuple of (code_length, syndrome_length, noise_weight)
214
+ """
215
+ # Silver uses approximately 10:1 compression
216
+ code_length = n
217
+ syndrome_length = max(n // 10, 128) # At least 128 for security
218
+ noise_weight = 64 # Low noise for efficient decoding
219
+
220
+ return code_length, syndrome_length, noise_weight
221
+
222
+
223
+ # ============================================================================
224
+ # Utility Functions
225
+ # ============================================================================
226
+
227
+
228
+ def matrix_to_sparse_repr(H: sp.csr_matrix) -> tuple[np.ndarray, np.ndarray]:
229
+ """Convert sparse matrix to compact representation for C++ kernel.
230
+
231
+ Returns:
232
+ Tuple of (indptr, indices) arrays
233
+ """
234
+ return H.indptr.astype(np.uint64), H.indices.astype(np.uint64)
235
+
236
+
237
+ def verify_ldpc_structure(H: sp.csr_matrix) -> bool:
238
+ """Verify LDPC matrix has correct structure.
239
+
240
+ Checks:
241
+ - Sparsity (low density)
242
+ - No all-zero rows
243
+ - Reasonable row weights
244
+ """
245
+ m, n = H.shape
246
+
247
+ # Check sparsity
248
+ density = H.nnz / (m * n)
249
+ if density > 0.1:
250
+ print(f"Warning: LDPC density {density:.3f} is high")
251
+ return False
252
+
253
+ # Check row weights
254
+ row_weights = np.diff(H.indptr)
255
+ if np.any(row_weights == 0):
256
+ print("Warning: LDPC has zero-weight rows")
257
+ return False
258
+
259
+ avg_weight = np.mean(row_weights)
260
+ if avg_weight < 2 or avg_weight > 20:
261
+ print(f"Warning: LDPC average row weight {avg_weight:.1f} unusual")
262
+ return False
263
+
264
+ return True
265
+
266
+
267
+ # ============================================================================
268
+ # JAX/EDSL Implementations
269
+ # ============================================================================
270
+
271
+
272
+ def generate_sparse_noise(n: int, weight: int) -> el.Object:
273
+ """Generate cryptographically secure sparse noise vector.
274
+
275
+ Uses entropy from crypto.random_bytes at runtime to select `weight` unique
276
+ positions from [0, n), then generates random 128-bit values at those positions.
277
+
278
+ Security: This is suitable for LPN-based protocols like Silver VOLE.
279
+ The randomness is generated at runtime, not trace-time.
280
+
281
+ Args:
282
+ n: Length of noise vector
283
+ weight: Hamming weight (number of non-zero positions)
284
+
285
+ Returns:
286
+ (n, 2) uint64 tensor with exactly `weight` non-zero 128-bit elements
287
+ """
288
+ # Phase 1: Generate runtime entropy
289
+ # 8 bytes per position (for index selection) + 16 bytes per value
290
+ entropy_needed = weight * 8 + weight * 16
291
+ entropy = crypto.random_bytes(entropy_needed)
292
+
293
+ # Phase 2: Deterministic construction from entropy
294
+ def _build_noise(ent: Any) -> Any:
295
+ # Split entropy into index selection and value parts
296
+ idx_entropy = ent[: weight * 8].view(jnp.uint64) # (weight,)
297
+ val_entropy = (
298
+ ent[weight * 8 :].view(jnp.uint64).reshape(weight, 2)
299
+ ) # (weight, 2)
300
+
301
+ # Generate unique indices using rejection-free Fisher-Yates-like approach
302
+ # Map random u64 to positions while ensuring uniqueness
303
+ # Use int64 to avoid dtype mismatch warning in scatter operations
304
+ positions = jnp.zeros(weight, dtype=jnp.int64)
305
+
306
+ # Build positions array (unrolled for JAX compatibility)
307
+ for i in range(weight):
308
+ # Map random value to remaining range [0, n-i)
309
+ pos = jnp.int64(idx_entropy[i] % (n - i))
310
+
311
+ # Shift position to avoid already-used indices
312
+ # Count how many existing positions are <= current pos
313
+ offset = jnp.sum(positions[:i] <= pos)
314
+ pos = pos + offset
315
+
316
+ positions = positions.at[i].set(pos)
317
+
318
+ # Sort positions for efficient scatter
319
+ positions = jnp.sort(positions)
320
+
321
+ # Build sparse noise vector using scatter
322
+ noise = jnp.zeros((n, 2), dtype=jnp.uint64)
323
+ noise = noise.at[positions].set(val_entropy)
324
+
325
+ return noise
326
+
327
+ return cast(el.Object, tensor.run_jax(_build_noise, entropy))
328
+
329
+
330
+ def ldpc_encode_dense_jax(message: el.Object, H_dense: el.Object) -> el.Object:
331
+ """Compute H * message (LDPC encode) using dense JAX operations.
332
+
333
+ This acts as a reference implementation for correctness checking.
334
+ It is significantly slower than the sparse C++ kernel.
335
+
336
+ Args:
337
+ message: (N, 2) uint64 message.
338
+ H_dense: (M, N) uint8 parity check matrix (0/1).
339
+
340
+ Returns:
341
+ (M, 2) uint64 syndrome.
342
+ """
343
+
344
+ def _encode(msg: Any, h: Any) -> Any:
345
+ # msg: (N, 2)
346
+ # h: (M, N)
347
+
348
+ # Broadcast for element-wise AND
349
+ # msg: (1, N, 2) -> Broadcasts across M rows
350
+ # h: (M, N, 1) -> Broadcasts across 2 columns
351
+ msg_broad = msg.reshape(1, msg.shape[0], 2)
352
+ h_broad = h.reshape(h.shape[0], h.shape[1], 1).astype(jnp.uint64)
353
+
354
+ # Select active message elements
355
+ terms = jnp.bitwise_and(msg_broad, h_broad)
356
+
357
+ # Reduce: Sum (XOR) across N (axis 1)
358
+ # Using scan for memory efficiency over direct reduce
359
+ def body(carry: Any, x: Any) -> tuple[Any, None]:
360
+ return jnp.bitwise_xor(carry, x), None
361
+
362
+ # Transpose to (N, M, 2) to iterate over N
363
+ terms_tp = jnp.transpose(terms, (1, 0, 2))
364
+
365
+ zeros = jnp.zeros((h.shape[0], 2), dtype=jnp.uint64)
366
+ res, _ = jax.lax.scan(body, zeros, terms_tp)
367
+ return res
368
+
369
+ import jax
370
+ import jax.numpy as jnp
371
+
372
+ return cast(el.Object, tensor.run_jax(_encode, message, H_dense))
373
+
374
+
375
+ def ldpc_encode_sparse(
376
+ message: el.Object, h_indices: el.Object, h_indptr: el.Object, m: int, n: int
377
+ ) -> el.Object:
378
+ """Compute S = H * x using C++ Kernel via Field Dialect Primitive.
379
+
380
+ This invokes `field.ldpc_encode` which bypasses JAX callback overhead
381
+ and uses the direct Interpreter dispatch mechanism.
382
+ """
383
+ return field.ldpc_encode(message, h_indices, h_indptr, m, n)