flwr-nightly 1.10.0.dev20240619__py3-none-any.whl → 1.10.0.dev20240707__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (109) hide show
  1. flwr/cli/app.py +3 -0
  2. flwr/cli/build.py +5 -9
  3. flwr/cli/new/new.py +104 -28
  4. flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
  5. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  6. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +86 -0
  7. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +124 -0
  8. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
  9. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
  10. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
  11. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
  12. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
  13. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +42 -0
  14. flwr/cli/run/run.py +21 -5
  15. flwr/client/__init__.py +2 -0
  16. flwr/client/app.py +15 -10
  17. flwr/client/client_app.py +30 -5
  18. flwr/client/dpfedavg_numpy_client.py +1 -1
  19. flwr/client/grpc_rere_client/__init__.py +1 -1
  20. flwr/client/grpc_rere_client/connection.py +1 -1
  21. flwr/client/message_handler/__init__.py +1 -1
  22. flwr/client/message_handler/message_handler.py +4 -5
  23. flwr/client/mod/__init__.py +1 -1
  24. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  25. flwr/client/mod/utils.py +1 -1
  26. flwr/client/node_state.py +6 -3
  27. flwr/client/node_state_tests.py +1 -1
  28. flwr/client/rest_client/__init__.py +1 -1
  29. flwr/client/rest_client/connection.py +1 -1
  30. flwr/client/supernode/app.py +12 -4
  31. flwr/client/typing.py +2 -1
  32. flwr/common/address.py +1 -1
  33. flwr/common/config.py +8 -6
  34. flwr/common/constant.py +4 -1
  35. flwr/common/context.py +11 -1
  36. flwr/common/date.py +1 -1
  37. flwr/common/dp.py +1 -1
  38. flwr/common/grpc.py +1 -1
  39. flwr/common/logger.py +13 -0
  40. flwr/common/message.py +0 -17
  41. flwr/common/secure_aggregation/__init__.py +1 -1
  42. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  43. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  44. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
  45. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  46. flwr/common/secure_aggregation/quantization.py +1 -1
  47. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  48. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  49. flwr/common/version.py +14 -0
  50. flwr/server/compat/app.py +1 -1
  51. flwr/server/compat/app_utils.py +1 -1
  52. flwr/server/compat/driver_client_proxy.py +1 -1
  53. flwr/server/driver/driver.py +6 -0
  54. flwr/server/driver/grpc_driver.py +85 -63
  55. flwr/server/driver/inmemory_driver.py +28 -26
  56. flwr/server/run_serverapp.py +61 -18
  57. flwr/server/strategy/bulyan.py +1 -1
  58. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  59. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  60. flwr/server/strategy/fedadagrad.py +1 -1
  61. flwr/server/strategy/fedadam.py +1 -1
  62. flwr/server/strategy/fedavg_android.py +1 -1
  63. flwr/server/strategy/fedavgm.py +1 -1
  64. flwr/server/strategy/fedmedian.py +1 -1
  65. flwr/server/strategy/fedopt.py +1 -1
  66. flwr/server/strategy/fedprox.py +1 -1
  67. flwr/server/strategy/fedxgb_bagging.py +1 -1
  68. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  69. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  70. flwr/server/strategy/fedyogi.py +1 -1
  71. flwr/server/strategy/krum.py +1 -1
  72. flwr/server/strategy/qfedavg.py +1 -1
  73. flwr/server/superlink/driver/__init__.py +1 -1
  74. flwr/server/superlink/driver/driver_grpc.py +1 -1
  75. flwr/server/superlink/driver/driver_servicer.py +15 -3
  76. flwr/server/superlink/fleet/__init__.py +1 -1
  77. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  78. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  79. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  80. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  81. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -1
  82. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  83. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
  84. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  85. flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
  86. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  87. flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
  88. flwr/server/superlink/fleet/vce/backend/raybackend.py +45 -26
  89. flwr/server/superlink/fleet/vce/vce_api.py +3 -8
  90. flwr/server/superlink/state/__init__.py +1 -1
  91. flwr/server/superlink/state/in_memory_state.py +5 -5
  92. flwr/server/superlink/state/sqlite_state.py +5 -5
  93. flwr/server/superlink/state/state.py +1 -1
  94. flwr/server/superlink/state/state_factory.py +11 -2
  95. flwr/server/superlink/state/utils.py +6 -0
  96. flwr/server/utils/__init__.py +1 -1
  97. flwr/server/utils/tensorboard.py +1 -1
  98. flwr/simulation/__init__.py +1 -1
  99. flwr/simulation/app.py +52 -37
  100. flwr/simulation/ray_transport/__init__.py +1 -1
  101. flwr/simulation/ray_transport/ray_actor.py +0 -6
  102. flwr/simulation/ray_transport/ray_client_proxy.py +17 -10
  103. flwr/simulation/run_simulation.py +47 -28
  104. flwr/superexec/deployment.py +109 -0
  105. {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/METADATA +2 -1
  106. {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/RECORD +109 -98
  107. {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/LICENSE +0 -0
  108. {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/WHEEL +0 -0
  109. {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/entry_points.txt +0 -0
flwr/common/date.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
flwr/common/dp.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
flwr/common/grpc.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
flwr/common/logger.py CHANGED
@@ -197,6 +197,19 @@ def warn_deprecated_feature(name: str) -> None:
197
197
  )
198
198
 
199
199
 
200
+ def warn_unsupported_feature(name: str) -> None:
201
+ """Warn the user when they use an unsupported feature."""
202
+ log(
203
+ WARN,
204
+ """UNSUPPORTED FEATURE: %s
205
+
206
+ This is an unsupported feature. It will be removed
207
+ entirely in future versions of Flower.
208
+ """,
209
+ name,
210
+ )
211
+
212
+
200
213
  def set_logger_propagation(
201
214
  child_logger: logging.Logger, value: bool = True
202
215
  ) -> logging.Logger:
flwr/common/message.py CHANGED
@@ -48,10 +48,6 @@ class Metadata: # pylint: disable=too-many-instance-attributes
48
48
  message_type : str
49
49
  A string that encodes the action to be executed on
50
50
  the receiving end.
51
- partition_id : Optional[int]
52
- An identifier that can be used when loading a particular
53
- data partition for a ClientApp. Making use of this identifier
54
- is more relevant when conducting simulations.
55
51
  """
56
52
 
57
53
  def __init__( # pylint: disable=too-many-arguments
@@ -64,7 +60,6 @@ class Metadata: # pylint: disable=too-many-instance-attributes
64
60
  group_id: str,
65
61
  ttl: float,
66
62
  message_type: str,
67
- partition_id: int | None = None,
68
63
  ) -> None:
69
64
  var_dict = {
70
65
  "_run_id": run_id,
@@ -75,7 +70,6 @@ class Metadata: # pylint: disable=too-many-instance-attributes
75
70
  "_group_id": group_id,
76
71
  "_ttl": ttl,
77
72
  "_message_type": message_type,
78
- "_partition_id": partition_id,
79
73
  }
80
74
  self.__dict__.update(var_dict)
81
75
 
@@ -149,16 +143,6 @@ class Metadata: # pylint: disable=too-many-instance-attributes
149
143
  """Set message_type."""
150
144
  self.__dict__["_message_type"] = value
151
145
 
152
- @property
153
- def partition_id(self) -> int | None:
154
- """An identifier telling which data partition a ClientApp should use."""
155
- return cast(int, self.__dict__["_partition_id"])
156
-
157
- @partition_id.setter
158
- def partition_id(self, value: int) -> None:
159
- """Set partition_id."""
160
- self.__dict__["_partition_id"] = value
161
-
162
146
  def __repr__(self) -> str:
163
147
  """Return a string representation of this instance."""
164
148
  view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
@@ -398,5 +382,4 @@ def _create_reply_metadata(msg: Message, ttl: float) -> Metadata:
398
382
  group_id=msg.metadata.group_id,
399
383
  ttl=ttl,
400
384
  message_type=msg.metadata.message_type,
401
- partition_id=msg.metadata.partition_id,
402
385
  )
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
flwr/common/version.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
1
15
  """Flower package version helper."""
2
16
 
3
17
  import importlib.metadata as importlib_metadata
flwr/server/compat/app.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -91,7 +91,7 @@ def _update_client_manager(
91
91
  node_id=node_id,
92
92
  driver=driver,
93
93
  anonymous=False,
94
- run_id=driver.run_id, # type: ignore
94
+ run_id=driver.run.run_id,
95
95
  )
96
96
  if client_manager.register(client_proxy):
97
97
  registered_nodes[node_id] = client_proxy
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -19,11 +19,17 @@ from abc import ABC, abstractmethod
19
19
  from typing import Iterable, List, Optional
20
20
 
21
21
  from flwr.common import Message, RecordSet
22
+ from flwr.common.typing import Run
22
23
 
23
24
 
24
25
  class Driver(ABC):
25
26
  """Abstract base Driver class for the Driver API."""
26
27
 
28
+ @property
29
+ @abstractmethod
30
+ def run(self) -> Run:
31
+ """Run information."""
32
+
27
33
  @abstractmethod
28
34
  def create_message( # pylint: disable=too-many-arguments
29
35
  self,
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@
17
17
  import time
18
18
  import warnings
19
19
  from logging import DEBUG, ERROR, WARNING
20
- from typing import Iterable, List, Optional, Tuple
20
+ from typing import Iterable, List, Optional, Tuple, cast
21
21
 
22
22
  import grpc
23
23
 
@@ -25,6 +25,7 @@ from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, ev
25
25
  from flwr.common.grpc import create_channel
26
26
  from flwr.common.logger import log
27
27
  from flwr.common.serde import message_from_taskres, message_to_taskins
28
+ from flwr.common.typing import Run
28
29
  from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
29
30
  CreateRunRequest,
30
31
  CreateRunResponse,
@@ -37,6 +38,7 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
37
38
  )
38
39
  from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
39
40
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
41
+ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
40
42
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
41
43
 
42
44
  from .driver import Driver
@@ -46,13 +48,24 @@ DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
46
48
  ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
47
49
  [Driver] Error: Not connected.
48
50
 
49
- Call `connect()` on the `GrpcDriverHelper` instance before calling any of the other
50
- `GrpcDriverHelper` methods.
51
+ Call `connect()` on the `GrpcDriverStub` instance before calling any of the other
52
+ `GrpcDriverStub` methods.
51
53
  """
52
54
 
53
55
 
54
- class GrpcDriverHelper:
55
- """`GrpcDriverHelper` provides access to the gRPC Driver API/service."""
56
+ class GrpcDriverStub:
57
+ """`GrpcDriverStub` provides access to the gRPC Driver API/service.
58
+
59
+ Parameters
60
+ ----------
61
+ driver_service_address : Optional[str]
62
+ The IPv4 or IPv6 address of the Driver API server.
63
+ Defaults to `"[::]:9091"`.
64
+ root_certificates : Optional[bytes] (default: None)
65
+ The PEM-encoded root certificates as a byte string.
66
+ If provided, a secure connection using the certificates will be
67
+ established to an SSL-enabled Flower server.
68
+ """
56
69
 
57
70
  def __init__(
58
71
  self,
@@ -64,6 +77,10 @@ class GrpcDriverHelper:
64
77
  self.channel: Optional[grpc.Channel] = None
65
78
  self.stub: Optional[DriverStub] = None
66
79
 
80
+ def is_connected(self) -> bool:
81
+ """Return True if connected to the Driver API server, otherwise False."""
82
+ return self.channel is not None
83
+
67
84
  def connect(self) -> None:
68
85
  """Connect to the Driver API."""
69
86
  event(EventType.DRIVER_CONNECT)
@@ -95,18 +112,29 @@ class GrpcDriverHelper:
95
112
  # Check if channel is open
96
113
  if self.stub is None:
97
114
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
98
- raise ConnectionError("`GrpcDriverHelper` instance not connected")
115
+ raise ConnectionError("`GrpcDriverStub` instance not connected")
99
116
 
100
117
  # Call Driver API
101
118
  res: CreateRunResponse = self.stub.CreateRun(request=req)
102
119
  return res
103
120
 
121
+ def get_run(self, req: GetRunRequest) -> GetRunResponse:
122
+ """Get run information."""
123
+ # Check if channel is open
124
+ if self.stub is None:
125
+ log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
126
+ raise ConnectionError("`GrpcDriverStub` instance not connected")
127
+
128
+ # Call gRPC Driver API
129
+ res: GetRunResponse = self.stub.GetRun(request=req)
130
+ return res
131
+
104
132
  def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
105
133
  """Get client IDs."""
106
134
  # Check if channel is open
107
135
  if self.stub is None:
108
136
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
109
- raise ConnectionError("`GrpcDriverHelper` instance not connected")
137
+ raise ConnectionError("`GrpcDriverStub` instance not connected")
110
138
 
111
139
  # Call gRPC Driver API
112
140
  res: GetNodesResponse = self.stub.GetNodes(request=req)
@@ -117,7 +145,7 @@ class GrpcDriverHelper:
117
145
  # Check if channel is open
118
146
  if self.stub is None:
119
147
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
120
- raise ConnectionError("`GrpcDriverHelper` instance not connected")
148
+ raise ConnectionError("`GrpcDriverStub` instance not connected")
121
149
 
122
150
  # Call gRPC Driver API
123
151
  res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
@@ -128,7 +156,7 @@ class GrpcDriverHelper:
128
156
  # Check if channel is open
129
157
  if self.stub is None:
130
158
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
131
- raise ConnectionError("`GrpcDriverHelper` instance not connected")
159
+ raise ConnectionError("`GrpcDriverStub` instance not connected")
132
160
 
133
161
  # Call Driver API
134
162
  res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
@@ -140,56 +168,52 @@ class GrpcDriver(Driver):
140
168
 
141
169
  Parameters
142
170
  ----------
143
- driver_service_address : Optional[str]
144
- The IPv4 or IPv6 address of the Driver API server.
145
- Defaults to `"[::]:9091"`.
146
- certificates : bytes (default: None)
147
- Tuple containing root certificate, server certificate, and private key
148
- to start a secure SSL-enabled server. The tuple is expected to have
149
- three bytes elements in the following order:
150
-
151
- * CA certificate.
152
- * server certificate.
153
- * server private key.
154
- fab_id : str (default: None)
155
- The identifier of the FAB used in the run.
156
- fab_version : str (default: None)
157
- The version of the FAB used in the run.
171
+ run_id : int
172
+ The identifier of the run.
173
+ stub : Optional[GrpcDriverStub] (default: None)
174
+ The ``GrpcDriverStub`` instance used to communicate with the SuperLink.
175
+ If None, an instance connected to "[::]:9091" will be created.
158
176
  """
159
177
 
160
- def __init__(
178
+ def __init__( # pylint: disable=too-many-arguments
161
179
  self,
162
- driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
163
- root_certificates: Optional[bytes] = None,
164
- fab_id: Optional[str] = None,
165
- fab_version: Optional[str] = None,
180
+ run_id: int,
181
+ stub: Optional[GrpcDriverStub] = None,
166
182
  ) -> None:
167
- self.addr = driver_service_address
168
- self.root_certificates = root_certificates
169
- self.driver_helper: Optional[GrpcDriverHelper] = None
170
- self.run_id: Optional[int] = None
171
- self.fab_id = fab_id if fab_id is not None else ""
172
- self.fab_version = fab_version if fab_version is not None else ""
183
+ self._run_id = run_id
184
+ self._run: Optional[Run] = None
185
+ self.stub = stub if stub is not None else GrpcDriverStub()
173
186
  self.node = Node(node_id=0, anonymous=True)
174
187
 
175
- def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]:
176
- # Check if the GrpcDriverHelper is initialized
177
- if self.driver_helper is None or self.run_id is None:
178
- # Connect and create run
179
- self.driver_helper = GrpcDriverHelper(
180
- driver_service_address=self.addr,
181
- root_certificates=self.root_certificates,
188
+ @property
189
+ def run(self) -> Run:
190
+ """Run information."""
191
+ self._get_stub_and_run_id()
192
+ return Run(**vars(cast(Run, self._run)))
193
+
194
+ def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]:
195
+ # Check if is initialized
196
+ if self._run is None:
197
+ # Connect
198
+ if not self.stub.is_connected():
199
+ self.stub.connect()
200
+ # Get the run info
201
+ req = GetRunRequest(run_id=self._run_id)
202
+ res = self.stub.get_run(req)
203
+ if not res.HasField("run"):
204
+ raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
205
+ self._run = Run(
206
+ run_id=res.run.run_id,
207
+ fab_id=res.run.fab_id,
208
+ fab_version=res.run.fab_version,
182
209
  )
183
- self.driver_helper.connect()
184
- req = CreateRunRequest(fab_id=self.fab_id, fab_version=self.fab_version)
185
- res = self.driver_helper.create_run(req)
186
- self.run_id = res.run_id
187
- return self.driver_helper, self.run_id
210
+
211
+ return self.stub, self._run.run_id
188
212
 
189
213
  def _check_message(self, message: Message) -> None:
190
214
  # Check if the message is valid
191
215
  if not (
192
- message.metadata.run_id == self.run_id
216
+ message.metadata.run_id == cast(Run, self._run).run_id
193
217
  and message.metadata.src_node_id == self.node.node_id
194
218
  and message.metadata.message_id == ""
195
219
  and message.metadata.reply_to_message == ""
@@ -210,7 +234,7 @@ class GrpcDriver(Driver):
210
234
  This method constructs a new `Message` with given content and metadata.
211
235
  The `run_id` and `src_node_id` will be set automatically.
212
236
  """
213
- _, run_id = self._get_grpc_driver_helper_and_run_id()
237
+ _, run_id = self._get_stub_and_run_id()
214
238
  if ttl:
215
239
  warnings.warn(
216
240
  "A custom TTL was set, but note that the SuperLink does not enforce "
@@ -234,9 +258,9 @@ class GrpcDriver(Driver):
234
258
 
235
259
  def get_node_ids(self) -> List[int]:
236
260
  """Get node IDs."""
237
- grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id()
238
- # Call GrpcDriverHelper method
239
- res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id))
261
+ stub, run_id = self._get_stub_and_run_id()
262
+ # Call GrpcDriverStub method
263
+ res = stub.get_nodes(GetNodesRequest(run_id=run_id))
240
264
  return [node.node_id for node in res.nodes]
241
265
 
242
266
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
@@ -245,7 +269,7 @@ class GrpcDriver(Driver):
245
269
  This method takes an iterable of messages and sends each message
246
270
  to the node specified in `dst_node_id`.
247
271
  """
248
- grpc_driver_helper, _ = self._get_grpc_driver_helper_and_run_id()
272
+ stub, _ = self._get_stub_and_run_id()
249
273
  # Construct TaskIns
250
274
  task_ins_list: List[TaskIns] = []
251
275
  for msg in messages:
@@ -255,10 +279,8 @@ class GrpcDriver(Driver):
255
279
  taskins = message_to_taskins(msg)
256
280
  # Add to list
257
281
  task_ins_list.append(taskins)
258
- # Call GrpcDriverHelper method
259
- res = grpc_driver_helper.push_task_ins(
260
- PushTaskInsRequest(task_ins_list=task_ins_list)
261
- )
282
+ # Call GrpcDriverStub method
283
+ res = stub.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
262
284
  return list(res.task_ids)
263
285
 
264
286
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
@@ -267,9 +289,9 @@ class GrpcDriver(Driver):
267
289
  This method is used to collect messages from the SuperLink that correspond to a
268
290
  set of given message IDs.
269
291
  """
270
- grpc_driver, _ = self._get_grpc_driver_helper_and_run_id()
292
+ stub, _ = self._get_stub_and_run_id()
271
293
  # Pull TaskRes
272
- res = grpc_driver.pull_task_res(
294
+ res = stub.pull_task_res(
273
295
  PullTaskResRequest(node=self.node, task_ids=message_ids)
274
296
  )
275
297
  # Convert TaskRes to Message
@@ -308,8 +330,8 @@ class GrpcDriver(Driver):
308
330
 
309
331
  def close(self) -> None:
310
332
  """Disconnect from the SuperLink if connected."""
311
- # Check if GrpcDriverHelper is initialized
312
- if self.driver_helper is None:
333
+ # Check if `connect` was called before
334
+ if not self.stub.is_connected():
313
335
  return
314
336
  # Disconnect
315
- self.driver_helper.disconnect()
337
+ self.stub.disconnect()
@@ -17,11 +17,12 @@
17
17
 
18
18
  import time
19
19
  import warnings
20
- from typing import Iterable, List, Optional
20
+ from typing import Iterable, List, Optional, cast
21
21
  from uuid import UUID
22
22
 
23
23
  from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
24
24
  from flwr.common.serde import message_from_taskres, message_to_taskins
25
+ from flwr.common.typing import Run
25
26
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
26
27
  from flwr.server.superlink.state import StateFactory
27
28
 
@@ -33,30 +34,27 @@ class InMemoryDriver(Driver):
33
34
 
34
35
  Parameters
35
36
  ----------
37
+ run_id : int
38
+ The identifier of the run.
36
39
  state_factory : StateFactory
37
40
  A StateFactory embedding a state that this driver can interface with.
38
- fab_id : str (default: None)
39
- The identifier of the FAB used in the run.
40
- fab_version : str (default: None)
41
- The version of the FAB used in the run.
42
41
  """
43
42
 
44
43
  def __init__(
45
44
  self,
45
+ run_id: int,
46
46
  state_factory: StateFactory,
47
- fab_id: Optional[str] = None,
48
- fab_version: Optional[str] = None,
49
47
  ) -> None:
50
- self.run_id: Optional[int] = None
51
- self.fab_id = fab_id if fab_id is not None else ""
52
- self.fab_version = fab_version if fab_version is not None else ""
53
- self.node = Node(node_id=0, anonymous=True)
48
+ self._run_id = run_id
49
+ self._run: Optional[Run] = None
54
50
  self.state = state_factory.state()
51
+ self.node = Node(node_id=0, anonymous=True)
55
52
 
56
53
  def _check_message(self, message: Message) -> None:
54
+ self._init_run()
57
55
  # Check if the message is valid
58
56
  if not (
59
- message.metadata.run_id == self.run_id
57
+ message.metadata.run_id == cast(Run, self._run).run_id
60
58
  and message.metadata.src_node_id == self.node.node_id
61
59
  and message.metadata.message_id == ""
62
60
  and message.metadata.reply_to_message == ""
@@ -64,16 +62,20 @@ class InMemoryDriver(Driver):
64
62
  ):
65
63
  raise ValueError(f"Invalid message: {message}")
66
64
 
67
- def _get_run_id(self) -> int:
68
- """Return run_id.
69
-
70
- If unset, create a new run.
71
- """
72
- if self.run_id is None:
73
- self.run_id = self.state.create_run(
74
- fab_id=self.fab_id, fab_version=self.fab_version
75
- )
76
- return self.run_id
65
+ def _init_run(self) -> None:
66
+ """Initialize the run."""
67
+ if self._run is not None:
68
+ return
69
+ run = self.state.get_run(self._run_id)
70
+ if run is None:
71
+ raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
72
+ self._run = run
73
+
74
+ @property
75
+ def run(self) -> Run:
76
+ """Run ID."""
77
+ self._init_run()
78
+ return Run(**vars(cast(Run, self._run)))
77
79
 
78
80
  def create_message( # pylint: disable=too-many-arguments
79
81
  self,
@@ -88,7 +90,7 @@ class InMemoryDriver(Driver):
88
90
  This method constructs a new `Message` with given content and metadata.
89
91
  The `run_id` and `src_node_id` will be set automatically.
90
92
  """
91
- run_id = self._get_run_id()
93
+ self._init_run()
92
94
  if ttl:
93
95
  warnings.warn(
94
96
  "A custom TTL was set, but note that the SuperLink does not enforce "
@@ -99,7 +101,7 @@ class InMemoryDriver(Driver):
99
101
  ttl_ = DEFAULT_TTL if ttl is None else ttl
100
102
 
101
103
  metadata = Metadata(
102
- run_id=run_id,
104
+ run_id=cast(Run, self._run).run_id,
103
105
  message_id="", # Will be set by the server
104
106
  src_node_id=self.node.node_id,
105
107
  dst_node_id=dst_node_id,
@@ -112,8 +114,8 @@ class InMemoryDriver(Driver):
112
114
 
113
115
  def get_node_ids(self) -> List[int]:
114
116
  """Get node IDs."""
115
- run_id = self._get_run_id()
116
- return list(self.state.get_nodes(run_id))
117
+ self._init_run()
118
+ return list(self.state.get_nodes(cast(Run, self._run).run_id))
117
119
 
118
120
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
119
121
  """Push messages to specified node IDs.