flwr-nightly 1.9.0.dev20240531__py3-none-any.whl → 1.10.0.dev20240619__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +4 -15
- flwr/cli/config_utils.py +64 -7
- flwr/cli/install.py +211 -0
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +39 -2
- flwr/cli/utils.py +14 -0
- flwr/client/__init__.py +1 -0
- flwr/client/app.py +153 -103
- flwr/client/client_app.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +15 -0
- flwr/client/grpc_adapter_client/connection.py +94 -0
- flwr/client/grpc_client/connection.py +5 -1
- flwr/client/grpc_rere_client/client_interceptor.py +1 -1
- flwr/client/grpc_rere_client/connection.py +9 -5
- flwr/client/grpc_rere_client/grpc_adapter.py +133 -0
- flwr/client/mod/__init__.py +4 -4
- flwr/client/rest_client/connection.py +10 -3
- flwr/client/supernode/app.py +155 -31
- flwr/common/__init__.py +12 -12
- flwr/common/config.py +71 -0
- flwr/common/constant.py +15 -0
- flwr/common/object_ref.py +52 -14
- flwr/common/record/__init__.py +1 -1
- flwr/common/telemetry.py +4 -0
- flwr/common/typing.py +9 -0
- flwr/proto/driver_pb2.py +20 -19
- flwr/proto/driver_pb2_grpc.py +35 -0
- flwr/proto/driver_pb2_grpc.pyi +14 -0
- flwr/proto/exec_pb2.py +34 -0
- flwr/proto/exec_pb2.pyi +55 -0
- flwr/proto/exec_pb2_grpc.py +101 -0
- flwr/proto/exec_pb2_grpc.pyi +41 -0
- flwr/proto/fab_pb2.py +30 -0
- flwr/proto/fab_pb2.pyi +56 -0
- flwr/proto/fab_pb2_grpc.py +4 -0
- flwr/proto/fab_pb2_grpc.pyi +4 -0
- flwr/proto/fleet_pb2.py +28 -33
- flwr/proto/fleet_pb2.pyi +0 -42
- flwr/proto/fleet_pb2_grpc.py +7 -6
- flwr/proto/fleet_pb2_grpc.pyi +5 -4
- flwr/proto/run_pb2.py +30 -0
- flwr/proto/run_pb2.pyi +52 -0
- flwr/proto/run_pb2_grpc.py +4 -0
- flwr/proto/run_pb2_grpc.pyi +4 -0
- flwr/server/__init__.py +2 -6
- flwr/server/app.py +94 -214
- flwr/server/run_serverapp.py +33 -7
- flwr/server/server_app.py +2 -2
- flwr/server/strategy/__init__.py +2 -2
- flwr/server/superlink/driver/driver_servicer.py +7 -0
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +4 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -6
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +3 -1
- flwr/server/superlink/state/in_memory_state.py +8 -5
- flwr/server/superlink/state/sqlite_state.py +6 -3
- flwr/server/superlink/state/state.py +5 -4
- flwr/simulation/__init__.py +4 -1
- flwr/simulation/run_simulation.py +22 -0
- flwr/superexec/__init__.py +21 -0
- flwr/superexec/app.py +178 -0
- flwr/superexec/exec_grpc.py +51 -0
- flwr/superexec/exec_servicer.py +65 -0
- flwr/superexec/executor.py +54 -0
- {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/METADATA +1 -1
- {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/RECORD +80 -56
- {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/entry_points.txt +1 -2
- {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/WHEEL +0 -0
flwr/client/app.py
CHANGED
|
@@ -14,10 +14,12 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower client app."""
|
|
16
16
|
|
|
17
|
+
import signal
|
|
17
18
|
import sys
|
|
18
19
|
import time
|
|
20
|
+
from dataclasses import dataclass
|
|
19
21
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
20
|
-
from typing import Callable, ContextManager, Optional, Tuple, Type, Union
|
|
22
|
+
from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union
|
|
21
23
|
|
|
22
24
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
23
25
|
from grpc import RpcError
|
|
@@ -29,6 +31,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event
|
|
|
29
31
|
from flwr.common.address import parse_address
|
|
30
32
|
from flwr.common.constant import (
|
|
31
33
|
MISSING_EXTRA_REST,
|
|
34
|
+
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
32
35
|
TRANSPORT_TYPE_GRPC_BIDI,
|
|
33
36
|
TRANSPORT_TYPE_GRPC_RERE,
|
|
34
37
|
TRANSPORT_TYPE_REST,
|
|
@@ -37,8 +40,9 @@ from flwr.common.constant import (
|
|
|
37
40
|
)
|
|
38
41
|
from flwr.common.logger import log, warn_deprecated_feature
|
|
39
42
|
from flwr.common.message import Error
|
|
40
|
-
from flwr.common.retry_invoker import RetryInvoker, exponential
|
|
43
|
+
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
|
|
41
44
|
|
|
45
|
+
from .grpc_adapter_client.connection import grpc_adapter
|
|
42
46
|
from .grpc_client.connection import grpc_connection
|
|
43
47
|
from .grpc_rere_client.connection import grpc_request_response
|
|
44
48
|
from .message_handler.message_handler import handle_control_message
|
|
@@ -175,7 +179,7 @@ def start_client(
|
|
|
175
179
|
def _start_client_internal(
|
|
176
180
|
*,
|
|
177
181
|
server_address: str,
|
|
178
|
-
load_client_app_fn: Optional[Callable[[], ClientApp]] = None,
|
|
182
|
+
load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None,
|
|
179
183
|
client_fn: Optional[ClientFn] = None,
|
|
180
184
|
client: Optional[Client] = None,
|
|
181
185
|
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
@@ -250,7 +254,7 @@ def _start_client_internal(
|
|
|
250
254
|
|
|
251
255
|
client_fn = single_client_factory
|
|
252
256
|
|
|
253
|
-
def _load_client_app() -> ClientApp:
|
|
257
|
+
def _load_client_app(_1: str, _2: str) -> ClientApp:
|
|
254
258
|
return ClientApp(client_fn=client_fn)
|
|
255
259
|
|
|
256
260
|
load_client_app_fn = _load_client_app
|
|
@@ -263,6 +267,29 @@ def _start_client_internal(
|
|
|
263
267
|
transport, server_address
|
|
264
268
|
)
|
|
265
269
|
|
|
270
|
+
app_state_tracker = _AppStateTracker()
|
|
271
|
+
|
|
272
|
+
def _on_sucess(retry_state: RetryState) -> None:
|
|
273
|
+
app_state_tracker.is_connected = True
|
|
274
|
+
if retry_state.tries > 1:
|
|
275
|
+
log(
|
|
276
|
+
INFO,
|
|
277
|
+
"Connection successful after %.2f seconds and %s tries.",
|
|
278
|
+
retry_state.elapsed_time,
|
|
279
|
+
retry_state.tries,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
def _on_backoff(retry_state: RetryState) -> None:
|
|
283
|
+
app_state_tracker.is_connected = False
|
|
284
|
+
if retry_state.tries == 1:
|
|
285
|
+
log(WARN, "Connection attempt failed, retrying...")
|
|
286
|
+
else:
|
|
287
|
+
log(
|
|
288
|
+
DEBUG,
|
|
289
|
+
"Connection attempt failed, retrying in %.2f seconds",
|
|
290
|
+
retry_state.actual_wait,
|
|
291
|
+
)
|
|
292
|
+
|
|
266
293
|
retry_invoker = RetryInvoker(
|
|
267
294
|
wait_gen_factory=exponential,
|
|
268
295
|
recoverable_exceptions=connection_error_type,
|
|
@@ -278,30 +305,15 @@ def _start_client_internal(
|
|
|
278
305
|
if retry_state.tries > 1
|
|
279
306
|
else None
|
|
280
307
|
),
|
|
281
|
-
on_success=
|
|
282
|
-
|
|
283
|
-
INFO,
|
|
284
|
-
"Connection successful after %.2f seconds and %s tries.",
|
|
285
|
-
retry_state.elapsed_time,
|
|
286
|
-
retry_state.tries,
|
|
287
|
-
)
|
|
288
|
-
if retry_state.tries > 1
|
|
289
|
-
else None
|
|
290
|
-
),
|
|
291
|
-
on_backoff=lambda retry_state: (
|
|
292
|
-
log(WARN, "Connection attempt failed, retrying...")
|
|
293
|
-
if retry_state.tries == 1
|
|
294
|
-
else log(
|
|
295
|
-
DEBUG,
|
|
296
|
-
"Connection attempt failed, retrying in %.2f seconds",
|
|
297
|
-
retry_state.actual_wait,
|
|
298
|
-
)
|
|
299
|
-
),
|
|
308
|
+
on_success=_on_sucess,
|
|
309
|
+
on_backoff=_on_backoff,
|
|
300
310
|
)
|
|
301
311
|
|
|
302
312
|
node_state = NodeState()
|
|
313
|
+
# run_id -> (fab_id, fab_version)
|
|
314
|
+
run_info: Dict[int, Tuple[str, str]] = {}
|
|
303
315
|
|
|
304
|
-
while
|
|
316
|
+
while not app_state_tracker.interrupt:
|
|
305
317
|
sleep_duration: int = 0
|
|
306
318
|
with connection(
|
|
307
319
|
address,
|
|
@@ -311,106 +323,125 @@ def _start_client_internal(
|
|
|
311
323
|
root_certificates,
|
|
312
324
|
authentication_keys,
|
|
313
325
|
) as conn:
|
|
314
|
-
# pylint: disable-next=W0612
|
|
315
326
|
receive, send, create_node, delete_node, get_run = conn
|
|
316
327
|
|
|
317
328
|
# Register node
|
|
318
329
|
if create_node is not None:
|
|
319
330
|
create_node() # pylint: disable=not-callable
|
|
320
331
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
332
|
+
app_state_tracker.register_signal_handler()
|
|
333
|
+
while not app_state_tracker.interrupt:
|
|
334
|
+
try:
|
|
335
|
+
# Receive
|
|
336
|
+
message = receive()
|
|
337
|
+
if message is None:
|
|
338
|
+
time.sleep(3) # Wait for 3s before asking again
|
|
339
|
+
continue
|
|
340
|
+
|
|
341
|
+
log(INFO, "")
|
|
342
|
+
if len(message.metadata.group_id) > 0:
|
|
343
|
+
log(
|
|
344
|
+
INFO,
|
|
345
|
+
"[RUN %s, ROUND %s]",
|
|
346
|
+
message.metadata.run_id,
|
|
347
|
+
message.metadata.group_id,
|
|
348
|
+
)
|
|
330
349
|
log(
|
|
331
350
|
INFO,
|
|
332
|
-
"
|
|
333
|
-
message.metadata.
|
|
334
|
-
message.metadata.
|
|
351
|
+
"Received: %s message %s",
|
|
352
|
+
message.metadata.message_type,
|
|
353
|
+
message.metadata.message_id,
|
|
335
354
|
)
|
|
336
|
-
log(
|
|
337
|
-
INFO,
|
|
338
|
-
"Received: %s message %s",
|
|
339
|
-
message.metadata.message_type,
|
|
340
|
-
message.metadata.message_id,
|
|
341
|
-
)
|
|
342
|
-
|
|
343
|
-
# Handle control message
|
|
344
|
-
out_message, sleep_duration = handle_control_message(message)
|
|
345
|
-
if out_message:
|
|
346
|
-
send(out_message)
|
|
347
|
-
break
|
|
348
|
-
|
|
349
|
-
# Register context for this run
|
|
350
|
-
node_state.register_context(run_id=message.metadata.run_id)
|
|
351
|
-
|
|
352
|
-
# Retrieve context for this run
|
|
353
|
-
context = node_state.retrieve_context(run_id=message.metadata.run_id)
|
|
354
355
|
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
356
|
+
# Handle control message
|
|
357
|
+
out_message, sleep_duration = handle_control_message(message)
|
|
358
|
+
if out_message:
|
|
359
|
+
send(out_message)
|
|
360
|
+
break
|
|
361
|
+
|
|
362
|
+
# Get run info
|
|
363
|
+
run_id = message.metadata.run_id
|
|
364
|
+
if run_id not in run_info:
|
|
365
|
+
if get_run is not None:
|
|
366
|
+
run_info[run_id] = get_run(run_id)
|
|
367
|
+
# If get_run is None, i.e., in grpc-bidi mode
|
|
368
|
+
else:
|
|
369
|
+
run_info[run_id] = ("", "")
|
|
370
|
+
|
|
371
|
+
# Register context for this run
|
|
372
|
+
node_state.register_context(run_id=run_id)
|
|
373
|
+
|
|
374
|
+
# Retrieve context for this run
|
|
375
|
+
context = node_state.retrieve_context(run_id=run_id)
|
|
376
|
+
|
|
377
|
+
# Create an error reply message that will never be used to prevent
|
|
378
|
+
# the used-before-assignment linting error
|
|
379
|
+
reply_message = message.create_error_reply(
|
|
380
|
+
error=Error(code=ErrorCode.UNKNOWN, reason="Unknown")
|
|
381
|
+
)
|
|
360
382
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
383
|
+
# Handle app loading and task message
|
|
384
|
+
try:
|
|
385
|
+
# Load ClientApp instance
|
|
386
|
+
client_app: ClientApp = load_client_app_fn(*run_info[run_id])
|
|
387
|
+
|
|
388
|
+
# Execute ClientApp
|
|
389
|
+
reply_message = client_app(message=message, context=context)
|
|
390
|
+
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
391
|
+
|
|
392
|
+
# Legacy grpc-bidi
|
|
393
|
+
if transport in ["grpc-bidi", None]:
|
|
394
|
+
log(ERROR, "Client raised an exception.", exc_info=ex)
|
|
395
|
+
# Raise exception, crash process
|
|
396
|
+
raise ex
|
|
397
|
+
|
|
398
|
+
# Don't update/change NodeState
|
|
399
|
+
|
|
400
|
+
e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION
|
|
401
|
+
# Ex fmt: "<class 'ZeroDivisionError'>:<'division by zero'>"
|
|
402
|
+
reason = str(type(ex)) + ":<'" + str(ex) + "'>"
|
|
403
|
+
exc_entity = "ClientApp"
|
|
404
|
+
if isinstance(ex, LoadClientAppError):
|
|
405
|
+
reason = (
|
|
406
|
+
"An exception was raised when attempting to load "
|
|
407
|
+
"`ClientApp`"
|
|
408
|
+
)
|
|
409
|
+
e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
|
|
410
|
+
exc_entity = "SuperNode"
|
|
411
|
+
|
|
412
|
+
if not app_state_tracker.interrupt:
|
|
413
|
+
log(
|
|
414
|
+
ERROR, "%s raised an exception", exc_entity, exc_info=ex
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
# Create error message
|
|
418
|
+
reply_message = message.create_error_reply(
|
|
419
|
+
error=Error(code=e_code, reason=reason)
|
|
420
|
+
)
|
|
421
|
+
else:
|
|
422
|
+
# No exception, update node state
|
|
423
|
+
node_state.update_context(
|
|
424
|
+
run_id=run_id,
|
|
425
|
+
context=context,
|
|
386
426
|
)
|
|
387
|
-
e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
|
|
388
|
-
exc_entity = "SuperNode"
|
|
389
|
-
|
|
390
|
-
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
|
|
391
427
|
|
|
392
|
-
#
|
|
393
|
-
reply_message
|
|
394
|
-
|
|
395
|
-
)
|
|
396
|
-
else:
|
|
397
|
-
# No exception, update node state
|
|
398
|
-
node_state.update_context(
|
|
399
|
-
run_id=message.metadata.run_id,
|
|
400
|
-
context=context,
|
|
401
|
-
)
|
|
428
|
+
# Send
|
|
429
|
+
send(reply_message)
|
|
430
|
+
log(INFO, "Sent reply")
|
|
402
431
|
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
432
|
+
except StopIteration:
|
|
433
|
+
sleep_duration = 0
|
|
434
|
+
break
|
|
406
435
|
|
|
407
436
|
# Unregister node
|
|
408
|
-
if delete_node is not None:
|
|
437
|
+
if delete_node is not None and app_state_tracker.is_connected:
|
|
409
438
|
delete_node() # pylint: disable=not-callable
|
|
410
439
|
|
|
411
440
|
if sleep_duration == 0:
|
|
412
441
|
log(INFO, "Disconnect and shut down")
|
|
442
|
+
del app_state_tracker
|
|
413
443
|
break
|
|
444
|
+
|
|
414
445
|
# Sleep and reconnect afterwards
|
|
415
446
|
log(
|
|
416
447
|
INFO,
|
|
@@ -571,6 +602,8 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
|
|
571
602
|
connection, error_type = http_request_response, RequestsConnectionError
|
|
572
603
|
elif transport == TRANSPORT_TYPE_GRPC_RERE:
|
|
573
604
|
connection, error_type = grpc_request_response, RpcError
|
|
605
|
+
elif transport == TRANSPORT_TYPE_GRPC_ADAPTER:
|
|
606
|
+
connection, error_type = grpc_adapter, RpcError
|
|
574
607
|
elif transport == TRANSPORT_TYPE_GRPC_BIDI:
|
|
575
608
|
connection, error_type = grpc_connection, RpcError
|
|
576
609
|
else:
|
|
@@ -579,3 +612,20 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
|
|
579
612
|
)
|
|
580
613
|
|
|
581
614
|
return connection, address, error_type
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
@dataclass
|
|
618
|
+
class _AppStateTracker:
|
|
619
|
+
interrupt: bool = False
|
|
620
|
+
is_connected: bool = False
|
|
621
|
+
|
|
622
|
+
def register_signal_handler(self) -> None:
|
|
623
|
+
"""Register handlers for exit signals."""
|
|
624
|
+
|
|
625
|
+
def signal_handler(sig, frame): # type: ignore
|
|
626
|
+
# pylint: disable=unused-argument
|
|
627
|
+
self.interrupt = True
|
|
628
|
+
raise StopIteration from None
|
|
629
|
+
|
|
630
|
+
signal.signal(signal.SIGINT, signal_handler)
|
|
631
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
flwr/client/client_app.py
CHANGED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Client-side part of the GrpcAdapter transport layer."""
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Contextmanager for a GrpcAdapter channel to the Flower server."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from contextlib import contextmanager
|
|
19
|
+
from logging import ERROR
|
|
20
|
+
from typing import Callable, Iterator, Optional, Tuple, Union
|
|
21
|
+
|
|
22
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
23
|
+
|
|
24
|
+
from flwr.client.grpc_rere_client.connection import grpc_request_response
|
|
25
|
+
from flwr.client.grpc_rere_client.grpc_adapter import GrpcAdapter
|
|
26
|
+
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
27
|
+
from flwr.common.logger import log
|
|
28
|
+
from flwr.common.message import Message
|
|
29
|
+
from flwr.common.retry_invoker import RetryInvoker
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@contextmanager
|
|
33
|
+
def grpc_adapter( # pylint: disable=R0913
|
|
34
|
+
server_address: str,
|
|
35
|
+
insecure: bool,
|
|
36
|
+
retry_invoker: RetryInvoker,
|
|
37
|
+
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
|
|
38
|
+
root_certificates: Optional[Union[bytes, str]] = None,
|
|
39
|
+
authentication_keys: Optional[ # pylint: disable=unused-argument
|
|
40
|
+
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
41
|
+
] = None,
|
|
42
|
+
) -> Iterator[
|
|
43
|
+
Tuple[
|
|
44
|
+
Callable[[], Optional[Message]],
|
|
45
|
+
Callable[[Message], None],
|
|
46
|
+
Optional[Callable[[], None]],
|
|
47
|
+
Optional[Callable[[], None]],
|
|
48
|
+
Optional[Callable[[int], Tuple[str, str]]],
|
|
49
|
+
]
|
|
50
|
+
]:
|
|
51
|
+
"""Primitives for request/response-based interaction with a server via GrpcAdapter.
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
server_address : str
|
|
56
|
+
The IPv6 address of the server with `http://` or `https://`.
|
|
57
|
+
If the Flower server runs on the same machine
|
|
58
|
+
on port 8080, then `server_address` would be `"http://[::]:8080"`.
|
|
59
|
+
insecure : bool
|
|
60
|
+
Starts an insecure gRPC connection when True. Enables HTTPS connection
|
|
61
|
+
when False, using system certificates if `root_certificates` is None.
|
|
62
|
+
retry_invoker: RetryInvoker
|
|
63
|
+
`RetryInvoker` object that will try to reconnect the client to the server
|
|
64
|
+
after gRPC errors. If None, the client will only try to
|
|
65
|
+
reconnect once after a failure.
|
|
66
|
+
max_message_length : int
|
|
67
|
+
Ignored, only present to preserve API-compatibility.
|
|
68
|
+
root_certificates : Optional[Union[bytes, str]] (default: None)
|
|
69
|
+
Path of the root certificate. If provided, a secure
|
|
70
|
+
connection using the certificates will be established to an SSL-enabled
|
|
71
|
+
Flower server. Bytes won't work for the REST API.
|
|
72
|
+
authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
|
|
73
|
+
Client authentication is not supported for this transport type.
|
|
74
|
+
|
|
75
|
+
Returns
|
|
76
|
+
-------
|
|
77
|
+
receive : Callable
|
|
78
|
+
send : Callable
|
|
79
|
+
create_node : Optional[Callable]
|
|
80
|
+
delete_node : Optional[Callable]
|
|
81
|
+
get_run : Optional[Callable]
|
|
82
|
+
"""
|
|
83
|
+
if authentication_keys is not None:
|
|
84
|
+
log(ERROR, "Client authentication is not supported for this transport type.")
|
|
85
|
+
with grpc_request_response(
|
|
86
|
+
server_address=server_address,
|
|
87
|
+
insecure=insecure,
|
|
88
|
+
retry_invoker=retry_invoker,
|
|
89
|
+
max_message_length=max_message_length,
|
|
90
|
+
root_certificates=root_certificates,
|
|
91
|
+
authentication_keys=None, # Authentication is not supported
|
|
92
|
+
adapter_cls=GrpcAdapter,
|
|
93
|
+
) as conn:
|
|
94
|
+
yield conn
|
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
import uuid
|
|
19
19
|
from contextlib import contextmanager
|
|
20
|
-
from logging import DEBUG
|
|
20
|
+
from logging import DEBUG, ERROR
|
|
21
21
|
from pathlib import Path
|
|
22
22
|
from queue import Queue
|
|
23
23
|
from typing import Callable, Iterator, Optional, Tuple, Union, cast
|
|
@@ -101,6 +101,8 @@ def grpc_connection( # pylint: disable=R0913, R0915
|
|
|
101
101
|
The PEM-encoded root certificates as a byte string or a path string.
|
|
102
102
|
If provided, a secure connection using the certificates will be
|
|
103
103
|
established to an SSL-enabled Flower server.
|
|
104
|
+
authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
|
|
105
|
+
Client authentication is not supported for this transport type.
|
|
104
106
|
|
|
105
107
|
Returns
|
|
106
108
|
-------
|
|
@@ -123,6 +125,8 @@ def grpc_connection( # pylint: disable=R0913, R0915
|
|
|
123
125
|
"""
|
|
124
126
|
if isinstance(root_certificates, str):
|
|
125
127
|
root_certificates = Path(root_certificates).read_bytes()
|
|
128
|
+
if authentication_keys is not None:
|
|
129
|
+
log(ERROR, "Client authentication is not supported for this transport type.")
|
|
126
130
|
|
|
127
131
|
channel = create_channel(
|
|
128
132
|
server_address=server_address,
|
|
@@ -31,11 +31,11 @@ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
|
31
31
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
32
32
|
CreateNodeRequest,
|
|
33
33
|
DeleteNodeRequest,
|
|
34
|
-
GetRunRequest,
|
|
35
34
|
PingRequest,
|
|
36
35
|
PullTaskInsRequest,
|
|
37
36
|
PushTaskResRequest,
|
|
38
37
|
)
|
|
38
|
+
from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
|
|
39
39
|
|
|
40
40
|
_PUBLIC_KEY_HEADER = "public-key"
|
|
41
41
|
_AUTH_TOKEN_HEADER = "auth-token"
|
|
@@ -44,8 +44,6 @@ from flwr.common.serde import message_from_taskins, message_to_taskres
|
|
|
44
44
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
45
45
|
CreateNodeRequest,
|
|
46
46
|
DeleteNodeRequest,
|
|
47
|
-
GetRunRequest,
|
|
48
|
-
GetRunResponse,
|
|
49
47
|
PingRequest,
|
|
50
48
|
PingResponse,
|
|
51
49
|
PullTaskInsRequest,
|
|
@@ -53,9 +51,11 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
53
51
|
)
|
|
54
52
|
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
|
55
53
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
54
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
56
55
|
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
57
56
|
|
|
58
57
|
from .client_interceptor import AuthenticateClientInterceptor
|
|
58
|
+
from .grpc_adapter import GrpcAdapter
|
|
59
59
|
|
|
60
60
|
|
|
61
61
|
def on_channel_state_change(channel_connectivity: str) -> None:
|
|
@@ -73,7 +73,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
73
73
|
authentication_keys: Optional[
|
|
74
74
|
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
75
75
|
] = None,
|
|
76
|
-
adapter_cls: Optional[Type[FleetStub]] = None,
|
|
76
|
+
adapter_cls: Optional[Union[Type[FleetStub], Type[GrpcAdapter]]] = None,
|
|
77
77
|
) -> Iterator[
|
|
78
78
|
Tuple[
|
|
79
79
|
Callable[[], Optional[Message]],
|
|
@@ -107,6 +107,11 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
107
107
|
Path of the root certificate. If provided, a secure
|
|
108
108
|
connection using the certificates will be established to an SSL-enabled
|
|
109
109
|
Flower server. Bytes won't work for the REST API.
|
|
110
|
+
authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
|
|
111
|
+
Tuple containing the elliptic curve private key and public key for
|
|
112
|
+
authentication from the cryptography library.
|
|
113
|
+
Source: https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ec/
|
|
114
|
+
Used to establish an authenticated connection with the server.
|
|
110
115
|
|
|
111
116
|
Returns
|
|
112
117
|
-------
|
|
@@ -114,6 +119,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
114
119
|
send : Callable
|
|
115
120
|
create_node : Optional[Callable]
|
|
116
121
|
delete_node : Optional[Callable]
|
|
122
|
+
get_run : Optional[Callable]
|
|
117
123
|
"""
|
|
118
124
|
if isinstance(root_certificates, str):
|
|
119
125
|
root_certificates = Path(root_certificates).read_bytes()
|
|
@@ -193,8 +199,6 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
|
|
|
193
199
|
|
|
194
200
|
# Stop the ping-loop thread
|
|
195
201
|
ping_stop_event.set()
|
|
196
|
-
if ping_thread is not None:
|
|
197
|
-
ping_thread.join()
|
|
198
202
|
|
|
199
203
|
# Call FleetAPI
|
|
200
204
|
delete_node_request = DeleteNodeRequest(node=node)
|