flwr-nightly 1.9.0.dev20240420__py3-none-any.whl → 1.9.0.dev20240507__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +151 -0
- flwr/cli/config_utils.py +18 -46
- flwr/cli/new/new.py +42 -18
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +15 -29
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +9 -1
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
- flwr/cli/run/run.py +1 -1
- flwr/cli/utils.py +18 -17
- flwr/client/__init__.py +1 -1
- flwr/client/app.py +17 -93
- flwr/client/grpc_client/connection.py +6 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +17 -2
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/localdp_mod.py +9 -3
- flwr/client/rest_client/connection.py +5 -1
- flwr/client/supernode/__init__.py +2 -0
- flwr/client/supernode/app.py +181 -7
- flwr/common/grpc.py +5 -1
- flwr/common/logger.py +37 -4
- flwr/common/message.py +105 -86
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +17 -5
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
- flwr/server/app.py +111 -1
- flwr/server/compat/app.py +2 -2
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +27 -72
- flwr/server/driver/__init__.py +3 -0
- flwr/server/driver/driver.py +12 -242
- flwr/server/driver/grpc_driver.py +315 -0
- flwr/server/run_serverapp.py +18 -4
- flwr/server/strategy/dp_adaptive_clipping.py +5 -3
- flwr/server/strategy/dp_fixed_clipping.py +6 -3
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +76 -8
- flwr/server/superlink/state/sqlite_state.py +116 -11
- flwr/server/superlink/state/state.py +35 -3
- flwr/simulation/__init__.py +2 -2
- flwr/simulation/app.py +16 -1
- flwr/simulation/run_simulation.py +10 -7
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/METADATA +3 -2
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/RECORD +63 -52
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/entry_points.txt +1 -1
- flwr/server/driver/abc_driver.py +0 -140
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower client interceptor."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import base64
|
|
19
|
+
import collections
|
|
20
|
+
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
|
21
|
+
|
|
22
|
+
import grpc
|
|
23
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
|
+
|
|
25
|
+
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
26
|
+
bytes_to_public_key,
|
|
27
|
+
compute_hmac,
|
|
28
|
+
generate_shared_key,
|
|
29
|
+
public_key_to_bytes,
|
|
30
|
+
)
|
|
31
|
+
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
32
|
+
CreateNodeRequest,
|
|
33
|
+
DeleteNodeRequest,
|
|
34
|
+
GetRunRequest,
|
|
35
|
+
PingRequest,
|
|
36
|
+
PullTaskInsRequest,
|
|
37
|
+
PushTaskResRequest,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
_PUBLIC_KEY_HEADER = "public-key"
|
|
41
|
+
_AUTH_TOKEN_HEADER = "auth-token"
|
|
42
|
+
|
|
43
|
+
Request = Union[
|
|
44
|
+
CreateNodeRequest,
|
|
45
|
+
DeleteNodeRequest,
|
|
46
|
+
PullTaskInsRequest,
|
|
47
|
+
PushTaskResRequest,
|
|
48
|
+
GetRunRequest,
|
|
49
|
+
PingRequest,
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _get_value_from_tuples(
|
|
54
|
+
key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]]
|
|
55
|
+
) -> bytes:
|
|
56
|
+
value = next((value for key, value in tuples if key == key_string), "")
|
|
57
|
+
if isinstance(value, str):
|
|
58
|
+
return value.encode()
|
|
59
|
+
|
|
60
|
+
return value
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class _ClientCallDetails(
|
|
64
|
+
collections.namedtuple(
|
|
65
|
+
"_ClientCallDetails", ("method", "timeout", "metadata", "credentials")
|
|
66
|
+
),
|
|
67
|
+
grpc.ClientCallDetails, # type: ignore
|
|
68
|
+
):
|
|
69
|
+
"""Details for each client call.
|
|
70
|
+
|
|
71
|
+
The class will be passed on as the first argument in continuation function.
|
|
72
|
+
In our case, `AuthenticateClientInterceptor` adds new metadata to the construct.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
|
|
77
|
+
"""Client interceptor for client authentication."""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
private_key: ec.EllipticCurvePrivateKey,
|
|
82
|
+
public_key: ec.EllipticCurvePublicKey,
|
|
83
|
+
):
|
|
84
|
+
self.private_key = private_key
|
|
85
|
+
self.public_key = public_key
|
|
86
|
+
self.shared_secret: Optional[bytes] = None
|
|
87
|
+
self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None
|
|
88
|
+
self.encoded_public_key = base64.urlsafe_b64encode(
|
|
89
|
+
public_key_to_bytes(self.public_key)
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def intercept_unary_unary(
|
|
93
|
+
self,
|
|
94
|
+
continuation: Callable[[Any, Any], Any],
|
|
95
|
+
client_call_details: grpc.ClientCallDetails,
|
|
96
|
+
request: Request,
|
|
97
|
+
) -> grpc.Call:
|
|
98
|
+
"""Flower client interceptor.
|
|
99
|
+
|
|
100
|
+
Intercept unary call from client and add necessary authentication header in the
|
|
101
|
+
RPC metadata.
|
|
102
|
+
"""
|
|
103
|
+
metadata = []
|
|
104
|
+
postprocess = False
|
|
105
|
+
if client_call_details.metadata is not None:
|
|
106
|
+
metadata = list(client_call_details.metadata)
|
|
107
|
+
|
|
108
|
+
# Always add the public key header
|
|
109
|
+
metadata.append(
|
|
110
|
+
(
|
|
111
|
+
_PUBLIC_KEY_HEADER,
|
|
112
|
+
self.encoded_public_key,
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if isinstance(request, CreateNodeRequest):
|
|
117
|
+
postprocess = True
|
|
118
|
+
elif isinstance(
|
|
119
|
+
request,
|
|
120
|
+
(
|
|
121
|
+
DeleteNodeRequest,
|
|
122
|
+
PullTaskInsRequest,
|
|
123
|
+
PushTaskResRequest,
|
|
124
|
+
GetRunRequest,
|
|
125
|
+
PingRequest,
|
|
126
|
+
),
|
|
127
|
+
):
|
|
128
|
+
if self.shared_secret is None:
|
|
129
|
+
raise RuntimeError("Failure to compute hmac")
|
|
130
|
+
|
|
131
|
+
metadata.append(
|
|
132
|
+
(
|
|
133
|
+
_AUTH_TOKEN_HEADER,
|
|
134
|
+
base64.urlsafe_b64encode(
|
|
135
|
+
compute_hmac(
|
|
136
|
+
self.shared_secret, request.SerializeToString(True)
|
|
137
|
+
)
|
|
138
|
+
),
|
|
139
|
+
)
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
client_call_details = _ClientCallDetails(
|
|
143
|
+
client_call_details.method,
|
|
144
|
+
client_call_details.timeout,
|
|
145
|
+
metadata,
|
|
146
|
+
client_call_details.credentials,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
response = continuation(client_call_details, request)
|
|
150
|
+
if postprocess:
|
|
151
|
+
server_public_key_bytes = base64.urlsafe_b64decode(
|
|
152
|
+
_get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
|
|
153
|
+
)
|
|
154
|
+
self.server_public_key = bytes_to_public_key(server_public_key_bytes)
|
|
155
|
+
self.shared_secret = generate_shared_key(
|
|
156
|
+
self.private_key, self.server_public_key
|
|
157
|
+
)
|
|
158
|
+
return response
|
|
@@ -21,7 +21,10 @@ from contextlib import contextmanager
|
|
|
21
21
|
from copy import copy
|
|
22
22
|
from logging import DEBUG, ERROR
|
|
23
23
|
from pathlib import Path
|
|
24
|
-
from typing import Callable, Iterator, Optional, Tuple, Union, cast
|
|
24
|
+
from typing import Callable, Iterator, Optional, Sequence, Tuple, Union, cast
|
|
25
|
+
|
|
26
|
+
import grpc
|
|
27
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
25
28
|
|
|
26
29
|
from flwr.client.heartbeat import start_ping_loop
|
|
27
30
|
from flwr.client.message_handler.message_handler import validate_out_message
|
|
@@ -52,6 +55,8 @@ from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
|
|
52
55
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
53
56
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
54
57
|
|
|
58
|
+
from .client_interceptor import AuthenticateClientInterceptor
|
|
59
|
+
|
|
55
60
|
|
|
56
61
|
def on_channel_state_change(channel_connectivity: str) -> None:
|
|
57
62
|
"""Log channel connectivity."""
|
|
@@ -59,12 +64,15 @@ def on_channel_state_change(channel_connectivity: str) -> None:
|
|
|
59
64
|
|
|
60
65
|
|
|
61
66
|
@contextmanager
|
|
62
|
-
def grpc_request_response( # pylint: disable=R0914, R0915
|
|
67
|
+
def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
63
68
|
server_address: str,
|
|
64
69
|
insecure: bool,
|
|
65
70
|
retry_invoker: RetryInvoker,
|
|
66
71
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
|
|
67
72
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
73
|
+
authentication_keys: Optional[
|
|
74
|
+
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
75
|
+
] = None,
|
|
68
76
|
) -> Iterator[
|
|
69
77
|
Tuple[
|
|
70
78
|
Callable[[], Optional[Message]],
|
|
@@ -109,11 +117,18 @@ def grpc_request_response( # pylint: disable=R0914, R0915
|
|
|
109
117
|
if isinstance(root_certificates, str):
|
|
110
118
|
root_certificates = Path(root_certificates).read_bytes()
|
|
111
119
|
|
|
120
|
+
interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None
|
|
121
|
+
if authentication_keys is not None:
|
|
122
|
+
interceptors = AuthenticateClientInterceptor(
|
|
123
|
+
authentication_keys[0], authentication_keys[1]
|
|
124
|
+
)
|
|
125
|
+
|
|
112
126
|
channel = create_channel(
|
|
113
127
|
server_address=server_address,
|
|
114
128
|
insecure=insecure,
|
|
115
129
|
root_certificates=root_certificates,
|
|
116
130
|
max_message_length=max_message_length,
|
|
131
|
+
interceptors=interceptors,
|
|
117
132
|
)
|
|
118
133
|
channel.subscribe(on_channel_state_change)
|
|
119
134
|
|
|
@@ -82,7 +82,9 @@ def fixedclipping_mod(
|
|
|
82
82
|
clipping_norm,
|
|
83
83
|
)
|
|
84
84
|
|
|
85
|
-
log(
|
|
85
|
+
log(
|
|
86
|
+
INFO, "fixedclipping_mod: parameters are clipped by value: %.4f.", clipping_norm
|
|
87
|
+
)
|
|
86
88
|
|
|
87
89
|
fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
|
|
88
90
|
out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
|
|
@@ -146,7 +148,7 @@ def adaptiveclipping_mod(
|
|
|
146
148
|
)
|
|
147
149
|
log(
|
|
148
150
|
INFO,
|
|
149
|
-
"adaptiveclipping_mod: parameters are clipped by value:
|
|
151
|
+
"adaptiveclipping_mod: parameters are clipped by value: %.4f.",
|
|
150
152
|
clipping_norm,
|
|
151
153
|
)
|
|
152
154
|
|
flwr/client/mod/localdp_mod.py
CHANGED
|
@@ -128,7 +128,9 @@ class LocalDpMod:
|
|
|
128
128
|
self.clipping_norm,
|
|
129
129
|
)
|
|
130
130
|
log(
|
|
131
|
-
INFO,
|
|
131
|
+
INFO,
|
|
132
|
+
"LocalDpMod: parameters are clipped by value: %.4f.",
|
|
133
|
+
self.clipping_norm,
|
|
132
134
|
)
|
|
133
135
|
|
|
134
136
|
fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
|
|
@@ -137,11 +139,15 @@ class LocalDpMod:
|
|
|
137
139
|
add_localdp_gaussian_noise_to_params(
|
|
138
140
|
fit_res.parameters, self.sensitivity, self.epsilon, self.delta
|
|
139
141
|
)
|
|
142
|
+
|
|
143
|
+
noise_value_sd = (
|
|
144
|
+
self.sensitivity * np.sqrt(2 * np.log(1.25 / self.delta)) / self.epsilon
|
|
145
|
+
)
|
|
140
146
|
log(
|
|
141
147
|
INFO,
|
|
142
148
|
"LocalDpMod: local DP noise with "
|
|
143
|
-
"standard deviation:
|
|
144
|
-
|
|
149
|
+
"standard deviation: %.4f added to parameters.",
|
|
150
|
+
noise_value_sd,
|
|
145
151
|
)
|
|
146
152
|
|
|
147
153
|
out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
|
|
@@ -23,6 +23,7 @@ from copy import copy
|
|
|
23
23
|
from logging import ERROR, INFO, WARN
|
|
24
24
|
from typing import Callable, Iterator, Optional, Tuple, Type, TypeVar, Union
|
|
25
25
|
|
|
26
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
26
27
|
from google.protobuf.message import Message as GrpcMessage
|
|
27
28
|
|
|
28
29
|
from flwr.client.heartbeat import start_ping_loop
|
|
@@ -74,7 +75,7 @@ T = TypeVar("T", bound=GrpcMessage)
|
|
|
74
75
|
|
|
75
76
|
|
|
76
77
|
@contextmanager
|
|
77
|
-
def http_request_response( # pylint: disable
|
|
78
|
+
def http_request_response( # pylint: disable=,R0913, R0914, R0915
|
|
78
79
|
server_address: str,
|
|
79
80
|
insecure: bool, # pylint: disable=unused-argument
|
|
80
81
|
retry_invoker: RetryInvoker,
|
|
@@ -82,6 +83,9 @@ def http_request_response( # pylint: disable=R0914, R0915
|
|
|
82
83
|
root_certificates: Optional[
|
|
83
84
|
Union[bytes, str]
|
|
84
85
|
] = None, # pylint: disable=unused-argument
|
|
86
|
+
authentication_keys: Optional[ # pylint: disable=unused-argument
|
|
87
|
+
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
88
|
+
] = None,
|
|
85
89
|
) -> Iterator[
|
|
86
90
|
Tuple[
|
|
87
91
|
Callable[[], Optional[Message]],
|
flwr/client/supernode/app.py
CHANGED
|
@@ -15,11 +15,27 @@
|
|
|
15
15
|
"""Flower SuperNode."""
|
|
16
16
|
|
|
17
17
|
import argparse
|
|
18
|
-
|
|
18
|
+
import sys
|
|
19
|
+
from logging import DEBUG, INFO, WARN
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Callable, Optional, Tuple
|
|
19
22
|
|
|
23
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
|
+
from cryptography.hazmat.primitives.serialization import (
|
|
25
|
+
load_ssh_private_key,
|
|
26
|
+
load_ssh_public_key,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
from flwr.client.client_app import ClientApp, LoadClientAppError
|
|
20
30
|
from flwr.common import EventType, event
|
|
21
31
|
from flwr.common.exit_handlers import register_exit_handlers
|
|
22
32
|
from flwr.common.logger import log
|
|
33
|
+
from flwr.common.object_ref import load_app, validate
|
|
34
|
+
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
35
|
+
ssh_types_to_elliptic_curve,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
from ..app import _start_client_internal
|
|
23
39
|
|
|
24
40
|
|
|
25
41
|
def run_supernode() -> None:
|
|
@@ -28,12 +44,11 @@ def run_supernode() -> None:
|
|
|
28
44
|
|
|
29
45
|
event(EventType.RUN_SUPERNODE_ENTER)
|
|
30
46
|
|
|
31
|
-
|
|
47
|
+
_ = _parse_args_run_supernode().parse_args()
|
|
32
48
|
|
|
33
49
|
log(
|
|
34
50
|
DEBUG,
|
|
35
|
-
"Flower
|
|
36
|
-
getattr(args, "client-app"),
|
|
51
|
+
"Flower SuperNode starting...",
|
|
37
52
|
)
|
|
38
53
|
|
|
39
54
|
# Graceful shutdown
|
|
@@ -42,23 +57,144 @@ def run_supernode() -> None:
|
|
|
42
57
|
)
|
|
43
58
|
|
|
44
59
|
|
|
60
|
+
def run_client_app() -> None:
|
|
61
|
+
"""Run Flower client app."""
|
|
62
|
+
log(INFO, "Long-running Flower client starting")
|
|
63
|
+
|
|
64
|
+
event(EventType.RUN_CLIENT_APP_ENTER)
|
|
65
|
+
|
|
66
|
+
args = _parse_args_run_client_app().parse_args()
|
|
67
|
+
|
|
68
|
+
root_certificates = _get_certificates(args)
|
|
69
|
+
log(
|
|
70
|
+
DEBUG,
|
|
71
|
+
"Flower will load ClientApp `%s`",
|
|
72
|
+
getattr(args, "client-app"),
|
|
73
|
+
)
|
|
74
|
+
load_fn = _get_load_client_app_fn(args)
|
|
75
|
+
authentication_keys = _try_setup_client_authentication(args)
|
|
76
|
+
|
|
77
|
+
_start_client_internal(
|
|
78
|
+
server_address=args.server,
|
|
79
|
+
load_client_app_fn=load_fn,
|
|
80
|
+
transport="rest" if args.rest else "grpc-rere",
|
|
81
|
+
root_certificates=root_certificates,
|
|
82
|
+
insecure=args.insecure,
|
|
83
|
+
authentication_keys=authentication_keys,
|
|
84
|
+
max_retries=args.max_retries,
|
|
85
|
+
max_wait_time=args.max_wait_time,
|
|
86
|
+
)
|
|
87
|
+
register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
|
|
91
|
+
"""Load certificates if specified in args."""
|
|
92
|
+
# Obtain certificates
|
|
93
|
+
if args.insecure:
|
|
94
|
+
if args.root_certificates is not None:
|
|
95
|
+
sys.exit(
|
|
96
|
+
"Conflicting options: The '--insecure' flag disables HTTPS, "
|
|
97
|
+
"but '--root-certificates' was also specified. Please remove "
|
|
98
|
+
"the '--root-certificates' option when running in insecure mode, "
|
|
99
|
+
"or omit '--insecure' to use HTTPS."
|
|
100
|
+
)
|
|
101
|
+
log(
|
|
102
|
+
WARN,
|
|
103
|
+
"Option `--insecure` was set. "
|
|
104
|
+
"Starting insecure HTTP client connected to %s.",
|
|
105
|
+
args.server,
|
|
106
|
+
)
|
|
107
|
+
root_certificates = None
|
|
108
|
+
else:
|
|
109
|
+
# Load the certificates if provided, or load the system certificates
|
|
110
|
+
cert_path = args.root_certificates
|
|
111
|
+
if cert_path is None:
|
|
112
|
+
root_certificates = None
|
|
113
|
+
else:
|
|
114
|
+
root_certificates = Path(cert_path).read_bytes()
|
|
115
|
+
log(
|
|
116
|
+
DEBUG,
|
|
117
|
+
"Starting secure HTTPS client connected to %s "
|
|
118
|
+
"with the following certificates: %s.",
|
|
119
|
+
args.server,
|
|
120
|
+
cert_path,
|
|
121
|
+
)
|
|
122
|
+
return root_certificates
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _get_load_client_app_fn(
|
|
126
|
+
args: argparse.Namespace,
|
|
127
|
+
) -> Callable[[], ClientApp]:
|
|
128
|
+
"""Get the load_client_app_fn function."""
|
|
129
|
+
client_app_dir = args.dir
|
|
130
|
+
if client_app_dir is not None:
|
|
131
|
+
sys.path.insert(0, client_app_dir)
|
|
132
|
+
|
|
133
|
+
app_ref: str = getattr(args, "client-app")
|
|
134
|
+
valid, error_msg = validate(app_ref)
|
|
135
|
+
if not valid and error_msg:
|
|
136
|
+
raise LoadClientAppError(error_msg) from None
|
|
137
|
+
|
|
138
|
+
def _load() -> ClientApp:
|
|
139
|
+
client_app = load_app(app_ref, LoadClientAppError)
|
|
140
|
+
|
|
141
|
+
if not isinstance(client_app, ClientApp):
|
|
142
|
+
raise LoadClientAppError(
|
|
143
|
+
f"Attribute {app_ref} is not of type {ClientApp}",
|
|
144
|
+
) from None
|
|
145
|
+
|
|
146
|
+
return client_app
|
|
147
|
+
|
|
148
|
+
return _load
|
|
149
|
+
|
|
150
|
+
|
|
45
151
|
def _parse_args_run_supernode() -> argparse.ArgumentParser:
|
|
46
152
|
"""Parse flower-supernode command line arguments."""
|
|
47
153
|
parser = argparse.ArgumentParser(
|
|
48
154
|
description="Start a Flower SuperNode",
|
|
49
155
|
)
|
|
50
156
|
|
|
51
|
-
|
|
157
|
+
parser.add_argument(
|
|
158
|
+
"client-app",
|
|
159
|
+
nargs="?",
|
|
160
|
+
default="",
|
|
161
|
+
help="For example: `client:app` or `project.package.module:wrapper.app`. "
|
|
162
|
+
"This is optional and serves as the default ClientApp to be loaded when "
|
|
163
|
+
"the ServerApp does not specify `fab_id` and `fab_version`. "
|
|
164
|
+
"If not provided, defaults to an empty string.",
|
|
165
|
+
)
|
|
166
|
+
_parse_args_common(parser)
|
|
167
|
+
parser.add_argument(
|
|
168
|
+
"--flwr-dir",
|
|
169
|
+
default=None,
|
|
170
|
+
help="""The path containing installed Flower Apps.
|
|
171
|
+
By default, this value isequal to:
|
|
172
|
+
|
|
173
|
+
- `$FLWR_HOME/` if `$FLWR_HOME` is defined
|
|
174
|
+
- `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined
|
|
175
|
+
- `$HOME/.flwr/` in all other cases
|
|
176
|
+
""",
|
|
177
|
+
)
|
|
52
178
|
|
|
53
179
|
return parser
|
|
54
180
|
|
|
55
181
|
|
|
56
|
-
def
|
|
57
|
-
"""Parse command line arguments."""
|
|
182
|
+
def _parse_args_run_client_app() -> argparse.ArgumentParser:
|
|
183
|
+
"""Parse flower-client-app command line arguments."""
|
|
184
|
+
parser = argparse.ArgumentParser(
|
|
185
|
+
description="Start a Flower client app",
|
|
186
|
+
)
|
|
187
|
+
|
|
58
188
|
parser.add_argument(
|
|
59
189
|
"client-app",
|
|
60
190
|
help="For example: `client:app` or `project.package.module:wrapper.app`",
|
|
61
191
|
)
|
|
192
|
+
_parse_args_common(parser=parser)
|
|
193
|
+
|
|
194
|
+
return parser
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _parse_args_common(parser: argparse.ArgumentParser) -> None:
|
|
62
198
|
parser.add_argument(
|
|
63
199
|
"--insecure",
|
|
64
200
|
action="store_true",
|
|
@@ -105,3 +241,41 @@ def parse_args_run_client_app(parser: argparse.ArgumentParser) -> None:
|
|
|
105
241
|
"app from there."
|
|
106
242
|
" Default: current working directory.",
|
|
107
243
|
)
|
|
244
|
+
parser.add_argument(
|
|
245
|
+
"--authentication-keys",
|
|
246
|
+
nargs=2,
|
|
247
|
+
metavar=("CLIENT_PRIVATE_KEY", "CLIENT_PUBLIC_KEY"),
|
|
248
|
+
type=str,
|
|
249
|
+
help="Provide two file paths: (1) the client's private "
|
|
250
|
+
"key file, and (2) the client's public key file.",
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def _try_setup_client_authentication(
|
|
255
|
+
args: argparse.Namespace,
|
|
256
|
+
) -> Optional[Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
|
|
257
|
+
if not args.authentication_keys:
|
|
258
|
+
return None
|
|
259
|
+
|
|
260
|
+
ssh_private_key = load_ssh_private_key(
|
|
261
|
+
Path(args.authentication_keys[0]).read_bytes(),
|
|
262
|
+
None,
|
|
263
|
+
)
|
|
264
|
+
ssh_public_key = load_ssh_public_key(Path(args.authentication_keys[1]).read_bytes())
|
|
265
|
+
|
|
266
|
+
try:
|
|
267
|
+
client_private_key, client_public_key = ssh_types_to_elliptic_curve(
|
|
268
|
+
ssh_private_key, ssh_public_key
|
|
269
|
+
)
|
|
270
|
+
except TypeError:
|
|
271
|
+
sys.exit(
|
|
272
|
+
"The file paths provided could not be read as a private and public "
|
|
273
|
+
"key pair. Client authentication requires an elliptic curve public and "
|
|
274
|
+
"private key pair. Please provide the file paths containing elliptic "
|
|
275
|
+
"curve private and public keys to '--authentication-keys'."
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
return (
|
|
279
|
+
client_private_key,
|
|
280
|
+
client_public_key,
|
|
281
|
+
)
|
flwr/common/grpc.py
CHANGED
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from logging import DEBUG
|
|
19
|
-
from typing import Optional
|
|
19
|
+
from typing import Optional, Sequence
|
|
20
20
|
|
|
21
21
|
import grpc
|
|
22
22
|
|
|
@@ -30,6 +30,7 @@ def create_channel(
|
|
|
30
30
|
insecure: bool,
|
|
31
31
|
root_certificates: Optional[bytes] = None,
|
|
32
32
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
33
|
+
interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None,
|
|
33
34
|
) -> grpc.Channel:
|
|
34
35
|
"""Create a gRPC channel, either secure or insecure."""
|
|
35
36
|
# Check for conflicting parameters
|
|
@@ -57,4 +58,7 @@ def create_channel(
|
|
|
57
58
|
)
|
|
58
59
|
log(DEBUG, "Opened secure gRPC connection using certificates")
|
|
59
60
|
|
|
61
|
+
if interceptors is not None:
|
|
62
|
+
channel = grpc.intercept_channel(channel, interceptors)
|
|
63
|
+
|
|
60
64
|
return channel
|
flwr/common/logger.py
CHANGED
|
@@ -82,13 +82,20 @@ class ConsoleHandler(StreamHandler):
|
|
|
82
82
|
return formatter.format(record)
|
|
83
83
|
|
|
84
84
|
|
|
85
|
-
def update_console_handler(
|
|
85
|
+
def update_console_handler(
|
|
86
|
+
level: Optional[int] = None,
|
|
87
|
+
timestamps: Optional[bool] = None,
|
|
88
|
+
colored: Optional[bool] = None,
|
|
89
|
+
) -> None:
|
|
86
90
|
"""Update the logging handler."""
|
|
87
91
|
for handler in logging.getLogger(LOGGER_NAME).handlers:
|
|
88
92
|
if isinstance(handler, ConsoleHandler):
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
93
|
+
if level is not None:
|
|
94
|
+
handler.setLevel(level)
|
|
95
|
+
if timestamps is not None:
|
|
96
|
+
handler.timestamps = timestamps
|
|
97
|
+
if colored is not None:
|
|
98
|
+
handler.colored = colored
|
|
92
99
|
|
|
93
100
|
|
|
94
101
|
# Configure console logger
|
|
@@ -188,3 +195,29 @@ def warn_deprecated_feature(name: str) -> None:
|
|
|
188
195
|
""",
|
|
189
196
|
name,
|
|
190
197
|
)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def set_logger_propagation(
|
|
201
|
+
child_logger: logging.Logger, value: bool = True
|
|
202
|
+
) -> logging.Logger:
|
|
203
|
+
"""Set the logger propagation attribute.
|
|
204
|
+
|
|
205
|
+
Parameters
|
|
206
|
+
----------
|
|
207
|
+
child_logger : logging.Logger
|
|
208
|
+
Child logger object
|
|
209
|
+
value : bool
|
|
210
|
+
Boolean setting for propagation. If True, both parent and child logger
|
|
211
|
+
display messages. Otherwise, only the child logger displays a message.
|
|
212
|
+
This False setting prevents duplicate logs in Colab notebooks.
|
|
213
|
+
Reference: https://stackoverflow.com/a/19561320
|
|
214
|
+
|
|
215
|
+
Returns
|
|
216
|
+
-------
|
|
217
|
+
logging.Logger
|
|
218
|
+
Child logger object with updated propagation setting
|
|
219
|
+
"""
|
|
220
|
+
child_logger.propagate = value
|
|
221
|
+
if not child_logger.propagate:
|
|
222
|
+
child_logger.log(logging.DEBUG, "Logger propagate set to False")
|
|
223
|
+
return child_logger
|