flwr 1.19.0__py3-none-any.whl → 1.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +4 -1
- flwr/app/__init__.py +28 -0
- flwr/app/exception.py +31 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
- flwr/cli/build.py +15 -5
- flwr/cli/cli_user_auth_interceptor.py +1 -1
- flwr/cli/config_utils.py +3 -3
- flwr/cli/constant.py +25 -8
- flwr/cli/log.py +9 -9
- flwr/cli/login/login.py +3 -3
- flwr/cli/ls.py +5 -5
- flwr/cli/new/new.py +23 -4
- flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
- flwr/cli/new/templates/app/README.md.tpl +5 -0
- flwr/cli/new/templates/app/code/__init__.pytorch_msg_api.py.tpl +1 -0
- flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +80 -0
- flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +41 -0
- flwr/cli/new/templates/app/code/task.pytorch_msg_api.py.tpl +98 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -3
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
- flwr/cli/new/templates/app/pyproject.pytorch_msg_api.toml.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
- flwr/cli/run/run.py +53 -50
- flwr/cli/stop.py +7 -4
- flwr/cli/utils.py +29 -11
- flwr/client/grpc_adapter_client/connection.py +11 -4
- flwr/client/grpc_rere_client/connection.py +93 -129
- flwr/client/rest_client/connection.py +134 -164
- flwr/clientapp/__init__.py +10 -0
- flwr/clientapp/mod/__init__.py +26 -0
- flwr/clientapp/mod/centraldp_mods.py +132 -0
- flwr/common/args.py +20 -6
- flwr/common/auth_plugin/__init__.py +4 -4
- flwr/common/auth_plugin/auth_plugin.py +7 -7
- flwr/common/constant.py +26 -5
- flwr/common/event_log_plugin/event_log_plugin.py +1 -1
- flwr/common/exit/__init__.py +4 -0
- flwr/common/exit/exit.py +8 -1
- flwr/common/exit/exit_code.py +42 -8
- flwr/common/exit/exit_handler.py +62 -0
- flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
- flwr/common/grpc.py +1 -1
- flwr/common/{inflatable_grpc_utils.py → inflatable_protobuf_utils.py} +52 -10
- flwr/common/inflatable_utils.py +191 -24
- flwr/common/logger.py +1 -1
- flwr/common/record/array.py +101 -22
- flwr/common/record/arraychunk.py +59 -0
- flwr/common/retry_invoker.py +30 -11
- flwr/common/serde.py +0 -28
- flwr/common/telemetry.py +4 -0
- flwr/compat/client/app.py +14 -31
- flwr/compat/server/app.py +2 -2
- flwr/proto/appio_pb2.py +51 -0
- flwr/proto/appio_pb2.pyi +195 -0
- flwr/proto/appio_pb2_grpc.py +4 -0
- flwr/proto/appio_pb2_grpc.pyi +4 -0
- flwr/proto/clientappio_pb2.py +4 -19
- flwr/proto/clientappio_pb2.pyi +0 -125
- flwr/proto/clientappio_pb2_grpc.py +269 -29
- flwr/proto/clientappio_pb2_grpc.pyi +114 -21
- flwr/proto/control_pb2.py +62 -0
- flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +54 -54
- flwr/proto/{exec_pb2_grpc.pyi → control_pb2_grpc.pyi} +28 -28
- flwr/proto/fleet_pb2.py +12 -20
- flwr/proto/fleet_pb2.pyi +6 -36
- flwr/proto/serverappio_pb2.py +8 -31
- flwr/proto/serverappio_pb2.pyi +0 -152
- flwr/proto/serverappio_pb2_grpc.py +107 -38
- flwr/proto/serverappio_pb2_grpc.pyi +47 -20
- flwr/proto/simulationio_pb2.py +4 -11
- flwr/proto/simulationio_pb2.pyi +0 -58
- flwr/proto/simulationio_pb2_grpc.py +129 -27
- flwr/proto/simulationio_pb2_grpc.pyi +52 -13
- flwr/server/app.py +130 -153
- flwr/server/fleet_event_log_interceptor.py +4 -0
- flwr/server/grid/grpc_grid.py +94 -54
- flwr/server/grid/inmemory_grid.py +1 -0
- flwr/server/serverapp/app.py +165 -144
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +8 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
- flwr/server/superlink/fleet/message_handler/message_handler.py +10 -16
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
- flwr/server/superlink/fleet/vce/vce_api.py +6 -6
- flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
- flwr/server/superlink/linkstate/linkstate.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
- flwr/server/superlink/serverappio/serverappio_grpc.py +2 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +95 -48
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +98 -22
- flwr/server/superlink/utils.py +0 -35
- flwr/serverapp/__init__.py +12 -0
- flwr/serverapp/dp_fixed_clipping.py +352 -0
- flwr/serverapp/exception.py +38 -0
- flwr/serverapp/strategy/__init__.py +38 -0
- flwr/serverapp/strategy/dp_fixed_clipping.py +352 -0
- flwr/serverapp/strategy/fedadagrad.py +162 -0
- flwr/serverapp/strategy/fedadam.py +181 -0
- flwr/serverapp/strategy/fedavg.py +295 -0
- flwr/serverapp/strategy/fedopt.py +218 -0
- flwr/serverapp/strategy/fedyogi.py +173 -0
- flwr/serverapp/strategy/result.py +105 -0
- flwr/serverapp/strategy/strategy.py +285 -0
- flwr/serverapp/strategy/strategy_utils.py +251 -0
- flwr/serverapp/strategy/strategy_utils_tests.py +304 -0
- flwr/simulation/app.py +159 -154
- flwr/simulation/run_simulation.py +17 -0
- flwr/supercore/app_utils.py +58 -0
- flwr/supercore/cli/__init__.py +22 -0
- flwr/supercore/cli/flower_superexec.py +141 -0
- flwr/supercore/corestate/__init__.py +22 -0
- flwr/supercore/corestate/corestate.py +81 -0
- flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
- flwr/supercore/grpc_health/__init__.py +25 -0
- flwr/supercore/grpc_health/health_server.py +53 -0
- flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
- flwr/supercore/license_plugin/__init__.py +22 -0
- flwr/supercore/license_plugin/license_plugin.py +26 -0
- flwr/supercore/object_store/in_memory_object_store.py +31 -31
- flwr/supercore/object_store/object_store.py +20 -42
- flwr/supercore/object_store/utils.py +43 -0
- flwr/{superexec → supercore/superexec}/__init__.py +1 -1
- flwr/supercore/superexec/plugin/__init__.py +28 -0
- flwr/supercore/superexec/plugin/base_exec_plugin.py +53 -0
- flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
- flwr/supercore/superexec/plugin/exec_plugin.py +71 -0
- flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
- flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
- flwr/supercore/superexec/run_superexec.py +185 -0
- flwr/supercore/utils.py +32 -0
- flwr/superlink/servicer/__init__.py +15 -0
- flwr/superlink/servicer/control/__init__.py +22 -0
- flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +9 -5
- flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +39 -28
- flwr/superlink/servicer/control/control_license_interceptor.py +82 -0
- flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +79 -31
- flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +18 -10
- flwr/supernode/cli/flower_supernode.py +3 -7
- flwr/supernode/cli/flwr_clientapp.py +20 -16
- flwr/supernode/nodestate/in_memory_nodestate.py +13 -4
- flwr/supernode/nodestate/nodestate.py +3 -44
- flwr/supernode/runtime/run_clientapp.py +129 -115
- flwr/supernode/servicer/clientappio/__init__.py +1 -3
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +217 -165
- flwr/supernode/start_client_internal.py +205 -148
- {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/METADATA +5 -3
- {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/RECORD +161 -117
- {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/entry_points.txt +1 -0
- flwr/common/inflatable_rest_utils.py +0 -99
- flwr/proto/exec_pb2.py +0 -62
- flwr/superexec/app.py +0 -45
- flwr/superexec/deployment.py +0 -192
- flwr/superexec/executor.py +0 -100
- flwr/superexec/simulation.py +0 -130
- /flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +0 -0
- /flwr/{server/superlink → supercore}/ffs/__init__.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
- {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Strategy results."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import pprint
|
|
19
|
+
from dataclasses import dataclass, field
|
|
20
|
+
|
|
21
|
+
from flwr.common import ArrayRecord, MetricRecord
|
|
22
|
+
from flwr.common.typing import MetricRecordValues
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class Result:
|
|
27
|
+
"""Data class carrying records generated during the execution of a strategy.
|
|
28
|
+
|
|
29
|
+
This class encapsulates the results of a federated learning strategy execution,
|
|
30
|
+
including the final global model parameters and metrics collected throughout
|
|
31
|
+
the federated training and evaluation (both federated and centralized) stages.
|
|
32
|
+
|
|
33
|
+
Attributes
|
|
34
|
+
----------
|
|
35
|
+
arrays : ArrayRecord
|
|
36
|
+
The final global model parameters. Contains the
|
|
37
|
+
aggregated model weights/parameters that resulted from the federated
|
|
38
|
+
learning process.
|
|
39
|
+
train_metrics_clientapp : dict[int, MetricRecord]
|
|
40
|
+
Training metrics collected from ClientApps, indexed by round number.
|
|
41
|
+
Contains aggregated metrics (e.g., loss, accuracy) from the training
|
|
42
|
+
phase of each federated learning round.
|
|
43
|
+
evaluate_metrics_clientapp : dict[int, MetricRecord]
|
|
44
|
+
Evaluation metrics collected from ClientApps, indexed by round number.
|
|
45
|
+
Contains aggregated metrics (e.g. validation loss) from the evaluation
|
|
46
|
+
phase where ClientApps evaluate the global model on their local
|
|
47
|
+
validation/test data.
|
|
48
|
+
evaluate_metrics_serverapp : dict[int, MetricRecord]
|
|
49
|
+
Evaluation metrics generated at the ServerApp, indexed by round number.
|
|
50
|
+
Contains metrics from centralized evaluation performed by the ServerApp
|
|
51
|
+
(e.g., when the server evaluates the global model on a held-out dataset).
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
arrays: ArrayRecord = field(default_factory=ArrayRecord)
|
|
55
|
+
train_metrics_clientapp: dict[int, MetricRecord] = field(default_factory=dict)
|
|
56
|
+
evaluate_metrics_clientapp: dict[int, MetricRecord] = field(default_factory=dict)
|
|
57
|
+
evaluate_metrics_serverapp: dict[int, MetricRecord] = field(default_factory=dict)
|
|
58
|
+
|
|
59
|
+
def __repr__(self) -> str:
|
|
60
|
+
"""Create a representation of the Result instance."""
|
|
61
|
+
rep = ""
|
|
62
|
+
arr_size = sum(len(array.data) for array in self.arrays.values()) / (1024**2)
|
|
63
|
+
rep += "Global Arrays:\n" + f"\tArrayRecord ({arr_size:.3f} MB)\n" + "\n"
|
|
64
|
+
rep += (
|
|
65
|
+
"Aggregated ClientApp-side Train Metrics:\n"
|
|
66
|
+
+ pprint.pformat(stringify_dict(self.train_metrics_clientapp), indent=2)
|
|
67
|
+
+ "\n\n"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
rep += (
|
|
71
|
+
"Aggregated ClientApp-side Evaluate Metrics:\n"
|
|
72
|
+
+ pprint.pformat(stringify_dict(self.evaluate_metrics_clientapp), indent=2)
|
|
73
|
+
+ "\n\n"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
rep += (
|
|
77
|
+
"ServerApp-side Evaluate Metrics:\n"
|
|
78
|
+
+ pprint.pformat(stringify_dict(self.evaluate_metrics_serverapp), indent=2)
|
|
79
|
+
+ "\n"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
return rep
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def format_value(val: MetricRecordValues) -> str:
|
|
86
|
+
"""Format a value as string, applying scientific notation for floats."""
|
|
87
|
+
if isinstance(val, float):
|
|
88
|
+
return f"{val:.4e}"
|
|
89
|
+
if isinstance(val, int):
|
|
90
|
+
return str(val)
|
|
91
|
+
if isinstance(val, list):
|
|
92
|
+
return str([f"{x:.4e}" if isinstance(x, float) else str(x) for x in val])
|
|
93
|
+
return str(val)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def stringify_dict(d: dict[int, MetricRecord]) -> dict[int, dict[str, str]]:
|
|
97
|
+
"""Return a copy results metrics but with values converted to string and formatted
|
|
98
|
+
accordingtly."""
|
|
99
|
+
new_metrics_dict = {}
|
|
100
|
+
for k, inner in d.items():
|
|
101
|
+
new_inner = {}
|
|
102
|
+
for ik, iv in inner.items():
|
|
103
|
+
new_inner[ik] = format_value(iv)
|
|
104
|
+
new_metrics_dict[k] = new_inner
|
|
105
|
+
return new_metrics_dict
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower message-based strategy."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import io
|
|
19
|
+
import time
|
|
20
|
+
from abc import ABC, abstractmethod
|
|
21
|
+
from collections.abc import Iterable
|
|
22
|
+
from logging import INFO
|
|
23
|
+
from typing import Callable, Optional
|
|
24
|
+
|
|
25
|
+
from flwr.common import ArrayRecord, ConfigRecord, Message, MetricRecord, log
|
|
26
|
+
from flwr.server import Grid
|
|
27
|
+
|
|
28
|
+
from .result import Result
|
|
29
|
+
from .strategy_utils import log_strategy_start_info
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Strategy(ABC):
|
|
33
|
+
"""Abstract base class for server strategy implementations."""
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def configure_train(
|
|
37
|
+
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
|
38
|
+
) -> Iterable[Message]:
|
|
39
|
+
"""Configure the next round of training.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
server_round : int
|
|
44
|
+
The current round of federated learning.
|
|
45
|
+
arrays : ArrayRecord
|
|
46
|
+
Current global ArrayRecord (e.g. global model) to be sent to client
|
|
47
|
+
nodes for training.
|
|
48
|
+
config : ConfigRecord
|
|
49
|
+
Configuration to be sent to clients nodes for training.
|
|
50
|
+
grid : Grid
|
|
51
|
+
The Grid instance used for node sampling and communication.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
Iterable[Message]
|
|
56
|
+
An iterable of messages to be sent to selected client nodes for training.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def aggregate_train(
|
|
61
|
+
self,
|
|
62
|
+
server_round: int,
|
|
63
|
+
replies: Iterable[Message],
|
|
64
|
+
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
|
65
|
+
"""Aggregate training results from client nodes.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
server_round : int
|
|
70
|
+
The current round of federated learning, starting from 1.
|
|
71
|
+
replies : Iterable[Message]
|
|
72
|
+
Iterable of reply messages received from client nodes after training.
|
|
73
|
+
Each message contains ArrayRecords and MetricRecords that get aggregated.
|
|
74
|
+
|
|
75
|
+
Returns
|
|
76
|
+
-------
|
|
77
|
+
tuple[Optional[ArrayRecord], Optional[MetricRecord]]
|
|
78
|
+
A tuple containing:
|
|
79
|
+
- ArrayRecord: Aggregated ArrayRecord, or None if aggregation failed
|
|
80
|
+
- MetricRecord: Aggregated MetricRecord, or None if aggregation failed
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
@abstractmethod
|
|
84
|
+
def configure_evaluate(
|
|
85
|
+
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
|
86
|
+
) -> Iterable[Message]:
|
|
87
|
+
"""Configure the next round of evaluation.
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
server_round : int
|
|
92
|
+
The current round of federated learning.
|
|
93
|
+
arrays : ArrayRecord
|
|
94
|
+
Current global ArrayRecord (e.g. global model) to be sent to client
|
|
95
|
+
nodes for evaluation.
|
|
96
|
+
config : ConfigRecord
|
|
97
|
+
Configuration to be sent to clients nodes for evaluation.
|
|
98
|
+
grid : Grid
|
|
99
|
+
The Grid instance used for node sampling and communication.
|
|
100
|
+
|
|
101
|
+
Returns
|
|
102
|
+
-------
|
|
103
|
+
Iterable[Message]
|
|
104
|
+
An iterable of messages to be sent to selected client nodes for evaluation.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
@abstractmethod
|
|
108
|
+
def aggregate_evaluate(
|
|
109
|
+
self,
|
|
110
|
+
server_round: int,
|
|
111
|
+
replies: Iterable[Message],
|
|
112
|
+
) -> Optional[MetricRecord]:
|
|
113
|
+
"""Aggregate evaluation metrics from client nodes.
|
|
114
|
+
|
|
115
|
+
Parameters
|
|
116
|
+
----------
|
|
117
|
+
server_round : int
|
|
118
|
+
The current round of federated learning.
|
|
119
|
+
replies : Iterable[Message]
|
|
120
|
+
Iterable of reply messages received from client nodes after evaluation.
|
|
121
|
+
MetricRecords in the messages are aggregated.
|
|
122
|
+
|
|
123
|
+
Returns
|
|
124
|
+
-------
|
|
125
|
+
Optional[MetricRecord]
|
|
126
|
+
Aggregated evaluation metrics from all participating clients,
|
|
127
|
+
or None if aggregation failed.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
@abstractmethod
|
|
131
|
+
def summary(self) -> None:
|
|
132
|
+
"""Log summary configuration of the strategy."""
|
|
133
|
+
|
|
134
|
+
# pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-locals
|
|
135
|
+
def start(
|
|
136
|
+
self,
|
|
137
|
+
grid: Grid,
|
|
138
|
+
initial_arrays: ArrayRecord,
|
|
139
|
+
num_rounds: int = 3,
|
|
140
|
+
timeout: float = 3600,
|
|
141
|
+
train_config: Optional[ConfigRecord] = None,
|
|
142
|
+
evaluate_config: Optional[ConfigRecord] = None,
|
|
143
|
+
evaluate_fn: Optional[
|
|
144
|
+
Callable[[int, ArrayRecord], Optional[MetricRecord]]
|
|
145
|
+
] = None,
|
|
146
|
+
) -> Result:
|
|
147
|
+
"""Execute the federated learning strategy.
|
|
148
|
+
|
|
149
|
+
Runs the complete federated learning workflow for the specified number of
|
|
150
|
+
rounds, including training, evaluation, and optional centralized evaluation.
|
|
151
|
+
|
|
152
|
+
Parameters
|
|
153
|
+
----------
|
|
154
|
+
grid : Grid
|
|
155
|
+
The Grid instance used to send/receive Messages from nodes executing a
|
|
156
|
+
ClientApp.
|
|
157
|
+
initial_arrays : ArrayRecord
|
|
158
|
+
Initial model parameters (arrays) to be used for federated learning.
|
|
159
|
+
num_rounds : int (default: 3)
|
|
160
|
+
Number of federated learning rounds to execute.
|
|
161
|
+
timeout : float (default: 3600)
|
|
162
|
+
Timeout in seconds for waiting for node responses.
|
|
163
|
+
train_config : ConfigRecord, optional
|
|
164
|
+
Configuration to be sent to nodes during training rounds.
|
|
165
|
+
If unset, an empty ConfigRecord will be used.
|
|
166
|
+
evaluate_config : ConfigRecord, optional
|
|
167
|
+
Configuration to be sent to nodes during evaluation rounds.
|
|
168
|
+
If unset, an empty ConfigRecord will be used.
|
|
169
|
+
evaluate_fn : Callable[[int, ArrayRecord], Optional[MetricRecord]], optional
|
|
170
|
+
Optional function for centralized evaluation of the global model. Takes
|
|
171
|
+
server round number and array record, returns a MetricRecord or None. If
|
|
172
|
+
provided, will be called before the first round and after each round.
|
|
173
|
+
Defaults to None.
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
Results
|
|
178
|
+
Results containing final model arrays and also training metrics, evaluation
|
|
179
|
+
metrics and global evaluation metrics (if provided) from all rounds.
|
|
180
|
+
"""
|
|
181
|
+
log(INFO, "Starting %s strategy:", self.__class__.__name__)
|
|
182
|
+
log_strategy_start_info(
|
|
183
|
+
num_rounds, initial_arrays, train_config, evaluate_config
|
|
184
|
+
)
|
|
185
|
+
self.summary()
|
|
186
|
+
log(INFO, "")
|
|
187
|
+
|
|
188
|
+
# Initialize if None
|
|
189
|
+
train_config = ConfigRecord() if train_config is None else train_config
|
|
190
|
+
evaluate_config = ConfigRecord() if evaluate_config is None else evaluate_config
|
|
191
|
+
result = Result()
|
|
192
|
+
|
|
193
|
+
t_start = time.time()
|
|
194
|
+
# Evaluate starting global parameters
|
|
195
|
+
if evaluate_fn:
|
|
196
|
+
res = evaluate_fn(0, initial_arrays)
|
|
197
|
+
log(INFO, "Initial global evaluation results: %s", res)
|
|
198
|
+
if res is not None:
|
|
199
|
+
result.evaluate_metrics_serverapp[0] = res
|
|
200
|
+
|
|
201
|
+
arrays = initial_arrays
|
|
202
|
+
|
|
203
|
+
for current_round in range(1, num_rounds + 1):
|
|
204
|
+
log(INFO, "")
|
|
205
|
+
log(INFO, "[ROUND %s/%s]", current_round, num_rounds)
|
|
206
|
+
|
|
207
|
+
# -----------------------------------------------------------------
|
|
208
|
+
# --- TRAINING (CLIENTAPP-SIDE) -----------------------------------
|
|
209
|
+
# -----------------------------------------------------------------
|
|
210
|
+
|
|
211
|
+
# Call strategy to configure training round
|
|
212
|
+
# Send messages and wait for replies
|
|
213
|
+
train_replies = grid.send_and_receive(
|
|
214
|
+
messages=self.configure_train(
|
|
215
|
+
current_round,
|
|
216
|
+
arrays,
|
|
217
|
+
train_config,
|
|
218
|
+
grid,
|
|
219
|
+
),
|
|
220
|
+
timeout=timeout,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# Aggregate train
|
|
224
|
+
agg_arrays, agg_train_metrics = self.aggregate_train(
|
|
225
|
+
current_round,
|
|
226
|
+
train_replies,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Log training metrics and append to history
|
|
230
|
+
if agg_arrays is not None:
|
|
231
|
+
result.arrays = agg_arrays
|
|
232
|
+
arrays = agg_arrays
|
|
233
|
+
if agg_train_metrics is not None:
|
|
234
|
+
log(INFO, "\t└──> Aggregated MetricRecord: %s", agg_train_metrics)
|
|
235
|
+
result.train_metrics_clientapp[current_round] = agg_train_metrics
|
|
236
|
+
|
|
237
|
+
# -----------------------------------------------------------------
|
|
238
|
+
# --- EVALUATION (CLIENTAPP-SIDE) ---------------------------------
|
|
239
|
+
# -----------------------------------------------------------------
|
|
240
|
+
|
|
241
|
+
# Call strategy to configure evaluation round
|
|
242
|
+
# Send messages and wait for replies
|
|
243
|
+
evaluate_replies = grid.send_and_receive(
|
|
244
|
+
messages=self.configure_evaluate(
|
|
245
|
+
current_round,
|
|
246
|
+
arrays,
|
|
247
|
+
evaluate_config,
|
|
248
|
+
grid,
|
|
249
|
+
),
|
|
250
|
+
timeout=timeout,
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
# Aggregate evaluate
|
|
254
|
+
agg_evaluate_metrics = self.aggregate_evaluate(
|
|
255
|
+
current_round,
|
|
256
|
+
evaluate_replies,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# Log training metrics and append to history
|
|
260
|
+
if agg_evaluate_metrics is not None:
|
|
261
|
+
log(INFO, "\t└──> Aggregated MetricRecord: %s", agg_evaluate_metrics)
|
|
262
|
+
result.evaluate_metrics_clientapp[current_round] = agg_evaluate_metrics
|
|
263
|
+
|
|
264
|
+
# -----------------------------------------------------------------
|
|
265
|
+
# --- EVALUATION (SERVERAPP-SIDE) ---------------------------------
|
|
266
|
+
# -----------------------------------------------------------------
|
|
267
|
+
|
|
268
|
+
# Centralized evaluation
|
|
269
|
+
if evaluate_fn:
|
|
270
|
+
log(INFO, "Global evaluation")
|
|
271
|
+
res = evaluate_fn(current_round, arrays)
|
|
272
|
+
log(INFO, "\t└──> MetricRecord: %s", res)
|
|
273
|
+
if res is not None:
|
|
274
|
+
result.evaluate_metrics_serverapp[current_round] = res
|
|
275
|
+
|
|
276
|
+
log(INFO, "")
|
|
277
|
+
log(INFO, "Strategy execution finished in %.2fs", time.time() - t_start)
|
|
278
|
+
log(INFO, "")
|
|
279
|
+
log(INFO, "Final results:")
|
|
280
|
+
log(INFO, "")
|
|
281
|
+
for line in io.StringIO(str(result)):
|
|
282
|
+
log(INFO, "\t%s", line.strip("\n"))
|
|
283
|
+
log(INFO, "")
|
|
284
|
+
|
|
285
|
+
return result
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower message-based strategy utilities."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import random
|
|
19
|
+
from collections import OrderedDict
|
|
20
|
+
from logging import INFO
|
|
21
|
+
from time import sleep
|
|
22
|
+
from typing import Optional, cast
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
|
|
26
|
+
from flwr.common import (
|
|
27
|
+
Array,
|
|
28
|
+
ArrayRecord,
|
|
29
|
+
ConfigRecord,
|
|
30
|
+
MetricRecord,
|
|
31
|
+
NDArray,
|
|
32
|
+
RecordDict,
|
|
33
|
+
log,
|
|
34
|
+
)
|
|
35
|
+
from flwr.server import Grid
|
|
36
|
+
|
|
37
|
+
from ..exception import InconsistentMessageReplies
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def config_to_str(config: ConfigRecord) -> str:
|
|
41
|
+
"""Convert a ConfigRecord to a string representation masking bytes."""
|
|
42
|
+
content = ", ".join(
|
|
43
|
+
f"'{k}': {'<bytes>' if isinstance(v, bytes) else v}" for k, v in config.items()
|
|
44
|
+
)
|
|
45
|
+
return f"{{{content}}}"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def log_strategy_start_info(
|
|
49
|
+
num_rounds: int,
|
|
50
|
+
arrays: ArrayRecord,
|
|
51
|
+
train_config: Optional[ConfigRecord],
|
|
52
|
+
evaluate_config: Optional[ConfigRecord],
|
|
53
|
+
) -> None:
|
|
54
|
+
"""Log information about the strategy start."""
|
|
55
|
+
log(INFO, "\t├── Number of rounds: %d", num_rounds)
|
|
56
|
+
log(
|
|
57
|
+
INFO,
|
|
58
|
+
"\t├── ArrayRecord (%.2f MB)",
|
|
59
|
+
sum(len(array.data) for array in arrays.values()) / (1024**2),
|
|
60
|
+
)
|
|
61
|
+
log(
|
|
62
|
+
INFO,
|
|
63
|
+
"\t├── ConfigRecord (train): %s",
|
|
64
|
+
config_to_str(train_config) if train_config else "(empty!)",
|
|
65
|
+
)
|
|
66
|
+
log(
|
|
67
|
+
INFO,
|
|
68
|
+
"\t├── ConfigRecord (evaluate): %s",
|
|
69
|
+
config_to_str(evaluate_config) if evaluate_config else "(empty!)",
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def aggregate_arrayrecords(
|
|
74
|
+
records: list[RecordDict], weighting_metric_name: str
|
|
75
|
+
) -> ArrayRecord:
|
|
76
|
+
"""Perform weighted aggregation all ArrayRecords using a specific key."""
|
|
77
|
+
# Retrieve weighting factor from MetricRecord
|
|
78
|
+
weights: list[float] = []
|
|
79
|
+
for record in records:
|
|
80
|
+
# Get the first (and only) MetricRecord in the record
|
|
81
|
+
metricrecord = next(iter(record.metric_records.values()))
|
|
82
|
+
# Because replies have been checked for consistency,
|
|
83
|
+
# we can safely cast the weighting factor to float
|
|
84
|
+
w = cast(float, metricrecord[weighting_metric_name])
|
|
85
|
+
weights.append(w)
|
|
86
|
+
|
|
87
|
+
# Average
|
|
88
|
+
total_weight = sum(weights)
|
|
89
|
+
weight_factors = [w / total_weight for w in weights]
|
|
90
|
+
|
|
91
|
+
# Perform weighted aggregation
|
|
92
|
+
aggregated_np_arrays: dict[str, NDArray] = {}
|
|
93
|
+
|
|
94
|
+
for record, weight in zip(records, weight_factors):
|
|
95
|
+
for record_item in record.array_records.values():
|
|
96
|
+
# aggregate in-place
|
|
97
|
+
for key, value in record_item.items():
|
|
98
|
+
if key not in aggregated_np_arrays:
|
|
99
|
+
aggregated_np_arrays[key] = value.numpy() * weight
|
|
100
|
+
else:
|
|
101
|
+
aggregated_np_arrays[key] += value.numpy() * weight
|
|
102
|
+
|
|
103
|
+
return ArrayRecord(
|
|
104
|
+
OrderedDict({k: Array(np.asarray(v)) for k, v in aggregated_np_arrays.items()})
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def aggregate_metricrecords(
|
|
109
|
+
records: list[RecordDict], weighting_metric_name: str
|
|
110
|
+
) -> MetricRecord:
|
|
111
|
+
"""Perform weighted aggregation all MetricRecords using a specific key."""
|
|
112
|
+
# Retrieve weighting factor from MetricRecord
|
|
113
|
+
weights: list[float] = []
|
|
114
|
+
for record in records:
|
|
115
|
+
# Get the first (and only) MetricRecord in the record
|
|
116
|
+
metricrecord = next(iter(record.metric_records.values()))
|
|
117
|
+
# Because replies have been checked for consistency,
|
|
118
|
+
# we can safely cast the weighting factor to float
|
|
119
|
+
w = cast(float, metricrecord[weighting_metric_name])
|
|
120
|
+
weights.append(w)
|
|
121
|
+
|
|
122
|
+
# Average
|
|
123
|
+
total_weight = sum(weights)
|
|
124
|
+
weight_factors = [w / total_weight for w in weights]
|
|
125
|
+
|
|
126
|
+
aggregated_metrics = MetricRecord()
|
|
127
|
+
for record, weight in zip(records, weight_factors):
|
|
128
|
+
for record_item in record.metric_records.values():
|
|
129
|
+
# aggregate in-place
|
|
130
|
+
for key, value in record_item.items():
|
|
131
|
+
if key == weighting_metric_name:
|
|
132
|
+
# We exclude the weighting key from the aggregated MetricRecord
|
|
133
|
+
continue
|
|
134
|
+
if key not in aggregated_metrics:
|
|
135
|
+
if isinstance(value, list):
|
|
136
|
+
aggregated_metrics[key] = [v * weight for v in value]
|
|
137
|
+
else:
|
|
138
|
+
aggregated_metrics[key] = value * weight
|
|
139
|
+
else:
|
|
140
|
+
if isinstance(value, list):
|
|
141
|
+
current_list = cast(list[float], aggregated_metrics[key])
|
|
142
|
+
aggregated_metrics[key] = [
|
|
143
|
+
curr + val * weight
|
|
144
|
+
for curr, val in zip(current_list, value)
|
|
145
|
+
]
|
|
146
|
+
else:
|
|
147
|
+
current_value = cast(float, aggregated_metrics[key])
|
|
148
|
+
aggregated_metrics[key] = current_value + value * weight
|
|
149
|
+
|
|
150
|
+
return aggregated_metrics
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def sample_nodes(
|
|
154
|
+
grid: Grid, min_available_nodes: int, sample_size: int
|
|
155
|
+
) -> tuple[list[int], list[int]]:
|
|
156
|
+
"""Sample the specified number of nodes using the Grid.
|
|
157
|
+
|
|
158
|
+
Parameters
|
|
159
|
+
----------
|
|
160
|
+
grid : Grid
|
|
161
|
+
The grid object.
|
|
162
|
+
min_available_nodes : int
|
|
163
|
+
The minimum number of available nodes to sample from.
|
|
164
|
+
sample_size : int
|
|
165
|
+
The number of nodes to sample.
|
|
166
|
+
|
|
167
|
+
Returns
|
|
168
|
+
-------
|
|
169
|
+
tuple[list[int], list[int]]
|
|
170
|
+
A tuple containing the sampled node IDs and the list
|
|
171
|
+
of all connected node IDs.
|
|
172
|
+
"""
|
|
173
|
+
sampled_nodes = []
|
|
174
|
+
|
|
175
|
+
# Ensure min_available_nodes is at least as large as sample_size
|
|
176
|
+
min_available_nodes = max(min_available_nodes, sample_size)
|
|
177
|
+
|
|
178
|
+
# wait for min_available_nodes to be online
|
|
179
|
+
while len(all_nodes := list(grid.get_node_ids())) < min_available_nodes:
|
|
180
|
+
log(
|
|
181
|
+
INFO,
|
|
182
|
+
"Waiting for nodes to connect: %d connected (minimum required: %d).",
|
|
183
|
+
len(all_nodes),
|
|
184
|
+
min_available_nodes,
|
|
185
|
+
)
|
|
186
|
+
sleep(1)
|
|
187
|
+
|
|
188
|
+
# Sample nodes
|
|
189
|
+
sampled_nodes = random.sample(all_nodes, sample_size)
|
|
190
|
+
|
|
191
|
+
return sampled_nodes, all_nodes
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
# pylint: disable=too-many-return-statements
|
|
195
|
+
def validate_message_reply_consistency(
|
|
196
|
+
replies: list[RecordDict], weighted_by_key: str, check_arrayrecord: bool
|
|
197
|
+
) -> None:
|
|
198
|
+
"""Validate that replies contain exactly one ArrayRecord and one MetricRecord, and
|
|
199
|
+
that the MetricRecord includes a weight factor key.
|
|
200
|
+
|
|
201
|
+
These checks ensure that Message-based strategies behave consistently with
|
|
202
|
+
*Ins/*Res-based strategies.
|
|
203
|
+
"""
|
|
204
|
+
# Checking for ArrayRecord consistency
|
|
205
|
+
if check_arrayrecord:
|
|
206
|
+
if any(len(msg.array_records) != 1 for msg in replies):
|
|
207
|
+
raise InconsistentMessageReplies(
|
|
208
|
+
reason="Expected exactly one ArrayRecord in replies. "
|
|
209
|
+
"Skipping aggregation."
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# Ensure all key are present in all ArrayRecords
|
|
213
|
+
record_key = next(iter(replies[0].array_records.keys()))
|
|
214
|
+
all_keys = set(replies[0][record_key].keys())
|
|
215
|
+
if any(set(msg.get(record_key, {}).keys()) != all_keys for msg in replies[1:]):
|
|
216
|
+
raise InconsistentMessageReplies(
|
|
217
|
+
reason="All ArrayRecords must have the same keys for aggregation. "
|
|
218
|
+
"This condition wasn't met. Skipping aggregation."
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Checking for MetricRecord consistency
|
|
222
|
+
if any(len(msg.metric_records) != 1 for msg in replies):
|
|
223
|
+
raise InconsistentMessageReplies(
|
|
224
|
+
reason="Expected exactly one MetricRecord in replies, but found more. "
|
|
225
|
+
"Skipping aggregation."
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# Ensure all key are present in all MetricRecords
|
|
229
|
+
record_key = next(iter(replies[0].metric_records.keys()))
|
|
230
|
+
all_keys = set(replies[0][record_key].keys())
|
|
231
|
+
if any(set(msg.get(record_key, {}).keys()) != all_keys for msg in replies[1:]):
|
|
232
|
+
raise InconsistentMessageReplies(
|
|
233
|
+
reason="All MetricRecords must have the same keys for aggregation. "
|
|
234
|
+
"This condition wasn't met. Skipping aggregation."
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# Verify the weight factor key presence in all MetricRecords
|
|
238
|
+
if weighted_by_key not in all_keys:
|
|
239
|
+
raise InconsistentMessageReplies(
|
|
240
|
+
reason=f"Missing required key `{weighted_by_key}` in the MetricRecord of "
|
|
241
|
+
"reply messages. Cannot average ArrayRecords and MetricRecords. Skipping "
|
|
242
|
+
"aggregation."
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# Check that it is not a list
|
|
246
|
+
if any(isinstance(msg[record_key][weighted_by_key], list) for msg in replies):
|
|
247
|
+
raise InconsistentMessageReplies(
|
|
248
|
+
reason=f"Key `{weighted_by_key}` in the MetricRecord of reply messages "
|
|
249
|
+
"must be a single value (int or float), but a list was found. Skipping "
|
|
250
|
+
"aggregation."
|
|
251
|
+
)
|