flwr-nightly 1.17.0.dev20250320__py3-none-any.whl → 1.17.0.dev20250322__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/run/run.py +5 -9
- flwr/client/client_app.py +10 -12
- flwr/client/grpc_client/connection.py +3 -3
- flwr/client/message_handler/message_handler.py +3 -3
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +26 -26
- flwr/common/__init__.py +10 -4
- flwr/common/config.py +4 -4
- flwr/common/constant.py +1 -1
- flwr/common/record/__init__.py +6 -3
- flwr/common/record/{parametersrecord.py → arrayrecord.py} +74 -31
- flwr/common/record/{configsrecord.py → configrecord.py} +73 -27
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/{metricsrecord.py → metricrecord.py} +77 -31
- flwr/common/record/recorddict.py +95 -56
- flwr/common/recorddict_compat.py +54 -62
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +42 -43
- flwr/common/typing.py +8 -8
- flwr/proto/exec_pb2.py +30 -30
- flwr/proto/exec_pb2.pyi +2 -2
- flwr/proto/recorddict_pb2.py +29 -29
- flwr/proto/recorddict_pb2.pyi +33 -33
- flwr/proto/run_pb2.py +2 -2
- flwr/proto/run_pb2.pyi +2 -2
- flwr/server/compat/grid_client_proxy.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +4 -4
- flwr/server/superlink/linkstate/linkstate.py +4 -4
- flwr/server/superlink/linkstate/sqlite_linkstate.py +7 -7
- flwr/server/superlink/linkstate/utils.py +9 -9
- flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/workflow/default_workflows.py +27 -34
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +32 -34
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/run_simulation.py +2 -2
- flwr/superexec/deployment.py +3 -3
- flwr/superexec/exec_servicer.py +2 -2
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +2 -2
- {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/METADATA +1 -1
- {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/RECORD +49 -49
- {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.17.0.dev20250320.dist-info → flwr_nightly-1.17.0.dev20250322.dist-info}/entry_points.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""
|
15
|
+
"""ArrayRecord and Array."""
|
16
16
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
@@ -22,11 +22,13 @@ import sys
|
|
22
22
|
from collections import OrderedDict
|
23
23
|
from dataclasses import dataclass
|
24
24
|
from io import BytesIO
|
25
|
+
from logging import WARN
|
25
26
|
from typing import TYPE_CHECKING, Any, cast, overload
|
26
27
|
|
27
28
|
import numpy as np
|
28
29
|
|
29
30
|
from ..constant import GC_THRESHOLD, SType
|
31
|
+
from ..logger import log
|
30
32
|
from ..typing import NDArray
|
31
33
|
from .typeddict import TypedDict
|
32
34
|
|
@@ -42,9 +44,9 @@ def _raise_array_init_error() -> None:
|
|
42
44
|
)
|
43
45
|
|
44
46
|
|
45
|
-
def
|
47
|
+
def _raise_array_record_init_error() -> None:
|
46
48
|
raise TypeError(
|
47
|
-
f"Invalid arguments for {
|
49
|
+
f"Invalid arguments for {ArrayRecord.__qualname__}. Expected either "
|
48
50
|
"a list of NumPy ndarrays, a PyTorch state_dict, or a dictionary of Arrays. "
|
49
51
|
"The `keep_input` argument is keyword-only."
|
50
52
|
)
|
@@ -274,13 +276,14 @@ def _check_value(value: Array) -> None:
|
|
274
276
|
)
|
275
277
|
|
276
278
|
|
277
|
-
class
|
278
|
-
|
279
|
+
class ArrayRecord(TypedDict[str, Array]):
|
280
|
+
"""Array record.
|
279
281
|
|
280
|
-
A typed dictionary (``str`` to :class:`Array`) that can store named
|
281
|
-
|
282
|
-
``OrderedDict[str, Array]``.
|
283
|
-
equivalent to PyTorch's ``state_dict``,
|
282
|
+
A typed dictionary (``str`` to :class:`Array`) that can store named arrays,
|
283
|
+
including model parameters, gradients, embeddings or non-parameter arrays.
|
284
|
+
Internally, this behaves similarly to an ``OrderedDict[str, Array]``.
|
285
|
+
An ``ArrayRecord`` can be viewed as an equivalent to PyTorch's ``state_dict``,
|
286
|
+
but it holds arrays in a serialized form.
|
284
287
|
|
285
288
|
This object is one of the record types supported by :class:`RecordDict` and can
|
286
289
|
therefore be stored in the ``content`` of a :class:`Message` or the ``state``
|
@@ -312,33 +315,33 @@ class ParametersRecord(TypedDict[str, Array]):
|
|
312
315
|
|
313
316
|
Examples
|
314
317
|
--------
|
315
|
-
Initializing an empty
|
318
|
+
Initializing an empty ArrayRecord:
|
316
319
|
|
317
|
-
>>>
|
320
|
+
>>> record = ArrayRecord()
|
318
321
|
|
319
322
|
Initializing with a dictionary of :class:`Array`:
|
320
323
|
|
321
324
|
>>> arr = Array("float32", [5, 5], "numpy.ndarray", b"serialized_data...")
|
322
|
-
>>>
|
325
|
+
>>> record = ArrayRecord({"weight": arr})
|
323
326
|
|
324
327
|
Initializing with a list of NumPy arrays:
|
325
328
|
|
326
329
|
>>> import numpy as np
|
327
330
|
>>> arr1 = np.random.randn(3, 3)
|
328
331
|
>>> arr2 = np.random.randn(2, 2)
|
329
|
-
>>>
|
332
|
+
>>> record = ArrayRecord([arr1, arr2])
|
330
333
|
|
331
334
|
Initializing with a PyTorch model state_dict:
|
332
335
|
|
333
336
|
>>> import torch.nn as nn
|
334
337
|
>>> model = nn.Linear(10, 5)
|
335
|
-
>>>
|
338
|
+
>>> record = ArrayRecord(model.state_dict())
|
336
339
|
|
337
340
|
Initializing with a TensorFlow model weights (a list of NumPy arrays):
|
338
341
|
|
339
342
|
>>> import tensorflow as tf
|
340
343
|
>>> model = tf.keras.Sequential([tf.keras.layers.Dense(5, input_shape=(10,))])
|
341
|
-
>>>
|
344
|
+
>>> record = ArrayRecord(model.get_weights())
|
342
345
|
"""
|
343
346
|
|
344
347
|
@overload
|
@@ -380,7 +383,7 @@ class ParametersRecord(TypedDict[str, Array]):
|
|
380
383
|
|
381
384
|
# Init the argument
|
382
385
|
if len(args) > 1:
|
383
|
-
|
386
|
+
_raise_array_record_init_error()
|
384
387
|
arg = args[0] if args else None
|
385
388
|
init_method: str | None = None # Track which init method is being used
|
386
389
|
|
@@ -393,10 +396,10 @@ class ParametersRecord(TypedDict[str, Array]):
|
|
393
396
|
nonlocal arg, init_method
|
394
397
|
# Raise an error if arg is already set
|
395
398
|
if arg is not None:
|
396
|
-
|
399
|
+
_raise_array_record_init_error()
|
397
400
|
# Raise an error if a different initialization method is already set
|
398
401
|
if init_method is not None:
|
399
|
-
|
402
|
+
_raise_array_record_init_error()
|
400
403
|
# Set init_method and arg
|
401
404
|
if init_method is None:
|
402
405
|
init_method = method
|
@@ -454,7 +457,7 @@ class ParametersRecord(TypedDict[str, Array]):
|
|
454
457
|
self.__dict__.update(converted.__dict__)
|
455
458
|
return
|
456
459
|
|
457
|
-
|
460
|
+
_raise_array_record_init_error()
|
458
461
|
|
459
462
|
@classmethod
|
460
463
|
def from_array_dict(
|
@@ -462,9 +465,9 @@ class ParametersRecord(TypedDict[str, Array]):
|
|
462
465
|
array_dict: OrderedDict[str, Array],
|
463
466
|
*,
|
464
467
|
keep_input: bool = True,
|
465
|
-
) ->
|
466
|
-
"""Create
|
467
|
-
record =
|
468
|
+
) -> ArrayRecord:
|
469
|
+
"""Create ArrayRecord from a dictionary of :class:`Array`."""
|
470
|
+
record = ArrayRecord()
|
468
471
|
for k, v in array_dict.items():
|
469
472
|
record[k] = Array(
|
470
473
|
dtype=v.dtype, shape=list(v.shape), stype=v.stype, data=v.data
|
@@ -479,9 +482,9 @@ class ParametersRecord(TypedDict[str, Array]):
|
|
479
482
|
ndarrays: list[NDArray],
|
480
483
|
*,
|
481
484
|
keep_input: bool = True,
|
482
|
-
) ->
|
483
|
-
"""Create
|
484
|
-
record =
|
485
|
+
) -> ArrayRecord:
|
486
|
+
"""Create ArrayRecord from a list of NumPy ``ndarray``."""
|
487
|
+
record = ArrayRecord()
|
485
488
|
total_serialized_bytes = 0
|
486
489
|
|
487
490
|
for i in range(len(ndarrays)): # pylint: disable=C0200
|
@@ -509,14 +512,14 @@ class ParametersRecord(TypedDict[str, Array]):
|
|
509
512
|
state_dict: OrderedDict[str, torch.Tensor],
|
510
513
|
*,
|
511
514
|
keep_input: bool = True,
|
512
|
-
) ->
|
513
|
-
"""Create
|
515
|
+
) -> ArrayRecord:
|
516
|
+
"""Create ArrayRecord from PyTorch ``state_dict``."""
|
514
517
|
if "torch" not in sys.modules:
|
515
518
|
raise RuntimeError(
|
516
519
|
f"PyTorch is required to use {cls.from_torch_state_dict.__name__}"
|
517
520
|
)
|
518
521
|
|
519
|
-
record =
|
522
|
+
record = ArrayRecord()
|
520
523
|
|
521
524
|
for k in list(state_dict.keys()):
|
522
525
|
v = state_dict[k] if keep_input else state_dict.pop(k)
|
@@ -525,7 +528,7 @@ class ParametersRecord(TypedDict[str, Array]):
|
|
525
528
|
return record
|
526
529
|
|
527
530
|
def to_numpy_ndarrays(self, *, keep_input: bool = True) -> list[NDArray]:
|
528
|
-
"""Return the
|
531
|
+
"""Return the ArrayRecord as a list of NumPy ``ndarray``."""
|
529
532
|
if keep_input:
|
530
533
|
return [v.numpy() for v in self.values()]
|
531
534
|
|
@@ -551,7 +554,7 @@ class ParametersRecord(TypedDict[str, Array]):
|
|
551
554
|
def to_torch_state_dict(
|
552
555
|
self, *, keep_input: bool = True
|
553
556
|
) -> OrderedDict[str, torch.Tensor]:
|
554
|
-
"""Return the
|
557
|
+
"""Return the ArrayRecord as a PyTorch ``state_dict``."""
|
555
558
|
if not (torch := sys.modules.get("torch")):
|
556
559
|
raise RuntimeError(
|
557
560
|
f"PyTorch is required to use {self.to_torch_state_dict.__name__}"
|
@@ -581,3 +584,43 @@ class ParametersRecord(TypedDict[str, Array]):
|
|
581
584
|
num_bytes += len(k)
|
582
585
|
|
583
586
|
return num_bytes
|
587
|
+
|
588
|
+
|
589
|
+
class ParametersRecord(ArrayRecord):
|
590
|
+
"""Deprecated class ``ParametersRecord``, use ``ArrayRecord`` instead.
|
591
|
+
|
592
|
+
This class exists solely for backward compatibility with legacy
|
593
|
+
code that previously used ``ParametersRecord``. It has been renamed
|
594
|
+
to ``ArrayRecord``.
|
595
|
+
|
596
|
+
.. warning::
|
597
|
+
``ParametersRecord`` is deprecated and will be removed in a future release.
|
598
|
+
Use ``ArrayRecord`` instead.
|
599
|
+
|
600
|
+
Examples
|
601
|
+
--------
|
602
|
+
Legacy (deprecated) usage::
|
603
|
+
|
604
|
+
from flwr.common import ParametersRecord
|
605
|
+
|
606
|
+
record = ParametersRecord()
|
607
|
+
|
608
|
+
Updated usage::
|
609
|
+
|
610
|
+
from flwr.common import ArrayRecord
|
611
|
+
|
612
|
+
record = ArrayRecord()
|
613
|
+
"""
|
614
|
+
|
615
|
+
_warning_logged = False
|
616
|
+
|
617
|
+
def __init__(self, *args: Any, **kwargs: dict[str, Any]) -> None:
|
618
|
+
if not ParametersRecord._warning_logged:
|
619
|
+
ParametersRecord._warning_logged = True
|
620
|
+
log(
|
621
|
+
WARN,
|
622
|
+
"The `ParametersRecord` class has been renamed to `ArrayRecord`. "
|
623
|
+
"Support for `ParametersRecord` will be removed in a future release. "
|
624
|
+
"Please update your code accordingly.",
|
625
|
+
)
|
626
|
+
super().__init__(*args, **kwargs)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,13 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""
|
15
|
+
"""ConfigRecord."""
|
16
16
|
|
17
17
|
|
18
|
+
from logging import WARN
|
18
19
|
from typing import Optional, get_args
|
19
20
|
|
20
|
-
from flwr.common.typing import
|
21
|
+
from flwr.common.typing import ConfigRecordValues, ConfigScalar
|
21
22
|
|
23
|
+
from ..logger import log
|
22
24
|
from .typeddict import TypedDict
|
23
25
|
|
24
26
|
|
@@ -28,20 +30,20 @@ def _check_key(key: str) -> None:
|
|
28
30
|
raise TypeError(f"Key must be of type `str` but `{type(key)}` was passed.")
|
29
31
|
|
30
32
|
|
31
|
-
def _check_value(value:
|
32
|
-
def is_valid(__v:
|
33
|
+
def _check_value(value: ConfigRecordValues) -> None:
|
34
|
+
def is_valid(__v: ConfigScalar) -> None:
|
33
35
|
"""Check if value is of expected type."""
|
34
|
-
if not isinstance(__v, get_args(
|
36
|
+
if not isinstance(__v, get_args(ConfigScalar)):
|
35
37
|
raise TypeError(
|
36
38
|
"Not all values are of valid type."
|
37
|
-
f" Expected `{
|
39
|
+
f" Expected `{ConfigRecordValues}` but `{type(__v)}` was passed."
|
38
40
|
)
|
39
41
|
|
40
42
|
if isinstance(value, list):
|
41
43
|
# If your lists are large (e.g. 1M+ elements) this will be slow
|
42
44
|
# 1s to check 10M element list on a M2 Pro
|
43
45
|
# In such settings, you'd be better of treating such config as
|
44
|
-
# an array and pass it to a
|
46
|
+
# an array and pass it to a ArrayRecord.
|
45
47
|
# Empty lists are valid
|
46
48
|
if len(value) > 0:
|
47
49
|
is_valid(value[0])
|
@@ -51,24 +53,24 @@ def _check_value(value: ConfigsRecordValues) -> None:
|
|
51
53
|
if not all(isinstance(v, value_type) for v in value):
|
52
54
|
raise TypeError(
|
53
55
|
"All values in a list must be of the same valid type. "
|
54
|
-
f"One of {
|
56
|
+
f"One of {ConfigScalar}."
|
55
57
|
)
|
56
58
|
else:
|
57
59
|
is_valid(value)
|
58
60
|
|
59
61
|
|
60
|
-
class
|
62
|
+
class ConfigRecord(TypedDict[str, ConfigRecordValues]):
|
61
63
|
"""Configs record.
|
62
64
|
|
63
|
-
A :code:`
|
64
|
-
each key-value pair adheres to specified data types. A :code:`
|
65
|
+
A :code:`ConfigRecord` is a Python dictionary designed to ensure that
|
66
|
+
each key-value pair adheres to specified data types. A :code:`ConfigRecord`
|
65
67
|
is one of the types of records that a
|
66
68
|
`flwr.common.RecordDict <flwr.common.RecordDict.html#recorddict>`_ supports and
|
67
69
|
can therefore be used to construct :code:`common.Message` objects.
|
68
70
|
|
69
71
|
Parameters
|
70
72
|
----------
|
71
|
-
|
73
|
+
config_dict : Optional[Dict[str, ConfigRecordValues]]
|
72
74
|
A dictionary that stores basic types (i.e. `str`, `int`, `float`, `bytes` as
|
73
75
|
defined in `ConfigsScalar`) and lists of such types (see
|
74
76
|
`ConfigsScalarList`).
|
@@ -80,20 +82,20 @@ class ConfigsRecord(TypedDict[str, ConfigsRecordValues]):
|
|
80
82
|
|
81
83
|
Examples
|
82
84
|
--------
|
83
|
-
The usage of a :code:`
|
85
|
+
The usage of a :code:`ConfigRecord` is envisioned for sending configuration values
|
84
86
|
telling the target node how to perform a certain action (e.g. train/evaluate a model
|
85
87
|
). You can use standard Python built-in types such as :code:`float`, :code:`str`
|
86
88
|
, :code:`bytes`. All types allowed are defined in
|
87
|
-
:code:`flwr.common.
|
88
|
-
encourage you to use a :code:`
|
89
|
+
:code:`flwr.common.ConfigRecordValues`. While lists are supported, we
|
90
|
+
encourage you to use a :code:`ArrayRecord` instead if these are of high
|
89
91
|
dimensionality.
|
90
92
|
|
91
|
-
Let's see some examples of how to construct a :code:`
|
93
|
+
Let's see some examples of how to construct a :code:`ConfigRecord` from scratch:
|
92
94
|
|
93
|
-
>>> from flwr.common import
|
95
|
+
>>> from flwr.common import ConfigRecord
|
94
96
|
>>>
|
95
|
-
>>> # A `
|
96
|
-
>>> record =
|
97
|
+
>>> # A `ConfigRecord` is a specialized Python dictionary
|
98
|
+
>>> record = ConfigRecord({"lr": 0.1, "batch-size": 128})
|
97
99
|
>>> # You can add more content to an existing record
|
98
100
|
>>> record["compute-average"] = True
|
99
101
|
>>> # It also supports lists
|
@@ -104,21 +106,21 @@ class ConfigsRecord(TypedDict[str, ConfigsRecordValues]):
|
|
104
106
|
Just like the other types of records in a :code:`flwr.common.RecordDict`, types are
|
105
107
|
enforced. If you need to add a custom data structure or object, we recommend to
|
106
108
|
serialise it into bytes and save it as such (bytes are allowed in a
|
107
|
-
:code:`
|
109
|
+
:code:`ConfigRecord`)
|
108
110
|
"""
|
109
111
|
|
110
112
|
def __init__(
|
111
113
|
self,
|
112
|
-
|
114
|
+
config_dict: Optional[dict[str, ConfigRecordValues]] = None,
|
113
115
|
keep_input: bool = True,
|
114
116
|
) -> None:
|
115
117
|
|
116
118
|
super().__init__(_check_key, _check_value)
|
117
|
-
if
|
118
|
-
for k in list(
|
119
|
-
self[k] =
|
119
|
+
if config_dict:
|
120
|
+
for k in list(config_dict.keys()):
|
121
|
+
self[k] = config_dict[k]
|
120
122
|
if not keep_input:
|
121
|
-
del
|
123
|
+
del config_dict[k]
|
122
124
|
|
123
125
|
def count_bytes(self) -> int:
|
124
126
|
"""Return number of Bytes stored in this object.
|
@@ -126,7 +128,7 @@ class ConfigsRecord(TypedDict[str, ConfigsRecordValues]):
|
|
126
128
|
This function counts booleans as occupying 1 Byte.
|
127
129
|
"""
|
128
130
|
|
129
|
-
def get_var_bytes(value:
|
131
|
+
def get_var_bytes(value: ConfigScalar) -> int:
|
130
132
|
"""Return Bytes of value passed."""
|
131
133
|
var_bytes = 0
|
132
134
|
if isinstance(value, bool):
|
@@ -161,3 +163,47 @@ class ConfigsRecord(TypedDict[str, ConfigsRecordValues]):
|
|
161
163
|
num_bytes += len(k)
|
162
164
|
|
163
165
|
return num_bytes
|
166
|
+
|
167
|
+
|
168
|
+
class ConfigsRecord(ConfigRecord):
|
169
|
+
"""Deprecated class ``ConfigsRecord``, use ``ConfigRecord`` instead.
|
170
|
+
|
171
|
+
This class exists solely for backward compatibility with legacy
|
172
|
+
code that previously used ``ConfigsRecord``. It has been renamed
|
173
|
+
to ``ConfigRecord``.
|
174
|
+
|
175
|
+
.. warning::
|
176
|
+
``ConfigsRecord`` is deprecated and will be removed in a future release.
|
177
|
+
Use ``ConfigRecord`` instead.
|
178
|
+
|
179
|
+
Examples
|
180
|
+
--------
|
181
|
+
Legacy (deprecated) usage::
|
182
|
+
|
183
|
+
from flwr.common import ConfigsRecord
|
184
|
+
|
185
|
+
record = ConfigsRecord()
|
186
|
+
|
187
|
+
Updated usage::
|
188
|
+
|
189
|
+
from flwr.common import ConfigRecord
|
190
|
+
|
191
|
+
record = ConfigRecord()
|
192
|
+
"""
|
193
|
+
|
194
|
+
_warning_logged = False
|
195
|
+
|
196
|
+
def __init__(
|
197
|
+
self,
|
198
|
+
config_dict: Optional[dict[str, ConfigRecordValues]] = None,
|
199
|
+
keep_input: bool = True,
|
200
|
+
):
|
201
|
+
if not ConfigsRecord._warning_logged:
|
202
|
+
ConfigsRecord._warning_logged = True
|
203
|
+
log(
|
204
|
+
WARN,
|
205
|
+
"The `ConfigsRecord` class has been renamed to `ConfigRecord`. "
|
206
|
+
"Support for `ConfigsRecord` will be removed in a future release. "
|
207
|
+
"Please update your code accordingly.",
|
208
|
+
)
|
209
|
+
super().__init__(config_dict, keep_input)
|
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
from ..logger import warn_deprecated_feature
|
19
19
|
from ..typing import NDArray
|
20
|
-
from .
|
20
|
+
from .arrayrecord import Array
|
21
21
|
|
22
22
|
WARN_DEPRECATED_MESSAGE = (
|
23
23
|
"`array_from_numpy` is deprecated. Instead, use the `Array(ndarray)` class "
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,13 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""
|
15
|
+
"""MetricRecord."""
|
16
16
|
|
17
17
|
|
18
|
+
from logging import WARN
|
18
19
|
from typing import Optional, get_args
|
19
20
|
|
20
|
-
from flwr.common.typing import
|
21
|
+
from flwr.common.typing import MetricRecordValues, MetricScalar
|
21
22
|
|
23
|
+
from ..logger import log
|
22
24
|
from .typeddict import TypedDict
|
23
25
|
|
24
26
|
|
@@ -28,20 +30,20 @@ def _check_key(key: str) -> None:
|
|
28
30
|
raise TypeError(f"Key must be of type `str` but `{type(key)}` was passed.")
|
29
31
|
|
30
32
|
|
31
|
-
def _check_value(value:
|
32
|
-
def is_valid(__v:
|
33
|
+
def _check_value(value: MetricRecordValues) -> None:
|
34
|
+
def is_valid(__v: MetricScalar) -> None:
|
33
35
|
"""Check if value is of expected type."""
|
34
|
-
if not isinstance(__v, get_args(
|
36
|
+
if not isinstance(__v, get_args(MetricScalar)) or isinstance(__v, bool):
|
35
37
|
raise TypeError(
|
36
38
|
"Not all values are of valid type."
|
37
|
-
f" Expected `{
|
39
|
+
f" Expected `{MetricRecordValues}` but `{type(__v)}` was passed."
|
38
40
|
)
|
39
41
|
|
40
42
|
if isinstance(value, list):
|
41
43
|
# If your lists are large (e.g. 1M+ elements) this will be slow
|
42
44
|
# 1s to check 10M element list on a M2 Pro
|
43
45
|
# In such settings, you'd be better of treating such metric as
|
44
|
-
# an array and pass it to
|
46
|
+
# an array and pass it to an ArrayRecord.
|
45
47
|
# Empty lists are valid
|
46
48
|
if len(value) > 0:
|
47
49
|
is_valid(value[0])
|
@@ -51,26 +53,26 @@ def _check_value(value: MetricsRecordValues) -> None:
|
|
51
53
|
if not all(isinstance(v, value_type) for v in value):
|
52
54
|
raise TypeError(
|
53
55
|
"All values in a list must be of the same valid type. "
|
54
|
-
f"One of {
|
56
|
+
f"One of {MetricScalar}."
|
55
57
|
)
|
56
58
|
else:
|
57
59
|
is_valid(value)
|
58
60
|
|
59
61
|
|
60
|
-
class
|
62
|
+
class MetricRecord(TypedDict[str, MetricRecordValues]):
|
61
63
|
"""Metrics recod.
|
62
64
|
|
63
|
-
A :code:`
|
64
|
-
each key-value pair adheres to specified data types. A :code:`
|
65
|
+
A :code:`MetricRecord` is a Python dictionary designed to ensure that
|
66
|
+
each key-value pair adheres to specified data types. A :code:`MetricRecord`
|
65
67
|
is one of the types of records that a
|
66
68
|
`flwr.common.RecordDict <flwr.common.RecordDict.html#recorddict>`_ supports and
|
67
69
|
can therefore be used to construct :code:`common.Message` objects.
|
68
70
|
|
69
71
|
Parameters
|
70
72
|
----------
|
71
|
-
|
73
|
+
metric_dict : Optional[Dict[str, MetricRecordValues]]
|
72
74
|
A dictionary that stores basic types (i.e. `int`, `float` as defined
|
73
|
-
in `
|
75
|
+
in `MetricScalar`) and list of such types (see `MetricScalarList`).
|
74
76
|
keep_input : bool (default: True)
|
75
77
|
A boolean indicating whether metrics should be deleted from the input
|
76
78
|
dictionary immediately after adding them to the record. When set
|
@@ -79,7 +81,7 @@ class MetricsRecord(TypedDict[str, MetricsRecordValues]):
|
|
79
81
|
|
80
82
|
Examples
|
81
83
|
--------
|
82
|
-
The usage of a :code:`
|
84
|
+
The usage of a :code:`MetricRecord` is envisioned for communicating results
|
83
85
|
obtained when a node performs an action. A few typical examples include:
|
84
86
|
communicating the training accuracy after a model is trained locally by a
|
85
87
|
:code:`ClientApp`, reporting the validation loss obtained at a :code:`ClientApp`,
|
@@ -87,43 +89,43 @@ class MetricsRecord(TypedDict[str, MetricsRecordValues]):
|
|
87
89
|
Common to these examples is that the output can be typically represented by
|
88
90
|
a single scalar (:code:`int`, :code:`float`) or list of scalars.
|
89
91
|
|
90
|
-
Let's see some examples of how to construct a :code:`
|
92
|
+
Let's see some examples of how to construct a :code:`MetricRecord` from scratch:
|
91
93
|
|
92
|
-
>>> from flwr.common import
|
94
|
+
>>> from flwr.common import MetricRecord
|
93
95
|
>>>
|
94
|
-
>>> # A `
|
95
|
-
>>> record =
|
96
|
+
>>> # A `MetricRecord` is a specialized Python dictionary
|
97
|
+
>>> record = MetricRecord({"accuracy": 0.94})
|
96
98
|
>>> # You can add more content to an existing record
|
97
99
|
>>> record["loss"] = 0.01
|
98
100
|
>>> # It also supports lists
|
99
101
|
>>> record["loss-historic"] = [0.9, 0.5, 0.01]
|
100
102
|
|
101
103
|
Since types are enforced, the types of the objects inserted are checked. For a
|
102
|
-
:code:`
|
103
|
-
:code:`flwr.common.
|
104
|
+
:code:`MetricRecord`, value types allowed are those in defined in
|
105
|
+
:code:`flwr.common.MetricRecordValues`. Similarly, only :code:`str` keys are
|
104
106
|
allowed.
|
105
107
|
|
106
|
-
>>> from flwr.common import
|
108
|
+
>>> from flwr.common import MetricRecord
|
107
109
|
>>>
|
108
|
-
>>> record =
|
110
|
+
>>> record = MetricRecord() # an empty record
|
109
111
|
>>> # Add unsupported value
|
110
112
|
>>> record["something-unsupported"] = {'a': 123} # Will throw a `TypeError`
|
111
113
|
|
112
|
-
If you need a more versatily type of record try :code:`
|
113
|
-
:code:`
|
114
|
+
If you need a more versatily type of record try :code:`ConfigRecord` or
|
115
|
+
:code:`ArrayRecord`.
|
114
116
|
"""
|
115
117
|
|
116
118
|
def __init__(
|
117
119
|
self,
|
118
|
-
|
120
|
+
metric_dict: Optional[dict[str, MetricRecordValues]] = None,
|
119
121
|
keep_input: bool = True,
|
120
|
-
):
|
122
|
+
) -> None:
|
121
123
|
super().__init__(_check_key, _check_value)
|
122
|
-
if
|
123
|
-
for k in list(
|
124
|
-
self[k] =
|
124
|
+
if metric_dict:
|
125
|
+
for k in list(metric_dict.keys()):
|
126
|
+
self[k] = metric_dict[k]
|
125
127
|
if not keep_input:
|
126
|
-
del
|
128
|
+
del metric_dict[k]
|
127
129
|
|
128
130
|
def count_bytes(self) -> int:
|
129
131
|
"""Return number of Bytes stored in this object."""
|
@@ -140,3 +142,47 @@ class MetricsRecord(TypedDict[str, MetricsRecordValues]):
|
|
140
142
|
# We also count the bytes footprint of the keys
|
141
143
|
num_bytes += len(k)
|
142
144
|
return num_bytes
|
145
|
+
|
146
|
+
|
147
|
+
class MetricsRecord(MetricRecord):
|
148
|
+
"""Deprecated class ``MetricsRecord``, use ``MetricRecord`` instead.
|
149
|
+
|
150
|
+
This class exists solely for backward compatibility with legacy
|
151
|
+
code that previously used ``MetricsRecord``. It has been renamed
|
152
|
+
to ``MetricRecord``.
|
153
|
+
|
154
|
+
.. warning::
|
155
|
+
``MetricsRecord`` is deprecated and will be removed in a future release.
|
156
|
+
Use ``MetricRecord`` instead.
|
157
|
+
|
158
|
+
Examples
|
159
|
+
--------
|
160
|
+
Legacy (deprecated) usage::
|
161
|
+
|
162
|
+
from flwr.common import MetricsRecord
|
163
|
+
|
164
|
+
record = MetricsRecord()
|
165
|
+
|
166
|
+
Updated usage::
|
167
|
+
|
168
|
+
from flwr.common import MetricRecord
|
169
|
+
|
170
|
+
record = MetricRecord()
|
171
|
+
"""
|
172
|
+
|
173
|
+
_warning_logged = False
|
174
|
+
|
175
|
+
def __init__(
|
176
|
+
self,
|
177
|
+
metric_dict: Optional[dict[str, MetricRecordValues]] = None,
|
178
|
+
keep_input: bool = True,
|
179
|
+
):
|
180
|
+
if not MetricsRecord._warning_logged:
|
181
|
+
MetricsRecord._warning_logged = True
|
182
|
+
log(
|
183
|
+
WARN,
|
184
|
+
"The `MetricsRecord` class has been renamed to `MetricRecord`. "
|
185
|
+
"Support for `MetricsRecord` will be removed in a future release. "
|
186
|
+
"Please update your code accordingly.",
|
187
|
+
)
|
188
|
+
super().__init__(metric_dict, keep_input)
|