flwr 1.22.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +16 -5
- flwr/app/error.py +2 -2
- flwr/app/exception.py +3 -3
- flwr/cli/app.py +34 -1
- flwr/cli/app_cmd/__init__.py +23 -0
- flwr/cli/app_cmd/publish.py +285 -0
- flwr/cli/app_cmd/review.py +252 -0
- flwr/cli/auth_plugin/__init__.py +15 -6
- flwr/cli/auth_plugin/auth_plugin.py +94 -0
- flwr/cli/auth_plugin/noop_auth_plugin.py +101 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +46 -32
- flwr/cli/build.py +166 -53
- flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +29 -11
- flwr/cli/config_utils.py +101 -13
- flwr/cli/federation/__init__.py +24 -0
- flwr/cli/federation/ls.py +140 -0
- flwr/cli/federation/show.py +317 -0
- flwr/cli/install.py +91 -13
- flwr/cli/log.py +54 -11
- flwr/cli/login/login.py +41 -27
- flwr/cli/ls.py +177 -133
- flwr/cli/new/new.py +175 -40
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/cli/pull.py +12 -7
- flwr/cli/run/run.py +82 -31
- flwr/cli/run_utils.py +130 -0
- flwr/cli/stop.py +27 -9
- flwr/cli/supernode/__init__.py +25 -0
- flwr/cli/supernode/ls.py +268 -0
- flwr/cli/supernode/register.py +190 -0
- flwr/cli/supernode/unregister.py +140 -0
- flwr/cli/utils.py +464 -81
- flwr/client/__init__.py +2 -1
- flwr/client/dpfedavg_numpy_client.py +4 -1
- flwr/client/grpc_adapter_client/connection.py +12 -15
- flwr/client/grpc_rere_client/connection.py +68 -41
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -14
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +5 -7
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +10 -8
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/connection.py +94 -51
- flwr/client/run_info_store.py +4 -5
- flwr/client/typing.py +1 -1
- flwr/clientapp/__init__.py +1 -2
- flwr/{client → clientapp}/client_app.py +9 -10
- flwr/clientapp/mod/centraldp_mods.py +16 -17
- flwr/clientapp/mod/localdp_mod.py +8 -9
- flwr/clientapp/typing.py +1 -1
- flwr/{client/clientapp → clientapp}/utils.py +4 -4
- flwr/common/address.py +1 -2
- flwr/common/args.py +3 -4
- flwr/common/config.py +13 -16
- flwr/common/constant.py +56 -13
- flwr/common/differential_privacy.py +3 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -4
- flwr/common/exit/exit.py +15 -2
- flwr/common/exit/exit_code.py +39 -10
- flwr/common/exit/exit_handler.py +6 -2
- flwr/common/exit/signal_handler.py +5 -5
- flwr/common/grpc.py +6 -6
- flwr/common/inflatable_protobuf_utils.py +1 -1
- flwr/common/inflatable_utils.py +48 -31
- flwr/common/logger.py +19 -19
- flwr/common/message.py +4 -4
- flwr/common/object_ref.py +7 -7
- flwr/common/record/array.py +6 -6
- flwr/common/record/arrayrecord.py +18 -21
- flwr/common/record/configrecord.py +3 -3
- flwr/common/record/recorddict.py +5 -5
- flwr/common/record/typeddict.py +9 -2
- flwr/common/recorddict_compat.py +7 -10
- flwr/common/retry_invoker.py +20 -20
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
- flwr/common/serde.py +9 -6
- flwr/common/serde_utils.py +2 -2
- flwr/common/telemetry.py +9 -5
- flwr/common/typing.py +59 -43
- flwr/compat/client/app.py +39 -38
- flwr/compat/client/grpc_client/connection.py +13 -13
- flwr/compat/server/app.py +5 -6
- flwr/proto/appio_pb2.py +13 -3
- flwr/proto/appio_pb2.pyi +134 -65
- flwr/proto/appio_pb2_grpc.py +20 -0
- flwr/proto/appio_pb2_grpc.pyi +27 -0
- flwr/proto/clientappio_pb2.py +17 -7
- flwr/proto/clientappio_pb2.pyi +15 -0
- flwr/proto/clientappio_pb2_grpc.py +206 -40
- flwr/proto/clientappio_pb2_grpc.pyi +168 -53
- flwr/proto/control_pb2.py +72 -40
- flwr/proto/control_pb2.pyi +319 -87
- flwr/proto/control_pb2_grpc.py +339 -28
- flwr/proto/control_pb2_grpc.pyi +209 -37
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2.pyi +24 -6
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/error_pb2_grpc.pyi +27 -0
- flwr/proto/fab_pb2.py +24 -10
- flwr/proto/fab_pb2.pyi +68 -20
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fab_pb2_grpc.pyi +27 -0
- flwr/proto/federation_pb2.py +38 -0
- flwr/proto/federation_pb2.pyi +56 -0
- flwr/proto/federation_pb2_grpc.py +24 -0
- flwr/proto/federation_pb2_grpc.pyi +31 -0
- flwr/proto/fleet_pb2.py +45 -27
- flwr/proto/fleet_pb2.pyi +186 -70
- flwr/proto/fleet_pb2_grpc.py +277 -66
- flwr/proto/fleet_pb2_grpc.pyi +201 -55
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2.pyi +38 -16
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
- flwr/proto/heartbeat_pb2.py +17 -7
- flwr/proto/heartbeat_pb2.pyi +51 -22
- flwr/proto/heartbeat_pb2_grpc.py +20 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2.pyi +34 -11
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/log_pb2_grpc.pyi +27 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2.pyi +154 -86
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/message_pb2_grpc.pyi +27 -0
- flwr/proto/node_pb2.py +16 -4
- flwr/proto/node_pb2.pyi +77 -4
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/node_pb2_grpc.pyi +27 -0
- flwr/proto/recorddict_pb2.py +13 -3
- flwr/proto/recorddict_pb2.pyi +184 -107
- flwr/proto/recorddict_pb2_grpc.py +20 -0
- flwr/proto/recorddict_pb2_grpc.pyi +27 -0
- flwr/proto/run_pb2.py +40 -31
- flwr/proto/run_pb2.pyi +149 -84
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/run_pb2_grpc.pyi +27 -0
- flwr/proto/serverappio_pb2.py +13 -3
- flwr/proto/serverappio_pb2.pyi +32 -8
- flwr/proto/serverappio_pb2_grpc.py +246 -65
- flwr/proto/serverappio_pb2_grpc.pyi +221 -85
- flwr/proto/simulationio_pb2.py +16 -8
- flwr/proto/simulationio_pb2.pyi +15 -0
- flwr/proto/simulationio_pb2_grpc.py +162 -41
- flwr/proto/simulationio_pb2_grpc.pyi +149 -55
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2.pyi +249 -160
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/proto/transport_pb2_grpc.pyi +38 -8
- flwr/server/app.py +173 -127
- flwr/server/client_manager.py +4 -5
- flwr/server/client_proxy.py +10 -11
- flwr/server/compat/app.py +4 -5
- flwr/server/compat/app_utils.py +2 -1
- flwr/server/compat/grid_client_proxy.py +10 -12
- flwr/server/compat/legacy_context.py +3 -4
- flwr/server/fleet_event_log_interceptor.py +2 -1
- flwr/server/grid/grid.py +2 -3
- flwr/server/grid/grpc_grid.py +10 -8
- flwr/server/grid/inmemory_grid.py +4 -4
- flwr/server/run_serverapp.py +2 -3
- flwr/server/server.py +34 -39
- flwr/server/server_app.py +7 -8
- flwr/server/server_config.py +1 -2
- flwr/server/serverapp/app.py +34 -28
- flwr/server/serverapp_components.py +4 -5
- flwr/server/strategy/aggregate.py +9 -8
- flwr/server/strategy/bulyan.py +13 -11
- flwr/server/strategy/dp_adaptive_clipping.py +16 -20
- flwr/server/strategy/dp_fixed_clipping.py +12 -17
- flwr/server/strategy/dpfedavg_adaptive.py +3 -4
- flwr/server/strategy/dpfedavg_fixed.py +6 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
- flwr/server/strategy/fedadagrad.py +18 -14
- flwr/server/strategy/fedadam.py +16 -14
- flwr/server/strategy/fedavg.py +16 -17
- flwr/server/strategy/fedavg_android.py +15 -15
- flwr/server/strategy/fedavgm.py +21 -18
- flwr/server/strategy/fedmedian.py +2 -3
- flwr/server/strategy/fedopt.py +11 -10
- flwr/server/strategy/fedprox.py +10 -9
- flwr/server/strategy/fedtrimmedavg.py +12 -11
- flwr/server/strategy/fedxgb_bagging.py +13 -11
- flwr/server/strategy/fedxgb_cyclic.py +6 -6
- flwr/server/strategy/fedxgb_nn_avg.py +4 -4
- flwr/server/strategy/fedyogi.py +16 -14
- flwr/server/strategy/krum.py +12 -11
- flwr/server/strategy/qfedavg.py +16 -15
- flwr/server/strategy/strategy.py +6 -9
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +19 -8
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +136 -42
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +28 -51
- flwr/server/superlink/fleet/message_handler/message_handler.py +100 -49
- flwr/server/superlink/fleet/rest_rere/rest_api.py +54 -33
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +6 -6
- flwr/server/superlink/fleet/vce/vce_api.py +32 -13
- flwr/server/superlink/linkstate/in_memory_linkstate.py +266 -207
- flwr/server/superlink/linkstate/linkstate.py +161 -62
- flwr/server/superlink/linkstate/linkstate_factory.py +24 -6
- flwr/server/superlink/linkstate/sqlite_linkstate.py +698 -638
- flwr/server/superlink/linkstate/utils.py +9 -60
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +28 -23
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +19 -14
- flwr/server/superlink/utils.py +4 -6
- flwr/server/typing.py +1 -1
- flwr/server/utils/tensorboard.py +15 -8
- flwr/server/utils/validator.py +2 -3
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +12 -10
- flwr/serverapp/strategy/bulyan.py +16 -15
- flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
- flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
- flwr/serverapp/strategy/fedadagrad.py +10 -11
- flwr/serverapp/strategy/fedadam.py +10 -11
- flwr/serverapp/strategy/fedavg.py +9 -10
- flwr/serverapp/strategy/fedavgm.py +17 -16
- flwr/serverapp/strategy/fedmedian.py +2 -2
- flwr/serverapp/strategy/fedopt.py +10 -11
- flwr/serverapp/strategy/fedprox.py +7 -8
- flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
- flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
- flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
- flwr/serverapp/strategy/fedyogi.py +9 -11
- flwr/serverapp/strategy/krum.py +7 -7
- flwr/serverapp/strategy/multikrum.py +9 -9
- flwr/serverapp/strategy/qfedavg.py +17 -16
- flwr/serverapp/strategy/strategy.py +6 -9
- flwr/serverapp/strategy/strategy_utils.py +7 -8
- flwr/simulation/app.py +46 -42
- flwr/simulation/legacy_app.py +12 -12
- flwr/simulation/ray_transport/ray_actor.py +11 -12
- flwr/simulation/ray_transport/ray_client_proxy.py +12 -13
- flwr/simulation/run_simulation.py +44 -43
- flwr/simulation/simulationio_connection.py +4 -4
- flwr/supercore/cli/flower_superexec.py +3 -4
- flwr/supercore/constant.py +52 -0
- flwr/supercore/corestate/corestate.py +24 -3
- flwr/supercore/corestate/in_memory_corestate.py +138 -0
- flwr/supercore/corestate/sqlite_corestate.py +157 -0
- flwr/supercore/ffs/disk_ffs.py +1 -2
- flwr/supercore/ffs/ffs.py +1 -2
- flwr/supercore/ffs/ffs_factory.py +1 -2
- flwr/{common → supercore}/heartbeat.py +20 -25
- flwr/supercore/object_store/in_memory_object_store.py +1 -6
- flwr/supercore/object_store/object_store.py +1 -2
- flwr/supercore/object_store/object_store_factory.py +27 -8
- flwr/supercore/object_store/sqlite_object_store.py +253 -0
- flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
- flwr/supercore/primitives/asymmetric.py +117 -0
- flwr/supercore/primitives/asymmetric_ed25519.py +175 -0
- flwr/supercore/sqlite_mixin.py +159 -0
- flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
- flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
- flwr/supercore/superexec/run_superexec.py +9 -13
- flwr/supercore/utils.py +20 -0
- flwr/superlink/artifact_provider/artifact_provider.py +1 -2
- flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
- flwr/superlink/auth_plugin/auth_plugin.py +88 -0
- flwr/superlink/auth_plugin/noop_auth_plugin.py +84 -0
- flwr/superlink/federation/__init__.py +24 -0
- flwr/superlink/federation/federation_manager.py +64 -0
- flwr/superlink/federation/noop_federation_manager.py +71 -0
- flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +41 -32
- flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
- flwr/superlink/servicer/control/control_grpc.py +18 -17
- flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
- flwr/superlink/servicer/control/control_servicer.py +239 -63
- flwr/supernode/cli/flower_supernode.py +74 -26
- flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
- flwr/supernode/nodestate/nodestate.py +7 -8
- flwr/supernode/nodestate/nodestate_factory.py +7 -4
- flwr/supernode/runtime/run_clientapp.py +43 -24
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
- flwr/supernode/start_client_internal.py +175 -51
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
- flwr-1.24.0.dist-info/RECORD +454 -0
- flwr/common/auth_plugin/auth_plugin.py +0 -149
- flwr/supercore/object_store/utils.py +0 -43
- flwr-1.22.0.dist-info/RECORD +0 -428
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
flwr/server/serverapp/app.py
CHANGED
|
@@ -19,7 +19,8 @@ import argparse
|
|
|
19
19
|
from logging import DEBUG, ERROR, INFO
|
|
20
20
|
from pathlib import Path
|
|
21
21
|
from queue import Queue
|
|
22
|
-
|
|
22
|
+
|
|
23
|
+
import grpc
|
|
23
24
|
|
|
24
25
|
from flwr.app.exception import AppExitException
|
|
25
26
|
from flwr.cli.config_utils import get_fab_metadata
|
|
@@ -38,8 +39,7 @@ from flwr.common.constant import (
|
|
|
38
39
|
Status,
|
|
39
40
|
SubStatus,
|
|
40
41
|
)
|
|
41
|
-
from flwr.common.exit import ExitCode,
|
|
42
|
-
from flwr.common.heartbeat import HeartbeatSender, get_grpc_app_heartbeat_fn
|
|
42
|
+
from flwr.common.exit import ExitCode, flwr_exit, register_signal_handlers
|
|
43
43
|
from flwr.common.logger import (
|
|
44
44
|
log,
|
|
45
45
|
mirror_output_to_queue,
|
|
@@ -66,6 +66,7 @@ from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub
|
|
|
66
66
|
from flwr.server.grid.grpc_grid import GrpcGrid
|
|
67
67
|
from flwr.server.run_serverapp import run as run_
|
|
68
68
|
from flwr.supercore.app_utils import start_parent_process_monitor
|
|
69
|
+
from flwr.supercore.heartbeat import HeartbeatSender, make_app_heartbeat_fn_grpc
|
|
69
70
|
from flwr.supercore.superexec.plugin import ServerAppExecPlugin
|
|
70
71
|
from flwr.supercore.superexec.run_superexec import run_with_deprecation_warning
|
|
71
72
|
|
|
@@ -73,7 +74,7 @@ from flwr.supercore.superexec.run_superexec import run_with_deprecation_warning
|
|
|
73
74
|
def flwr_serverapp() -> None:
|
|
74
75
|
"""Run process-isolated Flower ServerApp."""
|
|
75
76
|
# Capture stdout/stderr
|
|
76
|
-
log_queue: Queue[
|
|
77
|
+
log_queue: Queue[str | None] = Queue()
|
|
77
78
|
mirror_output_to_queue(log_queue)
|
|
78
79
|
|
|
79
80
|
args = _parse_args_run_flwr_serverapp().parse_args()
|
|
@@ -120,21 +121,22 @@ def flwr_serverapp() -> None:
|
|
|
120
121
|
|
|
121
122
|
def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
122
123
|
serverappio_api_address: str,
|
|
123
|
-
log_queue: Queue[
|
|
124
|
+
log_queue: Queue[str | None],
|
|
124
125
|
token: str,
|
|
125
|
-
flwr_dir:
|
|
126
|
-
certificates:
|
|
127
|
-
parent_pid:
|
|
126
|
+
flwr_dir: str | None = None,
|
|
127
|
+
certificates: bytes | None = None,
|
|
128
|
+
parent_pid: int | None = None,
|
|
128
129
|
) -> None:
|
|
129
130
|
"""Run Flower ServerApp process."""
|
|
130
131
|
# Monitor the main process in case of SIGKILL
|
|
131
132
|
if parent_pid is not None:
|
|
132
133
|
start_parent_process_monitor(parent_pid)
|
|
133
134
|
|
|
134
|
-
#
|
|
135
|
+
# Initialize variables for exit handler
|
|
135
136
|
flwr_dir_ = get_flwr_dir(flwr_dir)
|
|
136
137
|
log_uploader = None
|
|
137
138
|
hash_run_id = None
|
|
139
|
+
run = None
|
|
138
140
|
run_status = None
|
|
139
141
|
heartbeat_sender = None
|
|
140
142
|
grid = None
|
|
@@ -143,7 +145,7 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
|
143
145
|
|
|
144
146
|
def on_exit() -> None:
|
|
145
147
|
# Stop heartbeat sender
|
|
146
|
-
if heartbeat_sender:
|
|
148
|
+
if heartbeat_sender and heartbeat_sender.is_running:
|
|
147
149
|
heartbeat_sender.stop()
|
|
148
150
|
|
|
149
151
|
# Stop log uploader for this run and upload final logs
|
|
@@ -151,7 +153,7 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
|
151
153
|
stop_log_uploader(log_queue, log_uploader)
|
|
152
154
|
|
|
153
155
|
# Update run status
|
|
154
|
-
if run_status and grid:
|
|
156
|
+
if run and run_status and grid:
|
|
155
157
|
run_status_proto = run_status_to_proto(run_status)
|
|
156
158
|
grid._stub.UpdateRunStatus(
|
|
157
159
|
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
|
|
@@ -161,7 +163,12 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
|
161
163
|
if grid:
|
|
162
164
|
grid.close()
|
|
163
165
|
|
|
164
|
-
|
|
166
|
+
# Register signal handlers for graceful shutdown
|
|
167
|
+
register_signal_handlers(
|
|
168
|
+
event_type=EventType.FLWR_SERVERAPP_RUN_LEAVE,
|
|
169
|
+
exit_message="Run stopped by user.",
|
|
170
|
+
exit_handlers=[on_exit],
|
|
171
|
+
)
|
|
165
172
|
|
|
166
173
|
try:
|
|
167
174
|
# Initialize the GrpcGrid
|
|
@@ -171,9 +178,14 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
|
171
178
|
)
|
|
172
179
|
|
|
173
180
|
# Pull ServerAppInputs from LinkState
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
181
|
+
try:
|
|
182
|
+
log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
|
|
183
|
+
req = PullAppInputsRequest(token=token)
|
|
184
|
+
res: PullAppInputsResponse = grid._stub.PullAppInputs(req)
|
|
185
|
+
except grpc.RpcError as ex:
|
|
186
|
+
if ex.code() == grpc.StatusCode.FAILED_PRECONDITION:
|
|
187
|
+
raise RuntimeError("Failed to start the run.") from ex
|
|
188
|
+
raise
|
|
177
189
|
context = context_from_proto(res.context)
|
|
178
190
|
run = run_from_proto(res.run)
|
|
179
191
|
fab = fab_from_proto(res.fab)
|
|
@@ -214,25 +226,15 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
|
214
226
|
app_path,
|
|
215
227
|
)
|
|
216
228
|
|
|
217
|
-
# Change status to Running
|
|
218
|
-
run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
|
|
219
|
-
grid._stub.UpdateRunStatus(
|
|
220
|
-
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
|
|
221
|
-
)
|
|
222
|
-
|
|
223
229
|
event(
|
|
224
230
|
EventType.FLWR_SERVERAPP_RUN_ENTER,
|
|
225
231
|
event_details={"run-id-hash": hash_run_id},
|
|
226
232
|
)
|
|
227
233
|
|
|
228
234
|
# Set up heartbeat sender
|
|
229
|
-
|
|
230
|
-
grid._stub,
|
|
231
|
-
run.run_id,
|
|
232
|
-
failure_message="Heartbeat failed unexpectedly. The SuperLink could "
|
|
233
|
-
"not find the provided run ID, or the run status is invalid.",
|
|
235
|
+
heartbeat_sender = HeartbeatSender(
|
|
236
|
+
make_app_heartbeat_fn_grpc(grid._stub, token)
|
|
234
237
|
)
|
|
235
|
-
heartbeat_sender = HeartbeatSender(heartbeat_fn)
|
|
236
238
|
heartbeat_sender.start()
|
|
237
239
|
|
|
238
240
|
# Load and run the ServerApp with the Grid
|
|
@@ -256,11 +258,15 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
|
256
258
|
# Raised when the run is already stopped by the user
|
|
257
259
|
except RunNotRunningException:
|
|
258
260
|
log(INFO, "")
|
|
259
|
-
log(INFO, "Run ID %s stopped.", run.run_id)
|
|
261
|
+
log(INFO, "Run ID %s stopped.", run.run_id) # type: ignore[union-attr]
|
|
260
262
|
log(INFO, "")
|
|
261
263
|
run_status = None
|
|
262
264
|
# No need to update the exit code since this is expected behavior
|
|
263
265
|
|
|
266
|
+
except RuntimeError:
|
|
267
|
+
log(ERROR, "Failed to start run.")
|
|
268
|
+
exit_code = ExitCode.SERVERAPP_RUN_START_REJECTED
|
|
269
|
+
|
|
264
270
|
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
265
271
|
exc_entity = "ServerApp"
|
|
266
272
|
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
|
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from dataclasses import dataclass
|
|
19
|
-
from typing import Optional
|
|
20
19
|
|
|
21
20
|
from .client_manager import ClientManager
|
|
22
21
|
from .server import Server
|
|
@@ -46,7 +45,7 @@ class ServerAppComponents: # pylint: disable=too-many-instance-attributes
|
|
|
46
45
|
will be used.
|
|
47
46
|
"""
|
|
48
47
|
|
|
49
|
-
server:
|
|
50
|
-
config:
|
|
51
|
-
strategy:
|
|
52
|
-
client_manager:
|
|
48
|
+
server: Server | None = None
|
|
49
|
+
config: ServerConfig | None = None
|
|
50
|
+
strategy: Strategy | None = None
|
|
51
|
+
client_manager: ClientManager | None = None
|
|
@@ -15,8 +15,9 @@
|
|
|
15
15
|
"""Aggregation functions for strategy implementations."""
|
|
16
16
|
# mypy: disallow_untyped_calls=False
|
|
17
17
|
|
|
18
|
+
from collections.abc import Callable
|
|
18
19
|
from functools import partial, reduce
|
|
19
|
-
from typing import Any
|
|
20
|
+
from typing import Any
|
|
20
21
|
|
|
21
22
|
import numpy as np
|
|
22
23
|
|
|
@@ -37,7 +38,7 @@ def aggregate(results: list[tuple[NDArrays, int]]) -> NDArrays:
|
|
|
37
38
|
# Compute average weights of each layer
|
|
38
39
|
weights_prime: NDArrays = [
|
|
39
40
|
reduce(np.add, layer_updates) / num_examples_total
|
|
40
|
-
for layer_updates in zip(*weighted_weights)
|
|
41
|
+
for layer_updates in zip(*weighted_weights, strict=True)
|
|
41
42
|
]
|
|
42
43
|
return weights_prime
|
|
43
44
|
|
|
@@ -53,7 +54,7 @@ def aggregate_inplace(results: list[tuple[ClientProxy, FitRes]]) -> NDArrays:
|
|
|
53
54
|
)
|
|
54
55
|
|
|
55
56
|
def _try_inplace(
|
|
56
|
-
x: NDArray, y:
|
|
57
|
+
x: NDArray, y: NDArray | np.float64, np_binary_op: np.ufunc
|
|
57
58
|
) -> NDArray:
|
|
58
59
|
return ( # type: ignore[no-any-return]
|
|
59
60
|
np_binary_op(x, y, out=x)
|
|
@@ -75,7 +76,7 @@ def aggregate_inplace(results: list[tuple[ClientProxy, FitRes]]) -> NDArrays:
|
|
|
75
76
|
)
|
|
76
77
|
params = [
|
|
77
78
|
reduce(partial(_try_inplace, np_binary_op=np.add), layer_updates)
|
|
78
|
-
for layer_updates in zip(params, res)
|
|
79
|
+
for layer_updates in zip(params, res, strict=True)
|
|
79
80
|
]
|
|
80
81
|
|
|
81
82
|
return params
|
|
@@ -88,7 +89,7 @@ def aggregate_median(results: list[tuple[NDArrays, int]]) -> NDArrays:
|
|
|
88
89
|
|
|
89
90
|
# Compute median weight of each layer
|
|
90
91
|
median_w: NDArrays = [
|
|
91
|
-
np.median(np.asarray(layer), axis=0) for layer in zip(*weights)
|
|
92
|
+
np.median(np.asarray(layer), axis=0) for layer in zip(*weights, strict=True)
|
|
92
93
|
]
|
|
93
94
|
return median_w
|
|
94
95
|
|
|
@@ -235,7 +236,7 @@ def aggregate_qffl(
|
|
|
235
236
|
for j in range(1, len(deltas)):
|
|
236
237
|
tmp += scaled_deltas[j][i]
|
|
237
238
|
updates.append(tmp)
|
|
238
|
-
new_parameters = [(u - v) * 1.0 for u, v in zip(parameters, updates)]
|
|
239
|
+
new_parameters = [(u - v) * 1.0 for u, v in zip(parameters, updates, strict=True)]
|
|
239
240
|
return new_parameters
|
|
240
241
|
|
|
241
242
|
|
|
@@ -287,7 +288,7 @@ def aggregate_trimmed_avg(
|
|
|
287
288
|
|
|
288
289
|
trimmed_w: NDArrays = [
|
|
289
290
|
_trim_mean(np.asarray(layer), proportiontocut=proportiontocut)
|
|
290
|
-
for layer in zip(*weights)
|
|
291
|
+
for layer in zip(*weights, strict=True)
|
|
291
292
|
]
|
|
292
293
|
|
|
293
294
|
return trimmed_w
|
|
@@ -299,7 +300,7 @@ def _check_weights_equality(weights1: NDArrays, weights2: NDArrays) -> bool:
|
|
|
299
300
|
return False
|
|
300
301
|
return all(
|
|
301
302
|
np.array_equal(layer_weights1, layer_weights2)
|
|
302
|
-
for layer_weights1, layer_weights2 in zip(weights1, weights2)
|
|
303
|
+
for layer_weights1, layer_weights2 in zip(weights1, weights2, strict=True)
|
|
303
304
|
)
|
|
304
305
|
|
|
305
306
|
|
flwr/server/strategy/bulyan.py
CHANGED
|
@@ -18,8 +18,9 @@ Paper: arxiv.org/abs/1802.07927
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
+
from collections.abc import Callable
|
|
21
22
|
from logging import WARNING
|
|
22
|
-
from typing import Any
|
|
23
|
+
from typing import Any
|
|
23
24
|
|
|
24
25
|
from flwr.common import (
|
|
25
26
|
FitRes,
|
|
@@ -84,18 +85,19 @@ class Bulyan(FedAvg):
|
|
|
84
85
|
min_evaluate_clients: int = 2,
|
|
85
86
|
min_available_clients: int = 2,
|
|
86
87
|
num_malicious_clients: int = 0,
|
|
87
|
-
evaluate_fn:
|
|
88
|
+
evaluate_fn: (
|
|
88
89
|
Callable[
|
|
89
90
|
[int, NDArrays, dict[str, Scalar]],
|
|
90
|
-
|
|
91
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
91
92
|
]
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
93
|
+
| None
|
|
94
|
+
) = None,
|
|
95
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
96
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
95
97
|
accept_failures: bool = True,
|
|
96
|
-
initial_parameters:
|
|
97
|
-
fit_metrics_aggregation_fn:
|
|
98
|
-
evaluate_metrics_aggregation_fn:
|
|
98
|
+
initial_parameters: Parameters | None = None,
|
|
99
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
100
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
99
101
|
first_aggregation_rule: Callable = aggregate_krum, # type: ignore
|
|
100
102
|
**aggregation_rule_kwargs: Any,
|
|
101
103
|
) -> None:
|
|
@@ -126,8 +128,8 @@ class Bulyan(FedAvg):
|
|
|
126
128
|
self,
|
|
127
129
|
server_round: int,
|
|
128
130
|
results: list[tuple[ClientProxy, FitRes]],
|
|
129
|
-
failures: list[
|
|
130
|
-
) -> tuple[
|
|
131
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
132
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
131
133
|
"""Aggregate fit results using Bulyan."""
|
|
132
134
|
if not results:
|
|
133
135
|
return None, {}
|
|
@@ -20,7 +20,6 @@ Paper (Andrew et al.): https://arxiv.org/abs/1905.03871
|
|
|
20
20
|
|
|
21
21
|
import math
|
|
22
22
|
from logging import INFO, WARNING
|
|
23
|
-
from typing import Optional, Union
|
|
24
23
|
|
|
25
24
|
import numpy as np
|
|
26
25
|
|
|
@@ -97,7 +96,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
|
|
|
97
96
|
initial_clipping_norm: float = 0.1,
|
|
98
97
|
target_clipped_quantile: float = 0.5,
|
|
99
98
|
clip_norm_lr: float = 0.2,
|
|
100
|
-
clipped_count_stddev:
|
|
99
|
+
clipped_count_stddev: float | None = None,
|
|
101
100
|
) -> None:
|
|
102
101
|
super().__init__()
|
|
103
102
|
|
|
@@ -148,9 +147,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
|
|
|
148
147
|
rep = "Differential Privacy Strategy Wrapper (Server-Side Adaptive Clipping)"
|
|
149
148
|
return rep
|
|
150
149
|
|
|
151
|
-
def initialize_parameters(
|
|
152
|
-
self, client_manager: ClientManager
|
|
153
|
-
) -> Optional[Parameters]:
|
|
150
|
+
def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
|
|
154
151
|
"""Initialize global model parameters using given strategy."""
|
|
155
152
|
return self.strategy.initialize_parameters(client_manager)
|
|
156
153
|
|
|
@@ -173,8 +170,8 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
|
|
|
173
170
|
self,
|
|
174
171
|
server_round: int,
|
|
175
172
|
results: list[tuple[ClientProxy, FitRes]],
|
|
176
|
-
failures: list[
|
|
177
|
-
) -> tuple[
|
|
173
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
174
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
178
175
|
"""Aggregate training results and update clip norms."""
|
|
179
176
|
if failures:
|
|
180
177
|
return None, {}
|
|
@@ -192,7 +189,8 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
|
|
|
192
189
|
param = parameters_to_ndarrays(res.parameters)
|
|
193
190
|
# Compute and clip update
|
|
194
191
|
model_update = [
|
|
195
|
-
np.subtract(x, y)
|
|
192
|
+
np.subtract(x, y)
|
|
193
|
+
for (x, y) in zip(param, self.current_round_params, strict=True)
|
|
196
194
|
]
|
|
197
195
|
|
|
198
196
|
norm_bit = adaptive_clip_inputs_inplace(model_update, self.clipping_norm)
|
|
@@ -246,14 +244,14 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
|
|
|
246
244
|
self,
|
|
247
245
|
server_round: int,
|
|
248
246
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
249
|
-
failures: list[
|
|
250
|
-
) -> tuple[
|
|
247
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
248
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
251
249
|
"""Aggregate evaluation losses using the given strategy."""
|
|
252
250
|
return self.strategy.aggregate_evaluate(server_round, results, failures)
|
|
253
251
|
|
|
254
252
|
def evaluate(
|
|
255
253
|
self, server_round: int, parameters: Parameters
|
|
256
|
-
) ->
|
|
254
|
+
) -> tuple[float, dict[str, Scalar]] | None:
|
|
257
255
|
"""Evaluate model parameters using an evaluation function from the strategy."""
|
|
258
256
|
return self.strategy.evaluate(server_round, parameters)
|
|
259
257
|
|
|
@@ -316,7 +314,7 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
|
|
|
316
314
|
initial_clipping_norm: float = 0.1,
|
|
317
315
|
target_clipped_quantile: float = 0.5,
|
|
318
316
|
clip_norm_lr: float = 0.2,
|
|
319
|
-
clipped_count_stddev:
|
|
317
|
+
clipped_count_stddev: float | None = None,
|
|
320
318
|
) -> None:
|
|
321
319
|
super().__init__()
|
|
322
320
|
|
|
@@ -364,9 +362,7 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
|
|
|
364
362
|
rep = "Differential Privacy Strategy Wrapper (Client-Side Adaptive Clipping)"
|
|
365
363
|
return rep
|
|
366
364
|
|
|
367
|
-
def initialize_parameters(
|
|
368
|
-
self, client_manager: ClientManager
|
|
369
|
-
) -> Optional[Parameters]:
|
|
365
|
+
def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
|
|
370
366
|
"""Initialize global model parameters using given strategy."""
|
|
371
367
|
return self.strategy.initialize_parameters(client_manager)
|
|
372
368
|
|
|
@@ -395,8 +391,8 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
|
|
|
395
391
|
self,
|
|
396
392
|
server_round: int,
|
|
397
393
|
results: list[tuple[ClientProxy, FitRes]],
|
|
398
|
-
failures: list[
|
|
399
|
-
) -> tuple[
|
|
394
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
395
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
400
396
|
"""Aggregate training results and update clip norms."""
|
|
401
397
|
if failures:
|
|
402
398
|
return None, {}
|
|
@@ -458,13 +454,13 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
|
|
|
458
454
|
self,
|
|
459
455
|
server_round: int,
|
|
460
456
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
461
|
-
failures: list[
|
|
462
|
-
) -> tuple[
|
|
457
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
458
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
463
459
|
"""Aggregate evaluation losses using the given strategy."""
|
|
464
460
|
return self.strategy.aggregate_evaluate(server_round, results, failures)
|
|
465
461
|
|
|
466
462
|
def evaluate(
|
|
467
463
|
self, server_round: int, parameters: Parameters
|
|
468
|
-
) ->
|
|
464
|
+
) -> tuple[float, dict[str, Scalar]] | None:
|
|
469
465
|
"""Evaluate model parameters using an evaluation function from the strategy."""
|
|
470
466
|
return self.strategy.evaluate(server_round, parameters)
|
|
@@ -19,7 +19,6 @@ Papers: https://arxiv.org/abs/1712.07557, https://arxiv.org/abs/1710.06963
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
from logging import INFO, WARNING
|
|
22
|
-
from typing import Optional, Union
|
|
23
22
|
|
|
24
23
|
from flwr.common import (
|
|
25
24
|
EvaluateIns,
|
|
@@ -109,9 +108,7 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
|
|
|
109
108
|
rep = "Differential Privacy Strategy Wrapper (Server-Side Fixed Clipping)"
|
|
110
109
|
return rep
|
|
111
110
|
|
|
112
|
-
def initialize_parameters(
|
|
113
|
-
self, client_manager: ClientManager
|
|
114
|
-
) -> Optional[Parameters]:
|
|
111
|
+
def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
|
|
115
112
|
"""Initialize global model parameters using given strategy."""
|
|
116
113
|
return self.strategy.initialize_parameters(client_manager)
|
|
117
114
|
|
|
@@ -134,8 +131,8 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
|
|
|
134
131
|
self,
|
|
135
132
|
server_round: int,
|
|
136
133
|
results: list[tuple[ClientProxy, FitRes]],
|
|
137
|
-
failures: list[
|
|
138
|
-
) -> tuple[
|
|
134
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
135
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
139
136
|
"""Compute the updates, clip, and pass them for aggregation.
|
|
140
137
|
|
|
141
138
|
Afterward, add noise to the aggregated parameters.
|
|
@@ -192,14 +189,14 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
|
|
|
192
189
|
self,
|
|
193
190
|
server_round: int,
|
|
194
191
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
195
|
-
failures: list[
|
|
196
|
-
) -> tuple[
|
|
192
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
193
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
197
194
|
"""Aggregate evaluation losses using the given strategy."""
|
|
198
195
|
return self.strategy.aggregate_evaluate(server_round, results, failures)
|
|
199
196
|
|
|
200
197
|
def evaluate(
|
|
201
198
|
self, server_round: int, parameters: Parameters
|
|
202
|
-
) ->
|
|
199
|
+
) -> tuple[float, dict[str, Scalar]] | None:
|
|
203
200
|
"""Evaluate model parameters using an evaluation function from the strategy."""
|
|
204
201
|
return self.strategy.evaluate(server_round, parameters)
|
|
205
202
|
|
|
@@ -277,9 +274,7 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
|
|
|
277
274
|
rep = "Differential Privacy Strategy Wrapper (Client-Side Fixed Clipping)"
|
|
278
275
|
return rep
|
|
279
276
|
|
|
280
|
-
def initialize_parameters(
|
|
281
|
-
self, client_manager: ClientManager
|
|
282
|
-
) -> Optional[Parameters]:
|
|
277
|
+
def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
|
|
283
278
|
"""Initialize global model parameters using given strategy."""
|
|
284
279
|
return self.strategy.initialize_parameters(client_manager)
|
|
285
280
|
|
|
@@ -308,8 +303,8 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
|
|
|
308
303
|
self,
|
|
309
304
|
server_round: int,
|
|
310
305
|
results: list[tuple[ClientProxy, FitRes]],
|
|
311
|
-
failures: list[
|
|
312
|
-
) -> tuple[
|
|
306
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
307
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
313
308
|
"""Add noise to the aggregated parameters."""
|
|
314
309
|
if failures:
|
|
315
310
|
return None, {}
|
|
@@ -349,13 +344,13 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
|
|
|
349
344
|
self,
|
|
350
345
|
server_round: int,
|
|
351
346
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
352
|
-
failures: list[
|
|
353
|
-
) -> tuple[
|
|
347
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
348
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
354
349
|
"""Aggregate evaluation losses using the given strategy."""
|
|
355
350
|
return self.strategy.aggregate_evaluate(server_round, results, failures)
|
|
356
351
|
|
|
357
352
|
def evaluate(
|
|
358
353
|
self, server_round: int, parameters: Parameters
|
|
359
|
-
) ->
|
|
354
|
+
) -> tuple[float, dict[str, Scalar]] | None:
|
|
360
355
|
"""Evaluate model parameters using an evaluation function from the strategy."""
|
|
361
356
|
return self.strategy.evaluate(server_round, parameters)
|
|
@@ -19,7 +19,6 @@ Paper: arxiv.org/pdf/1905.03871.pdf
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
import math
|
|
22
|
-
from typing import Optional, Union
|
|
23
22
|
|
|
24
23
|
import numpy as np
|
|
25
24
|
|
|
@@ -49,7 +48,7 @@ class DPFedAvgAdaptive(DPFedAvgFixed):
|
|
|
49
48
|
server_side_noising: bool = True,
|
|
50
49
|
clip_norm_lr: float = 0.2,
|
|
51
50
|
clip_norm_target_quantile: float = 0.5,
|
|
52
|
-
clip_count_stddev:
|
|
51
|
+
clip_count_stddev: float | None = None,
|
|
53
52
|
) -> None:
|
|
54
53
|
warn_deprecated_feature("`DPFedAvgAdaptive` wrapper")
|
|
55
54
|
super().__init__(
|
|
@@ -119,8 +118,8 @@ class DPFedAvgAdaptive(DPFedAvgFixed):
|
|
|
119
118
|
self,
|
|
120
119
|
server_round: int,
|
|
121
120
|
results: list[tuple[ClientProxy, FitRes]],
|
|
122
|
-
failures: list[
|
|
123
|
-
) -> tuple[
|
|
121
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
122
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
124
123
|
"""Aggregate training results as in DPFedAvgFixed and update clip norms."""
|
|
125
124
|
if failures:
|
|
126
125
|
return None, {}
|
|
@@ -18,8 +18,6 @@ Paper: arxiv.org/pdf/1710.06963.pdf
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from typing import Optional, Union
|
|
22
|
-
|
|
23
21
|
from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar
|
|
24
22
|
from flwr.common.dp import add_gaussian_noise
|
|
25
23
|
from flwr.common.logger import warn_deprecated_feature
|
|
@@ -72,9 +70,7 @@ class DPFedAvgFixed(Strategy):
|
|
|
72
70
|
self.noise_multiplier * self.clip_norm / (self.num_sampled_clients ** (0.5))
|
|
73
71
|
)
|
|
74
72
|
|
|
75
|
-
def initialize_parameters(
|
|
76
|
-
self, client_manager: ClientManager
|
|
77
|
-
) -> Optional[Parameters]:
|
|
73
|
+
def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
|
|
78
74
|
"""Initialize global model parameters using given strategy."""
|
|
79
75
|
return self.strategy.initialize_parameters(client_manager)
|
|
80
76
|
|
|
@@ -149,8 +145,8 @@ class DPFedAvgFixed(Strategy):
|
|
|
149
145
|
self,
|
|
150
146
|
server_round: int,
|
|
151
147
|
results: list[tuple[ClientProxy, FitRes]],
|
|
152
|
-
failures: list[
|
|
153
|
-
) -> tuple[
|
|
148
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
149
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
154
150
|
"""Aggregate training results using unweighted aggregation."""
|
|
155
151
|
if failures:
|
|
156
152
|
return None, {}
|
|
@@ -170,13 +166,13 @@ class DPFedAvgFixed(Strategy):
|
|
|
170
166
|
self,
|
|
171
167
|
server_round: int,
|
|
172
168
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
173
|
-
failures: list[
|
|
174
|
-
) -> tuple[
|
|
169
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
170
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
175
171
|
"""Aggregate evaluation losses using the given strategy."""
|
|
176
172
|
return self.strategy.aggregate_evaluate(server_round, results, failures)
|
|
177
173
|
|
|
178
174
|
def evaluate(
|
|
179
175
|
self, server_round: int, parameters: Parameters
|
|
180
|
-
) ->
|
|
176
|
+
) -> tuple[float, dict[str, Scalar]] | None:
|
|
181
177
|
"""Evaluate model parameters using an evaluation function from the strategy."""
|
|
182
178
|
return self.strategy.evaluate(server_round, parameters)
|
|
@@ -15,8 +15,8 @@
|
|
|
15
15
|
"""Fault-tolerant variant of FedAvg strategy."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from collections.abc import Callable
|
|
18
19
|
from logging import WARNING
|
|
19
|
-
from typing import Callable, Optional, Union
|
|
20
20
|
|
|
21
21
|
from flwr.common import (
|
|
22
22
|
EvaluateRes,
|
|
@@ -47,19 +47,20 @@ class FaultTolerantFedAvg(FedAvg):
|
|
|
47
47
|
min_fit_clients: int = 1,
|
|
48
48
|
min_evaluate_clients: int = 1,
|
|
49
49
|
min_available_clients: int = 1,
|
|
50
|
-
evaluate_fn:
|
|
50
|
+
evaluate_fn: (
|
|
51
51
|
Callable[
|
|
52
52
|
[int, NDArrays, dict[str, Scalar]],
|
|
53
|
-
|
|
53
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
54
54
|
]
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
55
|
+
| None
|
|
56
|
+
) = None,
|
|
57
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
58
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
58
59
|
min_completion_rate_fit: float = 0.5,
|
|
59
60
|
min_completion_rate_evaluate: float = 0.5,
|
|
60
|
-
initial_parameters:
|
|
61
|
-
fit_metrics_aggregation_fn:
|
|
62
|
-
evaluate_metrics_aggregation_fn:
|
|
61
|
+
initial_parameters: Parameters | None = None,
|
|
62
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
63
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
63
64
|
) -> None:
|
|
64
65
|
super().__init__(
|
|
65
66
|
fraction_fit=fraction_fit,
|
|
@@ -86,8 +87,8 @@ class FaultTolerantFedAvg(FedAvg):
|
|
|
86
87
|
self,
|
|
87
88
|
server_round: int,
|
|
88
89
|
results: list[tuple[ClientProxy, FitRes]],
|
|
89
|
-
failures: list[
|
|
90
|
-
) -> tuple[
|
|
90
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
91
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
91
92
|
"""Aggregate fit results using weighted average."""
|
|
92
93
|
if not results:
|
|
93
94
|
return None, {}
|
|
@@ -118,8 +119,8 @@ class FaultTolerantFedAvg(FedAvg):
|
|
|
118
119
|
self,
|
|
119
120
|
server_round: int,
|
|
120
121
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
121
|
-
failures: list[
|
|
122
|
-
) -> tuple[
|
|
122
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
123
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
123
124
|
"""Aggregate evaluation losses using weighted average."""
|
|
124
125
|
if not results:
|
|
125
126
|
return None, {}
|