flwr 1.18.0__py3-none-any.whl → 1.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. flwr/app/__init__.py +15 -0
  2. flwr/app/error.py +68 -0
  3. flwr/app/metadata.py +223 -0
  4. flwr/cli/build.py +94 -59
  5. flwr/cli/log.py +3 -3
  6. flwr/cli/login/login.py +3 -7
  7. flwr/cli/ls.py +15 -36
  8. flwr/cli/new/new.py +12 -4
  9. flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
  10. flwr/cli/new/templates/app/README.md.tpl +5 -0
  11. flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
  12. flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
  13. flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
  14. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +25 -17
  15. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
  16. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
  17. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
  18. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
  19. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
  20. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
  21. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
  22. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
  23. flwr/cli/run/run.py +48 -49
  24. flwr/cli/stop.py +2 -2
  25. flwr/cli/utils.py +38 -5
  26. flwr/client/__init__.py +2 -2
  27. flwr/client/client_app.py +1 -1
  28. flwr/client/clientapp/__init__.py +0 -7
  29. flwr/client/grpc_adapter_client/connection.py +15 -8
  30. flwr/client/grpc_rere_client/connection.py +142 -97
  31. flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
  32. flwr/client/message_handler/message_handler.py +1 -1
  33. flwr/client/mod/comms_mods.py +36 -17
  34. flwr/client/rest_client/connection.py +176 -103
  35. flwr/clientapp/__init__.py +15 -0
  36. flwr/common/__init__.py +2 -2
  37. flwr/common/auth_plugin/__init__.py +2 -0
  38. flwr/common/auth_plugin/auth_plugin.py +29 -3
  39. flwr/common/constant.py +39 -8
  40. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  41. flwr/common/exit/exit_code.py +16 -1
  42. flwr/common/exit_handlers.py +30 -0
  43. flwr/common/grpc.py +12 -1
  44. flwr/common/heartbeat.py +165 -0
  45. flwr/common/inflatable.py +290 -0
  46. flwr/common/inflatable_protobuf_utils.py +141 -0
  47. flwr/common/inflatable_utils.py +508 -0
  48. flwr/common/message.py +110 -242
  49. flwr/common/record/__init__.py +2 -1
  50. flwr/common/record/array.py +402 -0
  51. flwr/common/record/arraychunk.py +59 -0
  52. flwr/common/record/arrayrecord.py +103 -225
  53. flwr/common/record/configrecord.py +59 -4
  54. flwr/common/record/conversion_utils.py +1 -1
  55. flwr/common/record/metricrecord.py +55 -4
  56. flwr/common/record/recorddict.py +69 -1
  57. flwr/common/recorddict_compat.py +2 -2
  58. flwr/common/retry_invoker.py +5 -1
  59. flwr/common/serde.py +59 -211
  60. flwr/common/serde_utils.py +175 -0
  61. flwr/common/typing.py +5 -3
  62. flwr/compat/__init__.py +15 -0
  63. flwr/compat/client/__init__.py +15 -0
  64. flwr/{client → compat/client}/app.py +28 -185
  65. flwr/compat/common/__init__.py +15 -0
  66. flwr/compat/server/__init__.py +15 -0
  67. flwr/compat/server/app.py +174 -0
  68. flwr/compat/simulation/__init__.py +15 -0
  69. flwr/proto/appio_pb2.py +43 -0
  70. flwr/proto/appio_pb2.pyi +151 -0
  71. flwr/proto/appio_pb2_grpc.py +4 -0
  72. flwr/proto/appio_pb2_grpc.pyi +4 -0
  73. flwr/proto/clientappio_pb2.py +12 -19
  74. flwr/proto/clientappio_pb2.pyi +23 -101
  75. flwr/proto/clientappio_pb2_grpc.py +269 -28
  76. flwr/proto/clientappio_pb2_grpc.pyi +114 -20
  77. flwr/proto/fleet_pb2.py +24 -27
  78. flwr/proto/fleet_pb2.pyi +19 -35
  79. flwr/proto/fleet_pb2_grpc.py +117 -13
  80. flwr/proto/fleet_pb2_grpc.pyi +47 -6
  81. flwr/proto/heartbeat_pb2.py +33 -0
  82. flwr/proto/heartbeat_pb2.pyi +66 -0
  83. flwr/proto/heartbeat_pb2_grpc.py +4 -0
  84. flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
  85. flwr/proto/message_pb2.py +28 -11
  86. flwr/proto/message_pb2.pyi +125 -0
  87. flwr/proto/recorddict_pb2.py +16 -28
  88. flwr/proto/recorddict_pb2.pyi +46 -64
  89. flwr/proto/run_pb2.py +24 -32
  90. flwr/proto/run_pb2.pyi +4 -52
  91. flwr/proto/serverappio_pb2.py +9 -23
  92. flwr/proto/serverappio_pb2.pyi +0 -110
  93. flwr/proto/serverappio_pb2_grpc.py +177 -72
  94. flwr/proto/serverappio_pb2_grpc.pyi +75 -33
  95. flwr/proto/simulationio_pb2.py +12 -11
  96. flwr/proto/simulationio_pb2_grpc.py +35 -0
  97. flwr/proto/simulationio_pb2_grpc.pyi +14 -0
  98. flwr/server/__init__.py +1 -1
  99. flwr/server/app.py +69 -187
  100. flwr/server/compat/app_utils.py +50 -28
  101. flwr/server/fleet_event_log_interceptor.py +6 -2
  102. flwr/server/grid/grpc_grid.py +148 -41
  103. flwr/server/grid/inmemory_grid.py +5 -4
  104. flwr/server/serverapp/app.py +45 -17
  105. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +21 -3
  106. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
  107. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
  108. flwr/server/superlink/fleet/message_handler/message_handler.py +130 -19
  109. flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -13
  110. flwr/server/superlink/fleet/vce/vce_api.py +6 -3
  111. flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
  112. flwr/server/superlink/linkstate/linkstate.py +53 -20
  113. flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
  114. flwr/server/superlink/linkstate/utils.py +33 -29
  115. flwr/server/superlink/serverappio/serverappio_grpc.py +4 -1
  116. flwr/server/superlink/serverappio/serverappio_servicer.py +230 -84
  117. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  118. flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
  119. flwr/server/superlink/utils.py +9 -2
  120. flwr/server/utils/validator.py +2 -2
  121. flwr/serverapp/__init__.py +15 -0
  122. flwr/simulation/app.py +25 -0
  123. flwr/simulation/run_simulation.py +17 -0
  124. flwr/supercore/__init__.py +15 -0
  125. flwr/{server/superlink → supercore}/ffs/__init__.py +2 -0
  126. flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
  127. flwr/supercore/grpc_health/__init__.py +22 -0
  128. flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
  129. flwr/supercore/license_plugin/__init__.py +22 -0
  130. flwr/supercore/license_plugin/license_plugin.py +26 -0
  131. flwr/supercore/object_store/__init__.py +24 -0
  132. flwr/supercore/object_store/in_memory_object_store.py +229 -0
  133. flwr/supercore/object_store/object_store.py +170 -0
  134. flwr/supercore/object_store/object_store_factory.py +44 -0
  135. flwr/supercore/object_store/utils.py +43 -0
  136. flwr/supercore/scheduler/__init__.py +22 -0
  137. flwr/supercore/scheduler/plugin.py +71 -0
  138. flwr/{client/nodestate/nodestate.py → supercore/utils.py} +14 -13
  139. flwr/superexec/deployment.py +7 -4
  140. flwr/superexec/exec_event_log_interceptor.py +8 -4
  141. flwr/superexec/exec_grpc.py +25 -5
  142. flwr/superexec/exec_license_interceptor.py +82 -0
  143. flwr/superexec/exec_servicer.py +135 -24
  144. flwr/superexec/exec_user_auth_interceptor.py +45 -8
  145. flwr/superexec/executor.py +5 -1
  146. flwr/superexec/simulation.py +8 -3
  147. flwr/superlink/__init__.py +15 -0
  148. flwr/{client/supernode → supernode}/__init__.py +0 -7
  149. flwr/supernode/cli/__init__.py +24 -0
  150. flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -19
  151. flwr/supernode/cli/flwr_clientapp.py +88 -0
  152. flwr/supernode/nodestate/in_memory_nodestate.py +199 -0
  153. flwr/supernode/nodestate/nodestate.py +227 -0
  154. flwr/supernode/runtime/__init__.py +15 -0
  155. flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +135 -89
  156. flwr/supernode/scheduler/__init__.py +22 -0
  157. flwr/supernode/scheduler/simple_clientapp_scheduler_plugin.py +49 -0
  158. flwr/supernode/servicer/__init__.py +15 -0
  159. flwr/supernode/servicer/clientappio/__init__.py +22 -0
  160. flwr/supernode/servicer/clientappio/clientappio_servicer.py +303 -0
  161. flwr/supernode/start_client_internal.py +589 -0
  162. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/METADATA +6 -4
  163. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/RECORD +171 -123
  164. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/WHEEL +1 -1
  165. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/entry_points.txt +2 -2
  166. flwr/client/clientapp/clientappio_servicer.py +0 -244
  167. flwr/client/heartbeat.py +0 -74
  168. flwr/client/nodestate/in_memory_nodestate.py +0 -38
  169. /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
  170. /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
  171. /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
  172. /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
  173. /flwr/{client → supernode}/nodestate/__init__.py +0 -0
  174. /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
@@ -18,19 +18,22 @@
18
18
  import threading
19
19
  import time
20
20
  from bisect import bisect_right
21
+ from collections import defaultdict
21
22
  from dataclasses import dataclass, field
22
23
  from logging import ERROR, WARNING
23
24
  from typing import Optional
24
- from uuid import UUID, uuid4
25
25
 
26
26
  from flwr.common import Context, Message, log, now
27
27
  from flwr.common.constant import (
28
+ HEARTBEAT_MAX_INTERVAL,
29
+ HEARTBEAT_PATIENCE,
28
30
  MESSAGE_TTL_TOLERANCE,
29
31
  NODE_ID_NUM_BYTES,
30
- PING_PATIENCE,
32
+ RUN_FAILURE_DETAILS_NO_HEARTBEAT,
31
33
  RUN_ID_NUM_BYTES,
32
34
  SUPERLINK_NODE_ID,
33
35
  Status,
36
+ SubStatus,
34
37
  )
35
38
  from flwr.common.record import ConfigRecord
36
39
  from flwr.common.typing import Run, RunStatus, UserConfig
@@ -52,8 +55,11 @@ class RunRecord: # pylint: disable=R0902
52
55
  """The record of a specific run, including its status and timestamps."""
53
56
 
54
57
  run: Run
58
+ active_until: float = 0.0
59
+ heartbeat_interval: float = 0.0
55
60
  logs: list[tuple[float, str]] = field(default_factory=list)
56
61
  log_lock: threading.Lock = field(default_factory=threading.Lock)
62
+ lock: threading.RLock = field(default_factory=threading.RLock)
57
63
 
58
64
 
59
65
  class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
@@ -61,7 +67,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
61
67
 
62
68
  def __init__(self) -> None:
63
69
 
64
- # Map node_id to (online_until, ping_interval)
70
+ # Map node_id to (online_until, heartbeat_interval)
65
71
  self.node_ids: dict[int, tuple[float, float]] = {}
66
72
  self.public_key_to_node_id: dict[bytes, int] = {}
67
73
  self.node_id_to_public_key: dict[int, bytes] = {}
@@ -70,15 +76,18 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
70
76
  self.run_ids: dict[int, RunRecord] = {}
71
77
  self.contexts: dict[int, Context] = {}
72
78
  self.federation_options: dict[int, ConfigRecord] = {}
73
- self.message_ins_store: dict[UUID, Message] = {}
74
- self.message_res_store: dict[UUID, Message] = {}
75
- self.message_ins_id_to_message_res_id: dict[UUID, UUID] = {}
79
+ self.message_ins_store: dict[str, Message] = {}
80
+ self.message_res_store: dict[str, Message] = {}
81
+ self.message_ins_id_to_message_res_id: dict[str, str] = {}
82
+
83
+ # Map flwr_aid to run_ids for O(1) reverse index lookup
84
+ self.flwr_aid_to_run_ids: dict[str, set[int]] = defaultdict(set)
76
85
 
77
86
  self.node_public_keys: set[bytes] = set()
78
87
 
79
88
  self.lock = threading.RLock()
80
89
 
81
- def store_message_ins(self, message: Message) -> Optional[UUID]:
90
+ def store_message_ins(self, message: Message) -> Optional[str]:
82
91
  """Store one Message."""
83
92
  # Validate message
84
93
  errors = validate_message(message, is_reply_message=False)
@@ -106,12 +115,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
106
115
  )
107
116
  return None
108
117
 
109
- # Create message_id
110
- message_id = uuid4()
111
-
112
- # Store Message
113
- # pylint: disable-next=W0212
114
- message.metadata._message_id = str(message_id) # type: ignore
118
+ message_id = message.metadata.message_id
115
119
  with self.lock:
116
120
  self.message_ins_store[message_id] = message
117
121
 
@@ -147,7 +151,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
147
151
  return message_ins_list
148
152
 
149
153
  # pylint: disable=R0911
150
- def store_message_res(self, message: Message) -> Optional[UUID]:
154
+ def store_message_res(self, message: Message) -> Optional[str]:
151
155
  """Store one Message."""
152
156
  # Validate message
153
157
  errors = validate_message(message, is_reply_message=True)
@@ -159,7 +163,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
159
163
  with self.lock:
160
164
  # Check if the Message it is replying to exists and is valid
161
165
  msg_ins_id = res_metadata.reply_to_message_id
162
- msg_ins = self.message_ins_store.get(UUID(msg_ins_id))
166
+ msg_ins = self.message_ins_store.get(msg_ins_id)
163
167
 
164
168
  # Ensure that dst_node_id of original Message matches the src_node_id of
165
169
  # reply Message.
@@ -214,22 +218,17 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
214
218
  log(ERROR, "`metadata.run_id` is invalid")
215
219
  return None
216
220
 
217
- # Create message_id
218
- message_id = uuid4()
219
-
220
- # Store Message
221
- # pylint: disable-next=W0212
222
- message.metadata._message_id = str(message_id) # type: ignore
221
+ message_id = message.metadata.message_id
223
222
  with self.lock:
224
223
  self.message_res_store[message_id] = message
225
- self.message_ins_id_to_message_res_id[UUID(msg_ins_id)] = message_id
224
+ self.message_ins_id_to_message_res_id[msg_ins_id] = message_id
226
225
 
227
226
  # Return the new message_id
228
227
  return message_id
229
228
 
230
- def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
229
+ def get_message_res(self, message_ids: set[str]) -> list[Message]:
231
230
  """Get reply Messages for the given Message IDs."""
232
- ret: dict[UUID, Message] = {}
231
+ ret: dict[str, Message] = {}
233
232
 
234
233
  with self.lock:
235
234
  current = time.time()
@@ -250,7 +249,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
250
249
  inquired_in_message_ids=message_ids,
251
250
  found_in_message_dict=self.message_ins_store,
252
251
  node_id_to_online_until={
253
- node_id: self.node_ids[node_id][0] for node_id in dst_node_ids
252
+ node_id: self.node_ids[node_id][0]
253
+ for node_id in dst_node_ids
254
+ if node_id in self.node_ids
254
255
  },
255
256
  current_time=current,
256
257
  )
@@ -281,7 +282,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
281
282
 
282
283
  return list(ret.values())
283
284
 
284
- def delete_messages(self, message_ins_ids: set[UUID]) -> None:
285
+ def delete_messages(self, message_ins_ids: set[str]) -> None:
285
286
  """Delete a Message and its reply based on provided Message IDs."""
286
287
  if not message_ins_ids:
287
288
  return
@@ -298,9 +299,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
298
299
  )
299
300
  del self.message_res_store[message_res_id]
300
301
 
301
- def get_message_ids_from_run_id(self, run_id: int) -> set[UUID]:
302
+ def get_message_ids_from_run_id(self, run_id: int) -> set[str]:
302
303
  """Get all instruction Message IDs for the given run_id."""
303
- message_id_list: set[UUID] = set()
304
+ message_id_list: set[str] = set()
304
305
  with self.lock:
305
306
  for message_id, message in self.message_ins_store.items():
306
307
  if message.metadata.run_id == run_id:
@@ -322,7 +323,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
322
323
  """
323
324
  return len(self.message_res_store)
324
325
 
325
- def create_node(self, ping_interval: float) -> int:
326
+ def create_node(self, heartbeat_interval: float) -> int:
326
327
  """Create, store in the link state, and return `node_id`."""
327
328
  # Sample a random int64 as node_id
328
329
  node_id = generate_rand_int_from_bytes(
@@ -334,8 +335,11 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
334
335
  log(ERROR, "Unexpected node registration failure.")
335
336
  return 0
336
337
 
337
- # Mark the node online util time.time() + ping_interval
338
- self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
338
+ # Mark the node online until time.time() + heartbeat_interval
339
+ self.node_ids[node_id] = (
340
+ time.time() + heartbeat_interval,
341
+ heartbeat_interval,
342
+ )
339
343
  return node_id
340
344
 
341
345
  def delete_node(self, node_id: int) -> None:
@@ -400,6 +404,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
400
404
  fab_hash: Optional[str],
401
405
  override_config: UserConfig,
402
406
  federation_options: ConfigRecord,
407
+ flwr_aid: Optional[str],
403
408
  ) -> int:
404
409
  """Create a new run for the specified `fab_hash`."""
405
410
  # Sample a random int64 as run_id
@@ -423,9 +428,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
423
428
  sub_status="",
424
429
  details="",
425
430
  ),
431
+ flwr_aid=flwr_aid if flwr_aid else "",
426
432
  ),
427
433
  )
428
434
  self.run_ids[run_id] = run_record
435
+ # Add run_id to the flwr_aid_to_run_ids mapping if flwr_aid is provided
436
+ if flwr_aid:
437
+ self.flwr_aid_to_run_ids[flwr_aid].add(run_id)
429
438
 
430
439
  # Record federation options. Leave empty if not passed
431
440
  self.federation_options[run_id] = federation_options
@@ -453,13 +462,42 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
453
462
  with self.lock:
454
463
  return self.node_public_keys.copy()
455
464
 
456
- def get_run_ids(self) -> set[int]:
457
- """Retrieve all run IDs."""
465
+ def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
466
+ """Retrieve all run IDs if `flwr_aid` is not specified.
467
+
468
+ Otherwise, retrieve all run IDs for the specified `flwr_aid`.
469
+ """
458
470
  with self.lock:
471
+ if flwr_aid is not None:
472
+ # Return run IDs for the specified flwr_aid
473
+ return set(self.flwr_aid_to_run_ids.get(flwr_aid, ()))
459
474
  return set(self.run_ids.keys())
460
475
 
476
+ def _check_and_tag_inactive_run(self, run_ids: set[int]) -> None:
477
+ """Check if any runs are no longer active.
478
+
479
+ Marks runs with status 'starting' or 'running' as failed
480
+ if they have not sent a heartbeat before `active_until`.
481
+ """
482
+ current = now()
483
+ for record in (self.run_ids.get(run_id) for run_id in run_ids):
484
+ if record is None:
485
+ continue
486
+ with record.lock:
487
+ if record.run.status.status in (Status.STARTING, Status.RUNNING):
488
+ if record.active_until < current.timestamp():
489
+ record.run.status = RunStatus(
490
+ status=Status.FINISHED,
491
+ sub_status=SubStatus.FAILED,
492
+ details=RUN_FAILURE_DETAILS_NO_HEARTBEAT,
493
+ )
494
+ record.run.finished_at = now().isoformat()
495
+
461
496
  def get_run(self, run_id: int) -> Optional[Run]:
462
497
  """Retrieve information about the run with the specified `run_id`."""
498
+ # Check if runs are still active
499
+ self._check_and_tag_inactive_run(run_ids={run_id})
500
+
463
501
  with self.lock:
464
502
  if run_id not in self.run_ids:
465
503
  log(ERROR, "`run_id` is invalid")
@@ -468,6 +506,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
468
506
 
469
507
  def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
470
508
  """Retrieve the statuses for the specified runs."""
509
+ # Check if runs are still active
510
+ self._check_and_tag_inactive_run(run_ids=run_ids)
511
+
471
512
  with self.lock:
472
513
  return {
473
514
  run_id: self.run_ids[run_id].run.status
@@ -477,12 +518,16 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
477
518
 
478
519
  def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
479
520
  """Update the status of the run with the specified `run_id`."""
521
+ # Check if runs are still active
522
+ self._check_and_tag_inactive_run(run_ids={run_id})
523
+
480
524
  with self.lock:
481
525
  # Check if the run_id exists
482
526
  if run_id not in self.run_ids:
483
527
  log(ERROR, "`run_id` is invalid")
484
528
  return False
485
529
 
530
+ with self.run_ids[run_id].lock:
486
531
  # Check if the status transition is valid
487
532
  current_status = self.run_ids[run_id].run.status
488
533
  if not is_valid_transition(current_status, new_status):
@@ -504,14 +549,23 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
504
549
  )
505
550
  return False
506
551
 
507
- # Update the status
552
+ # Initialize heartbeat_interval and active_until
553
+ # when switching to starting or running
554
+ current = now()
508
555
  run_record = self.run_ids[run_id]
556
+ if new_status.status in (Status.STARTING, Status.RUNNING):
557
+ run_record.heartbeat_interval = HEARTBEAT_MAX_INTERVAL
558
+ run_record.active_until = (
559
+ current.timestamp() + run_record.heartbeat_interval
560
+ )
561
+
562
+ # Update the run status
509
563
  if new_status.status == Status.STARTING:
510
- run_record.run.starting_at = now().isoformat()
564
+ run_record.run.starting_at = current.isoformat()
511
565
  elif new_status.status == Status.RUNNING:
512
- run_record.run.running_at = now().isoformat()
566
+ run_record.run.running_at = current.isoformat()
513
567
  elif new_status.status == Status.FINISHED:
514
- run_record.run.finished_at = now().isoformat()
568
+ run_record.run.finished_at = current.isoformat()
515
569
  run_record.run.status = new_status
516
570
  return True
517
571
 
@@ -536,21 +590,62 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
536
590
  return None
537
591
  return self.federation_options[run_id]
538
592
 
539
- def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
540
- """Acknowledge a ping received from a node, serving as a heartbeat.
593
+ def acknowledge_node_heartbeat(
594
+ self, node_id: int, heartbeat_interval: float
595
+ ) -> bool:
596
+ """Acknowledge a heartbeat received from a node, serving as a heartbeat.
541
597
 
542
- It allows for one missed ping (in a PING_PATIENCE * ping_interval) before
543
- marking the node as offline, where PING_PATIENCE = 2 in default.
598
+ A node is considered online as long as it sends heartbeats within
599
+ the tolerated interval: HEARTBEAT_PATIENCE × heartbeat_interval.
600
+ HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before
601
+ the node is marked as offline.
544
602
  """
545
603
  with self.lock:
546
604
  if node_id in self.node_ids:
547
605
  self.node_ids[node_id] = (
548
- time.time() + PING_PATIENCE * ping_interval,
549
- ping_interval,
606
+ time.time() + HEARTBEAT_PATIENCE * heartbeat_interval,
607
+ heartbeat_interval,
550
608
  )
551
609
  return True
552
610
  return False
553
611
 
612
+ def acknowledge_app_heartbeat(self, run_id: int, heartbeat_interval: float) -> bool:
613
+ """Acknowledge a heartbeat received from a ServerApp for a given run.
614
+
615
+ A run with status `"running"` is considered alive as long as it sends heartbeats
616
+ within the tolerated interval: HEARTBEAT_PATIENCE × heartbeat_interval.
617
+ HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before the run is
618
+ marked as `"completed:failed"`.
619
+ """
620
+ with self.lock:
621
+ # Search for the run
622
+ record = self.run_ids.get(run_id)
623
+
624
+ # Check if the run_id exists
625
+ if record is None:
626
+ log(ERROR, "`run_id` is invalid")
627
+ return False
628
+
629
+ with record.lock:
630
+ # Check if runs are still active
631
+ self._check_and_tag_inactive_run(run_ids={run_id})
632
+
633
+ # Check if the run is of status "running"/"starting"
634
+ current_status = record.run.status
635
+ if current_status.status not in (Status.RUNNING, Status.STARTING):
636
+ log(
637
+ ERROR,
638
+ 'Cannot acknowledge heartbeat for run with status "%s"',
639
+ current_status.status,
640
+ )
641
+ return False
642
+
643
+ # Update the `active_until` and `heartbeat_interval` for the given run
644
+ current = now().timestamp()
645
+ record.active_until = current + HEARTBEAT_PATIENCE * heartbeat_interval
646
+ record.heartbeat_interval = heartbeat_interval
647
+ return True
648
+
554
649
  def get_serverapp_context(self, run_id: int) -> Optional[Context]:
555
650
  """Get the context for the specified `run_id`."""
556
651
  return self.contexts.get(run_id)
@@ -17,7 +17,6 @@
17
17
 
18
18
  import abc
19
19
  from typing import Optional
20
- from uuid import UUID
21
20
 
22
21
  from flwr.common import Context, Message
23
22
  from flwr.common.record import ConfigRecord
@@ -28,13 +27,13 @@ class LinkState(abc.ABC): # pylint: disable=R0904
28
27
  """Abstract LinkState."""
29
28
 
30
29
  @abc.abstractmethod
31
- def store_message_ins(self, message: Message) -> Optional[UUID]:
30
+ def store_message_ins(self, message: Message) -> Optional[str]:
32
31
  """Store one Message.
33
32
 
34
33
  Usually, the ServerAppIo API calls this to schedule instructions.
35
34
 
36
35
  Stores the value of the `message` in the link state and, if successful,
37
- returns the `message_id` (UUID) of the `message`. If, for any reason,
36
+ returns the `message_id` (str) of the `message`. If, for any reason,
38
37
  storing the `message` fails, `None` is returned.
39
38
 
40
39
  Constraints
@@ -61,12 +60,12 @@ class LinkState(abc.ABC): # pylint: disable=R0904
61
60
  """
62
61
 
63
62
  @abc.abstractmethod
64
- def store_message_res(self, message: Message) -> Optional[UUID]:
63
+ def store_message_res(self, message: Message) -> Optional[str]:
65
64
  """Store one Message.
66
65
 
67
66
  Usually, the Fleet API calls this for Nodes returning results.
68
67
 
69
- Stores the Message and, if successful, returns the `message_id` (UUID) of
68
+ Stores the Message and, if successful, returns the `message_id` (str) of
70
69
  the `message`. If storing the `message` fails, `None` is returned.
71
70
 
72
71
  Constraints
@@ -78,7 +77,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
78
77
  """
79
78
 
80
79
  @abc.abstractmethod
81
- def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
80
+ def get_message_res(self, message_ids: set[str]) -> list[Message]:
82
81
  """Get reply Messages for the given Message IDs.
83
82
 
84
83
  This method is typically called by the ServerAppIo API to obtain
@@ -94,7 +93,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
94
93
 
95
94
  Parameters
96
95
  ----------
97
- message_ids : set[UUID]
96
+ message_ids : set[str]
98
97
  A set of Message IDs used to retrieve reply Messages responding to them.
99
98
 
100
99
  Returns
@@ -113,22 +112,22 @@ class LinkState(abc.ABC): # pylint: disable=R0904
113
112
  """Calculate the number of reply Messages in store."""
114
113
 
115
114
  @abc.abstractmethod
116
- def delete_messages(self, message_ins_ids: set[UUID]) -> None:
115
+ def delete_messages(self, message_ins_ids: set[str]) -> None:
117
116
  """Delete a Message and its reply based on provided Message IDs.
118
117
 
119
118
  Parameters
120
119
  ----------
121
- message_ins_ids : set[UUID]
120
+ message_ins_ids : set[str]
122
121
  A set of Message IDs. For each ID in the set, the corresponding
123
122
  Message and its associated reply Message will be deleted.
124
123
  """
125
124
 
126
125
  @abc.abstractmethod
127
- def get_message_ids_from_run_id(self, run_id: int) -> set[UUID]:
126
+ def get_message_ids_from_run_id(self, run_id: int) -> set[str]:
128
127
  """Get all instruction Message IDs for the given run_id."""
129
128
 
130
129
  @abc.abstractmethod
131
- def create_node(self, ping_interval: float) -> int:
130
+ def create_node(self, heartbeat_interval: float) -> int:
132
131
  """Create, store in the link state, and return `node_id`."""
133
132
 
134
133
  @abc.abstractmethod
@@ -165,12 +164,16 @@ class LinkState(abc.ABC): # pylint: disable=R0904
165
164
  fab_hash: Optional[str],
166
165
  override_config: UserConfig,
167
166
  federation_options: ConfigRecord,
167
+ flwr_aid: Optional[str],
168
168
  ) -> int:
169
169
  """Create a new run for the specified `fab_hash`."""
170
170
 
171
171
  @abc.abstractmethod
172
- def get_run_ids(self) -> set[int]:
173
- """Retrieve all run IDs."""
172
+ def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
173
+ """Retrieve all run IDs if `flwr_aid` is not specified.
174
+
175
+ Otherwise, retrieve all run IDs for the specified `flwr_aid`.
176
+ """
174
177
 
175
178
  @abc.abstractmethod
176
179
  def get_run(self, run_id: int) -> Optional[Run]:
@@ -267,22 +270,52 @@ class LinkState(abc.ABC): # pylint: disable=R0904
267
270
  """Retrieve all currently stored `node_public_keys` as a set."""
268
271
 
269
272
  @abc.abstractmethod
270
- def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
271
- """Acknowledge a ping received from a node, serving as a heartbeat.
273
+ def acknowledge_node_heartbeat(
274
+ self, node_id: int, heartbeat_interval: float
275
+ ) -> bool:
276
+ """Acknowledge a heartbeat received from a node.
277
+
278
+ A node is considered online as long as it sends heartbeats within
279
+ the tolerated interval: HEARTBEAT_PATIENCE × heartbeat_interval.
280
+ HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before
281
+ the node is marked as offline.
272
282
 
273
283
  Parameters
274
284
  ----------
275
285
  node_id : int
276
- The `node_id` from which the ping was received.
277
- ping_interval : float
286
+ The `node_id` from which the heartbeat was received.
287
+ heartbeat_interval : float
288
+ The interval (in seconds) from the current timestamp within which the next
289
+ heartbeat from this node must be received. This acts as a hard deadline to
290
+ ensure an accurate assessment of the node's availability.
291
+
292
+ Returns
293
+ -------
294
+ is_acknowledged : bool
295
+ True if the heartbeat is successfully acknowledged; otherwise, False.
296
+ """
297
+
298
+ @abc.abstractmethod
299
+ def acknowledge_app_heartbeat(self, run_id: int, heartbeat_interval: float) -> bool:
300
+ """Acknowledge a heartbeat received from a ServerApp for a given run.
301
+
302
+ A run with status `"running"` is considered alive as long as it sends heartbeats
303
+ within the tolerated interval: HEARTBEAT_PATIENCE × heartbeat_interval.
304
+ HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before the run is
305
+ marked as `"completed:failed"`.
306
+
307
+ Parameters
308
+ ----------
309
+ run_id : int
310
+ The `run_id` from which the heartbeat was received.
311
+ heartbeat_interval : float
278
312
  The interval (in seconds) from the current timestamp within which the next
279
- ping from this node must be received. This acts as a hard deadline to ensure
280
- an accurate assessment of the node's availability.
313
+ heartbeat from the ServerApp for this run must be received.
281
314
 
282
315
  Returns
283
316
  -------
284
317
  is_acknowledged : bool
285
- True if the ping is successfully acknowledged; otherwise, False.
318
+ True if the heartbeat is successfully acknowledged; otherwise, False.
286
319
  """
287
320
 
288
321
  @abc.abstractmethod