flwr 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. flwr/__init__.py +4 -1
  2. flwr/app/__init__.py +28 -0
  3. flwr/app/exception.py +31 -0
  4. flwr/cli/app.py +2 -0
  5. flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
  6. flwr/cli/cli_user_auth_interceptor.py +1 -1
  7. flwr/cli/config_utils.py +3 -3
  8. flwr/cli/constant.py +25 -8
  9. flwr/cli/log.py +9 -9
  10. flwr/cli/login/login.py +3 -3
  11. flwr/cli/ls.py +5 -5
  12. flwr/cli/new/new.py +15 -2
  13. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  14. flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +1 -0
  15. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  16. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
  17. flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
  18. flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
  19. flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
  20. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
  21. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
  22. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
  23. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
  24. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
  25. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  26. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  27. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  28. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  29. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  30. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  31. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
  32. flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
  33. flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
  34. flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
  35. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
  36. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
  37. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
  38. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
  39. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
  40. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
  41. flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
  42. flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
  43. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
  44. flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +111 -0
  45. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
  46. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
  47. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  48. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  49. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  50. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  51. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  52. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  53. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  54. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +53 -0
  55. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  56. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  57. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
  58. flwr/cli/pull.py +100 -0
  59. flwr/cli/run/run.py +9 -13
  60. flwr/cli/stop.py +7 -4
  61. flwr/cli/utils.py +36 -8
  62. flwr/client/grpc_rere_client/connection.py +1 -12
  63. flwr/client/rest_client/connection.py +3 -0
  64. flwr/clientapp/__init__.py +10 -0
  65. flwr/clientapp/mod/__init__.py +29 -0
  66. flwr/clientapp/mod/centraldp_mods.py +248 -0
  67. flwr/clientapp/mod/localdp_mod.py +169 -0
  68. flwr/clientapp/typing.py +22 -0
  69. flwr/common/args.py +20 -6
  70. flwr/common/auth_plugin/__init__.py +4 -4
  71. flwr/common/auth_plugin/auth_plugin.py +7 -7
  72. flwr/common/constant.py +26 -4
  73. flwr/common/event_log_plugin/event_log_plugin.py +1 -1
  74. flwr/common/exit/__init__.py +4 -0
  75. flwr/common/exit/exit.py +8 -1
  76. flwr/common/exit/exit_code.py +30 -7
  77. flwr/common/exit/exit_handler.py +62 -0
  78. flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
  79. flwr/common/grpc.py +0 -11
  80. flwr/common/inflatable_utils.py +1 -1
  81. flwr/common/logger.py +1 -1
  82. flwr/common/record/typeddict.py +12 -0
  83. flwr/common/retry_invoker.py +30 -11
  84. flwr/common/telemetry.py +4 -0
  85. flwr/compat/server/app.py +2 -2
  86. flwr/proto/appio_pb2.py +25 -17
  87. flwr/proto/appio_pb2.pyi +46 -2
  88. flwr/proto/clientappio_pb2.py +3 -11
  89. flwr/proto/clientappio_pb2.pyi +0 -47
  90. flwr/proto/clientappio_pb2_grpc.py +19 -20
  91. flwr/proto/clientappio_pb2_grpc.pyi +10 -11
  92. flwr/proto/control_pb2.py +66 -0
  93. flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +24 -0
  94. flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +88 -54
  95. flwr/proto/control_pb2_grpc.pyi +106 -0
  96. flwr/proto/serverappio_pb2.py +2 -2
  97. flwr/proto/serverappio_pb2_grpc.py +68 -0
  98. flwr/proto/serverappio_pb2_grpc.pyi +26 -0
  99. flwr/proto/simulationio_pb2.py +4 -11
  100. flwr/proto/simulationio_pb2.pyi +0 -58
  101. flwr/proto/simulationio_pb2_grpc.py +129 -27
  102. flwr/proto/simulationio_pb2_grpc.pyi +52 -13
  103. flwr/server/app.py +142 -152
  104. flwr/server/grid/grpc_grid.py +3 -0
  105. flwr/server/grid/inmemory_grid.py +1 -0
  106. flwr/server/serverapp/app.py +157 -146
  107. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
  108. flwr/server/superlink/fleet/vce/vce_api.py +6 -6
  109. flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
  110. flwr/server/superlink/linkstate/linkstate.py +2 -1
  111. flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
  112. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -1
  113. flwr/server/superlink/serverappio/serverappio_servicer.py +61 -6
  114. flwr/server/superlink/simulation/simulationio_servicer.py +97 -21
  115. flwr/serverapp/__init__.py +12 -0
  116. flwr/serverapp/exception.py +38 -0
  117. flwr/serverapp/strategy/__init__.py +64 -0
  118. flwr/serverapp/strategy/bulyan.py +238 -0
  119. flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
  120. flwr/serverapp/strategy/dp_fixed_clipping.py +374 -0
  121. flwr/serverapp/strategy/fedadagrad.py +159 -0
  122. flwr/serverapp/strategy/fedadam.py +178 -0
  123. flwr/serverapp/strategy/fedavg.py +320 -0
  124. flwr/serverapp/strategy/fedavgm.py +198 -0
  125. flwr/serverapp/strategy/fedmedian.py +105 -0
  126. flwr/serverapp/strategy/fedopt.py +218 -0
  127. flwr/serverapp/strategy/fedprox.py +174 -0
  128. flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
  129. flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
  130. flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
  131. flwr/serverapp/strategy/fedyogi.py +170 -0
  132. flwr/serverapp/strategy/krum.py +112 -0
  133. flwr/serverapp/strategy/multikrum.py +247 -0
  134. flwr/serverapp/strategy/qfedavg.py +252 -0
  135. flwr/serverapp/strategy/result.py +105 -0
  136. flwr/serverapp/strategy/strategy.py +285 -0
  137. flwr/serverapp/strategy/strategy_utils.py +299 -0
  138. flwr/simulation/app.py +161 -164
  139. flwr/simulation/run_simulation.py +25 -30
  140. flwr/supercore/app_utils.py +58 -0
  141. flwr/{supernode/scheduler → supercore/cli}/__init__.py +3 -3
  142. flwr/supercore/cli/flower_superexec.py +166 -0
  143. flwr/supercore/constant.py +19 -0
  144. flwr/supercore/{scheduler → corestate}/__init__.py +3 -3
  145. flwr/supercore/corestate/corestate.py +81 -0
  146. flwr/supercore/grpc_health/__init__.py +3 -0
  147. flwr/supercore/grpc_health/health_server.py +53 -0
  148. flwr/supercore/grpc_health/simple_health_servicer.py +2 -2
  149. flwr/{superexec → supercore/superexec}/__init__.py +1 -1
  150. flwr/supercore/superexec/plugin/__init__.py +28 -0
  151. flwr/{supernode/scheduler/simple_clientapp_scheduler_plugin.py → supercore/superexec/plugin/base_exec_plugin.py} +10 -6
  152. flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
  153. flwr/supercore/{scheduler/plugin.py → superexec/plugin/exec_plugin.py} +15 -5
  154. flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
  155. flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
  156. flwr/supercore/superexec/run_superexec.py +199 -0
  157. flwr/superlink/artifact_provider/__init__.py +22 -0
  158. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  159. flwr/superlink/servicer/__init__.py +15 -0
  160. flwr/superlink/servicer/control/__init__.py +22 -0
  161. flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +7 -7
  162. flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +27 -29
  163. flwr/{superexec/exec_license_interceptor.py → superlink/servicer/control/control_license_interceptor.py} +6 -6
  164. flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +127 -31
  165. flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +10 -10
  166. flwr/supernode/cli/flower_supernode.py +3 -0
  167. flwr/supernode/cli/flwr_clientapp.py +18 -21
  168. flwr/supernode/nodestate/in_memory_nodestate.py +2 -2
  169. flwr/supernode/nodestate/nodestate.py +3 -59
  170. flwr/supernode/runtime/run_clientapp.py +39 -102
  171. flwr/supernode/servicer/clientappio/clientappio_servicer.py +10 -17
  172. flwr/supernode/start_client_internal.py +35 -76
  173. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/METADATA +9 -18
  174. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/RECORD +176 -128
  175. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/entry_points.txt +1 -0
  176. flwr/proto/exec_pb2.py +0 -62
  177. flwr/proto/exec_pb2_grpc.pyi +0 -93
  178. flwr/superexec/app.py +0 -45
  179. flwr/superexec/deployment.py +0 -191
  180. flwr/superexec/executor.py +0 -100
  181. flwr/superexec/simulation.py +0 -129
  182. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,56 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import numpy as np
4
+ import xgboost as xgb
5
+ from flwr.app import ArrayRecord, Context
6
+ from flwr.common.config import unflatten_dict
7
+ from flwr.serverapp import Grid, ServerApp
8
+ from flwr.serverapp.strategy import FedXgbBagging
9
+
10
+ from $import_name.task import replace_keys
11
+
12
+ # Create ServerApp
13
+ app = ServerApp()
14
+
15
+
16
+ @app.main()
17
+ def main(grid: Grid, context: Context) -> None:
18
+ # Read run config
19
+ num_rounds = context.run_config["num-server-rounds"]
20
+ fraction_train = context.run_config["fraction-train"]
21
+ fraction_evaluate = context.run_config["fraction-evaluate"]
22
+ # Flatted config dict and replace "-" with "_"
23
+ cfg = replace_keys(unflatten_dict(context.run_config))
24
+ params = cfg["params"]
25
+
26
+ # Init global model
27
+ # Init with an empty object; the XGBooster will be created
28
+ # and trained on the client side.
29
+ global_model = b""
30
+ # Note: we store the model as the first item in a list into ArrayRecord,
31
+ # which can be accessed using index ["0"].
32
+ arrays = ArrayRecord([np.frombuffer(global_model, dtype=np.uint8)])
33
+
34
+ # Initialize FedXgbBagging strategy
35
+ strategy = FedXgbBagging(
36
+ fraction_train=fraction_train,
37
+ fraction_evaluate=fraction_evaluate,
38
+ )
39
+
40
+ # Start strategy, run FedXgbBagging for `num_rounds`
41
+ result = strategy.start(
42
+ grid=grid,
43
+ initial_arrays=arrays,
44
+ num_rounds=num_rounds,
45
+ )
46
+
47
+ # Save final model to disk
48
+ bst = xgb.Booster(params=params)
49
+ global_model = bytearray(result.arrays["0"].numpy().tobytes())
50
+
51
+ # Load global model into booster
52
+ bst.load_model(global_model)
53
+
54
+ # Save model
55
+ print("\nSaving final model to disk...")
56
+ bst.save_model("final_model.json")
@@ -1,7 +1,6 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import warnings
4
- from collections import OrderedDict
5
4
 
6
5
  import torch
7
6
  import transformers
@@ -62,17 +61,24 @@ def load_data(partition_id: int, num_partitions: int, model_name: str):
62
61
  return trainloader, testloader
63
62
 
64
63
 
65
- def train(net, trainloader, epochs, device):
64
+ def train(net, trainloader, num_steps, device):
66
65
  optimizer = AdamW(net.parameters(), lr=5e-5)
67
66
  net.train()
68
- for _ in range(epochs):
69
- for batch in trainloader:
70
- batch = {k: v.to(device) for k, v in batch.items()}
71
- outputs = net(**batch)
72
- loss = outputs.loss
73
- loss.backward()
74
- optimizer.step()
75
- optimizer.zero_grad()
67
+ running_loss = 0.0
68
+ step_cnt = 0
69
+ for batch in trainloader:
70
+ batch = {k: v.to(device) for k, v in batch.items()}
71
+ outputs = net(**batch)
72
+ loss = outputs.loss
73
+ loss.backward()
74
+ optimizer.step()
75
+ optimizer.zero_grad()
76
+ running_loss += loss.item()
77
+ step_cnt += 1
78
+ if step_cnt >= num_steps:
79
+ break
80
+ avg_trainloss = running_loss / step_cnt
81
+ return avg_trainloss
76
82
 
77
83
 
78
84
  def test(net, testloader, device):
@@ -90,13 +96,3 @@ def test(net, testloader, device):
90
96
  loss /= len(testloader.dataset)
91
97
  accuracy = metric.compute()["accuracy"]
92
98
  return loss, accuracy
93
-
94
-
95
- def get_weights(net):
96
- return [val.cpu().numpy() for _, val in net.state_dict().items()]
97
-
98
-
99
- def set_weights(net, parameters):
100
- params_dict = zip(net.state_dict().keys(), parameters)
101
- state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
102
- net.load_state_dict(state_dict, strict=True)
@@ -31,7 +31,7 @@ def loss_fn(params, X, y):
31
31
  def train(params, grad_fn, X, y):
32
32
  loss = 1_000_000
33
33
  num_examples = X.shape[0]
34
- for epochs in range(50):
34
+ for _ in range(50):
35
35
  grads = grad_fn(params, X, y)
36
36
  params = jax.tree.map(lambda p, g: p - 0.05 * g, params, grads)
37
37
  loss = loss_fn(params, X, y)
@@ -4,4 +4,4 @@ import numpy as np
4
4
 
5
5
 
6
6
  def get_dummy_model():
7
- return np.ones((1, 1))
7
+ return [np.ones((1, 1))]
@@ -1,7 +1,5 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from collections import OrderedDict
4
-
5
3
  import torch
6
4
  import torch.nn as nn
7
5
  import torch.nn.functional as F
@@ -34,6 +32,14 @@ class Net(nn.Module):
34
32
 
35
33
  fds = None # Cache FederatedDataset
36
34
 
35
+ pytorch_transforms = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
36
+
37
+
38
+ def apply_transforms(batch):
39
+ """Apply transforms to the partition from FederatedDataset."""
40
+ batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
41
+ return batch
42
+
37
43
 
38
44
  def load_data(partition_id: int, num_partitions: int):
39
45
  """Load partition CIFAR10 data."""
@@ -48,38 +54,29 @@ def load_data(partition_id: int, num_partitions: int):
48
54
  partition = fds.load_partition(partition_id)
49
55
  # Divide data on each node: 80% train, 20% test
50
56
  partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
51
- pytorch_transforms = Compose(
52
- [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
53
- )
54
-
55
- def apply_transforms(batch):
56
- """Apply transforms to the partition from FederatedDataset."""
57
- batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
58
- return batch
59
-
57
+ # Construct dataloaders
60
58
  partition_train_test = partition_train_test.with_transform(apply_transforms)
61
59
  trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
62
60
  testloader = DataLoader(partition_train_test["test"], batch_size=32)
63
61
  return trainloader, testloader
64
62
 
65
63
 
66
- def train(net, trainloader, epochs, device):
64
+ def train(net, trainloader, epochs, lr, device):
67
65
  """Train the model on the training set."""
68
66
  net.to(device) # move model to GPU if available
69
67
  criterion = torch.nn.CrossEntropyLoss().to(device)
70
- optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
68
+ optimizer = torch.optim.Adam(net.parameters(), lr=lr)
71
69
  net.train()
72
70
  running_loss = 0.0
73
71
  for _ in range(epochs):
74
72
  for batch in trainloader:
75
- images = batch["img"]
76
- labels = batch["label"]
73
+ images = batch["img"].to(device)
74
+ labels = batch["label"].to(device)
77
75
  optimizer.zero_grad()
78
- loss = criterion(net(images.to(device)), labels.to(device))
76
+ loss = criterion(net(images), labels)
79
77
  loss.backward()
80
78
  optimizer.step()
81
79
  running_loss += loss.item()
82
-
83
80
  avg_trainloss = running_loss / len(trainloader)
84
81
  return avg_trainloss
85
82
 
@@ -99,13 +96,3 @@ def test(net, testloader, device):
99
96
  accuracy = correct / len(testloader.dataset)
100
97
  loss = loss / len(testloader)
101
98
  return loss, accuracy
102
-
103
-
104
- def get_weights(net):
105
- return [val.cpu().numpy() for _, val in net.state_dict().items()]
106
-
107
-
108
- def set_weights(net, parameters):
109
- params_dict = zip(net.state_dict().keys(), parameters)
110
- state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
111
- net.load_state_dict(state_dict, strict=True)
@@ -0,0 +1,111 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ from collections import OrderedDict
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from flwr_datasets import FederatedDataset
9
+ from flwr_datasets.partitioner import IidPartitioner
10
+ from torch.utils.data import DataLoader
11
+ from torchvision.transforms import Compose, Normalize, ToTensor
12
+
13
+
14
+ class Net(nn.Module):
15
+ """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""
16
+
17
+ def __init__(self):
18
+ super(Net, self).__init__()
19
+ self.conv1 = nn.Conv2d(3, 6, 5)
20
+ self.pool = nn.MaxPool2d(2, 2)
21
+ self.conv2 = nn.Conv2d(6, 16, 5)
22
+ self.fc1 = nn.Linear(16 * 5 * 5, 120)
23
+ self.fc2 = nn.Linear(120, 84)
24
+ self.fc3 = nn.Linear(84, 10)
25
+
26
+ def forward(self, x):
27
+ x = self.pool(F.relu(self.conv1(x)))
28
+ x = self.pool(F.relu(self.conv2(x)))
29
+ x = x.view(-1, 16 * 5 * 5)
30
+ x = F.relu(self.fc1(x))
31
+ x = F.relu(self.fc2(x))
32
+ return self.fc3(x)
33
+
34
+
35
+ fds = None # Cache FederatedDataset
36
+
37
+
38
+ def load_data(partition_id: int, num_partitions: int):
39
+ """Load partition CIFAR10 data."""
40
+ # Only initialize `FederatedDataset` once
41
+ global fds
42
+ if fds is None:
43
+ partitioner = IidPartitioner(num_partitions=num_partitions)
44
+ fds = FederatedDataset(
45
+ dataset="uoft-cs/cifar10",
46
+ partitioners={"train": partitioner},
47
+ )
48
+ partition = fds.load_partition(partition_id)
49
+ # Divide data on each node: 80% train, 20% test
50
+ partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
51
+ pytorch_transforms = Compose(
52
+ [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
53
+ )
54
+
55
+ def apply_transforms(batch):
56
+ """Apply transforms to the partition from FederatedDataset."""
57
+ batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
58
+ return batch
59
+
60
+ partition_train_test = partition_train_test.with_transform(apply_transforms)
61
+ trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
62
+ testloader = DataLoader(partition_train_test["test"], batch_size=32)
63
+ return trainloader, testloader
64
+
65
+
66
+ def train(net, trainloader, epochs, device):
67
+ """Train the model on the training set."""
68
+ net.to(device) # move model to GPU if available
69
+ criterion = torch.nn.CrossEntropyLoss().to(device)
70
+ optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
71
+ net.train()
72
+ running_loss = 0.0
73
+ for _ in range(epochs):
74
+ for batch in trainloader:
75
+ images = batch["img"]
76
+ labels = batch["label"]
77
+ optimizer.zero_grad()
78
+ loss = criterion(net(images.to(device)), labels.to(device))
79
+ loss.backward()
80
+ optimizer.step()
81
+ running_loss += loss.item()
82
+
83
+ avg_trainloss = running_loss / len(trainloader)
84
+ return avg_trainloss
85
+
86
+
87
+ def test(net, testloader, device):
88
+ """Validate the model on the test set."""
89
+ net.to(device)
90
+ criterion = torch.nn.CrossEntropyLoss()
91
+ correct, loss = 0, 0.0
92
+ with torch.no_grad():
93
+ for batch in testloader:
94
+ images = batch["img"].to(device)
95
+ labels = batch["label"].to(device)
96
+ outputs = net(images)
97
+ loss += criterion(outputs, labels).item()
98
+ correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
99
+ accuracy = correct / len(testloader.dataset)
100
+ loss = loss / len(testloader)
101
+ return loss, accuracy
102
+
103
+
104
+ def get_weights(net):
105
+ return [val.cpu().numpy() for _, val in net.state_dict().items()]
106
+
107
+
108
+ def set_weights(net, parameters):
109
+ params_dict = zip(net.state_dict().keys(), parameters)
110
+ state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
111
+ net.load_state_dict(state_dict, strict=True)
@@ -3,10 +3,9 @@
3
3
  import os
4
4
 
5
5
  import keras
6
- from keras import layers
7
6
  from flwr_datasets import FederatedDataset
8
7
  from flwr_datasets.partitioner import IidPartitioner
9
-
8
+ from keras import layers
10
9
 
11
10
  # Make TensorFlow log less verbose
12
11
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -0,0 +1,67 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import xgboost as xgb
4
+ from flwr_datasets import FederatedDataset
5
+ from flwr_datasets.partitioner import IidPartitioner
6
+
7
+
8
+ def train_test_split(partition, test_fraction, seed):
9
+ """Split the data into train and validation set given split rate."""
10
+ train_test = partition.train_test_split(test_size=test_fraction, seed=seed)
11
+ partition_train = train_test["train"]
12
+ partition_test = train_test["test"]
13
+
14
+ num_train = len(partition_train)
15
+ num_test = len(partition_test)
16
+
17
+ return partition_train, partition_test, num_train, num_test
18
+
19
+
20
+ def transform_dataset_to_dmatrix(data):
21
+ """Transform dataset to DMatrix format for xgboost."""
22
+ x = data["inputs"]
23
+ y = data["label"]
24
+ new_data = xgb.DMatrix(x, label=y)
25
+ return new_data
26
+
27
+
28
+ fds = None # Cache FederatedDataset
29
+
30
+
31
+ def load_data(partition_id, num_clients):
32
+ """Load partition HIGGS data."""
33
+ # Only initialize `FederatedDataset` once
34
+ global fds
35
+ if fds is None:
36
+ partitioner = IidPartitioner(num_partitions=num_clients)
37
+ fds = FederatedDataset(
38
+ dataset="jxie/higgs",
39
+ partitioners={"train": partitioner},
40
+ )
41
+
42
+ # Load the partition for this `partition_id`
43
+ partition = fds.load_partition(partition_id, split="train")
44
+ partition.set_format("numpy")
45
+
46
+ # Train/test splitting
47
+ train_data, valid_data, num_train, num_val = train_test_split(
48
+ partition, test_fraction=0.2, seed=42
49
+ )
50
+
51
+ # Reformat data to DMatrix for xgboost
52
+ train_dmatrix = transform_dataset_to_dmatrix(train_data)
53
+ valid_dmatrix = transform_dataset_to_dmatrix(valid_data)
54
+
55
+ return train_dmatrix, valid_dmatrix, num_train, num_val
56
+
57
+
58
+ def replace_keys(input_dict, match="-", target="_"):
59
+ """Recursively replace match string with target string in dictionary keys."""
60
+ new_dict = {}
61
+ for key, value in input_dict.items():
62
+ new_key = key.replace(match, target)
63
+ if isinstance(value, dict):
64
+ new_dict[new_key] = replace_keys(value, match, target)
65
+ else:
66
+ new_dict[new_key] = value
67
+ return new_dict
@@ -14,10 +14,10 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.20.0",
17
+ "flwr[simulation]>=1.22.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
- "torch==2.7.1",
20
- "torchvision==0.22.1",
19
+ "torch==2.8.0",
20
+ "torchvision==0.23.0",
21
21
  ]
22
22
 
23
23
  [tool.hatch.metadata]
@@ -132,7 +132,7 @@ clientapp = "$import_name.client_app:app"
132
132
  # Custom config values accessible via `context.run_config`
133
133
  [tool.flwr.app.config]
134
134
  num-server-rounds = 3
135
- fraction-fit = 0.5
135
+ fraction-train = 0.5
136
136
  local-epochs = 1
137
137
 
138
138
  # Default federation to use when running the app
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.20.0",
17
+ "flwr[simulation]>=1.22.0",
18
18
  "flwr-datasets>=0.5.0",
19
19
  "torch==2.4.0",
20
20
  "trl==0.8.1",
@@ -61,7 +61,7 @@ train.training-arguments.save-steps = 1000
61
61
  train.training-arguments.save-total-limit = 10
62
62
  train.training-arguments.gradient-checkpointing = true
63
63
  train.training-arguments.lr-scheduler-type = "constant"
64
- strategy.fraction-fit = $fraction_fit
64
+ strategy.fraction-train = $fraction_train
65
65
  strategy.fraction-evaluate = 0.0
66
66
  num-server-rounds = 200
67
67
 
@@ -14,9 +14,9 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.20.0",
17
+ "flwr[simulation]>=1.22.0",
18
18
  "flwr-datasets>=0.5.0",
19
- "torch==2.7.1",
19
+ "torch>=2.7.1",
20
20
  "transformers>=4.30.0,<5.0",
21
21
  "evaluate>=0.4.0,<1.0",
22
22
  "datasets>=2.0.0, <3.0",
@@ -38,8 +38,8 @@ clientapp = "$import_name.client_app:app"
38
38
  # Custom config values accessible via `context.run_config`
39
39
  [tool.flwr.app.config]
40
40
  num-server-rounds = 3
41
- fraction-fit = 0.5
42
- local-epochs = 1
41
+ fraction-train = 0.5
42
+ local-steps = 5
43
43
  model-name = "prajjwal1/bert-tiny" # Set a larger model if you have access to more GPU resources
44
44
  num-labels = 2
45
45
 
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.20.0",
17
+ "flwr[simulation]>=1.22.0",
18
18
  "jax==0.4.30",
19
19
  "jaxlib==0.4.30",
20
20
  "scikit-learn==1.6.1",
@@ -14,9 +14,9 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.20.0",
17
+ "flwr[simulation]>=1.22.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
- "mlx==0.26.5",
19
+ "mlx==0.29.0",
20
20
  ]
21
21
 
22
22
  [tool.hatch.build.targets.wheel]
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.20.0",
17
+ "flwr[simulation]>=1.22.0",
18
18
  "numpy>=2.0.2",
19
19
  ]
20
20
 
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.20.0",
17
+ "flwr[simulation]>=1.22.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
19
  "torch==2.7.1",
20
20
  "torchvision==0.22.1",
@@ -27,7 +27,6 @@ packages = ["."]
27
27
  publisher = "$username"
28
28
 
29
29
  # Point to your ServerApp and ClientApp objects
30
- # Format: "<module>:<object>"
31
30
  [tool.flwr.app.components]
32
31
  serverapp = "$import_name.server_app:app"
33
32
  clientapp = "$import_name.client_app:app"
@@ -35,8 +34,9 @@ clientapp = "$import_name.client_app:app"
35
34
  # Custom config values accessible via `context.run_config`
36
35
  [tool.flwr.app.config]
37
36
  num-server-rounds = 3
38
- fraction-fit = 0.5
37
+ fraction-train = 0.5
39
38
  local-epochs = 1
39
+ lr = 0.01
40
40
 
41
41
  # Default federation to use when running the app
42
42
  [tool.flwr.federations]
@@ -0,0 +1,53 @@
1
+ # =====================================================================
2
+ # For a full TOML configuration guide, check the Flower docs:
3
+ # https://flower.ai/docs/framework/how-to-configure-pyproject-toml.html
4
+ # =====================================================================
5
+
6
+ [build-system]
7
+ requires = ["hatchling"]
8
+ build-backend = "hatchling.build"
9
+
10
+ [project]
11
+ name = "$package_name"
12
+ version = "1.0.0"
13
+ description = ""
14
+ license = "Apache-2.0"
15
+ # Dependencies for your Flower App
16
+ dependencies = [
17
+ "flwr[simulation]>=1.22.0",
18
+ "flwr-datasets[vision]>=0.5.0",
19
+ "torch==2.7.1",
20
+ "torchvision==0.22.1",
21
+ ]
22
+
23
+ [tool.hatch.build.targets.wheel]
24
+ packages = ["."]
25
+
26
+ [tool.flwr.app]
27
+ publisher = "$username"
28
+
29
+ # Point to your ServerApp and ClientApp objects
30
+ # Format: "<module>:<object>"
31
+ [tool.flwr.app.components]
32
+ serverapp = "$import_name.server_app:app"
33
+ clientapp = "$import_name.client_app:app"
34
+
35
+ # Custom config values accessible via `context.run_config`
36
+ [tool.flwr.app.config]
37
+ num-server-rounds = 3
38
+ fraction-fit = 0.5
39
+ local-epochs = 1
40
+
41
+ # Default federation to use when running the app
42
+ [tool.flwr.federations]
43
+ default = "local-simulation"
44
+
45
+ # Local simulation federation with 10 virtual SuperNodes
46
+ [tool.flwr.federations.local-simulation]
47
+ options.num-supernodes = 10
48
+
49
+ # Remote federation example for use with SuperLink
50
+ [tool.flwr.federations.remote-federation]
51
+ address = "<SUPERLINK-ADDRESS>:<PORT>"
52
+ insecure = true # Remove this line to enable TLS
53
+ # root-certificates = "<PATH/TO/ca.crt>" # For TLS setup
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.20.0",
17
+ "flwr[simulation]>=1.22.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
19
  "scikit-learn>=1.6.1",
20
20
  ]
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.20.0",
17
+ "flwr[simulation]>=1.22.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
19
  "tensorflow>=2.11.1,<2.18.0",
20
20
  ]
@@ -0,0 +1,61 @@
1
+ # =====================================================================
2
+ # For a full TOML configuration guide, check the Flower docs:
3
+ # https://flower.ai/docs/framework/how-to-configure-pyproject-toml.html
4
+ # =====================================================================
5
+
6
+ [build-system]
7
+ requires = ["hatchling"]
8
+ build-backend = "hatchling.build"
9
+
10
+ [project]
11
+ name = "$package_name"
12
+ version = "1.0.0"
13
+ description = ""
14
+ license = "Apache-2.0"
15
+ # Dependencies for your Flower App
16
+ dependencies = [
17
+ "flwr[simulation]>=1.22.0",
18
+ "flwr-datasets>=0.5.0",
19
+ "xgboost>=2.0.0",
20
+ ]
21
+
22
+ [tool.hatch.build.targets.wheel]
23
+ packages = ["."]
24
+
25
+ [tool.flwr.app]
26
+ publisher = "$username"
27
+
28
+ [tool.flwr.app.components]
29
+ serverapp = "$import_name.server_app:app"
30
+ clientapp = "$import_name.client_app:app"
31
+
32
+ # Custom config values accessible via `context.run_config`
33
+ [tool.flwr.app.config]
34
+ num-server-rounds = 3
35
+ fraction-train = 0.1
36
+ fraction-evaluate = 0.1
37
+ local-epochs = 1
38
+
39
+ # XGBoost parameters
40
+ params.objective = "binary:logistic"
41
+ params.eta = 0.1 # Learning rate
42
+ params.max-depth = 8
43
+ params.eval-metric = "auc"
44
+ params.nthread = 16
45
+ params.num-parallel-tree = 1
46
+ params.subsample = 1
47
+ params.tree-method = "hist"
48
+
49
+ # Default federation to use when running the app
50
+ [tool.flwr.federations]
51
+ default = "local-simulation"
52
+
53
+ # Local simulation federation with 10 virtual SuperNodes
54
+ [tool.flwr.federations.local-simulation]
55
+ options.num-supernodes = 10
56
+
57
+ # Remote federation example for use with SuperLink
58
+ [tool.flwr.federations.remote-federation]
59
+ address = "<SUPERLINK-ADDRESS>:<PORT>"
60
+ insecure = true # Remove this line to enable TLS
61
+ # root-certificates = "<PATH/TO/ca.crt>" # For TLS setup