mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Simp Driver ops (DRIVER_HANDLERS).
|
|
16
|
+
|
|
17
|
+
Unified SPMD dispatch pattern for all SIMP operations.
|
|
18
|
+
All ops: wrap → dispatch to ALL workers → collect DriverVar(s).
|
|
19
|
+
Op-specific logic lives in Worker handlers (simp_worker/ops.py).
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
from typing import Any
|
|
25
|
+
|
|
26
|
+
from mplang.v2.backends.simp_driver.values import DriverVar
|
|
27
|
+
from mplang.v2.dialects import simp
|
|
28
|
+
from mplang.v2.edsl.graph import Graph, Operation
|
|
29
|
+
from mplang.v2.edsl.typing import CustomType
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _get_driver_context(interpreter: Any) -> Any:
|
|
33
|
+
"""Get the simp driver state from interpreter."""
|
|
34
|
+
state = interpreter.get_dialect_state("simp")
|
|
35
|
+
if state is None:
|
|
36
|
+
raise RuntimeError("Interpreter must have simp dialect state attached")
|
|
37
|
+
return state
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _wrap_op_as_graph(op: Operation) -> Graph:
|
|
41
|
+
"""Wrap an Operation into a single-op Graph for worker submission."""
|
|
42
|
+
g = Graph()
|
|
43
|
+
any_type = CustomType("Any")
|
|
44
|
+
|
|
45
|
+
# Create graph inputs
|
|
46
|
+
graph_inputs = [g.add_input(f"in_{i}", any_type) for i in range(len(op.inputs))]
|
|
47
|
+
|
|
48
|
+
# Determine output types
|
|
49
|
+
output_types = [out.type for out in op.outputs] if op.outputs else [any_type]
|
|
50
|
+
|
|
51
|
+
# Add the operation (this handles outputs and value registration)
|
|
52
|
+
g.add_op(
|
|
53
|
+
opcode=op.opcode,
|
|
54
|
+
inputs=graph_inputs,
|
|
55
|
+
output_types=output_types,
|
|
56
|
+
attrs=op.attrs.copy(),
|
|
57
|
+
regions=op.regions,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Mark outputs
|
|
61
|
+
for v in g.operations[-1].outputs:
|
|
62
|
+
g.add_output(v)
|
|
63
|
+
|
|
64
|
+
return g
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _collect_to_hostvars(results: list[Any], num_outputs: int, world_size: int) -> Any:
|
|
68
|
+
"""Collect worker results into DriverVar(s).
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
results: List of results from each worker (length = world_size).
|
|
72
|
+
Each result is a list of URIs (one per output).
|
|
73
|
+
num_outputs: Number of outputs per worker
|
|
74
|
+
world_size: Total number of workers
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Single DriverVar if num_outputs == 1, else list of DriverVars
|
|
78
|
+
"""
|
|
79
|
+
if num_outputs == 0:
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
# Transpose [worker][output] -> [output][worker]
|
|
83
|
+
# results[worker_idx] is a list of URIs for that worker's outputs
|
|
84
|
+
transposed = []
|
|
85
|
+
for out_idx in range(num_outputs):
|
|
86
|
+
transposed.append(
|
|
87
|
+
DriverVar([res[out_idx] if res is not None else None for res in results])
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
if num_outputs == 1:
|
|
91
|
+
return transposed[0]
|
|
92
|
+
return transposed
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _generic_simp_dispatch(interpreter: Any, op: Operation, *args: Any) -> Any:
|
|
96
|
+
"""Unified SIMP dispatch: wrap op, SPMD submit, collect DriverVar(s).
|
|
97
|
+
|
|
98
|
+
This is the ONLY driver handler needed for all SIMP ops.
|
|
99
|
+
Worker handlers implement the actual op-specific logic.
|
|
100
|
+
"""
|
|
101
|
+
driver = _get_driver_context(interpreter)
|
|
102
|
+
world_size = driver.world_size
|
|
103
|
+
|
|
104
|
+
# 1. Wrap operation into a Graph
|
|
105
|
+
wrapper_graph = _wrap_op_as_graph(op)
|
|
106
|
+
|
|
107
|
+
# 2. SPMD dispatch to ALL workers
|
|
108
|
+
futures = []
|
|
109
|
+
for rank in range(world_size):
|
|
110
|
+
# Extract per-party inputs from DriverVars
|
|
111
|
+
party_inputs = [
|
|
112
|
+
arg[rank] if isinstance(arg, DriverVar) else arg for arg in args
|
|
113
|
+
]
|
|
114
|
+
futures.append(driver.submit(rank, wrapper_graph, party_inputs))
|
|
115
|
+
|
|
116
|
+
# 3. Collect results
|
|
117
|
+
results = driver.collect(futures)
|
|
118
|
+
|
|
119
|
+
# 4. Assemble into DriverVar(s)
|
|
120
|
+
num_outputs = len(op.outputs) if op.outputs else 1
|
|
121
|
+
return _collect_to_hostvars(results, num_outputs, world_size)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
# =============================================================================
|
|
125
|
+
# All SIMP ops use unified dispatch
|
|
126
|
+
# =============================================================================
|
|
127
|
+
|
|
128
|
+
DRIVER_HANDLERS = {
|
|
129
|
+
simp.pcall_static_p.name: _generic_simp_dispatch,
|
|
130
|
+
simp.pcall_dynamic_p.name: _generic_simp_dispatch,
|
|
131
|
+
simp.shuffle_static_p.name: _generic_simp_dispatch,
|
|
132
|
+
simp.converge_p.name: _generic_simp_dispatch,
|
|
133
|
+
simp.uniform_cond_p.name: _generic_simp_dispatch,
|
|
134
|
+
simp.while_loop_p.name: _generic_simp_dispatch,
|
|
135
|
+
}
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""SimpDriver abstract base class."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from abc import ABC, abstractmethod
|
|
20
|
+
from typing import TYPE_CHECKING, Any
|
|
21
|
+
|
|
22
|
+
from mplang.v2.runtime.dialect_state import DialectState
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from concurrent.futures import Future
|
|
26
|
+
|
|
27
|
+
from mplang.v2.edsl.graph import Graph
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SimpDriver(DialectState, ABC):
|
|
31
|
+
"""Abstract base class for Simp Host drivers.
|
|
32
|
+
|
|
33
|
+
All simp drivers must implement submit/fetch/collect interface
|
|
34
|
+
for dispatching work to workers.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
dialect_name: str = "simp"
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def world_size(self) -> int:
|
|
42
|
+
"""Number of workers."""
|
|
43
|
+
...
|
|
44
|
+
|
|
45
|
+
@abstractmethod
|
|
46
|
+
def submit(
|
|
47
|
+
self, rank: int, graph: Graph, inputs: list[Any], job_id: str | None = None
|
|
48
|
+
) -> Future[Any]:
|
|
49
|
+
"""Submit graph execution to a worker."""
|
|
50
|
+
...
|
|
51
|
+
|
|
52
|
+
@abstractmethod
|
|
53
|
+
def fetch(self, rank: int, uri: str) -> Future[Any]:
|
|
54
|
+
"""Fetch data from a worker."""
|
|
55
|
+
...
|
|
56
|
+
|
|
57
|
+
@abstractmethod
|
|
58
|
+
def collect(self, futures: list[Future[Any]]) -> list[Any]:
|
|
59
|
+
"""Collect results from futures."""
|
|
60
|
+
...
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Simp Driver values (DriverVar)."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from typing import Any, ClassVar, Self
|
|
20
|
+
|
|
21
|
+
from mplang.v2.edsl import serde
|
|
22
|
+
from mplang.v2.runtime.value import Value
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@serde.register_class
|
|
26
|
+
class DriverVar(Value):
|
|
27
|
+
"""A value replicated (or sharded) on the Driver.
|
|
28
|
+
|
|
29
|
+
A DriverVar holds a list of values, one for each party in the computation.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
_serde_kind: ClassVar[str] = "simp.DriverVar"
|
|
33
|
+
|
|
34
|
+
def __init__(self, values: list[Any]):
|
|
35
|
+
self.values = values
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def world_size(self) -> int:
|
|
39
|
+
return len(self.values)
|
|
40
|
+
|
|
41
|
+
def __repr__(self) -> str:
|
|
42
|
+
return f"DriverVar({self.values})"
|
|
43
|
+
|
|
44
|
+
def __getitem__(self, idx: int) -> Any:
|
|
45
|
+
return self.values[idx]
|
|
46
|
+
|
|
47
|
+
def to_json(self) -> dict[str, Any]:
|
|
48
|
+
return {"values": self.values}
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def from_json(cls, data: dict[str, Any]) -> Self:
|
|
52
|
+
return cls(values=data["values"])
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Simp Worker package.
|
|
16
|
+
|
|
17
|
+
Provides Worker-side state and ops for the simp dialect.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from mplang.v2.backends.simp_worker.mem import LocalMesh, ThreadCommunicator
|
|
21
|
+
from mplang.v2.backends.simp_worker.ops import WORKER_HANDLERS
|
|
22
|
+
from mplang.v2.backends.simp_worker.state import SimpWorker
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"WORKER_HANDLERS",
|
|
26
|
+
"LocalMesh",
|
|
27
|
+
"SimpWorker",
|
|
28
|
+
"ThreadCommunicator",
|
|
29
|
+
]
|
|
@@ -0,0 +1,354 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""SIMP HTTP Worker module.
|
|
16
|
+
|
|
17
|
+
Provides the HTTP-based worker entry point for distributed deployment.
|
|
18
|
+
This module contains:
|
|
19
|
+
- HttpCommunicator: HTTP-based inter-worker communication
|
|
20
|
+
- create_worker_app: Factory for FastAPI application
|
|
21
|
+
|
|
22
|
+
Usage:
|
|
23
|
+
# Start a worker server
|
|
24
|
+
from mplang.v2.backends.simp_http_worker import create_worker_app
|
|
25
|
+
import uvicorn
|
|
26
|
+
|
|
27
|
+
app = create_worker_app(rank=0, world_size=3, endpoints=[...])
|
|
28
|
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
29
|
+
|
|
30
|
+
Security:
|
|
31
|
+
This module uses secure JSON-based serialization (serde) for computation
|
|
32
|
+
graphs and data between workers. Unlike pickle, JSON deserialization
|
|
33
|
+
cannot execute arbitrary code, making it safe for network communication.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
from __future__ import annotations
|
|
37
|
+
|
|
38
|
+
import concurrent.futures
|
|
39
|
+
import logging
|
|
40
|
+
import os
|
|
41
|
+
import pathlib
|
|
42
|
+
import threading
|
|
43
|
+
import time
|
|
44
|
+
from typing import Any
|
|
45
|
+
|
|
46
|
+
import httpx
|
|
47
|
+
from fastapi import FastAPI, HTTPException
|
|
48
|
+
from pydantic import BaseModel
|
|
49
|
+
|
|
50
|
+
from mplang.v2.backends import spu_impl as _spu_impl # noqa: F401
|
|
51
|
+
from mplang.v2.backends import tensor_impl as _tensor_impl # noqa: F401
|
|
52
|
+
|
|
53
|
+
# Register operation implementations (side-effect imports)
|
|
54
|
+
from mplang.v2.backends.simp_worker import SimpWorker
|
|
55
|
+
from mplang.v2.backends.simp_worker import ops as _simp_worker_ops # noqa: F401
|
|
56
|
+
from mplang.v2.edsl import serde
|
|
57
|
+
from mplang.v2.edsl.graph import Graph
|
|
58
|
+
from mplang.v2.runtime.interpreter import ExecutionTracer, Interpreter
|
|
59
|
+
from mplang.v2.runtime.object_store import ObjectStore
|
|
60
|
+
|
|
61
|
+
logger = logging.getLogger(__name__)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class HttpCommunicator:
|
|
65
|
+
"""Communicator using HTTP requests for inter-worker communication.
|
|
66
|
+
|
|
67
|
+
Uses a background thread pool for sending to avoid blocking the main execution.
|
|
68
|
+
|
|
69
|
+
Attributes:
|
|
70
|
+
rank: This worker's rank
|
|
71
|
+
world_size: Total number of workers
|
|
72
|
+
endpoints: HTTP endpoints for all workers
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
rank: int,
|
|
78
|
+
world_size: int,
|
|
79
|
+
endpoints: list[str],
|
|
80
|
+
tracer: ExecutionTracer | None = None,
|
|
81
|
+
):
|
|
82
|
+
self.rank = rank
|
|
83
|
+
self.world_size = world_size
|
|
84
|
+
self.endpoints = endpoints
|
|
85
|
+
self.tracer = tracer
|
|
86
|
+
self._mailbox: dict[tuple[int, str], Any] = {}
|
|
87
|
+
self._cond = threading.Condition()
|
|
88
|
+
self._send_executor = concurrent.futures.ThreadPoolExecutor(
|
|
89
|
+
max_workers=world_size, thread_name_prefix=f"comm_send_{rank}"
|
|
90
|
+
)
|
|
91
|
+
self._pending_sends: list[concurrent.futures.Future[None]] = []
|
|
92
|
+
self.client = httpx.Client(timeout=None)
|
|
93
|
+
|
|
94
|
+
def send(self, to: int, key: str, data: Any) -> None:
|
|
95
|
+
"""Send data to another rank asynchronously."""
|
|
96
|
+
future = self._send_executor.submit(self._do_send, to, key, data)
|
|
97
|
+
self._pending_sends.append(future)
|
|
98
|
+
|
|
99
|
+
def _do_send(self, to: int, key: str, data: Any) -> None:
|
|
100
|
+
"""Perform the HTTP send."""
|
|
101
|
+
url = f"{self.endpoints[to]}/comm/{key}"
|
|
102
|
+
logger.debug(f"Rank {self.rank} sending to {to} key={key}")
|
|
103
|
+
|
|
104
|
+
# Detect SPU channel (tag prefix "spu:") and handle bytes
|
|
105
|
+
if key.startswith("spu:") and isinstance(data, bytes):
|
|
106
|
+
# Send raw bytes for SPU channels
|
|
107
|
+
import base64
|
|
108
|
+
|
|
109
|
+
payload = base64.b64encode(data).decode("ascii")
|
|
110
|
+
is_raw_bytes = True
|
|
111
|
+
else:
|
|
112
|
+
# Use secure JSON serialization
|
|
113
|
+
payload = serde.dumps_b64(data)
|
|
114
|
+
is_raw_bytes = False
|
|
115
|
+
|
|
116
|
+
size_bytes = len(payload)
|
|
117
|
+
|
|
118
|
+
# Log to profiler
|
|
119
|
+
if self.tracer:
|
|
120
|
+
self.tracer.log_custom_event(
|
|
121
|
+
name="comm.send",
|
|
122
|
+
start_ts=time.time(),
|
|
123
|
+
end_ts=time.time(), # Instant event for size? Or measure duration?
|
|
124
|
+
cat="comm",
|
|
125
|
+
args={"to": to, "key": key, "bytes": size_bytes},
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
try:
|
|
129
|
+
t0 = time.time()
|
|
130
|
+
resp = self.client.put(
|
|
131
|
+
url,
|
|
132
|
+
json={
|
|
133
|
+
"data": payload,
|
|
134
|
+
"from_rank": self.rank,
|
|
135
|
+
"is_raw_bytes": is_raw_bytes,
|
|
136
|
+
},
|
|
137
|
+
)
|
|
138
|
+
resp.raise_for_status()
|
|
139
|
+
duration = time.time() - t0
|
|
140
|
+
if self.tracer:
|
|
141
|
+
self.tracer.log_custom_event(
|
|
142
|
+
name="comm.send_req",
|
|
143
|
+
start_ts=t0,
|
|
144
|
+
end_ts=t0 + duration,
|
|
145
|
+
cat="comm",
|
|
146
|
+
args={"to": to, "key": key, "bytes": size_bytes},
|
|
147
|
+
)
|
|
148
|
+
except Exception as e:
|
|
149
|
+
logger.error(f"Rank {self.rank} failed to send to {to}: {e}")
|
|
150
|
+
raise RuntimeError(f"Failed to send to {to} ({url}): {e}") from e
|
|
151
|
+
|
|
152
|
+
def recv(self, frm: int, key: str) -> Any:
|
|
153
|
+
"""Receive data from another rank (blocking)."""
|
|
154
|
+
logger.debug(f"Rank {self.rank} waiting recv from {frm} key={key}")
|
|
155
|
+
mailbox_key = (frm, key)
|
|
156
|
+
with self._cond:
|
|
157
|
+
while mailbox_key not in self._mailbox:
|
|
158
|
+
self._cond.wait(timeout=1.0)
|
|
159
|
+
return self._mailbox.pop(mailbox_key)
|
|
160
|
+
|
|
161
|
+
def on_receive(self, from_rank: int, key: str, data: Any) -> None:
|
|
162
|
+
"""Called when data is received from the HTTP endpoint."""
|
|
163
|
+
mailbox_key = (from_rank, key)
|
|
164
|
+
with self._cond:
|
|
165
|
+
if mailbox_key in self._mailbox:
|
|
166
|
+
raise RuntimeError(
|
|
167
|
+
f"Mailbox overflow: key {mailbox_key} already exists"
|
|
168
|
+
)
|
|
169
|
+
self._mailbox[mailbox_key] = data
|
|
170
|
+
self._cond.notify_all()
|
|
171
|
+
|
|
172
|
+
def wait_pending_sends(self) -> None:
|
|
173
|
+
"""Wait for all pending sends to complete."""
|
|
174
|
+
for future in self._pending_sends:
|
|
175
|
+
try:
|
|
176
|
+
future.result(timeout=60.0)
|
|
177
|
+
except Exception as e:
|
|
178
|
+
logger.error(f"Rank {self.rank} send failed: {e}")
|
|
179
|
+
self._pending_sends.clear()
|
|
180
|
+
|
|
181
|
+
def shutdown(self) -> None:
|
|
182
|
+
"""Shutdown the send executor."""
|
|
183
|
+
self.wait_pending_sends()
|
|
184
|
+
self._send_executor.shutdown(wait=True)
|
|
185
|
+
self.client.close()
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class ExecRequest(BaseModel):
|
|
189
|
+
"""Request model for /exec endpoint."""
|
|
190
|
+
|
|
191
|
+
graph: str
|
|
192
|
+
inputs: str
|
|
193
|
+
job_id: str | None = None
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class CommRequest(BaseModel):
|
|
197
|
+
"""Request model for /comm endpoint."""
|
|
198
|
+
|
|
199
|
+
data: str
|
|
200
|
+
from_rank: int
|
|
201
|
+
is_raw_bytes: bool = False # NEW: indicates raw bytes (not serde)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class FetchRequest(BaseModel):
|
|
205
|
+
"""Request model for /fetch endpoint."""
|
|
206
|
+
|
|
207
|
+
uri: str
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def create_worker_app(
|
|
211
|
+
rank: int,
|
|
212
|
+
world_size: int,
|
|
213
|
+
endpoints: list[str],
|
|
214
|
+
spu_endpoints: dict[int, str] | None = None,
|
|
215
|
+
) -> FastAPI:
|
|
216
|
+
"""Create a FastAPI app for the worker.
|
|
217
|
+
|
|
218
|
+
The app uses async endpoints with run_in_executor to allow concurrent
|
|
219
|
+
handling of /exec and /comm requests. This is essential for cross-party
|
|
220
|
+
communication where one party sends while another receives.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
rank: The global rank of this worker.
|
|
224
|
+
world_size: Total number of workers.
|
|
225
|
+
endpoints: HTTP endpoints for all workers (for shuffle communication).
|
|
226
|
+
spu_endpoints: Optional dict mapping global_rank -> BRPC endpoint for SPU.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
FastAPI application instance
|
|
230
|
+
"""
|
|
231
|
+
import asyncio
|
|
232
|
+
|
|
233
|
+
app = FastAPI(title=f"SIMP Worker {rank}")
|
|
234
|
+
|
|
235
|
+
# Persistence root: ${MPLANG_DATA_ROOT}/<cluster_id>/node<rank>/
|
|
236
|
+
data_root = pathlib.Path(os.environ.get("MPLANG_DATA_ROOT", ".mpl"))
|
|
237
|
+
cluster_id = os.environ.get("MPLANG_CLUSTER_ID", f"__http_{world_size}")
|
|
238
|
+
root_dir = data_root / cluster_id / f"node{rank}"
|
|
239
|
+
trace_dir = root_dir / "trace"
|
|
240
|
+
|
|
241
|
+
# Enable tracing by default for now (or make it configurable via env)
|
|
242
|
+
tracer = ExecutionTracer(enabled=True, trace_dir=trace_dir)
|
|
243
|
+
tracer.start()
|
|
244
|
+
|
|
245
|
+
comm = HttpCommunicator(rank, world_size, endpoints, tracer=tracer)
|
|
246
|
+
store = ObjectStore(fs_root=str(root_dir))
|
|
247
|
+
ctx = SimpWorker(rank, world_size, comm, store, spu_endpoints)
|
|
248
|
+
|
|
249
|
+
# Register handlers
|
|
250
|
+
from collections.abc import Callable
|
|
251
|
+
from typing import cast
|
|
252
|
+
|
|
253
|
+
from mplang.v2.backends.simp_worker.ops import WORKER_HANDLERS
|
|
254
|
+
|
|
255
|
+
# func_impl is already imported at module level for side-effects
|
|
256
|
+
handlers: dict[str, Callable[..., Any]] = {**WORKER_HANDLERS} # type: ignore[dict-item]
|
|
257
|
+
|
|
258
|
+
worker = Interpreter(
|
|
259
|
+
tracer=tracer, root_dir=root_dir, handlers=handlers, store=store
|
|
260
|
+
)
|
|
261
|
+
# Register SimpWorker context as 'simp' dialect state
|
|
262
|
+
worker.set_dialect_state("simp", ctx)
|
|
263
|
+
|
|
264
|
+
exec_pool = concurrent.futures.ThreadPoolExecutor(
|
|
265
|
+
max_workers=2, thread_name_prefix=f"exec_{rank}"
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
def _do_execute(graph: Graph, inputs: list[Any], job_id: str | None = None) -> Any:
|
|
269
|
+
"""Execute graph in worker thread."""
|
|
270
|
+
# Resolve URI inputs (None means rank has no data)
|
|
271
|
+
resolved_inputs = [
|
|
272
|
+
store.get(inp) if inp is not None else None for inp in inputs
|
|
273
|
+
]
|
|
274
|
+
|
|
275
|
+
result = worker.evaluate_graph(graph, resolved_inputs)
|
|
276
|
+
comm.wait_pending_sends()
|
|
277
|
+
|
|
278
|
+
# Store results and return URIs (result is always a list)
|
|
279
|
+
if not graph.outputs:
|
|
280
|
+
return None
|
|
281
|
+
return [store.put(res) if res is not None else None for res in result]
|
|
282
|
+
|
|
283
|
+
@app.post("/exec")
|
|
284
|
+
async def execute(req: ExecRequest) -> dict[str, str]:
|
|
285
|
+
"""Execute a graph on this worker."""
|
|
286
|
+
logger.debug(f"Worker {rank} received exec request")
|
|
287
|
+
try:
|
|
288
|
+
# Use secure JSON deserialization
|
|
289
|
+
graph = serde.loads_b64(req.graph)
|
|
290
|
+
inputs = serde.loads_b64(req.inputs)
|
|
291
|
+
loop = asyncio.get_event_loop()
|
|
292
|
+
result = await loop.run_in_executor(
|
|
293
|
+
exec_pool, _do_execute, graph, inputs, req.job_id
|
|
294
|
+
)
|
|
295
|
+
return {"result": serde.dumps_b64(result)}
|
|
296
|
+
except Exception as e:
|
|
297
|
+
logger.error(f"Worker {rank} exec failed: {e}")
|
|
298
|
+
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
299
|
+
|
|
300
|
+
@app.put("/comm/{key}")
|
|
301
|
+
async def receive_comm(key: str, req: CommRequest) -> dict[str, str]:
|
|
302
|
+
"""Receive communication data from another worker."""
|
|
303
|
+
logger.debug(f"Worker {rank} received comm key={key} from {req.from_rank}")
|
|
304
|
+
try:
|
|
305
|
+
# Handle raw bytes (SPU channels) vs serde data
|
|
306
|
+
if req.is_raw_bytes:
|
|
307
|
+
# Decode base64 to raw bytes
|
|
308
|
+
import base64
|
|
309
|
+
|
|
310
|
+
data = base64.b64decode(req.data)
|
|
311
|
+
else:
|
|
312
|
+
# Use secure JSON deserialization
|
|
313
|
+
data = serde.loads_b64(req.data)
|
|
314
|
+
|
|
315
|
+
comm.on_receive(req.from_rank, key, data)
|
|
316
|
+
return {"status": "ok"}
|
|
317
|
+
except Exception as e:
|
|
318
|
+
logger.error(f"Worker {rank} comm failed: {e}")
|
|
319
|
+
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
320
|
+
|
|
321
|
+
@app.post("/fetch")
|
|
322
|
+
async def fetch(req: FetchRequest) -> dict[str, str]:
|
|
323
|
+
"""Fetch data by URI."""
|
|
324
|
+
logger.debug(f"Worker {rank} received fetch request for {req.uri}")
|
|
325
|
+
try:
|
|
326
|
+
state = cast(SimpWorker, worker.get_dialect_state("simp"))
|
|
327
|
+
val = state.store.get(req.uri)
|
|
328
|
+
return {"result": serde.dumps_b64(val)}
|
|
329
|
+
except Exception as e:
|
|
330
|
+
logger.error(f"Worker {rank} fetch failed: {e}")
|
|
331
|
+
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
332
|
+
|
|
333
|
+
@app.get("/objects")
|
|
334
|
+
async def list_objects() -> dict[str, list[str]]:
|
|
335
|
+
"""List all objects in the worker's store."""
|
|
336
|
+
try:
|
|
337
|
+
state = cast(SimpWorker, worker.get_dialect_state("simp"))
|
|
338
|
+
return {"objects": state.store.list_objects()}
|
|
339
|
+
except Exception as e:
|
|
340
|
+
logger.error(f"Worker {rank} list_objects failed: {e}")
|
|
341
|
+
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
342
|
+
|
|
343
|
+
@app.get("/health")
|
|
344
|
+
async def health() -> dict[str, str]:
|
|
345
|
+
"""Health check endpoint."""
|
|
346
|
+
return {"status": "ok", "rank": str(rank), "world_size": str(world_size)}
|
|
347
|
+
|
|
348
|
+
@app.on_event("shutdown")
|
|
349
|
+
def shutdown_event() -> None:
|
|
350
|
+
"""Cleanup on shutdown."""
|
|
351
|
+
comm.shutdown()
|
|
352
|
+
exec_pool.shutdown(wait=True)
|
|
353
|
+
|
|
354
|
+
return app
|