flwr 1.22.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +16 -5
- flwr/app/error.py +2 -2
- flwr/app/exception.py +3 -3
- flwr/cli/app.py +34 -1
- flwr/cli/app_cmd/__init__.py +23 -0
- flwr/cli/app_cmd/publish.py +285 -0
- flwr/cli/app_cmd/review.py +252 -0
- flwr/cli/auth_plugin/__init__.py +15 -6
- flwr/cli/auth_plugin/auth_plugin.py +94 -0
- flwr/cli/auth_plugin/noop_auth_plugin.py +101 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +46 -32
- flwr/cli/build.py +166 -53
- flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +29 -11
- flwr/cli/config_utils.py +101 -13
- flwr/cli/federation/__init__.py +24 -0
- flwr/cli/federation/ls.py +140 -0
- flwr/cli/federation/show.py +317 -0
- flwr/cli/install.py +91 -13
- flwr/cli/log.py +54 -11
- flwr/cli/login/login.py +41 -27
- flwr/cli/ls.py +177 -133
- flwr/cli/new/new.py +175 -40
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/cli/pull.py +12 -7
- flwr/cli/run/run.py +82 -31
- flwr/cli/run_utils.py +130 -0
- flwr/cli/stop.py +27 -9
- flwr/cli/supernode/__init__.py +25 -0
- flwr/cli/supernode/ls.py +268 -0
- flwr/cli/supernode/register.py +190 -0
- flwr/cli/supernode/unregister.py +140 -0
- flwr/cli/utils.py +464 -81
- flwr/client/__init__.py +2 -1
- flwr/client/dpfedavg_numpy_client.py +4 -1
- flwr/client/grpc_adapter_client/connection.py +12 -15
- flwr/client/grpc_rere_client/connection.py +68 -41
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -14
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +5 -7
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +10 -8
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/connection.py +94 -51
- flwr/client/run_info_store.py +4 -5
- flwr/client/typing.py +1 -1
- flwr/clientapp/__init__.py +1 -2
- flwr/{client → clientapp}/client_app.py +9 -10
- flwr/clientapp/mod/centraldp_mods.py +16 -17
- flwr/clientapp/mod/localdp_mod.py +8 -9
- flwr/clientapp/typing.py +1 -1
- flwr/{client/clientapp → clientapp}/utils.py +4 -4
- flwr/common/address.py +1 -2
- flwr/common/args.py +3 -4
- flwr/common/config.py +13 -16
- flwr/common/constant.py +56 -13
- flwr/common/differential_privacy.py +3 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -4
- flwr/common/exit/exit.py +15 -2
- flwr/common/exit/exit_code.py +39 -10
- flwr/common/exit/exit_handler.py +6 -2
- flwr/common/exit/signal_handler.py +5 -5
- flwr/common/grpc.py +6 -6
- flwr/common/inflatable_protobuf_utils.py +1 -1
- flwr/common/inflatable_utils.py +48 -31
- flwr/common/logger.py +19 -19
- flwr/common/message.py +4 -4
- flwr/common/object_ref.py +7 -7
- flwr/common/record/array.py +6 -6
- flwr/common/record/arrayrecord.py +18 -21
- flwr/common/record/configrecord.py +3 -3
- flwr/common/record/recorddict.py +5 -5
- flwr/common/record/typeddict.py +9 -2
- flwr/common/recorddict_compat.py +7 -10
- flwr/common/retry_invoker.py +20 -20
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
- flwr/common/serde.py +9 -6
- flwr/common/serde_utils.py +2 -2
- flwr/common/telemetry.py +9 -5
- flwr/common/typing.py +59 -43
- flwr/compat/client/app.py +39 -38
- flwr/compat/client/grpc_client/connection.py +13 -13
- flwr/compat/server/app.py +5 -6
- flwr/proto/appio_pb2.py +13 -3
- flwr/proto/appio_pb2.pyi +134 -65
- flwr/proto/appio_pb2_grpc.py +20 -0
- flwr/proto/appio_pb2_grpc.pyi +27 -0
- flwr/proto/clientappio_pb2.py +17 -7
- flwr/proto/clientappio_pb2.pyi +15 -0
- flwr/proto/clientappio_pb2_grpc.py +206 -40
- flwr/proto/clientappio_pb2_grpc.pyi +168 -53
- flwr/proto/control_pb2.py +72 -40
- flwr/proto/control_pb2.pyi +319 -87
- flwr/proto/control_pb2_grpc.py +339 -28
- flwr/proto/control_pb2_grpc.pyi +209 -37
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2.pyi +24 -6
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/error_pb2_grpc.pyi +27 -0
- flwr/proto/fab_pb2.py +24 -10
- flwr/proto/fab_pb2.pyi +68 -20
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fab_pb2_grpc.pyi +27 -0
- flwr/proto/federation_pb2.py +38 -0
- flwr/proto/federation_pb2.pyi +56 -0
- flwr/proto/federation_pb2_grpc.py +24 -0
- flwr/proto/federation_pb2_grpc.pyi +31 -0
- flwr/proto/fleet_pb2.py +45 -27
- flwr/proto/fleet_pb2.pyi +186 -70
- flwr/proto/fleet_pb2_grpc.py +277 -66
- flwr/proto/fleet_pb2_grpc.pyi +201 -55
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2.pyi +38 -16
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
- flwr/proto/heartbeat_pb2.py +17 -7
- flwr/proto/heartbeat_pb2.pyi +51 -22
- flwr/proto/heartbeat_pb2_grpc.py +20 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2.pyi +34 -11
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/log_pb2_grpc.pyi +27 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2.pyi +154 -86
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/message_pb2_grpc.pyi +27 -0
- flwr/proto/node_pb2.py +16 -4
- flwr/proto/node_pb2.pyi +77 -4
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/node_pb2_grpc.pyi +27 -0
- flwr/proto/recorddict_pb2.py +13 -3
- flwr/proto/recorddict_pb2.pyi +184 -107
- flwr/proto/recorddict_pb2_grpc.py +20 -0
- flwr/proto/recorddict_pb2_grpc.pyi +27 -0
- flwr/proto/run_pb2.py +40 -31
- flwr/proto/run_pb2.pyi +149 -84
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/run_pb2_grpc.pyi +27 -0
- flwr/proto/serverappio_pb2.py +13 -3
- flwr/proto/serverappio_pb2.pyi +32 -8
- flwr/proto/serverappio_pb2_grpc.py +246 -65
- flwr/proto/serverappio_pb2_grpc.pyi +221 -85
- flwr/proto/simulationio_pb2.py +16 -8
- flwr/proto/simulationio_pb2.pyi +15 -0
- flwr/proto/simulationio_pb2_grpc.py +162 -41
- flwr/proto/simulationio_pb2_grpc.pyi +149 -55
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2.pyi +249 -160
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/proto/transport_pb2_grpc.pyi +38 -8
- flwr/server/app.py +173 -127
- flwr/server/client_manager.py +4 -5
- flwr/server/client_proxy.py +10 -11
- flwr/server/compat/app.py +4 -5
- flwr/server/compat/app_utils.py +2 -1
- flwr/server/compat/grid_client_proxy.py +10 -12
- flwr/server/compat/legacy_context.py +3 -4
- flwr/server/fleet_event_log_interceptor.py +2 -1
- flwr/server/grid/grid.py +2 -3
- flwr/server/grid/grpc_grid.py +10 -8
- flwr/server/grid/inmemory_grid.py +4 -4
- flwr/server/run_serverapp.py +2 -3
- flwr/server/server.py +34 -39
- flwr/server/server_app.py +7 -8
- flwr/server/server_config.py +1 -2
- flwr/server/serverapp/app.py +34 -28
- flwr/server/serverapp_components.py +4 -5
- flwr/server/strategy/aggregate.py +9 -8
- flwr/server/strategy/bulyan.py +13 -11
- flwr/server/strategy/dp_adaptive_clipping.py +16 -20
- flwr/server/strategy/dp_fixed_clipping.py +12 -17
- flwr/server/strategy/dpfedavg_adaptive.py +3 -4
- flwr/server/strategy/dpfedavg_fixed.py +6 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
- flwr/server/strategy/fedadagrad.py +18 -14
- flwr/server/strategy/fedadam.py +16 -14
- flwr/server/strategy/fedavg.py +16 -17
- flwr/server/strategy/fedavg_android.py +15 -15
- flwr/server/strategy/fedavgm.py +21 -18
- flwr/server/strategy/fedmedian.py +2 -3
- flwr/server/strategy/fedopt.py +11 -10
- flwr/server/strategy/fedprox.py +10 -9
- flwr/server/strategy/fedtrimmedavg.py +12 -11
- flwr/server/strategy/fedxgb_bagging.py +13 -11
- flwr/server/strategy/fedxgb_cyclic.py +6 -6
- flwr/server/strategy/fedxgb_nn_avg.py +4 -4
- flwr/server/strategy/fedyogi.py +16 -14
- flwr/server/strategy/krum.py +12 -11
- flwr/server/strategy/qfedavg.py +16 -15
- flwr/server/strategy/strategy.py +6 -9
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +19 -8
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +136 -42
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +28 -51
- flwr/server/superlink/fleet/message_handler/message_handler.py +100 -49
- flwr/server/superlink/fleet/rest_rere/rest_api.py +54 -33
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +6 -6
- flwr/server/superlink/fleet/vce/vce_api.py +32 -13
- flwr/server/superlink/linkstate/in_memory_linkstate.py +266 -207
- flwr/server/superlink/linkstate/linkstate.py +161 -62
- flwr/server/superlink/linkstate/linkstate_factory.py +24 -6
- flwr/server/superlink/linkstate/sqlite_linkstate.py +698 -638
- flwr/server/superlink/linkstate/utils.py +9 -60
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +28 -23
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +19 -14
- flwr/server/superlink/utils.py +4 -6
- flwr/server/typing.py +1 -1
- flwr/server/utils/tensorboard.py +15 -8
- flwr/server/utils/validator.py +2 -3
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +12 -10
- flwr/serverapp/strategy/bulyan.py +16 -15
- flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
- flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
- flwr/serverapp/strategy/fedadagrad.py +10 -11
- flwr/serverapp/strategy/fedadam.py +10 -11
- flwr/serverapp/strategy/fedavg.py +9 -10
- flwr/serverapp/strategy/fedavgm.py +17 -16
- flwr/serverapp/strategy/fedmedian.py +2 -2
- flwr/serverapp/strategy/fedopt.py +10 -11
- flwr/serverapp/strategy/fedprox.py +7 -8
- flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
- flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
- flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
- flwr/serverapp/strategy/fedyogi.py +9 -11
- flwr/serverapp/strategy/krum.py +7 -7
- flwr/serverapp/strategy/multikrum.py +9 -9
- flwr/serverapp/strategy/qfedavg.py +17 -16
- flwr/serverapp/strategy/strategy.py +6 -9
- flwr/serverapp/strategy/strategy_utils.py +7 -8
- flwr/simulation/app.py +46 -42
- flwr/simulation/legacy_app.py +12 -12
- flwr/simulation/ray_transport/ray_actor.py +11 -12
- flwr/simulation/ray_transport/ray_client_proxy.py +12 -13
- flwr/simulation/run_simulation.py +44 -43
- flwr/simulation/simulationio_connection.py +4 -4
- flwr/supercore/cli/flower_superexec.py +3 -4
- flwr/supercore/constant.py +52 -0
- flwr/supercore/corestate/corestate.py +24 -3
- flwr/supercore/corestate/in_memory_corestate.py +138 -0
- flwr/supercore/corestate/sqlite_corestate.py +157 -0
- flwr/supercore/ffs/disk_ffs.py +1 -2
- flwr/supercore/ffs/ffs.py +1 -2
- flwr/supercore/ffs/ffs_factory.py +1 -2
- flwr/{common → supercore}/heartbeat.py +20 -25
- flwr/supercore/object_store/in_memory_object_store.py +1 -6
- flwr/supercore/object_store/object_store.py +1 -2
- flwr/supercore/object_store/object_store_factory.py +27 -8
- flwr/supercore/object_store/sqlite_object_store.py +253 -0
- flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
- flwr/supercore/primitives/asymmetric.py +117 -0
- flwr/supercore/primitives/asymmetric_ed25519.py +175 -0
- flwr/supercore/sqlite_mixin.py +159 -0
- flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
- flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
- flwr/supercore/superexec/run_superexec.py +9 -13
- flwr/supercore/utils.py +20 -0
- flwr/superlink/artifact_provider/artifact_provider.py +1 -2
- flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
- flwr/superlink/auth_plugin/auth_plugin.py +88 -0
- flwr/superlink/auth_plugin/noop_auth_plugin.py +84 -0
- flwr/superlink/federation/__init__.py +24 -0
- flwr/superlink/federation/federation_manager.py +64 -0
- flwr/superlink/federation/noop_federation_manager.py +71 -0
- flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +41 -32
- flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
- flwr/superlink/servicer/control/control_grpc.py +18 -17
- flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
- flwr/superlink/servicer/control/control_servicer.py +239 -63
- flwr/supernode/cli/flower_supernode.py +74 -26
- flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
- flwr/supernode/nodestate/nodestate.py +7 -8
- flwr/supernode/nodestate/nodestate_factory.py +7 -4
- flwr/supernode/runtime/run_clientapp.py +43 -24
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
- flwr/supernode/start_client_internal.py +175 -51
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
- flwr-1.24.0.dist-info/RECORD +454 -0
- flwr/common/auth_plugin/auth_plugin.py +0 -149
- flwr/supercore/object_store/utils.py +0 -43
- flwr-1.22.0.dist-info/RECORD +0 -428
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
|
@@ -18,10 +18,9 @@ Paper: arxiv.org/abs/1802.07927
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from collections import
|
|
22
|
-
from collections.abc import Iterable
|
|
21
|
+
from collections.abc import Callable, Iterable
|
|
23
22
|
from logging import INFO, WARN
|
|
24
|
-
from typing import
|
|
23
|
+
from typing import cast
|
|
25
24
|
|
|
26
25
|
import numpy as np
|
|
27
26
|
|
|
@@ -104,15 +103,15 @@ class Bulyan(FedAvg):
|
|
|
104
103
|
weighted_by_key: str = "num-examples",
|
|
105
104
|
arrayrecord_key: str = "arrays",
|
|
106
105
|
configrecord_key: str = "config",
|
|
107
|
-
train_metrics_aggr_fn:
|
|
108
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
109
|
-
|
|
110
|
-
evaluate_metrics_aggr_fn:
|
|
111
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
112
|
-
|
|
113
|
-
selection_rule:
|
|
114
|
-
Callable[[list[RecordDict], int, int], list[RecordDict]]
|
|
115
|
-
|
|
106
|
+
train_metrics_aggr_fn: (
|
|
107
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
108
|
+
) = None,
|
|
109
|
+
evaluate_metrics_aggr_fn: (
|
|
110
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
111
|
+
) = None,
|
|
112
|
+
selection_rule: (
|
|
113
|
+
Callable[[list[RecordDict], int, int], list[RecordDict]] | None
|
|
114
|
+
) = None,
|
|
116
115
|
) -> None:
|
|
117
116
|
super().__init__(
|
|
118
117
|
fraction_train=fraction_train,
|
|
@@ -140,7 +139,7 @@ class Bulyan(FedAvg):
|
|
|
140
139
|
self,
|
|
141
140
|
server_round: int,
|
|
142
141
|
replies: Iterable[Message],
|
|
143
|
-
) -> tuple[
|
|
142
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
144
143
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
145
144
|
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
|
146
145
|
|
|
@@ -175,7 +174,9 @@ class Bulyan(FedAvg):
|
|
|
175
174
|
]
|
|
176
175
|
|
|
177
176
|
# Compute median
|
|
178
|
-
median_ndarrays = [
|
|
177
|
+
median_ndarrays = [
|
|
178
|
+
np.median(arr, axis=0) for arr in zip(*selected_ndarrays, strict=True)
|
|
179
|
+
]
|
|
179
180
|
|
|
180
181
|
# Aggregate the beta closest weights element-wise
|
|
181
182
|
aggregated_ndarrays = aggregate_n_closest_weights(
|
|
@@ -184,7 +185,7 @@ class Bulyan(FedAvg):
|
|
|
184
185
|
|
|
185
186
|
# Convert to ArrayRecord
|
|
186
187
|
arrays = ArrayRecord(
|
|
187
|
-
|
|
188
|
+
dict(zip(array_keys, map(Array, aggregated_ndarrays), strict=True))
|
|
188
189
|
)
|
|
189
190
|
|
|
190
191
|
# Aggregate MetricRecords
|
|
@@ -19,10 +19,8 @@ Paper (Andrew et al.): https://arxiv.org/abs/1905.03871
|
|
|
19
19
|
|
|
20
20
|
import math
|
|
21
21
|
from abc import ABC
|
|
22
|
-
from collections import OrderedDict
|
|
23
22
|
from collections.abc import Iterable
|
|
24
23
|
from logging import INFO
|
|
25
|
-
from typing import Optional
|
|
26
24
|
|
|
27
25
|
import numpy as np
|
|
28
26
|
|
|
@@ -53,7 +51,7 @@ class DifferentialPrivacyAdaptiveBase(Strategy, ABC):
|
|
|
53
51
|
initial_clipping_norm: float = 0.1,
|
|
54
52
|
target_clipped_quantile: float = 0.5,
|
|
55
53
|
clip_norm_lr: float = 0.2,
|
|
56
|
-
clipped_count_stddev:
|
|
54
|
+
clipped_count_stddev: float | None = None,
|
|
57
55
|
) -> None:
|
|
58
56
|
super().__init__()
|
|
59
57
|
|
|
@@ -96,7 +94,7 @@ class DifferentialPrivacyAdaptiveBase(Strategy, ABC):
|
|
|
96
94
|
add_gaussian_noise_inplace(nds, stdv)
|
|
97
95
|
log(INFO, "aggregate_fit: central DP noise with %.4f stdev added", stdv)
|
|
98
96
|
return ArrayRecord(
|
|
99
|
-
|
|
97
|
+
{k: Array(v) for k, v in zip(aggregated.keys(), nds, strict=True)}
|
|
100
98
|
)
|
|
101
99
|
|
|
102
100
|
def _noisy_fraction(self, count: int, total: int) -> float:
|
|
@@ -115,7 +113,7 @@ class DifferentialPrivacyAdaptiveBase(Strategy, ABC):
|
|
|
115
113
|
|
|
116
114
|
def aggregate_evaluate(
|
|
117
115
|
self, server_round: int, replies: Iterable[Message]
|
|
118
|
-
) ->
|
|
116
|
+
) -> MetricRecord | None:
|
|
119
117
|
"""Aggregate MetricRecords in the received Messages."""
|
|
120
118
|
return self.strategy.aggregate_evaluate(server_round, replies)
|
|
121
119
|
|
|
@@ -136,7 +134,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(DifferentialPrivacyAdaptiveB
|
|
|
136
134
|
initial_clipping_norm: float = 0.1,
|
|
137
135
|
target_clipped_quantile: float = 0.5,
|
|
138
136
|
clip_norm_lr: float = 0.2,
|
|
139
|
-
clipped_count_stddev:
|
|
137
|
+
clipped_count_stddev: float | None = None,
|
|
140
138
|
) -> None:
|
|
141
139
|
super().__init__(
|
|
142
140
|
strategy,
|
|
@@ -171,7 +169,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(DifferentialPrivacyAdaptiveB
|
|
|
171
169
|
|
|
172
170
|
def aggregate_train(
|
|
173
171
|
self, server_round: int, replies: Iterable[Message]
|
|
174
|
-
) -> tuple[
|
|
172
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
175
173
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
176
174
|
if not validate_replies(replies, self.num_sampled_clients):
|
|
177
175
|
return None, None
|
|
@@ -184,16 +182,19 @@ class DifferentialPrivacyServerSideAdaptiveClipping(DifferentialPrivacyAdaptiveB
|
|
|
184
182
|
for arr_name, record in reply.content.array_records.items():
|
|
185
183
|
reply_nd = record.to_numpy_ndarrays()
|
|
186
184
|
model_update = [
|
|
187
|
-
np.subtract(x, y)
|
|
185
|
+
np.subtract(x, y)
|
|
186
|
+
for (x, y) in zip(reply_nd, current_nd, strict=True)
|
|
188
187
|
]
|
|
189
188
|
norm_bit = adaptive_clip_inputs_inplace(
|
|
190
189
|
model_update, self.clipping_norm
|
|
191
190
|
)
|
|
192
191
|
clipped_indicator_count += int(norm_bit)
|
|
193
192
|
# reconstruct array using clipped contribution from current round
|
|
194
|
-
restored = [
|
|
193
|
+
restored = [
|
|
194
|
+
c + u for c, u in zip(current_nd, model_update, strict=True)
|
|
195
|
+
]
|
|
195
196
|
reply.content[arr_name] = ArrayRecord(
|
|
196
|
-
|
|
197
|
+
{k: Array(v) for k, v in zip(record.keys(), restored, strict=True)}
|
|
197
198
|
)
|
|
198
199
|
log(
|
|
199
200
|
INFO,
|
|
@@ -287,7 +288,7 @@ class DifferentialPrivacyClientSideAdaptiveClipping(DifferentialPrivacyAdaptiveB
|
|
|
287
288
|
|
|
288
289
|
def aggregate_train(
|
|
289
290
|
self, server_round: int, replies: Iterable[Message]
|
|
290
|
-
) -> tuple[
|
|
291
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
291
292
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
292
293
|
if not validate_replies(replies, self.num_sampled_clients):
|
|
293
294
|
return None, None
|
|
@@ -17,11 +17,10 @@
|
|
|
17
17
|
Papers: https://arxiv.org/abs/1712.07557, https://arxiv.org/abs/1710.06963
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
|
+
|
|
20
21
|
from abc import ABC
|
|
21
|
-
from collections import OrderedDict
|
|
22
22
|
from collections.abc import Iterable
|
|
23
23
|
from logging import INFO, WARNING
|
|
24
|
-
from typing import Optional
|
|
25
24
|
|
|
26
25
|
from flwr.common import Array, ArrayRecord, ConfigRecord, Message, MetricRecord, log
|
|
27
26
|
from flwr.common.differential_privacy import (
|
|
@@ -112,12 +111,12 @@ class DifferentialPrivacyFixedClippingBase(Strategy, ABC):
|
|
|
112
111
|
)
|
|
113
112
|
|
|
114
113
|
return ArrayRecord(
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
114
|
+
{
|
|
115
|
+
k: Array(v)
|
|
116
|
+
for k, v in zip(
|
|
117
|
+
aggregated_arrays.keys(), aggregated_ndarrays, strict=True
|
|
118
|
+
)
|
|
119
|
+
}
|
|
121
120
|
)
|
|
122
121
|
|
|
123
122
|
def configure_evaluate(
|
|
@@ -130,7 +129,7 @@ class DifferentialPrivacyFixedClippingBase(Strategy, ABC):
|
|
|
130
129
|
self,
|
|
131
130
|
server_round: int,
|
|
132
131
|
replies: Iterable[Message],
|
|
133
|
-
) ->
|
|
132
|
+
) -> MetricRecord | None:
|
|
134
133
|
"""Aggregate MetricRecords in the received Messages."""
|
|
135
134
|
return self.strategy.aggregate_evaluate(server_round, replies)
|
|
136
135
|
|
|
@@ -199,7 +198,7 @@ class DifferentialPrivacyServerSideFixedClipping(DifferentialPrivacyFixedClippin
|
|
|
199
198
|
self,
|
|
200
199
|
server_round: int,
|
|
201
200
|
replies: Iterable[Message],
|
|
202
|
-
) -> tuple[
|
|
201
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
203
202
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
204
203
|
if not validate_replies(replies, self.num_sampled_clients):
|
|
205
204
|
return None, None
|
|
@@ -217,9 +216,7 @@ class DifferentialPrivacyServerSideFixedClipping(DifferentialPrivacyFixedClippin
|
|
|
217
216
|
)
|
|
218
217
|
# Replace content while preserving keys
|
|
219
218
|
reply.content[arr_name] = ArrayRecord(
|
|
220
|
-
|
|
221
|
-
{k: Array(v) for k, v in zip(record.keys(), reply_ndarrays)}
|
|
222
|
-
)
|
|
219
|
+
dict(zip(record.keys(), map(Array, reply_ndarrays), strict=True))
|
|
223
220
|
)
|
|
224
221
|
log(
|
|
225
222
|
INFO,
|
|
@@ -302,7 +299,7 @@ class DifferentialPrivacyClientSideFixedClipping(DifferentialPrivacyFixedClippin
|
|
|
302
299
|
self,
|
|
303
300
|
server_round: int,
|
|
304
301
|
replies: Iterable[Message],
|
|
305
|
-
) -> tuple[
|
|
302
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
306
303
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
307
304
|
if not validate_replies(replies, self.num_sampled_clients):
|
|
308
305
|
return None, None
|
|
@@ -19,9 +19,8 @@ Adaptive Federated Optimization using Adagrad.
|
|
|
19
19
|
Paper: arxiv.org/abs/2003.00295
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
|
-
|
|
23
|
-
from collections.abc import Iterable
|
|
24
|
-
from typing import Callable, Optional
|
|
22
|
+
|
|
23
|
+
from collections.abc import Callable, Iterable
|
|
25
24
|
|
|
26
25
|
import numpy as np
|
|
27
26
|
|
|
@@ -90,12 +89,12 @@ class FedAdagrad(FedOpt):
|
|
|
90
89
|
weighted_by_key: str = "num-examples",
|
|
91
90
|
arrayrecord_key: str = "arrays",
|
|
92
91
|
configrecord_key: str = "config",
|
|
93
|
-
train_metrics_aggr_fn:
|
|
94
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
95
|
-
|
|
96
|
-
evaluate_metrics_aggr_fn:
|
|
97
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
98
|
-
|
|
92
|
+
train_metrics_aggr_fn: (
|
|
93
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
94
|
+
) = None,
|
|
95
|
+
evaluate_metrics_aggr_fn: (
|
|
96
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
97
|
+
) = None,
|
|
99
98
|
eta: float = 1e-1,
|
|
100
99
|
eta_l: float = 1e-1,
|
|
101
100
|
tau: float = 1e-3,
|
|
@@ -122,7 +121,7 @@ class FedAdagrad(FedOpt):
|
|
|
122
121
|
self,
|
|
123
122
|
server_round: int,
|
|
124
123
|
replies: Iterable[Message],
|
|
125
|
-
) -> tuple[
|
|
124
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
126
125
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
127
126
|
aggregated_arrayrecord, aggregated_metrics = super().aggregate_train(
|
|
128
127
|
server_round, replies
|
|
@@ -154,6 +153,6 @@ class FedAdagrad(FedOpt):
|
|
|
154
153
|
}
|
|
155
154
|
|
|
156
155
|
return (
|
|
157
|
-
ArrayRecord(
|
|
156
|
+
ArrayRecord({k: Array(v) for k, v in new_arrays.items()}),
|
|
158
157
|
aggregated_metrics,
|
|
159
158
|
)
|
|
@@ -19,9 +19,8 @@
|
|
|
19
19
|
Paper: arxiv.org/abs/2003.00295
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
|
-
|
|
23
|
-
from collections.abc import Iterable
|
|
24
|
-
from typing import Callable, Optional
|
|
22
|
+
|
|
23
|
+
from collections.abc import Callable, Iterable
|
|
25
24
|
|
|
26
25
|
import numpy as np
|
|
27
26
|
|
|
@@ -94,12 +93,12 @@ class FedAdam(FedOpt):
|
|
|
94
93
|
weighted_by_key: str = "num-examples",
|
|
95
94
|
arrayrecord_key: str = "arrays",
|
|
96
95
|
configrecord_key: str = "config",
|
|
97
|
-
train_metrics_aggr_fn:
|
|
98
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
99
|
-
|
|
100
|
-
evaluate_metrics_aggr_fn:
|
|
101
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
102
|
-
|
|
96
|
+
train_metrics_aggr_fn: (
|
|
97
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
98
|
+
) = None,
|
|
99
|
+
evaluate_metrics_aggr_fn: (
|
|
100
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
101
|
+
) = None,
|
|
103
102
|
eta: float = 1e-1,
|
|
104
103
|
eta_l: float = 1e-1,
|
|
105
104
|
beta_1: float = 0.9,
|
|
@@ -128,7 +127,7 @@ class FedAdam(FedOpt):
|
|
|
128
127
|
self,
|
|
129
128
|
server_round: int,
|
|
130
129
|
replies: Iterable[Message],
|
|
131
|
-
) -> tuple[
|
|
130
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
132
131
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
133
132
|
aggregated_arrayrecord, aggregated_metrics = super().aggregate_train(
|
|
134
133
|
server_round, replies
|
|
@@ -173,6 +172,6 @@ class FedAdam(FedOpt):
|
|
|
173
172
|
}
|
|
174
173
|
|
|
175
174
|
return (
|
|
176
|
-
ArrayRecord(
|
|
175
|
+
ArrayRecord({k: Array(v) for k, v in new_arrays.items()}),
|
|
177
176
|
aggregated_metrics,
|
|
178
177
|
)
|
|
@@ -15,9 +15,8 @@
|
|
|
15
15
|
"""Flower message-based FedAvg strategy."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from collections.abc import Iterable
|
|
18
|
+
from collections.abc import Callable, Iterable
|
|
19
19
|
from logging import INFO, WARNING
|
|
20
|
-
from typing import Callable, Optional
|
|
21
20
|
|
|
22
21
|
from flwr.common import (
|
|
23
22
|
ArrayRecord,
|
|
@@ -91,12 +90,12 @@ class FedAvg(Strategy):
|
|
|
91
90
|
weighted_by_key: str = "num-examples",
|
|
92
91
|
arrayrecord_key: str = "arrays",
|
|
93
92
|
configrecord_key: str = "config",
|
|
94
|
-
train_metrics_aggr_fn:
|
|
95
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
96
|
-
|
|
97
|
-
evaluate_metrics_aggr_fn:
|
|
98
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
99
|
-
|
|
93
|
+
train_metrics_aggr_fn: (
|
|
94
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
95
|
+
) = None,
|
|
96
|
+
evaluate_metrics_aggr_fn: (
|
|
97
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
98
|
+
) = None,
|
|
100
99
|
) -> None:
|
|
101
100
|
self.fraction_train = fraction_train
|
|
102
101
|
self.fraction_evaluate = fraction_evaluate
|
|
@@ -251,7 +250,7 @@ class FedAvg(Strategy):
|
|
|
251
250
|
self,
|
|
252
251
|
server_round: int,
|
|
253
252
|
replies: Iterable[Message],
|
|
254
|
-
) -> tuple[
|
|
253
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
255
254
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
256
255
|
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
|
257
256
|
|
|
@@ -304,7 +303,7 @@ class FedAvg(Strategy):
|
|
|
304
303
|
self,
|
|
305
304
|
server_round: int,
|
|
306
305
|
replies: Iterable[Message],
|
|
307
|
-
) ->
|
|
306
|
+
) -> MetricRecord | None:
|
|
308
307
|
"""Aggregate MetricRecords in the received Messages."""
|
|
309
308
|
valid_replies, _ = self._check_and_log_replies(replies, is_train=False)
|
|
310
309
|
|
|
@@ -18,10 +18,8 @@ Paper: arxiv.org/pdf/1909.06335.pdf
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from collections import
|
|
22
|
-
from collections.abc import Iterable
|
|
21
|
+
from collections.abc import Callable, Iterable
|
|
23
22
|
from logging import INFO
|
|
24
|
-
from typing import Callable, Optional
|
|
25
23
|
|
|
26
24
|
from flwr.common import (
|
|
27
25
|
Array,
|
|
@@ -93,12 +91,12 @@ class FedAvgM(FedAvg):
|
|
|
93
91
|
weighted_by_key: str = "num-examples",
|
|
94
92
|
arrayrecord_key: str = "arrays",
|
|
95
93
|
configrecord_key: str = "config",
|
|
96
|
-
train_metrics_aggr_fn:
|
|
97
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
98
|
-
|
|
99
|
-
evaluate_metrics_aggr_fn:
|
|
100
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
101
|
-
|
|
94
|
+
train_metrics_aggr_fn: (
|
|
95
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
96
|
+
) = None,
|
|
97
|
+
evaluate_metrics_aggr_fn: (
|
|
98
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
99
|
+
) = None,
|
|
102
100
|
server_learning_rate: float = 1.0,
|
|
103
101
|
server_momentum: float = 0.0,
|
|
104
102
|
) -> None:
|
|
@@ -119,8 +117,8 @@ class FedAvgM(FedAvg):
|
|
|
119
117
|
self.server_opt: bool = (self.server_momentum != 0.0) or (
|
|
120
118
|
self.server_learning_rate != 1.0
|
|
121
119
|
)
|
|
122
|
-
self.current_arrays:
|
|
123
|
-
self.momentum_vector:
|
|
120
|
+
self.current_arrays: ArrayRecord | None = None
|
|
121
|
+
self.momentum_vector: NDArrays | None = None
|
|
124
122
|
|
|
125
123
|
def summary(self) -> None:
|
|
126
124
|
"""Log summary configuration of the strategy."""
|
|
@@ -143,7 +141,7 @@ class FedAvgM(FedAvg):
|
|
|
143
141
|
self,
|
|
144
142
|
server_round: int,
|
|
145
143
|
replies: Iterable[Message],
|
|
146
|
-
) -> tuple[
|
|
144
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
147
145
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
148
146
|
# Call FedAvg aggregate_train to perform validation and aggregation
|
|
149
147
|
aggregated_arrays, aggregated_metrics = super().aggregate_train(
|
|
@@ -168,7 +166,8 @@ class FedAvgM(FedAvg):
|
|
|
168
166
|
|
|
169
167
|
# Remember that updates are the opposite of gradients
|
|
170
168
|
pseudo_gradient = [
|
|
171
|
-
old - new
|
|
169
|
+
old - new
|
|
170
|
+
for new, old in zip(aggregated_ndarrays, ndarrays, strict=True)
|
|
172
171
|
]
|
|
173
172
|
if self.server_momentum > 0.0:
|
|
174
173
|
if self.momentum_vector is None:
|
|
@@ -177,7 +176,9 @@ class FedAvgM(FedAvg):
|
|
|
177
176
|
else:
|
|
178
177
|
self.momentum_vector = [
|
|
179
178
|
self.server_momentum * mv + pg
|
|
180
|
-
for mv, pg in zip(
|
|
179
|
+
for mv, pg in zip(
|
|
180
|
+
self.momentum_vector, pseudo_gradient, strict=True
|
|
181
|
+
)
|
|
181
182
|
]
|
|
182
183
|
|
|
183
184
|
# No nesterov for now
|
|
@@ -186,10 +187,10 @@ class FedAvgM(FedAvg):
|
|
|
186
187
|
# SGD and convert back to ArrayRecord
|
|
187
188
|
updated_array_list = [
|
|
188
189
|
Array(old - self.server_learning_rate * pg)
|
|
189
|
-
for old, pg in zip(ndarrays, pseudo_gradient)
|
|
190
|
+
for old, pg in zip(ndarrays, pseudo_gradient, strict=True)
|
|
190
191
|
]
|
|
191
192
|
aggregated_arrays = ArrayRecord(
|
|
192
|
-
|
|
193
|
+
dict(zip(array_keys, updated_array_list, strict=True))
|
|
193
194
|
)
|
|
194
195
|
|
|
195
196
|
# Update current weights
|
|
@@ -19,7 +19,7 @@ Paper: arxiv.org/pdf/1803.01498v1.pdf
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
from collections.abc import Iterable
|
|
22
|
-
from typing import
|
|
22
|
+
from typing import cast
|
|
23
23
|
|
|
24
24
|
import numpy as np
|
|
25
25
|
|
|
@@ -72,7 +72,7 @@ class FedMedian(FedAvg):
|
|
|
72
72
|
self,
|
|
73
73
|
server_round: int,
|
|
74
74
|
replies: Iterable[Message],
|
|
75
|
-
) -> tuple[
|
|
75
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
76
76
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
77
77
|
# Call FedAvg aggregate_train to perform validation and aggregation
|
|
78
78
|
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
|
@@ -17,9 +17,8 @@
|
|
|
17
17
|
Paper: arxiv.org/abs/2003.00295
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
|
-
from collections.abc import Iterable
|
|
20
|
+
from collections.abc import Callable, Iterable
|
|
21
21
|
from logging import INFO
|
|
22
|
-
from typing import Callable, Optional
|
|
23
22
|
|
|
24
23
|
import numpy as np
|
|
25
24
|
|
|
@@ -101,12 +100,12 @@ class FedOpt(FedAvg):
|
|
|
101
100
|
weighted_by_key: str = "num-examples",
|
|
102
101
|
arrayrecord_key: str = "arrays",
|
|
103
102
|
configrecord_key: str = "config",
|
|
104
|
-
train_metrics_aggr_fn:
|
|
105
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
106
|
-
|
|
107
|
-
evaluate_metrics_aggr_fn:
|
|
108
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
109
|
-
|
|
103
|
+
train_metrics_aggr_fn: (
|
|
104
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
105
|
+
) = None,
|
|
106
|
+
evaluate_metrics_aggr_fn: (
|
|
107
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
108
|
+
) = None,
|
|
110
109
|
eta: float = 1e-1,
|
|
111
110
|
eta_l: float = 1e-1,
|
|
112
111
|
beta_1: float = 0.0,
|
|
@@ -125,14 +124,14 @@ class FedOpt(FedAvg):
|
|
|
125
124
|
train_metrics_aggr_fn=train_metrics_aggr_fn,
|
|
126
125
|
evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
|
|
127
126
|
)
|
|
128
|
-
self.current_arrays:
|
|
127
|
+
self.current_arrays: dict[str, NDArray] | None = None
|
|
129
128
|
self.eta = eta
|
|
130
129
|
self.eta_l = eta_l
|
|
131
130
|
self.tau = tau
|
|
132
131
|
self.beta_1 = beta_1
|
|
133
132
|
self.beta_2 = beta_2
|
|
134
|
-
self.m_t:
|
|
135
|
-
self.v_t:
|
|
133
|
+
self.m_t: dict[str, NDArray] | None = None
|
|
134
|
+
self.v_t: dict[str, NDArray] | None = None
|
|
136
135
|
|
|
137
136
|
def summary(self) -> None:
|
|
138
137
|
"""Log summary configuration of the strategy."""
|
|
@@ -18,9 +18,8 @@ Paper: arxiv.org/abs/1812.06127
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from collections.abc import Iterable
|
|
21
|
+
from collections.abc import Callable, Iterable
|
|
22
22
|
from logging import INFO, WARN
|
|
23
|
-
from typing import Callable, Optional
|
|
24
23
|
|
|
25
24
|
from flwr.common import (
|
|
26
25
|
ArrayRecord,
|
|
@@ -130,12 +129,12 @@ class FedProx(FedAvg):
|
|
|
130
129
|
weighted_by_key: str = "num-examples",
|
|
131
130
|
arrayrecord_key: str = "arrays",
|
|
132
131
|
configrecord_key: str = "config",
|
|
133
|
-
train_metrics_aggr_fn:
|
|
134
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
135
|
-
|
|
136
|
-
evaluate_metrics_aggr_fn:
|
|
137
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
138
|
-
|
|
132
|
+
train_metrics_aggr_fn: (
|
|
133
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
134
|
+
) = None,
|
|
135
|
+
evaluate_metrics_aggr_fn: (
|
|
136
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
137
|
+
) = None,
|
|
139
138
|
proximal_mu: float = 0.0,
|
|
140
139
|
) -> None:
|
|
141
140
|
super().__init__(
|
|
@@ -18,9 +18,9 @@ Paper: arxiv.org/abs/1803.01498
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from collections.abc import Iterable
|
|
21
|
+
from collections.abc import Callable, Iterable
|
|
22
22
|
from logging import INFO
|
|
23
|
-
from typing import
|
|
23
|
+
from typing import cast
|
|
24
24
|
|
|
25
25
|
import numpy as np
|
|
26
26
|
|
|
@@ -83,12 +83,12 @@ class FedTrimmedAvg(FedAvg):
|
|
|
83
83
|
weighted_by_key: str = "num-examples",
|
|
84
84
|
arrayrecord_key: str = "arrays",
|
|
85
85
|
configrecord_key: str = "config",
|
|
86
|
-
train_metrics_aggr_fn:
|
|
87
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
88
|
-
|
|
89
|
-
evaluate_metrics_aggr_fn:
|
|
90
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
91
|
-
|
|
86
|
+
train_metrics_aggr_fn: (
|
|
87
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
88
|
+
) = None,
|
|
89
|
+
evaluate_metrics_aggr_fn: (
|
|
90
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
91
|
+
) = None,
|
|
92
92
|
beta: float = 0.2,
|
|
93
93
|
) -> None:
|
|
94
94
|
super().__init__(
|
|
@@ -115,7 +115,7 @@ class FedTrimmedAvg(FedAvg):
|
|
|
115
115
|
self,
|
|
116
116
|
server_round: int,
|
|
117
117
|
replies: Iterable[Message],
|
|
118
|
-
) -> tuple[
|
|
118
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
119
119
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
120
120
|
# Call FedAvg aggregate_train to perform validation and aggregation
|
|
121
121
|
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower message-based FedXgbBagging strategy."""
|
|
16
16
|
from collections.abc import Iterable
|
|
17
|
-
from typing import
|
|
17
|
+
from typing import cast
|
|
18
18
|
|
|
19
19
|
import numpy as np
|
|
20
20
|
|
|
@@ -65,7 +65,7 @@ class FedXgbBagging(FedAvg):
|
|
|
65
65
|
average using the provided weight factor key.
|
|
66
66
|
"""
|
|
67
67
|
|
|
68
|
-
current_bst:
|
|
68
|
+
current_bst: bytes | None = None
|
|
69
69
|
|
|
70
70
|
def _ensure_single_array(self, arrays: ArrayRecord) -> None:
|
|
71
71
|
"""Check that ensures there's only one Array in the ArrayRecord."""
|
|
@@ -89,7 +89,7 @@ class FedXgbBagging(FedAvg):
|
|
|
89
89
|
self,
|
|
90
90
|
server_round: int,
|
|
91
91
|
replies: Iterable[Message],
|
|
92
|
-
) -> tuple[
|
|
92
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
93
93
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
94
94
|
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
|
95
95
|
|
|
@@ -15,9 +15,9 @@
|
|
|
15
15
|
"""Flower message-based FedXgbCyclic strategy."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from collections.abc import Iterable
|
|
18
|
+
from collections.abc import Callable, Iterable
|
|
19
19
|
from logging import INFO
|
|
20
|
-
from typing import
|
|
20
|
+
from typing import cast
|
|
21
21
|
|
|
22
22
|
from flwr.common import (
|
|
23
23
|
ArrayRecord,
|
|
@@ -78,12 +78,12 @@ class FedXgbCyclic(FedAvg):
|
|
|
78
78
|
weighted_by_key: str = "num-examples",
|
|
79
79
|
arrayrecord_key: str = "arrays",
|
|
80
80
|
configrecord_key: str = "config",
|
|
81
|
-
train_metrics_aggr_fn:
|
|
82
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
83
|
-
|
|
84
|
-
evaluate_metrics_aggr_fn:
|
|
85
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
86
|
-
|
|
81
|
+
train_metrics_aggr_fn: (
|
|
82
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
83
|
+
) = None,
|
|
84
|
+
evaluate_metrics_aggr_fn: (
|
|
85
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
86
|
+
) = None,
|
|
87
87
|
) -> None:
|
|
88
88
|
super().__init__(
|
|
89
89
|
fraction_train=fraction_train,
|
|
@@ -184,7 +184,7 @@ class FedXgbCyclic(FedAvg):
|
|
|
184
184
|
self,
|
|
185
185
|
server_round: int,
|
|
186
186
|
replies: Iterable[Message],
|
|
187
|
-
) -> tuple[
|
|
187
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
188
188
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
189
189
|
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
|
190
190
|
|