mplang-nightly 0.1.dev164__py3-none-any.whl → 0.1.dev166__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mplang/kernels/phe.py CHANGED
@@ -14,15 +14,27 @@
14
14
 
15
15
  """PHE (Partially Homomorphic Encryption) backend implementation using lightPHE."""
16
16
 
17
- from typing import Any
17
+ from __future__ import annotations
18
+
19
+ import json
20
+ from typing import Any, ClassVar
18
21
 
19
22
  import numpy as np
20
23
  from lightphe import LightPHE
24
+ from lightphe.models.Ciphertext import Ciphertext
21
25
 
22
26
  from mplang.core.dtype import DType
23
- from mplang.core.mptype import TensorLike
24
27
  from mplang.core.pfunc import PFunction
25
28
  from mplang.kernels.base import kernel_def
29
+ from mplang.kernels.value import (
30
+ TensorValue,
31
+ Value,
32
+ ValueDecodeError,
33
+ ValueProtoBuilder,
34
+ ValueProtoReader,
35
+ register_value,
36
+ )
37
+ from mplang.protos.v1alpha1 import value_pb2 as _value_pb2
26
38
 
27
39
  # This controls the decimal precision used in lightPHE for float operations
28
40
  # we force it to 0 to only support integer operations
@@ -30,8 +42,12 @@ from mplang.kernels.base import kernel_def
30
42
  PRECISION = 0
31
43
 
32
44
 
33
- class PublicKey:
34
- """PHE Public Key that implements TensorLike protocol."""
45
+ @register_value
46
+ class PublicKey(Value):
47
+ """PHE Public Key Value type."""
48
+
49
+ KIND: ClassVar[str] = "mplang.phe.PublicKey"
50
+ WIRE_VERSION: ClassVar[int] = 1
35
51
 
36
52
  def __init__(
37
53
  self,
@@ -62,12 +78,56 @@ class PublicKey:
62
78
  """Maximum float value that can be encoded."""
63
79
  return float(self.max_value / (2**self.fxp_bits))
64
80
 
81
+ def to_proto(self) -> _value_pb2.ValueProto:
82
+ """Serialize PublicKey to wire format."""
83
+ return (
84
+ ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
85
+ .set_attr("scheme", self.scheme)
86
+ .set_attr("key_size", self.key_size)
87
+ .set_attr("max_value", self.max_value)
88
+ .set_attr("fxp_bits", self.fxp_bits)
89
+ .set_attr("modulus", str(self.modulus) if self.modulus is not None else "")
90
+ .set_payload(json.dumps(self.key_data).encode("utf-8"))
91
+ .build()
92
+ )
93
+
94
+ @classmethod
95
+ def from_proto(cls, proto: _value_pb2.ValueProto) -> PublicKey:
96
+ """Deserialize PublicKey from wire format."""
97
+ reader = ValueProtoReader(proto)
98
+ if reader.version != cls.WIRE_VERSION:
99
+ raise ValueDecodeError(f"Unsupported PublicKey version {reader.version}")
100
+
101
+ # Read metadata from runtime_attrs
102
+ scheme = reader.get_attr("scheme")
103
+ key_size = reader.get_attr("key_size")
104
+ max_value = reader.get_attr("max_value")
105
+ fxp_bits = reader.get_attr("fxp_bits")
106
+ modulus_str = reader.get_attr("modulus")
107
+ modulus = None if modulus_str == "" else int(modulus_str)
108
+
109
+ # JSON deserialize the public key dict
110
+ key_data = json.loads(reader.payload.decode("utf-8"))
111
+
112
+ return cls(
113
+ key_data=key_data,
114
+ scheme=scheme,
115
+ key_size=key_size,
116
+ max_value=max_value,
117
+ fxp_bits=fxp_bits,
118
+ modulus=modulus,
119
+ )
120
+
65
121
  def __repr__(self) -> str:
66
122
  return f"PublicKey(scheme={self.scheme}, key_size={self.key_size}, max_value={self.max_value}, fxp_bits={self.fxp_bits})"
67
123
 
68
124
 
69
- class PrivateKey:
70
- """PHE Private Key that implements TensorLike protocol."""
125
+ @register_value
126
+ class PrivateKey(Value):
127
+ """PHE Private Key Value type."""
128
+
129
+ KIND: ClassVar[str] = "mplang.phe.PrivateKey"
130
+ WIRE_VERSION: ClassVar[int] = 1
71
131
 
72
132
  def __init__(
73
133
  self,
@@ -100,12 +160,63 @@ class PrivateKey:
100
160
  """Maximum float value that can be encoded."""
101
161
  return float(self.max_value / (2**self.fxp_bits))
102
162
 
163
+ def to_proto(self) -> _value_pb2.ValueProto:
164
+ """Serialize PrivateKey to wire format."""
165
+ # JSON serialize both key dicts (contain int values)
166
+ # Store both keys in a single dict to avoid needing length metadata
167
+ keys_dict = {"sk": self.sk_data, "pk": self.pk_data}
168
+
169
+ return (
170
+ ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
171
+ .set_attr("scheme", self.scheme)
172
+ .set_attr("key_size", self.key_size)
173
+ .set_attr("max_value", self.max_value)
174
+ .set_attr("fxp_bits", self.fxp_bits)
175
+ .set_attr("modulus", str(self.modulus) if self.modulus is not None else "")
176
+ .set_payload(json.dumps(keys_dict).encode("utf-8"))
177
+ .build()
178
+ )
179
+
180
+ @classmethod
181
+ def from_proto(cls, proto: _value_pb2.ValueProto) -> PrivateKey:
182
+ """Deserialize PrivateKey from wire format."""
183
+ reader = ValueProtoReader(proto)
184
+ if reader.version != cls.WIRE_VERSION:
185
+ raise ValueDecodeError(f"Unsupported PrivateKey version {reader.version}")
186
+
187
+ # Read metadata from runtime_attrs
188
+ scheme = reader.get_attr("scheme")
189
+ key_size = reader.get_attr("key_size")
190
+ max_value = reader.get_attr("max_value")
191
+ fxp_bits = reader.get_attr("fxp_bits")
192
+ modulus_str = reader.get_attr("modulus")
193
+ modulus = None if modulus_str == "" else int(modulus_str)
194
+
195
+ # JSON deserialize both key dicts
196
+ keys_dict = json.loads(reader.payload.decode("utf-8"))
197
+ sk_data = keys_dict["sk"]
198
+ pk_data = keys_dict["pk"]
199
+
200
+ return cls(
201
+ sk_data=sk_data,
202
+ pk_data=pk_data,
203
+ scheme=scheme,
204
+ key_size=key_size,
205
+ max_value=max_value,
206
+ fxp_bits=fxp_bits,
207
+ modulus=modulus,
208
+ )
209
+
103
210
  def __repr__(self) -> str:
104
211
  return f"PrivateKey(scheme={self.scheme}, key_size={self.key_size}, max_value={self.max_value}, fxp_bits={self.fxp_bits})"
105
212
 
106
213
 
107
- class CipherText:
108
- """PHE CipherText that implements TensorLike protocol."""
214
+ @register_value
215
+ class CipherText(Value):
216
+ """PHE CipherText Value type."""
217
+
218
+ KIND: ClassVar[str] = "mplang.phe.CipherText"
219
+ WIRE_VERSION: ClassVar[int] = 1
109
220
 
110
221
  def __init__(
111
222
  self,
@@ -142,6 +253,106 @@ class CipherText:
142
253
  """Maximum float value that can be encoded."""
143
254
  return float(self.max_value / (2**self.fxp_bits))
144
255
 
256
+ def to_proto(self) -> _value_pb2.ValueProto:
257
+ """Serialize CipherText to wire format.
258
+
259
+ WARNING: This serialization is tightly coupled to lightphe.Ciphertext
260
+ internal attributes (value, algorithm_name, keys). Any changes to these
261
+ attributes in future lightphe versions will break serialization.
262
+
263
+ TODO: Check if lightphe provides official serialization methods and
264
+ migrate to them if available. Consider adding version compatibility checks.
265
+ """
266
+ # JSON serialize ciphertext components
267
+ # ct_data is a list of lightPHE Ciphertext objects
268
+ # Each Ciphertext has: value, algorithm_name, keys
269
+ # We need to serialize the list of ciphertexts
270
+ if not isinstance(self.ct_data, list):
271
+ raise ValueError(f"ct_data should be a list, got {type(self.ct_data)}")
272
+
273
+ ct_list = []
274
+ for ct in self.ct_data:
275
+ if not isinstance(ct, Ciphertext):
276
+ raise TypeError(
277
+ f"ct_data must contain lightphe.Ciphertext objects, got {type(ct).__name__}"
278
+ )
279
+ ct_list.append({
280
+ "value": ct.value,
281
+ "algorithm_name": ct.algorithm_name,
282
+ "keys": ct.keys,
283
+ })
284
+
285
+ # Combine ct_data and pk_data into single dict
286
+ payload_dict = {
287
+ "ct_list": ct_list,
288
+ "pk": self.pk_data if self.pk_data is not None else None,
289
+ }
290
+
291
+ return (
292
+ ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
293
+ .set_attr("semantic_dtype", str(self.semantic_dtype))
294
+ .set_attr("semantic_shape", list(self.semantic_shape))
295
+ .set_attr("scheme", self.scheme)
296
+ .set_attr("key_size", self.key_size)
297
+ .set_attr("max_value", self.max_value)
298
+ .set_attr("fxp_bits", self.fxp_bits)
299
+ .set_attr("modulus", str(self.modulus) if self.modulus is not None else "")
300
+ .set_payload(json.dumps(payload_dict).encode("utf-8"))
301
+ .build()
302
+ )
303
+
304
+ @classmethod
305
+ def from_proto(cls, proto: _value_pb2.ValueProto) -> CipherText:
306
+ """Deserialize CipherText from wire format."""
307
+ reader = ValueProtoReader(proto)
308
+ if reader.version != cls.WIRE_VERSION:
309
+ raise ValueDecodeError(f"Unsupported CipherText version {reader.version}")
310
+
311
+ # Read metadata from runtime_attrs
312
+ semantic_dtype_str = reader.get_attr("semantic_dtype")
313
+ semantic_shape = reader.get_attr("semantic_shape")
314
+ scheme = reader.get_attr("scheme")
315
+ key_size = reader.get_attr("key_size")
316
+ max_value = reader.get_attr("max_value")
317
+ fxp_bits = reader.get_attr("fxp_bits")
318
+ modulus_str = reader.get_attr("modulus")
319
+ modulus = None if modulus_str == "" else int(modulus_str)
320
+
321
+ # JSON deserialize ciphertext and public key
322
+ payload_dict = json.loads(reader.payload.decode("utf-8"))
323
+ ct_list = payload_dict["ct_list"]
324
+ pk_data = payload_dict["pk"]
325
+
326
+ # Reconstruct ct_data: list of Ciphertext objects
327
+ ct_data = []
328
+ for ct_dict in ct_list:
329
+ if ct_dict["keys"] is None or ct_dict["algorithm_name"] is None:
330
+ raise ValueDecodeError(
331
+ "Invalid CipherText: missing keys or algorithm_name in serialized data"
332
+ )
333
+ ct_data.append(
334
+ Ciphertext(
335
+ algorithm_name=ct_dict["algorithm_name"],
336
+ keys=ct_dict["keys"],
337
+ value=ct_dict["value"],
338
+ )
339
+ )
340
+
341
+ # Parse dtype string back to DType
342
+ dtype = DType.from_any(semantic_dtype_str)
343
+
344
+ return cls(
345
+ ct_data=ct_data,
346
+ semantic_dtype=dtype,
347
+ semantic_shape=tuple(semantic_shape),
348
+ scheme=scheme,
349
+ key_size=key_size,
350
+ pk_data=pk_data,
351
+ max_value=max_value,
352
+ fxp_bits=fxp_bits,
353
+ modulus=modulus,
354
+ )
355
+
145
356
  def __repr__(self) -> str:
146
357
  return f"CipherText(dtype={self.semantic_dtype}, shape={self.semantic_shape}, scheme={self.scheme})"
147
358
 
@@ -257,23 +468,6 @@ def _range_decode_mixed(
257
468
  return _range_decode_integer(encoded_value, max_value, modulus)
258
469
 
259
470
 
260
- def _convert_to_numpy(obj: TensorLike) -> np.ndarray:
261
- """Convert a TensorLike object to numpy array."""
262
- if isinstance(obj, np.ndarray):
263
- return obj
264
-
265
- # Try to use .numpy() method if available
266
- if hasattr(obj, "numpy"):
267
- numpy_method = getattr(obj, "numpy", None)
268
- if callable(numpy_method):
269
- try:
270
- return np.asarray(numpy_method())
271
- except Exception:
272
- pass
273
-
274
- return np.asarray(obj)
275
-
276
-
277
471
  @kernel_def("phe.keygen")
278
472
  def _phe_keygen(pfunc: PFunction) -> Any:
279
473
  scheme = pfunc.attrs.get("scheme", "paillier")
@@ -334,14 +528,16 @@ def _phe_keygen(pfunc: PFunction) -> Any:
334
528
 
335
529
 
336
530
  @kernel_def("phe.encrypt")
337
- def _phe_encrypt(pfunc: PFunction, plaintext: Any, public_key: PublicKey) -> Any:
531
+ def _phe_encrypt(
532
+ pfunc: PFunction, plaintext: TensorValue, public_key: PublicKey
533
+ ) -> Any:
338
534
  # Validate public_key type
339
535
  if not isinstance(public_key, PublicKey):
340
536
  raise ValueError("Second argument must be a PublicKey instance")
341
537
 
342
538
  try:
343
539
  # Convert plaintext to numpy to get semantic type info
344
- plaintext_np = _convert_to_numpy(plaintext)
540
+ plaintext_np = plaintext.to_numpy()
345
541
  semantic_dtype = DType.from_numpy(plaintext_np.dtype)
346
542
  semantic_shape = plaintext_np.shape
347
543
 
@@ -403,14 +599,14 @@ def _phe_encrypt(pfunc: PFunction, plaintext: Any, public_key: PublicKey) -> Any
403
599
 
404
600
 
405
601
  @kernel_def("phe.mul")
406
- def _phe_mul(pfunc: PFunction, ciphertext: CipherText, plaintext: Any) -> Any:
602
+ def _phe_mul(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorValue) -> Any:
407
603
  # Validate that first argument is a CipherText
408
604
  if not isinstance(ciphertext, CipherText):
409
605
  raise ValueError("First argument must be a CipherText instance")
410
606
 
411
607
  try:
412
608
  # Convert plaintext to numpy
413
- plaintext_np = _convert_to_numpy(plaintext)
609
+ plaintext_np = plaintext.to_numpy()
414
610
 
415
611
  # Check if plaintext is floating point type - multiplication not supported
416
612
  if np.issubdtype(plaintext_np.dtype, np.floating):
@@ -511,7 +707,7 @@ def _phe_add(pfunc: PFunction, lhs: Any, rhs: Any) -> Any:
511
707
  elif isinstance(rhs, CipherText):
512
708
  return _phe_add_ct2pt(rhs, lhs)
513
709
  else:
514
- return _convert_to_numpy(lhs) + _convert_to_numpy(rhs)
710
+ return TensorValue(lhs.to_numpy() + rhs.to_numpy())
515
711
  except ValueError:
516
712
  raise
517
713
  except Exception as e: # pragma: no cover
@@ -593,9 +789,9 @@ def _phe_add_ct2ct(ct1: CipherText, ct2: CipherText) -> CipherText:
593
789
  )
594
790
 
595
791
 
596
- def _phe_add_ct2pt(ciphertext: CipherText, plaintext: TensorLike) -> CipherText:
792
+ def _phe_add_ct2pt(ciphertext: CipherText, plaintext: TensorValue) -> CipherText:
597
793
  # Convert plaintext to numpy
598
- plaintext_np = _convert_to_numpy(plaintext)
794
+ plaintext_np = plaintext.to_numpy()
599
795
  plaintext_dtype = DType.from_numpy(plaintext_np.dtype)
600
796
 
601
797
  # Check for mixed precision issue: floating point ciphertext + integer plaintext
@@ -816,14 +1012,14 @@ def _phe_decrypt(
816
1012
  ciphertext.semantic_shape
817
1013
  )
818
1014
 
819
- return [plaintext_np]
1015
+ return [TensorValue(plaintext_np)]
820
1016
 
821
1017
  except Exception as e:
822
1018
  raise RuntimeError(f"Failed to decrypt data: {e}") from e
823
1019
 
824
1020
 
825
1021
  @kernel_def("phe.dot")
826
- def _phe_dot(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorLike) -> Any:
1022
+ def _phe_dot(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorValue) -> Any:
827
1023
  """Execute homomorphic dot product with zero-value optimization.
828
1024
 
829
1025
  Supports various dot product operations:
@@ -844,7 +1040,7 @@ def _phe_dot(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorLike) ->
844
1040
 
845
1041
  try:
846
1042
  # Convert plaintext to numpy
847
- plaintext_np = _convert_to_numpy(plaintext)
1043
+ plaintext_np = plaintext.to_numpy()
848
1044
 
849
1045
  # Check if plaintext is floating point type - dot product not supported
850
1046
  if np.issubdtype(plaintext_np.dtype, np.floating):
@@ -1109,7 +1305,7 @@ def _phe_dot(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorLike) ->
1109
1305
 
1110
1306
 
1111
1307
  @kernel_def("phe.gather")
1112
- def _phe_gather(pfunc: PFunction, ciphertext: CipherText, indices: Any) -> Any:
1308
+ def _phe_gather(pfunc: PFunction, ciphertext: CipherText, indices: TensorValue) -> Any:
1113
1309
  """Execute gather operation on CipherText.
1114
1310
 
1115
1311
  Supports gathering from multidimensional CipherText using multidimensional indices.
@@ -1126,7 +1322,7 @@ def _phe_gather(pfunc: PFunction, ciphertext: CipherText, indices: Any) -> Any:
1126
1322
 
1127
1323
  try:
1128
1324
  # Convert indices to numpy
1129
- indices_np = _convert_to_numpy(indices)
1325
+ indices_np = indices.to_numpy()
1130
1326
 
1131
1327
  if not np.issubdtype(indices_np.dtype, np.integer):
1132
1328
  raise ValueError("Indices must be of integer type")
@@ -1224,7 +1420,10 @@ def _phe_gather(pfunc: PFunction, ciphertext: CipherText, indices: Any) -> Any:
1224
1420
 
1225
1421
  @kernel_def("phe.scatter")
1226
1422
  def _phe_scatter(
1227
- pfunc: PFunction, ciphertext: CipherText, indices: TensorLike, updated: CipherText
1423
+ pfunc: PFunction,
1424
+ ciphertext: CipherText,
1425
+ indices: TensorValue,
1426
+ updated: CipherText,
1228
1427
  ) -> Any:
1229
1428
  """Execute scatter operation on CipherText.
1230
1429
 
@@ -1252,7 +1451,7 @@ def _phe_scatter(
1252
1451
 
1253
1452
  try:
1254
1453
  # Convert indices to numpy
1255
- indices_np = _convert_to_numpy(indices)
1454
+ indices_np = indices.to_numpy()
1256
1455
 
1257
1456
  if not np.issubdtype(indices_np.dtype, np.integer):
1258
1457
  raise ValueError("Indices must be of integer type")
mplang/kernels/spu.py CHANGED
@@ -15,15 +15,37 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  from dataclasses import dataclass
18
- from typing import Any
18
+ from typing import Any, ClassVar
19
19
 
20
20
  import numpy as np
21
21
  import spu.api as spu_api
22
22
  import spu.libspu as libspu
23
23
 
24
- from mplang.core.mptype import TensorLike
24
+ from mplang.core.dtype import (
25
+ BOOL,
26
+ FLOAT32,
27
+ FLOAT64,
28
+ INT8,
29
+ INT16,
30
+ INT32,
31
+ INT64,
32
+ UINT8,
33
+ UINT16,
34
+ UINT32,
35
+ UINT64,
36
+ DType,
37
+ )
25
38
  from mplang.core.pfunc import PFunction
26
39
  from mplang.kernels.base import cur_kctx, kernel_def
40
+ from mplang.kernels.value import (
41
+ TensorValue,
42
+ Value,
43
+ ValueDecodeError,
44
+ ValueProtoBuilder,
45
+ ValueProtoReader,
46
+ register_value,
47
+ )
48
+ from mplang.protos.v1alpha1 import value_pb2 as _value_pb2
27
49
  from mplang.runtime.link_comm import LinkCommunicator
28
50
 
29
51
 
@@ -32,36 +54,106 @@ def shape_spu_to_np(spu_shape: Any) -> tuple[int, ...]:
32
54
  return tuple(spu_shape.dims)
33
55
 
34
56
 
35
- def dtype_spu_to_np(spu_dtype: Any) -> np.dtype:
36
- """Convert SPU dtype to numpy dtype."""
57
+ def dtype_spu_to_mpl(spu_dtype: libspu.DataType) -> DType:
58
+ """Convert libspu.DataType to MPLang DType."""
37
59
  MAP = {
38
- libspu.DataType.DT_F32: np.float32,
39
- libspu.DataType.DT_F64: np.float64,
40
- libspu.DataType.DT_I1: np.bool_,
41
- libspu.DataType.DT_I8: np.int8,
42
- libspu.DataType.DT_U8: np.uint8,
43
- libspu.DataType.DT_I16: np.int16,
44
- libspu.DataType.DT_U16: np.uint16,
45
- libspu.DataType.DT_I32: np.int32,
46
- libspu.DataType.DT_U32: np.uint32,
47
- libspu.DataType.DT_I64: np.int64,
48
- libspu.DataType.DT_U64: np.uint64,
60
+ libspu.DataType.DT_F32: FLOAT32,
61
+ libspu.DataType.DT_F64: FLOAT64,
62
+ libspu.DataType.DT_I1: BOOL,
63
+ libspu.DataType.DT_I8: INT8,
64
+ libspu.DataType.DT_U8: UINT8,
65
+ libspu.DataType.DT_I16: INT16,
66
+ libspu.DataType.DT_U16: UINT16,
67
+ libspu.DataType.DT_I32: INT32,
68
+ libspu.DataType.DT_U32: UINT32,
69
+ libspu.DataType.DT_I64: INT64,
70
+ libspu.DataType.DT_U64: UINT64,
49
71
  }
50
- return MAP[spu_dtype] # type: ignore[return-value]
72
+ return MAP[spu_dtype]
51
73
 
52
74
 
75
+ @register_value
53
76
  @dataclass
54
- class SpuValue:
55
- """SPU value container for secure computation."""
77
+ class SpuValue(Value):
78
+ """SPU value container for secure computation (Value type)."""
79
+
80
+ KIND: ClassVar[str] = "mplang.spu.SpuValue"
81
+ WIRE_VERSION: ClassVar[int] = 1
56
82
 
57
83
  shape: tuple[int, ...]
58
- dtype: Any
84
+ dtype: DType # Now uses MPLang's unified DType
59
85
  vtype: libspu.Visibility
60
86
  share: libspu.Share
61
87
 
62
88
  def __repr__(self) -> str:
63
89
  return f"SpuValue({self.shape},{self.dtype},{self.vtype})"
64
90
 
91
+ def to_proto(self) -> _value_pb2.ValueProto:
92
+ """Serialize SpuValue to wire format.
93
+
94
+ libspu.Share has two attributes:
95
+ - meta: bytes (protobuf serialized metadata)
96
+ - share_chunks: list[bytes] (the actual secret share data)
97
+
98
+ Strategy: Store shape/dtype/vtype in runtime_attrs, concatenate share.meta + all chunks in payload.
99
+ """
100
+ # Store metadata in runtime_attrs; keep chunk lengths for payload splitting
101
+ chunk_lengths = [len(chunk) for chunk in self.share.share_chunks]
102
+
103
+ # Payload contains only share chunks (meta stored in attrs)
104
+ payload = b""
105
+ for chunk in self.share.share_chunks:
106
+ payload += chunk
107
+
108
+ return (
109
+ ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
110
+ .set_attr("shape", list(self.shape))
111
+ .set_attr("dtype", self.dtype.name) # Serialize DType name
112
+ .set_attr("vtype", int(self.vtype))
113
+ .set_attr("share_meta", self.share.meta)
114
+ .set_attr("chunk_lengths", chunk_lengths)
115
+ .set_payload(payload)
116
+ .build()
117
+ )
118
+
119
+ @classmethod
120
+ def from_proto(cls, proto: _value_pb2.ValueProto) -> SpuValue:
121
+ """Deserialize SpuValue from wire format."""
122
+ reader = ValueProtoReader(proto)
123
+ if reader.version != cls.WIRE_VERSION:
124
+ raise ValueDecodeError(f"Unsupported SpuValue version {reader.version}")
125
+
126
+ # Read metadata from runtime_attrs
127
+ shape = tuple(reader.get_attr("shape"))
128
+ dtype_name = reader.get_attr("dtype")
129
+ # Reconstruct DType from serialized name (numpy dtype string)
130
+ dtype = DType.from_numpy(dtype_name)
131
+ vtype = libspu.Visibility(reader.get_attr("vtype"))
132
+ share_meta = reader.get_attr("share_meta")
133
+ chunk_lengths = reader.get_attr("chunk_lengths")
134
+
135
+ # Parse payload: [chunk_0][chunk_1]...
136
+ payload = reader.payload
137
+ offset = 0
138
+
139
+ share_chunks: list[bytes] = []
140
+ for chunk_len in chunk_lengths:
141
+ chunk = payload[offset : offset + chunk_len]
142
+ offset += chunk_len
143
+ share_chunks.append(chunk)
144
+
145
+ # Reconstruct libspu.Share
146
+ share = libspu.Share()
147
+ share.meta = share_meta
148
+ share.share_chunks = share_chunks
149
+
150
+ return cls(
151
+ shape=shape,
152
+ dtype=dtype,
153
+ vtype=vtype,
154
+ share=share,
155
+ )
156
+
65
157
 
66
158
  def _get_spu_config_and_world() -> tuple[libspu.RuntimeConfig, int]:
67
159
  kctx = cur_kctx()
@@ -128,33 +220,25 @@ def _spu_seed_env(pfunc: PFunction, *args: Any) -> Any:
128
220
 
129
221
 
130
222
  @kernel_def("spu.makeshares")
131
- def _spu_makeshares(pfunc: PFunction, *args: Any) -> Any:
132
- """Create SPU shares from input data.
133
-
134
- Args:
135
- pfunc: PFunction containing makeshares metadata
136
- args: Input data to be shared (single tensor)
137
-
138
- Returns:
139
- Tuple of SPU shares (SpuValue), one for each party.
140
- """
141
- assert len(args) == 1
142
-
223
+ def _spu_makeshares(pfunc: PFunction, tensor: TensorValue) -> tuple[SpuValue, ...]:
224
+ """Create SPU shares from input TensorValue data."""
143
225
  visibility_value = pfunc.attrs.get("visibility", libspu.Visibility.VIS_SECRET.value)
144
226
  if isinstance(visibility_value, int):
145
227
  visibility = libspu.Visibility(visibility_value)
146
228
  else:
147
229
  visibility = visibility_value
148
230
 
149
- arg = np.array(args[0], copy=False)
231
+ arg = tensor.to_numpy()
150
232
  cfg, world = _get_spu_config_and_world()
151
233
  spu_io = spu_api.Io(world, cfg)
152
234
  shares = spu_io.make_shares(arg, visibility)
153
235
  assert len(shares) == world, f"Expected {world} shares, got {len(shares)}"
236
+ # Store MPLang DType instead of libspu.DataType
237
+ dtype = DType.from_numpy(arg.dtype)
154
238
  return tuple(
155
239
  SpuValue(
156
240
  shape=arg.shape,
157
- dtype=arg.dtype,
241
+ dtype=dtype,
158
242
  vtype=visibility,
159
243
  share=share,
160
244
  )
@@ -163,24 +247,29 @@ def _spu_makeshares(pfunc: PFunction, *args: Any) -> Any:
163
247
 
164
248
 
165
249
  @kernel_def("spu.reconstruct")
166
- def _spu_reconstruct(pfunc: PFunction, *args: Any) -> Any:
250
+ def _spu_reconstruct(pfunc: PFunction, *shares: SpuValue) -> TensorValue:
167
251
  """Reconstruct plaintext data from SPU shares."""
168
252
  cfg, world = _get_spu_config_and_world()
169
- assert len(args) == world, f"Expected {world} shares, got {len(args)}"
170
- for i, arg in enumerate(args):
171
- if not isinstance(arg, SpuValue):
253
+ assert len(shares) == world, f"Expected {world} shares, got {len(shares)}"
254
+ for i, share in enumerate(shares):
255
+ if not isinstance(share, SpuValue):
172
256
  raise ValueError(
173
- f"Input {i} must be SpuValue, got {type(arg)}. Reconstruction requires SPU shares as input."
257
+ f"Input {i} must be SpuValue, got {type(share)}. Reconstruction requires SPU shares as input."
174
258
  )
175
- spu_args: list[SpuValue] = list(args) # type: ignore
176
- shares = [spu_arg.share for spu_arg in spu_args]
259
+ spu_args: list[SpuValue] = list(shares) # type: ignore
260
+ share_payloads = [spu_arg.share for spu_arg in spu_args]
177
261
  spu_io = spu_api.Io(world, cfg)
178
- reconstructed = spu_io.reconstruct(shares)
179
- return reconstructed
262
+ reconstructed = spu_io.reconstruct(share_payloads)
263
+ base = np.array(reconstructed, copy=False)
264
+ # Respect semantic dtype/shape recorded on shares (all shares share same meta).
265
+ semantic_dtype = shares[0].dtype.to_numpy() # DType now has to_numpy() method
266
+ semantic_shape = shares[0].shape
267
+ restored = np.asarray(base, dtype=semantic_dtype).reshape(semantic_shape)
268
+ return TensorValue(np.array(restored, copy=False))
180
269
 
181
270
 
182
271
  @kernel_def("spu.run_pphlo")
183
- def _spu_run_mlir(pfunc: PFunction, *args: Any) -> Any:
272
+ def _spu_run_mlir(pfunc: PFunction, *args: SpuValue) -> tuple[SpuValue, ...]:
184
273
  """Execute compiled SPU function (spu.run_pphlo) and return SpuValue outputs.
185
274
 
186
275
  Participation rule: a rank participates iff its entry in the stored
@@ -240,10 +329,10 @@ def _spu_run_mlir(pfunc: PFunction, *args: Any) -> Any:
240
329
  spu_rt.run(executable)
241
330
  shares = [spu_rt.get_var(out_name) for out_name in output_names]
242
331
  metas = [spu_rt.get_var_meta(out_name) for out_name in output_names]
243
- results: list[TensorLike] = [
332
+ results: list[SpuValue] = [
244
333
  SpuValue(
245
334
  shape=shape_spu_to_np(meta.shape),
246
- dtype=dtype_spu_to_np(meta.data_type),
335
+ dtype=dtype_spu_to_mpl(meta.data_type),
247
336
  vtype=meta.visibility,
248
337
  share=shares[idx],
249
338
  )