flwr 1.18.0__py3-none-any.whl → 1.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/app/__init__.py +15 -0
- flwr/app/error.py +68 -0
- flwr/app/metadata.py +223 -0
- flwr/cli/build.py +82 -57
- flwr/cli/log.py +3 -3
- flwr/cli/login/login.py +3 -7
- flwr/cli/ls.py +15 -36
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -17
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +10 -18
- flwr/cli/stop.py +2 -2
- flwr/cli/utils.py +31 -5
- flwr/client/__init__.py +2 -2
- flwr/client/client_app.py +1 -1
- flwr/client/clientapp/__init__.py +0 -7
- flwr/client/grpc_adapter_client/connection.py +4 -4
- flwr/client/grpc_rere_client/connection.py +130 -60
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/mod/comms_mods.py +36 -17
- flwr/client/rest_client/connection.py +173 -67
- flwr/clientapp/__init__.py +15 -0
- flwr/common/__init__.py +2 -2
- flwr/common/auth_plugin/__init__.py +2 -0
- flwr/common/auth_plugin/auth_plugin.py +29 -3
- flwr/common/constant.py +36 -7
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit_handlers.py +30 -0
- flwr/common/heartbeat.py +165 -0
- flwr/common/inflatable.py +290 -0
- flwr/common/inflatable_grpc_utils.py +99 -0
- flwr/common/inflatable_rest_utils.py +99 -0
- flwr/common/inflatable_utils.py +341 -0
- flwr/common/message.py +110 -242
- flwr/common/record/__init__.py +2 -1
- flwr/common/record/array.py +323 -0
- flwr/common/record/arrayrecord.py +103 -225
- flwr/common/record/configrecord.py +59 -4
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/metricrecord.py +55 -4
- flwr/common/record/recorddict.py +69 -1
- flwr/common/recorddict_compat.py +2 -2
- flwr/common/retry_invoker.py +5 -1
- flwr/common/serde.py +59 -183
- flwr/common/serde_utils.py +175 -0
- flwr/common/typing.py +5 -3
- flwr/compat/__init__.py +15 -0
- flwr/compat/client/__init__.py +15 -0
- flwr/{client → compat/client}/app.py +19 -159
- flwr/compat/common/__init__.py +15 -0
- flwr/compat/server/__init__.py +15 -0
- flwr/compat/server/app.py +174 -0
- flwr/compat/simulation/__init__.py +15 -0
- flwr/proto/fleet_pb2.py +32 -27
- flwr/proto/fleet_pb2.pyi +49 -35
- flwr/proto/fleet_pb2_grpc.py +117 -13
- flwr/proto/fleet_pb2_grpc.pyi +47 -6
- flwr/proto/heartbeat_pb2.py +33 -0
- flwr/proto/heartbeat_pb2.pyi +66 -0
- flwr/proto/heartbeat_pb2_grpc.py +4 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +28 -11
- flwr/proto/message_pb2.pyi +125 -0
- flwr/proto/recorddict_pb2.py +16 -28
- flwr/proto/recorddict_pb2.pyi +46 -64
- flwr/proto/run_pb2.py +24 -32
- flwr/proto/run_pb2.pyi +4 -52
- flwr/proto/serverappio_pb2.py +32 -23
- flwr/proto/serverappio_pb2.pyi +45 -3
- flwr/proto/serverappio_pb2_grpc.py +138 -34
- flwr/proto/serverappio_pb2_grpc.pyi +54 -13
- flwr/proto/simulationio_pb2.py +12 -11
- flwr/proto/simulationio_pb2_grpc.py +35 -0
- flwr/proto/simulationio_pb2_grpc.pyi +14 -0
- flwr/server/__init__.py +1 -1
- flwr/server/app.py +68 -186
- flwr/server/compat/app_utils.py +50 -28
- flwr/server/fleet_event_log_interceptor.py +2 -2
- flwr/server/grid/grpc_grid.py +104 -34
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/serverapp/app.py +18 -0
- flwr/server/superlink/ffs/__init__.py +2 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +13 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +101 -7
- flwr/server/superlink/fleet/message_handler/message_handler.py +135 -18
- flwr/server/superlink/fleet/rest_rere/rest_api.py +72 -11
- flwr/server/superlink/fleet/vce/vce_api.py +6 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
- flwr/server/superlink/linkstate/linkstate.py +53 -20
- flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
- flwr/server/superlink/linkstate/utils.py +33 -29
- flwr/server/superlink/serverappio/serverappio_grpc.py +3 -0
- flwr/server/superlink/serverappio/serverappio_servicer.py +211 -57
- flwr/server/superlink/simulation/simulationio_servicer.py +25 -1
- flwr/server/superlink/utils.py +44 -2
- flwr/server/utils/validator.py +2 -2
- flwr/serverapp/__init__.py +15 -0
- flwr/simulation/app.py +17 -0
- flwr/supercore/__init__.py +15 -0
- flwr/supercore/object_store/__init__.py +24 -0
- flwr/supercore/object_store/in_memory_object_store.py +229 -0
- flwr/supercore/object_store/object_store.py +192 -0
- flwr/supercore/object_store/object_store_factory.py +44 -0
- flwr/superexec/deployment.py +6 -2
- flwr/superexec/exec_event_log_interceptor.py +4 -4
- flwr/superexec/exec_grpc.py +7 -3
- flwr/superexec/exec_servicer.py +125 -23
- flwr/superexec/exec_user_auth_interceptor.py +37 -8
- flwr/superexec/executor.py +4 -0
- flwr/superexec/simulation.py +7 -1
- flwr/superlink/__init__.py +15 -0
- flwr/{client/supernode → supernode}/__init__.py +0 -7
- flwr/{client/nodestate/nodestate.py → supernode/cli/__init__.py} +7 -14
- flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -12
- flwr/supernode/cli/flwr_clientapp.py +81 -0
- flwr/supernode/nodestate/in_memory_nodestate.py +190 -0
- flwr/supernode/nodestate/nodestate.py +212 -0
- flwr/supernode/runtime/__init__.py +15 -0
- flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +25 -56
- flwr/supernode/servicer/__init__.py +15 -0
- flwr/supernode/servicer/clientappio/__init__.py +24 -0
- flwr/supernode/start_client_internal.py +491 -0
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/METADATA +5 -4
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/RECORD +141 -108
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/WHEEL +1 -1
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/entry_points.txt +2 -2
- flwr/client/heartbeat.py +0 -74
- flwr/client/nodestate/in_memory_nodestate.py +0 -38
- /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
- /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
- /flwr/{client → supernode}/nodestate/__init__.py +0 -0
- /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
- /flwr/{client/clientapp → supernode/servicer/clientappio}/clientappio_servicer.py +0 -0
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Array."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import sys
|
|
21
|
+
from dataclasses import dataclass
|
|
22
|
+
from io import BytesIO
|
|
23
|
+
from typing import TYPE_CHECKING, Any, cast, overload
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
26
|
+
|
|
27
|
+
from flwr.proto.recorddict_pb2 import Array as ArrayProto # pylint: disable=E0611
|
|
28
|
+
|
|
29
|
+
from ..constant import SType
|
|
30
|
+
from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
|
|
31
|
+
from ..typing import NDArray
|
|
32
|
+
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
import torch
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _raise_array_init_error() -> None:
|
|
38
|
+
raise TypeError(
|
|
39
|
+
f"Invalid arguments for {Array.__qualname__}. Expected either a "
|
|
40
|
+
"PyTorch tensor, a NumPy ndarray, or explicit"
|
|
41
|
+
" dtype/shape/stype/data values."
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class Array(InflatableObject):
|
|
47
|
+
"""Array type.
|
|
48
|
+
|
|
49
|
+
A dataclass containing serialized data from an array-like or tensor-like object
|
|
50
|
+
along with metadata about it. The class can be initialized in one of three ways:
|
|
51
|
+
|
|
52
|
+
1. By specifying explicit values for `dtype`, `shape`, `stype`, and `data`.
|
|
53
|
+
2. By providing a NumPy ndarray (via the `ndarray` argument).
|
|
54
|
+
3. By providing a PyTorch tensor (via the `torch_tensor` argument).
|
|
55
|
+
|
|
56
|
+
In scenarios (2)-(3), the `dtype`, `shape`, `stype`, and `data` are automatically
|
|
57
|
+
derived from the input. In scenario (1), these fields must be specified manually.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
dtype : Optional[str] (default: None)
|
|
62
|
+
A string representing the data type of the serialized object (e.g. `"float32"`).
|
|
63
|
+
Only required if you are not passing in a ndarray or a tensor.
|
|
64
|
+
|
|
65
|
+
shape : Optional[tuple[int, ...]] (default: None)
|
|
66
|
+
A tuple representing the shape of the unserialized array-like object. Only
|
|
67
|
+
required if you are not passing in a ndarray or a tensor.
|
|
68
|
+
|
|
69
|
+
stype : Optional[str] (default: None)
|
|
70
|
+
A string indicating the serialization mechanism used to generate the bytes in
|
|
71
|
+
`data` from an array-like or tensor-like object. Only required if you are not
|
|
72
|
+
passing in a ndarray or a tensor.
|
|
73
|
+
|
|
74
|
+
data : Optional[bytes] (default: None)
|
|
75
|
+
A buffer of bytes containing the data. Only required if you are not passing in
|
|
76
|
+
a ndarray or a tensor.
|
|
77
|
+
|
|
78
|
+
ndarray : Optional[NDArray] (default: None)
|
|
79
|
+
A NumPy ndarray. If provided, the `dtype`, `shape`, `stype`, and `data`
|
|
80
|
+
fields are derived automatically from it.
|
|
81
|
+
|
|
82
|
+
torch_tensor : Optional[torch.Tensor] (default: None)
|
|
83
|
+
A PyTorch tensor. If provided, it will be **detached and moved to CPU**
|
|
84
|
+
before conversion, and the `dtype`, `shape`, `stype`, and `data` fields
|
|
85
|
+
will be derived automatically from it.
|
|
86
|
+
|
|
87
|
+
Examples
|
|
88
|
+
--------
|
|
89
|
+
Initializing by specifying all fields directly::
|
|
90
|
+
|
|
91
|
+
arr1 = Array(
|
|
92
|
+
dtype="float32",
|
|
93
|
+
shape=[3, 3],
|
|
94
|
+
stype="numpy.ndarray",
|
|
95
|
+
data=b"serialized_data...",
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
Initializing with a NumPy ndarray::
|
|
99
|
+
|
|
100
|
+
import numpy as np
|
|
101
|
+
arr2 = Array(np.random.randn(3, 3))
|
|
102
|
+
|
|
103
|
+
Initializing with a PyTorch tensor::
|
|
104
|
+
|
|
105
|
+
import torch
|
|
106
|
+
arr3 = Array(torch.randn(3, 3))
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
dtype: str
|
|
110
|
+
shape: tuple[int, ...]
|
|
111
|
+
stype: str
|
|
112
|
+
data: bytes
|
|
113
|
+
|
|
114
|
+
@overload
|
|
115
|
+
def __init__( # noqa: E704
|
|
116
|
+
self, dtype: str, shape: tuple[int, ...], stype: str, data: bytes
|
|
117
|
+
) -> None: ...
|
|
118
|
+
|
|
119
|
+
@overload
|
|
120
|
+
def __init__(self, ndarray: NDArray) -> None: ... # noqa: E704
|
|
121
|
+
|
|
122
|
+
@overload
|
|
123
|
+
def __init__(self, torch_tensor: torch.Tensor) -> None: ... # noqa: E704
|
|
124
|
+
|
|
125
|
+
def __init__( # pylint: disable=too-many-arguments, too-many-locals
|
|
126
|
+
self,
|
|
127
|
+
*args: Any,
|
|
128
|
+
dtype: str | None = None,
|
|
129
|
+
shape: tuple[int, ...] | None = None,
|
|
130
|
+
stype: str | None = None,
|
|
131
|
+
data: bytes | None = None,
|
|
132
|
+
ndarray: NDArray | None = None,
|
|
133
|
+
torch_tensor: torch.Tensor | None = None,
|
|
134
|
+
) -> None:
|
|
135
|
+
# Determine the initialization method and validate input arguments.
|
|
136
|
+
# Support three initialization formats:
|
|
137
|
+
# 1. Array(dtype: str, shape: tuple[int, ...], stype: str, data: bytes)
|
|
138
|
+
# 2. Array(ndarray: NDArray)
|
|
139
|
+
# 3. Array(torch_tensor: torch.Tensor)
|
|
140
|
+
|
|
141
|
+
# Initialize all arguments
|
|
142
|
+
# If more than 4 positional arguments are provided, raise an error.
|
|
143
|
+
if len(args) > 4:
|
|
144
|
+
_raise_array_init_error()
|
|
145
|
+
all_args = [None] * 4
|
|
146
|
+
for i, arg in enumerate(args):
|
|
147
|
+
all_args[i] = arg
|
|
148
|
+
init_method: str | None = None # Track which init method is being used
|
|
149
|
+
|
|
150
|
+
# Try to assign a value to all_args[index] if it's not already set.
|
|
151
|
+
# If an initialization method is provided, update init_method.
|
|
152
|
+
def _try_set_arg(index: int, arg: Any, method: str) -> None:
|
|
153
|
+
# Skip if arg is None
|
|
154
|
+
if arg is None:
|
|
155
|
+
return
|
|
156
|
+
# Raise an error if all_args[index] is already set
|
|
157
|
+
if all_args[index] is not None:
|
|
158
|
+
_raise_array_init_error()
|
|
159
|
+
# Raise an error if a different initialization method is already set
|
|
160
|
+
nonlocal init_method
|
|
161
|
+
if init_method is not None and init_method != method:
|
|
162
|
+
_raise_array_init_error()
|
|
163
|
+
# Set init_method and all_args[index]
|
|
164
|
+
if init_method is None:
|
|
165
|
+
init_method = method
|
|
166
|
+
all_args[index] = arg
|
|
167
|
+
|
|
168
|
+
# Try to set keyword arguments in all_args
|
|
169
|
+
_try_set_arg(0, dtype, "direct")
|
|
170
|
+
_try_set_arg(1, shape, "direct")
|
|
171
|
+
_try_set_arg(2, stype, "direct")
|
|
172
|
+
_try_set_arg(3, data, "direct")
|
|
173
|
+
_try_set_arg(0, ndarray, "ndarray")
|
|
174
|
+
_try_set_arg(0, torch_tensor, "torch_tensor")
|
|
175
|
+
|
|
176
|
+
# Check if all arguments are correctly set
|
|
177
|
+
all_args = [arg for arg in all_args if arg is not None]
|
|
178
|
+
|
|
179
|
+
# Handle direct field initialization
|
|
180
|
+
if not init_method or init_method == "direct":
|
|
181
|
+
if (
|
|
182
|
+
len(all_args) == 4 # pylint: disable=too-many-boolean-expressions
|
|
183
|
+
and isinstance(all_args[0], str)
|
|
184
|
+
and isinstance(all_args[1], tuple)
|
|
185
|
+
and all(isinstance(i, int) for i in all_args[1])
|
|
186
|
+
and isinstance(all_args[2], str)
|
|
187
|
+
and isinstance(all_args[3], bytes)
|
|
188
|
+
):
|
|
189
|
+
self.dtype, self.shape, self.stype, self.data = all_args
|
|
190
|
+
return
|
|
191
|
+
|
|
192
|
+
# Handle NumPy array
|
|
193
|
+
if not init_method or init_method == "ndarray":
|
|
194
|
+
if len(all_args) == 1 and isinstance(all_args[0], np.ndarray):
|
|
195
|
+
self.__dict__.update(self.from_numpy_ndarray(all_args[0]).__dict__)
|
|
196
|
+
return
|
|
197
|
+
|
|
198
|
+
# Handle PyTorch tensor
|
|
199
|
+
if not init_method or init_method == "torch_tensor":
|
|
200
|
+
if (
|
|
201
|
+
len(all_args) == 1
|
|
202
|
+
and "torch" in sys.modules
|
|
203
|
+
and isinstance(all_args[0], sys.modules["torch"].Tensor)
|
|
204
|
+
):
|
|
205
|
+
self.__dict__.update(self.from_torch_tensor(all_args[0]).__dict__)
|
|
206
|
+
return
|
|
207
|
+
|
|
208
|
+
_raise_array_init_error()
|
|
209
|
+
|
|
210
|
+
@classmethod
|
|
211
|
+
def from_numpy_ndarray(cls, ndarray: NDArray) -> Array:
|
|
212
|
+
"""Create Array from NumPy ndarray."""
|
|
213
|
+
assert isinstance(
|
|
214
|
+
ndarray, np.ndarray
|
|
215
|
+
), f"Expected NumPy ndarray, got {type(ndarray)}"
|
|
216
|
+
buffer = BytesIO()
|
|
217
|
+
# WARNING: NEVER set allow_pickle to true.
|
|
218
|
+
# Reason: loading pickled data can execute arbitrary code
|
|
219
|
+
# Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
|
|
220
|
+
np.save(buffer, ndarray, allow_pickle=False)
|
|
221
|
+
data = buffer.getvalue()
|
|
222
|
+
return Array(
|
|
223
|
+
dtype=str(ndarray.dtype),
|
|
224
|
+
shape=tuple(ndarray.shape),
|
|
225
|
+
stype=SType.NUMPY,
|
|
226
|
+
data=data,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
@classmethod
|
|
230
|
+
def from_torch_tensor(cls, tensor: torch.Tensor) -> Array:
|
|
231
|
+
"""Create Array from PyTorch tensor."""
|
|
232
|
+
if not (torch := sys.modules.get("torch")):
|
|
233
|
+
raise RuntimeError(
|
|
234
|
+
f"PyTorch is required to use {cls.from_torch_tensor.__name__}"
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
assert isinstance(
|
|
238
|
+
tensor, torch.Tensor
|
|
239
|
+
), f"Expected PyTorch Tensor, got {type(tensor)}"
|
|
240
|
+
return cls.from_numpy_ndarray(tensor.detach().cpu().numpy())
|
|
241
|
+
|
|
242
|
+
def numpy(self) -> NDArray:
|
|
243
|
+
"""Return the array as a NumPy array."""
|
|
244
|
+
if self.stype != SType.NUMPY:
|
|
245
|
+
raise TypeError(
|
|
246
|
+
f"Unsupported serialization type for numpy conversion: '{self.stype}'"
|
|
247
|
+
)
|
|
248
|
+
bytes_io = BytesIO(self.data)
|
|
249
|
+
# WARNING: NEVER set allow_pickle to true.
|
|
250
|
+
# Reason: loading pickled data can execute arbitrary code
|
|
251
|
+
# Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
|
|
252
|
+
ndarray_deserialized = np.load(bytes_io, allow_pickle=False)
|
|
253
|
+
return cast(NDArray, ndarray_deserialized)
|
|
254
|
+
|
|
255
|
+
def deflate(self) -> bytes:
|
|
256
|
+
"""Deflate the Array."""
|
|
257
|
+
array_proto = ArrayProto(
|
|
258
|
+
dtype=self.dtype,
|
|
259
|
+
shape=self.shape,
|
|
260
|
+
stype=self.stype,
|
|
261
|
+
data=self.data,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
obj_body = array_proto.SerializeToString(deterministic=True)
|
|
265
|
+
return add_header_to_object_body(object_body=obj_body, obj=self)
|
|
266
|
+
|
|
267
|
+
@classmethod
|
|
268
|
+
def inflate(
|
|
269
|
+
cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
|
|
270
|
+
) -> Array:
|
|
271
|
+
"""Inflate an Array from bytes.
|
|
272
|
+
|
|
273
|
+
Parameters
|
|
274
|
+
----------
|
|
275
|
+
object_content : bytes
|
|
276
|
+
The deflated object content of the Array.
|
|
277
|
+
|
|
278
|
+
children : Optional[dict[str, InflatableObject]] (default: None)
|
|
279
|
+
Must be ``None``. ``Array`` does not support child objects.
|
|
280
|
+
Providing any children will raise a ``ValueError``.
|
|
281
|
+
|
|
282
|
+
Returns
|
|
283
|
+
-------
|
|
284
|
+
Array
|
|
285
|
+
The inflated Array.
|
|
286
|
+
"""
|
|
287
|
+
if children:
|
|
288
|
+
raise ValueError("`Array` objects do not have children.")
|
|
289
|
+
|
|
290
|
+
obj_body = get_object_body(object_content, cls)
|
|
291
|
+
proto_array = ArrayProto.FromString(obj_body)
|
|
292
|
+
return cls(
|
|
293
|
+
dtype=proto_array.dtype,
|
|
294
|
+
shape=tuple(proto_array.shape),
|
|
295
|
+
stype=proto_array.stype,
|
|
296
|
+
data=proto_array.data,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
@property
|
|
300
|
+
def object_id(self) -> str:
|
|
301
|
+
"""Get object ID."""
|
|
302
|
+
ret = super().object_id
|
|
303
|
+
self.is_dirty = False # Reset dirty flag
|
|
304
|
+
return ret
|
|
305
|
+
|
|
306
|
+
@property
|
|
307
|
+
def is_dirty(self) -> bool:
|
|
308
|
+
"""Check if the object is dirty after the last deflation."""
|
|
309
|
+
if "_is_dirty" not in self.__dict__:
|
|
310
|
+
self.__dict__["_is_dirty"] = True
|
|
311
|
+
return cast(bool, self.__dict__["_is_dirty"])
|
|
312
|
+
|
|
313
|
+
@is_dirty.setter
|
|
314
|
+
def is_dirty(self, value: bool) -> None:
|
|
315
|
+
"""Set the dirty flag."""
|
|
316
|
+
self.__dict__["_is_dirty"] = value
|
|
317
|
+
|
|
318
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
319
|
+
"""Set attribute with special handling for dirty state."""
|
|
320
|
+
if name in ("dtype", "shape", "stype", "data"):
|
|
321
|
+
# Mark as dirty if any of the main attributes are set
|
|
322
|
+
self.is_dirty = True
|
|
323
|
+
super().__setattr__(name, value)
|
|
@@ -12,38 +12,31 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""ArrayRecord
|
|
15
|
+
"""ArrayRecord."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from __future__ import annotations
|
|
19
19
|
|
|
20
20
|
import gc
|
|
21
|
+
import json
|
|
21
22
|
import sys
|
|
22
23
|
from collections import OrderedDict
|
|
23
|
-
from dataclasses import dataclass
|
|
24
|
-
from io import BytesIO
|
|
25
24
|
from logging import WARN
|
|
26
25
|
from typing import TYPE_CHECKING, Any, cast, overload
|
|
27
26
|
|
|
28
27
|
import numpy as np
|
|
29
28
|
|
|
30
|
-
from ..constant import GC_THRESHOLD
|
|
29
|
+
from ..constant import GC_THRESHOLD
|
|
30
|
+
from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
|
|
31
31
|
from ..logger import log
|
|
32
32
|
from ..typing import NDArray
|
|
33
|
+
from .array import Array
|
|
33
34
|
from .typeddict import TypedDict
|
|
34
35
|
|
|
35
36
|
if TYPE_CHECKING:
|
|
36
37
|
import torch
|
|
37
38
|
|
|
38
39
|
|
|
39
|
-
def _raise_array_init_error() -> None:
|
|
40
|
-
raise TypeError(
|
|
41
|
-
f"Invalid arguments for {Array.__qualname__}. Expected either a "
|
|
42
|
-
"PyTorch tensor, a NumPy ndarray, or explicit"
|
|
43
|
-
" dtype/shape/stype/data values."
|
|
44
|
-
)
|
|
45
|
-
|
|
46
|
-
|
|
47
40
|
def _raise_array_record_init_error() -> None:
|
|
48
41
|
raise TypeError(
|
|
49
42
|
f"Invalid arguments for {ArrayRecord.__qualname__}. Expected either "
|
|
@@ -52,217 +45,6 @@ def _raise_array_record_init_error() -> None:
|
|
|
52
45
|
)
|
|
53
46
|
|
|
54
47
|
|
|
55
|
-
@dataclass
|
|
56
|
-
class Array:
|
|
57
|
-
"""Array type.
|
|
58
|
-
|
|
59
|
-
A dataclass containing serialized data from an array-like or tensor-like object
|
|
60
|
-
along with metadata about it. The class can be initialized in one of three ways:
|
|
61
|
-
|
|
62
|
-
1. By specifying explicit values for `dtype`, `shape`, `stype`, and `data`.
|
|
63
|
-
2. By providing a NumPy ndarray (via the `ndarray` argument).
|
|
64
|
-
3. By providing a PyTorch tensor (via the `torch_tensor` argument).
|
|
65
|
-
|
|
66
|
-
In scenarios (2)-(3), the `dtype`, `shape`, `stype`, and `data` are automatically
|
|
67
|
-
derived from the input. In scenario (1), these fields must be specified manually.
|
|
68
|
-
|
|
69
|
-
Parameters
|
|
70
|
-
----------
|
|
71
|
-
dtype : Optional[str] (default: None)
|
|
72
|
-
A string representing the data type of the serialized object (e.g. `"float32"`).
|
|
73
|
-
Only required if you are not passing in a ndarray or a tensor.
|
|
74
|
-
|
|
75
|
-
shape : Optional[list[int]] (default: None)
|
|
76
|
-
A list representing the shape of the unserialized array-like object. Only
|
|
77
|
-
required if you are not passing in a ndarray or a tensor.
|
|
78
|
-
|
|
79
|
-
stype : Optional[str] (default: None)
|
|
80
|
-
A string indicating the serialization mechanism used to generate the bytes in
|
|
81
|
-
`data` from an array-like or tensor-like object. Only required if you are not
|
|
82
|
-
passing in a ndarray or a tensor.
|
|
83
|
-
|
|
84
|
-
data : Optional[bytes] (default: None)
|
|
85
|
-
A buffer of bytes containing the data. Only required if you are not passing in
|
|
86
|
-
a ndarray or a tensor.
|
|
87
|
-
|
|
88
|
-
ndarray : Optional[NDArray] (default: None)
|
|
89
|
-
A NumPy ndarray. If provided, the `dtype`, `shape`, `stype`, and `data`
|
|
90
|
-
fields are derived automatically from it.
|
|
91
|
-
|
|
92
|
-
torch_tensor : Optional[torch.Tensor] (default: None)
|
|
93
|
-
A PyTorch tensor. If provided, it will be **detached and moved to CPU**
|
|
94
|
-
before conversion, and the `dtype`, `shape`, `stype`, and `data` fields
|
|
95
|
-
will be derived automatically from it.
|
|
96
|
-
|
|
97
|
-
Examples
|
|
98
|
-
--------
|
|
99
|
-
Initializing by specifying all fields directly::
|
|
100
|
-
|
|
101
|
-
arr1 = Array(
|
|
102
|
-
dtype="float32",
|
|
103
|
-
shape=[3, 3],
|
|
104
|
-
stype="numpy.ndarray",
|
|
105
|
-
data=b"serialized_data...",
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
Initializing with a NumPy ndarray::
|
|
109
|
-
|
|
110
|
-
import numpy as np
|
|
111
|
-
arr2 = Array(np.random.randn(3, 3))
|
|
112
|
-
|
|
113
|
-
Initializing with a PyTorch tensor::
|
|
114
|
-
|
|
115
|
-
import torch
|
|
116
|
-
arr3 = Array(torch.randn(3, 3))
|
|
117
|
-
"""
|
|
118
|
-
|
|
119
|
-
dtype: str
|
|
120
|
-
shape: list[int]
|
|
121
|
-
stype: str
|
|
122
|
-
data: bytes
|
|
123
|
-
|
|
124
|
-
@overload
|
|
125
|
-
def __init__( # noqa: E704
|
|
126
|
-
self, dtype: str, shape: list[int], stype: str, data: bytes
|
|
127
|
-
) -> None: ...
|
|
128
|
-
|
|
129
|
-
@overload
|
|
130
|
-
def __init__(self, ndarray: NDArray) -> None: ... # noqa: E704
|
|
131
|
-
|
|
132
|
-
@overload
|
|
133
|
-
def __init__(self, torch_tensor: torch.Tensor) -> None: ... # noqa: E704
|
|
134
|
-
|
|
135
|
-
def __init__( # pylint: disable=too-many-arguments, too-many-locals
|
|
136
|
-
self,
|
|
137
|
-
*args: Any,
|
|
138
|
-
dtype: str | None = None,
|
|
139
|
-
shape: list[int] | None = None,
|
|
140
|
-
stype: str | None = None,
|
|
141
|
-
data: bytes | None = None,
|
|
142
|
-
ndarray: NDArray | None = None,
|
|
143
|
-
torch_tensor: torch.Tensor | None = None,
|
|
144
|
-
) -> None:
|
|
145
|
-
# Determine the initialization method and validate input arguments.
|
|
146
|
-
# Support three initialization formats:
|
|
147
|
-
# 1. Array(dtype: str, shape: list[int], stype: str, data: bytes)
|
|
148
|
-
# 2. Array(ndarray: NDArray)
|
|
149
|
-
# 3. Array(torch_tensor: torch.Tensor)
|
|
150
|
-
|
|
151
|
-
# Initialize all arguments
|
|
152
|
-
# If more than 4 positional arguments are provided, raise an error.
|
|
153
|
-
if len(args) > 4:
|
|
154
|
-
_raise_array_init_error()
|
|
155
|
-
all_args = [None] * 4
|
|
156
|
-
for i, arg in enumerate(args):
|
|
157
|
-
all_args[i] = arg
|
|
158
|
-
init_method: str | None = None # Track which init method is being used
|
|
159
|
-
|
|
160
|
-
# Try to assign a value to all_args[index] if it's not already set.
|
|
161
|
-
# If an initialization method is provided, update init_method.
|
|
162
|
-
def _try_set_arg(index: int, arg: Any, method: str) -> None:
|
|
163
|
-
# Skip if arg is None
|
|
164
|
-
if arg is None:
|
|
165
|
-
return
|
|
166
|
-
# Raise an error if all_args[index] is already set
|
|
167
|
-
if all_args[index] is not None:
|
|
168
|
-
_raise_array_init_error()
|
|
169
|
-
# Raise an error if a different initialization method is already set
|
|
170
|
-
nonlocal init_method
|
|
171
|
-
if init_method is not None and init_method != method:
|
|
172
|
-
_raise_array_init_error()
|
|
173
|
-
# Set init_method and all_args[index]
|
|
174
|
-
if init_method is None:
|
|
175
|
-
init_method = method
|
|
176
|
-
all_args[index] = arg
|
|
177
|
-
|
|
178
|
-
# Try to set keyword arguments in all_args
|
|
179
|
-
_try_set_arg(0, dtype, "direct")
|
|
180
|
-
_try_set_arg(1, shape, "direct")
|
|
181
|
-
_try_set_arg(2, stype, "direct")
|
|
182
|
-
_try_set_arg(3, data, "direct")
|
|
183
|
-
_try_set_arg(0, ndarray, "ndarray")
|
|
184
|
-
_try_set_arg(0, torch_tensor, "torch_tensor")
|
|
185
|
-
|
|
186
|
-
# Check if all arguments are correctly set
|
|
187
|
-
all_args = [arg for arg in all_args if arg is not None]
|
|
188
|
-
|
|
189
|
-
# Handle direct field initialization
|
|
190
|
-
if not init_method or init_method == "direct":
|
|
191
|
-
if (
|
|
192
|
-
len(all_args) == 4 # pylint: disable=too-many-boolean-expressions
|
|
193
|
-
and isinstance(all_args[0], str)
|
|
194
|
-
and isinstance(all_args[1], list)
|
|
195
|
-
and all(isinstance(i, int) for i in all_args[1])
|
|
196
|
-
and isinstance(all_args[2], str)
|
|
197
|
-
and isinstance(all_args[3], bytes)
|
|
198
|
-
):
|
|
199
|
-
self.dtype, self.shape, self.stype, self.data = all_args
|
|
200
|
-
return
|
|
201
|
-
|
|
202
|
-
# Handle NumPy array
|
|
203
|
-
if not init_method or init_method == "ndarray":
|
|
204
|
-
if len(all_args) == 1 and isinstance(all_args[0], np.ndarray):
|
|
205
|
-
self.__dict__.update(self.from_numpy_ndarray(all_args[0]).__dict__)
|
|
206
|
-
return
|
|
207
|
-
|
|
208
|
-
# Handle PyTorch tensor
|
|
209
|
-
if not init_method or init_method == "torch_tensor":
|
|
210
|
-
if (
|
|
211
|
-
len(all_args) == 1
|
|
212
|
-
and "torch" in sys.modules
|
|
213
|
-
and isinstance(all_args[0], sys.modules["torch"].Tensor)
|
|
214
|
-
):
|
|
215
|
-
self.__dict__.update(self.from_torch_tensor(all_args[0]).__dict__)
|
|
216
|
-
return
|
|
217
|
-
|
|
218
|
-
_raise_array_init_error()
|
|
219
|
-
|
|
220
|
-
@classmethod
|
|
221
|
-
def from_numpy_ndarray(cls, ndarray: NDArray) -> Array:
|
|
222
|
-
"""Create Array from NumPy ndarray."""
|
|
223
|
-
assert isinstance(
|
|
224
|
-
ndarray, np.ndarray
|
|
225
|
-
), f"Expected NumPy ndarray, got {type(ndarray)}"
|
|
226
|
-
buffer = BytesIO()
|
|
227
|
-
# WARNING: NEVER set allow_pickle to true.
|
|
228
|
-
# Reason: loading pickled data can execute arbitrary code
|
|
229
|
-
# Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
|
|
230
|
-
np.save(buffer, ndarray, allow_pickle=False)
|
|
231
|
-
data = buffer.getvalue()
|
|
232
|
-
return Array(
|
|
233
|
-
dtype=str(ndarray.dtype),
|
|
234
|
-
shape=list(ndarray.shape),
|
|
235
|
-
stype=SType.NUMPY,
|
|
236
|
-
data=data,
|
|
237
|
-
)
|
|
238
|
-
|
|
239
|
-
@classmethod
|
|
240
|
-
def from_torch_tensor(cls, tensor: torch.Tensor) -> Array:
|
|
241
|
-
"""Create Array from PyTorch tensor."""
|
|
242
|
-
if not (torch := sys.modules.get("torch")):
|
|
243
|
-
raise RuntimeError(
|
|
244
|
-
f"PyTorch is required to use {cls.from_torch_tensor.__name__}"
|
|
245
|
-
)
|
|
246
|
-
|
|
247
|
-
assert isinstance(
|
|
248
|
-
tensor, torch.Tensor
|
|
249
|
-
), f"Expected PyTorch Tensor, got {type(tensor)}"
|
|
250
|
-
return cls.from_numpy_ndarray(tensor.detach().cpu().numpy())
|
|
251
|
-
|
|
252
|
-
def numpy(self) -> NDArray:
|
|
253
|
-
"""Return the array as a NumPy array."""
|
|
254
|
-
if self.stype != SType.NUMPY:
|
|
255
|
-
raise TypeError(
|
|
256
|
-
f"Unsupported serialization type for numpy conversion: '{self.stype}'"
|
|
257
|
-
)
|
|
258
|
-
bytes_io = BytesIO(self.data)
|
|
259
|
-
# WARNING: NEVER set allow_pickle to true.
|
|
260
|
-
# Reason: loading pickled data can execute arbitrary code
|
|
261
|
-
# Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
|
|
262
|
-
ndarray_deserialized = np.load(bytes_io, allow_pickle=False)
|
|
263
|
-
return cast(NDArray, ndarray_deserialized)
|
|
264
|
-
|
|
265
|
-
|
|
266
48
|
def _check_key(key: str) -> None:
|
|
267
49
|
"""Check if key is of expected type."""
|
|
268
50
|
if not isinstance(key, str):
|
|
@@ -276,7 +58,7 @@ def _check_value(value: Array) -> None:
|
|
|
276
58
|
)
|
|
277
59
|
|
|
278
60
|
|
|
279
|
-
class ArrayRecord(TypedDict[str, Array]):
|
|
61
|
+
class ArrayRecord(TypedDict[str, Array], InflatableObject):
|
|
280
62
|
"""Array record.
|
|
281
63
|
|
|
282
64
|
A typed dictionary (``str`` to :class:`Array`) that can store named arrays,
|
|
@@ -470,7 +252,7 @@ class ArrayRecord(TypedDict[str, Array]):
|
|
|
470
252
|
record = ArrayRecord()
|
|
471
253
|
for k, v in array_dict.items():
|
|
472
254
|
record[k] = Array(
|
|
473
|
-
dtype=v.dtype, shape=
|
|
255
|
+
dtype=v.dtype, shape=tuple(v.shape), stype=v.stype, data=v.data
|
|
474
256
|
)
|
|
475
257
|
if not keep_input:
|
|
476
258
|
array_dict.clear()
|
|
@@ -585,6 +367,102 @@ class ArrayRecord(TypedDict[str, Array]):
|
|
|
585
367
|
|
|
586
368
|
return num_bytes
|
|
587
369
|
|
|
370
|
+
@property
|
|
371
|
+
def children(self) -> dict[str, InflatableObject]:
|
|
372
|
+
"""Return a dictionary of Arrays with their Object IDs as keys."""
|
|
373
|
+
return {arr.object_id: arr for arr in self.values()}
|
|
374
|
+
|
|
375
|
+
def deflate(self) -> bytes:
|
|
376
|
+
"""Deflate the ArrayRecord."""
|
|
377
|
+
# array_name: array_object_id mapping
|
|
378
|
+
array_refs: dict[str, str] = {}
|
|
379
|
+
|
|
380
|
+
for array_name, array in self.items():
|
|
381
|
+
array_refs[array_name] = array.object_id
|
|
382
|
+
|
|
383
|
+
# Serialize references dict
|
|
384
|
+
object_body = json.dumps(array_refs).encode("utf-8")
|
|
385
|
+
return add_header_to_object_body(object_body=object_body, obj=self)
|
|
386
|
+
|
|
387
|
+
@classmethod
|
|
388
|
+
def inflate(
|
|
389
|
+
cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
|
|
390
|
+
) -> ArrayRecord:
|
|
391
|
+
"""Inflate an ArrayRecord from bytes.
|
|
392
|
+
|
|
393
|
+
Parameters
|
|
394
|
+
----------
|
|
395
|
+
object_content : bytes
|
|
396
|
+
The deflated object content of the ArrayRecord.
|
|
397
|
+
children : Optional[dict[str, InflatableObject]] (default: None)
|
|
398
|
+
Dictionary of children InflatableObjects mapped to their Object IDs.
|
|
399
|
+
These children enable the full inflation of the ArrayRecord.
|
|
400
|
+
|
|
401
|
+
Returns
|
|
402
|
+
-------
|
|
403
|
+
ArrayRecord
|
|
404
|
+
The inflated ArrayRecord.
|
|
405
|
+
"""
|
|
406
|
+
if children is None:
|
|
407
|
+
children = {}
|
|
408
|
+
|
|
409
|
+
# Inflate mapping of array_names (keys in the ArrayRecord) to Arrays' object IDs
|
|
410
|
+
obj_body = get_object_body(object_content, cls)
|
|
411
|
+
array_refs: dict[str, str] = json.loads(obj_body.decode(encoding="utf-8"))
|
|
412
|
+
|
|
413
|
+
unique_arrays = set(array_refs.values())
|
|
414
|
+
children_obj_ids = set(children.keys())
|
|
415
|
+
if unique_arrays != children_obj_ids:
|
|
416
|
+
raise ValueError(
|
|
417
|
+
"Unexpected set of `children`. "
|
|
418
|
+
f"Expected {unique_arrays} but got {children_obj_ids}."
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
# Ensure children are of type Array
|
|
422
|
+
if not all(isinstance(arr, Array) for arr in children.values()):
|
|
423
|
+
raise ValueError("`Children` are expected to be of type `Array`.")
|
|
424
|
+
|
|
425
|
+
# Instantiate new ArrayRecord
|
|
426
|
+
return ArrayRecord(
|
|
427
|
+
OrderedDict(
|
|
428
|
+
{name: children[object_id] for name, object_id in array_refs.items()}
|
|
429
|
+
)
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
@property
|
|
433
|
+
def object_id(self) -> str:
|
|
434
|
+
"""Get object ID."""
|
|
435
|
+
ret = super().object_id
|
|
436
|
+
self.is_dirty = False # Reset dirty flag
|
|
437
|
+
return ret
|
|
438
|
+
|
|
439
|
+
@property
|
|
440
|
+
def is_dirty(self) -> bool:
|
|
441
|
+
"""Check if the object is dirty after the last deflation."""
|
|
442
|
+
if "_is_dirty" not in self.__dict__:
|
|
443
|
+
self.__dict__["_is_dirty"] = True
|
|
444
|
+
|
|
445
|
+
if not self.__dict__["_is_dirty"]:
|
|
446
|
+
if any(v.is_dirty for v in self.values()):
|
|
447
|
+
# If any Array is dirty, mark the record as dirty
|
|
448
|
+
self.__dict__["_is_dirty"] = True
|
|
449
|
+
return cast(bool, self.__dict__["_is_dirty"])
|
|
450
|
+
|
|
451
|
+
@is_dirty.setter
|
|
452
|
+
def is_dirty(self, value: bool) -> None:
|
|
453
|
+
"""Set the dirty flag."""
|
|
454
|
+
self.__dict__["_is_dirty"] = value
|
|
455
|
+
|
|
456
|
+
def __setitem__(self, key: str, value: Array) -> None:
|
|
457
|
+
"""Set item and mark the record as dirty."""
|
|
458
|
+
self.is_dirty = True # Mark as dirty when setting an item
|
|
459
|
+
super().__setitem__(key, value)
|
|
460
|
+
|
|
461
|
+
def __delitem__(self, key: str) -> None:
|
|
462
|
+
"""Delete item and mark the record as dirty."""
|
|
463
|
+
self.is_dirty = True # Mark as dirty when deleting an item
|
|
464
|
+
super().__delitem__(key)
|
|
465
|
+
|
|
588
466
|
|
|
589
467
|
class ParametersRecord(ArrayRecord):
|
|
590
468
|
"""Deprecated class ``ParametersRecord``, use ``ArrayRecord`` instead.
|