flwr 1.20.0__py3-none-any.whl → 1.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +4 -1
- flwr/app/__init__.py +28 -0
- flwr/app/exception.py +31 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
- flwr/cli/cli_user_auth_interceptor.py +1 -1
- flwr/cli/config_utils.py +3 -3
- flwr/cli/constant.py +25 -8
- flwr/cli/log.py +9 -9
- flwr/cli/login/login.py +3 -3
- flwr/cli/ls.py +5 -5
- flwr/cli/new/new.py +11 -0
- flwr/cli/new/templates/app/code/__init__.pytorch_msg_api.py.tpl +1 -0
- flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +80 -0
- flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +41 -0
- flwr/cli/new/templates/app/code/task.pytorch_msg_api.py.tpl +98 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch_msg_api.toml.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +9 -13
- flwr/cli/stop.py +7 -4
- flwr/cli/utils.py +19 -8
- flwr/client/grpc_rere_client/connection.py +1 -12
- flwr/client/rest_client/connection.py +3 -0
- flwr/clientapp/__init__.py +10 -0
- flwr/clientapp/mod/__init__.py +26 -0
- flwr/clientapp/mod/centraldp_mods.py +132 -0
- flwr/common/args.py +20 -6
- flwr/common/auth_plugin/__init__.py +4 -4
- flwr/common/auth_plugin/auth_plugin.py +7 -7
- flwr/common/constant.py +23 -4
- flwr/common/event_log_plugin/event_log_plugin.py +1 -1
- flwr/common/exit/__init__.py +4 -0
- flwr/common/exit/exit.py +8 -1
- flwr/common/exit/exit_code.py +26 -7
- flwr/common/exit/exit_handler.py +62 -0
- flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
- flwr/common/grpc.py +0 -11
- flwr/common/inflatable_utils.py +1 -1
- flwr/common/logger.py +1 -1
- flwr/common/retry_invoker.py +30 -11
- flwr/common/telemetry.py +4 -0
- flwr/compat/server/app.py +2 -2
- flwr/proto/appio_pb2.py +25 -17
- flwr/proto/appio_pb2.pyi +46 -2
- flwr/proto/clientappio_pb2.py +3 -11
- flwr/proto/clientappio_pb2.pyi +0 -47
- flwr/proto/clientappio_pb2_grpc.py +19 -20
- flwr/proto/clientappio_pb2_grpc.pyi +10 -11
- flwr/proto/control_pb2.py +62 -0
- flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +54 -54
- flwr/proto/{exec_pb2_grpc.pyi → control_pb2_grpc.pyi} +28 -28
- flwr/proto/serverappio_pb2.py +2 -2
- flwr/proto/serverappio_pb2_grpc.py +68 -0
- flwr/proto/serverappio_pb2_grpc.pyi +26 -0
- flwr/proto/simulationio_pb2.py +4 -11
- flwr/proto/simulationio_pb2.pyi +0 -58
- flwr/proto/simulationio_pb2_grpc.py +129 -27
- flwr/proto/simulationio_pb2_grpc.pyi +52 -13
- flwr/server/app.py +129 -152
- flwr/server/grid/grpc_grid.py +3 -0
- flwr/server/grid/inmemory_grid.py +1 -0
- flwr/server/serverapp/app.py +157 -146
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
- flwr/server/superlink/fleet/vce/vce_api.py +6 -6
- flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
- flwr/server/superlink/linkstate/linkstate.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -1
- flwr/server/superlink/serverappio/serverappio_servicer.py +61 -6
- flwr/server/superlink/simulation/simulationio_servicer.py +97 -21
- flwr/serverapp/__init__.py +12 -0
- flwr/serverapp/dp_fixed_clipping.py +352 -0
- flwr/serverapp/exception.py +38 -0
- flwr/serverapp/strategy/__init__.py +38 -0
- flwr/serverapp/strategy/dp_fixed_clipping.py +352 -0
- flwr/serverapp/strategy/fedadagrad.py +162 -0
- flwr/serverapp/strategy/fedadam.py +181 -0
- flwr/serverapp/strategy/fedavg.py +295 -0
- flwr/serverapp/strategy/fedopt.py +218 -0
- flwr/serverapp/strategy/fedyogi.py +173 -0
- flwr/serverapp/strategy/result.py +105 -0
- flwr/serverapp/strategy/strategy.py +285 -0
- flwr/serverapp/strategy/strategy_utils.py +251 -0
- flwr/serverapp/strategy/strategy_utils_tests.py +304 -0
- flwr/simulation/app.py +161 -164
- flwr/supercore/app_utils.py +58 -0
- flwr/{supernode/scheduler → supercore/cli}/__init__.py +3 -3
- flwr/supercore/cli/flower_superexec.py +141 -0
- flwr/supercore/{scheduler → corestate}/__init__.py +3 -3
- flwr/supercore/corestate/corestate.py +81 -0
- flwr/supercore/grpc_health/__init__.py +3 -0
- flwr/supercore/grpc_health/health_server.py +53 -0
- flwr/supercore/grpc_health/simple_health_servicer.py +2 -2
- flwr/{superexec → supercore/superexec}/__init__.py +1 -1
- flwr/supercore/superexec/plugin/__init__.py +28 -0
- flwr/{supernode/scheduler/simple_clientapp_scheduler_plugin.py → supercore/superexec/plugin/base_exec_plugin.py} +10 -6
- flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
- flwr/supercore/{scheduler/plugin.py → superexec/plugin/exec_plugin.py} +4 -4
- flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
- flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
- flwr/supercore/superexec/run_superexec.py +185 -0
- flwr/superlink/servicer/__init__.py +15 -0
- flwr/superlink/servicer/control/__init__.py +22 -0
- flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +7 -7
- flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +24 -29
- flwr/{superexec/exec_license_interceptor.py → superlink/servicer/control/control_license_interceptor.py} +6 -6
- flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +69 -30
- flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +10 -10
- flwr/supernode/cli/flower_supernode.py +3 -0
- flwr/supernode/cli/flwr_clientapp.py +18 -21
- flwr/supernode/nodestate/in_memory_nodestate.py +2 -2
- flwr/supernode/nodestate/nodestate.py +3 -59
- flwr/supernode/runtime/run_clientapp.py +39 -102
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +10 -17
- flwr/supernode/start_client_internal.py +35 -76
- {flwr-1.20.0.dist-info → flwr-1.21.0.dist-info}/METADATA +4 -3
- {flwr-1.20.0.dist-info → flwr-1.21.0.dist-info}/RECORD +127 -98
- {flwr-1.20.0.dist-info → flwr-1.21.0.dist-info}/entry_points.txt +1 -0
- flwr/proto/exec_pb2.py +0 -62
- flwr/superexec/app.py +0 -45
- flwr/superexec/deployment.py +0 -191
- flwr/superexec/executor.py +0 -100
- flwr/superexec/simulation.py +0 -129
- /flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +0 -0
- {flwr-1.20.0.dist-info → flwr-1.21.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Adaptive Federated Optimization using Yogi (FedYogi) [Reddi et al., 2020] strategy.
|
|
16
|
+
|
|
17
|
+
Paper: arxiv.org/abs/2003.00295
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
from collections import OrderedDict
|
|
22
|
+
from collections.abc import Iterable
|
|
23
|
+
from typing import Callable, Optional
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
26
|
+
|
|
27
|
+
from flwr.common import Array, ArrayRecord, Message, MetricRecord, RecordDict
|
|
28
|
+
|
|
29
|
+
from ..exception import AggregationError
|
|
30
|
+
from .fedopt import FedOpt
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# pylint: disable=line-too-long
|
|
34
|
+
class FedYogi(FedOpt):
|
|
35
|
+
"""FedYogi [Reddi et al., 2020] strategy.
|
|
36
|
+
|
|
37
|
+
Implementation based on https://arxiv.org/abs/2003.00295v5
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
fraction_train : float (default: 1.0)
|
|
43
|
+
Fraction of nodes used during training. In case `min_train_nodes`
|
|
44
|
+
is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
|
|
45
|
+
will still be sampled.
|
|
46
|
+
fraction_evaluate : float (default: 1.0)
|
|
47
|
+
Fraction of nodes used during validation. In case `min_evaluate_nodes`
|
|
48
|
+
is larger than `fraction_evaluate * total_connected_nodes`,
|
|
49
|
+
`min_evaluate_nodes` will still be sampled.
|
|
50
|
+
min_train_nodes : int (default: 2)
|
|
51
|
+
Minimum number of nodes used during training.
|
|
52
|
+
min_evaluate_nodes : int (default: 2)
|
|
53
|
+
Minimum number of nodes used during validation.
|
|
54
|
+
min_available_nodes : int (default: 2)
|
|
55
|
+
Minimum number of total nodes in the system.
|
|
56
|
+
weighted_by_key : str (default: "num-examples")
|
|
57
|
+
The key within each MetricRecord whose value is used as the weight when
|
|
58
|
+
computing weighted averages for both ArrayRecords and MetricRecords.
|
|
59
|
+
arrayrecord_key : str (default: "arrays")
|
|
60
|
+
Key used to store the ArrayRecord when constructing Messages.
|
|
61
|
+
configrecord_key : str (default: "config")
|
|
62
|
+
Key used to store the ConfigRecord when constructing Messages.
|
|
63
|
+
train_metrics_aggr_fn : Optional[callable] (default: None)
|
|
64
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
|
65
|
+
used to aggregate MetricRecords from training round replies.
|
|
66
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
|
67
|
+
average using the provided weight factor key.
|
|
68
|
+
evaluate_metrics_aggr_fn : Optional[callable] (default: None)
|
|
69
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
|
70
|
+
used to aggregate MetricRecords from training round replies.
|
|
71
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
|
72
|
+
average using the provided weight factor key.
|
|
73
|
+
eta : float, optional
|
|
74
|
+
Server-side learning rate. Defaults to 1e-2.
|
|
75
|
+
eta_l : float, optional
|
|
76
|
+
Client-side learning rate. Defaults to 0.0316.
|
|
77
|
+
beta_1 : float, optional
|
|
78
|
+
Momentum parameter. Defaults to 0.9.
|
|
79
|
+
beta_2 : float, optional
|
|
80
|
+
Second moment parameter. Defaults to 0.99.
|
|
81
|
+
tau : float, optional
|
|
82
|
+
Controls the algorithm's degree of adaptability.
|
|
83
|
+
Defaults to 1e-3.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
# pylint: disable=too-many-arguments, too-many-locals
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
*,
|
|
90
|
+
fraction_train: float = 1.0,
|
|
91
|
+
fraction_evaluate: float = 1.0,
|
|
92
|
+
min_train_nodes: int = 2,
|
|
93
|
+
min_evaluate_nodes: int = 2,
|
|
94
|
+
min_available_nodes: int = 2,
|
|
95
|
+
weighted_by_key: str = "num-examples",
|
|
96
|
+
arrayrecord_key: str = "arrays",
|
|
97
|
+
configrecord_key: str = "config",
|
|
98
|
+
train_metrics_aggr_fn: Optional[
|
|
99
|
+
Callable[[list[RecordDict], str], MetricRecord]
|
|
100
|
+
] = None,
|
|
101
|
+
evaluate_metrics_aggr_fn: Optional[
|
|
102
|
+
Callable[[list[RecordDict], str], MetricRecord]
|
|
103
|
+
] = None,
|
|
104
|
+
eta: float = 1e-2,
|
|
105
|
+
eta_l: float = 0.0316,
|
|
106
|
+
beta_1: float = 0.9,
|
|
107
|
+
beta_2: float = 0.99,
|
|
108
|
+
tau: float = 1e-3,
|
|
109
|
+
) -> None:
|
|
110
|
+
super().__init__(
|
|
111
|
+
fraction_train=fraction_train,
|
|
112
|
+
fraction_evaluate=fraction_evaluate,
|
|
113
|
+
min_train_nodes=min_train_nodes,
|
|
114
|
+
min_evaluate_nodes=min_evaluate_nodes,
|
|
115
|
+
min_available_nodes=min_available_nodes,
|
|
116
|
+
weighted_by_key=weighted_by_key,
|
|
117
|
+
arrayrecord_key=arrayrecord_key,
|
|
118
|
+
configrecord_key=configrecord_key,
|
|
119
|
+
train_metrics_aggr_fn=train_metrics_aggr_fn,
|
|
120
|
+
evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
|
|
121
|
+
eta=eta,
|
|
122
|
+
eta_l=eta_l,
|
|
123
|
+
beta_1=beta_1,
|
|
124
|
+
beta_2=beta_2,
|
|
125
|
+
tau=tau,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
def aggregate_train(
|
|
129
|
+
self,
|
|
130
|
+
server_round: int,
|
|
131
|
+
replies: Iterable[Message],
|
|
132
|
+
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
|
133
|
+
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
134
|
+
aggregated_arrayrecord, aggregated_metrics = super().aggregate_train(
|
|
135
|
+
server_round, replies
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
if aggregated_arrayrecord is None:
|
|
139
|
+
return aggregated_arrayrecord, aggregated_metrics
|
|
140
|
+
|
|
141
|
+
if self.current_arrays is None:
|
|
142
|
+
reason = (
|
|
143
|
+
"Current arrays not set. Ensure that `configure_train` has been "
|
|
144
|
+
"called before aggregation."
|
|
145
|
+
)
|
|
146
|
+
raise AggregationError(reason=reason)
|
|
147
|
+
|
|
148
|
+
# Compute intermediate variables
|
|
149
|
+
delta_t, m_t, aggregated_ndarrays = self._compute_deltat_and_mt(
|
|
150
|
+
aggregated_arrayrecord
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# v_t
|
|
154
|
+
if not self.v_t:
|
|
155
|
+
self.v_t = {k: np.zeros_like(v) for k, v in aggregated_ndarrays.items()}
|
|
156
|
+
self.v_t = {
|
|
157
|
+
k: v
|
|
158
|
+
- (1.0 - self.beta_2) * (delta_t[k] ** 2) * np.sign(v - delta_t[k] ** 2)
|
|
159
|
+
for k, v in self.v_t.items()
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
new_arrays = {
|
|
163
|
+
k: x + self.eta * m_t[k] / (np.sqrt(self.v_t[k]) + self.tau)
|
|
164
|
+
for k, x in self.current_arrays.items()
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
# Update current arrays
|
|
168
|
+
self.current_arrays = new_arrays
|
|
169
|
+
|
|
170
|
+
return (
|
|
171
|
+
ArrayRecord(OrderedDict({k: Array(v) for k, v in new_arrays.items()})),
|
|
172
|
+
aggregated_metrics,
|
|
173
|
+
)
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Strategy results."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import pprint
|
|
19
|
+
from dataclasses import dataclass, field
|
|
20
|
+
|
|
21
|
+
from flwr.common import ArrayRecord, MetricRecord
|
|
22
|
+
from flwr.common.typing import MetricRecordValues
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class Result:
|
|
27
|
+
"""Data class carrying records generated during the execution of a strategy.
|
|
28
|
+
|
|
29
|
+
This class encapsulates the results of a federated learning strategy execution,
|
|
30
|
+
including the final global model parameters and metrics collected throughout
|
|
31
|
+
the federated training and evaluation (both federated and centralized) stages.
|
|
32
|
+
|
|
33
|
+
Attributes
|
|
34
|
+
----------
|
|
35
|
+
arrays : ArrayRecord
|
|
36
|
+
The final global model parameters. Contains the
|
|
37
|
+
aggregated model weights/parameters that resulted from the federated
|
|
38
|
+
learning process.
|
|
39
|
+
train_metrics_clientapp : dict[int, MetricRecord]
|
|
40
|
+
Training metrics collected from ClientApps, indexed by round number.
|
|
41
|
+
Contains aggregated metrics (e.g., loss, accuracy) from the training
|
|
42
|
+
phase of each federated learning round.
|
|
43
|
+
evaluate_metrics_clientapp : dict[int, MetricRecord]
|
|
44
|
+
Evaluation metrics collected from ClientApps, indexed by round number.
|
|
45
|
+
Contains aggregated metrics (e.g. validation loss) from the evaluation
|
|
46
|
+
phase where ClientApps evaluate the global model on their local
|
|
47
|
+
validation/test data.
|
|
48
|
+
evaluate_metrics_serverapp : dict[int, MetricRecord]
|
|
49
|
+
Evaluation metrics generated at the ServerApp, indexed by round number.
|
|
50
|
+
Contains metrics from centralized evaluation performed by the ServerApp
|
|
51
|
+
(e.g., when the server evaluates the global model on a held-out dataset).
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
arrays: ArrayRecord = field(default_factory=ArrayRecord)
|
|
55
|
+
train_metrics_clientapp: dict[int, MetricRecord] = field(default_factory=dict)
|
|
56
|
+
evaluate_metrics_clientapp: dict[int, MetricRecord] = field(default_factory=dict)
|
|
57
|
+
evaluate_metrics_serverapp: dict[int, MetricRecord] = field(default_factory=dict)
|
|
58
|
+
|
|
59
|
+
def __repr__(self) -> str:
|
|
60
|
+
"""Create a representation of the Result instance."""
|
|
61
|
+
rep = ""
|
|
62
|
+
arr_size = sum(len(array.data) for array in self.arrays.values()) / (1024**2)
|
|
63
|
+
rep += "Global Arrays:\n" + f"\tArrayRecord ({arr_size:.3f} MB)\n" + "\n"
|
|
64
|
+
rep += (
|
|
65
|
+
"Aggregated ClientApp-side Train Metrics:\n"
|
|
66
|
+
+ pprint.pformat(stringify_dict(self.train_metrics_clientapp), indent=2)
|
|
67
|
+
+ "\n\n"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
rep += (
|
|
71
|
+
"Aggregated ClientApp-side Evaluate Metrics:\n"
|
|
72
|
+
+ pprint.pformat(stringify_dict(self.evaluate_metrics_clientapp), indent=2)
|
|
73
|
+
+ "\n\n"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
rep += (
|
|
77
|
+
"ServerApp-side Evaluate Metrics:\n"
|
|
78
|
+
+ pprint.pformat(stringify_dict(self.evaluate_metrics_serverapp), indent=2)
|
|
79
|
+
+ "\n"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
return rep
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def format_value(val: MetricRecordValues) -> str:
|
|
86
|
+
"""Format a value as string, applying scientific notation for floats."""
|
|
87
|
+
if isinstance(val, float):
|
|
88
|
+
return f"{val:.4e}"
|
|
89
|
+
if isinstance(val, int):
|
|
90
|
+
return str(val)
|
|
91
|
+
if isinstance(val, list):
|
|
92
|
+
return str([f"{x:.4e}" if isinstance(x, float) else str(x) for x in val])
|
|
93
|
+
return str(val)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def stringify_dict(d: dict[int, MetricRecord]) -> dict[int, dict[str, str]]:
|
|
97
|
+
"""Return a copy results metrics but with values converted to string and formatted
|
|
98
|
+
accordingtly."""
|
|
99
|
+
new_metrics_dict = {}
|
|
100
|
+
for k, inner in d.items():
|
|
101
|
+
new_inner = {}
|
|
102
|
+
for ik, iv in inner.items():
|
|
103
|
+
new_inner[ik] = format_value(iv)
|
|
104
|
+
new_metrics_dict[k] = new_inner
|
|
105
|
+
return new_metrics_dict
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower message-based strategy."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import io
|
|
19
|
+
import time
|
|
20
|
+
from abc import ABC, abstractmethod
|
|
21
|
+
from collections.abc import Iterable
|
|
22
|
+
from logging import INFO
|
|
23
|
+
from typing import Callable, Optional
|
|
24
|
+
|
|
25
|
+
from flwr.common import ArrayRecord, ConfigRecord, Message, MetricRecord, log
|
|
26
|
+
from flwr.server import Grid
|
|
27
|
+
|
|
28
|
+
from .result import Result
|
|
29
|
+
from .strategy_utils import log_strategy_start_info
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Strategy(ABC):
|
|
33
|
+
"""Abstract base class for server strategy implementations."""
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def configure_train(
|
|
37
|
+
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
|
38
|
+
) -> Iterable[Message]:
|
|
39
|
+
"""Configure the next round of training.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
server_round : int
|
|
44
|
+
The current round of federated learning.
|
|
45
|
+
arrays : ArrayRecord
|
|
46
|
+
Current global ArrayRecord (e.g. global model) to be sent to client
|
|
47
|
+
nodes for training.
|
|
48
|
+
config : ConfigRecord
|
|
49
|
+
Configuration to be sent to clients nodes for training.
|
|
50
|
+
grid : Grid
|
|
51
|
+
The Grid instance used for node sampling and communication.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
Iterable[Message]
|
|
56
|
+
An iterable of messages to be sent to selected client nodes for training.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def aggregate_train(
|
|
61
|
+
self,
|
|
62
|
+
server_round: int,
|
|
63
|
+
replies: Iterable[Message],
|
|
64
|
+
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
|
65
|
+
"""Aggregate training results from client nodes.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
server_round : int
|
|
70
|
+
The current round of federated learning, starting from 1.
|
|
71
|
+
replies : Iterable[Message]
|
|
72
|
+
Iterable of reply messages received from client nodes after training.
|
|
73
|
+
Each message contains ArrayRecords and MetricRecords that get aggregated.
|
|
74
|
+
|
|
75
|
+
Returns
|
|
76
|
+
-------
|
|
77
|
+
tuple[Optional[ArrayRecord], Optional[MetricRecord]]
|
|
78
|
+
A tuple containing:
|
|
79
|
+
- ArrayRecord: Aggregated ArrayRecord, or None if aggregation failed
|
|
80
|
+
- MetricRecord: Aggregated MetricRecord, or None if aggregation failed
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
@abstractmethod
|
|
84
|
+
def configure_evaluate(
|
|
85
|
+
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
|
86
|
+
) -> Iterable[Message]:
|
|
87
|
+
"""Configure the next round of evaluation.
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
server_round : int
|
|
92
|
+
The current round of federated learning.
|
|
93
|
+
arrays : ArrayRecord
|
|
94
|
+
Current global ArrayRecord (e.g. global model) to be sent to client
|
|
95
|
+
nodes for evaluation.
|
|
96
|
+
config : ConfigRecord
|
|
97
|
+
Configuration to be sent to clients nodes for evaluation.
|
|
98
|
+
grid : Grid
|
|
99
|
+
The Grid instance used for node sampling and communication.
|
|
100
|
+
|
|
101
|
+
Returns
|
|
102
|
+
-------
|
|
103
|
+
Iterable[Message]
|
|
104
|
+
An iterable of messages to be sent to selected client nodes for evaluation.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
@abstractmethod
|
|
108
|
+
def aggregate_evaluate(
|
|
109
|
+
self,
|
|
110
|
+
server_round: int,
|
|
111
|
+
replies: Iterable[Message],
|
|
112
|
+
) -> Optional[MetricRecord]:
|
|
113
|
+
"""Aggregate evaluation metrics from client nodes.
|
|
114
|
+
|
|
115
|
+
Parameters
|
|
116
|
+
----------
|
|
117
|
+
server_round : int
|
|
118
|
+
The current round of federated learning.
|
|
119
|
+
replies : Iterable[Message]
|
|
120
|
+
Iterable of reply messages received from client nodes after evaluation.
|
|
121
|
+
MetricRecords in the messages are aggregated.
|
|
122
|
+
|
|
123
|
+
Returns
|
|
124
|
+
-------
|
|
125
|
+
Optional[MetricRecord]
|
|
126
|
+
Aggregated evaluation metrics from all participating clients,
|
|
127
|
+
or None if aggregation failed.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
@abstractmethod
|
|
131
|
+
def summary(self) -> None:
|
|
132
|
+
"""Log summary configuration of the strategy."""
|
|
133
|
+
|
|
134
|
+
# pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-locals
|
|
135
|
+
def start(
|
|
136
|
+
self,
|
|
137
|
+
grid: Grid,
|
|
138
|
+
initial_arrays: ArrayRecord,
|
|
139
|
+
num_rounds: int = 3,
|
|
140
|
+
timeout: float = 3600,
|
|
141
|
+
train_config: Optional[ConfigRecord] = None,
|
|
142
|
+
evaluate_config: Optional[ConfigRecord] = None,
|
|
143
|
+
evaluate_fn: Optional[
|
|
144
|
+
Callable[[int, ArrayRecord], Optional[MetricRecord]]
|
|
145
|
+
] = None,
|
|
146
|
+
) -> Result:
|
|
147
|
+
"""Execute the federated learning strategy.
|
|
148
|
+
|
|
149
|
+
Runs the complete federated learning workflow for the specified number of
|
|
150
|
+
rounds, including training, evaluation, and optional centralized evaluation.
|
|
151
|
+
|
|
152
|
+
Parameters
|
|
153
|
+
----------
|
|
154
|
+
grid : Grid
|
|
155
|
+
The Grid instance used to send/receive Messages from nodes executing a
|
|
156
|
+
ClientApp.
|
|
157
|
+
initial_arrays : ArrayRecord
|
|
158
|
+
Initial model parameters (arrays) to be used for federated learning.
|
|
159
|
+
num_rounds : int (default: 3)
|
|
160
|
+
Number of federated learning rounds to execute.
|
|
161
|
+
timeout : float (default: 3600)
|
|
162
|
+
Timeout in seconds for waiting for node responses.
|
|
163
|
+
train_config : ConfigRecord, optional
|
|
164
|
+
Configuration to be sent to nodes during training rounds.
|
|
165
|
+
If unset, an empty ConfigRecord will be used.
|
|
166
|
+
evaluate_config : ConfigRecord, optional
|
|
167
|
+
Configuration to be sent to nodes during evaluation rounds.
|
|
168
|
+
If unset, an empty ConfigRecord will be used.
|
|
169
|
+
evaluate_fn : Callable[[int, ArrayRecord], Optional[MetricRecord]], optional
|
|
170
|
+
Optional function for centralized evaluation of the global model. Takes
|
|
171
|
+
server round number and array record, returns a MetricRecord or None. If
|
|
172
|
+
provided, will be called before the first round and after each round.
|
|
173
|
+
Defaults to None.
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
Results
|
|
178
|
+
Results containing final model arrays and also training metrics, evaluation
|
|
179
|
+
metrics and global evaluation metrics (if provided) from all rounds.
|
|
180
|
+
"""
|
|
181
|
+
log(INFO, "Starting %s strategy:", self.__class__.__name__)
|
|
182
|
+
log_strategy_start_info(
|
|
183
|
+
num_rounds, initial_arrays, train_config, evaluate_config
|
|
184
|
+
)
|
|
185
|
+
self.summary()
|
|
186
|
+
log(INFO, "")
|
|
187
|
+
|
|
188
|
+
# Initialize if None
|
|
189
|
+
train_config = ConfigRecord() if train_config is None else train_config
|
|
190
|
+
evaluate_config = ConfigRecord() if evaluate_config is None else evaluate_config
|
|
191
|
+
result = Result()
|
|
192
|
+
|
|
193
|
+
t_start = time.time()
|
|
194
|
+
# Evaluate starting global parameters
|
|
195
|
+
if evaluate_fn:
|
|
196
|
+
res = evaluate_fn(0, initial_arrays)
|
|
197
|
+
log(INFO, "Initial global evaluation results: %s", res)
|
|
198
|
+
if res is not None:
|
|
199
|
+
result.evaluate_metrics_serverapp[0] = res
|
|
200
|
+
|
|
201
|
+
arrays = initial_arrays
|
|
202
|
+
|
|
203
|
+
for current_round in range(1, num_rounds + 1):
|
|
204
|
+
log(INFO, "")
|
|
205
|
+
log(INFO, "[ROUND %s/%s]", current_round, num_rounds)
|
|
206
|
+
|
|
207
|
+
# -----------------------------------------------------------------
|
|
208
|
+
# --- TRAINING (CLIENTAPP-SIDE) -----------------------------------
|
|
209
|
+
# -----------------------------------------------------------------
|
|
210
|
+
|
|
211
|
+
# Call strategy to configure training round
|
|
212
|
+
# Send messages and wait for replies
|
|
213
|
+
train_replies = grid.send_and_receive(
|
|
214
|
+
messages=self.configure_train(
|
|
215
|
+
current_round,
|
|
216
|
+
arrays,
|
|
217
|
+
train_config,
|
|
218
|
+
grid,
|
|
219
|
+
),
|
|
220
|
+
timeout=timeout,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# Aggregate train
|
|
224
|
+
agg_arrays, agg_train_metrics = self.aggregate_train(
|
|
225
|
+
current_round,
|
|
226
|
+
train_replies,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Log training metrics and append to history
|
|
230
|
+
if agg_arrays is not None:
|
|
231
|
+
result.arrays = agg_arrays
|
|
232
|
+
arrays = agg_arrays
|
|
233
|
+
if agg_train_metrics is not None:
|
|
234
|
+
log(INFO, "\t└──> Aggregated MetricRecord: %s", agg_train_metrics)
|
|
235
|
+
result.train_metrics_clientapp[current_round] = agg_train_metrics
|
|
236
|
+
|
|
237
|
+
# -----------------------------------------------------------------
|
|
238
|
+
# --- EVALUATION (CLIENTAPP-SIDE) ---------------------------------
|
|
239
|
+
# -----------------------------------------------------------------
|
|
240
|
+
|
|
241
|
+
# Call strategy to configure evaluation round
|
|
242
|
+
# Send messages and wait for replies
|
|
243
|
+
evaluate_replies = grid.send_and_receive(
|
|
244
|
+
messages=self.configure_evaluate(
|
|
245
|
+
current_round,
|
|
246
|
+
arrays,
|
|
247
|
+
evaluate_config,
|
|
248
|
+
grid,
|
|
249
|
+
),
|
|
250
|
+
timeout=timeout,
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
# Aggregate evaluate
|
|
254
|
+
agg_evaluate_metrics = self.aggregate_evaluate(
|
|
255
|
+
current_round,
|
|
256
|
+
evaluate_replies,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# Log training metrics and append to history
|
|
260
|
+
if agg_evaluate_metrics is not None:
|
|
261
|
+
log(INFO, "\t└──> Aggregated MetricRecord: %s", agg_evaluate_metrics)
|
|
262
|
+
result.evaluate_metrics_clientapp[current_round] = agg_evaluate_metrics
|
|
263
|
+
|
|
264
|
+
# -----------------------------------------------------------------
|
|
265
|
+
# --- EVALUATION (SERVERAPP-SIDE) ---------------------------------
|
|
266
|
+
# -----------------------------------------------------------------
|
|
267
|
+
|
|
268
|
+
# Centralized evaluation
|
|
269
|
+
if evaluate_fn:
|
|
270
|
+
log(INFO, "Global evaluation")
|
|
271
|
+
res = evaluate_fn(current_round, arrays)
|
|
272
|
+
log(INFO, "\t└──> MetricRecord: %s", res)
|
|
273
|
+
if res is not None:
|
|
274
|
+
result.evaluate_metrics_serverapp[current_round] = res
|
|
275
|
+
|
|
276
|
+
log(INFO, "")
|
|
277
|
+
log(INFO, "Strategy execution finished in %.2fs", time.time() - t_start)
|
|
278
|
+
log(INFO, "")
|
|
279
|
+
log(INFO, "Final results:")
|
|
280
|
+
log(INFO, "")
|
|
281
|
+
for line in io.StringIO(str(result)):
|
|
282
|
+
log(INFO, "\t%s", line.strip("\n"))
|
|
283
|
+
log(INFO, "")
|
|
284
|
+
|
|
285
|
+
return result
|