flwr 1.23.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +16 -5
- flwr/app/error.py +2 -2
- flwr/app/exception.py +3 -3
- flwr/cli/app.py +19 -0
- flwr/cli/app_cmd/__init__.py +23 -0
- flwr/cli/app_cmd/publish.py +285 -0
- flwr/cli/app_cmd/review.py +252 -0
- flwr/cli/auth_plugin/auth_plugin.py +4 -5
- flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
- flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
- flwr/cli/build.py +60 -18
- flwr/cli/cli_account_auth_interceptor.py +24 -7
- flwr/cli/config_utils.py +101 -13
- flwr/cli/federation/__init__.py +24 -0
- flwr/cli/federation/ls.py +140 -0
- flwr/cli/federation/show.py +317 -0
- flwr/cli/install.py +91 -13
- flwr/cli/log.py +52 -9
- flwr/cli/login/login.py +7 -4
- flwr/cli/ls.py +170 -130
- flwr/cli/new/new.py +33 -50
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/cli/pull.py +10 -5
- flwr/cli/run/run.py +77 -30
- flwr/cli/run_utils.py +130 -0
- flwr/cli/stop.py +25 -7
- flwr/cli/supernode/ls.py +16 -8
- flwr/cli/supernode/register.py +9 -4
- flwr/cli/supernode/unregister.py +5 -3
- flwr/cli/utils.py +376 -16
- flwr/client/__init__.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +4 -1
- flwr/client/grpc_adapter_client/connection.py +6 -7
- flwr/client/grpc_rere_client/connection.py +10 -11
- flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
- flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/connection.py +12 -14
- flwr/client/run_info_store.py +4 -5
- flwr/client/typing.py +1 -1
- flwr/clientapp/client_app.py +9 -10
- flwr/clientapp/mod/centraldp_mods.py +16 -17
- flwr/clientapp/mod/localdp_mod.py +8 -9
- flwr/clientapp/typing.py +1 -1
- flwr/clientapp/utils.py +3 -3
- flwr/common/address.py +1 -2
- flwr/common/args.py +3 -4
- flwr/common/config.py +13 -16
- flwr/common/constant.py +5 -2
- flwr/common/differential_privacy.py +3 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -4
- flwr/common/exit/exit.py +15 -2
- flwr/common/exit/exit_code.py +19 -0
- flwr/common/exit/exit_handler.py +6 -2
- flwr/common/exit/signal_handler.py +5 -5
- flwr/common/grpc.py +6 -6
- flwr/common/inflatable_protobuf_utils.py +1 -1
- flwr/common/inflatable_utils.py +38 -21
- flwr/common/logger.py +19 -19
- flwr/common/message.py +4 -4
- flwr/common/object_ref.py +7 -7
- flwr/common/record/array.py +3 -3
- flwr/common/record/arrayrecord.py +18 -30
- flwr/common/record/configrecord.py +3 -3
- flwr/common/record/recorddict.py +5 -5
- flwr/common/record/typeddict.py +9 -2
- flwr/common/recorddict_compat.py +7 -10
- flwr/common/retry_invoker.py +20 -20
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
- flwr/common/serde.py +5 -4
- flwr/common/serde_utils.py +2 -2
- flwr/common/telemetry.py +9 -5
- flwr/common/typing.py +52 -37
- flwr/compat/client/app.py +38 -37
- flwr/compat/client/grpc_client/connection.py +11 -11
- flwr/compat/server/app.py +5 -6
- flwr/proto/appio_pb2.py +13 -3
- flwr/proto/appio_pb2.pyi +134 -65
- flwr/proto/appio_pb2_grpc.py +20 -0
- flwr/proto/appio_pb2_grpc.pyi +27 -0
- flwr/proto/clientappio_pb2.py +17 -7
- flwr/proto/clientappio_pb2.pyi +15 -0
- flwr/proto/clientappio_pb2_grpc.py +206 -40
- flwr/proto/clientappio_pb2_grpc.pyi +168 -53
- flwr/proto/control_pb2.py +71 -52
- flwr/proto/control_pb2.pyi +277 -111
- flwr/proto/control_pb2_grpc.py +249 -40
- flwr/proto/control_pb2_grpc.pyi +185 -52
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2.pyi +24 -6
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/error_pb2_grpc.pyi +27 -0
- flwr/proto/fab_pb2.py +14 -4
- flwr/proto/fab_pb2.pyi +59 -31
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fab_pb2_grpc.pyi +27 -0
- flwr/proto/federation_pb2.py +38 -0
- flwr/proto/federation_pb2.pyi +56 -0
- flwr/proto/federation_pb2_grpc.py +24 -0
- flwr/proto/federation_pb2_grpc.pyi +31 -0
- flwr/proto/fleet_pb2.py +14 -4
- flwr/proto/fleet_pb2.pyi +137 -61
- flwr/proto/fleet_pb2_grpc.py +189 -48
- flwr/proto/fleet_pb2_grpc.pyi +175 -61
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2.pyi +38 -16
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
- flwr/proto/heartbeat_pb2.py +17 -7
- flwr/proto/heartbeat_pb2.pyi +51 -22
- flwr/proto/heartbeat_pb2_grpc.py +20 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2.pyi +34 -11
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/log_pb2_grpc.pyi +27 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2.pyi +154 -86
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/message_pb2_grpc.pyi +27 -0
- flwr/proto/node_pb2.py +15 -5
- flwr/proto/node_pb2.pyi +50 -25
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/node_pb2_grpc.pyi +27 -0
- flwr/proto/recorddict_pb2.py +13 -3
- flwr/proto/recorddict_pb2.pyi +184 -107
- flwr/proto/recorddict_pb2_grpc.py +20 -0
- flwr/proto/recorddict_pb2_grpc.pyi +27 -0
- flwr/proto/run_pb2.py +40 -31
- flwr/proto/run_pb2.pyi +149 -84
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/run_pb2_grpc.pyi +27 -0
- flwr/proto/serverappio_pb2.py +13 -3
- flwr/proto/serverappio_pb2.pyi +32 -8
- flwr/proto/serverappio_pb2_grpc.py +246 -65
- flwr/proto/serverappio_pb2_grpc.pyi +221 -85
- flwr/proto/simulationio_pb2.py +16 -8
- flwr/proto/simulationio_pb2.pyi +15 -0
- flwr/proto/simulationio_pb2_grpc.py +162 -41
- flwr/proto/simulationio_pb2_grpc.pyi +149 -55
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2.pyi +249 -160
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/proto/transport_pb2_grpc.pyi +38 -8
- flwr/server/app.py +38 -17
- flwr/server/client_manager.py +4 -5
- flwr/server/client_proxy.py +10 -11
- flwr/server/compat/app.py +4 -5
- flwr/server/compat/app_utils.py +2 -1
- flwr/server/compat/grid_client_proxy.py +10 -12
- flwr/server/compat/legacy_context.py +3 -4
- flwr/server/fleet_event_log_interceptor.py +2 -1
- flwr/server/grid/grid.py +2 -3
- flwr/server/grid/grpc_grid.py +10 -8
- flwr/server/grid/inmemory_grid.py +4 -4
- flwr/server/run_serverapp.py +2 -3
- flwr/server/server.py +34 -39
- flwr/server/server_app.py +7 -8
- flwr/server/server_config.py +1 -2
- flwr/server/serverapp/app.py +34 -28
- flwr/server/serverapp_components.py +4 -5
- flwr/server/strategy/aggregate.py +9 -8
- flwr/server/strategy/bulyan.py +13 -11
- flwr/server/strategy/dp_adaptive_clipping.py +16 -20
- flwr/server/strategy/dp_fixed_clipping.py +12 -17
- flwr/server/strategy/dpfedavg_adaptive.py +3 -4
- flwr/server/strategy/dpfedavg_fixed.py +6 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
- flwr/server/strategy/fedadagrad.py +18 -14
- flwr/server/strategy/fedadam.py +16 -14
- flwr/server/strategy/fedavg.py +16 -17
- flwr/server/strategy/fedavg_android.py +15 -15
- flwr/server/strategy/fedavgm.py +21 -18
- flwr/server/strategy/fedmedian.py +2 -3
- flwr/server/strategy/fedopt.py +11 -10
- flwr/server/strategy/fedprox.py +10 -9
- flwr/server/strategy/fedtrimmedavg.py +12 -11
- flwr/server/strategy/fedxgb_bagging.py +13 -11
- flwr/server/strategy/fedxgb_cyclic.py +6 -6
- flwr/server/strategy/fedxgb_nn_avg.py +4 -4
- flwr/server/strategy/fedyogi.py +16 -14
- flwr/server/strategy/krum.py +12 -11
- flwr/server/strategy/qfedavg.py +16 -15
- flwr/server/strategy/strategy.py +6 -9
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
- flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +34 -28
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +15 -9
- flwr/server/superlink/linkstate/in_memory_linkstate.py +115 -150
- flwr/server/superlink/linkstate/linkstate.py +59 -43
- flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
- flwr/server/superlink/linkstate/sqlite_linkstate.py +447 -438
- flwr/server/superlink/linkstate/utils.py +6 -6
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
- flwr/server/superlink/utils.py +4 -6
- flwr/server/typing.py +1 -1
- flwr/server/utils/tensorboard.py +15 -8
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
- flwr/serverapp/strategy/bulyan.py +16 -15
- flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
- flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
- flwr/serverapp/strategy/fedadagrad.py +10 -11
- flwr/serverapp/strategy/fedadam.py +10 -11
- flwr/serverapp/strategy/fedavg.py +9 -10
- flwr/serverapp/strategy/fedavgm.py +17 -16
- flwr/serverapp/strategy/fedmedian.py +2 -2
- flwr/serverapp/strategy/fedopt.py +10 -11
- flwr/serverapp/strategy/fedprox.py +7 -8
- flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
- flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
- flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
- flwr/serverapp/strategy/fedyogi.py +9 -11
- flwr/serverapp/strategy/krum.py +7 -7
- flwr/serverapp/strategy/multikrum.py +9 -9
- flwr/serverapp/strategy/qfedavg.py +17 -16
- flwr/serverapp/strategy/strategy.py +6 -9
- flwr/serverapp/strategy/strategy_utils.py +7 -8
- flwr/simulation/app.py +46 -42
- flwr/simulation/legacy_app.py +12 -12
- flwr/simulation/ray_transport/ray_actor.py +10 -11
- flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
- flwr/simulation/run_simulation.py +43 -43
- flwr/simulation/simulationio_connection.py +4 -4
- flwr/supercore/cli/flower_superexec.py +3 -4
- flwr/supercore/constant.py +31 -1
- flwr/supercore/corestate/corestate.py +24 -3
- flwr/supercore/corestate/in_memory_corestate.py +138 -0
- flwr/supercore/corestate/sqlite_corestate.py +157 -0
- flwr/supercore/ffs/disk_ffs.py +1 -2
- flwr/supercore/ffs/ffs.py +1 -2
- flwr/supercore/ffs/ffs_factory.py +1 -2
- flwr/{common → supercore}/heartbeat.py +20 -25
- flwr/supercore/object_store/in_memory_object_store.py +1 -2
- flwr/supercore/object_store/object_store.py +1 -2
- flwr/supercore/object_store/object_store_factory.py +1 -2
- flwr/supercore/object_store/sqlite_object_store.py +8 -7
- flwr/supercore/primitives/asymmetric.py +1 -1
- flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
- flwr/supercore/sqlite_mixin.py +37 -34
- flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
- flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
- flwr/supercore/superexec/run_superexec.py +9 -13
- flwr/superlink/artifact_provider/artifact_provider.py +1 -2
- flwr/superlink/auth_plugin/auth_plugin.py +6 -9
- flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
- flwr/superlink/federation/__init__.py +24 -0
- flwr/superlink/federation/federation_manager.py +64 -0
- flwr/superlink/federation/noop_federation_manager.py +71 -0
- flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
- flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
- flwr/superlink/servicer/control/control_grpc.py +5 -6
- flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
- flwr/superlink/servicer/control/control_servicer.py +102 -18
- flwr/supernode/cli/flower_supernode.py +58 -3
- flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
- flwr/supernode/nodestate/nodestate.py +7 -8
- flwr/supernode/nodestate/nodestate_factory.py +7 -4
- flwr/supernode/runtime/run_clientapp.py +41 -22
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
- flwr/supernode/start_client_internal.py +158 -42
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
- flwr-1.24.0.dist-info/RECORD +454 -0
- flwr/supercore/object_store/utils.py +0 -43
- flwr-1.23.0.dist-info/RECORD +0 -439
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
flwr/server/strategy/fedavgm.py
CHANGED
|
@@ -18,8 +18,8 @@ Paper: arxiv.org/pdf/1909.06335.pdf
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
+
from collections.abc import Callable
|
|
21
22
|
from logging import WARNING
|
|
22
|
-
from typing import Callable, Optional, Union
|
|
23
23
|
|
|
24
24
|
from flwr.common import (
|
|
25
25
|
FitRes,
|
|
@@ -82,18 +82,19 @@ class FedAvgM(FedAvg):
|
|
|
82
82
|
min_fit_clients: int = 2,
|
|
83
83
|
min_evaluate_clients: int = 2,
|
|
84
84
|
min_available_clients: int = 2,
|
|
85
|
-
evaluate_fn:
|
|
85
|
+
evaluate_fn: (
|
|
86
86
|
Callable[
|
|
87
87
|
[int, NDArrays, dict[str, Scalar]],
|
|
88
|
-
|
|
88
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
89
89
|
]
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
90
|
+
| None
|
|
91
|
+
) = None,
|
|
92
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
93
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
93
94
|
accept_failures: bool = True,
|
|
94
|
-
initial_parameters:
|
|
95
|
-
fit_metrics_aggregation_fn:
|
|
96
|
-
evaluate_metrics_aggregation_fn:
|
|
95
|
+
initial_parameters: Parameters | None = None,
|
|
96
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
97
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
97
98
|
server_learning_rate: float = 1.0,
|
|
98
99
|
server_momentum: float = 0.0,
|
|
99
100
|
) -> None:
|
|
@@ -116,16 +117,14 @@ class FedAvgM(FedAvg):
|
|
|
116
117
|
self.server_opt: bool = (self.server_momentum != 0.0) or (
|
|
117
118
|
self.server_learning_rate != 1.0
|
|
118
119
|
)
|
|
119
|
-
self.momentum_vector:
|
|
120
|
+
self.momentum_vector: NDArrays | None = None
|
|
120
121
|
|
|
121
122
|
def __repr__(self) -> str:
|
|
122
123
|
"""Compute a string representation of the strategy."""
|
|
123
124
|
rep = f"FedAvgM(accept_failures={self.accept_failures})"
|
|
124
125
|
return rep
|
|
125
126
|
|
|
126
|
-
def initialize_parameters(
|
|
127
|
-
self, client_manager: ClientManager
|
|
128
|
-
) -> Optional[Parameters]:
|
|
127
|
+
def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
|
|
129
128
|
"""Initialize global model parameters."""
|
|
130
129
|
return self.initial_parameters
|
|
131
130
|
|
|
@@ -133,8 +132,8 @@ class FedAvgM(FedAvg):
|
|
|
133
132
|
self,
|
|
134
133
|
server_round: int,
|
|
135
134
|
results: list[tuple[ClientProxy, FitRes]],
|
|
136
|
-
failures: list[
|
|
137
|
-
) -> tuple[
|
|
135
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
136
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
138
137
|
"""Aggregate fit results using weighted average."""
|
|
139
138
|
if not results:
|
|
140
139
|
return None, {}
|
|
@@ -161,7 +160,9 @@ class FedAvgM(FedAvg):
|
|
|
161
160
|
pseudo_gradient: NDArrays = [
|
|
162
161
|
x - y
|
|
163
162
|
for x, y in zip(
|
|
164
|
-
parameters_to_ndarrays(self.initial_parameters),
|
|
163
|
+
parameters_to_ndarrays(self.initial_parameters),
|
|
164
|
+
fedavg_result,
|
|
165
|
+
strict=True,
|
|
165
166
|
)
|
|
166
167
|
]
|
|
167
168
|
if self.server_momentum > 0.0:
|
|
@@ -171,7 +172,9 @@ class FedAvgM(FedAvg):
|
|
|
171
172
|
), "Momentum should have been created on round 1."
|
|
172
173
|
self.momentum_vector = [
|
|
173
174
|
self.server_momentum * x + y
|
|
174
|
-
for x, y in zip(
|
|
175
|
+
for x, y in zip(
|
|
176
|
+
self.momentum_vector, pseudo_gradient, strict=True
|
|
177
|
+
)
|
|
175
178
|
]
|
|
176
179
|
else:
|
|
177
180
|
self.momentum_vector = pseudo_gradient
|
|
@@ -182,7 +185,7 @@ class FedAvgM(FedAvg):
|
|
|
182
185
|
# SGD
|
|
183
186
|
fedavg_result = [
|
|
184
187
|
x - self.server_learning_rate * y
|
|
185
|
-
for x, y in zip(initial_weights, pseudo_gradient)
|
|
188
|
+
for x, y in zip(initial_weights, pseudo_gradient, strict=True)
|
|
186
189
|
]
|
|
187
190
|
# Update current weights
|
|
188
191
|
self.initial_parameters = ndarrays_to_parameters(fedavg_result)
|
|
@@ -19,7 +19,6 @@ Paper: arxiv.org/pdf/1803.01498v1.pdf
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
from logging import WARNING
|
|
22
|
-
from typing import Optional, Union
|
|
23
22
|
|
|
24
23
|
from flwr.common import (
|
|
25
24
|
FitRes,
|
|
@@ -47,8 +46,8 @@ class FedMedian(FedAvg):
|
|
|
47
46
|
self,
|
|
48
47
|
server_round: int,
|
|
49
48
|
results: list[tuple[ClientProxy, FitRes]],
|
|
50
|
-
failures: list[
|
|
51
|
-
) -> tuple[
|
|
49
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
50
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
52
51
|
"""Aggregate fit results using median."""
|
|
53
52
|
if not results:
|
|
54
53
|
return None, {}
|
flwr/server/strategy/fedopt.py
CHANGED
|
@@ -18,7 +18,7 @@ Paper: arxiv.org/abs/2003.00295
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from
|
|
21
|
+
from collections.abc import Callable
|
|
22
22
|
|
|
23
23
|
from flwr.common import (
|
|
24
24
|
MetricsAggregationFn,
|
|
@@ -84,18 +84,19 @@ class FedOpt(FedAvg):
|
|
|
84
84
|
min_fit_clients: int = 2,
|
|
85
85
|
min_evaluate_clients: int = 2,
|
|
86
86
|
min_available_clients: int = 2,
|
|
87
|
-
evaluate_fn:
|
|
87
|
+
evaluate_fn: (
|
|
88
88
|
Callable[
|
|
89
89
|
[int, NDArrays, dict[str, Scalar]],
|
|
90
|
-
|
|
90
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
91
91
|
]
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
92
|
+
| None
|
|
93
|
+
) = None,
|
|
94
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
95
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
95
96
|
accept_failures: bool = True,
|
|
96
97
|
initial_parameters: Parameters,
|
|
97
|
-
fit_metrics_aggregation_fn:
|
|
98
|
-
evaluate_metrics_aggregation_fn:
|
|
98
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
99
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
99
100
|
eta: float = 1e-1,
|
|
100
101
|
eta_l: float = 1e-1,
|
|
101
102
|
beta_1: float = 0.0,
|
|
@@ -122,8 +123,8 @@ class FedOpt(FedAvg):
|
|
|
122
123
|
self.tau = tau
|
|
123
124
|
self.beta_1 = beta_1
|
|
124
125
|
self.beta_2 = beta_2
|
|
125
|
-
self.m_t:
|
|
126
|
-
self.v_t:
|
|
126
|
+
self.m_t: NDArrays | None = None
|
|
127
|
+
self.v_t: NDArrays | None = None
|
|
127
128
|
|
|
128
129
|
def __repr__(self) -> str:
|
|
129
130
|
"""Compute a string representation of the strategy."""
|
flwr/server/strategy/fedprox.py
CHANGED
|
@@ -18,7 +18,7 @@ Paper: arxiv.org/abs/1812.06127
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from
|
|
21
|
+
from collections.abc import Callable
|
|
22
22
|
|
|
23
23
|
from flwr.common import FitIns, MetricsAggregationFn, NDArrays, Parameters, Scalar
|
|
24
24
|
from flwr.server.client_manager import ClientManager
|
|
@@ -111,18 +111,19 @@ class FedProx(FedAvg):
|
|
|
111
111
|
min_fit_clients: int = 2,
|
|
112
112
|
min_evaluate_clients: int = 2,
|
|
113
113
|
min_available_clients: int = 2,
|
|
114
|
-
evaluate_fn:
|
|
114
|
+
evaluate_fn: (
|
|
115
115
|
Callable[
|
|
116
116
|
[int, NDArrays, dict[str, Scalar]],
|
|
117
|
-
|
|
117
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
118
118
|
]
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
119
|
+
| None
|
|
120
|
+
) = None,
|
|
121
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
122
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
122
123
|
accept_failures: bool = True,
|
|
123
|
-
initial_parameters:
|
|
124
|
-
fit_metrics_aggregation_fn:
|
|
125
|
-
evaluate_metrics_aggregation_fn:
|
|
124
|
+
initial_parameters: Parameters | None = None,
|
|
125
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
126
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
126
127
|
proximal_mu: float,
|
|
127
128
|
) -> None:
|
|
128
129
|
super().__init__(
|
|
@@ -16,8 +16,8 @@
|
|
|
16
16
|
|
|
17
17
|
Paper: arxiv.org/abs/1803.01498
|
|
18
18
|
"""
|
|
19
|
+
from collections.abc import Callable
|
|
19
20
|
from logging import WARNING
|
|
20
|
-
from typing import Callable, Optional, Union
|
|
21
21
|
|
|
22
22
|
from flwr.common import (
|
|
23
23
|
FitRes,
|
|
@@ -76,18 +76,19 @@ class FedTrimmedAvg(FedAvg):
|
|
|
76
76
|
min_fit_clients: int = 2,
|
|
77
77
|
min_evaluate_clients: int = 2,
|
|
78
78
|
min_available_clients: int = 2,
|
|
79
|
-
evaluate_fn:
|
|
79
|
+
evaluate_fn: (
|
|
80
80
|
Callable[
|
|
81
81
|
[int, NDArrays, dict[str, Scalar]],
|
|
82
|
-
|
|
82
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
83
83
|
]
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
84
|
+
| None
|
|
85
|
+
) = None,
|
|
86
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
87
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
87
88
|
accept_failures: bool = True,
|
|
88
|
-
initial_parameters:
|
|
89
|
-
fit_metrics_aggregation_fn:
|
|
90
|
-
evaluate_metrics_aggregation_fn:
|
|
89
|
+
initial_parameters: Parameters | None = None,
|
|
90
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
91
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
91
92
|
beta: float = 0.2,
|
|
92
93
|
) -> None:
|
|
93
94
|
super().__init__(
|
|
@@ -115,8 +116,8 @@ class FedTrimmedAvg(FedAvg):
|
|
|
115
116
|
self,
|
|
116
117
|
server_round: int,
|
|
117
118
|
results: list[tuple[ClientProxy, FitRes]],
|
|
118
|
-
failures: list[
|
|
119
|
-
) -> tuple[
|
|
119
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
120
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
120
121
|
"""Aggregate fit results using trimmed average."""
|
|
121
122
|
if not results:
|
|
122
123
|
return None, {}
|
|
@@ -16,8 +16,9 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import json
|
|
19
|
+
from collections.abc import Callable
|
|
19
20
|
from logging import WARNING
|
|
20
|
-
from typing import Any,
|
|
21
|
+
from typing import Any, cast
|
|
21
22
|
|
|
22
23
|
from flwr.common import EvaluateRes, FitRes, Parameters, Scalar
|
|
23
24
|
from flwr.common.logger import log
|
|
@@ -32,16 +33,17 @@ class FedXgbBagging(FedAvg):
|
|
|
32
33
|
# pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long
|
|
33
34
|
def __init__(
|
|
34
35
|
self,
|
|
35
|
-
evaluate_function:
|
|
36
|
+
evaluate_function: (
|
|
36
37
|
Callable[
|
|
37
38
|
[int, Parameters, dict[str, Scalar]],
|
|
38
|
-
|
|
39
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
39
40
|
]
|
|
40
|
-
|
|
41
|
+
| None
|
|
42
|
+
) = None,
|
|
41
43
|
**kwargs: Any,
|
|
42
44
|
):
|
|
43
45
|
self.evaluate_function = evaluate_function
|
|
44
|
-
self.global_model:
|
|
46
|
+
self.global_model: bytes | None = None
|
|
45
47
|
super().__init__(**kwargs)
|
|
46
48
|
|
|
47
49
|
def __repr__(self) -> str:
|
|
@@ -53,8 +55,8 @@ class FedXgbBagging(FedAvg):
|
|
|
53
55
|
self,
|
|
54
56
|
server_round: int,
|
|
55
57
|
results: list[tuple[ClientProxy, FitRes]],
|
|
56
|
-
failures: list[
|
|
57
|
-
) -> tuple[
|
|
58
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
59
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
58
60
|
"""Aggregate fit results using bagging."""
|
|
59
61
|
if not results:
|
|
60
62
|
return None, {}
|
|
@@ -80,8 +82,8 @@ class FedXgbBagging(FedAvg):
|
|
|
80
82
|
self,
|
|
81
83
|
server_round: int,
|
|
82
84
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
83
|
-
failures: list[
|
|
84
|
-
) -> tuple[
|
|
85
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
86
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
85
87
|
"""Aggregate evaluation metrics using average."""
|
|
86
88
|
if not results:
|
|
87
89
|
return None, {}
|
|
@@ -101,7 +103,7 @@ class FedXgbBagging(FedAvg):
|
|
|
101
103
|
|
|
102
104
|
def evaluate(
|
|
103
105
|
self, server_round: int, parameters: Parameters
|
|
104
|
-
) ->
|
|
106
|
+
) -> tuple[float, dict[str, Scalar]] | None:
|
|
105
107
|
"""Evaluate model parameters using an evaluation function."""
|
|
106
108
|
if self.evaluate_function is None:
|
|
107
109
|
# No evaluation function provided
|
|
@@ -114,7 +116,7 @@ class FedXgbBagging(FedAvg):
|
|
|
114
116
|
|
|
115
117
|
|
|
116
118
|
def aggregate(
|
|
117
|
-
bst_prev_org:
|
|
119
|
+
bst_prev_org: bytes | None,
|
|
118
120
|
bst_curr_org: bytes,
|
|
119
121
|
) -> bytes:
|
|
120
122
|
"""Conduct bagging aggregation for given trees."""
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from logging import WARNING
|
|
19
|
-
from typing import Any,
|
|
19
|
+
from typing import Any, cast
|
|
20
20
|
|
|
21
21
|
from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar
|
|
22
22
|
from flwr.common.logger import log
|
|
@@ -34,7 +34,7 @@ class FedXgbCyclic(FedAvg):
|
|
|
34
34
|
self,
|
|
35
35
|
**kwargs: Any,
|
|
36
36
|
):
|
|
37
|
-
self.global_model:
|
|
37
|
+
self.global_model: bytes | None = None
|
|
38
38
|
super().__init__(**kwargs)
|
|
39
39
|
|
|
40
40
|
def __repr__(self) -> str:
|
|
@@ -46,8 +46,8 @@ class FedXgbCyclic(FedAvg):
|
|
|
46
46
|
self,
|
|
47
47
|
server_round: int,
|
|
48
48
|
results: list[tuple[ClientProxy, FitRes]],
|
|
49
|
-
failures: list[
|
|
50
|
-
) -> tuple[
|
|
49
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
50
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
51
51
|
"""Aggregate fit results using bagging."""
|
|
52
52
|
if not results:
|
|
53
53
|
return None, {}
|
|
@@ -70,8 +70,8 @@ class FedXgbCyclic(FedAvg):
|
|
|
70
70
|
self,
|
|
71
71
|
server_round: int,
|
|
72
72
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
73
|
-
failures: list[
|
|
74
|
-
) -> tuple[
|
|
73
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
74
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
75
75
|
"""Aggregate evaluation metrics using average."""
|
|
76
76
|
if not results:
|
|
77
77
|
return None, {}
|
|
@@ -22,7 +22,7 @@ Paper: arxiv.org/abs/2304.07537
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
from logging import WARNING
|
|
25
|
-
from typing import Any
|
|
25
|
+
from typing import Any
|
|
26
26
|
|
|
27
27
|
from flwr.common import FitRes, Scalar, ndarrays_to_parameters, parameters_to_ndarrays
|
|
28
28
|
from flwr.common.logger import log, warn_deprecated_feature
|
|
@@ -57,7 +57,7 @@ class FedXgbNnAvg(FedAvg):
|
|
|
57
57
|
|
|
58
58
|
def evaluate(
|
|
59
59
|
self, server_round: int, parameters: Any
|
|
60
|
-
) ->
|
|
60
|
+
) -> tuple[float, dict[str, Scalar]] | None:
|
|
61
61
|
"""Evaluate model parameters using an evaluation function."""
|
|
62
62
|
if self.evaluate_fn is None:
|
|
63
63
|
# No evaluation function provided
|
|
@@ -72,8 +72,8 @@ class FedXgbNnAvg(FedAvg):
|
|
|
72
72
|
self,
|
|
73
73
|
server_round: int,
|
|
74
74
|
results: list[tuple[ClientProxy, FitRes]],
|
|
75
|
-
failures: list[
|
|
76
|
-
) -> tuple[
|
|
75
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
76
|
+
) -> tuple[Any | None, dict[str, Scalar]]:
|
|
77
77
|
"""Aggregate fit results using weighted average."""
|
|
78
78
|
if not results:
|
|
79
79
|
return None, {}
|
flwr/server/strategy/fedyogi.py
CHANGED
|
@@ -18,7 +18,7 @@ Paper: arxiv.org/abs/2003.00295
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from
|
|
21
|
+
from collections.abc import Callable
|
|
22
22
|
|
|
23
23
|
import numpy as np
|
|
24
24
|
|
|
@@ -91,18 +91,19 @@ class FedYogi(FedOpt):
|
|
|
91
91
|
min_fit_clients: int = 2,
|
|
92
92
|
min_evaluate_clients: int = 2,
|
|
93
93
|
min_available_clients: int = 2,
|
|
94
|
-
evaluate_fn:
|
|
94
|
+
evaluate_fn: (
|
|
95
95
|
Callable[
|
|
96
96
|
[int, NDArrays, dict[str, Scalar]],
|
|
97
|
-
|
|
97
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
98
98
|
]
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
99
|
+
| None
|
|
100
|
+
) = None,
|
|
101
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
102
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
102
103
|
accept_failures: bool = True,
|
|
103
104
|
initial_parameters: Parameters,
|
|
104
|
-
fit_metrics_aggregation_fn:
|
|
105
|
-
evaluate_metrics_aggregation_fn:
|
|
105
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
106
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
106
107
|
eta: float = 1e-2,
|
|
107
108
|
eta_l: float = 0.0316,
|
|
108
109
|
beta_1: float = 0.9,
|
|
@@ -138,8 +139,8 @@ class FedYogi(FedOpt):
|
|
|
138
139
|
self,
|
|
139
140
|
server_round: int,
|
|
140
141
|
results: list[tuple[ClientProxy, FitRes]],
|
|
141
|
-
failures: list[
|
|
142
|
-
) -> tuple[
|
|
142
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
143
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
143
144
|
"""Aggregate fit results using weighted average."""
|
|
144
145
|
fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit(
|
|
145
146
|
server_round=server_round, results=results, failures=failures
|
|
@@ -151,7 +152,8 @@ class FedYogi(FedOpt):
|
|
|
151
152
|
|
|
152
153
|
# Yogi
|
|
153
154
|
delta_t: NDArrays = [
|
|
154
|
-
x - y
|
|
155
|
+
x - y
|
|
156
|
+
for x, y in zip(fedavg_weights_aggregate, self.current_weights, strict=True)
|
|
155
157
|
]
|
|
156
158
|
|
|
157
159
|
# m_t
|
|
@@ -159,7 +161,7 @@ class FedYogi(FedOpt):
|
|
|
159
161
|
self.m_t = [np.zeros_like(x) for x in delta_t]
|
|
160
162
|
self.m_t = [
|
|
161
163
|
np.multiply(self.beta_1, x) + (1 - self.beta_1) * y
|
|
162
|
-
for x, y in zip(self.m_t, delta_t)
|
|
164
|
+
for x, y in zip(self.m_t, delta_t, strict=True)
|
|
163
165
|
]
|
|
164
166
|
|
|
165
167
|
# v_t
|
|
@@ -167,12 +169,12 @@ class FedYogi(FedOpt):
|
|
|
167
169
|
self.v_t = [np.zeros_like(x) for x in delta_t]
|
|
168
170
|
self.v_t = [
|
|
169
171
|
x - (1.0 - self.beta_2) * np.multiply(y, y) * np.sign(x - np.multiply(y, y))
|
|
170
|
-
for x, y in zip(self.v_t, delta_t)
|
|
172
|
+
for x, y in zip(self.v_t, delta_t, strict=True)
|
|
171
173
|
]
|
|
172
174
|
|
|
173
175
|
new_weights = [
|
|
174
176
|
x + self.eta * y / (np.sqrt(z) + self.tau)
|
|
175
|
-
for x, y, z in zip(self.current_weights, self.m_t, self.v_t)
|
|
177
|
+
for x, y, z in zip(self.current_weights, self.m_t, self.v_t, strict=True)
|
|
176
178
|
]
|
|
177
179
|
|
|
178
180
|
self.current_weights = new_weights
|
flwr/server/strategy/krum.py
CHANGED
|
@@ -20,8 +20,8 @@ Paper: proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-P
|
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
22
|
|
|
23
|
+
from collections.abc import Callable
|
|
23
24
|
from logging import WARNING
|
|
24
|
-
from typing import Callable, Optional, Union
|
|
25
25
|
|
|
26
26
|
from flwr.common import (
|
|
27
27
|
FitRes,
|
|
@@ -85,18 +85,19 @@ class Krum(FedAvg):
|
|
|
85
85
|
min_available_clients: int = 2,
|
|
86
86
|
num_malicious_clients: int = 0,
|
|
87
87
|
num_clients_to_keep: int = 0,
|
|
88
|
-
evaluate_fn:
|
|
88
|
+
evaluate_fn: (
|
|
89
89
|
Callable[
|
|
90
90
|
[int, NDArrays, dict[str, Scalar]],
|
|
91
|
-
|
|
91
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
92
92
|
]
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
93
|
+
| None
|
|
94
|
+
) = None,
|
|
95
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
96
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
96
97
|
accept_failures: bool = True,
|
|
97
|
-
initial_parameters:
|
|
98
|
-
fit_metrics_aggregation_fn:
|
|
99
|
-
evaluate_metrics_aggregation_fn:
|
|
98
|
+
initial_parameters: Parameters | None = None,
|
|
99
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
100
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
100
101
|
) -> None:
|
|
101
102
|
super().__init__(
|
|
102
103
|
fraction_fit=fraction_fit,
|
|
@@ -124,8 +125,8 @@ class Krum(FedAvg):
|
|
|
124
125
|
self,
|
|
125
126
|
server_round: int,
|
|
126
127
|
results: list[tuple[ClientProxy, FitRes]],
|
|
127
|
-
failures: list[
|
|
128
|
-
) -> tuple[
|
|
128
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
129
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
129
130
|
"""Aggregate fit results using Krum."""
|
|
130
131
|
if not results:
|
|
131
132
|
return None, {}
|
flwr/server/strategy/qfedavg.py
CHANGED
|
@@ -18,8 +18,8 @@ Paper: openreview.net/pdf?id=ByexElSYDr
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
+
from collections.abc import Callable
|
|
21
22
|
from logging import WARNING
|
|
22
|
-
from typing import Callable, Optional, Union
|
|
23
23
|
|
|
24
24
|
import numpy as np
|
|
25
25
|
|
|
@@ -58,18 +58,19 @@ class QFedAvg(FedAvg):
|
|
|
58
58
|
min_fit_clients: int = 1,
|
|
59
59
|
min_evaluate_clients: int = 1,
|
|
60
60
|
min_available_clients: int = 1,
|
|
61
|
-
evaluate_fn:
|
|
61
|
+
evaluate_fn: (
|
|
62
62
|
Callable[
|
|
63
63
|
[int, NDArrays, dict[str, Scalar]],
|
|
64
|
-
|
|
64
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
65
65
|
]
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
66
|
+
| None
|
|
67
|
+
) = None,
|
|
68
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
69
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
69
70
|
accept_failures: bool = True,
|
|
70
|
-
initial_parameters:
|
|
71
|
-
fit_metrics_aggregation_fn:
|
|
72
|
-
evaluate_metrics_aggregation_fn:
|
|
71
|
+
initial_parameters: Parameters | None = None,
|
|
72
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
73
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
73
74
|
) -> None:
|
|
74
75
|
super().__init__(
|
|
75
76
|
fraction_fit=fraction_fit,
|
|
@@ -87,7 +88,7 @@ class QFedAvg(FedAvg):
|
|
|
87
88
|
)
|
|
88
89
|
self.learning_rate = qffl_learning_rate
|
|
89
90
|
self.q_param = q_param
|
|
90
|
-
self.pre_weights:
|
|
91
|
+
self.pre_weights: NDArrays | None = None
|
|
91
92
|
|
|
92
93
|
def __repr__(self) -> str:
|
|
93
94
|
"""Compute a string representation of the strategy."""
|
|
@@ -159,8 +160,8 @@ class QFedAvg(FedAvg):
|
|
|
159
160
|
self,
|
|
160
161
|
server_round: int,
|
|
161
162
|
results: list[tuple[ClientProxy, FitRes]],
|
|
162
|
-
failures: list[
|
|
163
|
-
) -> tuple[
|
|
163
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
164
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
164
165
|
"""Aggregate fit results using weighted average."""
|
|
165
166
|
if not results:
|
|
166
167
|
return None, {}
|
|
@@ -199,7 +200,7 @@ class QFedAvg(FedAvg):
|
|
|
199
200
|
# plug in the weight updates into the gradient
|
|
200
201
|
grads = [
|
|
201
202
|
np.multiply((u - v), 1.0 / self.learning_rate)
|
|
202
|
-
for u, v in zip(weights_before, new_weights)
|
|
203
|
+
for u, v in zip(weights_before, new_weights, strict=True)
|
|
203
204
|
]
|
|
204
205
|
deltas.append(
|
|
205
206
|
[np.float_power(loss + 1e-10, self.q_param) * grad for grad in grads]
|
|
@@ -230,8 +231,8 @@ class QFedAvg(FedAvg):
|
|
|
230
231
|
self,
|
|
231
232
|
server_round: int,
|
|
232
233
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
233
|
-
failures: list[
|
|
234
|
-
) -> tuple[
|
|
234
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
235
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
235
236
|
"""Aggregate evaluation losses using weighted average."""
|
|
236
237
|
if not results:
|
|
237
238
|
return None, {}
|
flwr/server/strategy/strategy.py
CHANGED
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from abc import ABC, abstractmethod
|
|
19
|
-
from typing import Optional, Union
|
|
20
19
|
|
|
21
20
|
from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar
|
|
22
21
|
from flwr.server.client_manager import ClientManager
|
|
@@ -27,9 +26,7 @@ class Strategy(ABC):
|
|
|
27
26
|
"""Abstract base class for server strategy implementations."""
|
|
28
27
|
|
|
29
28
|
@abstractmethod
|
|
30
|
-
def initialize_parameters(
|
|
31
|
-
self, client_manager: ClientManager
|
|
32
|
-
) -> Optional[Parameters]:
|
|
29
|
+
def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
|
|
33
30
|
"""Initialize the (global) model parameters.
|
|
34
31
|
|
|
35
32
|
Parameters
|
|
@@ -73,8 +70,8 @@ class Strategy(ABC):
|
|
|
73
70
|
self,
|
|
74
71
|
server_round: int,
|
|
75
72
|
results: list[tuple[ClientProxy, FitRes]],
|
|
76
|
-
failures: list[
|
|
77
|
-
) -> tuple[
|
|
73
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
74
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
78
75
|
"""Aggregate training results.
|
|
79
76
|
|
|
80
77
|
Parameters
|
|
@@ -135,8 +132,8 @@ class Strategy(ABC):
|
|
|
135
132
|
self,
|
|
136
133
|
server_round: int,
|
|
137
134
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
138
|
-
failures: list[
|
|
139
|
-
) -> tuple[
|
|
135
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
136
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
140
137
|
"""Aggregate evaluation results.
|
|
141
138
|
|
|
142
139
|
Parameters
|
|
@@ -164,7 +161,7 @@ class Strategy(ABC):
|
|
|
164
161
|
@abstractmethod
|
|
165
162
|
def evaluate(
|
|
166
163
|
self, server_round: int, parameters: Parameters
|
|
167
|
-
) ->
|
|
164
|
+
) -> tuple[float, dict[str, Scalar]] | None:
|
|
168
165
|
"""Evaluate the current model parameters.
|
|
169
166
|
|
|
170
167
|
This function can be used to perform centralized (i.e., server-side) evaluation
|
|
@@ -15,8 +15,9 @@
|
|
|
15
15
|
"""Fleet API gRPC adapter servicer."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from collections.abc import Callable
|
|
18
19
|
from logging import DEBUG
|
|
19
|
-
from typing import
|
|
20
|
+
from typing import TypeVar
|
|
20
21
|
|
|
21
22
|
import grpc
|
|
22
23
|
from google.protobuf.message import Message as GrpcMessage
|