flwr 1.22.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +16 -5
- flwr/app/error.py +2 -2
- flwr/app/exception.py +3 -3
- flwr/cli/app.py +34 -1
- flwr/cli/app_cmd/__init__.py +23 -0
- flwr/cli/app_cmd/publish.py +285 -0
- flwr/cli/app_cmd/review.py +252 -0
- flwr/cli/auth_plugin/__init__.py +15 -6
- flwr/cli/auth_plugin/auth_plugin.py +94 -0
- flwr/cli/auth_plugin/noop_auth_plugin.py +101 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +46 -32
- flwr/cli/build.py +166 -53
- flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +29 -11
- flwr/cli/config_utils.py +101 -13
- flwr/cli/federation/__init__.py +24 -0
- flwr/cli/federation/ls.py +140 -0
- flwr/cli/federation/show.py +317 -0
- flwr/cli/install.py +91 -13
- flwr/cli/log.py +54 -11
- flwr/cli/login/login.py +41 -27
- flwr/cli/ls.py +177 -133
- flwr/cli/new/new.py +175 -40
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/cli/pull.py +12 -7
- flwr/cli/run/run.py +82 -31
- flwr/cli/run_utils.py +130 -0
- flwr/cli/stop.py +27 -9
- flwr/cli/supernode/__init__.py +25 -0
- flwr/cli/supernode/ls.py +268 -0
- flwr/cli/supernode/register.py +190 -0
- flwr/cli/supernode/unregister.py +140 -0
- flwr/cli/utils.py +464 -81
- flwr/client/__init__.py +2 -1
- flwr/client/dpfedavg_numpy_client.py +4 -1
- flwr/client/grpc_adapter_client/connection.py +12 -15
- flwr/client/grpc_rere_client/connection.py +68 -41
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -14
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +5 -7
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +10 -8
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/connection.py +94 -51
- flwr/client/run_info_store.py +4 -5
- flwr/client/typing.py +1 -1
- flwr/clientapp/__init__.py +1 -2
- flwr/{client → clientapp}/client_app.py +9 -10
- flwr/clientapp/mod/centraldp_mods.py +16 -17
- flwr/clientapp/mod/localdp_mod.py +8 -9
- flwr/clientapp/typing.py +1 -1
- flwr/{client/clientapp → clientapp}/utils.py +4 -4
- flwr/common/address.py +1 -2
- flwr/common/args.py +3 -4
- flwr/common/config.py +13 -16
- flwr/common/constant.py +56 -13
- flwr/common/differential_privacy.py +3 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -4
- flwr/common/exit/exit.py +15 -2
- flwr/common/exit/exit_code.py +39 -10
- flwr/common/exit/exit_handler.py +6 -2
- flwr/common/exit/signal_handler.py +5 -5
- flwr/common/grpc.py +6 -6
- flwr/common/inflatable_protobuf_utils.py +1 -1
- flwr/common/inflatable_utils.py +48 -31
- flwr/common/logger.py +19 -19
- flwr/common/message.py +4 -4
- flwr/common/object_ref.py +7 -7
- flwr/common/record/array.py +6 -6
- flwr/common/record/arrayrecord.py +18 -21
- flwr/common/record/configrecord.py +3 -3
- flwr/common/record/recorddict.py +5 -5
- flwr/common/record/typeddict.py +9 -2
- flwr/common/recorddict_compat.py +7 -10
- flwr/common/retry_invoker.py +20 -20
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
- flwr/common/serde.py +9 -6
- flwr/common/serde_utils.py +2 -2
- flwr/common/telemetry.py +9 -5
- flwr/common/typing.py +59 -43
- flwr/compat/client/app.py +39 -38
- flwr/compat/client/grpc_client/connection.py +13 -13
- flwr/compat/server/app.py +5 -6
- flwr/proto/appio_pb2.py +13 -3
- flwr/proto/appio_pb2.pyi +134 -65
- flwr/proto/appio_pb2_grpc.py +20 -0
- flwr/proto/appio_pb2_grpc.pyi +27 -0
- flwr/proto/clientappio_pb2.py +17 -7
- flwr/proto/clientappio_pb2.pyi +15 -0
- flwr/proto/clientappio_pb2_grpc.py +206 -40
- flwr/proto/clientappio_pb2_grpc.pyi +168 -53
- flwr/proto/control_pb2.py +72 -40
- flwr/proto/control_pb2.pyi +319 -87
- flwr/proto/control_pb2_grpc.py +339 -28
- flwr/proto/control_pb2_grpc.pyi +209 -37
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2.pyi +24 -6
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/error_pb2_grpc.pyi +27 -0
- flwr/proto/fab_pb2.py +24 -10
- flwr/proto/fab_pb2.pyi +68 -20
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fab_pb2_grpc.pyi +27 -0
- flwr/proto/federation_pb2.py +38 -0
- flwr/proto/federation_pb2.pyi +56 -0
- flwr/proto/federation_pb2_grpc.py +24 -0
- flwr/proto/federation_pb2_grpc.pyi +31 -0
- flwr/proto/fleet_pb2.py +45 -27
- flwr/proto/fleet_pb2.pyi +186 -70
- flwr/proto/fleet_pb2_grpc.py +277 -66
- flwr/proto/fleet_pb2_grpc.pyi +201 -55
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2.pyi +38 -16
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
- flwr/proto/heartbeat_pb2.py +17 -7
- flwr/proto/heartbeat_pb2.pyi +51 -22
- flwr/proto/heartbeat_pb2_grpc.py +20 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2.pyi +34 -11
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/log_pb2_grpc.pyi +27 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2.pyi +154 -86
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/message_pb2_grpc.pyi +27 -0
- flwr/proto/node_pb2.py +16 -4
- flwr/proto/node_pb2.pyi +77 -4
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/node_pb2_grpc.pyi +27 -0
- flwr/proto/recorddict_pb2.py +13 -3
- flwr/proto/recorddict_pb2.pyi +184 -107
- flwr/proto/recorddict_pb2_grpc.py +20 -0
- flwr/proto/recorddict_pb2_grpc.pyi +27 -0
- flwr/proto/run_pb2.py +40 -31
- flwr/proto/run_pb2.pyi +149 -84
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/run_pb2_grpc.pyi +27 -0
- flwr/proto/serverappio_pb2.py +13 -3
- flwr/proto/serverappio_pb2.pyi +32 -8
- flwr/proto/serverappio_pb2_grpc.py +246 -65
- flwr/proto/serverappio_pb2_grpc.pyi +221 -85
- flwr/proto/simulationio_pb2.py +16 -8
- flwr/proto/simulationio_pb2.pyi +15 -0
- flwr/proto/simulationio_pb2_grpc.py +162 -41
- flwr/proto/simulationio_pb2_grpc.pyi +149 -55
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2.pyi +249 -160
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/proto/transport_pb2_grpc.pyi +38 -8
- flwr/server/app.py +173 -127
- flwr/server/client_manager.py +4 -5
- flwr/server/client_proxy.py +10 -11
- flwr/server/compat/app.py +4 -5
- flwr/server/compat/app_utils.py +2 -1
- flwr/server/compat/grid_client_proxy.py +10 -12
- flwr/server/compat/legacy_context.py +3 -4
- flwr/server/fleet_event_log_interceptor.py +2 -1
- flwr/server/grid/grid.py +2 -3
- flwr/server/grid/grpc_grid.py +10 -8
- flwr/server/grid/inmemory_grid.py +4 -4
- flwr/server/run_serverapp.py +2 -3
- flwr/server/server.py +34 -39
- flwr/server/server_app.py +7 -8
- flwr/server/server_config.py +1 -2
- flwr/server/serverapp/app.py +34 -28
- flwr/server/serverapp_components.py +4 -5
- flwr/server/strategy/aggregate.py +9 -8
- flwr/server/strategy/bulyan.py +13 -11
- flwr/server/strategy/dp_adaptive_clipping.py +16 -20
- flwr/server/strategy/dp_fixed_clipping.py +12 -17
- flwr/server/strategy/dpfedavg_adaptive.py +3 -4
- flwr/server/strategy/dpfedavg_fixed.py +6 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
- flwr/server/strategy/fedadagrad.py +18 -14
- flwr/server/strategy/fedadam.py +16 -14
- flwr/server/strategy/fedavg.py +16 -17
- flwr/server/strategy/fedavg_android.py +15 -15
- flwr/server/strategy/fedavgm.py +21 -18
- flwr/server/strategy/fedmedian.py +2 -3
- flwr/server/strategy/fedopt.py +11 -10
- flwr/server/strategy/fedprox.py +10 -9
- flwr/server/strategy/fedtrimmedavg.py +12 -11
- flwr/server/strategy/fedxgb_bagging.py +13 -11
- flwr/server/strategy/fedxgb_cyclic.py +6 -6
- flwr/server/strategy/fedxgb_nn_avg.py +4 -4
- flwr/server/strategy/fedyogi.py +16 -14
- flwr/server/strategy/krum.py +12 -11
- flwr/server/strategy/qfedavg.py +16 -15
- flwr/server/strategy/strategy.py +6 -9
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +19 -8
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +136 -42
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +28 -51
- flwr/server/superlink/fleet/message_handler/message_handler.py +100 -49
- flwr/server/superlink/fleet/rest_rere/rest_api.py +54 -33
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +6 -6
- flwr/server/superlink/fleet/vce/vce_api.py +32 -13
- flwr/server/superlink/linkstate/in_memory_linkstate.py +266 -207
- flwr/server/superlink/linkstate/linkstate.py +161 -62
- flwr/server/superlink/linkstate/linkstate_factory.py +24 -6
- flwr/server/superlink/linkstate/sqlite_linkstate.py +698 -638
- flwr/server/superlink/linkstate/utils.py +9 -60
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +28 -23
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +19 -14
- flwr/server/superlink/utils.py +4 -6
- flwr/server/typing.py +1 -1
- flwr/server/utils/tensorboard.py +15 -8
- flwr/server/utils/validator.py +2 -3
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +12 -10
- flwr/serverapp/strategy/bulyan.py +16 -15
- flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
- flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
- flwr/serverapp/strategy/fedadagrad.py +10 -11
- flwr/serverapp/strategy/fedadam.py +10 -11
- flwr/serverapp/strategy/fedavg.py +9 -10
- flwr/serverapp/strategy/fedavgm.py +17 -16
- flwr/serverapp/strategy/fedmedian.py +2 -2
- flwr/serverapp/strategy/fedopt.py +10 -11
- flwr/serverapp/strategy/fedprox.py +7 -8
- flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
- flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
- flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
- flwr/serverapp/strategy/fedyogi.py +9 -11
- flwr/serverapp/strategy/krum.py +7 -7
- flwr/serverapp/strategy/multikrum.py +9 -9
- flwr/serverapp/strategy/qfedavg.py +17 -16
- flwr/serverapp/strategy/strategy.py +6 -9
- flwr/serverapp/strategy/strategy_utils.py +7 -8
- flwr/simulation/app.py +46 -42
- flwr/simulation/legacy_app.py +12 -12
- flwr/simulation/ray_transport/ray_actor.py +11 -12
- flwr/simulation/ray_transport/ray_client_proxy.py +12 -13
- flwr/simulation/run_simulation.py +44 -43
- flwr/simulation/simulationio_connection.py +4 -4
- flwr/supercore/cli/flower_superexec.py +3 -4
- flwr/supercore/constant.py +52 -0
- flwr/supercore/corestate/corestate.py +24 -3
- flwr/supercore/corestate/in_memory_corestate.py +138 -0
- flwr/supercore/corestate/sqlite_corestate.py +157 -0
- flwr/supercore/ffs/disk_ffs.py +1 -2
- flwr/supercore/ffs/ffs.py +1 -2
- flwr/supercore/ffs/ffs_factory.py +1 -2
- flwr/{common → supercore}/heartbeat.py +20 -25
- flwr/supercore/object_store/in_memory_object_store.py +1 -6
- flwr/supercore/object_store/object_store.py +1 -2
- flwr/supercore/object_store/object_store_factory.py +27 -8
- flwr/supercore/object_store/sqlite_object_store.py +253 -0
- flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
- flwr/supercore/primitives/asymmetric.py +117 -0
- flwr/supercore/primitives/asymmetric_ed25519.py +175 -0
- flwr/supercore/sqlite_mixin.py +159 -0
- flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
- flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
- flwr/supercore/superexec/run_superexec.py +9 -13
- flwr/supercore/utils.py +20 -0
- flwr/superlink/artifact_provider/artifact_provider.py +1 -2
- flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
- flwr/superlink/auth_plugin/auth_plugin.py +88 -0
- flwr/superlink/auth_plugin/noop_auth_plugin.py +84 -0
- flwr/superlink/federation/__init__.py +24 -0
- flwr/superlink/federation/federation_manager.py +64 -0
- flwr/superlink/federation/noop_federation_manager.py +71 -0
- flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +41 -32
- flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
- flwr/superlink/servicer/control/control_grpc.py +18 -17
- flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
- flwr/superlink/servicer/control/control_servicer.py +239 -63
- flwr/supernode/cli/flower_supernode.py +74 -26
- flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
- flwr/supernode/nodestate/nodestate.py +7 -8
- flwr/supernode/nodestate/nodestate_factory.py +7 -4
- flwr/supernode/runtime/run_clientapp.py +43 -24
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
- flwr/supernode/start_client_internal.py +175 -51
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
- flwr-1.24.0.dist-info/RECORD +454 -0
- flwr/common/auth_plugin/auth_plugin.py +0 -149
- flwr/supercore/object_store/utils.py +0 -43
- flwr-1.22.0.dist-info/RECORD +0 -428
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
|
@@ -20,7 +20,7 @@ Paper: arxiv.org/abs/2003.00295
|
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
from
|
|
23
|
+
from collections.abc import Callable
|
|
24
24
|
|
|
25
25
|
import numpy as np
|
|
26
26
|
|
|
@@ -87,16 +87,17 @@ class FedAdagrad(FedOpt):
|
|
|
87
87
|
min_fit_clients: int = 2,
|
|
88
88
|
min_evaluate_clients: int = 2,
|
|
89
89
|
min_available_clients: int = 2,
|
|
90
|
-
evaluate_fn:
|
|
90
|
+
evaluate_fn: (
|
|
91
91
|
Callable[
|
|
92
92
|
[int, NDArrays, dict[str, Scalar]],
|
|
93
|
-
|
|
93
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
94
94
|
]
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
95
|
+
| None
|
|
96
|
+
) = None,
|
|
97
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
98
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
99
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
100
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
100
101
|
accept_failures: bool = True,
|
|
101
102
|
initial_parameters: Parameters,
|
|
102
103
|
eta: float = 1e-1,
|
|
@@ -132,8 +133,8 @@ class FedAdagrad(FedOpt):
|
|
|
132
133
|
self,
|
|
133
134
|
server_round: int,
|
|
134
135
|
results: list[tuple[ClientProxy, FitRes]],
|
|
135
|
-
failures: list[
|
|
136
|
-
) -> tuple[
|
|
136
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
137
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
137
138
|
"""Aggregate fit results using weighted average."""
|
|
138
139
|
fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit(
|
|
139
140
|
server_round=server_round, results=results, failures=failures
|
|
@@ -145,7 +146,8 @@ class FedAdagrad(FedOpt):
|
|
|
145
146
|
|
|
146
147
|
# Adagrad
|
|
147
148
|
delta_t: NDArrays = [
|
|
148
|
-
x - y
|
|
149
|
+
x - y
|
|
150
|
+
for x, y in zip(fedavg_weights_aggregate, self.current_weights, strict=True)
|
|
149
151
|
]
|
|
150
152
|
|
|
151
153
|
# m_t
|
|
@@ -153,17 +155,19 @@ class FedAdagrad(FedOpt):
|
|
|
153
155
|
self.m_t = [np.zeros_like(x) for x in delta_t]
|
|
154
156
|
self.m_t = [
|
|
155
157
|
np.multiply(self.beta_1, x) + (1 - self.beta_1) * y
|
|
156
|
-
for x, y in zip(self.m_t, delta_t)
|
|
158
|
+
for x, y in zip(self.m_t, delta_t, strict=True)
|
|
157
159
|
]
|
|
158
160
|
|
|
159
161
|
# v_t
|
|
160
162
|
if not self.v_t:
|
|
161
163
|
self.v_t = [np.zeros_like(x) for x in delta_t]
|
|
162
|
-
self.v_t = [
|
|
164
|
+
self.v_t = [
|
|
165
|
+
x + np.multiply(y, y) for x, y in zip(self.v_t, delta_t, strict=True)
|
|
166
|
+
]
|
|
163
167
|
|
|
164
168
|
new_weights = [
|
|
165
169
|
x + self.eta * y / (np.sqrt(z) + self.tau)
|
|
166
|
-
for x, y, z in zip(self.current_weights, self.m_t, self.v_t)
|
|
170
|
+
for x, y, z in zip(self.current_weights, self.m_t, self.v_t, strict=True)
|
|
167
171
|
]
|
|
168
172
|
|
|
169
173
|
self.current_weights = new_weights
|
flwr/server/strategy/fedadam.py
CHANGED
|
@@ -20,7 +20,7 @@ Paper: arxiv.org/abs/2003.00295
|
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
from
|
|
23
|
+
from collections.abc import Callable
|
|
24
24
|
|
|
25
25
|
import numpy as np
|
|
26
26
|
|
|
@@ -91,18 +91,19 @@ class FedAdam(FedOpt):
|
|
|
91
91
|
min_fit_clients: int = 2,
|
|
92
92
|
min_evaluate_clients: int = 2,
|
|
93
93
|
min_available_clients: int = 2,
|
|
94
|
-
evaluate_fn:
|
|
94
|
+
evaluate_fn: (
|
|
95
95
|
Callable[
|
|
96
96
|
[int, NDArrays, dict[str, Scalar]],
|
|
97
|
-
|
|
97
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
98
98
|
]
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
99
|
+
| None
|
|
100
|
+
) = None,
|
|
101
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
102
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
102
103
|
accept_failures: bool = True,
|
|
103
104
|
initial_parameters: Parameters,
|
|
104
|
-
fit_metrics_aggregation_fn:
|
|
105
|
-
evaluate_metrics_aggregation_fn:
|
|
105
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
106
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
106
107
|
eta: float = 1e-1,
|
|
107
108
|
eta_l: float = 1e-1,
|
|
108
109
|
beta_1: float = 0.9,
|
|
@@ -138,8 +139,8 @@ class FedAdam(FedOpt):
|
|
|
138
139
|
self,
|
|
139
140
|
server_round: int,
|
|
140
141
|
results: list[tuple[ClientProxy, FitRes]],
|
|
141
|
-
failures: list[
|
|
142
|
-
) -> tuple[
|
|
142
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
143
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
143
144
|
"""Aggregate fit results using weighted average."""
|
|
144
145
|
fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit(
|
|
145
146
|
server_round=server_round, results=results, failures=failures
|
|
@@ -151,7 +152,8 @@ class FedAdam(FedOpt):
|
|
|
151
152
|
|
|
152
153
|
# Adam
|
|
153
154
|
delta_t: NDArrays = [
|
|
154
|
-
x - y
|
|
155
|
+
x - y
|
|
156
|
+
for x, y in zip(fedavg_weights_aggregate, self.current_weights, strict=True)
|
|
155
157
|
]
|
|
156
158
|
|
|
157
159
|
# m_t
|
|
@@ -159,7 +161,7 @@ class FedAdam(FedOpt):
|
|
|
159
161
|
self.m_t = [np.zeros_like(x) for x in delta_t]
|
|
160
162
|
self.m_t = [
|
|
161
163
|
np.multiply(self.beta_1, x) + (1 - self.beta_1) * y
|
|
162
|
-
for x, y in zip(self.m_t, delta_t)
|
|
164
|
+
for x, y in zip(self.m_t, delta_t, strict=True)
|
|
163
165
|
]
|
|
164
166
|
|
|
165
167
|
# v_t
|
|
@@ -167,7 +169,7 @@ class FedAdam(FedOpt):
|
|
|
167
169
|
self.v_t = [np.zeros_like(x) for x in delta_t]
|
|
168
170
|
self.v_t = [
|
|
169
171
|
self.beta_2 * x + (1 - self.beta_2) * np.multiply(y, y)
|
|
170
|
-
for x, y in zip(self.v_t, delta_t)
|
|
172
|
+
for x, y in zip(self.v_t, delta_t, strict=True)
|
|
171
173
|
]
|
|
172
174
|
|
|
173
175
|
# Compute the bias-corrected learning rate, `eta_norm` for improving convergence
|
|
@@ -182,7 +184,7 @@ class FedAdam(FedOpt):
|
|
|
182
184
|
|
|
183
185
|
new_weights = [
|
|
184
186
|
x + eta_norm * y / (np.sqrt(z) + self.tau)
|
|
185
|
-
for x, y, z in zip(self.current_weights, self.m_t, self.v_t)
|
|
187
|
+
for x, y, z in zip(self.current_weights, self.m_t, self.v_t, strict=True)
|
|
186
188
|
]
|
|
187
189
|
|
|
188
190
|
self.current_weights = new_weights
|
flwr/server/strategy/fedavg.py
CHANGED
|
@@ -18,8 +18,8 @@ Paper: arxiv.org/abs/1602.05629
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
+
from collections.abc import Callable
|
|
21
22
|
from logging import WARNING
|
|
22
|
-
from typing import Callable, Optional, Union
|
|
23
23
|
|
|
24
24
|
from flwr.common import (
|
|
25
25
|
EvaluateIns,
|
|
@@ -97,18 +97,19 @@ class FedAvg(Strategy):
|
|
|
97
97
|
min_fit_clients: int = 2,
|
|
98
98
|
min_evaluate_clients: int = 2,
|
|
99
99
|
min_available_clients: int = 2,
|
|
100
|
-
evaluate_fn:
|
|
100
|
+
evaluate_fn: (
|
|
101
101
|
Callable[
|
|
102
102
|
[int, NDArrays, dict[str, Scalar]],
|
|
103
|
-
|
|
103
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
104
104
|
]
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
105
|
+
| None
|
|
106
|
+
) = None,
|
|
107
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
108
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
108
109
|
accept_failures: bool = True,
|
|
109
|
-
initial_parameters:
|
|
110
|
-
fit_metrics_aggregation_fn:
|
|
111
|
-
evaluate_metrics_aggregation_fn:
|
|
110
|
+
initial_parameters: Parameters | None = None,
|
|
111
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
112
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
112
113
|
inplace: bool = True,
|
|
113
114
|
) -> None:
|
|
114
115
|
super().__init__()
|
|
@@ -148,9 +149,7 @@ class FedAvg(Strategy):
|
|
|
148
149
|
num_clients = int(num_available_clients * self.fraction_evaluate)
|
|
149
150
|
return max(num_clients, self.min_evaluate_clients), self.min_available_clients
|
|
150
151
|
|
|
151
|
-
def initialize_parameters(
|
|
152
|
-
self, client_manager: ClientManager
|
|
153
|
-
) -> Optional[Parameters]:
|
|
152
|
+
def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
|
|
154
153
|
"""Initialize global model parameters."""
|
|
155
154
|
initial_parameters = self.initial_parameters
|
|
156
155
|
self.initial_parameters = None # Don't keep initial parameters in memory
|
|
@@ -158,7 +157,7 @@ class FedAvg(Strategy):
|
|
|
158
157
|
|
|
159
158
|
def evaluate(
|
|
160
159
|
self, server_round: int, parameters: Parameters
|
|
161
|
-
) ->
|
|
160
|
+
) -> tuple[float, dict[str, Scalar]] | None:
|
|
162
161
|
"""Evaluate model parameters using an evaluation function."""
|
|
163
162
|
if self.evaluate_fn is None:
|
|
164
163
|
# No evaluation function provided
|
|
@@ -221,8 +220,8 @@ class FedAvg(Strategy):
|
|
|
221
220
|
self,
|
|
222
221
|
server_round: int,
|
|
223
222
|
results: list[tuple[ClientProxy, FitRes]],
|
|
224
|
-
failures: list[
|
|
225
|
-
) -> tuple[
|
|
223
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
224
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
226
225
|
"""Aggregate fit results using weighted average."""
|
|
227
226
|
if not results:
|
|
228
227
|
return None, {}
|
|
@@ -257,8 +256,8 @@ class FedAvg(Strategy):
|
|
|
257
256
|
self,
|
|
258
257
|
server_round: int,
|
|
259
258
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
260
|
-
failures: list[
|
|
261
|
-
) -> tuple[
|
|
259
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
260
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
262
261
|
"""Aggregate evaluation losses using weighted average."""
|
|
263
262
|
if not results:
|
|
264
263
|
return None, {}
|
|
@@ -18,7 +18,8 @@ Paper: arxiv.org/abs/1602.05629
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from
|
|
21
|
+
from collections.abc import Callable
|
|
22
|
+
from typing import cast
|
|
22
23
|
|
|
23
24
|
import numpy as np
|
|
24
25
|
|
|
@@ -79,16 +80,17 @@ class FedAvgAndroid(Strategy):
|
|
|
79
80
|
min_fit_clients: int = 2,
|
|
80
81
|
min_evaluate_clients: int = 2,
|
|
81
82
|
min_available_clients: int = 2,
|
|
82
|
-
evaluate_fn:
|
|
83
|
+
evaluate_fn: (
|
|
83
84
|
Callable[
|
|
84
85
|
[int, NDArrays, dict[str, Scalar]],
|
|
85
|
-
|
|
86
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
86
87
|
]
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
88
|
+
| None
|
|
89
|
+
) = None,
|
|
90
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
91
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
90
92
|
accept_failures: bool = True,
|
|
91
|
-
initial_parameters:
|
|
93
|
+
initial_parameters: Parameters | None = None,
|
|
92
94
|
) -> None:
|
|
93
95
|
super().__init__()
|
|
94
96
|
self.min_fit_clients = min_fit_clients
|
|
@@ -117,9 +119,7 @@ class FedAvgAndroid(Strategy):
|
|
|
117
119
|
num_clients = int(num_available_clients * self.fraction_evaluate)
|
|
118
120
|
return max(num_clients, self.min_evaluate_clients), self.min_available_clients
|
|
119
121
|
|
|
120
|
-
def initialize_parameters(
|
|
121
|
-
self, client_manager: ClientManager
|
|
122
|
-
) -> Optional[Parameters]:
|
|
122
|
+
def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
|
|
123
123
|
"""Initialize global model parameters."""
|
|
124
124
|
initial_parameters = self.initial_parameters
|
|
125
125
|
self.initial_parameters = None # Don't keep initial parameters in memory
|
|
@@ -127,7 +127,7 @@ class FedAvgAndroid(Strategy):
|
|
|
127
127
|
|
|
128
128
|
def evaluate(
|
|
129
129
|
self, server_round: int, parameters: Parameters
|
|
130
|
-
) ->
|
|
130
|
+
) -> tuple[float, dict[str, Scalar]] | None:
|
|
131
131
|
"""Evaluate model parameters using an evaluation function."""
|
|
132
132
|
if self.evaluate_fn is None:
|
|
133
133
|
# No evaluation function provided
|
|
@@ -190,8 +190,8 @@ class FedAvgAndroid(Strategy):
|
|
|
190
190
|
self,
|
|
191
191
|
server_round: int,
|
|
192
192
|
results: list[tuple[ClientProxy, FitRes]],
|
|
193
|
-
failures: list[
|
|
194
|
-
) -> tuple[
|
|
193
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
194
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
195
195
|
"""Aggregate fit results using weighted average."""
|
|
196
196
|
if not results:
|
|
197
197
|
return None, {}
|
|
@@ -209,8 +209,8 @@ class FedAvgAndroid(Strategy):
|
|
|
209
209
|
self,
|
|
210
210
|
server_round: int,
|
|
211
211
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
212
|
-
failures: list[
|
|
213
|
-
) -> tuple[
|
|
212
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
213
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
214
214
|
"""Aggregate evaluation losses using weighted average."""
|
|
215
215
|
if not results:
|
|
216
216
|
return None, {}
|
flwr/server/strategy/fedavgm.py
CHANGED
|
@@ -18,8 +18,8 @@ Paper: arxiv.org/pdf/1909.06335.pdf
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
+
from collections.abc import Callable
|
|
21
22
|
from logging import WARNING
|
|
22
|
-
from typing import Callable, Optional, Union
|
|
23
23
|
|
|
24
24
|
from flwr.common import (
|
|
25
25
|
FitRes,
|
|
@@ -82,18 +82,19 @@ class FedAvgM(FedAvg):
|
|
|
82
82
|
min_fit_clients: int = 2,
|
|
83
83
|
min_evaluate_clients: int = 2,
|
|
84
84
|
min_available_clients: int = 2,
|
|
85
|
-
evaluate_fn:
|
|
85
|
+
evaluate_fn: (
|
|
86
86
|
Callable[
|
|
87
87
|
[int, NDArrays, dict[str, Scalar]],
|
|
88
|
-
|
|
88
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
89
89
|
]
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
90
|
+
| None
|
|
91
|
+
) = None,
|
|
92
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
93
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
93
94
|
accept_failures: bool = True,
|
|
94
|
-
initial_parameters:
|
|
95
|
-
fit_metrics_aggregation_fn:
|
|
96
|
-
evaluate_metrics_aggregation_fn:
|
|
95
|
+
initial_parameters: Parameters | None = None,
|
|
96
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
97
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
97
98
|
server_learning_rate: float = 1.0,
|
|
98
99
|
server_momentum: float = 0.0,
|
|
99
100
|
) -> None:
|
|
@@ -116,16 +117,14 @@ class FedAvgM(FedAvg):
|
|
|
116
117
|
self.server_opt: bool = (self.server_momentum != 0.0) or (
|
|
117
118
|
self.server_learning_rate != 1.0
|
|
118
119
|
)
|
|
119
|
-
self.momentum_vector:
|
|
120
|
+
self.momentum_vector: NDArrays | None = None
|
|
120
121
|
|
|
121
122
|
def __repr__(self) -> str:
|
|
122
123
|
"""Compute a string representation of the strategy."""
|
|
123
124
|
rep = f"FedAvgM(accept_failures={self.accept_failures})"
|
|
124
125
|
return rep
|
|
125
126
|
|
|
126
|
-
def initialize_parameters(
|
|
127
|
-
self, client_manager: ClientManager
|
|
128
|
-
) -> Optional[Parameters]:
|
|
127
|
+
def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None:
|
|
129
128
|
"""Initialize global model parameters."""
|
|
130
129
|
return self.initial_parameters
|
|
131
130
|
|
|
@@ -133,8 +132,8 @@ class FedAvgM(FedAvg):
|
|
|
133
132
|
self,
|
|
134
133
|
server_round: int,
|
|
135
134
|
results: list[tuple[ClientProxy, FitRes]],
|
|
136
|
-
failures: list[
|
|
137
|
-
) -> tuple[
|
|
135
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
136
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
138
137
|
"""Aggregate fit results using weighted average."""
|
|
139
138
|
if not results:
|
|
140
139
|
return None, {}
|
|
@@ -161,7 +160,9 @@ class FedAvgM(FedAvg):
|
|
|
161
160
|
pseudo_gradient: NDArrays = [
|
|
162
161
|
x - y
|
|
163
162
|
for x, y in zip(
|
|
164
|
-
parameters_to_ndarrays(self.initial_parameters),
|
|
163
|
+
parameters_to_ndarrays(self.initial_parameters),
|
|
164
|
+
fedavg_result,
|
|
165
|
+
strict=True,
|
|
165
166
|
)
|
|
166
167
|
]
|
|
167
168
|
if self.server_momentum > 0.0:
|
|
@@ -171,7 +172,9 @@ class FedAvgM(FedAvg):
|
|
|
171
172
|
), "Momentum should have been created on round 1."
|
|
172
173
|
self.momentum_vector = [
|
|
173
174
|
self.server_momentum * x + y
|
|
174
|
-
for x, y in zip(
|
|
175
|
+
for x, y in zip(
|
|
176
|
+
self.momentum_vector, pseudo_gradient, strict=True
|
|
177
|
+
)
|
|
175
178
|
]
|
|
176
179
|
else:
|
|
177
180
|
self.momentum_vector = pseudo_gradient
|
|
@@ -182,7 +185,7 @@ class FedAvgM(FedAvg):
|
|
|
182
185
|
# SGD
|
|
183
186
|
fedavg_result = [
|
|
184
187
|
x - self.server_learning_rate * y
|
|
185
|
-
for x, y in zip(initial_weights, pseudo_gradient)
|
|
188
|
+
for x, y in zip(initial_weights, pseudo_gradient, strict=True)
|
|
186
189
|
]
|
|
187
190
|
# Update current weights
|
|
188
191
|
self.initial_parameters = ndarrays_to_parameters(fedavg_result)
|
|
@@ -19,7 +19,6 @@ Paper: arxiv.org/pdf/1803.01498v1.pdf
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
from logging import WARNING
|
|
22
|
-
from typing import Optional, Union
|
|
23
22
|
|
|
24
23
|
from flwr.common import (
|
|
25
24
|
FitRes,
|
|
@@ -47,8 +46,8 @@ class FedMedian(FedAvg):
|
|
|
47
46
|
self,
|
|
48
47
|
server_round: int,
|
|
49
48
|
results: list[tuple[ClientProxy, FitRes]],
|
|
50
|
-
failures: list[
|
|
51
|
-
) -> tuple[
|
|
49
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
50
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
52
51
|
"""Aggregate fit results using median."""
|
|
53
52
|
if not results:
|
|
54
53
|
return None, {}
|
flwr/server/strategy/fedopt.py
CHANGED
|
@@ -18,7 +18,7 @@ Paper: arxiv.org/abs/2003.00295
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from
|
|
21
|
+
from collections.abc import Callable
|
|
22
22
|
|
|
23
23
|
from flwr.common import (
|
|
24
24
|
MetricsAggregationFn,
|
|
@@ -84,18 +84,19 @@ class FedOpt(FedAvg):
|
|
|
84
84
|
min_fit_clients: int = 2,
|
|
85
85
|
min_evaluate_clients: int = 2,
|
|
86
86
|
min_available_clients: int = 2,
|
|
87
|
-
evaluate_fn:
|
|
87
|
+
evaluate_fn: (
|
|
88
88
|
Callable[
|
|
89
89
|
[int, NDArrays, dict[str, Scalar]],
|
|
90
|
-
|
|
90
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
91
91
|
]
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
92
|
+
| None
|
|
93
|
+
) = None,
|
|
94
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
95
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
95
96
|
accept_failures: bool = True,
|
|
96
97
|
initial_parameters: Parameters,
|
|
97
|
-
fit_metrics_aggregation_fn:
|
|
98
|
-
evaluate_metrics_aggregation_fn:
|
|
98
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
99
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
99
100
|
eta: float = 1e-1,
|
|
100
101
|
eta_l: float = 1e-1,
|
|
101
102
|
beta_1: float = 0.0,
|
|
@@ -122,8 +123,8 @@ class FedOpt(FedAvg):
|
|
|
122
123
|
self.tau = tau
|
|
123
124
|
self.beta_1 = beta_1
|
|
124
125
|
self.beta_2 = beta_2
|
|
125
|
-
self.m_t:
|
|
126
|
-
self.v_t:
|
|
126
|
+
self.m_t: NDArrays | None = None
|
|
127
|
+
self.v_t: NDArrays | None = None
|
|
127
128
|
|
|
128
129
|
def __repr__(self) -> str:
|
|
129
130
|
"""Compute a string representation of the strategy."""
|
flwr/server/strategy/fedprox.py
CHANGED
|
@@ -18,7 +18,7 @@ Paper: arxiv.org/abs/1812.06127
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from
|
|
21
|
+
from collections.abc import Callable
|
|
22
22
|
|
|
23
23
|
from flwr.common import FitIns, MetricsAggregationFn, NDArrays, Parameters, Scalar
|
|
24
24
|
from flwr.server.client_manager import ClientManager
|
|
@@ -111,18 +111,19 @@ class FedProx(FedAvg):
|
|
|
111
111
|
min_fit_clients: int = 2,
|
|
112
112
|
min_evaluate_clients: int = 2,
|
|
113
113
|
min_available_clients: int = 2,
|
|
114
|
-
evaluate_fn:
|
|
114
|
+
evaluate_fn: (
|
|
115
115
|
Callable[
|
|
116
116
|
[int, NDArrays, dict[str, Scalar]],
|
|
117
|
-
|
|
117
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
118
118
|
]
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
119
|
+
| None
|
|
120
|
+
) = None,
|
|
121
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
122
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
122
123
|
accept_failures: bool = True,
|
|
123
|
-
initial_parameters:
|
|
124
|
-
fit_metrics_aggregation_fn:
|
|
125
|
-
evaluate_metrics_aggregation_fn:
|
|
124
|
+
initial_parameters: Parameters | None = None,
|
|
125
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
126
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
126
127
|
proximal_mu: float,
|
|
127
128
|
) -> None:
|
|
128
129
|
super().__init__(
|
|
@@ -16,8 +16,8 @@
|
|
|
16
16
|
|
|
17
17
|
Paper: arxiv.org/abs/1803.01498
|
|
18
18
|
"""
|
|
19
|
+
from collections.abc import Callable
|
|
19
20
|
from logging import WARNING
|
|
20
|
-
from typing import Callable, Optional, Union
|
|
21
21
|
|
|
22
22
|
from flwr.common import (
|
|
23
23
|
FitRes,
|
|
@@ -76,18 +76,19 @@ class FedTrimmedAvg(FedAvg):
|
|
|
76
76
|
min_fit_clients: int = 2,
|
|
77
77
|
min_evaluate_clients: int = 2,
|
|
78
78
|
min_available_clients: int = 2,
|
|
79
|
-
evaluate_fn:
|
|
79
|
+
evaluate_fn: (
|
|
80
80
|
Callable[
|
|
81
81
|
[int, NDArrays, dict[str, Scalar]],
|
|
82
|
-
|
|
82
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
83
83
|
]
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
84
|
+
| None
|
|
85
|
+
) = None,
|
|
86
|
+
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
87
|
+
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
|
|
87
88
|
accept_failures: bool = True,
|
|
88
|
-
initial_parameters:
|
|
89
|
-
fit_metrics_aggregation_fn:
|
|
90
|
-
evaluate_metrics_aggregation_fn:
|
|
89
|
+
initial_parameters: Parameters | None = None,
|
|
90
|
+
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
91
|
+
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
|
|
91
92
|
beta: float = 0.2,
|
|
92
93
|
) -> None:
|
|
93
94
|
super().__init__(
|
|
@@ -115,8 +116,8 @@ class FedTrimmedAvg(FedAvg):
|
|
|
115
116
|
self,
|
|
116
117
|
server_round: int,
|
|
117
118
|
results: list[tuple[ClientProxy, FitRes]],
|
|
118
|
-
failures: list[
|
|
119
|
-
) -> tuple[
|
|
119
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
120
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
120
121
|
"""Aggregate fit results using trimmed average."""
|
|
121
122
|
if not results:
|
|
122
123
|
return None, {}
|
|
@@ -16,8 +16,9 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import json
|
|
19
|
+
from collections.abc import Callable
|
|
19
20
|
from logging import WARNING
|
|
20
|
-
from typing import Any,
|
|
21
|
+
from typing import Any, cast
|
|
21
22
|
|
|
22
23
|
from flwr.common import EvaluateRes, FitRes, Parameters, Scalar
|
|
23
24
|
from flwr.common.logger import log
|
|
@@ -32,16 +33,17 @@ class FedXgbBagging(FedAvg):
|
|
|
32
33
|
# pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long
|
|
33
34
|
def __init__(
|
|
34
35
|
self,
|
|
35
|
-
evaluate_function:
|
|
36
|
+
evaluate_function: (
|
|
36
37
|
Callable[
|
|
37
38
|
[int, Parameters, dict[str, Scalar]],
|
|
38
|
-
|
|
39
|
+
tuple[float, dict[str, Scalar]] | None,
|
|
39
40
|
]
|
|
40
|
-
|
|
41
|
+
| None
|
|
42
|
+
) = None,
|
|
41
43
|
**kwargs: Any,
|
|
42
44
|
):
|
|
43
45
|
self.evaluate_function = evaluate_function
|
|
44
|
-
self.global_model:
|
|
46
|
+
self.global_model: bytes | None = None
|
|
45
47
|
super().__init__(**kwargs)
|
|
46
48
|
|
|
47
49
|
def __repr__(self) -> str:
|
|
@@ -53,8 +55,8 @@ class FedXgbBagging(FedAvg):
|
|
|
53
55
|
self,
|
|
54
56
|
server_round: int,
|
|
55
57
|
results: list[tuple[ClientProxy, FitRes]],
|
|
56
|
-
failures: list[
|
|
57
|
-
) -> tuple[
|
|
58
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException],
|
|
59
|
+
) -> tuple[Parameters | None, dict[str, Scalar]]:
|
|
58
60
|
"""Aggregate fit results using bagging."""
|
|
59
61
|
if not results:
|
|
60
62
|
return None, {}
|
|
@@ -80,8 +82,8 @@ class FedXgbBagging(FedAvg):
|
|
|
80
82
|
self,
|
|
81
83
|
server_round: int,
|
|
82
84
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
83
|
-
failures: list[
|
|
84
|
-
) -> tuple[
|
|
85
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
86
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
85
87
|
"""Aggregate evaluation metrics using average."""
|
|
86
88
|
if not results:
|
|
87
89
|
return None, {}
|
|
@@ -101,7 +103,7 @@ class FedXgbBagging(FedAvg):
|
|
|
101
103
|
|
|
102
104
|
def evaluate(
|
|
103
105
|
self, server_round: int, parameters: Parameters
|
|
104
|
-
) ->
|
|
106
|
+
) -> tuple[float, dict[str, Scalar]] | None:
|
|
105
107
|
"""Evaluate model parameters using an evaluation function."""
|
|
106
108
|
if self.evaluate_function is None:
|
|
107
109
|
# No evaluation function provided
|
|
@@ -114,7 +116,7 @@ class FedXgbBagging(FedAvg):
|
|
|
114
116
|
|
|
115
117
|
|
|
116
118
|
def aggregate(
|
|
117
|
-
bst_prev_org:
|
|
119
|
+
bst_prev_org: bytes | None,
|
|
118
120
|
bst_curr_org: bytes,
|
|
119
121
|
) -> bytes:
|
|
120
122
|
"""Conduct bagging aggregation for given trees."""
|