flwr-nightly 1.22.0.dev20250920__py3-none-any.whl → 1.23.0.dev20250922__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/clientapp/mod/__init__.py +2 -0
- flwr/clientapp/mod/centraldp_mods.py +6 -6
- flwr/clientapp/mod/localdp_mod.py +169 -0
- flwr/serverapp/strategy/__init__.py +4 -0
- flwr/serverapp/strategy/bulyan.py +238 -0
- flwr/serverapp/strategy/fedmedian.py +34 -0
- flwr/serverapp/strategy/fedxgb_bagging.py +36 -1
- flwr/serverapp/strategy/fedxgb_cyclic.py +1 -1
- flwr/serverapp/strategy/krum.py +7 -125
- flwr/serverapp/strategy/multikrum.py +247 -0
- {flwr_nightly-1.22.0.dev20250920.dist-info → flwr_nightly-1.23.0.dev20250922.dist-info}/METADATA +1 -1
- {flwr_nightly-1.22.0.dev20250920.dist-info → flwr_nightly-1.23.0.dev20250922.dist-info}/RECORD +25 -22
- {flwr_nightly-1.22.0.dev20250920.dist-info → flwr_nightly-1.23.0.dev20250922.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.22.0.dev20250920.dist-info → flwr_nightly-1.23.0.dev20250922.dist-info}/entry_points.txt +0 -0
flwr/clientapp/mod/__init__.py
CHANGED
@@ -18,8 +18,10 @@
|
|
18
18
|
from flwr.client.mod.comms_mods import arrays_size_mod, message_size_mod
|
19
19
|
|
20
20
|
from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
|
21
|
+
from .localdp_mod import LocalDpMod
|
21
22
|
|
22
23
|
__all__ = [
|
24
|
+
"LocalDpMod",
|
23
25
|
"adaptiveclipping_mod",
|
24
26
|
"arrays_size_mod",
|
25
27
|
"fixedclipping_mod",
|
@@ -89,7 +89,7 @@ def fixedclipping_mod(
|
|
89
89
|
iter(out_msg.content.array_records.items())
|
90
90
|
)
|
91
91
|
# Ensure keys in returned ArrayRecord match those in the one sent from server
|
92
|
-
if
|
92
|
+
if list(original_array_record.keys()) != list(client_to_server_arrecord.keys()):
|
93
93
|
return _handle_array_key_mismatch_err("fixedclipping_mod", out_msg)
|
94
94
|
|
95
95
|
client_to_server_ndarrays = client_to_server_arrecord.to_numpy_ndarrays()
|
@@ -158,6 +158,10 @@ def adaptiveclipping_mod(
|
|
158
158
|
# Call inner app
|
159
159
|
out_msg = call_next(msg, ctxt)
|
160
160
|
|
161
|
+
# Check if the msg has error
|
162
|
+
if out_msg.has_error():
|
163
|
+
return out_msg
|
164
|
+
|
161
165
|
# Ensure reply has a single ArrayRecord
|
162
166
|
if len(out_msg.content.array_records) != 1:
|
163
167
|
return _handle_multi_record_err("adaptiveclipping_mod", out_msg, ArrayRecord)
|
@@ -166,16 +170,12 @@ def adaptiveclipping_mod(
|
|
166
170
|
if len(out_msg.content.metric_records) != 1:
|
167
171
|
return _handle_multi_record_err("adaptiveclipping_mod", out_msg, MetricRecord)
|
168
172
|
|
169
|
-
# Check if the msg has error
|
170
|
-
if out_msg.has_error():
|
171
|
-
return out_msg
|
172
|
-
|
173
173
|
new_array_record_key, client_to_server_arrecord = next(
|
174
174
|
iter(out_msg.content.array_records.items())
|
175
175
|
)
|
176
176
|
|
177
177
|
# Ensure keys in returned ArrayRecord match those in the one sent from server
|
178
|
-
if
|
178
|
+
if list(original_array_record.keys()) != list(client_to_server_arrecord.keys()):
|
179
179
|
return _handle_array_key_mismatch_err("adaptiveclipping_mod", out_msg)
|
180
180
|
|
181
181
|
client_to_server_ndarrays = client_to_server_arrecord.to_numpy_ndarrays()
|
@@ -0,0 +1,169 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Local DP modifier."""
|
16
|
+
|
17
|
+
|
18
|
+
from collections import OrderedDict
|
19
|
+
from logging import INFO
|
20
|
+
|
21
|
+
import numpy as np
|
22
|
+
|
23
|
+
from flwr.clientapp.typing import ClientAppCallable
|
24
|
+
from flwr.common import Array, ArrayRecord
|
25
|
+
from flwr.common.context import Context
|
26
|
+
from flwr.common.differential_privacy import (
|
27
|
+
add_gaussian_noise_inplace,
|
28
|
+
compute_clip_model_update,
|
29
|
+
)
|
30
|
+
from flwr.common.logger import log
|
31
|
+
from flwr.common.message import Message
|
32
|
+
|
33
|
+
from .centraldp_mods import _handle_array_key_mismatch_err, _handle_multi_record_err
|
34
|
+
|
35
|
+
|
36
|
+
class LocalDpMod:
|
37
|
+
"""Modifier for local differential privacy.
|
38
|
+
|
39
|
+
This mod clips the client model updates and
|
40
|
+
adds noise to the params before sending them to the server.
|
41
|
+
|
42
|
+
It operates on messages of type `MessageType.TRAIN`.
|
43
|
+
|
44
|
+
Parameters
|
45
|
+
----------
|
46
|
+
clipping_norm : float
|
47
|
+
The value of the clipping norm.
|
48
|
+
sensitivity : float
|
49
|
+
The sensitivity of the client model.
|
50
|
+
epsilon : float
|
51
|
+
The privacy budget.
|
52
|
+
Smaller value of epsilon indicates a higher level of privacy protection.
|
53
|
+
delta : float
|
54
|
+
The failure probability.
|
55
|
+
The probability that the privacy mechanism
|
56
|
+
fails to provide the desired level of privacy.
|
57
|
+
A smaller value of delta indicates a stricter privacy guarantee.
|
58
|
+
|
59
|
+
Examples
|
60
|
+
--------
|
61
|
+
Create an instance of the local DP mod and add it to the client-side mods::
|
62
|
+
|
63
|
+
local_dp_mod = LocalDpMod( ... )
|
64
|
+
app = fl.client.ClientApp(mods=[local_dp_mod])
|
65
|
+
"""
|
66
|
+
|
67
|
+
def __init__(
|
68
|
+
self, clipping_norm: float, sensitivity: float, epsilon: float, delta: float
|
69
|
+
) -> None:
|
70
|
+
if clipping_norm <= 0:
|
71
|
+
raise ValueError("The clipping norm should be a positive value.")
|
72
|
+
|
73
|
+
if sensitivity < 0:
|
74
|
+
raise ValueError("The sensitivity should be a non-negative value.")
|
75
|
+
|
76
|
+
if epsilon < 0:
|
77
|
+
raise ValueError("Epsilon should be a non-negative value.")
|
78
|
+
|
79
|
+
if delta < 0:
|
80
|
+
raise ValueError("Delta should be a non-negative value.")
|
81
|
+
|
82
|
+
self.clipping_norm = clipping_norm
|
83
|
+
self.sensitivity = sensitivity
|
84
|
+
self.epsilon = epsilon
|
85
|
+
self.delta = delta
|
86
|
+
|
87
|
+
def __call__(
|
88
|
+
self, msg: Message, ctxt: Context, call_next: ClientAppCallable
|
89
|
+
) -> Message:
|
90
|
+
"""Perform local DP on the client model parameters.
|
91
|
+
|
92
|
+
Parameters
|
93
|
+
----------
|
94
|
+
msg : Message
|
95
|
+
The message received from the ServerApp.
|
96
|
+
ctxt : Context
|
97
|
+
The context of the ClientApp.
|
98
|
+
call_next : ClientAppCallable
|
99
|
+
The callable to call the next mod (or the ClientApp) in the chain.
|
100
|
+
|
101
|
+
Returns
|
102
|
+
-------
|
103
|
+
Message
|
104
|
+
The modified message to be sent back to the server.
|
105
|
+
"""
|
106
|
+
if len(msg.content.array_records) != 1:
|
107
|
+
return _handle_multi_record_err("LocalDpMod", msg, ArrayRecord)
|
108
|
+
|
109
|
+
# Record array record communicated to client and clipping norm
|
110
|
+
original_array_record = next(iter(msg.content.array_records.values()))
|
111
|
+
|
112
|
+
# Call inner app
|
113
|
+
out_msg = call_next(msg, ctxt)
|
114
|
+
|
115
|
+
# Check if the msg has error
|
116
|
+
if out_msg.has_error():
|
117
|
+
return out_msg
|
118
|
+
|
119
|
+
# Ensure reply has a single ArrayRecord
|
120
|
+
if len(out_msg.content.array_records) != 1:
|
121
|
+
return _handle_multi_record_err("LocalDpMod", out_msg, ArrayRecord)
|
122
|
+
|
123
|
+
new_array_record_key, client_to_server_arrecord = next(
|
124
|
+
iter(out_msg.content.array_records.items())
|
125
|
+
)
|
126
|
+
|
127
|
+
# Ensure keys in returned ArrayRecord match those in the one sent from server
|
128
|
+
if list(original_array_record.keys()) != list(client_to_server_arrecord.keys()):
|
129
|
+
return _handle_array_key_mismatch_err("LocalDpMod", out_msg)
|
130
|
+
|
131
|
+
client_to_server_ndarrays = client_to_server_arrecord.to_numpy_ndarrays()
|
132
|
+
|
133
|
+
# Clip the client update
|
134
|
+
compute_clip_model_update(
|
135
|
+
client_to_server_ndarrays,
|
136
|
+
original_array_record.to_numpy_ndarrays(),
|
137
|
+
self.clipping_norm,
|
138
|
+
)
|
139
|
+
log(
|
140
|
+
INFO,
|
141
|
+
"LocalDpMod: parameters are clipped by value: %.4f.",
|
142
|
+
self.clipping_norm,
|
143
|
+
)
|
144
|
+
|
145
|
+
std_dev = (
|
146
|
+
self.sensitivity * np.sqrt(2 * np.log(1.25 / self.delta)) / self.epsilon
|
147
|
+
)
|
148
|
+
add_gaussian_noise_inplace(
|
149
|
+
client_to_server_ndarrays,
|
150
|
+
std_dev,
|
151
|
+
)
|
152
|
+
log(
|
153
|
+
INFO,
|
154
|
+
"LocalDpMod: local DP noise with %.4f stddev added to parameters",
|
155
|
+
std_dev,
|
156
|
+
)
|
157
|
+
|
158
|
+
# Replace outgoing ArrayRecord's Array while preserving their keys
|
159
|
+
out_msg.content[new_array_record_key] = ArrayRecord(
|
160
|
+
OrderedDict(
|
161
|
+
{
|
162
|
+
k: Array(v)
|
163
|
+
for k, v in zip(
|
164
|
+
client_to_server_arrecord.keys(), client_to_server_ndarrays
|
165
|
+
)
|
166
|
+
}
|
167
|
+
)
|
168
|
+
)
|
169
|
+
return out_msg
|
@@ -15,6 +15,7 @@
|
|
15
15
|
"""ServerApp strategies."""
|
16
16
|
|
17
17
|
|
18
|
+
from .bulyan import Bulyan
|
18
19
|
from .dp_adaptive_clipping import (
|
19
20
|
DifferentialPrivacyClientSideAdaptiveClipping,
|
20
21
|
DifferentialPrivacyServerSideAdaptiveClipping,
|
@@ -34,11 +35,13 @@ from .fedxgb_bagging import FedXgbBagging
|
|
34
35
|
from .fedxgb_cyclic import FedXgbCyclic
|
35
36
|
from .fedyogi import FedYogi
|
36
37
|
from .krum import Krum
|
38
|
+
from .multikrum import MultiKrum
|
37
39
|
from .qfedavg import QFedAvg
|
38
40
|
from .result import Result
|
39
41
|
from .strategy import Strategy
|
40
42
|
|
41
43
|
__all__ = [
|
44
|
+
"Bulyan",
|
42
45
|
"DifferentialPrivacyClientSideAdaptiveClipping",
|
43
46
|
"DifferentialPrivacyClientSideFixedClipping",
|
44
47
|
"DifferentialPrivacyServerSideAdaptiveClipping",
|
@@ -54,6 +57,7 @@ __all__ = [
|
|
54
57
|
"FedXgbCyclic",
|
55
58
|
"FedYogi",
|
56
59
|
"Krum",
|
60
|
+
"MultiKrum",
|
57
61
|
"QFedAvg",
|
58
62
|
"Result",
|
59
63
|
"Strategy",
|
@@ -0,0 +1,238 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Bulyan [El Mhamdi et al., 2018] strategy.
|
16
|
+
|
17
|
+
Paper: arxiv.org/abs/1802.07927
|
18
|
+
"""
|
19
|
+
|
20
|
+
|
21
|
+
from collections import OrderedDict
|
22
|
+
from collections.abc import Iterable
|
23
|
+
from logging import INFO, WARN
|
24
|
+
from typing import Callable, Optional, cast
|
25
|
+
|
26
|
+
import numpy as np
|
27
|
+
|
28
|
+
from flwr.common import (
|
29
|
+
Array,
|
30
|
+
ArrayRecord,
|
31
|
+
Message,
|
32
|
+
MetricRecord,
|
33
|
+
NDArrays,
|
34
|
+
RecordDict,
|
35
|
+
log,
|
36
|
+
)
|
37
|
+
|
38
|
+
from .fedavg import FedAvg
|
39
|
+
from .multikrum import select_multikrum
|
40
|
+
|
41
|
+
|
42
|
+
# pylint: disable=too-many-instance-attributes
|
43
|
+
class Bulyan(FedAvg):
|
44
|
+
"""Bulyan strategy.
|
45
|
+
|
46
|
+
Implementation based on https://arxiv.org/abs/1802.07927.
|
47
|
+
|
48
|
+
Parameters
|
49
|
+
----------
|
50
|
+
fraction_train : float (default: 1.0)
|
51
|
+
Fraction of nodes used during training. In case `min_train_nodes`
|
52
|
+
is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
|
53
|
+
will still be sampled.
|
54
|
+
fraction_evaluate : float (default: 1.0)
|
55
|
+
Fraction of nodes used during validation. In case `min_evaluate_nodes`
|
56
|
+
is larger than `fraction_evaluate * total_connected_nodes`,
|
57
|
+
`min_evaluate_nodes` will still be sampled.
|
58
|
+
min_train_nodes : int (default: 2)
|
59
|
+
Minimum number of nodes used during training.
|
60
|
+
min_evaluate_nodes : int (default: 2)
|
61
|
+
Minimum number of nodes used during validation.
|
62
|
+
min_available_nodes : int (default: 2)
|
63
|
+
Minimum number of total nodes in the system.
|
64
|
+
num_malicious_nodes : int (default: 0)
|
65
|
+
Number of malicious nodes in the system.
|
66
|
+
weighted_by_key : str (default: "num-examples")
|
67
|
+
The key within each MetricRecord whose value is used as the weight when
|
68
|
+
computing weighted averages for MetricRecords.
|
69
|
+
arrayrecord_key : str (default: "arrays")
|
70
|
+
Key used to store the ArrayRecord when constructing Messages.
|
71
|
+
configrecord_key : str (default: "config")
|
72
|
+
Key used to store the ConfigRecord when constructing Messages.
|
73
|
+
train_metrics_aggr_fn : Optional[callable] (default: None)
|
74
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
75
|
+
used to aggregate MetricRecords from training round replies.
|
76
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
77
|
+
average using the provided weight factor key.
|
78
|
+
evaluate_metrics_aggr_fn : Optional[callable] (default: None)
|
79
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
80
|
+
used to aggregate MetricRecords from training round replies.
|
81
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
82
|
+
average using the provided weight factor key.
|
83
|
+
selection_rule : Optional[Callable] (default: None)
|
84
|
+
Function with signature (list[RecordDict], int, int) -> list[RecordDict].
|
85
|
+
The inputs are:
|
86
|
+
- a list of contents from reply messages,
|
87
|
+
- the assumed number of malicious nodes (`num_malicious_nodes`),
|
88
|
+
- the number of nodes to select (`num_nodes_to_select`).
|
89
|
+
|
90
|
+
The function should implement a Byzantine-resilient selection rule that
|
91
|
+
serves as the first step of Bulyan. If None, defaults to `select_multikrum`,
|
92
|
+
which selects nodes according to the Multi-Krum algorithm.
|
93
|
+
"""
|
94
|
+
|
95
|
+
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
96
|
+
def __init__(
|
97
|
+
self,
|
98
|
+
fraction_train: float = 1.0,
|
99
|
+
fraction_evaluate: float = 1.0,
|
100
|
+
min_train_nodes: int = 2,
|
101
|
+
min_evaluate_nodes: int = 2,
|
102
|
+
min_available_nodes: int = 2,
|
103
|
+
num_malicious_nodes: int = 0,
|
104
|
+
weighted_by_key: str = "num-examples",
|
105
|
+
arrayrecord_key: str = "arrays",
|
106
|
+
configrecord_key: str = "config",
|
107
|
+
train_metrics_aggr_fn: Optional[
|
108
|
+
Callable[[list[RecordDict], str], MetricRecord]
|
109
|
+
] = None,
|
110
|
+
evaluate_metrics_aggr_fn: Optional[
|
111
|
+
Callable[[list[RecordDict], str], MetricRecord]
|
112
|
+
] = None,
|
113
|
+
selection_rule: Optional[
|
114
|
+
Callable[[list[RecordDict], int, int], list[RecordDict]]
|
115
|
+
] = None,
|
116
|
+
) -> None:
|
117
|
+
super().__init__(
|
118
|
+
fraction_train=fraction_train,
|
119
|
+
fraction_evaluate=fraction_evaluate,
|
120
|
+
min_train_nodes=min_train_nodes,
|
121
|
+
min_evaluate_nodes=min_evaluate_nodes,
|
122
|
+
min_available_nodes=min_available_nodes,
|
123
|
+
weighted_by_key=weighted_by_key,
|
124
|
+
arrayrecord_key=arrayrecord_key,
|
125
|
+
configrecord_key=configrecord_key,
|
126
|
+
train_metrics_aggr_fn=train_metrics_aggr_fn,
|
127
|
+
evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
|
128
|
+
)
|
129
|
+
self.num_malicious_nodes = num_malicious_nodes
|
130
|
+
self.selection_rule = selection_rule or select_multikrum
|
131
|
+
|
132
|
+
def summary(self) -> None:
|
133
|
+
"""Log summary configuration of the strategy."""
|
134
|
+
log(INFO, "\t├──> Bulyan settings:")
|
135
|
+
log(INFO, "\t│\t├── Number of malicious nodes: %d", self.num_malicious_nodes)
|
136
|
+
log(INFO, "\t│\t└── Selection rule: %s", self.selection_rule.__name__)
|
137
|
+
super().summary()
|
138
|
+
|
139
|
+
def aggregate_train(
|
140
|
+
self,
|
141
|
+
server_round: int,
|
142
|
+
replies: Iterable[Message],
|
143
|
+
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
144
|
+
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
145
|
+
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
146
|
+
|
147
|
+
# Check if sufficient replies have been received
|
148
|
+
if len(valid_replies) < 4 * self.num_malicious_nodes + 3:
|
149
|
+
log(
|
150
|
+
WARN,
|
151
|
+
"Insufficient replies, skipping Bulyan aggregation: "
|
152
|
+
"Required at least %d (4*num_malicious_nodes + 3), but received %d.",
|
153
|
+
4 * self.num_malicious_nodes + 3,
|
154
|
+
len(valid_replies),
|
155
|
+
)
|
156
|
+
return None, None
|
157
|
+
|
158
|
+
reply_contents = [msg.content for msg in valid_replies]
|
159
|
+
|
160
|
+
# Compute theta and beta
|
161
|
+
theta = len(valid_replies) - 2 * self.num_malicious_nodes
|
162
|
+
beta = theta - 2 * self.num_malicious_nodes
|
163
|
+
|
164
|
+
# Byzantine-resilient selection rule
|
165
|
+
selected_contents = self.selection_rule(
|
166
|
+
reply_contents, self.num_malicious_nodes, theta
|
167
|
+
)
|
168
|
+
|
169
|
+
# Convert each ArrayRecord to a list of NDArray for easier computation
|
170
|
+
key = list(selected_contents[0].array_records.keys())[0]
|
171
|
+
array_keys = list(selected_contents[0][key].keys())
|
172
|
+
selected_ndarrays = [
|
173
|
+
cast(ArrayRecord, ctnt[key]).to_numpy_ndarrays(keep_input=False)
|
174
|
+
for ctnt in selected_contents
|
175
|
+
]
|
176
|
+
|
177
|
+
# Compute median
|
178
|
+
median_ndarrays = [np.median(arr, axis=0) for arr in zip(*selected_ndarrays)]
|
179
|
+
|
180
|
+
# Aggregate the beta closest weights element-wise
|
181
|
+
aggregated_ndarrays = aggregate_n_closest_weights(
|
182
|
+
median_ndarrays, selected_ndarrays, beta
|
183
|
+
)
|
184
|
+
|
185
|
+
# Convert to ArrayRecord
|
186
|
+
arrays = ArrayRecord(
|
187
|
+
OrderedDict(zip(array_keys, map(Array, aggregated_ndarrays)))
|
188
|
+
)
|
189
|
+
|
190
|
+
# Aggregate MetricRecords
|
191
|
+
metrics = self.train_metrics_aggr_fn(
|
192
|
+
selected_contents,
|
193
|
+
self.weighted_by_key,
|
194
|
+
)
|
195
|
+
return arrays, metrics
|
196
|
+
|
197
|
+
|
198
|
+
def aggregate_n_closest_weights(
|
199
|
+
ref_weights: NDArrays, weights_list: list[NDArrays], beta: int
|
200
|
+
) -> NDArrays:
|
201
|
+
"""Compute the element-wise mean of the `beta` closest weight arrays.
|
202
|
+
|
203
|
+
For each element (i-th coordinate), the output is the average of the
|
204
|
+
`beta` weight arrays that are closest to the reference weights.
|
205
|
+
|
206
|
+
Parameters
|
207
|
+
----------
|
208
|
+
ref_weights : NDArrays
|
209
|
+
Reference weights used to compute distances.
|
210
|
+
weights_list : list[NDArrays]
|
211
|
+
List of weight arrays (e.g., from selected nodes).
|
212
|
+
beta : int
|
213
|
+
Number of closest weight arrays to include in the averaging.
|
214
|
+
|
215
|
+
Returns
|
216
|
+
-------
|
217
|
+
aggregated_weights : NDArrays
|
218
|
+
Element-wise average of the `beta` closest weight arrays to the
|
219
|
+
reference weights.
|
220
|
+
"""
|
221
|
+
aggregated_weights = []
|
222
|
+
for layer_id, ref_layer in enumerate(ref_weights):
|
223
|
+
# Shape: (n_models, *layer_shape)
|
224
|
+
layer_stack = np.stack([weights[layer_id] for weights in weights_list])
|
225
|
+
|
226
|
+
# Compute absolute differences: shape (n_models, *layer_shape)
|
227
|
+
diffs = np.abs(layer_stack - ref_layer)
|
228
|
+
|
229
|
+
# Find indices of `beta` smallest per coordinate
|
230
|
+
idx = np.argpartition(diffs, beta - 1, axis=0)[:beta]
|
231
|
+
|
232
|
+
# Gather the closest weights
|
233
|
+
closest = np.take_along_axis(layer_stack, idx, axis=0)
|
234
|
+
|
235
|
+
# Average them
|
236
|
+
aggregated_weights.append(np.mean(closest, axis=0))
|
237
|
+
|
238
|
+
return aggregated_weights
|
@@ -32,6 +32,40 @@ class FedMedian(FedAvg):
|
|
32
32
|
"""Federated Median (FedMedian) strategy.
|
33
33
|
|
34
34
|
Implementation based on https://arxiv.org/pdf/1803.01498v1
|
35
|
+
|
36
|
+
Parameters
|
37
|
+
----------
|
38
|
+
fraction_train : float (default: 1.0)
|
39
|
+
Fraction of nodes used during training. In case `min_train_nodes`
|
40
|
+
is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
|
41
|
+
will still be sampled.
|
42
|
+
fraction_evaluate : float (default: 1.0)
|
43
|
+
Fraction of nodes used during validation. In case `min_evaluate_nodes`
|
44
|
+
is larger than `fraction_evaluate * total_connected_nodes`,
|
45
|
+
`min_evaluate_nodes` will still be sampled.
|
46
|
+
min_train_nodes : int (default: 2)
|
47
|
+
Minimum number of nodes used during training.
|
48
|
+
min_evaluate_nodes : int (default: 2)
|
49
|
+
Minimum number of nodes used during validation.
|
50
|
+
min_available_nodes : int (default: 2)
|
51
|
+
Minimum number of total nodes in the system.
|
52
|
+
weighted_by_key : str (default: "num-examples")
|
53
|
+
The key within each MetricRecord whose value is used as the weight when
|
54
|
+
computing weighted averages for MetricRecords.
|
55
|
+
arrayrecord_key : str (default: "arrays")
|
56
|
+
Key used to store the ArrayRecord when constructing Messages.
|
57
|
+
configrecord_key : str (default: "config")
|
58
|
+
Key used to store the ConfigRecord when constructing Messages.
|
59
|
+
train_metrics_aggr_fn : Optional[callable] (default: None)
|
60
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
61
|
+
used to aggregate MetricRecords from training round replies.
|
62
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
63
|
+
average using the provided weight factor key.
|
64
|
+
evaluate_metrics_aggr_fn : Optional[callable] (default: None)
|
65
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
66
|
+
used to aggregate MetricRecords from training round replies.
|
67
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
68
|
+
average using the provided weight factor key.
|
35
69
|
"""
|
36
70
|
|
37
71
|
def aggregate_train(
|
@@ -28,7 +28,42 @@ from .strategy_utils import aggregate_bagging
|
|
28
28
|
|
29
29
|
# pylint: disable=line-too-long
|
30
30
|
class FedXgbBagging(FedAvg):
|
31
|
-
"""Configurable FedXgbBagging strategy implementation.
|
31
|
+
"""Configurable FedXgbBagging strategy implementation.
|
32
|
+
|
33
|
+
Parameters
|
34
|
+
----------
|
35
|
+
fraction_train : float (default: 1.0)
|
36
|
+
Fraction of nodes used during training. In case `min_train_nodes`
|
37
|
+
is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
|
38
|
+
will still be sampled.
|
39
|
+
fraction_evaluate : float (default: 1.0)
|
40
|
+
Fraction of nodes used during validation. In case `min_evaluate_nodes`
|
41
|
+
is larger than `fraction_evaluate * total_connected_nodes`,
|
42
|
+
`min_evaluate_nodes` will still be sampled.
|
43
|
+
min_train_nodes : int (default: 2)
|
44
|
+
Minimum number of nodes used during training.
|
45
|
+
min_evaluate_nodes : int (default: 2)
|
46
|
+
Minimum number of nodes used during validation.
|
47
|
+
min_available_nodes : int (default: 2)
|
48
|
+
Minimum number of total nodes in the system.
|
49
|
+
weighted_by_key : str (default: "num-examples")
|
50
|
+
The key within each MetricRecord whose value is used as the weight when
|
51
|
+
computing weighted averages for MetricRecords.
|
52
|
+
arrayrecord_key : str (default: "arrays")
|
53
|
+
Key used to store the ArrayRecord when constructing Messages.
|
54
|
+
configrecord_key : str (default: "config")
|
55
|
+
Key used to store the ConfigRecord when constructing Messages.
|
56
|
+
train_metrics_aggr_fn : Optional[callable] (default: None)
|
57
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
58
|
+
used to aggregate MetricRecords from training round replies.
|
59
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
60
|
+
average using the provided weight factor key.
|
61
|
+
evaluate_metrics_aggr_fn : Optional[callable] (default: None)
|
62
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
63
|
+
used to aggregate MetricRecords from training round replies.
|
64
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
65
|
+
average using the provided weight factor key.
|
66
|
+
"""
|
32
67
|
|
33
68
|
current_bst: Optional[bytes] = None
|
34
69
|
|
@@ -52,7 +52,7 @@ class FedXgbCyclic(FedAvg):
|
|
52
52
|
Minimum number of total nodes in the system.
|
53
53
|
weighted_by_key : str (default: "num-examples")
|
54
54
|
The key within each MetricRecord whose value is used as the weight when
|
55
|
-
computing weighted averages for
|
55
|
+
computing weighted averages for MetricRecords.
|
56
56
|
arrayrecord_key : str (default: "arrays")
|
57
57
|
Key used to store the ArrayRecord when constructing Messages.
|
58
58
|
configrecord_key : str (default: "config")
|
flwr/serverapp/strategy/krum.py
CHANGED
@@ -20,20 +20,16 @@ Paper: proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-P
|
|
20
20
|
"""
|
21
21
|
|
22
22
|
|
23
|
-
from collections.abc import Iterable
|
24
23
|
from logging import INFO
|
25
24
|
from typing import Callable, Optional
|
26
25
|
|
27
|
-
import
|
26
|
+
from flwr.common import MetricRecord, RecordDict, log
|
28
27
|
|
29
|
-
from
|
30
|
-
|
31
|
-
from .fedavg import FedAvg
|
32
|
-
from .strategy_utils import aggregate_arrayrecords
|
28
|
+
from .multikrum import MultiKrum
|
33
29
|
|
34
30
|
|
35
31
|
# pylint: disable=too-many-instance-attributes
|
36
|
-
class Krum(
|
32
|
+
class Krum(MultiKrum):
|
37
33
|
"""Krum [Blanchard et al., 2017] strategy.
|
38
34
|
|
39
35
|
Implementation based on https://arxiv.org/abs/1703.02757
|
@@ -56,12 +52,9 @@ class Krum(FedAvg):
|
|
56
52
|
Minimum number of total nodes in the system.
|
57
53
|
num_malicious_nodes : int (default: 0)
|
58
54
|
Number of malicious nodes in the system. Defaults to 0.
|
59
|
-
num_nodes_to_keep : int (default: 0)
|
60
|
-
Number of nodes to keep before averaging (MultiKrum). Defaults to 0, in
|
61
|
-
that case classical Krum is applied.
|
62
55
|
weighted_by_key : str (default: "num-examples")
|
63
56
|
The key within each MetricRecord whose value is used as the weight when
|
64
|
-
computing weighted averages for
|
57
|
+
computing weighted averages for MetricRecords.
|
65
58
|
arrayrecord_key : str (default: "arrays")
|
66
59
|
Key used to store the ArrayRecord when constructing Messages.
|
67
60
|
configrecord_key : str (default: "config")
|
@@ -87,7 +80,6 @@ class Krum(FedAvg):
|
|
87
80
|
min_evaluate_nodes: int = 2,
|
88
81
|
min_available_nodes: int = 2,
|
89
82
|
num_malicious_nodes: int = 0,
|
90
|
-
num_nodes_to_keep: int = 0,
|
91
83
|
weighted_by_key: str = "num-examples",
|
92
84
|
arrayrecord_key: str = "arrays",
|
93
85
|
configrecord_key: str = "config",
|
@@ -105,126 +97,16 @@ class Krum(FedAvg):
|
|
105
97
|
min_evaluate_nodes=min_evaluate_nodes,
|
106
98
|
min_available_nodes=min_available_nodes,
|
107
99
|
weighted_by_key=weighted_by_key,
|
100
|
+
num_malicious_nodes=num_malicious_nodes,
|
101
|
+
num_nodes_to_select=1, # Krum selects 1 node
|
108
102
|
arrayrecord_key=arrayrecord_key,
|
109
103
|
configrecord_key=configrecord_key,
|
110
104
|
train_metrics_aggr_fn=train_metrics_aggr_fn,
|
111
105
|
evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
|
112
106
|
)
|
113
|
-
self.num_malicious_nodes = num_malicious_nodes
|
114
|
-
self.num_nodes_to_keep = num_nodes_to_keep
|
115
107
|
|
116
108
|
def summary(self) -> None:
|
117
109
|
"""Log summary configuration of the strategy."""
|
118
110
|
log(INFO, "\t├──> Krum settings:")
|
119
|
-
log(INFO, "\t│\t
|
120
|
-
log(INFO, "\t│\t└── Number of nodes to keep: %d", self.num_nodes_to_keep)
|
111
|
+
log(INFO, "\t│\t└── Number of malicious nodes: %d", self.num_malicious_nodes)
|
121
112
|
super().summary()
|
122
|
-
|
123
|
-
def _compute_distances(self, records: list[ArrayRecord]) -> NDArray:
|
124
|
-
"""Compute distances between ArrayRecords.
|
125
|
-
|
126
|
-
Parameters
|
127
|
-
----------
|
128
|
-
records : list[ArrayRecord]
|
129
|
-
A list of ArrayRecords (arrays received in replies)
|
130
|
-
|
131
|
-
Returns
|
132
|
-
-------
|
133
|
-
NDArray
|
134
|
-
A 2D array representing the distance matrix of squared distances
|
135
|
-
between input ArrayRecords
|
136
|
-
"""
|
137
|
-
flat_w = np.array(
|
138
|
-
[
|
139
|
-
np.concatenate(rec.to_numpy_ndarrays(), axis=None).ravel()
|
140
|
-
for rec in records
|
141
|
-
]
|
142
|
-
)
|
143
|
-
distance_matrix = np.zeros((len(records), len(records)))
|
144
|
-
for i, flat_w_i in enumerate(flat_w):
|
145
|
-
for j, flat_w_j in enumerate(flat_w):
|
146
|
-
delta = flat_w_i - flat_w_j
|
147
|
-
norm = np.linalg.norm(delta)
|
148
|
-
distance_matrix[i, j] = norm**2
|
149
|
-
return distance_matrix
|
150
|
-
|
151
|
-
def _krum(self, replies: list[RecordDict]) -> list[RecordDict]:
|
152
|
-
"""Select the set of RecordDicts to aggregate using the Krum or MultiKrum
|
153
|
-
algorithm.
|
154
|
-
|
155
|
-
For each node, computes the sum of squared distances to its n-f-2 closest
|
156
|
-
parameter vectors, where n is the number of nodes and f is the number of
|
157
|
-
malicious nodes. The node(s) with the lowest score(s) are selected for
|
158
|
-
aggregation.
|
159
|
-
|
160
|
-
Parameters
|
161
|
-
----------
|
162
|
-
replies : list[RecordDict]
|
163
|
-
List of RecordDicts, each containing an ArrayRecord representing model
|
164
|
-
parameters from a client.
|
165
|
-
|
166
|
-
Returns
|
167
|
-
-------
|
168
|
-
list[RecordDict]
|
169
|
-
List of RecordDicts selected for aggregation. If `num_nodes_to_keep` > 0,
|
170
|
-
returns the top `num_nodes_to_keep` RecordDicts (MultiKrum); otherwise,
|
171
|
-
returns the single RecordDict with the lowest score (Krum).
|
172
|
-
"""
|
173
|
-
# Construct list of ArrayRecord objects from replies
|
174
|
-
# Recall aggregate_train first ensures replies only contain one ArrayRecord
|
175
|
-
array_records = [list(reply.array_records.values())[0] for reply in replies]
|
176
|
-
distance_matrix = self._compute_distances(array_records)
|
177
|
-
|
178
|
-
# For each node, take the n-f-2 closest parameters vectors
|
179
|
-
num_closest = max(1, len(array_records) - self.num_malicious_nodes - 2)
|
180
|
-
closest_indices = []
|
181
|
-
for distance in distance_matrix:
|
182
|
-
closest_indices.append(
|
183
|
-
np.argsort(distance)[1 : num_closest + 1].tolist() # noqa: E203
|
184
|
-
)
|
185
|
-
|
186
|
-
# Compute the score for each node, that is the sum of the distances
|
187
|
-
# of the n-f-2 closest parameters vectors
|
188
|
-
scores = [
|
189
|
-
np.sum(distance_matrix[i, closest_indices[i]])
|
190
|
-
for i in range(len(distance_matrix))
|
191
|
-
]
|
192
|
-
|
193
|
-
# Return RecordDicts that should be aggregated
|
194
|
-
if self.num_nodes_to_keep > 0:
|
195
|
-
# Choose to_keep nodes and return their average (MultiKrum)
|
196
|
-
best_indices = np.argsort(scores)[::-1][
|
197
|
-
len(scores) - self.num_nodes_to_keep :
|
198
|
-
] # noqa: E203
|
199
|
-
return [replies[i] for i in best_indices]
|
200
|
-
|
201
|
-
# Return the RecordDict with the ArrayRecord that minimize the score (Krum)
|
202
|
-
return [replies[np.argmin(scores)]]
|
203
|
-
|
204
|
-
def aggregate_train(
|
205
|
-
self,
|
206
|
-
server_round: int,
|
207
|
-
replies: Iterable[Message],
|
208
|
-
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
209
|
-
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
210
|
-
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
211
|
-
|
212
|
-
arrays, metrics = None, None
|
213
|
-
if valid_replies:
|
214
|
-
reply_contents = [msg.content for msg in valid_replies]
|
215
|
-
|
216
|
-
# Krum
|
217
|
-
replies_to_aggregate = self._krum(reply_contents)
|
218
|
-
|
219
|
-
# Aggregate ArrayRecords
|
220
|
-
arrays = aggregate_arrayrecords(
|
221
|
-
replies_to_aggregate,
|
222
|
-
self.weighted_by_key,
|
223
|
-
)
|
224
|
-
|
225
|
-
# Aggregate MetricRecords
|
226
|
-
metrics = self.train_metrics_aggr_fn(
|
227
|
-
replies_to_aggregate,
|
228
|
-
self.weighted_by_key,
|
229
|
-
)
|
230
|
-
return arrays, metrics
|
@@ -0,0 +1,247 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent.
|
16
|
+
|
17
|
+
[Blanchard et al., 2017].
|
18
|
+
|
19
|
+
Paper: proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf
|
20
|
+
"""
|
21
|
+
|
22
|
+
|
23
|
+
from collections.abc import Iterable
|
24
|
+
from logging import INFO
|
25
|
+
from typing import Callable, Optional, cast
|
26
|
+
|
27
|
+
import numpy as np
|
28
|
+
|
29
|
+
from flwr.common import ArrayRecord, Message, MetricRecord, NDArray, RecordDict, log
|
30
|
+
|
31
|
+
from .fedavg import FedAvg
|
32
|
+
from .strategy_utils import aggregate_arrayrecords
|
33
|
+
|
34
|
+
|
35
|
+
# pylint: disable=too-many-instance-attributes
|
36
|
+
class MultiKrum(FedAvg):
|
37
|
+
"""MultiKrum [Blanchard et al., 2017] strategy.
|
38
|
+
|
39
|
+
Implementation based on https://arxiv.org/abs/1703.02757
|
40
|
+
|
41
|
+
Parameters
|
42
|
+
----------
|
43
|
+
fraction_train : float (default: 1.0)
|
44
|
+
Fraction of nodes used during training. In case `min_train_nodes`
|
45
|
+
is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
|
46
|
+
will still be sampled.
|
47
|
+
fraction_evaluate : float (default: 1.0)
|
48
|
+
Fraction of nodes used during validation. In case `min_evaluate_nodes`
|
49
|
+
is larger than `fraction_evaluate * total_connected_nodes`,
|
50
|
+
`min_evaluate_nodes` will still be sampled.
|
51
|
+
min_train_nodes : int (default: 2)
|
52
|
+
Minimum number of nodes used during training.
|
53
|
+
min_evaluate_nodes : int (default: 2)
|
54
|
+
Minimum number of nodes used during validation.
|
55
|
+
min_available_nodes : int (default: 2)
|
56
|
+
Minimum number of total nodes in the system.
|
57
|
+
num_malicious_nodes : int (default: 0)
|
58
|
+
Number of malicious nodes in the system. Defaults to 0.
|
59
|
+
num_nodes_to_select : int (default: 1)
|
60
|
+
Number of nodes to select before averaging.
|
61
|
+
weighted_by_key : str (default: "num-examples")
|
62
|
+
The key within each MetricRecord whose value is used as the weight when
|
63
|
+
computing weighted averages for both ArrayRecords and MetricRecords.
|
64
|
+
arrayrecord_key : str (default: "arrays")
|
65
|
+
Key used to store the ArrayRecord when constructing Messages.
|
66
|
+
configrecord_key : str (default: "config")
|
67
|
+
Key used to store the ConfigRecord when constructing Messages.
|
68
|
+
train_metrics_aggr_fn : Optional[callable] (default: None)
|
69
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
70
|
+
used to aggregate MetricRecords from training round replies.
|
71
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
72
|
+
average using the provided weight factor key.
|
73
|
+
evaluate_metrics_aggr_fn : Optional[callable] (default: None)
|
74
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
75
|
+
used to aggregate MetricRecords from training round replies.
|
76
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
77
|
+
average using the provided weight factor key.
|
78
|
+
|
79
|
+
Notes
|
80
|
+
-----
|
81
|
+
MultiKrum is a generalization of Krum. If `num_nodes_to_select` is set to 1,
|
82
|
+
MultiKrum will reduce to classical Krum.
|
83
|
+
"""
|
84
|
+
|
85
|
+
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
86
|
+
def __init__(
|
87
|
+
self,
|
88
|
+
fraction_train: float = 1.0,
|
89
|
+
fraction_evaluate: float = 1.0,
|
90
|
+
min_train_nodes: int = 2,
|
91
|
+
min_evaluate_nodes: int = 2,
|
92
|
+
min_available_nodes: int = 2,
|
93
|
+
num_malicious_nodes: int = 0,
|
94
|
+
num_nodes_to_select: int = 1,
|
95
|
+
weighted_by_key: str = "num-examples",
|
96
|
+
arrayrecord_key: str = "arrays",
|
97
|
+
configrecord_key: str = "config",
|
98
|
+
train_metrics_aggr_fn: Optional[
|
99
|
+
Callable[[list[RecordDict], str], MetricRecord]
|
100
|
+
] = None,
|
101
|
+
evaluate_metrics_aggr_fn: Optional[
|
102
|
+
Callable[[list[RecordDict], str], MetricRecord]
|
103
|
+
] = None,
|
104
|
+
) -> None:
|
105
|
+
super().__init__(
|
106
|
+
fraction_train=fraction_train,
|
107
|
+
fraction_evaluate=fraction_evaluate,
|
108
|
+
min_train_nodes=min_train_nodes,
|
109
|
+
min_evaluate_nodes=min_evaluate_nodes,
|
110
|
+
min_available_nodes=min_available_nodes,
|
111
|
+
weighted_by_key=weighted_by_key,
|
112
|
+
arrayrecord_key=arrayrecord_key,
|
113
|
+
configrecord_key=configrecord_key,
|
114
|
+
train_metrics_aggr_fn=train_metrics_aggr_fn,
|
115
|
+
evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
|
116
|
+
)
|
117
|
+
self.num_malicious_nodes = num_malicious_nodes
|
118
|
+
self.num_nodes_to_select = num_nodes_to_select
|
119
|
+
|
120
|
+
def summary(self) -> None:
|
121
|
+
"""Log summary configuration of the strategy."""
|
122
|
+
log(INFO, "\t├──> MultiKrum settings:")
|
123
|
+
log(INFO, "\t│\t├── Number of malicious nodes: %d", self.num_malicious_nodes)
|
124
|
+
log(INFO, "\t│\t└── Number of nodes to select: %d", self.num_nodes_to_select)
|
125
|
+
super().summary()
|
126
|
+
|
127
|
+
def aggregate_train(
|
128
|
+
self,
|
129
|
+
server_round: int,
|
130
|
+
replies: Iterable[Message],
|
131
|
+
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
132
|
+
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
133
|
+
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
134
|
+
|
135
|
+
arrays, metrics = None, None
|
136
|
+
if valid_replies:
|
137
|
+
reply_contents = [msg.content for msg in valid_replies]
|
138
|
+
|
139
|
+
# Krum or MultiKrum selection
|
140
|
+
replies_to_aggregate = select_multikrum(
|
141
|
+
reply_contents,
|
142
|
+
num_malicious_nodes=self.num_malicious_nodes,
|
143
|
+
num_nodes_to_select=self.num_nodes_to_select,
|
144
|
+
)
|
145
|
+
|
146
|
+
# Aggregate ArrayRecords
|
147
|
+
arrays = aggregate_arrayrecords(
|
148
|
+
replies_to_aggregate,
|
149
|
+
self.weighted_by_key,
|
150
|
+
)
|
151
|
+
|
152
|
+
# Aggregate MetricRecords
|
153
|
+
metrics = self.train_metrics_aggr_fn(
|
154
|
+
replies_to_aggregate,
|
155
|
+
self.weighted_by_key,
|
156
|
+
)
|
157
|
+
return arrays, metrics
|
158
|
+
|
159
|
+
|
160
|
+
def compute_distances(records: list[ArrayRecord]) -> NDArray:
|
161
|
+
"""Compute squared L2 distances between ArrayRecords.
|
162
|
+
|
163
|
+
Parameters
|
164
|
+
----------
|
165
|
+
records : list[ArrayRecord]
|
166
|
+
A list of ArrayRecords (arrays received in replies)
|
167
|
+
|
168
|
+
Returns
|
169
|
+
-------
|
170
|
+
NDArray
|
171
|
+
A 2D array representing the distance matrix of squared L2 distances
|
172
|
+
between input ArrayRecords
|
173
|
+
"""
|
174
|
+
# Formula: ||x - y||^2 = ||x||^2 + ||y||^2 - 2 * x.y
|
175
|
+
# Flatten records and stack them into a matrix
|
176
|
+
flat_w = np.stack(
|
177
|
+
[np.concatenate(rec.to_numpy_ndarrays(), axis=None).ravel() for rec in records],
|
178
|
+
axis=0,
|
179
|
+
) # shape: (n, d) with n number of records and d the dimension of model
|
180
|
+
|
181
|
+
# Compute squared norms of each vector
|
182
|
+
norms: NDArray = np.square(flat_w).sum(axis=1) # shape (n,)
|
183
|
+
|
184
|
+
# Use broadcasting to compute pairwise distances
|
185
|
+
distance_matrix: NDArray = norms[:, None] + norms[None, :] - 2 * flat_w @ flat_w.T
|
186
|
+
return distance_matrix
|
187
|
+
|
188
|
+
|
189
|
+
def select_multikrum(
|
190
|
+
contents: list[RecordDict],
|
191
|
+
num_malicious_nodes: int,
|
192
|
+
num_nodes_to_select: int,
|
193
|
+
) -> list[RecordDict]:
|
194
|
+
"""Select the set of RecordDicts to aggregate using the Krum or MultiKrum algorithm.
|
195
|
+
|
196
|
+
For each node, computes the sum of squared L2 distances to its n-f-2 closest
|
197
|
+
parameter vectors, where n is the number of nodes and f is the number of
|
198
|
+
malicious nodes. The node(s) with the lowest score(s) are selected for
|
199
|
+
aggregation.
|
200
|
+
|
201
|
+
Parameters
|
202
|
+
----------
|
203
|
+
contents : list[RecordDict]
|
204
|
+
List of contents from reply messages, where each content is a RecordDict
|
205
|
+
containing an ArrayRecord of model parameters from a node (client).
|
206
|
+
num_malicious_nodes : int
|
207
|
+
Number of malicious nodes in the system.
|
208
|
+
num_nodes_to_select : int
|
209
|
+
Number of client updates to select.
|
210
|
+
- If 1, the algorithm reduces to classical Krum (selecting a single update).
|
211
|
+
- If >1, Multi-Krum is applied (selecting multiple updates).
|
212
|
+
|
213
|
+
Returns
|
214
|
+
-------
|
215
|
+
list[RecordDict]
|
216
|
+
Selected contents following the Krum or Multi-Krum algorithm.
|
217
|
+
|
218
|
+
Notes
|
219
|
+
-----
|
220
|
+
If `num_nodes_to_select` is set to 1, Multi-Krum reduces to classical Krum
|
221
|
+
and only a single RecordDict is selected.
|
222
|
+
"""
|
223
|
+
# Construct list of ArrayRecord objects from replies
|
224
|
+
record_key = list(contents[0].array_records.keys())[0]
|
225
|
+
# Recall aggregate_train first ensures replies only contain one ArrayRecord
|
226
|
+
array_records = [cast(ArrayRecord, reply[record_key]) for reply in contents]
|
227
|
+
distance_matrix = compute_distances(array_records)
|
228
|
+
|
229
|
+
# For each node, take the n-f-2 closest parameters vectors
|
230
|
+
num_closest = max(1, len(array_records) - num_malicious_nodes - 2)
|
231
|
+
closest_indices = []
|
232
|
+
for distance in distance_matrix:
|
233
|
+
closest_indices.append(
|
234
|
+
np.argsort(distance)[1 : num_closest + 1].tolist() # noqa: E203
|
235
|
+
)
|
236
|
+
|
237
|
+
# Compute the score for each node, that is the sum of the distances
|
238
|
+
# of the n-f-2 closest parameters vectors
|
239
|
+
scores = [
|
240
|
+
np.sum(distance_matrix[i, closest_indices[i]])
|
241
|
+
for i in range(len(distance_matrix))
|
242
|
+
]
|
243
|
+
|
244
|
+
# Choose the num_nodes_to_select lowest-scoring nodes (MultiKrum)
|
245
|
+
# and return their updates
|
246
|
+
best_indices = np.argsort(scores)[:num_nodes_to_select]
|
247
|
+
return [contents[i] for i in best_indices]
|
{flwr_nightly-1.22.0.dev20250920.dist-info → flwr_nightly-1.23.0.dev20250922.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: flwr-nightly
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.23.0.dev20250922
|
4
4
|
Summary: Flower: A Friendly Federated AI Framework
|
5
5
|
License: Apache-2.0
|
6
6
|
Keywords: Artificial Intelligence,Federated AI,Federated Analytics,Federated Evaluation,Federated Learning,Flower,Machine Learning
|
{flwr_nightly-1.22.0.dev20250920.dist-info → flwr_nightly-1.23.0.dev20250922.dist-info}/RECORD
RENAMED
@@ -69,17 +69,17 @@ flwr/cli/new/templates/app/code/task.sklearn.py.tpl,sha256=vHdhtMp0FHxbYafXyhDT9
|
|
69
69
|
flwr/cli/new/templates/app/code/task.tensorflow.py.tpl,sha256=impgWN7MfztmcWF4xh1llcZGsgTvrb1HD5ZE0t-8U08,1731
|
70
70
|
flwr/cli/new/templates/app/code/task.xgboost.py.tpl,sha256=0xO8jQvrHuB1llVDopQPOmt5Hn6rBw8umzoNwiZZs-o,2135
|
71
71
|
flwr/cli/new/templates/app/code/utils.baseline.py.tpl,sha256=YkHAgppUeD2BnBoGfVB6dEvBfjuIPGsU1gw4CiUi3qA,40
|
72
|
-
flwr/cli/new/templates/app/pyproject.baseline.toml.tpl,sha256=
|
73
|
-
flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl,sha256=
|
74
|
-
flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl,sha256=
|
75
|
-
flwr/cli/new/templates/app/pyproject.jax.toml.tpl,sha256=
|
76
|
-
flwr/cli/new/templates/app/pyproject.mlx.toml.tpl,sha256=
|
77
|
-
flwr/cli/new/templates/app/pyproject.numpy.toml.tpl,sha256=
|
78
|
-
flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl,sha256=
|
79
|
-
flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl,sha256=
|
80
|
-
flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl,sha256=
|
81
|
-
flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl,sha256=
|
82
|
-
flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl,sha256=
|
72
|
+
flwr/cli/new/templates/app/pyproject.baseline.toml.tpl,sha256=7t57rL9zAjCC0Vd1T2LMkPXvl2x0V0iVDTap4oC66LY,3182
|
73
|
+
flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl,sha256=iwUlPEC2DsDuIuY3FZmEDxqKkUqcWTLFZXxObFGAXEU,2501
|
74
|
+
flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl,sha256=PiQemjUmmLWCdjiZylXFTZer7VhqxqT4_BXbDYw1uh4,2020
|
75
|
+
flwr/cli/new/templates/app/pyproject.jax.toml.tpl,sha256=S6PvrDW1JPGnVs6Nfn2e6GHg0taZKt3bTmM-5KKyUpo,1471
|
76
|
+
flwr/cli/new/templates/app/pyproject.mlx.toml.tpl,sha256=K00iuPOuqJBGgNE-eFiv5-bahM6PwKtf-6g7lH7FBx8,1542
|
77
|
+
flwr/cli/new/templates/app/pyproject.numpy.toml.tpl,sha256=1ptnw1a7wcb3UxbDVquudZ79Z6oTbJl6q33KAGq0cEo,1409
|
78
|
+
flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl,sha256=TEqf7kWd4tsN3TbNNosN0o2DAWn4W2OFHqipCf-5gCU,1490
|
79
|
+
flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl,sha256=ZWW31QEdrqIvIIWn0UFEYUkgufgDnWQtMSgDGOnWgv8,1508
|
80
|
+
flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl,sha256=Elh5bUTVBNRG1aDTrgQcZgjfrXcIaVXH00UduNJDZ2U,1484
|
81
|
+
flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl,sha256=g7SYiAJr6Uhwg4ZsQI2qsLVeizU8vd0CzR0_jka99_A,1508
|
82
|
+
flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl,sha256=yAJ9jL2q6U_XXYwwUT9fpIqIKFuQR_Kgg82GpWfQ5J8,1661
|
83
83
|
flwr/cli/pull.py,sha256=dHiMe6x8w8yRoFNKpjA-eiPD6eFiHz4Vah5HZrqNpuo,3364
|
84
84
|
flwr/cli/run/__init__.py,sha256=RPyB7KbYTFl6YRiilCch6oezxrLQrl1kijV7BMGkLbA,790
|
85
85
|
flwr/cli/run/run.py,sha256=ECa0kup9dn15O70H74QdgUsEaeErbzDqVX_U0zZO5IM,8173
|
@@ -113,8 +113,9 @@ flwr/client/rest_client/connection.py,sha256=fyiS1aXTv71jWczx7mSco94LYJTBXgTF-p2
|
|
113
113
|
flwr/client/run_info_store.py,sha256=MaJ3UQ-07hWtK67wnWu0zR29jrk0fsfgJX506dvEOfE,4042
|
114
114
|
flwr/client/typing.py,sha256=Jw3rawDzI_-ZDcRmEQcs5gZModY7oeQlEeltYsdOhlU,1048
|
115
115
|
flwr/clientapp/__init__.py,sha256=uoTjvIynfGvMhsmc7iYK-5qJ0AdKKjCbx7WTKc0KeSk,828
|
116
|
-
flwr/clientapp/mod/__init__.py,sha256=
|
117
|
-
flwr/clientapp/mod/centraldp_mods.py,sha256=
|
116
|
+
flwr/clientapp/mod/__init__.py,sha256=w18mDPUmjlm5P_er2GJ7l6Zbh8qyiHfMKoVqXxaTxdE,1024
|
117
|
+
flwr/clientapp/mod/centraldp_mods.py,sha256=bL2NMxbin9nj2OYSHbuOvIPk0rKH6xyAxG-pbtP9ISY,8954
|
118
|
+
flwr/clientapp/mod/localdp_mod.py,sha256=L3kQK3kjc-vwlhJcBwFIK6EiwBYlpHQkv4jiR7K9abQ,5642
|
118
119
|
flwr/clientapp/typing.py,sha256=x1GvXWy112RqZh27liJqz-yZ7SSCOwiOSmAQsUxk9MY,853
|
119
120
|
flwr/common/__init__.py,sha256=5GCLVk399Az_rTJHNticRlL0Sl_oPw_j5_LuFKfX7-M,4171
|
120
121
|
flwr/common/address.py,sha256=9JucdTwlc-jpeJkRKeUboZoacUtErwSVtnDR9kAtLqE,4119
|
@@ -337,21 +338,23 @@ flwr/server/workflow/secure_aggregation/secagg_workflow.py,sha256=b_pKk7gmbahwyj
|
|
337
338
|
flwr/server/workflow/secure_aggregation/secaggplus_workflow.py,sha256=DkayCsnlAya6Y2PZsueLgoUCMRtV-GbnW08RfWx_SXM,29460
|
338
339
|
flwr/serverapp/__init__.py,sha256=ZujKNXULwhWYQhFnxOOT5Wi9MRq2JCWFhAAj7ouiQ78,884
|
339
340
|
flwr/serverapp/exception.py,sha256=5cuH-2AafvihzosWDdDjuMmHdDqZ1XxHvCqZXNBVklw,1334
|
340
|
-
flwr/serverapp/strategy/__init__.py,sha256=
|
341
|
+
flwr/serverapp/strategy/__init__.py,sha256=dezK2TKSffjjBVXW18ATRxJLTuQ7I2M1dPuNi5y-_6c,1968
|
342
|
+
flwr/serverapp/strategy/bulyan.py,sha256=lSFOZof5NDzDFS6206rp6V_08LsitSHIHOOckcLt4_E,9306
|
341
343
|
flwr/serverapp/strategy/dp_adaptive_clipping.py,sha256=mssiVGMgfJw8DeP6_pBSZUKWmaXvYeG-B-p7RSt2tAU,13600
|
342
344
|
flwr/serverapp/strategy/dp_fixed_clipping.py,sha256=C_faT0ggzeUB2bGv_r1Vss-fv7-UhDrpmfiHATESI0w,12832
|
343
345
|
flwr/serverapp/strategy/fedadagrad.py,sha256=faFsuKZziPTCLeNrJOyKbPTNo-1xrIZOz7SWT5rdjJs,6269
|
344
346
|
flwr/serverapp/strategy/fedadam.py,sha256=NsY_V6TGFAfCeA9vmqaLpvB_T5siJEtKozKGdxJssAI,7064
|
345
347
|
flwr/serverapp/strategy/fedavg.py,sha256=Bq_nlmngzJbjqX1fF1mevXGVN6-pwglHv-6yNrs6lkA,12035
|
346
348
|
flwr/serverapp/strategy/fedavgm.py,sha256=FtFmBGLzuUQ_7JWk85Xh19d8sP0YDwqczGTliGzZyGs,8333
|
347
|
-
flwr/serverapp/strategy/fedmedian.py,sha256=
|
349
|
+
flwr/serverapp/strategy/fedmedian.py,sha256=yhGg6WGWYEbn3oYMfnCBm1F7v9u5LHYVsSAYvdI9Pns,4498
|
348
350
|
flwr/serverapp/strategy/fedopt.py,sha256=kqT0uV2IUE93O72XEVa1JJo61dcwbZEoT9KmYTjR2tE,8477
|
349
351
|
flwr/serverapp/strategy/fedprox.py,sha256=J1KrcE5DFko6i4608iICv1G0t9MPXspjibPd-SF_HT8,7028
|
350
352
|
flwr/serverapp/strategy/fedtrimmedavg.py,sha256=58xDPc_YO41QM8jXn0gZ79PFzO8zo3Mh3UlkF0UBbIA,7168
|
351
|
-
flwr/serverapp/strategy/fedxgb_bagging.py,sha256=
|
352
|
-
flwr/serverapp/strategy/fedxgb_cyclic.py,sha256=
|
353
|
+
flwr/serverapp/strategy/fedxgb_bagging.py,sha256=QAnsQXTE1qzj1qGkZ8mCOfmSKJ2pO_gG9YgmJH-EV6s,5189
|
354
|
+
flwr/serverapp/strategy/fedxgb_cyclic.py,sha256=yedhRJ7dEz-Yi5yEiS9zki_LKHPHAk962PWCYLLDLoY,8752
|
353
355
|
flwr/serverapp/strategy/fedyogi.py,sha256=Y9RFBQaNch3fPgGXF7OfnTH6eOpavZxpMWxWVIC9_SY,6579
|
354
|
-
flwr/serverapp/strategy/krum.py,sha256=
|
356
|
+
flwr/serverapp/strategy/krum.py,sha256=kcy8TQuJC_qkYJ7jtn26lx2VYYBHlQC_LjNWOuCw9ZQ,4848
|
357
|
+
flwr/serverapp/strategy/multikrum.py,sha256=bv44TzMQHh1NEp3ilt0ANoR5xHJ4kpb0qxzSzXD_lx0,9865
|
355
358
|
flwr/serverapp/strategy/qfedavg.py,sha256=EM1tO_ovkybOBeW-h1PYX0lszCUAVHT6hUpwXykAEps,10204
|
356
359
|
flwr/serverapp/strategy/result.py,sha256=E0Hl2VLnZAgQJjE2GDoKsK7JX-kPPU2KXc47Axt6hGw,4295
|
357
360
|
flwr/serverapp/strategy/strategy.py,sha256=8uJGGm1ROLZERQ_dkRS7Z_rs-yK6XCE0UxXtIdFiEWk,10789
|
@@ -419,7 +422,7 @@ flwr/supernode/servicer/__init__.py,sha256=lucTzre5WPK7G1YLCfaqg3rbFWdNSb7ZTt-ca
|
|
419
422
|
flwr/supernode/servicer/clientappio/__init__.py,sha256=7Oy62Y_oijqF7Dxi6tpcUQyOpLc_QpIRZ83NvwmB0Yg,813
|
420
423
|
flwr/supernode/servicer/clientappio/clientappio_servicer.py,sha256=nIHRu38EWK-rpNOkcgBRAAKwYQQWFeCwu0lkO7OPZGQ,10239
|
421
424
|
flwr/supernode/start_client_internal.py,sha256=Y9S1-QlO2WP6eo4JvWzIpfaCoh2aoE7bjEYyxNNnlyg,20777
|
422
|
-
flwr_nightly-1.
|
423
|
-
flwr_nightly-1.
|
424
|
-
flwr_nightly-1.
|
425
|
-
flwr_nightly-1.
|
425
|
+
flwr_nightly-1.23.0.dev20250922.dist-info/METADATA,sha256=oINfkQPJKThCiCdU1LRyzabquHXkxkPYfRGIVF67slI,14559
|
426
|
+
flwr_nightly-1.23.0.dev20250922.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
427
|
+
flwr_nightly-1.23.0.dev20250922.dist-info/entry_points.txt,sha256=hxHD2ixb_vJFDOlZV-zB4Ao32_BQlL34ftsDh1GXv14,420
|
428
|
+
flwr_nightly-1.23.0.dev20250922.dist-info/RECORD,,
|
{flwr_nightly-1.22.0.dev20250920.dist-info → flwr_nightly-1.23.0.dev20250922.dist-info}/WHEEL
RENAMED
File without changes
|
File without changes
|