flwr-nightly 1.9.0.dev20240531__py3-none-any.whl → 1.10.0.dev20240619__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (80) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/build.py +4 -15
  3. flwr/cli/config_utils.py +64 -7
  4. flwr/cli/install.py +211 -0
  5. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  11. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  12. flwr/cli/run/run.py +39 -2
  13. flwr/cli/utils.py +14 -0
  14. flwr/client/__init__.py +1 -0
  15. flwr/client/app.py +153 -103
  16. flwr/client/client_app.py +1 -1
  17. flwr/client/grpc_adapter_client/__init__.py +15 -0
  18. flwr/client/grpc_adapter_client/connection.py +94 -0
  19. flwr/client/grpc_client/connection.py +5 -1
  20. flwr/client/grpc_rere_client/client_interceptor.py +1 -1
  21. flwr/client/grpc_rere_client/connection.py +9 -5
  22. flwr/client/grpc_rere_client/grpc_adapter.py +133 -0
  23. flwr/client/mod/__init__.py +4 -4
  24. flwr/client/rest_client/connection.py +10 -3
  25. flwr/client/supernode/app.py +155 -31
  26. flwr/common/__init__.py +12 -12
  27. flwr/common/config.py +71 -0
  28. flwr/common/constant.py +15 -0
  29. flwr/common/object_ref.py +52 -14
  30. flwr/common/record/__init__.py +1 -1
  31. flwr/common/telemetry.py +4 -0
  32. flwr/common/typing.py +9 -0
  33. flwr/proto/driver_pb2.py +20 -19
  34. flwr/proto/driver_pb2_grpc.py +35 -0
  35. flwr/proto/driver_pb2_grpc.pyi +14 -0
  36. flwr/proto/exec_pb2.py +34 -0
  37. flwr/proto/exec_pb2.pyi +55 -0
  38. flwr/proto/exec_pb2_grpc.py +101 -0
  39. flwr/proto/exec_pb2_grpc.pyi +41 -0
  40. flwr/proto/fab_pb2.py +30 -0
  41. flwr/proto/fab_pb2.pyi +56 -0
  42. flwr/proto/fab_pb2_grpc.py +4 -0
  43. flwr/proto/fab_pb2_grpc.pyi +4 -0
  44. flwr/proto/fleet_pb2.py +28 -33
  45. flwr/proto/fleet_pb2.pyi +0 -42
  46. flwr/proto/fleet_pb2_grpc.py +7 -6
  47. flwr/proto/fleet_pb2_grpc.pyi +5 -4
  48. flwr/proto/run_pb2.py +30 -0
  49. flwr/proto/run_pb2.pyi +52 -0
  50. flwr/proto/run_pb2_grpc.py +4 -0
  51. flwr/proto/run_pb2_grpc.pyi +4 -0
  52. flwr/server/__init__.py +2 -6
  53. flwr/server/app.py +94 -214
  54. flwr/server/run_serverapp.py +33 -7
  55. flwr/server/server_app.py +2 -2
  56. flwr/server/strategy/__init__.py +2 -2
  57. flwr/server/superlink/driver/driver_servicer.py +7 -0
  58. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  59. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
  60. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +4 -0
  61. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -2
  62. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -2
  63. flwr/server/superlink/fleet/message_handler/message_handler.py +8 -6
  64. flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
  65. flwr/server/superlink/fleet/vce/vce_api.py +3 -1
  66. flwr/server/superlink/state/in_memory_state.py +8 -5
  67. flwr/server/superlink/state/sqlite_state.py +6 -3
  68. flwr/server/superlink/state/state.py +5 -4
  69. flwr/simulation/__init__.py +4 -1
  70. flwr/simulation/run_simulation.py +22 -0
  71. flwr/superexec/__init__.py +21 -0
  72. flwr/superexec/app.py +178 -0
  73. flwr/superexec/exec_grpc.py +51 -0
  74. flwr/superexec/exec_servicer.py +65 -0
  75. flwr/superexec/executor.py +54 -0
  76. {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/METADATA +1 -1
  77. {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/RECORD +80 -56
  78. {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/entry_points.txt +1 -2
  79. {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/LICENSE +0 -0
  80. {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/WHEEL +0 -0
flwr/client/app.py CHANGED
@@ -14,10 +14,12 @@
14
14
  # ==============================================================================
15
15
  """Flower client app."""
16
16
 
17
+ import signal
17
18
  import sys
18
19
  import time
20
+ from dataclasses import dataclass
19
21
  from logging import DEBUG, ERROR, INFO, WARN
20
- from typing import Callable, ContextManager, Optional, Tuple, Type, Union
22
+ from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union
21
23
 
22
24
  from cryptography.hazmat.primitives.asymmetric import ec
23
25
  from grpc import RpcError
@@ -29,6 +31,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event
29
31
  from flwr.common.address import parse_address
30
32
  from flwr.common.constant import (
31
33
  MISSING_EXTRA_REST,
34
+ TRANSPORT_TYPE_GRPC_ADAPTER,
32
35
  TRANSPORT_TYPE_GRPC_BIDI,
33
36
  TRANSPORT_TYPE_GRPC_RERE,
34
37
  TRANSPORT_TYPE_REST,
@@ -37,8 +40,9 @@ from flwr.common.constant import (
37
40
  )
38
41
  from flwr.common.logger import log, warn_deprecated_feature
39
42
  from flwr.common.message import Error
40
- from flwr.common.retry_invoker import RetryInvoker, exponential
43
+ from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
41
44
 
45
+ from .grpc_adapter_client.connection import grpc_adapter
42
46
  from .grpc_client.connection import grpc_connection
43
47
  from .grpc_rere_client.connection import grpc_request_response
44
48
  from .message_handler.message_handler import handle_control_message
@@ -175,7 +179,7 @@ def start_client(
175
179
  def _start_client_internal(
176
180
  *,
177
181
  server_address: str,
178
- load_client_app_fn: Optional[Callable[[], ClientApp]] = None,
182
+ load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None,
179
183
  client_fn: Optional[ClientFn] = None,
180
184
  client: Optional[Client] = None,
181
185
  grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
@@ -250,7 +254,7 @@ def _start_client_internal(
250
254
 
251
255
  client_fn = single_client_factory
252
256
 
253
- def _load_client_app() -> ClientApp:
257
+ def _load_client_app(_1: str, _2: str) -> ClientApp:
254
258
  return ClientApp(client_fn=client_fn)
255
259
 
256
260
  load_client_app_fn = _load_client_app
@@ -263,6 +267,29 @@ def _start_client_internal(
263
267
  transport, server_address
264
268
  )
265
269
 
270
+ app_state_tracker = _AppStateTracker()
271
+
272
+ def _on_sucess(retry_state: RetryState) -> None:
273
+ app_state_tracker.is_connected = True
274
+ if retry_state.tries > 1:
275
+ log(
276
+ INFO,
277
+ "Connection successful after %.2f seconds and %s tries.",
278
+ retry_state.elapsed_time,
279
+ retry_state.tries,
280
+ )
281
+
282
+ def _on_backoff(retry_state: RetryState) -> None:
283
+ app_state_tracker.is_connected = False
284
+ if retry_state.tries == 1:
285
+ log(WARN, "Connection attempt failed, retrying...")
286
+ else:
287
+ log(
288
+ DEBUG,
289
+ "Connection attempt failed, retrying in %.2f seconds",
290
+ retry_state.actual_wait,
291
+ )
292
+
266
293
  retry_invoker = RetryInvoker(
267
294
  wait_gen_factory=exponential,
268
295
  recoverable_exceptions=connection_error_type,
@@ -278,30 +305,15 @@ def _start_client_internal(
278
305
  if retry_state.tries > 1
279
306
  else None
280
307
  ),
281
- on_success=lambda retry_state: (
282
- log(
283
- INFO,
284
- "Connection successful after %.2f seconds and %s tries.",
285
- retry_state.elapsed_time,
286
- retry_state.tries,
287
- )
288
- if retry_state.tries > 1
289
- else None
290
- ),
291
- on_backoff=lambda retry_state: (
292
- log(WARN, "Connection attempt failed, retrying...")
293
- if retry_state.tries == 1
294
- else log(
295
- DEBUG,
296
- "Connection attempt failed, retrying in %.2f seconds",
297
- retry_state.actual_wait,
298
- )
299
- ),
308
+ on_success=_on_sucess,
309
+ on_backoff=_on_backoff,
300
310
  )
301
311
 
302
312
  node_state = NodeState()
313
+ # run_id -> (fab_id, fab_version)
314
+ run_info: Dict[int, Tuple[str, str]] = {}
303
315
 
304
- while True:
316
+ while not app_state_tracker.interrupt:
305
317
  sleep_duration: int = 0
306
318
  with connection(
307
319
  address,
@@ -311,106 +323,125 @@ def _start_client_internal(
311
323
  root_certificates,
312
324
  authentication_keys,
313
325
  ) as conn:
314
- # pylint: disable-next=W0612
315
326
  receive, send, create_node, delete_node, get_run = conn
316
327
 
317
328
  # Register node
318
329
  if create_node is not None:
319
330
  create_node() # pylint: disable=not-callable
320
331
 
321
- while True:
322
- # Receive
323
- message = receive()
324
- if message is None:
325
- time.sleep(3) # Wait for 3s before asking again
326
- continue
327
-
328
- log(INFO, "")
329
- if len(message.metadata.group_id) > 0:
332
+ app_state_tracker.register_signal_handler()
333
+ while not app_state_tracker.interrupt:
334
+ try:
335
+ # Receive
336
+ message = receive()
337
+ if message is None:
338
+ time.sleep(3) # Wait for 3s before asking again
339
+ continue
340
+
341
+ log(INFO, "")
342
+ if len(message.metadata.group_id) > 0:
343
+ log(
344
+ INFO,
345
+ "[RUN %s, ROUND %s]",
346
+ message.metadata.run_id,
347
+ message.metadata.group_id,
348
+ )
330
349
  log(
331
350
  INFO,
332
- "[RUN %s, ROUND %s]",
333
- message.metadata.run_id,
334
- message.metadata.group_id,
351
+ "Received: %s message %s",
352
+ message.metadata.message_type,
353
+ message.metadata.message_id,
335
354
  )
336
- log(
337
- INFO,
338
- "Received: %s message %s",
339
- message.metadata.message_type,
340
- message.metadata.message_id,
341
- )
342
-
343
- # Handle control message
344
- out_message, sleep_duration = handle_control_message(message)
345
- if out_message:
346
- send(out_message)
347
- break
348
-
349
- # Register context for this run
350
- node_state.register_context(run_id=message.metadata.run_id)
351
-
352
- # Retrieve context for this run
353
- context = node_state.retrieve_context(run_id=message.metadata.run_id)
354
355
 
355
- # Create an error reply message that will never be used to prevent
356
- # the used-before-assignment linting error
357
- reply_message = message.create_error_reply(
358
- error=Error(code=ErrorCode.UNKNOWN, reason="Unknown")
359
- )
356
+ # Handle control message
357
+ out_message, sleep_duration = handle_control_message(message)
358
+ if out_message:
359
+ send(out_message)
360
+ break
361
+
362
+ # Get run info
363
+ run_id = message.metadata.run_id
364
+ if run_id not in run_info:
365
+ if get_run is not None:
366
+ run_info[run_id] = get_run(run_id)
367
+ # If get_run is None, i.e., in grpc-bidi mode
368
+ else:
369
+ run_info[run_id] = ("", "")
370
+
371
+ # Register context for this run
372
+ node_state.register_context(run_id=run_id)
373
+
374
+ # Retrieve context for this run
375
+ context = node_state.retrieve_context(run_id=run_id)
376
+
377
+ # Create an error reply message that will never be used to prevent
378
+ # the used-before-assignment linting error
379
+ reply_message = message.create_error_reply(
380
+ error=Error(code=ErrorCode.UNKNOWN, reason="Unknown")
381
+ )
360
382
 
361
- # Handle app loading and task message
362
- try:
363
- # Load ClientApp instance
364
- client_app: ClientApp = load_client_app_fn()
365
-
366
- # Execute ClientApp
367
- reply_message = client_app(message=message, context=context)
368
- except Exception as ex: # pylint: disable=broad-exception-caught
369
-
370
- # Legacy grpc-bidi
371
- if transport in ["grpc-bidi", None]:
372
- log(ERROR, "Client raised an exception.", exc_info=ex)
373
- # Raise exception, crash process
374
- raise ex
375
-
376
- # Don't update/change NodeState
377
-
378
- e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION
379
- # Reason example: "<class 'ZeroDivisionError'>:<'division by zero'>"
380
- reason = str(type(ex)) + ":<'" + str(ex) + "'>"
381
- exc_entity = "ClientApp"
382
- if isinstance(ex, LoadClientAppError):
383
- reason = (
384
- "An exception was raised when attempting to load "
385
- "`ClientApp`"
383
+ # Handle app loading and task message
384
+ try:
385
+ # Load ClientApp instance
386
+ client_app: ClientApp = load_client_app_fn(*run_info[run_id])
387
+
388
+ # Execute ClientApp
389
+ reply_message = client_app(message=message, context=context)
390
+ except Exception as ex: # pylint: disable=broad-exception-caught
391
+
392
+ # Legacy grpc-bidi
393
+ if transport in ["grpc-bidi", None]:
394
+ log(ERROR, "Client raised an exception.", exc_info=ex)
395
+ # Raise exception, crash process
396
+ raise ex
397
+
398
+ # Don't update/change NodeState
399
+
400
+ e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION
401
+ # Ex fmt: "<class 'ZeroDivisionError'>:<'division by zero'>"
402
+ reason = str(type(ex)) + ":<'" + str(ex) + "'>"
403
+ exc_entity = "ClientApp"
404
+ if isinstance(ex, LoadClientAppError):
405
+ reason = (
406
+ "An exception was raised when attempting to load "
407
+ "`ClientApp`"
408
+ )
409
+ e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
410
+ exc_entity = "SuperNode"
411
+
412
+ if not app_state_tracker.interrupt:
413
+ log(
414
+ ERROR, "%s raised an exception", exc_entity, exc_info=ex
415
+ )
416
+
417
+ # Create error message
418
+ reply_message = message.create_error_reply(
419
+ error=Error(code=e_code, reason=reason)
420
+ )
421
+ else:
422
+ # No exception, update node state
423
+ node_state.update_context(
424
+ run_id=run_id,
425
+ context=context,
386
426
  )
387
- e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
388
- exc_entity = "SuperNode"
389
-
390
- log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
391
427
 
392
- # Create error message
393
- reply_message = message.create_error_reply(
394
- error=Error(code=e_code, reason=reason)
395
- )
396
- else:
397
- # No exception, update node state
398
- node_state.update_context(
399
- run_id=message.metadata.run_id,
400
- context=context,
401
- )
428
+ # Send
429
+ send(reply_message)
430
+ log(INFO, "Sent reply")
402
431
 
403
- # Send
404
- send(reply_message)
405
- log(INFO, "Sent reply")
432
+ except StopIteration:
433
+ sleep_duration = 0
434
+ break
406
435
 
407
436
  # Unregister node
408
- if delete_node is not None:
437
+ if delete_node is not None and app_state_tracker.is_connected:
409
438
  delete_node() # pylint: disable=not-callable
410
439
 
411
440
  if sleep_duration == 0:
412
441
  log(INFO, "Disconnect and shut down")
442
+ del app_state_tracker
413
443
  break
444
+
414
445
  # Sleep and reconnect afterwards
415
446
  log(
416
447
  INFO,
@@ -571,6 +602,8 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
571
602
  connection, error_type = http_request_response, RequestsConnectionError
572
603
  elif transport == TRANSPORT_TYPE_GRPC_RERE:
573
604
  connection, error_type = grpc_request_response, RpcError
605
+ elif transport == TRANSPORT_TYPE_GRPC_ADAPTER:
606
+ connection, error_type = grpc_adapter, RpcError
574
607
  elif transport == TRANSPORT_TYPE_GRPC_BIDI:
575
608
  connection, error_type = grpc_connection, RpcError
576
609
  else:
@@ -579,3 +612,20 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
579
612
  )
580
613
 
581
614
  return connection, address, error_type
615
+
616
+
617
+ @dataclass
618
+ class _AppStateTracker:
619
+ interrupt: bool = False
620
+ is_connected: bool = False
621
+
622
+ def register_signal_handler(self) -> None:
623
+ """Register handlers for exit signals."""
624
+
625
+ def signal_handler(sig, frame): # type: ignore
626
+ # pylint: disable=unused-argument
627
+ self.interrupt = True
628
+ raise StopIteration from None
629
+
630
+ signal.signal(signal.SIGINT, signal_handler)
631
+ signal.signal(signal.SIGTERM, signal_handler)
flwr/client/client_app.py CHANGED
@@ -221,7 +221,7 @@ def _registration_error(fn_name: str) -> ValueError:
221
221
  >>> def client_fn(cid) -> Client:
222
222
  >>> return FlowerClient().to_client()
223
223
  >>>
224
- >>> app = ClientApp()
224
+ >>> app = ClientApp(
225
225
  >>> client_fn=client_fn,
226
226
  >>> )
227
227
 
@@ -0,0 +1,15 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Client-side part of the GrpcAdapter transport layer."""
@@ -0,0 +1,94 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Contextmanager for a GrpcAdapter channel to the Flower server."""
16
+
17
+
18
+ from contextlib import contextmanager
19
+ from logging import ERROR
20
+ from typing import Callable, Iterator, Optional, Tuple, Union
21
+
22
+ from cryptography.hazmat.primitives.asymmetric import ec
23
+
24
+ from flwr.client.grpc_rere_client.connection import grpc_request_response
25
+ from flwr.client.grpc_rere_client.grpc_adapter import GrpcAdapter
26
+ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
27
+ from flwr.common.logger import log
28
+ from flwr.common.message import Message
29
+ from flwr.common.retry_invoker import RetryInvoker
30
+
31
+
32
+ @contextmanager
33
+ def grpc_adapter( # pylint: disable=R0913
34
+ server_address: str,
35
+ insecure: bool,
36
+ retry_invoker: RetryInvoker,
37
+ max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
38
+ root_certificates: Optional[Union[bytes, str]] = None,
39
+ authentication_keys: Optional[ # pylint: disable=unused-argument
40
+ Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
41
+ ] = None,
42
+ ) -> Iterator[
43
+ Tuple[
44
+ Callable[[], Optional[Message]],
45
+ Callable[[Message], None],
46
+ Optional[Callable[[], None]],
47
+ Optional[Callable[[], None]],
48
+ Optional[Callable[[int], Tuple[str, str]]],
49
+ ]
50
+ ]:
51
+ """Primitives for request/response-based interaction with a server via GrpcAdapter.
52
+
53
+ Parameters
54
+ ----------
55
+ server_address : str
56
+ The IPv6 address of the server with `http://` or `https://`.
57
+ If the Flower server runs on the same machine
58
+ on port 8080, then `server_address` would be `"http://[::]:8080"`.
59
+ insecure : bool
60
+ Starts an insecure gRPC connection when True. Enables HTTPS connection
61
+ when False, using system certificates if `root_certificates` is None.
62
+ retry_invoker: RetryInvoker
63
+ `RetryInvoker` object that will try to reconnect the client to the server
64
+ after gRPC errors. If None, the client will only try to
65
+ reconnect once after a failure.
66
+ max_message_length : int
67
+ Ignored, only present to preserve API-compatibility.
68
+ root_certificates : Optional[Union[bytes, str]] (default: None)
69
+ Path of the root certificate. If provided, a secure
70
+ connection using the certificates will be established to an SSL-enabled
71
+ Flower server. Bytes won't work for the REST API.
72
+ authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
73
+ Client authentication is not supported for this transport type.
74
+
75
+ Returns
76
+ -------
77
+ receive : Callable
78
+ send : Callable
79
+ create_node : Optional[Callable]
80
+ delete_node : Optional[Callable]
81
+ get_run : Optional[Callable]
82
+ """
83
+ if authentication_keys is not None:
84
+ log(ERROR, "Client authentication is not supported for this transport type.")
85
+ with grpc_request_response(
86
+ server_address=server_address,
87
+ insecure=insecure,
88
+ retry_invoker=retry_invoker,
89
+ max_message_length=max_message_length,
90
+ root_certificates=root_certificates,
91
+ authentication_keys=None, # Authentication is not supported
92
+ adapter_cls=GrpcAdapter,
93
+ ) as conn:
94
+ yield conn
@@ -17,7 +17,7 @@
17
17
 
18
18
  import uuid
19
19
  from contextlib import contextmanager
20
- from logging import DEBUG
20
+ from logging import DEBUG, ERROR
21
21
  from pathlib import Path
22
22
  from queue import Queue
23
23
  from typing import Callable, Iterator, Optional, Tuple, Union, cast
@@ -101,6 +101,8 @@ def grpc_connection( # pylint: disable=R0913, R0915
101
101
  The PEM-encoded root certificates as a byte string or a path string.
102
102
  If provided, a secure connection using the certificates will be
103
103
  established to an SSL-enabled Flower server.
104
+ authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
105
+ Client authentication is not supported for this transport type.
104
106
 
105
107
  Returns
106
108
  -------
@@ -123,6 +125,8 @@ def grpc_connection( # pylint: disable=R0913, R0915
123
125
  """
124
126
  if isinstance(root_certificates, str):
125
127
  root_certificates = Path(root_certificates).read_bytes()
128
+ if authentication_keys is not None:
129
+ log(ERROR, "Client authentication is not supported for this transport type.")
126
130
 
127
131
  channel = create_channel(
128
132
  server_address=server_address,
@@ -31,11 +31,11 @@ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
31
31
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
32
32
  CreateNodeRequest,
33
33
  DeleteNodeRequest,
34
- GetRunRequest,
35
34
  PingRequest,
36
35
  PullTaskInsRequest,
37
36
  PushTaskResRequest,
38
37
  )
38
+ from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
39
39
 
40
40
  _PUBLIC_KEY_HEADER = "public-key"
41
41
  _AUTH_TOKEN_HEADER = "auth-token"
@@ -44,8 +44,6 @@ from flwr.common.serde import message_from_taskins, message_to_taskres
44
44
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
45
45
  CreateNodeRequest,
46
46
  DeleteNodeRequest,
47
- GetRunRequest,
48
- GetRunResponse,
49
47
  PingRequest,
50
48
  PingResponse,
51
49
  PullTaskInsRequest,
@@ -53,9 +51,11 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
53
51
  )
54
52
  from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
55
53
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
54
+ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
56
55
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
57
56
 
58
57
  from .client_interceptor import AuthenticateClientInterceptor
58
+ from .grpc_adapter import GrpcAdapter
59
59
 
60
60
 
61
61
  def on_channel_state_change(channel_connectivity: str) -> None:
@@ -73,7 +73,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
73
73
  authentication_keys: Optional[
74
74
  Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
75
75
  ] = None,
76
- adapter_cls: Optional[Type[FleetStub]] = None,
76
+ adapter_cls: Optional[Union[Type[FleetStub], Type[GrpcAdapter]]] = None,
77
77
  ) -> Iterator[
78
78
  Tuple[
79
79
  Callable[[], Optional[Message]],
@@ -107,6 +107,11 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
107
107
  Path of the root certificate. If provided, a secure
108
108
  connection using the certificates will be established to an SSL-enabled
109
109
  Flower server. Bytes won't work for the REST API.
110
+ authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
111
+ Tuple containing the elliptic curve private key and public key for
112
+ authentication from the cryptography library.
113
+ Source: https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ec/
114
+ Used to establish an authenticated connection with the server.
110
115
 
111
116
  Returns
112
117
  -------
@@ -114,6 +119,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
114
119
  send : Callable
115
120
  create_node : Optional[Callable]
116
121
  delete_node : Optional[Callable]
122
+ get_run : Optional[Callable]
117
123
  """
118
124
  if isinstance(root_certificates, str):
119
125
  root_certificates = Path(root_certificates).read_bytes()
@@ -193,8 +199,6 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
193
199
 
194
200
  # Stop the ping-loop thread
195
201
  ping_stop_event.set()
196
- if ping_thread is not None:
197
- ping_thread.join()
198
202
 
199
203
  # Call FleetAPI
200
204
  delete_node_request = DeleteNodeRequest(node=node)