flwr-nightly 1.8.0.dev20240328__py3-none-any.whl → 1.9.0.dev20240404__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (31) hide show
  1. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -1
  2. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  3. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +1 -1
  4. flwr/client/app.py +26 -13
  5. flwr/client/client_app.py +16 -0
  6. flwr/client/grpc_rere_client/connection.py +71 -29
  7. flwr/client/heartbeat.py +72 -0
  8. flwr/client/rest_client/connection.py +102 -28
  9. flwr/common/constant.py +20 -0
  10. flwr/common/logger.py +4 -4
  11. flwr/common/message.py +15 -0
  12. flwr/common/retry_invoker.py +24 -13
  13. flwr/proto/fleet_pb2.py +26 -26
  14. flwr/proto/fleet_pb2.pyi +5 -0
  15. flwr/server/driver/driver.py +15 -5
  16. flwr/server/server_app.py +3 -0
  17. flwr/server/superlink/fleet/message_handler/message_handler.py +3 -2
  18. flwr/server/superlink/fleet/rest_rere/rest_api.py +28 -0
  19. flwr/server/superlink/fleet/vce/vce_api.py +22 -4
  20. flwr/server/superlink/state/in_memory_state.py +25 -8
  21. flwr/server/superlink/state/sqlite_state.py +53 -5
  22. flwr/server/superlink/state/state.py +1 -1
  23. flwr/server/superlink/state/utils.py +56 -0
  24. flwr/server/workflow/default_workflows.py +1 -4
  25. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +0 -5
  26. flwr/simulation/ray_transport/ray_actor.py +2 -22
  27. {flwr_nightly-1.8.0.dev20240328.dist-info → flwr_nightly-1.9.0.dev20240404.dist-info}/METADATA +1 -1
  28. {flwr_nightly-1.8.0.dev20240328.dist-info → flwr_nightly-1.9.0.dev20240404.dist-info}/RECORD +31 -29
  29. {flwr_nightly-1.8.0.dev20240328.dist-info → flwr_nightly-1.9.0.dev20240404.dist-info}/LICENSE +0 -0
  30. {flwr_nightly-1.8.0.dev20240328.dist-info → flwr_nightly-1.9.0.dev20240404.dist-info}/WHEEL +0 -0
  31. {flwr_nightly-1.8.0.dev20240328.dist-info → flwr_nightly-1.9.0.dev20240404.dist-info}/entry_points.txt +0 -0
flwr/common/constant.py CHANGED
@@ -36,6 +36,13 @@ TRANSPORT_TYPES = [
36
36
  TRANSPORT_TYPE_VCE,
37
37
  ]
38
38
 
39
+ # Constants for ping
40
+ PING_DEFAULT_INTERVAL = 30
41
+ PING_CALL_TIMEOUT = 5
42
+ PING_BASE_MULTIPLIER = 0.8
43
+ PING_RANDOM_RANGE = (-0.1, 0.1)
44
+ PING_MAX_INTERVAL = 1e300
45
+
39
46
 
40
47
  class MessageType:
41
48
  """Message type."""
@@ -68,3 +75,16 @@ class SType:
68
75
  def __new__(cls) -> SType:
69
76
  """Prevent instantiation."""
70
77
  raise TypeError(f"{cls.__name__} cannot be instantiated.")
78
+
79
+
80
+ class ErrorCode:
81
+ """Error codes for Message's Error."""
82
+
83
+ UNKNOWN = 0
84
+ LOAD_CLIENT_APP_EXCEPTION = 1
85
+ CLIENT_APP_RAISED_EXCEPTION = 2
86
+ NODE_UNAVAILABLE = 3
87
+
88
+ def __new__(cls) -> ErrorCode:
89
+ """Prevent instantiation."""
90
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")
flwr/common/logger.py CHANGED
@@ -164,13 +164,13 @@ logger = logging.getLogger(LOGGER_NAME) # pylint: disable=invalid-name
164
164
  log = logger.log # pylint: disable=invalid-name
165
165
 
166
166
 
167
- def warn_experimental_feature(name: str) -> None:
168
- """Warn the user when they use an experimental feature."""
167
+ def warn_preview_feature(name: str) -> None:
168
+ """Warn the user when they use a preview feature."""
169
169
  log(
170
170
  WARN,
171
- """EXPERIMENTAL FEATURE: %s
171
+ """PREVIEW FEATURE: %s
172
172
 
173
- This is an experimental feature. It could change significantly or be removed
173
+ This is a preview feature. It could change significantly or be removed
174
174
  entirely in future versions of Flower.
175
175
  """,
176
176
  name,
flwr/common/message.py CHANGED
@@ -17,6 +17,7 @@
17
17
  from __future__ import annotations
18
18
 
19
19
  import time
20
+ import warnings
20
21
  from dataclasses import dataclass
21
22
 
22
23
  from .record import RecordSet
@@ -311,6 +312,13 @@ class Message:
311
312
 
312
313
  ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at)
313
314
  """
315
+ if ttl:
316
+ warnings.warn(
317
+ "A custom TTL was set, but note that the SuperLink does not enforce "
318
+ "the TTL yet. The SuperLink will start enforcing the TTL in a future "
319
+ "version of Flower.",
320
+ stacklevel=2,
321
+ )
314
322
  # If no TTL passed, use default for message creation (will update after
315
323
  # message creation)
316
324
  ttl_ = DEFAULT_TTL if ttl is None else ttl
@@ -349,6 +357,13 @@ class Message:
349
357
  Message
350
358
  A new `Message` instance representing the reply.
351
359
  """
360
+ if ttl:
361
+ warnings.warn(
362
+ "A custom TTL was set, but note that the SuperLink does not enforce "
363
+ "the TTL yet. The SuperLink will start enforcing the TTL in a future "
364
+ "version of Flower.",
365
+ stacklevel=2,
366
+ )
352
367
  # If no TTL passed, use default for message creation (will update after
353
368
  # message creation)
354
369
  ttl_ = DEFAULT_TTL if ttl is None else ttl
@@ -107,7 +107,7 @@ class RetryInvoker:
107
107
 
108
108
  Parameters
109
109
  ----------
110
- wait_factory: Callable[[], Generator[float, None, None]]
110
+ wait_gen_factory: Callable[[], Generator[float, None, None]]
111
111
  A generator yielding successive wait times in seconds. If the generator
112
112
  is finite, the giveup event will be triggered when the generator raises
113
113
  `StopIteration`.
@@ -129,12 +129,12 @@ class RetryInvoker:
129
129
  data class object detailing the invocation.
130
130
  on_giveup: Optional[Callable[[RetryState], None]] (default: None)
131
131
  A callable to be executed in the event that `max_tries` or `max_time` is
132
- exceeded, `should_giveup` returns True, or `wait_factory()` generator raises
132
+ exceeded, `should_giveup` returns True, or `wait_gen_factory()` generator raises
133
133
  `StopInteration`. The parameter is a data class object detailing the
134
134
  invocation.
135
135
  jitter: Optional[Callable[[float], float]] (default: full_jitter)
136
- A function of the value yielded by `wait_factory()` returning the actual time
137
- to wait. This function helps distribute wait times stochastically to avoid
136
+ A function of the value yielded by `wait_gen_factory()` returning the actual
137
+ time to wait. This function helps distribute wait times stochastically to avoid
138
138
  timing collisions across concurrent clients. Wait times are jittered by
139
139
  default using the `full_jitter` function. To disable jittering, pass
140
140
  `jitter=None`.
@@ -142,6 +142,13 @@ class RetryInvoker:
142
142
  A function accepting an exception instance, returning whether or not
143
143
  to give up prematurely before other give-up conditions are evaluated.
144
144
  If set to None, the strategy is to never give up prematurely.
145
+ wait_function: Optional[Callable[[float], None]] (default: None)
146
+ A function that defines how to wait between retry attempts. It accepts
147
+ one argument, the wait time in seconds, allowing the use of various waiting
148
+ mechanisms (e.g., asynchronous waits or event-based synchronization) suitable
149
+ for different execution environments. If set to `None`, the `wait_function`
150
+ defaults to `time.sleep`, which is ideal for synchronous operations. Custom
151
+ functions should manage execution flow to prevent blocking or interference.
145
152
 
146
153
  Examples
147
154
  --------
@@ -159,7 +166,7 @@ class RetryInvoker:
159
166
  # pylint: disable-next=too-many-arguments
160
167
  def __init__(
161
168
  self,
162
- wait_factory: Callable[[], Generator[float, None, None]],
169
+ wait_gen_factory: Callable[[], Generator[float, None, None]],
163
170
  recoverable_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]],
164
171
  max_tries: Optional[int],
165
172
  max_time: Optional[float],
@@ -169,8 +176,9 @@ class RetryInvoker:
169
176
  on_giveup: Optional[Callable[[RetryState], None]] = None,
170
177
  jitter: Optional[Callable[[float], float]] = full_jitter,
171
178
  should_giveup: Optional[Callable[[Exception], bool]] = None,
179
+ wait_function: Optional[Callable[[float], None]] = None,
172
180
  ) -> None:
173
- self.wait_factory = wait_factory
181
+ self.wait_gen_factory = wait_gen_factory
174
182
  self.recoverable_exceptions = recoverable_exceptions
175
183
  self.max_tries = max_tries
176
184
  self.max_time = max_time
@@ -179,6 +187,9 @@ class RetryInvoker:
179
187
  self.on_giveup = on_giveup
180
188
  self.jitter = jitter
181
189
  self.should_giveup = should_giveup
190
+ if wait_function is None:
191
+ wait_function = time.sleep
192
+ self.wait_function = wait_function
182
193
 
183
194
  # pylint: disable-next=too-many-locals
184
195
  def invoke(
@@ -212,13 +223,13 @@ class RetryInvoker:
212
223
  Raises
213
224
  ------
214
225
  Exception
215
- If the number of tries exceeds `max_tries`, if the total time
216
- exceeds `max_time`, if `wait_factory()` generator raises `StopInteration`,
226
+ If the number of tries exceeds `max_tries`, if the total time exceeds
227
+ `max_time`, if `wait_gen_factory()` generator raises `StopInteration`,
217
228
  or if the `should_giveup` returns True for a raised exception.
218
229
 
219
230
  Notes
220
231
  -----
221
- The time between retries is determined by the provided `wait_factory()`
232
+ The time between retries is determined by the provided `wait_gen_factory()`
222
233
  generator and can optionally be jittered using the `jitter` function.
223
234
  The recoverable exceptions that trigger a retry, as well as conditions to
224
235
  stop retries, are also determined by the class's initialization parameters.
@@ -231,13 +242,13 @@ class RetryInvoker:
231
242
  handler(cast(RetryState, ref_state[0]))
232
243
 
233
244
  try_cnt = 0
234
- wait_generator = self.wait_factory()
235
- start = time.time()
245
+ wait_generator = self.wait_gen_factory()
246
+ start = time.monotonic()
236
247
  ref_state: List[Optional[RetryState]] = [None]
237
248
 
238
249
  while True:
239
250
  try_cnt += 1
240
- elapsed_time = time.time() - start
251
+ elapsed_time = time.monotonic() - start
241
252
  state = RetryState(
242
253
  target=target,
243
254
  args=args,
@@ -282,7 +293,7 @@ class RetryInvoker:
282
293
  try_call_event_handler(self.on_backoff)
283
294
 
284
295
  # Sleep
285
- time.sleep(wait_time)
296
+ self.wait_function(state.actual_wait)
286
297
  else:
287
298
  # Trigger success event
288
299
  try_call_event_handler(self.on_success)
flwr/proto/fleet_pb2.py CHANGED
@@ -16,7 +16,7 @@ from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
16
16
  from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2
17
17
 
18
18
 
19
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x13\n\x11\x43reateNodeRequest\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"D\n\x0bPingRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x15\n\rping_interval\x18\x02 \x01(\x01\"\x1f\n\x0cPingResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\x86\x03\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12;\n\x04Ping\x12\x17.flwr.proto.PingRequest\x1a\x18.flwr.proto.PingResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x62\x06proto3')
19
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"*\n\x11\x43reateNodeRequest\x12\x15\n\rping_interval\x18\x01 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"D\n\x0bPingRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x15\n\rping_interval\x18\x02 \x01(\x01\"\x1f\n\x0cPingResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\x86\x03\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12;\n\x04Ping\x12\x17.flwr.proto.PingRequest\x1a\x18.flwr.proto.PingResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x62\x06proto3')
20
20
 
21
21
  _globals = globals()
22
22
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -26,29 +26,29 @@ if _descriptor._USE_C_DESCRIPTORS == False:
26
26
  _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._options = None
27
27
  _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_options = b'8\001'
28
28
  _globals['_CREATENODEREQUEST']._serialized_start=84
29
- _globals['_CREATENODEREQUEST']._serialized_end=103
30
- _globals['_CREATENODERESPONSE']._serialized_start=105
31
- _globals['_CREATENODERESPONSE']._serialized_end=157
32
- _globals['_DELETENODEREQUEST']._serialized_start=159
33
- _globals['_DELETENODEREQUEST']._serialized_end=210
34
- _globals['_DELETENODERESPONSE']._serialized_start=212
35
- _globals['_DELETENODERESPONSE']._serialized_end=232
36
- _globals['_PINGREQUEST']._serialized_start=234
37
- _globals['_PINGREQUEST']._serialized_end=302
38
- _globals['_PINGRESPONSE']._serialized_start=304
39
- _globals['_PINGRESPONSE']._serialized_end=335
40
- _globals['_PULLTASKINSREQUEST']._serialized_start=337
41
- _globals['_PULLTASKINSREQUEST']._serialized_end=407
42
- _globals['_PULLTASKINSRESPONSE']._serialized_start=409
43
- _globals['_PULLTASKINSRESPONSE']._serialized_end=516
44
- _globals['_PUSHTASKRESREQUEST']._serialized_start=518
45
- _globals['_PUSHTASKRESREQUEST']._serialized_end=582
46
- _globals['_PUSHTASKRESRESPONSE']._serialized_start=585
47
- _globals['_PUSHTASKRESRESPONSE']._serialized_end=759
48
- _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=713
49
- _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=759
50
- _globals['_RECONNECT']._serialized_start=761
51
- _globals['_RECONNECT']._serialized_end=791
52
- _globals['_FLEET']._serialized_start=794
53
- _globals['_FLEET']._serialized_end=1184
29
+ _globals['_CREATENODEREQUEST']._serialized_end=126
30
+ _globals['_CREATENODERESPONSE']._serialized_start=128
31
+ _globals['_CREATENODERESPONSE']._serialized_end=180
32
+ _globals['_DELETENODEREQUEST']._serialized_start=182
33
+ _globals['_DELETENODEREQUEST']._serialized_end=233
34
+ _globals['_DELETENODERESPONSE']._serialized_start=235
35
+ _globals['_DELETENODERESPONSE']._serialized_end=255
36
+ _globals['_PINGREQUEST']._serialized_start=257
37
+ _globals['_PINGREQUEST']._serialized_end=325
38
+ _globals['_PINGRESPONSE']._serialized_start=327
39
+ _globals['_PINGRESPONSE']._serialized_end=358
40
+ _globals['_PULLTASKINSREQUEST']._serialized_start=360
41
+ _globals['_PULLTASKINSREQUEST']._serialized_end=430
42
+ _globals['_PULLTASKINSRESPONSE']._serialized_start=432
43
+ _globals['_PULLTASKINSRESPONSE']._serialized_end=539
44
+ _globals['_PUSHTASKRESREQUEST']._serialized_start=541
45
+ _globals['_PUSHTASKRESREQUEST']._serialized_end=605
46
+ _globals['_PUSHTASKRESRESPONSE']._serialized_start=608
47
+ _globals['_PUSHTASKRESRESPONSE']._serialized_end=782
48
+ _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=736
49
+ _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=782
50
+ _globals['_RECONNECT']._serialized_start=784
51
+ _globals['_RECONNECT']._serialized_end=814
52
+ _globals['_FLEET']._serialized_start=817
53
+ _globals['_FLEET']._serialized_end=1207
54
54
  # @@protoc_insertion_point(module_scope)
flwr/proto/fleet_pb2.pyi CHANGED
@@ -16,8 +16,13 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
16
16
  class CreateNodeRequest(google.protobuf.message.Message):
17
17
  """CreateNode messages"""
18
18
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
19
+ PING_INTERVAL_FIELD_NUMBER: builtins.int
20
+ ping_interval: builtins.float
19
21
  def __init__(self,
22
+ *,
23
+ ping_interval: builtins.float = ...,
20
24
  ) -> None: ...
25
+ def ClearField(self, field_name: typing_extensions.Literal["ping_interval",b"ping_interval"]) -> None: ...
21
26
  global___CreateNodeRequest = CreateNodeRequest
22
27
 
23
28
  class CreateNodeResponse(google.protobuf.message.Message):
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  """Flower driver service client."""
16
16
 
17
-
18
17
  import time
18
+ import warnings
19
19
  from typing import Iterable, List, Optional, Tuple
20
20
 
21
21
  from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
@@ -91,7 +91,7 @@ class Driver:
91
91
  message_type: str,
92
92
  dst_node_id: int,
93
93
  group_id: str,
94
- ttl: float = DEFAULT_TTL,
94
+ ttl: Optional[float] = None,
95
95
  ) -> Message:
96
96
  """Create a new message with specified parameters.
97
97
 
@@ -111,10 +111,11 @@ class Driver:
111
111
  group_id : str
112
112
  The ID of the group to which this message is associated. In some settings,
113
113
  this is used as the FL round.
114
- ttl : float (default: common.DEFAULT_TTL)
114
+ ttl : Optional[float] (default: None)
115
115
  Time-to-live for the round trip of this message, i.e., the time from sending
116
116
  this message to receiving a reply. It specifies in seconds the duration for
117
- which the message and its potential reply are considered valid.
117
+ which the message and its potential reply are considered valid. If unset,
118
+ the default TTL (i.e., `common.DEFAULT_TTL`) will be used.
118
119
 
119
120
  Returns
120
121
  -------
@@ -122,6 +123,15 @@ class Driver:
122
123
  A new `Message` instance with the specified content and metadata.
123
124
  """
124
125
  _, run_id = self._get_grpc_driver_and_run_id()
126
+ if ttl:
127
+ warnings.warn(
128
+ "A custom TTL was set, but note that the SuperLink does not enforce "
129
+ "the TTL yet. The SuperLink will start enforcing the TTL in a future "
130
+ "version of Flower.",
131
+ stacklevel=2,
132
+ )
133
+
134
+ ttl_ = DEFAULT_TTL if ttl is None else ttl
125
135
  metadata = Metadata(
126
136
  run_id=run_id,
127
137
  message_id="", # Will be set by the server
@@ -129,7 +139,7 @@ class Driver:
129
139
  dst_node_id=dst_node_id,
130
140
  reply_to_message="",
131
141
  group_id=group_id,
132
- ttl=ttl,
142
+ ttl=ttl_,
133
143
  message_type=message_type,
134
144
  )
135
145
  return Message(metadata=metadata, content=content)
flwr/server/server_app.py CHANGED
@@ -18,6 +18,7 @@
18
18
  from typing import Callable, Optional
19
19
 
20
20
  from flwr.common import Context, RecordSet
21
+ from flwr.common.logger import warn_preview_feature
21
22
  from flwr.server.strategy import Strategy
22
23
 
23
24
  from .client_manager import ClientManager
@@ -120,6 +121,8 @@ class ServerApp:
120
121
  """,
121
122
  )
122
123
 
124
+ warn_preview_feature("ServerApp-register-main-function")
125
+
123
126
  # Register provided function with the ServerApp object
124
127
  self._main = main_fn
125
128
 
@@ -43,7 +43,7 @@ def create_node(
43
43
  ) -> CreateNodeResponse:
44
44
  """."""
45
45
  # Create node
46
- node_id = state.create_node()
46
+ node_id = state.create_node(ping_interval=request.ping_interval)
47
47
  return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
48
48
 
49
49
 
@@ -63,7 +63,8 @@ def ping(
63
63
  state: State, # pylint: disable=unused-argument
64
64
  ) -> PingResponse:
65
65
  """."""
66
- return PingResponse(success=True)
66
+ res = state.acknowledge_ping(request.node.node_id, request.ping_interval)
67
+ return PingResponse(success=res)
67
68
 
68
69
 
69
70
  def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsResponse:
@@ -21,6 +21,7 @@ from flwr.common.constant import MISSING_EXTRA_REST
21
21
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
22
22
  CreateNodeRequest,
23
23
  DeleteNodeRequest,
24
+ PingRequest,
24
25
  PullTaskInsRequest,
25
26
  PushTaskResRequest,
26
27
  )
@@ -152,11 +153,38 @@ async def push_task_res(request: Request) -> Response: # Check if token is need
152
153
  )
153
154
 
154
155
 
156
+ async def ping(request: Request) -> Response:
157
+ """Ping."""
158
+ _check_headers(request.headers)
159
+
160
+ # Get the request body as raw bytes
161
+ ping_request_bytes: bytes = await request.body()
162
+
163
+ # Deserialize ProtoBuf
164
+ ping_request_proto = PingRequest()
165
+ ping_request_proto.ParseFromString(ping_request_bytes)
166
+
167
+ # Get state from app
168
+ state: State = app.state.STATE_FACTORY.state()
169
+
170
+ # Handle message
171
+ ping_response_proto = message_handler.ping(request=ping_request_proto, state=state)
172
+
173
+ # Return serialized ProtoBuf
174
+ ping_response_bytes = ping_response_proto.SerializeToString()
175
+ return Response(
176
+ status_code=200,
177
+ content=ping_response_bytes,
178
+ headers={"Content-Type": "application/protobuf"},
179
+ )
180
+
181
+
155
182
  routes = [
156
183
  Route("/api/v0/fleet/create-node", create_node, methods=["POST"]),
157
184
  Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
158
185
  Route("/api/v0/fleet/pull-task-ins", pull_task_ins, methods=["POST"]),
159
186
  Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]),
187
+ Route("/api/v0/fleet/ping", ping, methods=["POST"]),
160
188
  ]
161
189
 
162
190
  app: Starlette = Starlette(
@@ -22,8 +22,9 @@ import traceback
22
22
  from logging import DEBUG, ERROR, INFO, WARN
23
23
  from typing import Callable, Dict, List, Optional
24
24
 
25
- from flwr.client.client_app import ClientApp, LoadClientAppError
25
+ from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
26
26
  from flwr.client.node_state import NodeState
27
+ from flwr.common.constant import PING_MAX_INTERVAL, ErrorCode
27
28
  from flwr.common.logger import log
28
29
  from flwr.common.message import Error
29
30
  from flwr.common.object_ref import load_app
@@ -43,7 +44,7 @@ def _register_nodes(
43
44
  nodes_mapping: NodeToPartitionMapping = {}
44
45
  state = state_factory.state()
45
46
  for i in range(num_nodes):
46
- node_id = state.create_node()
47
+ node_id = state.create_node(ping_interval=PING_MAX_INTERVAL)
47
48
  nodes_mapping[node_id] = i
48
49
  log(INFO, "Registered %i nodes", len(nodes_mapping))
49
50
  return nodes_mapping
@@ -93,9 +94,18 @@ async def worker(
93
94
  except Exception as ex: # pylint: disable=broad-exception-caught
94
95
  log(ERROR, ex)
95
96
  log(ERROR, traceback.format_exc())
97
+
98
+ if isinstance(ex, ClientAppException):
99
+ e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION
100
+ elif isinstance(ex, LoadClientAppError):
101
+ e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
102
+ else:
103
+ e_code = ErrorCode.UNKNOWN
104
+
96
105
  reason = str(type(ex)) + ":<'" + str(ex) + "'>"
97
- error = Error(code=0, reason=reason)
98
- out_mssg = message.create_error_reply(error=error)
106
+ out_mssg = message.create_error_reply(
107
+ error=Error(code=e_code, reason=reason)
108
+ )
99
109
 
100
110
  finally:
101
111
  if out_mssg:
@@ -223,6 +233,7 @@ async def run(
223
233
 
224
234
 
225
235
  # pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches
236
+ # pylint: disable=too-many-statements
226
237
  def start_vce(
227
238
  backend_name: str,
228
239
  backend_config_json_stream: str,
@@ -341,6 +352,13 @@ def start_vce(
341
352
  )
342
353
  )
343
354
  except LoadClientAppError as loadapp_ex:
355
+ f_stop_delay = 10
356
+ log(
357
+ ERROR,
358
+ "LoadClientAppError exception encountered. Terminating simulation in %is",
359
+ f_stop_delay,
360
+ )
361
+ time.sleep(f_stop_delay)
344
362
  f_stop.set() # set termination event
345
363
  raise loadapp_ex
346
364
  except Exception as ex:
@@ -27,6 +27,8 @@ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
27
27
  from flwr.server.superlink.state.state import State
28
28
  from flwr.server.utils import validate_task_ins_or_res
29
29
 
30
+ from .utils import make_node_unavailable_taskres
31
+
30
32
 
31
33
  class InMemoryState(State):
32
34
  """In-memory State implementation."""
@@ -129,15 +131,32 @@ class InMemoryState(State):
129
131
  with self.lock:
130
132
  # Find TaskRes that were not delivered yet
131
133
  task_res_list: List[TaskRes] = []
134
+ replied_task_ids: Set[UUID] = set()
132
135
  for _, task_res in self.task_res_store.items():
133
- if (
134
- UUID(task_res.task.ancestry[0]) in task_ids
135
- and task_res.task.delivered_at == ""
136
- ):
136
+ reply_to = UUID(task_res.task.ancestry[0])
137
+ if reply_to in task_ids and task_res.task.delivered_at == "":
137
138
  task_res_list.append(task_res)
139
+ replied_task_ids.add(reply_to)
138
140
  if limit and len(task_res_list) == limit:
139
141
  break
140
142
 
143
+ # Check if the node is offline
144
+ for task_id in task_ids - replied_task_ids:
145
+ if limit and len(task_res_list) == limit:
146
+ break
147
+ task_ins = self.task_ins_store.get(task_id)
148
+ if task_ins is None:
149
+ continue
150
+ node_id = task_ins.task.consumer.node_id
151
+ online_until, _ = self.node_ids[node_id]
152
+ # Generate a TaskRes containing an error reply if the node is offline.
153
+ if online_until < time.time():
154
+ err_taskres = make_node_unavailable_taskres(
155
+ ref_taskins=task_ins,
156
+ )
157
+ self.task_res_store[UUID(err_taskres.task_id)] = err_taskres
158
+ task_res_list.append(err_taskres)
159
+
141
160
  # Mark all of them as delivered
142
161
  delivered_at = now().isoformat()
143
162
  for task_res in task_res_list:
@@ -182,16 +201,14 @@ class InMemoryState(State):
182
201
  """
183
202
  return len(self.task_res_store)
184
203
 
185
- def create_node(self) -> int:
204
+ def create_node(self, ping_interval: float) -> int:
186
205
  """Create, store in state, and return `node_id`."""
187
206
  # Sample a random int64 as node_id
188
207
  node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
189
208
 
190
209
  with self.lock:
191
210
  if node_id not in self.node_ids:
192
- # Default ping interval is 30s
193
- # TODO: change 1e9 to 30s # pylint: disable=W0511
194
- self.node_ids[node_id] = (time.time() + 1e9, 1e9)
211
+ self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
195
212
  return node_id
196
213
  log(ERROR, "Unexpected node registration failure.")
197
214
  return 0
@@ -30,6 +30,7 @@ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
30
30
  from flwr.server.utils.validator import validate_task_ins_or_res
31
31
 
32
32
  from .state import State
33
+ from .utils import make_node_unavailable_taskres
33
34
 
34
35
  SQL_CREATE_TABLE_NODE = """
35
36
  CREATE TABLE IF NOT EXISTS node(
@@ -344,6 +345,7 @@ class SqliteState(State):
344
345
 
345
346
  return task_id
346
347
 
348
+ # pylint: disable-next=R0914
347
349
  def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRes]:
348
350
  """Get TaskRes for task_ids.
349
351
 
@@ -374,7 +376,7 @@ class SqliteState(State):
374
376
  AND delivered_at = ""
375
377
  """
376
378
 
377
- data: Dict[str, Union[str, int]] = {}
379
+ data: Dict[str, Union[str, float, int]] = {}
378
380
 
379
381
  if limit is not None:
380
382
  query += " LIMIT :limit"
@@ -408,6 +410,54 @@ class SqliteState(State):
408
410
  rows = self.query(query, data)
409
411
 
410
412
  result = [dict_to_task_res(row) for row in rows]
413
+
414
+ # 1. Query: Fetch consumer_node_id of remaining task_ids
415
+ # Assume the ancestry field only contains one element
416
+ data.clear()
417
+ replied_task_ids: Set[UUID] = {UUID(str(row["ancestry"])) for row in rows}
418
+ remaining_task_ids = task_ids - replied_task_ids
419
+ placeholders = ",".join([f":id_{i}" for i in range(len(remaining_task_ids))])
420
+ query = f"""
421
+ SELECT consumer_node_id
422
+ FROM task_ins
423
+ WHERE task_id IN ({placeholders});
424
+ """
425
+ for index, task_id in enumerate(remaining_task_ids):
426
+ data[f"id_{index}"] = str(task_id)
427
+ node_ids = [int(row["consumer_node_id"]) for row in self.query(query, data)]
428
+
429
+ # 2. Query: Select offline nodes
430
+ placeholders = ",".join([f":id_{i}" for i in range(len(node_ids))])
431
+ query = f"""
432
+ SELECT node_id
433
+ FROM node
434
+ WHERE node_id IN ({placeholders})
435
+ AND online_until < :time;
436
+ """
437
+ data = {f"id_{i}": str(node_id) for i, node_id in enumerate(node_ids)}
438
+ data["time"] = time.time()
439
+ offline_node_ids = [int(row["node_id"]) for row in self.query(query, data)]
440
+
441
+ # 3. Query: Select TaskIns for offline nodes
442
+ placeholders = ",".join([f":id_{i}" for i in range(len(offline_node_ids))])
443
+ query = f"""
444
+ SELECT *
445
+ FROM task_ins
446
+ WHERE consumer_node_id IN ({placeholders});
447
+ """
448
+ data = {f"id_{i}": str(node_id) for i, node_id in enumerate(offline_node_ids)}
449
+ task_ins_rows = self.query(query, data)
450
+
451
+ # Make TaskRes containing node unavailabe error
452
+ for row in task_ins_rows:
453
+ if limit and len(result) == limit:
454
+ break
455
+ task_ins = dict_to_task_ins(row)
456
+ err_taskres = make_node_unavailable_taskres(
457
+ ref_taskins=task_ins,
458
+ )
459
+ result.append(err_taskres)
460
+
411
461
  return result
412
462
 
413
463
  def num_task_ins(self) -> int:
@@ -468,7 +518,7 @@ class SqliteState(State):
468
518
 
469
519
  return None
470
520
 
471
- def create_node(self) -> int:
521
+ def create_node(self, ping_interval: float) -> int:
472
522
  """Create, store in state, and return `node_id`."""
473
523
  # Sample a random int64 as node_id
474
524
  node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
@@ -478,9 +528,7 @@ class SqliteState(State):
478
528
  )
479
529
 
480
530
  try:
481
- # Default ping interval is 30s
482
- # TODO: change 1e9 to 30s # pylint: disable=W0511
483
- self.query(query, (node_id, time.time() + 1e9, 1e9))
531
+ self.query(query, (node_id, time.time() + ping_interval, ping_interval))
484
532
  except sqlite3.IntegrityError:
485
533
  log(ERROR, "Unexpected node registration failure.")
486
534
  return 0
@@ -132,7 +132,7 @@ class State(abc.ABC):
132
132
  """Delete all delivered TaskIns/TaskRes pairs."""
133
133
 
134
134
  @abc.abstractmethod
135
- def create_node(self) -> int:
135
+ def create_node(self, ping_interval: float) -> int:
136
136
  """Create, store in state, and return `node_id`."""
137
137
 
138
138
  @abc.abstractmethod