flwr-nightly 1.8.0.dev20240314__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +7 -0
- flwr/cli/build.py +150 -0
- flwr/cli/config_utils.py +219 -0
- flwr/cli/example.py +3 -1
- flwr/cli/install.py +227 -0
- flwr/cli/new/new.py +179 -48
- flwr/cli/new/templates/app/.gitignore.tpl +160 -0
- flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
- flwr/cli/new/templates/app/README.md.tpl +1 -5
- flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
- flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
- flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
- flwr/cli/run/run.py +168 -17
- flwr/cli/utils.py +75 -4
- flwr/client/__init__.py +6 -1
- flwr/client/app.py +239 -248
- flwr/client/client_app.py +70 -9
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +15 -0
- flwr/client/grpc_adapter_client/connection.py +97 -0
- flwr/client/grpc_client/connection.py +18 -5
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +127 -33
- flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
- flwr/client/heartbeat.py +74 -0
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +7 -7
- flwr/client/mod/__init__.py +5 -5
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/comms_mods.py +4 -4
- flwr/client/mod/localdp_mod.py +9 -4
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/node_state.py +60 -10
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +177 -157
- flwr/client/supernode/__init__.py +26 -0
- flwr/client/supernode/app.py +464 -0
- flwr/client/typing.py +1 -0
- flwr/common/__init__.py +13 -11
- flwr/common/address.py +1 -1
- flwr/common/config.py +193 -0
- flwr/common/constant.py +42 -1
- flwr/common/context.py +26 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +6 -2
- flwr/common/logger.py +79 -8
- flwr/common/message.py +167 -105
- flwr/common/object_ref.py +126 -25
- flwr/common/record/__init__.py +1 -1
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +78 -27
- flwr/common/recordset_compat.py +8 -1
- flwr/common/retry_invoker.py +25 -13
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/serde.py +209 -3
- flwr/common/telemetry.py +25 -0
- flwr/common/typing.py +38 -0
- flwr/common/version.py +14 -0
- flwr/proto/clientappio_pb2.py +41 -0
- flwr/proto/clientappio_pb2.pyi +110 -0
- flwr/proto/clientappio_pb2_grpc.py +101 -0
- flwr/proto/clientappio_pb2_grpc.pyi +40 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +26 -19
- flwr/proto/driver_pb2.pyi +34 -0
- flwr/proto/driver_pb2_grpc.py +70 -0
- flwr/proto/driver_pb2_grpc.pyi +28 -0
- flwr/proto/exec_pb2.py +43 -0
- flwr/proto/exec_pb2.pyi +95 -0
- flwr/proto/exec_pb2_grpc.py +101 -0
- flwr/proto/exec_pb2_grpc.pyi +41 -0
- flwr/proto/fab_pb2.py +30 -0
- flwr/proto/fab_pb2.pyi +56 -0
- flwr/proto/fab_pb2_grpc.py +4 -0
- flwr/proto/fab_pb2_grpc.pyi +4 -0
- flwr/proto/fleet_pb2.py +29 -23
- flwr/proto/fleet_pb2.pyi +33 -0
- flwr/proto/fleet_pb2_grpc.py +102 -0
- flwr/proto/fleet_pb2_grpc.pyi +35 -0
- flwr/proto/grpcadapter_pb2.py +32 -0
- flwr/proto/grpcadapter_pb2.pyi +43 -0
- flwr/proto/grpcadapter_pb2_grpc.py +66 -0
- flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
- flwr/proto/message_pb2.py +41 -0
- flwr/proto/message_pb2.pyi +122 -0
- flwr/proto/message_pb2_grpc.py +4 -0
- flwr/proto/message_pb2_grpc.pyi +4 -0
- flwr/proto/run_pb2.py +35 -0
- flwr/proto/run_pb2.pyi +76 -0
- flwr/proto/run_pb2_grpc.py +4 -0
- flwr/proto/run_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +7 -8
- flwr/proto/task_pb2.pyi +8 -5
- flwr/server/__init__.py +4 -8
- flwr/server/app.py +298 -350
- flwr/server/compat/app.py +6 -57
- flwr/server/compat/app_utils.py +5 -4
- flwr/server/compat/driver_client_proxy.py +29 -48
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/__init__.py +2 -0
- flwr/server/driver/driver.py +22 -132
- flwr/server/driver/grpc_driver.py +224 -74
- flwr/server/driver/inmemory_driver.py +183 -0
- flwr/server/history.py +20 -20
- flwr/server/run_serverapp.py +121 -34
- flwr/server/server.py +11 -7
- flwr/server/server_app.py +59 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/strategy/__init__.py +2 -2
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dp_adaptive_clipping.py +3 -3
- flwr/server/strategy/dp_fixed_clipping.py +4 -3
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +51 -4
- flwr/server/superlink/ffs/__init__.py +24 -0
- flwr/server/superlink/ffs/disk_ffs.py +104 -0
- flwr/server/superlink/ffs/ffs.py +79 -0
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
- flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
- flwr/server/superlink/fleet/vce/vce_api.py +190 -127
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +159 -42
- flwr/server/superlink/state/sqlite_state.py +243 -39
- flwr/server/superlink/state/state.py +81 -6
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/superlink/state/utils.py +62 -0
- flwr/server/typing.py +2 -0
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/server/utils/validator.py +23 -9
- flwr/server/workflow/default_workflows.py +67 -25
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
- flwr/simulation/__init__.py +7 -4
- flwr/simulation/app.py +67 -36
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +20 -46
- flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
- flwr/simulation/run_simulation.py +308 -92
- flwr/superexec/__init__.py +21 -0
- flwr/superexec/app.py +184 -0
- flwr/superexec/deployment.py +185 -0
- flwr/superexec/exec_grpc.py +55 -0
- flwr/superexec/exec_servicer.py +70 -0
- flwr/superexec/executor.py +75 -0
- flwr/superexec/simulation.py +193 -0
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
- flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
- flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
- flwr/cli/flower_toml.py +0 -140
- flwr/cli/new/templates/app/flower.toml.tpl +0 -13
- flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
- flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
- flwr_nightly-1.8.0.dev20240314.dist-info/RECORD +0 -211
- flwr_nightly-1.8.0.dev20240314.dist-info/entry_points.txt +0 -9
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -15,32 +15,52 @@
|
|
|
15
15
|
"""Contextmanager for a gRPC request-response channel to the Flower server."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import random
|
|
19
|
+
import threading
|
|
18
20
|
from contextlib import contextmanager
|
|
19
21
|
from copy import copy
|
|
20
22
|
from logging import DEBUG, ERROR
|
|
21
23
|
from pathlib import Path
|
|
22
|
-
from typing import Callable,
|
|
24
|
+
from typing import Callable, Iterator, Optional, Sequence, Tuple, Type, Union, cast
|
|
23
25
|
|
|
26
|
+
import grpc
|
|
27
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
28
|
+
|
|
29
|
+
from flwr.client.heartbeat import start_ping_loop
|
|
24
30
|
from flwr.client.message_handler.message_handler import validate_out_message
|
|
25
31
|
from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins
|
|
26
32
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
33
|
+
from flwr.common.constant import (
|
|
34
|
+
PING_BASE_MULTIPLIER,
|
|
35
|
+
PING_CALL_TIMEOUT,
|
|
36
|
+
PING_DEFAULT_INTERVAL,
|
|
37
|
+
PING_RANDOM_RANGE,
|
|
38
|
+
)
|
|
27
39
|
from flwr.common.grpc import create_channel
|
|
28
|
-
from flwr.common.logger import log
|
|
40
|
+
from flwr.common.logger import log
|
|
29
41
|
from flwr.common.message import Message, Metadata
|
|
30
42
|
from flwr.common.retry_invoker import RetryInvoker
|
|
31
|
-
from flwr.common.serde import
|
|
43
|
+
from flwr.common.serde import (
|
|
44
|
+
message_from_taskins,
|
|
45
|
+
message_to_taskres,
|
|
46
|
+
user_config_from_proto,
|
|
47
|
+
)
|
|
48
|
+
from flwr.common.typing import Fab, Run
|
|
32
49
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
33
50
|
CreateNodeRequest,
|
|
34
51
|
DeleteNodeRequest,
|
|
52
|
+
PingRequest,
|
|
53
|
+
PingResponse,
|
|
35
54
|
PullTaskInsRequest,
|
|
36
55
|
PushTaskResRequest,
|
|
37
56
|
)
|
|
38
57
|
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
|
39
58
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
59
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
40
60
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
41
61
|
|
|
42
|
-
|
|
43
|
-
|
|
62
|
+
from .client_interceptor import AuthenticateClientInterceptor
|
|
63
|
+
from .grpc_adapter import GrpcAdapter
|
|
44
64
|
|
|
45
65
|
|
|
46
66
|
def on_channel_state_change(channel_connectivity: str) -> None:
|
|
@@ -49,18 +69,24 @@ def on_channel_state_change(channel_connectivity: str) -> None:
|
|
|
49
69
|
|
|
50
70
|
|
|
51
71
|
@contextmanager
|
|
52
|
-
def grpc_request_response(
|
|
72
|
+
def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
53
73
|
server_address: str,
|
|
54
74
|
insecure: bool,
|
|
55
75
|
retry_invoker: RetryInvoker,
|
|
56
76
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
|
|
57
77
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
78
|
+
authentication_keys: Optional[
|
|
79
|
+
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
80
|
+
] = None,
|
|
81
|
+
adapter_cls: Optional[Union[Type[FleetStub], Type[GrpcAdapter]]] = None,
|
|
58
82
|
) -> Iterator[
|
|
59
83
|
Tuple[
|
|
60
84
|
Callable[[], Optional[Message]],
|
|
61
85
|
Callable[[Message], None],
|
|
86
|
+
Optional[Callable[[], Optional[int]]],
|
|
62
87
|
Optional[Callable[[], None]],
|
|
63
|
-
Optional[Callable[[],
|
|
88
|
+
Optional[Callable[[int], Run]],
|
|
89
|
+
Optional[Callable[[str], Fab]],
|
|
64
90
|
]
|
|
65
91
|
]:
|
|
66
92
|
"""Primitives for request/response-based interaction with a server.
|
|
@@ -87,6 +113,11 @@ def grpc_request_response(
|
|
|
87
113
|
Path of the root certificate. If provided, a secure
|
|
88
114
|
connection using the certificates will be established to an SSL-enabled
|
|
89
115
|
Flower server. Bytes won't work for the REST API.
|
|
116
|
+
authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
|
|
117
|
+
Tuple containing the elliptic curve private key and public key for
|
|
118
|
+
authentication from the cryptography library.
|
|
119
|
+
Source: https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ec/
|
|
120
|
+
Used to establish an authenticated connection with the server.
|
|
90
121
|
|
|
91
122
|
Returns
|
|
92
123
|
-------
|
|
@@ -94,60 +125,101 @@ def grpc_request_response(
|
|
|
94
125
|
send : Callable
|
|
95
126
|
create_node : Optional[Callable]
|
|
96
127
|
delete_node : Optional[Callable]
|
|
128
|
+
get_run : Optional[Callable]
|
|
97
129
|
"""
|
|
98
|
-
warn_experimental_feature("`grpc-rere`")
|
|
99
|
-
|
|
100
130
|
if isinstance(root_certificates, str):
|
|
101
131
|
root_certificates = Path(root_certificates).read_bytes()
|
|
102
132
|
|
|
133
|
+
interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None
|
|
134
|
+
if authentication_keys is not None:
|
|
135
|
+
interceptors = AuthenticateClientInterceptor(
|
|
136
|
+
authentication_keys[0], authentication_keys[1]
|
|
137
|
+
)
|
|
138
|
+
|
|
103
139
|
channel = create_channel(
|
|
104
140
|
server_address=server_address,
|
|
105
141
|
insecure=insecure,
|
|
106
142
|
root_certificates=root_certificates,
|
|
107
143
|
max_message_length=max_message_length,
|
|
144
|
+
interceptors=interceptors,
|
|
108
145
|
)
|
|
109
146
|
channel.subscribe(on_channel_state_change)
|
|
110
|
-
stub = FleetStub(channel)
|
|
111
|
-
|
|
112
|
-
# Necessary state to validate messages to be sent
|
|
113
|
-
state: Dict[str, Optional[Metadata]] = {KEY_METADATA: None}
|
|
114
147
|
|
|
115
|
-
#
|
|
116
|
-
|
|
148
|
+
# Shared variables for inner functions
|
|
149
|
+
if adapter_cls is None:
|
|
150
|
+
adapter_cls = FleetStub
|
|
151
|
+
stub = adapter_cls(channel)
|
|
152
|
+
metadata: Optional[Metadata] = None
|
|
153
|
+
node: Optional[Node] = None
|
|
154
|
+
ping_thread: Optional[threading.Thread] = None
|
|
155
|
+
ping_stop_event = threading.Event()
|
|
117
156
|
|
|
118
157
|
###########################################################################
|
|
119
|
-
# receive/send functions
|
|
158
|
+
# ping/create_node/delete_node/receive/send/get_run functions
|
|
120
159
|
###########################################################################
|
|
121
160
|
|
|
122
|
-
def
|
|
161
|
+
def ping() -> None:
|
|
162
|
+
# Get Node
|
|
163
|
+
if node is None:
|
|
164
|
+
log(ERROR, "Node instance missing")
|
|
165
|
+
return
|
|
166
|
+
|
|
167
|
+
# Construct the ping request
|
|
168
|
+
req = PingRequest(node=node, ping_interval=PING_DEFAULT_INTERVAL)
|
|
169
|
+
|
|
170
|
+
# Call FleetAPI
|
|
171
|
+
res: PingResponse = stub.Ping(req, timeout=PING_CALL_TIMEOUT)
|
|
172
|
+
|
|
173
|
+
# Check if success
|
|
174
|
+
if not res.success:
|
|
175
|
+
raise RuntimeError("Ping failed unexpectedly.")
|
|
176
|
+
|
|
177
|
+
# Wait
|
|
178
|
+
rd = random.uniform(*PING_RANDOM_RANGE)
|
|
179
|
+
next_interval: float = PING_DEFAULT_INTERVAL - PING_CALL_TIMEOUT
|
|
180
|
+
next_interval *= PING_BASE_MULTIPLIER + rd
|
|
181
|
+
if not ping_stop_event.is_set():
|
|
182
|
+
ping_stop_event.wait(next_interval)
|
|
183
|
+
|
|
184
|
+
def create_node() -> Optional[int]:
|
|
123
185
|
"""Set create_node."""
|
|
124
|
-
|
|
186
|
+
# Call FleetAPI
|
|
187
|
+
create_node_request = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
|
|
125
188
|
create_node_response = retry_invoker.invoke(
|
|
126
189
|
stub.CreateNode,
|
|
127
190
|
request=create_node_request,
|
|
128
191
|
)
|
|
129
|
-
|
|
192
|
+
|
|
193
|
+
# Remember the node and the ping-loop thread
|
|
194
|
+
nonlocal node, ping_thread
|
|
195
|
+
node = cast(Node, create_node_response.node)
|
|
196
|
+
ping_thread = start_ping_loop(ping, ping_stop_event)
|
|
197
|
+
return node.node_id
|
|
130
198
|
|
|
131
199
|
def delete_node() -> None:
|
|
132
200
|
"""Set delete_node."""
|
|
133
201
|
# Get Node
|
|
134
|
-
|
|
202
|
+
nonlocal node
|
|
203
|
+
if node is None:
|
|
135
204
|
log(ERROR, "Node instance missing")
|
|
136
205
|
return
|
|
137
|
-
node: Node = cast(Node, node_store[KEY_NODE])
|
|
138
206
|
|
|
207
|
+
# Stop the ping-loop thread
|
|
208
|
+
ping_stop_event.set()
|
|
209
|
+
|
|
210
|
+
# Call FleetAPI
|
|
139
211
|
delete_node_request = DeleteNodeRequest(node=node)
|
|
140
212
|
retry_invoker.invoke(stub.DeleteNode, request=delete_node_request)
|
|
141
213
|
|
|
142
|
-
|
|
214
|
+
# Cleanup
|
|
215
|
+
node = None
|
|
143
216
|
|
|
144
217
|
def receive() -> Optional[Message]:
|
|
145
218
|
"""Receive next task from server."""
|
|
146
219
|
# Get Node
|
|
147
|
-
if
|
|
220
|
+
if node is None:
|
|
148
221
|
log(ERROR, "Node instance missing")
|
|
149
222
|
return None
|
|
150
|
-
node: Node = cast(Node, node_store[KEY_NODE])
|
|
151
223
|
|
|
152
224
|
# Request instructions (task) from server
|
|
153
225
|
request = PullTaskInsRequest(node=node)
|
|
@@ -167,7 +239,8 @@ def grpc_request_response(
|
|
|
167
239
|
in_message = message_from_taskins(task_ins) if task_ins else None
|
|
168
240
|
|
|
169
241
|
# Remember `metadata` of the in message
|
|
170
|
-
|
|
242
|
+
nonlocal metadata
|
|
243
|
+
metadata = copy(in_message.metadata) if in_message else None
|
|
171
244
|
|
|
172
245
|
# Return the message if available
|
|
173
246
|
return in_message
|
|
@@ -175,18 +248,18 @@ def grpc_request_response(
|
|
|
175
248
|
def send(message: Message) -> None:
|
|
176
249
|
"""Send task result back to server."""
|
|
177
250
|
# Get Node
|
|
178
|
-
if
|
|
251
|
+
if node is None:
|
|
179
252
|
log(ERROR, "Node instance missing")
|
|
180
253
|
return
|
|
181
254
|
|
|
182
|
-
# Get incoming message
|
|
183
|
-
|
|
184
|
-
if
|
|
255
|
+
# Get the metadata of the incoming message
|
|
256
|
+
nonlocal metadata
|
|
257
|
+
if metadata is None:
|
|
185
258
|
log(ERROR, "No current message")
|
|
186
259
|
return
|
|
187
260
|
|
|
188
261
|
# Validate out message
|
|
189
|
-
if not validate_out_message(message,
|
|
262
|
+
if not validate_out_message(message, metadata):
|
|
190
263
|
log(ERROR, "Invalid out message")
|
|
191
264
|
return
|
|
192
265
|
|
|
@@ -197,10 +270,31 @@ def grpc_request_response(
|
|
|
197
270
|
request = PushTaskResRequest(task_res_list=[task_res])
|
|
198
271
|
_ = retry_invoker.invoke(stub.PushTaskRes, request)
|
|
199
272
|
|
|
200
|
-
|
|
273
|
+
# Cleanup
|
|
274
|
+
metadata = None
|
|
275
|
+
|
|
276
|
+
def get_run(run_id: int) -> Run:
|
|
277
|
+
# Call FleetAPI
|
|
278
|
+
get_run_request = GetRunRequest(run_id=run_id)
|
|
279
|
+
get_run_response: GetRunResponse = retry_invoker.invoke(
|
|
280
|
+
stub.GetRun,
|
|
281
|
+
request=get_run_request,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
# Return fab_id and fab_version
|
|
285
|
+
return Run(
|
|
286
|
+
run_id,
|
|
287
|
+
get_run_response.run.fab_id,
|
|
288
|
+
get_run_response.run.fab_version,
|
|
289
|
+
user_config_from_proto(get_run_response.run.override_config),
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
def get_fab(fab_hash: str) -> Fab:
|
|
293
|
+
# Call FleetAPI
|
|
294
|
+
raise NotImplementedError
|
|
201
295
|
|
|
202
296
|
try:
|
|
203
297
|
# Yield methods
|
|
204
|
-
yield (receive, send, create_node, delete_node)
|
|
298
|
+
yield (receive, send, create_node, delete_node, get_run, get_fab)
|
|
205
299
|
except Exception as exc: # pylint: disable=broad-except
|
|
206
300
|
log(ERROR, exc)
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""GrpcAdapter implementation."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import sys
|
|
19
|
+
from logging import DEBUG
|
|
20
|
+
from typing import Any, Type, TypeVar, cast
|
|
21
|
+
|
|
22
|
+
import grpc
|
|
23
|
+
from google.protobuf.message import Message as GrpcMessage
|
|
24
|
+
|
|
25
|
+
from flwr.common import log
|
|
26
|
+
from flwr.common.constant import (
|
|
27
|
+
GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY,
|
|
28
|
+
GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY,
|
|
29
|
+
)
|
|
30
|
+
from flwr.common.version import package_version
|
|
31
|
+
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
32
|
+
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
33
|
+
CreateNodeRequest,
|
|
34
|
+
CreateNodeResponse,
|
|
35
|
+
DeleteNodeRequest,
|
|
36
|
+
DeleteNodeResponse,
|
|
37
|
+
PingRequest,
|
|
38
|
+
PingResponse,
|
|
39
|
+
PullTaskInsRequest,
|
|
40
|
+
PullTaskInsResponse,
|
|
41
|
+
PushTaskResRequest,
|
|
42
|
+
PushTaskResResponse,
|
|
43
|
+
)
|
|
44
|
+
from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
|
|
45
|
+
from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub
|
|
46
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
47
|
+
|
|
48
|
+
T = TypeVar("T", bound=GrpcMessage)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class GrpcAdapter:
|
|
52
|
+
"""Adapter class to send and receive gRPC messages via the ``GrpcAdapterStub``.
|
|
53
|
+
|
|
54
|
+
This class utilizes the ``GrpcAdapterStub`` to send and receive gRPC messages
|
|
55
|
+
which are defined and used by the Fleet API, as defined in ``fleet.proto``.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(self, channel: grpc.Channel) -> None:
|
|
59
|
+
self.stub = GrpcAdapterStub(channel)
|
|
60
|
+
|
|
61
|
+
def _send_and_receive(
|
|
62
|
+
self, request: GrpcMessage, response_type: Type[T], **kwargs: Any
|
|
63
|
+
) -> T:
|
|
64
|
+
# Serialize request
|
|
65
|
+
container_req = MessageContainer(
|
|
66
|
+
metadata={GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY: package_version},
|
|
67
|
+
grpc_message_name=request.__class__.__qualname__,
|
|
68
|
+
grpc_message_content=request.SerializeToString(),
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Send via the stub
|
|
72
|
+
container_res = cast(
|
|
73
|
+
MessageContainer, self.stub.SendReceive(container_req, **kwargs)
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# Handle control message
|
|
77
|
+
should_exit = (
|
|
78
|
+
container_res.metadata.get(GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY, "false")
|
|
79
|
+
== "true"
|
|
80
|
+
)
|
|
81
|
+
if should_exit:
|
|
82
|
+
log(
|
|
83
|
+
DEBUG,
|
|
84
|
+
'Received shutdown signal: exit flag is set to ``"true"``. Exiting...',
|
|
85
|
+
)
|
|
86
|
+
sys.exit(0)
|
|
87
|
+
|
|
88
|
+
# Check the grpc_message_name of the response
|
|
89
|
+
if container_res.grpc_message_name != response_type.__qualname__:
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"Invalid grpc_message_name. Expected {response_type.__qualname__}"
|
|
92
|
+
f", but got {container_res.grpc_message_name}."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Deserialize response
|
|
96
|
+
response = response_type()
|
|
97
|
+
response.ParseFromString(container_res.grpc_message_content)
|
|
98
|
+
return response
|
|
99
|
+
|
|
100
|
+
def CreateNode( # pylint: disable=C0103
|
|
101
|
+
self, request: CreateNodeRequest, **kwargs: Any
|
|
102
|
+
) -> CreateNodeResponse:
|
|
103
|
+
"""."""
|
|
104
|
+
return self._send_and_receive(request, CreateNodeResponse, **kwargs)
|
|
105
|
+
|
|
106
|
+
def DeleteNode( # pylint: disable=C0103
|
|
107
|
+
self, request: DeleteNodeRequest, **kwargs: Any
|
|
108
|
+
) -> DeleteNodeResponse:
|
|
109
|
+
"""."""
|
|
110
|
+
return self._send_and_receive(request, DeleteNodeResponse, **kwargs)
|
|
111
|
+
|
|
112
|
+
def Ping( # pylint: disable=C0103
|
|
113
|
+
self, request: PingRequest, **kwargs: Any
|
|
114
|
+
) -> PingResponse:
|
|
115
|
+
"""."""
|
|
116
|
+
return self._send_and_receive(request, PingResponse, **kwargs)
|
|
117
|
+
|
|
118
|
+
def PullTaskIns( # pylint: disable=C0103
|
|
119
|
+
self, request: PullTaskInsRequest, **kwargs: Any
|
|
120
|
+
) -> PullTaskInsResponse:
|
|
121
|
+
"""."""
|
|
122
|
+
return self._send_and_receive(request, PullTaskInsResponse, **kwargs)
|
|
123
|
+
|
|
124
|
+
def PushTaskRes( # pylint: disable=C0103
|
|
125
|
+
self, request: PushTaskResRequest, **kwargs: Any
|
|
126
|
+
) -> PushTaskResResponse:
|
|
127
|
+
"""."""
|
|
128
|
+
return self._send_and_receive(request, PushTaskResResponse, **kwargs)
|
|
129
|
+
|
|
130
|
+
def GetRun( # pylint: disable=C0103
|
|
131
|
+
self, request: GetRunRequest, **kwargs: Any
|
|
132
|
+
) -> GetRunResponse:
|
|
133
|
+
"""."""
|
|
134
|
+
return self._send_and_receive(request, GetRunResponse, **kwargs)
|
|
135
|
+
|
|
136
|
+
def GetFab( # pylint: disable=C0103
|
|
137
|
+
self, request: GetFabRequest, **kwargs: Any
|
|
138
|
+
) -> GetFabResponse:
|
|
139
|
+
"""."""
|
|
140
|
+
return self._send_and_receive(request, GetFabResponse, **kwargs)
|
flwr/client/heartbeat.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Heartbeat utility functions."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import threading
|
|
19
|
+
from typing import Callable
|
|
20
|
+
|
|
21
|
+
import grpc
|
|
22
|
+
|
|
23
|
+
from flwr.common.constant import PING_CALL_TIMEOUT
|
|
24
|
+
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _ping_loop(ping_fn: Callable[[], None], stop_event: threading.Event) -> None:
|
|
28
|
+
def wait_fn(wait_time: float) -> None:
|
|
29
|
+
if not stop_event.is_set():
|
|
30
|
+
stop_event.wait(wait_time)
|
|
31
|
+
|
|
32
|
+
def on_backoff(state: RetryState) -> None:
|
|
33
|
+
err = state.exception
|
|
34
|
+
if not isinstance(err, grpc.RpcError):
|
|
35
|
+
return
|
|
36
|
+
status_code = err.code()
|
|
37
|
+
# If ping call timeout is triggered
|
|
38
|
+
if status_code == grpc.StatusCode.DEADLINE_EXCEEDED:
|
|
39
|
+
# Avoid long wait time.
|
|
40
|
+
if state.actual_wait is None:
|
|
41
|
+
return
|
|
42
|
+
state.actual_wait = max(state.actual_wait - PING_CALL_TIMEOUT, 0.0)
|
|
43
|
+
|
|
44
|
+
def wrapped_ping() -> None:
|
|
45
|
+
if not stop_event.is_set():
|
|
46
|
+
ping_fn()
|
|
47
|
+
|
|
48
|
+
retrier = RetryInvoker(
|
|
49
|
+
exponential,
|
|
50
|
+
grpc.RpcError,
|
|
51
|
+
max_tries=None,
|
|
52
|
+
max_time=None,
|
|
53
|
+
on_backoff=on_backoff,
|
|
54
|
+
wait_function=wait_fn,
|
|
55
|
+
)
|
|
56
|
+
while not stop_event.is_set():
|
|
57
|
+
retrier.invoke(wrapped_ping)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def start_ping_loop(
|
|
61
|
+
ping_fn: Callable[[], None], stop_event: threading.Event
|
|
62
|
+
) -> threading.Thread:
|
|
63
|
+
"""Start a ping loop in a separate thread.
|
|
64
|
+
|
|
65
|
+
This function initializes a new thread that runs a ping loop, allowing for
|
|
66
|
+
asynchronous ping operations. The loop can be terminated through the provided stop
|
|
67
|
+
event.
|
|
68
|
+
"""
|
|
69
|
+
thread = threading.Thread(
|
|
70
|
+
target=_ping_loop, args=(ping_fn, stop_event), daemon=True
|
|
71
|
+
)
|
|
72
|
+
thread.start()
|
|
73
|
+
|
|
74
|
+
return thread
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2022 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -14,7 +14,6 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Client-side message handler."""
|
|
16
16
|
|
|
17
|
-
|
|
18
17
|
from logging import WARN
|
|
19
18
|
from typing import Optional, Tuple, cast
|
|
20
19
|
|
|
@@ -25,7 +24,7 @@ from flwr.client.client import (
|
|
|
25
24
|
maybe_call_get_properties,
|
|
26
25
|
)
|
|
27
26
|
from flwr.client.numpy_client import NumPyClient
|
|
28
|
-
from flwr.client.typing import
|
|
27
|
+
from flwr.client.typing import ClientFnExt
|
|
29
28
|
from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet, log
|
|
30
29
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
31
30
|
from flwr.common.recordset_compat import (
|
|
@@ -81,7 +80,7 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]:
|
|
|
81
80
|
reason = cast(int, disconnect_msg.disconnect_res.reason)
|
|
82
81
|
recordset = RecordSet()
|
|
83
82
|
recordset.configs_records["config"] = ConfigsRecord({"reason": reason})
|
|
84
|
-
out_message = message.create_reply(recordset
|
|
83
|
+
out_message = message.create_reply(recordset)
|
|
85
84
|
# Return TaskRes and sleep duration
|
|
86
85
|
return out_message, sleep_duration
|
|
87
86
|
|
|
@@ -90,10 +89,10 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]:
|
|
|
90
89
|
|
|
91
90
|
|
|
92
91
|
def handle_legacy_message_from_msgtype(
|
|
93
|
-
client_fn:
|
|
92
|
+
client_fn: ClientFnExt, message: Message, context: Context
|
|
94
93
|
) -> Message:
|
|
95
94
|
"""Handle legacy message in the inner most mod."""
|
|
96
|
-
client = client_fn(
|
|
95
|
+
client = client_fn(context)
|
|
97
96
|
|
|
98
97
|
# Check if NumPyClient is returend
|
|
99
98
|
if isinstance(client, NumPyClient):
|
|
@@ -143,7 +142,7 @@ def handle_legacy_message_from_msgtype(
|
|
|
143
142
|
raise ValueError(f"Invalid message type: {message_type}")
|
|
144
143
|
|
|
145
144
|
# Return Message
|
|
146
|
-
return message.create_reply(out_recordset
|
|
145
|
+
return message.create_reply(out_recordset)
|
|
147
146
|
|
|
148
147
|
|
|
149
148
|
def _reconnect(
|
|
@@ -172,6 +171,7 @@ def validate_out_message(out_message: Message, in_message_metadata: Metadata) ->
|
|
|
172
171
|
and out_meta.reply_to_message == in_meta.message_id
|
|
173
172
|
and out_meta.group_id == in_meta.group_id
|
|
174
173
|
and out_meta.message_type == in_meta.message_type
|
|
174
|
+
and out_meta.created_at > in_meta.created_at
|
|
175
175
|
):
|
|
176
176
|
return True
|
|
177
177
|
return False
|
flwr/client/mod/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Mods."""
|
|
15
|
+
"""Flower Built-in Mods."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
|
|
@@ -22,12 +22,12 @@ from .secure_aggregation import secagg_mod, secaggplus_mod
|
|
|
22
22
|
from .utils import make_ffn
|
|
23
23
|
|
|
24
24
|
__all__ = [
|
|
25
|
+
"LocalDpMod",
|
|
25
26
|
"adaptiveclipping_mod",
|
|
26
27
|
"fixedclipping_mod",
|
|
27
|
-
"LocalDpMod",
|
|
28
28
|
"make_ffn",
|
|
29
|
-
"secagg_mod",
|
|
30
|
-
"secaggplus_mod",
|
|
31
29
|
"message_size_mod",
|
|
32
30
|
"parameters_size_mod",
|
|
31
|
+
"secagg_mod",
|
|
32
|
+
"secaggplus_mod",
|
|
33
33
|
]
|
|
@@ -82,7 +82,9 @@ def fixedclipping_mod(
|
|
|
82
82
|
clipping_norm,
|
|
83
83
|
)
|
|
84
84
|
|
|
85
|
-
log(
|
|
85
|
+
log(
|
|
86
|
+
INFO, "fixedclipping_mod: parameters are clipped by value: %.4f.", clipping_norm
|
|
87
|
+
)
|
|
86
88
|
|
|
87
89
|
fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
|
|
88
90
|
out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
|
|
@@ -146,7 +148,7 @@ def adaptiveclipping_mod(
|
|
|
146
148
|
)
|
|
147
149
|
log(
|
|
148
150
|
INFO,
|
|
149
|
-
"adaptiveclipping_mod: parameters are clipped by value:
|
|
151
|
+
"adaptiveclipping_mod: parameters are clipped by value: %.4f.",
|
|
150
152
|
clipping_norm,
|
|
151
153
|
)
|
|
152
154
|
|
flwr/client/mod/comms_mods.py
CHANGED
|
@@ -29,7 +29,7 @@ def message_size_mod(
|
|
|
29
29
|
) -> Message:
|
|
30
30
|
"""Message size mod.
|
|
31
31
|
|
|
32
|
-
This mod logs the size in
|
|
32
|
+
This mod logs the size in bytes of the message being transmited.
|
|
33
33
|
"""
|
|
34
34
|
message_size_in_bytes = 0
|
|
35
35
|
|
|
@@ -42,7 +42,7 @@ def message_size_mod(
|
|
|
42
42
|
for m_record in msg.content.metrics_records.values():
|
|
43
43
|
message_size_in_bytes += m_record.count_bytes()
|
|
44
44
|
|
|
45
|
-
log(INFO, "Message size: %i
|
|
45
|
+
log(INFO, "Message size: %i bytes", message_size_in_bytes)
|
|
46
46
|
|
|
47
47
|
return call_next(msg, ctxt)
|
|
48
48
|
|
|
@@ -53,7 +53,7 @@ def parameters_size_mod(
|
|
|
53
53
|
"""Parameters size mod.
|
|
54
54
|
|
|
55
55
|
This mod logs the number of parameters transmitted in the message as well as their
|
|
56
|
-
size in
|
|
56
|
+
size in bytes.
|
|
57
57
|
"""
|
|
58
58
|
model_size_stats = {}
|
|
59
59
|
parameters_size_in_bytes = 0
|
|
@@ -74,6 +74,6 @@ def parameters_size_mod(
|
|
|
74
74
|
if model_size_stats:
|
|
75
75
|
log(INFO, model_size_stats)
|
|
76
76
|
|
|
77
|
-
log(INFO, "Total parameters
|
|
77
|
+
log(INFO, "Total parameters transmitted: %i bytes", parameters_size_in_bytes)
|
|
78
78
|
|
|
79
79
|
return call_next(msg, ctxt)
|
flwr/client/mod/localdp_mod.py
CHANGED
|
@@ -128,7 +128,9 @@ class LocalDpMod:
|
|
|
128
128
|
self.clipping_norm,
|
|
129
129
|
)
|
|
130
130
|
log(
|
|
131
|
-
INFO,
|
|
131
|
+
INFO,
|
|
132
|
+
"LocalDpMod: parameters are clipped by value: %.4f.",
|
|
133
|
+
self.clipping_norm,
|
|
132
134
|
)
|
|
133
135
|
|
|
134
136
|
fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
|
|
@@ -137,11 +139,14 @@ class LocalDpMod:
|
|
|
137
139
|
add_localdp_gaussian_noise_to_params(
|
|
138
140
|
fit_res.parameters, self.sensitivity, self.epsilon, self.delta
|
|
139
141
|
)
|
|
142
|
+
|
|
143
|
+
noise_value_sd = (
|
|
144
|
+
self.sensitivity * np.sqrt(2 * np.log(1.25 / self.delta)) / self.epsilon
|
|
145
|
+
)
|
|
140
146
|
log(
|
|
141
147
|
INFO,
|
|
142
|
-
"LocalDpMod: local DP noise with "
|
|
143
|
-
|
|
144
|
-
self.sensitivity * np.sqrt(2 * np.log(1.25 / self.delta)) / self.epsilon,
|
|
148
|
+
"LocalDpMod: local DP noise with %.4f stedv added to parameters",
|
|
149
|
+
noise_value_sd,
|
|
145
150
|
)
|
|
146
151
|
|
|
147
152
|
out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
|