flwr 1.25.0__py3-none-any.whl → 1.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. flwr/__init__.py +1 -1
  2. flwr/app/__init__.py +4 -1
  3. flwr/app/message_type.py +29 -0
  4. flwr/app/metadata.py +5 -2
  5. flwr/app/user_config.py +19 -0
  6. flwr/cli/app.py +37 -19
  7. flwr/cli/app_cmd/publish.py +25 -75
  8. flwr/cli/app_cmd/review.py +18 -69
  9. flwr/cli/auth_plugin/auth_plugin.py +5 -10
  10. flwr/cli/auth_plugin/noop_auth_plugin.py +1 -2
  11. flwr/cli/auth_plugin/oidc_cli_plugin.py +38 -38
  12. flwr/cli/build.py +15 -28
  13. flwr/cli/config/__init__.py +21 -0
  14. flwr/cli/config/ls.py +71 -0
  15. flwr/cli/config_migration.py +297 -0
  16. flwr/cli/config_utils.py +63 -156
  17. flwr/cli/constant.py +71 -0
  18. flwr/cli/federation/__init__.py +0 -2
  19. flwr/cli/federation/ls.py +256 -64
  20. flwr/cli/flower_config.py +429 -0
  21. flwr/cli/install.py +23 -62
  22. flwr/cli/log.py +23 -37
  23. flwr/cli/login/login.py +29 -63
  24. flwr/cli/ls.py +28 -58
  25. flwr/cli/new/new.py +9 -29
  26. flwr/cli/pull.py +19 -37
  27. flwr/cli/run/run.py +85 -93
  28. flwr/cli/run_utils.py +1 -1
  29. flwr/cli/stop.py +32 -73
  30. flwr/cli/supernode/ls.py +25 -57
  31. flwr/cli/supernode/register.py +31 -80
  32. flwr/cli/supernode/unregister.py +24 -70
  33. flwr/cli/typing.py +200 -0
  34. flwr/cli/utils.py +160 -275
  35. flwr/client/grpc_rere_client/connection.py +3 -3
  36. flwr/client/grpc_rere_client/grpc_adapter.py +1 -1
  37. flwr/client/message_handler/message_handler.py +2 -1
  38. flwr/client/mod/centraldp_mods.py +1 -1
  39. flwr/client/mod/localdp_mod.py +1 -1
  40. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  41. flwr/client/run_info_store.py +2 -1
  42. flwr/clientapp/client_app.py +2 -1
  43. flwr/common/__init__.py +3 -2
  44. flwr/common/args.py +5 -5
  45. flwr/common/config.py +12 -17
  46. flwr/common/constant.py +3 -16
  47. flwr/common/context.py +2 -1
  48. flwr/common/exit/exit.py +4 -4
  49. flwr/common/exit/exit_code.py +6 -0
  50. flwr/common/grpc.py +2 -1
  51. flwr/common/logger.py +1 -1
  52. flwr/common/message.py +1 -1
  53. flwr/common/retry_invoker.py +13 -5
  54. flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -2
  55. flwr/common/serde.py +7 -5
  56. flwr/common/telemetry.py +1 -1
  57. flwr/common/typing.py +4 -3
  58. flwr/compat/client/app.py +6 -9
  59. flwr/compat/client/grpc_client/connection.py +2 -1
  60. flwr/compat/common/constant.py +29 -0
  61. flwr/compat/server/app.py +1 -1
  62. flwr/proto/clientappio_pb2.py +2 -2
  63. flwr/proto/clientappio_pb2_grpc.py +104 -88
  64. flwr/proto/clientappio_pb2_grpc.pyi +140 -80
  65. flwr/proto/federation_pb2.py +5 -3
  66. flwr/proto/federation_pb2.pyi +32 -2
  67. flwr/proto/run_pb2.py +5 -13
  68. flwr/proto/run_pb2.pyi +0 -57
  69. flwr/proto/serverappio_pb2.py +2 -2
  70. flwr/proto/serverappio_pb2_grpc.py +138 -207
  71. flwr/proto/serverappio_pb2_grpc.pyi +189 -155
  72. flwr/proto/simulationio_pb2.py +2 -2
  73. flwr/proto/simulationio_pb2_grpc.py +62 -90
  74. flwr/proto/simulationio_pb2_grpc.pyi +95 -55
  75. flwr/server/app.py +6 -13
  76. flwr/server/compat/grid_client_proxy.py +2 -1
  77. flwr/server/grid/grpc_grid.py +5 -5
  78. flwr/server/serverapp/app.py +11 -4
  79. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -1
  80. flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +13 -12
  81. flwr/server/superlink/fleet/message_handler/message_handler.py +6 -5
  82. flwr/server/superlink/linkstate/__init__.py +2 -2
  83. flwr/server/superlink/linkstate/in_memory_linkstate.py +2 -10
  84. flwr/server/superlink/linkstate/linkstate.py +2 -21
  85. flwr/server/superlink/linkstate/linkstate_factory.py +16 -8
  86. flwr/server/superlink/linkstate/{sqlite_linkstate.py → sql_linkstate.py} +432 -534
  87. flwr/server/superlink/linkstate/utils.py +49 -2
  88. flwr/server/superlink/serverappio/serverappio_servicer.py +1 -33
  89. flwr/server/superlink/simulation/simulationio_servicer.py +0 -19
  90. flwr/server/utils/validator.py +1 -1
  91. flwr/server/workflow/default_workflows.py +2 -1
  92. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +1 -1
  93. flwr/serverapp/strategy/bulyan.py +7 -1
  94. flwr/serverapp/strategy/dp_fixed_clipping.py +9 -1
  95. flwr/serverapp/strategy/fedavg.py +1 -1
  96. flwr/serverapp/strategy/fedxgb_cyclic.py +1 -1
  97. flwr/simulation/ray_transport/ray_client_proxy.py +2 -6
  98. flwr/simulation/run_simulation.py +3 -12
  99. flwr/simulation/simulationio_connection.py +3 -3
  100. flwr/{common → supercore}/address.py +7 -33
  101. flwr/supercore/app_utils.py +2 -1
  102. flwr/supercore/constant.py +24 -2
  103. flwr/supercore/corestate/{sqlite_corestate.py → sql_corestate.py} +19 -23
  104. flwr/supercore/credential_store/__init__.py +33 -0
  105. flwr/supercore/credential_store/credential_store.py +34 -0
  106. flwr/supercore/credential_store/file_credential_store.py +76 -0
  107. flwr/{common → supercore}/date.py +0 -11
  108. flwr/supercore/ffs/disk_ffs.py +1 -1
  109. flwr/supercore/object_store/object_store_factory.py +14 -6
  110. flwr/supercore/object_store/{sqlite_object_store.py → sql_object_store.py} +115 -117
  111. flwr/supercore/sql_mixin.py +315 -0
  112. flwr/supercore/state/__init__.py +15 -0
  113. flwr/supercore/state/alembic/__init__.py +15 -0
  114. flwr/supercore/state/alembic/env.py +103 -0
  115. flwr/supercore/state/alembic/script.py.mako +43 -0
  116. flwr/supercore/state/alembic/utils.py +239 -0
  117. flwr/supercore/state/alembic/versions/__init__.py +15 -0
  118. flwr/supercore/state/alembic/versions/rev_2026_01_28_initialize_migration_of_state_tables.py +200 -0
  119. flwr/supercore/state/schema/README.md +121 -0
  120. flwr/supercore/state/schema/__init__.py +15 -0
  121. flwr/supercore/state/schema/corestate_tables.py +36 -0
  122. flwr/supercore/state/schema/linkstate_tables.py +152 -0
  123. flwr/supercore/state/schema/objectstore_tables.py +90 -0
  124. flwr/supercore/superexec/run_superexec.py +2 -2
  125. flwr/supercore/utils.py +36 -1
  126. flwr/superlink/federation/federation_manager.py +2 -2
  127. flwr/superlink/federation/noop_federation_manager.py +8 -6
  128. flwr/superlink/servicer/control/control_servicer.py +19 -17
  129. flwr/supernode/cli/flower_supernode.py +2 -1
  130. flwr/supernode/runtime/run_clientapp.py +14 -14
  131. flwr/supernode/servicer/clientappio/clientappio_servicer.py +10 -8
  132. flwr/supernode/start_client_internal.py +10 -6
  133. {flwr-1.25.0.dist-info → flwr-1.26.0.dist-info}/METADATA +7 -5
  134. {flwr-1.25.0.dist-info → flwr-1.26.0.dist-info}/RECORD +137 -116
  135. flwr/cli/federation/show.py +0 -318
  136. flwr/common/pyproject.py +0 -42
  137. flwr/supercore/sqlite_mixin.py +0 -159
  138. /flwr/{common → supercore}/version.py +0 -0
  139. {flwr-1.25.0.dist-info → flwr-1.26.0.dist-info}/WHEEL +0 -0
  140. {flwr-1.25.0.dist-info → flwr-1.26.0.dist-info}/entry_points.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2026 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,19 +12,22 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """SQLite based implemenation of the link state."""
15
+ """SQLAlchemy-based implementation of the link state."""
16
16
 
17
17
 
18
18
  # pylint: disable=too-many-lines
19
19
 
20
20
  import json
21
- import sqlite3
22
21
  from collections.abc import Sequence
23
22
  from datetime import datetime, timezone
24
23
  from logging import ERROR, WARNING
25
- from typing import Any, cast
24
+ from typing import Any
26
25
 
27
- from flwr.common import Context, Message, Metadata, log, now
26
+ from sqlalchemy import MetaData
27
+ from sqlalchemy.exc import IntegrityError
28
+
29
+ from flwr.app.user_config import UserConfig
30
+ from flwr.common import Context, Message, log, now
28
31
  from flwr.common.constant import (
29
32
  HEARTBEAT_PATIENCE,
30
33
  MESSAGE_TTL_TOLERANCE,
@@ -35,22 +38,15 @@ from flwr.common.constant import (
35
38
  Status,
36
39
  SubStatus,
37
40
  )
38
- from flwr.common.message import make_message
39
41
  from flwr.common.record import ConfigRecord
40
- from flwr.common.serde import recorddict_from_proto, recorddict_to_proto
41
- from flwr.common.serde_utils import error_from_proto, error_to_proto
42
- from flwr.common.typing import Run, RunStatus, UserConfig
43
-
44
- # pylint: disable=E0611
45
- from flwr.proto.error_pb2 import Error as ProtoError
46
- from flwr.proto.node_pb2 import NodeInfo
47
- from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
48
-
49
- # pylint: enable=E0611
42
+ from flwr.common.typing import Run, RunStatus
43
+ from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
50
44
  from flwr.server.utils.validator import validate_message
51
45
  from flwr.supercore.constant import NodeStatus
52
- from flwr.supercore.corestate.sqlite_corestate import SqliteCoreState
46
+ from flwr.supercore.corestate.sql_corestate import SqlCoreState
53
47
  from flwr.supercore.object_store.object_store import ObjectStore
48
+ from flwr.supercore.state.schema.corestate_tables import create_corestate_metadata
49
+ from flwr.supercore.state.schema.linkstate_tables import create_linkstate_metadata
54
50
  from flwr.supercore.utils import int64_to_uint64, uint64_to_int64
55
51
  from flwr.superlink.federation import FederationManager
56
52
 
@@ -63,128 +59,18 @@ from .utils import (
63
59
  context_to_bytes,
64
60
  convert_sint64_values_in_dict_to_uint64,
65
61
  convert_uint64_values_in_dict_to_sint64,
62
+ dict_to_message,
66
63
  generate_rand_int_from_bytes,
67
64
  has_valid_sub_status,
68
65
  is_valid_transition,
66
+ message_to_dict,
69
67
  verify_found_message_replies,
70
68
  verify_message_ids,
71
69
  )
72
70
 
73
- SQL_CREATE_TABLE_NODE = """
74
- CREATE TABLE IF NOT EXISTS node(
75
- node_id INTEGER UNIQUE,
76
- owner_aid TEXT,
77
- owner_name TEXT,
78
- status TEXT,
79
- registered_at TEXT,
80
- last_activated_at TEXT NULL,
81
- last_deactivated_at TEXT NULL,
82
- unregistered_at TEXT NULL,
83
- online_until TIMESTAMP NULL,
84
- heartbeat_interval REAL,
85
- public_key BLOB UNIQUE
86
- );
87
- """
88
-
89
- SQL_CREATE_TABLE_PUBLIC_KEY = """
90
- CREATE TABLE IF NOT EXISTS public_key(
91
- public_key BLOB PRIMARY KEY
92
- );
93
- """
94
-
95
- SQL_CREATE_INDEX_ONLINE_UNTIL = """
96
- CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
97
- """
98
-
99
- SQL_CREATE_INDEX_OWNER_AID = """
100
- CREATE INDEX IF NOT EXISTS idx_node_owner_aid ON node(owner_aid);
101
- """
102
-
103
- SQL_CREATE_INDEX_NODE_STATUS = """
104
- CREATE INDEX IF NOT EXISTS idx_node_status ON node(status);
105
- """
106
-
107
- SQL_CREATE_TABLE_RUN = """
108
- CREATE TABLE IF NOT EXISTS run(
109
- run_id INTEGER UNIQUE,
110
- fab_id TEXT,
111
- fab_version TEXT,
112
- fab_hash TEXT,
113
- override_config TEXT,
114
- pending_at TEXT,
115
- starting_at TEXT,
116
- running_at TEXT,
117
- finished_at TEXT,
118
- sub_status TEXT,
119
- details TEXT,
120
- federation TEXT,
121
- federation_options BLOB,
122
- flwr_aid TEXT,
123
- bytes_sent INTEGER DEFAULT 0,
124
- bytes_recv INTEGER DEFAULT 0,
125
- clientapp_runtime REAL DEFAULT 0.0
126
- );
127
- """
128
-
129
- SQL_CREATE_TABLE_LOGS = """
130
- CREATE TABLE IF NOT EXISTS logs (
131
- timestamp REAL,
132
- run_id INTEGER,
133
- node_id INTEGER,
134
- log TEXT,
135
- PRIMARY KEY (timestamp, run_id, node_id),
136
- FOREIGN KEY (run_id) REFERENCES run(run_id)
137
- );
138
- """
139
-
140
- SQL_CREATE_TABLE_CONTEXT = """
141
- CREATE TABLE IF NOT EXISTS context(
142
- run_id INTEGER UNIQUE,
143
- context BLOB,
144
- FOREIGN KEY(run_id) REFERENCES run(run_id)
145
- );
146
- """
147
-
148
- SQL_CREATE_TABLE_MESSAGE_INS = """
149
- CREATE TABLE IF NOT EXISTS message_ins(
150
- message_id TEXT UNIQUE,
151
- group_id TEXT,
152
- run_id INTEGER,
153
- src_node_id INTEGER,
154
- dst_node_id INTEGER,
155
- reply_to_message_id TEXT,
156
- created_at REAL,
157
- delivered_at TEXT,
158
- ttl REAL,
159
- message_type TEXT,
160
- content BLOB NULL,
161
- error BLOB NULL,
162
- FOREIGN KEY(run_id) REFERENCES run(run_id)
163
- );
164
- """
165
-
166
-
167
- SQL_CREATE_TABLE_MESSAGE_RES = """
168
- CREATE TABLE IF NOT EXISTS message_res(
169
- message_id TEXT UNIQUE,
170
- group_id TEXT,
171
- run_id INTEGER,
172
- src_node_id INTEGER,
173
- dst_node_id INTEGER,
174
- reply_to_message_id TEXT,
175
- created_at REAL,
176
- delivered_at TEXT,
177
- ttl REAL,
178
- message_type TEXT,
179
- content BLOB NULL,
180
- error BLOB NULL,
181
- FOREIGN KEY(run_id) REFERENCES run(run_id)
182
- );
183
- """
184
-
185
-
186
- class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
187
- """SQLite-based LinkState implementation."""
71
+
72
+ class SqlLinkState(LinkState, SqlCoreState): # pylint: disable=R0904
73
+ """SQLAlchemy-based LinkState implementation."""
188
74
 
189
75
  def __init__(
190
76
  self,
@@ -196,24 +82,21 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
196
82
  federation_manager.linkstate = self
197
83
  self._federation_manager = federation_manager
198
84
 
199
- def get_sql_statements(self) -> tuple[str, ...]:
200
- """Return SQL statements for LinkState tables."""
201
- return super().get_sql_statements() + (
202
- SQL_CREATE_TABLE_RUN,
203
- SQL_CREATE_TABLE_LOGS,
204
- SQL_CREATE_TABLE_CONTEXT,
205
- SQL_CREATE_TABLE_MESSAGE_INS,
206
- SQL_CREATE_TABLE_MESSAGE_RES,
207
- SQL_CREATE_TABLE_NODE,
208
- SQL_CREATE_TABLE_PUBLIC_KEY,
209
- SQL_CREATE_INDEX_ONLINE_UNTIL,
210
- SQL_CREATE_INDEX_OWNER_AID,
211
- SQL_CREATE_INDEX_NODE_STATUS,
212
- )
85
+ def get_metadata(self) -> MetaData:
86
+ """Return combined SQLAlchemy MetaData for LinkState and CoreState tables."""
87
+ # Start with linkstate tables
88
+ metadata = create_linkstate_metadata()
89
+
90
+ # Add corestate tables (token_store)
91
+ corestate_metadata = create_corestate_metadata()
92
+ for table in corestate_metadata.tables.values():
93
+ table.to_metadata(metadata)
94
+
95
+ return metadata
213
96
 
214
97
  @property
215
98
  def federation_manager(self) -> FederationManager:
216
- """Get the FederationManager instance."""
99
+ """Return the FederationManager instance."""
217
100
  return self._federation_manager
218
101
 
219
102
  def store_message_ins(self, message: Message) -> str | None:
@@ -241,20 +124,26 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
241
124
  )
242
125
  return None
243
126
 
244
- with self.conn:
127
+ with self.session():
245
128
  # Validate run_id
246
- query = "SELECT federation FROM run WHERE run_id = ?;"
247
- rows = self.conn.execute(query, (data[0]["run_id"],)).fetchall()
129
+ query = "SELECT federation FROM run WHERE run_id = :run_id"
130
+ rows = self.query(query, {"run_id": data[0]["run_id"]})
248
131
  if not rows:
249
132
  log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
250
133
  return None
251
134
  federation: str = rows[0]["federation"]
252
135
 
253
136
  # Validate destination node ID
254
- query = "SELECT node_id FROM node WHERE node_id = ? AND status IN (?, ?);"
255
- rows = self.conn.execute(
256
- query, (data[0]["dst_node_id"], NodeStatus.ONLINE, NodeStatus.OFFLINE)
257
- ).fetchall()
137
+ query = """SELECT node_id FROM node WHERE node_id = :node_id
138
+ AND status IN (:online, :offline)"""
139
+ rows = self.query(
140
+ query,
141
+ {
142
+ "node_id": data[0]["dst_node_id"],
143
+ "online": NodeStatus.ONLINE,
144
+ "offline": NodeStatus.OFFLINE,
145
+ },
146
+ )
258
147
  if not rows or not self.federation_manager.has_node(
259
148
  message.metadata.dst_node_id, federation
260
149
  ):
@@ -265,29 +154,62 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
265
154
  )
266
155
  return None
267
156
 
157
+ # Insert message
268
158
  columns = ", ".join([f":{key}" for key in data[0]])
269
- query = f"INSERT INTO message_ins VALUES({columns});"
159
+ query = f"INSERT INTO message_ins VALUES({columns})"
270
160
 
271
161
  # Only invalid run_id can trigger IntegrityError.
272
162
  # This may need to be changed in the future version
273
163
  # with more integrity checks.
274
- self.conn.execute(query, data[0])
164
+ self.query(query, data[0])
275
165
 
276
166
  return message.metadata.message_id
277
167
 
168
+ # pylint: disable-next=too-many-locals
278
169
  def _check_stored_messages(self, message_ids: set[str]) -> None:
279
170
  """Check and delete the message if it's invalid."""
280
171
  if not message_ids:
281
172
  return
282
173
 
283
- with self.conn:
174
+ with self.session():
175
+ # Batch fetch all messages in one query
176
+ placeholders = ",".join([f":mid_{i}" for i in range(len(message_ids))])
177
+ query = f"""
178
+ SELECT * FROM message_ins
179
+ WHERE message_id IN ({placeholders})
180
+ """
181
+ params = {f"mid_{i}": str(mid) for i, mid in enumerate(message_ids)}
182
+ message_rows = self.query(query, params)
183
+
184
+ if not message_rows:
185
+ return
186
+
187
+ # Build message lookup dict
188
+ message_dict: dict[str, dict[str, Any]] = {
189
+ row["message_id"]: row for row in message_rows
190
+ }
191
+
192
+ # Collect unique run_ids for batch federation lookup
193
+ run_ids = {row["run_id"] for row in message_rows}
194
+ placeholders = ",".join([f":rid_{i}" for i in range(len(run_ids))])
195
+ query = f"""
196
+ SELECT run_id, federation FROM run
197
+ WHERE run_id IN ({placeholders})
198
+ """
199
+ params = {f"rid_{i}": rid for i, rid in enumerate(run_ids)}
200
+ run_rows = self.query(query, params)
201
+
202
+ # Build run_id to federation mapping
203
+ run_id_to_federation: dict[int, str] = {
204
+ row["run_id"]: row["federation"] for row in run_rows
205
+ }
206
+
284
207
  invalid_msg_ids: set[str] = set()
285
208
  current_time = now().timestamp()
286
209
 
210
+ # Check each message for validity
287
211
  for msg_id in message_ids:
288
- # Check if message exists
289
- query = "SELECT * FROM message_ins WHERE message_id = ?;"
290
- message_row = self.conn.execute(query, (msg_id,)).fetchone()
212
+ message_row = message_dict.get(msg_id)
291
213
  if not message_row:
292
214
  continue
293
215
 
@@ -297,15 +219,12 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
297
219
  invalid_msg_ids.add(msg_id)
298
220
  continue
299
221
 
300
- # Check if src_node_id and dst_node_id are in the federation
301
- # Get federation from run table
222
+ # Check if run exists and get federation
302
223
  run_id = message_row["run_id"]
303
- query = "SELECT federation FROM run WHERE run_id = ?;"
304
- run_row = self.conn.execute(query, (run_id,)).fetchone()
305
- if not run_row: # This should not happen
224
+ federation = run_id_to_federation.get(run_id)
225
+ if not federation:
306
226
  invalid_msg_ids.add(msg_id)
307
227
  continue
308
- federation = run_row["federation"]
309
228
 
310
229
  # Convert sint64 to uint64 for node IDs
311
230
  src_node_id = int64_to_uint64(message_row["src_node_id"])
@@ -330,52 +249,48 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
330
249
  msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
331
250
  raise AssertionError(msg)
332
251
 
333
- data: dict[str, str | int] = {}
252
+ params: dict[str, str | int] = {}
334
253
 
335
254
  # Convert the uint64 value to sint64 for SQLite
336
- data["node_id"] = uint64_to_int64(node_id)
255
+ params["node_id"] = uint64_to_int64(node_id)
337
256
 
338
- with self.conn:
257
+ with self.session():
339
258
  # Retrieve all Messages for node_id
340
259
  query = """
341
260
  SELECT message_id
342
261
  FROM message_ins
343
- WHERE dst_node_id == :node_id
344
- AND delivered_at = ""
345
- AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
262
+ WHERE dst_node_id = :node_id
263
+ AND delivered_at = ''
264
+ AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
346
265
  """
347
266
 
348
267
  if limit is not None:
349
268
  query += " LIMIT :limit"
350
- data["limit"] = limit
351
-
352
- query += ";"
269
+ params["limit"] = limit
353
270
 
354
- rows = self.conn.execute(query, data).fetchall()
271
+ rows = self.query(query, params)
355
272
  message_ids: set[str] = {row["message_id"] for row in rows}
356
273
  self._check_stored_messages(message_ids)
357
274
 
358
275
  # Mark retrieved Messages as delivered
359
276
  if rows:
360
277
  # Prepare query
361
- placeholders: str = ",".join(
362
- [f":id_{i}" for i in range(len(message_ids))]
363
- )
278
+ placeholders = ",".join([f":mid_{i}" for i in range(len(message_ids))])
364
279
  query = f"""
365
280
  UPDATE message_ins
366
281
  SET delivered_at = :delivered_at
367
282
  WHERE message_id IN ({placeholders})
368
- RETURNING *;
283
+ RETURNING *
369
284
  """
370
285
 
371
286
  # Prepare data for query
372
287
  delivered_at = now().isoformat()
373
- data = {"delivered_at": delivered_at}
288
+ params = {"delivered_at": delivered_at}
374
289
  for index, msg_id in enumerate(message_ids):
375
- data[f"id_{index}"] = str(msg_id)
290
+ params[f"mid_{index}"] = str(msg_id)
376
291
 
377
292
  # Run query
378
- rows = self.conn.execute(query, data).fetchall()
293
+ rows = self.query(query, params)
379
294
 
380
295
  for row in rows:
381
296
  # Convert values from sint64 to uint64
@@ -383,7 +298,7 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
383
298
  row, ["run_id", "src_node_id", "dst_node_id"]
384
299
  )
385
300
 
386
- result = [dict_to_message(row) for row in rows]
301
+ result = [dict_to_message(dict(row)) for row in rows]
387
302
 
388
303
  return result
389
304
 
@@ -409,8 +324,8 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
409
324
  )
410
325
  return None
411
326
 
412
- # Ensure that the dst_node_id of the original message matches the src_node_id of
413
- # reply being processed.
327
+ # Ensure that the dst_node_id of the original message matches the src_node_id
328
+ # of reply being processed.
414
329
  if (
415
330
  msg_ins
416
331
  and message
@@ -440,21 +355,19 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
440
355
  return None
441
356
 
442
357
  # Store Message
443
- data = (message_to_dict(message),)
358
+ msg_dict = message_to_dict(message)
444
359
 
445
360
  # Convert values from uint64 to sint64 for SQLite
446
361
  convert_uint64_values_in_dict_to_sint64(
447
- data[0], ["run_id", "src_node_id", "dst_node_id"]
362
+ msg_dict, ["run_id", "src_node_id", "dst_node_id"]
448
363
  )
449
364
 
450
- columns = ", ".join([f":{key}" for key in data[0]])
451
- query = f"INSERT INTO message_res VALUES({columns});"
365
+ columns = ", ".join([f":{key}" for key in msg_dict])
366
+ query = f"INSERT INTO message_res VALUES({columns})"
452
367
 
453
- # Only invalid run_id can trigger IntegrityError.
454
- # This may need to be changed in the future version with more integrity checks.
455
368
  try:
456
- self.query(query, data)
457
- except sqlite3.IntegrityError:
369
+ self.query(query, msg_dict)
370
+ except IntegrityError:
458
371
  log(ERROR, "`run` is invalid")
459
372
  return None
460
373
 
@@ -462,21 +375,21 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
462
375
 
463
376
  def get_message_res(self, message_ids: set[str]) -> list[Message]:
464
377
  """Get reply Messages for the given Message IDs."""
465
- # pylint: disable-msg=too-many-locals
378
+ # pylint: disable=too-many-locals
466
379
  ret: dict[str, Message] = {}
467
380
 
468
- with self.conn:
381
+ with self.session():
469
382
  # Verify Message IDs
470
383
  self._check_stored_messages(message_ids)
471
384
  current = now().timestamp()
385
+ placeholders = ",".join([f":mid_{i}" for i in range(len(message_ids))])
472
386
  query = f"""
473
387
  SELECT *
474
388
  FROM message_ins
475
- WHERE message_id IN ({','.join(['?'] * len(message_ids))});
389
+ WHERE message_id IN ({placeholders})
476
390
  """
477
- rows = self.conn.execute(
478
- query, tuple(str(message_id) for message_id in message_ids)
479
- ).fetchall()
391
+ params = {f"mid_{i}": str(mid) for i, mid in enumerate(message_ids)}
392
+ rows = self.query(query, params)
480
393
  found_message_ins_dict: dict[str, Message] = {}
481
394
  for row in rows:
482
395
  convert_sint64_values_in_dict_to_uint64(
@@ -496,15 +409,18 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
496
409
  in_message = found_message_ins_dict[message_id]
497
410
  sint_node_id = uint64_to_int64(in_message.metadata.dst_node_id)
498
411
  dst_node_ids.add(sint_node_id)
412
+ placeholders = ",".join([f":nid_{i}" for i in range(len(dst_node_ids))])
499
413
  query = f"""
500
414
  SELECT node_id, online_until
501
415
  FROM node
502
- WHERE node_id IN ({','.join(['?'] * len(dst_node_ids))})
503
- AND status != ?
416
+ WHERE node_id IN ({placeholders})
417
+ AND status != :unregistered
504
418
  """
505
- rows = self.conn.execute(
506
- query, tuple(dst_node_ids) + (NodeStatus.UNREGISTERED,)
507
- ).fetchall()
419
+ node_params: dict[str, int | str] = {
420
+ f"nid_{i}": nid for i, nid in enumerate(dst_node_ids)
421
+ }
422
+ node_params["unregistered"] = NodeStatus.UNREGISTERED
423
+ rows = self.query(query, node_params)
508
424
  tmp_ret_dict = check_node_availability_for_in_message(
509
425
  inquired_in_message_ids=message_ids,
510
426
  found_in_message_dict=found_message_ins_dict,
@@ -516,15 +432,15 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
516
432
  ret.update(tmp_ret_dict)
517
433
 
518
434
  # Find all reply Messages
435
+ placeholders = ",".join([f":mid_{i}" for i in range(len(message_ids))])
519
436
  query = f"""
520
437
  SELECT *
521
438
  FROM message_res
522
- WHERE reply_to_message_id IN ({','.join(['?'] * len(message_ids))})
523
- AND delivered_at = "";
439
+ WHERE reply_to_message_id IN ({placeholders})
440
+ AND delivered_at = ''
524
441
  """
525
- rows = self.conn.execute(
526
- query, tuple(str(message_id) for message_id in message_ids)
527
- ).fetchall()
442
+ params = {f"mid_{i}": str(mid) for i, mid in enumerate(message_ids)}
443
+ rows = self.query(query, params)
528
444
  for row in rows:
529
445
  convert_sint64_values_in_dict_to_uint64(
530
446
  row, ["run_id", "src_node_id", "dst_node_id"]
@@ -544,13 +460,15 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
544
460
  message_res_ids = [
545
461
  message_res.metadata.message_id for message_res in ret.values()
546
462
  ]
463
+ placeholders = ",".join([f":mid_{i}" for i in range(len(message_res_ids))])
547
464
  query = f"""
548
465
  UPDATE message_res
549
- SET delivered_at = ?
550
- WHERE message_id IN ({','.join(['?'] * len(message_res_ids))});
466
+ SET delivered_at = :delivered_at
467
+ WHERE message_id IN ({placeholders})
551
468
  """
552
- data: list[Any] = [delivered_at] + message_res_ids
553
- self.conn.execute(query, data)
469
+ params = {"delivered_at": delivered_at}
470
+ params.update({f"mid_{i}": mid for i, mid in enumerate(message_res_ids)})
471
+ self.query(query, params)
554
472
 
555
473
  return list(ret.values())
556
474
 
@@ -559,64 +477,55 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
559
477
 
560
478
  This includes delivered but not yet deleted.
561
479
  """
562
- query = "SELECT count(*) AS num FROM message_ins;"
563
- rows = self.query(query)
564
- result = rows[0]
565
- num = cast(int, result["num"])
566
- return num
480
+ query = "SELECT count(*) AS num FROM message_ins"
481
+ rows = self.query(query, {})
482
+ return int(rows[0]["num"])
567
483
 
568
484
  def num_message_res(self) -> int:
569
485
  """Calculate the number of reply Messages in store.
570
486
 
571
487
  This includes delivered but not yet deleted.
572
488
  """
573
- query = "SELECT count(*) AS num FROM message_res;"
489
+ query = "SELECT count(*) AS num FROM message_res"
574
490
  rows = self.query(query)
575
- result: dict[str, int] = rows[0]
576
- return result["num"]
491
+ return int(rows[0]["num"])
577
492
 
578
493
  def delete_messages(self, message_ins_ids: set[str]) -> None:
579
494
  """Delete a Message and its reply based on provided Message IDs."""
580
495
  if not message_ins_ids:
581
496
  return
582
- if self.conn is None:
583
- raise AttributeError("LinkState not initialized")
584
497
 
585
- placeholders = ",".join(["?"] * len(message_ins_ids))
586
- data = tuple(str(message_id) for message_id in message_ins_ids)
498
+ placeholders = ",".join([f":mid_{i}" for i in range(len(message_ins_ids))])
499
+ params = {f"mid_{i}": str(mid) for i, mid in enumerate(message_ins_ids)}
587
500
 
588
501
  # Delete Message
589
502
  query_1 = f"""
590
503
  DELETE FROM message_ins
591
- WHERE message_id IN ({placeholders});
504
+ WHERE message_id IN ({placeholders})
592
505
  """
593
506
 
594
507
  # Delete reply Message
595
508
  query_2 = f"""
596
509
  DELETE FROM message_res
597
- WHERE reply_to_message_id IN ({placeholders});
510
+ WHERE reply_to_message_id IN ({placeholders})
598
511
  """
599
512
 
600
- with self.conn:
601
- self.conn.execute(query_1, data)
602
- self.conn.execute(query_2, data)
513
+ with self.session():
514
+ self.query(query_1, params)
515
+ self.query(query_2, params)
603
516
 
604
517
  def get_message_ids_from_run_id(self, run_id: int) -> set[str]:
605
518
  """Get all instruction Message IDs for the given run_id."""
606
- if self.conn is None:
607
- raise AttributeError("LinkState not initialized")
608
-
609
519
  query = """
610
520
  SELECT message_id
611
521
  FROM message_ins
612
- WHERE run_id = :run_id;
522
+ WHERE run_id = :run_id
613
523
  """
614
-
615
524
  sint64_run_id = uint64_to_int64(run_id)
616
- data = {"run_id": sint64_run_id}
525
+ params = {"run_id": sint64_run_id}
617
526
 
618
- with self.conn:
619
- rows = self.conn.execute(query, data).fetchall()
527
+ with self.session():
528
+ rows = self.query(query, params)
620
529
 
621
530
  return {row["message_id"] for row in rows}
622
531
 
@@ -641,29 +550,31 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
641
550
  (node_id, owner_aid, owner_name, status, registered_at, last_activated_at,
642
551
  last_deactivated_at, unregistered_at, online_until, heartbeat_interval,
643
552
  public_key)
644
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
553
+ VALUES (:node_id, :owner_aid, :owner_name, :status, :registered_at,
554
+ :last_activated_at, :last_deactivated_at, :unregistered_at, :online_until,
555
+ :heartbeat_interval, :public_key)
645
556
  """
646
557
 
647
558
  # Mark the node online until now().timestamp() + heartbeat_interval
648
559
  try:
649
560
  self.query(
650
561
  query,
651
- (
652
- sint64_node_id, # node_id
653
- owner_aid, # owner_aid
654
- owner_name, # owner_name
655
- NodeStatus.REGISTERED, # status
656
- now().isoformat(), # registered_at
657
- None, # last_activated_at
658
- None, # last_deactivated_at
659
- None, # unregistered_at
660
- None, # online_until, initialized with offline status
661
- heartbeat_interval, # heartbeat_interval
662
- public_key, # public_key
663
- ),
562
+ {
563
+ "node_id": sint64_node_id,
564
+ "owner_aid": owner_aid,
565
+ "owner_name": owner_name,
566
+ "status": NodeStatus.REGISTERED,
567
+ "registered_at": now().isoformat(),
568
+ "last_activated_at": None,
569
+ "last_deactivated_at": None,
570
+ "unregistered_at": None,
571
+ "online_until": None, # initialized with offline status
572
+ "heartbeat_interval": heartbeat_interval,
573
+ "public_key": public_key,
574
+ },
664
575
  )
665
- except sqlite3.IntegrityError as e:
666
- if "UNIQUE constraint failed: node.public_key" in str(e):
576
+ except IntegrityError as e:
577
+ if "node.public_key" in str(e):
667
578
  raise ValueError("Public key already in use.") from None
668
579
  # Must be node ID conflict, almost impossible unless system is compromised
669
580
  log(ERROR, "Unexpected node registration failure.")
@@ -678,21 +589,20 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
678
589
 
679
590
  query = """
680
591
  UPDATE node
681
- SET status = ?, unregistered_at = ?,
682
- online_until = IIF(online_until > ?, ?, online_until)
683
- WHERE node_id = ? AND status != ? AND owner_aid = ?
592
+ SET status = :unregistered, unregistered_at = :unregistered_at,
593
+ online_until = IIF(online_until > :current, :current, online_until)
594
+ WHERE node_id = :node_id AND status != :unregistered
595
+ AND owner_aid = :owner_aid
684
596
  RETURNING node_id
685
597
  """
686
598
  current = now()
687
- params = (
688
- NodeStatus.UNREGISTERED,
689
- current.isoformat(),
690
- current.timestamp(),
691
- current.timestamp(),
692
- sint64_node_id,
693
- NodeStatus.UNREGISTERED,
694
- owner_aid,
695
- )
599
+ params = {
600
+ "unregistered": NodeStatus.UNREGISTERED,
601
+ "unregistered_at": current.isoformat(),
602
+ "current": current.timestamp(),
603
+ "node_id": sint64_node_id,
604
+ "owner_aid": owner_aid,
605
+ }
696
606
 
697
607
  rows = self.query(query, params)
698
608
  if not rows:
@@ -703,58 +613,58 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
703
613
 
704
614
  def activate_node(self, node_id: int, heartbeat_interval: float) -> bool:
705
615
  """Activate the node with the specified `node_id`."""
706
- with self.conn:
707
- self._check_and_tag_offline_nodes([node_id])
616
+ self._check_and_tag_offline_nodes([node_id])
708
617
 
709
- # Only activate if the node is currently registered or offline
710
- current_dt = now()
711
- query = """
712
- UPDATE node
713
- SET status = ?,
714
- last_activated_at = ?,
715
- online_until = ?,
716
- heartbeat_interval = ?
717
- WHERE node_id = ? AND status in (?, ?)
718
- RETURNING node_id
719
- """
720
- params = (
721
- NodeStatus.ONLINE,
722
- current_dt.isoformat(),
723
- current_dt.timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval,
724
- heartbeat_interval,
725
- uint64_to_int64(node_id),
726
- NodeStatus.REGISTERED,
727
- NodeStatus.OFFLINE,
728
- )
618
+ # Only activate if the node is currently registered or offline
619
+ current_dt = now()
620
+ sint64_node_id = uint64_to_int64(node_id)
621
+ query = """
622
+ UPDATE node
623
+ SET status = :online,
624
+ last_activated_at = :current,
625
+ online_until = :online_until,
626
+ heartbeat_interval = :heartbeat_interval
627
+ WHERE node_id = :node_id AND status IN (:registered, :offline)
628
+ RETURNING node_id
629
+ """
630
+ params = {
631
+ "online": NodeStatus.ONLINE,
632
+ "current": current_dt.isoformat(),
633
+ "online_until": current_dt.timestamp()
634
+ + HEARTBEAT_PATIENCE * heartbeat_interval,
635
+ "heartbeat_interval": heartbeat_interval,
636
+ "node_id": sint64_node_id,
637
+ "registered": NodeStatus.REGISTERED,
638
+ "offline": NodeStatus.OFFLINE,
639
+ }
729
640
 
730
- row = self.conn.execute(query, params).fetchone()
731
- return row is not None
641
+ rows = self.query(query, params)
642
+ return len(rows) > 0
732
643
 
733
644
  def deactivate_node(self, node_id: int) -> bool:
734
645
  """Deactivate the node with the specified `node_id`."""
735
- with self.conn:
736
- self._check_and_tag_offline_nodes([node_id])
646
+ self._check_and_tag_offline_nodes([node_id])
737
647
 
738
- # Only deactivate if the node is currently online
739
- current_dt = now()
740
- query = """
741
- UPDATE node
742
- SET status = ?,
743
- last_deactivated_at = ?,
744
- online_until = ?
745
- WHERE node_id = ? AND status = ?
746
- RETURNING node_id
747
- """
748
- params = (
749
- NodeStatus.OFFLINE,
750
- current_dt.isoformat(),
751
- current_dt.timestamp(),
752
- uint64_to_int64(node_id),
753
- NodeStatus.ONLINE,
754
- )
648
+ # Only deactivate if the node is currently online
649
+ current_dt = now()
650
+ query = """
651
+ UPDATE node
652
+ SET status = :offline,
653
+ last_deactivated_at = :current_iso,
654
+ online_until = :current_ts
655
+ WHERE node_id = :node_id AND status = :online
656
+ RETURNING node_id
657
+ """
658
+ params = {
659
+ "offline": NodeStatus.OFFLINE,
660
+ "current_iso": current_dt.isoformat(),
661
+ "current_ts": current_dt.timestamp(),
662
+ "node_id": uint64_to_int64(node_id),
663
+ "online": NodeStatus.ONLINE,
664
+ }
755
665
 
756
- row = self.conn.execute(query, params).fetchone()
757
- return row is not None
666
+ rows = self.query(query, params)
667
+ return len(rows) > 0
758
668
 
759
669
  def get_nodes(self, run_id: int) -> set[int]:
760
670
  """Retrieve all currently stored node IDs as a set.
@@ -764,16 +674,13 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
764
674
  If the provided `run_id` does not exist or has no matching nodes,
765
675
  an empty `Set` MUST be returned.
766
676
  """
767
- if self.conn is None:
768
- raise AttributeError("LinkState not initialized")
769
-
770
- with self.conn:
677
+ with self.session():
771
678
  # Convert the uint64 value to sint64 for SQLite
772
679
  sint64_run_id = uint64_to_int64(run_id)
773
680
 
774
681
  # Validate run ID
775
- query = "SELECT federation FROM run WHERE run_id = ?"
776
- rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
682
+ query = "SELECT federation FROM run WHERE run_id = :run_id"
683
+ rows = self.query(query, {"run_id": sint64_run_id})
777
684
  if not rows:
778
685
  return set()
779
686
  federation: str = rows[0]["federation"]
@@ -790,23 +697,25 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
790
697
  """Check and tag offline nodes."""
791
698
  # strftime will convert POSIX timestamp to ISO format
792
699
  query = """
793
- UPDATE node SET status = ?,
700
+ UPDATE node SET status = :offline,
794
701
  last_deactivated_at =
795
- strftime("%Y-%m-%dT%H:%M:%f+00:00", online_until, "unixepoch")
796
- WHERE online_until <= ? AND status == ?
702
+ strftime('%Y-%m-%dT%H:%M:%f+00:00', online_until, 'unixepoch')
703
+ WHERE online_until <= :current_time AND status = :online
797
704
  """
798
- params = [
799
- NodeStatus.OFFLINE,
800
- now().timestamp(),
801
- NodeStatus.ONLINE,
802
- ]
705
+ params: dict[str, Any] = {
706
+ "offline": NodeStatus.OFFLINE,
707
+ "current_time": now().timestamp(),
708
+ "online": NodeStatus.ONLINE,
709
+ }
803
710
  if node_ids is not None:
804
- placeholders = ",".join(["?"] * len(node_ids))
711
+ placeholders = ",".join([f":nid_{i}" for i in range(len(node_ids))])
805
712
  query += f" AND node_id IN ({placeholders})"
806
- params.extend(uint64_to_int64(node_id) for node_id in node_ids)
807
- self.conn.execute(query, params)
713
+ params.update(
714
+ {f"nid_{i}": uint64_to_int64(nid) for i, nid in enumerate(node_ids)}
715
+ )
716
+ self.query(query, params)
808
717
 
809
- def get_node_info(
718
+ def get_node_info( # pylint: disable=too-many-locals
810
719
  self,
811
720
  *,
812
721
  node_ids: Sequence[int] | None = None,
@@ -814,32 +723,37 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
814
723
  statuses: Sequence[str] | None = None,
815
724
  ) -> Sequence[NodeInfo]:
816
725
  """Retrieve information about nodes based on the specified filters."""
817
- with self.conn:
726
+ with self.session():
818
727
  self._check_and_tag_offline_nodes()
819
728
 
820
729
  # Build the WHERE clause based on provided filters
821
730
  conditions = []
822
- params: list[Any] = []
731
+ params: dict[str, Any] = {}
823
732
  if node_ids is not None:
824
733
  sint64_node_ids = [uint64_to_int64(node_id) for node_id in node_ids]
825
- placeholders = ",".join(["?"] * len(sint64_node_ids))
734
+ placeholders = ",".join(
735
+ [f":nid_{i}" for i in range(len(sint64_node_ids))]
736
+ )
826
737
  conditions.append(f"node_id IN ({placeholders})")
827
- params.extend(sint64_node_ids)
738
+ for i, nid in enumerate(sint64_node_ids):
739
+ params[f"nid_{i}"] = nid
828
740
  if owner_aids is not None:
829
- placeholders = ",".join(["?"] * len(owner_aids))
741
+ placeholders = ",".join([f":aid_{i}" for i in range(len(owner_aids))])
830
742
  conditions.append(f"owner_aid IN ({placeholders})")
831
- params.extend(owner_aids)
743
+ for i, aid in enumerate(owner_aids):
744
+ params[f"aid_{i}"] = aid
832
745
  if statuses is not None:
833
- placeholders = ",".join(["?"] * len(statuses))
746
+ placeholders = ",".join([f":st_{i}" for i in range(len(statuses))])
834
747
  conditions.append(f"status IN ({placeholders})")
835
- params.extend(statuses)
748
+ for i, status in enumerate(statuses):
749
+ params[f"st_{i}"] = status
836
750
 
837
751
  # Construct the final query
838
752
  query = "SELECT * FROM node"
839
753
  if conditions:
840
754
  query += " WHERE " + " AND ".join(conditions)
841
755
 
842
- rows = self.conn.execute(query, params).fetchall()
756
+ rows = self.query(query, params)
843
757
 
844
758
  result: list[NodeInfo] = []
845
759
  for row in rows:
@@ -849,27 +763,14 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
849
763
 
850
764
  return result
851
765
 
852
- def get_node_public_key(self, node_id: int) -> bytes:
853
- """Get `public_key` for the specified `node_id`."""
854
- # Convert the uint64 value to sint64 for SQLite
855
- sint64_node_id = uint64_to_int64(node_id)
856
-
857
- # Query the public key for the given node_id
858
- query = "SELECT public_key FROM node WHERE node_id = ? AND status != ?;"
859
- rows = self.query(query, (sint64_node_id, NodeStatus.UNREGISTERED))
860
-
861
- # If no result is found, return None
862
- if not rows:
863
- raise ValueError(f"Node ID {node_id} not found")
864
-
865
- # Return the public key
866
- return cast(bytes, rows[0]["public_key"])
867
-
868
766
  def get_node_id_by_public_key(self, public_key: bytes) -> int | None:
869
767
  """Get `node_id` for the specified `public_key` if it exists and is not
870
768
  deleted."""
871
- query = "SELECT node_id FROM node WHERE public_key = ? AND status != ?;"
872
- rows = self.query(query, (public_key, NodeStatus.UNREGISTERED))
769
+ query = """SELECT node_id FROM node
770
+ WHERE public_key = :public_key AND status != :unregistered;"""
771
+ rows = self.query(
772
+ query, {"public_key": public_key, "unregistered": NodeStatus.UNREGISTERED}
773
+ )
873
774
 
874
775
  # If no result is found, return None
875
776
  if not rows:
@@ -879,8 +780,7 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
879
780
  node_id = int64_to_uint64(rows[0]["node_id"])
880
781
  return node_id
881
782
 
882
- # pylint: disable=too-many-arguments,too-many-positional-arguments
883
- def create_run(
783
+ def create_run( # pylint: disable=too-many-arguments,too-many-positional-arguments
884
784
  self,
885
785
  fab_id: str | None,
886
786
  fab_version: str | None,
@@ -897,41 +797,43 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
897
797
  # Convert the uint64 value to sint64 for SQLite
898
798
  sint64_run_id = uint64_to_int64(uint64_run_id)
899
799
 
900
- with self.conn:
800
+ with self.session():
901
801
  # Check conflicts
902
- query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
903
- # If sint64_run_id does not exist
904
- row = self.conn.execute(query, (sint64_run_id,)).fetchone()
905
- if row["COUNT(*)"] == 0:
802
+ query = "SELECT COUNT(*) as cnt FROM run WHERE run_id = :run_id"
803
+ rows = self.query(query, {"run_id": sint64_run_id})
804
+ if rows[0]["cnt"] == 0:
906
805
  query = """
907
806
  INSERT INTO run
908
807
  (run_id, fab_id, fab_version,
909
808
  fab_hash, override_config, federation, federation_options,
910
809
  pending_at, starting_at, running_at, finished_at, sub_status,
911
810
  details, flwr_aid, bytes_sent, bytes_recv, clientapp_runtime)
912
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
811
+ VALUES (:run_id, :fab_id, :fab_version, :fab_hash, :override_config,
812
+ :federation, :federation_options, :pending_at, :starting_at,
813
+ :running_at, :finished_at, :sub_status, :details, :flwr_aid,
814
+ :bytes_sent, :bytes_recv, :clientapp_runtime)
913
815
  """
914
816
  override_config_json = json.dumps(override_config)
915
- data = [
916
- sint64_run_id, # run_id
917
- fab_id, # fab_id
918
- fab_version, # fab_version
919
- fab_hash, # fab_hash
920
- override_config_json, # override_config
921
- federation, # federation
922
- configrecord_to_bytes(federation_options), # federation_options
923
- now().isoformat(), # pending_at
924
- "", # starting_at
925
- "", # running_at
926
- "", # finished_at
927
- "", # sub_status
928
- "", # details
929
- flwr_aid or "", # flwr_aid
930
- 0, # bytes_sent
931
- 0, # bytes_recv
932
- 0, # clientapp_runtime
933
- ]
934
- self.conn.execute(query, tuple(data))
817
+ params = {
818
+ "run_id": sint64_run_id,
819
+ "fab_id": fab_id or "",
820
+ "fab_version": fab_version or "",
821
+ "fab_hash": fab_hash or "",
822
+ "override_config": override_config_json,
823
+ "federation": federation,
824
+ "federation_options": configrecord_to_bytes(federation_options),
825
+ "pending_at": now().isoformat(),
826
+ "starting_at": "",
827
+ "running_at": "",
828
+ "finished_at": "",
829
+ "sub_status": "",
830
+ "details": "",
831
+ "flwr_aid": flwr_aid or "",
832
+ "bytes_sent": 0,
833
+ "bytes_recv": 0,
834
+ "clientapp_runtime": 0.0,
835
+ }
836
+ self.query(query, params)
935
837
  return uint64_run_id
936
838
  log(ERROR, "Unexpected run creation failure.")
937
839
  return 0
@@ -943,11 +845,11 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
943
845
  """
944
846
  if flwr_aid:
945
847
  rows = self.query(
946
- "SELECT run_id FROM run WHERE flwr_aid = ?;",
947
- (flwr_aid,),
848
+ "SELECT run_id FROM run WHERE flwr_aid = :flwr_aid",
849
+ {"flwr_aid": flwr_aid},
948
850
  )
949
851
  else:
950
- rows = self.query("SELECT run_id FROM run;", ())
852
+ rows = self.query("SELECT run_id FROM run", {})
951
853
  return {int64_to_uint64(row["run_id"]) for row in rows}
952
854
 
953
855
  def get_run(self, run_id: int) -> Run | None:
@@ -957,8 +859,8 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
957
859
 
958
860
  # Convert the uint64 value to sint64 for SQLite
959
861
  sint64_run_id = uint64_to_int64(run_id)
960
- query = "SELECT * FROM run WHERE run_id = ?;"
961
- rows = self.query(query, (sint64_run_id,))
862
+ query = "SELECT * FROM run WHERE run_id = :run_id"
863
+ rows = self.query(query, {"run_id": sint64_run_id})
962
864
  if rows:
963
865
  row = rows[0]
964
866
  return Run(
@@ -991,9 +893,10 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
991
893
  self._cleanup_expired_tokens()
992
894
 
993
895
  # Convert the uint64 value to sint64 for SQLite
994
- sint64_run_ids = (uint64_to_int64(run_id) for run_id in set(run_ids))
995
- query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});"
996
- rows = self.query(query, tuple(sint64_run_ids))
896
+ placeholders = ",".join([f":rid_{i}" for i in range(len(run_ids))])
897
+ query = f"SELECT * FROM run WHERE run_id IN ({placeholders})"
898
+ params = {f"rid_{i}": uint64_to_int64(rid) for i, rid in enumerate(run_ids)}
899
+ rows = self.query(query, params)
997
900
 
998
901
  return {
999
902
  # Restore uint64 run IDs
@@ -1010,11 +913,11 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1010
913
  # Clean up expired tokens; this will flag inactive runs as needed
1011
914
  self._cleanup_expired_tokens()
1012
915
 
1013
- with self.conn:
916
+ with self.session():
1014
917
  # Convert the uint64 value to sint64 for SQLite
1015
918
  sint64_run_id = uint64_to_int64(run_id)
1016
- query = "SELECT * FROM run WHERE run_id = ?;"
1017
- rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
919
+ query = "SELECT * FROM run WHERE run_id = :run_id"
920
+ rows = self.query(query, {"run_id": sint64_run_id})
1018
921
 
1019
922
  # Check if the run_id exists
1020
923
  if not rows:
@@ -1049,7 +952,9 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1049
952
 
1050
953
  # Update the status
1051
954
  query = """
1052
- UPDATE run SET %s= ?, sub_status = ?, details = ? WHERE run_id = ?;
955
+ UPDATE run SET %s = :timestamp,
956
+ sub_status = :sub_status, details = :details
957
+ WHERE run_id = :run_id
1053
958
  """
1054
959
 
1055
960
  # Prepare data for query
@@ -1064,33 +969,30 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1064
969
  elif new_status.status == Status.FINISHED:
1065
970
  timestamp_fld = "finished_at"
1066
971
 
1067
- data = (
1068
- current.isoformat(),
1069
- new_status.sub_status,
1070
- new_status.details,
1071
- uint64_to_int64(run_id),
1072
- )
1073
- self.conn.execute(query % timestamp_fld, data)
972
+ params = {
973
+ "timestamp": current.isoformat(),
974
+ "sub_status": new_status.sub_status,
975
+ "details": new_status.details,
976
+ "run_id": sint64_run_id,
977
+ }
978
+ self.query(query % timestamp_fld, params)
1074
979
  return True
1075
980
 
1076
981
  def get_pending_run_id(self) -> int | None:
1077
- """Get the `run_id` of a run with `Status.PENDING` status, if any."""
1078
- pending_run_id = None
1079
-
982
+ """Get the `run_id` of a run with `Status.PENDING` status."""
1080
983
  # Fetch all runs with unset `starting_at` (i.e. they are in PENDING status)
1081
- query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1;"
1082
- rows = self.query(query)
984
+ query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1"
985
+ rows = self.query(query, {})
1083
986
  if rows:
1084
- pending_run_id = int64_to_uint64(rows[0]["run_id"])
1085
-
1086
- return pending_run_id
987
+ return int64_to_uint64(rows[0]["run_id"])
988
+ return None
1087
989
 
1088
990
  def get_federation_options(self, run_id: int) -> ConfigRecord | None:
1089
991
  """Retrieve the federation options for the specified `run_id`."""
1090
992
  # Convert the uint64 value to sint64 for SQLite
1091
993
  sint64_run_id = uint64_to_int64(run_id)
1092
- query = "SELECT federation_options FROM run WHERE run_id = ?;"
1093
- rows = self.query(query, (sint64_run_id,))
994
+ query = "SELECT federation_options FROM run WHERE run_id = :run_id"
995
+ rows = self.query(query, {"run_id": sint64_run_id})
1094
996
 
1095
997
  # Check if the run_id exists
1096
998
  if not rows:
@@ -1110,41 +1012,46 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1110
1012
  HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before
1111
1013
  the node is marked as offline.
1112
1014
  """
1113
- if self.conn is None:
1114
- raise AttributeError("LinkState not initialized")
1115
-
1116
1015
  sint64_node_id = uint64_to_int64(node_id)
1117
1016
 
1118
- with self.conn:
1119
- # Check if node exists and not deleted
1120
- query = "SELECT status FROM node WHERE node_id = ? AND status != ?"
1121
- row = self.conn.execute(
1122
- query, (sint64_node_id, NodeStatus.UNREGISTERED)
1123
- ).fetchone()
1124
- if row is None:
1125
- return False
1017
+ # Check if the node exists and is not unregistered
1018
+ query = """
1019
+ SELECT status FROM node WHERE node_id = :node_id AND status != :unregistered
1020
+ """
1021
+ rows = self.query(
1022
+ query, {"node_id": sint64_node_id, "unregistered": NodeStatus.UNREGISTERED}
1023
+ )
1024
+ if not rows:
1025
+ return False
1126
1026
 
1127
- # Construct query and params
1128
- current_dt = now()
1129
- query = "UPDATE node SET online_until = ?, heartbeat_interval = ?"
1130
- params: list[Any] = [
1131
- current_dt.timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval,
1132
- heartbeat_interval,
1133
- ]
1027
+ # Construct query and params
1028
+ current_dt = now()
1029
+ query = (
1030
+ "UPDATE node SET online_until = :online_until, "
1031
+ "heartbeat_interval = :heartbeat_interval"
1032
+ )
1033
+ params: dict[str, Any] = {
1034
+ "online_until": current_dt.timestamp()
1035
+ + HEARTBEAT_PATIENCE * heartbeat_interval,
1036
+ "heartbeat_interval": heartbeat_interval,
1037
+ }
1134
1038
 
1135
- # Set timestamp if the status changes
1136
- if row["status"] != NodeStatus.ONLINE:
1137
- query += ", status = ?, last_activated_at = ?"
1138
- params += [NodeStatus.ONLINE, current_dt.isoformat()]
1039
+ # Set timestamp if the status changes
1040
+ if rows[0]["status"] != NodeStatus.ONLINE:
1041
+ query += ", status = :online, last_activated_at = :last_activated_at"
1042
+ params["online"] = NodeStatus.ONLINE
1043
+ params["last_activated_at"] = current_dt.isoformat()
1139
1044
 
1140
- # Execute the query, refreshing `online_until` and `heartbeat_interval`
1141
- query += " WHERE node_id = ?"
1142
- params += [sint64_node_id]
1143
- self.conn.execute(query, params)
1144
- return True
1045
+ # Execute the query, refreshing `online_until` and `heartbeat_interval`
1046
+ query += " WHERE node_id = :node_id"
1047
+ params["node_id"] = sint64_node_id
1048
+ self.query(query, params)
1049
+ return True
1145
1050
 
1146
1051
  def _on_tokens_expired(self, expired_records: list[tuple[int, float]]) -> None:
1147
- """Transition runs with expired tokens to failed status.
1052
+ """Handle cleanup of expired tokens.
1053
+
1054
+ Override in subclasses to add custom cleanup logic.
1148
1055
 
1149
1056
  Parameters
1150
1057
  ----------
@@ -1155,28 +1062,30 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1155
1062
  if not expired_records:
1156
1063
  return
1157
1064
 
1158
- with self.conn:
1065
+ with self.session():
1159
1066
  query = """
1160
1067
  UPDATE run
1161
- SET sub_status = ?, details = ?, finished_at = ?
1162
- WHERE run_id = ?;
1068
+ SET sub_status = :failed, details = :details, finished_at = :finished_at
1069
+ WHERE run_id = :run_id
1163
1070
  """
1164
1071
  data = [
1165
- (
1166
- SubStatus.FAILED,
1167
- RUN_FAILURE_DETAILS_NO_HEARTBEAT,
1168
- datetime.fromtimestamp(active_until, tz=timezone.utc).isoformat(),
1169
- uint64_to_int64(run_id),
1170
- )
1072
+ {
1073
+ "failed": SubStatus.FAILED,
1074
+ "details": RUN_FAILURE_DETAILS_NO_HEARTBEAT,
1075
+ "finished_at": datetime.fromtimestamp(
1076
+ active_until, tz=timezone.utc
1077
+ ).isoformat(),
1078
+ "run_id": uint64_to_int64(run_id),
1079
+ }
1171
1080
  for run_id, active_until in expired_records
1172
1081
  ]
1173
- self.conn.executemany(query, data)
1082
+ self.query(query, data)
1174
1083
 
1175
1084
  def get_serverapp_context(self, run_id: int) -> Context | None:
1176
1085
  """Get the context for the specified `run_id`."""
1177
1086
  # Retrieve context if any
1178
- query = "SELECT context FROM context WHERE run_id = ?;"
1179
- rows = self.query(query, (uint64_to_int64(run_id),))
1087
+ query = "SELECT context FROM context WHERE run_id = :run_id"
1088
+ rows = self.query(query, {"run_id": uint64_to_int64(run_id)})
1180
1089
  context = context_from_bytes(rows[0]["context"]) if rows else None
1181
1090
  return context
1182
1091
 
@@ -1186,20 +1095,30 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1186
1095
  context_bytes = context_to_bytes(context)
1187
1096
  sint_run_id = uint64_to_int64(run_id)
1188
1097
 
1189
- with self.conn:
1098
+ with self.session():
1190
1099
  # Check if any existing Context assigned to the run_id
1191
- query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
1192
- row = self.conn.execute(query, (sint_run_id,)).fetchone()
1193
- if row["COUNT(*)"] > 0:
1100
+ query = "SELECT COUNT(*) as count FROM context WHERE run_id = :run_id"
1101
+ row = self.query(query, {"run_id": sint_run_id})[0]
1102
+ if row["count"] > 0:
1194
1103
  # Update context
1195
- query = "UPDATE context SET context = ? WHERE run_id = ?;"
1196
- self.conn.execute(query, (context_bytes, sint_run_id))
1104
+ query = """
1105
+ UPDATE context
1106
+ SET context = :context_bytes WHERE run_id = :run_id
1107
+ """
1108
+ self.query(
1109
+ query, {"context_bytes": context_bytes, "run_id": sint_run_id}
1110
+ )
1197
1111
  else:
1198
1112
  try:
1199
1113
  # Store context
1200
- query = "INSERT INTO context (run_id, context) VALUES (?, ?);"
1201
- self.conn.execute(query, (sint_run_id, context_bytes))
1202
- except sqlite3.IntegrityError:
1114
+ query = (
1115
+ "INSERT INTO context (run_id, context) "
1116
+ "VALUES (:run_id, :context_bytes)"
1117
+ )
1118
+ self.query(
1119
+ query, {"run_id": sint_run_id, "context_bytes": context_bytes}
1120
+ )
1121
+ except IntegrityError:
1203
1122
  raise ValueError(f"Run {run_id} not found") from None
1204
1123
 
1205
1124
  def add_serverapp_log(self, run_id: int, log_message: str) -> None:
@@ -1210,10 +1129,19 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1210
1129
  # Store log
1211
1130
  try:
1212
1131
  query = """
1213
- INSERT INTO logs (timestamp, run_id, node_id, log) VALUES (?, ?, ?, ?);
1132
+ INSERT INTO logs (timestamp, run_id, node_id, log)
1133
+ VALUES (:current_ts, :run_id, :node_id, :log)
1214
1134
  """
1215
- self.query(query, (now().timestamp(), sint64_run_id, 0, log_message))
1216
- except sqlite3.IntegrityError:
1135
+ self.query(
1136
+ query,
1137
+ {
1138
+ "current_ts": now().timestamp(),
1139
+ "run_id": sint64_run_id,
1140
+ "node_id": 0,
1141
+ "log": log_message,
1142
+ },
1143
+ )
1144
+ except IntegrityError:
1217
1145
  raise ValueError(f"Run {run_id} not found") from None
1218
1146
 
1219
1147
  def get_serverapp_log(
@@ -1223,10 +1151,10 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1223
1151
  # Convert the uint64 value to sint64 for SQLite
1224
1152
  sint64_run_id = uint64_to_int64(run_id)
1225
1153
 
1226
- with self.conn:
1154
+ with self.session():
1227
1155
  # Check if the run_id exists
1228
- query = "SELECT run_id FROM run WHERE run_id = ?;"
1229
- rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
1156
+ query = "SELECT run_id FROM run WHERE run_id = :run_id"
1157
+ rows = self.query(query, {"run_id": sint64_run_id})
1230
1158
  if not rows:
1231
1159
  raise ValueError(f"Run {run_id} not found")
1232
1160
 
@@ -1235,12 +1163,18 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1235
1163
  after_timestamp = 0.0
1236
1164
  query = """
1237
1165
  SELECT log, timestamp FROM logs
1238
- WHERE run_id = ? AND node_id = ? AND timestamp > ?;
1166
+ WHERE run_id = :run_id AND node_id = :node_id
1167
+ AND timestamp > :after_timestamp
1168
+ ORDER BY timestamp
1239
1169
  """
1240
- rows = self.conn.execute(
1241
- query, (sint64_run_id, 0, after_timestamp)
1242
- ).fetchall()
1243
- rows.sort(key=lambda x: x["timestamp"])
1170
+ rows = self.query(
1171
+ query,
1172
+ {
1173
+ "run_id": sint64_run_id,
1174
+ "node_id": 0,
1175
+ "after_timestamp": after_timestamp,
1176
+ },
1177
+ )
1244
1178
  latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
1245
1179
  return "".join(row["log"] for row in rows), latest_timestamp
1246
1180
 
@@ -1249,20 +1183,19 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1249
1183
 
1250
1184
  Return Message if valid.
1251
1185
  """
1252
- with self.conn:
1186
+ with self.session():
1253
1187
  self._check_stored_messages({message_id})
1254
1188
  query = """
1255
1189
  SELECT *
1256
1190
  FROM message_ins
1257
1191
  WHERE message_id = :message_id
1258
1192
  """
1259
- data = {"message_id": message_id}
1260
- rows: list[dict[str, Any]] = self.conn.execute(query, data).fetchall()
1193
+ rows = self.query(query, {"message_id": message_id})
1261
1194
  if not rows:
1262
1195
  # Message does not exist
1263
1196
  return None
1264
1197
 
1265
- return rows[0]
1198
+ return dict(rows[0])
1266
1199
 
1267
1200
  def store_traffic(self, run_id: int, *, bytes_sent: int, bytes_recv: int) -> None:
1268
1201
  """Store traffic data for the specified `run_id`."""
@@ -1280,18 +1213,23 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1280
1213
 
1281
1214
  sint64_run_id = uint64_to_int64(run_id)
1282
1215
 
1283
- with self.conn:
1216
+ with self.session():
1284
1217
  # Check if run exists, performing the update only if it does
1285
1218
  update_query = """
1286
1219
  UPDATE run
1287
- SET bytes_sent = bytes_sent + ?,
1288
- bytes_recv = bytes_recv + ?
1289
- WHERE run_id = ?
1290
- RETURNING run_id;
1220
+ SET bytes_sent = bytes_sent + :bytes_sent,
1221
+ bytes_recv = bytes_recv + :bytes_recv
1222
+ WHERE run_id = :run_id
1223
+ RETURNING run_id
1291
1224
  """
1292
- rows = self.conn.execute(
1293
- update_query, (bytes_sent, bytes_recv, sint64_run_id)
1294
- ).fetchall()
1225
+ rows = self.query(
1226
+ update_query,
1227
+ {
1228
+ "bytes_sent": bytes_sent,
1229
+ "bytes_recv": bytes_recv,
1230
+ "run_id": sint64_run_id,
1231
+ },
1232
+ )
1295
1233
 
1296
1234
  if not rows:
1297
1235
  raise ValueError(f"Run {run_id} not found")
@@ -1299,62 +1237,22 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1299
1237
  def add_clientapp_runtime(self, run_id: int, runtime: float) -> None:
1300
1238
  """Add ClientApp runtime to the cumulative total for the specified `run_id`."""
1301
1239
  sint64_run_id = uint64_to_int64(run_id)
1302
- with self.conn:
1240
+ with self.session():
1303
1241
  # Check if run exists, performing the update only if it does
1304
1242
  update_query = """
1305
1243
  UPDATE run
1306
- SET clientapp_runtime = clientapp_runtime + ?
1307
- WHERE run_id = ?
1308
- RETURNING run_id;
1244
+ SET clientapp_runtime = clientapp_runtime + :runtime
1245
+ WHERE run_id = :run_id
1246
+ RETURNING run_id
1309
1247
  """
1310
- rows = self.conn.execute(update_query, (runtime, sint64_run_id)).fetchall()
1248
+ rows = self.query(
1249
+ update_query, {"runtime": runtime, "run_id": sint64_run_id}
1250
+ )
1311
1251
 
1312
1252
  if not rows:
1313
1253
  raise ValueError(f"Run {run_id} not found")
1314
1254
 
1315
1255
 
1316
- def message_to_dict(message: Message) -> dict[str, Any]:
1317
- """Transform Message to dict."""
1318
- result = {
1319
- "message_id": message.metadata.message_id,
1320
- "group_id": message.metadata.group_id,
1321
- "run_id": message.metadata.run_id,
1322
- "src_node_id": message.metadata.src_node_id,
1323
- "dst_node_id": message.metadata.dst_node_id,
1324
- "reply_to_message_id": message.metadata.reply_to_message_id,
1325
- "created_at": message.metadata.created_at,
1326
- "delivered_at": message.metadata.delivered_at,
1327
- "ttl": message.metadata.ttl,
1328
- "message_type": message.metadata.message_type,
1329
- "content": None,
1330
- "error": None,
1331
- }
1332
-
1333
- if message.has_content():
1334
- result["content"] = recorddict_to_proto(message.content).SerializeToString()
1335
- else:
1336
- result["error"] = error_to_proto(message.error).SerializeToString()
1337
-
1338
- return result
1339
-
1340
-
1341
- def dict_to_message(message_dict: dict[str, Any]) -> Message:
1342
- """Transform dict to Message."""
1343
- content, error = None, None
1344
- if (b_content := message_dict.pop("content")) is not None:
1345
- content = recorddict_from_proto(ProtoRecordDict.FromString(b_content))
1346
- if (b_error := message_dict.pop("error")) is not None:
1347
- error = error_from_proto(ProtoError.FromString(b_error))
1348
-
1349
- # Metadata constructor doesn't allow passing created_at. We set it later
1350
- metadata = Metadata(
1351
- **{k: v for k, v in message_dict.items() if k not in ["delivered_at"]}
1352
- )
1353
- msg = make_message(metadata=metadata, content=content, error=error)
1354
- msg.metadata.delivered_at = message_dict["delivered_at"]
1355
- return msg
1356
-
1357
-
1358
1256
  def determine_run_status(row: dict[str, Any]) -> str:
1359
1257
  """Determine the status of the run based on timestamp fields."""
1360
1258
  if row["pending_at"]:
@@ -1366,4 +1264,4 @@ def determine_run_status(row: dict[str, Any]) -> str:
1366
1264
  return Status.STARTING
1367
1265
  return Status.PENDING
1368
1266
  run_id = int64_to_uint64(row["run_id"])
1369
- raise sqlite3.IntegrityError(f"The run {run_id} does not have a valid status.")
1267
+ raise ValueError(f"The run {run_id} does not have a valid status.")