flwr 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +4 -1
- flwr/app/__init__.py +28 -0
- flwr/app/exception.py +31 -0
- flwr/cli/app.py +2 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
- flwr/cli/cli_user_auth_interceptor.py +1 -1
- flwr/cli/config_utils.py +3 -3
- flwr/cli/constant.py +25 -8
- flwr/cli/log.py +9 -9
- flwr/cli/login/login.py +3 -3
- flwr/cli/ls.py +5 -5
- flwr/cli/new/new.py +15 -2
- flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
- flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +1 -0
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
- flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
- flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
- flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
- flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +111 -0
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
- flwr/cli/pull.py +100 -0
- flwr/cli/run/run.py +9 -13
- flwr/cli/stop.py +7 -4
- flwr/cli/utils.py +36 -8
- flwr/client/grpc_rere_client/connection.py +1 -12
- flwr/client/rest_client/connection.py +3 -0
- flwr/clientapp/__init__.py +10 -0
- flwr/clientapp/mod/__init__.py +29 -0
- flwr/clientapp/mod/centraldp_mods.py +248 -0
- flwr/clientapp/mod/localdp_mod.py +169 -0
- flwr/clientapp/typing.py +22 -0
- flwr/common/args.py +20 -6
- flwr/common/auth_plugin/__init__.py +4 -4
- flwr/common/auth_plugin/auth_plugin.py +7 -7
- flwr/common/constant.py +26 -4
- flwr/common/event_log_plugin/event_log_plugin.py +1 -1
- flwr/common/exit/__init__.py +4 -0
- flwr/common/exit/exit.py +8 -1
- flwr/common/exit/exit_code.py +30 -7
- flwr/common/exit/exit_handler.py +62 -0
- flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
- flwr/common/grpc.py +0 -11
- flwr/common/inflatable_utils.py +1 -1
- flwr/common/logger.py +1 -1
- flwr/common/record/typeddict.py +12 -0
- flwr/common/retry_invoker.py +30 -11
- flwr/common/telemetry.py +4 -0
- flwr/compat/server/app.py +2 -2
- flwr/proto/appio_pb2.py +25 -17
- flwr/proto/appio_pb2.pyi +46 -2
- flwr/proto/clientappio_pb2.py +3 -11
- flwr/proto/clientappio_pb2.pyi +0 -47
- flwr/proto/clientappio_pb2_grpc.py +19 -20
- flwr/proto/clientappio_pb2_grpc.pyi +10 -11
- flwr/proto/control_pb2.py +66 -0
- flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +24 -0
- flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +88 -54
- flwr/proto/control_pb2_grpc.pyi +106 -0
- flwr/proto/serverappio_pb2.py +2 -2
- flwr/proto/serverappio_pb2_grpc.py +68 -0
- flwr/proto/serverappio_pb2_grpc.pyi +26 -0
- flwr/proto/simulationio_pb2.py +4 -11
- flwr/proto/simulationio_pb2.pyi +0 -58
- flwr/proto/simulationio_pb2_grpc.py +129 -27
- flwr/proto/simulationio_pb2_grpc.pyi +52 -13
- flwr/server/app.py +142 -152
- flwr/server/grid/grpc_grid.py +3 -0
- flwr/server/grid/inmemory_grid.py +1 -0
- flwr/server/serverapp/app.py +157 -146
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
- flwr/server/superlink/fleet/vce/vce_api.py +6 -6
- flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
- flwr/server/superlink/linkstate/linkstate.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -1
- flwr/server/superlink/serverappio/serverappio_servicer.py +61 -6
- flwr/server/superlink/simulation/simulationio_servicer.py +97 -21
- flwr/serverapp/__init__.py +12 -0
- flwr/serverapp/exception.py +38 -0
- flwr/serverapp/strategy/__init__.py +64 -0
- flwr/serverapp/strategy/bulyan.py +238 -0
- flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
- flwr/serverapp/strategy/dp_fixed_clipping.py +374 -0
- flwr/serverapp/strategy/fedadagrad.py +159 -0
- flwr/serverapp/strategy/fedadam.py +178 -0
- flwr/serverapp/strategy/fedavg.py +320 -0
- flwr/serverapp/strategy/fedavgm.py +198 -0
- flwr/serverapp/strategy/fedmedian.py +105 -0
- flwr/serverapp/strategy/fedopt.py +218 -0
- flwr/serverapp/strategy/fedprox.py +174 -0
- flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
- flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
- flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
- flwr/serverapp/strategy/fedyogi.py +170 -0
- flwr/serverapp/strategy/krum.py +112 -0
- flwr/serverapp/strategy/multikrum.py +247 -0
- flwr/serverapp/strategy/qfedavg.py +252 -0
- flwr/serverapp/strategy/result.py +105 -0
- flwr/serverapp/strategy/strategy.py +285 -0
- flwr/serverapp/strategy/strategy_utils.py +299 -0
- flwr/simulation/app.py +161 -164
- flwr/simulation/run_simulation.py +25 -30
- flwr/supercore/app_utils.py +58 -0
- flwr/{supernode/scheduler → supercore/cli}/__init__.py +3 -3
- flwr/supercore/cli/flower_superexec.py +166 -0
- flwr/supercore/constant.py +19 -0
- flwr/supercore/{scheduler → corestate}/__init__.py +3 -3
- flwr/supercore/corestate/corestate.py +81 -0
- flwr/supercore/grpc_health/__init__.py +3 -0
- flwr/supercore/grpc_health/health_server.py +53 -0
- flwr/supercore/grpc_health/simple_health_servicer.py +2 -2
- flwr/{superexec → supercore/superexec}/__init__.py +1 -1
- flwr/supercore/superexec/plugin/__init__.py +28 -0
- flwr/{supernode/scheduler/simple_clientapp_scheduler_plugin.py → supercore/superexec/plugin/base_exec_plugin.py} +10 -6
- flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
- flwr/supercore/{scheduler/plugin.py → superexec/plugin/exec_plugin.py} +15 -5
- flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
- flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
- flwr/supercore/superexec/run_superexec.py +199 -0
- flwr/superlink/artifact_provider/__init__.py +22 -0
- flwr/superlink/artifact_provider/artifact_provider.py +37 -0
- flwr/superlink/servicer/__init__.py +15 -0
- flwr/superlink/servicer/control/__init__.py +22 -0
- flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +7 -7
- flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +27 -29
- flwr/{superexec/exec_license_interceptor.py → superlink/servicer/control/control_license_interceptor.py} +6 -6
- flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +127 -31
- flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +10 -10
- flwr/supernode/cli/flower_supernode.py +3 -0
- flwr/supernode/cli/flwr_clientapp.py +18 -21
- flwr/supernode/nodestate/in_memory_nodestate.py +2 -2
- flwr/supernode/nodestate/nodestate.py +3 -59
- flwr/supernode/runtime/run_clientapp.py +39 -102
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +10 -17
- flwr/supernode/start_client_internal.py +35 -76
- {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/METADATA +9 -18
- {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/RECORD +176 -128
- {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/entry_points.txt +1 -0
- flwr/proto/exec_pb2.py +0 -62
- flwr/proto/exec_pb2_grpc.pyi +0 -93
- flwr/superexec/app.py +0 -45
- flwr/superexec/deployment.py +0 -191
- flwr/superexec/executor.py +0 -100
- flwr/superexec/simulation.py +0 -129
- {flwr-1.20.0.dist-info → flwr-1.22.0.dist-info}/WHEEL +0 -0
|
@@ -1,53 +1,48 @@
|
|
|
1
1
|
"""$project_name: A Flower / FlowerTune app."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Iterable
|
|
4
4
|
from logging import INFO, WARN
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Optional
|
|
6
6
|
|
|
7
|
-
from flwr.
|
|
8
|
-
from flwr.
|
|
9
|
-
from flwr.
|
|
10
|
-
from flwr.
|
|
7
|
+
from flwr.app import ArrayRecord, ConfigRecord, Message, MetricRecord
|
|
8
|
+
from flwr.common import log
|
|
9
|
+
from flwr.serverapp import Grid
|
|
10
|
+
from flwr.serverapp.strategy import FedAvg
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class FlowerTuneLlm(FedAvg):
|
|
14
14
|
"""Customised FedAvg strategy implementation.
|
|
15
|
-
|
|
15
|
+
|
|
16
16
|
This class behaves just like FedAvg but also tracks the communication
|
|
17
|
-
costs associated with `
|
|
17
|
+
costs associated with `train` over FL rounds.
|
|
18
18
|
"""
|
|
19
19
|
def __init__(self, **kwargs):
|
|
20
20
|
super().__init__(**kwargs)
|
|
21
21
|
self.comm_tracker = CommunicationTracker()
|
|
22
22
|
|
|
23
|
-
def
|
|
24
|
-
|
|
25
|
-
):
|
|
23
|
+
def configure_train(
|
|
24
|
+
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
|
25
|
+
) -> Iterable[Message]:
|
|
26
26
|
"""Configure the next round of training."""
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
#
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
# Test communication costs
|
|
43
|
-
fit_res_list = [fit_res for _, fit_res in results]
|
|
44
|
-
self.comm_tracker.track(fit_res_list)
|
|
45
|
-
|
|
46
|
-
parameters_aggregated, metrics_aggregated = super().aggregate_fit(
|
|
47
|
-
server_round, results, failures
|
|
48
|
-
)
|
|
27
|
+
messages = super().configure_train(server_round, arrays, config, grid)
|
|
28
|
+
|
|
29
|
+
# Track communication costs
|
|
30
|
+
self.comm_tracker.track(messages)
|
|
31
|
+
|
|
32
|
+
return messages
|
|
33
|
+
|
|
34
|
+
def aggregate_train(
|
|
35
|
+
self,
|
|
36
|
+
server_round: int,
|
|
37
|
+
replies: Iterable[Message],
|
|
38
|
+
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
|
39
|
+
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
40
|
+
# Track communication costs
|
|
41
|
+
self.comm_tracker.track(replies)
|
|
49
42
|
|
|
50
|
-
|
|
43
|
+
arrays, metrics = super().aggregate_train(server_round, replies)
|
|
44
|
+
|
|
45
|
+
return arrays, metrics
|
|
51
46
|
|
|
52
47
|
|
|
53
48
|
class CommunicationTracker:
|
|
@@ -55,16 +50,16 @@ class CommunicationTracker:
|
|
|
55
50
|
def __init__(self):
|
|
56
51
|
self.curr_comm_cost = 0.0
|
|
57
52
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
53
|
+
def track(self, messages: Iterable[Message]):
|
|
54
|
+
comm_cost = (
|
|
55
|
+
sum(
|
|
56
|
+
record.count_bytes()
|
|
57
|
+
for msg in messages
|
|
58
|
+
if msg.has_content()
|
|
59
|
+
for record in msg.content.array_records.values()
|
|
60
|
+
)
|
|
61
|
+
/ 1024**2
|
|
62
|
+
)
|
|
68
63
|
|
|
69
64
|
self.curr_comm_cost += comm_cost
|
|
70
65
|
log(
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""$project_name: A Flower Baseline."""
|
|
2
2
|
|
|
3
|
-
from collections import OrderedDict
|
|
4
|
-
|
|
5
3
|
import torch
|
|
6
4
|
import torch.nn.functional as F
|
|
7
5
|
from torch import nn
|
|
@@ -66,15 +64,3 @@ def test(net, testloader, device):
|
|
|
66
64
|
accuracy = correct / len(testloader.dataset)
|
|
67
65
|
loss = loss / len(testloader)
|
|
68
66
|
return loss, accuracy
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def get_weights(net):
|
|
72
|
-
"""Extract model parameters as numpy arrays from state_dict."""
|
|
73
|
-
return [val.cpu().numpy() for _, val in net.state_dict().items()]
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def set_weights(net, parameters):
|
|
77
|
-
"""Apply parameters to an existing model."""
|
|
78
|
-
params_dict = zip(net.state_dict().keys(), parameters)
|
|
79
|
-
state_dict = OrderedDict({k: torch.from_numpy(v) for k, v in params_dict})
|
|
80
|
-
net.load_state_dict(state_dict, strict=True)
|
|
@@ -1,45 +1,43 @@
|
|
|
1
1
|
"""$project_name: A Flower Baseline."""
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
from flwr.
|
|
5
|
-
from flwr.
|
|
3
|
+
import torch
|
|
4
|
+
from flwr.app import ArrayRecord, Context
|
|
5
|
+
from flwr.serverapp import Grid, ServerApp
|
|
6
|
+
from flwr.serverapp.strategy import FedAvg
|
|
6
7
|
|
|
7
|
-
from $import_name.model import Net
|
|
8
|
+
from $import_name.model import Net
|
|
8
9
|
|
|
10
|
+
# Create ServerApp
|
|
11
|
+
app = ServerApp()
|
|
9
12
|
|
|
10
|
-
# Define metric aggregation function
|
|
11
|
-
def weighted_average(metrics: list[tuple[int, Metrics]]) -> Metrics:
|
|
12
|
-
"""Do weighted average of accuracy metric."""
|
|
13
|
-
# Multiply accuracy of each client by number of examples used
|
|
14
|
-
accuracies = [num_examples * float(m["accuracy"]) for num_examples, m in metrics]
|
|
15
|
-
examples = [num_examples for num_examples, _ in metrics]
|
|
16
|
-
|
|
17
|
-
# Aggregate and return custom metric (weighted average)
|
|
18
|
-
return {"accuracy": sum(accuracies) / sum(examples)}
|
|
19
13
|
|
|
14
|
+
@app.main()
|
|
15
|
+
def main(grid: Grid, context: Context) -> None:
|
|
16
|
+
"""Main entry point for the ServerApp."""
|
|
20
17
|
|
|
21
|
-
def server_fn(context: Context):
|
|
22
|
-
"""Construct components that set the ServerApp behaviour."""
|
|
23
18
|
# Read from config
|
|
24
19
|
num_rounds = context.run_config["num-server-rounds"]
|
|
25
|
-
|
|
20
|
+
fraction_train = context.run_config["fraction-train"]
|
|
26
21
|
|
|
27
|
-
#
|
|
28
|
-
|
|
29
|
-
|
|
22
|
+
# Load global model
|
|
23
|
+
global_model = Net()
|
|
24
|
+
arrays = ArrayRecord(global_model.state_dict())
|
|
30
25
|
|
|
31
|
-
#
|
|
26
|
+
# Initialize FedAvg strategy
|
|
32
27
|
strategy = FedAvg(
|
|
33
|
-
|
|
28
|
+
fraction_train=fraction_train,
|
|
34
29
|
fraction_evaluate=1.0,
|
|
35
|
-
|
|
36
|
-
initial_parameters=parameters,
|
|
37
|
-
evaluate_metrics_aggregation_fn=weighted_average,
|
|
30
|
+
min_available_nodes=2,
|
|
38
31
|
)
|
|
39
|
-
config = ServerConfig(num_rounds=int(num_rounds))
|
|
40
|
-
|
|
41
|
-
return ServerAppComponents(strategy=strategy, config=config)
|
|
42
32
|
|
|
33
|
+
# Start strategy, run FedAvg for `num_rounds`
|
|
34
|
+
result = strategy.start(
|
|
35
|
+
grid=grid,
|
|
36
|
+
initial_arrays=arrays,
|
|
37
|
+
num_rounds=num_rounds,
|
|
38
|
+
)
|
|
43
39
|
|
|
44
|
-
#
|
|
45
|
-
|
|
40
|
+
# Save final model to disk
|
|
41
|
+
print("\nSaving final model to disk...")
|
|
42
|
+
state_dict = result.arrays.to_torch_state_dict()
|
|
43
|
+
torch.save(state_dict, "final_model.pt")
|
|
@@ -1,17 +1,22 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
from flwr.
|
|
5
|
-
from flwr.
|
|
3
|
+
import torch
|
|
4
|
+
from flwr.app import ArrayRecord, Context
|
|
5
|
+
from flwr.serverapp import Grid, ServerApp
|
|
6
|
+
from flwr.serverapp.strategy import FedAvg
|
|
6
7
|
from transformers import AutoModelForSequenceClassification
|
|
7
8
|
|
|
8
|
-
|
|
9
|
+
# Create ServerApp
|
|
10
|
+
app = ServerApp()
|
|
11
|
+
|
|
9
12
|
|
|
13
|
+
@app.main()
|
|
14
|
+
def main(grid: Grid, context: Context) -> None:
|
|
15
|
+
"""Main entry point for the ServerApp."""
|
|
10
16
|
|
|
11
|
-
def server_fn(context: Context):
|
|
12
17
|
# Read from config
|
|
13
18
|
num_rounds = context.run_config["num-server-rounds"]
|
|
14
|
-
|
|
19
|
+
fraction_train = context.run_config["fraction-train"]
|
|
15
20
|
|
|
16
21
|
# Initialize global model
|
|
17
22
|
model_name = context.run_config["model-name"]
|
|
@@ -19,20 +24,19 @@ def server_fn(context: Context):
|
|
|
19
24
|
net = AutoModelForSequenceClassification.from_pretrained(
|
|
20
25
|
model_name, num_labels=num_labels
|
|
21
26
|
)
|
|
27
|
+
arrays = ArrayRecord(net.state_dict())
|
|
22
28
|
|
|
23
|
-
|
|
24
|
-
|
|
29
|
+
# Initialize FedAvg strategy
|
|
30
|
+
strategy = FedAvg(fraction_train=fraction_train)
|
|
25
31
|
|
|
26
|
-
#
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
32
|
+
# Start strategy, run FedAvg for `num_rounds`
|
|
33
|
+
result = strategy.start(
|
|
34
|
+
grid=grid,
|
|
35
|
+
initial_arrays=arrays,
|
|
36
|
+
num_rounds=num_rounds,
|
|
31
37
|
)
|
|
32
|
-
config = ServerConfig(num_rounds=num_rounds)
|
|
33
|
-
|
|
34
|
-
return ServerAppComponents(strategy=strategy, config=config)
|
|
35
38
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
+
# Save final model to disk
|
|
40
|
+
print("\nSaving final model to disk...")
|
|
41
|
+
state_dict = result.arrays.to_torch_state_dict()
|
|
42
|
+
torch.save(state_dict, "final_model.pt")
|
|
@@ -1,26 +1,39 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
from flwr.
|
|
5
|
-
from flwr.
|
|
3
|
+
import numpy as np
|
|
4
|
+
from flwr.app import ArrayRecord, Context
|
|
5
|
+
from flwr.serverapp import Grid, ServerApp
|
|
6
|
+
from flwr.serverapp.strategy import FedAvg
|
|
7
|
+
|
|
6
8
|
from $import_name.task import get_params, load_model
|
|
7
9
|
|
|
10
|
+
# Create ServerApp
|
|
11
|
+
app = ServerApp()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@app.main()
|
|
15
|
+
def main(grid: Grid, context: Context) -> None:
|
|
16
|
+
"""Main entry point for the ServerApp."""
|
|
8
17
|
|
|
9
|
-
def server_fn(context: Context):
|
|
10
18
|
# Read from config
|
|
11
19
|
num_rounds = context.run_config["num-server-rounds"]
|
|
12
20
|
input_dim = context.run_config["input-dim"]
|
|
13
21
|
|
|
14
|
-
#
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
# Define strategy
|
|
19
|
-
strategy = FedAvg(initial_parameters=initial_parameters)
|
|
20
|
-
config = ServerConfig(num_rounds=num_rounds)
|
|
22
|
+
# Load global model
|
|
23
|
+
model = load_model((input_dim,))
|
|
24
|
+
arrays = ArrayRecord(get_params(model))
|
|
21
25
|
|
|
22
|
-
|
|
26
|
+
# Initialize FedAvg strategy
|
|
27
|
+
strategy = FedAvg()
|
|
23
28
|
|
|
29
|
+
# Start strategy, run FedAvg for `num_rounds`
|
|
30
|
+
result = strategy.start(
|
|
31
|
+
grid=grid,
|
|
32
|
+
initial_arrays=arrays,
|
|
33
|
+
num_rounds=num_rounds,
|
|
34
|
+
)
|
|
24
35
|
|
|
25
|
-
#
|
|
26
|
-
|
|
36
|
+
# Save final model to disk
|
|
37
|
+
print("\nSaving final model to disk...")
|
|
38
|
+
ndarrays = result.arrays.to_numpy_ndarrays()
|
|
39
|
+
np.savez("final_model.npz", *ndarrays)
|
|
@@ -1,31 +1,41 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
-
from flwr.
|
|
4
|
-
from flwr.
|
|
5
|
-
from flwr.
|
|
6
|
-
from $import_name.task import MLP, get_params
|
|
3
|
+
from flwr.app import ArrayRecord, Context
|
|
4
|
+
from flwr.serverapp import Grid, ServerApp
|
|
5
|
+
from flwr.serverapp.strategy import FedAvg
|
|
7
6
|
|
|
7
|
+
from $import_name.task import MLP, get_params, set_params
|
|
8
8
|
|
|
9
|
-
|
|
9
|
+
# Create ServerApp
|
|
10
|
+
app = ServerApp()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@app.main()
|
|
14
|
+
def main(grid: Grid, context: Context) -> None:
|
|
15
|
+
"""Main entry point for the ServerApp."""
|
|
10
16
|
# Read from config
|
|
11
17
|
num_rounds = context.run_config["num-server-rounds"]
|
|
12
|
-
|
|
13
|
-
num_classes = 10
|
|
14
18
|
num_layers = context.run_config["num-layers"]
|
|
15
19
|
input_dim = context.run_config["input-dim"]
|
|
16
20
|
hidden_dim = context.run_config["hidden-dim"]
|
|
17
21
|
|
|
18
22
|
# Initialize global model
|
|
19
|
-
model = MLP(num_layers, input_dim, hidden_dim,
|
|
23
|
+
model = MLP(num_layers, input_dim, hidden_dim, output_dim=10)
|
|
20
24
|
params = get_params(model)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
#
|
|
24
|
-
strategy = FedAvg(
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
25
|
+
arrays = ArrayRecord(params)
|
|
26
|
+
|
|
27
|
+
# Initialize FedAvg strategy
|
|
28
|
+
strategy = FedAvg()
|
|
29
|
+
|
|
30
|
+
# Start strategy, run FedAvg for `num_rounds`
|
|
31
|
+
result = strategy.start(
|
|
32
|
+
grid=grid,
|
|
33
|
+
initial_arrays=arrays,
|
|
34
|
+
num_rounds=num_rounds,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# Save final model to disk
|
|
38
|
+
print("\nSaving final model to disk...")
|
|
39
|
+
ndarrays = result.arrays.to_numpy_ndarrays()
|
|
40
|
+
set_params(model, ndarrays)
|
|
41
|
+
model.save_weights("final_model.npz")
|
|
@@ -1,25 +1,38 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
from flwr.
|
|
5
|
-
from flwr.
|
|
6
|
-
from
|
|
7
|
-
|
|
3
|
+
import numpy as np
|
|
4
|
+
from flwr.app import ArrayRecord, Context
|
|
5
|
+
from flwr.serverapp import Grid, ServerApp
|
|
6
|
+
from flwr.serverapp.strategy import FedAvg
|
|
8
7
|
|
|
9
|
-
|
|
10
|
-
# Read from config
|
|
11
|
-
num_rounds = context.run_config["num-server-rounds"]
|
|
8
|
+
from $import_name.task import get_dummy_model
|
|
12
9
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
dummy_parameters = ndarrays_to_parameters([model])
|
|
10
|
+
# Create ServerApp
|
|
11
|
+
app = ServerApp()
|
|
16
12
|
|
|
17
|
-
# Define strategy
|
|
18
|
-
strategy = FedAvg(initial_parameters=dummy_parameters)
|
|
19
|
-
config = ServerConfig(num_rounds=num_rounds)
|
|
20
13
|
|
|
21
|
-
|
|
14
|
+
@app.main()
|
|
15
|
+
def main(grid: Grid, context: Context) -> None:
|
|
16
|
+
"""Main entry point for the ServerApp."""
|
|
22
17
|
|
|
18
|
+
# Read run config
|
|
19
|
+
num_rounds: int = context.run_config["num-server-rounds"]
|
|
23
20
|
|
|
24
|
-
#
|
|
25
|
-
|
|
21
|
+
# Load global model
|
|
22
|
+
model = get_dummy_model()
|
|
23
|
+
arrays = ArrayRecord(model)
|
|
24
|
+
|
|
25
|
+
# Initialize FedAvg strategy
|
|
26
|
+
strategy = FedAvg()
|
|
27
|
+
|
|
28
|
+
# Start strategy, run FedAvg for `num_rounds`
|
|
29
|
+
result = strategy.start(
|
|
30
|
+
grid=grid,
|
|
31
|
+
initial_arrays=arrays,
|
|
32
|
+
num_rounds=num_rounds,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Save final model to disk
|
|
36
|
+
print("\nSaving final model to disk...")
|
|
37
|
+
ndarrays = result.arrays.to_numpy_ndarrays()
|
|
38
|
+
np.savez("final_model", *ndarrays)
|
|
@@ -1,31 +1,41 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
from flwr.
|
|
5
|
-
from flwr.
|
|
6
|
-
from
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
def server_fn(context: Context):
|
|
10
|
-
# Read from config
|
|
11
|
-
num_rounds = context.run_config["num-server-rounds"]
|
|
12
|
-
fraction_fit = context.run_config["fraction-fit"]
|
|
13
|
-
|
|
14
|
-
# Initialize model parameters
|
|
15
|
-
ndarrays = get_weights(Net())
|
|
16
|
-
parameters = ndarrays_to_parameters(ndarrays)
|
|
17
|
-
|
|
18
|
-
# Define strategy
|
|
19
|
-
strategy = FedAvg(
|
|
20
|
-
fraction_fit=fraction_fit,
|
|
21
|
-
fraction_evaluate=1.0,
|
|
22
|
-
min_available_clients=2,
|
|
23
|
-
initial_parameters=parameters,
|
|
24
|
-
)
|
|
25
|
-
config = ServerConfig(num_rounds=num_rounds)
|
|
26
|
-
|
|
27
|
-
return ServerAppComponents(strategy=strategy, config=config)
|
|
3
|
+
import torch
|
|
4
|
+
from flwr.app import ArrayRecord, ConfigRecord, Context
|
|
5
|
+
from flwr.serverapp import Grid, ServerApp
|
|
6
|
+
from flwr.serverapp.strategy import FedAvg
|
|
28
7
|
|
|
8
|
+
from $import_name.task import Net
|
|
29
9
|
|
|
30
10
|
# Create ServerApp
|
|
31
|
-
app = ServerApp(
|
|
11
|
+
app = ServerApp()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@app.main()
|
|
15
|
+
def main(grid: Grid, context: Context) -> None:
|
|
16
|
+
"""Main entry point for the ServerApp."""
|
|
17
|
+
|
|
18
|
+
# Read run config
|
|
19
|
+
fraction_train: float = context.run_config["fraction-train"]
|
|
20
|
+
num_rounds: int = context.run_config["num-server-rounds"]
|
|
21
|
+
lr: float = context.run_config["lr"]
|
|
22
|
+
|
|
23
|
+
# Load global model
|
|
24
|
+
global_model = Net()
|
|
25
|
+
arrays = ArrayRecord(global_model.state_dict())
|
|
26
|
+
|
|
27
|
+
# Initialize FedAvg strategy
|
|
28
|
+
strategy = FedAvg(fraction_train=fraction_train)
|
|
29
|
+
|
|
30
|
+
# Start strategy, run FedAvg for `num_rounds`
|
|
31
|
+
result = strategy.start(
|
|
32
|
+
grid=grid,
|
|
33
|
+
initial_arrays=arrays,
|
|
34
|
+
train_config=ConfigRecord({"lr": lr}),
|
|
35
|
+
num_rounds=num_rounds,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# Save final model to disk
|
|
39
|
+
print("\nSaving final model to disk...")
|
|
40
|
+
state_dict = result.arrays.to_torch_state_dict()
|
|
41
|
+
torch.save(state_dict, "final_model.pt")
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
from flwr.common import Context, ndarrays_to_parameters
|
|
4
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
5
|
+
from flwr.server.strategy import FedAvg
|
|
6
|
+
from $import_name.task import Net, get_weights
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def server_fn(context: Context):
|
|
10
|
+
# Read from config
|
|
11
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
12
|
+
fraction_fit = context.run_config["fraction-fit"]
|
|
13
|
+
|
|
14
|
+
# Initialize model parameters
|
|
15
|
+
ndarrays = get_weights(Net())
|
|
16
|
+
parameters = ndarrays_to_parameters(ndarrays)
|
|
17
|
+
|
|
18
|
+
# Define strategy
|
|
19
|
+
strategy = FedAvg(
|
|
20
|
+
fraction_fit=fraction_fit,
|
|
21
|
+
fraction_evaluate=1.0,
|
|
22
|
+
min_available_clients=2,
|
|
23
|
+
initial_parameters=parameters,
|
|
24
|
+
)
|
|
25
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
26
|
+
|
|
27
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# Create ServerApp
|
|
31
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -1,36 +1,44 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
from flwr.
|
|
5
|
-
from flwr.
|
|
6
|
-
from
|
|
3
|
+
import joblib
|
|
4
|
+
from flwr.app import ArrayRecord, Context
|
|
5
|
+
from flwr.serverapp import Grid, ServerApp
|
|
6
|
+
from flwr.serverapp.strategy import FedAvg
|
|
7
7
|
|
|
8
|
+
from $import_name.task import get_model, get_model_params, set_initial_params, set_model_params
|
|
8
9
|
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
10
|
+
# Create ServerApp
|
|
11
|
+
app = ServerApp()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@app.main()
|
|
15
|
+
def main(grid: Grid, context: Context) -> None:
|
|
16
|
+
"""Main entry point for the ServerApp."""
|
|
17
|
+
|
|
18
|
+
# Read run config
|
|
19
|
+
num_rounds: int = context.run_config["num-server-rounds"]
|
|
12
20
|
|
|
13
21
|
# Create LogisticRegression Model
|
|
14
22
|
penalty = context.run_config["penalty"]
|
|
15
23
|
local_epochs = context.run_config["local-epochs"]
|
|
16
24
|
model = get_model(penalty, local_epochs)
|
|
17
|
-
|
|
18
25
|
# Setting initial parameters, akin to model.compile for keras models
|
|
19
26
|
set_initial_params(model)
|
|
27
|
+
# Construct ArrayRecord representation
|
|
28
|
+
arrays = ArrayRecord(get_model_params(model))
|
|
20
29
|
|
|
21
|
-
|
|
30
|
+
# Initialize FedAvg strategy
|
|
31
|
+
strategy = FedAvg(fraction_train=1.0, fraction_evaluate=1.0)
|
|
22
32
|
|
|
23
|
-
#
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
initial_parameters=initial_parameters,
|
|
33
|
+
# Start strategy, run FedAvg for `num_rounds`
|
|
34
|
+
result = strategy.start(
|
|
35
|
+
grid=grid,
|
|
36
|
+
initial_arrays=arrays,
|
|
37
|
+
num_rounds=num_rounds,
|
|
29
38
|
)
|
|
30
|
-
config = ServerConfig(num_rounds=num_rounds)
|
|
31
|
-
|
|
32
|
-
return ServerAppComponents(strategy=strategy, config=config)
|
|
33
39
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
40
|
+
# Save final model parameters
|
|
41
|
+
print("\nSaving final model to disk...")
|
|
42
|
+
ndarrays = result.arrays.to_numpy_ndarrays()
|
|
43
|
+
set_model_params(model, ndarrays)
|
|
44
|
+
joblib.dump(model, "logreg_model.pkl")
|
|
@@ -1,29 +1,38 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
-
from flwr.
|
|
4
|
-
from flwr.
|
|
5
|
-
from flwr.
|
|
3
|
+
from flwr.app import ArrayRecord, Context
|
|
4
|
+
from flwr.serverapp import Grid, ServerApp
|
|
5
|
+
from flwr.serverapp.strategy import FedAvg
|
|
6
6
|
|
|
7
7
|
from $import_name.task import load_model
|
|
8
8
|
|
|
9
|
+
# Create ServerApp
|
|
10
|
+
app = ServerApp()
|
|
9
11
|
|
|
10
|
-
def server_fn(context: Context):
|
|
11
|
-
# Read from config
|
|
12
|
-
num_rounds = context.run_config["num-server-rounds"]
|
|
13
12
|
|
|
14
|
-
|
|
15
|
-
|
|
13
|
+
@app.main()
|
|
14
|
+
def main(grid: Grid, context: Context) -> None:
|
|
15
|
+
"""Main entry point for the ServerApp."""
|
|
16
16
|
|
|
17
|
-
#
|
|
18
|
-
|
|
19
|
-
fraction_fit=1.0,
|
|
20
|
-
fraction_evaluate=1.0,
|
|
21
|
-
min_available_clients=2,
|
|
22
|
-
initial_parameters=parameters,
|
|
23
|
-
)
|
|
24
|
-
config = ServerConfig(num_rounds=num_rounds)
|
|
17
|
+
# Read run config
|
|
18
|
+
num_rounds: int = context.run_config["num-server-rounds"]
|
|
25
19
|
|
|
26
|
-
|
|
20
|
+
# Load global model
|
|
21
|
+
model = load_model()
|
|
22
|
+
arrays = ArrayRecord(model.get_weights())
|
|
27
23
|
|
|
28
|
-
#
|
|
29
|
-
|
|
24
|
+
# Initialize FedAvg strategy
|
|
25
|
+
strategy = FedAvg(fraction_train=1.0, fraction_evaluate=1.0)
|
|
26
|
+
|
|
27
|
+
# Start strategy, run FedAvg for `num_rounds`
|
|
28
|
+
result = strategy.start(
|
|
29
|
+
grid=grid,
|
|
30
|
+
initial_arrays=arrays,
|
|
31
|
+
num_rounds=num_rounds,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
# Save final model to disk
|
|
35
|
+
print("\nSaving final model to disk...")
|
|
36
|
+
ndarrays = result.arrays.to_numpy_ndarrays()
|
|
37
|
+
model.set_weights(ndarrays)
|
|
38
|
+
model.save("final_model.keras")
|