flwr 1.21.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/new/new.py +9 -7
  3. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  4. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  5. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
  6. flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
  7. flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
  8. flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
  9. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
  10. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
  11. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
  12. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
  13. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
  14. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  15. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  16. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  17. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  18. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  19. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  20. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
  21. flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
  22. flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
  23. flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
  24. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
  25. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
  26. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
  27. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
  28. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
  29. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
  30. flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
  31. flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
  32. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
  33. flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
  34. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
  35. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
  36. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  37. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  38. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  39. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  40. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  41. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  42. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  43. flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +3 -3
  44. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  45. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  46. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
  47. flwr/cli/pull.py +100 -0
  48. flwr/cli/utils.py +17 -0
  49. flwr/clientapp/mod/__init__.py +4 -1
  50. flwr/clientapp/mod/centraldp_mods.py +156 -40
  51. flwr/clientapp/mod/localdp_mod.py +169 -0
  52. flwr/clientapp/typing.py +22 -0
  53. flwr/common/constant.py +3 -0
  54. flwr/common/exit/exit_code.py +4 -0
  55. flwr/common/record/typeddict.py +12 -0
  56. flwr/proto/control_pb2.py +7 -3
  57. flwr/proto/control_pb2.pyi +24 -0
  58. flwr/proto/control_pb2_grpc.py +34 -0
  59. flwr/proto/control_pb2_grpc.pyi +13 -0
  60. flwr/server/app.py +13 -0
  61. flwr/serverapp/strategy/__init__.py +26 -0
  62. flwr/serverapp/strategy/bulyan.py +238 -0
  63. flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
  64. flwr/serverapp/strategy/dp_fixed_clipping.py +71 -49
  65. flwr/serverapp/strategy/fedadagrad.py +0 -3
  66. flwr/serverapp/strategy/fedadam.py +0 -3
  67. flwr/serverapp/strategy/fedavg.py +89 -64
  68. flwr/serverapp/strategy/fedavgm.py +198 -0
  69. flwr/serverapp/strategy/fedmedian.py +105 -0
  70. flwr/serverapp/strategy/fedprox.py +174 -0
  71. flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
  72. flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
  73. flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
  74. flwr/serverapp/strategy/fedyogi.py +0 -3
  75. flwr/serverapp/strategy/krum.py +112 -0
  76. flwr/serverapp/strategy/multikrum.py +247 -0
  77. flwr/serverapp/strategy/qfedavg.py +252 -0
  78. flwr/serverapp/strategy/strategy_utils.py +48 -0
  79. flwr/simulation/app.py +1 -1
  80. flwr/simulation/run_simulation.py +25 -30
  81. flwr/supercore/cli/flower_superexec.py +26 -1
  82. flwr/supercore/constant.py +19 -0
  83. flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
  84. flwr/supercore/superexec/run_superexec.py +16 -2
  85. flwr/superlink/artifact_provider/__init__.py +22 -0
  86. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  87. flwr/superlink/servicer/control/control_grpc.py +3 -0
  88. flwr/superlink/servicer/control/control_servicer.py +59 -2
  89. {flwr-1.21.0.dist-info → flwr-1.22.0.dist-info}/METADATA +6 -16
  90. {flwr-1.21.0.dist-info → flwr-1.22.0.dist-info}/RECORD +93 -74
  91. flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
  92. flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
  93. flwr/serverapp/dp_fixed_clipping.py +0 -352
  94. flwr/serverapp/strategy/strategy_utils_tests.py +0 -304
  95. /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
  96. {flwr-1.21.0.dist-info → flwr-1.22.0.dist-info}/WHEEL +0 -0
  97. {flwr-1.21.0.dist-info → flwr-1.22.0.dist-info}/entry_points.txt +0 -0
flwr/cli/app.py CHANGED
@@ -25,6 +25,7 @@ from .log import log
25
25
  from .login import login
26
26
  from .ls import ls
27
27
  from .new import new
28
+ from .pull import pull
28
29
  from .run import run
29
30
  from .stop import stop
30
31
 
@@ -46,6 +47,7 @@ app.command()(log)
46
47
  app.command()(ls)
47
48
  app.command()(stop)
48
49
  app.command()(login)
50
+ app.command()(pull)
49
51
 
50
52
  typer_click_object = get_command(app)
51
53
 
flwr/cli/new/new.py CHANGED
@@ -35,15 +35,16 @@ class MlFramework(str, Enum):
35
35
  """Available frameworks."""
36
36
 
37
37
  PYTORCH = "PyTorch"
38
- PYTORCH_MSG_API = "PyTorch (Message API)"
39
38
  TENSORFLOW = "TensorFlow"
40
39
  SKLEARN = "sklearn"
41
40
  HUGGINGFACE = "HuggingFace"
42
41
  JAX = "JAX"
43
42
  MLX = "MLX"
44
43
  NUMPY = "NumPy"
44
+ XGBOOST = "XGBoost"
45
45
  FLOWERTUNE = "FlowerTune"
46
46
  BASELINE = "Flower Baseline"
47
+ PYTORCH_LEGACY_API = "PyTorch (Legacy API, deprecated)"
47
48
 
48
49
 
49
50
  class LlmChallengeName(str, Enum):
@@ -155,8 +156,8 @@ def new(
155
156
  if framework_str == MlFramework.BASELINE:
156
157
  framework_str = "baseline"
157
158
 
158
- if framework_str == MlFramework.PYTORCH_MSG_API:
159
- framework_str = "pytorch_msg_api"
159
+ if framework_str == MlFramework.PYTORCH_LEGACY_API:
160
+ framework_str = "pytorch_legacy_api"
160
161
 
161
162
  print(
162
163
  typer.style(
@@ -201,7 +202,7 @@ def new(
201
202
  }
202
203
 
203
204
  # Challenge specific context
204
- fraction_fit = "0.2" if llm_challenge_str == "code" else "0.1"
205
+ fraction_train = "0.2" if llm_challenge_str == "code" else "0.1"
205
206
  if llm_challenge_str == "generalnlp":
206
207
  challenge_name = "General NLP"
207
208
  num_clients = "20"
@@ -220,7 +221,7 @@ def new(
220
221
  dataset_name = "flwrlabs/code-alpaca-20k"
221
222
 
222
223
  context["llm_challenge_str"] = llm_challenge_str
223
- context["fraction_fit"] = fraction_fit
224
+ context["fraction_train"] = fraction_train
224
225
  context["challenge_name"] = challenge_name
225
226
  context["num_clients"] = num_clients
226
227
  context["dataset_name"] = dataset_name
@@ -247,14 +248,15 @@ def new(
247
248
  MlFramework.TENSORFLOW.value,
248
249
  MlFramework.SKLEARN.value,
249
250
  MlFramework.NUMPY.value,
250
- "pytorch_msg_api",
251
+ MlFramework.XGBOOST.value,
252
+ "pytorch_legacy_api",
251
253
  ]
252
254
  if framework_str in frameworks_with_tasks:
253
255
  files[f"{import_name}/task.py"] = {
254
256
  "template": f"app/code/task.{template_name}.py.tpl"
255
257
  }
256
258
 
257
- if framework_str == "pytorch_msg_api":
259
+ if framework_str == "pytorch_legacy_api":
258
260
  # Use custom __init__ that better captures name of framework
259
261
  files[f"{import_name}/__init__.py"] = {
260
262
  "template": f"app/code/__init__.{framework_str}.py.tpl"
@@ -26,7 +26,7 @@ pip install -e .
26
26
  ## Experimental setup
27
27
 
28
28
  The dataset is divided into $num_clients partitions in an IID fashion, a partition is assigned to each ClientApp.
29
- We randomly sample a fraction ($fraction_fit) of the total nodes to participate in each round, for a total of `200` rounds.
29
+ We randomly sample a fraction ($fraction_train) of the total nodes to participate in each round, for a total of `200` rounds.
30
30
  All settings are defined in `pyproject.toml`.
31
31
 
32
32
  > [!IMPORTANT]
@@ -1,58 +1,75 @@
1
1
  """$project_name: A Flower Baseline."""
2
2
 
3
3
  import torch
4
- from flwr.client import ClientApp, NumPyClient
5
- from flwr.common import Context
4
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
5
+ from flwr.clientapp import ClientApp
6
6
 
7
7
  from $import_name.dataset import load_data
8
- from $import_name.model import Net, get_weights, set_weights, test, train
9
-
10
-
11
- class FlowerClient(NumPyClient):
12
- """A class defining the client."""
13
-
14
- def __init__(self, net, trainloader, valloader, local_epochs):
15
- self.net = net
16
- self.trainloader = trainloader
17
- self.valloader = valloader
18
- self.local_epochs = local_epochs
19
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
- self.net.to(self.device)
21
-
22
- def fit(self, parameters, config):
23
- """Traim model using this client's data."""
24
- set_weights(self.net, parameters)
25
- train_loss = train(
26
- self.net,
27
- self.trainloader,
28
- self.local_epochs,
29
- self.device,
30
- )
31
- return (
32
- get_weights(self.net),
33
- len(self.trainloader.dataset),
34
- {"train_loss": train_loss},
35
- )
36
-
37
- def evaluate(self, parameters, config):
38
- """Evaluate model using this client's data."""
39
- set_weights(self.net, parameters)
40
- loss, accuracy = test(self.net, self.valloader, self.device)
41
- return loss, len(self.valloader.dataset), {"accuracy": accuracy}
42
-
43
-
44
- def client_fn(context: Context):
45
- """Construct a Client that will be run in a ClientApp."""
46
- # Load model and data
47
- net = Net()
8
+ from $import_name.model import Net
9
+ from $import_name.model import test as test_fn
10
+ from $import_name.model import train as train_fn
11
+
12
+ # Flower ClientApp
13
+ app = ClientApp()
14
+
15
+
16
+ @app.train()
17
+ def train(msg: Message, context: Context):
18
+ """Train the model on local data."""
19
+
20
+ # Load the model and initialize it with the received weights
21
+ model = Net()
22
+ model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
23
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
+
25
+ # Load the data
48
26
  partition_id = int(context.node_config["partition-id"])
49
27
  num_partitions = int(context.node_config["num-partitions"])
50
- trainloader, valloader = load_data(partition_id, num_partitions)
28
+ trainloader, _ = load_data(partition_id, num_partitions)
51
29
  local_epochs = context.run_config["local-epochs"]
52
30
 
53
- # Return Client instance
54
- return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
31
+ # Call the training function
32
+ train_loss = train_fn(
33
+ model,
34
+ trainloader,
35
+ local_epochs,
36
+ device,
37
+ )
55
38
 
39
+ # Construct and return reply Message
40
+ model_record = ArrayRecord(model.state_dict())
41
+ metrics = {
42
+ "train_loss": train_loss,
43
+ "num-examples": len(trainloader.dataset),
44
+ }
45
+ metric_record = MetricRecord(metrics)
46
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
47
+ return Message(content=content, reply_to=msg)
56
48
 
57
- # Flower ClientApp
58
- app = ClientApp(client_fn)
49
+
50
+ @app.evaluate()
51
+ def evaluate(msg: Message, context: Context):
52
+ """Evaluate the model on local data."""
53
+
54
+ # Load the model and initialize it with the received weights
55
+ model = Net()
56
+ model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
57
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58
+
59
+ # Load the data
60
+ partition_id = int(context.node_config["partition-id"])
61
+ num_partitions = int(context.node_config["num-partitions"])
62
+ _, valloader = load_data(partition_id, num_partitions)
63
+
64
+ # Call the evaluation function
65
+ eval_loss, eval_acc = test_fn(model, valloader, device)
66
+
67
+ # Construct and return reply Message
68
+ metrics = {
69
+ "eval_loss": eval_loss,
70
+ "eval_acc": eval_acc,
71
+ "num-examples": len(valloader.dataset),
72
+ }
73
+ metric_record = MetricRecord(metrics)
74
+ content = RecordDict({"metrics": metric_record})
75
+ return Message(content=content, reply_to=msg)
@@ -1,41 +1,67 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import torch
4
- from flwr.client import ClientApp, NumPyClient
5
- from flwr.common import Context
4
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
5
+ from flwr.clientapp import ClientApp
6
6
  from transformers import AutoModelForSequenceClassification
7
7
 
8
- from $import_name.task import get_weights, load_data, set_weights, test, train
8
+ from $import_name.task import load_data
9
+ from $import_name.task import test as test_fn
10
+ from $import_name.task import train as train_fn
9
11
 
12
+ # Flower ClientApp
13
+ app = ClientApp()
14
+
15
+
16
+ @app.train()
17
+ def train(msg: Message, context: Context):
18
+ """Train the model on local data."""
19
+
20
+ # Get this client's dataset partition
21
+ partition_id = context.node_config["partition-id"]
22
+ num_partitions = context.node_config["num-partitions"]
23
+ model_name = context.run_config["model-name"]
24
+ trainloader, _ = load_data(partition_id, num_partitions, model_name)
25
+
26
+ # Load model
27
+ num_labels = context.run_config["num-labels"]
28
+ net = AutoModelForSequenceClassification.from_pretrained(
29
+ model_name, num_labels=num_labels
30
+ )
10
31
 
11
- # Flower client
12
- class FlowerClient(NumPyClient):
13
- def __init__(self, net, trainloader, testloader, local_epochs):
14
- self.net = net
15
- self.trainloader = trainloader
16
- self.testloader = testloader
17
- self.local_epochs = local_epochs
18
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
- self.net.to(self.device)
32
+ # Initialize it with the received weights
33
+ net.load_state_dict(msg.content["arrays"].to_torch_state_dict())
34
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
+ net.to(device)
20
36
 
21
- def fit(self, parameters, config):
22
- set_weights(self.net, parameters)
23
- train(self.net, self.trainloader, epochs=self.local_epochs, device=self.device)
24
- return get_weights(self.net), len(self.trainloader), {}
37
+ # Train the model on local data
38
+ train_loss = train_fn(
39
+ net,
40
+ trainloader,
41
+ context.run_config["local-steps"],
42
+ device,
43
+ )
25
44
 
26
- def evaluate(self, parameters, config):
27
- set_weights(self.net, parameters)
28
- loss, accuracy = test(self.net, self.testloader, self.device)
29
- return float(loss), len(self.testloader), {"accuracy": accuracy}
45
+ # Construct and return reply Message
46
+ model_record = ArrayRecord(net.state_dict())
47
+ metrics = {
48
+ "train_loss": train_loss,
49
+ "num-examples": len(trainloader.dataset),
50
+ }
51
+ metric_record = MetricRecord(metrics)
52
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
53
+ return Message(content=content, reply_to=msg)
30
54
 
31
55
 
32
- def client_fn(context: Context):
56
+ @app.evaluate()
57
+ def evaluate(msg: Message, context: Context):
58
+ """Evaluate the model on local data."""
33
59
 
34
60
  # Get this client's dataset partition
35
61
  partition_id = context.node_config["partition-id"]
36
62
  num_partitions = context.node_config["num-partitions"]
37
63
  model_name = context.run_config["model-name"]
38
- trainloader, valloader = load_data(partition_id, num_partitions, model_name)
64
+ _, valloader = load_data(partition_id, num_partitions, model_name)
39
65
 
40
66
  # Load model
41
67
  num_labels = context.run_config["num-labels"]
@@ -43,13 +69,25 @@ def client_fn(context: Context):
43
69
  model_name, num_labels=num_labels
44
70
  )
45
71
 
46
- local_epochs = context.run_config["local-epochs"]
47
-
48
- # Return Client instance
49
- return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
72
+ # Initialize it with the received weights
73
+ net.load_state_dict(msg.content["arrays"].to_torch_state_dict())
74
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
75
+ net.to(device)
50
76
 
77
+ # Evaluate the model on local data
78
+ val_loss, val_accuracy = test_fn(
79
+ net,
80
+ valloader,
81
+ device,
82
+ )
51
83
 
52
- # Flower ClientApp
53
- app = ClientApp(
54
- client_fn,
55
- )
84
+ # Construct and return reply Message
85
+ model_record = ArrayRecord(net.state_dict())
86
+ metrics = {
87
+ "val_loss": val_loss,
88
+ "val_accuracy": val_accuracy,
89
+ "num-examples": len(valloader.dataset),
90
+ }
91
+ metric_record = MetricRecord(metrics)
92
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
93
+ return Message(content=content, reply_to=msg)
@@ -1,50 +1,71 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import jax
4
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
5
+ from flwr.clientapp import ClientApp
4
6
 
5
- from flwr.client import ClientApp, NumPyClient
6
- from flwr.common import Context
7
- from $import_name.task import (
8
- evaluation,
9
- get_params,
10
- load_data,
11
- load_model,
12
- loss_fn,
13
- set_params,
14
- train,
15
- )
16
-
17
-
18
- # Define Flower Client and client_fn
19
- class FlowerClient(NumPyClient):
20
- def __init__(self, input_dim):
21
- self.train_x, self.train_y, self.test_x, self.test_y = load_data()
22
- self.grad_fn = jax.grad(loss_fn)
23
- self.params = load_model((input_dim,))
24
-
25
- def fit(self, parameters, config):
26
- set_params(self.params, parameters)
27
- self.params, loss, num_examples = train(
28
- self.params, self.grad_fn, self.train_x, self.train_y
29
- )
30
- return get_params(self.params), num_examples, {"loss": float(loss)}
31
-
32
- def evaluate(self, parameters, config):
33
- set_params(self.params, parameters)
34
- loss, num_examples = evaluation(
35
- self.params, self.grad_fn, self.test_x, self.test_y
36
- )
37
- return float(loss), num_examples, {"loss": float(loss)}
38
-
39
-
40
- def client_fn(context: Context):
7
+ from $import_name.task import evaluation as evaluation_fn
8
+ from $import_name.task import get_params, load_data, load_model, loss_fn, set_params
9
+ from $import_name.task import train as train_fn
10
+
11
+ # Flower ClientApp
12
+ app = ClientApp()
13
+
14
+
15
+ @app.train()
16
+ def train(msg: Message, context: Context):
17
+ """Train the model on local data."""
18
+
19
+ # Read from config
41
20
  input_dim = context.run_config["input-dim"]
42
21
 
43
- # Return Client instance
44
- return FlowerClient(input_dim).to_client()
22
+ # Load data and model
23
+ train_x, train_y, _, _ = load_data()
24
+ model = load_model((input_dim,))
25
+ grad_fn = jax.grad(loss_fn)
45
26
 
27
+ # Set model parameters
28
+ ndarrays = msg.content["arrays"].to_numpy_ndarrays()
29
+ set_params(model, ndarrays)
46
30
 
47
- # Flower ClientApp
48
- app = ClientApp(
49
- client_fn,
50
- )
31
+ # Train the model on local data
32
+ model, loss, num_examples = train_fn(model, grad_fn, train_x, train_y)
33
+
34
+ # Construct and return reply Message
35
+ model_record = ArrayRecord(get_params(model))
36
+ metrics = {
37
+ "train_loss": float(loss),
38
+ "num-examples": num_examples,
39
+ }
40
+ metric_record = MetricRecord(metrics)
41
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
42
+ return Message(content=content, reply_to=msg)
43
+
44
+
45
+ @app.evaluate()
46
+ def evaluate(msg: Message, context: Context):
47
+ """Evaluate the model on local data."""
48
+
49
+ # Read from config
50
+ input_dim = context.run_config["input-dim"]
51
+
52
+ # Load data and model
53
+ _, _, test_x, test_y = load_data()
54
+ model = load_model((input_dim,))
55
+ grad_fn = jax.grad(loss_fn)
56
+
57
+ # Set model parameters
58
+ ndarrays = msg.content["arrays"].to_numpy_ndarrays()
59
+ set_params(model, ndarrays)
60
+
61
+ # Evaluate the model on local data
62
+ loss, num_examples = evaluation_fn(model, grad_fn, test_x, test_y)
63
+
64
+ # Construct and return reply Message
65
+ metrics = {
66
+ "test_loss": float(loss),
67
+ "num-examples": num_examples,
68
+ }
69
+ metric_record = MetricRecord(metrics)
70
+ content = RecordDict({"metrics": metric_record})
71
+ return Message(content=content, reply_to=msg)
@@ -3,10 +3,9 @@
3
3
  import mlx.core as mx
4
4
  import mlx.nn as nn
5
5
  import mlx.optimizers as optim
6
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
7
+ from flwr.clientapp import ClientApp
6
8
 
7
- from flwr.client import ClientApp, NumPyClient
8
- from flwr.common import Context
9
- from flwr.common.config import UserConfig
10
9
  from $import_name.task import (
11
10
  MLP,
12
11
  batch_iterate,
@@ -17,57 +16,87 @@ from $import_name.task import (
17
16
  set_params,
18
17
  )
19
18
 
19
+ # Flower ClientApp
20
+ app = ClientApp()
21
+
22
+
23
+ @app.train()
24
+ def train(msg: Message, context: Context):
25
+ """Train the model on local data."""
26
+
27
+ # Read config
28
+ num_layers = context.run_config["num-layers"]
29
+ input_dim = context.run_config["input-dim"]
30
+ hidden_dim = context.run_config["hidden-dim"]
31
+ batch_size = context.run_config["batch-size"]
32
+ learning_rate = context.run_config["lr"]
33
+ num_epochs = context.run_config["local-epochs"]
20
34
 
21
- # Define Flower Client and client_fn
22
- class FlowerClient(NumPyClient):
23
- def __init__(
24
- self,
25
- data,
26
- run_config: UserConfig,
27
- num_classes,
28
- ):
29
- num_layers = run_config["num-layers"]
30
- hidden_dim = run_config["hidden-dim"]
31
- input_dim = run_config["input-dim"]
32
- batch_size = run_config["batch-size"]
33
- learning_rate = run_config["lr"]
34
- self.num_epochs = run_config["local-epochs"]
35
-
36
- self.train_images, self.train_labels, self.test_images, self.test_labels = data
37
- self.model = MLP(num_layers, input_dim, hidden_dim, num_classes)
38
- self.optimizer = optim.SGD(learning_rate=learning_rate)
39
- self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn)
40
- self.batch_size = batch_size
41
-
42
- def fit(self, parameters, config):
43
- set_params(self.model, parameters)
44
- for _ in range(self.num_epochs):
45
- for X, y in batch_iterate(
46
- self.batch_size, self.train_images, self.train_labels
47
- ):
48
- _, grads = self.loss_and_grad_fn(self.model, X, y)
49
- self.optimizer.update(self.model, grads)
50
- mx.eval(self.model.parameters(), self.optimizer.state)
51
- return get_params(self.model), len(self.train_images), {}
52
-
53
- def evaluate(self, parameters, config):
54
- set_params(self.model, parameters)
55
- accuracy = eval_fn(self.model, self.test_images, self.test_labels)
56
- loss = loss_fn(self.model, self.test_images, self.test_labels)
57
- return loss.item(), len(self.test_images), {"accuracy": accuracy.item()}
58
-
59
-
60
- def client_fn(context: Context):
35
+ # Instantiate model and apply global parameters
36
+ model = MLP(num_layers, input_dim, hidden_dim, output_dim=10)
37
+ ndarrays = msg.content["arrays"].to_numpy_ndarrays()
38
+ set_params(model, ndarrays)
39
+
40
+ # Define optimizer and loss function
41
+ optimizer = optim.SGD(learning_rate=learning_rate)
42
+ loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
43
+
44
+ # Load data
61
45
  partition_id = context.node_config["partition-id"]
62
46
  num_partitions = context.node_config["num-partitions"]
63
- data = load_data(partition_id, num_partitions)
64
- num_classes = 10
47
+ train_images, train_labels, _, _ = load_data(partition_id, num_partitions)
65
48
 
66
- # Return Client instance
67
- return FlowerClient(data, context.run_config, num_classes).to_client()
49
+ # Train the model on local data
50
+ for _ in range(num_epochs):
51
+ for X, y in batch_iterate(batch_size, train_images, train_labels):
52
+ _, grads = loss_and_grad_fn(model, X, y)
53
+ optimizer.update(model, grads)
54
+ mx.eval(model.parameters(), optimizer.state)
68
55
 
56
+ # Compute train accuracy and loss
57
+ accuracy = eval_fn(model, train_images, train_labels)
58
+ loss = loss_fn(model, train_images, train_labels)
59
+ # Construct and return reply Message
60
+ model_record = ArrayRecord(get_params(model))
61
+ metrics = {
62
+ "num-examples": len(train_images),
63
+ "accuracy": float(accuracy.item()),
64
+ "loss": float(loss.item()),
65
+ }
66
+ metric_record = MetricRecord(metrics)
67
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
68
+ return Message(content=content, reply_to=msg)
69
69
 
70
- # Flower ClientApp
71
- app = ClientApp(
72
- client_fn,
73
- )
70
+
71
+ @app.evaluate()
72
+ def evaluate(msg: Message, context: Context):
73
+ """Evaluate the model on local data."""
74
+
75
+ # Read config
76
+ num_layers = context.run_config["num-layers"]
77
+ input_dim = context.run_config["input-dim"]
78
+ hidden_dim = context.run_config["hidden-dim"]
79
+
80
+ # Instantiate model and apply global parameters
81
+ model = MLP(num_layers, input_dim, hidden_dim, output_dim=10)
82
+ ndarrays = msg.content["arrays"].to_numpy_ndarrays()
83
+ set_params(model, ndarrays)
84
+
85
+ # Load data
86
+ partition_id = context.node_config["partition-id"]
87
+ num_partitions = context.node_config["num-partitions"]
88
+ _, _, test_images, test_labels = load_data(partition_id, num_partitions)
89
+
90
+ # Evaluate the model on local data
91
+ accuracy = eval_fn(model, test_images, test_labels)
92
+ loss = loss_fn(model, test_images, test_labels)
93
+
94
+ # Construct and return reply Message
95
+ metrics = {
96
+ "num-examples": len(test_images),
97
+ "accuracy": float(accuracy.item()),
98
+ "loss": float(loss.item()),
99
+ }
100
+ metric_record = MetricRecord(metrics)
101
+ content = RecordDict({"metrics": metric_record})
102
+ return Message(content=content, reply_to=msg)
@@ -1,23 +1,46 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.client import ClientApp, NumPyClient
4
- from flwr.common import Context
5
- from $import_name.task import get_dummy_model
3
+ import numpy as np
4
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
5
+ from flwr.clientapp import ClientApp
6
6
 
7
+ # Flower ClientApp
8
+ app = ClientApp()
7
9
 
8
- class FlowerClient(NumPyClient):
9
10
 
10
- def fit(self, parameters, config):
11
- model = get_dummy_model()
12
- return [model], 1, {}
11
+ @app.train()
12
+ def train(msg: Message, context: Context):
13
+ """Train the model on local data."""
13
14
 
14
- def evaluate(self, parameters, config):
15
- return float(0.0), 1, {"accuracy": float(1.0)}
15
+ # The model is the global arrays
16
+ ndarrays = msg.content["arrays"].to_numpy_ndarrays()
16
17
 
18
+ # Simulate local training (here we just add random noise to model parameters)
19
+ model = [m + np.random.rand(*m.shape) for m in ndarrays]
17
20
 
18
- def client_fn(context: Context):
19
- return FlowerClient().to_client()
21
+ # Construct and return reply Message
22
+ model_record = ArrayRecord(model)
23
+ metrics = {
24
+ "random_metric": np.random.rand(),
25
+ "num-examples": 1,
26
+ }
27
+ metric_record = MetricRecord(metrics)
28
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
29
+ return Message(content=content, reply_to=msg)
20
30
 
21
31
 
22
- # Flower ClientApp
23
- app = ClientApp(client_fn=client_fn)
32
+ @app.evaluate()
33
+ def evaluate(msg: Message, context: Context):
34
+ """Evaluate the model on local data."""
35
+
36
+ # The model is the global arrays
37
+ ndarrays = msg.content["arrays"].to_numpy_ndarrays()
38
+
39
+ # Return reply Message
40
+ metrics = {
41
+ "random_metric": np.random.rand(3).tolist(),
42
+ "num-examples": 1,
43
+ }
44
+ metric_record = MetricRecord(metrics)
45
+ content = RecordDict({"metrics": metric_record})
46
+ return Message(content=content, reply_to=msg)