mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/v2/cli.py
ADDED
|
@@ -0,0 +1,603 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
"""
|
|
17
|
+
Command-line interface for MPLang2 clusters and jobs.
|
|
18
|
+
|
|
19
|
+
Examples:
|
|
20
|
+
# Generate a cluster config file
|
|
21
|
+
python -m mplang.v2.cli config gen -w 3 -p 8100 -o cluster.yaml
|
|
22
|
+
|
|
23
|
+
# Start a single worker (production usage)
|
|
24
|
+
python -m mplang.v2.cli worker --rank 0 -c cluster.yaml
|
|
25
|
+
|
|
26
|
+
# Start 3 local workers (development usage)
|
|
27
|
+
python -m mplang.v2.cli up -c cluster.yaml
|
|
28
|
+
|
|
29
|
+
# Check cluster status
|
|
30
|
+
python -m mplang.v2.cli status -c cluster.yaml
|
|
31
|
+
|
|
32
|
+
# Run a job
|
|
33
|
+
python -m mplang.v2.cli run -c cluster.yaml -f my_job.py
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
import argparse
|
|
37
|
+
import glob
|
|
38
|
+
import importlib.util
|
|
39
|
+
import json
|
|
40
|
+
import multiprocessing
|
|
41
|
+
import os
|
|
42
|
+
import re
|
|
43
|
+
import signal
|
|
44
|
+
import sys
|
|
45
|
+
from collections.abc import Callable
|
|
46
|
+
from types import ModuleType
|
|
47
|
+
from typing import Any, cast
|
|
48
|
+
|
|
49
|
+
import uvicorn
|
|
50
|
+
import yaml
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def run_worker(
|
|
54
|
+
rank: int,
|
|
55
|
+
world_size: int,
|
|
56
|
+
port: int,
|
|
57
|
+
endpoints: list[str],
|
|
58
|
+
spu_endpoints: dict[int, str] | None = None,
|
|
59
|
+
) -> None:
|
|
60
|
+
"""Run a single worker server."""
|
|
61
|
+
# Reset signal handlers to default in child process to avoid conflict with parent's shutdown handler
|
|
62
|
+
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
|
63
|
+
signal.signal(signal.SIGTERM, signal.SIG_DFL)
|
|
64
|
+
|
|
65
|
+
from mplang.v2.backends.simp_worker.http import create_worker_app
|
|
66
|
+
|
|
67
|
+
app = create_worker_app(rank, world_size, endpoints, spu_endpoints)
|
|
68
|
+
|
|
69
|
+
log_config: dict[str, Any] = {
|
|
70
|
+
"version": 1,
|
|
71
|
+
"disable_existing_loggers": False,
|
|
72
|
+
"formatters": {
|
|
73
|
+
"default": {
|
|
74
|
+
"format": f"[Worker {rank}] %(levelname)s: %(message)s",
|
|
75
|
+
},
|
|
76
|
+
},
|
|
77
|
+
"handlers": {
|
|
78
|
+
"default": {
|
|
79
|
+
"formatter": "default",
|
|
80
|
+
"class": "logging.StreamHandler",
|
|
81
|
+
"stream": "ext://sys.stderr",
|
|
82
|
+
},
|
|
83
|
+
},
|
|
84
|
+
"loggers": {
|
|
85
|
+
"uvicorn": {"handlers": ["default"], "level": "WARNING"},
|
|
86
|
+
"uvicorn.error": {"handlers": ["default"], "level": "WARNING"},
|
|
87
|
+
"uvicorn.access": {"handlers": ["default"], "level": "WARNING"},
|
|
88
|
+
},
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
uvicorn.run(
|
|
92
|
+
app,
|
|
93
|
+
host="0.0.0.0",
|
|
94
|
+
port=port,
|
|
95
|
+
log_config=log_config,
|
|
96
|
+
log_level="warning",
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def build_endpoints(
|
|
101
|
+
args: argparse.Namespace,
|
|
102
|
+
) -> tuple[list[str], list[int], int, dict[int, str] | None]:
|
|
103
|
+
"""Build endpoints and SPU endpoints from config or CLI flags."""
|
|
104
|
+
|
|
105
|
+
def normalize_endpoint(ep: str) -> str:
|
|
106
|
+
return ep if ep.startswith("http") else f"http://{ep}"
|
|
107
|
+
|
|
108
|
+
spu_endpoints: dict[int, str] | None = None
|
|
109
|
+
|
|
110
|
+
if args.config:
|
|
111
|
+
with open(args.config, encoding="utf-8") as f:
|
|
112
|
+
conf = yaml.safe_load(f)
|
|
113
|
+
nodes = conf.get("nodes", [])
|
|
114
|
+
if not nodes:
|
|
115
|
+
raise ValueError("Config must contain nodes")
|
|
116
|
+
world_size = len(nodes)
|
|
117
|
+
endpoints = []
|
|
118
|
+
ports = []
|
|
119
|
+
for node in nodes:
|
|
120
|
+
endpoint = normalize_endpoint(node["endpoint"])
|
|
121
|
+
endpoints.append(endpoint)
|
|
122
|
+
ports.append(int(endpoint.split(":")[-1]))
|
|
123
|
+
|
|
124
|
+
devices = conf.get("devices", {})
|
|
125
|
+
for _dev_name, dev_conf in devices.items():
|
|
126
|
+
if dev_conf.get("kind", "").upper() == "SPU":
|
|
127
|
+
spu_endpoints = {}
|
|
128
|
+
spu_base_port = args.spu_base_port or (ports[0] + 1000)
|
|
129
|
+
for i, node in enumerate(nodes):
|
|
130
|
+
host = node["endpoint"].split(":")[0]
|
|
131
|
+
spu_endpoints[i] = f"{host}:{spu_base_port + i}"
|
|
132
|
+
break
|
|
133
|
+
else:
|
|
134
|
+
world_size = args.world_size
|
|
135
|
+
base_port = getattr(args, "base_port", 5000)
|
|
136
|
+
ports = [base_port + i for i in range(world_size)]
|
|
137
|
+
endpoints = [f"http://127.0.0.1:{p}" for p in ports]
|
|
138
|
+
|
|
139
|
+
if args.spu_base_port:
|
|
140
|
+
spu_endpoints = {
|
|
141
|
+
i: f"127.0.0.1:{args.spu_base_port + i}" for i in range(world_size)
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
if args.endpoints:
|
|
145
|
+
endpoints = [normalize_endpoint(ep.strip()) for ep in args.endpoints.split(",")]
|
|
146
|
+
ports = [int(ep.split(":")[-1]) for ep in endpoints]
|
|
147
|
+
world_size = len(endpoints)
|
|
148
|
+
|
|
149
|
+
return endpoints, ports, world_size, spu_endpoints
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def add_cluster_args(
|
|
153
|
+
parser: argparse.ArgumentParser, *, include_ports: bool = True
|
|
154
|
+
) -> None:
|
|
155
|
+
"""Add common cluster arguments to subparsers."""
|
|
156
|
+
|
|
157
|
+
parser.add_argument("-c", "--config", type=str, help="Cluster config YAML")
|
|
158
|
+
parser.add_argument("--endpoints", type=str, help="Comma-separated HTTP endpoints")
|
|
159
|
+
parser.add_argument(
|
|
160
|
+
"--spu-endpoints", type=str, help="Comma-separated SPU BRPC endpoints"
|
|
161
|
+
)
|
|
162
|
+
parser.add_argument(
|
|
163
|
+
"--spu-base-port",
|
|
164
|
+
type=int,
|
|
165
|
+
default=None,
|
|
166
|
+
help="Base port for SPU BRPC (default: http_port + 1000)",
|
|
167
|
+
)
|
|
168
|
+
parser.add_argument(
|
|
169
|
+
"-w",
|
|
170
|
+
"--world-size",
|
|
171
|
+
type=int,
|
|
172
|
+
default=3,
|
|
173
|
+
help="Number of workers (default: 3)",
|
|
174
|
+
)
|
|
175
|
+
if include_ports:
|
|
176
|
+
parser.add_argument(
|
|
177
|
+
"-p",
|
|
178
|
+
"--base-port",
|
|
179
|
+
type=int,
|
|
180
|
+
default=8100,
|
|
181
|
+
help="Base port (default: 8100)",
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def cmd_config_gen(args: argparse.Namespace) -> None:
|
|
186
|
+
"""Generate cluster configuration."""
|
|
187
|
+
world_size = args.world_size
|
|
188
|
+
base_port = args.base_port
|
|
189
|
+
|
|
190
|
+
nodes = []
|
|
191
|
+
for i in range(world_size):
|
|
192
|
+
nodes.append({"name": f"node_{i}", "endpoint": f"127.0.0.1:{base_port + i}"})
|
|
193
|
+
|
|
194
|
+
config: dict[str, Any] = {"nodes": nodes}
|
|
195
|
+
|
|
196
|
+
# Add default PPU devices
|
|
197
|
+
devices: dict[str, Any] = {}
|
|
198
|
+
for i in range(world_size):
|
|
199
|
+
devices[f"P{i}"] = {
|
|
200
|
+
"kind": "ppu",
|
|
201
|
+
"members": [f"node_{i}"],
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
if args.spu_base_port:
|
|
205
|
+
devices["SPU0"] = {
|
|
206
|
+
"kind": "SPU",
|
|
207
|
+
"members": [n["name"] for n in nodes],
|
|
208
|
+
"config": {"protocol": "ABY3", "field": "FM64"},
|
|
209
|
+
}
|
|
210
|
+
config["devices"] = devices
|
|
211
|
+
|
|
212
|
+
yaml_content = yaml.dump(config, sort_keys=False)
|
|
213
|
+
|
|
214
|
+
if args.output:
|
|
215
|
+
with open(args.output, "w") as f:
|
|
216
|
+
f.write(yaml_content)
|
|
217
|
+
print(f"Config written to {args.output}")
|
|
218
|
+
else:
|
|
219
|
+
print(yaml_content)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def cmd_worker(args: argparse.Namespace) -> None:
|
|
223
|
+
"""Start a single worker process."""
|
|
224
|
+
endpoints, ports, world_size, spu_endpoints = build_endpoints(args)
|
|
225
|
+
rank = args.rank
|
|
226
|
+
|
|
227
|
+
if rank < 0 or rank >= world_size:
|
|
228
|
+
raise ValueError(f"Rank {rank} is out of range [0, {world_size - 1}]")
|
|
229
|
+
|
|
230
|
+
print(f"Starting Worker {rank} on {endpoints[rank]}...")
|
|
231
|
+
if spu_endpoints and rank in spu_endpoints:
|
|
232
|
+
print(f" SPU BRPC: {spu_endpoints[rank]}")
|
|
233
|
+
|
|
234
|
+
run_worker(rank, world_size, ports[rank], endpoints, spu_endpoints)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def cmd_up(args: argparse.Namespace) -> None:
|
|
238
|
+
"""Start worker servers locally."""
|
|
239
|
+
endpoints, ports, world_size, spu_endpoints = build_endpoints(args)
|
|
240
|
+
|
|
241
|
+
print(f"Starting {world_size} workers...")
|
|
242
|
+
for i, endpoint in enumerate(endpoints):
|
|
243
|
+
print(f" Worker {i}: {endpoint}")
|
|
244
|
+
if spu_endpoints:
|
|
245
|
+
print("SPU BRPC endpoints:")
|
|
246
|
+
for rank, ep in spu_endpoints.items():
|
|
247
|
+
print(f" Rank {rank}: {ep}")
|
|
248
|
+
|
|
249
|
+
processes: list[multiprocessing.Process] = []
|
|
250
|
+
|
|
251
|
+
def shutdown(signum: int, frame: Any) -> None:
|
|
252
|
+
print("\nShutting down workers...")
|
|
253
|
+
for p in processes:
|
|
254
|
+
p.terminate()
|
|
255
|
+
for p in processes:
|
|
256
|
+
p.join(timeout=5)
|
|
257
|
+
sys.exit(0)
|
|
258
|
+
|
|
259
|
+
signal.signal(signal.SIGINT, shutdown)
|
|
260
|
+
signal.signal(signal.SIGTERM, shutdown)
|
|
261
|
+
|
|
262
|
+
for rank in range(world_size):
|
|
263
|
+
p = multiprocessing.Process(
|
|
264
|
+
target=run_worker,
|
|
265
|
+
args=(rank, world_size, ports[rank], endpoints, spu_endpoints),
|
|
266
|
+
)
|
|
267
|
+
p.start()
|
|
268
|
+
processes.append(p)
|
|
269
|
+
|
|
270
|
+
print("\nWorkers started. Press Ctrl+C to stop.")
|
|
271
|
+
|
|
272
|
+
for p in processes:
|
|
273
|
+
p.join()
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def cmd_status(args: argparse.Namespace) -> None:
|
|
277
|
+
"""Check /health of workers."""
|
|
278
|
+
import httpx
|
|
279
|
+
|
|
280
|
+
endpoints, _, world_size, _ = build_endpoints(args)
|
|
281
|
+
|
|
282
|
+
print(f"Checking {len(endpoints)} endpoints (world_size={world_size})...")
|
|
283
|
+
for ep in endpoints:
|
|
284
|
+
url = f"{ep}/health"
|
|
285
|
+
try:
|
|
286
|
+
resp = httpx.get(url, timeout=3.0)
|
|
287
|
+
resp.raise_for_status()
|
|
288
|
+
print(f"OK {url} -> {resp.json()}")
|
|
289
|
+
except Exception as exc:
|
|
290
|
+
print(f"ERR {url} -> {exc}")
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def load_user_module(path: str) -> ModuleType:
|
|
294
|
+
"""Load a Python module from file path."""
|
|
295
|
+
if not os.path.exists(path):
|
|
296
|
+
raise FileNotFoundError(path)
|
|
297
|
+
spec = importlib.util.spec_from_file_location("mp_user_module", path)
|
|
298
|
+
if spec is None or spec.loader is None:
|
|
299
|
+
raise ImportError(f"Cannot import module from {path}")
|
|
300
|
+
module = importlib.util.module_from_spec(spec)
|
|
301
|
+
sys.modules[spec.name] = module
|
|
302
|
+
spec.loader.exec_module(module)
|
|
303
|
+
return module
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def resolve_entry(module: ModuleType, name: str) -> Callable[..., Any]:
|
|
307
|
+
entry = getattr(module, name, None)
|
|
308
|
+
if entry is None or not callable(entry):
|
|
309
|
+
raise AttributeError(f"Entry function '{name}' not found or not callable")
|
|
310
|
+
return cast(Callable[..., Any], entry)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def parse_spu_endpoints(
|
|
314
|
+
raw: str | None, world_size: int, default: dict[int, str] | None
|
|
315
|
+
) -> dict[int, str] | None:
|
|
316
|
+
if raw is None:
|
|
317
|
+
return default
|
|
318
|
+
parts = [p.strip() for p in raw.split(",") if p.strip()]
|
|
319
|
+
if len(parts) != world_size:
|
|
320
|
+
raise ValueError("spu-endpoints count must match world size")
|
|
321
|
+
return {i: parts[i] for i in range(world_size)}
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def cmd_run(args: argparse.Namespace) -> None:
|
|
325
|
+
"""Run a user job via HTTP cluster or local simulator."""
|
|
326
|
+
from mplang.v2 import make_driver, make_simulator
|
|
327
|
+
from mplang.v2.edsl.context import pop_context, push_context
|
|
328
|
+
from mplang.v2.libs.device import ClusterSpec
|
|
329
|
+
|
|
330
|
+
cluster: ClusterSpec
|
|
331
|
+
|
|
332
|
+
if args.config:
|
|
333
|
+
# Load cluster from config file
|
|
334
|
+
with open(args.config, encoding="utf-8") as f:
|
|
335
|
+
conf = yaml.safe_load(f)
|
|
336
|
+
cluster = ClusterSpec.from_dict(conf)
|
|
337
|
+
else:
|
|
338
|
+
# Build cluster from CLI arguments
|
|
339
|
+
endpoints, _, world_size, _ = build_endpoints(args)
|
|
340
|
+
cluster = ClusterSpec.simple(world_size, endpoints=endpoints)
|
|
341
|
+
|
|
342
|
+
driver: Any
|
|
343
|
+
if args.backend == "sim":
|
|
344
|
+
enable_tracing = getattr(args, "profile", False)
|
|
345
|
+
driver = make_simulator(
|
|
346
|
+
cluster.world_size,
|
|
347
|
+
cluster_spec=cluster,
|
|
348
|
+
enable_tracing=enable_tracing,
|
|
349
|
+
)
|
|
350
|
+
else:
|
|
351
|
+
driver = make_driver(cluster.endpoints, cluster_spec=cluster)
|
|
352
|
+
|
|
353
|
+
# Set up context: push driver and set global cluster
|
|
354
|
+
push_context(driver)
|
|
355
|
+
# REMOVED: set_global_cluster(cluster)
|
|
356
|
+
|
|
357
|
+
module = load_user_module(args.file)
|
|
358
|
+
entry = resolve_entry(module, args.entry)
|
|
359
|
+
|
|
360
|
+
try:
|
|
361
|
+
# Entry function doesn't need driver parameter - it uses context
|
|
362
|
+
result = entry(*args.args)
|
|
363
|
+
if result is not None:
|
|
364
|
+
print(result)
|
|
365
|
+
finally:
|
|
366
|
+
pop_context()
|
|
367
|
+
if hasattr(driver, "shutdown"):
|
|
368
|
+
driver.shutdown()
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def cmd_trace_merge(args: argparse.Namespace) -> None:
|
|
372
|
+
"""Merge multiple Chrome Trace JSON files into a single file."""
|
|
373
|
+
pattern = args.pattern
|
|
374
|
+
output_file = args.output
|
|
375
|
+
|
|
376
|
+
files = glob.glob(pattern)
|
|
377
|
+
if not files:
|
|
378
|
+
print(f"No files found matching pattern: {pattern}")
|
|
379
|
+
sys.exit(1)
|
|
380
|
+
|
|
381
|
+
print(f"Found {len(files)} trace files.")
|
|
382
|
+
|
|
383
|
+
merged_events = []
|
|
384
|
+
|
|
385
|
+
# Regex to extract rank from filename if present (e.g., trace_..._rank_0.json)
|
|
386
|
+
rank_pattern = re.compile(r"_rank_(\d+)\.json$")
|
|
387
|
+
|
|
388
|
+
for fname in files:
|
|
389
|
+
print(f"Processing {fname}...")
|
|
390
|
+
try:
|
|
391
|
+
with open(fname) as f:
|
|
392
|
+
data = json.load(f)
|
|
393
|
+
events = data.get("traceEvents", [])
|
|
394
|
+
|
|
395
|
+
# Determine rank/pid offset
|
|
396
|
+
match = rank_pattern.search(fname)
|
|
397
|
+
if match:
|
|
398
|
+
rank = int(match.group(1))
|
|
399
|
+
# Remap PID: rank 0 -> 1000, rank 1 -> 2000, etc.
|
|
400
|
+
# Or just use rank as PID if it's small enough, but Perfetto likes PIDs
|
|
401
|
+
pid_offset = (rank + 1) * 10000
|
|
402
|
+
else:
|
|
403
|
+
pid_offset = 0
|
|
404
|
+
|
|
405
|
+
for event in events:
|
|
406
|
+
# Remap PID
|
|
407
|
+
if "pid" in event:
|
|
408
|
+
# If original PID is present, we shift it to avoid collision
|
|
409
|
+
# between different processes on different machines that might have same PID
|
|
410
|
+
original_pid = event["pid"]
|
|
411
|
+
# Simple remapping: new_pid = offset + (original_pid % 10000)
|
|
412
|
+
# This preserves thread grouping within a rank
|
|
413
|
+
event["pid"] = pid_offset + (original_pid % 10000)
|
|
414
|
+
|
|
415
|
+
# Add rank info to args if not present
|
|
416
|
+
if match:
|
|
417
|
+
event_args = event.get("args", {})
|
|
418
|
+
event_args["rank"] = rank
|
|
419
|
+
event["args"] = event_args
|
|
420
|
+
|
|
421
|
+
merged_events.append(event)
|
|
422
|
+
|
|
423
|
+
except Exception as e:
|
|
424
|
+
print(f"Error processing {fname}: {e}")
|
|
425
|
+
|
|
426
|
+
# Write merged file
|
|
427
|
+
with open(output_file, "w") as f:
|
|
428
|
+
json.dump({"traceEvents": merged_events}, f)
|
|
429
|
+
|
|
430
|
+
print(f"Successfully merged {len(merged_events)} events into {output_file}")
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def cmd_objects(args: argparse.Namespace) -> None:
|
|
434
|
+
"""List objects on workers."""
|
|
435
|
+
endpoints, _, world_size, _ = build_endpoints(args)
|
|
436
|
+
|
|
437
|
+
print(f"Listing objects on {world_size} workers...")
|
|
438
|
+
print("-" * 80)
|
|
439
|
+
print(f"{'Rank':<6} | {'Endpoint':<25} | {'Count':<6} | {'Objects'}")
|
|
440
|
+
print("-" * 80)
|
|
441
|
+
|
|
442
|
+
import httpx
|
|
443
|
+
|
|
444
|
+
for rank in range(world_size):
|
|
445
|
+
url = f"{endpoints[rank]}/objects"
|
|
446
|
+
try:
|
|
447
|
+
resp = httpx.get(url, timeout=2.0)
|
|
448
|
+
if resp.status_code == 200:
|
|
449
|
+
objects = resp.json()["objects"]
|
|
450
|
+
count = len(objects)
|
|
451
|
+
# Truncate list if too long
|
|
452
|
+
obj_str = ", ".join(objects[:3])
|
|
453
|
+
if count > 3:
|
|
454
|
+
obj_str += ", ..."
|
|
455
|
+
print(f"{rank:<6} | {endpoints[rank]:<25} | {count:<6} | {obj_str}")
|
|
456
|
+
else:
|
|
457
|
+
print(
|
|
458
|
+
f"{rank:<6} | {endpoints[rank]:<25} | {'Err':<6} | Status {resp.status_code}"
|
|
459
|
+
)
|
|
460
|
+
except Exception as e:
|
|
461
|
+
print(f"{rank:<6} | {endpoints[rank]:<25} | {'Err':<6} | {e}")
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def main() -> None:
|
|
465
|
+
parser = argparse.ArgumentParser(
|
|
466
|
+
description="MPLang2 cluster and job CLI",
|
|
467
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
|
471
|
+
|
|
472
|
+
# 'config' subcommand
|
|
473
|
+
config_parser = subparsers.add_parser("config", help="Configuration management")
|
|
474
|
+
config_subparsers = config_parser.add_subparsers(
|
|
475
|
+
dest="config_command", help="Config commands"
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
# 'config gen'
|
|
479
|
+
gen_parser = config_subparsers.add_parser("gen", help="Generate cluster config")
|
|
480
|
+
gen_parser.add_argument(
|
|
481
|
+
"-w", "--world-size", type=int, default=3, help="Number of workers"
|
|
482
|
+
)
|
|
483
|
+
gen_parser.add_argument(
|
|
484
|
+
"-p", "--base-port", type=int, default=8100, help="Base port"
|
|
485
|
+
)
|
|
486
|
+
gen_parser.add_argument(
|
|
487
|
+
"--spu-base-port", type=int, default=None, help="Base port for SPU"
|
|
488
|
+
)
|
|
489
|
+
gen_parser.add_argument("-o", "--output", type=str, help="Output file path")
|
|
490
|
+
|
|
491
|
+
# 'worker' subcommand
|
|
492
|
+
worker_parser = subparsers.add_parser("worker", help="Start a single worker")
|
|
493
|
+
add_cluster_args(worker_parser)
|
|
494
|
+
worker_parser.add_argument(
|
|
495
|
+
"--rank", type=int, required=True, help="Rank of this worker"
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
# 'up' subcommand
|
|
499
|
+
up_parser = subparsers.add_parser("up", help="Start local cluster (all workers)")
|
|
500
|
+
add_cluster_args(up_parser)
|
|
501
|
+
|
|
502
|
+
# 'status' subcommand
|
|
503
|
+
status_parser = subparsers.add_parser("status", help="Check worker health")
|
|
504
|
+
add_cluster_args(status_parser, include_ports=True)
|
|
505
|
+
|
|
506
|
+
# 'run' subcommand
|
|
507
|
+
run_parser = subparsers.add_parser("run", help="Run a user job")
|
|
508
|
+
add_cluster_args(run_parser, include_ports=True)
|
|
509
|
+
run_parser.add_argument("-f", "--file", required=True, help="Path to user script")
|
|
510
|
+
run_parser.add_argument(
|
|
511
|
+
"--entry",
|
|
512
|
+
default="__mp_main__",
|
|
513
|
+
help="Entry function name in the user script (default: __mp_main__)",
|
|
514
|
+
)
|
|
515
|
+
run_parser.add_argument(
|
|
516
|
+
"--backend",
|
|
517
|
+
choices=["http", "sim"],
|
|
518
|
+
default="http",
|
|
519
|
+
help="Execution backend: http (cluster) or sim (local simulator)",
|
|
520
|
+
)
|
|
521
|
+
run_parser.add_argument(
|
|
522
|
+
"--args",
|
|
523
|
+
nargs="*",
|
|
524
|
+
default=[],
|
|
525
|
+
help="Arguments passed to the entry function",
|
|
526
|
+
)
|
|
527
|
+
run_parser.add_argument(
|
|
528
|
+
"--profile",
|
|
529
|
+
action="store_true",
|
|
530
|
+
help="Enable performance profiling (only for sim backend)",
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
# 'sim' subcommand
|
|
534
|
+
sim_parser = subparsers.add_parser("sim", help="Run a user job in local simulator")
|
|
535
|
+
add_cluster_args(sim_parser, include_ports=False)
|
|
536
|
+
sim_parser.add_argument("-f", "--file", required=True, help="Path to user script")
|
|
537
|
+
sim_parser.add_argument(
|
|
538
|
+
"--entry",
|
|
539
|
+
default="__mp_main__",
|
|
540
|
+
help="Entry function name in the user script (default: __mp_main__)",
|
|
541
|
+
)
|
|
542
|
+
sim_parser.add_argument(
|
|
543
|
+
"--args",
|
|
544
|
+
nargs="*",
|
|
545
|
+
default=[],
|
|
546
|
+
help="Arguments passed to the entry function",
|
|
547
|
+
)
|
|
548
|
+
sim_parser.add_argument(
|
|
549
|
+
"--profile",
|
|
550
|
+
action="store_true",
|
|
551
|
+
help="Enable performance profiling",
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
# 'objects' subcommand
|
|
555
|
+
objects_parser = subparsers.add_parser("objects", help="List objects on workers")
|
|
556
|
+
add_cluster_args(objects_parser, include_ports=False)
|
|
557
|
+
|
|
558
|
+
# 'trace' subcommand
|
|
559
|
+
trace_parser = subparsers.add_parser("trace", help="Trace utilities")
|
|
560
|
+
trace_subparsers = trace_parser.add_subparsers(
|
|
561
|
+
dest="trace_command", help="Trace commands"
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
# 'trace merge'
|
|
565
|
+
merge_parser = trace_subparsers.add_parser("merge", help="Merge trace files")
|
|
566
|
+
merge_parser.add_argument(
|
|
567
|
+
"pattern", help="Glob pattern for trace files (e.g. 'trace_*.json')"
|
|
568
|
+
)
|
|
569
|
+
merge_parser.add_argument(
|
|
570
|
+
"-o", "--output", default="merged_trace.json", help="Output filename"
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
args = parser.parse_args()
|
|
574
|
+
|
|
575
|
+
if args.command == "config":
|
|
576
|
+
if args.config_command == "gen":
|
|
577
|
+
cmd_config_gen(args)
|
|
578
|
+
else:
|
|
579
|
+
config_parser.print_help()
|
|
580
|
+
elif args.command == "trace":
|
|
581
|
+
if args.trace_command == "merge":
|
|
582
|
+
cmd_trace_merge(args)
|
|
583
|
+
else:
|
|
584
|
+
trace_parser.print_help()
|
|
585
|
+
elif args.command == "worker":
|
|
586
|
+
cmd_worker(args)
|
|
587
|
+
elif args.command == "up":
|
|
588
|
+
cmd_up(args)
|
|
589
|
+
elif args.command == "status":
|
|
590
|
+
cmd_status(args)
|
|
591
|
+
elif args.command == "run":
|
|
592
|
+
cmd_run(args)
|
|
593
|
+
elif args.command == "sim":
|
|
594
|
+
args.backend = "sim"
|
|
595
|
+
cmd_run(args)
|
|
596
|
+
elif args.command == "objects":
|
|
597
|
+
cmd_objects(args)
|
|
598
|
+
else:
|
|
599
|
+
parser.print_help()
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
if __name__ == "__main__":
|
|
603
|
+
main()
|