mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,102 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Simp Worker memory IPC runtime (LocalMesh, ThreadCommunicator)."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import concurrent.futures
20
+ import threading
21
+ from typing import Any
22
+
23
+
24
+ class ThreadCommunicator:
25
+ """Thread-based communicator for in-memory communication.
26
+
27
+ Args:
28
+ rank: This communicator's rank.
29
+ world_size: Total number of parties.
30
+ use_serde: If True, serialize/deserialize data through serde on send.
31
+ """
32
+
33
+ def __init__(self, rank: int, world_size: int, *, use_serde: bool = False):
34
+ self.rank = rank
35
+ self.world_size = world_size
36
+ self.use_serde = use_serde
37
+ self.peers: list[ThreadCommunicator] = []
38
+ # Mailbox keyed by (from_rank, tag): each key has exactly one message
39
+ self._mailbox: dict[tuple[int, str], Any] = {}
40
+ self._cond = threading.Condition()
41
+ self._sent_events: dict[str, threading.Event] = {}
42
+ self._shutdown = False
43
+
44
+ def set_peers(self, peers: list[ThreadCommunicator]) -> None:
45
+ assert len(peers) == self.world_size
46
+ self.peers = peers
47
+
48
+ def shutdown(self) -> None:
49
+ with self._cond:
50
+ self._shutdown = True
51
+ self._cond.notify_all()
52
+
53
+ def send(self, to: int, key: str, data: Any) -> None:
54
+ assert 0 <= to < self.world_size
55
+ if self.use_serde:
56
+ from mplang.v2.edsl import serde
57
+
58
+ data = serde.loads(serde.dumps(data))
59
+ self.peers[to]._on_receive(self.rank, key, data)
60
+
61
+ def recv(self, frm: int, key: str) -> Any:
62
+ mailbox_key = (frm, key)
63
+ with self._cond:
64
+ while mailbox_key not in self._mailbox and not self._shutdown:
65
+ self._cond.wait()
66
+ if self._shutdown:
67
+ raise RuntimeError("Communicator shut down")
68
+ return self._mailbox.pop(mailbox_key)
69
+
70
+ def _on_receive(self, frm: int, key: str, data: Any) -> None:
71
+ mailbox_key = (frm, key)
72
+ with self._cond:
73
+ if mailbox_key in self._mailbox:
74
+ raise RuntimeError(
75
+ f"Mailbox overflow: key {mailbox_key} already exists"
76
+ )
77
+ self._mailbox[mailbox_key] = data
78
+ self._cond.notify_all()
79
+
80
+
81
+ class LocalMesh:
82
+ """Communication mesh for local SIMP simulation.
83
+
84
+ Creates a set of ThreadCommunicators and a ThreadPoolExecutor for
85
+ worker-side execution.
86
+ """
87
+
88
+ def __init__(self, world_size: int, *, use_serde: bool = False):
89
+ self.world_size = world_size
90
+ self.use_serde = use_serde
91
+ self.comms = [
92
+ ThreadCommunicator(i, world_size, use_serde=use_serde)
93
+ for i in range(world_size)
94
+ ]
95
+ for comm in self.comms:
96
+ comm.set_peers(self.comms)
97
+ self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=world_size)
98
+
99
+ def shutdown(self, wait: bool = True) -> None:
100
+ for comm in self.comms:
101
+ comm.shutdown()
102
+ self.executor.shutdown(wait=wait)
@@ -0,0 +1,167 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Simp Worker ops (WORKER_HANDLERS).
16
+
17
+ This module contains all simp operation implementations for the Worker Interpreter.
18
+ These implementations execute locally on a single party.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ from typing import Any
24
+
25
+ from mplang.v2.dialects import simp
26
+ from mplang.v2.edsl.graph import Operation
27
+ from mplang.v2.runtime.interpreter import Interpreter
28
+
29
+
30
+ def _ensure_worker_context(interpreter: Any, op_name: str) -> Any:
31
+ """Validate that interpreter has a Worker context."""
32
+ state = interpreter.get_dialect_state("simp")
33
+ if state is None or not hasattr(state, "communicator"):
34
+ raise RuntimeError(f"{op_name} requires simp Worker state (with communicator)")
35
+ return state
36
+
37
+
38
+ def _pcall_static_worker_impl(
39
+ interpreter: Interpreter, op: Operation, *args: Any
40
+ ) -> Any:
41
+ """Worker implementation of pcall_static."""
42
+ worker = _ensure_worker_context(interpreter, "pcall_static_impl")
43
+
44
+ parties = op.attrs.get("parties")
45
+ if parties is None:
46
+ raise ValueError("pcall_static requires 'parties' attribute")
47
+
48
+ if worker.rank in parties:
49
+ fn_graph = op.regions[0]
50
+ prev_parties = worker.current_parties
51
+ worker.current_parties = parties
52
+
53
+ try:
54
+ result = interpreter.evaluate_graph(fn_graph, list(args))
55
+ # Return single value for single output (interpreter expects this)
56
+ return result[0] if len(op.outputs) == 1 else result
57
+ finally:
58
+ worker.current_parties = prev_parties
59
+ else:
60
+ # No data for this rank
61
+ return None if len(op.outputs) == 1 else [None] * len(op.outputs)
62
+
63
+
64
+ def _pcall_dynamic_worker_impl(
65
+ interpreter: Interpreter, op: Operation, *args: Any
66
+ ) -> Any:
67
+ """Worker implementation of pcall_dynamic."""
68
+ fn_graph = op.regions[0]
69
+ result = interpreter.evaluate_graph(fn_graph, list(args))
70
+ return result[0] if len(op.outputs) == 1 else result
71
+
72
+
73
+ def _shuffle_static_worker_impl(
74
+ interpreter: Interpreter, op: Operation, *args: Any
75
+ ) -> Any:
76
+ """Worker implementation of shuffle_static."""
77
+ worker = _ensure_worker_context(interpreter, "shuffle_static_impl")
78
+
79
+ routing = op.attrs.get("routing")
80
+ if routing is None:
81
+ return args[0]
82
+
83
+ comm = worker.communicator
84
+ my_rank = worker.rank
85
+ data = args[0]
86
+
87
+ for tgt, src in routing.items():
88
+ if src == my_rank and tgt != my_rank:
89
+ key = f"shuffle_{op.name}_{tgt}"
90
+ comm.send(tgt, key, data)
91
+
92
+ if my_rank in routing:
93
+ src = routing[my_rank]
94
+ if src == my_rank:
95
+ return data
96
+ key = f"shuffle_{op.name}_{my_rank}"
97
+ return comm.recv(src, key)
98
+ else:
99
+ return None
100
+
101
+
102
+ def _converge_worker_impl(interpreter: Interpreter, op: Operation, *args: Any) -> Any:
103
+ """Worker implementation of simp.converge."""
104
+ for arg in args:
105
+ if arg is not None:
106
+ return arg
107
+ return None
108
+
109
+
110
+ def _uniform_cond_worker_impl(
111
+ interpreter: Interpreter, op: Operation, pred: Any, *args: Any
112
+ ) -> Any:
113
+ """Worker implementation of simp.uniform_cond."""
114
+ from mplang.v2.backends.tensor_impl import TensorValue
115
+
116
+ if op.attrs.get("verify_uniform", True):
117
+ pass # TODO: Implement AllReduce verification
118
+
119
+ if isinstance(pred, TensorValue):
120
+ pred = bool(pred.unwrap())
121
+
122
+ if pred:
123
+ result = interpreter.evaluate_graph(op.regions[0], list(args))
124
+ else:
125
+ result = interpreter.evaluate_graph(op.regions[1], list(args))
126
+ return result[0] if len(op.outputs) == 1 else result
127
+
128
+
129
+ def _while_loop_worker_impl(interpreter: Interpreter, op: Operation, *args: Any) -> Any:
130
+ """Worker implementation of simp.while_loop."""
131
+ from mplang.v2.backends.tensor_impl import TensorValue
132
+
133
+ cond_graph = op.regions[0]
134
+ body_graph = op.regions[1]
135
+
136
+ num_state = len(op.outputs)
137
+ current_state = list(args[:num_state])
138
+ captures = list(args[num_state:])
139
+
140
+ while True:
141
+ region_inputs = current_state + captures
142
+
143
+ cond_res = interpreter.evaluate_graph(cond_graph, region_inputs)
144
+ # cond_res is a list, extract the single boolean
145
+ cond_val = cond_res[0] if cond_res else False
146
+
147
+ if isinstance(cond_val, TensorValue):
148
+ cond_val = bool(cond_val.unwrap())
149
+
150
+ if not cond_val:
151
+ break
152
+
153
+ body_res = interpreter.evaluate_graph(body_graph, region_inputs)
154
+ current_state = body_res # body_res is always a list now
155
+
156
+ # Return single value for single output
157
+ return current_state[0] if len(current_state) == 1 else current_state
158
+
159
+
160
+ WORKER_HANDLERS = {
161
+ simp.pcall_static_p.name: _pcall_static_worker_impl,
162
+ simp.pcall_dynamic_p.name: _pcall_dynamic_worker_impl,
163
+ simp.shuffle_static_p.name: _shuffle_static_worker_impl,
164
+ simp.converge_p.name: _converge_worker_impl,
165
+ simp.uniform_cond_p.name: _uniform_cond_worker_impl,
166
+ simp.while_loop_p.name: _while_loop_worker_impl,
167
+ }
@@ -0,0 +1,49 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Simp Worker state (SimpWorker)."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from typing import Any
20
+
21
+ import mplang.v2.backends.field_impl # noqa: F401
22
+ import mplang.v2.backends.tensor_impl # noqa: F401
23
+ from mplang.v2.runtime.dialect_state import DialectState
24
+ from mplang.v2.runtime.object_store import ObjectStore
25
+
26
+
27
+ class SimpWorker(DialectState):
28
+ """Worker state for SIMP execution.
29
+
30
+ This state provides capabilities (Store, Communicator) to the Interpreter.
31
+ Attached to Worker Interpreters.
32
+ """
33
+
34
+ dialect_name: str = "simp"
35
+
36
+ def __init__(
37
+ self,
38
+ rank: int,
39
+ world_size: int,
40
+ communicator: Any,
41
+ store: ObjectStore,
42
+ spu_endpoints: dict[int, str] | None = None,
43
+ ):
44
+ self.rank = rank
45
+ self.world_size = world_size
46
+ self.communicator = communicator
47
+ self.store = store
48
+ self.spu_endpoints = spu_endpoints
49
+ self.current_parties: tuple[int, ...] | None = None
@@ -0,0 +1,275 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """SPU Runtime Implementation.
16
+
17
+ Implements execution logic for SPU primitives using libspu.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import base64
23
+ from typing import Any, ClassVar
24
+
25
+ import numpy as np
26
+ import spu.api as spu_api
27
+ import spu.libspu as libspu
28
+
29
+ from mplang.v2.backends.spu_state import SPUState
30
+ from mplang.v2.backends.tensor_impl import TensorValue
31
+ from mplang.v2.dialects import spu
32
+ from mplang.v2.edsl import serde
33
+ from mplang.v2.edsl.graph import Operation
34
+ from mplang.v2.runtime.interpreter import Interpreter
35
+ from mplang.v2.runtime.value import WrapValue
36
+
37
+ # =============================================================================
38
+ # SPU Share Wrapper
39
+ # =============================================================================
40
+
41
+
42
+ @serde.register_class
43
+ class SPUShareValue(WrapValue[libspu.Share]):
44
+ """Wrapper for libspu.Share representing an SPU secret share.
45
+
46
+ This wraps the external libspu library's Share type to provide
47
+ proper serialization support via the Value base class.
48
+
49
+ In-memory, we hold the libspu.Share directly to avoid copying.
50
+ Serialization extracts meta/share_chunks when needed.
51
+ """
52
+
53
+ _serde_kind: ClassVar[str] = "spu_impl.SPUShareValue"
54
+
55
+ def _convert(self, data: Any) -> libspu.Share:
56
+ if isinstance(data, SPUShareValue):
57
+ return data.unwrap()
58
+ if isinstance(data, libspu.Share):
59
+ return data
60
+ raise TypeError(f"Expected libspu.Share, got {type(data)}")
61
+
62
+ @property
63
+ def libspu_share(self) -> libspu.Share:
64
+ """Get the underlying libspu.Share object."""
65
+ return self._data
66
+
67
+ def to_json(self) -> dict[str, Any]:
68
+ return {
69
+ "meta": base64.b64encode(self._data.meta).decode("ascii"),
70
+ "share_chunks": [
71
+ base64.b64encode(chunk).decode("ascii")
72
+ for chunk in self._data.share_chunks
73
+ ],
74
+ }
75
+
76
+ @classmethod
77
+ def from_json(cls, data: dict[str, Any]) -> SPUShareValue:
78
+ share = libspu.Share()
79
+ share.meta = base64.b64decode(data["meta"])
80
+ share.share_chunks = [
81
+ base64.b64decode(chunk_b64) for chunk_b64 in data["share_chunks"]
82
+ ]
83
+ return cls(share)
84
+
85
+ @classmethod
86
+ def from_libspu(cls, share: libspu.Share) -> SPUShareValue:
87
+ """Create SPUShareValue from a libspu.Share (zero-copy)."""
88
+ return cls(share)
89
+
90
+
91
+ # =============================================================================
92
+ # SPU Config Helpers
93
+ # =============================================================================
94
+
95
+
96
+ def to_runtime_config(config: spu.SPUConfig) -> libspu.RuntimeConfig:
97
+ """Convert SPUConfig to libspu.RuntimeConfig.
98
+
99
+ This is a runtime-only function that maps the string-based configuration
100
+ to libspu enums. Should only be called in the backend implementation.
101
+ """
102
+ runtime_config = libspu.RuntimeConfig()
103
+ # ProtocolKind uses "SEMI2K" not "PROT_SEMI2K"
104
+ runtime_config.protocol = getattr(libspu.ProtocolKind, config.protocol)
105
+ runtime_config.field = getattr(libspu.FieldType, config.field)
106
+ runtime_config.fxp_fraction_bits = config.fxp_fraction_bits
107
+ return runtime_config
108
+
109
+
110
+ @spu.makeshares_p.def_impl
111
+ def makeshares_impl(
112
+ interpreter: Interpreter, op: Operation, data: TensorValue
113
+ ) -> tuple[SPUShareValue, ...]:
114
+ """Generate secret shares for data using spu.Io."""
115
+ count = op.attrs["count"]
116
+ config: spu.SPUConfig = op.attrs["config"]
117
+
118
+ # We create a standalone Io for share generation (no link needed for make_shares)
119
+ runtime_config = to_runtime_config(config)
120
+ io = spu_api.Io(count, runtime_config)
121
+
122
+ # Unwrap TensorValue
123
+ arr = data.unwrap()
124
+
125
+ # data is expected to be numpy array
126
+ arr = np.asarray(arr)
127
+
128
+ # Generate shares (VIS_SECRET)
129
+ libspu_shares = io.make_shares(arr, libspu.Visibility.VIS_SECRET)
130
+
131
+ # Wrap libspu.Share objects in SPUShareValue
132
+ return tuple(SPUShareValue.from_libspu(share) for share in libspu_shares)
133
+
134
+
135
+ @spu.reconstruct_p.def_impl
136
+ def reconstruct_impl(
137
+ interpreter: Interpreter, op: Operation, *shares: SPUShareValue
138
+ ) -> TensorValue:
139
+ """Reconstruct data from secret shares using spu.Io."""
140
+ count = len(shares)
141
+ config: spu.SPUConfig = op.attrs["config"]
142
+
143
+ runtime_config = to_runtime_config(config)
144
+ io = spu_api.Io(count, runtime_config)
145
+
146
+ # Unwrap SPUShareValue to libspu.Share
147
+ libspu_shares = [share.libspu_share for share in shares]
148
+
149
+ # Reconstruct
150
+ result = io.reconstruct(libspu_shares)
151
+
152
+ # Wrap result as TensorValue
153
+ return TensorValue.wrap(result)
154
+
155
+
156
+ @spu.exec_p.def_impl
157
+ def exec_impl(interpreter: Interpreter, op: Operation, *args: Any) -> Any:
158
+ """Execute SPU kernel using spu.Runtime.
159
+
160
+ The SPU config must contain parties info to correctly map global rank
161
+ to local SPU rank and determine SPU world size.
162
+ """
163
+ from mplang.v2.backends.simp_worker.state import SimpWorker
164
+
165
+ # Get SPU config from attrs (passed through from run_jax)
166
+ config: spu.SPUConfig = op.attrs["config"]
167
+
168
+ # Get parties from SimpWorker state (injected by pcall_static_impl)
169
+ context = interpreter.get_dialect_state("simp")
170
+ if not isinstance(context, SimpWorker):
171
+ raise RuntimeError(f"spu.exec requires SimpWorker, got {type(context)}")
172
+
173
+ parties = context.current_parties
174
+ if parties is None:
175
+ raise RuntimeError(
176
+ "spu.exec requires 'current_parties' in SimpWorker state. "
177
+ "Ensure it is called within a pcall_static block."
178
+ )
179
+
180
+ global_rank = context.rank
181
+
182
+ if global_rank not in parties:
183
+ raise RuntimeError(
184
+ f"Global rank {global_rank} is not in current parties {parties}"
185
+ )
186
+
187
+ # Convert global rank to local SPU rank
188
+ local_rank = parties.index(global_rank)
189
+ spu_world_size = len(parties)
190
+
191
+ # Get SPU endpoints from interpreter (set by WorkerInterpreter for BRPC mode)
192
+ # spu_endpoints is a dict mapping global_rank -> brpc_endpoint
193
+ spu_endpoints_map: dict[int, str] | None = getattr(
194
+ interpreter, "spu_endpoints", None
195
+ )
196
+ if spu_endpoints_map is None:
197
+ # Try getting from SimpWorker context (context is already SimpWorker)
198
+ spu_endpoints_map = getattr(context, "spu_endpoints", None)
199
+
200
+ # Build ordered list of endpoints for SPU parties
201
+ spu_endpoints: list[str] | None = None
202
+ if spu_endpoints_map is not None:
203
+ spu_endpoints = []
204
+ for party_rank in parties:
205
+ if party_rank not in spu_endpoints_map:
206
+ raise RuntimeError(
207
+ f"SPU endpoint not found for party {party_rank}. "
208
+ f"Available: {list(spu_endpoints_map.keys())}"
209
+ )
210
+ spu_endpoints.append(spu_endpoints_map[party_rank])
211
+
212
+ # Get communicator for Channels mode (reuse existing communication)
213
+ # If no BRPC endpoints configured, use Channels mode
214
+ communicator = None
215
+ if spu_endpoints is None:
216
+ # Use worker's communicator for channel reuse
217
+ # (SimpWorker already imported at function start)
218
+ communicator = context.communicator
219
+
220
+ # Get or create SPUState for caching Runtime/Io
221
+ spu_state = interpreter.get_dialect_state(SPUState.dialect_name)
222
+ if not isinstance(spu_state, SPUState):
223
+ spu_state = SPUState()
224
+ interpreter.set_dialect_state(SPUState.dialect_name, spu_state)
225
+
226
+ runtime, io = spu_state.get_or_create(
227
+ local_rank,
228
+ spu_world_size,
229
+ config,
230
+ spu_endpoints,
231
+ communicator=communicator,
232
+ parties=list(parties),
233
+ )
234
+
235
+ executable_code = op.attrs["executable"]
236
+ input_names = op.attrs["input_names"]
237
+ output_names = op.attrs["output_names"]
238
+
239
+ # Create Executable
240
+ executable = libspu.Executable(
241
+ name="spu_kernel",
242
+ input_names=input_names,
243
+ output_names=output_names,
244
+ code=executable_code,
245
+ )
246
+
247
+ # Set inputs
248
+ for name, share in zip(input_names, args, strict=True):
249
+ # Handle SPUShareValue wrapper - unwrap to libspu.Share
250
+ if isinstance(share, SPUShareValue):
251
+ libspu_share = share.libspu_share
252
+ else:
253
+ # Handle public input (numpy array)
254
+ # Generate shares with VIS_PUBLIC
255
+ # make_shares expects numpy array
256
+ if not isinstance(share, (np.ndarray, np.generic, int, float)):
257
+ share = np.array(share)
258
+
259
+ shares = io.make_shares(share, libspu.Visibility.VIS_PUBLIC)
260
+ libspu_share = shares[local_rank]
261
+
262
+ runtime.set_var(name, libspu_share)
263
+
264
+ # Run
265
+ runtime.run(executable)
266
+
267
+ # Get outputs and wrap in SPUShareValue
268
+ results = []
269
+ for name in output_names:
270
+ libspu_share = runtime.get_var(name)
271
+ results.append(SPUShareValue.from_libspu(libspu_share))
272
+
273
+ if len(results) == 1:
274
+ return results[0]
275
+ return tuple(results)