flwr-nightly 1.21.0.dev20250826__py3-none-any.whl → 1.21.0.dev20250827__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/common/exit/exit_code.py +13 -0
- flwr/serverapp/__init__.py +10 -0
- flwr/serverapp/fedavg.py +292 -0
- flwr/serverapp/result.py +30 -0
- flwr/serverapp/strategy.py +286 -0
- flwr/serverapp/strategy_utils.py +256 -0
- flwr/serverapp/strategy_utils_tests.py +277 -0
- {flwr_nightly-1.21.0.dev20250826.dist-info → flwr_nightly-1.21.0.dev20250827.dist-info}/METADATA +1 -1
- {flwr_nightly-1.21.0.dev20250826.dist-info → flwr_nightly-1.21.0.dev20250827.dist-info}/RECORD +11 -6
- {flwr_nightly-1.21.0.dev20250826.dist-info → flwr_nightly-1.21.0.dev20250827.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.21.0.dev20250826.dist-info → flwr_nightly-1.21.0.dev20250827.dist-info}/entry_points.txt +0 -0
flwr/common/exit/exit_code.py
CHANGED
@@ -34,6 +34,9 @@ class ExitCode:
|
|
34
34
|
SUPERLINK_LICENSE_URL_INVALID = 103
|
35
35
|
SUPERLINK_INVALID_ARGS = 104
|
36
36
|
|
37
|
+
# ServerApp-specific exit codes (200-299)
|
38
|
+
SERVERAPP_STRATEGY_PRECONDITION_UNMET = 200
|
39
|
+
|
37
40
|
# SuperNode-specific exit codes (300-399)
|
38
41
|
SUPERNODE_REST_ADDRESS_INVALID = 300
|
39
42
|
SUPERNODE_NODE_AUTH_KEYS_REQUIRED = 301
|
@@ -76,6 +79,16 @@ EXIT_CODE_HELP = {
|
|
76
79
|
"Invalid arguments provided to SuperLink. Use `--help` check for the correct "
|
77
80
|
"usage. Alternatively, check the documentation."
|
78
81
|
),
|
82
|
+
# ServerApp-specific exit codes (200-299)
|
83
|
+
ExitCode.SERVERAPP_STRATEGY_PRECONDITION_UNMET: (
|
84
|
+
"The strategy received replies that cannot be aggregated. Please ensure all "
|
85
|
+
"replies returned by ClientApps have one `ArrayRecord` (none when replies are "
|
86
|
+
"from a round of federated evaluation, i.e. when message type is "
|
87
|
+
"`MessageType.EVALUATE`) and one `MetricRecord`. The records in all replies "
|
88
|
+
"must use identical keys. In addition, if the strategy expects a key to "
|
89
|
+
"perform weighted average (e.g. in FedAvg) please ensure the returned "
|
90
|
+
"MetricRecord from ClientApps do include this key."
|
91
|
+
),
|
79
92
|
# SuperNode-specific exit codes (300-399)
|
80
93
|
ExitCode.SUPERNODE_REST_ADDRESS_INVALID: (
|
81
94
|
"When using the REST API, please provide `https://` or "
|
flwr/serverapp/__init__.py
CHANGED
@@ -13,3 +13,13 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
"""Public Flower ServerApp APIs."""
|
16
|
+
|
17
|
+
from .fedavg import FedAvg
|
18
|
+
from .result import Result
|
19
|
+
from .strategy import Strategy
|
20
|
+
|
21
|
+
__all__ = [
|
22
|
+
"FedAvg",
|
23
|
+
"Result",
|
24
|
+
"Strategy",
|
25
|
+
]
|
flwr/serverapp/fedavg.py
ADDED
@@ -0,0 +1,292 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Flower message-based FedAvg strategy."""
|
16
|
+
|
17
|
+
|
18
|
+
from collections.abc import Iterable
|
19
|
+
from logging import INFO
|
20
|
+
from typing import Callable, Optional
|
21
|
+
|
22
|
+
from flwr.common import (
|
23
|
+
ArrayRecord,
|
24
|
+
ConfigRecord,
|
25
|
+
Message,
|
26
|
+
MessageType,
|
27
|
+
MetricRecord,
|
28
|
+
RecordDict,
|
29
|
+
log,
|
30
|
+
)
|
31
|
+
from flwr.server import Grid
|
32
|
+
|
33
|
+
from .strategy import Strategy
|
34
|
+
from .strategy_utils import (
|
35
|
+
aggregate_arrayrecords,
|
36
|
+
aggregate_metricrecords,
|
37
|
+
sample_nodes,
|
38
|
+
validate_message_reply_consistency,
|
39
|
+
)
|
40
|
+
|
41
|
+
|
42
|
+
# pylint: disable=too-many-instance-attributes
|
43
|
+
class FedAvg(Strategy):
|
44
|
+
"""Federated Averaging strategy.
|
45
|
+
|
46
|
+
Implementation based on https://arxiv.org/abs/1602.05629
|
47
|
+
|
48
|
+
Parameters
|
49
|
+
----------
|
50
|
+
fraction_train : float (default: 1.0)
|
51
|
+
Fraction of nodes used during training. In case `min_train_nodes`
|
52
|
+
is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
|
53
|
+
will still be sampled.
|
54
|
+
fraction_evaluate : float (default: 1.0)
|
55
|
+
Fraction of nodes used during validation. In case `min_evaluate_nodes`
|
56
|
+
is larger than `fraction_evaluate * total_connected_nodes`,
|
57
|
+
`min_evaluate_nodes` will still be sampled.
|
58
|
+
min_train_nodes : int (default: 2)
|
59
|
+
Minimum number of nodes used during training.
|
60
|
+
min_evaluate_nodes : int (default: 2)
|
61
|
+
Minimum number of nodes used during validation.
|
62
|
+
min_available_nodes : int (default: 2)
|
63
|
+
Minimum number of total nodes in the system.
|
64
|
+
weighted_by_key : str (default: "num-examples")
|
65
|
+
The key within each MetricRecord whose value is used as the weight when
|
66
|
+
computing weighted averages for both ArrayRecords and MetricRecords.
|
67
|
+
arrayrecord_key : str (default: "arrays")
|
68
|
+
Key used to store the ArrayRecord when constructing Messages.
|
69
|
+
configrecord_key : str (default: "config")
|
70
|
+
Key used to store the ConfigRecord when constructing Messages.
|
71
|
+
train_metrics_aggr_fn : Optional[callable] (default: None)
|
72
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
73
|
+
used to aggregate MetricRecords from training round replies.
|
74
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
75
|
+
average using the provided weight factor key.
|
76
|
+
evaluate_metrics_aggr_fn : Optional[callable] (default: None)
|
77
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
78
|
+
used to aggregate MetricRecords from training round replies.
|
79
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
80
|
+
average using the provided weight factor key.
|
81
|
+
"""
|
82
|
+
|
83
|
+
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
84
|
+
def __init__(
|
85
|
+
self,
|
86
|
+
fraction_train: float = 1.0,
|
87
|
+
fraction_evaluate: float = 1.0,
|
88
|
+
min_train_nodes: int = 2,
|
89
|
+
min_evaluate_nodes: int = 2,
|
90
|
+
min_available_nodes: int = 2,
|
91
|
+
weighted_by_key: str = "num-examples",
|
92
|
+
arrayrecord_key: str = "arrays",
|
93
|
+
configrecord_key: str = "config",
|
94
|
+
train_metrics_aggr_fn: Optional[
|
95
|
+
Callable[[list[RecordDict], str], MetricRecord]
|
96
|
+
] = None,
|
97
|
+
evaluate_metrics_aggr_fn: Optional[
|
98
|
+
Callable[[list[RecordDict], str], MetricRecord]
|
99
|
+
] = None,
|
100
|
+
) -> None:
|
101
|
+
self.fraction_train = fraction_train
|
102
|
+
self.fraction_evaluate = fraction_evaluate
|
103
|
+
self.min_train_nodes = min_train_nodes
|
104
|
+
self.min_evaluate_nodes = min_evaluate_nodes
|
105
|
+
self.min_available_nodes = min_available_nodes
|
106
|
+
self.weighted_by_key = weighted_by_key
|
107
|
+
self.arrayrecord_key = arrayrecord_key
|
108
|
+
self.configrecord_key = configrecord_key
|
109
|
+
self.train_metrics_aggr_fn = train_metrics_aggr_fn or aggregate_metricrecords
|
110
|
+
self.evaluate_metrics_aggr_fn = (
|
111
|
+
evaluate_metrics_aggr_fn or aggregate_metricrecords
|
112
|
+
)
|
113
|
+
|
114
|
+
def summary(self) -> None:
|
115
|
+
"""Log summary configuration of the strategy."""
|
116
|
+
log(INFO, "\t├──> Sampling:")
|
117
|
+
log(
|
118
|
+
INFO,
|
119
|
+
"\t│\t├──Fraction: train (%.2f) | evaluate ( %.2f)",
|
120
|
+
self.fraction_train,
|
121
|
+
self.fraction_evaluate,
|
122
|
+
) # pylint: disable=line-too-long
|
123
|
+
log(
|
124
|
+
INFO,
|
125
|
+
"\t│\t├──Minimum nodes: train (%d) | evaluate (%d)",
|
126
|
+
self.min_train_nodes,
|
127
|
+
self.min_evaluate_nodes,
|
128
|
+
) # pylint: disable=line-too-long
|
129
|
+
log(INFO, "\t│\t└──Minimum available nodes: %d", self.min_available_nodes)
|
130
|
+
log(INFO, "\t└──> Keys in records:")
|
131
|
+
log(INFO, "\t\t├── Weighted by: '%s'", self.weighted_by_key)
|
132
|
+
log(INFO, "\t\t├── ArrayRecord key: '%s'", self.arrayrecord_key)
|
133
|
+
log(INFO, "\t\t└── ConfigRecord key: '%s'", self.configrecord_key)
|
134
|
+
|
135
|
+
def _construct_messages(
|
136
|
+
self, record: RecordDict, node_ids: list[int], message_type: str
|
137
|
+
) -> Iterable[Message]:
|
138
|
+
"""Construct N Messages carrying the same RecordDict payload."""
|
139
|
+
messages = []
|
140
|
+
for node_id in node_ids: # one message for each node
|
141
|
+
message = Message(
|
142
|
+
content=record,
|
143
|
+
message_type=message_type,
|
144
|
+
dst_node_id=node_id,
|
145
|
+
)
|
146
|
+
messages.append(message)
|
147
|
+
return messages
|
148
|
+
|
149
|
+
def configure_train(
|
150
|
+
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
151
|
+
) -> Iterable[Message]:
|
152
|
+
"""Configure the next round of federated training."""
|
153
|
+
# Sample nodes
|
154
|
+
num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_train)
|
155
|
+
sample_size = max(num_nodes, self.min_train_nodes)
|
156
|
+
node_ids, num_total = sample_nodes(grid, self.min_available_nodes, sample_size)
|
157
|
+
log(
|
158
|
+
INFO,
|
159
|
+
"configure_train: Sampled %s nodes (out of %s)",
|
160
|
+
len(node_ids),
|
161
|
+
len(num_total),
|
162
|
+
)
|
163
|
+
# Always inject current server round
|
164
|
+
config["server-round"] = server_round
|
165
|
+
|
166
|
+
# Construct messages
|
167
|
+
record = RecordDict(
|
168
|
+
{self.arrayrecord_key: arrays, self.configrecord_key: config}
|
169
|
+
)
|
170
|
+
return self._construct_messages(record, node_ids, MessageType.TRAIN)
|
171
|
+
|
172
|
+
def aggregate_train(
|
173
|
+
self,
|
174
|
+
server_round: int,
|
175
|
+
replies: Iterable[Message],
|
176
|
+
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
177
|
+
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
178
|
+
if not replies:
|
179
|
+
return None, None
|
180
|
+
|
181
|
+
# Log if any Messages carried errors
|
182
|
+
# Filter messages that carry content
|
183
|
+
num_errors = 0
|
184
|
+
replies_with_content = []
|
185
|
+
for msg in replies:
|
186
|
+
if msg.has_error():
|
187
|
+
log(
|
188
|
+
INFO,
|
189
|
+
"Received error in reply from node %d: %s",
|
190
|
+
msg.metadata.src_node_id,
|
191
|
+
msg.error,
|
192
|
+
)
|
193
|
+
num_errors += 1
|
194
|
+
else:
|
195
|
+
replies_with_content.append(msg.content)
|
196
|
+
|
197
|
+
log(
|
198
|
+
INFO,
|
199
|
+
"aggregate_train: Received %s results and %s failures",
|
200
|
+
len(replies_with_content) - num_errors,
|
201
|
+
num_errors,
|
202
|
+
)
|
203
|
+
|
204
|
+
# Ensure expected ArrayRecords and MetricRecords are received
|
205
|
+
validate_message_reply_consistency(
|
206
|
+
replies=replies_with_content,
|
207
|
+
weighted_by_key=self.weighted_by_key,
|
208
|
+
check_arrayrecord=True,
|
209
|
+
)
|
210
|
+
|
211
|
+
# Aggregate ArrayRecords
|
212
|
+
arrays = aggregate_arrayrecords(
|
213
|
+
replies_with_content,
|
214
|
+
self.weighted_by_key,
|
215
|
+
)
|
216
|
+
|
217
|
+
# Aggregate MetricRecords
|
218
|
+
metrics = self.train_metrics_aggr_fn(
|
219
|
+
replies_with_content,
|
220
|
+
self.weighted_by_key,
|
221
|
+
)
|
222
|
+
return arrays, metrics
|
223
|
+
|
224
|
+
def configure_evaluate(
|
225
|
+
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
226
|
+
) -> Iterable[Message]:
|
227
|
+
"""Configure the next round of federated evaluation."""
|
228
|
+
# Sample nodes
|
229
|
+
num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_evaluate)
|
230
|
+
sample_size = max(num_nodes, self.min_evaluate_nodes)
|
231
|
+
node_ids, num_total = sample_nodes(grid, self.min_available_nodes, sample_size)
|
232
|
+
log(
|
233
|
+
INFO,
|
234
|
+
"configure_evaluate: Sampled %s nodes (out of %s)",
|
235
|
+
len(node_ids),
|
236
|
+
len(num_total),
|
237
|
+
)
|
238
|
+
|
239
|
+
# Always inject current server round
|
240
|
+
config["server-round"] = server_round
|
241
|
+
|
242
|
+
# Construct messages
|
243
|
+
record = RecordDict(
|
244
|
+
{self.arrayrecord_key: arrays, self.configrecord_key: config}
|
245
|
+
)
|
246
|
+
return self._construct_messages(record, node_ids, MessageType.EVALUATE)
|
247
|
+
|
248
|
+
def aggregate_evaluate(
|
249
|
+
self,
|
250
|
+
server_round: int,
|
251
|
+
replies: Iterable[Message],
|
252
|
+
) -> Optional[MetricRecord]:
|
253
|
+
"""Aggregate MetricRecords in the received Messages."""
|
254
|
+
if not replies:
|
255
|
+
return None
|
256
|
+
|
257
|
+
# Log if any Messages carried errors
|
258
|
+
# Filter messages that carry content
|
259
|
+
num_errors = 0
|
260
|
+
replies_with_content = []
|
261
|
+
for msg in replies:
|
262
|
+
if msg.has_error():
|
263
|
+
log(
|
264
|
+
INFO,
|
265
|
+
"Received error in reply from node %d: %s",
|
266
|
+
msg.metadata.src_node_id,
|
267
|
+
msg.error,
|
268
|
+
)
|
269
|
+
num_errors += 1
|
270
|
+
else:
|
271
|
+
replies_with_content.append(msg.content)
|
272
|
+
|
273
|
+
log(
|
274
|
+
INFO,
|
275
|
+
"aggregate_evaluate: Received %s results and %s failures",
|
276
|
+
len(replies_with_content) - num_errors,
|
277
|
+
num_errors,
|
278
|
+
)
|
279
|
+
|
280
|
+
# Ensure expected ArrayRecords and MetricRecords are received
|
281
|
+
validate_message_reply_consistency(
|
282
|
+
replies=replies_with_content,
|
283
|
+
weighted_by_key=self.weighted_by_key,
|
284
|
+
check_arrayrecord=False,
|
285
|
+
)
|
286
|
+
|
287
|
+
# Aggregate MetricRecords
|
288
|
+
metrics = self.evaluate_metrics_aggr_fn(
|
289
|
+
replies_with_content,
|
290
|
+
self.weighted_by_key,
|
291
|
+
)
|
292
|
+
return metrics
|
flwr/serverapp/result.py
ADDED
@@ -0,0 +1,30 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Strategy results."""
|
16
|
+
|
17
|
+
|
18
|
+
from dataclasses import dataclass, field
|
19
|
+
|
20
|
+
from flwr.common import ArrayRecord, MetricRecord
|
21
|
+
|
22
|
+
|
23
|
+
@dataclass
|
24
|
+
class Result:
|
25
|
+
"""Data class carrying records generated during the execution of a strategy."""
|
26
|
+
|
27
|
+
arrays: ArrayRecord = field(default_factory=ArrayRecord)
|
28
|
+
train_metrics_clientapp: dict[int, MetricRecord] = field(default_factory=dict)
|
29
|
+
evaluate_metrics_clientapp: dict[int, MetricRecord] = field(default_factory=dict)
|
30
|
+
evaluate_metrics_serverapp: dict[int, MetricRecord] = field(default_factory=dict)
|
@@ -0,0 +1,286 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Flower message-based strategy."""
|
16
|
+
|
17
|
+
|
18
|
+
import time
|
19
|
+
from abc import ABC, abstractmethod
|
20
|
+
from collections.abc import Iterable
|
21
|
+
from logging import INFO
|
22
|
+
from typing import Callable, Optional
|
23
|
+
|
24
|
+
from flwr.common import ArrayRecord, ConfigRecord, Message, MetricRecord, log
|
25
|
+
from flwr.common.exit import ExitCode, flwr_exit
|
26
|
+
from flwr.server import Grid
|
27
|
+
|
28
|
+
from .result import Result
|
29
|
+
from .strategy_utils import InconsistentMessageReplies, log_strategy_start_info
|
30
|
+
|
31
|
+
|
32
|
+
class Strategy(ABC):
|
33
|
+
"""Abstract base class for server strategy implementations."""
|
34
|
+
|
35
|
+
@abstractmethod
|
36
|
+
def configure_train(
|
37
|
+
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
38
|
+
) -> Iterable[Message]:
|
39
|
+
"""Configure the next round of training.
|
40
|
+
|
41
|
+
Parameters
|
42
|
+
----------
|
43
|
+
server_round : int
|
44
|
+
The current round of federated learning.
|
45
|
+
arrays : ArrayRecord
|
46
|
+
Current global ArrayRecord (e.g. global model) to be sent to client
|
47
|
+
nodes for training.
|
48
|
+
config : ConfigRecord
|
49
|
+
Configuration to be sent to clients nodes for training.
|
50
|
+
grid : Grid
|
51
|
+
The Grid instance used for node sampling and communication.
|
52
|
+
|
53
|
+
Returns
|
54
|
+
-------
|
55
|
+
Iterable[Message]
|
56
|
+
An iterable of messages to be sent to selected client nodes for training.
|
57
|
+
"""
|
58
|
+
|
59
|
+
@abstractmethod
|
60
|
+
def aggregate_train(
|
61
|
+
self,
|
62
|
+
server_round: int,
|
63
|
+
replies: Iterable[Message],
|
64
|
+
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
65
|
+
"""Aggregate training results from client nodes.
|
66
|
+
|
67
|
+
Parameters
|
68
|
+
----------
|
69
|
+
server_round : int
|
70
|
+
The current round of federated learning, starting from 1.
|
71
|
+
replies : Iterable[Message]
|
72
|
+
Iterable of reply messages received from client nodes after training.
|
73
|
+
Each message contains ArrayRecords and MetricRecords that get aggregated.
|
74
|
+
|
75
|
+
Returns
|
76
|
+
-------
|
77
|
+
tuple[Optional[ArrayRecord], Optional[MetricRecord]]
|
78
|
+
A tuple containing:
|
79
|
+
- ArrayRecord: Aggregated ArrayRecord, or None if aggregation failed
|
80
|
+
- MetricRecord: Aggregated MetricRecord, or None if aggregation failed
|
81
|
+
"""
|
82
|
+
|
83
|
+
@abstractmethod
|
84
|
+
def configure_evaluate(
|
85
|
+
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
86
|
+
) -> Iterable[Message]:
|
87
|
+
"""Configure the next round of evaluation.
|
88
|
+
|
89
|
+
Parameters
|
90
|
+
----------
|
91
|
+
server_round : int
|
92
|
+
The current round of federated learning.
|
93
|
+
arrays : ArrayRecord
|
94
|
+
Current global ArrayRecord (e.g. global model) to be sent to client
|
95
|
+
nodes for evaluation.
|
96
|
+
config : ConfigRecord
|
97
|
+
Configuration to be sent to clients nodes for evaluation.
|
98
|
+
grid : Grid
|
99
|
+
The Grid instance used for node sampling and communication.
|
100
|
+
|
101
|
+
Returns
|
102
|
+
-------
|
103
|
+
Iterable[Message]
|
104
|
+
An iterable of messages to be sent to selected client nodes for evaluation.
|
105
|
+
"""
|
106
|
+
|
107
|
+
@abstractmethod
|
108
|
+
def aggregate_evaluate(
|
109
|
+
self,
|
110
|
+
server_round: int,
|
111
|
+
replies: Iterable[Message],
|
112
|
+
) -> Optional[MetricRecord]:
|
113
|
+
"""Aggregate evaluation metrics from client nodes.
|
114
|
+
|
115
|
+
Parameters
|
116
|
+
----------
|
117
|
+
server_round : int
|
118
|
+
The current round of federated learning.
|
119
|
+
replies : Iterable[Message]
|
120
|
+
Iterable of reply messages received from client nodes after evaluation.
|
121
|
+
MetricRecords in the messages are aggregated.
|
122
|
+
|
123
|
+
Returns
|
124
|
+
-------
|
125
|
+
Optional[MetricRecord]
|
126
|
+
Aggregated evaluation metrics from all participating clients,
|
127
|
+
or None if aggregation failed.
|
128
|
+
"""
|
129
|
+
|
130
|
+
@abstractmethod
|
131
|
+
def summary(self) -> None:
|
132
|
+
"""Log summary configuration of the strategy."""
|
133
|
+
|
134
|
+
# pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-locals
|
135
|
+
def start(
|
136
|
+
self,
|
137
|
+
grid: Grid,
|
138
|
+
initial_arrays: ArrayRecord,
|
139
|
+
num_rounds: int = 3,
|
140
|
+
timeout: float = 3600,
|
141
|
+
train_config: Optional[ConfigRecord] = None,
|
142
|
+
evaluate_config: Optional[ConfigRecord] = None,
|
143
|
+
evaluate_fn: Optional[Callable[[int, ArrayRecord], MetricRecord]] = None,
|
144
|
+
) -> Result:
|
145
|
+
"""Execute the federated learning strategy.
|
146
|
+
|
147
|
+
Runs the complete federated learning workflow for the specified number of
|
148
|
+
rounds, including training, evaluation, and optional centralized evaluation.
|
149
|
+
|
150
|
+
Parameters
|
151
|
+
----------
|
152
|
+
grid : Grid
|
153
|
+
The Grid instance used to send/receive Messages from nodes executing a
|
154
|
+
ClientApp.
|
155
|
+
initial_arrays : ArrayRecord
|
156
|
+
Initial model parameters (arrays) to be used for federated learning.
|
157
|
+
num_rounds : int (default: 3)
|
158
|
+
Number of federated learning rounds to execute.
|
159
|
+
timeout : float (default: 3600)
|
160
|
+
Timeout in seconds for waiting for node responses.
|
161
|
+
train_config : ConfigRecord, optional
|
162
|
+
Configuration to be sent to nodes during training rounds.
|
163
|
+
If unset, an empty ConfigRecord will be used.
|
164
|
+
evaluate_config : ConfigRecord, optional
|
165
|
+
Configuration to be sent to nodes during evaluation rounds.
|
166
|
+
If unset, an empty ConfigRecord will be used.
|
167
|
+
evaluate_fn : Callable[[int, ArrayRecord], MetricRecord], optional
|
168
|
+
Optional function for centralized evaluation of the global model. Takes
|
169
|
+
server round number and array record, returns a MetricRecord. If provided,
|
170
|
+
will be called before the first round and after each round. Defaults to
|
171
|
+
None.
|
172
|
+
|
173
|
+
Returns
|
174
|
+
-------
|
175
|
+
Results
|
176
|
+
Results containing final model arrays and also training metrics, evaluation
|
177
|
+
metrics and global evaluation metrics (if provided) from all rounds.
|
178
|
+
"""
|
179
|
+
log(INFO, "Starting %s strategy:", self.__class__.__name__)
|
180
|
+
log_strategy_start_info(
|
181
|
+
num_rounds, initial_arrays, train_config, evaluate_config
|
182
|
+
)
|
183
|
+
self.summary()
|
184
|
+
log(INFO, "")
|
185
|
+
|
186
|
+
# Initialize if None
|
187
|
+
train_config = ConfigRecord() if train_config is None else train_config
|
188
|
+
evaluate_config = ConfigRecord() if evaluate_config is None else evaluate_config
|
189
|
+
result = Result()
|
190
|
+
|
191
|
+
t_start = time.time()
|
192
|
+
# Evaluate starting global parameters
|
193
|
+
if evaluate_fn:
|
194
|
+
res = evaluate_fn(0, initial_arrays)
|
195
|
+
log(INFO, "Initial global evaluation results: %s", res)
|
196
|
+
result.evaluate_metrics_serverapp[0] = res
|
197
|
+
|
198
|
+
arrays = initial_arrays
|
199
|
+
|
200
|
+
for current_round in range(1, num_rounds + 1):
|
201
|
+
log(INFO, "")
|
202
|
+
log(INFO, "[ROUND %s/%s]", current_round, num_rounds)
|
203
|
+
|
204
|
+
# -----------------------------------------------------------------
|
205
|
+
# --- TRAINING ----------------------------------------------------
|
206
|
+
# -----------------------------------------------------------------
|
207
|
+
|
208
|
+
# Call strategy to configure training round
|
209
|
+
# Send messages and wait for replies
|
210
|
+
train_replies = grid.send_and_receive(
|
211
|
+
messages=self.configure_train(
|
212
|
+
current_round,
|
213
|
+
arrays,
|
214
|
+
train_config,
|
215
|
+
grid,
|
216
|
+
),
|
217
|
+
timeout=timeout,
|
218
|
+
)
|
219
|
+
|
220
|
+
# Aggregate train
|
221
|
+
try:
|
222
|
+
agg_arrays, agg_train_metrics = self.aggregate_train(
|
223
|
+
current_round,
|
224
|
+
train_replies,
|
225
|
+
)
|
226
|
+
except InconsistentMessageReplies as e:
|
227
|
+
flwr_exit(
|
228
|
+
ExitCode.SERVERAPP_STRATEGY_PRECONDITION_UNMET, message=str(e)
|
229
|
+
)
|
230
|
+
|
231
|
+
# Log training metrics and append to history
|
232
|
+
if agg_arrays is not None:
|
233
|
+
result.arrays = agg_arrays
|
234
|
+
arrays = agg_arrays
|
235
|
+
if agg_train_metrics is not None:
|
236
|
+
log(INFO, "\t└──> Aggregated MetricRecord: %s", agg_train_metrics)
|
237
|
+
result.train_metrics_clientapp[current_round] = agg_train_metrics
|
238
|
+
|
239
|
+
# -----------------------------------------------------------------
|
240
|
+
# --- EVALUATION (LOCAL) ------------------------------------------
|
241
|
+
# -----------------------------------------------------------------
|
242
|
+
|
243
|
+
# Call strategy to configure evaluation round
|
244
|
+
# Send messages and wait for replies
|
245
|
+
evaluate_replies = grid.send_and_receive(
|
246
|
+
messages=self.configure_evaluate(
|
247
|
+
current_round,
|
248
|
+
arrays,
|
249
|
+
evaluate_config,
|
250
|
+
grid,
|
251
|
+
),
|
252
|
+
timeout=timeout,
|
253
|
+
)
|
254
|
+
|
255
|
+
# Aggregate evaluate
|
256
|
+
try:
|
257
|
+
agg_evaluate_metrics = self.aggregate_evaluate(
|
258
|
+
current_round,
|
259
|
+
evaluate_replies,
|
260
|
+
)
|
261
|
+
except InconsistentMessageReplies as e:
|
262
|
+
flwr_exit(
|
263
|
+
ExitCode.SERVERAPP_STRATEGY_PRECONDITION_UNMET, message=str(e)
|
264
|
+
)
|
265
|
+
|
266
|
+
# Log training metrics and append to history
|
267
|
+
if agg_evaluate_metrics is not None:
|
268
|
+
log(INFO, "\t└──> Aggregated MetricRecord: %s", agg_evaluate_metrics)
|
269
|
+
result.evaluate_metrics_clientapp[current_round] = agg_evaluate_metrics
|
270
|
+
|
271
|
+
# -----------------------------------------------------------------
|
272
|
+
# --- EVALUATION (GLOBAL) -----------------------------------------
|
273
|
+
# -----------------------------------------------------------------
|
274
|
+
|
275
|
+
# Centralized evaluation
|
276
|
+
if evaluate_fn:
|
277
|
+
log(INFO, "Global evaluation")
|
278
|
+
res = evaluate_fn(current_round, arrays)
|
279
|
+
log(INFO, "\t└──> MetricRecord: %s", res)
|
280
|
+
result.evaluate_metrics_serverapp[current_round] = res
|
281
|
+
|
282
|
+
log(INFO, "")
|
283
|
+
log(INFO, "Strategy execution finished in %.2fs", time.time() - t_start)
|
284
|
+
log(INFO, "")
|
285
|
+
|
286
|
+
return result
|
@@ -0,0 +1,256 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Flower message-based strategy utilities."""
|
16
|
+
|
17
|
+
|
18
|
+
import random
|
19
|
+
from collections import OrderedDict
|
20
|
+
from logging import INFO
|
21
|
+
from time import sleep
|
22
|
+
from typing import Optional, cast
|
23
|
+
|
24
|
+
from flwr.common import (
|
25
|
+
Array,
|
26
|
+
ArrayRecord,
|
27
|
+
ConfigRecord,
|
28
|
+
MetricRecord,
|
29
|
+
NDArray,
|
30
|
+
RecordDict,
|
31
|
+
log,
|
32
|
+
)
|
33
|
+
from flwr.server import Grid
|
34
|
+
|
35
|
+
|
36
|
+
# Define a new exception
|
37
|
+
class InconsistentMessageReplies(Exception):
|
38
|
+
"""Exception triggered when replies are inconsistent and therefore aggregation must
|
39
|
+
be skipped."""
|
40
|
+
|
41
|
+
def __init__(self, reason: str):
|
42
|
+
super().__init__(reason)
|
43
|
+
|
44
|
+
|
45
|
+
def config_to_str(config: ConfigRecord) -> str:
|
46
|
+
"""Convert a ConfigRecord to a string representation masking bytes."""
|
47
|
+
content = ", ".join(
|
48
|
+
f"'{k}': {'<bytes>' if isinstance(v, bytes) else v}" for k, v in config.items()
|
49
|
+
)
|
50
|
+
return f"{{{content}}}"
|
51
|
+
|
52
|
+
|
53
|
+
def log_strategy_start_info(
|
54
|
+
num_rounds: int,
|
55
|
+
arrays: ArrayRecord,
|
56
|
+
train_config: Optional[ConfigRecord],
|
57
|
+
evaluate_config: Optional[ConfigRecord],
|
58
|
+
) -> None:
|
59
|
+
"""Log information about the strategy start."""
|
60
|
+
log(INFO, "\t├── Number of rounds: %d", num_rounds)
|
61
|
+
log(
|
62
|
+
INFO,
|
63
|
+
"\t├── ArrayRecord (%.2f MB)",
|
64
|
+
sum(len(array.data) for array in arrays.values()) / (1024**2),
|
65
|
+
)
|
66
|
+
log(
|
67
|
+
INFO,
|
68
|
+
"\t├── ConfigRecord (train): %s",
|
69
|
+
config_to_str(train_config) if train_config else "(empty!)",
|
70
|
+
)
|
71
|
+
log(
|
72
|
+
INFO,
|
73
|
+
"\t├── ConfigRecord (evaluate): %s",
|
74
|
+
config_to_str(evaluate_config) if evaluate_config else "(empty!)",
|
75
|
+
)
|
76
|
+
|
77
|
+
|
78
|
+
def aggregate_arrayrecords(
|
79
|
+
records: list[RecordDict], weighting_metric_name: str
|
80
|
+
) -> ArrayRecord:
|
81
|
+
"""Perform weighted aggregation all ArrayRecords using a specific key."""
|
82
|
+
# Retrieve weighting factor from MetricRecord
|
83
|
+
weights: list[float] = []
|
84
|
+
for record in records:
|
85
|
+
# Get the first (and only) MetricRecord in the record
|
86
|
+
metricrecord = next(iter(record.metric_records.values()))
|
87
|
+
# Because replies have been checked for consistency,
|
88
|
+
# we can safely cast the weighting factor to float
|
89
|
+
w = cast(float, metricrecord[weighting_metric_name])
|
90
|
+
weights.append(w)
|
91
|
+
|
92
|
+
# Average
|
93
|
+
total_weight = sum(weights)
|
94
|
+
weight_factors = [w / total_weight for w in weights]
|
95
|
+
|
96
|
+
# Perform weighted aggregation
|
97
|
+
aggregated_np_arrays: dict[str, NDArray] = {}
|
98
|
+
|
99
|
+
for record, weight in zip(records, weight_factors):
|
100
|
+
for record_item in record.array_records.values():
|
101
|
+
# aggregate in-place
|
102
|
+
for key, value in record_item.items():
|
103
|
+
if key not in aggregated_np_arrays:
|
104
|
+
aggregated_np_arrays[key] = value.numpy() * weight
|
105
|
+
else:
|
106
|
+
aggregated_np_arrays[key] += value.numpy() * weight
|
107
|
+
|
108
|
+
return ArrayRecord(
|
109
|
+
OrderedDict({k: Array(v) for k, v in aggregated_np_arrays.items()})
|
110
|
+
)
|
111
|
+
|
112
|
+
|
113
|
+
def aggregate_metricrecords(
|
114
|
+
records: list[RecordDict], weighting_metric_name: str
|
115
|
+
) -> MetricRecord:
|
116
|
+
"""Perform weighted aggregation all MetricRecords using a specific key."""
|
117
|
+
# Retrieve weighting factor from MetricRecord
|
118
|
+
weights: list[float] = []
|
119
|
+
for record in records:
|
120
|
+
# Get the first (and only) MetricRecord in the record
|
121
|
+
metricrecord = next(iter(record.metric_records.values()))
|
122
|
+
# Because replies have been checked for consistency,
|
123
|
+
# we can safely cast the weighting factor to float
|
124
|
+
w = cast(float, metricrecord[weighting_metric_name])
|
125
|
+
weights.append(w)
|
126
|
+
|
127
|
+
# Average
|
128
|
+
total_weight = sum(weights)
|
129
|
+
weight_factors = [w / total_weight for w in weights]
|
130
|
+
|
131
|
+
aggregated_metrics = MetricRecord()
|
132
|
+
for record, weight in zip(records, weight_factors):
|
133
|
+
for record_item in record.metric_records.values():
|
134
|
+
# aggregate in-place
|
135
|
+
for key, value in record_item.items():
|
136
|
+
if key == weighting_metric_name:
|
137
|
+
# We exclude the weighting key from the aggregated MetricRecord
|
138
|
+
continue
|
139
|
+
if key not in aggregated_metrics:
|
140
|
+
if isinstance(value, list):
|
141
|
+
aggregated_metrics[key] = [v * weight for v in value]
|
142
|
+
else:
|
143
|
+
aggregated_metrics[key] = value * weight
|
144
|
+
else:
|
145
|
+
if isinstance(value, list):
|
146
|
+
current_list = cast(list[float], aggregated_metrics[key])
|
147
|
+
aggregated_metrics[key] = [
|
148
|
+
curr + val * weight
|
149
|
+
for curr, val in zip(current_list, value)
|
150
|
+
]
|
151
|
+
else:
|
152
|
+
current_value = cast(float, aggregated_metrics[key])
|
153
|
+
aggregated_metrics[key] = current_value + value * weight
|
154
|
+
|
155
|
+
return aggregated_metrics
|
156
|
+
|
157
|
+
|
158
|
+
def sample_nodes(
|
159
|
+
grid: Grid, min_available_nodes: int, sample_size: int
|
160
|
+
) -> tuple[list[int], list[int]]:
|
161
|
+
"""Sample the specified number of nodes using the Grid.
|
162
|
+
|
163
|
+
Parameters
|
164
|
+
----------
|
165
|
+
grid : Grid
|
166
|
+
The grid object.
|
167
|
+
min_available_nodes : int
|
168
|
+
The minimum number of available nodes to sample from.
|
169
|
+
sample_size : int
|
170
|
+
The number of nodes to sample.
|
171
|
+
|
172
|
+
Returns
|
173
|
+
-------
|
174
|
+
tuple[list[int], list[int]]
|
175
|
+
A tuple containing the sampled node IDs and the list
|
176
|
+
of all connected node IDs.
|
177
|
+
"""
|
178
|
+
sampled_nodes = []
|
179
|
+
|
180
|
+
# Ensure min_available_nodes is at least as large as sample_size
|
181
|
+
min_available_nodes = max(min_available_nodes, sample_size)
|
182
|
+
|
183
|
+
# wait for min_available_nodes to be online
|
184
|
+
while len(all_nodes := list(grid.get_node_ids())) < min_available_nodes:
|
185
|
+
log(
|
186
|
+
INFO,
|
187
|
+
"Waiting for nodes to connect: %d connected (minimum required: %d).",
|
188
|
+
len(all_nodes),
|
189
|
+
min_available_nodes,
|
190
|
+
)
|
191
|
+
sleep(1)
|
192
|
+
|
193
|
+
# Sample nodes
|
194
|
+
sampled_nodes = random.sample(all_nodes, sample_size)
|
195
|
+
|
196
|
+
return sampled_nodes, all_nodes
|
197
|
+
|
198
|
+
|
199
|
+
# pylint: disable=too-many-return-statements
|
200
|
+
def validate_message_reply_consistency(
|
201
|
+
replies: list[RecordDict], weighted_by_key: str, check_arrayrecord: bool
|
202
|
+
) -> None:
|
203
|
+
"""Validate that replies contain exactly one ArrayRecord and one MetricRecord, and
|
204
|
+
that the MetricRecord includes a weight factor key.
|
205
|
+
|
206
|
+
These checks ensure that Message-based strategies behave consistently with
|
207
|
+
*Ins/*Res-based strategies.
|
208
|
+
"""
|
209
|
+
# Checking for ArrayRecord consistency
|
210
|
+
if check_arrayrecord:
|
211
|
+
if any(len(msg.array_records) != 1 for msg in replies):
|
212
|
+
raise InconsistentMessageReplies(
|
213
|
+
reason="Expected exactly one ArrayRecord in replies. "
|
214
|
+
"Skipping aggregation."
|
215
|
+
)
|
216
|
+
|
217
|
+
# Ensure all key are present in all ArrayRecords
|
218
|
+
record_key = next(iter(replies[0].array_records.keys()))
|
219
|
+
all_keys = set(replies[0][record_key].keys())
|
220
|
+
if any(set(msg.get(record_key, {}).keys()) != all_keys for msg in replies[1:]):
|
221
|
+
raise InconsistentMessageReplies(
|
222
|
+
reason="All ArrayRecords must have the same keys for aggregation. "
|
223
|
+
"This condition wasn't met. Skipping aggregation."
|
224
|
+
)
|
225
|
+
|
226
|
+
# Checking for MetricRecord consistency
|
227
|
+
if any(len(msg.metric_records) != 1 for msg in replies):
|
228
|
+
raise InconsistentMessageReplies(
|
229
|
+
reason="Expected exactly one MetricRecord in replies, but found more. "
|
230
|
+
"Skipping aggregation."
|
231
|
+
)
|
232
|
+
|
233
|
+
# Ensure all key are present in all MetricRecords
|
234
|
+
record_key = next(iter(replies[0].metric_records.keys()))
|
235
|
+
all_keys = set(replies[0][record_key].keys())
|
236
|
+
if any(set(msg.get(record_key, {}).keys()) != all_keys for msg in replies[1:]):
|
237
|
+
raise InconsistentMessageReplies(
|
238
|
+
reason="All MetricRecords must have the same keys for aggregation. "
|
239
|
+
"This condition wasn't met. Skipping aggregation."
|
240
|
+
)
|
241
|
+
|
242
|
+
# Verify the weight factor key presence in all MetricRecords
|
243
|
+
if weighted_by_key not in all_keys:
|
244
|
+
raise InconsistentMessageReplies(
|
245
|
+
reason=f"Missing required key `{weighted_by_key}` in the MetricRecord of "
|
246
|
+
"reply messages. Cannot average ArrayRecords and MetricRecords. Skipping "
|
247
|
+
"aggregation."
|
248
|
+
)
|
249
|
+
|
250
|
+
# Check that it is not a list
|
251
|
+
if any(isinstance(msg[record_key][weighted_by_key], list) for msg in replies):
|
252
|
+
raise InconsistentMessageReplies(
|
253
|
+
reason=f"Key `{weighted_by_key}` in the MetricRecord of reply messages "
|
254
|
+
"must be a single value (int or float), but a list was found. Skipping "
|
255
|
+
"aggregation."
|
256
|
+
)
|
@@ -0,0 +1,277 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Tests for message-based strategy utilities."""
|
16
|
+
|
17
|
+
|
18
|
+
from collections import OrderedDict
|
19
|
+
|
20
|
+
import numpy as np
|
21
|
+
import pytest
|
22
|
+
from parameterized import parameterized
|
23
|
+
|
24
|
+
from flwr.common import Array, ArrayRecord, ConfigRecord, MetricRecord, RecordDict
|
25
|
+
|
26
|
+
from .strategy_utils import (
|
27
|
+
InconsistentMessageReplies,
|
28
|
+
aggregate_arrayrecords,
|
29
|
+
aggregate_metricrecords,
|
30
|
+
config_to_str,
|
31
|
+
validate_message_reply_consistency,
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
def test_config_to_str() -> None:
|
36
|
+
"""Test that items of types bytes are masked out."""
|
37
|
+
config = ConfigRecord({"a": 123, "b": [1, 2, 3], "c": b"bytes"})
|
38
|
+
expected_str = "{'a': 123, 'b': [1, 2, 3], 'c': <bytes>}"
|
39
|
+
assert config_to_str(config) == expected_str
|
40
|
+
|
41
|
+
|
42
|
+
def test_arrayrecords_aggregation() -> None:
|
43
|
+
"""Test aggregation of ArrayRecords."""
|
44
|
+
num_replies = 3
|
45
|
+
num_arrays = 4
|
46
|
+
weights = [0.25, 0.4, 0.35]
|
47
|
+
np_arrays = [
|
48
|
+
[np.random.randn(7, 3) for _ in range(num_arrays)] for _ in range(num_replies)
|
49
|
+
]
|
50
|
+
|
51
|
+
avg_list = [
|
52
|
+
np.average([lst[i] for lst in np_arrays], axis=0, weights=weights)
|
53
|
+
for i in range(num_arrays)
|
54
|
+
]
|
55
|
+
|
56
|
+
# Construct RecordDicts (mimicing replies)
|
57
|
+
records = [
|
58
|
+
RecordDict(
|
59
|
+
{
|
60
|
+
"arrays": ArrayRecord(np_arrays[i]),
|
61
|
+
"metrics": MetricRecord({"weight": weights[i]}),
|
62
|
+
}
|
63
|
+
)
|
64
|
+
for i in range(num_replies)
|
65
|
+
]
|
66
|
+
# Execute aggregate
|
67
|
+
aggrd = aggregate_arrayrecords(records, weighting_metric_name="weight")
|
68
|
+
|
69
|
+
# Assert consistency
|
70
|
+
assert all(np.allclose(a, b) for a, b in zip(aggrd.to_numpy_ndarrays(), avg_list))
|
71
|
+
assert aggrd.object_id == ArrayRecord(avg_list).object_id
|
72
|
+
|
73
|
+
|
74
|
+
def test_metricrecords_aggregation() -> None:
|
75
|
+
"""Test aggregation of MetricRecords."""
|
76
|
+
num_replies = 3
|
77
|
+
weights = [0.25, 0.4, 0.35]
|
78
|
+
metric_records = [
|
79
|
+
MetricRecord({"a": 1, "b": 2.0, "c": np.random.randn(3).tolist()})
|
80
|
+
for _ in range(num_replies)
|
81
|
+
]
|
82
|
+
|
83
|
+
# Compute expected aggregated MetricRecord.
|
84
|
+
# For ease, we convert everything into numpy arrays, then aggregate
|
85
|
+
as_np_entries = [
|
86
|
+
{
|
87
|
+
k: np.array(v) if isinstance(v, (int, float, list)) else v
|
88
|
+
for k, v in record.items()
|
89
|
+
}
|
90
|
+
for record in metric_records
|
91
|
+
]
|
92
|
+
avg_list = [
|
93
|
+
np.average(
|
94
|
+
[list(entries.values())[i] for entries in as_np_entries],
|
95
|
+
axis=0,
|
96
|
+
weights=weights,
|
97
|
+
).tolist()
|
98
|
+
for i in range(len(as_np_entries[0]))
|
99
|
+
]
|
100
|
+
expected_record = MetricRecord(dict(zip(as_np_entries[0].keys(), avg_list)))
|
101
|
+
expected_record["a"] = float(expected_record["a"]) # type: ignore
|
102
|
+
expected_record["b"] = float(expected_record["b"]) # type: ignore
|
103
|
+
|
104
|
+
# Construct RecordDicts (mimicing replies)
|
105
|
+
# Inject weighting factor
|
106
|
+
records = [
|
107
|
+
RecordDict(
|
108
|
+
{
|
109
|
+
"metrics": MetricRecord(
|
110
|
+
record.__dict__["_data"] | {"weight": weights[i]}
|
111
|
+
),
|
112
|
+
}
|
113
|
+
)
|
114
|
+
for i, record in enumerate(metric_records)
|
115
|
+
]
|
116
|
+
|
117
|
+
# Execute aggregate
|
118
|
+
aggrd = aggregate_metricrecords(records, weighting_metric_name="weight")
|
119
|
+
# Assert
|
120
|
+
assert expected_record.object_id == aggrd.object_id
|
121
|
+
|
122
|
+
|
123
|
+
@parameterized.expand( # type: ignore
|
124
|
+
[
|
125
|
+
(
|
126
|
+
True,
|
127
|
+
RecordDict(
|
128
|
+
{
|
129
|
+
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
130
|
+
"metrics": MetricRecord({"weight": 0.123}),
|
131
|
+
}
|
132
|
+
),
|
133
|
+
), # Compliant
|
134
|
+
(
|
135
|
+
False,
|
136
|
+
RecordDict(
|
137
|
+
{
|
138
|
+
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
139
|
+
"metrics": MetricRecord({"weight": [0.123]}),
|
140
|
+
}
|
141
|
+
),
|
142
|
+
), # Weighting key is not a scalar (BAD)
|
143
|
+
(
|
144
|
+
False,
|
145
|
+
RecordDict(
|
146
|
+
{
|
147
|
+
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
148
|
+
"metrics": MetricRecord({"loss": 0.01}),
|
149
|
+
}
|
150
|
+
),
|
151
|
+
), # No weighting key in MetricRecord (BAD)
|
152
|
+
(
|
153
|
+
False,
|
154
|
+
RecordDict({"global-model": ArrayRecord([np.random.randn(7, 3)])}),
|
155
|
+
), # No MetricsRecord (BAD)
|
156
|
+
(
|
157
|
+
False,
|
158
|
+
RecordDict(
|
159
|
+
{
|
160
|
+
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
161
|
+
"another-model": ArrayRecord([np.random.randn(7, 3)]),
|
162
|
+
}
|
163
|
+
),
|
164
|
+
), # Two ArrayRecords (BAD)
|
165
|
+
(
|
166
|
+
False,
|
167
|
+
RecordDict(
|
168
|
+
{
|
169
|
+
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
170
|
+
"metrics": MetricRecord({"weight": 0.123}),
|
171
|
+
"more-metrics": MetricRecord({"loss": 0.321}),
|
172
|
+
}
|
173
|
+
),
|
174
|
+
), # Two MetricRecords (BAD)
|
175
|
+
]
|
176
|
+
)
|
177
|
+
def test_consistency_of_replies_with_matching_keys(
|
178
|
+
is_valid: bool, recorddict: RecordDict
|
179
|
+
) -> None:
|
180
|
+
"""Test consistency in replies."""
|
181
|
+
# Create dummy records
|
182
|
+
records = [recorddict for _ in range(3)]
|
183
|
+
|
184
|
+
if not is_valid:
|
185
|
+
# Should raise InconsistentMessageReplies exception
|
186
|
+
with pytest.raises(InconsistentMessageReplies):
|
187
|
+
validate_message_reply_consistency(
|
188
|
+
records, weighted_by_key="weight", check_arrayrecord=True
|
189
|
+
)
|
190
|
+
else:
|
191
|
+
# Should not raise an exception
|
192
|
+
validate_message_reply_consistency(
|
193
|
+
records, weighted_by_key="weight", check_arrayrecord=True
|
194
|
+
)
|
195
|
+
|
196
|
+
|
197
|
+
@parameterized.expand( # type: ignore
|
198
|
+
[
|
199
|
+
(
|
200
|
+
[
|
201
|
+
RecordDict(
|
202
|
+
{
|
203
|
+
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
204
|
+
"metrics": MetricRecord({"weight": 0.123}),
|
205
|
+
}
|
206
|
+
),
|
207
|
+
RecordDict(
|
208
|
+
{
|
209
|
+
"model": ArrayRecord([np.random.randn(7, 3)]),
|
210
|
+
"metrics": MetricRecord({"weight": 0.123}),
|
211
|
+
}
|
212
|
+
),
|
213
|
+
],
|
214
|
+
), # top-level keys don't match for ArrayRecords
|
215
|
+
(
|
216
|
+
[
|
217
|
+
RecordDict(
|
218
|
+
{
|
219
|
+
"global-model": ArrayRecord(
|
220
|
+
OrderedDict({"a": Array(np.random.randn(7, 3))})
|
221
|
+
),
|
222
|
+
"metrics": MetricRecord({"weight": 0.123}),
|
223
|
+
}
|
224
|
+
),
|
225
|
+
RecordDict(
|
226
|
+
{
|
227
|
+
"global-model": ArrayRecord(
|
228
|
+
OrderedDict({"b": Array(np.random.randn(7, 3))})
|
229
|
+
),
|
230
|
+
"metrics": MetricRecord({"weight": 0.123}),
|
231
|
+
}
|
232
|
+
),
|
233
|
+
],
|
234
|
+
), # top-level keys match for ArrayRecords but not those for Arrays
|
235
|
+
(
|
236
|
+
[
|
237
|
+
RecordDict(
|
238
|
+
{
|
239
|
+
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
240
|
+
"metrics": MetricRecord({"weight": 0.123}),
|
241
|
+
}
|
242
|
+
),
|
243
|
+
RecordDict(
|
244
|
+
{
|
245
|
+
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
246
|
+
"my-metrics": MetricRecord({"weight": 0.123}),
|
247
|
+
}
|
248
|
+
),
|
249
|
+
],
|
250
|
+
), # top-level keys don't match for MetricRecords
|
251
|
+
(
|
252
|
+
[
|
253
|
+
RecordDict(
|
254
|
+
{
|
255
|
+
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
256
|
+
"metrics": MetricRecord({"weight": 0.123}),
|
257
|
+
}
|
258
|
+
),
|
259
|
+
RecordDict(
|
260
|
+
{
|
261
|
+
"global-model": ArrayRecord([np.random.randn(7, 3)]),
|
262
|
+
"my-metrics": MetricRecord({"my-weights": 0.123}),
|
263
|
+
}
|
264
|
+
),
|
265
|
+
],
|
266
|
+
), # top-level keys match for MetricRecords but not inner ones
|
267
|
+
]
|
268
|
+
)
|
269
|
+
def test_consistency_of_replies_with_different_keys(
|
270
|
+
list_records: list[RecordDict],
|
271
|
+
) -> None:
|
272
|
+
"""Test consistency in replies when records don't have matching keys."""
|
273
|
+
# All test cases expect InconsistentMessageReplies exception to be raised
|
274
|
+
with pytest.raises(InconsistentMessageReplies):
|
275
|
+
validate_message_reply_consistency(
|
276
|
+
list_records, weighted_by_key="weight", check_arrayrecord=True
|
277
|
+
)
|
{flwr_nightly-1.21.0.dev20250826.dist-info → flwr_nightly-1.21.0.dev20250827.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: flwr-nightly
|
3
|
-
Version: 1.21.0.
|
3
|
+
Version: 1.21.0.dev20250827
|
4
4
|
Summary: Flower: A Friendly Federated AI Framework
|
5
5
|
License: Apache-2.0
|
6
6
|
Keywords: Artificial Intelligence,Federated AI,Federated Analytics,Federated Evaluation,Federated Learning,Flower,Machine Learning
|
{flwr_nightly-1.21.0.dev20250826.dist-info → flwr_nightly-1.21.0.dev20250827.dist-info}/RECORD
RENAMED
@@ -118,7 +118,7 @@ flwr/common/event_log_plugin/__init__.py,sha256=ts3VAL3Fk6Grp1EK_1Qg_V-BfOof9F86
|
|
118
118
|
flwr/common/event_log_plugin/event_log_plugin.py,sha256=4SkVa1Ic-sPlICJShBuggXmXDcQtWQ1KDby4kthFNF0,2064
|
119
119
|
flwr/common/exit/__init__.py,sha256=-ZOJYLaNnR729a7VzZiFsLiqngzKQh3xc27svYStZ_Q,826
|
120
120
|
flwr/common/exit/exit.py,sha256=dWZgznSZkg8Tr_Bh9jRHGUhlWk2228q5XFIK98Zr4Tc,3531
|
121
|
-
flwr/common/exit/exit_code.py,sha256=
|
121
|
+
flwr/common/exit/exit_code.py,sha256=K0QSQg5elE3xl3OHBaMu1vClAXpyxqBEo6t_weFkh7I,4910
|
122
122
|
flwr/common/exit_handlers.py,sha256=IaqJ60fXZuu7McaRYnoYKtlbH9t4Yl9goNExKqtmQbs,4304
|
123
123
|
flwr/common/grpc.py,sha256=y70hUFvXkIf3l03xOhlb7qhS6W1UJZRSZqCdB0ir0v8,10381
|
124
124
|
flwr/common/heartbeat.py,sha256=SyEpNDnmJ0lni0cWO67rcoJVKasCLmkNHm3dKLeNrLU,5749
|
@@ -320,7 +320,12 @@ flwr/server/workflow/default_workflows.py,sha256=RlD26dXbSksY-23f3ZspnN1YU1DOhDY
|
|
320
320
|
flwr/server/workflow/secure_aggregation/__init__.py,sha256=vGkycLb65CxdaMkKsANxQE6AS4urfZKvwcS3r1Vln_c,880
|
321
321
|
flwr/server/workflow/secure_aggregation/secagg_workflow.py,sha256=b_pKk7gmbahwyj0ftOOLXvu-AMtRHEc82N9PJTEO8dc,5839
|
322
322
|
flwr/server/workflow/secure_aggregation/secaggplus_workflow.py,sha256=DkayCsnlAya6Y2PZsueLgoUCMRtV-GbnW08RfWx_SXM,29460
|
323
|
-
flwr/serverapp/__init__.py,sha256=
|
323
|
+
flwr/serverapp/__init__.py,sha256=SRPsqsa4pOfcF9J3_i1hb9KJi3z4KDTTCqCTwv7DcK0,864
|
324
|
+
flwr/serverapp/fedavg.py,sha256=Z051Z3XBYmaMzIKRn5uSlqb9FrRTUAXxuoMurMZn3PE,10861
|
325
|
+
flwr/serverapp/result.py,sha256=rw1ZoCGBosSVSNrTLLUFMxP1XzDwJWWsn1qdBR7JtlI,1229
|
326
|
+
flwr/serverapp/strategy.py,sha256=1mxxtA5Pyg9lZ1d3g4OCL-m8YR_0E3HUGl8Gv5BGOXY,10982
|
327
|
+
flwr/serverapp/strategy_utils.py,sha256=P2DO3pcrDTDYcrjkmYuL79Bbv2boj7T4bZ42EeRTyYk,9412
|
328
|
+
flwr/serverapp/strategy_utils_tests.py,sha256=taG6HwApwutkjUuMY3R8Ib48Xepw6g5xl9HEB_-leoY,9232
|
324
329
|
flwr/simulation/__init__.py,sha256=Gg6OsP1Z-ixc3-xxzvl7j7rz2Fijy9rzyEPpxgAQCeM,1556
|
325
330
|
flwr/simulation/app.py,sha256=LbGLMvN9Ap119yBqsUcNNmVLRnCySnr4VechqcQ1hpA,10401
|
326
331
|
flwr/simulation/legacy_app.py,sha256=nMISQqW0otJL1-2Kfd94O6BLlGS2IEmEPKTM2WGKrIs,15861
|
@@ -380,7 +385,7 @@ flwr/supernode/servicer/__init__.py,sha256=lucTzre5WPK7G1YLCfaqg3rbFWdNSb7ZTt-ca
|
|
380
385
|
flwr/supernode/servicer/clientappio/__init__.py,sha256=7Oy62Y_oijqF7Dxi6tpcUQyOpLc_QpIRZ83NvwmB0Yg,813
|
381
386
|
flwr/supernode/servicer/clientappio/clientappio_servicer.py,sha256=nIHRu38EWK-rpNOkcgBRAAKwYQQWFeCwu0lkO7OPZGQ,10239
|
382
387
|
flwr/supernode/start_client_internal.py,sha256=ftS8GOyT9M1tOWpbobN_Xrz4xwPAPOvsTGiWSfzhheE,20269
|
383
|
-
flwr_nightly-1.21.0.
|
384
|
-
flwr_nightly-1.21.0.
|
385
|
-
flwr_nightly-1.21.0.
|
386
|
-
flwr_nightly-1.21.0.
|
388
|
+
flwr_nightly-1.21.0.dev20250827.dist-info/METADATA,sha256=bBuFboanPjg8v12eYhwPbkoGwCxFJbwgOnTUsaZG2sU,15967
|
389
|
+
flwr_nightly-1.21.0.dev20250827.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
390
|
+
flwr_nightly-1.21.0.dev20250827.dist-info/entry_points.txt,sha256=hxHD2ixb_vJFDOlZV-zB4Ao32_BQlL34ftsDh1GXv14,420
|
391
|
+
flwr_nightly-1.21.0.dev20250827.dist-info/RECORD,,
|
{flwr_nightly-1.21.0.dev20250826.dist-info → flwr_nightly-1.21.0.dev20250827.dist-info}/WHEEL
RENAMED
File without changes
|
File without changes
|