flwr-nightly 1.16.0.dev20250304__py3-none-any.whl → 1.16.0.dev20250306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -26,7 +26,7 @@ from logging import DEBUG, ERROR, WARNING
26
26
  from typing import Any, Optional, Union, cast
27
27
  from uuid import UUID, uuid4
28
28
 
29
- from flwr.common import Context, log, now
29
+ from flwr.common import Context, Message, Metadata, log, now
30
30
  from flwr.common.constant import (
31
31
  MESSAGE_TTL_TOLERANCE,
32
32
  NODE_ID_NUM_BYTES,
@@ -35,15 +35,22 @@ from flwr.common.constant import (
35
35
  Status,
36
36
  )
37
37
  from flwr.common.record import ConfigsRecord
38
+ from flwr.common.serde import (
39
+ error_from_proto,
40
+ error_to_proto,
41
+ recordset_from_proto,
42
+ recordset_to_proto,
43
+ )
38
44
  from flwr.common.typing import Run, RunStatus, UserConfig
39
45
 
40
46
  # pylint: disable=E0611
47
+ from flwr.proto.error_pb2 import Error as ProtoError
41
48
  from flwr.proto.node_pb2 import Node
42
49
  from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
43
50
  from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
44
51
 
45
52
  # pylint: enable=E0611
46
- from flwr.server.utils.validator import validate_task_ins_or_res
53
+ from flwr.server.utils.validator import validate_message, validate_task_ins_or_res
47
54
 
48
55
  from .linkstate import LinkState
49
56
  from .utils import (
@@ -58,7 +65,9 @@ from .utils import (
58
65
  generate_rand_int_from_bytes,
59
66
  has_valid_sub_status,
60
67
  is_valid_transition,
68
+ verify_found_message_replies,
61
69
  verify_found_taskres,
70
+ verify_message_ids,
62
71
  verify_taskins_ids,
63
72
  )
64
73
 
@@ -134,6 +143,24 @@ CREATE TABLE IF NOT EXISTS task_ins(
134
143
  );
135
144
  """
136
145
 
146
+ SQL_CREATE_TABLE_MESSAGE_INS = """
147
+ CREATE TABLE IF NOT EXISTS message_ins(
148
+ message_id TEXT UNIQUE,
149
+ group_id TEXT,
150
+ run_id INTEGER,
151
+ src_node_id INTEGER,
152
+ dst_node_id INTEGER,
153
+ reply_to_message TEXT,
154
+ created_at REAL,
155
+ delivered_at TEXT,
156
+ ttl REAL,
157
+ message_type TEXT,
158
+ content BLOB NULL,
159
+ error BLOB NULL,
160
+ FOREIGN KEY(run_id) REFERENCES run(run_id)
161
+ );
162
+ """
163
+
137
164
  SQL_CREATE_TABLE_TASK_RES = """
138
165
  CREATE TABLE IF NOT EXISTS task_res(
139
166
  task_id TEXT UNIQUE,
@@ -151,6 +178,25 @@ CREATE TABLE IF NOT EXISTS task_res(
151
178
  );
152
179
  """
153
180
 
181
+
182
+ SQL_CREATE_TABLE_MESSAGE_RES = """
183
+ CREATE TABLE IF NOT EXISTS message_res(
184
+ message_id TEXT UNIQUE,
185
+ group_id TEXT,
186
+ run_id INTEGER,
187
+ src_node_id INTEGER,
188
+ dst_node_id INTEGER,
189
+ reply_to_message TEXT,
190
+ created_at REAL,
191
+ delivered_at TEXT,
192
+ ttl REAL,
193
+ message_type TEXT,
194
+ content BLOB NULL,
195
+ error BLOB NULL,
196
+ FOREIGN KEY(run_id) REFERENCES run(run_id)
197
+ );
198
+ """
199
+
154
200
  DictOrTuple = Union[tuple[Any, ...], dict[str, Any]]
155
201
 
156
202
 
@@ -198,6 +244,8 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
198
244
  cur.execute(SQL_CREATE_TABLE_CONTEXT)
199
245
  cur.execute(SQL_CREATE_TABLE_TASK_INS)
200
246
  cur.execute(SQL_CREATE_TABLE_TASK_RES)
247
+ cur.execute(SQL_CREATE_TABLE_MESSAGE_INS)
248
+ cur.execute(SQL_CREATE_TABLE_MESSAGE_RES)
201
249
  cur.execute(SQL_CREATE_TABLE_NODE)
202
250
  cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
203
251
  cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
@@ -302,6 +350,60 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
302
350
 
303
351
  return task_id
304
352
 
353
+ def store_message_ins(self, message: Message) -> Optional[UUID]:
354
+ """Store one Message."""
355
+ # Validate message
356
+ errors = validate_message(message=message, is_reply_message=False)
357
+ if any(errors):
358
+ log(ERROR, errors)
359
+ return None
360
+ # Create message_id
361
+ message_id = uuid4()
362
+
363
+ # Store Message
364
+ # pylint: disable-next=W0212
365
+ message.metadata._message_id = str(message_id) # type: ignore
366
+ data = (message_to_dict(message),)
367
+
368
+ # Convert values from uint64 to sint64 for SQLite
369
+ convert_uint64_values_in_dict_to_sint64(
370
+ data[0], ["run_id", "src_node_id", "dst_node_id"]
371
+ )
372
+
373
+ # Validate run_id
374
+ query = "SELECT run_id FROM run WHERE run_id = ?;"
375
+ if not self.query(query, (data[0]["run_id"],)):
376
+ log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
377
+ return None
378
+
379
+ # Validate source node ID
380
+ if message.metadata.src_node_id != SUPERLINK_NODE_ID:
381
+ log(
382
+ ERROR,
383
+ "Invalid source node ID for Message: %s",
384
+ message.metadata.src_node_id,
385
+ )
386
+ return None
387
+
388
+ # Validate destination node ID
389
+ query = "SELECT node_id FROM node WHERE node_id = ?;"
390
+ if not self.query(query, (data[0]["dst_node_id"],)):
391
+ log(
392
+ ERROR,
393
+ "Invalid destination node ID for Message: %s",
394
+ message.metadata.dst_node_id,
395
+ )
396
+ return None
397
+
398
+ columns = ", ".join([f":{key}" for key in data[0]])
399
+ query = f"INSERT INTO message_ins VALUES({columns});"
400
+
401
+ # Only invalid run_id can trigger IntegrityError.
402
+ # This may need to be changed in the future version with more integrity checks.
403
+ self.query(query, data)
404
+
405
+ return message_id
406
+
305
407
  def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
306
408
  """Get undelivered TaskIns for one node.
307
409
 
@@ -380,6 +482,67 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
380
482
 
381
483
  return result
382
484
 
485
+ def get_message_ins(self, node_id: int, limit: Optional[int]) -> list[Message]:
486
+ """Get all Messages that have not been delivered yet."""
487
+ if limit is not None and limit < 1:
488
+ raise AssertionError("`limit` must be >= 1")
489
+
490
+ if node_id == SUPERLINK_NODE_ID:
491
+ msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
492
+ raise AssertionError(msg)
493
+
494
+ data: dict[str, Union[str, int]] = {}
495
+
496
+ # Convert the uint64 value to sint64 for SQLite
497
+ data["node_id"] = convert_uint64_to_sint64(node_id)
498
+
499
+ # Retrieve all Messages for node_id
500
+ query = """
501
+ SELECT message_id
502
+ FROM message_ins
503
+ WHERE dst_node_id == :node_id
504
+ AND delivered_at = ""
505
+ AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
506
+ """
507
+
508
+ if limit is not None:
509
+ query += " LIMIT :limit"
510
+ data["limit"] = limit
511
+
512
+ query += ";"
513
+
514
+ rows = self.query(query, data)
515
+
516
+ if rows:
517
+ # Prepare query
518
+ message_ids = [row["message_id"] for row in rows]
519
+ placeholders: str = ",".join([f":id_{i}" for i in range(len(message_ids))])
520
+ query = f"""
521
+ UPDATE message_ins
522
+ SET delivered_at = :delivered_at
523
+ WHERE message_id IN ({placeholders})
524
+ RETURNING *;
525
+ """
526
+
527
+ # Prepare data for query
528
+ delivered_at = now().isoformat()
529
+ data = {"delivered_at": delivered_at}
530
+ for index, msg_id in enumerate(message_ids):
531
+ data[f"id_{index}"] = str(msg_id)
532
+
533
+ # Run query
534
+ rows = self.query(query, data)
535
+
536
+ for row in rows:
537
+ # Convert values from sint64 to uint64
538
+ convert_sint64_values_in_dict_to_uint64(
539
+ row, ["run_id", "src_node_id", "dst_node_id"]
540
+ )
541
+
542
+ result = [dict_to_message(row) for row in rows]
543
+
544
+ return result
545
+
383
546
  def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
384
547
  """Store one TaskRes.
385
548
 
@@ -464,6 +627,84 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
464
627
 
465
628
  return task_id
466
629
 
630
+ def store_message_res(self, message: Message) -> Optional[UUID]:
631
+ """Store one Message."""
632
+ # Validate message
633
+ errors = validate_message(message=message, is_reply_message=True)
634
+ if any(errors):
635
+ log(ERROR, errors)
636
+ return None
637
+
638
+ res_metadata = message.metadata
639
+ msg_ins_id = res_metadata.reply_to_message
640
+ msg_ins = self.get_valid_message_ins(msg_ins_id)
641
+ if msg_ins is None:
642
+ log(
643
+ ERROR,
644
+ "Failed to store Message reply: "
645
+ "The message it replies to with message_id %s does not exist or "
646
+ "has expired.",
647
+ msg_ins_id,
648
+ )
649
+ return None
650
+
651
+ # Ensure that the dst_node_id of the original message matches the src_node_id of
652
+ # reply being processed.
653
+ if (
654
+ msg_ins
655
+ and message
656
+ and convert_sint64_to_uint64(msg_ins["dst_node_id"])
657
+ != res_metadata.src_node_id
658
+ ):
659
+ return None
660
+
661
+ # Fail if the Message TTL exceeds the
662
+ # expiration time of the Message it replies to.
663
+ # Condition: ins_metadata.created_at + ins_metadata.ttl ≥
664
+ # res_metadata.created_at + res_metadata.ttl
665
+ # A small tolerance is introduced to account
666
+ # for floating-point precision issues.
667
+ max_allowed_ttl = (
668
+ msg_ins["created_at"] + msg_ins["ttl"] - res_metadata.created_at
669
+ )
670
+ if res_metadata.ttl and (
671
+ res_metadata.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
672
+ ):
673
+ log(
674
+ WARNING,
675
+ "Received Message with TTL %.2f exceeding the allowed maximum "
676
+ "TTL %.2f.",
677
+ res_metadata.ttl,
678
+ max_allowed_ttl,
679
+ )
680
+ return None
681
+
682
+ # Create message_id
683
+ message_id = uuid4()
684
+
685
+ # Store Message
686
+ # pylint: disable-next=W0212
687
+ message.metadata._message_id = str(message_id) # type: ignore
688
+ data = (message_to_dict(message),)
689
+
690
+ # Convert values from uint64 to sint64 for SQLite
691
+ convert_uint64_values_in_dict_to_sint64(
692
+ data[0], ["run_id", "src_node_id", "dst_node_id"]
693
+ )
694
+
695
+ columns = ", ".join([f":{key}" for key in data[0]])
696
+ query = f"INSERT INTO message_res VALUES({columns});"
697
+
698
+ # Only invalid run_id can trigger IntegrityError.
699
+ # This may need to be changed in the future version with more integrity checks.
700
+ try:
701
+ self.query(query, data)
702
+ except sqlite3.IntegrityError:
703
+ log(ERROR, "`run` is invalid")
704
+ return None
705
+
706
+ return message_id
707
+
467
708
  # pylint: disable-next=R0912,R0915,R0914
468
709
  def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
469
710
  """Get TaskRes for the given TaskIns IDs."""
@@ -525,6 +766,68 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
525
766
 
526
767
  return list(ret.values())
527
768
 
769
+ def get_message_res(self, message_ids: set[UUID]) -> list[Message]:
770
+ """Get reply Messages for the given Message IDs."""
771
+ ret: dict[UUID, Message] = {}
772
+
773
+ # Verify Message IDs
774
+ current = time.time()
775
+ query = f"""
776
+ SELECT *
777
+ FROM message_ins
778
+ WHERE message_id IN ({",".join(["?"] * len(message_ids))});
779
+ """
780
+ rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
781
+ found_message_ins_dict: dict[UUID, Message] = {}
782
+ for row in rows:
783
+ convert_sint64_values_in_dict_to_uint64(
784
+ row, ["run_id", "src_node_id", "dst_node_id"]
785
+ )
786
+ found_message_ins_dict[UUID(row["message_id"])] = dict_to_message(row)
787
+
788
+ ret = verify_message_ids(
789
+ inquired_message_ids=message_ids,
790
+ found_message_ins_dict=found_message_ins_dict,
791
+ current_time=current,
792
+ )
793
+
794
+ # Find all reply Messages
795
+ query = f"""
796
+ SELECT *
797
+ FROM message_res
798
+ WHERE reply_to_message IN ({",".join(["?"] * len(message_ids))})
799
+ AND delivered_at = "";
800
+ """
801
+ rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
802
+ for row in rows:
803
+ convert_sint64_values_in_dict_to_uint64(
804
+ row, ["run_id", "src_node_id", "dst_node_id"]
805
+ )
806
+ tmp_ret_dict = verify_found_message_replies(
807
+ inquired_message_ids=message_ids,
808
+ found_message_ins_dict=found_message_ins_dict,
809
+ found_message_res_list=[dict_to_message(row) for row in rows],
810
+ current_time=current,
811
+ )
812
+ ret.update(tmp_ret_dict)
813
+
814
+ # Mark existing reply Messages to be returned as delivered
815
+ delivered_at = now().isoformat()
816
+ for message_res in ret.values():
817
+ message_res.metadata.delivered_at = delivered_at
818
+ message_res_ids = [
819
+ message_res.metadata.message_id for message_res in ret.values()
820
+ ]
821
+ query = f"""
822
+ UPDATE message_res
823
+ SET delivered_at = ?
824
+ WHERE message_id IN ({",".join(["?"] * len(message_res_ids))});
825
+ """
826
+ data: list[Any] = [delivered_at] + message_res_ids
827
+ self.query(query, data)
828
+
829
+ return list(ret.values())
830
+
528
831
  def num_task_ins(self) -> int:
529
832
  """Calculate the number of task_ins in store.
530
833
 
@@ -536,6 +839,17 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
536
839
  num = cast(int, result["num"])
537
840
  return num
538
841
 
842
+ def num_message_ins(self) -> int:
843
+ """Calculate the number of instruction Messages in store.
844
+
845
+ This includes delivered but not yet deleted.
846
+ """
847
+ query = "SELECT count(*) AS num FROM message_ins;"
848
+ rows = self.query(query)
849
+ result = rows[0]
850
+ num = cast(int, result["num"])
851
+ return num
852
+
539
853
  def num_task_res(self) -> int:
540
854
  """Calculate the number of task_res in store.
541
855
 
@@ -546,6 +860,16 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
546
860
  result: dict[str, int] = rows[0]
547
861
  return result["num"]
548
862
 
863
+ def num_message_res(self) -> int:
864
+ """Calculate the number of reply Messages in store.
865
+
866
+ This includes delivered but not yet deleted.
867
+ """
868
+ query = "SELECT count(*) AS num FROM message_res;"
869
+ rows = self.query(query)
870
+ result: dict[str, int] = rows[0]
871
+ return result["num"]
872
+
549
873
  def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
550
874
  """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
551
875
  if not task_ins_ids:
@@ -572,6 +896,32 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
572
896
  self.conn.execute(query_1, data)
573
897
  self.conn.execute(query_2, data)
574
898
 
899
+ def delete_messages(self, message_ins_ids: set[UUID]) -> None:
900
+ """Delete a Message and its reply based on provided Message IDs."""
901
+ if not message_ins_ids:
902
+ return
903
+ if self.conn is None:
904
+ raise AttributeError("LinkState not initialized")
905
+
906
+ placeholders = ",".join(["?"] * len(message_ins_ids))
907
+ data = tuple(str(message_id) for message_id in message_ins_ids)
908
+
909
+ # Delete Message
910
+ query_1 = f"""
911
+ DELETE FROM message_ins
912
+ WHERE message_id IN ({placeholders});
913
+ """
914
+
915
+ # Delete reply Message
916
+ query_2 = f"""
917
+ DELETE FROM message_res
918
+ WHERE reply_to_message IN ({placeholders});
919
+ """
920
+
921
+ with self.conn:
922
+ self.conn.execute(query_1, data)
923
+ self.conn.execute(query_2, data)
924
+
575
925
  def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
576
926
  """Get all TaskIns IDs for the given run_id."""
577
927
  if self.conn is None:
@@ -591,6 +941,25 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
591
941
 
592
942
  return {UUID(row["task_id"]) for row in rows}
593
943
 
944
+ def get_message_ids_from_run_id(self, run_id: int) -> set[UUID]:
945
+ """Get all instruction Message IDs for the given run_id."""
946
+ if self.conn is None:
947
+ raise AttributeError("LinkState not initialized")
948
+
949
+ query = """
950
+ SELECT message_id
951
+ FROM message_ins
952
+ WHERE run_id = :run_id;
953
+ """
954
+
955
+ sint64_run_id = convert_uint64_to_sint64(run_id)
956
+ data = {"run_id": sint64_run_id}
957
+
958
+ with self.conn:
959
+ rows = self.conn.execute(query, data).fetchall()
960
+
961
+ return {UUID(row["message_id"]) for row in rows}
962
+
594
963
  def create_node(self, ping_interval: float) -> int:
595
964
  """Create, store in the link state, and return `node_id`."""
596
965
  # Sample a random uint64 as node_id
@@ -1028,6 +1397,33 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1028
1397
 
1029
1398
  return task_ins
1030
1399
 
1400
+ def get_valid_message_ins(self, message_id: str) -> Optional[dict[str, Any]]:
1401
+ """Check if the Message exists and is valid (not expired).
1402
+
1403
+ Return Message if valid.
1404
+ """
1405
+ query = """
1406
+ SELECT *
1407
+ FROM message_ins
1408
+ WHERE message_id = :message_id
1409
+ """
1410
+ data = {"message_id": message_id}
1411
+ rows = self.query(query, data)
1412
+ if not rows:
1413
+ # Message does not exist
1414
+ return None
1415
+
1416
+ message_ins = rows[0]
1417
+ created_at = message_ins["created_at"]
1418
+ ttl = message_ins["ttl"]
1419
+ current_time = time.time()
1420
+
1421
+ # Check if TaskIns is expired
1422
+ if ttl is not None and created_at + ttl <= current_time:
1423
+ return None
1424
+
1425
+ return message_ins
1426
+
1031
1427
 
1032
1428
  def dict_factory(
1033
1429
  cursor: sqlite3.Cursor,
@@ -1077,6 +1473,31 @@ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
1077
1473
  return result
1078
1474
 
1079
1475
 
1476
+ def message_to_dict(message: Message) -> dict[str, Any]:
1477
+ """Transform Message to dict."""
1478
+ result = {
1479
+ "message_id": message.metadata.message_id,
1480
+ "group_id": message.metadata.group_id,
1481
+ "run_id": message.metadata.run_id,
1482
+ "src_node_id": message.metadata.src_node_id,
1483
+ "dst_node_id": message.metadata.dst_node_id,
1484
+ "reply_to_message": message.metadata.reply_to_message,
1485
+ "created_at": message.metadata.created_at,
1486
+ "delivered_at": message.metadata.delivered_at,
1487
+ "ttl": message.metadata.ttl,
1488
+ "message_type": message.metadata.message_type,
1489
+ "content": None,
1490
+ "error": None,
1491
+ }
1492
+
1493
+ if message.has_content():
1494
+ result["content"] = recordset_to_proto(message.content).SerializeToString()
1495
+ else:
1496
+ result["error"] = error_to_proto(message.error).SerializeToString()
1497
+
1498
+ return result
1499
+
1500
+
1080
1501
  def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
1081
1502
  """Turn task_dict into protobuf message."""
1082
1503
  recordset = ProtoRecordSet()
@@ -1131,6 +1552,28 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
1131
1552
  return result
1132
1553
 
1133
1554
 
1555
+ def dict_to_message(message_dict: dict[str, Any]) -> Message:
1556
+ """Transform dict to Message."""
1557
+ content, error = None, None
1558
+ if (b_content := message_dict.pop("content")) is not None:
1559
+ content = recordset_from_proto(ProtoRecordSet.FromString(b_content))
1560
+ if (b_error := message_dict.pop("error")) is not None:
1561
+ error = error_from_proto(ProtoError.FromString(b_error))
1562
+
1563
+ # Metadata constructor doesn't allow passing created_at. We set it later
1564
+ metadata = Metadata(
1565
+ **{
1566
+ k: v
1567
+ for k, v in message_dict.items()
1568
+ if k not in ["created_at", "delivered_at"]
1569
+ }
1570
+ )
1571
+ msg = Message(metadata=metadata, content=content, error=error)
1572
+ msg.metadata.__dict__["_created_at"] = message_dict["created_at"]
1573
+ msg.metadata.delivered_at = message_dict["delivered_at"]
1574
+ return msg
1575
+
1576
+
1134
1577
  def determine_run_status(row: dict[str, Any]) -> str:
1135
1578
  """Determine the status of the run based on timestamp fields."""
1136
1579
  if row["pending_at"]: