flwr-nightly 1.10.0.dev20240722__py3-none-any.whl → 1.11.0.dev20240805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/config_utils.py +40 -23
- flwr/cli/new/new.py +7 -6
- flwr/cli/new/templates/app/README.md.tpl +1 -1
- flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
- flwr/cli/new/templates/app/code/{client.hf.py.tpl → client.huggingface.py.tpl} +8 -6
- flwr/cli/new/templates/app/code/client.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +29 -11
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +16 -13
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +3 -3
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +20 -13
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +20 -17
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +5 -3
- flwr/cli/new/templates/app/code/{server.hf.py.tpl → server.huggingface.py.tpl} +3 -2
- flwr/cli/new/templates/app/code/server.jax.py.tpl +3 -2
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +3 -2
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +3 -2
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +8 -7
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +3 -2
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +5 -6
- flwr/cli/new/templates/app/code/{task.hf.py.tpl → task.huggingface.py.tpl} +14 -2
- flwr/cli/new/templates/app/code/task.jax.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +15 -2
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +26 -21
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -5
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +3 -3
- flwr/cli/new/templates/app/{pyproject.hf.toml.tpl → pyproject.huggingface.toml.tpl} +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +11 -11
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -6
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +5 -5
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +8 -8
- flwr/cli/run/run.py +31 -27
- flwr/client/grpc_rere_client/grpc_adapter.py +7 -0
- flwr/client/supernode/app.py +12 -43
- flwr/common/config.py +6 -1
- flwr/common/object_ref.py +84 -21
- flwr/proto/driver_pb2.py +22 -21
- flwr/proto/driver_pb2.pyi +7 -1
- flwr/proto/driver_pb2_grpc.py +35 -0
- flwr/proto/driver_pb2_grpc.pyi +14 -0
- flwr/proto/exec_pb2.py +16 -12
- flwr/proto/exec_pb2.pyi +20 -1
- flwr/proto/fleet_pb2.py +28 -27
- flwr/proto/fleet_pb2_grpc.py +35 -0
- flwr/proto/fleet_pb2_grpc.pyi +14 -0
- flwr/proto/run_pb2.py +8 -8
- flwr/proto/run_pb2.pyi +4 -1
- flwr/server/run_serverapp.py +0 -3
- flwr/server/superlink/driver/driver_servicer.py +7 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +7 -0
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +4 -4
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/run_simulation.py +32 -4
- flwr/superexec/app.py +4 -5
- flwr/superexec/deployment.py +1 -2
- flwr/superexec/exec_servicer.py +3 -1
- flwr/superexec/executor.py +3 -0
- flwr/superexec/simulation.py +54 -12
- {flwr_nightly-1.10.0.dev20240722.dist-info → flwr_nightly-1.11.0.dev20240805.dist-info}/METADATA +1 -1
- {flwr_nightly-1.10.0.dev20240722.dist-info → flwr_nightly-1.11.0.dev20240805.dist-info}/RECORD +66 -66
- {flwr_nightly-1.10.0.dev20240722.dist-info → flwr_nightly-1.11.0.dev20240805.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240722.dist-info → flwr_nightly-1.11.0.dev20240805.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240722.dist-info → flwr_nightly-1.11.0.dev20240805.dist-info}/entry_points.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
from flwr.common import Context, ndarrays_to_parameters
|
|
4
4
|
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
@@ -6,14 +6,13 @@ from flwr.server.strategy import FedAvg
|
|
|
6
6
|
|
|
7
7
|
from $import_name.task import load_model
|
|
8
8
|
|
|
9
|
-
# Define config
|
|
10
|
-
config = ServerConfig(num_rounds=3)
|
|
11
|
-
|
|
12
|
-
parameters = ndarrays_to_parameters(load_model().get_weights())
|
|
13
9
|
|
|
14
10
|
def server_fn(context: Context):
|
|
15
11
|
# Read from config
|
|
16
|
-
num_rounds =
|
|
12
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
13
|
+
|
|
14
|
+
# Get parameters to initialize global model
|
|
15
|
+
parameters = ndarrays_to_parameters(load_model().get_weights())
|
|
17
16
|
|
|
18
17
|
# Define strategy
|
|
19
18
|
strategy = strategy = FedAvg(
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
import warnings
|
|
4
4
|
from collections import OrderedDict
|
|
@@ -10,15 +10,27 @@ from torch.utils.data import DataLoader
|
|
|
10
10
|
from transformers import AutoTokenizer, DataCollatorWithPadding
|
|
11
11
|
|
|
12
12
|
from flwr_datasets import FederatedDataset
|
|
13
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
14
|
+
|
|
13
15
|
|
|
14
16
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
15
17
|
DEVICE = torch.device("cpu")
|
|
16
18
|
CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint
|
|
17
19
|
|
|
18
20
|
|
|
21
|
+
fds = None # Cache FederatedDataset
|
|
22
|
+
|
|
23
|
+
|
|
19
24
|
def load_data(partition_id: int, num_partitions: int):
|
|
20
25
|
"""Load IMDB data (training and eval)"""
|
|
21
|
-
|
|
26
|
+
# Only initialize `FederatedDataset` once
|
|
27
|
+
global fds
|
|
28
|
+
if fds is None:
|
|
29
|
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
30
|
+
fds = FederatedDataset(
|
|
31
|
+
dataset="stanfordnlp/imdb",
|
|
32
|
+
partitioners={"train": partitioner},
|
|
33
|
+
)
|
|
22
34
|
partition = fds.load_partition(partition_id)
|
|
23
35
|
# Divide data: 80% train, 20% test
|
|
24
36
|
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
import jax
|
|
4
4
|
import jax.numpy as jnp
|
|
@@ -33,7 +33,7 @@ def train(params, grad_fn, X, y):
|
|
|
33
33
|
num_examples = X.shape[0]
|
|
34
34
|
for epochs in range(50):
|
|
35
35
|
grads = grad_fn(params, X, y)
|
|
36
|
-
params = jax.
|
|
36
|
+
params = jax.tree_map(lambda p, g: p - 0.05 * g, params, grads)
|
|
37
37
|
loss = loss_fn(params, X, y)
|
|
38
38
|
return params, loss, num_examples
|
|
39
39
|
|
|
@@ -1,14 +1,16 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
import mlx.core as mx
|
|
4
4
|
import mlx.nn as nn
|
|
5
5
|
import numpy as np
|
|
6
6
|
from datasets.utils.logging import disable_progress_bar
|
|
7
7
|
from flwr_datasets import FederatedDataset
|
|
8
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
disable_progress_bar()
|
|
11
12
|
|
|
13
|
+
|
|
12
14
|
class MLP(nn.Module):
|
|
13
15
|
"""A simple MLP."""
|
|
14
16
|
|
|
@@ -43,8 +45,19 @@ def batch_iterate(batch_size, X, y):
|
|
|
43
45
|
yield X[ids], y[ids]
|
|
44
46
|
|
|
45
47
|
|
|
48
|
+
fds = None # Cache FederatedDataset
|
|
49
|
+
|
|
50
|
+
|
|
46
51
|
def load_data(partition_id: int, num_partitions: int):
|
|
47
|
-
|
|
52
|
+
# Only initialize `FederatedDataset` once
|
|
53
|
+
global fds
|
|
54
|
+
if fds is None:
|
|
55
|
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
56
|
+
fds = FederatedDataset(
|
|
57
|
+
dataset="ylecun/mnist",
|
|
58
|
+
partitioners={"train": partitioner},
|
|
59
|
+
trust_remote_code=True,
|
|
60
|
+
)
|
|
48
61
|
partition = fds.load_partition(partition_id)
|
|
49
62
|
partition_splits = partition.train_test_split(test_size=0.2, seed=42)
|
|
50
63
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
from collections import OrderedDict
|
|
4
4
|
|
|
@@ -6,11 +6,9 @@ import torch
|
|
|
6
6
|
import torch.nn as nn
|
|
7
7
|
import torch.nn.functional as F
|
|
8
8
|
from torch.utils.data import DataLoader
|
|
9
|
-
from torchvision.datasets import CIFAR10
|
|
10
9
|
from torchvision.transforms import Compose, Normalize, ToTensor
|
|
11
10
|
from flwr_datasets import FederatedDataset
|
|
12
|
-
|
|
13
|
-
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
11
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
14
12
|
|
|
15
13
|
|
|
16
14
|
class Net(nn.Module):
|
|
@@ -34,9 +32,19 @@ class Net(nn.Module):
|
|
|
34
32
|
return self.fc3(x)
|
|
35
33
|
|
|
36
34
|
|
|
35
|
+
fds = None # Cache FederatedDataset
|
|
36
|
+
|
|
37
|
+
|
|
37
38
|
def load_data(partition_id: int, num_partitions: int):
|
|
38
39
|
"""Load partition CIFAR10 data."""
|
|
39
|
-
|
|
40
|
+
# Only initialize `FederatedDataset` once
|
|
41
|
+
global fds
|
|
42
|
+
if fds is None:
|
|
43
|
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
44
|
+
fds = FederatedDataset(
|
|
45
|
+
dataset="uoft-cs/cifar10",
|
|
46
|
+
partitioners={"train": partitioner},
|
|
47
|
+
)
|
|
40
48
|
partition = fds.load_partition(partition_id)
|
|
41
49
|
# Divide data on each node: 80% train, 20% test
|
|
42
50
|
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
|
@@ -55,44 +63,41 @@ def load_data(partition_id: int, num_partitions: int):
|
|
|
55
63
|
return trainloader, testloader
|
|
56
64
|
|
|
57
65
|
|
|
58
|
-
def train(net, trainloader,
|
|
66
|
+
def train(net, trainloader, epochs, device):
|
|
59
67
|
"""Train the model on the training set."""
|
|
60
68
|
net.to(device) # move model to GPU if available
|
|
61
69
|
criterion = torch.nn.CrossEntropyLoss().to(device)
|
|
62
|
-
optimizer = torch.optim.SGD(net.parameters(), lr=0.
|
|
70
|
+
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
|
|
63
71
|
net.train()
|
|
72
|
+
running_loss = 0.0
|
|
64
73
|
for _ in range(epochs):
|
|
65
74
|
for batch in trainloader:
|
|
66
75
|
images = batch["img"]
|
|
67
76
|
labels = batch["label"]
|
|
68
77
|
optimizer.zero_grad()
|
|
69
|
-
criterion(net(images.to(
|
|
78
|
+
loss = criterion(net(images.to(device)), labels.to(device))
|
|
79
|
+
loss.backward()
|
|
70
80
|
optimizer.step()
|
|
81
|
+
running_loss += loss.item()
|
|
71
82
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
results = {
|
|
76
|
-
"train_loss": train_loss,
|
|
77
|
-
"train_accuracy": train_acc,
|
|
78
|
-
"val_loss": val_loss,
|
|
79
|
-
"val_accuracy": val_acc,
|
|
80
|
-
}
|
|
81
|
-
return results
|
|
83
|
+
avg_trainloss = running_loss / len(trainloader)
|
|
84
|
+
return avg_trainloss
|
|
82
85
|
|
|
83
86
|
|
|
84
|
-
def test(net, testloader):
|
|
87
|
+
def test(net, testloader, device):
|
|
85
88
|
"""Validate the model on the test set."""
|
|
89
|
+
net.to(device)
|
|
86
90
|
criterion = torch.nn.CrossEntropyLoss()
|
|
87
91
|
correct, loss = 0, 0.0
|
|
88
92
|
with torch.no_grad():
|
|
89
93
|
for batch in testloader:
|
|
90
|
-
images = batch["img"].to(
|
|
91
|
-
labels = batch["label"].to(
|
|
94
|
+
images = batch["img"].to(device)
|
|
95
|
+
labels = batch["label"].to(device)
|
|
92
96
|
outputs = net(images)
|
|
93
97
|
loss += criterion(outputs, labels).item()
|
|
94
98
|
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
|
|
95
99
|
accuracy = correct / len(testloader.dataset)
|
|
100
|
+
loss = loss / len(testloader)
|
|
96
101
|
return loss, accuracy
|
|
97
102
|
|
|
98
103
|
|
|
@@ -1,24 +1,48 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
|
|
5
|
-
import
|
|
5
|
+
import keras
|
|
6
|
+
from keras import layers
|
|
6
7
|
from flwr_datasets import FederatedDataset
|
|
8
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
# Make TensorFlow log less verbose
|
|
10
12
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
11
13
|
|
|
14
|
+
|
|
12
15
|
def load_model():
|
|
13
|
-
#
|
|
14
|
-
model =
|
|
16
|
+
# Define a simple CNN for CIFAR-10 and set Adam optimizer
|
|
17
|
+
model = keras.Sequential(
|
|
18
|
+
[
|
|
19
|
+
keras.Input(shape=(32, 32, 3)),
|
|
20
|
+
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
|
|
21
|
+
layers.MaxPooling2D(pool_size=(2, 2)),
|
|
22
|
+
layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
|
|
23
|
+
layers.MaxPooling2D(pool_size=(2, 2)),
|
|
24
|
+
layers.Flatten(),
|
|
25
|
+
layers.Dropout(0.5),
|
|
26
|
+
layers.Dense(10, activation="softmax"),
|
|
27
|
+
]
|
|
28
|
+
)
|
|
15
29
|
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
|
|
16
30
|
return model
|
|
17
31
|
|
|
18
32
|
|
|
33
|
+
fds = None # Cache FederatedDataset
|
|
34
|
+
|
|
35
|
+
|
|
19
36
|
def load_data(partition_id, num_partitions):
|
|
20
37
|
# Download and partition dataset
|
|
21
|
-
|
|
38
|
+
# Only initialize `FederatedDataset` once
|
|
39
|
+
global fds
|
|
40
|
+
if fds is None:
|
|
41
|
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
42
|
+
fds = FederatedDataset(
|
|
43
|
+
dataset="uoft-cs/cifar10",
|
|
44
|
+
partitioners={"train": partitioner},
|
|
45
|
+
)
|
|
22
46
|
partition = fds.load_partition(partition_id, "train")
|
|
23
47
|
partition.set_format("numpy")
|
|
24
48
|
|
|
@@ -30,10 +30,10 @@ serverapp = "$import_name.app:server"
|
|
|
30
30
|
clientapp = "$import_name.app:client"
|
|
31
31
|
|
|
32
32
|
[tool.flwr.app.config]
|
|
33
|
-
num-server-rounds =
|
|
33
|
+
num-server-rounds = 3
|
|
34
34
|
|
|
35
35
|
[tool.flwr.federations]
|
|
36
|
-
default = "
|
|
36
|
+
default = "local-simulation"
|
|
37
37
|
|
|
38
|
-
[tool.flwr.federations.
|
|
38
|
+
[tool.flwr.federations.local-simulation]
|
|
39
39
|
options.num-supernodes = 10
|
|
@@ -8,8 +8,8 @@ version = "1.0.0"
|
|
|
8
8
|
description = ""
|
|
9
9
|
license = "Apache-2.0"
|
|
10
10
|
dependencies = [
|
|
11
|
-
"flwr[simulation]>=1.
|
|
12
|
-
"flwr-datasets>=0.
|
|
11
|
+
"flwr[simulation]>=1.10.0",
|
|
12
|
+
"flwr-datasets>=0.3.0",
|
|
13
13
|
"torch==2.2.1",
|
|
14
14
|
"transformers>=4.30.0,<5.0",
|
|
15
15
|
"evaluate>=0.4.0,<1.0",
|
|
@@ -28,8 +28,8 @@ serverapp = "$import_name.server_app:app"
|
|
|
28
28
|
clientapp = "$import_name.client_app:app"
|
|
29
29
|
|
|
30
30
|
[tool.flwr.app.config]
|
|
31
|
-
num-server-rounds =
|
|
32
|
-
local-epochs =
|
|
31
|
+
num-server-rounds = 3
|
|
32
|
+
local-epochs = 1
|
|
33
33
|
|
|
34
34
|
[tool.flwr.federations]
|
|
35
35
|
default = "localhost"
|
|
@@ -8,7 +8,7 @@ version = "1.0.0"
|
|
|
8
8
|
description = ""
|
|
9
9
|
license = "Apache-2.0"
|
|
10
10
|
dependencies = [
|
|
11
|
-
"flwr[simulation]>=1.
|
|
11
|
+
"flwr[simulation]>=1.10.0",
|
|
12
12
|
"jax==0.4.13",
|
|
13
13
|
"jaxlib==0.4.13",
|
|
14
14
|
"scikit-learn==1.3.2",
|
|
@@ -25,10 +25,10 @@ serverapp = "$import_name.server_app:app"
|
|
|
25
25
|
clientapp = "$import_name.client_app:app"
|
|
26
26
|
|
|
27
27
|
[tool.flwr.app.config]
|
|
28
|
-
num-server-rounds =
|
|
28
|
+
num-server-rounds = 3
|
|
29
29
|
|
|
30
30
|
[tool.flwr.federations]
|
|
31
|
-
default = "
|
|
31
|
+
default = "local-simulation"
|
|
32
32
|
|
|
33
|
-
[tool.flwr.federations.
|
|
33
|
+
[tool.flwr.federations.local-simulation]
|
|
34
34
|
options.num-supernodes = 10
|
|
@@ -8,9 +8,9 @@ version = "1.0.0"
|
|
|
8
8
|
description = ""
|
|
9
9
|
license = "Apache-2.0"
|
|
10
10
|
dependencies = [
|
|
11
|
-
"flwr[simulation]>=1.
|
|
12
|
-
"flwr-datasets[vision]>=0.
|
|
13
|
-
"mlx==0.
|
|
11
|
+
"flwr[simulation]>=1.10.0",
|
|
12
|
+
"flwr-datasets[vision]>=0.3.0",
|
|
13
|
+
"mlx==0.16.1",
|
|
14
14
|
"numpy==1.24.4",
|
|
15
15
|
]
|
|
16
16
|
|
|
@@ -25,15 +25,15 @@ serverapp = "$import_name.server_app:app"
|
|
|
25
25
|
clientapp = "$import_name.client_app:app"
|
|
26
26
|
|
|
27
27
|
[tool.flwr.app.config]
|
|
28
|
-
num-server-rounds =
|
|
29
|
-
local-epochs =
|
|
30
|
-
num-layers =
|
|
31
|
-
hidden-dim =
|
|
32
|
-
batch-size =
|
|
33
|
-
lr =
|
|
28
|
+
num-server-rounds = 3
|
|
29
|
+
local-epochs = 1
|
|
30
|
+
num-layers = 2
|
|
31
|
+
hidden-dim = 32
|
|
32
|
+
batch-size = 256
|
|
33
|
+
lr = 0.1
|
|
34
34
|
|
|
35
35
|
[tool.flwr.federations]
|
|
36
|
-
default = "
|
|
36
|
+
default = "local-simulation"
|
|
37
37
|
|
|
38
|
-
[tool.flwr.federations.
|
|
38
|
+
[tool.flwr.federations.local-simulation]
|
|
39
39
|
options.num-supernodes = 10
|
|
@@ -8,7 +8,7 @@ version = "1.0.0"
|
|
|
8
8
|
description = ""
|
|
9
9
|
license = "Apache-2.0"
|
|
10
10
|
dependencies = [
|
|
11
|
-
"flwr[simulation]>=1.
|
|
11
|
+
"flwr[simulation]>=1.10.0",
|
|
12
12
|
"numpy>=1.21.0",
|
|
13
13
|
]
|
|
14
14
|
|
|
@@ -23,10 +23,10 @@ serverapp = "$import_name.server_app:app"
|
|
|
23
23
|
clientapp = "$import_name.client_app:app"
|
|
24
24
|
|
|
25
25
|
[tool.flwr.app.config]
|
|
26
|
-
num-server-rounds =
|
|
26
|
+
num-server-rounds = 3
|
|
27
27
|
|
|
28
28
|
[tool.flwr.federations]
|
|
29
|
-
default = "
|
|
29
|
+
default = "local-simulation"
|
|
30
30
|
|
|
31
|
-
[tool.flwr.federations.
|
|
31
|
+
[tool.flwr.federations.local-simulation]
|
|
32
32
|
options.num-supernodes = 10
|
|
@@ -8,8 +8,8 @@ version = "1.0.0"
|
|
|
8
8
|
description = ""
|
|
9
9
|
license = "Apache-2.0"
|
|
10
10
|
dependencies = [
|
|
11
|
-
"flwr[simulation]>=1.
|
|
12
|
-
"flwr-datasets[vision]>=0.
|
|
11
|
+
"flwr[simulation]>=1.10.0",
|
|
12
|
+
"flwr-datasets[vision]>=0.3.0",
|
|
13
13
|
"torch==2.2.1",
|
|
14
14
|
"torchvision==0.17.1",
|
|
15
15
|
]
|
|
@@ -25,11 +25,12 @@ serverapp = "$import_name.server_app:app"
|
|
|
25
25
|
clientapp = "$import_name.client_app:app"
|
|
26
26
|
|
|
27
27
|
[tool.flwr.app.config]
|
|
28
|
-
num-server-rounds =
|
|
29
|
-
|
|
28
|
+
num-server-rounds = 3
|
|
29
|
+
fraction-fit = 0.5
|
|
30
|
+
local-epochs = 1
|
|
30
31
|
|
|
31
32
|
[tool.flwr.federations]
|
|
32
|
-
default = "
|
|
33
|
+
default = "local-simulation"
|
|
33
34
|
|
|
34
|
-
[tool.flwr.federations.
|
|
35
|
+
[tool.flwr.federations.local-simulation]
|
|
35
36
|
options.num-supernodes = 10
|
|
@@ -8,8 +8,8 @@ version = "1.0.0"
|
|
|
8
8
|
description = ""
|
|
9
9
|
license = "Apache-2.0"
|
|
10
10
|
dependencies = [
|
|
11
|
-
"flwr[simulation]>=1.
|
|
12
|
-
"flwr-datasets[vision]>=0.
|
|
11
|
+
"flwr[simulation]>=1.10.0",
|
|
12
|
+
"flwr-datasets[vision]>=0.3.0",
|
|
13
13
|
"scikit-learn>=1.1.1",
|
|
14
14
|
]
|
|
15
15
|
|
|
@@ -24,10 +24,10 @@ serverapp = "$import_name.server_app:app"
|
|
|
24
24
|
clientapp = "$import_name.client_app:app"
|
|
25
25
|
|
|
26
26
|
[tool.flwr.app.config]
|
|
27
|
-
num-server-rounds =
|
|
27
|
+
num-server-rounds = 3
|
|
28
28
|
|
|
29
29
|
[tool.flwr.federations]
|
|
30
|
-
default = "
|
|
30
|
+
default = "local-simulation"
|
|
31
31
|
|
|
32
|
-
[tool.flwr.federations.
|
|
32
|
+
[tool.flwr.federations.local-simulation]
|
|
33
33
|
options.num-supernodes = 10
|
|
@@ -8,8 +8,8 @@ version = "1.0.0"
|
|
|
8
8
|
description = ""
|
|
9
9
|
license = "Apache-2.0"
|
|
10
10
|
dependencies = [
|
|
11
|
-
"flwr[simulation]>=1.
|
|
12
|
-
"flwr-datasets[vision]>=0.
|
|
11
|
+
"flwr[simulation]>=1.10.0",
|
|
12
|
+
"flwr-datasets[vision]>=0.3.0",
|
|
13
13
|
"tensorflow>=2.11.1",
|
|
14
14
|
]
|
|
15
15
|
|
|
@@ -24,13 +24,13 @@ serverapp = "$import_name.server_app:app"
|
|
|
24
24
|
clientapp = "$import_name.client_app:app"
|
|
25
25
|
|
|
26
26
|
[tool.flwr.app.config]
|
|
27
|
-
num-server-rounds =
|
|
28
|
-
local-epochs =
|
|
29
|
-
batch-size =
|
|
30
|
-
verbose =
|
|
27
|
+
num-server-rounds = 3
|
|
28
|
+
local-epochs = 1
|
|
29
|
+
batch-size = 32
|
|
30
|
+
verbose = false
|
|
31
31
|
|
|
32
32
|
[tool.flwr.federations]
|
|
33
|
-
default = "
|
|
33
|
+
default = "local-simulation"
|
|
34
34
|
|
|
35
|
-
[tool.flwr.federations.
|
|
35
|
+
[tool.flwr.federations.local-simulation]
|
|
36
36
|
options.num-supernodes = 10
|
flwr/cli/run/run.py
CHANGED
|
@@ -25,7 +25,7 @@ from typing_extensions import Annotated
|
|
|
25
25
|
|
|
26
26
|
from flwr.cli.build import build
|
|
27
27
|
from flwr.cli.config_utils import load_and_validate
|
|
28
|
-
from flwr.common.config import parse_config_args
|
|
28
|
+
from flwr.common.config import flatten_dict, parse_config_args
|
|
29
29
|
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
|
|
30
30
|
from flwr.common.logger import log
|
|
31
31
|
from flwr.common.serde import user_config_to_proto
|
|
@@ -35,27 +35,30 @@ from flwr.proto.exec_pb2_grpc import ExecStub
|
|
|
35
35
|
|
|
36
36
|
# pylint: disable-next=too-many-locals
|
|
37
37
|
def run(
|
|
38
|
-
|
|
38
|
+
app_dir: Annotated[
|
|
39
39
|
Path,
|
|
40
|
-
typer.Argument(help="Path of the Flower project to run"),
|
|
40
|
+
typer.Argument(help="Path of the Flower project to run."),
|
|
41
41
|
] = Path("."),
|
|
42
|
-
|
|
42
|
+
federation: Annotated[
|
|
43
43
|
Optional[str],
|
|
44
|
-
typer.Argument(help="Name of the federation to run the app on"),
|
|
44
|
+
typer.Argument(help="Name of the federation to run the app on."),
|
|
45
45
|
] = None,
|
|
46
46
|
config_overrides: Annotated[
|
|
47
47
|
Optional[List[str]],
|
|
48
48
|
typer.Option(
|
|
49
49
|
"--run-config",
|
|
50
50
|
"-c",
|
|
51
|
-
help="Override configuration key-value pairs"
|
|
51
|
+
help="Override configuration key-value pairs, should be of the format:\n\n"
|
|
52
|
+
"`--run-config key1=value1,key2=value2 --run-config key3=value3`\n\n"
|
|
53
|
+
"Note that `key1`, `key2`, and `key3` in this example need to exist "
|
|
54
|
+
"inside the `pyproject.toml` in order to be properly overriden.",
|
|
52
55
|
),
|
|
53
56
|
] = None,
|
|
54
57
|
) -> None:
|
|
55
58
|
"""Run Flower project."""
|
|
56
59
|
typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
|
|
57
60
|
|
|
58
|
-
pyproject_path =
|
|
61
|
+
pyproject_path = app_dir / "pyproject.toml" if app_dir else None
|
|
59
62
|
config, errors, warnings = load_and_validate(path=pyproject_path)
|
|
60
63
|
|
|
61
64
|
if config is None:
|
|
@@ -78,11 +81,9 @@ def run(
|
|
|
78
81
|
|
|
79
82
|
typer.secho("Success", fg=typer.colors.GREEN)
|
|
80
83
|
|
|
81
|
-
|
|
82
|
-
"default"
|
|
83
|
-
)
|
|
84
|
+
federation = federation or config["tool"]["flwr"]["federations"].get("default")
|
|
84
85
|
|
|
85
|
-
if
|
|
86
|
+
if federation is None:
|
|
86
87
|
typer.secho(
|
|
87
88
|
"❌ No federation name was provided and the project's `pyproject.toml` "
|
|
88
89
|
"doesn't declare a default federation (with a SuperExec address or an "
|
|
@@ -93,13 +94,13 @@ def run(
|
|
|
93
94
|
raise typer.Exit(code=1)
|
|
94
95
|
|
|
95
96
|
# Validate the federation exists in the configuration
|
|
96
|
-
|
|
97
|
-
if
|
|
97
|
+
federation_config = config["tool"]["flwr"]["federations"].get(federation)
|
|
98
|
+
if federation_config is None:
|
|
98
99
|
available_feds = {
|
|
99
100
|
fed for fed in config["tool"]["flwr"]["federations"] if fed != "default"
|
|
100
101
|
}
|
|
101
102
|
typer.secho(
|
|
102
|
-
f"❌ There is no `{
|
|
103
|
+
f"❌ There is no `{federation}` federation declared in "
|
|
103
104
|
"`pyproject.toml`.\n The following federations were found:\n\n"
|
|
104
105
|
+ "\n".join(available_feds),
|
|
105
106
|
fg=typer.colors.RED,
|
|
@@ -107,15 +108,15 @@ def run(
|
|
|
107
108
|
)
|
|
108
109
|
raise typer.Exit(code=1)
|
|
109
110
|
|
|
110
|
-
if "address" in
|
|
111
|
-
_run_with_superexec(
|
|
111
|
+
if "address" in federation_config:
|
|
112
|
+
_run_with_superexec(federation_config, app_dir, config_overrides)
|
|
112
113
|
else:
|
|
113
|
-
_run_without_superexec(
|
|
114
|
+
_run_without_superexec(app_dir, federation_config, federation, config_overrides)
|
|
114
115
|
|
|
115
116
|
|
|
116
117
|
def _run_with_superexec(
|
|
117
|
-
|
|
118
|
-
|
|
118
|
+
federation_config: Dict[str, Any],
|
|
119
|
+
app_dir: Optional[Path],
|
|
119
120
|
config_overrides: Optional[List[str]],
|
|
120
121
|
) -> None:
|
|
121
122
|
|
|
@@ -123,8 +124,8 @@ def _run_with_superexec(
|
|
|
123
124
|
"""Log channel connectivity."""
|
|
124
125
|
log(DEBUG, channel_connectivity)
|
|
125
126
|
|
|
126
|
-
insecure_str =
|
|
127
|
-
if root_certificates :=
|
|
127
|
+
insecure_str = federation_config.get("insecure")
|
|
128
|
+
if root_certificates := federation_config.get("root-certificates"):
|
|
128
129
|
root_certificates_bytes = Path(root_certificates).read_bytes()
|
|
129
130
|
if insecure := bool(insecure_str):
|
|
130
131
|
typer.secho(
|
|
@@ -152,7 +153,7 @@ def _run_with_superexec(
|
|
|
152
153
|
raise typer.Exit(code=1)
|
|
153
154
|
|
|
154
155
|
channel = create_channel(
|
|
155
|
-
server_address=
|
|
156
|
+
server_address=federation_config["address"],
|
|
156
157
|
insecure=insecure,
|
|
157
158
|
root_certificates=root_certificates_bytes,
|
|
158
159
|
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
@@ -161,13 +162,16 @@ def _run_with_superexec(
|
|
|
161
162
|
channel.subscribe(on_channel_state_change)
|
|
162
163
|
stub = ExecStub(channel)
|
|
163
164
|
|
|
164
|
-
fab_path = build(
|
|
165
|
+
fab_path = build(app_dir)
|
|
165
166
|
|
|
166
167
|
req = StartRunRequest(
|
|
167
168
|
fab_file=Path(fab_path).read_bytes(),
|
|
168
169
|
override_config=user_config_to_proto(
|
|
169
170
|
parse_config_args(config_overrides, separator=",")
|
|
170
171
|
),
|
|
172
|
+
federation_config=user_config_to_proto(
|
|
173
|
+
flatten_dict(federation_config.get("options"))
|
|
174
|
+
),
|
|
171
175
|
)
|
|
172
176
|
res = stub.StartRun(req)
|
|
173
177
|
typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
|
|
@@ -175,18 +179,18 @@ def _run_with_superexec(
|
|
|
175
179
|
|
|
176
180
|
def _run_without_superexec(
|
|
177
181
|
app_path: Optional[Path],
|
|
178
|
-
|
|
179
|
-
|
|
182
|
+
federation_config: Dict[str, Any],
|
|
183
|
+
federation: str,
|
|
180
184
|
config_overrides: Optional[List[str]],
|
|
181
185
|
) -> None:
|
|
182
186
|
try:
|
|
183
|
-
num_supernodes =
|
|
187
|
+
num_supernodes = federation_config["options"]["num-supernodes"]
|
|
184
188
|
except KeyError as err:
|
|
185
189
|
typer.secho(
|
|
186
190
|
"❌ The project's `pyproject.toml` needs to declare the number of"
|
|
187
191
|
" SuperNodes in the simulation. To simulate 10 SuperNodes,"
|
|
188
192
|
" use the following notation:\n\n"
|
|
189
|
-
f"[tool.flwr.federations.{
|
|
193
|
+
f"[tool.flwr.federations.{federation}]\n"
|
|
190
194
|
"options.num-supernodes = 10\n",
|
|
191
195
|
fg=typer.colors.RED,
|
|
192
196
|
bold=True,
|
|
@@ -28,6 +28,7 @@ from flwr.common.constant import (
|
|
|
28
28
|
GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY,
|
|
29
29
|
)
|
|
30
30
|
from flwr.common.version import package_version
|
|
31
|
+
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
31
32
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
32
33
|
CreateNodeRequest,
|
|
33
34
|
CreateNodeResponse,
|
|
@@ -131,3 +132,9 @@ class GrpcAdapter:
|
|
|
131
132
|
) -> GetRunResponse:
|
|
132
133
|
"""."""
|
|
133
134
|
return self._send_and_receive(request, GetRunResponse, **kwargs)
|
|
135
|
+
|
|
136
|
+
def GetFab( # pylint: disable=C0103
|
|
137
|
+
self, request: GetFabRequest, **kwargs: Any
|
|
138
|
+
) -> GetFabResponse:
|
|
139
|
+
"""."""
|
|
140
|
+
return self._send_and_receive(request, GetFabResponse, **kwargs)
|