mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
mplang/v1/runtime/session.py
DELETED
|
@@ -1,270 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""Core Session model (pure, no global registries).
|
|
16
|
-
|
|
17
|
-
Contents:
|
|
18
|
-
* SessionState dataclass
|
|
19
|
-
* LinkCommFactory (SPU link reuse cache)
|
|
20
|
-
* Session (topology derivation, runtime init, SPU env seeding, local symbol/computation storage)
|
|
21
|
-
|
|
22
|
-
Process-wide registries (sessions, global symbols) live in the server layer
|
|
23
|
-
(`server.py`) so this module remains portable and easy to unit test.
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
from __future__ import annotations
|
|
27
|
-
|
|
28
|
-
import time
|
|
29
|
-
from dataclasses import dataclass, field
|
|
30
|
-
from functools import cached_property
|
|
31
|
-
from typing import TYPE_CHECKING, Any, cast
|
|
32
|
-
|
|
33
|
-
import spu.libspu as libspu
|
|
34
|
-
|
|
35
|
-
from mplang.v1.core.cluster import ClusterSpec
|
|
36
|
-
from mplang.v1.core.comm import ICommunicator
|
|
37
|
-
from mplang.v1.core.expr.ast import Expr
|
|
38
|
-
from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator
|
|
39
|
-
from mplang.v1.core.mask import Mask
|
|
40
|
-
from mplang.v1.kernels.context import RuntimeContext
|
|
41
|
-
from mplang.v1.kernels.spu import PFunction # type: ignore
|
|
42
|
-
from mplang.v1.kernels.value import Value
|
|
43
|
-
from mplang.v1.runtime.communicator import HttpCommunicator
|
|
44
|
-
from mplang.v1.runtime.exceptions import ResourceNotFound
|
|
45
|
-
from mplang.v1.runtime.link_comm import LinkCommunicator
|
|
46
|
-
from mplang.v1.utils.spu_utils import parse_field, parse_protocol
|
|
47
|
-
|
|
48
|
-
if TYPE_CHECKING: # pragma: no cover - import only for type checking
|
|
49
|
-
from mplang.v1.core.cluster import ClusterSpec, Node, RuntimeInfo
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
@dataclass
|
|
53
|
-
class Symbol:
|
|
54
|
-
name: str
|
|
55
|
-
mptype: Any
|
|
56
|
-
data: Any
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
@dataclass
|
|
60
|
-
class Computation:
|
|
61
|
-
name: str
|
|
62
|
-
expr: Expr
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
@dataclass
|
|
66
|
-
class SessionState:
|
|
67
|
-
runtime: RuntimeContext | None = None
|
|
68
|
-
computations: dict[str, Computation] = field(default_factory=dict)
|
|
69
|
-
symbols: dict[str, Symbol] = field(default_factory=dict)
|
|
70
|
-
spu_seeded: bool = False
|
|
71
|
-
created_ts: float = field(default_factory=time.time)
|
|
72
|
-
last_access_ts: float = field(default_factory=time.time)
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
class Session:
|
|
76
|
-
"""Represents the per-rank execution context.
|
|
77
|
-
|
|
78
|
-
Immutable config: name, rank, cluster_spec, communicator.
|
|
79
|
-
Derived: node, runtime_info, endpoints, spu_device, spu_mask, protocol/field, is_spu_party.
|
|
80
|
-
Mutable: state (runtime object, symbols, computations, seeded flag).
|
|
81
|
-
|
|
82
|
-
Note: communicator is assumed to be initialized with cluster spec info (e.g. endpoints).
|
|
83
|
-
"""
|
|
84
|
-
|
|
85
|
-
def __init__(
|
|
86
|
-
self,
|
|
87
|
-
name: str,
|
|
88
|
-
rank: int,
|
|
89
|
-
cluster_spec: ClusterSpec,
|
|
90
|
-
communicator: ICommunicator,
|
|
91
|
-
):
|
|
92
|
-
self.name = name
|
|
93
|
-
self.rank = rank
|
|
94
|
-
self.cluster_spec = cluster_spec
|
|
95
|
-
self.state = SessionState()
|
|
96
|
-
self.communicator = communicator
|
|
97
|
-
|
|
98
|
-
# --- Derived topology ---
|
|
99
|
-
@cached_property
|
|
100
|
-
def node(self) -> Node:
|
|
101
|
-
return self.cluster_spec.get_node_by_rank(self.rank)
|
|
102
|
-
|
|
103
|
-
@property
|
|
104
|
-
def runtime_info(self) -> RuntimeInfo:
|
|
105
|
-
return self.node.runtime_info
|
|
106
|
-
|
|
107
|
-
@property
|
|
108
|
-
def endpoints(self) -> list[str]:
|
|
109
|
-
return self.cluster_spec.endpoints
|
|
110
|
-
|
|
111
|
-
@cached_property
|
|
112
|
-
def spu_device(self): # type: ignore
|
|
113
|
-
devs = self.cluster_spec.get_devices_by_kind("SPU")
|
|
114
|
-
if len(devs) != 1:
|
|
115
|
-
raise RuntimeError(
|
|
116
|
-
f"Expected exactly one SPU device, got {len(devs)} (session={self.name})"
|
|
117
|
-
)
|
|
118
|
-
return devs[0]
|
|
119
|
-
|
|
120
|
-
@cached_property
|
|
121
|
-
def spu_mask(self) -> Mask:
|
|
122
|
-
return Mask.from_ranks([m.rank for m in self.spu_device.members])
|
|
123
|
-
|
|
124
|
-
@property
|
|
125
|
-
def spu_protocol(self) -> str:
|
|
126
|
-
return cast(str, self.spu_device.config.get("protocol", "SEMI2K"))
|
|
127
|
-
|
|
128
|
-
@property
|
|
129
|
-
def spu_field(self) -> str:
|
|
130
|
-
return cast(str, self.spu_device.config.get("field", "FM64"))
|
|
131
|
-
|
|
132
|
-
@property
|
|
133
|
-
def is_spu_party(self) -> bool:
|
|
134
|
-
return self.rank in self.spu_mask
|
|
135
|
-
|
|
136
|
-
# --- Runtime helpers ---
|
|
137
|
-
def ensure_runtime(self) -> RuntimeContext:
|
|
138
|
-
if self.state.runtime is None:
|
|
139
|
-
self.state.runtime = RuntimeContext(
|
|
140
|
-
rank=self.rank,
|
|
141
|
-
world_size=len(self.cluster_spec.nodes), # type: ignore[attr-defined]
|
|
142
|
-
initial_bindings=(
|
|
143
|
-
self.runtime_info.op_bindings if self.runtime_info else {}
|
|
144
|
-
),
|
|
145
|
-
)
|
|
146
|
-
return self.state.runtime
|
|
147
|
-
|
|
148
|
-
def ensure_spu_env(self) -> None:
|
|
149
|
-
"""Ensure SPU kernel env (config/world[/link]) registered on this runtime.
|
|
150
|
-
|
|
151
|
-
Previous logic only seeded SPU parties; non-participating ranks then raised
|
|
152
|
-
a hard error when the evaluator encountered SPU ops in the global program,
|
|
153
|
-
because the kernel pocket lacked config/world. For now we register the
|
|
154
|
-
config/world on ALL parties (idempotent) and only attach a link context for
|
|
155
|
-
participating SPU ranks. Non-parties will still error later if they try to
|
|
156
|
-
execute a link-dependent SPU kernel (which should be guarded by masks in the
|
|
157
|
-
IR), but they will no longer fail early with a misleading
|
|
158
|
-
"SPU kernel state not initialized" message.
|
|
159
|
-
"""
|
|
160
|
-
if self.state.spu_seeded:
|
|
161
|
-
return
|
|
162
|
-
|
|
163
|
-
link_ctx = None
|
|
164
|
-
|
|
165
|
-
if self.is_spu_party:
|
|
166
|
-
# Use Channels mode to reuse existing HttpCommunicator
|
|
167
|
-
# This eliminates the need for separate BRPC ports (SPU_PORT_OFFSET)
|
|
168
|
-
from mplang.v1.core.comm import CommunicatorBase
|
|
169
|
-
|
|
170
|
-
# Type assertion: ICommunicator is actually CommunicatorBase
|
|
171
|
-
comm = cast(CommunicatorBase, self.communicator)
|
|
172
|
-
link_ctx = LinkCommunicator(
|
|
173
|
-
rank=self.rank,
|
|
174
|
-
comm=comm,
|
|
175
|
-
spu_mask=self.spu_mask,
|
|
176
|
-
)
|
|
177
|
-
|
|
178
|
-
spu_config = libspu.RuntimeConfig(
|
|
179
|
-
protocol=parse_protocol(self.spu_protocol),
|
|
180
|
-
field=parse_field(self.spu_field),
|
|
181
|
-
fxp_fraction_bits=18,
|
|
182
|
-
)
|
|
183
|
-
seed_pfunc = PFunction(
|
|
184
|
-
fn_type="spu.seed_env",
|
|
185
|
-
ins_info=(),
|
|
186
|
-
outs_info=(),
|
|
187
|
-
config=spu_config,
|
|
188
|
-
world=self.spu_mask.num_parties(),
|
|
189
|
-
link=link_ctx,
|
|
190
|
-
)
|
|
191
|
-
self.ensure_runtime().run_kernel(seed_pfunc, [])
|
|
192
|
-
self.state.spu_seeded = True
|
|
193
|
-
|
|
194
|
-
# --- Computations & Symbols (instance-local) ---
|
|
195
|
-
def add_computation(self, computation: Computation) -> None:
|
|
196
|
-
self.state.computations[computation.name] = computation
|
|
197
|
-
|
|
198
|
-
def get_computation(self, name: str) -> Computation | None:
|
|
199
|
-
return self.state.computations.get(name)
|
|
200
|
-
|
|
201
|
-
def add_symbol(self, symbol: Symbol) -> None:
|
|
202
|
-
self.state.symbols[symbol.name] = symbol
|
|
203
|
-
|
|
204
|
-
def get_symbol(self, name: str) -> Symbol | None:
|
|
205
|
-
return self.state.symbols.get(name)
|
|
206
|
-
|
|
207
|
-
def list_symbols(self) -> list[str]: # pragma: no cover - trivial
|
|
208
|
-
return list(self.state.symbols.keys())
|
|
209
|
-
|
|
210
|
-
def delete_symbol(self, name: str) -> bool:
|
|
211
|
-
if name in self.state.symbols:
|
|
212
|
-
del self.state.symbols[name]
|
|
213
|
-
return True
|
|
214
|
-
return False
|
|
215
|
-
|
|
216
|
-
def list_computations(self) -> list[str]: # pragma: no cover - trivial
|
|
217
|
-
return list(self.state.computations.keys())
|
|
218
|
-
|
|
219
|
-
def delete_computation(self, name: str) -> bool:
|
|
220
|
-
if name in self.state.computations:
|
|
221
|
-
del self.state.computations[name]
|
|
222
|
-
return True
|
|
223
|
-
return False
|
|
224
|
-
|
|
225
|
-
# --- Execution ---
|
|
226
|
-
def execute(
|
|
227
|
-
self, computation: Computation, input_names: list[str], output_names: list[str]
|
|
228
|
-
) -> None:
|
|
229
|
-
env: dict[str, Any] = {}
|
|
230
|
-
for in_name in input_names:
|
|
231
|
-
sym = self.get_symbol(in_name)
|
|
232
|
-
if sym is None:
|
|
233
|
-
raise ResourceNotFound(
|
|
234
|
-
f"Input symbol '{in_name}' not found in session '{self.name}'"
|
|
235
|
-
)
|
|
236
|
-
env[in_name] = sym.data
|
|
237
|
-
rt = self.ensure_runtime()
|
|
238
|
-
self.ensure_spu_env()
|
|
239
|
-
evaluator: IEvaluator = create_evaluator(
|
|
240
|
-
rank=self.rank, env=env, comm=self.communicator, runtime=rt
|
|
241
|
-
)
|
|
242
|
-
results = evaluator.evaluate(computation.expr)
|
|
243
|
-
if results and len(results) != len(output_names):
|
|
244
|
-
raise RuntimeError(
|
|
245
|
-
f"Expected {len(output_names)} results, got {len(results)}"
|
|
246
|
-
)
|
|
247
|
-
for name, val in zip(output_names, results, strict=True):
|
|
248
|
-
# In pure SIMP model, all nodes should have the same symbol table.
|
|
249
|
-
# Non-participating nodes get None values.
|
|
250
|
-
if val is not None and not isinstance(val, Value):
|
|
251
|
-
raise TypeError(
|
|
252
|
-
"Session executions must produce kernel Value outputs; "
|
|
253
|
-
f"got {type(val).__name__} for symbol '{name}'"
|
|
254
|
-
)
|
|
255
|
-
self.add_symbol(Symbol(name=name, mptype={}, data=val))
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
# --- Convenience constructor use HttpCommunicator---
|
|
259
|
-
def create_session_from_spec(name: str, rank: int, spec: ClusterSpec) -> Session:
|
|
260
|
-
if len(spec.get_devices_by_kind("SPU")) == 0:
|
|
261
|
-
raise RuntimeError("No SPU device found in cluster_spec")
|
|
262
|
-
|
|
263
|
-
# Create HttpCommunicator for the session
|
|
264
|
-
communicator = HttpCommunicator(
|
|
265
|
-
session_name=name,
|
|
266
|
-
rank=rank,
|
|
267
|
-
endpoints=spec.endpoints,
|
|
268
|
-
)
|
|
269
|
-
|
|
270
|
-
return Session(name=name, rank=rank, cluster_spec=spec, communicator=communicator)
|
mplang/v1/runtime/simulation.py
DELETED
|
@@ -1,324 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
import concurrent.futures
|
|
18
|
-
import faulthandler
|
|
19
|
-
import logging
|
|
20
|
-
import sys
|
|
21
|
-
import threading
|
|
22
|
-
import traceback
|
|
23
|
-
from collections.abc import Sequence
|
|
24
|
-
from typing import Any, cast
|
|
25
|
-
|
|
26
|
-
import spu.libspu as libspu
|
|
27
|
-
|
|
28
|
-
from mplang.v1.core import (
|
|
29
|
-
ClusterSpec,
|
|
30
|
-
CollectiveMixin,
|
|
31
|
-
CommunicatorBase,
|
|
32
|
-
InterpContext,
|
|
33
|
-
InterpVar,
|
|
34
|
-
IrReader,
|
|
35
|
-
IrWriter,
|
|
36
|
-
Mask,
|
|
37
|
-
MPObject,
|
|
38
|
-
MPType,
|
|
39
|
-
PFunction, # for spu.seed_env kernel seeding
|
|
40
|
-
TensorLike,
|
|
41
|
-
)
|
|
42
|
-
from mplang.v1.core.expr.ast import Expr
|
|
43
|
-
from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator
|
|
44
|
-
from mplang.v1.kernels.context import RuntimeContext
|
|
45
|
-
from mplang.v1.runtime.link_comm import LinkCommunicator
|
|
46
|
-
from mplang.v1.utils.spu_utils import parse_field, parse_protocol
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
class ThreadCommunicator(CommunicatorBase, CollectiveMixin):
|
|
50
|
-
"""Thread-based communicator for in-memory communication between threads"""
|
|
51
|
-
|
|
52
|
-
def __init__(self, rank: int, world_size: int):
|
|
53
|
-
super().__init__(rank, world_size)
|
|
54
|
-
self.peers: list[ThreadCommunicator] = []
|
|
55
|
-
logging.debug(
|
|
56
|
-
f"ThreadCommunicator initialized with rank={self.rank}, world_size={self.world_size}"
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
def set_peers(self, peers: list[ThreadCommunicator]) -> None:
|
|
60
|
-
assert self.world_size == len(peers)
|
|
61
|
-
self.peers = peers
|
|
62
|
-
|
|
63
|
-
def send(self, to: int, key: str, data: Any) -> None:
|
|
64
|
-
assert 0 <= to < self.world_size
|
|
65
|
-
# print(f"send {key}: {self.rank} -> {to_rank}")
|
|
66
|
-
self.peers[to].onSent(self.rank, key, data)
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
class SimVar(InterpVar):
|
|
70
|
-
"""A variable that references a value in an interpreter.
|
|
71
|
-
|
|
72
|
-
SimVar represents a value that has been computed and exists
|
|
73
|
-
in the interpreter's variable store.
|
|
74
|
-
"""
|
|
75
|
-
|
|
76
|
-
def __init__(self, ctx: Simulator, mptype: MPType, values: list[Any]):
|
|
77
|
-
# Initialize the parent InterpVar with a generated name
|
|
78
|
-
super().__init__(ctx, mptype)
|
|
79
|
-
self._values = values
|
|
80
|
-
|
|
81
|
-
@property
|
|
82
|
-
def values(self) -> list[Any]:
|
|
83
|
-
"""Converted values across all ranks for user inspection."""
|
|
84
|
-
return [v.to_numpy() if hasattr(v, "to_numpy") else v for v in self._values]
|
|
85
|
-
|
|
86
|
-
def __repr__(self) -> str:
|
|
87
|
-
return f"SimVar({self.mptype})"
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
class Simulator(InterpContext):
|
|
91
|
-
def __init__(
|
|
92
|
-
self,
|
|
93
|
-
cluster_spec: ClusterSpec,
|
|
94
|
-
*,
|
|
95
|
-
trace_ranks: list[int] | None = None,
|
|
96
|
-
) -> None:
|
|
97
|
-
"""Initialize a simulator with the given cluster specification.
|
|
98
|
-
|
|
99
|
-
Args:
|
|
100
|
-
cluster_spec: The cluster specification defining the simulation environment.
|
|
101
|
-
trace_ranks: List of ranks to trace execution for debugging.
|
|
102
|
-
Per-node op binding overrides should now be provided via
|
|
103
|
-
each node's `runtime_info.op_bindings` in the supplied
|
|
104
|
-
`cluster_spec`.
|
|
105
|
-
"""
|
|
106
|
-
super().__init__(cluster_spec)
|
|
107
|
-
self._trace_ranks = trace_ranks or []
|
|
108
|
-
|
|
109
|
-
spu_devices = cluster_spec.get_devices_by_kind("SPU")
|
|
110
|
-
if not spu_devices:
|
|
111
|
-
raise ValueError("No SPU device found in the cluster specification")
|
|
112
|
-
if len(spu_devices) > 1:
|
|
113
|
-
raise ValueError("Multiple SPU devices found in the cluster specification")
|
|
114
|
-
spu_device = spu_devices[0]
|
|
115
|
-
|
|
116
|
-
# compute spu_mask from spu_device members
|
|
117
|
-
spu_mask = Mask.from_ranks([member.rank for member in spu_device.members])
|
|
118
|
-
|
|
119
|
-
# Convert protocol and field from config using utility functions
|
|
120
|
-
spu_protocol = parse_protocol(spu_device.config["protocol"])
|
|
121
|
-
spu_field = parse_field(spu_device.config["field"])
|
|
122
|
-
|
|
123
|
-
world_size = self.world_size()
|
|
124
|
-
|
|
125
|
-
# Setup communicators
|
|
126
|
-
self._comms = [
|
|
127
|
-
ThreadCommunicator(rank, world_size) for rank in range(world_size)
|
|
128
|
-
]
|
|
129
|
-
for comm in self._comms:
|
|
130
|
-
comm.set_peers(self._comms)
|
|
131
|
-
|
|
132
|
-
# Prepare link contexts for SPU parties (store for evaluator-time initialization)
|
|
133
|
-
# Use Channels mode to reuse ThreadCommunicator instead of separate mem_link
|
|
134
|
-
self._spu_link_ctxs: list[LinkCommunicator | None] = [None] * world_size
|
|
135
|
-
|
|
136
|
-
# Create LinkCommunicators in parallel to avoid deadlock
|
|
137
|
-
# (create_with_channels does handshake via TestSend/TestRecv)
|
|
138
|
-
exceptions: dict[int, Exception] = {}
|
|
139
|
-
|
|
140
|
-
def create_link(g_rank: int) -> None:
|
|
141
|
-
try:
|
|
142
|
-
self._spu_link_ctxs[g_rank] = LinkCommunicator(
|
|
143
|
-
rank=g_rank,
|
|
144
|
-
comm=self._comms[g_rank],
|
|
145
|
-
spu_mask=spu_mask,
|
|
146
|
-
)
|
|
147
|
-
except Exception as e:
|
|
148
|
-
exceptions[g_rank] = e
|
|
149
|
-
|
|
150
|
-
threads = [
|
|
151
|
-
threading.Thread(target=create_link, args=(g_rank,)) for g_rank in spu_mask
|
|
152
|
-
]
|
|
153
|
-
for t in threads:
|
|
154
|
-
t.start()
|
|
155
|
-
for t in threads:
|
|
156
|
-
t.join()
|
|
157
|
-
|
|
158
|
-
# Check for exceptions during link creation
|
|
159
|
-
if exceptions:
|
|
160
|
-
first_exc = next(iter(exceptions.values()))
|
|
161
|
-
raise RuntimeError(
|
|
162
|
-
f"Failed to create SPU link contexts for ranks {list(exceptions.keys())}"
|
|
163
|
-
) from first_exc
|
|
164
|
-
|
|
165
|
-
self._spu_runtime_cfg = libspu.RuntimeConfig(
|
|
166
|
-
protocol=spu_protocol, field=spu_field
|
|
167
|
-
)
|
|
168
|
-
self._spu_world = spu_mask.num_parties()
|
|
169
|
-
self._spu_mask = spu_mask
|
|
170
|
-
|
|
171
|
-
# Persistent per-rank RuntimeContext instances (reused across evaluates).
|
|
172
|
-
# We no longer pre-create evaluators since each evaluate has different env bindings.
|
|
173
|
-
# Build per-rank runtime contexts.
|
|
174
|
-
self._runtimes: list[RuntimeContext] = []
|
|
175
|
-
for rank in range(self.world_size()):
|
|
176
|
-
node = self.cluster_spec.get_node_by_rank(rank)
|
|
177
|
-
rt = RuntimeContext(
|
|
178
|
-
rank=rank,
|
|
179
|
-
world_size=self.world_size(),
|
|
180
|
-
initial_bindings=node.runtime_info.op_bindings,
|
|
181
|
-
)
|
|
182
|
-
self._runtimes.append(rt)
|
|
183
|
-
|
|
184
|
-
@classmethod
|
|
185
|
-
def simple(
|
|
186
|
-
cls,
|
|
187
|
-
world_size: int,
|
|
188
|
-
op_bindings: dict[str, str] | None = None,
|
|
189
|
-
**kwargs: Any,
|
|
190
|
-
) -> Simulator:
|
|
191
|
-
"""Create a simple simulator with the given number of parties.
|
|
192
|
-
|
|
193
|
-
This is a convenience method that creates a ClusterSpec.simple()
|
|
194
|
-
configuration for quick testing and prototyping.
|
|
195
|
-
|
|
196
|
-
Args:
|
|
197
|
-
world_size: Number of simulated parties.
|
|
198
|
-
**kwargs: Additional arguments passed to the Simulator constructor.
|
|
199
|
-
|
|
200
|
-
Returns:
|
|
201
|
-
A Simulator instance with a simple cluster configuration.
|
|
202
|
-
"""
|
|
203
|
-
cluster_spec = ClusterSpec.simple(world_size)
|
|
204
|
-
if op_bindings:
|
|
205
|
-
# Apply the same op_bindings to every node's runtime_info for convenience
|
|
206
|
-
for node in cluster_spec.nodes.values():
|
|
207
|
-
node.runtime_info.op_bindings.update(op_bindings)
|
|
208
|
-
return cls(cluster_spec, **kwargs)
|
|
209
|
-
|
|
210
|
-
def _do_evaluate(self, expr: Expr, evaluator_engine: IEvaluator) -> Any:
|
|
211
|
-
"""
|
|
212
|
-
Helper function to simulate real-world MPIR serialization/deserialization
|
|
213
|
-
process instead of direct expr.accept execution.
|
|
214
|
-
|
|
215
|
-
This exposes potential MPIR serialization bugs by forcing expressions
|
|
216
|
-
to go through the full serialize->deserialize cycle.
|
|
217
|
-
"""
|
|
218
|
-
writer = IrWriter()
|
|
219
|
-
graph_proto = writer.dumps(expr)
|
|
220
|
-
|
|
221
|
-
reader = IrReader()
|
|
222
|
-
deserialized_expr = reader.loads(graph_proto)
|
|
223
|
-
|
|
224
|
-
if deserialized_expr is None:
|
|
225
|
-
raise ValueError("Failed to deserialize expression")
|
|
226
|
-
|
|
227
|
-
return evaluator_engine.evaluate(deserialized_expr)
|
|
228
|
-
|
|
229
|
-
# override
|
|
230
|
-
def fetch(self, obj: MPObject) -> list[TensorLike]:
|
|
231
|
-
if not isinstance(obj, SimVar):
|
|
232
|
-
raise ValueError(f"Expected SimVar, got {type(obj)}")
|
|
233
|
-
return [v.to_numpy() if hasattr(v, "to_numpy") else v for v in obj._values]
|
|
234
|
-
|
|
235
|
-
# override
|
|
236
|
-
def evaluate(self, expr: Expr, bindings: dict[str, MPObject]) -> Sequence[MPObject]:
|
|
237
|
-
# sanity check for bindings.
|
|
238
|
-
for name, var in bindings.items():
|
|
239
|
-
if var.ctx is not self:
|
|
240
|
-
raise ValueError(f"Variable {name} not in this context, got {var.ctx}.")
|
|
241
|
-
|
|
242
|
-
pts_env = [
|
|
243
|
-
{name: cast(SimVar, var)._values[rank] for name, var in bindings.items()}
|
|
244
|
-
for rank in range(self.world_size())
|
|
245
|
-
]
|
|
246
|
-
|
|
247
|
-
# Build per-rank evaluators with the per-party environment (runtime reused)
|
|
248
|
-
pts_evaluators: list[IEvaluator] = []
|
|
249
|
-
for rank in range(self.world_size()):
|
|
250
|
-
runtime = self._runtimes[rank]
|
|
251
|
-
ev = create_evaluator(
|
|
252
|
-
rank,
|
|
253
|
-
pts_env[rank],
|
|
254
|
-
self._comms[rank],
|
|
255
|
-
runtime,
|
|
256
|
-
None,
|
|
257
|
-
)
|
|
258
|
-
# Seed SPU once per runtime (idempotent logical requirement)
|
|
259
|
-
# Use setdefault to both retrieve and create metadata dict in one step.
|
|
260
|
-
spu_meta = runtime.state.setdefault("_spu", {})
|
|
261
|
-
if not spu_meta.get("inited", False):
|
|
262
|
-
link_ctx = self._spu_link_ctxs[rank]
|
|
263
|
-
seed_fn = PFunction(
|
|
264
|
-
fn_type="spu.seed_env",
|
|
265
|
-
ins_info=(),
|
|
266
|
-
outs_info=(),
|
|
267
|
-
config=self._spu_runtime_cfg,
|
|
268
|
-
world=self._spu_world,
|
|
269
|
-
link=link_ctx,
|
|
270
|
-
)
|
|
271
|
-
ev.runtime.run_kernel(seed_fn, []) # type: ignore[arg-type]
|
|
272
|
-
spu_meta["inited"] = True
|
|
273
|
-
pts_evaluators.append(ev)
|
|
274
|
-
|
|
275
|
-
# Collect evaluation results from all parties
|
|
276
|
-
pts_results: list[Any] = []
|
|
277
|
-
|
|
278
|
-
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
279
|
-
futures = [
|
|
280
|
-
executor.submit(self._do_evaluate, expr, evaluator)
|
|
281
|
-
for evaluator in pts_evaluators
|
|
282
|
-
]
|
|
283
|
-
|
|
284
|
-
# Collect results with proper exception handling
|
|
285
|
-
for i, future in enumerate(futures):
|
|
286
|
-
try:
|
|
287
|
-
result = future.result(100) # 100 second timeout
|
|
288
|
-
pts_results.append(result)
|
|
289
|
-
except concurrent.futures.TimeoutError:
|
|
290
|
-
faulthandler.dump_traceback(file=sys.stderr, all_threads=True)
|
|
291
|
-
raise
|
|
292
|
-
except Exception as e:
|
|
293
|
-
print(
|
|
294
|
-
f"Exception in party {i}: {type(e).__name__}: {e}",
|
|
295
|
-
file=sys.stderr,
|
|
296
|
-
)
|
|
297
|
-
traceback.print_exc(file=sys.stderr)
|
|
298
|
-
executor.shutdown(wait=False, cancel_futures=True)
|
|
299
|
-
raise
|
|
300
|
-
|
|
301
|
-
# Convert results to SimVar objects
|
|
302
|
-
# pts_results is a list of party results, where each party result is a list of values
|
|
303
|
-
# We need to transpose this to get (n_outputs, n_parties) structure
|
|
304
|
-
assert len(pts_results) == self.world_size()
|
|
305
|
-
|
|
306
|
-
# Ensure all parties returned the same number of outputs (matrix validation)
|
|
307
|
-
if pts_results and not all(
|
|
308
|
-
len(row) == len(pts_results[0]) for row in pts_results
|
|
309
|
-
):
|
|
310
|
-
raise ValueError("Inconsistent number of outputs across parties")
|
|
311
|
-
|
|
312
|
-
# Transpose: (n_parties, n_outputs) -> (n_outputs, n_parties)
|
|
313
|
-
output_values = list(zip(*pts_results, strict=False))
|
|
314
|
-
|
|
315
|
-
# Get the output types from the expression
|
|
316
|
-
output_types = expr.mptypes
|
|
317
|
-
|
|
318
|
-
# Create SimVar objects for each output
|
|
319
|
-
sim_vars = []
|
|
320
|
-
for values, mptype in zip(output_values, output_types, strict=False):
|
|
321
|
-
sim_var = SimVar(self, mptype, list(values))
|
|
322
|
-
sim_vars.append(sim_var)
|
|
323
|
-
|
|
324
|
-
return sim_vars
|
mplang/v1/simp/__init__.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|