flwr-nightly 1.22.0.dev20250919__py3-none-any.whl → 1.23.0.dev20250921__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  2. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  3. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  4. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  11. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
  12. flwr/clientapp/mod/__init__.py +2 -0
  13. flwr/clientapp/mod/centraldp_mods.py +6 -6
  14. flwr/clientapp/mod/localdp_mod.py +169 -0
  15. flwr/serverapp/strategy/__init__.py +4 -0
  16. flwr/serverapp/strategy/bulyan.py +238 -0
  17. flwr/serverapp/strategy/fedmedian.py +34 -0
  18. flwr/serverapp/strategy/fedxgb_bagging.py +36 -1
  19. flwr/serverapp/strategy/fedxgb_cyclic.py +1 -1
  20. flwr/serverapp/strategy/krum.py +7 -125
  21. flwr/serverapp/strategy/multikrum.py +247 -0
  22. {flwr_nightly-1.22.0.dev20250919.dist-info → flwr_nightly-1.23.0.dev20250921.dist-info}/METADATA +1 -1
  23. {flwr_nightly-1.22.0.dev20250919.dist-info → flwr_nightly-1.23.0.dev20250921.dist-info}/RECORD +25 -22
  24. {flwr_nightly-1.22.0.dev20250919.dist-info → flwr_nightly-1.23.0.dev20250921.dist-info}/WHEEL +0 -0
  25. {flwr_nightly-1.22.0.dev20250919.dist-info → flwr_nightly-1.23.0.dev20250921.dist-info}/entry_points.txt +0 -0
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.22.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
19
  "torch==2.8.0",
20
20
  "torchvision==0.23.0",
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.22.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "flwr-datasets>=0.5.0",
19
19
  "torch==2.4.0",
20
20
  "trl==0.8.1",
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.22.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "flwr-datasets>=0.5.0",
19
19
  "torch>=2.7.1",
20
20
  "transformers>=4.30.0,<5.0",
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.22.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "jax==0.4.30",
19
19
  "jaxlib==0.4.30",
20
20
  "scikit-learn==1.6.1",
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.22.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
19
  "mlx==0.29.0",
20
20
  ]
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.22.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "numpy>=2.0.2",
19
19
  ]
20
20
 
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.22.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
19
  "torch==2.7.1",
20
20
  "torchvision==0.22.1",
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.22.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
19
  "torch==2.7.1",
20
20
  "torchvision==0.22.1",
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.22.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
19
  "scikit-learn>=1.6.1",
20
20
  ]
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.22.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
19
  "tensorflow>=2.11.1,<2.18.0",
20
20
  ]
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.22.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "flwr-datasets>=0.5.0",
19
19
  "xgboost>=2.0.0",
20
20
  ]
@@ -18,8 +18,10 @@
18
18
  from flwr.client.mod.comms_mods import arrays_size_mod, message_size_mod
19
19
 
20
20
  from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
21
+ from .localdp_mod import LocalDpMod
21
22
 
22
23
  __all__ = [
24
+ "LocalDpMod",
23
25
  "adaptiveclipping_mod",
24
26
  "arrays_size_mod",
25
27
  "fixedclipping_mod",
@@ -89,7 +89,7 @@ def fixedclipping_mod(
89
89
  iter(out_msg.content.array_records.items())
90
90
  )
91
91
  # Ensure keys in returned ArrayRecord match those in the one sent from server
92
- if set(original_array_record.keys()) != set(client_to_server_arrecord.keys()):
92
+ if list(original_array_record.keys()) != list(client_to_server_arrecord.keys()):
93
93
  return _handle_array_key_mismatch_err("fixedclipping_mod", out_msg)
94
94
 
95
95
  client_to_server_ndarrays = client_to_server_arrecord.to_numpy_ndarrays()
@@ -158,6 +158,10 @@ def adaptiveclipping_mod(
158
158
  # Call inner app
159
159
  out_msg = call_next(msg, ctxt)
160
160
 
161
+ # Check if the msg has error
162
+ if out_msg.has_error():
163
+ return out_msg
164
+
161
165
  # Ensure reply has a single ArrayRecord
162
166
  if len(out_msg.content.array_records) != 1:
163
167
  return _handle_multi_record_err("adaptiveclipping_mod", out_msg, ArrayRecord)
@@ -166,16 +170,12 @@ def adaptiveclipping_mod(
166
170
  if len(out_msg.content.metric_records) != 1:
167
171
  return _handle_multi_record_err("adaptiveclipping_mod", out_msg, MetricRecord)
168
172
 
169
- # Check if the msg has error
170
- if out_msg.has_error():
171
- return out_msg
172
-
173
173
  new_array_record_key, client_to_server_arrecord = next(
174
174
  iter(out_msg.content.array_records.items())
175
175
  )
176
176
 
177
177
  # Ensure keys in returned ArrayRecord match those in the one sent from server
178
- if set(original_array_record.keys()) != set(client_to_server_arrecord.keys()):
178
+ if list(original_array_record.keys()) != list(client_to_server_arrecord.keys()):
179
179
  return _handle_array_key_mismatch_err("adaptiveclipping_mod", out_msg)
180
180
 
181
181
  client_to_server_ndarrays = client_to_server_arrecord.to_numpy_ndarrays()
@@ -0,0 +1,169 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Local DP modifier."""
16
+
17
+
18
+ from collections import OrderedDict
19
+ from logging import INFO
20
+
21
+ import numpy as np
22
+
23
+ from flwr.clientapp.typing import ClientAppCallable
24
+ from flwr.common import Array, ArrayRecord
25
+ from flwr.common.context import Context
26
+ from flwr.common.differential_privacy import (
27
+ add_gaussian_noise_inplace,
28
+ compute_clip_model_update,
29
+ )
30
+ from flwr.common.logger import log
31
+ from flwr.common.message import Message
32
+
33
+ from .centraldp_mods import _handle_array_key_mismatch_err, _handle_multi_record_err
34
+
35
+
36
+ class LocalDpMod:
37
+ """Modifier for local differential privacy.
38
+
39
+ This mod clips the client model updates and
40
+ adds noise to the params before sending them to the server.
41
+
42
+ It operates on messages of type `MessageType.TRAIN`.
43
+
44
+ Parameters
45
+ ----------
46
+ clipping_norm : float
47
+ The value of the clipping norm.
48
+ sensitivity : float
49
+ The sensitivity of the client model.
50
+ epsilon : float
51
+ The privacy budget.
52
+ Smaller value of epsilon indicates a higher level of privacy protection.
53
+ delta : float
54
+ The failure probability.
55
+ The probability that the privacy mechanism
56
+ fails to provide the desired level of privacy.
57
+ A smaller value of delta indicates a stricter privacy guarantee.
58
+
59
+ Examples
60
+ --------
61
+ Create an instance of the local DP mod and add it to the client-side mods::
62
+
63
+ local_dp_mod = LocalDpMod( ... )
64
+ app = fl.client.ClientApp(mods=[local_dp_mod])
65
+ """
66
+
67
+ def __init__(
68
+ self, clipping_norm: float, sensitivity: float, epsilon: float, delta: float
69
+ ) -> None:
70
+ if clipping_norm <= 0:
71
+ raise ValueError("The clipping norm should be a positive value.")
72
+
73
+ if sensitivity < 0:
74
+ raise ValueError("The sensitivity should be a non-negative value.")
75
+
76
+ if epsilon < 0:
77
+ raise ValueError("Epsilon should be a non-negative value.")
78
+
79
+ if delta < 0:
80
+ raise ValueError("Delta should be a non-negative value.")
81
+
82
+ self.clipping_norm = clipping_norm
83
+ self.sensitivity = sensitivity
84
+ self.epsilon = epsilon
85
+ self.delta = delta
86
+
87
+ def __call__(
88
+ self, msg: Message, ctxt: Context, call_next: ClientAppCallable
89
+ ) -> Message:
90
+ """Perform local DP on the client model parameters.
91
+
92
+ Parameters
93
+ ----------
94
+ msg : Message
95
+ The message received from the ServerApp.
96
+ ctxt : Context
97
+ The context of the ClientApp.
98
+ call_next : ClientAppCallable
99
+ The callable to call the next mod (or the ClientApp) in the chain.
100
+
101
+ Returns
102
+ -------
103
+ Message
104
+ The modified message to be sent back to the server.
105
+ """
106
+ if len(msg.content.array_records) != 1:
107
+ return _handle_multi_record_err("LocalDpMod", msg, ArrayRecord)
108
+
109
+ # Record array record communicated to client and clipping norm
110
+ original_array_record = next(iter(msg.content.array_records.values()))
111
+
112
+ # Call inner app
113
+ out_msg = call_next(msg, ctxt)
114
+
115
+ # Check if the msg has error
116
+ if out_msg.has_error():
117
+ return out_msg
118
+
119
+ # Ensure reply has a single ArrayRecord
120
+ if len(out_msg.content.array_records) != 1:
121
+ return _handle_multi_record_err("LocalDpMod", out_msg, ArrayRecord)
122
+
123
+ new_array_record_key, client_to_server_arrecord = next(
124
+ iter(out_msg.content.array_records.items())
125
+ )
126
+
127
+ # Ensure keys in returned ArrayRecord match those in the one sent from server
128
+ if list(original_array_record.keys()) != list(client_to_server_arrecord.keys()):
129
+ return _handle_array_key_mismatch_err("LocalDpMod", out_msg)
130
+
131
+ client_to_server_ndarrays = client_to_server_arrecord.to_numpy_ndarrays()
132
+
133
+ # Clip the client update
134
+ compute_clip_model_update(
135
+ client_to_server_ndarrays,
136
+ original_array_record.to_numpy_ndarrays(),
137
+ self.clipping_norm,
138
+ )
139
+ log(
140
+ INFO,
141
+ "LocalDpMod: parameters are clipped by value: %.4f.",
142
+ self.clipping_norm,
143
+ )
144
+
145
+ std_dev = (
146
+ self.sensitivity * np.sqrt(2 * np.log(1.25 / self.delta)) / self.epsilon
147
+ )
148
+ add_gaussian_noise_inplace(
149
+ client_to_server_ndarrays,
150
+ std_dev,
151
+ )
152
+ log(
153
+ INFO,
154
+ "LocalDpMod: local DP noise with %.4f stddev added to parameters",
155
+ std_dev,
156
+ )
157
+
158
+ # Replace outgoing ArrayRecord's Array while preserving their keys
159
+ out_msg.content[new_array_record_key] = ArrayRecord(
160
+ OrderedDict(
161
+ {
162
+ k: Array(v)
163
+ for k, v in zip(
164
+ client_to_server_arrecord.keys(), client_to_server_ndarrays
165
+ )
166
+ }
167
+ )
168
+ )
169
+ return out_msg
@@ -15,6 +15,7 @@
15
15
  """ServerApp strategies."""
16
16
 
17
17
 
18
+ from .bulyan import Bulyan
18
19
  from .dp_adaptive_clipping import (
19
20
  DifferentialPrivacyClientSideAdaptiveClipping,
20
21
  DifferentialPrivacyServerSideAdaptiveClipping,
@@ -34,11 +35,13 @@ from .fedxgb_bagging import FedXgbBagging
34
35
  from .fedxgb_cyclic import FedXgbCyclic
35
36
  from .fedyogi import FedYogi
36
37
  from .krum import Krum
38
+ from .multikrum import MultiKrum
37
39
  from .qfedavg import QFedAvg
38
40
  from .result import Result
39
41
  from .strategy import Strategy
40
42
 
41
43
  __all__ = [
44
+ "Bulyan",
42
45
  "DifferentialPrivacyClientSideAdaptiveClipping",
43
46
  "DifferentialPrivacyClientSideFixedClipping",
44
47
  "DifferentialPrivacyServerSideAdaptiveClipping",
@@ -54,6 +57,7 @@ __all__ = [
54
57
  "FedXgbCyclic",
55
58
  "FedYogi",
56
59
  "Krum",
60
+ "MultiKrum",
57
61
  "QFedAvg",
58
62
  "Result",
59
63
  "Strategy",
@@ -0,0 +1,238 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Bulyan [El Mhamdi et al., 2018] strategy.
16
+
17
+ Paper: arxiv.org/abs/1802.07927
18
+ """
19
+
20
+
21
+ from collections import OrderedDict
22
+ from collections.abc import Iterable
23
+ from logging import INFO, WARN
24
+ from typing import Callable, Optional, cast
25
+
26
+ import numpy as np
27
+
28
+ from flwr.common import (
29
+ Array,
30
+ ArrayRecord,
31
+ Message,
32
+ MetricRecord,
33
+ NDArrays,
34
+ RecordDict,
35
+ log,
36
+ )
37
+
38
+ from .fedavg import FedAvg
39
+ from .multikrum import select_multikrum
40
+
41
+
42
+ # pylint: disable=too-many-instance-attributes
43
+ class Bulyan(FedAvg):
44
+ """Bulyan strategy.
45
+
46
+ Implementation based on https://arxiv.org/abs/1802.07927.
47
+
48
+ Parameters
49
+ ----------
50
+ fraction_train : float (default: 1.0)
51
+ Fraction of nodes used during training. In case `min_train_nodes`
52
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
53
+ will still be sampled.
54
+ fraction_evaluate : float (default: 1.0)
55
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
56
+ is larger than `fraction_evaluate * total_connected_nodes`,
57
+ `min_evaluate_nodes` will still be sampled.
58
+ min_train_nodes : int (default: 2)
59
+ Minimum number of nodes used during training.
60
+ min_evaluate_nodes : int (default: 2)
61
+ Minimum number of nodes used during validation.
62
+ min_available_nodes : int (default: 2)
63
+ Minimum number of total nodes in the system.
64
+ num_malicious_nodes : int (default: 0)
65
+ Number of malicious nodes in the system.
66
+ weighted_by_key : str (default: "num-examples")
67
+ The key within each MetricRecord whose value is used as the weight when
68
+ computing weighted averages for MetricRecords.
69
+ arrayrecord_key : str (default: "arrays")
70
+ Key used to store the ArrayRecord when constructing Messages.
71
+ configrecord_key : str (default: "config")
72
+ Key used to store the ConfigRecord when constructing Messages.
73
+ train_metrics_aggr_fn : Optional[callable] (default: None)
74
+ Function with signature (list[RecordDict], str) -> MetricRecord,
75
+ used to aggregate MetricRecords from training round replies.
76
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
77
+ average using the provided weight factor key.
78
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
79
+ Function with signature (list[RecordDict], str) -> MetricRecord,
80
+ used to aggregate MetricRecords from training round replies.
81
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
82
+ average using the provided weight factor key.
83
+ selection_rule : Optional[Callable] (default: None)
84
+ Function with signature (list[RecordDict], int, int) -> list[RecordDict].
85
+ The inputs are:
86
+ - a list of contents from reply messages,
87
+ - the assumed number of malicious nodes (`num_malicious_nodes`),
88
+ - the number of nodes to select (`num_nodes_to_select`).
89
+
90
+ The function should implement a Byzantine-resilient selection rule that
91
+ serves as the first step of Bulyan. If None, defaults to `select_multikrum`,
92
+ which selects nodes according to the Multi-Krum algorithm.
93
+ """
94
+
95
+ # pylint: disable=too-many-arguments,too-many-positional-arguments
96
+ def __init__(
97
+ self,
98
+ fraction_train: float = 1.0,
99
+ fraction_evaluate: float = 1.0,
100
+ min_train_nodes: int = 2,
101
+ min_evaluate_nodes: int = 2,
102
+ min_available_nodes: int = 2,
103
+ num_malicious_nodes: int = 0,
104
+ weighted_by_key: str = "num-examples",
105
+ arrayrecord_key: str = "arrays",
106
+ configrecord_key: str = "config",
107
+ train_metrics_aggr_fn: Optional[
108
+ Callable[[list[RecordDict], str], MetricRecord]
109
+ ] = None,
110
+ evaluate_metrics_aggr_fn: Optional[
111
+ Callable[[list[RecordDict], str], MetricRecord]
112
+ ] = None,
113
+ selection_rule: Optional[
114
+ Callable[[list[RecordDict], int, int], list[RecordDict]]
115
+ ] = None,
116
+ ) -> None:
117
+ super().__init__(
118
+ fraction_train=fraction_train,
119
+ fraction_evaluate=fraction_evaluate,
120
+ min_train_nodes=min_train_nodes,
121
+ min_evaluate_nodes=min_evaluate_nodes,
122
+ min_available_nodes=min_available_nodes,
123
+ weighted_by_key=weighted_by_key,
124
+ arrayrecord_key=arrayrecord_key,
125
+ configrecord_key=configrecord_key,
126
+ train_metrics_aggr_fn=train_metrics_aggr_fn,
127
+ evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
128
+ )
129
+ self.num_malicious_nodes = num_malicious_nodes
130
+ self.selection_rule = selection_rule or select_multikrum
131
+
132
+ def summary(self) -> None:
133
+ """Log summary configuration of the strategy."""
134
+ log(INFO, "\t├──> Bulyan settings:")
135
+ log(INFO, "\t│\t├── Number of malicious nodes: %d", self.num_malicious_nodes)
136
+ log(INFO, "\t│\t└── Selection rule: %s", self.selection_rule.__name__)
137
+ super().summary()
138
+
139
+ def aggregate_train(
140
+ self,
141
+ server_round: int,
142
+ replies: Iterable[Message],
143
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
144
+ """Aggregate ArrayRecords and MetricRecords in the received Messages."""
145
+ valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
146
+
147
+ # Check if sufficient replies have been received
148
+ if len(valid_replies) < 4 * self.num_malicious_nodes + 3:
149
+ log(
150
+ WARN,
151
+ "Insufficient replies, skipping Bulyan aggregation: "
152
+ "Required at least %d (4*num_malicious_nodes + 3), but received %d.",
153
+ 4 * self.num_malicious_nodes + 3,
154
+ len(valid_replies),
155
+ )
156
+ return None, None
157
+
158
+ reply_contents = [msg.content for msg in valid_replies]
159
+
160
+ # Compute theta and beta
161
+ theta = len(valid_replies) - 2 * self.num_malicious_nodes
162
+ beta = theta - 2 * self.num_malicious_nodes
163
+
164
+ # Byzantine-resilient selection rule
165
+ selected_contents = self.selection_rule(
166
+ reply_contents, self.num_malicious_nodes, theta
167
+ )
168
+
169
+ # Convert each ArrayRecord to a list of NDArray for easier computation
170
+ key = list(selected_contents[0].array_records.keys())[0]
171
+ array_keys = list(selected_contents[0][key].keys())
172
+ selected_ndarrays = [
173
+ cast(ArrayRecord, ctnt[key]).to_numpy_ndarrays(keep_input=False)
174
+ for ctnt in selected_contents
175
+ ]
176
+
177
+ # Compute median
178
+ median_ndarrays = [np.median(arr, axis=0) for arr in zip(*selected_ndarrays)]
179
+
180
+ # Aggregate the beta closest weights element-wise
181
+ aggregated_ndarrays = aggregate_n_closest_weights(
182
+ median_ndarrays, selected_ndarrays, beta
183
+ )
184
+
185
+ # Convert to ArrayRecord
186
+ arrays = ArrayRecord(
187
+ OrderedDict(zip(array_keys, map(Array, aggregated_ndarrays)))
188
+ )
189
+
190
+ # Aggregate MetricRecords
191
+ metrics = self.train_metrics_aggr_fn(
192
+ selected_contents,
193
+ self.weighted_by_key,
194
+ )
195
+ return arrays, metrics
196
+
197
+
198
+ def aggregate_n_closest_weights(
199
+ ref_weights: NDArrays, weights_list: list[NDArrays], beta: int
200
+ ) -> NDArrays:
201
+ """Compute the element-wise mean of the `beta` closest weight arrays.
202
+
203
+ For each element (i-th coordinate), the output is the average of the
204
+ `beta` weight arrays that are closest to the reference weights.
205
+
206
+ Parameters
207
+ ----------
208
+ ref_weights : NDArrays
209
+ Reference weights used to compute distances.
210
+ weights_list : list[NDArrays]
211
+ List of weight arrays (e.g., from selected nodes).
212
+ beta : int
213
+ Number of closest weight arrays to include in the averaging.
214
+
215
+ Returns
216
+ -------
217
+ aggregated_weights : NDArrays
218
+ Element-wise average of the `beta` closest weight arrays to the
219
+ reference weights.
220
+ """
221
+ aggregated_weights = []
222
+ for layer_id, ref_layer in enumerate(ref_weights):
223
+ # Shape: (n_models, *layer_shape)
224
+ layer_stack = np.stack([weights[layer_id] for weights in weights_list])
225
+
226
+ # Compute absolute differences: shape (n_models, *layer_shape)
227
+ diffs = np.abs(layer_stack - ref_layer)
228
+
229
+ # Find indices of `beta` smallest per coordinate
230
+ idx = np.argpartition(diffs, beta - 1, axis=0)[:beta]
231
+
232
+ # Gather the closest weights
233
+ closest = np.take_along_axis(layer_stack, idx, axis=0)
234
+
235
+ # Average them
236
+ aggregated_weights.append(np.mean(closest, axis=0))
237
+
238
+ return aggregated_weights
@@ -32,6 +32,40 @@ class FedMedian(FedAvg):
32
32
  """Federated Median (FedMedian) strategy.
33
33
 
34
34
  Implementation based on https://arxiv.org/pdf/1803.01498v1
35
+
36
+ Parameters
37
+ ----------
38
+ fraction_train : float (default: 1.0)
39
+ Fraction of nodes used during training. In case `min_train_nodes`
40
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
41
+ will still be sampled.
42
+ fraction_evaluate : float (default: 1.0)
43
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
44
+ is larger than `fraction_evaluate * total_connected_nodes`,
45
+ `min_evaluate_nodes` will still be sampled.
46
+ min_train_nodes : int (default: 2)
47
+ Minimum number of nodes used during training.
48
+ min_evaluate_nodes : int (default: 2)
49
+ Minimum number of nodes used during validation.
50
+ min_available_nodes : int (default: 2)
51
+ Minimum number of total nodes in the system.
52
+ weighted_by_key : str (default: "num-examples")
53
+ The key within each MetricRecord whose value is used as the weight when
54
+ computing weighted averages for MetricRecords.
55
+ arrayrecord_key : str (default: "arrays")
56
+ Key used to store the ArrayRecord when constructing Messages.
57
+ configrecord_key : str (default: "config")
58
+ Key used to store the ConfigRecord when constructing Messages.
59
+ train_metrics_aggr_fn : Optional[callable] (default: None)
60
+ Function with signature (list[RecordDict], str) -> MetricRecord,
61
+ used to aggregate MetricRecords from training round replies.
62
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
63
+ average using the provided weight factor key.
64
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
65
+ Function with signature (list[RecordDict], str) -> MetricRecord,
66
+ used to aggregate MetricRecords from training round replies.
67
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
68
+ average using the provided weight factor key.
35
69
  """
36
70
 
37
71
  def aggregate_train(
@@ -28,7 +28,42 @@ from .strategy_utils import aggregate_bagging
28
28
 
29
29
  # pylint: disable=line-too-long
30
30
  class FedXgbBagging(FedAvg):
31
- """Configurable FedXgbBagging strategy implementation."""
31
+ """Configurable FedXgbBagging strategy implementation.
32
+
33
+ Parameters
34
+ ----------
35
+ fraction_train : float (default: 1.0)
36
+ Fraction of nodes used during training. In case `min_train_nodes`
37
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
38
+ will still be sampled.
39
+ fraction_evaluate : float (default: 1.0)
40
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
41
+ is larger than `fraction_evaluate * total_connected_nodes`,
42
+ `min_evaluate_nodes` will still be sampled.
43
+ min_train_nodes : int (default: 2)
44
+ Minimum number of nodes used during training.
45
+ min_evaluate_nodes : int (default: 2)
46
+ Minimum number of nodes used during validation.
47
+ min_available_nodes : int (default: 2)
48
+ Minimum number of total nodes in the system.
49
+ weighted_by_key : str (default: "num-examples")
50
+ The key within each MetricRecord whose value is used as the weight when
51
+ computing weighted averages for MetricRecords.
52
+ arrayrecord_key : str (default: "arrays")
53
+ Key used to store the ArrayRecord when constructing Messages.
54
+ configrecord_key : str (default: "config")
55
+ Key used to store the ConfigRecord when constructing Messages.
56
+ train_metrics_aggr_fn : Optional[callable] (default: None)
57
+ Function with signature (list[RecordDict], str) -> MetricRecord,
58
+ used to aggregate MetricRecords from training round replies.
59
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
60
+ average using the provided weight factor key.
61
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
62
+ Function with signature (list[RecordDict], str) -> MetricRecord,
63
+ used to aggregate MetricRecords from training round replies.
64
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
65
+ average using the provided weight factor key.
66
+ """
32
67
 
33
68
  current_bst: Optional[bytes] = None
34
69
 
@@ -52,7 +52,7 @@ class FedXgbCyclic(FedAvg):
52
52
  Minimum number of total nodes in the system.
53
53
  weighted_by_key : str (default: "num-examples")
54
54
  The key within each MetricRecord whose value is used as the weight when
55
- computing weighted averages for both ArrayRecords and MetricRecords.
55
+ computing weighted averages for MetricRecords.
56
56
  arrayrecord_key : str (default: "arrays")
57
57
  Key used to store the ArrayRecord when constructing Messages.
58
58
  configrecord_key : str (default: "config")
@@ -20,20 +20,16 @@ Paper: proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-P
20
20
  """
21
21
 
22
22
 
23
- from collections.abc import Iterable
24
23
  from logging import INFO
25
24
  from typing import Callable, Optional
26
25
 
27
- import numpy as np
26
+ from flwr.common import MetricRecord, RecordDict, log
28
27
 
29
- from flwr.common import ArrayRecord, Message, MetricRecord, NDArray, RecordDict, log
30
-
31
- from .fedavg import FedAvg
32
- from .strategy_utils import aggregate_arrayrecords
28
+ from .multikrum import MultiKrum
33
29
 
34
30
 
35
31
  # pylint: disable=too-many-instance-attributes
36
- class Krum(FedAvg):
32
+ class Krum(MultiKrum):
37
33
  """Krum [Blanchard et al., 2017] strategy.
38
34
 
39
35
  Implementation based on https://arxiv.org/abs/1703.02757
@@ -56,12 +52,9 @@ class Krum(FedAvg):
56
52
  Minimum number of total nodes in the system.
57
53
  num_malicious_nodes : int (default: 0)
58
54
  Number of malicious nodes in the system. Defaults to 0.
59
- num_nodes_to_keep : int (default: 0)
60
- Number of nodes to keep before averaging (MultiKrum). Defaults to 0, in
61
- that case classical Krum is applied.
62
55
  weighted_by_key : str (default: "num-examples")
63
56
  The key within each MetricRecord whose value is used as the weight when
64
- computing weighted averages for both ArrayRecords and MetricRecords.
57
+ computing weighted averages for MetricRecords.
65
58
  arrayrecord_key : str (default: "arrays")
66
59
  Key used to store the ArrayRecord when constructing Messages.
67
60
  configrecord_key : str (default: "config")
@@ -87,7 +80,6 @@ class Krum(FedAvg):
87
80
  min_evaluate_nodes: int = 2,
88
81
  min_available_nodes: int = 2,
89
82
  num_malicious_nodes: int = 0,
90
- num_nodes_to_keep: int = 0,
91
83
  weighted_by_key: str = "num-examples",
92
84
  arrayrecord_key: str = "arrays",
93
85
  configrecord_key: str = "config",
@@ -105,126 +97,16 @@ class Krum(FedAvg):
105
97
  min_evaluate_nodes=min_evaluate_nodes,
106
98
  min_available_nodes=min_available_nodes,
107
99
  weighted_by_key=weighted_by_key,
100
+ num_malicious_nodes=num_malicious_nodes,
101
+ num_nodes_to_select=1, # Krum selects 1 node
108
102
  arrayrecord_key=arrayrecord_key,
109
103
  configrecord_key=configrecord_key,
110
104
  train_metrics_aggr_fn=train_metrics_aggr_fn,
111
105
  evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
112
106
  )
113
- self.num_malicious_nodes = num_malicious_nodes
114
- self.num_nodes_to_keep = num_nodes_to_keep
115
107
 
116
108
  def summary(self) -> None:
117
109
  """Log summary configuration of the strategy."""
118
110
  log(INFO, "\t├──> Krum settings:")
119
- log(INFO, "\t│\t├── Number of malicious nodes: %d", self.num_malicious_nodes)
120
- log(INFO, "\t│\t└── Number of nodes to keep: %d", self.num_nodes_to_keep)
111
+ log(INFO, "\t│\t└── Number of malicious nodes: %d", self.num_malicious_nodes)
121
112
  super().summary()
122
-
123
- def _compute_distances(self, records: list[ArrayRecord]) -> NDArray:
124
- """Compute distances between ArrayRecords.
125
-
126
- Parameters
127
- ----------
128
- records : list[ArrayRecord]
129
- A list of ArrayRecords (arrays received in replies)
130
-
131
- Returns
132
- -------
133
- NDArray
134
- A 2D array representing the distance matrix of squared distances
135
- between input ArrayRecords
136
- """
137
- flat_w = np.array(
138
- [
139
- np.concatenate(rec.to_numpy_ndarrays(), axis=None).ravel()
140
- for rec in records
141
- ]
142
- )
143
- distance_matrix = np.zeros((len(records), len(records)))
144
- for i, flat_w_i in enumerate(flat_w):
145
- for j, flat_w_j in enumerate(flat_w):
146
- delta = flat_w_i - flat_w_j
147
- norm = np.linalg.norm(delta)
148
- distance_matrix[i, j] = norm**2
149
- return distance_matrix
150
-
151
- def _krum(self, replies: list[RecordDict]) -> list[RecordDict]:
152
- """Select the set of RecordDicts to aggregate using the Krum or MultiKrum
153
- algorithm.
154
-
155
- For each node, computes the sum of squared distances to its n-f-2 closest
156
- parameter vectors, where n is the number of nodes and f is the number of
157
- malicious nodes. The node(s) with the lowest score(s) are selected for
158
- aggregation.
159
-
160
- Parameters
161
- ----------
162
- replies : list[RecordDict]
163
- List of RecordDicts, each containing an ArrayRecord representing model
164
- parameters from a client.
165
-
166
- Returns
167
- -------
168
- list[RecordDict]
169
- List of RecordDicts selected for aggregation. If `num_nodes_to_keep` > 0,
170
- returns the top `num_nodes_to_keep` RecordDicts (MultiKrum); otherwise,
171
- returns the single RecordDict with the lowest score (Krum).
172
- """
173
- # Construct list of ArrayRecord objects from replies
174
- # Recall aggregate_train first ensures replies only contain one ArrayRecord
175
- array_records = [list(reply.array_records.values())[0] for reply in replies]
176
- distance_matrix = self._compute_distances(array_records)
177
-
178
- # For each node, take the n-f-2 closest parameters vectors
179
- num_closest = max(1, len(array_records) - self.num_malicious_nodes - 2)
180
- closest_indices = []
181
- for distance in distance_matrix:
182
- closest_indices.append(
183
- np.argsort(distance)[1 : num_closest + 1].tolist() # noqa: E203
184
- )
185
-
186
- # Compute the score for each node, that is the sum of the distances
187
- # of the n-f-2 closest parameters vectors
188
- scores = [
189
- np.sum(distance_matrix[i, closest_indices[i]])
190
- for i in range(len(distance_matrix))
191
- ]
192
-
193
- # Return RecordDicts that should be aggregated
194
- if self.num_nodes_to_keep > 0:
195
- # Choose to_keep nodes and return their average (MultiKrum)
196
- best_indices = np.argsort(scores)[::-1][
197
- len(scores) - self.num_nodes_to_keep :
198
- ] # noqa: E203
199
- return [replies[i] for i in best_indices]
200
-
201
- # Return the RecordDict with the ArrayRecord that minimize the score (Krum)
202
- return [replies[np.argmin(scores)]]
203
-
204
- def aggregate_train(
205
- self,
206
- server_round: int,
207
- replies: Iterable[Message],
208
- ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
209
- """Aggregate ArrayRecords and MetricRecords in the received Messages."""
210
- valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
211
-
212
- arrays, metrics = None, None
213
- if valid_replies:
214
- reply_contents = [msg.content for msg in valid_replies]
215
-
216
- # Krum
217
- replies_to_aggregate = self._krum(reply_contents)
218
-
219
- # Aggregate ArrayRecords
220
- arrays = aggregate_arrayrecords(
221
- replies_to_aggregate,
222
- self.weighted_by_key,
223
- )
224
-
225
- # Aggregate MetricRecords
226
- metrics = self.train_metrics_aggr_fn(
227
- replies_to_aggregate,
228
- self.weighted_by_key,
229
- )
230
- return arrays, metrics
@@ -0,0 +1,247 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent.
16
+
17
+ [Blanchard et al., 2017].
18
+
19
+ Paper: proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf
20
+ """
21
+
22
+
23
+ from collections.abc import Iterable
24
+ from logging import INFO
25
+ from typing import Callable, Optional, cast
26
+
27
+ import numpy as np
28
+
29
+ from flwr.common import ArrayRecord, Message, MetricRecord, NDArray, RecordDict, log
30
+
31
+ from .fedavg import FedAvg
32
+ from .strategy_utils import aggregate_arrayrecords
33
+
34
+
35
+ # pylint: disable=too-many-instance-attributes
36
+ class MultiKrum(FedAvg):
37
+ """MultiKrum [Blanchard et al., 2017] strategy.
38
+
39
+ Implementation based on https://arxiv.org/abs/1703.02757
40
+
41
+ Parameters
42
+ ----------
43
+ fraction_train : float (default: 1.0)
44
+ Fraction of nodes used during training. In case `min_train_nodes`
45
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
46
+ will still be sampled.
47
+ fraction_evaluate : float (default: 1.0)
48
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
49
+ is larger than `fraction_evaluate * total_connected_nodes`,
50
+ `min_evaluate_nodes` will still be sampled.
51
+ min_train_nodes : int (default: 2)
52
+ Minimum number of nodes used during training.
53
+ min_evaluate_nodes : int (default: 2)
54
+ Minimum number of nodes used during validation.
55
+ min_available_nodes : int (default: 2)
56
+ Minimum number of total nodes in the system.
57
+ num_malicious_nodes : int (default: 0)
58
+ Number of malicious nodes in the system. Defaults to 0.
59
+ num_nodes_to_select : int (default: 1)
60
+ Number of nodes to select before averaging.
61
+ weighted_by_key : str (default: "num-examples")
62
+ The key within each MetricRecord whose value is used as the weight when
63
+ computing weighted averages for both ArrayRecords and MetricRecords.
64
+ arrayrecord_key : str (default: "arrays")
65
+ Key used to store the ArrayRecord when constructing Messages.
66
+ configrecord_key : str (default: "config")
67
+ Key used to store the ConfigRecord when constructing Messages.
68
+ train_metrics_aggr_fn : Optional[callable] (default: None)
69
+ Function with signature (list[RecordDict], str) -> MetricRecord,
70
+ used to aggregate MetricRecords from training round replies.
71
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
72
+ average using the provided weight factor key.
73
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
74
+ Function with signature (list[RecordDict], str) -> MetricRecord,
75
+ used to aggregate MetricRecords from training round replies.
76
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
77
+ average using the provided weight factor key.
78
+
79
+ Notes
80
+ -----
81
+ MultiKrum is a generalization of Krum. If `num_nodes_to_select` is set to 1,
82
+ MultiKrum will reduce to classical Krum.
83
+ """
84
+
85
+ # pylint: disable=too-many-arguments,too-many-positional-arguments
86
+ def __init__(
87
+ self,
88
+ fraction_train: float = 1.0,
89
+ fraction_evaluate: float = 1.0,
90
+ min_train_nodes: int = 2,
91
+ min_evaluate_nodes: int = 2,
92
+ min_available_nodes: int = 2,
93
+ num_malicious_nodes: int = 0,
94
+ num_nodes_to_select: int = 1,
95
+ weighted_by_key: str = "num-examples",
96
+ arrayrecord_key: str = "arrays",
97
+ configrecord_key: str = "config",
98
+ train_metrics_aggr_fn: Optional[
99
+ Callable[[list[RecordDict], str], MetricRecord]
100
+ ] = None,
101
+ evaluate_metrics_aggr_fn: Optional[
102
+ Callable[[list[RecordDict], str], MetricRecord]
103
+ ] = None,
104
+ ) -> None:
105
+ super().__init__(
106
+ fraction_train=fraction_train,
107
+ fraction_evaluate=fraction_evaluate,
108
+ min_train_nodes=min_train_nodes,
109
+ min_evaluate_nodes=min_evaluate_nodes,
110
+ min_available_nodes=min_available_nodes,
111
+ weighted_by_key=weighted_by_key,
112
+ arrayrecord_key=arrayrecord_key,
113
+ configrecord_key=configrecord_key,
114
+ train_metrics_aggr_fn=train_metrics_aggr_fn,
115
+ evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
116
+ )
117
+ self.num_malicious_nodes = num_malicious_nodes
118
+ self.num_nodes_to_select = num_nodes_to_select
119
+
120
+ def summary(self) -> None:
121
+ """Log summary configuration of the strategy."""
122
+ log(INFO, "\t├──> MultiKrum settings:")
123
+ log(INFO, "\t│\t├── Number of malicious nodes: %d", self.num_malicious_nodes)
124
+ log(INFO, "\t│\t└── Number of nodes to select: %d", self.num_nodes_to_select)
125
+ super().summary()
126
+
127
+ def aggregate_train(
128
+ self,
129
+ server_round: int,
130
+ replies: Iterable[Message],
131
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
132
+ """Aggregate ArrayRecords and MetricRecords in the received Messages."""
133
+ valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
134
+
135
+ arrays, metrics = None, None
136
+ if valid_replies:
137
+ reply_contents = [msg.content for msg in valid_replies]
138
+
139
+ # Krum or MultiKrum selection
140
+ replies_to_aggregate = select_multikrum(
141
+ reply_contents,
142
+ num_malicious_nodes=self.num_malicious_nodes,
143
+ num_nodes_to_select=self.num_nodes_to_select,
144
+ )
145
+
146
+ # Aggregate ArrayRecords
147
+ arrays = aggregate_arrayrecords(
148
+ replies_to_aggregate,
149
+ self.weighted_by_key,
150
+ )
151
+
152
+ # Aggregate MetricRecords
153
+ metrics = self.train_metrics_aggr_fn(
154
+ replies_to_aggregate,
155
+ self.weighted_by_key,
156
+ )
157
+ return arrays, metrics
158
+
159
+
160
+ def compute_distances(records: list[ArrayRecord]) -> NDArray:
161
+ """Compute squared L2 distances between ArrayRecords.
162
+
163
+ Parameters
164
+ ----------
165
+ records : list[ArrayRecord]
166
+ A list of ArrayRecords (arrays received in replies)
167
+
168
+ Returns
169
+ -------
170
+ NDArray
171
+ A 2D array representing the distance matrix of squared L2 distances
172
+ between input ArrayRecords
173
+ """
174
+ # Formula: ||x - y||^2 = ||x||^2 + ||y||^2 - 2 * x.y
175
+ # Flatten records and stack them into a matrix
176
+ flat_w = np.stack(
177
+ [np.concatenate(rec.to_numpy_ndarrays(), axis=None).ravel() for rec in records],
178
+ axis=0,
179
+ ) # shape: (n, d) with n number of records and d the dimension of model
180
+
181
+ # Compute squared norms of each vector
182
+ norms: NDArray = np.square(flat_w).sum(axis=1) # shape (n,)
183
+
184
+ # Use broadcasting to compute pairwise distances
185
+ distance_matrix: NDArray = norms[:, None] + norms[None, :] - 2 * flat_w @ flat_w.T
186
+ return distance_matrix
187
+
188
+
189
+ def select_multikrum(
190
+ contents: list[RecordDict],
191
+ num_malicious_nodes: int,
192
+ num_nodes_to_select: int,
193
+ ) -> list[RecordDict]:
194
+ """Select the set of RecordDicts to aggregate using the Krum or MultiKrum algorithm.
195
+
196
+ For each node, computes the sum of squared L2 distances to its n-f-2 closest
197
+ parameter vectors, where n is the number of nodes and f is the number of
198
+ malicious nodes. The node(s) with the lowest score(s) are selected for
199
+ aggregation.
200
+
201
+ Parameters
202
+ ----------
203
+ contents : list[RecordDict]
204
+ List of contents from reply messages, where each content is a RecordDict
205
+ containing an ArrayRecord of model parameters from a node (client).
206
+ num_malicious_nodes : int
207
+ Number of malicious nodes in the system.
208
+ num_nodes_to_select : int
209
+ Number of client updates to select.
210
+ - If 1, the algorithm reduces to classical Krum (selecting a single update).
211
+ - If >1, Multi-Krum is applied (selecting multiple updates).
212
+
213
+ Returns
214
+ -------
215
+ list[RecordDict]
216
+ Selected contents following the Krum or Multi-Krum algorithm.
217
+
218
+ Notes
219
+ -----
220
+ If `num_nodes_to_select` is set to 1, Multi-Krum reduces to classical Krum
221
+ and only a single RecordDict is selected.
222
+ """
223
+ # Construct list of ArrayRecord objects from replies
224
+ record_key = list(contents[0].array_records.keys())[0]
225
+ # Recall aggregate_train first ensures replies only contain one ArrayRecord
226
+ array_records = [cast(ArrayRecord, reply[record_key]) for reply in contents]
227
+ distance_matrix = compute_distances(array_records)
228
+
229
+ # For each node, take the n-f-2 closest parameters vectors
230
+ num_closest = max(1, len(array_records) - num_malicious_nodes - 2)
231
+ closest_indices = []
232
+ for distance in distance_matrix:
233
+ closest_indices.append(
234
+ np.argsort(distance)[1 : num_closest + 1].tolist() # noqa: E203
235
+ )
236
+
237
+ # Compute the score for each node, that is the sum of the distances
238
+ # of the n-f-2 closest parameters vectors
239
+ scores = [
240
+ np.sum(distance_matrix[i, closest_indices[i]])
241
+ for i in range(len(distance_matrix))
242
+ ]
243
+
244
+ # Choose the num_nodes_to_select lowest-scoring nodes (MultiKrum)
245
+ # and return their updates
246
+ best_indices = np.argsort(scores)[:num_nodes_to_select]
247
+ return [contents[i] for i in best_indices]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: flwr-nightly
3
- Version: 1.22.0.dev20250919
3
+ Version: 1.23.0.dev20250921
4
4
  Summary: Flower: A Friendly Federated AI Framework
5
5
  License: Apache-2.0
6
6
  Keywords: Artificial Intelligence,Federated AI,Federated Analytics,Federated Evaluation,Federated Learning,Flower,Machine Learning
@@ -69,17 +69,17 @@ flwr/cli/new/templates/app/code/task.sklearn.py.tpl,sha256=vHdhtMp0FHxbYafXyhDT9
69
69
  flwr/cli/new/templates/app/code/task.tensorflow.py.tpl,sha256=impgWN7MfztmcWF4xh1llcZGsgTvrb1HD5ZE0t-8U08,1731
70
70
  flwr/cli/new/templates/app/code/task.xgboost.py.tpl,sha256=0xO8jQvrHuB1llVDopQPOmt5Hn6rBw8umzoNwiZZs-o,2135
71
71
  flwr/cli/new/templates/app/code/utils.baseline.py.tpl,sha256=YkHAgppUeD2BnBoGfVB6dEvBfjuIPGsU1gw4CiUi3qA,40
72
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl,sha256=mPIMPfneVYV03l8jWRgWZ0V5Kh_pJw-AMUvkhcKkmL8,3182
73
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl,sha256=wqYW4bWcf12m0U2njR995lySSesFvnHB-eSkPWz-QdM,2501
74
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl,sha256=xHGF38i7oFpvnFvqfqLdtc08CkHRYsenbLz3q1dhCXk,2020
75
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl,sha256=fdDhwmPoMirJ095cU_vFCBf0ILQlAoa1fdnHb2LM1yk,1471
76
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl,sha256=PAjPT2v06sBZxacNiyMJloDwocCK5tFcGQmMXOoBqc8,1542
77
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl,sha256=Kb_O2iQfzwc6FTy3fWqtQYc3FwY6x9SUgQPGqZR_ILg,1409
78
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl,sha256=SE4H23OFkQbqNU64nYf38igqrT4cJGA7XxEtSnNxJqg,1490
79
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl,sha256=docQbs3MuRR-yT24lVz7N2sQL3Sj49EHuOCuRj_0djQ,1508
80
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl,sha256=apauU_PUmLEbt2rjckKniEbzdRs1EnMri_qgtHtBJZ8,1484
81
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl,sha256=LQpDKJTEnRKj5Ygn5FkT44SxlnLVprkPlbrGaFf5Q50,1508
82
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl,sha256=504pHibNRGFe-DLnzqHLYhKeF_n8BPMv0Xog5EfnZ0M,1661
72
+ flwr/cli/new/templates/app/pyproject.baseline.toml.tpl,sha256=7t57rL9zAjCC0Vd1T2LMkPXvl2x0V0iVDTap4oC66LY,3182
73
+ flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl,sha256=iwUlPEC2DsDuIuY3FZmEDxqKkUqcWTLFZXxObFGAXEU,2501
74
+ flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl,sha256=PiQemjUmmLWCdjiZylXFTZer7VhqxqT4_BXbDYw1uh4,2020
75
+ flwr/cli/new/templates/app/pyproject.jax.toml.tpl,sha256=S6PvrDW1JPGnVs6Nfn2e6GHg0taZKt3bTmM-5KKyUpo,1471
76
+ flwr/cli/new/templates/app/pyproject.mlx.toml.tpl,sha256=K00iuPOuqJBGgNE-eFiv5-bahM6PwKtf-6g7lH7FBx8,1542
77
+ flwr/cli/new/templates/app/pyproject.numpy.toml.tpl,sha256=1ptnw1a7wcb3UxbDVquudZ79Z6oTbJl6q33KAGq0cEo,1409
78
+ flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl,sha256=TEqf7kWd4tsN3TbNNosN0o2DAWn4W2OFHqipCf-5gCU,1490
79
+ flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl,sha256=ZWW31QEdrqIvIIWn0UFEYUkgufgDnWQtMSgDGOnWgv8,1508
80
+ flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl,sha256=Elh5bUTVBNRG1aDTrgQcZgjfrXcIaVXH00UduNJDZ2U,1484
81
+ flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl,sha256=g7SYiAJr6Uhwg4ZsQI2qsLVeizU8vd0CzR0_jka99_A,1508
82
+ flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl,sha256=yAJ9jL2q6U_XXYwwUT9fpIqIKFuQR_Kgg82GpWfQ5J8,1661
83
83
  flwr/cli/pull.py,sha256=dHiMe6x8w8yRoFNKpjA-eiPD6eFiHz4Vah5HZrqNpuo,3364
84
84
  flwr/cli/run/__init__.py,sha256=RPyB7KbYTFl6YRiilCch6oezxrLQrl1kijV7BMGkLbA,790
85
85
  flwr/cli/run/run.py,sha256=ECa0kup9dn15O70H74QdgUsEaeErbzDqVX_U0zZO5IM,8173
@@ -113,8 +113,9 @@ flwr/client/rest_client/connection.py,sha256=fyiS1aXTv71jWczx7mSco94LYJTBXgTF-p2
113
113
  flwr/client/run_info_store.py,sha256=MaJ3UQ-07hWtK67wnWu0zR29jrk0fsfgJX506dvEOfE,4042
114
114
  flwr/client/typing.py,sha256=Jw3rawDzI_-ZDcRmEQcs5gZModY7oeQlEeltYsdOhlU,1048
115
115
  flwr/clientapp/__init__.py,sha256=uoTjvIynfGvMhsmc7iYK-5qJ0AdKKjCbx7WTKc0KeSk,828
116
- flwr/clientapp/mod/__init__.py,sha256=LZveV-U2YZVEH4FyAZMXMT-aD1Hc0I5miIkDr7p6uQQ,970
117
- flwr/clientapp/mod/centraldp_mods.py,sha256=a-F-ELs3lt_wtmLl8900ExJiIY792cPCrmwmJKRrerI,8950
116
+ flwr/clientapp/mod/__init__.py,sha256=w18mDPUmjlm5P_er2GJ7l6Zbh8qyiHfMKoVqXxaTxdE,1024
117
+ flwr/clientapp/mod/centraldp_mods.py,sha256=bL2NMxbin9nj2OYSHbuOvIPk0rKH6xyAxG-pbtP9ISY,8954
118
+ flwr/clientapp/mod/localdp_mod.py,sha256=L3kQK3kjc-vwlhJcBwFIK6EiwBYlpHQkv4jiR7K9abQ,5642
118
119
  flwr/clientapp/typing.py,sha256=x1GvXWy112RqZh27liJqz-yZ7SSCOwiOSmAQsUxk9MY,853
119
120
  flwr/common/__init__.py,sha256=5GCLVk399Az_rTJHNticRlL0Sl_oPw_j5_LuFKfX7-M,4171
120
121
  flwr/common/address.py,sha256=9JucdTwlc-jpeJkRKeUboZoacUtErwSVtnDR9kAtLqE,4119
@@ -337,21 +338,23 @@ flwr/server/workflow/secure_aggregation/secagg_workflow.py,sha256=b_pKk7gmbahwyj
337
338
  flwr/server/workflow/secure_aggregation/secaggplus_workflow.py,sha256=DkayCsnlAya6Y2PZsueLgoUCMRtV-GbnW08RfWx_SXM,29460
338
339
  flwr/serverapp/__init__.py,sha256=ZujKNXULwhWYQhFnxOOT5Wi9MRq2JCWFhAAj7ouiQ78,884
339
340
  flwr/serverapp/exception.py,sha256=5cuH-2AafvihzosWDdDjuMmHdDqZ1XxHvCqZXNBVklw,1334
340
- flwr/serverapp/strategy/__init__.py,sha256=QQFa0uMXWrSCTVbd7Ixk_48U6o3K-g4nLYYJUhEVbfo,1877
341
+ flwr/serverapp/strategy/__init__.py,sha256=dezK2TKSffjjBVXW18ATRxJLTuQ7I2M1dPuNi5y-_6c,1968
342
+ flwr/serverapp/strategy/bulyan.py,sha256=lSFOZof5NDzDFS6206rp6V_08LsitSHIHOOckcLt4_E,9306
341
343
  flwr/serverapp/strategy/dp_adaptive_clipping.py,sha256=mssiVGMgfJw8DeP6_pBSZUKWmaXvYeG-B-p7RSt2tAU,13600
342
344
  flwr/serverapp/strategy/dp_fixed_clipping.py,sha256=C_faT0ggzeUB2bGv_r1Vss-fv7-UhDrpmfiHATESI0w,12832
343
345
  flwr/serverapp/strategy/fedadagrad.py,sha256=faFsuKZziPTCLeNrJOyKbPTNo-1xrIZOz7SWT5rdjJs,6269
344
346
  flwr/serverapp/strategy/fedadam.py,sha256=NsY_V6TGFAfCeA9vmqaLpvB_T5siJEtKozKGdxJssAI,7064
345
347
  flwr/serverapp/strategy/fedavg.py,sha256=Bq_nlmngzJbjqX1fF1mevXGVN6-pwglHv-6yNrs6lkA,12035
346
348
  flwr/serverapp/strategy/fedavgm.py,sha256=FtFmBGLzuUQ_7JWk85Xh19d8sP0YDwqczGTliGzZyGs,8333
347
- flwr/serverapp/strategy/fedmedian.py,sha256=b31Dk0LQBbQxi_f-jeSbWHI7iOBugcuBSN2Az-_a75E,2596
349
+ flwr/serverapp/strategy/fedmedian.py,sha256=yhGg6WGWYEbn3oYMfnCBm1F7v9u5LHYVsSAYvdI9Pns,4498
348
350
  flwr/serverapp/strategy/fedopt.py,sha256=kqT0uV2IUE93O72XEVa1JJo61dcwbZEoT9KmYTjR2tE,8477
349
351
  flwr/serverapp/strategy/fedprox.py,sha256=J1KrcE5DFko6i4608iICv1G0t9MPXspjibPd-SF_HT8,7028
350
352
  flwr/serverapp/strategy/fedtrimmedavg.py,sha256=58xDPc_YO41QM8jXn0gZ79PFzO8zo3Mh3UlkF0UBbIA,7168
351
- flwr/serverapp/strategy/fedxgb_bagging.py,sha256=ktDjzov4y0BRecioq788umCEtcuwElou9olBizQKOnM,3282
352
- flwr/serverapp/strategy/fedxgb_cyclic.py,sha256=8H8WoLdG4Fy1_dtLLE4AYiidC-Cvaw2GxySfzAb7Xj0,8774
353
+ flwr/serverapp/strategy/fedxgb_bagging.py,sha256=QAnsQXTE1qzj1qGkZ8mCOfmSKJ2pO_gG9YgmJH-EV6s,5189
354
+ flwr/serverapp/strategy/fedxgb_cyclic.py,sha256=yedhRJ7dEz-Yi5yEiS9zki_LKHPHAk962PWCYLLDLoY,8752
353
355
  flwr/serverapp/strategy/fedyogi.py,sha256=Y9RFBQaNch3fPgGXF7OfnTH6eOpavZxpMWxWVIC9_SY,6579
354
- flwr/serverapp/strategy/krum.py,sha256=iUM7MFCKQcSqUO3Eu4JKWnMc8NV0WMQW9dZXm4onQ-s,9490
356
+ flwr/serverapp/strategy/krum.py,sha256=kcy8TQuJC_qkYJ7jtn26lx2VYYBHlQC_LjNWOuCw9ZQ,4848
357
+ flwr/serverapp/strategy/multikrum.py,sha256=bv44TzMQHh1NEp3ilt0ANoR5xHJ4kpb0qxzSzXD_lx0,9865
355
358
  flwr/serverapp/strategy/qfedavg.py,sha256=EM1tO_ovkybOBeW-h1PYX0lszCUAVHT6hUpwXykAEps,10204
356
359
  flwr/serverapp/strategy/result.py,sha256=E0Hl2VLnZAgQJjE2GDoKsK7JX-kPPU2KXc47Axt6hGw,4295
357
360
  flwr/serverapp/strategy/strategy.py,sha256=8uJGGm1ROLZERQ_dkRS7Z_rs-yK6XCE0UxXtIdFiEWk,10789
@@ -419,7 +422,7 @@ flwr/supernode/servicer/__init__.py,sha256=lucTzre5WPK7G1YLCfaqg3rbFWdNSb7ZTt-ca
419
422
  flwr/supernode/servicer/clientappio/__init__.py,sha256=7Oy62Y_oijqF7Dxi6tpcUQyOpLc_QpIRZ83NvwmB0Yg,813
420
423
  flwr/supernode/servicer/clientappio/clientappio_servicer.py,sha256=nIHRu38EWK-rpNOkcgBRAAKwYQQWFeCwu0lkO7OPZGQ,10239
421
424
  flwr/supernode/start_client_internal.py,sha256=Y9S1-QlO2WP6eo4JvWzIpfaCoh2aoE7bjEYyxNNnlyg,20777
422
- flwr_nightly-1.22.0.dev20250919.dist-info/METADATA,sha256=RuNCBGYrX4Df-vdIVKJ6Do1uDfeKomIIJ557ZlE8uLI,14559
423
- flwr_nightly-1.22.0.dev20250919.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
424
- flwr_nightly-1.22.0.dev20250919.dist-info/entry_points.txt,sha256=hxHD2ixb_vJFDOlZV-zB4Ao32_BQlL34ftsDh1GXv14,420
425
- flwr_nightly-1.22.0.dev20250919.dist-info/RECORD,,
425
+ flwr_nightly-1.23.0.dev20250921.dist-info/METADATA,sha256=zNvKQmiC8a35nT_ngTXES5yhX0VkHdKKjvaFIMRr4QM,14559
426
+ flwr_nightly-1.23.0.dev20250921.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
427
+ flwr_nightly-1.23.0.dev20250921.dist-info/entry_points.txt,sha256=hxHD2ixb_vJFDOlZV-zB4Ao32_BQlL34ftsDh1GXv14,420
428
+ flwr_nightly-1.23.0.dev20250921.dist-info/RECORD,,