mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
mplang/v1/ops/basic.py DELETED
@@ -1,294 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from jax.tree_util import PyTreeDef, tree_flatten
17
-
18
- from mplang.v1.core import (
19
- UINT8,
20
- UINT64,
21
- MPObject,
22
- PFunction,
23
- ScalarType,
24
- Shape,
25
- TableLike,
26
- TableType,
27
- TensorLike,
28
- TensorType,
29
- )
30
- from mplang.v1.ops.base import stateless_mod
31
- from mplang.v1.utils import table_utils
32
-
33
- _BASIC_MOD = stateless_mod("basic")
34
-
35
-
36
- @_BASIC_MOD.simple_op()
37
- def identity(x: TensorType) -> TensorType:
38
- """Return the input type unchanged.
39
-
40
- Args:
41
- x: The input tensor type. If called with an MPObject, the value is
42
- captured positionally; the kernel sees only the type.
43
-
44
- Returns:
45
- The same type as ``x``.
46
- """
47
- return x
48
-
49
-
50
- @_BASIC_MOD.simple_op()
51
- def read(*, path: str, ty: TensorType) -> TensorType:
52
- """Declare reading a value of type ``ty`` from ``path`` (type-only).
53
-
54
- Args:
55
- path: Non-empty path or URI to read from (stored as an attribute).
56
- ty: The expected output type/schema.
57
-
58
- Returns:
59
- Exactly ``ty``.
60
-
61
- Raises:
62
- ValueError: If ``path`` is empty.
63
- TypeError: If ``ty`` is not a TensorType or TableType.
64
- """
65
- if not isinstance(path, str) or path == "":
66
- raise ValueError("path must be a non-empty string")
67
- if not isinstance(ty, (TensorType, TableType)):
68
- raise TypeError("ty must be a TensorType or TableType")
69
- # typed_op will attach 'path' as an attribute and build the PFunction
70
- return ty
71
-
72
-
73
- @_BASIC_MOD.simple_op()
74
- def write(x: TensorType, *, path: str) -> TensorType:
75
- """Declare writing the input value to ``path`` and return the same type.
76
-
77
- Args:
78
- x: The value's type to be written; values are captured positionally.
79
- path: Destination path or URI (attribute).
80
-
81
- Returns:
82
- The same type as ``x``.
83
- """
84
- return x
85
-
86
-
87
- @_BASIC_MOD.op_def()
88
- def constant(
89
- data: TensorLike | ScalarType | TableLike,
90
- ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
91
- """Embed a literal tensor/table and return the full triad.
92
-
93
- Args:
94
- data: Constant payload. Supports scalars, array-like tensors, or
95
- table-like dataframes.
96
-
97
- Returns:
98
- Tuple[PFunction, list[MPObject], PyTreeDef]:
99
- - PFunction: ``fn_type='basic.constant'`` with one output whose type
100
- matches ``data``; payload serialized via ``data_bytes`` with
101
- ``data_format`` ('bytes[numpy]' or 'bytes[csv]').
102
- - list[MPObject]: Empty (no inputs captured).
103
- - PyTreeDef: Output tree (single leaf).
104
- """
105
- import numpy as np
106
-
107
- data_bytes: bytes
108
- out_type: TableType | TensorType
109
-
110
- if isinstance(data, TableLike):
111
- format = "parquet"
112
- data_bytes = table_utils.encode_table(data, format=format)
113
- data_format = f"bytes[{format}]"
114
- out_type = TableType.from_tablelike(data)
115
- elif isinstance(data, ScalarType):
116
- out_type = TensorType.from_obj(data)
117
- np_data = np.array(data)
118
- data_bytes = np_data.tobytes()
119
- data_format = "bytes[numpy]"
120
- else:
121
- if hasattr(data, "tobytes"):
122
- out_type = TensorType.from_obj(data)
123
- data_bytes = data.tobytes() # type: ignore[attr-defined]
124
- else:
125
- np_data = np.array(data)
126
- out_type = TensorType.from_obj(np_data)
127
- data_bytes = np_data.tobytes()
128
- data_format = "bytes[numpy]"
129
-
130
- pfunc = PFunction(
131
- fn_type="basic.constant",
132
- ins_info=(),
133
- outs_info=(out_type,),
134
- data_bytes=data_bytes,
135
- data_format=data_format,
136
- )
137
- _, treedef = tree_flatten(out_type)
138
- return pfunc, [], treedef
139
-
140
-
141
- @_BASIC_MOD.simple_op()
142
- def rank() -> TensorType:
143
- """Return the scalar UINT64 tensor type for the current party rank.
144
-
145
- Returns:
146
- A scalar ``UINT64`` tensor type (shape ``()``).
147
- """
148
- return TensorType(UINT64, ())
149
-
150
-
151
- @_BASIC_MOD.simple_op()
152
- def prand(*, shape: Shape = ()) -> TensorType:
153
- """Declare a private random UINT64 tensor with the given shape.
154
-
155
- Args:
156
- shape: Output tensor shape. Defaults to ``()``.
157
-
158
- Returns:
159
- A ``UINT64`` tensor type with the specified shape.
160
- """
161
- return TensorType(UINT64, shape)
162
-
163
-
164
- @_BASIC_MOD.simple_op()
165
- def debug_print(
166
- x: TensorType | TableType, *, prefix: str = ""
167
- ) -> TableType | TensorType:
168
- """Print a value at runtime and return the same type.
169
-
170
- Args:
171
- x: The value to print (captured positionally; kernel sees only type).
172
- prefix: Optional text prefix for the printed output.
173
-
174
- Returns:
175
- The same type as ``x``.
176
- """
177
- return x
178
-
179
-
180
- @_BASIC_MOD.simple_op()
181
- def pack(x: TensorType | TableType) -> TensorType:
182
- """Serialize a tensor/table into a byte vector (type-only).
183
-
184
- Args:
185
- x: Input type to pack.
186
-
187
- Returns:
188
- A ``UINT8`` tensor type with shape ``(-1,)`` (length decided at runtime).
189
-
190
- Raises:
191
- TypeError: If ``x`` is not a TensorType or TableType.
192
- """
193
-
194
- if not isinstance(x, (TensorType, TableType)):
195
- raise TypeError("pack expects TensorType or TableType input")
196
-
197
- return TensorType(UINT8, (-1,))
198
-
199
-
200
- @_BASIC_MOD.simple_op()
201
- def unpack(b: TensorType, *, out_ty: TensorType | TableType) -> TensorType | TableType:
202
- """Deserialize a byte vector into the explicit output type.
203
-
204
- Args:
205
- b: Byte vector type. Must be ``UINT8`` with shape ``(N,)`` (``N`` may be
206
- ``-1``).
207
- out_ty: Resulting type/schema after unpacking.
208
-
209
- Returns:
210
- Exactly ``out_ty``.
211
-
212
- Raises:
213
- TypeError: If ``out_ty`` is not a TensorType/TableType, or if ``b`` is
214
- not a 1-D UINT8 tensor.
215
- """
216
-
217
- if not isinstance(out_ty, (TensorType, TableType)):
218
- raise TypeError("out_ty must be TensorType or TableType")
219
-
220
- if b.dtype != UINT8 or len(b.shape) != 1:
221
- raise TypeError("unpack expects a 1-D UINT8 tensor")
222
-
223
- return out_ty
224
-
225
-
226
- @_BASIC_MOD.simple_op()
227
- def table_to_tensor(table: TableType, *, number_rows: int) -> TensorType:
228
- """Convert a homogeneous-typed table to a dense 2D tensor.
229
-
230
- Args:
231
- table: Input table whose columns all share the same dtype.
232
- number_rows: Number of rows in the resulting tensor. Must be ``>= 0``.
233
-
234
- Returns:
235
- A rank-2 tensor with dtype equal to the table column dtype and shape
236
- ``(number_rows, table.num_columns())``.
237
-
238
- Raises:
239
- ValueError: If the table is empty or ``number_rows < 0``.
240
- TypeError: If the table has heterogeneous column dtypes or ``number_rows``
241
- is not an int.
242
- """
243
- if table.num_columns() == 0:
244
- raise ValueError("Cannot pack empty table")
245
- col_dtypes = list(table.column_types())
246
- first = col_dtypes[0]
247
- if not all(dt == first for dt in col_dtypes[1:]):
248
- raise TypeError(
249
- "Heterogeneous dtypes; perform casting upstream before table_to_tensor"
250
- )
251
- if not isinstance(number_rows, int):
252
- raise TypeError("number_rows must be an int")
253
- if number_rows < 0:
254
- raise ValueError("number_rows must be >= 0")
255
- shape = (number_rows, table.num_columns())
256
- return TensorType(first, shape) # type: ignore[arg-type]
257
-
258
-
259
- @_BASIC_MOD.simple_op()
260
- def tensor_to_table(tensor: TensorType, *, column_names: list[str]) -> TableType:
261
- """Convert a rank-2 tensor into a table with named columns.
262
-
263
- Args:
264
- tensor: Rank-2 tensor with shape ``(N, F)``.
265
- column_names: List of unique, non-whitespace column names of length ``F``.
266
-
267
- Returns:
268
- A table with ``F`` columns named as provided, each with dtype
269
- ``tensor.dtype``.
270
-
271
- Raises:
272
- TypeError: If ``tensor`` is not rank-2, or if any column name is not a
273
- string.
274
- ValueError: If names are empty/whitespace, duplicated, or length != ``F``.
275
- """
276
- if len(tensor.shape) != 2:
277
- raise TypeError("tensor_to_table expects a rank-2 tensor (N,F)")
278
- n_cols = tensor.shape[1]
279
- if not column_names:
280
- raise ValueError("column_names required (non-empty)")
281
- if len(column_names) != n_cols:
282
- raise ValueError("column_names length must match tensor second dim")
283
- for i, name in enumerate(column_names):
284
- if not isinstance(name, str):
285
- raise TypeError(f"column_names[{i}] must be str, got {type(name).__name__}")
286
- if name == "" or name.strip() == "":
287
- raise ValueError("column names must be non-empty and not whitespace-only")
288
- seen: set[str] = set()
289
- for name in column_names:
290
- if name in seen:
291
- raise ValueError(f"duplicate column name: {name!r}")
292
- seen.add(name)
293
- col_types = [tensor.dtype] * n_cols
294
- return TableType.from_pairs(list(zip(column_names, col_types, strict=True)))
mplang/v1/ops/crypto.py DELETED
@@ -1,262 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- Crypto frontend operations: operation signatures, types, and high-level semantics.
17
-
18
- Scope and contracts:
19
- - This module defines portable API shapes; it does not implement cryptography.
20
- - Backends execute the operations and must meet the security semantics required
21
- by the deployment (confidentiality, authenticity, correctness, etc.).
22
- - The enc/dec API in this frontend uses a conventional 12-byte nonce prefix
23
- (ciphertext = nonce || payload), and dec expects that format. Other security
24
- properties (e.g., AEAD) are backend responsibilities.
25
- """
26
-
27
- from __future__ import annotations
28
-
29
- from jax.tree_util import PyTreeDef, tree_flatten
30
-
31
- from mplang.v1.core import UINT8, TensorType
32
- from mplang.v1.core.mpobject import MPObject
33
- from mplang.v1.core.pfunc import PFunction
34
- from mplang.v1.ops.base import stateless_mod
35
-
36
- _CRYPTO_MOD = stateless_mod("crypto")
37
-
38
-
39
- def _get_algo_overhead(algo: str) -> int:
40
- """Get ciphertext overhead for a given encryption algorithm.
41
-
42
- Args:
43
- algo: Encryption algorithm identifier
44
-
45
- Returns:
46
- int: Number of overhead bytes added to plaintext length
47
- """
48
- overhead_map = {
49
- "aes-ctr": 16, # nonce only (legacy compatibility)
50
- "aes-gcm": 28, # nonce(12) + tag(16) for AES-GCM
51
- "sm4-gcm": 28, # nonce(12) + tag(16) for SM4-GCM
52
- }
53
-
54
- if algo not in overhead_map:
55
- # return unknown overhead as -1
56
- return -1
57
- return overhead_map[algo]
58
-
59
-
60
- @_CRYPTO_MOD.simple_op()
61
- def keygen(*, length: int = 32) -> TensorType:
62
- """Generate random bytes for symmetric keys or generic randomness.
63
-
64
- API: keygen(length: int = 32) -> key: u8[length]
65
-
66
- Notes:
67
- - Frontend defines the type/shape; backend provides randomness.
68
- - Raises ValueError when length <= 0.
69
- """
70
- if length <= 0:
71
- raise ValueError("length must be > 0")
72
- return TensorType(UINT8, (length,))
73
-
74
-
75
- @_CRYPTO_MOD.op_def()
76
- def enc(
77
- plaintext: MPObject, key: MPObject, algo: str = "aes-ctr"
78
- ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
79
- """Symmetric encryption with algorithm-aware output sizing.
80
-
81
- API: enc(plaintext: u8[N], key: u8[M], *, algo: str = "aes-ctr") -> ciphertext: u8[N + overhead]
82
-
83
- Supported algorithms and overhead:
84
- - "aes-ctr": 16 bytes (nonce only, legacy compatibility)
85
- - "aes-gcm": 28 bytes (nonce + 16-byte authentication tag)
86
- - "sm4-gcm": 28 bytes (nonce + 16-byte authentication tag)
87
-
88
- The algo parameter is stored in the PFunction attributes for backend use.
89
- """
90
- pt_ty = plaintext
91
- if pt_ty.dtype != UINT8:
92
- raise TypeError("enc expects UINT8 plaintext")
93
- if len(pt_ty.shape) != 1:
94
- raise TypeError("enc expects 1-D plaintext")
95
-
96
- # Validate and get overhead for the specified algorithm
97
- overhead = _get_algo_overhead(algo)
98
- length = pt_ty.shape[0]
99
- if length >= 0 and overhead >= 0:
100
- outs_info = (TensorType(UINT8, (length + overhead,)),)
101
- else:
102
- # Unknown length or overhead, return dynamic length
103
- outs_info = (TensorType(UINT8, (-1,)),)
104
-
105
- ins_info = (TensorType.from_obj(pt_ty), TensorType.from_obj(key))
106
- pfunc = PFunction(
107
- fn_type="crypto.enc",
108
- ins_info=ins_info,
109
- outs_info=outs_info,
110
- algo=algo,
111
- )
112
- _, treedef = tree_flatten(outs_info[0])
113
- return pfunc, [plaintext, key], treedef
114
-
115
-
116
- @_CRYPTO_MOD.op_def()
117
- def dec(
118
- ciphertext: MPObject, key: MPObject, algo: str = "aes-ctr"
119
- ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
120
- """Symmetric decryption with algorithm-aware input sizing.
121
-
122
- API: dec(ciphertext: u8[N + overhead], key: u8[M], *, algo: str = "aes-ctr") -> plaintext: u8[N]
123
-
124
- Supported algorithms and overhead:
125
- - "aes-ctr": 16 bytes (nonce only, legacy compatibility)
126
- - "aes-gcm": 28 bytes (nonce + 16-byte authentication tag)
127
- - "sm4-gcm": 28 bytes (nonce + 16-byte authentication tag)
128
-
129
- The algo parameter is stored in the PFunction attributes for backend use.
130
- Backend is responsible for parsing the ciphertext format according to algo.
131
- """
132
- ct_ty = ciphertext
133
- if ct_ty.dtype != UINT8:
134
- raise TypeError("dec expects UINT8 ciphertext")
135
- if len(ct_ty.shape) != 1:
136
- raise TypeError("dec expects 1-D ciphertext")
137
-
138
- # Validate and get overhead for the specified algorithm
139
- overhead = _get_algo_overhead(algo)
140
- length = ct_ty.shape[0]
141
-
142
- # Validate minimum ciphertext length
143
- if length >= 0 and overhead >= 0 and length < overhead:
144
- raise TypeError(
145
- f"dec expects ciphertext with at least {overhead} bytes for algo='{algo}', but got {length} bytes"
146
- )
147
-
148
- # Compute output plaintext length
149
- if length >= 0 and overhead >= 0:
150
- outs_info = (TensorType(UINT8, (length - overhead,)),)
151
- else:
152
- # Unknown length or overhead, return dynamic length
153
- outs_info = (TensorType(UINT8, (-1,)),)
154
-
155
- ins_info = (TensorType.from_obj(ct_ty), TensorType.from_obj(key))
156
- pfunc = PFunction(
157
- fn_type="crypto.dec",
158
- ins_info=ins_info,
159
- outs_info=outs_info,
160
- algo=algo,
161
- )
162
- _, treedef = tree_flatten(outs_info[0])
163
- return pfunc, [ciphertext, key], treedef
164
-
165
-
166
- @_CRYPTO_MOD.op_def()
167
- def kem_keygen(suite: str = "x25519") -> tuple[PFunction, list[MPObject], PyTreeDef]:
168
- """KEM-style keypair generation: returns (sk, pk) bytes.
169
-
170
- API: kem_keygen(suite: str = "x25519") -> (sk: u8[32], pk: u8[32])
171
-
172
- The suite parameter is stored in the PFunction attributes for backend use.
173
- """
174
- if suite == "x25519":
175
- sk_ty = TensorType(UINT8, (32,))
176
- pk_ty = TensorType(UINT8, (32,))
177
- else:
178
- # Unknown suite, return dynamic lengths
179
- sk_ty = TensorType(UINT8, (-1,))
180
- pk_ty = TensorType(UINT8, (-1,))
181
- outs_info = (sk_ty, pk_ty)
182
-
183
- pfunc = PFunction(
184
- fn_type="crypto.kem_keygen",
185
- ins_info=(),
186
- outs_info=outs_info,
187
- suite=suite,
188
- )
189
- _, treedef = tree_flatten(outs_info)
190
- return pfunc, [], treedef
191
-
192
-
193
- @_CRYPTO_MOD.op_def()
194
- def kem_derive(
195
- sk: MPObject, peer_pk: MPObject, suite: str = "x25519"
196
- ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
197
- """KEM-style shared secret derivation: returns secret bytes.
198
-
199
- API: kem_derive(sk: u8[32], peer_pk: u8[32], suite: str = "x25519") -> secret: u8[32]
200
-
201
- The suite parameter is stored in the PFunction attributes for backend use.
202
- """
203
- # Validate input types
204
- if sk.dtype != UINT8:
205
- raise TypeError("kem_derive expects UINT8 secret key")
206
- if peer_pk.dtype != UINT8:
207
- raise TypeError("kem_derive expects UINT8 peer public key")
208
- if len(sk.shape) != 1 or len(peer_pk.shape) != 1:
209
- raise TypeError("kem_derive expects 1-D inputs")
210
-
211
- if suite == "x25519":
212
- if sk.shape[0] != 32 or peer_pk.shape[0] != 32:
213
- raise TypeError("kem_derive expects 32-byte keys for suite 'x25519'")
214
- secret_ty = TensorType(UINT8, (32,))
215
- else:
216
- # Unknown suite, return dynamic length
217
- secret_ty = TensorType(UINT8, (-1,))
218
- outs_info = (secret_ty,)
219
-
220
- ins_info = (TensorType.from_obj(sk), TensorType.from_obj(peer_pk))
221
- pfunc = PFunction(
222
- fn_type="crypto.kem_derive",
223
- ins_info=ins_info,
224
- outs_info=outs_info,
225
- suite=suite,
226
- )
227
- _, treedef = tree_flatten(outs_info[0])
228
- return pfunc, [sk, peer_pk], treedef
229
-
230
-
231
- @_CRYPTO_MOD.op_def()
232
- def hkdf(
233
- secret: MPObject, info: str, hash: str = "SHA-256"
234
- ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
235
- """HKDF-style key derivation: returns a 32-byte key.
236
-
237
- API: hkdf(secret: u8[N], info: str, hash: str = "SHA-256") -> key: u8[32]
238
-
239
- The hash parameter is stored in the PFunction attributes for backend use.
240
- """
241
- # Validate input types
242
- if secret.dtype != UINT8:
243
- raise TypeError("hkdf expects UINT8 secret")
244
- if len(secret.shape) != 1:
245
- raise TypeError("hkdf expects 1-D secret")
246
-
247
- if hash == "SHA-256" or hash == "SM3":
248
- outs_info = (TensorType(UINT8, (32,)),)
249
- else:
250
- # Unknown hash, return dynamic length
251
- outs_info = (TensorType(UINT8, (-1,)),)
252
-
253
- ins_info = (TensorType.from_obj(secret),)
254
- pfunc = PFunction(
255
- fn_type="crypto.hkdf",
256
- ins_info=ins_info,
257
- outs_info=outs_info,
258
- hash=hash,
259
- info=info,
260
- )
261
- _, treedef = tree_flatten(outs_info[0])
262
- return pfunc, [secret], treedef