flwr 1.19.0__py3-none-any.whl → 1.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. flwr/__init__.py +4 -1
  2. flwr/app/__init__.py +28 -0
  3. flwr/app/exception.py +31 -0
  4. flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
  5. flwr/cli/build.py +15 -5
  6. flwr/cli/cli_user_auth_interceptor.py +1 -1
  7. flwr/cli/config_utils.py +3 -3
  8. flwr/cli/constant.py +25 -8
  9. flwr/cli/log.py +9 -9
  10. flwr/cli/login/login.py +3 -3
  11. flwr/cli/ls.py +5 -5
  12. flwr/cli/new/new.py +23 -4
  13. flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
  14. flwr/cli/new/templates/app/README.md.tpl +5 -0
  15. flwr/cli/new/templates/app/code/__init__.pytorch_msg_api.py.tpl +1 -0
  16. flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +80 -0
  17. flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +41 -0
  18. flwr/cli/new/templates/app/code/task.pytorch_msg_api.py.tpl +98 -0
  19. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -3
  20. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
  21. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
  22. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
  23. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
  24. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
  25. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
  26. flwr/cli/new/templates/app/pyproject.pytorch_msg_api.toml.tpl +53 -0
  27. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
  28. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
  29. flwr/cli/run/run.py +53 -50
  30. flwr/cli/stop.py +7 -4
  31. flwr/cli/utils.py +29 -11
  32. flwr/client/grpc_adapter_client/connection.py +11 -4
  33. flwr/client/grpc_rere_client/connection.py +93 -129
  34. flwr/client/rest_client/connection.py +134 -164
  35. flwr/clientapp/__init__.py +10 -0
  36. flwr/clientapp/mod/__init__.py +26 -0
  37. flwr/clientapp/mod/centraldp_mods.py +132 -0
  38. flwr/common/args.py +20 -6
  39. flwr/common/auth_plugin/__init__.py +4 -4
  40. flwr/common/auth_plugin/auth_plugin.py +7 -7
  41. flwr/common/constant.py +26 -5
  42. flwr/common/event_log_plugin/event_log_plugin.py +1 -1
  43. flwr/common/exit/__init__.py +4 -0
  44. flwr/common/exit/exit.py +8 -1
  45. flwr/common/exit/exit_code.py +42 -8
  46. flwr/common/exit/exit_handler.py +62 -0
  47. flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
  48. flwr/common/grpc.py +1 -1
  49. flwr/common/{inflatable_grpc_utils.py → inflatable_protobuf_utils.py} +52 -10
  50. flwr/common/inflatable_utils.py +191 -24
  51. flwr/common/logger.py +1 -1
  52. flwr/common/record/array.py +101 -22
  53. flwr/common/record/arraychunk.py +59 -0
  54. flwr/common/retry_invoker.py +30 -11
  55. flwr/common/serde.py +0 -28
  56. flwr/common/telemetry.py +4 -0
  57. flwr/compat/client/app.py +14 -31
  58. flwr/compat/server/app.py +2 -2
  59. flwr/proto/appio_pb2.py +51 -0
  60. flwr/proto/appio_pb2.pyi +195 -0
  61. flwr/proto/appio_pb2_grpc.py +4 -0
  62. flwr/proto/appio_pb2_grpc.pyi +4 -0
  63. flwr/proto/clientappio_pb2.py +4 -19
  64. flwr/proto/clientappio_pb2.pyi +0 -125
  65. flwr/proto/clientappio_pb2_grpc.py +269 -29
  66. flwr/proto/clientappio_pb2_grpc.pyi +114 -21
  67. flwr/proto/control_pb2.py +62 -0
  68. flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +54 -54
  69. flwr/proto/{exec_pb2_grpc.pyi → control_pb2_grpc.pyi} +28 -28
  70. flwr/proto/fleet_pb2.py +12 -20
  71. flwr/proto/fleet_pb2.pyi +6 -36
  72. flwr/proto/serverappio_pb2.py +8 -31
  73. flwr/proto/serverappio_pb2.pyi +0 -152
  74. flwr/proto/serverappio_pb2_grpc.py +107 -38
  75. flwr/proto/serverappio_pb2_grpc.pyi +47 -20
  76. flwr/proto/simulationio_pb2.py +4 -11
  77. flwr/proto/simulationio_pb2.pyi +0 -58
  78. flwr/proto/simulationio_pb2_grpc.py +129 -27
  79. flwr/proto/simulationio_pb2_grpc.pyi +52 -13
  80. flwr/server/app.py +130 -153
  81. flwr/server/fleet_event_log_interceptor.py +4 -0
  82. flwr/server/grid/grpc_grid.py +94 -54
  83. flwr/server/grid/inmemory_grid.py +1 -0
  84. flwr/server/serverapp/app.py +165 -144
  85. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +8 -0
  86. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
  87. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
  88. flwr/server/superlink/fleet/message_handler/message_handler.py +10 -16
  89. flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -2
  90. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
  91. flwr/server/superlink/fleet/vce/vce_api.py +6 -6
  92. flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
  93. flwr/server/superlink/linkstate/linkstate.py +2 -1
  94. flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
  95. flwr/server/superlink/serverappio/serverappio_grpc.py +2 -2
  96. flwr/server/superlink/serverappio/serverappio_servicer.py +95 -48
  97. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  98. flwr/server/superlink/simulation/simulationio_servicer.py +98 -22
  99. flwr/server/superlink/utils.py +0 -35
  100. flwr/serverapp/__init__.py +12 -0
  101. flwr/serverapp/dp_fixed_clipping.py +352 -0
  102. flwr/serverapp/exception.py +38 -0
  103. flwr/serverapp/strategy/__init__.py +38 -0
  104. flwr/serverapp/strategy/dp_fixed_clipping.py +352 -0
  105. flwr/serverapp/strategy/fedadagrad.py +162 -0
  106. flwr/serverapp/strategy/fedadam.py +181 -0
  107. flwr/serverapp/strategy/fedavg.py +295 -0
  108. flwr/serverapp/strategy/fedopt.py +218 -0
  109. flwr/serverapp/strategy/fedyogi.py +173 -0
  110. flwr/serverapp/strategy/result.py +105 -0
  111. flwr/serverapp/strategy/strategy.py +285 -0
  112. flwr/serverapp/strategy/strategy_utils.py +251 -0
  113. flwr/serverapp/strategy/strategy_utils_tests.py +304 -0
  114. flwr/simulation/app.py +159 -154
  115. flwr/simulation/run_simulation.py +17 -0
  116. flwr/supercore/app_utils.py +58 -0
  117. flwr/supercore/cli/__init__.py +22 -0
  118. flwr/supercore/cli/flower_superexec.py +141 -0
  119. flwr/supercore/corestate/__init__.py +22 -0
  120. flwr/supercore/corestate/corestate.py +81 -0
  121. flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
  122. flwr/supercore/grpc_health/__init__.py +25 -0
  123. flwr/supercore/grpc_health/health_server.py +53 -0
  124. flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
  125. flwr/supercore/license_plugin/__init__.py +22 -0
  126. flwr/supercore/license_plugin/license_plugin.py +26 -0
  127. flwr/supercore/object_store/in_memory_object_store.py +31 -31
  128. flwr/supercore/object_store/object_store.py +20 -42
  129. flwr/supercore/object_store/utils.py +43 -0
  130. flwr/{superexec → supercore/superexec}/__init__.py +1 -1
  131. flwr/supercore/superexec/plugin/__init__.py +28 -0
  132. flwr/supercore/superexec/plugin/base_exec_plugin.py +53 -0
  133. flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
  134. flwr/supercore/superexec/plugin/exec_plugin.py +71 -0
  135. flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
  136. flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
  137. flwr/supercore/superexec/run_superexec.py +185 -0
  138. flwr/supercore/utils.py +32 -0
  139. flwr/superlink/servicer/__init__.py +15 -0
  140. flwr/superlink/servicer/control/__init__.py +22 -0
  141. flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +9 -5
  142. flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +39 -28
  143. flwr/superlink/servicer/control/control_license_interceptor.py +82 -0
  144. flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +79 -31
  145. flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +18 -10
  146. flwr/supernode/cli/flower_supernode.py +3 -7
  147. flwr/supernode/cli/flwr_clientapp.py +20 -16
  148. flwr/supernode/nodestate/in_memory_nodestate.py +13 -4
  149. flwr/supernode/nodestate/nodestate.py +3 -44
  150. flwr/supernode/runtime/run_clientapp.py +129 -115
  151. flwr/supernode/servicer/clientappio/__init__.py +1 -3
  152. flwr/supernode/servicer/clientappio/clientappio_servicer.py +217 -165
  153. flwr/supernode/start_client_internal.py +205 -148
  154. {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/METADATA +5 -3
  155. {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/RECORD +161 -117
  156. {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/entry_points.txt +1 -0
  157. flwr/common/inflatable_rest_utils.py +0 -99
  158. flwr/proto/exec_pb2.py +0 -62
  159. flwr/superexec/app.py +0 -45
  160. flwr/superexec/deployment.py +0 -192
  161. flwr/superexec/executor.py +0 -100
  162. flwr/superexec/simulation.py +0 -130
  163. /flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +0 -0
  164. /flwr/{server/superlink → supercore}/ffs/__init__.py +0 -0
  165. /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
  166. /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
  167. {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,105 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Strategy results."""
16
+
17
+
18
+ import pprint
19
+ from dataclasses import dataclass, field
20
+
21
+ from flwr.common import ArrayRecord, MetricRecord
22
+ from flwr.common.typing import MetricRecordValues
23
+
24
+
25
+ @dataclass
26
+ class Result:
27
+ """Data class carrying records generated during the execution of a strategy.
28
+
29
+ This class encapsulates the results of a federated learning strategy execution,
30
+ including the final global model parameters and metrics collected throughout
31
+ the federated training and evaluation (both federated and centralized) stages.
32
+
33
+ Attributes
34
+ ----------
35
+ arrays : ArrayRecord
36
+ The final global model parameters. Contains the
37
+ aggregated model weights/parameters that resulted from the federated
38
+ learning process.
39
+ train_metrics_clientapp : dict[int, MetricRecord]
40
+ Training metrics collected from ClientApps, indexed by round number.
41
+ Contains aggregated metrics (e.g., loss, accuracy) from the training
42
+ phase of each federated learning round.
43
+ evaluate_metrics_clientapp : dict[int, MetricRecord]
44
+ Evaluation metrics collected from ClientApps, indexed by round number.
45
+ Contains aggregated metrics (e.g. validation loss) from the evaluation
46
+ phase where ClientApps evaluate the global model on their local
47
+ validation/test data.
48
+ evaluate_metrics_serverapp : dict[int, MetricRecord]
49
+ Evaluation metrics generated at the ServerApp, indexed by round number.
50
+ Contains metrics from centralized evaluation performed by the ServerApp
51
+ (e.g., when the server evaluates the global model on a held-out dataset).
52
+ """
53
+
54
+ arrays: ArrayRecord = field(default_factory=ArrayRecord)
55
+ train_metrics_clientapp: dict[int, MetricRecord] = field(default_factory=dict)
56
+ evaluate_metrics_clientapp: dict[int, MetricRecord] = field(default_factory=dict)
57
+ evaluate_metrics_serverapp: dict[int, MetricRecord] = field(default_factory=dict)
58
+
59
+ def __repr__(self) -> str:
60
+ """Create a representation of the Result instance."""
61
+ rep = ""
62
+ arr_size = sum(len(array.data) for array in self.arrays.values()) / (1024**2)
63
+ rep += "Global Arrays:\n" + f"\tArrayRecord ({arr_size:.3f} MB)\n" + "\n"
64
+ rep += (
65
+ "Aggregated ClientApp-side Train Metrics:\n"
66
+ + pprint.pformat(stringify_dict(self.train_metrics_clientapp), indent=2)
67
+ + "\n\n"
68
+ )
69
+
70
+ rep += (
71
+ "Aggregated ClientApp-side Evaluate Metrics:\n"
72
+ + pprint.pformat(stringify_dict(self.evaluate_metrics_clientapp), indent=2)
73
+ + "\n\n"
74
+ )
75
+
76
+ rep += (
77
+ "ServerApp-side Evaluate Metrics:\n"
78
+ + pprint.pformat(stringify_dict(self.evaluate_metrics_serverapp), indent=2)
79
+ + "\n"
80
+ )
81
+
82
+ return rep
83
+
84
+
85
+ def format_value(val: MetricRecordValues) -> str:
86
+ """Format a value as string, applying scientific notation for floats."""
87
+ if isinstance(val, float):
88
+ return f"{val:.4e}"
89
+ if isinstance(val, int):
90
+ return str(val)
91
+ if isinstance(val, list):
92
+ return str([f"{x:.4e}" if isinstance(x, float) else str(x) for x in val])
93
+ return str(val)
94
+
95
+
96
+ def stringify_dict(d: dict[int, MetricRecord]) -> dict[int, dict[str, str]]:
97
+ """Return a copy results metrics but with values converted to string and formatted
98
+ accordingtly."""
99
+ new_metrics_dict = {}
100
+ for k, inner in d.items():
101
+ new_inner = {}
102
+ for ik, iv in inner.items():
103
+ new_inner[ik] = format_value(iv)
104
+ new_metrics_dict[k] = new_inner
105
+ return new_metrics_dict
@@ -0,0 +1,285 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower message-based strategy."""
16
+
17
+
18
+ import io
19
+ import time
20
+ from abc import ABC, abstractmethod
21
+ from collections.abc import Iterable
22
+ from logging import INFO
23
+ from typing import Callable, Optional
24
+
25
+ from flwr.common import ArrayRecord, ConfigRecord, Message, MetricRecord, log
26
+ from flwr.server import Grid
27
+
28
+ from .result import Result
29
+ from .strategy_utils import log_strategy_start_info
30
+
31
+
32
+ class Strategy(ABC):
33
+ """Abstract base class for server strategy implementations."""
34
+
35
+ @abstractmethod
36
+ def configure_train(
37
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
38
+ ) -> Iterable[Message]:
39
+ """Configure the next round of training.
40
+
41
+ Parameters
42
+ ----------
43
+ server_round : int
44
+ The current round of federated learning.
45
+ arrays : ArrayRecord
46
+ Current global ArrayRecord (e.g. global model) to be sent to client
47
+ nodes for training.
48
+ config : ConfigRecord
49
+ Configuration to be sent to clients nodes for training.
50
+ grid : Grid
51
+ The Grid instance used for node sampling and communication.
52
+
53
+ Returns
54
+ -------
55
+ Iterable[Message]
56
+ An iterable of messages to be sent to selected client nodes for training.
57
+ """
58
+
59
+ @abstractmethod
60
+ def aggregate_train(
61
+ self,
62
+ server_round: int,
63
+ replies: Iterable[Message],
64
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
65
+ """Aggregate training results from client nodes.
66
+
67
+ Parameters
68
+ ----------
69
+ server_round : int
70
+ The current round of federated learning, starting from 1.
71
+ replies : Iterable[Message]
72
+ Iterable of reply messages received from client nodes after training.
73
+ Each message contains ArrayRecords and MetricRecords that get aggregated.
74
+
75
+ Returns
76
+ -------
77
+ tuple[Optional[ArrayRecord], Optional[MetricRecord]]
78
+ A tuple containing:
79
+ - ArrayRecord: Aggregated ArrayRecord, or None if aggregation failed
80
+ - MetricRecord: Aggregated MetricRecord, or None if aggregation failed
81
+ """
82
+
83
+ @abstractmethod
84
+ def configure_evaluate(
85
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
86
+ ) -> Iterable[Message]:
87
+ """Configure the next round of evaluation.
88
+
89
+ Parameters
90
+ ----------
91
+ server_round : int
92
+ The current round of federated learning.
93
+ arrays : ArrayRecord
94
+ Current global ArrayRecord (e.g. global model) to be sent to client
95
+ nodes for evaluation.
96
+ config : ConfigRecord
97
+ Configuration to be sent to clients nodes for evaluation.
98
+ grid : Grid
99
+ The Grid instance used for node sampling and communication.
100
+
101
+ Returns
102
+ -------
103
+ Iterable[Message]
104
+ An iterable of messages to be sent to selected client nodes for evaluation.
105
+ """
106
+
107
+ @abstractmethod
108
+ def aggregate_evaluate(
109
+ self,
110
+ server_round: int,
111
+ replies: Iterable[Message],
112
+ ) -> Optional[MetricRecord]:
113
+ """Aggregate evaluation metrics from client nodes.
114
+
115
+ Parameters
116
+ ----------
117
+ server_round : int
118
+ The current round of federated learning.
119
+ replies : Iterable[Message]
120
+ Iterable of reply messages received from client nodes after evaluation.
121
+ MetricRecords in the messages are aggregated.
122
+
123
+ Returns
124
+ -------
125
+ Optional[MetricRecord]
126
+ Aggregated evaluation metrics from all participating clients,
127
+ or None if aggregation failed.
128
+ """
129
+
130
+ @abstractmethod
131
+ def summary(self) -> None:
132
+ """Log summary configuration of the strategy."""
133
+
134
+ # pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-locals
135
+ def start(
136
+ self,
137
+ grid: Grid,
138
+ initial_arrays: ArrayRecord,
139
+ num_rounds: int = 3,
140
+ timeout: float = 3600,
141
+ train_config: Optional[ConfigRecord] = None,
142
+ evaluate_config: Optional[ConfigRecord] = None,
143
+ evaluate_fn: Optional[
144
+ Callable[[int, ArrayRecord], Optional[MetricRecord]]
145
+ ] = None,
146
+ ) -> Result:
147
+ """Execute the federated learning strategy.
148
+
149
+ Runs the complete federated learning workflow for the specified number of
150
+ rounds, including training, evaluation, and optional centralized evaluation.
151
+
152
+ Parameters
153
+ ----------
154
+ grid : Grid
155
+ The Grid instance used to send/receive Messages from nodes executing a
156
+ ClientApp.
157
+ initial_arrays : ArrayRecord
158
+ Initial model parameters (arrays) to be used for federated learning.
159
+ num_rounds : int (default: 3)
160
+ Number of federated learning rounds to execute.
161
+ timeout : float (default: 3600)
162
+ Timeout in seconds for waiting for node responses.
163
+ train_config : ConfigRecord, optional
164
+ Configuration to be sent to nodes during training rounds.
165
+ If unset, an empty ConfigRecord will be used.
166
+ evaluate_config : ConfigRecord, optional
167
+ Configuration to be sent to nodes during evaluation rounds.
168
+ If unset, an empty ConfigRecord will be used.
169
+ evaluate_fn : Callable[[int, ArrayRecord], Optional[MetricRecord]], optional
170
+ Optional function for centralized evaluation of the global model. Takes
171
+ server round number and array record, returns a MetricRecord or None. If
172
+ provided, will be called before the first round and after each round.
173
+ Defaults to None.
174
+
175
+ Returns
176
+ -------
177
+ Results
178
+ Results containing final model arrays and also training metrics, evaluation
179
+ metrics and global evaluation metrics (if provided) from all rounds.
180
+ """
181
+ log(INFO, "Starting %s strategy:", self.__class__.__name__)
182
+ log_strategy_start_info(
183
+ num_rounds, initial_arrays, train_config, evaluate_config
184
+ )
185
+ self.summary()
186
+ log(INFO, "")
187
+
188
+ # Initialize if None
189
+ train_config = ConfigRecord() if train_config is None else train_config
190
+ evaluate_config = ConfigRecord() if evaluate_config is None else evaluate_config
191
+ result = Result()
192
+
193
+ t_start = time.time()
194
+ # Evaluate starting global parameters
195
+ if evaluate_fn:
196
+ res = evaluate_fn(0, initial_arrays)
197
+ log(INFO, "Initial global evaluation results: %s", res)
198
+ if res is not None:
199
+ result.evaluate_metrics_serverapp[0] = res
200
+
201
+ arrays = initial_arrays
202
+
203
+ for current_round in range(1, num_rounds + 1):
204
+ log(INFO, "")
205
+ log(INFO, "[ROUND %s/%s]", current_round, num_rounds)
206
+
207
+ # -----------------------------------------------------------------
208
+ # --- TRAINING (CLIENTAPP-SIDE) -----------------------------------
209
+ # -----------------------------------------------------------------
210
+
211
+ # Call strategy to configure training round
212
+ # Send messages and wait for replies
213
+ train_replies = grid.send_and_receive(
214
+ messages=self.configure_train(
215
+ current_round,
216
+ arrays,
217
+ train_config,
218
+ grid,
219
+ ),
220
+ timeout=timeout,
221
+ )
222
+
223
+ # Aggregate train
224
+ agg_arrays, agg_train_metrics = self.aggregate_train(
225
+ current_round,
226
+ train_replies,
227
+ )
228
+
229
+ # Log training metrics and append to history
230
+ if agg_arrays is not None:
231
+ result.arrays = agg_arrays
232
+ arrays = agg_arrays
233
+ if agg_train_metrics is not None:
234
+ log(INFO, "\t└──> Aggregated MetricRecord: %s", agg_train_metrics)
235
+ result.train_metrics_clientapp[current_round] = agg_train_metrics
236
+
237
+ # -----------------------------------------------------------------
238
+ # --- EVALUATION (CLIENTAPP-SIDE) ---------------------------------
239
+ # -----------------------------------------------------------------
240
+
241
+ # Call strategy to configure evaluation round
242
+ # Send messages and wait for replies
243
+ evaluate_replies = grid.send_and_receive(
244
+ messages=self.configure_evaluate(
245
+ current_round,
246
+ arrays,
247
+ evaluate_config,
248
+ grid,
249
+ ),
250
+ timeout=timeout,
251
+ )
252
+
253
+ # Aggregate evaluate
254
+ agg_evaluate_metrics = self.aggregate_evaluate(
255
+ current_round,
256
+ evaluate_replies,
257
+ )
258
+
259
+ # Log training metrics and append to history
260
+ if agg_evaluate_metrics is not None:
261
+ log(INFO, "\t└──> Aggregated MetricRecord: %s", agg_evaluate_metrics)
262
+ result.evaluate_metrics_clientapp[current_round] = agg_evaluate_metrics
263
+
264
+ # -----------------------------------------------------------------
265
+ # --- EVALUATION (SERVERAPP-SIDE) ---------------------------------
266
+ # -----------------------------------------------------------------
267
+
268
+ # Centralized evaluation
269
+ if evaluate_fn:
270
+ log(INFO, "Global evaluation")
271
+ res = evaluate_fn(current_round, arrays)
272
+ log(INFO, "\t└──> MetricRecord: %s", res)
273
+ if res is not None:
274
+ result.evaluate_metrics_serverapp[current_round] = res
275
+
276
+ log(INFO, "")
277
+ log(INFO, "Strategy execution finished in %.2fs", time.time() - t_start)
278
+ log(INFO, "")
279
+ log(INFO, "Final results:")
280
+ log(INFO, "")
281
+ for line in io.StringIO(str(result)):
282
+ log(INFO, "\t%s", line.strip("\n"))
283
+ log(INFO, "")
284
+
285
+ return result
@@ -0,0 +1,251 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower message-based strategy utilities."""
16
+
17
+
18
+ import random
19
+ from collections import OrderedDict
20
+ from logging import INFO
21
+ from time import sleep
22
+ from typing import Optional, cast
23
+
24
+ import numpy as np
25
+
26
+ from flwr.common import (
27
+ Array,
28
+ ArrayRecord,
29
+ ConfigRecord,
30
+ MetricRecord,
31
+ NDArray,
32
+ RecordDict,
33
+ log,
34
+ )
35
+ from flwr.server import Grid
36
+
37
+ from ..exception import InconsistentMessageReplies
38
+
39
+
40
+ def config_to_str(config: ConfigRecord) -> str:
41
+ """Convert a ConfigRecord to a string representation masking bytes."""
42
+ content = ", ".join(
43
+ f"'{k}': {'<bytes>' if isinstance(v, bytes) else v}" for k, v in config.items()
44
+ )
45
+ return f"{{{content}}}"
46
+
47
+
48
+ def log_strategy_start_info(
49
+ num_rounds: int,
50
+ arrays: ArrayRecord,
51
+ train_config: Optional[ConfigRecord],
52
+ evaluate_config: Optional[ConfigRecord],
53
+ ) -> None:
54
+ """Log information about the strategy start."""
55
+ log(INFO, "\t├── Number of rounds: %d", num_rounds)
56
+ log(
57
+ INFO,
58
+ "\t├── ArrayRecord (%.2f MB)",
59
+ sum(len(array.data) for array in arrays.values()) / (1024**2),
60
+ )
61
+ log(
62
+ INFO,
63
+ "\t├── ConfigRecord (train): %s",
64
+ config_to_str(train_config) if train_config else "(empty!)",
65
+ )
66
+ log(
67
+ INFO,
68
+ "\t├── ConfigRecord (evaluate): %s",
69
+ config_to_str(evaluate_config) if evaluate_config else "(empty!)",
70
+ )
71
+
72
+
73
+ def aggregate_arrayrecords(
74
+ records: list[RecordDict], weighting_metric_name: str
75
+ ) -> ArrayRecord:
76
+ """Perform weighted aggregation all ArrayRecords using a specific key."""
77
+ # Retrieve weighting factor from MetricRecord
78
+ weights: list[float] = []
79
+ for record in records:
80
+ # Get the first (and only) MetricRecord in the record
81
+ metricrecord = next(iter(record.metric_records.values()))
82
+ # Because replies have been checked for consistency,
83
+ # we can safely cast the weighting factor to float
84
+ w = cast(float, metricrecord[weighting_metric_name])
85
+ weights.append(w)
86
+
87
+ # Average
88
+ total_weight = sum(weights)
89
+ weight_factors = [w / total_weight for w in weights]
90
+
91
+ # Perform weighted aggregation
92
+ aggregated_np_arrays: dict[str, NDArray] = {}
93
+
94
+ for record, weight in zip(records, weight_factors):
95
+ for record_item in record.array_records.values():
96
+ # aggregate in-place
97
+ for key, value in record_item.items():
98
+ if key not in aggregated_np_arrays:
99
+ aggregated_np_arrays[key] = value.numpy() * weight
100
+ else:
101
+ aggregated_np_arrays[key] += value.numpy() * weight
102
+
103
+ return ArrayRecord(
104
+ OrderedDict({k: Array(np.asarray(v)) for k, v in aggregated_np_arrays.items()})
105
+ )
106
+
107
+
108
+ def aggregate_metricrecords(
109
+ records: list[RecordDict], weighting_metric_name: str
110
+ ) -> MetricRecord:
111
+ """Perform weighted aggregation all MetricRecords using a specific key."""
112
+ # Retrieve weighting factor from MetricRecord
113
+ weights: list[float] = []
114
+ for record in records:
115
+ # Get the first (and only) MetricRecord in the record
116
+ metricrecord = next(iter(record.metric_records.values()))
117
+ # Because replies have been checked for consistency,
118
+ # we can safely cast the weighting factor to float
119
+ w = cast(float, metricrecord[weighting_metric_name])
120
+ weights.append(w)
121
+
122
+ # Average
123
+ total_weight = sum(weights)
124
+ weight_factors = [w / total_weight for w in weights]
125
+
126
+ aggregated_metrics = MetricRecord()
127
+ for record, weight in zip(records, weight_factors):
128
+ for record_item in record.metric_records.values():
129
+ # aggregate in-place
130
+ for key, value in record_item.items():
131
+ if key == weighting_metric_name:
132
+ # We exclude the weighting key from the aggregated MetricRecord
133
+ continue
134
+ if key not in aggregated_metrics:
135
+ if isinstance(value, list):
136
+ aggregated_metrics[key] = [v * weight for v in value]
137
+ else:
138
+ aggregated_metrics[key] = value * weight
139
+ else:
140
+ if isinstance(value, list):
141
+ current_list = cast(list[float], aggregated_metrics[key])
142
+ aggregated_metrics[key] = [
143
+ curr + val * weight
144
+ for curr, val in zip(current_list, value)
145
+ ]
146
+ else:
147
+ current_value = cast(float, aggregated_metrics[key])
148
+ aggregated_metrics[key] = current_value + value * weight
149
+
150
+ return aggregated_metrics
151
+
152
+
153
+ def sample_nodes(
154
+ grid: Grid, min_available_nodes: int, sample_size: int
155
+ ) -> tuple[list[int], list[int]]:
156
+ """Sample the specified number of nodes using the Grid.
157
+
158
+ Parameters
159
+ ----------
160
+ grid : Grid
161
+ The grid object.
162
+ min_available_nodes : int
163
+ The minimum number of available nodes to sample from.
164
+ sample_size : int
165
+ The number of nodes to sample.
166
+
167
+ Returns
168
+ -------
169
+ tuple[list[int], list[int]]
170
+ A tuple containing the sampled node IDs and the list
171
+ of all connected node IDs.
172
+ """
173
+ sampled_nodes = []
174
+
175
+ # Ensure min_available_nodes is at least as large as sample_size
176
+ min_available_nodes = max(min_available_nodes, sample_size)
177
+
178
+ # wait for min_available_nodes to be online
179
+ while len(all_nodes := list(grid.get_node_ids())) < min_available_nodes:
180
+ log(
181
+ INFO,
182
+ "Waiting for nodes to connect: %d connected (minimum required: %d).",
183
+ len(all_nodes),
184
+ min_available_nodes,
185
+ )
186
+ sleep(1)
187
+
188
+ # Sample nodes
189
+ sampled_nodes = random.sample(all_nodes, sample_size)
190
+
191
+ return sampled_nodes, all_nodes
192
+
193
+
194
+ # pylint: disable=too-many-return-statements
195
+ def validate_message_reply_consistency(
196
+ replies: list[RecordDict], weighted_by_key: str, check_arrayrecord: bool
197
+ ) -> None:
198
+ """Validate that replies contain exactly one ArrayRecord and one MetricRecord, and
199
+ that the MetricRecord includes a weight factor key.
200
+
201
+ These checks ensure that Message-based strategies behave consistently with
202
+ *Ins/*Res-based strategies.
203
+ """
204
+ # Checking for ArrayRecord consistency
205
+ if check_arrayrecord:
206
+ if any(len(msg.array_records) != 1 for msg in replies):
207
+ raise InconsistentMessageReplies(
208
+ reason="Expected exactly one ArrayRecord in replies. "
209
+ "Skipping aggregation."
210
+ )
211
+
212
+ # Ensure all key are present in all ArrayRecords
213
+ record_key = next(iter(replies[0].array_records.keys()))
214
+ all_keys = set(replies[0][record_key].keys())
215
+ if any(set(msg.get(record_key, {}).keys()) != all_keys for msg in replies[1:]):
216
+ raise InconsistentMessageReplies(
217
+ reason="All ArrayRecords must have the same keys for aggregation. "
218
+ "This condition wasn't met. Skipping aggregation."
219
+ )
220
+
221
+ # Checking for MetricRecord consistency
222
+ if any(len(msg.metric_records) != 1 for msg in replies):
223
+ raise InconsistentMessageReplies(
224
+ reason="Expected exactly one MetricRecord in replies, but found more. "
225
+ "Skipping aggregation."
226
+ )
227
+
228
+ # Ensure all key are present in all MetricRecords
229
+ record_key = next(iter(replies[0].metric_records.keys()))
230
+ all_keys = set(replies[0][record_key].keys())
231
+ if any(set(msg.get(record_key, {}).keys()) != all_keys for msg in replies[1:]):
232
+ raise InconsistentMessageReplies(
233
+ reason="All MetricRecords must have the same keys for aggregation. "
234
+ "This condition wasn't met. Skipping aggregation."
235
+ )
236
+
237
+ # Verify the weight factor key presence in all MetricRecords
238
+ if weighted_by_key not in all_keys:
239
+ raise InconsistentMessageReplies(
240
+ reason=f"Missing required key `{weighted_by_key}` in the MetricRecord of "
241
+ "reply messages. Cannot average ArrayRecords and MetricRecords. Skipping "
242
+ "aggregation."
243
+ )
244
+
245
+ # Check that it is not a list
246
+ if any(isinstance(msg[record_key][weighted_by_key], list) for msg in replies):
247
+ raise InconsistentMessageReplies(
248
+ reason=f"Key `{weighted_by_key}` in the MetricRecord of reply messages "
249
+ "must be a single value (int or float), but a list was found. Skipping "
250
+ "aggregation."
251
+ )