mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,240 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import numpy as np
18
+
19
+ from mplang.v1.core import PFunction, TableType, TensorType
20
+ from mplang.v1.kernels.base import cur_kctx, kernel_def
21
+ from mplang.v1.kernels.value import TableValue, TensorValue, Value
22
+ from mplang.v1.runtime.data_providers import get_provider, resolve_uri
23
+ from mplang.v1.utils import table_utils
24
+
25
+
26
+ @kernel_def("basic.identity")
27
+ def _identity(pfunc: PFunction, value: Value) -> Value:
28
+ # Runtime guarantees exactly one argument; no extra arity checks here.
29
+ return value
30
+
31
+
32
+ @kernel_def("basic.read")
33
+ def _read(pfunc: PFunction) -> Value:
34
+ path = pfunc.attrs.get("path")
35
+ if path is None:
36
+ raise ValueError("missing path attr for basic.read")
37
+ out_t = pfunc.outs_info[0]
38
+ uri = resolve_uri(str(path))
39
+ prov = get_provider(uri.scheme)
40
+ if prov is None:
41
+ raise NotImplementedError(f"no resource provider for scheme: {uri.scheme}")
42
+ ctx = cur_kctx()
43
+ try:
44
+ data = prov.read(uri, out_t, ctx=ctx)
45
+ except Exception as e: # pragma: no cover - provider errors
46
+ raise RuntimeError(f"basic.read failed: {e}") from e
47
+
48
+ if isinstance(data, Value):
49
+ return data
50
+
51
+ if isinstance(out_t, TableType):
52
+ return TableValue(data)
53
+ elif isinstance(out_t, TensorType):
54
+ return TensorValue(np.asarray(data))
55
+ else:
56
+ raise TypeError(
57
+ f"basic.read only supports TableType/TensorType outputs, got {type(out_t).__name__}"
58
+ )
59
+
60
+
61
+ @kernel_def("basic.write")
62
+ def _write(pfunc: PFunction, obj: Value) -> Value:
63
+ path = pfunc.attrs.get("path")
64
+ if path is None:
65
+ raise ValueError("missing path attr for basic.write")
66
+ uri = resolve_uri(str(path))
67
+ prov = get_provider(uri.scheme)
68
+ if prov is None:
69
+ raise NotImplementedError(f"no resource provider for scheme: {uri.scheme}")
70
+ # Pass Value object directly to provider - let provider decide how to handle it
71
+ ctx = cur_kctx()
72
+ try:
73
+ prov.write(uri, obj, ctx=ctx)
74
+ except Exception as e: # pragma: no cover
75
+ raise RuntimeError(f"basic.write failed: {e}") from e
76
+ return obj
77
+
78
+
79
+ @kernel_def("basic.constant")
80
+ def _constant(pfunc: PFunction) -> Value:
81
+ """Return constants as Value types (TensorValue or TableValue)."""
82
+ data_bytes = pfunc.attrs.get("data_bytes")
83
+ if data_bytes is None:
84
+ raise ValueError("missing data_bytes attr for basic.constant")
85
+ out_t = pfunc.outs_info[0]
86
+ fmt = pfunc.attrs.get("data_format")
87
+ if isinstance(out_t, TableType):
88
+ if fmt != "bytes[parquet]":
89
+ raise ValueError(f"unsupported table constant format {fmt}")
90
+ df = table_utils.decode_table(data_bytes, format="parquet")
91
+ return TableValue(df)
92
+ # tensor path
93
+ shape = out_t.shape # type: ignore[attr-defined,union-attr]
94
+ dtype = out_t.dtype.numpy_dtype() # type: ignore[attr-defined,union-attr]
95
+ arr = np.frombuffer(data_bytes, dtype=dtype).reshape(shape)
96
+ return TensorValue(arr)
97
+
98
+
99
+ @kernel_def("basic.rank")
100
+ def _rank(pfunc: PFunction) -> TensorValue:
101
+ """Return rank as TensorValue."""
102
+ ctx = cur_kctx()
103
+ arr = np.array(ctx.rank, dtype=np.uint64)
104
+ return TensorValue(arr)
105
+
106
+
107
+ @kernel_def("basic.prand")
108
+ def _prand(pfunc: PFunction) -> TensorValue:
109
+ """Return random data as TensorValue."""
110
+ shape = pfunc.attrs.get("shape", ())
111
+ rng = np.random.default_rng()
112
+ info = np.iinfo(np.uint64)
113
+ data = rng.integers(
114
+ low=info.min, high=info.max, size=shape, dtype=np.uint64, endpoint=True
115
+ )
116
+ return TensorValue(data)
117
+
118
+
119
+ @kernel_def("basic.table_to_tensor")
120
+ def _table_to_tensor(pfunc: PFunction, table: TableValue) -> TensorValue:
121
+ """Convert table to tensor, return as TensorValue."""
122
+ arrow_table = table.to_arrow()
123
+ if arrow_table.num_columns == 0:
124
+ raise ValueError("cannot pack empty table")
125
+ # Convert Arrow columns to numpy arrays and stack
126
+ mat = np.column_stack([
127
+ arrow_table.column(i).to_numpy() for i in range(arrow_table.num_columns)
128
+ ])
129
+ return TensorValue(mat)
130
+
131
+
132
+ @kernel_def("basic.tensor_to_table")
133
+ def _tensor_to_table(pfunc: PFunction, tensor: TensorValue) -> TableValue:
134
+ """Convert tensor to table, return as TableValue."""
135
+ import pyarrow as pa # type: ignore
136
+
137
+ arr = tensor.to_numpy()
138
+ if arr.ndim != 2:
139
+ raise ValueError("tensor_to_table expects rank-2 array")
140
+ col_names = pfunc.attrs.get("column_names")
141
+ if col_names is None:
142
+ raise ValueError("missing column_names attr")
143
+ # Create Arrow table directly from numpy array columns
144
+ arrays = [pa.array(arr[:, i]) for i in range(arr.shape[1])]
145
+ arrow_table = pa.table(dict(zip(col_names, arrays, strict=True)))
146
+ return TableValue(arrow_table)
147
+
148
+
149
+ def _summ(v: Value) -> str:
150
+ try:
151
+ if isinstance(v, TableValue):
152
+ # Use Arrow's native string representation (more efficient)
153
+ arrow_table = v.to_arrow()
154
+ # Show first 8 rows
155
+ preview = arrow_table.slice(0, min(8, arrow_table.num_rows))
156
+ return str(preview)
157
+ if isinstance(v, TensorValue):
158
+ arr = v.to_numpy()
159
+ return str(
160
+ np.array2string(
161
+ arr, threshold=64, edgeitems=3, precision=6, suppress_small=True
162
+ )
163
+ )
164
+ return repr(v)
165
+ except Exception as e: # pragma: no cover
166
+ return f"<unprintable {type(v).__name__}: {e}>"
167
+
168
+
169
+ @kernel_def("basic.debug_print")
170
+ def _debug_print(pfunc: PFunction, val: Value) -> Value:
171
+ prefix = pfunc.attrs.get("prefix", "")
172
+ ctx = cur_kctx()
173
+ print(f"[debug_print][rank={ctx.rank}] {prefix}{_summ(val)}")
174
+ return val
175
+
176
+
177
+ @kernel_def("basic.pack")
178
+ def _pack(pfunc: PFunction, value: Value) -> TensorValue:
179
+ outs_info = pfunc.outs_info
180
+ if len(outs_info) != 1:
181
+ raise ValueError("basic.pack expects single output type")
182
+ out_ty = outs_info[0]
183
+ if not isinstance(out_ty, TensorType):
184
+ raise TypeError("basic.pack must return TensorType")
185
+ if out_ty.dtype.numpy_dtype() != np.uint8:
186
+ raise TypeError("basic.pack output dtype must be uint8")
187
+
188
+ if isinstance(value, TableValue):
189
+ # Serialize Arrow table using IPC stream for consistency with Value serde
190
+ import pyarrow as pa # type: ignore
191
+ import pyarrow.ipc as pa_ipc # type: ignore
192
+
193
+ arrow_table = value.to_arrow()
194
+ sink = pa.BufferOutputStream()
195
+ with pa_ipc.new_stream(sink, arrow_table.schema) as writer: # type: ignore[arg-type]
196
+ writer.write_table(arrow_table) # type: ignore[arg-type]
197
+ ipc_bytes = sink.getvalue().to_pybytes()
198
+ return TensorValue(np.frombuffer(ipc_bytes, dtype=np.uint8))
199
+
200
+ if isinstance(value, TensorValue):
201
+ arr = value.to_numpy()
202
+ return TensorValue(np.frombuffer(arr.tobytes(order="C"), dtype=np.uint8))
203
+
204
+ raise TypeError(f"basic.pack does not support Value type {type(value).__name__}")
205
+
206
+
207
+ @kernel_def("basic.unpack")
208
+ def _unpack(pfunc: PFunction, packed: TensorValue) -> Value:
209
+ outs_info = pfunc.outs_info
210
+ if len(outs_info) != 1:
211
+ raise ValueError("basic.unpack expects single output type")
212
+ out_ty = outs_info[0]
213
+
214
+ b = packed.to_numpy().astype(np.uint8, copy=False).reshape(-1)
215
+
216
+ if isinstance(out_ty, TensorType):
217
+ np_dtype = out_ty.dtype.numpy_dtype()
218
+ shape = tuple(out_ty.shape)
219
+ if any(dim < 0 for dim in shape):
220
+ raise ValueError("basic.unpack does not support dynamic tensor shapes")
221
+ elem_count = int(np.prod(shape))
222
+ expected = elem_count * np.dtype(np_dtype).itemsize
223
+ if b.size != expected:
224
+ raise ValueError(
225
+ f"unpack size mismatch: got {b.size} bytes, expect {expected} for {np_dtype} {shape}"
226
+ )
227
+ arr = np.frombuffer(b.tobytes(), dtype=np_dtype)
228
+ return TensorValue(arr.reshape(shape))
229
+
230
+ if isinstance(out_ty, TableType):
231
+ # Deserialize Arrow IPC stream back to TableValue
232
+ import pyarrow as pa # type: ignore
233
+ import pyarrow.ipc as pa_ipc # type: ignore
234
+
235
+ buf = pa.py_buffer(b.tobytes())
236
+ reader = pa_ipc.open_stream(buf)
237
+ table = reader.read_all()
238
+ return TableValue(table)
239
+
240
+ raise TypeError("basic.unpack output type must be TensorType or TableType")
@@ -17,12 +17,12 @@ from __future__ import annotations
17
17
  from collections.abc import Mapping
18
18
  from typing import Any
19
19
 
20
- from mplang.core.dtype import UINT8, DType
21
- from mplang.core.pfunc import PFunction
22
- from mplang.core.table import TableLike, TableType
23
- from mplang.core.tensor import TensorLike, TensorType
24
- from mplang.kernels import base
25
- from mplang.kernels.base import KernelContext, get_kernel_spec, kernel_exists
20
+ from mplang.v1.core.dtypes import UINT8, DType
21
+ from mplang.v1.core.pfunc import PFunction
22
+ from mplang.v1.core.table import PandasTableLike, TableLike, TableType
23
+ from mplang.v1.core.tensor import TensorLike, TensorType
24
+ from mplang.v1.kernels import base
25
+ from mplang.v1.kernels.base import KernelContext, get_kernel_spec, kernel_exists
26
26
 
27
27
  # Default bindings
28
28
  # Import kernel implementation modules explicitly so their @kernel_def entries
@@ -35,13 +35,14 @@ def _ensure_impl_imported() -> None:
35
35
  global _IMPL_IMPORTED
36
36
  if _IMPL_IMPORTED:
37
37
  return
38
- from mplang.kernels import builtin as _impl_builtin # noqa: F401
39
- from mplang.kernels import crypto as _impl_crypto # noqa: F401
40
- from mplang.kernels import mock_tee as _impl_tee # noqa: F401
41
- from mplang.kernels import phe as _impl_phe # noqa: F401
42
- from mplang.kernels import spu as _impl_spu # noqa: F401
43
- from mplang.kernels import sql_duckdb as _impl_sql_duckdb # noqa: F401
44
- from mplang.kernels import stablehlo as _impl_stablehlo # noqa: F401
38
+ from mplang.v1.kernels import basic as _impl_basic # noqa: F401
39
+ from mplang.v1.kernels import crypto as _impl_crypto # noqa: F401
40
+ from mplang.v1.kernels import fhe as _impl_fhe # noqa: F401
41
+ from mplang.v1.kernels import mock_tee as _impl_tee # noqa: F401
42
+ from mplang.v1.kernels import phe as _impl_phe # noqa: F401
43
+ from mplang.v1.kernels import spu as _impl_spu # noqa: F401
44
+ from mplang.v1.kernels import sql_duckdb as _impl_sql_duckdb # noqa: F401
45
+ from mplang.v1.kernels import stablehlo as _impl_stablehlo # noqa: F401
45
46
 
46
47
  _IMPL_IMPORTED = True
47
48
 
@@ -49,18 +50,18 @@ def _ensure_impl_imported() -> None:
49
50
  # imports consolidated above
50
51
 
51
52
  _DEFAULT_BINDINGS: dict[str, str] = {
52
- # builtin
53
- "builtin.identity": "builtin.identity",
54
- "builtin.read": "builtin.read",
55
- "builtin.write": "builtin.write",
56
- "builtin.constant": "builtin.constant",
57
- "builtin.rank": "builtin.rank",
58
- "builtin.prand": "builtin.prand",
59
- "builtin.table_to_tensor": "builtin.table_to_tensor",
60
- "builtin.tensor_to_table": "builtin.tensor_to_table",
61
- "builtin.debug_print": "builtin.debug_print",
62
- "builtin.pack": "builtin.pack",
63
- "builtin.unpack": "builtin.unpack",
53
+ # basic
54
+ "basic.identity": "basic.identity",
55
+ "basic.read": "basic.read",
56
+ "basic.write": "basic.write",
57
+ "basic.constant": "basic.constant",
58
+ "basic.rank": "basic.rank",
59
+ "basic.prand": "basic.prand",
60
+ "basic.table_to_tensor": "basic.table_to_tensor",
61
+ "basic.tensor_to_table": "basic.tensor_to_table",
62
+ "basic.debug_print": "basic.debug_print",
63
+ "basic.pack": "basic.pack",
64
+ "basic.unpack": "basic.unpack",
64
65
  # crypto
65
66
  "crypto.keygen": "crypto.keygen",
66
67
  "crypto.enc": "crypto.enc",
@@ -80,6 +81,17 @@ _DEFAULT_BINDINGS: dict[str, str] = {
80
81
  "phe.concat": "phe.concat",
81
82
  "phe.reshape": "phe.reshape",
82
83
  "phe.transpose": "phe.transpose",
84
+ # fhe
85
+ "fhe.keygen": "fhe.keygen",
86
+ "fhe.encrypt": "fhe.encrypt",
87
+ "fhe.decrypt": "fhe.decrypt",
88
+ "fhe.add": "fhe.add",
89
+ "fhe.mul": "fhe.mul",
90
+ "fhe.dot": "fhe.dot",
91
+ "fhe.polyval": "fhe.polyval",
92
+ "fhe.sub": "fhe.sub",
93
+ "fhe.negate": "fhe.negate",
94
+ "fhe.square": "fhe.square",
83
95
  # spu
84
96
  "spu.seed_env": "spu.seed_env",
85
97
  "spu.makeshares": "spu.makeshares",
@@ -305,9 +317,12 @@ def _validate_table_arg(
305
317
  raise TypeError(
306
318
  f"kernel {fn_type} input[{arg_index}] expects TableLike, got {type(value).__name__}"
307
319
  )
308
- if len(value.columns) != len(spec.columns):
320
+ columns = (
321
+ value.columns if isinstance(value, PandasTableLike) else value.column_names
322
+ )
323
+ if len(columns) != len(spec.columns):
309
324
  raise ValueError(
310
- f"kernel {fn_type} input[{arg_index}] column count mismatch: got {len(value.columns)}, expected {len(spec.columns)}"
325
+ f"kernel {fn_type} input[{arg_index}] column count mismatch: got {len(columns)}, expected {len(spec.columns)}"
311
326
  )
312
327
 
313
328
 
@@ -15,15 +15,15 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  import os
18
- from typing import Any
19
18
 
20
19
  import numpy as np
21
20
 
22
- from mplang.core.pfunc import PFunction
23
- from mplang.kernels.base import cur_kctx, kernel_def
24
- from mplang.utils.crypto import blake2b
21
+ from mplang.v1.core import PFunction
22
+ from mplang.v1.kernels.base import cur_kctx, kernel_def
23
+ from mplang.v1.kernels.value import TensorValue
24
+ from mplang.v1.utils.crypto import blake2b
25
25
 
26
- __all__: list[str] = [] # flat kernels only
26
+ __all__: list[str] = [] # No public exports currently
27
27
 
28
28
 
29
29
  def _get_rng() -> np.random.Generator:
@@ -45,71 +45,78 @@ def _get_rng() -> np.random.Generator:
45
45
  def _keystream(key: bytes, nonce: bytes, length: int) -> bytes:
46
46
  # WARNING (INSECURE): hash-based keystream (key||nonce||counter)
47
47
  out = bytearray()
48
- counter = 0
49
48
  while len(out) < length:
50
- chunk = blake2b(key + nonce + counter.to_bytes(4, "little"))
49
+ chunk = blake2b(key + nonce)
51
50
  out.extend(chunk)
52
- counter += 1
53
51
  return bytes(out[:length])
54
52
 
55
53
 
56
54
  @kernel_def("crypto.keygen")
57
- def _crypto_keygen(pfunc: PFunction) -> Any:
55
+ def _crypto_keygen(pfunc: PFunction) -> TensorValue:
58
56
  length = int(pfunc.attrs.get("length", 32))
59
57
  rng = _get_rng()
60
58
  key = rng.integers(0, 256, size=(length,), dtype=np.uint8)
61
- return key
59
+ return TensorValue(key)
62
60
 
63
61
 
64
62
  @kernel_def("crypto.enc")
65
- def _crypto_encrypt(pfunc: PFunction, pt_bytes: Any, key: Any) -> Any:
66
- pt_bytes = np.asarray(pt_bytes, dtype=np.uint8)
67
- key = np.asarray(key, dtype=np.uint8)
63
+ def _crypto_encrypt(
64
+ pfunc: PFunction, pt_bytes: TensorValue, key: TensorValue
65
+ ) -> TensorValue:
66
+ pt_bytes_np = pt_bytes.to_numpy().astype(np.uint8, copy=False)
67
+ key_np = key.to_numpy().astype(np.uint8, copy=False)
68
68
  rng = _get_rng()
69
- nonce = rng.integers(0, 256, size=(12,), dtype=np.uint8)
69
+ nonce = rng.integers(0, 256, size=(16,), dtype=np.uint8)
70
70
  stream = np.frombuffer(
71
- _keystream(key.tobytes(), nonce.tobytes(), pt_bytes.size), dtype=np.uint8
71
+ _keystream(key_np.tobytes(), nonce.tobytes(), pt_bytes_np.size), dtype=np.uint8
72
72
  )
73
- ct = (pt_bytes ^ stream).astype(np.uint8)
73
+ ct = (pt_bytes_np ^ stream).astype(np.uint8)
74
74
  out = np.concatenate([nonce, ct]).astype(np.uint8)
75
- return out
75
+ return TensorValue(out)
76
76
 
77
77
 
78
78
  @kernel_def("crypto.dec")
79
- def _crypto_decrypt(pfunc: PFunction, ct_with_nonce: Any, key: Any) -> Any:
80
- ct_with_nonce = np.asarray(ct_with_nonce, dtype=np.uint8)
81
- key = np.asarray(key, dtype=np.uint8)
82
- nonce = ct_with_nonce[:12]
83
- ct = ct_with_nonce[12:]
79
+ def _crypto_decrypt(
80
+ pfunc: PFunction, ct_with_nonce: TensorValue, key: TensorValue
81
+ ) -> TensorValue:
82
+ ct_np = ct_with_nonce.to_numpy().astype(np.uint8, copy=False)
83
+ key_np = key.to_numpy().astype(np.uint8, copy=False)
84
+ nonce = ct_np[:16]
85
+ ct = ct_np[16:]
84
86
  stream = np.frombuffer(
85
- _keystream(key.tobytes(), nonce.tobytes(), len(ct)), dtype=np.uint8
87
+ _keystream(key_np.tobytes(), nonce.tobytes(), len(ct)), dtype=np.uint8
86
88
  )
87
89
  pt_bytes = (ct ^ stream).astype(np.uint8)
88
- return pt_bytes
90
+ return TensorValue(pt_bytes)
89
91
 
90
92
 
91
93
  @kernel_def("crypto.kem_keygen")
92
- def _crypto_kem_keygen(pfunc: PFunction) -> Any:
94
+ def _crypto_kem_keygen(pfunc: PFunction) -> tuple[TensorValue, TensorValue]:
93
95
  rng = _get_rng()
94
96
  sk = rng.integers(0, 256, size=(32,), dtype=np.uint8)
95
- pk = np.frombuffer(blake2b(sk.tobytes())[:32], dtype=np.uint8)
96
- return (sk, pk)
97
+ pk_bytes = blake2b(sk.tobytes())[:32]
98
+ pk = np.frombuffer(pk_bytes, dtype=np.uint8)
99
+ return (TensorValue(sk), TensorValue(pk))
97
100
 
98
101
 
99
102
  @kernel_def("crypto.kem_derive")
100
- def _crypto_kem_derive(pfunc: PFunction, sk: Any, peer_pk: Any) -> Any:
101
- sk = np.asarray(sk, dtype=np.uint8)
102
- peer_pk = np.asarray(peer_pk, dtype=np.uint8)
103
- self_pk = np.frombuffer(blake2b(sk.tobytes())[:32], dtype=np.uint8)
104
- xored = (self_pk ^ peer_pk).astype(np.uint8)
103
+ def _crypto_kem_derive(
104
+ pfunc: PFunction, sk: TensorValue, peer_pk: TensorValue
105
+ ) -> TensorValue:
106
+ sk_np = sk.to_numpy().astype(np.uint8, copy=False)
107
+ peer_pk_np = peer_pk.to_numpy().astype(np.uint8, copy=False)
108
+
109
+ self_pk_bytes = blake2b(sk_np.tobytes())[:32]
110
+ self_pk_arr = np.frombuffer(self_pk_bytes, dtype=np.uint8)
111
+ xored = (self_pk_arr ^ peer_pk_np).astype(np.uint8)
105
112
  secret = np.frombuffer(blake2b(xored.tobytes())[:32], dtype=np.uint8)
106
- return secret
113
+ return TensorValue(secret)
107
114
 
108
115
 
109
116
  @kernel_def("crypto.hkdf")
110
- def _crypto_hkdf(pfunc: PFunction, secret: Any) -> Any:
111
- secret = np.asarray(secret, dtype=np.uint8)
117
+ def _crypto_hkdf(pfunc: PFunction, secret: TensorValue) -> TensorValue:
118
+ secret_np = secret.to_numpy().astype(np.uint8, copy=False)
112
119
  info_str = str(pfunc.attrs.get("info", ""))
113
120
  info = info_str.encode("utf-8")
114
- out = np.frombuffer(blake2b(secret.tobytes() + info)[:32], dtype=np.uint8)
115
- return out
121
+ out = np.frombuffer(blake2b(secret_np.tobytes() + info)[:32], dtype=np.uint8)
122
+ return TensorValue(out)