flwr-nightly 1.8.0.dev20240327__py3-none-any.whl → 1.8.0.dev20240402__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (30) hide show
  1. flwr/client/app.py +53 -29
  2. flwr/client/client_app.py +16 -0
  3. flwr/client/grpc_rere_client/connection.py +71 -29
  4. flwr/client/heartbeat.py +72 -0
  5. flwr/client/rest_client/connection.py +102 -28
  6. flwr/common/constant.py +20 -0
  7. flwr/common/logger.py +4 -4
  8. flwr/common/message.py +53 -14
  9. flwr/common/retry_invoker.py +24 -13
  10. flwr/proto/fleet_pb2.py +26 -26
  11. flwr/proto/fleet_pb2.pyi +5 -0
  12. flwr/server/compat/driver_client_proxy.py +16 -0
  13. flwr/server/driver/driver.py +15 -5
  14. flwr/server/server_app.py +3 -0
  15. flwr/server/superlink/fleet/message_handler/message_handler.py +3 -2
  16. flwr/server/superlink/fleet/rest_rere/rest_api.py +28 -0
  17. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -4
  18. flwr/server/superlink/fleet/vce/vce_api.py +61 -27
  19. flwr/server/superlink/state/in_memory_state.py +25 -8
  20. flwr/server/superlink/state/sqlite_state.py +53 -5
  21. flwr/server/superlink/state/state.py +1 -1
  22. flwr/server/superlink/state/utils.py +56 -0
  23. flwr/server/workflow/default_workflows.py +1 -4
  24. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +0 -5
  25. flwr/simulation/ray_transport/ray_actor.py +8 -24
  26. {flwr_nightly-1.8.0.dev20240327.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/METADATA +1 -1
  27. {flwr_nightly-1.8.0.dev20240327.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/RECORD +30 -28
  28. {flwr_nightly-1.8.0.dev20240327.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/LICENSE +0 -0
  29. {flwr_nightly-1.8.0.dev20240327.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/WHEEL +0 -0
  30. {flwr_nightly-1.8.0.dev20240327.dist-info → flwr_nightly-1.8.0.dev20240402.dist-info}/entry_points.txt +0 -0
flwr/common/constant.py CHANGED
@@ -36,6 +36,13 @@ TRANSPORT_TYPES = [
36
36
  TRANSPORT_TYPE_VCE,
37
37
  ]
38
38
 
39
+ # Constants for ping
40
+ PING_DEFAULT_INTERVAL = 30
41
+ PING_CALL_TIMEOUT = 5
42
+ PING_BASE_MULTIPLIER = 0.8
43
+ PING_RANDOM_RANGE = (-0.1, 0.1)
44
+ PING_MAX_INTERVAL = 1e300
45
+
39
46
 
40
47
  class MessageType:
41
48
  """Message type."""
@@ -68,3 +75,16 @@ class SType:
68
75
  def __new__(cls) -> SType:
69
76
  """Prevent instantiation."""
70
77
  raise TypeError(f"{cls.__name__} cannot be instantiated.")
78
+
79
+
80
+ class ErrorCode:
81
+ """Error codes for Message's Error."""
82
+
83
+ UNKNOWN = 0
84
+ LOAD_CLIENT_APP_EXCEPTION = 1
85
+ CLIENT_APP_RAISED_EXCEPTION = 2
86
+ NODE_UNAVAILABLE = 3
87
+
88
+ def __new__(cls) -> ErrorCode:
89
+ """Prevent instantiation."""
90
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")
flwr/common/logger.py CHANGED
@@ -164,13 +164,13 @@ logger = logging.getLogger(LOGGER_NAME) # pylint: disable=invalid-name
164
164
  log = logger.log # pylint: disable=invalid-name
165
165
 
166
166
 
167
- def warn_experimental_feature(name: str) -> None:
168
- """Warn the user when they use an experimental feature."""
167
+ def warn_preview_feature(name: str) -> None:
168
+ """Warn the user when they use a preview feature."""
169
169
  log(
170
170
  WARN,
171
- """EXPERIMENTAL FEATURE: %s
171
+ """PREVIEW FEATURE: %s
172
172
 
173
- This is an experimental feature. It could change significantly or be removed
173
+ This is a preview feature. It could change significantly or be removed
174
174
  entirely in future versions of Flower.
175
175
  """,
176
176
  name,
flwr/common/message.py CHANGED
@@ -17,6 +17,7 @@
17
17
  from __future__ import annotations
18
18
 
19
19
  import time
20
+ import warnings
20
21
  from dataclasses import dataclass
21
22
 
22
23
  from .record import RecordSet
@@ -297,22 +298,40 @@ class Message:
297
298
  partition_id=self.metadata.partition_id,
298
299
  )
299
300
 
300
- def create_error_reply(
301
- self,
302
- error: Error,
303
- ttl: float,
304
- ) -> Message:
301
+ def create_error_reply(self, error: Error, ttl: float | None = None) -> Message:
305
302
  """Construct a reply message indicating an error happened.
306
303
 
307
304
  Parameters
308
305
  ----------
309
306
  error : Error
310
307
  The error that was encountered.
311
- ttl : float
312
- Time-to-live for this message in seconds.
308
+ ttl : Optional[float] (default: None)
309
+ Time-to-live for this message in seconds. If unset, it will be set based
310
+ on the remaining time for the received message before it expires. This
311
+ follows the equation:
312
+
313
+ ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at)
313
314
  """
315
+ if ttl:
316
+ warnings.warn(
317
+ "A custom TTL was set, but note that the SuperLink does not enforce "
318
+ "the TTL yet. The SuperLink will start enforcing the TTL in a future "
319
+ "version of Flower.",
320
+ stacklevel=2,
321
+ )
322
+ # If no TTL passed, use default for message creation (will update after
323
+ # message creation)
324
+ ttl_ = DEFAULT_TTL if ttl is None else ttl
314
325
  # Create reply with error
315
- message = Message(metadata=self._create_reply_metadata(ttl), error=error)
326
+ message = Message(metadata=self._create_reply_metadata(ttl_), error=error)
327
+
328
+ if ttl is None:
329
+ # Set TTL equal to the remaining time for the received message to expire
330
+ ttl = self.metadata.ttl - (
331
+ message.metadata.created_at - self.metadata.created_at
332
+ )
333
+ message.metadata.ttl = ttl
334
+
316
335
  return message
317
336
 
318
337
  def create_reply(self, content: RecordSet, ttl: float | None = None) -> Message:
@@ -327,18 +346,38 @@ class Message:
327
346
  content : RecordSet
328
347
  The content for the reply message.
329
348
  ttl : Optional[float] (default: None)
330
- Time-to-live for this message in seconds. If unset, it will use
331
- the `common.DEFAULT_TTL` value.
349
+ Time-to-live for this message in seconds. If unset, it will be set based
350
+ on the remaining time for the received message before it expires. This
351
+ follows the equation:
352
+
353
+ ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at)
332
354
 
333
355
  Returns
334
356
  -------
335
357
  Message
336
358
  A new `Message` instance representing the reply.
337
359
  """
338
- if ttl is None:
339
- ttl = DEFAULT_TTL
360
+ if ttl:
361
+ warnings.warn(
362
+ "A custom TTL was set, but note that the SuperLink does not enforce "
363
+ "the TTL yet. The SuperLink will start enforcing the TTL in a future "
364
+ "version of Flower.",
365
+ stacklevel=2,
366
+ )
367
+ # If no TTL passed, use default for message creation (will update after
368
+ # message creation)
369
+ ttl_ = DEFAULT_TTL if ttl is None else ttl
340
370
 
341
- return Message(
342
- metadata=self._create_reply_metadata(ttl),
371
+ message = Message(
372
+ metadata=self._create_reply_metadata(ttl_),
343
373
  content=content,
344
374
  )
375
+
376
+ if ttl is None:
377
+ # Set TTL equal to the remaining time for the received message to expire
378
+ ttl = self.metadata.ttl - (
379
+ message.metadata.created_at - self.metadata.created_at
380
+ )
381
+ message.metadata.ttl = ttl
382
+
383
+ return message
@@ -107,7 +107,7 @@ class RetryInvoker:
107
107
 
108
108
  Parameters
109
109
  ----------
110
- wait_factory: Callable[[], Generator[float, None, None]]
110
+ wait_gen_factory: Callable[[], Generator[float, None, None]]
111
111
  A generator yielding successive wait times in seconds. If the generator
112
112
  is finite, the giveup event will be triggered when the generator raises
113
113
  `StopIteration`.
@@ -129,12 +129,12 @@ class RetryInvoker:
129
129
  data class object detailing the invocation.
130
130
  on_giveup: Optional[Callable[[RetryState], None]] (default: None)
131
131
  A callable to be executed in the event that `max_tries` or `max_time` is
132
- exceeded, `should_giveup` returns True, or `wait_factory()` generator raises
132
+ exceeded, `should_giveup` returns True, or `wait_gen_factory()` generator raises
133
133
  `StopInteration`. The parameter is a data class object detailing the
134
134
  invocation.
135
135
  jitter: Optional[Callable[[float], float]] (default: full_jitter)
136
- A function of the value yielded by `wait_factory()` returning the actual time
137
- to wait. This function helps distribute wait times stochastically to avoid
136
+ A function of the value yielded by `wait_gen_factory()` returning the actual
137
+ time to wait. This function helps distribute wait times stochastically to avoid
138
138
  timing collisions across concurrent clients. Wait times are jittered by
139
139
  default using the `full_jitter` function. To disable jittering, pass
140
140
  `jitter=None`.
@@ -142,6 +142,13 @@ class RetryInvoker:
142
142
  A function accepting an exception instance, returning whether or not
143
143
  to give up prematurely before other give-up conditions are evaluated.
144
144
  If set to None, the strategy is to never give up prematurely.
145
+ wait_function: Optional[Callable[[float], None]] (default: None)
146
+ A function that defines how to wait between retry attempts. It accepts
147
+ one argument, the wait time in seconds, allowing the use of various waiting
148
+ mechanisms (e.g., asynchronous waits or event-based synchronization) suitable
149
+ for different execution environments. If set to `None`, the `wait_function`
150
+ defaults to `time.sleep`, which is ideal for synchronous operations. Custom
151
+ functions should manage execution flow to prevent blocking or interference.
145
152
 
146
153
  Examples
147
154
  --------
@@ -159,7 +166,7 @@ class RetryInvoker:
159
166
  # pylint: disable-next=too-many-arguments
160
167
  def __init__(
161
168
  self,
162
- wait_factory: Callable[[], Generator[float, None, None]],
169
+ wait_gen_factory: Callable[[], Generator[float, None, None]],
163
170
  recoverable_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]],
164
171
  max_tries: Optional[int],
165
172
  max_time: Optional[float],
@@ -169,8 +176,9 @@ class RetryInvoker:
169
176
  on_giveup: Optional[Callable[[RetryState], None]] = None,
170
177
  jitter: Optional[Callable[[float], float]] = full_jitter,
171
178
  should_giveup: Optional[Callable[[Exception], bool]] = None,
179
+ wait_function: Optional[Callable[[float], None]] = None,
172
180
  ) -> None:
173
- self.wait_factory = wait_factory
181
+ self.wait_gen_factory = wait_gen_factory
174
182
  self.recoverable_exceptions = recoverable_exceptions
175
183
  self.max_tries = max_tries
176
184
  self.max_time = max_time
@@ -179,6 +187,9 @@ class RetryInvoker:
179
187
  self.on_giveup = on_giveup
180
188
  self.jitter = jitter
181
189
  self.should_giveup = should_giveup
190
+ if wait_function is None:
191
+ wait_function = time.sleep
192
+ self.wait_function = wait_function
182
193
 
183
194
  # pylint: disable-next=too-many-locals
184
195
  def invoke(
@@ -212,13 +223,13 @@ class RetryInvoker:
212
223
  Raises
213
224
  ------
214
225
  Exception
215
- If the number of tries exceeds `max_tries`, if the total time
216
- exceeds `max_time`, if `wait_factory()` generator raises `StopInteration`,
226
+ If the number of tries exceeds `max_tries`, if the total time exceeds
227
+ `max_time`, if `wait_gen_factory()` generator raises `StopInteration`,
217
228
  or if the `should_giveup` returns True for a raised exception.
218
229
 
219
230
  Notes
220
231
  -----
221
- The time between retries is determined by the provided `wait_factory()`
232
+ The time between retries is determined by the provided `wait_gen_factory()`
222
233
  generator and can optionally be jittered using the `jitter` function.
223
234
  The recoverable exceptions that trigger a retry, as well as conditions to
224
235
  stop retries, are also determined by the class's initialization parameters.
@@ -231,13 +242,13 @@ class RetryInvoker:
231
242
  handler(cast(RetryState, ref_state[0]))
232
243
 
233
244
  try_cnt = 0
234
- wait_generator = self.wait_factory()
235
- start = time.time()
245
+ wait_generator = self.wait_gen_factory()
246
+ start = time.monotonic()
236
247
  ref_state: List[Optional[RetryState]] = [None]
237
248
 
238
249
  while True:
239
250
  try_cnt += 1
240
- elapsed_time = time.time() - start
251
+ elapsed_time = time.monotonic() - start
241
252
  state = RetryState(
242
253
  target=target,
243
254
  args=args,
@@ -282,7 +293,7 @@ class RetryInvoker:
282
293
  try_call_event_handler(self.on_backoff)
283
294
 
284
295
  # Sleep
285
- time.sleep(wait_time)
296
+ self.wait_function(state.actual_wait)
286
297
  else:
287
298
  # Trigger success event
288
299
  try_call_event_handler(self.on_success)
flwr/proto/fleet_pb2.py CHANGED
@@ -16,7 +16,7 @@ from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
16
16
  from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2
17
17
 
18
18
 
19
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x13\n\x11\x43reateNodeRequest\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"D\n\x0bPingRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x15\n\rping_interval\x18\x02 \x01(\x01\"\x1f\n\x0cPingResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\x86\x03\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12;\n\x04Ping\x12\x17.flwr.proto.PingRequest\x1a\x18.flwr.proto.PingResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x62\x06proto3')
19
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"*\n\x11\x43reateNodeRequest\x12\x15\n\rping_interval\x18\x01 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"D\n\x0bPingRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x15\n\rping_interval\x18\x02 \x01(\x01\"\x1f\n\x0cPingResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\x86\x03\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12;\n\x04Ping\x12\x17.flwr.proto.PingRequest\x1a\x18.flwr.proto.PingResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x62\x06proto3')
20
20
 
21
21
  _globals = globals()
22
22
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -26,29 +26,29 @@ if _descriptor._USE_C_DESCRIPTORS == False:
26
26
  _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._options = None
27
27
  _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_options = b'8\001'
28
28
  _globals['_CREATENODEREQUEST']._serialized_start=84
29
- _globals['_CREATENODEREQUEST']._serialized_end=103
30
- _globals['_CREATENODERESPONSE']._serialized_start=105
31
- _globals['_CREATENODERESPONSE']._serialized_end=157
32
- _globals['_DELETENODEREQUEST']._serialized_start=159
33
- _globals['_DELETENODEREQUEST']._serialized_end=210
34
- _globals['_DELETENODERESPONSE']._serialized_start=212
35
- _globals['_DELETENODERESPONSE']._serialized_end=232
36
- _globals['_PINGREQUEST']._serialized_start=234
37
- _globals['_PINGREQUEST']._serialized_end=302
38
- _globals['_PINGRESPONSE']._serialized_start=304
39
- _globals['_PINGRESPONSE']._serialized_end=335
40
- _globals['_PULLTASKINSREQUEST']._serialized_start=337
41
- _globals['_PULLTASKINSREQUEST']._serialized_end=407
42
- _globals['_PULLTASKINSRESPONSE']._serialized_start=409
43
- _globals['_PULLTASKINSRESPONSE']._serialized_end=516
44
- _globals['_PUSHTASKRESREQUEST']._serialized_start=518
45
- _globals['_PUSHTASKRESREQUEST']._serialized_end=582
46
- _globals['_PUSHTASKRESRESPONSE']._serialized_start=585
47
- _globals['_PUSHTASKRESRESPONSE']._serialized_end=759
48
- _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=713
49
- _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=759
50
- _globals['_RECONNECT']._serialized_start=761
51
- _globals['_RECONNECT']._serialized_end=791
52
- _globals['_FLEET']._serialized_start=794
53
- _globals['_FLEET']._serialized_end=1184
29
+ _globals['_CREATENODEREQUEST']._serialized_end=126
30
+ _globals['_CREATENODERESPONSE']._serialized_start=128
31
+ _globals['_CREATENODERESPONSE']._serialized_end=180
32
+ _globals['_DELETENODEREQUEST']._serialized_start=182
33
+ _globals['_DELETENODEREQUEST']._serialized_end=233
34
+ _globals['_DELETENODERESPONSE']._serialized_start=235
35
+ _globals['_DELETENODERESPONSE']._serialized_end=255
36
+ _globals['_PINGREQUEST']._serialized_start=257
37
+ _globals['_PINGREQUEST']._serialized_end=325
38
+ _globals['_PINGRESPONSE']._serialized_start=327
39
+ _globals['_PINGRESPONSE']._serialized_end=358
40
+ _globals['_PULLTASKINSREQUEST']._serialized_start=360
41
+ _globals['_PULLTASKINSREQUEST']._serialized_end=430
42
+ _globals['_PULLTASKINSRESPONSE']._serialized_start=432
43
+ _globals['_PULLTASKINSRESPONSE']._serialized_end=539
44
+ _globals['_PUSHTASKRESREQUEST']._serialized_start=541
45
+ _globals['_PUSHTASKRESREQUEST']._serialized_end=605
46
+ _globals['_PUSHTASKRESRESPONSE']._serialized_start=608
47
+ _globals['_PUSHTASKRESRESPONSE']._serialized_end=782
48
+ _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=736
49
+ _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=782
50
+ _globals['_RECONNECT']._serialized_start=784
51
+ _globals['_RECONNECT']._serialized_end=814
52
+ _globals['_FLEET']._serialized_start=817
53
+ _globals['_FLEET']._serialized_end=1207
54
54
  # @@protoc_insertion_point(module_scope)
flwr/proto/fleet_pb2.pyi CHANGED
@@ -16,8 +16,13 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
16
16
  class CreateNodeRequest(google.protobuf.message.Message):
17
17
  """CreateNode messages"""
18
18
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
19
+ PING_INTERVAL_FIELD_NUMBER: builtins.int
20
+ ping_interval: builtins.float
19
21
  def __init__(self,
22
+ *,
23
+ ping_interval: builtins.float = ...,
20
24
  ) -> None: ...
25
+ def ClearField(self, field_name: typing_extensions.Literal["ping_interval",b"ping_interval"]) -> None: ...
21
26
  global___CreateNodeRequest = CreateNodeRequest
22
27
 
23
28
  class CreateNodeResponse(google.protobuf.message.Message):
@@ -170,8 +170,24 @@ class DriverClientProxy(ClientProxy):
170
170
  )
171
171
  if len(task_res_list) == 1:
172
172
  task_res = task_res_list[0]
173
+
174
+ # This will raise an Exception if task_res carries an `error`
175
+ validate_task_res(task_res=task_res)
176
+
173
177
  return serde.recordset_from_proto(task_res.task.recordset)
174
178
 
175
179
  if timeout is not None and time.time() > start_time + timeout:
176
180
  raise RuntimeError("Timeout reached")
177
181
  time.sleep(SLEEP_TIME)
182
+
183
+
184
+ def validate_task_res(
185
+ task_res: task_pb2.TaskRes, # pylint: disable=E1101
186
+ ) -> None:
187
+ """Validate if a TaskRes is empty or not."""
188
+ if not task_res.HasField("task"):
189
+ raise ValueError("Invalid TaskRes, field `task` missing")
190
+ if task_res.task.HasField("error"):
191
+ raise ValueError("Exception during client-side task execution")
192
+ if not task_res.task.HasField("recordset"):
193
+ raise ValueError("Invalid TaskRes, both `recordset` and `error` are missing")
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  """Flower driver service client."""
16
16
 
17
-
18
17
  import time
18
+ import warnings
19
19
  from typing import Iterable, List, Optional, Tuple
20
20
 
21
21
  from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
@@ -91,7 +91,7 @@ class Driver:
91
91
  message_type: str,
92
92
  dst_node_id: int,
93
93
  group_id: str,
94
- ttl: float = DEFAULT_TTL,
94
+ ttl: Optional[float] = None,
95
95
  ) -> Message:
96
96
  """Create a new message with specified parameters.
97
97
 
@@ -111,10 +111,11 @@ class Driver:
111
111
  group_id : str
112
112
  The ID of the group to which this message is associated. In some settings,
113
113
  this is used as the FL round.
114
- ttl : float (default: common.DEFAULT_TTL)
114
+ ttl : Optional[float] (default: None)
115
115
  Time-to-live for the round trip of this message, i.e., the time from sending
116
116
  this message to receiving a reply. It specifies in seconds the duration for
117
- which the message and its potential reply are considered valid.
117
+ which the message and its potential reply are considered valid. If unset,
118
+ the default TTL (i.e., `common.DEFAULT_TTL`) will be used.
118
119
 
119
120
  Returns
120
121
  -------
@@ -122,6 +123,15 @@ class Driver:
122
123
  A new `Message` instance with the specified content and metadata.
123
124
  """
124
125
  _, run_id = self._get_grpc_driver_and_run_id()
126
+ if ttl:
127
+ warnings.warn(
128
+ "A custom TTL was set, but note that the SuperLink does not enforce "
129
+ "the TTL yet. The SuperLink will start enforcing the TTL in a future "
130
+ "version of Flower.",
131
+ stacklevel=2,
132
+ )
133
+
134
+ ttl_ = DEFAULT_TTL if ttl is None else ttl
125
135
  metadata = Metadata(
126
136
  run_id=run_id,
127
137
  message_id="", # Will be set by the server
@@ -129,7 +139,7 @@ class Driver:
129
139
  dst_node_id=dst_node_id,
130
140
  reply_to_message="",
131
141
  group_id=group_id,
132
- ttl=ttl,
142
+ ttl=ttl_,
133
143
  message_type=message_type,
134
144
  )
135
145
  return Message(metadata=metadata, content=content)
flwr/server/server_app.py CHANGED
@@ -18,6 +18,7 @@
18
18
  from typing import Callable, Optional
19
19
 
20
20
  from flwr.common import Context, RecordSet
21
+ from flwr.common.logger import warn_preview_feature
21
22
  from flwr.server.strategy import Strategy
22
23
 
23
24
  from .client_manager import ClientManager
@@ -120,6 +121,8 @@ class ServerApp:
120
121
  """,
121
122
  )
122
123
 
124
+ warn_preview_feature("ServerApp-register-main-function")
125
+
123
126
  # Register provided function with the ServerApp object
124
127
  self._main = main_fn
125
128
 
@@ -43,7 +43,7 @@ def create_node(
43
43
  ) -> CreateNodeResponse:
44
44
  """."""
45
45
  # Create node
46
- node_id = state.create_node()
46
+ node_id = state.create_node(ping_interval=request.ping_interval)
47
47
  return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
48
48
 
49
49
 
@@ -63,7 +63,8 @@ def ping(
63
63
  state: State, # pylint: disable=unused-argument
64
64
  ) -> PingResponse:
65
65
  """."""
66
- return PingResponse(success=True)
66
+ res = state.acknowledge_ping(request.node.node_id, request.ping_interval)
67
+ return PingResponse(success=res)
67
68
 
68
69
 
69
70
  def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsResponse:
@@ -21,6 +21,7 @@ from flwr.common.constant import MISSING_EXTRA_REST
21
21
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
22
22
  CreateNodeRequest,
23
23
  DeleteNodeRequest,
24
+ PingRequest,
24
25
  PullTaskInsRequest,
25
26
  PushTaskResRequest,
26
27
  )
@@ -152,11 +153,38 @@ async def push_task_res(request: Request) -> Response: # Check if token is need
152
153
  )
153
154
 
154
155
 
156
+ async def ping(request: Request) -> Response:
157
+ """Ping."""
158
+ _check_headers(request.headers)
159
+
160
+ # Get the request body as raw bytes
161
+ ping_request_bytes: bytes = await request.body()
162
+
163
+ # Deserialize ProtoBuf
164
+ ping_request_proto = PingRequest()
165
+ ping_request_proto.ParseFromString(ping_request_bytes)
166
+
167
+ # Get state from app
168
+ state: State = app.state.STATE_FACTORY.state()
169
+
170
+ # Handle message
171
+ ping_response_proto = message_handler.ping(request=ping_request_proto, state=state)
172
+
173
+ # Return serialized ProtoBuf
174
+ ping_response_bytes = ping_response_proto.SerializeToString()
175
+ return Response(
176
+ status_code=200,
177
+ content=ping_response_bytes,
178
+ headers={"Content-Type": "application/protobuf"},
179
+ )
180
+
181
+
155
182
  routes = [
156
183
  Route("/api/v0/fleet/create-node", create_node, methods=["POST"]),
157
184
  Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
158
185
  Route("/api/v0/fleet/pull-task-ins", pull_task_ins, methods=["POST"]),
159
186
  Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]),
187
+ Route("/api/v0/fleet/ping", ping, methods=["POST"]),
160
188
  ]
161
189
 
162
190
  app: Starlette = Starlette(
@@ -20,7 +20,7 @@ from typing import Callable, Dict, List, Tuple, Union
20
20
 
21
21
  import ray
22
22
 
23
- from flwr.client.client_app import ClientApp, LoadClientAppError
23
+ from flwr.client.client_app import ClientApp
24
24
  from flwr.common.context import Context
25
25
  from flwr.common.logger import log
26
26
  from flwr.common.message import Message
@@ -151,7 +151,6 @@ class RayBackend(Backend):
151
151
  )
152
152
 
153
153
  await future
154
-
155
154
  # Fetch result
156
155
  (
157
156
  out_mssg,
@@ -160,13 +159,15 @@ class RayBackend(Backend):
160
159
 
161
160
  return out_mssg, updated_context
162
161
 
163
- except LoadClientAppError as load_ex:
162
+ except Exception as ex:
164
163
  log(
165
164
  ERROR,
166
165
  "An exception was raised when processing a message by %s",
167
166
  self.__class__.__name__,
168
167
  )
169
- raise load_ex
168
+ # add actor back into pool
169
+ await self.pool.add_actor_back_to_pool(future)
170
+ raise ex
170
171
 
171
172
  async def terminate(self) -> None:
172
173
  """Terminate all actors in actor pool."""