flwr 1.23.0__py3-none-any.whl → 1.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +16 -5
- flwr/app/error.py +2 -2
- flwr/app/exception.py +3 -3
- flwr/cli/app.py +19 -0
- flwr/cli/{new/templates → app_cmd}/__init__.py +9 -1
- flwr/cli/app_cmd/publish.py +285 -0
- flwr/cli/app_cmd/review.py +262 -0
- flwr/cli/auth_plugin/auth_plugin.py +4 -5
- flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
- flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
- flwr/cli/build.py +60 -18
- flwr/cli/cli_account_auth_interceptor.py +24 -7
- flwr/cli/config_utils.py +101 -13
- flwr/cli/{new/templates/app/code/flwr_tune → federation}/__init__.py +10 -1
- flwr/cli/federation/ls.py +140 -0
- flwr/cli/federation/show.py +318 -0
- flwr/cli/install.py +91 -13
- flwr/cli/log.py +52 -9
- flwr/cli/login/login.py +7 -4
- flwr/cli/ls.py +211 -130
- flwr/cli/new/new.py +123 -331
- flwr/cli/pull.py +10 -5
- flwr/cli/run/run.py +71 -29
- flwr/cli/run_utils.py +148 -0
- flwr/cli/stop.py +26 -8
- flwr/cli/supernode/ls.py +25 -12
- flwr/cli/supernode/register.py +9 -4
- flwr/cli/supernode/unregister.py +5 -3
- flwr/cli/utils.py +239 -16
- flwr/client/__init__.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +4 -1
- flwr/client/grpc_adapter_client/connection.py +8 -9
- flwr/client/grpc_rere_client/connection.py +16 -14
- flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
- flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/connection.py +18 -18
- flwr/client/run_info_store.py +4 -5
- flwr/client/typing.py +1 -1
- flwr/clientapp/client_app.py +9 -10
- flwr/clientapp/mod/centraldp_mods.py +16 -17
- flwr/clientapp/mod/localdp_mod.py +8 -9
- flwr/clientapp/typing.py +1 -1
- flwr/clientapp/utils.py +3 -3
- flwr/common/address.py +1 -2
- flwr/common/args.py +3 -4
- flwr/common/config.py +13 -16
- flwr/common/constant.py +5 -2
- flwr/common/differential_privacy.py +3 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -4
- flwr/common/exit/exit.py +15 -2
- flwr/common/exit/exit_code.py +19 -0
- flwr/common/exit/exit_handler.py +6 -2
- flwr/common/exit/signal_handler.py +5 -5
- flwr/common/grpc.py +6 -6
- flwr/common/inflatable_protobuf_utils.py +1 -1
- flwr/common/inflatable_utils.py +38 -21
- flwr/common/logger.py +19 -19
- flwr/common/message.py +4 -4
- flwr/common/object_ref.py +7 -7
- flwr/common/record/array.py +3 -3
- flwr/common/record/arrayrecord.py +18 -30
- flwr/common/record/configrecord.py +3 -3
- flwr/common/record/recorddict.py +5 -5
- flwr/common/record/typeddict.py +9 -2
- flwr/common/recorddict_compat.py +7 -10
- flwr/common/retry_invoker.py +20 -20
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
- flwr/common/serde.py +11 -4
- flwr/common/serde_utils.py +2 -2
- flwr/common/telemetry.py +9 -5
- flwr/common/typing.py +58 -37
- flwr/compat/client/app.py +38 -37
- flwr/compat/client/grpc_client/connection.py +11 -11
- flwr/compat/server/app.py +5 -6
- flwr/proto/appio_pb2.py +13 -3
- flwr/proto/appio_pb2.pyi +134 -65
- flwr/proto/appio_pb2_grpc.py +20 -0
- flwr/proto/appio_pb2_grpc.pyi +27 -0
- flwr/proto/clientappio_pb2.py +17 -7
- flwr/proto/clientappio_pb2.pyi +15 -0
- flwr/proto/clientappio_pb2_grpc.py +206 -40
- flwr/proto/clientappio_pb2_grpc.pyi +168 -53
- flwr/proto/control_pb2.py +71 -52
- flwr/proto/control_pb2.pyi +277 -111
- flwr/proto/control_pb2_grpc.py +249 -40
- flwr/proto/control_pb2_grpc.pyi +185 -52
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2.pyi +24 -6
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/error_pb2_grpc.pyi +27 -0
- flwr/proto/fab_pb2.py +14 -4
- flwr/proto/fab_pb2.pyi +59 -31
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fab_pb2_grpc.pyi +27 -0
- flwr/proto/federation_pb2.py +38 -0
- flwr/proto/federation_pb2.pyi +56 -0
- flwr/proto/federation_pb2_grpc.py +24 -0
- flwr/proto/federation_pb2_grpc.pyi +31 -0
- flwr/proto/fleet_pb2.py +24 -14
- flwr/proto/fleet_pb2.pyi +141 -61
- flwr/proto/fleet_pb2_grpc.py +189 -48
- flwr/proto/fleet_pb2_grpc.pyi +175 -61
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2.pyi +38 -16
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
- flwr/proto/heartbeat_pb2.py +17 -7
- flwr/proto/heartbeat_pb2.pyi +51 -22
- flwr/proto/heartbeat_pb2_grpc.py +20 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2.pyi +34 -11
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/log_pb2_grpc.pyi +27 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2.pyi +154 -86
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/message_pb2_grpc.pyi +27 -0
- flwr/proto/node_pb2.py +15 -5
- flwr/proto/node_pb2.pyi +50 -25
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/node_pb2_grpc.pyi +27 -0
- flwr/proto/recorddict_pb2.py +13 -3
- flwr/proto/recorddict_pb2.pyi +184 -107
- flwr/proto/recorddict_pb2_grpc.py +20 -0
- flwr/proto/recorddict_pb2_grpc.pyi +27 -0
- flwr/proto/run_pb2.py +40 -31
- flwr/proto/run_pb2.pyi +158 -84
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/run_pb2_grpc.pyi +27 -0
- flwr/proto/serverappio_pb2.py +13 -3
- flwr/proto/serverappio_pb2.pyi +32 -8
- flwr/proto/serverappio_pb2_grpc.py +246 -65
- flwr/proto/serverappio_pb2_grpc.pyi +221 -85
- flwr/proto/simulationio_pb2.py +16 -8
- flwr/proto/simulationio_pb2.pyi +15 -0
- flwr/proto/simulationio_pb2_grpc.py +162 -41
- flwr/proto/simulationio_pb2_grpc.pyi +149 -55
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2.pyi +249 -160
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/proto/transport_pb2_grpc.pyi +38 -8
- flwr/server/app.py +39 -17
- flwr/server/client_manager.py +4 -5
- flwr/server/client_proxy.py +10 -11
- flwr/server/compat/app.py +4 -5
- flwr/server/compat/app_utils.py +2 -1
- flwr/server/compat/grid_client_proxy.py +10 -12
- flwr/server/compat/legacy_context.py +3 -4
- flwr/server/fleet_event_log_interceptor.py +2 -1
- flwr/server/grid/grid.py +2 -3
- flwr/server/grid/grpc_grid.py +10 -8
- flwr/server/grid/inmemory_grid.py +4 -4
- flwr/server/run_serverapp.py +2 -3
- flwr/server/server.py +34 -39
- flwr/server/server_app.py +7 -8
- flwr/server/server_config.py +1 -2
- flwr/server/serverapp/app.py +34 -28
- flwr/server/serverapp_components.py +4 -5
- flwr/server/strategy/aggregate.py +9 -8
- flwr/server/strategy/bulyan.py +13 -11
- flwr/server/strategy/dp_adaptive_clipping.py +16 -20
- flwr/server/strategy/dp_fixed_clipping.py +12 -17
- flwr/server/strategy/dpfedavg_adaptive.py +3 -4
- flwr/server/strategy/dpfedavg_fixed.py +6 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
- flwr/server/strategy/fedadagrad.py +18 -14
- flwr/server/strategy/fedadam.py +16 -14
- flwr/server/strategy/fedavg.py +16 -17
- flwr/server/strategy/fedavg_android.py +15 -15
- flwr/server/strategy/fedavgm.py +21 -18
- flwr/server/strategy/fedmedian.py +2 -3
- flwr/server/strategy/fedopt.py +11 -10
- flwr/server/strategy/fedprox.py +10 -9
- flwr/server/strategy/fedtrimmedavg.py +12 -11
- flwr/server/strategy/fedxgb_bagging.py +13 -11
- flwr/server/strategy/fedxgb_cyclic.py +6 -6
- flwr/server/strategy/fedxgb_nn_avg.py +4 -4
- flwr/server/strategy/fedyogi.py +16 -14
- flwr/server/strategy/krum.py +12 -11
- flwr/server/strategy/qfedavg.py +16 -15
- flwr/server/strategy/strategy.py +6 -9
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
- flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +75 -30
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +15 -9
- flwr/server/superlink/linkstate/in_memory_linkstate.py +148 -149
- flwr/server/superlink/linkstate/linkstate.py +91 -43
- flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
- flwr/server/superlink/linkstate/sqlite_linkstate.py +502 -436
- flwr/server/superlink/linkstate/utils.py +6 -6
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
- flwr/server/superlink/utils.py +4 -6
- flwr/server/typing.py +1 -1
- flwr/server/utils/tensorboard.py +15 -8
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
- flwr/serverapp/strategy/bulyan.py +16 -15
- flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
- flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
- flwr/serverapp/strategy/fedadagrad.py +10 -11
- flwr/serverapp/strategy/fedadam.py +10 -11
- flwr/serverapp/strategy/fedavg.py +9 -10
- flwr/serverapp/strategy/fedavgm.py +17 -16
- flwr/serverapp/strategy/fedmedian.py +2 -2
- flwr/serverapp/strategy/fedopt.py +10 -11
- flwr/serverapp/strategy/fedprox.py +7 -8
- flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
- flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
- flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
- flwr/serverapp/strategy/fedyogi.py +9 -11
- flwr/serverapp/strategy/krum.py +7 -7
- flwr/serverapp/strategy/multikrum.py +9 -9
- flwr/serverapp/strategy/qfedavg.py +17 -16
- flwr/serverapp/strategy/strategy.py +6 -9
- flwr/serverapp/strategy/strategy_utils.py +7 -8
- flwr/simulation/app.py +46 -42
- flwr/simulation/legacy_app.py +12 -12
- flwr/simulation/ray_transport/ray_actor.py +10 -11
- flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
- flwr/simulation/run_simulation.py +43 -43
- flwr/simulation/simulationio_connection.py +4 -4
- flwr/supercore/cli/flower_superexec.py +3 -4
- flwr/supercore/constant.py +34 -1
- flwr/supercore/corestate/corestate.py +24 -3
- flwr/supercore/corestate/in_memory_corestate.py +138 -0
- flwr/supercore/corestate/sqlite_corestate.py +157 -0
- flwr/supercore/ffs/disk_ffs.py +1 -2
- flwr/supercore/ffs/ffs.py +1 -2
- flwr/supercore/ffs/ffs_factory.py +1 -2
- flwr/{common → supercore}/heartbeat.py +20 -25
- flwr/supercore/object_store/in_memory_object_store.py +1 -2
- flwr/supercore/object_store/object_store.py +1 -2
- flwr/supercore/object_store/object_store_factory.py +1 -2
- flwr/supercore/object_store/sqlite_object_store.py +8 -7
- flwr/supercore/primitives/asymmetric.py +1 -1
- flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
- flwr/supercore/sqlite_mixin.py +37 -34
- flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
- flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
- flwr/supercore/superexec/run_superexec.py +9 -13
- flwr/supercore/utils.py +190 -0
- flwr/superlink/artifact_provider/artifact_provider.py +1 -2
- flwr/superlink/auth_plugin/auth_plugin.py +6 -9
- flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
- flwr/{cli/new/templates/app → superlink/federation}/__init__.py +10 -1
- flwr/superlink/federation/federation_manager.py +64 -0
- flwr/superlink/federation/noop_federation_manager.py +71 -0
- flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
- flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
- flwr/superlink/servicer/control/control_grpc.py +7 -6
- flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
- flwr/superlink/servicer/control/control_servicer.py +190 -23
- flwr/supernode/cli/flower_supernode.py +58 -3
- flwr/supernode/nodestate/in_memory_nodestate.py +121 -49
- flwr/supernode/nodestate/nodestate.py +52 -8
- flwr/supernode/nodestate/nodestate_factory.py +7 -4
- flwr/supernode/runtime/run_clientapp.py +41 -22
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +46 -10
- flwr/supernode/start_client_internal.py +165 -46
- {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/METADATA +9 -11
- flwr-1.25.0.dist-info/RECORD +393 -0
- flwr/cli/new/templates/app/.gitignore.tpl +0 -163
- flwr/cli/new/templates/app/LICENSE.tpl +0 -202
- flwr/cli/new/templates/app/README.baseline.md.tpl +0 -127
- flwr/cli/new/templates/app/README.flowertune.md.tpl +0 -68
- flwr/cli/new/templates/app/README.md.tpl +0 -37
- flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/code/__init__.py +0 -15
- flwr/cli/new/templates/app/code/__init__.py.tpl +0 -1
- flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +0 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +0 -75
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +0 -93
- flwr/cli/new/templates/app/code/client.jax.py.tpl +0 -71
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +0 -102
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +0 -46
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +0 -80
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +0 -55
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +0 -108
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -82
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +0 -110
- flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +0 -36
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +0 -92
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +0 -87
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -56
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +0 -73
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +0 -78
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -66
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +0 -43
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +0 -42
- flwr/cli/new/templates/app/code/server.jax.py.tpl +0 -39
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +0 -41
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +0 -38
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +0 -41
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +0 -31
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +0 -44
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +0 -38
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +0 -56
- flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +0 -98
- flwr/cli/new/templates/app/code/task.jax.py.tpl +0 -57
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +0 -102
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +0 -7
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +0 -98
- flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +0 -111
- flwr/cli/new/templates/app/code/task.sklearn.py.tpl +0 -67
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +0 -52
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +0 -67
- flwr/cli/new/templates/app/code/utils.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +0 -146
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +0 -80
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +0 -65
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +0 -52
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +0 -56
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +0 -49
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +0 -52
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +0 -61
- flwr/supercore/object_store/utils.py +0 -43
- flwr-1.23.0.dist-info/RECORD +0 -439
- {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/WHEEL +0 -0
- {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/entry_points.txt +0 -0
flwr/server/server.py
CHANGED
|
@@ -19,7 +19,6 @@ import concurrent.futures
|
|
|
19
19
|
import io
|
|
20
20
|
import timeit
|
|
21
21
|
from logging import INFO, WARN
|
|
22
|
-
from typing import Optional, Union
|
|
23
22
|
|
|
24
23
|
from flwr.common import (
|
|
25
24
|
Code,
|
|
@@ -43,15 +42,15 @@ from .server_config import ServerConfig
|
|
|
43
42
|
|
|
44
43
|
FitResultsAndFailures = tuple[
|
|
45
44
|
list[tuple[ClientProxy, FitRes]],
|
|
46
|
-
list[
|
|
45
|
+
list[tuple[ClientProxy, FitRes] | BaseException],
|
|
47
46
|
]
|
|
48
47
|
EvaluateResultsAndFailures = tuple[
|
|
49
48
|
list[tuple[ClientProxy, EvaluateRes]],
|
|
50
|
-
list[
|
|
49
|
+
list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
51
50
|
]
|
|
52
51
|
ReconnectResultsAndFailures = tuple[
|
|
53
52
|
list[tuple[ClientProxy, DisconnectRes]],
|
|
54
|
-
list[
|
|
53
|
+
list[tuple[ClientProxy, DisconnectRes] | BaseException],
|
|
55
54
|
]
|
|
56
55
|
|
|
57
56
|
|
|
@@ -62,16 +61,16 @@ class Server:
|
|
|
62
61
|
self,
|
|
63
62
|
*,
|
|
64
63
|
client_manager: ClientManager,
|
|
65
|
-
strategy:
|
|
64
|
+
strategy: Strategy | None = None,
|
|
66
65
|
) -> None:
|
|
67
66
|
self._client_manager: ClientManager = client_manager
|
|
68
67
|
self.parameters: Parameters = Parameters(
|
|
69
68
|
tensors=[], tensor_type="numpy.ndarray"
|
|
70
69
|
)
|
|
71
70
|
self.strategy: Strategy = strategy if strategy is not None else FedAvg()
|
|
72
|
-
self.max_workers:
|
|
71
|
+
self.max_workers: int | None = None
|
|
73
72
|
|
|
74
|
-
def set_max_workers(self, max_workers:
|
|
73
|
+
def set_max_workers(self, max_workers: int | None) -> None:
|
|
75
74
|
"""Set the max_workers used by ThreadPoolExecutor."""
|
|
76
75
|
self.max_workers = max_workers
|
|
77
76
|
|
|
@@ -84,7 +83,7 @@ class Server:
|
|
|
84
83
|
return self._client_manager
|
|
85
84
|
|
|
86
85
|
# pylint: disable=too-many-locals
|
|
87
|
-
def fit(self, num_rounds: int, timeout:
|
|
86
|
+
def fit(self, num_rounds: int, timeout: float | None) -> tuple[History, float]:
|
|
88
87
|
"""Run federated averaging for a number of rounds."""
|
|
89
88
|
history = History()
|
|
90
89
|
|
|
@@ -161,10 +160,8 @@ class Server:
|
|
|
161
160
|
def evaluate_round(
|
|
162
161
|
self,
|
|
163
162
|
server_round: int,
|
|
164
|
-
timeout:
|
|
165
|
-
) ->
|
|
166
|
-
tuple[Optional[float], dict[str, Scalar], EvaluateResultsAndFailures]
|
|
167
|
-
]:
|
|
163
|
+
timeout: float | None,
|
|
164
|
+
) -> tuple[float | None, dict[str, Scalar], EvaluateResultsAndFailures] | None:
|
|
168
165
|
"""Validate current global model on a number of clients."""
|
|
169
166
|
# Get clients and their respective instructions from strategy
|
|
170
167
|
client_instructions = self.strategy.configure_evaluate(
|
|
@@ -198,7 +195,7 @@ class Server:
|
|
|
198
195
|
|
|
199
196
|
# Aggregate the evaluation results
|
|
200
197
|
aggregated_result: tuple[
|
|
201
|
-
|
|
198
|
+
float | None,
|
|
202
199
|
dict[str, Scalar],
|
|
203
200
|
] = self.strategy.aggregate_evaluate(server_round, results, failures)
|
|
204
201
|
|
|
@@ -208,10 +205,8 @@ class Server:
|
|
|
208
205
|
def fit_round(
|
|
209
206
|
self,
|
|
210
207
|
server_round: int,
|
|
211
|
-
timeout:
|
|
212
|
-
) ->
|
|
213
|
-
tuple[Optional[Parameters], dict[str, Scalar], FitResultsAndFailures]
|
|
214
|
-
]:
|
|
208
|
+
timeout: float | None,
|
|
209
|
+
) -> tuple[Parameters | None, dict[str, Scalar], FitResultsAndFailures] | None:
|
|
215
210
|
"""Perform a single round of federated averaging."""
|
|
216
211
|
# Get clients and their respective instructions from strategy
|
|
217
212
|
client_instructions = self.strategy.configure_fit(
|
|
@@ -246,14 +241,14 @@ class Server:
|
|
|
246
241
|
|
|
247
242
|
# Aggregate training results
|
|
248
243
|
aggregated_result: tuple[
|
|
249
|
-
|
|
244
|
+
Parameters | None,
|
|
250
245
|
dict[str, Scalar],
|
|
251
246
|
] = self.strategy.aggregate_fit(server_round, results, failures)
|
|
252
247
|
|
|
253
248
|
parameters_aggregated, metrics_aggregated = aggregated_result
|
|
254
249
|
return parameters_aggregated, metrics_aggregated, (results, failures)
|
|
255
250
|
|
|
256
|
-
def disconnect_all_clients(self, timeout:
|
|
251
|
+
def disconnect_all_clients(self, timeout: float | None) -> None:
|
|
257
252
|
"""Send shutdown signal to all clients."""
|
|
258
253
|
all_clients = self._client_manager.all()
|
|
259
254
|
clients = [all_clients[k] for k in all_clients.keys()]
|
|
@@ -266,11 +261,11 @@ class Server:
|
|
|
266
261
|
)
|
|
267
262
|
|
|
268
263
|
def _get_initial_parameters(
|
|
269
|
-
self, server_round: int, timeout:
|
|
264
|
+
self, server_round: int, timeout: float | None
|
|
270
265
|
) -> Parameters:
|
|
271
266
|
"""Get initial parameters from one of the available clients."""
|
|
272
267
|
# Server-side parameter initialization
|
|
273
|
-
parameters:
|
|
268
|
+
parameters: Parameters | None = self.strategy.initialize_parameters(
|
|
274
269
|
client_manager=self._client_manager
|
|
275
270
|
)
|
|
276
271
|
if parameters is not None:
|
|
@@ -297,8 +292,8 @@ class Server:
|
|
|
297
292
|
|
|
298
293
|
def reconnect_clients(
|
|
299
294
|
client_instructions: list[tuple[ClientProxy, ReconnectIns]],
|
|
300
|
-
max_workers:
|
|
301
|
-
timeout:
|
|
295
|
+
max_workers: int | None,
|
|
296
|
+
timeout: float | None,
|
|
302
297
|
) -> ReconnectResultsAndFailures:
|
|
303
298
|
"""Instruct clients to disconnect and never reconnect."""
|
|
304
299
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
@@ -313,7 +308,7 @@ def reconnect_clients(
|
|
|
313
308
|
|
|
314
309
|
# Gather results
|
|
315
310
|
results: list[tuple[ClientProxy, DisconnectRes]] = []
|
|
316
|
-
failures: list[
|
|
311
|
+
failures: list[tuple[ClientProxy, DisconnectRes] | BaseException] = []
|
|
317
312
|
for future in finished_fs:
|
|
318
313
|
failure = future.exception()
|
|
319
314
|
if failure is not None:
|
|
@@ -327,7 +322,7 @@ def reconnect_clients(
|
|
|
327
322
|
def reconnect_client(
|
|
328
323
|
client: ClientProxy,
|
|
329
324
|
reconnect: ReconnectIns,
|
|
330
|
-
timeout:
|
|
325
|
+
timeout: float | None,
|
|
331
326
|
) -> tuple[ClientProxy, DisconnectRes]:
|
|
332
327
|
"""Instruct client to disconnect and (optionally) reconnect later."""
|
|
333
328
|
disconnect = client.reconnect(
|
|
@@ -340,8 +335,8 @@ def reconnect_client(
|
|
|
340
335
|
|
|
341
336
|
def fit_clients(
|
|
342
337
|
client_instructions: list[tuple[ClientProxy, FitIns]],
|
|
343
|
-
max_workers:
|
|
344
|
-
timeout:
|
|
338
|
+
max_workers: int | None,
|
|
339
|
+
timeout: float | None,
|
|
345
340
|
group_id: int,
|
|
346
341
|
) -> FitResultsAndFailures:
|
|
347
342
|
"""Refine parameters concurrently on all selected clients."""
|
|
@@ -357,7 +352,7 @@ def fit_clients(
|
|
|
357
352
|
|
|
358
353
|
# Gather results
|
|
359
354
|
results: list[tuple[ClientProxy, FitRes]] = []
|
|
360
|
-
failures: list[
|
|
355
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException] = []
|
|
361
356
|
for future in finished_fs:
|
|
362
357
|
_handle_finished_future_after_fit(
|
|
363
358
|
future=future, results=results, failures=failures
|
|
@@ -366,7 +361,7 @@ def fit_clients(
|
|
|
366
361
|
|
|
367
362
|
|
|
368
363
|
def fit_client(
|
|
369
|
-
client: ClientProxy, ins: FitIns, timeout:
|
|
364
|
+
client: ClientProxy, ins: FitIns, timeout: float | None, group_id: int
|
|
370
365
|
) -> tuple[ClientProxy, FitRes]:
|
|
371
366
|
"""Refine parameters on a single client."""
|
|
372
367
|
fit_res = client.fit(ins, timeout=timeout, group_id=group_id)
|
|
@@ -376,7 +371,7 @@ def fit_client(
|
|
|
376
371
|
def _handle_finished_future_after_fit(
|
|
377
372
|
future: concurrent.futures.Future, # type: ignore
|
|
378
373
|
results: list[tuple[ClientProxy, FitRes]],
|
|
379
|
-
failures: list[
|
|
374
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
380
375
|
) -> None:
|
|
381
376
|
"""Convert finished future into either a result or a failure."""
|
|
382
377
|
# Check if there was an exception
|
|
@@ -400,8 +395,8 @@ def _handle_finished_future_after_fit(
|
|
|
400
395
|
|
|
401
396
|
def evaluate_clients(
|
|
402
397
|
client_instructions: list[tuple[ClientProxy, EvaluateIns]],
|
|
403
|
-
max_workers:
|
|
404
|
-
timeout:
|
|
398
|
+
max_workers: int | None,
|
|
399
|
+
timeout: float | None,
|
|
405
400
|
group_id: int,
|
|
406
401
|
) -> EvaluateResultsAndFailures:
|
|
407
402
|
"""Evaluate parameters concurrently on all selected clients."""
|
|
@@ -417,7 +412,7 @@ def evaluate_clients(
|
|
|
417
412
|
|
|
418
413
|
# Gather results
|
|
419
414
|
results: list[tuple[ClientProxy, EvaluateRes]] = []
|
|
420
|
-
failures: list[
|
|
415
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException] = []
|
|
421
416
|
for future in finished_fs:
|
|
422
417
|
_handle_finished_future_after_evaluate(
|
|
423
418
|
future=future, results=results, failures=failures
|
|
@@ -428,7 +423,7 @@ def evaluate_clients(
|
|
|
428
423
|
def evaluate_client(
|
|
429
424
|
client: ClientProxy,
|
|
430
425
|
ins: EvaluateIns,
|
|
431
|
-
timeout:
|
|
426
|
+
timeout: float | None,
|
|
432
427
|
group_id: int,
|
|
433
428
|
) -> tuple[ClientProxy, EvaluateRes]:
|
|
434
429
|
"""Evaluate parameters on a single client."""
|
|
@@ -439,7 +434,7 @@ def evaluate_client(
|
|
|
439
434
|
def _handle_finished_future_after_evaluate(
|
|
440
435
|
future: concurrent.futures.Future, # type: ignore
|
|
441
436
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
442
|
-
failures: list[
|
|
437
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
443
438
|
) -> None:
|
|
444
439
|
"""Convert finished future into either a result or a failure."""
|
|
445
440
|
# Check if there was an exception
|
|
@@ -462,10 +457,10 @@ def _handle_finished_future_after_evaluate(
|
|
|
462
457
|
|
|
463
458
|
|
|
464
459
|
def init_defaults(
|
|
465
|
-
server:
|
|
466
|
-
config:
|
|
467
|
-
strategy:
|
|
468
|
-
client_manager:
|
|
460
|
+
server: Server | None,
|
|
461
|
+
config: ServerConfig | None,
|
|
462
|
+
strategy: Strategy | None,
|
|
463
|
+
client_manager: ClientManager | None,
|
|
469
464
|
) -> tuple[Server, ServerConfig]:
|
|
470
465
|
"""Create server instance if none was given."""
|
|
471
466
|
if server is None:
|
flwr/server/server_app.py
CHANGED
|
@@ -16,9 +16,8 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import inspect
|
|
19
|
-
from collections.abc import Iterator
|
|
19
|
+
from collections.abc import Callable, Iterator
|
|
20
20
|
from contextlib import contextmanager
|
|
21
|
-
from typing import Callable, Optional
|
|
22
21
|
|
|
23
22
|
from flwr.common import Context
|
|
24
23
|
from flwr.common.logger import warn_deprecated_feature_with_example
|
|
@@ -118,11 +117,11 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
|
118
117
|
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
119
118
|
def __init__(
|
|
120
119
|
self,
|
|
121
|
-
server:
|
|
122
|
-
config:
|
|
123
|
-
strategy:
|
|
124
|
-
client_manager:
|
|
125
|
-
server_fn:
|
|
120
|
+
server: Server | None = None,
|
|
121
|
+
config: ServerConfig | None = None,
|
|
122
|
+
strategy: Strategy | None = None,
|
|
123
|
+
client_manager: ClientManager | None = None,
|
|
124
|
+
server_fn: ServerFn | None = None,
|
|
126
125
|
) -> None:
|
|
127
126
|
if any([server, config, strategy, client_manager]):
|
|
128
127
|
warn_deprecated_feature_with_example(
|
|
@@ -148,7 +147,7 @@ class ServerApp: # pylint: disable=too-many-instance-attributes
|
|
|
148
147
|
self._strategy = strategy
|
|
149
148
|
self._client_manager = client_manager
|
|
150
149
|
self._server_fn = server_fn
|
|
151
|
-
self._main:
|
|
150
|
+
self._main: ServerAppCallable | None = None
|
|
152
151
|
self._lifespan = _empty_lifespan
|
|
153
152
|
|
|
154
153
|
def __call__(self, grid: Grid, context: Context) -> None:
|
flwr/server/server_config.py
CHANGED
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from dataclasses import dataclass
|
|
19
|
-
from typing import Optional
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
@dataclass
|
|
@@ -28,7 +27,7 @@ class ServerConfig:
|
|
|
28
27
|
"""
|
|
29
28
|
|
|
30
29
|
num_rounds: int = 1
|
|
31
|
-
round_timeout:
|
|
30
|
+
round_timeout: float | None = None
|
|
32
31
|
|
|
33
32
|
def __repr__(self) -> str:
|
|
34
33
|
"""Return the string representation of the ServerConfig."""
|
flwr/server/serverapp/app.py
CHANGED
|
@@ -19,7 +19,8 @@ import argparse
|
|
|
19
19
|
from logging import DEBUG, ERROR, INFO
|
|
20
20
|
from pathlib import Path
|
|
21
21
|
from queue import Queue
|
|
22
|
-
|
|
22
|
+
|
|
23
|
+
import grpc
|
|
23
24
|
|
|
24
25
|
from flwr.app.exception import AppExitException
|
|
25
26
|
from flwr.cli.config_utils import get_fab_metadata
|
|
@@ -38,8 +39,7 @@ from flwr.common.constant import (
|
|
|
38
39
|
Status,
|
|
39
40
|
SubStatus,
|
|
40
41
|
)
|
|
41
|
-
from flwr.common.exit import ExitCode,
|
|
42
|
-
from flwr.common.heartbeat import HeartbeatSender, get_grpc_app_heartbeat_fn
|
|
42
|
+
from flwr.common.exit import ExitCode, flwr_exit, register_signal_handlers
|
|
43
43
|
from flwr.common.logger import (
|
|
44
44
|
log,
|
|
45
45
|
mirror_output_to_queue,
|
|
@@ -66,6 +66,7 @@ from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub
|
|
|
66
66
|
from flwr.server.grid.grpc_grid import GrpcGrid
|
|
67
67
|
from flwr.server.run_serverapp import run as run_
|
|
68
68
|
from flwr.supercore.app_utils import start_parent_process_monitor
|
|
69
|
+
from flwr.supercore.heartbeat import HeartbeatSender, make_app_heartbeat_fn_grpc
|
|
69
70
|
from flwr.supercore.superexec.plugin import ServerAppExecPlugin
|
|
70
71
|
from flwr.supercore.superexec.run_superexec import run_with_deprecation_warning
|
|
71
72
|
|
|
@@ -73,7 +74,7 @@ from flwr.supercore.superexec.run_superexec import run_with_deprecation_warning
|
|
|
73
74
|
def flwr_serverapp() -> None:
|
|
74
75
|
"""Run process-isolated Flower ServerApp."""
|
|
75
76
|
# Capture stdout/stderr
|
|
76
|
-
log_queue: Queue[
|
|
77
|
+
log_queue: Queue[str | None] = Queue()
|
|
77
78
|
mirror_output_to_queue(log_queue)
|
|
78
79
|
|
|
79
80
|
args = _parse_args_run_flwr_serverapp().parse_args()
|
|
@@ -120,21 +121,22 @@ def flwr_serverapp() -> None:
|
|
|
120
121
|
|
|
121
122
|
def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
122
123
|
serverappio_api_address: str,
|
|
123
|
-
log_queue: Queue[
|
|
124
|
+
log_queue: Queue[str | None],
|
|
124
125
|
token: str,
|
|
125
|
-
flwr_dir:
|
|
126
|
-
certificates:
|
|
127
|
-
parent_pid:
|
|
126
|
+
flwr_dir: str | None = None,
|
|
127
|
+
certificates: bytes | None = None,
|
|
128
|
+
parent_pid: int | None = None,
|
|
128
129
|
) -> None:
|
|
129
130
|
"""Run Flower ServerApp process."""
|
|
130
131
|
# Monitor the main process in case of SIGKILL
|
|
131
132
|
if parent_pid is not None:
|
|
132
133
|
start_parent_process_monitor(parent_pid)
|
|
133
134
|
|
|
134
|
-
#
|
|
135
|
+
# Initialize variables for exit handler
|
|
135
136
|
flwr_dir_ = get_flwr_dir(flwr_dir)
|
|
136
137
|
log_uploader = None
|
|
137
138
|
hash_run_id = None
|
|
139
|
+
run = None
|
|
138
140
|
run_status = None
|
|
139
141
|
heartbeat_sender = None
|
|
140
142
|
grid = None
|
|
@@ -143,7 +145,7 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
|
143
145
|
|
|
144
146
|
def on_exit() -> None:
|
|
145
147
|
# Stop heartbeat sender
|
|
146
|
-
if heartbeat_sender:
|
|
148
|
+
if heartbeat_sender and heartbeat_sender.is_running:
|
|
147
149
|
heartbeat_sender.stop()
|
|
148
150
|
|
|
149
151
|
# Stop log uploader for this run and upload final logs
|
|
@@ -151,7 +153,7 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
|
151
153
|
stop_log_uploader(log_queue, log_uploader)
|
|
152
154
|
|
|
153
155
|
# Update run status
|
|
154
|
-
if run_status and grid:
|
|
156
|
+
if run and run_status and grid:
|
|
155
157
|
run_status_proto = run_status_to_proto(run_status)
|
|
156
158
|
grid._stub.UpdateRunStatus(
|
|
157
159
|
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
|
|
@@ -161,7 +163,12 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
|
161
163
|
if grid:
|
|
162
164
|
grid.close()
|
|
163
165
|
|
|
164
|
-
|
|
166
|
+
# Register signal handlers for graceful shutdown
|
|
167
|
+
register_signal_handlers(
|
|
168
|
+
event_type=EventType.FLWR_SERVERAPP_RUN_LEAVE,
|
|
169
|
+
exit_message="Run stopped by user.",
|
|
170
|
+
exit_handlers=[on_exit],
|
|
171
|
+
)
|
|
165
172
|
|
|
166
173
|
try:
|
|
167
174
|
# Initialize the GrpcGrid
|
|
@@ -171,9 +178,14 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
|
171
178
|
)
|
|
172
179
|
|
|
173
180
|
# Pull ServerAppInputs from LinkState
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
181
|
+
try:
|
|
182
|
+
log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
|
|
183
|
+
req = PullAppInputsRequest(token=token)
|
|
184
|
+
res: PullAppInputsResponse = grid._stub.PullAppInputs(req)
|
|
185
|
+
except grpc.RpcError as ex:
|
|
186
|
+
if ex.code() == grpc.StatusCode.FAILED_PRECONDITION:
|
|
187
|
+
raise RuntimeError("Failed to start the run.") from ex
|
|
188
|
+
raise
|
|
177
189
|
context = context_from_proto(res.context)
|
|
178
190
|
run = run_from_proto(res.run)
|
|
179
191
|
fab = fab_from_proto(res.fab)
|
|
@@ -214,25 +226,15 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
|
214
226
|
app_path,
|
|
215
227
|
)
|
|
216
228
|
|
|
217
|
-
# Change status to Running
|
|
218
|
-
run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
|
|
219
|
-
grid._stub.UpdateRunStatus(
|
|
220
|
-
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
|
|
221
|
-
)
|
|
222
|
-
|
|
223
229
|
event(
|
|
224
230
|
EventType.FLWR_SERVERAPP_RUN_ENTER,
|
|
225
231
|
event_details={"run-id-hash": hash_run_id},
|
|
226
232
|
)
|
|
227
233
|
|
|
228
234
|
# Set up heartbeat sender
|
|
229
|
-
|
|
230
|
-
grid._stub,
|
|
231
|
-
run.run_id,
|
|
232
|
-
failure_message="Heartbeat failed unexpectedly. The SuperLink could "
|
|
233
|
-
"not find the provided run ID, or the run status is invalid.",
|
|
235
|
+
heartbeat_sender = HeartbeatSender(
|
|
236
|
+
make_app_heartbeat_fn_grpc(grid._stub, token)
|
|
234
237
|
)
|
|
235
|
-
heartbeat_sender = HeartbeatSender(heartbeat_fn)
|
|
236
238
|
heartbeat_sender.start()
|
|
237
239
|
|
|
238
240
|
# Load and run the ServerApp with the Grid
|
|
@@ -256,11 +258,15 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
|
256
258
|
# Raised when the run is already stopped by the user
|
|
257
259
|
except RunNotRunningException:
|
|
258
260
|
log(INFO, "")
|
|
259
|
-
log(INFO, "Run ID %s stopped.", run.run_id)
|
|
261
|
+
log(INFO, "Run ID %s stopped.", run.run_id) # type: ignore[union-attr]
|
|
260
262
|
log(INFO, "")
|
|
261
263
|
run_status = None
|
|
262
264
|
# No need to update the exit code since this is expected behavior
|
|
263
265
|
|
|
266
|
+
except RuntimeError:
|
|
267
|
+
log(ERROR, "Failed to start run.")
|
|
268
|
+
exit_code = ExitCode.SERVERAPP_RUN_START_REJECTED
|
|
269
|
+
|
|
264
270
|
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
265
271
|
exc_entity = "ServerApp"
|
|
266
272
|
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
|
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from dataclasses import dataclass
|
|
19
|
-
from typing import Optional
|
|
20
19
|
|
|
21
20
|
from .client_manager import ClientManager
|
|
22
21
|
from .server import Server
|
|
@@ -46,7 +45,7 @@ class ServerAppComponents: # pylint: disable=too-many-instance-attributes
|
|
|
46
45
|
will be used.
|
|
47
46
|
"""
|
|
48
47
|
|
|
49
|
-
server:
|
|
50
|
-
config:
|
|
51
|
-
strategy:
|
|
52
|
-
client_manager:
|
|
48
|
+
server: Server | None = None
|
|
49
|
+
config: ServerConfig | None = None
|
|
50
|
+
strategy: Strategy | None = None
|
|
51
|
+
client_manager: ClientManager | None = None
|
|
@@ -15,8 +15,9 @@
|
|
|
15
15
|
"""Aggregation functions for strategy implementations."""
|
|
16
16
|
# mypy: disallow_untyped_calls=False
|
|
17
17
|
|
|
18
|
+
from collections.abc import Callable
|
|
18
19
|
from functools import partial, reduce
|
|
19
|
-
from typing import Any
|
|
20
|
+
from typing import Any
|
|
20
21
|
|
|
21
22
|
import numpy as np
|
|
22
23
|
|
|
@@ -37,7 +38,7 @@ def aggregate(results: list[tuple[NDArrays, int]]) -> NDArrays:
|
|
|
37
38
|
# Compute average weights of each layer
|
|
38
39
|
weights_prime: NDArrays = [
|
|
39
40
|
reduce(np.add, layer_updates) / num_examples_total
|
|
40
|
-
for layer_updates in zip(*weighted_weights)
|
|
41
|
+
for layer_updates in zip(*weighted_weights, strict=True)
|
|
41
42
|
]
|
|
42
43
|
return weights_prime
|
|
43
44
|
|
|
@@ -53,7 +54,7 @@ def aggregate_inplace(results: list[tuple[ClientProxy, FitRes]]) -> NDArrays:
|
|
|
53
54
|
)
|
|
54
55
|
|
|
55
56
|
def _try_inplace(
|
|
56
|
-
x: NDArray, y:
|
|
57
|
+
x: NDArray, y: NDArray | np.float64, np_binary_op: np.ufunc
|
|
57
58
|
) -> NDArray:
|
|
58
59
|
return ( # type: ignore[no-any-return]
|
|
59
60
|
np_binary_op(x, y, out=x)
|
|
@@ -75,7 +76,7 @@ def aggregate_inplace(results: list[tuple[ClientProxy, FitRes]]) -> NDArrays:
|
|
|
75
76
|
)
|
|
76
77
|
params = [
|
|
77
78
|
reduce(partial(_try_inplace, np_binary_op=np.add), layer_updates)
|
|
78
|
-
for layer_updates in zip(params, res)
|
|
79
|
+
for layer_updates in zip(params, res, strict=True)
|
|
79
80
|
]
|
|
80
81
|
|
|
81
82
|
return params
|
|
@@ -88,7 +89,7 @@ def aggregate_median(results: list[tuple[NDArrays, int]]) -> NDArrays:
|
|
|
88
89
|
|
|
89
90
|
# Compute median weight of each layer
|
|
90
91
|
median_w: NDArrays = [
|
|
91
|
-
np.median(np.asarray(layer), axis=0) for layer in zip(*weights)
|
|
92
|
+
np.median(np.asarray(layer), axis=0) for layer in zip(*weights, strict=True)
|
|
92
93
|
]
|
|
93
94
|
return median_w
|
|
94
95
|
|
|
@@ -235,7 +236,7 @@ def aggregate_qffl(
|
|
|
235
236
|
for j in range(1, len(deltas)):
|
|
236
237
|
tmp += scaled_deltas[j][i]
|
|
237
238
|
updates.append(tmp)
|
|
238
|
-
new_parameters = [(u - v) * 1.0 for u, v in zip(parameters, updates)]
|
|
239
|
+
new_parameters = [(u - v) * 1.0 for u, v in zip(parameters, updates, strict=True)]
|
|
239
240
|
return new_parameters
|
|
240
241
|
|
|
241
242
|
|
|
@@ -287,7 +288,7 @@ def aggregate_trimmed_avg(
|
|
|
287
288
|
|
|
288
289
|
trimmed_w: NDArrays = [
|
|
289
290
|
_trim_mean(np.asarray(layer), proportiontocut=proportiontocut)
|
|
290
|
-
for layer in zip(*weights)
|
|
291
|
+
for layer in zip(*weights, strict=True)
|
|
291
292
|
]
|
|
292
293
|
|
|
293
294
|
return trimmed_w
|
|
@@ -299,7 +300,7 @@ def _check_weights_equality(weights1: NDArrays, weights2: NDArrays) -> bool:
|
|
|
299
300
|
return False
|
|
300
301
|
return all(
|
|
301
302
|
np.array_equal(layer_weights1, layer_weights2)
|
|
302
|
-
for layer_weights1, layer_weights2 in zip(weights1, weights2)
|
|
303
|
+
for layer_weights1, layer_weights2 in zip(weights1, weights2, strict=True)
|
|
303
304
|
)
|
|
304
305
|
|
|
305
306
|
|
flwr/server/strategy/bulyan.py
CHANGED
|
@@ -18,8 +18,9 @@ Paper: arxiv.org/abs/1802.07927
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
+
from collections.abc import Callable
|
|
21
22
|
from logging import WARNING
|
|
22
|
-
from typing import Any
|
|
23
|
+
from typing import Any
|
|
23
24
|
|
|
24
25
|
from flwr.common import (
|
|
25
26
|
FitRes,
|
|
@@ -84,18 +85,19 @@ class Bulyan(FedAvg):
|
|
|
84
85
|
min_evaluate_clients: int = 2,
|
|
85
86
|
min_available_clients: int = 2,
|
|
86
87
|
num_malicious_clients: int = 0,
|
|
87
|
-
evaluate_fn:
|
|
88
|
+
evaluate_fn: (
|
|
88
89
|
Callable[
|
|
89
90
|
[int, NDArrays, dict[str, Scalar]],
|
|
90
|
-
|
|
91
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
91
92
|
]
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
93
|
+
| None
|
|
94
|
+
) = None,
|
|
95
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
96
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
95
97
|
accept_failures: bool = True,
|
|
96
|
-
initial_parameters:
|
|
97
|
-
fit_metrics_aggregation_fn:
|
|
98
|
-
evaluate_metrics_aggregation_fn:
|
|
98
|
+
initial_parameters: Parameters | None = None,
|
|
99
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
100
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
99
101
|
first_aggregation_rule: Callable = aggregate_krum, # type: ignore
|
|
100
102
|
**aggregation_rule_kwargs: Any,
|
|
101
103
|
) -> None:
|
|
@@ -126,8 +128,8 @@ class Bulyan(FedAvg):
|
|
|
126
128
|
self,
|
|
127
129
|
server_round: int,
|
|
128
130
|
results: list[tuple[ClientProxy, FitRes]],
|
|
129
|
-
failures: list[
|
|
130
|
-
) -> tuple[
|
|
131
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
132
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
131
133
|
"""Aggregate fit results using Bulyan."""
|
|
132
134
|
if not results:
|
|
133
135
|
return None, {}
|