flwr-nightly 1.10.0.dev20240722__py3-none-any.whl → 1.11.0.dev20240805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (66) hide show
  1. flwr/cli/config_utils.py +40 -23
  2. flwr/cli/new/new.py +7 -6
  3. flwr/cli/new/templates/app/README.md.tpl +1 -1
  4. flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
  5. flwr/cli/new/templates/app/code/{client.hf.py.tpl → client.huggingface.py.tpl} +8 -6
  6. flwr/cli/new/templates/app/code/client.jax.py.tpl +1 -1
  7. flwr/cli/new/templates/app/code/client.mlx.py.tpl +29 -11
  8. flwr/cli/new/templates/app/code/client.numpy.py.tpl +1 -1
  9. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +16 -13
  10. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +3 -3
  11. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +20 -13
  12. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +20 -17
  13. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +5 -3
  14. flwr/cli/new/templates/app/code/{server.hf.py.tpl → server.huggingface.py.tpl} +3 -2
  15. flwr/cli/new/templates/app/code/server.jax.py.tpl +3 -2
  16. flwr/cli/new/templates/app/code/server.mlx.py.tpl +3 -2
  17. flwr/cli/new/templates/app/code/server.numpy.py.tpl +3 -2
  18. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +8 -7
  19. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +3 -2
  20. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +5 -6
  21. flwr/cli/new/templates/app/code/{task.hf.py.tpl → task.huggingface.py.tpl} +14 -2
  22. flwr/cli/new/templates/app/code/task.jax.py.tpl +2 -2
  23. flwr/cli/new/templates/app/code/task.mlx.py.tpl +15 -2
  24. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +26 -21
  25. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -5
  26. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +3 -3
  27. flwr/cli/new/templates/app/{pyproject.hf.toml.tpl → pyproject.huggingface.toml.tpl} +4 -4
  28. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +4 -4
  29. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +11 -11
  30. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +4 -4
  31. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -6
  32. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +5 -5
  33. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +8 -8
  34. flwr/cli/run/run.py +31 -27
  35. flwr/client/grpc_rere_client/grpc_adapter.py +7 -0
  36. flwr/client/supernode/app.py +12 -43
  37. flwr/common/config.py +6 -1
  38. flwr/common/object_ref.py +84 -21
  39. flwr/proto/driver_pb2.py +22 -21
  40. flwr/proto/driver_pb2.pyi +7 -1
  41. flwr/proto/driver_pb2_grpc.py +35 -0
  42. flwr/proto/driver_pb2_grpc.pyi +14 -0
  43. flwr/proto/exec_pb2.py +16 -12
  44. flwr/proto/exec_pb2.pyi +20 -1
  45. flwr/proto/fleet_pb2.py +28 -27
  46. flwr/proto/fleet_pb2_grpc.py +35 -0
  47. flwr/proto/fleet_pb2_grpc.pyi +14 -0
  48. flwr/proto/run_pb2.py +8 -8
  49. flwr/proto/run_pb2.pyi +4 -1
  50. flwr/server/run_serverapp.py +0 -3
  51. flwr/server/superlink/driver/driver_servicer.py +7 -0
  52. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +7 -0
  53. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  54. flwr/server/superlink/fleet/vce/vce_api.py +4 -4
  55. flwr/simulation/__init__.py +1 -1
  56. flwr/simulation/run_simulation.py +32 -4
  57. flwr/superexec/app.py +4 -5
  58. flwr/superexec/deployment.py +1 -2
  59. flwr/superexec/exec_servicer.py +3 -1
  60. flwr/superexec/executor.py +3 -0
  61. flwr/superexec/simulation.py +54 -12
  62. {flwr_nightly-1.10.0.dev20240722.dist-info → flwr_nightly-1.11.0.dev20240805.dist-info}/METADATA +1 -1
  63. {flwr_nightly-1.10.0.dev20240722.dist-info → flwr_nightly-1.11.0.dev20240805.dist-info}/RECORD +66 -66
  64. {flwr_nightly-1.10.0.dev20240722.dist-info → flwr_nightly-1.11.0.dev20240805.dist-info}/LICENSE +0 -0
  65. {flwr_nightly-1.10.0.dev20240722.dist-info → flwr_nightly-1.11.0.dev20240805.dist-info}/WHEEL +0 -0
  66. {flwr_nightly-1.10.0.dev20240722.dist-info → flwr_nightly-1.11.0.dev20240805.dist-info}/entry_points.txt +0 -0
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / TensorFlow app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  from flwr.common import Context, ndarrays_to_parameters
4
4
  from flwr.server import ServerApp, ServerAppComponents, ServerConfig
@@ -6,14 +6,13 @@ from flwr.server.strategy import FedAvg
6
6
 
7
7
  from $import_name.task import load_model
8
8
 
9
- # Define config
10
- config = ServerConfig(num_rounds=3)
11
-
12
- parameters = ndarrays_to_parameters(load_model().get_weights())
13
9
 
14
10
  def server_fn(context: Context):
15
11
  # Read from config
16
- num_rounds = int(context.run_config["num-server-rounds"])
12
+ num_rounds = context.run_config["num-server-rounds"]
13
+
14
+ # Get parameters to initialize global model
15
+ parameters = ndarrays_to_parameters(load_model().get_weights())
17
16
 
18
17
  # Define strategy
19
18
  strategy = strategy = FedAvg(
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / HuggingFace Transformers app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import warnings
4
4
  from collections import OrderedDict
@@ -10,15 +10,27 @@ from torch.utils.data import DataLoader
10
10
  from transformers import AutoTokenizer, DataCollatorWithPadding
11
11
 
12
12
  from flwr_datasets import FederatedDataset
13
+ from flwr_datasets.partitioner import IidPartitioner
14
+
13
15
 
14
16
  warnings.filterwarnings("ignore", category=UserWarning)
15
17
  DEVICE = torch.device("cpu")
16
18
  CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint
17
19
 
18
20
 
21
+ fds = None # Cache FederatedDataset
22
+
23
+
19
24
  def load_data(partition_id: int, num_partitions: int):
20
25
  """Load IMDB data (training and eval)"""
21
- fds = FederatedDataset(dataset="imdb", partitioners={"train": num_partitions})
26
+ # Only initialize `FederatedDataset` once
27
+ global fds
28
+ if fds is None:
29
+ partitioner = IidPartitioner(num_partitions=num_partitions)
30
+ fds = FederatedDataset(
31
+ dataset="stanfordnlp/imdb",
32
+ partitioners={"train": partitioner},
33
+ )
22
34
  partition = fds.load_partition(partition_id)
23
35
  # Divide data: 80% train, 20% test
24
36
  partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / JAX app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
@@ -33,7 +33,7 @@ def train(params, grad_fn, X, y):
33
33
  num_examples = X.shape[0]
34
34
  for epochs in range(50):
35
35
  grads = grad_fn(params, X, y)
36
- params = jax.tree.map(lambda p, g: p - 0.05 * g, params, grads)
36
+ params = jax.tree_map(lambda p, g: p - 0.05 * g, params, grads)
37
37
  loss = loss_fn(params, X, y)
38
38
  return params, loss, num_examples
39
39
 
@@ -1,14 +1,16 @@
1
- """$project_name: A Flower / MLX app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import mlx.core as mx
4
4
  import mlx.nn as nn
5
5
  import numpy as np
6
6
  from datasets.utils.logging import disable_progress_bar
7
7
  from flwr_datasets import FederatedDataset
8
+ from flwr_datasets.partitioner import IidPartitioner
8
9
 
9
10
 
10
11
  disable_progress_bar()
11
12
 
13
+
12
14
  class MLP(nn.Module):
13
15
  """A simple MLP."""
14
16
 
@@ -43,8 +45,19 @@ def batch_iterate(batch_size, X, y):
43
45
  yield X[ids], y[ids]
44
46
 
45
47
 
48
+ fds = None # Cache FederatedDataset
49
+
50
+
46
51
  def load_data(partition_id: int, num_partitions: int):
47
- fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions})
52
+ # Only initialize `FederatedDataset` once
53
+ global fds
54
+ if fds is None:
55
+ partitioner = IidPartitioner(num_partitions=num_partitions)
56
+ fds = FederatedDataset(
57
+ dataset="ylecun/mnist",
58
+ partitioners={"train": partitioner},
59
+ trust_remote_code=True,
60
+ )
48
61
  partition = fds.load_partition(partition_id)
49
62
  partition_splits = partition.train_test_split(test_size=0.2, seed=42)
50
63
 
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / PyTorch app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  from collections import OrderedDict
4
4
 
@@ -6,11 +6,9 @@ import torch
6
6
  import torch.nn as nn
7
7
  import torch.nn.functional as F
8
8
  from torch.utils.data import DataLoader
9
- from torchvision.datasets import CIFAR10
10
9
  from torchvision.transforms import Compose, Normalize, ToTensor
11
10
  from flwr_datasets import FederatedDataset
12
-
13
- DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
11
+ from flwr_datasets.partitioner import IidPartitioner
14
12
 
15
13
 
16
14
  class Net(nn.Module):
@@ -34,9 +32,19 @@ class Net(nn.Module):
34
32
  return self.fc3(x)
35
33
 
36
34
 
35
+ fds = None # Cache FederatedDataset
36
+
37
+
37
38
  def load_data(partition_id: int, num_partitions: int):
38
39
  """Load partition CIFAR10 data."""
39
- fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
40
+ # Only initialize `FederatedDataset` once
41
+ global fds
42
+ if fds is None:
43
+ partitioner = IidPartitioner(num_partitions=num_partitions)
44
+ fds = FederatedDataset(
45
+ dataset="uoft-cs/cifar10",
46
+ partitioners={"train": partitioner},
47
+ )
40
48
  partition = fds.load_partition(partition_id)
41
49
  # Divide data on each node: 80% train, 20% test
42
50
  partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
@@ -55,44 +63,41 @@ def load_data(partition_id: int, num_partitions: int):
55
63
  return trainloader, testloader
56
64
 
57
65
 
58
- def train(net, trainloader, valloader, epochs, device):
66
+ def train(net, trainloader, epochs, device):
59
67
  """Train the model on the training set."""
60
68
  net.to(device) # move model to GPU if available
61
69
  criterion = torch.nn.CrossEntropyLoss().to(device)
62
- optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
70
+ optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
63
71
  net.train()
72
+ running_loss = 0.0
64
73
  for _ in range(epochs):
65
74
  for batch in trainloader:
66
75
  images = batch["img"]
67
76
  labels = batch["label"]
68
77
  optimizer.zero_grad()
69
- criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
78
+ loss = criterion(net(images.to(device)), labels.to(device))
79
+ loss.backward()
70
80
  optimizer.step()
81
+ running_loss += loss.item()
71
82
 
72
- train_loss, train_acc = test(net, trainloader)
73
- val_loss, val_acc = test(net, valloader)
74
-
75
- results = {
76
- "train_loss": train_loss,
77
- "train_accuracy": train_acc,
78
- "val_loss": val_loss,
79
- "val_accuracy": val_acc,
80
- }
81
- return results
83
+ avg_trainloss = running_loss / len(trainloader)
84
+ return avg_trainloss
82
85
 
83
86
 
84
- def test(net, testloader):
87
+ def test(net, testloader, device):
85
88
  """Validate the model on the test set."""
89
+ net.to(device)
86
90
  criterion = torch.nn.CrossEntropyLoss()
87
91
  correct, loss = 0, 0.0
88
92
  with torch.no_grad():
89
93
  for batch in testloader:
90
- images = batch["img"].to(DEVICE)
91
- labels = batch["label"].to(DEVICE)
94
+ images = batch["img"].to(device)
95
+ labels = batch["label"].to(device)
92
96
  outputs = net(images)
93
97
  loss += criterion(outputs, labels).item()
94
98
  correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
95
99
  accuracy = correct / len(testloader.dataset)
100
+ loss = loss / len(testloader)
96
101
  return loss, accuracy
97
102
 
98
103
 
@@ -1,24 +1,48 @@
1
- """$project_name: A Flower / TensorFlow app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import os
4
4
 
5
- import tensorflow as tf
5
+ import keras
6
+ from keras import layers
6
7
  from flwr_datasets import FederatedDataset
8
+ from flwr_datasets.partitioner import IidPartitioner
7
9
 
8
10
 
9
11
  # Make TensorFlow log less verbose
10
12
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
11
13
 
14
+
12
15
  def load_model():
13
- # Load model and data (MobileNetV2, CIFAR-10)
14
- model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
16
+ # Define a simple CNN for CIFAR-10 and set Adam optimizer
17
+ model = keras.Sequential(
18
+ [
19
+ keras.Input(shape=(32, 32, 3)),
20
+ layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
21
+ layers.MaxPooling2D(pool_size=(2, 2)),
22
+ layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
23
+ layers.MaxPooling2D(pool_size=(2, 2)),
24
+ layers.Flatten(),
25
+ layers.Dropout(0.5),
26
+ layers.Dense(10, activation="softmax"),
27
+ ]
28
+ )
15
29
  model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
16
30
  return model
17
31
 
18
32
 
33
+ fds = None # Cache FederatedDataset
34
+
35
+
19
36
  def load_data(partition_id, num_partitions):
20
37
  # Download and partition dataset
21
- fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
38
+ # Only initialize `FederatedDataset` once
39
+ global fds
40
+ if fds is None:
41
+ partitioner = IidPartitioner(num_partitions=num_partitions)
42
+ fds = FederatedDataset(
43
+ dataset="uoft-cs/cifar10",
44
+ partitioners={"train": partitioner},
45
+ )
22
46
  partition = fds.load_partition(partition_id, "train")
23
47
  partition.set_format("numpy")
24
48
 
@@ -30,10 +30,10 @@ serverapp = "$import_name.app:server"
30
30
  clientapp = "$import_name.app:client"
31
31
 
32
32
  [tool.flwr.app.config]
33
- num-server-rounds = "3"
33
+ num-server-rounds = 3
34
34
 
35
35
  [tool.flwr.federations]
36
- default = "localhost"
36
+ default = "local-simulation"
37
37
 
38
- [tool.flwr.federations.localhost]
38
+ [tool.flwr.federations.local-simulation]
39
39
  options.num-supernodes = 10
@@ -8,8 +8,8 @@ version = "1.0.0"
8
8
  description = ""
9
9
  license = "Apache-2.0"
10
10
  dependencies = [
11
- "flwr[simulation]>=1.9.0,<2.0",
12
- "flwr-datasets>=0.0.2,<1.0.0",
11
+ "flwr[simulation]>=1.10.0",
12
+ "flwr-datasets>=0.3.0",
13
13
  "torch==2.2.1",
14
14
  "transformers>=4.30.0,<5.0",
15
15
  "evaluate>=0.4.0,<1.0",
@@ -28,8 +28,8 @@ serverapp = "$import_name.server_app:app"
28
28
  clientapp = "$import_name.client_app:app"
29
29
 
30
30
  [tool.flwr.app.config]
31
- num-server-rounds = "3"
32
- local-epochs = "1"
31
+ num-server-rounds = 3
32
+ local-epochs = 1
33
33
 
34
34
  [tool.flwr.federations]
35
35
  default = "localhost"
@@ -8,7 +8,7 @@ version = "1.0.0"
8
8
  description = ""
9
9
  license = "Apache-2.0"
10
10
  dependencies = [
11
- "flwr[simulation]>=1.9.0,<2.0",
11
+ "flwr[simulation]>=1.10.0",
12
12
  "jax==0.4.13",
13
13
  "jaxlib==0.4.13",
14
14
  "scikit-learn==1.3.2",
@@ -25,10 +25,10 @@ serverapp = "$import_name.server_app:app"
25
25
  clientapp = "$import_name.client_app:app"
26
26
 
27
27
  [tool.flwr.app.config]
28
- num-server-rounds = "3"
28
+ num-server-rounds = 3
29
29
 
30
30
  [tool.flwr.federations]
31
- default = "localhost"
31
+ default = "local-simulation"
32
32
 
33
- [tool.flwr.federations.localhost]
33
+ [tool.flwr.federations.local-simulation]
34
34
  options.num-supernodes = 10
@@ -8,9 +8,9 @@ version = "1.0.0"
8
8
  description = ""
9
9
  license = "Apache-2.0"
10
10
  dependencies = [
11
- "flwr[simulation]>=1.9.0,<2.0",
12
- "flwr-datasets[vision]>=0.0.2,<1.0.0",
13
- "mlx==0.10.0",
11
+ "flwr[simulation]>=1.10.0",
12
+ "flwr-datasets[vision]>=0.3.0",
13
+ "mlx==0.16.1",
14
14
  "numpy==1.24.4",
15
15
  ]
16
16
 
@@ -25,15 +25,15 @@ serverapp = "$import_name.server_app:app"
25
25
  clientapp = "$import_name.client_app:app"
26
26
 
27
27
  [tool.flwr.app.config]
28
- num-server-rounds = "3"
29
- local-epochs = "1"
30
- num-layers = "2"
31
- hidden-dim = "32"
32
- batch-size = "256"
33
- lr = "0.1"
28
+ num-server-rounds = 3
29
+ local-epochs = 1
30
+ num-layers = 2
31
+ hidden-dim = 32
32
+ batch-size = 256
33
+ lr = 0.1
34
34
 
35
35
  [tool.flwr.federations]
36
- default = "localhost"
36
+ default = "local-simulation"
37
37
 
38
- [tool.flwr.federations.localhost]
38
+ [tool.flwr.federations.local-simulation]
39
39
  options.num-supernodes = 10
@@ -8,7 +8,7 @@ version = "1.0.0"
8
8
  description = ""
9
9
  license = "Apache-2.0"
10
10
  dependencies = [
11
- "flwr[simulation]>=1.9.0,<2.0",
11
+ "flwr[simulation]>=1.10.0",
12
12
  "numpy>=1.21.0",
13
13
  ]
14
14
 
@@ -23,10 +23,10 @@ serverapp = "$import_name.server_app:app"
23
23
  clientapp = "$import_name.client_app:app"
24
24
 
25
25
  [tool.flwr.app.config]
26
- num-server-rounds = "3"
26
+ num-server-rounds = 3
27
27
 
28
28
  [tool.flwr.federations]
29
- default = "localhost"
29
+ default = "local-simulation"
30
30
 
31
- [tool.flwr.federations.localhost]
31
+ [tool.flwr.federations.local-simulation]
32
32
  options.num-supernodes = 10
@@ -8,8 +8,8 @@ version = "1.0.0"
8
8
  description = ""
9
9
  license = "Apache-2.0"
10
10
  dependencies = [
11
- "flwr[simulation]>=1.9.0,<2.0",
12
- "flwr-datasets[vision]>=0.0.2,<1.0.0",
11
+ "flwr[simulation]>=1.10.0",
12
+ "flwr-datasets[vision]>=0.3.0",
13
13
  "torch==2.2.1",
14
14
  "torchvision==0.17.1",
15
15
  ]
@@ -25,11 +25,12 @@ serverapp = "$import_name.server_app:app"
25
25
  clientapp = "$import_name.client_app:app"
26
26
 
27
27
  [tool.flwr.app.config]
28
- num-server-rounds = "3"
29
- local-epochs = "1"
28
+ num-server-rounds = 3
29
+ fraction-fit = 0.5
30
+ local-epochs = 1
30
31
 
31
32
  [tool.flwr.federations]
32
- default = "localhost"
33
+ default = "local-simulation"
33
34
 
34
- [tool.flwr.federations.localhost]
35
+ [tool.flwr.federations.local-simulation]
35
36
  options.num-supernodes = 10
@@ -8,8 +8,8 @@ version = "1.0.0"
8
8
  description = ""
9
9
  license = "Apache-2.0"
10
10
  dependencies = [
11
- "flwr[simulation]>=1.9.0,<2.0",
12
- "flwr-datasets[vision]>=0.0.2,<1.0.0",
11
+ "flwr[simulation]>=1.10.0",
12
+ "flwr-datasets[vision]>=0.3.0",
13
13
  "scikit-learn>=1.1.1",
14
14
  ]
15
15
 
@@ -24,10 +24,10 @@ serverapp = "$import_name.server_app:app"
24
24
  clientapp = "$import_name.client_app:app"
25
25
 
26
26
  [tool.flwr.app.config]
27
- num-server-rounds = "3"
27
+ num-server-rounds = 3
28
28
 
29
29
  [tool.flwr.federations]
30
- default = "localhost"
30
+ default = "local-simulation"
31
31
 
32
- [tool.flwr.federations.localhost]
32
+ [tool.flwr.federations.local-simulation]
33
33
  options.num-supernodes = 10
@@ -8,8 +8,8 @@ version = "1.0.0"
8
8
  description = ""
9
9
  license = "Apache-2.0"
10
10
  dependencies = [
11
- "flwr[simulation]>=1.9.0,<2.0",
12
- "flwr-datasets[vision]>=0.0.2,<1.0.0",
11
+ "flwr[simulation]>=1.10.0",
12
+ "flwr-datasets[vision]>=0.3.0",
13
13
  "tensorflow>=2.11.1",
14
14
  ]
15
15
 
@@ -24,13 +24,13 @@ serverapp = "$import_name.server_app:app"
24
24
  clientapp = "$import_name.client_app:app"
25
25
 
26
26
  [tool.flwr.app.config]
27
- num-server-rounds = "3"
28
- local-epochs = "1"
29
- batch-size = "32"
30
- verbose = "" # Empty string means False
27
+ num-server-rounds = 3
28
+ local-epochs = 1
29
+ batch-size = 32
30
+ verbose = false
31
31
 
32
32
  [tool.flwr.federations]
33
- default = "localhost"
33
+ default = "local-simulation"
34
34
 
35
- [tool.flwr.federations.localhost]
35
+ [tool.flwr.federations.local-simulation]
36
36
  options.num-supernodes = 10
flwr/cli/run/run.py CHANGED
@@ -25,7 +25,7 @@ from typing_extensions import Annotated
25
25
 
26
26
  from flwr.cli.build import build
27
27
  from flwr.cli.config_utils import load_and_validate
28
- from flwr.common.config import parse_config_args
28
+ from flwr.common.config import flatten_dict, parse_config_args
29
29
  from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
30
30
  from flwr.common.logger import log
31
31
  from flwr.common.serde import user_config_to_proto
@@ -35,27 +35,30 @@ from flwr.proto.exec_pb2_grpc import ExecStub
35
35
 
36
36
  # pylint: disable-next=too-many-locals
37
37
  def run(
38
- directory: Annotated[
38
+ app_dir: Annotated[
39
39
  Path,
40
- typer.Argument(help="Path of the Flower project to run"),
40
+ typer.Argument(help="Path of the Flower project to run."),
41
41
  ] = Path("."),
42
- federation_name: Annotated[
42
+ federation: Annotated[
43
43
  Optional[str],
44
- typer.Argument(help="Name of the federation to run the app on"),
44
+ typer.Argument(help="Name of the federation to run the app on."),
45
45
  ] = None,
46
46
  config_overrides: Annotated[
47
47
  Optional[List[str]],
48
48
  typer.Option(
49
49
  "--run-config",
50
50
  "-c",
51
- help="Override configuration key-value pairs",
51
+ help="Override configuration key-value pairs, should be of the format:\n\n"
52
+ "`--run-config key1=value1,key2=value2 --run-config key3=value3`\n\n"
53
+ "Note that `key1`, `key2`, and `key3` in this example need to exist "
54
+ "inside the `pyproject.toml` in order to be properly overriden.",
52
55
  ),
53
56
  ] = None,
54
57
  ) -> None:
55
58
  """Run Flower project."""
56
59
  typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
57
60
 
58
- pyproject_path = directory / "pyproject.toml" if directory else None
61
+ pyproject_path = app_dir / "pyproject.toml" if app_dir else None
59
62
  config, errors, warnings = load_and_validate(path=pyproject_path)
60
63
 
61
64
  if config is None:
@@ -78,11 +81,9 @@ def run(
78
81
 
79
82
  typer.secho("Success", fg=typer.colors.GREEN)
80
83
 
81
- federation_name = federation_name or config["tool"]["flwr"]["federations"].get(
82
- "default"
83
- )
84
+ federation = federation or config["tool"]["flwr"]["federations"].get("default")
84
85
 
85
- if federation_name is None:
86
+ if federation is None:
86
87
  typer.secho(
87
88
  "❌ No federation name was provided and the project's `pyproject.toml` "
88
89
  "doesn't declare a default federation (with a SuperExec address or an "
@@ -93,13 +94,13 @@ def run(
93
94
  raise typer.Exit(code=1)
94
95
 
95
96
  # Validate the federation exists in the configuration
96
- federation = config["tool"]["flwr"]["federations"].get(federation_name)
97
- if federation is None:
97
+ federation_config = config["tool"]["flwr"]["federations"].get(federation)
98
+ if federation_config is None:
98
99
  available_feds = {
99
100
  fed for fed in config["tool"]["flwr"]["federations"] if fed != "default"
100
101
  }
101
102
  typer.secho(
102
- f"❌ There is no `{federation_name}` federation declared in the "
103
+ f"❌ There is no `{federation}` federation declared in "
103
104
  "`pyproject.toml`.\n The following federations were found:\n\n"
104
105
  + "\n".join(available_feds),
105
106
  fg=typer.colors.RED,
@@ -107,15 +108,15 @@ def run(
107
108
  )
108
109
  raise typer.Exit(code=1)
109
110
 
110
- if "address" in federation:
111
- _run_with_superexec(federation, directory, config_overrides)
111
+ if "address" in federation_config:
112
+ _run_with_superexec(federation_config, app_dir, config_overrides)
112
113
  else:
113
- _run_without_superexec(directory, federation, federation_name, config_overrides)
114
+ _run_without_superexec(app_dir, federation_config, federation, config_overrides)
114
115
 
115
116
 
116
117
  def _run_with_superexec(
117
- federation: Dict[str, str],
118
- directory: Optional[Path],
118
+ federation_config: Dict[str, Any],
119
+ app_dir: Optional[Path],
119
120
  config_overrides: Optional[List[str]],
120
121
  ) -> None:
121
122
 
@@ -123,8 +124,8 @@ def _run_with_superexec(
123
124
  """Log channel connectivity."""
124
125
  log(DEBUG, channel_connectivity)
125
126
 
126
- insecure_str = federation.get("insecure")
127
- if root_certificates := federation.get("root-certificates"):
127
+ insecure_str = federation_config.get("insecure")
128
+ if root_certificates := federation_config.get("root-certificates"):
128
129
  root_certificates_bytes = Path(root_certificates).read_bytes()
129
130
  if insecure := bool(insecure_str):
130
131
  typer.secho(
@@ -152,7 +153,7 @@ def _run_with_superexec(
152
153
  raise typer.Exit(code=1)
153
154
 
154
155
  channel = create_channel(
155
- server_address=federation["address"],
156
+ server_address=federation_config["address"],
156
157
  insecure=insecure,
157
158
  root_certificates=root_certificates_bytes,
158
159
  max_message_length=GRPC_MAX_MESSAGE_LENGTH,
@@ -161,13 +162,16 @@ def _run_with_superexec(
161
162
  channel.subscribe(on_channel_state_change)
162
163
  stub = ExecStub(channel)
163
164
 
164
- fab_path = build(directory)
165
+ fab_path = build(app_dir)
165
166
 
166
167
  req = StartRunRequest(
167
168
  fab_file=Path(fab_path).read_bytes(),
168
169
  override_config=user_config_to_proto(
169
170
  parse_config_args(config_overrides, separator=",")
170
171
  ),
172
+ federation_config=user_config_to_proto(
173
+ flatten_dict(federation_config.get("options"))
174
+ ),
171
175
  )
172
176
  res = stub.StartRun(req)
173
177
  typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
@@ -175,18 +179,18 @@ def _run_with_superexec(
175
179
 
176
180
  def _run_without_superexec(
177
181
  app_path: Optional[Path],
178
- federation: Dict[str, Any],
179
- federation_name: str,
182
+ federation_config: Dict[str, Any],
183
+ federation: str,
180
184
  config_overrides: Optional[List[str]],
181
185
  ) -> None:
182
186
  try:
183
- num_supernodes = federation["options"]["num-supernodes"]
187
+ num_supernodes = federation_config["options"]["num-supernodes"]
184
188
  except KeyError as err:
185
189
  typer.secho(
186
190
  "❌ The project's `pyproject.toml` needs to declare the number of"
187
191
  " SuperNodes in the simulation. To simulate 10 SuperNodes,"
188
192
  " use the following notation:\n\n"
189
- f"[tool.flwr.federations.{federation_name}]\n"
193
+ f"[tool.flwr.federations.{federation}]\n"
190
194
  "options.num-supernodes = 10\n",
191
195
  fg=typer.colors.RED,
192
196
  bold=True,
@@ -28,6 +28,7 @@ from flwr.common.constant import (
28
28
  GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY,
29
29
  )
30
30
  from flwr.common.version import package_version
31
+ from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
31
32
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
32
33
  CreateNodeRequest,
33
34
  CreateNodeResponse,
@@ -131,3 +132,9 @@ class GrpcAdapter:
131
132
  ) -> GetRunResponse:
132
133
  """."""
133
134
  return self._send_and_receive(request, GetRunResponse, **kwargs)
135
+
136
+ def GetFab( # pylint: disable=C0103
137
+ self, request: GetFabRequest, **kwargs: Any
138
+ ) -> GetFabResponse:
139
+ """."""
140
+ return self._send_and_receive(request, GetFabResponse, **kwargs)