mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/v2/edsl/serde.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
JSON-based serialization for MPLang types and graphs.
|
|
17
|
+
|
|
18
|
+
This module provides a secure, extensible serialization mechanism that replaces
|
|
19
|
+
cloudpickle. Each type is responsible for its own serialization via the
|
|
20
|
+
`@register_class` decorator pattern.
|
|
21
|
+
|
|
22
|
+
Usage:
|
|
23
|
+
from mplang.v2.edsl import serde
|
|
24
|
+
|
|
25
|
+
@serde.register_class
|
|
26
|
+
class MyType:
|
|
27
|
+
_serde_kind = "mymodule.MyType"
|
|
28
|
+
|
|
29
|
+
def to_json(self) -> dict:
|
|
30
|
+
return {"field": self.field}
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def from_json(cls, data: dict) -> "MyType":
|
|
34
|
+
return cls(data["field"])
|
|
35
|
+
|
|
36
|
+
# Serialize
|
|
37
|
+
obj = MyType(...)
|
|
38
|
+
json_data = serde.to_json(obj)
|
|
39
|
+
|
|
40
|
+
# Deserialize
|
|
41
|
+
obj2 = serde.from_json(json_data)
|
|
42
|
+
|
|
43
|
+
Security:
|
|
44
|
+
Unlike pickle/cloudpickle, JSON deserialization only reconstructs data
|
|
45
|
+
structures - it cannot execute arbitrary code. The `from_json` methods
|
|
46
|
+
are explicitly defined by each registered class.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
from __future__ import annotations
|
|
50
|
+
|
|
51
|
+
import base64
|
|
52
|
+
import gzip
|
|
53
|
+
import json
|
|
54
|
+
from typing import Any, ClassVar, Protocol, TypeVar, runtime_checkable
|
|
55
|
+
|
|
56
|
+
import numpy as np
|
|
57
|
+
|
|
58
|
+
# =============================================================================
|
|
59
|
+
# Type Registry
|
|
60
|
+
# =============================================================================
|
|
61
|
+
|
|
62
|
+
# Global registry: kind string -> class
|
|
63
|
+
_CLASS_REGISTRY: dict[str, type] = {}
|
|
64
|
+
|
|
65
|
+
T = TypeVar("T")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def register_class(cls: type[T]) -> type[T]:
|
|
69
|
+
"""Decorator to register a class for JSON serialization.
|
|
70
|
+
|
|
71
|
+
The class must define:
|
|
72
|
+
- `_serde_kind: ClassVar[str]` - unique identifier for this type
|
|
73
|
+
- `to_json(self) -> dict` - serialize instance to JSON-compatible dict
|
|
74
|
+
- `from_json(cls, data: dict) -> Self` - deserialize from dict
|
|
75
|
+
|
|
76
|
+
Example:
|
|
77
|
+
@serde.register_class
|
|
78
|
+
class MyType:
|
|
79
|
+
_serde_kind = "mymodule.MyType"
|
|
80
|
+
|
|
81
|
+
def to_json(self) -> dict:
|
|
82
|
+
return {"value": self.value}
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
def from_json(cls, data: dict) -> "MyType":
|
|
86
|
+
return cls(data["value"])
|
|
87
|
+
"""
|
|
88
|
+
kind = getattr(cls, "_serde_kind", None)
|
|
89
|
+
if kind is None:
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"{cls.__name__} must define `_serde_kind` class variable "
|
|
92
|
+
"for serialization registration"
|
|
93
|
+
)
|
|
94
|
+
if kind in _CLASS_REGISTRY:
|
|
95
|
+
existing = _CLASS_REGISTRY[kind]
|
|
96
|
+
if existing is not cls:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
f"Duplicate _serde_kind '{kind}': "
|
|
99
|
+
f"already registered by {existing.__name__}"
|
|
100
|
+
)
|
|
101
|
+
_CLASS_REGISTRY[kind] = cls
|
|
102
|
+
return cls
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def get_registered_class(kind: str) -> type | None:
|
|
106
|
+
"""Get the class registered for a given kind string."""
|
|
107
|
+
return _CLASS_REGISTRY.get(kind)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def list_registered_kinds() -> list[str]:
|
|
111
|
+
"""List all registered kind strings."""
|
|
112
|
+
return list(_CLASS_REGISTRY.keys())
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
# =============================================================================
|
|
116
|
+
# Serialization Protocol
|
|
117
|
+
# =============================================================================
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@runtime_checkable
|
|
121
|
+
class JsonSerializable(Protocol):
|
|
122
|
+
"""Protocol for types that can be serialized to JSON."""
|
|
123
|
+
|
|
124
|
+
_serde_kind: ClassVar[str]
|
|
125
|
+
|
|
126
|
+
def to_json(self) -> dict[str, Any]: ...
|
|
127
|
+
|
|
128
|
+
@classmethod
|
|
129
|
+
def from_json(cls, data: dict[str, Any]) -> JsonSerializable: ...
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
# =============================================================================
|
|
133
|
+
# Core Serialization Functions
|
|
134
|
+
# =============================================================================
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def to_json(obj: Any) -> dict[str, Any]:
|
|
138
|
+
"""Serialize an object to a JSON-compatible dict.
|
|
139
|
+
|
|
140
|
+
The object must either:
|
|
141
|
+
1. Be a registered class with `_serde_kind` and `to_json()` method
|
|
142
|
+
2. Be a primitive type (int, float, str, bool, None)
|
|
143
|
+
3. Be a list/tuple of serializable objects
|
|
144
|
+
4. Be a dict with string keys and serializable values
|
|
145
|
+
5. Be a numpy ndarray
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
obj: Object to serialize
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
JSON-compatible dict with `_kind` field for type dispatch
|
|
152
|
+
|
|
153
|
+
Raises:
|
|
154
|
+
TypeError: If object cannot be serialized
|
|
155
|
+
"""
|
|
156
|
+
# Registered classes
|
|
157
|
+
if hasattr(obj, "_serde_kind") and hasattr(obj, "to_json"):
|
|
158
|
+
data: dict[str, Any] = obj.to_json()
|
|
159
|
+
data["_kind"] = obj._serde_kind
|
|
160
|
+
return data
|
|
161
|
+
|
|
162
|
+
# Primitives
|
|
163
|
+
if obj is None:
|
|
164
|
+
return {"_kind": "_null"}
|
|
165
|
+
if isinstance(obj, bool): # Must check before int (bool is subclass of int)
|
|
166
|
+
return {"_kind": "_bool", "v": obj}
|
|
167
|
+
|
|
168
|
+
if isinstance(obj, int):
|
|
169
|
+
return {"_kind": "_int", "v": obj}
|
|
170
|
+
if isinstance(obj, float):
|
|
171
|
+
return {"_kind": "_float", "v": obj}
|
|
172
|
+
if isinstance(obj, str):
|
|
173
|
+
return {"_kind": "_str", "v": obj}
|
|
174
|
+
|
|
175
|
+
# Numpy scalar types (int64, float32, etc.)
|
|
176
|
+
if isinstance(obj, np.integer):
|
|
177
|
+
return {"_kind": "_int", "v": int(obj)}
|
|
178
|
+
if isinstance(obj, np.floating):
|
|
179
|
+
return {"_kind": "_float", "v": float(obj)}
|
|
180
|
+
|
|
181
|
+
# Numpy array - handle both numeric and object arrays
|
|
182
|
+
if isinstance(obj, np.ndarray):
|
|
183
|
+
# Object arrays need element-wise serialization
|
|
184
|
+
if obj.dtype == np.object_:
|
|
185
|
+
return {
|
|
186
|
+
"_kind": "_ndarray_object",
|
|
187
|
+
"shape": list(obj.shape),
|
|
188
|
+
"items": [to_json(item) for item in obj.flat],
|
|
189
|
+
}
|
|
190
|
+
# Numeric arrays use efficient binary format
|
|
191
|
+
return {
|
|
192
|
+
"_kind": "_ndarray",
|
|
193
|
+
"dtype": str(obj.dtype),
|
|
194
|
+
"shape": list(obj.shape),
|
|
195
|
+
"data": base64.b64encode(obj.tobytes()).decode("ascii"),
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
# Array-like (JAX, etc.) - convert to numpy
|
|
199
|
+
if hasattr(obj, "__array__"):
|
|
200
|
+
return to_json(np.asarray(obj))
|
|
201
|
+
|
|
202
|
+
# Collections
|
|
203
|
+
if isinstance(obj, (list, tuple)):
|
|
204
|
+
return {
|
|
205
|
+
"_kind": "_list" if isinstance(obj, list) else "_tuple",
|
|
206
|
+
"items": [to_json(item) for item in obj],
|
|
207
|
+
}
|
|
208
|
+
if isinstance(obj, dict):
|
|
209
|
+
# Handle dicts with non-string keys by serializing as list of pairs
|
|
210
|
+
# This preserves key types (int, tuple, etc.)
|
|
211
|
+
has_non_string_keys = any(not isinstance(k, str) for k in obj.keys())
|
|
212
|
+
if has_non_string_keys:
|
|
213
|
+
return {
|
|
214
|
+
"_kind": "_dict_pairs",
|
|
215
|
+
"pairs": [[to_json(k), to_json(v)] for k, v in obj.items()],
|
|
216
|
+
}
|
|
217
|
+
return {
|
|
218
|
+
"_kind": "_dict",
|
|
219
|
+
"items": {k: to_json(v) for k, v in obj.items()},
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
# Bytes
|
|
223
|
+
if isinstance(obj, bytes):
|
|
224
|
+
return {
|
|
225
|
+
"_kind": "_bytes",
|
|
226
|
+
"data": base64.b64encode(obj).decode("ascii"),
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
raise TypeError(
|
|
230
|
+
f"Cannot serialize object of type {type(obj).__name__}. "
|
|
231
|
+
"Ensure the class is decorated with @serde.register_class "
|
|
232
|
+
"and implements to_json()/from_json()."
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def from_json(data: dict[str, Any]) -> Any:
|
|
237
|
+
"""Deserialize an object from a JSON-compatible dict.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
data: Dict with `_kind` field indicating the type
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
Deserialized object
|
|
244
|
+
|
|
245
|
+
Raises:
|
|
246
|
+
ValueError: If `_kind` is missing or unknown
|
|
247
|
+
"""
|
|
248
|
+
if not isinstance(data, dict):
|
|
249
|
+
raise ValueError(f"Expected dict, got {type(data).__name__}")
|
|
250
|
+
|
|
251
|
+
kind = data.get("_kind")
|
|
252
|
+
if kind is None:
|
|
253
|
+
raise ValueError("Missing '_kind' field in JSON data")
|
|
254
|
+
|
|
255
|
+
# Built-in primitives
|
|
256
|
+
if kind == "_null":
|
|
257
|
+
return None
|
|
258
|
+
if kind == "_bool":
|
|
259
|
+
return bool(data["v"])
|
|
260
|
+
if kind == "_int":
|
|
261
|
+
return int(data["v"])
|
|
262
|
+
if kind == "_float":
|
|
263
|
+
return float(data["v"])
|
|
264
|
+
if kind == "_str":
|
|
265
|
+
return str(data["v"])
|
|
266
|
+
|
|
267
|
+
# Collections
|
|
268
|
+
if kind == "_list":
|
|
269
|
+
return [from_json(item) for item in data["items"]]
|
|
270
|
+
if kind == "_tuple":
|
|
271
|
+
return tuple(from_json(item) for item in data["items"])
|
|
272
|
+
if kind == "_dict":
|
|
273
|
+
return {k: from_json(v) for k, v in data["items"].items()}
|
|
274
|
+
if kind == "_dict_pairs":
|
|
275
|
+
# Handle dicts with non-string keys
|
|
276
|
+
return {from_json(pair[0]): from_json(pair[1]) for pair in data["pairs"]}
|
|
277
|
+
|
|
278
|
+
# Bytes
|
|
279
|
+
if kind == "_bytes":
|
|
280
|
+
return base64.b64decode(data["data"])
|
|
281
|
+
|
|
282
|
+
# Legacy numpy array formats - kept for backward compatibility
|
|
283
|
+
# New serializations go through TensorValue (tensor_impl.TensorValue)
|
|
284
|
+
if kind == "_ndarray":
|
|
285
|
+
dtype_str = data["dtype"]
|
|
286
|
+
shape = tuple(data["shape"])
|
|
287
|
+
buffer = base64.b64decode(data["data"])
|
|
288
|
+
dtype = np.dtype(dtype_str)
|
|
289
|
+
return np.frombuffer(buffer, dtype=dtype).reshape(shape).copy()
|
|
290
|
+
|
|
291
|
+
if kind == "_ndarray_object":
|
|
292
|
+
shape = tuple(data["shape"])
|
|
293
|
+
items = [from_json(item) for item in data["items"]]
|
|
294
|
+
arr = np.empty(len(items), dtype=object)
|
|
295
|
+
for i, item in enumerate(items):
|
|
296
|
+
arr[i] = item
|
|
297
|
+
# Always reshape - empty tuple () means scalar, which requires reshape
|
|
298
|
+
return arr.reshape(shape)
|
|
299
|
+
|
|
300
|
+
# Registered classes
|
|
301
|
+
if kind in _CLASS_REGISTRY:
|
|
302
|
+
cls = _CLASS_REGISTRY[kind]
|
|
303
|
+
# Remove _kind before passing to from_json
|
|
304
|
+
data_copy = {k: v for k, v in data.items() if k != "_kind"}
|
|
305
|
+
return cls.from_json(data_copy) # type: ignore[attr-defined]
|
|
306
|
+
|
|
307
|
+
raise ValueError(
|
|
308
|
+
f"Unknown type kind: '{kind}'. "
|
|
309
|
+
"Ensure the class is registered with @serde.register_class "
|
|
310
|
+
"and the module is imported."
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
# =============================================================================
|
|
315
|
+
# Convenience Functions for Wire Format
|
|
316
|
+
# =============================================================================
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def dumps(obj: Any, *, compress: bool = True) -> bytes:
|
|
320
|
+
"""Serialize object to bytes (JSON + optional gzip).
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
obj: Object to serialize
|
|
324
|
+
compress: Whether to gzip compress the output (default: True)
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
Serialized bytes
|
|
328
|
+
"""
|
|
329
|
+
json_str = json.dumps(to_json(obj), separators=(",", ":"))
|
|
330
|
+
data = json_str.encode("utf-8")
|
|
331
|
+
if compress:
|
|
332
|
+
data = gzip.compress(data)
|
|
333
|
+
return data
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def loads(data: bytes, *, compressed: bool = True) -> Any:
|
|
337
|
+
"""Deserialize object from bytes.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
data: Serialized bytes
|
|
341
|
+
compressed: Whether the data is gzip compressed (default: True)
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
Deserialized object
|
|
345
|
+
"""
|
|
346
|
+
if compressed:
|
|
347
|
+
data = gzip.decompress(data)
|
|
348
|
+
json_data = json.loads(data.decode("utf-8"))
|
|
349
|
+
return from_json(json_data)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def dumps_b64(obj: Any, *, compress: bool = True) -> str:
|
|
353
|
+
"""Serialize object to base64 string (for JSON transport).
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
obj: Object to serialize
|
|
357
|
+
compress: Whether to gzip compress (default: True)
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
Base64-encoded string
|
|
361
|
+
"""
|
|
362
|
+
return base64.b64encode(dumps(obj, compress=compress)).decode("ascii")
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def loads_b64(data: str, *, compressed: bool = True) -> Any:
|
|
366
|
+
"""Deserialize object from base64 string.
|
|
367
|
+
|
|
368
|
+
Args:
|
|
369
|
+
data: Base64-encoded string
|
|
370
|
+
compressed: Whether the data is gzip compressed (default: True)
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
Deserialized object
|
|
374
|
+
"""
|
|
375
|
+
return loads(base64.b64decode(data), compressed=compressed)
|