mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/{simp → v1/simp}/api.py
RENAMED
|
@@ -17,7 +17,7 @@ from __future__ import annotations
|
|
|
17
17
|
from collections.abc import Callable
|
|
18
18
|
from typing import Any, cast
|
|
19
19
|
|
|
20
|
-
from mplang.core import (
|
|
20
|
+
from mplang.v1.core import (
|
|
21
21
|
Mask,
|
|
22
22
|
MPObject,
|
|
23
23
|
Rank,
|
|
@@ -28,8 +28,8 @@ from mplang.core import (
|
|
|
28
28
|
builtin_function,
|
|
29
29
|
peval,
|
|
30
30
|
)
|
|
31
|
-
from mplang.ops import basic,
|
|
32
|
-
from mplang.ops.base import FeOperation
|
|
31
|
+
from mplang.v1.ops import basic, jax_cc, nnx_cc, sql_cc
|
|
32
|
+
from mplang.v1.ops.base import FeOperation
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
def run(
|
|
@@ -273,23 +273,81 @@ def run_jax_at(rank: Rank, jax_fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
|
273
273
|
return run_at(rank, jax_cc.run_jax, jax_fn, *args, **kwargs)
|
|
274
274
|
|
|
275
275
|
|
|
276
|
-
def run_ibis(ibis_expr: Any, *args: Any, **kwargs: Any) -> Any:
|
|
277
|
-
# TODO(jint): add docstring, add type hints, describe args and kwargs constraints.
|
|
278
|
-
return run(None, ibis_cc.run_ibis, ibis_expr, *args, **kwargs)
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
def run_ibis_at(rank: Rank, ibis_fn: Any, *args: Any, **kwargs: Any) -> Any:
|
|
282
|
-
return run_at(rank, ibis_cc.run_ibis, ibis_fn, *args, **kwargs)
|
|
283
|
-
|
|
284
|
-
|
|
285
276
|
def run_sql(
|
|
286
277
|
query: str, out_type: Any, in_tables: dict[str, MPObject] | None = None
|
|
287
278
|
) -> Any:
|
|
288
279
|
# TODO(jint): add docstring, drop out_type.
|
|
289
|
-
return run(None, sql_cc.
|
|
280
|
+
return run(None, sql_cc.run_sql_raw, query, out_type, in_tables)
|
|
290
281
|
|
|
291
282
|
|
|
292
283
|
def run_sql_at(
|
|
293
284
|
rank: Rank, query: str, out_type: Any, in_tables: dict[str, MPObject] | None = None
|
|
294
285
|
) -> Any:
|
|
295
|
-
return run_at(rank, sql_cc.
|
|
286
|
+
return run_at(rank, sql_cc.run_sql_raw, query, out_type, in_tables)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def run_nnx(nnx_fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
290
|
+
"""Run an NNX function.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
nnx_fn: The NNX function to be executed.
|
|
294
|
+
*args: Positional arguments to pass to the NNX function.
|
|
295
|
+
**kwargs: Keyword arguments to pass to the NNX function.
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
The result of evaluating the NNX function through the mplang system.
|
|
299
|
+
|
|
300
|
+
Raises:
|
|
301
|
+
TypeError: If the function compilation or evaluation fails.
|
|
302
|
+
RuntimeError: If the underlying peval execution encounters errors.
|
|
303
|
+
|
|
304
|
+
Notes:
|
|
305
|
+
Argument binding semantics with respect to NNX static arguments:
|
|
306
|
+
|
|
307
|
+
- If an argument (or any leaf within a PyTree argument) is an
|
|
308
|
+
:class:`~mplana.v1.core.mpobject.MPObject`, it is captured as a runtime
|
|
309
|
+
variable (dynamic value) in the traced program and is not treated as a
|
|
310
|
+
NNX static argument.
|
|
311
|
+
- If an argument contains no :class:`MPObject` leaves, it is treated as a
|
|
312
|
+
constant configuration with respect to NNX; effectively it behaves
|
|
313
|
+
like a static argument and may contribute to NNX compilation cache
|
|
314
|
+
keys (similar to ``static_argnums`` semantics). Changing such constant
|
|
315
|
+
arguments can lead to different compiled variants/cached entries.
|
|
316
|
+
|
|
317
|
+
Examples:
|
|
318
|
+
Defining and running a simple NNX function:
|
|
319
|
+
|
|
320
|
+
>>> from flax import nnx
|
|
321
|
+
>>> import jax.numpy as jnp
|
|
322
|
+
>>> def nnx_linear(inputs, weights, bias):
|
|
323
|
+
... return jnp.dot(inputs, weights) + bias
|
|
324
|
+
>>> result = run_nnx(nnx_linear, inputs, weights, bias)
|
|
325
|
+
|
|
326
|
+
Running an NNX model:
|
|
327
|
+
|
|
328
|
+
>>> class LinearModel(nnx.Module):
|
|
329
|
+
... def __init__(self, features: int, rngs: nnx.Rngs):
|
|
330
|
+
... self.linear = nnx.Linear(features, features, rngs=rngs)
|
|
331
|
+
...
|
|
332
|
+
... def __call__(self, x):
|
|
333
|
+
... return self.linear(x)
|
|
334
|
+
>>> def forward_pass(model, x):
|
|
335
|
+
... return model(x)
|
|
336
|
+
>>> output = run_nnx(forward_pass, model, input_data)
|
|
337
|
+
"""
|
|
338
|
+
return run(None, nnx_cc.run_nnx, nnx_fn, *args, **kwargs)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def run_nnx_at(rank: Rank, nnx_fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
342
|
+
"""Run an NNX function at a specific rank.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
rank: The rank where the NNX function should be executed.
|
|
346
|
+
nnx_fn: The NNX function to be executed.
|
|
347
|
+
*args: Positional arguments to pass to the NNX function.
|
|
348
|
+
**kwargs: Keyword arguments to pass to the NNX function.
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
The result of evaluating the NNX function at the specified rank.
|
|
352
|
+
"""
|
|
353
|
+
return run_at(rank, nnx_cc.run_nnx, nnx_fn, *args, **kwargs)
|
mplang/{simp → v1/simp}/mpi.py
RENAMED
mplang/{simp → v1/simp}/party.py
RENAMED
|
@@ -22,9 +22,9 @@ from functools import wraps
|
|
|
22
22
|
from types import ModuleType
|
|
23
23
|
from typing import Any
|
|
24
24
|
|
|
25
|
-
from mplang.ops.base import FeOperation
|
|
26
|
-
from mplang.simp.api import run_at, run_jax_at
|
|
27
|
-
from mplang.simp.mpi import p2p
|
|
25
|
+
from mplang.v1.ops.base import FeOperation
|
|
26
|
+
from mplang.v1.simp.api import run_at, run_jax_at
|
|
27
|
+
from mplang.v1.simp.mpi import p2p
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
def P2P(src: Party, dst: Party, value: Any) -> Any:
|
|
@@ -163,7 +163,7 @@ def _load_prelude_modules() -> None:
|
|
|
163
163
|
unwieldy we can switch to an allowlist.
|
|
164
164
|
"""
|
|
165
165
|
try:
|
|
166
|
-
import mplang.ops as _fe # type: ignore
|
|
166
|
+
import mplang.v1.ops as _fe # type: ignore
|
|
167
167
|
except (ImportError, ModuleNotFoundError): # pragma: no cover
|
|
168
168
|
# Frontend package not present (minimal install); safe to skip.
|
|
169
169
|
return
|
|
@@ -173,7 +173,7 @@ def _load_prelude_modules() -> None:
|
|
|
173
173
|
if m.name.startswith("_"):
|
|
174
174
|
continue
|
|
175
175
|
if m.name not in _NAMESPACE_REGISTRY:
|
|
176
|
-
_NAMESPACE_REGISTRY[m.name] = f"mplang.ops.{m.name}"
|
|
176
|
+
_NAMESPACE_REGISTRY[m.name] = f"mplang.v1.ops.{m.name}"
|
|
177
177
|
|
|
178
178
|
|
|
179
179
|
def load_module(module: str, alias: str | None = None) -> None:
|
|
@@ -21,8 +21,8 @@ import jax.numpy as jnp
|
|
|
21
21
|
import jax.random as jr
|
|
22
22
|
from jax.typing import ArrayLike
|
|
23
23
|
|
|
24
|
-
from mplang.core import MPObject, Shape, function, pmask, psize
|
|
25
|
-
from mplang.simp.api import prand, prank, run_jax
|
|
24
|
+
from mplang.v1.core import MPObject, Shape, function, pmask, psize
|
|
25
|
+
from mplang.v1.simp.api import prand, prank, run_jax
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
@function
|
mplang/v1/simp/smpc.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
SMPC on simp: conventions and object semantics
|
|
17
|
+
|
|
18
|
+
Overview
|
|
19
|
+
- simp is party-centric. Objects produced purely by simp code carry only an execution
|
|
20
|
+
mask ("pmask") and have no security device semantics by default.
|
|
21
|
+
- Secure semantics (secret sharing, protected execution, declassification) are introduced
|
|
22
|
+
only when using the device API or the helpers in this module: "seal", "srun", "reveal".
|
|
23
|
+
|
|
24
|
+
Definitions
|
|
25
|
+
- "__device__" attribute is attached by the device API to indicate the concrete device
|
|
26
|
+
an object is bound to (e.g., an SPU/TEE/PPU name). See mplang.device.DEVICE_ATTR_NAME.
|
|
27
|
+
- pmask describes which parties currently hold/execute the value under the simp model.
|
|
28
|
+
|
|
29
|
+
Conventions
|
|
30
|
+
1) If an object has NO "__device__" attribute (i.e., it has not gone through the device API):
|
|
31
|
+
- It is a simp object, privately owned on the parties indicated by its pmask.
|
|
32
|
+
- When sealed via "seal(obj)", we infer target PPU device(s) from pmask:
|
|
33
|
+
• one-hot pmask {pi} → route to PPU(pi).
|
|
34
|
+
• multi-party pmask → fan out per party and seal independently to each party's PPU.
|
|
35
|
+
- Such objects CANNOT be passed to "srun"/"reveal" directly; seal first.
|
|
36
|
+
|
|
37
|
+
2) If an object HAS a "__device__" attribute:
|
|
38
|
+
- Its behavior follows the bound device (e.g., SPU/TEE/PPU) and its membership.
|
|
39
|
+
- "srun" executes on that device; "reveal" declassifies from that device to the requested parties.
|
|
40
|
+
- pmask must be consistent with the device membership during transitions; inconsistencies raise errors.
|
|
41
|
+
|
|
42
|
+
Notes
|
|
43
|
+
- "seal"/"seal_from" construct secret shares on the chosen secure device and attach the
|
|
44
|
+
"__device__" attribute to outputs. "srun"/"reveal" assume inputs are already sealed
|
|
45
|
+
(device-bound) and validate pmask ↔ device-membership consistency.
|
|
46
|
+
- These rules align with "design/simp_vs_device.md" and keep routing unambiguous.
|
|
47
|
+
|
|
48
|
+
Examples (obj state → interpretation/behavior)
|
|
49
|
+
- {pmask={A}, dev_attr=None}: simp plaintext on party A. "seal" routes to PPU(A);
|
|
50
|
+
must "seal" before "srun"/"reveal".
|
|
51
|
+
- {pmask={A,B}, dev_attr=None}: simp plaintext held by A and B. "seal" produces two
|
|
52
|
+
per-party sealed objects via PPU(A) and PPU(B), respectively.
|
|
53
|
+
- {pmask={A,B}, dev_attr="spu:spu0"}: device object on SPU(spu0) whose members are {A,B};
|
|
54
|
+
"srun" runs on spu0; "reveal(to={A})" reveals result to party A.
|
|
55
|
+
- {pmask={A}, dev_attr="ppu:A"}: device object on PPU(A); "reveal(to={A})" returns A's plaintext.
|
|
56
|
+
- {pmask=None, dev_attr=None}: dynamic pmask; "seal" is unsupported and will error.
|
|
57
|
+
- {pmask={A}, dev_attr="spu:spu0"} where A ∉ members(spu0): inconsistent; operations will error.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
from collections.abc import Callable
|
|
61
|
+
from typing import Any
|
|
62
|
+
|
|
63
|
+
from mplang.v1 import _device
|
|
64
|
+
from mplang.v1.core import Mask, MPObject, Rank, psize
|
|
65
|
+
from mplang.v1.core.cluster import Device
|
|
66
|
+
from mplang.v1.core.context_mgr import cur_ctx
|
|
67
|
+
from mplang.v1.core.primitive import pconv
|
|
68
|
+
from mplang.v1.simp.api import set_mask
|
|
69
|
+
from mplang.v1.utils.func_utils import normalize_fn
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _determine_secure_device(*args: MPObject) -> Device:
|
|
73
|
+
"""Determine secure device from args, or find any available if no args."""
|
|
74
|
+
if not args:
|
|
75
|
+
# Find an available secure device (fallback when no args provided).
|
|
76
|
+
devices = cur_ctx().cluster_spec.get_devices_by_kind("SPU")
|
|
77
|
+
if devices:
|
|
78
|
+
return devices[0]
|
|
79
|
+
|
|
80
|
+
devices = cur_ctx().cluster_spec.get_devices_by_kind("TEE")
|
|
81
|
+
if devices:
|
|
82
|
+
return devices[0]
|
|
83
|
+
|
|
84
|
+
raise ValueError(
|
|
85
|
+
"No secure device (SPU or TEE) found in the cluster specification"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
dev_names: list[str] = []
|
|
89
|
+
for arg in args:
|
|
90
|
+
if not _device.is_device_obj(arg):
|
|
91
|
+
raise ValueError(
|
|
92
|
+
"srun/reveal expect sealed inputs with a device attribute; "
|
|
93
|
+
f"got an unsealed object: {arg}. Please call seal()/seal_from() first."
|
|
94
|
+
)
|
|
95
|
+
dev_names.append(_device.get_dev_attr(arg))
|
|
96
|
+
|
|
97
|
+
if len(set(dev_names)) != 1:
|
|
98
|
+
raise ValueError(f"Ambiguous secure devices among arguments: {dev_names}")
|
|
99
|
+
|
|
100
|
+
dev_name = dev_names[0]
|
|
101
|
+
|
|
102
|
+
cluster_spec = cur_ctx().cluster_spec
|
|
103
|
+
assert dev_name in cluster_spec.devices
|
|
104
|
+
return cluster_spec.devices[dev_name]
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _get_ppu_from_rank(rank: Rank) -> Device:
|
|
108
|
+
"""Get the PPU device for a specific rank."""
|
|
109
|
+
for dev in cur_ctx().cluster_spec.get_devices_by_kind("PPU"):
|
|
110
|
+
assert len(dev.members) == 1, "Expected single member PPU devices."
|
|
111
|
+
if dev.members[0].rank == rank:
|
|
112
|
+
return dev
|
|
113
|
+
raise ValueError(f"No PPU device found for rank {rank}.")
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def seal(obj: MPObject) -> list[MPObject] | MPObject:
|
|
117
|
+
"""Seal a simp object to a secure device.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
obj: The simp object to seal.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
The sealed object(s). If the input is a plaintext simp object with a multi-party
|
|
124
|
+
mask, a list of sealed objects (one per party) is returned. Otherwise, a
|
|
125
|
+
single sealed object is returned.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
if obj.pmask is None:
|
|
129
|
+
raise ValueError("Seal does not support dynamic masks.")
|
|
130
|
+
|
|
131
|
+
if _device.is_device_obj(obj):
|
|
132
|
+
sdev = _determine_secure_device()
|
|
133
|
+
return _device._d2d(sdev.name, obj)
|
|
134
|
+
else:
|
|
135
|
+
# it's a normal plaintext simp object, treat as a list of PPU objects
|
|
136
|
+
rets: list[MPObject] = []
|
|
137
|
+
for rank in obj.pmask:
|
|
138
|
+
ppu_obj = set_mask(obj, Mask.from_ranks([rank]))
|
|
139
|
+
_device.set_dev_attr(ppu_obj, _get_ppu_from_rank(rank).name)
|
|
140
|
+
sealed = seal(ppu_obj)
|
|
141
|
+
assert isinstance(sealed, MPObject), (
|
|
142
|
+
"Expected single sealed object per rank"
|
|
143
|
+
)
|
|
144
|
+
rets.append(sealed)
|
|
145
|
+
return rets
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def seal_from(from_rank: Rank, obj: MPObject) -> MPObject:
|
|
149
|
+
"""Seal a simp object from a specific party to its PPU.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
from_rank: The party rank from which to seal the object.
|
|
153
|
+
obj: The simp object to seal.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
The sealed object.
|
|
157
|
+
"""
|
|
158
|
+
obj = set_mask(obj, Mask.from_ranks([from_rank]))
|
|
159
|
+
out = seal(obj)
|
|
160
|
+
assert isinstance(out, list), "seal_from should return a list of sealed objects."
|
|
161
|
+
assert len(out) == 1, "seal_from should return a single sealed object."
|
|
162
|
+
return out[0]
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
# reveal :: s a -> m a
|
|
166
|
+
def reveal(obj: MPObject, to_mask: Mask | None = None) -> MPObject:
|
|
167
|
+
"""Reveal a sealed object to pmask'ed parties."""
|
|
168
|
+
assert isinstance(obj, MPObject), "reveal expects an MPObject."
|
|
169
|
+
|
|
170
|
+
if not _device.is_device_obj(obj):
|
|
171
|
+
raise ValueError(f"reveal does not support non-device object={obj}.")
|
|
172
|
+
|
|
173
|
+
if to_mask is None:
|
|
174
|
+
ranks = []
|
|
175
|
+
for rank in range(psize()):
|
|
176
|
+
try:
|
|
177
|
+
_get_ppu_from_rank(rank)
|
|
178
|
+
except ValueError:
|
|
179
|
+
continue
|
|
180
|
+
ranks.append(rank)
|
|
181
|
+
to_mask = Mask.from_ranks(ranks)
|
|
182
|
+
rets = [reveal_to(rank, obj) for rank in to_mask]
|
|
183
|
+
return pconv(rets)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def reveal_to(to_rank: Rank, obj: MPObject) -> MPObject:
|
|
187
|
+
"""Reveal a sealed object to a specific party."""
|
|
188
|
+
if not _device.is_device_obj(obj):
|
|
189
|
+
raise ValueError("reveal_to expects a device object (sealed value).")
|
|
190
|
+
|
|
191
|
+
to_dev = _get_ppu_from_rank(to_rank)
|
|
192
|
+
return _device._d2d(to_dev.name, obj)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def srun(fe_type: str, pyfn: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
196
|
+
"""Run a function on sealed values securely.
|
|
197
|
+
|
|
198
|
+
This function executes a computation on sealed (secret-shared) values
|
|
199
|
+
using secure multi-party computation (MPC).
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
fe_type: The front-end type, e.g., "jax"
|
|
203
|
+
pyfn: A function to run on sealed values
|
|
204
|
+
*args: Positional arguments (sealed values)
|
|
205
|
+
**kwargs: Keyword arguments (sealed values)
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
The result of the computation, still in sealed form
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
fn_flat, args_flat = normalize_fn(
|
|
212
|
+
pyfn, args, kwargs, lambda x: isinstance(x, MPObject)
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
dev_info = _determine_secure_device(*args_flat)
|
|
216
|
+
|
|
217
|
+
dev_kind = dev_info.kind.upper()
|
|
218
|
+
if dev_kind in {"SPU", "TEE"}:
|
|
219
|
+
return _device.device(dev_info.name, fe_type=fe_type)(fn_flat)(args_flat)
|
|
220
|
+
else:
|
|
221
|
+
raise ValueError(f"Unsupported secure device kind: {dev_kind}")
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def srun_jax(jax_fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
|
225
|
+
"""Run a jax function on sealed values securely.
|
|
226
|
+
|
|
227
|
+
This function executes a JAX computation on sealed (secret-shared) values
|
|
228
|
+
using secure multi-party computation (MPC).
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
jax_fn: A JAX function to run on sealed values
|
|
232
|
+
*args: Positional arguments (sealed values)
|
|
233
|
+
**kwargs: Keyword arguments (sealed values)
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
The result of the computation, still in sealed form
|
|
237
|
+
"""
|
|
238
|
+
return srun("jax", jax_fn, *args, **kwargs)
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import io
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
import pyarrow as pa
|
|
21
|
+
import pyarrow.csv as pa_csv
|
|
22
|
+
import pyarrow.orc as pa_orc
|
|
23
|
+
import pyarrow.parquet as pa_pq
|
|
24
|
+
|
|
25
|
+
from mplang.v1.core.table import TableLike
|
|
26
|
+
|
|
27
|
+
__all__ = ["decode_table", "encode_table", "read_table", "write_table"]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _parse_kwargs(kwargs: dict[str, Any], keys: list[str]) -> dict[str, Any] | None:
|
|
31
|
+
if not kwargs:
|
|
32
|
+
return None
|
|
33
|
+
|
|
34
|
+
return {key: kwargs[key] for key in keys if key in kwargs}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
_csv_read_option_keys = [
|
|
38
|
+
"skip_rows",
|
|
39
|
+
"skip_rows_after_names",
|
|
40
|
+
"column_names",
|
|
41
|
+
"autogenerate_column_names",
|
|
42
|
+
"encoding",
|
|
43
|
+
]
|
|
44
|
+
_csv_parse_option_keys = [
|
|
45
|
+
"delimiter",
|
|
46
|
+
"quote_char",
|
|
47
|
+
"double_quote",
|
|
48
|
+
"escape_char",
|
|
49
|
+
"newlines_in_values",
|
|
50
|
+
"ignore_empty_lines",
|
|
51
|
+
]
|
|
52
|
+
_csv_convert_option_keys = [
|
|
53
|
+
"check_utf8",
|
|
54
|
+
"column_types",
|
|
55
|
+
"null_values",
|
|
56
|
+
"true_values",
|
|
57
|
+
"false_values",
|
|
58
|
+
"decimal_point",
|
|
59
|
+
"strings_can_be_null",
|
|
60
|
+
"quoted_strings_can_be_null",
|
|
61
|
+
"include_columns",
|
|
62
|
+
"include_missing_columns",
|
|
63
|
+
"auto_dict_encode",
|
|
64
|
+
"auto_dict_max_cardinality",
|
|
65
|
+
"timestamp_parsers",
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def read_table(
|
|
70
|
+
source: Any,
|
|
71
|
+
format: str = "parquet",
|
|
72
|
+
columns: list[str] | None = None,
|
|
73
|
+
**kwargs: Any,
|
|
74
|
+
) -> pa.Table:
|
|
75
|
+
"""Read data from a file and return a PyArrow table.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
source: The source to read data from (file path, file-like object, etc.)
|
|
79
|
+
format: The format of the data source ("parquet", "csv", or "orc")
|
|
80
|
+
columns: List of column names to read (None means all columns)
|
|
81
|
+
**kwargs: Additional keyword arguments passed to the underlying reader
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
A PyArrow Table containing the data from the source
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
ValueError: If an unsupported format is specified
|
|
88
|
+
"""
|
|
89
|
+
match format:
|
|
90
|
+
case "csv":
|
|
91
|
+
if columns:
|
|
92
|
+
kwargs["include_columns"] = columns
|
|
93
|
+
read_args = _parse_kwargs(kwargs, _csv_read_option_keys)
|
|
94
|
+
parse_args = _parse_kwargs(kwargs, _csv_parse_option_keys)
|
|
95
|
+
convert_args = _parse_kwargs(kwargs, _csv_convert_option_keys)
|
|
96
|
+
|
|
97
|
+
read_opts = pa_csv.ReadOptions(**read_args) if read_args else None
|
|
98
|
+
parse_opts = pa_csv.ParseOptions(**parse_args) if parse_args else None
|
|
99
|
+
conv_opts = pa_csv.ConvertOptions(**convert_args) if convert_args else None
|
|
100
|
+
return pa_csv.read_csv(
|
|
101
|
+
source,
|
|
102
|
+
read_options=read_opts,
|
|
103
|
+
parse_options=parse_opts,
|
|
104
|
+
convert_options=conv_opts,
|
|
105
|
+
)
|
|
106
|
+
case "orc":
|
|
107
|
+
return pa_orc.read_table(source, columns=columns, **kwargs)
|
|
108
|
+
case "parquet":
|
|
109
|
+
return pa_pq.read_table(source, columns=columns, **kwargs)
|
|
110
|
+
case _:
|
|
111
|
+
raise ValueError(f"unsupported data format. {format}")
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def write_table(
|
|
115
|
+
data: TableLike,
|
|
116
|
+
where: Any,
|
|
117
|
+
format: str = "parquet",
|
|
118
|
+
**kwargs: Any,
|
|
119
|
+
) -> None:
|
|
120
|
+
"""Write a table-like object to a file in the specified format.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
data: The table-like object to write (PyArrow Table or other compatible format)
|
|
124
|
+
where: The destination to write to (file path, file-like object, etc.)
|
|
125
|
+
format: The format to write the data in ("parquet", "csv", or "orc")
|
|
126
|
+
**kwargs: Additional keyword arguments passed to the underlying writer
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
None
|
|
130
|
+
|
|
131
|
+
Raises:
|
|
132
|
+
ValueError: If the table has no columns or an unsupported format is specified
|
|
133
|
+
"""
|
|
134
|
+
# Convert data to PyArrow Table if needed
|
|
135
|
+
table = data if isinstance(data, pa.Table) else pa.table(data)
|
|
136
|
+
if len(table.column_names) == 0:
|
|
137
|
+
raise ValueError("Cannot convert Table with no columns.")
|
|
138
|
+
|
|
139
|
+
match format:
|
|
140
|
+
case "csv":
|
|
141
|
+
options = pa_csv.WriteOptions(**kwargs) if kwargs else None
|
|
142
|
+
pa_csv.write_csv(table, where, write_options=options)
|
|
143
|
+
case "orc":
|
|
144
|
+
pa_orc.write_table(table, where, **kwargs)
|
|
145
|
+
case "parquet":
|
|
146
|
+
pa_pq.write_table(table, where, **kwargs)
|
|
147
|
+
case _:
|
|
148
|
+
raise ValueError(f"unsupported data format. {format}")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def decode_table(
|
|
152
|
+
data: bytes,
|
|
153
|
+
format: str = "parquet",
|
|
154
|
+
columns: list[str] | None = None,
|
|
155
|
+
**kwargs: Any,
|
|
156
|
+
) -> pa.Table:
|
|
157
|
+
"""Decode a bytes object into a PyArrow table.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
data: The bytes object containing the encoded table data
|
|
161
|
+
format: The format of the encoded data ("parquet", "csv", or "orc")
|
|
162
|
+
columns: List of column names to decode (None means all columns)
|
|
163
|
+
**kwargs: Additional keyword arguments passed to the underlying reader
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
A PyArrow Table decoded from the bytes data
|
|
167
|
+
"""
|
|
168
|
+
buffer = io.BytesIO(data)
|
|
169
|
+
return read_table(buffer, format=format, columns=columns, **kwargs)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def encode_table(data: TableLike, format: str = "parquet", **kwargs: Any) -> bytes:
|
|
173
|
+
"""Encode a table-like object into bytes.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
data: The table-like object to encode (PyArrow Table or other compatible format)
|
|
177
|
+
format: The format to encode the data in ("parquet", "csv", or "orc")
|
|
178
|
+
**kwargs: Additional keyword arguments passed to the underlying writer
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
Bytes object containing the encoded table data
|
|
182
|
+
"""
|
|
183
|
+
buffer = io.BytesIO()
|
|
184
|
+
write_table(data, buffer, format, **kwargs)
|
|
185
|
+
return buffer.getvalue()
|