flwr-nightly 1.22.0.dev20250913__py3-none-any.whl → 1.22.0.dev20250915__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flwr/cli/new/new.py CHANGED
@@ -35,7 +35,6 @@ class MlFramework(str, Enum):
35
35
  """Available frameworks."""
36
36
 
37
37
  PYTORCH = "PyTorch"
38
- PYTORCH_MSG_API = "PyTorch (Message API)"
39
38
  TENSORFLOW = "TensorFlow"
40
39
  SKLEARN = "sklearn"
41
40
  HUGGINGFACE = "HuggingFace"
@@ -44,6 +43,7 @@ class MlFramework(str, Enum):
44
43
  NUMPY = "NumPy"
45
44
  FLOWERTUNE = "FlowerTune"
46
45
  BASELINE = "Flower Baseline"
46
+ PYTORCH_LEGACY_API = "PyTorch (Legacy API, deprecated)"
47
47
 
48
48
 
49
49
  class LlmChallengeName(str, Enum):
@@ -155,8 +155,8 @@ def new(
155
155
  if framework_str == MlFramework.BASELINE:
156
156
  framework_str = "baseline"
157
157
 
158
- if framework_str == MlFramework.PYTORCH_MSG_API:
159
- framework_str = "pytorch_msg_api"
158
+ if framework_str == MlFramework.PYTORCH_LEGACY_API:
159
+ framework_str = "pytorch_legacy_api"
160
160
 
161
161
  print(
162
162
  typer.style(
@@ -247,14 +247,14 @@ def new(
247
247
  MlFramework.TENSORFLOW.value,
248
248
  MlFramework.SKLEARN.value,
249
249
  MlFramework.NUMPY.value,
250
- "pytorch_msg_api",
250
+ "pytorch_legacy_api",
251
251
  ]
252
252
  if framework_str in frameworks_with_tasks:
253
253
  files[f"{import_name}/task.py"] = {
254
254
  "template": f"app/code/task.{template_name}.py.tpl"
255
255
  }
256
256
 
257
- if framework_str == "pytorch_msg_api":
257
+ if framework_str == "pytorch_legacy_api":
258
258
  # Use custom __init__ that better captures name of framework
259
259
  files[f"{import_name}/__init__.py"] = {
260
260
  "template": f"app/code/__init__.{framework_str}.py.tpl"
@@ -1,55 +1,80 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import torch
4
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
5
+ from flwr.clientapp import ClientApp
4
6
 
5
- from flwr.client import ClientApp, NumPyClient
6
- from flwr.common import Context
7
- from $import_name.task import Net, get_weights, load_data, set_weights, test, train
8
-
9
-
10
- # Define Flower Client and client_fn
11
- class FlowerClient(NumPyClient):
12
- def __init__(self, net, trainloader, valloader, local_epochs):
13
- self.net = net
14
- self.trainloader = trainloader
15
- self.valloader = valloader
16
- self.local_epochs = local_epochs
17
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
- self.net.to(self.device)
19
-
20
- def fit(self, parameters, config):
21
- set_weights(self.net, parameters)
22
- train_loss = train(
23
- self.net,
24
- self.trainloader,
25
- self.local_epochs,
26
- self.device,
27
- )
28
- return (
29
- get_weights(self.net),
30
- len(self.trainloader.dataset),
31
- {"train_loss": train_loss},
32
- )
33
-
34
- def evaluate(self, parameters, config):
35
- set_weights(self.net, parameters)
36
- loss, accuracy = test(self.net, self.valloader, self.device)
37
- return loss, len(self.valloader.dataset), {"accuracy": accuracy}
38
-
39
-
40
- def client_fn(context: Context):
41
- # Load model and data
42
- net = Net()
7
+ from $import_name.task import Net, load_data
8
+ from $import_name.task import test as test_fn
9
+ from $import_name.task import train as train_fn
10
+
11
+ # Flower ClientApp
12
+ app = ClientApp()
13
+
14
+
15
+ @app.train()
16
+ def train(msg: Message, context: Context):
17
+ """Train the model on local data."""
18
+
19
+ # Load the model and initialize it with the received weights
20
+ model = Net()
21
+ model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
22
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
+ model.to(device)
24
+
25
+ # Load the data
43
26
  partition_id = context.node_config["partition-id"]
44
27
  num_partitions = context.node_config["num-partitions"]
45
- trainloader, valloader = load_data(partition_id, num_partitions)
46
- local_epochs = context.run_config["local-epochs"]
28
+ trainloader, _ = load_data(partition_id, num_partitions)
47
29
 
48
- # Return Client instance
49
- return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
30
+ # Call the training function
31
+ train_loss = train_fn(
32
+ model,
33
+ trainloader,
34
+ context.run_config["local-epochs"],
35
+ msg.content["config"]["lr"],
36
+ device,
37
+ )
50
38
 
39
+ # Construct and return reply Message
40
+ model_record = ArrayRecord(model.state_dict())
41
+ metrics = {
42
+ "train_loss": train_loss,
43
+ "num-examples": len(trainloader.dataset),
44
+ }
45
+ metric_record = MetricRecord(metrics)
46
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
47
+ return Message(content=content, reply_to=msg)
51
48
 
52
- # Flower ClientApp
53
- app = ClientApp(
54
- client_fn,
55
- )
49
+
50
+ @app.evaluate()
51
+ def evaluate(msg: Message, context: Context):
52
+ """Evaluate the model on local data."""
53
+
54
+ # Load the model and initialize it with the received weights
55
+ model = Net()
56
+ model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
57
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58
+ model.to(device)
59
+
60
+ # Load the data
61
+ partition_id = context.node_config["partition-id"]
62
+ num_partitions = context.node_config["num-partitions"]
63
+ _, valloader = load_data(partition_id, num_partitions)
64
+
65
+ # Call the evaluation function
66
+ eval_loss, eval_acc = test_fn(
67
+ model,
68
+ valloader,
69
+ device,
70
+ )
71
+
72
+ # Construct and return reply Message
73
+ metrics = {
74
+ "eval_loss": eval_loss,
75
+ "eval_acc": eval_acc,
76
+ "num-examples": len(valloader.dataset),
77
+ }
78
+ metric_record = MetricRecord(metrics)
79
+ content = RecordDict({"metrics": metric_record})
80
+ return Message(content=content, reply_to=msg)
@@ -0,0 +1,55 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import torch
4
+
5
+ from flwr.client import ClientApp, NumPyClient
6
+ from flwr.common import Context
7
+ from $import_name.task import Net, get_weights, load_data, set_weights, test, train
8
+
9
+
10
+ # Define Flower Client and client_fn
11
+ class FlowerClient(NumPyClient):
12
+ def __init__(self, net, trainloader, valloader, local_epochs):
13
+ self.net = net
14
+ self.trainloader = trainloader
15
+ self.valloader = valloader
16
+ self.local_epochs = local_epochs
17
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
+ self.net.to(self.device)
19
+
20
+ def fit(self, parameters, config):
21
+ set_weights(self.net, parameters)
22
+ train_loss = train(
23
+ self.net,
24
+ self.trainloader,
25
+ self.local_epochs,
26
+ self.device,
27
+ )
28
+ return (
29
+ get_weights(self.net),
30
+ len(self.trainloader.dataset),
31
+ {"train_loss": train_loss},
32
+ )
33
+
34
+ def evaluate(self, parameters, config):
35
+ set_weights(self.net, parameters)
36
+ loss, accuracy = test(self.net, self.valloader, self.device)
37
+ return loss, len(self.valloader.dataset), {"accuracy": accuracy}
38
+
39
+
40
+ def client_fn(context: Context):
41
+ # Load model and data
42
+ net = Net()
43
+ partition_id = context.node_config["partition-id"]
44
+ num_partitions = context.node_config["num-partitions"]
45
+ trainloader, valloader = load_data(partition_id, num_partitions)
46
+ local_epochs = context.run_config["local-epochs"]
47
+
48
+ # Return Client instance
49
+ return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
50
+
51
+
52
+ # Flower ClientApp
53
+ app = ClientApp(
54
+ client_fn,
55
+ )
@@ -1,31 +1,41 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
6
- from $import_name.task import Net, get_weights
7
-
8
-
9
- def server_fn(context: Context):
10
- # Read from config
11
- num_rounds = context.run_config["num-server-rounds"]
12
- fraction_fit = context.run_config["fraction-fit"]
13
-
14
- # Initialize model parameters
15
- ndarrays = get_weights(Net())
16
- parameters = ndarrays_to_parameters(ndarrays)
17
-
18
- # Define strategy
19
- strategy = FedAvg(
20
- fraction_fit=fraction_fit,
21
- fraction_evaluate=1.0,
22
- min_available_clients=2,
23
- initial_parameters=parameters,
24
- )
25
- config = ServerConfig(num_rounds=num_rounds)
26
-
27
- return ServerAppComponents(strategy=strategy, config=config)
3
+ import torch
4
+ from flwr.app import ArrayRecord, ConfigRecord, Context
5
+ from flwr.serverapp import Grid, ServerApp
6
+ from flwr.serverapp.strategy import FedAvg
28
7
 
8
+ from $import_name.task import Net
29
9
 
30
10
  # Create ServerApp
31
- app = ServerApp(server_fn=server_fn)
11
+ app = ServerApp()
12
+
13
+
14
+ @app.main()
15
+ def main(grid: Grid, context: Context) -> None:
16
+ """Main entry point for the ServerApp."""
17
+
18
+ # Read run config
19
+ fraction_train: float = context.run_config["fraction-train"]
20
+ num_rounds: int = context.run_config["num-server-rounds"]
21
+ lr: float = context.run_config["lr"]
22
+
23
+ # Load global model
24
+ global_model = Net()
25
+ arrays = ArrayRecord(global_model.state_dict())
26
+
27
+ # Initialize FedAvg strategy
28
+ strategy = FedAvg(fraction_train=fraction_train)
29
+
30
+ # Start strategy, run FedAvg for `num_rounds`
31
+ result = strategy.start(
32
+ grid=grid,
33
+ initial_arrays=arrays,
34
+ train_config=ConfigRecord({"lr": lr}),
35
+ num_rounds=num_rounds,
36
+ )
37
+
38
+ # Save final model to disk
39
+ print("\nSaving final model to disk...")
40
+ state_dict = result.arrays.to_torch_state_dict()
41
+ torch.save(state_dict, "final_model.pt")
@@ -0,0 +1,31 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ from flwr.common import Context, ndarrays_to_parameters
4
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
+ from flwr.server.strategy import FedAvg
6
+ from $import_name.task import Net, get_weights
7
+
8
+
9
+ def server_fn(context: Context):
10
+ # Read from config
11
+ num_rounds = context.run_config["num-server-rounds"]
12
+ fraction_fit = context.run_config["fraction-fit"]
13
+
14
+ # Initialize model parameters
15
+ ndarrays = get_weights(Net())
16
+ parameters = ndarrays_to_parameters(ndarrays)
17
+
18
+ # Define strategy
19
+ strategy = FedAvg(
20
+ fraction_fit=fraction_fit,
21
+ fraction_evaluate=1.0,
22
+ min_available_clients=2,
23
+ initial_parameters=parameters,
24
+ )
25
+ config = ServerConfig(num_rounds=num_rounds)
26
+
27
+ return ServerAppComponents(strategy=strategy, config=config)
28
+
29
+
30
+ # Create ServerApp
31
+ app = ServerApp(server_fn=server_fn)
@@ -1,7 +1,5 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from collections import OrderedDict
4
-
5
3
  import torch
6
4
  import torch.nn as nn
7
5
  import torch.nn.functional as F
@@ -34,6 +32,14 @@ class Net(nn.Module):
34
32
 
35
33
  fds = None # Cache FederatedDataset
36
34
 
35
+ pytorch_transforms = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
36
+
37
+
38
+ def apply_transforms(batch):
39
+ """Apply transforms to the partition from FederatedDataset."""
40
+ batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
41
+ return batch
42
+
37
43
 
38
44
  def load_data(partition_id: int, num_partitions: int):
39
45
  """Load partition CIFAR10 data."""
@@ -48,38 +54,29 @@ def load_data(partition_id: int, num_partitions: int):
48
54
  partition = fds.load_partition(partition_id)
49
55
  # Divide data on each node: 80% train, 20% test
50
56
  partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
51
- pytorch_transforms = Compose(
52
- [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
53
- )
54
-
55
- def apply_transforms(batch):
56
- """Apply transforms to the partition from FederatedDataset."""
57
- batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
58
- return batch
59
-
57
+ # Construct dataloaders
60
58
  partition_train_test = partition_train_test.with_transform(apply_transforms)
61
59
  trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
62
60
  testloader = DataLoader(partition_train_test["test"], batch_size=32)
63
61
  return trainloader, testloader
64
62
 
65
63
 
66
- def train(net, trainloader, epochs, device):
64
+ def train(net, trainloader, epochs, lr, device):
67
65
  """Train the model on the training set."""
68
66
  net.to(device) # move model to GPU if available
69
67
  criterion = torch.nn.CrossEntropyLoss().to(device)
70
- optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
68
+ optimizer = torch.optim.Adam(net.parameters(), lr=lr)
71
69
  net.train()
72
70
  running_loss = 0.0
73
71
  for _ in range(epochs):
74
72
  for batch in trainloader:
75
- images = batch["img"]
76
- labels = batch["label"]
73
+ images = batch["img"].to(device)
74
+ labels = batch["label"].to(device)
77
75
  optimizer.zero_grad()
78
- loss = criterion(net(images.to(device)), labels.to(device))
76
+ loss = criterion(net(images), labels)
79
77
  loss.backward()
80
78
  optimizer.step()
81
79
  running_loss += loss.item()
82
-
83
80
  avg_trainloss = running_loss / len(trainloader)
84
81
  return avg_trainloss
85
82
 
@@ -99,13 +96,3 @@ def test(net, testloader, device):
99
96
  accuracy = correct / len(testloader.dataset)
100
97
  loss = loss / len(testloader)
101
98
  return loss, accuracy
102
-
103
-
104
- def get_weights(net):
105
- return [val.cpu().numpy() for _, val in net.state_dict().items()]
106
-
107
-
108
- def set_weights(net, parameters):
109
- params_dict = zip(net.state_dict().keys(), parameters)
110
- state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
111
- net.load_state_dict(state_dict, strict=True)
@@ -1,5 +1,7 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
+ from collections import OrderedDict
4
+
3
5
  import torch
4
6
  import torch.nn as nn
5
7
  import torch.nn.functional as F
@@ -32,14 +34,6 @@ class Net(nn.Module):
32
34
 
33
35
  fds = None # Cache FederatedDataset
34
36
 
35
- pytorch_transforms = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
36
-
37
-
38
- def apply_transforms(batch):
39
- """Apply transforms to the partition from FederatedDataset."""
40
- batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
41
- return batch
42
-
43
37
 
44
38
  def load_data(partition_id: int, num_partitions: int):
45
39
  """Load partition CIFAR10 data."""
@@ -54,29 +48,38 @@ def load_data(partition_id: int, num_partitions: int):
54
48
  partition = fds.load_partition(partition_id)
55
49
  # Divide data on each node: 80% train, 20% test
56
50
  partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
57
- # Construct dataloaders
51
+ pytorch_transforms = Compose(
52
+ [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
53
+ )
54
+
55
+ def apply_transforms(batch):
56
+ """Apply transforms to the partition from FederatedDataset."""
57
+ batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
58
+ return batch
59
+
58
60
  partition_train_test = partition_train_test.with_transform(apply_transforms)
59
61
  trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
60
62
  testloader = DataLoader(partition_train_test["test"], batch_size=32)
61
63
  return trainloader, testloader
62
64
 
63
65
 
64
- def train(net, trainloader, epochs, lr, device):
66
+ def train(net, trainloader, epochs, device):
65
67
  """Train the model on the training set."""
66
68
  net.to(device) # move model to GPU if available
67
69
  criterion = torch.nn.CrossEntropyLoss().to(device)
68
- optimizer = torch.optim.Adam(net.parameters(), lr=lr)
70
+ optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
69
71
  net.train()
70
72
  running_loss = 0.0
71
73
  for _ in range(epochs):
72
74
  for batch in trainloader:
73
- images = batch["img"].to(device)
74
- labels = batch["label"].to(device)
75
+ images = batch["img"]
76
+ labels = batch["label"]
75
77
  optimizer.zero_grad()
76
- loss = criterion(net(images), labels)
78
+ loss = criterion(net(images.to(device)), labels.to(device))
77
79
  loss.backward()
78
80
  optimizer.step()
79
81
  running_loss += loss.item()
82
+
80
83
  avg_trainloss = running_loss / len(trainloader)
81
84
  return avg_trainloss
82
85
 
@@ -96,3 +99,13 @@ def test(net, testloader, device):
96
99
  accuracy = correct / len(testloader.dataset)
97
100
  loss = loss / len(testloader)
98
101
  return loss, accuracy
102
+
103
+
104
+ def get_weights(net):
105
+ return [val.cpu().numpy() for _, val in net.state_dict().items()]
106
+
107
+
108
+ def set_weights(net, parameters):
109
+ params_dict = zip(net.state_dict().keys(), parameters)
110
+ state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
111
+ net.load_state_dict(state_dict, strict=True)
@@ -27,7 +27,6 @@ packages = ["."]
27
27
  publisher = "$username"
28
28
 
29
29
  # Point to your ServerApp and ClientApp objects
30
- # Format: "<module>:<object>"
31
30
  [tool.flwr.app.components]
32
31
  serverapp = "$import_name.server_app:app"
33
32
  clientapp = "$import_name.client_app:app"
@@ -35,8 +34,9 @@ clientapp = "$import_name.client_app:app"
35
34
  # Custom config values accessible via `context.run_config`
36
35
  [tool.flwr.app.config]
37
36
  num-server-rounds = 3
38
- fraction-fit = 0.5
37
+ fraction-train = 0.5
39
38
  local-epochs = 1
39
+ lr = 0.01
40
40
 
41
41
  # Default federation to use when running the app
42
42
  [tool.flwr.federations]
@@ -27,6 +27,7 @@ packages = ["."]
27
27
  publisher = "$username"
28
28
 
29
29
  # Point to your ServerApp and ClientApp objects
30
+ # Format: "<module>:<object>"
30
31
  [tool.flwr.app.components]
31
32
  serverapp = "$import_name.server_app:app"
32
33
  clientapp = "$import_name.client_app:app"
@@ -34,9 +35,8 @@ clientapp = "$import_name.client_app:app"
34
35
  # Custom config values accessible via `context.run_config`
35
36
  [tool.flwr.app.config]
36
37
  num-server-rounds = 3
37
- fraction-train = 0.5
38
+ fraction-fit = 0.5
38
39
  local-epochs = 1
39
- lr = 0.01
40
40
 
41
41
  # Default federation to use when running the app
42
42
  [tool.flwr.federations]
@@ -22,6 +22,7 @@ from .dp_fixed_clipping import (
22
22
  from .fedadagrad import FedAdagrad
23
23
  from .fedadam import FedAdam
24
24
  from .fedavg import FedAvg
25
+ from .fedxgb_bagging import FedXgbBagging
25
26
  from .fedyogi import FedYogi
26
27
  from .result import Result
27
28
  from .strategy import Strategy
@@ -32,6 +33,7 @@ __all__ = [
32
33
  "FedAdagrad",
33
34
  "FedAdam",
34
35
  "FedAvg",
36
+ "FedXgbBagging",
35
37
  "FedYogi",
36
38
  "Result",
37
39
  "Strategy",
@@ -0,0 +1,82 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower message-based FedXgbBagging strategy."""
16
+ from collections.abc import Iterable
17
+ from typing import Optional, cast
18
+
19
+ import numpy as np
20
+
21
+ from flwr.common import ArrayRecord, ConfigRecord, Message, MetricRecord
22
+ from flwr.server import Grid
23
+
24
+ from ..exception import InconsistentMessageReplies
25
+ from .fedavg import FedAvg
26
+ from .strategy_utils import aggregate_bagging
27
+
28
+
29
+ # pylint: disable=line-too-long
30
+ class FedXgbBagging(FedAvg):
31
+ """Configurable FedXgbBagging strategy implementation."""
32
+
33
+ current_bst: Optional[bytes] = None
34
+
35
+ def _ensure_single_array(self, arrays: ArrayRecord) -> None:
36
+ """Check that ensures there's only one Array in the ArrayRecord."""
37
+ n = len(arrays)
38
+ if n != 1:
39
+ raise InconsistentMessageReplies(
40
+ reason="Expected exactly one Array in ArrayRecord. "
41
+ "Skipping aggregation."
42
+ )
43
+
44
+ def configure_train(
45
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
46
+ ) -> Iterable[Message]:
47
+ """Configure the next round of federated training."""
48
+ self._ensure_single_array(arrays)
49
+ # Keep track of array record being communicated
50
+ self.current_bst = arrays["0"].numpy().tobytes()
51
+ return super().configure_train(server_round, arrays, config, grid)
52
+
53
+ def aggregate_train(
54
+ self,
55
+ server_round: int,
56
+ replies: Iterable[Message],
57
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
58
+ """Aggregate ArrayRecords and MetricRecords in the received Messages."""
59
+ valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
60
+
61
+ arrays, metrics = None, None
62
+ if valid_replies:
63
+ reply_contents = [msg.content for msg in valid_replies]
64
+ array_record_key = next(iter(reply_contents[0].array_records.keys()))
65
+
66
+ # Aggregate ArrayRecords
67
+ for content in reply_contents:
68
+ self._ensure_single_array(cast(ArrayRecord, content[array_record_key]))
69
+ bst = content[array_record_key]["0"].numpy().tobytes() # type: ignore[union-attr]
70
+
71
+ if self.current_bst is not None:
72
+ self.current_bst = aggregate_bagging(self.current_bst, bst)
73
+
74
+ if self.current_bst is not None:
75
+ arrays = ArrayRecord([np.frombuffer(self.current_bst, dtype=np.uint8)])
76
+
77
+ # Aggregate MetricRecords
78
+ metrics = self.train_metrics_aggr_fn(
79
+ reply_contents,
80
+ self.weighted_by_key,
81
+ )
82
+ return arrays, metrics
@@ -15,6 +15,7 @@
15
15
  """Flower message-based strategy utilities."""
16
16
 
17
17
 
18
+ import json
18
19
  import random
19
20
  from collections import OrderedDict
20
21
  from logging import INFO
@@ -249,3 +250,50 @@ def validate_message_reply_consistency(
249
250
  "must be a single value (int or float), but a list was found. Skipping "
250
251
  "aggregation."
251
252
  )
253
+
254
+
255
+ def aggregate_bagging(
256
+ bst_prev_org: bytes,
257
+ bst_curr_org: bytes,
258
+ ) -> bytes:
259
+ """Conduct bagging aggregation for given trees."""
260
+ if bst_prev_org == b"":
261
+ return bst_curr_org
262
+
263
+ # Get the tree numbers
264
+ tree_num_prev, _ = _get_tree_nums(bst_prev_org)
265
+ _, paral_tree_num_curr = _get_tree_nums(bst_curr_org)
266
+
267
+ bst_prev = json.loads(bytearray(bst_prev_org))
268
+ bst_curr = json.loads(bytearray(bst_curr_org))
269
+
270
+ previous_model = bst_prev["learner"]["gradient_booster"]["model"]
271
+ previous_model["gbtree_model_param"]["num_trees"] = str(
272
+ tree_num_prev + paral_tree_num_curr
273
+ )
274
+ iteration_indptr = previous_model["iteration_indptr"]
275
+ previous_model["iteration_indptr"].append(
276
+ iteration_indptr[-1] + paral_tree_num_curr
277
+ )
278
+
279
+ # Aggregate new trees
280
+ trees_curr = bst_curr["learner"]["gradient_booster"]["model"]["trees"]
281
+ for tree_count in range(paral_tree_num_curr):
282
+ trees_curr[tree_count]["id"] = tree_num_prev + tree_count
283
+ previous_model["trees"].append(trees_curr[tree_count])
284
+ previous_model["tree_info"].append(0)
285
+
286
+ bst_prev_bytes = bytes(json.dumps(bst_prev), "utf-8")
287
+
288
+ return bst_prev_bytes
289
+
290
+
291
+ def _get_tree_nums(xgb_model_org: bytes) -> tuple[int, int]:
292
+ xgb_model = json.loads(bytearray(xgb_model_org))
293
+
294
+ # Access model parameters
295
+ model_param = xgb_model["learner"]["gradient_booster"]["model"][
296
+ "gbtree_model_param"
297
+ ]
298
+ # Return the number of trees and the number of parallel trees
299
+ return int(model_param["num_trees"]), int(model_param["num_parallel_tree"])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: flwr-nightly
3
- Version: 1.22.0.dev20250913
3
+ Version: 1.22.0.dev20250915
4
4
  Summary: Flower: A Friendly Federated AI Framework
5
5
  License: Apache-2.0
6
6
  Keywords: Artificial Intelligence,Federated AI,Federated Analytics,Federated Evaluation,Federated Learning,Flower,Machine Learning
@@ -18,7 +18,7 @@ flwr/cli/login/__init__.py,sha256=B1SXKU3HCQhWfFDMJhlC7FOl8UsvH4mxysxeBnrfyUE,80
18
18
  flwr/cli/login/login.py,sha256=RM1Jiv_VFm3oz4rTHSr3D87X90lW3WzErjBBU7WviWY,4309
19
19
  flwr/cli/ls.py,sha256=3YK7cpoImJ7PbjlP_JgYRQWz1GymX2q7Reu-mKJEpao,10957
20
20
  flwr/cli/new/__init__.py,sha256=QA1E2QtzPvFCjLTUHnFnJbufuFiGyT_0Y53Wpbvg1F0,790
21
- flwr/cli/new/new.py,sha256=46QuAi7Act3_TbD0IkejUhognXPXlo2r3LRPvN8pEkA,10503
21
+ flwr/cli/new/new.py,sha256=KyTs9Fbm4eoJ5DohhuTkYNJJX5rDC0p-YTPtNatYXrI,10529
22
22
  flwr/cli/new/templates/__init__.py,sha256=FpjWCfIySU2DB4kh0HOXLAjlZNNFDTVU4w3HoE2TzcI,725
23
23
  flwr/cli/new/templates/app/.gitignore.tpl,sha256=HZJcGQoxp7aUzaPg8Uqch3kNrIESwr9yjimDxJYgXVY,3104
24
24
  flwr/cli/new/templates/app/LICENSE.tpl,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
@@ -29,14 +29,14 @@ flwr/cli/new/templates/app/__init__.py,sha256=LbR0ksGiF566JcHM_H5m1Tc4-oYUEilWFl
29
29
  flwr/cli/new/templates/app/code/__init__.baseline.py.tpl,sha256=YkHAgppUeD2BnBoGfVB6dEvBfjuIPGsU1gw4CiUi3qA,40
30
30
  flwr/cli/new/templates/app/code/__init__.py,sha256=zXa2YU1swzHxOKDQbwlDMEwVPOUswVeosjkiXNMTgFo,736
31
31
  flwr/cli/new/templates/app/code/__init__.py.tpl,sha256=J0Gn74E7khpLyKJVNqOPu7ev93vkcu1PZugsbxtABMw,52
32
- flwr/cli/new/templates/app/code/__init__.pytorch_msg_api.py.tpl,sha256=mKIS8MK_X8T9NlmcX1-_c9Bbexc-ueqDIBI7uN6c4dE,45
32
+ flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl,sha256=mKIS8MK_X8T9NlmcX1-_c9Bbexc-ueqDIBI7uN6c4dE,45
33
33
  flwr/cli/new/templates/app/code/client.baseline.py.tpl,sha256=IYlCZqnaxT2ucP1ReffRNohOkYwNrhtrnDoQBBcrThY,1901
34
34
  flwr/cli/new/templates/app/code/client.huggingface.py.tpl,sha256=SIZZ3s-6u8IU8cFfsqu6ZU8zjhfI1m1SWauOSUcW8TA,3015
35
35
  flwr/cli/new/templates/app/code/client.jax.py.tpl,sha256=uFCIPwAHYiRAgh2W3nRni_Oig02ZzRF-ofUG5O19zcE,2125
36
36
  flwr/cli/new/templates/app/code/client.mlx.py.tpl,sha256=CHU2IBIzI2YENZZuvTsAlSdL94DK19wMYMIhr-JgwZ8,3422
37
37
  flwr/cli/new/templates/app/code/client.numpy.py.tpl,sha256=1_WEoOPe9jJeK-7FZgYuDUqY8mC0vxgqA83d-h201Gk,1381
38
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl,sha256=fuxVmZpjHIueNy_aHWF81531vmi8DGu4CYjYDqmUwWo,1705
39
- flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl,sha256=fYoh-dTu07LkqNYvwcxQnbgVvH4Yo4eiGEcyHECbsnU,2473
38
+ flwr/cli/new/templates/app/code/client.pytorch.py.tpl,sha256=fYoh-dTu07LkqNYvwcxQnbgVvH4Yo4eiGEcyHECbsnU,2473
39
+ flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl,sha256=fuxVmZpjHIueNy_aHWF81531vmi8DGu4CYjYDqmUwWo,1705
40
40
  flwr/cli/new/templates/app/code/client.sklearn.py.tpl,sha256=0qqEe-RRjkHGOH8gsD9e83ae-kyyYixhyBgzVHjYpzk,3500
41
41
  flwr/cli/new/templates/app/code/client.tensorflow.py.tpl,sha256=8o55KXpsbF_rv6o98ZNYJDCazjwMp_RPTaSzDfT7Qlw,2682
42
42
  flwr/cli/new/templates/app/code/dataset.baseline.py.tpl,sha256=jbd_exHAk2-Blu_kVutjPO6a_dkJQWb232zxSeXIZ1k,1453
@@ -52,8 +52,8 @@ flwr/cli/new/templates/app/code/server.huggingface.py.tpl,sha256=_2Mv-SqGSMf7sMd
52
52
  flwr/cli/new/templates/app/code/server.jax.py.tpl,sha256=RW-rh7ogcJ3_BD66bJxTw-ZoP7c-4SK8hVHc-e0SSVY,1029
53
53
  flwr/cli/new/templates/app/code/server.mlx.py.tpl,sha256=J8rIe6RL2ndODVJD79xShRKBH70HljFSCi4s_RJ-xLQ,1200
54
54
  flwr/cli/new/templates/app/code/server.numpy.py.tpl,sha256=T3hcKbPw3uL5lXEP-MuVJXIBXjzva5sWJXfpQqarUwA,955
55
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl,sha256=gvBsGA_Jg9kAH8xTxjzTjMcvBtciuccOwQFbO7ey8tU,916
56
- flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl,sha256=epARqfcQ-EQsdZwaaaUp5y4OSTBT6CiFGlNRocw-23A,1158
55
+ flwr/cli/new/templates/app/code/server.pytorch.py.tpl,sha256=epARqfcQ-EQsdZwaaaUp5y4OSTBT6CiFGlNRocw-23A,1158
56
+ flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl,sha256=gvBsGA_Jg9kAH8xTxjzTjMcvBtciuccOwQFbO7ey8tU,916
57
57
  flwr/cli/new/templates/app/code/server.sklearn.py.tpl,sha256=ehQ5VRgBn92WeFl6kupwJnuxSNkKvE-EvKde6A9mNQo,1377
58
58
  flwr/cli/new/templates/app/code/server.tensorflow.py.tpl,sha256=2-WTOPd-ewdLd9QmSlflIH7ix7zxAzPEOZoyiPBOy8c,1010
59
59
  flwr/cli/new/templates/app/code/strategy.baseline.py.tpl,sha256=YkHAgppUeD2BnBoGfVB6dEvBfjuIPGsU1gw4CiUi3qA,40
@@ -61,8 +61,8 @@ flwr/cli/new/templates/app/code/task.huggingface.py.tpl,sha256=piBbY3Dg60bQnCg15
61
61
  flwr/cli/new/templates/app/code/task.jax.py.tpl,sha256=Fb0XgdTAQplM-ZCusI081XA9asO3gHptH772S-Xcyy8,1525
62
62
  flwr/cli/new/templates/app/code/task.mlx.py.tpl,sha256=YxH5z4s5kOh5_9DIY9pvzqURckLDfgdanTA68_iM_Wo,2946
63
63
  flwr/cli/new/templates/app/code/task.numpy.py.tpl,sha256=CwUJPnN3z6GjP8-KVGWzx7RYRJsl0wLFZ72xscvl3RM,126
64
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl,sha256=XlJqA4Ix_PloO_zJLhjiN5vDj16w3I4CPVGdmbe8asE,3800
65
- flwr/cli/new/templates/app/code/task.pytorch_msg_api.py.tpl,sha256=RKA5lV6O6OnVKZ2r75pbzwy9arg5o2lzXqG2kNrLIUU,3446
64
+ flwr/cli/new/templates/app/code/task.pytorch.py.tpl,sha256=RKA5lV6O6OnVKZ2r75pbzwy9arg5o2lzXqG2kNrLIUU,3446
65
+ flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl,sha256=XlJqA4Ix_PloO_zJLhjiN5vDj16w3I4CPVGdmbe8asE,3800
66
66
  flwr/cli/new/templates/app/code/task.sklearn.py.tpl,sha256=vHdhtMp0FHxbYafXyhDT9aKmmmA0Jvpx5Oum1Yu9lWY,1850
67
67
  flwr/cli/new/templates/app/code/task.tensorflow.py.tpl,sha256=impgWN7MfztmcWF4xh1llcZGsgTvrb1HD5ZE0t-8U08,1731
68
68
  flwr/cli/new/templates/app/code/utils.baseline.py.tpl,sha256=YkHAgppUeD2BnBoGfVB6dEvBfjuIPGsU1gw4CiUi3qA,40
@@ -72,8 +72,8 @@ flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl,sha256=xHGF38i7oFpvnFv
72
72
  flwr/cli/new/templates/app/pyproject.jax.toml.tpl,sha256=fdDhwmPoMirJ095cU_vFCBf0ILQlAoa1fdnHb2LM1yk,1471
73
73
  flwr/cli/new/templates/app/pyproject.mlx.toml.tpl,sha256=PAjPT2v06sBZxacNiyMJloDwocCK5tFcGQmMXOoBqc8,1542
74
74
  flwr/cli/new/templates/app/pyproject.numpy.toml.tpl,sha256=Kb_O2iQfzwc6FTy3fWqtQYc3FwY6x9SUgQPGqZR_ILg,1409
75
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl,sha256=docQbs3MuRR-yT24lVz7N2sQL3Sj49EHuOCuRj_0djQ,1508
76
- flwr/cli/new/templates/app/pyproject.pytorch_msg_api.toml.tpl,sha256=SE4H23OFkQbqNU64nYf38igqrT4cJGA7XxEtSnNxJqg,1490
75
+ flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl,sha256=SE4H23OFkQbqNU64nYf38igqrT4cJGA7XxEtSnNxJqg,1490
76
+ flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl,sha256=docQbs3MuRR-yT24lVz7N2sQL3Sj49EHuOCuRj_0djQ,1508
77
77
  flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl,sha256=apauU_PUmLEbt2rjckKniEbzdRs1EnMri_qgtHtBJZ8,1484
78
78
  flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl,sha256=LQpDKJTEnRKj5Ygn5FkT44SxlnLVprkPlbrGaFf5Q50,1508
79
79
  flwr/cli/run/__init__.py,sha256=RPyB7KbYTFl6YRiilCch6oezxrLQrl1kijV7BMGkLbA,790
@@ -332,16 +332,17 @@ flwr/server/workflow/secure_aggregation/secaggplus_workflow.py,sha256=DkayCsnlAy
332
332
  flwr/serverapp/__init__.py,sha256=ZujKNXULwhWYQhFnxOOT5Wi9MRq2JCWFhAAj7ouiQ78,884
333
333
  flwr/serverapp/dp_fixed_clipping.py,sha256=wbP4W7CaUHXdll8ZSVUnTBSEWrnWM00CGk63rOR-Q2s,12133
334
334
  flwr/serverapp/exception.py,sha256=5cuH-2AafvihzosWDdDjuMmHdDqZ1XxHvCqZXNBVklw,1334
335
- flwr/serverapp/strategy/__init__.py,sha256=yAYBZUkp4aNmcTLsvormEc9HyO34oEoFN45LiHgujE0,1229
335
+ flwr/serverapp/strategy/__init__.py,sha256=0ldxlooz4a5yewUbQJGVrW9awrrIcFDIrNR4yZgpfKw,1292
336
336
  flwr/serverapp/strategy/dp_fixed_clipping.py,sha256=wbP4W7CaUHXdll8ZSVUnTBSEWrnWM00CGk63rOR-Q2s,12133
337
337
  flwr/serverapp/strategy/fedadagrad.py,sha256=fD65P6OEERa_pxq847e1UZpA083AcWR44XavYB0naGM,6343
338
338
  flwr/serverapp/strategy/fedadam.py,sha256=s3xPIqhopy6yPTeFxevSPnc7a6BcKnKsvo2AaO6Z_xs,7138
339
339
  flwr/serverapp/strategy/fedavg.py,sha256=53L06lZLkbGV0TRZrUWvPaocvFTT1PAhTvu9UkKq1zE,11294
340
340
  flwr/serverapp/strategy/fedopt.py,sha256=kqT0uV2IUE93O72XEVa1JJo61dcwbZEoT9KmYTjR2tE,8477
341
+ flwr/serverapp/strategy/fedxgb_bagging.py,sha256=ktDjzov4y0BRecioq788umCEtcuwElou9olBizQKOnM,3282
341
342
  flwr/serverapp/strategy/fedyogi.py,sha256=1Ripr4Hi2cdeTOLiFOXtMKvOxR3BsUQwc7bbTrXN4LM,6653
342
343
  flwr/serverapp/strategy/result.py,sha256=E0Hl2VLnZAgQJjE2GDoKsK7JX-kPPU2KXc47Axt6hGw,4295
343
344
  flwr/serverapp/strategy/strategy.py,sha256=8uJGGm1ROLZERQ_dkRS7Z_rs-yK6XCE0UxXtIdFiEWk,10789
344
- flwr/serverapp/strategy/strategy_utils.py,sha256=9ga93Se21I_k7zYiw343EMC2qCTQ8rUG5ZEm8HVEuFs,9246
345
+ flwr/serverapp/strategy/strategy_utils.py,sha256=hiwS7k-Hx6_c4NZXoKpHucS5CBKb7f8GppXRBSMt3Us,10851
345
346
  flwr/serverapp/strategy/strategy_utils_tests.py,sha256=o32XHujd9PLCB-YZMI2AttWLlvUXHe9yuxgiCrCkpgU,10209
346
347
  flwr/simulation/__init__.py,sha256=Gg6OsP1Z-ixc3-xxzvl7j7rz2Fijy9rzyEPpxgAQCeM,1556
347
348
  flwr/simulation/app.py,sha256=LbGLMvN9Ap119yBqsUcNNmVLRnCySnr4VechqcQ1hpA,10401
@@ -403,7 +404,7 @@ flwr/supernode/servicer/__init__.py,sha256=lucTzre5WPK7G1YLCfaqg3rbFWdNSb7ZTt-ca
403
404
  flwr/supernode/servicer/clientappio/__init__.py,sha256=7Oy62Y_oijqF7Dxi6tpcUQyOpLc_QpIRZ83NvwmB0Yg,813
404
405
  flwr/supernode/servicer/clientappio/clientappio_servicer.py,sha256=nIHRu38EWK-rpNOkcgBRAAKwYQQWFeCwu0lkO7OPZGQ,10239
405
406
  flwr/supernode/start_client_internal.py,sha256=Y9S1-QlO2WP6eo4JvWzIpfaCoh2aoE7bjEYyxNNnlyg,20777
406
- flwr_nightly-1.22.0.dev20250913.dist-info/METADATA,sha256=taZ5hyFAPFrevCeD1fE30C3M-BaOJVn2vpR-z-f_eA8,15967
407
- flwr_nightly-1.22.0.dev20250913.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
408
- flwr_nightly-1.22.0.dev20250913.dist-info/entry_points.txt,sha256=hxHD2ixb_vJFDOlZV-zB4Ao32_BQlL34ftsDh1GXv14,420
409
- flwr_nightly-1.22.0.dev20250913.dist-info/RECORD,,
407
+ flwr_nightly-1.22.0.dev20250915.dist-info/METADATA,sha256=FBo-ub8Rc1rRhLrioWMroybBDcoP9t7v6vBqdE9U3do,15967
408
+ flwr_nightly-1.22.0.dev20250915.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
409
+ flwr_nightly-1.22.0.dev20250915.dist-info/entry_points.txt,sha256=hxHD2ixb_vJFDOlZV-zB4Ao32_BQlL34ftsDh1GXv14,420
410
+ flwr_nightly-1.22.0.dev20250915.dist-info/RECORD,,
@@ -1,80 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import torch
4
- from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
5
- from flwr.clientapp import ClientApp
6
-
7
- from $import_name.task import Net, load_data
8
- from $import_name.task import test as test_fn
9
- from $import_name.task import train as train_fn
10
-
11
- # Flower ClientApp
12
- app = ClientApp()
13
-
14
-
15
- @app.train()
16
- def train(msg: Message, context: Context):
17
- """Train the model on local data."""
18
-
19
- # Load the model and initialize it with the received weights
20
- model = Net()
21
- model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
22
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
- model.to(device)
24
-
25
- # Load the data
26
- partition_id = context.node_config["partition-id"]
27
- num_partitions = context.node_config["num-partitions"]
28
- trainloader, _ = load_data(partition_id, num_partitions)
29
-
30
- # Call the training function
31
- train_loss = train_fn(
32
- model,
33
- trainloader,
34
- context.run_config["local-epochs"],
35
- msg.content["config"]["lr"],
36
- device,
37
- )
38
-
39
- # Construct and return reply Message
40
- model_record = ArrayRecord(model.state_dict())
41
- metrics = {
42
- "train_loss": train_loss,
43
- "num-examples": len(trainloader.dataset),
44
- }
45
- metric_record = MetricRecord(metrics)
46
- content = RecordDict({"arrays": model_record, "metrics": metric_record})
47
- return Message(content=content, reply_to=msg)
48
-
49
-
50
- @app.evaluate()
51
- def evaluate(msg: Message, context: Context):
52
- """Evaluate the model on local data."""
53
-
54
- # Load the model and initialize it with the received weights
55
- model = Net()
56
- model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
57
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58
- model.to(device)
59
-
60
- # Load the data
61
- partition_id = context.node_config["partition-id"]
62
- num_partitions = context.node_config["num-partitions"]
63
- _, valloader = load_data(partition_id, num_partitions)
64
-
65
- # Call the evaluation function
66
- eval_loss, eval_acc = test_fn(
67
- model,
68
- valloader,
69
- device,
70
- )
71
-
72
- # Construct and return reply Message
73
- metrics = {
74
- "eval_loss": eval_loss,
75
- "eval_acc": eval_acc,
76
- "num-examples": len(valloader.dataset),
77
- }
78
- metric_record = MetricRecord(metrics)
79
- content = RecordDict({"metrics": metric_record})
80
- return Message(content=content, reply_to=msg)
@@ -1,41 +0,0 @@
1
- """$project_name: A Flower / $framework_str app."""
2
-
3
- import torch
4
- from flwr.app import ArrayRecord, ConfigRecord, Context
5
- from flwr.serverapp import Grid, ServerApp
6
- from flwr.serverapp.strategy import FedAvg
7
-
8
- from $import_name.task import Net
9
-
10
- # Create ServerApp
11
- app = ServerApp()
12
-
13
-
14
- @app.main()
15
- def main(grid: Grid, context: Context) -> None:
16
- """Main entry point for the ServerApp."""
17
-
18
- # Read run config
19
- fraction_train: float = context.run_config["fraction-train"]
20
- num_rounds: int = context.run_config["num-server-rounds"]
21
- lr: float = context.run_config["lr"]
22
-
23
- # Load global model
24
- global_model = Net()
25
- arrays = ArrayRecord(global_model.state_dict())
26
-
27
- # Initialize FedAvg strategy
28
- strategy = FedAvg(fraction_train=fraction_train)
29
-
30
- # Start strategy, run FedAvg for `num_rounds`
31
- result = strategy.start(
32
- grid=grid,
33
- initial_arrays=arrays,
34
- train_config=ConfigRecord({"lr": lr}),
35
- num_rounds=num_rounds,
36
- )
37
-
38
- # Save final model to disk
39
- print("\nSaving final model to disk...")
40
- state_dict = result.arrays.to_torch_state_dict()
41
- torch.save(state_dict, "final_model.pt")