mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
|
@@ -18,8 +18,8 @@ from __future__ import annotations
|
|
|
18
18
|
|
|
19
19
|
from typing import Any, ClassVar, Self
|
|
20
20
|
|
|
21
|
-
from mplang.
|
|
22
|
-
from mplang.
|
|
21
|
+
from mplang.edsl import serde
|
|
22
|
+
from mplang.runtime.value import Value
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
@serde.register_class
|
|
@@ -17,9 +17,9 @@
|
|
|
17
17
|
Provides Worker-side state and ops for the simp dialect.
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
|
-
from mplang.
|
|
21
|
-
from mplang.
|
|
22
|
-
from mplang.
|
|
20
|
+
from mplang.backends.simp_worker.mem import LocalMesh, ThreadCommunicator
|
|
21
|
+
from mplang.backends.simp_worker.ops import WORKER_HANDLERS
|
|
22
|
+
from mplang.backends.simp_worker.state import SimpWorker
|
|
23
23
|
|
|
24
24
|
__all__ = [
|
|
25
25
|
"WORKER_HANDLERS",
|
|
@@ -21,7 +21,7 @@ This module contains:
|
|
|
21
21
|
|
|
22
22
|
Usage:
|
|
23
23
|
# Start a worker server
|
|
24
|
-
from mplang.
|
|
24
|
+
from mplang.backends.simp_http_worker import create_worker_app
|
|
25
25
|
import uvicorn
|
|
26
26
|
|
|
27
27
|
app = create_worker_app(rank=0, world_size=3, endpoints=[...])
|
|
@@ -47,16 +47,16 @@ import httpx
|
|
|
47
47
|
from fastapi import FastAPI, HTTPException
|
|
48
48
|
from pydantic import BaseModel
|
|
49
49
|
|
|
50
|
-
from mplang.
|
|
51
|
-
from mplang.
|
|
50
|
+
from mplang.backends import spu_impl as _spu_impl # noqa: F401
|
|
51
|
+
from mplang.backends import tensor_impl as _tensor_impl # noqa: F401
|
|
52
52
|
|
|
53
53
|
# Register operation implementations (side-effect imports)
|
|
54
|
-
from mplang.
|
|
55
|
-
from mplang.
|
|
56
|
-
from mplang.
|
|
57
|
-
from mplang.
|
|
58
|
-
from mplang.
|
|
59
|
-
from mplang.
|
|
54
|
+
from mplang.backends.simp_worker import SimpWorker
|
|
55
|
+
from mplang.backends.simp_worker import ops as _simp_worker_ops # noqa: F401
|
|
56
|
+
from mplang.edsl import serde
|
|
57
|
+
from mplang.edsl.graph import Graph
|
|
58
|
+
from mplang.runtime.interpreter import ExecutionTracer, Interpreter
|
|
59
|
+
from mplang.runtime.object_store import ObjectStore
|
|
60
60
|
|
|
61
61
|
logger = logging.getLogger(__name__)
|
|
62
62
|
|
|
@@ -250,7 +250,7 @@ def create_worker_app(
|
|
|
250
250
|
from collections.abc import Callable
|
|
251
251
|
from typing import cast
|
|
252
252
|
|
|
253
|
-
from mplang.
|
|
253
|
+
from mplang.backends.simp_worker.ops import WORKER_HANDLERS
|
|
254
254
|
|
|
255
255
|
# func_impl is already imported at module level for side-effects
|
|
256
256
|
handlers: dict[str, Callable[..., Any]] = {**WORKER_HANDLERS} # type: ignore[dict-item]
|
|
@@ -53,7 +53,7 @@ class ThreadCommunicator:
|
|
|
53
53
|
def send(self, to: int, key: str, data: Any) -> None:
|
|
54
54
|
assert 0 <= to < self.world_size
|
|
55
55
|
if self.use_serde:
|
|
56
|
-
from mplang.
|
|
56
|
+
from mplang.edsl import serde
|
|
57
57
|
|
|
58
58
|
data = serde.loads(serde.dumps(data))
|
|
59
59
|
self.peers[to]._on_receive(self.rank, key, data)
|
|
@@ -22,9 +22,9 @@ from __future__ import annotations
|
|
|
22
22
|
|
|
23
23
|
from typing import Any
|
|
24
24
|
|
|
25
|
-
from mplang.
|
|
26
|
-
from mplang.
|
|
27
|
-
from mplang.
|
|
25
|
+
from mplang.dialects import simp
|
|
26
|
+
from mplang.edsl.graph import Operation
|
|
27
|
+
from mplang.runtime.interpreter import Interpreter
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
def _ensure_worker_context(interpreter: Any, op_name: str) -> Any:
|
|
@@ -111,7 +111,7 @@ def _uniform_cond_worker_impl(
|
|
|
111
111
|
interpreter: Interpreter, op: Operation, pred: Any, *args: Any
|
|
112
112
|
) -> Any:
|
|
113
113
|
"""Worker implementation of simp.uniform_cond."""
|
|
114
|
-
from mplang.
|
|
114
|
+
from mplang.backends.tensor_impl import TensorValue
|
|
115
115
|
|
|
116
116
|
if op.attrs.get("verify_uniform", True):
|
|
117
117
|
pass # TODO: Implement AllReduce verification
|
|
@@ -128,7 +128,7 @@ def _uniform_cond_worker_impl(
|
|
|
128
128
|
|
|
129
129
|
def _while_loop_worker_impl(interpreter: Interpreter, op: Operation, *args: Any) -> Any:
|
|
130
130
|
"""Worker implementation of simp.while_loop."""
|
|
131
|
-
from mplang.
|
|
131
|
+
from mplang.backends.tensor_impl import TensorValue
|
|
132
132
|
|
|
133
133
|
cond_graph = op.regions[0]
|
|
134
134
|
body_graph = op.regions[1]
|
|
@@ -18,10 +18,8 @@ from __future__ import annotations
|
|
|
18
18
|
|
|
19
19
|
from typing import Any
|
|
20
20
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
from mplang.v2.runtime.dialect_state import DialectState
|
|
24
|
-
from mplang.v2.runtime.object_store import ObjectStore
|
|
21
|
+
from mplang.runtime.dialect_state import DialectState
|
|
22
|
+
from mplang.runtime.object_store import ObjectStore
|
|
25
23
|
|
|
26
24
|
|
|
27
25
|
class SimpWorker(DialectState):
|
|
@@ -26,13 +26,13 @@ import numpy as np
|
|
|
26
26
|
import spu.api as spu_api
|
|
27
27
|
import spu.libspu as libspu
|
|
28
28
|
|
|
29
|
-
from mplang.
|
|
30
|
-
from mplang.
|
|
31
|
-
from mplang.
|
|
32
|
-
from mplang.
|
|
33
|
-
from mplang.
|
|
34
|
-
from mplang.
|
|
35
|
-
from mplang.
|
|
29
|
+
from mplang.backends.spu_state import SPUState
|
|
30
|
+
from mplang.backends.tensor_impl import TensorValue
|
|
31
|
+
from mplang.dialects import spu
|
|
32
|
+
from mplang.edsl import serde
|
|
33
|
+
from mplang.edsl.graph import Operation
|
|
34
|
+
from mplang.runtime.interpreter import Interpreter
|
|
35
|
+
from mplang.runtime.value import WrapValue
|
|
36
36
|
|
|
37
37
|
# =============================================================================
|
|
38
38
|
# SPU Share Wrapper
|
|
@@ -160,7 +160,7 @@ def exec_impl(interpreter: Interpreter, op: Operation, *args: Any) -> Any:
|
|
|
160
160
|
The SPU config must contain parties info to correctly map global rank
|
|
161
161
|
to local SPU rank and determine SPU world size.
|
|
162
162
|
"""
|
|
163
|
-
from mplang.
|
|
163
|
+
from mplang.backends.simp_worker.state import SimpWorker
|
|
164
164
|
|
|
165
165
|
# Get SPU config from attrs (passed through from run_jax)
|
|
166
166
|
config: spu.SPUConfig = op.attrs["config"]
|
|
@@ -25,10 +25,10 @@ from typing import TYPE_CHECKING, Any
|
|
|
25
25
|
import spu.api as spu_api
|
|
26
26
|
import spu.libspu as libspu
|
|
27
27
|
|
|
28
|
-
from mplang.
|
|
28
|
+
from mplang.runtime.dialect_state import DialectState
|
|
29
29
|
|
|
30
30
|
if TYPE_CHECKING:
|
|
31
|
-
from mplang.
|
|
31
|
+
from mplang.dialects import spu
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
class SPUState(DialectState):
|
|
@@ -74,7 +74,7 @@ class SPUState(DialectState):
|
|
|
74
74
|
Returns:
|
|
75
75
|
A tuple of (Runtime, Io) for this party.
|
|
76
76
|
"""
|
|
77
|
-
from mplang.
|
|
77
|
+
from mplang.backends.spu_impl import to_runtime_config
|
|
78
78
|
|
|
79
79
|
# Determine link mode
|
|
80
80
|
if communicator is not None:
|
|
@@ -143,7 +143,7 @@ class SPUState(DialectState):
|
|
|
143
143
|
Returns:
|
|
144
144
|
libspu link context using BaseChannel adapters
|
|
145
145
|
"""
|
|
146
|
-
from mplang.
|
|
146
|
+
from mplang.backends.channel import BaseChannel
|
|
147
147
|
|
|
148
148
|
# Get this worker's global rank
|
|
149
149
|
global_rank = parties[local_rank]
|
|
@@ -18,9 +18,9 @@ from __future__ import annotations
|
|
|
18
18
|
|
|
19
19
|
from typing import Any
|
|
20
20
|
|
|
21
|
-
from mplang.
|
|
22
|
-
from mplang.
|
|
23
|
-
from mplang.
|
|
21
|
+
from mplang.dialects import store
|
|
22
|
+
from mplang.edsl.graph import Operation
|
|
23
|
+
from mplang.runtime.interpreter import Interpreter
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
def _get_uri(uri_base: str) -> str:
|
|
@@ -28,13 +28,13 @@ import duckdb
|
|
|
28
28
|
import pandas as pd
|
|
29
29
|
import pyarrow as pa
|
|
30
30
|
|
|
31
|
-
import mplang.
|
|
32
|
-
from mplang.
|
|
33
|
-
from mplang.
|
|
34
|
-
from mplang.
|
|
35
|
-
from mplang.
|
|
36
|
-
from mplang.
|
|
37
|
-
from mplang.
|
|
31
|
+
import mplang.edsl.typing as elt
|
|
32
|
+
from mplang.backends.tensor_impl import TensorValue
|
|
33
|
+
from mplang.dialects import table
|
|
34
|
+
from mplang.edsl import serde
|
|
35
|
+
from mplang.edsl.graph import Operation
|
|
36
|
+
from mplang.runtime.interpreter import Interpreter
|
|
37
|
+
from mplang.runtime.value import WrapValue
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
class BatchReader(ABC):
|
|
@@ -631,7 +631,7 @@ def table2tensor_impl(interpreter: Interpreter, op: Operation, table_val: Any) -
|
|
|
631
631
|
|
|
632
632
|
Returns TensorValue if tensor_impl is available, otherwise raw np.ndarray.
|
|
633
633
|
"""
|
|
634
|
-
from mplang.
|
|
634
|
+
from mplang.backends.tensor_impl import TensorValue
|
|
635
635
|
|
|
636
636
|
tbl = _unwrap(table_val)
|
|
637
637
|
df = tbl.to_pandas()
|
|
@@ -31,14 +31,14 @@ from typing import TYPE_CHECKING, Any, ClassVar
|
|
|
31
31
|
|
|
32
32
|
import numpy as np
|
|
33
33
|
|
|
34
|
-
from mplang.
|
|
35
|
-
from mplang.
|
|
36
|
-
from mplang.
|
|
37
|
-
from mplang.
|
|
34
|
+
from mplang.backends.crypto_impl import BytesValue, PublicKeyValue
|
|
35
|
+
from mplang.dialects import tee
|
|
36
|
+
from mplang.edsl import serde
|
|
37
|
+
from mplang.runtime.value import Value
|
|
38
38
|
|
|
39
39
|
if TYPE_CHECKING:
|
|
40
|
-
from mplang.
|
|
41
|
-
from mplang.
|
|
40
|
+
from mplang.edsl.graph import Operation
|
|
41
|
+
from mplang.runtime.interpreter import Interpreter
|
|
42
42
|
|
|
43
43
|
|
|
44
44
|
# ==============================================================================
|
|
@@ -32,12 +32,12 @@ import numpy as np
|
|
|
32
32
|
from jax._src import compiler
|
|
33
33
|
from numpy.typing import ArrayLike
|
|
34
34
|
|
|
35
|
-
import mplang.
|
|
36
|
-
from mplang.
|
|
37
|
-
from mplang.
|
|
38
|
-
from mplang.
|
|
39
|
-
from mplang.
|
|
40
|
-
from mplang.
|
|
35
|
+
import mplang.edsl.typing as elt
|
|
36
|
+
from mplang.dialects import dtypes, tensor
|
|
37
|
+
from mplang.edsl import serde
|
|
38
|
+
from mplang.edsl.graph import Operation
|
|
39
|
+
from mplang.runtime.interpreter import Interpreter
|
|
40
|
+
from mplang.runtime.value import Value, WrapValue
|
|
41
41
|
|
|
42
42
|
# =============================================================================
|
|
43
43
|
# TensorValue Wrapper
|
mplang/{v2/cli.py → cli.py}
RENAMED
|
@@ -18,19 +18,19 @@ Command-line interface for MPLang2 clusters and jobs.
|
|
|
18
18
|
|
|
19
19
|
Examples:
|
|
20
20
|
# Generate a cluster config file
|
|
21
|
-
python -m mplang.
|
|
21
|
+
python -m mplang.cli config gen -w 3 -p 8100 -o cluster.yaml
|
|
22
22
|
|
|
23
23
|
# Start a single worker (production usage)
|
|
24
|
-
python -m mplang.
|
|
24
|
+
python -m mplang.cli worker --rank 0 -c cluster.yaml
|
|
25
25
|
|
|
26
26
|
# Start 3 local workers (development usage)
|
|
27
|
-
python -m mplang.
|
|
27
|
+
python -m mplang.cli up -c cluster.yaml
|
|
28
28
|
|
|
29
29
|
# Check cluster status
|
|
30
|
-
python -m mplang.
|
|
30
|
+
python -m mplang.cli status -c cluster.yaml
|
|
31
31
|
|
|
32
32
|
# Run a job
|
|
33
|
-
python -m mplang.
|
|
33
|
+
python -m mplang.cli run -c cluster.yaml -f my_job.py
|
|
34
34
|
"""
|
|
35
35
|
|
|
36
36
|
import argparse
|
|
@@ -62,7 +62,7 @@ def run_worker(
|
|
|
62
62
|
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
|
63
63
|
signal.signal(signal.SIGTERM, signal.SIG_DFL)
|
|
64
64
|
|
|
65
|
-
from mplang.
|
|
65
|
+
from mplang.backends.simp_worker.http import create_worker_app
|
|
66
66
|
|
|
67
67
|
app = create_worker_app(rank, world_size, endpoints, spu_endpoints)
|
|
68
68
|
|
|
@@ -323,9 +323,9 @@ def parse_spu_endpoints(
|
|
|
323
323
|
|
|
324
324
|
def cmd_run(args: argparse.Namespace) -> None:
|
|
325
325
|
"""Run a user job via HTTP cluster or local simulator."""
|
|
326
|
-
from mplang
|
|
327
|
-
from mplang.
|
|
328
|
-
from mplang.
|
|
326
|
+
from mplang import make_driver, make_simulator
|
|
327
|
+
from mplang.edsl.context import pop_context, push_context
|
|
328
|
+
from mplang.libs.device import ClusterSpec
|
|
329
329
|
|
|
330
330
|
cluster: ClusterSpec
|
|
331
331
|
|
|
@@ -26,7 +26,7 @@ First, generate a `cluster.yaml` file. This defines the topology of your MPLang
|
|
|
26
26
|
|
|
27
27
|
```bash
|
|
28
28
|
# Generate a config for 2 workers starting at port 8100
|
|
29
|
-
python -m mplang.
|
|
29
|
+
python -m mplang.cli config gen -w 2 -p 8100 -o cluster.yaml
|
|
30
30
|
```
|
|
31
31
|
|
|
32
32
|
### 2. Start the Cluster (Terminal 1)
|
|
@@ -35,7 +35,7 @@ In your **first terminal**, start the cluster using the `up` command. This will
|
|
|
35
35
|
|
|
36
36
|
```bash
|
|
37
37
|
# Terminal 1
|
|
38
|
-
python -m mplang.
|
|
38
|
+
python -m mplang.cli up -c cluster.yaml
|
|
39
39
|
```
|
|
40
40
|
|
|
41
41
|
You should see logs indicating that workers have started (e.g., `[Worker 0] INFO: Started server process...`). Keep this terminal open.
|
|
@@ -46,7 +46,7 @@ Create a Python script (e.g., `my_job.py`) that defines the computation you want
|
|
|
46
46
|
|
|
47
47
|
```python
|
|
48
48
|
# my_job.py
|
|
49
|
-
from mplang.
|
|
49
|
+
from mplang.dialects import simp
|
|
50
50
|
import numpy as np
|
|
51
51
|
|
|
52
52
|
def main():
|
|
@@ -71,7 +71,7 @@ In your **second terminal**, use the `run` command to submit the script to the r
|
|
|
71
71
|
|
|
72
72
|
```bash
|
|
73
73
|
# Terminal 2
|
|
74
|
-
python -m mplang.
|
|
74
|
+
python -m mplang.cli run -c cluster.yaml -f my_job.py
|
|
75
75
|
```
|
|
76
76
|
|
|
77
77
|
The CLI will connect to the driver, which orchestrates the execution across the workers.
|
|
@@ -82,7 +82,7 @@ You can check the health and latency of your workers at any time.
|
|
|
82
82
|
|
|
83
83
|
```bash
|
|
84
84
|
# Terminal 2
|
|
85
|
-
python -m mplang.
|
|
85
|
+
python -m mplang.cli status -c cluster.yaml
|
|
86
86
|
```
|
|
87
87
|
|
|
88
88
|
**Output Example:**
|
|
@@ -99,7 +99,7 @@ To debug or verify intermediate results, you can list the objects currently stor
|
|
|
99
99
|
|
|
100
100
|
```bash
|
|
101
101
|
# Terminal 2
|
|
102
|
-
python -m mplang.
|
|
102
|
+
python -m mplang.cli objects -c cluster.yaml
|
|
103
103
|
```
|
|
104
104
|
|
|
105
105
|
**Output Example:**
|
|
@@ -114,9 +114,9 @@ Rank | Endpoint | Count | Objects
|
|
|
114
114
|
|
|
115
115
|
| Command | Description | Usage |
|
|
116
116
|
| :--- | :--- | :--- |
|
|
117
|
-
| `config gen` | Generate cluster config file | `python -m mplang.
|
|
118
|
-
| `up` | Start all workers locally | `python -m mplang.
|
|
119
|
-
| `run` | Submit a job script | `python -m mplang.
|
|
120
|
-
| `status` | Check worker health | `python -m mplang.
|
|
121
|
-
| `objects` | List objects on workers | `python -m mplang.
|
|
122
|
-
| `worker` | Start a single worker (prod) | `python -m mplang.
|
|
117
|
+
| `config gen` | Generate cluster config file | `python -m mplang.cli config gen -w <workers> -o <file>` |
|
|
118
|
+
| `up` | Start all workers locally | `python -m mplang.cli up -c <config>` |
|
|
119
|
+
| `run` | Submit a job script | `python -m mplang.cli run -c <config> -f <script>` |
|
|
120
|
+
| `status` | Check worker health | `python -m mplang.cli status -c <config>` |
|
|
121
|
+
| `objects` | List objects on workers | `python -m mplang.cli objects -c <config>` |
|
|
122
|
+
| `worker` | Start a single worker (prod) | `python -m mplang.cli worker --rank <id> -c <config>` |
|
|
@@ -29,8 +29,8 @@ from __future__ import annotations
|
|
|
29
29
|
|
|
30
30
|
# Import dialects to trigger their type registrations
|
|
31
31
|
# Each dialect module registers its types at import time via _register_*_types()
|
|
32
|
-
from mplang.
|
|
33
|
-
from mplang.
|
|
34
|
-
from mplang.
|
|
35
|
-
from mplang.
|
|
36
|
-
from mplang.
|
|
32
|
+
from mplang.dialects import bfv as _bfv # noqa: F401
|
|
33
|
+
from mplang.dialects import crypto as _crypto # noqa: F401
|
|
34
|
+
from mplang.dialects import spu as _spu # noqa: F401
|
|
35
|
+
from mplang.dialects import store as _store # noqa: F401
|
|
36
|
+
from mplang.dialects import tee as _tee # noqa: F401
|
|
@@ -54,8 +54,8 @@ Architecture:
|
|
|
54
54
|
|
|
55
55
|
Example:
|
|
56
56
|
```python
|
|
57
|
-
from mplang.
|
|
58
|
-
import mplang.
|
|
57
|
+
from mplang.dialects import tensor, bfv
|
|
58
|
+
import mplang.edsl.typing as elt
|
|
59
59
|
import numpy as np
|
|
60
60
|
|
|
61
61
|
# 1. Setup
|
|
@@ -91,9 +91,9 @@ from __future__ import annotations
|
|
|
91
91
|
|
|
92
92
|
from typing import Any, ClassVar, Literal, cast
|
|
93
93
|
|
|
94
|
-
import mplang.
|
|
95
|
-
import mplang.
|
|
96
|
-
from mplang.
|
|
94
|
+
import mplang.edsl as el
|
|
95
|
+
import mplang.edsl.typing as elt
|
|
96
|
+
from mplang.edsl import serde
|
|
97
97
|
|
|
98
98
|
# ==============================================================================
|
|
99
99
|
# --- Type Definitions
|
|
@@ -369,7 +369,7 @@ def _batch_encode_trace(
|
|
|
369
369
|
encoder: el.Object,
|
|
370
370
|
key: el.Object,
|
|
371
371
|
) -> tuple[el.Object, ...]:
|
|
372
|
-
from mplang.
|
|
372
|
+
from mplang.edsl.tracer import TraceObject, Tracer
|
|
373
373
|
|
|
374
374
|
ctx = el.get_current_context()
|
|
375
375
|
if not isinstance(ctx, Tracer):
|
|
@@ -21,9 +21,9 @@ from __future__ import annotations
|
|
|
21
21
|
|
|
22
22
|
from typing import Any, ClassVar
|
|
23
23
|
|
|
24
|
-
import mplang.
|
|
25
|
-
import mplang.
|
|
26
|
-
from mplang.
|
|
24
|
+
import mplang.edsl as el
|
|
25
|
+
import mplang.edsl.typing as elt
|
|
26
|
+
from mplang.edsl import serde
|
|
27
27
|
|
|
28
28
|
# ==============================================================================
|
|
29
29
|
# --- Type Definitions
|
|
@@ -607,7 +607,7 @@ def random_tensor(shape: tuple[int, ...], dtype: elt.ScalarType) -> el.Object:
|
|
|
607
607
|
import math
|
|
608
608
|
from typing import cast
|
|
609
609
|
|
|
610
|
-
from mplang.
|
|
610
|
+
from mplang.dialects import dtypes, tensor
|
|
611
611
|
|
|
612
612
|
# Get byte size from numpy dtype
|
|
613
613
|
np_dtype = dtypes.to_numpy(dtype)
|
|
@@ -644,7 +644,7 @@ def random_bits(n: int) -> el.Object:
|
|
|
644
644
|
|
|
645
645
|
import jax.numpy as jnp
|
|
646
646
|
|
|
647
|
-
from mplang.
|
|
647
|
+
from mplang.dialects import tensor
|
|
648
648
|
|
|
649
649
|
# Generate enough bytes to cover n bits
|
|
650
650
|
num_bytes = (n + 7) // 8
|
|
@@ -18,7 +18,7 @@ This module provides bidirectional conversion between MPLang's type system
|
|
|
18
18
|
(ScalarType hierarchy) and external library types (NumPy, JAX, PyArrow, Pandas).
|
|
19
19
|
|
|
20
20
|
Usage:
|
|
21
|
-
from mplang.
|
|
21
|
+
from mplang.dialects import dtypes
|
|
22
22
|
|
|
23
23
|
# MPLang ScalarType → JAX/NumPy
|
|
24
24
|
jax_dtype = dtypes.to_jax(scalar_types.f32) # → jnp.float32
|
|
@@ -40,7 +40,7 @@ from typing import Any
|
|
|
40
40
|
import jax.numpy as jnp
|
|
41
41
|
import numpy as np
|
|
42
42
|
|
|
43
|
-
import mplang.
|
|
43
|
+
import mplang.edsl.typing as scalar_types
|
|
44
44
|
|
|
45
45
|
# ==============================================================================
|
|
46
46
|
# MPLang ScalarType → JAX/NumPy conversion
|
|
@@ -29,9 +29,9 @@ from typing import Any, cast
|
|
|
29
29
|
|
|
30
30
|
import jax.numpy as jnp
|
|
31
31
|
|
|
32
|
-
import mplang.
|
|
33
|
-
import mplang.
|
|
34
|
-
from mplang.
|
|
32
|
+
import mplang.edsl as el
|
|
33
|
+
import mplang.edsl.typing as elt
|
|
34
|
+
from mplang.dialects import tensor
|
|
35
35
|
|
|
36
36
|
# =============================================================================
|
|
37
37
|
# Primitives
|
|
@@ -24,8 +24,8 @@ from __future__ import annotations
|
|
|
24
24
|
from collections.abc import Callable
|
|
25
25
|
from typing import Any
|
|
26
26
|
|
|
27
|
-
import mplang.
|
|
28
|
-
import mplang.
|
|
27
|
+
import mplang.edsl as el
|
|
28
|
+
import mplang.edsl.typing as elt
|
|
29
29
|
|
|
30
30
|
func_def_p = el.Primitive[el.TraceObject]("func.func")
|
|
31
31
|
call_p = el.Primitive[Any]("func.call")
|
|
@@ -35,8 +35,8 @@ Architecture:
|
|
|
35
35
|
|
|
36
36
|
Example:
|
|
37
37
|
```python
|
|
38
|
-
from mplang.
|
|
39
|
-
import mplang.
|
|
38
|
+
from mplang.dialects import tensor, phe
|
|
39
|
+
import mplang.edsl.typing as elt
|
|
40
40
|
import numpy as np
|
|
41
41
|
|
|
42
42
|
# 1. Generate keys (cryptographic only)
|
|
@@ -75,9 +75,9 @@ from __future__ import annotations
|
|
|
75
75
|
from collections.abc import Callable
|
|
76
76
|
from typing import Any, NamedTuple
|
|
77
77
|
|
|
78
|
-
import mplang.
|
|
79
|
-
import mplang.
|
|
80
|
-
from mplang.
|
|
78
|
+
import mplang.edsl as el
|
|
79
|
+
import mplang.edsl.typing as elt
|
|
80
|
+
from mplang.dialects import tensor
|
|
81
81
|
|
|
82
82
|
# ==============================================================================
|
|
83
83
|
# --- Type Definitions
|
|
@@ -415,7 +415,7 @@ def create_encoder(
|
|
|
415
415
|
PHE encoder configured for the specified dtype
|
|
416
416
|
|
|
417
417
|
Example:
|
|
418
|
-
>>> import mplang.
|
|
418
|
+
>>> import mplang.edsl.typing as elt
|
|
419
419
|
>>>
|
|
420
420
|
>>> # Float encoder with 16-bit fractional precision
|
|
421
421
|
>>> encoder_f64 = phe.create_encoder(dtype=elt.f64, fxp_bits=16)
|
|
@@ -36,8 +36,8 @@ from typing import Any, cast
|
|
|
36
36
|
|
|
37
37
|
from jax.tree_util import tree_flatten, tree_unflatten
|
|
38
38
|
|
|
39
|
-
import mplang.
|
|
40
|
-
import mplang.
|
|
39
|
+
import mplang.edsl as el
|
|
40
|
+
import mplang.edsl.typing as elt
|
|
41
41
|
|
|
42
42
|
# ---------------------------------------------------------------------------
|
|
43
43
|
# Global configuration
|
|
@@ -809,7 +809,7 @@ def constant(parties: tuple[int, ...], data: Any) -> el.Object:
|
|
|
809
809
|
import jax.numpy as jnp
|
|
810
810
|
import numpy as np
|
|
811
811
|
|
|
812
|
-
from mplang.
|
|
812
|
+
from mplang.dialects import table, tensor
|
|
813
813
|
|
|
814
814
|
# 1. Scalars (int, float, bool, numpy scalars)
|
|
815
815
|
if isinstance(data, (int, float, bool, np.number, np.bool_)):
|
|
@@ -888,11 +888,11 @@ def make_simulator(
|
|
|
888
888
|
... result = my_func()
|
|
889
889
|
"""
|
|
890
890
|
if enable_profiling:
|
|
891
|
-
from mplang.
|
|
891
|
+
from mplang.edsl import registry
|
|
892
892
|
|
|
893
893
|
registry.enable_profiling()
|
|
894
894
|
|
|
895
|
-
from mplang.
|
|
895
|
+
from mplang.backends.simp_driver.mem import make_simulator as _make_sim
|
|
896
896
|
|
|
897
897
|
return _make_sim(
|
|
898
898
|
world_size, cluster_spec=cluster_spec, enable_tracing=enable_tracing
|
|
@@ -917,7 +917,7 @@ def make_driver(endpoints: list[str], *, cluster_spec: Any = None) -> Any:
|
|
|
917
917
|
>>> with interp:
|
|
918
918
|
... result = my_func()
|
|
919
919
|
"""
|
|
920
|
-
from mplang.
|
|
920
|
+
from mplang.backends.simp_driver.http import make_driver as _make_drv
|
|
921
921
|
|
|
922
922
|
return _make_drv(endpoints, cluster_spec=cluster_spec)
|
|
923
923
|
|
|
@@ -27,8 +27,8 @@ Concepts:
|
|
|
27
27
|
Example:
|
|
28
28
|
```python
|
|
29
29
|
import jax.numpy as jnp
|
|
30
|
-
from mplang.
|
|
31
|
-
import mplang.
|
|
30
|
+
from mplang.dialects import spu, tensor, simp
|
|
31
|
+
import mplang.edsl.typing as elt
|
|
32
32
|
|
|
33
33
|
# 0. Setup
|
|
34
34
|
spu_device = spu.SPUDevice(parties=(0, 1, 2))
|
|
@@ -83,11 +83,11 @@ import spu.utils.frontend as spu_fe
|
|
|
83
83
|
from jax import ShapeDtypeStruct
|
|
84
84
|
from jax.tree_util import tree_flatten, tree_unflatten
|
|
85
85
|
|
|
86
|
-
import mplang.
|
|
87
|
-
import mplang.
|
|
88
|
-
from mplang.
|
|
89
|
-
from mplang.
|
|
90
|
-
from mplang.
|
|
86
|
+
import mplang.edsl as el
|
|
87
|
+
import mplang.edsl.typing as elt
|
|
88
|
+
from mplang.dialects import dtypes
|
|
89
|
+
from mplang.edsl import serde
|
|
90
|
+
from mplang.utils import normalize_fn
|
|
91
91
|
|
|
92
92
|
# ==============================================================================
|
|
93
93
|
# --- Configuration
|
|
@@ -16,8 +16,8 @@
|
|
|
16
16
|
|
|
17
17
|
from __future__ import annotations
|
|
18
18
|
|
|
19
|
-
import mplang.
|
|
20
|
-
import mplang.
|
|
19
|
+
import mplang.edsl as el
|
|
20
|
+
import mplang.edsl.typing as elt
|
|
21
21
|
|
|
22
22
|
save_p: el.Primitive[el.Object] = el.Primitive("store.save")
|
|
23
23
|
load_p: el.Primitive[el.Object] = el.Primitive("store.load")
|
|
@@ -18,8 +18,8 @@ from __future__ import annotations
|
|
|
18
18
|
|
|
19
19
|
from typing import Any, cast
|
|
20
20
|
|
|
21
|
-
import mplang.
|
|
22
|
-
import mplang.
|
|
21
|
+
import mplang.edsl as el
|
|
22
|
+
import mplang.edsl.typing as elt
|
|
23
23
|
|
|
24
24
|
run_sql_p: el.Primitive[Any] = el.Primitive("table.run_sql")
|
|
25
25
|
table2tensor_p: el.Primitive[el.Object] = el.Primitive("table.table2tensor")
|
|
@@ -182,7 +182,7 @@ def _constant_ae(*, data: Any) -> elt.TableType:
|
|
|
182
182
|
import pandas as pd
|
|
183
183
|
import pyarrow as pa
|
|
184
184
|
|
|
185
|
-
from mplang.
|
|
185
|
+
from mplang.dialects import dtypes
|
|
186
186
|
|
|
187
187
|
# Handle PyArrow Table directly
|
|
188
188
|
if isinstance(data, pa.Table):
|