flwr-nightly 1.16.0.dev20250306__py3-none-any.whl → 1.16.0.dev20250308__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/rest_client/connection.py +4 -6
- flwr/common/message.py +7 -7
- flwr/common/record/recordset.py +4 -12
- flwr/common/serde.py +8 -126
- flwr/server/compat/driver_client_proxy.py +2 -2
- flwr/server/driver/inmemory_driver.py +15 -18
- flwr/server/superlink/driver/serverappio_servicer.py +18 -23
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +32 -35
- flwr/server/superlink/linkstate/in_memory_linkstate.py +1 -221
- flwr/server/superlink/linkstate/linkstate.py +0 -113
- flwr/server/superlink/linkstate/sqlite_linkstate.py +2 -511
- flwr/server/superlink/linkstate/utils.py +2 -179
- flwr/server/utils/__init__.py +0 -2
- flwr/server/utils/validator.py +0 -88
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +3 -3
- flwr/superexec/exec_servicer.py +3 -3
- {flwr_nightly-1.16.0.dev20250306.dist-info → flwr_nightly-1.16.0.dev20250308.dist-info}/METADATA +1 -1
- {flwr_nightly-1.16.0.dev20250306.dist-info → flwr_nightly-1.16.0.dev20250308.dist-info}/RECORD +25 -30
- flwr/client/message_handler/task_handler.py +0 -37
- flwr/proto/task_pb2.py +0 -33
- flwr/proto/task_pb2.pyi +0 -100
- flwr/proto/task_pb2_grpc.py +0 -4
- flwr/proto/task_pb2_grpc.pyi +0 -4
- {flwr_nightly-1.16.0.dev20250306.dist-info → flwr_nightly-1.16.0.dev20250308.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.16.0.dev20250306.dist-info → flwr_nightly-1.16.0.dev20250308.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.16.0.dev20250306.dist-info → flwr_nightly-1.16.0.dev20250308.dist-info}/entry_points.txt +0 -0
@@ -45,12 +45,10 @@ from flwr.common.typing import Run, RunStatus, UserConfig
|
|
45
45
|
|
46
46
|
# pylint: disable=E0611
|
47
47
|
from flwr.proto.error_pb2 import Error as ProtoError
|
48
|
-
from flwr.proto.node_pb2 import Node
|
49
48
|
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
|
50
|
-
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
51
49
|
|
52
50
|
# pylint: enable=E0611
|
53
|
-
from flwr.server.utils.validator import validate_message
|
51
|
+
from flwr.server.utils.validator import validate_message
|
54
52
|
|
55
53
|
from .linkstate import LinkState
|
56
54
|
from .utils import (
|
@@ -66,9 +64,7 @@ from .utils import (
|
|
66
64
|
has_valid_sub_status,
|
67
65
|
is_valid_transition,
|
68
66
|
verify_found_message_replies,
|
69
|
-
verify_found_taskres,
|
70
67
|
verify_message_ids,
|
71
|
-
verify_taskins_ids,
|
72
68
|
)
|
73
69
|
|
74
70
|
SQL_CREATE_TABLE_NODE = """
|
@@ -126,23 +122,6 @@ CREATE TABLE IF NOT EXISTS context(
|
|
126
122
|
);
|
127
123
|
"""
|
128
124
|
|
129
|
-
SQL_CREATE_TABLE_TASK_INS = """
|
130
|
-
CREATE TABLE IF NOT EXISTS task_ins(
|
131
|
-
task_id TEXT UNIQUE,
|
132
|
-
group_id TEXT,
|
133
|
-
run_id INTEGER,
|
134
|
-
producer_node_id INTEGER,
|
135
|
-
consumer_node_id INTEGER,
|
136
|
-
created_at REAL,
|
137
|
-
delivered_at TEXT,
|
138
|
-
ttl REAL,
|
139
|
-
ancestry TEXT,
|
140
|
-
task_type TEXT,
|
141
|
-
recordset BLOB,
|
142
|
-
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
143
|
-
);
|
144
|
-
"""
|
145
|
-
|
146
125
|
SQL_CREATE_TABLE_MESSAGE_INS = """
|
147
126
|
CREATE TABLE IF NOT EXISTS message_ins(
|
148
127
|
message_id TEXT UNIQUE,
|
@@ -161,23 +140,6 @@ CREATE TABLE IF NOT EXISTS message_ins(
|
|
161
140
|
);
|
162
141
|
"""
|
163
142
|
|
164
|
-
SQL_CREATE_TABLE_TASK_RES = """
|
165
|
-
CREATE TABLE IF NOT EXISTS task_res(
|
166
|
-
task_id TEXT UNIQUE,
|
167
|
-
group_id TEXT,
|
168
|
-
run_id INTEGER,
|
169
|
-
producer_node_id INTEGER,
|
170
|
-
consumer_node_id INTEGER,
|
171
|
-
created_at REAL,
|
172
|
-
delivered_at TEXT,
|
173
|
-
ttl REAL,
|
174
|
-
ancestry TEXT,
|
175
|
-
task_type TEXT,
|
176
|
-
recordset BLOB,
|
177
|
-
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
178
|
-
);
|
179
|
-
"""
|
180
|
-
|
181
143
|
|
182
144
|
SQL_CREATE_TABLE_MESSAGE_RES = """
|
183
145
|
CREATE TABLE IF NOT EXISTS message_res(
|
@@ -242,8 +204,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
242
204
|
cur.execute(SQL_CREATE_TABLE_RUN)
|
243
205
|
cur.execute(SQL_CREATE_TABLE_LOGS)
|
244
206
|
cur.execute(SQL_CREATE_TABLE_CONTEXT)
|
245
|
-
cur.execute(SQL_CREATE_TABLE_TASK_INS)
|
246
|
-
cur.execute(SQL_CREATE_TABLE_TASK_RES)
|
247
207
|
cur.execute(SQL_CREATE_TABLE_MESSAGE_INS)
|
248
208
|
cur.execute(SQL_CREATE_TABLE_MESSAGE_RES)
|
249
209
|
cur.execute(SQL_CREATE_TABLE_NODE)
|
@@ -287,69 +247,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
287
247
|
|
288
248
|
return result
|
289
249
|
|
290
|
-
def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
|
291
|
-
"""Store one TaskIns.
|
292
|
-
|
293
|
-
Usually, the ServerAppIo API calls this to schedule instructions.
|
294
|
-
|
295
|
-
Stores the value of the task_ins in the link state and, if successful,
|
296
|
-
returns the task_id (UUID) of the task_ins. If, for any reason, storing
|
297
|
-
the task_ins fails, `None` is returned.
|
298
|
-
|
299
|
-
Constraints
|
300
|
-
-----------
|
301
|
-
|
302
|
-
`task_ins.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
|
303
|
-
"""
|
304
|
-
# Validate task
|
305
|
-
errors = validate_task_ins_or_res(task_ins)
|
306
|
-
if any(errors):
|
307
|
-
log(ERROR, errors)
|
308
|
-
return None
|
309
|
-
# Create task_id
|
310
|
-
task_id = uuid4()
|
311
|
-
|
312
|
-
# Store TaskIns
|
313
|
-
task_ins.task_id = str(task_id)
|
314
|
-
data = (task_ins_to_dict(task_ins),)
|
315
|
-
|
316
|
-
# Convert values from uint64 to sint64 for SQLite
|
317
|
-
convert_uint64_values_in_dict_to_sint64(
|
318
|
-
data[0], ["run_id", "producer_node_id", "consumer_node_id"]
|
319
|
-
)
|
320
|
-
|
321
|
-
# Validate run_id
|
322
|
-
query = "SELECT run_id FROM run WHERE run_id = ?;"
|
323
|
-
if not self.query(query, (data[0]["run_id"],)):
|
324
|
-
log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
|
325
|
-
return None
|
326
|
-
# Validate source node ID
|
327
|
-
if task_ins.task.producer.node_id != SUPERLINK_NODE_ID:
|
328
|
-
log(
|
329
|
-
ERROR,
|
330
|
-
"Invalid source node ID for TaskIns: %s",
|
331
|
-
task_ins.task.producer.node_id,
|
332
|
-
)
|
333
|
-
return None
|
334
|
-
# Validate destination node ID
|
335
|
-
query = "SELECT node_id FROM node WHERE node_id = ?;"
|
336
|
-
if not self.query(query, (data[0]["consumer_node_id"],)):
|
337
|
-
log(
|
338
|
-
ERROR,
|
339
|
-
"Invalid destination node ID for TaskIns: %s",
|
340
|
-
task_ins.task.consumer.node_id,
|
341
|
-
)
|
342
|
-
return None
|
343
|
-
|
344
|
-
columns = ", ".join([f":{key}" for key in data[0]])
|
345
|
-
query = f"INSERT INTO task_ins VALUES({columns});"
|
346
|
-
|
347
|
-
# Only invalid run_id can trigger IntegrityError.
|
348
|
-
# This may need to be changed in the future version with more integrity checks.
|
349
|
-
self.query(query, data)
|
350
|
-
|
351
|
-
return task_id
|
352
|
-
|
353
250
|
def store_message_ins(self, message: Message) -> Optional[UUID]:
|
354
251
|
"""Store one Message."""
|
355
252
|
# Validate message
|
@@ -404,84 +301,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
404
301
|
|
405
302
|
return message_id
|
406
303
|
|
407
|
-
def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
|
408
|
-
"""Get undelivered TaskIns for one node.
|
409
|
-
|
410
|
-
Usually, the Fleet API calls this for Nodes planning to work on one or more
|
411
|
-
TaskIns.
|
412
|
-
|
413
|
-
Constraints
|
414
|
-
-----------
|
415
|
-
Retrieve all TaskIns where
|
416
|
-
|
417
|
-
1. the `task_ins.task.consumer.node_id` equals `node_id` AND
|
418
|
-
2. the `task_ins.task.delivered_at` equals `""`.
|
419
|
-
|
420
|
-
`delivered_at` MUST BE set (i.e., not `""`) otherwise the TaskIns MUST not be in
|
421
|
-
the result.
|
422
|
-
|
423
|
-
If `limit` is not `None`, return, at most, `limit` number of `task_ins`. If
|
424
|
-
`limit` is set, it has to be greater than zero.
|
425
|
-
"""
|
426
|
-
if limit is not None and limit < 1:
|
427
|
-
raise AssertionError("`limit` must be >= 1")
|
428
|
-
|
429
|
-
if node_id == SUPERLINK_NODE_ID:
|
430
|
-
msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
|
431
|
-
raise AssertionError(msg)
|
432
|
-
|
433
|
-
data: dict[str, Union[str, int]] = {}
|
434
|
-
|
435
|
-
# Convert the uint64 value to sint64 for SQLite
|
436
|
-
data["node_id"] = convert_uint64_to_sint64(node_id)
|
437
|
-
|
438
|
-
# Retrieve all TaskIns for node_id
|
439
|
-
query = """
|
440
|
-
SELECT task_id
|
441
|
-
FROM task_ins
|
442
|
-
WHERE consumer_node_id == :node_id
|
443
|
-
AND delivered_at = ""
|
444
|
-
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
|
445
|
-
"""
|
446
|
-
|
447
|
-
if limit is not None:
|
448
|
-
query += " LIMIT :limit"
|
449
|
-
data["limit"] = limit
|
450
|
-
|
451
|
-
query += ";"
|
452
|
-
|
453
|
-
rows = self.query(query, data)
|
454
|
-
|
455
|
-
if rows:
|
456
|
-
# Prepare query
|
457
|
-
task_ids = [row["task_id"] for row in rows]
|
458
|
-
placeholders: str = ",".join([f":id_{i}" for i in range(len(task_ids))])
|
459
|
-
query = f"""
|
460
|
-
UPDATE task_ins
|
461
|
-
SET delivered_at = :delivered_at
|
462
|
-
WHERE task_id IN ({placeholders})
|
463
|
-
RETURNING *;
|
464
|
-
"""
|
465
|
-
|
466
|
-
# Prepare data for query
|
467
|
-
delivered_at = now().isoformat()
|
468
|
-
data = {"delivered_at": delivered_at}
|
469
|
-
for index, task_id in enumerate(task_ids):
|
470
|
-
data[f"id_{index}"] = str(task_id)
|
471
|
-
|
472
|
-
# Run query
|
473
|
-
rows = self.query(query, data)
|
474
|
-
|
475
|
-
for row in rows:
|
476
|
-
# Convert values from sint64 to uint64
|
477
|
-
convert_sint64_values_in_dict_to_uint64(
|
478
|
-
row, ["run_id", "producer_node_id", "consumer_node_id"]
|
479
|
-
)
|
480
|
-
|
481
|
-
result = [dict_to_task_ins(row) for row in rows]
|
482
|
-
|
483
|
-
return result
|
484
|
-
|
485
304
|
def get_message_ins(self, node_id: int, limit: Optional[int]) -> list[Message]:
|
486
305
|
"""Get all Messages that have not been delivered yet."""
|
487
306
|
if limit is not None and limit < 1:
|
@@ -543,90 +362,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
543
362
|
|
544
363
|
return result
|
545
364
|
|
546
|
-
def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
|
547
|
-
"""Store one TaskRes.
|
548
|
-
|
549
|
-
Usually, the Fleet API calls this when Nodes return their results.
|
550
|
-
|
551
|
-
Stores the TaskRes and, if successful, returns the `task_id` (UUID) of
|
552
|
-
the `task_res`. If storing the `task_res` fails, `None` is returned.
|
553
|
-
|
554
|
-
Constraints
|
555
|
-
-----------
|
556
|
-
`task_res.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
|
557
|
-
"""
|
558
|
-
# Validate task
|
559
|
-
errors = validate_task_ins_or_res(task_res)
|
560
|
-
if any(errors):
|
561
|
-
log(ERROR, errors)
|
562
|
-
return None
|
563
|
-
|
564
|
-
# Create task_id
|
565
|
-
task_id = uuid4()
|
566
|
-
|
567
|
-
task_ins_id = task_res.task.ancestry[0]
|
568
|
-
task_ins = self.get_valid_task_ins(task_ins_id)
|
569
|
-
if task_ins is None:
|
570
|
-
log(
|
571
|
-
ERROR,
|
572
|
-
"Failed to store TaskRes: "
|
573
|
-
"TaskIns with task_id %s does not exist or has expired.",
|
574
|
-
task_ins_id,
|
575
|
-
)
|
576
|
-
return None
|
577
|
-
|
578
|
-
# Ensure that the consumer_id of taskIns matches the producer_id of taskRes.
|
579
|
-
if (
|
580
|
-
task_ins
|
581
|
-
and task_res
|
582
|
-
and convert_sint64_to_uint64(task_ins["consumer_node_id"])
|
583
|
-
!= task_res.task.producer.node_id
|
584
|
-
):
|
585
|
-
return None
|
586
|
-
|
587
|
-
# Fail if the TaskRes TTL exceeds the
|
588
|
-
# expiration time of the TaskIns it replies to.
|
589
|
-
# Condition: TaskIns.created_at + TaskIns.ttl ≥
|
590
|
-
# TaskRes.created_at + TaskRes.ttl
|
591
|
-
# A small tolerance is introduced to account
|
592
|
-
# for floating-point precision issues.
|
593
|
-
max_allowed_ttl = (
|
594
|
-
task_ins["created_at"] + task_ins["ttl"] - task_res.task.created_at
|
595
|
-
)
|
596
|
-
if task_res.task.ttl and (
|
597
|
-
task_res.task.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
|
598
|
-
):
|
599
|
-
log(
|
600
|
-
WARNING,
|
601
|
-
"Received TaskRes with TTL %.2f "
|
602
|
-
"exceeding the allowed maximum TTL %.2f.",
|
603
|
-
task_res.task.ttl,
|
604
|
-
max_allowed_ttl,
|
605
|
-
)
|
606
|
-
return None
|
607
|
-
|
608
|
-
# Store TaskRes
|
609
|
-
task_res.task_id = str(task_id)
|
610
|
-
data = (task_res_to_dict(task_res),)
|
611
|
-
|
612
|
-
# Convert values from uint64 to sint64 for SQLite
|
613
|
-
convert_uint64_values_in_dict_to_sint64(
|
614
|
-
data[0], ["run_id", "producer_node_id", "consumer_node_id"]
|
615
|
-
)
|
616
|
-
|
617
|
-
columns = ", ".join([f":{key}" for key in data[0]])
|
618
|
-
query = f"INSERT INTO task_res VALUES({columns});"
|
619
|
-
|
620
|
-
# Only invalid run_id can trigger IntegrityError.
|
621
|
-
# This may need to be changed in the future version with more integrity checks.
|
622
|
-
try:
|
623
|
-
self.query(query, data)
|
624
|
-
except sqlite3.IntegrityError:
|
625
|
-
log(ERROR, "`run` is invalid")
|
626
|
-
return None
|
627
|
-
|
628
|
-
return task_id
|
629
|
-
|
630
365
|
def store_message_res(self, message: Message) -> Optional[UUID]:
|
631
366
|
"""Store one Message."""
|
632
367
|
# Validate message
|
@@ -705,67 +440,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
705
440
|
|
706
441
|
return message_id
|
707
442
|
|
708
|
-
# pylint: disable-next=R0912,R0915,R0914
|
709
|
-
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
710
|
-
"""Get TaskRes for the given TaskIns IDs."""
|
711
|
-
ret: dict[UUID, TaskRes] = {}
|
712
|
-
|
713
|
-
# Verify TaskIns IDs
|
714
|
-
current = time.time()
|
715
|
-
query = f"""
|
716
|
-
SELECT *
|
717
|
-
FROM task_ins
|
718
|
-
WHERE task_id IN ({",".join(["?"] * len(task_ids))});
|
719
|
-
"""
|
720
|
-
rows = self.query(query, tuple(str(task_id) for task_id in task_ids))
|
721
|
-
found_task_ins_dict: dict[UUID, TaskIns] = {}
|
722
|
-
for row in rows:
|
723
|
-
convert_sint64_values_in_dict_to_uint64(
|
724
|
-
row, ["run_id", "producer_node_id", "consumer_node_id"]
|
725
|
-
)
|
726
|
-
found_task_ins_dict[UUID(row["task_id"])] = dict_to_task_ins(row)
|
727
|
-
|
728
|
-
ret = verify_taskins_ids(
|
729
|
-
inquired_taskins_ids=task_ids,
|
730
|
-
found_taskins_dict=found_task_ins_dict,
|
731
|
-
current_time=current,
|
732
|
-
)
|
733
|
-
|
734
|
-
# Find all TaskRes
|
735
|
-
query = f"""
|
736
|
-
SELECT *
|
737
|
-
FROM task_res
|
738
|
-
WHERE ancestry IN ({",".join(["?"] * len(task_ids))})
|
739
|
-
AND delivered_at = "";
|
740
|
-
"""
|
741
|
-
rows = self.query(query, tuple(str(task_id) for task_id in task_ids))
|
742
|
-
for row in rows:
|
743
|
-
convert_sint64_values_in_dict_to_uint64(
|
744
|
-
row, ["run_id", "producer_node_id", "consumer_node_id"]
|
745
|
-
)
|
746
|
-
tmp_ret_dict = verify_found_taskres(
|
747
|
-
inquired_taskins_ids=task_ids,
|
748
|
-
found_taskins_dict=found_task_ins_dict,
|
749
|
-
found_taskres_list=[dict_to_task_res(row) for row in rows],
|
750
|
-
current_time=current,
|
751
|
-
)
|
752
|
-
ret.update(tmp_ret_dict)
|
753
|
-
|
754
|
-
# Mark existing TaskRes to be returned as delivered
|
755
|
-
delivered_at = now().isoformat()
|
756
|
-
for task_res in ret.values():
|
757
|
-
task_res.task.delivered_at = delivered_at
|
758
|
-
task_res_ids = [task_res.task_id for task_res in ret.values()]
|
759
|
-
query = f"""
|
760
|
-
UPDATE task_res
|
761
|
-
SET delivered_at = ?
|
762
|
-
WHERE task_id IN ({",".join(["?"] * len(task_res_ids))});
|
763
|
-
"""
|
764
|
-
data: list[Any] = [delivered_at] + task_res_ids
|
765
|
-
self.query(query, data)
|
766
|
-
|
767
|
-
return list(ret.values())
|
768
|
-
|
769
443
|
def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
|
770
444
|
"""Get reply Messages for the given Message IDs."""
|
771
445
|
ret: dict[UUID, Message] = {}
|
@@ -828,17 +502,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
828
502
|
|
829
503
|
return list(ret.values())
|
830
504
|
|
831
|
-
def num_task_ins(self) -> int:
|
832
|
-
"""Calculate the number of task_ins in store.
|
833
|
-
|
834
|
-
This includes delivered but not yet deleted task_ins.
|
835
|
-
"""
|
836
|
-
query = "SELECT count(*) AS num FROM task_ins;"
|
837
|
-
rows = self.query(query)
|
838
|
-
result = rows[0]
|
839
|
-
num = cast(int, result["num"])
|
840
|
-
return num
|
841
|
-
|
842
505
|
def num_message_ins(self) -> int:
|
843
506
|
"""Calculate the number of instruction Messages in store.
|
844
507
|
|
@@ -850,16 +513,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
850
513
|
num = cast(int, result["num"])
|
851
514
|
return num
|
852
515
|
|
853
|
-
def num_task_res(self) -> int:
|
854
|
-
"""Calculate the number of task_res in store.
|
855
|
-
|
856
|
-
This includes delivered but not yet deleted task_res.
|
857
|
-
"""
|
858
|
-
query = "SELECT count(*) AS num FROM task_res;"
|
859
|
-
rows = self.query(query)
|
860
|
-
result: dict[str, int] = rows[0]
|
861
|
-
return result["num"]
|
862
|
-
|
863
516
|
def num_message_res(self) -> int:
|
864
517
|
"""Calculate the number of reply Messages in store.
|
865
518
|
|
@@ -870,32 +523,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
870
523
|
result: dict[str, int] = rows[0]
|
871
524
|
return result["num"]
|
872
525
|
|
873
|
-
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
874
|
-
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
|
875
|
-
if not task_ins_ids:
|
876
|
-
return
|
877
|
-
if self.conn is None:
|
878
|
-
raise AttributeError("LinkState not initialized")
|
879
|
-
|
880
|
-
placeholders = ",".join(["?"] * len(task_ins_ids))
|
881
|
-
data = tuple(str(task_id) for task_id in task_ins_ids)
|
882
|
-
|
883
|
-
# Delete task_ins
|
884
|
-
query_1 = f"""
|
885
|
-
DELETE FROM task_ins
|
886
|
-
WHERE task_id IN ({placeholders});
|
887
|
-
"""
|
888
|
-
|
889
|
-
# Delete task_res
|
890
|
-
query_2 = f"""
|
891
|
-
DELETE FROM task_res
|
892
|
-
WHERE ancestry IN ({placeholders});
|
893
|
-
"""
|
894
|
-
|
895
|
-
with self.conn:
|
896
|
-
self.conn.execute(query_1, data)
|
897
|
-
self.conn.execute(query_2, data)
|
898
|
-
|
899
526
|
def delete_messages(self, message_ins_ids: set[UUID]) -> None:
|
900
527
|
"""Delete a Message and its reply based on provided Message IDs."""
|
901
528
|
if not message_ins_ids:
|
@@ -922,25 +549,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
922
549
|
self.conn.execute(query_1, data)
|
923
550
|
self.conn.execute(query_2, data)
|
924
551
|
|
925
|
-
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
926
|
-
"""Get all TaskIns IDs for the given run_id."""
|
927
|
-
if self.conn is None:
|
928
|
-
raise AttributeError("LinkState not initialized")
|
929
|
-
|
930
|
-
query = """
|
931
|
-
SELECT task_id
|
932
|
-
FROM task_ins
|
933
|
-
WHERE run_id = :run_id;
|
934
|
-
"""
|
935
|
-
|
936
|
-
sint64_run_id = convert_uint64_to_sint64(run_id)
|
937
|
-
data = {"run_id": sint64_run_id}
|
938
|
-
|
939
|
-
with self.conn:
|
940
|
-
rows = self.conn.execute(query, data).fetchall()
|
941
|
-
|
942
|
-
return {UUID(row["task_id"]) for row in rows}
|
943
|
-
|
944
552
|
def get_message_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
945
553
|
"""Get all instruction Message IDs for the given run_id."""
|
946
554
|
if self.conn is None:
|
@@ -1370,33 +978,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
1370
978
|
latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
|
1371
979
|
return "".join(row["log"] for row in rows), latest_timestamp
|
1372
980
|
|
1373
|
-
def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]:
|
1374
|
-
"""Check if the TaskIns exists and is valid (not expired).
|
1375
|
-
|
1376
|
-
Return TaskIns if valid.
|
1377
|
-
"""
|
1378
|
-
query = """
|
1379
|
-
SELECT *
|
1380
|
-
FROM task_ins
|
1381
|
-
WHERE task_id = :task_id
|
1382
|
-
"""
|
1383
|
-
data = {"task_id": task_id}
|
1384
|
-
rows = self.query(query, data)
|
1385
|
-
if not rows:
|
1386
|
-
# TaskIns does not exist
|
1387
|
-
return None
|
1388
|
-
|
1389
|
-
task_ins = rows[0]
|
1390
|
-
created_at = task_ins["created_at"]
|
1391
|
-
ttl = task_ins["ttl"]
|
1392
|
-
current_time = time.time()
|
1393
|
-
|
1394
|
-
# Check if TaskIns is expired
|
1395
|
-
if ttl is not None and created_at + ttl <= current_time:
|
1396
|
-
return None
|
1397
|
-
|
1398
|
-
return task_ins
|
1399
|
-
|
1400
981
|
def get_valid_message_ins(self, message_id: str) -> Optional[dict[str, Any]]:
|
1401
982
|
"""Check if the Message exists and is valid (not expired).
|
1402
983
|
|
@@ -1418,7 +999,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
1418
999
|
ttl = message_ins["ttl"]
|
1419
1000
|
current_time = time.time()
|
1420
1001
|
|
1421
|
-
# Check if
|
1002
|
+
# Check if Message is expired
|
1422
1003
|
if ttl is not None and created_at + ttl <= current_time:
|
1423
1004
|
return None
|
1424
1005
|
|
@@ -1437,42 +1018,6 @@ def dict_factory(
|
|
1437
1018
|
return dict(zip(fields, row))
|
1438
1019
|
|
1439
1020
|
|
1440
|
-
def task_ins_to_dict(task_msg: TaskIns) -> dict[str, Any]:
|
1441
|
-
"""Transform TaskIns to dict."""
|
1442
|
-
result = {
|
1443
|
-
"task_id": task_msg.task_id,
|
1444
|
-
"group_id": task_msg.group_id,
|
1445
|
-
"run_id": task_msg.run_id,
|
1446
|
-
"producer_node_id": task_msg.task.producer.node_id,
|
1447
|
-
"consumer_node_id": task_msg.task.consumer.node_id,
|
1448
|
-
"created_at": task_msg.task.created_at,
|
1449
|
-
"delivered_at": task_msg.task.delivered_at,
|
1450
|
-
"ttl": task_msg.task.ttl,
|
1451
|
-
"ancestry": ",".join(task_msg.task.ancestry),
|
1452
|
-
"task_type": task_msg.task.task_type,
|
1453
|
-
"recordset": task_msg.task.recordset.SerializeToString(),
|
1454
|
-
}
|
1455
|
-
return result
|
1456
|
-
|
1457
|
-
|
1458
|
-
def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
|
1459
|
-
"""Transform TaskRes to dict."""
|
1460
|
-
result = {
|
1461
|
-
"task_id": task_msg.task_id,
|
1462
|
-
"group_id": task_msg.group_id,
|
1463
|
-
"run_id": task_msg.run_id,
|
1464
|
-
"producer_node_id": task_msg.task.producer.node_id,
|
1465
|
-
"consumer_node_id": task_msg.task.consumer.node_id,
|
1466
|
-
"created_at": task_msg.task.created_at,
|
1467
|
-
"delivered_at": task_msg.task.delivered_at,
|
1468
|
-
"ttl": task_msg.task.ttl,
|
1469
|
-
"ancestry": ",".join(task_msg.task.ancestry),
|
1470
|
-
"task_type": task_msg.task.task_type,
|
1471
|
-
"recordset": task_msg.task.recordset.SerializeToString(),
|
1472
|
-
}
|
1473
|
-
return result
|
1474
|
-
|
1475
|
-
|
1476
1021
|
def message_to_dict(message: Message) -> dict[str, Any]:
|
1477
1022
|
"""Transform Message to dict."""
|
1478
1023
|
result = {
|
@@ -1498,60 +1043,6 @@ def message_to_dict(message: Message) -> dict[str, Any]:
|
|
1498
1043
|
return result
|
1499
1044
|
|
1500
1045
|
|
1501
|
-
def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
|
1502
|
-
"""Turn task_dict into protobuf message."""
|
1503
|
-
recordset = ProtoRecordSet()
|
1504
|
-
recordset.ParseFromString(task_dict["recordset"])
|
1505
|
-
|
1506
|
-
result = TaskIns(
|
1507
|
-
task_id=task_dict["task_id"],
|
1508
|
-
group_id=task_dict["group_id"],
|
1509
|
-
run_id=task_dict["run_id"],
|
1510
|
-
task=Task(
|
1511
|
-
producer=Node(
|
1512
|
-
node_id=task_dict["producer_node_id"],
|
1513
|
-
),
|
1514
|
-
consumer=Node(
|
1515
|
-
node_id=task_dict["consumer_node_id"],
|
1516
|
-
),
|
1517
|
-
created_at=task_dict["created_at"],
|
1518
|
-
delivered_at=task_dict["delivered_at"],
|
1519
|
-
ttl=task_dict["ttl"],
|
1520
|
-
ancestry=task_dict["ancestry"].split(","),
|
1521
|
-
task_type=task_dict["task_type"],
|
1522
|
-
recordset=recordset,
|
1523
|
-
),
|
1524
|
-
)
|
1525
|
-
return result
|
1526
|
-
|
1527
|
-
|
1528
|
-
def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
|
1529
|
-
"""Turn task_dict into protobuf message."""
|
1530
|
-
recordset = ProtoRecordSet()
|
1531
|
-
recordset.ParseFromString(task_dict["recordset"])
|
1532
|
-
|
1533
|
-
result = TaskRes(
|
1534
|
-
task_id=task_dict["task_id"],
|
1535
|
-
group_id=task_dict["group_id"],
|
1536
|
-
run_id=task_dict["run_id"],
|
1537
|
-
task=Task(
|
1538
|
-
producer=Node(
|
1539
|
-
node_id=task_dict["producer_node_id"],
|
1540
|
-
),
|
1541
|
-
consumer=Node(
|
1542
|
-
node_id=task_dict["consumer_node_id"],
|
1543
|
-
),
|
1544
|
-
created_at=task_dict["created_at"],
|
1545
|
-
delivered_at=task_dict["delivered_at"],
|
1546
|
-
ttl=task_dict["ttl"],
|
1547
|
-
ancestry=task_dict["ancestry"].split(","),
|
1548
|
-
task_type=task_dict["task_type"],
|
1549
|
-
recordset=recordset,
|
1550
|
-
),
|
1551
|
-
)
|
1552
|
-
return result
|
1553
|
-
|
1554
|
-
|
1555
1046
|
def dict_to_message(message_dict: dict[str, Any]) -> Message:
|
1556
1047
|
"""Transform dict to Message."""
|
1557
1048
|
content, error = None, None
|