flwr 1.15.2__py3-none-any.whl → 1.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/build.py +2 -0
- flwr/cli/log.py +20 -21
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/client/client_app.py +147 -36
- flwr/client/clientapp/app.py +4 -0
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/rest_client/connection.py +4 -6
- flwr/client/supernode/__init__.py +0 -2
- flwr/client/supernode/app.py +1 -11
- flwr/common/address.py +35 -0
- flwr/common/args.py +8 -2
- flwr/common/auth_plugin/auth_plugin.py +2 -1
- flwr/common/constant.py +16 -0
- flwr/common/event_log_plugin/__init__.py +22 -0
- flwr/common/event_log_plugin/event_log_plugin.py +60 -0
- flwr/common/grpc.py +1 -1
- flwr/common/message.py +18 -7
- flwr/common/object_ref.py +0 -10
- flwr/common/record/conversion_utils.py +8 -17
- flwr/common/record/parametersrecord.py +151 -16
- flwr/common/record/recordset.py +95 -88
- flwr/common/secure_aggregation/quantization.py +5 -1
- flwr/common/serde.py +8 -126
- flwr/common/telemetry.py +0 -10
- flwr/common/typing.py +36 -0
- flwr/server/app.py +18 -2
- flwr/server/compat/app.py +4 -1
- flwr/server/compat/app_utils.py +10 -2
- flwr/server/compat/driver_client_proxy.py +2 -2
- flwr/server/driver/driver.py +1 -1
- flwr/server/driver/grpc_driver.py +10 -1
- flwr/server/driver/inmemory_driver.py +17 -20
- flwr/server/run_serverapp.py +2 -13
- flwr/server/server_app.py +93 -20
- flwr/server/superlink/driver/serverappio_servicer.py +25 -27
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +32 -35
- flwr/server/superlink/linkstate/in_memory_linkstate.py +140 -126
- flwr/server/superlink/linkstate/linkstate.py +47 -60
- flwr/server/superlink/linkstate/sqlite_linkstate.py +210 -276
- flwr/server/superlink/linkstate/utils.py +91 -119
- flwr/server/utils/__init__.py +2 -2
- flwr/server/utils/validator.py +53 -68
- flwr/server/workflow/default_workflows.py +4 -1
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +3 -3
- flwr/superexec/app.py +0 -14
- flwr/superexec/exec_servicer.py +4 -4
- flwr/superexec/exec_user_auth_interceptor.py +5 -3
- {flwr-1.15.2.dist-info → flwr-1.16.0.dist-info}/METADATA +4 -4
- {flwr-1.15.2.dist-info → flwr-1.16.0.dist-info}/RECORD +63 -66
- {flwr-1.15.2.dist-info → flwr-1.16.0.dist-info}/entry_points.txt +0 -3
- flwr/client/message_handler/task_handler.py +0 -37
- flwr/proto/task_pb2.py +0 -33
- flwr/proto/task_pb2.pyi +0 -100
- flwr/proto/task_pb2_grpc.py +0 -4
- flwr/proto/task_pb2_grpc.pyi +0 -4
- {flwr-1.15.2.dist-info → flwr-1.16.0.dist-info}/LICENSE +0 -0
- {flwr-1.15.2.dist-info → flwr-1.16.0.dist-info}/WHEEL +0 -0
flwr/common/constant.py
CHANGED
|
@@ -212,3 +212,19 @@ class AuthType:
|
|
|
212
212
|
def __new__(cls) -> AuthType:
|
|
213
213
|
"""Prevent instantiation."""
|
|
214
214
|
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class EventLogWriterType:
|
|
218
|
+
"""Event log writer types."""
|
|
219
|
+
|
|
220
|
+
FALSE = "false"
|
|
221
|
+
STDOUT = "stdout"
|
|
222
|
+
|
|
223
|
+
def __new__(cls) -> EventLogWriterType:
|
|
224
|
+
"""Prevent instantiation."""
|
|
225
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
|
226
|
+
|
|
227
|
+
@classmethod
|
|
228
|
+
def choices(cls) -> list[str]:
|
|
229
|
+
"""Return a list of available log writer choices."""
|
|
230
|
+
return [cls.FALSE, cls.STDOUT]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Event log plugin components."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from .event_log_plugin import EventLogWriterPlugin as EventLogWriterPlugin
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"EventLogWriterPlugin",
|
|
22
|
+
]
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Abstract class for Flower Event Log Writer Plugin."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from abc import ABC, abstractmethod
|
|
19
|
+
from typing import Optional, Union
|
|
20
|
+
|
|
21
|
+
import grpc
|
|
22
|
+
from google.protobuf.message import Message as GrpcMessage
|
|
23
|
+
|
|
24
|
+
from flwr.common.typing import LogEntry, UserInfo
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class EventLogWriterPlugin(ABC):
|
|
28
|
+
"""Abstract Flower Event Log Writer Plugin class for ExecServicer."""
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def __init__(self) -> None:
|
|
32
|
+
"""Abstract constructor."""
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def compose_log_before_event( # pylint: disable=too-many-arguments
|
|
36
|
+
self,
|
|
37
|
+
request: GrpcMessage,
|
|
38
|
+
context: grpc.ServicerContext,
|
|
39
|
+
user_info: UserInfo,
|
|
40
|
+
method_name: str,
|
|
41
|
+
) -> LogEntry:
|
|
42
|
+
"""Compose pre-event log entry from the provided request and context."""
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def compose_log_after_event( # pylint: disable=too-many-arguments,R0917
|
|
46
|
+
self,
|
|
47
|
+
request: GrpcMessage,
|
|
48
|
+
context: grpc.ServicerContext,
|
|
49
|
+
user_info: UserInfo,
|
|
50
|
+
method_name: str,
|
|
51
|
+
response: Optional[Union[GrpcMessage, Exception]],
|
|
52
|
+
) -> LogEntry:
|
|
53
|
+
"""Compose post-event log entry from the provided response and context."""
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def write_log(
|
|
57
|
+
self,
|
|
58
|
+
log_entry: LogEntry,
|
|
59
|
+
) -> None:
|
|
60
|
+
"""Write the event log to the specified data sink."""
|
flwr/common/grpc.py
CHANGED
|
@@ -27,7 +27,7 @@ import grpc
|
|
|
27
27
|
from .address import is_port_in_use
|
|
28
28
|
from .logger import log
|
|
29
29
|
|
|
30
|
-
GRPC_MAX_MESSAGE_LENGTH: int =
|
|
30
|
+
GRPC_MAX_MESSAGE_LENGTH: int = 2_147_483_647 # == 2048 * 1024 * 1024 -1 (2GB)
|
|
31
31
|
|
|
32
32
|
INVALID_CERTIFICATES_ERR_MSG = """
|
|
33
33
|
When setting any of root_certificate, certificate, or private_key,
|
flwr/common/message.py
CHANGED
|
@@ -25,7 +25,7 @@ from .constant import MESSAGE_TTL_TOLERANCE
|
|
|
25
25
|
from .logger import log
|
|
26
26
|
from .record import RecordSet
|
|
27
27
|
|
|
28
|
-
DEFAULT_TTL =
|
|
28
|
+
DEFAULT_TTL = 43200 # This is 12 hours
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
class Metadata: # pylint: disable=too-many-instance-attributes
|
|
@@ -126,6 +126,16 @@ class Metadata: # pylint: disable=too-many-instance-attributes
|
|
|
126
126
|
"""Set creation timestamp for this message."""
|
|
127
127
|
self.__dict__["_created_at"] = value
|
|
128
128
|
|
|
129
|
+
@property
|
|
130
|
+
def delivered_at(self) -> str:
|
|
131
|
+
"""Unix timestamp when the message was delivered."""
|
|
132
|
+
return cast(str, self.__dict__["_delivered_at"])
|
|
133
|
+
|
|
134
|
+
@delivered_at.setter
|
|
135
|
+
def delivered_at(self, value: str) -> None:
|
|
136
|
+
"""Set delivery timestamp of this message."""
|
|
137
|
+
self.__dict__["_delivered_at"] = value
|
|
138
|
+
|
|
129
139
|
@property
|
|
130
140
|
def ttl(self) -> float:
|
|
131
141
|
"""Time-to-live for this message."""
|
|
@@ -223,6 +233,7 @@ class Message:
|
|
|
223
233
|
raise ValueError("Either `content` or `error` must be set, but not both.")
|
|
224
234
|
|
|
225
235
|
metadata.created_at = time.time() # Set the message creation timestamp
|
|
236
|
+
metadata.delivered_at = ""
|
|
226
237
|
var_dict = {
|
|
227
238
|
"_metadata": metadata,
|
|
228
239
|
"_content": content,
|
|
@@ -310,7 +321,7 @@ class Message:
|
|
|
310
321
|
)
|
|
311
322
|
message.metadata.ttl = ttl
|
|
312
323
|
|
|
313
|
-
self.
|
|
324
|
+
self._limit_message_res_ttl(message)
|
|
314
325
|
|
|
315
326
|
return message
|
|
316
327
|
|
|
@@ -353,7 +364,7 @@ class Message:
|
|
|
353
364
|
)
|
|
354
365
|
message.metadata.ttl = ttl
|
|
355
366
|
|
|
356
|
-
self.
|
|
367
|
+
self._limit_message_res_ttl(message)
|
|
357
368
|
|
|
358
369
|
return message
|
|
359
370
|
|
|
@@ -368,14 +379,14 @@ class Message:
|
|
|
368
379
|
)
|
|
369
380
|
return f"{self.__class__.__qualname__}({view})"
|
|
370
381
|
|
|
371
|
-
def
|
|
372
|
-
"""Limit the
|
|
373
|
-
replies to.
|
|
382
|
+
def _limit_message_res_ttl(self, message: Message) -> None:
|
|
383
|
+
"""Limit the TTL of the provided Message to not exceed the expiration time of
|
|
384
|
+
this Message it replies to.
|
|
374
385
|
|
|
375
386
|
Parameters
|
|
376
387
|
----------
|
|
377
388
|
message : Message
|
|
378
|
-
The
|
|
389
|
+
The reply Message to limit the TTL for.
|
|
379
390
|
"""
|
|
380
391
|
# Calculate the maximum allowed TTL
|
|
381
392
|
max_allowed_ttl = (
|
flwr/common/object_ref.py
CHANGED
|
@@ -170,7 +170,6 @@ def load_app( # pylint: disable= too-many-branches
|
|
|
170
170
|
module = importlib.import_module(module_str)
|
|
171
171
|
else:
|
|
172
172
|
module = sys.modules[module_str]
|
|
173
|
-
_reload_modules(project_dir)
|
|
174
173
|
|
|
175
174
|
except ModuleNotFoundError as err:
|
|
176
175
|
raise error_type(
|
|
@@ -200,15 +199,6 @@ def _unload_modules(project_dir: Path) -> None:
|
|
|
200
199
|
del sys.modules[name]
|
|
201
200
|
|
|
202
201
|
|
|
203
|
-
def _reload_modules(project_dir: Path) -> None:
|
|
204
|
-
"""Reload modules from the project directory."""
|
|
205
|
-
dir_str = str(project_dir.absolute())
|
|
206
|
-
for m in list(sys.modules.values()):
|
|
207
|
-
path: Optional[str] = getattr(m, "__file__", None)
|
|
208
|
-
if path is not None and path.startswith(dir_str):
|
|
209
|
-
importlib.reload(m)
|
|
210
|
-
|
|
211
|
-
|
|
212
202
|
def _set_sys_path(directory: Optional[Union[str, Path]]) -> None:
|
|
213
203
|
"""Set the system path."""
|
|
214
204
|
if directory is None:
|
|
@@ -15,26 +15,17 @@
|
|
|
15
15
|
"""Conversion utility functions for Records."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from
|
|
19
|
-
|
|
20
|
-
import numpy as np
|
|
21
|
-
|
|
22
|
-
from ..constant import SType
|
|
18
|
+
from ..logger import warn_deprecated_feature
|
|
23
19
|
from ..typing import NDArray
|
|
24
20
|
from .parametersrecord import Array
|
|
25
21
|
|
|
22
|
+
WARN_DEPRECATED_MESSAGE = (
|
|
23
|
+
"`array_from_numpy` is deprecated. Instead, use the `Array(ndarray)` class "
|
|
24
|
+
"directly or `Array.from_numpy_ndarray(ndarray)`."
|
|
25
|
+
)
|
|
26
|
+
|
|
26
27
|
|
|
27
28
|
def array_from_numpy(ndarray: NDArray) -> Array:
|
|
28
29
|
"""Create Array from NumPy ndarray."""
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
# Reason: loading pickled data can execute arbitrary code
|
|
32
|
-
# Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
|
|
33
|
-
np.save(buffer, ndarray, allow_pickle=False)
|
|
34
|
-
data = buffer.getvalue()
|
|
35
|
-
return Array(
|
|
36
|
-
dtype=str(ndarray.dtype),
|
|
37
|
-
shape=list(ndarray.shape),
|
|
38
|
-
stype=SType.NUMPY,
|
|
39
|
-
data=data,
|
|
40
|
-
)
|
|
30
|
+
warn_deprecated_feature(WARN_DEPRECATED_MESSAGE)
|
|
31
|
+
return Array.from_numpy_ndarray(ndarray)
|
|
@@ -15,10 +15,12 @@
|
|
|
15
15
|
"""ParametersRecord and Array."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
18
20
|
from collections import OrderedDict
|
|
19
21
|
from dataclasses import dataclass
|
|
20
22
|
from io import BytesIO
|
|
21
|
-
from typing import
|
|
23
|
+
from typing import Any, cast, overload
|
|
22
24
|
|
|
23
25
|
import numpy as np
|
|
24
26
|
|
|
@@ -27,29 +29,64 @@ from ..typing import NDArray
|
|
|
27
29
|
from .typeddict import TypedDict
|
|
28
30
|
|
|
29
31
|
|
|
32
|
+
def _raise_array_init_error() -> None:
|
|
33
|
+
raise TypeError(
|
|
34
|
+
f"Invalid arguments for {Array.__qualname__}. Expected either a "
|
|
35
|
+
"NumPy ndarray, or explicit dtype/shape/stype/data values."
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
30
39
|
@dataclass
|
|
31
40
|
class Array:
|
|
32
41
|
"""Array type.
|
|
33
42
|
|
|
34
43
|
A dataclass containing serialized data from an array-like or tensor-like object
|
|
35
|
-
along with
|
|
44
|
+
along with metadata about it. The class can be initialized in one of two ways:
|
|
45
|
+
|
|
46
|
+
1. By specifying explicit values for `dtype`, `shape`, `stype`, and `data`.
|
|
47
|
+
2. By providing a NumPy ndarray (via the `ndarray` argument).
|
|
48
|
+
|
|
49
|
+
In scenario (2), the `dtype`, `shape`, `stype`, and `data` are automatically
|
|
50
|
+
derived from the input. In scenario (1), these fields must be specified manually.
|
|
36
51
|
|
|
37
52
|
Parameters
|
|
38
53
|
----------
|
|
39
|
-
dtype : str
|
|
40
|
-
A string representing the data type of the
|
|
54
|
+
dtype : Optional[str] (default: None)
|
|
55
|
+
A string representing the data type of the serialized object (e.g. `"float32"`).
|
|
56
|
+
Only required if you are not passing in a ndarray.
|
|
41
57
|
|
|
42
|
-
shape :
|
|
43
|
-
A list representing the shape of the unserialized array-like object.
|
|
44
|
-
|
|
45
|
-
as a metadata field.
|
|
58
|
+
shape : Optional[list[int]] (default: None)
|
|
59
|
+
A list representing the shape of the unserialized array-like object. Only
|
|
60
|
+
required if you are not passing in a ndarray.
|
|
46
61
|
|
|
47
|
-
stype : str
|
|
48
|
-
A string indicating the
|
|
49
|
-
|
|
62
|
+
stype : Optional[str] (default: None)
|
|
63
|
+
A string indicating the serialization mechanism used to generate the bytes in
|
|
64
|
+
`data` from an array-like or tensor-like object. Only required if you are not
|
|
65
|
+
passing in a ndarray.
|
|
50
66
|
|
|
51
|
-
data: bytes
|
|
52
|
-
A buffer of bytes containing the data.
|
|
67
|
+
data : Optional[bytes] (default: None)
|
|
68
|
+
A buffer of bytes containing the data. Only required if you are not passing in
|
|
69
|
+
a ndarray.
|
|
70
|
+
|
|
71
|
+
ndarray : Optional[NDArray] (default: None)
|
|
72
|
+
A NumPy ndarray. If provided, the `dtype`, `shape`, `stype`, and `data`
|
|
73
|
+
fields are derived automatically from it.
|
|
74
|
+
|
|
75
|
+
Examples
|
|
76
|
+
--------
|
|
77
|
+
Initializing by specifying all fields directly:
|
|
78
|
+
|
|
79
|
+
>>> arr1 = Array(
|
|
80
|
+
>>> dtype="float32",
|
|
81
|
+
>>> shape=[3, 3],
|
|
82
|
+
>>> stype="numpy.ndarray",
|
|
83
|
+
>>> data=b"serialized_data...",
|
|
84
|
+
>>> )
|
|
85
|
+
|
|
86
|
+
Initializing with a NumPy ndarray:
|
|
87
|
+
|
|
88
|
+
>>> import numpy as np
|
|
89
|
+
>>> arr2 = Array(np.random.randn(3, 3))
|
|
53
90
|
"""
|
|
54
91
|
|
|
55
92
|
dtype: str
|
|
@@ -57,6 +94,105 @@ class Array:
|
|
|
57
94
|
stype: str
|
|
58
95
|
data: bytes
|
|
59
96
|
|
|
97
|
+
@overload
|
|
98
|
+
def __init__( # noqa: E704
|
|
99
|
+
self, dtype: str, shape: list[int], stype: str, data: bytes
|
|
100
|
+
) -> None: ...
|
|
101
|
+
|
|
102
|
+
@overload
|
|
103
|
+
def __init__(self, ndarray: NDArray) -> None: ... # noqa: E704
|
|
104
|
+
|
|
105
|
+
def __init__( # pylint: disable=too-many-arguments, too-many-locals
|
|
106
|
+
self,
|
|
107
|
+
*args: Any,
|
|
108
|
+
dtype: str | None = None,
|
|
109
|
+
shape: list[int] | None = None,
|
|
110
|
+
stype: str | None = None,
|
|
111
|
+
data: bytes | None = None,
|
|
112
|
+
ndarray: NDArray | None = None,
|
|
113
|
+
) -> None:
|
|
114
|
+
# Determine the initialization method and validate input arguments.
|
|
115
|
+
# Support two initialization formats:
|
|
116
|
+
# 1. Array(dtype: str, shape: list[int], stype: str, data: bytes)
|
|
117
|
+
# 2. Array(ndarray: NDArray)
|
|
118
|
+
|
|
119
|
+
# Initialize all arguments
|
|
120
|
+
# If more than 4 positional arguments are provided, raise an error.
|
|
121
|
+
if len(args) > 4:
|
|
122
|
+
_raise_array_init_error()
|
|
123
|
+
all_args = [None] * 4
|
|
124
|
+
for i, arg in enumerate(args):
|
|
125
|
+
all_args[i] = arg
|
|
126
|
+
init_method: str | None = None # Track which init method is being used
|
|
127
|
+
|
|
128
|
+
# Try to assign a value to all_args[index] if it's not already set.
|
|
129
|
+
# If an initialization method is provided, update init_method.
|
|
130
|
+
def _try_set_arg(index: int, arg: Any, method: str) -> None:
|
|
131
|
+
# Skip if arg is None
|
|
132
|
+
if arg is None:
|
|
133
|
+
return
|
|
134
|
+
# Raise an error if all_args[index] is already set
|
|
135
|
+
if all_args[index] is not None:
|
|
136
|
+
_raise_array_init_error()
|
|
137
|
+
# Raise an error if a different initialization method is already set
|
|
138
|
+
nonlocal init_method
|
|
139
|
+
if init_method is not None and init_method != method:
|
|
140
|
+
_raise_array_init_error()
|
|
141
|
+
# Set init_method and all_args[index]
|
|
142
|
+
if init_method is None:
|
|
143
|
+
init_method = method
|
|
144
|
+
all_args[index] = arg
|
|
145
|
+
|
|
146
|
+
# Try to set keyword arguments in all_args
|
|
147
|
+
_try_set_arg(0, dtype, "direct")
|
|
148
|
+
_try_set_arg(1, shape, "direct")
|
|
149
|
+
_try_set_arg(2, stype, "direct")
|
|
150
|
+
_try_set_arg(3, data, "direct")
|
|
151
|
+
_try_set_arg(0, ndarray, "ndarray")
|
|
152
|
+
|
|
153
|
+
# Check if all arguments are correctly set
|
|
154
|
+
all_args = [arg for arg in all_args if arg is not None]
|
|
155
|
+
|
|
156
|
+
# Handle direct field initialization
|
|
157
|
+
if not init_method or init_method == "direct":
|
|
158
|
+
if (
|
|
159
|
+
len(all_args) == 4 # pylint: disable=too-many-boolean-expressions
|
|
160
|
+
and isinstance(all_args[0], str)
|
|
161
|
+
and isinstance(all_args[1], list)
|
|
162
|
+
and all(isinstance(i, int) for i in all_args[1])
|
|
163
|
+
and isinstance(all_args[2], str)
|
|
164
|
+
and isinstance(all_args[3], bytes)
|
|
165
|
+
):
|
|
166
|
+
self.dtype, self.shape, self.stype, self.data = all_args
|
|
167
|
+
return
|
|
168
|
+
|
|
169
|
+
# Handle NumPy array
|
|
170
|
+
if not init_method or init_method == "ndarray":
|
|
171
|
+
if len(all_args) == 1 and isinstance(all_args[0], np.ndarray):
|
|
172
|
+
self.__dict__.update(self.from_numpy_ndarray(all_args[0]).__dict__)
|
|
173
|
+
return
|
|
174
|
+
|
|
175
|
+
_raise_array_init_error()
|
|
176
|
+
|
|
177
|
+
@classmethod
|
|
178
|
+
def from_numpy_ndarray(cls, ndarray: NDArray) -> Array:
|
|
179
|
+
"""Create Array from NumPy ndarray."""
|
|
180
|
+
assert isinstance(
|
|
181
|
+
ndarray, np.ndarray
|
|
182
|
+
), f"Expected NumPy ndarray, got {type(ndarray)}"
|
|
183
|
+
buffer = BytesIO()
|
|
184
|
+
# WARNING: NEVER set allow_pickle to true.
|
|
185
|
+
# Reason: loading pickled data can execute arbitrary code
|
|
186
|
+
# Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
|
|
187
|
+
np.save(buffer, ndarray, allow_pickle=False)
|
|
188
|
+
data = buffer.getvalue()
|
|
189
|
+
return Array(
|
|
190
|
+
dtype=str(ndarray.dtype),
|
|
191
|
+
shape=list(ndarray.shape),
|
|
192
|
+
stype=SType.NUMPY,
|
|
193
|
+
data=data,
|
|
194
|
+
)
|
|
195
|
+
|
|
60
196
|
def numpy(self) -> NDArray:
|
|
61
197
|
"""Return the array as a NumPy array."""
|
|
62
198
|
if self.stype != SType.NUMPY:
|
|
@@ -117,7 +253,6 @@ class ParametersRecord(TypedDict[str, Array]):
|
|
|
117
253
|
|
|
118
254
|
>>> import numpy as np
|
|
119
255
|
>>> from flwr.common import ParametersRecord
|
|
120
|
-
>>> from flwr.common import array_from_numpy
|
|
121
256
|
>>>
|
|
122
257
|
>>> # Let's create a simple NumPy array
|
|
123
258
|
>>> arr_np = np.random.randn(3, 3)
|
|
@@ -128,7 +263,7 @@ class ParametersRecord(TypedDict[str, Array]):
|
|
|
128
263
|
>>> [-0.10758364, 1.97619858, -0.37120501]])
|
|
129
264
|
>>>
|
|
130
265
|
>>> # Let's create an Array out of it
|
|
131
|
-
>>> arr =
|
|
266
|
+
>>> arr = Array(arr_np)
|
|
132
267
|
>>>
|
|
133
268
|
>>> # If we print it you'll see (note the binary data)
|
|
134
269
|
>>> Array(dtype='float64', shape=[3,3], stype='numpy.ndarray', data=b'@\x99\x18...')
|
|
@@ -176,7 +311,7 @@ class ParametersRecord(TypedDict[str, Array]):
|
|
|
176
311
|
|
|
177
312
|
def __init__(
|
|
178
313
|
self,
|
|
179
|
-
array_dict:
|
|
314
|
+
array_dict: OrderedDict[str, Array] | None = None,
|
|
180
315
|
keep_input: bool = False,
|
|
181
316
|
) -> None:
|
|
182
317
|
super().__init__(_check_key, _check_value)
|