flwr 1.23.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +16 -5
- flwr/app/error.py +2 -2
- flwr/app/exception.py +3 -3
- flwr/cli/app.py +19 -0
- flwr/cli/app_cmd/__init__.py +23 -0
- flwr/cli/app_cmd/publish.py +285 -0
- flwr/cli/app_cmd/review.py +252 -0
- flwr/cli/auth_plugin/auth_plugin.py +4 -5
- flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
- flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
- flwr/cli/build.py +60 -18
- flwr/cli/cli_account_auth_interceptor.py +24 -7
- flwr/cli/config_utils.py +101 -13
- flwr/cli/federation/__init__.py +24 -0
- flwr/cli/federation/ls.py +140 -0
- flwr/cli/federation/show.py +317 -0
- flwr/cli/install.py +91 -13
- flwr/cli/log.py +52 -9
- flwr/cli/login/login.py +7 -4
- flwr/cli/ls.py +170 -130
- flwr/cli/new/new.py +33 -50
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/cli/pull.py +10 -5
- flwr/cli/run/run.py +77 -30
- flwr/cli/run_utils.py +130 -0
- flwr/cli/stop.py +25 -7
- flwr/cli/supernode/ls.py +16 -8
- flwr/cli/supernode/register.py +9 -4
- flwr/cli/supernode/unregister.py +5 -3
- flwr/cli/utils.py +376 -16
- flwr/client/__init__.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +4 -1
- flwr/client/grpc_adapter_client/connection.py +6 -7
- flwr/client/grpc_rere_client/connection.py +10 -11
- flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
- flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/connection.py +12 -14
- flwr/client/run_info_store.py +4 -5
- flwr/client/typing.py +1 -1
- flwr/clientapp/client_app.py +9 -10
- flwr/clientapp/mod/centraldp_mods.py +16 -17
- flwr/clientapp/mod/localdp_mod.py +8 -9
- flwr/clientapp/typing.py +1 -1
- flwr/clientapp/utils.py +3 -3
- flwr/common/address.py +1 -2
- flwr/common/args.py +3 -4
- flwr/common/config.py +13 -16
- flwr/common/constant.py +5 -2
- flwr/common/differential_privacy.py +3 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -4
- flwr/common/exit/exit.py +15 -2
- flwr/common/exit/exit_code.py +19 -0
- flwr/common/exit/exit_handler.py +6 -2
- flwr/common/exit/signal_handler.py +5 -5
- flwr/common/grpc.py +6 -6
- flwr/common/inflatable_protobuf_utils.py +1 -1
- flwr/common/inflatable_utils.py +38 -21
- flwr/common/logger.py +19 -19
- flwr/common/message.py +4 -4
- flwr/common/object_ref.py +7 -7
- flwr/common/record/array.py +3 -3
- flwr/common/record/arrayrecord.py +18 -30
- flwr/common/record/configrecord.py +3 -3
- flwr/common/record/recorddict.py +5 -5
- flwr/common/record/typeddict.py +9 -2
- flwr/common/recorddict_compat.py +7 -10
- flwr/common/retry_invoker.py +20 -20
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
- flwr/common/serde.py +5 -4
- flwr/common/serde_utils.py +2 -2
- flwr/common/telemetry.py +9 -5
- flwr/common/typing.py +52 -37
- flwr/compat/client/app.py +38 -37
- flwr/compat/client/grpc_client/connection.py +11 -11
- flwr/compat/server/app.py +5 -6
- flwr/proto/appio_pb2.py +13 -3
- flwr/proto/appio_pb2.pyi +134 -65
- flwr/proto/appio_pb2_grpc.py +20 -0
- flwr/proto/appio_pb2_grpc.pyi +27 -0
- flwr/proto/clientappio_pb2.py +17 -7
- flwr/proto/clientappio_pb2.pyi +15 -0
- flwr/proto/clientappio_pb2_grpc.py +206 -40
- flwr/proto/clientappio_pb2_grpc.pyi +168 -53
- flwr/proto/control_pb2.py +71 -52
- flwr/proto/control_pb2.pyi +277 -111
- flwr/proto/control_pb2_grpc.py +249 -40
- flwr/proto/control_pb2_grpc.pyi +185 -52
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2.pyi +24 -6
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/error_pb2_grpc.pyi +27 -0
- flwr/proto/fab_pb2.py +14 -4
- flwr/proto/fab_pb2.pyi +59 -31
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fab_pb2_grpc.pyi +27 -0
- flwr/proto/federation_pb2.py +38 -0
- flwr/proto/federation_pb2.pyi +56 -0
- flwr/proto/federation_pb2_grpc.py +24 -0
- flwr/proto/federation_pb2_grpc.pyi +31 -0
- flwr/proto/fleet_pb2.py +14 -4
- flwr/proto/fleet_pb2.pyi +137 -61
- flwr/proto/fleet_pb2_grpc.py +189 -48
- flwr/proto/fleet_pb2_grpc.pyi +175 -61
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2.pyi +38 -16
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
- flwr/proto/heartbeat_pb2.py +17 -7
- flwr/proto/heartbeat_pb2.pyi +51 -22
- flwr/proto/heartbeat_pb2_grpc.py +20 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2.pyi +34 -11
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/log_pb2_grpc.pyi +27 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2.pyi +154 -86
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/message_pb2_grpc.pyi +27 -0
- flwr/proto/node_pb2.py +15 -5
- flwr/proto/node_pb2.pyi +50 -25
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/node_pb2_grpc.pyi +27 -0
- flwr/proto/recorddict_pb2.py +13 -3
- flwr/proto/recorddict_pb2.pyi +184 -107
- flwr/proto/recorddict_pb2_grpc.py +20 -0
- flwr/proto/recorddict_pb2_grpc.pyi +27 -0
- flwr/proto/run_pb2.py +40 -31
- flwr/proto/run_pb2.pyi +149 -84
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/run_pb2_grpc.pyi +27 -0
- flwr/proto/serverappio_pb2.py +13 -3
- flwr/proto/serverappio_pb2.pyi +32 -8
- flwr/proto/serverappio_pb2_grpc.py +246 -65
- flwr/proto/serverappio_pb2_grpc.pyi +221 -85
- flwr/proto/simulationio_pb2.py +16 -8
- flwr/proto/simulationio_pb2.pyi +15 -0
- flwr/proto/simulationio_pb2_grpc.py +162 -41
- flwr/proto/simulationio_pb2_grpc.pyi +149 -55
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2.pyi +249 -160
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/proto/transport_pb2_grpc.pyi +38 -8
- flwr/server/app.py +38 -17
- flwr/server/client_manager.py +4 -5
- flwr/server/client_proxy.py +10 -11
- flwr/server/compat/app.py +4 -5
- flwr/server/compat/app_utils.py +2 -1
- flwr/server/compat/grid_client_proxy.py +10 -12
- flwr/server/compat/legacy_context.py +3 -4
- flwr/server/fleet_event_log_interceptor.py +2 -1
- flwr/server/grid/grid.py +2 -3
- flwr/server/grid/grpc_grid.py +10 -8
- flwr/server/grid/inmemory_grid.py +4 -4
- flwr/server/run_serverapp.py +2 -3
- flwr/server/server.py +34 -39
- flwr/server/server_app.py +7 -8
- flwr/server/server_config.py +1 -2
- flwr/server/serverapp/app.py +34 -28
- flwr/server/serverapp_components.py +4 -5
- flwr/server/strategy/aggregate.py +9 -8
- flwr/server/strategy/bulyan.py +13 -11
- flwr/server/strategy/dp_adaptive_clipping.py +16 -20
- flwr/server/strategy/dp_fixed_clipping.py +12 -17
- flwr/server/strategy/dpfedavg_adaptive.py +3 -4
- flwr/server/strategy/dpfedavg_fixed.py +6 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
- flwr/server/strategy/fedadagrad.py +18 -14
- flwr/server/strategy/fedadam.py +16 -14
- flwr/server/strategy/fedavg.py +16 -17
- flwr/server/strategy/fedavg_android.py +15 -15
- flwr/server/strategy/fedavgm.py +21 -18
- flwr/server/strategy/fedmedian.py +2 -3
- flwr/server/strategy/fedopt.py +11 -10
- flwr/server/strategy/fedprox.py +10 -9
- flwr/server/strategy/fedtrimmedavg.py +12 -11
- flwr/server/strategy/fedxgb_bagging.py +13 -11
- flwr/server/strategy/fedxgb_cyclic.py +6 -6
- flwr/server/strategy/fedxgb_nn_avg.py +4 -4
- flwr/server/strategy/fedyogi.py +16 -14
- flwr/server/strategy/krum.py +12 -11
- flwr/server/strategy/qfedavg.py +16 -15
- flwr/server/strategy/strategy.py +6 -9
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
- flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +34 -28
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +15 -9
- flwr/server/superlink/linkstate/in_memory_linkstate.py +115 -150
- flwr/server/superlink/linkstate/linkstate.py +59 -43
- flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
- flwr/server/superlink/linkstate/sqlite_linkstate.py +447 -438
- flwr/server/superlink/linkstate/utils.py +6 -6
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
- flwr/server/superlink/utils.py +4 -6
- flwr/server/typing.py +1 -1
- flwr/server/utils/tensorboard.py +15 -8
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
- flwr/serverapp/strategy/bulyan.py +16 -15
- flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
- flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
- flwr/serverapp/strategy/fedadagrad.py +10 -11
- flwr/serverapp/strategy/fedadam.py +10 -11
- flwr/serverapp/strategy/fedavg.py +9 -10
- flwr/serverapp/strategy/fedavgm.py +17 -16
- flwr/serverapp/strategy/fedmedian.py +2 -2
- flwr/serverapp/strategy/fedopt.py +10 -11
- flwr/serverapp/strategy/fedprox.py +7 -8
- flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
- flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
- flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
- flwr/serverapp/strategy/fedyogi.py +9 -11
- flwr/serverapp/strategy/krum.py +7 -7
- flwr/serverapp/strategy/multikrum.py +9 -9
- flwr/serverapp/strategy/qfedavg.py +17 -16
- flwr/serverapp/strategy/strategy.py +6 -9
- flwr/serverapp/strategy/strategy_utils.py +7 -8
- flwr/simulation/app.py +46 -42
- flwr/simulation/legacy_app.py +12 -12
- flwr/simulation/ray_transport/ray_actor.py +10 -11
- flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
- flwr/simulation/run_simulation.py +43 -43
- flwr/simulation/simulationio_connection.py +4 -4
- flwr/supercore/cli/flower_superexec.py +3 -4
- flwr/supercore/constant.py +31 -1
- flwr/supercore/corestate/corestate.py +24 -3
- flwr/supercore/corestate/in_memory_corestate.py +138 -0
- flwr/supercore/corestate/sqlite_corestate.py +157 -0
- flwr/supercore/ffs/disk_ffs.py +1 -2
- flwr/supercore/ffs/ffs.py +1 -2
- flwr/supercore/ffs/ffs_factory.py +1 -2
- flwr/{common → supercore}/heartbeat.py +20 -25
- flwr/supercore/object_store/in_memory_object_store.py +1 -2
- flwr/supercore/object_store/object_store.py +1 -2
- flwr/supercore/object_store/object_store_factory.py +1 -2
- flwr/supercore/object_store/sqlite_object_store.py +8 -7
- flwr/supercore/primitives/asymmetric.py +1 -1
- flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
- flwr/supercore/sqlite_mixin.py +37 -34
- flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
- flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
- flwr/supercore/superexec/run_superexec.py +9 -13
- flwr/superlink/artifact_provider/artifact_provider.py +1 -2
- flwr/superlink/auth_plugin/auth_plugin.py +6 -9
- flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
- flwr/superlink/federation/__init__.py +24 -0
- flwr/superlink/federation/federation_manager.py +64 -0
- flwr/superlink/federation/noop_federation_manager.py +71 -0
- flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
- flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
- flwr/superlink/servicer/control/control_grpc.py +5 -6
- flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
- flwr/superlink/servicer/control/control_servicer.py +102 -18
- flwr/supernode/cli/flower_supernode.py +58 -3
- flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
- flwr/supernode/nodestate/nodestate.py +7 -8
- flwr/supernode/nodestate/nodestate_factory.py +7 -4
- flwr/supernode/runtime/run_clientapp.py +41 -22
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
- flwr/supernode/start_client_internal.py +158 -42
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
- flwr-1.24.0.dist-info/RECORD +454 -0
- flwr/supercore/object_store/utils.py +0 -43
- flwr-1.23.0.dist-info/RECORD +0 -439
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
|
@@ -17,11 +17,10 @@
|
|
|
17
17
|
Papers: https://arxiv.org/abs/1712.07557, https://arxiv.org/abs/1710.06963
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
|
+
|
|
20
21
|
from abc import ABC
|
|
21
|
-
from collections import OrderedDict
|
|
22
22
|
from collections.abc import Iterable
|
|
23
23
|
from logging import INFO, WARNING
|
|
24
|
-
from typing import Optional
|
|
25
24
|
|
|
26
25
|
from flwr.common import Array, ArrayRecord, ConfigRecord, Message, MetricRecord, log
|
|
27
26
|
from flwr.common.differential_privacy import (
|
|
@@ -112,12 +111,12 @@ class DifferentialPrivacyFixedClippingBase(Strategy, ABC):
|
|
|
112
111
|
)
|
|
113
112
|
|
|
114
113
|
return ArrayRecord(
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
114
|
+
{
|
|
115
|
+
k: Array(v)
|
|
116
|
+
for k, v in zip(
|
|
117
|
+
aggregated_arrays.keys(), aggregated_ndarrays, strict=True
|
|
118
|
+
)
|
|
119
|
+
}
|
|
121
120
|
)
|
|
122
121
|
|
|
123
122
|
def configure_evaluate(
|
|
@@ -130,7 +129,7 @@ class DifferentialPrivacyFixedClippingBase(Strategy, ABC):
|
|
|
130
129
|
self,
|
|
131
130
|
server_round: int,
|
|
132
131
|
replies: Iterable[Message],
|
|
133
|
-
) ->
|
|
132
|
+
) -> MetricRecord | None:
|
|
134
133
|
"""Aggregate MetricRecords in the received Messages."""
|
|
135
134
|
return self.strategy.aggregate_evaluate(server_round, replies)
|
|
136
135
|
|
|
@@ -199,7 +198,7 @@ class DifferentialPrivacyServerSideFixedClipping(DifferentialPrivacyFixedClippin
|
|
|
199
198
|
self,
|
|
200
199
|
server_round: int,
|
|
201
200
|
replies: Iterable[Message],
|
|
202
|
-
) -> tuple[
|
|
201
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
203
202
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
204
203
|
if not validate_replies(replies, self.num_sampled_clients):
|
|
205
204
|
return None, None
|
|
@@ -217,9 +216,7 @@ class DifferentialPrivacyServerSideFixedClipping(DifferentialPrivacyFixedClippin
|
|
|
217
216
|
)
|
|
218
217
|
# Replace content while preserving keys
|
|
219
218
|
reply.content[arr_name] = ArrayRecord(
|
|
220
|
-
|
|
221
|
-
{k: Array(v) for k, v in zip(record.keys(), reply_ndarrays)}
|
|
222
|
-
)
|
|
219
|
+
dict(zip(record.keys(), map(Array, reply_ndarrays), strict=True))
|
|
223
220
|
)
|
|
224
221
|
log(
|
|
225
222
|
INFO,
|
|
@@ -302,7 +299,7 @@ class DifferentialPrivacyClientSideFixedClipping(DifferentialPrivacyFixedClippin
|
|
|
302
299
|
self,
|
|
303
300
|
server_round: int,
|
|
304
301
|
replies: Iterable[Message],
|
|
305
|
-
) -> tuple[
|
|
302
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
306
303
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
307
304
|
if not validate_replies(replies, self.num_sampled_clients):
|
|
308
305
|
return None, None
|
|
@@ -19,9 +19,8 @@ Adaptive Federated Optimization using Adagrad.
|
|
|
19
19
|
Paper: arxiv.org/abs/2003.00295
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
|
-
|
|
23
|
-
from collections.abc import Iterable
|
|
24
|
-
from typing import Callable, Optional
|
|
22
|
+
|
|
23
|
+
from collections.abc import Callable, Iterable
|
|
25
24
|
|
|
26
25
|
import numpy as np
|
|
27
26
|
|
|
@@ -90,12 +89,12 @@ class FedAdagrad(FedOpt):
|
|
|
90
89
|
weighted_by_key: str = "num-examples",
|
|
91
90
|
arrayrecord_key: str = "arrays",
|
|
92
91
|
configrecord_key: str = "config",
|
|
93
|
-
train_metrics_aggr_fn:
|
|
94
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
95
|
-
|
|
96
|
-
evaluate_metrics_aggr_fn:
|
|
97
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
98
|
-
|
|
92
|
+
train_metrics_aggr_fn: (
|
|
93
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
94
|
+
) = None,
|
|
95
|
+
evaluate_metrics_aggr_fn: (
|
|
96
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
97
|
+
) = None,
|
|
99
98
|
eta: float = 1e-1,
|
|
100
99
|
eta_l: float = 1e-1,
|
|
101
100
|
tau: float = 1e-3,
|
|
@@ -122,7 +121,7 @@ class FedAdagrad(FedOpt):
|
|
|
122
121
|
self,
|
|
123
122
|
server_round: int,
|
|
124
123
|
replies: Iterable[Message],
|
|
125
|
-
) -> tuple[
|
|
124
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
126
125
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
127
126
|
aggregated_arrayrecord, aggregated_metrics = super().aggregate_train(
|
|
128
127
|
server_round, replies
|
|
@@ -154,6 +153,6 @@ class FedAdagrad(FedOpt):
|
|
|
154
153
|
}
|
|
155
154
|
|
|
156
155
|
return (
|
|
157
|
-
ArrayRecord(
|
|
156
|
+
ArrayRecord({k: Array(v) for k, v in new_arrays.items()}),
|
|
158
157
|
aggregated_metrics,
|
|
159
158
|
)
|
|
@@ -19,9 +19,8 @@
|
|
|
19
19
|
Paper: arxiv.org/abs/2003.00295
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
|
-
|
|
23
|
-
from collections.abc import Iterable
|
|
24
|
-
from typing import Callable, Optional
|
|
22
|
+
|
|
23
|
+
from collections.abc import Callable, Iterable
|
|
25
24
|
|
|
26
25
|
import numpy as np
|
|
27
26
|
|
|
@@ -94,12 +93,12 @@ class FedAdam(FedOpt):
|
|
|
94
93
|
weighted_by_key: str = "num-examples",
|
|
95
94
|
arrayrecord_key: str = "arrays",
|
|
96
95
|
configrecord_key: str = "config",
|
|
97
|
-
train_metrics_aggr_fn:
|
|
98
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
99
|
-
|
|
100
|
-
evaluate_metrics_aggr_fn:
|
|
101
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
102
|
-
|
|
96
|
+
train_metrics_aggr_fn: (
|
|
97
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
98
|
+
) = None,
|
|
99
|
+
evaluate_metrics_aggr_fn: (
|
|
100
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
101
|
+
) = None,
|
|
103
102
|
eta: float = 1e-1,
|
|
104
103
|
eta_l: float = 1e-1,
|
|
105
104
|
beta_1: float = 0.9,
|
|
@@ -128,7 +127,7 @@ class FedAdam(FedOpt):
|
|
|
128
127
|
self,
|
|
129
128
|
server_round: int,
|
|
130
129
|
replies: Iterable[Message],
|
|
131
|
-
) -> tuple[
|
|
130
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
132
131
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
133
132
|
aggregated_arrayrecord, aggregated_metrics = super().aggregate_train(
|
|
134
133
|
server_round, replies
|
|
@@ -173,6 +172,6 @@ class FedAdam(FedOpt):
|
|
|
173
172
|
}
|
|
174
173
|
|
|
175
174
|
return (
|
|
176
|
-
ArrayRecord(
|
|
175
|
+
ArrayRecord({k: Array(v) for k, v in new_arrays.items()}),
|
|
177
176
|
aggregated_metrics,
|
|
178
177
|
)
|
|
@@ -15,9 +15,8 @@
|
|
|
15
15
|
"""Flower message-based FedAvg strategy."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from collections.abc import Iterable
|
|
18
|
+
from collections.abc import Callable, Iterable
|
|
19
19
|
from logging import INFO, WARNING
|
|
20
|
-
from typing import Callable, Optional
|
|
21
20
|
|
|
22
21
|
from flwr.common import (
|
|
23
22
|
ArrayRecord,
|
|
@@ -91,12 +90,12 @@ class FedAvg(Strategy):
|
|
|
91
90
|
weighted_by_key: str = "num-examples",
|
|
92
91
|
arrayrecord_key: str = "arrays",
|
|
93
92
|
configrecord_key: str = "config",
|
|
94
|
-
train_metrics_aggr_fn:
|
|
95
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
96
|
-
|
|
97
|
-
evaluate_metrics_aggr_fn:
|
|
98
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
99
|
-
|
|
93
|
+
train_metrics_aggr_fn: (
|
|
94
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
95
|
+
) = None,
|
|
96
|
+
evaluate_metrics_aggr_fn: (
|
|
97
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
98
|
+
) = None,
|
|
100
99
|
) -> None:
|
|
101
100
|
self.fraction_train = fraction_train
|
|
102
101
|
self.fraction_evaluate = fraction_evaluate
|
|
@@ -251,7 +250,7 @@ class FedAvg(Strategy):
|
|
|
251
250
|
self,
|
|
252
251
|
server_round: int,
|
|
253
252
|
replies: Iterable[Message],
|
|
254
|
-
) -> tuple[
|
|
253
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
255
254
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
256
255
|
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
|
257
256
|
|
|
@@ -304,7 +303,7 @@ class FedAvg(Strategy):
|
|
|
304
303
|
self,
|
|
305
304
|
server_round: int,
|
|
306
305
|
replies: Iterable[Message],
|
|
307
|
-
) ->
|
|
306
|
+
) -> MetricRecord | None:
|
|
308
307
|
"""Aggregate MetricRecords in the received Messages."""
|
|
309
308
|
valid_replies, _ = self._check_and_log_replies(replies, is_train=False)
|
|
310
309
|
|
|
@@ -18,10 +18,8 @@ Paper: arxiv.org/pdf/1909.06335.pdf
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from collections import
|
|
22
|
-
from collections.abc import Iterable
|
|
21
|
+
from collections.abc import Callable, Iterable
|
|
23
22
|
from logging import INFO
|
|
24
|
-
from typing import Callable, Optional
|
|
25
23
|
|
|
26
24
|
from flwr.common import (
|
|
27
25
|
Array,
|
|
@@ -93,12 +91,12 @@ class FedAvgM(FedAvg):
|
|
|
93
91
|
weighted_by_key: str = "num-examples",
|
|
94
92
|
arrayrecord_key: str = "arrays",
|
|
95
93
|
configrecord_key: str = "config",
|
|
96
|
-
train_metrics_aggr_fn:
|
|
97
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
98
|
-
|
|
99
|
-
evaluate_metrics_aggr_fn:
|
|
100
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
101
|
-
|
|
94
|
+
train_metrics_aggr_fn: (
|
|
95
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
96
|
+
) = None,
|
|
97
|
+
evaluate_metrics_aggr_fn: (
|
|
98
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
99
|
+
) = None,
|
|
102
100
|
server_learning_rate: float = 1.0,
|
|
103
101
|
server_momentum: float = 0.0,
|
|
104
102
|
) -> None:
|
|
@@ -119,8 +117,8 @@ class FedAvgM(FedAvg):
|
|
|
119
117
|
self.server_opt: bool = (self.server_momentum != 0.0) or (
|
|
120
118
|
self.server_learning_rate != 1.0
|
|
121
119
|
)
|
|
122
|
-
self.current_arrays:
|
|
123
|
-
self.momentum_vector:
|
|
120
|
+
self.current_arrays: ArrayRecord | None = None
|
|
121
|
+
self.momentum_vector: NDArrays | None = None
|
|
124
122
|
|
|
125
123
|
def summary(self) -> None:
|
|
126
124
|
"""Log summary configuration of the strategy."""
|
|
@@ -143,7 +141,7 @@ class FedAvgM(FedAvg):
|
|
|
143
141
|
self,
|
|
144
142
|
server_round: int,
|
|
145
143
|
replies: Iterable[Message],
|
|
146
|
-
) -> tuple[
|
|
144
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
147
145
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
148
146
|
# Call FedAvg aggregate_train to perform validation and aggregation
|
|
149
147
|
aggregated_arrays, aggregated_metrics = super().aggregate_train(
|
|
@@ -168,7 +166,8 @@ class FedAvgM(FedAvg):
|
|
|
168
166
|
|
|
169
167
|
# Remember that updates are the opposite of gradients
|
|
170
168
|
pseudo_gradient = [
|
|
171
|
-
old - new
|
|
169
|
+
old - new
|
|
170
|
+
for new, old in zip(aggregated_ndarrays, ndarrays, strict=True)
|
|
172
171
|
]
|
|
173
172
|
if self.server_momentum > 0.0:
|
|
174
173
|
if self.momentum_vector is None:
|
|
@@ -177,7 +176,9 @@ class FedAvgM(FedAvg):
|
|
|
177
176
|
else:
|
|
178
177
|
self.momentum_vector = [
|
|
179
178
|
self.server_momentum * mv + pg
|
|
180
|
-
for mv, pg in zip(
|
|
179
|
+
for mv, pg in zip(
|
|
180
|
+
self.momentum_vector, pseudo_gradient, strict=True
|
|
181
|
+
)
|
|
181
182
|
]
|
|
182
183
|
|
|
183
184
|
# No nesterov for now
|
|
@@ -186,10 +187,10 @@ class FedAvgM(FedAvg):
|
|
|
186
187
|
# SGD and convert back to ArrayRecord
|
|
187
188
|
updated_array_list = [
|
|
188
189
|
Array(old - self.server_learning_rate * pg)
|
|
189
|
-
for old, pg in zip(ndarrays, pseudo_gradient)
|
|
190
|
+
for old, pg in zip(ndarrays, pseudo_gradient, strict=True)
|
|
190
191
|
]
|
|
191
192
|
aggregated_arrays = ArrayRecord(
|
|
192
|
-
|
|
193
|
+
dict(zip(array_keys, updated_array_list, strict=True))
|
|
193
194
|
)
|
|
194
195
|
|
|
195
196
|
# Update current weights
|
|
@@ -19,7 +19,7 @@ Paper: arxiv.org/pdf/1803.01498v1.pdf
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
from collections.abc import Iterable
|
|
22
|
-
from typing import
|
|
22
|
+
from typing import cast
|
|
23
23
|
|
|
24
24
|
import numpy as np
|
|
25
25
|
|
|
@@ -72,7 +72,7 @@ class FedMedian(FedAvg):
|
|
|
72
72
|
self,
|
|
73
73
|
server_round: int,
|
|
74
74
|
replies: Iterable[Message],
|
|
75
|
-
) -> tuple[
|
|
75
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
76
76
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
77
77
|
# Call FedAvg aggregate_train to perform validation and aggregation
|
|
78
78
|
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
|
@@ -17,9 +17,8 @@
|
|
|
17
17
|
Paper: arxiv.org/abs/2003.00295
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
|
-
from collections.abc import Iterable
|
|
20
|
+
from collections.abc import Callable, Iterable
|
|
21
21
|
from logging import INFO
|
|
22
|
-
from typing import Callable, Optional
|
|
23
22
|
|
|
24
23
|
import numpy as np
|
|
25
24
|
|
|
@@ -101,12 +100,12 @@ class FedOpt(FedAvg):
|
|
|
101
100
|
weighted_by_key: str = "num-examples",
|
|
102
101
|
arrayrecord_key: str = "arrays",
|
|
103
102
|
configrecord_key: str = "config",
|
|
104
|
-
train_metrics_aggr_fn:
|
|
105
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
106
|
-
|
|
107
|
-
evaluate_metrics_aggr_fn:
|
|
108
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
109
|
-
|
|
103
|
+
train_metrics_aggr_fn: (
|
|
104
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
105
|
+
) = None,
|
|
106
|
+
evaluate_metrics_aggr_fn: (
|
|
107
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
108
|
+
) = None,
|
|
110
109
|
eta: float = 1e-1,
|
|
111
110
|
eta_l: float = 1e-1,
|
|
112
111
|
beta_1: float = 0.0,
|
|
@@ -125,14 +124,14 @@ class FedOpt(FedAvg):
|
|
|
125
124
|
train_metrics_aggr_fn=train_metrics_aggr_fn,
|
|
126
125
|
evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
|
|
127
126
|
)
|
|
128
|
-
self.current_arrays:
|
|
127
|
+
self.current_arrays: dict[str, NDArray] | None = None
|
|
129
128
|
self.eta = eta
|
|
130
129
|
self.eta_l = eta_l
|
|
131
130
|
self.tau = tau
|
|
132
131
|
self.beta_1 = beta_1
|
|
133
132
|
self.beta_2 = beta_2
|
|
134
|
-
self.m_t:
|
|
135
|
-
self.v_t:
|
|
133
|
+
self.m_t: dict[str, NDArray] | None = None
|
|
134
|
+
self.v_t: dict[str, NDArray] | None = None
|
|
136
135
|
|
|
137
136
|
def summary(self) -> None:
|
|
138
137
|
"""Log summary configuration of the strategy."""
|
|
@@ -18,9 +18,8 @@ Paper: arxiv.org/abs/1812.06127
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from collections.abc import Iterable
|
|
21
|
+
from collections.abc import Callable, Iterable
|
|
22
22
|
from logging import INFO, WARN
|
|
23
|
-
from typing import Callable, Optional
|
|
24
23
|
|
|
25
24
|
from flwr.common import (
|
|
26
25
|
ArrayRecord,
|
|
@@ -130,12 +129,12 @@ class FedProx(FedAvg):
|
|
|
130
129
|
weighted_by_key: str = "num-examples",
|
|
131
130
|
arrayrecord_key: str = "arrays",
|
|
132
131
|
configrecord_key: str = "config",
|
|
133
|
-
train_metrics_aggr_fn:
|
|
134
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
135
|
-
|
|
136
|
-
evaluate_metrics_aggr_fn:
|
|
137
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
138
|
-
|
|
132
|
+
train_metrics_aggr_fn: (
|
|
133
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
134
|
+
) = None,
|
|
135
|
+
evaluate_metrics_aggr_fn: (
|
|
136
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
137
|
+
) = None,
|
|
139
138
|
proximal_mu: float = 0.0,
|
|
140
139
|
) -> None:
|
|
141
140
|
super().__init__(
|
|
@@ -18,9 +18,9 @@ Paper: arxiv.org/abs/1803.01498
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from collections.abc import Iterable
|
|
21
|
+
from collections.abc import Callable, Iterable
|
|
22
22
|
from logging import INFO
|
|
23
|
-
from typing import
|
|
23
|
+
from typing import cast
|
|
24
24
|
|
|
25
25
|
import numpy as np
|
|
26
26
|
|
|
@@ -83,12 +83,12 @@ class FedTrimmedAvg(FedAvg):
|
|
|
83
83
|
weighted_by_key: str = "num-examples",
|
|
84
84
|
arrayrecord_key: str = "arrays",
|
|
85
85
|
configrecord_key: str = "config",
|
|
86
|
-
train_metrics_aggr_fn:
|
|
87
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
88
|
-
|
|
89
|
-
evaluate_metrics_aggr_fn:
|
|
90
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
91
|
-
|
|
86
|
+
train_metrics_aggr_fn: (
|
|
87
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
88
|
+
) = None,
|
|
89
|
+
evaluate_metrics_aggr_fn: (
|
|
90
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
91
|
+
) = None,
|
|
92
92
|
beta: float = 0.2,
|
|
93
93
|
) -> None:
|
|
94
94
|
super().__init__(
|
|
@@ -115,7 +115,7 @@ class FedTrimmedAvg(FedAvg):
|
|
|
115
115
|
self,
|
|
116
116
|
server_round: int,
|
|
117
117
|
replies: Iterable[Message],
|
|
118
|
-
) -> tuple[
|
|
118
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
119
119
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
120
120
|
# Call FedAvg aggregate_train to perform validation and aggregation
|
|
121
121
|
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower message-based FedXgbBagging strategy."""
|
|
16
16
|
from collections.abc import Iterable
|
|
17
|
-
from typing import
|
|
17
|
+
from typing import cast
|
|
18
18
|
|
|
19
19
|
import numpy as np
|
|
20
20
|
|
|
@@ -65,7 +65,7 @@ class FedXgbBagging(FedAvg):
|
|
|
65
65
|
average using the provided weight factor key.
|
|
66
66
|
"""
|
|
67
67
|
|
|
68
|
-
current_bst:
|
|
68
|
+
current_bst: bytes | None = None
|
|
69
69
|
|
|
70
70
|
def _ensure_single_array(self, arrays: ArrayRecord) -> None:
|
|
71
71
|
"""Check that ensures there's only one Array in the ArrayRecord."""
|
|
@@ -89,7 +89,7 @@ class FedXgbBagging(FedAvg):
|
|
|
89
89
|
self,
|
|
90
90
|
server_round: int,
|
|
91
91
|
replies: Iterable[Message],
|
|
92
|
-
) -> tuple[
|
|
92
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
93
93
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
94
94
|
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
|
95
95
|
|
|
@@ -15,9 +15,9 @@
|
|
|
15
15
|
"""Flower message-based FedXgbCyclic strategy."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from collections.abc import Iterable
|
|
18
|
+
from collections.abc import Callable, Iterable
|
|
19
19
|
from logging import INFO
|
|
20
|
-
from typing import
|
|
20
|
+
from typing import cast
|
|
21
21
|
|
|
22
22
|
from flwr.common import (
|
|
23
23
|
ArrayRecord,
|
|
@@ -78,12 +78,12 @@ class FedXgbCyclic(FedAvg):
|
|
|
78
78
|
weighted_by_key: str = "num-examples",
|
|
79
79
|
arrayrecord_key: str = "arrays",
|
|
80
80
|
configrecord_key: str = "config",
|
|
81
|
-
train_metrics_aggr_fn:
|
|
82
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
83
|
-
|
|
84
|
-
evaluate_metrics_aggr_fn:
|
|
85
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
86
|
-
|
|
81
|
+
train_metrics_aggr_fn: (
|
|
82
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
83
|
+
) = None,
|
|
84
|
+
evaluate_metrics_aggr_fn: (
|
|
85
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
86
|
+
) = None,
|
|
87
87
|
) -> None:
|
|
88
88
|
super().__init__(
|
|
89
89
|
fraction_train=fraction_train,
|
|
@@ -184,7 +184,7 @@ class FedXgbCyclic(FedAvg):
|
|
|
184
184
|
self,
|
|
185
185
|
server_round: int,
|
|
186
186
|
replies: Iterable[Message],
|
|
187
|
-
) -> tuple[
|
|
187
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
188
188
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
189
189
|
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
|
190
190
|
|
|
@@ -18,9 +18,7 @@ Paper: arxiv.org/abs/2003.00295
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from collections import
|
|
22
|
-
from collections.abc import Iterable
|
|
23
|
-
from typing import Callable, Optional
|
|
21
|
+
from collections.abc import Callable, Iterable
|
|
24
22
|
|
|
25
23
|
import numpy as np
|
|
26
24
|
|
|
@@ -95,12 +93,12 @@ class FedYogi(FedOpt):
|
|
|
95
93
|
weighted_by_key: str = "num-examples",
|
|
96
94
|
arrayrecord_key: str = "arrays",
|
|
97
95
|
configrecord_key: str = "config",
|
|
98
|
-
train_metrics_aggr_fn:
|
|
99
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
100
|
-
|
|
101
|
-
evaluate_metrics_aggr_fn:
|
|
102
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
103
|
-
|
|
96
|
+
train_metrics_aggr_fn: (
|
|
97
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
98
|
+
) = None,
|
|
99
|
+
evaluate_metrics_aggr_fn: (
|
|
100
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
101
|
+
) = None,
|
|
104
102
|
eta: float = 1e-2,
|
|
105
103
|
eta_l: float = 0.0316,
|
|
106
104
|
beta_1: float = 0.9,
|
|
@@ -129,7 +127,7 @@ class FedYogi(FedOpt):
|
|
|
129
127
|
self,
|
|
130
128
|
server_round: int,
|
|
131
129
|
replies: Iterable[Message],
|
|
132
|
-
) -> tuple[
|
|
130
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
133
131
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
134
132
|
aggregated_arrayrecord, aggregated_metrics = super().aggregate_train(
|
|
135
133
|
server_round, replies
|
|
@@ -165,6 +163,6 @@ class FedYogi(FedOpt):
|
|
|
165
163
|
}
|
|
166
164
|
|
|
167
165
|
return (
|
|
168
|
-
ArrayRecord(
|
|
166
|
+
ArrayRecord({k: Array(v) for k, v in new_arrays.items()}),
|
|
169
167
|
aggregated_metrics,
|
|
170
168
|
)
|
flwr/serverapp/strategy/krum.py
CHANGED
|
@@ -20,8 +20,8 @@ Paper: proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-P
|
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
22
|
|
|
23
|
+
from collections.abc import Callable
|
|
23
24
|
from logging import INFO
|
|
24
|
-
from typing import Callable, Optional
|
|
25
25
|
|
|
26
26
|
from flwr.common import MetricRecord, RecordDict, log
|
|
27
27
|
|
|
@@ -83,12 +83,12 @@ class Krum(MultiKrum):
|
|
|
83
83
|
weighted_by_key: str = "num-examples",
|
|
84
84
|
arrayrecord_key: str = "arrays",
|
|
85
85
|
configrecord_key: str = "config",
|
|
86
|
-
train_metrics_aggr_fn:
|
|
87
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
88
|
-
|
|
89
|
-
evaluate_metrics_aggr_fn:
|
|
90
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
91
|
-
|
|
86
|
+
train_metrics_aggr_fn: (
|
|
87
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
88
|
+
) = None,
|
|
89
|
+
evaluate_metrics_aggr_fn: (
|
|
90
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
91
|
+
) = None,
|
|
92
92
|
) -> None:
|
|
93
93
|
super().__init__(
|
|
94
94
|
fraction_train=fraction_train,
|
|
@@ -20,9 +20,9 @@ Paper: proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-P
|
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
from collections.abc import Iterable
|
|
23
|
+
from collections.abc import Callable, Iterable
|
|
24
24
|
from logging import INFO
|
|
25
|
-
from typing import
|
|
25
|
+
from typing import cast
|
|
26
26
|
|
|
27
27
|
import numpy as np
|
|
28
28
|
|
|
@@ -95,12 +95,12 @@ class MultiKrum(FedAvg):
|
|
|
95
95
|
weighted_by_key: str = "num-examples",
|
|
96
96
|
arrayrecord_key: str = "arrays",
|
|
97
97
|
configrecord_key: str = "config",
|
|
98
|
-
train_metrics_aggr_fn:
|
|
99
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
100
|
-
|
|
101
|
-
evaluate_metrics_aggr_fn:
|
|
102
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
103
|
-
|
|
98
|
+
train_metrics_aggr_fn: (
|
|
99
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
100
|
+
) = None,
|
|
101
|
+
evaluate_metrics_aggr_fn: (
|
|
102
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
103
|
+
) = None,
|
|
104
104
|
) -> None:
|
|
105
105
|
super().__init__(
|
|
106
106
|
fraction_train=fraction_train,
|
|
@@ -128,7 +128,7 @@ class MultiKrum(FedAvg):
|
|
|
128
128
|
self,
|
|
129
129
|
server_round: int,
|
|
130
130
|
replies: Iterable[Message],
|
|
131
|
-
) -> tuple[
|
|
131
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
132
132
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
133
133
|
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
|
134
134
|
|