flwr 1.19.0__py3-none-any.whl → 1.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. flwr/__init__.py +4 -1
  2. flwr/app/__init__.py +28 -0
  3. flwr/app/exception.py +31 -0
  4. flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
  5. flwr/cli/build.py +15 -5
  6. flwr/cli/cli_user_auth_interceptor.py +1 -1
  7. flwr/cli/config_utils.py +3 -3
  8. flwr/cli/constant.py +25 -8
  9. flwr/cli/log.py +9 -9
  10. flwr/cli/login/login.py +3 -3
  11. flwr/cli/ls.py +5 -5
  12. flwr/cli/new/new.py +23 -4
  13. flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
  14. flwr/cli/new/templates/app/README.md.tpl +5 -0
  15. flwr/cli/new/templates/app/code/__init__.pytorch_msg_api.py.tpl +1 -0
  16. flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +80 -0
  17. flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +41 -0
  18. flwr/cli/new/templates/app/code/task.pytorch_msg_api.py.tpl +98 -0
  19. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -3
  20. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
  21. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
  22. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
  23. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
  24. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
  25. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
  26. flwr/cli/new/templates/app/pyproject.pytorch_msg_api.toml.tpl +53 -0
  27. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
  28. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
  29. flwr/cli/run/run.py +53 -50
  30. flwr/cli/stop.py +7 -4
  31. flwr/cli/utils.py +29 -11
  32. flwr/client/grpc_adapter_client/connection.py +11 -4
  33. flwr/client/grpc_rere_client/connection.py +93 -129
  34. flwr/client/rest_client/connection.py +134 -164
  35. flwr/clientapp/__init__.py +10 -0
  36. flwr/clientapp/mod/__init__.py +26 -0
  37. flwr/clientapp/mod/centraldp_mods.py +132 -0
  38. flwr/common/args.py +20 -6
  39. flwr/common/auth_plugin/__init__.py +4 -4
  40. flwr/common/auth_plugin/auth_plugin.py +7 -7
  41. flwr/common/constant.py +26 -5
  42. flwr/common/event_log_plugin/event_log_plugin.py +1 -1
  43. flwr/common/exit/__init__.py +4 -0
  44. flwr/common/exit/exit.py +8 -1
  45. flwr/common/exit/exit_code.py +42 -8
  46. flwr/common/exit/exit_handler.py +62 -0
  47. flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
  48. flwr/common/grpc.py +1 -1
  49. flwr/common/{inflatable_grpc_utils.py → inflatable_protobuf_utils.py} +52 -10
  50. flwr/common/inflatable_utils.py +191 -24
  51. flwr/common/logger.py +1 -1
  52. flwr/common/record/array.py +101 -22
  53. flwr/common/record/arraychunk.py +59 -0
  54. flwr/common/retry_invoker.py +30 -11
  55. flwr/common/serde.py +0 -28
  56. flwr/common/telemetry.py +4 -0
  57. flwr/compat/client/app.py +14 -31
  58. flwr/compat/server/app.py +2 -2
  59. flwr/proto/appio_pb2.py +51 -0
  60. flwr/proto/appio_pb2.pyi +195 -0
  61. flwr/proto/appio_pb2_grpc.py +4 -0
  62. flwr/proto/appio_pb2_grpc.pyi +4 -0
  63. flwr/proto/clientappio_pb2.py +4 -19
  64. flwr/proto/clientappio_pb2.pyi +0 -125
  65. flwr/proto/clientappio_pb2_grpc.py +269 -29
  66. flwr/proto/clientappio_pb2_grpc.pyi +114 -21
  67. flwr/proto/control_pb2.py +62 -0
  68. flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +54 -54
  69. flwr/proto/{exec_pb2_grpc.pyi → control_pb2_grpc.pyi} +28 -28
  70. flwr/proto/fleet_pb2.py +12 -20
  71. flwr/proto/fleet_pb2.pyi +6 -36
  72. flwr/proto/serverappio_pb2.py +8 -31
  73. flwr/proto/serverappio_pb2.pyi +0 -152
  74. flwr/proto/serverappio_pb2_grpc.py +107 -38
  75. flwr/proto/serverappio_pb2_grpc.pyi +47 -20
  76. flwr/proto/simulationio_pb2.py +4 -11
  77. flwr/proto/simulationio_pb2.pyi +0 -58
  78. flwr/proto/simulationio_pb2_grpc.py +129 -27
  79. flwr/proto/simulationio_pb2_grpc.pyi +52 -13
  80. flwr/server/app.py +130 -153
  81. flwr/server/fleet_event_log_interceptor.py +4 -0
  82. flwr/server/grid/grpc_grid.py +94 -54
  83. flwr/server/grid/inmemory_grid.py +1 -0
  84. flwr/server/serverapp/app.py +165 -144
  85. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +8 -0
  86. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
  87. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
  88. flwr/server/superlink/fleet/message_handler/message_handler.py +10 -16
  89. flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -2
  90. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
  91. flwr/server/superlink/fleet/vce/vce_api.py +6 -6
  92. flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
  93. flwr/server/superlink/linkstate/linkstate.py +2 -1
  94. flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
  95. flwr/server/superlink/serverappio/serverappio_grpc.py +2 -2
  96. flwr/server/superlink/serverappio/serverappio_servicer.py +95 -48
  97. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  98. flwr/server/superlink/simulation/simulationio_servicer.py +98 -22
  99. flwr/server/superlink/utils.py +0 -35
  100. flwr/serverapp/__init__.py +12 -0
  101. flwr/serverapp/dp_fixed_clipping.py +352 -0
  102. flwr/serverapp/exception.py +38 -0
  103. flwr/serverapp/strategy/__init__.py +38 -0
  104. flwr/serverapp/strategy/dp_fixed_clipping.py +352 -0
  105. flwr/serverapp/strategy/fedadagrad.py +162 -0
  106. flwr/serverapp/strategy/fedadam.py +181 -0
  107. flwr/serverapp/strategy/fedavg.py +295 -0
  108. flwr/serverapp/strategy/fedopt.py +218 -0
  109. flwr/serverapp/strategy/fedyogi.py +173 -0
  110. flwr/serverapp/strategy/result.py +105 -0
  111. flwr/serverapp/strategy/strategy.py +285 -0
  112. flwr/serverapp/strategy/strategy_utils.py +251 -0
  113. flwr/serverapp/strategy/strategy_utils_tests.py +304 -0
  114. flwr/simulation/app.py +159 -154
  115. flwr/simulation/run_simulation.py +17 -0
  116. flwr/supercore/app_utils.py +58 -0
  117. flwr/supercore/cli/__init__.py +22 -0
  118. flwr/supercore/cli/flower_superexec.py +141 -0
  119. flwr/supercore/corestate/__init__.py +22 -0
  120. flwr/supercore/corestate/corestate.py +81 -0
  121. flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
  122. flwr/supercore/grpc_health/__init__.py +25 -0
  123. flwr/supercore/grpc_health/health_server.py +53 -0
  124. flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
  125. flwr/supercore/license_plugin/__init__.py +22 -0
  126. flwr/supercore/license_plugin/license_plugin.py +26 -0
  127. flwr/supercore/object_store/in_memory_object_store.py +31 -31
  128. flwr/supercore/object_store/object_store.py +20 -42
  129. flwr/supercore/object_store/utils.py +43 -0
  130. flwr/{superexec → supercore/superexec}/__init__.py +1 -1
  131. flwr/supercore/superexec/plugin/__init__.py +28 -0
  132. flwr/supercore/superexec/plugin/base_exec_plugin.py +53 -0
  133. flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
  134. flwr/supercore/superexec/plugin/exec_plugin.py +71 -0
  135. flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
  136. flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
  137. flwr/supercore/superexec/run_superexec.py +185 -0
  138. flwr/supercore/utils.py +32 -0
  139. flwr/superlink/servicer/__init__.py +15 -0
  140. flwr/superlink/servicer/control/__init__.py +22 -0
  141. flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +9 -5
  142. flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +39 -28
  143. flwr/superlink/servicer/control/control_license_interceptor.py +82 -0
  144. flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +79 -31
  145. flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +18 -10
  146. flwr/supernode/cli/flower_supernode.py +3 -7
  147. flwr/supernode/cli/flwr_clientapp.py +20 -16
  148. flwr/supernode/nodestate/in_memory_nodestate.py +13 -4
  149. flwr/supernode/nodestate/nodestate.py +3 -44
  150. flwr/supernode/runtime/run_clientapp.py +129 -115
  151. flwr/supernode/servicer/clientappio/__init__.py +1 -3
  152. flwr/supernode/servicer/clientappio/clientappio_servicer.py +217 -165
  153. flwr/supernode/start_client_internal.py +205 -148
  154. {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/METADATA +5 -3
  155. {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/RECORD +161 -117
  156. {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/entry_points.txt +1 -0
  157. flwr/common/inflatable_rest_utils.py +0 -99
  158. flwr/proto/exec_pb2.py +0 -62
  159. flwr/superexec/app.py +0 -45
  160. flwr/superexec/deployment.py +0 -192
  161. flwr/superexec/executor.py +0 -100
  162. flwr/superexec/simulation.py +0 -130
  163. /flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +0 -0
  164. /flwr/{server/superlink → supercore}/ffs/__init__.py +0 -0
  165. /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
  166. /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
  167. {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,295 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower message-based FedAvg strategy."""
16
+
17
+
18
+ from collections.abc import Iterable
19
+ from logging import INFO
20
+ from typing import Callable, Optional
21
+
22
+ from flwr.common import (
23
+ ArrayRecord,
24
+ ConfigRecord,
25
+ Message,
26
+ MessageType,
27
+ MetricRecord,
28
+ RecordDict,
29
+ log,
30
+ )
31
+ from flwr.server import Grid
32
+
33
+ from .strategy import Strategy
34
+ from .strategy_utils import (
35
+ aggregate_arrayrecords,
36
+ aggregate_metricrecords,
37
+ sample_nodes,
38
+ validate_message_reply_consistency,
39
+ )
40
+
41
+
42
+ # pylint: disable=too-many-instance-attributes
43
+ class FedAvg(Strategy):
44
+ """Federated Averaging strategy.
45
+
46
+ Implementation based on https://arxiv.org/abs/1602.05629
47
+
48
+ Parameters
49
+ ----------
50
+ fraction_train : float (default: 1.0)
51
+ Fraction of nodes used during training. In case `min_train_nodes`
52
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
53
+ will still be sampled.
54
+ fraction_evaluate : float (default: 1.0)
55
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
56
+ is larger than `fraction_evaluate * total_connected_nodes`,
57
+ `min_evaluate_nodes` will still be sampled.
58
+ min_train_nodes : int (default: 2)
59
+ Minimum number of nodes used during training.
60
+ min_evaluate_nodes : int (default: 2)
61
+ Minimum number of nodes used during validation.
62
+ min_available_nodes : int (default: 2)
63
+ Minimum number of total nodes in the system.
64
+ weighted_by_key : str (default: "num-examples")
65
+ The key within each MetricRecord whose value is used as the weight when
66
+ computing weighted averages for both ArrayRecords and MetricRecords.
67
+ arrayrecord_key : str (default: "arrays")
68
+ Key used to store the ArrayRecord when constructing Messages.
69
+ configrecord_key : str (default: "config")
70
+ Key used to store the ConfigRecord when constructing Messages.
71
+ train_metrics_aggr_fn : Optional[callable] (default: None)
72
+ Function with signature (list[RecordDict], str) -> MetricRecord,
73
+ used to aggregate MetricRecords from training round replies.
74
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
75
+ average using the provided weight factor key.
76
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
77
+ Function with signature (list[RecordDict], str) -> MetricRecord,
78
+ used to aggregate MetricRecords from training round replies.
79
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
80
+ average using the provided weight factor key.
81
+ """
82
+
83
+ # pylint: disable=too-many-arguments,too-many-positional-arguments
84
+ def __init__(
85
+ self,
86
+ fraction_train: float = 1.0,
87
+ fraction_evaluate: float = 1.0,
88
+ min_train_nodes: int = 2,
89
+ min_evaluate_nodes: int = 2,
90
+ min_available_nodes: int = 2,
91
+ weighted_by_key: str = "num-examples",
92
+ arrayrecord_key: str = "arrays",
93
+ configrecord_key: str = "config",
94
+ train_metrics_aggr_fn: Optional[
95
+ Callable[[list[RecordDict], str], MetricRecord]
96
+ ] = None,
97
+ evaluate_metrics_aggr_fn: Optional[
98
+ Callable[[list[RecordDict], str], MetricRecord]
99
+ ] = None,
100
+ ) -> None:
101
+ self.fraction_train = fraction_train
102
+ self.fraction_evaluate = fraction_evaluate
103
+ self.min_train_nodes = min_train_nodes
104
+ self.min_evaluate_nodes = min_evaluate_nodes
105
+ self.min_available_nodes = min_available_nodes
106
+ self.weighted_by_key = weighted_by_key
107
+ self.arrayrecord_key = arrayrecord_key
108
+ self.configrecord_key = configrecord_key
109
+ self.train_metrics_aggr_fn = train_metrics_aggr_fn or aggregate_metricrecords
110
+ self.evaluate_metrics_aggr_fn = (
111
+ evaluate_metrics_aggr_fn or aggregate_metricrecords
112
+ )
113
+
114
+ def summary(self) -> None:
115
+ """Log summary configuration of the strategy."""
116
+ log(INFO, "\t├──> Sampling:")
117
+ log(
118
+ INFO,
119
+ "\t│\t├──Fraction: train (%.2f) | evaluate ( %.2f)",
120
+ self.fraction_train,
121
+ self.fraction_evaluate,
122
+ ) # pylint: disable=line-too-long
123
+ log(
124
+ INFO,
125
+ "\t│\t├──Minimum nodes: train (%d) | evaluate (%d)",
126
+ self.min_train_nodes,
127
+ self.min_evaluate_nodes,
128
+ ) # pylint: disable=line-too-long
129
+ log(INFO, "\t│\t└──Minimum available nodes: %d", self.min_available_nodes)
130
+ log(INFO, "\t└──> Keys in records:")
131
+ log(INFO, "\t\t├── Weighted by: '%s'", self.weighted_by_key)
132
+ log(INFO, "\t\t├── ArrayRecord key: '%s'", self.arrayrecord_key)
133
+ log(INFO, "\t\t└── ConfigRecord key: '%s'", self.configrecord_key)
134
+
135
+ def _construct_messages(
136
+ self, record: RecordDict, node_ids: list[int], message_type: str
137
+ ) -> Iterable[Message]:
138
+ """Construct N Messages carrying the same RecordDict payload."""
139
+ messages = []
140
+ for node_id in node_ids: # one message for each node
141
+ message = Message(
142
+ content=record,
143
+ message_type=message_type,
144
+ dst_node_id=node_id,
145
+ )
146
+ messages.append(message)
147
+ return messages
148
+
149
+ def configure_train(
150
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
151
+ ) -> Iterable[Message]:
152
+ """Configure the next round of federated training."""
153
+ # Sample nodes
154
+ num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_train)
155
+ sample_size = max(num_nodes, self.min_train_nodes)
156
+ node_ids, num_total = sample_nodes(grid, self.min_available_nodes, sample_size)
157
+ log(
158
+ INFO,
159
+ "configure_train: Sampled %s nodes (out of %s)",
160
+ len(node_ids),
161
+ len(num_total),
162
+ )
163
+ # Always inject current server round
164
+ config["server-round"] = server_round
165
+
166
+ # Construct messages
167
+ record = RecordDict(
168
+ {self.arrayrecord_key: arrays, self.configrecord_key: config}
169
+ )
170
+ return self._construct_messages(record, node_ids, MessageType.TRAIN)
171
+
172
+ def aggregate_train(
173
+ self,
174
+ server_round: int,
175
+ replies: Iterable[Message],
176
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
177
+ """Aggregate ArrayRecords and MetricRecords in the received Messages."""
178
+ if not replies:
179
+ return None, None
180
+
181
+ # Log if any Messages carried errors
182
+ # Filter messages that carry content
183
+ num_errors = 0
184
+ replies_with_content = []
185
+ for msg in replies:
186
+ if msg.has_error():
187
+ log(
188
+ INFO,
189
+ "Received error in reply from node %d: %s",
190
+ msg.metadata.src_node_id,
191
+ msg.error,
192
+ )
193
+ num_errors += 1
194
+ else:
195
+ replies_with_content.append(msg.content)
196
+
197
+ log(
198
+ INFO,
199
+ "aggregate_train: Received %s results and %s failures",
200
+ len(replies_with_content),
201
+ num_errors,
202
+ )
203
+
204
+ # Ensure expected ArrayRecords and MetricRecords are received
205
+ validate_message_reply_consistency(
206
+ replies=replies_with_content,
207
+ weighted_by_key=self.weighted_by_key,
208
+ check_arrayrecord=True,
209
+ )
210
+
211
+ arrays, metrics = None, None
212
+ if replies_with_content:
213
+ # Aggregate ArrayRecords
214
+ arrays = aggregate_arrayrecords(
215
+ replies_with_content,
216
+ self.weighted_by_key,
217
+ )
218
+
219
+ # Aggregate MetricRecords
220
+ metrics = self.train_metrics_aggr_fn(
221
+ replies_with_content,
222
+ self.weighted_by_key,
223
+ )
224
+ return arrays, metrics
225
+
226
+ def configure_evaluate(
227
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
228
+ ) -> Iterable[Message]:
229
+ """Configure the next round of federated evaluation."""
230
+ # Sample nodes
231
+ num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_evaluate)
232
+ sample_size = max(num_nodes, self.min_evaluate_nodes)
233
+ node_ids, num_total = sample_nodes(grid, self.min_available_nodes, sample_size)
234
+ log(
235
+ INFO,
236
+ "configure_evaluate: Sampled %s nodes (out of %s)",
237
+ len(node_ids),
238
+ len(num_total),
239
+ )
240
+
241
+ # Always inject current server round
242
+ config["server-round"] = server_round
243
+
244
+ # Construct messages
245
+ record = RecordDict(
246
+ {self.arrayrecord_key: arrays, self.configrecord_key: config}
247
+ )
248
+ return self._construct_messages(record, node_ids, MessageType.EVALUATE)
249
+
250
+ def aggregate_evaluate(
251
+ self,
252
+ server_round: int,
253
+ replies: Iterable[Message],
254
+ ) -> Optional[MetricRecord]:
255
+ """Aggregate MetricRecords in the received Messages."""
256
+ if not replies:
257
+ return None
258
+
259
+ # Log if any Messages carried errors
260
+ # Filter messages that carry content
261
+ num_errors = 0
262
+ replies_with_content = []
263
+ for msg in replies:
264
+ if msg.has_error():
265
+ log(
266
+ INFO,
267
+ "Received error in reply from node %d: %s",
268
+ msg.metadata.src_node_id,
269
+ msg.error,
270
+ )
271
+ num_errors += 1
272
+ else:
273
+ replies_with_content.append(msg.content)
274
+
275
+ log(
276
+ INFO,
277
+ "aggregate_evaluate: Received %s results and %s failures",
278
+ len(replies_with_content),
279
+ num_errors,
280
+ )
281
+
282
+ # Ensure expected ArrayRecords and MetricRecords are received
283
+ validate_message_reply_consistency(
284
+ replies=replies_with_content,
285
+ weighted_by_key=self.weighted_by_key,
286
+ check_arrayrecord=False,
287
+ )
288
+ metrics = None
289
+ if replies_with_content:
290
+ # Aggregate MetricRecords
291
+ metrics = self.evaluate_metrics_aggr_fn(
292
+ replies_with_content,
293
+ self.weighted_by_key,
294
+ )
295
+ return metrics
@@ -0,0 +1,218 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Adaptive Federated Optimization (FedOpt) [Reddi et al., 2020] abstract strategy.
16
+
17
+ Paper: arxiv.org/abs/2003.00295
18
+ """
19
+
20
+ from collections.abc import Iterable
21
+ from logging import INFO
22
+ from typing import Callable, Optional
23
+
24
+ import numpy as np
25
+
26
+ from flwr.common import (
27
+ ArrayRecord,
28
+ ConfigRecord,
29
+ Message,
30
+ MetricRecord,
31
+ NDArray,
32
+ RecordDict,
33
+ log,
34
+ )
35
+ from flwr.server import Grid
36
+
37
+ from ..exception import AggregationError
38
+ from .fedavg import FedAvg
39
+
40
+
41
+ # pylint: disable=line-too-long
42
+ class FedOpt(FedAvg):
43
+ """Federated Optim strategy.
44
+
45
+ Implementation based on https://arxiv.org/abs/2003.00295v5
46
+
47
+ Parameters
48
+ ----------
49
+ fraction_train : float (default: 1.0)
50
+ Fraction of nodes used during training. In case `min_train_nodes`
51
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
52
+ will still be sampled.
53
+ fraction_evaluate : float (default: 1.0)
54
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
55
+ is larger than `fraction_evaluate * total_connected_nodes`,
56
+ `min_evaluate_nodes` will still be sampled.
57
+ min_train_nodes : int (default: 2)
58
+ Minimum number of nodes used during training.
59
+ min_evaluate_nodes : int (default: 2)
60
+ Minimum number of nodes used during validation.
61
+ min_available_nodes : int (default: 2)
62
+ Minimum number of total nodes in the system.
63
+ weighted_by_key : str (default: "num-examples")
64
+ The key within each MetricRecord whose value is used as the weight when
65
+ computing weighted averages for both ArrayRecords and MetricRecords.
66
+ arrayrecord_key : str (default: "arrays")
67
+ Key used to store the ArrayRecord when constructing Messages.
68
+ configrecord_key : str (default: "config")
69
+ Key used to store the ConfigRecord when constructing Messages.
70
+ train_metrics_aggr_fn : Optional[callable] (default: None)
71
+ Function with signature (list[RecordDict], str) -> MetricRecord,
72
+ used to aggregate MetricRecords from training round replies.
73
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
74
+ average using the provided weight factor key.
75
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
76
+ Function with signature (list[RecordDict], str) -> MetricRecord,
77
+ used to aggregate MetricRecords from training round replies.
78
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
79
+ average using the provided weight factor key.
80
+ eta : float, optional
81
+ Server-side learning rate. Defaults to 1e-1.
82
+ eta_l : float, optional
83
+ Client-side learning rate. Defaults to 1e-1.
84
+ beta_1 : float, optional
85
+ Momentum parameter. Defaults to 0.0.
86
+ beta_2 : float, optional
87
+ Second moment parameter. Defaults to 0.0.
88
+ tau : float, optional
89
+ Controls the algorithm's degree of adaptability. Defaults to 1e-3.
90
+ """
91
+
92
+ # pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-locals, line-too-long
93
+ def __init__(
94
+ self,
95
+ *,
96
+ fraction_train: float = 1.0,
97
+ fraction_evaluate: float = 1.0,
98
+ min_train_nodes: int = 2,
99
+ min_evaluate_nodes: int = 2,
100
+ min_available_nodes: int = 2,
101
+ weighted_by_key: str = "num-examples",
102
+ arrayrecord_key: str = "arrays",
103
+ configrecord_key: str = "config",
104
+ train_metrics_aggr_fn: Optional[
105
+ Callable[[list[RecordDict], str], MetricRecord]
106
+ ] = None,
107
+ evaluate_metrics_aggr_fn: Optional[
108
+ Callable[[list[RecordDict], str], MetricRecord]
109
+ ] = None,
110
+ eta: float = 1e-1,
111
+ eta_l: float = 1e-1,
112
+ beta_1: float = 0.0,
113
+ beta_2: float = 0.0,
114
+ tau: float = 1e-3,
115
+ ) -> None:
116
+ super().__init__(
117
+ fraction_train=fraction_train,
118
+ fraction_evaluate=fraction_evaluate,
119
+ min_train_nodes=min_train_nodes,
120
+ min_evaluate_nodes=min_evaluate_nodes,
121
+ min_available_nodes=min_available_nodes,
122
+ weighted_by_key=weighted_by_key,
123
+ arrayrecord_key=arrayrecord_key,
124
+ configrecord_key=configrecord_key,
125
+ train_metrics_aggr_fn=train_metrics_aggr_fn,
126
+ evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
127
+ )
128
+ self.current_arrays: Optional[dict[str, NDArray]] = None
129
+ self.eta = eta
130
+ self.eta_l = eta_l
131
+ self.tau = tau
132
+ self.beta_1 = beta_1
133
+ self.beta_2 = beta_2
134
+ self.m_t: Optional[dict[str, NDArray]] = None
135
+ self.v_t: Optional[dict[str, NDArray]] = None
136
+
137
+ def summary(self) -> None:
138
+ """Log summary configuration of the strategy."""
139
+ log(INFO, "\t├──> FedOpt settings:")
140
+ log(
141
+ INFO,
142
+ "\t│\t├── eta (%s) | eta_l (%s)",
143
+ f"{self.eta:.6g}",
144
+ f"{self.eta_l:.6g}",
145
+ )
146
+ log(
147
+ INFO,
148
+ "\t│\t├── beta_1 (%s) | beta_2 (%s)",
149
+ f"{self.beta_1:.6g}",
150
+ f"{self.beta_2:.6g}",
151
+ )
152
+ log(
153
+ INFO,
154
+ "\t│\t└── tau (%s)",
155
+ f"{self.tau:.6g}",
156
+ )
157
+ super().summary()
158
+
159
+ def configure_train(
160
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
161
+ ) -> Iterable[Message]:
162
+ """Configure the next round of federated training."""
163
+ # Keep track of array record being communicated
164
+ self.current_arrays = {k: array.numpy() for k, array in arrays.items()}
165
+ return super().configure_train(server_round, arrays, config, grid)
166
+
167
+ def _compute_deltat_and_mt(
168
+ self, aggregated_arrayrecord: ArrayRecord
169
+ ) -> tuple[dict[str, NDArray], dict[str, NDArray], dict[str, NDArray]]:
170
+ """Compute delta_t and m_t.
171
+
172
+ This is a shared stage during aggregation for FedAdagrad, FedAdam and FedYogi.
173
+ """
174
+ if self.current_arrays is None:
175
+ reason = (
176
+ "Current arrays not set. Ensure that `configure_train` has been "
177
+ "called before aggregation."
178
+ )
179
+ raise AggregationError(reason=reason)
180
+
181
+ aggregated_ndarrays = {
182
+ k: array.numpy() for k, array in aggregated_arrayrecord.items()
183
+ }
184
+
185
+ # Check keys in aggregated arrays match those in current arrays
186
+ if set(aggregated_ndarrays.keys()) != set(self.current_arrays.keys()):
187
+ reason = (
188
+ "Keys of the aggregated arrays do not match those of the arrays "
189
+ "stored at the strategy. `delta_t = aggregated_arrays - "
190
+ "current_arrays` cannot be computed."
191
+ )
192
+ raise AggregationError(reason=reason)
193
+
194
+ # Check that the shape of values match
195
+ # Only shapes that match can compute delta_t (we don't want
196
+ # broadcasting to happen)
197
+ for k, x in aggregated_ndarrays.items():
198
+ if x.shape != self.current_arrays[k].shape:
199
+ reason = (
200
+ f"Shape of aggregated array '{k}' does not match "
201
+ f"shape of the array under the same key stored in the strategy. "
202
+ f"Cannot compute `delta_t`."
203
+ )
204
+ raise AggregationError(reason=reason)
205
+
206
+ delta_t = {
207
+ k: x - self.current_arrays[k] for k, x in aggregated_ndarrays.items()
208
+ }
209
+
210
+ # m_t
211
+ if not self.m_t:
212
+ self.m_t = {k: np.zeros_like(v) for k, v in aggregated_ndarrays.items()}
213
+ self.m_t = {
214
+ k: self.beta_1 * v + (1 - self.beta_1) * delta_t[k]
215
+ for k, v in self.m_t.items()
216
+ }
217
+
218
+ return delta_t, self.m_t, aggregated_ndarrays
@@ -0,0 +1,173 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Adaptive Federated Optimization using Yogi (FedYogi) [Reddi et al., 2020] strategy.
16
+
17
+ Paper: arxiv.org/abs/2003.00295
18
+ """
19
+
20
+
21
+ from collections import OrderedDict
22
+ from collections.abc import Iterable
23
+ from typing import Callable, Optional
24
+
25
+ import numpy as np
26
+
27
+ from flwr.common import Array, ArrayRecord, Message, MetricRecord, RecordDict
28
+
29
+ from ..exception import AggregationError
30
+ from .fedopt import FedOpt
31
+
32
+
33
+ # pylint: disable=line-too-long
34
+ class FedYogi(FedOpt):
35
+ """FedYogi [Reddi et al., 2020] strategy.
36
+
37
+ Implementation based on https://arxiv.org/abs/2003.00295v5
38
+
39
+
40
+ Parameters
41
+ ----------
42
+ fraction_train : float (default: 1.0)
43
+ Fraction of nodes used during training. In case `min_train_nodes`
44
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
45
+ will still be sampled.
46
+ fraction_evaluate : float (default: 1.0)
47
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
48
+ is larger than `fraction_evaluate * total_connected_nodes`,
49
+ `min_evaluate_nodes` will still be sampled.
50
+ min_train_nodes : int (default: 2)
51
+ Minimum number of nodes used during training.
52
+ min_evaluate_nodes : int (default: 2)
53
+ Minimum number of nodes used during validation.
54
+ min_available_nodes : int (default: 2)
55
+ Minimum number of total nodes in the system.
56
+ weighted_by_key : str (default: "num-examples")
57
+ The key within each MetricRecord whose value is used as the weight when
58
+ computing weighted averages for both ArrayRecords and MetricRecords.
59
+ arrayrecord_key : str (default: "arrays")
60
+ Key used to store the ArrayRecord when constructing Messages.
61
+ configrecord_key : str (default: "config")
62
+ Key used to store the ConfigRecord when constructing Messages.
63
+ train_metrics_aggr_fn : Optional[callable] (default: None)
64
+ Function with signature (list[RecordDict], str) -> MetricRecord,
65
+ used to aggregate MetricRecords from training round replies.
66
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
67
+ average using the provided weight factor key.
68
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
69
+ Function with signature (list[RecordDict], str) -> MetricRecord,
70
+ used to aggregate MetricRecords from training round replies.
71
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
72
+ average using the provided weight factor key.
73
+ eta : float, optional
74
+ Server-side learning rate. Defaults to 1e-2.
75
+ eta_l : float, optional
76
+ Client-side learning rate. Defaults to 0.0316.
77
+ beta_1 : float, optional
78
+ Momentum parameter. Defaults to 0.9.
79
+ beta_2 : float, optional
80
+ Second moment parameter. Defaults to 0.99.
81
+ tau : float, optional
82
+ Controls the algorithm's degree of adaptability.
83
+ Defaults to 1e-3.
84
+ """
85
+
86
+ # pylint: disable=too-many-arguments, too-many-locals
87
+ def __init__(
88
+ self,
89
+ *,
90
+ fraction_train: float = 1.0,
91
+ fraction_evaluate: float = 1.0,
92
+ min_train_nodes: int = 2,
93
+ min_evaluate_nodes: int = 2,
94
+ min_available_nodes: int = 2,
95
+ weighted_by_key: str = "num-examples",
96
+ arrayrecord_key: str = "arrays",
97
+ configrecord_key: str = "config",
98
+ train_metrics_aggr_fn: Optional[
99
+ Callable[[list[RecordDict], str], MetricRecord]
100
+ ] = None,
101
+ evaluate_metrics_aggr_fn: Optional[
102
+ Callable[[list[RecordDict], str], MetricRecord]
103
+ ] = None,
104
+ eta: float = 1e-2,
105
+ eta_l: float = 0.0316,
106
+ beta_1: float = 0.9,
107
+ beta_2: float = 0.99,
108
+ tau: float = 1e-3,
109
+ ) -> None:
110
+ super().__init__(
111
+ fraction_train=fraction_train,
112
+ fraction_evaluate=fraction_evaluate,
113
+ min_train_nodes=min_train_nodes,
114
+ min_evaluate_nodes=min_evaluate_nodes,
115
+ min_available_nodes=min_available_nodes,
116
+ weighted_by_key=weighted_by_key,
117
+ arrayrecord_key=arrayrecord_key,
118
+ configrecord_key=configrecord_key,
119
+ train_metrics_aggr_fn=train_metrics_aggr_fn,
120
+ evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
121
+ eta=eta,
122
+ eta_l=eta_l,
123
+ beta_1=beta_1,
124
+ beta_2=beta_2,
125
+ tau=tau,
126
+ )
127
+
128
+ def aggregate_train(
129
+ self,
130
+ server_round: int,
131
+ replies: Iterable[Message],
132
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
133
+ """Aggregate ArrayRecords and MetricRecords in the received Messages."""
134
+ aggregated_arrayrecord, aggregated_metrics = super().aggregate_train(
135
+ server_round, replies
136
+ )
137
+
138
+ if aggregated_arrayrecord is None:
139
+ return aggregated_arrayrecord, aggregated_metrics
140
+
141
+ if self.current_arrays is None:
142
+ reason = (
143
+ "Current arrays not set. Ensure that `configure_train` has been "
144
+ "called before aggregation."
145
+ )
146
+ raise AggregationError(reason=reason)
147
+
148
+ # Compute intermediate variables
149
+ delta_t, m_t, aggregated_ndarrays = self._compute_deltat_and_mt(
150
+ aggregated_arrayrecord
151
+ )
152
+
153
+ # v_t
154
+ if not self.v_t:
155
+ self.v_t = {k: np.zeros_like(v) for k, v in aggregated_ndarrays.items()}
156
+ self.v_t = {
157
+ k: v
158
+ - (1.0 - self.beta_2) * (delta_t[k] ** 2) * np.sign(v - delta_t[k] ** 2)
159
+ for k, v in self.v_t.items()
160
+ }
161
+
162
+ new_arrays = {
163
+ k: x + self.eta * m_t[k] / (np.sqrt(self.v_t[k]) + self.tau)
164
+ for k, x in self.current_arrays.items()
165
+ }
166
+
167
+ # Update current arrays
168
+ self.current_arrays = new_arrays
169
+
170
+ return (
171
+ ArrayRecord(OrderedDict({k: Array(v) for k, v in new_arrays.items()})),
172
+ aggregated_metrics,
173
+ )