mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/v2/cli.py ADDED
@@ -0,0 +1,603 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Ant Group Co., Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ Command-line interface for MPLang2 clusters and jobs.
18
+
19
+ Examples:
20
+ # Generate a cluster config file
21
+ python -m mplang.v2.cli config gen -w 3 -p 8100 -o cluster.yaml
22
+
23
+ # Start a single worker (production usage)
24
+ python -m mplang.v2.cli worker --rank 0 -c cluster.yaml
25
+
26
+ # Start 3 local workers (development usage)
27
+ python -m mplang.v2.cli up -c cluster.yaml
28
+
29
+ # Check cluster status
30
+ python -m mplang.v2.cli status -c cluster.yaml
31
+
32
+ # Run a job
33
+ python -m mplang.v2.cli run -c cluster.yaml -f my_job.py
34
+ """
35
+
36
+ import argparse
37
+ import glob
38
+ import importlib.util
39
+ import json
40
+ import multiprocessing
41
+ import os
42
+ import re
43
+ import signal
44
+ import sys
45
+ from collections.abc import Callable
46
+ from types import ModuleType
47
+ from typing import Any, cast
48
+
49
+ import uvicorn
50
+ import yaml
51
+
52
+
53
+ def run_worker(
54
+ rank: int,
55
+ world_size: int,
56
+ port: int,
57
+ endpoints: list[str],
58
+ spu_endpoints: dict[int, str] | None = None,
59
+ ) -> None:
60
+ """Run a single worker server."""
61
+ # Reset signal handlers to default in child process to avoid conflict with parent's shutdown handler
62
+ signal.signal(signal.SIGINT, signal.SIG_DFL)
63
+ signal.signal(signal.SIGTERM, signal.SIG_DFL)
64
+
65
+ from mplang.v2.backends.simp_worker.http import create_worker_app
66
+
67
+ app = create_worker_app(rank, world_size, endpoints, spu_endpoints)
68
+
69
+ log_config: dict[str, Any] = {
70
+ "version": 1,
71
+ "disable_existing_loggers": False,
72
+ "formatters": {
73
+ "default": {
74
+ "format": f"[Worker {rank}] %(levelname)s: %(message)s",
75
+ },
76
+ },
77
+ "handlers": {
78
+ "default": {
79
+ "formatter": "default",
80
+ "class": "logging.StreamHandler",
81
+ "stream": "ext://sys.stderr",
82
+ },
83
+ },
84
+ "loggers": {
85
+ "uvicorn": {"handlers": ["default"], "level": "WARNING"},
86
+ "uvicorn.error": {"handlers": ["default"], "level": "WARNING"},
87
+ "uvicorn.access": {"handlers": ["default"], "level": "WARNING"},
88
+ },
89
+ }
90
+
91
+ uvicorn.run(
92
+ app,
93
+ host="0.0.0.0",
94
+ port=port,
95
+ log_config=log_config,
96
+ log_level="warning",
97
+ )
98
+
99
+
100
+ def build_endpoints(
101
+ args: argparse.Namespace,
102
+ ) -> tuple[list[str], list[int], int, dict[int, str] | None]:
103
+ """Build endpoints and SPU endpoints from config or CLI flags."""
104
+
105
+ def normalize_endpoint(ep: str) -> str:
106
+ return ep if ep.startswith("http") else f"http://{ep}"
107
+
108
+ spu_endpoints: dict[int, str] | None = None
109
+
110
+ if args.config:
111
+ with open(args.config, encoding="utf-8") as f:
112
+ conf = yaml.safe_load(f)
113
+ nodes = conf.get("nodes", [])
114
+ if not nodes:
115
+ raise ValueError("Config must contain nodes")
116
+ world_size = len(nodes)
117
+ endpoints = []
118
+ ports = []
119
+ for node in nodes:
120
+ endpoint = normalize_endpoint(node["endpoint"])
121
+ endpoints.append(endpoint)
122
+ ports.append(int(endpoint.split(":")[-1]))
123
+
124
+ devices = conf.get("devices", {})
125
+ for _dev_name, dev_conf in devices.items():
126
+ if dev_conf.get("kind", "").upper() == "SPU":
127
+ spu_endpoints = {}
128
+ spu_base_port = args.spu_base_port or (ports[0] + 1000)
129
+ for i, node in enumerate(nodes):
130
+ host = node["endpoint"].split(":")[0]
131
+ spu_endpoints[i] = f"{host}:{spu_base_port + i}"
132
+ break
133
+ else:
134
+ world_size = args.world_size
135
+ base_port = getattr(args, "base_port", 5000)
136
+ ports = [base_port + i for i in range(world_size)]
137
+ endpoints = [f"http://127.0.0.1:{p}" for p in ports]
138
+
139
+ if args.spu_base_port:
140
+ spu_endpoints = {
141
+ i: f"127.0.0.1:{args.spu_base_port + i}" for i in range(world_size)
142
+ }
143
+
144
+ if args.endpoints:
145
+ endpoints = [normalize_endpoint(ep.strip()) for ep in args.endpoints.split(",")]
146
+ ports = [int(ep.split(":")[-1]) for ep in endpoints]
147
+ world_size = len(endpoints)
148
+
149
+ return endpoints, ports, world_size, spu_endpoints
150
+
151
+
152
+ def add_cluster_args(
153
+ parser: argparse.ArgumentParser, *, include_ports: bool = True
154
+ ) -> None:
155
+ """Add common cluster arguments to subparsers."""
156
+
157
+ parser.add_argument("-c", "--config", type=str, help="Cluster config YAML")
158
+ parser.add_argument("--endpoints", type=str, help="Comma-separated HTTP endpoints")
159
+ parser.add_argument(
160
+ "--spu-endpoints", type=str, help="Comma-separated SPU BRPC endpoints"
161
+ )
162
+ parser.add_argument(
163
+ "--spu-base-port",
164
+ type=int,
165
+ default=None,
166
+ help="Base port for SPU BRPC (default: http_port + 1000)",
167
+ )
168
+ parser.add_argument(
169
+ "-w",
170
+ "--world-size",
171
+ type=int,
172
+ default=3,
173
+ help="Number of workers (default: 3)",
174
+ )
175
+ if include_ports:
176
+ parser.add_argument(
177
+ "-p",
178
+ "--base-port",
179
+ type=int,
180
+ default=8100,
181
+ help="Base port (default: 8100)",
182
+ )
183
+
184
+
185
+ def cmd_config_gen(args: argparse.Namespace) -> None:
186
+ """Generate cluster configuration."""
187
+ world_size = args.world_size
188
+ base_port = args.base_port
189
+
190
+ nodes = []
191
+ for i in range(world_size):
192
+ nodes.append({"name": f"node_{i}", "endpoint": f"127.0.0.1:{base_port + i}"})
193
+
194
+ config: dict[str, Any] = {"nodes": nodes}
195
+
196
+ # Add default PPU devices
197
+ devices: dict[str, Any] = {}
198
+ for i in range(world_size):
199
+ devices[f"P{i}"] = {
200
+ "kind": "ppu",
201
+ "members": [f"node_{i}"],
202
+ }
203
+
204
+ if args.spu_base_port:
205
+ devices["SPU0"] = {
206
+ "kind": "SPU",
207
+ "members": [n["name"] for n in nodes],
208
+ "config": {"protocol": "ABY3", "field": "FM64"},
209
+ }
210
+ config["devices"] = devices
211
+
212
+ yaml_content = yaml.dump(config, sort_keys=False)
213
+
214
+ if args.output:
215
+ with open(args.output, "w") as f:
216
+ f.write(yaml_content)
217
+ print(f"Config written to {args.output}")
218
+ else:
219
+ print(yaml_content)
220
+
221
+
222
+ def cmd_worker(args: argparse.Namespace) -> None:
223
+ """Start a single worker process."""
224
+ endpoints, ports, world_size, spu_endpoints = build_endpoints(args)
225
+ rank = args.rank
226
+
227
+ if rank < 0 or rank >= world_size:
228
+ raise ValueError(f"Rank {rank} is out of range [0, {world_size - 1}]")
229
+
230
+ print(f"Starting Worker {rank} on {endpoints[rank]}...")
231
+ if spu_endpoints and rank in spu_endpoints:
232
+ print(f" SPU BRPC: {spu_endpoints[rank]}")
233
+
234
+ run_worker(rank, world_size, ports[rank], endpoints, spu_endpoints)
235
+
236
+
237
+ def cmd_up(args: argparse.Namespace) -> None:
238
+ """Start worker servers locally."""
239
+ endpoints, ports, world_size, spu_endpoints = build_endpoints(args)
240
+
241
+ print(f"Starting {world_size} workers...")
242
+ for i, endpoint in enumerate(endpoints):
243
+ print(f" Worker {i}: {endpoint}")
244
+ if spu_endpoints:
245
+ print("SPU BRPC endpoints:")
246
+ for rank, ep in spu_endpoints.items():
247
+ print(f" Rank {rank}: {ep}")
248
+
249
+ processes: list[multiprocessing.Process] = []
250
+
251
+ def shutdown(signum: int, frame: Any) -> None:
252
+ print("\nShutting down workers...")
253
+ for p in processes:
254
+ p.terminate()
255
+ for p in processes:
256
+ p.join(timeout=5)
257
+ sys.exit(0)
258
+
259
+ signal.signal(signal.SIGINT, shutdown)
260
+ signal.signal(signal.SIGTERM, shutdown)
261
+
262
+ for rank in range(world_size):
263
+ p = multiprocessing.Process(
264
+ target=run_worker,
265
+ args=(rank, world_size, ports[rank], endpoints, spu_endpoints),
266
+ )
267
+ p.start()
268
+ processes.append(p)
269
+
270
+ print("\nWorkers started. Press Ctrl+C to stop.")
271
+
272
+ for p in processes:
273
+ p.join()
274
+
275
+
276
+ def cmd_status(args: argparse.Namespace) -> None:
277
+ """Check /health of workers."""
278
+ import httpx
279
+
280
+ endpoints, _, world_size, _ = build_endpoints(args)
281
+
282
+ print(f"Checking {len(endpoints)} endpoints (world_size={world_size})...")
283
+ for ep in endpoints:
284
+ url = f"{ep}/health"
285
+ try:
286
+ resp = httpx.get(url, timeout=3.0)
287
+ resp.raise_for_status()
288
+ print(f"OK {url} -> {resp.json()}")
289
+ except Exception as exc:
290
+ print(f"ERR {url} -> {exc}")
291
+
292
+
293
+ def load_user_module(path: str) -> ModuleType:
294
+ """Load a Python module from file path."""
295
+ if not os.path.exists(path):
296
+ raise FileNotFoundError(path)
297
+ spec = importlib.util.spec_from_file_location("mp_user_module", path)
298
+ if spec is None or spec.loader is None:
299
+ raise ImportError(f"Cannot import module from {path}")
300
+ module = importlib.util.module_from_spec(spec)
301
+ sys.modules[spec.name] = module
302
+ spec.loader.exec_module(module)
303
+ return module
304
+
305
+
306
+ def resolve_entry(module: ModuleType, name: str) -> Callable[..., Any]:
307
+ entry = getattr(module, name, None)
308
+ if entry is None or not callable(entry):
309
+ raise AttributeError(f"Entry function '{name}' not found or not callable")
310
+ return cast(Callable[..., Any], entry)
311
+
312
+
313
+ def parse_spu_endpoints(
314
+ raw: str | None, world_size: int, default: dict[int, str] | None
315
+ ) -> dict[int, str] | None:
316
+ if raw is None:
317
+ return default
318
+ parts = [p.strip() for p in raw.split(",") if p.strip()]
319
+ if len(parts) != world_size:
320
+ raise ValueError("spu-endpoints count must match world size")
321
+ return {i: parts[i] for i in range(world_size)}
322
+
323
+
324
+ def cmd_run(args: argparse.Namespace) -> None:
325
+ """Run a user job via HTTP cluster or local simulator."""
326
+ from mplang.v2 import make_driver, make_simulator
327
+ from mplang.v2.edsl.context import pop_context, push_context
328
+ from mplang.v2.libs.device import ClusterSpec
329
+
330
+ cluster: ClusterSpec
331
+
332
+ if args.config:
333
+ # Load cluster from config file
334
+ with open(args.config, encoding="utf-8") as f:
335
+ conf = yaml.safe_load(f)
336
+ cluster = ClusterSpec.from_dict(conf)
337
+ else:
338
+ # Build cluster from CLI arguments
339
+ endpoints, _, world_size, _ = build_endpoints(args)
340
+ cluster = ClusterSpec.simple(world_size, endpoints=endpoints)
341
+
342
+ driver: Any
343
+ if args.backend == "sim":
344
+ enable_tracing = getattr(args, "profile", False)
345
+ driver = make_simulator(
346
+ cluster.world_size,
347
+ cluster_spec=cluster,
348
+ enable_tracing=enable_tracing,
349
+ )
350
+ else:
351
+ driver = make_driver(cluster.endpoints, cluster_spec=cluster)
352
+
353
+ # Set up context: push driver and set global cluster
354
+ push_context(driver)
355
+ # REMOVED: set_global_cluster(cluster)
356
+
357
+ module = load_user_module(args.file)
358
+ entry = resolve_entry(module, args.entry)
359
+
360
+ try:
361
+ # Entry function doesn't need driver parameter - it uses context
362
+ result = entry(*args.args)
363
+ if result is not None:
364
+ print(result)
365
+ finally:
366
+ pop_context()
367
+ if hasattr(driver, "shutdown"):
368
+ driver.shutdown()
369
+
370
+
371
+ def cmd_trace_merge(args: argparse.Namespace) -> None:
372
+ """Merge multiple Chrome Trace JSON files into a single file."""
373
+ pattern = args.pattern
374
+ output_file = args.output
375
+
376
+ files = glob.glob(pattern)
377
+ if not files:
378
+ print(f"No files found matching pattern: {pattern}")
379
+ sys.exit(1)
380
+
381
+ print(f"Found {len(files)} trace files.")
382
+
383
+ merged_events = []
384
+
385
+ # Regex to extract rank from filename if present (e.g., trace_..._rank_0.json)
386
+ rank_pattern = re.compile(r"_rank_(\d+)\.json$")
387
+
388
+ for fname in files:
389
+ print(f"Processing {fname}...")
390
+ try:
391
+ with open(fname) as f:
392
+ data = json.load(f)
393
+ events = data.get("traceEvents", [])
394
+
395
+ # Determine rank/pid offset
396
+ match = rank_pattern.search(fname)
397
+ if match:
398
+ rank = int(match.group(1))
399
+ # Remap PID: rank 0 -> 1000, rank 1 -> 2000, etc.
400
+ # Or just use rank as PID if it's small enough, but Perfetto likes PIDs
401
+ pid_offset = (rank + 1) * 10000
402
+ else:
403
+ pid_offset = 0
404
+
405
+ for event in events:
406
+ # Remap PID
407
+ if "pid" in event:
408
+ # If original PID is present, we shift it to avoid collision
409
+ # between different processes on different machines that might have same PID
410
+ original_pid = event["pid"]
411
+ # Simple remapping: new_pid = offset + (original_pid % 10000)
412
+ # This preserves thread grouping within a rank
413
+ event["pid"] = pid_offset + (original_pid % 10000)
414
+
415
+ # Add rank info to args if not present
416
+ if match:
417
+ event_args = event.get("args", {})
418
+ event_args["rank"] = rank
419
+ event["args"] = event_args
420
+
421
+ merged_events.append(event)
422
+
423
+ except Exception as e:
424
+ print(f"Error processing {fname}: {e}")
425
+
426
+ # Write merged file
427
+ with open(output_file, "w") as f:
428
+ json.dump({"traceEvents": merged_events}, f)
429
+
430
+ print(f"Successfully merged {len(merged_events)} events into {output_file}")
431
+
432
+
433
+ def cmd_objects(args: argparse.Namespace) -> None:
434
+ """List objects on workers."""
435
+ endpoints, _, world_size, _ = build_endpoints(args)
436
+
437
+ print(f"Listing objects on {world_size} workers...")
438
+ print("-" * 80)
439
+ print(f"{'Rank':<6} | {'Endpoint':<25} | {'Count':<6} | {'Objects'}")
440
+ print("-" * 80)
441
+
442
+ import httpx
443
+
444
+ for rank in range(world_size):
445
+ url = f"{endpoints[rank]}/objects"
446
+ try:
447
+ resp = httpx.get(url, timeout=2.0)
448
+ if resp.status_code == 200:
449
+ objects = resp.json()["objects"]
450
+ count = len(objects)
451
+ # Truncate list if too long
452
+ obj_str = ", ".join(objects[:3])
453
+ if count > 3:
454
+ obj_str += ", ..."
455
+ print(f"{rank:<6} | {endpoints[rank]:<25} | {count:<6} | {obj_str}")
456
+ else:
457
+ print(
458
+ f"{rank:<6} | {endpoints[rank]:<25} | {'Err':<6} | Status {resp.status_code}"
459
+ )
460
+ except Exception as e:
461
+ print(f"{rank:<6} | {endpoints[rank]:<25} | {'Err':<6} | {e}")
462
+
463
+
464
+ def main() -> None:
465
+ parser = argparse.ArgumentParser(
466
+ description="MPLang2 cluster and job CLI",
467
+ formatter_class=argparse.RawDescriptionHelpFormatter,
468
+ )
469
+
470
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
471
+
472
+ # 'config' subcommand
473
+ config_parser = subparsers.add_parser("config", help="Configuration management")
474
+ config_subparsers = config_parser.add_subparsers(
475
+ dest="config_command", help="Config commands"
476
+ )
477
+
478
+ # 'config gen'
479
+ gen_parser = config_subparsers.add_parser("gen", help="Generate cluster config")
480
+ gen_parser.add_argument(
481
+ "-w", "--world-size", type=int, default=3, help="Number of workers"
482
+ )
483
+ gen_parser.add_argument(
484
+ "-p", "--base-port", type=int, default=8100, help="Base port"
485
+ )
486
+ gen_parser.add_argument(
487
+ "--spu-base-port", type=int, default=None, help="Base port for SPU"
488
+ )
489
+ gen_parser.add_argument("-o", "--output", type=str, help="Output file path")
490
+
491
+ # 'worker' subcommand
492
+ worker_parser = subparsers.add_parser("worker", help="Start a single worker")
493
+ add_cluster_args(worker_parser)
494
+ worker_parser.add_argument(
495
+ "--rank", type=int, required=True, help="Rank of this worker"
496
+ )
497
+
498
+ # 'up' subcommand
499
+ up_parser = subparsers.add_parser("up", help="Start local cluster (all workers)")
500
+ add_cluster_args(up_parser)
501
+
502
+ # 'status' subcommand
503
+ status_parser = subparsers.add_parser("status", help="Check worker health")
504
+ add_cluster_args(status_parser, include_ports=True)
505
+
506
+ # 'run' subcommand
507
+ run_parser = subparsers.add_parser("run", help="Run a user job")
508
+ add_cluster_args(run_parser, include_ports=True)
509
+ run_parser.add_argument("-f", "--file", required=True, help="Path to user script")
510
+ run_parser.add_argument(
511
+ "--entry",
512
+ default="__mp_main__",
513
+ help="Entry function name in the user script (default: __mp_main__)",
514
+ )
515
+ run_parser.add_argument(
516
+ "--backend",
517
+ choices=["http", "sim"],
518
+ default="http",
519
+ help="Execution backend: http (cluster) or sim (local simulator)",
520
+ )
521
+ run_parser.add_argument(
522
+ "--args",
523
+ nargs="*",
524
+ default=[],
525
+ help="Arguments passed to the entry function",
526
+ )
527
+ run_parser.add_argument(
528
+ "--profile",
529
+ action="store_true",
530
+ help="Enable performance profiling (only for sim backend)",
531
+ )
532
+
533
+ # 'sim' subcommand
534
+ sim_parser = subparsers.add_parser("sim", help="Run a user job in local simulator")
535
+ add_cluster_args(sim_parser, include_ports=False)
536
+ sim_parser.add_argument("-f", "--file", required=True, help="Path to user script")
537
+ sim_parser.add_argument(
538
+ "--entry",
539
+ default="__mp_main__",
540
+ help="Entry function name in the user script (default: __mp_main__)",
541
+ )
542
+ sim_parser.add_argument(
543
+ "--args",
544
+ nargs="*",
545
+ default=[],
546
+ help="Arguments passed to the entry function",
547
+ )
548
+ sim_parser.add_argument(
549
+ "--profile",
550
+ action="store_true",
551
+ help="Enable performance profiling",
552
+ )
553
+
554
+ # 'objects' subcommand
555
+ objects_parser = subparsers.add_parser("objects", help="List objects on workers")
556
+ add_cluster_args(objects_parser, include_ports=False)
557
+
558
+ # 'trace' subcommand
559
+ trace_parser = subparsers.add_parser("trace", help="Trace utilities")
560
+ trace_subparsers = trace_parser.add_subparsers(
561
+ dest="trace_command", help="Trace commands"
562
+ )
563
+
564
+ # 'trace merge'
565
+ merge_parser = trace_subparsers.add_parser("merge", help="Merge trace files")
566
+ merge_parser.add_argument(
567
+ "pattern", help="Glob pattern for trace files (e.g. 'trace_*.json')"
568
+ )
569
+ merge_parser.add_argument(
570
+ "-o", "--output", default="merged_trace.json", help="Output filename"
571
+ )
572
+
573
+ args = parser.parse_args()
574
+
575
+ if args.command == "config":
576
+ if args.config_command == "gen":
577
+ cmd_config_gen(args)
578
+ else:
579
+ config_parser.print_help()
580
+ elif args.command == "trace":
581
+ if args.trace_command == "merge":
582
+ cmd_trace_merge(args)
583
+ else:
584
+ trace_parser.print_help()
585
+ elif args.command == "worker":
586
+ cmd_worker(args)
587
+ elif args.command == "up":
588
+ cmd_up(args)
589
+ elif args.command == "status":
590
+ cmd_status(args)
591
+ elif args.command == "run":
592
+ cmd_run(args)
593
+ elif args.command == "sim":
594
+ args.backend = "sim"
595
+ cmd_run(args)
596
+ elif args.command == "objects":
597
+ cmd_objects(args)
598
+ else:
599
+ parser.print_help()
600
+
601
+
602
+ if __name__ == "__main__":
603
+ main()