flwr-nightly 1.16.0.dev20250304__py3-none-any.whl → 1.16.0.dev20250305__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/server/compat/app.py +4 -1
- flwr/server/compat/app_utils.py +10 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +236 -2
- flwr/server/superlink/linkstate/linkstate.py +101 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +445 -2
- flwr/server/superlink/linkstate/utils.py +153 -4
- flwr/server/workflow/default_workflows.py +4 -1
- {flwr_nightly-1.16.0.dev20250304.dist-info → flwr_nightly-1.16.0.dev20250305.dist-info}/METADATA +1 -1
- {flwr_nightly-1.16.0.dev20250304.dist-info → flwr_nightly-1.16.0.dev20250305.dist-info}/RECORD +12 -12
- {flwr_nightly-1.16.0.dev20250304.dist-info → flwr_nightly-1.16.0.dev20250305.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.16.0.dev20250304.dist-info → flwr_nightly-1.16.0.dev20250305.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.16.0.dev20250304.dist-info → flwr_nightly-1.16.0.dev20250305.dist-info}/entry_points.txt +0 -0
@@ -26,7 +26,7 @@ from logging import DEBUG, ERROR, WARNING
|
|
26
26
|
from typing import Any, Optional, Union, cast
|
27
27
|
from uuid import UUID, uuid4
|
28
28
|
|
29
|
-
from flwr.common import Context, log, now
|
29
|
+
from flwr.common import Context, Message, Metadata, log, now
|
30
30
|
from flwr.common.constant import (
|
31
31
|
MESSAGE_TTL_TOLERANCE,
|
32
32
|
NODE_ID_NUM_BYTES,
|
@@ -35,15 +35,22 @@ from flwr.common.constant import (
|
|
35
35
|
Status,
|
36
36
|
)
|
37
37
|
from flwr.common.record import ConfigsRecord
|
38
|
+
from flwr.common.serde import (
|
39
|
+
error_from_proto,
|
40
|
+
error_to_proto,
|
41
|
+
recordset_from_proto,
|
42
|
+
recordset_to_proto,
|
43
|
+
)
|
38
44
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
39
45
|
|
40
46
|
# pylint: disable=E0611
|
47
|
+
from flwr.proto.error_pb2 import Error as ProtoError
|
41
48
|
from flwr.proto.node_pb2 import Node
|
42
49
|
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
|
43
50
|
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
44
51
|
|
45
52
|
# pylint: enable=E0611
|
46
|
-
from flwr.server.utils.validator import validate_task_ins_or_res
|
53
|
+
from flwr.server.utils.validator import validate_message, validate_task_ins_or_res
|
47
54
|
|
48
55
|
from .linkstate import LinkState
|
49
56
|
from .utils import (
|
@@ -58,7 +65,9 @@ from .utils import (
|
|
58
65
|
generate_rand_int_from_bytes,
|
59
66
|
has_valid_sub_status,
|
60
67
|
is_valid_transition,
|
68
|
+
verify_found_message_replies,
|
61
69
|
verify_found_taskres,
|
70
|
+
verify_message_ids,
|
62
71
|
verify_taskins_ids,
|
63
72
|
)
|
64
73
|
|
@@ -134,6 +143,24 @@ CREATE TABLE IF NOT EXISTS task_ins(
|
|
134
143
|
);
|
135
144
|
"""
|
136
145
|
|
146
|
+
SQL_CREATE_TABLE_MESSAGE_INS = """
|
147
|
+
CREATE TABLE IF NOT EXISTS message_ins(
|
148
|
+
message_id TEXT UNIQUE,
|
149
|
+
group_id TEXT,
|
150
|
+
run_id INTEGER,
|
151
|
+
src_node_id INTEGER,
|
152
|
+
dst_node_id INTEGER,
|
153
|
+
reply_to_message TEXT,
|
154
|
+
created_at REAL,
|
155
|
+
delivered_at TEXT,
|
156
|
+
ttl REAL,
|
157
|
+
message_type TEXT,
|
158
|
+
content BLOB NULL,
|
159
|
+
error BLOB NULL,
|
160
|
+
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
161
|
+
);
|
162
|
+
"""
|
163
|
+
|
137
164
|
SQL_CREATE_TABLE_TASK_RES = """
|
138
165
|
CREATE TABLE IF NOT EXISTS task_res(
|
139
166
|
task_id TEXT UNIQUE,
|
@@ -151,6 +178,25 @@ CREATE TABLE IF NOT EXISTS task_res(
|
|
151
178
|
);
|
152
179
|
"""
|
153
180
|
|
181
|
+
|
182
|
+
SQL_CREATE_TABLE_MESSAGE_RES = """
|
183
|
+
CREATE TABLE IF NOT EXISTS message_res(
|
184
|
+
message_id TEXT UNIQUE,
|
185
|
+
group_id TEXT,
|
186
|
+
run_id INTEGER,
|
187
|
+
src_node_id INTEGER,
|
188
|
+
dst_node_id INTEGER,
|
189
|
+
reply_to_message TEXT,
|
190
|
+
created_at REAL,
|
191
|
+
delivered_at TEXT,
|
192
|
+
ttl REAL,
|
193
|
+
message_type TEXT,
|
194
|
+
content BLOB NULL,
|
195
|
+
error BLOB NULL,
|
196
|
+
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
197
|
+
);
|
198
|
+
"""
|
199
|
+
|
154
200
|
DictOrTuple = Union[tuple[Any, ...], dict[str, Any]]
|
155
201
|
|
156
202
|
|
@@ -198,6 +244,8 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
198
244
|
cur.execute(SQL_CREATE_TABLE_CONTEXT)
|
199
245
|
cur.execute(SQL_CREATE_TABLE_TASK_INS)
|
200
246
|
cur.execute(SQL_CREATE_TABLE_TASK_RES)
|
247
|
+
cur.execute(SQL_CREATE_TABLE_MESSAGE_INS)
|
248
|
+
cur.execute(SQL_CREATE_TABLE_MESSAGE_RES)
|
201
249
|
cur.execute(SQL_CREATE_TABLE_NODE)
|
202
250
|
cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
|
203
251
|
cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
|
@@ -302,6 +350,60 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
302
350
|
|
303
351
|
return task_id
|
304
352
|
|
353
|
+
def store_message_ins(self, message: Message) -> Optional[UUID]:
|
354
|
+
"""Store one Message."""
|
355
|
+
# Validate message
|
356
|
+
errors = validate_message(message=message, is_reply_message=False)
|
357
|
+
if any(errors):
|
358
|
+
log(ERROR, errors)
|
359
|
+
return None
|
360
|
+
# Create message_id
|
361
|
+
message_id = uuid4()
|
362
|
+
|
363
|
+
# Store Message
|
364
|
+
# pylint: disable-next=W0212
|
365
|
+
message.metadata._message_id = str(message_id) # type: ignore
|
366
|
+
data = (message_to_dict(message),)
|
367
|
+
|
368
|
+
# Convert values from uint64 to sint64 for SQLite
|
369
|
+
convert_uint64_values_in_dict_to_sint64(
|
370
|
+
data[0], ["run_id", "src_node_id", "dst_node_id"]
|
371
|
+
)
|
372
|
+
|
373
|
+
# Validate run_id
|
374
|
+
query = "SELECT run_id FROM run WHERE run_id = ?;"
|
375
|
+
if not self.query(query, (data[0]["run_id"],)):
|
376
|
+
log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
|
377
|
+
return None
|
378
|
+
|
379
|
+
# Validate source node ID
|
380
|
+
if message.metadata.src_node_id != SUPERLINK_NODE_ID:
|
381
|
+
log(
|
382
|
+
ERROR,
|
383
|
+
"Invalid source node ID for Message: %s",
|
384
|
+
message.metadata.src_node_id,
|
385
|
+
)
|
386
|
+
return None
|
387
|
+
|
388
|
+
# Validate destination node ID
|
389
|
+
query = "SELECT node_id FROM node WHERE node_id = ?;"
|
390
|
+
if not self.query(query, (data[0]["dst_node_id"],)):
|
391
|
+
log(
|
392
|
+
ERROR,
|
393
|
+
"Invalid destination node ID for Message: %s",
|
394
|
+
message.metadata.dst_node_id,
|
395
|
+
)
|
396
|
+
return None
|
397
|
+
|
398
|
+
columns = ", ".join([f":{key}" for key in data[0]])
|
399
|
+
query = f"INSERT INTO message_ins VALUES({columns});"
|
400
|
+
|
401
|
+
# Only invalid run_id can trigger IntegrityError.
|
402
|
+
# This may need to be changed in the future version with more integrity checks.
|
403
|
+
self.query(query, data)
|
404
|
+
|
405
|
+
return message_id
|
406
|
+
|
305
407
|
def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
|
306
408
|
"""Get undelivered TaskIns for one node.
|
307
409
|
|
@@ -380,6 +482,67 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
380
482
|
|
381
483
|
return result
|
382
484
|
|
485
|
+
def get_message_ins(self, node_id: int, limit: Optional[int]) -> list[Message]:
|
486
|
+
"""Get all Messages that have not been delivered yet."""
|
487
|
+
if limit is not None and limit < 1:
|
488
|
+
raise AssertionError("`limit` must be >= 1")
|
489
|
+
|
490
|
+
if node_id == SUPERLINK_NODE_ID:
|
491
|
+
msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
|
492
|
+
raise AssertionError(msg)
|
493
|
+
|
494
|
+
data: dict[str, Union[str, int]] = {}
|
495
|
+
|
496
|
+
# Convert the uint64 value to sint64 for SQLite
|
497
|
+
data["node_id"] = convert_uint64_to_sint64(node_id)
|
498
|
+
|
499
|
+
# Retrieve all Messages for node_id
|
500
|
+
query = """
|
501
|
+
SELECT message_id
|
502
|
+
FROM message_ins
|
503
|
+
WHERE dst_node_id == :node_id
|
504
|
+
AND delivered_at = ""
|
505
|
+
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
|
506
|
+
"""
|
507
|
+
|
508
|
+
if limit is not None:
|
509
|
+
query += " LIMIT :limit"
|
510
|
+
data["limit"] = limit
|
511
|
+
|
512
|
+
query += ";"
|
513
|
+
|
514
|
+
rows = self.query(query, data)
|
515
|
+
|
516
|
+
if rows:
|
517
|
+
# Prepare query
|
518
|
+
message_ids = [row["message_id"] for row in rows]
|
519
|
+
placeholders: str = ",".join([f":id_{i}" for i in range(len(message_ids))])
|
520
|
+
query = f"""
|
521
|
+
UPDATE message_ins
|
522
|
+
SET delivered_at = :delivered_at
|
523
|
+
WHERE message_id IN ({placeholders})
|
524
|
+
RETURNING *;
|
525
|
+
"""
|
526
|
+
|
527
|
+
# Prepare data for query
|
528
|
+
delivered_at = now().isoformat()
|
529
|
+
data = {"delivered_at": delivered_at}
|
530
|
+
for index, msg_id in enumerate(message_ids):
|
531
|
+
data[f"id_{index}"] = str(msg_id)
|
532
|
+
|
533
|
+
# Run query
|
534
|
+
rows = self.query(query, data)
|
535
|
+
|
536
|
+
for row in rows:
|
537
|
+
# Convert values from sint64 to uint64
|
538
|
+
convert_sint64_values_in_dict_to_uint64(
|
539
|
+
row, ["run_id", "src_node_id", "dst_node_id"]
|
540
|
+
)
|
541
|
+
|
542
|
+
result = [dict_to_message(row) for row in rows]
|
543
|
+
|
544
|
+
return result
|
545
|
+
|
383
546
|
def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
|
384
547
|
"""Store one TaskRes.
|
385
548
|
|
@@ -464,6 +627,84 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
464
627
|
|
465
628
|
return task_id
|
466
629
|
|
630
|
+
def store_message_res(self, message: Message) -> Optional[UUID]:
|
631
|
+
"""Store one Message."""
|
632
|
+
# Validate message
|
633
|
+
errors = validate_message(message=message, is_reply_message=True)
|
634
|
+
if any(errors):
|
635
|
+
log(ERROR, errors)
|
636
|
+
return None
|
637
|
+
|
638
|
+
res_metadata = message.metadata
|
639
|
+
msg_ins_id = res_metadata.reply_to_message
|
640
|
+
msg_ins = self.get_valid_message_ins(msg_ins_id)
|
641
|
+
if msg_ins is None:
|
642
|
+
log(
|
643
|
+
ERROR,
|
644
|
+
"Failed to store Message reply: "
|
645
|
+
"The message it replies to with message_id %s does not exist or "
|
646
|
+
"has expired.",
|
647
|
+
msg_ins_id,
|
648
|
+
)
|
649
|
+
return None
|
650
|
+
|
651
|
+
# Ensure that the dst_node_id of the original message matches the src_node_id of
|
652
|
+
# reply being processed.
|
653
|
+
if (
|
654
|
+
msg_ins
|
655
|
+
and message
|
656
|
+
and convert_sint64_to_uint64(msg_ins["dst_node_id"])
|
657
|
+
!= res_metadata.src_node_id
|
658
|
+
):
|
659
|
+
return None
|
660
|
+
|
661
|
+
# Fail if the Message TTL exceeds the
|
662
|
+
# expiration time of the Message it replies to.
|
663
|
+
# Condition: ins_metadata.created_at + ins_metadata.ttl ≥
|
664
|
+
# res_metadata.created_at + res_metadata.ttl
|
665
|
+
# A small tolerance is introduced to account
|
666
|
+
# for floating-point precision issues.
|
667
|
+
max_allowed_ttl = (
|
668
|
+
msg_ins["created_at"] + msg_ins["ttl"] - res_metadata.created_at
|
669
|
+
)
|
670
|
+
if res_metadata.ttl and (
|
671
|
+
res_metadata.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
|
672
|
+
):
|
673
|
+
log(
|
674
|
+
WARNING,
|
675
|
+
"Received Message with TTL %.2f exceeding the allowed maximum "
|
676
|
+
"TTL %.2f.",
|
677
|
+
res_metadata.ttl,
|
678
|
+
max_allowed_ttl,
|
679
|
+
)
|
680
|
+
return None
|
681
|
+
|
682
|
+
# Create message_id
|
683
|
+
message_id = uuid4()
|
684
|
+
|
685
|
+
# Store Message
|
686
|
+
# pylint: disable-next=W0212
|
687
|
+
message.metadata._message_id = str(message_id) # type: ignore
|
688
|
+
data = (message_to_dict(message),)
|
689
|
+
|
690
|
+
# Convert values from uint64 to sint64 for SQLite
|
691
|
+
convert_uint64_values_in_dict_to_sint64(
|
692
|
+
data[0], ["run_id", "src_node_id", "dst_node_id"]
|
693
|
+
)
|
694
|
+
|
695
|
+
columns = ", ".join([f":{key}" for key in data[0]])
|
696
|
+
query = f"INSERT INTO message_res VALUES({columns});"
|
697
|
+
|
698
|
+
# Only invalid run_id can trigger IntegrityError.
|
699
|
+
# This may need to be changed in the future version with more integrity checks.
|
700
|
+
try:
|
701
|
+
self.query(query, data)
|
702
|
+
except sqlite3.IntegrityError:
|
703
|
+
log(ERROR, "`run` is invalid")
|
704
|
+
return None
|
705
|
+
|
706
|
+
return message_id
|
707
|
+
|
467
708
|
# pylint: disable-next=R0912,R0915,R0914
|
468
709
|
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
469
710
|
"""Get TaskRes for the given TaskIns IDs."""
|
@@ -525,6 +766,68 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
525
766
|
|
526
767
|
return list(ret.values())
|
527
768
|
|
769
|
+
def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
|
770
|
+
"""Get reply Messages for the given Message IDs."""
|
771
|
+
ret: dict[UUID, Message] = {}
|
772
|
+
|
773
|
+
# Verify Message IDs
|
774
|
+
current = time.time()
|
775
|
+
query = f"""
|
776
|
+
SELECT *
|
777
|
+
FROM message_ins
|
778
|
+
WHERE message_id IN ({",".join(["?"] * len(message_ids))});
|
779
|
+
"""
|
780
|
+
rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
|
781
|
+
found_message_ins_dict: dict[UUID, Message] = {}
|
782
|
+
for row in rows:
|
783
|
+
convert_sint64_values_in_dict_to_uint64(
|
784
|
+
row, ["run_id", "src_node_id", "dst_node_id"]
|
785
|
+
)
|
786
|
+
found_message_ins_dict[UUID(row["message_id"])] = dict_to_message(row)
|
787
|
+
|
788
|
+
ret = verify_message_ids(
|
789
|
+
inquired_message_ids=message_ids,
|
790
|
+
found_message_ins_dict=found_message_ins_dict,
|
791
|
+
current_time=current,
|
792
|
+
)
|
793
|
+
|
794
|
+
# Find all reply Messages
|
795
|
+
query = f"""
|
796
|
+
SELECT *
|
797
|
+
FROM message_res
|
798
|
+
WHERE reply_to_message IN ({",".join(["?"] * len(message_ids))})
|
799
|
+
AND delivered_at = "";
|
800
|
+
"""
|
801
|
+
rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
|
802
|
+
for row in rows:
|
803
|
+
convert_sint64_values_in_dict_to_uint64(
|
804
|
+
row, ["run_id", "src_node_id", "dst_node_id"]
|
805
|
+
)
|
806
|
+
tmp_ret_dict = verify_found_message_replies(
|
807
|
+
inquired_message_ids=message_ids,
|
808
|
+
found_message_ins_dict=found_message_ins_dict,
|
809
|
+
found_message_res_list=[dict_to_message(row) for row in rows],
|
810
|
+
current_time=current,
|
811
|
+
)
|
812
|
+
ret.update(tmp_ret_dict)
|
813
|
+
|
814
|
+
# Mark existing reply Messages to be returned as delivered
|
815
|
+
delivered_at = now().isoformat()
|
816
|
+
for message_res in ret.values():
|
817
|
+
message_res.metadata.delivered_at = delivered_at
|
818
|
+
message_res_ids = [
|
819
|
+
message_res.metadata.message_id for message_res in ret.values()
|
820
|
+
]
|
821
|
+
query = f"""
|
822
|
+
UPDATE message_res
|
823
|
+
SET delivered_at = ?
|
824
|
+
WHERE message_id IN ({",".join(["?"] * len(message_res_ids))});
|
825
|
+
"""
|
826
|
+
data: list[Any] = [delivered_at] + message_res_ids
|
827
|
+
self.query(query, data)
|
828
|
+
|
829
|
+
return list(ret.values())
|
830
|
+
|
528
831
|
def num_task_ins(self) -> int:
|
529
832
|
"""Calculate the number of task_ins in store.
|
530
833
|
|
@@ -536,6 +839,17 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
536
839
|
num = cast(int, result["num"])
|
537
840
|
return num
|
538
841
|
|
842
|
+
def num_message_ins(self) -> int:
|
843
|
+
"""Calculate the number of instruction Messages in store.
|
844
|
+
|
845
|
+
This includes delivered but not yet deleted.
|
846
|
+
"""
|
847
|
+
query = "SELECT count(*) AS num FROM message_ins;"
|
848
|
+
rows = self.query(query)
|
849
|
+
result = rows[0]
|
850
|
+
num = cast(int, result["num"])
|
851
|
+
return num
|
852
|
+
|
539
853
|
def num_task_res(self) -> int:
|
540
854
|
"""Calculate the number of task_res in store.
|
541
855
|
|
@@ -546,6 +860,16 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
546
860
|
result: dict[str, int] = rows[0]
|
547
861
|
return result["num"]
|
548
862
|
|
863
|
+
def num_message_res(self) -> int:
|
864
|
+
"""Calculate the number of reply Messages in store.
|
865
|
+
|
866
|
+
This includes delivered but not yet deleted.
|
867
|
+
"""
|
868
|
+
query = "SELECT count(*) AS num FROM message_res;"
|
869
|
+
rows = self.query(query)
|
870
|
+
result: dict[str, int] = rows[0]
|
871
|
+
return result["num"]
|
872
|
+
|
549
873
|
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
550
874
|
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
|
551
875
|
if not task_ins_ids:
|
@@ -572,6 +896,32 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
572
896
|
self.conn.execute(query_1, data)
|
573
897
|
self.conn.execute(query_2, data)
|
574
898
|
|
899
|
+
def delete_messages(self, message_ins_ids: set[UUID]) -> None:
|
900
|
+
"""Delete a Message and its reply based on provided Message IDs."""
|
901
|
+
if not message_ins_ids:
|
902
|
+
return
|
903
|
+
if self.conn is None:
|
904
|
+
raise AttributeError("LinkState not initialized")
|
905
|
+
|
906
|
+
placeholders = ",".join(["?"] * len(message_ins_ids))
|
907
|
+
data = tuple(str(message_id) for message_id in message_ins_ids)
|
908
|
+
|
909
|
+
# Delete Message
|
910
|
+
query_1 = f"""
|
911
|
+
DELETE FROM message_ins
|
912
|
+
WHERE message_id IN ({placeholders});
|
913
|
+
"""
|
914
|
+
|
915
|
+
# Delete reply Message
|
916
|
+
query_2 = f"""
|
917
|
+
DELETE FROM message_res
|
918
|
+
WHERE reply_to_message IN ({placeholders});
|
919
|
+
"""
|
920
|
+
|
921
|
+
with self.conn:
|
922
|
+
self.conn.execute(query_1, data)
|
923
|
+
self.conn.execute(query_2, data)
|
924
|
+
|
575
925
|
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
576
926
|
"""Get all TaskIns IDs for the given run_id."""
|
577
927
|
if self.conn is None:
|
@@ -591,6 +941,25 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
591
941
|
|
592
942
|
return {UUID(row["task_id"]) for row in rows}
|
593
943
|
|
944
|
+
def get_message_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
945
|
+
"""Get all instruction Message IDs for the given run_id."""
|
946
|
+
if self.conn is None:
|
947
|
+
raise AttributeError("LinkState not initialized")
|
948
|
+
|
949
|
+
query = """
|
950
|
+
SELECT message_id
|
951
|
+
FROM message_ins
|
952
|
+
WHERE run_id = :run_id;
|
953
|
+
"""
|
954
|
+
|
955
|
+
sint64_run_id = convert_uint64_to_sint64(run_id)
|
956
|
+
data = {"run_id": sint64_run_id}
|
957
|
+
|
958
|
+
with self.conn:
|
959
|
+
rows = self.conn.execute(query, data).fetchall()
|
960
|
+
|
961
|
+
return {UUID(row["message_id"]) for row in rows}
|
962
|
+
|
594
963
|
def create_node(self, ping_interval: float) -> int:
|
595
964
|
"""Create, store in the link state, and return `node_id`."""
|
596
965
|
# Sample a random uint64 as node_id
|
@@ -1028,6 +1397,33 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
1028
1397
|
|
1029
1398
|
return task_ins
|
1030
1399
|
|
1400
|
+
def get_valid_message_ins(self, message_id: str) -> Optional[dict[str, Any]]:
|
1401
|
+
"""Check if the Message exists and is valid (not expired).
|
1402
|
+
|
1403
|
+
Return Message if valid.
|
1404
|
+
"""
|
1405
|
+
query = """
|
1406
|
+
SELECT *
|
1407
|
+
FROM message_ins
|
1408
|
+
WHERE message_id = :message_id
|
1409
|
+
"""
|
1410
|
+
data = {"message_id": message_id}
|
1411
|
+
rows = self.query(query, data)
|
1412
|
+
if not rows:
|
1413
|
+
# Message does not exist
|
1414
|
+
return None
|
1415
|
+
|
1416
|
+
message_ins = rows[0]
|
1417
|
+
created_at = message_ins["created_at"]
|
1418
|
+
ttl = message_ins["ttl"]
|
1419
|
+
current_time = time.time()
|
1420
|
+
|
1421
|
+
# Check if TaskIns is expired
|
1422
|
+
if ttl is not None and created_at + ttl <= current_time:
|
1423
|
+
return None
|
1424
|
+
|
1425
|
+
return message_ins
|
1426
|
+
|
1031
1427
|
|
1032
1428
|
def dict_factory(
|
1033
1429
|
cursor: sqlite3.Cursor,
|
@@ -1077,6 +1473,31 @@ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
|
|
1077
1473
|
return result
|
1078
1474
|
|
1079
1475
|
|
1476
|
+
def message_to_dict(message: Message) -> dict[str, Any]:
|
1477
|
+
"""Transform Message to dict."""
|
1478
|
+
result = {
|
1479
|
+
"message_id": message.metadata.message_id,
|
1480
|
+
"group_id": message.metadata.group_id,
|
1481
|
+
"run_id": message.metadata.run_id,
|
1482
|
+
"src_node_id": message.metadata.src_node_id,
|
1483
|
+
"dst_node_id": message.metadata.dst_node_id,
|
1484
|
+
"reply_to_message": message.metadata.reply_to_message,
|
1485
|
+
"created_at": message.metadata.created_at,
|
1486
|
+
"delivered_at": message.metadata.delivered_at,
|
1487
|
+
"ttl": message.metadata.ttl,
|
1488
|
+
"message_type": message.metadata.message_type,
|
1489
|
+
"content": None,
|
1490
|
+
"error": None,
|
1491
|
+
}
|
1492
|
+
|
1493
|
+
if message.has_content():
|
1494
|
+
result["content"] = recordset_to_proto(message.content).SerializeToString()
|
1495
|
+
else:
|
1496
|
+
result["error"] = error_to_proto(message.error).SerializeToString()
|
1497
|
+
|
1498
|
+
return result
|
1499
|
+
|
1500
|
+
|
1080
1501
|
def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
|
1081
1502
|
"""Turn task_dict into protobuf message."""
|
1082
1503
|
recordset = ProtoRecordSet()
|
@@ -1131,6 +1552,28 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
|
|
1131
1552
|
return result
|
1132
1553
|
|
1133
1554
|
|
1555
|
+
def dict_to_message(message_dict: dict[str, Any]) -> Message:
|
1556
|
+
"""Transform dict to Message."""
|
1557
|
+
content, error = None, None
|
1558
|
+
if (b_content := message_dict.pop("content")) is not None:
|
1559
|
+
content = recordset_from_proto(ProtoRecordSet.FromString(b_content))
|
1560
|
+
if (b_error := message_dict.pop("error")) is not None:
|
1561
|
+
error = error_from_proto(ProtoError.FromString(b_error))
|
1562
|
+
|
1563
|
+
# Metadata constructor doesn't allow passing created_at. We set it later
|
1564
|
+
metadata = Metadata(
|
1565
|
+
**{
|
1566
|
+
k: v
|
1567
|
+
for k, v in message_dict.items()
|
1568
|
+
if k not in ["created_at", "delivered_at"]
|
1569
|
+
}
|
1570
|
+
)
|
1571
|
+
msg = Message(metadata=metadata, content=content, error=error)
|
1572
|
+
msg.metadata.__dict__["_created_at"] = message_dict["created_at"]
|
1573
|
+
msg.metadata.delivered_at = message_dict["delivered_at"]
|
1574
|
+
return msg
|
1575
|
+
|
1576
|
+
|
1134
1577
|
def determine_run_status(row: dict[str, Any]) -> str:
|
1135
1578
|
"""Determine the status of the run based on timestamp fields."""
|
1136
1579
|
if row["pending_at"]:
|