flwr-nightly 1.11.0.dev20240804__py3-none-any.whl → 1.11.0.dev20240806__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/config_utils.py +27 -8
- flwr/cli/new/new.py +10 -9
- flwr/cli/new/templates/app/README.md.tpl +1 -1
- flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +9 -8
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +5 -8
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +7 -6
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +4 -5
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +2 -1
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -20
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +16 -4
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -2
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/client/grpc_rere_client/grpc_adapter.py +7 -0
- flwr/proto/driver_pb2.py +22 -21
- flwr/proto/driver_pb2.pyi +7 -1
- flwr/proto/driver_pb2_grpc.py +35 -0
- flwr/proto/driver_pb2_grpc.pyi +14 -0
- flwr/proto/fleet_pb2.py +28 -27
- flwr/proto/fleet_pb2_grpc.py +35 -0
- flwr/proto/fleet_pb2_grpc.pyi +14 -0
- flwr/proto/run_pb2.py +8 -8
- flwr/proto/run_pb2.pyi +4 -1
- flwr/server/superlink/driver/driver_servicer.py +7 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +7 -0
- {flwr_nightly-1.11.0.dev20240804.dist-info → flwr_nightly-1.11.0.dev20240806.dist-info}/METADATA +1 -1
- {flwr_nightly-1.11.0.dev20240804.dist-info → flwr_nightly-1.11.0.dev20240806.dist-info}/RECORD +47 -47
- {flwr_nightly-1.11.0.dev20240804.dist-info → flwr_nightly-1.11.0.dev20240806.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.11.0.dev20240804.dist-info → flwr_nightly-1.11.0.dev20240806.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.11.0.dev20240804.dist-info → flwr_nightly-1.11.0.dev20240806.dist-info}/entry_points.txt +0 -0
flwr/cli/config_utils.py
CHANGED
|
@@ -25,8 +25,8 @@ from flwr.common import object_ref
|
|
|
25
25
|
from flwr.common.typing import UserConfigValue
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
def
|
|
29
|
-
"""Extract the
|
|
28
|
+
def get_fab_config(fab_file: Union[Path, bytes]) -> Dict[str, Any]:
|
|
29
|
+
"""Extract the config from a FAB file or path.
|
|
30
30
|
|
|
31
31
|
Parameters
|
|
32
32
|
----------
|
|
@@ -36,8 +36,8 @@ def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
|
|
|
36
36
|
|
|
37
37
|
Returns
|
|
38
38
|
-------
|
|
39
|
-
|
|
40
|
-
The `
|
|
39
|
+
Dict[str, Any]
|
|
40
|
+
The `config` of the given Flower App Bundle.
|
|
41
41
|
"""
|
|
42
42
|
fab_file_archive: Union[Path, IO[bytes]]
|
|
43
43
|
if isinstance(fab_file, bytes):
|
|
@@ -59,10 +59,29 @@ def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
|
|
|
59
59
|
if not is_valid:
|
|
60
60
|
raise ValueError(errors)
|
|
61
61
|
|
|
62
|
-
return
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
62
|
+
return conf
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
|
|
66
|
+
"""Extract the fab_id and the fab_version from a FAB file or path.
|
|
67
|
+
|
|
68
|
+
Parameters
|
|
69
|
+
----------
|
|
70
|
+
fab_file : Union[Path, bytes]
|
|
71
|
+
The Flower App Bundle file to validate and extract the metadata from.
|
|
72
|
+
It can either be a path to the file or the file itself as bytes.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
Tuple[str, str]
|
|
77
|
+
The `fab_version` and `fab_id` of the given Flower App Bundle.
|
|
78
|
+
"""
|
|
79
|
+
conf = get_fab_config(fab_file)
|
|
80
|
+
|
|
81
|
+
return (
|
|
82
|
+
conf["project"]["version"],
|
|
83
|
+
f"{conf['tool']['flwr']['app']['publisher']}/{conf['project']['name']}",
|
|
84
|
+
)
|
|
66
85
|
|
|
67
86
|
|
|
68
87
|
def load_and_validate(
|
flwr/cli/new/new.py
CHANGED
|
@@ -34,13 +34,13 @@ from ..utils import (
|
|
|
34
34
|
class MlFramework(str, Enum):
|
|
35
35
|
"""Available frameworks."""
|
|
36
36
|
|
|
37
|
-
NUMPY = "NumPy"
|
|
38
37
|
PYTORCH = "PyTorch"
|
|
39
38
|
TENSORFLOW = "TensorFlow"
|
|
40
|
-
|
|
39
|
+
SKLEARN = "sklearn"
|
|
41
40
|
HUGGINGFACE = "HuggingFace"
|
|
41
|
+
JAX = "JAX"
|
|
42
42
|
MLX = "MLX"
|
|
43
|
-
|
|
43
|
+
NUMPY = "NumPy"
|
|
44
44
|
FLOWERTUNE = "FlowerTune"
|
|
45
45
|
|
|
46
46
|
|
|
@@ -135,20 +135,20 @@ def new(
|
|
|
135
135
|
username = prompt_text("Please provide your Flower username")
|
|
136
136
|
|
|
137
137
|
if framework is not None:
|
|
138
|
-
|
|
138
|
+
framework_str_upper = str(framework.value)
|
|
139
139
|
else:
|
|
140
140
|
framework_value = prompt_options(
|
|
141
141
|
"Please select ML framework by typing in the number",
|
|
142
|
-
|
|
142
|
+
[mlf.value for mlf in MlFramework],
|
|
143
143
|
)
|
|
144
144
|
selected_value = [
|
|
145
145
|
name
|
|
146
146
|
for name, value in vars(MlFramework).items()
|
|
147
147
|
if value == framework_value
|
|
148
148
|
]
|
|
149
|
-
|
|
149
|
+
framework_str_upper = selected_value[0]
|
|
150
150
|
|
|
151
|
-
framework_str =
|
|
151
|
+
framework_str = framework_str_upper.lower()
|
|
152
152
|
|
|
153
153
|
llm_challenge_str = None
|
|
154
154
|
if framework_str == "flowertune":
|
|
@@ -173,9 +173,10 @@ def new(
|
|
|
173
173
|
)
|
|
174
174
|
|
|
175
175
|
context = {
|
|
176
|
-
"
|
|
177
|
-
"package_name": package_name,
|
|
176
|
+
"framework_str": framework_str_upper,
|
|
178
177
|
"import_name": import_name.replace("-", "_"),
|
|
178
|
+
"package_name": package_name,
|
|
179
|
+
"project_name": project_name,
|
|
179
180
|
"username": username,
|
|
180
181
|
}
|
|
181
182
|
|
|
@@ -1 +1 @@
|
|
|
1
|
-
"""$project_name."""
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
+
import torch
|
|
3
4
|
from flwr.client import NumPyClient, ClientApp
|
|
4
5
|
from flwr.common import Context
|
|
5
6
|
|
|
6
7
|
from $import_name.task import (
|
|
7
8
|
Net,
|
|
8
|
-
DEVICE,
|
|
9
9
|
load_data,
|
|
10
10
|
get_weights,
|
|
11
11
|
set_weights,
|
|
@@ -21,27 +21,28 @@ class FlowerClient(NumPyClient):
|
|
|
21
21
|
self.trainloader = trainloader
|
|
22
22
|
self.valloader = valloader
|
|
23
23
|
self.local_epochs = local_epochs
|
|
24
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
25
|
+
self.net.to(self.device)
|
|
24
26
|
|
|
25
27
|
def fit(self, parameters, config):
|
|
26
28
|
set_weights(self.net, parameters)
|
|
27
|
-
|
|
29
|
+
train_loss = train(
|
|
28
30
|
self.net,
|
|
29
31
|
self.trainloader,
|
|
30
|
-
self.valloader,
|
|
31
32
|
self.local_epochs,
|
|
32
|
-
|
|
33
|
+
self.device,
|
|
33
34
|
)
|
|
34
|
-
return get_weights(self.net), len(self.trainloader.dataset),
|
|
35
|
+
return get_weights(self.net), len(self.trainloader.dataset), {"train_loss": train_loss}
|
|
35
36
|
|
|
36
37
|
def evaluate(self, parameters, config):
|
|
37
38
|
set_weights(self.net, parameters)
|
|
38
|
-
loss, accuracy = test(self.net, self.valloader)
|
|
39
|
+
loss, accuracy = test(self.net, self.valloader, self.device)
|
|
39
40
|
return loss, len(self.valloader.dataset), {"accuracy": accuracy}
|
|
40
41
|
|
|
41
42
|
|
|
42
43
|
def client_fn(context: Context):
|
|
43
44
|
# Load model and data
|
|
44
|
-
net = Net()
|
|
45
|
+
net = Net()
|
|
45
46
|
partition_id = context.node_config["partition-id"]
|
|
46
47
|
num_partitions = context.node_config["num-partitions"]
|
|
47
48
|
trainloader, valloader = load_data(partition_id, num_partitions)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
from flwr.client import NumPyClient, ClientApp
|
|
4
4
|
from flwr.common import Context
|
|
@@ -9,13 +9,10 @@ from $import_name.task import load_data, load_model
|
|
|
9
9
|
# Define Flower Client and client_fn
|
|
10
10
|
class FlowerClient(NumPyClient):
|
|
11
11
|
def __init__(
|
|
12
|
-
self, model,
|
|
12
|
+
self, model, data, epochs, batch_size, verbose
|
|
13
13
|
):
|
|
14
14
|
self.model = model
|
|
15
|
-
self.x_train =
|
|
16
|
-
self.y_train = y_train
|
|
17
|
-
self.x_test = x_test
|
|
18
|
-
self.y_test = y_test
|
|
15
|
+
self.x_train, self.y_train, self.x_test, self.y_test = data
|
|
19
16
|
self.epochs = epochs
|
|
20
17
|
self.batch_size = batch_size
|
|
21
18
|
self.verbose = verbose
|
|
@@ -46,14 +43,14 @@ def client_fn(context: Context):
|
|
|
46
43
|
|
|
47
44
|
partition_id = context.node_config["partition-id"]
|
|
48
45
|
num_partitions = context.node_config["num-partitions"]
|
|
49
|
-
|
|
46
|
+
data = load_data(partition_id, num_partitions)
|
|
50
47
|
epochs = context.run_config["local-epochs"]
|
|
51
48
|
batch_size = context.run_config["batch-size"]
|
|
52
49
|
verbose = context.run_config.get("verbose")
|
|
53
50
|
|
|
54
51
|
# Return Client instance
|
|
55
52
|
return FlowerClient(
|
|
56
|
-
net,
|
|
53
|
+
net, data, epochs, batch_size, verbose
|
|
57
54
|
).to_client()
|
|
58
55
|
|
|
59
56
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
from flwr.common import Context, ndarrays_to_parameters
|
|
4
4
|
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
@@ -7,17 +7,18 @@ from flwr.server.strategy import FedAvg
|
|
|
7
7
|
from $import_name.task import Net, get_weights
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
# Initialize model parameters
|
|
11
|
-
ndarrays = get_weights(Net())
|
|
12
|
-
parameters = ndarrays_to_parameters(ndarrays)
|
|
13
|
-
|
|
14
10
|
def server_fn(context: Context):
|
|
15
11
|
# Read from config
|
|
16
12
|
num_rounds = context.run_config["num-server-rounds"]
|
|
13
|
+
fraction_fit = context.run_config["fraction-fit"]
|
|
14
|
+
|
|
15
|
+
# Initialize model parameters
|
|
16
|
+
ndarrays = get_weights(Net())
|
|
17
|
+
parameters = ndarrays_to_parameters(ndarrays)
|
|
17
18
|
|
|
18
19
|
# Define strategy
|
|
19
20
|
strategy = FedAvg(
|
|
20
|
-
fraction_fit=
|
|
21
|
+
fraction_fit=fraction_fit,
|
|
21
22
|
fraction_evaluate=1.0,
|
|
22
23
|
min_available_clients=2,
|
|
23
24
|
initial_parameters=parameters,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
from flwr.common import Context, ndarrays_to_parameters
|
|
4
4
|
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
@@ -6,15 +6,14 @@ from flwr.server.strategy import FedAvg
|
|
|
6
6
|
|
|
7
7
|
from $import_name.task import load_model
|
|
8
8
|
|
|
9
|
-
# Define config
|
|
10
|
-
config = ServerConfig(num_rounds=3)
|
|
11
|
-
|
|
12
|
-
parameters = ndarrays_to_parameters(load_model().get_weights())
|
|
13
9
|
|
|
14
10
|
def server_fn(context: Context):
|
|
15
11
|
# Read from config
|
|
16
12
|
num_rounds = context.run_config["num-server-rounds"]
|
|
17
13
|
|
|
14
|
+
# Get parameters to initialize global model
|
|
15
|
+
parameters = ndarrays_to_parameters(load_model().get_weights())
|
|
16
|
+
|
|
18
17
|
# Define strategy
|
|
19
18
|
strategy = strategy = FedAvg(
|
|
20
19
|
fraction_fit=1.0,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
import mlx.core as mx
|
|
4
4
|
import mlx.nn as nn
|
|
@@ -56,6 +56,7 @@ def load_data(partition_id: int, num_partitions: int):
|
|
|
56
56
|
fds = FederatedDataset(
|
|
57
57
|
dataset="ylecun/mnist",
|
|
58
58
|
partitioners={"train": partitioner},
|
|
59
|
+
trust_remote_code=True,
|
|
59
60
|
)
|
|
60
61
|
partition = fds.load_partition(partition_id)
|
|
61
62
|
partition_splits = partition.train_test_split(test_size=0.2, seed=42)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
from collections import OrderedDict
|
|
4
4
|
|
|
@@ -11,9 +11,6 @@ from flwr_datasets import FederatedDataset
|
|
|
11
11
|
from flwr_datasets.partitioner import IidPartitioner
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
15
|
-
|
|
16
|
-
|
|
17
14
|
class Net(nn.Module):
|
|
18
15
|
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""
|
|
19
16
|
|
|
@@ -66,44 +63,41 @@ def load_data(partition_id: int, num_partitions: int):
|
|
|
66
63
|
return trainloader, testloader
|
|
67
64
|
|
|
68
65
|
|
|
69
|
-
def train(net, trainloader,
|
|
66
|
+
def train(net, trainloader, epochs, device):
|
|
70
67
|
"""Train the model on the training set."""
|
|
71
68
|
net.to(device) # move model to GPU if available
|
|
72
69
|
criterion = torch.nn.CrossEntropyLoss().to(device)
|
|
73
|
-
optimizer = torch.optim.SGD(net.parameters(), lr=0.
|
|
70
|
+
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
|
|
74
71
|
net.train()
|
|
72
|
+
running_loss = 0.0
|
|
75
73
|
for _ in range(epochs):
|
|
76
74
|
for batch in trainloader:
|
|
77
75
|
images = batch["img"]
|
|
78
76
|
labels = batch["label"]
|
|
79
77
|
optimizer.zero_grad()
|
|
80
|
-
criterion(net(images.to(
|
|
78
|
+
loss = criterion(net(images.to(device)), labels.to(device))
|
|
79
|
+
loss.backward()
|
|
81
80
|
optimizer.step()
|
|
81
|
+
running_loss += loss.item()
|
|
82
82
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
results = {
|
|
87
|
-
"train_loss": train_loss,
|
|
88
|
-
"train_accuracy": train_acc,
|
|
89
|
-
"val_loss": val_loss,
|
|
90
|
-
"val_accuracy": val_acc,
|
|
91
|
-
}
|
|
92
|
-
return results
|
|
83
|
+
avg_trainloss = running_loss / len(trainloader)
|
|
84
|
+
return avg_trainloss
|
|
93
85
|
|
|
94
86
|
|
|
95
|
-
def test(net, testloader):
|
|
87
|
+
def test(net, testloader, device):
|
|
96
88
|
"""Validate the model on the test set."""
|
|
89
|
+
net.to(device)
|
|
97
90
|
criterion = torch.nn.CrossEntropyLoss()
|
|
98
91
|
correct, loss = 0, 0.0
|
|
99
92
|
with torch.no_grad():
|
|
100
93
|
for batch in testloader:
|
|
101
|
-
images = batch["img"].to(
|
|
102
|
-
labels = batch["label"].to(
|
|
94
|
+
images = batch["img"].to(device)
|
|
95
|
+
labels = batch["label"].to(device)
|
|
103
96
|
outputs = net(images)
|
|
104
97
|
loss += criterion(outputs, labels).item()
|
|
105
98
|
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
|
|
106
99
|
accuracy = correct / len(testloader.dataset)
|
|
100
|
+
loss = loss / len(testloader)
|
|
107
101
|
return loss, accuracy
|
|
108
102
|
|
|
109
103
|
|
|
@@ -1,8 +1,9 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
|
|
5
|
-
import
|
|
5
|
+
import keras
|
|
6
|
+
from keras import layers
|
|
6
7
|
from flwr_datasets import FederatedDataset
|
|
7
8
|
from flwr_datasets.partitioner import IidPartitioner
|
|
8
9
|
|
|
@@ -12,8 +13,19 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
def load_model():
|
|
15
|
-
#
|
|
16
|
-
model =
|
|
16
|
+
# Define a simple CNN for CIFAR-10 and set Adam optimizer
|
|
17
|
+
model = keras.Sequential(
|
|
18
|
+
[
|
|
19
|
+
keras.Input(shape=(32, 32, 3)),
|
|
20
|
+
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
|
|
21
|
+
layers.MaxPooling2D(pool_size=(2, 2)),
|
|
22
|
+
layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
|
|
23
|
+
layers.MaxPooling2D(pool_size=(2, 2)),
|
|
24
|
+
layers.Flatten(),
|
|
25
|
+
layers.Dropout(0.5),
|
|
26
|
+
layers.Dense(10, activation="softmax"),
|
|
27
|
+
]
|
|
28
|
+
)
|
|
17
29
|
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
|
|
18
30
|
return model
|
|
19
31
|
|
|
@@ -8,8 +8,8 @@ version = "1.0.0"
|
|
|
8
8
|
description = ""
|
|
9
9
|
license = "Apache-2.0"
|
|
10
10
|
dependencies = [
|
|
11
|
-
"flwr[simulation]>=1.
|
|
12
|
-
"flwr-datasets>=0.
|
|
11
|
+
"flwr[simulation]>=1.10.0",
|
|
12
|
+
"flwr-datasets>=0.3.0",
|
|
13
13
|
"torch==2.2.1",
|
|
14
14
|
"transformers>=4.30.0,<5.0",
|
|
15
15
|
"evaluate>=0.4.0,<1.0",
|
|
@@ -8,9 +8,9 @@ version = "1.0.0"
|
|
|
8
8
|
description = ""
|
|
9
9
|
license = "Apache-2.0"
|
|
10
10
|
dependencies = [
|
|
11
|
-
"flwr[simulation]>=1.
|
|
12
|
-
"flwr-datasets[vision]>=0.
|
|
13
|
-
"mlx==0.
|
|
11
|
+
"flwr[simulation]>=1.10.0",
|
|
12
|
+
"flwr-datasets[vision]>=0.3.0",
|
|
13
|
+
"mlx==0.16.1",
|
|
14
14
|
"numpy==1.24.4",
|
|
15
15
|
]
|
|
16
16
|
|
|
@@ -8,8 +8,8 @@ version = "1.0.0"
|
|
|
8
8
|
description = ""
|
|
9
9
|
license = "Apache-2.0"
|
|
10
10
|
dependencies = [
|
|
11
|
-
"flwr[simulation]>=1.
|
|
12
|
-
"flwr-datasets[vision]>=0.
|
|
11
|
+
"flwr[simulation]>=1.10.0",
|
|
12
|
+
"flwr-datasets[vision]>=0.3.0",
|
|
13
13
|
"torch==2.2.1",
|
|
14
14
|
"torchvision==0.17.1",
|
|
15
15
|
]
|
|
@@ -26,6 +26,7 @@ clientapp = "$import_name.client_app:app"
|
|
|
26
26
|
|
|
27
27
|
[tool.flwr.app.config]
|
|
28
28
|
num-server-rounds = 3
|
|
29
|
+
fraction-fit = 0.5
|
|
29
30
|
local-epochs = 1
|
|
30
31
|
|
|
31
32
|
[tool.flwr.federations]
|
|
@@ -28,6 +28,7 @@ from flwr.common.constant import (
|
|
|
28
28
|
GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY,
|
|
29
29
|
)
|
|
30
30
|
from flwr.common.version import package_version
|
|
31
|
+
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
31
32
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
32
33
|
CreateNodeRequest,
|
|
33
34
|
CreateNodeResponse,
|
|
@@ -131,3 +132,9 @@ class GrpcAdapter:
|
|
|
131
132
|
) -> GetRunResponse:
|
|
132
133
|
"""."""
|
|
133
134
|
return self._send_and_receive(request, GetRunResponse, **kwargs)
|
|
135
|
+
|
|
136
|
+
def GetFab( # pylint: disable=C0103
|
|
137
|
+
self, request: GetFabRequest, **kwargs: Any
|
|
138
|
+
) -> GetFabResponse:
|
|
139
|
+
"""."""
|
|
140
|
+
return self._send_and_receive(request, GetFabResponse, **kwargs)
|
flwr/proto/driver_pb2.py
CHANGED
|
@@ -15,10 +15,11 @@ _sym_db = _symbol_database.Default()
|
|
|
15
15
|
from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
|
|
16
16
|
from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2
|
|
17
17
|
from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
|
|
18
|
+
from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
|
|
18
19
|
from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
|
|
19
20
|
|
|
20
21
|
|
|
21
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x1a\x66lwr/proto/transport.proto\"\
|
|
22
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xeb\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x12\x1c\n\x03\x66\x61\x62\x18\x04 \x01(\x0b\x32\x0f.flwr.proto.Fab\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\xc7\x03\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x62\x06proto3')
|
|
22
23
|
|
|
23
24
|
_globals = globals()
|
|
24
25
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -27,24 +28,24 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
|
27
28
|
DESCRIPTOR._options = None
|
|
28
29
|
_globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._options = None
|
|
29
30
|
_globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001'
|
|
30
|
-
_globals['_CREATERUNREQUEST']._serialized_start=
|
|
31
|
-
_globals['_CREATERUNREQUEST']._serialized_end=
|
|
32
|
-
_globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=
|
|
33
|
-
_globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=
|
|
34
|
-
_globals['_CREATERUNRESPONSE']._serialized_start=
|
|
35
|
-
_globals['_CREATERUNRESPONSE']._serialized_end=
|
|
36
|
-
_globals['_GETNODESREQUEST']._serialized_start=
|
|
37
|
-
_globals['_GETNODESREQUEST']._serialized_end=
|
|
38
|
-
_globals['_GETNODESRESPONSE']._serialized_start=
|
|
39
|
-
_globals['_GETNODESRESPONSE']._serialized_end=
|
|
40
|
-
_globals['_PUSHTASKINSREQUEST']._serialized_start=
|
|
41
|
-
_globals['_PUSHTASKINSREQUEST']._serialized_end=
|
|
42
|
-
_globals['_PUSHTASKINSRESPONSE']._serialized_start=
|
|
43
|
-
_globals['_PUSHTASKINSRESPONSE']._serialized_end=
|
|
44
|
-
_globals['_PULLTASKRESREQUEST']._serialized_start=
|
|
45
|
-
_globals['_PULLTASKRESREQUEST']._serialized_end=
|
|
46
|
-
_globals['_PULLTASKRESRESPONSE']._serialized_start=
|
|
47
|
-
_globals['_PULLTASKRESRESPONSE']._serialized_end=
|
|
48
|
-
_globals['_DRIVER']._serialized_start=
|
|
49
|
-
_globals['_DRIVER']._serialized_end=
|
|
31
|
+
_globals['_CREATERUNREQUEST']._serialized_start=158
|
|
32
|
+
_globals['_CREATERUNREQUEST']._serialized_end=393
|
|
33
|
+
_globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=320
|
|
34
|
+
_globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=393
|
|
35
|
+
_globals['_CREATERUNRESPONSE']._serialized_start=395
|
|
36
|
+
_globals['_CREATERUNRESPONSE']._serialized_end=430
|
|
37
|
+
_globals['_GETNODESREQUEST']._serialized_start=432
|
|
38
|
+
_globals['_GETNODESREQUEST']._serialized_end=465
|
|
39
|
+
_globals['_GETNODESRESPONSE']._serialized_start=467
|
|
40
|
+
_globals['_GETNODESRESPONSE']._serialized_end=518
|
|
41
|
+
_globals['_PUSHTASKINSREQUEST']._serialized_start=520
|
|
42
|
+
_globals['_PUSHTASKINSREQUEST']._serialized_end=584
|
|
43
|
+
_globals['_PUSHTASKINSRESPONSE']._serialized_start=586
|
|
44
|
+
_globals['_PUSHTASKINSRESPONSE']._serialized_end=625
|
|
45
|
+
_globals['_PULLTASKRESREQUEST']._serialized_start=627
|
|
46
|
+
_globals['_PULLTASKRESREQUEST']._serialized_end=697
|
|
47
|
+
_globals['_PULLTASKRESRESPONSE']._serialized_start=699
|
|
48
|
+
_globals['_PULLTASKRESRESPONSE']._serialized_end=764
|
|
49
|
+
_globals['_DRIVER']._serialized_start=767
|
|
50
|
+
_globals['_DRIVER']._serialized_end=1222
|
|
50
51
|
# @@protoc_insertion_point(module_scope)
|