flwr 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. flwr/cli/app.py +17 -1
  2. flwr/cli/auth_plugin/__init__.py +15 -6
  3. flwr/cli/auth_plugin/auth_plugin.py +95 -0
  4. flwr/cli/auth_plugin/noop_auth_plugin.py +58 -0
  5. flwr/cli/auth_plugin/oidc_cli_plugin.py +16 -25
  6. flwr/cli/build.py +118 -47
  7. flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +6 -5
  8. flwr/cli/log.py +2 -2
  9. flwr/cli/login/login.py +34 -23
  10. flwr/cli/ls.py +13 -9
  11. flwr/cli/new/new.py +196 -42
  12. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  13. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  14. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
  15. flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
  16. flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
  17. flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
  18. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
  19. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
  20. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
  21. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
  22. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
  23. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  24. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  25. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  26. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  27. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  28. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  29. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
  30. flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
  31. flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
  32. flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
  33. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
  34. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
  35. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
  36. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
  37. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
  38. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
  39. flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
  40. flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
  41. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
  42. flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
  43. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
  44. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
  45. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  46. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  47. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  48. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  49. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  50. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  51. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  52. flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +3 -3
  53. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  54. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  55. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
  56. flwr/cli/pull.py +100 -0
  57. flwr/cli/run/run.py +11 -7
  58. flwr/cli/stop.py +2 -2
  59. flwr/cli/supernode/__init__.py +25 -0
  60. flwr/cli/supernode/ls.py +260 -0
  61. flwr/cli/supernode/register.py +185 -0
  62. flwr/cli/supernode/unregister.py +138 -0
  63. flwr/cli/utils.py +109 -69
  64. flwr/client/__init__.py +2 -1
  65. flwr/client/grpc_adapter_client/connection.py +6 -8
  66. flwr/client/grpc_rere_client/connection.py +59 -31
  67. flwr/client/grpc_rere_client/grpc_adapter.py +28 -12
  68. flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
  69. flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
  70. flwr/client/rest_client/connection.py +82 -37
  71. flwr/clientapp/__init__.py +1 -2
  72. flwr/clientapp/mod/__init__.py +4 -1
  73. flwr/clientapp/mod/centraldp_mods.py +156 -40
  74. flwr/clientapp/mod/localdp_mod.py +169 -0
  75. flwr/clientapp/typing.py +22 -0
  76. flwr/{client/clientapp → clientapp}/utils.py +1 -1
  77. flwr/common/constant.py +56 -13
  78. flwr/common/exit/exit_code.py +24 -10
  79. flwr/common/inflatable_utils.py +10 -10
  80. flwr/common/record/array.py +3 -3
  81. flwr/common/record/arrayrecord.py +10 -1
  82. flwr/common/record/typeddict.py +12 -0
  83. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
  84. flwr/common/serde.py +4 -2
  85. flwr/common/typing.py +7 -6
  86. flwr/compat/client/app.py +1 -1
  87. flwr/compat/client/grpc_client/connection.py +2 -2
  88. flwr/proto/control_pb2.py +48 -31
  89. flwr/proto/control_pb2.pyi +95 -5
  90. flwr/proto/control_pb2_grpc.py +136 -0
  91. flwr/proto/control_pb2_grpc.pyi +52 -0
  92. flwr/proto/fab_pb2.py +11 -7
  93. flwr/proto/fab_pb2.pyi +21 -1
  94. flwr/proto/fleet_pb2.py +31 -23
  95. flwr/proto/fleet_pb2.pyi +63 -23
  96. flwr/proto/fleet_pb2_grpc.py +98 -28
  97. flwr/proto/fleet_pb2_grpc.pyi +45 -13
  98. flwr/proto/node_pb2.py +3 -1
  99. flwr/proto/node_pb2.pyi +48 -0
  100. flwr/server/app.py +152 -114
  101. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -7
  102. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +132 -38
  103. flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +27 -51
  104. flwr/server/superlink/fleet/message_handler/message_handler.py +67 -22
  105. flwr/server/superlink/fleet/rest_rere/rest_api.py +52 -31
  106. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  107. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
  108. flwr/server/superlink/fleet/vce/vce_api.py +18 -5
  109. flwr/server/superlink/linkstate/in_memory_linkstate.py +167 -73
  110. flwr/server/superlink/linkstate/linkstate.py +107 -24
  111. flwr/server/superlink/linkstate/linkstate_factory.py +2 -1
  112. flwr/server/superlink/linkstate/sqlite_linkstate.py +306 -255
  113. flwr/server/superlink/linkstate/utils.py +3 -54
  114. flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
  115. flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
  116. flwr/server/utils/validator.py +2 -3
  117. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
  118. flwr/serverapp/strategy/__init__.py +26 -0
  119. flwr/serverapp/strategy/bulyan.py +238 -0
  120. flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
  121. flwr/serverapp/strategy/dp_fixed_clipping.py +71 -49
  122. flwr/serverapp/strategy/fedadagrad.py +0 -3
  123. flwr/serverapp/strategy/fedadam.py +0 -3
  124. flwr/serverapp/strategy/fedavg.py +89 -64
  125. flwr/serverapp/strategy/fedavgm.py +198 -0
  126. flwr/serverapp/strategy/fedmedian.py +105 -0
  127. flwr/serverapp/strategy/fedprox.py +174 -0
  128. flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
  129. flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
  130. flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
  131. flwr/serverapp/strategy/fedyogi.py +0 -3
  132. flwr/serverapp/strategy/krum.py +112 -0
  133. flwr/serverapp/strategy/multikrum.py +247 -0
  134. flwr/serverapp/strategy/qfedavg.py +252 -0
  135. flwr/serverapp/strategy/strategy_utils.py +48 -0
  136. flwr/simulation/app.py +1 -1
  137. flwr/simulation/ray_transport/ray_actor.py +1 -1
  138. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  139. flwr/simulation/run_simulation.py +28 -32
  140. flwr/supercore/cli/flower_superexec.py +26 -1
  141. flwr/supercore/constant.py +41 -0
  142. flwr/supercore/object_store/in_memory_object_store.py +0 -4
  143. flwr/supercore/object_store/object_store_factory.py +26 -6
  144. flwr/supercore/object_store/sqlite_object_store.py +252 -0
  145. flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
  146. flwr/supercore/primitives/asymmetric.py +117 -0
  147. flwr/supercore/primitives/asymmetric_ed25519.py +165 -0
  148. flwr/supercore/sqlite_mixin.py +156 -0
  149. flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
  150. flwr/supercore/superexec/run_superexec.py +16 -2
  151. flwr/supercore/utils.py +20 -0
  152. flwr/superlink/artifact_provider/__init__.py +22 -0
  153. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  154. flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
  155. flwr/superlink/auth_plugin/auth_plugin.py +91 -0
  156. flwr/superlink/auth_plugin/noop_auth_plugin.py +87 -0
  157. flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +19 -19
  158. flwr/superlink/servicer/control/control_event_log_interceptor.py +1 -1
  159. flwr/superlink/servicer/control/control_grpc.py +16 -11
  160. flwr/superlink/servicer/control/control_servicer.py +207 -58
  161. flwr/supernode/cli/flower_supernode.py +19 -26
  162. flwr/supernode/runtime/run_clientapp.py +2 -2
  163. flwr/supernode/servicer/clientappio/clientappio_servicer.py +1 -1
  164. flwr/supernode/start_client_internal.py +17 -9
  165. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/METADATA +6 -16
  166. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/RECORD +170 -140
  167. flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
  168. flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
  169. flwr/common/auth_plugin/auth_plugin.py +0 -149
  170. flwr/serverapp/dp_fixed_clipping.py +0 -352
  171. flwr/serverapp/strategy/strategy_utils_tests.py +0 -304
  172. /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
  173. /flwr/{client → clientapp}/client_app.py +0 -0
  174. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/WHEEL +0 -0
  175. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/entry_points.txt +0 -0
@@ -18,18 +18,16 @@
18
18
  # pylint: disable=too-many-lines
19
19
 
20
20
  import json
21
- import re
22
21
  import secrets
23
22
  import sqlite3
24
- import time
25
23
  from collections.abc import Sequence
26
- from logging import DEBUG, ERROR, WARNING
24
+ from logging import ERROR, WARNING
27
25
  from typing import Any, Optional, Union, cast
28
26
 
29
27
  from flwr.common import Context, Message, Metadata, log, now
30
28
  from flwr.common.constant import (
31
29
  FLWR_APP_TOKEN_LENGTH,
32
- HEARTBEAT_MAX_INTERVAL,
30
+ HEARTBEAT_INTERVAL_INF,
33
31
  HEARTBEAT_PATIENCE,
34
32
  MESSAGE_TTL_TOLERANCE,
35
33
  NODE_ID_NUM_BYTES,
@@ -47,10 +45,14 @@ from flwr.common.typing import Run, RunStatus, UserConfig
47
45
 
48
46
  # pylint: disable=E0611
49
47
  from flwr.proto.error_pb2 import Error as ProtoError
48
+ from flwr.proto.node_pb2 import NodeInfo
50
49
  from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
51
50
 
52
51
  # pylint: enable=E0611
53
52
  from flwr.server.utils.validator import validate_message
53
+ from flwr.supercore.constant import NodeStatus
54
+ from flwr.supercore.sqlite_mixin import SqliteMixin
55
+ from flwr.supercore.utils import int64_to_uint64, uint64_to_int64
54
56
 
55
57
  from .linkstate import LinkState
56
58
  from .utils import (
@@ -59,9 +61,7 @@ from .utils import (
59
61
  configrecord_to_bytes,
60
62
  context_from_bytes,
61
63
  context_to_bytes,
62
- convert_sint64_to_uint64,
63
64
  convert_sint64_values_in_dict_to_uint64,
64
- convert_uint64_to_sint64,
65
65
  convert_uint64_values_in_dict_to_sint64,
66
66
  generate_rand_int_from_bytes,
67
67
  has_valid_sub_status,
@@ -72,10 +72,16 @@ from .utils import (
72
72
 
73
73
  SQL_CREATE_TABLE_NODE = """
74
74
  CREATE TABLE IF NOT EXISTS node(
75
- node_id INTEGER UNIQUE,
76
- online_until REAL,
77
- heartbeat_interval REAL,
78
- public_key BLOB
75
+ node_id INTEGER UNIQUE,
76
+ owner_aid TEXT,
77
+ status TEXT,
78
+ registered_at TEXT,
79
+ last_activated_at TEXT NULL,
80
+ last_deactivated_at TEXT NULL,
81
+ unregistered_at TEXT NULL,
82
+ online_until TIMESTAMP NULL,
83
+ heartbeat_interval REAL,
84
+ public_key BLOB UNIQUE
79
85
  );
80
86
  """
81
87
 
@@ -89,6 +95,14 @@ SQL_CREATE_INDEX_ONLINE_UNTIL = """
89
95
  CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
90
96
  """
91
97
 
98
+ SQL_CREATE_INDEX_OWNER_AID = """
99
+ CREATE INDEX IF NOT EXISTS idx_node_owner_aid ON node(owner_aid);
100
+ """
101
+
102
+ SQL_CREATE_INDEX_NODE_STATUS = """
103
+ CREATE INDEX IF NOT EXISTS idx_node_status ON node(status);
104
+ """
105
+
92
106
  SQL_CREATE_TABLE_RUN = """
93
107
  CREATE TABLE IF NOT EXISTS run(
94
108
  run_id INTEGER UNIQUE,
@@ -172,94 +186,26 @@ CREATE TABLE IF NOT EXISTS token_store (
172
186
  );
173
187
  """
174
188
 
175
- DictOrTuple = Union[tuple[Any, ...], dict[str, Any]]
176
189
 
177
-
178
- class SqliteLinkState(LinkState): # pylint: disable=R0904
190
+ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
179
191
  """SQLite-based LinkState implementation."""
180
192
 
181
- def __init__(
182
- self,
183
- database_path: str,
184
- ) -> None:
185
- """Initialize an SqliteLinkState.
186
-
187
- Parameters
188
- ----------
189
- database : (path-like object)
190
- The path to the database file to be opened. Pass ":memory:" to open
191
- a connection to a database that is in RAM, instead of on disk.
192
- """
193
- self.database_path = database_path
194
- self.conn: Optional[sqlite3.Connection] = None
195
-
196
193
  def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
197
- """Create tables if they don't exist yet.
198
-
199
- Parameters
200
- ----------
201
- log_queries : bool
202
- Log each query which is executed.
203
-
204
- Returns
205
- -------
206
- list[tuple[str]]
207
- The list of all tables in the DB.
208
- """
209
- self.conn = sqlite3.connect(self.database_path)
210
- self.conn.execute("PRAGMA foreign_keys = ON;")
211
- self.conn.row_factory = dict_factory
212
- if log_queries:
213
- self.conn.set_trace_callback(lambda query: log(DEBUG, query))
214
- cur = self.conn.cursor()
215
-
216
- # Create each table if not exists queries
217
- cur.execute(SQL_CREATE_TABLE_RUN)
218
- cur.execute(SQL_CREATE_TABLE_LOGS)
219
- cur.execute(SQL_CREATE_TABLE_CONTEXT)
220
- cur.execute(SQL_CREATE_TABLE_MESSAGE_INS)
221
- cur.execute(SQL_CREATE_TABLE_MESSAGE_RES)
222
- cur.execute(SQL_CREATE_TABLE_NODE)
223
- cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
224
- cur.execute(SQL_CREATE_TABLE_TOKEN_STORE)
225
- cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
226
- res = cur.execute("SELECT name FROM sqlite_schema;")
227
- return res.fetchall()
228
-
229
- def query(
230
- self,
231
- query: str,
232
- data: Optional[Union[Sequence[DictOrTuple], DictOrTuple]] = None,
233
- ) -> list[dict[str, Any]]:
234
- """Execute a SQL query."""
235
- if self.conn is None:
236
- raise AttributeError("LinkState is not initialized.")
237
-
238
- if data is None:
239
- data = []
240
-
241
- # Clean up whitespace to make the logs nicer
242
- query = re.sub(r"\s+", " ", query)
243
-
244
- try:
245
- with self.conn:
246
- if (
247
- len(data) > 0
248
- and isinstance(data, (tuple, list))
249
- and isinstance(data[0], (tuple, dict))
250
- ):
251
- rows = self.conn.executemany(query, data)
252
- else:
253
- rows = self.conn.execute(query, data)
254
-
255
- # Extract results before committing to support
256
- # INSERT/UPDATE ... RETURNING
257
- # style queries
258
- result = rows.fetchall()
259
- except KeyError as exc:
260
- log(ERROR, {"query": query, "data": data, "exception": exc})
261
-
262
- return result
194
+ """Connect to the DB, enable FK support, and create tables if needed."""
195
+ return self._ensure_initialized(
196
+ SQL_CREATE_TABLE_RUN,
197
+ SQL_CREATE_TABLE_LOGS,
198
+ SQL_CREATE_TABLE_CONTEXT,
199
+ SQL_CREATE_TABLE_MESSAGE_INS,
200
+ SQL_CREATE_TABLE_MESSAGE_RES,
201
+ SQL_CREATE_TABLE_NODE,
202
+ SQL_CREATE_TABLE_PUBLIC_KEY,
203
+ SQL_CREATE_TABLE_TOKEN_STORE,
204
+ SQL_CREATE_INDEX_ONLINE_UNTIL,
205
+ SQL_CREATE_INDEX_OWNER_AID,
206
+ SQL_CREATE_INDEX_NODE_STATUS,
207
+ log_queries=log_queries,
208
+ )
263
209
 
264
210
  def store_message_ins(self, message: Message) -> Optional[str]:
265
211
  """Store one Message."""
@@ -293,8 +239,10 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
293
239
  return None
294
240
 
295
241
  # Validate destination node ID
296
- query = "SELECT node_id FROM node WHERE node_id = ?;"
297
- if not self.query(query, (data[0]["dst_node_id"],)):
242
+ query = "SELECT node_id FROM node WHERE node_id = ? AND status IN (?, ?);"
243
+ if not self.query(
244
+ query, (data[0]["dst_node_id"], NodeStatus.ONLINE, NodeStatus.OFFLINE)
245
+ ):
298
246
  log(
299
247
  ERROR,
300
248
  "Invalid destination node ID for Message: %s",
@@ -323,7 +271,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
323
271
  data: dict[str, Union[str, int]] = {}
324
272
 
325
273
  # Convert the uint64 value to sint64 for SQLite
326
- data["node_id"] = convert_uint64_to_sint64(node_id)
274
+ data["node_id"] = uint64_to_int64(node_id)
327
275
 
328
276
  # Retrieve all Messages for node_id
329
277
  query = """
@@ -398,8 +346,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
398
346
  if (
399
347
  msg_ins
400
348
  and message
401
- and convert_sint64_to_uint64(msg_ins["dst_node_id"])
402
- != res_metadata.src_node_id
349
+ and int64_to_uint64(msg_ins["dst_node_id"]) != res_metadata.src_node_id
403
350
  ):
404
351
  return None
405
352
 
@@ -451,7 +398,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
451
398
  ret: dict[str, Message] = {}
452
399
 
453
400
  # Verify Message IDs
454
- current = time.time()
401
+ current = now().timestamp()
455
402
  query = f"""
456
403
  SELECT *
457
404
  FROM message_ins
@@ -475,20 +422,20 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
475
422
  dst_node_ids: set[int] = set()
476
423
  for message_id in message_ids:
477
424
  in_message = found_message_ins_dict[message_id]
478
- sint_node_id = convert_uint64_to_sint64(in_message.metadata.dst_node_id)
425
+ sint_node_id = uint64_to_int64(in_message.metadata.dst_node_id)
479
426
  dst_node_ids.add(sint_node_id)
480
427
  query = f"""
481
- SELECT node_id, online_until
482
- FROM node
483
- WHERE node_id IN ({",".join(["?"] * len(dst_node_ids))});
484
- """
485
- rows = self.query(query, tuple(dst_node_ids))
428
+ SELECT node_id, online_until
429
+ FROM node
430
+ WHERE node_id IN ({",".join(["?"] * len(dst_node_ids))})
431
+ AND status != ?
432
+ """
433
+ rows = self.query(query, tuple(dst_node_ids) + (NodeStatus.UNREGISTERED,))
486
434
  tmp_ret_dict = check_node_availability_for_in_message(
487
435
  inquired_in_message_ids=message_ids,
488
436
  found_in_message_dict=found_message_ins_dict,
489
437
  node_id_to_online_until={
490
- convert_sint64_to_uint64(row["node_id"]): row["online_until"]
491
- for row in rows
438
+ int64_to_uint64(row["node_id"]): row["online_until"] for row in rows
492
439
  },
493
440
  current_time=current,
494
441
  )
@@ -589,7 +536,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
589
536
  WHERE run_id = :run_id;
590
537
  """
591
538
 
592
- sint64_run_id = convert_uint64_to_sint64(run_id)
539
+ sint64_run_id = uint64_to_int64(run_id)
593
540
  data = {"run_id": sint64_run_id}
594
541
 
595
542
  with self.conn:
@@ -597,7 +544,9 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
597
544
 
598
545
  return {row["message_id"] for row in rows}
599
546
 
600
- def create_node(self, heartbeat_interval: float) -> int:
547
+ def create_node(
548
+ self, owner_aid: str, public_key: bytes, heartbeat_interval: float
549
+ ) -> int:
601
550
  """Create, store in the link state, and return `node_id`."""
602
551
  # Sample a random uint64 as node_id
603
552
  uint64_node_id = generate_rand_int_from_bytes(
@@ -605,50 +554,126 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
605
554
  )
606
555
 
607
556
  # Convert the uint64 value to sint64 for SQLite
608
- sint64_node_id = convert_uint64_to_sint64(uint64_node_id)
557
+ sint64_node_id = uint64_to_int64(uint64_node_id)
609
558
 
610
- query = (
611
- "INSERT INTO node "
612
- "(node_id, online_until, heartbeat_interval, public_key) "
613
- "VALUES (?, ?, ?, ?)"
614
- )
559
+ query = """
560
+ INSERT INTO node
561
+ (node_id, owner_aid, status, registered_at, last_activated_at,
562
+ last_deactivated_at, unregistered_at, online_until, heartbeat_interval,
563
+ public_key)
564
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
565
+ """
615
566
 
616
- # Mark the node online util time.time() + heartbeat_interval
567
+ # Mark the node online until now().timestamp() + heartbeat_interval
617
568
  try:
618
569
  self.query(
619
570
  query,
620
571
  (
621
- sint64_node_id,
622
- time.time() + heartbeat_interval,
623
- heartbeat_interval,
624
- b"", # Initialize with an empty public key
572
+ sint64_node_id, # node_id
573
+ owner_aid, # owner_aid
574
+ NodeStatus.REGISTERED, # status
575
+ now().isoformat(), # registered_at
576
+ None, # last_activated_at
577
+ None, # last_deactivated_at
578
+ None, # unregistered_at
579
+ None, # online_until, initialized with offline status
580
+ heartbeat_interval, # heartbeat_interval
581
+ public_key, # public_key
625
582
  ),
626
583
  )
627
- except sqlite3.IntegrityError:
584
+ except sqlite3.IntegrityError as e:
585
+ if "UNIQUE constraint failed: node.public_key" in str(e):
586
+ raise ValueError("Public key already in use.") from None
587
+ # Must be node ID conflict, almost impossible unless system is compromised
628
588
  log(ERROR, "Unexpected node registration failure.")
629
589
  return 0
630
590
 
631
591
  # Note: we need to return the uint64 value of the node_id
632
592
  return uint64_node_id
633
593
 
634
- def delete_node(self, node_id: int) -> None:
594
+ def delete_node(self, owner_aid: str, node_id: int) -> None:
635
595
  """Delete a node."""
636
- # Convert the uint64 value to sint64 for SQLite
637
- sint64_node_id = convert_uint64_to_sint64(node_id)
596
+ sint64_node_id = uint64_to_int64(node_id)
638
597
 
639
- query = "DELETE FROM node WHERE node_id = ?"
640
- params = (sint64_node_id,)
598
+ query = """
599
+ UPDATE node
600
+ SET status = ?, unregistered_at = ?,
601
+ online_until = IIF(online_until > ?, ?, online_until)
602
+ WHERE node_id = ? AND status != ? AND owner_aid = ?
603
+ RETURNING node_id
604
+ """
605
+ current = now()
606
+ params = (
607
+ NodeStatus.UNREGISTERED,
608
+ current.isoformat(),
609
+ current.timestamp(),
610
+ current.timestamp(),
611
+ sint64_node_id,
612
+ NodeStatus.UNREGISTERED,
613
+ owner_aid,
614
+ )
641
615
 
642
- if self.conn is None:
643
- raise AttributeError("LinkState is not initialized.")
616
+ rows = self.query(query, params)
617
+ if not rows:
618
+ raise ValueError(
619
+ f"Node {node_id} already deleted, not found or unauthorized "
620
+ "deletion attempt."
621
+ )
644
622
 
645
- try:
646
- with self.conn:
647
- rows = self.conn.execute(query, params)
648
- if rows.rowcount < 1:
649
- raise ValueError(f"Node {node_id} not found")
650
- except KeyError as exc:
651
- log(ERROR, {"query": query, "data": params, "exception": exc})
623
+ def activate_node(self, node_id: int, heartbeat_interval: float) -> bool:
624
+ """Activate the node with the specified `node_id`."""
625
+ with self.conn:
626
+ self._check_and_tag_offline_nodes([node_id])
627
+
628
+ # Only activate if the node is currently registered or offline
629
+ current_dt = now()
630
+ query = """
631
+ UPDATE node
632
+ SET status = ?,
633
+ last_activated_at = ?,
634
+ online_until = ?,
635
+ heartbeat_interval = ?
636
+ WHERE node_id = ? AND status in (?, ?)
637
+ RETURNING node_id
638
+ """
639
+ params = (
640
+ NodeStatus.ONLINE,
641
+ current_dt.isoformat(),
642
+ current_dt.timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval,
643
+ heartbeat_interval,
644
+ uint64_to_int64(node_id),
645
+ NodeStatus.REGISTERED,
646
+ NodeStatus.OFFLINE,
647
+ )
648
+
649
+ row = self.conn.execute(query, params).fetchone()
650
+ return row is not None
651
+
652
+ def deactivate_node(self, node_id: int) -> bool:
653
+ """Deactivate the node with the specified `node_id`."""
654
+ with self.conn:
655
+ self._check_and_tag_offline_nodes([node_id])
656
+
657
+ # Only deactivate if the node is currently online
658
+ current_dt = now()
659
+ query = """
660
+ UPDATE node
661
+ SET status = ?,
662
+ last_deactivated_at = ?,
663
+ online_until = ?
664
+ WHERE node_id = ? AND status = ?
665
+ RETURNING node_id
666
+ """
667
+ params = (
668
+ NodeStatus.OFFLINE,
669
+ current_dt.isoformat(),
670
+ current_dt.timestamp(),
671
+ uint64_to_int64(node_id),
672
+ NodeStatus.ONLINE,
673
+ )
674
+
675
+ row = self.conn.execute(query, params).fetchone()
676
+ return row is not None
652
677
 
653
678
  def get_nodes(self, run_id: int) -> set[int]:
654
679
  """Retrieve all currently stored node IDs as a set.
@@ -658,69 +683,117 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
658
683
  If the provided `run_id` does not exist or has no matching nodes,
659
684
  an empty `Set` MUST be returned.
660
685
  """
686
+ if self.conn is None:
687
+ raise AttributeError("LinkState not initialized")
688
+
661
689
  # Convert the uint64 value to sint64 for SQLite
662
- sint64_run_id = convert_uint64_to_sint64(run_id)
690
+ sint64_run_id = uint64_to_int64(run_id)
663
691
 
664
692
  # Validate run ID
665
- query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
666
- if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
693
+ query = "SELECT COUNT(*) FROM run WHERE run_id = ?"
694
+ rows = self.query(query, (sint64_run_id,))
695
+ if rows[0]["COUNT(*)"] == 0:
667
696
  return set()
668
697
 
669
- # Get nodes
670
- query = "SELECT node_id FROM node WHERE online_until > ?;"
671
- rows = self.query(query, (time.time(),))
672
-
673
- # Convert sint64 node_ids to uint64
674
- result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows}
675
- return result
676
-
677
- def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
678
- """Set `public_key` for the specified `node_id`."""
679
- # Convert the uint64 value to sint64 for SQLite
680
- sint64_node_id = convert_uint64_to_sint64(node_id)
681
-
682
- # Check if the node exists in the `node` table
683
- query = "SELECT 1 FROM node WHERE node_id = ?"
684
- if not self.query(query, (sint64_node_id,)):
685
- raise ValueError(f"Node {node_id} not found")
686
-
687
- # Check if the public key is already in use in the `node` table
688
- query = "SELECT 1 FROM node WHERE public_key = ?"
689
- if self.query(query, (public_key,)):
690
- raise ValueError("Public key already in use")
698
+ # Retrieve all online nodes
699
+ return {
700
+ node.node_id for node in self.get_node_info(statuses=[NodeStatus.ONLINE])
701
+ }
691
702
 
692
- # Update the `node` table to set the public key for the given node ID
693
- query = "UPDATE node SET public_key = ? WHERE node_id = ?"
694
- self.query(query, (public_key, sint64_node_id))
703
+ def _check_and_tag_offline_nodes(
704
+ self, node_ids: Optional[list[int]] = None
705
+ ) -> None:
706
+ """Check and tag offline nodes."""
707
+ # strftime will convert POSIX timestamp to ISO format
708
+ query = """
709
+ UPDATE node SET status = ?,
710
+ last_deactivated_at =
711
+ strftime("%Y-%m-%dT%H:%M:%f+00:00", online_until, "unixepoch")
712
+ WHERE online_until <= ? AND status == ?
713
+ """
714
+ params = [
715
+ NodeStatus.OFFLINE,
716
+ now().timestamp(),
717
+ NodeStatus.ONLINE,
718
+ ]
719
+ if node_ids is not None:
720
+ placeholders = ",".join(["?"] * len(node_ids))
721
+ query += f" AND node_id IN ({placeholders})"
722
+ params.extend(uint64_to_int64(node_id) for node_id in node_ids)
723
+ self.conn.execute(query, params)
695
724
 
696
- def get_node_public_key(self, node_id: int) -> Optional[bytes]:
725
+ def get_node_info(
726
+ self,
727
+ *,
728
+ node_ids: Optional[Sequence[int]] = None,
729
+ owner_aids: Optional[Sequence[str]] = None,
730
+ statuses: Optional[Sequence[str]] = None,
731
+ ) -> Sequence[NodeInfo]:
732
+ """Retrieve information about nodes based on the specified filters."""
733
+ with self.conn:
734
+ self._check_and_tag_offline_nodes()
735
+
736
+ # Build the WHERE clause based on provided filters
737
+ conditions = []
738
+ params: list[Any] = []
739
+ if node_ids is not None:
740
+ sint64_node_ids = [uint64_to_int64(node_id) for node_id in node_ids]
741
+ placeholders = ",".join(["?"] * len(sint64_node_ids))
742
+ conditions.append(f"node_id IN ({placeholders})")
743
+ params.extend(sint64_node_ids)
744
+ if owner_aids is not None:
745
+ placeholders = ",".join(["?"] * len(owner_aids))
746
+ conditions.append(f"owner_aid IN ({placeholders})")
747
+ params.extend(owner_aids)
748
+ if statuses is not None:
749
+ placeholders = ",".join(["?"] * len(statuses))
750
+ conditions.append(f"status IN ({placeholders})")
751
+ params.extend(statuses)
752
+
753
+ # Construct the final query
754
+ query = "SELECT * FROM node"
755
+ if conditions:
756
+ query += " WHERE " + " AND ".join(conditions)
757
+
758
+ rows = self.conn.execute(query, params).fetchall()
759
+
760
+ result: list[NodeInfo] = []
761
+ for row in rows:
762
+ # Convert sint64 node_id to uint64
763
+ row["node_id"] = int64_to_uint64(row["node_id"])
764
+ result.append(NodeInfo(**row))
765
+
766
+ return result
767
+
768
+ def get_node_public_key(self, node_id: int) -> bytes:
697
769
  """Get `public_key` for the specified `node_id`."""
698
770
  # Convert the uint64 value to sint64 for SQLite
699
- sint64_node_id = convert_uint64_to_sint64(node_id)
771
+ sint64_node_id = uint64_to_int64(node_id)
700
772
 
701
773
  # Query the public key for the given node_id
702
- query = "SELECT public_key FROM node WHERE node_id = ?"
703
- rows = self.query(query, (sint64_node_id,))
774
+ query = "SELECT public_key FROM node WHERE node_id = ? AND status != ?;"
775
+ rows = self.query(query, (sint64_node_id, NodeStatus.UNREGISTERED))
704
776
 
705
777
  # If no result is found, return None
706
778
  if not rows:
707
- raise ValueError(f"Node {node_id} not found")
779
+ raise ValueError(f"Node ID {node_id} not found")
708
780
 
709
- # Return the public key if it is not empty, otherwise return None
710
- return rows[0]["public_key"] or None
781
+ # Return the public key
782
+ return cast(bytes, rows[0]["public_key"])
711
783
 
712
- def get_node_id(self, node_public_key: bytes) -> Optional[int]:
713
- """Retrieve stored `node_id` filtered by `node_public_keys`."""
714
- query = "SELECT node_id FROM node WHERE public_key = :public_key;"
715
- row = self.query(query, {"public_key": node_public_key})
716
- if len(row) > 0:
717
- node_id: int = row[0]["node_id"]
784
+ def get_node_id_by_public_key(self, public_key: bytes) -> Optional[int]:
785
+ """Get `node_id` for the specified `public_key` if it exists and is not
786
+ deleted."""
787
+ query = "SELECT node_id FROM node WHERE public_key = ? AND status != ?;"
788
+ rows = self.query(query, (public_key, NodeStatus.UNREGISTERED))
718
789
 
719
- # Convert the sint64 value to uint64 after reading from SQLite
720
- uint64_node_id = convert_sint64_to_uint64(node_id)
790
+ # If no result is found, return None
791
+ if not rows:
792
+ return None
721
793
 
722
- return uint64_node_id
723
- return None
794
+ # Convert sint64 node_id to uint64
795
+ node_id = int64_to_uint64(rows[0]["node_id"])
796
+ return node_id
724
797
 
725
798
  # pylint: disable=too-many-arguments,too-many-positional-arguments
726
799
  def create_run(
@@ -737,7 +810,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
737
810
  uint64_run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
738
811
 
739
812
  # Convert the uint64 value to sint64 for SQLite
740
- sint64_run_id = convert_uint64_to_sint64(uint64_run_id)
813
+ sint64_run_id = uint64_to_int64(uint64_run_id)
741
814
 
742
815
  # Check conflicts
743
816
  query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
@@ -773,28 +846,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
773
846
  log(ERROR, "Unexpected run creation failure.")
774
847
  return 0
775
848
 
776
- def clear_supernode_auth_keys(self) -> None:
777
- """Clear stored `node_public_keys` in the link state if any."""
778
- self.query("DELETE FROM public_key;")
779
-
780
- def store_node_public_keys(self, public_keys: set[bytes]) -> None:
781
- """Store a set of `node_public_keys` in the link state."""
782
- query = "INSERT INTO public_key (public_key) VALUES (?)"
783
- data = [(key,) for key in public_keys]
784
- self.query(query, data)
785
-
786
- def store_node_public_key(self, public_key: bytes) -> None:
787
- """Store a `node_public_key` in the link state."""
788
- query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
789
- self.query(query, {"public_key": public_key})
790
-
791
- def get_node_public_keys(self) -> set[bytes]:
792
- """Retrieve all currently stored `node_public_keys` as a set."""
793
- query = "SELECT public_key FROM public_key"
794
- rows = self.query(query)
795
- result: set[bytes] = {row["public_key"] for row in rows}
796
- return result
797
-
798
849
  def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
799
850
  """Retrieve all run IDs if `flwr_aid` is not specified.
800
851
 
@@ -807,7 +858,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
807
858
  )
808
859
  else:
809
860
  rows = self.query("SELECT run_id FROM run;", ())
810
- return {convert_sint64_to_uint64(row["run_id"]) for row in rows}
861
+ return {int64_to_uint64(row["run_id"]) for row in rows}
811
862
 
812
863
  def _check_and_tag_inactive_run(self, run_ids: set[int]) -> None:
813
864
  """Check if any runs are no longer active.
@@ -815,7 +866,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
815
866
  Marks runs with status 'starting' or 'running' as failed
816
867
  if they have not sent a heartbeat before `active_until`.
817
868
  """
818
- sint_run_ids = [convert_uint64_to_sint64(run_id) for run_id in run_ids]
869
+ sint_run_ids = [uint64_to_int64(run_id) for run_id in run_ids]
819
870
  query = "UPDATE run SET finished_at = ?, sub_status = ?, details = ? "
820
871
  query += "WHERE starting_at != '' AND finished_at = '' AND active_until < ?"
821
872
  query += f" AND run_id IN ({','.join(['?'] * len(run_ids))});"
@@ -837,13 +888,13 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
837
888
  self._check_and_tag_inactive_run(run_ids={run_id})
838
889
 
839
890
  # Convert the uint64 value to sint64 for SQLite
840
- sint64_run_id = convert_uint64_to_sint64(run_id)
891
+ sint64_run_id = uint64_to_int64(run_id)
841
892
  query = "SELECT * FROM run WHERE run_id = ?;"
842
893
  rows = self.query(query, (sint64_run_id,))
843
894
  if rows:
844
895
  row = rows[0]
845
896
  return Run(
846
- run_id=convert_sint64_to_uint64(row["run_id"]),
897
+ run_id=int64_to_uint64(row["run_id"]),
847
898
  fab_id=row["fab_id"],
848
899
  fab_version=row["fab_version"],
849
900
  fab_hash=row["fab_hash"],
@@ -868,13 +919,13 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
868
919
  self._check_and_tag_inactive_run(run_ids=run_ids)
869
920
 
870
921
  # Convert the uint64 value to sint64 for SQLite
871
- sint64_run_ids = (convert_uint64_to_sint64(run_id) for run_id in set(run_ids))
922
+ sint64_run_ids = (uint64_to_int64(run_id) for run_id in set(run_ids))
872
923
  query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});"
873
924
  rows = self.query(query, tuple(sint64_run_ids))
874
925
 
875
926
  return {
876
927
  # Restore uint64 run IDs
877
- convert_sint64_to_uint64(row["run_id"]): RunStatus(
928
+ int64_to_uint64(row["run_id"]): RunStatus(
878
929
  status=determine_run_status(row),
879
930
  sub_status=row["sub_status"],
880
931
  details=row["details"],
@@ -888,7 +939,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
888
939
  self._check_and_tag_inactive_run(run_ids={run_id})
889
940
 
890
941
  # Convert the uint64 value to sint64 for SQLite
891
- sint64_run_id = convert_uint64_to_sint64(run_id)
942
+ sint64_run_id = uint64_to_int64(run_id)
892
943
  query = "SELECT * FROM run WHERE run_id = ?;"
893
944
  rows = self.query(query, (sint64_run_id,))
894
945
 
@@ -933,7 +984,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
933
984
  # when switching to starting or running
934
985
  current = now()
935
986
  if new_status.status in (Status.STARTING, Status.RUNNING):
936
- heartbeat_interval = HEARTBEAT_MAX_INTERVAL
987
+ heartbeat_interval = HEARTBEAT_INTERVAL_INF
937
988
  active_until = current.timestamp() + heartbeat_interval
938
989
  else:
939
990
  heartbeat_interval = 0
@@ -954,7 +1005,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
954
1005
  new_status.details,
955
1006
  active_until,
956
1007
  heartbeat_interval,
957
- convert_uint64_to_sint64(run_id),
1008
+ uint64_to_int64(run_id),
958
1009
  )
959
1010
  self.query(query % timestamp_fld, data)
960
1011
  return True
@@ -967,14 +1018,14 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
967
1018
  query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1;"
968
1019
  rows = self.query(query)
969
1020
  if rows:
970
- pending_run_id = convert_sint64_to_uint64(rows[0]["run_id"])
1021
+ pending_run_id = int64_to_uint64(rows[0]["run_id"])
971
1022
 
972
1023
  return pending_run_id
973
1024
 
974
1025
  def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
975
1026
  """Retrieve the federation options for the specified `run_id`."""
976
1027
  # Convert the uint64 value to sint64 for SQLite
977
- sint64_run_id = convert_uint64_to_sint64(run_id)
1028
+ sint64_run_id = uint64_to_int64(run_id)
978
1029
  query = "SELECT federation_options FROM run WHERE run_id = ?;"
979
1030
  rows = self.query(query, (sint64_run_id,))
980
1031
 
@@ -996,26 +1047,38 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
996
1047
  HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before
997
1048
  the node is marked as offline.
998
1049
  """
999
- sint64_node_id = convert_uint64_to_sint64(node_id)
1050
+ if self.conn is None:
1051
+ raise AttributeError("LinkState not initialized")
1000
1052
 
1001
- # Check if the node exists in the `node` table
1002
- query = "SELECT 1 FROM node WHERE node_id = ?"
1003
- if not self.query(query, (sint64_node_id,)):
1004
- return False
1053
+ sint64_node_id = uint64_to_int64(node_id)
1005
1054
 
1006
- # Update `online_until` and `heartbeat_interval` for the given `node_id`
1007
- query = (
1008
- "UPDATE node SET online_until = ?, heartbeat_interval = ? WHERE node_id = ?"
1009
- )
1010
- self.query(
1011
- query,
1012
- (
1013
- time.time() + HEARTBEAT_PATIENCE * heartbeat_interval,
1055
+ with self.conn:
1056
+ # Check if node exists and not deleted
1057
+ query = "SELECT status FROM node WHERE node_id = ? AND status != ?"
1058
+ row = self.conn.execute(
1059
+ query, (sint64_node_id, NodeStatus.UNREGISTERED)
1060
+ ).fetchone()
1061
+ if row is None:
1062
+ return False
1063
+
1064
+ # Construct query and params
1065
+ current_dt = now()
1066
+ query = "UPDATE node SET online_until = ?, heartbeat_interval = ?"
1067
+ params: list[Any] = [
1068
+ current_dt.timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval,
1014
1069
  heartbeat_interval,
1015
- sint64_node_id,
1016
- ),
1017
- )
1018
- return True
1070
+ ]
1071
+
1072
+ # Set timestamp if the status changes
1073
+ if row["status"] != NodeStatus.ONLINE:
1074
+ query += ", status = ?, last_activated_at = ?"
1075
+ params += [NodeStatus.ONLINE, current_dt.isoformat()]
1076
+
1077
+ # Execute the query, refreshing `online_until` and `heartbeat_interval`
1078
+ query += " WHERE node_id = ?"
1079
+ params += [sint64_node_id]
1080
+ self.conn.execute(query, params)
1081
+ return True
1019
1082
 
1020
1083
  def acknowledge_app_heartbeat(self, run_id: int, heartbeat_interval: float) -> bool:
1021
1084
  """Acknowledge a heartbeat received from a ServerApp for a given run.
@@ -1029,7 +1092,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1029
1092
  self._check_and_tag_inactive_run(run_ids={run_id})
1030
1093
 
1031
1094
  # Search for the run
1032
- sint_run_id = convert_uint64_to_sint64(run_id)
1095
+ sint_run_id = uint64_to_int64(run_id)
1033
1096
  query = "SELECT * FROM run WHERE run_id = ?;"
1034
1097
  rows = self.query(query, (sint_run_id,))
1035
1098
 
@@ -1059,7 +1122,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1059
1122
  """Get the context for the specified `run_id`."""
1060
1123
  # Retrieve context if any
1061
1124
  query = "SELECT context FROM context WHERE run_id = ?;"
1062
- rows = self.query(query, (convert_uint64_to_sint64(run_id),))
1125
+ rows = self.query(query, (uint64_to_int64(run_id),))
1063
1126
  context = context_from_bytes(rows[0]["context"]) if rows else None
1064
1127
  return context
1065
1128
 
@@ -1067,7 +1130,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1067
1130
  """Set the context for the specified `run_id`."""
1068
1131
  # Convert context to bytes
1069
1132
  context_bytes = context_to_bytes(context)
1070
- sint_run_id = convert_uint64_to_sint64(run_id)
1133
+ sint_run_id = uint64_to_int64(run_id)
1071
1134
 
1072
1135
  # Check if any existing Context assigned to the run_id
1073
1136
  query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
@@ -1086,7 +1149,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1086
1149
  def add_serverapp_log(self, run_id: int, log_message: str) -> None:
1087
1150
  """Add a log entry to the ServerApp logs for the specified `run_id`."""
1088
1151
  # Convert the uint64 value to sint64 for SQLite
1089
- sint64_run_id = convert_uint64_to_sint64(run_id)
1152
+ sint64_run_id = uint64_to_int64(run_id)
1090
1153
 
1091
1154
  # Store log
1092
1155
  try:
@@ -1102,7 +1165,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1102
1165
  ) -> tuple[str, float]:
1103
1166
  """Get the ServerApp logs for the specified `run_id`."""
1104
1167
  # Convert the uint64 value to sint64 for SQLite
1105
- sint64_run_id = convert_uint64_to_sint64(run_id)
1168
+ sint64_run_id = uint64_to_int64(run_id)
1106
1169
 
1107
1170
  # Check if the run_id exists
1108
1171
  query = "SELECT run_id FROM run WHERE run_id = ?;"
@@ -1140,7 +1203,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1140
1203
  message_ins = rows[0]
1141
1204
  created_at = message_ins["created_at"]
1142
1205
  ttl = message_ins["ttl"]
1143
- current_time = time.time()
1206
+ current_time = now().timestamp()
1144
1207
 
1145
1208
  # Check if Message is expired
1146
1209
  if ttl is not None and created_at + ttl <= current_time:
@@ -1152,7 +1215,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1152
1215
  """Create a token for the given run ID."""
1153
1216
  token = secrets.token_hex(FLWR_APP_TOKEN_LENGTH) # Generate a random token
1154
1217
  query = "INSERT INTO token_store (run_id, token) VALUES (:run_id, :token);"
1155
- data = {"run_id": convert_uint64_to_sint64(run_id), "token": token}
1218
+ data = {"run_id": uint64_to_int64(run_id), "token": token}
1156
1219
  try:
1157
1220
  self.query(query, data)
1158
1221
  except sqlite3.IntegrityError:
@@ -1162,7 +1225,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1162
1225
  def verify_token(self, run_id: int, token: str) -> bool:
1163
1226
  """Verify a token for the given run ID."""
1164
1227
  query = "SELECT token FROM token_store WHERE run_id = :run_id;"
1165
- data = {"run_id": convert_uint64_to_sint64(run_id)}
1228
+ data = {"run_id": uint64_to_int64(run_id)}
1166
1229
  rows = self.query(query, data)
1167
1230
  if not rows:
1168
1231
  return False
@@ -1171,7 +1234,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1171
1234
  def delete_token(self, run_id: int) -> None:
1172
1235
  """Delete the token for the given run ID."""
1173
1236
  query = "DELETE FROM token_store WHERE run_id = :run_id;"
1174
- data = {"run_id": convert_uint64_to_sint64(run_id)}
1237
+ data = {"run_id": uint64_to_int64(run_id)}
1175
1238
  self.query(query, data)
1176
1239
 
1177
1240
  def get_run_id_by_token(self, token: str) -> Optional[int]:
@@ -1181,19 +1244,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1181
1244
  rows = self.query(query, data)
1182
1245
  if not rows:
1183
1246
  return None
1184
- return convert_sint64_to_uint64(rows[0]["run_id"])
1185
-
1186
-
1187
- def dict_factory(
1188
- cursor: sqlite3.Cursor,
1189
- row: sqlite3.Row,
1190
- ) -> dict[str, Any]:
1191
- """Turn SQLite results into dicts.
1192
-
1193
- Less efficent for retrival of large amounts of data but easier to use.
1194
- """
1195
- fields = [column[0] for column in cursor.description]
1196
- return dict(zip(fields, row))
1247
+ return int64_to_uint64(rows[0]["run_id"])
1197
1248
 
1198
1249
 
1199
1250
  def message_to_dict(message: Message) -> dict[str, Any]:
@@ -1248,5 +1299,5 @@ def determine_run_status(row: dict[str, Any]) -> str:
1248
1299
  return Status.RUNNING
1249
1300
  return Status.STARTING
1250
1301
  return Status.PENDING
1251
- run_id = convert_sint64_to_uint64(row["run_id"])
1302
+ run_id = int64_to_uint64(row["run_id"])
1252
1303
  raise sqlite3.IntegrityError(f"The run {run_id} does not have a valid status.")