flwr-nightly 1.22.0.dev20250910__py3-none-any.whl → 1.22.0.dev20250911__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
  2. flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
  3. flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
  4. flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
  5. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
  6. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
  7. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
  8. flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
  9. flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
  10. flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
  11. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
  12. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
  13. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
  14. flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
  15. flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
  16. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
  17. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +3 -3
  18. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  19. flwr/serverapp/strategy/fedavg.py +66 -62
  20. {flwr_nightly-1.22.0.dev20250910.dist-info → flwr_nightly-1.22.0.dev20250911.dist-info}/METADATA +1 -1
  21. {flwr_nightly-1.22.0.dev20250910.dist-info → flwr_nightly-1.22.0.dev20250911.dist-info}/RECORD +23 -23
  22. {flwr_nightly-1.22.0.dev20250910.dist-info → flwr_nightly-1.22.0.dev20250911.dist-info}/WHEEL +0 -0
  23. {flwr_nightly-1.22.0.dev20250910.dist-info → flwr_nightly-1.22.0.dev20250911.dist-info}/entry_points.txt +0 -0
@@ -1,57 +1,82 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.client import NumPyClient, ClientApp
4
- from flwr.common import Context
3
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
4
+ from flwr.clientapp import ClientApp
5
5
 
6
6
  from $import_name.task import load_data, load_model
7
7
 
8
+ # Flower ClientApp
9
+ app = ClientApp()
8
10
 
9
- # Define Flower Client and client_fn
10
- class FlowerClient(NumPyClient):
11
- def __init__(
12
- self, model, data, epochs, batch_size, verbose
13
- ):
14
- self.model = model
15
- self.x_train, self.y_train, self.x_test, self.y_test = data
16
- self.epochs = epochs
17
- self.batch_size = batch_size
18
- self.verbose = verbose
19
-
20
- def fit(self, parameters, config):
21
- self.model.set_weights(parameters)
22
- self.model.fit(
23
- self.x_train,
24
- self.y_train,
25
- epochs=self.epochs,
26
- batch_size=self.batch_size,
27
- verbose=self.verbose,
28
- )
29
- return self.model.get_weights(), len(self.x_train), {}
30
-
31
- def evaluate(self, parameters, config):
32
- self.model.set_weights(parameters)
33
- loss, accuracy = self.model.evaluate(self.x_test, self.y_test, verbose=0)
34
- return loss, len(self.x_test), {"accuracy": accuracy}
35
-
36
-
37
- def client_fn(context: Context):
38
- # Load model and data
39
- net = load_model()
40
11
 
41
- partition_id = context.node_config["partition-id"]
42
- num_partitions = context.node_config["num-partitions"]
43
- data = load_data(partition_id, num_partitions)
12
+ @app.train()
13
+ def train(msg: Message, context: Context):
14
+ """Train the model on local data."""
15
+
16
+ # Load the model and initialize it with the received weights
17
+ model = load_model()
18
+ ndarrays = msg.content["arrays"].to_numpy_ndarrays()
19
+ model.set_weights(ndarrays)
20
+
21
+ # Read from config
44
22
  epochs = context.run_config["local-epochs"]
45
23
  batch_size = context.run_config["batch-size"]
46
24
  verbose = context.run_config.get("verbose")
47
25
 
48
- # Return Client instance
49
- return FlowerClient(
50
- net, data, epochs, batch_size, verbose
51
- ).to_client()
26
+ # Load the data
27
+ partition_id = context.node_config["partition-id"]
28
+ num_partitions = context.node_config["num-partitions"]
29
+ x_train, y_train, _, _ = load_data(partition_id, num_partitions)
52
30
 
31
+ # Train the model on local data
32
+ history = model.fit(
33
+ x_train,
34
+ y_train,
35
+ epochs=epochs,
36
+ batch_size=batch_size,
37
+ verbose=verbose,
38
+ )
53
39
 
54
- # Flower ClientApp
55
- app = ClientApp(
56
- client_fn=client_fn,
57
- )
40
+ # Get final training loss and accuracy
41
+ train_loss = history.history["loss"][-1] if "loss" in history.history else None
42
+ train_acc = history.history.get("accuracy")
43
+ train_acc = train_acc[-1] if train_acc is not None else None
44
+
45
+ # Construct and return reply Message
46
+ model_record = ArrayRecord(model.get_weights())
47
+ metrics = {"num-examples": len(x_train)}
48
+ if train_loss is not None:
49
+ metrics["train_loss"] = train_loss
50
+ if train_acc is not None:
51
+ metrics["train_acc"] = train_acc
52
+ metric_record = MetricRecord(metrics)
53
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
54
+ return Message(content=content, reply_to=msg)
55
+
56
+
57
+ @app.evaluate()
58
+ def evaluate(msg: Message, context: Context):
59
+ """Evaluate the model on local data."""
60
+
61
+ # Load the model and initialize it with the received weights
62
+ model = load_model()
63
+ ndarrays = msg.content["arrays"].to_numpy_ndarrays()
64
+ model.set_weights(ndarrays)
65
+
66
+ # Load the data
67
+ partition_id = context.node_config["partition-id"]
68
+ num_partitions = context.node_config["num-partitions"]
69
+ _, _, x_test, y_test = load_data(partition_id, num_partitions)
70
+
71
+ # Evaluate the model on local data
72
+ loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
73
+
74
+ # Construct and return reply Message
75
+ metrics = {
76
+ "eval_loss": loss,
77
+ "eval_acc": accuracy,
78
+ "num-examples": len(x_test),
79
+ }
80
+ metric_record = MetricRecord(metrics)
81
+ content = RecordDict({"metrics": metric_record})
82
+ return Message(content=content, reply_to=msg)
@@ -1,17 +1,22 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
3
+ import torch
4
+ from flwr.app import ArrayRecord, Context
5
+ from flwr.serverapp import Grid, ServerApp
6
+ from flwr.serverapp.strategy import FedAvg
6
7
  from transformers import AutoModelForSequenceClassification
7
8
 
8
- from $import_name.task import get_weights
9
+ # Create ServerApp
10
+ app = ServerApp()
11
+
9
12
 
13
+ @app.main()
14
+ def main(grid: Grid, context: Context) -> None:
15
+ """Main entry point for the ServerApp."""
10
16
 
11
- def server_fn(context: Context):
12
17
  # Read from config
13
18
  num_rounds = context.run_config["num-server-rounds"]
14
- fraction_fit = context.run_config["fraction-fit"]
19
+ fraction_train = context.run_config["fraction-train"]
15
20
 
16
21
  # Initialize global model
17
22
  model_name = context.run_config["model-name"]
@@ -19,20 +24,19 @@ def server_fn(context: Context):
19
24
  net = AutoModelForSequenceClassification.from_pretrained(
20
25
  model_name, num_labels=num_labels
21
26
  )
27
+ arrays = ArrayRecord(net.state_dict())
22
28
 
23
- weights = get_weights(net)
24
- initial_parameters = ndarrays_to_parameters(weights)
29
+ # Initialize FedAvg strategy
30
+ strategy = FedAvg(fraction_train=fraction_train)
25
31
 
26
- # Define strategy
27
- strategy = FedAvg(
28
- fraction_fit=fraction_fit,
29
- fraction_evaluate=1.0,
30
- initial_parameters=initial_parameters,
32
+ # Start strategy, run FedAvg for `num_rounds`
33
+ result = strategy.start(
34
+ grid=grid,
35
+ initial_arrays=arrays,
36
+ num_rounds=num_rounds,
31
37
  )
32
- config = ServerConfig(num_rounds=num_rounds)
33
-
34
- return ServerAppComponents(strategy=strategy, config=config)
35
38
 
36
-
37
- # Create ServerApp
38
- app = ServerApp(server_fn=server_fn)
39
+ # Save final model to disk
40
+ print("\nSaving final model to disk...")
41
+ state_dict = result.arrays.to_torch_state_dict()
42
+ torch.save(state_dict, "final_model.pt")
@@ -1,26 +1,39 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
3
+ import numpy as np
4
+ from flwr.app import ArrayRecord, Context
5
+ from flwr.serverapp import Grid, ServerApp
6
+ from flwr.serverapp.strategy import FedAvg
7
+
6
8
  from $import_name.task import get_params, load_model
7
9
 
10
+ # Create ServerApp
11
+ app = ServerApp()
12
+
13
+
14
+ @app.main()
15
+ def main(grid: Grid, context: Context) -> None:
16
+ """Main entry point for the ServerApp."""
8
17
 
9
- def server_fn(context: Context):
10
18
  # Read from config
11
19
  num_rounds = context.run_config["num-server-rounds"]
12
20
  input_dim = context.run_config["input-dim"]
13
21
 
14
- # Initialize global model
15
- params = get_params(load_model((input_dim,)))
16
- initial_parameters = ndarrays_to_parameters(params)
17
-
18
- # Define strategy
19
- strategy = FedAvg(initial_parameters=initial_parameters)
20
- config = ServerConfig(num_rounds=num_rounds)
22
+ # Load global model
23
+ model = load_model((input_dim,))
24
+ arrays = ArrayRecord(get_params(model))
21
25
 
22
- return ServerAppComponents(strategy=strategy, config=config)
26
+ # Initialize FedAvg strategy
27
+ strategy = FedAvg()
23
28
 
29
+ # Start strategy, run FedAvg for `num_rounds`
30
+ result = strategy.start(
31
+ grid=grid,
32
+ initial_arrays=arrays,
33
+ num_rounds=num_rounds,
34
+ )
24
35
 
25
- # Create ServerApp
26
- app = ServerApp(server_fn=server_fn)
36
+ # Save final model to disk
37
+ print("\nSaving final model to disk...")
38
+ ndarrays = result.arrays.to_numpy_ndarrays()
39
+ np.savez("final_model.npz", *ndarrays)
@@ -1,31 +1,41 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
6
- from $import_name.task import MLP, get_params
3
+ from flwr.app import ArrayRecord, Context
4
+ from flwr.serverapp import Grid, ServerApp
5
+ from flwr.serverapp.strategy import FedAvg
7
6
 
7
+ from $import_name.task import MLP, get_params, set_params
8
8
 
9
- def server_fn(context: Context):
9
+ # Create ServerApp
10
+ app = ServerApp()
11
+
12
+
13
+ @app.main()
14
+ def main(grid: Grid, context: Context) -> None:
15
+ """Main entry point for the ServerApp."""
10
16
  # Read from config
11
17
  num_rounds = context.run_config["num-server-rounds"]
12
-
13
- num_classes = 10
14
18
  num_layers = context.run_config["num-layers"]
15
19
  input_dim = context.run_config["input-dim"]
16
20
  hidden_dim = context.run_config["hidden-dim"]
17
21
 
18
22
  # Initialize global model
19
- model = MLP(num_layers, input_dim, hidden_dim, num_classes)
23
+ model = MLP(num_layers, input_dim, hidden_dim, output_dim=10)
20
24
  params = get_params(model)
21
- initial_parameters = ndarrays_to_parameters(params)
22
-
23
- # Define strategy
24
- strategy = FedAvg(initial_parameters=initial_parameters)
25
- config = ServerConfig(num_rounds=num_rounds)
26
-
27
- return ServerAppComponents(strategy=strategy, config=config)
28
-
29
-
30
- # Create ServerApp
31
- app = ServerApp(server_fn=server_fn)
25
+ arrays = ArrayRecord(params)
26
+
27
+ # Initialize FedAvg strategy
28
+ strategy = FedAvg()
29
+
30
+ # Start strategy, run FedAvg for `num_rounds`
31
+ result = strategy.start(
32
+ grid=grid,
33
+ initial_arrays=arrays,
34
+ num_rounds=num_rounds,
35
+ )
36
+
37
+ # Save final model to disk
38
+ print("\nSaving final model to disk...")
39
+ ndarrays = result.arrays.to_numpy_ndarrays()
40
+ set_params(model, ndarrays)
41
+ model.save_weights("final_model.npz")
@@ -1,25 +1,38 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
6
- from $import_name.task import get_dummy_model
7
-
3
+ import numpy as np
4
+ from flwr.app import ArrayRecord, Context
5
+ from flwr.serverapp import Grid, ServerApp
6
+ from flwr.serverapp.strategy import FedAvg
8
7
 
9
- def server_fn(context: Context):
10
- # Read from config
11
- num_rounds = context.run_config["num-server-rounds"]
8
+ from $import_name.task import get_dummy_model
12
9
 
13
- # Initial model
14
- model = get_dummy_model()
15
- dummy_parameters = ndarrays_to_parameters([model])
10
+ # Create ServerApp
11
+ app = ServerApp()
16
12
 
17
- # Define strategy
18
- strategy = FedAvg(initial_parameters=dummy_parameters)
19
- config = ServerConfig(num_rounds=num_rounds)
20
13
 
21
- return ServerAppComponents(strategy=strategy, config=config)
14
+ @app.main()
15
+ def main(grid: Grid, context: Context) -> None:
16
+ """Main entry point for the ServerApp."""
22
17
 
18
+ # Read run config
19
+ num_rounds: int = context.run_config["num-server-rounds"]
23
20
 
24
- # Create ServerApp
25
- app = ServerApp(server_fn=server_fn)
21
+ # Load global model
22
+ model = get_dummy_model()
23
+ arrays = ArrayRecord(model)
24
+
25
+ # Initialize FedAvg strategy
26
+ strategy = FedAvg()
27
+
28
+ # Start strategy, run FedAvg for `num_rounds`
29
+ result = strategy.start(
30
+ grid=grid,
31
+ initial_arrays=arrays,
32
+ num_rounds=num_rounds,
33
+ )
34
+
35
+ # Save final model to disk
36
+ print("\nSaving final model to disk...")
37
+ ndarrays = result.arrays.to_numpy_ndarrays()
38
+ np.savez("final_model", *ndarrays)
@@ -1,36 +1,44 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
6
- from $import_name.task import get_model, get_model_params, set_initial_params
3
+ import joblib
4
+ from flwr.app import ArrayRecord, Context
5
+ from flwr.serverapp import Grid, ServerApp
6
+ from flwr.serverapp.strategy import FedAvg
7
7
 
8
+ from $import_name.task import get_model, get_model_params, set_initial_params, set_model_params
8
9
 
9
- def server_fn(context: Context):
10
- # Read from config
11
- num_rounds = context.run_config["num-server-rounds"]
10
+ # Create ServerApp
11
+ app = ServerApp()
12
+
13
+
14
+ @app.main()
15
+ def main(grid: Grid, context: Context) -> None:
16
+ """Main entry point for the ServerApp."""
17
+
18
+ # Read run config
19
+ num_rounds: int = context.run_config["num-server-rounds"]
12
20
 
13
21
  # Create LogisticRegression Model
14
22
  penalty = context.run_config["penalty"]
15
23
  local_epochs = context.run_config["local-epochs"]
16
24
  model = get_model(penalty, local_epochs)
17
-
18
25
  # Setting initial parameters, akin to model.compile for keras models
19
26
  set_initial_params(model)
27
+ # Construct ArrayRecord representation
28
+ arrays = ArrayRecord(get_model_params(model))
20
29
 
21
- initial_parameters = ndarrays_to_parameters(get_model_params(model))
30
+ # Initialize FedAvg strategy
31
+ strategy = FedAvg(fraction_train=1.0, fraction_evaluate=1.0)
22
32
 
23
- # Define strategy
24
- strategy = FedAvg(
25
- fraction_fit=1.0,
26
- fraction_evaluate=1.0,
27
- min_available_clients=2,
28
- initial_parameters=initial_parameters,
33
+ # Start strategy, run FedAvg for `num_rounds`
34
+ result = strategy.start(
35
+ grid=grid,
36
+ initial_arrays=arrays,
37
+ num_rounds=num_rounds,
29
38
  )
30
- config = ServerConfig(num_rounds=num_rounds)
31
-
32
- return ServerAppComponents(strategy=strategy, config=config)
33
39
 
34
-
35
- # Create ServerApp
36
- app = ServerApp(server_fn=server_fn)
40
+ # Save final model parameters
41
+ print("\nSaving final model to disk...")
42
+ ndarrays = result.arrays.to_numpy_ndarrays()
43
+ set_model_params(model, ndarrays)
44
+ joblib.dump(model, "logreg_model.pkl")
@@ -1,29 +1,38 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context, ndarrays_to_parameters
4
- from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
- from flwr.server.strategy import FedAvg
3
+ from flwr.app import ArrayRecord, Context
4
+ from flwr.serverapp import Grid, ServerApp
5
+ from flwr.serverapp.strategy import FedAvg
6
6
 
7
7
  from $import_name.task import load_model
8
8
 
9
+ # Create ServerApp
10
+ app = ServerApp()
9
11
 
10
- def server_fn(context: Context):
11
- # Read from config
12
- num_rounds = context.run_config["num-server-rounds"]
13
12
 
14
- # Get parameters to initialize global model
15
- parameters = ndarrays_to_parameters(load_model().get_weights())
13
+ @app.main()
14
+ def main(grid: Grid, context: Context) -> None:
15
+ """Main entry point for the ServerApp."""
16
16
 
17
- # Define strategy
18
- strategy = strategy = FedAvg(
19
- fraction_fit=1.0,
20
- fraction_evaluate=1.0,
21
- min_available_clients=2,
22
- initial_parameters=parameters,
23
- )
24
- config = ServerConfig(num_rounds=num_rounds)
17
+ # Read run config
18
+ num_rounds: int = context.run_config["num-server-rounds"]
25
19
 
26
- return ServerAppComponents(strategy=strategy, config=config)
20
+ # Load global model
21
+ model = load_model()
22
+ arrays = ArrayRecord(model.get_weights())
27
23
 
28
- # Create ServerApp
29
- app = ServerApp(server_fn=server_fn)
24
+ # Initialize FedAvg strategy
25
+ strategy = FedAvg(fraction_train=1.0, fraction_evaluate=1.0)
26
+
27
+ # Start strategy, run FedAvg for `num_rounds`
28
+ result = strategy.start(
29
+ grid=grid,
30
+ initial_arrays=arrays,
31
+ num_rounds=num_rounds,
32
+ )
33
+
34
+ # Save final model to disk
35
+ print("\nSaving final model to disk...")
36
+ ndarrays = result.arrays.to_numpy_ndarrays()
37
+ model.set_weights(ndarrays)
38
+ model.save("final_model.keras")
@@ -1,7 +1,6 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import warnings
4
- from collections import OrderedDict
5
4
 
6
5
  import torch
7
6
  import transformers
@@ -62,17 +61,24 @@ def load_data(partition_id: int, num_partitions: int, model_name: str):
62
61
  return trainloader, testloader
63
62
 
64
63
 
65
- def train(net, trainloader, epochs, device):
64
+ def train(net, trainloader, num_steps, device):
66
65
  optimizer = AdamW(net.parameters(), lr=5e-5)
67
66
  net.train()
68
- for _ in range(epochs):
69
- for batch in trainloader:
70
- batch = {k: v.to(device) for k, v in batch.items()}
71
- outputs = net(**batch)
72
- loss = outputs.loss
73
- loss.backward()
74
- optimizer.step()
75
- optimizer.zero_grad()
67
+ running_loss = 0.0
68
+ step_cnt = 0
69
+ for batch in trainloader:
70
+ batch = {k: v.to(device) for k, v in batch.items()}
71
+ outputs = net(**batch)
72
+ loss = outputs.loss
73
+ loss.backward()
74
+ optimizer.step()
75
+ optimizer.zero_grad()
76
+ running_loss += loss.item()
77
+ step_cnt += 1
78
+ if step_cnt >= num_steps:
79
+ break
80
+ avg_trainloss = running_loss / step_cnt
81
+ return avg_trainloss
76
82
 
77
83
 
78
84
  def test(net, testloader, device):
@@ -90,13 +96,3 @@ def test(net, testloader, device):
90
96
  loss /= len(testloader.dataset)
91
97
  accuracy = metric.compute()["accuracy"]
92
98
  return loss, accuracy
93
-
94
-
95
- def get_weights(net):
96
- return [val.cpu().numpy() for _, val in net.state_dict().items()]
97
-
98
-
99
- def set_weights(net, parameters):
100
- params_dict = zip(net.state_dict().keys(), parameters)
101
- state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
102
- net.load_state_dict(state_dict, strict=True)
@@ -31,7 +31,7 @@ def loss_fn(params, X, y):
31
31
  def train(params, grad_fn, X, y):
32
32
  loss = 1_000_000
33
33
  num_examples = X.shape[0]
34
- for epochs in range(50):
34
+ for _ in range(50):
35
35
  grads = grad_fn(params, X, y)
36
36
  params = jax.tree.map(lambda p, g: p - 0.05 * g, params, grads)
37
37
  loss = loss_fn(params, X, y)
@@ -4,4 +4,4 @@ import numpy as np
4
4
 
5
5
 
6
6
  def get_dummy_model():
7
- return np.ones((1, 1))
7
+ return [np.ones((1, 1))]
@@ -3,10 +3,9 @@
3
3
  import os
4
4
 
5
5
  import keras
6
- from keras import layers
7
6
  from flwr_datasets import FederatedDataset
8
7
  from flwr_datasets.partitioner import IidPartitioner
9
-
8
+ from keras import layers
10
9
 
11
10
  # Make TensorFlow log less verbose
12
11
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -16,7 +16,7 @@ license = "Apache-2.0"
16
16
  dependencies = [
17
17
  "flwr[simulation]>=1.22.0",
18
18
  "flwr-datasets>=0.5.0",
19
- "torch==2.7.1",
19
+ "torch>=2.7.1",
20
20
  "transformers>=4.30.0,<5.0",
21
21
  "evaluate>=0.4.0,<1.0",
22
22
  "datasets>=2.0.0, <3.0",
@@ -38,8 +38,8 @@ clientapp = "$import_name.client_app:app"
38
38
  # Custom config values accessible via `context.run_config`
39
39
  [tool.flwr.app.config]
40
40
  num-server-rounds = 3
41
- fraction-fit = 0.5
42
- local-epochs = 1
41
+ fraction-train = 0.5
42
+ local-steps = 5
43
43
  model-name = "prajjwal1/bert-tiny" # Set a larger model if you have access to more GPU resources
44
44
  num-labels = 2
45
45
 
@@ -16,7 +16,7 @@ license = "Apache-2.0"
16
16
  dependencies = [
17
17
  "flwr[simulation]>=1.22.0",
18
18
  "flwr-datasets[vision]>=0.5.0",
19
- "mlx==0.26.5",
19
+ "mlx==0.29.0",
20
20
  ]
21
21
 
22
22
  [tool.hatch.build.targets.wheel]