flwr-nightly 1.9.0.dev20240417__py3-none-any.whl → 1.9.0.dev20240507__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +151 -0
- flwr/cli/config_utils.py +19 -14
- flwr/cli/new/new.py +51 -22
- flwr/cli/new/templates/app/.gitignore.tpl +160 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +42 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +26 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
- flwr/cli/run/run.py +1 -1
- flwr/cli/utils.py +18 -17
- flwr/client/__init__.py +3 -1
- flwr/client/app.py +20 -142
- flwr/client/grpc_client/connection.py +8 -2
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +33 -4
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/localdp_mod.py +9 -3
- flwr/client/rest_client/connection.py +92 -169
- flwr/client/supernode/__init__.py +24 -0
- flwr/client/supernode/app.py +281 -0
- flwr/common/grpc.py +5 -1
- flwr/common/logger.py +37 -4
- flwr/common/message.py +105 -86
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +78 -27
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
- flwr/common/telemetry.py +4 -0
- flwr/server/app.py +116 -6
- flwr/server/compat/app.py +2 -2
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +27 -70
- flwr/server/driver/__init__.py +2 -1
- flwr/server/driver/driver.py +12 -139
- flwr/server/driver/grpc_driver.py +199 -13
- flwr/server/run_serverapp.py +18 -4
- flwr/server/strategy/dp_adaptive_clipping.py +5 -3
- flwr/server/strategy/dp_fixed_clipping.py +6 -3
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +4 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +89 -12
- flwr/server/superlink/state/sqlite_state.py +133 -16
- flwr/server/superlink/state/state.py +56 -6
- flwr/simulation/__init__.py +2 -2
- flwr/simulation/app.py +16 -1
- flwr/simulation/run_simulation.py +10 -7
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/METADATA +3 -2
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/RECORD +66 -52
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/entry_points.txt +2 -1
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240417.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/WHEEL +0 -0
flwr/client/app.py
CHANGED
|
@@ -14,13 +14,12 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower client app."""
|
|
16
16
|
|
|
17
|
-
import argparse
|
|
18
17
|
import sys
|
|
19
18
|
import time
|
|
20
19
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
21
|
-
from pathlib import Path
|
|
22
20
|
from typing import Callable, ContextManager, Optional, Tuple, Type, Union
|
|
23
21
|
|
|
22
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
23
|
from grpc import RpcError
|
|
25
24
|
|
|
26
25
|
from flwr.client.client import Client
|
|
@@ -36,10 +35,8 @@ from flwr.common.constant import (
|
|
|
36
35
|
TRANSPORT_TYPES,
|
|
37
36
|
ErrorCode,
|
|
38
37
|
)
|
|
39
|
-
from flwr.common.exit_handlers import register_exit_handlers
|
|
40
38
|
from flwr.common.logger import log, warn_deprecated_feature
|
|
41
39
|
from flwr.common.message import Error
|
|
42
|
-
from flwr.common.object_ref import load_app, validate
|
|
43
40
|
from flwr.common.retry_invoker import RetryInvoker, exponential
|
|
44
41
|
|
|
45
42
|
from .grpc_client.connection import grpc_connection
|
|
@@ -49,142 +46,6 @@ from .node_state import NodeState
|
|
|
49
46
|
from .numpy_client import NumPyClient
|
|
50
47
|
|
|
51
48
|
|
|
52
|
-
def run_client_app() -> None:
|
|
53
|
-
"""Run Flower client app."""
|
|
54
|
-
event(EventType.RUN_CLIENT_APP_ENTER)
|
|
55
|
-
|
|
56
|
-
log(INFO, "Long-running Flower client starting")
|
|
57
|
-
|
|
58
|
-
args = _parse_args_run_client_app().parse_args()
|
|
59
|
-
|
|
60
|
-
# Obtain certificates
|
|
61
|
-
if args.insecure:
|
|
62
|
-
if args.root_certificates is not None:
|
|
63
|
-
sys.exit(
|
|
64
|
-
"Conflicting options: The '--insecure' flag disables HTTPS, "
|
|
65
|
-
"but '--root-certificates' was also specified. Please remove "
|
|
66
|
-
"the '--root-certificates' option when running in insecure mode, "
|
|
67
|
-
"or omit '--insecure' to use HTTPS."
|
|
68
|
-
)
|
|
69
|
-
log(
|
|
70
|
-
WARN,
|
|
71
|
-
"Option `--insecure` was set. "
|
|
72
|
-
"Starting insecure HTTP client connected to %s.",
|
|
73
|
-
args.server,
|
|
74
|
-
)
|
|
75
|
-
root_certificates = None
|
|
76
|
-
else:
|
|
77
|
-
# Load the certificates if provided, or load the system certificates
|
|
78
|
-
cert_path = args.root_certificates
|
|
79
|
-
if cert_path is None:
|
|
80
|
-
root_certificates = None
|
|
81
|
-
else:
|
|
82
|
-
root_certificates = Path(cert_path).read_bytes()
|
|
83
|
-
log(
|
|
84
|
-
DEBUG,
|
|
85
|
-
"Starting secure HTTPS client connected to %s "
|
|
86
|
-
"with the following certificates: %s.",
|
|
87
|
-
args.server,
|
|
88
|
-
cert_path,
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
log(
|
|
92
|
-
DEBUG,
|
|
93
|
-
"Flower will load ClientApp `%s`",
|
|
94
|
-
getattr(args, "client-app"),
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
client_app_dir = args.dir
|
|
98
|
-
if client_app_dir is not None:
|
|
99
|
-
sys.path.insert(0, client_app_dir)
|
|
100
|
-
|
|
101
|
-
app_ref: str = getattr(args, "client-app")
|
|
102
|
-
valid, error_msg = validate(app_ref)
|
|
103
|
-
if not valid and error_msg:
|
|
104
|
-
raise LoadClientAppError(error_msg) from None
|
|
105
|
-
|
|
106
|
-
def _load() -> ClientApp:
|
|
107
|
-
client_app = load_app(app_ref, LoadClientAppError)
|
|
108
|
-
|
|
109
|
-
if not isinstance(client_app, ClientApp):
|
|
110
|
-
raise LoadClientAppError(
|
|
111
|
-
f"Attribute {app_ref} is not of type {ClientApp}",
|
|
112
|
-
) from None
|
|
113
|
-
|
|
114
|
-
return client_app
|
|
115
|
-
|
|
116
|
-
_start_client_internal(
|
|
117
|
-
server_address=args.server,
|
|
118
|
-
load_client_app_fn=_load,
|
|
119
|
-
transport="rest" if args.rest else "grpc-rere",
|
|
120
|
-
root_certificates=root_certificates,
|
|
121
|
-
insecure=args.insecure,
|
|
122
|
-
max_retries=args.max_retries,
|
|
123
|
-
max_wait_time=args.max_wait_time,
|
|
124
|
-
)
|
|
125
|
-
register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE)
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
def _parse_args_run_client_app() -> argparse.ArgumentParser:
|
|
129
|
-
"""Parse flower-client-app command line arguments."""
|
|
130
|
-
parser = argparse.ArgumentParser(
|
|
131
|
-
description="Start a Flower client app",
|
|
132
|
-
)
|
|
133
|
-
|
|
134
|
-
parser.add_argument(
|
|
135
|
-
"client-app",
|
|
136
|
-
help="For example: `client:app` or `project.package.module:wrapper.app`",
|
|
137
|
-
)
|
|
138
|
-
parser.add_argument(
|
|
139
|
-
"--insecure",
|
|
140
|
-
action="store_true",
|
|
141
|
-
help="Run the client without HTTPS. By default, the client runs with "
|
|
142
|
-
"HTTPS enabled. Use this flag only if you understand the risks.",
|
|
143
|
-
)
|
|
144
|
-
parser.add_argument(
|
|
145
|
-
"--rest",
|
|
146
|
-
action="store_true",
|
|
147
|
-
help="Use REST as a transport layer for the client.",
|
|
148
|
-
)
|
|
149
|
-
parser.add_argument(
|
|
150
|
-
"--root-certificates",
|
|
151
|
-
metavar="ROOT_CERT",
|
|
152
|
-
type=str,
|
|
153
|
-
help="Specifies the path to the PEM-encoded root certificate file for "
|
|
154
|
-
"establishing secure HTTPS connections.",
|
|
155
|
-
)
|
|
156
|
-
parser.add_argument(
|
|
157
|
-
"--server",
|
|
158
|
-
default="0.0.0.0:9092",
|
|
159
|
-
help="Server address",
|
|
160
|
-
)
|
|
161
|
-
parser.add_argument(
|
|
162
|
-
"--max-retries",
|
|
163
|
-
type=int,
|
|
164
|
-
default=None,
|
|
165
|
-
help="The maximum number of times the client will try to connect to the"
|
|
166
|
-
"server before giving up in case of a connection error. By default,"
|
|
167
|
-
"it is set to None, meaning there is no limit to the number of tries.",
|
|
168
|
-
)
|
|
169
|
-
parser.add_argument(
|
|
170
|
-
"--max-wait-time",
|
|
171
|
-
type=float,
|
|
172
|
-
default=None,
|
|
173
|
-
help="The maximum duration before the client stops trying to"
|
|
174
|
-
"connect to the server in case of connection error. By default, it"
|
|
175
|
-
"is set to None, meaning there is no limit to the total time.",
|
|
176
|
-
)
|
|
177
|
-
parser.add_argument(
|
|
178
|
-
"--dir",
|
|
179
|
-
default="",
|
|
180
|
-
help="Add specified directory to the PYTHONPATH and load Flower "
|
|
181
|
-
"app from there."
|
|
182
|
-
" Default: current working directory.",
|
|
183
|
-
)
|
|
184
|
-
|
|
185
|
-
return parser
|
|
186
|
-
|
|
187
|
-
|
|
188
49
|
def _check_actionable_client(
|
|
189
50
|
client: Optional[Client], client_fn: Optional[ClientFn]
|
|
190
51
|
) -> None:
|
|
@@ -213,6 +74,9 @@ def start_client(
|
|
|
213
74
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
214
75
|
insecure: Optional[bool] = None,
|
|
215
76
|
transport: Optional[str] = None,
|
|
77
|
+
authentication_keys: Optional[
|
|
78
|
+
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
79
|
+
] = None,
|
|
216
80
|
max_retries: Optional[int] = None,
|
|
217
81
|
max_wait_time: Optional[float] = None,
|
|
218
82
|
) -> None:
|
|
@@ -297,6 +161,7 @@ def start_client(
|
|
|
297
161
|
root_certificates=root_certificates,
|
|
298
162
|
insecure=insecure,
|
|
299
163
|
transport=transport,
|
|
164
|
+
authentication_keys=authentication_keys,
|
|
300
165
|
max_retries=max_retries,
|
|
301
166
|
max_wait_time=max_wait_time,
|
|
302
167
|
)
|
|
@@ -317,6 +182,9 @@ def _start_client_internal(
|
|
|
317
182
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
318
183
|
insecure: Optional[bool] = None,
|
|
319
184
|
transport: Optional[str] = None,
|
|
185
|
+
authentication_keys: Optional[
|
|
186
|
+
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
187
|
+
] = None,
|
|
320
188
|
max_retries: Optional[int] = None,
|
|
321
189
|
max_wait_time: Optional[float] = None,
|
|
322
190
|
) -> None:
|
|
@@ -441,8 +309,10 @@ def _start_client_internal(
|
|
|
441
309
|
retry_invoker,
|
|
442
310
|
grpc_max_message_length,
|
|
443
311
|
root_certificates,
|
|
312
|
+
authentication_keys,
|
|
444
313
|
) as conn:
|
|
445
|
-
|
|
314
|
+
# pylint: disable-next=W0612
|
|
315
|
+
receive, send, create_node, delete_node, get_run = conn
|
|
446
316
|
|
|
447
317
|
# Register node
|
|
448
318
|
if create_node is not None:
|
|
@@ -653,13 +523,21 @@ def start_numpy_client(
|
|
|
653
523
|
|
|
654
524
|
def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
|
655
525
|
Callable[
|
|
656
|
-
[
|
|
526
|
+
[
|
|
527
|
+
str,
|
|
528
|
+
bool,
|
|
529
|
+
RetryInvoker,
|
|
530
|
+
int,
|
|
531
|
+
Union[bytes, str, None],
|
|
532
|
+
Optional[Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]],
|
|
533
|
+
],
|
|
657
534
|
ContextManager[
|
|
658
535
|
Tuple[
|
|
659
536
|
Callable[[], Optional[Message]],
|
|
660
537
|
Callable[[Message], None],
|
|
661
538
|
Optional[Callable[[], None]],
|
|
662
539
|
Optional[Callable[[], None]],
|
|
540
|
+
Optional[Callable[[int], Tuple[str, str]]],
|
|
663
541
|
]
|
|
664
542
|
],
|
|
665
543
|
],
|
|
@@ -22,6 +22,8 @@ from pathlib import Path
|
|
|
22
22
|
from queue import Queue
|
|
23
23
|
from typing import Callable, Iterator, Optional, Tuple, Union, cast
|
|
24
24
|
|
|
25
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
26
|
+
|
|
25
27
|
from flwr.common import (
|
|
26
28
|
DEFAULT_TTL,
|
|
27
29
|
GRPC_MAX_MESSAGE_LENGTH,
|
|
@@ -56,18 +58,22 @@ def on_channel_state_change(channel_connectivity: str) -> None:
|
|
|
56
58
|
|
|
57
59
|
|
|
58
60
|
@contextmanager
|
|
59
|
-
def grpc_connection( # pylint: disable=R0915
|
|
61
|
+
def grpc_connection( # pylint: disable=R0913, R0915
|
|
60
62
|
server_address: str,
|
|
61
63
|
insecure: bool,
|
|
62
64
|
retry_invoker: RetryInvoker, # pylint: disable=unused-argument
|
|
63
65
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
64
66
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
67
|
+
authentication_keys: Optional[ # pylint: disable=unused-argument
|
|
68
|
+
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
69
|
+
] = None,
|
|
65
70
|
) -> Iterator[
|
|
66
71
|
Tuple[
|
|
67
72
|
Callable[[], Optional[Message]],
|
|
68
73
|
Callable[[Message], None],
|
|
69
74
|
Optional[Callable[[], None]],
|
|
70
75
|
Optional[Callable[[], None]],
|
|
76
|
+
Optional[Callable[[int], Tuple[str, str]]],
|
|
71
77
|
]
|
|
72
78
|
]:
|
|
73
79
|
"""Establish a gRPC connection to a gRPC server.
|
|
@@ -224,7 +230,7 @@ def grpc_connection( # pylint: disable=R0915
|
|
|
224
230
|
|
|
225
231
|
try:
|
|
226
232
|
# Yield methods
|
|
227
|
-
yield (receive, send, None, None)
|
|
233
|
+
yield (receive, send, None, None, None)
|
|
228
234
|
finally:
|
|
229
235
|
# Make sure to have a final
|
|
230
236
|
channel.close()
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower client interceptor."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import base64
|
|
19
|
+
import collections
|
|
20
|
+
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
|
21
|
+
|
|
22
|
+
import grpc
|
|
23
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
|
+
|
|
25
|
+
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
26
|
+
bytes_to_public_key,
|
|
27
|
+
compute_hmac,
|
|
28
|
+
generate_shared_key,
|
|
29
|
+
public_key_to_bytes,
|
|
30
|
+
)
|
|
31
|
+
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
32
|
+
CreateNodeRequest,
|
|
33
|
+
DeleteNodeRequest,
|
|
34
|
+
GetRunRequest,
|
|
35
|
+
PingRequest,
|
|
36
|
+
PullTaskInsRequest,
|
|
37
|
+
PushTaskResRequest,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
_PUBLIC_KEY_HEADER = "public-key"
|
|
41
|
+
_AUTH_TOKEN_HEADER = "auth-token"
|
|
42
|
+
|
|
43
|
+
Request = Union[
|
|
44
|
+
CreateNodeRequest,
|
|
45
|
+
DeleteNodeRequest,
|
|
46
|
+
PullTaskInsRequest,
|
|
47
|
+
PushTaskResRequest,
|
|
48
|
+
GetRunRequest,
|
|
49
|
+
PingRequest,
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _get_value_from_tuples(
|
|
54
|
+
key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]]
|
|
55
|
+
) -> bytes:
|
|
56
|
+
value = next((value for key, value in tuples if key == key_string), "")
|
|
57
|
+
if isinstance(value, str):
|
|
58
|
+
return value.encode()
|
|
59
|
+
|
|
60
|
+
return value
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class _ClientCallDetails(
|
|
64
|
+
collections.namedtuple(
|
|
65
|
+
"_ClientCallDetails", ("method", "timeout", "metadata", "credentials")
|
|
66
|
+
),
|
|
67
|
+
grpc.ClientCallDetails, # type: ignore
|
|
68
|
+
):
|
|
69
|
+
"""Details for each client call.
|
|
70
|
+
|
|
71
|
+
The class will be passed on as the first argument in continuation function.
|
|
72
|
+
In our case, `AuthenticateClientInterceptor` adds new metadata to the construct.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
|
|
77
|
+
"""Client interceptor for client authentication."""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
private_key: ec.EllipticCurvePrivateKey,
|
|
82
|
+
public_key: ec.EllipticCurvePublicKey,
|
|
83
|
+
):
|
|
84
|
+
self.private_key = private_key
|
|
85
|
+
self.public_key = public_key
|
|
86
|
+
self.shared_secret: Optional[bytes] = None
|
|
87
|
+
self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None
|
|
88
|
+
self.encoded_public_key = base64.urlsafe_b64encode(
|
|
89
|
+
public_key_to_bytes(self.public_key)
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def intercept_unary_unary(
|
|
93
|
+
self,
|
|
94
|
+
continuation: Callable[[Any, Any], Any],
|
|
95
|
+
client_call_details: grpc.ClientCallDetails,
|
|
96
|
+
request: Request,
|
|
97
|
+
) -> grpc.Call:
|
|
98
|
+
"""Flower client interceptor.
|
|
99
|
+
|
|
100
|
+
Intercept unary call from client and add necessary authentication header in the
|
|
101
|
+
RPC metadata.
|
|
102
|
+
"""
|
|
103
|
+
metadata = []
|
|
104
|
+
postprocess = False
|
|
105
|
+
if client_call_details.metadata is not None:
|
|
106
|
+
metadata = list(client_call_details.metadata)
|
|
107
|
+
|
|
108
|
+
# Always add the public key header
|
|
109
|
+
metadata.append(
|
|
110
|
+
(
|
|
111
|
+
_PUBLIC_KEY_HEADER,
|
|
112
|
+
self.encoded_public_key,
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if isinstance(request, CreateNodeRequest):
|
|
117
|
+
postprocess = True
|
|
118
|
+
elif isinstance(
|
|
119
|
+
request,
|
|
120
|
+
(
|
|
121
|
+
DeleteNodeRequest,
|
|
122
|
+
PullTaskInsRequest,
|
|
123
|
+
PushTaskResRequest,
|
|
124
|
+
GetRunRequest,
|
|
125
|
+
PingRequest,
|
|
126
|
+
),
|
|
127
|
+
):
|
|
128
|
+
if self.shared_secret is None:
|
|
129
|
+
raise RuntimeError("Failure to compute hmac")
|
|
130
|
+
|
|
131
|
+
metadata.append(
|
|
132
|
+
(
|
|
133
|
+
_AUTH_TOKEN_HEADER,
|
|
134
|
+
base64.urlsafe_b64encode(
|
|
135
|
+
compute_hmac(
|
|
136
|
+
self.shared_secret, request.SerializeToString(True)
|
|
137
|
+
)
|
|
138
|
+
),
|
|
139
|
+
)
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
client_call_details = _ClientCallDetails(
|
|
143
|
+
client_call_details.method,
|
|
144
|
+
client_call_details.timeout,
|
|
145
|
+
metadata,
|
|
146
|
+
client_call_details.credentials,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
response = continuation(client_call_details, request)
|
|
150
|
+
if postprocess:
|
|
151
|
+
server_public_key_bytes = base64.urlsafe_b64decode(
|
|
152
|
+
_get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
|
|
153
|
+
)
|
|
154
|
+
self.server_public_key = bytes_to_public_key(server_public_key_bytes)
|
|
155
|
+
self.shared_secret = generate_shared_key(
|
|
156
|
+
self.private_key, self.server_public_key
|
|
157
|
+
)
|
|
158
|
+
return response
|
|
@@ -21,7 +21,10 @@ from contextlib import contextmanager
|
|
|
21
21
|
from copy import copy
|
|
22
22
|
from logging import DEBUG, ERROR
|
|
23
23
|
from pathlib import Path
|
|
24
|
-
from typing import Callable, Iterator, Optional, Tuple, Union, cast
|
|
24
|
+
from typing import Callable, Iterator, Optional, Sequence, Tuple, Union, cast
|
|
25
|
+
|
|
26
|
+
import grpc
|
|
27
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
25
28
|
|
|
26
29
|
from flwr.client.heartbeat import start_ping_loop
|
|
27
30
|
from flwr.client.message_handler.message_handler import validate_out_message
|
|
@@ -41,6 +44,8 @@ from flwr.common.serde import message_from_taskins, message_to_taskres
|
|
|
41
44
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
42
45
|
CreateNodeRequest,
|
|
43
46
|
DeleteNodeRequest,
|
|
47
|
+
GetRunRequest,
|
|
48
|
+
GetRunResponse,
|
|
44
49
|
PingRequest,
|
|
45
50
|
PingResponse,
|
|
46
51
|
PullTaskInsRequest,
|
|
@@ -50,6 +55,8 @@ from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
|
|
50
55
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
51
56
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
52
57
|
|
|
58
|
+
from .client_interceptor import AuthenticateClientInterceptor
|
|
59
|
+
|
|
53
60
|
|
|
54
61
|
def on_channel_state_change(channel_connectivity: str) -> None:
|
|
55
62
|
"""Log channel connectivity."""
|
|
@@ -57,18 +64,22 @@ def on_channel_state_change(channel_connectivity: str) -> None:
|
|
|
57
64
|
|
|
58
65
|
|
|
59
66
|
@contextmanager
|
|
60
|
-
def grpc_request_response( # pylint: disable=R0914, R0915
|
|
67
|
+
def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
61
68
|
server_address: str,
|
|
62
69
|
insecure: bool,
|
|
63
70
|
retry_invoker: RetryInvoker,
|
|
64
71
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
|
|
65
72
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
73
|
+
authentication_keys: Optional[
|
|
74
|
+
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
75
|
+
] = None,
|
|
66
76
|
) -> Iterator[
|
|
67
77
|
Tuple[
|
|
68
78
|
Callable[[], Optional[Message]],
|
|
69
79
|
Callable[[Message], None],
|
|
70
80
|
Optional[Callable[[], None]],
|
|
71
81
|
Optional[Callable[[], None]],
|
|
82
|
+
Optional[Callable[[int], Tuple[str, str]]],
|
|
72
83
|
]
|
|
73
84
|
]:
|
|
74
85
|
"""Primitives for request/response-based interaction with a server.
|
|
@@ -106,11 +117,18 @@ def grpc_request_response( # pylint: disable=R0914, R0915
|
|
|
106
117
|
if isinstance(root_certificates, str):
|
|
107
118
|
root_certificates = Path(root_certificates).read_bytes()
|
|
108
119
|
|
|
120
|
+
interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None
|
|
121
|
+
if authentication_keys is not None:
|
|
122
|
+
interceptors = AuthenticateClientInterceptor(
|
|
123
|
+
authentication_keys[0], authentication_keys[1]
|
|
124
|
+
)
|
|
125
|
+
|
|
109
126
|
channel = create_channel(
|
|
110
127
|
server_address=server_address,
|
|
111
128
|
insecure=insecure,
|
|
112
129
|
root_certificates=root_certificates,
|
|
113
130
|
max_message_length=max_message_length,
|
|
131
|
+
interceptors=interceptors,
|
|
114
132
|
)
|
|
115
133
|
channel.subscribe(on_channel_state_change)
|
|
116
134
|
|
|
@@ -122,7 +140,7 @@ def grpc_request_response( # pylint: disable=R0914, R0915
|
|
|
122
140
|
ping_stop_event = threading.Event()
|
|
123
141
|
|
|
124
142
|
###########################################################################
|
|
125
|
-
# ping/create_node/delete_node/receive/send functions
|
|
143
|
+
# ping/create_node/delete_node/receive/send/get_run functions
|
|
126
144
|
###########################################################################
|
|
127
145
|
|
|
128
146
|
def ping() -> None:
|
|
@@ -241,8 +259,19 @@ def grpc_request_response( # pylint: disable=R0914, R0915
|
|
|
241
259
|
# Cleanup
|
|
242
260
|
metadata = None
|
|
243
261
|
|
|
262
|
+
def get_run(run_id: int) -> Tuple[str, str]:
|
|
263
|
+
# Call FleetAPI
|
|
264
|
+
get_run_request = GetRunRequest(run_id=run_id)
|
|
265
|
+
get_run_response: GetRunResponse = retry_invoker.invoke(
|
|
266
|
+
stub.GetRun,
|
|
267
|
+
request=get_run_request,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Return fab_id and fab_version
|
|
271
|
+
return get_run_response.run.fab_id, get_run_response.run.fab_version
|
|
272
|
+
|
|
244
273
|
try:
|
|
245
274
|
# Yield methods
|
|
246
|
-
yield (receive, send, create_node, delete_node)
|
|
275
|
+
yield (receive, send, create_node, delete_node, get_run)
|
|
247
276
|
except Exception as exc: # pylint: disable=broad-except
|
|
248
277
|
log(ERROR, exc)
|
|
@@ -82,7 +82,9 @@ def fixedclipping_mod(
|
|
|
82
82
|
clipping_norm,
|
|
83
83
|
)
|
|
84
84
|
|
|
85
|
-
log(
|
|
85
|
+
log(
|
|
86
|
+
INFO, "fixedclipping_mod: parameters are clipped by value: %.4f.", clipping_norm
|
|
87
|
+
)
|
|
86
88
|
|
|
87
89
|
fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
|
|
88
90
|
out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
|
|
@@ -146,7 +148,7 @@ def adaptiveclipping_mod(
|
|
|
146
148
|
)
|
|
147
149
|
log(
|
|
148
150
|
INFO,
|
|
149
|
-
"adaptiveclipping_mod: parameters are clipped by value:
|
|
151
|
+
"adaptiveclipping_mod: parameters are clipped by value: %.4f.",
|
|
150
152
|
clipping_norm,
|
|
151
153
|
)
|
|
152
154
|
|
flwr/client/mod/localdp_mod.py
CHANGED
|
@@ -128,7 +128,9 @@ class LocalDpMod:
|
|
|
128
128
|
self.clipping_norm,
|
|
129
129
|
)
|
|
130
130
|
log(
|
|
131
|
-
INFO,
|
|
131
|
+
INFO,
|
|
132
|
+
"LocalDpMod: parameters are clipped by value: %.4f.",
|
|
133
|
+
self.clipping_norm,
|
|
132
134
|
)
|
|
133
135
|
|
|
134
136
|
fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
|
|
@@ -137,11 +139,15 @@ class LocalDpMod:
|
|
|
137
139
|
add_localdp_gaussian_noise_to_params(
|
|
138
140
|
fit_res.parameters, self.sensitivity, self.epsilon, self.delta
|
|
139
141
|
)
|
|
142
|
+
|
|
143
|
+
noise_value_sd = (
|
|
144
|
+
self.sensitivity * np.sqrt(2 * np.log(1.25 / self.delta)) / self.epsilon
|
|
145
|
+
)
|
|
140
146
|
log(
|
|
141
147
|
INFO,
|
|
142
148
|
"LocalDpMod: local DP noise with "
|
|
143
|
-
"standard deviation:
|
|
144
|
-
|
|
149
|
+
"standard deviation: %.4f added to parameters.",
|
|
150
|
+
noise_value_sd,
|
|
145
151
|
)
|
|
146
152
|
|
|
147
153
|
out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
|