mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/ops/crypto.py
DELETED
|
@@ -1,109 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
Crypto frontend operations: operation signatures, types, and high-level semantics.
|
|
17
|
-
|
|
18
|
-
Scope and contracts:
|
|
19
|
-
- This module defines portable API shapes; it does not implement cryptography.
|
|
20
|
-
- Backends execute the operations and must meet the security semantics required
|
|
21
|
-
by the deployment (confidentiality, authenticity, correctness, etc.).
|
|
22
|
-
- The enc/dec API in this frontend uses a conventional 12-byte nonce prefix
|
|
23
|
-
(ciphertext = nonce || payload), and dec expects that format. Other security
|
|
24
|
-
properties (e.g., AEAD) are backend responsibilities.
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
from __future__ import annotations
|
|
28
|
-
|
|
29
|
-
from mplang.core.dtype import UINT8
|
|
30
|
-
from mplang.core.tensor import TensorType
|
|
31
|
-
from mplang.ops.base import stateless_mod
|
|
32
|
-
|
|
33
|
-
_CRYPTO_MOD = stateless_mod("crypto")
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
@_CRYPTO_MOD.simple_op()
|
|
37
|
-
def keygen(*, length: int = 32) -> TensorType:
|
|
38
|
-
"""Generate random bytes for symmetric keys or generic randomness.
|
|
39
|
-
|
|
40
|
-
API: keygen(length: int = 32) -> key: u8[length]
|
|
41
|
-
|
|
42
|
-
Notes:
|
|
43
|
-
- Frontend defines the type/shape; backend provides randomness.
|
|
44
|
-
- Raises ValueError when length <= 0.
|
|
45
|
-
"""
|
|
46
|
-
if length <= 0:
|
|
47
|
-
raise ValueError("length must be > 0")
|
|
48
|
-
return TensorType(UINT8, (length,))
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
@_CRYPTO_MOD.simple_op()
|
|
52
|
-
def enc(plaintext: TensorType, key: TensorType) -> TensorType:
|
|
53
|
-
"""Symmetric encryption.
|
|
54
|
-
|
|
55
|
-
API: enc(plaintext: u8[N], key: u8[M]) -> ciphertext: u8[N + 12]
|
|
56
|
-
"""
|
|
57
|
-
pt_ty = plaintext
|
|
58
|
-
if pt_ty.dtype != UINT8:
|
|
59
|
-
raise TypeError("enc expects UINT8 plaintext")
|
|
60
|
-
if len(pt_ty.shape) != 1:
|
|
61
|
-
raise TypeError("enc expects 1-D plaintext")
|
|
62
|
-
length = pt_ty.shape[0]
|
|
63
|
-
if length >= 0:
|
|
64
|
-
return TensorType(UINT8, (length + 12,))
|
|
65
|
-
return TensorType(UINT8, (-1,))
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
@_CRYPTO_MOD.simple_op()
|
|
69
|
-
def dec(ciphertext: TensorType, key: TensorType) -> TensorType:
|
|
70
|
-
"""Symmetric decryption.
|
|
71
|
-
|
|
72
|
-
API: dec(ciphertext: u8[N + 12], key: u8[M]) -> plaintext: u8[N]
|
|
73
|
-
"""
|
|
74
|
-
ct_ty = ciphertext
|
|
75
|
-
if ct_ty.dtype != UINT8:
|
|
76
|
-
raise TypeError("dec expects UINT8 ciphertext")
|
|
77
|
-
if len(ct_ty.shape) != 1:
|
|
78
|
-
raise TypeError("dec expects 1-D ciphertext with nonce")
|
|
79
|
-
length = ct_ty.shape[0]
|
|
80
|
-
if length >= 0 and length < 12:
|
|
81
|
-
raise TypeError("dec expects 1-D ciphertext with nonce")
|
|
82
|
-
if length >= 0:
|
|
83
|
-
return TensorType(UINT8, (length - 12,))
|
|
84
|
-
return TensorType(UINT8, (-1,))
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
@_CRYPTO_MOD.simple_op()
|
|
88
|
-
def kem_keygen(*, suite: str = "x25519") -> tuple[TensorType, TensorType]:
|
|
89
|
-
"""KEM-style keypair generation: returns (sk, pk) bytes."""
|
|
90
|
-
sk_ty = TensorType(UINT8, (32,))
|
|
91
|
-
pk_ty = TensorType(UINT8, (32,))
|
|
92
|
-
return sk_ty, pk_ty
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
@_CRYPTO_MOD.simple_op()
|
|
96
|
-
def kem_derive(
|
|
97
|
-
sk: TensorType, peer_pk: TensorType, *, suite: str = "x25519"
|
|
98
|
-
) -> TensorType:
|
|
99
|
-
"""KEM-style shared secret derivation: returns secret bytes."""
|
|
100
|
-
_ = sk
|
|
101
|
-
_ = peer_pk
|
|
102
|
-
return TensorType(UINT8, (32,))
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
@_CRYPTO_MOD.simple_op()
|
|
106
|
-
def hkdf(secret: TensorType, *, info: str) -> TensorType:
|
|
107
|
-
"""HKDF-style key derivation: returns a 32-byte key."""
|
|
108
|
-
_ = secret
|
|
109
|
-
return TensorType(UINT8, (32,))
|
mplang/ops/ibis_cc.py
DELETED
|
@@ -1,139 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
import inspect
|
|
17
|
-
from collections.abc import Callable
|
|
18
|
-
from typing import Any
|
|
19
|
-
|
|
20
|
-
import ibis
|
|
21
|
-
from jax.tree_util import PyTreeDef, tree_flatten
|
|
22
|
-
|
|
23
|
-
from mplang.core import dtype
|
|
24
|
-
from mplang.core.mpobject import MPObject
|
|
25
|
-
from mplang.core.pfunc import PFunction
|
|
26
|
-
from mplang.core.table import TableType
|
|
27
|
-
from mplang.ops.base import FeOperation, stateless_mod
|
|
28
|
-
from mplang.utils.func_utils import normalize_fn
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def ibis2sql(
|
|
32
|
-
expr: ibis.Table,
|
|
33
|
-
in_schemas: list[ibis.Schema],
|
|
34
|
-
in_names: list[str],
|
|
35
|
-
fn_name: str = "",
|
|
36
|
-
) -> PFunction:
|
|
37
|
-
"""
|
|
38
|
-
Compile a ibis expr to sql and return the PFunction.
|
|
39
|
-
|
|
40
|
-
Args:
|
|
41
|
-
expr: ibis expr.
|
|
42
|
-
in_schemas: the input table schemas
|
|
43
|
-
in_names: the input table names, If there is only one table, it is usually defaulted to "table"
|
|
44
|
-
Return:
|
|
45
|
-
PFunction: The compiled PFunction
|
|
46
|
-
"""
|
|
47
|
-
assert len(in_schemas) == len(in_names), (
|
|
48
|
-
f"length of input table names and schemas mismatch. {len(in_schemas)}!={len(in_names)}"
|
|
49
|
-
)
|
|
50
|
-
|
|
51
|
-
def _convert(s: ibis.Schema) -> TableType:
|
|
52
|
-
return TableType.from_pairs([
|
|
53
|
-
(name, dtype.from_numpy(dt.to_numpy())) for name, dt in s.fields.items()
|
|
54
|
-
])
|
|
55
|
-
|
|
56
|
-
ins_info = [_convert(s) for s in in_schemas]
|
|
57
|
-
outs_info = [_convert(expr.schema())]
|
|
58
|
-
|
|
59
|
-
sql = ibis.to_sql(expr, dialect="duckdb")
|
|
60
|
-
# Emit generic sql.run op; runtime maps to backend-specific kernel.
|
|
61
|
-
pfn = PFunction(
|
|
62
|
-
fn_type="sql.run",
|
|
63
|
-
fn_name=fn_name,
|
|
64
|
-
fn_text=sql,
|
|
65
|
-
ins_info=tuple(ins_info),
|
|
66
|
-
outs_info=tuple(outs_info),
|
|
67
|
-
in_names=tuple(in_names),
|
|
68
|
-
dialect="duckdb",
|
|
69
|
-
)
|
|
70
|
-
return pfn
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
def is_ibis_function(func: Callable) -> bool:
|
|
74
|
-
"""
|
|
75
|
-
Verify whether a function is an ibis function.
|
|
76
|
-
The func signature should like def foo(t0:ibis.Table, t1:ibis.Table)->ibis.Table
|
|
77
|
-
"""
|
|
78
|
-
try:
|
|
79
|
-
sig = inspect.signature(func)
|
|
80
|
-
except (ValueError, TypeError):
|
|
81
|
-
return False
|
|
82
|
-
|
|
83
|
-
ret_anno = sig.return_annotation
|
|
84
|
-
if ret_anno is ibis.Table:
|
|
85
|
-
return True
|
|
86
|
-
|
|
87
|
-
for param in sig.parameters.values():
|
|
88
|
-
par_anno = param.annotation
|
|
89
|
-
if par_anno is ibis.Table:
|
|
90
|
-
return True
|
|
91
|
-
|
|
92
|
-
return False
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
_IBIS_MOD = stateless_mod("ibis")
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
class IbisCompiler(FeOperation):
|
|
99
|
-
"""Ibis compiler frontend operation."""
|
|
100
|
-
|
|
101
|
-
def trace(
|
|
102
|
-
self, func: Callable, *args: Any, **kwargs: Any
|
|
103
|
-
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
104
|
-
"""Compile an Ibis function to SQL format.
|
|
105
|
-
|
|
106
|
-
Args:
|
|
107
|
-
func: The Ibis function to compile
|
|
108
|
-
*args: Positional arguments to the function
|
|
109
|
-
**kwargs: Keyword arguments to the function
|
|
110
|
-
|
|
111
|
-
Returns:
|
|
112
|
-
tuple[PFunction, list[MPObject], Any]: The compiled PFunction, input variables, and output tree
|
|
113
|
-
"""
|
|
114
|
-
|
|
115
|
-
def is_variable(arg: Any) -> bool:
|
|
116
|
-
return isinstance(arg, MPObject)
|
|
117
|
-
|
|
118
|
-
normalized_fn, in_vars = normalize_fn(func, args, kwargs, is_variable)
|
|
119
|
-
|
|
120
|
-
in_args, in_schemas, in_names = [], [], []
|
|
121
|
-
idx = 0
|
|
122
|
-
for arg in in_vars:
|
|
123
|
-
columns = [(p[0], p[1].to_numpy()) for p in arg.schema.columns]
|
|
124
|
-
schema = ibis.schema(columns)
|
|
125
|
-
name = f"table{idx}"
|
|
126
|
-
table = ibis.table(schema=schema, name=name)
|
|
127
|
-
in_args.append(table)
|
|
128
|
-
in_schemas.append(schema)
|
|
129
|
-
in_names.append(name)
|
|
130
|
-
idx += 1
|
|
131
|
-
|
|
132
|
-
result = normalized_fn(in_args)
|
|
133
|
-
assert isinstance(result, ibis.Table)
|
|
134
|
-
pfunc = ibis2sql(result, in_schemas, in_names, func.__name__)
|
|
135
|
-
_, treedef = tree_flatten(result)
|
|
136
|
-
return pfunc, in_vars, treedef
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
ibis_compile = IbisCompiler(_IBIS_MOD, "compile")
|
mplang/ops/sql.py
DELETED
|
@@ -1,61 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from jax.tree_util import PyTreeDef, tree_flatten
|
|
16
|
-
|
|
17
|
-
from mplang.core.mpobject import MPObject
|
|
18
|
-
from mplang.core.pfunc import PFunction
|
|
19
|
-
from mplang.core.table import TableType
|
|
20
|
-
from mplang.ops.base import FeOperation, stateless_mod
|
|
21
|
-
|
|
22
|
-
_SQL_MOD = stateless_mod("sql")
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class SqlFE(FeOperation):
|
|
26
|
-
def __init__(self, dialect: str = "duckdb"):
|
|
27
|
-
# Bind to sql module with a stable op name for registry/dispatch
|
|
28
|
-
super().__init__(_SQL_MOD, "run")
|
|
29
|
-
self._dialect = dialect
|
|
30
|
-
|
|
31
|
-
def trace(
|
|
32
|
-
self,
|
|
33
|
-
sql: str,
|
|
34
|
-
out_type: TableType,
|
|
35
|
-
in_tables: dict[str, MPObject] | None = None,
|
|
36
|
-
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
37
|
-
in_names: list[str] = []
|
|
38
|
-
ins_info: list[TableType] = []
|
|
39
|
-
in_vars: list[MPObject] = []
|
|
40
|
-
if in_tables:
|
|
41
|
-
for name, tbl in in_tables.items():
|
|
42
|
-
assert isinstance(tbl, MPObject)
|
|
43
|
-
assert tbl.schema is not None
|
|
44
|
-
in_names.append(name)
|
|
45
|
-
ins_info.append(tbl.schema)
|
|
46
|
-
in_vars.append(tbl)
|
|
47
|
-
|
|
48
|
-
pfn = PFunction(
|
|
49
|
-
fn_type="sql.run",
|
|
50
|
-
fn_name="",
|
|
51
|
-
fn_text=sql,
|
|
52
|
-
ins_info=tuple(ins_info),
|
|
53
|
-
outs_info=(out_type,),
|
|
54
|
-
in_names=tuple(in_names),
|
|
55
|
-
dialect=self._dialect,
|
|
56
|
-
)
|
|
57
|
-
_, treedef = tree_flatten(out_type)
|
|
58
|
-
return pfn, in_vars, treedef
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
sql_run = SqlFE("duckdb")
|
mplang/runtime/link_comm.py
DELETED
|
@@ -1,131 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
import logging
|
|
18
|
-
from typing import Any
|
|
19
|
-
|
|
20
|
-
import cloudpickle as pickle
|
|
21
|
-
import spu.libspu as libspu
|
|
22
|
-
|
|
23
|
-
from mplang.core.comm import ICollective, ICommunicator
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class LinkCommunicator(ICommunicator, ICollective):
|
|
27
|
-
"""Wraps libspu link communicator for distributed communication"""
|
|
28
|
-
|
|
29
|
-
def __init__(self, rank: int, addrs: list[str], *, mem_link: bool = False):
|
|
30
|
-
self._rank = rank
|
|
31
|
-
self._world_size = len(addrs)
|
|
32
|
-
|
|
33
|
-
desc = libspu.link.Desc() # type: ignore
|
|
34
|
-
desc.recv_timeout_ms = 100 * 1000 # 100 seconds
|
|
35
|
-
desc.http_max_payload_size = 32 * 1024 * 1024 # Default set link payload to 32M
|
|
36
|
-
for rank, addr in enumerate(addrs):
|
|
37
|
-
desc.add_party(f"P{rank}", addr)
|
|
38
|
-
|
|
39
|
-
if mem_link:
|
|
40
|
-
self.lctx = libspu.link.create_mem(desc, self._rank)
|
|
41
|
-
else:
|
|
42
|
-
self.lctx = libspu.link.create_brpc(desc, self._rank)
|
|
43
|
-
|
|
44
|
-
logging.info(
|
|
45
|
-
f"LinkCommunicator initialized with rank={self._rank}, world_size={self._world_size}, addrs={addrs}",
|
|
46
|
-
)
|
|
47
|
-
|
|
48
|
-
self._counter = 0
|
|
49
|
-
|
|
50
|
-
@property
|
|
51
|
-
def rank(self) -> int:
|
|
52
|
-
return self.lctx.rank # type: ignore[no-any-return]
|
|
53
|
-
|
|
54
|
-
@property
|
|
55
|
-
def world_size(self) -> int:
|
|
56
|
-
return self.lctx.world_size # type: ignore[no-any-return]
|
|
57
|
-
|
|
58
|
-
def get_lctx(self) -> libspu.link.Context:
|
|
59
|
-
"""Get the link context"""
|
|
60
|
-
return self.lctx
|
|
61
|
-
|
|
62
|
-
# override
|
|
63
|
-
def new_id(self) -> str:
|
|
64
|
-
res = self._counter
|
|
65
|
-
self._counter += 1
|
|
66
|
-
return str(res)
|
|
67
|
-
|
|
68
|
-
def wrap(self, obj: Any) -> str:
|
|
69
|
-
data = pickle.dumps(obj)
|
|
70
|
-
return data.hex() # type: ignore[no-any-return]
|
|
71
|
-
|
|
72
|
-
def unwrap(self, obj: str) -> Any:
|
|
73
|
-
data = bytes.fromhex(obj)
|
|
74
|
-
return pickle.loads(data) # type: ignore[no-any-return]
|
|
75
|
-
|
|
76
|
-
def send(self, to: int, key: str, data: Any) -> None:
|
|
77
|
-
serialized = pickle.dumps((key, data))
|
|
78
|
-
self.lctx.send(to, serialized.hex())
|
|
79
|
-
|
|
80
|
-
def recv(self, frm: int, key: str) -> Any:
|
|
81
|
-
serialized = self.lctx.recv(frm)
|
|
82
|
-
rkey, data = pickle.loads(bytes.fromhex(serialized.decode()))
|
|
83
|
-
assert key == rkey, f"recv key {key} != {rkey}"
|
|
84
|
-
return data # type: ignore[no-any-return]
|
|
85
|
-
|
|
86
|
-
def p2p(self, frm: int, to: int, data: Any) -> Any:
|
|
87
|
-
assert 0 <= frm < self.world_size
|
|
88
|
-
assert 0 <= to < self.world_size
|
|
89
|
-
|
|
90
|
-
# TODO: link handles cid internally?
|
|
91
|
-
cid = self.new_id()
|
|
92
|
-
|
|
93
|
-
if self.rank == frm:
|
|
94
|
-
self.send(to, cid, data)
|
|
95
|
-
return None
|
|
96
|
-
elif self.rank == to:
|
|
97
|
-
return self.recv(frm, cid)
|
|
98
|
-
else:
|
|
99
|
-
return None
|
|
100
|
-
|
|
101
|
-
def gather(self, root: int, data: Any) -> list[Any]:
|
|
102
|
-
assert 0 <= root < self.world_size
|
|
103
|
-
rets = self.lctx.gather(self.wrap(data), root)
|
|
104
|
-
return [self.unwrap(ret) for ret in rets]
|
|
105
|
-
|
|
106
|
-
def scatter(self, root: int, args: list[Any]) -> Any:
|
|
107
|
-
assert 0 <= root < self.world_size
|
|
108
|
-
assert len(args) == self.world_size, f"{len(args)} != {self.world_size}"
|
|
109
|
-
ret = self.lctx.scatter([self.wrap(arg) for arg in args], root)
|
|
110
|
-
return self.unwrap(ret)
|
|
111
|
-
|
|
112
|
-
def allgather(self, arg: Any) -> list[Any]:
|
|
113
|
-
rets = self.lctx.all_gather(self.wrap(arg))
|
|
114
|
-
return [self.unwrap(ret) for ret in rets]
|
|
115
|
-
|
|
116
|
-
def bcast(self, root: int, arg: Any) -> Any:
|
|
117
|
-
assert 0 <= root < self.world_size
|
|
118
|
-
ret = self.lctx.broadcast(self.wrap(arg), root)
|
|
119
|
-
return self.unwrap(ret)
|
|
120
|
-
|
|
121
|
-
def gather_m(self, pmask: int, root: int, data: Any) -> list[Any]:
|
|
122
|
-
raise ValueError("Not supported by LinkCommunicator")
|
|
123
|
-
|
|
124
|
-
def scatter_m(self, pmask: int, root: int, args: list[Any]) -> Any:
|
|
125
|
-
raise ValueError("Not supported by LinkCommunicator")
|
|
126
|
-
|
|
127
|
-
def allgather_m(self, pmask: int, arg: Any) -> list[Any]:
|
|
128
|
-
raise ValueError("Not supported by LinkCommunicator")
|
|
129
|
-
|
|
130
|
-
def bcast_m(self, pmask: int, root: int, arg: Any) -> Any:
|
|
131
|
-
raise ValueError("Not supported by LinkCommunicator")
|
mplang/simp/smpc.py
DELETED
|
@@ -1,201 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
from abc import ABC, abstractmethod
|
|
17
|
-
from collections.abc import Callable
|
|
18
|
-
from enum import Enum
|
|
19
|
-
from functools import wraps
|
|
20
|
-
from typing import Any
|
|
21
|
-
|
|
22
|
-
from jax.tree_util import tree_unflatten
|
|
23
|
-
|
|
24
|
-
from mplang.core import Mask, MPObject, Rank, peval, psize
|
|
25
|
-
from mplang.core.context_mgr import cur_ctx
|
|
26
|
-
from mplang.ops import spu
|
|
27
|
-
from mplang.simp import mpi
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class SecureAPI(ABC):
|
|
31
|
-
"""Base class for secure APIs."""
|
|
32
|
-
|
|
33
|
-
@abstractmethod
|
|
34
|
-
def seal(self, obj: MPObject, frm_mask: Mask | None) -> list[MPObject]: ...
|
|
35
|
-
|
|
36
|
-
@abstractmethod
|
|
37
|
-
def sealFrom(self, obj: MPObject, root: Rank) -> MPObject: ...
|
|
38
|
-
|
|
39
|
-
@abstractmethod
|
|
40
|
-
def seval(self, fe_type: str, pyfn: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
41
|
-
"""Run a function in the secure environment."""
|
|
42
|
-
|
|
43
|
-
@abstractmethod
|
|
44
|
-
def reveal(self, obj: MPObject, to_mask: Mask) -> MPObject: ...
|
|
45
|
-
|
|
46
|
-
@abstractmethod
|
|
47
|
-
def revealTo(self, obj: MPObject, to_rank: Rank) -> MPObject: ...
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
class Delegation(SecureAPI):
|
|
51
|
-
"""Delegate to a trusted third-party to perform secure operations."""
|
|
52
|
-
|
|
53
|
-
def seal(self, obj: MPObject, frm_mask: Mask | None = None) -> list[MPObject]:
|
|
54
|
-
raise NotImplementedError("TODO")
|
|
55
|
-
|
|
56
|
-
def sealFrom(self, obj: MPObject, root: Rank) -> MPObject:
|
|
57
|
-
raise NotImplementedError("TODO")
|
|
58
|
-
|
|
59
|
-
def seval(self, fe_type: str, pyfn: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
60
|
-
raise NotImplementedError("TODO")
|
|
61
|
-
|
|
62
|
-
def reveal(self, obj: MPObject, to_mask: Mask) -> MPObject:
|
|
63
|
-
raise NotImplementedError("TODO")
|
|
64
|
-
|
|
65
|
-
def revealTo(self, obj: MPObject, to_rank: Rank) -> MPObject:
|
|
66
|
-
raise NotImplementedError("TODO")
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
class SPU(SecureAPI):
|
|
70
|
-
"""Use SPU to perform secure operations."""
|
|
71
|
-
|
|
72
|
-
def get_spu_mask(self) -> Mask:
|
|
73
|
-
spu_devices = cur_ctx().cluster_spec.get_devices_by_kind("SPU")
|
|
74
|
-
if not spu_devices:
|
|
75
|
-
raise ValueError("No SPU device found in the cluster specification")
|
|
76
|
-
if len(spu_devices) > 1:
|
|
77
|
-
raise ValueError("Multiple SPU devices found in the cluster specification")
|
|
78
|
-
spu_device = spu_devices[0]
|
|
79
|
-
spu_mask = Mask.from_ranks([member.rank for member in spu_device.members])
|
|
80
|
-
return spu_mask
|
|
81
|
-
|
|
82
|
-
def seal(self, obj: MPObject, frm_mask: Mask | None = None) -> list[MPObject]:
|
|
83
|
-
spu_mask: Mask = self.get_spu_mask()
|
|
84
|
-
if obj.pmask is None:
|
|
85
|
-
if frm_mask is None:
|
|
86
|
-
# NOTE: The length of the return list is statically determined by obj_mask,
|
|
87
|
-
# so only static masks are supported here.
|
|
88
|
-
raise ValueError("Seal does not support dynamic masks.")
|
|
89
|
-
else:
|
|
90
|
-
# Force seal from the given mask, the runtime will raise error if the mask
|
|
91
|
-
# does not match obj.pmask.
|
|
92
|
-
# TODO(jint): add set_pmask primitive.
|
|
93
|
-
pass
|
|
94
|
-
else:
|
|
95
|
-
if frm_mask is None:
|
|
96
|
-
frm_mask = obj.pmask
|
|
97
|
-
else:
|
|
98
|
-
if not Mask(frm_mask).is_subset(obj.pmask):
|
|
99
|
-
raise ValueError(f"Cannot seal from {frm_mask} to {obj.pmask}, ")
|
|
100
|
-
|
|
101
|
-
# Get the world_size from spu_mask (number of parties in SPU computation)
|
|
102
|
-
world_size = Mask(spu_mask).num_parties()
|
|
103
|
-
pfunc, ins, _ = spu.makeshares(
|
|
104
|
-
obj, world_size=world_size, visibility=spu.Visibility.SECRET
|
|
105
|
-
)
|
|
106
|
-
assert len(ins) == 1
|
|
107
|
-
shares = peval(pfunc, ins, frm_mask)
|
|
108
|
-
|
|
109
|
-
# scatter the shares to each party.
|
|
110
|
-
return [mpi.scatter_m(spu_mask, rank, shares) for rank in Mask(frm_mask)]
|
|
111
|
-
|
|
112
|
-
def sealFrom(self, obj: MPObject, root: Rank) -> MPObject:
|
|
113
|
-
results = seal(obj, frm_mask=Mask.from_ranks(root))
|
|
114
|
-
assert len(results) == 1, f"Expected one result, got {len(results)}"
|
|
115
|
-
return results[0]
|
|
116
|
-
|
|
117
|
-
def seval(self, fe_type: str, pyfn: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
118
|
-
if fe_type != "jax":
|
|
119
|
-
raise ValueError(f"Unsupported fe_type: {fe_type}")
|
|
120
|
-
|
|
121
|
-
spu_mask = self.get_spu_mask()
|
|
122
|
-
pfunc, in_vars, out_tree = spu.jax_compile(pyfn, *args, **kwargs)
|
|
123
|
-
assert all(var.pmask == spu_mask for var in in_vars), in_vars
|
|
124
|
-
out_flat = peval(pfunc, in_vars, spu_mask)
|
|
125
|
-
return tree_unflatten(out_tree, out_flat)
|
|
126
|
-
|
|
127
|
-
def reveal(self, obj: MPObject, to_mask: Mask) -> MPObject:
|
|
128
|
-
spu_mask = self.get_spu_mask()
|
|
129
|
-
|
|
130
|
-
assert obj.pmask == spu_mask, (obj.pmask, spu_mask)
|
|
131
|
-
|
|
132
|
-
# (n_parties, n_shares)
|
|
133
|
-
shares = [mpi.bcast_m(to_mask, rank, obj) for rank in Mask(spu_mask)]
|
|
134
|
-
assert len(shares) == Mask(spu_mask).num_parties(), (shares, spu_mask)
|
|
135
|
-
assert all(share.pmask == to_mask for share in shares)
|
|
136
|
-
|
|
137
|
-
# Reconstruct the original object from shares
|
|
138
|
-
pfunc, ins, _ = spu.reconstruct(*shares)
|
|
139
|
-
return peval(pfunc, ins, to_mask)[0] # type: ignore[no-any-return]
|
|
140
|
-
|
|
141
|
-
def revealTo(self, obj: MPObject, to_rank: Rank) -> MPObject:
|
|
142
|
-
return self.reveal(obj, to_mask=Mask.from_ranks(to_rank))
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
class SEE(Enum):
|
|
146
|
-
"""Secure Execution Environment."""
|
|
147
|
-
|
|
148
|
-
MOCK = 0
|
|
149
|
-
SPU = 1
|
|
150
|
-
TEE = 2
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
# TODO(jint): move me to options.py
|
|
154
|
-
mode: SEE = SEE.SPU
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
def _get_sapi() -> SecureAPI:
|
|
158
|
-
"""Get the current secure API based on the mode."""
|
|
159
|
-
if mode == SEE.MOCK:
|
|
160
|
-
return Delegation()
|
|
161
|
-
elif mode == SEE.SPU:
|
|
162
|
-
return SPU()
|
|
163
|
-
elif mode == SEE.TEE:
|
|
164
|
-
raise NotImplementedError("TEE is not implemented yet")
|
|
165
|
-
else:
|
|
166
|
-
raise ValueError(f"Unknown mode: {mode}")
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
# seal :: m a -> [s a]
|
|
170
|
-
def seal(obj: MPObject, frm_mask: Mask | None = None) -> list[MPObject]:
|
|
171
|
-
"""Seal an simp object, result a list of sealed objects, with
|
|
172
|
-
the i'th element as the secret from the i'th party.
|
|
173
|
-
"""
|
|
174
|
-
return _get_sapi().seal(obj, frm_mask=frm_mask)
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
# sealFrom :: m a -> m Rank -> s a
|
|
178
|
-
def sealFrom(obj: MPObject, root: Rank) -> MPObject:
|
|
179
|
-
"""Seal an simp object from a specific root party."""
|
|
180
|
-
return _get_sapi().sealFrom(obj, root)
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
# reveal :: s a -> m a
|
|
184
|
-
def reveal(obj: MPObject, to_mask: Mask | None = None) -> MPObject:
|
|
185
|
-
"""Reveal a sealed object to pmask'ed parties."""
|
|
186
|
-
to_mask = to_mask or Mask.all(psize())
|
|
187
|
-
return _get_sapi().reveal(obj, to_mask)
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
# revealTo :: s a -> m Rank -> m a
|
|
191
|
-
def revealTo(obj: MPObject, to_rank: Rank) -> MPObject:
|
|
192
|
-
return _get_sapi().revealTo(obj, to_rank)
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
# srun :: (a -> a) -> s a -> s a
|
|
196
|
-
def srun(pyfn: Callable, *, fe_type: str = "jax") -> Callable:
|
|
197
|
-
@wraps(pyfn)
|
|
198
|
-
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
|
199
|
-
return _get_sapi().seval(fe_type, pyfn, *args, **kwargs)
|
|
200
|
-
|
|
201
|
-
return wrapped
|