flwr 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. flwr/__init__.py +4 -1
  2. flwr/app/__init__.py +28 -0
  3. flwr/app/exception.py +31 -0
  4. flwr/cli/app.py +2 -0
  5. flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
  6. flwr/cli/cli_user_auth_interceptor.py +1 -1
  7. flwr/cli/config_utils.py +3 -3
  8. flwr/cli/constant.py +25 -8
  9. flwr/cli/log.py +9 -9
  10. flwr/cli/login/login.py +3 -3
  11. flwr/cli/ls.py +5 -5
  12. flwr/cli/new/new.py +15 -2
  13. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  14. flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +1 -0
  15. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  16. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
  17. flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
  18. flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
  19. flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
  20. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
  21. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
  22. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
  23. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
  24. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
  25. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  26. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  27. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  28. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  29. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  30. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  31. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
  32. flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
  33. flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
  34. flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
  35. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
  36. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
  37. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
  38. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
  39. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
  40. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
  41. flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
  42. flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
  43. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
  44. flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +111 -0
  45. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
  46. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
  47. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  48. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  49. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  50. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  51. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  52. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  53. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  54. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +53 -0
  55. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  56. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  57. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
  58. flwr/cli/pull.py +100 -0
  59. flwr/cli/run/run.py +9 -13
  60. flwr/cli/stop.py +7 -4
  61. flwr/cli/utils.py +36 -8
  62. flwr/client/grpc_rere_client/connection.py +1 -12
  63. flwr/client/rest_client/connection.py +3 -0
  64. flwr/clientapp/__init__.py +10 -0
  65. flwr/clientapp/mod/__init__.py +29 -0
  66. flwr/clientapp/mod/centraldp_mods.py +248 -0
  67. flwr/clientapp/mod/localdp_mod.py +169 -0
  68. flwr/clientapp/typing.py +22 -0
  69. flwr/common/args.py +20 -6
  70. flwr/common/auth_plugin/__init__.py +4 -4
  71. flwr/common/auth_plugin/auth_plugin.py +7 -7
  72. flwr/common/constant.py +26 -4
  73. flwr/common/event_log_plugin/event_log_plugin.py +1 -1
  74. flwr/common/exit/__init__.py +4 -0
  75. flwr/common/exit/exit.py +8 -1
  76. flwr/common/exit/exit_code.py +30 -7
  77. flwr/common/exit/exit_handler.py +62 -0
  78. flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
  79. flwr/common/grpc.py +0 -11
  80. flwr/common/inflatable_utils.py +1 -1
  81. flwr/common/logger.py +1 -1
  82. flwr/common/record/typeddict.py +12 -0
  83. flwr/common/retry_invoker.py +30 -11
  84. flwr/common/telemetry.py +4 -0
  85. flwr/compat/server/app.py +2 -2
  86. flwr/proto/appio_pb2.py +25 -17
  87. flwr/proto/appio_pb2.pyi +46 -2
  88. flwr/proto/clientappio_pb2.py +3 -11
  89. flwr/proto/clientappio_pb2.pyi +0 -47
  90. flwr/proto/clientappio_pb2_grpc.py +19 -20
  91. flwr/proto/clientappio_pb2_grpc.pyi +10 -11
  92. flwr/proto/control_pb2.py +66 -0
  93. flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +24 -0
  94. flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +88 -54
  95. flwr/proto/control_pb2_grpc.pyi +106 -0
  96. flwr/proto/serverappio_pb2.py +2 -2
  97. flwr/proto/serverappio_pb2_grpc.py +68 -0
  98. flwr/proto/serverappio_pb2_grpc.pyi +26 -0
  99. flwr/proto/simulationio_pb2.py +4 -11
  100. flwr/proto/simulationio_pb2.pyi +0 -58
  101. flwr/proto/simulationio_pb2_grpc.py +129 -27
  102. flwr/proto/simulationio_pb2_grpc.pyi +52 -13
  103. flwr/server/app.py +142 -152
  104. flwr/server/grid/grpc_grid.py +3 -0
  105. flwr/server/grid/inmemory_grid.py +1 -0
  106. flwr/server/serverapp/app.py +157 -146
  107. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
  108. flwr/server/superlink/fleet/vce/vce_api.py +6 -6
  109. flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
  110. flwr/server/superlink/linkstate/linkstate.py +2 -1
  111. flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
  112. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -1
  113. flwr/server/superlink/serverappio/serverappio_servicer.py +61 -6
  114. flwr/server/superlink/simulation/simulationio_servicer.py +97 -21
  115. flwr/serverapp/__init__.py +12 -0
  116. flwr/serverapp/exception.py +38 -0
  117. flwr/serverapp/strategy/__init__.py +64 -0
  118. flwr/serverapp/strategy/bulyan.py +238 -0
  119. flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
  120. flwr/serverapp/strategy/dp_fixed_clipping.py +374 -0
  121. flwr/serverapp/strategy/fedadagrad.py +159 -0
  122. flwr/serverapp/strategy/fedadam.py +178 -0
  123. flwr/serverapp/strategy/fedavg.py +320 -0
  124. flwr/serverapp/strategy/fedavgm.py +198 -0
  125. flwr/serverapp/strategy/fedmedian.py +105 -0
  126. flwr/serverapp/strategy/fedopt.py +218 -0
  127. flwr/serverapp/strategy/fedprox.py +174 -0
  128. flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
  129. flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
  130. flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
  131. flwr/serverapp/strategy/fedyogi.py +170 -0
  132. flwr/serverapp/strategy/krum.py +112 -0
  133. flwr/serverapp/strategy/multikrum.py +247 -0
  134. flwr/serverapp/strategy/qfedavg.py +252 -0
  135. flwr/serverapp/strategy/result.py +105 -0
  136. flwr/serverapp/strategy/strategy.py +285 -0
  137. flwr/serverapp/strategy/strategy_utils.py +299 -0
  138. flwr/simulation/app.py +161 -164
  139. flwr/simulation/run_simulation.py +25 -30
  140. flwr/supercore/app_utils.py +58 -0
  141. flwr/{supernode/scheduler → supercore/cli}/__init__.py +3 -3
  142. flwr/supercore/cli/flower_superexec.py +166 -0
  143. flwr/supercore/constant.py +19 -0
  144. flwr/supercore/{scheduler → corestate}/__init__.py +3 -3
  145. flwr/supercore/corestate/corestate.py +81 -0
  146. flwr/supercore/grpc_health/__init__.py +3 -0
  147. flwr/supercore/grpc_health/health_server.py +53 -0
  148. flwr/supercore/grpc_health/simple_health_servicer.py +2 -2
  149. flwr/{superexec → supercore/superexec}/__init__.py +1 -1
  150. flwr/supercore/superexec/plugin/__init__.py +28 -0
  151. flwr/{supernode/scheduler/simple_clientapp_scheduler_plugin.py → supercore/superexec/plugin/base_exec_plugin.py} +10 -6
  152. flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
  153. flwr/supercore/{scheduler/plugin.py → superexec/plugin/exec_plugin.py} +15 -5
  154. flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
  155. flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
  156. flwr/supercore/superexec/run_superexec.py +199 -0
  157. flwr/superlink/artifact_provider/__init__.py +22 -0
  158. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  159. flwr/superlink/servicer/__init__.py +15 -0
  160. flwr/superlink/servicer/control/__init__.py +22 -0
  161. flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +7 -7
  162. flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +27 -29
  163. flwr/{superexec/exec_license_interceptor.py → superlink/servicer/control/control_license_interceptor.py} +6 -6
  164. flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +127 -31
  165. flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +10 -10
  166. flwr/supernode/cli/flower_supernode.py +3 -0
  167. flwr/supernode/cli/flwr_clientapp.py +18 -21
  168. flwr/supernode/nodestate/in_memory_nodestate.py +2 -2
  169. flwr/supernode/nodestate/nodestate.py +3 -59
  170. flwr/supernode/runtime/run_clientapp.py +39 -102
  171. flwr/supernode/servicer/clientappio/clientappio_servicer.py +10 -17
  172. flwr/supernode/start_client_internal.py +35 -76
  173. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/METADATA +9 -18
  174. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/RECORD +176 -128
  175. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/entry_points.txt +1 -0
  176. flwr/proto/exec_pb2.py +0 -62
  177. flwr/proto/exec_pb2_grpc.pyi +0 -93
  178. flwr/superexec/app.py +0 -45
  179. flwr/superexec/deployment.py +0 -191
  180. flwr/superexec/executor.py +0 -100
  181. flwr/superexec/simulation.py +0 -129
  182. {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/WHEEL +0 -0
@@ -1,50 +1,71 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import jax
4
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
5
+ from flwr.clientapp import ClientApp
4
6
 
5
- from flwr.client import ClientApp, NumPyClient
6
- from flwr.common import Context
7
- from $import_name.task import (
8
- evaluation,
9
- get_params,
10
- load_data,
11
- load_model,
12
- loss_fn,
13
- set_params,
14
- train,
15
- )
16
-
17
-
18
- # Define Flower Client and client_fn
19
- class FlowerClient(NumPyClient):
20
- def __init__(self, input_dim):
21
- self.train_x, self.train_y, self.test_x, self.test_y = load_data()
22
- self.grad_fn = jax.grad(loss_fn)
23
- self.params = load_model((input_dim,))
24
-
25
- def fit(self, parameters, config):
26
- set_params(self.params, parameters)
27
- self.params, loss, num_examples = train(
28
- self.params, self.grad_fn, self.train_x, self.train_y
29
- )
30
- return get_params(self.params), num_examples, {"loss": float(loss)}
31
-
32
- def evaluate(self, parameters, config):
33
- set_params(self.params, parameters)
34
- loss, num_examples = evaluation(
35
- self.params, self.grad_fn, self.test_x, self.test_y
36
- )
37
- return float(loss), num_examples, {"loss": float(loss)}
38
-
39
-
40
- def client_fn(context: Context):
7
+ from $import_name.task import evaluation as evaluation_fn
8
+ from $import_name.task import get_params, load_data, load_model, loss_fn, set_params
9
+ from $import_name.task import train as train_fn
10
+
11
+ # Flower ClientApp
12
+ app = ClientApp()
13
+
14
+
15
+ @app.train()
16
+ def train(msg: Message, context: Context):
17
+ """Train the model on local data."""
18
+
19
+ # Read from config
41
20
  input_dim = context.run_config["input-dim"]
42
21
 
43
- # Return Client instance
44
- return FlowerClient(input_dim).to_client()
22
+ # Load data and model
23
+ train_x, train_y, _, _ = load_data()
24
+ model = load_model((input_dim,))
25
+ grad_fn = jax.grad(loss_fn)
45
26
 
27
+ # Set model parameters
28
+ ndarrays = msg.content["arrays"].to_numpy_ndarrays()
29
+ set_params(model, ndarrays)
46
30
 
47
- # Flower ClientApp
48
- app = ClientApp(
49
- client_fn,
50
- )
31
+ # Train the model on local data
32
+ model, loss, num_examples = train_fn(model, grad_fn, train_x, train_y)
33
+
34
+ # Construct and return reply Message
35
+ model_record = ArrayRecord(get_params(model))
36
+ metrics = {
37
+ "train_loss": float(loss),
38
+ "num-examples": num_examples,
39
+ }
40
+ metric_record = MetricRecord(metrics)
41
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
42
+ return Message(content=content, reply_to=msg)
43
+
44
+
45
+ @app.evaluate()
46
+ def evaluate(msg: Message, context: Context):
47
+ """Evaluate the model on local data."""
48
+
49
+ # Read from config
50
+ input_dim = context.run_config["input-dim"]
51
+
52
+ # Load data and model
53
+ _, _, test_x, test_y = load_data()
54
+ model = load_model((input_dim,))
55
+ grad_fn = jax.grad(loss_fn)
56
+
57
+ # Set model parameters
58
+ ndarrays = msg.content["arrays"].to_numpy_ndarrays()
59
+ set_params(model, ndarrays)
60
+
61
+ # Evaluate the model on local data
62
+ loss, num_examples = evaluation_fn(model, grad_fn, test_x, test_y)
63
+
64
+ # Construct and return reply Message
65
+ metrics = {
66
+ "test_loss": float(loss),
67
+ "num-examples": num_examples,
68
+ }
69
+ metric_record = MetricRecord(metrics)
70
+ content = RecordDict({"metrics": metric_record})
71
+ return Message(content=content, reply_to=msg)
@@ -3,10 +3,9 @@
3
3
  import mlx.core as mx
4
4
  import mlx.nn as nn
5
5
  import mlx.optimizers as optim
6
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
7
+ from flwr.clientapp import ClientApp
6
8
 
7
- from flwr.client import ClientApp, NumPyClient
8
- from flwr.common import Context
9
- from flwr.common.config import UserConfig
10
9
  from $import_name.task import (
11
10
  MLP,
12
11
  batch_iterate,
@@ -17,57 +16,87 @@ from $import_name.task import (
17
16
  set_params,
18
17
  )
19
18
 
19
+ # Flower ClientApp
20
+ app = ClientApp()
21
+
22
+
23
+ @app.train()
24
+ def train(msg: Message, context: Context):
25
+ """Train the model on local data."""
26
+
27
+ # Read config
28
+ num_layers = context.run_config["num-layers"]
29
+ input_dim = context.run_config["input-dim"]
30
+ hidden_dim = context.run_config["hidden-dim"]
31
+ batch_size = context.run_config["batch-size"]
32
+ learning_rate = context.run_config["lr"]
33
+ num_epochs = context.run_config["local-epochs"]
20
34
 
21
- # Define Flower Client and client_fn
22
- class FlowerClient(NumPyClient):
23
- def __init__(
24
- self,
25
- data,
26
- run_config: UserConfig,
27
- num_classes,
28
- ):
29
- num_layers = run_config["num-layers"]
30
- hidden_dim = run_config["hidden-dim"]
31
- input_dim = run_config["input-dim"]
32
- batch_size = run_config["batch-size"]
33
- learning_rate = run_config["lr"]
34
- self.num_epochs = run_config["local-epochs"]
35
-
36
- self.train_images, self.train_labels, self.test_images, self.test_labels = data
37
- self.model = MLP(num_layers, input_dim, hidden_dim, num_classes)
38
- self.optimizer = optim.SGD(learning_rate=learning_rate)
39
- self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn)
40
- self.batch_size = batch_size
41
-
42
- def fit(self, parameters, config):
43
- set_params(self.model, parameters)
44
- for _ in range(self.num_epochs):
45
- for X, y in batch_iterate(
46
- self.batch_size, self.train_images, self.train_labels
47
- ):
48
- _, grads = self.loss_and_grad_fn(self.model, X, y)
49
- self.optimizer.update(self.model, grads)
50
- mx.eval(self.model.parameters(), self.optimizer.state)
51
- return get_params(self.model), len(self.train_images), {}
52
-
53
- def evaluate(self, parameters, config):
54
- set_params(self.model, parameters)
55
- accuracy = eval_fn(self.model, self.test_images, self.test_labels)
56
- loss = loss_fn(self.model, self.test_images, self.test_labels)
57
- return loss.item(), len(self.test_images), {"accuracy": accuracy.item()}
58
-
59
-
60
- def client_fn(context: Context):
35
+ # Instantiate model and apply global parameters
36
+ model = MLP(num_layers, input_dim, hidden_dim, output_dim=10)
37
+ ndarrays = msg.content["arrays"].to_numpy_ndarrays()
38
+ set_params(model, ndarrays)
39
+
40
+ # Define optimizer and loss function
41
+ optimizer = optim.SGD(learning_rate=learning_rate)
42
+ loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
43
+
44
+ # Load data
61
45
  partition_id = context.node_config["partition-id"]
62
46
  num_partitions = context.node_config["num-partitions"]
63
- data = load_data(partition_id, num_partitions)
64
- num_classes = 10
47
+ train_images, train_labels, _, _ = load_data(partition_id, num_partitions)
65
48
 
66
- # Return Client instance
67
- return FlowerClient(data, context.run_config, num_classes).to_client()
49
+ # Train the model on local data
50
+ for _ in range(num_epochs):
51
+ for X, y in batch_iterate(batch_size, train_images, train_labels):
52
+ _, grads = loss_and_grad_fn(model, X, y)
53
+ optimizer.update(model, grads)
54
+ mx.eval(model.parameters(), optimizer.state)
68
55
 
56
+ # Compute train accuracy and loss
57
+ accuracy = eval_fn(model, train_images, train_labels)
58
+ loss = loss_fn(model, train_images, train_labels)
59
+ # Construct and return reply Message
60
+ model_record = ArrayRecord(get_params(model))
61
+ metrics = {
62
+ "num-examples": len(train_images),
63
+ "accuracy": float(accuracy.item()),
64
+ "loss": float(loss.item()),
65
+ }
66
+ metric_record = MetricRecord(metrics)
67
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
68
+ return Message(content=content, reply_to=msg)
69
69
 
70
- # Flower ClientApp
71
- app = ClientApp(
72
- client_fn,
73
- )
70
+
71
+ @app.evaluate()
72
+ def evaluate(msg: Message, context: Context):
73
+ """Evaluate the model on local data."""
74
+
75
+ # Read config
76
+ num_layers = context.run_config["num-layers"]
77
+ input_dim = context.run_config["input-dim"]
78
+ hidden_dim = context.run_config["hidden-dim"]
79
+
80
+ # Instantiate model and apply global parameters
81
+ model = MLP(num_layers, input_dim, hidden_dim, output_dim=10)
82
+ ndarrays = msg.content["arrays"].to_numpy_ndarrays()
83
+ set_params(model, ndarrays)
84
+
85
+ # Load data
86
+ partition_id = context.node_config["partition-id"]
87
+ num_partitions = context.node_config["num-partitions"]
88
+ _, _, test_images, test_labels = load_data(partition_id, num_partitions)
89
+
90
+ # Evaluate the model on local data
91
+ accuracy = eval_fn(model, test_images, test_labels)
92
+ loss = loss_fn(model, test_images, test_labels)
93
+
94
+ # Construct and return reply Message
95
+ metrics = {
96
+ "num-examples": len(test_images),
97
+ "accuracy": float(accuracy.item()),
98
+ "loss": float(loss.item()),
99
+ }
100
+ metric_record = MetricRecord(metrics)
101
+ content = RecordDict({"metrics": metric_record})
102
+ return Message(content=content, reply_to=msg)
@@ -1,23 +1,46 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.client import ClientApp, NumPyClient
4
- from flwr.common import Context
5
- from $import_name.task import get_dummy_model
3
+ import numpy as np
4
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
5
+ from flwr.clientapp import ClientApp
6
6
 
7
+ # Flower ClientApp
8
+ app = ClientApp()
7
9
 
8
- class FlowerClient(NumPyClient):
9
10
 
10
- def fit(self, parameters, config):
11
- model = get_dummy_model()
12
- return [model], 1, {}
11
+ @app.train()
12
+ def train(msg: Message, context: Context):
13
+ """Train the model on local data."""
13
14
 
14
- def evaluate(self, parameters, config):
15
- return float(0.0), 1, {"accuracy": float(1.0)}
15
+ # The model is the global arrays
16
+ ndarrays = msg.content["arrays"].to_numpy_ndarrays()
16
17
 
18
+ # Simulate local training (here we just add random noise to model parameters)
19
+ model = [m + np.random.rand(*m.shape) for m in ndarrays]
17
20
 
18
- def client_fn(context: Context):
19
- return FlowerClient().to_client()
21
+ # Construct and return reply Message
22
+ model_record = ArrayRecord(model)
23
+ metrics = {
24
+ "random_metric": np.random.rand(),
25
+ "num-examples": 1,
26
+ }
27
+ metric_record = MetricRecord(metrics)
28
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
29
+ return Message(content=content, reply_to=msg)
20
30
 
21
31
 
22
- # Flower ClientApp
23
- app = ClientApp(client_fn=client_fn)
32
+ @app.evaluate()
33
+ def evaluate(msg: Message, context: Context):
34
+ """Evaluate the model on local data."""
35
+
36
+ # The model is the global arrays
37
+ ndarrays = msg.content["arrays"].to_numpy_ndarrays()
38
+
39
+ # Return reply Message
40
+ metrics = {
41
+ "random_metric": np.random.rand(3).tolist(),
42
+ "num-examples": 1,
43
+ }
44
+ metric_record = MetricRecord(metrics)
45
+ content = RecordDict({"metrics": metric_record})
46
+ return Message(content=content, reply_to=msg)
@@ -1,55 +1,80 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import torch
4
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
5
+ from flwr.clientapp import ClientApp
4
6
 
5
- from flwr.client import ClientApp, NumPyClient
6
- from flwr.common import Context
7
- from $import_name.task import Net, get_weights, load_data, set_weights, test, train
8
-
9
-
10
- # Define Flower Client and client_fn
11
- class FlowerClient(NumPyClient):
12
- def __init__(self, net, trainloader, valloader, local_epochs):
13
- self.net = net
14
- self.trainloader = trainloader
15
- self.valloader = valloader
16
- self.local_epochs = local_epochs
17
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
- self.net.to(self.device)
19
-
20
- def fit(self, parameters, config):
21
- set_weights(self.net, parameters)
22
- train_loss = train(
23
- self.net,
24
- self.trainloader,
25
- self.local_epochs,
26
- self.device,
27
- )
28
- return (
29
- get_weights(self.net),
30
- len(self.trainloader.dataset),
31
- {"train_loss": train_loss},
32
- )
33
-
34
- def evaluate(self, parameters, config):
35
- set_weights(self.net, parameters)
36
- loss, accuracy = test(self.net, self.valloader, self.device)
37
- return loss, len(self.valloader.dataset), {"accuracy": accuracy}
38
-
39
-
40
- def client_fn(context: Context):
41
- # Load model and data
42
- net = Net()
7
+ from $import_name.task import Net, load_data
8
+ from $import_name.task import test as test_fn
9
+ from $import_name.task import train as train_fn
10
+
11
+ # Flower ClientApp
12
+ app = ClientApp()
13
+
14
+
15
+ @app.train()
16
+ def train(msg: Message, context: Context):
17
+ """Train the model on local data."""
18
+
19
+ # Load the model and initialize it with the received weights
20
+ model = Net()
21
+ model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
22
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
+ model.to(device)
24
+
25
+ # Load the data
43
26
  partition_id = context.node_config["partition-id"]
44
27
  num_partitions = context.node_config["num-partitions"]
45
- trainloader, valloader = load_data(partition_id, num_partitions)
46
- local_epochs = context.run_config["local-epochs"]
28
+ trainloader, _ = load_data(partition_id, num_partitions)
47
29
 
48
- # Return Client instance
49
- return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
30
+ # Call the training function
31
+ train_loss = train_fn(
32
+ model,
33
+ trainloader,
34
+ context.run_config["local-epochs"],
35
+ msg.content["config"]["lr"],
36
+ device,
37
+ )
50
38
 
39
+ # Construct and return reply Message
40
+ model_record = ArrayRecord(model.state_dict())
41
+ metrics = {
42
+ "train_loss": train_loss,
43
+ "num-examples": len(trainloader.dataset),
44
+ }
45
+ metric_record = MetricRecord(metrics)
46
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
47
+ return Message(content=content, reply_to=msg)
51
48
 
52
- # Flower ClientApp
53
- app = ClientApp(
54
- client_fn,
55
- )
49
+
50
+ @app.evaluate()
51
+ def evaluate(msg: Message, context: Context):
52
+ """Evaluate the model on local data."""
53
+
54
+ # Load the model and initialize it with the received weights
55
+ model = Net()
56
+ model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
57
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58
+ model.to(device)
59
+
60
+ # Load the data
61
+ partition_id = context.node_config["partition-id"]
62
+ num_partitions = context.node_config["num-partitions"]
63
+ _, valloader = load_data(partition_id, num_partitions)
64
+
65
+ # Call the evaluation function
66
+ eval_loss, eval_acc = test_fn(
67
+ model,
68
+ valloader,
69
+ device,
70
+ )
71
+
72
+ # Construct and return reply Message
73
+ metrics = {
74
+ "eval_loss": eval_loss,
75
+ "eval_acc": eval_acc,
76
+ "num-examples": len(valloader.dataset),
77
+ }
78
+ metric_record = MetricRecord(metrics)
79
+ content = RecordDict({"metrics": metric_record})
80
+ return Message(content=content, reply_to=msg)
@@ -0,0 +1,55 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import torch
4
+
5
+ from flwr.client import ClientApp, NumPyClient
6
+ from flwr.common import Context
7
+ from $import_name.task import Net, get_weights, load_data, set_weights, test, train
8
+
9
+
10
+ # Define Flower Client and client_fn
11
+ class FlowerClient(NumPyClient):
12
+ def __init__(self, net, trainloader, valloader, local_epochs):
13
+ self.net = net
14
+ self.trainloader = trainloader
15
+ self.valloader = valloader
16
+ self.local_epochs = local_epochs
17
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
+ self.net.to(self.device)
19
+
20
+ def fit(self, parameters, config):
21
+ set_weights(self.net, parameters)
22
+ train_loss = train(
23
+ self.net,
24
+ self.trainloader,
25
+ self.local_epochs,
26
+ self.device,
27
+ )
28
+ return (
29
+ get_weights(self.net),
30
+ len(self.trainloader.dataset),
31
+ {"train_loss": train_loss},
32
+ )
33
+
34
+ def evaluate(self, parameters, config):
35
+ set_weights(self.net, parameters)
36
+ loss, accuracy = test(self.net, self.valloader, self.device)
37
+ return loss, len(self.valloader.dataset), {"accuracy": accuracy}
38
+
39
+
40
+ def client_fn(context: Context):
41
+ # Load model and data
42
+ net = Net()
43
+ partition_id = context.node_config["partition-id"]
44
+ num_partitions = context.node_config["num-partitions"]
45
+ trainloader, valloader = load_data(partition_id, num_partitions)
46
+ local_epochs = context.run_config["local-epochs"]
47
+
48
+ # Return Client instance
49
+ return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
50
+
51
+
52
+ # Flower ClientApp
53
+ app = ClientApp(
54
+ client_fn,
55
+ )
@@ -2,10 +2,16 @@
2
2
 
3
3
  import warnings
4
4
 
5
- from sklearn.metrics import log_loss
5
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
6
+ from flwr.clientapp import ClientApp
7
+ from sklearn.metrics import (
8
+ accuracy_score,
9
+ f1_score,
10
+ log_loss,
11
+ precision_score,
12
+ recall_score,
13
+ )
6
14
 
7
- from flwr.client import ClientApp, NumPyClient
8
- from flwr.common import Context
9
15
  from $import_name.task import (
10
16
  get_model,
11
17
  get_model_params,
@@ -14,39 +20,52 @@ from $import_name.task import (
14
20
  set_model_params,
15
21
  )
16
22
 
23
+ # Flower ClientApp
24
+ app = ClientApp()
17
25
 
18
- class FlowerClient(NumPyClient):
19
- def __init__(self, model, X_train, X_test, y_train, y_test):
20
- self.model = model
21
- self.X_train = X_train
22
- self.X_test = X_test
23
- self.y_train = y_train
24
- self.y_test = y_test
25
26
 
26
- def fit(self, parameters, config):
27
- set_model_params(self.model, parameters)
27
+ @app.train()
28
+ def train(msg: Message, context: Context):
29
+ """Train the model on local data."""
28
30
 
29
- # Ignore convergence failure due to low local epochs
30
- with warnings.catch_warnings():
31
- warnings.simplefilter("ignore")
32
- self.model.fit(self.X_train, self.y_train)
31
+ # Create LogisticRegression Model
32
+ penalty = context.run_config["penalty"]
33
+ local_epochs = context.run_config["local-epochs"]
34
+ model = get_model(penalty, local_epochs)
35
+ # Setting initial parameters, akin to model.compile for keras models
36
+ set_initial_params(model)
33
37
 
34
- return get_model_params(self.model), len(self.X_train), {}
38
+ # Apply received pararameters
39
+ ndarrays = msg.content["arrays"].to_numpy_ndarrays()
40
+ set_model_params(model, ndarrays)
35
41
 
36
- def evaluate(self, parameters, config):
37
- set_model_params(self.model, parameters)
42
+ # Load the data
43
+ partition_id = context.node_config["partition-id"]
44
+ num_partitions = context.node_config["num-partitions"]
45
+ X_train, _, y_train, _ = load_data(partition_id, num_partitions)
38
46
 
39
- loss = log_loss(self.y_test, self.model.predict_proba(self.X_test))
40
- accuracy = self.model.score(self.X_test, self.y_test)
47
+ # Ignore convergence failure due to low local epochs
48
+ with warnings.catch_warnings():
49
+ warnings.simplefilter("ignore")
50
+ # Train the model on local data
51
+ model.fit(X_train, y_train)
41
52
 
42
- return loss, len(self.X_test), {"accuracy": accuracy}
53
+ # Let's compute train loss
54
+ y_train_pred_proba = model.predict_proba(X_train)
55
+ train_logloss = log_loss(y_train, y_train_pred_proba)
43
56
 
57
+ # Construct and return reply Message
58
+ ndarrays = get_model_params(model)
59
+ model_record = ArrayRecord(ndarrays)
60
+ metrics = {"num-examples": len(X_train), "train_logloss": train_logloss}
61
+ metric_record = MetricRecord(metrics)
62
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
63
+ return Message(content=content, reply_to=msg)
44
64
 
45
- def client_fn(context: Context):
46
- partition_id = context.node_config["partition-id"]
47
- num_partitions = context.node_config["num-partitions"]
48
65
 
49
- X_train, X_test, y_train, y_test = load_data(partition_id, num_partitions)
66
+ @app.evaluate()
67
+ def evaluate(msg: Message, context: Context):
68
+ """Evaluate the model on test data."""
50
69
 
51
70
  # Create LogisticRegression Model
52
71
  penalty = context.run_config["penalty"]
@@ -56,8 +75,34 @@ def client_fn(context: Context):
56
75
  # Setting initial parameters, akin to model.compile for keras models
57
76
  set_initial_params(model)
58
77
 
59
- return FlowerClient(model, X_train, X_test, y_train, y_test).to_client()
78
+ # Apply received pararameters
79
+ ndarrays = msg.content["arrays"].to_numpy_ndarrays()
80
+ set_model_params(model, ndarrays)
60
81
 
61
-
62
- # Flower ClientApp
63
- app = ClientApp(client_fn=client_fn)
82
+ # Load the data
83
+ partition_id = context.node_config["partition-id"]
84
+ num_partitions = context.node_config["num-partitions"]
85
+ _, X_test, _, y_test = load_data(partition_id, num_partitions)
86
+
87
+ # Evaluate the model on local data
88
+ y_train_pred = model.predict(X_test)
89
+ y_train_pred_proba = model.predict_proba(X_test)
90
+
91
+ accuracy = accuracy_score(y_test, y_train_pred)
92
+ loss = log_loss(y_test, y_train_pred_proba)
93
+ precision = precision_score(y_test, y_train_pred, average="macro", zero_division=0)
94
+ recall = recall_score(y_test, y_train_pred, average="macro", zero_division=0)
95
+ f1 = f1_score(y_test, y_train_pred, average="macro", zero_division=0)
96
+
97
+ # Construct and return reply Message
98
+ metrics = {
99
+ "num-examples": len(X_test),
100
+ "test_logloss": loss,
101
+ "accuracy": accuracy,
102
+ "precision": precision,
103
+ "recall": recall,
104
+ "f1": f1,
105
+ }
106
+ metric_record = MetricRecord(metrics)
107
+ content = RecordDict({"metrics": metric_record})
108
+ return Message(content=content, reply_to=msg)