flwr 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. flwr/__init__.py +4 -1
  2. flwr/app/__init__.py +28 -0
  3. flwr/app/exception.py +31 -0
  4. flwr/cli/app.py +2 -0
  5. flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
  6. flwr/cli/cli_user_auth_interceptor.py +1 -1
  7. flwr/cli/config_utils.py +3 -3
  8. flwr/cli/constant.py +25 -8
  9. flwr/cli/log.py +9 -9
  10. flwr/cli/login/login.py +3 -3
  11. flwr/cli/ls.py +5 -5
  12. flwr/cli/new/new.py +15 -2
  13. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  14. flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +1 -0
  15. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  16. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
  17. flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
  18. flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
  19. flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
  20. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
  21. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
  22. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
  23. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
  24. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
  25. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  26. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  27. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  28. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  29. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  30. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  31. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
  32. flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
  33. flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
  34. flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
  35. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
  36. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
  37. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
  38. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
  39. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
  40. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
  41. flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
  42. flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
  43. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
  44. flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +111 -0
  45. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
  46. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
  47. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  48. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  49. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  50. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  51. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  52. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  53. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  54. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +53 -0
  55. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  56. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  57. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
  58. flwr/cli/pull.py +100 -0
  59. flwr/cli/run/run.py +9 -13
  60. flwr/cli/stop.py +7 -4
  61. flwr/cli/utils.py +36 -8
  62. flwr/client/grpc_rere_client/connection.py +1 -12
  63. flwr/client/rest_client/connection.py +3 -0
  64. flwr/clientapp/__init__.py +10 -0
  65. flwr/clientapp/mod/__init__.py +29 -0
  66. flwr/clientapp/mod/centraldp_mods.py +248 -0
  67. flwr/clientapp/mod/localdp_mod.py +169 -0
  68. flwr/clientapp/typing.py +22 -0
  69. flwr/common/args.py +20 -6
  70. flwr/common/auth_plugin/__init__.py +4 -4
  71. flwr/common/auth_plugin/auth_plugin.py +7 -7
  72. flwr/common/constant.py +26 -4
  73. flwr/common/event_log_plugin/event_log_plugin.py +1 -1
  74. flwr/common/exit/__init__.py +4 -0
  75. flwr/common/exit/exit.py +8 -1
  76. flwr/common/exit/exit_code.py +30 -7
  77. flwr/common/exit/exit_handler.py +62 -0
  78. flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
  79. flwr/common/grpc.py +0 -11
  80. flwr/common/inflatable_utils.py +1 -1
  81. flwr/common/logger.py +1 -1
  82. flwr/common/record/typeddict.py +12 -0
  83. flwr/common/retry_invoker.py +30 -11
  84. flwr/common/telemetry.py +4 -0
  85. flwr/compat/server/app.py +2 -2
  86. flwr/proto/appio_pb2.py +25 -17
  87. flwr/proto/appio_pb2.pyi +46 -2
  88. flwr/proto/clientappio_pb2.py +3 -11
  89. flwr/proto/clientappio_pb2.pyi +0 -47
  90. flwr/proto/clientappio_pb2_grpc.py +19 -20
  91. flwr/proto/clientappio_pb2_grpc.pyi +10 -11
  92. flwr/proto/control_pb2.py +66 -0
  93. flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +24 -0
  94. flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +88 -54
  95. flwr/proto/control_pb2_grpc.pyi +106 -0
  96. flwr/proto/serverappio_pb2.py +2 -2
  97. flwr/proto/serverappio_pb2_grpc.py +68 -0
  98. flwr/proto/serverappio_pb2_grpc.pyi +26 -0
  99. flwr/proto/simulationio_pb2.py +4 -11
  100. flwr/proto/simulationio_pb2.pyi +0 -58
  101. flwr/proto/simulationio_pb2_grpc.py +129 -27
  102. flwr/proto/simulationio_pb2_grpc.pyi +52 -13
  103. flwr/server/app.py +142 -152
  104. flwr/server/grid/grpc_grid.py +3 -0
  105. flwr/server/grid/inmemory_grid.py +1 -0
  106. flwr/server/serverapp/app.py +157 -146
  107. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
  108. flwr/server/superlink/fleet/vce/vce_api.py +6 -6
  109. flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
  110. flwr/server/superlink/linkstate/linkstate.py +2 -1
  111. flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
  112. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -1
  113. flwr/server/superlink/serverappio/serverappio_servicer.py +61 -6
  114. flwr/server/superlink/simulation/simulationio_servicer.py +97 -21
  115. flwr/serverapp/__init__.py +12 -0
  116. flwr/serverapp/exception.py +38 -0
  117. flwr/serverapp/strategy/__init__.py +64 -0
  118. flwr/serverapp/strategy/bulyan.py +238 -0
  119. flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
  120. flwr/serverapp/strategy/dp_fixed_clipping.py +374 -0
  121. flwr/serverapp/strategy/fedadagrad.py +159 -0
  122. flwr/serverapp/strategy/fedadam.py +178 -0
  123. flwr/serverapp/strategy/fedavg.py +320 -0
  124. flwr/serverapp/strategy/fedavgm.py +198 -0
  125. flwr/serverapp/strategy/fedmedian.py +105 -0
  126. flwr/serverapp/strategy/fedopt.py +218 -0
  127. flwr/serverapp/strategy/fedprox.py +174 -0
  128. flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
  129. flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
  130. flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
  131. flwr/serverapp/strategy/fedyogi.py +170 -0
  132. flwr/serverapp/strategy/krum.py +112 -0
  133. flwr/serverapp/strategy/multikrum.py +247 -0
  134. flwr/serverapp/strategy/qfedavg.py +252 -0
  135. flwr/serverapp/strategy/result.py +105 -0
  136. flwr/serverapp/strategy/strategy.py +285 -0
  137. flwr/serverapp/strategy/strategy_utils.py +299 -0
  138. flwr/simulation/app.py +161 -164
  139. flwr/simulation/run_simulation.py +25 -30
  140. flwr/supercore/app_utils.py +58 -0
  141. flwr/{supernode/scheduler → supercore/cli}/__init__.py +3 -3
  142. flwr/supercore/cli/flower_superexec.py +166 -0
  143. flwr/supercore/constant.py +19 -0
  144. flwr/supercore/{scheduler → corestate}/__init__.py +3 -3
  145. flwr/supercore/corestate/corestate.py +81 -0
  146. flwr/supercore/grpc_health/__init__.py +3 -0
  147. flwr/supercore/grpc_health/health_server.py +53 -0
  148. flwr/supercore/grpc_health/simple_health_servicer.py +2 -2
  149. flwr/{superexec → supercore/superexec}/__init__.py +1 -1
  150. flwr/supercore/superexec/plugin/__init__.py +28 -0
  151. flwr/{supernode/scheduler/simple_clientapp_scheduler_plugin.py → supercore/superexec/plugin/base_exec_plugin.py} +10 -6
  152. flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
  153. flwr/supercore/{scheduler/plugin.py → superexec/plugin/exec_plugin.py} +15 -5
  154. flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
  155. flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
  156. flwr/supercore/superexec/run_superexec.py +199 -0
  157. flwr/superlink/artifact_provider/__init__.py +22 -0
  158. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  159. flwr/superlink/servicer/__init__.py +15 -0
  160. flwr/superlink/servicer/control/__init__.py +22 -0
  161. flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +7 -7
  162. flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +27 -29
  163. flwr/{superexec/exec_license_interceptor.py → superlink/servicer/control/control_license_interceptor.py} +6 -6
  164. flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +127 -31
  165. flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +10 -10
  166. flwr/supernode/cli/flower_supernode.py +3 -0
  167. flwr/supernode/cli/flwr_clientapp.py +18 -21
  168. flwr/supernode/nodestate/in_memory_nodestate.py +2 -2
  169. flwr/supernode/nodestate/nodestate.py +3 -59
  170. flwr/supernode/runtime/run_clientapp.py +39 -102
  171. flwr/supernode/servicer/clientappio/clientappio_servicer.py +10 -17
  172. flwr/supernode/start_client_internal.py +35 -76
  173. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/METADATA +9 -18
  174. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/RECORD +176 -128
  175. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/entry_points.txt +1 -0
  176. flwr/proto/exec_pb2.py +0 -62
  177. flwr/proto/exec_pb2_grpc.pyi +0 -93
  178. flwr/superexec/app.py +0 -45
  179. flwr/superexec/deployment.py +0 -191
  180. flwr/superexec/executor.py +0 -100
  181. flwr/superexec/simulation.py +0 -129
  182. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,285 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower message-based strategy."""
16
+
17
+
18
+ import io
19
+ import time
20
+ from abc import ABC, abstractmethod
21
+ from collections.abc import Iterable
22
+ from logging import INFO
23
+ from typing import Callable, Optional
24
+
25
+ from flwr.common import ArrayRecord, ConfigRecord, Message, MetricRecord, log
26
+ from flwr.server import Grid
27
+
28
+ from .result import Result
29
+ from .strategy_utils import log_strategy_start_info
30
+
31
+
32
+ class Strategy(ABC):
33
+ """Abstract base class for server strategy implementations."""
34
+
35
+ @abstractmethod
36
+ def configure_train(
37
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
38
+ ) -> Iterable[Message]:
39
+ """Configure the next round of training.
40
+
41
+ Parameters
42
+ ----------
43
+ server_round : int
44
+ The current round of federated learning.
45
+ arrays : ArrayRecord
46
+ Current global ArrayRecord (e.g. global model) to be sent to client
47
+ nodes for training.
48
+ config : ConfigRecord
49
+ Configuration to be sent to clients nodes for training.
50
+ grid : Grid
51
+ The Grid instance used for node sampling and communication.
52
+
53
+ Returns
54
+ -------
55
+ Iterable[Message]
56
+ An iterable of messages to be sent to selected client nodes for training.
57
+ """
58
+
59
+ @abstractmethod
60
+ def aggregate_train(
61
+ self,
62
+ server_round: int,
63
+ replies: Iterable[Message],
64
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
65
+ """Aggregate training results from client nodes.
66
+
67
+ Parameters
68
+ ----------
69
+ server_round : int
70
+ The current round of federated learning, starting from 1.
71
+ replies : Iterable[Message]
72
+ Iterable of reply messages received from client nodes after training.
73
+ Each message contains ArrayRecords and MetricRecords that get aggregated.
74
+
75
+ Returns
76
+ -------
77
+ tuple[Optional[ArrayRecord], Optional[MetricRecord]]
78
+ A tuple containing:
79
+ - ArrayRecord: Aggregated ArrayRecord, or None if aggregation failed
80
+ - MetricRecord: Aggregated MetricRecord, or None if aggregation failed
81
+ """
82
+
83
+ @abstractmethod
84
+ def configure_evaluate(
85
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
86
+ ) -> Iterable[Message]:
87
+ """Configure the next round of evaluation.
88
+
89
+ Parameters
90
+ ----------
91
+ server_round : int
92
+ The current round of federated learning.
93
+ arrays : ArrayRecord
94
+ Current global ArrayRecord (e.g. global model) to be sent to client
95
+ nodes for evaluation.
96
+ config : ConfigRecord
97
+ Configuration to be sent to clients nodes for evaluation.
98
+ grid : Grid
99
+ The Grid instance used for node sampling and communication.
100
+
101
+ Returns
102
+ -------
103
+ Iterable[Message]
104
+ An iterable of messages to be sent to selected client nodes for evaluation.
105
+ """
106
+
107
+ @abstractmethod
108
+ def aggregate_evaluate(
109
+ self,
110
+ server_round: int,
111
+ replies: Iterable[Message],
112
+ ) -> Optional[MetricRecord]:
113
+ """Aggregate evaluation metrics from client nodes.
114
+
115
+ Parameters
116
+ ----------
117
+ server_round : int
118
+ The current round of federated learning.
119
+ replies : Iterable[Message]
120
+ Iterable of reply messages received from client nodes after evaluation.
121
+ MetricRecords in the messages are aggregated.
122
+
123
+ Returns
124
+ -------
125
+ Optional[MetricRecord]
126
+ Aggregated evaluation metrics from all participating clients,
127
+ or None if aggregation failed.
128
+ """
129
+
130
+ @abstractmethod
131
+ def summary(self) -> None:
132
+ """Log summary configuration of the strategy."""
133
+
134
+ # pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-locals
135
+ def start(
136
+ self,
137
+ grid: Grid,
138
+ initial_arrays: ArrayRecord,
139
+ num_rounds: int = 3,
140
+ timeout: float = 3600,
141
+ train_config: Optional[ConfigRecord] = None,
142
+ evaluate_config: Optional[ConfigRecord] = None,
143
+ evaluate_fn: Optional[
144
+ Callable[[int, ArrayRecord], Optional[MetricRecord]]
145
+ ] = None,
146
+ ) -> Result:
147
+ """Execute the federated learning strategy.
148
+
149
+ Runs the complete federated learning workflow for the specified number of
150
+ rounds, including training, evaluation, and optional centralized evaluation.
151
+
152
+ Parameters
153
+ ----------
154
+ grid : Grid
155
+ The Grid instance used to send/receive Messages from nodes executing a
156
+ ClientApp.
157
+ initial_arrays : ArrayRecord
158
+ Initial model parameters (arrays) to be used for federated learning.
159
+ num_rounds : int (default: 3)
160
+ Number of federated learning rounds to execute.
161
+ timeout : float (default: 3600)
162
+ Timeout in seconds for waiting for node responses.
163
+ train_config : ConfigRecord, optional
164
+ Configuration to be sent to nodes during training rounds.
165
+ If unset, an empty ConfigRecord will be used.
166
+ evaluate_config : ConfigRecord, optional
167
+ Configuration to be sent to nodes during evaluation rounds.
168
+ If unset, an empty ConfigRecord will be used.
169
+ evaluate_fn : Callable[[int, ArrayRecord], Optional[MetricRecord]], optional
170
+ Optional function for centralized evaluation of the global model. Takes
171
+ server round number and array record, returns a MetricRecord or None. If
172
+ provided, will be called before the first round and after each round.
173
+ Defaults to None.
174
+
175
+ Returns
176
+ -------
177
+ Results
178
+ Results containing final model arrays and also training metrics, evaluation
179
+ metrics and global evaluation metrics (if provided) from all rounds.
180
+ """
181
+ log(INFO, "Starting %s strategy:", self.__class__.__name__)
182
+ log_strategy_start_info(
183
+ num_rounds, initial_arrays, train_config, evaluate_config
184
+ )
185
+ self.summary()
186
+ log(INFO, "")
187
+
188
+ # Initialize if None
189
+ train_config = ConfigRecord() if train_config is None else train_config
190
+ evaluate_config = ConfigRecord() if evaluate_config is None else evaluate_config
191
+ result = Result()
192
+
193
+ t_start = time.time()
194
+ # Evaluate starting global parameters
195
+ if evaluate_fn:
196
+ res = evaluate_fn(0, initial_arrays)
197
+ log(INFO, "Initial global evaluation results: %s", res)
198
+ if res is not None:
199
+ result.evaluate_metrics_serverapp[0] = res
200
+
201
+ arrays = initial_arrays
202
+
203
+ for current_round in range(1, num_rounds + 1):
204
+ log(INFO, "")
205
+ log(INFO, "[ROUND %s/%s]", current_round, num_rounds)
206
+
207
+ # -----------------------------------------------------------------
208
+ # --- TRAINING (CLIENTAPP-SIDE) -----------------------------------
209
+ # -----------------------------------------------------------------
210
+
211
+ # Call strategy to configure training round
212
+ # Send messages and wait for replies
213
+ train_replies = grid.send_and_receive(
214
+ messages=self.configure_train(
215
+ current_round,
216
+ arrays,
217
+ train_config,
218
+ grid,
219
+ ),
220
+ timeout=timeout,
221
+ )
222
+
223
+ # Aggregate train
224
+ agg_arrays, agg_train_metrics = self.aggregate_train(
225
+ current_round,
226
+ train_replies,
227
+ )
228
+
229
+ # Log training metrics and append to history
230
+ if agg_arrays is not None:
231
+ result.arrays = agg_arrays
232
+ arrays = agg_arrays
233
+ if agg_train_metrics is not None:
234
+ log(INFO, "\t└──> Aggregated MetricRecord: %s", agg_train_metrics)
235
+ result.train_metrics_clientapp[current_round] = agg_train_metrics
236
+
237
+ # -----------------------------------------------------------------
238
+ # --- EVALUATION (CLIENTAPP-SIDE) ---------------------------------
239
+ # -----------------------------------------------------------------
240
+
241
+ # Call strategy to configure evaluation round
242
+ # Send messages and wait for replies
243
+ evaluate_replies = grid.send_and_receive(
244
+ messages=self.configure_evaluate(
245
+ current_round,
246
+ arrays,
247
+ evaluate_config,
248
+ grid,
249
+ ),
250
+ timeout=timeout,
251
+ )
252
+
253
+ # Aggregate evaluate
254
+ agg_evaluate_metrics = self.aggregate_evaluate(
255
+ current_round,
256
+ evaluate_replies,
257
+ )
258
+
259
+ # Log training metrics and append to history
260
+ if agg_evaluate_metrics is not None:
261
+ log(INFO, "\t└──> Aggregated MetricRecord: %s", agg_evaluate_metrics)
262
+ result.evaluate_metrics_clientapp[current_round] = agg_evaluate_metrics
263
+
264
+ # -----------------------------------------------------------------
265
+ # --- EVALUATION (SERVERAPP-SIDE) ---------------------------------
266
+ # -----------------------------------------------------------------
267
+
268
+ # Centralized evaluation
269
+ if evaluate_fn:
270
+ log(INFO, "Global evaluation")
271
+ res = evaluate_fn(current_round, arrays)
272
+ log(INFO, "\t└──> MetricRecord: %s", res)
273
+ if res is not None:
274
+ result.evaluate_metrics_serverapp[current_round] = res
275
+
276
+ log(INFO, "")
277
+ log(INFO, "Strategy execution finished in %.2fs", time.time() - t_start)
278
+ log(INFO, "")
279
+ log(INFO, "Final results:")
280
+ log(INFO, "")
281
+ for line in io.StringIO(str(result)):
282
+ log(INFO, "\t%s", line.strip("\n"))
283
+ log(INFO, "")
284
+
285
+ return result
@@ -0,0 +1,299 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower message-based strategy utilities."""
16
+
17
+
18
+ import json
19
+ import random
20
+ from collections import OrderedDict
21
+ from logging import INFO
22
+ from time import sleep
23
+ from typing import Optional, cast
24
+
25
+ import numpy as np
26
+
27
+ from flwr.common import (
28
+ Array,
29
+ ArrayRecord,
30
+ ConfigRecord,
31
+ MetricRecord,
32
+ NDArray,
33
+ RecordDict,
34
+ log,
35
+ )
36
+ from flwr.server import Grid
37
+
38
+ from ..exception import InconsistentMessageReplies
39
+
40
+
41
+ def config_to_str(config: ConfigRecord) -> str:
42
+ """Convert a ConfigRecord to a string representation masking bytes."""
43
+ content = ", ".join(
44
+ f"'{k}': {'<bytes>' if isinstance(v, bytes) else v}" for k, v in config.items()
45
+ )
46
+ return f"{{{content}}}"
47
+
48
+
49
+ def log_strategy_start_info(
50
+ num_rounds: int,
51
+ arrays: ArrayRecord,
52
+ train_config: Optional[ConfigRecord],
53
+ evaluate_config: Optional[ConfigRecord],
54
+ ) -> None:
55
+ """Log information about the strategy start."""
56
+ log(INFO, "\t├── Number of rounds: %d", num_rounds)
57
+ log(
58
+ INFO,
59
+ "\t├── ArrayRecord (%.2f MB)",
60
+ sum(len(array.data) for array in arrays.values()) / (1024**2),
61
+ )
62
+ log(
63
+ INFO,
64
+ "\t├── ConfigRecord (train): %s",
65
+ config_to_str(train_config) if train_config else "(empty!)",
66
+ )
67
+ log(
68
+ INFO,
69
+ "\t├── ConfigRecord (evaluate): %s",
70
+ config_to_str(evaluate_config) if evaluate_config else "(empty!)",
71
+ )
72
+
73
+
74
+ def aggregate_arrayrecords(
75
+ records: list[RecordDict], weighting_metric_name: str
76
+ ) -> ArrayRecord:
77
+ """Perform weighted aggregation all ArrayRecords using a specific key."""
78
+ # Retrieve weighting factor from MetricRecord
79
+ weights: list[float] = []
80
+ for record in records:
81
+ # Get the first (and only) MetricRecord in the record
82
+ metricrecord = next(iter(record.metric_records.values()))
83
+ # Because replies have been checked for consistency,
84
+ # we can safely cast the weighting factor to float
85
+ w = cast(float, metricrecord[weighting_metric_name])
86
+ weights.append(w)
87
+
88
+ # Average
89
+ total_weight = sum(weights)
90
+ weight_factors = [w / total_weight for w in weights]
91
+
92
+ # Perform weighted aggregation
93
+ aggregated_np_arrays: dict[str, NDArray] = {}
94
+
95
+ for record, weight in zip(records, weight_factors):
96
+ for record_item in record.array_records.values():
97
+ # aggregate in-place
98
+ for key, value in record_item.items():
99
+ if key not in aggregated_np_arrays:
100
+ aggregated_np_arrays[key] = value.numpy() * weight
101
+ else:
102
+ aggregated_np_arrays[key] += value.numpy() * weight
103
+
104
+ return ArrayRecord(
105
+ OrderedDict({k: Array(np.asarray(v)) for k, v in aggregated_np_arrays.items()})
106
+ )
107
+
108
+
109
+ def aggregate_metricrecords(
110
+ records: list[RecordDict], weighting_metric_name: str
111
+ ) -> MetricRecord:
112
+ """Perform weighted aggregation all MetricRecords using a specific key."""
113
+ # Retrieve weighting factor from MetricRecord
114
+ weights: list[float] = []
115
+ for record in records:
116
+ # Get the first (and only) MetricRecord in the record
117
+ metricrecord = next(iter(record.metric_records.values()))
118
+ # Because replies have been checked for consistency,
119
+ # we can safely cast the weighting factor to float
120
+ w = cast(float, metricrecord[weighting_metric_name])
121
+ weights.append(w)
122
+
123
+ # Average
124
+ total_weight = sum(weights)
125
+ weight_factors = [w / total_weight for w in weights]
126
+
127
+ aggregated_metrics = MetricRecord()
128
+ for record, weight in zip(records, weight_factors):
129
+ for record_item in record.metric_records.values():
130
+ # aggregate in-place
131
+ for key, value in record_item.items():
132
+ if key == weighting_metric_name:
133
+ # We exclude the weighting key from the aggregated MetricRecord
134
+ continue
135
+ if key not in aggregated_metrics:
136
+ if isinstance(value, list):
137
+ aggregated_metrics[key] = [v * weight for v in value]
138
+ else:
139
+ aggregated_metrics[key] = value * weight
140
+ else:
141
+ if isinstance(value, list):
142
+ current_list = cast(list[float], aggregated_metrics[key])
143
+ aggregated_metrics[key] = [
144
+ curr + val * weight
145
+ for curr, val in zip(current_list, value)
146
+ ]
147
+ else:
148
+ current_value = cast(float, aggregated_metrics[key])
149
+ aggregated_metrics[key] = current_value + value * weight
150
+
151
+ return aggregated_metrics
152
+
153
+
154
+ def sample_nodes(
155
+ grid: Grid, min_available_nodes: int, sample_size: int
156
+ ) -> tuple[list[int], list[int]]:
157
+ """Sample the specified number of nodes using the Grid.
158
+
159
+ Parameters
160
+ ----------
161
+ grid : Grid
162
+ The grid object.
163
+ min_available_nodes : int
164
+ The minimum number of available nodes to sample from.
165
+ sample_size : int
166
+ The number of nodes to sample.
167
+
168
+ Returns
169
+ -------
170
+ tuple[list[int], list[int]]
171
+ A tuple containing the sampled node IDs and the list
172
+ of all connected node IDs.
173
+ """
174
+ sampled_nodes = []
175
+
176
+ # Ensure min_available_nodes is at least as large as sample_size
177
+ min_available_nodes = max(min_available_nodes, sample_size)
178
+
179
+ # wait for min_available_nodes to be online
180
+ while len(all_nodes := list(grid.get_node_ids())) < min_available_nodes:
181
+ log(
182
+ INFO,
183
+ "Waiting for nodes to connect: %d connected (minimum required: %d).",
184
+ len(all_nodes),
185
+ min_available_nodes,
186
+ )
187
+ sleep(1)
188
+
189
+ # Sample nodes
190
+ sampled_nodes = random.sample(all_nodes, sample_size)
191
+
192
+ return sampled_nodes, all_nodes
193
+
194
+
195
+ # pylint: disable=too-many-return-statements
196
+ def validate_message_reply_consistency(
197
+ replies: list[RecordDict], weighted_by_key: str, check_arrayrecord: bool
198
+ ) -> None:
199
+ """Validate that replies contain exactly one ArrayRecord and one MetricRecord, and
200
+ that the MetricRecord includes a weight factor key.
201
+
202
+ These checks ensure that Message-based strategies behave consistently with
203
+ *Ins/*Res-based strategies.
204
+ """
205
+ # Checking for ArrayRecord consistency
206
+ if check_arrayrecord:
207
+ if any(len(msg.array_records) != 1 for msg in replies):
208
+ raise InconsistentMessageReplies(
209
+ reason="Expected exactly one ArrayRecord in replies. "
210
+ "Skipping aggregation."
211
+ )
212
+
213
+ # Ensure all key are present in all ArrayRecords
214
+ record_key = next(iter(replies[0].array_records.keys()))
215
+ all_keys = set(replies[0][record_key].keys())
216
+ if any(set(msg.get(record_key, {}).keys()) != all_keys for msg in replies[1:]):
217
+ raise InconsistentMessageReplies(
218
+ reason="All ArrayRecords must have the same keys for aggregation. "
219
+ "This condition wasn't met. Skipping aggregation."
220
+ )
221
+
222
+ # Checking for MetricRecord consistency
223
+ if any(len(msg.metric_records) != 1 for msg in replies):
224
+ raise InconsistentMessageReplies(
225
+ reason="Expected exactly one MetricRecord in replies, but found more. "
226
+ "Skipping aggregation."
227
+ )
228
+
229
+ # Ensure all key are present in all MetricRecords
230
+ record_key = next(iter(replies[0].metric_records.keys()))
231
+ all_keys = set(replies[0][record_key].keys())
232
+ if any(set(msg.get(record_key, {}).keys()) != all_keys for msg in replies[1:]):
233
+ raise InconsistentMessageReplies(
234
+ reason="All MetricRecords must have the same keys for aggregation. "
235
+ "This condition wasn't met. Skipping aggregation."
236
+ )
237
+
238
+ # Verify the weight factor key presence in all MetricRecords
239
+ if weighted_by_key not in all_keys:
240
+ raise InconsistentMessageReplies(
241
+ reason=f"Missing required key `{weighted_by_key}` in the MetricRecord of "
242
+ "reply messages. Cannot average ArrayRecords and MetricRecords. Skipping "
243
+ "aggregation."
244
+ )
245
+
246
+ # Check that it is not a list
247
+ if any(isinstance(msg[record_key][weighted_by_key], list) for msg in replies):
248
+ raise InconsistentMessageReplies(
249
+ reason=f"Key `{weighted_by_key}` in the MetricRecord of reply messages "
250
+ "must be a single value (int or float), but a list was found. Skipping "
251
+ "aggregation."
252
+ )
253
+
254
+
255
+ def aggregate_bagging(
256
+ bst_prev_org: bytes,
257
+ bst_curr_org: bytes,
258
+ ) -> bytes:
259
+ """Conduct bagging aggregation for given trees."""
260
+ if bst_prev_org == b"":
261
+ return bst_curr_org
262
+
263
+ # Get the tree numbers
264
+ tree_num_prev, _ = _get_tree_nums(bst_prev_org)
265
+ _, paral_tree_num_curr = _get_tree_nums(bst_curr_org)
266
+
267
+ bst_prev = json.loads(bytearray(bst_prev_org))
268
+ bst_curr = json.loads(bytearray(bst_curr_org))
269
+
270
+ previous_model = bst_prev["learner"]["gradient_booster"]["model"]
271
+ previous_model["gbtree_model_param"]["num_trees"] = str(
272
+ tree_num_prev + paral_tree_num_curr
273
+ )
274
+ iteration_indptr = previous_model["iteration_indptr"]
275
+ previous_model["iteration_indptr"].append(
276
+ iteration_indptr[-1] + paral_tree_num_curr
277
+ )
278
+
279
+ # Aggregate new trees
280
+ trees_curr = bst_curr["learner"]["gradient_booster"]["model"]["trees"]
281
+ for tree_count in range(paral_tree_num_curr):
282
+ trees_curr[tree_count]["id"] = tree_num_prev + tree_count
283
+ previous_model["trees"].append(trees_curr[tree_count])
284
+ previous_model["tree_info"].append(0)
285
+
286
+ bst_prev_bytes = bytes(json.dumps(bst_prev), "utf-8")
287
+
288
+ return bst_prev_bytes
289
+
290
+
291
+ def _get_tree_nums(xgb_model_org: bytes) -> tuple[int, int]:
292
+ xgb_model = json.loads(bytearray(xgb_model_org))
293
+
294
+ # Access model parameters
295
+ model_param = xgb_model["learner"]["gradient_booster"]["model"][
296
+ "gbtree_model_param"
297
+ ]
298
+ # Return the number of trees and the number of parallel trees
299
+ return int(model_param["num_trees"]), int(model_param["num_parallel_tree"])