flwr 1.20.0__py3-none-any.whl → 1.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. flwr/__init__.py +4 -1
  2. flwr/app/__init__.py +28 -0
  3. flwr/app/exception.py +31 -0
  4. flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
  5. flwr/cli/cli_user_auth_interceptor.py +1 -1
  6. flwr/cli/config_utils.py +3 -3
  7. flwr/cli/constant.py +25 -8
  8. flwr/cli/log.py +9 -9
  9. flwr/cli/login/login.py +3 -3
  10. flwr/cli/ls.py +5 -5
  11. flwr/cli/new/new.py +11 -0
  12. flwr/cli/new/templates/app/code/__init__.pytorch_msg_api.py.tpl +1 -0
  13. flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +80 -0
  14. flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +41 -0
  15. flwr/cli/new/templates/app/code/task.pytorch_msg_api.py.tpl +98 -0
  16. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  17. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  18. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  19. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  20. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  21. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  22. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  23. flwr/cli/new/templates/app/pyproject.pytorch_msg_api.toml.tpl +53 -0
  24. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  25. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  26. flwr/cli/run/run.py +9 -13
  27. flwr/cli/stop.py +7 -4
  28. flwr/cli/utils.py +19 -8
  29. flwr/client/grpc_rere_client/connection.py +1 -12
  30. flwr/client/rest_client/connection.py +3 -0
  31. flwr/clientapp/__init__.py +10 -0
  32. flwr/clientapp/mod/__init__.py +26 -0
  33. flwr/clientapp/mod/centraldp_mods.py +132 -0
  34. flwr/common/args.py +20 -6
  35. flwr/common/auth_plugin/__init__.py +4 -4
  36. flwr/common/auth_plugin/auth_plugin.py +7 -7
  37. flwr/common/constant.py +23 -4
  38. flwr/common/event_log_plugin/event_log_plugin.py +1 -1
  39. flwr/common/exit/__init__.py +4 -0
  40. flwr/common/exit/exit.py +8 -1
  41. flwr/common/exit/exit_code.py +26 -7
  42. flwr/common/exit/exit_handler.py +62 -0
  43. flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
  44. flwr/common/grpc.py +0 -11
  45. flwr/common/inflatable_utils.py +1 -1
  46. flwr/common/logger.py +1 -1
  47. flwr/common/retry_invoker.py +30 -11
  48. flwr/common/telemetry.py +4 -0
  49. flwr/compat/server/app.py +2 -2
  50. flwr/proto/appio_pb2.py +25 -17
  51. flwr/proto/appio_pb2.pyi +46 -2
  52. flwr/proto/clientappio_pb2.py +3 -11
  53. flwr/proto/clientappio_pb2.pyi +0 -47
  54. flwr/proto/clientappio_pb2_grpc.py +19 -20
  55. flwr/proto/clientappio_pb2_grpc.pyi +10 -11
  56. flwr/proto/control_pb2.py +62 -0
  57. flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +54 -54
  58. flwr/proto/{exec_pb2_grpc.pyi → control_pb2_grpc.pyi} +28 -28
  59. flwr/proto/serverappio_pb2.py +2 -2
  60. flwr/proto/serverappio_pb2_grpc.py +68 -0
  61. flwr/proto/serverappio_pb2_grpc.pyi +26 -0
  62. flwr/proto/simulationio_pb2.py +4 -11
  63. flwr/proto/simulationio_pb2.pyi +0 -58
  64. flwr/proto/simulationio_pb2_grpc.py +129 -27
  65. flwr/proto/simulationio_pb2_grpc.pyi +52 -13
  66. flwr/server/app.py +129 -152
  67. flwr/server/grid/grpc_grid.py +3 -0
  68. flwr/server/grid/inmemory_grid.py +1 -0
  69. flwr/server/serverapp/app.py +157 -146
  70. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
  71. flwr/server/superlink/fleet/vce/vce_api.py +6 -6
  72. flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
  73. flwr/server/superlink/linkstate/linkstate.py +2 -1
  74. flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
  75. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -1
  76. flwr/server/superlink/serverappio/serverappio_servicer.py +61 -6
  77. flwr/server/superlink/simulation/simulationio_servicer.py +97 -21
  78. flwr/serverapp/__init__.py +12 -0
  79. flwr/serverapp/dp_fixed_clipping.py +352 -0
  80. flwr/serverapp/exception.py +38 -0
  81. flwr/serverapp/strategy/__init__.py +38 -0
  82. flwr/serverapp/strategy/dp_fixed_clipping.py +352 -0
  83. flwr/serverapp/strategy/fedadagrad.py +162 -0
  84. flwr/serverapp/strategy/fedadam.py +181 -0
  85. flwr/serverapp/strategy/fedavg.py +295 -0
  86. flwr/serverapp/strategy/fedopt.py +218 -0
  87. flwr/serverapp/strategy/fedyogi.py +173 -0
  88. flwr/serverapp/strategy/result.py +105 -0
  89. flwr/serverapp/strategy/strategy.py +285 -0
  90. flwr/serverapp/strategy/strategy_utils.py +251 -0
  91. flwr/serverapp/strategy/strategy_utils_tests.py +304 -0
  92. flwr/simulation/app.py +161 -164
  93. flwr/supercore/app_utils.py +58 -0
  94. flwr/{supernode/scheduler → supercore/cli}/__init__.py +3 -3
  95. flwr/supercore/cli/flower_superexec.py +141 -0
  96. flwr/supercore/{scheduler → corestate}/__init__.py +3 -3
  97. flwr/supercore/corestate/corestate.py +81 -0
  98. flwr/supercore/grpc_health/__init__.py +3 -0
  99. flwr/supercore/grpc_health/health_server.py +53 -0
  100. flwr/supercore/grpc_health/simple_health_servicer.py +2 -2
  101. flwr/{superexec → supercore/superexec}/__init__.py +1 -1
  102. flwr/supercore/superexec/plugin/__init__.py +28 -0
  103. flwr/{supernode/scheduler/simple_clientapp_scheduler_plugin.py → supercore/superexec/plugin/base_exec_plugin.py} +10 -6
  104. flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
  105. flwr/supercore/{scheduler/plugin.py → superexec/plugin/exec_plugin.py} +4 -4
  106. flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
  107. flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
  108. flwr/supercore/superexec/run_superexec.py +185 -0
  109. flwr/superlink/servicer/__init__.py +15 -0
  110. flwr/superlink/servicer/control/__init__.py +22 -0
  111. flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +7 -7
  112. flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +24 -29
  113. flwr/{superexec/exec_license_interceptor.py → superlink/servicer/control/control_license_interceptor.py} +6 -6
  114. flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +69 -30
  115. flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +10 -10
  116. flwr/supernode/cli/flower_supernode.py +3 -0
  117. flwr/supernode/cli/flwr_clientapp.py +18 -21
  118. flwr/supernode/nodestate/in_memory_nodestate.py +2 -2
  119. flwr/supernode/nodestate/nodestate.py +3 -59
  120. flwr/supernode/runtime/run_clientapp.py +39 -102
  121. flwr/supernode/servicer/clientappio/clientappio_servicer.py +10 -17
  122. flwr/supernode/start_client_internal.py +35 -76
  123. {flwr-1.20.0.dist-info → flwr-1.21.0.dist-info}/METADATA +4 -3
  124. {flwr-1.20.0.dist-info → flwr-1.21.0.dist-info}/RECORD +127 -98
  125. {flwr-1.20.0.dist-info → flwr-1.21.0.dist-info}/entry_points.txt +1 -0
  126. flwr/proto/exec_pb2.py +0 -62
  127. flwr/superexec/app.py +0 -45
  128. flwr/superexec/deployment.py +0 -191
  129. flwr/superexec/executor.py +0 -100
  130. flwr/superexec/simulation.py +0 -129
  131. /flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +0 -0
  132. {flwr-1.20.0.dist-info → flwr-1.21.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,173 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Adaptive Federated Optimization using Yogi (FedYogi) [Reddi et al., 2020] strategy.
16
+
17
+ Paper: arxiv.org/abs/2003.00295
18
+ """
19
+
20
+
21
+ from collections import OrderedDict
22
+ from collections.abc import Iterable
23
+ from typing import Callable, Optional
24
+
25
+ import numpy as np
26
+
27
+ from flwr.common import Array, ArrayRecord, Message, MetricRecord, RecordDict
28
+
29
+ from ..exception import AggregationError
30
+ from .fedopt import FedOpt
31
+
32
+
33
+ # pylint: disable=line-too-long
34
+ class FedYogi(FedOpt):
35
+ """FedYogi [Reddi et al., 2020] strategy.
36
+
37
+ Implementation based on https://arxiv.org/abs/2003.00295v5
38
+
39
+
40
+ Parameters
41
+ ----------
42
+ fraction_train : float (default: 1.0)
43
+ Fraction of nodes used during training. In case `min_train_nodes`
44
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
45
+ will still be sampled.
46
+ fraction_evaluate : float (default: 1.0)
47
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
48
+ is larger than `fraction_evaluate * total_connected_nodes`,
49
+ `min_evaluate_nodes` will still be sampled.
50
+ min_train_nodes : int (default: 2)
51
+ Minimum number of nodes used during training.
52
+ min_evaluate_nodes : int (default: 2)
53
+ Minimum number of nodes used during validation.
54
+ min_available_nodes : int (default: 2)
55
+ Minimum number of total nodes in the system.
56
+ weighted_by_key : str (default: "num-examples")
57
+ The key within each MetricRecord whose value is used as the weight when
58
+ computing weighted averages for both ArrayRecords and MetricRecords.
59
+ arrayrecord_key : str (default: "arrays")
60
+ Key used to store the ArrayRecord when constructing Messages.
61
+ configrecord_key : str (default: "config")
62
+ Key used to store the ConfigRecord when constructing Messages.
63
+ train_metrics_aggr_fn : Optional[callable] (default: None)
64
+ Function with signature (list[RecordDict], str) -> MetricRecord,
65
+ used to aggregate MetricRecords from training round replies.
66
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
67
+ average using the provided weight factor key.
68
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
69
+ Function with signature (list[RecordDict], str) -> MetricRecord,
70
+ used to aggregate MetricRecords from training round replies.
71
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
72
+ average using the provided weight factor key.
73
+ eta : float, optional
74
+ Server-side learning rate. Defaults to 1e-2.
75
+ eta_l : float, optional
76
+ Client-side learning rate. Defaults to 0.0316.
77
+ beta_1 : float, optional
78
+ Momentum parameter. Defaults to 0.9.
79
+ beta_2 : float, optional
80
+ Second moment parameter. Defaults to 0.99.
81
+ tau : float, optional
82
+ Controls the algorithm's degree of adaptability.
83
+ Defaults to 1e-3.
84
+ """
85
+
86
+ # pylint: disable=too-many-arguments, too-many-locals
87
+ def __init__(
88
+ self,
89
+ *,
90
+ fraction_train: float = 1.0,
91
+ fraction_evaluate: float = 1.0,
92
+ min_train_nodes: int = 2,
93
+ min_evaluate_nodes: int = 2,
94
+ min_available_nodes: int = 2,
95
+ weighted_by_key: str = "num-examples",
96
+ arrayrecord_key: str = "arrays",
97
+ configrecord_key: str = "config",
98
+ train_metrics_aggr_fn: Optional[
99
+ Callable[[list[RecordDict], str], MetricRecord]
100
+ ] = None,
101
+ evaluate_metrics_aggr_fn: Optional[
102
+ Callable[[list[RecordDict], str], MetricRecord]
103
+ ] = None,
104
+ eta: float = 1e-2,
105
+ eta_l: float = 0.0316,
106
+ beta_1: float = 0.9,
107
+ beta_2: float = 0.99,
108
+ tau: float = 1e-3,
109
+ ) -> None:
110
+ super().__init__(
111
+ fraction_train=fraction_train,
112
+ fraction_evaluate=fraction_evaluate,
113
+ min_train_nodes=min_train_nodes,
114
+ min_evaluate_nodes=min_evaluate_nodes,
115
+ min_available_nodes=min_available_nodes,
116
+ weighted_by_key=weighted_by_key,
117
+ arrayrecord_key=arrayrecord_key,
118
+ configrecord_key=configrecord_key,
119
+ train_metrics_aggr_fn=train_metrics_aggr_fn,
120
+ evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
121
+ eta=eta,
122
+ eta_l=eta_l,
123
+ beta_1=beta_1,
124
+ beta_2=beta_2,
125
+ tau=tau,
126
+ )
127
+
128
+ def aggregate_train(
129
+ self,
130
+ server_round: int,
131
+ replies: Iterable[Message],
132
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
133
+ """Aggregate ArrayRecords and MetricRecords in the received Messages."""
134
+ aggregated_arrayrecord, aggregated_metrics = super().aggregate_train(
135
+ server_round, replies
136
+ )
137
+
138
+ if aggregated_arrayrecord is None:
139
+ return aggregated_arrayrecord, aggregated_metrics
140
+
141
+ if self.current_arrays is None:
142
+ reason = (
143
+ "Current arrays not set. Ensure that `configure_train` has been "
144
+ "called before aggregation."
145
+ )
146
+ raise AggregationError(reason=reason)
147
+
148
+ # Compute intermediate variables
149
+ delta_t, m_t, aggregated_ndarrays = self._compute_deltat_and_mt(
150
+ aggregated_arrayrecord
151
+ )
152
+
153
+ # v_t
154
+ if not self.v_t:
155
+ self.v_t = {k: np.zeros_like(v) for k, v in aggregated_ndarrays.items()}
156
+ self.v_t = {
157
+ k: v
158
+ - (1.0 - self.beta_2) * (delta_t[k] ** 2) * np.sign(v - delta_t[k] ** 2)
159
+ for k, v in self.v_t.items()
160
+ }
161
+
162
+ new_arrays = {
163
+ k: x + self.eta * m_t[k] / (np.sqrt(self.v_t[k]) + self.tau)
164
+ for k, x in self.current_arrays.items()
165
+ }
166
+
167
+ # Update current arrays
168
+ self.current_arrays = new_arrays
169
+
170
+ return (
171
+ ArrayRecord(OrderedDict({k: Array(v) for k, v in new_arrays.items()})),
172
+ aggregated_metrics,
173
+ )
@@ -0,0 +1,105 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Strategy results."""
16
+
17
+
18
+ import pprint
19
+ from dataclasses import dataclass, field
20
+
21
+ from flwr.common import ArrayRecord, MetricRecord
22
+ from flwr.common.typing import MetricRecordValues
23
+
24
+
25
+ @dataclass
26
+ class Result:
27
+ """Data class carrying records generated during the execution of a strategy.
28
+
29
+ This class encapsulates the results of a federated learning strategy execution,
30
+ including the final global model parameters and metrics collected throughout
31
+ the federated training and evaluation (both federated and centralized) stages.
32
+
33
+ Attributes
34
+ ----------
35
+ arrays : ArrayRecord
36
+ The final global model parameters. Contains the
37
+ aggregated model weights/parameters that resulted from the federated
38
+ learning process.
39
+ train_metrics_clientapp : dict[int, MetricRecord]
40
+ Training metrics collected from ClientApps, indexed by round number.
41
+ Contains aggregated metrics (e.g., loss, accuracy) from the training
42
+ phase of each federated learning round.
43
+ evaluate_metrics_clientapp : dict[int, MetricRecord]
44
+ Evaluation metrics collected from ClientApps, indexed by round number.
45
+ Contains aggregated metrics (e.g. validation loss) from the evaluation
46
+ phase where ClientApps evaluate the global model on their local
47
+ validation/test data.
48
+ evaluate_metrics_serverapp : dict[int, MetricRecord]
49
+ Evaluation metrics generated at the ServerApp, indexed by round number.
50
+ Contains metrics from centralized evaluation performed by the ServerApp
51
+ (e.g., when the server evaluates the global model on a held-out dataset).
52
+ """
53
+
54
+ arrays: ArrayRecord = field(default_factory=ArrayRecord)
55
+ train_metrics_clientapp: dict[int, MetricRecord] = field(default_factory=dict)
56
+ evaluate_metrics_clientapp: dict[int, MetricRecord] = field(default_factory=dict)
57
+ evaluate_metrics_serverapp: dict[int, MetricRecord] = field(default_factory=dict)
58
+
59
+ def __repr__(self) -> str:
60
+ """Create a representation of the Result instance."""
61
+ rep = ""
62
+ arr_size = sum(len(array.data) for array in self.arrays.values()) / (1024**2)
63
+ rep += "Global Arrays:\n" + f"\tArrayRecord ({arr_size:.3f} MB)\n" + "\n"
64
+ rep += (
65
+ "Aggregated ClientApp-side Train Metrics:\n"
66
+ + pprint.pformat(stringify_dict(self.train_metrics_clientapp), indent=2)
67
+ + "\n\n"
68
+ )
69
+
70
+ rep += (
71
+ "Aggregated ClientApp-side Evaluate Metrics:\n"
72
+ + pprint.pformat(stringify_dict(self.evaluate_metrics_clientapp), indent=2)
73
+ + "\n\n"
74
+ )
75
+
76
+ rep += (
77
+ "ServerApp-side Evaluate Metrics:\n"
78
+ + pprint.pformat(stringify_dict(self.evaluate_metrics_serverapp), indent=2)
79
+ + "\n"
80
+ )
81
+
82
+ return rep
83
+
84
+
85
+ def format_value(val: MetricRecordValues) -> str:
86
+ """Format a value as string, applying scientific notation for floats."""
87
+ if isinstance(val, float):
88
+ return f"{val:.4e}"
89
+ if isinstance(val, int):
90
+ return str(val)
91
+ if isinstance(val, list):
92
+ return str([f"{x:.4e}" if isinstance(x, float) else str(x) for x in val])
93
+ return str(val)
94
+
95
+
96
+ def stringify_dict(d: dict[int, MetricRecord]) -> dict[int, dict[str, str]]:
97
+ """Return a copy results metrics but with values converted to string and formatted
98
+ accordingtly."""
99
+ new_metrics_dict = {}
100
+ for k, inner in d.items():
101
+ new_inner = {}
102
+ for ik, iv in inner.items():
103
+ new_inner[ik] = format_value(iv)
104
+ new_metrics_dict[k] = new_inner
105
+ return new_metrics_dict
@@ -0,0 +1,285 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower message-based strategy."""
16
+
17
+
18
+ import io
19
+ import time
20
+ from abc import ABC, abstractmethod
21
+ from collections.abc import Iterable
22
+ from logging import INFO
23
+ from typing import Callable, Optional
24
+
25
+ from flwr.common import ArrayRecord, ConfigRecord, Message, MetricRecord, log
26
+ from flwr.server import Grid
27
+
28
+ from .result import Result
29
+ from .strategy_utils import log_strategy_start_info
30
+
31
+
32
+ class Strategy(ABC):
33
+ """Abstract base class for server strategy implementations."""
34
+
35
+ @abstractmethod
36
+ def configure_train(
37
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
38
+ ) -> Iterable[Message]:
39
+ """Configure the next round of training.
40
+
41
+ Parameters
42
+ ----------
43
+ server_round : int
44
+ The current round of federated learning.
45
+ arrays : ArrayRecord
46
+ Current global ArrayRecord (e.g. global model) to be sent to client
47
+ nodes for training.
48
+ config : ConfigRecord
49
+ Configuration to be sent to clients nodes for training.
50
+ grid : Grid
51
+ The Grid instance used for node sampling and communication.
52
+
53
+ Returns
54
+ -------
55
+ Iterable[Message]
56
+ An iterable of messages to be sent to selected client nodes for training.
57
+ """
58
+
59
+ @abstractmethod
60
+ def aggregate_train(
61
+ self,
62
+ server_round: int,
63
+ replies: Iterable[Message],
64
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
65
+ """Aggregate training results from client nodes.
66
+
67
+ Parameters
68
+ ----------
69
+ server_round : int
70
+ The current round of federated learning, starting from 1.
71
+ replies : Iterable[Message]
72
+ Iterable of reply messages received from client nodes after training.
73
+ Each message contains ArrayRecords and MetricRecords that get aggregated.
74
+
75
+ Returns
76
+ -------
77
+ tuple[Optional[ArrayRecord], Optional[MetricRecord]]
78
+ A tuple containing:
79
+ - ArrayRecord: Aggregated ArrayRecord, or None if aggregation failed
80
+ - MetricRecord: Aggregated MetricRecord, or None if aggregation failed
81
+ """
82
+
83
+ @abstractmethod
84
+ def configure_evaluate(
85
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
86
+ ) -> Iterable[Message]:
87
+ """Configure the next round of evaluation.
88
+
89
+ Parameters
90
+ ----------
91
+ server_round : int
92
+ The current round of federated learning.
93
+ arrays : ArrayRecord
94
+ Current global ArrayRecord (e.g. global model) to be sent to client
95
+ nodes for evaluation.
96
+ config : ConfigRecord
97
+ Configuration to be sent to clients nodes for evaluation.
98
+ grid : Grid
99
+ The Grid instance used for node sampling and communication.
100
+
101
+ Returns
102
+ -------
103
+ Iterable[Message]
104
+ An iterable of messages to be sent to selected client nodes for evaluation.
105
+ """
106
+
107
+ @abstractmethod
108
+ def aggregate_evaluate(
109
+ self,
110
+ server_round: int,
111
+ replies: Iterable[Message],
112
+ ) -> Optional[MetricRecord]:
113
+ """Aggregate evaluation metrics from client nodes.
114
+
115
+ Parameters
116
+ ----------
117
+ server_round : int
118
+ The current round of federated learning.
119
+ replies : Iterable[Message]
120
+ Iterable of reply messages received from client nodes after evaluation.
121
+ MetricRecords in the messages are aggregated.
122
+
123
+ Returns
124
+ -------
125
+ Optional[MetricRecord]
126
+ Aggregated evaluation metrics from all participating clients,
127
+ or None if aggregation failed.
128
+ """
129
+
130
+ @abstractmethod
131
+ def summary(self) -> None:
132
+ """Log summary configuration of the strategy."""
133
+
134
+ # pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-locals
135
+ def start(
136
+ self,
137
+ grid: Grid,
138
+ initial_arrays: ArrayRecord,
139
+ num_rounds: int = 3,
140
+ timeout: float = 3600,
141
+ train_config: Optional[ConfigRecord] = None,
142
+ evaluate_config: Optional[ConfigRecord] = None,
143
+ evaluate_fn: Optional[
144
+ Callable[[int, ArrayRecord], Optional[MetricRecord]]
145
+ ] = None,
146
+ ) -> Result:
147
+ """Execute the federated learning strategy.
148
+
149
+ Runs the complete federated learning workflow for the specified number of
150
+ rounds, including training, evaluation, and optional centralized evaluation.
151
+
152
+ Parameters
153
+ ----------
154
+ grid : Grid
155
+ The Grid instance used to send/receive Messages from nodes executing a
156
+ ClientApp.
157
+ initial_arrays : ArrayRecord
158
+ Initial model parameters (arrays) to be used for federated learning.
159
+ num_rounds : int (default: 3)
160
+ Number of federated learning rounds to execute.
161
+ timeout : float (default: 3600)
162
+ Timeout in seconds for waiting for node responses.
163
+ train_config : ConfigRecord, optional
164
+ Configuration to be sent to nodes during training rounds.
165
+ If unset, an empty ConfigRecord will be used.
166
+ evaluate_config : ConfigRecord, optional
167
+ Configuration to be sent to nodes during evaluation rounds.
168
+ If unset, an empty ConfigRecord will be used.
169
+ evaluate_fn : Callable[[int, ArrayRecord], Optional[MetricRecord]], optional
170
+ Optional function for centralized evaluation of the global model. Takes
171
+ server round number and array record, returns a MetricRecord or None. If
172
+ provided, will be called before the first round and after each round.
173
+ Defaults to None.
174
+
175
+ Returns
176
+ -------
177
+ Results
178
+ Results containing final model arrays and also training metrics, evaluation
179
+ metrics and global evaluation metrics (if provided) from all rounds.
180
+ """
181
+ log(INFO, "Starting %s strategy:", self.__class__.__name__)
182
+ log_strategy_start_info(
183
+ num_rounds, initial_arrays, train_config, evaluate_config
184
+ )
185
+ self.summary()
186
+ log(INFO, "")
187
+
188
+ # Initialize if None
189
+ train_config = ConfigRecord() if train_config is None else train_config
190
+ evaluate_config = ConfigRecord() if evaluate_config is None else evaluate_config
191
+ result = Result()
192
+
193
+ t_start = time.time()
194
+ # Evaluate starting global parameters
195
+ if evaluate_fn:
196
+ res = evaluate_fn(0, initial_arrays)
197
+ log(INFO, "Initial global evaluation results: %s", res)
198
+ if res is not None:
199
+ result.evaluate_metrics_serverapp[0] = res
200
+
201
+ arrays = initial_arrays
202
+
203
+ for current_round in range(1, num_rounds + 1):
204
+ log(INFO, "")
205
+ log(INFO, "[ROUND %s/%s]", current_round, num_rounds)
206
+
207
+ # -----------------------------------------------------------------
208
+ # --- TRAINING (CLIENTAPP-SIDE) -----------------------------------
209
+ # -----------------------------------------------------------------
210
+
211
+ # Call strategy to configure training round
212
+ # Send messages and wait for replies
213
+ train_replies = grid.send_and_receive(
214
+ messages=self.configure_train(
215
+ current_round,
216
+ arrays,
217
+ train_config,
218
+ grid,
219
+ ),
220
+ timeout=timeout,
221
+ )
222
+
223
+ # Aggregate train
224
+ agg_arrays, agg_train_metrics = self.aggregate_train(
225
+ current_round,
226
+ train_replies,
227
+ )
228
+
229
+ # Log training metrics and append to history
230
+ if agg_arrays is not None:
231
+ result.arrays = agg_arrays
232
+ arrays = agg_arrays
233
+ if agg_train_metrics is not None:
234
+ log(INFO, "\t└──> Aggregated MetricRecord: %s", agg_train_metrics)
235
+ result.train_metrics_clientapp[current_round] = agg_train_metrics
236
+
237
+ # -----------------------------------------------------------------
238
+ # --- EVALUATION (CLIENTAPP-SIDE) ---------------------------------
239
+ # -----------------------------------------------------------------
240
+
241
+ # Call strategy to configure evaluation round
242
+ # Send messages and wait for replies
243
+ evaluate_replies = grid.send_and_receive(
244
+ messages=self.configure_evaluate(
245
+ current_round,
246
+ arrays,
247
+ evaluate_config,
248
+ grid,
249
+ ),
250
+ timeout=timeout,
251
+ )
252
+
253
+ # Aggregate evaluate
254
+ agg_evaluate_metrics = self.aggregate_evaluate(
255
+ current_round,
256
+ evaluate_replies,
257
+ )
258
+
259
+ # Log training metrics and append to history
260
+ if agg_evaluate_metrics is not None:
261
+ log(INFO, "\t└──> Aggregated MetricRecord: %s", agg_evaluate_metrics)
262
+ result.evaluate_metrics_clientapp[current_round] = agg_evaluate_metrics
263
+
264
+ # -----------------------------------------------------------------
265
+ # --- EVALUATION (SERVERAPP-SIDE) ---------------------------------
266
+ # -----------------------------------------------------------------
267
+
268
+ # Centralized evaluation
269
+ if evaluate_fn:
270
+ log(INFO, "Global evaluation")
271
+ res = evaluate_fn(current_round, arrays)
272
+ log(INFO, "\t└──> MetricRecord: %s", res)
273
+ if res is not None:
274
+ result.evaluate_metrics_serverapp[current_round] = res
275
+
276
+ log(INFO, "")
277
+ log(INFO, "Strategy execution finished in %.2fs", time.time() - t_start)
278
+ log(INFO, "")
279
+ log(INFO, "Final results:")
280
+ log(INFO, "")
281
+ for line in io.StringIO(str(result)):
282
+ log(INFO, "\t%s", line.strip("\n"))
283
+ log(INFO, "")
284
+
285
+ return result