mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/ops/basic.py
DELETED
|
@@ -1,294 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
from jax.tree_util import PyTreeDef, tree_flatten
|
|
17
|
-
|
|
18
|
-
from mplang.v1.core import (
|
|
19
|
-
UINT8,
|
|
20
|
-
UINT64,
|
|
21
|
-
MPObject,
|
|
22
|
-
PFunction,
|
|
23
|
-
ScalarType,
|
|
24
|
-
Shape,
|
|
25
|
-
TableLike,
|
|
26
|
-
TableType,
|
|
27
|
-
TensorLike,
|
|
28
|
-
TensorType,
|
|
29
|
-
)
|
|
30
|
-
from mplang.v1.ops.base import stateless_mod
|
|
31
|
-
from mplang.v1.utils import table_utils
|
|
32
|
-
|
|
33
|
-
_BASIC_MOD = stateless_mod("basic")
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
@_BASIC_MOD.simple_op()
|
|
37
|
-
def identity(x: TensorType) -> TensorType:
|
|
38
|
-
"""Return the input type unchanged.
|
|
39
|
-
|
|
40
|
-
Args:
|
|
41
|
-
x: The input tensor type. If called with an MPObject, the value is
|
|
42
|
-
captured positionally; the kernel sees only the type.
|
|
43
|
-
|
|
44
|
-
Returns:
|
|
45
|
-
The same type as ``x``.
|
|
46
|
-
"""
|
|
47
|
-
return x
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
@_BASIC_MOD.simple_op()
|
|
51
|
-
def read(*, path: str, ty: TensorType) -> TensorType:
|
|
52
|
-
"""Declare reading a value of type ``ty`` from ``path`` (type-only).
|
|
53
|
-
|
|
54
|
-
Args:
|
|
55
|
-
path: Non-empty path or URI to read from (stored as an attribute).
|
|
56
|
-
ty: The expected output type/schema.
|
|
57
|
-
|
|
58
|
-
Returns:
|
|
59
|
-
Exactly ``ty``.
|
|
60
|
-
|
|
61
|
-
Raises:
|
|
62
|
-
ValueError: If ``path`` is empty.
|
|
63
|
-
TypeError: If ``ty`` is not a TensorType or TableType.
|
|
64
|
-
"""
|
|
65
|
-
if not isinstance(path, str) or path == "":
|
|
66
|
-
raise ValueError("path must be a non-empty string")
|
|
67
|
-
if not isinstance(ty, (TensorType, TableType)):
|
|
68
|
-
raise TypeError("ty must be a TensorType or TableType")
|
|
69
|
-
# typed_op will attach 'path' as an attribute and build the PFunction
|
|
70
|
-
return ty
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
@_BASIC_MOD.simple_op()
|
|
74
|
-
def write(x: TensorType, *, path: str) -> TensorType:
|
|
75
|
-
"""Declare writing the input value to ``path`` and return the same type.
|
|
76
|
-
|
|
77
|
-
Args:
|
|
78
|
-
x: The value's type to be written; values are captured positionally.
|
|
79
|
-
path: Destination path or URI (attribute).
|
|
80
|
-
|
|
81
|
-
Returns:
|
|
82
|
-
The same type as ``x``.
|
|
83
|
-
"""
|
|
84
|
-
return x
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
@_BASIC_MOD.op_def()
|
|
88
|
-
def constant(
|
|
89
|
-
data: TensorLike | ScalarType | TableLike,
|
|
90
|
-
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
91
|
-
"""Embed a literal tensor/table and return the full triad.
|
|
92
|
-
|
|
93
|
-
Args:
|
|
94
|
-
data: Constant payload. Supports scalars, array-like tensors, or
|
|
95
|
-
table-like dataframes.
|
|
96
|
-
|
|
97
|
-
Returns:
|
|
98
|
-
Tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
99
|
-
- PFunction: ``fn_type='basic.constant'`` with one output whose type
|
|
100
|
-
matches ``data``; payload serialized via ``data_bytes`` with
|
|
101
|
-
``data_format`` ('bytes[numpy]' or 'bytes[csv]').
|
|
102
|
-
- list[MPObject]: Empty (no inputs captured).
|
|
103
|
-
- PyTreeDef: Output tree (single leaf).
|
|
104
|
-
"""
|
|
105
|
-
import numpy as np
|
|
106
|
-
|
|
107
|
-
data_bytes: bytes
|
|
108
|
-
out_type: TableType | TensorType
|
|
109
|
-
|
|
110
|
-
if isinstance(data, TableLike):
|
|
111
|
-
format = "parquet"
|
|
112
|
-
data_bytes = table_utils.encode_table(data, format=format)
|
|
113
|
-
data_format = f"bytes[{format}]"
|
|
114
|
-
out_type = TableType.from_tablelike(data)
|
|
115
|
-
elif isinstance(data, ScalarType):
|
|
116
|
-
out_type = TensorType.from_obj(data)
|
|
117
|
-
np_data = np.array(data)
|
|
118
|
-
data_bytes = np_data.tobytes()
|
|
119
|
-
data_format = "bytes[numpy]"
|
|
120
|
-
else:
|
|
121
|
-
if hasattr(data, "tobytes"):
|
|
122
|
-
out_type = TensorType.from_obj(data)
|
|
123
|
-
data_bytes = data.tobytes() # type: ignore[attr-defined]
|
|
124
|
-
else:
|
|
125
|
-
np_data = np.array(data)
|
|
126
|
-
out_type = TensorType.from_obj(np_data)
|
|
127
|
-
data_bytes = np_data.tobytes()
|
|
128
|
-
data_format = "bytes[numpy]"
|
|
129
|
-
|
|
130
|
-
pfunc = PFunction(
|
|
131
|
-
fn_type="basic.constant",
|
|
132
|
-
ins_info=(),
|
|
133
|
-
outs_info=(out_type,),
|
|
134
|
-
data_bytes=data_bytes,
|
|
135
|
-
data_format=data_format,
|
|
136
|
-
)
|
|
137
|
-
_, treedef = tree_flatten(out_type)
|
|
138
|
-
return pfunc, [], treedef
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
@_BASIC_MOD.simple_op()
|
|
142
|
-
def rank() -> TensorType:
|
|
143
|
-
"""Return the scalar UINT64 tensor type for the current party rank.
|
|
144
|
-
|
|
145
|
-
Returns:
|
|
146
|
-
A scalar ``UINT64`` tensor type (shape ``()``).
|
|
147
|
-
"""
|
|
148
|
-
return TensorType(UINT64, ())
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
@_BASIC_MOD.simple_op()
|
|
152
|
-
def prand(*, shape: Shape = ()) -> TensorType:
|
|
153
|
-
"""Declare a private random UINT64 tensor with the given shape.
|
|
154
|
-
|
|
155
|
-
Args:
|
|
156
|
-
shape: Output tensor shape. Defaults to ``()``.
|
|
157
|
-
|
|
158
|
-
Returns:
|
|
159
|
-
A ``UINT64`` tensor type with the specified shape.
|
|
160
|
-
"""
|
|
161
|
-
return TensorType(UINT64, shape)
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
@_BASIC_MOD.simple_op()
|
|
165
|
-
def debug_print(
|
|
166
|
-
x: TensorType | TableType, *, prefix: str = ""
|
|
167
|
-
) -> TableType | TensorType:
|
|
168
|
-
"""Print a value at runtime and return the same type.
|
|
169
|
-
|
|
170
|
-
Args:
|
|
171
|
-
x: The value to print (captured positionally; kernel sees only type).
|
|
172
|
-
prefix: Optional text prefix for the printed output.
|
|
173
|
-
|
|
174
|
-
Returns:
|
|
175
|
-
The same type as ``x``.
|
|
176
|
-
"""
|
|
177
|
-
return x
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
@_BASIC_MOD.simple_op()
|
|
181
|
-
def pack(x: TensorType | TableType) -> TensorType:
|
|
182
|
-
"""Serialize a tensor/table into a byte vector (type-only).
|
|
183
|
-
|
|
184
|
-
Args:
|
|
185
|
-
x: Input type to pack.
|
|
186
|
-
|
|
187
|
-
Returns:
|
|
188
|
-
A ``UINT8`` tensor type with shape ``(-1,)`` (length decided at runtime).
|
|
189
|
-
|
|
190
|
-
Raises:
|
|
191
|
-
TypeError: If ``x`` is not a TensorType or TableType.
|
|
192
|
-
"""
|
|
193
|
-
|
|
194
|
-
if not isinstance(x, (TensorType, TableType)):
|
|
195
|
-
raise TypeError("pack expects TensorType or TableType input")
|
|
196
|
-
|
|
197
|
-
return TensorType(UINT8, (-1,))
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
@_BASIC_MOD.simple_op()
|
|
201
|
-
def unpack(b: TensorType, *, out_ty: TensorType | TableType) -> TensorType | TableType:
|
|
202
|
-
"""Deserialize a byte vector into the explicit output type.
|
|
203
|
-
|
|
204
|
-
Args:
|
|
205
|
-
b: Byte vector type. Must be ``UINT8`` with shape ``(N,)`` (``N`` may be
|
|
206
|
-
``-1``).
|
|
207
|
-
out_ty: Resulting type/schema after unpacking.
|
|
208
|
-
|
|
209
|
-
Returns:
|
|
210
|
-
Exactly ``out_ty``.
|
|
211
|
-
|
|
212
|
-
Raises:
|
|
213
|
-
TypeError: If ``out_ty`` is not a TensorType/TableType, or if ``b`` is
|
|
214
|
-
not a 1-D UINT8 tensor.
|
|
215
|
-
"""
|
|
216
|
-
|
|
217
|
-
if not isinstance(out_ty, (TensorType, TableType)):
|
|
218
|
-
raise TypeError("out_ty must be TensorType or TableType")
|
|
219
|
-
|
|
220
|
-
if b.dtype != UINT8 or len(b.shape) != 1:
|
|
221
|
-
raise TypeError("unpack expects a 1-D UINT8 tensor")
|
|
222
|
-
|
|
223
|
-
return out_ty
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
@_BASIC_MOD.simple_op()
|
|
227
|
-
def table_to_tensor(table: TableType, *, number_rows: int) -> TensorType:
|
|
228
|
-
"""Convert a homogeneous-typed table to a dense 2D tensor.
|
|
229
|
-
|
|
230
|
-
Args:
|
|
231
|
-
table: Input table whose columns all share the same dtype.
|
|
232
|
-
number_rows: Number of rows in the resulting tensor. Must be ``>= 0``.
|
|
233
|
-
|
|
234
|
-
Returns:
|
|
235
|
-
A rank-2 tensor with dtype equal to the table column dtype and shape
|
|
236
|
-
``(number_rows, table.num_columns())``.
|
|
237
|
-
|
|
238
|
-
Raises:
|
|
239
|
-
ValueError: If the table is empty or ``number_rows < 0``.
|
|
240
|
-
TypeError: If the table has heterogeneous column dtypes or ``number_rows``
|
|
241
|
-
is not an int.
|
|
242
|
-
"""
|
|
243
|
-
if table.num_columns() == 0:
|
|
244
|
-
raise ValueError("Cannot pack empty table")
|
|
245
|
-
col_dtypes = list(table.column_types())
|
|
246
|
-
first = col_dtypes[0]
|
|
247
|
-
if not all(dt == first for dt in col_dtypes[1:]):
|
|
248
|
-
raise TypeError(
|
|
249
|
-
"Heterogeneous dtypes; perform casting upstream before table_to_tensor"
|
|
250
|
-
)
|
|
251
|
-
if not isinstance(number_rows, int):
|
|
252
|
-
raise TypeError("number_rows must be an int")
|
|
253
|
-
if number_rows < 0:
|
|
254
|
-
raise ValueError("number_rows must be >= 0")
|
|
255
|
-
shape = (number_rows, table.num_columns())
|
|
256
|
-
return TensorType(first, shape) # type: ignore[arg-type]
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
@_BASIC_MOD.simple_op()
|
|
260
|
-
def tensor_to_table(tensor: TensorType, *, column_names: list[str]) -> TableType:
|
|
261
|
-
"""Convert a rank-2 tensor into a table with named columns.
|
|
262
|
-
|
|
263
|
-
Args:
|
|
264
|
-
tensor: Rank-2 tensor with shape ``(N, F)``.
|
|
265
|
-
column_names: List of unique, non-whitespace column names of length ``F``.
|
|
266
|
-
|
|
267
|
-
Returns:
|
|
268
|
-
A table with ``F`` columns named as provided, each with dtype
|
|
269
|
-
``tensor.dtype``.
|
|
270
|
-
|
|
271
|
-
Raises:
|
|
272
|
-
TypeError: If ``tensor`` is not rank-2, or if any column name is not a
|
|
273
|
-
string.
|
|
274
|
-
ValueError: If names are empty/whitespace, duplicated, or length != ``F``.
|
|
275
|
-
"""
|
|
276
|
-
if len(tensor.shape) != 2:
|
|
277
|
-
raise TypeError("tensor_to_table expects a rank-2 tensor (N,F)")
|
|
278
|
-
n_cols = tensor.shape[1]
|
|
279
|
-
if not column_names:
|
|
280
|
-
raise ValueError("column_names required (non-empty)")
|
|
281
|
-
if len(column_names) != n_cols:
|
|
282
|
-
raise ValueError("column_names length must match tensor second dim")
|
|
283
|
-
for i, name in enumerate(column_names):
|
|
284
|
-
if not isinstance(name, str):
|
|
285
|
-
raise TypeError(f"column_names[{i}] must be str, got {type(name).__name__}")
|
|
286
|
-
if name == "" or name.strip() == "":
|
|
287
|
-
raise ValueError("column names must be non-empty and not whitespace-only")
|
|
288
|
-
seen: set[str] = set()
|
|
289
|
-
for name in column_names:
|
|
290
|
-
if name in seen:
|
|
291
|
-
raise ValueError(f"duplicate column name: {name!r}")
|
|
292
|
-
seen.add(name)
|
|
293
|
-
col_types = [tensor.dtype] * n_cols
|
|
294
|
-
return TableType.from_pairs(list(zip(column_names, col_types, strict=True)))
|
mplang/v1/ops/crypto.py
DELETED
|
@@ -1,262 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
Crypto frontend operations: operation signatures, types, and high-level semantics.
|
|
17
|
-
|
|
18
|
-
Scope and contracts:
|
|
19
|
-
- This module defines portable API shapes; it does not implement cryptography.
|
|
20
|
-
- Backends execute the operations and must meet the security semantics required
|
|
21
|
-
by the deployment (confidentiality, authenticity, correctness, etc.).
|
|
22
|
-
- The enc/dec API in this frontend uses a conventional 12-byte nonce prefix
|
|
23
|
-
(ciphertext = nonce || payload), and dec expects that format. Other security
|
|
24
|
-
properties (e.g., AEAD) are backend responsibilities.
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
from __future__ import annotations
|
|
28
|
-
|
|
29
|
-
from jax.tree_util import PyTreeDef, tree_flatten
|
|
30
|
-
|
|
31
|
-
from mplang.v1.core import UINT8, TensorType
|
|
32
|
-
from mplang.v1.core.mpobject import MPObject
|
|
33
|
-
from mplang.v1.core.pfunc import PFunction
|
|
34
|
-
from mplang.v1.ops.base import stateless_mod
|
|
35
|
-
|
|
36
|
-
_CRYPTO_MOD = stateless_mod("crypto")
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def _get_algo_overhead(algo: str) -> int:
|
|
40
|
-
"""Get ciphertext overhead for a given encryption algorithm.
|
|
41
|
-
|
|
42
|
-
Args:
|
|
43
|
-
algo: Encryption algorithm identifier
|
|
44
|
-
|
|
45
|
-
Returns:
|
|
46
|
-
int: Number of overhead bytes added to plaintext length
|
|
47
|
-
"""
|
|
48
|
-
overhead_map = {
|
|
49
|
-
"aes-ctr": 16, # nonce only (legacy compatibility)
|
|
50
|
-
"aes-gcm": 28, # nonce(12) + tag(16) for AES-GCM
|
|
51
|
-
"sm4-gcm": 28, # nonce(12) + tag(16) for SM4-GCM
|
|
52
|
-
}
|
|
53
|
-
|
|
54
|
-
if algo not in overhead_map:
|
|
55
|
-
# return unknown overhead as -1
|
|
56
|
-
return -1
|
|
57
|
-
return overhead_map[algo]
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
@_CRYPTO_MOD.simple_op()
|
|
61
|
-
def keygen(*, length: int = 32) -> TensorType:
|
|
62
|
-
"""Generate random bytes for symmetric keys or generic randomness.
|
|
63
|
-
|
|
64
|
-
API: keygen(length: int = 32) -> key: u8[length]
|
|
65
|
-
|
|
66
|
-
Notes:
|
|
67
|
-
- Frontend defines the type/shape; backend provides randomness.
|
|
68
|
-
- Raises ValueError when length <= 0.
|
|
69
|
-
"""
|
|
70
|
-
if length <= 0:
|
|
71
|
-
raise ValueError("length must be > 0")
|
|
72
|
-
return TensorType(UINT8, (length,))
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
@_CRYPTO_MOD.op_def()
|
|
76
|
-
def enc(
|
|
77
|
-
plaintext: MPObject, key: MPObject, algo: str = "aes-ctr"
|
|
78
|
-
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
79
|
-
"""Symmetric encryption with algorithm-aware output sizing.
|
|
80
|
-
|
|
81
|
-
API: enc(plaintext: u8[N], key: u8[M], *, algo: str = "aes-ctr") -> ciphertext: u8[N + overhead]
|
|
82
|
-
|
|
83
|
-
Supported algorithms and overhead:
|
|
84
|
-
- "aes-ctr": 16 bytes (nonce only, legacy compatibility)
|
|
85
|
-
- "aes-gcm": 28 bytes (nonce + 16-byte authentication tag)
|
|
86
|
-
- "sm4-gcm": 28 bytes (nonce + 16-byte authentication tag)
|
|
87
|
-
|
|
88
|
-
The algo parameter is stored in the PFunction attributes for backend use.
|
|
89
|
-
"""
|
|
90
|
-
pt_ty = plaintext
|
|
91
|
-
if pt_ty.dtype != UINT8:
|
|
92
|
-
raise TypeError("enc expects UINT8 plaintext")
|
|
93
|
-
if len(pt_ty.shape) != 1:
|
|
94
|
-
raise TypeError("enc expects 1-D plaintext")
|
|
95
|
-
|
|
96
|
-
# Validate and get overhead for the specified algorithm
|
|
97
|
-
overhead = _get_algo_overhead(algo)
|
|
98
|
-
length = pt_ty.shape[0]
|
|
99
|
-
if length >= 0 and overhead >= 0:
|
|
100
|
-
outs_info = (TensorType(UINT8, (length + overhead,)),)
|
|
101
|
-
else:
|
|
102
|
-
# Unknown length or overhead, return dynamic length
|
|
103
|
-
outs_info = (TensorType(UINT8, (-1,)),)
|
|
104
|
-
|
|
105
|
-
ins_info = (TensorType.from_obj(pt_ty), TensorType.from_obj(key))
|
|
106
|
-
pfunc = PFunction(
|
|
107
|
-
fn_type="crypto.enc",
|
|
108
|
-
ins_info=ins_info,
|
|
109
|
-
outs_info=outs_info,
|
|
110
|
-
algo=algo,
|
|
111
|
-
)
|
|
112
|
-
_, treedef = tree_flatten(outs_info[0])
|
|
113
|
-
return pfunc, [plaintext, key], treedef
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
@_CRYPTO_MOD.op_def()
|
|
117
|
-
def dec(
|
|
118
|
-
ciphertext: MPObject, key: MPObject, algo: str = "aes-ctr"
|
|
119
|
-
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
120
|
-
"""Symmetric decryption with algorithm-aware input sizing.
|
|
121
|
-
|
|
122
|
-
API: dec(ciphertext: u8[N + overhead], key: u8[M], *, algo: str = "aes-ctr") -> plaintext: u8[N]
|
|
123
|
-
|
|
124
|
-
Supported algorithms and overhead:
|
|
125
|
-
- "aes-ctr": 16 bytes (nonce only, legacy compatibility)
|
|
126
|
-
- "aes-gcm": 28 bytes (nonce + 16-byte authentication tag)
|
|
127
|
-
- "sm4-gcm": 28 bytes (nonce + 16-byte authentication tag)
|
|
128
|
-
|
|
129
|
-
The algo parameter is stored in the PFunction attributes for backend use.
|
|
130
|
-
Backend is responsible for parsing the ciphertext format according to algo.
|
|
131
|
-
"""
|
|
132
|
-
ct_ty = ciphertext
|
|
133
|
-
if ct_ty.dtype != UINT8:
|
|
134
|
-
raise TypeError("dec expects UINT8 ciphertext")
|
|
135
|
-
if len(ct_ty.shape) != 1:
|
|
136
|
-
raise TypeError("dec expects 1-D ciphertext")
|
|
137
|
-
|
|
138
|
-
# Validate and get overhead for the specified algorithm
|
|
139
|
-
overhead = _get_algo_overhead(algo)
|
|
140
|
-
length = ct_ty.shape[0]
|
|
141
|
-
|
|
142
|
-
# Validate minimum ciphertext length
|
|
143
|
-
if length >= 0 and overhead >= 0 and length < overhead:
|
|
144
|
-
raise TypeError(
|
|
145
|
-
f"dec expects ciphertext with at least {overhead} bytes for algo='{algo}', but got {length} bytes"
|
|
146
|
-
)
|
|
147
|
-
|
|
148
|
-
# Compute output plaintext length
|
|
149
|
-
if length >= 0 and overhead >= 0:
|
|
150
|
-
outs_info = (TensorType(UINT8, (length - overhead,)),)
|
|
151
|
-
else:
|
|
152
|
-
# Unknown length or overhead, return dynamic length
|
|
153
|
-
outs_info = (TensorType(UINT8, (-1,)),)
|
|
154
|
-
|
|
155
|
-
ins_info = (TensorType.from_obj(ct_ty), TensorType.from_obj(key))
|
|
156
|
-
pfunc = PFunction(
|
|
157
|
-
fn_type="crypto.dec",
|
|
158
|
-
ins_info=ins_info,
|
|
159
|
-
outs_info=outs_info,
|
|
160
|
-
algo=algo,
|
|
161
|
-
)
|
|
162
|
-
_, treedef = tree_flatten(outs_info[0])
|
|
163
|
-
return pfunc, [ciphertext, key], treedef
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
@_CRYPTO_MOD.op_def()
|
|
167
|
-
def kem_keygen(suite: str = "x25519") -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
168
|
-
"""KEM-style keypair generation: returns (sk, pk) bytes.
|
|
169
|
-
|
|
170
|
-
API: kem_keygen(suite: str = "x25519") -> (sk: u8[32], pk: u8[32])
|
|
171
|
-
|
|
172
|
-
The suite parameter is stored in the PFunction attributes for backend use.
|
|
173
|
-
"""
|
|
174
|
-
if suite == "x25519":
|
|
175
|
-
sk_ty = TensorType(UINT8, (32,))
|
|
176
|
-
pk_ty = TensorType(UINT8, (32,))
|
|
177
|
-
else:
|
|
178
|
-
# Unknown suite, return dynamic lengths
|
|
179
|
-
sk_ty = TensorType(UINT8, (-1,))
|
|
180
|
-
pk_ty = TensorType(UINT8, (-1,))
|
|
181
|
-
outs_info = (sk_ty, pk_ty)
|
|
182
|
-
|
|
183
|
-
pfunc = PFunction(
|
|
184
|
-
fn_type="crypto.kem_keygen",
|
|
185
|
-
ins_info=(),
|
|
186
|
-
outs_info=outs_info,
|
|
187
|
-
suite=suite,
|
|
188
|
-
)
|
|
189
|
-
_, treedef = tree_flatten(outs_info)
|
|
190
|
-
return pfunc, [], treedef
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
@_CRYPTO_MOD.op_def()
|
|
194
|
-
def kem_derive(
|
|
195
|
-
sk: MPObject, peer_pk: MPObject, suite: str = "x25519"
|
|
196
|
-
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
197
|
-
"""KEM-style shared secret derivation: returns secret bytes.
|
|
198
|
-
|
|
199
|
-
API: kem_derive(sk: u8[32], peer_pk: u8[32], suite: str = "x25519") -> secret: u8[32]
|
|
200
|
-
|
|
201
|
-
The suite parameter is stored in the PFunction attributes for backend use.
|
|
202
|
-
"""
|
|
203
|
-
# Validate input types
|
|
204
|
-
if sk.dtype != UINT8:
|
|
205
|
-
raise TypeError("kem_derive expects UINT8 secret key")
|
|
206
|
-
if peer_pk.dtype != UINT8:
|
|
207
|
-
raise TypeError("kem_derive expects UINT8 peer public key")
|
|
208
|
-
if len(sk.shape) != 1 or len(peer_pk.shape) != 1:
|
|
209
|
-
raise TypeError("kem_derive expects 1-D inputs")
|
|
210
|
-
|
|
211
|
-
if suite == "x25519":
|
|
212
|
-
if sk.shape[0] != 32 or peer_pk.shape[0] != 32:
|
|
213
|
-
raise TypeError("kem_derive expects 32-byte keys for suite 'x25519'")
|
|
214
|
-
secret_ty = TensorType(UINT8, (32,))
|
|
215
|
-
else:
|
|
216
|
-
# Unknown suite, return dynamic length
|
|
217
|
-
secret_ty = TensorType(UINT8, (-1,))
|
|
218
|
-
outs_info = (secret_ty,)
|
|
219
|
-
|
|
220
|
-
ins_info = (TensorType.from_obj(sk), TensorType.from_obj(peer_pk))
|
|
221
|
-
pfunc = PFunction(
|
|
222
|
-
fn_type="crypto.kem_derive",
|
|
223
|
-
ins_info=ins_info,
|
|
224
|
-
outs_info=outs_info,
|
|
225
|
-
suite=suite,
|
|
226
|
-
)
|
|
227
|
-
_, treedef = tree_flatten(outs_info[0])
|
|
228
|
-
return pfunc, [sk, peer_pk], treedef
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
@_CRYPTO_MOD.op_def()
|
|
232
|
-
def hkdf(
|
|
233
|
-
secret: MPObject, info: str, hash: str = "SHA-256"
|
|
234
|
-
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
235
|
-
"""HKDF-style key derivation: returns a 32-byte key.
|
|
236
|
-
|
|
237
|
-
API: hkdf(secret: u8[N], info: str, hash: str = "SHA-256") -> key: u8[32]
|
|
238
|
-
|
|
239
|
-
The hash parameter is stored in the PFunction attributes for backend use.
|
|
240
|
-
"""
|
|
241
|
-
# Validate input types
|
|
242
|
-
if secret.dtype != UINT8:
|
|
243
|
-
raise TypeError("hkdf expects UINT8 secret")
|
|
244
|
-
if len(secret.shape) != 1:
|
|
245
|
-
raise TypeError("hkdf expects 1-D secret")
|
|
246
|
-
|
|
247
|
-
if hash == "SHA-256" or hash == "SM3":
|
|
248
|
-
outs_info = (TensorType(UINT8, (32,)),)
|
|
249
|
-
else:
|
|
250
|
-
# Unknown hash, return dynamic length
|
|
251
|
-
outs_info = (TensorType(UINT8, (-1,)),)
|
|
252
|
-
|
|
253
|
-
ins_info = (TensorType.from_obj(secret),)
|
|
254
|
-
pfunc = PFunction(
|
|
255
|
-
fn_type="crypto.hkdf",
|
|
256
|
-
ins_info=ins_info,
|
|
257
|
-
outs_info=outs_info,
|
|
258
|
-
hash=hash,
|
|
259
|
-
info=info,
|
|
260
|
-
)
|
|
261
|
-
_, treedef = tree_flatten(outs_info[0])
|
|
262
|
-
return pfunc, [secret], treedef
|