flwr-nightly 1.17.0.dev20250317__py3-none-any.whl → 1.17.0.dev20250319__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/common/constant.py +5 -0
- flwr/common/logger.py +2 -2
- flwr/common/record/parametersrecord.py +336 -92
- flwr/server/__init__.py +3 -1
- flwr/server/app.py +1 -1
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +11 -11
- flwr/server/compat/app_utils.py +16 -16
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +9 -9
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +44 -15
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +12 -20
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +6 -14
- flwr/server/run_serverapp.py +4 -4
- flwr/server/server_app.py +38 -12
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/linkstate/in_memory_linkstate.py +28 -3
- flwr/server/superlink/linkstate/sqlite_linkstate.py +40 -2
- flwr/server/superlink/linkstate/utils.py +67 -10
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +1 -1
- flwr/server/typing.py +3 -3
- flwr/server/workflow/default_workflows.py +17 -19
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +15 -15
- flwr/simulation/run_simulation.py +10 -10
- {flwr_nightly-1.17.0.dev20250317.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/METADATA +1 -1
- {flwr_nightly-1.17.0.dev20250317.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/RECORD +31 -31
- {flwr_nightly-1.17.0.dev20250317.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.17.0.dev20250317.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.17.0.dev20250317.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/entry_points.txt +0 -0
flwr/common/constant.py
CHANGED
@@ -61,6 +61,7 @@ PING_CALL_TIMEOUT = 5
|
|
61
61
|
PING_BASE_MULTIPLIER = 0.8
|
62
62
|
PING_RANDOM_RANGE = (-0.1, 0.1)
|
63
63
|
PING_MAX_INTERVAL = 1e300
|
64
|
+
PING_PATIENCE = 2
|
64
65
|
|
65
66
|
# IDs
|
66
67
|
RUN_ID_NUM_BYTES = 8
|
@@ -120,6 +121,9 @@ TIMESTAMP_HEADER = "flwr-timestamp"
|
|
120
121
|
TIMESTAMP_TOLERANCE = 10 # General tolerance for timestamp verification
|
121
122
|
SYSTEM_TIME_TOLERANCE = 5 # Allowance for system time drift
|
122
123
|
|
124
|
+
# Constants for ParametersRecord
|
125
|
+
GC_THRESHOLD = 200_000_000 # 200 MB
|
126
|
+
|
123
127
|
|
124
128
|
class MessageType:
|
125
129
|
"""Message type."""
|
@@ -163,6 +167,7 @@ class ErrorCode:
|
|
163
167
|
CLIENT_APP_RAISED_EXCEPTION = 2
|
164
168
|
MESSAGE_UNAVAILABLE = 3
|
165
169
|
REPLY_MESSAGE_UNAVAILABLE = 4
|
170
|
+
NODE_UNAVAILABLE = 5
|
166
171
|
|
167
172
|
def __new__(cls) -> ErrorCode:
|
168
173
|
"""Prevent instantiation."""
|
flwr/common/logger.py
CHANGED
@@ -250,9 +250,9 @@ def warn_deprecated_feature_with_example(
|
|
250
250
|
log(
|
251
251
|
WARN,
|
252
252
|
"""FEATURE UPDATE: %s
|
253
|
-
|
253
|
+
------------------------------------------------------------
|
254
254
|
%s
|
255
|
-
|
255
|
+
------------------------------------------------------------
|
256
256
|
""",
|
257
257
|
example_message,
|
258
258
|
code_example,
|
@@ -17,22 +17,36 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
+
import gc
|
21
|
+
import sys
|
20
22
|
from collections import OrderedDict
|
21
23
|
from dataclasses import dataclass
|
22
24
|
from io import BytesIO
|
23
|
-
from typing import Any, cast, overload
|
25
|
+
from typing import TYPE_CHECKING, Any, cast, overload
|
24
26
|
|
25
27
|
import numpy as np
|
26
28
|
|
27
|
-
from ..constant import SType
|
29
|
+
from ..constant import GC_THRESHOLD, SType
|
28
30
|
from ..typing import NDArray
|
29
31
|
from .typeddict import TypedDict
|
30
32
|
|
33
|
+
if TYPE_CHECKING:
|
34
|
+
import torch
|
35
|
+
|
31
36
|
|
32
37
|
def _raise_array_init_error() -> None:
|
33
38
|
raise TypeError(
|
34
39
|
f"Invalid arguments for {Array.__qualname__}. Expected either a "
|
35
|
-
"NumPy ndarray, or explicit
|
40
|
+
"PyTorch tensor, a NumPy ndarray, or explicit"
|
41
|
+
" dtype/shape/stype/data values."
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
def _raise_parameters_record_init_error() -> None:
|
46
|
+
raise TypeError(
|
47
|
+
f"Invalid arguments for {ParametersRecord.__qualname__}. Expected either "
|
48
|
+
"a list of NumPy ndarrays, a PyTorch state_dict, or a dictionary of Arrays. "
|
49
|
+
"The `keep_input` argument is keyword-only."
|
36
50
|
)
|
37
51
|
|
38
52
|
|
@@ -41,37 +55,43 @@ class Array:
|
|
41
55
|
"""Array type.
|
42
56
|
|
43
57
|
A dataclass containing serialized data from an array-like or tensor-like object
|
44
|
-
along with metadata about it. The class can be initialized in one of
|
58
|
+
along with metadata about it. The class can be initialized in one of three ways:
|
45
59
|
|
46
60
|
1. By specifying explicit values for `dtype`, `shape`, `stype`, and `data`.
|
47
61
|
2. By providing a NumPy ndarray (via the `ndarray` argument).
|
62
|
+
3. By providing a PyTorch tensor (via the `torch_tensor` argument).
|
48
63
|
|
49
|
-
In
|
64
|
+
In scenarios (2)-(3), the `dtype`, `shape`, `stype`, and `data` are automatically
|
50
65
|
derived from the input. In scenario (1), these fields must be specified manually.
|
51
66
|
|
52
67
|
Parameters
|
53
68
|
----------
|
54
69
|
dtype : Optional[str] (default: None)
|
55
70
|
A string representing the data type of the serialized object (e.g. `"float32"`).
|
56
|
-
Only required if you are not passing in a ndarray.
|
71
|
+
Only required if you are not passing in a ndarray or a tensor.
|
57
72
|
|
58
73
|
shape : Optional[list[int]] (default: None)
|
59
74
|
A list representing the shape of the unserialized array-like object. Only
|
60
|
-
required if you are not passing in a ndarray.
|
75
|
+
required if you are not passing in a ndarray or a tensor.
|
61
76
|
|
62
77
|
stype : Optional[str] (default: None)
|
63
78
|
A string indicating the serialization mechanism used to generate the bytes in
|
64
79
|
`data` from an array-like or tensor-like object. Only required if you are not
|
65
|
-
passing in a ndarray.
|
80
|
+
passing in a ndarray or a tensor.
|
66
81
|
|
67
82
|
data : Optional[bytes] (default: None)
|
68
83
|
A buffer of bytes containing the data. Only required if you are not passing in
|
69
|
-
a ndarray.
|
84
|
+
a ndarray or a tensor.
|
70
85
|
|
71
86
|
ndarray : Optional[NDArray] (default: None)
|
72
87
|
A NumPy ndarray. If provided, the `dtype`, `shape`, `stype`, and `data`
|
73
88
|
fields are derived automatically from it.
|
74
89
|
|
90
|
+
torch_tensor : Optional[torch.Tensor] (default: None)
|
91
|
+
A PyTorch tensor. If provided, it will be **detached and moved to CPU**
|
92
|
+
before conversion, and the `dtype`, `shape`, `stype`, and `data` fields
|
93
|
+
will be derived automatically from it.
|
94
|
+
|
75
95
|
Examples
|
76
96
|
--------
|
77
97
|
Initializing by specifying all fields directly:
|
@@ -87,6 +107,11 @@ class Array:
|
|
87
107
|
|
88
108
|
>>> import numpy as np
|
89
109
|
>>> arr2 = Array(np.random.randn(3, 3))
|
110
|
+
|
111
|
+
Initializing with a PyTorch tensor:
|
112
|
+
|
113
|
+
>>> import torch
|
114
|
+
>>> arr3 = Array(torch.randn(3, 3))
|
90
115
|
"""
|
91
116
|
|
92
117
|
dtype: str
|
@@ -102,6 +127,9 @@ class Array:
|
|
102
127
|
@overload
|
103
128
|
def __init__(self, ndarray: NDArray) -> None: ... # noqa: E704
|
104
129
|
|
130
|
+
@overload
|
131
|
+
def __init__(self, torch_tensor: torch.Tensor) -> None: ... # noqa: E704
|
132
|
+
|
105
133
|
def __init__( # pylint: disable=too-many-arguments, too-many-locals
|
106
134
|
self,
|
107
135
|
*args: Any,
|
@@ -110,11 +138,13 @@ class Array:
|
|
110
138
|
stype: str | None = None,
|
111
139
|
data: bytes | None = None,
|
112
140
|
ndarray: NDArray | None = None,
|
141
|
+
torch_tensor: torch.Tensor | None = None,
|
113
142
|
) -> None:
|
114
143
|
# Determine the initialization method and validate input arguments.
|
115
|
-
# Support
|
144
|
+
# Support three initialization formats:
|
116
145
|
# 1. Array(dtype: str, shape: list[int], stype: str, data: bytes)
|
117
146
|
# 2. Array(ndarray: NDArray)
|
147
|
+
# 3. Array(torch_tensor: torch.Tensor)
|
118
148
|
|
119
149
|
# Initialize all arguments
|
120
150
|
# If more than 4 positional arguments are provided, raise an error.
|
@@ -149,6 +179,7 @@ class Array:
|
|
149
179
|
_try_set_arg(2, stype, "direct")
|
150
180
|
_try_set_arg(3, data, "direct")
|
151
181
|
_try_set_arg(0, ndarray, "ndarray")
|
182
|
+
_try_set_arg(0, torch_tensor, "torch_tensor")
|
152
183
|
|
153
184
|
# Check if all arguments are correctly set
|
154
185
|
all_args = [arg for arg in all_args if arg is not None]
|
@@ -172,6 +203,16 @@ class Array:
|
|
172
203
|
self.__dict__.update(self.from_numpy_ndarray(all_args[0]).__dict__)
|
173
204
|
return
|
174
205
|
|
206
|
+
# Handle PyTorch tensor
|
207
|
+
if not init_method or init_method == "torch_tensor":
|
208
|
+
if (
|
209
|
+
len(all_args) == 1
|
210
|
+
and "torch" in sys.modules
|
211
|
+
and isinstance(all_args[0], sys.modules["torch"].Tensor)
|
212
|
+
):
|
213
|
+
self.__dict__.update(self.from_torch_tensor(all_args[0]).__dict__)
|
214
|
+
return
|
215
|
+
|
175
216
|
_raise_array_init_error()
|
176
217
|
|
177
218
|
@classmethod
|
@@ -193,6 +234,19 @@ class Array:
|
|
193
234
|
data=data,
|
194
235
|
)
|
195
236
|
|
237
|
+
@classmethod
|
238
|
+
def from_torch_tensor(cls, tensor: torch.Tensor) -> Array:
|
239
|
+
"""Create Array from PyTorch tensor."""
|
240
|
+
if not (torch := sys.modules.get("torch")):
|
241
|
+
raise RuntimeError(
|
242
|
+
f"PyTorch is required to use {cls.from_torch_tensor.__name__}"
|
243
|
+
)
|
244
|
+
|
245
|
+
assert isinstance(
|
246
|
+
tensor, torch.Tensor
|
247
|
+
), f"Expected PyTorch Tensor, got {type(tensor)}"
|
248
|
+
return cls.from_numpy_ndarray(tensor.detach().cpu().numpy())
|
249
|
+
|
196
250
|
def numpy(self) -> NDArray:
|
197
251
|
"""Return the array as a NumPy array."""
|
198
252
|
if self.stype != SType.NUMPY:
|
@@ -223,103 +277,293 @@ def _check_value(value: Array) -> None:
|
|
223
277
|
class ParametersRecord(TypedDict[str, Array]):
|
224
278
|
r"""Parameters record.
|
225
279
|
|
226
|
-
A
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
280
|
+
A typed dictionary (``str`` to :class:`Array`) that can store named parameters
|
281
|
+
as serialized tensors. Internally, this behaves similarly to an
|
282
|
+
``OrderedDict[str, Array]``. A ``ParametersRecord`` can be viewed as an
|
283
|
+
equivalent to PyTorch's ``state_dict``, but it holds arrays in serialized form.
|
284
|
+
|
285
|
+
This object is one of the record types supported by :class:`RecordSet` and can
|
286
|
+
therefore be stored in the ``content`` of a :class:`Message` or the ``state``
|
287
|
+
of a :class:`Context`.
|
288
|
+
|
289
|
+
This class can be instantiated in multiple ways:
|
290
|
+
|
291
|
+
1. By providing nothing (empty container).
|
292
|
+
2. By providing a dictionary of :class:`Array` (via the ``array_dict`` argument).
|
293
|
+
3. By providing a list of NumPy ``ndarray`` (via the ``numpy_ndarrays`` argument).
|
294
|
+
4. By providing a PyTorch ``state_dict`` (via the ``torch_state_dict`` argument).
|
232
295
|
|
233
296
|
Parameters
|
234
297
|
----------
|
235
|
-
array_dict : Optional[OrderedDict[str, Array]]
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
298
|
+
array_dict : Optional[OrderedDict[str, Array]] (default: None)
|
299
|
+
An existing dictionary containing named :class:`Array` instances. If
|
300
|
+
provided, these entries will be used directly to populate the record.
|
301
|
+
numpy_ndarrays : Optional[list[NDArray]] (default: None)
|
302
|
+
A list of NumPy arrays. Each array will be automatically converted
|
303
|
+
into an :class:`Array` and stored in this record with generated keys.
|
304
|
+
torch_state_dict : Optional[OrderedDict[str, torch.Tensor]] (default: None)
|
305
|
+
A PyTorch ``state_dict`` (``str`` keys to ``torch.Tensor`` values). Each
|
306
|
+
tensor will be converted into an :class:`Array` and stored in this record.
|
307
|
+
keep_input : bool (default: True)
|
308
|
+
If ``False``, entries from the input are removed after being added to
|
309
|
+
this record to free up memory. If ``True``, the input remains unchanged.
|
310
|
+
Regardless of this value, no duplicate memory is used if the input is a
|
311
|
+
dictionary of :class:`Array`, i.e., ``array_dict``.
|
245
312
|
|
246
313
|
Examples
|
247
314
|
--------
|
248
|
-
|
249
|
-
|
250
|
-
|
315
|
+
Initializing an empty ParametersRecord:
|
316
|
+
|
317
|
+
>>> p_record = ParametersRecord()
|
318
|
+
|
319
|
+
Initializing with a dictionary of :class:`Array`:
|
320
|
+
|
321
|
+
>>> arr = Array("float32", [5, 5], "numpy.ndarray", b"serialized_data...")
|
322
|
+
>>> p_record = ParametersRecord({"weight": arr})
|
251
323
|
|
252
|
-
|
324
|
+
Initializing with a list of NumPy arrays:
|
253
325
|
|
254
326
|
>>> import numpy as np
|
255
|
-
>>>
|
256
|
-
>>>
|
257
|
-
>>>
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
>>>
|
262
|
-
>>>
|
263
|
-
>>>
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
>>>
|
268
|
-
>>>
|
269
|
-
>>>
|
270
|
-
>>>
|
271
|
-
>>> # Adding it to a ParametersRecord:
|
272
|
-
>>> p_record = ParametersRecord({"my_array": arr})
|
273
|
-
|
274
|
-
Now that the NumPy array is embedded into a :code:`ParametersRecord` it could be
|
275
|
-
sent if added as part of a :code:`common.Message` or it could be saved as a
|
276
|
-
persistent state of a :code:`ClientApp` via its context. Regardless of the usecase,
|
277
|
-
we will sooner or later want to recover the array in its original NumPy
|
278
|
-
representation. For the example above, where the array was serialized using the
|
279
|
-
built-in utility function, deserialization can be done as follows:
|
280
|
-
|
281
|
-
>>> # Use the Array's built-in method
|
282
|
-
>>> arr_np_d = arr.numpy()
|
283
|
-
>>>
|
284
|
-
>>> # If printed, it will show the exact same data as above:
|
285
|
-
>>> array([[-1.84242409, -1.01539537, -0.46528405],
|
286
|
-
>>> [ 0.32991896, 0.55540414, 0.44085534],
|
287
|
-
>>> [-0.10758364, 1.97619858, -0.37120501]])
|
288
|
-
|
289
|
-
If you need finer control on how your arrays are serialized and deserialized, you
|
290
|
-
can construct :code:`Array` objects directly like this:
|
291
|
-
|
292
|
-
>>> from flwr.common import Array
|
293
|
-
>>> # Serialize your array and construct Array object
|
294
|
-
>>> arr = Array(
|
295
|
-
>>> data=ndarray.tobytes(),
|
296
|
-
>>> dtype=str(ndarray.dtype),
|
297
|
-
>>> stype="", # Could be used in a deserialization function
|
298
|
-
>>> shape=list(ndarray.shape),
|
299
|
-
>>> )
|
300
|
-
>>>
|
301
|
-
>>> # Then you can deserialize it like this
|
302
|
-
>>> arr_np_d = np.frombuffer(
|
303
|
-
>>> buffer=array.data,
|
304
|
-
>>> dtype=array.dtype,
|
305
|
-
>>> ).reshape(array.shape)
|
306
|
-
|
307
|
-
Note that different arrays (e.g. from PyTorch, Tensorflow) might require different
|
308
|
-
serialization mechanism. Howerver, they often support a conversion to NumPy,
|
309
|
-
therefore allowing to use the same or similar steps as in the example above.
|
327
|
+
>>> arr1 = np.random.randn(3, 3)
|
328
|
+
>>> arr2 = np.random.randn(2, 2)
|
329
|
+
>>> p_record = ParametersRecord([arr1, arr2])
|
330
|
+
|
331
|
+
Initializing with a PyTorch model state_dict:
|
332
|
+
|
333
|
+
>>> import torch.nn as nn
|
334
|
+
>>> model = nn.Linear(10, 5)
|
335
|
+
>>> p_record = ParametersRecord(model.state_dict())
|
336
|
+
|
337
|
+
Initializing with a TensorFlow model weights (a list of NumPy arrays):
|
338
|
+
|
339
|
+
>>> import tensorflow as tf
|
340
|
+
>>> model = tf.keras.Sequential([tf.keras.layers.Dense(5, input_shape=(10,))])
|
341
|
+
>>> p_record = ParametersRecord(model.get_weights())
|
310
342
|
"""
|
311
343
|
|
312
|
-
|
344
|
+
@overload
|
345
|
+
def __init__(self) -> None: ... # noqa: E704
|
346
|
+
|
347
|
+
@overload
|
348
|
+
def __init__( # noqa: E704
|
349
|
+
self, array_dict: OrderedDict[str, Array], *, keep_input: bool = True
|
350
|
+
) -> None: ...
|
351
|
+
|
352
|
+
@overload
|
353
|
+
def __init__( # noqa: E704
|
354
|
+
self, numpy_ndarrays: list[NDArray], *, keep_input: bool = True
|
355
|
+
) -> None: ...
|
356
|
+
|
357
|
+
@overload
|
358
|
+
def __init__( # noqa: E704
|
359
|
+
self,
|
360
|
+
torch_state_dict: OrderedDict[str, torch.Tensor],
|
361
|
+
*,
|
362
|
+
keep_input: bool = True,
|
363
|
+
) -> None: ...
|
364
|
+
|
365
|
+
def __init__( # pylint: disable=too-many-arguments
|
313
366
|
self,
|
367
|
+
*args: Any,
|
368
|
+
numpy_ndarrays: list[NDArray] | None = None,
|
369
|
+
torch_state_dict: OrderedDict[str, torch.Tensor] | None = None,
|
314
370
|
array_dict: OrderedDict[str, Array] | None = None,
|
315
|
-
keep_input: bool =
|
371
|
+
keep_input: bool = True,
|
316
372
|
) -> None:
|
317
373
|
super().__init__(_check_key, _check_value)
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
374
|
+
|
375
|
+
# Determine the initialization method and validates input arguments.
|
376
|
+
# Support the following initialization formats:
|
377
|
+
# 1. cls(array_dict: OrderedDict[str, Array], keep_input: bool)
|
378
|
+
# 2. cls(numpy_ndarrays: list[NDArray], keep_input: bool)
|
379
|
+
# 3. cls(torch_state_dict: dict[str, torch.Tensor], keep_input: bool)
|
380
|
+
|
381
|
+
# Init the argument
|
382
|
+
if len(args) > 1:
|
383
|
+
_raise_parameters_record_init_error()
|
384
|
+
arg = args[0] if args else None
|
385
|
+
init_method: str | None = None # Track which init method is being used
|
386
|
+
|
387
|
+
# Try to assign a value to arg if it's not already set.
|
388
|
+
# If an initialization method is provided, update init_method.
|
389
|
+
def _try_set_arg(_arg: Any, method: str) -> None:
|
390
|
+
# Skip if _arg is None
|
391
|
+
if _arg is None:
|
392
|
+
return
|
393
|
+
nonlocal arg, init_method
|
394
|
+
# Raise an error if arg is already set
|
395
|
+
if arg is not None:
|
396
|
+
_raise_parameters_record_init_error()
|
397
|
+
# Raise an error if a different initialization method is already set
|
398
|
+
if init_method is not None:
|
399
|
+
_raise_parameters_record_init_error()
|
400
|
+
# Set init_method and arg
|
401
|
+
if init_method is None:
|
402
|
+
init_method = method
|
403
|
+
arg = _arg
|
404
|
+
|
405
|
+
# Try to set keyword arguments
|
406
|
+
_try_set_arg(array_dict, "array_dict")
|
407
|
+
_try_set_arg(numpy_ndarrays, "numpy_ndarrays")
|
408
|
+
_try_set_arg(torch_state_dict, "state_dict")
|
409
|
+
|
410
|
+
# If no arguments are provided, return and keep self empty
|
411
|
+
if arg is None:
|
412
|
+
return
|
413
|
+
|
414
|
+
# Handle dictionary of Arrays
|
415
|
+
if not init_method or init_method == "array_dict":
|
416
|
+
# Type check the input
|
417
|
+
if (
|
418
|
+
isinstance(arg, dict)
|
419
|
+
and all(isinstance(k, str) for k in arg.keys())
|
420
|
+
and all(isinstance(v, Array) for v in arg.values())
|
421
|
+
):
|
422
|
+
array_dict = cast(OrderedDict[str, Array], arg)
|
423
|
+
converted = self.from_array_dict(array_dict, keep_input=keep_input)
|
424
|
+
self.__dict__.update(converted.__dict__)
|
425
|
+
return
|
426
|
+
|
427
|
+
# Handle NumPy ndarrays
|
428
|
+
if not init_method or init_method == "numpy_ndarrays":
|
429
|
+
# Type check the input
|
430
|
+
# pylint: disable-next=not-an-iterable
|
431
|
+
if isinstance(arg, list) and all(isinstance(v, np.ndarray) for v in arg):
|
432
|
+
numpy_ndarrays = cast(list[NDArray], arg)
|
433
|
+
converted = self.from_numpy_ndarrays(
|
434
|
+
numpy_ndarrays, keep_input=keep_input
|
435
|
+
)
|
436
|
+
self.__dict__.update(converted.__dict__)
|
437
|
+
return
|
438
|
+
|
439
|
+
# Handle PyTorch state_dict
|
440
|
+
if not init_method or init_method == "state_dict":
|
441
|
+
# Type check the input
|
442
|
+
if (
|
443
|
+
(torch := sys.modules.get("torch")) is not None
|
444
|
+
and isinstance(arg, dict)
|
445
|
+
and all(isinstance(k, str) for k in arg.keys())
|
446
|
+
and all(isinstance(v, torch.Tensor) for v in arg.values())
|
447
|
+
):
|
448
|
+
torch_state_dict = cast(
|
449
|
+
OrderedDict[str, torch.Tensor], arg # type: ignore
|
450
|
+
)
|
451
|
+
converted = self.from_torch_state_dict(
|
452
|
+
torch_state_dict, keep_input=keep_input
|
453
|
+
)
|
454
|
+
self.__dict__.update(converted.__dict__)
|
455
|
+
return
|
456
|
+
|
457
|
+
_raise_parameters_record_init_error()
|
458
|
+
|
459
|
+
@classmethod
|
460
|
+
def from_array_dict(
|
461
|
+
cls,
|
462
|
+
array_dict: OrderedDict[str, Array],
|
463
|
+
*,
|
464
|
+
keep_input: bool = True,
|
465
|
+
) -> ParametersRecord:
|
466
|
+
"""Create ParametersRecord from a dictionary of :class:`Array`."""
|
467
|
+
record = ParametersRecord()
|
468
|
+
for k, v in array_dict.items():
|
469
|
+
record[k] = Array(
|
470
|
+
dtype=v.dtype, shape=list(v.shape), stype=v.stype, data=v.data
|
471
|
+
)
|
472
|
+
if not keep_input:
|
473
|
+
array_dict.clear()
|
474
|
+
return record
|
475
|
+
|
476
|
+
@classmethod
|
477
|
+
def from_numpy_ndarrays(
|
478
|
+
cls,
|
479
|
+
ndarrays: list[NDArray],
|
480
|
+
*,
|
481
|
+
keep_input: bool = True,
|
482
|
+
) -> ParametersRecord:
|
483
|
+
"""Create ParametersRecord from a list of NumPy ``ndarray``."""
|
484
|
+
record = ParametersRecord()
|
485
|
+
total_serialized_bytes = 0
|
486
|
+
|
487
|
+
for i in range(len(ndarrays)): # pylint: disable=C0200
|
488
|
+
record[str(i)] = Array.from_numpy_ndarray(ndarrays[i])
|
489
|
+
|
490
|
+
if not keep_input:
|
491
|
+
# Remove the reference
|
492
|
+
ndarrays[i] = None # type: ignore
|
493
|
+
total_serialized_bytes += len(record[str(i)].data)
|
494
|
+
|
495
|
+
# If total serialized data exceeds the threshold, trigger GC
|
496
|
+
if total_serialized_bytes > GC_THRESHOLD:
|
497
|
+
total_serialized_bytes = 0
|
498
|
+
gc.collect()
|
499
|
+
|
500
|
+
if not keep_input:
|
501
|
+
# Clear the entire list to remove all references and force GC
|
502
|
+
ndarrays.clear()
|
503
|
+
gc.collect()
|
504
|
+
return record
|
505
|
+
|
506
|
+
@classmethod
|
507
|
+
def from_torch_state_dict(
|
508
|
+
cls,
|
509
|
+
state_dict: OrderedDict[str, torch.Tensor],
|
510
|
+
*,
|
511
|
+
keep_input: bool = True,
|
512
|
+
) -> ParametersRecord:
|
513
|
+
"""Create ParametersRecord from PyTorch ``state_dict``."""
|
514
|
+
if "torch" not in sys.modules:
|
515
|
+
raise RuntimeError(
|
516
|
+
f"PyTorch is required to use {cls.from_torch_state_dict.__name__}"
|
517
|
+
)
|
518
|
+
|
519
|
+
record = ParametersRecord()
|
520
|
+
|
521
|
+
for k in list(state_dict.keys()):
|
522
|
+
v = state_dict[k] if keep_input else state_dict.pop(k)
|
523
|
+
record[k] = Array.from_numpy_ndarray(v.detach().cpu().numpy())
|
524
|
+
|
525
|
+
return record
|
526
|
+
|
527
|
+
def to_numpy_ndarrays(self, *, keep_input: bool = True) -> list[NDArray]:
|
528
|
+
"""Return the ParametersRecord as a list of NumPy ``ndarray``."""
|
529
|
+
if keep_input:
|
530
|
+
return [v.numpy() for v in self.values()]
|
531
|
+
|
532
|
+
# Clear the record and return the list of NumPy arrays
|
533
|
+
ret: list[NDArray] = []
|
534
|
+
total_serialized_bytes = 0
|
535
|
+
for k in list(self.keys()):
|
536
|
+
arr = self.pop(k)
|
537
|
+
ret.append(arr.numpy())
|
538
|
+
total_serialized_bytes += len(arr.data)
|
539
|
+
del arr
|
540
|
+
|
541
|
+
# If total serialized data exceeds the threshold, trigger GC
|
542
|
+
if total_serialized_bytes > GC_THRESHOLD:
|
543
|
+
total_serialized_bytes = 0
|
544
|
+
gc.collect()
|
545
|
+
|
546
|
+
if not keep_input:
|
547
|
+
# Force GC
|
548
|
+
gc.collect()
|
549
|
+
return ret
|
550
|
+
|
551
|
+
def to_torch_state_dict(
|
552
|
+
self, *, keep_input: bool = True
|
553
|
+
) -> OrderedDict[str, torch.Tensor]:
|
554
|
+
"""Return the ParametersRecord as a PyTorch ``state_dict``."""
|
555
|
+
if not (torch := sys.modules.get("torch")):
|
556
|
+
raise RuntimeError(
|
557
|
+
f"PyTorch is required to use {self.to_torch_state_dict.__name__}"
|
558
|
+
)
|
559
|
+
|
560
|
+
state_dict = OrderedDict()
|
561
|
+
|
562
|
+
for k in list(self.keys()):
|
563
|
+
arr = self[k] if keep_input else self.pop(k)
|
564
|
+
state_dict[k] = torch.from_numpy(arr.numpy())
|
565
|
+
|
566
|
+
return state_dict
|
323
567
|
|
324
568
|
def count_bytes(self) -> int:
|
325
569
|
"""Return number of Bytes stored in this object.
|
flwr/server/__init__.py
CHANGED
@@ -21,7 +21,8 @@ from .app import start_server as start_server
|
|
21
21
|
from .client_manager import ClientManager as ClientManager
|
22
22
|
from .client_manager import SimpleClientManager as SimpleClientManager
|
23
23
|
from .compat import LegacyContext as LegacyContext
|
24
|
-
from .
|
24
|
+
from .grid import Driver as Driver
|
25
|
+
from .grid import Grid as Grid
|
25
26
|
from .history import History as History
|
26
27
|
from .server import Server as Server
|
27
28
|
from .server_app import ServerApp as ServerApp
|
@@ -31,6 +32,7 @@ from .serverapp_components import ServerAppComponents as ServerAppComponents
|
|
31
32
|
__all__ = [
|
32
33
|
"ClientManager",
|
33
34
|
"Driver",
|
35
|
+
"Grid",
|
34
36
|
"History",
|
35
37
|
"LegacyContext",
|
36
38
|
"Server",
|
flwr/server/app.py
CHANGED
@@ -79,13 +79,13 @@ from .history import History
|
|
79
79
|
from .server import Server, init_defaults, run_fl
|
80
80
|
from .server_config import ServerConfig
|
81
81
|
from .strategy import Strategy
|
82
|
-
from .superlink.driver.serverappio_grpc import run_serverappio_api_grpc
|
83
82
|
from .superlink.ffs.ffs_factory import FfsFactory
|
84
83
|
from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
|
85
84
|
from .superlink.fleet.grpc_bidi.grpc_server import start_grpc_server
|
86
85
|
from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
|
87
86
|
from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
|
88
87
|
from .superlink.linkstate import LinkStateFactory
|
88
|
+
from .superlink.serverappio.serverappio_grpc import run_serverappio_api_grpc
|
89
89
|
from .superlink.simulation.simulationio_grpc import run_simulationio_api_grpc
|
90
90
|
|
91
91
|
DATABASE = ":flwr-in-memory-state:"
|
flwr/server/compat/__init__.py
CHANGED
@@ -15,10 +15,10 @@
|
|
15
15
|
"""Flower ServerApp compatibility package."""
|
16
16
|
|
17
17
|
|
18
|
-
from .app import
|
18
|
+
from .app import start_grid as start_grid
|
19
19
|
from .legacy_context import LegacyContext as LegacyContext
|
20
20
|
|
21
21
|
__all__ = [
|
22
22
|
"LegacyContext",
|
23
|
-
"
|
23
|
+
"start_grid",
|
24
24
|
]
|