flwr-nightly 1.22.0.dev20250917__py3-none-any.whl → 1.22.0.dev20250918__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flwr/cli/new/new.py CHANGED
@@ -41,6 +41,7 @@ class MlFramework(str, Enum):
41
41
  JAX = "JAX"
42
42
  MLX = "MLX"
43
43
  NUMPY = "NumPy"
44
+ XGBOOST = "XGBoost"
44
45
  FLOWERTUNE = "FlowerTune"
45
46
  BASELINE = "Flower Baseline"
46
47
  PYTORCH_LEGACY_API = "PyTorch (Legacy API, deprecated)"
@@ -247,6 +248,7 @@ def new(
247
248
  MlFramework.TENSORFLOW.value,
248
249
  MlFramework.SKLEARN.value,
249
250
  MlFramework.NUMPY.value,
251
+ MlFramework.XGBOOST.value,
250
252
  "pytorch_legacy_api",
251
253
  ]
252
254
  if framework_str in frameworks_with_tasks:
@@ -0,0 +1,110 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import warnings
4
+
5
+ import numpy as np
6
+ import xgboost as xgb
7
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
8
+ from flwr.clientapp import ClientApp
9
+ from flwr.common.config import unflatten_dict
10
+
11
+ from $import_name.task import load_data, replace_keys
12
+
13
+ warnings.filterwarnings("ignore", category=UserWarning)
14
+
15
+
16
+ # Flower ClientApp
17
+ app = ClientApp()
18
+
19
+
20
+ def _local_boost(bst_input, num_local_round, train_dmatrix):
21
+ # Update trees based on local training data.
22
+ for i in range(num_local_round):
23
+ bst_input.update(train_dmatrix, bst_input.num_boosted_rounds())
24
+
25
+ # Bagging: extract the last N=num_local_round trees for sever aggregation
26
+ bst = bst_input[
27
+ bst_input.num_boosted_rounds()
28
+ - num_local_round : bst_input.num_boosted_rounds()
29
+ ]
30
+ return bst
31
+
32
+
33
+ @app.train()
34
+ def train(msg: Message, context: Context) -> Message:
35
+ # Load model and data
36
+ partition_id = context.node_config["partition-id"]
37
+ num_partitions = context.node_config["num-partitions"]
38
+ train_dmatrix, _, num_train, _ = load_data(partition_id, num_partitions)
39
+
40
+ # Read from run config
41
+ num_local_round = context.run_config["local-epochs"]
42
+ # Flatted config dict and replace "-" with "_"
43
+ cfg = replace_keys(unflatten_dict(context.run_config))
44
+ params = cfg["params"]
45
+
46
+ global_round = msg.content["config"]["server-round"]
47
+ if global_round == 1:
48
+ # First round local training
49
+ bst = xgb.train(
50
+ params,
51
+ train_dmatrix,
52
+ num_boost_round=num_local_round,
53
+ )
54
+ else:
55
+ bst = xgb.Booster(params=params)
56
+ global_model = bytearray(msg.content["arrays"]["0"].numpy().tobytes())
57
+
58
+ # Load global model into booster
59
+ bst.load_model(global_model)
60
+
61
+ # Local training
62
+ bst = _local_boost(bst, num_local_round, train_dmatrix)
63
+
64
+ # Save model
65
+ local_model = bst.save_raw("json")
66
+ model_np = np.frombuffer(local_model, dtype=np.uint8)
67
+
68
+ # Construct reply message
69
+ # Note: we store the model as the first item in a list into ArrayRecord,
70
+ # which can be accessed using index ["0"].
71
+ model_record = ArrayRecord([model_np])
72
+ metrics = {
73
+ "num-examples": num_train,
74
+ }
75
+ metric_record = MetricRecord(metrics)
76
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
77
+ return Message(content=content, reply_to=msg)
78
+
79
+
80
+ @app.evaluate()
81
+ def evaluate(msg: Message, context: Context) -> Message:
82
+ # Load model and data
83
+ partition_id = context.node_config["partition-id"]
84
+ num_partitions = context.node_config["num-partitions"]
85
+ _, valid_dmatrix, _, num_val = load_data(partition_id, num_partitions)
86
+
87
+ # Load config
88
+ cfg = replace_keys(unflatten_dict(context.run_config))
89
+ params = cfg["params"]
90
+
91
+ # Load global model
92
+ bst = xgb.Booster(params=params)
93
+ global_model = bytearray(msg.content["arrays"]["0"].numpy().tobytes())
94
+ bst.load_model(global_model)
95
+
96
+ # Run evaluation
97
+ eval_results = bst.eval_set(
98
+ evals=[(valid_dmatrix, "valid")],
99
+ iteration=bst.num_boosted_rounds() - 1,
100
+ )
101
+ auc = float(eval_results.split("\t")[1].split(":")[1])
102
+
103
+ # Construct and return reply Message
104
+ metrics = {
105
+ "auc": auc,
106
+ "num-examples": num_val,
107
+ }
108
+ metric_record = MetricRecord(metrics)
109
+ content = RecordDict({"metrics": metric_record})
110
+ return Message(content=content, reply_to=msg)
@@ -0,0 +1,56 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import numpy as np
4
+ import xgboost as xgb
5
+ from flwr.app import ArrayRecord, Context
6
+ from flwr.common.config import unflatten_dict
7
+ from flwr.serverapp import Grid, ServerApp
8
+ from flwr.serverapp.strategy import FedXgbBagging
9
+
10
+ from $import_name.task import replace_keys
11
+
12
+ # Create ServerApp
13
+ app = ServerApp()
14
+
15
+
16
+ @app.main()
17
+ def main(grid: Grid, context: Context) -> None:
18
+ # Read run config
19
+ num_rounds = context.run_config["num-server-rounds"]
20
+ fraction_train = context.run_config["fraction-train"]
21
+ fraction_evaluate = context.run_config["fraction-evaluate"]
22
+ # Flatted config dict and replace "-" with "_"
23
+ cfg = replace_keys(unflatten_dict(context.run_config))
24
+ params = cfg["params"]
25
+
26
+ # Init global model
27
+ # Init with an empty object; the XGBooster will be created
28
+ # and trained on the client side.
29
+ global_model = b""
30
+ # Note: we store the model as the first item in a list into ArrayRecord,
31
+ # which can be accessed using index ["0"].
32
+ arrays = ArrayRecord([np.frombuffer(global_model, dtype=np.uint8)])
33
+
34
+ # Initialize FedXgbBagging strategy
35
+ strategy = FedXgbBagging(
36
+ fraction_train=fraction_train,
37
+ fraction_evaluate=fraction_evaluate,
38
+ )
39
+
40
+ # Start strategy, run FedXgbBagging for `num_rounds`
41
+ result = strategy.start(
42
+ grid=grid,
43
+ initial_arrays=arrays,
44
+ num_rounds=num_rounds,
45
+ )
46
+
47
+ # Save final model to disk
48
+ bst = xgb.Booster(params=params)
49
+ global_model = bytearray(result.arrays["0"].numpy().tobytes())
50
+
51
+ # Load global model into booster
52
+ bst.load_model(global_model)
53
+
54
+ # Save model
55
+ print("\nSaving final model to disk...")
56
+ bst.save_model("final_model.json")
@@ -0,0 +1,67 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import xgboost as xgb
4
+ from flwr_datasets import FederatedDataset
5
+ from flwr_datasets.partitioner import IidPartitioner
6
+
7
+
8
+ def train_test_split(partition, test_fraction, seed):
9
+ """Split the data into train and validation set given split rate."""
10
+ train_test = partition.train_test_split(test_size=test_fraction, seed=seed)
11
+ partition_train = train_test["train"]
12
+ partition_test = train_test["test"]
13
+
14
+ num_train = len(partition_train)
15
+ num_test = len(partition_test)
16
+
17
+ return partition_train, partition_test, num_train, num_test
18
+
19
+
20
+ def transform_dataset_to_dmatrix(data):
21
+ """Transform dataset to DMatrix format for xgboost."""
22
+ x = data["inputs"]
23
+ y = data["label"]
24
+ new_data = xgb.DMatrix(x, label=y)
25
+ return new_data
26
+
27
+
28
+ fds = None # Cache FederatedDataset
29
+
30
+
31
+ def load_data(partition_id, num_clients):
32
+ """Load partition HIGGS data."""
33
+ # Only initialize `FederatedDataset` once
34
+ global fds
35
+ if fds is None:
36
+ partitioner = IidPartitioner(num_partitions=num_clients)
37
+ fds = FederatedDataset(
38
+ dataset="jxie/higgs",
39
+ partitioners={"train": partitioner},
40
+ )
41
+
42
+ # Load the partition for this `partition_id`
43
+ partition = fds.load_partition(partition_id, split="train")
44
+ partition.set_format("numpy")
45
+
46
+ # Train/test splitting
47
+ train_data, valid_data, num_train, num_val = train_test_split(
48
+ partition, test_fraction=0.2, seed=42
49
+ )
50
+
51
+ # Reformat data to DMatrix for xgboost
52
+ train_dmatrix = transform_dataset_to_dmatrix(train_data)
53
+ valid_dmatrix = transform_dataset_to_dmatrix(valid_data)
54
+
55
+ return train_dmatrix, valid_dmatrix, num_train, num_val
56
+
57
+
58
+ def replace_keys(input_dict, match="-", target="_"):
59
+ """Recursively replace match string with target string in dictionary keys."""
60
+ new_dict = {}
61
+ for key, value in input_dict.items():
62
+ new_key = key.replace(match, target)
63
+ if isinstance(value, dict):
64
+ new_dict[new_key] = replace_keys(value, match, target)
65
+ else:
66
+ new_dict[new_key] = value
67
+ return new_dict
@@ -0,0 +1,61 @@
1
+ # =====================================================================
2
+ # For a full TOML configuration guide, check the Flower docs:
3
+ # https://flower.ai/docs/framework/how-to-configure-pyproject-toml.html
4
+ # =====================================================================
5
+
6
+ [build-system]
7
+ requires = ["hatchling"]
8
+ build-backend = "hatchling.build"
9
+
10
+ [project]
11
+ name = "$package_name"
12
+ version = "1.0.0"
13
+ description = ""
14
+ license = "Apache-2.0"
15
+ # Dependencies for your Flower App
16
+ dependencies = [
17
+ "flwr[simulation]>=1.22.0",
18
+ "flwr-datasets>=0.5.0",
19
+ "xgboost>=2.0.0",
20
+ ]
21
+
22
+ [tool.hatch.build.targets.wheel]
23
+ packages = ["."]
24
+
25
+ [tool.flwr.app]
26
+ publisher = "$username"
27
+
28
+ [tool.flwr.app.components]
29
+ serverapp = "$import_name.server_app:app"
30
+ clientapp = "$import_name.client_app:app"
31
+
32
+ # Custom config values accessible via `context.run_config`
33
+ [tool.flwr.app.config]
34
+ num-server-rounds = 3
35
+ fraction-train = 0.1
36
+ fraction-evaluate = 0.1
37
+ local-epochs = 1
38
+
39
+ # XGBoost parameters
40
+ params.objective = "binary:logistic"
41
+ params.eta = 0.1 # Learning rate
42
+ params.max-depth = 8
43
+ params.eval-metric = "auc"
44
+ params.nthread = 16
45
+ params.num-parallel-tree = 1
46
+ params.subsample = 1
47
+ params.tree-method = "hist"
48
+
49
+ # Default federation to use when running the app
50
+ [tool.flwr.federations]
51
+ default = "local-simulation"
52
+
53
+ # Local simulation federation with 10 virtual SuperNodes
54
+ [tool.flwr.federations.local-simulation]
55
+ options.num-supernodes = 10
56
+
57
+ # Remote federation example for use with SuperLink
58
+ [tool.flwr.federations.remote-federation]
59
+ address = "<SUPERLINK-ADDRESS>:<PORT>"
60
+ insecure = true # Remove this line to enable TLS
61
+ # root-certificates = "<PATH/TO/ca.crt>" # For TLS setup
@@ -45,6 +45,7 @@ class ExitCode:
45
45
  SUPERNODE_NODE_AUTH_KEYS_INVALID = 302
46
46
 
47
47
  # SuperExec-specific exit codes (400-499)
48
+ SUPEREXEC_INVALID_PLUGIN_CONFIG = 400
48
49
 
49
50
  # Common exit codes (600-699)
50
51
  COMMON_ADDRESS_INVALID = 600
@@ -112,6 +113,9 @@ EXIT_CODE_HELP = {
112
113
  "file and try again."
113
114
  ),
114
115
  # SuperExec-specific exit codes (400-499)
116
+ ExitCode.SUPEREXEC_INVALID_PLUGIN_CONFIG: (
117
+ "The YAML configuration for the SuperExec plugin is invalid."
118
+ ),
115
119
  # Common exit codes (600-699)
116
120
  ExitCode.COMMON_ADDRESS_INVALID: (
117
121
  "Please provide a valid URL, IPv4 or IPv6 address."
@@ -27,6 +27,7 @@ from .fedmedian import FedMedian
27
27
  from .fedprox import FedProx
28
28
  from .fedtrimmedavg import FedTrimmedAvg
29
29
  from .fedxgb_bagging import FedXgbBagging
30
+ from .fedxgb_cyclic import FedXgbCyclic
30
31
  from .fedyogi import FedYogi
31
32
  from .result import Result
32
33
  from .strategy import Strategy
@@ -42,6 +43,7 @@ __all__ = [
42
43
  "FedProx",
43
44
  "FedTrimmedAvg",
44
45
  "FedXgbBagging",
46
+ "FedXgbCyclic",
45
47
  "FedYogi",
46
48
  "Result",
47
49
  "Strategy",
@@ -0,0 +1,220 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower message-based FedXgbCyclic strategy."""
16
+
17
+
18
+ from collections.abc import Iterable
19
+ from logging import INFO
20
+ from typing import Callable, Optional, cast
21
+
22
+ from flwr.common import (
23
+ ArrayRecord,
24
+ ConfigRecord,
25
+ Message,
26
+ MessageType,
27
+ MetricRecord,
28
+ RecordDict,
29
+ log,
30
+ )
31
+ from flwr.server import Grid
32
+
33
+ from .fedavg import FedAvg
34
+ from .strategy_utils import sample_nodes
35
+
36
+
37
+ # pylint: disable=line-too-long
38
+ class FedXgbCyclic(FedAvg):
39
+ """Configurable FedXgbCyclic strategy implementation.
40
+
41
+ Parameters
42
+ ----------
43
+ fraction_train : float (default: 1.0)
44
+ Fraction of nodes used during training. In case `min_train_nodes`
45
+ is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
46
+ will still be sampled.
47
+ fraction_evaluate : float (default: 1.0)
48
+ Fraction of nodes used during validation. In case `min_evaluate_nodes`
49
+ is larger than `fraction_evaluate * total_connected_nodes`,
50
+ `min_evaluate_nodes` will still be sampled.
51
+ min_available_nodes : int (default: 2)
52
+ Minimum number of total nodes in the system.
53
+ weighted_by_key : str (default: "num-examples")
54
+ The key within each MetricRecord whose value is used as the weight when
55
+ computing weighted averages for both ArrayRecords and MetricRecords.
56
+ arrayrecord_key : str (default: "arrays")
57
+ Key used to store the ArrayRecord when constructing Messages.
58
+ configrecord_key : str (default: "config")
59
+ Key used to store the ConfigRecord when constructing Messages.
60
+ train_metrics_aggr_fn : Optional[callable] (default: None)
61
+ Function with signature (list[RecordDict], str) -> MetricRecord,
62
+ used to aggregate MetricRecords from training round replies.
63
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
64
+ average using the provided weight factor key.
65
+ evaluate_metrics_aggr_fn : Optional[callable] (default: None)
66
+ Function with signature (list[RecordDict], str) -> MetricRecord,
67
+ used to aggregate MetricRecords from training round replies.
68
+ If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
69
+ average using the provided weight factor key.
70
+ """
71
+
72
+ # pylint: disable=too-many-arguments,too-many-positional-arguments
73
+ def __init__(
74
+ self,
75
+ fraction_train: float = 1.0,
76
+ fraction_evaluate: float = 1.0,
77
+ min_available_nodes: int = 2,
78
+ weighted_by_key: str = "num-examples",
79
+ arrayrecord_key: str = "arrays",
80
+ configrecord_key: str = "config",
81
+ train_metrics_aggr_fn: Optional[
82
+ Callable[[list[RecordDict], str], MetricRecord]
83
+ ] = None,
84
+ evaluate_metrics_aggr_fn: Optional[
85
+ Callable[[list[RecordDict], str], MetricRecord]
86
+ ] = None,
87
+ ) -> None:
88
+ super().__init__(
89
+ fraction_train=fraction_train,
90
+ fraction_evaluate=fraction_evaluate,
91
+ min_train_nodes=2,
92
+ min_evaluate_nodes=2,
93
+ min_available_nodes=min_available_nodes,
94
+ weighted_by_key=weighted_by_key,
95
+ arrayrecord_key=arrayrecord_key,
96
+ configrecord_key=configrecord_key,
97
+ train_metrics_aggr_fn=train_metrics_aggr_fn,
98
+ evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
99
+ )
100
+
101
+ self.registered_nodes: dict[int, int] = {}
102
+
103
+ if fraction_train not in (0.0, 1.0):
104
+ raise ValueError(
105
+ "fraction_train can only be set to 1.0 or 0.0 for FedXgbCyclic."
106
+ )
107
+ if fraction_evaluate not in (0.0, 1.0):
108
+ raise ValueError(
109
+ "fraction_evaluate can only be set to 1.0 or 0.0 for FedXgbCyclic."
110
+ )
111
+
112
+ def _reorder_nodes(self, node_ids: list[int]) -> list[int]:
113
+ """Re-order node ids based on registered nodes.
114
+
115
+ Each node ID is assigned a persistent index in `self.registered_nodes`
116
+ the first time it appears. The input list is then reordered according
117
+ to these stored indices, and the result is compacted into ascending
118
+ order (1..N) for the current call.
119
+ """
120
+ # Assign new indices to unknown nodes
121
+ next_index = max(self.registered_nodes.values(), default=0) + 1
122
+ for nid in node_ids:
123
+ if nid not in self.registered_nodes:
124
+ self.registered_nodes[nid] = next_index
125
+ next_index += 1
126
+
127
+ # Sort node_ids by their stored indices
128
+ sorted_by_index = sorted(node_ids, key=lambda x: self.registered_nodes[x])
129
+
130
+ # Compact re-map of indices just for this output list
131
+ unique_indices = sorted(self.registered_nodes[nid] for nid in sorted_by_index)
132
+ remap = {old: new for new, old in enumerate(unique_indices, start=1)}
133
+
134
+ # Build the result list ordered by compact indices
135
+ result_list = [
136
+ nid
137
+ for _, nid in sorted(
138
+ (remap[self.registered_nodes[nid]], nid) for nid in sorted_by_index
139
+ )
140
+ ]
141
+ return result_list
142
+
143
+ def _make_sampling(
144
+ self, grid: Grid, server_round: int, configure_type: str
145
+ ) -> list[int]:
146
+ """Sample nodes using the Grid."""
147
+ # Sample nodes
148
+ num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_train)
149
+ sample_size = max(num_nodes, self.min_train_nodes)
150
+ node_ids, _ = sample_nodes(grid, self.min_available_nodes, sample_size)
151
+
152
+ # Re-order node_ids
153
+ node_ids = self._reorder_nodes(node_ids)
154
+
155
+ # Sample the clients sequentially given server_round
156
+ sampled_idx = (server_round - 1) % len(node_ids)
157
+ sampled_node_id = [node_ids[sampled_idx]]
158
+
159
+ log(
160
+ INFO,
161
+ f"{configure_type}: Sampled %s nodes (out of %s)",
162
+ len(sampled_node_id),
163
+ len(node_ids),
164
+ )
165
+ return sampled_node_id
166
+
167
+ def configure_train(
168
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
169
+ ) -> Iterable[Message]:
170
+ """Configure the next round of federated training."""
171
+ # Sample one node
172
+ sampled_node_id = self._make_sampling(grid, server_round, "configure_train")
173
+
174
+ # Always inject current server round
175
+ config["server-round"] = server_round
176
+
177
+ # Construct messages
178
+ record = RecordDict(
179
+ {self.arrayrecord_key: arrays, self.configrecord_key: config}
180
+ )
181
+ return self._construct_messages(record, sampled_node_id, MessageType.TRAIN)
182
+
183
+ def aggregate_train(
184
+ self,
185
+ server_round: int,
186
+ replies: Iterable[Message],
187
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
188
+ """Aggregate ArrayRecords and MetricRecords in the received Messages."""
189
+ valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
190
+
191
+ arrays, metrics = None, None
192
+ if valid_replies:
193
+ reply_contents = [msg.content for msg in valid_replies]
194
+ array_record_key = next(iter(reply_contents[0].array_records.keys()))
195
+
196
+ # Fetch the client model from current round as global model
197
+ arrays = cast(ArrayRecord, reply_contents[0][array_record_key])
198
+
199
+ # Aggregate MetricRecords
200
+ metrics = self.train_metrics_aggr_fn(
201
+ reply_contents,
202
+ self.weighted_by_key,
203
+ )
204
+ return arrays, metrics
205
+
206
+ def configure_evaluate(
207
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
208
+ ) -> Iterable[Message]:
209
+ """Configure the next round of federated evaluation."""
210
+ # Sample one node
211
+ sampled_node_id = self._make_sampling(grid, server_round, "configure_evaluate")
212
+
213
+ # Always inject current server round
214
+ config["server-round"] = server_round
215
+
216
+ # Construct messages
217
+ record = RecordDict(
218
+ {self.arrayrecord_key: arrays, self.configrecord_key: config}
219
+ )
220
+ return self._construct_messages(record, sampled_node_id, MessageType.EVALUATE)
@@ -17,7 +17,9 @@
17
17
 
18
18
  import argparse
19
19
  from logging import INFO
20
- from typing import Optional
20
+ from typing import Any, Optional
21
+
22
+ import yaml
21
23
 
22
24
  from flwr.common import EventType, event
23
25
  from flwr.common.constant import ExecPluginType
@@ -26,6 +28,7 @@ from flwr.common.logger import log
26
28
  from flwr.proto.clientappio_pb2_grpc import ClientAppIoStub
27
29
  from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub
28
30
  from flwr.proto.simulationio_pb2_grpc import SimulationIoStub
31
+ from flwr.supercore.constant import EXEC_PLUGIN_SECTION
29
32
  from flwr.supercore.grpc_health import add_args_health
30
33
  from flwr.supercore.superexec.plugin import (
31
34
  ClientAppExecPlugin,
@@ -36,6 +39,7 @@ from flwr.supercore.superexec.plugin import (
36
39
  from flwr.supercore.superexec.run_superexec import run_superexec
37
40
 
38
41
  try:
42
+ from flwr.ee import add_ee_args_superexec
39
43
  from flwr.ee.constant import ExecEePluginType
40
44
  from flwr.ee.exec_plugin import get_ee_plugin_and_stub_class
41
45
  except ImportError:
@@ -54,6 +58,10 @@ except ImportError:
54
58
  """Get the EE plugin class and stub class based on the plugin type."""
55
59
  return None
56
60
 
61
+ # pylint: disable-next=unused-argument
62
+ def add_ee_args_superexec(parser: argparse.ArgumentParser) -> None:
63
+ """Add EE-specific arguments to the parser."""
64
+
57
65
 
58
66
  def flower_superexec() -> None:
59
67
  """Run `flower-superexec` command."""
@@ -70,12 +78,28 @@ def flower_superexec() -> None:
70
78
  # Trigger telemetry event
71
79
  event(EventType.RUN_SUPEREXEC_ENTER, {"plugin_type": args.plugin_type})
72
80
 
81
+ # Load plugin config from YAML file if provided
82
+ plugin_config = None
83
+ if plugin_config_path := getattr(args, "plugin_config", None):
84
+ try:
85
+ with open(plugin_config_path, encoding="utf-8") as file:
86
+ yaml_config: Optional[dict[str, Any]] = yaml.safe_load(file)
87
+ if yaml_config is None or EXEC_PLUGIN_SECTION not in yaml_config:
88
+ raise ValueError(f"Missing '{EXEC_PLUGIN_SECTION}' section.")
89
+ plugin_config = yaml_config[EXEC_PLUGIN_SECTION]
90
+ except (FileNotFoundError, yaml.YAMLError, ValueError) as e:
91
+ flwr_exit(
92
+ ExitCode.SUPEREXEC_INVALID_PLUGIN_CONFIG,
93
+ f"Failed to load plugin config from '{plugin_config_path}': {e!r}",
94
+ )
95
+
73
96
  # Get the plugin class and stub class based on the plugin type
74
97
  plugin_class, stub_class = _get_plugin_and_stub_class(args.plugin_type)
75
98
  run_superexec(
76
99
  plugin_class=plugin_class,
77
100
  stub_class=stub_class, # type: ignore
78
101
  appio_api_address=args.appio_api_address,
102
+ plugin_config=plugin_config,
79
103
  flwr_dir=args.flwr_dir,
80
104
  parent_pid=args.parent_pid,
81
105
  health_server_address=args.health_server_address,
@@ -122,6 +146,7 @@ def _parse_args() -> argparse.ArgumentParser:
122
146
  help="The PID of the parent process. When set, the process will terminate "
123
147
  "when the parent process exits.",
124
148
  )
149
+ add_ee_args_superexec(parser)
125
150
  add_args_health(parser)
126
151
  return parser
127
152
 
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Constants for Flower infrastructure."""
16
+
17
+
18
+ # Top-level key in YAML config for exec plugin settings
19
+ EXEC_PLUGIN_SECTION = "exec_plugin"
@@ -17,7 +17,7 @@
17
17
 
18
18
  from abc import ABC, abstractmethod
19
19
  from collections.abc import Sequence
20
- from typing import Callable, Optional
20
+ from typing import Any, Callable, Optional
21
21
 
22
22
  from flwr.common.typing import Run
23
23
 
@@ -69,3 +69,13 @@ class ExecPlugin(ABC):
69
69
  The ID of the run associated with the token, used for tracking or
70
70
  logging purposes.
71
71
  """
72
+
73
+ # This method is optional to implement
74
+ def load_config(self, yaml_config: dict[str, Any]) -> None:
75
+ """Load configuration from a YAML dictionary.
76
+
77
+ Parameters
78
+ ----------
79
+ yaml_config : dict[str, Any]
80
+ A dictionary representing the YAML configuration.
81
+ """
@@ -17,10 +17,10 @@
17
17
 
18
18
  import time
19
19
  from logging import WARN
20
- from typing import Optional, Union
20
+ from typing import Any, Optional, Union
21
21
 
22
22
  from flwr.common.config import get_flwr_dir
23
- from flwr.common.exit import register_signal_handlers
23
+ from flwr.common.exit import ExitCode, flwr_exit, register_signal_handlers
24
24
  from flwr.common.grpc import create_channel, on_channel_state_change
25
25
  from flwr.common.logger import log
26
26
  from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
@@ -47,6 +47,7 @@ def run_superexec( # pylint: disable=R0913,R0914,R0917
47
47
  type[ClientAppIoStub], type[ServerAppIoStub], type[SimulationIoStub]
48
48
  ],
49
49
  appio_api_address: str,
50
+ plugin_config: Optional[dict[str, Any]] = None,
50
51
  flwr_dir: Optional[str] = None,
51
52
  parent_pid: Optional[int] = None,
52
53
  health_server_address: Optional[str] = None,
@@ -61,6 +62,9 @@ def run_superexec( # pylint: disable=R0913,R0914,R0917
61
62
  The gRPC stub class for the AppIO API.
62
63
  appio_api_address : str
63
64
  The address of the AppIO API.
65
+ plugin_config : Optional[dict[str, Any]] (default: None)
66
+ The configuration dictionary for the plugin. If `None`, the plugin will use
67
+ its default configuration.
64
68
  flwr_dir : Optional[str] (default: None)
65
69
  The Flower directory.
66
70
  parent_pid : Optional[int] (default: None)
@@ -113,6 +117,16 @@ def run_superexec( # pylint: disable=R0913,R0914,R0917
113
117
  get_run=get_run,
114
118
  )
115
119
 
120
+ # Load plugin configuration from file if provided
121
+ try:
122
+ if plugin_config is not None:
123
+ plugin.load_config(plugin_config)
124
+ except (KeyError, ValueError) as e:
125
+ flwr_exit(
126
+ code=ExitCode.SUPEREXEC_INVALID_PLUGIN_CONFIG,
127
+ message=f"Invalid plugin config: {e!r}",
128
+ )
129
+
116
130
  # Start the main loop
117
131
  try:
118
132
  while True:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: flwr-nightly
3
- Version: 1.22.0.dev20250917
3
+ Version: 1.22.0.dev20250918
4
4
  Summary: Flower: A Friendly Federated AI Framework
5
5
  License: Apache-2.0
6
6
  Keywords: Artificial Intelligence,Federated AI,Federated Analytics,Federated Evaluation,Federated Learning,Flower,Machine Learning
@@ -18,7 +18,7 @@ flwr/cli/login/__init__.py,sha256=B1SXKU3HCQhWfFDMJhlC7FOl8UsvH4mxysxeBnrfyUE,80
18
18
  flwr/cli/login/login.py,sha256=RM1Jiv_VFm3oz4rTHSr3D87X90lW3WzErjBBU7WviWY,4309
19
19
  flwr/cli/ls.py,sha256=3YK7cpoImJ7PbjlP_JgYRQWz1GymX2q7Reu-mKJEpao,10957
20
20
  flwr/cli/new/__init__.py,sha256=QA1E2QtzPvFCjLTUHnFnJbufuFiGyT_0Y53Wpbvg1F0,790
21
- flwr/cli/new/new.py,sha256=GvXapfYMrUQMktn1qJ3bfxfXeK0IAsxsxHBIGfPg3sE,10535
21
+ flwr/cli/new/new.py,sha256=nIuUrQSGDjI4kqnymlq-rOT0RU3AHwZrat3abqHhCwM,10598
22
22
  flwr/cli/new/templates/__init__.py,sha256=FpjWCfIySU2DB4kh0HOXLAjlZNNFDTVU4w3HoE2TzcI,725
23
23
  flwr/cli/new/templates/app/.gitignore.tpl,sha256=HZJcGQoxp7aUzaPg8Uqch3kNrIESwr9yjimDxJYgXVY,3104
24
24
  flwr/cli/new/templates/app/LICENSE.tpl,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
@@ -39,6 +39,7 @@ flwr/cli/new/templates/app/code/client.pytorch.py.tpl,sha256=fYoh-dTu07LkqNYvwcx
39
39
  flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl,sha256=fuxVmZpjHIueNy_aHWF81531vmi8DGu4CYjYDqmUwWo,1705
40
40
  flwr/cli/new/templates/app/code/client.sklearn.py.tpl,sha256=0qqEe-RRjkHGOH8gsD9e83ae-kyyYixhyBgzVHjYpzk,3500
41
41
  flwr/cli/new/templates/app/code/client.tensorflow.py.tpl,sha256=8o55KXpsbF_rv6o98ZNYJDCazjwMp_RPTaSzDfT7Qlw,2682
42
+ flwr/cli/new/templates/app/code/client.xgboost.py.tpl,sha256=-ipRV8gqpbEg7Mht77Yyqs1viL-3JYSVZR47I7xeG4c,3493
42
43
  flwr/cli/new/templates/app/code/dataset.baseline.py.tpl,sha256=jbd_exHAk2-Blu_kVutjPO6a_dkJQWb232zxSeXIZ1k,1453
43
44
  flwr/cli/new/templates/app/code/flwr_tune/__init__.py,sha256=Xq5fEn5yZkw6HAJi10T_3HRBoqN5_5pNqJHY4wXvD5k,748
44
45
  flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl,sha256=p6BzTbP-mXkFANiVC7iz3YlskOidWaLC341IJyrUotQ,2951
@@ -56,6 +57,7 @@ flwr/cli/new/templates/app/code/server.pytorch.py.tpl,sha256=epARqfcQ-EQsdZwaaaU
56
57
  flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl,sha256=gvBsGA_Jg9kAH8xTxjzTjMcvBtciuccOwQFbO7ey8tU,916
57
58
  flwr/cli/new/templates/app/code/server.sklearn.py.tpl,sha256=ehQ5VRgBn92WeFl6kupwJnuxSNkKvE-EvKde6A9mNQo,1377
58
59
  flwr/cli/new/templates/app/code/server.tensorflow.py.tpl,sha256=2-WTOPd-ewdLd9QmSlflIH7ix7zxAzPEOZoyiPBOy8c,1010
60
+ flwr/cli/new/templates/app/code/server.xgboost.py.tpl,sha256=fwtCRyCG2hDSH1zVMyZv7zA7wsdKNPfpugDSZjxCs5Q,1746
59
61
  flwr/cli/new/templates/app/code/strategy.baseline.py.tpl,sha256=YkHAgppUeD2BnBoGfVB6dEvBfjuIPGsU1gw4CiUi3qA,40
60
62
  flwr/cli/new/templates/app/code/task.huggingface.py.tpl,sha256=piBbY3Dg60bQnCg15uzMw0QiL5SDOYX4YhQouy-X2zI,3164
61
63
  flwr/cli/new/templates/app/code/task.jax.py.tpl,sha256=Fb0XgdTAQplM-ZCusI081XA9asO3gHptH772S-Xcyy8,1525
@@ -65,6 +67,7 @@ flwr/cli/new/templates/app/code/task.pytorch.py.tpl,sha256=RKA5lV6O6OnVKZ2r75pbz
65
67
  flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl,sha256=XlJqA4Ix_PloO_zJLhjiN5vDj16w3I4CPVGdmbe8asE,3800
66
68
  flwr/cli/new/templates/app/code/task.sklearn.py.tpl,sha256=vHdhtMp0FHxbYafXyhDT9aKmmmA0Jvpx5Oum1Yu9lWY,1850
67
69
  flwr/cli/new/templates/app/code/task.tensorflow.py.tpl,sha256=impgWN7MfztmcWF4xh1llcZGsgTvrb1HD5ZE0t-8U08,1731
70
+ flwr/cli/new/templates/app/code/task.xgboost.py.tpl,sha256=0xO8jQvrHuB1llVDopQPOmt5Hn6rBw8umzoNwiZZs-o,2135
68
71
  flwr/cli/new/templates/app/code/utils.baseline.py.tpl,sha256=YkHAgppUeD2BnBoGfVB6dEvBfjuIPGsU1gw4CiUi3qA,40
69
72
  flwr/cli/new/templates/app/pyproject.baseline.toml.tpl,sha256=mPIMPfneVYV03l8jWRgWZ0V5Kh_pJw-AMUvkhcKkmL8,3182
70
73
  flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl,sha256=wqYW4bWcf12m0U2njR995lySSesFvnHB-eSkPWz-QdM,2501
@@ -76,6 +79,7 @@ flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl,sha256=SE4H23OFkQbqNU64nYf
76
79
  flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl,sha256=docQbs3MuRR-yT24lVz7N2sQL3Sj49EHuOCuRj_0djQ,1508
77
80
  flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl,sha256=apauU_PUmLEbt2rjckKniEbzdRs1EnMri_qgtHtBJZ8,1484
78
81
  flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl,sha256=LQpDKJTEnRKj5Ygn5FkT44SxlnLVprkPlbrGaFf5Q50,1508
82
+ flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl,sha256=504pHibNRGFe-DLnzqHLYhKeF_n8BPMv0Xog5EfnZ0M,1661
79
83
  flwr/cli/pull.py,sha256=dHiMe6x8w8yRoFNKpjA-eiPD6eFiHz4Vah5HZrqNpuo,3364
80
84
  flwr/cli/run/__init__.py,sha256=RPyB7KbYTFl6YRiilCch6oezxrLQrl1kijV7BMGkLbA,790
81
85
  flwr/cli/run/run.py,sha256=ECa0kup9dn15O70H74QdgUsEaeErbzDqVX_U0zZO5IM,8173
@@ -127,7 +131,7 @@ flwr/common/event_log_plugin/__init__.py,sha256=ts3VAL3Fk6Grp1EK_1Qg_V-BfOof9F86
127
131
  flwr/common/event_log_plugin/event_log_plugin.py,sha256=4SkVa1Ic-sPlICJShBuggXmXDcQtWQ1KDby4kthFNF0,2064
128
132
  flwr/common/exit/__init__.py,sha256=8W7xaO1iw0vacgmQW7FTFbSh7csNv6XfsgIlnIbNF6U,978
129
133
  flwr/common/exit/exit.py,sha256=DcXJfbpW1g-pQJqSZmps-1MZydd7T7RaarghIf2e4tU,3636
130
- flwr/common/exit/exit_code.py,sha256=e8O71zIqVT1H84mNBeenTz7S39yPZSpZQm-xUenpzN4,5249
134
+ flwr/common/exit/exit_code.py,sha256=Xa1NFGny2cefZ62kZZOfT8eii__PolMWCHxYmxoSQ2s,5416
131
135
  flwr/common/exit/exit_handler.py,sha256=uzDdWwhKgc1w5csZS52b86kjmEApmDZKwMn_X0zDZZo,2126
132
136
  flwr/common/exit/signal_handler.py,sha256=wqxykrwgmpFzmEMhpnlM7RtO0PnqIvYiSB1qYahZ5Sk,3710
133
137
  flwr/common/grpc.py,sha256=nHnFC7E84pZVTvd6BhcSYWnGd0jf8t5UmGea04qvilM,9806
@@ -333,7 +337,7 @@ flwr/server/workflow/secure_aggregation/secaggplus_workflow.py,sha256=DkayCsnlAy
333
337
  flwr/serverapp/__init__.py,sha256=ZujKNXULwhWYQhFnxOOT5Wi9MRq2JCWFhAAj7ouiQ78,884
334
338
  flwr/serverapp/dp_fixed_clipping.py,sha256=wbP4W7CaUHXdll8ZSVUnTBSEWrnWM00CGk63rOR-Q2s,12133
335
339
  flwr/serverapp/exception.py,sha256=5cuH-2AafvihzosWDdDjuMmHdDqZ1XxHvCqZXNBVklw,1334
336
- flwr/serverapp/strategy/__init__.py,sha256=MJNWeBWzpQDfqhhpND5LncxPVK91kUo_Yzu6IMFdLCc,1492
340
+ flwr/serverapp/strategy/__init__.py,sha256=mt2l31EAQ9oSvBcQhk4Jj4SvTePmWzBHQxZqL1v0uhE,1552
337
341
  flwr/serverapp/strategy/dp_fixed_clipping.py,sha256=wbP4W7CaUHXdll8ZSVUnTBSEWrnWM00CGk63rOR-Q2s,12133
338
342
  flwr/serverapp/strategy/fedadagrad.py,sha256=fD65P6OEERa_pxq847e1UZpA083AcWR44XavYB0naGM,6343
339
343
  flwr/serverapp/strategy/fedadam.py,sha256=s3xPIqhopy6yPTeFxevSPnc7a6BcKnKsvo2AaO6Z_xs,7138
@@ -344,11 +348,11 @@ flwr/serverapp/strategy/fedopt.py,sha256=kqT0uV2IUE93O72XEVa1JJo61dcwbZEoT9KmYTj
344
348
  flwr/serverapp/strategy/fedprox.py,sha256=XkkLTk3XpXAj0QoAzHqAvcAlPjrNlX11ISAu5u2x6X8,7026
345
349
  flwr/serverapp/strategy/fedtrimmedavg.py,sha256=4-QxgAQGo_7vB_L7qDYy28d95OBt9MeDa92yaTRMHqk,7166
346
350
  flwr/serverapp/strategy/fedxgb_bagging.py,sha256=ktDjzov4y0BRecioq788umCEtcuwElou9olBizQKOnM,3282
351
+ flwr/serverapp/strategy/fedxgb_cyclic.py,sha256=8H8WoLdG4Fy1_dtLLE4AYiidC-Cvaw2GxySfzAb7Xj0,8774
347
352
  flwr/serverapp/strategy/fedyogi.py,sha256=1Ripr4Hi2cdeTOLiFOXtMKvOxR3BsUQwc7bbTrXN4LM,6653
348
353
  flwr/serverapp/strategy/result.py,sha256=E0Hl2VLnZAgQJjE2GDoKsK7JX-kPPU2KXc47Axt6hGw,4295
349
354
  flwr/serverapp/strategy/strategy.py,sha256=8uJGGm1ROLZERQ_dkRS7Z_rs-yK6XCE0UxXtIdFiEWk,10789
350
355
  flwr/serverapp/strategy/strategy_utils.py,sha256=hiwS7k-Hx6_c4NZXoKpHucS5CBKb7f8GppXRBSMt3Us,10851
351
- flwr/serverapp/strategy/strategy_utils_tests.py,sha256=_adS23Lrv1QA6V_3oZ7P_csMd8RqDObFeIhOkFnNtTg,10690
352
356
  flwr/simulation/__init__.py,sha256=Gg6OsP1Z-ixc3-xxzvl7j7rz2Fijy9rzyEPpxgAQCeM,1556
353
357
  flwr/simulation/app.py,sha256=b_bDyZFwBf2zpKs37Vmd5cFJSzDRE0fL-8uqA0UkAv4,10393
354
358
  flwr/simulation/legacy_app.py,sha256=nMISQqW0otJL1-2Kfd94O6BLlGS2IEmEPKTM2WGKrIs,15861
@@ -361,7 +365,8 @@ flwr/simulation/simulationio_connection.py,sha256=mzS1C6EEREwQDPceDo30anAasmTDLb
361
365
  flwr/supercore/__init__.py,sha256=pqkFoow_E6UhbBlhmoD1gmTH-33yJRhBsIZqxRPFZ7U,755
362
366
  flwr/supercore/app_utils.py,sha256=K76Zt6R670b1hUmxOsNc1WUCVYvF7lejXPcCO9K0Q0g,1753
363
367
  flwr/supercore/cli/__init__.py,sha256=EDl2aO-fuQfxSbL-T1W9RAfA2N0hpWHmqX_GSwblJbQ,845
364
- flwr/supercore/cli/flower_superexec.py,sha256=kov4uEeihf7QEUAfHEgdEvsL_8nL_fzQI9EePnRM1Ww,5012
368
+ flwr/supercore/cli/flower_superexec.py,sha256=JtqYrEWVu3BxLkjavsdohTOwvMwzuFqWP5j4Mo9dqsk,6155
369
+ flwr/supercore/constant.py,sha256=F9kRjisedaZcoyGvUITSDmIG12QDSCpo2LlM_l-q6jM,820
365
370
  flwr/supercore/corestate/__init__.py,sha256=Vau6-L_JG5QzNqtCTa9xCKGGljc09wY8avZmIjSJemg,774
366
371
  flwr/supercore/corestate/corestate.py,sha256=rDAWWeG5DcpCyQso9Z3RhwL4zr2IroPlRMcDzqoSu8s,2328
367
372
  flwr/supercore/ffs/__init__.py,sha256=U3KXwG_SplEvchat27K0LYPoPHzh-cwwT_NHsGlYMt8,908
@@ -382,10 +387,10 @@ flwr/supercore/superexec/__init__.py,sha256=XKX208hZ6a9gZ4KT9kMqfpCtp_8VGxekzKFf
382
387
  flwr/supercore/superexec/plugin/__init__.py,sha256=GNwq8uNdE8RB7ywEFRAvKjLFzgS3YXgz39-HBGsemWw,1035
383
388
  flwr/supercore/superexec/plugin/base_exec_plugin.py,sha256=fL-Ufc9Dp56OhWOzNSJUc7HumbkuSDYqZKwde2opG4g,2074
384
389
  flwr/supercore/superexec/plugin/clientapp_exec_plugin.py,sha256=9FT6ufEqV5K9g4FaAB9lVDbIv-VCH5LcxT4YKy23roE,1035
385
- flwr/supercore/superexec/plugin/exec_plugin.py,sha256=w3jmtxdv7ov_EdAgifKcm4q8nV39e2Xna4sNjqClwOM,2447
390
+ flwr/supercore/superexec/plugin/exec_plugin.py,sha256=4WtCQ4bsuFRlfCbg91ZcPAsX8htrCCo_fFh1DKo3cCQ,2764
386
391
  flwr/supercore/superexec/plugin/serverapp_exec_plugin.py,sha256=IwRzdPV-cSKwrP2krGh0De4IkAuxsmgK0WU6J-2GXqM,1035
387
392
  flwr/supercore/superexec/plugin/simulation_exec_plugin.py,sha256=upn5zE-YKkl_jTw8RzmeyQ58PU_UAlQ7CqnBXXdng8I,1060
388
- flwr/supercore/superexec/run_superexec.py,sha256=8hUlaVPVNnhePQ9OUgen4yy0fSGZAVggBGzm-33iJPw,6630
393
+ flwr/supercore/superexec/run_superexec.py,sha256=JiwKq9s_WPpk0S9MSi1lIgMZU120NOZLf4GlObHzI_k,7217
389
394
  flwr/supercore/utils.py,sha256=ebuHMbeA8eXisX0oMPqBK3hk7uVnIE_yiqWVz8YbkpQ,1324
390
395
  flwr/superlink/__init__.py,sha256=GNSuJ4-N6Z8wun2iZNlXqENt5beUyzC0Gi_tN396bbM,707
391
396
  flwr/superlink/artifact_provider/__init__.py,sha256=pgZEcVPKRE874LSu3cgy0HbwSJBIpVy_HxQOmne4PAs,810
@@ -411,7 +416,7 @@ flwr/supernode/servicer/__init__.py,sha256=lucTzre5WPK7G1YLCfaqg3rbFWdNSb7ZTt-ca
411
416
  flwr/supernode/servicer/clientappio/__init__.py,sha256=7Oy62Y_oijqF7Dxi6tpcUQyOpLc_QpIRZ83NvwmB0Yg,813
412
417
  flwr/supernode/servicer/clientappio/clientappio_servicer.py,sha256=nIHRu38EWK-rpNOkcgBRAAKwYQQWFeCwu0lkO7OPZGQ,10239
413
418
  flwr/supernode/start_client_internal.py,sha256=Y9S1-QlO2WP6eo4JvWzIpfaCoh2aoE7bjEYyxNNnlyg,20777
414
- flwr_nightly-1.22.0.dev20250917.dist-info/METADATA,sha256=NGqW090jvJZ7dg66AQQr9jKb4SF7DdF78xwqFuhtFUE,14559
415
- flwr_nightly-1.22.0.dev20250917.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
416
- flwr_nightly-1.22.0.dev20250917.dist-info/entry_points.txt,sha256=hxHD2ixb_vJFDOlZV-zB4Ao32_BQlL34ftsDh1GXv14,420
417
- flwr_nightly-1.22.0.dev20250917.dist-info/RECORD,,
419
+ flwr_nightly-1.22.0.dev20250918.dist-info/METADATA,sha256=XdvpDYzZ_dWvMX49Pwi2gOUo0qLtrt466UABjDeyavg,14559
420
+ flwr_nightly-1.22.0.dev20250918.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
421
+ flwr_nightly-1.22.0.dev20250918.dist-info/entry_points.txt,sha256=hxHD2ixb_vJFDOlZV-zB4Ao32_BQlL34ftsDh1GXv14,420
422
+ flwr_nightly-1.22.0.dev20250918.dist-info/RECORD,,
@@ -1,323 +0,0 @@
1
- # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- """Tests for message-based strategy utilities."""
16
-
17
-
18
- from collections import OrderedDict
19
- from unittest.mock import Mock
20
-
21
- import numpy as np
22
- import pytest
23
- from parameterized import parameterized
24
-
25
- from flwr.common import (
26
- Array,
27
- ArrayRecord,
28
- ConfigRecord,
29
- Message,
30
- MetricRecord,
31
- RecordDict,
32
- )
33
- from flwr.serverapp.exception import InconsistentMessageReplies
34
-
35
- from .strategy_utils import (
36
- aggregate_arrayrecords,
37
- aggregate_metricrecords,
38
- config_to_str,
39
- validate_message_reply_consistency,
40
- )
41
-
42
-
43
- def create_mock_reply(arrays: ArrayRecord, num_examples: float) -> Message:
44
- """Create a mock reply Message with default keys."""
45
- message = Mock(spec=Message)
46
- message.content = RecordDict(
47
- {"arrays": arrays, "metrics": MetricRecord({"num-examples": num_examples})}
48
- )
49
- message.has_error.side_effect = lambda: False
50
- message.has_content.side_effect = lambda: True
51
- return message
52
-
53
-
54
- def test_config_to_str() -> None:
55
- """Test that items of types bytes are masked out."""
56
- config = ConfigRecord({"a": 123, "b": [1, 2, 3], "c": b"bytes"})
57
- expected_str = "{'a': 123, 'b': [1, 2, 3], 'c': <bytes>}"
58
- assert config_to_str(config) == expected_str
59
-
60
-
61
- def test_arrayrecords_aggregation() -> None:
62
- """Test aggregation of ArrayRecords."""
63
- num_replies = 3
64
- num_arrays = 4
65
- weights = [0.25, 0.4, 0.35]
66
- np_arrays = [
67
- [np.random.randn(7, 3) for _ in range(num_arrays)] for _ in range(num_replies)
68
- ]
69
-
70
- avg_list = [
71
- np.average([lst[i] for lst in np_arrays], axis=0, weights=weights)
72
- for i in range(num_arrays)
73
- ]
74
-
75
- # Construct RecordDicts (mimicing replies)
76
- records = [
77
- RecordDict(
78
- {
79
- "arrays": ArrayRecord(np_arrays[i]),
80
- "metrics": MetricRecord({"weight": weights[i]}),
81
- }
82
- )
83
- for i in range(num_replies)
84
- ]
85
- # Execute aggregate
86
- aggrd = aggregate_arrayrecords(records, weighting_metric_name="weight")
87
-
88
- # Assert consistency
89
- assert all(np.allclose(a, b) for a, b in zip(aggrd.to_numpy_ndarrays(), avg_list))
90
- assert aggrd.object_id == ArrayRecord(avg_list).object_id
91
-
92
-
93
- def test_arrayrecords_aggregation_with_ndim_zero() -> None:
94
- """Test aggregation of ArrayRecords with 0-dim arrays."""
95
- num_replies = 3
96
- weights = [0.25, 0.4, 0.35]
97
- np_arrays = [np.array(np.random.randn()) for _ in range(num_replies)]
98
-
99
- # For 0-dimensional arrays, we just compute the weighted average directly
100
- avg_list = [np.average(np_arrays, axis=0, weights=weights)]
101
-
102
- # Construct RecordDicts (mimicing replies)
103
- records = [
104
- RecordDict(
105
- {
106
- "arrays": ArrayRecord([np_arrays[i]]),
107
- "metrics": MetricRecord({"weight": weights[i]}),
108
- }
109
- )
110
- for i in range(num_replies)
111
- ]
112
- # Execute aggregate
113
- aggrd = aggregate_arrayrecords(records, weighting_metric_name="weight")
114
-
115
- # Assert consistency
116
- assert np.isclose(aggrd.to_numpy_ndarrays()[0], avg_list[0])
117
- assert aggrd.object_id == ArrayRecord([np.array(avg_list[0])]).object_id
118
-
119
-
120
- def test_metricrecords_aggregation() -> None:
121
- """Test aggregation of MetricRecords."""
122
- num_replies = 3
123
- weights = [0.25, 0.4, 0.35]
124
- metric_records = [
125
- MetricRecord({"a": 1, "b": 2.0, "c": np.random.randn(3).tolist()})
126
- for _ in range(num_replies)
127
- ]
128
-
129
- # Compute expected aggregated MetricRecord.
130
- # For ease, we convert everything into numpy arrays, then aggregate
131
- as_np_entries = [
132
- {
133
- k: np.array(v) if isinstance(v, (int, float, list)) else v
134
- for k, v in record.items()
135
- }
136
- for record in metric_records
137
- ]
138
- avg_list = [
139
- np.average(
140
- [list(entries.values())[i] for entries in as_np_entries],
141
- axis=0,
142
- weights=weights,
143
- ).tolist()
144
- for i in range(len(as_np_entries[0]))
145
- ]
146
- expected_record = MetricRecord(dict(zip(as_np_entries[0].keys(), avg_list)))
147
- expected_record["a"] = float(expected_record["a"]) # type: ignore
148
- expected_record["b"] = float(expected_record["b"]) # type: ignore
149
-
150
- # Construct RecordDicts (mimicing replies)
151
- # Inject weighting factor
152
- records = [
153
- RecordDict(
154
- {
155
- "metrics": MetricRecord(
156
- record.__dict__["_data"] | {"weight": weights[i]}
157
- ),
158
- }
159
- )
160
- for i, record in enumerate(metric_records)
161
- ]
162
-
163
- # Execute aggregate
164
- aggrd = aggregate_metricrecords(records, weighting_metric_name="weight")
165
- # Assert
166
- assert expected_record.object_id == aggrd.object_id
167
-
168
-
169
- @parameterized.expand( # type: ignore
170
- [
171
- (
172
- True,
173
- RecordDict(
174
- {
175
- "global-model": ArrayRecord([np.random.randn(7, 3)]),
176
- "metrics": MetricRecord({"weight": 0.123}),
177
- }
178
- ),
179
- ), # Compliant
180
- (
181
- False,
182
- RecordDict(
183
- {
184
- "global-model": ArrayRecord([np.random.randn(7, 3)]),
185
- "metrics": MetricRecord({"weight": [0.123]}),
186
- }
187
- ),
188
- ), # Weighting key is not a scalar (BAD)
189
- (
190
- False,
191
- RecordDict(
192
- {
193
- "global-model": ArrayRecord([np.random.randn(7, 3)]),
194
- "metrics": MetricRecord({"loss": 0.01}),
195
- }
196
- ),
197
- ), # No weighting key in MetricRecord (BAD)
198
- (
199
- False,
200
- RecordDict({"global-model": ArrayRecord([np.random.randn(7, 3)])}),
201
- ), # No MetricsRecord (BAD)
202
- (
203
- False,
204
- RecordDict(
205
- {
206
- "global-model": ArrayRecord([np.random.randn(7, 3)]),
207
- "another-model": ArrayRecord([np.random.randn(7, 3)]),
208
- }
209
- ),
210
- ), # Two ArrayRecords (BAD)
211
- (
212
- False,
213
- RecordDict(
214
- {
215
- "global-model": ArrayRecord([np.random.randn(7, 3)]),
216
- "metrics": MetricRecord({"weight": 0.123}),
217
- "more-metrics": MetricRecord({"loss": 0.321}),
218
- }
219
- ),
220
- ), # Two MetricRecords (BAD)
221
- ]
222
- )
223
- def test_consistency_of_replies_with_matching_keys(
224
- is_valid: bool, recorddict: RecordDict
225
- ) -> None:
226
- """Test consistency in replies."""
227
- # Create dummy records
228
- records = [recorddict for _ in range(3)]
229
-
230
- if not is_valid:
231
- # Should raise InconsistentMessageReplies exception
232
- with pytest.raises(InconsistentMessageReplies):
233
- validate_message_reply_consistency(
234
- records, weighted_by_key="weight", check_arrayrecord=True
235
- )
236
- else:
237
- # Should not raise an exception
238
- validate_message_reply_consistency(
239
- records, weighted_by_key="weight", check_arrayrecord=True
240
- )
241
-
242
-
243
- @parameterized.expand( # type: ignore
244
- [
245
- (
246
- [
247
- RecordDict(
248
- {
249
- "global-model": ArrayRecord([np.random.randn(7, 3)]),
250
- "metrics": MetricRecord({"weight": 0.123}),
251
- }
252
- ),
253
- RecordDict(
254
- {
255
- "model": ArrayRecord([np.random.randn(7, 3)]),
256
- "metrics": MetricRecord({"weight": 0.123}),
257
- }
258
- ),
259
- ],
260
- ), # top-level keys don't match for ArrayRecords
261
- (
262
- [
263
- RecordDict(
264
- {
265
- "global-model": ArrayRecord(
266
- OrderedDict({"a": Array(np.random.randn(7, 3))})
267
- ),
268
- "metrics": MetricRecord({"weight": 0.123}),
269
- }
270
- ),
271
- RecordDict(
272
- {
273
- "global-model": ArrayRecord(
274
- OrderedDict({"b": Array(np.random.randn(7, 3))})
275
- ),
276
- "metrics": MetricRecord({"weight": 0.123}),
277
- }
278
- ),
279
- ],
280
- ), # top-level keys match for ArrayRecords but not those for Arrays
281
- (
282
- [
283
- RecordDict(
284
- {
285
- "global-model": ArrayRecord([np.random.randn(7, 3)]),
286
- "metrics": MetricRecord({"weight": 0.123}),
287
- }
288
- ),
289
- RecordDict(
290
- {
291
- "global-model": ArrayRecord([np.random.randn(7, 3)]),
292
- "my-metrics": MetricRecord({"weight": 0.123}),
293
- }
294
- ),
295
- ],
296
- ), # top-level keys don't match for MetricRecords
297
- (
298
- [
299
- RecordDict(
300
- {
301
- "global-model": ArrayRecord([np.random.randn(7, 3)]),
302
- "metrics": MetricRecord({"weight": 0.123}),
303
- }
304
- ),
305
- RecordDict(
306
- {
307
- "global-model": ArrayRecord([np.random.randn(7, 3)]),
308
- "my-metrics": MetricRecord({"my-weights": 0.123}),
309
- }
310
- ),
311
- ],
312
- ), # top-level keys match for MetricRecords but not inner ones
313
- ]
314
- )
315
- def test_consistency_of_replies_with_different_keys(
316
- list_records: list[RecordDict],
317
- ) -> None:
318
- """Test consistency in replies when records don't have matching keys."""
319
- # All test cases expect InconsistentMessageReplies exception to be raised
320
- with pytest.raises(InconsistentMessageReplies):
321
- validate_message_reply_consistency(
322
- list_records, weighted_by_key="weight", check_arrayrecord=True
323
- )