flwr-nightly 1.23.0.dev20251020__py3-none-any.whl → 1.23.0.dev20251022__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

@@ -18,11 +18,10 @@
18
18
  # pylint: disable=too-many-lines
19
19
 
20
20
  import json
21
- import re
22
21
  import secrets
23
22
  import sqlite3
24
23
  from collections.abc import Sequence
25
- from logging import DEBUG, ERROR, WARNING
24
+ from logging import ERROR, WARNING
26
25
  from typing import Any, Optional, Union, cast
27
26
 
28
27
  from flwr.common import Context, Message, Metadata, log, now
@@ -52,6 +51,8 @@ from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
52
51
  # pylint: enable=E0611
53
52
  from flwr.server.utils.validator import validate_message
54
53
  from flwr.supercore.constant import NodeStatus
54
+ from flwr.supercore.sqlite_mixin import SqliteMixin
55
+ from flwr.supercore.utils import int64_to_uint64, uint64_to_int64
55
56
 
56
57
  from .linkstate import LinkState
57
58
  from .utils import (
@@ -60,9 +61,7 @@ from .utils import (
60
61
  configrecord_to_bytes,
61
62
  context_from_bytes,
62
63
  context_to_bytes,
63
- convert_sint64_to_uint64,
64
64
  convert_sint64_values_in_dict_to_uint64,
65
- convert_uint64_to_sint64,
66
65
  convert_uint64_values_in_dict_to_sint64,
67
66
  generate_rand_int_from_bytes,
68
67
  has_valid_sub_status,
@@ -183,95 +182,25 @@ CREATE TABLE IF NOT EXISTS token_store (
183
182
  );
184
183
  """
185
184
 
186
- DictOrTuple = Union[tuple[Any, ...], dict[str, Any]]
187
185
 
188
-
189
- class SqliteLinkState(LinkState): # pylint: disable=R0904
186
+ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
190
187
  """SQLite-based LinkState implementation."""
191
188
 
192
- def __init__(
193
- self,
194
- database_path: str,
195
- ) -> None:
196
- """Initialize an SqliteLinkState.
197
-
198
- Parameters
199
- ----------
200
- database : (path-like object)
201
- The path to the database file to be opened. Pass ":memory:" to open
202
- a connection to a database that is in RAM, instead of on disk.
203
- """
204
- self.database_path = database_path
205
- self.conn: Optional[sqlite3.Connection] = None
206
-
207
189
  def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
208
- """Create tables if they don't exist yet.
209
-
210
- Parameters
211
- ----------
212
- log_queries : bool
213
- Log each query which is executed.
214
-
215
- Returns
216
- -------
217
- list[tuple[str]]
218
- The list of all tables in the DB.
219
- """
220
- self.conn = sqlite3.connect(self.database_path)
221
- self.conn.execute("PRAGMA foreign_keys = ON;")
222
- self.conn.row_factory = dict_factory
223
- if log_queries:
224
- self.conn.set_trace_callback(lambda query: log(DEBUG, query))
225
- cur = self.conn.cursor()
226
-
227
- # Create each table if not exists queries
228
- cur.execute(SQL_CREATE_TABLE_RUN)
229
- cur.execute(SQL_CREATE_TABLE_LOGS)
230
- cur.execute(SQL_CREATE_TABLE_CONTEXT)
231
- cur.execute(SQL_CREATE_TABLE_MESSAGE_INS)
232
- cur.execute(SQL_CREATE_TABLE_MESSAGE_RES)
233
- cur.execute(SQL_CREATE_TABLE_NODE)
234
- cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
235
- cur.execute(SQL_CREATE_TABLE_TOKEN_STORE)
236
- cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
237
- cur.execute(SQL_CREATE_INDEX_OWNER_AID)
238
- res = cur.execute("SELECT name FROM sqlite_schema;")
239
- return res.fetchall()
240
-
241
- def query(
242
- self,
243
- query: str,
244
- data: Optional[Union[Sequence[DictOrTuple], DictOrTuple]] = None,
245
- ) -> list[dict[str, Any]]:
246
- """Execute a SQL query."""
247
- if self.conn is None:
248
- raise AttributeError("LinkState is not initialized.")
249
-
250
- if data is None:
251
- data = []
252
-
253
- # Clean up whitespace to make the logs nicer
254
- query = re.sub(r"\s+", " ", query)
255
-
256
- try:
257
- with self.conn:
258
- if (
259
- len(data) > 0
260
- and isinstance(data, (tuple, list))
261
- and isinstance(data[0], (tuple, dict))
262
- ):
263
- rows = self.conn.executemany(query, data)
264
- else:
265
- rows = self.conn.execute(query, data)
266
-
267
- # Extract results before committing to support
268
- # INSERT/UPDATE ... RETURNING
269
- # style queries
270
- result = rows.fetchall()
271
- except KeyError as exc:
272
- log(ERROR, {"query": query, "data": data, "exception": exc})
273
-
274
- return result
190
+ """Connect to the DB, enable FK support, and create tables if needed."""
191
+ return self._ensure_initialized(
192
+ SQL_CREATE_TABLE_RUN,
193
+ SQL_CREATE_TABLE_LOGS,
194
+ SQL_CREATE_TABLE_CONTEXT,
195
+ SQL_CREATE_TABLE_MESSAGE_INS,
196
+ SQL_CREATE_TABLE_MESSAGE_RES,
197
+ SQL_CREATE_TABLE_NODE,
198
+ SQL_CREATE_TABLE_PUBLIC_KEY,
199
+ SQL_CREATE_TABLE_TOKEN_STORE,
200
+ SQL_CREATE_INDEX_ONLINE_UNTIL,
201
+ SQL_CREATE_INDEX_OWNER_AID,
202
+ log_queries=log_queries,
203
+ )
275
204
 
276
205
  def store_message_ins(self, message: Message) -> Optional[str]:
277
206
  """Store one Message."""
@@ -335,7 +264,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
335
264
  data: dict[str, Union[str, int]] = {}
336
265
 
337
266
  # Convert the uint64 value to sint64 for SQLite
338
- data["node_id"] = convert_uint64_to_sint64(node_id)
267
+ data["node_id"] = uint64_to_int64(node_id)
339
268
 
340
269
  # Retrieve all Messages for node_id
341
270
  query = """
@@ -410,8 +339,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
410
339
  if (
411
340
  msg_ins
412
341
  and message
413
- and convert_sint64_to_uint64(msg_ins["dst_node_id"])
414
- != res_metadata.src_node_id
342
+ and int64_to_uint64(msg_ins["dst_node_id"]) != res_metadata.src_node_id
415
343
  ):
416
344
  return None
417
345
 
@@ -487,20 +415,20 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
487
415
  dst_node_ids: set[int] = set()
488
416
  for message_id in message_ids:
489
417
  in_message = found_message_ins_dict[message_id]
490
- sint_node_id = convert_uint64_to_sint64(in_message.metadata.dst_node_id)
418
+ sint_node_id = uint64_to_int64(in_message.metadata.dst_node_id)
491
419
  dst_node_ids.add(sint_node_id)
492
420
  query = f"""
493
- SELECT node_id, online_until
494
- FROM node
495
- WHERE node_id IN ({",".join(["?"] * len(dst_node_ids))});
496
- """
497
- rows = self.query(query, tuple(dst_node_ids))
421
+ SELECT node_id, online_until
422
+ FROM node
423
+ WHERE node_id IN ({",".join(["?"] * len(dst_node_ids))})
424
+ AND status != ?
425
+ """
426
+ rows = self.query(query, tuple(dst_node_ids) + (NodeStatus.UNREGISTERED,))
498
427
  tmp_ret_dict = check_node_availability_for_in_message(
499
428
  inquired_in_message_ids=message_ids,
500
429
  found_in_message_dict=found_message_ins_dict,
501
430
  node_id_to_online_until={
502
- convert_sint64_to_uint64(row["node_id"]): row["online_until"]
503
- for row in rows
431
+ int64_to_uint64(row["node_id"]): row["online_until"] for row in rows
504
432
  },
505
433
  current_time=current,
506
434
  )
@@ -601,7 +529,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
601
529
  WHERE run_id = :run_id;
602
530
  """
603
531
 
604
- sint64_run_id = convert_uint64_to_sint64(run_id)
532
+ sint64_run_id = uint64_to_int64(run_id)
605
533
  data = {"run_id": sint64_run_id}
606
534
 
607
535
  with self.conn:
@@ -619,7 +547,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
619
547
  )
620
548
 
621
549
  # Convert the uint64 value to sint64 for SQLite
622
- sint64_node_id = convert_uint64_to_sint64(uint64_node_id)
550
+ sint64_node_id = uint64_to_int64(uint64_node_id)
623
551
 
624
552
  query = """
625
553
  INSERT INTO node
@@ -658,17 +586,21 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
658
586
 
659
587
  def delete_node(self, owner_aid: str, node_id: int) -> None:
660
588
  """Delete a node."""
661
- sint64_node_id = convert_uint64_to_sint64(node_id)
589
+ sint64_node_id = uint64_to_int64(node_id)
662
590
 
663
591
  query = """
664
592
  UPDATE node
665
- SET status = ?, unregistered_at = ?
593
+ SET status = ?, unregistered_at = ?,
594
+ online_until = IIF(online_until > ?, ?, online_until)
666
595
  WHERE node_id = ? AND status != ? AND owner_aid = ?
667
596
  RETURNING node_id
668
597
  """
598
+ current = now()
669
599
  params = (
670
600
  NodeStatus.UNREGISTERED,
671
- now().isoformat(),
601
+ current.isoformat(),
602
+ current.timestamp(),
603
+ current.timestamp(),
672
604
  sint64_node_id,
673
605
  NodeStatus.UNREGISTERED,
674
606
  owner_aid,
@@ -693,7 +625,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
693
625
  raise AttributeError("LinkState not initialized")
694
626
 
695
627
  # Convert the uint64 value to sint64 for SQLite
696
- sint64_run_id = convert_uint64_to_sint64(run_id)
628
+ sint64_run_id = uint64_to_int64(run_id)
697
629
 
698
630
  # Validate run ID
699
631
  query = "SELECT COUNT(*) FROM run WHERE run_id = ?"
@@ -738,9 +670,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
738
670
  conditions = []
739
671
  params = []
740
672
  if node_ids is not None:
741
- sint64_node_ids = [
742
- convert_uint64_to_sint64(node_id) for node_id in node_ids
743
- ]
673
+ sint64_node_ids = [uint64_to_int64(node_id) for node_id in node_ids]
744
674
  placeholders = ",".join(["?"] * len(sint64_node_ids))
745
675
  conditions.append(f"node_id IN ({placeholders})")
746
676
  params.extend(sint64_node_ids)
@@ -763,7 +693,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
763
693
  result: list[NodeInfo] = []
764
694
  for row in rows:
765
695
  # Convert sint64 node_id to uint64
766
- row["node_id"] = convert_sint64_to_uint64(row["node_id"])
696
+ row["node_id"] = int64_to_uint64(row["node_id"])
767
697
  result.append(NodeInfo(**row))
768
698
 
769
699
  return result
@@ -771,7 +701,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
771
701
  def get_node_public_key(self, node_id: int) -> bytes:
772
702
  """Get `public_key` for the specified `node_id`."""
773
703
  # Convert the uint64 value to sint64 for SQLite
774
- sint64_node_id = convert_uint64_to_sint64(node_id)
704
+ sint64_node_id = uint64_to_int64(node_id)
775
705
 
776
706
  # Query the public key for the given node_id
777
707
  query = "SELECT public_key FROM node WHERE node_id = ? AND status != ?;"
@@ -795,7 +725,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
795
725
  return None
796
726
 
797
727
  # Convert sint64 node_id to uint64
798
- node_id = convert_sint64_to_uint64(rows[0]["node_id"])
728
+ node_id = int64_to_uint64(rows[0]["node_id"])
799
729
  return node_id
800
730
 
801
731
  # pylint: disable=too-many-arguments,too-many-positional-arguments
@@ -813,7 +743,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
813
743
  uint64_run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
814
744
 
815
745
  # Convert the uint64 value to sint64 for SQLite
816
- sint64_run_id = convert_uint64_to_sint64(uint64_run_id)
746
+ sint64_run_id = uint64_to_int64(uint64_run_id)
817
747
 
818
748
  # Check conflicts
819
749
  query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
@@ -861,7 +791,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
861
791
  )
862
792
  else:
863
793
  rows = self.query("SELECT run_id FROM run;", ())
864
- return {convert_sint64_to_uint64(row["run_id"]) for row in rows}
794
+ return {int64_to_uint64(row["run_id"]) for row in rows}
865
795
 
866
796
  def _check_and_tag_inactive_run(self, run_ids: set[int]) -> None:
867
797
  """Check if any runs are no longer active.
@@ -869,7 +799,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
869
799
  Marks runs with status 'starting' or 'running' as failed
870
800
  if they have not sent a heartbeat before `active_until`.
871
801
  """
872
- sint_run_ids = [convert_uint64_to_sint64(run_id) for run_id in run_ids]
802
+ sint_run_ids = [uint64_to_int64(run_id) for run_id in run_ids]
873
803
  query = "UPDATE run SET finished_at = ?, sub_status = ?, details = ? "
874
804
  query += "WHERE starting_at != '' AND finished_at = '' AND active_until < ?"
875
805
  query += f" AND run_id IN ({','.join(['?'] * len(run_ids))});"
@@ -891,13 +821,13 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
891
821
  self._check_and_tag_inactive_run(run_ids={run_id})
892
822
 
893
823
  # Convert the uint64 value to sint64 for SQLite
894
- sint64_run_id = convert_uint64_to_sint64(run_id)
824
+ sint64_run_id = uint64_to_int64(run_id)
895
825
  query = "SELECT * FROM run WHERE run_id = ?;"
896
826
  rows = self.query(query, (sint64_run_id,))
897
827
  if rows:
898
828
  row = rows[0]
899
829
  return Run(
900
- run_id=convert_sint64_to_uint64(row["run_id"]),
830
+ run_id=int64_to_uint64(row["run_id"]),
901
831
  fab_id=row["fab_id"],
902
832
  fab_version=row["fab_version"],
903
833
  fab_hash=row["fab_hash"],
@@ -922,13 +852,13 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
922
852
  self._check_and_tag_inactive_run(run_ids=run_ids)
923
853
 
924
854
  # Convert the uint64 value to sint64 for SQLite
925
- sint64_run_ids = (convert_uint64_to_sint64(run_id) for run_id in set(run_ids))
855
+ sint64_run_ids = (uint64_to_int64(run_id) for run_id in set(run_ids))
926
856
  query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});"
927
857
  rows = self.query(query, tuple(sint64_run_ids))
928
858
 
929
859
  return {
930
860
  # Restore uint64 run IDs
931
- convert_sint64_to_uint64(row["run_id"]): RunStatus(
861
+ int64_to_uint64(row["run_id"]): RunStatus(
932
862
  status=determine_run_status(row),
933
863
  sub_status=row["sub_status"],
934
864
  details=row["details"],
@@ -942,7 +872,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
942
872
  self._check_and_tag_inactive_run(run_ids={run_id})
943
873
 
944
874
  # Convert the uint64 value to sint64 for SQLite
945
- sint64_run_id = convert_uint64_to_sint64(run_id)
875
+ sint64_run_id = uint64_to_int64(run_id)
946
876
  query = "SELECT * FROM run WHERE run_id = ?;"
947
877
  rows = self.query(query, (sint64_run_id,))
948
878
 
@@ -1008,7 +938,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1008
938
  new_status.details,
1009
939
  active_until,
1010
940
  heartbeat_interval,
1011
- convert_uint64_to_sint64(run_id),
941
+ uint64_to_int64(run_id),
1012
942
  )
1013
943
  self.query(query % timestamp_fld, data)
1014
944
  return True
@@ -1021,14 +951,14 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1021
951
  query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1;"
1022
952
  rows = self.query(query)
1023
953
  if rows:
1024
- pending_run_id = convert_sint64_to_uint64(rows[0]["run_id"])
954
+ pending_run_id = int64_to_uint64(rows[0]["run_id"])
1025
955
 
1026
956
  return pending_run_id
1027
957
 
1028
958
  def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
1029
959
  """Retrieve the federation options for the specified `run_id`."""
1030
960
  # Convert the uint64 value to sint64 for SQLite
1031
- sint64_run_id = convert_uint64_to_sint64(run_id)
961
+ sint64_run_id = uint64_to_int64(run_id)
1032
962
  query = "SELECT federation_options FROM run WHERE run_id = ?;"
1033
963
  rows = self.query(query, (sint64_run_id,))
1034
964
 
@@ -1053,7 +983,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1053
983
  if self.conn is None:
1054
984
  raise AttributeError("LinkState not initialized")
1055
985
 
1056
- sint64_node_id = convert_uint64_to_sint64(node_id)
986
+ sint64_node_id = uint64_to_int64(node_id)
1057
987
 
1058
988
  with self.conn:
1059
989
  # Check if node exists and not deleted
@@ -1095,7 +1025,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1095
1025
  self._check_and_tag_inactive_run(run_ids={run_id})
1096
1026
 
1097
1027
  # Search for the run
1098
- sint_run_id = convert_uint64_to_sint64(run_id)
1028
+ sint_run_id = uint64_to_int64(run_id)
1099
1029
  query = "SELECT * FROM run WHERE run_id = ?;"
1100
1030
  rows = self.query(query, (sint_run_id,))
1101
1031
 
@@ -1125,7 +1055,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1125
1055
  """Get the context for the specified `run_id`."""
1126
1056
  # Retrieve context if any
1127
1057
  query = "SELECT context FROM context WHERE run_id = ?;"
1128
- rows = self.query(query, (convert_uint64_to_sint64(run_id),))
1058
+ rows = self.query(query, (uint64_to_int64(run_id),))
1129
1059
  context = context_from_bytes(rows[0]["context"]) if rows else None
1130
1060
  return context
1131
1061
 
@@ -1133,7 +1063,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1133
1063
  """Set the context for the specified `run_id`."""
1134
1064
  # Convert context to bytes
1135
1065
  context_bytes = context_to_bytes(context)
1136
- sint_run_id = convert_uint64_to_sint64(run_id)
1066
+ sint_run_id = uint64_to_int64(run_id)
1137
1067
 
1138
1068
  # Check if any existing Context assigned to the run_id
1139
1069
  query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
@@ -1152,7 +1082,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1152
1082
  def add_serverapp_log(self, run_id: int, log_message: str) -> None:
1153
1083
  """Add a log entry to the ServerApp logs for the specified `run_id`."""
1154
1084
  # Convert the uint64 value to sint64 for SQLite
1155
- sint64_run_id = convert_uint64_to_sint64(run_id)
1085
+ sint64_run_id = uint64_to_int64(run_id)
1156
1086
 
1157
1087
  # Store log
1158
1088
  try:
@@ -1168,7 +1098,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1168
1098
  ) -> tuple[str, float]:
1169
1099
  """Get the ServerApp logs for the specified `run_id`."""
1170
1100
  # Convert the uint64 value to sint64 for SQLite
1171
- sint64_run_id = convert_uint64_to_sint64(run_id)
1101
+ sint64_run_id = uint64_to_int64(run_id)
1172
1102
 
1173
1103
  # Check if the run_id exists
1174
1104
  query = "SELECT run_id FROM run WHERE run_id = ?;"
@@ -1218,7 +1148,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1218
1148
  """Create a token for the given run ID."""
1219
1149
  token = secrets.token_hex(FLWR_APP_TOKEN_LENGTH) # Generate a random token
1220
1150
  query = "INSERT INTO token_store (run_id, token) VALUES (:run_id, :token);"
1221
- data = {"run_id": convert_uint64_to_sint64(run_id), "token": token}
1151
+ data = {"run_id": uint64_to_int64(run_id), "token": token}
1222
1152
  try:
1223
1153
  self.query(query, data)
1224
1154
  except sqlite3.IntegrityError:
@@ -1228,7 +1158,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1228
1158
  def verify_token(self, run_id: int, token: str) -> bool:
1229
1159
  """Verify a token for the given run ID."""
1230
1160
  query = "SELECT token FROM token_store WHERE run_id = :run_id;"
1231
- data = {"run_id": convert_uint64_to_sint64(run_id)}
1161
+ data = {"run_id": uint64_to_int64(run_id)}
1232
1162
  rows = self.query(query, data)
1233
1163
  if not rows:
1234
1164
  return False
@@ -1237,7 +1167,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1237
1167
  def delete_token(self, run_id: int) -> None:
1238
1168
  """Delete the token for the given run ID."""
1239
1169
  query = "DELETE FROM token_store WHERE run_id = :run_id;"
1240
- data = {"run_id": convert_uint64_to_sint64(run_id)}
1170
+ data = {"run_id": uint64_to_int64(run_id)}
1241
1171
  self.query(query, data)
1242
1172
 
1243
1173
  def get_run_id_by_token(self, token: str) -> Optional[int]:
@@ -1247,19 +1177,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1247
1177
  rows = self.query(query, data)
1248
1178
  if not rows:
1249
1179
  return None
1250
- return convert_sint64_to_uint64(rows[0]["run_id"])
1251
-
1252
-
1253
- def dict_factory(
1254
- cursor: sqlite3.Cursor,
1255
- row: sqlite3.Row,
1256
- ) -> dict[str, Any]:
1257
- """Turn SQLite results into dicts.
1258
-
1259
- Less efficent for retrival of large amounts of data but easier to use.
1260
- """
1261
- fields = [column[0] for column in cursor.description]
1262
- return dict(zip(fields, row))
1180
+ return int64_to_uint64(rows[0]["run_id"])
1263
1181
 
1264
1182
 
1265
1183
  def message_to_dict(message: Message) -> dict[str, Any]:
@@ -1314,5 +1232,5 @@ def determine_run_status(row: dict[str, Any]) -> str:
1314
1232
  return Status.RUNNING
1315
1233
  return Status.STARTING
1316
1234
  return Status.PENDING
1317
- run_id = convert_sint64_to_uint64(row["run_id"])
1235
+ run_id = int64_to_uint64(row["run_id"])
1318
1236
  raise sqlite3.IntegrityError(f"The run {run_id} does not have a valid status.")
@@ -33,6 +33,7 @@ from flwr.common.typing import RunStatus
33
33
  # pylint: disable=E0611
34
34
  from flwr.proto.message_pb2 import Context as ProtoContext
35
35
  from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
36
+ from flwr.supercore.utils import int64_to_uint64, uint64_to_int64
36
37
 
37
38
  # pylint: enable=E0611
38
39
  VALID_RUN_STATUS_TRANSITIONS = {
@@ -76,58 +77,6 @@ def generate_rand_int_from_bytes(
76
77
  return num
77
78
 
78
79
 
79
- def convert_uint64_to_sint64(u: int) -> int:
80
- """Convert a uint64 value to a sint64 value with the same bit sequence.
81
-
82
- Parameters
83
- ----------
84
- u : int
85
- The unsigned 64-bit integer to convert.
86
-
87
- Returns
88
- -------
89
- int
90
- The signed 64-bit integer equivalent.
91
-
92
- The signed 64-bit integer will have the same bit pattern as the
93
- unsigned 64-bit integer but may have a different decimal value.
94
-
95
- For numbers within the range [0, `sint64` max value], the decimal
96
- value remains the same. However, for numbers greater than the `sint64`
97
- max value, the decimal value will differ due to the wraparound caused
98
- by the sign bit.
99
- """
100
- if u >= (1 << 63):
101
- return u - (1 << 64)
102
- return u
103
-
104
-
105
- def convert_sint64_to_uint64(s: int) -> int:
106
- """Convert a sint64 value to a uint64 value with the same bit sequence.
107
-
108
- Parameters
109
- ----------
110
- s : int
111
- The signed 64-bit integer to convert.
112
-
113
- Returns
114
- -------
115
- int
116
- The unsigned 64-bit integer equivalent.
117
-
118
- The unsigned 64-bit integer will have the same bit pattern as the
119
- signed 64-bit integer but may have a different decimal value.
120
-
121
- For negative `sint64` values, the conversion adds 2^64 to the
122
- signed value to obtain the equivalent `uint64` value. For non-negative
123
- `sint64` values, the decimal value remains unchanged in the `uint64`
124
- representation.
125
- """
126
- if s < 0:
127
- return s + (1 << 64)
128
- return s
129
-
130
-
131
80
  def convert_uint64_values_in_dict_to_sint64(
132
81
  data_dict: dict[str, int], keys: list[str]
133
82
  ) -> None:
@@ -142,7 +91,7 @@ def convert_uint64_values_in_dict_to_sint64(
142
91
  """
143
92
  for key in keys:
144
93
  if key in data_dict:
145
- data_dict[key] = convert_uint64_to_sint64(data_dict[key])
94
+ data_dict[key] = uint64_to_int64(data_dict[key])
146
95
 
147
96
 
148
97
  def convert_sint64_values_in_dict_to_uint64(
@@ -159,7 +108,7 @@ def convert_sint64_values_in_dict_to_uint64(
159
108
  """
160
109
  for key in keys:
161
110
  if key in data_dict:
162
- data_dict[key] = convert_sint64_to_uint64(data_dict[key])
111
+ data_dict[key] = int64_to_uint64(data_dict[key])
163
112
 
164
113
 
165
114
  def context_to_bytes(context: Context) -> bytes:
@@ -51,6 +51,7 @@ from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes
51
51
  from flwr.simulation.ray_transport.utils import (
52
52
  enable_tf_gpu_growth as enable_gpu_growth,
53
53
  )
54
+ from flwr.supercore.constant import FLWR_IN_MEMORY_DB_NAME
54
55
 
55
56
 
56
57
  def _replace_keys(d: Any, match: str, target: str) -> Any:
@@ -336,7 +337,7 @@ def _main_loop(
336
337
  ) -> Context:
337
338
  """Start ServerApp on a separate thread, then launch Simulation Engine."""
338
339
  # Initialize StateFactory
339
- state_factory = LinkStateFactory(":flwr-in-memory-state:")
340
+ state_factory = LinkStateFactory(FLWR_IN_MEMORY_DB_NAME)
340
341
 
341
342
  f_stop = threading.Event()
342
343
  # A Threading event to indicate if an exception was raised in the ServerApp thread
@@ -20,6 +20,9 @@ from __future__ import annotations
20
20
  # Top-level key in YAML config for exec plugin settings
21
21
  EXEC_PLUGIN_SECTION = "exec_plugin"
22
22
 
23
+ # Flower in-memory Python-based database name
24
+ FLWR_IN_MEMORY_DB_NAME = ":flwr-in-memory:"
25
+
23
26
 
24
27
  class NodeStatus:
25
28
  """Event log writer types."""
@@ -19,15 +19,27 @@ from logging import DEBUG
19
19
  from typing import Optional
20
20
 
21
21
  from flwr.common.logger import log
22
+ from flwr.supercore.constant import FLWR_IN_MEMORY_DB_NAME
22
23
 
23
24
  from .in_memory_object_store import InMemoryObjectStore
24
25
  from .object_store import ObjectStore
26
+ from .sqlite_object_store import SqliteObjectStore
25
27
 
26
28
 
27
29
  class ObjectStoreFactory:
28
- """Factory class that creates ObjectStore instances."""
30
+ """Factory class that creates ObjectStore instances.
29
31
 
30
- def __init__(self) -> None:
32
+ Parameters
33
+ ----------
34
+ database : str (default: FLWR_IN_MEMORY_DB_NAME)
35
+ A string representing the path to the database file that will be opened.
36
+ Note that passing ":memory:" will open a connection to a database that is
37
+ in RAM, instead of on disk. And FLWR_IN_MEMORY_DB_NAME will create an
38
+ Python-based in-memory ObjectStore.
39
+ """
40
+
41
+ def __init__(self, database: str = FLWR_IN_MEMORY_DB_NAME) -> None:
42
+ self.database = database
31
43
  self.store_instance: Optional[ObjectStore] = None
32
44
 
33
45
  def store(self) -> ObjectStore:
@@ -38,7 +50,15 @@ class ObjectStoreFactory:
38
50
  ObjectStore
39
51
  An ObjectStore instance for storing objects by object_id.
40
52
  """
41
- if self.store_instance is None:
42
- self.store_instance = InMemoryObjectStore()
43
- log(DEBUG, "Using InMemoryObjectStore")
44
- return self.store_instance
53
+ # InMemoryObjectStore
54
+ if self.database == FLWR_IN_MEMORY_DB_NAME:
55
+ if self.store_instance is None:
56
+ self.store_instance = InMemoryObjectStore()
57
+ log(DEBUG, "Using InMemoryObjectStore")
58
+ return self.store_instance
59
+
60
+ # SqliteObjectStore
61
+ store = SqliteObjectStore(self.database)
62
+ store.initialize()
63
+ log(DEBUG, "Using SqliteObjectStore")
64
+ return store