mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,136 @@
1
+ # Simp Dialect Backend Design
2
+
3
+ ## Overview
4
+
5
+ The `simp` (Simple Multi-Party) dialect implements SPMD (Single Program Multiple Data) distributed execution. A single program is written once and executed across multiple parties, with the runtime handling distribution, communication, and synchronization.
6
+
7
+ ## Why Two Implementations?
8
+
9
+ The simp dialect requires **two separate backend implementations** because the same primitives (`pcall`, `shuffle`, `converge`) have fundamentally different semantics depending on where they execute:
10
+
11
+ | Primitive | Driver (Host) | Worker |
12
+ |-----------|---------------|--------|
13
+ | `pcall` | Dispatch work to workers | Execute locally |
14
+ | `shuffle` | Route data between workers | Send/Receive via communicator |
15
+ | `converge` | Merge HostVars | Pick non-null value |
16
+
17
+ This is the essence of SPMD: the Driver orchestrates, Workers execute.
18
+
19
+ ## Architecture
20
+
21
+ ```
22
+ ┌─────────────────────────────────────────────────────────────────┐
23
+ │ dialects/simp.py │
24
+ │ (Primitive definitions) │
25
+ └─────────────────────────────────────────────────────────────────┘
26
+
27
+ ┌─────────────┴─────────────┐
28
+ ▼ ▼
29
+ ┌───────────────────────────┐ ┌───────────────────────────┐
30
+ │ simp_driver/ │ │ simp_worker/ │
31
+ │ (Host/Driver side) │ │ (Worker side) │
32
+ ├───────────────────────────┤ ├───────────────────────────┤
33
+ │ base.py SimpDriver │ │ state.py SimpWorker │
34
+ │ ops.py HOST_HANDLERS │ │ ops.py WORKER_HANDLERS│
35
+ │ values.py HostVar │ │ │
36
+ │ mem.py SimpMemDriver │ │ mem.py LocalMesh │
37
+ │ http.py SimpHttpDriver │ │ http.py HTTP Server │
38
+ └───────────────────────────┘ └───────────────────────────┘
39
+ ```
40
+
41
+ ## Directory Structure
42
+
43
+ ```
44
+ backends/
45
+ ├── simp_driver/ # Driver/Host side
46
+ │ ├── __init__.py # Exports
47
+ │ ├── base.py # SimpDriver (abstract base)
48
+ │ ├── ops.py # HOST_HANDLERS
49
+ │ ├── values.py # HostVar
50
+ │ ├── mem.py # MemCluster + SimpMemDriver + make_simulator
51
+ │ └── http.py # SimpHttpDriver + make_driver
52
+
53
+ ├── simp_worker/ # Worker side
54
+ │ ├── __init__.py # Exports
55
+ │ ├── state.py # SimpWorker (DialectState)
56
+ │ ├── ops.py # WORKER_HANDLERS
57
+ │ ├── mem.py # LocalMesh + ThreadCommunicator
58
+ │ └── http.py # HTTP Worker Server
59
+ ```
60
+
61
+ ## Key Classes
62
+
63
+ ### Driver Side
64
+
65
+ ```python
66
+ class SimpDriver(DialectState, ABC):
67
+ """Abstract interface for simp drivers."""
68
+ dialect_name = "simp"
69
+ world_size: int
70
+
71
+ @abstractmethod
72
+ def submit(self, rank, graph, inputs, job_id=None) -> Future: ...
73
+ @abstractmethod
74
+ def fetch(self, rank, uri) -> Future: ...
75
+ @abstractmethod
76
+ def collect(self, futures) -> list: ...
77
+
78
+ class SimpMemDriver(SimpDriver):
79
+ """In-memory IPC via ThreadPoolExecutor."""
80
+
81
+ class SimpHttpDriver(SimpDriver):
82
+ """HTTP IPC via httpx."""
83
+ ```
84
+
85
+ ### Worker Side
86
+
87
+ ```python
88
+ class SimpWorker(DialectState):
89
+ """Worker state with communicator and store."""
90
+ dialect_name = "simp"
91
+ rank: int
92
+ world_size: int
93
+ communicator: Any # ThreadCommunicator or HTTP client
94
+ store: ObjectStore
95
+ ```
96
+
97
+ ## IPC Symmetry
98
+
99
+ | IPC Type | Driver | Worker |
100
+ |----------|--------|--------|
101
+ | Memory | `simp_driver/mem.py` | `simp_worker/mem.py` |
102
+ | HTTP | `simp_driver/http.py` | `simp_worker/http.py` |
103
+
104
+ ## Factory Functions
105
+
106
+ ```python
107
+ # Create local simulator (memory IPC)
108
+ interp = simp.make_simulator(world_size=3)
109
+
110
+ # Create remote driver (HTTP IPC)
111
+ interp = simp.make_driver(["http://w1:8000", "http://w2:8000"])
112
+ ```
113
+
114
+ ## Data Flow
115
+
116
+ ```
117
+ User Code
118
+
119
+
120
+ simp.pcall(parties=(0,1), fn, args)
121
+
122
+ ▼ (Driver Interpreter)
123
+ HOST_HANDLERS["simp.pcall"]
124
+
125
+ ├─► driver.submit(rank=0, graph, inputs)
126
+ └─► driver.submit(rank=1, graph, inputs)
127
+
128
+ ▼ (IPC: Memory or HTTP)
129
+ Worker Interpreters
130
+
131
+
132
+ WORKER_HANDLERS["simp.pcall"]
133
+
134
+
135
+ Local Execution
136
+ ```
@@ -0,0 +1,41 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Simp Driver package.
16
+
17
+ Provides Driver-side state, values, and ops for the simp dialect.
18
+ """
19
+
20
+ from mplang.v2.backends.simp_driver.http import SimpHttpDriver, make_driver
21
+ from mplang.v2.backends.simp_driver.mem import (
22
+ LocalCluster,
23
+ MemCluster,
24
+ SimpMemDriver,
25
+ make_simulator,
26
+ )
27
+ from mplang.v2.backends.simp_driver.ops import DRIVER_HANDLERS
28
+ from mplang.v2.backends.simp_driver.state import SimpDriver
29
+ from mplang.v2.backends.simp_driver.values import DriverVar
30
+
31
+ __all__ = [
32
+ "DRIVER_HANDLERS",
33
+ "DriverVar",
34
+ "LocalCluster",
35
+ "MemCluster",
36
+ "SimpDriver",
37
+ "SimpHttpDriver",
38
+ "SimpMemDriver",
39
+ "make_driver",
40
+ "make_simulator",
41
+ ]
@@ -0,0 +1,168 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Simp Driver HTTP IPC (SimpHttpDriver, make_driver)."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import concurrent.futures
20
+ import os
21
+ import pathlib
22
+ from typing import TYPE_CHECKING, Any
23
+
24
+ import httpx
25
+
26
+ from mplang.v2.backends.simp_driver.state import SimpDriver
27
+ from mplang.v2.edsl import serde
28
+ from mplang.v2.runtime.interpreter import Interpreter
29
+ from mplang.v2.runtime.object_store import ObjectStore
30
+
31
+ if TYPE_CHECKING:
32
+ from concurrent.futures import Future
33
+
34
+ from mplang.v2.edsl.graph import Graph
35
+ from mplang.v2.libs.device import ClusterSpec
36
+
37
+
38
+ class SimpHttpDriver(SimpDriver):
39
+ """Simp Driver for remote HTTP IPC.
40
+
41
+ Implements submit/fetch/collect interface for dispatching work via HTTP.
42
+ """
43
+
44
+ dialect_name: str = "simp"
45
+
46
+ def __init__(
47
+ self,
48
+ endpoints: list[str],
49
+ *,
50
+ cluster_spec: ClusterSpec | None = None,
51
+ ) -> None:
52
+ """Create remote simp driver.
53
+
54
+ Args:
55
+ endpoints: List of HTTP endpoints for workers.
56
+ cluster_spec: Optional cluster specification for metadata.
57
+ """
58
+ # Normalize endpoints
59
+ self._endpoints = []
60
+ for ep in endpoints:
61
+ if not ep.startswith("http://") and not ep.startswith("https://"):
62
+ ep = f"http://{ep}"
63
+ self._endpoints.append(ep)
64
+
65
+ self._world_size = len(endpoints)
66
+ self._cluster_spec = cluster_spec
67
+
68
+ # Construct driver root
69
+ data_root = pathlib.Path(os.environ.get("MPLANG_DATA_ROOT", ".mpl"))
70
+ if cluster_spec:
71
+ self.driver_root = data_root / cluster_spec.cluster_id / "__driver__"
72
+ else:
73
+ self.driver_root = data_root / "__remote__" / "__driver__"
74
+
75
+ # HTTP client and executor
76
+ self._client = httpx.Client(timeout=None)
77
+ self._executor = concurrent.futures.ThreadPoolExecutor(
78
+ max_workers=self._world_size
79
+ )
80
+
81
+ @property
82
+ def world_size(self) -> int:
83
+ return self._world_size
84
+
85
+ def submit(
86
+ self, rank: int, graph: Graph, inputs: list[Any], job_id: str | None = None
87
+ ) -> Future[Any]:
88
+ """Submit execution to remote worker via HTTP."""
89
+ return self._executor.submit(self._do_request, rank, graph, inputs, job_id)
90
+
91
+ def collect(self, futures: list[Future[Any]]) -> list[Any]:
92
+ """Collect results from futures."""
93
+ return [f.result() for f in futures]
94
+
95
+ def fetch(self, rank: int, uri: str) -> Future[Any]:
96
+ """Fetch data from remote worker."""
97
+ return self._executor.submit(self._do_fetch, rank, uri)
98
+
99
+ def _do_request(
100
+ self, rank: int, graph: Graph, inputs: list[Any], job_id: str | None = None
101
+ ) -> Any:
102
+ """Execute HTTP request."""
103
+ url = f"{self._endpoints[rank]}/exec"
104
+ graph_b64 = serde.dumps_b64(graph)
105
+ inputs_b64 = serde.dumps_b64(inputs)
106
+ payload = {"graph": graph_b64, "inputs": inputs_b64}
107
+ if job_id:
108
+ payload["job_id"] = job_id
109
+
110
+ resp = self._client.post(url, json=payload)
111
+ resp.raise_for_status()
112
+ return serde.loads_b64(resp.json()["result"])
113
+
114
+ def _do_fetch(self, rank: int, uri: str) -> Any:
115
+ """Fetch data from remote worker."""
116
+ url = f"{self._endpoints[rank]}/fetch"
117
+ resp = self._client.post(url, json={"uri": uri})
118
+ resp.raise_for_status()
119
+ return serde.loads_b64(resp.json()["result"])
120
+
121
+ def shutdown(self) -> None:
122
+ """Close HTTP client and executor."""
123
+ self._client.close()
124
+ self._executor.shutdown()
125
+
126
+
127
+ def make_driver(endpoints: list[str], *, cluster_spec: Any = None) -> Interpreter:
128
+ """Create an Interpreter configured for remote SIMP execution.
129
+
130
+ This factory creates a SimpHttpDriver and returns an Interpreter
131
+ with the simp dialect state attached.
132
+
133
+ Args:
134
+ endpoints: List of HTTP endpoints for workers.
135
+ cluster_spec: Optional ClusterSpec for metadata.
136
+
137
+ Returns:
138
+ Configured Interpreter with simp state attached.
139
+
140
+ Example:
141
+ >>> interp = make_driver(["http://worker1:8000", "http://worker2:8000"])
142
+ >>> with interp:
143
+ ... result = my_func()
144
+ """
145
+ from mplang.v2.backends.simp_driver.ops import DRIVER_HANDLERS
146
+
147
+ if cluster_spec is None:
148
+ from mplang.v2.libs.device import ClusterSpec
149
+
150
+ cluster_spec = ClusterSpec.simple(
151
+ world_size=len(endpoints), endpoints=endpoints
152
+ )
153
+
154
+ state = SimpHttpDriver(endpoints, cluster_spec=cluster_spec)
155
+
156
+ from collections.abc import Callable
157
+
158
+ handlers: dict[str, Callable[..., Any]] = {**DRIVER_HANDLERS} # type: ignore[dict-item]
159
+ interp = Interpreter(
160
+ name="DriverInterpreter",
161
+ root_dir=state.driver_root,
162
+ handlers=handlers,
163
+ store=ObjectStore(fs_root=str(state.driver_root)),
164
+ )
165
+ interp.set_dialect_state("simp", state)
166
+ interp._cluster_spec = cluster_spec # type: ignore[attr-defined]
167
+
168
+ return interp
@@ -0,0 +1,280 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Simp Driver memory IPC (MemCluster, SimpMemDriver, make_simulator)."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import concurrent.futures
20
+ import os
21
+ import pathlib
22
+ from collections.abc import Callable
23
+ from typing import TYPE_CHECKING, Any, cast
24
+
25
+ from mplang.v2.backends.simp_driver.state import SimpDriver
26
+ from mplang.v2.backends.simp_worker import WORKER_HANDLERS, SimpWorker
27
+ from mplang.v2.backends.simp_worker.mem import LocalMesh
28
+ from mplang.v2.runtime.interpreter import ExecutionTracer, Interpreter
29
+ from mplang.v2.runtime.object_store import ObjectStore
30
+
31
+ if TYPE_CHECKING:
32
+ from concurrent.futures import Future
33
+
34
+ from mplang.v2.edsl.graph import Graph
35
+ from mplang.v2.libs.device import ClusterSpec
36
+
37
+
38
+ class MemCluster:
39
+ """Orchestrator that creates and manages local worker Interpreters.
40
+
41
+ This class handles worker lifecycle management. It does NOT attach to
42
+ an Interpreter - instead, it creates a SimpMemDriver that can be attached.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ world_size: int,
48
+ *,
49
+ cluster_spec: ClusterSpec | None = None,
50
+ enable_tracing: bool = False,
51
+ ) -> None:
52
+ """Create a local memory cluster.
53
+
54
+ Args:
55
+ world_size: Number of workers.
56
+ cluster_spec: Optional cluster specification for metadata.
57
+ enable_tracing: If True, enable execution tracing.
58
+ """
59
+ self._world_size = world_size
60
+ self._cluster_spec = cluster_spec
61
+
62
+ # Construct root_dir from cluster_id
63
+ data_root = pathlib.Path(os.environ.get("MPLANG_DATA_ROOT", ".mpl"))
64
+ cluster_id = cluster_spec.cluster_id if cluster_spec else f"local_{world_size}"
65
+ cluster_root = data_root / cluster_id
66
+ self.host_root = cluster_root / "__host__"
67
+
68
+ # Create Local Mesh (communication mesh for workers)
69
+ self._mesh = LocalMesh(world_size)
70
+
71
+ # Create Execution Tracer
72
+ self.tracer: ExecutionTracer = ExecutionTracer(
73
+ enabled=enable_tracing, trace_dir=self.host_root / "trace"
74
+ )
75
+ self.tracer.start()
76
+
77
+ # Create Workers
78
+ self._workers: list[Interpreter] = []
79
+ for rank in range(world_size):
80
+ worker_root = cluster_root / f"node{rank}"
81
+ store = ObjectStore(fs_root=str(worker_root / "store"))
82
+
83
+ worker_state = SimpWorker(
84
+ rank=rank,
85
+ world_size=world_size,
86
+ communicator=self._mesh.comms[rank],
87
+ store=store,
88
+ )
89
+
90
+ w_handlers: dict[str, Callable[..., Any]] = {**WORKER_HANDLERS} # type: ignore[dict-item]
91
+ w_interp = Interpreter(
92
+ name=f"Worker-{rank}",
93
+ tracer=self.tracer,
94
+ trace_pid=rank,
95
+ store=store,
96
+ root_dir=worker_root,
97
+ handlers=w_handlers,
98
+ )
99
+ w_interp.set_dialect_state("simp", worker_state)
100
+
101
+ w_interp.async_ops = {
102
+ "bfv.add",
103
+ "bfv.mul",
104
+ "bfv.rotate",
105
+ "bfv.batch_encode",
106
+ "bfv.relinearize",
107
+ "bfv.encrypt",
108
+ "bfv.decrypt",
109
+ "field.solve_okvs",
110
+ "field.decode_okvs",
111
+ "field.aes_expand",
112
+ "field.mul",
113
+ "simp.shuffle",
114
+ }
115
+ self._workers.append(w_interp)
116
+
117
+ @property
118
+ def world_size(self) -> int:
119
+ return self._world_size
120
+
121
+ @property
122
+ def workers(self) -> list[Interpreter]:
123
+ return self._workers
124
+
125
+ def create_state(self) -> SimpMemDriver:
126
+ """Create a SimpMemDriver that can be attached to a Driver Interpreter."""
127
+ return SimpMemDriver(
128
+ world_size=self._world_size,
129
+ workers=self._workers,
130
+ mesh=self._mesh,
131
+ )
132
+
133
+ def shutdown(self, wait: bool = True) -> None:
134
+ """Stop all workers and release resources."""
135
+ self._mesh.shutdown(wait=wait)
136
+
137
+
138
+ class SimpMemDriver(SimpDriver):
139
+ """Simp Driver for local memory IPC.
140
+
141
+ Implements submit/fetch/collect interface for dispatching work to local workers.
142
+ This class is created by MemCluster and attached to a Driver Interpreter.
143
+ """
144
+
145
+ dialect_name: str = "simp"
146
+
147
+ def __init__(
148
+ self,
149
+ world_size: int,
150
+ workers: list[Interpreter],
151
+ mesh: Any, # LocalMesh from simp_worker.mem
152
+ ) -> None:
153
+ self._world_size = world_size
154
+ self._workers = workers
155
+ self._mesh = mesh
156
+
157
+ def shutdown(self) -> None:
158
+ """Shutdown the local memory driver and its mesh."""
159
+ self._mesh.shutdown()
160
+
161
+ @property
162
+ def world_size(self) -> int:
163
+ return self._world_size
164
+
165
+ @property
166
+ def workers(self) -> list[Interpreter]:
167
+ """Worker interpreters (for backward compatibility)."""
168
+ return self._workers
169
+
170
+ def submit(
171
+ self, rank: int, graph: Graph, inputs: list[Any], job_id: str | None = None
172
+ ) -> Future[Any]:
173
+ """Submit execution to local worker thread."""
174
+ return cast(
175
+ "Future[Any]",
176
+ self._mesh.executor.submit(
177
+ self._run_worker, rank, graph, inputs, job_id=job_id
178
+ ),
179
+ )
180
+
181
+ def collect(self, futures: list[Future[Any]]) -> list[Any]:
182
+ """Wait for threads and collect results."""
183
+ done, _ = concurrent.futures.wait(
184
+ futures, return_when=concurrent.futures.FIRST_EXCEPTION
185
+ )
186
+ for f in done:
187
+ exc = f.exception()
188
+ if exc:
189
+ for nf in futures:
190
+ nf.cancel()
191
+ self._mesh.shutdown(wait=False)
192
+ raise exc
193
+ return [f.result() for f in futures]
194
+
195
+ def fetch(self, rank: int, uri: str) -> Future[Any]:
196
+ """Fetch directly from worker store."""
197
+ worker = self._workers[rank]
198
+ worker_ctx = cast(SimpWorker, worker.get_dialect_state("simp"))
199
+ return self._mesh.executor.submit(lambda: worker_ctx.store.get(uri)) # type: ignore[no-any-return]
200
+
201
+ def _run_worker(
202
+ self, rank: int, graph: Graph, inputs: list[Any], job_id: str | None = None
203
+ ) -> Any:
204
+ """Execute on worker interpreter."""
205
+ worker_interp = self._workers[rank]
206
+ worker_ctx = cast(SimpWorker, worker_interp.get_dialect_state("simp"))
207
+
208
+ # Resolve URI inputs (None means rank has no data)
209
+ resolved_inputs = [
210
+ worker_ctx.store.get(inp) if inp is not None else None for inp in inputs
211
+ ]
212
+
213
+ # Execute
214
+ results = worker_interp.evaluate_graph(graph, resolved_inputs, job_id)
215
+
216
+ # Store results (results is always a list)
217
+ if not graph.outputs:
218
+ return None
219
+ return [
220
+ worker_ctx.store.put(res) if res is not None else None for res in results
221
+ ]
222
+
223
+
224
+ def make_simulator(
225
+ world_size: int,
226
+ *,
227
+ cluster_spec: Any = None,
228
+ enable_tracing: bool = False,
229
+ ) -> Interpreter:
230
+ """Create an Interpreter configured for local SIMP simulation.
231
+
232
+ This factory creates a MemCluster with workers and returns an
233
+ Interpreter with the simp dialect state attached.
234
+
235
+ Args:
236
+ world_size: Number of simulated parties.
237
+ cluster_spec: Optional ClusterSpec for metadata.
238
+ enable_tracing: If True, enable execution tracing.
239
+
240
+ Returns:
241
+ Configured Interpreter with simp state attached.
242
+
243
+ Example:
244
+ >>> interp = make_simulator(2)
245
+ >>> with interp:
246
+ ... result = my_func()
247
+ """
248
+ from mplang.v2.backends.simp_driver.ops import DRIVER_HANDLERS
249
+
250
+ if cluster_spec is None:
251
+ from mplang.v2.libs.device import ClusterSpec
252
+
253
+ cluster_spec = ClusterSpec.simple(world_size)
254
+
255
+ cluster = MemCluster(
256
+ world_size=world_size,
257
+ cluster_spec=cluster_spec,
258
+ enable_tracing=enable_tracing,
259
+ )
260
+ state = cluster.create_state()
261
+
262
+ handlers: dict[str, Callable[..., Any]] = {**DRIVER_HANDLERS} # type: ignore[dict-item]
263
+ interp = Interpreter(
264
+ name="HostInterpreter",
265
+ root_dir=cluster.host_root,
266
+ handlers=handlers,
267
+ tracer=cluster.tracer,
268
+ store=ObjectStore(fs_root=str(cluster.host_root)),
269
+ )
270
+ interp.set_dialect_state("simp", state)
271
+
272
+ # Keep cluster alive (prevent GC)
273
+ interp._simp_cluster = cluster # type: ignore[attr-defined]
274
+ interp._cluster_spec = cluster_spec # type: ignore[attr-defined]
275
+
276
+ return interp
277
+
278
+
279
+ # Backward compatibility alias
280
+ LocalCluster = MemCluster