mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,135 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Simp Driver ops (DRIVER_HANDLERS).
16
+
17
+ Unified SPMD dispatch pattern for all SIMP operations.
18
+ All ops: wrap → dispatch to ALL workers → collect DriverVar(s).
19
+ Op-specific logic lives in Worker handlers (simp_worker/ops.py).
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ from typing import Any
25
+
26
+ from mplang.v2.backends.simp_driver.values import DriverVar
27
+ from mplang.v2.dialects import simp
28
+ from mplang.v2.edsl.graph import Graph, Operation
29
+ from mplang.v2.edsl.typing import CustomType
30
+
31
+
32
+ def _get_driver_context(interpreter: Any) -> Any:
33
+ """Get the simp driver state from interpreter."""
34
+ state = interpreter.get_dialect_state("simp")
35
+ if state is None:
36
+ raise RuntimeError("Interpreter must have simp dialect state attached")
37
+ return state
38
+
39
+
40
+ def _wrap_op_as_graph(op: Operation) -> Graph:
41
+ """Wrap an Operation into a single-op Graph for worker submission."""
42
+ g = Graph()
43
+ any_type = CustomType("Any")
44
+
45
+ # Create graph inputs
46
+ graph_inputs = [g.add_input(f"in_{i}", any_type) for i in range(len(op.inputs))]
47
+
48
+ # Determine output types
49
+ output_types = [out.type for out in op.outputs] if op.outputs else [any_type]
50
+
51
+ # Add the operation (this handles outputs and value registration)
52
+ g.add_op(
53
+ opcode=op.opcode,
54
+ inputs=graph_inputs,
55
+ output_types=output_types,
56
+ attrs=op.attrs.copy(),
57
+ regions=op.regions,
58
+ )
59
+
60
+ # Mark outputs
61
+ for v in g.operations[-1].outputs:
62
+ g.add_output(v)
63
+
64
+ return g
65
+
66
+
67
+ def _collect_to_hostvars(results: list[Any], num_outputs: int, world_size: int) -> Any:
68
+ """Collect worker results into DriverVar(s).
69
+
70
+ Args:
71
+ results: List of results from each worker (length = world_size).
72
+ Each result is a list of URIs (one per output).
73
+ num_outputs: Number of outputs per worker
74
+ world_size: Total number of workers
75
+
76
+ Returns:
77
+ Single DriverVar if num_outputs == 1, else list of DriverVars
78
+ """
79
+ if num_outputs == 0:
80
+ return None
81
+
82
+ # Transpose [worker][output] -> [output][worker]
83
+ # results[worker_idx] is a list of URIs for that worker's outputs
84
+ transposed = []
85
+ for out_idx in range(num_outputs):
86
+ transposed.append(
87
+ DriverVar([res[out_idx] if res is not None else None for res in results])
88
+ )
89
+
90
+ if num_outputs == 1:
91
+ return transposed[0]
92
+ return transposed
93
+
94
+
95
+ def _generic_simp_dispatch(interpreter: Any, op: Operation, *args: Any) -> Any:
96
+ """Unified SIMP dispatch: wrap op, SPMD submit, collect DriverVar(s).
97
+
98
+ This is the ONLY driver handler needed for all SIMP ops.
99
+ Worker handlers implement the actual op-specific logic.
100
+ """
101
+ driver = _get_driver_context(interpreter)
102
+ world_size = driver.world_size
103
+
104
+ # 1. Wrap operation into a Graph
105
+ wrapper_graph = _wrap_op_as_graph(op)
106
+
107
+ # 2. SPMD dispatch to ALL workers
108
+ futures = []
109
+ for rank in range(world_size):
110
+ # Extract per-party inputs from DriverVars
111
+ party_inputs = [
112
+ arg[rank] if isinstance(arg, DriverVar) else arg for arg in args
113
+ ]
114
+ futures.append(driver.submit(rank, wrapper_graph, party_inputs))
115
+
116
+ # 3. Collect results
117
+ results = driver.collect(futures)
118
+
119
+ # 4. Assemble into DriverVar(s)
120
+ num_outputs = len(op.outputs) if op.outputs else 1
121
+ return _collect_to_hostvars(results, num_outputs, world_size)
122
+
123
+
124
+ # =============================================================================
125
+ # All SIMP ops use unified dispatch
126
+ # =============================================================================
127
+
128
+ DRIVER_HANDLERS = {
129
+ simp.pcall_static_p.name: _generic_simp_dispatch,
130
+ simp.pcall_dynamic_p.name: _generic_simp_dispatch,
131
+ simp.shuffle_static_p.name: _generic_simp_dispatch,
132
+ simp.converge_p.name: _generic_simp_dispatch,
133
+ simp.uniform_cond_p.name: _generic_simp_dispatch,
134
+ simp.while_loop_p.name: _generic_simp_dispatch,
135
+ }
@@ -0,0 +1,60 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """SimpDriver abstract base class."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from abc import ABC, abstractmethod
20
+ from typing import TYPE_CHECKING, Any
21
+
22
+ from mplang.v2.runtime.dialect_state import DialectState
23
+
24
+ if TYPE_CHECKING:
25
+ from concurrent.futures import Future
26
+
27
+ from mplang.v2.edsl.graph import Graph
28
+
29
+
30
+ class SimpDriver(DialectState, ABC):
31
+ """Abstract base class for Simp Host drivers.
32
+
33
+ All simp drivers must implement submit/fetch/collect interface
34
+ for dispatching work to workers.
35
+ """
36
+
37
+ dialect_name: str = "simp"
38
+
39
+ @property
40
+ @abstractmethod
41
+ def world_size(self) -> int:
42
+ """Number of workers."""
43
+ ...
44
+
45
+ @abstractmethod
46
+ def submit(
47
+ self, rank: int, graph: Graph, inputs: list[Any], job_id: str | None = None
48
+ ) -> Future[Any]:
49
+ """Submit graph execution to a worker."""
50
+ ...
51
+
52
+ @abstractmethod
53
+ def fetch(self, rank: int, uri: str) -> Future[Any]:
54
+ """Fetch data from a worker."""
55
+ ...
56
+
57
+ @abstractmethod
58
+ def collect(self, futures: list[Future[Any]]) -> list[Any]:
59
+ """Collect results from futures."""
60
+ ...
@@ -0,0 +1,52 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Simp Driver values (DriverVar)."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from typing import Any, ClassVar, Self
20
+
21
+ from mplang.v2.edsl import serde
22
+ from mplang.v2.runtime.value import Value
23
+
24
+
25
+ @serde.register_class
26
+ class DriverVar(Value):
27
+ """A value replicated (or sharded) on the Driver.
28
+
29
+ A DriverVar holds a list of values, one for each party in the computation.
30
+ """
31
+
32
+ _serde_kind: ClassVar[str] = "simp.DriverVar"
33
+
34
+ def __init__(self, values: list[Any]):
35
+ self.values = values
36
+
37
+ @property
38
+ def world_size(self) -> int:
39
+ return len(self.values)
40
+
41
+ def __repr__(self) -> str:
42
+ return f"DriverVar({self.values})"
43
+
44
+ def __getitem__(self, idx: int) -> Any:
45
+ return self.values[idx]
46
+
47
+ def to_json(self) -> dict[str, Any]:
48
+ return {"values": self.values}
49
+
50
+ @classmethod
51
+ def from_json(cls, data: dict[str, Any]) -> Self:
52
+ return cls(values=data["values"])
@@ -0,0 +1,29 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Simp Worker package.
16
+
17
+ Provides Worker-side state and ops for the simp dialect.
18
+ """
19
+
20
+ from mplang.v2.backends.simp_worker.mem import LocalMesh, ThreadCommunicator
21
+ from mplang.v2.backends.simp_worker.ops import WORKER_HANDLERS
22
+ from mplang.v2.backends.simp_worker.state import SimpWorker
23
+
24
+ __all__ = [
25
+ "WORKER_HANDLERS",
26
+ "LocalMesh",
27
+ "SimpWorker",
28
+ "ThreadCommunicator",
29
+ ]
@@ -0,0 +1,354 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """SIMP HTTP Worker module.
16
+
17
+ Provides the HTTP-based worker entry point for distributed deployment.
18
+ This module contains:
19
+ - HttpCommunicator: HTTP-based inter-worker communication
20
+ - create_worker_app: Factory for FastAPI application
21
+
22
+ Usage:
23
+ # Start a worker server
24
+ from mplang.v2.backends.simp_http_worker import create_worker_app
25
+ import uvicorn
26
+
27
+ app = create_worker_app(rank=0, world_size=3, endpoints=[...])
28
+ uvicorn.run(app, host="0.0.0.0", port=8000)
29
+
30
+ Security:
31
+ This module uses secure JSON-based serialization (serde) for computation
32
+ graphs and data between workers. Unlike pickle, JSON deserialization
33
+ cannot execute arbitrary code, making it safe for network communication.
34
+ """
35
+
36
+ from __future__ import annotations
37
+
38
+ import concurrent.futures
39
+ import logging
40
+ import os
41
+ import pathlib
42
+ import threading
43
+ import time
44
+ from typing import Any
45
+
46
+ import httpx
47
+ from fastapi import FastAPI, HTTPException
48
+ from pydantic import BaseModel
49
+
50
+ from mplang.v2.backends import spu_impl as _spu_impl # noqa: F401
51
+ from mplang.v2.backends import tensor_impl as _tensor_impl # noqa: F401
52
+
53
+ # Register operation implementations (side-effect imports)
54
+ from mplang.v2.backends.simp_worker import SimpWorker
55
+ from mplang.v2.backends.simp_worker import ops as _simp_worker_ops # noqa: F401
56
+ from mplang.v2.edsl import serde
57
+ from mplang.v2.edsl.graph import Graph
58
+ from mplang.v2.runtime.interpreter import ExecutionTracer, Interpreter
59
+ from mplang.v2.runtime.object_store import ObjectStore
60
+
61
+ logger = logging.getLogger(__name__)
62
+
63
+
64
+ class HttpCommunicator:
65
+ """Communicator using HTTP requests for inter-worker communication.
66
+
67
+ Uses a background thread pool for sending to avoid blocking the main execution.
68
+
69
+ Attributes:
70
+ rank: This worker's rank
71
+ world_size: Total number of workers
72
+ endpoints: HTTP endpoints for all workers
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ rank: int,
78
+ world_size: int,
79
+ endpoints: list[str],
80
+ tracer: ExecutionTracer | None = None,
81
+ ):
82
+ self.rank = rank
83
+ self.world_size = world_size
84
+ self.endpoints = endpoints
85
+ self.tracer = tracer
86
+ self._mailbox: dict[tuple[int, str], Any] = {}
87
+ self._cond = threading.Condition()
88
+ self._send_executor = concurrent.futures.ThreadPoolExecutor(
89
+ max_workers=world_size, thread_name_prefix=f"comm_send_{rank}"
90
+ )
91
+ self._pending_sends: list[concurrent.futures.Future[None]] = []
92
+ self.client = httpx.Client(timeout=None)
93
+
94
+ def send(self, to: int, key: str, data: Any) -> None:
95
+ """Send data to another rank asynchronously."""
96
+ future = self._send_executor.submit(self._do_send, to, key, data)
97
+ self._pending_sends.append(future)
98
+
99
+ def _do_send(self, to: int, key: str, data: Any) -> None:
100
+ """Perform the HTTP send."""
101
+ url = f"{self.endpoints[to]}/comm/{key}"
102
+ logger.debug(f"Rank {self.rank} sending to {to} key={key}")
103
+
104
+ # Detect SPU channel (tag prefix "spu:") and handle bytes
105
+ if key.startswith("spu:") and isinstance(data, bytes):
106
+ # Send raw bytes for SPU channels
107
+ import base64
108
+
109
+ payload = base64.b64encode(data).decode("ascii")
110
+ is_raw_bytes = True
111
+ else:
112
+ # Use secure JSON serialization
113
+ payload = serde.dumps_b64(data)
114
+ is_raw_bytes = False
115
+
116
+ size_bytes = len(payload)
117
+
118
+ # Log to profiler
119
+ if self.tracer:
120
+ self.tracer.log_custom_event(
121
+ name="comm.send",
122
+ start_ts=time.time(),
123
+ end_ts=time.time(), # Instant event for size? Or measure duration?
124
+ cat="comm",
125
+ args={"to": to, "key": key, "bytes": size_bytes},
126
+ )
127
+
128
+ try:
129
+ t0 = time.time()
130
+ resp = self.client.put(
131
+ url,
132
+ json={
133
+ "data": payload,
134
+ "from_rank": self.rank,
135
+ "is_raw_bytes": is_raw_bytes,
136
+ },
137
+ )
138
+ resp.raise_for_status()
139
+ duration = time.time() - t0
140
+ if self.tracer:
141
+ self.tracer.log_custom_event(
142
+ name="comm.send_req",
143
+ start_ts=t0,
144
+ end_ts=t0 + duration,
145
+ cat="comm",
146
+ args={"to": to, "key": key, "bytes": size_bytes},
147
+ )
148
+ except Exception as e:
149
+ logger.error(f"Rank {self.rank} failed to send to {to}: {e}")
150
+ raise RuntimeError(f"Failed to send to {to} ({url}): {e}") from e
151
+
152
+ def recv(self, frm: int, key: str) -> Any:
153
+ """Receive data from another rank (blocking)."""
154
+ logger.debug(f"Rank {self.rank} waiting recv from {frm} key={key}")
155
+ mailbox_key = (frm, key)
156
+ with self._cond:
157
+ while mailbox_key not in self._mailbox:
158
+ self._cond.wait(timeout=1.0)
159
+ return self._mailbox.pop(mailbox_key)
160
+
161
+ def on_receive(self, from_rank: int, key: str, data: Any) -> None:
162
+ """Called when data is received from the HTTP endpoint."""
163
+ mailbox_key = (from_rank, key)
164
+ with self._cond:
165
+ if mailbox_key in self._mailbox:
166
+ raise RuntimeError(
167
+ f"Mailbox overflow: key {mailbox_key} already exists"
168
+ )
169
+ self._mailbox[mailbox_key] = data
170
+ self._cond.notify_all()
171
+
172
+ def wait_pending_sends(self) -> None:
173
+ """Wait for all pending sends to complete."""
174
+ for future in self._pending_sends:
175
+ try:
176
+ future.result(timeout=60.0)
177
+ except Exception as e:
178
+ logger.error(f"Rank {self.rank} send failed: {e}")
179
+ self._pending_sends.clear()
180
+
181
+ def shutdown(self) -> None:
182
+ """Shutdown the send executor."""
183
+ self.wait_pending_sends()
184
+ self._send_executor.shutdown(wait=True)
185
+ self.client.close()
186
+
187
+
188
+ class ExecRequest(BaseModel):
189
+ """Request model for /exec endpoint."""
190
+
191
+ graph: str
192
+ inputs: str
193
+ job_id: str | None = None
194
+
195
+
196
+ class CommRequest(BaseModel):
197
+ """Request model for /comm endpoint."""
198
+
199
+ data: str
200
+ from_rank: int
201
+ is_raw_bytes: bool = False # NEW: indicates raw bytes (not serde)
202
+
203
+
204
+ class FetchRequest(BaseModel):
205
+ """Request model for /fetch endpoint."""
206
+
207
+ uri: str
208
+
209
+
210
+ def create_worker_app(
211
+ rank: int,
212
+ world_size: int,
213
+ endpoints: list[str],
214
+ spu_endpoints: dict[int, str] | None = None,
215
+ ) -> FastAPI:
216
+ """Create a FastAPI app for the worker.
217
+
218
+ The app uses async endpoints with run_in_executor to allow concurrent
219
+ handling of /exec and /comm requests. This is essential for cross-party
220
+ communication where one party sends while another receives.
221
+
222
+ Args:
223
+ rank: The global rank of this worker.
224
+ world_size: Total number of workers.
225
+ endpoints: HTTP endpoints for all workers (for shuffle communication).
226
+ spu_endpoints: Optional dict mapping global_rank -> BRPC endpoint for SPU.
227
+
228
+ Returns:
229
+ FastAPI application instance
230
+ """
231
+ import asyncio
232
+
233
+ app = FastAPI(title=f"SIMP Worker {rank}")
234
+
235
+ # Persistence root: ${MPLANG_DATA_ROOT}/<cluster_id>/node<rank>/
236
+ data_root = pathlib.Path(os.environ.get("MPLANG_DATA_ROOT", ".mpl"))
237
+ cluster_id = os.environ.get("MPLANG_CLUSTER_ID", f"__http_{world_size}")
238
+ root_dir = data_root / cluster_id / f"node{rank}"
239
+ trace_dir = root_dir / "trace"
240
+
241
+ # Enable tracing by default for now (or make it configurable via env)
242
+ tracer = ExecutionTracer(enabled=True, trace_dir=trace_dir)
243
+ tracer.start()
244
+
245
+ comm = HttpCommunicator(rank, world_size, endpoints, tracer=tracer)
246
+ store = ObjectStore(fs_root=str(root_dir))
247
+ ctx = SimpWorker(rank, world_size, comm, store, spu_endpoints)
248
+
249
+ # Register handlers
250
+ from collections.abc import Callable
251
+ from typing import cast
252
+
253
+ from mplang.v2.backends.simp_worker.ops import WORKER_HANDLERS
254
+
255
+ # func_impl is already imported at module level for side-effects
256
+ handlers: dict[str, Callable[..., Any]] = {**WORKER_HANDLERS} # type: ignore[dict-item]
257
+
258
+ worker = Interpreter(
259
+ tracer=tracer, root_dir=root_dir, handlers=handlers, store=store
260
+ )
261
+ # Register SimpWorker context as 'simp' dialect state
262
+ worker.set_dialect_state("simp", ctx)
263
+
264
+ exec_pool = concurrent.futures.ThreadPoolExecutor(
265
+ max_workers=2, thread_name_prefix=f"exec_{rank}"
266
+ )
267
+
268
+ def _do_execute(graph: Graph, inputs: list[Any], job_id: str | None = None) -> Any:
269
+ """Execute graph in worker thread."""
270
+ # Resolve URI inputs (None means rank has no data)
271
+ resolved_inputs = [
272
+ store.get(inp) if inp is not None else None for inp in inputs
273
+ ]
274
+
275
+ result = worker.evaluate_graph(graph, resolved_inputs)
276
+ comm.wait_pending_sends()
277
+
278
+ # Store results and return URIs (result is always a list)
279
+ if not graph.outputs:
280
+ return None
281
+ return [store.put(res) if res is not None else None for res in result]
282
+
283
+ @app.post("/exec")
284
+ async def execute(req: ExecRequest) -> dict[str, str]:
285
+ """Execute a graph on this worker."""
286
+ logger.debug(f"Worker {rank} received exec request")
287
+ try:
288
+ # Use secure JSON deserialization
289
+ graph = serde.loads_b64(req.graph)
290
+ inputs = serde.loads_b64(req.inputs)
291
+ loop = asyncio.get_event_loop()
292
+ result = await loop.run_in_executor(
293
+ exec_pool, _do_execute, graph, inputs, req.job_id
294
+ )
295
+ return {"result": serde.dumps_b64(result)}
296
+ except Exception as e:
297
+ logger.error(f"Worker {rank} exec failed: {e}")
298
+ raise HTTPException(status_code=500, detail=str(e)) from e
299
+
300
+ @app.put("/comm/{key}")
301
+ async def receive_comm(key: str, req: CommRequest) -> dict[str, str]:
302
+ """Receive communication data from another worker."""
303
+ logger.debug(f"Worker {rank} received comm key={key} from {req.from_rank}")
304
+ try:
305
+ # Handle raw bytes (SPU channels) vs serde data
306
+ if req.is_raw_bytes:
307
+ # Decode base64 to raw bytes
308
+ import base64
309
+
310
+ data = base64.b64decode(req.data)
311
+ else:
312
+ # Use secure JSON deserialization
313
+ data = serde.loads_b64(req.data)
314
+
315
+ comm.on_receive(req.from_rank, key, data)
316
+ return {"status": "ok"}
317
+ except Exception as e:
318
+ logger.error(f"Worker {rank} comm failed: {e}")
319
+ raise HTTPException(status_code=500, detail=str(e)) from e
320
+
321
+ @app.post("/fetch")
322
+ async def fetch(req: FetchRequest) -> dict[str, str]:
323
+ """Fetch data by URI."""
324
+ logger.debug(f"Worker {rank} received fetch request for {req.uri}")
325
+ try:
326
+ state = cast(SimpWorker, worker.get_dialect_state("simp"))
327
+ val = state.store.get(req.uri)
328
+ return {"result": serde.dumps_b64(val)}
329
+ except Exception as e:
330
+ logger.error(f"Worker {rank} fetch failed: {e}")
331
+ raise HTTPException(status_code=500, detail=str(e)) from e
332
+
333
+ @app.get("/objects")
334
+ async def list_objects() -> dict[str, list[str]]:
335
+ """List all objects in the worker's store."""
336
+ try:
337
+ state = cast(SimpWorker, worker.get_dialect_state("simp"))
338
+ return {"objects": state.store.list_objects()}
339
+ except Exception as e:
340
+ logger.error(f"Worker {rank} list_objects failed: {e}")
341
+ raise HTTPException(status_code=500, detail=str(e)) from e
342
+
343
+ @app.get("/health")
344
+ async def health() -> dict[str, str]:
345
+ """Health check endpoint."""
346
+ return {"status": "ok", "rank": str(rank), "world_size": str(world_size)}
347
+
348
+ @app.on_event("shutdown")
349
+ def shutdown_event() -> None:
350
+ """Cleanup on shutdown."""
351
+ comm.shutdown()
352
+ exec_pool.shutdown(wait=True)
353
+
354
+ return app