mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# Simp Dialect Backend Design
|
|
2
|
+
|
|
3
|
+
## Overview
|
|
4
|
+
|
|
5
|
+
The `simp` (Simple Multi-Party) dialect implements SPMD (Single Program Multiple Data) distributed execution. A single program is written once and executed across multiple parties, with the runtime handling distribution, communication, and synchronization.
|
|
6
|
+
|
|
7
|
+
## Why Two Implementations?
|
|
8
|
+
|
|
9
|
+
The simp dialect requires **two separate backend implementations** because the same primitives (`pcall`, `shuffle`, `converge`) have fundamentally different semantics depending on where they execute:
|
|
10
|
+
|
|
11
|
+
| Primitive | Driver (Host) | Worker |
|
|
12
|
+
|-----------|---------------|--------|
|
|
13
|
+
| `pcall` | Dispatch work to workers | Execute locally |
|
|
14
|
+
| `shuffle` | Route data between workers | Send/Receive via communicator |
|
|
15
|
+
| `converge` | Merge HostVars | Pick non-null value |
|
|
16
|
+
|
|
17
|
+
This is the essence of SPMD: the Driver orchestrates, Workers execute.
|
|
18
|
+
|
|
19
|
+
## Architecture
|
|
20
|
+
|
|
21
|
+
```
|
|
22
|
+
┌─────────────────────────────────────────────────────────────────┐
|
|
23
|
+
│ dialects/simp.py │
|
|
24
|
+
│ (Primitive definitions) │
|
|
25
|
+
└─────────────────────────────────────────────────────────────────┘
|
|
26
|
+
│
|
|
27
|
+
┌─────────────┴─────────────┐
|
|
28
|
+
▼ ▼
|
|
29
|
+
┌───────────────────────────┐ ┌───────────────────────────┐
|
|
30
|
+
│ simp_driver/ │ │ simp_worker/ │
|
|
31
|
+
│ (Host/Driver side) │ │ (Worker side) │
|
|
32
|
+
├───────────────────────────┤ ├───────────────────────────┤
|
|
33
|
+
│ base.py SimpDriver │ │ state.py SimpWorker │
|
|
34
|
+
│ ops.py HOST_HANDLERS │ │ ops.py WORKER_HANDLERS│
|
|
35
|
+
│ values.py HostVar │ │ │
|
|
36
|
+
│ mem.py SimpMemDriver │ │ mem.py LocalMesh │
|
|
37
|
+
│ http.py SimpHttpDriver │ │ http.py HTTP Server │
|
|
38
|
+
└───────────────────────────┘ └───────────────────────────┘
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
## Directory Structure
|
|
42
|
+
|
|
43
|
+
```
|
|
44
|
+
backends/
|
|
45
|
+
├── simp_driver/ # Driver/Host side
|
|
46
|
+
│ ├── __init__.py # Exports
|
|
47
|
+
│ ├── base.py # SimpDriver (abstract base)
|
|
48
|
+
│ ├── ops.py # HOST_HANDLERS
|
|
49
|
+
│ ├── values.py # HostVar
|
|
50
|
+
│ ├── mem.py # MemCluster + SimpMemDriver + make_simulator
|
|
51
|
+
│ └── http.py # SimpHttpDriver + make_driver
|
|
52
|
+
│
|
|
53
|
+
├── simp_worker/ # Worker side
|
|
54
|
+
│ ├── __init__.py # Exports
|
|
55
|
+
│ ├── state.py # SimpWorker (DialectState)
|
|
56
|
+
│ ├── ops.py # WORKER_HANDLERS
|
|
57
|
+
│ ├── mem.py # LocalMesh + ThreadCommunicator
|
|
58
|
+
│ └── http.py # HTTP Worker Server
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
## Key Classes
|
|
62
|
+
|
|
63
|
+
### Driver Side
|
|
64
|
+
|
|
65
|
+
```python
|
|
66
|
+
class SimpDriver(DialectState, ABC):
|
|
67
|
+
"""Abstract interface for simp drivers."""
|
|
68
|
+
dialect_name = "simp"
|
|
69
|
+
world_size: int
|
|
70
|
+
|
|
71
|
+
@abstractmethod
|
|
72
|
+
def submit(self, rank, graph, inputs, job_id=None) -> Future: ...
|
|
73
|
+
@abstractmethod
|
|
74
|
+
def fetch(self, rank, uri) -> Future: ...
|
|
75
|
+
@abstractmethod
|
|
76
|
+
def collect(self, futures) -> list: ...
|
|
77
|
+
|
|
78
|
+
class SimpMemDriver(SimpDriver):
|
|
79
|
+
"""In-memory IPC via ThreadPoolExecutor."""
|
|
80
|
+
|
|
81
|
+
class SimpHttpDriver(SimpDriver):
|
|
82
|
+
"""HTTP IPC via httpx."""
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
### Worker Side
|
|
86
|
+
|
|
87
|
+
```python
|
|
88
|
+
class SimpWorker(DialectState):
|
|
89
|
+
"""Worker state with communicator and store."""
|
|
90
|
+
dialect_name = "simp"
|
|
91
|
+
rank: int
|
|
92
|
+
world_size: int
|
|
93
|
+
communicator: Any # ThreadCommunicator or HTTP client
|
|
94
|
+
store: ObjectStore
|
|
95
|
+
```
|
|
96
|
+
|
|
97
|
+
## IPC Symmetry
|
|
98
|
+
|
|
99
|
+
| IPC Type | Driver | Worker |
|
|
100
|
+
|----------|--------|--------|
|
|
101
|
+
| Memory | `simp_driver/mem.py` | `simp_worker/mem.py` |
|
|
102
|
+
| HTTP | `simp_driver/http.py` | `simp_worker/http.py` |
|
|
103
|
+
|
|
104
|
+
## Factory Functions
|
|
105
|
+
|
|
106
|
+
```python
|
|
107
|
+
# Create local simulator (memory IPC)
|
|
108
|
+
interp = simp.make_simulator(world_size=3)
|
|
109
|
+
|
|
110
|
+
# Create remote driver (HTTP IPC)
|
|
111
|
+
interp = simp.make_driver(["http://w1:8000", "http://w2:8000"])
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
## Data Flow
|
|
115
|
+
|
|
116
|
+
```
|
|
117
|
+
User Code
|
|
118
|
+
│
|
|
119
|
+
▼
|
|
120
|
+
simp.pcall(parties=(0,1), fn, args)
|
|
121
|
+
│
|
|
122
|
+
▼ (Driver Interpreter)
|
|
123
|
+
HOST_HANDLERS["simp.pcall"]
|
|
124
|
+
│
|
|
125
|
+
├─► driver.submit(rank=0, graph, inputs)
|
|
126
|
+
└─► driver.submit(rank=1, graph, inputs)
|
|
127
|
+
│
|
|
128
|
+
▼ (IPC: Memory or HTTP)
|
|
129
|
+
Worker Interpreters
|
|
130
|
+
│
|
|
131
|
+
▼
|
|
132
|
+
WORKER_HANDLERS["simp.pcall"]
|
|
133
|
+
│
|
|
134
|
+
▼
|
|
135
|
+
Local Execution
|
|
136
|
+
```
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Simp Driver package.
|
|
16
|
+
|
|
17
|
+
Provides Driver-side state, values, and ops for the simp dialect.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from mplang.v2.backends.simp_driver.http import SimpHttpDriver, make_driver
|
|
21
|
+
from mplang.v2.backends.simp_driver.mem import (
|
|
22
|
+
LocalCluster,
|
|
23
|
+
MemCluster,
|
|
24
|
+
SimpMemDriver,
|
|
25
|
+
make_simulator,
|
|
26
|
+
)
|
|
27
|
+
from mplang.v2.backends.simp_driver.ops import DRIVER_HANDLERS
|
|
28
|
+
from mplang.v2.backends.simp_driver.state import SimpDriver
|
|
29
|
+
from mplang.v2.backends.simp_driver.values import DriverVar
|
|
30
|
+
|
|
31
|
+
__all__ = [
|
|
32
|
+
"DRIVER_HANDLERS",
|
|
33
|
+
"DriverVar",
|
|
34
|
+
"LocalCluster",
|
|
35
|
+
"MemCluster",
|
|
36
|
+
"SimpDriver",
|
|
37
|
+
"SimpHttpDriver",
|
|
38
|
+
"SimpMemDriver",
|
|
39
|
+
"make_driver",
|
|
40
|
+
"make_simulator",
|
|
41
|
+
]
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Simp Driver HTTP IPC (SimpHttpDriver, make_driver)."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import concurrent.futures
|
|
20
|
+
import os
|
|
21
|
+
import pathlib
|
|
22
|
+
from typing import TYPE_CHECKING, Any
|
|
23
|
+
|
|
24
|
+
import httpx
|
|
25
|
+
|
|
26
|
+
from mplang.v2.backends.simp_driver.state import SimpDriver
|
|
27
|
+
from mplang.v2.edsl import serde
|
|
28
|
+
from mplang.v2.runtime.interpreter import Interpreter
|
|
29
|
+
from mplang.v2.runtime.object_store import ObjectStore
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from concurrent.futures import Future
|
|
33
|
+
|
|
34
|
+
from mplang.v2.edsl.graph import Graph
|
|
35
|
+
from mplang.v2.libs.device import ClusterSpec
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class SimpHttpDriver(SimpDriver):
|
|
39
|
+
"""Simp Driver for remote HTTP IPC.
|
|
40
|
+
|
|
41
|
+
Implements submit/fetch/collect interface for dispatching work via HTTP.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
dialect_name: str = "simp"
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
endpoints: list[str],
|
|
49
|
+
*,
|
|
50
|
+
cluster_spec: ClusterSpec | None = None,
|
|
51
|
+
) -> None:
|
|
52
|
+
"""Create remote simp driver.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
endpoints: List of HTTP endpoints for workers.
|
|
56
|
+
cluster_spec: Optional cluster specification for metadata.
|
|
57
|
+
"""
|
|
58
|
+
# Normalize endpoints
|
|
59
|
+
self._endpoints = []
|
|
60
|
+
for ep in endpoints:
|
|
61
|
+
if not ep.startswith("http://") and not ep.startswith("https://"):
|
|
62
|
+
ep = f"http://{ep}"
|
|
63
|
+
self._endpoints.append(ep)
|
|
64
|
+
|
|
65
|
+
self._world_size = len(endpoints)
|
|
66
|
+
self._cluster_spec = cluster_spec
|
|
67
|
+
|
|
68
|
+
# Construct driver root
|
|
69
|
+
data_root = pathlib.Path(os.environ.get("MPLANG_DATA_ROOT", ".mpl"))
|
|
70
|
+
if cluster_spec:
|
|
71
|
+
self.driver_root = data_root / cluster_spec.cluster_id / "__driver__"
|
|
72
|
+
else:
|
|
73
|
+
self.driver_root = data_root / "__remote__" / "__driver__"
|
|
74
|
+
|
|
75
|
+
# HTTP client and executor
|
|
76
|
+
self._client = httpx.Client(timeout=None)
|
|
77
|
+
self._executor = concurrent.futures.ThreadPoolExecutor(
|
|
78
|
+
max_workers=self._world_size
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def world_size(self) -> int:
|
|
83
|
+
return self._world_size
|
|
84
|
+
|
|
85
|
+
def submit(
|
|
86
|
+
self, rank: int, graph: Graph, inputs: list[Any], job_id: str | None = None
|
|
87
|
+
) -> Future[Any]:
|
|
88
|
+
"""Submit execution to remote worker via HTTP."""
|
|
89
|
+
return self._executor.submit(self._do_request, rank, graph, inputs, job_id)
|
|
90
|
+
|
|
91
|
+
def collect(self, futures: list[Future[Any]]) -> list[Any]:
|
|
92
|
+
"""Collect results from futures."""
|
|
93
|
+
return [f.result() for f in futures]
|
|
94
|
+
|
|
95
|
+
def fetch(self, rank: int, uri: str) -> Future[Any]:
|
|
96
|
+
"""Fetch data from remote worker."""
|
|
97
|
+
return self._executor.submit(self._do_fetch, rank, uri)
|
|
98
|
+
|
|
99
|
+
def _do_request(
|
|
100
|
+
self, rank: int, graph: Graph, inputs: list[Any], job_id: str | None = None
|
|
101
|
+
) -> Any:
|
|
102
|
+
"""Execute HTTP request."""
|
|
103
|
+
url = f"{self._endpoints[rank]}/exec"
|
|
104
|
+
graph_b64 = serde.dumps_b64(graph)
|
|
105
|
+
inputs_b64 = serde.dumps_b64(inputs)
|
|
106
|
+
payload = {"graph": graph_b64, "inputs": inputs_b64}
|
|
107
|
+
if job_id:
|
|
108
|
+
payload["job_id"] = job_id
|
|
109
|
+
|
|
110
|
+
resp = self._client.post(url, json=payload)
|
|
111
|
+
resp.raise_for_status()
|
|
112
|
+
return serde.loads_b64(resp.json()["result"])
|
|
113
|
+
|
|
114
|
+
def _do_fetch(self, rank: int, uri: str) -> Any:
|
|
115
|
+
"""Fetch data from remote worker."""
|
|
116
|
+
url = f"{self._endpoints[rank]}/fetch"
|
|
117
|
+
resp = self._client.post(url, json={"uri": uri})
|
|
118
|
+
resp.raise_for_status()
|
|
119
|
+
return serde.loads_b64(resp.json()["result"])
|
|
120
|
+
|
|
121
|
+
def shutdown(self) -> None:
|
|
122
|
+
"""Close HTTP client and executor."""
|
|
123
|
+
self._client.close()
|
|
124
|
+
self._executor.shutdown()
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def make_driver(endpoints: list[str], *, cluster_spec: Any = None) -> Interpreter:
|
|
128
|
+
"""Create an Interpreter configured for remote SIMP execution.
|
|
129
|
+
|
|
130
|
+
This factory creates a SimpHttpDriver and returns an Interpreter
|
|
131
|
+
with the simp dialect state attached.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
endpoints: List of HTTP endpoints for workers.
|
|
135
|
+
cluster_spec: Optional ClusterSpec for metadata.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
Configured Interpreter with simp state attached.
|
|
139
|
+
|
|
140
|
+
Example:
|
|
141
|
+
>>> interp = make_driver(["http://worker1:8000", "http://worker2:8000"])
|
|
142
|
+
>>> with interp:
|
|
143
|
+
... result = my_func()
|
|
144
|
+
"""
|
|
145
|
+
from mplang.v2.backends.simp_driver.ops import DRIVER_HANDLERS
|
|
146
|
+
|
|
147
|
+
if cluster_spec is None:
|
|
148
|
+
from mplang.v2.libs.device import ClusterSpec
|
|
149
|
+
|
|
150
|
+
cluster_spec = ClusterSpec.simple(
|
|
151
|
+
world_size=len(endpoints), endpoints=endpoints
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
state = SimpHttpDriver(endpoints, cluster_spec=cluster_spec)
|
|
155
|
+
|
|
156
|
+
from collections.abc import Callable
|
|
157
|
+
|
|
158
|
+
handlers: dict[str, Callable[..., Any]] = {**DRIVER_HANDLERS} # type: ignore[dict-item]
|
|
159
|
+
interp = Interpreter(
|
|
160
|
+
name="DriverInterpreter",
|
|
161
|
+
root_dir=state.driver_root,
|
|
162
|
+
handlers=handlers,
|
|
163
|
+
store=ObjectStore(fs_root=str(state.driver_root)),
|
|
164
|
+
)
|
|
165
|
+
interp.set_dialect_state("simp", state)
|
|
166
|
+
interp._cluster_spec = cluster_spec # type: ignore[attr-defined]
|
|
167
|
+
|
|
168
|
+
return interp
|
|
@@ -0,0 +1,280 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Simp Driver memory IPC (MemCluster, SimpMemDriver, make_simulator)."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import concurrent.futures
|
|
20
|
+
import os
|
|
21
|
+
import pathlib
|
|
22
|
+
from collections.abc import Callable
|
|
23
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
24
|
+
|
|
25
|
+
from mplang.v2.backends.simp_driver.state import SimpDriver
|
|
26
|
+
from mplang.v2.backends.simp_worker import WORKER_HANDLERS, SimpWorker
|
|
27
|
+
from mplang.v2.backends.simp_worker.mem import LocalMesh
|
|
28
|
+
from mplang.v2.runtime.interpreter import ExecutionTracer, Interpreter
|
|
29
|
+
from mplang.v2.runtime.object_store import ObjectStore
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from concurrent.futures import Future
|
|
33
|
+
|
|
34
|
+
from mplang.v2.edsl.graph import Graph
|
|
35
|
+
from mplang.v2.libs.device import ClusterSpec
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class MemCluster:
|
|
39
|
+
"""Orchestrator that creates and manages local worker Interpreters.
|
|
40
|
+
|
|
41
|
+
This class handles worker lifecycle management. It does NOT attach to
|
|
42
|
+
an Interpreter - instead, it creates a SimpMemDriver that can be attached.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
world_size: int,
|
|
48
|
+
*,
|
|
49
|
+
cluster_spec: ClusterSpec | None = None,
|
|
50
|
+
enable_tracing: bool = False,
|
|
51
|
+
) -> None:
|
|
52
|
+
"""Create a local memory cluster.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
world_size: Number of workers.
|
|
56
|
+
cluster_spec: Optional cluster specification for metadata.
|
|
57
|
+
enable_tracing: If True, enable execution tracing.
|
|
58
|
+
"""
|
|
59
|
+
self._world_size = world_size
|
|
60
|
+
self._cluster_spec = cluster_spec
|
|
61
|
+
|
|
62
|
+
# Construct root_dir from cluster_id
|
|
63
|
+
data_root = pathlib.Path(os.environ.get("MPLANG_DATA_ROOT", ".mpl"))
|
|
64
|
+
cluster_id = cluster_spec.cluster_id if cluster_spec else f"local_{world_size}"
|
|
65
|
+
cluster_root = data_root / cluster_id
|
|
66
|
+
self.host_root = cluster_root / "__host__"
|
|
67
|
+
|
|
68
|
+
# Create Local Mesh (communication mesh for workers)
|
|
69
|
+
self._mesh = LocalMesh(world_size)
|
|
70
|
+
|
|
71
|
+
# Create Execution Tracer
|
|
72
|
+
self.tracer: ExecutionTracer = ExecutionTracer(
|
|
73
|
+
enabled=enable_tracing, trace_dir=self.host_root / "trace"
|
|
74
|
+
)
|
|
75
|
+
self.tracer.start()
|
|
76
|
+
|
|
77
|
+
# Create Workers
|
|
78
|
+
self._workers: list[Interpreter] = []
|
|
79
|
+
for rank in range(world_size):
|
|
80
|
+
worker_root = cluster_root / f"node{rank}"
|
|
81
|
+
store = ObjectStore(fs_root=str(worker_root / "store"))
|
|
82
|
+
|
|
83
|
+
worker_state = SimpWorker(
|
|
84
|
+
rank=rank,
|
|
85
|
+
world_size=world_size,
|
|
86
|
+
communicator=self._mesh.comms[rank],
|
|
87
|
+
store=store,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
w_handlers: dict[str, Callable[..., Any]] = {**WORKER_HANDLERS} # type: ignore[dict-item]
|
|
91
|
+
w_interp = Interpreter(
|
|
92
|
+
name=f"Worker-{rank}",
|
|
93
|
+
tracer=self.tracer,
|
|
94
|
+
trace_pid=rank,
|
|
95
|
+
store=store,
|
|
96
|
+
root_dir=worker_root,
|
|
97
|
+
handlers=w_handlers,
|
|
98
|
+
)
|
|
99
|
+
w_interp.set_dialect_state("simp", worker_state)
|
|
100
|
+
|
|
101
|
+
w_interp.async_ops = {
|
|
102
|
+
"bfv.add",
|
|
103
|
+
"bfv.mul",
|
|
104
|
+
"bfv.rotate",
|
|
105
|
+
"bfv.batch_encode",
|
|
106
|
+
"bfv.relinearize",
|
|
107
|
+
"bfv.encrypt",
|
|
108
|
+
"bfv.decrypt",
|
|
109
|
+
"field.solve_okvs",
|
|
110
|
+
"field.decode_okvs",
|
|
111
|
+
"field.aes_expand",
|
|
112
|
+
"field.mul",
|
|
113
|
+
"simp.shuffle",
|
|
114
|
+
}
|
|
115
|
+
self._workers.append(w_interp)
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def world_size(self) -> int:
|
|
119
|
+
return self._world_size
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def workers(self) -> list[Interpreter]:
|
|
123
|
+
return self._workers
|
|
124
|
+
|
|
125
|
+
def create_state(self) -> SimpMemDriver:
|
|
126
|
+
"""Create a SimpMemDriver that can be attached to a Driver Interpreter."""
|
|
127
|
+
return SimpMemDriver(
|
|
128
|
+
world_size=self._world_size,
|
|
129
|
+
workers=self._workers,
|
|
130
|
+
mesh=self._mesh,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def shutdown(self, wait: bool = True) -> None:
|
|
134
|
+
"""Stop all workers and release resources."""
|
|
135
|
+
self._mesh.shutdown(wait=wait)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class SimpMemDriver(SimpDriver):
|
|
139
|
+
"""Simp Driver for local memory IPC.
|
|
140
|
+
|
|
141
|
+
Implements submit/fetch/collect interface for dispatching work to local workers.
|
|
142
|
+
This class is created by MemCluster and attached to a Driver Interpreter.
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
dialect_name: str = "simp"
|
|
146
|
+
|
|
147
|
+
def __init__(
|
|
148
|
+
self,
|
|
149
|
+
world_size: int,
|
|
150
|
+
workers: list[Interpreter],
|
|
151
|
+
mesh: Any, # LocalMesh from simp_worker.mem
|
|
152
|
+
) -> None:
|
|
153
|
+
self._world_size = world_size
|
|
154
|
+
self._workers = workers
|
|
155
|
+
self._mesh = mesh
|
|
156
|
+
|
|
157
|
+
def shutdown(self) -> None:
|
|
158
|
+
"""Shutdown the local memory driver and its mesh."""
|
|
159
|
+
self._mesh.shutdown()
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def world_size(self) -> int:
|
|
163
|
+
return self._world_size
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def workers(self) -> list[Interpreter]:
|
|
167
|
+
"""Worker interpreters (for backward compatibility)."""
|
|
168
|
+
return self._workers
|
|
169
|
+
|
|
170
|
+
def submit(
|
|
171
|
+
self, rank: int, graph: Graph, inputs: list[Any], job_id: str | None = None
|
|
172
|
+
) -> Future[Any]:
|
|
173
|
+
"""Submit execution to local worker thread."""
|
|
174
|
+
return cast(
|
|
175
|
+
"Future[Any]",
|
|
176
|
+
self._mesh.executor.submit(
|
|
177
|
+
self._run_worker, rank, graph, inputs, job_id=job_id
|
|
178
|
+
),
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
def collect(self, futures: list[Future[Any]]) -> list[Any]:
|
|
182
|
+
"""Wait for threads and collect results."""
|
|
183
|
+
done, _ = concurrent.futures.wait(
|
|
184
|
+
futures, return_when=concurrent.futures.FIRST_EXCEPTION
|
|
185
|
+
)
|
|
186
|
+
for f in done:
|
|
187
|
+
exc = f.exception()
|
|
188
|
+
if exc:
|
|
189
|
+
for nf in futures:
|
|
190
|
+
nf.cancel()
|
|
191
|
+
self._mesh.shutdown(wait=False)
|
|
192
|
+
raise exc
|
|
193
|
+
return [f.result() for f in futures]
|
|
194
|
+
|
|
195
|
+
def fetch(self, rank: int, uri: str) -> Future[Any]:
|
|
196
|
+
"""Fetch directly from worker store."""
|
|
197
|
+
worker = self._workers[rank]
|
|
198
|
+
worker_ctx = cast(SimpWorker, worker.get_dialect_state("simp"))
|
|
199
|
+
return self._mesh.executor.submit(lambda: worker_ctx.store.get(uri)) # type: ignore[no-any-return]
|
|
200
|
+
|
|
201
|
+
def _run_worker(
|
|
202
|
+
self, rank: int, graph: Graph, inputs: list[Any], job_id: str | None = None
|
|
203
|
+
) -> Any:
|
|
204
|
+
"""Execute on worker interpreter."""
|
|
205
|
+
worker_interp = self._workers[rank]
|
|
206
|
+
worker_ctx = cast(SimpWorker, worker_interp.get_dialect_state("simp"))
|
|
207
|
+
|
|
208
|
+
# Resolve URI inputs (None means rank has no data)
|
|
209
|
+
resolved_inputs = [
|
|
210
|
+
worker_ctx.store.get(inp) if inp is not None else None for inp in inputs
|
|
211
|
+
]
|
|
212
|
+
|
|
213
|
+
# Execute
|
|
214
|
+
results = worker_interp.evaluate_graph(graph, resolved_inputs, job_id)
|
|
215
|
+
|
|
216
|
+
# Store results (results is always a list)
|
|
217
|
+
if not graph.outputs:
|
|
218
|
+
return None
|
|
219
|
+
return [
|
|
220
|
+
worker_ctx.store.put(res) if res is not None else None for res in results
|
|
221
|
+
]
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def make_simulator(
|
|
225
|
+
world_size: int,
|
|
226
|
+
*,
|
|
227
|
+
cluster_spec: Any = None,
|
|
228
|
+
enable_tracing: bool = False,
|
|
229
|
+
) -> Interpreter:
|
|
230
|
+
"""Create an Interpreter configured for local SIMP simulation.
|
|
231
|
+
|
|
232
|
+
This factory creates a MemCluster with workers and returns an
|
|
233
|
+
Interpreter with the simp dialect state attached.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
world_size: Number of simulated parties.
|
|
237
|
+
cluster_spec: Optional ClusterSpec for metadata.
|
|
238
|
+
enable_tracing: If True, enable execution tracing.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Configured Interpreter with simp state attached.
|
|
242
|
+
|
|
243
|
+
Example:
|
|
244
|
+
>>> interp = make_simulator(2)
|
|
245
|
+
>>> with interp:
|
|
246
|
+
... result = my_func()
|
|
247
|
+
"""
|
|
248
|
+
from mplang.v2.backends.simp_driver.ops import DRIVER_HANDLERS
|
|
249
|
+
|
|
250
|
+
if cluster_spec is None:
|
|
251
|
+
from mplang.v2.libs.device import ClusterSpec
|
|
252
|
+
|
|
253
|
+
cluster_spec = ClusterSpec.simple(world_size)
|
|
254
|
+
|
|
255
|
+
cluster = MemCluster(
|
|
256
|
+
world_size=world_size,
|
|
257
|
+
cluster_spec=cluster_spec,
|
|
258
|
+
enable_tracing=enable_tracing,
|
|
259
|
+
)
|
|
260
|
+
state = cluster.create_state()
|
|
261
|
+
|
|
262
|
+
handlers: dict[str, Callable[..., Any]] = {**DRIVER_HANDLERS} # type: ignore[dict-item]
|
|
263
|
+
interp = Interpreter(
|
|
264
|
+
name="HostInterpreter",
|
|
265
|
+
root_dir=cluster.host_root,
|
|
266
|
+
handlers=handlers,
|
|
267
|
+
tracer=cluster.tracer,
|
|
268
|
+
store=ObjectStore(fs_root=str(cluster.host_root)),
|
|
269
|
+
)
|
|
270
|
+
interp.set_dialect_state("simp", state)
|
|
271
|
+
|
|
272
|
+
# Keep cluster alive (prevent GC)
|
|
273
|
+
interp._simp_cluster = cluster # type: ignore[attr-defined]
|
|
274
|
+
interp._cluster_spec = cluster_spec # type: ignore[attr-defined]
|
|
275
|
+
|
|
276
|
+
return interp
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
# Backward compatibility alias
|
|
280
|
+
LocalCluster = MemCluster
|