flwr-nightly 1.10.0.dev20240722__py3-none-any.whl → 1.11.0.dev20240805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/config_utils.py +40 -23
- flwr/cli/new/new.py +7 -6
- flwr/cli/new/templates/app/README.md.tpl +1 -1
- flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
- flwr/cli/new/templates/app/code/{client.hf.py.tpl → client.huggingface.py.tpl} +8 -6
- flwr/cli/new/templates/app/code/client.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +29 -11
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +16 -13
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +3 -3
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +20 -13
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +20 -17
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +5 -3
- flwr/cli/new/templates/app/code/{server.hf.py.tpl → server.huggingface.py.tpl} +3 -2
- flwr/cli/new/templates/app/code/server.jax.py.tpl +3 -2
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +3 -2
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +3 -2
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +8 -7
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +3 -2
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +5 -6
- flwr/cli/new/templates/app/code/{task.hf.py.tpl → task.huggingface.py.tpl} +14 -2
- flwr/cli/new/templates/app/code/task.jax.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +15 -2
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +26 -21
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -5
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +3 -3
- flwr/cli/new/templates/app/{pyproject.hf.toml.tpl → pyproject.huggingface.toml.tpl} +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +11 -11
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -6
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +5 -5
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +8 -8
- flwr/cli/run/run.py +31 -27
- flwr/client/grpc_rere_client/grpc_adapter.py +7 -0
- flwr/client/supernode/app.py +12 -43
- flwr/common/config.py +6 -1
- flwr/common/object_ref.py +84 -21
- flwr/proto/driver_pb2.py +22 -21
- flwr/proto/driver_pb2.pyi +7 -1
- flwr/proto/driver_pb2_grpc.py +35 -0
- flwr/proto/driver_pb2_grpc.pyi +14 -0
- flwr/proto/exec_pb2.py +16 -12
- flwr/proto/exec_pb2.pyi +20 -1
- flwr/proto/fleet_pb2.py +28 -27
- flwr/proto/fleet_pb2_grpc.py +35 -0
- flwr/proto/fleet_pb2_grpc.pyi +14 -0
- flwr/proto/run_pb2.py +8 -8
- flwr/proto/run_pb2.pyi +4 -1
- flwr/server/run_serverapp.py +0 -3
- flwr/server/superlink/driver/driver_servicer.py +7 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +7 -0
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +4 -4
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/run_simulation.py +32 -4
- flwr/superexec/app.py +4 -5
- flwr/superexec/deployment.py +1 -2
- flwr/superexec/exec_servicer.py +3 -1
- flwr/superexec/executor.py +3 -0
- flwr/superexec/simulation.py +54 -12
- {flwr_nightly-1.10.0.dev20240722.dist-info → flwr_nightly-1.11.0.dev20240805.dist-info}/METADATA +1 -1
- {flwr_nightly-1.10.0.dev20240722.dist-info → flwr_nightly-1.11.0.dev20240805.dist-info}/RECORD +66 -66
- {flwr_nightly-1.10.0.dev20240722.dist-info → flwr_nightly-1.11.0.dev20240805.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240722.dist-info → flwr_nightly-1.11.0.dev20240805.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240722.dist-info → flwr_nightly-1.11.0.dev20240805.dist-info}/entry_points.txt +0 -0
flwr/cli/config_utils.py
CHANGED
|
@@ -25,8 +25,8 @@ from flwr.common import object_ref
|
|
|
25
25
|
from flwr.common.typing import UserConfigValue
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
def
|
|
29
|
-
"""Extract the
|
|
28
|
+
def get_fab_config(fab_file: Union[Path, bytes]) -> Dict[str, Any]:
|
|
29
|
+
"""Extract the config from a FAB file or path.
|
|
30
30
|
|
|
31
31
|
Parameters
|
|
32
32
|
----------
|
|
@@ -36,8 +36,8 @@ def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
|
|
|
36
36
|
|
|
37
37
|
Returns
|
|
38
38
|
-------
|
|
39
|
-
|
|
40
|
-
The `
|
|
39
|
+
Dict[str, Any]
|
|
40
|
+
The `config` of the given Flower App Bundle.
|
|
41
41
|
"""
|
|
42
42
|
fab_file_archive: Union[Path, IO[bytes]]
|
|
43
43
|
if isinstance(fab_file, bytes):
|
|
@@ -59,10 +59,29 @@ def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
|
|
|
59
59
|
if not is_valid:
|
|
60
60
|
raise ValueError(errors)
|
|
61
61
|
|
|
62
|
-
return
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
62
|
+
return conf
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
|
|
66
|
+
"""Extract the fab_id and the fab_version from a FAB file or path.
|
|
67
|
+
|
|
68
|
+
Parameters
|
|
69
|
+
----------
|
|
70
|
+
fab_file : Union[Path, bytes]
|
|
71
|
+
The Flower App Bundle file to validate and extract the metadata from.
|
|
72
|
+
It can either be a path to the file or the file itself as bytes.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
Tuple[str, str]
|
|
77
|
+
The `fab_version` and `fab_id` of the given Flower App Bundle.
|
|
78
|
+
"""
|
|
79
|
+
conf = get_fab_config(fab_file)
|
|
80
|
+
|
|
81
|
+
return (
|
|
82
|
+
conf["project"]["version"],
|
|
83
|
+
f"{conf['tool']['flwr']['app']['publisher']}/{conf['project']['name']}",
|
|
84
|
+
)
|
|
66
85
|
|
|
67
86
|
|
|
68
87
|
def load_and_validate(
|
|
@@ -77,6 +96,9 @@ def load_and_validate(
|
|
|
77
96
|
A tuple with the optional config in case it exists and is valid
|
|
78
97
|
and associated errors and warnings.
|
|
79
98
|
"""
|
|
99
|
+
if path is None:
|
|
100
|
+
path = Path.cwd() / "pyproject.toml"
|
|
101
|
+
|
|
80
102
|
config = load(path)
|
|
81
103
|
|
|
82
104
|
if config is None:
|
|
@@ -86,7 +108,7 @@ def load_and_validate(
|
|
|
86
108
|
]
|
|
87
109
|
return (None, errors, [])
|
|
88
110
|
|
|
89
|
-
is_valid, errors, warnings = validate(config, check_module)
|
|
111
|
+
is_valid, errors, warnings = validate(config, check_module, path.parent)
|
|
90
112
|
|
|
91
113
|
if not is_valid:
|
|
92
114
|
return (None, errors, warnings)
|
|
@@ -94,14 +116,8 @@ def load_and_validate(
|
|
|
94
116
|
return (config, errors, warnings)
|
|
95
117
|
|
|
96
118
|
|
|
97
|
-
def load(
|
|
119
|
+
def load(toml_path: Path) -> Optional[Dict[str, Any]]:
|
|
98
120
|
"""Load pyproject.toml and return as dict."""
|
|
99
|
-
if path is None:
|
|
100
|
-
cur_dir = Path.cwd()
|
|
101
|
-
toml_path = cur_dir / "pyproject.toml"
|
|
102
|
-
else:
|
|
103
|
-
toml_path = path
|
|
104
|
-
|
|
105
121
|
if not toml_path.is_file():
|
|
106
122
|
return None
|
|
107
123
|
|
|
@@ -167,7 +183,9 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]
|
|
|
167
183
|
|
|
168
184
|
|
|
169
185
|
def validate(
|
|
170
|
-
config: Dict[str, Any],
|
|
186
|
+
config: Dict[str, Any],
|
|
187
|
+
check_module: bool = True,
|
|
188
|
+
project_dir: Optional[Union[str, Path]] = None,
|
|
171
189
|
) -> Tuple[bool, List[str], List[str]]:
|
|
172
190
|
"""Validate pyproject.toml."""
|
|
173
191
|
is_valid, errors, warnings = validate_fields(config)
|
|
@@ -176,16 +194,15 @@ def validate(
|
|
|
176
194
|
return False, errors, warnings
|
|
177
195
|
|
|
178
196
|
# Validate serverapp
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
197
|
+
serverapp_ref = config["tool"]["flwr"]["app"]["components"]["serverapp"]
|
|
198
|
+
is_valid, reason = object_ref.validate(serverapp_ref, check_module, project_dir)
|
|
199
|
+
|
|
182
200
|
if not is_valid and isinstance(reason, str):
|
|
183
201
|
return False, [reason], []
|
|
184
202
|
|
|
185
203
|
# Validate clientapp
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
)
|
|
204
|
+
clientapp_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
|
|
205
|
+
is_valid, reason = object_ref.validate(clientapp_ref, check_module, project_dir)
|
|
189
206
|
|
|
190
207
|
if not is_valid and isinstance(reason, str):
|
|
191
208
|
return False, [reason], []
|
flwr/cli/new/new.py
CHANGED
|
@@ -38,7 +38,7 @@ class MlFramework(str, Enum):
|
|
|
38
38
|
PYTORCH = "PyTorch"
|
|
39
39
|
TENSORFLOW = "TensorFlow"
|
|
40
40
|
JAX = "JAX"
|
|
41
|
-
HUGGINGFACE = "
|
|
41
|
+
HUGGINGFACE = "HuggingFace"
|
|
42
42
|
MLX = "MLX"
|
|
43
43
|
SKLEARN = "sklearn"
|
|
44
44
|
FLOWERTUNE = "FlowerTune"
|
|
@@ -135,7 +135,7 @@ def new(
|
|
|
135
135
|
username = prompt_text("Please provide your Flower username")
|
|
136
136
|
|
|
137
137
|
if framework is not None:
|
|
138
|
-
|
|
138
|
+
framework_str_upper = str(framework.value)
|
|
139
139
|
else:
|
|
140
140
|
framework_value = prompt_options(
|
|
141
141
|
"Please select ML framework by typing in the number",
|
|
@@ -146,9 +146,9 @@ def new(
|
|
|
146
146
|
for name, value in vars(MlFramework).items()
|
|
147
147
|
if value == framework_value
|
|
148
148
|
]
|
|
149
|
-
|
|
149
|
+
framework_str_upper = selected_value[0]
|
|
150
150
|
|
|
151
|
-
framework_str =
|
|
151
|
+
framework_str = framework_str_upper.lower()
|
|
152
152
|
|
|
153
153
|
llm_challenge_str = None
|
|
154
154
|
if framework_str == "flowertune":
|
|
@@ -173,9 +173,10 @@ def new(
|
|
|
173
173
|
)
|
|
174
174
|
|
|
175
175
|
context = {
|
|
176
|
-
"
|
|
177
|
-
"package_name": package_name,
|
|
176
|
+
"framework_str": framework_str_upper,
|
|
178
177
|
"import_name": import_name.replace("-", "_"),
|
|
178
|
+
"package_name": package_name,
|
|
179
|
+
"project_name": project_name,
|
|
179
180
|
"username": username,
|
|
180
181
|
}
|
|
181
182
|
|
|
@@ -1 +1 @@
|
|
|
1
|
-
"""$project_name."""
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
from flwr.client import ClientApp, NumPyClient
|
|
4
4
|
from flwr.common import Context
|
|
@@ -17,10 +17,11 @@ from $import_name.task import (
|
|
|
17
17
|
|
|
18
18
|
# Flower client
|
|
19
19
|
class FlowerClient(NumPyClient):
|
|
20
|
-
def __init__(self, net, trainloader, testloader):
|
|
20
|
+
def __init__(self, net, trainloader, testloader, local_epochs):
|
|
21
21
|
self.net = net
|
|
22
22
|
self.trainloader = trainloader
|
|
23
23
|
self.testloader = testloader
|
|
24
|
+
self.local_epochs = local_epochs
|
|
24
25
|
|
|
25
26
|
def get_parameters(self, config):
|
|
26
27
|
return get_weights(self.net)
|
|
@@ -33,7 +34,7 @@ class FlowerClient(NumPyClient):
|
|
|
33
34
|
train(
|
|
34
35
|
self.net,
|
|
35
36
|
self.trainloader,
|
|
36
|
-
epochs=
|
|
37
|
+
epochs=self.local_epochs,
|
|
37
38
|
)
|
|
38
39
|
return self.get_parameters(config={}), len(self.trainloader), {}
|
|
39
40
|
|
|
@@ -49,12 +50,13 @@ def client_fn(context: Context):
|
|
|
49
50
|
CHECKPOINT, num_labels=2
|
|
50
51
|
).to(DEVICE)
|
|
51
52
|
|
|
52
|
-
partition_id =
|
|
53
|
-
num_partitions =
|
|
53
|
+
partition_id = context.node_config["partition-id"]
|
|
54
|
+
num_partitions = context.node_config["num-partitions"]
|
|
54
55
|
trainloader, valloader = load_data(partition_id, num_partitions)
|
|
56
|
+
local_epochs = context.run_config["local-epochs"]
|
|
55
57
|
|
|
56
58
|
# Return Client instance
|
|
57
|
-
return FlowerClient(net, trainloader, valloader).to_client()
|
|
59
|
+
return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
|
|
58
60
|
|
|
59
61
|
|
|
60
62
|
# Flower ClientApp
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
import mlx.core as mx
|
|
4
4
|
import mlx.nn as nn
|
|
@@ -19,13 +19,22 @@ from $import_name.task import (
|
|
|
19
19
|
|
|
20
20
|
# Define Flower Client and client_fn
|
|
21
21
|
class FlowerClient(NumPyClient):
|
|
22
|
-
def __init__(
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
data,
|
|
25
|
+
num_layers,
|
|
26
|
+
hidden_dim,
|
|
27
|
+
num_classes,
|
|
28
|
+
batch_size,
|
|
29
|
+
learning_rate,
|
|
30
|
+
num_epochs,
|
|
31
|
+
):
|
|
32
|
+
self.num_layers = num_layers
|
|
33
|
+
self.hidden_dim = hidden_dim
|
|
34
|
+
self.num_classes = num_classes
|
|
35
|
+
self.batch_size = batch_size
|
|
36
|
+
self.learning_rate = learning_rate
|
|
37
|
+
self.num_epochs = num_epochs
|
|
29
38
|
|
|
30
39
|
self.train_images, self.train_labels, self.test_images, self.test_labels = data
|
|
31
40
|
self.model = MLP(
|
|
@@ -61,12 +70,21 @@ class FlowerClient(NumPyClient):
|
|
|
61
70
|
|
|
62
71
|
|
|
63
72
|
def client_fn(context: Context):
|
|
64
|
-
partition_id =
|
|
65
|
-
num_partitions =
|
|
73
|
+
partition_id = context.node_config["partition-id"]
|
|
74
|
+
num_partitions = context.node_config["num-partitions"]
|
|
66
75
|
data = load_data(partition_id, num_partitions)
|
|
67
76
|
|
|
77
|
+
num_layers = context.run_config["num-layers"]
|
|
78
|
+
hidden_dim = context.run_config["hidden-dim"]
|
|
79
|
+
num_classes = 10
|
|
80
|
+
batch_size = context.run_config["batch-size"]
|
|
81
|
+
learning_rate = context.run_config["lr"]
|
|
82
|
+
num_epochs = context.run_config["local-epochs"]
|
|
83
|
+
|
|
68
84
|
# Return Client instance
|
|
69
|
-
return FlowerClient(
|
|
85
|
+
return FlowerClient(
|
|
86
|
+
data, num_layers, hidden_dim, num_classes, batch_size, learning_rate, num_epochs
|
|
87
|
+
).to_client()
|
|
70
88
|
|
|
71
89
|
|
|
72
90
|
# Flower ClientApp
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
+
import torch
|
|
3
4
|
from flwr.client import NumPyClient, ClientApp
|
|
4
5
|
from flwr.common import Context
|
|
5
6
|
|
|
6
7
|
from $import_name.task import (
|
|
7
8
|
Net,
|
|
8
|
-
DEVICE,
|
|
9
9
|
load_data,
|
|
10
10
|
get_weights,
|
|
11
11
|
set_weights,
|
|
@@ -16,37 +16,40 @@ from $import_name.task import (
|
|
|
16
16
|
|
|
17
17
|
# Define Flower Client and client_fn
|
|
18
18
|
class FlowerClient(NumPyClient):
|
|
19
|
-
def __init__(self, net, trainloader, valloader):
|
|
19
|
+
def __init__(self, net, trainloader, valloader, local_epochs):
|
|
20
20
|
self.net = net
|
|
21
21
|
self.trainloader = trainloader
|
|
22
22
|
self.valloader = valloader
|
|
23
|
+
self.local_epochs = local_epochs
|
|
24
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
25
|
+
self.net.to(self.device)
|
|
23
26
|
|
|
24
27
|
def fit(self, parameters, config):
|
|
25
28
|
set_weights(self.net, parameters)
|
|
26
|
-
|
|
29
|
+
train_loss = train(
|
|
27
30
|
self.net,
|
|
28
31
|
self.trainloader,
|
|
29
|
-
self.
|
|
30
|
-
|
|
31
|
-
DEVICE,
|
|
32
|
+
self.local_epochs,
|
|
33
|
+
self.device,
|
|
32
34
|
)
|
|
33
|
-
return get_weights(self.net), len(self.trainloader.dataset),
|
|
35
|
+
return get_weights(self.net), len(self.trainloader.dataset), {"train_loss": train_loss}
|
|
34
36
|
|
|
35
37
|
def evaluate(self, parameters, config):
|
|
36
38
|
set_weights(self.net, parameters)
|
|
37
|
-
loss, accuracy = test(self.net, self.valloader)
|
|
39
|
+
loss, accuracy = test(self.net, self.valloader, self.device)
|
|
38
40
|
return loss, len(self.valloader.dataset), {"accuracy": accuracy}
|
|
39
41
|
|
|
40
42
|
|
|
41
43
|
def client_fn(context: Context):
|
|
42
44
|
# Load model and data
|
|
43
|
-
net = Net()
|
|
44
|
-
partition_id =
|
|
45
|
-
num_partitions =
|
|
45
|
+
net = Net()
|
|
46
|
+
partition_id = context.node_config["partition-id"]
|
|
47
|
+
num_partitions = context.node_config["num-partitions"]
|
|
46
48
|
trainloader, valloader = load_data(partition_id, num_partitions)
|
|
49
|
+
local_epochs = context.run_config["local-epochs"]
|
|
47
50
|
|
|
48
51
|
# Return Client instance
|
|
49
|
-
return FlowerClient(net, trainloader, valloader).to_client()
|
|
52
|
+
return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
|
|
50
53
|
|
|
51
54
|
|
|
52
55
|
# Flower ClientApp
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
import warnings
|
|
4
4
|
|
|
@@ -69,8 +69,8 @@ class FlowerClient(NumPyClient):
|
|
|
69
69
|
|
|
70
70
|
|
|
71
71
|
def client_fn(context: Context):
|
|
72
|
-
partition_id =
|
|
73
|
-
num_partitions =
|
|
72
|
+
partition_id = context.node_config["partition-id"]
|
|
73
|
+
num_partitions = context.node_config["num-partitions"]
|
|
74
74
|
fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions})
|
|
75
75
|
dataset = fds.load_partition(partition_id, "train").with_format("numpy")
|
|
76
76
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
from flwr.client import NumPyClient, ClientApp
|
|
4
4
|
from flwr.common import Context
|
|
@@ -8,12 +8,14 @@ from $import_name.task import load_data, load_model
|
|
|
8
8
|
|
|
9
9
|
# Define Flower Client and client_fn
|
|
10
10
|
class FlowerClient(NumPyClient):
|
|
11
|
-
def __init__(
|
|
11
|
+
def __init__(
|
|
12
|
+
self, model, data, epochs, batch_size, verbose
|
|
13
|
+
):
|
|
12
14
|
self.model = model
|
|
13
|
-
self.x_train =
|
|
14
|
-
self.
|
|
15
|
-
self.
|
|
16
|
-
self.
|
|
15
|
+
self.x_train, self.y_train, self.x_test, self.y_test = data
|
|
16
|
+
self.epochs = epochs
|
|
17
|
+
self.batch_size = batch_size
|
|
18
|
+
self.verbose = verbose
|
|
17
19
|
|
|
18
20
|
def get_parameters(self, config):
|
|
19
21
|
return self.model.get_weights()
|
|
@@ -23,9 +25,9 @@ class FlowerClient(NumPyClient):
|
|
|
23
25
|
self.model.fit(
|
|
24
26
|
self.x_train,
|
|
25
27
|
self.y_train,
|
|
26
|
-
epochs=
|
|
27
|
-
batch_size=
|
|
28
|
-
verbose=
|
|
28
|
+
epochs=self.epochs,
|
|
29
|
+
batch_size=self.batch_size,
|
|
30
|
+
verbose=self.verbose,
|
|
29
31
|
)
|
|
30
32
|
return self.model.get_weights(), len(self.x_train), {}
|
|
31
33
|
|
|
@@ -39,12 +41,17 @@ def client_fn(context: Context):
|
|
|
39
41
|
# Load model and data
|
|
40
42
|
net = load_model()
|
|
41
43
|
|
|
42
|
-
partition_id =
|
|
43
|
-
num_partitions =
|
|
44
|
-
|
|
44
|
+
partition_id = context.node_config["partition-id"]
|
|
45
|
+
num_partitions = context.node_config["num-partitions"]
|
|
46
|
+
data = load_data(partition_id, num_partitions)
|
|
47
|
+
epochs = context.run_config["local-epochs"]
|
|
48
|
+
batch_size = context.run_config["batch-size"]
|
|
49
|
+
verbose = context.run_config.get("verbose")
|
|
45
50
|
|
|
46
51
|
# Return Client instance
|
|
47
|
-
return FlowerClient(
|
|
52
|
+
return FlowerClient(
|
|
53
|
+
net, data, epochs, batch_size, verbose
|
|
54
|
+
).to_client()
|
|
48
55
|
|
|
49
56
|
|
|
50
57
|
# Flower ClientApp
|
|
@@ -9,8 +9,8 @@ from hydra import compose, initialize
|
|
|
9
9
|
from hydra.utils import instantiate
|
|
10
10
|
|
|
11
11
|
from flwr.client import ClientApp
|
|
12
|
-
from flwr.common import ndarrays_to_parameters
|
|
13
|
-
from flwr.server import ServerApp, ServerConfig
|
|
12
|
+
from flwr.common import Context, ndarrays_to_parameters
|
|
13
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
14
14
|
|
|
15
15
|
from $import_name.client_app import gen_client_fn, get_parameters
|
|
16
16
|
from $import_name.dataset import get_tokenizer_and_data_collator_and_propt_formatting
|
|
@@ -67,20 +67,23 @@ init_model = get_model(cfg.model)
|
|
|
67
67
|
init_model_parameters = get_parameters(init_model)
|
|
68
68
|
init_model_parameters = ndarrays_to_parameters(init_model_parameters)
|
|
69
69
|
|
|
70
|
-
|
|
71
|
-
#
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
)
|
|
70
|
+
def server_fn(context: Context):
|
|
71
|
+
# Instantiate strategy according to config. Here we pass other arguments
|
|
72
|
+
# that are only defined at runtime.
|
|
73
|
+
strategy = instantiate(
|
|
74
|
+
cfg.strategy,
|
|
75
|
+
on_fit_config_fn=get_on_fit_config(),
|
|
76
|
+
fit_metrics_aggregation_fn=fit_weighted_average,
|
|
77
|
+
initial_parameters=init_model_parameters,
|
|
78
|
+
evaluate_fn=get_evaluate_fn(
|
|
79
|
+
cfg.model, cfg.train.save_every_round, cfg_static.num_rounds, save_path
|
|
80
|
+
),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
config = ServerConfig(num_rounds=cfg_static.num_rounds)
|
|
84
|
+
|
|
85
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
86
|
+
|
|
81
87
|
|
|
82
88
|
# ServerApp for Flower Next
|
|
83
|
-
server = ServerApp(
|
|
84
|
-
config=ServerConfig(num_rounds=cfg_static.num_rounds),
|
|
85
|
-
strategy=strategy,
|
|
86
|
-
)
|
|
89
|
+
server = ServerApp(server_fn=server_fn)
|
|
@@ -10,6 +10,7 @@ from transformers import TrainingArguments
|
|
|
10
10
|
from trl import SFTTrainer
|
|
11
11
|
|
|
12
12
|
from flwr.client import NumPyClient
|
|
13
|
+
from flwr.common import Context
|
|
13
14
|
from flwr.common.typing import NDArrays, Scalar
|
|
14
15
|
from $import_name.dataset import reformat
|
|
15
16
|
from $import_name.models import cosine_annealing, get_model
|
|
@@ -102,13 +103,14 @@ def gen_client_fn(
|
|
|
102
103
|
model_cfg: DictConfig,
|
|
103
104
|
train_cfg: DictConfig,
|
|
104
105
|
save_path: str,
|
|
105
|
-
) -> Callable[[
|
|
106
|
+
) -> Callable[[Context], FlowerClient]: # pylint: disable=too-many-arguments
|
|
106
107
|
"""Generate the client function that creates the Flower Clients."""
|
|
107
108
|
|
|
108
|
-
def client_fn(
|
|
109
|
+
def client_fn(context: Context) -> FlowerClient:
|
|
109
110
|
"""Create a Flower client representing a single organization."""
|
|
110
111
|
# Let's get the partition corresponding to the i-th client
|
|
111
|
-
|
|
112
|
+
partition_id = context.node_config["partition-id"]
|
|
113
|
+
client_trainset = fds.load_partition(partition_id, "train")
|
|
112
114
|
client_trainset = reformat(client_trainset, llm_task="$llm_challenge_str")
|
|
113
115
|
|
|
114
116
|
return FlowerClient(
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
from flwr.common import Context
|
|
4
4
|
from flwr.server.strategy import FedAvg
|
|
@@ -7,7 +7,7 @@ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
|
7
7
|
|
|
8
8
|
def server_fn(context: Context):
|
|
9
9
|
# Read from config
|
|
10
|
-
num_rounds =
|
|
10
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
11
11
|
|
|
12
12
|
# Define strategy
|
|
13
13
|
strategy = FedAvg(
|
|
@@ -18,5 +18,6 @@ def server_fn(context: Context):
|
|
|
18
18
|
|
|
19
19
|
return ServerAppComponents(strategy=strategy, config=config)
|
|
20
20
|
|
|
21
|
+
|
|
21
22
|
# Create ServerApp
|
|
22
23
|
app = ServerApp(server_fn=server_fn)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
from flwr.common import Context
|
|
4
4
|
from flwr.server.strategy import FedAvg
|
|
@@ -7,7 +7,7 @@ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
|
7
7
|
|
|
8
8
|
def server_fn(context: Context):
|
|
9
9
|
# Read from config
|
|
10
|
-
num_rounds =
|
|
10
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
11
11
|
|
|
12
12
|
# Define strategy
|
|
13
13
|
strategy = FedAvg()
|
|
@@ -15,5 +15,6 @@ def server_fn(context: Context):
|
|
|
15
15
|
|
|
16
16
|
return ServerAppComponents(strategy=strategy, config=config)
|
|
17
17
|
|
|
18
|
+
|
|
18
19
|
# Create ServerApp
|
|
19
20
|
app = ServerApp(server_fn=server_fn)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
from flwr.common import Context
|
|
4
4
|
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
@@ -7,7 +7,7 @@ from flwr.server.strategy import FedAvg
|
|
|
7
7
|
|
|
8
8
|
def server_fn(context: Context):
|
|
9
9
|
# Read from config
|
|
10
|
-
num_rounds =
|
|
10
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
11
11
|
|
|
12
12
|
# Define strategy
|
|
13
13
|
strategy = FedAvg()
|
|
@@ -15,5 +15,6 @@ def server_fn(context: Context):
|
|
|
15
15
|
|
|
16
16
|
return ServerAppComponents(strategy=strategy, config=config)
|
|
17
17
|
|
|
18
|
+
|
|
18
19
|
# Create ServerApp
|
|
19
20
|
app = ServerApp(server_fn=server_fn)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
from flwr.common import Context
|
|
4
4
|
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
@@ -7,7 +7,7 @@ from flwr.server.strategy import FedAvg
|
|
|
7
7
|
|
|
8
8
|
def server_fn(context: Context):
|
|
9
9
|
# Read from config
|
|
10
|
-
num_rounds =
|
|
10
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
11
11
|
|
|
12
12
|
# Define strategy
|
|
13
13
|
strategy = FedAvg()
|
|
@@ -15,5 +15,6 @@ def server_fn(context: Context):
|
|
|
15
15
|
|
|
16
16
|
return ServerAppComponents(strategy=strategy, config=config)
|
|
17
17
|
|
|
18
|
+
|
|
18
19
|
# Create ServerApp
|
|
19
20
|
app = ServerApp(server_fn=server_fn)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
from flwr.common import Context, ndarrays_to_parameters
|
|
4
4
|
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
@@ -7,17 +7,18 @@ from flwr.server.strategy import FedAvg
|
|
|
7
7
|
from $import_name.task import Net, get_weights
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
# Initialize model parameters
|
|
11
|
-
ndarrays = get_weights(Net())
|
|
12
|
-
parameters = ndarrays_to_parameters(ndarrays)
|
|
13
|
-
|
|
14
10
|
def server_fn(context: Context):
|
|
15
11
|
# Read from config
|
|
16
|
-
num_rounds =
|
|
12
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
13
|
+
fraction_fit = context.run_config["fraction-fit"]
|
|
14
|
+
|
|
15
|
+
# Initialize model parameters
|
|
16
|
+
ndarrays = get_weights(Net())
|
|
17
|
+
parameters = ndarrays_to_parameters(ndarrays)
|
|
17
18
|
|
|
18
19
|
# Define strategy
|
|
19
20
|
strategy = FedAvg(
|
|
20
|
-
fraction_fit=
|
|
21
|
+
fraction_fit=fraction_fit,
|
|
21
22
|
fraction_evaluate=1.0,
|
|
22
23
|
min_available_clients=2,
|
|
23
24
|
initial_parameters=parameters,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
from flwr.common import Context
|
|
4
4
|
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
@@ -7,7 +7,7 @@ from flwr.server.strategy import FedAvg
|
|
|
7
7
|
|
|
8
8
|
def server_fn(context: Context):
|
|
9
9
|
# Read from config
|
|
10
|
-
num_rounds =
|
|
10
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
11
11
|
|
|
12
12
|
# Define strategy
|
|
13
13
|
strategy = FedAvg(
|
|
@@ -19,5 +19,6 @@ def server_fn(context: Context):
|
|
|
19
19
|
|
|
20
20
|
return ServerAppComponents(strategy=strategy, config=config)
|
|
21
21
|
|
|
22
|
+
|
|
22
23
|
# Create ServerApp
|
|
23
24
|
app = ServerApp(server_fn=server_fn)
|