flwr-nightly 1.22.0.dev20250913__py3-none-any.whl → 1.22.0.dev20250916__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. flwr/cli/new/new.py +5 -5
  2. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
  3. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
  4. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
  5. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
  6. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
  7. flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
  8. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +2 -2
  9. flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +2 -2
  10. flwr/serverapp/strategy/__init__.py +8 -0
  11. flwr/serverapp/strategy/fedavg.py +23 -2
  12. flwr/serverapp/strategy/fedavgm.py +198 -0
  13. flwr/serverapp/strategy/fedmedian.py +71 -0
  14. flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
  15. flwr/serverapp/strategy/fedxgb_bagging.py +82 -0
  16. flwr/serverapp/strategy/strategy_utils.py +48 -0
  17. flwr/serverapp/strategy/strategy_utils_tests.py +20 -1
  18. {flwr_nightly-1.22.0.dev20250913.dist-info → flwr_nightly-1.22.0.dev20250916.dist-info}/METADATA +6 -16
  19. {flwr_nightly-1.22.0.dev20250913.dist-info → flwr_nightly-1.22.0.dev20250916.dist-info}/RECORD +22 -18
  20. flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
  21. flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
  22. /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
  23. {flwr_nightly-1.22.0.dev20250913.dist-info → flwr_nightly-1.22.0.dev20250916.dist-info}/WHEEL +0 -0
  24. {flwr_nightly-1.22.0.dev20250913.dist-info → flwr_nightly-1.22.0.dev20250916.dist-info}/entry_points.txt +0 -0
flwr/cli/new/new.py CHANGED
@@ -35,7 +35,6 @@ class MlFramework(str, Enum):
35
35
  """Available frameworks."""
36
36
 
37
37
  PYTORCH = "PyTorch"
38
- PYTORCH_MSG_API = "PyTorch (Message API)"
39
38
  TENSORFLOW = "TensorFlow"
40
39
  SKLEARN = "sklearn"
41
40
  HUGGINGFACE = "HuggingFace"
@@ -44,6 +43,7 @@ class MlFramework(str, Enum):
44
43
  NUMPY = "NumPy"
45
44
  FLOWERTUNE = "FlowerTune"
46
45
  BASELINE = "Flower Baseline"
46
+ PYTORCH_LEGACY_API = "PyTorch (Legacy API, deprecated)"
47
47
 
48
48
 
49
49
  class LlmChallengeName(str, Enum):
@@ -155,8 +155,8 @@ def new(
155
155
  if framework_str == MlFramework.BASELINE:
156
156
  framework_str = "baseline"
157
157
 
158
- if framework_str == MlFramework.PYTORCH_MSG_API:
159
- framework_str = "pytorch_msg_api"
158
+ if framework_str == MlFramework.PYTORCH_LEGACY_API:
159
+ framework_str = "pytorch_legacy_api"
160
160
 
161
161
  print(
162
162
  typer.style(
@@ -247,14 +247,14 @@ def new(
247
247
  MlFramework.TENSORFLOW.value,
248
248
  MlFramework.SKLEARN.value,
249
249
  MlFramework.NUMPY.value,
250
- "pytorch_msg_api",
250
+ "pytorch_legacy_api",
251
251
  ]
252
252
  if framework_str in frameworks_with_tasks:
253
253
  files[f"{import_name}/task.py"] = {
254
254
  "template": f"app/code/task.{template_name}.py.tpl"
255
255
  }
256
256
 
257
- if framework_str == "pytorch_msg_api":
257
+ if framework_str == "pytorch_legacy_api":
258
258
  # Use custom __init__ that better captures name of framework
259
259
  files[f"{import_name}/__init__.py"] = {
260
260
  "template": f"app/code/__init__.{framework_str}.py.tpl"
@@ -1,55 +1,80 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import torch
4
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
5
+ from flwr.clientapp import ClientApp
4
6
 
5
- from flwr.client import ClientApp, NumPyClient
6
- from flwr.common import Context
7
- from $import_name.task import Net, get_weights, load_data, set_weights, test, train
8
-
9
-
10
- # Define Flower Client and client_fn
11
- class FlowerClient(NumPyClient):
12
- def __init__(self, net, trainloader, valloader, local_epochs):
13
- self.net = net
14
- self.trainloader = trainloader
15
- self.valloader = valloader
16
- self.local_epochs = local_epochs
17
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
- self.net.to(self.device)
19
-
20
- def fit(self, parameters, config):
21
- set_weights(self.net, parameters)
22
- train_loss = train(
23
- self.net,
24
- self.trainloader,
25
- self.local_epochs,
26
- self.device,
27
- )
28
- return (
29
- get_weights(self.net),
30
- len(self.trainloader.dataset),
31
- {"train_loss": train_loss},
32
- )
33
-
34
- def evaluate(self, parameters, config):
35
- set_weights(self.net, parameters)
36
- loss, accuracy = test(self.net, self.valloader, self.device)
37
- return loss, len(self.valloader.dataset), {"accuracy": accuracy}
38
-
39
-
40
- def client_fn(context: Context):
41
- # Load model and data
42
- net = Net()
7
+ from $import_name.task import Net, load_data
8
+ from $import_name.task import test as test_fn
9
+ from $import_name.task import train as train_fn
10
+
11
+ # Flower ClientApp
12
+ app = ClientApp()
13
+
14
+
15
+ @app.train()
16
+ def train(msg: Message, context: Context):
17
+ """Train the model on local data."""
18
+
19
+ # Load the model and initialize it with the received weights
20
+ model = Net()
21
+ model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
22
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
+ model.to(device)
24
+
25
+ # Load the data
43
26
  partition_id = context.node_config["partition-id"]
44
27
  num_partitions = context.node_config["num-partitions"]
45
- trainloader, valloader = load_data(partition_id, num_partitions)
46
- local_epochs = context.run_config["local-epochs"]
28
+ trainloader, _ = load_data(partition_id, num_partitions)
47
29
 
48
- # Return Client instance
49
- return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
30
+ # Call the training function
31
+ train_loss = train_fn(
32
+ model,
33
+ trainloader,
34
+ context.run_config["local-epochs"],
35
+ msg.content["config"]["lr"],
36
+ device,
37
+ )
50
38
 
39
+ # Construct and return reply Message
40
+ model_record = ArrayRecord(model.state_dict())
41
+ metrics = {
42
+ "train_loss": train_loss,
43
+ "num-examples": len(trainloader.dataset),
44
+ }
45
+ metric_record = MetricRecord(metrics)
46
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
47
+ return Message(content=content, reply_to=msg)
51
48
 
52
- # Flower ClientApp
53
- app = ClientApp(
54
- client_fn,
55
- )
49
+
50
+ @app.evaluate()
51
+ def evaluate(msg: Message, context: Context):
52
+ """Evaluate the model on local data."""
53
+
54
+ # Load the model and initialize it with the received weights
55
+ model = Net()
56
+ model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
57
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58
+ model.to(device)
59
+
60
+ # Load the data
61
+ partition_id = context.node_config["partition-id"]
62
+ num_partitions = context.node_config["num-partitions"]
63
+ _, valloader = load_data(partition_id, num_partitions)
64
+
65
+ # Call the evaluation function
66
+ eval_loss, eval_acc = test_fn(
67
+ model,
68
+ valloader,
69
+ device,
70
+ )
71
+
72
+ # Construct and return reply Message
73
+ metrics = {
74
+ "eval_loss": eval_loss,
75
+ "eval_acc": eval_acc,
76
+ "num-examples": len(valloader.dataset),
77
+ }
78
+ metric_record = MetricRecord(metrics)
79
+ content = RecordDict({"metrics": metric_record})
80
+ return Message(content=content, reply_to=msg)
@@ -0,0 +1,55 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import torch
4
+
5
+ from flwr.client import ClientApp, NumPyClient
6
+ from flwr.common import Context
7
+ from $import_name.task import Net, get_weights, load_data, set_weights, test, train
8
+
9
+
10
+ # Define Flower Client and client_fn
11
+ class FlowerClient(NumPyClient):
12
+ def __init__(self, net, trainloader, valloader, local_epochs):
13
+ self.net = net
14
+ self.trainloader = trainloader
15
+ self.valloader = valloader
16
+ self.local_epochs = local_epochs
17
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
+ self.net.to(self.device)
19
+
20
+ def fit(self, parameters, config):
21
+ set_weights(self.net, parameters)
22
+ train_loss = train(
23
+ self.net,
24
+ self.trainloader,
25
+ self.local_epochs,
26
+ self.device,
27
+ )
28
+ return (
29
+ get_weights(self.net),
30
+ len(self.trainloader.dataset),
31
+ {"train_loss": train_loss},
32
+ )
33
+
34
+ def evaluate(self, parameters, config):
35
+ set_weights(self.net, parameters)
36
+ loss, accuracy = test(self.net, self.valloader, self.device)
37
+ return loss, len(self.valloader.dataset), {"accuracy": accuracy}
38
+
39
+
40
+ def client_fn(context: Context):
41
+ # Load model and data
42
+ net = Net()
43
+ partition_id = context.node_config["partition-id"]
44
+ num_partitions = context.node_config["num-partitions"]
45
+ trainloader, valloader = load_data(partition_id, num_partitions)
46
+ local_epochs = context.run_config["local-epochs"]
47
+
48
+ # Return Client instance
49
+ return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
50
+
51
+
52
+ # Flower ClientApp
53
+ app = ClientApp(
54
+ client_fn,
55
+ )
@@ -1,31 +1,41 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
6
- from $import_name.task import Net, get_weights
7
-
8
-
9
- def server_fn(context: Context):
10
- # Read from config
11
- num_rounds = context.run_config["num-server-rounds"]
12
- fraction_fit = context.run_config["fraction-fit"]
13
-
14
- # Initialize model parameters
15
- ndarrays = get_weights(Net())
16
- parameters = ndarrays_to_parameters(ndarrays)
17
-
18
- # Define strategy
19
- strategy = FedAvg(
20
- fraction_fit=fraction_fit,
21
- fraction_evaluate=1.0,
22
- min_available_clients=2,
23
- initial_parameters=parameters,
24
- )
25
- config = ServerConfig(num_rounds=num_rounds)
26
-
27
- return ServerAppComponents(strategy=strategy, config=config)
3
+ import torch
4
+ from flwr.app import ArrayRecord, ConfigRecord, Context
5
+ from flwr.serverapp import Grid, ServerApp
6
+ from flwr.serverapp.strategy import FedAvg
28
7
 
8
+ from $import_name.task import Net
29
9
 
30
10
  # Create ServerApp
31
- app = ServerApp(server_fn=server_fn)
11
+ app = ServerApp()
12
+
13
+
14
+ @app.main()
15
+ def main(grid: Grid, context: Context) -> None:
16
+ """Main entry point for the ServerApp."""
17
+
18
+ # Read run config
19
+ fraction_train: float = context.run_config["fraction-train"]
20
+ num_rounds: int = context.run_config["num-server-rounds"]
21
+ lr: float = context.run_config["lr"]
22
+
23
+ # Load global model
24
+ global_model = Net()
25
+ arrays = ArrayRecord(global_model.state_dict())
26
+
27
+ # Initialize FedAvg strategy
28
+ strategy = FedAvg(fraction_train=fraction_train)
29
+
30
+ # Start strategy, run FedAvg for `num_rounds`
31
+ result = strategy.start(
32
+ grid=grid,
33
+ initial_arrays=arrays,
34
+ train_config=ConfigRecord({"lr": lr}),
35
+ num_rounds=num_rounds,
36
+ )
37
+
38
+ # Save final model to disk
39
+ print("\nSaving final model to disk...")
40
+ state_dict = result.arrays.to_torch_state_dict()
41
+ torch.save(state_dict, "final_model.pt")
@@ -0,0 +1,31 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ from flwr.common import Context, ndarrays_to_parameters
4
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
+ from flwr.server.strategy import FedAvg
6
+ from $import_name.task import Net, get_weights
7
+
8
+
9
+ def server_fn(context: Context):
10
+ # Read from config
11
+ num_rounds = context.run_config["num-server-rounds"]
12
+ fraction_fit = context.run_config["fraction-fit"]
13
+
14
+ # Initialize model parameters
15
+ ndarrays = get_weights(Net())
16
+ parameters = ndarrays_to_parameters(ndarrays)
17
+
18
+ # Define strategy
19
+ strategy = FedAvg(
20
+ fraction_fit=fraction_fit,
21
+ fraction_evaluate=1.0,
22
+ min_available_clients=2,
23
+ initial_parameters=parameters,
24
+ )
25
+ config = ServerConfig(num_rounds=num_rounds)
26
+
27
+ return ServerAppComponents(strategy=strategy, config=config)
28
+
29
+
30
+ # Create ServerApp
31
+ app = ServerApp(server_fn=server_fn)
@@ -1,7 +1,5 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from collections import OrderedDict
4
-
5
3
  import torch
6
4
  import torch.nn as nn
7
5
  import torch.nn.functional as F
@@ -34,6 +32,14 @@ class Net(nn.Module):
34
32
 
35
33
  fds = None # Cache FederatedDataset
36
34
 
35
+ pytorch_transforms = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
36
+
37
+
38
+ def apply_transforms(batch):
39
+ """Apply transforms to the partition from FederatedDataset."""
40
+ batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
41
+ return batch
42
+
37
43
 
38
44
  def load_data(partition_id: int, num_partitions: int):
39
45
  """Load partition CIFAR10 data."""
@@ -48,38 +54,29 @@ def load_data(partition_id: int, num_partitions: int):
48
54
  partition = fds.load_partition(partition_id)
49
55
  # Divide data on each node: 80% train, 20% test
50
56
  partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
51
- pytorch_transforms = Compose(
52
- [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
53
- )
54
-
55
- def apply_transforms(batch):
56
- """Apply transforms to the partition from FederatedDataset."""
57
- batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
58
- return batch
59
-
57
+ # Construct dataloaders
60
58
  partition_train_test = partition_train_test.with_transform(apply_transforms)
61
59
  trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
62
60
  testloader = DataLoader(partition_train_test["test"], batch_size=32)
63
61
  return trainloader, testloader
64
62
 
65
63
 
66
- def train(net, trainloader, epochs, device):
64
+ def train(net, trainloader, epochs, lr, device):
67
65
  """Train the model on the training set."""
68
66
  net.to(device) # move model to GPU if available
69
67
  criterion = torch.nn.CrossEntropyLoss().to(device)
70
- optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
68
+ optimizer = torch.optim.Adam(net.parameters(), lr=lr)
71
69
  net.train()
72
70
  running_loss = 0.0
73
71
  for _ in range(epochs):
74
72
  for batch in trainloader:
75
- images = batch["img"]
76
- labels = batch["label"]
73
+ images = batch["img"].to(device)
74
+ labels = batch["label"].to(device)
77
75
  optimizer.zero_grad()
78
- loss = criterion(net(images.to(device)), labels.to(device))
76
+ loss = criterion(net(images), labels)
79
77
  loss.backward()
80
78
  optimizer.step()
81
79
  running_loss += loss.item()
82
-
83
80
  avg_trainloss = running_loss / len(trainloader)
84
81
  return avg_trainloss
85
82
 
@@ -99,13 +96,3 @@ def test(net, testloader, device):
99
96
  accuracy = correct / len(testloader.dataset)
100
97
  loss = loss / len(testloader)
101
98
  return loss, accuracy
102
-
103
-
104
- def get_weights(net):
105
- return [val.cpu().numpy() for _, val in net.state_dict().items()]
106
-
107
-
108
- def set_weights(net, parameters):
109
- params_dict = zip(net.state_dict().keys(), parameters)
110
- state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
111
- net.load_state_dict(state_dict, strict=True)
@@ -1,5 +1,7 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
+ from collections import OrderedDict
4
+
3
5
  import torch
4
6
  import torch.nn as nn
5
7
  import torch.nn.functional as F
@@ -32,14 +34,6 @@ class Net(nn.Module):
32
34
 
33
35
  fds = None # Cache FederatedDataset
34
36
 
35
- pytorch_transforms = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
36
-
37
-
38
- def apply_transforms(batch):
39
- """Apply transforms to the partition from FederatedDataset."""
40
- batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
41
- return batch
42
-
43
37
 
44
38
  def load_data(partition_id: int, num_partitions: int):
45
39
  """Load partition CIFAR10 data."""
@@ -54,29 +48,38 @@ def load_data(partition_id: int, num_partitions: int):
54
48
  partition = fds.load_partition(partition_id)
55
49
  # Divide data on each node: 80% train, 20% test
56
50
  partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
57
- # Construct dataloaders
51
+ pytorch_transforms = Compose(
52
+ [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
53
+ )
54
+
55
+ def apply_transforms(batch):
56
+ """Apply transforms to the partition from FederatedDataset."""
57
+ batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
58
+ return batch
59
+
58
60
  partition_train_test = partition_train_test.with_transform(apply_transforms)
59
61
  trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
60
62
  testloader = DataLoader(partition_train_test["test"], batch_size=32)
61
63
  return trainloader, testloader
62
64
 
63
65
 
64
- def train(net, trainloader, epochs, lr, device):
66
+ def train(net, trainloader, epochs, device):
65
67
  """Train the model on the training set."""
66
68
  net.to(device) # move model to GPU if available
67
69
  criterion = torch.nn.CrossEntropyLoss().to(device)
68
- optimizer = torch.optim.Adam(net.parameters(), lr=lr)
70
+ optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
69
71
  net.train()
70
72
  running_loss = 0.0
71
73
  for _ in range(epochs):
72
74
  for batch in trainloader:
73
- images = batch["img"].to(device)
74
- labels = batch["label"].to(device)
75
+ images = batch["img"]
76
+ labels = batch["label"]
75
77
  optimizer.zero_grad()
76
- loss = criterion(net(images), labels)
78
+ loss = criterion(net(images.to(device)), labels.to(device))
77
79
  loss.backward()
78
80
  optimizer.step()
79
81
  running_loss += loss.item()
82
+
80
83
  avg_trainloss = running_loss / len(trainloader)
81
84
  return avg_trainloss
82
85
 
@@ -96,3 +99,13 @@ def test(net, testloader, device):
96
99
  accuracy = correct / len(testloader.dataset)
97
100
  loss = loss / len(testloader)
98
101
  return loss, accuracy
102
+
103
+
104
+ def get_weights(net):
105
+ return [val.cpu().numpy() for _, val in net.state_dict().items()]
106
+
107
+
108
+ def set_weights(net, parameters):
109
+ params_dict = zip(net.state_dict().keys(), parameters)
110
+ state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
111
+ net.load_state_dict(state_dict, strict=True)
@@ -27,7 +27,6 @@ packages = ["."]
27
27
  publisher = "$username"
28
28
 
29
29
  # Point to your ServerApp and ClientApp objects
30
- # Format: "<module>:<object>"
31
30
  [tool.flwr.app.components]
32
31
  serverapp = "$import_name.server_app:app"
33
32
  clientapp = "$import_name.client_app:app"
@@ -35,8 +34,9 @@ clientapp = "$import_name.client_app:app"
35
34
  # Custom config values accessible via `context.run_config`
36
35
  [tool.flwr.app.config]
37
36
  num-server-rounds = 3
38
- fraction-fit = 0.5
37
+ fraction-train = 0.5
39
38
  local-epochs = 1
39
+ lr = 0.01
40
40
 
41
41
  # Default federation to use when running the app
42
42
  [tool.flwr.federations]
@@ -27,6 +27,7 @@ packages = ["."]
27
27
  publisher = "$username"
28
28
 
29
29
  # Point to your ServerApp and ClientApp objects
30
+ # Format: "<module>:<object>"
30
31
  [tool.flwr.app.components]
31
32
  serverapp = "$import_name.server_app:app"
32
33
  clientapp = "$import_name.client_app:app"
@@ -34,9 +35,8 @@ clientapp = "$import_name.client_app:app"
34
35
  # Custom config values accessible via `context.run_config`
35
36
  [tool.flwr.app.config]
36
37
  num-server-rounds = 3
37
- fraction-train = 0.5
38
+ fraction-fit = 0.5
38
39
  local-epochs = 1
39
- lr = 0.01
40
40
 
41
41
  # Default federation to use when running the app
42
42
  [tool.flwr.federations]
@@ -22,6 +22,10 @@ from .dp_fixed_clipping import (
22
22
  from .fedadagrad import FedAdagrad
23
23
  from .fedadam import FedAdam
24
24
  from .fedavg import FedAvg
25
+ from .fedavgm import FedAvgM
26
+ from .fedmedian import FedMedian
27
+ from .fedtrimmedavg import FedTrimmedAvg
28
+ from .fedxgb_bagging import FedXgbBagging
25
29
  from .fedyogi import FedYogi
26
30
  from .result import Result
27
31
  from .strategy import Strategy
@@ -32,6 +36,10 @@ __all__ = [
32
36
  "FedAdagrad",
33
37
  "FedAdam",
34
38
  "FedAvg",
39
+ "FedAvgM",
40
+ "FedMedian",
41
+ "FedTrimmedAvg",
42
+ "FedXgbBagging",
35
43
  "FedYogi",
36
44
  "Result",
37
45
  "Strategy",
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  from collections.abc import Iterable
19
- from logging import INFO
19
+ from logging import INFO, WARNING
20
20
  from typing import Callable, Optional
21
21
 
22
22
  from flwr.common import (
@@ -67,7 +67,7 @@ class FedAvg(Strategy):
67
67
  arrayrecord_key : str (default: "arrays")
68
68
  Key used to store the ArrayRecord when constructing Messages.
69
69
  configrecord_key : str (default: "config")
70
- Key used to store the ConfigRecord when constructing Messages.
70
+ Key used to store the ConfigRecord when constructing Messages.
71
71
  train_metrics_aggr_fn : Optional[callable] (default: None)
72
72
  Function with signature (list[RecordDict], str) -> MetricRecord,
73
73
  used to aggregate MetricRecords from training round replies.
@@ -111,6 +111,20 @@ class FedAvg(Strategy):
111
111
  evaluate_metrics_aggr_fn or aggregate_metricrecords
112
112
  )
113
113
 
114
+ if self.fraction_evaluate == 0.0:
115
+ self.min_evaluate_nodes = 0
116
+ log(
117
+ WARNING,
118
+ "fraction_evaluate is set to 0.0. "
119
+ "Federated evaluation will be skipped.",
120
+ )
121
+ if self.fraction_train == 0.0:
122
+ self.min_train_nodes = 0
123
+ log(
124
+ WARNING,
125
+ "fraction_train is set to 0.0. Federated training will be skipped.",
126
+ )
127
+
114
128
  def summary(self) -> None:
115
129
  """Log summary configuration of the strategy."""
116
130
  log(INFO, "\t├──> Sampling:")
@@ -150,6 +164,9 @@ class FedAvg(Strategy):
150
164
  self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
151
165
  ) -> Iterable[Message]:
152
166
  """Configure the next round of federated training."""
167
+ # Do not configure federated train if fraction_train is 0.
168
+ if self.fraction_train == 0.0:
169
+ return []
153
170
  # Sample nodes
154
171
  num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_train)
155
172
  sample_size = max(num_nodes, self.min_train_nodes)
@@ -259,6 +276,10 @@ class FedAvg(Strategy):
259
276
  self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
260
277
  ) -> Iterable[Message]:
261
278
  """Configure the next round of federated evaluation."""
279
+ # Do not configure federated evaluation if fraction_evaluate is 0.
280
+ if self.fraction_evaluate == 0.0:
281
+ return []
282
+
262
283
  # Sample nodes
263
284
  num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_evaluate)
264
285
  sample_size = max(num_nodes, self.min_evaluate_nodes)