mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,454 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Field Backend Implementation.
16
+
17
+ Implements runtime execution logic for Field dialect primitives,
18
+ including bindings to C++ kernels (libmplang_kernels.so) and
19
+ NumPy fallbacks where appropriate.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import ctypes
25
+ import os
26
+ import threading
27
+
28
+ # print("DEBUG: Importing field_impl.py")
29
+ import jax.numpy as jnp
30
+ import numpy as np
31
+
32
+ from mplang.v2.backends.tensor_impl import TensorValue, _unwrap, _wrap
33
+ from mplang.v2.dialects import field
34
+ from mplang.v2.edsl.graph import Operation
35
+ from mplang.v2.kernels import py_kernels
36
+ from mplang.v2.runtime.interpreter import Interpreter
37
+
38
+ # =============================================================================
39
+ # Kernel Loading
40
+ # =============================================================================
41
+
42
+ # Load Kernel Library
43
+ # In a real package, this path would be resolved robustly
44
+ _KERNEL_LIB_PATH = os.path.join(
45
+ os.path.dirname(__file__), "..", "kernels", "libmplang_kernels.so"
46
+ )
47
+ _LIB = None
48
+ _LIB_LOCK = threading.Lock()
49
+
50
+
51
+ def _get_lib() -> ctypes.CDLL | None:
52
+ global _LIB
53
+ with _LIB_LOCK:
54
+ if _LIB is None:
55
+ try:
56
+ _LIB = ctypes.CDLL(_KERNEL_LIB_PATH)
57
+ # Define signatures
58
+ _LIB.gf128_mul.argtypes = [
59
+ ctypes.POINTER(ctypes.c_uint64),
60
+ ctypes.POINTER(ctypes.c_uint64),
61
+ ctypes.POINTER(ctypes.c_uint64),
62
+ ]
63
+ _LIB.gf128_mul_batch.argtypes = [
64
+ ctypes.POINTER(ctypes.c_uint64),
65
+ ctypes.POINTER(ctypes.c_uint64),
66
+ ctypes.POINTER(ctypes.c_uint64),
67
+ ctypes.c_int64,
68
+ ]
69
+ _LIB.solve_okvs.argtypes = [
70
+ ctypes.POINTER(ctypes.c_uint64), # keys
71
+ ctypes.POINTER(ctypes.c_uint64), # values
72
+ ctypes.POINTER(ctypes.c_uint64), # output
73
+ ctypes.c_uint64, # n
74
+ ctypes.c_uint64, # m
75
+ ctypes.POINTER(ctypes.c_uint64), # seed
76
+ ]
77
+ _LIB.decode_okvs.argtypes = [
78
+ ctypes.POINTER(ctypes.c_uint64), # keys
79
+ ctypes.POINTER(ctypes.c_uint64), # storage
80
+ ctypes.POINTER(ctypes.c_uint64), # output
81
+ ctypes.c_uint64, # n
82
+ ctypes.c_uint64, # m
83
+ ctypes.POINTER(ctypes.c_uint64), # seed
84
+ ]
85
+ # Optimized Mega-Binning Versions
86
+ _LIB.solve_okvs_opt.argtypes = _LIB.solve_okvs.argtypes
87
+ _LIB.decode_okvs_opt.argtypes = _LIB.decode_okvs.argtypes
88
+
89
+ _LIB.aes_128_expand.argtypes = [
90
+ ctypes.POINTER(ctypes.c_uint64), # seeds
91
+ ctypes.POINTER(ctypes.c_uint64), # output
92
+ ctypes.c_uint64, # num_seeds
93
+ ctypes.c_uint64, # length
94
+ ]
95
+ _LIB.ldpc_encode.argtypes = [
96
+ ctypes.POINTER(ctypes.c_uint64), # message
97
+ ctypes.POINTER(ctypes.c_uint64), # indices
98
+ ctypes.POINTER(ctypes.c_uint64), # indptr
99
+ ctypes.POINTER(ctypes.c_uint64), # output
100
+ ctypes.c_uint64, # m
101
+ ctypes.c_uint64, # n
102
+ ]
103
+ except OSError:
104
+ print(f"WARNING: Could not load kernels from {_KERNEL_LIB_PATH}")
105
+ return _LIB
106
+
107
+
108
+ # =============================================================================
109
+ # Helper Implementations (C++ Wrappers)
110
+ # =============================================================================
111
+
112
+
113
+ def _gf128_mul_impl(a: np.ndarray, b: np.ndarray) -> np.ndarray:
114
+ # a, b are numpy arrays (uint64) usually (N, 2)
115
+
116
+ lib = _get_lib()
117
+ if lib is None:
118
+ # Use pure Python fallback
119
+ return py_kernels.gf128_mul_batch(a, b)
120
+
121
+ # Enforce contiguous C-order arrays (important for ctypes)
122
+ # Use ascontiguousarray to avoid copy if already contiguous
123
+ a_contig = np.ascontiguousarray(a, dtype=np.uint64)
124
+ b_contig = np.ascontiguousarray(b, dtype=np.uint64)
125
+ out = np.zeros_like(a_contig)
126
+
127
+ # Calculate number of elements
128
+ # Assumes last dim is 2.
129
+ # Total uint64 count / 2
130
+ n_elements = a_contig.size // 2
131
+
132
+ a_ptr = a_contig.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64))
133
+ b_ptr = b_contig.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64))
134
+ out_ptr = out.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64))
135
+
136
+ lib.gf128_mul_batch(a_ptr, b_ptr, out_ptr, n_elements)
137
+
138
+ return out
139
+
140
+
141
+ def _okvs_solve_opt_impl(
142
+ keys: np.ndarray, values: np.ndarray, m: int, seed: np.ndarray
143
+ ) -> np.ndarray:
144
+ lib = _get_lib()
145
+ if seed.ndim > 1:
146
+ seed = seed.flatten()
147
+
148
+ if lib is None:
149
+ # Fallback to standard (no python impl for opt)
150
+ return _okvs_solve_impl(keys, values, m, seed)
151
+
152
+ n = keys.shape[0]
153
+
154
+ # Heuristic: Mega-Binning is unstable < 200k.
155
+ if n < 200_000:
156
+ return _okvs_solve_impl(keys, values, m, seed)
157
+
158
+ # Heuristic: Mega-Binning requires higher expansion (epsilon ~ 1.35)
159
+ # If m/n is too tight, fallback to Naive (which works with 1.25)
160
+ if m / n < 1.32:
161
+ return _okvs_solve_impl(keys, values, m, seed)
162
+
163
+ keys_c = np.ascontiguousarray(keys, dtype=np.uint64)
164
+ values_c = np.ascontiguousarray(values, dtype=np.uint64)
165
+ seed_c = np.ascontiguousarray(seed, dtype=np.uint64)
166
+ output = np.zeros((m, 2), dtype=np.uint64)
167
+
168
+ lib.solve_okvs_opt(
169
+ keys_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
170
+ values_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
171
+ output.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
172
+ n,
173
+ m,
174
+ seed_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
175
+ )
176
+ return output
177
+
178
+
179
+ def _okvs_decode_opt_impl(
180
+ keys: np.ndarray, storage: np.ndarray, m: int, seed: np.ndarray
181
+ ) -> np.ndarray:
182
+ lib = _get_lib()
183
+ if seed.ndim > 1:
184
+ seed = seed.flatten()
185
+
186
+ if lib is None:
187
+ return _okvs_decode_impl(keys, storage, m, seed)
188
+
189
+ n = keys.shape[0]
190
+
191
+ # Heuristic: Mega-Binning (1024 Bins) is unstable for small N due to variance.
192
+ # It requires ~1000 items/bin to be efficient and stable with epsilon=1.3.
193
+ # Threshold: 200,000 (approx 200 items/bin). Below this, Naive is fast enough (<50ms).
194
+ if n < 200_000:
195
+ return _okvs_decode_impl(keys, storage, m, seed)
196
+
197
+ if m / n < 1.32:
198
+ return _okvs_decode_impl(keys, storage, m, seed)
199
+
200
+ keys_c = np.ascontiguousarray(keys, dtype=np.uint64)
201
+ storage_c = np.ascontiguousarray(storage, dtype=np.uint64)
202
+ seed_c = np.ascontiguousarray(seed, dtype=np.uint64)
203
+ output = np.zeros((n, 2), dtype=np.uint64)
204
+
205
+ lib.decode_okvs_opt(
206
+ keys_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
207
+ storage_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
208
+ output.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
209
+ n,
210
+ m,
211
+ seed_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
212
+ )
213
+ return output
214
+
215
+
216
+ def _okvs_solve_impl(
217
+ keys: np.ndarray, values: np.ndarray, m: int, seed: np.ndarray
218
+ ) -> np.ndarray:
219
+ lib = _get_lib()
220
+ # Ensure seed is flat tuple or array
221
+ if seed.ndim > 1:
222
+ seed = seed.flatten()
223
+ s_tuple = (int(seed[0]), int(seed[1]))
224
+
225
+ if lib is None:
226
+ # Use pure Python fallback
227
+ keys_flat = keys.flatten() if keys.ndim > 1 else keys
228
+ return py_kernels.okvs_solve(keys_flat, values, m, seed=s_tuple)
229
+
230
+ n = keys.shape[0]
231
+ keys_c = np.ascontiguousarray(keys, dtype=np.uint64)
232
+ values_c = np.ascontiguousarray(values, dtype=np.uint64)
233
+ seed_c = np.ascontiguousarray(seed, dtype=np.uint64)
234
+ output = np.zeros((m, 2), dtype=np.uint64)
235
+
236
+ lib.solve_okvs(
237
+ keys_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
238
+ values_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
239
+ output.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
240
+ n,
241
+ m,
242
+ seed_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
243
+ )
244
+ return output
245
+
246
+
247
+ def _okvs_decode_impl(
248
+ keys: np.ndarray, storage: np.ndarray, m: int, seed: np.ndarray
249
+ ) -> np.ndarray:
250
+ lib = _get_lib()
251
+ # Ensure seed is flat tuple or array
252
+ if seed.ndim > 1:
253
+ seed = seed.flatten()
254
+ s_tuple = (int(seed[0]), int(seed[1]))
255
+
256
+ if lib is None:
257
+ # Use pure Python fallback
258
+ keys_flat = keys.flatten() if keys.ndim > 1 else keys
259
+ return py_kernels.okvs_decode(keys_flat, storage, m, seed=s_tuple)
260
+
261
+ n = keys.shape[0]
262
+ keys_c = np.ascontiguousarray(keys, dtype=np.uint64)
263
+ storage_c = np.ascontiguousarray(storage, dtype=np.uint64)
264
+ seed_c = np.ascontiguousarray(seed, dtype=np.uint64)
265
+ output = np.zeros((n, 2), dtype=np.uint64)
266
+
267
+ lib.decode_okvs(
268
+ keys_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
269
+ storage_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
270
+ output.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
271
+ n,
272
+ m,
273
+ seed_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
274
+ )
275
+ return output
276
+
277
+
278
+ def ldpc_encode_impl(
279
+ message: np.ndarray, h_indices: np.ndarray, h_indptr: np.ndarray, m: int
280
+ ) -> np.ndarray:
281
+ lib = _get_lib()
282
+ if lib is None:
283
+ # Use pure Python fallback
284
+ h_idx_flat = h_indices.flatten() if h_indices.ndim > 1 else h_indices
285
+ h_ptr_flat = h_indptr.flatten() if h_indptr.ndim > 1 else h_indptr
286
+ return py_kernels.ldpc_encode(message, h_idx_flat, h_ptr_flat, m)
287
+
288
+ # Fast C++ Path
289
+ msg_c = np.ascontiguousarray(message, dtype=np.uint64)
290
+ idx_c = np.ascontiguousarray(h_indices, dtype=np.uint64)
291
+ ptr_c = np.ascontiguousarray(h_indptr, dtype=np.uint64)
292
+
293
+ output = np.zeros((m, 2), dtype=np.uint64)
294
+
295
+ # n is inferred from message length
296
+ n = message.shape[0]
297
+
298
+ lib.ldpc_encode(
299
+ msg_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
300
+ idx_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
301
+ ptr_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
302
+ output.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
303
+ m,
304
+ n,
305
+ )
306
+ return output
307
+
308
+
309
+ # =============================================================================
310
+ # Primitive Implementations
311
+ # =============================================================================
312
+
313
+
314
+ @field.ldpc_encode_p.def_impl
315
+ def _ldpc_encode_impl_prim(
316
+ interpreter: Interpreter,
317
+ op: Operation,
318
+ message_val: TensorValue,
319
+ indices_val: TensorValue,
320
+ indptr_val: TensorValue,
321
+ ) -> TensorValue:
322
+ m = op.attrs["m"]
323
+ message = _unwrap(message_val)
324
+ indices = _unwrap(indices_val)
325
+ indptr = _unwrap(indptr_val)
326
+ res = ldpc_encode_impl(message, indices, indptr, m)
327
+ return _wrap(res)
328
+
329
+
330
+ @field.aes_expand_p.def_impl
331
+ def _aes_expand_impl_prim(
332
+ interpreter: Interpreter, op: Operation, seeds_val: TensorValue
333
+ ) -> TensorValue:
334
+ length = op.attrs["length"]
335
+ seeds = _unwrap(seeds_val)
336
+
337
+ # JAX PRG Fallback crashed. Switching to NumPy PRG.
338
+
339
+ # Check if bytes
340
+ if seeds.dtype == np.uint8 and seeds.shape[-1] == 16:
341
+ seeds = seeds.view(np.uint64)
342
+
343
+ if seeds.shape[-1] != 2:
344
+ seeds = seeds.reshape(-1, 2)
345
+
346
+ num_seeds = seeds.shape[0]
347
+ out_shape = (num_seeds, length, 2)
348
+ output = np.zeros(out_shape, dtype=np.uint64)
349
+
350
+ lib = _get_lib()
351
+ if lib is not None:
352
+ # Fast C++ Path
353
+ seeds_c = np.ascontiguousarray(seeds, dtype=np.uint64)
354
+
355
+ lib.aes_128_expand(
356
+ seeds_c.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
357
+ output.ctypes.data_as(ctypes.POINTER(ctypes.c_uint64)),
358
+ num_seeds,
359
+ length,
360
+ )
361
+ else:
362
+ # Slow Python Path (Fallback)
363
+ # Iterate and generate
364
+ for i in range(num_seeds):
365
+ # Seed from pair
366
+ s0 = int(seeds[i, 0])
367
+ s1 = int(seeds[i, 1])
368
+ seed_val = [s0, s1]
369
+
370
+ rng = np.random.default_rng(seed_val)
371
+ vals = rng.integers(
372
+ 0, 0xFFFFFFFFFFFFFFFF, size=(length, 2), dtype=np.uint64
373
+ )
374
+ output[i] = vals
375
+
376
+ # Return as JAX array to keep downstream happy
377
+ res_jax = jnp.array(output)
378
+
379
+ return _wrap(res_jax)
380
+
381
+
382
+ @field.mul_p.def_impl
383
+ def _mul_impl(
384
+ interpreter: Interpreter, op: Operation, a_val: TensorValue, b_val: TensorValue
385
+ ) -> TensorValue:
386
+ a = a_val.unwrap()
387
+ b = b_val.unwrap()
388
+ res = _gf128_mul_impl(a, b)
389
+ return TensorValue(res)
390
+
391
+
392
+ @field.solve_okvs_p.def_impl
393
+ def _solve_okvs_impl(
394
+ interpreter: Interpreter,
395
+ op: Operation,
396
+ keys_val: TensorValue,
397
+ values_val: TensorValue,
398
+ seed_val: TensorValue,
399
+ ) -> TensorValue:
400
+ m = op.attrs["m"]
401
+ keys = _unwrap(keys_val)
402
+ values = _unwrap(values_val)
403
+ seed = _unwrap(seed_val)
404
+ res = _okvs_solve_impl(keys, values, m, seed)
405
+ return _wrap(res)
406
+
407
+
408
+ @field.decode_okvs_p.def_impl
409
+ def _decode_okvs_impl(
410
+ interpreter: Interpreter,
411
+ op: Operation,
412
+ keys_val: TensorValue,
413
+ store_val: TensorValue,
414
+ seed_val: TensorValue,
415
+ ) -> TensorValue:
416
+ keys = _unwrap(keys_val)
417
+ storage = _unwrap(store_val)
418
+ seed = _unwrap(seed_val)
419
+ m = storage.shape[0]
420
+ res = _okvs_decode_impl(keys, storage, m, seed)
421
+ return _wrap(res)
422
+ return _wrap(res)
423
+
424
+
425
+ @field.solve_okvs_opt_p.def_impl
426
+ def _solve_okvs_opt_impl_prim(
427
+ interpreter: Interpreter,
428
+ op: Operation,
429
+ keys_val: TensorValue,
430
+ values_val: TensorValue,
431
+ seed_val: TensorValue,
432
+ ) -> TensorValue:
433
+ m = op.attrs["m"]
434
+ keys = _unwrap(keys_val)
435
+ values = _unwrap(values_val)
436
+ seed = _unwrap(seed_val)
437
+ res = _okvs_solve_opt_impl(keys, values, m, seed)
438
+ return _wrap(res)
439
+
440
+
441
+ @field.decode_okvs_opt_p.def_impl
442
+ def _decode_okvs_opt_impl_prim(
443
+ interpreter: Interpreter,
444
+ op: Operation,
445
+ keys_val: TensorValue,
446
+ store_val: TensorValue,
447
+ seed_val: TensorValue,
448
+ ) -> TensorValue:
449
+ keys = _unwrap(keys_val)
450
+ storage = _unwrap(store_val)
451
+ seed = _unwrap(seed_val)
452
+ m = storage.shape[0]
453
+ res = _okvs_decode_opt_impl(keys, storage, m, seed)
454
+ return _wrap(res)
@@ -0,0 +1,107 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Generic kernel implementations for the `func` dialect.
16
+
17
+ Design: Function as Value
18
+ - func.func impl: Returns a FunctionValue wrapping the traced graph.
19
+ - func.call impl: Executes the function graph with provided arguments.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ from typing import TYPE_CHECKING, Any, ClassVar
25
+
26
+ from mplang.v2.dialects.func import call_p, func_def_p
27
+ from mplang.v2.edsl import serde
28
+ from mplang.v2.edsl.graph import Graph, Operation
29
+ from mplang.v2.runtime.value import Value
30
+
31
+ if TYPE_CHECKING:
32
+ from typing import Self
33
+
34
+
35
+ @serde.register_class
36
+ class FunctionValue(Value):
37
+ """Runtime representation of a traced function.
38
+
39
+ This is a first-class runtime Value that wraps a Graph.
40
+ Produced by func.func, consumed by func.call.
41
+
42
+ Semantic rationale:
43
+ In the interpreter's computation model, Values are data that flow
44
+ between Operations. A function (Graph) is just another kind of data
45
+ that can be passed around, stored, and invoked - hence it's a Value.
46
+ """
47
+
48
+ _serde_kind: ClassVar[str] = "func.FunctionValue"
49
+
50
+ def __init__(self, graph: Graph, name: str = "anonymous") -> None:
51
+ self._graph = graph
52
+ self._name = name
53
+
54
+ @property
55
+ def graph(self) -> Graph:
56
+ return self._graph
57
+
58
+ @property
59
+ def name(self) -> str:
60
+ return self._name
61
+
62
+ def __repr__(self) -> str:
63
+ return f"FunctionValue({self._name!r}, ops={len(self._graph.operations)})"
64
+
65
+ def to_json(self) -> dict[str, Any]:
66
+ """Serialize function to JSON."""
67
+ return {
68
+ "graph": serde.to_json(self._graph),
69
+ "name": self._name,
70
+ }
71
+
72
+ @classmethod
73
+ def from_json(cls, data: dict[str, Any]) -> Self:
74
+ """Deserialize function from JSON."""
75
+ graph = serde.from_json(data["graph"])
76
+ return cls(graph=graph, name=data.get("name", "anonymous"))
77
+
78
+
79
+ @func_def_p.def_impl
80
+ def _func_def_impl(interpreter: Any, op: Operation, *args: Any) -> FunctionValue:
81
+ """Implementation of func.func: return a FunctionValue wrapping the body graph."""
82
+ if not op.regions:
83
+ raise ValueError("func.func operation missing body region")
84
+
85
+ name = op.attrs.get("sym_name", "anonymous")
86
+ return FunctionValue(graph=op.regions[0], name=name)
87
+
88
+
89
+ @call_p.def_impl
90
+ def _call_impl(
91
+ interpreter: Any, op: Operation, fn_obj: FunctionValue, *args: Any
92
+ ) -> Any:
93
+ """Implementation of func.call: execute the function graph.
94
+
95
+ Args:
96
+ interpreter: The interpreter instance.
97
+ op: The func.call operation.
98
+ fn_obj: The FunctionValue returned by func.func.
99
+ *args: Arguments to pass to the function.
100
+ """
101
+ if not isinstance(fn_obj, FunctionValue):
102
+ raise TypeError(f"func.call expects FunctionValue, got {type(fn_obj)}")
103
+
104
+ call_args = list(args)
105
+ result = interpreter.evaluate_graph(fn_obj.graph, call_args)
106
+ # Return single value or list based on graph outputs
107
+ return result[0] if len(fn_obj.graph.outputs) == 1 else result
@@ -0,0 +1,148 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """PHE Runtime Implementation using LightPHE."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from dataclasses import dataclass
20
+ from typing import Any, cast
21
+
22
+ from lightphe import LightPHE
23
+ from lightphe.models.Ciphertext import Ciphertext
24
+
25
+ from mplang.v2.dialects import phe
26
+ from mplang.v2.edsl.graph import Operation
27
+ from mplang.v2.runtime.interpreter import Interpreter
28
+
29
+
30
+ class PHEContext:
31
+ """Wraps LightPHE context."""
32
+
33
+ def __init__(self, algorithm_name: str = "Paillier", key_size: int = 2048):
34
+ # Normalize algorithm name (LightPHE expects capitalized names)
35
+ normalized_name = algorithm_name.capitalize()
36
+ self.cs = LightPHE(algorithm_name=normalized_name, key_size=key_size)
37
+
38
+ def encrypt(self, value: int) -> Ciphertext:
39
+ return self.cs.encrypt(value)
40
+
41
+ def decrypt(self, ct: Ciphertext) -> int:
42
+ return cast(int, self.cs.decrypt(ct))
43
+
44
+
45
+ @dataclass
46
+ class PHEEncoder:
47
+ """Simple fixed-point encoder."""
48
+
49
+ scale: float
50
+
51
+
52
+ @dataclass
53
+ class WrappedCiphertext:
54
+ ct: Ciphertext
55
+ ctx: PHEContext
56
+
57
+ def __add__(self, other: Any) -> WrappedCiphertext:
58
+ if isinstance(other, WrappedCiphertext):
59
+ # ct + ct
60
+ new_ct = self.ct + other.ct
61
+ return WrappedCiphertext(new_ct, self.ctx)
62
+ elif isinstance(other, int):
63
+ # ct + int -> ct + encrypt(int)
64
+ ct_other = self.ctx.encrypt(other)
65
+ new_ct = self.ct + ct_other
66
+ return WrappedCiphertext(new_ct, self.ctx)
67
+ return NotImplemented
68
+
69
+ def __mul__(self, other: Any) -> WrappedCiphertext:
70
+ if isinstance(other, int):
71
+ # ct * int
72
+ new_ct = self.ct * other
73
+ return WrappedCiphertext(new_ct, self.ctx)
74
+ return NotImplemented
75
+
76
+
77
+ @phe.keygen_p.def_impl
78
+ def keygen_impl(
79
+ interpreter: Interpreter, op: Operation, *args: Any
80
+ ) -> tuple[PHEContext, PHEContext]:
81
+ key_size = op.attrs.get("key_size", 2048)
82
+ scheme = op.attrs.get("scheme", "Paillier")
83
+
84
+ ctx = PHEContext(algorithm_name=scheme, key_size=key_size)
85
+
86
+ return ctx, ctx
87
+
88
+
89
+ @phe.create_encoder_p.def_impl
90
+ def create_encoder_impl(
91
+ interpreter: Interpreter, op: Operation, *args: Any
92
+ ) -> PHEEncoder:
93
+ fxp_bits = op.attrs.get("fxp_bits", 16)
94
+ scale = 2.0**fxp_bits
95
+ return PHEEncoder(scale=scale)
96
+
97
+
98
+ @phe.encode_p.def_impl
99
+ def encode_impl(
100
+ interpreter: Interpreter, op: Operation, value: float, encoder: PHEEncoder
101
+ ) -> int:
102
+ return int(value * encoder.scale)
103
+
104
+
105
+ @phe.decode_p.def_impl
106
+ def decode_impl(
107
+ interpreter: Interpreter, op: Operation, value: int, encoder: PHEEncoder
108
+ ) -> float:
109
+ return float(value) / encoder.scale
110
+
111
+
112
+ @phe.encrypt_p.def_impl
113
+ def encrypt_impl(
114
+ interpreter: Interpreter, op: Operation, value: int, pk: PHEContext
115
+ ) -> WrappedCiphertext:
116
+ ct = pk.encrypt(value)
117
+ return WrappedCiphertext(ct, pk)
118
+
119
+
120
+ @phe.decrypt_p.def_impl
121
+ def decrypt_impl(
122
+ interpreter: Interpreter, op: Operation, wct: WrappedCiphertext, sk: PHEContext
123
+ ) -> int:
124
+ return sk.decrypt(wct.ct)
125
+
126
+
127
+ @phe.add_cc_p.def_impl
128
+ def add_cc_impl(
129
+ interpreter: Interpreter,
130
+ op: Operation,
131
+ lhs: WrappedCiphertext,
132
+ rhs: WrappedCiphertext,
133
+ ) -> WrappedCiphertext:
134
+ return lhs + rhs
135
+
136
+
137
+ @phe.add_cp_p.def_impl
138
+ def add_cp_impl(
139
+ interpreter: Interpreter, op: Operation, lhs: WrappedCiphertext, rhs: int
140
+ ) -> WrappedCiphertext:
141
+ return lhs + rhs
142
+
143
+
144
+ @phe.mul_cp_p.def_impl
145
+ def mul_cp_impl(
146
+ interpreter: Interpreter, op: Operation, lhs: WrappedCiphertext, rhs: int
147
+ ) -> WrappedCiphertext:
148
+ return lhs * rhs