flwr 1.14.0__py3-none-any.whl → 1.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/auth_plugin/__init__.py +31 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +150 -0
- flwr/cli/cli_user_auth_interceptor.py +6 -2
- flwr/cli/config_utils.py +24 -147
- flwr/cli/constant.py +27 -0
- flwr/cli/install.py +1 -1
- flwr/cli/log.py +18 -3
- flwr/cli/login/login.py +43 -8
- flwr/cli/ls.py +14 -5
- flwr/cli/new/templates/app/README.md.tpl +3 -2
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/run/run.py +21 -11
- flwr/cli/stop.py +13 -4
- flwr/cli/utils.py +54 -40
- flwr/client/app.py +36 -48
- flwr/client/clientapp/app.py +19 -25
- flwr/client/clientapp/utils.py +1 -1
- flwr/client/grpc_client/connection.py +1 -12
- flwr/client/grpc_rere_client/client_interceptor.py +19 -119
- flwr/client/grpc_rere_client/connection.py +46 -36
- flwr/client/grpc_rere_client/grpc_adapter.py +12 -12
- flwr/client/message_handler/task_handler.py +0 -17
- flwr/client/rest_client/connection.py +34 -26
- flwr/client/supernode/app.py +18 -72
- flwr/common/args.py +25 -47
- flwr/common/auth_plugin/auth_plugin.py +34 -23
- flwr/common/config.py +166 -16
- flwr/common/constant.py +22 -9
- flwr/common/differential_privacy.py +2 -1
- flwr/common/exit/__init__.py +24 -0
- flwr/common/exit/exit.py +99 -0
- flwr/common/exit/exit_code.py +93 -0
- flwr/common/exit_handlers.py +24 -10
- flwr/common/grpc.py +167 -4
- flwr/common/logger.py +26 -7
- flwr/common/record/recordset.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
- flwr/common/serde.py +6 -4
- flwr/common/typing.py +20 -0
- flwr/proto/clientappio_pb2.py +1 -1
- flwr/proto/error_pb2.py +1 -1
- flwr/proto/exec_pb2.py +13 -25
- flwr/proto/exec_pb2.pyi +27 -54
- flwr/proto/fab_pb2.py +1 -1
- flwr/proto/fleet_pb2.py +31 -31
- flwr/proto/fleet_pb2.pyi +23 -23
- flwr/proto/fleet_pb2_grpc.py +30 -30
- flwr/proto/fleet_pb2_grpc.pyi +20 -20
- flwr/proto/grpcadapter_pb2.py +1 -1
- flwr/proto/log_pb2.py +1 -1
- flwr/proto/message_pb2.py +1 -1
- flwr/proto/node_pb2.py +3 -3
- flwr/proto/node_pb2.pyi +1 -4
- flwr/proto/recordset_pb2.py +1 -1
- flwr/proto/run_pb2.py +1 -1
- flwr/proto/serverappio_pb2.py +24 -25
- flwr/proto/serverappio_pb2.pyi +26 -32
- flwr/proto/serverappio_pb2_grpc.py +28 -28
- flwr/proto/serverappio_pb2_grpc.pyi +16 -16
- flwr/proto/simulationio_pb2.py +1 -1
- flwr/proto/task_pb2.py +1 -1
- flwr/proto/transport_pb2.py +1 -1
- flwr/server/app.py +116 -128
- flwr/server/compat/app_utils.py +0 -1
- flwr/server/compat/driver_client_proxy.py +1 -2
- flwr/server/driver/grpc_driver.py +32 -27
- flwr/server/driver/inmemory_driver.py +2 -1
- flwr/server/serverapp/app.py +12 -10
- flwr/server/superlink/driver/serverappio_grpc.py +1 -1
- flwr/server/superlink/driver/serverappio_servicer.py +74 -48
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -88
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -24
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +97 -168
- flwr/server/superlink/fleet/message_handler/message_handler.py +37 -24
- flwr/server/superlink/fleet/rest_rere/rest_api.py +16 -18
- flwr/server/superlink/fleet/vce/vce_api.py +2 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +45 -75
- flwr/server/superlink/linkstate/linkstate.py +17 -38
- flwr/server/superlink/linkstate/sqlite_linkstate.py +81 -145
- flwr/server/superlink/linkstate/utils.py +18 -8
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/utils/validator.py +9 -34
- flwr/simulation/app.py +4 -6
- flwr/simulation/legacy_app.py +4 -2
- flwr/simulation/run_simulation.py +1 -1
- flwr/simulation/simulationio_connection.py +2 -1
- flwr/superexec/exec_grpc.py +1 -1
- flwr/superexec/exec_servicer.py +23 -2
- {flwr-1.14.0.dist-info → flwr-1.15.0.dist-info}/METADATA +8 -8
- {flwr-1.14.0.dist-info → flwr-1.15.0.dist-info}/RECORD +102 -96
- {flwr-1.14.0.dist-info → flwr-1.15.0.dist-info}/LICENSE +0 -0
- {flwr-1.14.0.dist-info → flwr-1.15.0.dist-info}/WHEEL +0 -0
- {flwr-1.14.0.dist-info → flwr-1.15.0.dist-info}/entry_points.txt +0 -0
flwr/client/clientapp/app.py
CHANGED
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import argparse
|
|
19
|
-
import sys
|
|
20
19
|
import time
|
|
21
20
|
from logging import DEBUG, ERROR, INFO
|
|
22
21
|
from typing import Optional
|
|
@@ -29,7 +28,8 @@ from flwr.common import Context, Message
|
|
|
29
28
|
from flwr.common.args import add_args_flwr_app_common
|
|
30
29
|
from flwr.common.config import get_flwr_dir
|
|
31
30
|
from flwr.common.constant import CLIENTAPPIO_API_DEFAULT_CLIENT_ADDRESS, ErrorCode
|
|
32
|
-
from flwr.common.
|
|
31
|
+
from flwr.common.exit import ExitCode, flwr_exit
|
|
32
|
+
from flwr.common.grpc import create_channel, on_channel_state_change
|
|
33
33
|
from flwr.common.logger import log
|
|
34
34
|
from flwr.common.message import Error
|
|
35
35
|
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
|
|
@@ -61,18 +61,16 @@ def flwr_clientapp() -> None:
|
|
|
61
61
|
"""Run process-isolated Flower ClientApp."""
|
|
62
62
|
args = _parse_args_run_flwr_clientapp().parse_args()
|
|
63
63
|
if not args.insecure:
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
"flwr-clientapp does not support TLS yet.
|
|
67
|
-
"Please use the '--insecure' flag.",
|
|
64
|
+
flwr_exit(
|
|
65
|
+
ExitCode.COMMON_TLS_NOT_SUPPORTED,
|
|
66
|
+
"flwr-clientapp does not support TLS yet.",
|
|
68
67
|
)
|
|
69
|
-
sys.exit(1)
|
|
70
68
|
|
|
71
|
-
log(INFO, "
|
|
69
|
+
log(INFO, "Start `flwr-clientapp` process")
|
|
72
70
|
log(
|
|
73
71
|
DEBUG,
|
|
74
|
-
"
|
|
75
|
-
"with token %s",
|
|
72
|
+
"`flwr-clientapp` will attempt to connect to SuperNode's "
|
|
73
|
+
"ClientAppIo API at %s with token %s",
|
|
76
74
|
args.clientappio_api_address,
|
|
77
75
|
args.token,
|
|
78
76
|
)
|
|
@@ -85,11 +83,6 @@ def flwr_clientapp() -> None:
|
|
|
85
83
|
)
|
|
86
84
|
|
|
87
85
|
|
|
88
|
-
def on_channel_state_change(channel_connectivity: str) -> None:
|
|
89
|
-
"""Log channel connectivity."""
|
|
90
|
-
log(DEBUG, channel_connectivity)
|
|
91
|
-
|
|
92
|
-
|
|
93
86
|
def run_clientapp( # pylint: disable=R0914
|
|
94
87
|
clientappio_api_address: str,
|
|
95
88
|
run_once: bool,
|
|
@@ -118,11 +111,11 @@ def run_clientapp( # pylint: disable=R0914
|
|
|
118
111
|
time.sleep(1)
|
|
119
112
|
|
|
120
113
|
# Pull Message, Context, Run and (optional) FAB from SuperNode
|
|
121
|
-
message, context, run, fab =
|
|
114
|
+
message, context, run, fab = pull_clientappinputs(stub=stub, token=token)
|
|
122
115
|
|
|
123
116
|
# Install FAB, if provided
|
|
124
117
|
if fab:
|
|
125
|
-
log(DEBUG, "
|
|
118
|
+
log(DEBUG, "[flwr-clientapp] Start FAB installation.")
|
|
126
119
|
install_from_fab(fab.content, flwr_dir=flwr_dir_, skip_prompt=True)
|
|
127
120
|
|
|
128
121
|
load_client_app_fn = get_load_client_app_fn(
|
|
@@ -134,6 +127,7 @@ def run_clientapp( # pylint: disable=R0914
|
|
|
134
127
|
|
|
135
128
|
try:
|
|
136
129
|
# Load ClientApp
|
|
130
|
+
log(DEBUG, "[flwr-clientapp] Start `ClientApp` Loading.")
|
|
137
131
|
client_app: ClientApp = load_client_app_fn(
|
|
138
132
|
run.fab_id, run.fab_version, fab.hash_str if fab else ""
|
|
139
133
|
)
|
|
@@ -162,7 +156,7 @@ def run_clientapp( # pylint: disable=R0914
|
|
|
162
156
|
)
|
|
163
157
|
|
|
164
158
|
# Push Message and Context to SuperNode
|
|
165
|
-
_ =
|
|
159
|
+
_ = push_clientappoutputs(
|
|
166
160
|
stub=stub, token=token, message=reply_message, context=context
|
|
167
161
|
)
|
|
168
162
|
|
|
@@ -185,7 +179,7 @@ def run_clientapp( # pylint: disable=R0914
|
|
|
185
179
|
|
|
186
180
|
def get_token(stub: grpc.Channel) -> Optional[int]:
|
|
187
181
|
"""Get a token from SuperNode."""
|
|
188
|
-
log(DEBUG, "
|
|
182
|
+
log(DEBUG, "[flwr-clientapp] Request token")
|
|
189
183
|
try:
|
|
190
184
|
res: GetTokenResponse = stub.GetToken(GetTokenRequest())
|
|
191
185
|
log(DEBUG, "[GetToken] Received token: %s", res.token)
|
|
@@ -198,11 +192,11 @@ def get_token(stub: grpc.Channel) -> Optional[int]:
|
|
|
198
192
|
return None
|
|
199
193
|
|
|
200
194
|
|
|
201
|
-
def
|
|
195
|
+
def pull_clientappinputs(
|
|
202
196
|
stub: grpc.Channel, token: int
|
|
203
197
|
) -> tuple[Message, Context, Run, Optional[Fab]]:
|
|
204
|
-
"""Pull
|
|
205
|
-
log(INFO, "
|
|
198
|
+
"""Pull ClientAppInputs from SuperNode."""
|
|
199
|
+
log(INFO, "[flwr-clientapp] Pull `ClientAppInputs` for token %s", token)
|
|
206
200
|
try:
|
|
207
201
|
res: PullClientAppInputsResponse = stub.PullClientAppInputs(
|
|
208
202
|
PullClientAppInputsRequest(token=token)
|
|
@@ -217,11 +211,11 @@ def pull_message(
|
|
|
217
211
|
raise e
|
|
218
212
|
|
|
219
213
|
|
|
220
|
-
def
|
|
214
|
+
def push_clientappoutputs(
|
|
221
215
|
stub: grpc.Channel, token: int, message: Message, context: Context
|
|
222
216
|
) -> PushClientAppOutputsResponse:
|
|
223
|
-
"""Push
|
|
224
|
-
log(INFO, "
|
|
217
|
+
"""Push ClientAppOutputs to SuperNode."""
|
|
218
|
+
log(INFO, "[flwr-clientapp] Push `ClientAppOutputs` for token %s", token)
|
|
225
219
|
proto_message = message_to_proto(message)
|
|
226
220
|
proto_context = context_to_proto(context)
|
|
227
221
|
|
flwr/client/clientapp/utils.py
CHANGED
|
@@ -66,7 +66,7 @@ def get_load_client_app_fn(
|
|
|
66
66
|
# `fab_hash` is not required since the app is loaded from `runtime_app_dir`.
|
|
67
67
|
elif app_path is not None:
|
|
68
68
|
config = get_project_config(runtime_app_dir)
|
|
69
|
-
|
|
69
|
+
this_fab_id, this_fab_version = get_metadata_from_config(config)
|
|
70
70
|
|
|
71
71
|
if this_fab_version != fab_version or this_fab_id != fab_id:
|
|
72
72
|
raise LoadClientAppError(
|
|
@@ -36,7 +36,7 @@ from flwr.common import (
|
|
|
36
36
|
from flwr.common import recordset_compat as compat
|
|
37
37
|
from flwr.common import serde
|
|
38
38
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
39
|
-
from flwr.common.grpc import create_channel
|
|
39
|
+
from flwr.common.grpc import create_channel, on_channel_state_change
|
|
40
40
|
from flwr.common.logger import log
|
|
41
41
|
from flwr.common.retry_invoker import RetryInvoker
|
|
42
42
|
from flwr.common.typing import Fab, Run
|
|
@@ -47,17 +47,6 @@ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
|
|
47
47
|
)
|
|
48
48
|
from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611
|
|
49
49
|
|
|
50
|
-
# The following flags can be uncommented for debugging. Other possible values:
|
|
51
|
-
# https://github.com/grpc/grpc/blob/master/doc/environment_variables.md
|
|
52
|
-
# import os
|
|
53
|
-
# os.environ["GRPC_VERBOSITY"] = "debug"
|
|
54
|
-
# os.environ["GRPC_TRACE"] = "tcp,http"
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
def on_channel_state_change(channel_connectivity: str) -> None:
|
|
58
|
-
"""Log channel connectivity."""
|
|
59
|
-
log(DEBUG, channel_connectivity)
|
|
60
|
-
|
|
61
50
|
|
|
62
51
|
@contextmanager
|
|
63
52
|
def grpc_connection( # pylint: disable=R0913,R0915,too-many-positional-arguments
|
|
@@ -15,67 +15,18 @@
|
|
|
15
15
|
"""Flower client interceptor."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import
|
|
19
|
-
import collections
|
|
20
|
-
from collections.abc import Sequence
|
|
21
|
-
from logging import WARNING
|
|
22
|
-
from typing import Any, Callable, Optional, Union
|
|
18
|
+
from typing import Any, Callable
|
|
23
19
|
|
|
24
20
|
import grpc
|
|
25
21
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
22
|
+
from google.protobuf.message import Message as GrpcMessage
|
|
26
23
|
|
|
27
|
-
from flwr.common
|
|
24
|
+
from flwr.common import now
|
|
25
|
+
from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_HEADER
|
|
28
26
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
29
|
-
bytes_to_public_key,
|
|
30
|
-
compute_hmac,
|
|
31
|
-
generate_shared_key,
|
|
32
27
|
public_key_to_bytes,
|
|
28
|
+
sign_message,
|
|
33
29
|
)
|
|
34
|
-
from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611
|
|
35
|
-
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
36
|
-
CreateNodeRequest,
|
|
37
|
-
DeleteNodeRequest,
|
|
38
|
-
PingRequest,
|
|
39
|
-
PullTaskInsRequest,
|
|
40
|
-
PushTaskResRequest,
|
|
41
|
-
)
|
|
42
|
-
from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
|
|
43
|
-
|
|
44
|
-
_PUBLIC_KEY_HEADER = "public-key"
|
|
45
|
-
_AUTH_TOKEN_HEADER = "auth-token"
|
|
46
|
-
|
|
47
|
-
Request = Union[
|
|
48
|
-
CreateNodeRequest,
|
|
49
|
-
DeleteNodeRequest,
|
|
50
|
-
PullTaskInsRequest,
|
|
51
|
-
PushTaskResRequest,
|
|
52
|
-
GetRunRequest,
|
|
53
|
-
PingRequest,
|
|
54
|
-
GetFabRequest,
|
|
55
|
-
]
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
def _get_value_from_tuples(
|
|
59
|
-
key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]]
|
|
60
|
-
) -> bytes:
|
|
61
|
-
value = next((value for key, value in tuples if key == key_string), "")
|
|
62
|
-
if isinstance(value, str):
|
|
63
|
-
return value.encode()
|
|
64
|
-
|
|
65
|
-
return value
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
class _ClientCallDetails(
|
|
69
|
-
collections.namedtuple(
|
|
70
|
-
"_ClientCallDetails", ("method", "timeout", "metadata", "credentials")
|
|
71
|
-
),
|
|
72
|
-
grpc.ClientCallDetails, # type: ignore
|
|
73
|
-
):
|
|
74
|
-
"""Details for each client call.
|
|
75
|
-
|
|
76
|
-
The class will be passed on as the first argument in continuation function.
|
|
77
|
-
In our case, `AuthenticateClientInterceptor` adds new metadata to the construct.
|
|
78
|
-
"""
|
|
79
30
|
|
|
80
31
|
|
|
81
32
|
class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
|
|
@@ -87,84 +38,33 @@ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type:
|
|
|
87
38
|
public_key: ec.EllipticCurvePublicKey,
|
|
88
39
|
):
|
|
89
40
|
self.private_key = private_key
|
|
90
|
-
self.
|
|
91
|
-
self.shared_secret: Optional[bytes] = None
|
|
92
|
-
self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None
|
|
93
|
-
self.encoded_public_key = base64.urlsafe_b64encode(
|
|
94
|
-
public_key_to_bytes(self.public_key)
|
|
95
|
-
)
|
|
41
|
+
self.public_key_bytes = public_key_to_bytes(public_key)
|
|
96
42
|
|
|
97
43
|
def intercept_unary_unary(
|
|
98
44
|
self,
|
|
99
45
|
continuation: Callable[[Any, Any], Any],
|
|
100
46
|
client_call_details: grpc.ClientCallDetails,
|
|
101
|
-
request:
|
|
47
|
+
request: GrpcMessage,
|
|
102
48
|
) -> grpc.Call:
|
|
103
49
|
"""Flower client interceptor.
|
|
104
50
|
|
|
105
51
|
Intercept unary call from client and add necessary authentication header in the
|
|
106
52
|
RPC metadata.
|
|
107
53
|
"""
|
|
108
|
-
metadata = []
|
|
109
|
-
postprocess = False
|
|
110
|
-
if client_call_details.metadata is not None:
|
|
111
|
-
metadata = list(client_call_details.metadata)
|
|
112
|
-
|
|
113
|
-
# Always add the public key header
|
|
114
|
-
metadata.append(
|
|
115
|
-
(
|
|
116
|
-
_PUBLIC_KEY_HEADER,
|
|
117
|
-
self.encoded_public_key,
|
|
118
|
-
)
|
|
119
|
-
)
|
|
120
|
-
|
|
121
|
-
if isinstance(request, CreateNodeRequest):
|
|
122
|
-
postprocess = True
|
|
123
|
-
elif isinstance(
|
|
124
|
-
request,
|
|
125
|
-
(
|
|
126
|
-
DeleteNodeRequest,
|
|
127
|
-
PullTaskInsRequest,
|
|
128
|
-
PushTaskResRequest,
|
|
129
|
-
GetRunRequest,
|
|
130
|
-
PingRequest,
|
|
131
|
-
GetFabRequest,
|
|
132
|
-
),
|
|
133
|
-
):
|
|
134
|
-
if self.shared_secret is None:
|
|
135
|
-
raise RuntimeError("Failure to compute hmac")
|
|
136
|
-
|
|
137
|
-
message_bytes = request.SerializeToString(deterministic=True)
|
|
138
|
-
metadata.append(
|
|
139
|
-
(
|
|
140
|
-
_AUTH_TOKEN_HEADER,
|
|
141
|
-
base64.urlsafe_b64encode(
|
|
142
|
-
compute_hmac(self.shared_secret, message_bytes)
|
|
143
|
-
),
|
|
144
|
-
)
|
|
145
|
-
)
|
|
54
|
+
metadata = list(client_call_details.metadata or [])
|
|
146
55
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
client_call_details.timeout,
|
|
150
|
-
metadata,
|
|
151
|
-
client_call_details.credentials,
|
|
152
|
-
)
|
|
56
|
+
# Add the public key
|
|
57
|
+
metadata.append((PUBLIC_KEY_HEADER, self.public_key_bytes))
|
|
153
58
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
_get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
|
|
158
|
-
)
|
|
59
|
+
# Add timestamp
|
|
60
|
+
timestamp = now().isoformat()
|
|
61
|
+
metadata.append((TIMESTAMP_HEADER, timestamp))
|
|
159
62
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
log(WARNING, "Can't get server public key, SuperLink may be offline")
|
|
63
|
+
# Sign and add the signature
|
|
64
|
+
signature = sign_message(self.private_key, timestamp.encode("ascii"))
|
|
65
|
+
metadata.append((SIGNATURE_HEADER, signature))
|
|
164
66
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
self.private_key, self.server_public_key
|
|
168
|
-
)
|
|
67
|
+
# Overwrite the metadata
|
|
68
|
+
details = client_call_details._replace(metadata=metadata)
|
|
169
69
|
|
|
170
|
-
return
|
|
70
|
+
return continuation(details, request)
|
|
@@ -20,7 +20,7 @@ import threading
|
|
|
20
20
|
from collections.abc import Iterator, Sequence
|
|
21
21
|
from contextlib import contextmanager
|
|
22
22
|
from copy import copy
|
|
23
|
-
from logging import
|
|
23
|
+
from logging import ERROR
|
|
24
24
|
from pathlib import Path
|
|
25
25
|
from typing import Callable, Optional, Union, cast
|
|
26
26
|
|
|
@@ -29,7 +29,6 @@ from cryptography.hazmat.primitives.asymmetric import ec
|
|
|
29
29
|
|
|
30
30
|
from flwr.client.heartbeat import start_ping_loop
|
|
31
31
|
from flwr.client.message_handler.message_handler import validate_out_message
|
|
32
|
-
from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins
|
|
33
32
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
34
33
|
from flwr.common.constant import (
|
|
35
34
|
PING_BASE_MULTIPLIER,
|
|
@@ -37,11 +36,14 @@ from flwr.common.constant import (
|
|
|
37
36
|
PING_DEFAULT_INTERVAL,
|
|
38
37
|
PING_RANDOM_RANGE,
|
|
39
38
|
)
|
|
40
|
-
from flwr.common.grpc import create_channel
|
|
39
|
+
from flwr.common.grpc import create_channel, on_channel_state_change
|
|
41
40
|
from flwr.common.logger import log
|
|
42
41
|
from flwr.common.message import Message, Metadata
|
|
43
42
|
from flwr.common.retry_invoker import RetryInvoker
|
|
44
|
-
from flwr.common.
|
|
43
|
+
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
44
|
+
generate_key_pairs,
|
|
45
|
+
)
|
|
46
|
+
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
|
|
45
47
|
from flwr.common.typing import Fab, Run, RunNotRunningException
|
|
46
48
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
47
49
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
@@ -49,23 +51,18 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
49
51
|
DeleteNodeRequest,
|
|
50
52
|
PingRequest,
|
|
51
53
|
PingResponse,
|
|
52
|
-
|
|
53
|
-
|
|
54
|
+
PullMessagesRequest,
|
|
55
|
+
PullMessagesResponse,
|
|
56
|
+
PushMessagesRequest,
|
|
54
57
|
)
|
|
55
58
|
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
|
56
59
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
57
60
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
58
|
-
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
59
61
|
|
|
60
62
|
from .client_interceptor import AuthenticateClientInterceptor
|
|
61
63
|
from .grpc_adapter import GrpcAdapter
|
|
62
64
|
|
|
63
65
|
|
|
64
|
-
def on_channel_state_change(channel_connectivity: str) -> None:
|
|
65
|
-
"""Log channel connectivity."""
|
|
66
|
-
log(DEBUG, channel_connectivity)
|
|
67
|
-
|
|
68
|
-
|
|
69
66
|
@contextmanager
|
|
70
67
|
def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
71
68
|
server_address: str,
|
|
@@ -131,12 +128,14 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
131
128
|
if isinstance(root_certificates, str):
|
|
132
129
|
root_certificates = Path(root_certificates).read_bytes()
|
|
133
130
|
|
|
134
|
-
|
|
135
|
-
if authentication_keys is
|
|
136
|
-
|
|
137
|
-
authentication_keys[0], authentication_keys[1]
|
|
138
|
-
)
|
|
131
|
+
# Automatic node auth: generate keys if user didn't provide any
|
|
132
|
+
if authentication_keys is None:
|
|
133
|
+
authentication_keys = generate_key_pairs()
|
|
139
134
|
|
|
135
|
+
# Always configure auth interceptor, with either user-provided or generated keys
|
|
136
|
+
interceptors: Sequence[grpc.UnaryUnaryClientInterceptor] = [
|
|
137
|
+
AuthenticateClientInterceptor(*authentication_keys),
|
|
138
|
+
]
|
|
140
139
|
channel = create_channel(
|
|
141
140
|
server_address=server_address,
|
|
142
141
|
insecure=insecure,
|
|
@@ -227,28 +226,31 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
227
226
|
node = None
|
|
228
227
|
|
|
229
228
|
def receive() -> Optional[Message]:
|
|
230
|
-
"""Receive next
|
|
229
|
+
"""Receive next message from server."""
|
|
231
230
|
# Get Node
|
|
232
231
|
if node is None:
|
|
233
232
|
log(ERROR, "Node instance missing")
|
|
234
233
|
return None
|
|
235
234
|
|
|
236
|
-
# Request instructions (
|
|
237
|
-
request =
|
|
238
|
-
response = retry_invoker.invoke(
|
|
235
|
+
# Request instructions (message) from server
|
|
236
|
+
request = PullMessagesRequest(node=node)
|
|
237
|
+
response: PullMessagesResponse = retry_invoker.invoke(
|
|
238
|
+
stub.PullMessages, request=request
|
|
239
|
+
)
|
|
239
240
|
|
|
240
|
-
# Get the current
|
|
241
|
-
|
|
241
|
+
# Get the current Messages
|
|
242
|
+
message_proto = (
|
|
243
|
+
None if len(response.messages_list) == 0 else response.messages_list[0]
|
|
244
|
+
)
|
|
242
245
|
|
|
243
|
-
# Discard the current
|
|
244
|
-
if
|
|
245
|
-
|
|
246
|
-
and validate_task_ins(task_ins)
|
|
246
|
+
# Discard the current message if not valid
|
|
247
|
+
if message_proto is not None and not (
|
|
248
|
+
message_proto.metadata.dst_node_id == node.node_id
|
|
247
249
|
):
|
|
248
|
-
|
|
250
|
+
message_proto = None
|
|
249
251
|
|
|
250
252
|
# Construct the Message
|
|
251
|
-
in_message =
|
|
253
|
+
in_message = message_from_proto(message_proto) if message_proto else None
|
|
252
254
|
|
|
253
255
|
# Remember `metadata` of the in message
|
|
254
256
|
nonlocal metadata
|
|
@@ -258,7 +260,7 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
258
260
|
return in_message
|
|
259
261
|
|
|
260
262
|
def send(message: Message) -> None:
|
|
261
|
-
"""Send
|
|
263
|
+
"""Send message reply to server."""
|
|
262
264
|
# Get Node
|
|
263
265
|
if node is None:
|
|
264
266
|
log(ERROR, "Node instance missing")
|
|
@@ -275,12 +277,10 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
275
277
|
log(ERROR, "Invalid out message")
|
|
276
278
|
return
|
|
277
279
|
|
|
278
|
-
#
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
request = PushTaskResRequest(node=node, task_res_list=[task_res])
|
|
283
|
-
_ = retry_invoker.invoke(stub.PushTaskRes, request)
|
|
280
|
+
# Serialize Message
|
|
281
|
+
message_proto = message_to_proto(message=message)
|
|
282
|
+
request = PushMessagesRequest(node=node, messages_list=[message_proto])
|
|
283
|
+
_ = retry_invoker.invoke(stub.PushMessages, request)
|
|
284
284
|
|
|
285
285
|
# Cleanup
|
|
286
286
|
metadata = None
|
|
@@ -311,3 +311,13 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
311
311
|
yield (receive, send, create_node, delete_node, get_run, get_fab)
|
|
312
312
|
except Exception as exc: # pylint: disable=broad-except
|
|
313
313
|
log(ERROR, exc)
|
|
314
|
+
# Cleanup
|
|
315
|
+
finally:
|
|
316
|
+
try:
|
|
317
|
+
if node is not None:
|
|
318
|
+
# Disable retrying
|
|
319
|
+
retry_invoker.max_tries = 1
|
|
320
|
+
delete_node()
|
|
321
|
+
except grpc.RpcError:
|
|
322
|
+
pass
|
|
323
|
+
channel.close()
|
|
@@ -40,10 +40,10 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
40
40
|
DeleteNodeResponse,
|
|
41
41
|
PingRequest,
|
|
42
42
|
PingResponse,
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
43
|
+
PullMessagesRequest,
|
|
44
|
+
PullMessagesResponse,
|
|
45
|
+
PushMessagesRequest,
|
|
46
|
+
PushMessagesResponse,
|
|
47
47
|
)
|
|
48
48
|
from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
|
|
49
49
|
from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub
|
|
@@ -126,17 +126,17 @@ class GrpcAdapter:
|
|
|
126
126
|
"""."""
|
|
127
127
|
return self._send_and_receive(request, PingResponse, **kwargs)
|
|
128
128
|
|
|
129
|
-
def
|
|
130
|
-
self, request:
|
|
131
|
-
) ->
|
|
129
|
+
def PullMessages( # pylint: disable=C0103
|
|
130
|
+
self, request: PullMessagesRequest, **kwargs: Any
|
|
131
|
+
) -> PullMessagesResponse:
|
|
132
132
|
"""."""
|
|
133
|
-
return self._send_and_receive(request,
|
|
133
|
+
return self._send_and_receive(request, PullMessagesResponse, **kwargs)
|
|
134
134
|
|
|
135
|
-
def
|
|
136
|
-
self, request:
|
|
137
|
-
) ->
|
|
135
|
+
def PushMessages( # pylint: disable=C0103
|
|
136
|
+
self, request: PushMessagesRequest, **kwargs: Any
|
|
137
|
+
) -> PushMessagesResponse:
|
|
138
138
|
"""."""
|
|
139
|
-
return self._send_and_receive(request,
|
|
139
|
+
return self._send_and_receive(request, PushMessagesResponse, **kwargs)
|
|
140
140
|
|
|
141
141
|
def GetRun( # pylint: disable=C0103
|
|
142
142
|
self, request: GetRunRequest, **kwargs: Any
|
|
@@ -15,9 +15,6 @@
|
|
|
15
15
|
"""Task handling."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import Optional
|
|
19
|
-
|
|
20
|
-
from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611
|
|
21
18
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
22
19
|
|
|
23
20
|
|
|
@@ -38,17 +35,3 @@ def validate_task_ins(task_ins: TaskIns) -> bool:
|
|
|
38
35
|
if not (task_ins.HasField("task") and task_ins.task.HasField("recordset")):
|
|
39
36
|
return False
|
|
40
37
|
return True
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
def get_task_ins(
|
|
44
|
-
pull_task_ins_response: PullTaskInsResponse,
|
|
45
|
-
) -> Optional[TaskIns]:
|
|
46
|
-
"""Get the first TaskIns, if available."""
|
|
47
|
-
# Extract a single ServerMessage from the response, if possible
|
|
48
|
-
if len(pull_task_ins_response.task_ins_list) == 0:
|
|
49
|
-
return None
|
|
50
|
-
|
|
51
|
-
# Only evaluate the first message
|
|
52
|
-
task_ins: TaskIns = pull_task_ins_response.task_ins_list[0]
|
|
53
|
-
|
|
54
|
-
return task_ins
|