flwr 1.21.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/new/new.py +9 -7
  3. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  4. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  5. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
  6. flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
  7. flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
  8. flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
  9. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
  10. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
  11. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
  12. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
  13. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
  14. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  15. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  16. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  17. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  18. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  19. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  20. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
  21. flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
  22. flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
  23. flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
  24. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
  25. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
  26. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
  27. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
  28. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
  29. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
  30. flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
  31. flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
  32. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
  33. flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
  34. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
  35. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
  36. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  37. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  38. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  39. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  40. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  41. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  42. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  43. flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +3 -3
  44. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  45. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  46. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
  47. flwr/cli/pull.py +100 -0
  48. flwr/cli/utils.py +17 -0
  49. flwr/clientapp/mod/__init__.py +4 -1
  50. flwr/clientapp/mod/centraldp_mods.py +156 -40
  51. flwr/clientapp/mod/localdp_mod.py +169 -0
  52. flwr/clientapp/typing.py +22 -0
  53. flwr/common/constant.py +3 -0
  54. flwr/common/exit/exit_code.py +4 -0
  55. flwr/common/record/typeddict.py +12 -0
  56. flwr/proto/control_pb2.py +7 -3
  57. flwr/proto/control_pb2.pyi +24 -0
  58. flwr/proto/control_pb2_grpc.py +34 -0
  59. flwr/proto/control_pb2_grpc.pyi +13 -0
  60. flwr/server/app.py +13 -0
  61. flwr/serverapp/strategy/__init__.py +26 -0
  62. flwr/serverapp/strategy/bulyan.py +238 -0
  63. flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
  64. flwr/serverapp/strategy/dp_fixed_clipping.py +71 -49
  65. flwr/serverapp/strategy/fedadagrad.py +0 -3
  66. flwr/serverapp/strategy/fedadam.py +0 -3
  67. flwr/serverapp/strategy/fedavg.py +89 -64
  68. flwr/serverapp/strategy/fedavgm.py +198 -0
  69. flwr/serverapp/strategy/fedmedian.py +105 -0
  70. flwr/serverapp/strategy/fedprox.py +174 -0
  71. flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
  72. flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
  73. flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
  74. flwr/serverapp/strategy/fedyogi.py +0 -3
  75. flwr/serverapp/strategy/krum.py +112 -0
  76. flwr/serverapp/strategy/multikrum.py +247 -0
  77. flwr/serverapp/strategy/qfedavg.py +252 -0
  78. flwr/serverapp/strategy/strategy_utils.py +48 -0
  79. flwr/simulation/app.py +1 -1
  80. flwr/simulation/run_simulation.py +25 -30
  81. flwr/supercore/cli/flower_superexec.py +26 -1
  82. flwr/supercore/constant.py +19 -0
  83. flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
  84. flwr/supercore/superexec/run_superexec.py +16 -2
  85. flwr/superlink/artifact_provider/__init__.py +22 -0
  86. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  87. flwr/superlink/servicer/control/control_grpc.py +3 -0
  88. flwr/superlink/servicer/control/control_servicer.py +59 -2
  89. {flwr-1.21.0.dist-info → flwr-1.22.0.dist-info}/METADATA +6 -16
  90. {flwr-1.21.0.dist-info → flwr-1.22.0.dist-info}/RECORD +93 -74
  91. flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
  92. flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
  93. flwr/serverapp/dp_fixed_clipping.py +0 -352
  94. flwr/serverapp/strategy/strategy_utils_tests.py +0 -304
  95. /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
  96. {flwr-1.21.0.dist-info → flwr-1.22.0.dist-info}/WHEEL +0 -0
  97. {flwr-1.21.0.dist-info → flwr-1.22.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,252 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Fair Resource Allocation in Federated Learning [Li et al., 2020] strategy.
16
+
17
+ Paper: openreview.net/pdf?id=ByexElSYDr
18
+ """
19
+
20
+
21
+ from collections import OrderedDict
22
+ from collections.abc import Iterable
23
+ from logging import INFO
24
+ from typing import Callable, Optional
25
+
26
+ import numpy as np
27
+
28
+ from flwr.common import (
29
+ Array,
30
+ ArrayRecord,
31
+ ConfigRecord,
32
+ Message,
33
+ MetricRecord,
34
+ NDArray,
35
+ RecordDict,
36
+ )
37
+ from flwr.common.logger import log
38
+ from flwr.server import Grid
39
+
40
+ from ..exception import AggregationError
41
+ from .fedavg import FedAvg
42
+
43
+
44
+ class QFedAvg(FedAvg):
45
+ """Q-FedAvg strategy.
46
+
47
+ Implementation based on openreview.net/pdf?id=ByexElSYDr
48
+
49
+ Parameters
50
+ ----------
51
+ client_learning_rate : float
52
+ Local learning rate used by clients during training. This value is used by
53
+ the strategy to approximate the base Lipschitz constant L, via
54
+ L = 1 / client_learning_rate.
55
+ q : float (default: 0.1)
56
+ The parameter q that controls the degree of fairness of the algorithm. Please
57
+ tune this parameter based on your use case.
58
+ When set to 0, q-FedAvg is equivalent to FedAvg.
59
+ train_loss_key : str (default: "train_loss")
60
+ The key within the MetricRecord whose value is used as the training loss when
61
+ aggregating ArrayRecords following q-FedAvg.
62
+ fraction_train : float (default: 1.0)
63
+ Fraction of nodes used during training. In case `min_train_nodes`
64
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
65
+ will still be sampled.
66
+ fraction_evaluate : float (default: 1.0)
67
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
68
+ is larger than `fraction_evaluate * total_connected_nodes`,
69
+ `min_evaluate_nodes` will still be sampled.
70
+ min_train_nodes : int (default: 2)
71
+ Minimum number of nodes used during training.
72
+ min_evaluate_nodes : int (default: 2)
73
+ Minimum number of nodes used during validation.
74
+ min_available_nodes : int (default: 2)
75
+ Minimum number of total nodes in the system.
76
+ weighted_by_key : str (default: "num-examples")
77
+ The key within each MetricRecord whose value is used as the weight when
78
+ computing weighted averages for MetricRecords.
79
+ arrayrecord_key : str (default: "arrays")
80
+ Key used to store the ArrayRecord when constructing Messages.
81
+ configrecord_key : str (default: "config")
82
+ Key used to store the ConfigRecord when constructing Messages.
83
+ train_metrics_aggr_fn : Optional[callable] (default: None)
84
+ Function with signature (list[RecordDict], str) -> MetricRecord,
85
+ used to aggregate MetricRecords from training round replies.
86
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
87
+ average using the provided weight factor key.
88
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
89
+ Function with signature (list[RecordDict], str) -> MetricRecord,
90
+ used to aggregate MetricRecords from training round replies.
91
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
92
+ average using the provided weight factor key.
93
+ """
94
+
95
+ def __init__( # pylint: disable=R0913, R0917
96
+ self,
97
+ client_learning_rate: float,
98
+ q: float = 0.1,
99
+ train_loss_key: str = "train_loss",
100
+ fraction_train: float = 1.0,
101
+ fraction_evaluate: float = 1.0,
102
+ min_train_nodes: int = 2,
103
+ min_evaluate_nodes: int = 2,
104
+ min_available_nodes: int = 2,
105
+ weighted_by_key: str = "num-examples",
106
+ arrayrecord_key: str = "arrays",
107
+ configrecord_key: str = "config",
108
+ train_metrics_aggr_fn: Optional[
109
+ Callable[[list[RecordDict], str], MetricRecord]
110
+ ] = None,
111
+ evaluate_metrics_aggr_fn: Optional[
112
+ Callable[[list[RecordDict], str], MetricRecord]
113
+ ] = None,
114
+ ) -> None:
115
+ super().__init__(
116
+ fraction_train=fraction_train,
117
+ fraction_evaluate=fraction_evaluate,
118
+ min_train_nodes=min_train_nodes,
119
+ min_evaluate_nodes=min_evaluate_nodes,
120
+ min_available_nodes=min_available_nodes,
121
+ weighted_by_key=weighted_by_key,
122
+ arrayrecord_key=arrayrecord_key,
123
+ configrecord_key=configrecord_key,
124
+ train_metrics_aggr_fn=train_metrics_aggr_fn,
125
+ evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
126
+ )
127
+ self.q = q
128
+ self.client_learning_rate = client_learning_rate
129
+ self.train_loss_key = train_loss_key
130
+ self.current_arrays: Optional[ArrayRecord] = None
131
+
132
+ def summary(self) -> None:
133
+ """Log summary configuration of the strategy."""
134
+ log(INFO, "\t├──> q-FedAvg settings:")
135
+ log(INFO, "\t│\t├── client_learning_rate: %s", self.client_learning_rate)
136
+ log(INFO, "\t│\t├── q: %s", self.q)
137
+ log(INFO, "\t│\t└── train_loss_key: '%s'", self.train_loss_key)
138
+ super().summary()
139
+
140
+ def configure_train(
141
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
142
+ ) -> Iterable[Message]:
143
+ """Configure the next round of federated training."""
144
+ self.current_arrays = arrays.copy()
145
+ return super().configure_train(server_round, arrays, config, grid)
146
+
147
+ def aggregate_train( # pylint: disable=too-many-locals
148
+ self,
149
+ server_round: int,
150
+ replies: Iterable[Message],
151
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
152
+ """Aggregate ArrayRecords and MetricRecords in the received Messages."""
153
+ # Call FedAvg aggregate_train to perform validation and aggregation
154
+ valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
155
+
156
+ if not valid_replies:
157
+ return None, None
158
+
159
+ # Compute estimate of Lipschitz constant L
160
+ L = 1.0 / self.client_learning_rate # pylint: disable=C0103
161
+
162
+ # q-FedAvg aggregation
163
+ if self.current_arrays is None:
164
+ raise AggregationError(
165
+ "Current global model weights are not available. Make sure to call"
166
+ "`configure_train` before calling `aggregate_train`."
167
+ )
168
+ array_keys = list(self.current_arrays.keys()) # Preserve keys
169
+ global_weights = self.current_arrays.to_numpy_ndarrays(keep_input=False)
170
+ sum_delta = None
171
+ sum_h = 0.0
172
+
173
+ for msg in valid_replies:
174
+ # Extract local weights and training loss from Message
175
+ local_weights = get_local_weights(msg)
176
+ loss = get_train_loss(msg, self.train_loss_key)
177
+
178
+ # Compute delta and h
179
+ delta, h = compute_delta_and_h(
180
+ global_weights, local_weights, self.q, L, loss
181
+ )
182
+
183
+ # Compute sum of deltas and sum of h
184
+ if sum_delta is None:
185
+ sum_delta = delta
186
+ else:
187
+ sum_delta = [sd + d for sd, d in zip(sum_delta, delta)]
188
+ sum_h += h
189
+
190
+ # Compute new global weights and convert to Array type
191
+ # `np.asarray` can convert numpy scalars to 0-dim arrays
192
+ assert sum_delta is not None # Make mypy happy
193
+ array_list = [
194
+ Array(np.asarray(gw - (d / sum_h)))
195
+ for gw, d in zip(global_weights, sum_delta)
196
+ ]
197
+
198
+ # Aggregate MetricRecords
199
+ metrics = self.train_metrics_aggr_fn(
200
+ [msg.content for msg in valid_replies],
201
+ self.weighted_by_key,
202
+ )
203
+ return ArrayRecord(OrderedDict(zip(array_keys, array_list))), metrics
204
+
205
+
206
+ def get_train_loss(msg: Message, loss_key: str) -> float:
207
+ """Extract training loss from a Message."""
208
+ metrics = list(msg.content.metric_records.values())[0]
209
+ if (loss := metrics.get(loss_key)) is None or not isinstance(loss, (int, float)):
210
+ raise AggregationError(
211
+ "Missing or invalid training loss. "
212
+ f"The strategy expected a float value for the key '{loss_key}' "
213
+ "as the training loss in each MetricRecord from the clients. "
214
+ f"Ensure that '{loss_key}' is present and maps to a valid float."
215
+ )
216
+ return float(loss)
217
+
218
+
219
+ def get_local_weights(msg: Message) -> list[NDArray]:
220
+ """Extract local weights from a Message."""
221
+ arrays = list(msg.content.array_records.values())[0]
222
+ return arrays.to_numpy_ndarrays(keep_input=False)
223
+
224
+
225
+ def l2_norm(ndarrays: list[NDArray]) -> float:
226
+ """Compute the squared L2 norm of a list of numpy.ndarray."""
227
+ return float(sum(np.sum(np.square(g)) for g in ndarrays))
228
+
229
+
230
+ def compute_delta_and_h(
231
+ global_weights: list[NDArray],
232
+ local_weights: list[NDArray],
233
+ q: float,
234
+ L: float, # Lipschitz constant # pylint: disable=C0103
235
+ loss: float,
236
+ ) -> tuple[list[NDArray], float]:
237
+ """Compute delta and h used in q-FedAvg aggregation."""
238
+ # Compute gradient_k = L * (w - w_k)
239
+ for gw, lw in zip(global_weights, local_weights):
240
+ np.subtract(gw, lw, out=lw)
241
+ lw *= L
242
+ grad = local_weights # After in-place operations, local_weights is now grad
243
+ # Compute ||w_k - w||^2
244
+ norm = l2_norm(grad)
245
+ # Compute delta_k = loss_k^q * gradient_k
246
+ loss_pow_q: float = np.float_power(loss + 1e-10, q)
247
+ for g in grad:
248
+ g *= loss_pow_q
249
+ delta = grad # After in-place multiplication, grad is now delta
250
+ # Compute h_k
251
+ h = q * np.float_power(loss + 1e-10, q - 1) * norm + L * loss_pow_q
252
+ return delta, h
@@ -15,6 +15,7 @@
15
15
  """Flower message-based strategy utilities."""
16
16
 
17
17
 
18
+ import json
18
19
  import random
19
20
  from collections import OrderedDict
20
21
  from logging import INFO
@@ -249,3 +250,50 @@ def validate_message_reply_consistency(
249
250
  "must be a single value (int or float), but a list was found. Skipping "
250
251
  "aggregation."
251
252
  )
253
+
254
+
255
+ def aggregate_bagging(
256
+ bst_prev_org: bytes,
257
+ bst_curr_org: bytes,
258
+ ) -> bytes:
259
+ """Conduct bagging aggregation for given trees."""
260
+ if bst_prev_org == b"":
261
+ return bst_curr_org
262
+
263
+ # Get the tree numbers
264
+ tree_num_prev, _ = _get_tree_nums(bst_prev_org)
265
+ _, paral_tree_num_curr = _get_tree_nums(bst_curr_org)
266
+
267
+ bst_prev = json.loads(bytearray(bst_prev_org))
268
+ bst_curr = json.loads(bytearray(bst_curr_org))
269
+
270
+ previous_model = bst_prev["learner"]["gradient_booster"]["model"]
271
+ previous_model["gbtree_model_param"]["num_trees"] = str(
272
+ tree_num_prev + paral_tree_num_curr
273
+ )
274
+ iteration_indptr = previous_model["iteration_indptr"]
275
+ previous_model["iteration_indptr"].append(
276
+ iteration_indptr[-1] + paral_tree_num_curr
277
+ )
278
+
279
+ # Aggregate new trees
280
+ trees_curr = bst_curr["learner"]["gradient_booster"]["model"]["trees"]
281
+ for tree_count in range(paral_tree_num_curr):
282
+ trees_curr[tree_count]["id"] = tree_num_prev + tree_count
283
+ previous_model["trees"].append(trees_curr[tree_count])
284
+ previous_model["tree_info"].append(0)
285
+
286
+ bst_prev_bytes = bytes(json.dumps(bst_prev), "utf-8")
287
+
288
+ return bst_prev_bytes
289
+
290
+
291
+ def _get_tree_nums(xgb_model_org: bytes) -> tuple[int, int]:
292
+ xgb_model = json.loads(bytearray(xgb_model_org))
293
+
294
+ # Access model parameters
295
+ model_param = xgb_model["learner"]["gradient_booster"]["model"][
296
+ "gbtree_model_param"
297
+ ]
298
+ # Return the number of trees and the number of parallel trees
299
+ return int(model_param["num_trees"]), int(model_param["num_parallel_tree"])
flwr/simulation/app.py CHANGED
@@ -245,7 +245,7 @@ def run_simulation_process( # pylint: disable=R0913, R0914, R0915, R0917, W0212
245
245
  run=run,
246
246
  enable_tf_gpu_growth=enable_tf_gpu_growth,
247
247
  verbose_logging=verbose,
248
- server_app_run_config=fused_config,
248
+ server_app_context=context,
249
249
  is_app=True,
250
250
  exit_event=EventType.FLWR_SIMULATION_RUN_LEAVE,
251
251
  )
@@ -143,6 +143,15 @@ def run_simulation_from_cli() -> None:
143
143
  run = Run.create_empty(run_id)
144
144
  run.override_config = override_config
145
145
 
146
+ # Create Context
147
+ server_app_context = Context(
148
+ run_id=run_id,
149
+ node_id=0,
150
+ node_config=UserConfig(),
151
+ state=RecordDict(),
152
+ run_config=fused_config,
153
+ )
154
+
146
155
  _ = _run_simulation(
147
156
  server_app_attr=server_app_attr,
148
157
  client_app_attr=client_app_attr,
@@ -153,7 +162,7 @@ def run_simulation_from_cli() -> None:
153
162
  run=run,
154
163
  enable_tf_gpu_growth=args.enable_tf_gpu_growth,
155
164
  verbose_logging=args.verbose,
156
- server_app_run_config=fused_config,
165
+ server_app_context=server_app_context,
157
166
  is_app=True,
158
167
  exit_event=EventType.CLI_FLOWER_SIMULATION_LEAVE,
159
168
  )
@@ -241,13 +250,12 @@ def run_simulation(
241
250
  def run_serverapp_th(
242
251
  server_app_attr: Optional[str],
243
252
  server_app: Optional[ServerApp],
244
- server_app_run_config: UserConfig,
253
+ server_app_context: Context,
245
254
  grid: Grid,
246
255
  app_dir: str,
247
256
  f_stop: threading.Event,
248
257
  has_exception: threading.Event,
249
258
  enable_tf_gpu_growth: bool,
250
- run_id: int,
251
259
  ctx_queue: "Queue[Context]",
252
260
  ) -> threading.Thread:
253
261
  """Run SeverApp in a thread."""
@@ -258,7 +266,6 @@ def run_serverapp_th(
258
266
  exception_event: threading.Event,
259
267
  _grid: Grid,
260
268
  _server_app_dir: str,
261
- _server_app_run_config: UserConfig,
262
269
  _server_app_attr: Optional[str],
263
270
  _server_app: Optional[ServerApp],
264
271
  _ctx_queue: "Queue[Context]",
@@ -272,19 +279,10 @@ def run_serverapp_th(
272
279
  log(INFO, "Enabling GPU growth for Tensorflow on the server thread.")
273
280
  enable_gpu_growth()
274
281
 
275
- # Initialize Context
276
- context = Context(
277
- run_id=run_id,
278
- node_id=0,
279
- node_config={},
280
- state=RecordDict(),
281
- run_config=_server_app_run_config,
282
- )
283
-
284
282
  # Run ServerApp
285
283
  updated_context = _run(
286
284
  grid=_grid,
287
- context=context,
285
+ context=server_app_context,
288
286
  server_app_dir=_server_app_dir,
289
287
  server_app_attr=_server_app_attr,
290
288
  loaded_server_app=_server_app,
@@ -310,7 +308,6 @@ def run_serverapp_th(
310
308
  has_exception,
311
309
  grid,
312
310
  app_dir,
313
- server_app_run_config,
314
311
  server_app_attr,
315
312
  server_app,
316
313
  ctx_queue,
@@ -335,7 +332,7 @@ def _main_loop(
335
332
  client_app_attr: Optional[str] = None,
336
333
  server_app: Optional[ServerApp] = None,
337
334
  server_app_attr: Optional[str] = None,
338
- server_app_run_config: Optional[UserConfig] = None,
335
+ server_app_context: Optional[Context] = None,
339
336
  ) -> Context:
340
337
  """Start ServerApp on a separate thread, then launch Simulation Engine."""
341
338
  # Initialize StateFactory
@@ -346,13 +343,15 @@ def _main_loop(
346
343
  server_app_thread_has_exception = threading.Event()
347
344
  serverapp_th = None
348
345
  success = True
349
- updated_context = Context(
350
- run_id=run.run_id,
351
- node_id=0,
352
- node_config=UserConfig(),
353
- state=RecordDict(),
354
- run_config=UserConfig(),
355
- )
346
+ if server_app_context is None:
347
+ server_app_context = Context(
348
+ run_id=run.run_id,
349
+ node_id=0,
350
+ node_config=UserConfig(),
351
+ state=RecordDict(),
352
+ run_config=UserConfig(),
353
+ )
354
+ updated_context = server_app_context
356
355
  try:
357
356
  # Register run
358
357
  log(DEBUG, "Pre-registering run with id %s", run.run_id)
@@ -361,9 +360,6 @@ def _main_loop(
361
360
  run.running_at = run.starting_at
362
361
  state_factory.state().run_ids[run.run_id] = RunRecord(run=run) # type: ignore
363
362
 
364
- if server_app_run_config is None:
365
- server_app_run_config = {}
366
-
367
363
  # Initialize Grid
368
364
  grid = InMemoryGrid(state_factory=state_factory)
369
365
  grid.set_run(run_id=run.run_id)
@@ -373,13 +369,12 @@ def _main_loop(
373
369
  serverapp_th = run_serverapp_th(
374
370
  server_app_attr=server_app_attr,
375
371
  server_app=server_app,
376
- server_app_run_config=server_app_run_config,
372
+ server_app_context=server_app_context,
377
373
  grid=grid,
378
374
  app_dir=app_dir,
379
375
  f_stop=f_stop,
380
376
  has_exception=server_app_thread_has_exception,
381
377
  enable_tf_gpu_growth=enable_tf_gpu_growth,
382
- run_id=run.run_id,
383
378
  ctx_queue=output_context_queue,
384
379
  )
385
380
 
@@ -438,7 +433,7 @@ def _run_simulation(
438
433
  backend_config: Optional[BackendConfig] = None,
439
434
  client_app_attr: Optional[str] = None,
440
435
  server_app_attr: Optional[str] = None,
441
- server_app_run_config: Optional[UserConfig] = None,
436
+ server_app_context: Optional[Context] = None,
442
437
  app_dir: str = "",
443
438
  flwr_dir: Optional[str] = None,
444
439
  run: Optional[Run] = None,
@@ -502,7 +497,7 @@ def _run_simulation(
502
497
  client_app_attr,
503
498
  server_app,
504
499
  server_app_attr,
505
- server_app_run_config,
500
+ server_app_context,
506
501
  )
507
502
  # Detect if there is an Asyncio event loop already running.
508
503
  # If yes, disable logger propagation. In environmnets
@@ -17,7 +17,9 @@
17
17
 
18
18
  import argparse
19
19
  from logging import INFO
20
- from typing import Optional
20
+ from typing import Any, Optional
21
+
22
+ import yaml
21
23
 
22
24
  from flwr.common import EventType, event
23
25
  from flwr.common.constant import ExecPluginType
@@ -26,6 +28,7 @@ from flwr.common.logger import log
26
28
  from flwr.proto.clientappio_pb2_grpc import ClientAppIoStub
27
29
  from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub
28
30
  from flwr.proto.simulationio_pb2_grpc import SimulationIoStub
31
+ from flwr.supercore.constant import EXEC_PLUGIN_SECTION
29
32
  from flwr.supercore.grpc_health import add_args_health
30
33
  from flwr.supercore.superexec.plugin import (
31
34
  ClientAppExecPlugin,
@@ -36,6 +39,7 @@ from flwr.supercore.superexec.plugin import (
36
39
  from flwr.supercore.superexec.run_superexec import run_superexec
37
40
 
38
41
  try:
42
+ from flwr.ee import add_ee_args_superexec
39
43
  from flwr.ee.constant import ExecEePluginType
40
44
  from flwr.ee.exec_plugin import get_ee_plugin_and_stub_class
41
45
  except ImportError:
@@ -54,6 +58,10 @@ except ImportError:
54
58
  """Get the EE plugin class and stub class based on the plugin type."""
55
59
  return None
56
60
 
61
+ # pylint: disable-next=unused-argument
62
+ def add_ee_args_superexec(parser: argparse.ArgumentParser) -> None:
63
+ """Add EE-specific arguments to the parser."""
64
+
57
65
 
58
66
  def flower_superexec() -> None:
59
67
  """Run `flower-superexec` command."""
@@ -70,12 +78,28 @@ def flower_superexec() -> None:
70
78
  # Trigger telemetry event
71
79
  event(EventType.RUN_SUPEREXEC_ENTER, {"plugin_type": args.plugin_type})
72
80
 
81
+ # Load plugin config from YAML file if provided
82
+ plugin_config = None
83
+ if plugin_config_path := getattr(args, "plugin_config", None):
84
+ try:
85
+ with open(plugin_config_path, encoding="utf-8") as file:
86
+ yaml_config: Optional[dict[str, Any]] = yaml.safe_load(file)
87
+ if yaml_config is None or EXEC_PLUGIN_SECTION not in yaml_config:
88
+ raise ValueError(f"Missing '{EXEC_PLUGIN_SECTION}' section.")
89
+ plugin_config = yaml_config[EXEC_PLUGIN_SECTION]
90
+ except (FileNotFoundError, yaml.YAMLError, ValueError) as e:
91
+ flwr_exit(
92
+ ExitCode.SUPEREXEC_INVALID_PLUGIN_CONFIG,
93
+ f"Failed to load plugin config from '{plugin_config_path}': {e!r}",
94
+ )
95
+
73
96
  # Get the plugin class and stub class based on the plugin type
74
97
  plugin_class, stub_class = _get_plugin_and_stub_class(args.plugin_type)
75
98
  run_superexec(
76
99
  plugin_class=plugin_class,
77
100
  stub_class=stub_class, # type: ignore
78
101
  appio_api_address=args.appio_api_address,
102
+ plugin_config=plugin_config,
79
103
  flwr_dir=args.flwr_dir,
80
104
  parent_pid=args.parent_pid,
81
105
  health_server_address=args.health_server_address,
@@ -122,6 +146,7 @@ def _parse_args() -> argparse.ArgumentParser:
122
146
  help="The PID of the parent process. When set, the process will terminate "
123
147
  "when the parent process exits.",
124
148
  )
149
+ add_ee_args_superexec(parser)
125
150
  add_args_health(parser)
126
151
  return parser
127
152
 
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Constants for Flower infrastructure."""
16
+
17
+
18
+ # Top-level key in YAML config for exec plugin settings
19
+ EXEC_PLUGIN_SECTION = "exec_plugin"
@@ -17,7 +17,7 @@
17
17
 
18
18
  from abc import ABC, abstractmethod
19
19
  from collections.abc import Sequence
20
- from typing import Callable, Optional
20
+ from typing import Any, Callable, Optional
21
21
 
22
22
  from flwr.common.typing import Run
23
23
 
@@ -69,3 +69,13 @@ class ExecPlugin(ABC):
69
69
  The ID of the run associated with the token, used for tracking or
70
70
  logging purposes.
71
71
  """
72
+
73
+ # This method is optional to implement
74
+ def load_config(self, yaml_config: dict[str, Any]) -> None:
75
+ """Load configuration from a YAML dictionary.
76
+
77
+ Parameters
78
+ ----------
79
+ yaml_config : dict[str, Any]
80
+ A dictionary representing the YAML configuration.
81
+ """
@@ -17,10 +17,10 @@
17
17
 
18
18
  import time
19
19
  from logging import WARN
20
- from typing import Optional, Union
20
+ from typing import Any, Optional, Union
21
21
 
22
22
  from flwr.common.config import get_flwr_dir
23
- from flwr.common.exit import register_signal_handlers
23
+ from flwr.common.exit import ExitCode, flwr_exit, register_signal_handlers
24
24
  from flwr.common.grpc import create_channel, on_channel_state_change
25
25
  from flwr.common.logger import log
26
26
  from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
@@ -47,6 +47,7 @@ def run_superexec( # pylint: disable=R0913,R0914,R0917
47
47
  type[ClientAppIoStub], type[ServerAppIoStub], type[SimulationIoStub]
48
48
  ],
49
49
  appio_api_address: str,
50
+ plugin_config: Optional[dict[str, Any]] = None,
50
51
  flwr_dir: Optional[str] = None,
51
52
  parent_pid: Optional[int] = None,
52
53
  health_server_address: Optional[str] = None,
@@ -61,6 +62,9 @@ def run_superexec( # pylint: disable=R0913,R0914,R0917
61
62
  The gRPC stub class for the AppIO API.
62
63
  appio_api_address : str
63
64
  The address of the AppIO API.
65
+ plugin_config : Optional[dict[str, Any]] (default: None)
66
+ The configuration dictionary for the plugin. If `None`, the plugin will use
67
+ its default configuration.
64
68
  flwr_dir : Optional[str] (default: None)
65
69
  The Flower directory.
66
70
  parent_pid : Optional[int] (default: None)
@@ -113,6 +117,16 @@ def run_superexec( # pylint: disable=R0913,R0914,R0917
113
117
  get_run=get_run,
114
118
  )
115
119
 
120
+ # Load plugin configuration from file if provided
121
+ try:
122
+ if plugin_config is not None:
123
+ plugin.load_config(plugin_config)
124
+ except (KeyError, ValueError) as e:
125
+ flwr_exit(
126
+ code=ExitCode.SUPEREXEC_INVALID_PLUGIN_CONFIG,
127
+ message=f"Invalid plugin config: {e!r}",
128
+ )
129
+
116
130
  # Start the main loop
117
131
  try:
118
132
  while True:
@@ -0,0 +1,22 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """ArtifactProvider for SuperLink."""
16
+
17
+
18
+ from .artifact_provider import ArtifactProvider
19
+
20
+ __all__ = [
21
+ "ArtifactProvider",
22
+ ]