flwr-nightly 1.10.0.dev20240619__py3-none-any.whl → 1.10.0.dev20240620__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (86) hide show
  1. flwr/cli/app.py +3 -0
  2. flwr/cli/build.py +3 -7
  3. flwr/cli/new/new.py +1 -1
  4. flwr/cli/run/run.py +8 -1
  5. flwr/client/client_app.py +1 -1
  6. flwr/client/dpfedavg_numpy_client.py +1 -1
  7. flwr/client/grpc_rere_client/__init__.py +1 -1
  8. flwr/client/grpc_rere_client/connection.py +1 -1
  9. flwr/client/message_handler/__init__.py +1 -1
  10. flwr/client/message_handler/message_handler.py +1 -1
  11. flwr/client/mod/__init__.py +1 -1
  12. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  13. flwr/client/mod/utils.py +1 -1
  14. flwr/client/rest_client/__init__.py +1 -1
  15. flwr/client/rest_client/connection.py +1 -1
  16. flwr/client/supernode/app.py +1 -1
  17. flwr/common/address.py +1 -1
  18. flwr/common/config.py +8 -6
  19. flwr/common/constant.py +1 -1
  20. flwr/common/date.py +1 -1
  21. flwr/common/dp.py +1 -1
  22. flwr/common/grpc.py +1 -1
  23. flwr/common/secure_aggregation/__init__.py +1 -1
  24. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  25. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  26. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
  27. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  28. flwr/common/secure_aggregation/quantization.py +1 -1
  29. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  30. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  31. flwr/common/version.py +14 -0
  32. flwr/server/compat/app.py +1 -1
  33. flwr/server/compat/app_utils.py +1 -1
  34. flwr/server/compat/driver_client_proxy.py +1 -1
  35. flwr/server/driver/driver.py +6 -0
  36. flwr/server/driver/grpc_driver.py +85 -63
  37. flwr/server/driver/inmemory_driver.py +28 -26
  38. flwr/server/run_serverapp.py +12 -7
  39. flwr/server/strategy/bulyan.py +1 -1
  40. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  41. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  42. flwr/server/strategy/fedadagrad.py +1 -1
  43. flwr/server/strategy/fedadam.py +1 -1
  44. flwr/server/strategy/fedavg_android.py +1 -1
  45. flwr/server/strategy/fedavgm.py +1 -1
  46. flwr/server/strategy/fedmedian.py +1 -1
  47. flwr/server/strategy/fedopt.py +1 -1
  48. flwr/server/strategy/fedprox.py +1 -1
  49. flwr/server/strategy/fedxgb_bagging.py +1 -1
  50. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  51. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  52. flwr/server/strategy/fedyogi.py +1 -1
  53. flwr/server/strategy/krum.py +1 -1
  54. flwr/server/strategy/qfedavg.py +1 -1
  55. flwr/server/superlink/driver/__init__.py +1 -1
  56. flwr/server/superlink/driver/driver_grpc.py +1 -1
  57. flwr/server/superlink/driver/driver_servicer.py +15 -3
  58. flwr/server/superlink/fleet/__init__.py +1 -1
  59. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  60. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  61. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  62. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  63. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -1
  64. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  65. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
  66. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  67. flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
  68. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  69. flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
  70. flwr/server/superlink/state/__init__.py +1 -1
  71. flwr/server/superlink/state/in_memory_state.py +1 -1
  72. flwr/server/superlink/state/sqlite_state.py +1 -1
  73. flwr/server/superlink/state/state.py +1 -1
  74. flwr/server/superlink/state/state_factory.py +11 -2
  75. flwr/server/utils/__init__.py +1 -1
  76. flwr/server/utils/tensorboard.py +1 -1
  77. flwr/simulation/__init__.py +1 -1
  78. flwr/simulation/app.py +1 -1
  79. flwr/simulation/ray_transport/__init__.py +1 -1
  80. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  81. flwr/simulation/run_simulation.py +15 -8
  82. {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240620.dist-info}/METADATA +2 -1
  83. {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240620.dist-info}/RECORD +86 -86
  84. {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240620.dist-info}/LICENSE +0 -0
  85. {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240620.dist-info}/WHEEL +0 -0
  86. {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240620.dist-info}/entry_points.txt +0 -0
flwr/cli/app.py CHANGED
@@ -15,6 +15,7 @@
15
15
  """Flower command line interface."""
16
16
 
17
17
  import typer
18
+ from typer.main import get_command
18
19
 
19
20
  from .build import build
20
21
  from .example import example
@@ -37,5 +38,7 @@ app.command()(run)
37
38
  app.command()(build)
38
39
  app.command()(install)
39
40
 
41
+ typer_click_object = get_command(app)
42
+
40
43
  if __name__ == "__main__":
41
44
  app()
flwr/cli/build.py CHANGED
@@ -36,13 +36,9 @@ def build(
36
36
  ) -> str:
37
37
  """Build a Flower project into a Flower App Bundle (FAB).
38
38
 
39
- You can run `flwr build` without any argument to bundle the current directory:
40
-
41
- `flwr build`
42
-
43
- You can also build a specific directory:
44
-
45
- `flwr build --directory ./projects/flower-hello-world`
39
+ You can run ``flwr build`` without any arguments to bundle the current directory,
40
+ or you can use ``--directory`` to build a specific directory:
41
+ ``flwr build --directory ./projects/flower-hello-world``.
46
42
  """
47
43
  if directory is None:
48
44
  directory = Path.cwd()
flwr/cli/new/new.py CHANGED
@@ -190,7 +190,7 @@ def new(
190
190
  )
191
191
  print(
192
192
  typer.style(
193
- f" cd {project_name}\n" + " pip install -e .\n flwr run\n",
193
+ f" cd {package_name}\n" + " pip install -e .\n flwr run\n",
194
194
  fg=typer.colors.BRIGHT_CYAN,
195
195
  bold=True,
196
196
  )
flwr/cli/run/run.py CHANGED
@@ -41,7 +41,10 @@ class Engine(str, Enum):
41
41
  def run(
42
42
  engine: Annotated[
43
43
  Optional[Engine],
44
- typer.Option(case_sensitive=False, help="The execution engine to run the app"),
44
+ typer.Option(
45
+ case_sensitive=False,
46
+ help="The engine to run FL with (currently only simulation is supported).",
47
+ ),
45
48
  ] = None,
46
49
  use_superexec: Annotated[
47
50
  bool,
@@ -87,12 +90,16 @@ def run(
87
90
 
88
91
  if engine == Engine.SIMULATION:
89
92
  num_supernodes = config["flower"]["engine"]["simulation"]["supernode"]["num"]
93
+ backend_config = config["flower"]["engine"]["simulation"].get(
94
+ "backend_config", None
95
+ )
90
96
 
91
97
  typer.secho("Starting run... ", fg=typer.colors.BLUE)
92
98
  _run_simulation(
93
99
  server_app_attr=server_app_ref,
94
100
  client_app_attr=client_app_ref,
95
101
  num_supernodes=num_supernodes,
102
+ backend_config=backend_config,
96
103
  )
97
104
  else:
98
105
  typer.secho(
flwr/client/client_app.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
flwr/client/mod/utils.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -267,7 +267,7 @@ def _parse_args_run_supernode() -> argparse.ArgumentParser:
267
267
  "--flwr-dir",
268
268
  default=None,
269
269
  help="""The path containing installed Flower Apps.
270
- By default, this value isequal to:
270
+ By default, this value is equal to:
271
271
 
272
272
  - `$FLWR_HOME/` if `$FLWR_HOME` is defined
273
273
  - `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined
flwr/common/address.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
flwr/common/config.py CHANGED
@@ -24,14 +24,16 @@ from flwr.cli.config_utils import validate_fields
24
24
  from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME
25
25
 
26
26
 
27
- def get_flwr_dir() -> Path:
27
+ def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
28
28
  """Return the Flower home directory based on env variables."""
29
- return Path(
30
- os.getenv(
31
- FLWR_HOME,
32
- f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}/.flwr",
29
+ if provided_path is None or not Path(provided_path).is_dir():
30
+ return Path(
31
+ os.getenv(
32
+ FLWR_HOME,
33
+ f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}/.flwr",
34
+ )
33
35
  )
34
- )
36
+ return Path(provided_path).absolute()
35
37
 
36
38
 
37
39
  def get_project_dir(
flwr/common/constant.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
flwr/common/date.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
flwr/common/dp.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
flwr/common/grpc.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
flwr/common/version.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
1
15
  """Flower package version helper."""
2
16
 
3
17
  import importlib.metadata as importlib_metadata
flwr/server/compat/app.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -91,7 +91,7 @@ def _update_client_manager(
91
91
  node_id=node_id,
92
92
  driver=driver,
93
93
  anonymous=False,
94
- run_id=driver.run_id, # type: ignore
94
+ run_id=driver.run.run_id,
95
95
  )
96
96
  if client_manager.register(client_proxy):
97
97
  registered_nodes[node_id] = client_proxy
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -19,11 +19,17 @@ from abc import ABC, abstractmethod
19
19
  from typing import Iterable, List, Optional
20
20
 
21
21
  from flwr.common import Message, RecordSet
22
+ from flwr.common.typing import Run
22
23
 
23
24
 
24
25
  class Driver(ABC):
25
26
  """Abstract base Driver class for the Driver API."""
26
27
 
28
+ @property
29
+ @abstractmethod
30
+ def run(self) -> Run:
31
+ """Run information."""
32
+
27
33
  @abstractmethod
28
34
  def create_message( # pylint: disable=too-many-arguments
29
35
  self,
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@
17
17
  import time
18
18
  import warnings
19
19
  from logging import DEBUG, ERROR, WARNING
20
- from typing import Iterable, List, Optional, Tuple
20
+ from typing import Iterable, List, Optional, Tuple, cast
21
21
 
22
22
  import grpc
23
23
 
@@ -25,6 +25,7 @@ from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, ev
25
25
  from flwr.common.grpc import create_channel
26
26
  from flwr.common.logger import log
27
27
  from flwr.common.serde import message_from_taskres, message_to_taskins
28
+ from flwr.common.typing import Run
28
29
  from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
29
30
  CreateRunRequest,
30
31
  CreateRunResponse,
@@ -37,6 +38,7 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
37
38
  )
38
39
  from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
39
40
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
41
+ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
40
42
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
41
43
 
42
44
  from .driver import Driver
@@ -46,13 +48,24 @@ DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
46
48
  ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
47
49
  [Driver] Error: Not connected.
48
50
 
49
- Call `connect()` on the `GrpcDriverHelper` instance before calling any of the other
50
- `GrpcDriverHelper` methods.
51
+ Call `connect()` on the `GrpcDriverStub` instance before calling any of the other
52
+ `GrpcDriverStub` methods.
51
53
  """
52
54
 
53
55
 
54
- class GrpcDriverHelper:
55
- """`GrpcDriverHelper` provides access to the gRPC Driver API/service."""
56
+ class GrpcDriverStub:
57
+ """`GrpcDriverStub` provides access to the gRPC Driver API/service.
58
+
59
+ Parameters
60
+ ----------
61
+ driver_service_address : Optional[str]
62
+ The IPv4 or IPv6 address of the Driver API server.
63
+ Defaults to `"[::]:9091"`.
64
+ root_certificates : Optional[bytes] (default: None)
65
+ The PEM-encoded root certificates as a byte string.
66
+ If provided, a secure connection using the certificates will be
67
+ established to an SSL-enabled Flower server.
68
+ """
56
69
 
57
70
  def __init__(
58
71
  self,
@@ -64,6 +77,10 @@ class GrpcDriverHelper:
64
77
  self.channel: Optional[grpc.Channel] = None
65
78
  self.stub: Optional[DriverStub] = None
66
79
 
80
+ def is_connected(self) -> bool:
81
+ """Return True if connected to the Driver API server, otherwise False."""
82
+ return self.channel is not None
83
+
67
84
  def connect(self) -> None:
68
85
  """Connect to the Driver API."""
69
86
  event(EventType.DRIVER_CONNECT)
@@ -95,18 +112,29 @@ class GrpcDriverHelper:
95
112
  # Check if channel is open
96
113
  if self.stub is None:
97
114
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
98
- raise ConnectionError("`GrpcDriverHelper` instance not connected")
115
+ raise ConnectionError("`GrpcDriverStub` instance not connected")
99
116
 
100
117
  # Call Driver API
101
118
  res: CreateRunResponse = self.stub.CreateRun(request=req)
102
119
  return res
103
120
 
121
+ def get_run(self, req: GetRunRequest) -> GetRunResponse:
122
+ """Get run information."""
123
+ # Check if channel is open
124
+ if self.stub is None:
125
+ log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
126
+ raise ConnectionError("`GrpcDriverStub` instance not connected")
127
+
128
+ # Call gRPC Driver API
129
+ res: GetRunResponse = self.stub.GetRun(request=req)
130
+ return res
131
+
104
132
  def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
105
133
  """Get client IDs."""
106
134
  # Check if channel is open
107
135
  if self.stub is None:
108
136
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
109
- raise ConnectionError("`GrpcDriverHelper` instance not connected")
137
+ raise ConnectionError("`GrpcDriverStub` instance not connected")
110
138
 
111
139
  # Call gRPC Driver API
112
140
  res: GetNodesResponse = self.stub.GetNodes(request=req)
@@ -117,7 +145,7 @@ class GrpcDriverHelper:
117
145
  # Check if channel is open
118
146
  if self.stub is None:
119
147
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
120
- raise ConnectionError("`GrpcDriverHelper` instance not connected")
148
+ raise ConnectionError("`GrpcDriverStub` instance not connected")
121
149
 
122
150
  # Call gRPC Driver API
123
151
  res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
@@ -128,7 +156,7 @@ class GrpcDriverHelper:
128
156
  # Check if channel is open
129
157
  if self.stub is None:
130
158
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
131
- raise ConnectionError("`GrpcDriverHelper` instance not connected")
159
+ raise ConnectionError("`GrpcDriverStub` instance not connected")
132
160
 
133
161
  # Call Driver API
134
162
  res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
@@ -140,56 +168,52 @@ class GrpcDriver(Driver):
140
168
 
141
169
  Parameters
142
170
  ----------
143
- driver_service_address : Optional[str]
144
- The IPv4 or IPv6 address of the Driver API server.
145
- Defaults to `"[::]:9091"`.
146
- certificates : bytes (default: None)
147
- Tuple containing root certificate, server certificate, and private key
148
- to start a secure SSL-enabled server. The tuple is expected to have
149
- three bytes elements in the following order:
150
-
151
- * CA certificate.
152
- * server certificate.
153
- * server private key.
154
- fab_id : str (default: None)
155
- The identifier of the FAB used in the run.
156
- fab_version : str (default: None)
157
- The version of the FAB used in the run.
171
+ run_id : int
172
+ The identifier of the run.
173
+ stub : Optional[GrpcDriverStub] (default: None)
174
+ The ``GrpcDriverStub`` instance used to communicate with the SuperLink.
175
+ If None, an instance connected to "[::]:9091" will be created.
158
176
  """
159
177
 
160
- def __init__(
178
+ def __init__( # pylint: disable=too-many-arguments
161
179
  self,
162
- driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
163
- root_certificates: Optional[bytes] = None,
164
- fab_id: Optional[str] = None,
165
- fab_version: Optional[str] = None,
180
+ run_id: int,
181
+ stub: Optional[GrpcDriverStub] = None,
166
182
  ) -> None:
167
- self.addr = driver_service_address
168
- self.root_certificates = root_certificates
169
- self.driver_helper: Optional[GrpcDriverHelper] = None
170
- self.run_id: Optional[int] = None
171
- self.fab_id = fab_id if fab_id is not None else ""
172
- self.fab_version = fab_version if fab_version is not None else ""
183
+ self._run_id = run_id
184
+ self._run: Optional[Run] = None
185
+ self.stub = stub if stub is not None else GrpcDriverStub()
173
186
  self.node = Node(node_id=0, anonymous=True)
174
187
 
175
- def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]:
176
- # Check if the GrpcDriverHelper is initialized
177
- if self.driver_helper is None or self.run_id is None:
178
- # Connect and create run
179
- self.driver_helper = GrpcDriverHelper(
180
- driver_service_address=self.addr,
181
- root_certificates=self.root_certificates,
188
+ @property
189
+ def run(self) -> Run:
190
+ """Run information."""
191
+ self._get_stub_and_run_id()
192
+ return Run(**vars(cast(Run, self._run)))
193
+
194
+ def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]:
195
+ # Check if is initialized
196
+ if self._run is None:
197
+ # Connect
198
+ if not self.stub.is_connected():
199
+ self.stub.connect()
200
+ # Get the run info
201
+ req = GetRunRequest(run_id=self._run_id)
202
+ res = self.stub.get_run(req)
203
+ if not res.HasField("run"):
204
+ raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
205
+ self._run = Run(
206
+ run_id=res.run.run_id,
207
+ fab_id=res.run.fab_id,
208
+ fab_version=res.run.fab_version,
182
209
  )
183
- self.driver_helper.connect()
184
- req = CreateRunRequest(fab_id=self.fab_id, fab_version=self.fab_version)
185
- res = self.driver_helper.create_run(req)
186
- self.run_id = res.run_id
187
- return self.driver_helper, self.run_id
210
+
211
+ return self.stub, self._run.run_id
188
212
 
189
213
  def _check_message(self, message: Message) -> None:
190
214
  # Check if the message is valid
191
215
  if not (
192
- message.metadata.run_id == self.run_id
216
+ message.metadata.run_id == cast(Run, self._run).run_id
193
217
  and message.metadata.src_node_id == self.node.node_id
194
218
  and message.metadata.message_id == ""
195
219
  and message.metadata.reply_to_message == ""
@@ -210,7 +234,7 @@ class GrpcDriver(Driver):
210
234
  This method constructs a new `Message` with given content and metadata.
211
235
  The `run_id` and `src_node_id` will be set automatically.
212
236
  """
213
- _, run_id = self._get_grpc_driver_helper_and_run_id()
237
+ _, run_id = self._get_stub_and_run_id()
214
238
  if ttl:
215
239
  warnings.warn(
216
240
  "A custom TTL was set, but note that the SuperLink does not enforce "
@@ -234,9 +258,9 @@ class GrpcDriver(Driver):
234
258
 
235
259
  def get_node_ids(self) -> List[int]:
236
260
  """Get node IDs."""
237
- grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id()
238
- # Call GrpcDriverHelper method
239
- res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id))
261
+ stub, run_id = self._get_stub_and_run_id()
262
+ # Call GrpcDriverStub method
263
+ res = stub.get_nodes(GetNodesRequest(run_id=run_id))
240
264
  return [node.node_id for node in res.nodes]
241
265
 
242
266
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
@@ -245,7 +269,7 @@ class GrpcDriver(Driver):
245
269
  This method takes an iterable of messages and sends each message
246
270
  to the node specified in `dst_node_id`.
247
271
  """
248
- grpc_driver_helper, _ = self._get_grpc_driver_helper_and_run_id()
272
+ stub, _ = self._get_stub_and_run_id()
249
273
  # Construct TaskIns
250
274
  task_ins_list: List[TaskIns] = []
251
275
  for msg in messages:
@@ -255,10 +279,8 @@ class GrpcDriver(Driver):
255
279
  taskins = message_to_taskins(msg)
256
280
  # Add to list
257
281
  task_ins_list.append(taskins)
258
- # Call GrpcDriverHelper method
259
- res = grpc_driver_helper.push_task_ins(
260
- PushTaskInsRequest(task_ins_list=task_ins_list)
261
- )
282
+ # Call GrpcDriverStub method
283
+ res = stub.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
262
284
  return list(res.task_ids)
263
285
 
264
286
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
@@ -267,9 +289,9 @@ class GrpcDriver(Driver):
267
289
  This method is used to collect messages from the SuperLink that correspond to a
268
290
  set of given message IDs.
269
291
  """
270
- grpc_driver, _ = self._get_grpc_driver_helper_and_run_id()
292
+ stub, _ = self._get_stub_and_run_id()
271
293
  # Pull TaskRes
272
- res = grpc_driver.pull_task_res(
294
+ res = stub.pull_task_res(
273
295
  PullTaskResRequest(node=self.node, task_ids=message_ids)
274
296
  )
275
297
  # Convert TaskRes to Message
@@ -308,8 +330,8 @@ class GrpcDriver(Driver):
308
330
 
309
331
  def close(self) -> None:
310
332
  """Disconnect from the SuperLink if connected."""
311
- # Check if GrpcDriverHelper is initialized
312
- if self.driver_helper is None:
333
+ # Check if `connect` was called before
334
+ if not self.stub.is_connected():
313
335
  return
314
336
  # Disconnect
315
- self.driver_helper.disconnect()
337
+ self.stub.disconnect()