flwr-nightly 1.15.0.dev20250104__py3-none-any.whl → 1.15.0.dev20250123__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/cli_user_auth_interceptor.py +6 -2
- flwr/cli/config_utils.py +23 -146
- flwr/cli/constant.py +27 -0
- flwr/cli/install.py +1 -1
- flwr/cli/log.py +17 -2
- flwr/cli/login/login.py +20 -5
- flwr/cli/ls.py +10 -2
- flwr/cli/run/run.py +20 -10
- flwr/cli/stop.py +9 -1
- flwr/cli/utils.py +4 -4
- flwr/client/app.py +36 -48
- flwr/client/clientapp/app.py +4 -6
- flwr/client/clientapp/utils.py +1 -1
- flwr/client/grpc_client/connection.py +0 -6
- flwr/client/grpc_rere_client/client_interceptor.py +19 -119
- flwr/client/grpc_rere_client/connection.py +34 -24
- flwr/client/grpc_rere_client/grpc_adapter.py +16 -0
- flwr/client/rest_client/connection.py +34 -26
- flwr/client/supernode/app.py +14 -20
- flwr/common/auth_plugin/auth_plugin.py +34 -23
- flwr/common/config.py +152 -15
- flwr/common/constant.py +11 -8
- flwr/common/exit/__init__.py +24 -0
- flwr/common/exit/exit.py +99 -0
- flwr/common/exit/exit_code.py +93 -0
- flwr/common/exit_handlers.py +24 -10
- flwr/common/grpc.py +161 -3
- flwr/common/logger.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
- flwr/common/serde.py +6 -4
- flwr/common/typing.py +20 -0
- flwr/proto/clientappio_pb2.py +13 -3
- flwr/proto/clientappio_pb2_grpc.py +63 -12
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/exec_pb2.py +27 -29
- flwr/proto/exec_pb2.pyi +27 -54
- flwr/proto/exec_pb2_grpc.py +105 -24
- flwr/proto/fab_pb2.py +13 -3
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fleet_pb2.py +54 -31
- flwr/proto/fleet_pb2.pyi +84 -0
- flwr/proto/fleet_pb2_grpc.py +207 -28
- flwr/proto/fleet_pb2_grpc.pyi +26 -0
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/node_pb2.py +15 -5
- flwr/proto/node_pb2.pyi +1 -4
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/recordset_pb2.py +18 -8
- flwr/proto/recordset_pb2_grpc.py +20 -0
- flwr/proto/run_pb2.py +16 -6
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/serverappio_pb2.py +32 -14
- flwr/proto/serverappio_pb2.pyi +56 -0
- flwr/proto/serverappio_pb2_grpc.py +261 -44
- flwr/proto/serverappio_pb2_grpc.pyi +20 -0
- flwr/proto/simulationio_pb2.py +13 -3
- flwr/proto/simulationio_pb2_grpc.py +105 -24
- flwr/proto/task_pb2.py +13 -3
- flwr/proto/task_pb2_grpc.py +20 -0
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/server/app.py +87 -38
- flwr/server/compat/app_utils.py +0 -1
- flwr/server/compat/driver_client_proxy.py +1 -2
- flwr/server/driver/grpc_driver.py +5 -2
- flwr/server/driver/inmemory_driver.py +2 -1
- flwr/server/serverapp/app.py +5 -6
- flwr/server/superlink/driver/serverappio_grpc.py +1 -1
- flwr/server/superlink/driver/serverappio_servicer.py +132 -14
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -88
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +38 -0
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +95 -168
- flwr/server/superlink/fleet/message_handler/message_handler.py +66 -5
- flwr/server/superlink/fleet/rest_rere/rest_api.py +28 -3
- flwr/server/superlink/fleet/vce/vce_api.py +2 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +40 -48
- flwr/server/superlink/linkstate/linkstate.py +15 -22
- flwr/server/superlink/linkstate/sqlite_linkstate.py +80 -99
- flwr/server/superlink/linkstate/utils.py +18 -8
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/utils/validator.py +9 -34
- flwr/simulation/app.py +4 -6
- flwr/simulation/legacy_app.py +4 -2
- flwr/simulation/run_simulation.py +1 -1
- flwr/superexec/exec_grpc.py +1 -1
- flwr/superexec/exec_servicer.py +23 -2
- {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/METADATA +7 -7
- {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/RECORD +98 -94
- {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/entry_points.txt +0 -0
flwr/client/app.py
CHANGED
@@ -15,13 +15,14 @@
|
|
15
15
|
"""Flower client app."""
|
16
16
|
|
17
17
|
|
18
|
-
import
|
19
|
-
import
|
18
|
+
import multiprocessing
|
19
|
+
import os
|
20
20
|
import sys
|
21
|
+
import threading
|
21
22
|
import time
|
22
23
|
from contextlib import AbstractContextManager
|
23
|
-
from dataclasses import dataclass
|
24
24
|
from logging import ERROR, INFO, WARN
|
25
|
+
from os import urandom
|
25
26
|
from pathlib import Path
|
26
27
|
from typing import Callable, Optional, Union, cast
|
27
28
|
|
@@ -33,6 +34,7 @@ from flwr.cli.config_utils import get_fab_metadata
|
|
33
34
|
from flwr.cli.install import install_from_fab
|
34
35
|
from flwr.client.client import Client
|
35
36
|
from flwr.client.client_app import ClientApp, LoadClientAppError
|
37
|
+
from flwr.client.clientapp.app import flwr_clientapp
|
36
38
|
from flwr.client.nodestate.nodestate_factory import NodeStateFactory
|
37
39
|
from flwr.client.typing import ClientFnExt
|
38
40
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
|
@@ -43,7 +45,6 @@ from flwr.common.constant import (
|
|
43
45
|
ISOLATION_MODE_PROCESS,
|
44
46
|
ISOLATION_MODE_SUBPROCESS,
|
45
47
|
MAX_RETRY_DELAY,
|
46
|
-
MISSING_EXTRA_REST,
|
47
48
|
RUN_ID_NUM_BYTES,
|
48
49
|
SERVER_OCTET,
|
49
50
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
@@ -53,13 +54,13 @@ from flwr.common.constant import (
|
|
53
54
|
TRANSPORT_TYPES,
|
54
55
|
ErrorCode,
|
55
56
|
)
|
57
|
+
from flwr.common.exit import ExitCode, flwr_exit
|
58
|
+
from flwr.common.grpc import generic_create_grpc_server
|
56
59
|
from flwr.common.logger import log, warn_deprecated_feature
|
57
60
|
from flwr.common.message import Error
|
58
61
|
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
|
59
62
|
from flwr.common.typing import Fab, Run, RunNotRunningException, UserConfig
|
60
63
|
from flwr.proto.clientappio_pb2_grpc import add_ClientAppIoServicer_to_server
|
61
|
-
from flwr.server.superlink.fleet.grpc_bidi.grpc_server import generic_create_grpc_server
|
62
|
-
from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes
|
63
64
|
|
64
65
|
from .clientapp.clientappio_servicer import ClientAppInputs, ClientAppIoServicer
|
65
66
|
from .grpc_adapter_client.connection import grpc_adapter
|
@@ -345,10 +346,7 @@ def start_client_internal(
|
|
345
346
|
transport, server_address
|
346
347
|
)
|
347
348
|
|
348
|
-
app_state_tracker = _AppStateTracker()
|
349
|
-
|
350
349
|
def _on_sucess(retry_state: RetryState) -> None:
|
351
|
-
app_state_tracker.is_connected = True
|
352
350
|
if retry_state.tries > 1:
|
353
351
|
log(
|
354
352
|
INFO,
|
@@ -358,7 +356,6 @@ def start_client_internal(
|
|
358
356
|
)
|
359
357
|
|
360
358
|
def _on_backoff(retry_state: RetryState) -> None:
|
361
|
-
app_state_tracker.is_connected = False
|
362
359
|
if retry_state.tries == 1:
|
363
360
|
log(WARN, "Connection attempt failed, retrying...")
|
364
361
|
else:
|
@@ -391,10 +388,11 @@ def start_client_internal(
|
|
391
388
|
run_info_store: Optional[DeprecatedRunInfoStore] = None
|
392
389
|
state_factory = NodeStateFactory()
|
393
390
|
state = state_factory.state()
|
391
|
+
mp_spawn_context = multiprocessing.get_context("spawn")
|
394
392
|
|
395
393
|
runs: dict[int, Run] = {}
|
396
394
|
|
397
|
-
while
|
395
|
+
while True:
|
398
396
|
sleep_duration: int = 0
|
399
397
|
with connection(
|
400
398
|
address,
|
@@ -433,9 +431,8 @@ def start_client_internal(
|
|
433
431
|
node_config=node_config,
|
434
432
|
)
|
435
433
|
|
436
|
-
app_state_tracker.register_signal_handler()
|
437
434
|
# pylint: disable=too-many-nested-blocks
|
438
|
-
while
|
435
|
+
while True:
|
439
436
|
try:
|
440
437
|
# Receive
|
441
438
|
message = receive()
|
@@ -513,7 +510,7 @@ def start_client_internal(
|
|
513
510
|
# Docker container.
|
514
511
|
|
515
512
|
# Generate SuperNode token
|
516
|
-
token
|
513
|
+
token = int.from_bytes(urandom(RUN_ID_NUM_BYTES), "little")
|
517
514
|
|
518
515
|
# Mode 1: SuperNode starts ClientApp as subprocess
|
519
516
|
start_subprocess = isolation == ISOLATION_MODE_SUBPROCESS
|
@@ -549,12 +546,13 @@ def start_client_internal(
|
|
549
546
|
]
|
550
547
|
command.append("--insecure")
|
551
548
|
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
check=True,
|
549
|
+
proc = mp_spawn_context.Process(
|
550
|
+
target=_run_flwr_clientapp,
|
551
|
+
args=(command, os.getpid()),
|
552
|
+
daemon=True,
|
557
553
|
)
|
554
|
+
proc.start()
|
555
|
+
proc.join()
|
558
556
|
else:
|
559
557
|
# Wait for output to become available
|
560
558
|
while not clientappio_servicer.has_outputs():
|
@@ -592,10 +590,7 @@ def start_client_internal(
|
|
592
590
|
e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
|
593
591
|
exc_entity = "SuperNode"
|
594
592
|
|
595
|
-
|
596
|
-
log(
|
597
|
-
ERROR, "%s raised an exception", exc_entity, exc_info=ex
|
598
|
-
)
|
593
|
+
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
|
599
594
|
|
600
595
|
# Create error message
|
601
596
|
reply_message = message.create_error_reply(
|
@@ -621,19 +616,14 @@ def start_client_internal(
|
|
621
616
|
run_id,
|
622
617
|
)
|
623
618
|
log(INFO, "")
|
624
|
-
|
625
|
-
except StopIteration:
|
626
|
-
sleep_duration = 0
|
627
|
-
break
|
628
619
|
# pylint: enable=too-many-nested-blocks
|
629
620
|
|
630
621
|
# Unregister node
|
631
|
-
if delete_node is not None
|
622
|
+
if delete_node is not None:
|
632
623
|
delete_node() # pylint: disable=not-callable
|
633
624
|
|
634
625
|
if sleep_duration == 0:
|
635
626
|
log(INFO, "Disconnect and shut down")
|
636
|
-
del app_state_tracker
|
637
627
|
break
|
638
628
|
|
639
629
|
# Sleep and reconnect afterwards
|
@@ -773,7 +763,10 @@ def _init_connection(transport: Optional[str], server_address: str) -> tuple[
|
|
773
763
|
# Parse IP address
|
774
764
|
parsed_address = parse_address(server_address)
|
775
765
|
if not parsed_address:
|
776
|
-
|
766
|
+
flwr_exit(
|
767
|
+
ExitCode.COMMON_ADDRESS_INVALID,
|
768
|
+
f"SuperLink address ({server_address}) cannot be parsed.",
|
769
|
+
)
|
777
770
|
host, port, is_v6 = parsed_address
|
778
771
|
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
779
772
|
|
@@ -788,12 +781,9 @@ def _init_connection(transport: Optional[str], server_address: str) -> tuple[
|
|
788
781
|
|
789
782
|
from .rest_client.connection import http_request_response
|
790
783
|
except ModuleNotFoundError:
|
791
|
-
|
784
|
+
flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)
|
792
785
|
if server_address[:4] != "http":
|
793
|
-
|
794
|
-
"When using the REST API, please provide `https://` or "
|
795
|
-
"`http://` before the server address (e.g. `http://127.0.0.1:8080`)"
|
796
|
-
)
|
786
|
+
flwr_exit(ExitCode.SUPERNODE_REST_ADDRESS_INVALID)
|
797
787
|
connection, error_type = http_request_response, RequestsConnectionError
|
798
788
|
elif transport == TRANSPORT_TYPE_GRPC_RERE:
|
799
789
|
connection, error_type = grpc_request_response, RpcError
|
@@ -809,21 +799,19 @@ def _init_connection(transport: Optional[str], server_address: str) -> tuple[
|
|
809
799
|
return connection, address, error_type
|
810
800
|
|
811
801
|
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
802
|
+
def _run_flwr_clientapp(args: list[str], main_pid: int) -> None:
|
803
|
+
# Monitor the main process in case of SIGKILL
|
804
|
+
def main_process_monitor() -> None:
|
805
|
+
while True:
|
806
|
+
time.sleep(1)
|
807
|
+
if os.getppid() != main_pid:
|
808
|
+
os.kill(os.getpid(), 9)
|
819
809
|
|
820
|
-
|
821
|
-
# pylint: disable=unused-argument
|
822
|
-
self.interrupt = True
|
823
|
-
raise StopIteration from None
|
810
|
+
threading.Thread(target=main_process_monitor, daemon=True).start()
|
824
811
|
|
825
|
-
|
826
|
-
|
812
|
+
# Run the command
|
813
|
+
sys.argv = args
|
814
|
+
flwr_clientapp()
|
827
815
|
|
828
816
|
|
829
817
|
def run_clientappio_api_grpc(
|
flwr/client/clientapp/app.py
CHANGED
@@ -16,7 +16,6 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
import argparse
|
19
|
-
import sys
|
20
19
|
import time
|
21
20
|
from logging import DEBUG, ERROR, INFO
|
22
21
|
from typing import Optional
|
@@ -29,6 +28,7 @@ from flwr.common import Context, Message
|
|
29
28
|
from flwr.common.args import add_args_flwr_app_common
|
30
29
|
from flwr.common.config import get_flwr_dir
|
31
30
|
from flwr.common.constant import CLIENTAPPIO_API_DEFAULT_CLIENT_ADDRESS, ErrorCode
|
31
|
+
from flwr.common.exit import ExitCode, flwr_exit
|
32
32
|
from flwr.common.grpc import create_channel
|
33
33
|
from flwr.common.logger import log
|
34
34
|
from flwr.common.message import Error
|
@@ -61,12 +61,10 @@ def flwr_clientapp() -> None:
|
|
61
61
|
"""Run process-isolated Flower ClientApp."""
|
62
62
|
args = _parse_args_run_flwr_clientapp().parse_args()
|
63
63
|
if not args.insecure:
|
64
|
-
|
65
|
-
|
66
|
-
"flwr-clientapp does not support TLS yet.
|
67
|
-
"Please use the '--insecure' flag.",
|
64
|
+
flwr_exit(
|
65
|
+
ExitCode.COMMON_TLS_NOT_SUPPORTED,
|
66
|
+
"flwr-clientapp does not support TLS yet.",
|
68
67
|
)
|
69
|
-
sys.exit(1)
|
70
68
|
|
71
69
|
log(INFO, "Starting Flower ClientApp")
|
72
70
|
log(
|
flwr/client/clientapp/utils.py
CHANGED
@@ -66,7 +66,7 @@ def get_load_client_app_fn(
|
|
66
66
|
# `fab_hash` is not required since the app is loaded from `runtime_app_dir`.
|
67
67
|
elif app_path is not None:
|
68
68
|
config = get_project_config(runtime_app_dir)
|
69
|
-
|
69
|
+
this_fab_id, this_fab_version = get_metadata_from_config(config)
|
70
70
|
|
71
71
|
if this_fab_version != fab_version or this_fab_id != fab_id:
|
72
72
|
raise LoadClientAppError(
|
@@ -47,12 +47,6 @@ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
|
47
47
|
)
|
48
48
|
from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611
|
49
49
|
|
50
|
-
# The following flags can be uncommented for debugging. Other possible values:
|
51
|
-
# https://github.com/grpc/grpc/blob/master/doc/environment_variables.md
|
52
|
-
# import os
|
53
|
-
# os.environ["GRPC_VERBOSITY"] = "debug"
|
54
|
-
# os.environ["GRPC_TRACE"] = "tcp,http"
|
55
|
-
|
56
50
|
|
57
51
|
def on_channel_state_change(channel_connectivity: str) -> None:
|
58
52
|
"""Log channel connectivity."""
|
@@ -15,67 +15,18 @@
|
|
15
15
|
"""Flower client interceptor."""
|
16
16
|
|
17
17
|
|
18
|
-
import
|
19
|
-
import collections
|
20
|
-
from collections.abc import Sequence
|
21
|
-
from logging import WARNING
|
22
|
-
from typing import Any, Callable, Optional, Union
|
18
|
+
from typing import Any, Callable
|
23
19
|
|
24
20
|
import grpc
|
25
21
|
from cryptography.hazmat.primitives.asymmetric import ec
|
22
|
+
from google.protobuf.message import Message as GrpcMessage
|
26
23
|
|
27
|
-
from flwr.common
|
24
|
+
from flwr.common import now
|
25
|
+
from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_HEADER
|
28
26
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
29
|
-
bytes_to_public_key,
|
30
|
-
compute_hmac,
|
31
|
-
generate_shared_key,
|
32
27
|
public_key_to_bytes,
|
28
|
+
sign_message,
|
33
29
|
)
|
34
|
-
from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611
|
35
|
-
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
36
|
-
CreateNodeRequest,
|
37
|
-
DeleteNodeRequest,
|
38
|
-
PingRequest,
|
39
|
-
PullTaskInsRequest,
|
40
|
-
PushTaskResRequest,
|
41
|
-
)
|
42
|
-
from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
|
43
|
-
|
44
|
-
_PUBLIC_KEY_HEADER = "public-key"
|
45
|
-
_AUTH_TOKEN_HEADER = "auth-token"
|
46
|
-
|
47
|
-
Request = Union[
|
48
|
-
CreateNodeRequest,
|
49
|
-
DeleteNodeRequest,
|
50
|
-
PullTaskInsRequest,
|
51
|
-
PushTaskResRequest,
|
52
|
-
GetRunRequest,
|
53
|
-
PingRequest,
|
54
|
-
GetFabRequest,
|
55
|
-
]
|
56
|
-
|
57
|
-
|
58
|
-
def _get_value_from_tuples(
|
59
|
-
key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]]
|
60
|
-
) -> bytes:
|
61
|
-
value = next((value for key, value in tuples if key == key_string), "")
|
62
|
-
if isinstance(value, str):
|
63
|
-
return value.encode()
|
64
|
-
|
65
|
-
return value
|
66
|
-
|
67
|
-
|
68
|
-
class _ClientCallDetails(
|
69
|
-
collections.namedtuple(
|
70
|
-
"_ClientCallDetails", ("method", "timeout", "metadata", "credentials")
|
71
|
-
),
|
72
|
-
grpc.ClientCallDetails, # type: ignore
|
73
|
-
):
|
74
|
-
"""Details for each client call.
|
75
|
-
|
76
|
-
The class will be passed on as the first argument in continuation function.
|
77
|
-
In our case, `AuthenticateClientInterceptor` adds new metadata to the construct.
|
78
|
-
"""
|
79
30
|
|
80
31
|
|
81
32
|
class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
|
@@ -87,84 +38,33 @@ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type:
|
|
87
38
|
public_key: ec.EllipticCurvePublicKey,
|
88
39
|
):
|
89
40
|
self.private_key = private_key
|
90
|
-
self.
|
91
|
-
self.shared_secret: Optional[bytes] = None
|
92
|
-
self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None
|
93
|
-
self.encoded_public_key = base64.urlsafe_b64encode(
|
94
|
-
public_key_to_bytes(self.public_key)
|
95
|
-
)
|
41
|
+
self.public_key_bytes = public_key_to_bytes(public_key)
|
96
42
|
|
97
43
|
def intercept_unary_unary(
|
98
44
|
self,
|
99
45
|
continuation: Callable[[Any, Any], Any],
|
100
46
|
client_call_details: grpc.ClientCallDetails,
|
101
|
-
request:
|
47
|
+
request: GrpcMessage,
|
102
48
|
) -> grpc.Call:
|
103
49
|
"""Flower client interceptor.
|
104
50
|
|
105
51
|
Intercept unary call from client and add necessary authentication header in the
|
106
52
|
RPC metadata.
|
107
53
|
"""
|
108
|
-
metadata = []
|
109
|
-
postprocess = False
|
110
|
-
if client_call_details.metadata is not None:
|
111
|
-
metadata = list(client_call_details.metadata)
|
112
|
-
|
113
|
-
# Always add the public key header
|
114
|
-
metadata.append(
|
115
|
-
(
|
116
|
-
_PUBLIC_KEY_HEADER,
|
117
|
-
self.encoded_public_key,
|
118
|
-
)
|
119
|
-
)
|
120
|
-
|
121
|
-
if isinstance(request, CreateNodeRequest):
|
122
|
-
postprocess = True
|
123
|
-
elif isinstance(
|
124
|
-
request,
|
125
|
-
(
|
126
|
-
DeleteNodeRequest,
|
127
|
-
PullTaskInsRequest,
|
128
|
-
PushTaskResRequest,
|
129
|
-
GetRunRequest,
|
130
|
-
PingRequest,
|
131
|
-
GetFabRequest,
|
132
|
-
),
|
133
|
-
):
|
134
|
-
if self.shared_secret is None:
|
135
|
-
raise RuntimeError("Failure to compute hmac")
|
136
|
-
|
137
|
-
message_bytes = request.SerializeToString(deterministic=True)
|
138
|
-
metadata.append(
|
139
|
-
(
|
140
|
-
_AUTH_TOKEN_HEADER,
|
141
|
-
base64.urlsafe_b64encode(
|
142
|
-
compute_hmac(self.shared_secret, message_bytes)
|
143
|
-
),
|
144
|
-
)
|
145
|
-
)
|
54
|
+
metadata = list(client_call_details.metadata or [])
|
146
55
|
|
147
|
-
|
148
|
-
|
149
|
-
client_call_details.timeout,
|
150
|
-
metadata,
|
151
|
-
client_call_details.credentials,
|
152
|
-
)
|
56
|
+
# Add the public key
|
57
|
+
metadata.append((PUBLIC_KEY_HEADER, self.public_key_bytes))
|
153
58
|
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
_get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
|
158
|
-
)
|
59
|
+
# Add timestamp
|
60
|
+
timestamp = now().isoformat()
|
61
|
+
metadata.append((TIMESTAMP_HEADER, timestamp))
|
159
62
|
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
log(WARNING, "Can't get server public key, SuperLink may be offline")
|
63
|
+
# Sign and add the signature
|
64
|
+
signature = sign_message(self.private_key, timestamp.encode("ascii"))
|
65
|
+
metadata.append((SIGNATURE_HEADER, signature))
|
164
66
|
|
165
|
-
|
166
|
-
|
167
|
-
self.private_key, self.server_public_key
|
168
|
-
)
|
67
|
+
# Overwrite the metadata
|
68
|
+
details = client_call_details._replace(metadata=metadata)
|
169
69
|
|
170
|
-
return
|
70
|
+
return continuation(details, request)
|
@@ -29,7 +29,6 @@ from cryptography.hazmat.primitives.asymmetric import ec
|
|
29
29
|
|
30
30
|
from flwr.client.heartbeat import start_ping_loop
|
31
31
|
from flwr.client.message_handler.message_handler import validate_out_message
|
32
|
-
from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins
|
33
32
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
34
33
|
from flwr.common.constant import (
|
35
34
|
PING_BASE_MULTIPLIER,
|
@@ -41,7 +40,7 @@ from flwr.common.grpc import create_channel
|
|
41
40
|
from flwr.common.logger import log
|
42
41
|
from flwr.common.message import Message, Metadata
|
43
42
|
from flwr.common.retry_invoker import RetryInvoker
|
44
|
-
from flwr.common.serde import
|
43
|
+
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
|
45
44
|
from flwr.common.typing import Fab, Run, RunNotRunningException
|
46
45
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
47
46
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
@@ -49,13 +48,13 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
49
48
|
DeleteNodeRequest,
|
50
49
|
PingRequest,
|
51
50
|
PingResponse,
|
52
|
-
|
53
|
-
|
51
|
+
PullMessagesRequest,
|
52
|
+
PullMessagesResponse,
|
53
|
+
PushMessagesRequest,
|
54
54
|
)
|
55
55
|
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
56
56
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
57
57
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
58
|
-
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
59
58
|
|
60
59
|
from .client_interceptor import AuthenticateClientInterceptor
|
61
60
|
from .grpc_adapter import GrpcAdapter
|
@@ -227,28 +226,31 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
227
226
|
node = None
|
228
227
|
|
229
228
|
def receive() -> Optional[Message]:
|
230
|
-
"""Receive next
|
229
|
+
"""Receive next message from server."""
|
231
230
|
# Get Node
|
232
231
|
if node is None:
|
233
232
|
log(ERROR, "Node instance missing")
|
234
233
|
return None
|
235
234
|
|
236
|
-
# Request instructions (
|
237
|
-
request =
|
238
|
-
response = retry_invoker.invoke(
|
235
|
+
# Request instructions (message) from server
|
236
|
+
request = PullMessagesRequest(node=node)
|
237
|
+
response: PullMessagesResponse = retry_invoker.invoke(
|
238
|
+
stub.PullMessages, request=request
|
239
|
+
)
|
239
240
|
|
240
|
-
# Get the current
|
241
|
-
|
241
|
+
# Get the current Messages
|
242
|
+
message_proto = (
|
243
|
+
None if len(response.messages_list) == 0 else response.messages_list[0]
|
244
|
+
)
|
242
245
|
|
243
|
-
# Discard the current
|
244
|
-
if
|
245
|
-
|
246
|
-
and validate_task_ins(task_ins)
|
246
|
+
# Discard the current message if not valid
|
247
|
+
if message_proto is not None and not (
|
248
|
+
message_proto.metadata.dst_node_id == node.node_id
|
247
249
|
):
|
248
|
-
|
250
|
+
message_proto = None
|
249
251
|
|
250
252
|
# Construct the Message
|
251
|
-
in_message =
|
253
|
+
in_message = message_from_proto(message_proto) if message_proto else None
|
252
254
|
|
253
255
|
# Remember `metadata` of the in message
|
254
256
|
nonlocal metadata
|
@@ -258,7 +260,7 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
258
260
|
return in_message
|
259
261
|
|
260
262
|
def send(message: Message) -> None:
|
261
|
-
"""Send
|
263
|
+
"""Send message reply to server."""
|
262
264
|
# Get Node
|
263
265
|
if node is None:
|
264
266
|
log(ERROR, "Node instance missing")
|
@@ -275,12 +277,10 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
275
277
|
log(ERROR, "Invalid out message")
|
276
278
|
return
|
277
279
|
|
278
|
-
#
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
request = PushTaskResRequest(node=node, task_res_list=[task_res])
|
283
|
-
_ = retry_invoker.invoke(stub.PushTaskRes, request)
|
280
|
+
# Serialize Message
|
281
|
+
message_proto = message_to_proto(message=message)
|
282
|
+
request = PushMessagesRequest(node=node, messages_list=[message_proto])
|
283
|
+
_ = retry_invoker.invoke(stub.PushMessages, request)
|
284
284
|
|
285
285
|
# Cleanup
|
286
286
|
metadata = None
|
@@ -311,3 +311,13 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
311
311
|
yield (receive, send, create_node, delete_node, get_run, get_fab)
|
312
312
|
except Exception as exc: # pylint: disable=broad-except
|
313
313
|
log(ERROR, exc)
|
314
|
+
# Cleanup
|
315
|
+
finally:
|
316
|
+
try:
|
317
|
+
if node is not None:
|
318
|
+
# Disable retrying
|
319
|
+
retry_invoker.max_tries = 1
|
320
|
+
delete_node()
|
321
|
+
except grpc.RpcError:
|
322
|
+
pass
|
323
|
+
channel.close()
|
@@ -40,8 +40,12 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
40
40
|
DeleteNodeResponse,
|
41
41
|
PingRequest,
|
42
42
|
PingResponse,
|
43
|
+
PullMessagesRequest,
|
44
|
+
PullMessagesResponse,
|
43
45
|
PullTaskInsRequest,
|
44
46
|
PullTaskInsResponse,
|
47
|
+
PushMessagesRequest,
|
48
|
+
PushMessagesResponse,
|
45
49
|
PushTaskResRequest,
|
46
50
|
PushTaskResResponse,
|
47
51
|
)
|
@@ -132,12 +136,24 @@ class GrpcAdapter:
|
|
132
136
|
"""."""
|
133
137
|
return self._send_and_receive(request, PullTaskInsResponse, **kwargs)
|
134
138
|
|
139
|
+
def PullMessages( # pylint: disable=C0103
|
140
|
+
self, request: PullMessagesRequest, **kwargs: Any
|
141
|
+
) -> PullMessagesResponse:
|
142
|
+
"""."""
|
143
|
+
return self._send_and_receive(request, PullMessagesResponse, **kwargs)
|
144
|
+
|
135
145
|
def PushTaskRes( # pylint: disable=C0103
|
136
146
|
self, request: PushTaskResRequest, **kwargs: Any
|
137
147
|
) -> PushTaskResResponse:
|
138
148
|
"""."""
|
139
149
|
return self._send_and_receive(request, PushTaskResResponse, **kwargs)
|
140
150
|
|
151
|
+
def PushMessages( # pylint: disable=C0103
|
152
|
+
self, request: PushMessagesRequest, **kwargs: Any
|
153
|
+
) -> PushMessagesResponse:
|
154
|
+
"""."""
|
155
|
+
return self._send_and_receive(request, PushMessagesResponse, **kwargs)
|
156
|
+
|
141
157
|
def GetRun( # pylint: disable=C0103
|
142
158
|
self, request: GetRunRequest, **kwargs: Any
|
143
159
|
) -> GetRunResponse:
|