flwr-nightly 1.22.0.dev20250913__py3-none-any.whl → 1.22.0.dev20250915__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/new/new.py +5 -5
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
- flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +2 -2
- flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +2 -2
- flwr/serverapp/strategy/__init__.py +2 -0
- flwr/serverapp/strategy/fedxgb_bagging.py +82 -0
- flwr/serverapp/strategy/strategy_utils.py +48 -0
- {flwr_nightly-1.22.0.dev20250913.dist-info → flwr_nightly-1.22.0.dev20250915.dist-info}/METADATA +1 -1
- {flwr_nightly-1.22.0.dev20250913.dist-info → flwr_nightly-1.22.0.dev20250915.dist-info}/RECORD +17 -16
- flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
- flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
- /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
- {flwr_nightly-1.22.0.dev20250913.dist-info → flwr_nightly-1.22.0.dev20250915.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.22.0.dev20250913.dist-info → flwr_nightly-1.22.0.dev20250915.dist-info}/entry_points.txt +0 -0
flwr/cli/new/new.py
CHANGED
@@ -35,7 +35,6 @@ class MlFramework(str, Enum):
|
|
35
35
|
"""Available frameworks."""
|
36
36
|
|
37
37
|
PYTORCH = "PyTorch"
|
38
|
-
PYTORCH_MSG_API = "PyTorch (Message API)"
|
39
38
|
TENSORFLOW = "TensorFlow"
|
40
39
|
SKLEARN = "sklearn"
|
41
40
|
HUGGINGFACE = "HuggingFace"
|
@@ -44,6 +43,7 @@ class MlFramework(str, Enum):
|
|
44
43
|
NUMPY = "NumPy"
|
45
44
|
FLOWERTUNE = "FlowerTune"
|
46
45
|
BASELINE = "Flower Baseline"
|
46
|
+
PYTORCH_LEGACY_API = "PyTorch (Legacy API, deprecated)"
|
47
47
|
|
48
48
|
|
49
49
|
class LlmChallengeName(str, Enum):
|
@@ -155,8 +155,8 @@ def new(
|
|
155
155
|
if framework_str == MlFramework.BASELINE:
|
156
156
|
framework_str = "baseline"
|
157
157
|
|
158
|
-
if framework_str == MlFramework.
|
159
|
-
framework_str = "
|
158
|
+
if framework_str == MlFramework.PYTORCH_LEGACY_API:
|
159
|
+
framework_str = "pytorch_legacy_api"
|
160
160
|
|
161
161
|
print(
|
162
162
|
typer.style(
|
@@ -247,14 +247,14 @@ def new(
|
|
247
247
|
MlFramework.TENSORFLOW.value,
|
248
248
|
MlFramework.SKLEARN.value,
|
249
249
|
MlFramework.NUMPY.value,
|
250
|
-
"
|
250
|
+
"pytorch_legacy_api",
|
251
251
|
]
|
252
252
|
if framework_str in frameworks_with_tasks:
|
253
253
|
files[f"{import_name}/task.py"] = {
|
254
254
|
"template": f"app/code/task.{template_name}.py.tpl"
|
255
255
|
}
|
256
256
|
|
257
|
-
if framework_str == "
|
257
|
+
if framework_str == "pytorch_legacy_api":
|
258
258
|
# Use custom __init__ that better captures name of framework
|
259
259
|
files[f"{import_name}/__init__.py"] = {
|
260
260
|
"template": f"app/code/__init__.{framework_str}.py.tpl"
|
@@ -1,55 +1,80 @@
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
2
2
|
|
3
3
|
import torch
|
4
|
+
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
|
5
|
+
from flwr.clientapp import ClientApp
|
4
6
|
|
5
|
-
from
|
6
|
-
from
|
7
|
-
from $import_name.task import
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
self.trainloader,
|
25
|
-
self.local_epochs,
|
26
|
-
self.device,
|
27
|
-
)
|
28
|
-
return (
|
29
|
-
get_weights(self.net),
|
30
|
-
len(self.trainloader.dataset),
|
31
|
-
{"train_loss": train_loss},
|
32
|
-
)
|
33
|
-
|
34
|
-
def evaluate(self, parameters, config):
|
35
|
-
set_weights(self.net, parameters)
|
36
|
-
loss, accuracy = test(self.net, self.valloader, self.device)
|
37
|
-
return loss, len(self.valloader.dataset), {"accuracy": accuracy}
|
38
|
-
|
39
|
-
|
40
|
-
def client_fn(context: Context):
|
41
|
-
# Load model and data
|
42
|
-
net = Net()
|
7
|
+
from $import_name.task import Net, load_data
|
8
|
+
from $import_name.task import test as test_fn
|
9
|
+
from $import_name.task import train as train_fn
|
10
|
+
|
11
|
+
# Flower ClientApp
|
12
|
+
app = ClientApp()
|
13
|
+
|
14
|
+
|
15
|
+
@app.train()
|
16
|
+
def train(msg: Message, context: Context):
|
17
|
+
"""Train the model on local data."""
|
18
|
+
|
19
|
+
# Load the model and initialize it with the received weights
|
20
|
+
model = Net()
|
21
|
+
model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
|
22
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
23
|
+
model.to(device)
|
24
|
+
|
25
|
+
# Load the data
|
43
26
|
partition_id = context.node_config["partition-id"]
|
44
27
|
num_partitions = context.node_config["num-partitions"]
|
45
|
-
trainloader,
|
46
|
-
local_epochs = context.run_config["local-epochs"]
|
28
|
+
trainloader, _ = load_data(partition_id, num_partitions)
|
47
29
|
|
48
|
-
#
|
49
|
-
|
30
|
+
# Call the training function
|
31
|
+
train_loss = train_fn(
|
32
|
+
model,
|
33
|
+
trainloader,
|
34
|
+
context.run_config["local-epochs"],
|
35
|
+
msg.content["config"]["lr"],
|
36
|
+
device,
|
37
|
+
)
|
50
38
|
|
39
|
+
# Construct and return reply Message
|
40
|
+
model_record = ArrayRecord(model.state_dict())
|
41
|
+
metrics = {
|
42
|
+
"train_loss": train_loss,
|
43
|
+
"num-examples": len(trainloader.dataset),
|
44
|
+
}
|
45
|
+
metric_record = MetricRecord(metrics)
|
46
|
+
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
47
|
+
return Message(content=content, reply_to=msg)
|
51
48
|
|
52
|
-
|
53
|
-
app
|
54
|
-
|
55
|
-
|
49
|
+
|
50
|
+
@app.evaluate()
|
51
|
+
def evaluate(msg: Message, context: Context):
|
52
|
+
"""Evaluate the model on local data."""
|
53
|
+
|
54
|
+
# Load the model and initialize it with the received weights
|
55
|
+
model = Net()
|
56
|
+
model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
|
57
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
58
|
+
model.to(device)
|
59
|
+
|
60
|
+
# Load the data
|
61
|
+
partition_id = context.node_config["partition-id"]
|
62
|
+
num_partitions = context.node_config["num-partitions"]
|
63
|
+
_, valloader = load_data(partition_id, num_partitions)
|
64
|
+
|
65
|
+
# Call the evaluation function
|
66
|
+
eval_loss, eval_acc = test_fn(
|
67
|
+
model,
|
68
|
+
valloader,
|
69
|
+
device,
|
70
|
+
)
|
71
|
+
|
72
|
+
# Construct and return reply Message
|
73
|
+
metrics = {
|
74
|
+
"eval_loss": eval_loss,
|
75
|
+
"eval_acc": eval_acc,
|
76
|
+
"num-examples": len(valloader.dataset),
|
77
|
+
}
|
78
|
+
metric_record = MetricRecord(metrics)
|
79
|
+
content = RecordDict({"metrics": metric_record})
|
80
|
+
return Message(content=content, reply_to=msg)
|
@@ -0,0 +1,55 @@
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from flwr.client import ClientApp, NumPyClient
|
6
|
+
from flwr.common import Context
|
7
|
+
from $import_name.task import Net, get_weights, load_data, set_weights, test, train
|
8
|
+
|
9
|
+
|
10
|
+
# Define Flower Client and client_fn
|
11
|
+
class FlowerClient(NumPyClient):
|
12
|
+
def __init__(self, net, trainloader, valloader, local_epochs):
|
13
|
+
self.net = net
|
14
|
+
self.trainloader = trainloader
|
15
|
+
self.valloader = valloader
|
16
|
+
self.local_epochs = local_epochs
|
17
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
18
|
+
self.net.to(self.device)
|
19
|
+
|
20
|
+
def fit(self, parameters, config):
|
21
|
+
set_weights(self.net, parameters)
|
22
|
+
train_loss = train(
|
23
|
+
self.net,
|
24
|
+
self.trainloader,
|
25
|
+
self.local_epochs,
|
26
|
+
self.device,
|
27
|
+
)
|
28
|
+
return (
|
29
|
+
get_weights(self.net),
|
30
|
+
len(self.trainloader.dataset),
|
31
|
+
{"train_loss": train_loss},
|
32
|
+
)
|
33
|
+
|
34
|
+
def evaluate(self, parameters, config):
|
35
|
+
set_weights(self.net, parameters)
|
36
|
+
loss, accuracy = test(self.net, self.valloader, self.device)
|
37
|
+
return loss, len(self.valloader.dataset), {"accuracy": accuracy}
|
38
|
+
|
39
|
+
|
40
|
+
def client_fn(context: Context):
|
41
|
+
# Load model and data
|
42
|
+
net = Net()
|
43
|
+
partition_id = context.node_config["partition-id"]
|
44
|
+
num_partitions = context.node_config["num-partitions"]
|
45
|
+
trainloader, valloader = load_data(partition_id, num_partitions)
|
46
|
+
local_epochs = context.run_config["local-epochs"]
|
47
|
+
|
48
|
+
# Return Client instance
|
49
|
+
return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
|
50
|
+
|
51
|
+
|
52
|
+
# Flower ClientApp
|
53
|
+
app = ClientApp(
|
54
|
+
client_fn,
|
55
|
+
)
|
@@ -1,31 +1,41 @@
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
2
2
|
|
3
|
-
|
4
|
-
from flwr.
|
5
|
-
from flwr.
|
6
|
-
from
|
7
|
-
|
8
|
-
|
9
|
-
def server_fn(context: Context):
|
10
|
-
# Read from config
|
11
|
-
num_rounds = context.run_config["num-server-rounds"]
|
12
|
-
fraction_fit = context.run_config["fraction-fit"]
|
13
|
-
|
14
|
-
# Initialize model parameters
|
15
|
-
ndarrays = get_weights(Net())
|
16
|
-
parameters = ndarrays_to_parameters(ndarrays)
|
17
|
-
|
18
|
-
# Define strategy
|
19
|
-
strategy = FedAvg(
|
20
|
-
fraction_fit=fraction_fit,
|
21
|
-
fraction_evaluate=1.0,
|
22
|
-
min_available_clients=2,
|
23
|
-
initial_parameters=parameters,
|
24
|
-
)
|
25
|
-
config = ServerConfig(num_rounds=num_rounds)
|
26
|
-
|
27
|
-
return ServerAppComponents(strategy=strategy, config=config)
|
3
|
+
import torch
|
4
|
+
from flwr.app import ArrayRecord, ConfigRecord, Context
|
5
|
+
from flwr.serverapp import Grid, ServerApp
|
6
|
+
from flwr.serverapp.strategy import FedAvg
|
28
7
|
|
8
|
+
from $import_name.task import Net
|
29
9
|
|
30
10
|
# Create ServerApp
|
31
|
-
app = ServerApp(
|
11
|
+
app = ServerApp()
|
12
|
+
|
13
|
+
|
14
|
+
@app.main()
|
15
|
+
def main(grid: Grid, context: Context) -> None:
|
16
|
+
"""Main entry point for the ServerApp."""
|
17
|
+
|
18
|
+
# Read run config
|
19
|
+
fraction_train: float = context.run_config["fraction-train"]
|
20
|
+
num_rounds: int = context.run_config["num-server-rounds"]
|
21
|
+
lr: float = context.run_config["lr"]
|
22
|
+
|
23
|
+
# Load global model
|
24
|
+
global_model = Net()
|
25
|
+
arrays = ArrayRecord(global_model.state_dict())
|
26
|
+
|
27
|
+
# Initialize FedAvg strategy
|
28
|
+
strategy = FedAvg(fraction_train=fraction_train)
|
29
|
+
|
30
|
+
# Start strategy, run FedAvg for `num_rounds`
|
31
|
+
result = strategy.start(
|
32
|
+
grid=grid,
|
33
|
+
initial_arrays=arrays,
|
34
|
+
train_config=ConfigRecord({"lr": lr}),
|
35
|
+
num_rounds=num_rounds,
|
36
|
+
)
|
37
|
+
|
38
|
+
# Save final model to disk
|
39
|
+
print("\nSaving final model to disk...")
|
40
|
+
state_dict = result.arrays.to_torch_state_dict()
|
41
|
+
torch.save(state_dict, "final_model.pt")
|
@@ -0,0 +1,31 @@
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
2
|
+
|
3
|
+
from flwr.common import Context, ndarrays_to_parameters
|
4
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
5
|
+
from flwr.server.strategy import FedAvg
|
6
|
+
from $import_name.task import Net, get_weights
|
7
|
+
|
8
|
+
|
9
|
+
def server_fn(context: Context):
|
10
|
+
# Read from config
|
11
|
+
num_rounds = context.run_config["num-server-rounds"]
|
12
|
+
fraction_fit = context.run_config["fraction-fit"]
|
13
|
+
|
14
|
+
# Initialize model parameters
|
15
|
+
ndarrays = get_weights(Net())
|
16
|
+
parameters = ndarrays_to_parameters(ndarrays)
|
17
|
+
|
18
|
+
# Define strategy
|
19
|
+
strategy = FedAvg(
|
20
|
+
fraction_fit=fraction_fit,
|
21
|
+
fraction_evaluate=1.0,
|
22
|
+
min_available_clients=2,
|
23
|
+
initial_parameters=parameters,
|
24
|
+
)
|
25
|
+
config = ServerConfig(num_rounds=num_rounds)
|
26
|
+
|
27
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
28
|
+
|
29
|
+
|
30
|
+
# Create ServerApp
|
31
|
+
app = ServerApp(server_fn=server_fn)
|
@@ -1,7 +1,5 @@
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
2
2
|
|
3
|
-
from collections import OrderedDict
|
4
|
-
|
5
3
|
import torch
|
6
4
|
import torch.nn as nn
|
7
5
|
import torch.nn.functional as F
|
@@ -34,6 +32,14 @@ class Net(nn.Module):
|
|
34
32
|
|
35
33
|
fds = None # Cache FederatedDataset
|
36
34
|
|
35
|
+
pytorch_transforms = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
36
|
+
|
37
|
+
|
38
|
+
def apply_transforms(batch):
|
39
|
+
"""Apply transforms to the partition from FederatedDataset."""
|
40
|
+
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
|
41
|
+
return batch
|
42
|
+
|
37
43
|
|
38
44
|
def load_data(partition_id: int, num_partitions: int):
|
39
45
|
"""Load partition CIFAR10 data."""
|
@@ -48,38 +54,29 @@ def load_data(partition_id: int, num_partitions: int):
|
|
48
54
|
partition = fds.load_partition(partition_id)
|
49
55
|
# Divide data on each node: 80% train, 20% test
|
50
56
|
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
51
|
-
|
52
|
-
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
53
|
-
)
|
54
|
-
|
55
|
-
def apply_transforms(batch):
|
56
|
-
"""Apply transforms to the partition from FederatedDataset."""
|
57
|
-
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
|
58
|
-
return batch
|
59
|
-
|
57
|
+
# Construct dataloaders
|
60
58
|
partition_train_test = partition_train_test.with_transform(apply_transforms)
|
61
59
|
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
|
62
60
|
testloader = DataLoader(partition_train_test["test"], batch_size=32)
|
63
61
|
return trainloader, testloader
|
64
62
|
|
65
63
|
|
66
|
-
def train(net, trainloader, epochs, device):
|
64
|
+
def train(net, trainloader, epochs, lr, device):
|
67
65
|
"""Train the model on the training set."""
|
68
66
|
net.to(device) # move model to GPU if available
|
69
67
|
criterion = torch.nn.CrossEntropyLoss().to(device)
|
70
|
-
optimizer = torch.optim.Adam(net.parameters(), lr=
|
68
|
+
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
|
71
69
|
net.train()
|
72
70
|
running_loss = 0.0
|
73
71
|
for _ in range(epochs):
|
74
72
|
for batch in trainloader:
|
75
|
-
images = batch["img"]
|
76
|
-
labels = batch["label"]
|
73
|
+
images = batch["img"].to(device)
|
74
|
+
labels = batch["label"].to(device)
|
77
75
|
optimizer.zero_grad()
|
78
|
-
loss = criterion(net(images
|
76
|
+
loss = criterion(net(images), labels)
|
79
77
|
loss.backward()
|
80
78
|
optimizer.step()
|
81
79
|
running_loss += loss.item()
|
82
|
-
|
83
80
|
avg_trainloss = running_loss / len(trainloader)
|
84
81
|
return avg_trainloss
|
85
82
|
|
@@ -99,13 +96,3 @@ def test(net, testloader, device):
|
|
99
96
|
accuracy = correct / len(testloader.dataset)
|
100
97
|
loss = loss / len(testloader)
|
101
98
|
return loss, accuracy
|
102
|
-
|
103
|
-
|
104
|
-
def get_weights(net):
|
105
|
-
return [val.cpu().numpy() for _, val in net.state_dict().items()]
|
106
|
-
|
107
|
-
|
108
|
-
def set_weights(net, parameters):
|
109
|
-
params_dict = zip(net.state_dict().keys(), parameters)
|
110
|
-
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
|
111
|
-
net.load_state_dict(state_dict, strict=True)
|
flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl}
RENAMED
@@ -1,5 +1,7 @@
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
2
2
|
|
3
|
+
from collections import OrderedDict
|
4
|
+
|
3
5
|
import torch
|
4
6
|
import torch.nn as nn
|
5
7
|
import torch.nn.functional as F
|
@@ -32,14 +34,6 @@ class Net(nn.Module):
|
|
32
34
|
|
33
35
|
fds = None # Cache FederatedDataset
|
34
36
|
|
35
|
-
pytorch_transforms = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
36
|
-
|
37
|
-
|
38
|
-
def apply_transforms(batch):
|
39
|
-
"""Apply transforms to the partition from FederatedDataset."""
|
40
|
-
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
|
41
|
-
return batch
|
42
|
-
|
43
37
|
|
44
38
|
def load_data(partition_id: int, num_partitions: int):
|
45
39
|
"""Load partition CIFAR10 data."""
|
@@ -54,29 +48,38 @@ def load_data(partition_id: int, num_partitions: int):
|
|
54
48
|
partition = fds.load_partition(partition_id)
|
55
49
|
# Divide data on each node: 80% train, 20% test
|
56
50
|
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
57
|
-
|
51
|
+
pytorch_transforms = Compose(
|
52
|
+
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
53
|
+
)
|
54
|
+
|
55
|
+
def apply_transforms(batch):
|
56
|
+
"""Apply transforms to the partition from FederatedDataset."""
|
57
|
+
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
|
58
|
+
return batch
|
59
|
+
|
58
60
|
partition_train_test = partition_train_test.with_transform(apply_transforms)
|
59
61
|
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
|
60
62
|
testloader = DataLoader(partition_train_test["test"], batch_size=32)
|
61
63
|
return trainloader, testloader
|
62
64
|
|
63
65
|
|
64
|
-
def train(net, trainloader, epochs,
|
66
|
+
def train(net, trainloader, epochs, device):
|
65
67
|
"""Train the model on the training set."""
|
66
68
|
net.to(device) # move model to GPU if available
|
67
69
|
criterion = torch.nn.CrossEntropyLoss().to(device)
|
68
|
-
optimizer = torch.optim.Adam(net.parameters(), lr=
|
70
|
+
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
|
69
71
|
net.train()
|
70
72
|
running_loss = 0.0
|
71
73
|
for _ in range(epochs):
|
72
74
|
for batch in trainloader:
|
73
|
-
images = batch["img"]
|
74
|
-
labels = batch["label"]
|
75
|
+
images = batch["img"]
|
76
|
+
labels = batch["label"]
|
75
77
|
optimizer.zero_grad()
|
76
|
-
loss = criterion(net(images), labels)
|
78
|
+
loss = criterion(net(images.to(device)), labels.to(device))
|
77
79
|
loss.backward()
|
78
80
|
optimizer.step()
|
79
81
|
running_loss += loss.item()
|
82
|
+
|
80
83
|
avg_trainloss = running_loss / len(trainloader)
|
81
84
|
return avg_trainloss
|
82
85
|
|
@@ -96,3 +99,13 @@ def test(net, testloader, device):
|
|
96
99
|
accuracy = correct / len(testloader.dataset)
|
97
100
|
loss = loss / len(testloader)
|
98
101
|
return loss, accuracy
|
102
|
+
|
103
|
+
|
104
|
+
def get_weights(net):
|
105
|
+
return [val.cpu().numpy() for _, val in net.state_dict().items()]
|
106
|
+
|
107
|
+
|
108
|
+
def set_weights(net, parameters):
|
109
|
+
params_dict = zip(net.state_dict().keys(), parameters)
|
110
|
+
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
|
111
|
+
net.load_state_dict(state_dict, strict=True)
|
@@ -27,7 +27,6 @@ packages = ["."]
|
|
27
27
|
publisher = "$username"
|
28
28
|
|
29
29
|
# Point to your ServerApp and ClientApp objects
|
30
|
-
# Format: "<module>:<object>"
|
31
30
|
[tool.flwr.app.components]
|
32
31
|
serverapp = "$import_name.server_app:app"
|
33
32
|
clientapp = "$import_name.client_app:app"
|
@@ -35,8 +34,9 @@ clientapp = "$import_name.client_app:app"
|
|
35
34
|
# Custom config values accessible via `context.run_config`
|
36
35
|
[tool.flwr.app.config]
|
37
36
|
num-server-rounds = 3
|
38
|
-
fraction-
|
37
|
+
fraction-train = 0.5
|
39
38
|
local-epochs = 1
|
39
|
+
lr = 0.01
|
40
40
|
|
41
41
|
# Default federation to use when running the app
|
42
42
|
[tool.flwr.federations]
|
@@ -27,6 +27,7 @@ packages = ["."]
|
|
27
27
|
publisher = "$username"
|
28
28
|
|
29
29
|
# Point to your ServerApp and ClientApp objects
|
30
|
+
# Format: "<module>:<object>"
|
30
31
|
[tool.flwr.app.components]
|
31
32
|
serverapp = "$import_name.server_app:app"
|
32
33
|
clientapp = "$import_name.client_app:app"
|
@@ -34,9 +35,8 @@ clientapp = "$import_name.client_app:app"
|
|
34
35
|
# Custom config values accessible via `context.run_config`
|
35
36
|
[tool.flwr.app.config]
|
36
37
|
num-server-rounds = 3
|
37
|
-
fraction-
|
38
|
+
fraction-fit = 0.5
|
38
39
|
local-epochs = 1
|
39
|
-
lr = 0.01
|
40
40
|
|
41
41
|
# Default federation to use when running the app
|
42
42
|
[tool.flwr.federations]
|
@@ -22,6 +22,7 @@ from .dp_fixed_clipping import (
|
|
22
22
|
from .fedadagrad import FedAdagrad
|
23
23
|
from .fedadam import FedAdam
|
24
24
|
from .fedavg import FedAvg
|
25
|
+
from .fedxgb_bagging import FedXgbBagging
|
25
26
|
from .fedyogi import FedYogi
|
26
27
|
from .result import Result
|
27
28
|
from .strategy import Strategy
|
@@ -32,6 +33,7 @@ __all__ = [
|
|
32
33
|
"FedAdagrad",
|
33
34
|
"FedAdam",
|
34
35
|
"FedAvg",
|
36
|
+
"FedXgbBagging",
|
35
37
|
"FedYogi",
|
36
38
|
"Result",
|
37
39
|
"Strategy",
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Flower message-based FedXgbBagging strategy."""
|
16
|
+
from collections.abc import Iterable
|
17
|
+
from typing import Optional, cast
|
18
|
+
|
19
|
+
import numpy as np
|
20
|
+
|
21
|
+
from flwr.common import ArrayRecord, ConfigRecord, Message, MetricRecord
|
22
|
+
from flwr.server import Grid
|
23
|
+
|
24
|
+
from ..exception import InconsistentMessageReplies
|
25
|
+
from .fedavg import FedAvg
|
26
|
+
from .strategy_utils import aggregate_bagging
|
27
|
+
|
28
|
+
|
29
|
+
# pylint: disable=line-too-long
|
30
|
+
class FedXgbBagging(FedAvg):
|
31
|
+
"""Configurable FedXgbBagging strategy implementation."""
|
32
|
+
|
33
|
+
current_bst: Optional[bytes] = None
|
34
|
+
|
35
|
+
def _ensure_single_array(self, arrays: ArrayRecord) -> None:
|
36
|
+
"""Check that ensures there's only one Array in the ArrayRecord."""
|
37
|
+
n = len(arrays)
|
38
|
+
if n != 1:
|
39
|
+
raise InconsistentMessageReplies(
|
40
|
+
reason="Expected exactly one Array in ArrayRecord. "
|
41
|
+
"Skipping aggregation."
|
42
|
+
)
|
43
|
+
|
44
|
+
def configure_train(
|
45
|
+
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
46
|
+
) -> Iterable[Message]:
|
47
|
+
"""Configure the next round of federated training."""
|
48
|
+
self._ensure_single_array(arrays)
|
49
|
+
# Keep track of array record being communicated
|
50
|
+
self.current_bst = arrays["0"].numpy().tobytes()
|
51
|
+
return super().configure_train(server_round, arrays, config, grid)
|
52
|
+
|
53
|
+
def aggregate_train(
|
54
|
+
self,
|
55
|
+
server_round: int,
|
56
|
+
replies: Iterable[Message],
|
57
|
+
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
58
|
+
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
59
|
+
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
60
|
+
|
61
|
+
arrays, metrics = None, None
|
62
|
+
if valid_replies:
|
63
|
+
reply_contents = [msg.content for msg in valid_replies]
|
64
|
+
array_record_key = next(iter(reply_contents[0].array_records.keys()))
|
65
|
+
|
66
|
+
# Aggregate ArrayRecords
|
67
|
+
for content in reply_contents:
|
68
|
+
self._ensure_single_array(cast(ArrayRecord, content[array_record_key]))
|
69
|
+
bst = content[array_record_key]["0"].numpy().tobytes() # type: ignore[union-attr]
|
70
|
+
|
71
|
+
if self.current_bst is not None:
|
72
|
+
self.current_bst = aggregate_bagging(self.current_bst, bst)
|
73
|
+
|
74
|
+
if self.current_bst is not None:
|
75
|
+
arrays = ArrayRecord([np.frombuffer(self.current_bst, dtype=np.uint8)])
|
76
|
+
|
77
|
+
# Aggregate MetricRecords
|
78
|
+
metrics = self.train_metrics_aggr_fn(
|
79
|
+
reply_contents,
|
80
|
+
self.weighted_by_key,
|
81
|
+
)
|
82
|
+
return arrays, metrics
|
@@ -15,6 +15,7 @@
|
|
15
15
|
"""Flower message-based strategy utilities."""
|
16
16
|
|
17
17
|
|
18
|
+
import json
|
18
19
|
import random
|
19
20
|
from collections import OrderedDict
|
20
21
|
from logging import INFO
|
@@ -249,3 +250,50 @@ def validate_message_reply_consistency(
|
|
249
250
|
"must be a single value (int or float), but a list was found. Skipping "
|
250
251
|
"aggregation."
|
251
252
|
)
|
253
|
+
|
254
|
+
|
255
|
+
def aggregate_bagging(
|
256
|
+
bst_prev_org: bytes,
|
257
|
+
bst_curr_org: bytes,
|
258
|
+
) -> bytes:
|
259
|
+
"""Conduct bagging aggregation for given trees."""
|
260
|
+
if bst_prev_org == b"":
|
261
|
+
return bst_curr_org
|
262
|
+
|
263
|
+
# Get the tree numbers
|
264
|
+
tree_num_prev, _ = _get_tree_nums(bst_prev_org)
|
265
|
+
_, paral_tree_num_curr = _get_tree_nums(bst_curr_org)
|
266
|
+
|
267
|
+
bst_prev = json.loads(bytearray(bst_prev_org))
|
268
|
+
bst_curr = json.loads(bytearray(bst_curr_org))
|
269
|
+
|
270
|
+
previous_model = bst_prev["learner"]["gradient_booster"]["model"]
|
271
|
+
previous_model["gbtree_model_param"]["num_trees"] = str(
|
272
|
+
tree_num_prev + paral_tree_num_curr
|
273
|
+
)
|
274
|
+
iteration_indptr = previous_model["iteration_indptr"]
|
275
|
+
previous_model["iteration_indptr"].append(
|
276
|
+
iteration_indptr[-1] + paral_tree_num_curr
|
277
|
+
)
|
278
|
+
|
279
|
+
# Aggregate new trees
|
280
|
+
trees_curr = bst_curr["learner"]["gradient_booster"]["model"]["trees"]
|
281
|
+
for tree_count in range(paral_tree_num_curr):
|
282
|
+
trees_curr[tree_count]["id"] = tree_num_prev + tree_count
|
283
|
+
previous_model["trees"].append(trees_curr[tree_count])
|
284
|
+
previous_model["tree_info"].append(0)
|
285
|
+
|
286
|
+
bst_prev_bytes = bytes(json.dumps(bst_prev), "utf-8")
|
287
|
+
|
288
|
+
return bst_prev_bytes
|
289
|
+
|
290
|
+
|
291
|
+
def _get_tree_nums(xgb_model_org: bytes) -> tuple[int, int]:
|
292
|
+
xgb_model = json.loads(bytearray(xgb_model_org))
|
293
|
+
|
294
|
+
# Access model parameters
|
295
|
+
model_param = xgb_model["learner"]["gradient_booster"]["model"][
|
296
|
+
"gbtree_model_param"
|
297
|
+
]
|
298
|
+
# Return the number of trees and the number of parallel trees
|
299
|
+
return int(model_param["num_trees"]), int(model_param["num_parallel_tree"])
|
{flwr_nightly-1.22.0.dev20250913.dist-info → flwr_nightly-1.22.0.dev20250915.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: flwr-nightly
|
3
|
-
Version: 1.22.0.
|
3
|
+
Version: 1.22.0.dev20250915
|
4
4
|
Summary: Flower: A Friendly Federated AI Framework
|
5
5
|
License: Apache-2.0
|
6
6
|
Keywords: Artificial Intelligence,Federated AI,Federated Analytics,Federated Evaluation,Federated Learning,Flower,Machine Learning
|
{flwr_nightly-1.22.0.dev20250913.dist-info → flwr_nightly-1.22.0.dev20250915.dist-info}/RECORD
RENAMED
@@ -18,7 +18,7 @@ flwr/cli/login/__init__.py,sha256=B1SXKU3HCQhWfFDMJhlC7FOl8UsvH4mxysxeBnrfyUE,80
|
|
18
18
|
flwr/cli/login/login.py,sha256=RM1Jiv_VFm3oz4rTHSr3D87X90lW3WzErjBBU7WviWY,4309
|
19
19
|
flwr/cli/ls.py,sha256=3YK7cpoImJ7PbjlP_JgYRQWz1GymX2q7Reu-mKJEpao,10957
|
20
20
|
flwr/cli/new/__init__.py,sha256=QA1E2QtzPvFCjLTUHnFnJbufuFiGyT_0Y53Wpbvg1F0,790
|
21
|
-
flwr/cli/new/new.py,sha256=
|
21
|
+
flwr/cli/new/new.py,sha256=KyTs9Fbm4eoJ5DohhuTkYNJJX5rDC0p-YTPtNatYXrI,10529
|
22
22
|
flwr/cli/new/templates/__init__.py,sha256=FpjWCfIySU2DB4kh0HOXLAjlZNNFDTVU4w3HoE2TzcI,725
|
23
23
|
flwr/cli/new/templates/app/.gitignore.tpl,sha256=HZJcGQoxp7aUzaPg8Uqch3kNrIESwr9yjimDxJYgXVY,3104
|
24
24
|
flwr/cli/new/templates/app/LICENSE.tpl,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
@@ -29,14 +29,14 @@ flwr/cli/new/templates/app/__init__.py,sha256=LbR0ksGiF566JcHM_H5m1Tc4-oYUEilWFl
|
|
29
29
|
flwr/cli/new/templates/app/code/__init__.baseline.py.tpl,sha256=YkHAgppUeD2BnBoGfVB6dEvBfjuIPGsU1gw4CiUi3qA,40
|
30
30
|
flwr/cli/new/templates/app/code/__init__.py,sha256=zXa2YU1swzHxOKDQbwlDMEwVPOUswVeosjkiXNMTgFo,736
|
31
31
|
flwr/cli/new/templates/app/code/__init__.py.tpl,sha256=J0Gn74E7khpLyKJVNqOPu7ev93vkcu1PZugsbxtABMw,52
|
32
|
-
flwr/cli/new/templates/app/code/__init__.
|
32
|
+
flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl,sha256=mKIS8MK_X8T9NlmcX1-_c9Bbexc-ueqDIBI7uN6c4dE,45
|
33
33
|
flwr/cli/new/templates/app/code/client.baseline.py.tpl,sha256=IYlCZqnaxT2ucP1ReffRNohOkYwNrhtrnDoQBBcrThY,1901
|
34
34
|
flwr/cli/new/templates/app/code/client.huggingface.py.tpl,sha256=SIZZ3s-6u8IU8cFfsqu6ZU8zjhfI1m1SWauOSUcW8TA,3015
|
35
35
|
flwr/cli/new/templates/app/code/client.jax.py.tpl,sha256=uFCIPwAHYiRAgh2W3nRni_Oig02ZzRF-ofUG5O19zcE,2125
|
36
36
|
flwr/cli/new/templates/app/code/client.mlx.py.tpl,sha256=CHU2IBIzI2YENZZuvTsAlSdL94DK19wMYMIhr-JgwZ8,3422
|
37
37
|
flwr/cli/new/templates/app/code/client.numpy.py.tpl,sha256=1_WEoOPe9jJeK-7FZgYuDUqY8mC0vxgqA83d-h201Gk,1381
|
38
|
-
flwr/cli/new/templates/app/code/client.pytorch.py.tpl,sha256=
|
39
|
-
flwr/cli/new/templates/app/code/client.
|
38
|
+
flwr/cli/new/templates/app/code/client.pytorch.py.tpl,sha256=fYoh-dTu07LkqNYvwcxQnbgVvH4Yo4eiGEcyHECbsnU,2473
|
39
|
+
flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl,sha256=fuxVmZpjHIueNy_aHWF81531vmi8DGu4CYjYDqmUwWo,1705
|
40
40
|
flwr/cli/new/templates/app/code/client.sklearn.py.tpl,sha256=0qqEe-RRjkHGOH8gsD9e83ae-kyyYixhyBgzVHjYpzk,3500
|
41
41
|
flwr/cli/new/templates/app/code/client.tensorflow.py.tpl,sha256=8o55KXpsbF_rv6o98ZNYJDCazjwMp_RPTaSzDfT7Qlw,2682
|
42
42
|
flwr/cli/new/templates/app/code/dataset.baseline.py.tpl,sha256=jbd_exHAk2-Blu_kVutjPO6a_dkJQWb232zxSeXIZ1k,1453
|
@@ -52,8 +52,8 @@ flwr/cli/new/templates/app/code/server.huggingface.py.tpl,sha256=_2Mv-SqGSMf7sMd
|
|
52
52
|
flwr/cli/new/templates/app/code/server.jax.py.tpl,sha256=RW-rh7ogcJ3_BD66bJxTw-ZoP7c-4SK8hVHc-e0SSVY,1029
|
53
53
|
flwr/cli/new/templates/app/code/server.mlx.py.tpl,sha256=J8rIe6RL2ndODVJD79xShRKBH70HljFSCi4s_RJ-xLQ,1200
|
54
54
|
flwr/cli/new/templates/app/code/server.numpy.py.tpl,sha256=T3hcKbPw3uL5lXEP-MuVJXIBXjzva5sWJXfpQqarUwA,955
|
55
|
-
flwr/cli/new/templates/app/code/server.pytorch.py.tpl,sha256=
|
56
|
-
flwr/cli/new/templates/app/code/server.
|
55
|
+
flwr/cli/new/templates/app/code/server.pytorch.py.tpl,sha256=epARqfcQ-EQsdZwaaaUp5y4OSTBT6CiFGlNRocw-23A,1158
|
56
|
+
flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl,sha256=gvBsGA_Jg9kAH8xTxjzTjMcvBtciuccOwQFbO7ey8tU,916
|
57
57
|
flwr/cli/new/templates/app/code/server.sklearn.py.tpl,sha256=ehQ5VRgBn92WeFl6kupwJnuxSNkKvE-EvKde6A9mNQo,1377
|
58
58
|
flwr/cli/new/templates/app/code/server.tensorflow.py.tpl,sha256=2-WTOPd-ewdLd9QmSlflIH7ix7zxAzPEOZoyiPBOy8c,1010
|
59
59
|
flwr/cli/new/templates/app/code/strategy.baseline.py.tpl,sha256=YkHAgppUeD2BnBoGfVB6dEvBfjuIPGsU1gw4CiUi3qA,40
|
@@ -61,8 +61,8 @@ flwr/cli/new/templates/app/code/task.huggingface.py.tpl,sha256=piBbY3Dg60bQnCg15
|
|
61
61
|
flwr/cli/new/templates/app/code/task.jax.py.tpl,sha256=Fb0XgdTAQplM-ZCusI081XA9asO3gHptH772S-Xcyy8,1525
|
62
62
|
flwr/cli/new/templates/app/code/task.mlx.py.tpl,sha256=YxH5z4s5kOh5_9DIY9pvzqURckLDfgdanTA68_iM_Wo,2946
|
63
63
|
flwr/cli/new/templates/app/code/task.numpy.py.tpl,sha256=CwUJPnN3z6GjP8-KVGWzx7RYRJsl0wLFZ72xscvl3RM,126
|
64
|
-
flwr/cli/new/templates/app/code/task.pytorch.py.tpl,sha256=
|
65
|
-
flwr/cli/new/templates/app/code/task.
|
64
|
+
flwr/cli/new/templates/app/code/task.pytorch.py.tpl,sha256=RKA5lV6O6OnVKZ2r75pbzwy9arg5o2lzXqG2kNrLIUU,3446
|
65
|
+
flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl,sha256=XlJqA4Ix_PloO_zJLhjiN5vDj16w3I4CPVGdmbe8asE,3800
|
66
66
|
flwr/cli/new/templates/app/code/task.sklearn.py.tpl,sha256=vHdhtMp0FHxbYafXyhDT9aKmmmA0Jvpx5Oum1Yu9lWY,1850
|
67
67
|
flwr/cli/new/templates/app/code/task.tensorflow.py.tpl,sha256=impgWN7MfztmcWF4xh1llcZGsgTvrb1HD5ZE0t-8U08,1731
|
68
68
|
flwr/cli/new/templates/app/code/utils.baseline.py.tpl,sha256=YkHAgppUeD2BnBoGfVB6dEvBfjuIPGsU1gw4CiUi3qA,40
|
@@ -72,8 +72,8 @@ flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl,sha256=xHGF38i7oFpvnFv
|
|
72
72
|
flwr/cli/new/templates/app/pyproject.jax.toml.tpl,sha256=fdDhwmPoMirJ095cU_vFCBf0ILQlAoa1fdnHb2LM1yk,1471
|
73
73
|
flwr/cli/new/templates/app/pyproject.mlx.toml.tpl,sha256=PAjPT2v06sBZxacNiyMJloDwocCK5tFcGQmMXOoBqc8,1542
|
74
74
|
flwr/cli/new/templates/app/pyproject.numpy.toml.tpl,sha256=Kb_O2iQfzwc6FTy3fWqtQYc3FwY6x9SUgQPGqZR_ILg,1409
|
75
|
-
flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl,sha256=
|
76
|
-
flwr/cli/new/templates/app/pyproject.
|
75
|
+
flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl,sha256=SE4H23OFkQbqNU64nYf38igqrT4cJGA7XxEtSnNxJqg,1490
|
76
|
+
flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl,sha256=docQbs3MuRR-yT24lVz7N2sQL3Sj49EHuOCuRj_0djQ,1508
|
77
77
|
flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl,sha256=apauU_PUmLEbt2rjckKniEbzdRs1EnMri_qgtHtBJZ8,1484
|
78
78
|
flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl,sha256=LQpDKJTEnRKj5Ygn5FkT44SxlnLVprkPlbrGaFf5Q50,1508
|
79
79
|
flwr/cli/run/__init__.py,sha256=RPyB7KbYTFl6YRiilCch6oezxrLQrl1kijV7BMGkLbA,790
|
@@ -332,16 +332,17 @@ flwr/server/workflow/secure_aggregation/secaggplus_workflow.py,sha256=DkayCsnlAy
|
|
332
332
|
flwr/serverapp/__init__.py,sha256=ZujKNXULwhWYQhFnxOOT5Wi9MRq2JCWFhAAj7ouiQ78,884
|
333
333
|
flwr/serverapp/dp_fixed_clipping.py,sha256=wbP4W7CaUHXdll8ZSVUnTBSEWrnWM00CGk63rOR-Q2s,12133
|
334
334
|
flwr/serverapp/exception.py,sha256=5cuH-2AafvihzosWDdDjuMmHdDqZ1XxHvCqZXNBVklw,1334
|
335
|
-
flwr/serverapp/strategy/__init__.py,sha256=
|
335
|
+
flwr/serverapp/strategy/__init__.py,sha256=0ldxlooz4a5yewUbQJGVrW9awrrIcFDIrNR4yZgpfKw,1292
|
336
336
|
flwr/serverapp/strategy/dp_fixed_clipping.py,sha256=wbP4W7CaUHXdll8ZSVUnTBSEWrnWM00CGk63rOR-Q2s,12133
|
337
337
|
flwr/serverapp/strategy/fedadagrad.py,sha256=fD65P6OEERa_pxq847e1UZpA083AcWR44XavYB0naGM,6343
|
338
338
|
flwr/serverapp/strategy/fedadam.py,sha256=s3xPIqhopy6yPTeFxevSPnc7a6BcKnKsvo2AaO6Z_xs,7138
|
339
339
|
flwr/serverapp/strategy/fedavg.py,sha256=53L06lZLkbGV0TRZrUWvPaocvFTT1PAhTvu9UkKq1zE,11294
|
340
340
|
flwr/serverapp/strategy/fedopt.py,sha256=kqT0uV2IUE93O72XEVa1JJo61dcwbZEoT9KmYTjR2tE,8477
|
341
|
+
flwr/serverapp/strategy/fedxgb_bagging.py,sha256=ktDjzov4y0BRecioq788umCEtcuwElou9olBizQKOnM,3282
|
341
342
|
flwr/serverapp/strategy/fedyogi.py,sha256=1Ripr4Hi2cdeTOLiFOXtMKvOxR3BsUQwc7bbTrXN4LM,6653
|
342
343
|
flwr/serverapp/strategy/result.py,sha256=E0Hl2VLnZAgQJjE2GDoKsK7JX-kPPU2KXc47Axt6hGw,4295
|
343
344
|
flwr/serverapp/strategy/strategy.py,sha256=8uJGGm1ROLZERQ_dkRS7Z_rs-yK6XCE0UxXtIdFiEWk,10789
|
344
|
-
flwr/serverapp/strategy/strategy_utils.py,sha256=
|
345
|
+
flwr/serverapp/strategy/strategy_utils.py,sha256=hiwS7k-Hx6_c4NZXoKpHucS5CBKb7f8GppXRBSMt3Us,10851
|
345
346
|
flwr/serverapp/strategy/strategy_utils_tests.py,sha256=o32XHujd9PLCB-YZMI2AttWLlvUXHe9yuxgiCrCkpgU,10209
|
346
347
|
flwr/simulation/__init__.py,sha256=Gg6OsP1Z-ixc3-xxzvl7j7rz2Fijy9rzyEPpxgAQCeM,1556
|
347
348
|
flwr/simulation/app.py,sha256=LbGLMvN9Ap119yBqsUcNNmVLRnCySnr4VechqcQ1hpA,10401
|
@@ -403,7 +404,7 @@ flwr/supernode/servicer/__init__.py,sha256=lucTzre5WPK7G1YLCfaqg3rbFWdNSb7ZTt-ca
|
|
403
404
|
flwr/supernode/servicer/clientappio/__init__.py,sha256=7Oy62Y_oijqF7Dxi6tpcUQyOpLc_QpIRZ83NvwmB0Yg,813
|
404
405
|
flwr/supernode/servicer/clientappio/clientappio_servicer.py,sha256=nIHRu38EWK-rpNOkcgBRAAKwYQQWFeCwu0lkO7OPZGQ,10239
|
405
406
|
flwr/supernode/start_client_internal.py,sha256=Y9S1-QlO2WP6eo4JvWzIpfaCoh2aoE7bjEYyxNNnlyg,20777
|
406
|
-
flwr_nightly-1.22.0.
|
407
|
-
flwr_nightly-1.22.0.
|
408
|
-
flwr_nightly-1.22.0.
|
409
|
-
flwr_nightly-1.22.0.
|
407
|
+
flwr_nightly-1.22.0.dev20250915.dist-info/METADATA,sha256=FBo-ub8Rc1rRhLrioWMroybBDcoP9t7v6vBqdE9U3do,15967
|
408
|
+
flwr_nightly-1.22.0.dev20250915.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
409
|
+
flwr_nightly-1.22.0.dev20250915.dist-info/entry_points.txt,sha256=hxHD2ixb_vJFDOlZV-zB4Ao32_BQlL34ftsDh1GXv14,420
|
410
|
+
flwr_nightly-1.22.0.dev20250915.dist-info/RECORD,,
|
@@ -1,80 +0,0 @@
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
2
|
-
|
3
|
-
import torch
|
4
|
-
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
|
5
|
-
from flwr.clientapp import ClientApp
|
6
|
-
|
7
|
-
from $import_name.task import Net, load_data
|
8
|
-
from $import_name.task import test as test_fn
|
9
|
-
from $import_name.task import train as train_fn
|
10
|
-
|
11
|
-
# Flower ClientApp
|
12
|
-
app = ClientApp()
|
13
|
-
|
14
|
-
|
15
|
-
@app.train()
|
16
|
-
def train(msg: Message, context: Context):
|
17
|
-
"""Train the model on local data."""
|
18
|
-
|
19
|
-
# Load the model and initialize it with the received weights
|
20
|
-
model = Net()
|
21
|
-
model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
|
22
|
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
23
|
-
model.to(device)
|
24
|
-
|
25
|
-
# Load the data
|
26
|
-
partition_id = context.node_config["partition-id"]
|
27
|
-
num_partitions = context.node_config["num-partitions"]
|
28
|
-
trainloader, _ = load_data(partition_id, num_partitions)
|
29
|
-
|
30
|
-
# Call the training function
|
31
|
-
train_loss = train_fn(
|
32
|
-
model,
|
33
|
-
trainloader,
|
34
|
-
context.run_config["local-epochs"],
|
35
|
-
msg.content["config"]["lr"],
|
36
|
-
device,
|
37
|
-
)
|
38
|
-
|
39
|
-
# Construct and return reply Message
|
40
|
-
model_record = ArrayRecord(model.state_dict())
|
41
|
-
metrics = {
|
42
|
-
"train_loss": train_loss,
|
43
|
-
"num-examples": len(trainloader.dataset),
|
44
|
-
}
|
45
|
-
metric_record = MetricRecord(metrics)
|
46
|
-
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
47
|
-
return Message(content=content, reply_to=msg)
|
48
|
-
|
49
|
-
|
50
|
-
@app.evaluate()
|
51
|
-
def evaluate(msg: Message, context: Context):
|
52
|
-
"""Evaluate the model on local data."""
|
53
|
-
|
54
|
-
# Load the model and initialize it with the received weights
|
55
|
-
model = Net()
|
56
|
-
model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
|
57
|
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
58
|
-
model.to(device)
|
59
|
-
|
60
|
-
# Load the data
|
61
|
-
partition_id = context.node_config["partition-id"]
|
62
|
-
num_partitions = context.node_config["num-partitions"]
|
63
|
-
_, valloader = load_data(partition_id, num_partitions)
|
64
|
-
|
65
|
-
# Call the evaluation function
|
66
|
-
eval_loss, eval_acc = test_fn(
|
67
|
-
model,
|
68
|
-
valloader,
|
69
|
-
device,
|
70
|
-
)
|
71
|
-
|
72
|
-
# Construct and return reply Message
|
73
|
-
metrics = {
|
74
|
-
"eval_loss": eval_loss,
|
75
|
-
"eval_acc": eval_acc,
|
76
|
-
"num-examples": len(valloader.dataset),
|
77
|
-
}
|
78
|
-
metric_record = MetricRecord(metrics)
|
79
|
-
content = RecordDict({"metrics": metric_record})
|
80
|
-
return Message(content=content, reply_to=msg)
|
@@ -1,41 +0,0 @@
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
2
|
-
|
3
|
-
import torch
|
4
|
-
from flwr.app import ArrayRecord, ConfigRecord, Context
|
5
|
-
from flwr.serverapp import Grid, ServerApp
|
6
|
-
from flwr.serverapp.strategy import FedAvg
|
7
|
-
|
8
|
-
from $import_name.task import Net
|
9
|
-
|
10
|
-
# Create ServerApp
|
11
|
-
app = ServerApp()
|
12
|
-
|
13
|
-
|
14
|
-
@app.main()
|
15
|
-
def main(grid: Grid, context: Context) -> None:
|
16
|
-
"""Main entry point for the ServerApp."""
|
17
|
-
|
18
|
-
# Read run config
|
19
|
-
fraction_train: float = context.run_config["fraction-train"]
|
20
|
-
num_rounds: int = context.run_config["num-server-rounds"]
|
21
|
-
lr: float = context.run_config["lr"]
|
22
|
-
|
23
|
-
# Load global model
|
24
|
-
global_model = Net()
|
25
|
-
arrays = ArrayRecord(global_model.state_dict())
|
26
|
-
|
27
|
-
# Initialize FedAvg strategy
|
28
|
-
strategy = FedAvg(fraction_train=fraction_train)
|
29
|
-
|
30
|
-
# Start strategy, run FedAvg for `num_rounds`
|
31
|
-
result = strategy.start(
|
32
|
-
grid=grid,
|
33
|
-
initial_arrays=arrays,
|
34
|
-
train_config=ConfigRecord({"lr": lr}),
|
35
|
-
num_rounds=num_rounds,
|
36
|
-
)
|
37
|
-
|
38
|
-
# Save final model to disk
|
39
|
-
print("\nSaving final model to disk...")
|
40
|
-
state_dict = result.arrays.to_torch_state_dict()
|
41
|
-
torch.save(state_dict, "final_model.pt")
|
File without changes
|
{flwr_nightly-1.22.0.dev20250913.dist-info → flwr_nightly-1.22.0.dev20250915.dist-info}/WHEEL
RENAMED
File without changes
|
File without changes
|