flwr-nightly 1.10.0.dev20240612__py3-none-any.whl → 1.10.0.dev20240624__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (130) hide show
  1. flwr/cli/app.py +3 -0
  2. flwr/cli/build.py +6 -8
  3. flwr/cli/config_utils.py +53 -3
  4. flwr/cli/install.py +35 -20
  5. flwr/cli/new/new.py +104 -28
  6. flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
  7. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  8. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +86 -0
  9. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +124 -0
  10. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
  11. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
  12. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
  13. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
  14. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
  15. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +42 -0
  16. flwr/cli/run/run.py +46 -2
  17. flwr/client/__init__.py +1 -1
  18. flwr/client/app.py +22 -10
  19. flwr/client/client_app.py +1 -1
  20. flwr/client/dpfedavg_numpy_client.py +1 -1
  21. flwr/client/grpc_adapter_client/__init__.py +15 -0
  22. flwr/client/grpc_adapter_client/connection.py +94 -0
  23. flwr/client/grpc_client/connection.py +5 -1
  24. flwr/client/grpc_rere_client/__init__.py +1 -1
  25. flwr/client/grpc_rere_client/connection.py +9 -2
  26. flwr/client/grpc_rere_client/grpc_adapter.py +133 -0
  27. flwr/client/message_handler/__init__.py +1 -1
  28. flwr/client/message_handler/message_handler.py +1 -1
  29. flwr/client/mod/__init__.py +4 -4
  30. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  31. flwr/client/mod/utils.py +1 -1
  32. flwr/client/rest_client/__init__.py +1 -1
  33. flwr/client/rest_client/connection.py +10 -2
  34. flwr/client/supernode/app.py +141 -41
  35. flwr/common/__init__.py +12 -12
  36. flwr/common/address.py +1 -1
  37. flwr/common/config.py +73 -0
  38. flwr/common/constant.py +16 -1
  39. flwr/common/date.py +1 -1
  40. flwr/common/dp.py +1 -1
  41. flwr/common/grpc.py +1 -1
  42. flwr/common/object_ref.py +39 -5
  43. flwr/common/record/__init__.py +1 -1
  44. flwr/common/secure_aggregation/__init__.py +1 -1
  45. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  46. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  47. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
  48. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  49. flwr/common/secure_aggregation/quantization.py +1 -1
  50. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  51. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  52. flwr/common/telemetry.py +4 -0
  53. flwr/common/typing.py +9 -0
  54. flwr/common/version.py +14 -0
  55. flwr/proto/exec_pb2.py +34 -0
  56. flwr/proto/exec_pb2.pyi +55 -0
  57. flwr/proto/exec_pb2_grpc.py +101 -0
  58. flwr/proto/exec_pb2_grpc.pyi +41 -0
  59. flwr/proto/fab_pb2.py +30 -0
  60. flwr/proto/fab_pb2.pyi +56 -0
  61. flwr/proto/fab_pb2_grpc.py +4 -0
  62. flwr/proto/fab_pb2_grpc.pyi +4 -0
  63. flwr/server/__init__.py +2 -2
  64. flwr/server/app.py +62 -25
  65. flwr/server/compat/app.py +1 -1
  66. flwr/server/compat/app_utils.py +1 -1
  67. flwr/server/compat/driver_client_proxy.py +1 -1
  68. flwr/server/driver/driver.py +6 -0
  69. flwr/server/driver/grpc_driver.py +85 -63
  70. flwr/server/driver/inmemory_driver.py +28 -26
  71. flwr/server/run_serverapp.py +65 -20
  72. flwr/server/strategy/__init__.py +2 -2
  73. flwr/server/strategy/bulyan.py +1 -1
  74. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  75. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  76. flwr/server/strategy/fedadagrad.py +1 -1
  77. flwr/server/strategy/fedadam.py +1 -1
  78. flwr/server/strategy/fedavg_android.py +1 -1
  79. flwr/server/strategy/fedavgm.py +1 -1
  80. flwr/server/strategy/fedmedian.py +1 -1
  81. flwr/server/strategy/fedopt.py +1 -1
  82. flwr/server/strategy/fedprox.py +1 -1
  83. flwr/server/strategy/fedxgb_bagging.py +1 -1
  84. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  85. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  86. flwr/server/strategy/fedyogi.py +1 -1
  87. flwr/server/strategy/krum.py +1 -1
  88. flwr/server/strategy/qfedavg.py +1 -1
  89. flwr/server/superlink/driver/__init__.py +1 -1
  90. flwr/server/superlink/driver/driver_grpc.py +1 -1
  91. flwr/server/superlink/driver/driver_servicer.py +15 -3
  92. flwr/server/superlink/fleet/__init__.py +1 -1
  93. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  94. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
  95. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  96. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  97. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  98. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  99. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +5 -1
  100. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  101. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
  102. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  103. flwr/server/superlink/fleet/message_handler/message_handler.py +4 -4
  104. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  105. flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
  106. flwr/server/superlink/fleet/vce/backend/raybackend.py +44 -25
  107. flwr/server/superlink/fleet/vce/vce_api.py +3 -1
  108. flwr/server/superlink/state/__init__.py +1 -1
  109. flwr/server/superlink/state/in_memory_state.py +9 -6
  110. flwr/server/superlink/state/sqlite_state.py +7 -4
  111. flwr/server/superlink/state/state.py +6 -5
  112. flwr/server/superlink/state/state_factory.py +11 -2
  113. flwr/server/utils/__init__.py +1 -1
  114. flwr/server/utils/tensorboard.py +1 -1
  115. flwr/simulation/__init__.py +5 -2
  116. flwr/simulation/app.py +1 -1
  117. flwr/simulation/ray_transport/__init__.py +1 -1
  118. flwr/simulation/ray_transport/ray_actor.py +0 -6
  119. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  120. flwr/simulation/run_simulation.py +63 -22
  121. flwr/superexec/__init__.py +21 -0
  122. flwr/superexec/app.py +178 -0
  123. flwr/superexec/exec_grpc.py +51 -0
  124. flwr/superexec/exec_servicer.py +65 -0
  125. flwr/superexec/executor.py +54 -0
  126. {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/METADATA +2 -1
  127. {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/RECORD +130 -101
  128. {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/entry_points.txt +1 -0
  129. {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/LICENSE +0 -0
  130. {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/WHEEL +0 -0
flwr/proto/fab_pb2.pyi ADDED
@@ -0,0 +1,56 @@
1
+ """
2
+ @generated by mypy-protobuf. Do not edit manually!
3
+ isort:skip_file
4
+ """
5
+ import builtins
6
+ import google.protobuf.descriptor
7
+ import google.protobuf.message
8
+ import typing
9
+ import typing_extensions
10
+
11
+ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
12
+
13
+ class Fab(google.protobuf.message.Message):
14
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
15
+ HASH_FIELD_NUMBER: builtins.int
16
+ CONTENT_FIELD_NUMBER: builtins.int
17
+ hash: typing.Text
18
+ """This field is the hash of the data field. It is used to identify the data.
19
+ The hash is calculated using the SHA-256 algorithm and is represented as a
20
+ hex string (sha256hex).
21
+ """
22
+
23
+ content: builtins.bytes
24
+ """This field contains the fab file contents a one bytes blob."""
25
+
26
+ def __init__(self,
27
+ *,
28
+ hash: typing.Text = ...,
29
+ content: builtins.bytes = ...,
30
+ ) -> None: ...
31
+ def ClearField(self, field_name: typing_extensions.Literal["content",b"content","hash",b"hash"]) -> None: ...
32
+ global___Fab = Fab
33
+
34
+ class GetFabRequest(google.protobuf.message.Message):
35
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
36
+ HASH_FIELD_NUMBER: builtins.int
37
+ hash: typing.Text
38
+ def __init__(self,
39
+ *,
40
+ hash: typing.Text = ...,
41
+ ) -> None: ...
42
+ def ClearField(self, field_name: typing_extensions.Literal["hash",b"hash"]) -> None: ...
43
+ global___GetFabRequest = GetFabRequest
44
+
45
+ class GetFabResponse(google.protobuf.message.Message):
46
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
47
+ FAB_FIELD_NUMBER: builtins.int
48
+ @property
49
+ def fab(self) -> global___Fab: ...
50
+ def __init__(self,
51
+ *,
52
+ fab: typing.Optional[global___Fab] = ...,
53
+ ) -> None: ...
54
+ def HasField(self, field_name: typing_extensions.Literal["fab",b"fab"]) -> builtins.bool: ...
55
+ def ClearField(self, field_name: typing_extensions.Literal["fab",b"fab"]) -> None: ...
56
+ global___GetFabResponse = GetFabResponse
@@ -0,0 +1,4 @@
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+
@@ -0,0 +1,4 @@
1
+ """
2
+ @generated by mypy-protobuf. Do not edit manually!
3
+ isort:skip_file
4
+ """
flwr/server/__init__.py CHANGED
@@ -34,12 +34,12 @@ __all__ = [
34
34
  "Driver",
35
35
  "History",
36
36
  "LegacyContext",
37
- "run_server_app",
38
- "run_superlink",
39
37
  "Server",
40
38
  "ServerApp",
41
39
  "ServerConfig",
42
40
  "SimpleClientManager",
41
+ "run_server_app",
42
+ "run_superlink",
43
43
  "start_server",
44
44
  "strategy",
45
45
  "workflow",
flwr/server/app.py CHANGED
@@ -36,6 +36,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
36
36
  from flwr.common.address import parse_address
37
37
  from flwr.common.constant import (
38
38
  MISSING_EXTRA_REST,
39
+ TRANSPORT_TYPE_GRPC_ADAPTER,
39
40
  TRANSPORT_TYPE_GRPC_RERE,
40
41
  TRANSPORT_TYPE_REST,
41
42
  )
@@ -48,6 +49,7 @@ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
48
49
  from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
49
50
  add_FleetServicer_to_server,
50
51
  )
52
+ from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server
51
53
 
52
54
  from .client_manager import ClientManager
53
55
  from .history import History
@@ -55,6 +57,7 @@ from .server import Server, init_defaults, run_fl
55
57
  from .server_config import ServerConfig
56
58
  from .strategy import Strategy
57
59
  from .superlink.driver.driver_grpc import run_driver_api_grpc
60
+ from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
58
61
  from .superlink.fleet.grpc_bidi.grpc_server import (
59
62
  generic_create_grpc_server,
60
63
  start_grpc_server,
@@ -200,15 +203,7 @@ def run_superlink() -> None:
200
203
  args = _parse_args_run_superlink().parse_args()
201
204
 
202
205
  # Parse IP address
203
- parsed_driver_address = parse_address(args.driver_api_address)
204
- if not parsed_driver_address:
205
- sys.exit(f"Driver IP address ({args.driver_api_address}) cannot be parsed.")
206
- driver_host, driver_port, driver_is_v6 = parsed_driver_address
207
- driver_address = (
208
- f"[{driver_host}]:{driver_port}"
209
- if driver_is_v6
210
- else f"{driver_host}:{driver_port}"
211
- )
206
+ driver_address, _, _ = _format_address(args.driver_api_address)
212
207
 
213
208
  # Obtain certificates
214
209
  certificates = _try_obtain_certificates(args)
@@ -226,18 +221,15 @@ def run_superlink() -> None:
226
221
  grpc_servers = [driver_server]
227
222
  bckg_threads = []
228
223
  if not args.fleet_api_address:
229
- args.fleet_api_address = (
230
- ADDRESS_FLEET_API_GRPC_RERE
231
- if args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE
232
- else ADDRESS_FLEET_API_REST
233
- )
234
- parsed_fleet_address = parse_address(args.fleet_api_address)
235
- if not parsed_fleet_address:
236
- sys.exit(f"Fleet IP address ({args.fleet_api_address}) cannot be parsed.")
237
- fleet_host, fleet_port, fleet_is_v6 = parsed_fleet_address
238
- fleet_address = (
239
- f"[{fleet_host}]:{fleet_port}" if fleet_is_v6 else f"{fleet_host}:{fleet_port}"
240
- )
224
+ if args.fleet_api_type in [
225
+ TRANSPORT_TYPE_GRPC_RERE,
226
+ TRANSPORT_TYPE_GRPC_ADAPTER,
227
+ ]:
228
+ args.fleet_api_address = ADDRESS_FLEET_API_GRPC_RERE
229
+ elif args.fleet_api_type == TRANSPORT_TYPE_REST:
230
+ args.fleet_api_address = ADDRESS_FLEET_API_REST
231
+
232
+ fleet_address, host, port = _format_address(args.fleet_api_address)
241
233
 
242
234
  num_workers = args.fleet_api_num_workers
243
235
  if num_workers != 1:
@@ -267,8 +259,8 @@ def run_superlink() -> None:
267
259
  fleet_thread = threading.Thread(
268
260
  target=_run_fleet_api_rest,
269
261
  args=(
270
- fleet_host,
271
- fleet_port,
262
+ host,
263
+ port,
272
264
  ssl_keyfile,
273
265
  ssl_certfile,
274
266
  state_factory,
@@ -306,6 +298,13 @@ def run_superlink() -> None:
306
298
  interceptors=interceptors,
307
299
  )
308
300
  grpc_servers.append(fleet_server)
301
+ elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER:
302
+ fleet_server = _run_fleet_api_grpc_adapter(
303
+ address=fleet_address,
304
+ state_factory=state_factory,
305
+ certificates=certificates,
306
+ )
307
+ grpc_servers.append(fleet_server)
309
308
  else:
310
309
  raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
311
310
 
@@ -325,6 +324,16 @@ def run_superlink() -> None:
325
324
  driver_server.wait_for_termination(timeout=1)
326
325
 
327
326
 
327
+ def _format_address(address: str) -> Tuple[str, str, int]:
328
+ parsed_address = parse_address(address)
329
+ if not parsed_address:
330
+ sys.exit(
331
+ f"Address ({address}) cannot be parsed (expected: URL or IPv4 or IPv6)."
332
+ )
333
+ host, port, is_v6 = parsed_address
334
+ return (f"[{host}]:{port}" if is_v6 else f"{host}:{port}", host, port)
335
+
336
+
328
337
  def _try_setup_client_authentication(
329
338
  args: argparse.Namespace,
330
339
  certificates: Optional[Tuple[bytes, bytes, bytes]],
@@ -422,7 +431,7 @@ def _try_obtain_certificates(
422
431
  log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.")
423
432
  return None
424
433
  # Check if certificates are provided
425
- if args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
434
+ if args.fleet_api_type in [TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_GRPC_ADAPTER]:
426
435
  if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile:
427
436
  if not isfile(args.ssl_ca_certfile):
428
437
  sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.")
@@ -494,6 +503,30 @@ def _run_fleet_api_grpc_rere(
494
503
  return fleet_grpc_server
495
504
 
496
505
 
506
+ def _run_fleet_api_grpc_adapter(
507
+ address: str,
508
+ state_factory: StateFactory,
509
+ certificates: Optional[Tuple[bytes, bytes, bytes]],
510
+ ) -> grpc.Server:
511
+ """Run Fleet API (GrpcAdapter)."""
512
+ # Create Fleet API gRPC server
513
+ fleet_servicer = GrpcAdapterServicer(
514
+ state_factory=state_factory,
515
+ )
516
+ fleet_add_servicer_to_server_fn = add_GrpcAdapterServicer_to_server
517
+ fleet_grpc_server = generic_create_grpc_server(
518
+ servicer_and_add_fn=(fleet_servicer, fleet_add_servicer_to_server_fn),
519
+ server_address=address,
520
+ max_message_length=GRPC_MAX_MESSAGE_LENGTH,
521
+ certificates=certificates,
522
+ )
523
+
524
+ log(INFO, "Flower ECE: Starting Fleet API (GrpcAdapter) on %s", address)
525
+ fleet_grpc_server.start()
526
+
527
+ return fleet_grpc_server
528
+
529
+
497
530
  # pylint: disable=import-outside-toplevel,too-many-arguments
498
531
  def _run_fleet_api_rest(
499
532
  host: str,
@@ -609,7 +642,11 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None:
609
642
  "--fleet-api-type",
610
643
  default=TRANSPORT_TYPE_GRPC_RERE,
611
644
  type=str,
612
- choices=[TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST],
645
+ choices=[
646
+ TRANSPORT_TYPE_GRPC_RERE,
647
+ TRANSPORT_TYPE_GRPC_ADAPTER,
648
+ TRANSPORT_TYPE_REST,
649
+ ],
613
650
  help="Start a gRPC-rere or REST (experimental) Fleet API server.",
614
651
  )
615
652
  parser.add_argument(
flwr/server/compat/app.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -91,7 +91,7 @@ def _update_client_manager(
91
91
  node_id=node_id,
92
92
  driver=driver,
93
93
  anonymous=False,
94
- run_id=driver.run_id, # type: ignore
94
+ run_id=driver.run.run_id,
95
95
  )
96
96
  if client_manager.register(client_proxy):
97
97
  registered_nodes[node_id] = client_proxy
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -19,11 +19,17 @@ from abc import ABC, abstractmethod
19
19
  from typing import Iterable, List, Optional
20
20
 
21
21
  from flwr.common import Message, RecordSet
22
+ from flwr.common.typing import Run
22
23
 
23
24
 
24
25
  class Driver(ABC):
25
26
  """Abstract base Driver class for the Driver API."""
26
27
 
28
+ @property
29
+ @abstractmethod
30
+ def run(self) -> Run:
31
+ """Run information."""
32
+
27
33
  @abstractmethod
28
34
  def create_message( # pylint: disable=too-many-arguments
29
35
  self,
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@
17
17
  import time
18
18
  import warnings
19
19
  from logging import DEBUG, ERROR, WARNING
20
- from typing import Iterable, List, Optional, Tuple
20
+ from typing import Iterable, List, Optional, Tuple, cast
21
21
 
22
22
  import grpc
23
23
 
@@ -25,6 +25,7 @@ from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, ev
25
25
  from flwr.common.grpc import create_channel
26
26
  from flwr.common.logger import log
27
27
  from flwr.common.serde import message_from_taskres, message_to_taskins
28
+ from flwr.common.typing import Run
28
29
  from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
29
30
  CreateRunRequest,
30
31
  CreateRunResponse,
@@ -37,6 +38,7 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
37
38
  )
38
39
  from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
39
40
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
41
+ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
40
42
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
41
43
 
42
44
  from .driver import Driver
@@ -46,13 +48,24 @@ DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
46
48
  ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
47
49
  [Driver] Error: Not connected.
48
50
 
49
- Call `connect()` on the `GrpcDriverHelper` instance before calling any of the other
50
- `GrpcDriverHelper` methods.
51
+ Call `connect()` on the `GrpcDriverStub` instance before calling any of the other
52
+ `GrpcDriverStub` methods.
51
53
  """
52
54
 
53
55
 
54
- class GrpcDriverHelper:
55
- """`GrpcDriverHelper` provides access to the gRPC Driver API/service."""
56
+ class GrpcDriverStub:
57
+ """`GrpcDriverStub` provides access to the gRPC Driver API/service.
58
+
59
+ Parameters
60
+ ----------
61
+ driver_service_address : Optional[str]
62
+ The IPv4 or IPv6 address of the Driver API server.
63
+ Defaults to `"[::]:9091"`.
64
+ root_certificates : Optional[bytes] (default: None)
65
+ The PEM-encoded root certificates as a byte string.
66
+ If provided, a secure connection using the certificates will be
67
+ established to an SSL-enabled Flower server.
68
+ """
56
69
 
57
70
  def __init__(
58
71
  self,
@@ -64,6 +77,10 @@ class GrpcDriverHelper:
64
77
  self.channel: Optional[grpc.Channel] = None
65
78
  self.stub: Optional[DriverStub] = None
66
79
 
80
+ def is_connected(self) -> bool:
81
+ """Return True if connected to the Driver API server, otherwise False."""
82
+ return self.channel is not None
83
+
67
84
  def connect(self) -> None:
68
85
  """Connect to the Driver API."""
69
86
  event(EventType.DRIVER_CONNECT)
@@ -95,18 +112,29 @@ class GrpcDriverHelper:
95
112
  # Check if channel is open
96
113
  if self.stub is None:
97
114
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
98
- raise ConnectionError("`GrpcDriverHelper` instance not connected")
115
+ raise ConnectionError("`GrpcDriverStub` instance not connected")
99
116
 
100
117
  # Call Driver API
101
118
  res: CreateRunResponse = self.stub.CreateRun(request=req)
102
119
  return res
103
120
 
121
+ def get_run(self, req: GetRunRequest) -> GetRunResponse:
122
+ """Get run information."""
123
+ # Check if channel is open
124
+ if self.stub is None:
125
+ log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
126
+ raise ConnectionError("`GrpcDriverStub` instance not connected")
127
+
128
+ # Call gRPC Driver API
129
+ res: GetRunResponse = self.stub.GetRun(request=req)
130
+ return res
131
+
104
132
  def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
105
133
  """Get client IDs."""
106
134
  # Check if channel is open
107
135
  if self.stub is None:
108
136
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
109
- raise ConnectionError("`GrpcDriverHelper` instance not connected")
137
+ raise ConnectionError("`GrpcDriverStub` instance not connected")
110
138
 
111
139
  # Call gRPC Driver API
112
140
  res: GetNodesResponse = self.stub.GetNodes(request=req)
@@ -117,7 +145,7 @@ class GrpcDriverHelper:
117
145
  # Check if channel is open
118
146
  if self.stub is None:
119
147
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
120
- raise ConnectionError("`GrpcDriverHelper` instance not connected")
148
+ raise ConnectionError("`GrpcDriverStub` instance not connected")
121
149
 
122
150
  # Call gRPC Driver API
123
151
  res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
@@ -128,7 +156,7 @@ class GrpcDriverHelper:
128
156
  # Check if channel is open
129
157
  if self.stub is None:
130
158
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
131
- raise ConnectionError("`GrpcDriverHelper` instance not connected")
159
+ raise ConnectionError("`GrpcDriverStub` instance not connected")
132
160
 
133
161
  # Call Driver API
134
162
  res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
@@ -140,56 +168,52 @@ class GrpcDriver(Driver):
140
168
 
141
169
  Parameters
142
170
  ----------
143
- driver_service_address : Optional[str]
144
- The IPv4 or IPv6 address of the Driver API server.
145
- Defaults to `"[::]:9091"`.
146
- certificates : bytes (default: None)
147
- Tuple containing root certificate, server certificate, and private key
148
- to start a secure SSL-enabled server. The tuple is expected to have
149
- three bytes elements in the following order:
150
-
151
- * CA certificate.
152
- * server certificate.
153
- * server private key.
154
- fab_id : str (default: None)
155
- The identifier of the FAB used in the run.
156
- fab_version : str (default: None)
157
- The version of the FAB used in the run.
171
+ run_id : int
172
+ The identifier of the run.
173
+ stub : Optional[GrpcDriverStub] (default: None)
174
+ The ``GrpcDriverStub`` instance used to communicate with the SuperLink.
175
+ If None, an instance connected to "[::]:9091" will be created.
158
176
  """
159
177
 
160
- def __init__(
178
+ def __init__( # pylint: disable=too-many-arguments
161
179
  self,
162
- driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
163
- root_certificates: Optional[bytes] = None,
164
- fab_id: Optional[str] = None,
165
- fab_version: Optional[str] = None,
180
+ run_id: int,
181
+ stub: Optional[GrpcDriverStub] = None,
166
182
  ) -> None:
167
- self.addr = driver_service_address
168
- self.root_certificates = root_certificates
169
- self.driver_helper: Optional[GrpcDriverHelper] = None
170
- self.run_id: Optional[int] = None
171
- self.fab_id = fab_id if fab_id is not None else ""
172
- self.fab_version = fab_version if fab_version is not None else ""
183
+ self._run_id = run_id
184
+ self._run: Optional[Run] = None
185
+ self.stub = stub if stub is not None else GrpcDriverStub()
173
186
  self.node = Node(node_id=0, anonymous=True)
174
187
 
175
- def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]:
176
- # Check if the GrpcDriverHelper is initialized
177
- if self.driver_helper is None or self.run_id is None:
178
- # Connect and create run
179
- self.driver_helper = GrpcDriverHelper(
180
- driver_service_address=self.addr,
181
- root_certificates=self.root_certificates,
188
+ @property
189
+ def run(self) -> Run:
190
+ """Run information."""
191
+ self._get_stub_and_run_id()
192
+ return Run(**vars(cast(Run, self._run)))
193
+
194
+ def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]:
195
+ # Check if is initialized
196
+ if self._run is None:
197
+ # Connect
198
+ if not self.stub.is_connected():
199
+ self.stub.connect()
200
+ # Get the run info
201
+ req = GetRunRequest(run_id=self._run_id)
202
+ res = self.stub.get_run(req)
203
+ if not res.HasField("run"):
204
+ raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
205
+ self._run = Run(
206
+ run_id=res.run.run_id,
207
+ fab_id=res.run.fab_id,
208
+ fab_version=res.run.fab_version,
182
209
  )
183
- self.driver_helper.connect()
184
- req = CreateRunRequest(fab_id=self.fab_id, fab_version=self.fab_version)
185
- res = self.driver_helper.create_run(req)
186
- self.run_id = res.run_id
187
- return self.driver_helper, self.run_id
210
+
211
+ return self.stub, self._run.run_id
188
212
 
189
213
  def _check_message(self, message: Message) -> None:
190
214
  # Check if the message is valid
191
215
  if not (
192
- message.metadata.run_id == self.run_id
216
+ message.metadata.run_id == cast(Run, self._run).run_id
193
217
  and message.metadata.src_node_id == self.node.node_id
194
218
  and message.metadata.message_id == ""
195
219
  and message.metadata.reply_to_message == ""
@@ -210,7 +234,7 @@ class GrpcDriver(Driver):
210
234
  This method constructs a new `Message` with given content and metadata.
211
235
  The `run_id` and `src_node_id` will be set automatically.
212
236
  """
213
- _, run_id = self._get_grpc_driver_helper_and_run_id()
237
+ _, run_id = self._get_stub_and_run_id()
214
238
  if ttl:
215
239
  warnings.warn(
216
240
  "A custom TTL was set, but note that the SuperLink does not enforce "
@@ -234,9 +258,9 @@ class GrpcDriver(Driver):
234
258
 
235
259
  def get_node_ids(self) -> List[int]:
236
260
  """Get node IDs."""
237
- grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id()
238
- # Call GrpcDriverHelper method
239
- res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id))
261
+ stub, run_id = self._get_stub_and_run_id()
262
+ # Call GrpcDriverStub method
263
+ res = stub.get_nodes(GetNodesRequest(run_id=run_id))
240
264
  return [node.node_id for node in res.nodes]
241
265
 
242
266
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
@@ -245,7 +269,7 @@ class GrpcDriver(Driver):
245
269
  This method takes an iterable of messages and sends each message
246
270
  to the node specified in `dst_node_id`.
247
271
  """
248
- grpc_driver_helper, _ = self._get_grpc_driver_helper_and_run_id()
272
+ stub, _ = self._get_stub_and_run_id()
249
273
  # Construct TaskIns
250
274
  task_ins_list: List[TaskIns] = []
251
275
  for msg in messages:
@@ -255,10 +279,8 @@ class GrpcDriver(Driver):
255
279
  taskins = message_to_taskins(msg)
256
280
  # Add to list
257
281
  task_ins_list.append(taskins)
258
- # Call GrpcDriverHelper method
259
- res = grpc_driver_helper.push_task_ins(
260
- PushTaskInsRequest(task_ins_list=task_ins_list)
261
- )
282
+ # Call GrpcDriverStub method
283
+ res = stub.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
262
284
  return list(res.task_ids)
263
285
 
264
286
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
@@ -267,9 +289,9 @@ class GrpcDriver(Driver):
267
289
  This method is used to collect messages from the SuperLink that correspond to a
268
290
  set of given message IDs.
269
291
  """
270
- grpc_driver, _ = self._get_grpc_driver_helper_and_run_id()
292
+ stub, _ = self._get_stub_and_run_id()
271
293
  # Pull TaskRes
272
- res = grpc_driver.pull_task_res(
294
+ res = stub.pull_task_res(
273
295
  PullTaskResRequest(node=self.node, task_ids=message_ids)
274
296
  )
275
297
  # Convert TaskRes to Message
@@ -308,8 +330,8 @@ class GrpcDriver(Driver):
308
330
 
309
331
  def close(self) -> None:
310
332
  """Disconnect from the SuperLink if connected."""
311
- # Check if GrpcDriverHelper is initialized
312
- if self.driver_helper is None:
333
+ # Check if `connect` was called before
334
+ if not self.stub.is_connected():
313
335
  return
314
336
  # Disconnect
315
- self.driver_helper.disconnect()
337
+ self.stub.disconnect()
@@ -17,11 +17,12 @@
17
17
 
18
18
  import time
19
19
  import warnings
20
- from typing import Iterable, List, Optional
20
+ from typing import Iterable, List, Optional, cast
21
21
  from uuid import UUID
22
22
 
23
23
  from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
24
24
  from flwr.common.serde import message_from_taskres, message_to_taskins
25
+ from flwr.common.typing import Run
25
26
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
26
27
  from flwr.server.superlink.state import StateFactory
27
28
 
@@ -33,30 +34,27 @@ class InMemoryDriver(Driver):
33
34
 
34
35
  Parameters
35
36
  ----------
37
+ run_id : int
38
+ The identifier of the run.
36
39
  state_factory : StateFactory
37
40
  A StateFactory embedding a state that this driver can interface with.
38
- fab_id : str (default: None)
39
- The identifier of the FAB used in the run.
40
- fab_version : str (default: None)
41
- The version of the FAB used in the run.
42
41
  """
43
42
 
44
43
  def __init__(
45
44
  self,
45
+ run_id: int,
46
46
  state_factory: StateFactory,
47
- fab_id: Optional[str] = None,
48
- fab_version: Optional[str] = None,
49
47
  ) -> None:
50
- self.run_id: Optional[int] = None
51
- self.fab_id = fab_id if fab_id is not None else ""
52
- self.fab_version = fab_version if fab_version is not None else ""
53
- self.node = Node(node_id=0, anonymous=True)
48
+ self._run_id = run_id
49
+ self._run: Optional[Run] = None
54
50
  self.state = state_factory.state()
51
+ self.node = Node(node_id=0, anonymous=True)
55
52
 
56
53
  def _check_message(self, message: Message) -> None:
54
+ self._init_run()
57
55
  # Check if the message is valid
58
56
  if not (
59
- message.metadata.run_id == self.run_id
57
+ message.metadata.run_id == cast(Run, self._run).run_id
60
58
  and message.metadata.src_node_id == self.node.node_id
61
59
  and message.metadata.message_id == ""
62
60
  and message.metadata.reply_to_message == ""
@@ -64,16 +62,20 @@ class InMemoryDriver(Driver):
64
62
  ):
65
63
  raise ValueError(f"Invalid message: {message}")
66
64
 
67
- def _get_run_id(self) -> int:
68
- """Return run_id.
69
-
70
- If unset, create a new run.
71
- """
72
- if self.run_id is None:
73
- self.run_id = self.state.create_run(
74
- fab_id=self.fab_id, fab_version=self.fab_version
75
- )
76
- return self.run_id
65
+ def _init_run(self) -> None:
66
+ """Initialize the run."""
67
+ if self._run is not None:
68
+ return
69
+ run = self.state.get_run(self._run_id)
70
+ if run is None:
71
+ raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
72
+ self._run = run
73
+
74
+ @property
75
+ def run(self) -> Run:
76
+ """Run ID."""
77
+ self._init_run()
78
+ return Run(**vars(cast(Run, self._run)))
77
79
 
78
80
  def create_message( # pylint: disable=too-many-arguments
79
81
  self,
@@ -88,7 +90,7 @@ class InMemoryDriver(Driver):
88
90
  This method constructs a new `Message` with given content and metadata.
89
91
  The `run_id` and `src_node_id` will be set automatically.
90
92
  """
91
- run_id = self._get_run_id()
93
+ self._init_run()
92
94
  if ttl:
93
95
  warnings.warn(
94
96
  "A custom TTL was set, but note that the SuperLink does not enforce "
@@ -99,7 +101,7 @@ class InMemoryDriver(Driver):
99
101
  ttl_ = DEFAULT_TTL if ttl is None else ttl
100
102
 
101
103
  metadata = Metadata(
102
- run_id=run_id,
104
+ run_id=cast(Run, self._run).run_id,
103
105
  message_id="", # Will be set by the server
104
106
  src_node_id=self.node.node_id,
105
107
  dst_node_id=dst_node_id,
@@ -112,8 +114,8 @@ class InMemoryDriver(Driver):
112
114
 
113
115
  def get_node_ids(self) -> List[int]:
114
116
  """Get node IDs."""
115
- run_id = self._get_run_id()
116
- return list(self.state.get_nodes(run_id))
117
+ self._init_run()
118
+ return list(self.state.get_nodes(cast(Run, self._run).run_id))
117
119
 
118
120
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
119
121
  """Push messages to specified node IDs.