flwr 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. flwr/cli/app.py +17 -1
  2. flwr/cli/auth_plugin/__init__.py +15 -6
  3. flwr/cli/auth_plugin/auth_plugin.py +95 -0
  4. flwr/cli/auth_plugin/noop_auth_plugin.py +58 -0
  5. flwr/cli/auth_plugin/oidc_cli_plugin.py +16 -25
  6. flwr/cli/build.py +118 -47
  7. flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +6 -5
  8. flwr/cli/log.py +2 -2
  9. flwr/cli/login/login.py +34 -23
  10. flwr/cli/ls.py +13 -9
  11. flwr/cli/new/new.py +196 -42
  12. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  13. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  14. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
  15. flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
  16. flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
  17. flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
  18. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
  19. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
  20. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
  21. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
  22. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
  23. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  24. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  25. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  26. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  27. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  28. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  29. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
  30. flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
  31. flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
  32. flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
  33. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
  34. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
  35. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
  36. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
  37. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
  38. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
  39. flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
  40. flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
  41. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
  42. flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
  43. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
  44. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
  45. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  46. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  47. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  48. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  49. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  50. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  51. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  52. flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +3 -3
  53. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  54. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  55. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
  56. flwr/cli/pull.py +100 -0
  57. flwr/cli/run/run.py +11 -7
  58. flwr/cli/stop.py +2 -2
  59. flwr/cli/supernode/__init__.py +25 -0
  60. flwr/cli/supernode/ls.py +260 -0
  61. flwr/cli/supernode/register.py +185 -0
  62. flwr/cli/supernode/unregister.py +138 -0
  63. flwr/cli/utils.py +109 -69
  64. flwr/client/__init__.py +2 -1
  65. flwr/client/grpc_adapter_client/connection.py +6 -8
  66. flwr/client/grpc_rere_client/connection.py +59 -31
  67. flwr/client/grpc_rere_client/grpc_adapter.py +28 -12
  68. flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
  69. flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
  70. flwr/client/rest_client/connection.py +82 -37
  71. flwr/clientapp/__init__.py +1 -2
  72. flwr/clientapp/mod/__init__.py +4 -1
  73. flwr/clientapp/mod/centraldp_mods.py +156 -40
  74. flwr/clientapp/mod/localdp_mod.py +169 -0
  75. flwr/clientapp/typing.py +22 -0
  76. flwr/{client/clientapp → clientapp}/utils.py +1 -1
  77. flwr/common/constant.py +56 -13
  78. flwr/common/exit/exit_code.py +24 -10
  79. flwr/common/inflatable_utils.py +10 -10
  80. flwr/common/record/array.py +3 -3
  81. flwr/common/record/arrayrecord.py +10 -1
  82. flwr/common/record/typeddict.py +12 -0
  83. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
  84. flwr/common/serde.py +4 -2
  85. flwr/common/typing.py +7 -6
  86. flwr/compat/client/app.py +1 -1
  87. flwr/compat/client/grpc_client/connection.py +2 -2
  88. flwr/proto/control_pb2.py +48 -31
  89. flwr/proto/control_pb2.pyi +95 -5
  90. flwr/proto/control_pb2_grpc.py +136 -0
  91. flwr/proto/control_pb2_grpc.pyi +52 -0
  92. flwr/proto/fab_pb2.py +11 -7
  93. flwr/proto/fab_pb2.pyi +21 -1
  94. flwr/proto/fleet_pb2.py +31 -23
  95. flwr/proto/fleet_pb2.pyi +63 -23
  96. flwr/proto/fleet_pb2_grpc.py +98 -28
  97. flwr/proto/fleet_pb2_grpc.pyi +45 -13
  98. flwr/proto/node_pb2.py +3 -1
  99. flwr/proto/node_pb2.pyi +48 -0
  100. flwr/server/app.py +152 -114
  101. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -7
  102. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +132 -38
  103. flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +27 -51
  104. flwr/server/superlink/fleet/message_handler/message_handler.py +67 -22
  105. flwr/server/superlink/fleet/rest_rere/rest_api.py +52 -31
  106. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  107. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
  108. flwr/server/superlink/fleet/vce/vce_api.py +18 -5
  109. flwr/server/superlink/linkstate/in_memory_linkstate.py +167 -73
  110. flwr/server/superlink/linkstate/linkstate.py +107 -24
  111. flwr/server/superlink/linkstate/linkstate_factory.py +2 -1
  112. flwr/server/superlink/linkstate/sqlite_linkstate.py +306 -255
  113. flwr/server/superlink/linkstate/utils.py +3 -54
  114. flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
  115. flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
  116. flwr/server/utils/validator.py +2 -3
  117. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
  118. flwr/serverapp/strategy/__init__.py +26 -0
  119. flwr/serverapp/strategy/bulyan.py +238 -0
  120. flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
  121. flwr/serverapp/strategy/dp_fixed_clipping.py +71 -49
  122. flwr/serverapp/strategy/fedadagrad.py +0 -3
  123. flwr/serverapp/strategy/fedadam.py +0 -3
  124. flwr/serverapp/strategy/fedavg.py +89 -64
  125. flwr/serverapp/strategy/fedavgm.py +198 -0
  126. flwr/serverapp/strategy/fedmedian.py +105 -0
  127. flwr/serverapp/strategy/fedprox.py +174 -0
  128. flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
  129. flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
  130. flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
  131. flwr/serverapp/strategy/fedyogi.py +0 -3
  132. flwr/serverapp/strategy/krum.py +112 -0
  133. flwr/serverapp/strategy/multikrum.py +247 -0
  134. flwr/serverapp/strategy/qfedavg.py +252 -0
  135. flwr/serverapp/strategy/strategy_utils.py +48 -0
  136. flwr/simulation/app.py +1 -1
  137. flwr/simulation/ray_transport/ray_actor.py +1 -1
  138. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  139. flwr/simulation/run_simulation.py +28 -32
  140. flwr/supercore/cli/flower_superexec.py +26 -1
  141. flwr/supercore/constant.py +41 -0
  142. flwr/supercore/object_store/in_memory_object_store.py +0 -4
  143. flwr/supercore/object_store/object_store_factory.py +26 -6
  144. flwr/supercore/object_store/sqlite_object_store.py +252 -0
  145. flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
  146. flwr/supercore/primitives/asymmetric.py +117 -0
  147. flwr/supercore/primitives/asymmetric_ed25519.py +165 -0
  148. flwr/supercore/sqlite_mixin.py +156 -0
  149. flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
  150. flwr/supercore/superexec/run_superexec.py +16 -2
  151. flwr/supercore/utils.py +20 -0
  152. flwr/superlink/artifact_provider/__init__.py +22 -0
  153. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  154. flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
  155. flwr/superlink/auth_plugin/auth_plugin.py +91 -0
  156. flwr/superlink/auth_plugin/noop_auth_plugin.py +87 -0
  157. flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +19 -19
  158. flwr/superlink/servicer/control/control_event_log_interceptor.py +1 -1
  159. flwr/superlink/servicer/control/control_grpc.py +16 -11
  160. flwr/superlink/servicer/control/control_servicer.py +207 -58
  161. flwr/supernode/cli/flower_supernode.py +19 -26
  162. flwr/supernode/runtime/run_clientapp.py +2 -2
  163. flwr/supernode/servicer/clientappio/clientappio_servicer.py +1 -1
  164. flwr/supernode/start_client_internal.py +17 -9
  165. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/METADATA +6 -16
  166. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/RECORD +170 -140
  167. flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
  168. flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
  169. flwr/common/auth_plugin/auth_plugin.py +0 -149
  170. flwr/serverapp/dp_fixed_clipping.py +0 -352
  171. flwr/serverapp/strategy/strategy_utils_tests.py +0 -304
  172. /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
  173. /flwr/{client → clientapp}/client_app.py +0 -0
  174. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/WHEEL +0 -0
  175. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/entry_points.txt +0 -0
@@ -1,36 +1,44 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
6
- from $import_name.task import get_model, get_model_params, set_initial_params
3
+ import joblib
4
+ from flwr.app import ArrayRecord, Context
5
+ from flwr.serverapp import Grid, ServerApp
6
+ from flwr.serverapp.strategy import FedAvg
7
7
 
8
+ from $import_name.task import get_model, get_model_params, set_initial_params, set_model_params
8
9
 
9
- def server_fn(context: Context):
10
- # Read from config
11
- num_rounds = context.run_config["num-server-rounds"]
10
+ # Create ServerApp
11
+ app = ServerApp()
12
+
13
+
14
+ @app.main()
15
+ def main(grid: Grid, context: Context) -> None:
16
+ """Main entry point for the ServerApp."""
17
+
18
+ # Read run config
19
+ num_rounds: int = context.run_config["num-server-rounds"]
12
20
 
13
21
  # Create LogisticRegression Model
14
22
  penalty = context.run_config["penalty"]
15
23
  local_epochs = context.run_config["local-epochs"]
16
24
  model = get_model(penalty, local_epochs)
17
-
18
25
  # Setting initial parameters, akin to model.compile for keras models
19
26
  set_initial_params(model)
27
+ # Construct ArrayRecord representation
28
+ arrays = ArrayRecord(get_model_params(model))
20
29
 
21
- initial_parameters = ndarrays_to_parameters(get_model_params(model))
30
+ # Initialize FedAvg strategy
31
+ strategy = FedAvg(fraction_train=1.0, fraction_evaluate=1.0)
22
32
 
23
- # Define strategy
24
- strategy = FedAvg(
25
- fraction_fit=1.0,
26
- fraction_evaluate=1.0,
27
- min_available_clients=2,
28
- initial_parameters=initial_parameters,
33
+ # Start strategy, run FedAvg for `num_rounds`
34
+ result = strategy.start(
35
+ grid=grid,
36
+ initial_arrays=arrays,
37
+ num_rounds=num_rounds,
29
38
  )
30
- config = ServerConfig(num_rounds=num_rounds)
31
-
32
- return ServerAppComponents(strategy=strategy, config=config)
33
39
 
34
-
35
- # Create ServerApp
36
- app = ServerApp(server_fn=server_fn)
40
+ # Save final model parameters
41
+ print("\nSaving final model to disk...")
42
+ ndarrays = result.arrays.to_numpy_ndarrays()
43
+ set_model_params(model, ndarrays)
44
+ joblib.dump(model, "logreg_model.pkl")
@@ -1,29 +1,38 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
3
+ from flwr.app import ArrayRecord, Context
4
+ from flwr.serverapp import Grid, ServerApp
5
+ from flwr.serverapp.strategy import FedAvg
6
6
 
7
7
  from $import_name.task import load_model
8
8
 
9
+ # Create ServerApp
10
+ app = ServerApp()
9
11
 
10
- def server_fn(context: Context):
11
- # Read from config
12
- num_rounds = context.run_config["num-server-rounds"]
13
12
 
14
- # Get parameters to initialize global model
15
- parameters = ndarrays_to_parameters(load_model().get_weights())
13
+ @app.main()
14
+ def main(grid: Grid, context: Context) -> None:
15
+ """Main entry point for the ServerApp."""
16
16
 
17
- # Define strategy
18
- strategy = strategy = FedAvg(
19
- fraction_fit=1.0,
20
- fraction_evaluate=1.0,
21
- min_available_clients=2,
22
- initial_parameters=parameters,
23
- )
24
- config = ServerConfig(num_rounds=num_rounds)
17
+ # Read run config
18
+ num_rounds: int = context.run_config["num-server-rounds"]
25
19
 
26
- return ServerAppComponents(strategy=strategy, config=config)
20
+ # Load global model
21
+ model = load_model()
22
+ arrays = ArrayRecord(model.get_weights())
27
23
 
28
- # Create ServerApp
29
- app = ServerApp(server_fn=server_fn)
24
+ # Initialize FedAvg strategy
25
+ strategy = FedAvg(fraction_train=1.0, fraction_evaluate=1.0)
26
+
27
+ # Start strategy, run FedAvg for `num_rounds`
28
+ result = strategy.start(
29
+ grid=grid,
30
+ initial_arrays=arrays,
31
+ num_rounds=num_rounds,
32
+ )
33
+
34
+ # Save final model to disk
35
+ print("\nSaving final model to disk...")
36
+ ndarrays = result.arrays.to_numpy_ndarrays()
37
+ model.set_weights(ndarrays)
38
+ model.save("final_model.keras")
@@ -0,0 +1,56 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import numpy as np
4
+ import xgboost as xgb
5
+ from flwr.app import ArrayRecord, Context
6
+ from flwr.common.config import unflatten_dict
7
+ from flwr.serverapp import Grid, ServerApp
8
+ from flwr.serverapp.strategy import FedXgbBagging
9
+
10
+ from $import_name.task import replace_keys
11
+
12
+ # Create ServerApp
13
+ app = ServerApp()
14
+
15
+
16
+ @app.main()
17
+ def main(grid: Grid, context: Context) -> None:
18
+ # Read run config
19
+ num_rounds = context.run_config["num-server-rounds"]
20
+ fraction_train = context.run_config["fraction-train"]
21
+ fraction_evaluate = context.run_config["fraction-evaluate"]
22
+ # Flatted config dict and replace "-" with "_"
23
+ cfg = replace_keys(unflatten_dict(context.run_config))
24
+ params = cfg["params"]
25
+
26
+ # Init global model
27
+ # Init with an empty object; the XGBooster will be created
28
+ # and trained on the client side.
29
+ global_model = b""
30
+ # Note: we store the model as the first item in a list into ArrayRecord,
31
+ # which can be accessed using index ["0"].
32
+ arrays = ArrayRecord([np.frombuffer(global_model, dtype=np.uint8)])
33
+
34
+ # Initialize FedXgbBagging strategy
35
+ strategy = FedXgbBagging(
36
+ fraction_train=fraction_train,
37
+ fraction_evaluate=fraction_evaluate,
38
+ )
39
+
40
+ # Start strategy, run FedXgbBagging for `num_rounds`
41
+ result = strategy.start(
42
+ grid=grid,
43
+ initial_arrays=arrays,
44
+ num_rounds=num_rounds,
45
+ )
46
+
47
+ # Save final model to disk
48
+ bst = xgb.Booster(params=params)
49
+ global_model = bytearray(result.arrays["0"].numpy().tobytes())
50
+
51
+ # Load global model into booster
52
+ bst.load_model(global_model)
53
+
54
+ # Save model
55
+ print("\nSaving final model to disk...")
56
+ bst.save_model("final_model.json")
@@ -1,7 +1,6 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import warnings
4
- from collections import OrderedDict
5
4
 
6
5
  import torch
7
6
  import transformers
@@ -62,17 +61,24 @@ def load_data(partition_id: int, num_partitions: int, model_name: str):
62
61
  return trainloader, testloader
63
62
 
64
63
 
65
- def train(net, trainloader, epochs, device):
64
+ def train(net, trainloader, num_steps, device):
66
65
  optimizer = AdamW(net.parameters(), lr=5e-5)
67
66
  net.train()
68
- for _ in range(epochs):
69
- for batch in trainloader:
70
- batch = {k: v.to(device) for k, v in batch.items()}
71
- outputs = net(**batch)
72
- loss = outputs.loss
73
- loss.backward()
74
- optimizer.step()
75
- optimizer.zero_grad()
67
+ running_loss = 0.0
68
+ step_cnt = 0
69
+ for batch in trainloader:
70
+ batch = {k: v.to(device) for k, v in batch.items()}
71
+ outputs = net(**batch)
72
+ loss = outputs.loss
73
+ loss.backward()
74
+ optimizer.step()
75
+ optimizer.zero_grad()
76
+ running_loss += loss.item()
77
+ step_cnt += 1
78
+ if step_cnt >= num_steps:
79
+ break
80
+ avg_trainloss = running_loss / step_cnt
81
+ return avg_trainloss
76
82
 
77
83
 
78
84
  def test(net, testloader, device):
@@ -90,13 +96,3 @@ def test(net, testloader, device):
90
96
  loss /= len(testloader.dataset)
91
97
  accuracy = metric.compute()["accuracy"]
92
98
  return loss, accuracy
93
-
94
-
95
- def get_weights(net):
96
- return [val.cpu().numpy() for _, val in net.state_dict().items()]
97
-
98
-
99
- def set_weights(net, parameters):
100
- params_dict = zip(net.state_dict().keys(), parameters)
101
- state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
102
- net.load_state_dict(state_dict, strict=True)
@@ -31,7 +31,7 @@ def loss_fn(params, X, y):
31
31
  def train(params, grad_fn, X, y):
32
32
  loss = 1_000_000
33
33
  num_examples = X.shape[0]
34
- for epochs in range(50):
34
+ for _ in range(50):
35
35
  grads = grad_fn(params, X, y)
36
36
  params = jax.tree.map(lambda p, g: p - 0.05 * g, params, grads)
37
37
  loss = loss_fn(params, X, y)
@@ -4,4 +4,4 @@ import numpy as np
4
4
 
5
5
 
6
6
  def get_dummy_model():
7
- return np.ones((1, 1))
7
+ return [np.ones((1, 1))]
@@ -1,7 +1,5 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from collections import OrderedDict
4
-
5
3
  import torch
6
4
  import torch.nn as nn
7
5
  import torch.nn.functional as F
@@ -34,6 +32,14 @@ class Net(nn.Module):
34
32
 
35
33
  fds = None # Cache FederatedDataset
36
34
 
35
+ pytorch_transforms = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
36
+
37
+
38
+ def apply_transforms(batch):
39
+ """Apply transforms to the partition from FederatedDataset."""
40
+ batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
41
+ return batch
42
+
37
43
 
38
44
  def load_data(partition_id: int, num_partitions: int):
39
45
  """Load partition CIFAR10 data."""
@@ -48,38 +54,29 @@ def load_data(partition_id: int, num_partitions: int):
48
54
  partition = fds.load_partition(partition_id)
49
55
  # Divide data on each node: 80% train, 20% test
50
56
  partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
51
- pytorch_transforms = Compose(
52
- [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
53
- )
54
-
55
- def apply_transforms(batch):
56
- """Apply transforms to the partition from FederatedDataset."""
57
- batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
58
- return batch
59
-
57
+ # Construct dataloaders
60
58
  partition_train_test = partition_train_test.with_transform(apply_transforms)
61
59
  trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
62
60
  testloader = DataLoader(partition_train_test["test"], batch_size=32)
63
61
  return trainloader, testloader
64
62
 
65
63
 
66
- def train(net, trainloader, epochs, device):
64
+ def train(net, trainloader, epochs, lr, device):
67
65
  """Train the model on the training set."""
68
66
  net.to(device) # move model to GPU if available
69
67
  criterion = torch.nn.CrossEntropyLoss().to(device)
70
- optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
68
+ optimizer = torch.optim.Adam(net.parameters(), lr=lr)
71
69
  net.train()
72
70
  running_loss = 0.0
73
71
  for _ in range(epochs):
74
72
  for batch in trainloader:
75
- images = batch["img"]
76
- labels = batch["label"]
73
+ images = batch["img"].to(device)
74
+ labels = batch["label"].to(device)
77
75
  optimizer.zero_grad()
78
- loss = criterion(net(images.to(device)), labels.to(device))
76
+ loss = criterion(net(images), labels)
79
77
  loss.backward()
80
78
  optimizer.step()
81
79
  running_loss += loss.item()
82
-
83
80
  avg_trainloss = running_loss / len(trainloader)
84
81
  return avg_trainloss
85
82
 
@@ -99,13 +96,3 @@ def test(net, testloader, device):
99
96
  accuracy = correct / len(testloader.dataset)
100
97
  loss = loss / len(testloader)
101
98
  return loss, accuracy
102
-
103
-
104
- def get_weights(net):
105
- return [val.cpu().numpy() for _, val in net.state_dict().items()]
106
-
107
-
108
- def set_weights(net, parameters):
109
- params_dict = zip(net.state_dict().keys(), parameters)
110
- state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
111
- net.load_state_dict(state_dict, strict=True)
@@ -1,5 +1,7 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
+ from collections import OrderedDict
4
+
3
5
  import torch
4
6
  import torch.nn as nn
5
7
  import torch.nn.functional as F
@@ -32,14 +34,6 @@ class Net(nn.Module):
32
34
 
33
35
  fds = None # Cache FederatedDataset
34
36
 
35
- pytorch_transforms = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
36
-
37
-
38
- def apply_transforms(batch):
39
- """Apply transforms to the partition from FederatedDataset."""
40
- batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
41
- return batch
42
-
43
37
 
44
38
  def load_data(partition_id: int, num_partitions: int):
45
39
  """Load partition CIFAR10 data."""
@@ -54,29 +48,38 @@ def load_data(partition_id: int, num_partitions: int):
54
48
  partition = fds.load_partition(partition_id)
55
49
  # Divide data on each node: 80% train, 20% test
56
50
  partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
57
- # Construct dataloaders
51
+ pytorch_transforms = Compose(
52
+ [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
53
+ )
54
+
55
+ def apply_transforms(batch):
56
+ """Apply transforms to the partition from FederatedDataset."""
57
+ batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
58
+ return batch
59
+
58
60
  partition_train_test = partition_train_test.with_transform(apply_transforms)
59
61
  trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
60
62
  testloader = DataLoader(partition_train_test["test"], batch_size=32)
61
63
  return trainloader, testloader
62
64
 
63
65
 
64
- def train(net, trainloader, epochs, lr, device):
66
+ def train(net, trainloader, epochs, device):
65
67
  """Train the model on the training set."""
66
68
  net.to(device) # move model to GPU if available
67
69
  criterion = torch.nn.CrossEntropyLoss().to(device)
68
- optimizer = torch.optim.Adam(net.parameters(), lr=lr)
70
+ optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
69
71
  net.train()
70
72
  running_loss = 0.0
71
73
  for _ in range(epochs):
72
74
  for batch in trainloader:
73
- images = batch["img"].to(device)
74
- labels = batch["label"].to(device)
75
+ images = batch["img"]
76
+ labels = batch["label"]
75
77
  optimizer.zero_grad()
76
- loss = criterion(net(images), labels)
78
+ loss = criterion(net(images.to(device)), labels.to(device))
77
79
  loss.backward()
78
80
  optimizer.step()
79
81
  running_loss += loss.item()
82
+
80
83
  avg_trainloss = running_loss / len(trainloader)
81
84
  return avg_trainloss
82
85
 
@@ -96,3 +99,13 @@ def test(net, testloader, device):
96
99
  accuracy = correct / len(testloader.dataset)
97
100
  loss = loss / len(testloader)
98
101
  return loss, accuracy
102
+
103
+
104
+ def get_weights(net):
105
+ return [val.cpu().numpy() for _, val in net.state_dict().items()]
106
+
107
+
108
+ def set_weights(net, parameters):
109
+ params_dict = zip(net.state_dict().keys(), parameters)
110
+ state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
111
+ net.load_state_dict(state_dict, strict=True)
@@ -3,10 +3,9 @@
3
3
  import os
4
4
 
5
5
  import keras
6
- from keras import layers
7
6
  from flwr_datasets import FederatedDataset
8
7
  from flwr_datasets.partitioner import IidPartitioner
9
-
8
+ from keras import layers
10
9
 
11
10
  # Make TensorFlow log less verbose
12
11
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -0,0 +1,67 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import xgboost as xgb
4
+ from flwr_datasets import FederatedDataset
5
+ from flwr_datasets.partitioner import IidPartitioner
6
+
7
+
8
+ def train_test_split(partition, test_fraction, seed):
9
+ """Split the data into train and validation set given split rate."""
10
+ train_test = partition.train_test_split(test_size=test_fraction, seed=seed)
11
+ partition_train = train_test["train"]
12
+ partition_test = train_test["test"]
13
+
14
+ num_train = len(partition_train)
15
+ num_test = len(partition_test)
16
+
17
+ return partition_train, partition_test, num_train, num_test
18
+
19
+
20
+ def transform_dataset_to_dmatrix(data):
21
+ """Transform dataset to DMatrix format for xgboost."""
22
+ x = data["inputs"]
23
+ y = data["label"]
24
+ new_data = xgb.DMatrix(x, label=y)
25
+ return new_data
26
+
27
+
28
+ fds = None # Cache FederatedDataset
29
+
30
+
31
+ def load_data(partition_id, num_clients):
32
+ """Load partition HIGGS data."""
33
+ # Only initialize `FederatedDataset` once
34
+ global fds
35
+ if fds is None:
36
+ partitioner = IidPartitioner(num_partitions=num_clients)
37
+ fds = FederatedDataset(
38
+ dataset="jxie/higgs",
39
+ partitioners={"train": partitioner},
40
+ )
41
+
42
+ # Load the partition for this `partition_id`
43
+ partition = fds.load_partition(partition_id, split="train")
44
+ partition.set_format("numpy")
45
+
46
+ # Train/test splitting
47
+ train_data, valid_data, num_train, num_val = train_test_split(
48
+ partition, test_fraction=0.2, seed=42
49
+ )
50
+
51
+ # Reformat data to DMatrix for xgboost
52
+ train_dmatrix = transform_dataset_to_dmatrix(train_data)
53
+ valid_dmatrix = transform_dataset_to_dmatrix(valid_data)
54
+
55
+ return train_dmatrix, valid_dmatrix, num_train, num_val
56
+
57
+
58
+ def replace_keys(input_dict, match="-", target="_"):
59
+ """Recursively replace match string with target string in dictionary keys."""
60
+ new_dict = {}
61
+ for key, value in input_dict.items():
62
+ new_key = key.replace(match, target)
63
+ if isinstance(value, dict):
64
+ new_dict[new_key] = replace_keys(value, match, target)
65
+ else:
66
+ new_dict[new_key] = value
67
+ return new_dict
@@ -14,10 +14,10 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.21.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
- "torch==2.7.1",
20
- "torchvision==0.22.1",
19
+ "torch==2.8.0",
20
+ "torchvision==0.23.0",
21
21
  ]
22
22
 
23
23
  [tool.hatch.metadata]
@@ -132,7 +132,7 @@ clientapp = "$import_name.client_app:app"
132
132
  # Custom config values accessible via `context.run_config`
133
133
  [tool.flwr.app.config]
134
134
  num-server-rounds = 3
135
- fraction-fit = 0.5
135
+ fraction-train = 0.5
136
136
  local-epochs = 1
137
137
 
138
138
  # Default federation to use when running the app
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.21.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "flwr-datasets>=0.5.0",
19
19
  "torch==2.4.0",
20
20
  "trl==0.8.1",
@@ -61,7 +61,7 @@ train.training-arguments.save-steps = 1000
61
61
  train.training-arguments.save-total-limit = 10
62
62
  train.training-arguments.gradient-checkpointing = true
63
63
  train.training-arguments.lr-scheduler-type = "constant"
64
- strategy.fraction-fit = $fraction_fit
64
+ strategy.fraction-train = $fraction_train
65
65
  strategy.fraction-evaluate = 0.0
66
66
  num-server-rounds = 200
67
67
 
@@ -14,9 +14,9 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.21.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "flwr-datasets>=0.5.0",
19
- "torch==2.7.1",
19
+ "torch>=2.7.1",
20
20
  "transformers>=4.30.0,<5.0",
21
21
  "evaluate>=0.4.0,<1.0",
22
22
  "datasets>=2.0.0, <3.0",
@@ -38,8 +38,8 @@ clientapp = "$import_name.client_app:app"
38
38
  # Custom config values accessible via `context.run_config`
39
39
  [tool.flwr.app.config]
40
40
  num-server-rounds = 3
41
- fraction-fit = 0.5
42
- local-epochs = 1
41
+ fraction-train = 0.5
42
+ local-steps = 5
43
43
  model-name = "prajjwal1/bert-tiny" # Set a larger model if you have access to more GPU resources
44
44
  num-labels = 2
45
45
 
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.21.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "jax==0.4.30",
19
19
  "jaxlib==0.4.30",
20
20
  "scikit-learn==1.6.1",
@@ -14,9 +14,9 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.21.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
- "mlx==0.26.5",
19
+ "mlx==0.29.0",
20
20
  ]
21
21
 
22
22
  [tool.hatch.build.targets.wheel]
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.21.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "numpy>=2.0.2",
19
19
  ]
20
20
 
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.21.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
19
  "torch==2.7.1",
20
20
  "torchvision==0.22.1",
@@ -27,7 +27,6 @@ packages = ["."]
27
27
  publisher = "$username"
28
28
 
29
29
  # Point to your ServerApp and ClientApp objects
30
- # Format: "<module>:<object>"
31
30
  [tool.flwr.app.components]
32
31
  serverapp = "$import_name.server_app:app"
33
32
  clientapp = "$import_name.client_app:app"
@@ -35,8 +34,9 @@ clientapp = "$import_name.client_app:app"
35
34
  # Custom config values accessible via `context.run_config`
36
35
  [tool.flwr.app.config]
37
36
  num-server-rounds = 3
38
- fraction-fit = 0.5
37
+ fraction-train = 0.5
39
38
  local-epochs = 1
39
+ lr = 0.01
40
40
 
41
41
  # Default federation to use when running the app
42
42
  [tool.flwr.federations]
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.21.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
19
  "torch==2.7.1",
20
20
  "torchvision==0.22.1",
@@ -27,6 +27,7 @@ packages = ["."]
27
27
  publisher = "$username"
28
28
 
29
29
  # Point to your ServerApp and ClientApp objects
30
+ # Format: "<module>:<object>"
30
31
  [tool.flwr.app.components]
31
32
  serverapp = "$import_name.server_app:app"
32
33
  clientapp = "$import_name.client_app:app"
@@ -34,9 +35,8 @@ clientapp = "$import_name.client_app:app"
34
35
  # Custom config values accessible via `context.run_config`
35
36
  [tool.flwr.app.config]
36
37
  num-server-rounds = 3
37
- fraction-train = 0.5
38
+ fraction-fit = 0.5
38
39
  local-epochs = 1
39
- lr = 0.01
40
40
 
41
41
  # Default federation to use when running the app
42
42
  [tool.flwr.federations]
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.21.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
19
  "scikit-learn>=1.6.1",
20
20
  ]
@@ -14,7 +14,7 @@ description = ""
14
14
  license = "Apache-2.0"
15
15
  # Dependencies for your Flower App
16
16
  dependencies = [
17
- "flwr[simulation]>=1.21.0",
17
+ "flwr[simulation]>=1.23.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
19
  "tensorflow>=2.11.1,<2.18.0",
20
20
  ]