flwr-nightly 1.22.0.dev20250916__py3-none-any.whl → 1.22.0.dev20250918__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/new/new.py +4 -2
  3. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  4. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  5. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
  6. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  7. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  8. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  9. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  10. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  11. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  12. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
  13. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
  14. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +3 -3
  15. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  16. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
  17. flwr/cli/pull.py +100 -0
  18. flwr/cli/utils.py +17 -0
  19. flwr/common/constant.py +2 -0
  20. flwr/common/exit/exit_code.py +4 -0
  21. flwr/proto/control_pb2.py +7 -3
  22. flwr/proto/control_pb2.pyi +24 -0
  23. flwr/proto/control_pb2_grpc.py +34 -0
  24. flwr/proto/control_pb2_grpc.pyi +13 -0
  25. flwr/server/app.py +13 -0
  26. flwr/serverapp/strategy/__init__.py +4 -0
  27. flwr/serverapp/strategy/fedprox.py +174 -0
  28. flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
  29. flwr/simulation/app.py +1 -1
  30. flwr/simulation/run_simulation.py +25 -30
  31. flwr/supercore/cli/flower_superexec.py +26 -1
  32. flwr/supercore/constant.py +19 -0
  33. flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
  34. flwr/supercore/superexec/run_superexec.py +16 -2
  35. flwr/superlink/artifact_provider/__init__.py +22 -0
  36. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  37. flwr/superlink/servicer/control/control_grpc.py +3 -0
  38. flwr/superlink/servicer/control/control_servicer.py +59 -2
  39. {flwr_nightly-1.22.0.dev20250916.dist-info → flwr_nightly-1.22.0.dev20250918.dist-info}/METADATA +1 -1
  40. {flwr_nightly-1.22.0.dev20250916.dist-info → flwr_nightly-1.22.0.dev20250918.dist-info}/RECORD +42 -33
  41. flwr/serverapp/strategy/strategy_utils_tests.py +0 -323
  42. {flwr_nightly-1.22.0.dev20250916.dist-info → flwr_nightly-1.22.0.dev20250918.dist-info}/WHEEL +0 -0
  43. {flwr_nightly-1.22.0.dev20250916.dist-info → flwr_nightly-1.22.0.dev20250918.dist-info}/entry_points.txt +0 -0
@@ -44,6 +44,11 @@ class ControlStub(object):
44
44
  request_serializer=flwr_dot_proto_dot_control__pb2.GetAuthTokensRequest.SerializeToString,
45
45
  response_deserializer=flwr_dot_proto_dot_control__pb2.GetAuthTokensResponse.FromString,
46
46
  )
47
+ self.PullArtifacts = channel.unary_unary(
48
+ '/flwr.proto.Control/PullArtifacts',
49
+ request_serializer=flwr_dot_proto_dot_control__pb2.PullArtifactsRequest.SerializeToString,
50
+ response_deserializer=flwr_dot_proto_dot_control__pb2.PullArtifactsResponse.FromString,
51
+ )
47
52
 
48
53
 
49
54
  class ControlServicer(object):
@@ -91,6 +96,13 @@ class ControlServicer(object):
91
96
  context.set_details('Method not implemented!')
92
97
  raise NotImplementedError('Method not implemented!')
93
98
 
99
+ def PullArtifacts(self, request, context):
100
+ """Pull artifacts generated during a run (flwr pull)
101
+ """
102
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
103
+ context.set_details('Method not implemented!')
104
+ raise NotImplementedError('Method not implemented!')
105
+
94
106
 
95
107
  def add_ControlServicer_to_server(servicer, server):
96
108
  rpc_method_handlers = {
@@ -124,6 +136,11 @@ def add_ControlServicer_to_server(servicer, server):
124
136
  request_deserializer=flwr_dot_proto_dot_control__pb2.GetAuthTokensRequest.FromString,
125
137
  response_serializer=flwr_dot_proto_dot_control__pb2.GetAuthTokensResponse.SerializeToString,
126
138
  ),
139
+ 'PullArtifacts': grpc.unary_unary_rpc_method_handler(
140
+ servicer.PullArtifacts,
141
+ request_deserializer=flwr_dot_proto_dot_control__pb2.PullArtifactsRequest.FromString,
142
+ response_serializer=flwr_dot_proto_dot_control__pb2.PullArtifactsResponse.SerializeToString,
143
+ ),
127
144
  }
128
145
  generic_handler = grpc.method_handlers_generic_handler(
129
146
  'flwr.proto.Control', rpc_method_handlers)
@@ -235,3 +252,20 @@ class Control(object):
235
252
  flwr_dot_proto_dot_control__pb2.GetAuthTokensResponse.FromString,
236
253
  options, channel_credentials,
237
254
  insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
255
+
256
+ @staticmethod
257
+ def PullArtifacts(request,
258
+ target,
259
+ options=(),
260
+ channel_credentials=None,
261
+ call_credentials=None,
262
+ insecure=False,
263
+ compression=None,
264
+ wait_for_ready=None,
265
+ timeout=None,
266
+ metadata=None):
267
+ return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/PullArtifacts',
268
+ flwr_dot_proto_dot_control__pb2.PullArtifactsRequest.SerializeToString,
269
+ flwr_dot_proto_dot_control__pb2.PullArtifactsResponse.FromString,
270
+ options, channel_credentials,
271
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@@ -39,6 +39,11 @@ class ControlStub:
39
39
  flwr.proto.control_pb2.GetAuthTokensResponse]
40
40
  """Get auth tokens upon request"""
41
41
 
42
+ PullArtifacts: grpc.UnaryUnaryMultiCallable[
43
+ flwr.proto.control_pb2.PullArtifactsRequest,
44
+ flwr.proto.control_pb2.PullArtifactsResponse]
45
+ """Pull artifacts generated during a run (flwr pull)"""
46
+
42
47
 
43
48
  class ControlServicer(metaclass=abc.ABCMeta):
44
49
  @abc.abstractmethod
@@ -89,5 +94,13 @@ class ControlServicer(metaclass=abc.ABCMeta):
89
94
  """Get auth tokens upon request"""
90
95
  pass
91
96
 
97
+ @abc.abstractmethod
98
+ def PullArtifacts(self,
99
+ request: flwr.proto.control_pb2.PullArtifactsRequest,
100
+ context: grpc.ServicerContext,
101
+ ) -> flwr.proto.control_pb2.PullArtifactsResponse:
102
+ """Pull artifacts generated during a run (flwr pull)"""
103
+ pass
104
+
92
105
 
93
106
  def add_ControlServicer_to_server(servicer: ControlServicer, server: grpc.Server) -> None: ...
flwr/server/app.py CHANGED
@@ -71,6 +71,7 @@ from flwr.server.fleet_event_log_interceptor import FleetEventLogInterceptor
71
71
  from flwr.supercore.ffs import FfsFactory
72
72
  from flwr.supercore.grpc_health import add_args_health, run_health_server_grpc_no_tls
73
73
  from flwr.supercore.object_store import ObjectStoreFactory
74
+ from flwr.superlink.artifact_provider import ArtifactProvider
74
75
  from flwr.superlink.servicer.control import run_control_api_grpc
75
76
 
76
77
  from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
@@ -91,6 +92,7 @@ try:
91
92
  get_control_auth_plugins,
92
93
  get_control_authz_plugins,
93
94
  get_control_event_log_writer_plugins,
95
+ get_ee_artifact_provider,
94
96
  get_fleet_event_log_writer_plugins,
95
97
  )
96
98
  except ImportError:
@@ -113,6 +115,10 @@ except ImportError:
113
115
  "No event log writer plugins are currently supported."
114
116
  )
115
117
 
118
+ def get_ee_artifact_provider(config_path: str) -> ArtifactProvider:
119
+ """Return the EE artifact provider."""
120
+ raise NotImplementedError("No artifact provider is currently supported.")
121
+
116
122
  def get_fleet_event_log_writer_plugins() -> dict[str, type[EventLogWriterPlugin]]:
117
123
  """Return all Fleet API event log writer plugins."""
118
124
  raise NotImplementedError(
@@ -199,6 +205,12 @@ def run_superlink() -> None:
199
205
  if args.enable_event_log:
200
206
  event_log_plugin = _try_obtain_control_event_log_writer_plugin()
201
207
 
208
+ # Load artifact provider if the args.artifact_provider_config is provided
209
+ artifact_provider = None
210
+ if cfg_path := getattr(args, "artifact_provider_config", None):
211
+ log(WARN, "The `--artifact-provider-config` flag is highly experimental.")
212
+ artifact_provider = get_ee_artifact_provider(cfg_path)
213
+
202
214
  # Initialize StateFactory
203
215
  state_factory = LinkStateFactory(args.database)
204
216
 
@@ -220,6 +232,7 @@ def run_superlink() -> None:
220
232
  auth_plugin=auth_plugin,
221
233
  authz_plugin=authz_plugin,
222
234
  event_log_plugin=event_log_plugin,
235
+ artifact_provider=artifact_provider,
223
236
  )
224
237
  grpc_servers = [control_server]
225
238
  bckg_threads: list[threading.Thread] = []
@@ -24,8 +24,10 @@ from .fedadam import FedAdam
24
24
  from .fedavg import FedAvg
25
25
  from .fedavgm import FedAvgM
26
26
  from .fedmedian import FedMedian
27
+ from .fedprox import FedProx
27
28
  from .fedtrimmedavg import FedTrimmedAvg
28
29
  from .fedxgb_bagging import FedXgbBagging
30
+ from .fedxgb_cyclic import FedXgbCyclic
29
31
  from .fedyogi import FedYogi
30
32
  from .result import Result
31
33
  from .strategy import Strategy
@@ -38,8 +40,10 @@ __all__ = [
38
40
  "FedAvg",
39
41
  "FedAvgM",
40
42
  "FedMedian",
43
+ "FedProx",
41
44
  "FedTrimmedAvg",
42
45
  "FedXgbBagging",
46
+ "FedXgbCyclic",
43
47
  "FedYogi",
44
48
  "Result",
45
49
  "Strategy",
@@ -0,0 +1,174 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Federated Optimization (FedProx) [Li et al., 2018] strategy.
16
+
17
+ Paper: arxiv.org/abs/1812.06127
18
+ """
19
+
20
+
21
+ from collections.abc import Iterable
22
+ from logging import INFO, WARN
23
+ from typing import Callable, Optional
24
+
25
+ from flwr.common import (
26
+ ArrayRecord,
27
+ ConfigRecord,
28
+ Message,
29
+ MetricRecord,
30
+ RecordDict,
31
+ log,
32
+ )
33
+ from flwr.server import Grid
34
+
35
+ from .fedavg import FedAvg
36
+
37
+
38
+ class FedProx(FedAvg):
39
+ r"""Federated Optimization strategy.
40
+
41
+ Implementation based on https://arxiv.org/abs/1812.06127
42
+
43
+ FedProx extends FedAvg by introducing a proximal term into the client-side
44
+ optimization objective. The strategy itself behaves identically to FedAvg
45
+ on the server side, but each client **MUST** add a proximal regularization
46
+ term to its local loss function during training:
47
+
48
+ .. math::
49
+ \frac{\mu}{2} || w - w^t ||^2
50
+
51
+ Where $w^t$ denotes the global parameters and $w$ denotes the local weights
52
+ being optimized.
53
+
54
+ This strategy sends the proximal term inside the ``ConfigRecord`` as part of the
55
+ ``configure_train`` method under key ``"proximal-mu"``. The client can then use this
56
+ value to add the proximal term to the loss function.
57
+
58
+ In PyTorch, for example, the loss would go from:
59
+
60
+ .. code:: python
61
+ loss = criterion(net(inputs), labels)
62
+
63
+ To:
64
+
65
+ .. code:: python
66
+ # Get proximal term weight from message
67
+ mu = msg.content["config"]["proximal-mu"]
68
+
69
+ # Compute proximal term
70
+ proximal_term = 0.0
71
+ for local_weights, global_weights in zip(net.parameters(), global_params):
72
+ proximal_term += (local_weights - global_weights).norm(2)
73
+
74
+ # Update loss
75
+ loss = criterion(net(inputs), labels) + (mu / 2) * proximal_term
76
+
77
+ With ``global_params`` being a copy of the model parameters, created **after**
78
+ applying the received global weights but **before** local training begins.
79
+
80
+ .. code:: python
81
+ global_params = copy.deepcopy(net).parameters()
82
+
83
+ Parameters
84
+ ----------
85
+ fraction_train : float (default: 1.0)
86
+ Fraction of nodes used during training. In case `min_train_nodes`
87
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
88
+ will still be sampled.
89
+ fraction_evaluate : float (default: 1.0)
90
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
91
+ is larger than `fraction_evaluate * total_connected_nodes`,
92
+ `min_evaluate_nodes` will still be sampled.
93
+ min_train_nodes : int (default: 2)
94
+ Minimum number of nodes used during training.
95
+ min_evaluate_nodes : int (default: 2)
96
+ Minimum number of nodes used during validation.
97
+ min_available_nodes : int (default: 2)
98
+ Minimum number of total nodes in the system.
99
+ weighted_by_key : str (default: "num-examples")
100
+ The key within each MetricRecord whose value is used as the weight when
101
+ computing weighted averages for both ArrayRecords and MetricRecords.
102
+ arrayrecord_key : str (default: "arrays")
103
+ Key used to store the ArrayRecord when constructing Messages.
104
+ configrecord_key : str (default: "config")
105
+ Key used to store the ConfigRecord when constructing Messages.
106
+ train_metrics_aggr_fn : Optional[callable] (default: None)
107
+ Function with signature (list[RecordDict], str) -> MetricRecord,
108
+ used to aggregate MetricRecords from training round replies.
109
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
110
+ average using the provided weight factor key.
111
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
112
+ Function with signature (list[RecordDict], str) -> MetricRecord,
113
+ used to aggregate MetricRecords from training round replies.
114
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
115
+ average using the provided weight factor key.
116
+ proximal_mu : float (default: 0.0)
117
+ The weight of the proximal term used in the optimization. 0.0 makes
118
+ this strategy equivalent to FedAvg, and the higher the coefficient, the more
119
+ regularization will be used (that is, the client parameters will need to be
120
+ closer to the server parameters during training).
121
+ """
122
+
123
+ def __init__( # pylint: disable=R0913, R0917
124
+ self,
125
+ fraction_train: float = 1.0,
126
+ fraction_evaluate: float = 1.0,
127
+ min_train_nodes: int = 2,
128
+ min_evaluate_nodes: int = 2,
129
+ min_available_nodes: int = 2,
130
+ weighted_by_key: str = "num-examples",
131
+ arrayrecord_key: str = "arrays",
132
+ configrecord_key: str = "config",
133
+ train_metrics_aggr_fn: Optional[
134
+ Callable[[list[RecordDict], str], MetricRecord]
135
+ ] = None,
136
+ evaluate_metrics_aggr_fn: Optional[
137
+ Callable[[list[RecordDict], str], MetricRecord]
138
+ ] = None,
139
+ proximal_mu: float = 0.0,
140
+ ) -> None:
141
+ super().__init__(
142
+ fraction_train=fraction_train,
143
+ fraction_evaluate=fraction_evaluate,
144
+ min_train_nodes=min_train_nodes,
145
+ min_evaluate_nodes=min_evaluate_nodes,
146
+ min_available_nodes=min_available_nodes,
147
+ weighted_by_key=weighted_by_key,
148
+ arrayrecord_key=arrayrecord_key,
149
+ configrecord_key=configrecord_key,
150
+ train_metrics_aggr_fn=train_metrics_aggr_fn,
151
+ evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
152
+ )
153
+ self.proximal_mu = proximal_mu
154
+
155
+ if self.proximal_mu == 0.0:
156
+ log(
157
+ WARN,
158
+ "FedProx initialized with `proximal_mu=0.0`. "
159
+ "This makes the strategy equivalent to FedAvg.",
160
+ )
161
+
162
+ def summary(self) -> None:
163
+ """Log summary configuration of the strategy."""
164
+ log(INFO, "\t├──> FedProx settings:")
165
+ log(INFO, "\t|\t└── Proximal mu: %s", self.proximal_mu)
166
+ super().summary()
167
+
168
+ def configure_train(
169
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
170
+ ) -> Iterable[Message]:
171
+ """Configure the next round of federated training."""
172
+ # Inject proximal term weight into config
173
+ config["proximal-mu"] = self.proximal_mu
174
+ return super().configure_train(server_round, arrays, config, grid)
@@ -0,0 +1,220 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower message-based FedXgbCyclic strategy."""
16
+
17
+
18
+ from collections.abc import Iterable
19
+ from logging import INFO
20
+ from typing import Callable, Optional, cast
21
+
22
+ from flwr.common import (
23
+ ArrayRecord,
24
+ ConfigRecord,
25
+ Message,
26
+ MessageType,
27
+ MetricRecord,
28
+ RecordDict,
29
+ log,
30
+ )
31
+ from flwr.server import Grid
32
+
33
+ from .fedavg import FedAvg
34
+ from .strategy_utils import sample_nodes
35
+
36
+
37
+ # pylint: disable=line-too-long
38
+ class FedXgbCyclic(FedAvg):
39
+ """Configurable FedXgbCyclic strategy implementation.
40
+
41
+ Parameters
42
+ ----------
43
+ fraction_train : float (default: 1.0)
44
+ Fraction of nodes used during training. In case `min_train_nodes`
45
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
46
+ will still be sampled.
47
+ fraction_evaluate : float (default: 1.0)
48
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
49
+ is larger than `fraction_evaluate * total_connected_nodes`,
50
+ `min_evaluate_nodes` will still be sampled.
51
+ min_available_nodes : int (default: 2)
52
+ Minimum number of total nodes in the system.
53
+ weighted_by_key : str (default: "num-examples")
54
+ The key within each MetricRecord whose value is used as the weight when
55
+ computing weighted averages for both ArrayRecords and MetricRecords.
56
+ arrayrecord_key : str (default: "arrays")
57
+ Key used to store the ArrayRecord when constructing Messages.
58
+ configrecord_key : str (default: "config")
59
+ Key used to store the ConfigRecord when constructing Messages.
60
+ train_metrics_aggr_fn : Optional[callable] (default: None)
61
+ Function with signature (list[RecordDict], str) -> MetricRecord,
62
+ used to aggregate MetricRecords from training round replies.
63
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
64
+ average using the provided weight factor key.
65
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
66
+ Function with signature (list[RecordDict], str) -> MetricRecord,
67
+ used to aggregate MetricRecords from training round replies.
68
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
69
+ average using the provided weight factor key.
70
+ """
71
+
72
+ # pylint: disable=too-many-arguments,too-many-positional-arguments
73
+ def __init__(
74
+ self,
75
+ fraction_train: float = 1.0,
76
+ fraction_evaluate: float = 1.0,
77
+ min_available_nodes: int = 2,
78
+ weighted_by_key: str = "num-examples",
79
+ arrayrecord_key: str = "arrays",
80
+ configrecord_key: str = "config",
81
+ train_metrics_aggr_fn: Optional[
82
+ Callable[[list[RecordDict], str], MetricRecord]
83
+ ] = None,
84
+ evaluate_metrics_aggr_fn: Optional[
85
+ Callable[[list[RecordDict], str], MetricRecord]
86
+ ] = None,
87
+ ) -> None:
88
+ super().__init__(
89
+ fraction_train=fraction_train,
90
+ fraction_evaluate=fraction_evaluate,
91
+ min_train_nodes=2,
92
+ min_evaluate_nodes=2,
93
+ min_available_nodes=min_available_nodes,
94
+ weighted_by_key=weighted_by_key,
95
+ arrayrecord_key=arrayrecord_key,
96
+ configrecord_key=configrecord_key,
97
+ train_metrics_aggr_fn=train_metrics_aggr_fn,
98
+ evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
99
+ )
100
+
101
+ self.registered_nodes: dict[int, int] = {}
102
+
103
+ if fraction_train not in (0.0, 1.0):
104
+ raise ValueError(
105
+ "fraction_train can only be set to 1.0 or 0.0 for FedXgbCyclic."
106
+ )
107
+ if fraction_evaluate not in (0.0, 1.0):
108
+ raise ValueError(
109
+ "fraction_evaluate can only be set to 1.0 or 0.0 for FedXgbCyclic."
110
+ )
111
+
112
+ def _reorder_nodes(self, node_ids: list[int]) -> list[int]:
113
+ """Re-order node ids based on registered nodes.
114
+
115
+ Each node ID is assigned a persistent index in `self.registered_nodes`
116
+ the first time it appears. The input list is then reordered according
117
+ to these stored indices, and the result is compacted into ascending
118
+ order (1..N) for the current call.
119
+ """
120
+ # Assign new indices to unknown nodes
121
+ next_index = max(self.registered_nodes.values(), default=0) + 1
122
+ for nid in node_ids:
123
+ if nid not in self.registered_nodes:
124
+ self.registered_nodes[nid] = next_index
125
+ next_index += 1
126
+
127
+ # Sort node_ids by their stored indices
128
+ sorted_by_index = sorted(node_ids, key=lambda x: self.registered_nodes[x])
129
+
130
+ # Compact re-map of indices just for this output list
131
+ unique_indices = sorted(self.registered_nodes[nid] for nid in sorted_by_index)
132
+ remap = {old: new for new, old in enumerate(unique_indices, start=1)}
133
+
134
+ # Build the result list ordered by compact indices
135
+ result_list = [
136
+ nid
137
+ for _, nid in sorted(
138
+ (remap[self.registered_nodes[nid]], nid) for nid in sorted_by_index
139
+ )
140
+ ]
141
+ return result_list
142
+
143
+ def _make_sampling(
144
+ self, grid: Grid, server_round: int, configure_type: str
145
+ ) -> list[int]:
146
+ """Sample nodes using the Grid."""
147
+ # Sample nodes
148
+ num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_train)
149
+ sample_size = max(num_nodes, self.min_train_nodes)
150
+ node_ids, _ = sample_nodes(grid, self.min_available_nodes, sample_size)
151
+
152
+ # Re-order node_ids
153
+ node_ids = self._reorder_nodes(node_ids)
154
+
155
+ # Sample the clients sequentially given server_round
156
+ sampled_idx = (server_round - 1) % len(node_ids)
157
+ sampled_node_id = [node_ids[sampled_idx]]
158
+
159
+ log(
160
+ INFO,
161
+ f"{configure_type}: Sampled %s nodes (out of %s)",
162
+ len(sampled_node_id),
163
+ len(node_ids),
164
+ )
165
+ return sampled_node_id
166
+
167
+ def configure_train(
168
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
169
+ ) -> Iterable[Message]:
170
+ """Configure the next round of federated training."""
171
+ # Sample one node
172
+ sampled_node_id = self._make_sampling(grid, server_round, "configure_train")
173
+
174
+ # Always inject current server round
175
+ config["server-round"] = server_round
176
+
177
+ # Construct messages
178
+ record = RecordDict(
179
+ {self.arrayrecord_key: arrays, self.configrecord_key: config}
180
+ )
181
+ return self._construct_messages(record, sampled_node_id, MessageType.TRAIN)
182
+
183
+ def aggregate_train(
184
+ self,
185
+ server_round: int,
186
+ replies: Iterable[Message],
187
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
188
+ """Aggregate ArrayRecords and MetricRecords in the received Messages."""
189
+ valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
190
+
191
+ arrays, metrics = None, None
192
+ if valid_replies:
193
+ reply_contents = [msg.content for msg in valid_replies]
194
+ array_record_key = next(iter(reply_contents[0].array_records.keys()))
195
+
196
+ # Fetch the client model from current round as global model
197
+ arrays = cast(ArrayRecord, reply_contents[0][array_record_key])
198
+
199
+ # Aggregate MetricRecords
200
+ metrics = self.train_metrics_aggr_fn(
201
+ reply_contents,
202
+ self.weighted_by_key,
203
+ )
204
+ return arrays, metrics
205
+
206
+ def configure_evaluate(
207
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
208
+ ) -> Iterable[Message]:
209
+ """Configure the next round of federated evaluation."""
210
+ # Sample one node
211
+ sampled_node_id = self._make_sampling(grid, server_round, "configure_evaluate")
212
+
213
+ # Always inject current server round
214
+ config["server-round"] = server_round
215
+
216
+ # Construct messages
217
+ record = RecordDict(
218
+ {self.arrayrecord_key: arrays, self.configrecord_key: config}
219
+ )
220
+ return self._construct_messages(record, sampled_node_id, MessageType.EVALUATE)
flwr/simulation/app.py CHANGED
@@ -245,7 +245,7 @@ def run_simulation_process( # pylint: disable=R0913, R0914, R0915, R0917, W0212
245
245
  run=run,
246
246
  enable_tf_gpu_growth=enable_tf_gpu_growth,
247
247
  verbose_logging=verbose,
248
- server_app_run_config=fused_config,
248
+ server_app_context=context,
249
249
  is_app=True,
250
250
  exit_event=EventType.FLWR_SIMULATION_RUN_LEAVE,
251
251
  )