flwr-nightly 1.9.0.dev20240520__py3-none-any.whl → 1.10.0.dev20240612__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (53) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/build.py +4 -19
  3. flwr/cli/config_utils.py +12 -27
  4. flwr/cli/install.py +196 -0
  5. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +7 -1
  6. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +7 -1
  8. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -1
  9. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -1
  10. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +7 -1
  11. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -1
  12. flwr/cli/run/run.py +20 -4
  13. flwr/cli/utils.py +14 -0
  14. flwr/client/__init__.py +1 -0
  15. flwr/client/app.py +135 -97
  16. flwr/client/client_app.py +1 -1
  17. flwr/client/grpc_rere_client/client_interceptor.py +1 -1
  18. flwr/client/grpc_rere_client/connection.py +6 -6
  19. flwr/client/mod/__init__.py +1 -1
  20. flwr/client/rest_client/connection.py +1 -2
  21. flwr/client/supernode/app.py +70 -28
  22. flwr/common/object_ref.py +13 -9
  23. flwr/common/recordset_compat.py +8 -1
  24. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +0 -15
  25. flwr/proto/driver_pb2.py +20 -19
  26. flwr/proto/driver_pb2_grpc.py +35 -0
  27. flwr/proto/driver_pb2_grpc.pyi +14 -0
  28. flwr/proto/fleet_pb2.py +28 -33
  29. flwr/proto/fleet_pb2.pyi +0 -42
  30. flwr/proto/fleet_pb2_grpc.py +7 -6
  31. flwr/proto/fleet_pb2_grpc.pyi +5 -4
  32. flwr/proto/grpcadapter_pb2.py +32 -0
  33. flwr/proto/grpcadapter_pb2.pyi +43 -0
  34. flwr/proto/grpcadapter_pb2_grpc.py +66 -0
  35. flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
  36. flwr/proto/run_pb2.py +30 -0
  37. flwr/proto/run_pb2.pyi +52 -0
  38. flwr/proto/run_pb2_grpc.py +4 -0
  39. flwr/proto/run_pb2_grpc.pyi +4 -0
  40. flwr/server/__init__.py +0 -4
  41. flwr/server/app.py +190 -395
  42. flwr/server/run_serverapp.py +29 -5
  43. flwr/server/server_app.py +2 -2
  44. flwr/server/superlink/driver/driver_servicer.py +7 -0
  45. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -2
  46. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -2
  47. flwr/server/superlink/fleet/message_handler/message_handler.py +5 -3
  48. flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
  49. {flwr_nightly-1.9.0.dev20240520.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/METADATA +4 -3
  50. {flwr_nightly-1.9.0.dev20240520.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/RECORD +53 -44
  51. {flwr_nightly-1.9.0.dev20240520.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/entry_points.txt +0 -2
  52. {flwr_nightly-1.9.0.dev20240520.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/LICENSE +0 -0
  53. {flwr_nightly-1.9.0.dev20240520.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/WHEEL +0 -0
flwr/client/app.py CHANGED
@@ -14,8 +14,10 @@
14
14
  # ==============================================================================
15
15
  """Flower client app."""
16
16
 
17
+ import signal
17
18
  import sys
18
19
  import time
20
+ from dataclasses import dataclass
19
21
  from logging import DEBUG, ERROR, INFO, WARN
20
22
  from typing import Callable, ContextManager, Optional, Tuple, Type, Union
21
23
 
@@ -37,7 +39,7 @@ from flwr.common.constant import (
37
39
  )
38
40
  from flwr.common.logger import log, warn_deprecated_feature
39
41
  from flwr.common.message import Error
40
- from flwr.common.retry_invoker import RetryInvoker, exponential
42
+ from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
41
43
 
42
44
  from .grpc_client.connection import grpc_connection
43
45
  from .grpc_rere_client.connection import grpc_request_response
@@ -263,6 +265,29 @@ def _start_client_internal(
263
265
  transport, server_address
264
266
  )
265
267
 
268
+ app_state_tracker = _AppStateTracker()
269
+
270
+ def _on_sucess(retry_state: RetryState) -> None:
271
+ app_state_tracker.is_connected = True
272
+ if retry_state.tries > 1:
273
+ log(
274
+ INFO,
275
+ "Connection successful after %.2f seconds and %s tries.",
276
+ retry_state.elapsed_time,
277
+ retry_state.tries,
278
+ )
279
+
280
+ def _on_backoff(retry_state: RetryState) -> None:
281
+ app_state_tracker.is_connected = False
282
+ if retry_state.tries == 1:
283
+ log(WARN, "Connection attempt failed, retrying...")
284
+ else:
285
+ log(
286
+ DEBUG,
287
+ "Connection attempt failed, retrying in %.2f seconds",
288
+ retry_state.actual_wait,
289
+ )
290
+
266
291
  retry_invoker = RetryInvoker(
267
292
  wait_gen_factory=exponential,
268
293
  recoverable_exceptions=connection_error_type,
@@ -278,30 +303,13 @@ def _start_client_internal(
278
303
  if retry_state.tries > 1
279
304
  else None
280
305
  ),
281
- on_success=lambda retry_state: (
282
- log(
283
- INFO,
284
- "Connection successful after %.2f seconds and %s tries.",
285
- retry_state.elapsed_time,
286
- retry_state.tries,
287
- )
288
- if retry_state.tries > 1
289
- else None
290
- ),
291
- on_backoff=lambda retry_state: (
292
- log(WARN, "Connection attempt failed, retrying...")
293
- if retry_state.tries == 1
294
- else log(
295
- DEBUG,
296
- "Connection attempt failed, retrying in %.2f seconds",
297
- retry_state.actual_wait,
298
- )
299
- ),
306
+ on_success=_on_sucess,
307
+ on_backoff=_on_backoff,
300
308
  )
301
309
 
302
310
  node_state = NodeState()
303
311
 
304
- while True:
312
+ while not app_state_tracker.interrupt:
305
313
  sleep_duration: int = 0
306
314
  with connection(
307
315
  address,
@@ -318,99 +326,112 @@ def _start_client_internal(
318
326
  if create_node is not None:
319
327
  create_node() # pylint: disable=not-callable
320
328
 
321
- while True:
322
- # Receive
323
- message = receive()
324
- if message is None:
325
- time.sleep(3) # Wait for 3s before asking again
326
- continue
327
-
328
- log(INFO, "")
329
- if len(message.metadata.group_id) > 0:
329
+ app_state_tracker.register_signal_handler()
330
+ while not app_state_tracker.interrupt:
331
+ try:
332
+ # Receive
333
+ message = receive()
334
+ if message is None:
335
+ time.sleep(3) # Wait for 3s before asking again
336
+ continue
337
+
338
+ log(INFO, "")
339
+ if len(message.metadata.group_id) > 0:
340
+ log(
341
+ INFO,
342
+ "[RUN %s, ROUND %s]",
343
+ message.metadata.run_id,
344
+ message.metadata.group_id,
345
+ )
330
346
  log(
331
347
  INFO,
332
- "[RUN %s, ROUND %s]",
333
- message.metadata.run_id,
334
- message.metadata.group_id,
348
+ "Received: %s message %s",
349
+ message.metadata.message_type,
350
+ message.metadata.message_id,
335
351
  )
336
- log(
337
- INFO,
338
- "Received: %s message %s",
339
- message.metadata.message_type,
340
- message.metadata.message_id,
341
- )
342
-
343
- # Handle control message
344
- out_message, sleep_duration = handle_control_message(message)
345
- if out_message:
346
- send(out_message)
347
- break
348
-
349
- # Register context for this run
350
- node_state.register_context(run_id=message.metadata.run_id)
351
-
352
- # Retrieve context for this run
353
- context = node_state.retrieve_context(run_id=message.metadata.run_id)
354
352
 
355
- # Create an error reply message that will never be used to prevent
356
- # the used-before-assignment linting error
357
- reply_message = message.create_error_reply(
358
- error=Error(code=ErrorCode.UNKNOWN, reason="Unknown")
359
- )
353
+ # Handle control message
354
+ out_message, sleep_duration = handle_control_message(message)
355
+ if out_message:
356
+ send(out_message)
357
+ break
360
358
 
361
- # Handle app loading and task message
362
- try:
363
- # Load ClientApp instance
364
- client_app: ClientApp = load_client_app_fn()
365
-
366
- # Execute ClientApp
367
- reply_message = client_app(message=message, context=context)
368
- except Exception as ex: # pylint: disable=broad-exception-caught
369
-
370
- # Legacy grpc-bidi
371
- if transport in ["grpc-bidi", None]:
372
- log(ERROR, "Client raised an exception.", exc_info=ex)
373
- # Raise exception, crash process
374
- raise ex
375
-
376
- # Don't update/change NodeState
377
-
378
- e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION
379
- # Reason example: "<class 'ZeroDivisionError'>:<'division by zero'>"
380
- reason = str(type(ex)) + ":<'" + str(ex) + "'>"
381
- exc_entity = "ClientApp"
382
- if isinstance(ex, LoadClientAppError):
383
- reason = (
384
- "An exception was raised when attempting to load "
385
- "`ClientApp`"
386
- )
387
- e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
388
- exc_entity = "SuperNode"
359
+ # Register context for this run
360
+ node_state.register_context(run_id=message.metadata.run_id)
389
361
 
390
- log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
362
+ # Retrieve context for this run
363
+ context = node_state.retrieve_context(
364
+ run_id=message.metadata.run_id
365
+ )
391
366
 
392
- # Create error message
367
+ # Create an error reply message that will never be used to prevent
368
+ # the used-before-assignment linting error
393
369
  reply_message = message.create_error_reply(
394
- error=Error(code=e_code, reason=reason)
395
- )
396
- else:
397
- # No exception, update node state
398
- node_state.update_context(
399
- run_id=message.metadata.run_id,
400
- context=context,
370
+ error=Error(code=ErrorCode.UNKNOWN, reason="Unknown")
401
371
  )
402
372
 
403
- # Send
404
- send(reply_message)
405
- log(INFO, "Sent reply")
373
+ # Handle app loading and task message
374
+ try:
375
+ # Load ClientApp instance
376
+ client_app: ClientApp = load_client_app_fn()
377
+
378
+ # Execute ClientApp
379
+ reply_message = client_app(message=message, context=context)
380
+ except Exception as ex: # pylint: disable=broad-exception-caught
381
+
382
+ # Legacy grpc-bidi
383
+ if transport in ["grpc-bidi", None]:
384
+ log(ERROR, "Client raised an exception.", exc_info=ex)
385
+ # Raise exception, crash process
386
+ raise ex
387
+
388
+ # Don't update/change NodeState
389
+
390
+ e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION
391
+ # Ex fmt: "<class 'ZeroDivisionError'>:<'division by zero'>"
392
+ reason = str(type(ex)) + ":<'" + str(ex) + "'>"
393
+ exc_entity = "ClientApp"
394
+ if isinstance(ex, LoadClientAppError):
395
+ reason = (
396
+ "An exception was raised when attempting to load "
397
+ "`ClientApp`"
398
+ )
399
+ e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
400
+ exc_entity = "SuperNode"
401
+
402
+ if not app_state_tracker.interrupt:
403
+ log(
404
+ ERROR, "%s raised an exception", exc_entity, exc_info=ex
405
+ )
406
+
407
+ # Create error message
408
+ reply_message = message.create_error_reply(
409
+ error=Error(code=e_code, reason=reason)
410
+ )
411
+ else:
412
+ # No exception, update node state
413
+ node_state.update_context(
414
+ run_id=message.metadata.run_id,
415
+ context=context,
416
+ )
417
+
418
+ # Send
419
+ send(reply_message)
420
+ log(INFO, "Sent reply")
421
+
422
+ except StopIteration:
423
+ sleep_duration = 0
424
+ break
406
425
 
407
426
  # Unregister node
408
- if delete_node is not None:
427
+ if delete_node is not None and app_state_tracker.is_connected:
409
428
  delete_node() # pylint: disable=not-callable
410
429
 
411
430
  if sleep_duration == 0:
412
431
  log(INFO, "Disconnect and shut down")
432
+ del app_state_tracker
413
433
  break
434
+
414
435
  # Sleep and reconnect afterwards
415
436
  log(
416
437
  INFO,
@@ -579,3 +600,20 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
579
600
  )
580
601
 
581
602
  return connection, address, error_type
603
+
604
+
605
+ @dataclass
606
+ class _AppStateTracker:
607
+ interrupt: bool = False
608
+ is_connected: bool = False
609
+
610
+ def register_signal_handler(self) -> None:
611
+ """Register handlers for exit signals."""
612
+
613
+ def signal_handler(sig, frame): # type: ignore
614
+ # pylint: disable=unused-argument
615
+ self.interrupt = True
616
+ raise StopIteration from None
617
+
618
+ signal.signal(signal.SIGINT, signal_handler)
619
+ signal.signal(signal.SIGTERM, signal_handler)
flwr/client/client_app.py CHANGED
@@ -221,7 +221,7 @@ def _registration_error(fn_name: str) -> ValueError:
221
221
  >>> def client_fn(cid) -> Client:
222
222
  >>> return FlowerClient().to_client()
223
223
  >>>
224
- >>> app = ClientApp()
224
+ >>> app = ClientApp(
225
225
  >>> client_fn=client_fn,
226
226
  >>> )
227
227
 
@@ -31,11 +31,11 @@ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
31
31
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
32
32
  CreateNodeRequest,
33
33
  DeleteNodeRequest,
34
- GetRunRequest,
35
34
  PingRequest,
36
35
  PullTaskInsRequest,
37
36
  PushTaskResRequest,
38
37
  )
38
+ from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
39
39
 
40
40
  _PUBLIC_KEY_HEADER = "public-key"
41
41
  _AUTH_TOKEN_HEADER = "auth-token"
@@ -21,7 +21,7 @@ from contextlib import contextmanager
21
21
  from copy import copy
22
22
  from logging import DEBUG, ERROR
23
23
  from pathlib import Path
24
- from typing import Callable, Iterator, Optional, Sequence, Tuple, Union, cast
24
+ from typing import Callable, Iterator, Optional, Sequence, Tuple, Type, Union, cast
25
25
 
26
26
  import grpc
27
27
  from cryptography.hazmat.primitives.asymmetric import ec
@@ -44,8 +44,6 @@ from flwr.common.serde import message_from_taskins, message_to_taskres
44
44
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
45
45
  CreateNodeRequest,
46
46
  DeleteNodeRequest,
47
- GetRunRequest,
48
- GetRunResponse,
49
47
  PingRequest,
50
48
  PingResponse,
51
49
  PullTaskInsRequest,
@@ -53,6 +51,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
53
51
  )
54
52
  from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
55
53
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
54
+ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
56
55
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
57
56
 
58
57
  from .client_interceptor import AuthenticateClientInterceptor
@@ -73,6 +72,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
73
72
  authentication_keys: Optional[
74
73
  Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
75
74
  ] = None,
75
+ adapter_cls: Optional[Type[FleetStub]] = None,
76
76
  ) -> Iterator[
77
77
  Tuple[
78
78
  Callable[[], Optional[Message]],
@@ -133,7 +133,9 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
133
133
  channel.subscribe(on_channel_state_change)
134
134
 
135
135
  # Shared variables for inner functions
136
- stub = FleetStub(channel)
136
+ if adapter_cls is None:
137
+ adapter_cls = FleetStub
138
+ stub = adapter_cls(channel)
137
139
  metadata: Optional[Metadata] = None
138
140
  node: Optional[Node] = None
139
141
  ping_thread: Optional[threading.Thread] = None
@@ -190,8 +192,6 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
190
192
 
191
193
  # Stop the ping-loop thread
192
194
  ping_stop_event.set()
193
- if ping_thread is not None:
194
- ping_thread.join()
195
195
 
196
196
  # Call FleetAPI
197
197
  delete_node_request = DeleteNodeRequest(node=node)
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Mods."""
15
+ """Flower Built-in Mods."""
16
16
 
17
17
 
18
18
  from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
@@ -46,8 +46,6 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
46
46
  CreateNodeResponse,
47
47
  DeleteNodeRequest,
48
48
  DeleteNodeResponse,
49
- GetRunRequest,
50
- GetRunResponse,
51
49
  PingRequest,
52
50
  PingResponse,
53
51
  PullTaskInsRequest,
@@ -56,6 +54,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
56
54
  PushTaskResResponse,
57
55
  )
58
56
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
57
+ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
59
58
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
60
59
 
61
60
  try:
@@ -20,6 +20,7 @@ from logging import DEBUG, INFO, WARN
20
20
  from pathlib import Path
21
21
  from typing import Callable, Optional, Tuple
22
22
 
23
+ from cryptography.exceptions import UnsupportedAlgorithm
23
24
  from cryptography.hazmat.primitives.asymmetric import ec
24
25
  from cryptography.hazmat.primitives.serialization import (
25
26
  load_ssh_private_key,
@@ -29,14 +30,13 @@ from cryptography.hazmat.primitives.serialization import (
29
30
  from flwr.client.client_app import ClientApp, LoadClientAppError
30
31
  from flwr.common import EventType, event
31
32
  from flwr.common.exit_handlers import register_exit_handlers
32
- from flwr.common.logger import log
33
+ from flwr.common.logger import log, warn_deprecated_feature
33
34
  from flwr.common.object_ref import load_app, validate
34
- from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
35
- ssh_types_to_elliptic_curve,
36
- )
37
35
 
38
36
  from ..app import _start_client_internal
39
37
 
38
+ ADDRESS_FLEET_API_GRPC_RERE = "0.0.0.0:9092"
39
+
40
40
 
41
41
  def run_supernode() -> None:
42
42
  """Run Flower SuperNode."""
@@ -65,6 +65,23 @@ def run_client_app() -> None:
65
65
 
66
66
  args = _parse_args_run_client_app().parse_args()
67
67
 
68
+ if args.server != ADDRESS_FLEET_API_GRPC_RERE:
69
+ warn = "Passing flag --server is deprecated. Use --superlink instead."
70
+ warn_deprecated_feature(warn)
71
+
72
+ if args.superlink != ADDRESS_FLEET_API_GRPC_RERE:
73
+ # if `--superlink` also passed, then
74
+ # warn user that this argument overrides what was passed with `--server`
75
+ log(
76
+ WARN,
77
+ "Both `--server` and `--superlink` were passed. "
78
+ "`--server` will be ignored. Connecting to the Superlink Fleet API "
79
+ "at %s.",
80
+ args.superlink,
81
+ )
82
+ else:
83
+ args.superlink = args.server
84
+
68
85
  root_certificates = _get_certificates(args)
69
86
  log(
70
87
  DEBUG,
@@ -75,7 +92,7 @@ def run_client_app() -> None:
75
92
  authentication_keys = _try_setup_client_authentication(args)
76
93
 
77
94
  _start_client_internal(
78
- server_address=args.server,
95
+ server_address=args.superlink,
79
96
  load_client_app_fn=load_fn,
80
97
  transport="rest" if args.rest else "grpc-rere",
81
98
  root_certificates=root_certificates,
@@ -102,7 +119,7 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
102
119
  WARN,
103
120
  "Option `--insecure` was set. "
104
121
  "Starting insecure HTTP client connected to %s.",
105
- args.server,
122
+ args.superlink,
106
123
  )
107
124
  root_certificates = None
108
125
  else:
@@ -116,7 +133,7 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
116
133
  DEBUG,
117
134
  "Starting secure HTTPS client connected to %s "
118
135
  "with the following certificates: %s.",
119
- args.server,
136
+ args.superlink,
120
137
  cert_path,
121
138
  )
122
139
  return root_certificates
@@ -215,9 +232,14 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
215
232
  )
216
233
  parser.add_argument(
217
234
  "--server",
218
- default="0.0.0.0:9092",
235
+ default=ADDRESS_FLEET_API_GRPC_RERE,
219
236
  help="Server address",
220
237
  )
238
+ parser.add_argument(
239
+ "--superlink",
240
+ default=ADDRESS_FLEET_API_GRPC_RERE,
241
+ help="SuperLink Fleet API (gRPC-rere) address (IPv4, IPv6, or a domain name)",
242
+ )
221
243
  parser.add_argument(
222
244
  "--max-retries",
223
245
  type=int,
@@ -242,40 +264,60 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
242
264
  " Default: current working directory.",
243
265
  )
244
266
  parser.add_argument(
245
- "--authentication-keys",
246
- nargs=2,
247
- metavar=("CLIENT_PRIVATE_KEY", "CLIENT_PUBLIC_KEY"),
267
+ "--auth-supernode-private-key",
268
+ type=str,
269
+ help="The SuperNode's private key (as a path str) to enable authentication.",
270
+ )
271
+ parser.add_argument(
272
+ "--auth-supernode-public-key",
248
273
  type=str,
249
- help="Provide two file paths: (1) the client's private "
250
- "key file, and (2) the client's public key file.",
274
+ help="The SuperNode's public key (as a path str) to enable authentication.",
251
275
  )
252
276
 
253
277
 
254
278
  def _try_setup_client_authentication(
255
279
  args: argparse.Namespace,
256
280
  ) -> Optional[Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
257
- if not args.authentication_keys:
281
+ if not args.auth_supernode_private_key and not args.auth_supernode_public_key:
258
282
  return None
259
283
 
260
- ssh_private_key = load_ssh_private_key(
261
- Path(args.authentication_keys[0]).read_bytes(),
262
- None,
263
- )
264
- ssh_public_key = load_ssh_public_key(Path(args.authentication_keys[1]).read_bytes())
284
+ if not args.auth_supernode_private_key or not args.auth_supernode_public_key:
285
+ sys.exit(
286
+ "Authentication requires file paths to both "
287
+ "'--auth-supernode-private-key' and '--auth-supernode-public-key'"
288
+ "to be provided (providing only one of them is not sufficient)."
289
+ )
290
+
291
+ try:
292
+ ssh_private_key = load_ssh_private_key(
293
+ Path(args.auth_supernode_private_key).read_bytes(),
294
+ None,
295
+ )
296
+ if not isinstance(ssh_private_key, ec.EllipticCurvePrivateKey):
297
+ raise ValueError()
298
+ except (ValueError, UnsupportedAlgorithm):
299
+ sys.exit(
300
+ "Error: Unable to parse the private key file in "
301
+ "'--auth-supernode-private-key'. Authentication requires elliptic "
302
+ "curve private and public key pair. Please ensure that the file "
303
+ "path points to a valid private key file and try again."
304
+ )
265
305
 
266
306
  try:
267
- client_private_key, client_public_key = ssh_types_to_elliptic_curve(
268
- ssh_private_key, ssh_public_key
307
+ ssh_public_key = load_ssh_public_key(
308
+ Path(args.auth_supernode_public_key).read_bytes()
269
309
  )
270
- except TypeError:
310
+ if not isinstance(ssh_public_key, ec.EllipticCurvePublicKey):
311
+ raise ValueError()
312
+ except (ValueError, UnsupportedAlgorithm):
271
313
  sys.exit(
272
- "The file paths provided could not be read as a private and public "
273
- "key pair. Client authentication requires an elliptic curve public and "
274
- "private key pair. Please provide the file paths containing elliptic "
275
- "curve private and public keys to '--authentication-keys'."
314
+ "Error: Unable to parse the public key file in "
315
+ "'--auth-supernode-public-key'. Authentication requires elliptic "
316
+ "curve private and public key pair. Please ensure that the file "
317
+ "path points to a valid public key file and try again."
276
318
  )
277
319
 
278
320
  return (
279
- client_private_key,
280
- client_public_key,
321
+ ssh_private_key,
322
+ ssh_public_key,
281
323
  )
flwr/common/object_ref.py CHANGED
@@ -30,6 +30,7 @@ attribute.
30
30
 
31
31
  def validate(
32
32
  module_attribute_str: str,
33
+ check_module: bool = True,
33
34
  ) -> Tuple[bool, Optional[str]]:
34
35
  """Validate object reference.
35
36
 
@@ -56,15 +57,18 @@ def validate(
56
57
  f"Missing attribute in {module_attribute_str}{OBJECT_REF_HELP_STR}",
57
58
  )
58
59
 
59
- # Load module
60
- module = find_spec(module_str)
61
- if module and module.origin:
62
- if not _find_attribute_in_module(module.origin, attributes_str):
63
- return (
64
- False,
65
- f"Unable to find attribute {attributes_str} in module {module_str}"
66
- f"{OBJECT_REF_HELP_STR}",
67
- )
60
+ if check_module:
61
+ # Load module
62
+ module = find_spec(module_str)
63
+ if module and module.origin:
64
+ if not _find_attribute_in_module(module.origin, attributes_str):
65
+ return (
66
+ False,
67
+ f"Unable to find attribute {attributes_str} in module {module_str}"
68
+ f"{OBJECT_REF_HELP_STR}",
69
+ )
70
+ return (True, None)
71
+ else:
68
72
  return (True, None)
69
73
 
70
74
  return (
@@ -35,6 +35,8 @@ from .typing import (
35
35
  Status,
36
36
  )
37
37
 
38
+ EMPTY_TENSOR_KEY = "_empty"
39
+
38
40
 
39
41
  def parametersrecord_to_parameters(
40
42
  record: ParametersRecord, keep_input: bool
@@ -59,7 +61,8 @@ def parametersrecord_to_parameters(
59
61
  parameters = Parameters(tensors=[], tensor_type="")
60
62
 
61
63
  for key in list(record.keys()):
62
- parameters.tensors.append(record[key].data)
64
+ if key != EMPTY_TENSOR_KEY:
65
+ parameters.tensors.append(record[key].data)
63
66
 
64
67
  if not parameters.tensor_type:
65
68
  # Setting from first array in record. Recall the warning in the docstrings
@@ -103,6 +106,10 @@ def parameters_to_parametersrecord(
103
106
  data=tensor, dtype="", stype=tensor_type, shape=[]
104
107
  )
105
108
 
109
+ if num_arrays == 0:
110
+ ordered_dict[EMPTY_TENSOR_KEY] = Array(
111
+ data=b"", dtype="", stype=tensor_type, shape=[]
112
+ )
106
113
  return ParametersRecord(ordered_dict, keep_input=keep_input)
107
114
 
108
115
 
@@ -117,18 +117,3 @@ def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool:
117
117
  return True
118
118
  except InvalidSignature:
119
119
  return False
120
-
121
-
122
- def ssh_types_to_elliptic_curve(
123
- private_key: serialization.SSHPrivateKeyTypes,
124
- public_key: serialization.SSHPublicKeyTypes,
125
- ) -> Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]:
126
- """Cast SSH key types to elliptic curve."""
127
- if isinstance(private_key, ec.EllipticCurvePrivateKey) and isinstance(
128
- public_key, ec.EllipticCurvePublicKey
129
- ):
130
- return (private_key, public_key)
131
-
132
- raise TypeError(
133
- "The provided key is not an EllipticCurvePrivateKey or EllipticCurvePublicKey"
134
- )