flwr-nightly 1.10.0.dev20240722__py3-none-any.whl → 1.10.0.dev20240723__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/config_utils.py +13 -15
- flwr/cli/new/new.py +1 -1
- flwr/cli/new/templates/app/code/{client.hf.py.tpl → client.huggingface.py.tpl} +7 -5
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +28 -10
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +7 -5
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +2 -2
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +17 -7
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +20 -17
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +5 -3
- flwr/cli/new/templates/app/code/{server.hf.py.tpl → server.huggingface.py.tpl} +2 -1
- flwr/cli/new/templates/app/code/server.jax.py.tpl +2 -1
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +2 -1
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +2 -1
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +2 -1
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +1 -1
- flwr/cli/new/templates/app/code/{task.hf.py.tpl → task.huggingface.py.tpl} +13 -1
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +14 -1
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +13 -2
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +13 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/{pyproject.hf.toml.tpl → pyproject.huggingface.toml.tpl} +2 -2
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +6 -6
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +4 -4
- flwr/cli/run/run.py +31 -27
- flwr/client/supernode/app.py +12 -43
- flwr/common/config.py +6 -1
- flwr/common/object_ref.py +84 -21
- flwr/proto/exec_pb2.py +16 -12
- flwr/proto/exec_pb2.pyi +20 -1
- flwr/server/run_serverapp.py +0 -3
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +4 -4
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/run_simulation.py +32 -4
- flwr/superexec/app.py +4 -5
- flwr/superexec/deployment.py +1 -0
- flwr/superexec/exec_servicer.py +3 -1
- flwr/superexec/executor.py +3 -0
- flwr/superexec/simulation.py +28 -6
- {flwr_nightly-1.10.0.dev20240722.dist-info → flwr_nightly-1.10.0.dev20240723.dist-info}/METADATA +1 -1
- {flwr_nightly-1.10.0.dev20240722.dist-info → flwr_nightly-1.10.0.dev20240723.dist-info}/RECORD +49 -49
- {flwr_nightly-1.10.0.dev20240722.dist-info → flwr_nightly-1.10.0.dev20240723.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240722.dist-info → flwr_nightly-1.10.0.dev20240723.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240722.dist-info → flwr_nightly-1.10.0.dev20240723.dist-info}/entry_points.txt +0 -0
flwr/cli/config_utils.py
CHANGED
|
@@ -77,6 +77,9 @@ def load_and_validate(
|
|
|
77
77
|
A tuple with the optional config in case it exists and is valid
|
|
78
78
|
and associated errors and warnings.
|
|
79
79
|
"""
|
|
80
|
+
if path is None:
|
|
81
|
+
path = Path.cwd() / "pyproject.toml"
|
|
82
|
+
|
|
80
83
|
config = load(path)
|
|
81
84
|
|
|
82
85
|
if config is None:
|
|
@@ -86,7 +89,7 @@ def load_and_validate(
|
|
|
86
89
|
]
|
|
87
90
|
return (None, errors, [])
|
|
88
91
|
|
|
89
|
-
is_valid, errors, warnings = validate(config, check_module)
|
|
92
|
+
is_valid, errors, warnings = validate(config, check_module, path.parent)
|
|
90
93
|
|
|
91
94
|
if not is_valid:
|
|
92
95
|
return (None, errors, warnings)
|
|
@@ -94,14 +97,8 @@ def load_and_validate(
|
|
|
94
97
|
return (config, errors, warnings)
|
|
95
98
|
|
|
96
99
|
|
|
97
|
-
def load(
|
|
100
|
+
def load(toml_path: Path) -> Optional[Dict[str, Any]]:
|
|
98
101
|
"""Load pyproject.toml and return as dict."""
|
|
99
|
-
if path is None:
|
|
100
|
-
cur_dir = Path.cwd()
|
|
101
|
-
toml_path = cur_dir / "pyproject.toml"
|
|
102
|
-
else:
|
|
103
|
-
toml_path = path
|
|
104
|
-
|
|
105
102
|
if not toml_path.is_file():
|
|
106
103
|
return None
|
|
107
104
|
|
|
@@ -167,7 +164,9 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]
|
|
|
167
164
|
|
|
168
165
|
|
|
169
166
|
def validate(
|
|
170
|
-
config: Dict[str, Any],
|
|
167
|
+
config: Dict[str, Any],
|
|
168
|
+
check_module: bool = True,
|
|
169
|
+
project_dir: Optional[Union[str, Path]] = None,
|
|
171
170
|
) -> Tuple[bool, List[str], List[str]]:
|
|
172
171
|
"""Validate pyproject.toml."""
|
|
173
172
|
is_valid, errors, warnings = validate_fields(config)
|
|
@@ -176,16 +175,15 @@ def validate(
|
|
|
176
175
|
return False, errors, warnings
|
|
177
176
|
|
|
178
177
|
# Validate serverapp
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
178
|
+
serverapp_ref = config["tool"]["flwr"]["app"]["components"]["serverapp"]
|
|
179
|
+
is_valid, reason = object_ref.validate(serverapp_ref, check_module, project_dir)
|
|
180
|
+
|
|
182
181
|
if not is_valid and isinstance(reason, str):
|
|
183
182
|
return False, [reason], []
|
|
184
183
|
|
|
185
184
|
# Validate clientapp
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
)
|
|
185
|
+
clientapp_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
|
|
186
|
+
is_valid, reason = object_ref.validate(clientapp_ref, check_module, project_dir)
|
|
189
187
|
|
|
190
188
|
if not is_valid and isinstance(reason, str):
|
|
191
189
|
return False, [reason], []
|
flwr/cli/new/new.py
CHANGED
|
@@ -17,10 +17,11 @@ from $import_name.task import (
|
|
|
17
17
|
|
|
18
18
|
# Flower client
|
|
19
19
|
class FlowerClient(NumPyClient):
|
|
20
|
-
def __init__(self, net, trainloader, testloader):
|
|
20
|
+
def __init__(self, net, trainloader, testloader, local_epochs):
|
|
21
21
|
self.net = net
|
|
22
22
|
self.trainloader = trainloader
|
|
23
23
|
self.testloader = testloader
|
|
24
|
+
self.local_epochs = local_epochs
|
|
24
25
|
|
|
25
26
|
def get_parameters(self, config):
|
|
26
27
|
return get_weights(self.net)
|
|
@@ -33,7 +34,7 @@ class FlowerClient(NumPyClient):
|
|
|
33
34
|
train(
|
|
34
35
|
self.net,
|
|
35
36
|
self.trainloader,
|
|
36
|
-
epochs=
|
|
37
|
+
epochs=self.local_epochs,
|
|
37
38
|
)
|
|
38
39
|
return self.get_parameters(config={}), len(self.trainloader), {}
|
|
39
40
|
|
|
@@ -49,12 +50,13 @@ def client_fn(context: Context):
|
|
|
49
50
|
CHECKPOINT, num_labels=2
|
|
50
51
|
).to(DEVICE)
|
|
51
52
|
|
|
52
|
-
partition_id =
|
|
53
|
-
num_partitions =
|
|
53
|
+
partition_id = context.node_config["partition-id"]
|
|
54
|
+
num_partitions = context.node_config["num-partitions"]
|
|
54
55
|
trainloader, valloader = load_data(partition_id, num_partitions)
|
|
56
|
+
local_epochs = context.run_config["local-epochs"]
|
|
55
57
|
|
|
56
58
|
# Return Client instance
|
|
57
|
-
return FlowerClient(net, trainloader, valloader).to_client()
|
|
59
|
+
return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
|
|
58
60
|
|
|
59
61
|
|
|
60
62
|
# Flower ClientApp
|
|
@@ -19,13 +19,22 @@ from $import_name.task import (
|
|
|
19
19
|
|
|
20
20
|
# Define Flower Client and client_fn
|
|
21
21
|
class FlowerClient(NumPyClient):
|
|
22
|
-
def __init__(
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
data,
|
|
25
|
+
num_layers,
|
|
26
|
+
hidden_dim,
|
|
27
|
+
num_classes,
|
|
28
|
+
batch_size,
|
|
29
|
+
learning_rate,
|
|
30
|
+
num_epochs,
|
|
31
|
+
):
|
|
32
|
+
self.num_layers = num_layers
|
|
33
|
+
self.hidden_dim = hidden_dim
|
|
34
|
+
self.num_classes = num_classes
|
|
35
|
+
self.batch_size = batch_size
|
|
36
|
+
self.learning_rate = learning_rate
|
|
37
|
+
self.num_epochs = num_epochs
|
|
29
38
|
|
|
30
39
|
self.train_images, self.train_labels, self.test_images, self.test_labels = data
|
|
31
40
|
self.model = MLP(
|
|
@@ -61,12 +70,21 @@ class FlowerClient(NumPyClient):
|
|
|
61
70
|
|
|
62
71
|
|
|
63
72
|
def client_fn(context: Context):
|
|
64
|
-
partition_id =
|
|
65
|
-
num_partitions =
|
|
73
|
+
partition_id = context.node_config["partition-id"]
|
|
74
|
+
num_partitions = context.node_config["num-partitions"]
|
|
66
75
|
data = load_data(partition_id, num_partitions)
|
|
67
76
|
|
|
77
|
+
num_layers = context.run_config["num-layers"]
|
|
78
|
+
hidden_dim = context.run_config["hidden-dim"]
|
|
79
|
+
num_classes = 10
|
|
80
|
+
batch_size = context.run_config["batch-size"]
|
|
81
|
+
learning_rate = context.run_config["lr"]
|
|
82
|
+
num_epochs = context.run_config["local-epochs"]
|
|
83
|
+
|
|
68
84
|
# Return Client instance
|
|
69
|
-
return FlowerClient(
|
|
85
|
+
return FlowerClient(
|
|
86
|
+
data, num_layers, hidden_dim, num_classes, batch_size, learning_rate, num_epochs
|
|
87
|
+
).to_client()
|
|
70
88
|
|
|
71
89
|
|
|
72
90
|
# Flower ClientApp
|
|
@@ -16,10 +16,11 @@ from $import_name.task import (
|
|
|
16
16
|
|
|
17
17
|
# Define Flower Client and client_fn
|
|
18
18
|
class FlowerClient(NumPyClient):
|
|
19
|
-
def __init__(self, net, trainloader, valloader):
|
|
19
|
+
def __init__(self, net, trainloader, valloader, local_epochs):
|
|
20
20
|
self.net = net
|
|
21
21
|
self.trainloader = trainloader
|
|
22
22
|
self.valloader = valloader
|
|
23
|
+
self.local_epochs = local_epochs
|
|
23
24
|
|
|
24
25
|
def fit(self, parameters, config):
|
|
25
26
|
set_weights(self.net, parameters)
|
|
@@ -27,7 +28,7 @@ class FlowerClient(NumPyClient):
|
|
|
27
28
|
self.net,
|
|
28
29
|
self.trainloader,
|
|
29
30
|
self.valloader,
|
|
30
|
-
|
|
31
|
+
self.local_epochs,
|
|
31
32
|
DEVICE,
|
|
32
33
|
)
|
|
33
34
|
return get_weights(self.net), len(self.trainloader.dataset), results
|
|
@@ -41,12 +42,13 @@ class FlowerClient(NumPyClient):
|
|
|
41
42
|
def client_fn(context: Context):
|
|
42
43
|
# Load model and data
|
|
43
44
|
net = Net().to(DEVICE)
|
|
44
|
-
partition_id =
|
|
45
|
-
num_partitions =
|
|
45
|
+
partition_id = context.node_config["partition-id"]
|
|
46
|
+
num_partitions = context.node_config["num-partitions"]
|
|
46
47
|
trainloader, valloader = load_data(partition_id, num_partitions)
|
|
48
|
+
local_epochs = context.run_config["local-epochs"]
|
|
47
49
|
|
|
48
50
|
# Return Client instance
|
|
49
|
-
return FlowerClient(net, trainloader, valloader).to_client()
|
|
51
|
+
return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
|
|
50
52
|
|
|
51
53
|
|
|
52
54
|
# Flower ClientApp
|
|
@@ -69,8 +69,8 @@ class FlowerClient(NumPyClient):
|
|
|
69
69
|
|
|
70
70
|
|
|
71
71
|
def client_fn(context: Context):
|
|
72
|
-
partition_id =
|
|
73
|
-
num_partitions =
|
|
72
|
+
partition_id = context.node_config["partition-id"]
|
|
73
|
+
num_partitions = context.node_config["num-partitions"]
|
|
74
74
|
fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions})
|
|
75
75
|
dataset = fds.load_partition(partition_id, "train").with_format("numpy")
|
|
76
76
|
|
|
@@ -8,12 +8,17 @@ from $import_name.task import load_data, load_model
|
|
|
8
8
|
|
|
9
9
|
# Define Flower Client and client_fn
|
|
10
10
|
class FlowerClient(NumPyClient):
|
|
11
|
-
def __init__(
|
|
11
|
+
def __init__(
|
|
12
|
+
self, model, x_train, y_train, x_test, y_test, epochs, batch_size, verbose
|
|
13
|
+
):
|
|
12
14
|
self.model = model
|
|
13
15
|
self.x_train = x_train
|
|
14
16
|
self.y_train = y_train
|
|
15
17
|
self.x_test = x_test
|
|
16
18
|
self.y_test = y_test
|
|
19
|
+
self.epochs = epochs
|
|
20
|
+
self.batch_size = batch_size
|
|
21
|
+
self.verbose = verbose
|
|
17
22
|
|
|
18
23
|
def get_parameters(self, config):
|
|
19
24
|
return self.model.get_weights()
|
|
@@ -23,9 +28,9 @@ class FlowerClient(NumPyClient):
|
|
|
23
28
|
self.model.fit(
|
|
24
29
|
self.x_train,
|
|
25
30
|
self.y_train,
|
|
26
|
-
epochs=
|
|
27
|
-
batch_size=
|
|
28
|
-
verbose=
|
|
31
|
+
epochs=self.epochs,
|
|
32
|
+
batch_size=self.batch_size,
|
|
33
|
+
verbose=self.verbose,
|
|
29
34
|
)
|
|
30
35
|
return self.model.get_weights(), len(self.x_train), {}
|
|
31
36
|
|
|
@@ -39,12 +44,17 @@ def client_fn(context: Context):
|
|
|
39
44
|
# Load model and data
|
|
40
45
|
net = load_model()
|
|
41
46
|
|
|
42
|
-
partition_id =
|
|
43
|
-
num_partitions =
|
|
47
|
+
partition_id = context.node_config["partition-id"]
|
|
48
|
+
num_partitions = context.node_config["num-partitions"]
|
|
44
49
|
x_train, y_train, x_test, y_test = load_data(partition_id, num_partitions)
|
|
50
|
+
epochs = context.run_config["local-epochs"]
|
|
51
|
+
batch_size = context.run_config["batch-size"]
|
|
52
|
+
verbose = context.run_config.get("verbose")
|
|
45
53
|
|
|
46
54
|
# Return Client instance
|
|
47
|
-
return FlowerClient(
|
|
55
|
+
return FlowerClient(
|
|
56
|
+
net, x_train, y_train, x_test, y_test, epochs, batch_size, verbose
|
|
57
|
+
).to_client()
|
|
48
58
|
|
|
49
59
|
|
|
50
60
|
# Flower ClientApp
|
|
@@ -9,8 +9,8 @@ from hydra import compose, initialize
|
|
|
9
9
|
from hydra.utils import instantiate
|
|
10
10
|
|
|
11
11
|
from flwr.client import ClientApp
|
|
12
|
-
from flwr.common import ndarrays_to_parameters
|
|
13
|
-
from flwr.server import ServerApp, ServerConfig
|
|
12
|
+
from flwr.common import Context, ndarrays_to_parameters
|
|
13
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
14
14
|
|
|
15
15
|
from $import_name.client_app import gen_client_fn, get_parameters
|
|
16
16
|
from $import_name.dataset import get_tokenizer_and_data_collator_and_propt_formatting
|
|
@@ -67,20 +67,23 @@ init_model = get_model(cfg.model)
|
|
|
67
67
|
init_model_parameters = get_parameters(init_model)
|
|
68
68
|
init_model_parameters = ndarrays_to_parameters(init_model_parameters)
|
|
69
69
|
|
|
70
|
-
|
|
71
|
-
#
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
)
|
|
70
|
+
def server_fn(context: Context):
|
|
71
|
+
# Instantiate strategy according to config. Here we pass other arguments
|
|
72
|
+
# that are only defined at runtime.
|
|
73
|
+
strategy = instantiate(
|
|
74
|
+
cfg.strategy,
|
|
75
|
+
on_fit_config_fn=get_on_fit_config(),
|
|
76
|
+
fit_metrics_aggregation_fn=fit_weighted_average,
|
|
77
|
+
initial_parameters=init_model_parameters,
|
|
78
|
+
evaluate_fn=get_evaluate_fn(
|
|
79
|
+
cfg.model, cfg.train.save_every_round, cfg_static.num_rounds, save_path
|
|
80
|
+
),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
config = ServerConfig(num_rounds=cfg_static.num_rounds)
|
|
84
|
+
|
|
85
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
86
|
+
|
|
81
87
|
|
|
82
88
|
# ServerApp for Flower Next
|
|
83
|
-
server = ServerApp(
|
|
84
|
-
config=ServerConfig(num_rounds=cfg_static.num_rounds),
|
|
85
|
-
strategy=strategy,
|
|
86
|
-
)
|
|
89
|
+
server = ServerApp(server_fn=server_fn)
|
|
@@ -10,6 +10,7 @@ from transformers import TrainingArguments
|
|
|
10
10
|
from trl import SFTTrainer
|
|
11
11
|
|
|
12
12
|
from flwr.client import NumPyClient
|
|
13
|
+
from flwr.common import Context
|
|
13
14
|
from flwr.common.typing import NDArrays, Scalar
|
|
14
15
|
from $import_name.dataset import reformat
|
|
15
16
|
from $import_name.models import cosine_annealing, get_model
|
|
@@ -102,13 +103,14 @@ def gen_client_fn(
|
|
|
102
103
|
model_cfg: DictConfig,
|
|
103
104
|
train_cfg: DictConfig,
|
|
104
105
|
save_path: str,
|
|
105
|
-
) -> Callable[[
|
|
106
|
+
) -> Callable[[Context], FlowerClient]: # pylint: disable=too-many-arguments
|
|
106
107
|
"""Generate the client function that creates the Flower Clients."""
|
|
107
108
|
|
|
108
|
-
def client_fn(
|
|
109
|
+
def client_fn(context: Context) -> FlowerClient:
|
|
109
110
|
"""Create a Flower client representing a single organization."""
|
|
110
111
|
# Let's get the partition corresponding to the i-th client
|
|
111
|
-
|
|
112
|
+
partition_id = context.node_config["partition-id"]
|
|
113
|
+
client_trainset = fds.load_partition(partition_id, "train")
|
|
112
114
|
client_trainset = reformat(client_trainset, llm_task="$llm_challenge_str")
|
|
113
115
|
|
|
114
116
|
return FlowerClient(
|
|
@@ -7,7 +7,7 @@ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
|
7
7
|
|
|
8
8
|
def server_fn(context: Context):
|
|
9
9
|
# Read from config
|
|
10
|
-
num_rounds =
|
|
10
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
11
11
|
|
|
12
12
|
# Define strategy
|
|
13
13
|
strategy = FedAvg(
|
|
@@ -18,5 +18,6 @@ def server_fn(context: Context):
|
|
|
18
18
|
|
|
19
19
|
return ServerAppComponents(strategy=strategy, config=config)
|
|
20
20
|
|
|
21
|
+
|
|
21
22
|
# Create ServerApp
|
|
22
23
|
app = ServerApp(server_fn=server_fn)
|
|
@@ -7,7 +7,7 @@ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
|
7
7
|
|
|
8
8
|
def server_fn(context: Context):
|
|
9
9
|
# Read from config
|
|
10
|
-
num_rounds =
|
|
10
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
11
11
|
|
|
12
12
|
# Define strategy
|
|
13
13
|
strategy = FedAvg()
|
|
@@ -15,5 +15,6 @@ def server_fn(context: Context):
|
|
|
15
15
|
|
|
16
16
|
return ServerAppComponents(strategy=strategy, config=config)
|
|
17
17
|
|
|
18
|
+
|
|
18
19
|
# Create ServerApp
|
|
19
20
|
app = ServerApp(server_fn=server_fn)
|
|
@@ -7,7 +7,7 @@ from flwr.server.strategy import FedAvg
|
|
|
7
7
|
|
|
8
8
|
def server_fn(context: Context):
|
|
9
9
|
# Read from config
|
|
10
|
-
num_rounds =
|
|
10
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
11
11
|
|
|
12
12
|
# Define strategy
|
|
13
13
|
strategy = FedAvg()
|
|
@@ -15,5 +15,6 @@ def server_fn(context: Context):
|
|
|
15
15
|
|
|
16
16
|
return ServerAppComponents(strategy=strategy, config=config)
|
|
17
17
|
|
|
18
|
+
|
|
18
19
|
# Create ServerApp
|
|
19
20
|
app = ServerApp(server_fn=server_fn)
|
|
@@ -7,7 +7,7 @@ from flwr.server.strategy import FedAvg
|
|
|
7
7
|
|
|
8
8
|
def server_fn(context: Context):
|
|
9
9
|
# Read from config
|
|
10
|
-
num_rounds =
|
|
10
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
11
11
|
|
|
12
12
|
# Define strategy
|
|
13
13
|
strategy = FedAvg()
|
|
@@ -15,5 +15,6 @@ def server_fn(context: Context):
|
|
|
15
15
|
|
|
16
16
|
return ServerAppComponents(strategy=strategy, config=config)
|
|
17
17
|
|
|
18
|
+
|
|
18
19
|
# Create ServerApp
|
|
19
20
|
app = ServerApp(server_fn=server_fn)
|
|
@@ -13,7 +13,7 @@ parameters = ndarrays_to_parameters(ndarrays)
|
|
|
13
13
|
|
|
14
14
|
def server_fn(context: Context):
|
|
15
15
|
# Read from config
|
|
16
|
-
num_rounds =
|
|
16
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
17
17
|
|
|
18
18
|
# Define strategy
|
|
19
19
|
strategy = FedAvg(
|
|
@@ -7,7 +7,7 @@ from flwr.server.strategy import FedAvg
|
|
|
7
7
|
|
|
8
8
|
def server_fn(context: Context):
|
|
9
9
|
# Read from config
|
|
10
|
-
num_rounds =
|
|
10
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
11
11
|
|
|
12
12
|
# Define strategy
|
|
13
13
|
strategy = FedAvg(
|
|
@@ -19,5 +19,6 @@ def server_fn(context: Context):
|
|
|
19
19
|
|
|
20
20
|
return ServerAppComponents(strategy=strategy, config=config)
|
|
21
21
|
|
|
22
|
+
|
|
22
23
|
# Create ServerApp
|
|
23
24
|
app = ServerApp(server_fn=server_fn)
|
|
@@ -13,7 +13,7 @@ parameters = ndarrays_to_parameters(load_model().get_weights())
|
|
|
13
13
|
|
|
14
14
|
def server_fn(context: Context):
|
|
15
15
|
# Read from config
|
|
16
|
-
num_rounds =
|
|
16
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
17
17
|
|
|
18
18
|
# Define strategy
|
|
19
19
|
strategy = strategy = FedAvg(
|
|
@@ -10,15 +10,27 @@ from torch.utils.data import DataLoader
|
|
|
10
10
|
from transformers import AutoTokenizer, DataCollatorWithPadding
|
|
11
11
|
|
|
12
12
|
from flwr_datasets import FederatedDataset
|
|
13
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
14
|
+
|
|
13
15
|
|
|
14
16
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
15
17
|
DEVICE = torch.device("cpu")
|
|
16
18
|
CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint
|
|
17
19
|
|
|
18
20
|
|
|
21
|
+
fds = None # Cache FederatedDataset
|
|
22
|
+
|
|
23
|
+
|
|
19
24
|
def load_data(partition_id: int, num_partitions: int):
|
|
20
25
|
"""Load IMDB data (training and eval)"""
|
|
21
|
-
|
|
26
|
+
# Only initialize `FederatedDataset` once
|
|
27
|
+
global fds
|
|
28
|
+
if fds is None:
|
|
29
|
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
30
|
+
fds = FederatedDataset(
|
|
31
|
+
dataset="stanfordnlp/imdb",
|
|
32
|
+
partitioners={"train": partitioner},
|
|
33
|
+
)
|
|
22
34
|
partition = fds.load_partition(partition_id)
|
|
23
35
|
# Divide data: 80% train, 20% test
|
|
24
36
|
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
|
@@ -5,10 +5,12 @@ import mlx.nn as nn
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
from datasets.utils.logging import disable_progress_bar
|
|
7
7
|
from flwr_datasets import FederatedDataset
|
|
8
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
disable_progress_bar()
|
|
11
12
|
|
|
13
|
+
|
|
12
14
|
class MLP(nn.Module):
|
|
13
15
|
"""A simple MLP."""
|
|
14
16
|
|
|
@@ -43,8 +45,19 @@ def batch_iterate(batch_size, X, y):
|
|
|
43
45
|
yield X[ids], y[ids]
|
|
44
46
|
|
|
45
47
|
|
|
48
|
+
fds = None # Cache FederatedDataset
|
|
49
|
+
|
|
50
|
+
|
|
46
51
|
def load_data(partition_id: int, num_partitions: int):
|
|
47
|
-
|
|
52
|
+
# Only initialize `FederatedDataset` once
|
|
53
|
+
global fds
|
|
54
|
+
if fds is None:
|
|
55
|
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
56
|
+
fds = FederatedDataset(
|
|
57
|
+
dataset="ylecun/mnist",
|
|
58
|
+
partitioners={"train": partitioner},
|
|
59
|
+
trust_remote_code=True,
|
|
60
|
+
)
|
|
48
61
|
partition = fds.load_partition(partition_id)
|
|
49
62
|
partition_splits = partition.train_test_split(test_size=0.2, seed=42)
|
|
50
63
|
|
|
@@ -6,9 +6,10 @@ import torch
|
|
|
6
6
|
import torch.nn as nn
|
|
7
7
|
import torch.nn.functional as F
|
|
8
8
|
from torch.utils.data import DataLoader
|
|
9
|
-
from torchvision.datasets import CIFAR10
|
|
10
9
|
from torchvision.transforms import Compose, Normalize, ToTensor
|
|
11
10
|
from flwr_datasets import FederatedDataset
|
|
11
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
12
|
+
|
|
12
13
|
|
|
13
14
|
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
14
15
|
|
|
@@ -34,9 +35,19 @@ class Net(nn.Module):
|
|
|
34
35
|
return self.fc3(x)
|
|
35
36
|
|
|
36
37
|
|
|
38
|
+
fds = None # Cache FederatedDataset
|
|
39
|
+
|
|
40
|
+
|
|
37
41
|
def load_data(partition_id: int, num_partitions: int):
|
|
38
42
|
"""Load partition CIFAR10 data."""
|
|
39
|
-
|
|
43
|
+
# Only initialize `FederatedDataset` once
|
|
44
|
+
global fds
|
|
45
|
+
if fds is None:
|
|
46
|
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
47
|
+
fds = FederatedDataset(
|
|
48
|
+
dataset="uoft-cs/cifar10",
|
|
49
|
+
partitioners={"train": partitioner},
|
|
50
|
+
)
|
|
40
51
|
partition = fds.load_partition(partition_id)
|
|
41
52
|
# Divide data on each node: 80% train, 20% test
|
|
42
53
|
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
|
@@ -4,11 +4,13 @@ import os
|
|
|
4
4
|
|
|
5
5
|
import tensorflow as tf
|
|
6
6
|
from flwr_datasets import FederatedDataset
|
|
7
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
# Make TensorFlow log less verbose
|
|
10
11
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
11
12
|
|
|
13
|
+
|
|
12
14
|
def load_model():
|
|
13
15
|
# Load model and data (MobileNetV2, CIFAR-10)
|
|
14
16
|
model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
|
|
@@ -16,9 +18,19 @@ def load_model():
|
|
|
16
18
|
return model
|
|
17
19
|
|
|
18
20
|
|
|
21
|
+
fds = None # Cache FederatedDataset
|
|
22
|
+
|
|
23
|
+
|
|
19
24
|
def load_data(partition_id, num_partitions):
|
|
20
25
|
# Download and partition dataset
|
|
21
|
-
|
|
26
|
+
# Only initialize `FederatedDataset` once
|
|
27
|
+
global fds
|
|
28
|
+
if fds is None:
|
|
29
|
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
30
|
+
fds = FederatedDataset(
|
|
31
|
+
dataset="uoft-cs/cifar10",
|
|
32
|
+
partitioners={"train": partitioner},
|
|
33
|
+
)
|
|
22
34
|
partition = fds.load_partition(partition_id, "train")
|
|
23
35
|
partition.set_format("numpy")
|
|
24
36
|
|
|
@@ -28,8 +28,8 @@ serverapp = "$import_name.server_app:app"
|
|
|
28
28
|
clientapp = "$import_name.client_app:app"
|
|
29
29
|
|
|
30
30
|
[tool.flwr.app.config]
|
|
31
|
-
num-server-rounds =
|
|
32
|
-
local-epochs =
|
|
31
|
+
num-server-rounds = 3
|
|
32
|
+
local-epochs = 1
|
|
33
33
|
|
|
34
34
|
[tool.flwr.federations]
|
|
35
35
|
default = "localhost"
|
|
@@ -25,12 +25,12 @@ serverapp = "$import_name.server_app:app"
|
|
|
25
25
|
clientapp = "$import_name.client_app:app"
|
|
26
26
|
|
|
27
27
|
[tool.flwr.app.config]
|
|
28
|
-
num-server-rounds =
|
|
29
|
-
local-epochs =
|
|
30
|
-
num-layers =
|
|
31
|
-
hidden-dim =
|
|
32
|
-
batch-size =
|
|
33
|
-
lr =
|
|
28
|
+
num-server-rounds = 3
|
|
29
|
+
local-epochs = 1
|
|
30
|
+
num-layers = 2
|
|
31
|
+
hidden-dim = 32
|
|
32
|
+
batch-size = 256
|
|
33
|
+
lr = 0.1
|
|
34
34
|
|
|
35
35
|
[tool.flwr.federations]
|
|
36
36
|
default = "localhost"
|
|
@@ -25,8 +25,8 @@ serverapp = "$import_name.server_app:app"
|
|
|
25
25
|
clientapp = "$import_name.client_app:app"
|
|
26
26
|
|
|
27
27
|
[tool.flwr.app.config]
|
|
28
|
-
num-server-rounds =
|
|
29
|
-
local-epochs =
|
|
28
|
+
num-server-rounds = 3
|
|
29
|
+
local-epochs = 1
|
|
30
30
|
|
|
31
31
|
[tool.flwr.federations]
|
|
32
32
|
default = "localhost"
|