mplang-nightly 0.1.dev163__py3-none-any.whl → 0.1.dev165__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/core/expr/evaluator.py +55 -15
- mplang/device.py +4 -18
- mplang/kernels/__init__.py +28 -0
- mplang/kernels/builtin.py +91 -56
- mplang/kernels/crypto.py +39 -30
- mplang/kernels/mock_tee.py +10 -11
- mplang/kernels/phe.py +238 -39
- mplang/kernels/spu.py +134 -45
- mplang/kernels/sql_duckdb.py +8 -13
- mplang/kernels/stablehlo.py +15 -9
- mplang/kernels/value.py +626 -0
- mplang/ops/tee.py +7 -21
- mplang/protos/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/protos/v1alpha1/value_pb2.py +34 -0
- mplang/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/runtime/client.py +19 -8
- mplang/runtime/communicator.py +11 -4
- mplang/runtime/driver.py +16 -1
- mplang/runtime/link_comm.py +26 -79
- mplang/runtime/server.py +30 -29
- mplang/runtime/session.py +9 -0
- mplang/runtime/simulation.py +4 -5
- mplang/simp/__init__.py +1 -1
- {mplang_nightly-0.1.dev163.dist-info → mplang_nightly-0.1.dev165.dist-info}/METADATA +1 -1
- {mplang_nightly-0.1.dev163.dist-info → mplang_nightly-0.1.dev165.dist-info}/RECORD +28 -25
- {mplang_nightly-0.1.dev163.dist-info → mplang_nightly-0.1.dev165.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev163.dist-info → mplang_nightly-0.1.dev165.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev163.dist-info → mplang_nightly-0.1.dev165.dist-info}/licenses/LICENSE +0 -0
mplang/kernels/phe.py
CHANGED
@@ -14,15 +14,27 @@
|
|
14
14
|
|
15
15
|
"""PHE (Partially Homomorphic Encryption) backend implementation using lightPHE."""
|
16
16
|
|
17
|
-
from
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
import json
|
20
|
+
from typing import Any, ClassVar
|
18
21
|
|
19
22
|
import numpy as np
|
20
23
|
from lightphe import LightPHE
|
24
|
+
from lightphe.models.Ciphertext import Ciphertext
|
21
25
|
|
22
26
|
from mplang.core.dtype import DType
|
23
|
-
from mplang.core.mptype import TensorLike
|
24
27
|
from mplang.core.pfunc import PFunction
|
25
28
|
from mplang.kernels.base import kernel_def
|
29
|
+
from mplang.kernels.value import (
|
30
|
+
TensorValue,
|
31
|
+
Value,
|
32
|
+
ValueDecodeError,
|
33
|
+
ValueProtoBuilder,
|
34
|
+
ValueProtoReader,
|
35
|
+
register_value,
|
36
|
+
)
|
37
|
+
from mplang.protos.v1alpha1 import value_pb2 as _value_pb2
|
26
38
|
|
27
39
|
# This controls the decimal precision used in lightPHE for float operations
|
28
40
|
# we force it to 0 to only support integer operations
|
@@ -30,8 +42,12 @@ from mplang.kernels.base import kernel_def
|
|
30
42
|
PRECISION = 0
|
31
43
|
|
32
44
|
|
33
|
-
|
34
|
-
|
45
|
+
@register_value
|
46
|
+
class PublicKey(Value):
|
47
|
+
"""PHE Public Key Value type."""
|
48
|
+
|
49
|
+
KIND: ClassVar[str] = "mplang.phe.PublicKey"
|
50
|
+
WIRE_VERSION: ClassVar[int] = 1
|
35
51
|
|
36
52
|
def __init__(
|
37
53
|
self,
|
@@ -62,12 +78,56 @@ class PublicKey:
|
|
62
78
|
"""Maximum float value that can be encoded."""
|
63
79
|
return float(self.max_value / (2**self.fxp_bits))
|
64
80
|
|
81
|
+
def to_proto(self) -> _value_pb2.ValueProto:
|
82
|
+
"""Serialize PublicKey to wire format."""
|
83
|
+
return (
|
84
|
+
ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
|
85
|
+
.set_attr("scheme", self.scheme)
|
86
|
+
.set_attr("key_size", self.key_size)
|
87
|
+
.set_attr("max_value", self.max_value)
|
88
|
+
.set_attr("fxp_bits", self.fxp_bits)
|
89
|
+
.set_attr("modulus", str(self.modulus) if self.modulus is not None else "")
|
90
|
+
.set_payload(json.dumps(self.key_data).encode("utf-8"))
|
91
|
+
.build()
|
92
|
+
)
|
93
|
+
|
94
|
+
@classmethod
|
95
|
+
def from_proto(cls, proto: _value_pb2.ValueProto) -> PublicKey:
|
96
|
+
"""Deserialize PublicKey from wire format."""
|
97
|
+
reader = ValueProtoReader(proto)
|
98
|
+
if reader.version != cls.WIRE_VERSION:
|
99
|
+
raise ValueDecodeError(f"Unsupported PublicKey version {reader.version}")
|
100
|
+
|
101
|
+
# Read metadata from runtime_attrs
|
102
|
+
scheme = reader.get_attr("scheme")
|
103
|
+
key_size = reader.get_attr("key_size")
|
104
|
+
max_value = reader.get_attr("max_value")
|
105
|
+
fxp_bits = reader.get_attr("fxp_bits")
|
106
|
+
modulus_str = reader.get_attr("modulus")
|
107
|
+
modulus = None if modulus_str == "" else int(modulus_str)
|
108
|
+
|
109
|
+
# JSON deserialize the public key dict
|
110
|
+
key_data = json.loads(reader.payload.decode("utf-8"))
|
111
|
+
|
112
|
+
return cls(
|
113
|
+
key_data=key_data,
|
114
|
+
scheme=scheme,
|
115
|
+
key_size=key_size,
|
116
|
+
max_value=max_value,
|
117
|
+
fxp_bits=fxp_bits,
|
118
|
+
modulus=modulus,
|
119
|
+
)
|
120
|
+
|
65
121
|
def __repr__(self) -> str:
|
66
122
|
return f"PublicKey(scheme={self.scheme}, key_size={self.key_size}, max_value={self.max_value}, fxp_bits={self.fxp_bits})"
|
67
123
|
|
68
124
|
|
69
|
-
|
70
|
-
|
125
|
+
@register_value
|
126
|
+
class PrivateKey(Value):
|
127
|
+
"""PHE Private Key Value type."""
|
128
|
+
|
129
|
+
KIND: ClassVar[str] = "mplang.phe.PrivateKey"
|
130
|
+
WIRE_VERSION: ClassVar[int] = 1
|
71
131
|
|
72
132
|
def __init__(
|
73
133
|
self,
|
@@ -100,12 +160,63 @@ class PrivateKey:
|
|
100
160
|
"""Maximum float value that can be encoded."""
|
101
161
|
return float(self.max_value / (2**self.fxp_bits))
|
102
162
|
|
163
|
+
def to_proto(self) -> _value_pb2.ValueProto:
|
164
|
+
"""Serialize PrivateKey to wire format."""
|
165
|
+
# JSON serialize both key dicts (contain int values)
|
166
|
+
# Store both keys in a single dict to avoid needing length metadata
|
167
|
+
keys_dict = {"sk": self.sk_data, "pk": self.pk_data}
|
168
|
+
|
169
|
+
return (
|
170
|
+
ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
|
171
|
+
.set_attr("scheme", self.scheme)
|
172
|
+
.set_attr("key_size", self.key_size)
|
173
|
+
.set_attr("max_value", self.max_value)
|
174
|
+
.set_attr("fxp_bits", self.fxp_bits)
|
175
|
+
.set_attr("modulus", str(self.modulus) if self.modulus is not None else "")
|
176
|
+
.set_payload(json.dumps(keys_dict).encode("utf-8"))
|
177
|
+
.build()
|
178
|
+
)
|
179
|
+
|
180
|
+
@classmethod
|
181
|
+
def from_proto(cls, proto: _value_pb2.ValueProto) -> PrivateKey:
|
182
|
+
"""Deserialize PrivateKey from wire format."""
|
183
|
+
reader = ValueProtoReader(proto)
|
184
|
+
if reader.version != cls.WIRE_VERSION:
|
185
|
+
raise ValueDecodeError(f"Unsupported PrivateKey version {reader.version}")
|
186
|
+
|
187
|
+
# Read metadata from runtime_attrs
|
188
|
+
scheme = reader.get_attr("scheme")
|
189
|
+
key_size = reader.get_attr("key_size")
|
190
|
+
max_value = reader.get_attr("max_value")
|
191
|
+
fxp_bits = reader.get_attr("fxp_bits")
|
192
|
+
modulus_str = reader.get_attr("modulus")
|
193
|
+
modulus = None if modulus_str == "" else int(modulus_str)
|
194
|
+
|
195
|
+
# JSON deserialize both key dicts
|
196
|
+
keys_dict = json.loads(reader.payload.decode("utf-8"))
|
197
|
+
sk_data = keys_dict["sk"]
|
198
|
+
pk_data = keys_dict["pk"]
|
199
|
+
|
200
|
+
return cls(
|
201
|
+
sk_data=sk_data,
|
202
|
+
pk_data=pk_data,
|
203
|
+
scheme=scheme,
|
204
|
+
key_size=key_size,
|
205
|
+
max_value=max_value,
|
206
|
+
fxp_bits=fxp_bits,
|
207
|
+
modulus=modulus,
|
208
|
+
)
|
209
|
+
|
103
210
|
def __repr__(self) -> str:
|
104
211
|
return f"PrivateKey(scheme={self.scheme}, key_size={self.key_size}, max_value={self.max_value}, fxp_bits={self.fxp_bits})"
|
105
212
|
|
106
213
|
|
107
|
-
|
108
|
-
|
214
|
+
@register_value
|
215
|
+
class CipherText(Value):
|
216
|
+
"""PHE CipherText Value type."""
|
217
|
+
|
218
|
+
KIND: ClassVar[str] = "mplang.phe.CipherText"
|
219
|
+
WIRE_VERSION: ClassVar[int] = 1
|
109
220
|
|
110
221
|
def __init__(
|
111
222
|
self,
|
@@ -142,6 +253,106 @@ class CipherText:
|
|
142
253
|
"""Maximum float value that can be encoded."""
|
143
254
|
return float(self.max_value / (2**self.fxp_bits))
|
144
255
|
|
256
|
+
def to_proto(self) -> _value_pb2.ValueProto:
|
257
|
+
"""Serialize CipherText to wire format.
|
258
|
+
|
259
|
+
WARNING: This serialization is tightly coupled to lightphe.Ciphertext
|
260
|
+
internal attributes (value, algorithm_name, keys). Any changes to these
|
261
|
+
attributes in future lightphe versions will break serialization.
|
262
|
+
|
263
|
+
TODO: Check if lightphe provides official serialization methods and
|
264
|
+
migrate to them if available. Consider adding version compatibility checks.
|
265
|
+
"""
|
266
|
+
# JSON serialize ciphertext components
|
267
|
+
# ct_data is a list of lightPHE Ciphertext objects
|
268
|
+
# Each Ciphertext has: value, algorithm_name, keys
|
269
|
+
# We need to serialize the list of ciphertexts
|
270
|
+
if not isinstance(self.ct_data, list):
|
271
|
+
raise ValueError(f"ct_data should be a list, got {type(self.ct_data)}")
|
272
|
+
|
273
|
+
ct_list = []
|
274
|
+
for ct in self.ct_data:
|
275
|
+
if not isinstance(ct, Ciphertext):
|
276
|
+
raise TypeError(
|
277
|
+
f"ct_data must contain lightphe.Ciphertext objects, got {type(ct).__name__}"
|
278
|
+
)
|
279
|
+
ct_list.append({
|
280
|
+
"value": ct.value,
|
281
|
+
"algorithm_name": ct.algorithm_name,
|
282
|
+
"keys": ct.keys,
|
283
|
+
})
|
284
|
+
|
285
|
+
# Combine ct_data and pk_data into single dict
|
286
|
+
payload_dict = {
|
287
|
+
"ct_list": ct_list,
|
288
|
+
"pk": self.pk_data if self.pk_data is not None else None,
|
289
|
+
}
|
290
|
+
|
291
|
+
return (
|
292
|
+
ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
|
293
|
+
.set_attr("semantic_dtype", str(self.semantic_dtype))
|
294
|
+
.set_attr("semantic_shape", list(self.semantic_shape))
|
295
|
+
.set_attr("scheme", self.scheme)
|
296
|
+
.set_attr("key_size", self.key_size)
|
297
|
+
.set_attr("max_value", self.max_value)
|
298
|
+
.set_attr("fxp_bits", self.fxp_bits)
|
299
|
+
.set_attr("modulus", str(self.modulus) if self.modulus is not None else "")
|
300
|
+
.set_payload(json.dumps(payload_dict).encode("utf-8"))
|
301
|
+
.build()
|
302
|
+
)
|
303
|
+
|
304
|
+
@classmethod
|
305
|
+
def from_proto(cls, proto: _value_pb2.ValueProto) -> CipherText:
|
306
|
+
"""Deserialize CipherText from wire format."""
|
307
|
+
reader = ValueProtoReader(proto)
|
308
|
+
if reader.version != cls.WIRE_VERSION:
|
309
|
+
raise ValueDecodeError(f"Unsupported CipherText version {reader.version}")
|
310
|
+
|
311
|
+
# Read metadata from runtime_attrs
|
312
|
+
semantic_dtype_str = reader.get_attr("semantic_dtype")
|
313
|
+
semantic_shape = reader.get_attr("semantic_shape")
|
314
|
+
scheme = reader.get_attr("scheme")
|
315
|
+
key_size = reader.get_attr("key_size")
|
316
|
+
max_value = reader.get_attr("max_value")
|
317
|
+
fxp_bits = reader.get_attr("fxp_bits")
|
318
|
+
modulus_str = reader.get_attr("modulus")
|
319
|
+
modulus = None if modulus_str == "" else int(modulus_str)
|
320
|
+
|
321
|
+
# JSON deserialize ciphertext and public key
|
322
|
+
payload_dict = json.loads(reader.payload.decode("utf-8"))
|
323
|
+
ct_list = payload_dict["ct_list"]
|
324
|
+
pk_data = payload_dict["pk"]
|
325
|
+
|
326
|
+
# Reconstruct ct_data: list of Ciphertext objects
|
327
|
+
ct_data = []
|
328
|
+
for ct_dict in ct_list:
|
329
|
+
if ct_dict["keys"] is None or ct_dict["algorithm_name"] is None:
|
330
|
+
raise ValueDecodeError(
|
331
|
+
"Invalid CipherText: missing keys or algorithm_name in serialized data"
|
332
|
+
)
|
333
|
+
ct_data.append(
|
334
|
+
Ciphertext(
|
335
|
+
algorithm_name=ct_dict["algorithm_name"],
|
336
|
+
keys=ct_dict["keys"],
|
337
|
+
value=ct_dict["value"],
|
338
|
+
)
|
339
|
+
)
|
340
|
+
|
341
|
+
# Parse dtype string back to DType
|
342
|
+
dtype = DType.from_any(semantic_dtype_str)
|
343
|
+
|
344
|
+
return cls(
|
345
|
+
ct_data=ct_data,
|
346
|
+
semantic_dtype=dtype,
|
347
|
+
semantic_shape=tuple(semantic_shape),
|
348
|
+
scheme=scheme,
|
349
|
+
key_size=key_size,
|
350
|
+
pk_data=pk_data,
|
351
|
+
max_value=max_value,
|
352
|
+
fxp_bits=fxp_bits,
|
353
|
+
modulus=modulus,
|
354
|
+
)
|
355
|
+
|
145
356
|
def __repr__(self) -> str:
|
146
357
|
return f"CipherText(dtype={self.semantic_dtype}, shape={self.semantic_shape}, scheme={self.scheme})"
|
147
358
|
|
@@ -257,23 +468,6 @@ def _range_decode_mixed(
|
|
257
468
|
return _range_decode_integer(encoded_value, max_value, modulus)
|
258
469
|
|
259
470
|
|
260
|
-
def _convert_to_numpy(obj: TensorLike) -> np.ndarray:
|
261
|
-
"""Convert a TensorLike object to numpy array."""
|
262
|
-
if isinstance(obj, np.ndarray):
|
263
|
-
return obj
|
264
|
-
|
265
|
-
# Try to use .numpy() method if available
|
266
|
-
if hasattr(obj, "numpy"):
|
267
|
-
numpy_method = getattr(obj, "numpy", None)
|
268
|
-
if callable(numpy_method):
|
269
|
-
try:
|
270
|
-
return np.asarray(numpy_method())
|
271
|
-
except Exception:
|
272
|
-
pass
|
273
|
-
|
274
|
-
return np.asarray(obj)
|
275
|
-
|
276
|
-
|
277
471
|
@kernel_def("phe.keygen")
|
278
472
|
def _phe_keygen(pfunc: PFunction) -> Any:
|
279
473
|
scheme = pfunc.attrs.get("scheme", "paillier")
|
@@ -334,14 +528,16 @@ def _phe_keygen(pfunc: PFunction) -> Any:
|
|
334
528
|
|
335
529
|
|
336
530
|
@kernel_def("phe.encrypt")
|
337
|
-
def _phe_encrypt(
|
531
|
+
def _phe_encrypt(
|
532
|
+
pfunc: PFunction, plaintext: TensorValue, public_key: PublicKey
|
533
|
+
) -> Any:
|
338
534
|
# Validate public_key type
|
339
535
|
if not isinstance(public_key, PublicKey):
|
340
536
|
raise ValueError("Second argument must be a PublicKey instance")
|
341
537
|
|
342
538
|
try:
|
343
539
|
# Convert plaintext to numpy to get semantic type info
|
344
|
-
plaintext_np =
|
540
|
+
plaintext_np = plaintext.to_numpy()
|
345
541
|
semantic_dtype = DType.from_numpy(plaintext_np.dtype)
|
346
542
|
semantic_shape = plaintext_np.shape
|
347
543
|
|
@@ -403,14 +599,14 @@ def _phe_encrypt(pfunc: PFunction, plaintext: Any, public_key: PublicKey) -> Any
|
|
403
599
|
|
404
600
|
|
405
601
|
@kernel_def("phe.mul")
|
406
|
-
def _phe_mul(pfunc: PFunction, ciphertext: CipherText, plaintext:
|
602
|
+
def _phe_mul(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorValue) -> Any:
|
407
603
|
# Validate that first argument is a CipherText
|
408
604
|
if not isinstance(ciphertext, CipherText):
|
409
605
|
raise ValueError("First argument must be a CipherText instance")
|
410
606
|
|
411
607
|
try:
|
412
608
|
# Convert plaintext to numpy
|
413
|
-
plaintext_np =
|
609
|
+
plaintext_np = plaintext.to_numpy()
|
414
610
|
|
415
611
|
# Check if plaintext is floating point type - multiplication not supported
|
416
612
|
if np.issubdtype(plaintext_np.dtype, np.floating):
|
@@ -511,7 +707,7 @@ def _phe_add(pfunc: PFunction, lhs: Any, rhs: Any) -> Any:
|
|
511
707
|
elif isinstance(rhs, CipherText):
|
512
708
|
return _phe_add_ct2pt(rhs, lhs)
|
513
709
|
else:
|
514
|
-
return
|
710
|
+
return TensorValue(lhs.to_numpy() + rhs.to_numpy())
|
515
711
|
except ValueError:
|
516
712
|
raise
|
517
713
|
except Exception as e: # pragma: no cover
|
@@ -593,9 +789,9 @@ def _phe_add_ct2ct(ct1: CipherText, ct2: CipherText) -> CipherText:
|
|
593
789
|
)
|
594
790
|
|
595
791
|
|
596
|
-
def _phe_add_ct2pt(ciphertext: CipherText, plaintext:
|
792
|
+
def _phe_add_ct2pt(ciphertext: CipherText, plaintext: TensorValue) -> CipherText:
|
597
793
|
# Convert plaintext to numpy
|
598
|
-
plaintext_np =
|
794
|
+
plaintext_np = plaintext.to_numpy()
|
599
795
|
plaintext_dtype = DType.from_numpy(plaintext_np.dtype)
|
600
796
|
|
601
797
|
# Check for mixed precision issue: floating point ciphertext + integer plaintext
|
@@ -816,14 +1012,14 @@ def _phe_decrypt(
|
|
816
1012
|
ciphertext.semantic_shape
|
817
1013
|
)
|
818
1014
|
|
819
|
-
return [plaintext_np]
|
1015
|
+
return [TensorValue(plaintext_np)]
|
820
1016
|
|
821
1017
|
except Exception as e:
|
822
1018
|
raise RuntimeError(f"Failed to decrypt data: {e}") from e
|
823
1019
|
|
824
1020
|
|
825
1021
|
@kernel_def("phe.dot")
|
826
|
-
def _phe_dot(pfunc: PFunction, ciphertext: CipherText, plaintext:
|
1022
|
+
def _phe_dot(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorValue) -> Any:
|
827
1023
|
"""Execute homomorphic dot product with zero-value optimization.
|
828
1024
|
|
829
1025
|
Supports various dot product operations:
|
@@ -844,7 +1040,7 @@ def _phe_dot(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorLike) ->
|
|
844
1040
|
|
845
1041
|
try:
|
846
1042
|
# Convert plaintext to numpy
|
847
|
-
plaintext_np =
|
1043
|
+
plaintext_np = plaintext.to_numpy()
|
848
1044
|
|
849
1045
|
# Check if plaintext is floating point type - dot product not supported
|
850
1046
|
if np.issubdtype(plaintext_np.dtype, np.floating):
|
@@ -1109,7 +1305,7 @@ def _phe_dot(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorLike) ->
|
|
1109
1305
|
|
1110
1306
|
|
1111
1307
|
@kernel_def("phe.gather")
|
1112
|
-
def _phe_gather(pfunc: PFunction, ciphertext: CipherText, indices:
|
1308
|
+
def _phe_gather(pfunc: PFunction, ciphertext: CipherText, indices: TensorValue) -> Any:
|
1113
1309
|
"""Execute gather operation on CipherText.
|
1114
1310
|
|
1115
1311
|
Supports gathering from multidimensional CipherText using multidimensional indices.
|
@@ -1126,7 +1322,7 @@ def _phe_gather(pfunc: PFunction, ciphertext: CipherText, indices: Any) -> Any:
|
|
1126
1322
|
|
1127
1323
|
try:
|
1128
1324
|
# Convert indices to numpy
|
1129
|
-
indices_np =
|
1325
|
+
indices_np = indices.to_numpy()
|
1130
1326
|
|
1131
1327
|
if not np.issubdtype(indices_np.dtype, np.integer):
|
1132
1328
|
raise ValueError("Indices must be of integer type")
|
@@ -1224,7 +1420,10 @@ def _phe_gather(pfunc: PFunction, ciphertext: CipherText, indices: Any) -> Any:
|
|
1224
1420
|
|
1225
1421
|
@kernel_def("phe.scatter")
|
1226
1422
|
def _phe_scatter(
|
1227
|
-
pfunc: PFunction,
|
1423
|
+
pfunc: PFunction,
|
1424
|
+
ciphertext: CipherText,
|
1425
|
+
indices: TensorValue,
|
1426
|
+
updated: CipherText,
|
1228
1427
|
) -> Any:
|
1229
1428
|
"""Execute scatter operation on CipherText.
|
1230
1429
|
|
@@ -1252,7 +1451,7 @@ def _phe_scatter(
|
|
1252
1451
|
|
1253
1452
|
try:
|
1254
1453
|
# Convert indices to numpy
|
1255
|
-
indices_np =
|
1454
|
+
indices_np = indices.to_numpy()
|
1256
1455
|
|
1257
1456
|
if not np.issubdtype(indices_np.dtype, np.integer):
|
1258
1457
|
raise ValueError("Indices must be of integer type")
|
mplang/kernels/spu.py
CHANGED
@@ -15,15 +15,37 @@
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
17
|
from dataclasses import dataclass
|
18
|
-
from typing import Any
|
18
|
+
from typing import Any, ClassVar
|
19
19
|
|
20
20
|
import numpy as np
|
21
21
|
import spu.api as spu_api
|
22
22
|
import spu.libspu as libspu
|
23
23
|
|
24
|
-
from mplang.core.
|
24
|
+
from mplang.core.dtype import (
|
25
|
+
BOOL,
|
26
|
+
FLOAT32,
|
27
|
+
FLOAT64,
|
28
|
+
INT8,
|
29
|
+
INT16,
|
30
|
+
INT32,
|
31
|
+
INT64,
|
32
|
+
UINT8,
|
33
|
+
UINT16,
|
34
|
+
UINT32,
|
35
|
+
UINT64,
|
36
|
+
DType,
|
37
|
+
)
|
25
38
|
from mplang.core.pfunc import PFunction
|
26
39
|
from mplang.kernels.base import cur_kctx, kernel_def
|
40
|
+
from mplang.kernels.value import (
|
41
|
+
TensorValue,
|
42
|
+
Value,
|
43
|
+
ValueDecodeError,
|
44
|
+
ValueProtoBuilder,
|
45
|
+
ValueProtoReader,
|
46
|
+
register_value,
|
47
|
+
)
|
48
|
+
from mplang.protos.v1alpha1 import value_pb2 as _value_pb2
|
27
49
|
from mplang.runtime.link_comm import LinkCommunicator
|
28
50
|
|
29
51
|
|
@@ -32,36 +54,106 @@ def shape_spu_to_np(spu_shape: Any) -> tuple[int, ...]:
|
|
32
54
|
return tuple(spu_shape.dims)
|
33
55
|
|
34
56
|
|
35
|
-
def
|
36
|
-
"""Convert
|
57
|
+
def dtype_spu_to_mpl(spu_dtype: libspu.DataType) -> DType:
|
58
|
+
"""Convert libspu.DataType to MPLang DType."""
|
37
59
|
MAP = {
|
38
|
-
libspu.DataType.DT_F32:
|
39
|
-
libspu.DataType.DT_F64:
|
40
|
-
libspu.DataType.DT_I1:
|
41
|
-
libspu.DataType.DT_I8:
|
42
|
-
libspu.DataType.DT_U8:
|
43
|
-
libspu.DataType.DT_I16:
|
44
|
-
libspu.DataType.DT_U16:
|
45
|
-
libspu.DataType.DT_I32:
|
46
|
-
libspu.DataType.DT_U32:
|
47
|
-
libspu.DataType.DT_I64:
|
48
|
-
libspu.DataType.DT_U64:
|
60
|
+
libspu.DataType.DT_F32: FLOAT32,
|
61
|
+
libspu.DataType.DT_F64: FLOAT64,
|
62
|
+
libspu.DataType.DT_I1: BOOL,
|
63
|
+
libspu.DataType.DT_I8: INT8,
|
64
|
+
libspu.DataType.DT_U8: UINT8,
|
65
|
+
libspu.DataType.DT_I16: INT16,
|
66
|
+
libspu.DataType.DT_U16: UINT16,
|
67
|
+
libspu.DataType.DT_I32: INT32,
|
68
|
+
libspu.DataType.DT_U32: UINT32,
|
69
|
+
libspu.DataType.DT_I64: INT64,
|
70
|
+
libspu.DataType.DT_U64: UINT64,
|
49
71
|
}
|
50
|
-
return MAP[spu_dtype]
|
72
|
+
return MAP[spu_dtype]
|
51
73
|
|
52
74
|
|
75
|
+
@register_value
|
53
76
|
@dataclass
|
54
|
-
class SpuValue:
|
55
|
-
"""SPU value container for secure computation."""
|
77
|
+
class SpuValue(Value):
|
78
|
+
"""SPU value container for secure computation (Value type)."""
|
79
|
+
|
80
|
+
KIND: ClassVar[str] = "mplang.spu.SpuValue"
|
81
|
+
WIRE_VERSION: ClassVar[int] = 1
|
56
82
|
|
57
83
|
shape: tuple[int, ...]
|
58
|
-
dtype:
|
84
|
+
dtype: DType # Now uses MPLang's unified DType
|
59
85
|
vtype: libspu.Visibility
|
60
86
|
share: libspu.Share
|
61
87
|
|
62
88
|
def __repr__(self) -> str:
|
63
89
|
return f"SpuValue({self.shape},{self.dtype},{self.vtype})"
|
64
90
|
|
91
|
+
def to_proto(self) -> _value_pb2.ValueProto:
|
92
|
+
"""Serialize SpuValue to wire format.
|
93
|
+
|
94
|
+
libspu.Share has two attributes:
|
95
|
+
- meta: bytes (protobuf serialized metadata)
|
96
|
+
- share_chunks: list[bytes] (the actual secret share data)
|
97
|
+
|
98
|
+
Strategy: Store shape/dtype/vtype in runtime_attrs, concatenate share.meta + all chunks in payload.
|
99
|
+
"""
|
100
|
+
# Store metadata in runtime_attrs; keep chunk lengths for payload splitting
|
101
|
+
chunk_lengths = [len(chunk) for chunk in self.share.share_chunks]
|
102
|
+
|
103
|
+
# Payload contains only share chunks (meta stored in attrs)
|
104
|
+
payload = b""
|
105
|
+
for chunk in self.share.share_chunks:
|
106
|
+
payload += chunk
|
107
|
+
|
108
|
+
return (
|
109
|
+
ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
|
110
|
+
.set_attr("shape", list(self.shape))
|
111
|
+
.set_attr("dtype", self.dtype.name) # Serialize DType name
|
112
|
+
.set_attr("vtype", int(self.vtype))
|
113
|
+
.set_attr("share_meta", self.share.meta)
|
114
|
+
.set_attr("chunk_lengths", chunk_lengths)
|
115
|
+
.set_payload(payload)
|
116
|
+
.build()
|
117
|
+
)
|
118
|
+
|
119
|
+
@classmethod
|
120
|
+
def from_proto(cls, proto: _value_pb2.ValueProto) -> SpuValue:
|
121
|
+
"""Deserialize SpuValue from wire format."""
|
122
|
+
reader = ValueProtoReader(proto)
|
123
|
+
if reader.version != cls.WIRE_VERSION:
|
124
|
+
raise ValueDecodeError(f"Unsupported SpuValue version {reader.version}")
|
125
|
+
|
126
|
+
# Read metadata from runtime_attrs
|
127
|
+
shape = tuple(reader.get_attr("shape"))
|
128
|
+
dtype_name = reader.get_attr("dtype")
|
129
|
+
# Reconstruct DType from serialized name (numpy dtype string)
|
130
|
+
dtype = DType.from_numpy(dtype_name)
|
131
|
+
vtype = libspu.Visibility(reader.get_attr("vtype"))
|
132
|
+
share_meta = reader.get_attr("share_meta")
|
133
|
+
chunk_lengths = reader.get_attr("chunk_lengths")
|
134
|
+
|
135
|
+
# Parse payload: [chunk_0][chunk_1]...
|
136
|
+
payload = reader.payload
|
137
|
+
offset = 0
|
138
|
+
|
139
|
+
share_chunks: list[bytes] = []
|
140
|
+
for chunk_len in chunk_lengths:
|
141
|
+
chunk = payload[offset : offset + chunk_len]
|
142
|
+
offset += chunk_len
|
143
|
+
share_chunks.append(chunk)
|
144
|
+
|
145
|
+
# Reconstruct libspu.Share
|
146
|
+
share = libspu.Share()
|
147
|
+
share.meta = share_meta
|
148
|
+
share.share_chunks = share_chunks
|
149
|
+
|
150
|
+
return cls(
|
151
|
+
shape=shape,
|
152
|
+
dtype=dtype,
|
153
|
+
vtype=vtype,
|
154
|
+
share=share,
|
155
|
+
)
|
156
|
+
|
65
157
|
|
66
158
|
def _get_spu_config_and_world() -> tuple[libspu.RuntimeConfig, int]:
|
67
159
|
kctx = cur_kctx()
|
@@ -128,33 +220,25 @@ def _spu_seed_env(pfunc: PFunction, *args: Any) -> Any:
|
|
128
220
|
|
129
221
|
|
130
222
|
@kernel_def("spu.makeshares")
|
131
|
-
def _spu_makeshares(pfunc: PFunction,
|
132
|
-
"""Create SPU shares from input data.
|
133
|
-
|
134
|
-
Args:
|
135
|
-
pfunc: PFunction containing makeshares metadata
|
136
|
-
args: Input data to be shared (single tensor)
|
137
|
-
|
138
|
-
Returns:
|
139
|
-
Tuple of SPU shares (SpuValue), one for each party.
|
140
|
-
"""
|
141
|
-
assert len(args) == 1
|
142
|
-
|
223
|
+
def _spu_makeshares(pfunc: PFunction, tensor: TensorValue) -> tuple[SpuValue, ...]:
|
224
|
+
"""Create SPU shares from input TensorValue data."""
|
143
225
|
visibility_value = pfunc.attrs.get("visibility", libspu.Visibility.VIS_SECRET.value)
|
144
226
|
if isinstance(visibility_value, int):
|
145
227
|
visibility = libspu.Visibility(visibility_value)
|
146
228
|
else:
|
147
229
|
visibility = visibility_value
|
148
230
|
|
149
|
-
arg =
|
231
|
+
arg = tensor.to_numpy()
|
150
232
|
cfg, world = _get_spu_config_and_world()
|
151
233
|
spu_io = spu_api.Io(world, cfg)
|
152
234
|
shares = spu_io.make_shares(arg, visibility)
|
153
235
|
assert len(shares) == world, f"Expected {world} shares, got {len(shares)}"
|
236
|
+
# Store MPLang DType instead of libspu.DataType
|
237
|
+
dtype = DType.from_numpy(arg.dtype)
|
154
238
|
return tuple(
|
155
239
|
SpuValue(
|
156
240
|
shape=arg.shape,
|
157
|
-
dtype=
|
241
|
+
dtype=dtype,
|
158
242
|
vtype=visibility,
|
159
243
|
share=share,
|
160
244
|
)
|
@@ -163,24 +247,29 @@ def _spu_makeshares(pfunc: PFunction, *args: Any) -> Any:
|
|
163
247
|
|
164
248
|
|
165
249
|
@kernel_def("spu.reconstruct")
|
166
|
-
def _spu_reconstruct(pfunc: PFunction, *
|
250
|
+
def _spu_reconstruct(pfunc: PFunction, *shares: SpuValue) -> TensorValue:
|
167
251
|
"""Reconstruct plaintext data from SPU shares."""
|
168
252
|
cfg, world = _get_spu_config_and_world()
|
169
|
-
assert len(
|
170
|
-
for i,
|
171
|
-
if not isinstance(
|
253
|
+
assert len(shares) == world, f"Expected {world} shares, got {len(shares)}"
|
254
|
+
for i, share in enumerate(shares):
|
255
|
+
if not isinstance(share, SpuValue):
|
172
256
|
raise ValueError(
|
173
|
-
f"Input {i} must be SpuValue, got {type(
|
257
|
+
f"Input {i} must be SpuValue, got {type(share)}. Reconstruction requires SPU shares as input."
|
174
258
|
)
|
175
|
-
spu_args: list[SpuValue] = list(
|
176
|
-
|
259
|
+
spu_args: list[SpuValue] = list(shares) # type: ignore
|
260
|
+
share_payloads = [spu_arg.share for spu_arg in spu_args]
|
177
261
|
spu_io = spu_api.Io(world, cfg)
|
178
|
-
reconstructed = spu_io.reconstruct(
|
179
|
-
|
262
|
+
reconstructed = spu_io.reconstruct(share_payloads)
|
263
|
+
base = np.array(reconstructed, copy=False)
|
264
|
+
# Respect semantic dtype/shape recorded on shares (all shares share same meta).
|
265
|
+
semantic_dtype = shares[0].dtype.to_numpy() # DType now has to_numpy() method
|
266
|
+
semantic_shape = shares[0].shape
|
267
|
+
restored = np.asarray(base, dtype=semantic_dtype).reshape(semantic_shape)
|
268
|
+
return TensorValue(np.array(restored, copy=False))
|
180
269
|
|
181
270
|
|
182
271
|
@kernel_def("spu.run_pphlo")
|
183
|
-
def _spu_run_mlir(pfunc: PFunction, *args:
|
272
|
+
def _spu_run_mlir(pfunc: PFunction, *args: SpuValue) -> tuple[SpuValue, ...]:
|
184
273
|
"""Execute compiled SPU function (spu.run_pphlo) and return SpuValue outputs.
|
185
274
|
|
186
275
|
Participation rule: a rank participates iff its entry in the stored
|
@@ -240,10 +329,10 @@ def _spu_run_mlir(pfunc: PFunction, *args: Any) -> Any:
|
|
240
329
|
spu_rt.run(executable)
|
241
330
|
shares = [spu_rt.get_var(out_name) for out_name in output_names]
|
242
331
|
metas = [spu_rt.get_var_meta(out_name) for out_name in output_names]
|
243
|
-
results: list[
|
332
|
+
results: list[SpuValue] = [
|
244
333
|
SpuValue(
|
245
334
|
shape=shape_spu_to_np(meta.shape),
|
246
|
-
dtype=
|
335
|
+
dtype=dtype_spu_to_mpl(meta.data_type),
|
247
336
|
vtype=meta.visibility,
|
248
337
|
share=shares[idx],
|
249
338
|
)
|