flwr 1.18.0__py3-none-any.whl → 1.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/app/__init__.py +15 -0
- flwr/app/error.py +68 -0
- flwr/app/metadata.py +223 -0
- flwr/cli/build.py +94 -59
- flwr/cli/log.py +3 -3
- flwr/cli/login/login.py +3 -7
- flwr/cli/ls.py +15 -36
- flwr/cli/new/new.py +12 -4
- flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
- flwr/cli/new/templates/app/README.md.tpl +5 -0
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +25 -17
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
- flwr/cli/run/run.py +48 -49
- flwr/cli/stop.py +2 -2
- flwr/cli/utils.py +38 -5
- flwr/client/__init__.py +2 -2
- flwr/client/client_app.py +1 -1
- flwr/client/clientapp/__init__.py +0 -7
- flwr/client/grpc_adapter_client/connection.py +15 -8
- flwr/client/grpc_rere_client/connection.py +142 -97
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/mod/comms_mods.py +36 -17
- flwr/client/rest_client/connection.py +176 -103
- flwr/clientapp/__init__.py +15 -0
- flwr/common/__init__.py +2 -2
- flwr/common/auth_plugin/__init__.py +2 -0
- flwr/common/auth_plugin/auth_plugin.py +29 -3
- flwr/common/constant.py +39 -8
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit/exit_code.py +16 -1
- flwr/common/exit_handlers.py +30 -0
- flwr/common/grpc.py +12 -1
- flwr/common/heartbeat.py +165 -0
- flwr/common/inflatable.py +290 -0
- flwr/common/inflatable_protobuf_utils.py +141 -0
- flwr/common/inflatable_utils.py +508 -0
- flwr/common/message.py +110 -242
- flwr/common/record/__init__.py +2 -1
- flwr/common/record/array.py +402 -0
- flwr/common/record/arraychunk.py +59 -0
- flwr/common/record/arrayrecord.py +103 -225
- flwr/common/record/configrecord.py +59 -4
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/metricrecord.py +55 -4
- flwr/common/record/recorddict.py +69 -1
- flwr/common/recorddict_compat.py +2 -2
- flwr/common/retry_invoker.py +5 -1
- flwr/common/serde.py +59 -211
- flwr/common/serde_utils.py +175 -0
- flwr/common/typing.py +5 -3
- flwr/compat/__init__.py +15 -0
- flwr/compat/client/__init__.py +15 -0
- flwr/{client → compat/client}/app.py +28 -185
- flwr/compat/common/__init__.py +15 -0
- flwr/compat/server/__init__.py +15 -0
- flwr/compat/server/app.py +174 -0
- flwr/compat/simulation/__init__.py +15 -0
- flwr/proto/appio_pb2.py +43 -0
- flwr/proto/appio_pb2.pyi +151 -0
- flwr/proto/appio_pb2_grpc.py +4 -0
- flwr/proto/appio_pb2_grpc.pyi +4 -0
- flwr/proto/clientappio_pb2.py +12 -19
- flwr/proto/clientappio_pb2.pyi +23 -101
- flwr/proto/clientappio_pb2_grpc.py +269 -28
- flwr/proto/clientappio_pb2_grpc.pyi +114 -20
- flwr/proto/fleet_pb2.py +24 -27
- flwr/proto/fleet_pb2.pyi +19 -35
- flwr/proto/fleet_pb2_grpc.py +117 -13
- flwr/proto/fleet_pb2_grpc.pyi +47 -6
- flwr/proto/heartbeat_pb2.py +33 -0
- flwr/proto/heartbeat_pb2.pyi +66 -0
- flwr/proto/heartbeat_pb2_grpc.py +4 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +28 -11
- flwr/proto/message_pb2.pyi +125 -0
- flwr/proto/recorddict_pb2.py +16 -28
- flwr/proto/recorddict_pb2.pyi +46 -64
- flwr/proto/run_pb2.py +24 -32
- flwr/proto/run_pb2.pyi +4 -52
- flwr/proto/serverappio_pb2.py +9 -23
- flwr/proto/serverappio_pb2.pyi +0 -110
- flwr/proto/serverappio_pb2_grpc.py +177 -72
- flwr/proto/serverappio_pb2_grpc.pyi +75 -33
- flwr/proto/simulationio_pb2.py +12 -11
- flwr/proto/simulationio_pb2_grpc.py +35 -0
- flwr/proto/simulationio_pb2_grpc.pyi +14 -0
- flwr/server/__init__.py +1 -1
- flwr/server/app.py +69 -187
- flwr/server/compat/app_utils.py +50 -28
- flwr/server/fleet_event_log_interceptor.py +6 -2
- flwr/server/grid/grpc_grid.py +148 -41
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/serverapp/app.py +45 -17
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +21 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
- flwr/server/superlink/fleet/message_handler/message_handler.py +130 -19
- flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -13
- flwr/server/superlink/fleet/vce/vce_api.py +6 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
- flwr/server/superlink/linkstate/linkstate.py +53 -20
- flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
- flwr/server/superlink/linkstate/utils.py +33 -29
- flwr/server/superlink/serverappio/serverappio_grpc.py +4 -1
- flwr/server/superlink/serverappio/serverappio_servicer.py +230 -84
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
- flwr/server/superlink/utils.py +9 -2
- flwr/server/utils/validator.py +2 -2
- flwr/serverapp/__init__.py +15 -0
- flwr/simulation/app.py +25 -0
- flwr/simulation/run_simulation.py +17 -0
- flwr/supercore/__init__.py +15 -0
- flwr/{server/superlink → supercore}/ffs/__init__.py +2 -0
- flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
- flwr/supercore/grpc_health/__init__.py +22 -0
- flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
- flwr/supercore/license_plugin/__init__.py +22 -0
- flwr/supercore/license_plugin/license_plugin.py +26 -0
- flwr/supercore/object_store/__init__.py +24 -0
- flwr/supercore/object_store/in_memory_object_store.py +229 -0
- flwr/supercore/object_store/object_store.py +170 -0
- flwr/supercore/object_store/object_store_factory.py +44 -0
- flwr/supercore/object_store/utils.py +43 -0
- flwr/supercore/scheduler/__init__.py +22 -0
- flwr/supercore/scheduler/plugin.py +71 -0
- flwr/{client/nodestate/nodestate.py → supercore/utils.py} +14 -13
- flwr/superexec/deployment.py +7 -4
- flwr/superexec/exec_event_log_interceptor.py +8 -4
- flwr/superexec/exec_grpc.py +25 -5
- flwr/superexec/exec_license_interceptor.py +82 -0
- flwr/superexec/exec_servicer.py +135 -24
- flwr/superexec/exec_user_auth_interceptor.py +45 -8
- flwr/superexec/executor.py +5 -1
- flwr/superexec/simulation.py +8 -3
- flwr/superlink/__init__.py +15 -0
- flwr/{client/supernode → supernode}/__init__.py +0 -7
- flwr/supernode/cli/__init__.py +24 -0
- flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -19
- flwr/supernode/cli/flwr_clientapp.py +88 -0
- flwr/supernode/nodestate/in_memory_nodestate.py +199 -0
- flwr/supernode/nodestate/nodestate.py +227 -0
- flwr/supernode/runtime/__init__.py +15 -0
- flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +135 -89
- flwr/supernode/scheduler/__init__.py +22 -0
- flwr/supernode/scheduler/simple_clientapp_scheduler_plugin.py +49 -0
- flwr/supernode/servicer/__init__.py +15 -0
- flwr/supernode/servicer/clientappio/__init__.py +22 -0
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +303 -0
- flwr/supernode/start_client_internal.py +589 -0
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/METADATA +6 -4
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/RECORD +171 -123
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/WHEEL +1 -1
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/entry_points.txt +2 -2
- flwr/client/clientapp/clientappio_servicer.py +0 -244
- flwr/client/heartbeat.py +0 -74
- flwr/client/nodestate/in_memory_nodestate.py +0 -38
- /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
- /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
- /flwr/{client → supernode}/nodestate/__init__.py +0 -0
- /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
|
@@ -0,0 +1,402 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Array."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import json
|
|
21
|
+
import sys
|
|
22
|
+
from dataclasses import dataclass
|
|
23
|
+
from io import BytesIO
|
|
24
|
+
from typing import TYPE_CHECKING, Any, cast, overload
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
|
|
28
|
+
from ..constant import MAX_ARRAY_CHUNK_SIZE, SType
|
|
29
|
+
from ..inflatable import (
|
|
30
|
+
InflatableObject,
|
|
31
|
+
add_header_to_object_body,
|
|
32
|
+
get_object_body,
|
|
33
|
+
get_object_children_ids_from_object_content,
|
|
34
|
+
)
|
|
35
|
+
from ..typing import NDArray
|
|
36
|
+
from .arraychunk import ArrayChunk
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
import torch
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _raise_array_init_error() -> None:
|
|
43
|
+
raise TypeError(
|
|
44
|
+
f"Invalid arguments for {Array.__qualname__}. Expected either a "
|
|
45
|
+
"PyTorch tensor, a NumPy ndarray, or explicit"
|
|
46
|
+
" dtype/shape/stype/data values."
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class Array(InflatableObject):
|
|
52
|
+
"""Array type.
|
|
53
|
+
|
|
54
|
+
A dataclass containing serialized data from an array-like or tensor-like object
|
|
55
|
+
along with metadata about it. The class can be initialized in one of three ways:
|
|
56
|
+
|
|
57
|
+
1. By specifying explicit values for `dtype`, `shape`, `stype`, and `data`.
|
|
58
|
+
2. By providing a NumPy ndarray (via the `ndarray` argument).
|
|
59
|
+
3. By providing a PyTorch tensor (via the `torch_tensor` argument).
|
|
60
|
+
|
|
61
|
+
In scenarios (2)-(3), the `dtype`, `shape`, `stype`, and `data` are automatically
|
|
62
|
+
derived from the input. In scenario (1), these fields must be specified manually.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
dtype : Optional[str] (default: None)
|
|
67
|
+
A string representing the data type of the serialized object (e.g. `"float32"`).
|
|
68
|
+
Only required if you are not passing in a ndarray or a tensor.
|
|
69
|
+
|
|
70
|
+
shape : Optional[tuple[int, ...]] (default: None)
|
|
71
|
+
A tuple representing the shape of the unserialized array-like object. Only
|
|
72
|
+
required if you are not passing in a ndarray or a tensor.
|
|
73
|
+
|
|
74
|
+
stype : Optional[str] (default: None)
|
|
75
|
+
A string indicating the serialization mechanism used to generate the bytes in
|
|
76
|
+
`data` from an array-like or tensor-like object. Only required if you are not
|
|
77
|
+
passing in a ndarray or a tensor.
|
|
78
|
+
|
|
79
|
+
data : Optional[bytes] (default: None)
|
|
80
|
+
A buffer of bytes containing the data. Only required if you are not passing in
|
|
81
|
+
a ndarray or a tensor.
|
|
82
|
+
|
|
83
|
+
ndarray : Optional[NDArray] (default: None)
|
|
84
|
+
A NumPy ndarray. If provided, the `dtype`, `shape`, `stype`, and `data`
|
|
85
|
+
fields are derived automatically from it.
|
|
86
|
+
|
|
87
|
+
torch_tensor : Optional[torch.Tensor] (default: None)
|
|
88
|
+
A PyTorch tensor. If provided, it will be **detached and moved to CPU**
|
|
89
|
+
before conversion, and the `dtype`, `shape`, `stype`, and `data` fields
|
|
90
|
+
will be derived automatically from it.
|
|
91
|
+
|
|
92
|
+
Examples
|
|
93
|
+
--------
|
|
94
|
+
Initializing by specifying all fields directly::
|
|
95
|
+
|
|
96
|
+
arr1 = Array(
|
|
97
|
+
dtype="float32",
|
|
98
|
+
shape=[3, 3],
|
|
99
|
+
stype="numpy.ndarray",
|
|
100
|
+
data=b"serialized_data...",
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
Initializing with a NumPy ndarray::
|
|
104
|
+
|
|
105
|
+
import numpy as np
|
|
106
|
+
arr2 = Array(np.random.randn(3, 3))
|
|
107
|
+
|
|
108
|
+
Initializing with a PyTorch tensor::
|
|
109
|
+
|
|
110
|
+
import torch
|
|
111
|
+
arr3 = Array(torch.randn(3, 3))
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
dtype: str
|
|
115
|
+
shape: tuple[int, ...]
|
|
116
|
+
stype: str
|
|
117
|
+
data: bytes
|
|
118
|
+
|
|
119
|
+
@overload
|
|
120
|
+
def __init__( # noqa: E704
|
|
121
|
+
self, dtype: str, shape: tuple[int, ...], stype: str, data: bytes
|
|
122
|
+
) -> None: ...
|
|
123
|
+
|
|
124
|
+
@overload
|
|
125
|
+
def __init__(self, ndarray: NDArray) -> None: ... # noqa: E704
|
|
126
|
+
|
|
127
|
+
@overload
|
|
128
|
+
def __init__(self, torch_tensor: torch.Tensor) -> None: ... # noqa: E704
|
|
129
|
+
|
|
130
|
+
def __init__( # pylint: disable=too-many-arguments, too-many-locals
|
|
131
|
+
self,
|
|
132
|
+
*args: Any,
|
|
133
|
+
dtype: str | None = None,
|
|
134
|
+
shape: tuple[int, ...] | None = None,
|
|
135
|
+
stype: str | None = None,
|
|
136
|
+
data: bytes | None = None,
|
|
137
|
+
ndarray: NDArray | None = None,
|
|
138
|
+
torch_tensor: torch.Tensor | None = None,
|
|
139
|
+
) -> None:
|
|
140
|
+
# Determine the initialization method and validate input arguments.
|
|
141
|
+
# Support three initialization formats:
|
|
142
|
+
# 1. Array(dtype: str, shape: tuple[int, ...], stype: str, data: bytes)
|
|
143
|
+
# 2. Array(ndarray: NDArray)
|
|
144
|
+
# 3. Array(torch_tensor: torch.Tensor)
|
|
145
|
+
|
|
146
|
+
# Initialize all arguments
|
|
147
|
+
# If more than 4 positional arguments are provided, raise an error.
|
|
148
|
+
if len(args) > 4:
|
|
149
|
+
_raise_array_init_error()
|
|
150
|
+
all_args = [None] * 4
|
|
151
|
+
for i, arg in enumerate(args):
|
|
152
|
+
all_args[i] = arg
|
|
153
|
+
init_method: str | None = None # Track which init method is being used
|
|
154
|
+
|
|
155
|
+
# Try to assign a value to all_args[index] if it's not already set.
|
|
156
|
+
# If an initialization method is provided, update init_method.
|
|
157
|
+
def _try_set_arg(index: int, arg: Any, method: str) -> None:
|
|
158
|
+
# Skip if arg is None
|
|
159
|
+
if arg is None:
|
|
160
|
+
return
|
|
161
|
+
# Raise an error if all_args[index] is already set
|
|
162
|
+
if all_args[index] is not None:
|
|
163
|
+
_raise_array_init_error()
|
|
164
|
+
# Raise an error if a different initialization method is already set
|
|
165
|
+
nonlocal init_method
|
|
166
|
+
if init_method is not None and init_method != method:
|
|
167
|
+
_raise_array_init_error()
|
|
168
|
+
# Set init_method and all_args[index]
|
|
169
|
+
if init_method is None:
|
|
170
|
+
init_method = method
|
|
171
|
+
all_args[index] = arg
|
|
172
|
+
|
|
173
|
+
# Try to set keyword arguments in all_args
|
|
174
|
+
_try_set_arg(0, dtype, "direct")
|
|
175
|
+
_try_set_arg(1, shape, "direct")
|
|
176
|
+
_try_set_arg(2, stype, "direct")
|
|
177
|
+
_try_set_arg(3, data, "direct")
|
|
178
|
+
_try_set_arg(0, ndarray, "ndarray")
|
|
179
|
+
_try_set_arg(0, torch_tensor, "torch_tensor")
|
|
180
|
+
|
|
181
|
+
# Check if all arguments are correctly set
|
|
182
|
+
all_args = [arg for arg in all_args if arg is not None]
|
|
183
|
+
|
|
184
|
+
# Handle direct field initialization
|
|
185
|
+
if not init_method or init_method == "direct":
|
|
186
|
+
if (
|
|
187
|
+
len(all_args) == 4 # pylint: disable=too-many-boolean-expressions
|
|
188
|
+
and isinstance(all_args[0], str)
|
|
189
|
+
and isinstance(all_args[1], tuple)
|
|
190
|
+
and all(isinstance(i, int) for i in all_args[1])
|
|
191
|
+
and isinstance(all_args[2], str)
|
|
192
|
+
and isinstance(all_args[3], bytes)
|
|
193
|
+
):
|
|
194
|
+
self.dtype, self.shape, self.stype, self.data = all_args
|
|
195
|
+
return
|
|
196
|
+
|
|
197
|
+
# Handle NumPy array
|
|
198
|
+
if not init_method or init_method == "ndarray":
|
|
199
|
+
if len(all_args) == 1 and isinstance(all_args[0], np.ndarray):
|
|
200
|
+
self.__dict__.update(self.from_numpy_ndarray(all_args[0]).__dict__)
|
|
201
|
+
return
|
|
202
|
+
|
|
203
|
+
# Handle PyTorch tensor
|
|
204
|
+
if not init_method or init_method == "torch_tensor":
|
|
205
|
+
if (
|
|
206
|
+
len(all_args) == 1
|
|
207
|
+
and "torch" in sys.modules
|
|
208
|
+
and isinstance(all_args[0], sys.modules["torch"].Tensor)
|
|
209
|
+
):
|
|
210
|
+
self.__dict__.update(self.from_torch_tensor(all_args[0]).__dict__)
|
|
211
|
+
return
|
|
212
|
+
|
|
213
|
+
_raise_array_init_error()
|
|
214
|
+
|
|
215
|
+
@classmethod
|
|
216
|
+
def from_numpy_ndarray(cls, ndarray: NDArray) -> Array:
|
|
217
|
+
"""Create Array from NumPy ndarray."""
|
|
218
|
+
assert isinstance(
|
|
219
|
+
ndarray, np.ndarray
|
|
220
|
+
), f"Expected NumPy ndarray, got {type(ndarray)}"
|
|
221
|
+
buffer = BytesIO()
|
|
222
|
+
# WARNING: NEVER set allow_pickle to true.
|
|
223
|
+
# Reason: loading pickled data can execute arbitrary code
|
|
224
|
+
# Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
|
|
225
|
+
np.save(buffer, ndarray, allow_pickle=False)
|
|
226
|
+
data = buffer.getvalue()
|
|
227
|
+
return Array(
|
|
228
|
+
dtype=str(ndarray.dtype),
|
|
229
|
+
shape=tuple(ndarray.shape),
|
|
230
|
+
stype=SType.NUMPY,
|
|
231
|
+
data=data,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
@classmethod
|
|
235
|
+
def from_torch_tensor(cls, tensor: torch.Tensor) -> Array:
|
|
236
|
+
"""Create Array from PyTorch tensor."""
|
|
237
|
+
if not (torch := sys.modules.get("torch")):
|
|
238
|
+
raise RuntimeError(
|
|
239
|
+
f"PyTorch is required to use {cls.from_torch_tensor.__name__}"
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
assert isinstance(
|
|
243
|
+
tensor, torch.Tensor
|
|
244
|
+
), f"Expected PyTorch Tensor, got {type(tensor)}"
|
|
245
|
+
return cls.from_numpy_ndarray(tensor.detach().cpu().numpy())
|
|
246
|
+
|
|
247
|
+
def numpy(self) -> NDArray:
|
|
248
|
+
"""Return the array as a NumPy array."""
|
|
249
|
+
if self.stype != SType.NUMPY:
|
|
250
|
+
raise TypeError(
|
|
251
|
+
f"Unsupported serialization type for numpy conversion: '{self.stype}'"
|
|
252
|
+
)
|
|
253
|
+
bytes_io = BytesIO(self.data)
|
|
254
|
+
# WARNING: NEVER set allow_pickle to true.
|
|
255
|
+
# Reason: loading pickled data can execute arbitrary code
|
|
256
|
+
# Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
|
|
257
|
+
ndarray_deserialized = np.load(bytes_io, allow_pickle=False)
|
|
258
|
+
return cast(NDArray, ndarray_deserialized)
|
|
259
|
+
|
|
260
|
+
@property
|
|
261
|
+
def children(self) -> dict[str, InflatableObject]:
|
|
262
|
+
"""Return a dictionary of ArrayChunks with their Object IDs as keys."""
|
|
263
|
+
return dict(self.slice_array())
|
|
264
|
+
|
|
265
|
+
def slice_array(self) -> list[tuple[str, InflatableObject]]:
|
|
266
|
+
"""Slice Array data and construct a list of ArrayChunks."""
|
|
267
|
+
# Return cached chunks if they exist
|
|
268
|
+
if "_chunks" in self.__dict__:
|
|
269
|
+
return cast(list[tuple[str, InflatableObject]], self.__dict__["_chunks"])
|
|
270
|
+
|
|
271
|
+
# Chunks are not children as some of them may be identical
|
|
272
|
+
chunks: list[tuple[str, InflatableObject]] = []
|
|
273
|
+
# memoryview allows for zero-copy slicing
|
|
274
|
+
data_view = memoryview(self.data)
|
|
275
|
+
for start in range(0, len(data_view), MAX_ARRAY_CHUNK_SIZE):
|
|
276
|
+
end = min(start + MAX_ARRAY_CHUNK_SIZE, len(data_view))
|
|
277
|
+
ac = ArrayChunk(data_view[start:end])
|
|
278
|
+
chunks.append((ac.object_id, ac))
|
|
279
|
+
|
|
280
|
+
# Cache the chunks for future use
|
|
281
|
+
self.__dict__["_chunks"] = chunks
|
|
282
|
+
return chunks
|
|
283
|
+
|
|
284
|
+
def deflate(self) -> bytes:
|
|
285
|
+
"""Deflate the Array."""
|
|
286
|
+
array_metadata: dict[str, str | tuple[int, ...] | list[int]] = {}
|
|
287
|
+
|
|
288
|
+
# We want to record all object_id even if repeated
|
|
289
|
+
# it can happend that chunks carry the exact same data
|
|
290
|
+
# for example when the array has only zeros
|
|
291
|
+
children_list = self.slice_array()
|
|
292
|
+
# Let's not save the entire object_id but a mapping to those
|
|
293
|
+
# that will be carried in the object head
|
|
294
|
+
# (replace a long object_id with a single scalar)
|
|
295
|
+
unique_children = list(self.children.keys())
|
|
296
|
+
arraychunk_ids = [unique_children.index(ch_id) for ch_id, _ in children_list]
|
|
297
|
+
|
|
298
|
+
# The deflated Array carries everything but the data
|
|
299
|
+
# The `arraychunk_ids` will be used during Array inflation
|
|
300
|
+
# to rematerialize the data from ArrayChunk objects.
|
|
301
|
+
array_metadata = {
|
|
302
|
+
"dtype": self.dtype,
|
|
303
|
+
"shape": self.shape,
|
|
304
|
+
"stype": self.stype,
|
|
305
|
+
"arraychunk_ids": arraychunk_ids,
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
# Serialize metadata dict
|
|
309
|
+
obj_body = json.dumps(array_metadata).encode("utf-8")
|
|
310
|
+
return add_header_to_object_body(object_body=obj_body, obj=self)
|
|
311
|
+
|
|
312
|
+
@classmethod
|
|
313
|
+
def inflate(
|
|
314
|
+
cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
|
|
315
|
+
) -> Array:
|
|
316
|
+
"""Inflate an Array from bytes.
|
|
317
|
+
|
|
318
|
+
Parameters
|
|
319
|
+
----------
|
|
320
|
+
object_content : bytes
|
|
321
|
+
The deflated object content of the Array.
|
|
322
|
+
|
|
323
|
+
children : Optional[dict[str, InflatableObject]] (default: None)
|
|
324
|
+
Must be ``None``. ``Array`` must have child objects.
|
|
325
|
+
Providing no children will raise a ``ValueError``.
|
|
326
|
+
|
|
327
|
+
Returns
|
|
328
|
+
-------
|
|
329
|
+
Array
|
|
330
|
+
The inflated Array.
|
|
331
|
+
"""
|
|
332
|
+
if children is None:
|
|
333
|
+
children = {}
|
|
334
|
+
|
|
335
|
+
obj_body = get_object_body(object_content, cls)
|
|
336
|
+
|
|
337
|
+
# Extract children IDs from head
|
|
338
|
+
children_ids = get_object_children_ids_from_object_content(object_content)
|
|
339
|
+
# Decode the Array body
|
|
340
|
+
array_metadata: dict[str, str | tuple[int, ...] | list[int]] = json.loads(
|
|
341
|
+
obj_body.decode(encoding="utf-8")
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
# Verify children ids in body match those passed for inflation
|
|
345
|
+
chunk_ids_indices = cast(list[int], array_metadata["arraychunk_ids"])
|
|
346
|
+
# Convert indices back to IDs
|
|
347
|
+
chunk_ids = [children_ids[i] for i in chunk_ids_indices]
|
|
348
|
+
# Check consistency
|
|
349
|
+
unique_arrayschunks = set(chunk_ids)
|
|
350
|
+
children_obj_ids = set(children.keys())
|
|
351
|
+
if unique_arrayschunks != children_obj_ids:
|
|
352
|
+
raise ValueError(
|
|
353
|
+
"Unexpected set of `children`. "
|
|
354
|
+
f"Expected {unique_arrayschunks} but got {children_obj_ids}."
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
# Materialize Array with empty data
|
|
358
|
+
array = cls(
|
|
359
|
+
dtype=cast(str, array_metadata["dtype"]),
|
|
360
|
+
shape=cast(tuple[int], tuple(array_metadata["shape"])),
|
|
361
|
+
stype=cast(str, array_metadata["stype"]),
|
|
362
|
+
data=b"",
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
# Now inject data from chunks
|
|
366
|
+
buff = bytearray()
|
|
367
|
+
for ch_id in chunk_ids:
|
|
368
|
+
buff += cast(ArrayChunk, children[ch_id]).data
|
|
369
|
+
|
|
370
|
+
array.data = bytes(buff)
|
|
371
|
+
return array
|
|
372
|
+
|
|
373
|
+
@property
|
|
374
|
+
def object_id(self) -> str:
|
|
375
|
+
"""Get object ID."""
|
|
376
|
+
ret = super().object_id
|
|
377
|
+
self.is_dirty = False # Reset dirty flag
|
|
378
|
+
return ret
|
|
379
|
+
|
|
380
|
+
@property
|
|
381
|
+
def is_dirty(self) -> bool:
|
|
382
|
+
"""Check if the object is dirty after the last deflation."""
|
|
383
|
+
if "_is_dirty" not in self.__dict__:
|
|
384
|
+
self.__dict__["_is_dirty"] = True
|
|
385
|
+
return cast(bool, self.__dict__["_is_dirty"])
|
|
386
|
+
|
|
387
|
+
@is_dirty.setter
|
|
388
|
+
def is_dirty(self, value: bool) -> None:
|
|
389
|
+
"""Set the dirty flag."""
|
|
390
|
+
self.__dict__["_is_dirty"] = value
|
|
391
|
+
|
|
392
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
393
|
+
"""Set attribute with special handling for dirty state."""
|
|
394
|
+
if name in ("dtype", "shape", "stype", "data"):
|
|
395
|
+
# Mark as dirty if any of the main attributes are set
|
|
396
|
+
self.is_dirty = True
|
|
397
|
+
# Clear cached object ID
|
|
398
|
+
self.__dict__.pop("_object_id", None)
|
|
399
|
+
# Clear cached chunks if data is set
|
|
400
|
+
if name == "data":
|
|
401
|
+
self.__dict__.pop("_chunks", None)
|
|
402
|
+
super().__setattr__(name, value)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""ArrayChunk."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
|
|
22
|
+
from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class ArrayChunk(InflatableObject):
|
|
27
|
+
"""ArrayChunk type."""
|
|
28
|
+
|
|
29
|
+
data: memoryview
|
|
30
|
+
|
|
31
|
+
def deflate(self) -> bytes:
|
|
32
|
+
"""Deflate the ArrayChunk."""
|
|
33
|
+
return add_header_to_object_body(object_body=self.data, obj=self)
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def inflate(
|
|
37
|
+
cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
|
|
38
|
+
) -> ArrayChunk:
|
|
39
|
+
"""Inflate an ArrayChunk from bytes.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
object_content : bytes
|
|
44
|
+
The deflated object content of the ArrayChunk.
|
|
45
|
+
|
|
46
|
+
children : Optional[dict[str, InflatableObject]] (default: None)
|
|
47
|
+
Must be ``None``. ``ArrayChunk`` does not support child objects.
|
|
48
|
+
Providing any children will raise a ``ValueError``.
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
ArrayChunk
|
|
53
|
+
The inflated ArrayChunk.
|
|
54
|
+
"""
|
|
55
|
+
if children:
|
|
56
|
+
raise ValueError("`ArrayChunk` objects do not have children.")
|
|
57
|
+
|
|
58
|
+
obj_body = get_object_body(object_content, cls)
|
|
59
|
+
return cls(data=memoryview(obj_body))
|