flwr-nightly 1.21.0.dev20250903__py3-none-any.whl → 1.21.0.dev20250905__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,14 +19,20 @@ from .dp_fixed_clipping import (
19
19
  DifferentialPrivacyClientSideFixedClipping,
20
20
  DifferentialPrivacyServerSideFixedClipping,
21
21
  )
22
+ from .fedadagrad import FedAdagrad
23
+ from .fedadam import FedAdam
22
24
  from .fedavg import FedAvg
25
+ from .fedyogi import FedYogi
23
26
  from .result import Result
24
27
  from .strategy import Strategy
25
28
 
26
29
  __all__ = [
27
30
  "DifferentialPrivacyClientSideFixedClipping",
28
31
  "DifferentialPrivacyServerSideFixedClipping",
32
+ "FedAdagrad",
33
+ "FedAdam",
29
34
  "FedAvg",
35
+ "FedYogi",
30
36
  "Result",
31
37
  "Strategy",
32
38
  ]
@@ -0,0 +1,162 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """FedAdagrad [Reddi et al., 2020] strategy.
16
+
17
+ Adaptive Federated Optimization using Adagrad.
18
+
19
+ Paper: arxiv.org/abs/2003.00295
20
+ """
21
+
22
+ from collections import OrderedDict
23
+ from collections.abc import Iterable
24
+ from typing import Callable, Optional
25
+
26
+ import numpy as np
27
+
28
+ from flwr.common import Array, ArrayRecord, Message, MetricRecord, RecordDict
29
+
30
+ from .fedopt import FedOpt
31
+ from .strategy_utils import AggregationError
32
+
33
+
34
+ # pylint: disable=line-too-long
35
+ class FedAdagrad(FedOpt):
36
+ """FedAdagrad strategy - Adaptive Federated Optimization using Adagrad.
37
+
38
+ Implementation based on https://arxiv.org/abs/2003.00295v5
39
+
40
+ Parameters
41
+ ----------
42
+ fraction_train : float (default: 1.0)
43
+ Fraction of nodes used during training. In case `min_train_nodes`
44
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
45
+ will still be sampled.
46
+ fraction_evaluate : float (default: 1.0)
47
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
48
+ is larger than `fraction_evaluate * total_connected_nodes`,
49
+ `min_evaluate_nodes` will still be sampled.
50
+ min_train_nodes : int (default: 2)
51
+ Minimum number of nodes used during training.
52
+ min_evaluate_nodes : int (default: 2)
53
+ Minimum number of nodes used during validation.
54
+ min_available_nodes : int (default: 2)
55
+ Minimum number of total nodes in the system.
56
+ weighted_by_key : str (default: "num-examples")
57
+ The key within each MetricRecord whose value is used as the weight when
58
+ computing weighted averages for both ArrayRecords and MetricRecords.
59
+ arrayrecord_key : str (default: "arrays")
60
+ Key used to store the ArrayRecord when constructing Messages.
61
+ configrecord_key : str (default: "config")
62
+ Key used to store the ConfigRecord when constructing Messages.
63
+ train_metrics_aggr_fn : Optional[callable] (default: None)
64
+ Function with signature (list[RecordDict], str) -> MetricRecord,
65
+ used to aggregate MetricRecords from training round replies.
66
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
67
+ average using the provided weight factor key.
68
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
69
+ Function with signature (list[RecordDict], str) -> MetricRecord,
70
+ used to aggregate MetricRecords from training round replies.
71
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
72
+ average using the provided weight factor key.
73
+ eta : float, optional
74
+ Server-side learning rate. Defaults to 1e-1.
75
+ eta_l : float, optional
76
+ Client-side learning rate. Defaults to 1e-1.
77
+ tau : float, optional
78
+ Controls the algorithm's degree of adaptability. Defaults to 1e-3.
79
+ """
80
+
81
+ # pylint: disable=too-many-arguments
82
+ def __init__(
83
+ self,
84
+ *,
85
+ fraction_train: float = 1.0,
86
+ fraction_evaluate: float = 1.0,
87
+ min_train_nodes: int = 2,
88
+ min_evaluate_nodes: int = 2,
89
+ min_available_nodes: int = 2,
90
+ weighted_by_key: str = "num-examples",
91
+ arrayrecord_key: str = "arrays",
92
+ configrecord_key: str = "config",
93
+ train_metrics_aggr_fn: Optional[
94
+ Callable[[list[RecordDict], str], MetricRecord]
95
+ ] = None,
96
+ evaluate_metrics_aggr_fn: Optional[
97
+ Callable[[list[RecordDict], str], MetricRecord]
98
+ ] = None,
99
+ eta: float = 1e-1,
100
+ eta_l: float = 1e-1,
101
+ tau: float = 1e-3,
102
+ ) -> None:
103
+ super().__init__(
104
+ fraction_train=fraction_train,
105
+ fraction_evaluate=fraction_evaluate,
106
+ min_train_nodes=min_train_nodes,
107
+ min_evaluate_nodes=min_evaluate_nodes,
108
+ min_available_nodes=min_available_nodes,
109
+ weighted_by_key=weighted_by_key,
110
+ arrayrecord_key=arrayrecord_key,
111
+ configrecord_key=configrecord_key,
112
+ train_metrics_aggr_fn=train_metrics_aggr_fn,
113
+ evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
114
+ eta=eta,
115
+ eta_l=eta_l,
116
+ beta_1=0.0,
117
+ beta_2=0.0,
118
+ tau=tau,
119
+ )
120
+
121
+ def aggregate_train(
122
+ self,
123
+ server_round: int,
124
+ replies: Iterable[Message],
125
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
126
+ """Aggregate ArrayRecords and MetricRecords in the received Messages."""
127
+ aggregated_arrayrecord, aggregated_metrics = super().aggregate_train(
128
+ server_round, replies
129
+ )
130
+
131
+ if aggregated_arrayrecord is None:
132
+ return aggregated_arrayrecord, aggregated_metrics
133
+
134
+ if self.current_arrays is None:
135
+ reason = (
136
+ "Current arrays not set. Ensure that `configure_train` has been "
137
+ "called before aggregation."
138
+ )
139
+ raise AggregationError(reason=reason)
140
+
141
+ # Compute intermediate variables
142
+ delta_t, m_t, aggregated_ndarrays = self._compute_deltat_and_mt(
143
+ aggregated_arrayrecord
144
+ )
145
+
146
+ # v_t
147
+ if not self.v_t:
148
+ self.v_t = {k: np.zeros_like(v) for k, v in aggregated_ndarrays.items()}
149
+ self.v_t = {k: v + (delta_t[k] ** 2) for k, v in self.v_t.items()}
150
+
151
+ new_arrays = {
152
+ k: x + self.eta * m_t[k] / (np.sqrt(self.v_t[k]) + self.tau)
153
+ for k, x in self.current_arrays.items()
154
+ }
155
+
156
+ # Update current arrays
157
+ self.current_arrays = new_arrays
158
+
159
+ return (
160
+ ArrayRecord(OrderedDict({k: Array(v) for k, v in new_arrays.items()})),
161
+ aggregated_metrics,
162
+ )
@@ -0,0 +1,181 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Adaptive Federated Optimization using Adam (FedAdam) strategy.
16
+
17
+ [Reddi et al., 2020]
18
+
19
+ Paper: arxiv.org/abs/2003.00295
20
+ """
21
+
22
+ from collections import OrderedDict
23
+ from collections.abc import Iterable
24
+ from typing import Callable, Optional
25
+
26
+ import numpy as np
27
+
28
+ from flwr.common import Array, ArrayRecord, Message, MetricRecord, RecordDict
29
+
30
+ from .fedopt import FedOpt
31
+ from .strategy_utils import AggregationError
32
+
33
+
34
+ # pylint: disable=line-too-long
35
+ class FedAdam(FedOpt):
36
+ """FedAdam - Adaptive Federated Optimization using Adam.
37
+
38
+ Implementation based on https://arxiv.org/abs/2003.00295v5
39
+
40
+ Parameters
41
+ ----------
42
+ fraction_train : float (default: 1.0)
43
+ Fraction of nodes used during training. In case `min_train_nodes`
44
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
45
+ will still be sampled.
46
+ fraction_evaluate : float (default: 1.0)
47
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
48
+ is larger than `fraction_evaluate * total_connected_nodes`,
49
+ `min_evaluate_nodes` will still be sampled.
50
+ min_train_nodes : int (default: 2)
51
+ Minimum number of nodes used during training.
52
+ min_evaluate_nodes : int (default: 2)
53
+ Minimum number of nodes used during validation.
54
+ min_available_nodes : int (default: 2)
55
+ Minimum number of total nodes in the system.
56
+ weighted_by_key : str (default: "num-examples")
57
+ The key within each MetricRecord whose value is used as the weight when
58
+ computing weighted averages for both ArrayRecords and MetricRecords.
59
+ arrayrecord_key : str (default: "arrays")
60
+ Key used to store the ArrayRecord when constructing Messages.
61
+ configrecord_key : str (default: "config")
62
+ Key used to store the ConfigRecord when constructing Messages.
63
+ train_metrics_aggr_fn : Optional[callable] (default: None)
64
+ Function with signature (list[RecordDict], str) -> MetricRecord,
65
+ used to aggregate MetricRecords from training round replies.
66
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
67
+ average using the provided weight factor key.
68
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
69
+ Function with signature (list[RecordDict], str) -> MetricRecord,
70
+ used to aggregate MetricRecords from training round replies.
71
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
72
+ average using the provided weight factor key.
73
+ eta : float, optional
74
+ Server-side learning rate. Defaults to 1e-1.
75
+ eta_l : float, optional
76
+ Client-side learning rate. Defaults to 1e-1.
77
+ beta_1 : float, optional
78
+ Momentum parameter. Defaults to 0.9.
79
+ beta_2 : float, optional
80
+ Second moment parameter. Defaults to 0.99.
81
+ tau : float, optional
82
+ Controls the algorithm's degree of adaptability. Defaults to 1e-3.
83
+ """
84
+
85
+ # pylint: disable=too-many-arguments, too-many-locals
86
+ def __init__(
87
+ self,
88
+ *,
89
+ fraction_train: float = 1.0,
90
+ fraction_evaluate: float = 1.0,
91
+ min_train_nodes: int = 2,
92
+ min_evaluate_nodes: int = 2,
93
+ min_available_nodes: int = 2,
94
+ weighted_by_key: str = "num-examples",
95
+ arrayrecord_key: str = "arrays",
96
+ configrecord_key: str = "config",
97
+ train_metrics_aggr_fn: Optional[
98
+ Callable[[list[RecordDict], str], MetricRecord]
99
+ ] = None,
100
+ evaluate_metrics_aggr_fn: Optional[
101
+ Callable[[list[RecordDict], str], MetricRecord]
102
+ ] = None,
103
+ eta: float = 1e-1,
104
+ eta_l: float = 1e-1,
105
+ beta_1: float = 0.9,
106
+ beta_2: float = 0.99,
107
+ tau: float = 1e-3,
108
+ ) -> None:
109
+ super().__init__(
110
+ fraction_train=fraction_train,
111
+ fraction_evaluate=fraction_evaluate,
112
+ min_train_nodes=min_train_nodes,
113
+ min_evaluate_nodes=min_evaluate_nodes,
114
+ min_available_nodes=min_available_nodes,
115
+ weighted_by_key=weighted_by_key,
116
+ arrayrecord_key=arrayrecord_key,
117
+ configrecord_key=configrecord_key,
118
+ train_metrics_aggr_fn=train_metrics_aggr_fn,
119
+ evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
120
+ eta=eta,
121
+ eta_l=eta_l,
122
+ beta_1=beta_1,
123
+ beta_2=beta_2,
124
+ tau=tau,
125
+ )
126
+
127
+ def aggregate_train(
128
+ self,
129
+ server_round: int,
130
+ replies: Iterable[Message],
131
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
132
+ """Aggregate ArrayRecords and MetricRecords in the received Messages."""
133
+ aggregated_arrayrecord, aggregated_metrics = super().aggregate_train(
134
+ server_round, replies
135
+ )
136
+
137
+ if aggregated_arrayrecord is None:
138
+ return aggregated_arrayrecord, aggregated_metrics
139
+
140
+ if self.current_arrays is None:
141
+ reason = (
142
+ "Current arrays not set. Ensure that `configure_train` has been "
143
+ "called before aggregation."
144
+ )
145
+ raise AggregationError(reason=reason)
146
+
147
+ # Compute intermediate variables
148
+ delta_t, m_t, aggregated_ndarrays = self._compute_deltat_and_mt(
149
+ aggregated_arrayrecord
150
+ )
151
+
152
+ # v_t
153
+ if not self.v_t:
154
+ self.v_t = {k: np.zeros_like(v) for k, v in aggregated_ndarrays.items()}
155
+ self.v_t = {
156
+ k: self.beta_2 * v + (1 - self.beta_2) * (delta_t[k] ** 2)
157
+ for k, v in self.v_t.items()
158
+ }
159
+
160
+ # Compute the bias-corrected learning rate, `eta_norm` for improving convergence
161
+ # in the early rounds of FL training. This `eta_norm` is `\alpha_t` in Kingma &
162
+ # Ba, 2014 (http://arxiv.org/abs/1412.6980) "Adam: A Method for Stochastic
163
+ # Optimization" in the formula line right before Section 2.1.
164
+ eta_norm = (
165
+ self.eta
166
+ * np.sqrt(1 - np.power(self.beta_2, server_round + 1.0))
167
+ / (1 - np.power(self.beta_1, server_round + 1.0))
168
+ )
169
+
170
+ new_arrays = {
171
+ k: x + eta_norm * m_t[k] / (np.sqrt(self.v_t[k]) + self.tau)
172
+ for k, x in self.current_arrays.items()
173
+ }
174
+
175
+ # Update current arrays
176
+ self.current_arrays = new_arrays
177
+
178
+ return (
179
+ ArrayRecord(OrderedDict({k: Array(v) for k, v in new_arrays.items()})),
180
+ aggregated_metrics,
181
+ )
@@ -0,0 +1,218 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Adaptive Federated Optimization (FedOpt) [Reddi et al., 2020] abstract strategy.
16
+
17
+ Paper: arxiv.org/abs/2003.00295
18
+ """
19
+
20
+ from collections.abc import Iterable
21
+ from logging import INFO
22
+ from typing import Callable, Optional
23
+
24
+ import numpy as np
25
+
26
+ from flwr.common import (
27
+ ArrayRecord,
28
+ ConfigRecord,
29
+ Message,
30
+ MetricRecord,
31
+ NDArray,
32
+ RecordDict,
33
+ log,
34
+ )
35
+ from flwr.server import Grid
36
+
37
+ from .fedavg import FedAvg
38
+ from .strategy_utils import AggregationError
39
+
40
+
41
+ # pylint: disable=line-too-long
42
+ class FedOpt(FedAvg):
43
+ """Federated Optim strategy.
44
+
45
+ Implementation based on https://arxiv.org/abs/2003.00295v5
46
+
47
+ Parameters
48
+ ----------
49
+ fraction_train : float (default: 1.0)
50
+ Fraction of nodes used during training. In case `min_train_nodes`
51
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
52
+ will still be sampled.
53
+ fraction_evaluate : float (default: 1.0)
54
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
55
+ is larger than `fraction_evaluate * total_connected_nodes`,
56
+ `min_evaluate_nodes` will still be sampled.
57
+ min_train_nodes : int (default: 2)
58
+ Minimum number of nodes used during training.
59
+ min_evaluate_nodes : int (default: 2)
60
+ Minimum number of nodes used during validation.
61
+ min_available_nodes : int (default: 2)
62
+ Minimum number of total nodes in the system.
63
+ weighted_by_key : str (default: "num-examples")
64
+ The key within each MetricRecord whose value is used as the weight when
65
+ computing weighted averages for both ArrayRecords and MetricRecords.
66
+ arrayrecord_key : str (default: "arrays")
67
+ Key used to store the ArrayRecord when constructing Messages.
68
+ configrecord_key : str (default: "config")
69
+ Key used to store the ConfigRecord when constructing Messages.
70
+ train_metrics_aggr_fn : Optional[callable] (default: None)
71
+ Function with signature (list[RecordDict], str) -> MetricRecord,
72
+ used to aggregate MetricRecords from training round replies.
73
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
74
+ average using the provided weight factor key.
75
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
76
+ Function with signature (list[RecordDict], str) -> MetricRecord,
77
+ used to aggregate MetricRecords from training round replies.
78
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
79
+ average using the provided weight factor key.
80
+ eta : float, optional
81
+ Server-side learning rate. Defaults to 1e-1.
82
+ eta_l : float, optional
83
+ Client-side learning rate. Defaults to 1e-1.
84
+ beta_1 : float, optional
85
+ Momentum parameter. Defaults to 0.0.
86
+ beta_2 : float, optional
87
+ Second moment parameter. Defaults to 0.0.
88
+ tau : float, optional
89
+ Controls the algorithm's degree of adaptability. Defaults to 1e-3.
90
+ """
91
+
92
+ # pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-locals, line-too-long
93
+ def __init__(
94
+ self,
95
+ *,
96
+ fraction_train: float = 1.0,
97
+ fraction_evaluate: float = 1.0,
98
+ min_train_nodes: int = 2,
99
+ min_evaluate_nodes: int = 2,
100
+ min_available_nodes: int = 2,
101
+ weighted_by_key: str = "num-examples",
102
+ arrayrecord_key: str = "arrays",
103
+ configrecord_key: str = "config",
104
+ train_metrics_aggr_fn: Optional[
105
+ Callable[[list[RecordDict], str], MetricRecord]
106
+ ] = None,
107
+ evaluate_metrics_aggr_fn: Optional[
108
+ Callable[[list[RecordDict], str], MetricRecord]
109
+ ] = None,
110
+ eta: float = 1e-1,
111
+ eta_l: float = 1e-1,
112
+ beta_1: float = 0.0,
113
+ beta_2: float = 0.0,
114
+ tau: float = 1e-3,
115
+ ) -> None:
116
+ super().__init__(
117
+ fraction_train=fraction_train,
118
+ fraction_evaluate=fraction_evaluate,
119
+ min_train_nodes=min_train_nodes,
120
+ min_evaluate_nodes=min_evaluate_nodes,
121
+ min_available_nodes=min_available_nodes,
122
+ weighted_by_key=weighted_by_key,
123
+ arrayrecord_key=arrayrecord_key,
124
+ configrecord_key=configrecord_key,
125
+ train_metrics_aggr_fn=train_metrics_aggr_fn,
126
+ evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
127
+ )
128
+ self.current_arrays: Optional[dict[str, NDArray]] = None
129
+ self.eta = eta
130
+ self.eta_l = eta_l
131
+ self.tau = tau
132
+ self.beta_1 = beta_1
133
+ self.beta_2 = beta_2
134
+ self.m_t: Optional[dict[str, NDArray]] = None
135
+ self.v_t: Optional[dict[str, NDArray]] = None
136
+
137
+ def summary(self) -> None:
138
+ """Log summary configuration of the strategy."""
139
+ log(INFO, "\t├──> FedOpt settings:")
140
+ log(
141
+ INFO,
142
+ "\t│\t├── eta (%s) | eta_l (%s)",
143
+ f"{self.eta:.6g}",
144
+ f"{self.eta_l:.6g}",
145
+ )
146
+ log(
147
+ INFO,
148
+ "\t│\t├── beta_1 (%s) | beta_2 (%s)",
149
+ f"{self.beta_1:.6g}",
150
+ f"{self.beta_2:.6g}",
151
+ )
152
+ log(
153
+ INFO,
154
+ "\t│\t└── tau (%s)",
155
+ f"{self.tau:.6g}",
156
+ )
157
+ super().summary()
158
+
159
+ def configure_train(
160
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
161
+ ) -> Iterable[Message]:
162
+ """Configure the next round of federated training."""
163
+ # Keep track of array record being communicated
164
+ self.current_arrays = {k: array.numpy() for k, array in arrays.items()}
165
+ return super().configure_train(server_round, arrays, config, grid)
166
+
167
+ def _compute_deltat_and_mt(
168
+ self, aggregated_arrayrecord: ArrayRecord
169
+ ) -> tuple[dict[str, NDArray], dict[str, NDArray], dict[str, NDArray]]:
170
+ """Compute delta_t and m_t.
171
+
172
+ This is a shared stage during aggregation for FedAdagrad, FedAdam and FedYogi.
173
+ """
174
+ if self.current_arrays is None:
175
+ reason = (
176
+ "Current arrays not set. Ensure that `configure_train` has been "
177
+ "called before aggregation."
178
+ )
179
+ raise AggregationError(reason=reason)
180
+
181
+ aggregated_ndarrays = {
182
+ k: array.numpy() for k, array in aggregated_arrayrecord.items()
183
+ }
184
+
185
+ # Check keys in aggregated arrays match those in current arrays
186
+ if set(aggregated_ndarrays.keys()) != set(self.current_arrays.keys()):
187
+ reason = (
188
+ "Keys of the aggregated arrays do not match those of the arrays "
189
+ "stored at the strategy. `delta_t = aggregated_arrays - "
190
+ "current_arrays` cannot be computed."
191
+ )
192
+ raise AggregationError(reason=reason)
193
+
194
+ # Check that the shape of values match
195
+ # Only shapes that match can compute delta_t (we don't want
196
+ # broadcasting to happen)
197
+ for k, x in aggregated_ndarrays.items():
198
+ if x.shape != self.current_arrays[k].shape:
199
+ reason = (
200
+ f"Shape of aggregated array '{k}' does not match "
201
+ f"shape of the array under the same key stored in the strategy. "
202
+ f"Cannot compute `delta_t`."
203
+ )
204
+ raise AggregationError(reason=reason)
205
+
206
+ delta_t = {
207
+ k: x - self.current_arrays[k] for k, x in aggregated_ndarrays.items()
208
+ }
209
+
210
+ # m_t
211
+ if not self.m_t:
212
+ self.m_t = {k: np.zeros_like(v) for k, v in aggregated_ndarrays.items()}
213
+ self.m_t = {
214
+ k: self.beta_1 * v + (1 - self.beta_1) * delta_t[k]
215
+ for k, v in self.m_t.items()
216
+ }
217
+
218
+ return delta_t, self.m_t, aggregated_ndarrays
@@ -0,0 +1,173 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Adaptive Federated Optimization using Yogi (FedYogi) [Reddi et al., 2020] strategy.
16
+
17
+ Paper: arxiv.org/abs/2003.00295
18
+ """
19
+
20
+
21
+ from collections import OrderedDict
22
+ from collections.abc import Iterable
23
+ from typing import Callable, Optional
24
+
25
+ import numpy as np
26
+
27
+ from flwr.common import Array, ArrayRecord, Message, MetricRecord, RecordDict
28
+
29
+ from .fedopt import FedOpt
30
+ from .strategy_utils import AggregationError
31
+
32
+
33
+ # pylint: disable=line-too-long
34
+ class FedYogi(FedOpt):
35
+ """FedYogi [Reddi et al., 2020] strategy.
36
+
37
+ Implementation based on https://arxiv.org/abs/2003.00295v5
38
+
39
+
40
+ Parameters
41
+ ----------
42
+ fraction_train : float (default: 1.0)
43
+ Fraction of nodes used during training. In case `min_train_nodes`
44
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
45
+ will still be sampled.
46
+ fraction_evaluate : float (default: 1.0)
47
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
48
+ is larger than `fraction_evaluate * total_connected_nodes`,
49
+ `min_evaluate_nodes` will still be sampled.
50
+ min_train_nodes : int (default: 2)
51
+ Minimum number of nodes used during training.
52
+ min_evaluate_nodes : int (default: 2)
53
+ Minimum number of nodes used during validation.
54
+ min_available_nodes : int (default: 2)
55
+ Minimum number of total nodes in the system.
56
+ weighted_by_key : str (default: "num-examples")
57
+ The key within each MetricRecord whose value is used as the weight when
58
+ computing weighted averages for both ArrayRecords and MetricRecords.
59
+ arrayrecord_key : str (default: "arrays")
60
+ Key used to store the ArrayRecord when constructing Messages.
61
+ configrecord_key : str (default: "config")
62
+ Key used to store the ConfigRecord when constructing Messages.
63
+ train_metrics_aggr_fn : Optional[callable] (default: None)
64
+ Function with signature (list[RecordDict], str) -> MetricRecord,
65
+ used to aggregate MetricRecords from training round replies.
66
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
67
+ average using the provided weight factor key.
68
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
69
+ Function with signature (list[RecordDict], str) -> MetricRecord,
70
+ used to aggregate MetricRecords from training round replies.
71
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
72
+ average using the provided weight factor key.
73
+ eta : float, optional
74
+ Server-side learning rate. Defaults to 1e-2.
75
+ eta_l : float, optional
76
+ Client-side learning rate. Defaults to 0.0316.
77
+ beta_1 : float, optional
78
+ Momentum parameter. Defaults to 0.9.
79
+ beta_2 : float, optional
80
+ Second moment parameter. Defaults to 0.99.
81
+ tau : float, optional
82
+ Controls the algorithm's degree of adaptability.
83
+ Defaults to 1e-3.
84
+ """
85
+
86
+ # pylint: disable=too-many-arguments, too-many-locals
87
+ def __init__(
88
+ self,
89
+ *,
90
+ fraction_train: float = 1.0,
91
+ fraction_evaluate: float = 1.0,
92
+ min_train_nodes: int = 2,
93
+ min_evaluate_nodes: int = 2,
94
+ min_available_nodes: int = 2,
95
+ weighted_by_key: str = "num-examples",
96
+ arrayrecord_key: str = "arrays",
97
+ configrecord_key: str = "config",
98
+ train_metrics_aggr_fn: Optional[
99
+ Callable[[list[RecordDict], str], MetricRecord]
100
+ ] = None,
101
+ evaluate_metrics_aggr_fn: Optional[
102
+ Callable[[list[RecordDict], str], MetricRecord]
103
+ ] = None,
104
+ eta: float = 1e-2,
105
+ eta_l: float = 0.0316,
106
+ beta_1: float = 0.9,
107
+ beta_2: float = 0.99,
108
+ tau: float = 1e-3,
109
+ ) -> None:
110
+ super().__init__(
111
+ fraction_train=fraction_train,
112
+ fraction_evaluate=fraction_evaluate,
113
+ min_train_nodes=min_train_nodes,
114
+ min_evaluate_nodes=min_evaluate_nodes,
115
+ min_available_nodes=min_available_nodes,
116
+ weighted_by_key=weighted_by_key,
117
+ arrayrecord_key=arrayrecord_key,
118
+ configrecord_key=configrecord_key,
119
+ train_metrics_aggr_fn=train_metrics_aggr_fn,
120
+ evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
121
+ eta=eta,
122
+ eta_l=eta_l,
123
+ beta_1=beta_1,
124
+ beta_2=beta_2,
125
+ tau=tau,
126
+ )
127
+
128
+ def aggregate_train(
129
+ self,
130
+ server_round: int,
131
+ replies: Iterable[Message],
132
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
133
+ """Aggregate ArrayRecords and MetricRecords in the received Messages."""
134
+ aggregated_arrayrecord, aggregated_metrics = super().aggregate_train(
135
+ server_round, replies
136
+ )
137
+
138
+ if aggregated_arrayrecord is None:
139
+ return aggregated_arrayrecord, aggregated_metrics
140
+
141
+ if self.current_arrays is None:
142
+ reason = (
143
+ "Current arrays not set. Ensure that `configure_train` has been "
144
+ "called before aggregation."
145
+ )
146
+ raise AggregationError(reason=reason)
147
+
148
+ # Compute intermediate variables
149
+ delta_t, m_t, aggregated_ndarrays = self._compute_deltat_and_mt(
150
+ aggregated_arrayrecord
151
+ )
152
+
153
+ # v_t
154
+ if not self.v_t:
155
+ self.v_t = {k: np.zeros_like(v) for k, v in aggregated_ndarrays.items()}
156
+ self.v_t = {
157
+ k: v
158
+ - (1.0 - self.beta_2) * (delta_t[k] ** 2) * np.sign(v - delta_t[k] ** 2)
159
+ for k, v in self.v_t.items()
160
+ }
161
+
162
+ new_arrays = {
163
+ k: x + self.eta * m_t[k] / (np.sqrt(self.v_t[k]) + self.tau)
164
+ for k, x in self.current_arrays.items()
165
+ }
166
+
167
+ # Update current arrays
168
+ self.current_arrays = new_arrays
169
+
170
+ return (
171
+ ArrayRecord(OrderedDict({k: Array(v) for k, v in new_arrays.items()})),
172
+ aggregated_metrics,
173
+ )
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Strategy results."""
16
16
 
17
+
17
18
  import pprint
18
19
  from dataclasses import dataclass, field
19
20
 
@@ -61,19 +62,19 @@ class Result:
61
62
  arr_size = sum(len(array.data) for array in self.arrays.values()) / (1024**2)
62
63
  rep += "Global Arrays:\n" + f"\tArrayRecord ({arr_size:.3f} MB)\n" + "\n"
63
64
  rep += (
64
- "Aggregated Client-side Train Metrics:\n"
65
+ "Aggregated ClientApp-side Train Metrics:\n"
65
66
  + pprint.pformat(stringify_dict(self.train_metrics_clientapp), indent=2)
66
67
  + "\n\n"
67
68
  )
68
69
 
69
70
  rep += (
70
- "Aggregated Client-side Evaluate Metrics:\n"
71
+ "Aggregated ClientApp-side Evaluate Metrics:\n"
71
72
  + pprint.pformat(stringify_dict(self.evaluate_metrics_clientapp), indent=2)
72
73
  + "\n\n"
73
74
  )
74
75
 
75
76
  rep += (
76
- "Server-side Evaluate Metrics:\n"
77
+ "ServerApp-side Evaluate Metrics:\n"
77
78
  + pprint.pformat(stringify_dict(self.evaluate_metrics_serverapp), indent=2)
78
79
  + "\n"
79
80
  )
@@ -140,7 +140,9 @@ class Strategy(ABC):
140
140
  timeout: float = 3600,
141
141
  train_config: Optional[ConfigRecord] = None,
142
142
  evaluate_config: Optional[ConfigRecord] = None,
143
- evaluate_fn: Optional[Callable[[int, ArrayRecord], MetricRecord]] = None,
143
+ evaluate_fn: Optional[
144
+ Callable[[int, ArrayRecord], Optional[MetricRecord]]
145
+ ] = None,
144
146
  ) -> Result:
145
147
  """Execute the federated learning strategy.
146
148
 
@@ -164,11 +166,11 @@ class Strategy(ABC):
164
166
  evaluate_config : ConfigRecord, optional
165
167
  Configuration to be sent to nodes during evaluation rounds.
166
168
  If unset, an empty ConfigRecord will be used.
167
- evaluate_fn : Callable[[int, ArrayRecord], MetricRecord], optional
169
+ evaluate_fn : Callable[[int, ArrayRecord], Optional[MetricRecord]], optional
168
170
  Optional function for centralized evaluation of the global model. Takes
169
- server round number and array record, returns a MetricRecord. If provided,
170
- will be called before the first round and after each round. Defaults to
171
- None.
171
+ server round number and array record, returns a MetricRecord or None. If
172
+ provided, will be called before the first round and after each round.
173
+ Defaults to None.
172
174
 
173
175
  Returns
174
176
  -------
@@ -193,7 +195,8 @@ class Strategy(ABC):
193
195
  if evaluate_fn:
194
196
  res = evaluate_fn(0, initial_arrays)
195
197
  log(INFO, "Initial global evaluation results: %s", res)
196
- result.evaluate_metrics_serverapp[0] = res
198
+ if res is not None:
199
+ result.evaluate_metrics_serverapp[0] = res
197
200
 
198
201
  arrays = initial_arrays
199
202
 
@@ -202,7 +205,7 @@ class Strategy(ABC):
202
205
  log(INFO, "[ROUND %s/%s]", current_round, num_rounds)
203
206
 
204
207
  # -----------------------------------------------------------------
205
- # --- TRAINING ----------------------------------------------------
208
+ # --- TRAINING (CLIENTAPP-SIDE) -----------------------------------
206
209
  # -----------------------------------------------------------------
207
210
 
208
211
  # Call strategy to configure training round
@@ -232,7 +235,7 @@ class Strategy(ABC):
232
235
  result.train_metrics_clientapp[current_round] = agg_train_metrics
233
236
 
234
237
  # -----------------------------------------------------------------
235
- # --- EVALUATION (LOCAL) ------------------------------------------
238
+ # --- EVALUATION (CLIENTAPP-SIDE) ---------------------------------
236
239
  # -----------------------------------------------------------------
237
240
 
238
241
  # Call strategy to configure evaluation round
@@ -259,7 +262,7 @@ class Strategy(ABC):
259
262
  result.evaluate_metrics_clientapp[current_round] = agg_evaluate_metrics
260
263
 
261
264
  # -----------------------------------------------------------------
262
- # --- EVALUATION (GLOBAL) -----------------------------------------
265
+ # --- EVALUATION (SERVERAPP-SIDE) ---------------------------------
263
266
  # -----------------------------------------------------------------
264
267
 
265
268
  # Centralized evaluation
@@ -267,7 +270,8 @@ class Strategy(ABC):
267
270
  log(INFO, "Global evaluation")
268
271
  res = evaluate_fn(current_round, arrays)
269
272
  log(INFO, "\t└──> MetricRecord: %s", res)
270
- result.evaluate_metrics_serverapp[current_round] = res
273
+ if res is not None:
274
+ result.evaluate_metrics_serverapp[current_round] = res
271
275
 
272
276
  log(INFO, "")
273
277
  log(INFO, "Strategy execution finished in %.2fs", time.time() - t_start)
@@ -45,6 +45,15 @@ class InconsistentMessageReplies(AppExitException):
45
45
  super().__init__(reason)
46
46
 
47
47
 
48
+ class AggregationError(AppExitException):
49
+ """Exception triggered when aggregation fails."""
50
+
51
+ exit_code = ExitCode.SERVERAPP_STRATEGY_AGGREGATION_ERROR
52
+
53
+ def __init__(self, reason: str):
54
+ super().__init__(reason)
55
+
56
+
48
57
  def config_to_str(config: ConfigRecord) -> str:
49
58
  """Convert a ConfigRecord to a string representation masking bytes."""
50
59
  content = ", ".join(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: flwr-nightly
3
- Version: 1.21.0.dev20250903
3
+ Version: 1.21.0.dev20250905
4
4
  Summary: Flower: A Friendly Federated AI Framework
5
5
  License: Apache-2.0
6
6
  Keywords: Artificial Intelligence,Federated AI,Federated Analytics,Federated Evaluation,Federated Learning,Flower,Machine Learning
@@ -330,12 +330,16 @@ flwr/server/workflow/secure_aggregation/secagg_workflow.py,sha256=b_pKk7gmbahwyj
330
330
  flwr/server/workflow/secure_aggregation/secaggplus_workflow.py,sha256=DkayCsnlAya6Y2PZsueLgoUCMRtV-GbnW08RfWx_SXM,29460
331
331
  flwr/serverapp/__init__.py,sha256=dUGPpyO0YEJRIjwNw2YrUWXgsEj9JOUrP5OGm8bPX9k,774
332
332
  flwr/serverapp/dp_fixed_clipping.py,sha256=wbP4W7CaUHXdll8ZSVUnTBSEWrnWM00CGk63rOR-Q2s,12133
333
- flwr/serverapp/strategy/__init__.py,sha256=FpZN4AafpTSxW65dAPJ0zekHo9bU84tV4uhO4XVHJTc,1088
333
+ flwr/serverapp/strategy/__init__.py,sha256=yAYBZUkp4aNmcTLsvormEc9HyO34oEoFN45LiHgujE0,1229
334
334
  flwr/serverapp/strategy/dp_fixed_clipping.py,sha256=wbP4W7CaUHXdll8ZSVUnTBSEWrnWM00CGk63rOR-Q2s,12133
335
+ flwr/serverapp/strategy/fedadagrad.py,sha256=talxGzeGSIIkLPtZk4i_qXZTksRoBeajrmeUbnGHTUY,6347
336
+ flwr/serverapp/strategy/fedadam.py,sha256=TGLnxoJro758DUc9tAxBks9tSRtXDqy-4vWthiqscIo,7142
335
337
  flwr/serverapp/strategy/fedavg.py,sha256=C8UUvLTjodMpGRb4PNej5gW2cPbXsPKebGX1zPfAMUo,11020
336
- flwr/serverapp/strategy/result.py,sha256=rcg3NqZyC-A_x-f40fq8fk9ovL1zcSz9Jxpr32MRRIc,4285
337
- flwr/serverapp/strategy/strategy.py,sha256=9udL2q1zEpVw-rKDMoZG_fwoklF4t1HC9hrnPaYiEhA,10663
338
- flwr/serverapp/strategy/strategy_utils.py,sha256=rtcBQwFtWAihNdcWEAHdAqScPlRZSwqbkGjGxaWkmLE,9547
338
+ flwr/serverapp/strategy/fedopt.py,sha256=aIN5CtgsE88bAodN_M_pf_01a2vMIj9R_7CYwd8VeMU,8481
339
+ flwr/serverapp/strategy/fedyogi.py,sha256=UQnEKVTpJiB_zbCREfI8CEHiuJMIRmEIu5DV50FG_5s,6657
340
+ flwr/serverapp/strategy/result.py,sha256=E0Hl2VLnZAgQJjE2GDoKsK7JX-kPPU2KXc47Axt6hGw,4295
341
+ flwr/serverapp/strategy/strategy.py,sha256=8uJGGm1ROLZERQ_dkRS7Z_rs-yK6XCE0UxXtIdFiEWk,10789
342
+ flwr/serverapp/strategy/strategy_utils.py,sha256=C8vU8JqKhMylq102x5jjQITzv_X2Khfo-uXkPTpnHms,9779
339
343
  flwr/serverapp/strategy/strategy_utils_tests.py,sha256=taG6HwApwutkjUuMY3R8Ib48Xepw6g5xl9HEB_-leoY,9232
340
344
  flwr/simulation/__init__.py,sha256=Gg6OsP1Z-ixc3-xxzvl7j7rz2Fijy9rzyEPpxgAQCeM,1556
341
345
  flwr/simulation/app.py,sha256=LbGLMvN9Ap119yBqsUcNNmVLRnCySnr4VechqcQ1hpA,10401
@@ -397,7 +401,7 @@ flwr/supernode/servicer/__init__.py,sha256=lucTzre5WPK7G1YLCfaqg3rbFWdNSb7ZTt-ca
397
401
  flwr/supernode/servicer/clientappio/__init__.py,sha256=7Oy62Y_oijqF7Dxi6tpcUQyOpLc_QpIRZ83NvwmB0Yg,813
398
402
  flwr/supernode/servicer/clientappio/clientappio_servicer.py,sha256=nIHRu38EWK-rpNOkcgBRAAKwYQQWFeCwu0lkO7OPZGQ,10239
399
403
  flwr/supernode/start_client_internal.py,sha256=Y9S1-QlO2WP6eo4JvWzIpfaCoh2aoE7bjEYyxNNnlyg,20777
400
- flwr_nightly-1.21.0.dev20250903.dist-info/METADATA,sha256=x-9LvDwIejGasw9rlnPTi3qrHp3Z2hgABMTMj-_iU3k,15967
401
- flwr_nightly-1.21.0.dev20250903.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
402
- flwr_nightly-1.21.0.dev20250903.dist-info/entry_points.txt,sha256=hxHD2ixb_vJFDOlZV-zB4Ao32_BQlL34ftsDh1GXv14,420
403
- flwr_nightly-1.21.0.dev20250903.dist-info/RECORD,,
404
+ flwr_nightly-1.21.0.dev20250905.dist-info/METADATA,sha256=j_7UpChhN5B6nRj4Gelt99YjECPo5fbZV8k6K3B86Go,15967
405
+ flwr_nightly-1.21.0.dev20250905.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
406
+ flwr_nightly-1.21.0.dev20250905.dist-info/entry_points.txt,sha256=hxHD2ixb_vJFDOlZV-zB4Ao32_BQlL34ftsDh1GXv14,420
407
+ flwr_nightly-1.21.0.dev20250905.dist-info/RECORD,,