flwr-nightly 1.9.0.dev20240419__py3-none-any.whl → 1.9.0.dev20240422__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

@@ -28,12 +28,11 @@ def run_supernode() -> None:
28
28
 
29
29
  event(EventType.RUN_SUPERNODE_ENTER)
30
30
 
31
- args = _parse_args_run_supernode().parse_args()
31
+ _ = _parse_args_run_supernode().parse_args()
32
32
 
33
33
  log(
34
34
  DEBUG,
35
- "Flower will load ClientApp `%s`",
36
- getattr(args, "client-app"),
35
+ "Flower SuperNode starting...",
37
36
  )
38
37
 
39
38
  # Graceful shutdown
@@ -48,7 +47,16 @@ def _parse_args_run_supernode() -> argparse.ArgumentParser:
48
47
  description="Start a Flower SuperNode",
49
48
  )
50
49
 
51
- parse_args_run_client_app(parser=parser)
50
+ parser.add_argument(
51
+ "client-app",
52
+ nargs="?",
53
+ default="",
54
+ help="For example: `client:app` or `project.package.module:wrapper.app`. "
55
+ "This is optional and serves as the default ClientApp to be loaded when "
56
+ "the ServerApp does not specify `fab_id` and `fab_version`. "
57
+ "If not provided, defaults to an empty string.",
58
+ )
59
+ _parse_args_common(parser)
52
60
 
53
61
  return parser
54
62
 
@@ -59,6 +67,10 @@ def parse_args_run_client_app(parser: argparse.ArgumentParser) -> None:
59
67
  "client-app",
60
68
  help="For example: `client:app` or `project.package.module:wrapper.app`",
61
69
  )
70
+ _parse_args_common(parser)
71
+
72
+
73
+ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
62
74
  parser.add_argument(
63
75
  "--insecure",
64
76
  action="store_true",
flwr/server/compat/app.py CHANGED
@@ -29,7 +29,7 @@ from flwr.server.server import Server, init_defaults, run_fl
29
29
  from flwr.server.server_config import ServerConfig
30
30
  from flwr.server.strategy import Strategy
31
31
 
32
- from ..driver import Driver
32
+ from ..driver import Driver, GrpcDriver
33
33
  from .app_utils import start_update_client_manager_thread
34
34
 
35
35
  DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
@@ -114,7 +114,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
114
114
  # Create the Driver
115
115
  if isinstance(root_certificates, str):
116
116
  root_certificates = Path(root_certificates).read_bytes()
117
- driver = Driver(
117
+ driver = GrpcDriver(
118
118
  driver_service_address=address, root_certificates=root_certificates
119
119
  )
120
120
 
@@ -89,7 +89,7 @@ def _update_client_manager(
89
89
  for node_id in new_nodes:
90
90
  client_proxy = DriverClientProxy(
91
91
  node_id=node_id,
92
- driver=driver.grpc_driver, # type: ignore
92
+ driver=driver.grpc_driver_helper, # type: ignore
93
93
  anonymous=False,
94
94
  run_id=driver.run_id, # type: ignore
95
95
  )
@@ -25,7 +25,7 @@ from flwr.common import serde
25
25
  from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
26
26
  from flwr.server.client_proxy import ClientProxy
27
27
 
28
- from ..driver.grpc_driver import GrpcDriver
28
+ from ..driver.grpc_driver import GrpcDriverHelper
29
29
 
30
30
  SLEEP_TIME = 1
31
31
 
@@ -33,7 +33,9 @@ SLEEP_TIME = 1
33
33
  class DriverClientProxy(ClientProxy):
34
34
  """Flower client proxy which delegates work using the Driver API."""
35
35
 
36
- def __init__(self, node_id: int, driver: GrpcDriver, anonymous: bool, run_id: int):
36
+ def __init__(
37
+ self, node_id: int, driver: GrpcDriverHelper, anonymous: bool, run_id: int
38
+ ):
37
39
  super().__init__(str(node_id))
38
40
  self.node_id = node_id
39
41
  self.driver = driver
@@ -16,9 +16,10 @@
16
16
 
17
17
 
18
18
  from .driver import Driver
19
- from .grpc_driver import GrpcDriver
19
+ from .grpc_driver import GrpcDriver, GrpcDriverHelper
20
20
 
21
21
  __all__ = [
22
22
  "Driver",
23
23
  "GrpcDriver",
24
+ "GrpcDriverHelper",
24
25
  ]
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,79 +12,19 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower driver service client."""
15
+ """Driver (abstract base class)."""
16
16
 
17
- import time
18
- import warnings
19
- from typing import Iterable, List, Optional, Tuple
20
17
 
21
- from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
22
- from flwr.common.serde import message_from_taskres, message_to_taskins
23
- from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
24
- CreateRunRequest,
25
- GetNodesRequest,
26
- PullTaskResRequest,
27
- PushTaskInsRequest,
28
- )
29
- from flwr.proto.node_pb2 import Node # pylint: disable=E0611
30
- from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
18
+ from abc import ABC, abstractmethod
19
+ from typing import Iterable, List, Optional
31
20
 
32
- from .grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER, GrpcDriver
21
+ from flwr.common import Message, RecordSet
33
22
 
34
23
 
35
- class Driver:
36
- """`Driver` class provides an interface to the Driver API.
37
-
38
- Parameters
39
- ----------
40
- driver_service_address : Optional[str]
41
- The IPv4 or IPv6 address of the Driver API server.
42
- Defaults to `"[::]:9091"`.
43
- certificates : bytes (default: None)
44
- Tuple containing root certificate, server certificate, and private key
45
- to start a secure SSL-enabled server. The tuple is expected to have
46
- three bytes elements in the following order:
47
-
48
- * CA certificate.
49
- * server certificate.
50
- * server private key.
51
- """
52
-
53
- def __init__(
54
- self,
55
- driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
56
- root_certificates: Optional[bytes] = None,
57
- ) -> None:
58
- self.addr = driver_service_address
59
- self.root_certificates = root_certificates
60
- self.grpc_driver: Optional[GrpcDriver] = None
61
- self.run_id: Optional[int] = None
62
- self.node = Node(node_id=0, anonymous=True)
63
-
64
- def _get_grpc_driver_and_run_id(self) -> Tuple[GrpcDriver, int]:
65
- # Check if the GrpcDriver is initialized
66
- if self.grpc_driver is None or self.run_id is None:
67
- # Connect and create run
68
- self.grpc_driver = GrpcDriver(
69
- driver_service_address=self.addr,
70
- root_certificates=self.root_certificates,
71
- )
72
- self.grpc_driver.connect()
73
- res = self.grpc_driver.create_run(CreateRunRequest())
74
- self.run_id = res.run_id
75
- return self.grpc_driver, self.run_id
76
-
77
- def _check_message(self, message: Message) -> None:
78
- # Check if the message is valid
79
- if not (
80
- message.metadata.run_id == self.run_id
81
- and message.metadata.src_node_id == self.node.node_id
82
- and message.metadata.message_id == ""
83
- and message.metadata.reply_to_message == ""
84
- and message.metadata.ttl > 0
85
- ):
86
- raise ValueError(f"Invalid message: {message}")
24
+ class Driver(ABC):
25
+ """Abstract base Driver class for the Driver API."""
87
26
 
27
+ @abstractmethod
88
28
  def create_message( # pylint: disable=too-many-arguments
89
29
  self,
90
30
  content: RecordSet,
@@ -122,35 +62,12 @@ class Driver:
122
62
  message : Message
123
63
  A new `Message` instance with the specified content and metadata.
124
64
  """
125
- _, run_id = self._get_grpc_driver_and_run_id()
126
- if ttl:
127
- warnings.warn(
128
- "A custom TTL was set, but note that the SuperLink does not enforce "
129
- "the TTL yet. The SuperLink will start enforcing the TTL in a future "
130
- "version of Flower.",
131
- stacklevel=2,
132
- )
133
-
134
- ttl_ = DEFAULT_TTL if ttl is None else ttl
135
- metadata = Metadata(
136
- run_id=run_id,
137
- message_id="", # Will be set by the server
138
- src_node_id=self.node.node_id,
139
- dst_node_id=dst_node_id,
140
- reply_to_message="",
141
- group_id=group_id,
142
- ttl=ttl_,
143
- message_type=message_type,
144
- )
145
- return Message(metadata=metadata, content=content)
146
65
 
66
+ @abstractmethod
147
67
  def get_node_ids(self) -> List[int]:
148
68
  """Get node IDs."""
149
- grpc_driver, run_id = self._get_grpc_driver_and_run_id()
150
- # Call GrpcDriver method
151
- res = grpc_driver.get_nodes(GetNodesRequest(run_id=run_id))
152
- return [node.node_id for node in res.nodes]
153
69
 
70
+ @abstractmethod
154
71
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
155
72
  """Push messages to specified node IDs.
156
73
 
@@ -168,20 +85,8 @@ class Driver:
168
85
  An iterable of IDs for the messages that were sent, which can be used
169
86
  to pull replies.
170
87
  """
171
- grpc_driver, _ = self._get_grpc_driver_and_run_id()
172
- # Construct TaskIns
173
- task_ins_list: List[TaskIns] = []
174
- for msg in messages:
175
- # Check message
176
- self._check_message(msg)
177
- # Convert Message to TaskIns
178
- taskins = message_to_taskins(msg)
179
- # Add to list
180
- task_ins_list.append(taskins)
181
- # Call GrpcDriver method
182
- res = grpc_driver.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
183
- return list(res.task_ids)
184
88
 
89
+ @abstractmethod
185
90
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
186
91
  """Pull messages based on message IDs.
187
92
 
@@ -198,15 +103,8 @@ class Driver:
198
103
  messages : Iterable[Message]
199
104
  An iterable of messages received.
200
105
  """
201
- grpc_driver, _ = self._get_grpc_driver_and_run_id()
202
- # Pull TaskRes
203
- res = grpc_driver.pull_task_res(
204
- PullTaskResRequest(node=self.node, task_ids=message_ids)
205
- )
206
- # Convert TaskRes to Message
207
- msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
208
- return msgs
209
106
 
107
+ @abstractmethod
210
108
  def send_and_receive(
211
109
  self,
212
110
  messages: Iterable[Message],
@@ -240,28 +138,3 @@ class Driver:
240
138
  replies for all sent messages. A message remains valid until its TTL,
241
139
  which is not affected by `timeout`.
242
140
  """
243
- # Push messages
244
- msg_ids = set(self.push_messages(messages))
245
-
246
- # Pull messages
247
- end_time = time.time() + (timeout if timeout is not None else 0.0)
248
- ret: List[Message] = []
249
- while timeout is None or time.time() < end_time:
250
- res_msgs = self.pull_messages(msg_ids)
251
- ret.extend(res_msgs)
252
- msg_ids.difference_update(
253
- {msg.metadata.reply_to_message for msg in res_msgs}
254
- )
255
- if len(msg_ids) == 0:
256
- break
257
- # Sleep
258
- time.sleep(3)
259
- return ret
260
-
261
- def close(self) -> None:
262
- """Disconnect from the SuperLink if connected."""
263
- # Check if GrpcDriver is initialized
264
- if self.grpc_driver is None:
265
- return
266
- # Disconnect
267
- self.grpc_driver.disconnect()
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,17 +12,19 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Flower driver service client."""
16
-
15
+ """Flower gRPC Driver."""
17
16
 
17
+ import time
18
+ import warnings
18
19
  from logging import DEBUG, ERROR, WARNING
19
- from typing import Optional
20
+ from typing import Iterable, List, Optional, Tuple
20
21
 
21
22
  import grpc
22
23
 
23
- from flwr.common import EventType, event
24
+ from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event
24
25
  from flwr.common.grpc import create_channel
25
26
  from flwr.common.logger import log
27
+ from flwr.common.serde import message_from_taskres, message_to_taskins
26
28
  from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
27
29
  CreateRunRequest,
28
30
  CreateRunResponse,
@@ -34,19 +36,23 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
34
36
  PushTaskInsResponse,
35
37
  )
36
38
  from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
39
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
40
+ from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
41
+
42
+ from .driver import Driver
37
43
 
38
44
  DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
39
45
 
40
46
  ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
41
47
  [Driver] Error: Not connected.
42
48
 
43
- Call `connect()` on the `GrpcDriver` instance before calling any of the other
44
- `GrpcDriver` methods.
49
+ Call `connect()` on the `GrpcDriverHelper` instance before calling any of the other
50
+ `GrpcDriverHelper` methods.
45
51
  """
46
52
 
47
53
 
48
- class GrpcDriver:
49
- """`GrpcDriver` provides access to the gRPC Driver API/service."""
54
+ class GrpcDriverHelper:
55
+ """`GrpcDriverHelper` provides access to the gRPC Driver API/service."""
50
56
 
51
57
  def __init__(
52
58
  self,
@@ -89,7 +95,7 @@ class GrpcDriver:
89
95
  # Check if channel is open
90
96
  if self.stub is None:
91
97
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
92
- raise ConnectionError("`GrpcDriver` instance not connected")
98
+ raise ConnectionError("`GrpcDriverHelper` instance not connected")
93
99
 
94
100
  # Call Driver API
95
101
  res: CreateRunResponse = self.stub.CreateRun(request=req)
@@ -100,7 +106,7 @@ class GrpcDriver:
100
106
  # Check if channel is open
101
107
  if self.stub is None:
102
108
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
103
- raise ConnectionError("`GrpcDriver` instance not connected")
109
+ raise ConnectionError("`GrpcDriverHelper` instance not connected")
104
110
 
105
111
  # Call gRPC Driver API
106
112
  res: GetNodesResponse = self.stub.GetNodes(request=req)
@@ -111,7 +117,7 @@ class GrpcDriver:
111
117
  # Check if channel is open
112
118
  if self.stub is None:
113
119
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
114
- raise ConnectionError("`GrpcDriver` instance not connected")
120
+ raise ConnectionError("`GrpcDriverHelper` instance not connected")
115
121
 
116
122
  # Call gRPC Driver API
117
123
  res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
@@ -122,8 +128,179 @@ class GrpcDriver:
122
128
  # Check if channel is open
123
129
  if self.stub is None:
124
130
  log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
125
- raise ConnectionError("`GrpcDriver` instance not connected")
131
+ raise ConnectionError("`GrpcDriverHelper` instance not connected")
126
132
 
127
133
  # Call Driver API
128
134
  res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
129
135
  return res
136
+
137
+
138
+ class GrpcDriver(Driver):
139
+ """`Driver` class provides an interface to the Driver API.
140
+
141
+ Parameters
142
+ ----------
143
+ driver_service_address : Optional[str]
144
+ The IPv4 or IPv6 address of the Driver API server.
145
+ Defaults to `"[::]:9091"`.
146
+ certificates : bytes (default: None)
147
+ Tuple containing root certificate, server certificate, and private key
148
+ to start a secure SSL-enabled server. The tuple is expected to have
149
+ three bytes elements in the following order:
150
+
151
+ * CA certificate.
152
+ * server certificate.
153
+ * server private key.
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
159
+ root_certificates: Optional[bytes] = None,
160
+ ) -> None:
161
+ self.addr = driver_service_address
162
+ self.root_certificates = root_certificates
163
+ self.grpc_driver_helper: Optional[GrpcDriverHelper] = None
164
+ self.run_id: Optional[int] = None
165
+ self.node = Node(node_id=0, anonymous=True)
166
+
167
+ def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]:
168
+ # Check if the GrpcDriverHelper is initialized
169
+ if self.grpc_driver_helper is None or self.run_id is None:
170
+ # Connect and create run
171
+ self.grpc_driver_helper = GrpcDriverHelper(
172
+ driver_service_address=self.addr,
173
+ root_certificates=self.root_certificates,
174
+ )
175
+ self.grpc_driver_helper.connect()
176
+ res = self.grpc_driver_helper.create_run(CreateRunRequest())
177
+ self.run_id = res.run_id
178
+ return self.grpc_driver_helper, self.run_id
179
+
180
+ def _check_message(self, message: Message) -> None:
181
+ # Check if the message is valid
182
+ if not (
183
+ message.metadata.run_id == self.run_id
184
+ and message.metadata.src_node_id == self.node.node_id
185
+ and message.metadata.message_id == ""
186
+ and message.metadata.reply_to_message == ""
187
+ and message.metadata.ttl > 0
188
+ ):
189
+ raise ValueError(f"Invalid message: {message}")
190
+
191
+ def create_message( # pylint: disable=too-many-arguments
192
+ self,
193
+ content: RecordSet,
194
+ message_type: str,
195
+ dst_node_id: int,
196
+ group_id: str,
197
+ ttl: Optional[float] = None,
198
+ ) -> Message:
199
+ """Create a new message with specified parameters.
200
+
201
+ This method constructs a new `Message` with given content and metadata.
202
+ The `run_id` and `src_node_id` will be set automatically.
203
+ """
204
+ _, run_id = self._get_grpc_driver_helper_and_run_id()
205
+ if ttl:
206
+ warnings.warn(
207
+ "A custom TTL was set, but note that the SuperLink does not enforce "
208
+ "the TTL yet. The SuperLink will start enforcing the TTL in a future "
209
+ "version of Flower.",
210
+ stacklevel=2,
211
+ )
212
+
213
+ ttl_ = DEFAULT_TTL if ttl is None else ttl
214
+ metadata = Metadata(
215
+ run_id=run_id,
216
+ message_id="", # Will be set by the server
217
+ src_node_id=self.node.node_id,
218
+ dst_node_id=dst_node_id,
219
+ reply_to_message="",
220
+ group_id=group_id,
221
+ ttl=ttl_,
222
+ message_type=message_type,
223
+ )
224
+ return Message(metadata=metadata, content=content)
225
+
226
+ def get_node_ids(self) -> List[int]:
227
+ """Get node IDs."""
228
+ grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id()
229
+ # Call GrpcDriverHelper method
230
+ res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id))
231
+ return [node.node_id for node in res.nodes]
232
+
233
+ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
234
+ """Push messages to specified node IDs.
235
+
236
+ This method takes an iterable of messages and sends each message
237
+ to the node specified in `dst_node_id`.
238
+ """
239
+ grpc_driver_helper, _ = self._get_grpc_driver_helper_and_run_id()
240
+ # Construct TaskIns
241
+ task_ins_list: List[TaskIns] = []
242
+ for msg in messages:
243
+ # Check message
244
+ self._check_message(msg)
245
+ # Convert Message to TaskIns
246
+ taskins = message_to_taskins(msg)
247
+ # Add to list
248
+ task_ins_list.append(taskins)
249
+ # Call GrpcDriverHelper method
250
+ res = grpc_driver_helper.push_task_ins(
251
+ PushTaskInsRequest(task_ins_list=task_ins_list)
252
+ )
253
+ return list(res.task_ids)
254
+
255
+ def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
256
+ """Pull messages based on message IDs.
257
+
258
+ This method is used to collect messages from the SuperLink that correspond to a
259
+ set of given message IDs.
260
+ """
261
+ grpc_driver, _ = self._get_grpc_driver_helper_and_run_id()
262
+ # Pull TaskRes
263
+ res = grpc_driver.pull_task_res(
264
+ PullTaskResRequest(node=self.node, task_ids=message_ids)
265
+ )
266
+ # Convert TaskRes to Message
267
+ msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
268
+ return msgs
269
+
270
+ def send_and_receive(
271
+ self,
272
+ messages: Iterable[Message],
273
+ *,
274
+ timeout: Optional[float] = None,
275
+ ) -> Iterable[Message]:
276
+ """Push messages to specified node IDs and pull the reply messages.
277
+
278
+ This method sends a list of messages to their destination node IDs and then
279
+ waits for the replies. It continues to pull replies until either all replies are
280
+ received or the specified timeout duration is exceeded.
281
+ """
282
+ # Push messages
283
+ msg_ids = set(self.push_messages(messages))
284
+
285
+ # Pull messages
286
+ end_time = time.time() + (timeout if timeout is not None else 0.0)
287
+ ret: List[Message] = []
288
+ while timeout is None or time.time() < end_time:
289
+ res_msgs = self.pull_messages(msg_ids)
290
+ ret.extend(res_msgs)
291
+ msg_ids.difference_update(
292
+ {msg.metadata.reply_to_message for msg in res_msgs}
293
+ )
294
+ if len(msg_ids) == 0:
295
+ break
296
+ # Sleep
297
+ time.sleep(3)
298
+ return ret
299
+
300
+ def close(self) -> None:
301
+ """Disconnect from the SuperLink if connected."""
302
+ # Check if GrpcDriverHelper is initialized
303
+ if self.grpc_driver_helper is None:
304
+ return
305
+ # Disconnect
306
+ self.grpc_driver_helper.disconnect()
@@ -25,7 +25,7 @@ from flwr.common import Context, EventType, RecordSet, event
25
25
  from flwr.common.logger import log, update_console_handler
26
26
  from flwr.common.object_ref import load_app
27
27
 
28
- from .driver.driver import Driver
28
+ from .driver import Driver, GrpcDriver
29
29
  from .server_app import LoadServerAppError, ServerApp
30
30
 
31
31
 
@@ -128,13 +128,13 @@ def run_server_app() -> None:
128
128
  server_app_dir = args.dir
129
129
  server_app_attr = getattr(args, "server-app")
130
130
 
131
- # Initialize Driver
132
- driver = Driver(
131
+ # Initialize GrpcDriver
132
+ driver = GrpcDriver(
133
133
  driver_service_address=args.server,
134
134
  root_certificates=root_certificates,
135
135
  )
136
136
 
137
- # Run the Server App with the Driver
137
+ # Run the ServerApp with the Driver
138
138
  run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr)
139
139
 
140
140
  # Clean up
@@ -29,7 +29,7 @@ import grpc
29
29
  from flwr.client import ClientApp
30
30
  from flwr.common import EventType, event, log
31
31
  from flwr.common.typing import ConfigsRecordValues
32
- from flwr.server.driver.driver import Driver
32
+ from flwr.server.driver import Driver, GrpcDriver
33
33
  from flwr.server.run_serverapp import run
34
34
  from flwr.server.server_app import ServerApp
35
35
  from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc
@@ -204,7 +204,7 @@ def _main_loop(
204
204
  serverapp_th = None
205
205
  try:
206
206
  # Initialize Driver
207
- driver = Driver(
207
+ driver = GrpcDriver(
208
208
  driver_service_address=driver_api_address,
209
209
  root_certificates=None,
210
210
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.9.0.dev20240419
3
+ Version: 1.9.0.dev20240422
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.ai
6
6
  License: Apache-2.0
@@ -51,7 +51,7 @@ flwr/client/numpy_client.py,sha256=u76GWAdHmJM88Agm2EgLQSvO8Jnk225mJTk-_TmPjFE,1
51
51
  flwr/client/rest_client/__init__.py,sha256=ThwOnkMdzxo_UuyTI47Q7y9oSpuTgNT2OuFvJCfuDiw,735
52
52
  flwr/client/rest_client/connection.py,sha256=ZxTFVDXlONqKTX6uYgxshoEWqzqVcQ8QQ2hKS93oLM8,11302
53
53
  flwr/client/supernode/__init__.py,sha256=D5swXxemuRbA2rB_T9B8LwJW-_PucXwmlFQQerwIUv0,793
54
- flwr/client/supernode/app.py,sha256=JXRZ76JdyAkhfaEEqsMiONWVQ0bn8YqzZg9oHC4Qfko,3436
54
+ flwr/client/supernode/app.py,sha256=gauvN8elkIy0vuT0GxT7MmkuBRY74ckZfpxejE7dduM,3861
55
55
  flwr/client/typing.py,sha256=c9EvjlEjasxn1Wqx6bGl6Xg6vM1gMFfmXht-E2i5J-k,1006
56
56
  flwr/common/__init__.py,sha256=dHOptgKxna78CEQLD5Yu0QIsoSgpIIw5AhIUZCHDWAU,3721
57
57
  flwr/common/address.py,sha256=iTAN9jtmIGMrWFnx9XZQl45ZEtQJVZZLYPRBSNVARGI,1882
@@ -124,17 +124,16 @@ flwr/server/app.py,sha256=FriloRrkDHTlB5G7EBn6sH4v5GhiYFf_ZhbdROgjKbY,24199
124
124
  flwr/server/client_manager.py,sha256=T8UDSRJBVD3fyIDI7NTAA-NA7GPrMNNgH2OAF54RRxE,6127
125
125
  flwr/server/client_proxy.py,sha256=4G-oTwhb45sfWLx2uZdcXD98IZwdTS6F88xe3akCdUg,2399
126
126
  flwr/server/compat/__init__.py,sha256=VxnJtJyOjNFQXMNi9hIuzNlZM5n0Hj1p3aq_Pm2udw4,892
127
- flwr/server/compat/app.py,sha256=3Skh76Rg80B4oME1dJOhZvn9eTfVmTNIQ0jCiZ6CzeQ,5271
128
- flwr/server/compat/app_utils.py,sha256=GGmGApka7J9wHY2tiU_ZDejNvtfW_CZ9NZtb8L30M90,3496
129
- flwr/server/compat/driver_client_proxy.py,sha256=QWLl5YJwI6NVADwjQGQJqkLtCfPNT-aRH0NF9yeGEnA,7344
127
+ flwr/server/compat/app.py,sha256=BhF3DySbvKkOIyNXnB1rwZhw8cC8yK_w91Fku8HmC_w,5287
128
+ flwr/server/compat/app_utils.py,sha256=S-M4sGIiZPXXgKFLjlbFP2yN7d-oIj6DaiJNPIZ2z3A,3503
129
+ flwr/server/compat/driver_client_proxy.py,sha256=5XWroBrtA8MrQ5xQjgsju5RauMxNPshYLS_EtONEL1I,7370
130
130
  flwr/server/compat/legacy_context.py,sha256=D2s7PvQoDnTexuRmf1uG9Von7GUj4Qqyr7qLklSlKAM,1766
131
131
  flwr/server/criterion.py,sha256=ypbAexbztzGUxNen9RCHF91QeqiEQix4t4Ih3E-42MM,1061
132
- flwr/server/driver/__init__.py,sha256=yYyVX1FcDiDFM6rw0-DSZpuRy0EoWRfG9puwlQUswFA,820
133
- flwr/server/driver/abc_driver.py,sha256=t9SSSDlo9wT_y2Nl7waGYMTm2VlkvK3_bOb7ggPPlho,5090
134
- flwr/server/driver/driver.py,sha256=AwAxgYRx-FI6NvI5ukmdGlEmQRyp5GZSElFnDZhelj8,10106
135
- flwr/server/driver/grpc_driver.py,sha256=D2n3_Es_DHFgQsq_TjYVEz8RYJJJYoe24E1vozaTFiE,4586
132
+ flwr/server/driver/__init__.py,sha256=bbVL5pyA0Y2HcUK4s5U0B4epI-BuUFyEJbchew_8tJY,862
133
+ flwr/server/driver/driver.py,sha256=t9SSSDlo9wT_y2Nl7waGYMTm2VlkvK3_bOb7ggPPlho,5090
134
+ flwr/server/driver/grpc_driver.py,sha256=U5zfI3uYPUBaoOe4JI32t3dvCoSDacZ6EE0g9B8tKbU,11418
136
135
  flwr/server/history.py,sha256=hDsoBaA4kUa6d1yvDVXuLluBqOBKSm0_fVDtUtYJkmg,5121
137
- flwr/server/run_serverapp.py,sha256=3hoXa57T4L1vOWVWPSSdZ_UyRO-uTwUIrhha6TJAXMg,5592
136
+ flwr/server/run_serverapp.py,sha256=3FqKVdFJ280dOVQQ63fu3kL7yNg_4ggtx2H7ljSBT1c,5604
138
137
  flwr/server/server.py,sha256=UnBRlI6AGTj0nKeRtEQ3IalM3TJmggMKXhDyn8yKZNk,17664
139
138
  flwr/server/server_app.py,sha256=KgAT_HqsfseTLNnfX2ph42PBbVqQ0lFzvYrT90V34y0,4402
140
139
  flwr/server/server_config.py,sha256=CZaHVAsMvGLjpWVcLPkiYxgJN4xfIyAiUrCI3fETKY4,1349
@@ -205,9 +204,9 @@ flwr/simulation/ray_transport/__init__.py,sha256=FsaAnzC4cw4DqoouBCix6496k29jACk
205
204
  flwr/simulation/ray_transport/ray_actor.py,sha256=_wv2eP7qxkCZ-6rMyYWnjLrGPBZRxjvTPjaVk8zIaQ4,19367
206
205
  flwr/simulation/ray_transport/ray_client_proxy.py,sha256=oDu4sEPIOu39vrNi-fqDAe10xtNUXMO49bM2RWfRcyw,6738
207
206
  flwr/simulation/ray_transport/utils.py,sha256=TYdtfg1P9VfTdLMOJlifInGpxWHYs9UfUqIv2wfkRLA,2392
208
- flwr/simulation/run_simulation.py,sha256=HiIH6aa_v56NfKQN5ZBd94NyVfaZNyFs43_kItYsQXU,15685
209
- flwr_nightly-1.9.0.dev20240419.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
210
- flwr_nightly-1.9.0.dev20240419.dist-info/METADATA,sha256=W3tyRxj4LXms8QbNvSBIspZOouKU5DIz-UZ-UAiOsYw,15260
211
- flwr_nightly-1.9.0.dev20240419.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
212
- flwr_nightly-1.9.0.dev20240419.dist-info/entry_points.txt,sha256=DBrrf685V2W9NbbchQwvuqBEpj5ik8tMZNoZg_W2bZY,363
213
- flwr_nightly-1.9.0.dev20240419.dist-info/RECORD,,
207
+ flwr/simulation/run_simulation.py,sha256=nxXNv3r8ODImd5o6f0sa_w5L0I08LD2Udw2OTXStRnQ,15694
208
+ flwr_nightly-1.9.0.dev20240422.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
209
+ flwr_nightly-1.9.0.dev20240422.dist-info/METADATA,sha256=2g_AiXLNJzV4x9RNTWo1h1LjzMpUdhUQ8uNAPPxqlv8,15260
210
+ flwr_nightly-1.9.0.dev20240422.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
211
+ flwr_nightly-1.9.0.dev20240422.dist-info/entry_points.txt,sha256=DBrrf685V2W9NbbchQwvuqBEpj5ik8tMZNoZg_W2bZY,363
212
+ flwr_nightly-1.9.0.dev20240422.dist-info/RECORD,,
@@ -1,140 +0,0 @@
1
- # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- """Driver (abstract base class)."""
16
-
17
-
18
- from abc import ABC, abstractmethod
19
- from typing import Iterable, List, Optional
20
-
21
- from flwr.common import Message, RecordSet
22
-
23
-
24
- class Driver(ABC):
25
- """Abstract base Driver class for the Driver API."""
26
-
27
- @abstractmethod
28
- def create_message( # pylint: disable=too-many-arguments
29
- self,
30
- content: RecordSet,
31
- message_type: str,
32
- dst_node_id: int,
33
- group_id: str,
34
- ttl: Optional[float] = None,
35
- ) -> Message:
36
- """Create a new message with specified parameters.
37
-
38
- This method constructs a new `Message` with given content and metadata.
39
- The `run_id` and `src_node_id` will be set automatically.
40
-
41
- Parameters
42
- ----------
43
- content : RecordSet
44
- The content for the new message. This holds records that are to be sent
45
- to the destination node.
46
- message_type : str
47
- The type of the message, defining the action to be executed on
48
- the receiving end.
49
- dst_node_id : int
50
- The ID of the destination node to which the message is being sent.
51
- group_id : str
52
- The ID of the group to which this message is associated. In some settings,
53
- this is used as the FL round.
54
- ttl : Optional[float] (default: None)
55
- Time-to-live for the round trip of this message, i.e., the time from sending
56
- this message to receiving a reply. It specifies in seconds the duration for
57
- which the message and its potential reply are considered valid. If unset,
58
- the default TTL (i.e., `common.DEFAULT_TTL`) will be used.
59
-
60
- Returns
61
- -------
62
- message : Message
63
- A new `Message` instance with the specified content and metadata.
64
- """
65
-
66
- @abstractmethod
67
- def get_node_ids(self) -> List[int]:
68
- """Get node IDs."""
69
-
70
- @abstractmethod
71
- def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
72
- """Push messages to specified node IDs.
73
-
74
- This method takes an iterable of messages and sends each message
75
- to the node specified in `dst_node_id`.
76
-
77
- Parameters
78
- ----------
79
- messages : Iterable[Message]
80
- An iterable of messages to be sent.
81
-
82
- Returns
83
- -------
84
- message_ids : Iterable[str]
85
- An iterable of IDs for the messages that were sent, which can be used
86
- to pull replies.
87
- """
88
-
89
- @abstractmethod
90
- def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
91
- """Pull messages based on message IDs.
92
-
93
- This method is used to collect messages from the SuperLink
94
- that correspond to a set of given message IDs.
95
-
96
- Parameters
97
- ----------
98
- message_ids : Iterable[str]
99
- An iterable of message IDs for which reply messages are to be retrieved.
100
-
101
- Returns
102
- -------
103
- messages : Iterable[Message]
104
- An iterable of messages received.
105
- """
106
-
107
- @abstractmethod
108
- def send_and_receive(
109
- self,
110
- messages: Iterable[Message],
111
- *,
112
- timeout: Optional[float] = None,
113
- ) -> Iterable[Message]:
114
- """Push messages to specified node IDs and pull the reply messages.
115
-
116
- This method sends a list of messages to their destination node IDs and then
117
- waits for the replies. It continues to pull replies until either all
118
- replies are received or the specified timeout duration is exceeded.
119
-
120
- Parameters
121
- ----------
122
- messages : Iterable[Message]
123
- An iterable of messages to be sent.
124
- timeout : Optional[float] (default: None)
125
- The timeout duration in seconds. If specified, the method will wait for
126
- replies for this duration. If `None`, there is no time limit and the method
127
- will wait until replies for all messages are received.
128
-
129
- Returns
130
- -------
131
- replies : Iterable[Message]
132
- An iterable of reply messages received from the SuperLink.
133
-
134
- Notes
135
- -----
136
- This method uses `push_messages` to send the messages and `pull_messages`
137
- to collect the replies. If `timeout` is set, the method may not return
138
- replies for all sent messages. A message remains valid until its TTL,
139
- which is not affected by `timeout`.
140
- """