flwr 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. flwr/__init__.py +4 -1
  2. flwr/app/__init__.py +28 -0
  3. flwr/app/exception.py +31 -0
  4. flwr/cli/app.py +2 -0
  5. flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
  6. flwr/cli/cli_user_auth_interceptor.py +1 -1
  7. flwr/cli/config_utils.py +3 -3
  8. flwr/cli/constant.py +25 -8
  9. flwr/cli/log.py +9 -9
  10. flwr/cli/login/login.py +3 -3
  11. flwr/cli/ls.py +5 -5
  12. flwr/cli/new/new.py +15 -2
  13. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  14. flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +1 -0
  15. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  16. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
  17. flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
  18. flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
  19. flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
  20. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
  21. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
  22. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
  23. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
  24. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
  25. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  26. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  27. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  28. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  29. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  30. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  31. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
  32. flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
  33. flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
  34. flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
  35. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
  36. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
  37. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
  38. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
  39. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
  40. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
  41. flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
  42. flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
  43. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
  44. flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +111 -0
  45. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
  46. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
  47. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  48. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  49. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  50. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  51. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  52. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  53. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  54. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +53 -0
  55. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  56. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  57. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
  58. flwr/cli/pull.py +100 -0
  59. flwr/cli/run/run.py +9 -13
  60. flwr/cli/stop.py +7 -4
  61. flwr/cli/utils.py +36 -8
  62. flwr/client/grpc_rere_client/connection.py +1 -12
  63. flwr/client/rest_client/connection.py +3 -0
  64. flwr/clientapp/__init__.py +10 -0
  65. flwr/clientapp/mod/__init__.py +29 -0
  66. flwr/clientapp/mod/centraldp_mods.py +248 -0
  67. flwr/clientapp/mod/localdp_mod.py +169 -0
  68. flwr/clientapp/typing.py +22 -0
  69. flwr/common/args.py +20 -6
  70. flwr/common/auth_plugin/__init__.py +4 -4
  71. flwr/common/auth_plugin/auth_plugin.py +7 -7
  72. flwr/common/constant.py +26 -4
  73. flwr/common/event_log_plugin/event_log_plugin.py +1 -1
  74. flwr/common/exit/__init__.py +4 -0
  75. flwr/common/exit/exit.py +8 -1
  76. flwr/common/exit/exit_code.py +30 -7
  77. flwr/common/exit/exit_handler.py +62 -0
  78. flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
  79. flwr/common/grpc.py +0 -11
  80. flwr/common/inflatable_utils.py +1 -1
  81. flwr/common/logger.py +1 -1
  82. flwr/common/record/typeddict.py +12 -0
  83. flwr/common/retry_invoker.py +30 -11
  84. flwr/common/telemetry.py +4 -0
  85. flwr/compat/server/app.py +2 -2
  86. flwr/proto/appio_pb2.py +25 -17
  87. flwr/proto/appio_pb2.pyi +46 -2
  88. flwr/proto/clientappio_pb2.py +3 -11
  89. flwr/proto/clientappio_pb2.pyi +0 -47
  90. flwr/proto/clientappio_pb2_grpc.py +19 -20
  91. flwr/proto/clientappio_pb2_grpc.pyi +10 -11
  92. flwr/proto/control_pb2.py +66 -0
  93. flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +24 -0
  94. flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +88 -54
  95. flwr/proto/control_pb2_grpc.pyi +106 -0
  96. flwr/proto/serverappio_pb2.py +2 -2
  97. flwr/proto/serverappio_pb2_grpc.py +68 -0
  98. flwr/proto/serverappio_pb2_grpc.pyi +26 -0
  99. flwr/proto/simulationio_pb2.py +4 -11
  100. flwr/proto/simulationio_pb2.pyi +0 -58
  101. flwr/proto/simulationio_pb2_grpc.py +129 -27
  102. flwr/proto/simulationio_pb2_grpc.pyi +52 -13
  103. flwr/server/app.py +142 -152
  104. flwr/server/grid/grpc_grid.py +3 -0
  105. flwr/server/grid/inmemory_grid.py +1 -0
  106. flwr/server/serverapp/app.py +157 -146
  107. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
  108. flwr/server/superlink/fleet/vce/vce_api.py +6 -6
  109. flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
  110. flwr/server/superlink/linkstate/linkstate.py +2 -1
  111. flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
  112. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -1
  113. flwr/server/superlink/serverappio/serverappio_servicer.py +61 -6
  114. flwr/server/superlink/simulation/simulationio_servicer.py +97 -21
  115. flwr/serverapp/__init__.py +12 -0
  116. flwr/serverapp/exception.py +38 -0
  117. flwr/serverapp/strategy/__init__.py +64 -0
  118. flwr/serverapp/strategy/bulyan.py +238 -0
  119. flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
  120. flwr/serverapp/strategy/dp_fixed_clipping.py +374 -0
  121. flwr/serverapp/strategy/fedadagrad.py +159 -0
  122. flwr/serverapp/strategy/fedadam.py +178 -0
  123. flwr/serverapp/strategy/fedavg.py +320 -0
  124. flwr/serverapp/strategy/fedavgm.py +198 -0
  125. flwr/serverapp/strategy/fedmedian.py +105 -0
  126. flwr/serverapp/strategy/fedopt.py +218 -0
  127. flwr/serverapp/strategy/fedprox.py +174 -0
  128. flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
  129. flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
  130. flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
  131. flwr/serverapp/strategy/fedyogi.py +170 -0
  132. flwr/serverapp/strategy/krum.py +112 -0
  133. flwr/serverapp/strategy/multikrum.py +247 -0
  134. flwr/serverapp/strategy/qfedavg.py +252 -0
  135. flwr/serverapp/strategy/result.py +105 -0
  136. flwr/serverapp/strategy/strategy.py +285 -0
  137. flwr/serverapp/strategy/strategy_utils.py +299 -0
  138. flwr/simulation/app.py +161 -164
  139. flwr/simulation/run_simulation.py +25 -30
  140. flwr/supercore/app_utils.py +58 -0
  141. flwr/{supernode/scheduler → supercore/cli}/__init__.py +3 -3
  142. flwr/supercore/cli/flower_superexec.py +166 -0
  143. flwr/supercore/constant.py +19 -0
  144. flwr/supercore/{scheduler → corestate}/__init__.py +3 -3
  145. flwr/supercore/corestate/corestate.py +81 -0
  146. flwr/supercore/grpc_health/__init__.py +3 -0
  147. flwr/supercore/grpc_health/health_server.py +53 -0
  148. flwr/supercore/grpc_health/simple_health_servicer.py +2 -2
  149. flwr/{superexec → supercore/superexec}/__init__.py +1 -1
  150. flwr/supercore/superexec/plugin/__init__.py +28 -0
  151. flwr/{supernode/scheduler/simple_clientapp_scheduler_plugin.py → supercore/superexec/plugin/base_exec_plugin.py} +10 -6
  152. flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
  153. flwr/supercore/{scheduler/plugin.py → superexec/plugin/exec_plugin.py} +15 -5
  154. flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
  155. flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
  156. flwr/supercore/superexec/run_superexec.py +199 -0
  157. flwr/superlink/artifact_provider/__init__.py +22 -0
  158. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  159. flwr/superlink/servicer/__init__.py +15 -0
  160. flwr/superlink/servicer/control/__init__.py +22 -0
  161. flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +7 -7
  162. flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +27 -29
  163. flwr/{superexec/exec_license_interceptor.py → superlink/servicer/control/control_license_interceptor.py} +6 -6
  164. flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +127 -31
  165. flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +10 -10
  166. flwr/supernode/cli/flower_supernode.py +3 -0
  167. flwr/supernode/cli/flwr_clientapp.py +18 -21
  168. flwr/supernode/nodestate/in_memory_nodestate.py +2 -2
  169. flwr/supernode/nodestate/nodestate.py +3 -59
  170. flwr/supernode/runtime/run_clientapp.py +39 -102
  171. flwr/supernode/servicer/clientappio/clientappio_servicer.py +10 -17
  172. flwr/supernode/start_client_internal.py +35 -76
  173. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/METADATA +9 -18
  174. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/RECORD +176 -128
  175. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/entry_points.txt +1 -0
  176. flwr/proto/exec_pb2.py +0 -62
  177. flwr/proto/exec_pb2_grpc.pyi +0 -93
  178. flwr/superexec/app.py +0 -45
  179. flwr/superexec/deployment.py +0 -191
  180. flwr/superexec/executor.py +0 -100
  181. flwr/superexec/simulation.py +0 -129
  182. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/WHEEL +0 -0
@@ -1,53 +1,48 @@
1
1
  """$project_name: A Flower / FlowerTune app."""
2
2
 
3
- from io import BytesIO
3
+ from collections.abc import Iterable
4
4
  from logging import INFO, WARN
5
- from typing import List, Tuple, Union
5
+ from typing import Optional
6
6
 
7
- from flwr.common import FitIns, FitRes, Parameters, log, parameters_to_ndarrays
8
- from flwr.server.client_manager import ClientManager
9
- from flwr.server.client_proxy import ClientProxy
10
- from flwr.server.strategy import FedAvg
7
+ from flwr.app import ArrayRecord, ConfigRecord, Message, MetricRecord
8
+ from flwr.common import log
9
+ from flwr.serverapp import Grid
10
+ from flwr.serverapp.strategy import FedAvg
11
11
 
12
12
 
13
13
  class FlowerTuneLlm(FedAvg):
14
14
  """Customised FedAvg strategy implementation.
15
-
15
+
16
16
  This class behaves just like FedAvg but also tracks the communication
17
- costs associated with `fit` over FL rounds.
17
+ costs associated with `train` over FL rounds.
18
18
  """
19
19
  def __init__(self, **kwargs):
20
20
  super().__init__(**kwargs)
21
21
  self.comm_tracker = CommunicationTracker()
22
22
 
23
- def configure_fit(
24
- self, server_round: int, parameters: Parameters, client_manager: ClientManager
25
- ):
23
+ def configure_train(
24
+ self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
25
+ ) -> Iterable[Message]:
26
26
  """Configure the next round of training."""
27
- return_clients = super().configure_fit(server_round, parameters, client_manager)
28
-
29
- # Test communication costs
30
- fit_ins_list = [fit_ins for _, fit_ins in return_clients]
31
- self.comm_tracker.track(fit_ins_list)
32
-
33
- return return_clients
34
-
35
- def aggregate_fit(
36
- self,
37
- server_round: int,
38
- results: List[Tuple[ClientProxy, FitRes]],
39
- failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
40
- ):
41
- """Aggregate fit results using weighted average."""
42
- # Test communication costs
43
- fit_res_list = [fit_res for _, fit_res in results]
44
- self.comm_tracker.track(fit_res_list)
45
-
46
- parameters_aggregated, metrics_aggregated = super().aggregate_fit(
47
- server_round, results, failures
48
- )
27
+ messages = super().configure_train(server_round, arrays, config, grid)
28
+
29
+ # Track communication costs
30
+ self.comm_tracker.track(messages)
31
+
32
+ return messages
33
+
34
+ def aggregate_train(
35
+ self,
36
+ server_round: int,
37
+ replies: Iterable[Message],
38
+ ) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
39
+ """Aggregate ArrayRecords and MetricRecords in the received Messages."""
40
+ # Track communication costs
41
+ self.comm_tracker.track(replies)
49
42
 
50
- return parameters_aggregated, metrics_aggregated
43
+ arrays, metrics = super().aggregate_train(server_round, replies)
44
+
45
+ return arrays, metrics
51
46
 
52
47
 
53
48
  class CommunicationTracker:
@@ -55,16 +50,16 @@ class CommunicationTracker:
55
50
  def __init__(self):
56
51
  self.curr_comm_cost = 0.0
57
52
 
58
- @staticmethod
59
- def _compute_bytes(parameters):
60
- return sum([BytesIO(t).getbuffer().nbytes for t in parameters.tensors])
61
-
62
- def track(self, fit_list: List[Union[FitIns, FitRes]]):
63
- size_bytes_list = [
64
- self._compute_bytes(fit_ele.parameters)
65
- for fit_ele in fit_list
66
- ]
67
- comm_cost = sum(size_bytes_list) / 1024**2
53
+ def track(self, messages: Iterable[Message]):
54
+ comm_cost = (
55
+ sum(
56
+ record.count_bytes()
57
+ for msg in messages
58
+ if msg.has_content()
59
+ for record in msg.content.array_records.values()
60
+ )
61
+ / 1024**2
62
+ )
68
63
 
69
64
  self.curr_comm_cost += comm_cost
70
65
  log(
@@ -1,7 +1,5 @@
1
1
  """$project_name: A Flower Baseline."""
2
2
 
3
- from collections import OrderedDict
4
-
5
3
  import torch
6
4
  import torch.nn.functional as F
7
5
  from torch import nn
@@ -66,15 +64,3 @@ def test(net, testloader, device):
66
64
  accuracy = correct / len(testloader.dataset)
67
65
  loss = loss / len(testloader)
68
66
  return loss, accuracy
69
-
70
-
71
- def get_weights(net):
72
- """Extract model parameters as numpy arrays from state_dict."""
73
- return [val.cpu().numpy() for _, val in net.state_dict().items()]
74
-
75
-
76
- def set_weights(net, parameters):
77
- """Apply parameters to an existing model."""
78
- params_dict = zip(net.state_dict().keys(), parameters)
79
- state_dict = OrderedDict({k: torch.from_numpy(v) for k, v in params_dict})
80
- net.load_state_dict(state_dict, strict=True)
@@ -1,45 +1,43 @@
1
1
  """$project_name: A Flower Baseline."""
2
2
 
3
- from flwr.common import Context, Metrics, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
3
+ import torch
4
+ from flwr.app import ArrayRecord, Context
5
+ from flwr.serverapp import Grid, ServerApp
6
+ from flwr.serverapp.strategy import FedAvg
6
7
 
7
- from $import_name.model import Net, get_weights
8
+ from $import_name.model import Net
8
9
 
10
+ # Create ServerApp
11
+ app = ServerApp()
9
12
 
10
- # Define metric aggregation function
11
- def weighted_average(metrics: list[tuple[int, Metrics]]) -> Metrics:
12
- """Do weighted average of accuracy metric."""
13
- # Multiply accuracy of each client by number of examples used
14
- accuracies = [num_examples * float(m["accuracy"]) for num_examples, m in metrics]
15
- examples = [num_examples for num_examples, _ in metrics]
16
-
17
- # Aggregate and return custom metric (weighted average)
18
- return {"accuracy": sum(accuracies) / sum(examples)}
19
13
 
14
+ @app.main()
15
+ def main(grid: Grid, context: Context) -> None:
16
+ """Main entry point for the ServerApp."""
20
17
 
21
- def server_fn(context: Context):
22
- """Construct components that set the ServerApp behaviour."""
23
18
  # Read from config
24
19
  num_rounds = context.run_config["num-server-rounds"]
25
- fraction_fit = context.run_config["fraction-fit"]
20
+ fraction_train = context.run_config["fraction-train"]
26
21
 
27
- # Initialize model parameters
28
- ndarrays = get_weights(Net())
29
- parameters = ndarrays_to_parameters(ndarrays)
22
+ # Load global model
23
+ global_model = Net()
24
+ arrays = ArrayRecord(global_model.state_dict())
30
25
 
31
- # Define strategy
26
+ # Initialize FedAvg strategy
32
27
  strategy = FedAvg(
33
- fraction_fit=float(fraction_fit),
28
+ fraction_train=fraction_train,
34
29
  fraction_evaluate=1.0,
35
- min_available_clients=2,
36
- initial_parameters=parameters,
37
- evaluate_metrics_aggregation_fn=weighted_average,
30
+ min_available_nodes=2,
38
31
  )
39
- config = ServerConfig(num_rounds=int(num_rounds))
40
-
41
- return ServerAppComponents(strategy=strategy, config=config)
42
32
 
33
+ # Start strategy, run FedAvg for `num_rounds`
34
+ result = strategy.start(
35
+ grid=grid,
36
+ initial_arrays=arrays,
37
+ num_rounds=num_rounds,
38
+ )
43
39
 
44
- # Create ServerApp
45
- app = ServerApp(server_fn=server_fn)
40
+ # Save final model to disk
41
+ print("\nSaving final model to disk...")
42
+ state_dict = result.arrays.to_torch_state_dict()
43
+ torch.save(state_dict, "final_model.pt")
@@ -1,17 +1,22 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
3
+ import torch
4
+ from flwr.app import ArrayRecord, Context
5
+ from flwr.serverapp import Grid, ServerApp
6
+ from flwr.serverapp.strategy import FedAvg
6
7
  from transformers import AutoModelForSequenceClassification
7
8
 
8
- from $import_name.task import get_weights
9
+ # Create ServerApp
10
+ app = ServerApp()
11
+
9
12
 
13
+ @app.main()
14
+ def main(grid: Grid, context: Context) -> None:
15
+ """Main entry point for the ServerApp."""
10
16
 
11
- def server_fn(context: Context):
12
17
  # Read from config
13
18
  num_rounds = context.run_config["num-server-rounds"]
14
- fraction_fit = context.run_config["fraction-fit"]
19
+ fraction_train = context.run_config["fraction-train"]
15
20
 
16
21
  # Initialize global model
17
22
  model_name = context.run_config["model-name"]
@@ -19,20 +24,19 @@ def server_fn(context: Context):
19
24
  net = AutoModelForSequenceClassification.from_pretrained(
20
25
  model_name, num_labels=num_labels
21
26
  )
27
+ arrays = ArrayRecord(net.state_dict())
22
28
 
23
- weights = get_weights(net)
24
- initial_parameters = ndarrays_to_parameters(weights)
29
+ # Initialize FedAvg strategy
30
+ strategy = FedAvg(fraction_train=fraction_train)
25
31
 
26
- # Define strategy
27
- strategy = FedAvg(
28
- fraction_fit=fraction_fit,
29
- fraction_evaluate=1.0,
30
- initial_parameters=initial_parameters,
32
+ # Start strategy, run FedAvg for `num_rounds`
33
+ result = strategy.start(
34
+ grid=grid,
35
+ initial_arrays=arrays,
36
+ num_rounds=num_rounds,
31
37
  )
32
- config = ServerConfig(num_rounds=num_rounds)
33
-
34
- return ServerAppComponents(strategy=strategy, config=config)
35
38
 
36
-
37
- # Create ServerApp
38
- app = ServerApp(server_fn=server_fn)
39
+ # Save final model to disk
40
+ print("\nSaving final model to disk...")
41
+ state_dict = result.arrays.to_torch_state_dict()
42
+ torch.save(state_dict, "final_model.pt")
@@ -1,26 +1,39 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
3
+ import numpy as np
4
+ from flwr.app import ArrayRecord, Context
5
+ from flwr.serverapp import Grid, ServerApp
6
+ from flwr.serverapp.strategy import FedAvg
7
+
6
8
  from $import_name.task import get_params, load_model
7
9
 
10
+ # Create ServerApp
11
+ app = ServerApp()
12
+
13
+
14
+ @app.main()
15
+ def main(grid: Grid, context: Context) -> None:
16
+ """Main entry point for the ServerApp."""
8
17
 
9
- def server_fn(context: Context):
10
18
  # Read from config
11
19
  num_rounds = context.run_config["num-server-rounds"]
12
20
  input_dim = context.run_config["input-dim"]
13
21
 
14
- # Initialize global model
15
- params = get_params(load_model((input_dim,)))
16
- initial_parameters = ndarrays_to_parameters(params)
17
-
18
- # Define strategy
19
- strategy = FedAvg(initial_parameters=initial_parameters)
20
- config = ServerConfig(num_rounds=num_rounds)
22
+ # Load global model
23
+ model = load_model((input_dim,))
24
+ arrays = ArrayRecord(get_params(model))
21
25
 
22
- return ServerAppComponents(strategy=strategy, config=config)
26
+ # Initialize FedAvg strategy
27
+ strategy = FedAvg()
23
28
 
29
+ # Start strategy, run FedAvg for `num_rounds`
30
+ result = strategy.start(
31
+ grid=grid,
32
+ initial_arrays=arrays,
33
+ num_rounds=num_rounds,
34
+ )
24
35
 
25
- # Create ServerApp
26
- app = ServerApp(server_fn=server_fn)
36
+ # Save final model to disk
37
+ print("\nSaving final model to disk...")
38
+ ndarrays = result.arrays.to_numpy_ndarrays()
39
+ np.savez("final_model.npz", *ndarrays)
@@ -1,31 +1,41 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
6
- from $import_name.task import MLP, get_params
3
+ from flwr.app import ArrayRecord, Context
4
+ from flwr.serverapp import Grid, ServerApp
5
+ from flwr.serverapp.strategy import FedAvg
7
6
 
7
+ from $import_name.task import MLP, get_params, set_params
8
8
 
9
- def server_fn(context: Context):
9
+ # Create ServerApp
10
+ app = ServerApp()
11
+
12
+
13
+ @app.main()
14
+ def main(grid: Grid, context: Context) -> None:
15
+ """Main entry point for the ServerApp."""
10
16
  # Read from config
11
17
  num_rounds = context.run_config["num-server-rounds"]
12
-
13
- num_classes = 10
14
18
  num_layers = context.run_config["num-layers"]
15
19
  input_dim = context.run_config["input-dim"]
16
20
  hidden_dim = context.run_config["hidden-dim"]
17
21
 
18
22
  # Initialize global model
19
- model = MLP(num_layers, input_dim, hidden_dim, num_classes)
23
+ model = MLP(num_layers, input_dim, hidden_dim, output_dim=10)
20
24
  params = get_params(model)
21
- initial_parameters = ndarrays_to_parameters(params)
22
-
23
- # Define strategy
24
- strategy = FedAvg(initial_parameters=initial_parameters)
25
- config = ServerConfig(num_rounds=num_rounds)
26
-
27
- return ServerAppComponents(strategy=strategy, config=config)
28
-
29
-
30
- # Create ServerApp
31
- app = ServerApp(server_fn=server_fn)
25
+ arrays = ArrayRecord(params)
26
+
27
+ # Initialize FedAvg strategy
28
+ strategy = FedAvg()
29
+
30
+ # Start strategy, run FedAvg for `num_rounds`
31
+ result = strategy.start(
32
+ grid=grid,
33
+ initial_arrays=arrays,
34
+ num_rounds=num_rounds,
35
+ )
36
+
37
+ # Save final model to disk
38
+ print("\nSaving final model to disk...")
39
+ ndarrays = result.arrays.to_numpy_ndarrays()
40
+ set_params(model, ndarrays)
41
+ model.save_weights("final_model.npz")
@@ -1,25 +1,38 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
6
- from $import_name.task import get_dummy_model
7
-
3
+ import numpy as np
4
+ from flwr.app import ArrayRecord, Context
5
+ from flwr.serverapp import Grid, ServerApp
6
+ from flwr.serverapp.strategy import FedAvg
8
7
 
9
- def server_fn(context: Context):
10
- # Read from config
11
- num_rounds = context.run_config["num-server-rounds"]
8
+ from $import_name.task import get_dummy_model
12
9
 
13
- # Initial model
14
- model = get_dummy_model()
15
- dummy_parameters = ndarrays_to_parameters([model])
10
+ # Create ServerApp
11
+ app = ServerApp()
16
12
 
17
- # Define strategy
18
- strategy = FedAvg(initial_parameters=dummy_parameters)
19
- config = ServerConfig(num_rounds=num_rounds)
20
13
 
21
- return ServerAppComponents(strategy=strategy, config=config)
14
+ @app.main()
15
+ def main(grid: Grid, context: Context) -> None:
16
+ """Main entry point for the ServerApp."""
22
17
 
18
+ # Read run config
19
+ num_rounds: int = context.run_config["num-server-rounds"]
23
20
 
24
- # Create ServerApp
25
- app = ServerApp(server_fn=server_fn)
21
+ # Load global model
22
+ model = get_dummy_model()
23
+ arrays = ArrayRecord(model)
24
+
25
+ # Initialize FedAvg strategy
26
+ strategy = FedAvg()
27
+
28
+ # Start strategy, run FedAvg for `num_rounds`
29
+ result = strategy.start(
30
+ grid=grid,
31
+ initial_arrays=arrays,
32
+ num_rounds=num_rounds,
33
+ )
34
+
35
+ # Save final model to disk
36
+ print("\nSaving final model to disk...")
37
+ ndarrays = result.arrays.to_numpy_ndarrays()
38
+ np.savez("final_model", *ndarrays)
@@ -1,31 +1,41 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
6
- from $import_name.task import Net, get_weights
7
-
8
-
9
- def server_fn(context: Context):
10
- # Read from config
11
- num_rounds = context.run_config["num-server-rounds"]
12
- fraction_fit = context.run_config["fraction-fit"]
13
-
14
- # Initialize model parameters
15
- ndarrays = get_weights(Net())
16
- parameters = ndarrays_to_parameters(ndarrays)
17
-
18
- # Define strategy
19
- strategy = FedAvg(
20
- fraction_fit=fraction_fit,
21
- fraction_evaluate=1.0,
22
- min_available_clients=2,
23
- initial_parameters=parameters,
24
- )
25
- config = ServerConfig(num_rounds=num_rounds)
26
-
27
- return ServerAppComponents(strategy=strategy, config=config)
3
+ import torch
4
+ from flwr.app import ArrayRecord, ConfigRecord, Context
5
+ from flwr.serverapp import Grid, ServerApp
6
+ from flwr.serverapp.strategy import FedAvg
28
7
 
8
+ from $import_name.task import Net
29
9
 
30
10
  # Create ServerApp
31
- app = ServerApp(server_fn=server_fn)
11
+ app = ServerApp()
12
+
13
+
14
+ @app.main()
15
+ def main(grid: Grid, context: Context) -> None:
16
+ """Main entry point for the ServerApp."""
17
+
18
+ # Read run config
19
+ fraction_train: float = context.run_config["fraction-train"]
20
+ num_rounds: int = context.run_config["num-server-rounds"]
21
+ lr: float = context.run_config["lr"]
22
+
23
+ # Load global model
24
+ global_model = Net()
25
+ arrays = ArrayRecord(global_model.state_dict())
26
+
27
+ # Initialize FedAvg strategy
28
+ strategy = FedAvg(fraction_train=fraction_train)
29
+
30
+ # Start strategy, run FedAvg for `num_rounds`
31
+ result = strategy.start(
32
+ grid=grid,
33
+ initial_arrays=arrays,
34
+ train_config=ConfigRecord({"lr": lr}),
35
+ num_rounds=num_rounds,
36
+ )
37
+
38
+ # Save final model to disk
39
+ print("\nSaving final model to disk...")
40
+ state_dict = result.arrays.to_torch_state_dict()
41
+ torch.save(state_dict, "final_model.pt")
@@ -0,0 +1,31 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ from flwr.common import Context, ndarrays_to_parameters
4
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
+ from flwr.server.strategy import FedAvg
6
+ from $import_name.task import Net, get_weights
7
+
8
+
9
+ def server_fn(context: Context):
10
+ # Read from config
11
+ num_rounds = context.run_config["num-server-rounds"]
12
+ fraction_fit = context.run_config["fraction-fit"]
13
+
14
+ # Initialize model parameters
15
+ ndarrays = get_weights(Net())
16
+ parameters = ndarrays_to_parameters(ndarrays)
17
+
18
+ # Define strategy
19
+ strategy = FedAvg(
20
+ fraction_fit=fraction_fit,
21
+ fraction_evaluate=1.0,
22
+ min_available_clients=2,
23
+ initial_parameters=parameters,
24
+ )
25
+ config = ServerConfig(num_rounds=num_rounds)
26
+
27
+ return ServerAppComponents(strategy=strategy, config=config)
28
+
29
+
30
+ # Create ServerApp
31
+ app = ServerApp(server_fn=server_fn)
@@ -1,36 +1,44 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
6
- from $import_name.task import get_model, get_model_params, set_initial_params
3
+ import joblib
4
+ from flwr.app import ArrayRecord, Context
5
+ from flwr.serverapp import Grid, ServerApp
6
+ from flwr.serverapp.strategy import FedAvg
7
7
 
8
+ from $import_name.task import get_model, get_model_params, set_initial_params, set_model_params
8
9
 
9
- def server_fn(context: Context):
10
- # Read from config
11
- num_rounds = context.run_config["num-server-rounds"]
10
+ # Create ServerApp
11
+ app = ServerApp()
12
+
13
+
14
+ @app.main()
15
+ def main(grid: Grid, context: Context) -> None:
16
+ """Main entry point for the ServerApp."""
17
+
18
+ # Read run config
19
+ num_rounds: int = context.run_config["num-server-rounds"]
12
20
 
13
21
  # Create LogisticRegression Model
14
22
  penalty = context.run_config["penalty"]
15
23
  local_epochs = context.run_config["local-epochs"]
16
24
  model = get_model(penalty, local_epochs)
17
-
18
25
  # Setting initial parameters, akin to model.compile for keras models
19
26
  set_initial_params(model)
27
+ # Construct ArrayRecord representation
28
+ arrays = ArrayRecord(get_model_params(model))
20
29
 
21
- initial_parameters = ndarrays_to_parameters(get_model_params(model))
30
+ # Initialize FedAvg strategy
31
+ strategy = FedAvg(fraction_train=1.0, fraction_evaluate=1.0)
22
32
 
23
- # Define strategy
24
- strategy = FedAvg(
25
- fraction_fit=1.0,
26
- fraction_evaluate=1.0,
27
- min_available_clients=2,
28
- initial_parameters=initial_parameters,
33
+ # Start strategy, run FedAvg for `num_rounds`
34
+ result = strategy.start(
35
+ grid=grid,
36
+ initial_arrays=arrays,
37
+ num_rounds=num_rounds,
29
38
  )
30
- config = ServerConfig(num_rounds=num_rounds)
31
-
32
- return ServerAppComponents(strategy=strategy, config=config)
33
39
 
34
-
35
- # Create ServerApp
36
- app = ServerApp(server_fn=server_fn)
40
+ # Save final model parameters
41
+ print("\nSaving final model to disk...")
42
+ ndarrays = result.arrays.to_numpy_ndarrays()
43
+ set_model_params(model, ndarrays)
44
+ joblib.dump(model, "logreg_model.pkl")
@@ -1,29 +1,38 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
3
+ from flwr.app import ArrayRecord, Context
4
+ from flwr.serverapp import Grid, ServerApp
5
+ from flwr.serverapp.strategy import FedAvg
6
6
 
7
7
  from $import_name.task import load_model
8
8
 
9
+ # Create ServerApp
10
+ app = ServerApp()
9
11
 
10
- def server_fn(context: Context):
11
- # Read from config
12
- num_rounds = context.run_config["num-server-rounds"]
13
12
 
14
- # Get parameters to initialize global model
15
- parameters = ndarrays_to_parameters(load_model().get_weights())
13
+ @app.main()
14
+ def main(grid: Grid, context: Context) -> None:
15
+ """Main entry point for the ServerApp."""
16
16
 
17
- # Define strategy
18
- strategy = strategy = FedAvg(
19
- fraction_fit=1.0,
20
- fraction_evaluate=1.0,
21
- min_available_clients=2,
22
- initial_parameters=parameters,
23
- )
24
- config = ServerConfig(num_rounds=num_rounds)
17
+ # Read run config
18
+ num_rounds: int = context.run_config["num-server-rounds"]
25
19
 
26
- return ServerAppComponents(strategy=strategy, config=config)
20
+ # Load global model
21
+ model = load_model()
22
+ arrays = ArrayRecord(model.get_weights())
27
23
 
28
- # Create ServerApp
29
- app = ServerApp(server_fn=server_fn)
24
+ # Initialize FedAvg strategy
25
+ strategy = FedAvg(fraction_train=1.0, fraction_evaluate=1.0)
26
+
27
+ # Start strategy, run FedAvg for `num_rounds`
28
+ result = strategy.start(
29
+ grid=grid,
30
+ initial_arrays=arrays,
31
+ num_rounds=num_rounds,
32
+ )
33
+
34
+ # Save final model to disk
35
+ print("\nSaving final model to disk...")
36
+ ndarrays = result.arrays.to_numpy_ndarrays()
37
+ model.set_weights(ndarrays)
38
+ model.save("final_model.keras")