mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -16,11 +16,11 @@ from __future__ import annotations
|
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
|
|
19
|
-
from mplang.core import PFunction, TableType, TensorType
|
|
20
|
-
from mplang.kernels.base import cur_kctx, kernel_def
|
|
21
|
-
from mplang.kernels.value import TableValue, TensorValue, Value
|
|
22
|
-
from mplang.runtime.data_providers import get_provider, resolve_uri
|
|
23
|
-
from mplang.utils import table_utils
|
|
19
|
+
from mplang.v1.core import PFunction, TableType, TensorType
|
|
20
|
+
from mplang.v1.kernels.base import cur_kctx, kernel_def
|
|
21
|
+
from mplang.v1.kernels.value import TableValue, TensorValue, Value
|
|
22
|
+
from mplang.v1.runtime.data_providers import get_provider, resolve_uri
|
|
23
|
+
from mplang.v1.utils import table_utils
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
@kernel_def("basic.identity")
|
|
@@ -45,17 +45,17 @@ def _read(pfunc: PFunction) -> Value:
|
|
|
45
45
|
except Exception as e: # pragma: no cover - provider errors
|
|
46
46
|
raise RuntimeError(f"basic.read failed: {e}") from e
|
|
47
47
|
|
|
48
|
+
if isinstance(data, Value):
|
|
49
|
+
return data
|
|
50
|
+
|
|
48
51
|
if isinstance(out_t, TableType):
|
|
49
|
-
if isinstance(data, TableValue):
|
|
50
|
-
return data
|
|
51
52
|
return TableValue(data)
|
|
52
|
-
|
|
53
|
-
if isinstance(data, TensorValue):
|
|
54
|
-
return data
|
|
53
|
+
elif isinstance(out_t, TensorType):
|
|
55
54
|
return TensorValue(np.asarray(data))
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
55
|
+
else:
|
|
56
|
+
raise TypeError(
|
|
57
|
+
f"basic.read only supports TableType/TensorType outputs, got {type(out_t).__name__}"
|
|
58
|
+
)
|
|
59
59
|
|
|
60
60
|
|
|
61
61
|
@kernel_def("basic.write")
|
|
@@ -85,9 +85,9 @@ def _constant(pfunc: PFunction) -> Value:
|
|
|
85
85
|
out_t = pfunc.outs_info[0]
|
|
86
86
|
fmt = pfunc.attrs.get("data_format")
|
|
87
87
|
if isinstance(out_t, TableType):
|
|
88
|
-
if fmt != "bytes[
|
|
88
|
+
if fmt != "bytes[parquet]":
|
|
89
89
|
raise ValueError(f"unsupported table constant format {fmt}")
|
|
90
|
-
df = table_utils.
|
|
90
|
+
df = table_utils.decode_table(data_bytes, format="parquet")
|
|
91
91
|
return TableValue(df)
|
|
92
92
|
# tensor path
|
|
93
93
|
shape = out_t.shape # type: ignore[attr-defined,union-attr]
|
|
@@ -17,12 +17,12 @@ from __future__ import annotations
|
|
|
17
17
|
from collections.abc import Mapping
|
|
18
18
|
from typing import Any
|
|
19
19
|
|
|
20
|
-
from mplang.core.dtypes import UINT8, DType
|
|
21
|
-
from mplang.core.pfunc import PFunction
|
|
22
|
-
from mplang.core.table import TableLike, TableType
|
|
23
|
-
from mplang.core.tensor import TensorLike, TensorType
|
|
24
|
-
from mplang.kernels import base
|
|
25
|
-
from mplang.kernels.base import KernelContext, get_kernel_spec, kernel_exists
|
|
20
|
+
from mplang.v1.core.dtypes import UINT8, DType
|
|
21
|
+
from mplang.v1.core.pfunc import PFunction
|
|
22
|
+
from mplang.v1.core.table import PandasTableLike, TableLike, TableType
|
|
23
|
+
from mplang.v1.core.tensor import TensorLike, TensorType
|
|
24
|
+
from mplang.v1.kernels import base
|
|
25
|
+
from mplang.v1.kernels.base import KernelContext, get_kernel_spec, kernel_exists
|
|
26
26
|
|
|
27
27
|
# Default bindings
|
|
28
28
|
# Import kernel implementation modules explicitly so their @kernel_def entries
|
|
@@ -35,14 +35,14 @@ def _ensure_impl_imported() -> None:
|
|
|
35
35
|
global _IMPL_IMPORTED
|
|
36
36
|
if _IMPL_IMPORTED:
|
|
37
37
|
return
|
|
38
|
-
from mplang.kernels import basic as _impl_basic # noqa: F401
|
|
39
|
-
from mplang.kernels import crypto as _impl_crypto # noqa: F401
|
|
40
|
-
from mplang.kernels import fhe as _impl_fhe # noqa: F401
|
|
41
|
-
from mplang.kernels import mock_tee as _impl_tee # noqa: F401
|
|
42
|
-
from mplang.kernels import phe as _impl_phe # noqa: F401
|
|
43
|
-
from mplang.kernels import spu as _impl_spu # noqa: F401
|
|
44
|
-
from mplang.kernels import sql_duckdb as _impl_sql_duckdb # noqa: F401
|
|
45
|
-
from mplang.kernels import stablehlo as _impl_stablehlo # noqa: F401
|
|
38
|
+
from mplang.v1.kernels import basic as _impl_basic # noqa: F401
|
|
39
|
+
from mplang.v1.kernels import crypto as _impl_crypto # noqa: F401
|
|
40
|
+
from mplang.v1.kernels import fhe as _impl_fhe # noqa: F401
|
|
41
|
+
from mplang.v1.kernels import mock_tee as _impl_tee # noqa: F401
|
|
42
|
+
from mplang.v1.kernels import phe as _impl_phe # noqa: F401
|
|
43
|
+
from mplang.v1.kernels import spu as _impl_spu # noqa: F401
|
|
44
|
+
from mplang.v1.kernels import sql_duckdb as _impl_sql_duckdb # noqa: F401
|
|
45
|
+
from mplang.v1.kernels import stablehlo as _impl_stablehlo # noqa: F401
|
|
46
46
|
|
|
47
47
|
_IMPL_IMPORTED = True
|
|
48
48
|
|
|
@@ -317,9 +317,12 @@ def _validate_table_arg(
|
|
|
317
317
|
raise TypeError(
|
|
318
318
|
f"kernel {fn_type} input[{arg_index}] expects TableLike, got {type(value).__name__}"
|
|
319
319
|
)
|
|
320
|
-
|
|
320
|
+
columns = (
|
|
321
|
+
value.columns if isinstance(value, PandasTableLike) else value.column_names
|
|
322
|
+
)
|
|
323
|
+
if len(columns) != len(spec.columns):
|
|
321
324
|
raise ValueError(
|
|
322
|
-
f"kernel {fn_type} input[{arg_index}] column count mismatch: got {len(
|
|
325
|
+
f"kernel {fn_type} input[{arg_index}] column count mismatch: got {len(columns)}, expected {len(spec.columns)}"
|
|
323
326
|
)
|
|
324
327
|
|
|
325
328
|
|
|
@@ -18,10 +18,10 @@ import os
|
|
|
18
18
|
|
|
19
19
|
import numpy as np
|
|
20
20
|
|
|
21
|
-
from mplang.core import PFunction
|
|
22
|
-
from mplang.kernels.base import cur_kctx, kernel_def
|
|
23
|
-
from mplang.kernels.value import TensorValue
|
|
24
|
-
from mplang.utils.crypto import blake2b
|
|
21
|
+
from mplang.v1.core import PFunction
|
|
22
|
+
from mplang.v1.kernels.base import cur_kctx, kernel_def
|
|
23
|
+
from mplang.v1.kernels.value import TensorValue
|
|
24
|
+
from mplang.v1.utils.crypto import blake2b
|
|
25
25
|
|
|
26
26
|
__all__: list[str] = [] # No public exports currently
|
|
27
27
|
|
|
@@ -45,11 +45,9 @@ def _get_rng() -> np.random.Generator:
|
|
|
45
45
|
def _keystream(key: bytes, nonce: bytes, length: int) -> bytes:
|
|
46
46
|
# WARNING (INSECURE): hash-based keystream (key||nonce||counter)
|
|
47
47
|
out = bytearray()
|
|
48
|
-
counter = 0
|
|
49
48
|
while len(out) < length:
|
|
50
|
-
chunk = blake2b(key + nonce
|
|
49
|
+
chunk = blake2b(key + nonce)
|
|
51
50
|
out.extend(chunk)
|
|
52
|
-
counter += 1
|
|
53
51
|
return bytes(out[:length])
|
|
54
52
|
|
|
55
53
|
|
|
@@ -68,7 +66,7 @@ def _crypto_encrypt(
|
|
|
68
66
|
pt_bytes_np = pt_bytes.to_numpy().astype(np.uint8, copy=False)
|
|
69
67
|
key_np = key.to_numpy().astype(np.uint8, copy=False)
|
|
70
68
|
rng = _get_rng()
|
|
71
|
-
nonce = rng.integers(0, 256, size=(
|
|
69
|
+
nonce = rng.integers(0, 256, size=(16,), dtype=np.uint8)
|
|
72
70
|
stream = np.frombuffer(
|
|
73
71
|
_keystream(key_np.tobytes(), nonce.tobytes(), pt_bytes_np.size), dtype=np.uint8
|
|
74
72
|
)
|
|
@@ -83,8 +81,8 @@ def _crypto_decrypt(
|
|
|
83
81
|
) -> TensorValue:
|
|
84
82
|
ct_np = ct_with_nonce.to_numpy().astype(np.uint8, copy=False)
|
|
85
83
|
key_np = key.to_numpy().astype(np.uint8, copy=False)
|
|
86
|
-
nonce = ct_np[:
|
|
87
|
-
ct = ct_np[
|
|
84
|
+
nonce = ct_np[:16]
|
|
85
|
+
ct = ct_np[16:]
|
|
88
86
|
stream = np.frombuffer(
|
|
89
87
|
_keystream(key_np.tobytes(), nonce.tobytes(), len(ct)), dtype=np.uint8
|
|
90
88
|
)
|
|
@@ -23,8 +23,9 @@ from typing import Any
|
|
|
23
23
|
import numpy as np
|
|
24
24
|
import tenseal as ts
|
|
25
25
|
|
|
26
|
-
from mplang.core import DType, PFunction, TensorLike
|
|
27
|
-
from mplang.kernels.base import kernel_def
|
|
26
|
+
from mplang.v1.core import DType, PFunction, TensorLike
|
|
27
|
+
from mplang.v1.kernels.base import kernel_def
|
|
28
|
+
from mplang.v1.kernels.value import TensorValue
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
class FHEContext:
|
|
@@ -337,13 +338,14 @@ def _fhe_decrypt(pfunc: PFunction, ciphertext: CipherText, context: FHEContext)
|
|
|
337
338
|
|
|
338
339
|
# Restore original shape
|
|
339
340
|
if ciphertext.semantic_shape == ():
|
|
340
|
-
#
|
|
341
|
-
|
|
341
|
+
# Scalar: shape ()
|
|
342
|
+
result_np = decrypted_np[0:1].reshape(())
|
|
342
343
|
else:
|
|
343
|
-
#
|
|
344
|
-
|
|
344
|
+
# Vector: keep 1D array
|
|
345
|
+
result_np = decrypted_np
|
|
345
346
|
|
|
346
|
-
|
|
347
|
+
# Return TensorValue to adhere to kernel Value I/O convention
|
|
348
|
+
return (TensorValue(np.asarray(result_np)),)
|
|
347
349
|
|
|
348
350
|
except Exception as e:
|
|
349
351
|
raise RuntimeError(f"FHE vector decryption failed: {e}") from e
|
|
@@ -20,9 +20,9 @@ import warnings
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
from numpy.typing import NDArray
|
|
22
22
|
|
|
23
|
-
from mplang.core import PFunction
|
|
24
|
-
from mplang.kernels.base import cur_kctx, kernel_def
|
|
25
|
-
from mplang.kernels.value import TensorValue
|
|
23
|
+
from mplang.v1.core import PFunction
|
|
24
|
+
from mplang.v1.kernels.base import cur_kctx, kernel_def
|
|
25
|
+
from mplang.v1.kernels.value import TensorValue
|
|
26
26
|
|
|
27
27
|
__all__: list[str] = []
|
|
28
28
|
|
|
@@ -23,9 +23,9 @@ import numpy as np
|
|
|
23
23
|
from lightphe import LightPHE
|
|
24
24
|
from lightphe.models.Ciphertext import Ciphertext
|
|
25
25
|
|
|
26
|
-
from mplang.core import DType, PFunction
|
|
27
|
-
from mplang.kernels.base import kernel_def
|
|
28
|
-
from mplang.kernels.value import (
|
|
26
|
+
from mplang.v1.core import DType, PFunction
|
|
27
|
+
from mplang.v1.kernels.base import kernel_def
|
|
28
|
+
from mplang.v1.kernels.value import (
|
|
29
29
|
TensorValue,
|
|
30
30
|
Value,
|
|
31
31
|
ValueDecodeError,
|
|
@@ -33,7 +33,7 @@ from mplang.kernels.value import (
|
|
|
33
33
|
ValueProtoReader,
|
|
34
34
|
register_value,
|
|
35
35
|
)
|
|
36
|
-
from mplang.protos.v1alpha1 import value_pb2 as _value_pb2
|
|
36
|
+
from mplang.v1.protos.v1alpha1 import value_pb2 as _value_pb2
|
|
37
37
|
|
|
38
38
|
# This controls the decimal precision used in lightPHE for float operations
|
|
39
39
|
# we force it to 0 to only support integer operations
|
|
@@ -473,10 +473,9 @@ def _phe_keygen(pfunc: PFunction) -> Any:
|
|
|
473
473
|
# use small key_size to speed up tests
|
|
474
474
|
# in production use at least 2048 bits or 3072 bits for better security
|
|
475
475
|
key_size = pfunc.attrs.get("key_size", 2048)
|
|
476
|
-
max_value
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
fxp_bits = pfunc.attrs.get("fxp_bits", 12)
|
|
476
|
+
# Accept very large max_value; allow decimal string input, kept simple like other attrs
|
|
477
|
+
max_value = int(pfunc.attrs.get("max_value", 2**32))
|
|
478
|
+
fxp_bits = int(pfunc.attrs.get("fxp_bits", 12))
|
|
480
479
|
|
|
481
480
|
# Validate scheme
|
|
482
481
|
if scheme.lower() not in ["paillier"]:
|
|
@@ -638,7 +637,8 @@ def _phe_mul(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorValue) -
|
|
|
638
637
|
# Use numpy to create a properly broadcasted index mapping
|
|
639
638
|
# Create a dummy array with same shape as ciphertext, fill with indices
|
|
640
639
|
dummy_ct = (
|
|
641
|
-
np
|
|
640
|
+
np
|
|
641
|
+
.arange(np.prod(ciphertext.semantic_shape))
|
|
642
642
|
.reshape(ciphertext.semantic_shape)
|
|
643
643
|
.astype(np.int64)
|
|
644
644
|
)
|
|
@@ -745,7 +745,8 @@ def _phe_add_ct2ct(ct1: CipherText, ct2: CipherText) -> CipherText:
|
|
|
745
745
|
# Broadcast ct1 if needed
|
|
746
746
|
if ct1.semantic_shape != result_shape:
|
|
747
747
|
dummy_ct1 = (
|
|
748
|
-
np
|
|
748
|
+
np
|
|
749
|
+
.arange(np.prod(ct1.semantic_shape))
|
|
749
750
|
.reshape(ct1.semantic_shape)
|
|
750
751
|
.astype(np.int64)
|
|
751
752
|
)
|
|
@@ -758,7 +759,8 @@ def _phe_add_ct2ct(ct1: CipherText, ct2: CipherText) -> CipherText:
|
|
|
758
759
|
# Broadcast ct2 if needed
|
|
759
760
|
if ct2.semantic_shape != result_shape:
|
|
760
761
|
dummy_ct2 = (
|
|
761
|
-
np
|
|
762
|
+
np
|
|
763
|
+
.arange(np.prod(ct2.semantic_shape))
|
|
762
764
|
.reshape(ct2.semantic_shape)
|
|
763
765
|
.astype(np.int64)
|
|
764
766
|
)
|
|
@@ -831,7 +833,8 @@ def _phe_add_ct2pt(ciphertext: CipherText, plaintext: TensorValue) -> CipherText
|
|
|
831
833
|
# Broadcast ciphertext if needed
|
|
832
834
|
if ciphertext.semantic_shape != result_shape:
|
|
833
835
|
dummy_ct = (
|
|
834
|
-
np
|
|
836
|
+
np
|
|
837
|
+
.arange(np.prod(ciphertext.semantic_shape))
|
|
835
838
|
.reshape(ciphertext.semantic_shape)
|
|
836
839
|
.astype(np.int64)
|
|
837
840
|
)
|
|
@@ -997,12 +1000,17 @@ def _phe_decrypt(
|
|
|
997
1000
|
# Convert to target dtype
|
|
998
1001
|
if target_dtype.kind in "iu": # integer types
|
|
999
1002
|
# Convert floats back to integers for integer semantic types
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1003
|
+
# decoded_data are numeric (ints or floats); normalize to Python int
|
|
1004
|
+
ints = [round(v) if isinstance(v, float) else v for v in decoded_data]
|
|
1005
|
+
if np.issubdtype(target_dtype, np.unsignedinteger):
|
|
1006
|
+
# Reduce modulo 2^k for unsigned to preserve ring semantics
|
|
1007
|
+
width = np.iinfo(target_dtype).bits
|
|
1008
|
+
mod = 1 << width
|
|
1009
|
+
processed_data = [v % mod for v in ints]
|
|
1010
|
+
else:
|
|
1011
|
+
# Signed integers: clamp to dtype range
|
|
1012
|
+
info = np.iinfo(target_dtype)
|
|
1013
|
+
processed_data = [max(info.min, min(info.max, v)) for v in ints]
|
|
1006
1014
|
else: # float types
|
|
1007
1015
|
processed_data = decoded_data
|
|
1008
1016
|
|
|
@@ -21,7 +21,7 @@ import numpy as np
|
|
|
21
21
|
import spu.api as spu_api
|
|
22
22
|
import spu.libspu as libspu
|
|
23
23
|
|
|
24
|
-
from mplang.core import (
|
|
24
|
+
from mplang.v1.core import (
|
|
25
25
|
BOOL,
|
|
26
26
|
FLOAT32,
|
|
27
27
|
FLOAT64,
|
|
@@ -36,8 +36,8 @@ from mplang.core import (
|
|
|
36
36
|
DType,
|
|
37
37
|
PFunction,
|
|
38
38
|
)
|
|
39
|
-
from mplang.kernels.base import cur_kctx, kernel_def
|
|
40
|
-
from mplang.kernels.value import (
|
|
39
|
+
from mplang.v1.kernels.base import cur_kctx, kernel_def
|
|
40
|
+
from mplang.v1.kernels.value import (
|
|
41
41
|
TensorValue,
|
|
42
42
|
Value,
|
|
43
43
|
ValueDecodeError,
|
|
@@ -45,8 +45,8 @@ from mplang.kernels.value import (
|
|
|
45
45
|
ValueProtoReader,
|
|
46
46
|
register_value,
|
|
47
47
|
)
|
|
48
|
-
from mplang.protos.v1alpha1 import value_pb2 as _value_pb2
|
|
49
|
-
from mplang.runtime.link_comm import LinkCommunicator
|
|
48
|
+
from mplang.v1.protos.v1alpha1 import value_pb2 as _value_pb2
|
|
49
|
+
from mplang.v1.runtime.link_comm import LinkCommunicator
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
def shape_spu_to_np(spu_shape: Any) -> tuple[int, ...]:
|
|
@@ -14,9 +14,9 @@
|
|
|
14
14
|
|
|
15
15
|
from __future__ import annotations
|
|
16
16
|
|
|
17
|
-
from mplang.core import PFunction
|
|
18
|
-
from mplang.kernels.base import kernel_def
|
|
19
|
-
from mplang.kernels.value import TableValue
|
|
17
|
+
from mplang.v1.core import PFunction
|
|
18
|
+
from mplang.v1.kernels.base import kernel_def
|
|
19
|
+
from mplang.v1.kernels.value import TableValue
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
@kernel_def("duckdb.run_sql")
|
|
@@ -38,5 +38,7 @@ def _duckdb_sql(pfunc: PFunction, *args: TableValue) -> TableValue:
|
|
|
38
38
|
arrow_table = arg.to_arrow()
|
|
39
39
|
conn.register(name, arrow_table)
|
|
40
40
|
# Fetch result as Arrow table for consistency
|
|
41
|
+
if pfunc.fn_text is None:
|
|
42
|
+
raise ValueError("SQL function text is None")
|
|
41
43
|
res_arrow = conn.execute(pfunc.fn_text).fetch_arrow_table()
|
|
42
44
|
return TableValue(res_arrow)
|
|
@@ -17,14 +17,14 @@ from __future__ import annotations
|
|
|
17
17
|
from typing import Any
|
|
18
18
|
|
|
19
19
|
import jax
|
|
20
|
+
import jax.extend as jxt
|
|
20
21
|
import jax.numpy as jnp
|
|
21
22
|
import numpy as np
|
|
22
|
-
from jax._src import
|
|
23
|
-
from jax.lib import xla_client as xc
|
|
23
|
+
from jax._src import compiler
|
|
24
24
|
|
|
25
|
-
from mplang.core import PFunction
|
|
26
|
-
from mplang.kernels.base import cur_kctx, kernel_def
|
|
27
|
-
from mplang.kernels.value import TensorValue
|
|
25
|
+
from mplang.v1.core import PFunction
|
|
26
|
+
from mplang.v1.kernels.base import cur_kctx, kernel_def
|
|
27
|
+
from mplang.v1.kernels.value import TensorValue
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
@kernel_def("mlir.stablehlo")
|
|
@@ -47,11 +47,13 @@ def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
|
|
|
47
47
|
key = f"stablehlo.compile_cache.{h}"
|
|
48
48
|
compiled = rt.get_state(key)
|
|
49
49
|
if compiled is None:
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
50
|
+
client = jxt.backend.get_backend()
|
|
51
|
+
compile_options = compiler.get_compile_options(num_replicas=1, num_partitions=1)
|
|
52
|
+
|
|
53
53
|
try:
|
|
54
|
-
compiled = client.
|
|
54
|
+
compiled = client.compile_and_load(
|
|
55
|
+
mlir_text, client.devices(), compile_options
|
|
56
|
+
)
|
|
55
57
|
except Exception as e: # pragma: no cover
|
|
56
58
|
raise RuntimeError(f"StableHLO compile failed: {e}") from e
|
|
57
59
|
rt.set_state(key, compiled)
|
|
@@ -76,14 +78,13 @@ def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
|
|
|
76
78
|
]
|
|
77
79
|
|
|
78
80
|
try:
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
flat.extend(TensorValue(np.asarray(a)) for a in lst)
|
|
81
|
+
# Execute with the new LoadedExecutable interface
|
|
82
|
+
result = compiled.execute(jax_args)
|
|
83
|
+
|
|
84
|
+
# Use jax.tree_util.tree_flatten to robustly handle any PyTree structure
|
|
85
|
+
flat_results, _ = jax.tree_util.tree_flatten(result)
|
|
86
|
+
flat = [TensorValue(np.asarray(item)) for item in flat_results]
|
|
87
|
+
|
|
87
88
|
return tuple(flat)
|
|
88
89
|
except Exception as e: # pragma: no cover
|
|
89
90
|
raise RuntimeError(f"StableHLO execute failed: {e}") from e
|
|
@@ -17,7 +17,7 @@ from __future__ import annotations
|
|
|
17
17
|
from abc import ABC, abstractmethod
|
|
18
18
|
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar
|
|
19
19
|
|
|
20
|
-
from mplang.protos.v1alpha1 import value_pb2 as _value_pb2
|
|
20
|
+
from mplang.v1.protos.v1alpha1 import value_pb2 as _value_pb2
|
|
21
21
|
|
|
22
22
|
if TYPE_CHECKING:
|
|
23
23
|
import numpy as np
|
|
@@ -591,7 +591,7 @@ class TableValue(Value): # well-known table (Arrow IPC) Value
|
|
|
591
591
|
|
|
592
592
|
Note: This creates a copy and converts from Arrow to pandas format.
|
|
593
593
|
For better performance, consider using to_arrow() and working with
|
|
594
|
-
Arrow-native APIs (DuckDB,
|
|
594
|
+
Arrow-native APIs (DuckDB, etc.) directly.
|
|
595
595
|
|
|
596
596
|
Returns:
|
|
597
597
|
pandas.DataFrame: Converted dataframe
|
|
@@ -19,15 +19,15 @@ This module contains compilers that transform high-level functions into
|
|
|
19
19
|
portable, serializable intermediate representations.
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
|
-
from mplang.ops import basic, crypto,
|
|
23
|
-
from mplang.ops.base import FeOperation as FeOperation
|
|
22
|
+
from mplang.v1.ops import basic, crypto, jax_cc, nnx_cc, phe, spu, sql_cc, tee
|
|
23
|
+
from mplang.v1.ops.base import FeOperation as FeOperation
|
|
24
24
|
|
|
25
25
|
__all__ = [
|
|
26
26
|
"FeOperation",
|
|
27
27
|
"basic",
|
|
28
28
|
"crypto",
|
|
29
|
-
"ibis_cc",
|
|
30
29
|
"jax_cc",
|
|
30
|
+
"nnx_cc",
|
|
31
31
|
"phe",
|
|
32
32
|
"spu",
|
|
33
33
|
"sql_cc",
|
mplang/{ops → v1/ops}/base.py
RENAMED
|
@@ -20,7 +20,7 @@ from typing import Any
|
|
|
20
20
|
|
|
21
21
|
from jax.tree_util import PyTreeDef, tree_flatten
|
|
22
22
|
|
|
23
|
-
from mplang.core import MPContext, MPObject, PFunction, TableType, TensorType
|
|
23
|
+
from mplang.v1.core import MPContext, MPObject, PFunction, TableType, TensorType
|
|
24
24
|
|
|
25
25
|
# -----------------------------------------------------------------------------
|
|
26
26
|
# Triad ABI
|
mplang/{ops → v1/ops}/basic.py
RENAMED
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
from jax.tree_util import PyTreeDef, tree_flatten
|
|
17
17
|
|
|
18
|
-
from mplang.core import (
|
|
18
|
+
from mplang.v1.core import (
|
|
19
19
|
UINT8,
|
|
20
20
|
UINT64,
|
|
21
21
|
MPObject,
|
|
@@ -27,8 +27,8 @@ from mplang.core import (
|
|
|
27
27
|
TensorLike,
|
|
28
28
|
TensorType,
|
|
29
29
|
)
|
|
30
|
-
from mplang.ops.base import stateless_mod
|
|
31
|
-
from mplang.utils import table_utils
|
|
30
|
+
from mplang.v1.ops.base import stateless_mod
|
|
31
|
+
from mplang.v1.utils import table_utils
|
|
32
32
|
|
|
33
33
|
_BASIC_MOD = stateless_mod("basic")
|
|
34
34
|
|
|
@@ -108,8 +108,9 @@ def constant(
|
|
|
108
108
|
out_type: TableType | TensorType
|
|
109
109
|
|
|
110
110
|
if isinstance(data, TableLike):
|
|
111
|
-
|
|
112
|
-
|
|
111
|
+
format = "parquet"
|
|
112
|
+
data_bytes = table_utils.encode_table(data, format=format)
|
|
113
|
+
data_format = f"bytes[{format}]"
|
|
113
114
|
out_type = TableType.from_tablelike(data)
|
|
114
115
|
elif isinstance(data, ScalarType):
|
|
115
116
|
out_type = TensorType.from_obj(data)
|