mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Simp Worker memory IPC runtime (LocalMesh, ThreadCommunicator)."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import concurrent.futures
|
|
20
|
+
import threading
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ThreadCommunicator:
|
|
25
|
+
"""Thread-based communicator for in-memory communication.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
rank: This communicator's rank.
|
|
29
|
+
world_size: Total number of parties.
|
|
30
|
+
use_serde: If True, serialize/deserialize data through serde on send.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, rank: int, world_size: int, *, use_serde: bool = False):
|
|
34
|
+
self.rank = rank
|
|
35
|
+
self.world_size = world_size
|
|
36
|
+
self.use_serde = use_serde
|
|
37
|
+
self.peers: list[ThreadCommunicator] = []
|
|
38
|
+
# Mailbox keyed by (from_rank, tag): each key has exactly one message
|
|
39
|
+
self._mailbox: dict[tuple[int, str], Any] = {}
|
|
40
|
+
self._cond = threading.Condition()
|
|
41
|
+
self._sent_events: dict[str, threading.Event] = {}
|
|
42
|
+
self._shutdown = False
|
|
43
|
+
|
|
44
|
+
def set_peers(self, peers: list[ThreadCommunicator]) -> None:
|
|
45
|
+
assert len(peers) == self.world_size
|
|
46
|
+
self.peers = peers
|
|
47
|
+
|
|
48
|
+
def shutdown(self) -> None:
|
|
49
|
+
with self._cond:
|
|
50
|
+
self._shutdown = True
|
|
51
|
+
self._cond.notify_all()
|
|
52
|
+
|
|
53
|
+
def send(self, to: int, key: str, data: Any) -> None:
|
|
54
|
+
assert 0 <= to < self.world_size
|
|
55
|
+
if self.use_serde:
|
|
56
|
+
from mplang.v2.edsl import serde
|
|
57
|
+
|
|
58
|
+
data = serde.loads(serde.dumps(data))
|
|
59
|
+
self.peers[to]._on_receive(self.rank, key, data)
|
|
60
|
+
|
|
61
|
+
def recv(self, frm: int, key: str) -> Any:
|
|
62
|
+
mailbox_key = (frm, key)
|
|
63
|
+
with self._cond:
|
|
64
|
+
while mailbox_key not in self._mailbox and not self._shutdown:
|
|
65
|
+
self._cond.wait()
|
|
66
|
+
if self._shutdown:
|
|
67
|
+
raise RuntimeError("Communicator shut down")
|
|
68
|
+
return self._mailbox.pop(mailbox_key)
|
|
69
|
+
|
|
70
|
+
def _on_receive(self, frm: int, key: str, data: Any) -> None:
|
|
71
|
+
mailbox_key = (frm, key)
|
|
72
|
+
with self._cond:
|
|
73
|
+
if mailbox_key in self._mailbox:
|
|
74
|
+
raise RuntimeError(
|
|
75
|
+
f"Mailbox overflow: key {mailbox_key} already exists"
|
|
76
|
+
)
|
|
77
|
+
self._mailbox[mailbox_key] = data
|
|
78
|
+
self._cond.notify_all()
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class LocalMesh:
|
|
82
|
+
"""Communication mesh for local SIMP simulation.
|
|
83
|
+
|
|
84
|
+
Creates a set of ThreadCommunicators and a ThreadPoolExecutor for
|
|
85
|
+
worker-side execution.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def __init__(self, world_size: int, *, use_serde: bool = False):
|
|
89
|
+
self.world_size = world_size
|
|
90
|
+
self.use_serde = use_serde
|
|
91
|
+
self.comms = [
|
|
92
|
+
ThreadCommunicator(i, world_size, use_serde=use_serde)
|
|
93
|
+
for i in range(world_size)
|
|
94
|
+
]
|
|
95
|
+
for comm in self.comms:
|
|
96
|
+
comm.set_peers(self.comms)
|
|
97
|
+
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=world_size)
|
|
98
|
+
|
|
99
|
+
def shutdown(self, wait: bool = True) -> None:
|
|
100
|
+
for comm in self.comms:
|
|
101
|
+
comm.shutdown()
|
|
102
|
+
self.executor.shutdown(wait=wait)
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Simp Worker ops (WORKER_HANDLERS).
|
|
16
|
+
|
|
17
|
+
This module contains all simp operation implementations for the Worker Interpreter.
|
|
18
|
+
These implementations execute locally on a single party.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
from typing import Any
|
|
24
|
+
|
|
25
|
+
from mplang.v2.dialects import simp
|
|
26
|
+
from mplang.v2.edsl.graph import Operation
|
|
27
|
+
from mplang.v2.runtime.interpreter import Interpreter
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _ensure_worker_context(interpreter: Any, op_name: str) -> Any:
|
|
31
|
+
"""Validate that interpreter has a Worker context."""
|
|
32
|
+
state = interpreter.get_dialect_state("simp")
|
|
33
|
+
if state is None or not hasattr(state, "communicator"):
|
|
34
|
+
raise RuntimeError(f"{op_name} requires simp Worker state (with communicator)")
|
|
35
|
+
return state
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _pcall_static_worker_impl(
|
|
39
|
+
interpreter: Interpreter, op: Operation, *args: Any
|
|
40
|
+
) -> Any:
|
|
41
|
+
"""Worker implementation of pcall_static."""
|
|
42
|
+
worker = _ensure_worker_context(interpreter, "pcall_static_impl")
|
|
43
|
+
|
|
44
|
+
parties = op.attrs.get("parties")
|
|
45
|
+
if parties is None:
|
|
46
|
+
raise ValueError("pcall_static requires 'parties' attribute")
|
|
47
|
+
|
|
48
|
+
if worker.rank in parties:
|
|
49
|
+
fn_graph = op.regions[0]
|
|
50
|
+
prev_parties = worker.current_parties
|
|
51
|
+
worker.current_parties = parties
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
result = interpreter.evaluate_graph(fn_graph, list(args))
|
|
55
|
+
# Return single value for single output (interpreter expects this)
|
|
56
|
+
return result[0] if len(op.outputs) == 1 else result
|
|
57
|
+
finally:
|
|
58
|
+
worker.current_parties = prev_parties
|
|
59
|
+
else:
|
|
60
|
+
# No data for this rank
|
|
61
|
+
return None if len(op.outputs) == 1 else [None] * len(op.outputs)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _pcall_dynamic_worker_impl(
|
|
65
|
+
interpreter: Interpreter, op: Operation, *args: Any
|
|
66
|
+
) -> Any:
|
|
67
|
+
"""Worker implementation of pcall_dynamic."""
|
|
68
|
+
fn_graph = op.regions[0]
|
|
69
|
+
result = interpreter.evaluate_graph(fn_graph, list(args))
|
|
70
|
+
return result[0] if len(op.outputs) == 1 else result
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _shuffle_static_worker_impl(
|
|
74
|
+
interpreter: Interpreter, op: Operation, *args: Any
|
|
75
|
+
) -> Any:
|
|
76
|
+
"""Worker implementation of shuffle_static."""
|
|
77
|
+
worker = _ensure_worker_context(interpreter, "shuffle_static_impl")
|
|
78
|
+
|
|
79
|
+
routing = op.attrs.get("routing")
|
|
80
|
+
if routing is None:
|
|
81
|
+
return args[0]
|
|
82
|
+
|
|
83
|
+
comm = worker.communicator
|
|
84
|
+
my_rank = worker.rank
|
|
85
|
+
data = args[0]
|
|
86
|
+
|
|
87
|
+
for tgt, src in routing.items():
|
|
88
|
+
if src == my_rank and tgt != my_rank:
|
|
89
|
+
key = f"shuffle_{op.name}_{tgt}"
|
|
90
|
+
comm.send(tgt, key, data)
|
|
91
|
+
|
|
92
|
+
if my_rank in routing:
|
|
93
|
+
src = routing[my_rank]
|
|
94
|
+
if src == my_rank:
|
|
95
|
+
return data
|
|
96
|
+
key = f"shuffle_{op.name}_{my_rank}"
|
|
97
|
+
return comm.recv(src, key)
|
|
98
|
+
else:
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _converge_worker_impl(interpreter: Interpreter, op: Operation, *args: Any) -> Any:
|
|
103
|
+
"""Worker implementation of simp.converge."""
|
|
104
|
+
for arg in args:
|
|
105
|
+
if arg is not None:
|
|
106
|
+
return arg
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _uniform_cond_worker_impl(
|
|
111
|
+
interpreter: Interpreter, op: Operation, pred: Any, *args: Any
|
|
112
|
+
) -> Any:
|
|
113
|
+
"""Worker implementation of simp.uniform_cond."""
|
|
114
|
+
from mplang.v2.backends.tensor_impl import TensorValue
|
|
115
|
+
|
|
116
|
+
if op.attrs.get("verify_uniform", True):
|
|
117
|
+
pass # TODO: Implement AllReduce verification
|
|
118
|
+
|
|
119
|
+
if isinstance(pred, TensorValue):
|
|
120
|
+
pred = bool(pred.unwrap())
|
|
121
|
+
|
|
122
|
+
if pred:
|
|
123
|
+
result = interpreter.evaluate_graph(op.regions[0], list(args))
|
|
124
|
+
else:
|
|
125
|
+
result = interpreter.evaluate_graph(op.regions[1], list(args))
|
|
126
|
+
return result[0] if len(op.outputs) == 1 else result
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _while_loop_worker_impl(interpreter: Interpreter, op: Operation, *args: Any) -> Any:
|
|
130
|
+
"""Worker implementation of simp.while_loop."""
|
|
131
|
+
from mplang.v2.backends.tensor_impl import TensorValue
|
|
132
|
+
|
|
133
|
+
cond_graph = op.regions[0]
|
|
134
|
+
body_graph = op.regions[1]
|
|
135
|
+
|
|
136
|
+
num_state = len(op.outputs)
|
|
137
|
+
current_state = list(args[:num_state])
|
|
138
|
+
captures = list(args[num_state:])
|
|
139
|
+
|
|
140
|
+
while True:
|
|
141
|
+
region_inputs = current_state + captures
|
|
142
|
+
|
|
143
|
+
cond_res = interpreter.evaluate_graph(cond_graph, region_inputs)
|
|
144
|
+
# cond_res is a list, extract the single boolean
|
|
145
|
+
cond_val = cond_res[0] if cond_res else False
|
|
146
|
+
|
|
147
|
+
if isinstance(cond_val, TensorValue):
|
|
148
|
+
cond_val = bool(cond_val.unwrap())
|
|
149
|
+
|
|
150
|
+
if not cond_val:
|
|
151
|
+
break
|
|
152
|
+
|
|
153
|
+
body_res = interpreter.evaluate_graph(body_graph, region_inputs)
|
|
154
|
+
current_state = body_res # body_res is always a list now
|
|
155
|
+
|
|
156
|
+
# Return single value for single output
|
|
157
|
+
return current_state[0] if len(current_state) == 1 else current_state
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
WORKER_HANDLERS = {
|
|
161
|
+
simp.pcall_static_p.name: _pcall_static_worker_impl,
|
|
162
|
+
simp.pcall_dynamic_p.name: _pcall_dynamic_worker_impl,
|
|
163
|
+
simp.shuffle_static_p.name: _shuffle_static_worker_impl,
|
|
164
|
+
simp.converge_p.name: _converge_worker_impl,
|
|
165
|
+
simp.uniform_cond_p.name: _uniform_cond_worker_impl,
|
|
166
|
+
simp.while_loop_p.name: _while_loop_worker_impl,
|
|
167
|
+
}
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Simp Worker state (SimpWorker)."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
import mplang.v2.backends.field_impl # noqa: F401
|
|
22
|
+
import mplang.v2.backends.tensor_impl # noqa: F401
|
|
23
|
+
from mplang.v2.runtime.dialect_state import DialectState
|
|
24
|
+
from mplang.v2.runtime.object_store import ObjectStore
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class SimpWorker(DialectState):
|
|
28
|
+
"""Worker state for SIMP execution.
|
|
29
|
+
|
|
30
|
+
This state provides capabilities (Store, Communicator) to the Interpreter.
|
|
31
|
+
Attached to Worker Interpreters.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
dialect_name: str = "simp"
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
rank: int,
|
|
39
|
+
world_size: int,
|
|
40
|
+
communicator: Any,
|
|
41
|
+
store: ObjectStore,
|
|
42
|
+
spu_endpoints: dict[int, str] | None = None,
|
|
43
|
+
):
|
|
44
|
+
self.rank = rank
|
|
45
|
+
self.world_size = world_size
|
|
46
|
+
self.communicator = communicator
|
|
47
|
+
self.store = store
|
|
48
|
+
self.spu_endpoints = spu_endpoints
|
|
49
|
+
self.current_parties: tuple[int, ...] | None = None
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""SPU Runtime Implementation.
|
|
16
|
+
|
|
17
|
+
Implements execution logic for SPU primitives using libspu.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import base64
|
|
23
|
+
from typing import Any, ClassVar
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
26
|
+
import spu.api as spu_api
|
|
27
|
+
import spu.libspu as libspu
|
|
28
|
+
|
|
29
|
+
from mplang.v2.backends.spu_state import SPUState
|
|
30
|
+
from mplang.v2.backends.tensor_impl import TensorValue
|
|
31
|
+
from mplang.v2.dialects import spu
|
|
32
|
+
from mplang.v2.edsl import serde
|
|
33
|
+
from mplang.v2.edsl.graph import Operation
|
|
34
|
+
from mplang.v2.runtime.interpreter import Interpreter
|
|
35
|
+
from mplang.v2.runtime.value import WrapValue
|
|
36
|
+
|
|
37
|
+
# =============================================================================
|
|
38
|
+
# SPU Share Wrapper
|
|
39
|
+
# =============================================================================
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@serde.register_class
|
|
43
|
+
class SPUShareValue(WrapValue[libspu.Share]):
|
|
44
|
+
"""Wrapper for libspu.Share representing an SPU secret share.
|
|
45
|
+
|
|
46
|
+
This wraps the external libspu library's Share type to provide
|
|
47
|
+
proper serialization support via the Value base class.
|
|
48
|
+
|
|
49
|
+
In-memory, we hold the libspu.Share directly to avoid copying.
|
|
50
|
+
Serialization extracts meta/share_chunks when needed.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
_serde_kind: ClassVar[str] = "spu_impl.SPUShareValue"
|
|
54
|
+
|
|
55
|
+
def _convert(self, data: Any) -> libspu.Share:
|
|
56
|
+
if isinstance(data, SPUShareValue):
|
|
57
|
+
return data.unwrap()
|
|
58
|
+
if isinstance(data, libspu.Share):
|
|
59
|
+
return data
|
|
60
|
+
raise TypeError(f"Expected libspu.Share, got {type(data)}")
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def libspu_share(self) -> libspu.Share:
|
|
64
|
+
"""Get the underlying libspu.Share object."""
|
|
65
|
+
return self._data
|
|
66
|
+
|
|
67
|
+
def to_json(self) -> dict[str, Any]:
|
|
68
|
+
return {
|
|
69
|
+
"meta": base64.b64encode(self._data.meta).decode("ascii"),
|
|
70
|
+
"share_chunks": [
|
|
71
|
+
base64.b64encode(chunk).decode("ascii")
|
|
72
|
+
for chunk in self._data.share_chunks
|
|
73
|
+
],
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def from_json(cls, data: dict[str, Any]) -> SPUShareValue:
|
|
78
|
+
share = libspu.Share()
|
|
79
|
+
share.meta = base64.b64decode(data["meta"])
|
|
80
|
+
share.share_chunks = [
|
|
81
|
+
base64.b64decode(chunk_b64) for chunk_b64 in data["share_chunks"]
|
|
82
|
+
]
|
|
83
|
+
return cls(share)
|
|
84
|
+
|
|
85
|
+
@classmethod
|
|
86
|
+
def from_libspu(cls, share: libspu.Share) -> SPUShareValue:
|
|
87
|
+
"""Create SPUShareValue from a libspu.Share (zero-copy)."""
|
|
88
|
+
return cls(share)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
# =============================================================================
|
|
92
|
+
# SPU Config Helpers
|
|
93
|
+
# =============================================================================
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def to_runtime_config(config: spu.SPUConfig) -> libspu.RuntimeConfig:
|
|
97
|
+
"""Convert SPUConfig to libspu.RuntimeConfig.
|
|
98
|
+
|
|
99
|
+
This is a runtime-only function that maps the string-based configuration
|
|
100
|
+
to libspu enums. Should only be called in the backend implementation.
|
|
101
|
+
"""
|
|
102
|
+
runtime_config = libspu.RuntimeConfig()
|
|
103
|
+
# ProtocolKind uses "SEMI2K" not "PROT_SEMI2K"
|
|
104
|
+
runtime_config.protocol = getattr(libspu.ProtocolKind, config.protocol)
|
|
105
|
+
runtime_config.field = getattr(libspu.FieldType, config.field)
|
|
106
|
+
runtime_config.fxp_fraction_bits = config.fxp_fraction_bits
|
|
107
|
+
return runtime_config
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@spu.makeshares_p.def_impl
|
|
111
|
+
def makeshares_impl(
|
|
112
|
+
interpreter: Interpreter, op: Operation, data: TensorValue
|
|
113
|
+
) -> tuple[SPUShareValue, ...]:
|
|
114
|
+
"""Generate secret shares for data using spu.Io."""
|
|
115
|
+
count = op.attrs["count"]
|
|
116
|
+
config: spu.SPUConfig = op.attrs["config"]
|
|
117
|
+
|
|
118
|
+
# We create a standalone Io for share generation (no link needed for make_shares)
|
|
119
|
+
runtime_config = to_runtime_config(config)
|
|
120
|
+
io = spu_api.Io(count, runtime_config)
|
|
121
|
+
|
|
122
|
+
# Unwrap TensorValue
|
|
123
|
+
arr = data.unwrap()
|
|
124
|
+
|
|
125
|
+
# data is expected to be numpy array
|
|
126
|
+
arr = np.asarray(arr)
|
|
127
|
+
|
|
128
|
+
# Generate shares (VIS_SECRET)
|
|
129
|
+
libspu_shares = io.make_shares(arr, libspu.Visibility.VIS_SECRET)
|
|
130
|
+
|
|
131
|
+
# Wrap libspu.Share objects in SPUShareValue
|
|
132
|
+
return tuple(SPUShareValue.from_libspu(share) for share in libspu_shares)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@spu.reconstruct_p.def_impl
|
|
136
|
+
def reconstruct_impl(
|
|
137
|
+
interpreter: Interpreter, op: Operation, *shares: SPUShareValue
|
|
138
|
+
) -> TensorValue:
|
|
139
|
+
"""Reconstruct data from secret shares using spu.Io."""
|
|
140
|
+
count = len(shares)
|
|
141
|
+
config: spu.SPUConfig = op.attrs["config"]
|
|
142
|
+
|
|
143
|
+
runtime_config = to_runtime_config(config)
|
|
144
|
+
io = spu_api.Io(count, runtime_config)
|
|
145
|
+
|
|
146
|
+
# Unwrap SPUShareValue to libspu.Share
|
|
147
|
+
libspu_shares = [share.libspu_share for share in shares]
|
|
148
|
+
|
|
149
|
+
# Reconstruct
|
|
150
|
+
result = io.reconstruct(libspu_shares)
|
|
151
|
+
|
|
152
|
+
# Wrap result as TensorValue
|
|
153
|
+
return TensorValue.wrap(result)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@spu.exec_p.def_impl
|
|
157
|
+
def exec_impl(interpreter: Interpreter, op: Operation, *args: Any) -> Any:
|
|
158
|
+
"""Execute SPU kernel using spu.Runtime.
|
|
159
|
+
|
|
160
|
+
The SPU config must contain parties info to correctly map global rank
|
|
161
|
+
to local SPU rank and determine SPU world size.
|
|
162
|
+
"""
|
|
163
|
+
from mplang.v2.backends.simp_worker.state import SimpWorker
|
|
164
|
+
|
|
165
|
+
# Get SPU config from attrs (passed through from run_jax)
|
|
166
|
+
config: spu.SPUConfig = op.attrs["config"]
|
|
167
|
+
|
|
168
|
+
# Get parties from SimpWorker state (injected by pcall_static_impl)
|
|
169
|
+
context = interpreter.get_dialect_state("simp")
|
|
170
|
+
if not isinstance(context, SimpWorker):
|
|
171
|
+
raise RuntimeError(f"spu.exec requires SimpWorker, got {type(context)}")
|
|
172
|
+
|
|
173
|
+
parties = context.current_parties
|
|
174
|
+
if parties is None:
|
|
175
|
+
raise RuntimeError(
|
|
176
|
+
"spu.exec requires 'current_parties' in SimpWorker state. "
|
|
177
|
+
"Ensure it is called within a pcall_static block."
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
global_rank = context.rank
|
|
181
|
+
|
|
182
|
+
if global_rank not in parties:
|
|
183
|
+
raise RuntimeError(
|
|
184
|
+
f"Global rank {global_rank} is not in current parties {parties}"
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Convert global rank to local SPU rank
|
|
188
|
+
local_rank = parties.index(global_rank)
|
|
189
|
+
spu_world_size = len(parties)
|
|
190
|
+
|
|
191
|
+
# Get SPU endpoints from interpreter (set by WorkerInterpreter for BRPC mode)
|
|
192
|
+
# spu_endpoints is a dict mapping global_rank -> brpc_endpoint
|
|
193
|
+
spu_endpoints_map: dict[int, str] | None = getattr(
|
|
194
|
+
interpreter, "spu_endpoints", None
|
|
195
|
+
)
|
|
196
|
+
if spu_endpoints_map is None:
|
|
197
|
+
# Try getting from SimpWorker context (context is already SimpWorker)
|
|
198
|
+
spu_endpoints_map = getattr(context, "spu_endpoints", None)
|
|
199
|
+
|
|
200
|
+
# Build ordered list of endpoints for SPU parties
|
|
201
|
+
spu_endpoints: list[str] | None = None
|
|
202
|
+
if spu_endpoints_map is not None:
|
|
203
|
+
spu_endpoints = []
|
|
204
|
+
for party_rank in parties:
|
|
205
|
+
if party_rank not in spu_endpoints_map:
|
|
206
|
+
raise RuntimeError(
|
|
207
|
+
f"SPU endpoint not found for party {party_rank}. "
|
|
208
|
+
f"Available: {list(spu_endpoints_map.keys())}"
|
|
209
|
+
)
|
|
210
|
+
spu_endpoints.append(spu_endpoints_map[party_rank])
|
|
211
|
+
|
|
212
|
+
# Get communicator for Channels mode (reuse existing communication)
|
|
213
|
+
# If no BRPC endpoints configured, use Channels mode
|
|
214
|
+
communicator = None
|
|
215
|
+
if spu_endpoints is None:
|
|
216
|
+
# Use worker's communicator for channel reuse
|
|
217
|
+
# (SimpWorker already imported at function start)
|
|
218
|
+
communicator = context.communicator
|
|
219
|
+
|
|
220
|
+
# Get or create SPUState for caching Runtime/Io
|
|
221
|
+
spu_state = interpreter.get_dialect_state(SPUState.dialect_name)
|
|
222
|
+
if not isinstance(spu_state, SPUState):
|
|
223
|
+
spu_state = SPUState()
|
|
224
|
+
interpreter.set_dialect_state(SPUState.dialect_name, spu_state)
|
|
225
|
+
|
|
226
|
+
runtime, io = spu_state.get_or_create(
|
|
227
|
+
local_rank,
|
|
228
|
+
spu_world_size,
|
|
229
|
+
config,
|
|
230
|
+
spu_endpoints,
|
|
231
|
+
communicator=communicator,
|
|
232
|
+
parties=list(parties),
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
executable_code = op.attrs["executable"]
|
|
236
|
+
input_names = op.attrs["input_names"]
|
|
237
|
+
output_names = op.attrs["output_names"]
|
|
238
|
+
|
|
239
|
+
# Create Executable
|
|
240
|
+
executable = libspu.Executable(
|
|
241
|
+
name="spu_kernel",
|
|
242
|
+
input_names=input_names,
|
|
243
|
+
output_names=output_names,
|
|
244
|
+
code=executable_code,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Set inputs
|
|
248
|
+
for name, share in zip(input_names, args, strict=True):
|
|
249
|
+
# Handle SPUShareValue wrapper - unwrap to libspu.Share
|
|
250
|
+
if isinstance(share, SPUShareValue):
|
|
251
|
+
libspu_share = share.libspu_share
|
|
252
|
+
else:
|
|
253
|
+
# Handle public input (numpy array)
|
|
254
|
+
# Generate shares with VIS_PUBLIC
|
|
255
|
+
# make_shares expects numpy array
|
|
256
|
+
if not isinstance(share, (np.ndarray, np.generic, int, float)):
|
|
257
|
+
share = np.array(share)
|
|
258
|
+
|
|
259
|
+
shares = io.make_shares(share, libspu.Visibility.VIS_PUBLIC)
|
|
260
|
+
libspu_share = shares[local_rank]
|
|
261
|
+
|
|
262
|
+
runtime.set_var(name, libspu_share)
|
|
263
|
+
|
|
264
|
+
# Run
|
|
265
|
+
runtime.run(executable)
|
|
266
|
+
|
|
267
|
+
# Get outputs and wrap in SPUShareValue
|
|
268
|
+
results = []
|
|
269
|
+
for name in output_names:
|
|
270
|
+
libspu_share = runtime.get_var(name)
|
|
271
|
+
results.append(SPUShareValue.from_libspu(libspu_share))
|
|
272
|
+
|
|
273
|
+
if len(results) == 1:
|
|
274
|
+
return results[0]
|
|
275
|
+
return tuple(results)
|