mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""
|
|
2
|
+
@generated by mypy-protobuf. Do not edit manually!
|
|
3
|
+
isort:skip_file
|
|
4
|
+
Copyright 2025 Ant Group Co., Ltd.
|
|
5
|
+
|
|
6
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
you may not use this file except in compliance with the License.
|
|
8
|
+
You may obtain a copy of the License at
|
|
9
|
+
|
|
10
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
|
|
12
|
+
Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
See the License for the specific language governing permissions and
|
|
16
|
+
limitations under the License.
|
|
17
|
+
"""
|
|
18
|
+
import builtins
|
|
19
|
+
import collections.abc
|
|
20
|
+
import google.protobuf.descriptor
|
|
21
|
+
import google.protobuf.internal.containers
|
|
22
|
+
import google.protobuf.internal.enum_type_wrapper
|
|
23
|
+
import google.protobuf.message
|
|
24
|
+
import sys
|
|
25
|
+
import typing
|
|
26
|
+
|
|
27
|
+
if sys.version_info >= (3, 10):
|
|
28
|
+
import typing as typing_extensions
|
|
29
|
+
else:
|
|
30
|
+
import typing_extensions
|
|
31
|
+
|
|
32
|
+
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
|
33
|
+
|
|
34
|
+
@typing_extensions.final
|
|
35
|
+
class ValueAttrProto(google.protobuf.message.Message):
|
|
36
|
+
"""Lightweight attribute proto used solely for ValueProto runtime metadata."""
|
|
37
|
+
|
|
38
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
39
|
+
|
|
40
|
+
class _AttrType:
|
|
41
|
+
ValueType = typing.NewType("ValueType", builtins.int)
|
|
42
|
+
V: typing_extensions.TypeAlias = ValueType
|
|
43
|
+
|
|
44
|
+
class _AttrTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[ValueAttrProto._AttrType.ValueType], builtins.type):
|
|
45
|
+
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
|
46
|
+
UNDEFINED: ValueAttrProto._AttrType.ValueType # 0
|
|
47
|
+
FLOAT: ValueAttrProto._AttrType.ValueType # 1
|
|
48
|
+
INT: ValueAttrProto._AttrType.ValueType # 2
|
|
49
|
+
STRING: ValueAttrProto._AttrType.ValueType # 3
|
|
50
|
+
BOOL: ValueAttrProto._AttrType.ValueType # 4
|
|
51
|
+
BYTES: ValueAttrProto._AttrType.ValueType # 5
|
|
52
|
+
FLOATS: ValueAttrProto._AttrType.ValueType # 6
|
|
53
|
+
INTS: ValueAttrProto._AttrType.ValueType # 7
|
|
54
|
+
STRINGS: ValueAttrProto._AttrType.ValueType # 8
|
|
55
|
+
EMPTY: ValueAttrProto._AttrType.ValueType # 9
|
|
56
|
+
"""Represents an explicitly empty attribute value (e.g., empty list)"""
|
|
57
|
+
|
|
58
|
+
class AttrType(_AttrType, metaclass=_AttrTypeEnumTypeWrapper): ...
|
|
59
|
+
UNDEFINED: ValueAttrProto.AttrType.ValueType # 0
|
|
60
|
+
FLOAT: ValueAttrProto.AttrType.ValueType # 1
|
|
61
|
+
INT: ValueAttrProto.AttrType.ValueType # 2
|
|
62
|
+
STRING: ValueAttrProto.AttrType.ValueType # 3
|
|
63
|
+
BOOL: ValueAttrProto.AttrType.ValueType # 4
|
|
64
|
+
BYTES: ValueAttrProto.AttrType.ValueType # 5
|
|
65
|
+
FLOATS: ValueAttrProto.AttrType.ValueType # 6
|
|
66
|
+
INTS: ValueAttrProto.AttrType.ValueType # 7
|
|
67
|
+
STRINGS: ValueAttrProto.AttrType.ValueType # 8
|
|
68
|
+
EMPTY: ValueAttrProto.AttrType.ValueType # 9
|
|
69
|
+
"""Represents an explicitly empty attribute value (e.g., empty list)"""
|
|
70
|
+
|
|
71
|
+
TYPE_FIELD_NUMBER: builtins.int
|
|
72
|
+
F_FIELD_NUMBER: builtins.int
|
|
73
|
+
I_FIELD_NUMBER: builtins.int
|
|
74
|
+
S_FIELD_NUMBER: builtins.int
|
|
75
|
+
B_FIELD_NUMBER: builtins.int
|
|
76
|
+
RAW_BYTES_FIELD_NUMBER: builtins.int
|
|
77
|
+
FLOATS_FIELD_NUMBER: builtins.int
|
|
78
|
+
INTS_FIELD_NUMBER: builtins.int
|
|
79
|
+
STRS_FIELD_NUMBER: builtins.int
|
|
80
|
+
type: global___ValueAttrProto.AttrType.ValueType
|
|
81
|
+
f: builtins.float
|
|
82
|
+
i: builtins.int
|
|
83
|
+
s: builtins.str
|
|
84
|
+
b: builtins.bool
|
|
85
|
+
raw_bytes: builtins.bytes
|
|
86
|
+
@property
|
|
87
|
+
def floats(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: ...
|
|
88
|
+
@property
|
|
89
|
+
def ints(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
|
|
90
|
+
@property
|
|
91
|
+
def strs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
|
|
92
|
+
def __init__(
|
|
93
|
+
self,
|
|
94
|
+
*,
|
|
95
|
+
type: global___ValueAttrProto.AttrType.ValueType = ...,
|
|
96
|
+
f: builtins.float = ...,
|
|
97
|
+
i: builtins.int = ...,
|
|
98
|
+
s: builtins.str = ...,
|
|
99
|
+
b: builtins.bool = ...,
|
|
100
|
+
raw_bytes: builtins.bytes = ...,
|
|
101
|
+
floats: collections.abc.Iterable[builtins.float] | None = ...,
|
|
102
|
+
ints: collections.abc.Iterable[builtins.int] | None = ...,
|
|
103
|
+
strs: collections.abc.Iterable[builtins.str] | None = ...,
|
|
104
|
+
) -> None: ...
|
|
105
|
+
def ClearField(self, field_name: typing_extensions.Literal["b", b"b", "f", b"f", "floats", b"floats", "i", b"i", "ints", b"ints", "raw_bytes", b"raw_bytes", "s", b"s", "strs", b"strs", "type", b"type"]) -> None: ...
|
|
106
|
+
|
|
107
|
+
global___ValueAttrProto = ValueAttrProto
|
|
108
|
+
|
|
109
|
+
@typing_extensions.final
|
|
110
|
+
class ValueProto(google.protobuf.message.Message):
|
|
111
|
+
"""Generic envelope for kernel-level transferable values.
|
|
112
|
+
|
|
113
|
+
DESIGN PRINCIPLES
|
|
114
|
+
* Small, stable schema: only descriptors needed for dynamic dispatch.
|
|
115
|
+
* Payload is opaque to the envelope; per-kind versioning lives in
|
|
116
|
+
value_version.
|
|
117
|
+
* Backward-compatible evolution: only append new optional fields.
|
|
118
|
+
|
|
119
|
+
Versioning Guidelines:
|
|
120
|
+
- value_version: per-kind semantic payload version (maintained by KernelValue
|
|
121
|
+
subclass).
|
|
122
|
+
- Adding fields: assign new unique field numbers; never reuse old numbers.
|
|
123
|
+
- Removing fields: reserve the field number & (optionally) name.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
127
|
+
|
|
128
|
+
@typing_extensions.final
|
|
129
|
+
class RuntimeAttrsEntry(google.protobuf.message.Message):
|
|
130
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
131
|
+
|
|
132
|
+
KEY_FIELD_NUMBER: builtins.int
|
|
133
|
+
VALUE_FIELD_NUMBER: builtins.int
|
|
134
|
+
key: builtins.str
|
|
135
|
+
@property
|
|
136
|
+
def value(self) -> global___ValueAttrProto: ...
|
|
137
|
+
def __init__(
|
|
138
|
+
self,
|
|
139
|
+
*,
|
|
140
|
+
key: builtins.str = ...,
|
|
141
|
+
value: global___ValueAttrProto | None = ...,
|
|
142
|
+
) -> None: ...
|
|
143
|
+
def HasField(self, field_name: typing_extensions.Literal["value", b"value"]) -> builtins.bool: ...
|
|
144
|
+
def ClearField(self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]) -> None: ...
|
|
145
|
+
|
|
146
|
+
KIND_FIELD_NUMBER: builtins.int
|
|
147
|
+
VALUE_VERSION_FIELD_NUMBER: builtins.int
|
|
148
|
+
PAYLOAD_FIELD_NUMBER: builtins.int
|
|
149
|
+
RUNTIME_ATTRS_FIELD_NUMBER: builtins.int
|
|
150
|
+
kind: builtins.str
|
|
151
|
+
"""Globally unique identifier for Value subclass, e.g. "mplang.ndarray"."""
|
|
152
|
+
value_version: builtins.int
|
|
153
|
+
"""Per-kind payload schema version (>=1)."""
|
|
154
|
+
payload: builtins.bytes
|
|
155
|
+
"""Primary data payload bytes (layout defined by each Value subclass)."""
|
|
156
|
+
@property
|
|
157
|
+
def runtime_attrs(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___ValueAttrProto]:
|
|
158
|
+
"""Additional runtime metadata required to recreate the Value instance."""
|
|
159
|
+
def __init__(
|
|
160
|
+
self,
|
|
161
|
+
*,
|
|
162
|
+
kind: builtins.str = ...,
|
|
163
|
+
value_version: builtins.int = ...,
|
|
164
|
+
payload: builtins.bytes = ...,
|
|
165
|
+
runtime_attrs: collections.abc.Mapping[builtins.str, global___ValueAttrProto] | None = ...,
|
|
166
|
+
) -> None: ...
|
|
167
|
+
def ClearField(self, field_name: typing_extensions.Literal["kind", b"kind", "payload", b"payload", "runtime_attrs", b"runtime_attrs", "value_version", b"value_version"]) -> None: ...
|
|
168
|
+
|
|
169
|
+
global___ValueProto = ValueProto
|
|
@@ -20,8 +20,8 @@ This module contains runtime implementations including:
|
|
|
20
20
|
- Driver for distributed execution
|
|
21
21
|
"""
|
|
22
22
|
|
|
23
|
-
from mplang.runtime.driver import Driver, DriverVar
|
|
24
|
-
from mplang.runtime.simulation import Simulator
|
|
23
|
+
from mplang.v1.runtime.driver import Driver, DriverVar
|
|
24
|
+
from mplang.v1.runtime.simulation import Simulator
|
|
25
25
|
|
|
26
26
|
__all__ = [
|
|
27
27
|
"Driver",
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""SPU IChannel implementation that bridges to MPLang CommunicatorBase.
|
|
16
|
+
|
|
17
|
+
This module provides BaseChannel, which allows SPU to reuse MPLang's
|
|
18
|
+
existing communication layer (ThreadCommunicator/HttpCommunicator) instead
|
|
19
|
+
of creating separate BRPC connections.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import logging
|
|
25
|
+
from typing import TYPE_CHECKING
|
|
26
|
+
|
|
27
|
+
import spu.libspu as libspu
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from mplang.v1.core.comm import CommunicatorBase
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class BaseChannel(libspu.link.IChannel):
|
|
34
|
+
"""Bridge MPLang CommunicatorBase to SPU IChannel interface.
|
|
35
|
+
|
|
36
|
+
This adapter allows SPU to use MPLang's existing communication layer
|
|
37
|
+
(ThreadCommunicator or HttpCommunicator) instead of creating separate
|
|
38
|
+
BRPC connections.
|
|
39
|
+
|
|
40
|
+
Each BaseChannel represents a channel to ONE peer rank.
|
|
41
|
+
|
|
42
|
+
Communication Protocol:
|
|
43
|
+
- SPU calls send(tag, bytes_data) -> MPLang comm.send(peer, key, bytes_data)
|
|
44
|
+
- SPU calls recv(tag) -> bytes_data <- MPLang comm.recv(peer, key)
|
|
45
|
+
|
|
46
|
+
Tag Namespace:
|
|
47
|
+
All tags are prefixed with "spu:" to avoid collision with other
|
|
48
|
+
MPLang traffic on the same communicator.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
comm: CommunicatorBase,
|
|
54
|
+
local_rank: int,
|
|
55
|
+
peer_rank: int,
|
|
56
|
+
tag_prefix: str = "spu",
|
|
57
|
+
):
|
|
58
|
+
"""Initialize channel to a specific peer.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
comm: MPLang communicator instance (Thread/Http)
|
|
62
|
+
local_rank: Global rank of this party (for logging/debugging)
|
|
63
|
+
peer_rank: Global rank of the peer party
|
|
64
|
+
tag_prefix: Prefix for all tags to avoid collision (default: "spu")
|
|
65
|
+
"""
|
|
66
|
+
super().__init__()
|
|
67
|
+
self._comm = comm
|
|
68
|
+
self._local_rank = local_rank
|
|
69
|
+
self._peer_rank = peer_rank
|
|
70
|
+
self._tag_prefix = tag_prefix
|
|
71
|
+
|
|
72
|
+
logging.debug(
|
|
73
|
+
f"BaseChannel initialized: local_rank={local_rank}, "
|
|
74
|
+
f"peer_rank={peer_rank}, tag_prefix={tag_prefix}"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def _make_key(self, tag: str) -> str:
|
|
78
|
+
"""Create unique key for MPLang comm.
|
|
79
|
+
|
|
80
|
+
Prefixes the tag to avoid collision with non-SPU traffic.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
tag: SPU-provided tag (e.g., "send_0", "recv_0")
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Prefixed key (e.g., "spu:send_0")
|
|
87
|
+
"""
|
|
88
|
+
return f"{self._tag_prefix}:{tag}"
|
|
89
|
+
|
|
90
|
+
def Send(self, tag: str, data: bytes) -> None:
|
|
91
|
+
"""Send bytes to peer (synchronous in SPU semantics).
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
tag: Message tag for matching send/recv pairs
|
|
95
|
+
data: Raw bytes to send
|
|
96
|
+
"""
|
|
97
|
+
key = self._make_key(tag)
|
|
98
|
+
logging.debug(
|
|
99
|
+
f"BaseChannel.Send: {self._local_rank} -> {self._peer_rank}, "
|
|
100
|
+
f"tag={tag}, key={key}, size={len(data)}"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Send raw bytes directly
|
|
104
|
+
# Note: CommunicatorBase.send expects Any type, bytes is acceptable
|
|
105
|
+
self._comm.send(self._peer_rank, key, data)
|
|
106
|
+
|
|
107
|
+
def Recv(self, tag: str) -> bytes:
|
|
108
|
+
"""Receive bytes from peer (blocking).
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
tag: Message tag for matching send/recv pairs
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
Raw bytes received
|
|
115
|
+
"""
|
|
116
|
+
key = self._make_key(tag)
|
|
117
|
+
logging.debug(
|
|
118
|
+
f"BaseChannel.Recv: {self._local_rank} <- {self._peer_rank}, "
|
|
119
|
+
f"tag={tag}, key={key}"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Receive data (should be bytes)
|
|
123
|
+
data = self._comm.recv(self._peer_rank, key)
|
|
124
|
+
|
|
125
|
+
# Validate data type
|
|
126
|
+
if not isinstance(data, bytes):
|
|
127
|
+
raise TypeError(
|
|
128
|
+
f"Expected bytes from communicator, got {type(data).__name__}. "
|
|
129
|
+
f"Communicator must support raw bytes transmission for SPU channels."
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
logging.debug(
|
|
133
|
+
f"BaseChannel.Recv complete: {self._local_rank} <- {self._peer_rank}, "
|
|
134
|
+
f"tag={tag}, size={len(data)}"
|
|
135
|
+
)
|
|
136
|
+
return data
|
|
137
|
+
|
|
138
|
+
def SendAsync(self, tag: str, data: bytes) -> None:
|
|
139
|
+
"""Async send (MPLang's send is already async at network layer).
|
|
140
|
+
|
|
141
|
+
For HttpCommunicator, the underlying httpx.put() is non-blocking
|
|
142
|
+
at the HTTP client level. For ThreadCommunicator, send is instant
|
|
143
|
+
(memory transfer).
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
tag: Message tag
|
|
147
|
+
data: Raw bytes to send
|
|
148
|
+
"""
|
|
149
|
+
# Reuse synchronous send - it's already async underneath
|
|
150
|
+
self.Send(tag, data)
|
|
151
|
+
|
|
152
|
+
def SendAsyncThrottled(self, tag: str, data: bytes) -> None:
|
|
153
|
+
"""Throttled async send.
|
|
154
|
+
|
|
155
|
+
Currently maps to regular SendAsync. Future optimization could
|
|
156
|
+
implement rate limiting if needed.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
tag: Message tag
|
|
160
|
+
data: Raw bytes to send
|
|
161
|
+
"""
|
|
162
|
+
self.SendAsync(tag, data)
|
|
163
|
+
|
|
164
|
+
def TestSend(self, timeout: int) -> None:
|
|
165
|
+
"""Test if this channel can send a dummy msg to peer.
|
|
166
|
+
|
|
167
|
+
Uses fixed 0 seq_id as dummy msg's id to make this function reentrant.
|
|
168
|
+
ConnectToMesh will retry on this multiple times.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
timeout: Timeout in milliseconds
|
|
172
|
+
"""
|
|
173
|
+
# Send a handshake message to test connectivity
|
|
174
|
+
# Use fixed tag "__test__" to make this reentrant (idempotent)
|
|
175
|
+
test_data = b"\x00" # Minimal 1-byte message with seq_id=0
|
|
176
|
+
self.Send("__test__", test_data)
|
|
177
|
+
|
|
178
|
+
def TestRecv(self) -> None:
|
|
179
|
+
"""Wait for dummy msg from peer.
|
|
180
|
+
|
|
181
|
+
Timeout is controlled by recv_timeout_ms in link descriptor.
|
|
182
|
+
"""
|
|
183
|
+
# Receive the handshake message from peer
|
|
184
|
+
# This blocks until message arrives (timeout from desc.recv_timeout_ms)
|
|
185
|
+
test_data = self.Recv("__test__")
|
|
186
|
+
# Validate it's the expected handshake message
|
|
187
|
+
if test_data != b"\x00":
|
|
188
|
+
logging.warning(
|
|
189
|
+
f"TestRecv: unexpected handshake data from {self._peer_rank}, "
|
|
190
|
+
f"expected b'\\x00', got {test_data!r}"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
def WaitLinkTaskFinish(self) -> None:
|
|
194
|
+
"""Wait for all pending async tasks.
|
|
195
|
+
|
|
196
|
+
For MPLang communicators:
|
|
197
|
+
- ThreadCommunicator: No-op (instant memory transfer)
|
|
198
|
+
- HttpCommunicator: No explicit wait needed (httpx handles it)
|
|
199
|
+
|
|
200
|
+
This is a no-op in current implementation.
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
def Abort(self) -> None:
|
|
204
|
+
"""Abort communication (cleanup resources).
|
|
205
|
+
|
|
206
|
+
This could be extended to notify the communicator to drop pending
|
|
207
|
+
messages for this channel, but currently is a no-op.
|
|
208
|
+
"""
|
|
209
|
+
logging.warning(
|
|
210
|
+
f"BaseChannel.Abort called: {self._local_rank} <-> {self._peer_rank}"
|
|
211
|
+
)
|
|
212
|
+
# Future: Could call comm.abort_session() if implemented
|
|
213
|
+
|
|
214
|
+
def SetThrottleWindowSize(self, size: int) -> None:
|
|
215
|
+
"""Set throttle window size.
|
|
216
|
+
|
|
217
|
+
Not applicable to MPLang communicators. No-op.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
size: Window size (ignored)
|
|
221
|
+
"""
|
|
222
|
+
|
|
223
|
+
def SetChunkParallelSendSize(self, size: int) -> None:
|
|
224
|
+
"""Set chunk parallel send size.
|
|
225
|
+
|
|
226
|
+
Not applicable to MPLang communicators. No-op.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
size: Chunk size (ignored)
|
|
230
|
+
"""
|
|
@@ -26,9 +26,9 @@ from typing import Any
|
|
|
26
26
|
import uvicorn
|
|
27
27
|
import yaml
|
|
28
28
|
|
|
29
|
-
from mplang.core
|
|
30
|
-
from mplang.runtime.client import HttpExecutorClient
|
|
31
|
-
from mplang.runtime.server import app
|
|
29
|
+
from mplang.v1.core import ClusterSpec
|
|
30
|
+
from mplang.v1.runtime.client import HttpExecutorClient
|
|
31
|
+
from mplang.v1.runtime.server import app
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
def load_config(config_path: str) -> ClusterSpec:
|
|
@@ -195,14 +195,14 @@ def status_command(args: argparse.Namespace) -> int:
|
|
|
195
195
|
"""
|
|
196
196
|
|
|
197
197
|
async def _get_node_status(
|
|
198
|
-
node_id: str, endpoint: str, details:
|
|
198
|
+
node_id: str, endpoint: str, details: int = 0, timeout: int = 60
|
|
199
199
|
) -> dict[str, Any]:
|
|
200
200
|
"""Get status information for a single node.
|
|
201
201
|
|
|
202
202
|
Args:
|
|
203
203
|
node_id: Identifier for the node
|
|
204
204
|
endpoint: HTTP endpoint of the node
|
|
205
|
-
details:
|
|
205
|
+
details: Verbosity level (0=basic, 1=-v, 2=-vv)
|
|
206
206
|
timeout: HTTP request timeout in seconds (default: 60)
|
|
207
207
|
"""
|
|
208
208
|
|
|
@@ -224,21 +224,26 @@ def status_command(args: argparse.Namespace) -> int:
|
|
|
224
224
|
sessions = await client.list_sessions()
|
|
225
225
|
status["sessions"] = sessions
|
|
226
226
|
|
|
227
|
-
# Get detailed session info
|
|
228
|
-
|
|
227
|
+
# Get detailed session info based on verbosity level
|
|
228
|
+
# details=1 (-v): show session names and basic counts
|
|
229
|
+
# details=2 (-vv): show full computation and symbol lists
|
|
230
|
+
if details >= 1:
|
|
229
231
|
session_details = []
|
|
230
232
|
for session_name in sessions:
|
|
231
233
|
try:
|
|
232
234
|
# Get computations and symbols for each session
|
|
233
235
|
computations = await client.list_computations(session_name)
|
|
234
236
|
symbols = await client.list_symbols(session_name)
|
|
235
|
-
|
|
237
|
+
session_info = {
|
|
236
238
|
"name": session_name,
|
|
237
239
|
"computations": len(computations),
|
|
238
240
|
"symbols": len(symbols),
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
241
|
+
}
|
|
242
|
+
# Include full lists only at -vv level
|
|
243
|
+
if details >= 2:
|
|
244
|
+
session_info["computation_list"] = computations
|
|
245
|
+
session_info["symbol_list"] = symbols
|
|
246
|
+
session_details.append(session_info)
|
|
242
247
|
except Exception as e:
|
|
243
248
|
session_details.append({
|
|
244
249
|
"name": session_name,
|
|
@@ -255,13 +260,13 @@ def status_command(args: argparse.Namespace) -> int:
|
|
|
255
260
|
return status
|
|
256
261
|
|
|
257
262
|
async def _collect_cluster_status(
|
|
258
|
-
nodes: dict[str, str], details:
|
|
263
|
+
nodes: dict[str, str], details: int = 0
|
|
259
264
|
) -> list[dict[str, Any] | BaseException]:
|
|
260
265
|
"""Collect status from all nodes concurrently.
|
|
261
266
|
|
|
262
267
|
Args:
|
|
263
268
|
nodes: Dictionary mapping node IDs to their HTTP endpoints
|
|
264
|
-
details:
|
|
269
|
+
details: Verbosity level (0=basic, 1=-v, 2=-vv)
|
|
265
270
|
|
|
266
271
|
Returns:
|
|
267
272
|
List of status dictionaries or exceptions for each node
|
|
@@ -284,7 +289,8 @@ def status_command(args: argparse.Namespace) -> int:
|
|
|
284
289
|
node_addrs = {node_id: node.endpoint for node_id, node in nodes.items()}
|
|
285
290
|
|
|
286
291
|
# Collect status from all nodes
|
|
287
|
-
|
|
292
|
+
verbosity = getattr(args, "verbose", 0)
|
|
293
|
+
cluster_status = asyncio.run(_collect_cluster_status(node_addrs, verbosity))
|
|
288
294
|
|
|
289
295
|
# Basic node health check
|
|
290
296
|
print("Node Status:")
|
|
@@ -314,8 +320,8 @@ def status_command(args: argparse.Namespace) -> int:
|
|
|
314
320
|
print(f"{node_id:<15} {endpoint:<20} UNHEALTHY")
|
|
315
321
|
all_healthy = False
|
|
316
322
|
|
|
317
|
-
# If
|
|
318
|
-
if
|
|
323
|
+
# If verbose mode is enabled, show detailed information
|
|
324
|
+
if verbosity >= 1 and valid_statuses:
|
|
319
325
|
print("\nDetailed Runtime Status:")
|
|
320
326
|
print("-" * 50)
|
|
321
327
|
|
|
@@ -351,6 +357,14 @@ def status_command(args: argparse.Namespace) -> int:
|
|
|
351
357
|
print(
|
|
352
358
|
f" - Session '{session_name}': {computations} computations, {symbols} symbols"
|
|
353
359
|
)
|
|
360
|
+
# At -vv level, show the actual lists
|
|
361
|
+
if verbosity >= 2:
|
|
362
|
+
comp_list = session.get("computation_list", [])
|
|
363
|
+
symbol_list = session.get("symbol_list", [])
|
|
364
|
+
if comp_list:
|
|
365
|
+
print(f" Computations: {comp_list}")
|
|
366
|
+
if symbol_list:
|
|
367
|
+
print(f" Symbols: {symbol_list}")
|
|
354
368
|
elif not sessions:
|
|
355
369
|
print(" - No active sessions")
|
|
356
370
|
|
|
@@ -381,10 +395,11 @@ def main() -> int:
|
|
|
381
395
|
"--config", "-c", required=True, help="Path to the YAML configuration file"
|
|
382
396
|
)
|
|
383
397
|
status_parser.add_argument(
|
|
384
|
-
"--
|
|
385
|
-
"-
|
|
386
|
-
action="
|
|
387
|
-
|
|
398
|
+
"--verbose",
|
|
399
|
+
"-v",
|
|
400
|
+
action="count",
|
|
401
|
+
default=0,
|
|
402
|
+
help="Increase verbosity: -v for session details, -vv for full lists",
|
|
388
403
|
)
|
|
389
404
|
status_parser.set_defaults(func=status_command)
|
|
390
405
|
|
|
@@ -25,9 +25,10 @@ from __future__ import annotations
|
|
|
25
25
|
import base64
|
|
26
26
|
from typing import Any
|
|
27
27
|
|
|
28
|
-
import cloudpickle as pickle
|
|
29
28
|
import httpx
|
|
30
29
|
|
|
30
|
+
from mplang.v1.kernels.value import Value, decode_value, encode_value
|
|
31
|
+
|
|
31
32
|
|
|
32
33
|
class ExecutionStatus:
|
|
33
34
|
"""Status of a computation execution."""
|
|
@@ -253,8 +254,10 @@ class HttpExecutorClient:
|
|
|
253
254
|
"""
|
|
254
255
|
url = f"/sessions/{session_name}/symbols/{symbol_name}"
|
|
255
256
|
|
|
256
|
-
# Serialize data
|
|
257
|
-
|
|
257
|
+
# Serialize data using Value envelope
|
|
258
|
+
if not isinstance(data, Value):
|
|
259
|
+
raise TypeError(f"Data must be a Value instance, got {type(data)}")
|
|
260
|
+
data_bytes = encode_value(data)
|
|
258
261
|
data_b64 = base64.b64encode(data_bytes).decode("utf-8")
|
|
259
262
|
|
|
260
263
|
payload = {"data": data_b64, "mptype": mptype or {}}
|
|
@@ -286,11 +289,15 @@ class HttpExecutorClient:
|
|
|
286
289
|
response.raise_for_status()
|
|
287
290
|
symbol_data = response.json()
|
|
288
291
|
|
|
289
|
-
# Deserialize data
|
|
292
|
+
# Deserialize data using Value envelope
|
|
290
293
|
data_bytes = base64.b64decode(symbol_data["data"])
|
|
291
|
-
return
|
|
294
|
+
return decode_value(data_bytes)
|
|
292
295
|
|
|
293
|
-
except
|
|
296
|
+
except httpx.HTTPStatusError as e:
|
|
297
|
+
if e.response is not None and e.response.status_code == 404:
|
|
298
|
+
return None
|
|
299
|
+
raise self._raise_http_error(f"get symbol {symbol_name}", e) from e
|
|
300
|
+
except httpx.RequestError as e:
|
|
294
301
|
raise self._raise_http_error(f"get symbol {symbol_name}", e) from e
|
|
295
302
|
|
|
296
303
|
async def delete_symbol(self, session_name: str, symbol_name: str) -> None:
|
|
@@ -403,8 +410,12 @@ class HttpExecutorClient:
|
|
|
403
410
|
"""
|
|
404
411
|
url = f"/api/v1/symbols/{symbol_name}"
|
|
405
412
|
try:
|
|
413
|
+
# Serialize using Value envelope
|
|
414
|
+
if not isinstance(data, Value):
|
|
415
|
+
raise TypeError(f"Data must be a Value instance, got {type(data)}")
|
|
416
|
+
data_bytes = encode_value(data)
|
|
406
417
|
payload = {
|
|
407
|
-
"data": base64.b64encode(
|
|
418
|
+
"data": base64.b64encode(data_bytes).decode("utf-8"),
|
|
408
419
|
"mptype": mptype or {},
|
|
409
420
|
}
|
|
410
421
|
resp = await self._client.put(url, json=payload)
|
|
@@ -421,7 +432,7 @@ class HttpExecutorClient:
|
|
|
421
432
|
resp.raise_for_status()
|
|
422
433
|
payload = resp.json()
|
|
423
434
|
data_bytes = base64.b64decode(payload["data"])
|
|
424
|
-
return
|
|
435
|
+
return decode_value(data_bytes)
|
|
425
436
|
except (httpx.HTTPStatusError, httpx.RequestError) as e:
|
|
426
437
|
raise self._raise_http_error(f"get global symbol {symbol_name}", e) from e
|
|
427
438
|
|