flwr 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. flwr/__init__.py +4 -1
  2. flwr/app/__init__.py +28 -0
  3. flwr/app/exception.py +31 -0
  4. flwr/cli/app.py +2 -0
  5. flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
  6. flwr/cli/cli_user_auth_interceptor.py +1 -1
  7. flwr/cli/config_utils.py +3 -3
  8. flwr/cli/constant.py +25 -8
  9. flwr/cli/log.py +9 -9
  10. flwr/cli/login/login.py +3 -3
  11. flwr/cli/ls.py +5 -5
  12. flwr/cli/new/new.py +15 -2
  13. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  14. flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +1 -0
  15. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  16. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
  17. flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
  18. flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
  19. flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
  20. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
  21. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
  22. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
  23. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
  24. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
  25. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  26. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  27. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  28. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  29. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  30. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  31. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
  32. flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
  33. flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
  34. flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
  35. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
  36. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
  37. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
  38. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
  39. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
  40. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
  41. flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
  42. flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
  43. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
  44. flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +111 -0
  45. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
  46. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
  47. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  48. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  49. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  50. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  51. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  52. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  53. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  54. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +53 -0
  55. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  56. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  57. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
  58. flwr/cli/pull.py +100 -0
  59. flwr/cli/run/run.py +9 -13
  60. flwr/cli/stop.py +7 -4
  61. flwr/cli/utils.py +36 -8
  62. flwr/client/grpc_rere_client/connection.py +1 -12
  63. flwr/client/rest_client/connection.py +3 -0
  64. flwr/clientapp/__init__.py +10 -0
  65. flwr/clientapp/mod/__init__.py +29 -0
  66. flwr/clientapp/mod/centraldp_mods.py +248 -0
  67. flwr/clientapp/mod/localdp_mod.py +169 -0
  68. flwr/clientapp/typing.py +22 -0
  69. flwr/common/args.py +20 -6
  70. flwr/common/auth_plugin/__init__.py +4 -4
  71. flwr/common/auth_plugin/auth_plugin.py +7 -7
  72. flwr/common/constant.py +26 -4
  73. flwr/common/event_log_plugin/event_log_plugin.py +1 -1
  74. flwr/common/exit/__init__.py +4 -0
  75. flwr/common/exit/exit.py +8 -1
  76. flwr/common/exit/exit_code.py +30 -7
  77. flwr/common/exit/exit_handler.py +62 -0
  78. flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
  79. flwr/common/grpc.py +0 -11
  80. flwr/common/inflatable_utils.py +1 -1
  81. flwr/common/logger.py +1 -1
  82. flwr/common/record/typeddict.py +12 -0
  83. flwr/common/retry_invoker.py +30 -11
  84. flwr/common/telemetry.py +4 -0
  85. flwr/compat/server/app.py +2 -2
  86. flwr/proto/appio_pb2.py +25 -17
  87. flwr/proto/appio_pb2.pyi +46 -2
  88. flwr/proto/clientappio_pb2.py +3 -11
  89. flwr/proto/clientappio_pb2.pyi +0 -47
  90. flwr/proto/clientappio_pb2_grpc.py +19 -20
  91. flwr/proto/clientappio_pb2_grpc.pyi +10 -11
  92. flwr/proto/control_pb2.py +66 -0
  93. flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +24 -0
  94. flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +88 -54
  95. flwr/proto/control_pb2_grpc.pyi +106 -0
  96. flwr/proto/serverappio_pb2.py +2 -2
  97. flwr/proto/serverappio_pb2_grpc.py +68 -0
  98. flwr/proto/serverappio_pb2_grpc.pyi +26 -0
  99. flwr/proto/simulationio_pb2.py +4 -11
  100. flwr/proto/simulationio_pb2.pyi +0 -58
  101. flwr/proto/simulationio_pb2_grpc.py +129 -27
  102. flwr/proto/simulationio_pb2_grpc.pyi +52 -13
  103. flwr/server/app.py +142 -152
  104. flwr/server/grid/grpc_grid.py +3 -0
  105. flwr/server/grid/inmemory_grid.py +1 -0
  106. flwr/server/serverapp/app.py +157 -146
  107. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
  108. flwr/server/superlink/fleet/vce/vce_api.py +6 -6
  109. flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
  110. flwr/server/superlink/linkstate/linkstate.py +2 -1
  111. flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
  112. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -1
  113. flwr/server/superlink/serverappio/serverappio_servicer.py +61 -6
  114. flwr/server/superlink/simulation/simulationio_servicer.py +97 -21
  115. flwr/serverapp/__init__.py +12 -0
  116. flwr/serverapp/exception.py +38 -0
  117. flwr/serverapp/strategy/__init__.py +64 -0
  118. flwr/serverapp/strategy/bulyan.py +238 -0
  119. flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
  120. flwr/serverapp/strategy/dp_fixed_clipping.py +374 -0
  121. flwr/serverapp/strategy/fedadagrad.py +159 -0
  122. flwr/serverapp/strategy/fedadam.py +178 -0
  123. flwr/serverapp/strategy/fedavg.py +320 -0
  124. flwr/serverapp/strategy/fedavgm.py +198 -0
  125. flwr/serverapp/strategy/fedmedian.py +105 -0
  126. flwr/serverapp/strategy/fedopt.py +218 -0
  127. flwr/serverapp/strategy/fedprox.py +174 -0
  128. flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
  129. flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
  130. flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
  131. flwr/serverapp/strategy/fedyogi.py +170 -0
  132. flwr/serverapp/strategy/krum.py +112 -0
  133. flwr/serverapp/strategy/multikrum.py +247 -0
  134. flwr/serverapp/strategy/qfedavg.py +252 -0
  135. flwr/serverapp/strategy/result.py +105 -0
  136. flwr/serverapp/strategy/strategy.py +285 -0
  137. flwr/serverapp/strategy/strategy_utils.py +299 -0
  138. flwr/simulation/app.py +161 -164
  139. flwr/simulation/run_simulation.py +25 -30
  140. flwr/supercore/app_utils.py +58 -0
  141. flwr/{supernode/scheduler → supercore/cli}/__init__.py +3 -3
  142. flwr/supercore/cli/flower_superexec.py +166 -0
  143. flwr/supercore/constant.py +19 -0
  144. flwr/supercore/{scheduler → corestate}/__init__.py +3 -3
  145. flwr/supercore/corestate/corestate.py +81 -0
  146. flwr/supercore/grpc_health/__init__.py +3 -0
  147. flwr/supercore/grpc_health/health_server.py +53 -0
  148. flwr/supercore/grpc_health/simple_health_servicer.py +2 -2
  149. flwr/{superexec → supercore/superexec}/__init__.py +1 -1
  150. flwr/supercore/superexec/plugin/__init__.py +28 -0
  151. flwr/{supernode/scheduler/simple_clientapp_scheduler_plugin.py → supercore/superexec/plugin/base_exec_plugin.py} +10 -6
  152. flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
  153. flwr/supercore/{scheduler/plugin.py → superexec/plugin/exec_plugin.py} +15 -5
  154. flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
  155. flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
  156. flwr/supercore/superexec/run_superexec.py +199 -0
  157. flwr/superlink/artifact_provider/__init__.py +22 -0
  158. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  159. flwr/superlink/servicer/__init__.py +15 -0
  160. flwr/superlink/servicer/control/__init__.py +22 -0
  161. flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +7 -7
  162. flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +27 -29
  163. flwr/{superexec/exec_license_interceptor.py → superlink/servicer/control/control_license_interceptor.py} +6 -6
  164. flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +127 -31
  165. flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +10 -10
  166. flwr/supernode/cli/flower_supernode.py +3 -0
  167. flwr/supernode/cli/flwr_clientapp.py +18 -21
  168. flwr/supernode/nodestate/in_memory_nodestate.py +2 -2
  169. flwr/supernode/nodestate/nodestate.py +3 -59
  170. flwr/supernode/runtime/run_clientapp.py +39 -102
  171. flwr/supernode/servicer/clientappio/clientappio_servicer.py +10 -17
  172. flwr/supernode/start_client_internal.py +35 -76
  173. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/METADATA +9 -18
  174. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/RECORD +176 -128
  175. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/entry_points.txt +1 -0
  176. flwr/proto/exec_pb2.py +0 -62
  177. flwr/proto/exec_pb2_grpc.pyi +0 -93
  178. flwr/superexec/app.py +0 -45
  179. flwr/superexec/deployment.py +0 -191
  180. flwr/superexec/executor.py +0 -100
  181. flwr/superexec/simulation.py +0 -129
  182. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,178 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Adaptive Federated Optimization using Adam (FedAdam) strategy.
16
+
17
+ [Reddi et al., 2020]
18
+
19
+ Paper: arxiv.org/abs/2003.00295
20
+ """
21
+
22
+ from collections import OrderedDict
23
+ from collections.abc import Iterable
24
+ from typing import Callable, Optional
25
+
26
+ import numpy as np
27
+
28
+ from flwr.common import Array, ArrayRecord, Message, MetricRecord, RecordDict
29
+
30
+ from ..exception import AggregationError
31
+ from .fedopt import FedOpt
32
+
33
+
34
+ # pylint: disable=line-too-long
35
+ class FedAdam(FedOpt):
36
+ """FedAdam - Adaptive Federated Optimization using Adam.
37
+
38
+ Implementation based on https://arxiv.org/abs/2003.00295v5
39
+
40
+ Parameters
41
+ ----------
42
+ fraction_train : float (default: 1.0)
43
+ Fraction of nodes used during training. In case `min_train_nodes`
44
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
45
+ will still be sampled.
46
+ fraction_evaluate : float (default: 1.0)
47
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
48
+ is larger than `fraction_evaluate * total_connected_nodes`,
49
+ `min_evaluate_nodes` will still be sampled.
50
+ min_train_nodes : int (default: 2)
51
+ Minimum number of nodes used during training.
52
+ min_evaluate_nodes : int (default: 2)
53
+ Minimum number of nodes used during validation.
54
+ min_available_nodes : int (default: 2)
55
+ Minimum number of total nodes in the system.
56
+ weighted_by_key : str (default: "num-examples")
57
+ The key within each MetricRecord whose value is used as the weight when
58
+ computing weighted averages for both ArrayRecords and MetricRecords.
59
+ arrayrecord_key : str (default: "arrays")
60
+ Key used to store the ArrayRecord when constructing Messages.
61
+ configrecord_key : str (default: "config")
62
+ Key used to store the ConfigRecord when constructing Messages.
63
+ train_metrics_aggr_fn : Optional[callable] (default: None)
64
+ Function with signature (list[RecordDict], str) -> MetricRecord,
65
+ used to aggregate MetricRecords from training round replies.
66
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
67
+ average using the provided weight factor key.
68
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
69
+ Function with signature (list[RecordDict], str) -> MetricRecord,
70
+ used to aggregate MetricRecords from training round replies.
71
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
72
+ average using the provided weight factor key.
73
+ eta : float, optional
74
+ Server-side learning rate. Defaults to 1e-1.
75
+ eta_l : float, optional
76
+ Client-side learning rate. Defaults to 1e-1.
77
+ beta_1 : float, optional
78
+ Momentum parameter. Defaults to 0.9.
79
+ beta_2 : float, optional
80
+ Second moment parameter. Defaults to 0.99.
81
+ tau : float, optional
82
+ Controls the algorithm's degree of adaptability. Defaults to 1e-3.
83
+ """
84
+
85
+ # pylint: disable=too-many-arguments, too-many-locals
86
+ def __init__(
87
+ self,
88
+ *,
89
+ fraction_train: float = 1.0,
90
+ fraction_evaluate: float = 1.0,
91
+ min_train_nodes: int = 2,
92
+ min_evaluate_nodes: int = 2,
93
+ min_available_nodes: int = 2,
94
+ weighted_by_key: str = "num-examples",
95
+ arrayrecord_key: str = "arrays",
96
+ configrecord_key: str = "config",
97
+ train_metrics_aggr_fn: Optional[
98
+ Callable[[list[RecordDict], str], MetricRecord]
99
+ ] = None,
100
+ evaluate_metrics_aggr_fn: Optional[
101
+ Callable[[list[RecordDict], str], MetricRecord]
102
+ ] = None,
103
+ eta: float = 1e-1,
104
+ eta_l: float = 1e-1,
105
+ beta_1: float = 0.9,
106
+ beta_2: float = 0.99,
107
+ tau: float = 1e-3,
108
+ ) -> None:
109
+ super().__init__(
110
+ fraction_train=fraction_train,
111
+ fraction_evaluate=fraction_evaluate,
112
+ min_train_nodes=min_train_nodes,
113
+ min_evaluate_nodes=min_evaluate_nodes,
114
+ min_available_nodes=min_available_nodes,
115
+ weighted_by_key=weighted_by_key,
116
+ arrayrecord_key=arrayrecord_key,
117
+ configrecord_key=configrecord_key,
118
+ train_metrics_aggr_fn=train_metrics_aggr_fn,
119
+ evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
120
+ eta=eta,
121
+ eta_l=eta_l,
122
+ beta_1=beta_1,
123
+ beta_2=beta_2,
124
+ tau=tau,
125
+ )
126
+
127
+ def aggregate_train(
128
+ self,
129
+ server_round: int,
130
+ replies: Iterable[Message],
131
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
132
+ """Aggregate ArrayRecords and MetricRecords in the received Messages."""
133
+ aggregated_arrayrecord, aggregated_metrics = super().aggregate_train(
134
+ server_round, replies
135
+ )
136
+
137
+ if aggregated_arrayrecord is None:
138
+ return aggregated_arrayrecord, aggregated_metrics
139
+
140
+ if self.current_arrays is None:
141
+ reason = (
142
+ "Current arrays not set. Ensure that `configure_train` has been "
143
+ "called before aggregation."
144
+ )
145
+ raise AggregationError(reason=reason)
146
+
147
+ # Compute intermediate variables
148
+ delta_t, m_t, aggregated_ndarrays = self._compute_deltat_and_mt(
149
+ aggregated_arrayrecord
150
+ )
151
+
152
+ # v_t
153
+ if not self.v_t:
154
+ self.v_t = {k: np.zeros_like(v) for k, v in aggregated_ndarrays.items()}
155
+ self.v_t = {
156
+ k: self.beta_2 * v + (1 - self.beta_2) * (delta_t[k] ** 2)
157
+ for k, v in self.v_t.items()
158
+ }
159
+
160
+ # Compute the bias-corrected learning rate, `eta_norm` for improving convergence
161
+ # in the early rounds of FL training. This `eta_norm` is `\alpha_t` in Kingma &
162
+ # Ba, 2014 (http://arxiv.org/abs/1412.6980) "Adam: A Method for Stochastic
163
+ # Optimization" in the formula line right before Section 2.1.
164
+ eta_norm = (
165
+ self.eta
166
+ * np.sqrt(1 - np.power(self.beta_2, server_round + 1.0))
167
+ / (1 - np.power(self.beta_1, server_round + 1.0))
168
+ )
169
+
170
+ new_arrays = {
171
+ k: x + eta_norm * m_t[k] / (np.sqrt(self.v_t[k]) + self.tau)
172
+ for k, x in self.current_arrays.items()
173
+ }
174
+
175
+ return (
176
+ ArrayRecord(OrderedDict({k: Array(v) for k, v in new_arrays.items()})),
177
+ aggregated_metrics,
178
+ )
@@ -0,0 +1,320 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower message-based FedAvg strategy."""
16
+
17
+
18
+ from collections.abc import Iterable
19
+ from logging import INFO, WARNING
20
+ from typing import Callable, Optional
21
+
22
+ from flwr.common import (
23
+ ArrayRecord,
24
+ ConfigRecord,
25
+ Message,
26
+ MessageType,
27
+ MetricRecord,
28
+ RecordDict,
29
+ log,
30
+ )
31
+ from flwr.server import Grid
32
+
33
+ from .strategy import Strategy
34
+ from .strategy_utils import (
35
+ aggregate_arrayrecords,
36
+ aggregate_metricrecords,
37
+ sample_nodes,
38
+ validate_message_reply_consistency,
39
+ )
40
+
41
+
42
+ # pylint: disable=too-many-instance-attributes
43
+ class FedAvg(Strategy):
44
+ """Federated Averaging strategy.
45
+
46
+ Implementation based on https://arxiv.org/abs/1602.05629
47
+
48
+ Parameters
49
+ ----------
50
+ fraction_train : float (default: 1.0)
51
+ Fraction of nodes used during training. In case `min_train_nodes`
52
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
53
+ will still be sampled.
54
+ fraction_evaluate : float (default: 1.0)
55
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
56
+ is larger than `fraction_evaluate * total_connected_nodes`,
57
+ `min_evaluate_nodes` will still be sampled.
58
+ min_train_nodes : int (default: 2)
59
+ Minimum number of nodes used during training.
60
+ min_evaluate_nodes : int (default: 2)
61
+ Minimum number of nodes used during validation.
62
+ min_available_nodes : int (default: 2)
63
+ Minimum number of total nodes in the system.
64
+ weighted_by_key : str (default: "num-examples")
65
+ The key within each MetricRecord whose value is used as the weight when
66
+ computing weighted averages for both ArrayRecords and MetricRecords.
67
+ arrayrecord_key : str (default: "arrays")
68
+ Key used to store the ArrayRecord when constructing Messages.
69
+ configrecord_key : str (default: "config")
70
+ Key used to store the ConfigRecord when constructing Messages.
71
+ train_metrics_aggr_fn : Optional[callable] (default: None)
72
+ Function with signature (list[RecordDict], str) -> MetricRecord,
73
+ used to aggregate MetricRecords from training round replies.
74
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
75
+ average using the provided weight factor key.
76
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
77
+ Function with signature (list[RecordDict], str) -> MetricRecord,
78
+ used to aggregate MetricRecords from training round replies.
79
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
80
+ average using the provided weight factor key.
81
+ """
82
+
83
+ # pylint: disable=too-many-arguments,too-many-positional-arguments
84
+ def __init__(
85
+ self,
86
+ fraction_train: float = 1.0,
87
+ fraction_evaluate: float = 1.0,
88
+ min_train_nodes: int = 2,
89
+ min_evaluate_nodes: int = 2,
90
+ min_available_nodes: int = 2,
91
+ weighted_by_key: str = "num-examples",
92
+ arrayrecord_key: str = "arrays",
93
+ configrecord_key: str = "config",
94
+ train_metrics_aggr_fn: Optional[
95
+ Callable[[list[RecordDict], str], MetricRecord]
96
+ ] = None,
97
+ evaluate_metrics_aggr_fn: Optional[
98
+ Callable[[list[RecordDict], str], MetricRecord]
99
+ ] = None,
100
+ ) -> None:
101
+ self.fraction_train = fraction_train
102
+ self.fraction_evaluate = fraction_evaluate
103
+ self.min_train_nodes = min_train_nodes
104
+ self.min_evaluate_nodes = min_evaluate_nodes
105
+ self.min_available_nodes = min_available_nodes
106
+ self.weighted_by_key = weighted_by_key
107
+ self.arrayrecord_key = arrayrecord_key
108
+ self.configrecord_key = configrecord_key
109
+ self.train_metrics_aggr_fn = train_metrics_aggr_fn or aggregate_metricrecords
110
+ self.evaluate_metrics_aggr_fn = (
111
+ evaluate_metrics_aggr_fn or aggregate_metricrecords
112
+ )
113
+
114
+ if self.fraction_evaluate == 0.0:
115
+ self.min_evaluate_nodes = 0
116
+ log(
117
+ WARNING,
118
+ "fraction_evaluate is set to 0.0. "
119
+ "Federated evaluation will be skipped.",
120
+ )
121
+ if self.fraction_train == 0.0:
122
+ self.min_train_nodes = 0
123
+ log(
124
+ WARNING,
125
+ "fraction_train is set to 0.0. Federated training will be skipped.",
126
+ )
127
+
128
+ def summary(self) -> None:
129
+ """Log summary configuration of the strategy."""
130
+ log(INFO, "\t├──> Sampling:")
131
+ log(
132
+ INFO,
133
+ "\t│\t├──Fraction: train (%.2f) | evaluate ( %.2f)",
134
+ self.fraction_train,
135
+ self.fraction_evaluate,
136
+ ) # pylint: disable=line-too-long
137
+ log(
138
+ INFO,
139
+ "\t│\t├──Minimum nodes: train (%d) | evaluate (%d)",
140
+ self.min_train_nodes,
141
+ self.min_evaluate_nodes,
142
+ ) # pylint: disable=line-too-long
143
+ log(INFO, "\t│\t└──Minimum available nodes: %d", self.min_available_nodes)
144
+ log(INFO, "\t└──> Keys in records:")
145
+ log(INFO, "\t\t├── Weighted by: '%s'", self.weighted_by_key)
146
+ log(INFO, "\t\t├── ArrayRecord key: '%s'", self.arrayrecord_key)
147
+ log(INFO, "\t\t└── ConfigRecord key: '%s'", self.configrecord_key)
148
+
149
+ def _construct_messages(
150
+ self, record: RecordDict, node_ids: list[int], message_type: str
151
+ ) -> Iterable[Message]:
152
+ """Construct N Messages carrying the same RecordDict payload."""
153
+ messages = []
154
+ for node_id in node_ids: # one message for each node
155
+ message = Message(
156
+ content=record,
157
+ message_type=message_type,
158
+ dst_node_id=node_id,
159
+ )
160
+ messages.append(message)
161
+ return messages
162
+
163
+ def configure_train(
164
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
165
+ ) -> Iterable[Message]:
166
+ """Configure the next round of federated training."""
167
+ # Do not configure federated train if fraction_train is 0.
168
+ if self.fraction_train == 0.0:
169
+ return []
170
+ # Sample nodes
171
+ num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_train)
172
+ sample_size = max(num_nodes, self.min_train_nodes)
173
+ node_ids, num_total = sample_nodes(grid, self.min_available_nodes, sample_size)
174
+ log(
175
+ INFO,
176
+ "configure_train: Sampled %s nodes (out of %s)",
177
+ len(node_ids),
178
+ len(num_total),
179
+ )
180
+ # Always inject current server round
181
+ config["server-round"] = server_round
182
+
183
+ # Construct messages
184
+ record = RecordDict(
185
+ {self.arrayrecord_key: arrays, self.configrecord_key: config}
186
+ )
187
+ return self._construct_messages(record, node_ids, MessageType.TRAIN)
188
+
189
+ def _check_and_log_replies(
190
+ self, replies: Iterable[Message], is_train: bool, validate: bool = True
191
+ ) -> tuple[list[Message], list[Message]]:
192
+ """Check replies for errors and log them.
193
+
194
+ Parameters
195
+ ----------
196
+ replies : Iterable[Message]
197
+ Iterable of reply Messages.
198
+ is_train : bool
199
+ Set to True if the replies are from a training round; False otherwise.
200
+ This impacts logging and validation behavior.
201
+ validate : bool (default: True)
202
+ Whether to validate the reply contents for consistency.
203
+
204
+ Returns
205
+ -------
206
+ tuple[list[Message], list[Message]]
207
+ A tuple containing two lists:
208
+ - Messages with valid contents.
209
+ - Messages with errors.
210
+ """
211
+ if not replies:
212
+ return [], []
213
+
214
+ # Filter messages that carry content
215
+ valid_replies: list[Message] = []
216
+ error_replies: list[Message] = []
217
+ for msg in replies:
218
+ if msg.has_error():
219
+ error_replies.append(msg)
220
+ else:
221
+ valid_replies.append(msg)
222
+
223
+ log(
224
+ INFO,
225
+ "%s: Received %s results and %s failures",
226
+ "aggregate_train" if is_train else "aggregate_evaluate",
227
+ len(valid_replies),
228
+ len(error_replies),
229
+ )
230
+
231
+ # Log errors
232
+ for msg in error_replies:
233
+ log(
234
+ INFO,
235
+ "\t> Received error in reply from node %d: %s",
236
+ msg.metadata.src_node_id,
237
+ msg.error.reason,
238
+ )
239
+
240
+ # Ensure expected ArrayRecords and MetricRecords are received
241
+ if validate and valid_replies:
242
+ validate_message_reply_consistency(
243
+ replies=[msg.content for msg in valid_replies],
244
+ weighted_by_key=self.weighted_by_key,
245
+ check_arrayrecord=is_train,
246
+ )
247
+
248
+ return valid_replies, error_replies
249
+
250
+ def aggregate_train(
251
+ self,
252
+ server_round: int,
253
+ replies: Iterable[Message],
254
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
255
+ """Aggregate ArrayRecords and MetricRecords in the received Messages."""
256
+ valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
257
+
258
+ arrays, metrics = None, None
259
+ if valid_replies:
260
+ reply_contents = [msg.content for msg in valid_replies]
261
+
262
+ # Aggregate ArrayRecords
263
+ arrays = aggregate_arrayrecords(
264
+ reply_contents,
265
+ self.weighted_by_key,
266
+ )
267
+
268
+ # Aggregate MetricRecords
269
+ metrics = self.train_metrics_aggr_fn(
270
+ reply_contents,
271
+ self.weighted_by_key,
272
+ )
273
+ return arrays, metrics
274
+
275
+ def configure_evaluate(
276
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
277
+ ) -> Iterable[Message]:
278
+ """Configure the next round of federated evaluation."""
279
+ # Do not configure federated evaluation if fraction_evaluate is 0.
280
+ if self.fraction_evaluate == 0.0:
281
+ return []
282
+
283
+ # Sample nodes
284
+ num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_evaluate)
285
+ sample_size = max(num_nodes, self.min_evaluate_nodes)
286
+ node_ids, num_total = sample_nodes(grid, self.min_available_nodes, sample_size)
287
+ log(
288
+ INFO,
289
+ "configure_evaluate: Sampled %s nodes (out of %s)",
290
+ len(node_ids),
291
+ len(num_total),
292
+ )
293
+
294
+ # Always inject current server round
295
+ config["server-round"] = server_round
296
+
297
+ # Construct messages
298
+ record = RecordDict(
299
+ {self.arrayrecord_key: arrays, self.configrecord_key: config}
300
+ )
301
+ return self._construct_messages(record, node_ids, MessageType.EVALUATE)
302
+
303
+ def aggregate_evaluate(
304
+ self,
305
+ server_round: int,
306
+ replies: Iterable[Message],
307
+ ) -> Optional[MetricRecord]:
308
+ """Aggregate MetricRecords in the received Messages."""
309
+ valid_replies, _ = self._check_and_log_replies(replies, is_train=False)
310
+
311
+ metrics = None
312
+ if valid_replies:
313
+ reply_contents = [msg.content for msg in valid_replies]
314
+
315
+ # Aggregate MetricRecords
316
+ metrics = self.evaluate_metrics_aggr_fn(
317
+ reply_contents,
318
+ self.weighted_by_key,
319
+ )
320
+ return metrics
@@ -0,0 +1,198 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Federated Averaging with Momentum (FedAvgM) [Hsu et al., 2019] strategy.
16
+
17
+ Paper: arxiv.org/pdf/1909.06335.pdf
18
+ """
19
+
20
+
21
+ from collections import OrderedDict
22
+ from collections.abc import Iterable
23
+ from logging import INFO
24
+ from typing import Callable, Optional
25
+
26
+ from flwr.common import (
27
+ Array,
28
+ ArrayRecord,
29
+ ConfigRecord,
30
+ Message,
31
+ MetricRecord,
32
+ NDArrays,
33
+ RecordDict,
34
+ log,
35
+ )
36
+ from flwr.server import Grid
37
+
38
+ from ..exception import AggregationError
39
+ from .fedavg import FedAvg
40
+
41
+
42
+ class FedAvgM(FedAvg):
43
+ """Federated Averaging with Momentum strategy.
44
+
45
+ Implementation based on https://arxiv.org/abs/1909.06335
46
+
47
+ Parameters
48
+ ----------
49
+ fraction_train : float (default: 1.0)
50
+ Fraction of nodes used during training. In case `min_train_nodes`
51
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
52
+ will still be sampled.
53
+ fraction_evaluate : float (default: 1.0)
54
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
55
+ is larger than `fraction_evaluate * total_connected_nodes`,
56
+ `min_evaluate_nodes` will still be sampled.
57
+ min_train_nodes : int (default: 2)
58
+ Minimum number of nodes used during training.
59
+ min_evaluate_nodes : int (default: 2)
60
+ Minimum number of nodes used during validation.
61
+ min_available_nodes : int (default: 2)
62
+ Minimum number of total nodes in the system.
63
+ weighted_by_key : str (default: "num-examples")
64
+ The key within each MetricRecord whose value is used as the weight when
65
+ computing weighted averages for both ArrayRecords and MetricRecords.
66
+ arrayrecord_key : str (default: "arrays")
67
+ Key used to store the ArrayRecord when constructing Messages.
68
+ configrecord_key : str (default: "config")
69
+ Key used to store the ConfigRecord when constructing Messages.
70
+ train_metrics_aggr_fn : Optional[callable] (default: None)
71
+ Function with signature (list[RecordDict], str) -> MetricRecord,
72
+ used to aggregate MetricRecords from training round replies.
73
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
74
+ average using the provided weight factor key.
75
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
76
+ Function with signature (list[RecordDict], str) -> MetricRecord,
77
+ used to aggregate MetricRecords from training round replies.
78
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
79
+ average using the provided weight factor key.
80
+ server_learning_rate: float (default: 1.0)
81
+ Server-side learning rate used in server-side optimization.
82
+ server_momentum: float (default: 0.0)
83
+ Server-side momentum factor used for FedAvgM.
84
+ """
85
+
86
+ def __init__( # pylint: disable=R0913, R0917
87
+ self,
88
+ fraction_train: float = 1.0,
89
+ fraction_evaluate: float = 1.0,
90
+ min_train_nodes: int = 2,
91
+ min_evaluate_nodes: int = 2,
92
+ min_available_nodes: int = 2,
93
+ weighted_by_key: str = "num-examples",
94
+ arrayrecord_key: str = "arrays",
95
+ configrecord_key: str = "config",
96
+ train_metrics_aggr_fn: Optional[
97
+ Callable[[list[RecordDict], str], MetricRecord]
98
+ ] = None,
99
+ evaluate_metrics_aggr_fn: Optional[
100
+ Callable[[list[RecordDict], str], MetricRecord]
101
+ ] = None,
102
+ server_learning_rate: float = 1.0,
103
+ server_momentum: float = 0.0,
104
+ ) -> None:
105
+ super().__init__(
106
+ fraction_train=fraction_train,
107
+ fraction_evaluate=fraction_evaluate,
108
+ min_train_nodes=min_train_nodes,
109
+ min_evaluate_nodes=min_evaluate_nodes,
110
+ min_available_nodes=min_available_nodes,
111
+ weighted_by_key=weighted_by_key,
112
+ arrayrecord_key=arrayrecord_key,
113
+ configrecord_key=configrecord_key,
114
+ train_metrics_aggr_fn=train_metrics_aggr_fn,
115
+ evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
116
+ )
117
+ self.server_learning_rate = server_learning_rate
118
+ self.server_momentum = server_momentum
119
+ self.server_opt: bool = (self.server_momentum != 0.0) or (
120
+ self.server_learning_rate != 1.0
121
+ )
122
+ self.current_arrays: Optional[ArrayRecord] = None
123
+ self.momentum_vector: Optional[NDArrays] = None
124
+
125
+ def summary(self) -> None:
126
+ """Log summary configuration of the strategy."""
127
+ opt_status = "ON" if self.server_opt else "OFF"
128
+ log(INFO, "\t├──> FedAvgM settings:")
129
+ log(INFO, "\t│\t├── Server optimization: %s", opt_status)
130
+ log(INFO, "\t│\t├── Server learning rate: %s", self.server_learning_rate)
131
+ log(INFO, "\t│\t└── Server Momentum: %s", self.server_momentum)
132
+ super().summary()
133
+
134
+ def configure_train(
135
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
136
+ ) -> Iterable[Message]:
137
+ """Configure the next round of federated training."""
138
+ if self.current_arrays is None:
139
+ self.current_arrays = arrays
140
+ return super().configure_train(server_round, arrays, config, grid)
141
+
142
+ def aggregate_train(
143
+ self,
144
+ server_round: int,
145
+ replies: Iterable[Message],
146
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
147
+ """Aggregate ArrayRecords and MetricRecords in the received Messages."""
148
+ # Call FedAvg aggregate_train to perform validation and aggregation
149
+ aggregated_arrays, aggregated_metrics = super().aggregate_train(
150
+ server_round, replies
151
+ )
152
+
153
+ # following convention described in
154
+ # https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
155
+ if self.server_opt and aggregated_arrays is not None:
156
+ # The initial parameters should be set in `start()` method already
157
+ if self.current_arrays is None:
158
+ raise AggregationError(
159
+ "No initial parameters set for FedAvgM. "
160
+ "Ensure that `configure_train` has been called before aggregation."
161
+ )
162
+ ndarrays = self.current_arrays.to_numpy_ndarrays()
163
+ aggregated_ndarrays = aggregated_arrays.to_numpy_ndarrays()
164
+
165
+ # Preserve keys for arrays in ArrayRecord
166
+ array_keys = list(aggregated_arrays.keys())
167
+ aggregated_arrays.clear()
168
+
169
+ # Remember that updates are the opposite of gradients
170
+ pseudo_gradient = [
171
+ old - new for new, old in zip(aggregated_ndarrays, ndarrays)
172
+ ]
173
+ if self.server_momentum > 0.0:
174
+ if self.momentum_vector is None:
175
+ # Initialize momentum vector in the first round
176
+ self.momentum_vector = pseudo_gradient
177
+ else:
178
+ self.momentum_vector = [
179
+ self.server_momentum * mv + pg
180
+ for mv, pg in zip(self.momentum_vector, pseudo_gradient)
181
+ ]
182
+
183
+ # No nesterov for now
184
+ pseudo_gradient = self.momentum_vector
185
+
186
+ # SGD and convert back to ArrayRecord
187
+ updated_array_list = [
188
+ Array(old - self.server_learning_rate * pg)
189
+ for old, pg in zip(ndarrays, pseudo_gradient)
190
+ ]
191
+ aggregated_arrays = ArrayRecord(
192
+ OrderedDict(zip(array_keys, updated_array_list))
193
+ )
194
+
195
+ # Update current weights
196
+ self.current_arrays = aggregated_arrays
197
+
198
+ return aggregated_arrays, aggregated_metrics