flwr 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +4 -1
- flwr/app/__init__.py +28 -0
- flwr/app/exception.py +31 -0
- flwr/cli/app.py +2 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
- flwr/cli/cli_user_auth_interceptor.py +1 -1
- flwr/cli/config_utils.py +3 -3
- flwr/cli/constant.py +25 -8
- flwr/cli/log.py +9 -9
- flwr/cli/login/login.py +3 -3
- flwr/cli/ls.py +5 -5
- flwr/cli/new/new.py +15 -2
- flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
- flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +1 -0
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
- flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
- flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
- flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
- flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +111 -0
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
- flwr/cli/pull.py +100 -0
- flwr/cli/run/run.py +9 -13
- flwr/cli/stop.py +7 -4
- flwr/cli/utils.py +36 -8
- flwr/client/grpc_rere_client/connection.py +1 -12
- flwr/client/rest_client/connection.py +3 -0
- flwr/clientapp/__init__.py +10 -0
- flwr/clientapp/mod/__init__.py +29 -0
- flwr/clientapp/mod/centraldp_mods.py +248 -0
- flwr/clientapp/mod/localdp_mod.py +169 -0
- flwr/clientapp/typing.py +22 -0
- flwr/common/args.py +20 -6
- flwr/common/auth_plugin/__init__.py +4 -4
- flwr/common/auth_plugin/auth_plugin.py +7 -7
- flwr/common/constant.py +26 -4
- flwr/common/event_log_plugin/event_log_plugin.py +1 -1
- flwr/common/exit/__init__.py +4 -0
- flwr/common/exit/exit.py +8 -1
- flwr/common/exit/exit_code.py +30 -7
- flwr/common/exit/exit_handler.py +62 -0
- flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
- flwr/common/grpc.py +0 -11
- flwr/common/inflatable_utils.py +1 -1
- flwr/common/logger.py +1 -1
- flwr/common/record/typeddict.py +12 -0
- flwr/common/retry_invoker.py +30 -11
- flwr/common/telemetry.py +4 -0
- flwr/compat/server/app.py +2 -2
- flwr/proto/appio_pb2.py +25 -17
- flwr/proto/appio_pb2.pyi +46 -2
- flwr/proto/clientappio_pb2.py +3 -11
- flwr/proto/clientappio_pb2.pyi +0 -47
- flwr/proto/clientappio_pb2_grpc.py +19 -20
- flwr/proto/clientappio_pb2_grpc.pyi +10 -11
- flwr/proto/control_pb2.py +66 -0
- flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +24 -0
- flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +88 -54
- flwr/proto/control_pb2_grpc.pyi +106 -0
- flwr/proto/serverappio_pb2.py +2 -2
- flwr/proto/serverappio_pb2_grpc.py +68 -0
- flwr/proto/serverappio_pb2_grpc.pyi +26 -0
- flwr/proto/simulationio_pb2.py +4 -11
- flwr/proto/simulationio_pb2.pyi +0 -58
- flwr/proto/simulationio_pb2_grpc.py +129 -27
- flwr/proto/simulationio_pb2_grpc.pyi +52 -13
- flwr/server/app.py +142 -152
- flwr/server/grid/grpc_grid.py +3 -0
- flwr/server/grid/inmemory_grid.py +1 -0
- flwr/server/serverapp/app.py +157 -146
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
- flwr/server/superlink/fleet/vce/vce_api.py +6 -6
- flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
- flwr/server/superlink/linkstate/linkstate.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -1
- flwr/server/superlink/serverappio/serverappio_servicer.py +61 -6
- flwr/server/superlink/simulation/simulationio_servicer.py +97 -21
- flwr/serverapp/__init__.py +12 -0
- flwr/serverapp/exception.py +38 -0
- flwr/serverapp/strategy/__init__.py +64 -0
- flwr/serverapp/strategy/bulyan.py +238 -0
- flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
- flwr/serverapp/strategy/dp_fixed_clipping.py +374 -0
- flwr/serverapp/strategy/fedadagrad.py +159 -0
- flwr/serverapp/strategy/fedadam.py +178 -0
- flwr/serverapp/strategy/fedavg.py +320 -0
- flwr/serverapp/strategy/fedavgm.py +198 -0
- flwr/serverapp/strategy/fedmedian.py +105 -0
- flwr/serverapp/strategy/fedopt.py +218 -0
- flwr/serverapp/strategy/fedprox.py +174 -0
- flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
- flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
- flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
- flwr/serverapp/strategy/fedyogi.py +170 -0
- flwr/serverapp/strategy/krum.py +112 -0
- flwr/serverapp/strategy/multikrum.py +247 -0
- flwr/serverapp/strategy/qfedavg.py +252 -0
- flwr/serverapp/strategy/result.py +105 -0
- flwr/serverapp/strategy/strategy.py +285 -0
- flwr/serverapp/strategy/strategy_utils.py +299 -0
- flwr/simulation/app.py +161 -164
- flwr/simulation/run_simulation.py +25 -30
- flwr/supercore/app_utils.py +58 -0
- flwr/{supernode/scheduler → supercore/cli}/__init__.py +3 -3
- flwr/supercore/cli/flower_superexec.py +166 -0
- flwr/supercore/constant.py +19 -0
- flwr/supercore/{scheduler → corestate}/__init__.py +3 -3
- flwr/supercore/corestate/corestate.py +81 -0
- flwr/supercore/grpc_health/__init__.py +3 -0
- flwr/supercore/grpc_health/health_server.py +53 -0
- flwr/supercore/grpc_health/simple_health_servicer.py +2 -2
- flwr/{superexec → supercore/superexec}/__init__.py +1 -1
- flwr/supercore/superexec/plugin/__init__.py +28 -0
- flwr/{supernode/scheduler/simple_clientapp_scheduler_plugin.py → supercore/superexec/plugin/base_exec_plugin.py} +10 -6
- flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
- flwr/supercore/{scheduler/plugin.py → superexec/plugin/exec_plugin.py} +15 -5
- flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
- flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
- flwr/supercore/superexec/run_superexec.py +199 -0
- flwr/superlink/artifact_provider/__init__.py +22 -0
- flwr/superlink/artifact_provider/artifact_provider.py +37 -0
- flwr/superlink/servicer/__init__.py +15 -0
- flwr/superlink/servicer/control/__init__.py +22 -0
- flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +7 -7
- flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +27 -29
- flwr/{superexec/exec_license_interceptor.py → superlink/servicer/control/control_license_interceptor.py} +6 -6
- flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +127 -31
- flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +10 -10
- flwr/supernode/cli/flower_supernode.py +3 -0
- flwr/supernode/cli/flwr_clientapp.py +18 -21
- flwr/supernode/nodestate/in_memory_nodestate.py +2 -2
- flwr/supernode/nodestate/nodestate.py +3 -59
- flwr/supernode/runtime/run_clientapp.py +39 -102
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +10 -17
- flwr/supernode/start_client_internal.py +35 -76
- {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/METADATA +9 -18
- {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/RECORD +176 -128
- {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/entry_points.txt +1 -0
- flwr/proto/exec_pb2.py +0 -62
- flwr/proto/exec_pb2_grpc.pyi +0 -93
- flwr/superexec/app.py +0 -45
- flwr/superexec/deployment.py +0 -191
- flwr/superexec/executor.py +0 -100
- flwr/superexec/simulation.py +0 -129
- {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import xgboost as xgb
|
|
5
|
+
from flwr.app import ArrayRecord, Context
|
|
6
|
+
from flwr.common.config import unflatten_dict
|
|
7
|
+
from flwr.serverapp import Grid, ServerApp
|
|
8
|
+
from flwr.serverapp.strategy import FedXgbBagging
|
|
9
|
+
|
|
10
|
+
from $import_name.task import replace_keys
|
|
11
|
+
|
|
12
|
+
# Create ServerApp
|
|
13
|
+
app = ServerApp()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@app.main()
|
|
17
|
+
def main(grid: Grid, context: Context) -> None:
|
|
18
|
+
# Read run config
|
|
19
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
20
|
+
fraction_train = context.run_config["fraction-train"]
|
|
21
|
+
fraction_evaluate = context.run_config["fraction-evaluate"]
|
|
22
|
+
# Flatted config dict and replace "-" with "_"
|
|
23
|
+
cfg = replace_keys(unflatten_dict(context.run_config))
|
|
24
|
+
params = cfg["params"]
|
|
25
|
+
|
|
26
|
+
# Init global model
|
|
27
|
+
# Init with an empty object; the XGBooster will be created
|
|
28
|
+
# and trained on the client side.
|
|
29
|
+
global_model = b""
|
|
30
|
+
# Note: we store the model as the first item in a list into ArrayRecord,
|
|
31
|
+
# which can be accessed using index ["0"].
|
|
32
|
+
arrays = ArrayRecord([np.frombuffer(global_model, dtype=np.uint8)])
|
|
33
|
+
|
|
34
|
+
# Initialize FedXgbBagging strategy
|
|
35
|
+
strategy = FedXgbBagging(
|
|
36
|
+
fraction_train=fraction_train,
|
|
37
|
+
fraction_evaluate=fraction_evaluate,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# Start strategy, run FedXgbBagging for `num_rounds`
|
|
41
|
+
result = strategy.start(
|
|
42
|
+
grid=grid,
|
|
43
|
+
initial_arrays=arrays,
|
|
44
|
+
num_rounds=num_rounds,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Save final model to disk
|
|
48
|
+
bst = xgb.Booster(params=params)
|
|
49
|
+
global_model = bytearray(result.arrays["0"].numpy().tobytes())
|
|
50
|
+
|
|
51
|
+
# Load global model into booster
|
|
52
|
+
bst.load_model(global_model)
|
|
53
|
+
|
|
54
|
+
# Save model
|
|
55
|
+
print("\nSaving final model to disk...")
|
|
56
|
+
bst.save_model("final_model.json")
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
import warnings
|
|
4
|
-
from collections import OrderedDict
|
|
5
4
|
|
|
6
5
|
import torch
|
|
7
6
|
import transformers
|
|
@@ -62,17 +61,24 @@ def load_data(partition_id: int, num_partitions: int, model_name: str):
|
|
|
62
61
|
return trainloader, testloader
|
|
63
62
|
|
|
64
63
|
|
|
65
|
-
def train(net, trainloader,
|
|
64
|
+
def train(net, trainloader, num_steps, device):
|
|
66
65
|
optimizer = AdamW(net.parameters(), lr=5e-5)
|
|
67
66
|
net.train()
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
67
|
+
running_loss = 0.0
|
|
68
|
+
step_cnt = 0
|
|
69
|
+
for batch in trainloader:
|
|
70
|
+
batch = {k: v.to(device) for k, v in batch.items()}
|
|
71
|
+
outputs = net(**batch)
|
|
72
|
+
loss = outputs.loss
|
|
73
|
+
loss.backward()
|
|
74
|
+
optimizer.step()
|
|
75
|
+
optimizer.zero_grad()
|
|
76
|
+
running_loss += loss.item()
|
|
77
|
+
step_cnt += 1
|
|
78
|
+
if step_cnt >= num_steps:
|
|
79
|
+
break
|
|
80
|
+
avg_trainloss = running_loss / step_cnt
|
|
81
|
+
return avg_trainloss
|
|
76
82
|
|
|
77
83
|
|
|
78
84
|
def test(net, testloader, device):
|
|
@@ -90,13 +96,3 @@ def test(net, testloader, device):
|
|
|
90
96
|
loss /= len(testloader.dataset)
|
|
91
97
|
accuracy = metric.compute()["accuracy"]
|
|
92
98
|
return loss, accuracy
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
def get_weights(net):
|
|
96
|
-
return [val.cpu().numpy() for _, val in net.state_dict().items()]
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def set_weights(net, parameters):
|
|
100
|
-
params_dict = zip(net.state_dict().keys(), parameters)
|
|
101
|
-
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
|
|
102
|
-
net.load_state_dict(state_dict, strict=True)
|
|
@@ -31,7 +31,7 @@ def loss_fn(params, X, y):
|
|
|
31
31
|
def train(params, grad_fn, X, y):
|
|
32
32
|
loss = 1_000_000
|
|
33
33
|
num_examples = X.shape[0]
|
|
34
|
-
for
|
|
34
|
+
for _ in range(50):
|
|
35
35
|
grads = grad_fn(params, X, y)
|
|
36
36
|
params = jax.tree.map(lambda p, g: p - 0.05 * g, params, grads)
|
|
37
37
|
loss = loss_fn(params, X, y)
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
-
from collections import OrderedDict
|
|
4
|
-
|
|
5
3
|
import torch
|
|
6
4
|
import torch.nn as nn
|
|
7
5
|
import torch.nn.functional as F
|
|
@@ -34,6 +32,14 @@ class Net(nn.Module):
|
|
|
34
32
|
|
|
35
33
|
fds = None # Cache FederatedDataset
|
|
36
34
|
|
|
35
|
+
pytorch_transforms = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def apply_transforms(batch):
|
|
39
|
+
"""Apply transforms to the partition from FederatedDataset."""
|
|
40
|
+
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
|
|
41
|
+
return batch
|
|
42
|
+
|
|
37
43
|
|
|
38
44
|
def load_data(partition_id: int, num_partitions: int):
|
|
39
45
|
"""Load partition CIFAR10 data."""
|
|
@@ -48,38 +54,29 @@ def load_data(partition_id: int, num_partitions: int):
|
|
|
48
54
|
partition = fds.load_partition(partition_id)
|
|
49
55
|
# Divide data on each node: 80% train, 20% test
|
|
50
56
|
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
|
51
|
-
|
|
52
|
-
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
def apply_transforms(batch):
|
|
56
|
-
"""Apply transforms to the partition from FederatedDataset."""
|
|
57
|
-
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
|
|
58
|
-
return batch
|
|
59
|
-
|
|
57
|
+
# Construct dataloaders
|
|
60
58
|
partition_train_test = partition_train_test.with_transform(apply_transforms)
|
|
61
59
|
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
|
|
62
60
|
testloader = DataLoader(partition_train_test["test"], batch_size=32)
|
|
63
61
|
return trainloader, testloader
|
|
64
62
|
|
|
65
63
|
|
|
66
|
-
def train(net, trainloader, epochs, device):
|
|
64
|
+
def train(net, trainloader, epochs, lr, device):
|
|
67
65
|
"""Train the model on the training set."""
|
|
68
66
|
net.to(device) # move model to GPU if available
|
|
69
67
|
criterion = torch.nn.CrossEntropyLoss().to(device)
|
|
70
|
-
optimizer = torch.optim.Adam(net.parameters(), lr=
|
|
68
|
+
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
|
|
71
69
|
net.train()
|
|
72
70
|
running_loss = 0.0
|
|
73
71
|
for _ in range(epochs):
|
|
74
72
|
for batch in trainloader:
|
|
75
|
-
images = batch["img"]
|
|
76
|
-
labels = batch["label"]
|
|
73
|
+
images = batch["img"].to(device)
|
|
74
|
+
labels = batch["label"].to(device)
|
|
77
75
|
optimizer.zero_grad()
|
|
78
|
-
loss = criterion(net(images
|
|
76
|
+
loss = criterion(net(images), labels)
|
|
79
77
|
loss.backward()
|
|
80
78
|
optimizer.step()
|
|
81
79
|
running_loss += loss.item()
|
|
82
|
-
|
|
83
80
|
avg_trainloss = running_loss / len(trainloader)
|
|
84
81
|
return avg_trainloss
|
|
85
82
|
|
|
@@ -99,13 +96,3 @@ def test(net, testloader, device):
|
|
|
99
96
|
accuracy = correct / len(testloader.dataset)
|
|
100
97
|
loss = loss / len(testloader)
|
|
101
98
|
return loss, accuracy
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
def get_weights(net):
|
|
105
|
-
return [val.cpu().numpy() for _, val in net.state_dict().items()]
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
def set_weights(net, parameters):
|
|
109
|
-
params_dict = zip(net.state_dict().keys(), parameters)
|
|
110
|
-
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
|
|
111
|
-
net.load_state_dict(state_dict, strict=True)
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
from collections import OrderedDict
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from flwr_datasets import FederatedDataset
|
|
9
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
10
|
+
from torch.utils.data import DataLoader
|
|
11
|
+
from torchvision.transforms import Compose, Normalize, ToTensor
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Net(nn.Module):
|
|
15
|
+
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""
|
|
16
|
+
|
|
17
|
+
def __init__(self):
|
|
18
|
+
super(Net, self).__init__()
|
|
19
|
+
self.conv1 = nn.Conv2d(3, 6, 5)
|
|
20
|
+
self.pool = nn.MaxPool2d(2, 2)
|
|
21
|
+
self.conv2 = nn.Conv2d(6, 16, 5)
|
|
22
|
+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
|
|
23
|
+
self.fc2 = nn.Linear(120, 84)
|
|
24
|
+
self.fc3 = nn.Linear(84, 10)
|
|
25
|
+
|
|
26
|
+
def forward(self, x):
|
|
27
|
+
x = self.pool(F.relu(self.conv1(x)))
|
|
28
|
+
x = self.pool(F.relu(self.conv2(x)))
|
|
29
|
+
x = x.view(-1, 16 * 5 * 5)
|
|
30
|
+
x = F.relu(self.fc1(x))
|
|
31
|
+
x = F.relu(self.fc2(x))
|
|
32
|
+
return self.fc3(x)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
fds = None # Cache FederatedDataset
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def load_data(partition_id: int, num_partitions: int):
|
|
39
|
+
"""Load partition CIFAR10 data."""
|
|
40
|
+
# Only initialize `FederatedDataset` once
|
|
41
|
+
global fds
|
|
42
|
+
if fds is None:
|
|
43
|
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
44
|
+
fds = FederatedDataset(
|
|
45
|
+
dataset="uoft-cs/cifar10",
|
|
46
|
+
partitioners={"train": partitioner},
|
|
47
|
+
)
|
|
48
|
+
partition = fds.load_partition(partition_id)
|
|
49
|
+
# Divide data on each node: 80% train, 20% test
|
|
50
|
+
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
|
51
|
+
pytorch_transforms = Compose(
|
|
52
|
+
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
def apply_transforms(batch):
|
|
56
|
+
"""Apply transforms to the partition from FederatedDataset."""
|
|
57
|
+
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
|
|
58
|
+
return batch
|
|
59
|
+
|
|
60
|
+
partition_train_test = partition_train_test.with_transform(apply_transforms)
|
|
61
|
+
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
|
|
62
|
+
testloader = DataLoader(partition_train_test["test"], batch_size=32)
|
|
63
|
+
return trainloader, testloader
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def train(net, trainloader, epochs, device):
|
|
67
|
+
"""Train the model on the training set."""
|
|
68
|
+
net.to(device) # move model to GPU if available
|
|
69
|
+
criterion = torch.nn.CrossEntropyLoss().to(device)
|
|
70
|
+
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
|
|
71
|
+
net.train()
|
|
72
|
+
running_loss = 0.0
|
|
73
|
+
for _ in range(epochs):
|
|
74
|
+
for batch in trainloader:
|
|
75
|
+
images = batch["img"]
|
|
76
|
+
labels = batch["label"]
|
|
77
|
+
optimizer.zero_grad()
|
|
78
|
+
loss = criterion(net(images.to(device)), labels.to(device))
|
|
79
|
+
loss.backward()
|
|
80
|
+
optimizer.step()
|
|
81
|
+
running_loss += loss.item()
|
|
82
|
+
|
|
83
|
+
avg_trainloss = running_loss / len(trainloader)
|
|
84
|
+
return avg_trainloss
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def test(net, testloader, device):
|
|
88
|
+
"""Validate the model on the test set."""
|
|
89
|
+
net.to(device)
|
|
90
|
+
criterion = torch.nn.CrossEntropyLoss()
|
|
91
|
+
correct, loss = 0, 0.0
|
|
92
|
+
with torch.no_grad():
|
|
93
|
+
for batch in testloader:
|
|
94
|
+
images = batch["img"].to(device)
|
|
95
|
+
labels = batch["label"].to(device)
|
|
96
|
+
outputs = net(images)
|
|
97
|
+
loss += criterion(outputs, labels).item()
|
|
98
|
+
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
|
|
99
|
+
accuracy = correct / len(testloader.dataset)
|
|
100
|
+
loss = loss / len(testloader)
|
|
101
|
+
return loss, accuracy
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def get_weights(net):
|
|
105
|
+
return [val.cpu().numpy() for _, val in net.state_dict().items()]
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def set_weights(net, parameters):
|
|
109
|
+
params_dict = zip(net.state_dict().keys(), parameters)
|
|
110
|
+
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
|
|
111
|
+
net.load_state_dict(state_dict, strict=True)
|
|
@@ -3,10 +3,9 @@
|
|
|
3
3
|
import os
|
|
4
4
|
|
|
5
5
|
import keras
|
|
6
|
-
from keras import layers
|
|
7
6
|
from flwr_datasets import FederatedDataset
|
|
8
7
|
from flwr_datasets.partitioner import IidPartitioner
|
|
9
|
-
|
|
8
|
+
from keras import layers
|
|
10
9
|
|
|
11
10
|
# Make TensorFlow log less verbose
|
|
12
11
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
import xgboost as xgb
|
|
4
|
+
from flwr_datasets import FederatedDataset
|
|
5
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def train_test_split(partition, test_fraction, seed):
|
|
9
|
+
"""Split the data into train and validation set given split rate."""
|
|
10
|
+
train_test = partition.train_test_split(test_size=test_fraction, seed=seed)
|
|
11
|
+
partition_train = train_test["train"]
|
|
12
|
+
partition_test = train_test["test"]
|
|
13
|
+
|
|
14
|
+
num_train = len(partition_train)
|
|
15
|
+
num_test = len(partition_test)
|
|
16
|
+
|
|
17
|
+
return partition_train, partition_test, num_train, num_test
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def transform_dataset_to_dmatrix(data):
|
|
21
|
+
"""Transform dataset to DMatrix format for xgboost."""
|
|
22
|
+
x = data["inputs"]
|
|
23
|
+
y = data["label"]
|
|
24
|
+
new_data = xgb.DMatrix(x, label=y)
|
|
25
|
+
return new_data
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
fds = None # Cache FederatedDataset
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def load_data(partition_id, num_clients):
|
|
32
|
+
"""Load partition HIGGS data."""
|
|
33
|
+
# Only initialize `FederatedDataset` once
|
|
34
|
+
global fds
|
|
35
|
+
if fds is None:
|
|
36
|
+
partitioner = IidPartitioner(num_partitions=num_clients)
|
|
37
|
+
fds = FederatedDataset(
|
|
38
|
+
dataset="jxie/higgs",
|
|
39
|
+
partitioners={"train": partitioner},
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
# Load the partition for this `partition_id`
|
|
43
|
+
partition = fds.load_partition(partition_id, split="train")
|
|
44
|
+
partition.set_format("numpy")
|
|
45
|
+
|
|
46
|
+
# Train/test splitting
|
|
47
|
+
train_data, valid_data, num_train, num_val = train_test_split(
|
|
48
|
+
partition, test_fraction=0.2, seed=42
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# Reformat data to DMatrix for xgboost
|
|
52
|
+
train_dmatrix = transform_dataset_to_dmatrix(train_data)
|
|
53
|
+
valid_dmatrix = transform_dataset_to_dmatrix(valid_data)
|
|
54
|
+
|
|
55
|
+
return train_dmatrix, valid_dmatrix, num_train, num_val
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def replace_keys(input_dict, match="-", target="_"):
|
|
59
|
+
"""Recursively replace match string with target string in dictionary keys."""
|
|
60
|
+
new_dict = {}
|
|
61
|
+
for key, value in input_dict.items():
|
|
62
|
+
new_key = key.replace(match, target)
|
|
63
|
+
if isinstance(value, dict):
|
|
64
|
+
new_dict[new_key] = replace_keys(value, match, target)
|
|
65
|
+
else:
|
|
66
|
+
new_dict[new_key] = value
|
|
67
|
+
return new_dict
|
|
@@ -14,10 +14,10 @@ description = ""
|
|
|
14
14
|
license = "Apache-2.0"
|
|
15
15
|
# Dependencies for your Flower App
|
|
16
16
|
dependencies = [
|
|
17
|
-
"flwr[simulation]>=1.
|
|
17
|
+
"flwr[simulation]>=1.22.0",
|
|
18
18
|
"flwr-datasets[vision]>=0.5.0",
|
|
19
|
-
"torch==2.
|
|
20
|
-
"torchvision==0.
|
|
19
|
+
"torch==2.8.0",
|
|
20
|
+
"torchvision==0.23.0",
|
|
21
21
|
]
|
|
22
22
|
|
|
23
23
|
[tool.hatch.metadata]
|
|
@@ -132,7 +132,7 @@ clientapp = "$import_name.client_app:app"
|
|
|
132
132
|
# Custom config values accessible via `context.run_config`
|
|
133
133
|
[tool.flwr.app.config]
|
|
134
134
|
num-server-rounds = 3
|
|
135
|
-
fraction-
|
|
135
|
+
fraction-train = 0.5
|
|
136
136
|
local-epochs = 1
|
|
137
137
|
|
|
138
138
|
# Default federation to use when running the app
|
|
@@ -14,7 +14,7 @@ description = ""
|
|
|
14
14
|
license = "Apache-2.0"
|
|
15
15
|
# Dependencies for your Flower App
|
|
16
16
|
dependencies = [
|
|
17
|
-
"flwr[simulation]>=1.
|
|
17
|
+
"flwr[simulation]>=1.22.0",
|
|
18
18
|
"flwr-datasets>=0.5.0",
|
|
19
19
|
"torch==2.4.0",
|
|
20
20
|
"trl==0.8.1",
|
|
@@ -61,7 +61,7 @@ train.training-arguments.save-steps = 1000
|
|
|
61
61
|
train.training-arguments.save-total-limit = 10
|
|
62
62
|
train.training-arguments.gradient-checkpointing = true
|
|
63
63
|
train.training-arguments.lr-scheduler-type = "constant"
|
|
64
|
-
strategy.fraction-
|
|
64
|
+
strategy.fraction-train = $fraction_train
|
|
65
65
|
strategy.fraction-evaluate = 0.0
|
|
66
66
|
num-server-rounds = 200
|
|
67
67
|
|
|
@@ -14,9 +14,9 @@ description = ""
|
|
|
14
14
|
license = "Apache-2.0"
|
|
15
15
|
# Dependencies for your Flower App
|
|
16
16
|
dependencies = [
|
|
17
|
-
"flwr[simulation]>=1.
|
|
17
|
+
"flwr[simulation]>=1.22.0",
|
|
18
18
|
"flwr-datasets>=0.5.0",
|
|
19
|
-
"torch
|
|
19
|
+
"torch>=2.7.1",
|
|
20
20
|
"transformers>=4.30.0,<5.0",
|
|
21
21
|
"evaluate>=0.4.0,<1.0",
|
|
22
22
|
"datasets>=2.0.0, <3.0",
|
|
@@ -38,8 +38,8 @@ clientapp = "$import_name.client_app:app"
|
|
|
38
38
|
# Custom config values accessible via `context.run_config`
|
|
39
39
|
[tool.flwr.app.config]
|
|
40
40
|
num-server-rounds = 3
|
|
41
|
-
fraction-
|
|
42
|
-
local-
|
|
41
|
+
fraction-train = 0.5
|
|
42
|
+
local-steps = 5
|
|
43
43
|
model-name = "prajjwal1/bert-tiny" # Set a larger model if you have access to more GPU resources
|
|
44
44
|
num-labels = 2
|
|
45
45
|
|
|
@@ -14,9 +14,9 @@ description = ""
|
|
|
14
14
|
license = "Apache-2.0"
|
|
15
15
|
# Dependencies for your Flower App
|
|
16
16
|
dependencies = [
|
|
17
|
-
"flwr[simulation]>=1.
|
|
17
|
+
"flwr[simulation]>=1.22.0",
|
|
18
18
|
"flwr-datasets[vision]>=0.5.0",
|
|
19
|
-
"mlx==0.
|
|
19
|
+
"mlx==0.29.0",
|
|
20
20
|
]
|
|
21
21
|
|
|
22
22
|
[tool.hatch.build.targets.wheel]
|
|
@@ -14,7 +14,7 @@ description = ""
|
|
|
14
14
|
license = "Apache-2.0"
|
|
15
15
|
# Dependencies for your Flower App
|
|
16
16
|
dependencies = [
|
|
17
|
-
"flwr[simulation]>=1.
|
|
17
|
+
"flwr[simulation]>=1.22.0",
|
|
18
18
|
"flwr-datasets[vision]>=0.5.0",
|
|
19
19
|
"torch==2.7.1",
|
|
20
20
|
"torchvision==0.22.1",
|
|
@@ -27,7 +27,6 @@ packages = ["."]
|
|
|
27
27
|
publisher = "$username"
|
|
28
28
|
|
|
29
29
|
# Point to your ServerApp and ClientApp objects
|
|
30
|
-
# Format: "<module>:<object>"
|
|
31
30
|
[tool.flwr.app.components]
|
|
32
31
|
serverapp = "$import_name.server_app:app"
|
|
33
32
|
clientapp = "$import_name.client_app:app"
|
|
@@ -35,8 +34,9 @@ clientapp = "$import_name.client_app:app"
|
|
|
35
34
|
# Custom config values accessible via `context.run_config`
|
|
36
35
|
[tool.flwr.app.config]
|
|
37
36
|
num-server-rounds = 3
|
|
38
|
-
fraction-
|
|
37
|
+
fraction-train = 0.5
|
|
39
38
|
local-epochs = 1
|
|
39
|
+
lr = 0.01
|
|
40
40
|
|
|
41
41
|
# Default federation to use when running the app
|
|
42
42
|
[tool.flwr.federations]
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# =====================================================================
|
|
2
|
+
# For a full TOML configuration guide, check the Flower docs:
|
|
3
|
+
# https://flower.ai/docs/framework/how-to-configure-pyproject-toml.html
|
|
4
|
+
# =====================================================================
|
|
5
|
+
|
|
6
|
+
[build-system]
|
|
7
|
+
requires = ["hatchling"]
|
|
8
|
+
build-backend = "hatchling.build"
|
|
9
|
+
|
|
10
|
+
[project]
|
|
11
|
+
name = "$package_name"
|
|
12
|
+
version = "1.0.0"
|
|
13
|
+
description = ""
|
|
14
|
+
license = "Apache-2.0"
|
|
15
|
+
# Dependencies for your Flower App
|
|
16
|
+
dependencies = [
|
|
17
|
+
"flwr[simulation]>=1.22.0",
|
|
18
|
+
"flwr-datasets[vision]>=0.5.0",
|
|
19
|
+
"torch==2.7.1",
|
|
20
|
+
"torchvision==0.22.1",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
[tool.hatch.build.targets.wheel]
|
|
24
|
+
packages = ["."]
|
|
25
|
+
|
|
26
|
+
[tool.flwr.app]
|
|
27
|
+
publisher = "$username"
|
|
28
|
+
|
|
29
|
+
# Point to your ServerApp and ClientApp objects
|
|
30
|
+
# Format: "<module>:<object>"
|
|
31
|
+
[tool.flwr.app.components]
|
|
32
|
+
serverapp = "$import_name.server_app:app"
|
|
33
|
+
clientapp = "$import_name.client_app:app"
|
|
34
|
+
|
|
35
|
+
# Custom config values accessible via `context.run_config`
|
|
36
|
+
[tool.flwr.app.config]
|
|
37
|
+
num-server-rounds = 3
|
|
38
|
+
fraction-fit = 0.5
|
|
39
|
+
local-epochs = 1
|
|
40
|
+
|
|
41
|
+
# Default federation to use when running the app
|
|
42
|
+
[tool.flwr.federations]
|
|
43
|
+
default = "local-simulation"
|
|
44
|
+
|
|
45
|
+
# Local simulation federation with 10 virtual SuperNodes
|
|
46
|
+
[tool.flwr.federations.local-simulation]
|
|
47
|
+
options.num-supernodes = 10
|
|
48
|
+
|
|
49
|
+
# Remote federation example for use with SuperLink
|
|
50
|
+
[tool.flwr.federations.remote-federation]
|
|
51
|
+
address = "<SUPERLINK-ADDRESS>:<PORT>"
|
|
52
|
+
insecure = true # Remove this line to enable TLS
|
|
53
|
+
# root-certificates = "<PATH/TO/ca.crt>" # For TLS setup
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# =====================================================================
|
|
2
|
+
# For a full TOML configuration guide, check the Flower docs:
|
|
3
|
+
# https://flower.ai/docs/framework/how-to-configure-pyproject-toml.html
|
|
4
|
+
# =====================================================================
|
|
5
|
+
|
|
6
|
+
[build-system]
|
|
7
|
+
requires = ["hatchling"]
|
|
8
|
+
build-backend = "hatchling.build"
|
|
9
|
+
|
|
10
|
+
[project]
|
|
11
|
+
name = "$package_name"
|
|
12
|
+
version = "1.0.0"
|
|
13
|
+
description = ""
|
|
14
|
+
license = "Apache-2.0"
|
|
15
|
+
# Dependencies for your Flower App
|
|
16
|
+
dependencies = [
|
|
17
|
+
"flwr[simulation]>=1.22.0",
|
|
18
|
+
"flwr-datasets>=0.5.0",
|
|
19
|
+
"xgboost>=2.0.0",
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
[tool.hatch.build.targets.wheel]
|
|
23
|
+
packages = ["."]
|
|
24
|
+
|
|
25
|
+
[tool.flwr.app]
|
|
26
|
+
publisher = "$username"
|
|
27
|
+
|
|
28
|
+
[tool.flwr.app.components]
|
|
29
|
+
serverapp = "$import_name.server_app:app"
|
|
30
|
+
clientapp = "$import_name.client_app:app"
|
|
31
|
+
|
|
32
|
+
# Custom config values accessible via `context.run_config`
|
|
33
|
+
[tool.flwr.app.config]
|
|
34
|
+
num-server-rounds = 3
|
|
35
|
+
fraction-train = 0.1
|
|
36
|
+
fraction-evaluate = 0.1
|
|
37
|
+
local-epochs = 1
|
|
38
|
+
|
|
39
|
+
# XGBoost parameters
|
|
40
|
+
params.objective = "binary:logistic"
|
|
41
|
+
params.eta = 0.1 # Learning rate
|
|
42
|
+
params.max-depth = 8
|
|
43
|
+
params.eval-metric = "auc"
|
|
44
|
+
params.nthread = 16
|
|
45
|
+
params.num-parallel-tree = 1
|
|
46
|
+
params.subsample = 1
|
|
47
|
+
params.tree-method = "hist"
|
|
48
|
+
|
|
49
|
+
# Default federation to use when running the app
|
|
50
|
+
[tool.flwr.federations]
|
|
51
|
+
default = "local-simulation"
|
|
52
|
+
|
|
53
|
+
# Local simulation federation with 10 virtual SuperNodes
|
|
54
|
+
[tool.flwr.federations.local-simulation]
|
|
55
|
+
options.num-supernodes = 10
|
|
56
|
+
|
|
57
|
+
# Remote federation example for use with SuperLink
|
|
58
|
+
[tool.flwr.federations.remote-federation]
|
|
59
|
+
address = "<SUPERLINK-ADDRESS>:<PORT>"
|
|
60
|
+
insecure = true # Remove this line to enable TLS
|
|
61
|
+
# root-certificates = "<PATH/TO/ca.crt>" # For TLS setup
|