flwr-nightly 1.13.0.dev20241111__py3-none-any.whl → 1.14.0.dev20241126__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/install.py +0 -16
- flwr/cli/ls.py +228 -0
- flwr/cli/new/new.py +23 -13
- flwr/cli/new/templates/app/README.md.tpl +11 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +4 -2
- flwr/client/app.py +50 -14
- flwr/client/clientapp/app.py +40 -23
- flwr/client/grpc_rere_client/connection.py +7 -12
- flwr/client/rest_client/connection.py +4 -14
- flwr/client/supernode/app.py +31 -53
- flwr/common/args.py +85 -16
- flwr/common/constant.py +24 -6
- flwr/common/date.py +18 -0
- flwr/common/grpc.py +4 -1
- flwr/common/serde.py +10 -0
- flwr/common/typing.py +31 -10
- flwr/proto/exec_pb2.py +22 -13
- flwr/proto/exec_pb2.pyi +44 -0
- flwr/proto/exec_pb2_grpc.py +34 -0
- flwr/proto/exec_pb2_grpc.pyi +13 -0
- flwr/proto/run_pb2.py +30 -30
- flwr/proto/run_pb2.pyi +18 -1
- flwr/server/app.py +47 -77
- flwr/server/driver/grpc_driver.py +66 -16
- flwr/server/run_serverapp.py +8 -238
- flwr/server/serverapp/app.py +49 -29
- flwr/server/superlink/fleet/rest_rere/rest_api.py +10 -9
- flwr/server/superlink/linkstate/in_memory_linkstate.py +71 -46
- flwr/server/superlink/linkstate/linkstate.py +19 -5
- flwr/server/superlink/linkstate/sqlite_linkstate.py +81 -113
- flwr/server/superlink/linkstate/utils.py +193 -3
- flwr/simulation/app.py +52 -91
- flwr/simulation/legacy_app.py +21 -1
- flwr/simulation/run_simulation.py +7 -18
- flwr/simulation/simulationio_connection.py +2 -2
- flwr/superexec/deployment.py +12 -6
- flwr/superexec/exec_servicer.py +31 -2
- flwr/superexec/simulation.py +11 -46
- {flwr_nightly-1.13.0.dev20241111.dist-info → flwr_nightly-1.14.0.dev20241126.dist-info}/METADATA +5 -4
- {flwr_nightly-1.13.0.dev20241111.dist-info → flwr_nightly-1.14.0.dev20241126.dist-info}/RECORD +53 -52
- {flwr_nightly-1.13.0.dev20241111.dist-info → flwr_nightly-1.14.0.dev20241126.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.13.0.dev20241111.dist-info → flwr_nightly-1.14.0.dev20241126.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.13.0.dev20241111.dist-info → flwr_nightly-1.14.0.dev20241126.dist-info}/entry_points.txt +0 -0
|
@@ -40,6 +40,8 @@ from .utils import (
|
|
|
40
40
|
generate_rand_int_from_bytes,
|
|
41
41
|
has_valid_sub_status,
|
|
42
42
|
is_valid_transition,
|
|
43
|
+
verify_found_taskres,
|
|
44
|
+
verify_taskins_ids,
|
|
43
45
|
)
|
|
44
46
|
|
|
45
47
|
|
|
@@ -48,11 +50,6 @@ class RunRecord: # pylint: disable=R0902
|
|
|
48
50
|
"""The record of a specific run, including its status and timestamps."""
|
|
49
51
|
|
|
50
52
|
run: Run
|
|
51
|
-
status: RunStatus
|
|
52
|
-
pending_at: str = ""
|
|
53
|
-
starting_at: str = ""
|
|
54
|
-
running_at: str = ""
|
|
55
|
-
finished_at: str = ""
|
|
56
53
|
logs: list[tuple[float, str]] = field(default_factory=list)
|
|
57
54
|
log_lock: threading.Lock = field(default_factory=threading.Lock)
|
|
58
55
|
|
|
@@ -72,12 +69,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
72
69
|
self.federation_options: dict[int, ConfigsRecord] = {}
|
|
73
70
|
self.task_ins_store: dict[UUID, TaskIns] = {}
|
|
74
71
|
self.task_res_store: dict[UUID, TaskRes] = {}
|
|
72
|
+
self.task_ins_id_to_task_res_id: dict[UUID, UUID] = {}
|
|
75
73
|
|
|
76
74
|
self.node_public_keys: set[bytes] = set()
|
|
77
75
|
self.server_public_key: Optional[bytes] = None
|
|
78
76
|
self.server_private_key: Optional[bytes] = None
|
|
79
77
|
|
|
80
|
-
self.lock = threading.
|
|
78
|
+
self.lock = threading.RLock()
|
|
81
79
|
|
|
82
80
|
def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
|
|
83
81
|
"""Store one TaskIns."""
|
|
@@ -227,42 +225,50 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
227
225
|
task_res.task_id = str(task_id)
|
|
228
226
|
with self.lock:
|
|
229
227
|
self.task_res_store[task_id] = task_res
|
|
228
|
+
self.task_ins_id_to_task_res_id[UUID(task_ins_id)] = task_id
|
|
230
229
|
|
|
231
230
|
# Return the new task_id
|
|
232
231
|
return task_id
|
|
233
232
|
|
|
234
233
|
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
|
235
|
-
"""Get
|
|
234
|
+
"""Get TaskRes for the given TaskIns IDs."""
|
|
235
|
+
ret: dict[UUID, TaskRes] = {}
|
|
236
|
+
|
|
236
237
|
with self.lock:
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
if
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
238
|
+
current = time.time()
|
|
239
|
+
|
|
240
|
+
# Verify TaskIns IDs
|
|
241
|
+
ret = verify_taskins_ids(
|
|
242
|
+
inquired_taskins_ids=task_ids,
|
|
243
|
+
found_taskins_dict=self.task_ins_store,
|
|
244
|
+
current_time=current,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Find all TaskRes
|
|
248
|
+
task_res_found: list[TaskRes] = []
|
|
249
|
+
for task_id in task_ids:
|
|
250
|
+
# If TaskRes exists and is not delivered, add it to the list
|
|
251
|
+
if task_res_id := self.task_ins_id_to_task_res_id.get(task_id):
|
|
252
|
+
task_res = self.task_res_store[task_res_id]
|
|
253
|
+
if task_res.task.delivered_at == "":
|
|
254
|
+
task_res_found.append(task_res)
|
|
255
|
+
tmp_ret_dict = verify_found_taskres(
|
|
256
|
+
inquired_taskins_ids=task_ids,
|
|
257
|
+
found_taskins_dict=self.task_ins_store,
|
|
258
|
+
found_taskres_list=task_res_found,
|
|
259
|
+
current_time=current,
|
|
260
|
+
)
|
|
261
|
+
ret.update(tmp_ret_dict)
|
|
262
|
+
|
|
263
|
+
# Mark existing TaskRes to be returned as delivered
|
|
260
264
|
delivered_at = now().isoformat()
|
|
261
|
-
for task_res in
|
|
265
|
+
for task_res in task_res_found:
|
|
262
266
|
task_res.task.delivered_at = delivered_at
|
|
263
267
|
|
|
264
|
-
#
|
|
265
|
-
|
|
268
|
+
# Cleanup
|
|
269
|
+
self._force_delete_tasks_by_ids(set(ret.keys()))
|
|
270
|
+
|
|
271
|
+
return list(ret.values())
|
|
266
272
|
|
|
267
273
|
def delete_tasks(self, task_ids: set[UUID]) -> None:
|
|
268
274
|
"""Delete all delivered TaskIns/TaskRes pairs."""
|
|
@@ -283,9 +289,25 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
283
289
|
|
|
284
290
|
for task_id in task_ins_to_be_deleted:
|
|
285
291
|
del self.task_ins_store[task_id]
|
|
292
|
+
del self.task_ins_id_to_task_res_id[task_id]
|
|
286
293
|
for task_id in task_res_to_be_deleted:
|
|
287
294
|
del self.task_res_store[task_id]
|
|
288
295
|
|
|
296
|
+
def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
|
|
297
|
+
"""Delete tasks based on a set of TaskIns IDs."""
|
|
298
|
+
if not task_ids:
|
|
299
|
+
return
|
|
300
|
+
|
|
301
|
+
with self.lock:
|
|
302
|
+
for task_id in task_ids:
|
|
303
|
+
# Delete TaskIns
|
|
304
|
+
if task_id in self.task_ins_store:
|
|
305
|
+
del self.task_ins_store[task_id]
|
|
306
|
+
# Delete TaskRes
|
|
307
|
+
if task_id in self.task_ins_id_to_task_res_id:
|
|
308
|
+
task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
|
|
309
|
+
del self.task_res_store[task_res_id]
|
|
310
|
+
|
|
289
311
|
def num_task_ins(self) -> int:
|
|
290
312
|
"""Calculate the number of task_ins in store.
|
|
291
313
|
|
|
@@ -386,13 +408,16 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
386
408
|
fab_version=fab_version if fab_version else "",
|
|
387
409
|
fab_hash=fab_hash if fab_hash else "",
|
|
388
410
|
override_config=override_config,
|
|
411
|
+
pending_at=now().isoformat(),
|
|
412
|
+
starting_at="",
|
|
413
|
+
running_at="",
|
|
414
|
+
finished_at="",
|
|
415
|
+
status=RunStatus(
|
|
416
|
+
status=Status.PENDING,
|
|
417
|
+
sub_status="",
|
|
418
|
+
details="",
|
|
419
|
+
),
|
|
389
420
|
),
|
|
390
|
-
status=RunStatus(
|
|
391
|
-
status=Status.PENDING,
|
|
392
|
-
sub_status="",
|
|
393
|
-
details="",
|
|
394
|
-
),
|
|
395
|
-
pending_at=now().isoformat(),
|
|
396
421
|
)
|
|
397
422
|
self.run_ids[run_id] = run_record
|
|
398
423
|
|
|
@@ -452,7 +477,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
452
477
|
"""Retrieve the statuses for the specified runs."""
|
|
453
478
|
with self.lock:
|
|
454
479
|
return {
|
|
455
|
-
run_id: self.run_ids[run_id].status
|
|
480
|
+
run_id: self.run_ids[run_id].run.status
|
|
456
481
|
for run_id in set(run_ids)
|
|
457
482
|
if run_id in self.run_ids
|
|
458
483
|
}
|
|
@@ -466,7 +491,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
466
491
|
return False
|
|
467
492
|
|
|
468
493
|
# Check if the status transition is valid
|
|
469
|
-
current_status = self.run_ids[run_id].status
|
|
494
|
+
current_status = self.run_ids[run_id].run.status
|
|
470
495
|
if not is_valid_transition(current_status, new_status):
|
|
471
496
|
log(
|
|
472
497
|
ERROR,
|
|
@@ -489,12 +514,12 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
489
514
|
# Update the status
|
|
490
515
|
run_record = self.run_ids[run_id]
|
|
491
516
|
if new_status.status == Status.STARTING:
|
|
492
|
-
run_record.starting_at = now().isoformat()
|
|
517
|
+
run_record.run.starting_at = now().isoformat()
|
|
493
518
|
elif new_status.status == Status.RUNNING:
|
|
494
|
-
run_record.running_at = now().isoformat()
|
|
519
|
+
run_record.run.running_at = now().isoformat()
|
|
495
520
|
elif new_status.status == Status.FINISHED:
|
|
496
|
-
run_record.finished_at = now().isoformat()
|
|
497
|
-
run_record.status = new_status
|
|
521
|
+
run_record.run.finished_at = now().isoformat()
|
|
522
|
+
run_record.run.status = new_status
|
|
498
523
|
return True
|
|
499
524
|
|
|
500
525
|
def get_pending_run_id(self) -> Optional[int]:
|
|
@@ -504,7 +529,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
504
529
|
# Loop through all registered runs
|
|
505
530
|
for run_id, run_rec in self.run_ids.items():
|
|
506
531
|
# Break once a pending run is found
|
|
507
|
-
if run_rec.status.status == Status.PENDING:
|
|
532
|
+
if run_rec.run.status.status == Status.PENDING:
|
|
508
533
|
pending_run_id = run_id
|
|
509
534
|
break
|
|
510
535
|
|
|
@@ -101,13 +101,27 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
101
101
|
|
|
102
102
|
@abc.abstractmethod
|
|
103
103
|
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
|
104
|
-
"""Get TaskRes for
|
|
104
|
+
"""Get TaskRes for the given TaskIns IDs.
|
|
105
105
|
|
|
106
|
-
|
|
107
|
-
|
|
106
|
+
This method is typically called by the ServerAppIo API to obtain
|
|
107
|
+
results (TaskRes) for previously scheduled instructions (TaskIns).
|
|
108
|
+
For each task_id provided, this method returns one of the following responses:
|
|
108
109
|
|
|
109
|
-
|
|
110
|
-
|
|
110
|
+
- An error TaskRes if the corresponding TaskIns does not exist or has expired.
|
|
111
|
+
- An error TaskRes if the corresponding TaskRes exists but has expired.
|
|
112
|
+
- The valid TaskRes if the TaskIns has a corresponding valid TaskRes.
|
|
113
|
+
- Nothing if the TaskIns is still valid and waiting for a TaskRes.
|
|
114
|
+
|
|
115
|
+
Parameters
|
|
116
|
+
----------
|
|
117
|
+
task_ids : set[UUID]
|
|
118
|
+
A set of TaskIns IDs for which to retrieve results (TaskRes).
|
|
119
|
+
|
|
120
|
+
Returns
|
|
121
|
+
-------
|
|
122
|
+
list[TaskRes]
|
|
123
|
+
A list of TaskRes corresponding to the given task IDs. If no
|
|
124
|
+
TaskRes could be found for any of the task IDs, an empty list is returned.
|
|
111
125
|
"""
|
|
112
126
|
|
|
113
127
|
@abc.abstractmethod
|
|
@@ -57,6 +57,8 @@ from .utils import (
|
|
|
57
57
|
generate_rand_int_from_bytes,
|
|
58
58
|
has_valid_sub_status,
|
|
59
59
|
is_valid_transition,
|
|
60
|
+
verify_found_taskres,
|
|
61
|
+
verify_taskins_ids,
|
|
60
62
|
)
|
|
61
63
|
|
|
62
64
|
SQL_CREATE_TABLE_NODE = """
|
|
@@ -510,136 +512,67 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
510
512
|
|
|
511
513
|
# pylint: disable-next=R0912,R0915,R0914
|
|
512
514
|
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
|
513
|
-
"""Get TaskRes for
|
|
515
|
+
"""Get TaskRes for the given TaskIns IDs."""
|
|
516
|
+
ret: dict[UUID, TaskRes] = {}
|
|
514
517
|
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
Retrieves all TaskRes for the given `task_ids` and returns and empty list if
|
|
519
|
-
none could be found.
|
|
520
|
-
|
|
521
|
-
Constraints
|
|
522
|
-
-----------
|
|
523
|
-
If `limit` is not `None`, return, at most, `limit` number of TaskRes. The limit
|
|
524
|
-
will only take effect if enough task_ids are in the set AND are currently
|
|
525
|
-
available. If `limit` is set, it has to be greater than zero.
|
|
526
|
-
"""
|
|
527
|
-
# Check if corresponding TaskIns exists and is not expired
|
|
528
|
-
task_ids_placeholders = ",".join([f":id_{i}" for i in range(len(task_ids))])
|
|
518
|
+
# Verify TaskIns IDs
|
|
519
|
+
current = time.time()
|
|
529
520
|
query = f"""
|
|
530
521
|
SELECT *
|
|
531
522
|
FROM task_ins
|
|
532
|
-
WHERE task_id IN ({
|
|
533
|
-
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
|
|
523
|
+
WHERE task_id IN ({",".join(["?"] * len(task_ids))});
|
|
534
524
|
"""
|
|
535
|
-
query
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
for index, task_id in enumerate(task_ids):
|
|
539
|
-
task_ins_data[f"id_{index}"] = str(task_id)
|
|
540
|
-
|
|
541
|
-
task_ins_rows = self.query(query, task_ins_data)
|
|
542
|
-
|
|
543
|
-
if not task_ins_rows:
|
|
544
|
-
return []
|
|
545
|
-
|
|
546
|
-
for row in task_ins_rows:
|
|
547
|
-
# Convert values from sint64 to uint64
|
|
525
|
+
rows = self.query(query, tuple(str(task_id) for task_id in task_ids))
|
|
526
|
+
found_task_ins_dict: dict[UUID, TaskIns] = {}
|
|
527
|
+
for row in rows:
|
|
548
528
|
convert_sint64_values_in_dict_to_uint64(
|
|
549
529
|
row, ["run_id", "producer_node_id", "consumer_node_id"]
|
|
550
530
|
)
|
|
551
|
-
|
|
552
|
-
if task_ins.task.created_at + task_ins.task.ttl <= time.time():
|
|
553
|
-
log(WARNING, "TaskIns with task_id %s is expired.", task_ins.task_id)
|
|
554
|
-
task_ids.remove(UUID(task_ins.task_id))
|
|
531
|
+
found_task_ins_dict[UUID(row["task_id"])] = dict_to_task_ins(row)
|
|
555
532
|
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
533
|
+
ret = verify_taskins_ids(
|
|
534
|
+
inquired_taskins_ids=task_ids,
|
|
535
|
+
found_taskins_dict=found_task_ins_dict,
|
|
536
|
+
current_time=current,
|
|
537
|
+
)
|
|
559
538
|
|
|
560
|
-
|
|
539
|
+
# Find all TaskRes
|
|
561
540
|
query = f"""
|
|
562
541
|
SELECT *
|
|
563
542
|
FROM task_res
|
|
564
|
-
WHERE ancestry IN ({
|
|
565
|
-
AND delivered_at = ""
|
|
566
|
-
"""
|
|
567
|
-
|
|
568
|
-
data: dict[str, Union[str, float, int]] = {}
|
|
569
|
-
|
|
570
|
-
query += ";"
|
|
571
|
-
|
|
572
|
-
for index, task_id in enumerate(task_ids):
|
|
573
|
-
data[f"id_{index}"] = str(task_id)
|
|
574
|
-
|
|
575
|
-
rows = self.query(query, data)
|
|
576
|
-
|
|
577
|
-
if rows:
|
|
578
|
-
# Prepare query
|
|
579
|
-
found_task_ids = [row["task_id"] for row in rows]
|
|
580
|
-
placeholders = ",".join([f":id_{i}" for i in range(len(found_task_ids))])
|
|
581
|
-
query = f"""
|
|
582
|
-
UPDATE task_res
|
|
583
|
-
SET delivered_at = :delivered_at
|
|
584
|
-
WHERE task_id IN ({placeholders})
|
|
585
|
-
RETURNING *;
|
|
586
|
-
"""
|
|
587
|
-
|
|
588
|
-
# Prepare data for query
|
|
589
|
-
delivered_at = now().isoformat()
|
|
590
|
-
data = {"delivered_at": delivered_at}
|
|
591
|
-
for index, task_id in enumerate(found_task_ids):
|
|
592
|
-
data[f"id_{index}"] = str(task_id)
|
|
593
|
-
|
|
594
|
-
# Run query
|
|
595
|
-
rows = self.query(query, data)
|
|
596
|
-
|
|
597
|
-
for row in rows:
|
|
598
|
-
# Convert values from sint64 to uint64
|
|
599
|
-
convert_sint64_values_in_dict_to_uint64(
|
|
600
|
-
row, ["run_id", "producer_node_id", "consumer_node_id"]
|
|
601
|
-
)
|
|
602
|
-
|
|
603
|
-
result = [dict_to_task_res(row) for row in rows]
|
|
604
|
-
|
|
605
|
-
# 1. Query: Fetch consumer_node_id of remaining task_ids
|
|
606
|
-
# Assume the ancestry field only contains one element
|
|
607
|
-
data.clear()
|
|
608
|
-
replied_task_ids: set[UUID] = {UUID(str(row["ancestry"])) for row in rows}
|
|
609
|
-
remaining_task_ids = task_ids - replied_task_ids
|
|
610
|
-
placeholders = ",".join([f":id_{i}" for i in range(len(remaining_task_ids))])
|
|
611
|
-
query = f"""
|
|
612
|
-
SELECT consumer_node_id
|
|
613
|
-
FROM task_ins
|
|
614
|
-
WHERE task_id IN ({placeholders});
|
|
543
|
+
WHERE ancestry IN ({",".join(["?"] * len(task_ids))})
|
|
544
|
+
AND delivered_at = "";
|
|
615
545
|
"""
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
546
|
+
rows = self.query(query, tuple(str(task_id) for task_id in task_ids))
|
|
547
|
+
for row in rows:
|
|
548
|
+
convert_sint64_values_in_dict_to_uint64(
|
|
549
|
+
row, ["run_id", "producer_node_id", "consumer_node_id"]
|
|
550
|
+
)
|
|
551
|
+
tmp_ret_dict = verify_found_taskres(
|
|
552
|
+
inquired_taskins_ids=task_ids,
|
|
553
|
+
found_taskins_dict=found_task_ins_dict,
|
|
554
|
+
found_taskres_list=[dict_to_task_res(row) for row in rows],
|
|
555
|
+
current_time=current,
|
|
556
|
+
)
|
|
557
|
+
ret.update(tmp_ret_dict)
|
|
619
558
|
|
|
620
|
-
#
|
|
621
|
-
|
|
559
|
+
# Mark existing TaskRes to be returned as delivered
|
|
560
|
+
delivered_at = now().isoformat()
|
|
561
|
+
for task_res in ret.values():
|
|
562
|
+
task_res.task.delivered_at = delivered_at
|
|
563
|
+
task_res_ids = [task_res.task_id for task_res in ret.values()]
|
|
622
564
|
query = f"""
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
WHERE
|
|
626
|
-
AND online_until < :time;
|
|
565
|
+
UPDATE task_res
|
|
566
|
+
SET delivered_at = ?
|
|
567
|
+
WHERE task_id IN ({",".join(["?"] * len(task_res_ids))});
|
|
627
568
|
"""
|
|
628
|
-
data
|
|
629
|
-
|
|
630
|
-
offline_node_ids = [int(row["node_id"]) for row in self.query(query, data)]
|
|
569
|
+
data: list[Any] = [delivered_at] + task_res_ids
|
|
570
|
+
self.query(query, data)
|
|
631
571
|
|
|
632
|
-
#
|
|
633
|
-
|
|
634
|
-
query = f"""
|
|
635
|
-
SELECT *
|
|
636
|
-
FROM task_ins
|
|
637
|
-
WHERE consumer_node_id IN ({placeholders});
|
|
638
|
-
"""
|
|
639
|
-
data = {f"id_{i}": str(node_id) for i, node_id in enumerate(offline_node_ids)}
|
|
640
|
-
task_ins_rows = self.query(query, data)
|
|
572
|
+
# Cleanup
|
|
573
|
+
self._force_delete_tasks_by_ids(set(ret.keys()))
|
|
641
574
|
|
|
642
|
-
return
|
|
575
|
+
return list(ret.values())
|
|
643
576
|
|
|
644
577
|
def num_task_ins(self) -> int:
|
|
645
578
|
"""Calculate the number of task_ins in store.
|
|
@@ -699,6 +632,32 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
699
632
|
|
|
700
633
|
return None
|
|
701
634
|
|
|
635
|
+
def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
|
|
636
|
+
"""Delete tasks based on a set of TaskIns IDs."""
|
|
637
|
+
if not task_ids:
|
|
638
|
+
return
|
|
639
|
+
if self.conn is None:
|
|
640
|
+
raise AttributeError("LinkState not initialized")
|
|
641
|
+
|
|
642
|
+
placeholders = ",".join([f":id_{index}" for index in range(len(task_ids))])
|
|
643
|
+
data = {f"id_{index}": str(task_id) for index, task_id in enumerate(task_ids)}
|
|
644
|
+
|
|
645
|
+
# Delete task_ins
|
|
646
|
+
query_1 = f"""
|
|
647
|
+
DELETE FROM task_ins
|
|
648
|
+
WHERE task_id IN ({placeholders});
|
|
649
|
+
"""
|
|
650
|
+
|
|
651
|
+
# Delete task_res
|
|
652
|
+
query_2 = f"""
|
|
653
|
+
DELETE FROM task_res
|
|
654
|
+
WHERE ancestry IN ({placeholders});
|
|
655
|
+
"""
|
|
656
|
+
|
|
657
|
+
with self.conn:
|
|
658
|
+
self.conn.execute(query_1, data)
|
|
659
|
+
self.conn.execute(query_2, data)
|
|
660
|
+
|
|
702
661
|
def create_node(
|
|
703
662
|
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
704
663
|
) -> int:
|
|
@@ -922,6 +881,15 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
922
881
|
fab_version=row["fab_version"],
|
|
923
882
|
fab_hash=row["fab_hash"],
|
|
924
883
|
override_config=json.loads(row["override_config"]),
|
|
884
|
+
pending_at=row["pending_at"],
|
|
885
|
+
starting_at=row["starting_at"],
|
|
886
|
+
running_at=row["running_at"],
|
|
887
|
+
finished_at=row["finished_at"],
|
|
888
|
+
status=RunStatus(
|
|
889
|
+
status=determine_run_status(row),
|
|
890
|
+
sub_status=row["sub_status"],
|
|
891
|
+
details=row["details"],
|
|
892
|
+
),
|
|
925
893
|
)
|
|
926
894
|
log(ERROR, "`run_id` does not exist.")
|
|
927
895
|
return None
|
|
@@ -1255,10 +1223,10 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
|
|
|
1255
1223
|
def determine_run_status(row: dict[str, Any]) -> str:
|
|
1256
1224
|
"""Determine the status of the run based on timestamp fields."""
|
|
1257
1225
|
if row["pending_at"]:
|
|
1226
|
+
if row["finished_at"]:
|
|
1227
|
+
return Status.FINISHED
|
|
1258
1228
|
if row["starting_at"]:
|
|
1259
1229
|
if row["running_at"]:
|
|
1260
|
-
if row["finished_at"]:
|
|
1261
|
-
return Status.FINISHED
|
|
1262
1230
|
return Status.RUNNING
|
|
1263
1231
|
return Status.STARTING
|
|
1264
1232
|
return Status.PENDING
|