flwr 1.23.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +16 -5
- flwr/app/error.py +2 -2
- flwr/app/exception.py +3 -3
- flwr/cli/app.py +19 -0
- flwr/cli/app_cmd/__init__.py +23 -0
- flwr/cli/app_cmd/publish.py +285 -0
- flwr/cli/app_cmd/review.py +252 -0
- flwr/cli/auth_plugin/auth_plugin.py +4 -5
- flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
- flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
- flwr/cli/build.py +60 -18
- flwr/cli/cli_account_auth_interceptor.py +24 -7
- flwr/cli/config_utils.py +101 -13
- flwr/cli/federation/__init__.py +24 -0
- flwr/cli/federation/ls.py +140 -0
- flwr/cli/federation/show.py +317 -0
- flwr/cli/install.py +91 -13
- flwr/cli/log.py +52 -9
- flwr/cli/login/login.py +7 -4
- flwr/cli/ls.py +170 -130
- flwr/cli/new/new.py +33 -50
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/cli/pull.py +10 -5
- flwr/cli/run/run.py +77 -30
- flwr/cli/run_utils.py +130 -0
- flwr/cli/stop.py +25 -7
- flwr/cli/supernode/ls.py +16 -8
- flwr/cli/supernode/register.py +9 -4
- flwr/cli/supernode/unregister.py +5 -3
- flwr/cli/utils.py +376 -16
- flwr/client/__init__.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +4 -1
- flwr/client/grpc_adapter_client/connection.py +6 -7
- flwr/client/grpc_rere_client/connection.py +10 -11
- flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
- flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/connection.py +12 -14
- flwr/client/run_info_store.py +4 -5
- flwr/client/typing.py +1 -1
- flwr/clientapp/client_app.py +9 -10
- flwr/clientapp/mod/centraldp_mods.py +16 -17
- flwr/clientapp/mod/localdp_mod.py +8 -9
- flwr/clientapp/typing.py +1 -1
- flwr/clientapp/utils.py +3 -3
- flwr/common/address.py +1 -2
- flwr/common/args.py +3 -4
- flwr/common/config.py +13 -16
- flwr/common/constant.py +5 -2
- flwr/common/differential_privacy.py +3 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -4
- flwr/common/exit/exit.py +15 -2
- flwr/common/exit/exit_code.py +19 -0
- flwr/common/exit/exit_handler.py +6 -2
- flwr/common/exit/signal_handler.py +5 -5
- flwr/common/grpc.py +6 -6
- flwr/common/inflatable_protobuf_utils.py +1 -1
- flwr/common/inflatable_utils.py +38 -21
- flwr/common/logger.py +19 -19
- flwr/common/message.py +4 -4
- flwr/common/object_ref.py +7 -7
- flwr/common/record/array.py +3 -3
- flwr/common/record/arrayrecord.py +18 -30
- flwr/common/record/configrecord.py +3 -3
- flwr/common/record/recorddict.py +5 -5
- flwr/common/record/typeddict.py +9 -2
- flwr/common/recorddict_compat.py +7 -10
- flwr/common/retry_invoker.py +20 -20
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
- flwr/common/serde.py +5 -4
- flwr/common/serde_utils.py +2 -2
- flwr/common/telemetry.py +9 -5
- flwr/common/typing.py +52 -37
- flwr/compat/client/app.py +38 -37
- flwr/compat/client/grpc_client/connection.py +11 -11
- flwr/compat/server/app.py +5 -6
- flwr/proto/appio_pb2.py +13 -3
- flwr/proto/appio_pb2.pyi +134 -65
- flwr/proto/appio_pb2_grpc.py +20 -0
- flwr/proto/appio_pb2_grpc.pyi +27 -0
- flwr/proto/clientappio_pb2.py +17 -7
- flwr/proto/clientappio_pb2.pyi +15 -0
- flwr/proto/clientappio_pb2_grpc.py +206 -40
- flwr/proto/clientappio_pb2_grpc.pyi +168 -53
- flwr/proto/control_pb2.py +71 -52
- flwr/proto/control_pb2.pyi +277 -111
- flwr/proto/control_pb2_grpc.py +249 -40
- flwr/proto/control_pb2_grpc.pyi +185 -52
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2.pyi +24 -6
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/error_pb2_grpc.pyi +27 -0
- flwr/proto/fab_pb2.py +14 -4
- flwr/proto/fab_pb2.pyi +59 -31
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fab_pb2_grpc.pyi +27 -0
- flwr/proto/federation_pb2.py +38 -0
- flwr/proto/federation_pb2.pyi +56 -0
- flwr/proto/federation_pb2_grpc.py +24 -0
- flwr/proto/federation_pb2_grpc.pyi +31 -0
- flwr/proto/fleet_pb2.py +14 -4
- flwr/proto/fleet_pb2.pyi +137 -61
- flwr/proto/fleet_pb2_grpc.py +189 -48
- flwr/proto/fleet_pb2_grpc.pyi +175 -61
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2.pyi +38 -16
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
- flwr/proto/heartbeat_pb2.py +17 -7
- flwr/proto/heartbeat_pb2.pyi +51 -22
- flwr/proto/heartbeat_pb2_grpc.py +20 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2.pyi +34 -11
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/log_pb2_grpc.pyi +27 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2.pyi +154 -86
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/message_pb2_grpc.pyi +27 -0
- flwr/proto/node_pb2.py +15 -5
- flwr/proto/node_pb2.pyi +50 -25
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/node_pb2_grpc.pyi +27 -0
- flwr/proto/recorddict_pb2.py +13 -3
- flwr/proto/recorddict_pb2.pyi +184 -107
- flwr/proto/recorddict_pb2_grpc.py +20 -0
- flwr/proto/recorddict_pb2_grpc.pyi +27 -0
- flwr/proto/run_pb2.py +40 -31
- flwr/proto/run_pb2.pyi +149 -84
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/run_pb2_grpc.pyi +27 -0
- flwr/proto/serverappio_pb2.py +13 -3
- flwr/proto/serverappio_pb2.pyi +32 -8
- flwr/proto/serverappio_pb2_grpc.py +246 -65
- flwr/proto/serverappio_pb2_grpc.pyi +221 -85
- flwr/proto/simulationio_pb2.py +16 -8
- flwr/proto/simulationio_pb2.pyi +15 -0
- flwr/proto/simulationio_pb2_grpc.py +162 -41
- flwr/proto/simulationio_pb2_grpc.pyi +149 -55
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2.pyi +249 -160
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/proto/transport_pb2_grpc.pyi +38 -8
- flwr/server/app.py +38 -17
- flwr/server/client_manager.py +4 -5
- flwr/server/client_proxy.py +10 -11
- flwr/server/compat/app.py +4 -5
- flwr/server/compat/app_utils.py +2 -1
- flwr/server/compat/grid_client_proxy.py +10 -12
- flwr/server/compat/legacy_context.py +3 -4
- flwr/server/fleet_event_log_interceptor.py +2 -1
- flwr/server/grid/grid.py +2 -3
- flwr/server/grid/grpc_grid.py +10 -8
- flwr/server/grid/inmemory_grid.py +4 -4
- flwr/server/run_serverapp.py +2 -3
- flwr/server/server.py +34 -39
- flwr/server/server_app.py +7 -8
- flwr/server/server_config.py +1 -2
- flwr/server/serverapp/app.py +34 -28
- flwr/server/serverapp_components.py +4 -5
- flwr/server/strategy/aggregate.py +9 -8
- flwr/server/strategy/bulyan.py +13 -11
- flwr/server/strategy/dp_adaptive_clipping.py +16 -20
- flwr/server/strategy/dp_fixed_clipping.py +12 -17
- flwr/server/strategy/dpfedavg_adaptive.py +3 -4
- flwr/server/strategy/dpfedavg_fixed.py +6 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
- flwr/server/strategy/fedadagrad.py +18 -14
- flwr/server/strategy/fedadam.py +16 -14
- flwr/server/strategy/fedavg.py +16 -17
- flwr/server/strategy/fedavg_android.py +15 -15
- flwr/server/strategy/fedavgm.py +21 -18
- flwr/server/strategy/fedmedian.py +2 -3
- flwr/server/strategy/fedopt.py +11 -10
- flwr/server/strategy/fedprox.py +10 -9
- flwr/server/strategy/fedtrimmedavg.py +12 -11
- flwr/server/strategy/fedxgb_bagging.py +13 -11
- flwr/server/strategy/fedxgb_cyclic.py +6 -6
- flwr/server/strategy/fedxgb_nn_avg.py +4 -4
- flwr/server/strategy/fedyogi.py +16 -14
- flwr/server/strategy/krum.py +12 -11
- flwr/server/strategy/qfedavg.py +16 -15
- flwr/server/strategy/strategy.py +6 -9
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
- flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +34 -28
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +15 -9
- flwr/server/superlink/linkstate/in_memory_linkstate.py +115 -150
- flwr/server/superlink/linkstate/linkstate.py +59 -43
- flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
- flwr/server/superlink/linkstate/sqlite_linkstate.py +447 -438
- flwr/server/superlink/linkstate/utils.py +6 -6
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
- flwr/server/superlink/utils.py +4 -6
- flwr/server/typing.py +1 -1
- flwr/server/utils/tensorboard.py +15 -8
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
- flwr/serverapp/strategy/bulyan.py +16 -15
- flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
- flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
- flwr/serverapp/strategy/fedadagrad.py +10 -11
- flwr/serverapp/strategy/fedadam.py +10 -11
- flwr/serverapp/strategy/fedavg.py +9 -10
- flwr/serverapp/strategy/fedavgm.py +17 -16
- flwr/serverapp/strategy/fedmedian.py +2 -2
- flwr/serverapp/strategy/fedopt.py +10 -11
- flwr/serverapp/strategy/fedprox.py +7 -8
- flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
- flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
- flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
- flwr/serverapp/strategy/fedyogi.py +9 -11
- flwr/serverapp/strategy/krum.py +7 -7
- flwr/serverapp/strategy/multikrum.py +9 -9
- flwr/serverapp/strategy/qfedavg.py +17 -16
- flwr/serverapp/strategy/strategy.py +6 -9
- flwr/serverapp/strategy/strategy_utils.py +7 -8
- flwr/simulation/app.py +46 -42
- flwr/simulation/legacy_app.py +12 -12
- flwr/simulation/ray_transport/ray_actor.py +10 -11
- flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
- flwr/simulation/run_simulation.py +43 -43
- flwr/simulation/simulationio_connection.py +4 -4
- flwr/supercore/cli/flower_superexec.py +3 -4
- flwr/supercore/constant.py +31 -1
- flwr/supercore/corestate/corestate.py +24 -3
- flwr/supercore/corestate/in_memory_corestate.py +138 -0
- flwr/supercore/corestate/sqlite_corestate.py +157 -0
- flwr/supercore/ffs/disk_ffs.py +1 -2
- flwr/supercore/ffs/ffs.py +1 -2
- flwr/supercore/ffs/ffs_factory.py +1 -2
- flwr/{common → supercore}/heartbeat.py +20 -25
- flwr/supercore/object_store/in_memory_object_store.py +1 -2
- flwr/supercore/object_store/object_store.py +1 -2
- flwr/supercore/object_store/object_store_factory.py +1 -2
- flwr/supercore/object_store/sqlite_object_store.py +8 -7
- flwr/supercore/primitives/asymmetric.py +1 -1
- flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
- flwr/supercore/sqlite_mixin.py +37 -34
- flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
- flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
- flwr/supercore/superexec/run_superexec.py +9 -13
- flwr/superlink/artifact_provider/artifact_provider.py +1 -2
- flwr/superlink/auth_plugin/auth_plugin.py +6 -9
- flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
- flwr/superlink/federation/__init__.py +24 -0
- flwr/superlink/federation/federation_manager.py +64 -0
- flwr/superlink/federation/noop_federation_manager.py +71 -0
- flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
- flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
- flwr/superlink/servicer/control/control_grpc.py +5 -6
- flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
- flwr/superlink/servicer/control/control_servicer.py +102 -18
- flwr/supernode/cli/flower_supernode.py +58 -3
- flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
- flwr/supernode/nodestate/nodestate.py +7 -8
- flwr/supernode/nodestate/nodestate_factory.py +7 -4
- flwr/supernode/runtime/run_clientapp.py +41 -22
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
- flwr/supernode/start_client_internal.py +158 -42
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
- flwr-1.24.0.dist-info/RECORD +454 -0
- flwr/supercore/object_store/utils.py +0 -43
- flwr-1.23.0.dist-info/RECORD +0 -439
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from os import urandom
|
|
19
|
-
from typing import Optional
|
|
20
19
|
|
|
21
20
|
from flwr.common import ConfigRecord, Context, Error, Message, Metadata, now, serde
|
|
22
21
|
from flwr.common.constant import (
|
|
@@ -51,7 +50,8 @@ VALID_RUN_SUB_STATUSES = {
|
|
|
51
50
|
}
|
|
52
51
|
MESSAGE_UNAVAILABLE_ERROR_REASON = (
|
|
53
52
|
"Error: Message Unavailable - The requested message could not be found in the "
|
|
54
|
-
"database. It may have expired due to its TTL
|
|
53
|
+
"database. It may have expired due to its TTL, been deleted because the "
|
|
54
|
+
"destination SuperNode was removed from the federation, or never existed."
|
|
55
55
|
)
|
|
56
56
|
REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON = (
|
|
57
57
|
"Error: Reply Message Unavailable - The reply message has expired."
|
|
@@ -63,7 +63,7 @@ NODE_UNAVAILABLE_ERROR_REASON = (
|
|
|
63
63
|
|
|
64
64
|
|
|
65
65
|
def generate_rand_int_from_bytes(
|
|
66
|
-
num_bytes: int, exclude:
|
|
66
|
+
num_bytes: int, exclude: list[int] | None = None
|
|
67
67
|
) -> int:
|
|
68
68
|
"""Generate a random unsigned integer from `num_bytes` bytes.
|
|
69
69
|
|
|
@@ -257,7 +257,7 @@ def message_ttl_has_expired(message_metadata: Metadata, current_time: float) ->
|
|
|
257
257
|
def verify_message_ids(
|
|
258
258
|
inquired_message_ids: set[str],
|
|
259
259
|
found_message_ins_dict: dict[str, Message],
|
|
260
|
-
current_time:
|
|
260
|
+
current_time: float | None = None,
|
|
261
261
|
update_set: bool = True,
|
|
262
262
|
) -> dict[str, Message]:
|
|
263
263
|
"""Verify found Messages and generate error Messages for invalid ones.
|
|
@@ -300,7 +300,7 @@ def verify_found_message_replies(
|
|
|
300
300
|
inquired_message_ids: set[str],
|
|
301
301
|
found_message_ins_dict: dict[str, Message],
|
|
302
302
|
found_message_res_list: list[Message],
|
|
303
|
-
current_time:
|
|
303
|
+
current_time: float | None = None,
|
|
304
304
|
update_set: bool = True,
|
|
305
305
|
) -> dict[str, Message]:
|
|
306
306
|
"""Verify found Message replies and generate error Message for invalid ones.
|
|
@@ -345,7 +345,7 @@ def check_node_availability_for_in_message(
|
|
|
345
345
|
inquired_in_message_ids: set[str],
|
|
346
346
|
found_in_message_dict: dict[str, Message],
|
|
347
347
|
node_id_to_online_until: dict[int, float],
|
|
348
|
-
current_time:
|
|
348
|
+
current_time: float | None = None,
|
|
349
349
|
update_set: bool = True,
|
|
350
350
|
) -> dict[str, Message]:
|
|
351
351
|
"""Check node availability for given Message and generate error reply Message if
|
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from logging import INFO
|
|
19
|
-
from typing import Optional
|
|
20
19
|
|
|
21
20
|
import grpc
|
|
22
21
|
|
|
@@ -38,7 +37,7 @@ def run_serverappio_api_grpc(
|
|
|
38
37
|
state_factory: LinkStateFactory,
|
|
39
38
|
ffs_factory: FfsFactory,
|
|
40
39
|
objectstore_factory: ObjectStoreFactory,
|
|
41
|
-
certificates:
|
|
40
|
+
certificates: tuple[bytes, bytes, bytes] | None,
|
|
42
41
|
) -> grpc.Server:
|
|
43
42
|
"""Run ServerAppIo API (gRPC, request-response)."""
|
|
44
43
|
# Create ServerAppIo API gRPC server
|
|
@@ -17,7 +17,6 @@
|
|
|
17
17
|
|
|
18
18
|
import threading
|
|
19
19
|
from logging import DEBUG, ERROR, INFO
|
|
20
|
-
from typing import Optional
|
|
21
20
|
|
|
22
21
|
import grpc
|
|
23
22
|
|
|
@@ -91,7 +90,6 @@ from flwr.server.superlink.utils import abort_if
|
|
|
91
90
|
from flwr.server.utils.validator import validate_message
|
|
92
91
|
from flwr.supercore.ffs import Ffs, FfsFactory
|
|
93
92
|
from flwr.supercore.object_store import NoObjectInStoreError, ObjectStoreFactory
|
|
94
|
-
from flwr.supercore.object_store.utils import store_mapping_and_register_objects
|
|
95
93
|
|
|
96
94
|
|
|
97
95
|
class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
@@ -141,6 +139,13 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
141
139
|
# Attempt to create a token for the provided run ID
|
|
142
140
|
token = state.create_token(request.run_id)
|
|
143
141
|
|
|
142
|
+
# Transition the run to STARTING if token creation was successful
|
|
143
|
+
if token:
|
|
144
|
+
state.update_run_status(
|
|
145
|
+
run_id=request.run_id,
|
|
146
|
+
new_status=RunStatus(Status.STARTING, "", ""),
|
|
147
|
+
)
|
|
148
|
+
|
|
144
149
|
# Return the token
|
|
145
150
|
return RequestTokenResponse(token=token or "")
|
|
146
151
|
|
|
@@ -192,8 +197,11 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
192
197
|
request_name="PushMessages",
|
|
193
198
|
detail="`messages_list` must not be empty",
|
|
194
199
|
)
|
|
195
|
-
message_ids: list[
|
|
196
|
-
|
|
200
|
+
message_ids: list[str | None] = []
|
|
201
|
+
objects_to_push: set[str] = set()
|
|
202
|
+
for message_proto, object_tree in zip(
|
|
203
|
+
request.messages_list, request.message_object_trees, strict=True
|
|
204
|
+
):
|
|
197
205
|
message = message_from_proto(message_proto=message_proto)
|
|
198
206
|
validation_errors = validate_message(message, is_reply_message=False)
|
|
199
207
|
_raise_if(
|
|
@@ -206,13 +214,12 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
206
214
|
request_name="PushMessages",
|
|
207
215
|
detail="`Message.metadata` has mismatched `run_id`",
|
|
208
216
|
)
|
|
209
|
-
# Store
|
|
210
|
-
|
|
217
|
+
# Store objects
|
|
218
|
+
objects_to_push |= set(store.preregister(request.run_id, object_tree))
|
|
219
|
+
# Store message
|
|
220
|
+
message_id: str | None = state.store_message_ins(message=message)
|
|
211
221
|
message_ids.append(message_id)
|
|
212
222
|
|
|
213
|
-
# Store Message object to descendants mapping and preregister objects
|
|
214
|
-
objects_to_push = store_mapping_and_register_objects(store, request=request)
|
|
215
|
-
|
|
216
223
|
return PushAppMessagesResponse(
|
|
217
224
|
message_ids=[
|
|
218
225
|
str(message_id) if message_id else "" for message_id in message_ids
|
|
@@ -345,8 +352,8 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
345
352
|
if result := ffs.get(run.fab_hash):
|
|
346
353
|
fab = Fab(run.fab_hash, result[0], result[1])
|
|
347
354
|
if run and fab and serverapp_ctxt:
|
|
348
|
-
# Update run status to
|
|
349
|
-
if state.update_run_status(run_id, RunStatus(Status.
|
|
355
|
+
# Update run status to RUNNING
|
|
356
|
+
if state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")):
|
|
350
357
|
log(INFO, "Starting run %d", run_id)
|
|
351
358
|
return PullAppInputsResponse(
|
|
352
359
|
context=context_to_proto(serverapp_ctxt),
|
|
@@ -355,8 +362,12 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
355
362
|
)
|
|
356
363
|
|
|
357
364
|
# Raise an exception if the Run or Fab is not found,
|
|
358
|
-
# or if the status cannot be updated to
|
|
359
|
-
|
|
365
|
+
# or if the status cannot be updated to RUNNING
|
|
366
|
+
context.abort(
|
|
367
|
+
grpc.StatusCode.FAILED_PRECONDITION,
|
|
368
|
+
f"Failed to start run {run_id}",
|
|
369
|
+
)
|
|
370
|
+
raise RuntimeError("Unreachable code") # for mypy
|
|
360
371
|
|
|
361
372
|
def PushAppOutputs(
|
|
362
373
|
self, request: PushAppOutputsRequest, context: grpc.ServicerContext
|
|
@@ -441,20 +452,14 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
441
452
|
def SendAppHeartbeat(
|
|
442
453
|
self, request: SendAppHeartbeatRequest, context: grpc.ServicerContext
|
|
443
454
|
) -> SendAppHeartbeatResponse:
|
|
444
|
-
"""Handle a heartbeat from
|
|
455
|
+
"""Handle a heartbeat from an app process."""
|
|
445
456
|
log(DEBUG, "ServerAppIoServicer.SendAppHeartbeat")
|
|
446
457
|
|
|
447
458
|
# Init state
|
|
448
459
|
state = self.state_factory.state()
|
|
449
460
|
|
|
450
461
|
# Acknowledge the heartbeat
|
|
451
|
-
|
|
452
|
-
# starting or running status.
|
|
453
|
-
success = state.acknowledge_app_heartbeat(
|
|
454
|
-
run_id=request.run_id,
|
|
455
|
-
heartbeat_interval=request.heartbeat_interval,
|
|
456
|
-
)
|
|
457
|
-
|
|
462
|
+
success = state.acknowledge_app_heartbeat(request.token)
|
|
458
463
|
return SendAppHeartbeatResponse(success=success)
|
|
459
464
|
|
|
460
465
|
def PushObject(
|
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from logging import INFO
|
|
19
|
-
from typing import Optional
|
|
20
19
|
|
|
21
20
|
import grpc
|
|
22
21
|
|
|
@@ -36,7 +35,7 @@ def run_simulationio_api_grpc(
|
|
|
36
35
|
address: str,
|
|
37
36
|
state_factory: LinkStateFactory,
|
|
38
37
|
ffs_factory: FfsFactory,
|
|
39
|
-
certificates:
|
|
38
|
+
certificates: tuple[bytes, bytes, bytes] | None,
|
|
40
39
|
) -> grpc.Server:
|
|
41
40
|
"""Run SimulationIo API (gRPC, request-response)."""
|
|
42
41
|
# Create SimulationIo API gRPC server
|
|
@@ -110,6 +110,13 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
110
110
|
# Attempt to create a token for the provided run ID
|
|
111
111
|
token = state.create_token(request.run_id)
|
|
112
112
|
|
|
113
|
+
# Transition the run to STARTING if token creation was successful
|
|
114
|
+
if token:
|
|
115
|
+
state.update_run_status(
|
|
116
|
+
run_id=request.run_id,
|
|
117
|
+
new_status=RunStatus(Status.STARTING, "", ""),
|
|
118
|
+
)
|
|
119
|
+
|
|
113
120
|
# Return the token
|
|
114
121
|
return RequestTokenResponse(token=token or "")
|
|
115
122
|
|
|
@@ -152,8 +159,8 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
152
159
|
if result := ffs.get(run.fab_hash):
|
|
153
160
|
fab = Fab(run.fab_hash, result[0], result[1])
|
|
154
161
|
if run and fab and serverapp_ctxt:
|
|
155
|
-
# Update run status to
|
|
156
|
-
if state.update_run_status(run_id, RunStatus(Status.
|
|
162
|
+
# Update run status to RUNNING
|
|
163
|
+
if state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")):
|
|
157
164
|
log(INFO, "Starting run %d", run_id)
|
|
158
165
|
return PullAppInputsResponse(
|
|
159
166
|
context=context_to_proto(serverapp_ctxt),
|
|
@@ -162,8 +169,12 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
162
169
|
)
|
|
163
170
|
|
|
164
171
|
# Raise an exception if the Run or Fab is not found,
|
|
165
|
-
# or if the status cannot be updated to
|
|
166
|
-
|
|
172
|
+
# or if the status cannot be updated to RUNNING
|
|
173
|
+
context.abort(
|
|
174
|
+
grpc.StatusCode.FAILED_PRECONDITION,
|
|
175
|
+
f"Failed to start run {run_id}",
|
|
176
|
+
)
|
|
177
|
+
raise RuntimeError("Unreachable code") # for mypy
|
|
167
178
|
|
|
168
179
|
def PushAppOutputs(
|
|
169
180
|
self, request: PushAppOutputsRequest, context: ServicerContext
|
|
@@ -257,20 +268,14 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
257
268
|
def SendAppHeartbeat(
|
|
258
269
|
self, request: SendAppHeartbeatRequest, context: grpc.ServicerContext
|
|
259
270
|
) -> SendAppHeartbeatResponse:
|
|
260
|
-
"""Handle a heartbeat from
|
|
261
|
-
log(DEBUG, "
|
|
271
|
+
"""Handle a heartbeat from an app process."""
|
|
272
|
+
log(DEBUG, "SimulationIoServicer.SendAppHeartbeat")
|
|
262
273
|
|
|
263
274
|
# Init state
|
|
264
275
|
state = self.state_factory.state()
|
|
265
276
|
|
|
266
277
|
# Acknowledge the heartbeat
|
|
267
|
-
|
|
268
|
-
# starting or running status.
|
|
269
|
-
success = state.acknowledge_app_heartbeat(
|
|
270
|
-
run_id=request.run_id,
|
|
271
|
-
heartbeat_interval=request.heartbeat_interval,
|
|
272
|
-
)
|
|
273
|
-
|
|
278
|
+
success = state.acknowledge_app_heartbeat(request.token)
|
|
274
279
|
return SendAppHeartbeatResponse(success=success)
|
|
275
280
|
|
|
276
281
|
def _verify_token(self, token: str, context: grpc.ServicerContext) -> int:
|
flwr/server/superlink/utils.py
CHANGED
|
@@ -15,8 +15,6 @@
|
|
|
15
15
|
"""SuperLink utilities."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import Optional, Union
|
|
19
|
-
|
|
20
18
|
import grpc
|
|
21
19
|
|
|
22
20
|
from flwr.common.constant import Status, SubStatus
|
|
@@ -36,8 +34,8 @@ def check_abort(
|
|
|
36
34
|
run_id: int,
|
|
37
35
|
abort_status_list: list[str],
|
|
38
36
|
state: LinkState,
|
|
39
|
-
store:
|
|
40
|
-
) ->
|
|
37
|
+
store: ObjectStore | None = None,
|
|
38
|
+
) -> str | None:
|
|
41
39
|
"""Check if the status of the provided `run_id` is in `abort_status_list`."""
|
|
42
40
|
run_status: RunStatus = state.get_run_status({run_id})[run_id]
|
|
43
41
|
|
|
@@ -54,7 +52,7 @@ def check_abort(
|
|
|
54
52
|
return None
|
|
55
53
|
|
|
56
54
|
|
|
57
|
-
def abort_grpc_context(msg:
|
|
55
|
+
def abort_grpc_context(msg: str | None, context: grpc.ServicerContext) -> None:
|
|
58
56
|
"""Abort context with statuscode PERMISSION_DENIED if `msg` is not None."""
|
|
59
57
|
if msg is not None:
|
|
60
58
|
context.abort(grpc.StatusCode.PERMISSION_DENIED, msg)
|
|
@@ -64,7 +62,7 @@ def abort_if(
|
|
|
64
62
|
run_id: int,
|
|
65
63
|
abort_status_list: list[str],
|
|
66
64
|
state: LinkState,
|
|
67
|
-
store:
|
|
65
|
+
store: ObjectStore | None,
|
|
68
66
|
context: grpc.ServicerContext,
|
|
69
67
|
) -> None:
|
|
70
68
|
"""Abort context if status of the provided `run_id` is in `abort_status_list`."""
|
flwr/server/typing.py
CHANGED
flwr/server/utils/tensorboard.py
CHANGED
|
@@ -16,20 +16,16 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
|
+
from collections.abc import Callable
|
|
19
20
|
from datetime import datetime
|
|
20
21
|
from logging import WARN
|
|
21
|
-
from typing import
|
|
22
|
+
from typing import cast
|
|
22
23
|
|
|
23
24
|
from flwr.common import EvaluateRes, Scalar
|
|
24
25
|
from flwr.common.logger import log
|
|
25
26
|
from flwr.server.client_proxy import ClientProxy
|
|
26
27
|
from flwr.server.strategy import Strategy
|
|
27
28
|
|
|
28
|
-
try:
|
|
29
|
-
import tensorflow as TF
|
|
30
|
-
except ModuleNotFoundError:
|
|
31
|
-
TF = None
|
|
32
|
-
|
|
33
29
|
MISSING_EXTRA_TF = """
|
|
34
30
|
Extra dependency required for using tensorboard are missing.
|
|
35
31
|
The program will continue without tensorboard.
|
|
@@ -59,6 +55,17 @@ def tensorboard(logdir: str) -> Callable[[Strategy], Strategy]:
|
|
|
59
55
|
# Variant 2
|
|
60
56
|
strategy = tensorboard(logdir=LOGDIR)(FedAvg)()
|
|
61
57
|
"""
|
|
58
|
+
log(
|
|
59
|
+
WARN,
|
|
60
|
+
"The `tensorboard` function is deprecated and will be removed "
|
|
61
|
+
"in a future release.",
|
|
62
|
+
)
|
|
63
|
+
# Lazy import of TensorFlow to avoid slow import times
|
|
64
|
+
try:
|
|
65
|
+
import tensorflow as TF # pylint: disable=import-outside-toplevel
|
|
66
|
+
except ModuleNotFoundError:
|
|
67
|
+
TF = None # pylint: disable=invalid-name
|
|
68
|
+
|
|
62
69
|
print(
|
|
63
70
|
"\n\t\033[32mStart TensorBoard with the following parameters"
|
|
64
71
|
f"\n\t$ tensorboard --logdir {logdir}\033[39m\n"
|
|
@@ -93,8 +100,8 @@ def tensorboard(logdir: str) -> Callable[[Strategy], Strategy]:
|
|
|
93
100
|
self,
|
|
94
101
|
server_round: int,
|
|
95
102
|
results: list[tuple[ClientProxy, EvaluateRes]],
|
|
96
|
-
failures: list[
|
|
97
|
-
) -> tuple[
|
|
103
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
|
|
104
|
+
) -> tuple[float | None, dict[str, Scalar]]:
|
|
98
105
|
"""Hooks into aggregate_evaluate for TensorBoard logging purpose."""
|
|
99
106
|
# Execute decorated function and extract results for logging
|
|
100
107
|
# They will be returned at the end of this function but also
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
import io
|
|
19
19
|
import timeit
|
|
20
20
|
from logging import INFO, WARN
|
|
21
|
-
from typing import
|
|
21
|
+
from typing import cast
|
|
22
22
|
|
|
23
23
|
import flwr.common.recorddict_compat as compat
|
|
24
24
|
from flwr.common import (
|
|
@@ -47,8 +47,8 @@ class DefaultWorkflow:
|
|
|
47
47
|
|
|
48
48
|
def __init__(
|
|
49
49
|
self,
|
|
50
|
-
fit_workflow:
|
|
51
|
-
evaluate_workflow:
|
|
50
|
+
fit_workflow: Workflow | None = None,
|
|
51
|
+
evaluate_workflow: Workflow | None = None,
|
|
52
52
|
) -> None:
|
|
53
53
|
if fit_workflow is None:
|
|
54
54
|
fit_workflow = default_fit_workflow
|
|
@@ -275,7 +275,7 @@ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disab
|
|
|
275
275
|
|
|
276
276
|
# Aggregate training results
|
|
277
277
|
results: list[tuple[ClientProxy, FitRes]] = []
|
|
278
|
-
failures: list[
|
|
278
|
+
failures: list[tuple[ClientProxy, FitRes] | BaseException] = []
|
|
279
279
|
for msg in messages:
|
|
280
280
|
if msg.has_content():
|
|
281
281
|
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
|
@@ -357,7 +357,7 @@ def default_evaluate_workflow(grid: Grid, context: Context) -> None:
|
|
|
357
357
|
|
|
358
358
|
# Aggregate the evaluation results
|
|
359
359
|
results: list[tuple[ClientProxy, EvaluateRes]] = []
|
|
360
|
-
failures: list[
|
|
360
|
+
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException] = []
|
|
361
361
|
for msg in messages:
|
|
362
362
|
if msg.has_content():
|
|
363
363
|
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
|
@@ -15,8 +15,6 @@
|
|
|
15
15
|
"""Workflow for the SecAgg protocol."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import Optional, Union
|
|
19
|
-
|
|
20
18
|
from .secaggplus_workflow import SecAggPlusWorkflow
|
|
21
19
|
|
|
22
20
|
|
|
@@ -94,13 +92,13 @@ class SecAggWorkflow(SecAggPlusWorkflow):
|
|
|
94
92
|
|
|
95
93
|
def __init__( # pylint: disable=R0913
|
|
96
94
|
self,
|
|
97
|
-
reconstruction_threshold:
|
|
95
|
+
reconstruction_threshold: int | float,
|
|
98
96
|
*,
|
|
99
97
|
max_weight: float = 1000.0,
|
|
100
98
|
clipping_range: float = 8.0,
|
|
101
99
|
quantization_range: int = 4194304,
|
|
102
100
|
modulus_range: int = 4294967296,
|
|
103
|
-
timeout:
|
|
101
|
+
timeout: float | None = None,
|
|
104
102
|
) -> None:
|
|
105
103
|
super().__init__(
|
|
106
104
|
num_shares=1.0,
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
import random
|
|
19
19
|
from dataclasses import dataclass, field
|
|
20
20
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
21
|
-
from typing import
|
|
21
|
+
from typing import cast
|
|
22
22
|
|
|
23
23
|
import flwr.common.recorddict_compat as compat
|
|
24
24
|
from flwr.common import (
|
|
@@ -169,14 +169,14 @@ class SecAggPlusWorkflow:
|
|
|
169
169
|
|
|
170
170
|
def __init__( # pylint: disable=R0913
|
|
171
171
|
self,
|
|
172
|
-
num_shares:
|
|
173
|
-
reconstruction_threshold:
|
|
172
|
+
num_shares: int | float,
|
|
173
|
+
reconstruction_threshold: int | float,
|
|
174
174
|
*,
|
|
175
175
|
max_weight: float = 1000.0,
|
|
176
176
|
clipping_range: float = 8.0,
|
|
177
177
|
quantization_range: int = 4194304,
|
|
178
178
|
modulus_range: int = 4294967296,
|
|
179
|
-
timeout:
|
|
179
|
+
timeout: float | None = None,
|
|
180
180
|
) -> None:
|
|
181
181
|
self.num_shares = num_shares
|
|
182
182
|
self.reconstruction_threshold = reconstruction_threshold
|
|
@@ -211,7 +211,7 @@ class SecAggPlusWorkflow:
|
|
|
211
211
|
|
|
212
212
|
def _check_init_params(self) -> None: # pylint: disable=R0912
|
|
213
213
|
# Check `num_shares`
|
|
214
|
-
if not isinstance(self.num_shares, (int
|
|
214
|
+
if not isinstance(self.num_shares, (int | float)):
|
|
215
215
|
raise TypeError("`num_shares` must be of type int or float.")
|
|
216
216
|
if isinstance(self.num_shares, int):
|
|
217
217
|
if self.num_shares == 1:
|
|
@@ -229,7 +229,7 @@ class SecAggPlusWorkflow:
|
|
|
229
229
|
raise ValueError("`num_shares` as a float must be greater than 0.")
|
|
230
230
|
|
|
231
231
|
# Check `reconstruction_threshold`
|
|
232
|
-
if not isinstance(self.reconstruction_threshold, (int
|
|
232
|
+
if not isinstance(self.reconstruction_threshold, (int | float)):
|
|
233
233
|
raise TypeError("`reconstruction_threshold` must be of type int or float.")
|
|
234
234
|
if isinstance(self.reconstruction_threshold, int):
|
|
235
235
|
if self.reconstruction_threshold == 1:
|
|
@@ -467,7 +467,7 @@ class SecAggPlusWorkflow:
|
|
|
467
467
|
dsts += dst_lst
|
|
468
468
|
ciphertexts += ctxt_lst
|
|
469
469
|
|
|
470
|
-
for src, dst, ciphertext in zip(srcs, dsts, ciphertexts):
|
|
470
|
+
for src, dst, ciphertext in zip(srcs, dsts, ciphertexts, strict=True):
|
|
471
471
|
if dst in fwd_ciphertexts:
|
|
472
472
|
fwd_ciphertexts[dst].append(ciphertext)
|
|
473
473
|
fwd_srcs[dst].append(src)
|
|
@@ -604,7 +604,7 @@ class SecAggPlusWorkflow:
|
|
|
604
604
|
res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
|
|
605
605
|
nids = cast(list[int], res_dict[Key.NODE_ID_LIST])
|
|
606
606
|
shares = cast(list[bytes], res_dict[Key.SHARE_LIST])
|
|
607
|
-
for owner_nid, share in zip(nids, shares):
|
|
607
|
+
for owner_nid, share in zip(nids, shares, strict=True):
|
|
608
608
|
collected_shares_dict[owner_nid].append(share)
|
|
609
609
|
|
|
610
610
|
# Remove masks for every active client after collect_masked_vectors stage
|
|
@@ -18,10 +18,9 @@ Paper: arxiv.org/abs/1802.07927
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from collections import
|
|
22
|
-
from collections.abc import Iterable
|
|
21
|
+
from collections.abc import Callable, Iterable
|
|
23
22
|
from logging import INFO, WARN
|
|
24
|
-
from typing import
|
|
23
|
+
from typing import cast
|
|
25
24
|
|
|
26
25
|
import numpy as np
|
|
27
26
|
|
|
@@ -104,15 +103,15 @@ class Bulyan(FedAvg):
|
|
|
104
103
|
weighted_by_key: str = "num-examples",
|
|
105
104
|
arrayrecord_key: str = "arrays",
|
|
106
105
|
configrecord_key: str = "config",
|
|
107
|
-
train_metrics_aggr_fn:
|
|
108
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
109
|
-
|
|
110
|
-
evaluate_metrics_aggr_fn:
|
|
111
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
112
|
-
|
|
113
|
-
selection_rule:
|
|
114
|
-
Callable[[list[RecordDict], int, int], list[RecordDict]]
|
|
115
|
-
|
|
106
|
+
train_metrics_aggr_fn: (
|
|
107
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
108
|
+
) = None,
|
|
109
|
+
evaluate_metrics_aggr_fn: (
|
|
110
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
111
|
+
) = None,
|
|
112
|
+
selection_rule: (
|
|
113
|
+
Callable[[list[RecordDict], int, int], list[RecordDict]] | None
|
|
114
|
+
) = None,
|
|
116
115
|
) -> None:
|
|
117
116
|
super().__init__(
|
|
118
117
|
fraction_train=fraction_train,
|
|
@@ -140,7 +139,7 @@ class Bulyan(FedAvg):
|
|
|
140
139
|
self,
|
|
141
140
|
server_round: int,
|
|
142
141
|
replies: Iterable[Message],
|
|
143
|
-
) -> tuple[
|
|
142
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
144
143
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
145
144
|
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
|
146
145
|
|
|
@@ -175,7 +174,9 @@ class Bulyan(FedAvg):
|
|
|
175
174
|
]
|
|
176
175
|
|
|
177
176
|
# Compute median
|
|
178
|
-
median_ndarrays = [
|
|
177
|
+
median_ndarrays = [
|
|
178
|
+
np.median(arr, axis=0) for arr in zip(*selected_ndarrays, strict=True)
|
|
179
|
+
]
|
|
179
180
|
|
|
180
181
|
# Aggregate the beta closest weights element-wise
|
|
181
182
|
aggregated_ndarrays = aggregate_n_closest_weights(
|
|
@@ -184,7 +185,7 @@ class Bulyan(FedAvg):
|
|
|
184
185
|
|
|
185
186
|
# Convert to ArrayRecord
|
|
186
187
|
arrays = ArrayRecord(
|
|
187
|
-
|
|
188
|
+
dict(zip(array_keys, map(Array, aggregated_ndarrays), strict=True))
|
|
188
189
|
)
|
|
189
190
|
|
|
190
191
|
# Aggregate MetricRecords
|
|
@@ -19,10 +19,8 @@ Paper (Andrew et al.): https://arxiv.org/abs/1905.03871
|
|
|
19
19
|
|
|
20
20
|
import math
|
|
21
21
|
from abc import ABC
|
|
22
|
-
from collections import OrderedDict
|
|
23
22
|
from collections.abc import Iterable
|
|
24
23
|
from logging import INFO
|
|
25
|
-
from typing import Optional
|
|
26
24
|
|
|
27
25
|
import numpy as np
|
|
28
26
|
|
|
@@ -53,7 +51,7 @@ class DifferentialPrivacyAdaptiveBase(Strategy, ABC):
|
|
|
53
51
|
initial_clipping_norm: float = 0.1,
|
|
54
52
|
target_clipped_quantile: float = 0.5,
|
|
55
53
|
clip_norm_lr: float = 0.2,
|
|
56
|
-
clipped_count_stddev:
|
|
54
|
+
clipped_count_stddev: float | None = None,
|
|
57
55
|
) -> None:
|
|
58
56
|
super().__init__()
|
|
59
57
|
|
|
@@ -96,7 +94,7 @@ class DifferentialPrivacyAdaptiveBase(Strategy, ABC):
|
|
|
96
94
|
add_gaussian_noise_inplace(nds, stdv)
|
|
97
95
|
log(INFO, "aggregate_fit: central DP noise with %.4f stdev added", stdv)
|
|
98
96
|
return ArrayRecord(
|
|
99
|
-
|
|
97
|
+
{k: Array(v) for k, v in zip(aggregated.keys(), nds, strict=True)}
|
|
100
98
|
)
|
|
101
99
|
|
|
102
100
|
def _noisy_fraction(self, count: int, total: int) -> float:
|
|
@@ -115,7 +113,7 @@ class DifferentialPrivacyAdaptiveBase(Strategy, ABC):
|
|
|
115
113
|
|
|
116
114
|
def aggregate_evaluate(
|
|
117
115
|
self, server_round: int, replies: Iterable[Message]
|
|
118
|
-
) ->
|
|
116
|
+
) -> MetricRecord | None:
|
|
119
117
|
"""Aggregate MetricRecords in the received Messages."""
|
|
120
118
|
return self.strategy.aggregate_evaluate(server_round, replies)
|
|
121
119
|
|
|
@@ -136,7 +134,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(DifferentialPrivacyAdaptiveB
|
|
|
136
134
|
initial_clipping_norm: float = 0.1,
|
|
137
135
|
target_clipped_quantile: float = 0.5,
|
|
138
136
|
clip_norm_lr: float = 0.2,
|
|
139
|
-
clipped_count_stddev:
|
|
137
|
+
clipped_count_stddev: float | None = None,
|
|
140
138
|
) -> None:
|
|
141
139
|
super().__init__(
|
|
142
140
|
strategy,
|
|
@@ -171,7 +169,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(DifferentialPrivacyAdaptiveB
|
|
|
171
169
|
|
|
172
170
|
def aggregate_train(
|
|
173
171
|
self, server_round: int, replies: Iterable[Message]
|
|
174
|
-
) -> tuple[
|
|
172
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
175
173
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
176
174
|
if not validate_replies(replies, self.num_sampled_clients):
|
|
177
175
|
return None, None
|
|
@@ -184,16 +182,19 @@ class DifferentialPrivacyServerSideAdaptiveClipping(DifferentialPrivacyAdaptiveB
|
|
|
184
182
|
for arr_name, record in reply.content.array_records.items():
|
|
185
183
|
reply_nd = record.to_numpy_ndarrays()
|
|
186
184
|
model_update = [
|
|
187
|
-
np.subtract(x, y)
|
|
185
|
+
np.subtract(x, y)
|
|
186
|
+
for (x, y) in zip(reply_nd, current_nd, strict=True)
|
|
188
187
|
]
|
|
189
188
|
norm_bit = adaptive_clip_inputs_inplace(
|
|
190
189
|
model_update, self.clipping_norm
|
|
191
190
|
)
|
|
192
191
|
clipped_indicator_count += int(norm_bit)
|
|
193
192
|
# reconstruct array using clipped contribution from current round
|
|
194
|
-
restored = [
|
|
193
|
+
restored = [
|
|
194
|
+
c + u for c, u in zip(current_nd, model_update, strict=True)
|
|
195
|
+
]
|
|
195
196
|
reply.content[arr_name] = ArrayRecord(
|
|
196
|
-
|
|
197
|
+
{k: Array(v) for k, v in zip(record.keys(), restored, strict=True)}
|
|
197
198
|
)
|
|
198
199
|
log(
|
|
199
200
|
INFO,
|
|
@@ -287,7 +288,7 @@ class DifferentialPrivacyClientSideAdaptiveClipping(DifferentialPrivacyAdaptiveB
|
|
|
287
288
|
|
|
288
289
|
def aggregate_train(
|
|
289
290
|
self, server_round: int, replies: Iterable[Message]
|
|
290
|
-
) -> tuple[
|
|
291
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
291
292
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
292
293
|
if not validate_replies(replies, self.num_sampled_clients):
|
|
293
294
|
return None, None
|