flwr-nightly 1.23.0.dev20251020__py3-none-any.whl → 1.23.0.dev20251022__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/ls.py +5 -5
- flwr/cli/supernode/ls.py +8 -15
- flwr/common/constant.py +12 -4
- flwr/common/inflatable_utils.py +10 -10
- flwr/common/record/array.py +3 -3
- flwr/proto/control_pb2.py +6 -6
- flwr/proto/control_pb2.pyi +0 -5
- flwr/server/app.py +4 -6
- flwr/server/superlink/fleet/vce/vce_api.py +2 -1
- flwr/server/superlink/linkstate/in_memory_linkstate.py +5 -1
- flwr/server/superlink/linkstate/linkstate_factory.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +64 -146
- flwr/server/superlink/linkstate/utils.py +3 -54
- flwr/simulation/run_simulation.py +2 -1
- flwr/supercore/constant.py +3 -0
- flwr/supercore/object_store/object_store_factory.py +26 -6
- flwr/supercore/object_store/sqlite_object_store.py +252 -0
- flwr/supercore/sqlite_mixin.py +156 -0
- flwr/supercore/utils.py +20 -0
- flwr/superlink/servicer/control/control_servicer.py +6 -70
- {flwr_nightly-1.23.0.dev20251020.dist-info → flwr_nightly-1.23.0.dev20251022.dist-info}/METADATA +1 -1
- {flwr_nightly-1.23.0.dev20251020.dist-info → flwr_nightly-1.23.0.dev20251022.dist-info}/RECORD +24 -22
- {flwr_nightly-1.23.0.dev20251020.dist-info → flwr_nightly-1.23.0.dev20251022.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.23.0.dev20251020.dist-info → flwr_nightly-1.23.0.dev20251022.dist-info}/entry_points.txt +0 -0
|
@@ -18,11 +18,10 @@
|
|
|
18
18
|
# pylint: disable=too-many-lines
|
|
19
19
|
|
|
20
20
|
import json
|
|
21
|
-
import re
|
|
22
21
|
import secrets
|
|
23
22
|
import sqlite3
|
|
24
23
|
from collections.abc import Sequence
|
|
25
|
-
from logging import
|
|
24
|
+
from logging import ERROR, WARNING
|
|
26
25
|
from typing import Any, Optional, Union, cast
|
|
27
26
|
|
|
28
27
|
from flwr.common import Context, Message, Metadata, log, now
|
|
@@ -52,6 +51,8 @@ from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
|
|
|
52
51
|
# pylint: enable=E0611
|
|
53
52
|
from flwr.server.utils.validator import validate_message
|
|
54
53
|
from flwr.supercore.constant import NodeStatus
|
|
54
|
+
from flwr.supercore.sqlite_mixin import SqliteMixin
|
|
55
|
+
from flwr.supercore.utils import int64_to_uint64, uint64_to_int64
|
|
55
56
|
|
|
56
57
|
from .linkstate import LinkState
|
|
57
58
|
from .utils import (
|
|
@@ -60,9 +61,7 @@ from .utils import (
|
|
|
60
61
|
configrecord_to_bytes,
|
|
61
62
|
context_from_bytes,
|
|
62
63
|
context_to_bytes,
|
|
63
|
-
convert_sint64_to_uint64,
|
|
64
64
|
convert_sint64_values_in_dict_to_uint64,
|
|
65
|
-
convert_uint64_to_sint64,
|
|
66
65
|
convert_uint64_values_in_dict_to_sint64,
|
|
67
66
|
generate_rand_int_from_bytes,
|
|
68
67
|
has_valid_sub_status,
|
|
@@ -183,95 +182,25 @@ CREATE TABLE IF NOT EXISTS token_store (
|
|
|
183
182
|
);
|
|
184
183
|
"""
|
|
185
184
|
|
|
186
|
-
DictOrTuple = Union[tuple[Any, ...], dict[str, Any]]
|
|
187
185
|
|
|
188
|
-
|
|
189
|
-
class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
186
|
+
class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
190
187
|
"""SQLite-based LinkState implementation."""
|
|
191
188
|
|
|
192
|
-
def __init__(
|
|
193
|
-
self,
|
|
194
|
-
database_path: str,
|
|
195
|
-
) -> None:
|
|
196
|
-
"""Initialize an SqliteLinkState.
|
|
197
|
-
|
|
198
|
-
Parameters
|
|
199
|
-
----------
|
|
200
|
-
database : (path-like object)
|
|
201
|
-
The path to the database file to be opened. Pass ":memory:" to open
|
|
202
|
-
a connection to a database that is in RAM, instead of on disk.
|
|
203
|
-
"""
|
|
204
|
-
self.database_path = database_path
|
|
205
|
-
self.conn: Optional[sqlite3.Connection] = None
|
|
206
|
-
|
|
207
189
|
def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
|
|
208
|
-
"""
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
self.conn.row_factory = dict_factory
|
|
223
|
-
if log_queries:
|
|
224
|
-
self.conn.set_trace_callback(lambda query: log(DEBUG, query))
|
|
225
|
-
cur = self.conn.cursor()
|
|
226
|
-
|
|
227
|
-
# Create each table if not exists queries
|
|
228
|
-
cur.execute(SQL_CREATE_TABLE_RUN)
|
|
229
|
-
cur.execute(SQL_CREATE_TABLE_LOGS)
|
|
230
|
-
cur.execute(SQL_CREATE_TABLE_CONTEXT)
|
|
231
|
-
cur.execute(SQL_CREATE_TABLE_MESSAGE_INS)
|
|
232
|
-
cur.execute(SQL_CREATE_TABLE_MESSAGE_RES)
|
|
233
|
-
cur.execute(SQL_CREATE_TABLE_NODE)
|
|
234
|
-
cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
|
|
235
|
-
cur.execute(SQL_CREATE_TABLE_TOKEN_STORE)
|
|
236
|
-
cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
|
|
237
|
-
cur.execute(SQL_CREATE_INDEX_OWNER_AID)
|
|
238
|
-
res = cur.execute("SELECT name FROM sqlite_schema;")
|
|
239
|
-
return res.fetchall()
|
|
240
|
-
|
|
241
|
-
def query(
|
|
242
|
-
self,
|
|
243
|
-
query: str,
|
|
244
|
-
data: Optional[Union[Sequence[DictOrTuple], DictOrTuple]] = None,
|
|
245
|
-
) -> list[dict[str, Any]]:
|
|
246
|
-
"""Execute a SQL query."""
|
|
247
|
-
if self.conn is None:
|
|
248
|
-
raise AttributeError("LinkState is not initialized.")
|
|
249
|
-
|
|
250
|
-
if data is None:
|
|
251
|
-
data = []
|
|
252
|
-
|
|
253
|
-
# Clean up whitespace to make the logs nicer
|
|
254
|
-
query = re.sub(r"\s+", " ", query)
|
|
255
|
-
|
|
256
|
-
try:
|
|
257
|
-
with self.conn:
|
|
258
|
-
if (
|
|
259
|
-
len(data) > 0
|
|
260
|
-
and isinstance(data, (tuple, list))
|
|
261
|
-
and isinstance(data[0], (tuple, dict))
|
|
262
|
-
):
|
|
263
|
-
rows = self.conn.executemany(query, data)
|
|
264
|
-
else:
|
|
265
|
-
rows = self.conn.execute(query, data)
|
|
266
|
-
|
|
267
|
-
# Extract results before committing to support
|
|
268
|
-
# INSERT/UPDATE ... RETURNING
|
|
269
|
-
# style queries
|
|
270
|
-
result = rows.fetchall()
|
|
271
|
-
except KeyError as exc:
|
|
272
|
-
log(ERROR, {"query": query, "data": data, "exception": exc})
|
|
273
|
-
|
|
274
|
-
return result
|
|
190
|
+
"""Connect to the DB, enable FK support, and create tables if needed."""
|
|
191
|
+
return self._ensure_initialized(
|
|
192
|
+
SQL_CREATE_TABLE_RUN,
|
|
193
|
+
SQL_CREATE_TABLE_LOGS,
|
|
194
|
+
SQL_CREATE_TABLE_CONTEXT,
|
|
195
|
+
SQL_CREATE_TABLE_MESSAGE_INS,
|
|
196
|
+
SQL_CREATE_TABLE_MESSAGE_RES,
|
|
197
|
+
SQL_CREATE_TABLE_NODE,
|
|
198
|
+
SQL_CREATE_TABLE_PUBLIC_KEY,
|
|
199
|
+
SQL_CREATE_TABLE_TOKEN_STORE,
|
|
200
|
+
SQL_CREATE_INDEX_ONLINE_UNTIL,
|
|
201
|
+
SQL_CREATE_INDEX_OWNER_AID,
|
|
202
|
+
log_queries=log_queries,
|
|
203
|
+
)
|
|
275
204
|
|
|
276
205
|
def store_message_ins(self, message: Message) -> Optional[str]:
|
|
277
206
|
"""Store one Message."""
|
|
@@ -335,7 +264,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
335
264
|
data: dict[str, Union[str, int]] = {}
|
|
336
265
|
|
|
337
266
|
# Convert the uint64 value to sint64 for SQLite
|
|
338
|
-
data["node_id"] =
|
|
267
|
+
data["node_id"] = uint64_to_int64(node_id)
|
|
339
268
|
|
|
340
269
|
# Retrieve all Messages for node_id
|
|
341
270
|
query = """
|
|
@@ -410,8 +339,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
410
339
|
if (
|
|
411
340
|
msg_ins
|
|
412
341
|
and message
|
|
413
|
-
and
|
|
414
|
-
!= res_metadata.src_node_id
|
|
342
|
+
and int64_to_uint64(msg_ins["dst_node_id"]) != res_metadata.src_node_id
|
|
415
343
|
):
|
|
416
344
|
return None
|
|
417
345
|
|
|
@@ -487,20 +415,20 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
487
415
|
dst_node_ids: set[int] = set()
|
|
488
416
|
for message_id in message_ids:
|
|
489
417
|
in_message = found_message_ins_dict[message_id]
|
|
490
|
-
sint_node_id =
|
|
418
|
+
sint_node_id = uint64_to_int64(in_message.metadata.dst_node_id)
|
|
491
419
|
dst_node_ids.add(sint_node_id)
|
|
492
420
|
query = f"""
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
421
|
+
SELECT node_id, online_until
|
|
422
|
+
FROM node
|
|
423
|
+
WHERE node_id IN ({",".join(["?"] * len(dst_node_ids))})
|
|
424
|
+
AND status != ?
|
|
425
|
+
"""
|
|
426
|
+
rows = self.query(query, tuple(dst_node_ids) + (NodeStatus.UNREGISTERED,))
|
|
498
427
|
tmp_ret_dict = check_node_availability_for_in_message(
|
|
499
428
|
inquired_in_message_ids=message_ids,
|
|
500
429
|
found_in_message_dict=found_message_ins_dict,
|
|
501
430
|
node_id_to_online_until={
|
|
502
|
-
|
|
503
|
-
for row in rows
|
|
431
|
+
int64_to_uint64(row["node_id"]): row["online_until"] for row in rows
|
|
504
432
|
},
|
|
505
433
|
current_time=current,
|
|
506
434
|
)
|
|
@@ -601,7 +529,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
601
529
|
WHERE run_id = :run_id;
|
|
602
530
|
"""
|
|
603
531
|
|
|
604
|
-
sint64_run_id =
|
|
532
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
605
533
|
data = {"run_id": sint64_run_id}
|
|
606
534
|
|
|
607
535
|
with self.conn:
|
|
@@ -619,7 +547,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
619
547
|
)
|
|
620
548
|
|
|
621
549
|
# Convert the uint64 value to sint64 for SQLite
|
|
622
|
-
sint64_node_id =
|
|
550
|
+
sint64_node_id = uint64_to_int64(uint64_node_id)
|
|
623
551
|
|
|
624
552
|
query = """
|
|
625
553
|
INSERT INTO node
|
|
@@ -658,17 +586,21 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
658
586
|
|
|
659
587
|
def delete_node(self, owner_aid: str, node_id: int) -> None:
|
|
660
588
|
"""Delete a node."""
|
|
661
|
-
sint64_node_id =
|
|
589
|
+
sint64_node_id = uint64_to_int64(node_id)
|
|
662
590
|
|
|
663
591
|
query = """
|
|
664
592
|
UPDATE node
|
|
665
|
-
SET status = ?, unregistered_at =
|
|
593
|
+
SET status = ?, unregistered_at = ?,
|
|
594
|
+
online_until = IIF(online_until > ?, ?, online_until)
|
|
666
595
|
WHERE node_id = ? AND status != ? AND owner_aid = ?
|
|
667
596
|
RETURNING node_id
|
|
668
597
|
"""
|
|
598
|
+
current = now()
|
|
669
599
|
params = (
|
|
670
600
|
NodeStatus.UNREGISTERED,
|
|
671
|
-
|
|
601
|
+
current.isoformat(),
|
|
602
|
+
current.timestamp(),
|
|
603
|
+
current.timestamp(),
|
|
672
604
|
sint64_node_id,
|
|
673
605
|
NodeStatus.UNREGISTERED,
|
|
674
606
|
owner_aid,
|
|
@@ -693,7 +625,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
693
625
|
raise AttributeError("LinkState not initialized")
|
|
694
626
|
|
|
695
627
|
# Convert the uint64 value to sint64 for SQLite
|
|
696
|
-
sint64_run_id =
|
|
628
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
697
629
|
|
|
698
630
|
# Validate run ID
|
|
699
631
|
query = "SELECT COUNT(*) FROM run WHERE run_id = ?"
|
|
@@ -738,9 +670,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
738
670
|
conditions = []
|
|
739
671
|
params = []
|
|
740
672
|
if node_ids is not None:
|
|
741
|
-
sint64_node_ids = [
|
|
742
|
-
convert_uint64_to_sint64(node_id) for node_id in node_ids
|
|
743
|
-
]
|
|
673
|
+
sint64_node_ids = [uint64_to_int64(node_id) for node_id in node_ids]
|
|
744
674
|
placeholders = ",".join(["?"] * len(sint64_node_ids))
|
|
745
675
|
conditions.append(f"node_id IN ({placeholders})")
|
|
746
676
|
params.extend(sint64_node_ids)
|
|
@@ -763,7 +693,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
763
693
|
result: list[NodeInfo] = []
|
|
764
694
|
for row in rows:
|
|
765
695
|
# Convert sint64 node_id to uint64
|
|
766
|
-
row["node_id"] =
|
|
696
|
+
row["node_id"] = int64_to_uint64(row["node_id"])
|
|
767
697
|
result.append(NodeInfo(**row))
|
|
768
698
|
|
|
769
699
|
return result
|
|
@@ -771,7 +701,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
771
701
|
def get_node_public_key(self, node_id: int) -> bytes:
|
|
772
702
|
"""Get `public_key` for the specified `node_id`."""
|
|
773
703
|
# Convert the uint64 value to sint64 for SQLite
|
|
774
|
-
sint64_node_id =
|
|
704
|
+
sint64_node_id = uint64_to_int64(node_id)
|
|
775
705
|
|
|
776
706
|
# Query the public key for the given node_id
|
|
777
707
|
query = "SELECT public_key FROM node WHERE node_id = ? AND status != ?;"
|
|
@@ -795,7 +725,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
795
725
|
return None
|
|
796
726
|
|
|
797
727
|
# Convert sint64 node_id to uint64
|
|
798
|
-
node_id =
|
|
728
|
+
node_id = int64_to_uint64(rows[0]["node_id"])
|
|
799
729
|
return node_id
|
|
800
730
|
|
|
801
731
|
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
@@ -813,7 +743,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
813
743
|
uint64_run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
814
744
|
|
|
815
745
|
# Convert the uint64 value to sint64 for SQLite
|
|
816
|
-
sint64_run_id =
|
|
746
|
+
sint64_run_id = uint64_to_int64(uint64_run_id)
|
|
817
747
|
|
|
818
748
|
# Check conflicts
|
|
819
749
|
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
|
|
@@ -861,7 +791,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
861
791
|
)
|
|
862
792
|
else:
|
|
863
793
|
rows = self.query("SELECT run_id FROM run;", ())
|
|
864
|
-
return {
|
|
794
|
+
return {int64_to_uint64(row["run_id"]) for row in rows}
|
|
865
795
|
|
|
866
796
|
def _check_and_tag_inactive_run(self, run_ids: set[int]) -> None:
|
|
867
797
|
"""Check if any runs are no longer active.
|
|
@@ -869,7 +799,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
869
799
|
Marks runs with status 'starting' or 'running' as failed
|
|
870
800
|
if they have not sent a heartbeat before `active_until`.
|
|
871
801
|
"""
|
|
872
|
-
sint_run_ids = [
|
|
802
|
+
sint_run_ids = [uint64_to_int64(run_id) for run_id in run_ids]
|
|
873
803
|
query = "UPDATE run SET finished_at = ?, sub_status = ?, details = ? "
|
|
874
804
|
query += "WHERE starting_at != '' AND finished_at = '' AND active_until < ?"
|
|
875
805
|
query += f" AND run_id IN ({','.join(['?'] * len(run_ids))});"
|
|
@@ -891,13 +821,13 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
891
821
|
self._check_and_tag_inactive_run(run_ids={run_id})
|
|
892
822
|
|
|
893
823
|
# Convert the uint64 value to sint64 for SQLite
|
|
894
|
-
sint64_run_id =
|
|
824
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
895
825
|
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
896
826
|
rows = self.query(query, (sint64_run_id,))
|
|
897
827
|
if rows:
|
|
898
828
|
row = rows[0]
|
|
899
829
|
return Run(
|
|
900
|
-
run_id=
|
|
830
|
+
run_id=int64_to_uint64(row["run_id"]),
|
|
901
831
|
fab_id=row["fab_id"],
|
|
902
832
|
fab_version=row["fab_version"],
|
|
903
833
|
fab_hash=row["fab_hash"],
|
|
@@ -922,13 +852,13 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
922
852
|
self._check_and_tag_inactive_run(run_ids=run_ids)
|
|
923
853
|
|
|
924
854
|
# Convert the uint64 value to sint64 for SQLite
|
|
925
|
-
sint64_run_ids = (
|
|
855
|
+
sint64_run_ids = (uint64_to_int64(run_id) for run_id in set(run_ids))
|
|
926
856
|
query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});"
|
|
927
857
|
rows = self.query(query, tuple(sint64_run_ids))
|
|
928
858
|
|
|
929
859
|
return {
|
|
930
860
|
# Restore uint64 run IDs
|
|
931
|
-
|
|
861
|
+
int64_to_uint64(row["run_id"]): RunStatus(
|
|
932
862
|
status=determine_run_status(row),
|
|
933
863
|
sub_status=row["sub_status"],
|
|
934
864
|
details=row["details"],
|
|
@@ -942,7 +872,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
942
872
|
self._check_and_tag_inactive_run(run_ids={run_id})
|
|
943
873
|
|
|
944
874
|
# Convert the uint64 value to sint64 for SQLite
|
|
945
|
-
sint64_run_id =
|
|
875
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
946
876
|
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
947
877
|
rows = self.query(query, (sint64_run_id,))
|
|
948
878
|
|
|
@@ -1008,7 +938,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1008
938
|
new_status.details,
|
|
1009
939
|
active_until,
|
|
1010
940
|
heartbeat_interval,
|
|
1011
|
-
|
|
941
|
+
uint64_to_int64(run_id),
|
|
1012
942
|
)
|
|
1013
943
|
self.query(query % timestamp_fld, data)
|
|
1014
944
|
return True
|
|
@@ -1021,14 +951,14 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1021
951
|
query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1;"
|
|
1022
952
|
rows = self.query(query)
|
|
1023
953
|
if rows:
|
|
1024
|
-
pending_run_id =
|
|
954
|
+
pending_run_id = int64_to_uint64(rows[0]["run_id"])
|
|
1025
955
|
|
|
1026
956
|
return pending_run_id
|
|
1027
957
|
|
|
1028
958
|
def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
|
|
1029
959
|
"""Retrieve the federation options for the specified `run_id`."""
|
|
1030
960
|
# Convert the uint64 value to sint64 for SQLite
|
|
1031
|
-
sint64_run_id =
|
|
961
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
1032
962
|
query = "SELECT federation_options FROM run WHERE run_id = ?;"
|
|
1033
963
|
rows = self.query(query, (sint64_run_id,))
|
|
1034
964
|
|
|
@@ -1053,7 +983,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1053
983
|
if self.conn is None:
|
|
1054
984
|
raise AttributeError("LinkState not initialized")
|
|
1055
985
|
|
|
1056
|
-
sint64_node_id =
|
|
986
|
+
sint64_node_id = uint64_to_int64(node_id)
|
|
1057
987
|
|
|
1058
988
|
with self.conn:
|
|
1059
989
|
# Check if node exists and not deleted
|
|
@@ -1095,7 +1025,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1095
1025
|
self._check_and_tag_inactive_run(run_ids={run_id})
|
|
1096
1026
|
|
|
1097
1027
|
# Search for the run
|
|
1098
|
-
sint_run_id =
|
|
1028
|
+
sint_run_id = uint64_to_int64(run_id)
|
|
1099
1029
|
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
1100
1030
|
rows = self.query(query, (sint_run_id,))
|
|
1101
1031
|
|
|
@@ -1125,7 +1055,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1125
1055
|
"""Get the context for the specified `run_id`."""
|
|
1126
1056
|
# Retrieve context if any
|
|
1127
1057
|
query = "SELECT context FROM context WHERE run_id = ?;"
|
|
1128
|
-
rows = self.query(query, (
|
|
1058
|
+
rows = self.query(query, (uint64_to_int64(run_id),))
|
|
1129
1059
|
context = context_from_bytes(rows[0]["context"]) if rows else None
|
|
1130
1060
|
return context
|
|
1131
1061
|
|
|
@@ -1133,7 +1063,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1133
1063
|
"""Set the context for the specified `run_id`."""
|
|
1134
1064
|
# Convert context to bytes
|
|
1135
1065
|
context_bytes = context_to_bytes(context)
|
|
1136
|
-
sint_run_id =
|
|
1066
|
+
sint_run_id = uint64_to_int64(run_id)
|
|
1137
1067
|
|
|
1138
1068
|
# Check if any existing Context assigned to the run_id
|
|
1139
1069
|
query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
|
|
@@ -1152,7 +1082,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1152
1082
|
def add_serverapp_log(self, run_id: int, log_message: str) -> None:
|
|
1153
1083
|
"""Add a log entry to the ServerApp logs for the specified `run_id`."""
|
|
1154
1084
|
# Convert the uint64 value to sint64 for SQLite
|
|
1155
|
-
sint64_run_id =
|
|
1085
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
1156
1086
|
|
|
1157
1087
|
# Store log
|
|
1158
1088
|
try:
|
|
@@ -1168,7 +1098,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1168
1098
|
) -> tuple[str, float]:
|
|
1169
1099
|
"""Get the ServerApp logs for the specified `run_id`."""
|
|
1170
1100
|
# Convert the uint64 value to sint64 for SQLite
|
|
1171
|
-
sint64_run_id =
|
|
1101
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
1172
1102
|
|
|
1173
1103
|
# Check if the run_id exists
|
|
1174
1104
|
query = "SELECT run_id FROM run WHERE run_id = ?;"
|
|
@@ -1218,7 +1148,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1218
1148
|
"""Create a token for the given run ID."""
|
|
1219
1149
|
token = secrets.token_hex(FLWR_APP_TOKEN_LENGTH) # Generate a random token
|
|
1220
1150
|
query = "INSERT INTO token_store (run_id, token) VALUES (:run_id, :token);"
|
|
1221
|
-
data = {"run_id":
|
|
1151
|
+
data = {"run_id": uint64_to_int64(run_id), "token": token}
|
|
1222
1152
|
try:
|
|
1223
1153
|
self.query(query, data)
|
|
1224
1154
|
except sqlite3.IntegrityError:
|
|
@@ -1228,7 +1158,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1228
1158
|
def verify_token(self, run_id: int, token: str) -> bool:
|
|
1229
1159
|
"""Verify a token for the given run ID."""
|
|
1230
1160
|
query = "SELECT token FROM token_store WHERE run_id = :run_id;"
|
|
1231
|
-
data = {"run_id":
|
|
1161
|
+
data = {"run_id": uint64_to_int64(run_id)}
|
|
1232
1162
|
rows = self.query(query, data)
|
|
1233
1163
|
if not rows:
|
|
1234
1164
|
return False
|
|
@@ -1237,7 +1167,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1237
1167
|
def delete_token(self, run_id: int) -> None:
|
|
1238
1168
|
"""Delete the token for the given run ID."""
|
|
1239
1169
|
query = "DELETE FROM token_store WHERE run_id = :run_id;"
|
|
1240
|
-
data = {"run_id":
|
|
1170
|
+
data = {"run_id": uint64_to_int64(run_id)}
|
|
1241
1171
|
self.query(query, data)
|
|
1242
1172
|
|
|
1243
1173
|
def get_run_id_by_token(self, token: str) -> Optional[int]:
|
|
@@ -1247,19 +1177,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1247
1177
|
rows = self.query(query, data)
|
|
1248
1178
|
if not rows:
|
|
1249
1179
|
return None
|
|
1250
|
-
return
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
def dict_factory(
|
|
1254
|
-
cursor: sqlite3.Cursor,
|
|
1255
|
-
row: sqlite3.Row,
|
|
1256
|
-
) -> dict[str, Any]:
|
|
1257
|
-
"""Turn SQLite results into dicts.
|
|
1258
|
-
|
|
1259
|
-
Less efficent for retrival of large amounts of data but easier to use.
|
|
1260
|
-
"""
|
|
1261
|
-
fields = [column[0] for column in cursor.description]
|
|
1262
|
-
return dict(zip(fields, row))
|
|
1180
|
+
return int64_to_uint64(rows[0]["run_id"])
|
|
1263
1181
|
|
|
1264
1182
|
|
|
1265
1183
|
def message_to_dict(message: Message) -> dict[str, Any]:
|
|
@@ -1314,5 +1232,5 @@ def determine_run_status(row: dict[str, Any]) -> str:
|
|
|
1314
1232
|
return Status.RUNNING
|
|
1315
1233
|
return Status.STARTING
|
|
1316
1234
|
return Status.PENDING
|
|
1317
|
-
run_id =
|
|
1235
|
+
run_id = int64_to_uint64(row["run_id"])
|
|
1318
1236
|
raise sqlite3.IntegrityError(f"The run {run_id} does not have a valid status.")
|
|
@@ -33,6 +33,7 @@ from flwr.common.typing import RunStatus
|
|
|
33
33
|
# pylint: disable=E0611
|
|
34
34
|
from flwr.proto.message_pb2 import Context as ProtoContext
|
|
35
35
|
from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
|
|
36
|
+
from flwr.supercore.utils import int64_to_uint64, uint64_to_int64
|
|
36
37
|
|
|
37
38
|
# pylint: enable=E0611
|
|
38
39
|
VALID_RUN_STATUS_TRANSITIONS = {
|
|
@@ -76,58 +77,6 @@ def generate_rand_int_from_bytes(
|
|
|
76
77
|
return num
|
|
77
78
|
|
|
78
79
|
|
|
79
|
-
def convert_uint64_to_sint64(u: int) -> int:
|
|
80
|
-
"""Convert a uint64 value to a sint64 value with the same bit sequence.
|
|
81
|
-
|
|
82
|
-
Parameters
|
|
83
|
-
----------
|
|
84
|
-
u : int
|
|
85
|
-
The unsigned 64-bit integer to convert.
|
|
86
|
-
|
|
87
|
-
Returns
|
|
88
|
-
-------
|
|
89
|
-
int
|
|
90
|
-
The signed 64-bit integer equivalent.
|
|
91
|
-
|
|
92
|
-
The signed 64-bit integer will have the same bit pattern as the
|
|
93
|
-
unsigned 64-bit integer but may have a different decimal value.
|
|
94
|
-
|
|
95
|
-
For numbers within the range [0, `sint64` max value], the decimal
|
|
96
|
-
value remains the same. However, for numbers greater than the `sint64`
|
|
97
|
-
max value, the decimal value will differ due to the wraparound caused
|
|
98
|
-
by the sign bit.
|
|
99
|
-
"""
|
|
100
|
-
if u >= (1 << 63):
|
|
101
|
-
return u - (1 << 64)
|
|
102
|
-
return u
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
def convert_sint64_to_uint64(s: int) -> int:
|
|
106
|
-
"""Convert a sint64 value to a uint64 value with the same bit sequence.
|
|
107
|
-
|
|
108
|
-
Parameters
|
|
109
|
-
----------
|
|
110
|
-
s : int
|
|
111
|
-
The signed 64-bit integer to convert.
|
|
112
|
-
|
|
113
|
-
Returns
|
|
114
|
-
-------
|
|
115
|
-
int
|
|
116
|
-
The unsigned 64-bit integer equivalent.
|
|
117
|
-
|
|
118
|
-
The unsigned 64-bit integer will have the same bit pattern as the
|
|
119
|
-
signed 64-bit integer but may have a different decimal value.
|
|
120
|
-
|
|
121
|
-
For negative `sint64` values, the conversion adds 2^64 to the
|
|
122
|
-
signed value to obtain the equivalent `uint64` value. For non-negative
|
|
123
|
-
`sint64` values, the decimal value remains unchanged in the `uint64`
|
|
124
|
-
representation.
|
|
125
|
-
"""
|
|
126
|
-
if s < 0:
|
|
127
|
-
return s + (1 << 64)
|
|
128
|
-
return s
|
|
129
|
-
|
|
130
|
-
|
|
131
80
|
def convert_uint64_values_in_dict_to_sint64(
|
|
132
81
|
data_dict: dict[str, int], keys: list[str]
|
|
133
82
|
) -> None:
|
|
@@ -142,7 +91,7 @@ def convert_uint64_values_in_dict_to_sint64(
|
|
|
142
91
|
"""
|
|
143
92
|
for key in keys:
|
|
144
93
|
if key in data_dict:
|
|
145
|
-
data_dict[key] =
|
|
94
|
+
data_dict[key] = uint64_to_int64(data_dict[key])
|
|
146
95
|
|
|
147
96
|
|
|
148
97
|
def convert_sint64_values_in_dict_to_uint64(
|
|
@@ -159,7 +108,7 @@ def convert_sint64_values_in_dict_to_uint64(
|
|
|
159
108
|
"""
|
|
160
109
|
for key in keys:
|
|
161
110
|
if key in data_dict:
|
|
162
|
-
data_dict[key] =
|
|
111
|
+
data_dict[key] = int64_to_uint64(data_dict[key])
|
|
163
112
|
|
|
164
113
|
|
|
165
114
|
def context_to_bytes(context: Context) -> bytes:
|
|
@@ -51,6 +51,7 @@ from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes
|
|
|
51
51
|
from flwr.simulation.ray_transport.utils import (
|
|
52
52
|
enable_tf_gpu_growth as enable_gpu_growth,
|
|
53
53
|
)
|
|
54
|
+
from flwr.supercore.constant import FLWR_IN_MEMORY_DB_NAME
|
|
54
55
|
|
|
55
56
|
|
|
56
57
|
def _replace_keys(d: Any, match: str, target: str) -> Any:
|
|
@@ -336,7 +337,7 @@ def _main_loop(
|
|
|
336
337
|
) -> Context:
|
|
337
338
|
"""Start ServerApp on a separate thread, then launch Simulation Engine."""
|
|
338
339
|
# Initialize StateFactory
|
|
339
|
-
state_factory = LinkStateFactory(
|
|
340
|
+
state_factory = LinkStateFactory(FLWR_IN_MEMORY_DB_NAME)
|
|
340
341
|
|
|
341
342
|
f_stop = threading.Event()
|
|
342
343
|
# A Threading event to indicate if an exception was raised in the ServerApp thread
|
flwr/supercore/constant.py
CHANGED
|
@@ -20,6 +20,9 @@ from __future__ import annotations
|
|
|
20
20
|
# Top-level key in YAML config for exec plugin settings
|
|
21
21
|
EXEC_PLUGIN_SECTION = "exec_plugin"
|
|
22
22
|
|
|
23
|
+
# Flower in-memory Python-based database name
|
|
24
|
+
FLWR_IN_MEMORY_DB_NAME = ":flwr-in-memory:"
|
|
25
|
+
|
|
23
26
|
|
|
24
27
|
class NodeStatus:
|
|
25
28
|
"""Event log writer types."""
|
|
@@ -19,15 +19,27 @@ from logging import DEBUG
|
|
|
19
19
|
from typing import Optional
|
|
20
20
|
|
|
21
21
|
from flwr.common.logger import log
|
|
22
|
+
from flwr.supercore.constant import FLWR_IN_MEMORY_DB_NAME
|
|
22
23
|
|
|
23
24
|
from .in_memory_object_store import InMemoryObjectStore
|
|
24
25
|
from .object_store import ObjectStore
|
|
26
|
+
from .sqlite_object_store import SqliteObjectStore
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
class ObjectStoreFactory:
|
|
28
|
-
"""Factory class that creates ObjectStore instances.
|
|
30
|
+
"""Factory class that creates ObjectStore instances.
|
|
29
31
|
|
|
30
|
-
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
database : str (default: FLWR_IN_MEMORY_DB_NAME)
|
|
35
|
+
A string representing the path to the database file that will be opened.
|
|
36
|
+
Note that passing ":memory:" will open a connection to a database that is
|
|
37
|
+
in RAM, instead of on disk. And FLWR_IN_MEMORY_DB_NAME will create an
|
|
38
|
+
Python-based in-memory ObjectStore.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, database: str = FLWR_IN_MEMORY_DB_NAME) -> None:
|
|
42
|
+
self.database = database
|
|
31
43
|
self.store_instance: Optional[ObjectStore] = None
|
|
32
44
|
|
|
33
45
|
def store(self) -> ObjectStore:
|
|
@@ -38,7 +50,15 @@ class ObjectStoreFactory:
|
|
|
38
50
|
ObjectStore
|
|
39
51
|
An ObjectStore instance for storing objects by object_id.
|
|
40
52
|
"""
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
53
|
+
# InMemoryObjectStore
|
|
54
|
+
if self.database == FLWR_IN_MEMORY_DB_NAME:
|
|
55
|
+
if self.store_instance is None:
|
|
56
|
+
self.store_instance = InMemoryObjectStore()
|
|
57
|
+
log(DEBUG, "Using InMemoryObjectStore")
|
|
58
|
+
return self.store_instance
|
|
59
|
+
|
|
60
|
+
# SqliteObjectStore
|
|
61
|
+
store = SqliteObjectStore(self.database)
|
|
62
|
+
store.initialize()
|
|
63
|
+
log(DEBUG, "Using SqliteObjectStore")
|
|
64
|
+
return store
|