flwr-nightly 1.22.0.dev20250913__py3-none-any.whl → 1.22.0.dev20250916__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/new/new.py +5 -5
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
- flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +2 -2
- flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +2 -2
- flwr/serverapp/strategy/__init__.py +8 -0
- flwr/serverapp/strategy/fedavg.py +23 -2
- flwr/serverapp/strategy/fedavgm.py +198 -0
- flwr/serverapp/strategy/fedmedian.py +71 -0
- flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
- flwr/serverapp/strategy/fedxgb_bagging.py +82 -0
- flwr/serverapp/strategy/strategy_utils.py +48 -0
- flwr/serverapp/strategy/strategy_utils_tests.py +20 -1
- {flwr_nightly-1.22.0.dev20250913.dist-info → flwr_nightly-1.22.0.dev20250916.dist-info}/METADATA +6 -16
- {flwr_nightly-1.22.0.dev20250913.dist-info → flwr_nightly-1.22.0.dev20250916.dist-info}/RECORD +22 -18
- flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
- flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
- /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
- {flwr_nightly-1.22.0.dev20250913.dist-info → flwr_nightly-1.22.0.dev20250916.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.22.0.dev20250913.dist-info → flwr_nightly-1.22.0.dev20250916.dist-info}/entry_points.txt +0 -0
flwr/cli/new/new.py
CHANGED
@@ -35,7 +35,6 @@ class MlFramework(str, Enum):
|
|
35
35
|
"""Available frameworks."""
|
36
36
|
|
37
37
|
PYTORCH = "PyTorch"
|
38
|
-
PYTORCH_MSG_API = "PyTorch (Message API)"
|
39
38
|
TENSORFLOW = "TensorFlow"
|
40
39
|
SKLEARN = "sklearn"
|
41
40
|
HUGGINGFACE = "HuggingFace"
|
@@ -44,6 +43,7 @@ class MlFramework(str, Enum):
|
|
44
43
|
NUMPY = "NumPy"
|
45
44
|
FLOWERTUNE = "FlowerTune"
|
46
45
|
BASELINE = "Flower Baseline"
|
46
|
+
PYTORCH_LEGACY_API = "PyTorch (Legacy API, deprecated)"
|
47
47
|
|
48
48
|
|
49
49
|
class LlmChallengeName(str, Enum):
|
@@ -155,8 +155,8 @@ def new(
|
|
155
155
|
if framework_str == MlFramework.BASELINE:
|
156
156
|
framework_str = "baseline"
|
157
157
|
|
158
|
-
if framework_str == MlFramework.
|
159
|
-
framework_str = "
|
158
|
+
if framework_str == MlFramework.PYTORCH_LEGACY_API:
|
159
|
+
framework_str = "pytorch_legacy_api"
|
160
160
|
|
161
161
|
print(
|
162
162
|
typer.style(
|
@@ -247,14 +247,14 @@ def new(
|
|
247
247
|
MlFramework.TENSORFLOW.value,
|
248
248
|
MlFramework.SKLEARN.value,
|
249
249
|
MlFramework.NUMPY.value,
|
250
|
-
"
|
250
|
+
"pytorch_legacy_api",
|
251
251
|
]
|
252
252
|
if framework_str in frameworks_with_tasks:
|
253
253
|
files[f"{import_name}/task.py"] = {
|
254
254
|
"template": f"app/code/task.{template_name}.py.tpl"
|
255
255
|
}
|
256
256
|
|
257
|
-
if framework_str == "
|
257
|
+
if framework_str == "pytorch_legacy_api":
|
258
258
|
# Use custom __init__ that better captures name of framework
|
259
259
|
files[f"{import_name}/__init__.py"] = {
|
260
260
|
"template": f"app/code/__init__.{framework_str}.py.tpl"
|
@@ -1,55 +1,80 @@
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
2
2
|
|
3
3
|
import torch
|
4
|
+
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
|
5
|
+
from flwr.clientapp import ClientApp
|
4
6
|
|
5
|
-
from
|
6
|
-
from
|
7
|
-
from $import_name.task import
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
self.trainloader,
|
25
|
-
self.local_epochs,
|
26
|
-
self.device,
|
27
|
-
)
|
28
|
-
return (
|
29
|
-
get_weights(self.net),
|
30
|
-
len(self.trainloader.dataset),
|
31
|
-
{"train_loss": train_loss},
|
32
|
-
)
|
33
|
-
|
34
|
-
def evaluate(self, parameters, config):
|
35
|
-
set_weights(self.net, parameters)
|
36
|
-
loss, accuracy = test(self.net, self.valloader, self.device)
|
37
|
-
return loss, len(self.valloader.dataset), {"accuracy": accuracy}
|
38
|
-
|
39
|
-
|
40
|
-
def client_fn(context: Context):
|
41
|
-
# Load model and data
|
42
|
-
net = Net()
|
7
|
+
from $import_name.task import Net, load_data
|
8
|
+
from $import_name.task import test as test_fn
|
9
|
+
from $import_name.task import train as train_fn
|
10
|
+
|
11
|
+
# Flower ClientApp
|
12
|
+
app = ClientApp()
|
13
|
+
|
14
|
+
|
15
|
+
@app.train()
|
16
|
+
def train(msg: Message, context: Context):
|
17
|
+
"""Train the model on local data."""
|
18
|
+
|
19
|
+
# Load the model and initialize it with the received weights
|
20
|
+
model = Net()
|
21
|
+
model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
|
22
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
23
|
+
model.to(device)
|
24
|
+
|
25
|
+
# Load the data
|
43
26
|
partition_id = context.node_config["partition-id"]
|
44
27
|
num_partitions = context.node_config["num-partitions"]
|
45
|
-
trainloader,
|
46
|
-
local_epochs = context.run_config["local-epochs"]
|
28
|
+
trainloader, _ = load_data(partition_id, num_partitions)
|
47
29
|
|
48
|
-
#
|
49
|
-
|
30
|
+
# Call the training function
|
31
|
+
train_loss = train_fn(
|
32
|
+
model,
|
33
|
+
trainloader,
|
34
|
+
context.run_config["local-epochs"],
|
35
|
+
msg.content["config"]["lr"],
|
36
|
+
device,
|
37
|
+
)
|
50
38
|
|
39
|
+
# Construct and return reply Message
|
40
|
+
model_record = ArrayRecord(model.state_dict())
|
41
|
+
metrics = {
|
42
|
+
"train_loss": train_loss,
|
43
|
+
"num-examples": len(trainloader.dataset),
|
44
|
+
}
|
45
|
+
metric_record = MetricRecord(metrics)
|
46
|
+
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
47
|
+
return Message(content=content, reply_to=msg)
|
51
48
|
|
52
|
-
|
53
|
-
app
|
54
|
-
|
55
|
-
|
49
|
+
|
50
|
+
@app.evaluate()
|
51
|
+
def evaluate(msg: Message, context: Context):
|
52
|
+
"""Evaluate the model on local data."""
|
53
|
+
|
54
|
+
# Load the model and initialize it with the received weights
|
55
|
+
model = Net()
|
56
|
+
model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
|
57
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
58
|
+
model.to(device)
|
59
|
+
|
60
|
+
# Load the data
|
61
|
+
partition_id = context.node_config["partition-id"]
|
62
|
+
num_partitions = context.node_config["num-partitions"]
|
63
|
+
_, valloader = load_data(partition_id, num_partitions)
|
64
|
+
|
65
|
+
# Call the evaluation function
|
66
|
+
eval_loss, eval_acc = test_fn(
|
67
|
+
model,
|
68
|
+
valloader,
|
69
|
+
device,
|
70
|
+
)
|
71
|
+
|
72
|
+
# Construct and return reply Message
|
73
|
+
metrics = {
|
74
|
+
"eval_loss": eval_loss,
|
75
|
+
"eval_acc": eval_acc,
|
76
|
+
"num-examples": len(valloader.dataset),
|
77
|
+
}
|
78
|
+
metric_record = MetricRecord(metrics)
|
79
|
+
content = RecordDict({"metrics": metric_record})
|
80
|
+
return Message(content=content, reply_to=msg)
|
@@ -0,0 +1,55 @@
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from flwr.client import ClientApp, NumPyClient
|
6
|
+
from flwr.common import Context
|
7
|
+
from $import_name.task import Net, get_weights, load_data, set_weights, test, train
|
8
|
+
|
9
|
+
|
10
|
+
# Define Flower Client and client_fn
|
11
|
+
class FlowerClient(NumPyClient):
|
12
|
+
def __init__(self, net, trainloader, valloader, local_epochs):
|
13
|
+
self.net = net
|
14
|
+
self.trainloader = trainloader
|
15
|
+
self.valloader = valloader
|
16
|
+
self.local_epochs = local_epochs
|
17
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
18
|
+
self.net.to(self.device)
|
19
|
+
|
20
|
+
def fit(self, parameters, config):
|
21
|
+
set_weights(self.net, parameters)
|
22
|
+
train_loss = train(
|
23
|
+
self.net,
|
24
|
+
self.trainloader,
|
25
|
+
self.local_epochs,
|
26
|
+
self.device,
|
27
|
+
)
|
28
|
+
return (
|
29
|
+
get_weights(self.net),
|
30
|
+
len(self.trainloader.dataset),
|
31
|
+
{"train_loss": train_loss},
|
32
|
+
)
|
33
|
+
|
34
|
+
def evaluate(self, parameters, config):
|
35
|
+
set_weights(self.net, parameters)
|
36
|
+
loss, accuracy = test(self.net, self.valloader, self.device)
|
37
|
+
return loss, len(self.valloader.dataset), {"accuracy": accuracy}
|
38
|
+
|
39
|
+
|
40
|
+
def client_fn(context: Context):
|
41
|
+
# Load model and data
|
42
|
+
net = Net()
|
43
|
+
partition_id = context.node_config["partition-id"]
|
44
|
+
num_partitions = context.node_config["num-partitions"]
|
45
|
+
trainloader, valloader = load_data(partition_id, num_partitions)
|
46
|
+
local_epochs = context.run_config["local-epochs"]
|
47
|
+
|
48
|
+
# Return Client instance
|
49
|
+
return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
|
50
|
+
|
51
|
+
|
52
|
+
# Flower ClientApp
|
53
|
+
app = ClientApp(
|
54
|
+
client_fn,
|
55
|
+
)
|
@@ -1,31 +1,41 @@
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
2
2
|
|
3
|
-
|
4
|
-
from flwr.
|
5
|
-
from flwr.
|
6
|
-
from
|
7
|
-
|
8
|
-
|
9
|
-
def server_fn(context: Context):
|
10
|
-
# Read from config
|
11
|
-
num_rounds = context.run_config["num-server-rounds"]
|
12
|
-
fraction_fit = context.run_config["fraction-fit"]
|
13
|
-
|
14
|
-
# Initialize model parameters
|
15
|
-
ndarrays = get_weights(Net())
|
16
|
-
parameters = ndarrays_to_parameters(ndarrays)
|
17
|
-
|
18
|
-
# Define strategy
|
19
|
-
strategy = FedAvg(
|
20
|
-
fraction_fit=fraction_fit,
|
21
|
-
fraction_evaluate=1.0,
|
22
|
-
min_available_clients=2,
|
23
|
-
initial_parameters=parameters,
|
24
|
-
)
|
25
|
-
config = ServerConfig(num_rounds=num_rounds)
|
26
|
-
|
27
|
-
return ServerAppComponents(strategy=strategy, config=config)
|
3
|
+
import torch
|
4
|
+
from flwr.app import ArrayRecord, ConfigRecord, Context
|
5
|
+
from flwr.serverapp import Grid, ServerApp
|
6
|
+
from flwr.serverapp.strategy import FedAvg
|
28
7
|
|
8
|
+
from $import_name.task import Net
|
29
9
|
|
30
10
|
# Create ServerApp
|
31
|
-
app = ServerApp(
|
11
|
+
app = ServerApp()
|
12
|
+
|
13
|
+
|
14
|
+
@app.main()
|
15
|
+
def main(grid: Grid, context: Context) -> None:
|
16
|
+
"""Main entry point for the ServerApp."""
|
17
|
+
|
18
|
+
# Read run config
|
19
|
+
fraction_train: float = context.run_config["fraction-train"]
|
20
|
+
num_rounds: int = context.run_config["num-server-rounds"]
|
21
|
+
lr: float = context.run_config["lr"]
|
22
|
+
|
23
|
+
# Load global model
|
24
|
+
global_model = Net()
|
25
|
+
arrays = ArrayRecord(global_model.state_dict())
|
26
|
+
|
27
|
+
# Initialize FedAvg strategy
|
28
|
+
strategy = FedAvg(fraction_train=fraction_train)
|
29
|
+
|
30
|
+
# Start strategy, run FedAvg for `num_rounds`
|
31
|
+
result = strategy.start(
|
32
|
+
grid=grid,
|
33
|
+
initial_arrays=arrays,
|
34
|
+
train_config=ConfigRecord({"lr": lr}),
|
35
|
+
num_rounds=num_rounds,
|
36
|
+
)
|
37
|
+
|
38
|
+
# Save final model to disk
|
39
|
+
print("\nSaving final model to disk...")
|
40
|
+
state_dict = result.arrays.to_torch_state_dict()
|
41
|
+
torch.save(state_dict, "final_model.pt")
|
@@ -0,0 +1,31 @@
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
2
|
+
|
3
|
+
from flwr.common import Context, ndarrays_to_parameters
|
4
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
5
|
+
from flwr.server.strategy import FedAvg
|
6
|
+
from $import_name.task import Net, get_weights
|
7
|
+
|
8
|
+
|
9
|
+
def server_fn(context: Context):
|
10
|
+
# Read from config
|
11
|
+
num_rounds = context.run_config["num-server-rounds"]
|
12
|
+
fraction_fit = context.run_config["fraction-fit"]
|
13
|
+
|
14
|
+
# Initialize model parameters
|
15
|
+
ndarrays = get_weights(Net())
|
16
|
+
parameters = ndarrays_to_parameters(ndarrays)
|
17
|
+
|
18
|
+
# Define strategy
|
19
|
+
strategy = FedAvg(
|
20
|
+
fraction_fit=fraction_fit,
|
21
|
+
fraction_evaluate=1.0,
|
22
|
+
min_available_clients=2,
|
23
|
+
initial_parameters=parameters,
|
24
|
+
)
|
25
|
+
config = ServerConfig(num_rounds=num_rounds)
|
26
|
+
|
27
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
28
|
+
|
29
|
+
|
30
|
+
# Create ServerApp
|
31
|
+
app = ServerApp(server_fn=server_fn)
|
@@ -1,7 +1,5 @@
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
2
2
|
|
3
|
-
from collections import OrderedDict
|
4
|
-
|
5
3
|
import torch
|
6
4
|
import torch.nn as nn
|
7
5
|
import torch.nn.functional as F
|
@@ -34,6 +32,14 @@ class Net(nn.Module):
|
|
34
32
|
|
35
33
|
fds = None # Cache FederatedDataset
|
36
34
|
|
35
|
+
pytorch_transforms = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
36
|
+
|
37
|
+
|
38
|
+
def apply_transforms(batch):
|
39
|
+
"""Apply transforms to the partition from FederatedDataset."""
|
40
|
+
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
|
41
|
+
return batch
|
42
|
+
|
37
43
|
|
38
44
|
def load_data(partition_id: int, num_partitions: int):
|
39
45
|
"""Load partition CIFAR10 data."""
|
@@ -48,38 +54,29 @@ def load_data(partition_id: int, num_partitions: int):
|
|
48
54
|
partition = fds.load_partition(partition_id)
|
49
55
|
# Divide data on each node: 80% train, 20% test
|
50
56
|
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
51
|
-
|
52
|
-
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
53
|
-
)
|
54
|
-
|
55
|
-
def apply_transforms(batch):
|
56
|
-
"""Apply transforms to the partition from FederatedDataset."""
|
57
|
-
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
|
58
|
-
return batch
|
59
|
-
|
57
|
+
# Construct dataloaders
|
60
58
|
partition_train_test = partition_train_test.with_transform(apply_transforms)
|
61
59
|
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
|
62
60
|
testloader = DataLoader(partition_train_test["test"], batch_size=32)
|
63
61
|
return trainloader, testloader
|
64
62
|
|
65
63
|
|
66
|
-
def train(net, trainloader, epochs, device):
|
64
|
+
def train(net, trainloader, epochs, lr, device):
|
67
65
|
"""Train the model on the training set."""
|
68
66
|
net.to(device) # move model to GPU if available
|
69
67
|
criterion = torch.nn.CrossEntropyLoss().to(device)
|
70
|
-
optimizer = torch.optim.Adam(net.parameters(), lr=
|
68
|
+
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
|
71
69
|
net.train()
|
72
70
|
running_loss = 0.0
|
73
71
|
for _ in range(epochs):
|
74
72
|
for batch in trainloader:
|
75
|
-
images = batch["img"]
|
76
|
-
labels = batch["label"]
|
73
|
+
images = batch["img"].to(device)
|
74
|
+
labels = batch["label"].to(device)
|
77
75
|
optimizer.zero_grad()
|
78
|
-
loss = criterion(net(images
|
76
|
+
loss = criterion(net(images), labels)
|
79
77
|
loss.backward()
|
80
78
|
optimizer.step()
|
81
79
|
running_loss += loss.item()
|
82
|
-
|
83
80
|
avg_trainloss = running_loss / len(trainloader)
|
84
81
|
return avg_trainloss
|
85
82
|
|
@@ -99,13 +96,3 @@ def test(net, testloader, device):
|
|
99
96
|
accuracy = correct / len(testloader.dataset)
|
100
97
|
loss = loss / len(testloader)
|
101
98
|
return loss, accuracy
|
102
|
-
|
103
|
-
|
104
|
-
def get_weights(net):
|
105
|
-
return [val.cpu().numpy() for _, val in net.state_dict().items()]
|
106
|
-
|
107
|
-
|
108
|
-
def set_weights(net, parameters):
|
109
|
-
params_dict = zip(net.state_dict().keys(), parameters)
|
110
|
-
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
|
111
|
-
net.load_state_dict(state_dict, strict=True)
|
flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl}
RENAMED
@@ -1,5 +1,7 @@
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
2
2
|
|
3
|
+
from collections import OrderedDict
|
4
|
+
|
3
5
|
import torch
|
4
6
|
import torch.nn as nn
|
5
7
|
import torch.nn.functional as F
|
@@ -32,14 +34,6 @@ class Net(nn.Module):
|
|
32
34
|
|
33
35
|
fds = None # Cache FederatedDataset
|
34
36
|
|
35
|
-
pytorch_transforms = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
36
|
-
|
37
|
-
|
38
|
-
def apply_transforms(batch):
|
39
|
-
"""Apply transforms to the partition from FederatedDataset."""
|
40
|
-
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
|
41
|
-
return batch
|
42
|
-
|
43
37
|
|
44
38
|
def load_data(partition_id: int, num_partitions: int):
|
45
39
|
"""Load partition CIFAR10 data."""
|
@@ -54,29 +48,38 @@ def load_data(partition_id: int, num_partitions: int):
|
|
54
48
|
partition = fds.load_partition(partition_id)
|
55
49
|
# Divide data on each node: 80% train, 20% test
|
56
50
|
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
57
|
-
|
51
|
+
pytorch_transforms = Compose(
|
52
|
+
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
53
|
+
)
|
54
|
+
|
55
|
+
def apply_transforms(batch):
|
56
|
+
"""Apply transforms to the partition from FederatedDataset."""
|
57
|
+
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
|
58
|
+
return batch
|
59
|
+
|
58
60
|
partition_train_test = partition_train_test.with_transform(apply_transforms)
|
59
61
|
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
|
60
62
|
testloader = DataLoader(partition_train_test["test"], batch_size=32)
|
61
63
|
return trainloader, testloader
|
62
64
|
|
63
65
|
|
64
|
-
def train(net, trainloader, epochs,
|
66
|
+
def train(net, trainloader, epochs, device):
|
65
67
|
"""Train the model on the training set."""
|
66
68
|
net.to(device) # move model to GPU if available
|
67
69
|
criterion = torch.nn.CrossEntropyLoss().to(device)
|
68
|
-
optimizer = torch.optim.Adam(net.parameters(), lr=
|
70
|
+
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
|
69
71
|
net.train()
|
70
72
|
running_loss = 0.0
|
71
73
|
for _ in range(epochs):
|
72
74
|
for batch in trainloader:
|
73
|
-
images = batch["img"]
|
74
|
-
labels = batch["label"]
|
75
|
+
images = batch["img"]
|
76
|
+
labels = batch["label"]
|
75
77
|
optimizer.zero_grad()
|
76
|
-
loss = criterion(net(images), labels)
|
78
|
+
loss = criterion(net(images.to(device)), labels.to(device))
|
77
79
|
loss.backward()
|
78
80
|
optimizer.step()
|
79
81
|
running_loss += loss.item()
|
82
|
+
|
80
83
|
avg_trainloss = running_loss / len(trainloader)
|
81
84
|
return avg_trainloss
|
82
85
|
|
@@ -96,3 +99,13 @@ def test(net, testloader, device):
|
|
96
99
|
accuracy = correct / len(testloader.dataset)
|
97
100
|
loss = loss / len(testloader)
|
98
101
|
return loss, accuracy
|
102
|
+
|
103
|
+
|
104
|
+
def get_weights(net):
|
105
|
+
return [val.cpu().numpy() for _, val in net.state_dict().items()]
|
106
|
+
|
107
|
+
|
108
|
+
def set_weights(net, parameters):
|
109
|
+
params_dict = zip(net.state_dict().keys(), parameters)
|
110
|
+
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
|
111
|
+
net.load_state_dict(state_dict, strict=True)
|
@@ -27,7 +27,6 @@ packages = ["."]
|
|
27
27
|
publisher = "$username"
|
28
28
|
|
29
29
|
# Point to your ServerApp and ClientApp objects
|
30
|
-
# Format: "<module>:<object>"
|
31
30
|
[tool.flwr.app.components]
|
32
31
|
serverapp = "$import_name.server_app:app"
|
33
32
|
clientapp = "$import_name.client_app:app"
|
@@ -35,8 +34,9 @@ clientapp = "$import_name.client_app:app"
|
|
35
34
|
# Custom config values accessible via `context.run_config`
|
36
35
|
[tool.flwr.app.config]
|
37
36
|
num-server-rounds = 3
|
38
|
-
fraction-
|
37
|
+
fraction-train = 0.5
|
39
38
|
local-epochs = 1
|
39
|
+
lr = 0.01
|
40
40
|
|
41
41
|
# Default federation to use when running the app
|
42
42
|
[tool.flwr.federations]
|
@@ -27,6 +27,7 @@ packages = ["."]
|
|
27
27
|
publisher = "$username"
|
28
28
|
|
29
29
|
# Point to your ServerApp and ClientApp objects
|
30
|
+
# Format: "<module>:<object>"
|
30
31
|
[tool.flwr.app.components]
|
31
32
|
serverapp = "$import_name.server_app:app"
|
32
33
|
clientapp = "$import_name.client_app:app"
|
@@ -34,9 +35,8 @@ clientapp = "$import_name.client_app:app"
|
|
34
35
|
# Custom config values accessible via `context.run_config`
|
35
36
|
[tool.flwr.app.config]
|
36
37
|
num-server-rounds = 3
|
37
|
-
fraction-
|
38
|
+
fraction-fit = 0.5
|
38
39
|
local-epochs = 1
|
39
|
-
lr = 0.01
|
40
40
|
|
41
41
|
# Default federation to use when running the app
|
42
42
|
[tool.flwr.federations]
|
@@ -22,6 +22,10 @@ from .dp_fixed_clipping import (
|
|
22
22
|
from .fedadagrad import FedAdagrad
|
23
23
|
from .fedadam import FedAdam
|
24
24
|
from .fedavg import FedAvg
|
25
|
+
from .fedavgm import FedAvgM
|
26
|
+
from .fedmedian import FedMedian
|
27
|
+
from .fedtrimmedavg import FedTrimmedAvg
|
28
|
+
from .fedxgb_bagging import FedXgbBagging
|
25
29
|
from .fedyogi import FedYogi
|
26
30
|
from .result import Result
|
27
31
|
from .strategy import Strategy
|
@@ -32,6 +36,10 @@ __all__ = [
|
|
32
36
|
"FedAdagrad",
|
33
37
|
"FedAdam",
|
34
38
|
"FedAvg",
|
39
|
+
"FedAvgM",
|
40
|
+
"FedMedian",
|
41
|
+
"FedTrimmedAvg",
|
42
|
+
"FedXgbBagging",
|
35
43
|
"FedYogi",
|
36
44
|
"Result",
|
37
45
|
"Strategy",
|
@@ -16,7 +16,7 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
from collections.abc import Iterable
|
19
|
-
from logging import INFO
|
19
|
+
from logging import INFO, WARNING
|
20
20
|
from typing import Callable, Optional
|
21
21
|
|
22
22
|
from flwr.common import (
|
@@ -67,7 +67,7 @@ class FedAvg(Strategy):
|
|
67
67
|
arrayrecord_key : str (default: "arrays")
|
68
68
|
Key used to store the ArrayRecord when constructing Messages.
|
69
69
|
configrecord_key : str (default: "config")
|
70
|
-
|
70
|
+
Key used to store the ConfigRecord when constructing Messages.
|
71
71
|
train_metrics_aggr_fn : Optional[callable] (default: None)
|
72
72
|
Function with signature (list[RecordDict], str) -> MetricRecord,
|
73
73
|
used to aggregate MetricRecords from training round replies.
|
@@ -111,6 +111,20 @@ class FedAvg(Strategy):
|
|
111
111
|
evaluate_metrics_aggr_fn or aggregate_metricrecords
|
112
112
|
)
|
113
113
|
|
114
|
+
if self.fraction_evaluate == 0.0:
|
115
|
+
self.min_evaluate_nodes = 0
|
116
|
+
log(
|
117
|
+
WARNING,
|
118
|
+
"fraction_evaluate is set to 0.0. "
|
119
|
+
"Federated evaluation will be skipped.",
|
120
|
+
)
|
121
|
+
if self.fraction_train == 0.0:
|
122
|
+
self.min_train_nodes = 0
|
123
|
+
log(
|
124
|
+
WARNING,
|
125
|
+
"fraction_train is set to 0.0. Federated training will be skipped.",
|
126
|
+
)
|
127
|
+
|
114
128
|
def summary(self) -> None:
|
115
129
|
"""Log summary configuration of the strategy."""
|
116
130
|
log(INFO, "\t├──> Sampling:")
|
@@ -150,6 +164,9 @@ class FedAvg(Strategy):
|
|
150
164
|
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
151
165
|
) -> Iterable[Message]:
|
152
166
|
"""Configure the next round of federated training."""
|
167
|
+
# Do not configure federated train if fraction_train is 0.
|
168
|
+
if self.fraction_train == 0.0:
|
169
|
+
return []
|
153
170
|
# Sample nodes
|
154
171
|
num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_train)
|
155
172
|
sample_size = max(num_nodes, self.min_train_nodes)
|
@@ -259,6 +276,10 @@ class FedAvg(Strategy):
|
|
259
276
|
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
260
277
|
) -> Iterable[Message]:
|
261
278
|
"""Configure the next round of federated evaluation."""
|
279
|
+
# Do not configure federated evaluation if fraction_evaluate is 0.
|
280
|
+
if self.fraction_evaluate == 0.0:
|
281
|
+
return []
|
282
|
+
|
262
283
|
# Sample nodes
|
263
284
|
num_nodes = int(len(list(grid.get_node_ids())) * self.fraction_evaluate)
|
264
285
|
sample_size = max(num_nodes, self.min_evaluate_nodes)
|