flwr-nightly 1.8.0.dev20240315__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (237) hide show
  1. flwr/cli/app.py +7 -0
  2. flwr/cli/build.py +150 -0
  3. flwr/cli/config_utils.py +219 -0
  4. flwr/cli/example.py +3 -1
  5. flwr/cli/install.py +227 -0
  6. flwr/cli/new/new.py +179 -48
  7. flwr/cli/new/templates/app/.gitignore.tpl +160 -0
  8. flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
  9. flwr/cli/new/templates/app/README.md.tpl +1 -5
  10. flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
  11. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
  12. flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
  13. flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
  14. flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
  15. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
  16. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
  17. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
  18. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  19. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
  20. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
  21. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
  22. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
  23. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
  24. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
  25. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
  26. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
  27. flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
  28. flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
  29. flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
  30. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
  31. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
  32. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
  33. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
  34. flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
  35. flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
  36. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
  37. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
  38. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
  39. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
  40. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
  41. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
  42. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
  43. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
  44. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
  45. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
  46. flwr/cli/run/run.py +168 -17
  47. flwr/cli/utils.py +75 -4
  48. flwr/client/__init__.py +6 -1
  49. flwr/client/app.py +239 -248
  50. flwr/client/client_app.py +70 -9
  51. flwr/client/dpfedavg_numpy_client.py +1 -1
  52. flwr/client/grpc_adapter_client/__init__.py +15 -0
  53. flwr/client/grpc_adapter_client/connection.py +97 -0
  54. flwr/client/grpc_client/connection.py +18 -5
  55. flwr/client/grpc_rere_client/__init__.py +1 -1
  56. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  57. flwr/client/grpc_rere_client/connection.py +127 -33
  58. flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
  59. flwr/client/heartbeat.py +74 -0
  60. flwr/client/message_handler/__init__.py +1 -1
  61. flwr/client/message_handler/message_handler.py +7 -7
  62. flwr/client/mod/__init__.py +5 -5
  63. flwr/client/mod/centraldp_mods.py +4 -2
  64. flwr/client/mod/comms_mods.py +4 -4
  65. flwr/client/mod/localdp_mod.py +9 -4
  66. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  67. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  68. flwr/client/mod/utils.py +1 -1
  69. flwr/client/node_state.py +60 -10
  70. flwr/client/node_state_tests.py +4 -3
  71. flwr/client/rest_client/__init__.py +1 -1
  72. flwr/client/rest_client/connection.py +177 -157
  73. flwr/client/supernode/__init__.py +26 -0
  74. flwr/client/supernode/app.py +464 -0
  75. flwr/client/typing.py +1 -0
  76. flwr/common/__init__.py +13 -11
  77. flwr/common/address.py +1 -1
  78. flwr/common/config.py +193 -0
  79. flwr/common/constant.py +42 -1
  80. flwr/common/context.py +26 -1
  81. flwr/common/date.py +1 -1
  82. flwr/common/dp.py +1 -1
  83. flwr/common/grpc.py +6 -2
  84. flwr/common/logger.py +79 -8
  85. flwr/common/message.py +167 -105
  86. flwr/common/object_ref.py +126 -25
  87. flwr/common/record/__init__.py +1 -1
  88. flwr/common/record/parametersrecord.py +0 -1
  89. flwr/common/record/recordset.py +78 -27
  90. flwr/common/recordset_compat.py +8 -1
  91. flwr/common/retry_invoker.py +25 -13
  92. flwr/common/secure_aggregation/__init__.py +1 -1
  93. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  94. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  95. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
  96. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  97. flwr/common/secure_aggregation/quantization.py +1 -1
  98. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  99. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  100. flwr/common/serde.py +209 -3
  101. flwr/common/telemetry.py +25 -0
  102. flwr/common/typing.py +38 -0
  103. flwr/common/version.py +14 -0
  104. flwr/proto/clientappio_pb2.py +41 -0
  105. flwr/proto/clientappio_pb2.pyi +110 -0
  106. flwr/proto/clientappio_pb2_grpc.py +101 -0
  107. flwr/proto/clientappio_pb2_grpc.pyi +40 -0
  108. flwr/proto/common_pb2.py +36 -0
  109. flwr/proto/common_pb2.pyi +121 -0
  110. flwr/proto/common_pb2_grpc.py +4 -0
  111. flwr/proto/common_pb2_grpc.pyi +4 -0
  112. flwr/proto/driver_pb2.py +26 -19
  113. flwr/proto/driver_pb2.pyi +34 -0
  114. flwr/proto/driver_pb2_grpc.py +70 -0
  115. flwr/proto/driver_pb2_grpc.pyi +28 -0
  116. flwr/proto/exec_pb2.py +43 -0
  117. flwr/proto/exec_pb2.pyi +95 -0
  118. flwr/proto/exec_pb2_grpc.py +101 -0
  119. flwr/proto/exec_pb2_grpc.pyi +41 -0
  120. flwr/proto/fab_pb2.py +30 -0
  121. flwr/proto/fab_pb2.pyi +56 -0
  122. flwr/proto/fab_pb2_grpc.py +4 -0
  123. flwr/proto/fab_pb2_grpc.pyi +4 -0
  124. flwr/proto/fleet_pb2.py +29 -23
  125. flwr/proto/fleet_pb2.pyi +33 -0
  126. flwr/proto/fleet_pb2_grpc.py +102 -0
  127. flwr/proto/fleet_pb2_grpc.pyi +35 -0
  128. flwr/proto/grpcadapter_pb2.py +32 -0
  129. flwr/proto/grpcadapter_pb2.pyi +43 -0
  130. flwr/proto/grpcadapter_pb2_grpc.py +66 -0
  131. flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
  132. flwr/proto/message_pb2.py +41 -0
  133. flwr/proto/message_pb2.pyi +122 -0
  134. flwr/proto/message_pb2_grpc.py +4 -0
  135. flwr/proto/message_pb2_grpc.pyi +4 -0
  136. flwr/proto/run_pb2.py +35 -0
  137. flwr/proto/run_pb2.pyi +76 -0
  138. flwr/proto/run_pb2_grpc.py +4 -0
  139. flwr/proto/run_pb2_grpc.pyi +4 -0
  140. flwr/proto/task_pb2.py +7 -8
  141. flwr/proto/task_pb2.pyi +8 -5
  142. flwr/server/__init__.py +4 -8
  143. flwr/server/app.py +298 -350
  144. flwr/server/compat/app.py +6 -57
  145. flwr/server/compat/app_utils.py +5 -4
  146. flwr/server/compat/driver_client_proxy.py +29 -48
  147. flwr/server/compat/legacy_context.py +5 -4
  148. flwr/server/driver/__init__.py +2 -0
  149. flwr/server/driver/driver.py +22 -132
  150. flwr/server/driver/grpc_driver.py +224 -74
  151. flwr/server/driver/inmemory_driver.py +183 -0
  152. flwr/server/history.py +20 -20
  153. flwr/server/run_serverapp.py +121 -34
  154. flwr/server/server.py +11 -7
  155. flwr/server/server_app.py +59 -10
  156. flwr/server/serverapp_components.py +52 -0
  157. flwr/server/strategy/__init__.py +2 -2
  158. flwr/server/strategy/bulyan.py +1 -1
  159. flwr/server/strategy/dp_adaptive_clipping.py +3 -3
  160. flwr/server/strategy/dp_fixed_clipping.py +4 -3
  161. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  162. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  163. flwr/server/strategy/fedadagrad.py +1 -1
  164. flwr/server/strategy/fedadam.py +1 -1
  165. flwr/server/strategy/fedavg_android.py +1 -1
  166. flwr/server/strategy/fedavgm.py +1 -1
  167. flwr/server/strategy/fedmedian.py +1 -1
  168. flwr/server/strategy/fedopt.py +1 -1
  169. flwr/server/strategy/fedprox.py +1 -1
  170. flwr/server/strategy/fedxgb_bagging.py +1 -1
  171. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  172. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  173. flwr/server/strategy/fedyogi.py +1 -1
  174. flwr/server/strategy/krum.py +1 -1
  175. flwr/server/strategy/qfedavg.py +1 -1
  176. flwr/server/superlink/driver/__init__.py +1 -1
  177. flwr/server/superlink/driver/driver_grpc.py +1 -1
  178. flwr/server/superlink/driver/driver_servicer.py +51 -4
  179. flwr/server/superlink/ffs/__init__.py +24 -0
  180. flwr/server/superlink/ffs/disk_ffs.py +104 -0
  181. flwr/server/superlink/ffs/ffs.py +79 -0
  182. flwr/server/superlink/fleet/__init__.py +1 -1
  183. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  184. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
  185. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  186. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  187. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  188. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  189. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
  190. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  191. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
  192. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
  193. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  194. flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
  195. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  196. flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
  197. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  198. flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
  199. flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
  200. flwr/server/superlink/fleet/vce/vce_api.py +190 -127
  201. flwr/server/superlink/state/__init__.py +1 -1
  202. flwr/server/superlink/state/in_memory_state.py +159 -42
  203. flwr/server/superlink/state/sqlite_state.py +243 -39
  204. flwr/server/superlink/state/state.py +81 -6
  205. flwr/server/superlink/state/state_factory.py +11 -2
  206. flwr/server/superlink/state/utils.py +62 -0
  207. flwr/server/typing.py +2 -0
  208. flwr/server/utils/__init__.py +1 -1
  209. flwr/server/utils/tensorboard.py +1 -1
  210. flwr/server/utils/validator.py +23 -9
  211. flwr/server/workflow/default_workflows.py +67 -25
  212. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
  213. flwr/simulation/__init__.py +7 -4
  214. flwr/simulation/app.py +67 -36
  215. flwr/simulation/ray_transport/__init__.py +1 -1
  216. flwr/simulation/ray_transport/ray_actor.py +20 -46
  217. flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
  218. flwr/simulation/run_simulation.py +308 -92
  219. flwr/superexec/__init__.py +21 -0
  220. flwr/superexec/app.py +184 -0
  221. flwr/superexec/deployment.py +185 -0
  222. flwr/superexec/exec_grpc.py +55 -0
  223. flwr/superexec/exec_servicer.py +70 -0
  224. flwr/superexec/executor.py +75 -0
  225. flwr/superexec/simulation.py +193 -0
  226. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
  227. flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
  228. flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
  229. flwr/cli/flower_toml.py +0 -140
  230. flwr/cli/new/templates/app/flower.toml.tpl +0 -13
  231. flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
  232. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
  233. flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
  234. flwr_nightly-1.8.0.dev20240315.dist-info/RECORD +0 -211
  235. flwr_nightly-1.8.0.dev20240315.dist-info/entry_points.txt +0 -9
  236. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
  237. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
@@ -0,0 +1,59 @@
1
+ """$project_name: A Flower / FlowerTune app."""
2
+
3
+ import math
4
+
5
+ import torch
6
+ from omegaconf import DictConfig
7
+ from peft import LoraConfig, get_peft_model
8
+ from peft.utils import prepare_model_for_kbit_training
9
+ from transformers import AutoModelForCausalLM, BitsAndBytesConfig
10
+
11
+
12
+ def cosine_annealing(
13
+ current_round: int,
14
+ total_round: int,
15
+ lrate_max: float = 0.001,
16
+ lrate_min: float = 0.0,
17
+ ) -> float:
18
+ """Implement cosine annealing learning rate schedule."""
19
+ cos_inner = math.pi * current_round / total_round
20
+ return lrate_min + 0.5 * (lrate_max - lrate_min) * (1 + math.cos(cos_inner))
21
+
22
+
23
+ def get_model(model_cfg: DictConfig):
24
+ """Load model with appropriate quantization config and other optimizations.
25
+
26
+ Please refer to this example for `peft + BitsAndBytes`:
27
+ https://github.com/huggingface/peft/blob/main/examples/fp4_finetuning/finetune_fp4_opt_bnb_peft.py
28
+ """
29
+ if model_cfg.quantization == 4:
30
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True)
31
+ elif model_cfg.quantization == 8:
32
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
33
+ else:
34
+ raise ValueError(
35
+ f"Use 4-bit or 8-bit quantization. You passed: {model_cfg.quantization}/"
36
+ )
37
+
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ model_cfg.name,
40
+ quantization_config=quantization_config,
41
+ torch_dtype=torch.bfloat16,
42
+ low_cpu_mem_usage=True,
43
+ )
44
+
45
+ model = prepare_model_for_kbit_training(
46
+ model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
47
+ )
48
+
49
+ peft_config = LoraConfig(
50
+ r=model_cfg.lora.peft_lora_r,
51
+ lora_alpha=model_cfg.lora.peft_lora_alpha,
52
+ lora_dropout=0.075,
53
+ task_type="CAUSAL_LM",
54
+ )
55
+
56
+ if model_cfg.gradient_checkpointing:
57
+ model.config.use_cache = False
58
+
59
+ return get_peft_model(model, peft_config)
@@ -0,0 +1,48 @@
1
+ """$project_name: A Flower / FlowerTune app."""
2
+
3
+ from $import_name.client_app import set_parameters
4
+ from $import_name.models import get_model
5
+
6
+
7
+ # Get function that will be executed by the strategy's evaluate() method
8
+ # Here we use it to save global model checkpoints
9
+ def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path):
10
+ """Return an evaluation function for saving global model."""
11
+
12
+ def evaluate(server_round: int, parameters, config):
13
+ # Save model
14
+ if server_round != 0 and (
15
+ server_round == total_round or server_round % save_every_round == 0
16
+ ):
17
+ # Init model
18
+ model = get_model(model_cfg)
19
+ set_parameters(model, parameters)
20
+
21
+ model.save_pretrained(f"{save_path}/peft_{server_round}")
22
+
23
+ return 0.0, {}
24
+
25
+ return evaluate
26
+
27
+
28
+ def get_on_fit_config():
29
+ """
30
+ Return a function that will be used to construct the config
31
+ that the client's fit() method will receive.
32
+ """
33
+
34
+ def fit_config_fn(server_round: int):
35
+ fit_config = {"current_round": server_round}
36
+ return fit_config
37
+
38
+ return fit_config_fn
39
+
40
+
41
+ def fit_weighted_average(metrics):
42
+ """Aggregate (federated) evaluation metrics."""
43
+ # Multiply accuracy of each client by number of examples used
44
+ losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
45
+ examples = [num_examples for num_examples, _ in metrics]
46
+
47
+ # Aggregate and return custom metric (weighted average)
48
+ return {"train_loss": sum(losses) / sum(examples)}
@@ -0,0 +1,11 @@
1
+ # Federated Instruction Tuning (static)
2
+ ---
3
+ dataset:
4
+ name: $dataset_name
5
+
6
+ # FL experimental settings
7
+ num_clients: $num_clients # total number of clients
8
+ num_rounds: 200
9
+ partitioner:
10
+ _target_: flwr_datasets.partitioner.IidPartitioner
11
+ num_partitions: $num_clients
@@ -0,0 +1,23 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ from flwr.common import Context
4
+ from flwr.server.strategy import FedAvg
5
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
6
+
7
+
8
+ def server_fn(context: Context):
9
+ # Read from config
10
+ num_rounds = context.run_config["num-server-rounds"]
11
+
12
+ # Define strategy
13
+ strategy = FedAvg(
14
+ fraction_fit=1.0,
15
+ fraction_evaluate=1.0,
16
+ )
17
+ config = ServerConfig(num_rounds=num_rounds)
18
+
19
+ return ServerAppComponents(strategy=strategy, config=config)
20
+
21
+
22
+ # Create ServerApp
23
+ app = ServerApp(server_fn=server_fn)
@@ -0,0 +1,20 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ from flwr.common import Context
4
+ from flwr.server.strategy import FedAvg
5
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
6
+
7
+
8
+ def server_fn(context: Context):
9
+ # Read from config
10
+ num_rounds = context.run_config["num-server-rounds"]
11
+
12
+ # Define strategy
13
+ strategy = FedAvg()
14
+ config = ServerConfig(num_rounds=num_rounds)
15
+
16
+ return ServerAppComponents(strategy=strategy, config=config)
17
+
18
+
19
+ # Create ServerApp
20
+ app = ServerApp(server_fn=server_fn)
@@ -0,0 +1,20 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ from flwr.common import Context
4
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
+ from flwr.server.strategy import FedAvg
6
+
7
+
8
+ def server_fn(context: Context):
9
+ # Read from config
10
+ num_rounds = context.run_config["num-server-rounds"]
11
+
12
+ # Define strategy
13
+ strategy = FedAvg()
14
+ config = ServerConfig(num_rounds=num_rounds)
15
+
16
+ return ServerAppComponents(strategy=strategy, config=config)
17
+
18
+
19
+ # Create ServerApp
20
+ app = ServerApp(server_fn=server_fn)
@@ -1,12 +1,20 @@
1
- """$project_name: A Flower / NumPy app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
- import flwr as fl
3
+ from flwr.common import Context
4
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
+ from flwr.server.strategy import FedAvg
4
6
 
5
- # Configure the strategy
6
- strategy = fl.server.strategy.FedAvg()
7
7
 
8
- # Flower ServerApp
9
- app = fl.server.ServerApp(
10
- config=fl.server.ServerConfig(num_rounds=1),
11
- strategy=strategy,
12
- )
8
+ def server_fn(context: Context):
9
+ # Read from config
10
+ num_rounds = context.run_config["num-server-rounds"]
11
+
12
+ # Define strategy
13
+ strategy = FedAvg()
14
+ config = ServerConfig(num_rounds=num_rounds)
15
+
16
+ return ServerAppComponents(strategy=strategy, config=config)
17
+
18
+
19
+ # Create ServerApp
20
+ app = ServerApp(server_fn=server_fn)
@@ -1,28 +1,31 @@
1
- """$project_name: A Flower / PyTorch app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerConfig
3
+ from flwr.common import Context, ndarrays_to_parameters
4
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
5
  from flwr.server.strategy import FedAvg
6
6
 
7
- from $project_name.task import Net, get_weights
7
+ from $import_name.task import Net, get_weights
8
8
 
9
9
 
10
- # Initialize model parameters
11
- ndarrays = get_weights(Net())
12
- parameters = ndarrays_to_parameters(ndarrays)
10
+ def server_fn(context: Context):
11
+ # Read from config
12
+ num_rounds = context.run_config["num-server-rounds"]
13
+ fraction_fit = context.run_config["fraction-fit"]
13
14
 
15
+ # Initialize model parameters
16
+ ndarrays = get_weights(Net())
17
+ parameters = ndarrays_to_parameters(ndarrays)
14
18
 
15
- # Define strategy
16
- strategy = FedAvg(
17
- fraction_fit=1.0,
18
- fraction_evaluate=1.0,
19
- min_available_clients=2,
20
- initial_parameters=parameters,
21
- )
19
+ # Define strategy
20
+ strategy = FedAvg(
21
+ fraction_fit=fraction_fit,
22
+ fraction_evaluate=1.0,
23
+ min_available_clients=2,
24
+ initial_parameters=parameters,
25
+ )
26
+ config = ServerConfig(num_rounds=num_rounds)
22
27
 
28
+ return ServerAppComponents(strategy=strategy, config=config)
23
29
 
24
30
  # Create ServerApp
25
- app = ServerApp(
26
- config=ServerConfig(num_rounds=3),
27
- strategy=strategy,
28
- )
31
+ app = ServerApp(server_fn=server_fn)
@@ -0,0 +1,24 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ from flwr.common import Context
4
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
+ from flwr.server.strategy import FedAvg
6
+
7
+
8
+ def server_fn(context: Context):
9
+ # Read from config
10
+ num_rounds = context.run_config["num-server-rounds"]
11
+
12
+ # Define strategy
13
+ strategy = FedAvg(
14
+ fraction_fit=1.0,
15
+ fraction_evaluate=1.0,
16
+ min_available_clients=2,
17
+ )
18
+ config = ServerConfig(num_rounds=num_rounds)
19
+
20
+ return ServerAppComponents(strategy=strategy, config=config)
21
+
22
+
23
+ # Create ServerApp
24
+ app = ServerApp(server_fn=server_fn)
@@ -1 +1,29 @@
1
- """$project_name: A Flower / TensorFlow app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ from flwr.common import Context, ndarrays_to_parameters
4
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
+ from flwr.server.strategy import FedAvg
6
+
7
+ from $import_name.task import load_model
8
+
9
+
10
+ def server_fn(context: Context):
11
+ # Read from config
12
+ num_rounds = context.run_config["num-server-rounds"]
13
+
14
+ # Get parameters to initialize global model
15
+ parameters = ndarrays_to_parameters(load_model().get_weights())
16
+
17
+ # Define strategy
18
+ strategy = strategy = FedAvg(
19
+ fraction_fit=1.0,
20
+ fraction_evaluate=1.0,
21
+ min_available_clients=2,
22
+ initial_parameters=parameters,
23
+ )
24
+ config = ServerConfig(num_rounds=num_rounds)
25
+
26
+ return ServerAppComponents(strategy=strategy, config=config)
27
+
28
+ # Create ServerApp
29
+ app = ServerApp(server_fn=server_fn)
@@ -0,0 +1,99 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import warnings
4
+ from collections import OrderedDict
5
+
6
+ import torch
7
+ from evaluate import load as load_metric
8
+ from torch.optim import AdamW
9
+ from torch.utils.data import DataLoader
10
+ from transformers import AutoTokenizer, DataCollatorWithPadding
11
+
12
+ from flwr_datasets import FederatedDataset
13
+ from flwr_datasets.partitioner import IidPartitioner
14
+
15
+
16
+ warnings.filterwarnings("ignore", category=UserWarning)
17
+ DEVICE = torch.device("cpu")
18
+ CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint
19
+
20
+
21
+ fds = None # Cache FederatedDataset
22
+
23
+
24
+ def load_data(partition_id: int, num_partitions: int):
25
+ """Load IMDB data (training and eval)"""
26
+ # Only initialize `FederatedDataset` once
27
+ global fds
28
+ if fds is None:
29
+ partitioner = IidPartitioner(num_partitions=num_partitions)
30
+ fds = FederatedDataset(
31
+ dataset="stanfordnlp/imdb",
32
+ partitioners={"train": partitioner},
33
+ )
34
+ partition = fds.load_partition(partition_id)
35
+ # Divide data: 80% train, 20% test
36
+ partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
37
+
38
+ tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
39
+
40
+ def tokenize_function(examples):
41
+ return tokenizer(examples["text"], truncation=True)
42
+
43
+ partition_train_test = partition_train_test.map(tokenize_function, batched=True)
44
+ partition_train_test = partition_train_test.remove_columns("text")
45
+ partition_train_test = partition_train_test.rename_column("label", "labels")
46
+
47
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
48
+ trainloader = DataLoader(
49
+ partition_train_test["train"],
50
+ shuffle=True,
51
+ batch_size=32,
52
+ collate_fn=data_collator,
53
+ )
54
+
55
+ testloader = DataLoader(
56
+ partition_train_test["test"], batch_size=32, collate_fn=data_collator
57
+ )
58
+
59
+ return trainloader, testloader
60
+
61
+
62
+ def train(net, trainloader, epochs):
63
+ optimizer = AdamW(net.parameters(), lr=5e-5)
64
+ net.train()
65
+ for _ in range(epochs):
66
+ for batch in trainloader:
67
+ batch = {k: v.to(DEVICE) for k, v in batch.items()}
68
+ outputs = net(**batch)
69
+ loss = outputs.loss
70
+ loss.backward()
71
+ optimizer.step()
72
+ optimizer.zero_grad()
73
+
74
+
75
+ def test(net, testloader):
76
+ metric = load_metric("accuracy")
77
+ loss = 0
78
+ net.eval()
79
+ for batch in testloader:
80
+ batch = {k: v.to(DEVICE) for k, v in batch.items()}
81
+ with torch.no_grad():
82
+ outputs = net(**batch)
83
+ logits = outputs.logits
84
+ loss += outputs.loss.item()
85
+ predictions = torch.argmax(logits, dim=-1)
86
+ metric.add_batch(predictions=predictions, references=batch["labels"])
87
+ loss /= len(testloader.dataset)
88
+ accuracy = metric.compute()["accuracy"]
89
+ return loss, accuracy
90
+
91
+
92
+ def get_weights(net):
93
+ return [val.cpu().numpy() for _, val in net.state_dict().items()]
94
+
95
+
96
+ def set_weights(net, parameters):
97
+ params_dict = zip(net.state_dict().keys(), parameters)
98
+ state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
99
+ net.load_state_dict(state_dict, strict=True)
@@ -0,0 +1,57 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from sklearn.datasets import make_regression
6
+ from sklearn.model_selection import train_test_split
7
+ import numpy as np
8
+
9
+ key = jax.random.PRNGKey(0)
10
+
11
+
12
+ def load_data():
13
+ # Load dataset
14
+ X, y = make_regression(n_features=3, random_state=0)
15
+ X, X_test, y, y_test = train_test_split(X, y)
16
+ return X, y, X_test, y_test
17
+
18
+
19
+ def load_model(model_shape):
20
+ # Extract model parameters
21
+ params = {"b": jax.random.uniform(key), "w": jax.random.uniform(key, model_shape)}
22
+ return params
23
+
24
+
25
+ def loss_fn(params, X, y):
26
+ # Return MSE as loss
27
+ err = jnp.dot(X, params["w"]) + params["b"] - y
28
+ return jnp.mean(jnp.square(err))
29
+
30
+
31
+ def train(params, grad_fn, X, y):
32
+ loss = 1_000_000
33
+ num_examples = X.shape[0]
34
+ for epochs in range(50):
35
+ grads = grad_fn(params, X, y)
36
+ params = jax.tree_map(lambda p, g: p - 0.05 * g, params, grads)
37
+ loss = loss_fn(params, X, y)
38
+ return params, loss, num_examples
39
+
40
+
41
+ def evaluation(params, grad_fn, X_test, y_test):
42
+ num_examples = X_test.shape[0]
43
+ err_test = loss_fn(params, X_test, y_test)
44
+ loss_test = jnp.mean(jnp.square(err_test))
45
+ return loss_test, num_examples
46
+
47
+
48
+ def get_params(params):
49
+ parameters = []
50
+ for _, val in params.items():
51
+ parameters.append(np.array(val))
52
+ return parameters
53
+
54
+
55
+ def set_params(local_params, global_params):
56
+ for key, value in list(zip(local_params.keys(), global_params)):
57
+ local_params[key] = value
@@ -0,0 +1,102 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+ from datasets.utils.logging import disable_progress_bar
7
+ from flwr_datasets import FederatedDataset
8
+ from flwr_datasets.partitioner import IidPartitioner
9
+
10
+
11
+ disable_progress_bar()
12
+
13
+
14
+ class MLP(nn.Module):
15
+ """A simple MLP."""
16
+
17
+ def __init__(
18
+ self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
19
+ ):
20
+ super().__init__()
21
+ layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
22
+ self.layers = [
23
+ nn.Linear(idim, odim)
24
+ for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
25
+ ]
26
+
27
+ def __call__(self, x):
28
+ for l in self.layers[:-1]:
29
+ x = mx.maximum(l(x), 0.0)
30
+ return self.layers[-1](x)
31
+
32
+
33
+ def loss_fn(model, X, y):
34
+ return mx.mean(nn.losses.cross_entropy(model(X), y))
35
+
36
+
37
+ def eval_fn(model, X, y):
38
+ return mx.mean(mx.argmax(model(X), axis=1) == y)
39
+
40
+
41
+ def batch_iterate(batch_size, X, y):
42
+ perm = mx.array(np.random.permutation(y.size))
43
+ for s in range(0, y.size, batch_size):
44
+ ids = perm[s : s + batch_size]
45
+ yield X[ids], y[ids]
46
+
47
+
48
+ fds = None # Cache FederatedDataset
49
+
50
+
51
+ def load_data(partition_id: int, num_partitions: int):
52
+ # Only initialize `FederatedDataset` once
53
+ global fds
54
+ if fds is None:
55
+ partitioner = IidPartitioner(num_partitions=num_partitions)
56
+ fds = FederatedDataset(
57
+ dataset="ylecun/mnist",
58
+ partitioners={"train": partitioner},
59
+ trust_remote_code=True,
60
+ )
61
+ partition = fds.load_partition(partition_id)
62
+ partition_splits = partition.train_test_split(test_size=0.2, seed=42)
63
+
64
+ partition_splits["train"].set_format("numpy")
65
+ partition_splits["test"].set_format("numpy")
66
+
67
+ train_partition = partition_splits["train"].map(
68
+ lambda img: {
69
+ "img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0
70
+ },
71
+ input_columns="image",
72
+ )
73
+ test_partition = partition_splits["test"].map(
74
+ lambda img: {
75
+ "img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0
76
+ },
77
+ input_columns="image",
78
+ )
79
+
80
+ data = (
81
+ train_partition["img"],
82
+ train_partition["label"].astype(np.uint32),
83
+ test_partition["img"],
84
+ test_partition["label"].astype(np.uint32),
85
+ )
86
+
87
+ train_images, train_labels, test_images, test_labels = map(mx.array, data)
88
+ return train_images, train_labels, test_images, test_labels
89
+
90
+
91
+ def get_params(model):
92
+ layers = model.parameters()["layers"]
93
+ return [np.array(val) for layer in layers for _, val in layer.items()]
94
+
95
+
96
+ def set_params(model, parameters):
97
+ new_params = {}
98
+ new_params["layers"] = [
99
+ {"weight": mx.array(parameters[i]), "bias": mx.array(parameters[i + 1])}
100
+ for i in range(0, len(parameters), 2)
101
+ ]
102
+ model.update(new_params)
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / PyTorch app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  from collections import OrderedDict
4
4
 
@@ -6,11 +6,9 @@ import torch
6
6
  import torch.nn as nn
7
7
  import torch.nn.functional as F
8
8
  from torch.utils.data import DataLoader
9
- from torchvision.datasets import CIFAR10
10
9
  from torchvision.transforms import Compose, Normalize, ToTensor
11
10
  from flwr_datasets import FederatedDataset
12
-
13
- DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
11
+ from flwr_datasets.partitioner import IidPartitioner
14
12
 
15
13
 
16
14
  class Net(nn.Module):
@@ -34,12 +32,22 @@ class Net(nn.Module):
34
32
  return self.fc3(x)
35
33
 
36
34
 
37
- def load_data(partition_id, num_partitions):
35
+ fds = None # Cache FederatedDataset
36
+
37
+
38
+ def load_data(partition_id: int, num_partitions: int):
38
39
  """Load partition CIFAR10 data."""
39
- fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
40
+ # Only initialize `FederatedDataset` once
41
+ global fds
42
+ if fds is None:
43
+ partitioner = IidPartitioner(num_partitions=num_partitions)
44
+ fds = FederatedDataset(
45
+ dataset="uoft-cs/cifar10",
46
+ partitioners={"train": partitioner},
47
+ )
40
48
  partition = fds.load_partition(partition_id)
41
49
  # Divide data on each node: 80% train, 20% test
42
- partition_train_test = partition.train_test_split(test_size=0.2)
50
+ partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
43
51
  pytorch_transforms = Compose(
44
52
  [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
45
53
  )
@@ -55,44 +63,41 @@ def load_data(partition_id, num_partitions):
55
63
  return trainloader, testloader
56
64
 
57
65
 
58
- def train(net, trainloader, valloader, epochs, device):
66
+ def train(net, trainloader, epochs, device):
59
67
  """Train the model on the training set."""
60
68
  net.to(device) # move model to GPU if available
61
69
  criterion = torch.nn.CrossEntropyLoss().to(device)
62
- optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
70
+ optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
63
71
  net.train()
72
+ running_loss = 0.0
64
73
  for _ in range(epochs):
65
74
  for batch in trainloader:
66
75
  images = batch["img"]
67
76
  labels = batch["label"]
68
77
  optimizer.zero_grad()
69
- criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
78
+ loss = criterion(net(images.to(device)), labels.to(device))
79
+ loss.backward()
70
80
  optimizer.step()
81
+ running_loss += loss.item()
71
82
 
72
- train_loss, train_acc = test(net, trainloader)
73
- val_loss, val_acc = test(net, valloader)
74
-
75
- results = {
76
- "train_loss": train_loss,
77
- "train_accuracy": train_acc,
78
- "val_loss": val_loss,
79
- "val_accuracy": val_acc,
80
- }
81
- return results
83
+ avg_trainloss = running_loss / len(trainloader)
84
+ return avg_trainloss
82
85
 
83
86
 
84
- def test(net, testloader):
87
+ def test(net, testloader, device):
85
88
  """Validate the model on the test set."""
89
+ net.to(device)
86
90
  criterion = torch.nn.CrossEntropyLoss()
87
91
  correct, loss = 0, 0.0
88
92
  with torch.no_grad():
89
93
  for batch in testloader:
90
- images = batch["img"].to(DEVICE)
91
- labels = batch["label"].to(DEVICE)
94
+ images = batch["img"].to(device)
95
+ labels = batch["label"].to(device)
92
96
  outputs = net(images)
93
97
  loss += criterion(outputs, labels).item()
94
98
  correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
95
99
  accuracy = correct / len(testloader.dataset)
100
+ loss = loss / len(testloader)
96
101
  return loss, accuracy
97
102
 
98
103