flwr 1.23.0__py3-none-any.whl → 1.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +16 -5
- flwr/app/error.py +2 -2
- flwr/app/exception.py +3 -3
- flwr/cli/app.py +19 -0
- flwr/cli/{new/templates → app_cmd}/__init__.py +9 -1
- flwr/cli/app_cmd/publish.py +285 -0
- flwr/cli/app_cmd/review.py +262 -0
- flwr/cli/auth_plugin/auth_plugin.py +4 -5
- flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
- flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
- flwr/cli/build.py +60 -18
- flwr/cli/cli_account_auth_interceptor.py +24 -7
- flwr/cli/config_utils.py +101 -13
- flwr/cli/{new/templates/app/code/flwr_tune → federation}/__init__.py +10 -1
- flwr/cli/federation/ls.py +140 -0
- flwr/cli/federation/show.py +318 -0
- flwr/cli/install.py +91 -13
- flwr/cli/log.py +52 -9
- flwr/cli/login/login.py +7 -4
- flwr/cli/ls.py +211 -130
- flwr/cli/new/new.py +123 -331
- flwr/cli/pull.py +10 -5
- flwr/cli/run/run.py +71 -29
- flwr/cli/run_utils.py +148 -0
- flwr/cli/stop.py +26 -8
- flwr/cli/supernode/ls.py +25 -12
- flwr/cli/supernode/register.py +9 -4
- flwr/cli/supernode/unregister.py +5 -3
- flwr/cli/utils.py +239 -16
- flwr/client/__init__.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +4 -1
- flwr/client/grpc_adapter_client/connection.py +8 -9
- flwr/client/grpc_rere_client/connection.py +16 -14
- flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
- flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/connection.py +18 -18
- flwr/client/run_info_store.py +4 -5
- flwr/client/typing.py +1 -1
- flwr/clientapp/client_app.py +9 -10
- flwr/clientapp/mod/centraldp_mods.py +16 -17
- flwr/clientapp/mod/localdp_mod.py +8 -9
- flwr/clientapp/typing.py +1 -1
- flwr/clientapp/utils.py +3 -3
- flwr/common/address.py +1 -2
- flwr/common/args.py +3 -4
- flwr/common/config.py +13 -16
- flwr/common/constant.py +5 -2
- flwr/common/differential_privacy.py +3 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -4
- flwr/common/exit/exit.py +15 -2
- flwr/common/exit/exit_code.py +19 -0
- flwr/common/exit/exit_handler.py +6 -2
- flwr/common/exit/signal_handler.py +5 -5
- flwr/common/grpc.py +6 -6
- flwr/common/inflatable_protobuf_utils.py +1 -1
- flwr/common/inflatable_utils.py +38 -21
- flwr/common/logger.py +19 -19
- flwr/common/message.py +4 -4
- flwr/common/object_ref.py +7 -7
- flwr/common/record/array.py +3 -3
- flwr/common/record/arrayrecord.py +18 -30
- flwr/common/record/configrecord.py +3 -3
- flwr/common/record/recorddict.py +5 -5
- flwr/common/record/typeddict.py +9 -2
- flwr/common/recorddict_compat.py +7 -10
- flwr/common/retry_invoker.py +20 -20
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
- flwr/common/serde.py +11 -4
- flwr/common/serde_utils.py +2 -2
- flwr/common/telemetry.py +9 -5
- flwr/common/typing.py +58 -37
- flwr/compat/client/app.py +38 -37
- flwr/compat/client/grpc_client/connection.py +11 -11
- flwr/compat/server/app.py +5 -6
- flwr/proto/appio_pb2.py +13 -3
- flwr/proto/appio_pb2.pyi +134 -65
- flwr/proto/appio_pb2_grpc.py +20 -0
- flwr/proto/appio_pb2_grpc.pyi +27 -0
- flwr/proto/clientappio_pb2.py +17 -7
- flwr/proto/clientappio_pb2.pyi +15 -0
- flwr/proto/clientappio_pb2_grpc.py +206 -40
- flwr/proto/clientappio_pb2_grpc.pyi +168 -53
- flwr/proto/control_pb2.py +71 -52
- flwr/proto/control_pb2.pyi +277 -111
- flwr/proto/control_pb2_grpc.py +249 -40
- flwr/proto/control_pb2_grpc.pyi +185 -52
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2.pyi +24 -6
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/error_pb2_grpc.pyi +27 -0
- flwr/proto/fab_pb2.py +14 -4
- flwr/proto/fab_pb2.pyi +59 -31
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fab_pb2_grpc.pyi +27 -0
- flwr/proto/federation_pb2.py +38 -0
- flwr/proto/federation_pb2.pyi +56 -0
- flwr/proto/federation_pb2_grpc.py +24 -0
- flwr/proto/federation_pb2_grpc.pyi +31 -0
- flwr/proto/fleet_pb2.py +24 -14
- flwr/proto/fleet_pb2.pyi +141 -61
- flwr/proto/fleet_pb2_grpc.py +189 -48
- flwr/proto/fleet_pb2_grpc.pyi +175 -61
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2.pyi +38 -16
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
- flwr/proto/heartbeat_pb2.py +17 -7
- flwr/proto/heartbeat_pb2.pyi +51 -22
- flwr/proto/heartbeat_pb2_grpc.py +20 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2.pyi +34 -11
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/log_pb2_grpc.pyi +27 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2.pyi +154 -86
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/message_pb2_grpc.pyi +27 -0
- flwr/proto/node_pb2.py +15 -5
- flwr/proto/node_pb2.pyi +50 -25
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/node_pb2_grpc.pyi +27 -0
- flwr/proto/recorddict_pb2.py +13 -3
- flwr/proto/recorddict_pb2.pyi +184 -107
- flwr/proto/recorddict_pb2_grpc.py +20 -0
- flwr/proto/recorddict_pb2_grpc.pyi +27 -0
- flwr/proto/run_pb2.py +40 -31
- flwr/proto/run_pb2.pyi +158 -84
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/run_pb2_grpc.pyi +27 -0
- flwr/proto/serverappio_pb2.py +13 -3
- flwr/proto/serverappio_pb2.pyi +32 -8
- flwr/proto/serverappio_pb2_grpc.py +246 -65
- flwr/proto/serverappio_pb2_grpc.pyi +221 -85
- flwr/proto/simulationio_pb2.py +16 -8
- flwr/proto/simulationio_pb2.pyi +15 -0
- flwr/proto/simulationio_pb2_grpc.py +162 -41
- flwr/proto/simulationio_pb2_grpc.pyi +149 -55
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2.pyi +249 -160
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/proto/transport_pb2_grpc.pyi +38 -8
- flwr/server/app.py +39 -17
- flwr/server/client_manager.py +4 -5
- flwr/server/client_proxy.py +10 -11
- flwr/server/compat/app.py +4 -5
- flwr/server/compat/app_utils.py +2 -1
- flwr/server/compat/grid_client_proxy.py +10 -12
- flwr/server/compat/legacy_context.py +3 -4
- flwr/server/fleet_event_log_interceptor.py +2 -1
- flwr/server/grid/grid.py +2 -3
- flwr/server/grid/grpc_grid.py +10 -8
- flwr/server/grid/inmemory_grid.py +4 -4
- flwr/server/run_serverapp.py +2 -3
- flwr/server/server.py +34 -39
- flwr/server/server_app.py +7 -8
- flwr/server/server_config.py +1 -2
- flwr/server/serverapp/app.py +34 -28
- flwr/server/serverapp_components.py +4 -5
- flwr/server/strategy/aggregate.py +9 -8
- flwr/server/strategy/bulyan.py +13 -11
- flwr/server/strategy/dp_adaptive_clipping.py +16 -20
- flwr/server/strategy/dp_fixed_clipping.py +12 -17
- flwr/server/strategy/dpfedavg_adaptive.py +3 -4
- flwr/server/strategy/dpfedavg_fixed.py +6 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
- flwr/server/strategy/fedadagrad.py +18 -14
- flwr/server/strategy/fedadam.py +16 -14
- flwr/server/strategy/fedavg.py +16 -17
- flwr/server/strategy/fedavg_android.py +15 -15
- flwr/server/strategy/fedavgm.py +21 -18
- flwr/server/strategy/fedmedian.py +2 -3
- flwr/server/strategy/fedopt.py +11 -10
- flwr/server/strategy/fedprox.py +10 -9
- flwr/server/strategy/fedtrimmedavg.py +12 -11
- flwr/server/strategy/fedxgb_bagging.py +13 -11
- flwr/server/strategy/fedxgb_cyclic.py +6 -6
- flwr/server/strategy/fedxgb_nn_avg.py +4 -4
- flwr/server/strategy/fedyogi.py +16 -14
- flwr/server/strategy/krum.py +12 -11
- flwr/server/strategy/qfedavg.py +16 -15
- flwr/server/strategy/strategy.py +6 -9
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
- flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +75 -30
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +15 -9
- flwr/server/superlink/linkstate/in_memory_linkstate.py +148 -149
- flwr/server/superlink/linkstate/linkstate.py +91 -43
- flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
- flwr/server/superlink/linkstate/sqlite_linkstate.py +502 -436
- flwr/server/superlink/linkstate/utils.py +6 -6
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
- flwr/server/superlink/utils.py +4 -6
- flwr/server/typing.py +1 -1
- flwr/server/utils/tensorboard.py +15 -8
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
- flwr/serverapp/strategy/bulyan.py +16 -15
- flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
- flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
- flwr/serverapp/strategy/fedadagrad.py +10 -11
- flwr/serverapp/strategy/fedadam.py +10 -11
- flwr/serverapp/strategy/fedavg.py +9 -10
- flwr/serverapp/strategy/fedavgm.py +17 -16
- flwr/serverapp/strategy/fedmedian.py +2 -2
- flwr/serverapp/strategy/fedopt.py +10 -11
- flwr/serverapp/strategy/fedprox.py +7 -8
- flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
- flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
- flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
- flwr/serverapp/strategy/fedyogi.py +9 -11
- flwr/serverapp/strategy/krum.py +7 -7
- flwr/serverapp/strategy/multikrum.py +9 -9
- flwr/serverapp/strategy/qfedavg.py +17 -16
- flwr/serverapp/strategy/strategy.py +6 -9
- flwr/serverapp/strategy/strategy_utils.py +7 -8
- flwr/simulation/app.py +46 -42
- flwr/simulation/legacy_app.py +12 -12
- flwr/simulation/ray_transport/ray_actor.py +10 -11
- flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
- flwr/simulation/run_simulation.py +43 -43
- flwr/simulation/simulationio_connection.py +4 -4
- flwr/supercore/cli/flower_superexec.py +3 -4
- flwr/supercore/constant.py +34 -1
- flwr/supercore/corestate/corestate.py +24 -3
- flwr/supercore/corestate/in_memory_corestate.py +138 -0
- flwr/supercore/corestate/sqlite_corestate.py +157 -0
- flwr/supercore/ffs/disk_ffs.py +1 -2
- flwr/supercore/ffs/ffs.py +1 -2
- flwr/supercore/ffs/ffs_factory.py +1 -2
- flwr/{common → supercore}/heartbeat.py +20 -25
- flwr/supercore/object_store/in_memory_object_store.py +1 -2
- flwr/supercore/object_store/object_store.py +1 -2
- flwr/supercore/object_store/object_store_factory.py +1 -2
- flwr/supercore/object_store/sqlite_object_store.py +8 -7
- flwr/supercore/primitives/asymmetric.py +1 -1
- flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
- flwr/supercore/sqlite_mixin.py +37 -34
- flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
- flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
- flwr/supercore/superexec/run_superexec.py +9 -13
- flwr/supercore/utils.py +190 -0
- flwr/superlink/artifact_provider/artifact_provider.py +1 -2
- flwr/superlink/auth_plugin/auth_plugin.py +6 -9
- flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
- flwr/{cli/new/templates/app → superlink/federation}/__init__.py +10 -1
- flwr/superlink/federation/federation_manager.py +64 -0
- flwr/superlink/federation/noop_federation_manager.py +71 -0
- flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
- flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
- flwr/superlink/servicer/control/control_grpc.py +7 -6
- flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
- flwr/superlink/servicer/control/control_servicer.py +190 -23
- flwr/supernode/cli/flower_supernode.py +58 -3
- flwr/supernode/nodestate/in_memory_nodestate.py +121 -49
- flwr/supernode/nodestate/nodestate.py +52 -8
- flwr/supernode/nodestate/nodestate_factory.py +7 -4
- flwr/supernode/runtime/run_clientapp.py +41 -22
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +46 -10
- flwr/supernode/start_client_internal.py +165 -46
- {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/METADATA +9 -11
- flwr-1.25.0.dist-info/RECORD +393 -0
- flwr/cli/new/templates/app/.gitignore.tpl +0 -163
- flwr/cli/new/templates/app/LICENSE.tpl +0 -202
- flwr/cli/new/templates/app/README.baseline.md.tpl +0 -127
- flwr/cli/new/templates/app/README.flowertune.md.tpl +0 -68
- flwr/cli/new/templates/app/README.md.tpl +0 -37
- flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/code/__init__.py +0 -15
- flwr/cli/new/templates/app/code/__init__.py.tpl +0 -1
- flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +0 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +0 -75
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +0 -93
- flwr/cli/new/templates/app/code/client.jax.py.tpl +0 -71
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +0 -102
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +0 -46
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +0 -80
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +0 -55
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +0 -108
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -82
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +0 -110
- flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +0 -36
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +0 -92
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +0 -87
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -56
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +0 -73
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +0 -78
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -66
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +0 -43
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +0 -42
- flwr/cli/new/templates/app/code/server.jax.py.tpl +0 -39
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +0 -41
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +0 -38
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +0 -41
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +0 -31
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +0 -44
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +0 -38
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +0 -56
- flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +0 -98
- flwr/cli/new/templates/app/code/task.jax.py.tpl +0 -57
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +0 -102
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +0 -7
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +0 -98
- flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +0 -111
- flwr/cli/new/templates/app/code/task.sklearn.py.tpl +0 -67
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +0 -52
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +0 -67
- flwr/cli/new/templates/app/code/utils.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +0 -146
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +0 -80
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +0 -65
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +0 -52
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +0 -56
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +0 -49
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +0 -52
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +0 -61
- flwr/supercore/object_store/utils.py +0 -43
- flwr-1.23.0.dist-info/RECORD +0 -439
- {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/WHEEL +0 -0
- {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from flwr.app import ArrayRecord, ConfigRecord, Context
|
|
5
|
-
from flwr.serverapp import Grid, ServerApp
|
|
6
|
-
from flwr.serverapp.strategy import FedAvg
|
|
7
|
-
|
|
8
|
-
from $import_name.task import Net
|
|
9
|
-
|
|
10
|
-
# Create ServerApp
|
|
11
|
-
app = ServerApp()
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
@app.main()
|
|
15
|
-
def main(grid: Grid, context: Context) -> None:
|
|
16
|
-
"""Main entry point for the ServerApp."""
|
|
17
|
-
|
|
18
|
-
# Read run config
|
|
19
|
-
fraction_train: float = context.run_config["fraction-train"]
|
|
20
|
-
num_rounds: int = context.run_config["num-server-rounds"]
|
|
21
|
-
lr: float = context.run_config["lr"]
|
|
22
|
-
|
|
23
|
-
# Load global model
|
|
24
|
-
global_model = Net()
|
|
25
|
-
arrays = ArrayRecord(global_model.state_dict())
|
|
26
|
-
|
|
27
|
-
# Initialize FedAvg strategy
|
|
28
|
-
strategy = FedAvg(fraction_train=fraction_train)
|
|
29
|
-
|
|
30
|
-
# Start strategy, run FedAvg for `num_rounds`
|
|
31
|
-
result = strategy.start(
|
|
32
|
-
grid=grid,
|
|
33
|
-
initial_arrays=arrays,
|
|
34
|
-
train_config=ConfigRecord({"lr": lr}),
|
|
35
|
-
num_rounds=num_rounds,
|
|
36
|
-
)
|
|
37
|
-
|
|
38
|
-
# Save final model to disk
|
|
39
|
-
print("\nSaving final model to disk...")
|
|
40
|
-
state_dict = result.arrays.to_torch_state_dict()
|
|
41
|
-
torch.save(state_dict, "final_model.pt")
|
|
@@ -1,31 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
-
|
|
3
|
-
from flwr.common import Context, ndarrays_to_parameters
|
|
4
|
-
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
5
|
-
from flwr.server.strategy import FedAvg
|
|
6
|
-
from $import_name.task import Net, get_weights
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
def server_fn(context: Context):
|
|
10
|
-
# Read from config
|
|
11
|
-
num_rounds = context.run_config["num-server-rounds"]
|
|
12
|
-
fraction_fit = context.run_config["fraction-fit"]
|
|
13
|
-
|
|
14
|
-
# Initialize model parameters
|
|
15
|
-
ndarrays = get_weights(Net())
|
|
16
|
-
parameters = ndarrays_to_parameters(ndarrays)
|
|
17
|
-
|
|
18
|
-
# Define strategy
|
|
19
|
-
strategy = FedAvg(
|
|
20
|
-
fraction_fit=fraction_fit,
|
|
21
|
-
fraction_evaluate=1.0,
|
|
22
|
-
min_available_clients=2,
|
|
23
|
-
initial_parameters=parameters,
|
|
24
|
-
)
|
|
25
|
-
config = ServerConfig(num_rounds=num_rounds)
|
|
26
|
-
|
|
27
|
-
return ServerAppComponents(strategy=strategy, config=config)
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
# Create ServerApp
|
|
31
|
-
app = ServerApp(server_fn=server_fn)
|
|
@@ -1,44 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
-
|
|
3
|
-
import joblib
|
|
4
|
-
from flwr.app import ArrayRecord, Context
|
|
5
|
-
from flwr.serverapp import Grid, ServerApp
|
|
6
|
-
from flwr.serverapp.strategy import FedAvg
|
|
7
|
-
|
|
8
|
-
from $import_name.task import get_model, get_model_params, set_initial_params, set_model_params
|
|
9
|
-
|
|
10
|
-
# Create ServerApp
|
|
11
|
-
app = ServerApp()
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
@app.main()
|
|
15
|
-
def main(grid: Grid, context: Context) -> None:
|
|
16
|
-
"""Main entry point for the ServerApp."""
|
|
17
|
-
|
|
18
|
-
# Read run config
|
|
19
|
-
num_rounds: int = context.run_config["num-server-rounds"]
|
|
20
|
-
|
|
21
|
-
# Create LogisticRegression Model
|
|
22
|
-
penalty = context.run_config["penalty"]
|
|
23
|
-
local_epochs = context.run_config["local-epochs"]
|
|
24
|
-
model = get_model(penalty, local_epochs)
|
|
25
|
-
# Setting initial parameters, akin to model.compile for keras models
|
|
26
|
-
set_initial_params(model)
|
|
27
|
-
# Construct ArrayRecord representation
|
|
28
|
-
arrays = ArrayRecord(get_model_params(model))
|
|
29
|
-
|
|
30
|
-
# Initialize FedAvg strategy
|
|
31
|
-
strategy = FedAvg(fraction_train=1.0, fraction_evaluate=1.0)
|
|
32
|
-
|
|
33
|
-
# Start strategy, run FedAvg for `num_rounds`
|
|
34
|
-
result = strategy.start(
|
|
35
|
-
grid=grid,
|
|
36
|
-
initial_arrays=arrays,
|
|
37
|
-
num_rounds=num_rounds,
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
# Save final model parameters
|
|
41
|
-
print("\nSaving final model to disk...")
|
|
42
|
-
ndarrays = result.arrays.to_numpy_ndarrays()
|
|
43
|
-
set_model_params(model, ndarrays)
|
|
44
|
-
joblib.dump(model, "logreg_model.pkl")
|
|
@@ -1,38 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
-
|
|
3
|
-
from flwr.app import ArrayRecord, Context
|
|
4
|
-
from flwr.serverapp import Grid, ServerApp
|
|
5
|
-
from flwr.serverapp.strategy import FedAvg
|
|
6
|
-
|
|
7
|
-
from $import_name.task import load_model
|
|
8
|
-
|
|
9
|
-
# Create ServerApp
|
|
10
|
-
app = ServerApp()
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
@app.main()
|
|
14
|
-
def main(grid: Grid, context: Context) -> None:
|
|
15
|
-
"""Main entry point for the ServerApp."""
|
|
16
|
-
|
|
17
|
-
# Read run config
|
|
18
|
-
num_rounds: int = context.run_config["num-server-rounds"]
|
|
19
|
-
|
|
20
|
-
# Load global model
|
|
21
|
-
model = load_model()
|
|
22
|
-
arrays = ArrayRecord(model.get_weights())
|
|
23
|
-
|
|
24
|
-
# Initialize FedAvg strategy
|
|
25
|
-
strategy = FedAvg(fraction_train=1.0, fraction_evaluate=1.0)
|
|
26
|
-
|
|
27
|
-
# Start strategy, run FedAvg for `num_rounds`
|
|
28
|
-
result = strategy.start(
|
|
29
|
-
grid=grid,
|
|
30
|
-
initial_arrays=arrays,
|
|
31
|
-
num_rounds=num_rounds,
|
|
32
|
-
)
|
|
33
|
-
|
|
34
|
-
# Save final model to disk
|
|
35
|
-
print("\nSaving final model to disk...")
|
|
36
|
-
ndarrays = result.arrays.to_numpy_ndarrays()
|
|
37
|
-
model.set_weights(ndarrays)
|
|
38
|
-
model.save("final_model.keras")
|
|
@@ -1,56 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import xgboost as xgb
|
|
5
|
-
from flwr.app import ArrayRecord, Context
|
|
6
|
-
from flwr.common.config import unflatten_dict
|
|
7
|
-
from flwr.serverapp import Grid, ServerApp
|
|
8
|
-
from flwr.serverapp.strategy import FedXgbBagging
|
|
9
|
-
|
|
10
|
-
from $import_name.task import replace_keys
|
|
11
|
-
|
|
12
|
-
# Create ServerApp
|
|
13
|
-
app = ServerApp()
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
@app.main()
|
|
17
|
-
def main(grid: Grid, context: Context) -> None:
|
|
18
|
-
# Read run config
|
|
19
|
-
num_rounds = context.run_config["num-server-rounds"]
|
|
20
|
-
fraction_train = context.run_config["fraction-train"]
|
|
21
|
-
fraction_evaluate = context.run_config["fraction-evaluate"]
|
|
22
|
-
# Flatted config dict and replace "-" with "_"
|
|
23
|
-
cfg = replace_keys(unflatten_dict(context.run_config))
|
|
24
|
-
params = cfg["params"]
|
|
25
|
-
|
|
26
|
-
# Init global model
|
|
27
|
-
# Init with an empty object; the XGBooster will be created
|
|
28
|
-
# and trained on the client side.
|
|
29
|
-
global_model = b""
|
|
30
|
-
# Note: we store the model as the first item in a list into ArrayRecord,
|
|
31
|
-
# which can be accessed using index ["0"].
|
|
32
|
-
arrays = ArrayRecord([np.frombuffer(global_model, dtype=np.uint8)])
|
|
33
|
-
|
|
34
|
-
# Initialize FedXgbBagging strategy
|
|
35
|
-
strategy = FedXgbBagging(
|
|
36
|
-
fraction_train=fraction_train,
|
|
37
|
-
fraction_evaluate=fraction_evaluate,
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
# Start strategy, run FedXgbBagging for `num_rounds`
|
|
41
|
-
result = strategy.start(
|
|
42
|
-
grid=grid,
|
|
43
|
-
initial_arrays=arrays,
|
|
44
|
-
num_rounds=num_rounds,
|
|
45
|
-
)
|
|
46
|
-
|
|
47
|
-
# Save final model to disk
|
|
48
|
-
bst = xgb.Booster(params=params)
|
|
49
|
-
global_model = bytearray(result.arrays["0"].numpy().tobytes())
|
|
50
|
-
|
|
51
|
-
# Load global model into booster
|
|
52
|
-
bst.load_model(global_model)
|
|
53
|
-
|
|
54
|
-
# Save model
|
|
55
|
-
print("\nSaving final model to disk...")
|
|
56
|
-
bst.save_model("final_model.json")
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower Baseline."""
|
|
@@ -1,98 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
-
|
|
3
|
-
import warnings
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
import transformers
|
|
7
|
-
from datasets.utils.logging import disable_progress_bar
|
|
8
|
-
from evaluate import load as load_metric
|
|
9
|
-
from flwr_datasets import FederatedDataset
|
|
10
|
-
from flwr_datasets.partitioner import IidPartitioner
|
|
11
|
-
from torch.optim import AdamW
|
|
12
|
-
from torch.utils.data import DataLoader
|
|
13
|
-
from transformers import AutoTokenizer, DataCollatorWithPadding
|
|
14
|
-
|
|
15
|
-
warnings.filterwarnings("ignore", category=UserWarning)
|
|
16
|
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
17
|
-
disable_progress_bar()
|
|
18
|
-
transformers.logging.set_verbosity_error()
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
fds = None # Cache FederatedDataset
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def load_data(partition_id: int, num_partitions: int, model_name: str):
|
|
25
|
-
"""Load IMDB data (training and eval)"""
|
|
26
|
-
# Only initialize `FederatedDataset` once
|
|
27
|
-
global fds
|
|
28
|
-
if fds is None:
|
|
29
|
-
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
30
|
-
fds = FederatedDataset(
|
|
31
|
-
dataset="stanfordnlp/imdb",
|
|
32
|
-
partitioners={"train": partitioner},
|
|
33
|
-
)
|
|
34
|
-
partition = fds.load_partition(partition_id)
|
|
35
|
-
# Divide data: 80% train, 20% test
|
|
36
|
-
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
|
37
|
-
|
|
38
|
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
39
|
-
|
|
40
|
-
def tokenize_function(examples):
|
|
41
|
-
return tokenizer(
|
|
42
|
-
examples["text"], truncation=True, add_special_tokens=True, max_length=512
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
partition_train_test = partition_train_test.map(tokenize_function, batched=True)
|
|
46
|
-
partition_train_test = partition_train_test.remove_columns("text")
|
|
47
|
-
partition_train_test = partition_train_test.rename_column("label", "labels")
|
|
48
|
-
|
|
49
|
-
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
|
50
|
-
trainloader = DataLoader(
|
|
51
|
-
partition_train_test["train"],
|
|
52
|
-
shuffle=True,
|
|
53
|
-
batch_size=32,
|
|
54
|
-
collate_fn=data_collator,
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
testloader = DataLoader(
|
|
58
|
-
partition_train_test["test"], batch_size=32, collate_fn=data_collator
|
|
59
|
-
)
|
|
60
|
-
|
|
61
|
-
return trainloader, testloader
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
def train(net, trainloader, num_steps, device):
|
|
65
|
-
optimizer = AdamW(net.parameters(), lr=5e-5)
|
|
66
|
-
net.train()
|
|
67
|
-
running_loss = 0.0
|
|
68
|
-
step_cnt = 0
|
|
69
|
-
for batch in trainloader:
|
|
70
|
-
batch = {k: v.to(device) for k, v in batch.items()}
|
|
71
|
-
outputs = net(**batch)
|
|
72
|
-
loss = outputs.loss
|
|
73
|
-
loss.backward()
|
|
74
|
-
optimizer.step()
|
|
75
|
-
optimizer.zero_grad()
|
|
76
|
-
running_loss += loss.item()
|
|
77
|
-
step_cnt += 1
|
|
78
|
-
if step_cnt >= num_steps:
|
|
79
|
-
break
|
|
80
|
-
avg_trainloss = running_loss / step_cnt
|
|
81
|
-
return avg_trainloss
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
def test(net, testloader, device):
|
|
85
|
-
metric = load_metric("accuracy")
|
|
86
|
-
loss = 0
|
|
87
|
-
net.eval()
|
|
88
|
-
for batch in testloader:
|
|
89
|
-
batch = {k: v.to(device) for k, v in batch.items()}
|
|
90
|
-
with torch.no_grad():
|
|
91
|
-
outputs = net(**batch)
|
|
92
|
-
logits = outputs.logits
|
|
93
|
-
loss += outputs.loss.item()
|
|
94
|
-
predictions = torch.argmax(logits, dim=-1)
|
|
95
|
-
metric.add_batch(predictions=predictions, references=batch["labels"])
|
|
96
|
-
loss /= len(testloader.dataset)
|
|
97
|
-
accuracy = metric.compute()["accuracy"]
|
|
98
|
-
return loss, accuracy
|
|
@@ -1,57 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
-
|
|
3
|
-
import jax
|
|
4
|
-
import jax.numpy as jnp
|
|
5
|
-
import numpy as np
|
|
6
|
-
from sklearn.datasets import make_regression
|
|
7
|
-
from sklearn.model_selection import train_test_split
|
|
8
|
-
|
|
9
|
-
key = jax.random.PRNGKey(0)
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def load_data():
|
|
13
|
-
# Load dataset
|
|
14
|
-
X, y = make_regression(n_features=3, random_state=0)
|
|
15
|
-
X, X_test, y, y_test = train_test_split(X, y)
|
|
16
|
-
return X, y, X_test, y_test
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def load_model(model_shape):
|
|
20
|
-
# Extract model parameters
|
|
21
|
-
params = {"b": jax.random.uniform(key), "w": jax.random.uniform(key, model_shape)}
|
|
22
|
-
return params
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
def loss_fn(params, X, y):
|
|
26
|
-
# Return MSE as loss
|
|
27
|
-
err = jnp.dot(X, params["w"]) + params["b"] - y
|
|
28
|
-
return jnp.mean(jnp.square(err))
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def train(params, grad_fn, X, y):
|
|
32
|
-
loss = 1_000_000
|
|
33
|
-
num_examples = X.shape[0]
|
|
34
|
-
for _ in range(50):
|
|
35
|
-
grads = grad_fn(params, X, y)
|
|
36
|
-
params = jax.tree.map(lambda p, g: p - 0.05 * g, params, grads)
|
|
37
|
-
loss = loss_fn(params, X, y)
|
|
38
|
-
return params, loss, num_examples
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
def evaluation(params, grad_fn, X_test, y_test):
|
|
42
|
-
num_examples = X_test.shape[0]
|
|
43
|
-
err_test = loss_fn(params, X_test, y_test)
|
|
44
|
-
loss_test = jnp.mean(jnp.square(err_test))
|
|
45
|
-
return loss_test, num_examples
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
def get_params(params):
|
|
49
|
-
parameters = []
|
|
50
|
-
for _, val in params.items():
|
|
51
|
-
parameters.append(np.array(val))
|
|
52
|
-
return parameters
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def set_params(local_params, global_params):
|
|
56
|
-
for key, value in list(zip(local_params.keys(), global_params)):
|
|
57
|
-
local_params[key] = value
|
|
@@ -1,102 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
-
|
|
3
|
-
import mlx.core as mx
|
|
4
|
-
import mlx.nn as nn
|
|
5
|
-
import numpy as np
|
|
6
|
-
from flwr_datasets import FederatedDataset
|
|
7
|
-
from flwr_datasets.partitioner import IidPartitioner
|
|
8
|
-
|
|
9
|
-
from datasets.utils.logging import disable_progress_bar
|
|
10
|
-
|
|
11
|
-
disable_progress_bar()
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class MLP(nn.Module):
|
|
15
|
-
"""A simple MLP."""
|
|
16
|
-
|
|
17
|
-
def __init__(
|
|
18
|
-
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
|
19
|
-
):
|
|
20
|
-
super().__init__()
|
|
21
|
-
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
|
22
|
-
self.layers = [
|
|
23
|
-
nn.Linear(idim, odim)
|
|
24
|
-
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
|
25
|
-
]
|
|
26
|
-
|
|
27
|
-
def __call__(self, x):
|
|
28
|
-
for l in self.layers[:-1]:
|
|
29
|
-
x = mx.maximum(l(x), 0.0)
|
|
30
|
-
return self.layers[-1](x)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def loss_fn(model, X, y):
|
|
34
|
-
return mx.mean(nn.losses.cross_entropy(model(X), y))
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def eval_fn(model, X, y):
|
|
38
|
-
return mx.mean(mx.argmax(model(X), axis=1) == y)
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
def batch_iterate(batch_size, X, y):
|
|
42
|
-
perm = mx.array(np.random.permutation(y.size))
|
|
43
|
-
for s in range(0, y.size, batch_size):
|
|
44
|
-
ids = perm[s : s + batch_size]
|
|
45
|
-
yield X[ids], y[ids]
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
fds = None # Cache FederatedDataset
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
def load_data(partition_id: int, num_partitions: int):
|
|
52
|
-
# Only initialize `FederatedDataset` once
|
|
53
|
-
global fds
|
|
54
|
-
if fds is None:
|
|
55
|
-
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
56
|
-
fds = FederatedDataset(
|
|
57
|
-
dataset="ylecun/mnist",
|
|
58
|
-
partitioners={"train": partitioner},
|
|
59
|
-
trust_remote_code=True,
|
|
60
|
-
)
|
|
61
|
-
partition = fds.load_partition(partition_id)
|
|
62
|
-
partition_splits = partition.train_test_split(test_size=0.2, seed=42)
|
|
63
|
-
|
|
64
|
-
partition_splits["train"].set_format("numpy")
|
|
65
|
-
partition_splits["test"].set_format("numpy")
|
|
66
|
-
|
|
67
|
-
train_partition = partition_splits["train"].map(
|
|
68
|
-
lambda img: {
|
|
69
|
-
"img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0
|
|
70
|
-
},
|
|
71
|
-
input_columns="image",
|
|
72
|
-
)
|
|
73
|
-
test_partition = partition_splits["test"].map(
|
|
74
|
-
lambda img: {
|
|
75
|
-
"img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0
|
|
76
|
-
},
|
|
77
|
-
input_columns="image",
|
|
78
|
-
)
|
|
79
|
-
|
|
80
|
-
data = (
|
|
81
|
-
train_partition["img"],
|
|
82
|
-
train_partition["label"].astype(np.uint32),
|
|
83
|
-
test_partition["img"],
|
|
84
|
-
test_partition["label"].astype(np.uint32),
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
train_images, train_labels, test_images, test_labels = map(mx.array, data)
|
|
88
|
-
return train_images, train_labels, test_images, test_labels
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
def get_params(model):
|
|
92
|
-
layers = model.parameters()["layers"]
|
|
93
|
-
return [np.array(val) for layer in layers for _, val in layer.items()]
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
def set_params(model, parameters):
|
|
97
|
-
new_params = {}
|
|
98
|
-
new_params["layers"] = [
|
|
99
|
-
{"weight": mx.array(parameters[i]), "bias": mx.array(parameters[i + 1])}
|
|
100
|
-
for i in range(0, len(parameters), 2)
|
|
101
|
-
]
|
|
102
|
-
model.update(new_params)
|
|
@@ -1,98 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
import torch.nn as nn
|
|
5
|
-
import torch.nn.functional as F
|
|
6
|
-
from flwr_datasets import FederatedDataset
|
|
7
|
-
from flwr_datasets.partitioner import IidPartitioner
|
|
8
|
-
from torch.utils.data import DataLoader
|
|
9
|
-
from torchvision.transforms import Compose, Normalize, ToTensor
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class Net(nn.Module):
|
|
13
|
-
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""
|
|
14
|
-
|
|
15
|
-
def __init__(self):
|
|
16
|
-
super(Net, self).__init__()
|
|
17
|
-
self.conv1 = nn.Conv2d(3, 6, 5)
|
|
18
|
-
self.pool = nn.MaxPool2d(2, 2)
|
|
19
|
-
self.conv2 = nn.Conv2d(6, 16, 5)
|
|
20
|
-
self.fc1 = nn.Linear(16 * 5 * 5, 120)
|
|
21
|
-
self.fc2 = nn.Linear(120, 84)
|
|
22
|
-
self.fc3 = nn.Linear(84, 10)
|
|
23
|
-
|
|
24
|
-
def forward(self, x):
|
|
25
|
-
x = self.pool(F.relu(self.conv1(x)))
|
|
26
|
-
x = self.pool(F.relu(self.conv2(x)))
|
|
27
|
-
x = x.view(-1, 16 * 5 * 5)
|
|
28
|
-
x = F.relu(self.fc1(x))
|
|
29
|
-
x = F.relu(self.fc2(x))
|
|
30
|
-
return self.fc3(x)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
fds = None # Cache FederatedDataset
|
|
34
|
-
|
|
35
|
-
pytorch_transforms = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def apply_transforms(batch):
|
|
39
|
-
"""Apply transforms to the partition from FederatedDataset."""
|
|
40
|
-
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
|
|
41
|
-
return batch
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
def load_data(partition_id: int, num_partitions: int):
|
|
45
|
-
"""Load partition CIFAR10 data."""
|
|
46
|
-
# Only initialize `FederatedDataset` once
|
|
47
|
-
global fds
|
|
48
|
-
if fds is None:
|
|
49
|
-
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
50
|
-
fds = FederatedDataset(
|
|
51
|
-
dataset="uoft-cs/cifar10",
|
|
52
|
-
partitioners={"train": partitioner},
|
|
53
|
-
)
|
|
54
|
-
partition = fds.load_partition(partition_id)
|
|
55
|
-
# Divide data on each node: 80% train, 20% test
|
|
56
|
-
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
|
57
|
-
# Construct dataloaders
|
|
58
|
-
partition_train_test = partition_train_test.with_transform(apply_transforms)
|
|
59
|
-
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
|
|
60
|
-
testloader = DataLoader(partition_train_test["test"], batch_size=32)
|
|
61
|
-
return trainloader, testloader
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
def train(net, trainloader, epochs, lr, device):
|
|
65
|
-
"""Train the model on the training set."""
|
|
66
|
-
net.to(device) # move model to GPU if available
|
|
67
|
-
criterion = torch.nn.CrossEntropyLoss().to(device)
|
|
68
|
-
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
|
|
69
|
-
net.train()
|
|
70
|
-
running_loss = 0.0
|
|
71
|
-
for _ in range(epochs):
|
|
72
|
-
for batch in trainloader:
|
|
73
|
-
images = batch["img"].to(device)
|
|
74
|
-
labels = batch["label"].to(device)
|
|
75
|
-
optimizer.zero_grad()
|
|
76
|
-
loss = criterion(net(images), labels)
|
|
77
|
-
loss.backward()
|
|
78
|
-
optimizer.step()
|
|
79
|
-
running_loss += loss.item()
|
|
80
|
-
avg_trainloss = running_loss / len(trainloader)
|
|
81
|
-
return avg_trainloss
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
def test(net, testloader, device):
|
|
85
|
-
"""Validate the model on the test set."""
|
|
86
|
-
net.to(device)
|
|
87
|
-
criterion = torch.nn.CrossEntropyLoss()
|
|
88
|
-
correct, loss = 0, 0.0
|
|
89
|
-
with torch.no_grad():
|
|
90
|
-
for batch in testloader:
|
|
91
|
-
images = batch["img"].to(device)
|
|
92
|
-
labels = batch["label"].to(device)
|
|
93
|
-
outputs = net(images)
|
|
94
|
-
loss += criterion(outputs, labels).item()
|
|
95
|
-
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
|
|
96
|
-
accuracy = correct / len(testloader.dataset)
|
|
97
|
-
loss = loss / len(testloader)
|
|
98
|
-
return loss, accuracy
|
|
@@ -1,111 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
-
|
|
3
|
-
from collections import OrderedDict
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
import torch.nn as nn
|
|
7
|
-
import torch.nn.functional as F
|
|
8
|
-
from flwr_datasets import FederatedDataset
|
|
9
|
-
from flwr_datasets.partitioner import IidPartitioner
|
|
10
|
-
from torch.utils.data import DataLoader
|
|
11
|
-
from torchvision.transforms import Compose, Normalize, ToTensor
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class Net(nn.Module):
|
|
15
|
-
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""
|
|
16
|
-
|
|
17
|
-
def __init__(self):
|
|
18
|
-
super(Net, self).__init__()
|
|
19
|
-
self.conv1 = nn.Conv2d(3, 6, 5)
|
|
20
|
-
self.pool = nn.MaxPool2d(2, 2)
|
|
21
|
-
self.conv2 = nn.Conv2d(6, 16, 5)
|
|
22
|
-
self.fc1 = nn.Linear(16 * 5 * 5, 120)
|
|
23
|
-
self.fc2 = nn.Linear(120, 84)
|
|
24
|
-
self.fc3 = nn.Linear(84, 10)
|
|
25
|
-
|
|
26
|
-
def forward(self, x):
|
|
27
|
-
x = self.pool(F.relu(self.conv1(x)))
|
|
28
|
-
x = self.pool(F.relu(self.conv2(x)))
|
|
29
|
-
x = x.view(-1, 16 * 5 * 5)
|
|
30
|
-
x = F.relu(self.fc1(x))
|
|
31
|
-
x = F.relu(self.fc2(x))
|
|
32
|
-
return self.fc3(x)
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
fds = None # Cache FederatedDataset
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def load_data(partition_id: int, num_partitions: int):
|
|
39
|
-
"""Load partition CIFAR10 data."""
|
|
40
|
-
# Only initialize `FederatedDataset` once
|
|
41
|
-
global fds
|
|
42
|
-
if fds is None:
|
|
43
|
-
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
44
|
-
fds = FederatedDataset(
|
|
45
|
-
dataset="uoft-cs/cifar10",
|
|
46
|
-
partitioners={"train": partitioner},
|
|
47
|
-
)
|
|
48
|
-
partition = fds.load_partition(partition_id)
|
|
49
|
-
# Divide data on each node: 80% train, 20% test
|
|
50
|
-
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
|
51
|
-
pytorch_transforms = Compose(
|
|
52
|
-
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
def apply_transforms(batch):
|
|
56
|
-
"""Apply transforms to the partition from FederatedDataset."""
|
|
57
|
-
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
|
|
58
|
-
return batch
|
|
59
|
-
|
|
60
|
-
partition_train_test = partition_train_test.with_transform(apply_transforms)
|
|
61
|
-
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
|
|
62
|
-
testloader = DataLoader(partition_train_test["test"], batch_size=32)
|
|
63
|
-
return trainloader, testloader
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
def train(net, trainloader, epochs, device):
|
|
67
|
-
"""Train the model on the training set."""
|
|
68
|
-
net.to(device) # move model to GPU if available
|
|
69
|
-
criterion = torch.nn.CrossEntropyLoss().to(device)
|
|
70
|
-
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
|
|
71
|
-
net.train()
|
|
72
|
-
running_loss = 0.0
|
|
73
|
-
for _ in range(epochs):
|
|
74
|
-
for batch in trainloader:
|
|
75
|
-
images = batch["img"]
|
|
76
|
-
labels = batch["label"]
|
|
77
|
-
optimizer.zero_grad()
|
|
78
|
-
loss = criterion(net(images.to(device)), labels.to(device))
|
|
79
|
-
loss.backward()
|
|
80
|
-
optimizer.step()
|
|
81
|
-
running_loss += loss.item()
|
|
82
|
-
|
|
83
|
-
avg_trainloss = running_loss / len(trainloader)
|
|
84
|
-
return avg_trainloss
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def test(net, testloader, device):
|
|
88
|
-
"""Validate the model on the test set."""
|
|
89
|
-
net.to(device)
|
|
90
|
-
criterion = torch.nn.CrossEntropyLoss()
|
|
91
|
-
correct, loss = 0, 0.0
|
|
92
|
-
with torch.no_grad():
|
|
93
|
-
for batch in testloader:
|
|
94
|
-
images = batch["img"].to(device)
|
|
95
|
-
labels = batch["label"].to(device)
|
|
96
|
-
outputs = net(images)
|
|
97
|
-
loss += criterion(outputs, labels).item()
|
|
98
|
-
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
|
|
99
|
-
accuracy = correct / len(testloader.dataset)
|
|
100
|
-
loss = loss / len(testloader)
|
|
101
|
-
return loss, accuracy
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
def get_weights(net):
|
|
105
|
-
return [val.cpu().numpy() for _, val in net.state_dict().items()]
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
def set_weights(net, parameters):
|
|
109
|
-
params_dict = zip(net.state_dict().keys(), parameters)
|
|
110
|
-
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
|
|
111
|
-
net.load_state_dict(state_dict, strict=True)
|