flwr 1.23.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +16 -5
- flwr/app/error.py +2 -2
- flwr/app/exception.py +3 -3
- flwr/cli/app.py +19 -0
- flwr/cli/app_cmd/__init__.py +23 -0
- flwr/cli/app_cmd/publish.py +285 -0
- flwr/cli/app_cmd/review.py +252 -0
- flwr/cli/auth_plugin/auth_plugin.py +4 -5
- flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
- flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
- flwr/cli/build.py +60 -18
- flwr/cli/cli_account_auth_interceptor.py +24 -7
- flwr/cli/config_utils.py +101 -13
- flwr/cli/federation/__init__.py +24 -0
- flwr/cli/federation/ls.py +140 -0
- flwr/cli/federation/show.py +317 -0
- flwr/cli/install.py +91 -13
- flwr/cli/log.py +52 -9
- flwr/cli/login/login.py +7 -4
- flwr/cli/ls.py +170 -130
- flwr/cli/new/new.py +33 -50
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/cli/pull.py +10 -5
- flwr/cli/run/run.py +77 -30
- flwr/cli/run_utils.py +130 -0
- flwr/cli/stop.py +25 -7
- flwr/cli/supernode/ls.py +16 -8
- flwr/cli/supernode/register.py +9 -4
- flwr/cli/supernode/unregister.py +5 -3
- flwr/cli/utils.py +376 -16
- flwr/client/__init__.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +4 -1
- flwr/client/grpc_adapter_client/connection.py +6 -7
- flwr/client/grpc_rere_client/connection.py +10 -11
- flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
- flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/connection.py +12 -14
- flwr/client/run_info_store.py +4 -5
- flwr/client/typing.py +1 -1
- flwr/clientapp/client_app.py +9 -10
- flwr/clientapp/mod/centraldp_mods.py +16 -17
- flwr/clientapp/mod/localdp_mod.py +8 -9
- flwr/clientapp/typing.py +1 -1
- flwr/clientapp/utils.py +3 -3
- flwr/common/address.py +1 -2
- flwr/common/args.py +3 -4
- flwr/common/config.py +13 -16
- flwr/common/constant.py +5 -2
- flwr/common/differential_privacy.py +3 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -4
- flwr/common/exit/exit.py +15 -2
- flwr/common/exit/exit_code.py +19 -0
- flwr/common/exit/exit_handler.py +6 -2
- flwr/common/exit/signal_handler.py +5 -5
- flwr/common/grpc.py +6 -6
- flwr/common/inflatable_protobuf_utils.py +1 -1
- flwr/common/inflatable_utils.py +38 -21
- flwr/common/logger.py +19 -19
- flwr/common/message.py +4 -4
- flwr/common/object_ref.py +7 -7
- flwr/common/record/array.py +3 -3
- flwr/common/record/arrayrecord.py +18 -30
- flwr/common/record/configrecord.py +3 -3
- flwr/common/record/recorddict.py +5 -5
- flwr/common/record/typeddict.py +9 -2
- flwr/common/recorddict_compat.py +7 -10
- flwr/common/retry_invoker.py +20 -20
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
- flwr/common/serde.py +5 -4
- flwr/common/serde_utils.py +2 -2
- flwr/common/telemetry.py +9 -5
- flwr/common/typing.py +52 -37
- flwr/compat/client/app.py +38 -37
- flwr/compat/client/grpc_client/connection.py +11 -11
- flwr/compat/server/app.py +5 -6
- flwr/proto/appio_pb2.py +13 -3
- flwr/proto/appio_pb2.pyi +134 -65
- flwr/proto/appio_pb2_grpc.py +20 -0
- flwr/proto/appio_pb2_grpc.pyi +27 -0
- flwr/proto/clientappio_pb2.py +17 -7
- flwr/proto/clientappio_pb2.pyi +15 -0
- flwr/proto/clientappio_pb2_grpc.py +206 -40
- flwr/proto/clientappio_pb2_grpc.pyi +168 -53
- flwr/proto/control_pb2.py +71 -52
- flwr/proto/control_pb2.pyi +277 -111
- flwr/proto/control_pb2_grpc.py +249 -40
- flwr/proto/control_pb2_grpc.pyi +185 -52
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2.pyi +24 -6
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/error_pb2_grpc.pyi +27 -0
- flwr/proto/fab_pb2.py +14 -4
- flwr/proto/fab_pb2.pyi +59 -31
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fab_pb2_grpc.pyi +27 -0
- flwr/proto/federation_pb2.py +38 -0
- flwr/proto/federation_pb2.pyi +56 -0
- flwr/proto/federation_pb2_grpc.py +24 -0
- flwr/proto/federation_pb2_grpc.pyi +31 -0
- flwr/proto/fleet_pb2.py +14 -4
- flwr/proto/fleet_pb2.pyi +137 -61
- flwr/proto/fleet_pb2_grpc.py +189 -48
- flwr/proto/fleet_pb2_grpc.pyi +175 -61
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2.pyi +38 -16
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
- flwr/proto/heartbeat_pb2.py +17 -7
- flwr/proto/heartbeat_pb2.pyi +51 -22
- flwr/proto/heartbeat_pb2_grpc.py +20 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2.pyi +34 -11
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/log_pb2_grpc.pyi +27 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2.pyi +154 -86
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/message_pb2_grpc.pyi +27 -0
- flwr/proto/node_pb2.py +15 -5
- flwr/proto/node_pb2.pyi +50 -25
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/node_pb2_grpc.pyi +27 -0
- flwr/proto/recorddict_pb2.py +13 -3
- flwr/proto/recorddict_pb2.pyi +184 -107
- flwr/proto/recorddict_pb2_grpc.py +20 -0
- flwr/proto/recorddict_pb2_grpc.pyi +27 -0
- flwr/proto/run_pb2.py +40 -31
- flwr/proto/run_pb2.pyi +149 -84
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/run_pb2_grpc.pyi +27 -0
- flwr/proto/serverappio_pb2.py +13 -3
- flwr/proto/serverappio_pb2.pyi +32 -8
- flwr/proto/serverappio_pb2_grpc.py +246 -65
- flwr/proto/serverappio_pb2_grpc.pyi +221 -85
- flwr/proto/simulationio_pb2.py +16 -8
- flwr/proto/simulationio_pb2.pyi +15 -0
- flwr/proto/simulationio_pb2_grpc.py +162 -41
- flwr/proto/simulationio_pb2_grpc.pyi +149 -55
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2.pyi +249 -160
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/proto/transport_pb2_grpc.pyi +38 -8
- flwr/server/app.py +38 -17
- flwr/server/client_manager.py +4 -5
- flwr/server/client_proxy.py +10 -11
- flwr/server/compat/app.py +4 -5
- flwr/server/compat/app_utils.py +2 -1
- flwr/server/compat/grid_client_proxy.py +10 -12
- flwr/server/compat/legacy_context.py +3 -4
- flwr/server/fleet_event_log_interceptor.py +2 -1
- flwr/server/grid/grid.py +2 -3
- flwr/server/grid/grpc_grid.py +10 -8
- flwr/server/grid/inmemory_grid.py +4 -4
- flwr/server/run_serverapp.py +2 -3
- flwr/server/server.py +34 -39
- flwr/server/server_app.py +7 -8
- flwr/server/server_config.py +1 -2
- flwr/server/serverapp/app.py +34 -28
- flwr/server/serverapp_components.py +4 -5
- flwr/server/strategy/aggregate.py +9 -8
- flwr/server/strategy/bulyan.py +13 -11
- flwr/server/strategy/dp_adaptive_clipping.py +16 -20
- flwr/server/strategy/dp_fixed_clipping.py +12 -17
- flwr/server/strategy/dpfedavg_adaptive.py +3 -4
- flwr/server/strategy/dpfedavg_fixed.py +6 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
- flwr/server/strategy/fedadagrad.py +18 -14
- flwr/server/strategy/fedadam.py +16 -14
- flwr/server/strategy/fedavg.py +16 -17
- flwr/server/strategy/fedavg_android.py +15 -15
- flwr/server/strategy/fedavgm.py +21 -18
- flwr/server/strategy/fedmedian.py +2 -3
- flwr/server/strategy/fedopt.py +11 -10
- flwr/server/strategy/fedprox.py +10 -9
- flwr/server/strategy/fedtrimmedavg.py +12 -11
- flwr/server/strategy/fedxgb_bagging.py +13 -11
- flwr/server/strategy/fedxgb_cyclic.py +6 -6
- flwr/server/strategy/fedxgb_nn_avg.py +4 -4
- flwr/server/strategy/fedyogi.py +16 -14
- flwr/server/strategy/krum.py +12 -11
- flwr/server/strategy/qfedavg.py +16 -15
- flwr/server/strategy/strategy.py +6 -9
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
- flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +34 -28
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +15 -9
- flwr/server/superlink/linkstate/in_memory_linkstate.py +115 -150
- flwr/server/superlink/linkstate/linkstate.py +59 -43
- flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
- flwr/server/superlink/linkstate/sqlite_linkstate.py +447 -438
- flwr/server/superlink/linkstate/utils.py +6 -6
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
- flwr/server/superlink/utils.py +4 -6
- flwr/server/typing.py +1 -1
- flwr/server/utils/tensorboard.py +15 -8
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
- flwr/serverapp/strategy/bulyan.py +16 -15
- flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
- flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
- flwr/serverapp/strategy/fedadagrad.py +10 -11
- flwr/serverapp/strategy/fedadam.py +10 -11
- flwr/serverapp/strategy/fedavg.py +9 -10
- flwr/serverapp/strategy/fedavgm.py +17 -16
- flwr/serverapp/strategy/fedmedian.py +2 -2
- flwr/serverapp/strategy/fedopt.py +10 -11
- flwr/serverapp/strategy/fedprox.py +7 -8
- flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
- flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
- flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
- flwr/serverapp/strategy/fedyogi.py +9 -11
- flwr/serverapp/strategy/krum.py +7 -7
- flwr/serverapp/strategy/multikrum.py +9 -9
- flwr/serverapp/strategy/qfedavg.py +17 -16
- flwr/serverapp/strategy/strategy.py +6 -9
- flwr/serverapp/strategy/strategy_utils.py +7 -8
- flwr/simulation/app.py +46 -42
- flwr/simulation/legacy_app.py +12 -12
- flwr/simulation/ray_transport/ray_actor.py +10 -11
- flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
- flwr/simulation/run_simulation.py +43 -43
- flwr/simulation/simulationio_connection.py +4 -4
- flwr/supercore/cli/flower_superexec.py +3 -4
- flwr/supercore/constant.py +31 -1
- flwr/supercore/corestate/corestate.py +24 -3
- flwr/supercore/corestate/in_memory_corestate.py +138 -0
- flwr/supercore/corestate/sqlite_corestate.py +157 -0
- flwr/supercore/ffs/disk_ffs.py +1 -2
- flwr/supercore/ffs/ffs.py +1 -2
- flwr/supercore/ffs/ffs_factory.py +1 -2
- flwr/{common → supercore}/heartbeat.py +20 -25
- flwr/supercore/object_store/in_memory_object_store.py +1 -2
- flwr/supercore/object_store/object_store.py +1 -2
- flwr/supercore/object_store/object_store_factory.py +1 -2
- flwr/supercore/object_store/sqlite_object_store.py +8 -7
- flwr/supercore/primitives/asymmetric.py +1 -1
- flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
- flwr/supercore/sqlite_mixin.py +37 -34
- flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
- flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
- flwr/supercore/superexec/run_superexec.py +9 -13
- flwr/superlink/artifact_provider/artifact_provider.py +1 -2
- flwr/superlink/auth_plugin/auth_plugin.py +6 -9
- flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
- flwr/superlink/federation/__init__.py +24 -0
- flwr/superlink/federation/federation_manager.py +64 -0
- flwr/superlink/federation/noop_federation_manager.py +71 -0
- flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
- flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
- flwr/superlink/servicer/control/control_grpc.py +5 -6
- flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
- flwr/superlink/servicer/control/control_servicer.py +102 -18
- flwr/supernode/cli/flower_supernode.py +58 -3
- flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
- flwr/supernode/nodestate/nodestate.py +7 -8
- flwr/supernode/nodestate/nodestate_factory.py +7 -4
- flwr/supernode/runtime/run_clientapp.py +41 -22
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
- flwr/supernode/start_client_internal.py +158 -42
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
- flwr-1.24.0.dist-info/RECORD +454 -0
- flwr/supercore/object_store/utils.py +0 -43
- flwr-1.23.0.dist-info/RECORD +0 -439
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
|
@@ -20,7 +20,6 @@ Paper (Andrew et al.): https://arxiv.org/abs/1905.03871
|
|
|
20
20
|
|
|
21
21
|
import math
|
|
22
22
|
from logging import INFO, WARNING
|
|
23
|
-
from typing import Optional, Union
|
|
24
23
|
|
|
25
24
|
import numpy as np
|
|
26
25
|
|
|
@@ -97,7 +96,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
|
|
|
97
96
|
initial_clipping_norm: float = 0.1,
|
|
98
97
|
target_clipped_quantile: float = 0.5,
|
|
99
98
|
clip_norm_lr: float = 0.2,
|
|
100
|
-
clipped_count_stddev:
|
|
99
|
+
clipped_count_stddev: float | None = None,
|
|
101
100
|
) -> None:
|
|
102
101
|
super().__init__()
|
|
103
102
|
|
|
@@ -148,9 +147,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
|
|
|
148
147
|
rep = "Differential Privacy Strategy Wrapper (Server-Side Adaptive Clipping)"
|
|
149
148
|
return rep
|
|
150
149
|
|
|
151
|
-
def initialize_parameters(
|
|
152
|
-
self, client_manager: ClientManager
|
|
153
|
-
) -> Optional[Parameters]:
|
|
150
|
+
def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
|
|
154
151
|
"""Initialize global model parameters using given strategy."""
|
|
155
152
|
return self.strategy.initialize_parameters(client_manager)
|
|
156
153
|
|
|
@@ -173,8 +170,8 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
|
|
|
173
170
|
self,
|
|
174
171
|
server_round: int,
|
|
175
172
|
results: list[tuple[ClientProxy, FitRes]],
|
|
176
|
-
failures: list[
|
|
177
|
-
) -> tuple[
|
|
173
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
174
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
178
175
|
"""Aggregate training results and update clip norms."""
|
|
179
176
|
if failures:
|
|
180
177
|
return None, {}
|
|
@@ -192,7 +189,8 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
|
|
|
192
189
|
param = parameters_to_ndarrays(res.parameters)
|
|
193
190
|
# Compute and clip update
|
|
194
191
|
model_update = [
|
|
195
|
-
np.subtract(x, y)
|
|
192
|
+
np.subtract(x, y)
|
|
193
|
+
for (x, y) in zip(param, self.current_round_params, strict=True)
|
|
196
194
|
]
|
|
197
195
|
|
|
198
196
|
norm_bit = adaptive_clip_inputs_inplace(model_update, self.clipping_norm)
|
|
@@ -246,14 +244,14 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
|
|
|
246
244
|
self,
|
|
247
245
|
server_round: int,
|
|
248
246
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
249
|
-
failures: list[
|
|
250
|
-
) -> tuple[
|
|
247
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
248
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
251
249
|
"""Aggregate evaluation losses using the given strategy."""
|
|
252
250
|
return self.strategy.aggregate_evaluate(server_round, results, failures)
|
|
253
251
|
|
|
254
252
|
def evaluate(
|
|
255
253
|
self, server_round: int, parameters: Parameters
|
|
256
|
-
) ->
|
|
254
|
+
) -> tuple[float, dict[str, Scalar]] | None:
|
|
257
255
|
"""Evaluate model parameters using an evaluation function from the strategy."""
|
|
258
256
|
return self.strategy.evaluate(server_round, parameters)
|
|
259
257
|
|
|
@@ -316,7 +314,7 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
|
|
|
316
314
|
initial_clipping_norm: float = 0.1,
|
|
317
315
|
target_clipped_quantile: float = 0.5,
|
|
318
316
|
clip_norm_lr: float = 0.2,
|
|
319
|
-
clipped_count_stddev:
|
|
317
|
+
clipped_count_stddev: float | None = None,
|
|
320
318
|
) -> None:
|
|
321
319
|
super().__init__()
|
|
322
320
|
|
|
@@ -364,9 +362,7 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
|
|
|
364
362
|
rep = "Differential Privacy Strategy Wrapper (Client-Side Adaptive Clipping)"
|
|
365
363
|
return rep
|
|
366
364
|
|
|
367
|
-
def initialize_parameters(
|
|
368
|
-
self, client_manager: ClientManager
|
|
369
|
-
) -> Optional[Parameters]:
|
|
365
|
+
def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
|
|
370
366
|
"""Initialize global model parameters using given strategy."""
|
|
371
367
|
return self.strategy.initialize_parameters(client_manager)
|
|
372
368
|
|
|
@@ -395,8 +391,8 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
|
|
|
395
391
|
self,
|
|
396
392
|
server_round: int,
|
|
397
393
|
results: list[tuple[ClientProxy, FitRes]],
|
|
398
|
-
failures: list[
|
|
399
|
-
) -> tuple[
|
|
394
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
395
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
400
396
|
"""Aggregate training results and update clip norms."""
|
|
401
397
|
if failures:
|
|
402
398
|
return None, {}
|
|
@@ -458,13 +454,13 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
|
|
|
458
454
|
self,
|
|
459
455
|
server_round: int,
|
|
460
456
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
461
|
-
failures: list[
|
|
462
|
-
) -> tuple[
|
|
457
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
458
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
463
459
|
"""Aggregate evaluation losses using the given strategy."""
|
|
464
460
|
return self.strategy.aggregate_evaluate(server_round, results, failures)
|
|
465
461
|
|
|
466
462
|
def evaluate(
|
|
467
463
|
self, server_round: int, parameters: Parameters
|
|
468
|
-
) ->
|
|
464
|
+
) -> tuple[float, dict[str, Scalar]] | None:
|
|
469
465
|
"""Evaluate model parameters using an evaluation function from the strategy."""
|
|
470
466
|
return self.strategy.evaluate(server_round, parameters)
|
|
@@ -19,7 +19,6 @@ Papers: https://arxiv.org/abs/1712.07557, https://arxiv.org/abs/1710.06963
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
from logging import INFO, WARNING
|
|
22
|
-
from typing import Optional, Union
|
|
23
22
|
|
|
24
23
|
from flwr.common import (
|
|
25
24
|
EvaluateIns,
|
|
@@ -109,9 +108,7 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
|
|
|
109
108
|
rep = "Differential Privacy Strategy Wrapper (Server-Side Fixed Clipping)"
|
|
110
109
|
return rep
|
|
111
110
|
|
|
112
|
-
def initialize_parameters(
|
|
113
|
-
self, client_manager: ClientManager
|
|
114
|
-
) -> Optional[Parameters]:
|
|
111
|
+
def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
|
|
115
112
|
"""Initialize global model parameters using given strategy."""
|
|
116
113
|
return self.strategy.initialize_parameters(client_manager)
|
|
117
114
|
|
|
@@ -134,8 +131,8 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
|
|
|
134
131
|
self,
|
|
135
132
|
server_round: int,
|
|
136
133
|
results: list[tuple[ClientProxy, FitRes]],
|
|
137
|
-
failures: list[
|
|
138
|
-
) -> tuple[
|
|
134
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
135
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
139
136
|
"""Compute the updates, clip, and pass them for aggregation.
|
|
140
137
|
|
|
141
138
|
Afterward, add noise to the aggregated parameters.
|
|
@@ -192,14 +189,14 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
|
|
|
192
189
|
self,
|
|
193
190
|
server_round: int,
|
|
194
191
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
195
|
-
failures: list[
|
|
196
|
-
) -> tuple[
|
|
192
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
193
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
197
194
|
"""Aggregate evaluation losses using the given strategy."""
|
|
198
195
|
return self.strategy.aggregate_evaluate(server_round, results, failures)
|
|
199
196
|
|
|
200
197
|
def evaluate(
|
|
201
198
|
self, server_round: int, parameters: Parameters
|
|
202
|
-
) ->
|
|
199
|
+
) -> tuple[float, dict[str, Scalar]] | None:
|
|
203
200
|
"""Evaluate model parameters using an evaluation function from the strategy."""
|
|
204
201
|
return self.strategy.evaluate(server_round, parameters)
|
|
205
202
|
|
|
@@ -277,9 +274,7 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
|
|
|
277
274
|
rep = "Differential Privacy Strategy Wrapper (Client-Side Fixed Clipping)"
|
|
278
275
|
return rep
|
|
279
276
|
|
|
280
|
-
def initialize_parameters(
|
|
281
|
-
self, client_manager: ClientManager
|
|
282
|
-
) -> Optional[Parameters]:
|
|
277
|
+
def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
|
|
283
278
|
"""Initialize global model parameters using given strategy."""
|
|
284
279
|
return self.strategy.initialize_parameters(client_manager)
|
|
285
280
|
|
|
@@ -308,8 +303,8 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
|
|
|
308
303
|
self,
|
|
309
304
|
server_round: int,
|
|
310
305
|
results: list[tuple[ClientProxy, FitRes]],
|
|
311
|
-
failures: list[
|
|
312
|
-
) -> tuple[
|
|
306
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
307
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
313
308
|
"""Add noise to the aggregated parameters."""
|
|
314
309
|
if failures:
|
|
315
310
|
return None, {}
|
|
@@ -349,13 +344,13 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
|
|
|
349
344
|
self,
|
|
350
345
|
server_round: int,
|
|
351
346
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
352
|
-
failures: list[
|
|
353
|
-
) -> tuple[
|
|
347
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
348
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
354
349
|
"""Aggregate evaluation losses using the given strategy."""
|
|
355
350
|
return self.strategy.aggregate_evaluate(server_round, results, failures)
|
|
356
351
|
|
|
357
352
|
def evaluate(
|
|
358
353
|
self, server_round: int, parameters: Parameters
|
|
359
|
-
) ->
|
|
354
|
+
) -> tuple[float, dict[str, Scalar]] | None:
|
|
360
355
|
"""Evaluate model parameters using an evaluation function from the strategy."""
|
|
361
356
|
return self.strategy.evaluate(server_round, parameters)
|
|
@@ -19,7 +19,6 @@ Paper: arxiv.org/pdf/1905.03871.pdf
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
import math
|
|
22
|
-
from typing import Optional, Union
|
|
23
22
|
|
|
24
23
|
import numpy as np
|
|
25
24
|
|
|
@@ -49,7 +48,7 @@ class DPFedAvgAdaptive(DPFedAvgFixed):
|
|
|
49
48
|
server_side_noising: bool = True,
|
|
50
49
|
clip_norm_lr: float = 0.2,
|
|
51
50
|
clip_norm_target_quantile: float = 0.5,
|
|
52
|
-
clip_count_stddev:
|
|
51
|
+
clip_count_stddev: float | None = None,
|
|
53
52
|
) -> None:
|
|
54
53
|
warn_deprecated_feature("`DPFedAvgAdaptive` wrapper")
|
|
55
54
|
super().__init__(
|
|
@@ -119,8 +118,8 @@ class DPFedAvgAdaptive(DPFedAvgFixed):
|
|
|
119
118
|
self,
|
|
120
119
|
server_round: int,
|
|
121
120
|
results: list[tuple[ClientProxy, FitRes]],
|
|
122
|
-
failures: list[
|
|
123
|
-
) -> tuple[
|
|
121
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
122
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
124
123
|
"""Aggregate training results as in DPFedAvgFixed and update clip norms."""
|
|
125
124
|
if failures:
|
|
126
125
|
return None, {}
|
|
@@ -18,8 +18,6 @@ Paper: arxiv.org/pdf/1710.06963.pdf
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from typing import Optional, Union
|
|
22
|
-
|
|
23
21
|
from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar
|
|
24
22
|
from flwr.common.dp import add_gaussian_noise
|
|
25
23
|
from flwr.common.logger import warn_deprecated_feature
|
|
@@ -72,9 +70,7 @@ class DPFedAvgFixed(Strategy):
|
|
|
72
70
|
self.noise_multiplier * self.clip_norm / (self.num_sampled_clients ** (0.5))
|
|
73
71
|
)
|
|
74
72
|
|
|
75
|
-
def initialize_parameters(
|
|
76
|
-
self, client_manager: ClientManager
|
|
77
|
-
) -> Optional[Parameters]:
|
|
73
|
+
def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
|
|
78
74
|
"""Initialize global model parameters using given strategy."""
|
|
79
75
|
return self.strategy.initialize_parameters(client_manager)
|
|
80
76
|
|
|
@@ -149,8 +145,8 @@ class DPFedAvgFixed(Strategy):
|
|
|
149
145
|
self,
|
|
150
146
|
server_round: int,
|
|
151
147
|
results: list[tuple[ClientProxy, FitRes]],
|
|
152
|
-
failures: list[
|
|
153
|
-
) -> tuple[
|
|
148
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
149
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
154
150
|
"""Aggregate training results using unweighted aggregation."""
|
|
155
151
|
if failures:
|
|
156
152
|
return None, {}
|
|
@@ -170,13 +166,13 @@ class DPFedAvgFixed(Strategy):
|
|
|
170
166
|
self,
|
|
171
167
|
server_round: int,
|
|
172
168
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
173
|
-
failures: list[
|
|
174
|
-
) -> tuple[
|
|
169
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
170
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
175
171
|
"""Aggregate evaluation losses using the given strategy."""
|
|
176
172
|
return self.strategy.aggregate_evaluate(server_round, results, failures)
|
|
177
173
|
|
|
178
174
|
def evaluate(
|
|
179
175
|
self, server_round: int, parameters: Parameters
|
|
180
|
-
) ->
|
|
176
|
+
) -> tuple[float, dict[str, Scalar]] | None:
|
|
181
177
|
"""Evaluate model parameters using an evaluation function from the strategy."""
|
|
182
178
|
return self.strategy.evaluate(server_round, parameters)
|
|
@@ -15,8 +15,8 @@
|
|
|
15
15
|
"""Fault-tolerant variant of FedAvg strategy."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from collections.abc import Callable
|
|
18
19
|
from logging import WARNING
|
|
19
|
-
from typing import Callable, Optional, Union
|
|
20
20
|
|
|
21
21
|
from flwr.common import (
|
|
22
22
|
EvaluateRes,
|
|
@@ -47,19 +47,20 @@ class FaultTolerantFedAvg(FedAvg):
|
|
|
47
47
|
min_fit_clients: int = 1,
|
|
48
48
|
min_evaluate_clients: int = 1,
|
|
49
49
|
min_available_clients: int = 1,
|
|
50
|
-
evaluate_fn:
|
|
50
|
+
evaluate_fn: (
|
|
51
51
|
Callable[
|
|
52
52
|
[int, NDArrays, dict[str, Scalar]],
|
|
53
|
-
|
|
53
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
54
54
|
]
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
55
|
+
| None
|
|
56
|
+
) = None,
|
|
57
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
58
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
58
59
|
min_completion_rate_fit: float = 0.5,
|
|
59
60
|
min_completion_rate_evaluate: float = 0.5,
|
|
60
|
-
initial_parameters:
|
|
61
|
-
fit_metrics_aggregation_fn:
|
|
62
|
-
evaluate_metrics_aggregation_fn:
|
|
61
|
+
initial_parameters: Parameters | None = None,
|
|
62
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
63
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
63
64
|
) -> None:
|
|
64
65
|
super().__init__(
|
|
65
66
|
fraction_fit=fraction_fit,
|
|
@@ -86,8 +87,8 @@ class FaultTolerantFedAvg(FedAvg):
|
|
|
86
87
|
self,
|
|
87
88
|
server_round: int,
|
|
88
89
|
results: list[tuple[ClientProxy, FitRes]],
|
|
89
|
-
failures: list[
|
|
90
|
-
) -> tuple[
|
|
90
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
91
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
91
92
|
"""Aggregate fit results using weighted average."""
|
|
92
93
|
if not results:
|
|
93
94
|
return None, {}
|
|
@@ -118,8 +119,8 @@ class FaultTolerantFedAvg(FedAvg):
|
|
|
118
119
|
self,
|
|
119
120
|
server_round: int,
|
|
120
121
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
121
|
-
failures: list[
|
|
122
|
-
) -> tuple[
|
|
122
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
123
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
123
124
|
"""Aggregate evaluation losses using weighted average."""
|
|
124
125
|
if not results:
|
|
125
126
|
return None, {}
|
|
@@ -20,7 +20,7 @@ Paper: arxiv.org/abs/2003.00295
|
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
from
|
|
23
|
+
from collections.abc import Callable
|
|
24
24
|
|
|
25
25
|
import numpy as np
|
|
26
26
|
|
|
@@ -87,16 +87,17 @@ class FedAdagrad(FedOpt):
|
|
|
87
87
|
min_fit_clients: int = 2,
|
|
88
88
|
min_evaluate_clients: int = 2,
|
|
89
89
|
min_available_clients: int = 2,
|
|
90
|
-
evaluate_fn:
|
|
90
|
+
evaluate_fn: (
|
|
91
91
|
Callable[
|
|
92
92
|
[int, NDArrays, dict[str, Scalar]],
|
|
93
|
-
|
|
93
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
94
94
|
]
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
95
|
+
| None
|
|
96
|
+
) = None,
|
|
97
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
98
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
99
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
100
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
100
101
|
accept_failures: bool = True,
|
|
101
102
|
initial_parameters: Parameters,
|
|
102
103
|
eta: float = 1e-1,
|
|
@@ -132,8 +133,8 @@ class FedAdagrad(FedOpt):
|
|
|
132
133
|
self,
|
|
133
134
|
server_round: int,
|
|
134
135
|
results: list[tuple[ClientProxy, FitRes]],
|
|
135
|
-
failures: list[
|
|
136
|
-
) -> tuple[
|
|
136
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
137
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
137
138
|
"""Aggregate fit results using weighted average."""
|
|
138
139
|
fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit(
|
|
139
140
|
server_round=server_round, results=results, failures=failures
|
|
@@ -145,7 +146,8 @@ class FedAdagrad(FedOpt):
|
|
|
145
146
|
|
|
146
147
|
# Adagrad
|
|
147
148
|
delta_t: NDArrays = [
|
|
148
|
-
x - y
|
|
149
|
+
x - y
|
|
150
|
+
for x, y in zip(fedavg_weights_aggregate, self.current_weights, strict=True)
|
|
149
151
|
]
|
|
150
152
|
|
|
151
153
|
# m_t
|
|
@@ -153,17 +155,19 @@ class FedAdagrad(FedOpt):
|
|
|
153
155
|
self.m_t = [np.zeros_like(x) for x in delta_t]
|
|
154
156
|
self.m_t = [
|
|
155
157
|
np.multiply(self.beta_1, x) + (1 - self.beta_1) * y
|
|
156
|
-
for x, y in zip(self.m_t, delta_t)
|
|
158
|
+
for x, y in zip(self.m_t, delta_t, strict=True)
|
|
157
159
|
]
|
|
158
160
|
|
|
159
161
|
# v_t
|
|
160
162
|
if not self.v_t:
|
|
161
163
|
self.v_t = [np.zeros_like(x) for x in delta_t]
|
|
162
|
-
self.v_t = [
|
|
164
|
+
self.v_t = [
|
|
165
|
+
x + np.multiply(y, y) for x, y in zip(self.v_t, delta_t, strict=True)
|
|
166
|
+
]
|
|
163
167
|
|
|
164
168
|
new_weights = [
|
|
165
169
|
x + self.eta * y / (np.sqrt(z) + self.tau)
|
|
166
|
-
for x, y, z in zip(self.current_weights, self.m_t, self.v_t)
|
|
170
|
+
for x, y, z in zip(self.current_weights, self.m_t, self.v_t, strict=True)
|
|
167
171
|
]
|
|
168
172
|
|
|
169
173
|
self.current_weights = new_weights
|
flwr/server/strategy/fedadam.py
CHANGED
|
@@ -20,7 +20,7 @@ Paper: arxiv.org/abs/2003.00295
|
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
from
|
|
23
|
+
from collections.abc import Callable
|
|
24
24
|
|
|
25
25
|
import numpy as np
|
|
26
26
|
|
|
@@ -91,18 +91,19 @@ class FedAdam(FedOpt):
|
|
|
91
91
|
min_fit_clients: int = 2,
|
|
92
92
|
min_evaluate_clients: int = 2,
|
|
93
93
|
min_available_clients: int = 2,
|
|
94
|
-
evaluate_fn:
|
|
94
|
+
evaluate_fn: (
|
|
95
95
|
Callable[
|
|
96
96
|
[int, NDArrays, dict[str, Scalar]],
|
|
97
|
-
|
|
97
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
98
98
|
]
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
99
|
+
| None
|
|
100
|
+
) = None,
|
|
101
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
102
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
102
103
|
accept_failures: bool = True,
|
|
103
104
|
initial_parameters: Parameters,
|
|
104
|
-
fit_metrics_aggregation_fn:
|
|
105
|
-
evaluate_metrics_aggregation_fn:
|
|
105
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
106
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
106
107
|
eta: float = 1e-1,
|
|
107
108
|
eta_l: float = 1e-1,
|
|
108
109
|
beta_1: float = 0.9,
|
|
@@ -138,8 +139,8 @@ class FedAdam(FedOpt):
|
|
|
138
139
|
self,
|
|
139
140
|
server_round: int,
|
|
140
141
|
results: list[tuple[ClientProxy, FitRes]],
|
|
141
|
-
failures: list[
|
|
142
|
-
) -> tuple[
|
|
142
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
143
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
143
144
|
"""Aggregate fit results using weighted average."""
|
|
144
145
|
fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit(
|
|
145
146
|
server_round=server_round, results=results, failures=failures
|
|
@@ -151,7 +152,8 @@ class FedAdam(FedOpt):
|
|
|
151
152
|
|
|
152
153
|
# Adam
|
|
153
154
|
delta_t: NDArrays = [
|
|
154
|
-
x - y
|
|
155
|
+
x - y
|
|
156
|
+
for x, y in zip(fedavg_weights_aggregate, self.current_weights, strict=True)
|
|
155
157
|
]
|
|
156
158
|
|
|
157
159
|
# m_t
|
|
@@ -159,7 +161,7 @@ class FedAdam(FedOpt):
|
|
|
159
161
|
self.m_t = [np.zeros_like(x) for x in delta_t]
|
|
160
162
|
self.m_t = [
|
|
161
163
|
np.multiply(self.beta_1, x) + (1 - self.beta_1) * y
|
|
162
|
-
for x, y in zip(self.m_t, delta_t)
|
|
164
|
+
for x, y in zip(self.m_t, delta_t, strict=True)
|
|
163
165
|
]
|
|
164
166
|
|
|
165
167
|
# v_t
|
|
@@ -167,7 +169,7 @@ class FedAdam(FedOpt):
|
|
|
167
169
|
self.v_t = [np.zeros_like(x) for x in delta_t]
|
|
168
170
|
self.v_t = [
|
|
169
171
|
self.beta_2 * x + (1 - self.beta_2) * np.multiply(y, y)
|
|
170
|
-
for x, y in zip(self.v_t, delta_t)
|
|
172
|
+
for x, y in zip(self.v_t, delta_t, strict=True)
|
|
171
173
|
]
|
|
172
174
|
|
|
173
175
|
# Compute the bias-corrected learning rate, `eta_norm` for improving convergence
|
|
@@ -182,7 +184,7 @@ class FedAdam(FedOpt):
|
|
|
182
184
|
|
|
183
185
|
new_weights = [
|
|
184
186
|
x + eta_norm * y / (np.sqrt(z) + self.tau)
|
|
185
|
-
for x, y, z in zip(self.current_weights, self.m_t, self.v_t)
|
|
187
|
+
for x, y, z in zip(self.current_weights, self.m_t, self.v_t, strict=True)
|
|
186
188
|
]
|
|
187
189
|
|
|
188
190
|
self.current_weights = new_weights
|
flwr/server/strategy/fedavg.py
CHANGED
|
@@ -18,8 +18,8 @@ Paper: arxiv.org/abs/1602.05629
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
+
from collections.abc import Callable
|
|
21
22
|
from logging import WARNING
|
|
22
|
-
from typing import Callable, Optional, Union
|
|
23
23
|
|
|
24
24
|
from flwr.common import (
|
|
25
25
|
EvaluateIns,
|
|
@@ -97,18 +97,19 @@ class FedAvg(Strategy):
|
|
|
97
97
|
min_fit_clients: int = 2,
|
|
98
98
|
min_evaluate_clients: int = 2,
|
|
99
99
|
min_available_clients: int = 2,
|
|
100
|
-
evaluate_fn:
|
|
100
|
+
evaluate_fn: (
|
|
101
101
|
Callable[
|
|
102
102
|
[int, NDArrays, dict[str, Scalar]],
|
|
103
|
-
|
|
103
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
104
104
|
]
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
105
|
+
| None
|
|
106
|
+
) = None,
|
|
107
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
108
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
108
109
|
accept_failures: bool = True,
|
|
109
|
-
initial_parameters:
|
|
110
|
-
fit_metrics_aggregation_fn:
|
|
111
|
-
evaluate_metrics_aggregation_fn:
|
|
110
|
+
initial_parameters: Parameters | None = None,
|
|
111
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
112
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
112
113
|
inplace: bool = True,
|
|
113
114
|
) -> None:
|
|
114
115
|
super().__init__()
|
|
@@ -148,9 +149,7 @@ class FedAvg(Strategy):
|
|
|
148
149
|
num_clients = int(num_available_clients * self.fraction_evaluate)
|
|
149
150
|
return max(num_clients, self.min_evaluate_clients), self.min_available_clients
|
|
150
151
|
|
|
151
|
-
def initialize_parameters(
|
|
152
|
-
self, client_manager: ClientManager
|
|
153
|
-
) -> Optional[Parameters]:
|
|
152
|
+
def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
|
|
154
153
|
"""Initialize global model parameters."""
|
|
155
154
|
initial_parameters = self.initial_parameters
|
|
156
155
|
self.initial_parameters = None # Don't keep initial parameters in memory
|
|
@@ -158,7 +157,7 @@ class FedAvg(Strategy):
|
|
|
158
157
|
|
|
159
158
|
def evaluate(
|
|
160
159
|
self, server_round: int, parameters: Parameters
|
|
161
|
-
) ->
|
|
160
|
+
) -> tuple[float, dict[str, Scalar]] | None:
|
|
162
161
|
"""Evaluate model parameters using an evaluation function."""
|
|
163
162
|
if self.evaluate_fn is None:
|
|
164
163
|
# No evaluation function provided
|
|
@@ -221,8 +220,8 @@ class FedAvg(Strategy):
|
|
|
221
220
|
self,
|
|
222
221
|
server_round: int,
|
|
223
222
|
results: list[tuple[ClientProxy, FitRes]],
|
|
224
|
-
failures: list[
|
|
225
|
-
) -> tuple[
|
|
223
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
224
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
226
225
|
"""Aggregate fit results using weighted average."""
|
|
227
226
|
if not results:
|
|
228
227
|
return None, {}
|
|
@@ -257,8 +256,8 @@ class FedAvg(Strategy):
|
|
|
257
256
|
self,
|
|
258
257
|
server_round: int,
|
|
259
258
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
260
|
-
failures: list[
|
|
261
|
-
) -> tuple[
|
|
259
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
260
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
262
261
|
"""Aggregate evaluation losses using weighted average."""
|
|
263
262
|
if not results:
|
|
264
263
|
return None, {}
|
|
@@ -18,7 +18,8 @@ Paper: arxiv.org/abs/1602.05629
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from
|
|
21
|
+
from collections.abc import Callable
|
|
22
|
+
from typing import cast
|
|
22
23
|
|
|
23
24
|
import numpy as np
|
|
24
25
|
|
|
@@ -79,16 +80,17 @@ class FedAvgAndroid(Strategy):
|
|
|
79
80
|
min_fit_clients: int = 2,
|
|
80
81
|
min_evaluate_clients: int = 2,
|
|
81
82
|
min_available_clients: int = 2,
|
|
82
|
-
evaluate_fn:
|
|
83
|
+
evaluate_fn: (
|
|
83
84
|
Callable[
|
|
84
85
|
[int, NDArrays, dict[str, Scalar]],
|
|
85
|
-
|
|
86
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
86
87
|
]
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
88
|
+
| None
|
|
89
|
+
) = None,
|
|
90
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
91
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
90
92
|
accept_failures: bool = True,
|
|
91
|
-
initial_parameters:
|
|
93
|
+
initial_parameters: Parameters | None = None,
|
|
92
94
|
) -> None:
|
|
93
95
|
super().__init__()
|
|
94
96
|
self.min_fit_clients = min_fit_clients
|
|
@@ -117,9 +119,7 @@ class FedAvgAndroid(Strategy):
|
|
|
117
119
|
num_clients = int(num_available_clients * self.fraction_evaluate)
|
|
118
120
|
return max(num_clients, self.min_evaluate_clients), self.min_available_clients
|
|
119
121
|
|
|
120
|
-
def initialize_parameters(
|
|
121
|
-
self, client_manager: ClientManager
|
|
122
|
-
) -> Optional[Parameters]:
|
|
122
|
+
def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
|
|
123
123
|
"""Initialize global model parameters."""
|
|
124
124
|
initial_parameters = self.initial_parameters
|
|
125
125
|
self.initial_parameters = None # Don't keep initial parameters in memory
|
|
@@ -127,7 +127,7 @@ class FedAvgAndroid(Strategy):
|
|
|
127
127
|
|
|
128
128
|
def evaluate(
|
|
129
129
|
self, server_round: int, parameters: Parameters
|
|
130
|
-
) ->
|
|
130
|
+
) -> tuple[float, dict[str, Scalar]] | None:
|
|
131
131
|
"""Evaluate model parameters using an evaluation function."""
|
|
132
132
|
if self.evaluate_fn is None:
|
|
133
133
|
# No evaluation function provided
|
|
@@ -190,8 +190,8 @@ class FedAvgAndroid(Strategy):
|
|
|
190
190
|
self,
|
|
191
191
|
server_round: int,
|
|
192
192
|
results: list[tuple[ClientProxy, FitRes]],
|
|
193
|
-
failures: list[
|
|
194
|
-
) -> tuple[
|
|
193
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
194
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
195
195
|
"""Aggregate fit results using weighted average."""
|
|
196
196
|
if not results:
|
|
197
197
|
return None, {}
|
|
@@ -209,8 +209,8 @@ class FedAvgAndroid(Strategy):
|
|
|
209
209
|
self,
|
|
210
210
|
server_round: int,
|
|
211
211
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
212
|
-
failures: list[
|
|
213
|
-
) -> tuple[
|
|
212
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
213
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
214
214
|
"""Aggregate evaluation losses using weighted average."""
|
|
215
215
|
if not results:
|
|
216
216
|
return None, {}
|