flwr-nightly 1.8.0.dev20240314__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (237) hide show
  1. flwr/cli/app.py +7 -0
  2. flwr/cli/build.py +150 -0
  3. flwr/cli/config_utils.py +219 -0
  4. flwr/cli/example.py +3 -1
  5. flwr/cli/install.py +227 -0
  6. flwr/cli/new/new.py +179 -48
  7. flwr/cli/new/templates/app/.gitignore.tpl +160 -0
  8. flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
  9. flwr/cli/new/templates/app/README.md.tpl +1 -5
  10. flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
  11. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
  12. flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
  13. flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
  14. flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
  15. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
  16. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
  17. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
  18. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  19. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
  20. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
  21. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
  22. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
  23. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
  24. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
  25. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
  26. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
  27. flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
  28. flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
  29. flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
  30. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
  31. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
  32. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
  33. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
  34. flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
  35. flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
  36. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
  37. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
  38. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
  39. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
  40. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
  41. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
  42. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
  43. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
  44. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
  45. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
  46. flwr/cli/run/run.py +168 -17
  47. flwr/cli/utils.py +75 -4
  48. flwr/client/__init__.py +6 -1
  49. flwr/client/app.py +239 -248
  50. flwr/client/client_app.py +70 -9
  51. flwr/client/dpfedavg_numpy_client.py +1 -1
  52. flwr/client/grpc_adapter_client/__init__.py +15 -0
  53. flwr/client/grpc_adapter_client/connection.py +97 -0
  54. flwr/client/grpc_client/connection.py +18 -5
  55. flwr/client/grpc_rere_client/__init__.py +1 -1
  56. flwr/client/grpc_rere_client/client_interceptor.py +158 -0
  57. flwr/client/grpc_rere_client/connection.py +127 -33
  58. flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
  59. flwr/client/heartbeat.py +74 -0
  60. flwr/client/message_handler/__init__.py +1 -1
  61. flwr/client/message_handler/message_handler.py +7 -7
  62. flwr/client/mod/__init__.py +5 -5
  63. flwr/client/mod/centraldp_mods.py +4 -2
  64. flwr/client/mod/comms_mods.py +4 -4
  65. flwr/client/mod/localdp_mod.py +9 -4
  66. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  67. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  68. flwr/client/mod/utils.py +1 -1
  69. flwr/client/node_state.py +60 -10
  70. flwr/client/node_state_tests.py +4 -3
  71. flwr/client/rest_client/__init__.py +1 -1
  72. flwr/client/rest_client/connection.py +177 -157
  73. flwr/client/supernode/__init__.py +26 -0
  74. flwr/client/supernode/app.py +464 -0
  75. flwr/client/typing.py +1 -0
  76. flwr/common/__init__.py +13 -11
  77. flwr/common/address.py +1 -1
  78. flwr/common/config.py +193 -0
  79. flwr/common/constant.py +42 -1
  80. flwr/common/context.py +26 -1
  81. flwr/common/date.py +1 -1
  82. flwr/common/dp.py +1 -1
  83. flwr/common/grpc.py +6 -2
  84. flwr/common/logger.py +79 -8
  85. flwr/common/message.py +167 -105
  86. flwr/common/object_ref.py +126 -25
  87. flwr/common/record/__init__.py +1 -1
  88. flwr/common/record/parametersrecord.py +0 -1
  89. flwr/common/record/recordset.py +78 -27
  90. flwr/common/recordset_compat.py +8 -1
  91. flwr/common/retry_invoker.py +25 -13
  92. flwr/common/secure_aggregation/__init__.py +1 -1
  93. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  94. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  95. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
  96. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  97. flwr/common/secure_aggregation/quantization.py +1 -1
  98. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  99. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  100. flwr/common/serde.py +209 -3
  101. flwr/common/telemetry.py +25 -0
  102. flwr/common/typing.py +38 -0
  103. flwr/common/version.py +14 -0
  104. flwr/proto/clientappio_pb2.py +41 -0
  105. flwr/proto/clientappio_pb2.pyi +110 -0
  106. flwr/proto/clientappio_pb2_grpc.py +101 -0
  107. flwr/proto/clientappio_pb2_grpc.pyi +40 -0
  108. flwr/proto/common_pb2.py +36 -0
  109. flwr/proto/common_pb2.pyi +121 -0
  110. flwr/proto/common_pb2_grpc.py +4 -0
  111. flwr/proto/common_pb2_grpc.pyi +4 -0
  112. flwr/proto/driver_pb2.py +26 -19
  113. flwr/proto/driver_pb2.pyi +34 -0
  114. flwr/proto/driver_pb2_grpc.py +70 -0
  115. flwr/proto/driver_pb2_grpc.pyi +28 -0
  116. flwr/proto/exec_pb2.py +43 -0
  117. flwr/proto/exec_pb2.pyi +95 -0
  118. flwr/proto/exec_pb2_grpc.py +101 -0
  119. flwr/proto/exec_pb2_grpc.pyi +41 -0
  120. flwr/proto/fab_pb2.py +30 -0
  121. flwr/proto/fab_pb2.pyi +56 -0
  122. flwr/proto/fab_pb2_grpc.py +4 -0
  123. flwr/proto/fab_pb2_grpc.pyi +4 -0
  124. flwr/proto/fleet_pb2.py +29 -23
  125. flwr/proto/fleet_pb2.pyi +33 -0
  126. flwr/proto/fleet_pb2_grpc.py +102 -0
  127. flwr/proto/fleet_pb2_grpc.pyi +35 -0
  128. flwr/proto/grpcadapter_pb2.py +32 -0
  129. flwr/proto/grpcadapter_pb2.pyi +43 -0
  130. flwr/proto/grpcadapter_pb2_grpc.py +66 -0
  131. flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
  132. flwr/proto/message_pb2.py +41 -0
  133. flwr/proto/message_pb2.pyi +122 -0
  134. flwr/proto/message_pb2_grpc.py +4 -0
  135. flwr/proto/message_pb2_grpc.pyi +4 -0
  136. flwr/proto/run_pb2.py +35 -0
  137. flwr/proto/run_pb2.pyi +76 -0
  138. flwr/proto/run_pb2_grpc.py +4 -0
  139. flwr/proto/run_pb2_grpc.pyi +4 -0
  140. flwr/proto/task_pb2.py +7 -8
  141. flwr/proto/task_pb2.pyi +8 -5
  142. flwr/server/__init__.py +4 -8
  143. flwr/server/app.py +298 -350
  144. flwr/server/compat/app.py +6 -57
  145. flwr/server/compat/app_utils.py +5 -4
  146. flwr/server/compat/driver_client_proxy.py +29 -48
  147. flwr/server/compat/legacy_context.py +5 -4
  148. flwr/server/driver/__init__.py +2 -0
  149. flwr/server/driver/driver.py +22 -132
  150. flwr/server/driver/grpc_driver.py +224 -74
  151. flwr/server/driver/inmemory_driver.py +183 -0
  152. flwr/server/history.py +20 -20
  153. flwr/server/run_serverapp.py +121 -34
  154. flwr/server/server.py +11 -7
  155. flwr/server/server_app.py +59 -10
  156. flwr/server/serverapp_components.py +52 -0
  157. flwr/server/strategy/__init__.py +2 -2
  158. flwr/server/strategy/bulyan.py +1 -1
  159. flwr/server/strategy/dp_adaptive_clipping.py +3 -3
  160. flwr/server/strategy/dp_fixed_clipping.py +4 -3
  161. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  162. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  163. flwr/server/strategy/fedadagrad.py +1 -1
  164. flwr/server/strategy/fedadam.py +1 -1
  165. flwr/server/strategy/fedavg_android.py +1 -1
  166. flwr/server/strategy/fedavgm.py +1 -1
  167. flwr/server/strategy/fedmedian.py +1 -1
  168. flwr/server/strategy/fedopt.py +1 -1
  169. flwr/server/strategy/fedprox.py +1 -1
  170. flwr/server/strategy/fedxgb_bagging.py +1 -1
  171. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  172. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  173. flwr/server/strategy/fedyogi.py +1 -1
  174. flwr/server/strategy/krum.py +1 -1
  175. flwr/server/strategy/qfedavg.py +1 -1
  176. flwr/server/superlink/driver/__init__.py +1 -1
  177. flwr/server/superlink/driver/driver_grpc.py +1 -1
  178. flwr/server/superlink/driver/driver_servicer.py +51 -4
  179. flwr/server/superlink/ffs/__init__.py +24 -0
  180. flwr/server/superlink/ffs/disk_ffs.py +104 -0
  181. flwr/server/superlink/ffs/ffs.py +79 -0
  182. flwr/server/superlink/fleet/__init__.py +1 -1
  183. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  184. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
  185. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  186. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  187. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  188. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  189. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
  190. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  191. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
  192. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
  193. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  194. flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
  195. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  196. flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
  197. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  198. flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
  199. flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
  200. flwr/server/superlink/fleet/vce/vce_api.py +190 -127
  201. flwr/server/superlink/state/__init__.py +1 -1
  202. flwr/server/superlink/state/in_memory_state.py +159 -42
  203. flwr/server/superlink/state/sqlite_state.py +243 -39
  204. flwr/server/superlink/state/state.py +81 -6
  205. flwr/server/superlink/state/state_factory.py +11 -2
  206. flwr/server/superlink/state/utils.py +62 -0
  207. flwr/server/typing.py +2 -0
  208. flwr/server/utils/__init__.py +1 -1
  209. flwr/server/utils/tensorboard.py +1 -1
  210. flwr/server/utils/validator.py +23 -9
  211. flwr/server/workflow/default_workflows.py +67 -25
  212. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
  213. flwr/simulation/__init__.py +7 -4
  214. flwr/simulation/app.py +67 -36
  215. flwr/simulation/ray_transport/__init__.py +1 -1
  216. flwr/simulation/ray_transport/ray_actor.py +20 -46
  217. flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
  218. flwr/simulation/run_simulation.py +308 -92
  219. flwr/superexec/__init__.py +21 -0
  220. flwr/superexec/app.py +184 -0
  221. flwr/superexec/deployment.py +185 -0
  222. flwr/superexec/exec_grpc.py +55 -0
  223. flwr/superexec/exec_servicer.py +70 -0
  224. flwr/superexec/executor.py +75 -0
  225. flwr/superexec/simulation.py +193 -0
  226. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
  227. flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
  228. flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
  229. flwr/cli/flower_toml.py +0 -140
  230. flwr/cli/new/templates/app/flower.toml.tpl +0 -13
  231. flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
  232. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
  233. flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
  234. flwr_nightly-1.8.0.dev20240314.dist-info/RECORD +0 -211
  235. flwr_nightly-1.8.0.dev20240314.dist-info/entry_points.txt +0 -9
  236. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
  237. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
@@ -0,0 +1,93 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import mlx.optimizers as optim
6
+ from flwr.client import NumPyClient, ClientApp
7
+ from flwr.common import Context
8
+
9
+ from $import_name.task import (
10
+ batch_iterate,
11
+ eval_fn,
12
+ get_params,
13
+ load_data,
14
+ loss_fn,
15
+ set_params,
16
+ MLP,
17
+ )
18
+
19
+
20
+ # Define Flower Client and client_fn
21
+ class FlowerClient(NumPyClient):
22
+ def __init__(
23
+ self,
24
+ data,
25
+ num_layers,
26
+ hidden_dim,
27
+ num_classes,
28
+ batch_size,
29
+ learning_rate,
30
+ num_epochs,
31
+ ):
32
+ self.num_layers = num_layers
33
+ self.hidden_dim = hidden_dim
34
+ self.num_classes = num_classes
35
+ self.batch_size = batch_size
36
+ self.learning_rate = learning_rate
37
+ self.num_epochs = num_epochs
38
+
39
+ self.train_images, self.train_labels, self.test_images, self.test_labels = data
40
+ self.model = MLP(
41
+ num_layers, self.train_images.shape[-1], hidden_dim, num_classes
42
+ )
43
+ self.optimizer = optim.SGD(learning_rate=learning_rate)
44
+ self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn)
45
+ self.num_epochs = num_epochs
46
+ self.batch_size = batch_size
47
+
48
+ def get_parameters(self, config):
49
+ return get_params(self.model)
50
+
51
+ def set_parameters(self, parameters):
52
+ set_params(self.model, parameters)
53
+
54
+ def fit(self, parameters, config):
55
+ self.set_parameters(parameters)
56
+ for _ in range(self.num_epochs):
57
+ for X, y in batch_iterate(
58
+ self.batch_size, self.train_images, self.train_labels
59
+ ):
60
+ _, grads = self.loss_and_grad_fn(self.model, X, y)
61
+ self.optimizer.update(self.model, grads)
62
+ mx.eval(self.model.parameters(), self.optimizer.state)
63
+ return self.get_parameters(config={}), len(self.train_images), {}
64
+
65
+ def evaluate(self, parameters, config):
66
+ self.set_parameters(parameters)
67
+ accuracy = eval_fn(self.model, self.test_images, self.test_labels)
68
+ loss = loss_fn(self.model, self.test_images, self.test_labels)
69
+ return loss.item(), len(self.test_images), {"accuracy": accuracy.item()}
70
+
71
+
72
+ def client_fn(context: Context):
73
+ partition_id = context.node_config["partition-id"]
74
+ num_partitions = context.node_config["num-partitions"]
75
+ data = load_data(partition_id, num_partitions)
76
+
77
+ num_layers = context.run_config["num-layers"]
78
+ hidden_dim = context.run_config["hidden-dim"]
79
+ num_classes = 10
80
+ batch_size = context.run_config["batch-size"]
81
+ learning_rate = context.run_config["lr"]
82
+ num_epochs = context.run_config["local-epochs"]
83
+
84
+ # Return Client instance
85
+ return FlowerClient(
86
+ data, num_layers, hidden_dim, num_classes, batch_size, learning_rate, num_epochs
87
+ ).to_client()
88
+
89
+
90
+ # Flower ClientApp
91
+ app = ClientApp(
92
+ client_fn,
93
+ )
@@ -1,6 +1,7 @@
1
- """$project_name: A Flower / NumPy app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  from flwr.client import NumPyClient, ClientApp
4
+ from flwr.common import Context
4
5
  import numpy as np
5
6
 
6
7
 
@@ -15,7 +16,7 @@ class FlowerClient(NumPyClient):
15
16
  return float(0.0), 1, {"accuracy": float(1.0)}
16
17
 
17
18
 
18
- def client_fn(cid: str):
19
+ def client_fn(context: Context):
19
20
  return FlowerClient().to_client()
20
21
 
21
22
 
@@ -1,10 +1,11 @@
1
- """$project_name: A Flower / PyTorch app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
+ import torch
3
4
  from flwr.client import NumPyClient, ClientApp
5
+ from flwr.common import Context
4
6
 
5
- from $project_name.task import (
7
+ from $import_name.task import (
6
8
  Net,
7
- DEVICE,
8
9
  load_data,
9
10
  get_weights,
10
11
  set_weights,
@@ -15,29 +16,40 @@ from $project_name.task import (
15
16
 
16
17
  # Define Flower Client and client_fn
17
18
  class FlowerClient(NumPyClient):
18
- def __init__(self, net, trainloader, valloader):
19
+ def __init__(self, net, trainloader, valloader, local_epochs):
19
20
  self.net = net
20
21
  self.trainloader = trainloader
21
22
  self.valloader = valloader
23
+ self.local_epochs = local_epochs
24
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
+ self.net.to(self.device)
22
26
 
23
27
  def fit(self, parameters, config):
24
28
  set_weights(self.net, parameters)
25
- results = train(self.net, self.trainloader, self.valloader, 1, DEVICE)
26
- return get_weights(self.net), len(self.trainloader.dataset), results
29
+ train_loss = train(
30
+ self.net,
31
+ self.trainloader,
32
+ self.local_epochs,
33
+ self.device,
34
+ )
35
+ return get_weights(self.net), len(self.trainloader.dataset), {"train_loss": train_loss}
27
36
 
28
37
  def evaluate(self, parameters, config):
29
38
  set_weights(self.net, parameters)
30
- loss, accuracy = test(self.net, self.valloader)
39
+ loss, accuracy = test(self.net, self.valloader, self.device)
31
40
  return loss, len(self.valloader.dataset), {"accuracy": accuracy}
32
41
 
33
42
 
34
- def client_fn(cid):
43
+ def client_fn(context: Context):
35
44
  # Load model and data
36
- net = Net().to(DEVICE)
37
- trainloader, valloader = load_data(int(cid), 2)
45
+ net = Net()
46
+ partition_id = context.node_config["partition-id"]
47
+ num_partitions = context.node_config["num-partitions"]
48
+ trainloader, valloader = load_data(partition_id, num_partitions)
49
+ local_epochs = context.run_config["local-epochs"]
38
50
 
39
51
  # Return Client instance
40
- return FlowerClient(net, trainloader, valloader).to_client()
52
+ return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
41
53
 
42
54
 
43
55
  # Flower ClientApp
@@ -0,0 +1,97 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import warnings
4
+
5
+ import numpy as np
6
+ from flwr.client import NumPyClient, ClientApp
7
+ from flwr.common import Context
8
+ from flwr_datasets import FederatedDataset
9
+ from sklearn.linear_model import LogisticRegression
10
+ from sklearn.metrics import log_loss
11
+
12
+
13
+ def get_model_parameters(model):
14
+ if model.fit_intercept:
15
+ params = [
16
+ model.coef_,
17
+ model.intercept_,
18
+ ]
19
+ else:
20
+ params = [model.coef_]
21
+ return params
22
+
23
+
24
+ def set_model_params(model, params):
25
+ model.coef_ = params[0]
26
+ if model.fit_intercept:
27
+ model.intercept_ = params[1]
28
+ return model
29
+
30
+
31
+ def set_initial_params(model):
32
+ n_classes = 10 # MNIST has 10 classes
33
+ n_features = 784 # Number of features in dataset
34
+ model.classes_ = np.array([i for i in range(10)])
35
+
36
+ model.coef_ = np.zeros((n_classes, n_features))
37
+ if model.fit_intercept:
38
+ model.intercept_ = np.zeros((n_classes,))
39
+
40
+
41
+ class FlowerClient(NumPyClient):
42
+ def __init__(self, model, X_train, X_test, y_train, y_test):
43
+ self.model = model
44
+ self.X_train = X_train
45
+ self.X_test = X_test
46
+ self.y_train = y_train
47
+ self.y_test = y_test
48
+
49
+ def get_parameters(self, config):
50
+ return get_model_parameters(self.model)
51
+
52
+ def fit(self, parameters, config):
53
+ set_model_params(self.model, parameters)
54
+
55
+ # Ignore convergence failure due to low local epochs
56
+ with warnings.catch_warnings():
57
+ warnings.simplefilter("ignore")
58
+ self.model.fit(self.X_train, self.y_train)
59
+
60
+ return get_model_parameters(self.model), len(self.X_train), {}
61
+
62
+ def evaluate(self, parameters, config):
63
+ set_model_params(self.model, parameters)
64
+
65
+ loss = log_loss(self.y_test, self.model.predict_proba(self.X_test))
66
+ accuracy = self.model.score(self.X_test, self.y_test)
67
+
68
+ return loss, len(self.X_test), {"accuracy": accuracy}
69
+
70
+
71
+ def client_fn(context: Context):
72
+ partition_id = context.node_config["partition-id"]
73
+ num_partitions = context.node_config["num-partitions"]
74
+ fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions})
75
+ dataset = fds.load_partition(partition_id, "train").with_format("numpy")
76
+
77
+ X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]
78
+
79
+ # Split the on edge data: 80% train, 20% test
80
+ X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :]
81
+ y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :]
82
+
83
+ # Create LogisticRegression Model
84
+ model = LogisticRegression(
85
+ penalty="l2",
86
+ max_iter=1, # local epoch
87
+ warm_start=True, # prevent refreshing weights when fitting
88
+ )
89
+
90
+ # Setting initial parameters, akin to model.compile for keras models
91
+ set_initial_params(model)
92
+
93
+ return FlowerClient(model, X_train, X_test, y_train, y_test).to_client()
94
+
95
+
96
+ # Flower ClientApp
97
+ app = ClientApp(client_fn=client_fn)
@@ -1 +1,60 @@
1
- """$project_name: A Flower / TensorFlow app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ from flwr.client import NumPyClient, ClientApp
4
+ from flwr.common import Context
5
+
6
+ from $import_name.task import load_data, load_model
7
+
8
+
9
+ # Define Flower Client and client_fn
10
+ class FlowerClient(NumPyClient):
11
+ def __init__(
12
+ self, model, data, epochs, batch_size, verbose
13
+ ):
14
+ self.model = model
15
+ self.x_train, self.y_train, self.x_test, self.y_test = data
16
+ self.epochs = epochs
17
+ self.batch_size = batch_size
18
+ self.verbose = verbose
19
+
20
+ def get_parameters(self, config):
21
+ return self.model.get_weights()
22
+
23
+ def fit(self, parameters, config):
24
+ self.model.set_weights(parameters)
25
+ self.model.fit(
26
+ self.x_train,
27
+ self.y_train,
28
+ epochs=self.epochs,
29
+ batch_size=self.batch_size,
30
+ verbose=self.verbose,
31
+ )
32
+ return self.model.get_weights(), len(self.x_train), {}
33
+
34
+ def evaluate(self, parameters, config):
35
+ self.model.set_weights(parameters)
36
+ loss, accuracy = self.model.evaluate(self.x_test, self.y_test, verbose=0)
37
+ return loss, len(self.x_test), {"accuracy": accuracy}
38
+
39
+
40
+ def client_fn(context: Context):
41
+ # Load model and data
42
+ net = load_model()
43
+
44
+ partition_id = context.node_config["partition-id"]
45
+ num_partitions = context.node_config["num-partitions"]
46
+ data = load_data(partition_id, num_partitions)
47
+ epochs = context.run_config["local-epochs"]
48
+ batch_size = context.run_config["batch-size"]
49
+ verbose = context.run_config.get("verbose")
50
+
51
+ # Return Client instance
52
+ return FlowerClient(
53
+ net, data, epochs, batch_size, verbose
54
+ ).to_client()
55
+
56
+
57
+ # Flower ClientApp
58
+ app = ClientApp(
59
+ client_fn=client_fn,
60
+ )
@@ -0,0 +1,15 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower CLI `new` command app / code / flwr_tune templates."""
@@ -0,0 +1,89 @@
1
+ """$project_name: A Flower / FlowerTune app."""
2
+
3
+ import os
4
+ import warnings
5
+ from datetime import datetime
6
+
7
+ from flwr_datasets import FederatedDataset
8
+ from hydra import compose, initialize
9
+ from hydra.utils import instantiate
10
+
11
+ from flwr.client import ClientApp
12
+ from flwr.common import Context, ndarrays_to_parameters
13
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
14
+
15
+ from $import_name.client_app import gen_client_fn, get_parameters
16
+ from $import_name.dataset import get_tokenizer_and_data_collator_and_propt_formatting
17
+ from $import_name.models import get_model
18
+ from $import_name.server_app import fit_weighted_average, get_evaluate_fn, get_on_fit_config
19
+
20
+ # Avoid warnings
21
+ warnings.filterwarnings("ignore", category=UserWarning)
22
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
23
+ os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
24
+
25
+ # Initialise regular config
26
+ with initialize(config_path="conf", version_base="1.1"):
27
+ cfg = compose(config_name="config")
28
+
29
+ # Initialise static config
30
+ with initialize(config_path="conf", version_base="1.1"):
31
+ cfg_static = compose(config_name="static_config")
32
+
33
+ cfg.train.num_rounds = cfg_static.num_rounds
34
+
35
+ # Create output directory given current timestamp
36
+ current_time = datetime.now()
37
+ folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
38
+ save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
39
+ os.makedirs(save_path, exist_ok=True)
40
+
41
+ # Partition dataset and get dataloaders
42
+ partitioner = instantiate(cfg_static.partitioner)
43
+ fds = FederatedDataset(
44
+ dataset=cfg_static.dataset.name, partitioners={"train": partitioner}
45
+ )
46
+ (
47
+ tokenizer,
48
+ data_collator,
49
+ formatting_prompts_func,
50
+ ) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name)
51
+
52
+ # ClientApp for Flower Next
53
+ client = ClientApp(
54
+ client_fn=gen_client_fn(
55
+ fds,
56
+ tokenizer,
57
+ formatting_prompts_func,
58
+ data_collator,
59
+ cfg.model,
60
+ cfg.train,
61
+ save_path,
62
+ ),
63
+ )
64
+
65
+ # Get initial model weights
66
+ init_model = get_model(cfg.model)
67
+ init_model_parameters = get_parameters(init_model)
68
+ init_model_parameters = ndarrays_to_parameters(init_model_parameters)
69
+
70
+ def server_fn(context: Context):
71
+ # Instantiate strategy according to config. Here we pass other arguments
72
+ # that are only defined at runtime.
73
+ strategy = instantiate(
74
+ cfg.strategy,
75
+ on_fit_config_fn=get_on_fit_config(),
76
+ fit_metrics_aggregation_fn=fit_weighted_average,
77
+ initial_parameters=init_model_parameters,
78
+ evaluate_fn=get_evaluate_fn(
79
+ cfg.model, cfg.train.save_every_round, cfg_static.num_rounds, save_path
80
+ ),
81
+ )
82
+
83
+ config = ServerConfig(num_rounds=cfg_static.num_rounds)
84
+
85
+ return ServerAppComponents(strategy=strategy, config=config)
86
+
87
+
88
+ # ServerApp for Flower Next
89
+ server = ServerApp(server_fn=server_fn)
@@ -0,0 +1,126 @@
1
+ """$project_name: A Flower / FlowerTune app."""
2
+
3
+ from collections import OrderedDict
4
+ from typing import Callable, Dict, Tuple
5
+
6
+ import torch
7
+ from omegaconf import DictConfig
8
+ from peft import get_peft_model_state_dict, set_peft_model_state_dict
9
+ from transformers import TrainingArguments
10
+ from trl import SFTTrainer
11
+
12
+ from flwr.client import NumPyClient
13
+ from flwr.common import Context
14
+ from flwr.common.typing import NDArrays, Scalar
15
+ from $import_name.dataset import reformat
16
+ from $import_name.models import cosine_annealing, get_model
17
+
18
+
19
+ # pylint: disable=too-many-arguments
20
+ # pylint: disable=too-many-instance-attributes
21
+ class FlowerClient(NumPyClient):
22
+ """Standard Flower client for CNN training."""
23
+
24
+ def __init__(
25
+ self,
26
+ model_cfg: DictConfig,
27
+ train_cfg: DictConfig,
28
+ trainset,
29
+ tokenizer,
30
+ formatting_prompts_func,
31
+ data_collator,
32
+ save_path,
33
+ ): # pylint: disable=too-many-arguments
34
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
+ self.train_cfg = train_cfg
36
+ self.training_argumnets = TrainingArguments(**train_cfg.training_arguments)
37
+ self.tokenizer = tokenizer
38
+ self.formatting_prompts_func = formatting_prompts_func
39
+ self.data_collator = data_collator
40
+ self.save_path = save_path
41
+
42
+ # instantiate model
43
+ self.model = get_model(model_cfg)
44
+
45
+ self.trainset = trainset
46
+
47
+ def fit(
48
+ self, parameters: NDArrays, config: Dict[str, Scalar]
49
+ ) -> Tuple[NDArrays, int, Dict]:
50
+ """Implement distributed fit function for a given client."""
51
+ set_parameters(self.model, parameters)
52
+
53
+ new_lr = cosine_annealing(
54
+ int(config["current_round"]),
55
+ self.train_cfg.num_rounds,
56
+ self.train_cfg.learning_rate_max,
57
+ self.train_cfg.learning_rate_min,
58
+ )
59
+
60
+ self.training_argumnets.learning_rate = new_lr
61
+ self.training_argumnets.output_dir = self.save_path
62
+
63
+ # Construct trainer
64
+ trainer = SFTTrainer(
65
+ model=self.model,
66
+ tokenizer=self.tokenizer,
67
+ args=self.training_argumnets,
68
+ max_seq_length=self.train_cfg.seq_length,
69
+ train_dataset=self.trainset,
70
+ formatting_func=self.formatting_prompts_func,
71
+ data_collator=self.data_collator,
72
+ )
73
+
74
+ # Do local training
75
+ results = trainer.train()
76
+
77
+ return (
78
+ get_parameters(self.model),
79
+ len(self.trainset),
80
+ {"train_loss": results.training_loss},
81
+ )
82
+
83
+
84
+ def set_parameters(model, parameters: NDArrays) -> None:
85
+ """Change the parameters of the model using the given ones."""
86
+ peft_state_dict_keys = get_peft_model_state_dict(model).keys()
87
+ params_dict = zip(peft_state_dict_keys, parameters)
88
+ state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
89
+ set_peft_model_state_dict(model, state_dict)
90
+
91
+
92
+ def get_parameters(model) -> NDArrays:
93
+ """Return the parameters of the current net."""
94
+ state_dict = get_peft_model_state_dict(model)
95
+ return [val.cpu().numpy() for _, val in state_dict.items()]
96
+
97
+
98
+ def gen_client_fn(
99
+ fds,
100
+ tokenizer,
101
+ formatting_prompts_func,
102
+ data_collator,
103
+ model_cfg: DictConfig,
104
+ train_cfg: DictConfig,
105
+ save_path: str,
106
+ ) -> Callable[[Context], FlowerClient]: # pylint: disable=too-many-arguments
107
+ """Generate the client function that creates the Flower Clients."""
108
+
109
+ def client_fn(context: Context) -> FlowerClient:
110
+ """Create a Flower client representing a single organization."""
111
+ # Let's get the partition corresponding to the i-th client
112
+ partition_id = context.node_config["partition-id"]
113
+ client_trainset = fds.load_partition(partition_id, "train")
114
+ client_trainset = reformat(client_trainset, llm_task="$llm_challenge_str")
115
+
116
+ return FlowerClient(
117
+ model_cfg,
118
+ train_cfg,
119
+ client_trainset,
120
+ tokenizer,
121
+ formatting_prompts_func,
122
+ data_collator,
123
+ save_path,
124
+ ).to_client()
125
+
126
+ return client_fn
@@ -0,0 +1,34 @@
1
+ # Federated Instruction Tuning
2
+ ---
3
+ model:
4
+ name: "mistralai/Mistral-7B-v0.3"
5
+ quantization: 4 # 8 or 4 if you want to do quantization with BitsAndBytes
6
+ gradient_checkpointing: True
7
+ lora:
8
+ peft_lora_r: 32
9
+ peft_lora_alpha: 64
10
+
11
+ train:
12
+ num_rounds: null
13
+ save_every_round: 5
14
+ learning_rate_max: 5e-5
15
+ learning_rate_min: 1e-6
16
+ seq_length: 512
17
+ training_arguments:
18
+ output_dir: null # to be set by hydra
19
+ learning_rate: null # to be set by the client
20
+ per_device_train_batch_size: 16
21
+ gradient_accumulation_steps: 1
22
+ logging_steps: 10
23
+ num_train_epochs: 3
24
+ max_steps: 10
25
+ report_to: null
26
+ save_steps: 1000
27
+ save_total_limit: 10
28
+ gradient_checkpointing: True
29
+ lr_scheduler_type: "constant"
30
+
31
+ strategy:
32
+ _target_: flwr.server.strategy.FedAvg
33
+ fraction_fit: $fraction_fit
34
+ fraction_evaluate: 0.0 # no client evaluation
@@ -0,0 +1,57 @@
1
+ """$project_name: A Flower / FlowerTune app."""
2
+
3
+ from transformers import AutoTokenizer
4
+ from trl import DataCollatorForCompletionOnlyLM
5
+
6
+
7
+ def formatting_prompts_func(example):
8
+ """Construct prompts."""
9
+ output_texts = []
10
+ # Constructing a standard Alpaca
11
+ # (https://github.com/tatsu-lab/stanford_alpaca#data-release) prompt
12
+ mssg = (
13
+ "Below is an instruction that describes a task. "
14
+ "Write a response that appropriately completes the request."
15
+ )
16
+ for i in range(len(example["instruction"])):
17
+ text = (
18
+ f"{mssg}\n### Instruction:\n{example['instruction'][i]}\n"
19
+ f"### Response: {example['response'][i]}"
20
+ )
21
+ output_texts.append(text)
22
+ return output_texts
23
+
24
+
25
+ def get_tokenizer_and_data_collator_and_propt_formatting(model_name: str):
26
+ """Get tokenizer, data_collator and prompt formatting."""
27
+ # From: https://huggingface.co/docs/trl/en/sft_trainer
28
+ tokenizer = AutoTokenizer.from_pretrained(
29
+ model_name, use_fast=True, padding_side="right"
30
+ )
31
+ tokenizer.pad_token = tokenizer.eos_token
32
+ response_template_with_context = "\n### Response:" # alpaca response tag
33
+ response_template_ids = tokenizer.encode(
34
+ response_template_with_context, add_special_tokens=False
35
+ )[2:]
36
+ data_collator = DataCollatorForCompletionOnlyLM(
37
+ response_template_ids, tokenizer=tokenizer
38
+ )
39
+
40
+ return tokenizer, data_collator, formatting_prompts_func
41
+
42
+
43
+ def formatting(dataset):
44
+ """Format dataset."""
45
+ dataset["instruction"] = dataset["instruction"] + " " + dataset["input"]
46
+ return dataset
47
+
48
+
49
+ def reformat(dataset, llm_task):
50
+ """Reformat datasets."""
51
+ dataset = dataset.rename_column("output", "response")
52
+ if llm_task == "finance" or llm_task == "code":
53
+ dataset = dataset.map(formatting, remove_columns=["input"])
54
+ if llm_task == "medical":
55
+ dataset = dataset.remove_columns(["instruction"])
56
+ dataset = dataset.rename_column("input", "instruction")
57
+ return dataset