flwr-nightly 1.17.0.dev20250317__py3-none-any.whl → 1.17.0.dev20250319__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. flwr/common/constant.py +5 -0
  2. flwr/common/logger.py +2 -2
  3. flwr/common/record/parametersrecord.py +336 -92
  4. flwr/server/__init__.py +3 -1
  5. flwr/server/app.py +1 -1
  6. flwr/server/compat/__init__.py +2 -2
  7. flwr/server/compat/app.py +11 -11
  8. flwr/server/compat/app_utils.py +16 -16
  9. flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +9 -9
  10. flwr/server/{driver → grid}/__init__.py +8 -7
  11. flwr/server/{driver/driver.py → grid/grid.py} +44 -15
  12. flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +12 -20
  13. flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +6 -14
  14. flwr/server/run_serverapp.py +4 -4
  15. flwr/server/server_app.py +38 -12
  16. flwr/server/serverapp/app.py +10 -10
  17. flwr/server/superlink/linkstate/in_memory_linkstate.py +28 -3
  18. flwr/server/superlink/linkstate/sqlite_linkstate.py +40 -2
  19. flwr/server/superlink/linkstate/utils.py +67 -10
  20. flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
  21. flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
  22. flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +1 -1
  23. flwr/server/typing.py +3 -3
  24. flwr/server/workflow/default_workflows.py +17 -19
  25. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +15 -15
  26. flwr/simulation/run_simulation.py +10 -10
  27. {flwr_nightly-1.17.0.dev20250317.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/METADATA +1 -1
  28. {flwr_nightly-1.17.0.dev20250317.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/RECORD +31 -31
  29. {flwr_nightly-1.17.0.dev20250317.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/LICENSE +0 -0
  30. {flwr_nightly-1.17.0.dev20250317.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/WHEEL +0 -0
  31. {flwr_nightly-1.17.0.dev20250317.dist-info → flwr_nightly-1.17.0.dev20250319.dist-info}/entry_points.txt +0 -0
flwr/common/constant.py CHANGED
@@ -61,6 +61,7 @@ PING_CALL_TIMEOUT = 5
61
61
  PING_BASE_MULTIPLIER = 0.8
62
62
  PING_RANDOM_RANGE = (-0.1, 0.1)
63
63
  PING_MAX_INTERVAL = 1e300
64
+ PING_PATIENCE = 2
64
65
 
65
66
  # IDs
66
67
  RUN_ID_NUM_BYTES = 8
@@ -120,6 +121,9 @@ TIMESTAMP_HEADER = "flwr-timestamp"
120
121
  TIMESTAMP_TOLERANCE = 10 # General tolerance for timestamp verification
121
122
  SYSTEM_TIME_TOLERANCE = 5 # Allowance for system time drift
122
123
 
124
+ # Constants for ParametersRecord
125
+ GC_THRESHOLD = 200_000_000 # 200 MB
126
+
123
127
 
124
128
  class MessageType:
125
129
  """Message type."""
@@ -163,6 +167,7 @@ class ErrorCode:
163
167
  CLIENT_APP_RAISED_EXCEPTION = 2
164
168
  MESSAGE_UNAVAILABLE = 3
165
169
  REPLY_MESSAGE_UNAVAILABLE = 4
170
+ NODE_UNAVAILABLE = 5
166
171
 
167
172
  def __new__(cls) -> ErrorCode:
168
173
  """Prevent instantiation."""
flwr/common/logger.py CHANGED
@@ -250,9 +250,9 @@ def warn_deprecated_feature_with_example(
250
250
  log(
251
251
  WARN,
252
252
  """FEATURE UPDATE: %s
253
- ------------------------------------------------------------
253
+ ------------------------------------------------------------
254
254
  %s
255
- ------------------------------------------------------------
255
+ ------------------------------------------------------------
256
256
  """,
257
257
  example_message,
258
258
  code_example,
@@ -17,22 +17,36 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
+ import gc
21
+ import sys
20
22
  from collections import OrderedDict
21
23
  from dataclasses import dataclass
22
24
  from io import BytesIO
23
- from typing import Any, cast, overload
25
+ from typing import TYPE_CHECKING, Any, cast, overload
24
26
 
25
27
  import numpy as np
26
28
 
27
- from ..constant import SType
29
+ from ..constant import GC_THRESHOLD, SType
28
30
  from ..typing import NDArray
29
31
  from .typeddict import TypedDict
30
32
 
33
+ if TYPE_CHECKING:
34
+ import torch
35
+
31
36
 
32
37
  def _raise_array_init_error() -> None:
33
38
  raise TypeError(
34
39
  f"Invalid arguments for {Array.__qualname__}. Expected either a "
35
- "NumPy ndarray, or explicit dtype/shape/stype/data values."
40
+ "PyTorch tensor, a NumPy ndarray, or explicit"
41
+ " dtype/shape/stype/data values."
42
+ )
43
+
44
+
45
+ def _raise_parameters_record_init_error() -> None:
46
+ raise TypeError(
47
+ f"Invalid arguments for {ParametersRecord.__qualname__}. Expected either "
48
+ "a list of NumPy ndarrays, a PyTorch state_dict, or a dictionary of Arrays. "
49
+ "The `keep_input` argument is keyword-only."
36
50
  )
37
51
 
38
52
 
@@ -41,37 +55,43 @@ class Array:
41
55
  """Array type.
42
56
 
43
57
  A dataclass containing serialized data from an array-like or tensor-like object
44
- along with metadata about it. The class can be initialized in one of two ways:
58
+ along with metadata about it. The class can be initialized in one of three ways:
45
59
 
46
60
  1. By specifying explicit values for `dtype`, `shape`, `stype`, and `data`.
47
61
  2. By providing a NumPy ndarray (via the `ndarray` argument).
62
+ 3. By providing a PyTorch tensor (via the `torch_tensor` argument).
48
63
 
49
- In scenario (2), the `dtype`, `shape`, `stype`, and `data` are automatically
64
+ In scenarios (2)-(3), the `dtype`, `shape`, `stype`, and `data` are automatically
50
65
  derived from the input. In scenario (1), these fields must be specified manually.
51
66
 
52
67
  Parameters
53
68
  ----------
54
69
  dtype : Optional[str] (default: None)
55
70
  A string representing the data type of the serialized object (e.g. `"float32"`).
56
- Only required if you are not passing in a ndarray.
71
+ Only required if you are not passing in a ndarray or a tensor.
57
72
 
58
73
  shape : Optional[list[int]] (default: None)
59
74
  A list representing the shape of the unserialized array-like object. Only
60
- required if you are not passing in a ndarray.
75
+ required if you are not passing in a ndarray or a tensor.
61
76
 
62
77
  stype : Optional[str] (default: None)
63
78
  A string indicating the serialization mechanism used to generate the bytes in
64
79
  `data` from an array-like or tensor-like object. Only required if you are not
65
- passing in a ndarray.
80
+ passing in a ndarray or a tensor.
66
81
 
67
82
  data : Optional[bytes] (default: None)
68
83
  A buffer of bytes containing the data. Only required if you are not passing in
69
- a ndarray.
84
+ a ndarray or a tensor.
70
85
 
71
86
  ndarray : Optional[NDArray] (default: None)
72
87
  A NumPy ndarray. If provided, the `dtype`, `shape`, `stype`, and `data`
73
88
  fields are derived automatically from it.
74
89
 
90
+ torch_tensor : Optional[torch.Tensor] (default: None)
91
+ A PyTorch tensor. If provided, it will be **detached and moved to CPU**
92
+ before conversion, and the `dtype`, `shape`, `stype`, and `data` fields
93
+ will be derived automatically from it.
94
+
75
95
  Examples
76
96
  --------
77
97
  Initializing by specifying all fields directly:
@@ -87,6 +107,11 @@ class Array:
87
107
 
88
108
  >>> import numpy as np
89
109
  >>> arr2 = Array(np.random.randn(3, 3))
110
+
111
+ Initializing with a PyTorch tensor:
112
+
113
+ >>> import torch
114
+ >>> arr3 = Array(torch.randn(3, 3))
90
115
  """
91
116
 
92
117
  dtype: str
@@ -102,6 +127,9 @@ class Array:
102
127
  @overload
103
128
  def __init__(self, ndarray: NDArray) -> None: ... # noqa: E704
104
129
 
130
+ @overload
131
+ def __init__(self, torch_tensor: torch.Tensor) -> None: ... # noqa: E704
132
+
105
133
  def __init__( # pylint: disable=too-many-arguments, too-many-locals
106
134
  self,
107
135
  *args: Any,
@@ -110,11 +138,13 @@ class Array:
110
138
  stype: str | None = None,
111
139
  data: bytes | None = None,
112
140
  ndarray: NDArray | None = None,
141
+ torch_tensor: torch.Tensor | None = None,
113
142
  ) -> None:
114
143
  # Determine the initialization method and validate input arguments.
115
- # Support two initialization formats:
144
+ # Support three initialization formats:
116
145
  # 1. Array(dtype: str, shape: list[int], stype: str, data: bytes)
117
146
  # 2. Array(ndarray: NDArray)
147
+ # 3. Array(torch_tensor: torch.Tensor)
118
148
 
119
149
  # Initialize all arguments
120
150
  # If more than 4 positional arguments are provided, raise an error.
@@ -149,6 +179,7 @@ class Array:
149
179
  _try_set_arg(2, stype, "direct")
150
180
  _try_set_arg(3, data, "direct")
151
181
  _try_set_arg(0, ndarray, "ndarray")
182
+ _try_set_arg(0, torch_tensor, "torch_tensor")
152
183
 
153
184
  # Check if all arguments are correctly set
154
185
  all_args = [arg for arg in all_args if arg is not None]
@@ -172,6 +203,16 @@ class Array:
172
203
  self.__dict__.update(self.from_numpy_ndarray(all_args[0]).__dict__)
173
204
  return
174
205
 
206
+ # Handle PyTorch tensor
207
+ if not init_method or init_method == "torch_tensor":
208
+ if (
209
+ len(all_args) == 1
210
+ and "torch" in sys.modules
211
+ and isinstance(all_args[0], sys.modules["torch"].Tensor)
212
+ ):
213
+ self.__dict__.update(self.from_torch_tensor(all_args[0]).__dict__)
214
+ return
215
+
175
216
  _raise_array_init_error()
176
217
 
177
218
  @classmethod
@@ -193,6 +234,19 @@ class Array:
193
234
  data=data,
194
235
  )
195
236
 
237
+ @classmethod
238
+ def from_torch_tensor(cls, tensor: torch.Tensor) -> Array:
239
+ """Create Array from PyTorch tensor."""
240
+ if not (torch := sys.modules.get("torch")):
241
+ raise RuntimeError(
242
+ f"PyTorch is required to use {cls.from_torch_tensor.__name__}"
243
+ )
244
+
245
+ assert isinstance(
246
+ tensor, torch.Tensor
247
+ ), f"Expected PyTorch Tensor, got {type(tensor)}"
248
+ return cls.from_numpy_ndarray(tensor.detach().cpu().numpy())
249
+
196
250
  def numpy(self) -> NDArray:
197
251
  """Return the array as a NumPy array."""
198
252
  if self.stype != SType.NUMPY:
@@ -223,103 +277,293 @@ def _check_value(value: Array) -> None:
223
277
  class ParametersRecord(TypedDict[str, Array]):
224
278
  r"""Parameters record.
225
279
 
226
- A dataclass storing named Arrays in order. This means that it holds entries as an
227
- OrderedDict[str, Array]. ParametersRecord objects can be viewed as an equivalent to
228
- PyTorch's state_dict, but holding serialised tensors instead. A
229
- :code:`ParametersRecord` is one of the types of records that a
230
- `flwr.common.RecordSet <flwr.common.RecordSet.html#recordset>`_ supports and
231
- can therefore be used to construct :code:`common.Message` objects.
280
+ A typed dictionary (``str`` to :class:`Array`) that can store named parameters
281
+ as serialized tensors. Internally, this behaves similarly to an
282
+ ``OrderedDict[str, Array]``. A ``ParametersRecord`` can be viewed as an
283
+ equivalent to PyTorch's ``state_dict``, but it holds arrays in serialized form.
284
+
285
+ This object is one of the record types supported by :class:`RecordSet` and can
286
+ therefore be stored in the ``content`` of a :class:`Message` or the ``state``
287
+ of a :class:`Context`.
288
+
289
+ This class can be instantiated in multiple ways:
290
+
291
+ 1. By providing nothing (empty container).
292
+ 2. By providing a dictionary of :class:`Array` (via the ``array_dict`` argument).
293
+ 3. By providing a list of NumPy ``ndarray`` (via the ``numpy_ndarrays`` argument).
294
+ 4. By providing a PyTorch ``state_dict`` (via the ``torch_state_dict`` argument).
232
295
 
233
296
  Parameters
234
297
  ----------
235
- array_dict : Optional[OrderedDict[str, Array]]
236
- A dictionary that stores serialized array-like or tensor-like objects.
237
- keep_input : bool (default: False)
238
- A boolean indicating whether parameters should be deleted from the input
239
- dictionary immediately after adding them to the record. If False, the
240
- dictionary passed to `set_parameters()` will be empty once exiting from that
241
- function. This is the desired behaviour when working with very large
242
- models/tensors/arrays. However, if you plan to continue working with your
243
- parameters after adding it to the record, set this flag to True. When set
244
- to True, the data is duplicated in memory.
298
+ array_dict : Optional[OrderedDict[str, Array]] (default: None)
299
+ An existing dictionary containing named :class:`Array` instances. If
300
+ provided, these entries will be used directly to populate the record.
301
+ numpy_ndarrays : Optional[list[NDArray]] (default: None)
302
+ A list of NumPy arrays. Each array will be automatically converted
303
+ into an :class:`Array` and stored in this record with generated keys.
304
+ torch_state_dict : Optional[OrderedDict[str, torch.Tensor]] (default: None)
305
+ A PyTorch ``state_dict`` (``str`` keys to ``torch.Tensor`` values). Each
306
+ tensor will be converted into an :class:`Array` and stored in this record.
307
+ keep_input : bool (default: True)
308
+ If ``False``, entries from the input are removed after being added to
309
+ this record to free up memory. If ``True``, the input remains unchanged.
310
+ Regardless of this value, no duplicate memory is used if the input is a
311
+ dictionary of :class:`Array`, i.e., ``array_dict``.
245
312
 
246
313
  Examples
247
314
  --------
248
- The usage of :code:`ParametersRecord` is envisioned for storing data arrays (e.g.
249
- parameters of a machine learning model). These first need to be serialized into
250
- a :code:`flwr.common.Array` data structure.
315
+ Initializing an empty ParametersRecord:
316
+
317
+ >>> p_record = ParametersRecord()
318
+
319
+ Initializing with a dictionary of :class:`Array`:
320
+
321
+ >>> arr = Array("float32", [5, 5], "numpy.ndarray", b"serialized_data...")
322
+ >>> p_record = ParametersRecord({"weight": arr})
251
323
 
252
- Let's see some examples:
324
+ Initializing with a list of NumPy arrays:
253
325
 
254
326
  >>> import numpy as np
255
- >>> from flwr.common import ParametersRecord
256
- >>>
257
- >>> # Let's create a simple NumPy array
258
- >>> arr_np = np.random.randn(3, 3)
259
- >>>
260
- >>> # If we print it
261
- >>> array([[-1.84242409, -1.01539537, -0.46528405],
262
- >>> [ 0.32991896, 0.55540414, 0.44085534],
263
- >>> [-0.10758364, 1.97619858, -0.37120501]])
264
- >>>
265
- >>> # Let's create an Array out of it
266
- >>> arr = Array(arr_np)
267
- >>>
268
- >>> # If we print it you'll see (note the binary data)
269
- >>> Array(dtype='float64', shape=[3,3], stype='numpy.ndarray', data=b'@\x99\x18...')
270
- >>>
271
- >>> # Adding it to a ParametersRecord:
272
- >>> p_record = ParametersRecord({"my_array": arr})
273
-
274
- Now that the NumPy array is embedded into a :code:`ParametersRecord` it could be
275
- sent if added as part of a :code:`common.Message` or it could be saved as a
276
- persistent state of a :code:`ClientApp` via its context. Regardless of the usecase,
277
- we will sooner or later want to recover the array in its original NumPy
278
- representation. For the example above, where the array was serialized using the
279
- built-in utility function, deserialization can be done as follows:
280
-
281
- >>> # Use the Array's built-in method
282
- >>> arr_np_d = arr.numpy()
283
- >>>
284
- >>> # If printed, it will show the exact same data as above:
285
- >>> array([[-1.84242409, -1.01539537, -0.46528405],
286
- >>> [ 0.32991896, 0.55540414, 0.44085534],
287
- >>> [-0.10758364, 1.97619858, -0.37120501]])
288
-
289
- If you need finer control on how your arrays are serialized and deserialized, you
290
- can construct :code:`Array` objects directly like this:
291
-
292
- >>> from flwr.common import Array
293
- >>> # Serialize your array and construct Array object
294
- >>> arr = Array(
295
- >>> data=ndarray.tobytes(),
296
- >>> dtype=str(ndarray.dtype),
297
- >>> stype="", # Could be used in a deserialization function
298
- >>> shape=list(ndarray.shape),
299
- >>> )
300
- >>>
301
- >>> # Then you can deserialize it like this
302
- >>> arr_np_d = np.frombuffer(
303
- >>> buffer=array.data,
304
- >>> dtype=array.dtype,
305
- >>> ).reshape(array.shape)
306
-
307
- Note that different arrays (e.g. from PyTorch, Tensorflow) might require different
308
- serialization mechanism. Howerver, they often support a conversion to NumPy,
309
- therefore allowing to use the same or similar steps as in the example above.
327
+ >>> arr1 = np.random.randn(3, 3)
328
+ >>> arr2 = np.random.randn(2, 2)
329
+ >>> p_record = ParametersRecord([arr1, arr2])
330
+
331
+ Initializing with a PyTorch model state_dict:
332
+
333
+ >>> import torch.nn as nn
334
+ >>> model = nn.Linear(10, 5)
335
+ >>> p_record = ParametersRecord(model.state_dict())
336
+
337
+ Initializing with a TensorFlow model weights (a list of NumPy arrays):
338
+
339
+ >>> import tensorflow as tf
340
+ >>> model = tf.keras.Sequential([tf.keras.layers.Dense(5, input_shape=(10,))])
341
+ >>> p_record = ParametersRecord(model.get_weights())
310
342
  """
311
343
 
312
- def __init__(
344
+ @overload
345
+ def __init__(self) -> None: ... # noqa: E704
346
+
347
+ @overload
348
+ def __init__( # noqa: E704
349
+ self, array_dict: OrderedDict[str, Array], *, keep_input: bool = True
350
+ ) -> None: ...
351
+
352
+ @overload
353
+ def __init__( # noqa: E704
354
+ self, numpy_ndarrays: list[NDArray], *, keep_input: bool = True
355
+ ) -> None: ...
356
+
357
+ @overload
358
+ def __init__( # noqa: E704
359
+ self,
360
+ torch_state_dict: OrderedDict[str, torch.Tensor],
361
+ *,
362
+ keep_input: bool = True,
363
+ ) -> None: ...
364
+
365
+ def __init__( # pylint: disable=too-many-arguments
313
366
  self,
367
+ *args: Any,
368
+ numpy_ndarrays: list[NDArray] | None = None,
369
+ torch_state_dict: OrderedDict[str, torch.Tensor] | None = None,
314
370
  array_dict: OrderedDict[str, Array] | None = None,
315
- keep_input: bool = False,
371
+ keep_input: bool = True,
316
372
  ) -> None:
317
373
  super().__init__(_check_key, _check_value)
318
- if array_dict:
319
- for k in list(array_dict.keys()):
320
- self[k] = array_dict[k]
321
- if not keep_input:
322
- del array_dict[k]
374
+
375
+ # Determine the initialization method and validates input arguments.
376
+ # Support the following initialization formats:
377
+ # 1. cls(array_dict: OrderedDict[str, Array], keep_input: bool)
378
+ # 2. cls(numpy_ndarrays: list[NDArray], keep_input: bool)
379
+ # 3. cls(torch_state_dict: dict[str, torch.Tensor], keep_input: bool)
380
+
381
+ # Init the argument
382
+ if len(args) > 1:
383
+ _raise_parameters_record_init_error()
384
+ arg = args[0] if args else None
385
+ init_method: str | None = None # Track which init method is being used
386
+
387
+ # Try to assign a value to arg if it's not already set.
388
+ # If an initialization method is provided, update init_method.
389
+ def _try_set_arg(_arg: Any, method: str) -> None:
390
+ # Skip if _arg is None
391
+ if _arg is None:
392
+ return
393
+ nonlocal arg, init_method
394
+ # Raise an error if arg is already set
395
+ if arg is not None:
396
+ _raise_parameters_record_init_error()
397
+ # Raise an error if a different initialization method is already set
398
+ if init_method is not None:
399
+ _raise_parameters_record_init_error()
400
+ # Set init_method and arg
401
+ if init_method is None:
402
+ init_method = method
403
+ arg = _arg
404
+
405
+ # Try to set keyword arguments
406
+ _try_set_arg(array_dict, "array_dict")
407
+ _try_set_arg(numpy_ndarrays, "numpy_ndarrays")
408
+ _try_set_arg(torch_state_dict, "state_dict")
409
+
410
+ # If no arguments are provided, return and keep self empty
411
+ if arg is None:
412
+ return
413
+
414
+ # Handle dictionary of Arrays
415
+ if not init_method or init_method == "array_dict":
416
+ # Type check the input
417
+ if (
418
+ isinstance(arg, dict)
419
+ and all(isinstance(k, str) for k in arg.keys())
420
+ and all(isinstance(v, Array) for v in arg.values())
421
+ ):
422
+ array_dict = cast(OrderedDict[str, Array], arg)
423
+ converted = self.from_array_dict(array_dict, keep_input=keep_input)
424
+ self.__dict__.update(converted.__dict__)
425
+ return
426
+
427
+ # Handle NumPy ndarrays
428
+ if not init_method or init_method == "numpy_ndarrays":
429
+ # Type check the input
430
+ # pylint: disable-next=not-an-iterable
431
+ if isinstance(arg, list) and all(isinstance(v, np.ndarray) for v in arg):
432
+ numpy_ndarrays = cast(list[NDArray], arg)
433
+ converted = self.from_numpy_ndarrays(
434
+ numpy_ndarrays, keep_input=keep_input
435
+ )
436
+ self.__dict__.update(converted.__dict__)
437
+ return
438
+
439
+ # Handle PyTorch state_dict
440
+ if not init_method or init_method == "state_dict":
441
+ # Type check the input
442
+ if (
443
+ (torch := sys.modules.get("torch")) is not None
444
+ and isinstance(arg, dict)
445
+ and all(isinstance(k, str) for k in arg.keys())
446
+ and all(isinstance(v, torch.Tensor) for v in arg.values())
447
+ ):
448
+ torch_state_dict = cast(
449
+ OrderedDict[str, torch.Tensor], arg # type: ignore
450
+ )
451
+ converted = self.from_torch_state_dict(
452
+ torch_state_dict, keep_input=keep_input
453
+ )
454
+ self.__dict__.update(converted.__dict__)
455
+ return
456
+
457
+ _raise_parameters_record_init_error()
458
+
459
+ @classmethod
460
+ def from_array_dict(
461
+ cls,
462
+ array_dict: OrderedDict[str, Array],
463
+ *,
464
+ keep_input: bool = True,
465
+ ) -> ParametersRecord:
466
+ """Create ParametersRecord from a dictionary of :class:`Array`."""
467
+ record = ParametersRecord()
468
+ for k, v in array_dict.items():
469
+ record[k] = Array(
470
+ dtype=v.dtype, shape=list(v.shape), stype=v.stype, data=v.data
471
+ )
472
+ if not keep_input:
473
+ array_dict.clear()
474
+ return record
475
+
476
+ @classmethod
477
+ def from_numpy_ndarrays(
478
+ cls,
479
+ ndarrays: list[NDArray],
480
+ *,
481
+ keep_input: bool = True,
482
+ ) -> ParametersRecord:
483
+ """Create ParametersRecord from a list of NumPy ``ndarray``."""
484
+ record = ParametersRecord()
485
+ total_serialized_bytes = 0
486
+
487
+ for i in range(len(ndarrays)): # pylint: disable=C0200
488
+ record[str(i)] = Array.from_numpy_ndarray(ndarrays[i])
489
+
490
+ if not keep_input:
491
+ # Remove the reference
492
+ ndarrays[i] = None # type: ignore
493
+ total_serialized_bytes += len(record[str(i)].data)
494
+
495
+ # If total serialized data exceeds the threshold, trigger GC
496
+ if total_serialized_bytes > GC_THRESHOLD:
497
+ total_serialized_bytes = 0
498
+ gc.collect()
499
+
500
+ if not keep_input:
501
+ # Clear the entire list to remove all references and force GC
502
+ ndarrays.clear()
503
+ gc.collect()
504
+ return record
505
+
506
+ @classmethod
507
+ def from_torch_state_dict(
508
+ cls,
509
+ state_dict: OrderedDict[str, torch.Tensor],
510
+ *,
511
+ keep_input: bool = True,
512
+ ) -> ParametersRecord:
513
+ """Create ParametersRecord from PyTorch ``state_dict``."""
514
+ if "torch" not in sys.modules:
515
+ raise RuntimeError(
516
+ f"PyTorch is required to use {cls.from_torch_state_dict.__name__}"
517
+ )
518
+
519
+ record = ParametersRecord()
520
+
521
+ for k in list(state_dict.keys()):
522
+ v = state_dict[k] if keep_input else state_dict.pop(k)
523
+ record[k] = Array.from_numpy_ndarray(v.detach().cpu().numpy())
524
+
525
+ return record
526
+
527
+ def to_numpy_ndarrays(self, *, keep_input: bool = True) -> list[NDArray]:
528
+ """Return the ParametersRecord as a list of NumPy ``ndarray``."""
529
+ if keep_input:
530
+ return [v.numpy() for v in self.values()]
531
+
532
+ # Clear the record and return the list of NumPy arrays
533
+ ret: list[NDArray] = []
534
+ total_serialized_bytes = 0
535
+ for k in list(self.keys()):
536
+ arr = self.pop(k)
537
+ ret.append(arr.numpy())
538
+ total_serialized_bytes += len(arr.data)
539
+ del arr
540
+
541
+ # If total serialized data exceeds the threshold, trigger GC
542
+ if total_serialized_bytes > GC_THRESHOLD:
543
+ total_serialized_bytes = 0
544
+ gc.collect()
545
+
546
+ if not keep_input:
547
+ # Force GC
548
+ gc.collect()
549
+ return ret
550
+
551
+ def to_torch_state_dict(
552
+ self, *, keep_input: bool = True
553
+ ) -> OrderedDict[str, torch.Tensor]:
554
+ """Return the ParametersRecord as a PyTorch ``state_dict``."""
555
+ if not (torch := sys.modules.get("torch")):
556
+ raise RuntimeError(
557
+ f"PyTorch is required to use {self.to_torch_state_dict.__name__}"
558
+ )
559
+
560
+ state_dict = OrderedDict()
561
+
562
+ for k in list(self.keys()):
563
+ arr = self[k] if keep_input else self.pop(k)
564
+ state_dict[k] = torch.from_numpy(arr.numpy())
565
+
566
+ return state_dict
323
567
 
324
568
  def count_bytes(self) -> int:
325
569
  """Return number of Bytes stored in this object.
flwr/server/__init__.py CHANGED
@@ -21,7 +21,8 @@ from .app import start_server as start_server
21
21
  from .client_manager import ClientManager as ClientManager
22
22
  from .client_manager import SimpleClientManager as SimpleClientManager
23
23
  from .compat import LegacyContext as LegacyContext
24
- from .driver import Driver as Driver
24
+ from .grid import Driver as Driver
25
+ from .grid import Grid as Grid
25
26
  from .history import History as History
26
27
  from .server import Server as Server
27
28
  from .server_app import ServerApp as ServerApp
@@ -31,6 +32,7 @@ from .serverapp_components import ServerAppComponents as ServerAppComponents
31
32
  __all__ = [
32
33
  "ClientManager",
33
34
  "Driver",
35
+ "Grid",
34
36
  "History",
35
37
  "LegacyContext",
36
38
  "Server",
flwr/server/app.py CHANGED
@@ -79,13 +79,13 @@ from .history import History
79
79
  from .server import Server, init_defaults, run_fl
80
80
  from .server_config import ServerConfig
81
81
  from .strategy import Strategy
82
- from .superlink.driver.serverappio_grpc import run_serverappio_api_grpc
83
82
  from .superlink.ffs.ffs_factory import FfsFactory
84
83
  from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
85
84
  from .superlink.fleet.grpc_bidi.grpc_server import start_grpc_server
86
85
  from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
87
86
  from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
88
87
  from .superlink.linkstate import LinkStateFactory
88
+ from .superlink.serverappio.serverappio_grpc import run_serverappio_api_grpc
89
89
  from .superlink.simulation.simulationio_grpc import run_simulationio_api_grpc
90
90
 
91
91
  DATABASE = ":flwr-in-memory-state:"
@@ -15,10 +15,10 @@
15
15
  """Flower ServerApp compatibility package."""
16
16
 
17
17
 
18
- from .app import start_driver as start_driver
18
+ from .app import start_grid as start_grid
19
19
  from .legacy_context import LegacyContext as LegacyContext
20
20
 
21
21
  __all__ = [
22
22
  "LegacyContext",
23
- "start_driver",
23
+ "start_grid",
24
24
  ]