flwr-nightly 1.9.0.dev20240420__py3-none-any.whl → 1.9.0.dev20240509__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +151 -0
- flwr/cli/config_utils.py +18 -46
- flwr/cli/new/new.py +44 -18
- flwr/cli/new/templates/app/code/client.hf.py.tpl +55 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +15 -29
- flwr/cli/new/templates/app/code/server.hf.py.tpl +17 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +9 -1
- flwr/cli/new/templates/app/code/task.hf.py.tpl +87 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +31 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
- flwr/cli/run/run.py +1 -1
- flwr/cli/utils.py +18 -17
- flwr/client/__init__.py +1 -1
- flwr/client/app.py +17 -93
- flwr/client/grpc_client/connection.py +6 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +17 -2
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/localdp_mod.py +9 -3
- flwr/client/rest_client/connection.py +5 -1
- flwr/client/supernode/__init__.py +2 -0
- flwr/client/supernode/app.py +181 -7
- flwr/common/grpc.py +5 -1
- flwr/common/logger.py +37 -4
- flwr/common/message.py +105 -86
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +17 -5
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
- flwr/server/__init__.py +0 -2
- flwr/server/app.py +118 -2
- flwr/server/compat/app.py +5 -56
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +27 -72
- flwr/server/driver/__init__.py +3 -0
- flwr/server/driver/driver.py +12 -242
- flwr/server/driver/grpc_driver.py +315 -0
- flwr/server/history.py +20 -20
- flwr/server/run_serverapp.py +18 -4
- flwr/server/server.py +2 -5
- flwr/server/strategy/dp_adaptive_clipping.py +5 -3
- flwr/server/strategy/dp_fixed_clipping.py +6 -3
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +9 -6
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +76 -8
- flwr/server/superlink/state/sqlite_state.py +116 -11
- flwr/server/superlink/state/state.py +35 -3
- flwr/simulation/__init__.py +2 -2
- flwr/simulation/app.py +16 -1
- flwr/simulation/run_simulation.py +14 -9
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/METADATA +3 -2
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/RECORD +70 -55
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/entry_points.txt +1 -1
- flwr/server/driver/abc_driver.py +0 -140
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/WHEEL +0 -0
flwr/client/app.py
CHANGED
|
@@ -14,13 +14,12 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower client app."""
|
|
16
16
|
|
|
17
|
-
import argparse
|
|
18
17
|
import sys
|
|
19
18
|
import time
|
|
20
19
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
21
|
-
from pathlib import Path
|
|
22
20
|
from typing import Callable, ContextManager, Optional, Tuple, Type, Union
|
|
23
21
|
|
|
22
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
23
|
from grpc import RpcError
|
|
25
24
|
|
|
26
25
|
from flwr.client.client import Client
|
|
@@ -36,10 +35,8 @@ from flwr.common.constant import (
|
|
|
36
35
|
TRANSPORT_TYPES,
|
|
37
36
|
ErrorCode,
|
|
38
37
|
)
|
|
39
|
-
from flwr.common.exit_handlers import register_exit_handlers
|
|
40
38
|
from flwr.common.logger import log, warn_deprecated_feature
|
|
41
39
|
from flwr.common.message import Error
|
|
42
|
-
from flwr.common.object_ref import load_app, validate
|
|
43
40
|
from flwr.common.retry_invoker import RetryInvoker, exponential
|
|
44
41
|
|
|
45
42
|
from .grpc_client.connection import grpc_connection
|
|
@@ -47,94 +44,6 @@ from .grpc_rere_client.connection import grpc_request_response
|
|
|
47
44
|
from .message_handler.message_handler import handle_control_message
|
|
48
45
|
from .node_state import NodeState
|
|
49
46
|
from .numpy_client import NumPyClient
|
|
50
|
-
from .supernode.app import parse_args_run_client_app
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def run_client_app() -> None:
|
|
54
|
-
"""Run Flower client app."""
|
|
55
|
-
log(INFO, "Long-running Flower client starting")
|
|
56
|
-
|
|
57
|
-
event(EventType.RUN_CLIENT_APP_ENTER)
|
|
58
|
-
|
|
59
|
-
args = _parse_args_run_client_app().parse_args()
|
|
60
|
-
|
|
61
|
-
# Obtain certificates
|
|
62
|
-
if args.insecure:
|
|
63
|
-
if args.root_certificates is not None:
|
|
64
|
-
sys.exit(
|
|
65
|
-
"Conflicting options: The '--insecure' flag disables HTTPS, "
|
|
66
|
-
"but '--root-certificates' was also specified. Please remove "
|
|
67
|
-
"the '--root-certificates' option when running in insecure mode, "
|
|
68
|
-
"or omit '--insecure' to use HTTPS."
|
|
69
|
-
)
|
|
70
|
-
log(
|
|
71
|
-
WARN,
|
|
72
|
-
"Option `--insecure` was set. "
|
|
73
|
-
"Starting insecure HTTP client connected to %s.",
|
|
74
|
-
args.server,
|
|
75
|
-
)
|
|
76
|
-
root_certificates = None
|
|
77
|
-
else:
|
|
78
|
-
# Load the certificates if provided, or load the system certificates
|
|
79
|
-
cert_path = args.root_certificates
|
|
80
|
-
if cert_path is None:
|
|
81
|
-
root_certificates = None
|
|
82
|
-
else:
|
|
83
|
-
root_certificates = Path(cert_path).read_bytes()
|
|
84
|
-
log(
|
|
85
|
-
DEBUG,
|
|
86
|
-
"Starting secure HTTPS client connected to %s "
|
|
87
|
-
"with the following certificates: %s.",
|
|
88
|
-
args.server,
|
|
89
|
-
cert_path,
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
log(
|
|
93
|
-
DEBUG,
|
|
94
|
-
"Flower will load ClientApp `%s`",
|
|
95
|
-
getattr(args, "client-app"),
|
|
96
|
-
)
|
|
97
|
-
|
|
98
|
-
client_app_dir = args.dir
|
|
99
|
-
if client_app_dir is not None:
|
|
100
|
-
sys.path.insert(0, client_app_dir)
|
|
101
|
-
|
|
102
|
-
app_ref: str = getattr(args, "client-app")
|
|
103
|
-
valid, error_msg = validate(app_ref)
|
|
104
|
-
if not valid and error_msg:
|
|
105
|
-
raise LoadClientAppError(error_msg) from None
|
|
106
|
-
|
|
107
|
-
def _load() -> ClientApp:
|
|
108
|
-
client_app = load_app(app_ref, LoadClientAppError)
|
|
109
|
-
|
|
110
|
-
if not isinstance(client_app, ClientApp):
|
|
111
|
-
raise LoadClientAppError(
|
|
112
|
-
f"Attribute {app_ref} is not of type {ClientApp}",
|
|
113
|
-
) from None
|
|
114
|
-
|
|
115
|
-
return client_app
|
|
116
|
-
|
|
117
|
-
_start_client_internal(
|
|
118
|
-
server_address=args.server,
|
|
119
|
-
load_client_app_fn=_load,
|
|
120
|
-
transport="rest" if args.rest else "grpc-rere",
|
|
121
|
-
root_certificates=root_certificates,
|
|
122
|
-
insecure=args.insecure,
|
|
123
|
-
max_retries=args.max_retries,
|
|
124
|
-
max_wait_time=args.max_wait_time,
|
|
125
|
-
)
|
|
126
|
-
register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE)
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
def _parse_args_run_client_app() -> argparse.ArgumentParser:
|
|
130
|
-
"""Parse flower-client-app command line arguments."""
|
|
131
|
-
parser = argparse.ArgumentParser(
|
|
132
|
-
description="Start a Flower client app",
|
|
133
|
-
)
|
|
134
|
-
|
|
135
|
-
parse_args_run_client_app(parser=parser)
|
|
136
|
-
|
|
137
|
-
return parser
|
|
138
47
|
|
|
139
48
|
|
|
140
49
|
def _check_actionable_client(
|
|
@@ -165,6 +74,9 @@ def start_client(
|
|
|
165
74
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
166
75
|
insecure: Optional[bool] = None,
|
|
167
76
|
transport: Optional[str] = None,
|
|
77
|
+
authentication_keys: Optional[
|
|
78
|
+
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
79
|
+
] = None,
|
|
168
80
|
max_retries: Optional[int] = None,
|
|
169
81
|
max_wait_time: Optional[float] = None,
|
|
170
82
|
) -> None:
|
|
@@ -249,6 +161,7 @@ def start_client(
|
|
|
249
161
|
root_certificates=root_certificates,
|
|
250
162
|
insecure=insecure,
|
|
251
163
|
transport=transport,
|
|
164
|
+
authentication_keys=authentication_keys,
|
|
252
165
|
max_retries=max_retries,
|
|
253
166
|
max_wait_time=max_wait_time,
|
|
254
167
|
)
|
|
@@ -269,6 +182,9 @@ def _start_client_internal(
|
|
|
269
182
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
270
183
|
insecure: Optional[bool] = None,
|
|
271
184
|
transport: Optional[str] = None,
|
|
185
|
+
authentication_keys: Optional[
|
|
186
|
+
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
187
|
+
] = None,
|
|
272
188
|
max_retries: Optional[int] = None,
|
|
273
189
|
max_wait_time: Optional[float] = None,
|
|
274
190
|
) -> None:
|
|
@@ -393,6 +309,7 @@ def _start_client_internal(
|
|
|
393
309
|
retry_invoker,
|
|
394
310
|
grpc_max_message_length,
|
|
395
311
|
root_certificates,
|
|
312
|
+
authentication_keys,
|
|
396
313
|
) as conn:
|
|
397
314
|
# pylint: disable-next=W0612
|
|
398
315
|
receive, send, create_node, delete_node, get_run = conn
|
|
@@ -606,7 +523,14 @@ def start_numpy_client(
|
|
|
606
523
|
|
|
607
524
|
def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
|
608
525
|
Callable[
|
|
609
|
-
[
|
|
526
|
+
[
|
|
527
|
+
str,
|
|
528
|
+
bool,
|
|
529
|
+
RetryInvoker,
|
|
530
|
+
int,
|
|
531
|
+
Union[bytes, str, None],
|
|
532
|
+
Optional[Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]],
|
|
533
|
+
],
|
|
610
534
|
ContextManager[
|
|
611
535
|
Tuple[
|
|
612
536
|
Callable[[], Optional[Message]],
|
|
@@ -22,6 +22,8 @@ from pathlib import Path
|
|
|
22
22
|
from queue import Queue
|
|
23
23
|
from typing import Callable, Iterator, Optional, Tuple, Union, cast
|
|
24
24
|
|
|
25
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
26
|
+
|
|
25
27
|
from flwr.common import (
|
|
26
28
|
DEFAULT_TTL,
|
|
27
29
|
GRPC_MAX_MESSAGE_LENGTH,
|
|
@@ -56,12 +58,15 @@ def on_channel_state_change(channel_connectivity: str) -> None:
|
|
|
56
58
|
|
|
57
59
|
|
|
58
60
|
@contextmanager
|
|
59
|
-
def grpc_connection( # pylint: disable=R0915
|
|
61
|
+
def grpc_connection( # pylint: disable=R0913, R0915
|
|
60
62
|
server_address: str,
|
|
61
63
|
insecure: bool,
|
|
62
64
|
retry_invoker: RetryInvoker, # pylint: disable=unused-argument
|
|
63
65
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
64
66
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
67
|
+
authentication_keys: Optional[ # pylint: disable=unused-argument
|
|
68
|
+
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
69
|
+
] = None,
|
|
65
70
|
) -> Iterator[
|
|
66
71
|
Tuple[
|
|
67
72
|
Callable[[], Optional[Message]],
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower client interceptor."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import base64
|
|
19
|
+
import collections
|
|
20
|
+
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
|
21
|
+
|
|
22
|
+
import grpc
|
|
23
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
|
+
|
|
25
|
+
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
26
|
+
bytes_to_public_key,
|
|
27
|
+
compute_hmac,
|
|
28
|
+
generate_shared_key,
|
|
29
|
+
public_key_to_bytes,
|
|
30
|
+
)
|
|
31
|
+
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
32
|
+
CreateNodeRequest,
|
|
33
|
+
DeleteNodeRequest,
|
|
34
|
+
GetRunRequest,
|
|
35
|
+
PingRequest,
|
|
36
|
+
PullTaskInsRequest,
|
|
37
|
+
PushTaskResRequest,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
_PUBLIC_KEY_HEADER = "public-key"
|
|
41
|
+
_AUTH_TOKEN_HEADER = "auth-token"
|
|
42
|
+
|
|
43
|
+
Request = Union[
|
|
44
|
+
CreateNodeRequest,
|
|
45
|
+
DeleteNodeRequest,
|
|
46
|
+
PullTaskInsRequest,
|
|
47
|
+
PushTaskResRequest,
|
|
48
|
+
GetRunRequest,
|
|
49
|
+
PingRequest,
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _get_value_from_tuples(
|
|
54
|
+
key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]]
|
|
55
|
+
) -> bytes:
|
|
56
|
+
value = next((value for key, value in tuples if key == key_string), "")
|
|
57
|
+
if isinstance(value, str):
|
|
58
|
+
return value.encode()
|
|
59
|
+
|
|
60
|
+
return value
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class _ClientCallDetails(
|
|
64
|
+
collections.namedtuple(
|
|
65
|
+
"_ClientCallDetails", ("method", "timeout", "metadata", "credentials")
|
|
66
|
+
),
|
|
67
|
+
grpc.ClientCallDetails, # type: ignore
|
|
68
|
+
):
|
|
69
|
+
"""Details for each client call.
|
|
70
|
+
|
|
71
|
+
The class will be passed on as the first argument in continuation function.
|
|
72
|
+
In our case, `AuthenticateClientInterceptor` adds new metadata to the construct.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
|
|
77
|
+
"""Client interceptor for client authentication."""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
private_key: ec.EllipticCurvePrivateKey,
|
|
82
|
+
public_key: ec.EllipticCurvePublicKey,
|
|
83
|
+
):
|
|
84
|
+
self.private_key = private_key
|
|
85
|
+
self.public_key = public_key
|
|
86
|
+
self.shared_secret: Optional[bytes] = None
|
|
87
|
+
self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None
|
|
88
|
+
self.encoded_public_key = base64.urlsafe_b64encode(
|
|
89
|
+
public_key_to_bytes(self.public_key)
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def intercept_unary_unary(
|
|
93
|
+
self,
|
|
94
|
+
continuation: Callable[[Any, Any], Any],
|
|
95
|
+
client_call_details: grpc.ClientCallDetails,
|
|
96
|
+
request: Request,
|
|
97
|
+
) -> grpc.Call:
|
|
98
|
+
"""Flower client interceptor.
|
|
99
|
+
|
|
100
|
+
Intercept unary call from client and add necessary authentication header in the
|
|
101
|
+
RPC metadata.
|
|
102
|
+
"""
|
|
103
|
+
metadata = []
|
|
104
|
+
postprocess = False
|
|
105
|
+
if client_call_details.metadata is not None:
|
|
106
|
+
metadata = list(client_call_details.metadata)
|
|
107
|
+
|
|
108
|
+
# Always add the public key header
|
|
109
|
+
metadata.append(
|
|
110
|
+
(
|
|
111
|
+
_PUBLIC_KEY_HEADER,
|
|
112
|
+
self.encoded_public_key,
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if isinstance(request, CreateNodeRequest):
|
|
117
|
+
postprocess = True
|
|
118
|
+
elif isinstance(
|
|
119
|
+
request,
|
|
120
|
+
(
|
|
121
|
+
DeleteNodeRequest,
|
|
122
|
+
PullTaskInsRequest,
|
|
123
|
+
PushTaskResRequest,
|
|
124
|
+
GetRunRequest,
|
|
125
|
+
PingRequest,
|
|
126
|
+
),
|
|
127
|
+
):
|
|
128
|
+
if self.shared_secret is None:
|
|
129
|
+
raise RuntimeError("Failure to compute hmac")
|
|
130
|
+
|
|
131
|
+
metadata.append(
|
|
132
|
+
(
|
|
133
|
+
_AUTH_TOKEN_HEADER,
|
|
134
|
+
base64.urlsafe_b64encode(
|
|
135
|
+
compute_hmac(
|
|
136
|
+
self.shared_secret, request.SerializeToString(True)
|
|
137
|
+
)
|
|
138
|
+
),
|
|
139
|
+
)
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
client_call_details = _ClientCallDetails(
|
|
143
|
+
client_call_details.method,
|
|
144
|
+
client_call_details.timeout,
|
|
145
|
+
metadata,
|
|
146
|
+
client_call_details.credentials,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
response = continuation(client_call_details, request)
|
|
150
|
+
if postprocess:
|
|
151
|
+
server_public_key_bytes = base64.urlsafe_b64decode(
|
|
152
|
+
_get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
|
|
153
|
+
)
|
|
154
|
+
self.server_public_key = bytes_to_public_key(server_public_key_bytes)
|
|
155
|
+
self.shared_secret = generate_shared_key(
|
|
156
|
+
self.private_key, self.server_public_key
|
|
157
|
+
)
|
|
158
|
+
return response
|
|
@@ -21,7 +21,10 @@ from contextlib import contextmanager
|
|
|
21
21
|
from copy import copy
|
|
22
22
|
from logging import DEBUG, ERROR
|
|
23
23
|
from pathlib import Path
|
|
24
|
-
from typing import Callable, Iterator, Optional, Tuple, Union, cast
|
|
24
|
+
from typing import Callable, Iterator, Optional, Sequence, Tuple, Union, cast
|
|
25
|
+
|
|
26
|
+
import grpc
|
|
27
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
25
28
|
|
|
26
29
|
from flwr.client.heartbeat import start_ping_loop
|
|
27
30
|
from flwr.client.message_handler.message_handler import validate_out_message
|
|
@@ -52,6 +55,8 @@ from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
|
|
52
55
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
53
56
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
54
57
|
|
|
58
|
+
from .client_interceptor import AuthenticateClientInterceptor
|
|
59
|
+
|
|
55
60
|
|
|
56
61
|
def on_channel_state_change(channel_connectivity: str) -> None:
|
|
57
62
|
"""Log channel connectivity."""
|
|
@@ -59,12 +64,15 @@ def on_channel_state_change(channel_connectivity: str) -> None:
|
|
|
59
64
|
|
|
60
65
|
|
|
61
66
|
@contextmanager
|
|
62
|
-
def grpc_request_response( # pylint: disable=R0914, R0915
|
|
67
|
+
def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
63
68
|
server_address: str,
|
|
64
69
|
insecure: bool,
|
|
65
70
|
retry_invoker: RetryInvoker,
|
|
66
71
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
|
|
67
72
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
73
|
+
authentication_keys: Optional[
|
|
74
|
+
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
75
|
+
] = None,
|
|
68
76
|
) -> Iterator[
|
|
69
77
|
Tuple[
|
|
70
78
|
Callable[[], Optional[Message]],
|
|
@@ -109,11 +117,18 @@ def grpc_request_response( # pylint: disable=R0914, R0915
|
|
|
109
117
|
if isinstance(root_certificates, str):
|
|
110
118
|
root_certificates = Path(root_certificates).read_bytes()
|
|
111
119
|
|
|
120
|
+
interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None
|
|
121
|
+
if authentication_keys is not None:
|
|
122
|
+
interceptors = AuthenticateClientInterceptor(
|
|
123
|
+
authentication_keys[0], authentication_keys[1]
|
|
124
|
+
)
|
|
125
|
+
|
|
112
126
|
channel = create_channel(
|
|
113
127
|
server_address=server_address,
|
|
114
128
|
insecure=insecure,
|
|
115
129
|
root_certificates=root_certificates,
|
|
116
130
|
max_message_length=max_message_length,
|
|
131
|
+
interceptors=interceptors,
|
|
117
132
|
)
|
|
118
133
|
channel.subscribe(on_channel_state_change)
|
|
119
134
|
|
|
@@ -82,7 +82,9 @@ def fixedclipping_mod(
|
|
|
82
82
|
clipping_norm,
|
|
83
83
|
)
|
|
84
84
|
|
|
85
|
-
log(
|
|
85
|
+
log(
|
|
86
|
+
INFO, "fixedclipping_mod: parameters are clipped by value: %.4f.", clipping_norm
|
|
87
|
+
)
|
|
86
88
|
|
|
87
89
|
fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
|
|
88
90
|
out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
|
|
@@ -146,7 +148,7 @@ def adaptiveclipping_mod(
|
|
|
146
148
|
)
|
|
147
149
|
log(
|
|
148
150
|
INFO,
|
|
149
|
-
"adaptiveclipping_mod: parameters are clipped by value:
|
|
151
|
+
"adaptiveclipping_mod: parameters are clipped by value: %.4f.",
|
|
150
152
|
clipping_norm,
|
|
151
153
|
)
|
|
152
154
|
|
flwr/client/mod/localdp_mod.py
CHANGED
|
@@ -128,7 +128,9 @@ class LocalDpMod:
|
|
|
128
128
|
self.clipping_norm,
|
|
129
129
|
)
|
|
130
130
|
log(
|
|
131
|
-
INFO,
|
|
131
|
+
INFO,
|
|
132
|
+
"LocalDpMod: parameters are clipped by value: %.4f.",
|
|
133
|
+
self.clipping_norm,
|
|
132
134
|
)
|
|
133
135
|
|
|
134
136
|
fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
|
|
@@ -137,11 +139,15 @@ class LocalDpMod:
|
|
|
137
139
|
add_localdp_gaussian_noise_to_params(
|
|
138
140
|
fit_res.parameters, self.sensitivity, self.epsilon, self.delta
|
|
139
141
|
)
|
|
142
|
+
|
|
143
|
+
noise_value_sd = (
|
|
144
|
+
self.sensitivity * np.sqrt(2 * np.log(1.25 / self.delta)) / self.epsilon
|
|
145
|
+
)
|
|
140
146
|
log(
|
|
141
147
|
INFO,
|
|
142
148
|
"LocalDpMod: local DP noise with "
|
|
143
|
-
"standard deviation:
|
|
144
|
-
|
|
149
|
+
"standard deviation: %.4f added to parameters.",
|
|
150
|
+
noise_value_sd,
|
|
145
151
|
)
|
|
146
152
|
|
|
147
153
|
out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
|
|
@@ -23,6 +23,7 @@ from copy import copy
|
|
|
23
23
|
from logging import ERROR, INFO, WARN
|
|
24
24
|
from typing import Callable, Iterator, Optional, Tuple, Type, TypeVar, Union
|
|
25
25
|
|
|
26
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
26
27
|
from google.protobuf.message import Message as GrpcMessage
|
|
27
28
|
|
|
28
29
|
from flwr.client.heartbeat import start_ping_loop
|
|
@@ -74,7 +75,7 @@ T = TypeVar("T", bound=GrpcMessage)
|
|
|
74
75
|
|
|
75
76
|
|
|
76
77
|
@contextmanager
|
|
77
|
-
def http_request_response( # pylint: disable
|
|
78
|
+
def http_request_response( # pylint: disable=,R0913, R0914, R0915
|
|
78
79
|
server_address: str,
|
|
79
80
|
insecure: bool, # pylint: disable=unused-argument
|
|
80
81
|
retry_invoker: RetryInvoker,
|
|
@@ -82,6 +83,9 @@ def http_request_response( # pylint: disable=R0914, R0915
|
|
|
82
83
|
root_certificates: Optional[
|
|
83
84
|
Union[bytes, str]
|
|
84
85
|
] = None, # pylint: disable=unused-argument
|
|
86
|
+
authentication_keys: Optional[ # pylint: disable=unused-argument
|
|
87
|
+
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
88
|
+
] = None,
|
|
85
89
|
) -> Iterator[
|
|
86
90
|
Tuple[
|
|
87
91
|
Callable[[], Optional[Message]],
|