flwr-nightly 1.22.0.dev20250916__py3-none-any.whl → 1.22.0.dev20250918__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +2 -0
- flwr/cli/new/new.py +4 -2
- flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
- flwr/cli/pull.py +100 -0
- flwr/cli/utils.py +17 -0
- flwr/common/constant.py +2 -0
- flwr/common/exit/exit_code.py +4 -0
- flwr/proto/control_pb2.py +7 -3
- flwr/proto/control_pb2.pyi +24 -0
- flwr/proto/control_pb2_grpc.py +34 -0
- flwr/proto/control_pb2_grpc.pyi +13 -0
- flwr/server/app.py +13 -0
- flwr/serverapp/strategy/__init__.py +4 -0
- flwr/serverapp/strategy/fedprox.py +174 -0
- flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
- flwr/simulation/app.py +1 -1
- flwr/simulation/run_simulation.py +25 -30
- flwr/supercore/cli/flower_superexec.py +26 -1
- flwr/supercore/constant.py +19 -0
- flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
- flwr/supercore/superexec/run_superexec.py +16 -2
- flwr/superlink/artifact_provider/__init__.py +22 -0
- flwr/superlink/artifact_provider/artifact_provider.py +37 -0
- flwr/superlink/servicer/control/control_grpc.py +3 -0
- flwr/superlink/servicer/control/control_servicer.py +59 -2
- {flwr_nightly-1.22.0.dev20250916.dist-info → flwr_nightly-1.22.0.dev20250918.dist-info}/METADATA +1 -1
- {flwr_nightly-1.22.0.dev20250916.dist-info → flwr_nightly-1.22.0.dev20250918.dist-info}/RECORD +42 -33
- flwr/serverapp/strategy/strategy_utils_tests.py +0 -323
- {flwr_nightly-1.22.0.dev20250916.dist-info → flwr_nightly-1.22.0.dev20250918.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.22.0.dev20250916.dist-info → flwr_nightly-1.22.0.dev20250918.dist-info}/entry_points.txt +0 -0
flwr/proto/control_pb2_grpc.py
CHANGED
@@ -44,6 +44,11 @@ class ControlStub(object):
|
|
44
44
|
request_serializer=flwr_dot_proto_dot_control__pb2.GetAuthTokensRequest.SerializeToString,
|
45
45
|
response_deserializer=flwr_dot_proto_dot_control__pb2.GetAuthTokensResponse.FromString,
|
46
46
|
)
|
47
|
+
self.PullArtifacts = channel.unary_unary(
|
48
|
+
'/flwr.proto.Control/PullArtifacts',
|
49
|
+
request_serializer=flwr_dot_proto_dot_control__pb2.PullArtifactsRequest.SerializeToString,
|
50
|
+
response_deserializer=flwr_dot_proto_dot_control__pb2.PullArtifactsResponse.FromString,
|
51
|
+
)
|
47
52
|
|
48
53
|
|
49
54
|
class ControlServicer(object):
|
@@ -91,6 +96,13 @@ class ControlServicer(object):
|
|
91
96
|
context.set_details('Method not implemented!')
|
92
97
|
raise NotImplementedError('Method not implemented!')
|
93
98
|
|
99
|
+
def PullArtifacts(self, request, context):
|
100
|
+
"""Pull artifacts generated during a run (flwr pull)
|
101
|
+
"""
|
102
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
103
|
+
context.set_details('Method not implemented!')
|
104
|
+
raise NotImplementedError('Method not implemented!')
|
105
|
+
|
94
106
|
|
95
107
|
def add_ControlServicer_to_server(servicer, server):
|
96
108
|
rpc_method_handlers = {
|
@@ -124,6 +136,11 @@ def add_ControlServicer_to_server(servicer, server):
|
|
124
136
|
request_deserializer=flwr_dot_proto_dot_control__pb2.GetAuthTokensRequest.FromString,
|
125
137
|
response_serializer=flwr_dot_proto_dot_control__pb2.GetAuthTokensResponse.SerializeToString,
|
126
138
|
),
|
139
|
+
'PullArtifacts': grpc.unary_unary_rpc_method_handler(
|
140
|
+
servicer.PullArtifacts,
|
141
|
+
request_deserializer=flwr_dot_proto_dot_control__pb2.PullArtifactsRequest.FromString,
|
142
|
+
response_serializer=flwr_dot_proto_dot_control__pb2.PullArtifactsResponse.SerializeToString,
|
143
|
+
),
|
127
144
|
}
|
128
145
|
generic_handler = grpc.method_handlers_generic_handler(
|
129
146
|
'flwr.proto.Control', rpc_method_handlers)
|
@@ -235,3 +252,20 @@ class Control(object):
|
|
235
252
|
flwr_dot_proto_dot_control__pb2.GetAuthTokensResponse.FromString,
|
236
253
|
options, channel_credentials,
|
237
254
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
255
|
+
|
256
|
+
@staticmethod
|
257
|
+
def PullArtifacts(request,
|
258
|
+
target,
|
259
|
+
options=(),
|
260
|
+
channel_credentials=None,
|
261
|
+
call_credentials=None,
|
262
|
+
insecure=False,
|
263
|
+
compression=None,
|
264
|
+
wait_for_ready=None,
|
265
|
+
timeout=None,
|
266
|
+
metadata=None):
|
267
|
+
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/PullArtifacts',
|
268
|
+
flwr_dot_proto_dot_control__pb2.PullArtifactsRequest.SerializeToString,
|
269
|
+
flwr_dot_proto_dot_control__pb2.PullArtifactsResponse.FromString,
|
270
|
+
options, channel_credentials,
|
271
|
+
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
flwr/proto/control_pb2_grpc.pyi
CHANGED
@@ -39,6 +39,11 @@ class ControlStub:
|
|
39
39
|
flwr.proto.control_pb2.GetAuthTokensResponse]
|
40
40
|
"""Get auth tokens upon request"""
|
41
41
|
|
42
|
+
PullArtifacts: grpc.UnaryUnaryMultiCallable[
|
43
|
+
flwr.proto.control_pb2.PullArtifactsRequest,
|
44
|
+
flwr.proto.control_pb2.PullArtifactsResponse]
|
45
|
+
"""Pull artifacts generated during a run (flwr pull)"""
|
46
|
+
|
42
47
|
|
43
48
|
class ControlServicer(metaclass=abc.ABCMeta):
|
44
49
|
@abc.abstractmethod
|
@@ -89,5 +94,13 @@ class ControlServicer(metaclass=abc.ABCMeta):
|
|
89
94
|
"""Get auth tokens upon request"""
|
90
95
|
pass
|
91
96
|
|
97
|
+
@abc.abstractmethod
|
98
|
+
def PullArtifacts(self,
|
99
|
+
request: flwr.proto.control_pb2.PullArtifactsRequest,
|
100
|
+
context: grpc.ServicerContext,
|
101
|
+
) -> flwr.proto.control_pb2.PullArtifactsResponse:
|
102
|
+
"""Pull artifacts generated during a run (flwr pull)"""
|
103
|
+
pass
|
104
|
+
|
92
105
|
|
93
106
|
def add_ControlServicer_to_server(servicer: ControlServicer, server: grpc.Server) -> None: ...
|
flwr/server/app.py
CHANGED
@@ -71,6 +71,7 @@ from flwr.server.fleet_event_log_interceptor import FleetEventLogInterceptor
|
|
71
71
|
from flwr.supercore.ffs import FfsFactory
|
72
72
|
from flwr.supercore.grpc_health import add_args_health, run_health_server_grpc_no_tls
|
73
73
|
from flwr.supercore.object_store import ObjectStoreFactory
|
74
|
+
from flwr.superlink.artifact_provider import ArtifactProvider
|
74
75
|
from flwr.superlink.servicer.control import run_control_api_grpc
|
75
76
|
|
76
77
|
from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
|
@@ -91,6 +92,7 @@ try:
|
|
91
92
|
get_control_auth_plugins,
|
92
93
|
get_control_authz_plugins,
|
93
94
|
get_control_event_log_writer_plugins,
|
95
|
+
get_ee_artifact_provider,
|
94
96
|
get_fleet_event_log_writer_plugins,
|
95
97
|
)
|
96
98
|
except ImportError:
|
@@ -113,6 +115,10 @@ except ImportError:
|
|
113
115
|
"No event log writer plugins are currently supported."
|
114
116
|
)
|
115
117
|
|
118
|
+
def get_ee_artifact_provider(config_path: str) -> ArtifactProvider:
|
119
|
+
"""Return the EE artifact provider."""
|
120
|
+
raise NotImplementedError("No artifact provider is currently supported.")
|
121
|
+
|
116
122
|
def get_fleet_event_log_writer_plugins() -> dict[str, type[EventLogWriterPlugin]]:
|
117
123
|
"""Return all Fleet API event log writer plugins."""
|
118
124
|
raise NotImplementedError(
|
@@ -199,6 +205,12 @@ def run_superlink() -> None:
|
|
199
205
|
if args.enable_event_log:
|
200
206
|
event_log_plugin = _try_obtain_control_event_log_writer_plugin()
|
201
207
|
|
208
|
+
# Load artifact provider if the args.artifact_provider_config is provided
|
209
|
+
artifact_provider = None
|
210
|
+
if cfg_path := getattr(args, "artifact_provider_config", None):
|
211
|
+
log(WARN, "The `--artifact-provider-config` flag is highly experimental.")
|
212
|
+
artifact_provider = get_ee_artifact_provider(cfg_path)
|
213
|
+
|
202
214
|
# Initialize StateFactory
|
203
215
|
state_factory = LinkStateFactory(args.database)
|
204
216
|
|
@@ -220,6 +232,7 @@ def run_superlink() -> None:
|
|
220
232
|
auth_plugin=auth_plugin,
|
221
233
|
authz_plugin=authz_plugin,
|
222
234
|
event_log_plugin=event_log_plugin,
|
235
|
+
artifact_provider=artifact_provider,
|
223
236
|
)
|
224
237
|
grpc_servers = [control_server]
|
225
238
|
bckg_threads: list[threading.Thread] = []
|
@@ -24,8 +24,10 @@ from .fedadam import FedAdam
|
|
24
24
|
from .fedavg import FedAvg
|
25
25
|
from .fedavgm import FedAvgM
|
26
26
|
from .fedmedian import FedMedian
|
27
|
+
from .fedprox import FedProx
|
27
28
|
from .fedtrimmedavg import FedTrimmedAvg
|
28
29
|
from .fedxgb_bagging import FedXgbBagging
|
30
|
+
from .fedxgb_cyclic import FedXgbCyclic
|
29
31
|
from .fedyogi import FedYogi
|
30
32
|
from .result import Result
|
31
33
|
from .strategy import Strategy
|
@@ -38,8 +40,10 @@ __all__ = [
|
|
38
40
|
"FedAvg",
|
39
41
|
"FedAvgM",
|
40
42
|
"FedMedian",
|
43
|
+
"FedProx",
|
41
44
|
"FedTrimmedAvg",
|
42
45
|
"FedXgbBagging",
|
46
|
+
"FedXgbCyclic",
|
43
47
|
"FedYogi",
|
44
48
|
"Result",
|
45
49
|
"Strategy",
|
@@ -0,0 +1,174 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Federated Optimization (FedProx) [Li et al., 2018] strategy.
|
16
|
+
|
17
|
+
Paper: arxiv.org/abs/1812.06127
|
18
|
+
"""
|
19
|
+
|
20
|
+
|
21
|
+
from collections.abc import Iterable
|
22
|
+
from logging import INFO, WARN
|
23
|
+
from typing import Callable, Optional
|
24
|
+
|
25
|
+
from flwr.common import (
|
26
|
+
ArrayRecord,
|
27
|
+
ConfigRecord,
|
28
|
+
Message,
|
29
|
+
MetricRecord,
|
30
|
+
RecordDict,
|
31
|
+
log,
|
32
|
+
)
|
33
|
+
from flwr.server import Grid
|
34
|
+
|
35
|
+
from .fedavg import FedAvg
|
36
|
+
|
37
|
+
|
38
|
+
class FedProx(FedAvg):
|
39
|
+
r"""Federated Optimization strategy.
|
40
|
+
|
41
|
+
Implementation based on https://arxiv.org/abs/1812.06127
|
42
|
+
|
43
|
+
FedProx extends FedAvg by introducing a proximal term into the client-side
|
44
|
+
optimization objective. The strategy itself behaves identically to FedAvg
|
45
|
+
on the server side, but each client **MUST** add a proximal regularization
|
46
|
+
term to its local loss function during training:
|
47
|
+
|
48
|
+
.. math::
|
49
|
+
\frac{\mu}{2} || w - w^t ||^2
|
50
|
+
|
51
|
+
Where $w^t$ denotes the global parameters and $w$ denotes the local weights
|
52
|
+
being optimized.
|
53
|
+
|
54
|
+
This strategy sends the proximal term inside the ``ConfigRecord`` as part of the
|
55
|
+
``configure_train`` method under key ``"proximal-mu"``. The client can then use this
|
56
|
+
value to add the proximal term to the loss function.
|
57
|
+
|
58
|
+
In PyTorch, for example, the loss would go from:
|
59
|
+
|
60
|
+
.. code:: python
|
61
|
+
loss = criterion(net(inputs), labels)
|
62
|
+
|
63
|
+
To:
|
64
|
+
|
65
|
+
.. code:: python
|
66
|
+
# Get proximal term weight from message
|
67
|
+
mu = msg.content["config"]["proximal-mu"]
|
68
|
+
|
69
|
+
# Compute proximal term
|
70
|
+
proximal_term = 0.0
|
71
|
+
for local_weights, global_weights in zip(net.parameters(), global_params):
|
72
|
+
proximal_term += (local_weights - global_weights).norm(2)
|
73
|
+
|
74
|
+
# Update loss
|
75
|
+
loss = criterion(net(inputs), labels) + (mu / 2) * proximal_term
|
76
|
+
|
77
|
+
With ``global_params`` being a copy of the model parameters, created **after**
|
78
|
+
applying the received global weights but **before** local training begins.
|
79
|
+
|
80
|
+
.. code:: python
|
81
|
+
global_params = copy.deepcopy(net).parameters()
|
82
|
+
|
83
|
+
Parameters
|
84
|
+
----------
|
85
|
+
fraction_train : float (default: 1.0)
|
86
|
+
Fraction of nodes used during training. In case `min_train_nodes`
|
87
|
+
is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
|
88
|
+
will still be sampled.
|
89
|
+
fraction_evaluate : float (default: 1.0)
|
90
|
+
Fraction of nodes used during validation. In case `min_evaluate_nodes`
|
91
|
+
is larger than `fraction_evaluate * total_connected_nodes`,
|
92
|
+
`min_evaluate_nodes` will still be sampled.
|
93
|
+
min_train_nodes : int (default: 2)
|
94
|
+
Minimum number of nodes used during training.
|
95
|
+
min_evaluate_nodes : int (default: 2)
|
96
|
+
Minimum number of nodes used during validation.
|
97
|
+
min_available_nodes : int (default: 2)
|
98
|
+
Minimum number of total nodes in the system.
|
99
|
+
weighted_by_key : str (default: "num-examples")
|
100
|
+
The key within each MetricRecord whose value is used as the weight when
|
101
|
+
computing weighted averages for both ArrayRecords and MetricRecords.
|
102
|
+
arrayrecord_key : str (default: "arrays")
|
103
|
+
Key used to store the ArrayRecord when constructing Messages.
|
104
|
+
configrecord_key : str (default: "config")
|
105
|
+
Key used to store the ConfigRecord when constructing Messages.
|
106
|
+
train_metrics_aggr_fn : Optional[callable] (default: None)
|
107
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
108
|
+
used to aggregate MetricRecords from training round replies.
|
109
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
110
|
+
average using the provided weight factor key.
|
111
|
+
evaluate_metrics_aggr_fn : Optional[callable] (default: None)
|
112
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
113
|
+
used to aggregate MetricRecords from training round replies.
|
114
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
115
|
+
average using the provided weight factor key.
|
116
|
+
proximal_mu : float (default: 0.0)
|
117
|
+
The weight of the proximal term used in the optimization. 0.0 makes
|
118
|
+
this strategy equivalent to FedAvg, and the higher the coefficient, the more
|
119
|
+
regularization will be used (that is, the client parameters will need to be
|
120
|
+
closer to the server parameters during training).
|
121
|
+
"""
|
122
|
+
|
123
|
+
def __init__( # pylint: disable=R0913, R0917
|
124
|
+
self,
|
125
|
+
fraction_train: float = 1.0,
|
126
|
+
fraction_evaluate: float = 1.0,
|
127
|
+
min_train_nodes: int = 2,
|
128
|
+
min_evaluate_nodes: int = 2,
|
129
|
+
min_available_nodes: int = 2,
|
130
|
+
weighted_by_key: str = "num-examples",
|
131
|
+
arrayrecord_key: str = "arrays",
|
132
|
+
configrecord_key: str = "config",
|
133
|
+
train_metrics_aggr_fn: Optional[
|
134
|
+
Callable[[list[RecordDict], str], MetricRecord]
|
135
|
+
] = None,
|
136
|
+
evaluate_metrics_aggr_fn: Optional[
|
137
|
+
Callable[[list[RecordDict], str], MetricRecord]
|
138
|
+
] = None,
|
139
|
+
proximal_mu: float = 0.0,
|
140
|
+
) -> None:
|
141
|
+
super().__init__(
|
142
|
+
fraction_train=fraction_train,
|
143
|
+
fraction_evaluate=fraction_evaluate,
|
144
|
+
min_train_nodes=min_train_nodes,
|
145
|
+
min_evaluate_nodes=min_evaluate_nodes,
|
146
|
+
min_available_nodes=min_available_nodes,
|
147
|
+
weighted_by_key=weighted_by_key,
|
148
|
+
arrayrecord_key=arrayrecord_key,
|
149
|
+
configrecord_key=configrecord_key,
|
150
|
+
train_metrics_aggr_fn=train_metrics_aggr_fn,
|
151
|
+
evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
|
152
|
+
)
|
153
|
+
self.proximal_mu = proximal_mu
|
154
|
+
|
155
|
+
if self.proximal_mu == 0.0:
|
156
|
+
log(
|
157
|
+
WARN,
|
158
|
+
"FedProx initialized with `proximal_mu=0.0`. "
|
159
|
+
"This makes the strategy equivalent to FedAvg.",
|
160
|
+
)
|
161
|
+
|
162
|
+
def summary(self) -> None:
|
163
|
+
"""Log summary configuration of the strategy."""
|
164
|
+
log(INFO, "\t├──> FedProx settings:")
|
165
|
+
log(INFO, "\t|\t└── Proximal mu: %s", self.proximal_mu)
|
166
|
+
super().summary()
|
167
|
+
|
168
|
+
def configure_train(
|
169
|
+
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
170
|
+
) -> Iterable[Message]:
|
171
|
+
"""Configure the next round of federated training."""
|
172
|
+
# Inject proximal term weight into config
|
173
|
+
config["proximal-mu"] = self.proximal_mu
|
174
|
+
return super().configure_train(server_round, arrays, config, grid)
|
@@ -0,0 +1,220 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Flower message-based FedXgbCyclic strategy."""
|
16
|
+
|
17
|
+
|
18
|
+
from collections.abc import Iterable
|
19
|
+
from logging import INFO
|
20
|
+
from typing import Callable, Optional, cast
|
21
|
+
|
22
|
+
from flwr.common import (
|
23
|
+
ArrayRecord,
|
24
|
+
ConfigRecord,
|
25
|
+
Message,
|
26
|
+
MessageType,
|
27
|
+
MetricRecord,
|
28
|
+
RecordDict,
|
29
|
+
log,
|
30
|
+
)
|
31
|
+
from flwr.server import Grid
|
32
|
+
|
33
|
+
from .fedavg import FedAvg
|
34
|
+
from .strategy_utils import sample_nodes
|
35
|
+
|
36
|
+
|
37
|
+
# pylint: disable=line-too-long
|
38
|
+
class FedXgbCyclic(FedAvg):
|
39
|
+
"""Configurable FedXgbCyclic strategy implementation.
|
40
|
+
|
41
|
+
Parameters
|
42
|
+
----------
|
43
|
+
fraction_train : float (default: 1.0)
|
44
|
+
Fraction of nodes used during training. In case `min_train_nodes`
|
45
|
+
is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
|
46
|
+
will still be sampled.
|
47
|
+
fraction_evaluate : float (default: 1.0)
|
48
|
+
Fraction of nodes used during validation. In case `min_evaluate_nodes`
|
49
|
+
is larger than `fraction_evaluate * total_connected_nodes`,
|
50
|
+
`min_evaluate_nodes` will still be sampled.
|
51
|
+
min_available_nodes : int (default: 2)
|
52
|
+
Minimum number of total nodes in the system.
|
53
|
+
weighted_by_key : str (default: "num-examples")
|
54
|
+
The key within each MetricRecord whose value is used as the weight when
|
55
|
+
computing weighted averages for both ArrayRecords and MetricRecords.
|
56
|
+
arrayrecord_key : str (default: "arrays")
|
57
|
+
Key used to store the ArrayRecord when constructing Messages.
|
58
|
+
configrecord_key : str (default: "config")
|
59
|
+
Key used to store the ConfigRecord when constructing Messages.
|
60
|
+
train_metrics_aggr_fn : Optional[callable] (default: None)
|
61
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
62
|
+
used to aggregate MetricRecords from training round replies.
|
63
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
64
|
+
average using the provided weight factor key.
|
65
|
+
evaluate_metrics_aggr_fn : Optional[callable] (default: None)
|
66
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
67
|
+
used to aggregate MetricRecords from training round replies.
|
68
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
69
|
+
average using the provided weight factor key.
|
70
|
+
"""
|
71
|
+
|
72
|
+
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
73
|
+
def __init__(
|
74
|
+
self,
|
75
|
+
fraction_train: float = 1.0,
|
76
|
+
fraction_evaluate: float = 1.0,
|
77
|
+
min_available_nodes: int = 2,
|
78
|
+
weighted_by_key: str = "num-examples",
|
79
|
+
arrayrecord_key: str = "arrays",
|
80
|
+
configrecord_key: str = "config",
|
81
|
+
train_metrics_aggr_fn: Optional[
|
82
|
+
Callable[[list[RecordDict], str], MetricRecord]
|
83
|
+
] = None,
|
84
|
+
evaluate_metrics_aggr_fn: Optional[
|
85
|
+
Callable[[list[RecordDict], str], MetricRecord]
|
86
|
+
] = None,
|
87
|
+
) -> None:
|
88
|
+
super().__init__(
|
89
|
+
fraction_train=fraction_train,
|
90
|
+
fraction_evaluate=fraction_evaluate,
|
91
|
+
min_train_nodes=2,
|
92
|
+
min_evaluate_nodes=2,
|
93
|
+
min_available_nodes=min_available_nodes,
|
94
|
+
weighted_by_key=weighted_by_key,
|
95
|
+
arrayrecord_key=arrayrecord_key,
|
96
|
+
configrecord_key=configrecord_key,
|
97
|
+
train_metrics_aggr_fn=train_metrics_aggr_fn,
|
98
|
+
evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
|
99
|
+
)
|
100
|
+
|
101
|
+
self.registered_nodes: dict[int, int] = {}
|
102
|
+
|
103
|
+
if fraction_train not in (0.0, 1.0):
|
104
|
+
raise ValueError(
|
105
|
+
"fraction_train can only be set to 1.0 or 0.0 for FedXgbCyclic."
|
106
|
+
)
|
107
|
+
if fraction_evaluate not in (0.0, 1.0):
|
108
|
+
raise ValueError(
|
109
|
+
"fraction_evaluate can only be set to 1.0 or 0.0 for FedXgbCyclic."
|
110
|
+
)
|
111
|
+
|
112
|
+
def _reorder_nodes(self, node_ids: list[int]) -> list[int]:
|
113
|
+
"""Re-order node ids based on registered nodes.
|
114
|
+
|
115
|
+
Each node ID is assigned a persistent index in `self.registered_nodes`
|
116
|
+
the first time it appears. The input list is then reordered according
|
117
|
+
to these stored indices, and the result is compacted into ascending
|
118
|
+
order (1..N) for the current call.
|
119
|
+
"""
|
120
|
+
# Assign new indices to unknown nodes
|
121
|
+
next_index = max(self.registered_nodes.values(), default=0) + 1
|
122
|
+
for nid in node_ids:
|
123
|
+
if nid not in self.registered_nodes:
|
124
|
+
self.registered_nodes[nid] = next_index
|
125
|
+
next_index += 1
|
126
|
+
|
127
|
+
# Sort node_ids by their stored indices
|
128
|
+
sorted_by_index = sorted(node_ids, key=lambda x: self.registered_nodes[x])
|
129
|
+
|
130
|
+
# Compact re-map of indices just for this output list
|
131
|
+
unique_indices = sorted(self.registered_nodes[nid] for nid in sorted_by_index)
|
132
|
+
remap = {old: new for new, old in enumerate(unique_indices, start=1)}
|
133
|
+
|
134
|
+
# Build the result list ordered by compact indices
|
135
|
+
result_list = [
|
136
|
+
nid
|
137
|
+
for _, nid in sorted(
|
138
|
+
(remap[self.registered_nodes[nid]], nid) for nid in sorted_by_index
|
139
|
+
)
|
140
|
+
]
|
141
|
+
return result_list
|
142
|
+
|
143
|
+
def _make_sampling(
|
144
|
+
self, grid: Grid, server_round: int, configure_type: str
|
145
|
+
) -> list[int]:
|
146
|
+
"""Sample nodes using the Grid."""
|
147
|
+
# Sample nodes
|
148
|
+
num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_train)
|
149
|
+
sample_size = max(num_nodes, self.min_train_nodes)
|
150
|
+
node_ids, _ = sample_nodes(grid, self.min_available_nodes, sample_size)
|
151
|
+
|
152
|
+
# Re-order node_ids
|
153
|
+
node_ids = self._reorder_nodes(node_ids)
|
154
|
+
|
155
|
+
# Sample the clients sequentially given server_round
|
156
|
+
sampled_idx = (server_round - 1) % len(node_ids)
|
157
|
+
sampled_node_id = [node_ids[sampled_idx]]
|
158
|
+
|
159
|
+
log(
|
160
|
+
INFO,
|
161
|
+
f"{configure_type}: Sampled %s nodes (out of %s)",
|
162
|
+
len(sampled_node_id),
|
163
|
+
len(node_ids),
|
164
|
+
)
|
165
|
+
return sampled_node_id
|
166
|
+
|
167
|
+
def configure_train(
|
168
|
+
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
169
|
+
) -> Iterable[Message]:
|
170
|
+
"""Configure the next round of federated training."""
|
171
|
+
# Sample one node
|
172
|
+
sampled_node_id = self._make_sampling(grid, server_round, "configure_train")
|
173
|
+
|
174
|
+
# Always inject current server round
|
175
|
+
config["server-round"] = server_round
|
176
|
+
|
177
|
+
# Construct messages
|
178
|
+
record = RecordDict(
|
179
|
+
{self.arrayrecord_key: arrays, self.configrecord_key: config}
|
180
|
+
)
|
181
|
+
return self._construct_messages(record, sampled_node_id, MessageType.TRAIN)
|
182
|
+
|
183
|
+
def aggregate_train(
|
184
|
+
self,
|
185
|
+
server_round: int,
|
186
|
+
replies: Iterable[Message],
|
187
|
+
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
188
|
+
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
189
|
+
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
190
|
+
|
191
|
+
arrays, metrics = None, None
|
192
|
+
if valid_replies:
|
193
|
+
reply_contents = [msg.content for msg in valid_replies]
|
194
|
+
array_record_key = next(iter(reply_contents[0].array_records.keys()))
|
195
|
+
|
196
|
+
# Fetch the client model from current round as global model
|
197
|
+
arrays = cast(ArrayRecord, reply_contents[0][array_record_key])
|
198
|
+
|
199
|
+
# Aggregate MetricRecords
|
200
|
+
metrics = self.train_metrics_aggr_fn(
|
201
|
+
reply_contents,
|
202
|
+
self.weighted_by_key,
|
203
|
+
)
|
204
|
+
return arrays, metrics
|
205
|
+
|
206
|
+
def configure_evaluate(
|
207
|
+
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
208
|
+
) -> Iterable[Message]:
|
209
|
+
"""Configure the next round of federated evaluation."""
|
210
|
+
# Sample one node
|
211
|
+
sampled_node_id = self._make_sampling(grid, server_round, "configure_evaluate")
|
212
|
+
|
213
|
+
# Always inject current server round
|
214
|
+
config["server-round"] = server_round
|
215
|
+
|
216
|
+
# Construct messages
|
217
|
+
record = RecordDict(
|
218
|
+
{self.arrayrecord_key: arrays, self.configrecord_key: config}
|
219
|
+
)
|
220
|
+
return self._construct_messages(record, sampled_node_id, MessageType.EVALUATE)
|
flwr/simulation/app.py
CHANGED
@@ -245,7 +245,7 @@ def run_simulation_process( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
245
245
|
run=run,
|
246
246
|
enable_tf_gpu_growth=enable_tf_gpu_growth,
|
247
247
|
verbose_logging=verbose,
|
248
|
-
|
248
|
+
server_app_context=context,
|
249
249
|
is_app=True,
|
250
250
|
exit_event=EventType.FLWR_SIMULATION_RUN_LEAVE,
|
251
251
|
)
|