flwr 1.14.0__py3-none-any.whl → 1.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/auth_plugin/__init__.py +31 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +150 -0
- flwr/cli/cli_user_auth_interceptor.py +6 -2
- flwr/cli/config_utils.py +24 -147
- flwr/cli/constant.py +27 -0
- flwr/cli/install.py +1 -1
- flwr/cli/log.py +18 -3
- flwr/cli/login/login.py +43 -8
- flwr/cli/ls.py +14 -5
- flwr/cli/new/templates/app/README.md.tpl +3 -2
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/run/run.py +21 -11
- flwr/cli/stop.py +13 -4
- flwr/cli/utils.py +54 -40
- flwr/client/app.py +36 -48
- flwr/client/clientapp/app.py +19 -25
- flwr/client/clientapp/utils.py +1 -1
- flwr/client/grpc_client/connection.py +1 -12
- flwr/client/grpc_rere_client/client_interceptor.py +19 -119
- flwr/client/grpc_rere_client/connection.py +46 -36
- flwr/client/grpc_rere_client/grpc_adapter.py +12 -12
- flwr/client/message_handler/task_handler.py +0 -17
- flwr/client/rest_client/connection.py +34 -26
- flwr/client/supernode/app.py +18 -72
- flwr/common/args.py +25 -47
- flwr/common/auth_plugin/auth_plugin.py +34 -23
- flwr/common/config.py +166 -16
- flwr/common/constant.py +24 -9
- flwr/common/differential_privacy.py +2 -1
- flwr/common/exit/__init__.py +24 -0
- flwr/common/exit/exit.py +99 -0
- flwr/common/exit/exit_code.py +93 -0
- flwr/common/exit_handlers.py +32 -30
- flwr/common/grpc.py +167 -4
- flwr/common/logger.py +26 -7
- flwr/common/object_ref.py +0 -14
- flwr/common/record/recordset.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
- flwr/common/serde.py +6 -4
- flwr/common/typing.py +20 -0
- flwr/proto/clientappio_pb2.py +1 -1
- flwr/proto/error_pb2.py +1 -1
- flwr/proto/exec_pb2.py +13 -25
- flwr/proto/exec_pb2.pyi +27 -54
- flwr/proto/fab_pb2.py +1 -1
- flwr/proto/fleet_pb2.py +31 -31
- flwr/proto/fleet_pb2.pyi +23 -23
- flwr/proto/fleet_pb2_grpc.py +30 -30
- flwr/proto/fleet_pb2_grpc.pyi +20 -20
- flwr/proto/grpcadapter_pb2.py +1 -1
- flwr/proto/log_pb2.py +1 -1
- flwr/proto/message_pb2.py +1 -1
- flwr/proto/node_pb2.py +3 -3
- flwr/proto/node_pb2.pyi +1 -4
- flwr/proto/recordset_pb2.py +1 -1
- flwr/proto/run_pb2.py +1 -1
- flwr/proto/serverappio_pb2.py +24 -25
- flwr/proto/serverappio_pb2.pyi +26 -32
- flwr/proto/serverappio_pb2_grpc.py +28 -28
- flwr/proto/serverappio_pb2_grpc.pyi +16 -16
- flwr/proto/simulationio_pb2.py +1 -1
- flwr/proto/task_pb2.py +1 -1
- flwr/proto/transport_pb2.py +1 -1
- flwr/server/app.py +116 -128
- flwr/server/compat/app_utils.py +0 -1
- flwr/server/compat/driver_client_proxy.py +1 -2
- flwr/server/driver/grpc_driver.py +32 -27
- flwr/server/driver/inmemory_driver.py +2 -1
- flwr/server/serverapp/app.py +12 -10
- flwr/server/superlink/driver/serverappio_grpc.py +1 -1
- flwr/server/superlink/driver/serverappio_servicer.py +74 -48
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -88
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -24
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +110 -168
- flwr/server/superlink/fleet/message_handler/message_handler.py +37 -24
- flwr/server/superlink/fleet/rest_rere/rest_api.py +16 -18
- flwr/server/superlink/fleet/vce/vce_api.py +2 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +45 -75
- flwr/server/superlink/linkstate/linkstate.py +17 -38
- flwr/server/superlink/linkstate/sqlite_linkstate.py +81 -145
- flwr/server/superlink/linkstate/utils.py +18 -8
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/utils/validator.py +9 -34
- flwr/simulation/app.py +4 -6
- flwr/simulation/legacy_app.py +4 -2
- flwr/simulation/run_simulation.py +1 -1
- flwr/simulation/simulationio_connection.py +2 -1
- flwr/superexec/exec_grpc.py +1 -1
- flwr/superexec/exec_servicer.py +23 -2
- {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/METADATA +8 -8
- {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/RECORD +103 -97
- {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/LICENSE +0 -0
- {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/WHEEL +0 -0
- {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/entry_points.txt +0 -0
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import random
|
|
19
|
-
import sys
|
|
20
19
|
import threading
|
|
21
20
|
from collections.abc import Iterator
|
|
22
21
|
from contextlib import contextmanager
|
|
@@ -26,22 +25,22 @@ from typing import Callable, Optional, TypeVar, Union
|
|
|
26
25
|
|
|
27
26
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
28
27
|
from google.protobuf.message import Message as GrpcMessage
|
|
28
|
+
from requests.exceptions import ConnectionError as RequestsConnectionError
|
|
29
29
|
|
|
30
30
|
from flwr.client.heartbeat import start_ping_loop
|
|
31
31
|
from flwr.client.message_handler.message_handler import validate_out_message
|
|
32
|
-
from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins
|
|
33
32
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
34
33
|
from flwr.common.constant import (
|
|
35
|
-
MISSING_EXTRA_REST,
|
|
36
34
|
PING_BASE_MULTIPLIER,
|
|
37
35
|
PING_CALL_TIMEOUT,
|
|
38
36
|
PING_DEFAULT_INTERVAL,
|
|
39
37
|
PING_RANDOM_RANGE,
|
|
40
38
|
)
|
|
39
|
+
from flwr.common.exit import ExitCode, flwr_exit
|
|
41
40
|
from flwr.common.logger import log
|
|
42
41
|
from flwr.common.message import Message, Metadata
|
|
43
42
|
from flwr.common.retry_invoker import RetryInvoker
|
|
44
|
-
from flwr.common.serde import
|
|
43
|
+
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
|
|
45
44
|
from flwr.common.typing import Fab, Run
|
|
46
45
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
47
46
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
@@ -51,25 +50,26 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
51
50
|
DeleteNodeResponse,
|
|
52
51
|
PingRequest,
|
|
53
52
|
PingResponse,
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
53
|
+
PullMessagesRequest,
|
|
54
|
+
PullMessagesResponse,
|
|
55
|
+
PushMessagesRequest,
|
|
56
|
+
PushMessagesResponse,
|
|
58
57
|
)
|
|
59
58
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
60
59
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
61
|
-
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
62
60
|
|
|
63
61
|
try:
|
|
64
62
|
import requests
|
|
65
63
|
except ModuleNotFoundError:
|
|
66
|
-
|
|
64
|
+
flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)
|
|
67
65
|
|
|
68
66
|
|
|
69
67
|
PATH_CREATE_NODE: str = "api/v0/fleet/create-node"
|
|
70
68
|
PATH_DELETE_NODE: str = "api/v0/fleet/delete-node"
|
|
71
69
|
PATH_PULL_TASK_INS: str = "api/v0/fleet/pull-task-ins"
|
|
70
|
+
PATH_PULL_MESSAGES: str = "/api/v0/fleet/pull-messages"
|
|
72
71
|
PATH_PUSH_TASK_RES: str = "api/v0/fleet/push-task-res"
|
|
72
|
+
PATH_PUSH_MESSAGES: str = "/api/v0/fleet/push-messages"
|
|
73
73
|
PATH_PING: str = "api/v0/fleet/ping"
|
|
74
74
|
PATH_GET_RUN: str = "/api/v0/fleet/get-run"
|
|
75
75
|
PATH_GET_FAB: str = "/api/v0/fleet/get-fab"
|
|
@@ -286,29 +286,28 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
286
286
|
log(ERROR, "Node instance missing")
|
|
287
287
|
return None
|
|
288
288
|
|
|
289
|
-
# Request instructions (
|
|
290
|
-
req =
|
|
289
|
+
# Request instructions (message) from server
|
|
290
|
+
req = PullMessagesRequest(node=node)
|
|
291
291
|
|
|
292
292
|
# Send the request
|
|
293
|
-
res = _request(req,
|
|
293
|
+
res = _request(req, PullMessagesResponse, PATH_PULL_MESSAGES)
|
|
294
294
|
if res is None:
|
|
295
295
|
return None
|
|
296
296
|
|
|
297
|
-
# Get the current
|
|
298
|
-
|
|
297
|
+
# Get the current Messages
|
|
298
|
+
message_proto = None if len(res.messages_list) == 0 else res.messages_list[0]
|
|
299
299
|
|
|
300
|
-
# Discard the current
|
|
301
|
-
if
|
|
302
|
-
|
|
303
|
-
and validate_task_ins(task_ins)
|
|
300
|
+
# Discard the current message if not valid
|
|
301
|
+
if message_proto is not None and not (
|
|
302
|
+
message_proto.metadata.dst_node_id == node.node_id
|
|
304
303
|
):
|
|
305
|
-
|
|
304
|
+
message_proto = None
|
|
306
305
|
|
|
307
306
|
# Return the Message if available
|
|
308
307
|
nonlocal metadata
|
|
309
308
|
message = None
|
|
310
|
-
if
|
|
311
|
-
message =
|
|
309
|
+
if message_proto is not None:
|
|
310
|
+
message = message_from_proto(message_proto)
|
|
312
311
|
metadata = copy(message.metadata)
|
|
313
312
|
log(INFO, "[Node] POST /%s: success", PATH_PULL_TASK_INS)
|
|
314
313
|
return message
|
|
@@ -332,14 +331,14 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
332
331
|
return
|
|
333
332
|
metadata = None
|
|
334
333
|
|
|
335
|
-
#
|
|
336
|
-
|
|
334
|
+
# Serialize ProtoBuf to bytes
|
|
335
|
+
message_proto = message_to_proto(message=message)
|
|
337
336
|
|
|
338
337
|
# Serialize ProtoBuf to bytes
|
|
339
|
-
req =
|
|
338
|
+
req = PushMessagesRequest(node=node, messages_list=[message_proto])
|
|
340
339
|
|
|
341
340
|
# Send the request
|
|
342
|
-
res = _request(req,
|
|
341
|
+
res = _request(req, PushMessagesResponse, PATH_PUSH_MESSAGES)
|
|
343
342
|
if res is None:
|
|
344
343
|
return
|
|
345
344
|
|
|
@@ -380,3 +379,12 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
380
379
|
yield (receive, send, create_node, delete_node, get_run, get_fab)
|
|
381
380
|
except Exception as exc: # pylint: disable=broad-except
|
|
382
381
|
log(ERROR, exc)
|
|
382
|
+
# Cleanup
|
|
383
|
+
finally:
|
|
384
|
+
try:
|
|
385
|
+
if node is not None:
|
|
386
|
+
# Disable retrying
|
|
387
|
+
retry_invoker.max_tries = 1
|
|
388
|
+
delete_node()
|
|
389
|
+
except RequestsConnectionError:
|
|
390
|
+
pass
|
flwr/client/supernode/app.py
CHANGED
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import argparse
|
|
19
|
-
import sys
|
|
20
19
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
21
20
|
from pathlib import Path
|
|
22
21
|
from typing import Optional
|
|
@@ -40,8 +39,9 @@ from flwr.common.constant import (
|
|
|
40
39
|
TRANSPORT_TYPE_GRPC_RERE,
|
|
41
40
|
TRANSPORT_TYPE_REST,
|
|
42
41
|
)
|
|
42
|
+
from flwr.common.exit import ExitCode, flwr_exit
|
|
43
43
|
from flwr.common.exit_handlers import register_exit_handlers
|
|
44
|
-
from flwr.common.logger import log
|
|
44
|
+
from flwr.common.logger import log
|
|
45
45
|
|
|
46
46
|
from ..app import start_client_internal
|
|
47
47
|
from ..clientapp.utils import get_load_client_app_fn
|
|
@@ -50,7 +50,6 @@ from ..clientapp.utils import get_load_client_app_fn
|
|
|
50
50
|
def run_supernode() -> None:
|
|
51
51
|
"""Run Flower SuperNode."""
|
|
52
52
|
args = _parse_args_run_supernode().parse_args()
|
|
53
|
-
_warn_deprecated_server_arg(args)
|
|
54
53
|
|
|
55
54
|
log(INFO, "Starting Flower SuperNode")
|
|
56
55
|
|
|
@@ -64,17 +63,6 @@ def run_supernode() -> None:
|
|
|
64
63
|
"Ignoring `--flwr-dir`.",
|
|
65
64
|
)
|
|
66
65
|
|
|
67
|
-
# Exit if unsupported argument is passed by the user
|
|
68
|
-
if args.app is not None:
|
|
69
|
-
log(
|
|
70
|
-
ERROR,
|
|
71
|
-
"The `app` argument is deprecated. The SuperNode now automatically "
|
|
72
|
-
"uses the ClientApp delivered from the SuperLink. Providing the app "
|
|
73
|
-
"directory manually is no longer supported. Please remove the `app` "
|
|
74
|
-
"argument from your command.",
|
|
75
|
-
)
|
|
76
|
-
sys.exit(1)
|
|
77
|
-
|
|
78
66
|
root_certificates = try_obtain_root_certificates(args, args.superlink)
|
|
79
67
|
load_fn = get_load_client_app_fn(
|
|
80
68
|
default_app_ref="",
|
|
@@ -86,6 +74,12 @@ def run_supernode() -> None:
|
|
|
86
74
|
|
|
87
75
|
log(DEBUG, "Isolation mode: %s", args.isolation)
|
|
88
76
|
|
|
77
|
+
# Register handlers for graceful shutdown
|
|
78
|
+
register_exit_handlers(
|
|
79
|
+
event_type=EventType.RUN_SUPERNODE_LEAVE,
|
|
80
|
+
exit_message="SuperNode terminated gracefully.",
|
|
81
|
+
)
|
|
82
|
+
|
|
89
83
|
start_client_internal(
|
|
90
84
|
server_address=args.superlink,
|
|
91
85
|
load_client_app_fn=load_fn,
|
|
@@ -103,11 +97,6 @@ def run_supernode() -> None:
|
|
|
103
97
|
clientappio_api_address=args.clientappio_api_address,
|
|
104
98
|
)
|
|
105
99
|
|
|
106
|
-
# Graceful shutdown
|
|
107
|
-
register_exit_handlers(
|
|
108
|
-
event_type=EventType.RUN_SUPERNODE_LEAVE,
|
|
109
|
-
)
|
|
110
|
-
|
|
111
100
|
|
|
112
101
|
def run_client_app() -> None:
|
|
113
102
|
"""Run Flower client app."""
|
|
@@ -119,43 +108,11 @@ def run_client_app() -> None:
|
|
|
119
108
|
register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE)
|
|
120
109
|
|
|
121
110
|
|
|
122
|
-
def _warn_deprecated_server_arg(args: argparse.Namespace) -> None:
|
|
123
|
-
"""Warn about the deprecated argument `--server`."""
|
|
124
|
-
if args.server != FLEET_API_GRPC_RERE_DEFAULT_ADDRESS:
|
|
125
|
-
warn = "Passing flag --server is deprecated. Use --superlink instead."
|
|
126
|
-
warn_deprecated_feature(warn)
|
|
127
|
-
|
|
128
|
-
if args.superlink != FLEET_API_GRPC_RERE_DEFAULT_ADDRESS:
|
|
129
|
-
# if `--superlink` also passed, then
|
|
130
|
-
# warn user that this argument overrides what was passed with `--server`
|
|
131
|
-
log(
|
|
132
|
-
WARN,
|
|
133
|
-
"Both `--server` and `--superlink` were passed. "
|
|
134
|
-
"`--server` will be ignored. Connecting to the Superlink Fleet API "
|
|
135
|
-
"at %s.",
|
|
136
|
-
args.superlink,
|
|
137
|
-
)
|
|
138
|
-
else:
|
|
139
|
-
args.superlink = args.server
|
|
140
|
-
|
|
141
|
-
|
|
142
111
|
def _parse_args_run_supernode() -> argparse.ArgumentParser:
|
|
143
112
|
"""Parse flower-supernode command line arguments."""
|
|
144
113
|
parser = argparse.ArgumentParser(
|
|
145
114
|
description="Start a Flower SuperNode",
|
|
146
115
|
)
|
|
147
|
-
|
|
148
|
-
parser.add_argument(
|
|
149
|
-
"app",
|
|
150
|
-
nargs="?",
|
|
151
|
-
default=None,
|
|
152
|
-
help=(
|
|
153
|
-
"(REMOVED) This argument is removed. The SuperNode now automatically "
|
|
154
|
-
"uses the ClientApp delivered from the SuperLink, so there is no need to "
|
|
155
|
-
"provide the app directory manually. This argument will be removed in a "
|
|
156
|
-
"future version."
|
|
157
|
-
),
|
|
158
|
-
)
|
|
159
116
|
_parse_args_common(parser)
|
|
160
117
|
parser.add_argument(
|
|
161
118
|
"--flwr-dir",
|
|
@@ -228,15 +185,12 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
228
185
|
help="Specifies the path to the PEM-encoded root certificate file for "
|
|
229
186
|
"establishing secure HTTPS connections.",
|
|
230
187
|
)
|
|
231
|
-
parser.add_argument(
|
|
232
|
-
"--server",
|
|
233
|
-
default=FLEET_API_GRPC_RERE_DEFAULT_ADDRESS,
|
|
234
|
-
help="Server address",
|
|
235
|
-
)
|
|
236
188
|
parser.add_argument(
|
|
237
189
|
"--superlink",
|
|
238
190
|
default=FLEET_API_GRPC_RERE_DEFAULT_ADDRESS,
|
|
239
|
-
help="SuperLink Fleet API
|
|
191
|
+
help="SuperLink Fleet API address (IPv4, IPv6, or a domain name). If using the "
|
|
192
|
+
"REST (experimental) transport, ensure your address is in the form "
|
|
193
|
+
"`http://...` or `https://...` when TLS is enabled.",
|
|
240
194
|
)
|
|
241
195
|
parser.add_argument(
|
|
242
196
|
"--max-retries",
|
|
@@ -280,11 +234,7 @@ def _try_setup_client_authentication(
|
|
|
280
234
|
return None
|
|
281
235
|
|
|
282
236
|
if not args.auth_supernode_private_key or not args.auth_supernode_public_key:
|
|
283
|
-
|
|
284
|
-
"Authentication requires file paths to both "
|
|
285
|
-
"'--auth-supernode-private-key' and '--auth-supernode-public-key'"
|
|
286
|
-
"to be provided (providing only one of them is not sufficient)."
|
|
287
|
-
)
|
|
237
|
+
flwr_exit(ExitCode.SUPERNODE_NODE_AUTH_KEYS_REQUIRED)
|
|
288
238
|
|
|
289
239
|
try:
|
|
290
240
|
ssh_private_key = load_ssh_private_key(
|
|
@@ -294,11 +244,9 @@ def _try_setup_client_authentication(
|
|
|
294
244
|
if not isinstance(ssh_private_key, ec.EllipticCurvePrivateKey):
|
|
295
245
|
raise ValueError()
|
|
296
246
|
except (ValueError, UnsupportedAlgorithm):
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
"
|
|
300
|
-
"curve private and public key pair. Please ensure that the file "
|
|
301
|
-
"path points to a valid private key file and try again."
|
|
247
|
+
flwr_exit(
|
|
248
|
+
ExitCode.SUPERNODE_NODE_AUTH_KEYS_INVALID,
|
|
249
|
+
"Unable to parse the private key file.",
|
|
302
250
|
)
|
|
303
251
|
|
|
304
252
|
try:
|
|
@@ -308,11 +256,9 @@ def _try_setup_client_authentication(
|
|
|
308
256
|
if not isinstance(ssh_public_key, ec.EllipticCurvePublicKey):
|
|
309
257
|
raise ValueError()
|
|
310
258
|
except (ValueError, UnsupportedAlgorithm):
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
"
|
|
314
|
-
"curve private and public key pair. Please ensure that the file "
|
|
315
|
-
"path points to a valid public key file and try again."
|
|
259
|
+
flwr_exit(
|
|
260
|
+
ExitCode.SUPERNODE_NODE_AUTH_KEYS_INVALID,
|
|
261
|
+
"Unable to parse the public key file.",
|
|
316
262
|
)
|
|
317
263
|
|
|
318
264
|
return (
|
flwr/common/args.py
CHANGED
|
@@ -20,13 +20,9 @@ import sys
|
|
|
20
20
|
from logging import DEBUG, ERROR, WARN
|
|
21
21
|
from os.path import isfile
|
|
22
22
|
from pathlib import Path
|
|
23
|
-
from typing import Optional
|
|
23
|
+
from typing import Optional, Union
|
|
24
24
|
|
|
25
|
-
from flwr.common.constant import
|
|
26
|
-
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
27
|
-
TRANSPORT_TYPE_GRPC_RERE,
|
|
28
|
-
TRANSPORT_TYPE_REST,
|
|
29
|
-
)
|
|
25
|
+
from flwr.common.constant import TRANSPORT_TYPE_REST
|
|
30
26
|
from flwr.common.logger import log
|
|
31
27
|
|
|
32
28
|
|
|
@@ -55,9 +51,9 @@ def add_args_flwr_app_common(parser: argparse.ArgumentParser) -> None:
|
|
|
55
51
|
def try_obtain_root_certificates(
|
|
56
52
|
args: argparse.Namespace,
|
|
57
53
|
grpc_server_address: str,
|
|
58
|
-
) -> Optional[bytes]:
|
|
54
|
+
) -> Optional[Union[bytes, str]]:
|
|
59
55
|
"""Validate and return the root certificates."""
|
|
60
|
-
root_cert_path = args.root_certificates
|
|
56
|
+
root_cert_path: Optional[str] = args.root_certificates
|
|
61
57
|
if args.insecure:
|
|
62
58
|
if root_cert_path is not None:
|
|
63
59
|
sys.exit(
|
|
@@ -93,56 +89,38 @@ def try_obtain_root_certificates(
|
|
|
93
89
|
grpc_server_address,
|
|
94
90
|
root_cert_path,
|
|
95
91
|
)
|
|
92
|
+
if args.transport == TRANSPORT_TYPE_REST:
|
|
93
|
+
return root_cert_path
|
|
96
94
|
return root_certificates
|
|
97
95
|
|
|
98
96
|
|
|
99
97
|
def try_obtain_server_certificates(
|
|
100
98
|
args: argparse.Namespace,
|
|
101
|
-
transport_type: str,
|
|
102
99
|
) -> Optional[tuple[bytes, bytes, bytes]]:
|
|
103
100
|
"""Validate and return the CA cert, server cert, and server private key."""
|
|
104
101
|
if args.insecure:
|
|
105
102
|
log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.")
|
|
106
103
|
return None
|
|
107
104
|
# Check if certificates are provided
|
|
108
|
-
if
|
|
109
|
-
if
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
)
|
|
128
|
-
if transport_type == TRANSPORT_TYPE_REST:
|
|
129
|
-
if args.ssl_certfile and args.ssl_keyfile:
|
|
130
|
-
if not isfile(args.ssl_certfile):
|
|
131
|
-
sys.exit("Path argument `--ssl-certfile` does not point to a file.")
|
|
132
|
-
if not isfile(args.ssl_keyfile):
|
|
133
|
-
sys.exit("Path argument `--ssl-keyfile` does not point to a file.")
|
|
134
|
-
certificates = (
|
|
135
|
-
b"",
|
|
136
|
-
Path(args.ssl_certfile).read_bytes(), # server certificate
|
|
137
|
-
Path(args.ssl_keyfile).read_bytes(), # server private key
|
|
138
|
-
)
|
|
139
|
-
return certificates
|
|
140
|
-
if args.ssl_certfile or args.ssl_keyfile:
|
|
141
|
-
sys.exit(
|
|
142
|
-
"You need to provide valid file paths to `--ssl-certfile` "
|
|
143
|
-
"and `--ssl-keyfile` to create a secure connection "
|
|
144
|
-
"in Fleet API server (REST, experimental)."
|
|
145
|
-
)
|
|
105
|
+
if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile:
|
|
106
|
+
if not isfile(args.ssl_ca_certfile):
|
|
107
|
+
sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.")
|
|
108
|
+
if not isfile(args.ssl_certfile):
|
|
109
|
+
sys.exit("Path argument `--ssl-certfile` does not point to a file.")
|
|
110
|
+
if not isfile(args.ssl_keyfile):
|
|
111
|
+
sys.exit("Path argument `--ssl-keyfile` does not point to a file.")
|
|
112
|
+
certificates = (
|
|
113
|
+
Path(args.ssl_ca_certfile).read_bytes(), # CA certificate
|
|
114
|
+
Path(args.ssl_certfile).read_bytes(), # server certificate
|
|
115
|
+
Path(args.ssl_keyfile).read_bytes(), # server private key
|
|
116
|
+
)
|
|
117
|
+
return certificates
|
|
118
|
+
if args.ssl_certfile or args.ssl_keyfile or args.ssl_ca_certfile:
|
|
119
|
+
sys.exit(
|
|
120
|
+
"You need to provide valid file paths to `--ssl-certfile`, "
|
|
121
|
+
"`--ssl-keyfile`, and `—-ssl-ca-certfile` to create a secure "
|
|
122
|
+
"connection in Fleet API server (gRPC-rere)."
|
|
123
|
+
)
|
|
146
124
|
log(
|
|
147
125
|
ERROR,
|
|
148
126
|
"Certificates are required unless running in insecure mode. "
|
|
@@ -18,26 +18,32 @@
|
|
|
18
18
|
from abc import ABC, abstractmethod
|
|
19
19
|
from collections.abc import Sequence
|
|
20
20
|
from pathlib import Path
|
|
21
|
-
from typing import
|
|
21
|
+
from typing import Optional, Union
|
|
22
22
|
|
|
23
23
|
from flwr.proto.exec_pb2_grpc import ExecStub
|
|
24
24
|
|
|
25
|
+
from ..typing import UserAuthCredentials, UserAuthLoginDetails
|
|
26
|
+
|
|
25
27
|
|
|
26
28
|
class ExecAuthPlugin(ABC):
|
|
27
29
|
"""Abstract Flower Auth Plugin class for ExecServicer.
|
|
28
30
|
|
|
29
31
|
Parameters
|
|
30
32
|
----------
|
|
31
|
-
|
|
32
|
-
|
|
33
|
+
user_auth_config_path : Path
|
|
34
|
+
Path to the YAML file containing the authentication configuration.
|
|
33
35
|
"""
|
|
34
36
|
|
|
35
37
|
@abstractmethod
|
|
36
|
-
def __init__(
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
user_auth_config_path: Path,
|
|
41
|
+
verify_tls_cert: bool,
|
|
42
|
+
):
|
|
37
43
|
"""Abstract constructor."""
|
|
38
44
|
|
|
39
45
|
@abstractmethod
|
|
40
|
-
def get_login_details(self) ->
|
|
46
|
+
def get_login_details(self) -> Optional[UserAuthLoginDetails]:
|
|
41
47
|
"""Get the login details."""
|
|
42
48
|
|
|
43
49
|
@abstractmethod
|
|
@@ -47,7 +53,7 @@ class ExecAuthPlugin(ABC):
|
|
|
47
53
|
"""Validate authentication tokens in the provided metadata."""
|
|
48
54
|
|
|
49
55
|
@abstractmethod
|
|
50
|
-
def get_auth_tokens(self,
|
|
56
|
+
def get_auth_tokens(self, device_code: str) -> Optional[UserAuthCredentials]:
|
|
51
57
|
"""Get authentication tokens."""
|
|
52
58
|
|
|
53
59
|
@abstractmethod
|
|
@@ -62,50 +68,55 @@ class CliAuthPlugin(ABC):
|
|
|
62
68
|
|
|
63
69
|
Parameters
|
|
64
70
|
----------
|
|
65
|
-
|
|
66
|
-
|
|
71
|
+
credentials_path : Path
|
|
72
|
+
Path to the user's authentication credentials file.
|
|
67
73
|
"""
|
|
68
74
|
|
|
69
75
|
@staticmethod
|
|
70
76
|
@abstractmethod
|
|
71
77
|
def login(
|
|
72
|
-
login_details:
|
|
78
|
+
login_details: UserAuthLoginDetails,
|
|
73
79
|
exec_stub: ExecStub,
|
|
74
|
-
) ->
|
|
75
|
-
"""Authenticate the user
|
|
80
|
+
) -> UserAuthCredentials:
|
|
81
|
+
"""Authenticate the user and retrieve authentication credentials.
|
|
76
82
|
|
|
77
83
|
Parameters
|
|
78
84
|
----------
|
|
79
|
-
login_details :
|
|
80
|
-
|
|
85
|
+
login_details : UserAuthLoginDetails
|
|
86
|
+
An object containing the user's login details.
|
|
81
87
|
exec_stub : ExecStub
|
|
82
|
-
|
|
88
|
+
A stub for executing RPC calls to the server.
|
|
83
89
|
|
|
84
90
|
Returns
|
|
85
91
|
-------
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
in JSON format.
|
|
92
|
+
UserAuthCredentials
|
|
93
|
+
The authentication credentials obtained after login.
|
|
89
94
|
"""
|
|
90
95
|
|
|
91
96
|
@abstractmethod
|
|
92
|
-
def __init__(self,
|
|
97
|
+
def __init__(self, credentials_path: Path):
|
|
93
98
|
"""Abstract constructor."""
|
|
94
99
|
|
|
95
100
|
@abstractmethod
|
|
96
|
-
def store_tokens(self,
|
|
97
|
-
"""Store authentication tokens
|
|
101
|
+
def store_tokens(self, credentials: UserAuthCredentials) -> None:
|
|
102
|
+
"""Store authentication tokens to the `credentials_path`.
|
|
98
103
|
|
|
99
|
-
The
|
|
100
|
-
at `
|
|
104
|
+
The credentials, including tokens, will be saved as a JSON file
|
|
105
|
+
at `credentials_path`.
|
|
101
106
|
"""
|
|
102
107
|
|
|
103
108
|
@abstractmethod
|
|
104
109
|
def load_tokens(self) -> None:
|
|
105
|
-
"""Load authentication tokens from the
|
|
110
|
+
"""Load authentication tokens from the `credentials_path`."""
|
|
106
111
|
|
|
107
112
|
@abstractmethod
|
|
108
113
|
def write_tokens_to_metadata(
|
|
109
114
|
self, metadata: Sequence[tuple[str, Union[str, bytes]]]
|
|
110
115
|
) -> Sequence[tuple[str, Union[str, bytes]]]:
|
|
111
116
|
"""Write authentication tokens to the provided metadata."""
|
|
117
|
+
|
|
118
|
+
@abstractmethod
|
|
119
|
+
def read_tokens_from_metadata(
|
|
120
|
+
self, metadata: Sequence[tuple[str, Union[str, bytes]]]
|
|
121
|
+
) -> Optional[UserAuthCredentials]:
|
|
122
|
+
"""Read authentication tokens from the provided metadata."""
|