flwr-nightly 1.13.0.dev20241106__py3-none-any.whl → 1.13.0.dev20241117__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (58) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/build.py +37 -0
  3. flwr/cli/install.py +5 -3
  4. flwr/cli/ls.py +228 -0
  5. flwr/cli/run/run.py +16 -5
  6. flwr/client/app.py +68 -19
  7. flwr/client/clientapp/app.py +51 -35
  8. flwr/client/grpc_rere_client/connection.py +2 -12
  9. flwr/client/nodestate/__init__.py +25 -0
  10. flwr/client/nodestate/in_memory_nodestate.py +38 -0
  11. flwr/client/nodestate/nodestate.py +30 -0
  12. flwr/client/nodestate/nodestate_factory.py +37 -0
  13. flwr/client/rest_client/connection.py +4 -14
  14. flwr/client/supernode/app.py +57 -53
  15. flwr/common/args.py +148 -0
  16. flwr/common/config.py +10 -0
  17. flwr/common/constant.py +21 -7
  18. flwr/common/date.py +18 -0
  19. flwr/common/logger.py +6 -2
  20. flwr/common/object_ref.py +47 -16
  21. flwr/common/serde.py +10 -0
  22. flwr/common/typing.py +32 -11
  23. flwr/proto/exec_pb2.py +23 -17
  24. flwr/proto/exec_pb2.pyi +50 -20
  25. flwr/proto/exec_pb2_grpc.py +34 -0
  26. flwr/proto/exec_pb2_grpc.pyi +13 -0
  27. flwr/proto/run_pb2.py +32 -27
  28. flwr/proto/run_pb2.pyi +44 -1
  29. flwr/proto/simulationio_pb2.py +2 -2
  30. flwr/proto/simulationio_pb2_grpc.py +34 -0
  31. flwr/proto/simulationio_pb2_grpc.pyi +13 -0
  32. flwr/server/app.py +83 -87
  33. flwr/server/driver/driver.py +1 -1
  34. flwr/server/driver/grpc_driver.py +6 -20
  35. flwr/server/driver/inmemory_driver.py +1 -3
  36. flwr/server/run_serverapp.py +8 -238
  37. flwr/server/serverapp/app.py +44 -89
  38. flwr/server/strategy/aggregate.py +4 -4
  39. flwr/server/superlink/fleet/rest_rere/rest_api.py +10 -9
  40. flwr/server/superlink/linkstate/in_memory_linkstate.py +76 -62
  41. flwr/server/superlink/linkstate/linkstate.py +24 -9
  42. flwr/server/superlink/linkstate/sqlite_linkstate.py +87 -128
  43. flwr/server/superlink/linkstate/utils.py +191 -32
  44. flwr/server/superlink/simulation/simulationio_servicer.py +22 -1
  45. flwr/simulation/__init__.py +3 -1
  46. flwr/simulation/app.py +245 -352
  47. flwr/simulation/legacy_app.py +402 -0
  48. flwr/simulation/run_simulation.py +8 -19
  49. flwr/simulation/simulationio_connection.py +2 -2
  50. flwr/superexec/deployment.py +13 -7
  51. flwr/superexec/exec_servicer.py +32 -3
  52. flwr/superexec/executor.py +4 -3
  53. flwr/superexec/simulation.py +52 -145
  54. {flwr_nightly-1.13.0.dev20241106.dist-info → flwr_nightly-1.13.0.dev20241117.dist-info}/METADATA +10 -7
  55. {flwr_nightly-1.13.0.dev20241106.dist-info → flwr_nightly-1.13.0.dev20241117.dist-info}/RECORD +58 -51
  56. {flwr_nightly-1.13.0.dev20241106.dist-info → flwr_nightly-1.13.0.dev20241117.dist-info}/entry_points.txt +1 -0
  57. {flwr_nightly-1.13.0.dev20241106.dist-info → flwr_nightly-1.13.0.dev20241117.dist-info}/LICENSE +0 -0
  58. {flwr_nightly-1.13.0.dev20241106.dist-info → flwr_nightly-1.13.0.dev20241117.dist-info}/WHEEL +0 -0
@@ -40,7 +40,8 @@ from .utils import (
40
40
  generate_rand_int_from_bytes,
41
41
  has_valid_sub_status,
42
42
  is_valid_transition,
43
- make_node_unavailable_taskres,
43
+ verify_found_taskres,
44
+ verify_taskins_ids,
44
45
  )
45
46
 
46
47
 
@@ -49,11 +50,6 @@ class RunRecord: # pylint: disable=R0902
49
50
  """The record of a specific run, including its status and timestamps."""
50
51
 
51
52
  run: Run
52
- status: RunStatus
53
- pending_at: str = ""
54
- starting_at: str = ""
55
- running_at: str = ""
56
- finished_at: str = ""
57
53
  logs: list[tuple[float, str]] = field(default_factory=list)
58
54
  log_lock: threading.Lock = field(default_factory=threading.Lock)
59
55
 
@@ -73,12 +69,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
73
69
  self.federation_options: dict[int, ConfigsRecord] = {}
74
70
  self.task_ins_store: dict[UUID, TaskIns] = {}
75
71
  self.task_res_store: dict[UUID, TaskRes] = {}
72
+ self.task_ins_id_to_task_res_id: dict[UUID, UUID] = {}
76
73
 
77
74
  self.node_public_keys: set[bytes] = set()
78
75
  self.server_public_key: Optional[bytes] = None
79
76
  self.server_private_key: Optional[bytes] = None
80
77
 
81
- self.lock = threading.Lock()
78
+ self.lock = threading.RLock()
82
79
 
83
80
  def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
84
81
  """Store one TaskIns."""
@@ -228,57 +225,50 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
228
225
  task_res.task_id = str(task_id)
229
226
  with self.lock:
230
227
  self.task_res_store[task_id] = task_res
228
+ self.task_ins_id_to_task_res_id[UUID(task_ins_id)] = task_id
231
229
 
232
230
  # Return the new task_id
233
231
  return task_id
234
232
 
235
233
  def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
236
- """Get all TaskRes that have not been delivered yet."""
234
+ """Get TaskRes for the given TaskIns IDs."""
235
+ ret: dict[UUID, TaskRes] = {}
236
+
237
237
  with self.lock:
238
- # Find TaskRes that were not delivered yet
239
- task_res_list: list[TaskRes] = []
240
- replied_task_ids: set[UUID] = set()
241
- for _, task_res in self.task_res_store.items():
242
- reply_to = UUID(task_res.task.ancestry[0])
243
-
244
- # Check if corresponding TaskIns exists and is not expired
245
- task_ins = self.task_ins_store.get(reply_to)
246
- if task_ins is None:
247
- log(WARNING, "TaskIns with task_id %s does not exist.", reply_to)
248
- task_ids.remove(reply_to)
249
- continue
250
-
251
- if task_ins.task.created_at + task_ins.task.ttl <= time.time():
252
- log(WARNING, "TaskIns with task_id %s is expired.", reply_to)
253
- task_ids.remove(reply_to)
254
- continue
255
-
256
- if reply_to in task_ids and task_res.task.delivered_at == "":
257
- task_res_list.append(task_res)
258
- replied_task_ids.add(reply_to)
259
-
260
- # Check if the node is offline
261
- for task_id in task_ids - replied_task_ids:
262
- task_ins = self.task_ins_store.get(task_id)
263
- if task_ins is None:
264
- continue
265
- node_id = task_ins.task.consumer.node_id
266
- online_until, _ = self.node_ids[node_id]
267
- # Generate a TaskRes containing an error reply if the node is offline.
268
- if online_until < time.time():
269
- err_taskres = make_node_unavailable_taskres(
270
- ref_taskins=task_ins,
271
- )
272
- self.task_res_store[UUID(err_taskres.task_id)] = err_taskres
273
- task_res_list.append(err_taskres)
274
-
275
- # Mark all of them as delivered
238
+ current = time.time()
239
+
240
+ # Verify TaskIns IDs
241
+ ret = verify_taskins_ids(
242
+ inquired_taskins_ids=task_ids,
243
+ found_taskins_dict=self.task_ins_store,
244
+ current_time=current,
245
+ )
246
+
247
+ # Find all TaskRes
248
+ task_res_found: list[TaskRes] = []
249
+ for task_id in task_ids:
250
+ # If TaskRes exists and is not delivered, add it to the list
251
+ if task_res_id := self.task_ins_id_to_task_res_id.get(task_id):
252
+ task_res = self.task_res_store[task_res_id]
253
+ if task_res.task.delivered_at == "":
254
+ task_res_found.append(task_res)
255
+ tmp_ret_dict = verify_found_taskres(
256
+ inquired_taskins_ids=task_ids,
257
+ found_taskins_dict=self.task_ins_store,
258
+ found_taskres_list=task_res_found,
259
+ current_time=current,
260
+ )
261
+ ret.update(tmp_ret_dict)
262
+
263
+ # Mark existing TaskRes to be returned as delivered
276
264
  delivered_at = now().isoformat()
277
- for task_res in task_res_list:
265
+ for task_res in task_res_found:
278
266
  task_res.task.delivered_at = delivered_at
279
267
 
280
- # Return TaskRes
281
- return task_res_list
268
+ # Cleanup
269
+ self._force_delete_tasks_by_ids(set(ret.keys()))
270
+
271
+ return list(ret.values())
282
272
 
283
273
  def delete_tasks(self, task_ids: set[UUID]) -> None:
284
274
  """Delete all delivered TaskIns/TaskRes pairs."""
@@ -299,9 +289,25 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
299
289
 
300
290
  for task_id in task_ins_to_be_deleted:
301
291
  del self.task_ins_store[task_id]
292
+ del self.task_ins_id_to_task_res_id[task_id]
302
293
  for task_id in task_res_to_be_deleted:
303
294
  del self.task_res_store[task_id]
304
295
 
296
+ def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
297
+ """Delete tasks based on a set of TaskIns IDs."""
298
+ if not task_ids:
299
+ return
300
+
301
+ with self.lock:
302
+ for task_id in task_ids:
303
+ # Delete TaskIns
304
+ if task_id in self.task_ins_store:
305
+ del self.task_ins_store[task_id]
306
+ # Delete TaskRes
307
+ if task_id in self.task_ins_id_to_task_res_id:
308
+ task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
309
+ del self.task_res_store[task_res_id]
310
+
305
311
  def num_task_ins(self) -> int:
306
312
  """Calculate the number of task_ins in store.
307
313
 
@@ -402,13 +408,16 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
402
408
  fab_version=fab_version if fab_version else "",
403
409
  fab_hash=fab_hash if fab_hash else "",
404
410
  override_config=override_config,
411
+ pending_at=now().isoformat(),
412
+ starting_at="",
413
+ running_at="",
414
+ finished_at="",
415
+ status=RunStatus(
416
+ status=Status.PENDING,
417
+ sub_status="",
418
+ details="",
419
+ ),
405
420
  ),
406
- status=RunStatus(
407
- status=Status.PENDING,
408
- sub_status="",
409
- details="",
410
- ),
411
- pending_at=now().isoformat(),
412
421
  )
413
422
  self.run_ids[run_id] = run_record
414
423
 
@@ -451,6 +460,11 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
451
460
  """Retrieve all currently stored `node_public_keys` as a set."""
452
461
  return self.node_public_keys
453
462
 
463
+ def get_run_ids(self) -> set[int]:
464
+ """Retrieve all run IDs."""
465
+ with self.lock:
466
+ return set(self.run_ids.keys())
467
+
454
468
  def get_run(self, run_id: int) -> Optional[Run]:
455
469
  """Retrieve information about the run with the specified `run_id`."""
456
470
  with self.lock:
@@ -463,7 +477,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
463
477
  """Retrieve the statuses for the specified runs."""
464
478
  with self.lock:
465
479
  return {
466
- run_id: self.run_ids[run_id].status
480
+ run_id: self.run_ids[run_id].run.status
467
481
  for run_id in set(run_ids)
468
482
  if run_id in self.run_ids
469
483
  }
@@ -477,7 +491,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
477
491
  return False
478
492
 
479
493
  # Check if the status transition is valid
480
- current_status = self.run_ids[run_id].status
494
+ current_status = self.run_ids[run_id].run.status
481
495
  if not is_valid_transition(current_status, new_status):
482
496
  log(
483
497
  ERROR,
@@ -500,12 +514,12 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
500
514
  # Update the status
501
515
  run_record = self.run_ids[run_id]
502
516
  if new_status.status == Status.STARTING:
503
- run_record.starting_at = now().isoformat()
517
+ run_record.run.starting_at = now().isoformat()
504
518
  elif new_status.status == Status.RUNNING:
505
- run_record.running_at = now().isoformat()
519
+ run_record.run.running_at = now().isoformat()
506
520
  elif new_status.status == Status.FINISHED:
507
- run_record.finished_at = now().isoformat()
508
- run_record.status = new_status
521
+ run_record.run.finished_at = now().isoformat()
522
+ run_record.run.status = new_status
509
523
  return True
510
524
 
511
525
  def get_pending_run_id(self) -> Optional[int]:
@@ -515,7 +529,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
515
529
  # Loop through all registered runs
516
530
  for run_id, run_rec in self.run_ids.items():
517
531
  # Break once a pending run is found
518
- if run_rec.status.status == Status.PENDING:
532
+ if run_rec.run.status.status == Status.PENDING:
519
533
  pending_run_id = run_id
520
534
  break
521
535
 
@@ -101,13 +101,27 @@ class LinkState(abc.ABC): # pylint: disable=R0904
101
101
 
102
102
  @abc.abstractmethod
103
103
  def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
104
- """Get TaskRes for task_ids.
104
+ """Get TaskRes for the given TaskIns IDs.
105
105
 
106
- Usually, the ServerAppIo API calls this method to get results for instructions
107
- it has previously scheduled.
106
+ This method is typically called by the ServerAppIo API to obtain
107
+ results (TaskRes) for previously scheduled instructions (TaskIns).
108
+ For each task_id provided, this method returns one of the following responses:
108
109
 
109
- Retrieves all TaskRes for the given `task_ids` and returns and empty list of
110
- none could be found.
110
+ - An error TaskRes if the corresponding TaskIns does not exist or has expired.
111
+ - An error TaskRes if the corresponding TaskRes exists but has expired.
112
+ - The valid TaskRes if the TaskIns has a corresponding valid TaskRes.
113
+ - Nothing if the TaskIns is still valid and waiting for a TaskRes.
114
+
115
+ Parameters
116
+ ----------
117
+ task_ids : set[UUID]
118
+ A set of TaskIns IDs for which to retrieve results (TaskRes).
119
+
120
+ Returns
121
+ -------
122
+ list[TaskRes]
123
+ A list of TaskRes corresponding to the given task IDs. If no
124
+ TaskRes could be found for any of the task IDs, an empty list is returned.
111
125
  """
112
126
 
113
127
  @abc.abstractmethod
@@ -163,6 +177,10 @@ class LinkState(abc.ABC): # pylint: disable=R0904
163
177
  ) -> int:
164
178
  """Create a new run for the specified `fab_hash`."""
165
179
 
180
+ @abc.abstractmethod
181
+ def get_run_ids(self) -> set[int]:
182
+ """Retrieve all run IDs."""
183
+
166
184
  @abc.abstractmethod
167
185
  def get_run(self, run_id: int) -> Optional[Run]:
168
186
  """Retrieve information about the run with the specified `run_id`.
@@ -175,10 +193,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
175
193
  Returns
176
194
  -------
177
195
  Optional[Run]
178
- A dataclass instance containing three elements if `run_id` is valid:
179
- - `run_id`: The identifier of the run, same as the specified `run_id`.
180
- - `fab_id`: The identifier of the FAB used in the specified run.
181
- - `fab_version`: The version of the FAB used in the specified run.
196
+ The `Run` instance if found; otherwise, `None`.
182
197
  """
183
198
 
184
199
  @abc.abstractmethod
@@ -57,7 +57,8 @@ from .utils import (
57
57
  generate_rand_int_from_bytes,
58
58
  has_valid_sub_status,
59
59
  is_valid_transition,
60
- make_node_unavailable_taskres,
60
+ verify_found_taskres,
61
+ verify_taskins_ids,
61
62
  )
62
63
 
63
64
  SQL_CREATE_TABLE_NODE = """
@@ -511,150 +512,67 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
511
512
 
512
513
  # pylint: disable-next=R0912,R0915,R0914
513
514
  def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
514
- """Get TaskRes for task_ids.
515
+ """Get TaskRes for the given TaskIns IDs."""
516
+ ret: dict[UUID, TaskRes] = {}
515
517
 
516
- Usually, the ServerAppIo API calls this method to get results for instructions
517
- it has previously scheduled.
518
-
519
- Retrieves all TaskRes for the given `task_ids` and returns and empty list if
520
- none could be found.
521
-
522
- Constraints
523
- -----------
524
- If `limit` is not `None`, return, at most, `limit` number of TaskRes. The limit
525
- will only take effect if enough task_ids are in the set AND are currently
526
- available. If `limit` is set, it has to be greater than zero.
527
- """
528
- # Check if corresponding TaskIns exists and is not expired
529
- task_ids_placeholders = ",".join([f":id_{i}" for i in range(len(task_ids))])
518
+ # Verify TaskIns IDs
519
+ current = time.time()
530
520
  query = f"""
531
521
  SELECT *
532
522
  FROM task_ins
533
- WHERE task_id IN ({task_ids_placeholders})
534
- AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
523
+ WHERE task_id IN ({",".join(["?"] * len(task_ids))});
535
524
  """
536
- query += ";"
537
-
538
- task_ins_data = {}
539
- for index, task_id in enumerate(task_ids):
540
- task_ins_data[f"id_{index}"] = str(task_id)
541
-
542
- task_ins_rows = self.query(query, task_ins_data)
543
-
544
- if not task_ins_rows:
545
- return []
546
-
547
- for row in task_ins_rows:
548
- # Convert values from sint64 to uint64
525
+ rows = self.query(query, tuple(str(task_id) for task_id in task_ids))
526
+ found_task_ins_dict: dict[UUID, TaskIns] = {}
527
+ for row in rows:
549
528
  convert_sint64_values_in_dict_to_uint64(
550
529
  row, ["run_id", "producer_node_id", "consumer_node_id"]
551
530
  )
552
- task_ins = dict_to_task_ins(row)
553
- if task_ins.task.created_at + task_ins.task.ttl <= time.time():
554
- log(WARNING, "TaskIns with task_id %s is expired.", task_ins.task_id)
555
- task_ids.remove(UUID(task_ins.task_id))
531
+ found_task_ins_dict[UUID(row["task_id"])] = dict_to_task_ins(row)
556
532
 
557
- # Retrieve all anonymous Tasks
558
- if len(task_ids) == 0:
559
- return []
533
+ ret = verify_taskins_ids(
534
+ inquired_taskins_ids=task_ids,
535
+ found_taskins_dict=found_task_ins_dict,
536
+ current_time=current,
537
+ )
560
538
 
561
- placeholders = ",".join([f":id_{i}" for i in range(len(task_ids))])
539
+ # Find all TaskRes
562
540
  query = f"""
563
541
  SELECT *
564
542
  FROM task_res
565
- WHERE ancestry IN ({placeholders})
566
- AND delivered_at = ""
543
+ WHERE ancestry IN ({",".join(["?"] * len(task_ids))})
544
+ AND delivered_at = "";
567
545
  """
546
+ rows = self.query(query, tuple(str(task_id) for task_id in task_ids))
547
+ for row in rows:
548
+ convert_sint64_values_in_dict_to_uint64(
549
+ row, ["run_id", "producer_node_id", "consumer_node_id"]
550
+ )
551
+ tmp_ret_dict = verify_found_taskres(
552
+ inquired_taskins_ids=task_ids,
553
+ found_taskins_dict=found_task_ins_dict,
554
+ found_taskres_list=[dict_to_task_res(row) for row in rows],
555
+ current_time=current,
556
+ )
557
+ ret.update(tmp_ret_dict)
568
558
 
569
- data: dict[str, Union[str, float, int]] = {}
570
-
571
- query += ";"
572
-
573
- for index, task_id in enumerate(task_ids):
574
- data[f"id_{index}"] = str(task_id)
575
-
576
- rows = self.query(query, data)
577
-
578
- if rows:
579
- # Prepare query
580
- found_task_ids = [row["task_id"] for row in rows]
581
- placeholders = ",".join([f":id_{i}" for i in range(len(found_task_ids))])
582
- query = f"""
583
- UPDATE task_res
584
- SET delivered_at = :delivered_at
585
- WHERE task_id IN ({placeholders})
586
- RETURNING *;
587
- """
588
-
589
- # Prepare data for query
590
- delivered_at = now().isoformat()
591
- data = {"delivered_at": delivered_at}
592
- for index, task_id in enumerate(found_task_ids):
593
- data[f"id_{index}"] = str(task_id)
594
-
595
- # Run query
596
- rows = self.query(query, data)
597
-
598
- for row in rows:
599
- # Convert values from sint64 to uint64
600
- convert_sint64_values_in_dict_to_uint64(
601
- row, ["run_id", "producer_node_id", "consumer_node_id"]
602
- )
603
-
604
- result = [dict_to_task_res(row) for row in rows]
605
-
606
- # 1. Query: Fetch consumer_node_id of remaining task_ids
607
- # Assume the ancestry field only contains one element
608
- data.clear()
609
- replied_task_ids: set[UUID] = {UUID(str(row["ancestry"])) for row in rows}
610
- remaining_task_ids = task_ids - replied_task_ids
611
- placeholders = ",".join([f":id_{i}" for i in range(len(remaining_task_ids))])
612
- query = f"""
613
- SELECT consumer_node_id
614
- FROM task_ins
615
- WHERE task_id IN ({placeholders});
616
- """
617
- for index, task_id in enumerate(remaining_task_ids):
618
- data[f"id_{index}"] = str(task_id)
619
- node_ids = [int(row["consumer_node_id"]) for row in self.query(query, data)]
620
-
621
- # 2. Query: Select offline nodes
622
- placeholders = ",".join([f":id_{i}" for i in range(len(node_ids))])
623
- query = f"""
624
- SELECT node_id
625
- FROM node
626
- WHERE node_id IN ({placeholders})
627
- AND online_until < :time;
628
- """
629
- data = {f"id_{i}": str(node_id) for i, node_id in enumerate(node_ids)}
630
- data["time"] = time.time()
631
- offline_node_ids = [int(row["node_id"]) for row in self.query(query, data)]
632
-
633
- # 3. Query: Select TaskIns for offline nodes
634
- placeholders = ",".join([f":id_{i}" for i in range(len(offline_node_ids))])
559
+ # Mark existing TaskRes to be returned as delivered
560
+ delivered_at = now().isoformat()
561
+ for task_res in ret.values():
562
+ task_res.task.delivered_at = delivered_at
563
+ task_res_ids = [task_res.task_id for task_res in ret.values()]
635
564
  query = f"""
636
- SELECT *
637
- FROM task_ins
638
- WHERE consumer_node_id IN ({placeholders});
565
+ UPDATE task_res
566
+ SET delivered_at = ?
567
+ WHERE task_id IN ({",".join(["?"] * len(task_res_ids))});
639
568
  """
640
- data = {f"id_{i}": str(node_id) for i, node_id in enumerate(offline_node_ids)}
641
- task_ins_rows = self.query(query, data)
642
-
643
- # Make TaskRes containing node unavailabe error
644
- for row in task_ins_rows:
645
- for row in rows:
646
- # Convert values from sint64 to uint64
647
- convert_sint64_values_in_dict_to_uint64(
648
- row, ["run_id", "producer_node_id", "consumer_node_id"]
649
- )
569
+ data: list[Any] = [delivered_at] + task_res_ids
570
+ self.query(query, data)
650
571
 
651
- task_ins = dict_to_task_ins(row)
652
- err_taskres = make_node_unavailable_taskres(
653
- ref_taskins=task_ins,
654
- )
655
- result.append(err_taskres)
572
+ # Cleanup
573
+ self._force_delete_tasks_by_ids(set(ret.keys()))
656
574
 
657
- return result
575
+ return list(ret.values())
658
576
 
659
577
  def num_task_ins(self) -> int:
660
578
  """Calculate the number of task_ins in store.
@@ -714,6 +632,32 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
714
632
 
715
633
  return None
716
634
 
635
+ def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
636
+ """Delete tasks based on a set of TaskIns IDs."""
637
+ if not task_ids:
638
+ return
639
+ if self.conn is None:
640
+ raise AttributeError("LinkState not initialized")
641
+
642
+ placeholders = ",".join([f":id_{index}" for index in range(len(task_ids))])
643
+ data = {f"id_{index}": str(task_id) for index, task_id in enumerate(task_ids)}
644
+
645
+ # Delete task_ins
646
+ query_1 = f"""
647
+ DELETE FROM task_ins
648
+ WHERE task_id IN ({placeholders});
649
+ """
650
+
651
+ # Delete task_res
652
+ query_2 = f"""
653
+ DELETE FROM task_res
654
+ WHERE ancestry IN ({placeholders});
655
+ """
656
+
657
+ with self.conn:
658
+ self.conn.execute(query_1, data)
659
+ self.conn.execute(query_2, data)
660
+
717
661
  def create_node(
718
662
  self, ping_interval: float, public_key: Optional[bytes] = None
719
663
  ) -> int:
@@ -917,6 +861,12 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
917
861
  result: set[bytes] = {row["public_key"] for row in rows}
918
862
  return result
919
863
 
864
+ def get_run_ids(self) -> set[int]:
865
+ """Retrieve all run IDs."""
866
+ query = "SELECT run_id FROM run;"
867
+ rows = self.query(query)
868
+ return {convert_sint64_to_uint64(row["run_id"]) for row in rows}
869
+
920
870
  def get_run(self, run_id: int) -> Optional[Run]:
921
871
  """Retrieve information about the run with the specified `run_id`."""
922
872
  # Convert the uint64 value to sint64 for SQLite
@@ -931,6 +881,15 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
931
881
  fab_version=row["fab_version"],
932
882
  fab_hash=row["fab_hash"],
933
883
  override_config=json.loads(row["override_config"]),
884
+ pending_at=row["pending_at"],
885
+ starting_at=row["starting_at"],
886
+ running_at=row["running_at"],
887
+ finished_at=row["finished_at"],
888
+ status=RunStatus(
889
+ status=determine_run_status(row),
890
+ sub_status=row["sub_status"],
891
+ details=row["details"],
892
+ ),
934
893
  )
935
894
  log(ERROR, "`run_id` does not exist.")
936
895
  return None
@@ -1264,10 +1223,10 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
1264
1223
  def determine_run_status(row: dict[str, Any]) -> str:
1265
1224
  """Determine the status of the run based on timestamp fields."""
1266
1225
  if row["pending_at"]:
1226
+ if row["finished_at"]:
1227
+ return Status.FINISHED
1267
1228
  if row["starting_at"]:
1268
1229
  if row["running_at"]:
1269
- if row["finished_at"]:
1270
- return Status.FINISHED
1271
1230
  return Status.RUNNING
1272
1231
  return Status.STARTING
1273
1232
  return Status.PENDING