flwr 1.24.0__py3-none-any.whl → 1.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (204) hide show
  1. flwr/__init__.py +1 -1
  2. flwr/app/__init__.py +4 -1
  3. flwr/app/message_type.py +29 -0
  4. flwr/app/metadata.py +5 -2
  5. flwr/app/user_config.py +19 -0
  6. flwr/cli/app.py +37 -19
  7. flwr/cli/app_cmd/publish.py +25 -75
  8. flwr/cli/app_cmd/review.py +25 -66
  9. flwr/cli/auth_plugin/auth_plugin.py +5 -10
  10. flwr/cli/auth_plugin/noop_auth_plugin.py +1 -2
  11. flwr/cli/auth_plugin/oidc_cli_plugin.py +38 -38
  12. flwr/cli/build.py +15 -28
  13. flwr/cli/config/__init__.py +21 -0
  14. flwr/cli/config/ls.py +71 -0
  15. flwr/cli/config_migration.py +297 -0
  16. flwr/cli/config_utils.py +63 -156
  17. flwr/cli/constant.py +71 -0
  18. flwr/cli/federation/__init__.py +0 -2
  19. flwr/cli/federation/ls.py +256 -64
  20. flwr/cli/flower_config.py +429 -0
  21. flwr/cli/install.py +23 -62
  22. flwr/cli/log.py +23 -37
  23. flwr/cli/login/login.py +29 -63
  24. flwr/cli/ls.py +72 -61
  25. flwr/cli/new/new.py +98 -309
  26. flwr/cli/pull.py +19 -37
  27. flwr/cli/run/run.py +87 -100
  28. flwr/cli/run_utils.py +23 -5
  29. flwr/cli/stop.py +33 -74
  30. flwr/cli/supernode/ls.py +35 -62
  31. flwr/cli/supernode/register.py +31 -80
  32. flwr/cli/supernode/unregister.py +24 -70
  33. flwr/cli/typing.py +200 -0
  34. flwr/cli/utils.py +160 -412
  35. flwr/client/grpc_adapter_client/connection.py +2 -2
  36. flwr/client/grpc_rere_client/connection.py +9 -6
  37. flwr/client/grpc_rere_client/grpc_adapter.py +1 -1
  38. flwr/client/message_handler/message_handler.py +2 -1
  39. flwr/client/mod/centraldp_mods.py +1 -1
  40. flwr/client/mod/localdp_mod.py +1 -1
  41. flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
  42. flwr/client/rest_client/connection.py +6 -4
  43. flwr/client/run_info_store.py +2 -1
  44. flwr/clientapp/client_app.py +2 -1
  45. flwr/common/__init__.py +3 -2
  46. flwr/common/args.py +5 -5
  47. flwr/common/config.py +12 -17
  48. flwr/common/constant.py +3 -16
  49. flwr/common/context.py +2 -1
  50. flwr/common/exit/exit.py +4 -4
  51. flwr/common/exit/exit_code.py +6 -0
  52. flwr/common/grpc.py +2 -1
  53. flwr/common/logger.py +1 -1
  54. flwr/common/message.py +1 -1
  55. flwr/common/retry_invoker.py +13 -5
  56. flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -2
  57. flwr/common/serde.py +13 -5
  58. flwr/common/telemetry.py +1 -1
  59. flwr/common/typing.py +10 -3
  60. flwr/compat/client/app.py +6 -9
  61. flwr/compat/client/grpc_client/connection.py +2 -1
  62. flwr/compat/common/constant.py +29 -0
  63. flwr/compat/server/app.py +1 -1
  64. flwr/proto/clientappio_pb2.py +2 -2
  65. flwr/proto/clientappio_pb2_grpc.py +104 -88
  66. flwr/proto/clientappio_pb2_grpc.pyi +140 -80
  67. flwr/proto/federation_pb2.py +5 -3
  68. flwr/proto/federation_pb2.pyi +32 -2
  69. flwr/proto/fleet_pb2.py +10 -10
  70. flwr/proto/fleet_pb2.pyi +5 -1
  71. flwr/proto/run_pb2.py +18 -26
  72. flwr/proto/run_pb2.pyi +10 -58
  73. flwr/proto/serverappio_pb2.py +2 -2
  74. flwr/proto/serverappio_pb2_grpc.py +138 -207
  75. flwr/proto/serverappio_pb2_grpc.pyi +189 -155
  76. flwr/proto/simulationio_pb2.py +2 -2
  77. flwr/proto/simulationio_pb2_grpc.py +62 -90
  78. flwr/proto/simulationio_pb2_grpc.pyi +95 -55
  79. flwr/server/app.py +7 -13
  80. flwr/server/compat/grid_client_proxy.py +2 -1
  81. flwr/server/grid/grpc_grid.py +5 -5
  82. flwr/server/serverapp/app.py +11 -4
  83. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -1
  84. flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +13 -12
  85. flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
  86. flwr/server/superlink/linkstate/__init__.py +2 -2
  87. flwr/server/superlink/linkstate/in_memory_linkstate.py +36 -10
  88. flwr/server/superlink/linkstate/linkstate.py +34 -21
  89. flwr/server/superlink/linkstate/linkstate_factory.py +16 -8
  90. flwr/server/superlink/linkstate/{sqlite_linkstate.py → sql_linkstate.py} +471 -516
  91. flwr/server/superlink/linkstate/utils.py +49 -2
  92. flwr/server/superlink/serverappio/serverappio_servicer.py +1 -33
  93. flwr/server/superlink/simulation/simulationio_servicer.py +0 -19
  94. flwr/server/utils/validator.py +1 -1
  95. flwr/server/workflow/default_workflows.py +2 -1
  96. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +1 -1
  97. flwr/serverapp/strategy/bulyan.py +7 -1
  98. flwr/serverapp/strategy/dp_fixed_clipping.py +9 -1
  99. flwr/serverapp/strategy/fedavg.py +1 -1
  100. flwr/serverapp/strategy/fedxgb_cyclic.py +1 -1
  101. flwr/simulation/ray_transport/ray_client_proxy.py +2 -6
  102. flwr/simulation/run_simulation.py +3 -12
  103. flwr/simulation/simulationio_connection.py +3 -3
  104. flwr/{common → supercore}/address.py +7 -33
  105. flwr/supercore/app_utils.py +2 -1
  106. flwr/supercore/constant.py +27 -2
  107. flwr/supercore/corestate/{sqlite_corestate.py → sql_corestate.py} +19 -23
  108. flwr/supercore/credential_store/__init__.py +33 -0
  109. flwr/supercore/credential_store/credential_store.py +34 -0
  110. flwr/supercore/credential_store/file_credential_store.py +76 -0
  111. flwr/{common → supercore}/date.py +0 -11
  112. flwr/supercore/ffs/disk_ffs.py +1 -1
  113. flwr/supercore/object_store/object_store_factory.py +14 -6
  114. flwr/supercore/object_store/{sqlite_object_store.py → sql_object_store.py} +115 -117
  115. flwr/supercore/sql_mixin.py +315 -0
  116. flwr/{cli/new/templates → supercore/state}/__init__.py +2 -2
  117. flwr/{cli/new/templates/app/code/flwr_tune → supercore/state/alembic}/__init__.py +2 -2
  118. flwr/supercore/state/alembic/env.py +103 -0
  119. flwr/supercore/state/alembic/script.py.mako +43 -0
  120. flwr/supercore/state/alembic/utils.py +239 -0
  121. flwr/{cli/new/templates/app → supercore/state/alembic/versions}/__init__.py +2 -2
  122. flwr/supercore/state/alembic/versions/rev_2026_01_28_initialize_migration_of_state_tables.py +200 -0
  123. flwr/supercore/state/schema/README.md +121 -0
  124. flwr/{cli/new/templates/app/code → supercore/state/schema}/__init__.py +2 -2
  125. flwr/supercore/state/schema/corestate_tables.py +36 -0
  126. flwr/supercore/state/schema/linkstate_tables.py +152 -0
  127. flwr/supercore/state/schema/objectstore_tables.py +90 -0
  128. flwr/supercore/superexec/run_superexec.py +2 -2
  129. flwr/supercore/utils.py +225 -0
  130. flwr/superlink/federation/federation_manager.py +2 -2
  131. flwr/superlink/federation/noop_federation_manager.py +8 -6
  132. flwr/superlink/servicer/control/control_grpc.py +2 -0
  133. flwr/superlink/servicer/control/control_servicer.py +106 -21
  134. flwr/supernode/cli/flower_supernode.py +2 -1
  135. flwr/supernode/nodestate/in_memory_nodestate.py +62 -1
  136. flwr/supernode/nodestate/nodestate.py +45 -0
  137. flwr/supernode/runtime/run_clientapp.py +14 -14
  138. flwr/supernode/servicer/clientappio/clientappio_servicer.py +13 -5
  139. flwr/supernode/start_client_internal.py +17 -10
  140. {flwr-1.24.0.dist-info → flwr-1.26.0.dist-info}/METADATA +8 -8
  141. {flwr-1.24.0.dist-info → flwr-1.26.0.dist-info}/RECORD +144 -184
  142. flwr/cli/federation/show.py +0 -317
  143. flwr/cli/new/templates/app/.gitignore.tpl +0 -163
  144. flwr/cli/new/templates/app/LICENSE.tpl +0 -202
  145. flwr/cli/new/templates/app/README.baseline.md.tpl +0 -127
  146. flwr/cli/new/templates/app/README.flowertune.md.tpl +0 -68
  147. flwr/cli/new/templates/app/README.md.tpl +0 -37
  148. flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +0 -1
  149. flwr/cli/new/templates/app/code/__init__.py.tpl +0 -1
  150. flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +0 -1
  151. flwr/cli/new/templates/app/code/client.baseline.py.tpl +0 -75
  152. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +0 -93
  153. flwr/cli/new/templates/app/code/client.jax.py.tpl +0 -71
  154. flwr/cli/new/templates/app/code/client.mlx.py.tpl +0 -102
  155. flwr/cli/new/templates/app/code/client.numpy.py.tpl +0 -46
  156. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +0 -80
  157. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +0 -55
  158. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +0 -108
  159. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -82
  160. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +0 -110
  161. flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +0 -36
  162. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +0 -92
  163. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +0 -87
  164. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -56
  165. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +0 -73
  166. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +0 -78
  167. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -66
  168. flwr/cli/new/templates/app/code/server.baseline.py.tpl +0 -43
  169. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +0 -42
  170. flwr/cli/new/templates/app/code/server.jax.py.tpl +0 -39
  171. flwr/cli/new/templates/app/code/server.mlx.py.tpl +0 -41
  172. flwr/cli/new/templates/app/code/server.numpy.py.tpl +0 -38
  173. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +0 -41
  174. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +0 -31
  175. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +0 -44
  176. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +0 -38
  177. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +0 -56
  178. flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +0 -1
  179. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +0 -98
  180. flwr/cli/new/templates/app/code/task.jax.py.tpl +0 -57
  181. flwr/cli/new/templates/app/code/task.mlx.py.tpl +0 -102
  182. flwr/cli/new/templates/app/code/task.numpy.py.tpl +0 -7
  183. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +0 -99
  184. flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +0 -111
  185. flwr/cli/new/templates/app/code/task.sklearn.py.tpl +0 -67
  186. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +0 -52
  187. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +0 -67
  188. flwr/cli/new/templates/app/code/utils.baseline.py.tpl +0 -1
  189. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +0 -146
  190. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +0 -80
  191. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +0 -65
  192. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +0 -52
  193. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +0 -56
  194. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +0 -49
  195. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +0 -53
  196. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +0 -53
  197. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +0 -52
  198. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +0 -53
  199. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +0 -61
  200. flwr/common/pyproject.py +0 -42
  201. flwr/supercore/sqlite_mixin.py +0 -159
  202. /flwr/{common → supercore}/version.py +0 -0
  203. {flwr-1.24.0.dist-info → flwr-1.26.0.dist-info}/WHEEL +0 -0
  204. {flwr-1.24.0.dist-info → flwr-1.26.0.dist-info}/entry_points.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2026 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,19 +12,22 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """SQLite based implemenation of the link state."""
15
+ """SQLAlchemy-based implementation of the link state."""
16
16
 
17
17
 
18
18
  # pylint: disable=too-many-lines
19
19
 
20
20
  import json
21
- import sqlite3
22
21
  from collections.abc import Sequence
23
22
  from datetime import datetime, timezone
24
23
  from logging import ERROR, WARNING
25
- from typing import Any, cast
24
+ from typing import Any
26
25
 
27
- from flwr.common import Context, Message, Metadata, log, now
26
+ from sqlalchemy import MetaData
27
+ from sqlalchemy.exc import IntegrityError
28
+
29
+ from flwr.app.user_config import UserConfig
30
+ from flwr.common import Context, Message, log, now
28
31
  from flwr.common.constant import (
29
32
  HEARTBEAT_PATIENCE,
30
33
  MESSAGE_TTL_TOLERANCE,
@@ -35,22 +38,15 @@ from flwr.common.constant import (
35
38
  Status,
36
39
  SubStatus,
37
40
  )
38
- from flwr.common.message import make_message
39
41
  from flwr.common.record import ConfigRecord
40
- from flwr.common.serde import recorddict_from_proto, recorddict_to_proto
41
- from flwr.common.serde_utils import error_from_proto, error_to_proto
42
- from flwr.common.typing import Run, RunStatus, UserConfig
43
-
44
- # pylint: disable=E0611
45
- from flwr.proto.error_pb2 import Error as ProtoError
46
- from flwr.proto.node_pb2 import NodeInfo
47
- from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
48
-
49
- # pylint: enable=E0611
42
+ from flwr.common.typing import Run, RunStatus
43
+ from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
50
44
  from flwr.server.utils.validator import validate_message
51
45
  from flwr.supercore.constant import NodeStatus
52
- from flwr.supercore.corestate.sqlite_corestate import SqliteCoreState
46
+ from flwr.supercore.corestate.sql_corestate import SqlCoreState
53
47
  from flwr.supercore.object_store.object_store import ObjectStore
48
+ from flwr.supercore.state.schema.corestate_tables import create_corestate_metadata
49
+ from flwr.supercore.state.schema.linkstate_tables import create_linkstate_metadata
54
50
  from flwr.supercore.utils import int64_to_uint64, uint64_to_int64
55
51
  from flwr.superlink.federation import FederationManager
56
52
 
@@ -63,125 +59,18 @@ from .utils import (
63
59
  context_to_bytes,
64
60
  convert_sint64_values_in_dict_to_uint64,
65
61
  convert_uint64_values_in_dict_to_sint64,
62
+ dict_to_message,
66
63
  generate_rand_int_from_bytes,
67
64
  has_valid_sub_status,
68
65
  is_valid_transition,
66
+ message_to_dict,
69
67
  verify_found_message_replies,
70
68
  verify_message_ids,
71
69
  )
72
70
 
73
- SQL_CREATE_TABLE_NODE = """
74
- CREATE TABLE IF NOT EXISTS node(
75
- node_id INTEGER UNIQUE,
76
- owner_aid TEXT,
77
- owner_name TEXT,
78
- status TEXT,
79
- registered_at TEXT,
80
- last_activated_at TEXT NULL,
81
- last_deactivated_at TEXT NULL,
82
- unregistered_at TEXT NULL,
83
- online_until TIMESTAMP NULL,
84
- heartbeat_interval REAL,
85
- public_key BLOB UNIQUE
86
- );
87
- """
88
-
89
- SQL_CREATE_TABLE_PUBLIC_KEY = """
90
- CREATE TABLE IF NOT EXISTS public_key(
91
- public_key BLOB PRIMARY KEY
92
- );
93
- """
94
-
95
- SQL_CREATE_INDEX_ONLINE_UNTIL = """
96
- CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
97
- """
98
-
99
- SQL_CREATE_INDEX_OWNER_AID = """
100
- CREATE INDEX IF NOT EXISTS idx_node_owner_aid ON node(owner_aid);
101
- """
102
-
103
- SQL_CREATE_INDEX_NODE_STATUS = """
104
- CREATE INDEX IF NOT EXISTS idx_node_status ON node(status);
105
- """
106
-
107
- SQL_CREATE_TABLE_RUN = """
108
- CREATE TABLE IF NOT EXISTS run(
109
- run_id INTEGER UNIQUE,
110
- fab_id TEXT,
111
- fab_version TEXT,
112
- fab_hash TEXT,
113
- override_config TEXT,
114
- pending_at TEXT,
115
- starting_at TEXT,
116
- running_at TEXT,
117
- finished_at TEXT,
118
- sub_status TEXT,
119
- details TEXT,
120
- federation TEXT,
121
- federation_options BLOB,
122
- flwr_aid TEXT
123
- );
124
- """
125
-
126
- SQL_CREATE_TABLE_LOGS = """
127
- CREATE TABLE IF NOT EXISTS logs (
128
- timestamp REAL,
129
- run_id INTEGER,
130
- node_id INTEGER,
131
- log TEXT,
132
- PRIMARY KEY (timestamp, run_id, node_id),
133
- FOREIGN KEY (run_id) REFERENCES run(run_id)
134
- );
135
- """
136
-
137
- SQL_CREATE_TABLE_CONTEXT = """
138
- CREATE TABLE IF NOT EXISTS context(
139
- run_id INTEGER UNIQUE,
140
- context BLOB,
141
- FOREIGN KEY(run_id) REFERENCES run(run_id)
142
- );
143
- """
144
-
145
- SQL_CREATE_TABLE_MESSAGE_INS = """
146
- CREATE TABLE IF NOT EXISTS message_ins(
147
- message_id TEXT UNIQUE,
148
- group_id TEXT,
149
- run_id INTEGER,
150
- src_node_id INTEGER,
151
- dst_node_id INTEGER,
152
- reply_to_message_id TEXT,
153
- created_at REAL,
154
- delivered_at TEXT,
155
- ttl REAL,
156
- message_type TEXT,
157
- content BLOB NULL,
158
- error BLOB NULL,
159
- FOREIGN KEY(run_id) REFERENCES run(run_id)
160
- );
161
- """
162
-
163
-
164
- SQL_CREATE_TABLE_MESSAGE_RES = """
165
- CREATE TABLE IF NOT EXISTS message_res(
166
- message_id TEXT UNIQUE,
167
- group_id TEXT,
168
- run_id INTEGER,
169
- src_node_id INTEGER,
170
- dst_node_id INTEGER,
171
- reply_to_message_id TEXT,
172
- created_at REAL,
173
- delivered_at TEXT,
174
- ttl REAL,
175
- message_type TEXT,
176
- content BLOB NULL,
177
- error BLOB NULL,
178
- FOREIGN KEY(run_id) REFERENCES run(run_id)
179
- );
180
- """
181
-
182
-
183
- class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
184
- """SQLite-based LinkState implementation."""
71
+
72
+ class SqlLinkState(LinkState, SqlCoreState): # pylint: disable=R0904
73
+ """SQLAlchemy-based LinkState implementation."""
185
74
 
186
75
  def __init__(
187
76
  self,
@@ -193,24 +82,21 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
193
82
  federation_manager.linkstate = self
194
83
  self._federation_manager = federation_manager
195
84
 
196
- def get_sql_statements(self) -> tuple[str, ...]:
197
- """Return SQL statements for LinkState tables."""
198
- return super().get_sql_statements() + (
199
- SQL_CREATE_TABLE_RUN,
200
- SQL_CREATE_TABLE_LOGS,
201
- SQL_CREATE_TABLE_CONTEXT,
202
- SQL_CREATE_TABLE_MESSAGE_INS,
203
- SQL_CREATE_TABLE_MESSAGE_RES,
204
- SQL_CREATE_TABLE_NODE,
205
- SQL_CREATE_TABLE_PUBLIC_KEY,
206
- SQL_CREATE_INDEX_ONLINE_UNTIL,
207
- SQL_CREATE_INDEX_OWNER_AID,
208
- SQL_CREATE_INDEX_NODE_STATUS,
209
- )
85
+ def get_metadata(self) -> MetaData:
86
+ """Return combined SQLAlchemy MetaData for LinkState and CoreState tables."""
87
+ # Start with linkstate tables
88
+ metadata = create_linkstate_metadata()
89
+
90
+ # Add corestate tables (token_store)
91
+ corestate_metadata = create_corestate_metadata()
92
+ for table in corestate_metadata.tables.values():
93
+ table.to_metadata(metadata)
94
+
95
+ return metadata
210
96
 
211
97
  @property
212
98
  def federation_manager(self) -> FederationManager:
213
- """Get the FederationManager instance."""
99
+ """Return the FederationManager instance."""
214
100
  return self._federation_manager
215
101
 
216
102
  def store_message_ins(self, message: Message) -> str | None:
@@ -238,20 +124,26 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
238
124
  )
239
125
  return None
240
126
 
241
- with self.conn:
127
+ with self.session():
242
128
  # Validate run_id
243
- query = "SELECT federation FROM run WHERE run_id = ?;"
244
- rows = self.conn.execute(query, (data[0]["run_id"],)).fetchall()
129
+ query = "SELECT federation FROM run WHERE run_id = :run_id"
130
+ rows = self.query(query, {"run_id": data[0]["run_id"]})
245
131
  if not rows:
246
132
  log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
247
133
  return None
248
134
  federation: str = rows[0]["federation"]
249
135
 
250
136
  # Validate destination node ID
251
- query = "SELECT node_id FROM node WHERE node_id = ? AND status IN (?, ?);"
252
- rows = self.conn.execute(
253
- query, (data[0]["dst_node_id"], NodeStatus.ONLINE, NodeStatus.OFFLINE)
254
- ).fetchall()
137
+ query = """SELECT node_id FROM node WHERE node_id = :node_id
138
+ AND status IN (:online, :offline)"""
139
+ rows = self.query(
140
+ query,
141
+ {
142
+ "node_id": data[0]["dst_node_id"],
143
+ "online": NodeStatus.ONLINE,
144
+ "offline": NodeStatus.OFFLINE,
145
+ },
146
+ )
255
147
  if not rows or not self.federation_manager.has_node(
256
148
  message.metadata.dst_node_id, federation
257
149
  ):
@@ -262,29 +154,62 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
262
154
  )
263
155
  return None
264
156
 
157
+ # Insert message
265
158
  columns = ", ".join([f":{key}" for key in data[0]])
266
- query = f"INSERT INTO message_ins VALUES({columns});"
159
+ query = f"INSERT INTO message_ins VALUES({columns})"
267
160
 
268
161
  # Only invalid run_id can trigger IntegrityError.
269
162
  # This may need to be changed in the future version
270
163
  # with more integrity checks.
271
- self.conn.execute(query, data[0])
164
+ self.query(query, data[0])
272
165
 
273
166
  return message.metadata.message_id
274
167
 
168
+ # pylint: disable-next=too-many-locals
275
169
  def _check_stored_messages(self, message_ids: set[str]) -> None:
276
170
  """Check and delete the message if it's invalid."""
277
171
  if not message_ids:
278
172
  return
279
173
 
280
- with self.conn:
174
+ with self.session():
175
+ # Batch fetch all messages in one query
176
+ placeholders = ",".join([f":mid_{i}" for i in range(len(message_ids))])
177
+ query = f"""
178
+ SELECT * FROM message_ins
179
+ WHERE message_id IN ({placeholders})
180
+ """
181
+ params = {f"mid_{i}": str(mid) for i, mid in enumerate(message_ids)}
182
+ message_rows = self.query(query, params)
183
+
184
+ if not message_rows:
185
+ return
186
+
187
+ # Build message lookup dict
188
+ message_dict: dict[str, dict[str, Any]] = {
189
+ row["message_id"]: row for row in message_rows
190
+ }
191
+
192
+ # Collect unique run_ids for batch federation lookup
193
+ run_ids = {row["run_id"] for row in message_rows}
194
+ placeholders = ",".join([f":rid_{i}" for i in range(len(run_ids))])
195
+ query = f"""
196
+ SELECT run_id, federation FROM run
197
+ WHERE run_id IN ({placeholders})
198
+ """
199
+ params = {f"rid_{i}": rid for i, rid in enumerate(run_ids)}
200
+ run_rows = self.query(query, params)
201
+
202
+ # Build run_id to federation mapping
203
+ run_id_to_federation: dict[int, str] = {
204
+ row["run_id"]: row["federation"] for row in run_rows
205
+ }
206
+
281
207
  invalid_msg_ids: set[str] = set()
282
208
  current_time = now().timestamp()
283
209
 
210
+ # Check each message for validity
284
211
  for msg_id in message_ids:
285
- # Check if message exists
286
- query = "SELECT * FROM message_ins WHERE message_id = ?;"
287
- message_row = self.conn.execute(query, (msg_id,)).fetchone()
212
+ message_row = message_dict.get(msg_id)
288
213
  if not message_row:
289
214
  continue
290
215
 
@@ -294,15 +219,12 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
294
219
  invalid_msg_ids.add(msg_id)
295
220
  continue
296
221
 
297
- # Check if src_node_id and dst_node_id are in the federation
298
- # Get federation from run table
222
+ # Check if run exists and get federation
299
223
  run_id = message_row["run_id"]
300
- query = "SELECT federation FROM run WHERE run_id = ?;"
301
- run_row = self.conn.execute(query, (run_id,)).fetchone()
302
- if not run_row: # This should not happen
224
+ federation = run_id_to_federation.get(run_id)
225
+ if not federation:
303
226
  invalid_msg_ids.add(msg_id)
304
227
  continue
305
- federation = run_row["federation"]
306
228
 
307
229
  # Convert sint64 to uint64 for node IDs
308
230
  src_node_id = int64_to_uint64(message_row["src_node_id"])
@@ -327,52 +249,48 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
327
249
  msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
328
250
  raise AssertionError(msg)
329
251
 
330
- data: dict[str, str | int] = {}
252
+ params: dict[str, str | int] = {}
331
253
 
332
254
  # Convert the uint64 value to sint64 for SQLite
333
- data["node_id"] = uint64_to_int64(node_id)
255
+ params["node_id"] = uint64_to_int64(node_id)
334
256
 
335
- with self.conn:
257
+ with self.session():
336
258
  # Retrieve all Messages for node_id
337
259
  query = """
338
260
  SELECT message_id
339
261
  FROM message_ins
340
- WHERE dst_node_id == :node_id
341
- AND delivered_at = ""
342
- AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
262
+ WHERE dst_node_id = :node_id
263
+ AND delivered_at = ''
264
+ AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
343
265
  """
344
266
 
345
267
  if limit is not None:
346
268
  query += " LIMIT :limit"
347
- data["limit"] = limit
269
+ params["limit"] = limit
348
270
 
349
- query += ";"
350
-
351
- rows = self.conn.execute(query, data).fetchall()
271
+ rows = self.query(query, params)
352
272
  message_ids: set[str] = {row["message_id"] for row in rows}
353
273
  self._check_stored_messages(message_ids)
354
274
 
355
275
  # Mark retrieved Messages as delivered
356
276
  if rows:
357
277
  # Prepare query
358
- placeholders: str = ",".join(
359
- [f":id_{i}" for i in range(len(message_ids))]
360
- )
278
+ placeholders = ",".join([f":mid_{i}" for i in range(len(message_ids))])
361
279
  query = f"""
362
280
  UPDATE message_ins
363
281
  SET delivered_at = :delivered_at
364
282
  WHERE message_id IN ({placeholders})
365
- RETURNING *;
283
+ RETURNING *
366
284
  """
367
285
 
368
286
  # Prepare data for query
369
287
  delivered_at = now().isoformat()
370
- data = {"delivered_at": delivered_at}
288
+ params = {"delivered_at": delivered_at}
371
289
  for index, msg_id in enumerate(message_ids):
372
- data[f"id_{index}"] = str(msg_id)
290
+ params[f"mid_{index}"] = str(msg_id)
373
291
 
374
292
  # Run query
375
- rows = self.conn.execute(query, data).fetchall()
293
+ rows = self.query(query, params)
376
294
 
377
295
  for row in rows:
378
296
  # Convert values from sint64 to uint64
@@ -380,7 +298,7 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
380
298
  row, ["run_id", "src_node_id", "dst_node_id"]
381
299
  )
382
300
 
383
- result = [dict_to_message(row) for row in rows]
301
+ result = [dict_to_message(dict(row)) for row in rows]
384
302
 
385
303
  return result
386
304
 
@@ -406,8 +324,8 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
406
324
  )
407
325
  return None
408
326
 
409
- # Ensure that the dst_node_id of the original message matches the src_node_id of
410
- # reply being processed.
327
+ # Ensure that the dst_node_id of the original message matches the src_node_id
328
+ # of reply being processed.
411
329
  if (
412
330
  msg_ins
413
331
  and message
@@ -437,21 +355,19 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
437
355
  return None
438
356
 
439
357
  # Store Message
440
- data = (message_to_dict(message),)
358
+ msg_dict = message_to_dict(message)
441
359
 
442
360
  # Convert values from uint64 to sint64 for SQLite
443
361
  convert_uint64_values_in_dict_to_sint64(
444
- data[0], ["run_id", "src_node_id", "dst_node_id"]
362
+ msg_dict, ["run_id", "src_node_id", "dst_node_id"]
445
363
  )
446
364
 
447
- columns = ", ".join([f":{key}" for key in data[0]])
448
- query = f"INSERT INTO message_res VALUES({columns});"
365
+ columns = ", ".join([f":{key}" for key in msg_dict])
366
+ query = f"INSERT INTO message_res VALUES({columns})"
449
367
 
450
- # Only invalid run_id can trigger IntegrityError.
451
- # This may need to be changed in the future version with more integrity checks.
452
368
  try:
453
- self.query(query, data)
454
- except sqlite3.IntegrityError:
369
+ self.query(query, msg_dict)
370
+ except IntegrityError:
455
371
  log(ERROR, "`run` is invalid")
456
372
  return None
457
373
 
@@ -459,21 +375,21 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
459
375
 
460
376
  def get_message_res(self, message_ids: set[str]) -> list[Message]:
461
377
  """Get reply Messages for the given Message IDs."""
462
- # pylint: disable-msg=too-many-locals
378
+ # pylint: disable=too-many-locals
463
379
  ret: dict[str, Message] = {}
464
380
 
465
- with self.conn:
381
+ with self.session():
466
382
  # Verify Message IDs
467
383
  self._check_stored_messages(message_ids)
468
384
  current = now().timestamp()
385
+ placeholders = ",".join([f":mid_{i}" for i in range(len(message_ids))])
469
386
  query = f"""
470
387
  SELECT *
471
388
  FROM message_ins
472
- WHERE message_id IN ({','.join(['?'] * len(message_ids))});
389
+ WHERE message_id IN ({placeholders})
473
390
  """
474
- rows = self.conn.execute(
475
- query, tuple(str(message_id) for message_id in message_ids)
476
- ).fetchall()
391
+ params = {f"mid_{i}": str(mid) for i, mid in enumerate(message_ids)}
392
+ rows = self.query(query, params)
477
393
  found_message_ins_dict: dict[str, Message] = {}
478
394
  for row in rows:
479
395
  convert_sint64_values_in_dict_to_uint64(
@@ -493,15 +409,18 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
493
409
  in_message = found_message_ins_dict[message_id]
494
410
  sint_node_id = uint64_to_int64(in_message.metadata.dst_node_id)
495
411
  dst_node_ids.add(sint_node_id)
412
+ placeholders = ",".join([f":nid_{i}" for i in range(len(dst_node_ids))])
496
413
  query = f"""
497
414
  SELECT node_id, online_until
498
415
  FROM node
499
- WHERE node_id IN ({','.join(['?'] * len(dst_node_ids))})
500
- AND status != ?
416
+ WHERE node_id IN ({placeholders})
417
+ AND status != :unregistered
501
418
  """
502
- rows = self.conn.execute(
503
- query, tuple(dst_node_ids) + (NodeStatus.UNREGISTERED,)
504
- ).fetchall()
419
+ node_params: dict[str, int | str] = {
420
+ f"nid_{i}": nid for i, nid in enumerate(dst_node_ids)
421
+ }
422
+ node_params["unregistered"] = NodeStatus.UNREGISTERED
423
+ rows = self.query(query, node_params)
505
424
  tmp_ret_dict = check_node_availability_for_in_message(
506
425
  inquired_in_message_ids=message_ids,
507
426
  found_in_message_dict=found_message_ins_dict,
@@ -513,15 +432,15 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
513
432
  ret.update(tmp_ret_dict)
514
433
 
515
434
  # Find all reply Messages
435
+ placeholders = ",".join([f":mid_{i}" for i in range(len(message_ids))])
516
436
  query = f"""
517
437
  SELECT *
518
438
  FROM message_res
519
- WHERE reply_to_message_id IN ({','.join(['?'] * len(message_ids))})
520
- AND delivered_at = "";
439
+ WHERE reply_to_message_id IN ({placeholders})
440
+ AND delivered_at = ''
521
441
  """
522
- rows = self.conn.execute(
523
- query, tuple(str(message_id) for message_id in message_ids)
524
- ).fetchall()
442
+ params = {f"mid_{i}": str(mid) for i, mid in enumerate(message_ids)}
443
+ rows = self.query(query, params)
525
444
  for row in rows:
526
445
  convert_sint64_values_in_dict_to_uint64(
527
446
  row, ["run_id", "src_node_id", "dst_node_id"]
@@ -541,13 +460,15 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
541
460
  message_res_ids = [
542
461
  message_res.metadata.message_id for message_res in ret.values()
543
462
  ]
463
+ placeholders = ",".join([f":mid_{i}" for i in range(len(message_res_ids))])
544
464
  query = f"""
545
465
  UPDATE message_res
546
- SET delivered_at = ?
547
- WHERE message_id IN ({','.join(['?'] * len(message_res_ids))});
466
+ SET delivered_at = :delivered_at
467
+ WHERE message_id IN ({placeholders})
548
468
  """
549
- data: list[Any] = [delivered_at] + message_res_ids
550
- self.conn.execute(query, data)
469
+ params = {"delivered_at": delivered_at}
470
+ params.update({f"mid_{i}": mid for i, mid in enumerate(message_res_ids)})
471
+ self.query(query, params)
551
472
 
552
473
  return list(ret.values())
553
474
 
@@ -556,64 +477,55 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
556
477
 
557
478
  This includes delivered but not yet deleted.
558
479
  """
559
- query = "SELECT count(*) AS num FROM message_ins;"
560
- rows = self.query(query)
561
- result = rows[0]
562
- num = cast(int, result["num"])
563
- return num
480
+ query = "SELECT count(*) AS num FROM message_ins"
481
+ rows = self.query(query, {})
482
+ return int(rows[0]["num"])
564
483
 
565
484
  def num_message_res(self) -> int:
566
485
  """Calculate the number of reply Messages in store.
567
486
 
568
487
  This includes delivered but not yet deleted.
569
488
  """
570
- query = "SELECT count(*) AS num FROM message_res;"
489
+ query = "SELECT count(*) AS num FROM message_res"
571
490
  rows = self.query(query)
572
- result: dict[str, int] = rows[0]
573
- return result["num"]
491
+ return int(rows[0]["num"])
574
492
 
575
493
  def delete_messages(self, message_ins_ids: set[str]) -> None:
576
494
  """Delete a Message and its reply based on provided Message IDs."""
577
495
  if not message_ins_ids:
578
496
  return
579
- if self.conn is None:
580
- raise AttributeError("LinkState not initialized")
581
497
 
582
- placeholders = ",".join(["?"] * len(message_ins_ids))
583
- data = tuple(str(message_id) for message_id in message_ins_ids)
498
+ placeholders = ",".join([f":mid_{i}" for i in range(len(message_ins_ids))])
499
+ params = {f"mid_{i}": str(mid) for i, mid in enumerate(message_ins_ids)}
584
500
 
585
501
  # Delete Message
586
502
  query_1 = f"""
587
503
  DELETE FROM message_ins
588
- WHERE message_id IN ({placeholders});
504
+ WHERE message_id IN ({placeholders})
589
505
  """
590
506
 
591
507
  # Delete reply Message
592
508
  query_2 = f"""
593
509
  DELETE FROM message_res
594
- WHERE reply_to_message_id IN ({placeholders});
510
+ WHERE reply_to_message_id IN ({placeholders})
595
511
  """
596
512
 
597
- with self.conn:
598
- self.conn.execute(query_1, data)
599
- self.conn.execute(query_2, data)
513
+ with self.session():
514
+ self.query(query_1, params)
515
+ self.query(query_2, params)
600
516
 
601
517
  def get_message_ids_from_run_id(self, run_id: int) -> set[str]:
602
518
  """Get all instruction Message IDs for the given run_id."""
603
- if self.conn is None:
604
- raise AttributeError("LinkState not initialized")
605
-
606
519
  query = """
607
520
  SELECT message_id
608
521
  FROM message_ins
609
- WHERE run_id = :run_id;
522
+ WHERE run_id = :run_id
610
523
  """
611
-
612
524
  sint64_run_id = uint64_to_int64(run_id)
613
- data = {"run_id": sint64_run_id}
525
+ params = {"run_id": sint64_run_id}
614
526
 
615
- with self.conn:
616
- rows = self.conn.execute(query, data).fetchall()
527
+ with self.session():
528
+ rows = self.query(query, params)
617
529
 
618
530
  return {row["message_id"] for row in rows}
619
531
 
@@ -638,29 +550,31 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
638
550
  (node_id, owner_aid, owner_name, status, registered_at, last_activated_at,
639
551
  last_deactivated_at, unregistered_at, online_until, heartbeat_interval,
640
552
  public_key)
641
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
553
+ VALUES (:node_id, :owner_aid, :owner_name, :status, :registered_at,
554
+ :last_activated_at, :last_deactivated_at, :unregistered_at, :online_until,
555
+ :heartbeat_interval, :public_key)
642
556
  """
643
557
 
644
558
  # Mark the node online until now().timestamp() + heartbeat_interval
645
559
  try:
646
560
  self.query(
647
561
  query,
648
- (
649
- sint64_node_id, # node_id
650
- owner_aid, # owner_aid
651
- owner_name, # owner_name
652
- NodeStatus.REGISTERED, # status
653
- now().isoformat(), # registered_at
654
- None, # last_activated_at
655
- None, # last_deactivated_at
656
- None, # unregistered_at
657
- None, # online_until, initialized with offline status
658
- heartbeat_interval, # heartbeat_interval
659
- public_key, # public_key
660
- ),
562
+ {
563
+ "node_id": sint64_node_id,
564
+ "owner_aid": owner_aid,
565
+ "owner_name": owner_name,
566
+ "status": NodeStatus.REGISTERED,
567
+ "registered_at": now().isoformat(),
568
+ "last_activated_at": None,
569
+ "last_deactivated_at": None,
570
+ "unregistered_at": None,
571
+ "online_until": None, # initialized with offline status
572
+ "heartbeat_interval": heartbeat_interval,
573
+ "public_key": public_key,
574
+ },
661
575
  )
662
- except sqlite3.IntegrityError as e:
663
- if "UNIQUE constraint failed: node.public_key" in str(e):
576
+ except IntegrityError as e:
577
+ if "node.public_key" in str(e):
664
578
  raise ValueError("Public key already in use.") from None
665
579
  # Must be node ID conflict, almost impossible unless system is compromised
666
580
  log(ERROR, "Unexpected node registration failure.")
@@ -675,21 +589,20 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
675
589
 
676
590
  query = """
677
591
  UPDATE node
678
- SET status = ?, unregistered_at = ?,
679
- online_until = IIF(online_until > ?, ?, online_until)
680
- WHERE node_id = ? AND status != ? AND owner_aid = ?
592
+ SET status = :unregistered, unregistered_at = :unregistered_at,
593
+ online_until = IIF(online_until > :current, :current, online_until)
594
+ WHERE node_id = :node_id AND status != :unregistered
595
+ AND owner_aid = :owner_aid
681
596
  RETURNING node_id
682
597
  """
683
598
  current = now()
684
- params = (
685
- NodeStatus.UNREGISTERED,
686
- current.isoformat(),
687
- current.timestamp(),
688
- current.timestamp(),
689
- sint64_node_id,
690
- NodeStatus.UNREGISTERED,
691
- owner_aid,
692
- )
599
+ params = {
600
+ "unregistered": NodeStatus.UNREGISTERED,
601
+ "unregistered_at": current.isoformat(),
602
+ "current": current.timestamp(),
603
+ "node_id": sint64_node_id,
604
+ "owner_aid": owner_aid,
605
+ }
693
606
 
694
607
  rows = self.query(query, params)
695
608
  if not rows:
@@ -700,58 +613,58 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
700
613
 
701
614
  def activate_node(self, node_id: int, heartbeat_interval: float) -> bool:
702
615
  """Activate the node with the specified `node_id`."""
703
- with self.conn:
704
- self._check_and_tag_offline_nodes([node_id])
616
+ self._check_and_tag_offline_nodes([node_id])
705
617
 
706
- # Only activate if the node is currently registered or offline
707
- current_dt = now()
708
- query = """
709
- UPDATE node
710
- SET status = ?,
711
- last_activated_at = ?,
712
- online_until = ?,
713
- heartbeat_interval = ?
714
- WHERE node_id = ? AND status in (?, ?)
715
- RETURNING node_id
716
- """
717
- params = (
718
- NodeStatus.ONLINE,
719
- current_dt.isoformat(),
720
- current_dt.timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval,
721
- heartbeat_interval,
722
- uint64_to_int64(node_id),
723
- NodeStatus.REGISTERED,
724
- NodeStatus.OFFLINE,
725
- )
618
+ # Only activate if the node is currently registered or offline
619
+ current_dt = now()
620
+ sint64_node_id = uint64_to_int64(node_id)
621
+ query = """
622
+ UPDATE node
623
+ SET status = :online,
624
+ last_activated_at = :current,
625
+ online_until = :online_until,
626
+ heartbeat_interval = :heartbeat_interval
627
+ WHERE node_id = :node_id AND status IN (:registered, :offline)
628
+ RETURNING node_id
629
+ """
630
+ params = {
631
+ "online": NodeStatus.ONLINE,
632
+ "current": current_dt.isoformat(),
633
+ "online_until": current_dt.timestamp()
634
+ + HEARTBEAT_PATIENCE * heartbeat_interval,
635
+ "heartbeat_interval": heartbeat_interval,
636
+ "node_id": sint64_node_id,
637
+ "registered": NodeStatus.REGISTERED,
638
+ "offline": NodeStatus.OFFLINE,
639
+ }
726
640
 
727
- row = self.conn.execute(query, params).fetchone()
728
- return row is not None
641
+ rows = self.query(query, params)
642
+ return len(rows) > 0
729
643
 
730
644
  def deactivate_node(self, node_id: int) -> bool:
731
645
  """Deactivate the node with the specified `node_id`."""
732
- with self.conn:
733
- self._check_and_tag_offline_nodes([node_id])
646
+ self._check_and_tag_offline_nodes([node_id])
734
647
 
735
- # Only deactivate if the node is currently online
736
- current_dt = now()
737
- query = """
738
- UPDATE node
739
- SET status = ?,
740
- last_deactivated_at = ?,
741
- online_until = ?
742
- WHERE node_id = ? AND status = ?
743
- RETURNING node_id
744
- """
745
- params = (
746
- NodeStatus.OFFLINE,
747
- current_dt.isoformat(),
748
- current_dt.timestamp(),
749
- uint64_to_int64(node_id),
750
- NodeStatus.ONLINE,
751
- )
648
+ # Only deactivate if the node is currently online
649
+ current_dt = now()
650
+ query = """
651
+ UPDATE node
652
+ SET status = :offline,
653
+ last_deactivated_at = :current_iso,
654
+ online_until = :current_ts
655
+ WHERE node_id = :node_id AND status = :online
656
+ RETURNING node_id
657
+ """
658
+ params = {
659
+ "offline": NodeStatus.OFFLINE,
660
+ "current_iso": current_dt.isoformat(),
661
+ "current_ts": current_dt.timestamp(),
662
+ "node_id": uint64_to_int64(node_id),
663
+ "online": NodeStatus.ONLINE,
664
+ }
752
665
 
753
- row = self.conn.execute(query, params).fetchone()
754
- return row is not None
666
+ rows = self.query(query, params)
667
+ return len(rows) > 0
755
668
 
756
669
  def get_nodes(self, run_id: int) -> set[int]:
757
670
  """Retrieve all currently stored node IDs as a set.
@@ -761,16 +674,13 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
761
674
  If the provided `run_id` does not exist or has no matching nodes,
762
675
  an empty `Set` MUST be returned.
763
676
  """
764
- if self.conn is None:
765
- raise AttributeError("LinkState not initialized")
766
-
767
- with self.conn:
677
+ with self.session():
768
678
  # Convert the uint64 value to sint64 for SQLite
769
679
  sint64_run_id = uint64_to_int64(run_id)
770
680
 
771
681
  # Validate run ID
772
- query = "SELECT federation FROM run WHERE run_id = ?"
773
- rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
682
+ query = "SELECT federation FROM run WHERE run_id = :run_id"
683
+ rows = self.query(query, {"run_id": sint64_run_id})
774
684
  if not rows:
775
685
  return set()
776
686
  federation: str = rows[0]["federation"]
@@ -787,23 +697,25 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
787
697
  """Check and tag offline nodes."""
788
698
  # strftime will convert POSIX timestamp to ISO format
789
699
  query = """
790
- UPDATE node SET status = ?,
700
+ UPDATE node SET status = :offline,
791
701
  last_deactivated_at =
792
- strftime("%Y-%m-%dT%H:%M:%f+00:00", online_until, "unixepoch")
793
- WHERE online_until <= ? AND status == ?
702
+ strftime('%Y-%m-%dT%H:%M:%f+00:00', online_until, 'unixepoch')
703
+ WHERE online_until <= :current_time AND status = :online
794
704
  """
795
- params = [
796
- NodeStatus.OFFLINE,
797
- now().timestamp(),
798
- NodeStatus.ONLINE,
799
- ]
705
+ params: dict[str, Any] = {
706
+ "offline": NodeStatus.OFFLINE,
707
+ "current_time": now().timestamp(),
708
+ "online": NodeStatus.ONLINE,
709
+ }
800
710
  if node_ids is not None:
801
- placeholders = ",".join(["?"] * len(node_ids))
711
+ placeholders = ",".join([f":nid_{i}" for i in range(len(node_ids))])
802
712
  query += f" AND node_id IN ({placeholders})"
803
- params.extend(uint64_to_int64(node_id) for node_id in node_ids)
804
- self.conn.execute(query, params)
713
+ params.update(
714
+ {f"nid_{i}": uint64_to_int64(nid) for i, nid in enumerate(node_ids)}
715
+ )
716
+ self.query(query, params)
805
717
 
806
- def get_node_info(
718
+ def get_node_info( # pylint: disable=too-many-locals
807
719
  self,
808
720
  *,
809
721
  node_ids: Sequence[int] | None = None,
@@ -811,32 +723,37 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
811
723
  statuses: Sequence[str] | None = None,
812
724
  ) -> Sequence[NodeInfo]:
813
725
  """Retrieve information about nodes based on the specified filters."""
814
- with self.conn:
726
+ with self.session():
815
727
  self._check_and_tag_offline_nodes()
816
728
 
817
729
  # Build the WHERE clause based on provided filters
818
730
  conditions = []
819
- params: list[Any] = []
731
+ params: dict[str, Any] = {}
820
732
  if node_ids is not None:
821
733
  sint64_node_ids = [uint64_to_int64(node_id) for node_id in node_ids]
822
- placeholders = ",".join(["?"] * len(sint64_node_ids))
734
+ placeholders = ",".join(
735
+ [f":nid_{i}" for i in range(len(sint64_node_ids))]
736
+ )
823
737
  conditions.append(f"node_id IN ({placeholders})")
824
- params.extend(sint64_node_ids)
738
+ for i, nid in enumerate(sint64_node_ids):
739
+ params[f"nid_{i}"] = nid
825
740
  if owner_aids is not None:
826
- placeholders = ",".join(["?"] * len(owner_aids))
741
+ placeholders = ",".join([f":aid_{i}" for i in range(len(owner_aids))])
827
742
  conditions.append(f"owner_aid IN ({placeholders})")
828
- params.extend(owner_aids)
743
+ for i, aid in enumerate(owner_aids):
744
+ params[f"aid_{i}"] = aid
829
745
  if statuses is not None:
830
- placeholders = ",".join(["?"] * len(statuses))
746
+ placeholders = ",".join([f":st_{i}" for i in range(len(statuses))])
831
747
  conditions.append(f"status IN ({placeholders})")
832
- params.extend(statuses)
748
+ for i, status in enumerate(statuses):
749
+ params[f"st_{i}"] = status
833
750
 
834
751
  # Construct the final query
835
752
  query = "SELECT * FROM node"
836
753
  if conditions:
837
754
  query += " WHERE " + " AND ".join(conditions)
838
755
 
839
- rows = self.conn.execute(query, params).fetchall()
756
+ rows = self.query(query, params)
840
757
 
841
758
  result: list[NodeInfo] = []
842
759
  for row in rows:
@@ -846,27 +763,14 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
846
763
 
847
764
  return result
848
765
 
849
- def get_node_public_key(self, node_id: int) -> bytes:
850
- """Get `public_key` for the specified `node_id`."""
851
- # Convert the uint64 value to sint64 for SQLite
852
- sint64_node_id = uint64_to_int64(node_id)
853
-
854
- # Query the public key for the given node_id
855
- query = "SELECT public_key FROM node WHERE node_id = ? AND status != ?;"
856
- rows = self.query(query, (sint64_node_id, NodeStatus.UNREGISTERED))
857
-
858
- # If no result is found, return None
859
- if not rows:
860
- raise ValueError(f"Node ID {node_id} not found")
861
-
862
- # Return the public key
863
- return cast(bytes, rows[0]["public_key"])
864
-
865
766
  def get_node_id_by_public_key(self, public_key: bytes) -> int | None:
866
767
  """Get `node_id` for the specified `public_key` if it exists and is not
867
768
  deleted."""
868
- query = "SELECT node_id FROM node WHERE public_key = ? AND status != ?;"
869
- rows = self.query(query, (public_key, NodeStatus.UNREGISTERED))
769
+ query = """SELECT node_id FROM node
770
+ WHERE public_key = :public_key AND status != :unregistered;"""
771
+ rows = self.query(
772
+ query, {"public_key": public_key, "unregistered": NodeStatus.UNREGISTERED}
773
+ )
870
774
 
871
775
  # If no result is found, return None
872
776
  if not rows:
@@ -876,8 +780,7 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
876
780
  node_id = int64_to_uint64(rows[0]["node_id"])
877
781
  return node_id
878
782
 
879
- # pylint: disable=too-many-arguments,too-many-positional-arguments
880
- def create_run(
783
+ def create_run( # pylint: disable=too-many-arguments,too-many-positional-arguments
881
784
  self,
882
785
  fab_id: str | None,
883
786
  fab_version: str | None,
@@ -894,38 +797,43 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
894
797
  # Convert the uint64 value to sint64 for SQLite
895
798
  sint64_run_id = uint64_to_int64(uint64_run_id)
896
799
 
897
- with self.conn:
800
+ with self.session():
898
801
  # Check conflicts
899
- query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
900
- # If sint64_run_id does not exist
901
- row = self.conn.execute(query, (sint64_run_id,)).fetchone()
902
- if row["COUNT(*)"] == 0:
802
+ query = "SELECT COUNT(*) as cnt FROM run WHERE run_id = :run_id"
803
+ rows = self.query(query, {"run_id": sint64_run_id})
804
+ if rows[0]["cnt"] == 0:
903
805
  query = """
904
806
  INSERT INTO run
905
807
  (run_id, fab_id, fab_version,
906
808
  fab_hash, override_config, federation, federation_options,
907
809
  pending_at, starting_at, running_at, finished_at, sub_status,
908
- details, flwr_aid)
909
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
810
+ details, flwr_aid, bytes_sent, bytes_recv, clientapp_runtime)
811
+ VALUES (:run_id, :fab_id, :fab_version, :fab_hash, :override_config,
812
+ :federation, :federation_options, :pending_at, :starting_at,
813
+ :running_at, :finished_at, :sub_status, :details, :flwr_aid,
814
+ :bytes_sent, :bytes_recv, :clientapp_runtime)
910
815
  """
911
816
  override_config_json = json.dumps(override_config)
912
- data = [
913
- sint64_run_id, # run_id
914
- fab_id, # fab_id
915
- fab_version, # fab_version
916
- fab_hash, # fab_hash
917
- override_config_json, # override_config
918
- federation, # federation
919
- configrecord_to_bytes(federation_options), # federation_options
920
- now().isoformat(), # pending_at
921
- "", # starting_at
922
- "", # running_at
923
- "", # finished_at
924
- "", # sub_status
925
- "", # details
926
- flwr_aid or "", # flwr_aid
927
- ]
928
- self.conn.execute(query, tuple(data))
817
+ params = {
818
+ "run_id": sint64_run_id,
819
+ "fab_id": fab_id or "",
820
+ "fab_version": fab_version or "",
821
+ "fab_hash": fab_hash or "",
822
+ "override_config": override_config_json,
823
+ "federation": federation,
824
+ "federation_options": configrecord_to_bytes(federation_options),
825
+ "pending_at": now().isoformat(),
826
+ "starting_at": "",
827
+ "running_at": "",
828
+ "finished_at": "",
829
+ "sub_status": "",
830
+ "details": "",
831
+ "flwr_aid": flwr_aid or "",
832
+ "bytes_sent": 0,
833
+ "bytes_recv": 0,
834
+ "clientapp_runtime": 0.0,
835
+ }
836
+ self.query(query, params)
929
837
  return uint64_run_id
930
838
  log(ERROR, "Unexpected run creation failure.")
931
839
  return 0
@@ -937,11 +845,11 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
937
845
  """
938
846
  if flwr_aid:
939
847
  rows = self.query(
940
- "SELECT run_id FROM run WHERE flwr_aid = ?;",
941
- (flwr_aid,),
848
+ "SELECT run_id FROM run WHERE flwr_aid = :flwr_aid",
849
+ {"flwr_aid": flwr_aid},
942
850
  )
943
851
  else:
944
- rows = self.query("SELECT run_id FROM run;", ())
852
+ rows = self.query("SELECT run_id FROM run", {})
945
853
  return {int64_to_uint64(row["run_id"]) for row in rows}
946
854
 
947
855
  def get_run(self, run_id: int) -> Run | None:
@@ -951,8 +859,8 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
951
859
 
952
860
  # Convert the uint64 value to sint64 for SQLite
953
861
  sint64_run_id = uint64_to_int64(run_id)
954
- query = "SELECT * FROM run WHERE run_id = ?;"
955
- rows = self.query(query, (sint64_run_id,))
862
+ query = "SELECT * FROM run WHERE run_id = :run_id"
863
+ rows = self.query(query, {"run_id": sint64_run_id})
956
864
  if rows:
957
865
  row = rows[0]
958
866
  return Run(
@@ -972,6 +880,9 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
972
880
  ),
973
881
  flwr_aid=row["flwr_aid"],
974
882
  federation=row["federation"],
883
+ bytes_sent=row["bytes_sent"],
884
+ bytes_recv=row["bytes_recv"],
885
+ clientapp_runtime=row["clientapp_runtime"],
975
886
  )
976
887
  log(ERROR, "`run_id` does not exist.")
977
888
  return None
@@ -982,9 +893,10 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
982
893
  self._cleanup_expired_tokens()
983
894
 
984
895
  # Convert the uint64 value to sint64 for SQLite
985
- sint64_run_ids = (uint64_to_int64(run_id) for run_id in set(run_ids))
986
- query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});"
987
- rows = self.query(query, tuple(sint64_run_ids))
896
+ placeholders = ",".join([f":rid_{i}" for i in range(len(run_ids))])
897
+ query = f"SELECT * FROM run WHERE run_id IN ({placeholders})"
898
+ params = {f"rid_{i}": uint64_to_int64(rid) for i, rid in enumerate(run_ids)}
899
+ rows = self.query(query, params)
988
900
 
989
901
  return {
990
902
  # Restore uint64 run IDs
@@ -1001,11 +913,11 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1001
913
  # Clean up expired tokens; this will flag inactive runs as needed
1002
914
  self._cleanup_expired_tokens()
1003
915
 
1004
- with self.conn:
916
+ with self.session():
1005
917
  # Convert the uint64 value to sint64 for SQLite
1006
918
  sint64_run_id = uint64_to_int64(run_id)
1007
- query = "SELECT * FROM run WHERE run_id = ?;"
1008
- rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
919
+ query = "SELECT * FROM run WHERE run_id = :run_id"
920
+ rows = self.query(query, {"run_id": sint64_run_id})
1009
921
 
1010
922
  # Check if the run_id exists
1011
923
  if not rows:
@@ -1040,7 +952,9 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1040
952
 
1041
953
  # Update the status
1042
954
  query = """
1043
- UPDATE run SET %s= ?, sub_status = ?, details = ? WHERE run_id = ?;
955
+ UPDATE run SET %s = :timestamp,
956
+ sub_status = :sub_status, details = :details
957
+ WHERE run_id = :run_id
1044
958
  """
1045
959
 
1046
960
  # Prepare data for query
@@ -1055,33 +969,30 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1055
969
  elif new_status.status == Status.FINISHED:
1056
970
  timestamp_fld = "finished_at"
1057
971
 
1058
- data = (
1059
- current.isoformat(),
1060
- new_status.sub_status,
1061
- new_status.details,
1062
- uint64_to_int64(run_id),
1063
- )
1064
- self.conn.execute(query % timestamp_fld, data)
972
+ params = {
973
+ "timestamp": current.isoformat(),
974
+ "sub_status": new_status.sub_status,
975
+ "details": new_status.details,
976
+ "run_id": sint64_run_id,
977
+ }
978
+ self.query(query % timestamp_fld, params)
1065
979
  return True
1066
980
 
1067
981
  def get_pending_run_id(self) -> int | None:
1068
- """Get the `run_id` of a run with `Status.PENDING` status, if any."""
1069
- pending_run_id = None
1070
-
982
+ """Get the `run_id` of a run with `Status.PENDING` status."""
1071
983
  # Fetch all runs with unset `starting_at` (i.e. they are in PENDING status)
1072
- query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1;"
1073
- rows = self.query(query)
984
+ query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1"
985
+ rows = self.query(query, {})
1074
986
  if rows:
1075
- pending_run_id = int64_to_uint64(rows[0]["run_id"])
1076
-
1077
- return pending_run_id
987
+ return int64_to_uint64(rows[0]["run_id"])
988
+ return None
1078
989
 
1079
990
  def get_federation_options(self, run_id: int) -> ConfigRecord | None:
1080
991
  """Retrieve the federation options for the specified `run_id`."""
1081
992
  # Convert the uint64 value to sint64 for SQLite
1082
993
  sint64_run_id = uint64_to_int64(run_id)
1083
- query = "SELECT federation_options FROM run WHERE run_id = ?;"
1084
- rows = self.query(query, (sint64_run_id,))
994
+ query = "SELECT federation_options FROM run WHERE run_id = :run_id"
995
+ rows = self.query(query, {"run_id": sint64_run_id})
1085
996
 
1086
997
  # Check if the run_id exists
1087
998
  if not rows:
@@ -1101,41 +1012,46 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1101
1012
  HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before
1102
1013
  the node is marked as offline.
1103
1014
  """
1104
- if self.conn is None:
1105
- raise AttributeError("LinkState not initialized")
1106
-
1107
1015
  sint64_node_id = uint64_to_int64(node_id)
1108
1016
 
1109
- with self.conn:
1110
- # Check if node exists and not deleted
1111
- query = "SELECT status FROM node WHERE node_id = ? AND status != ?"
1112
- row = self.conn.execute(
1113
- query, (sint64_node_id, NodeStatus.UNREGISTERED)
1114
- ).fetchone()
1115
- if row is None:
1116
- return False
1017
+ # Check if the node exists and is not unregistered
1018
+ query = """
1019
+ SELECT status FROM node WHERE node_id = :node_id AND status != :unregistered
1020
+ """
1021
+ rows = self.query(
1022
+ query, {"node_id": sint64_node_id, "unregistered": NodeStatus.UNREGISTERED}
1023
+ )
1024
+ if not rows:
1025
+ return False
1117
1026
 
1118
- # Construct query and params
1119
- current_dt = now()
1120
- query = "UPDATE node SET online_until = ?, heartbeat_interval = ?"
1121
- params: list[Any] = [
1122
- current_dt.timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval,
1123
- heartbeat_interval,
1124
- ]
1027
+ # Construct query and params
1028
+ current_dt = now()
1029
+ query = (
1030
+ "UPDATE node SET online_until = :online_until, "
1031
+ "heartbeat_interval = :heartbeat_interval"
1032
+ )
1033
+ params: dict[str, Any] = {
1034
+ "online_until": current_dt.timestamp()
1035
+ + HEARTBEAT_PATIENCE * heartbeat_interval,
1036
+ "heartbeat_interval": heartbeat_interval,
1037
+ }
1125
1038
 
1126
- # Set timestamp if the status changes
1127
- if row["status"] != NodeStatus.ONLINE:
1128
- query += ", status = ?, last_activated_at = ?"
1129
- params += [NodeStatus.ONLINE, current_dt.isoformat()]
1039
+ # Set timestamp if the status changes
1040
+ if rows[0]["status"] != NodeStatus.ONLINE:
1041
+ query += ", status = :online, last_activated_at = :last_activated_at"
1042
+ params["online"] = NodeStatus.ONLINE
1043
+ params["last_activated_at"] = current_dt.isoformat()
1130
1044
 
1131
- # Execute the query, refreshing `online_until` and `heartbeat_interval`
1132
- query += " WHERE node_id = ?"
1133
- params += [sint64_node_id]
1134
- self.conn.execute(query, params)
1135
- return True
1045
+ # Execute the query, refreshing `online_until` and `heartbeat_interval`
1046
+ query += " WHERE node_id = :node_id"
1047
+ params["node_id"] = sint64_node_id
1048
+ self.query(query, params)
1049
+ return True
1136
1050
 
1137
1051
  def _on_tokens_expired(self, expired_records: list[tuple[int, float]]) -> None:
1138
- """Transition runs with expired tokens to failed status.
1052
+ """Handle cleanup of expired tokens.
1053
+
1054
+ Override in subclasses to add custom cleanup logic.
1139
1055
 
1140
1056
  Parameters
1141
1057
  ----------
@@ -1146,28 +1062,30 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1146
1062
  if not expired_records:
1147
1063
  return
1148
1064
 
1149
- with self.conn:
1065
+ with self.session():
1150
1066
  query = """
1151
1067
  UPDATE run
1152
- SET sub_status = ?, details = ?, finished_at = ?
1153
- WHERE run_id = ?;
1068
+ SET sub_status = :failed, details = :details, finished_at = :finished_at
1069
+ WHERE run_id = :run_id
1154
1070
  """
1155
1071
  data = [
1156
- (
1157
- SubStatus.FAILED,
1158
- RUN_FAILURE_DETAILS_NO_HEARTBEAT,
1159
- datetime.fromtimestamp(active_until, tz=timezone.utc).isoformat(),
1160
- uint64_to_int64(run_id),
1161
- )
1072
+ {
1073
+ "failed": SubStatus.FAILED,
1074
+ "details": RUN_FAILURE_DETAILS_NO_HEARTBEAT,
1075
+ "finished_at": datetime.fromtimestamp(
1076
+ active_until, tz=timezone.utc
1077
+ ).isoformat(),
1078
+ "run_id": uint64_to_int64(run_id),
1079
+ }
1162
1080
  for run_id, active_until in expired_records
1163
1081
  ]
1164
- self.conn.executemany(query, data)
1082
+ self.query(query, data)
1165
1083
 
1166
1084
  def get_serverapp_context(self, run_id: int) -> Context | None:
1167
1085
  """Get the context for the specified `run_id`."""
1168
1086
  # Retrieve context if any
1169
- query = "SELECT context FROM context WHERE run_id = ?;"
1170
- rows = self.query(query, (uint64_to_int64(run_id),))
1087
+ query = "SELECT context FROM context WHERE run_id = :run_id"
1088
+ rows = self.query(query, {"run_id": uint64_to_int64(run_id)})
1171
1089
  context = context_from_bytes(rows[0]["context"]) if rows else None
1172
1090
  return context
1173
1091
 
@@ -1177,20 +1095,30 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1177
1095
  context_bytes = context_to_bytes(context)
1178
1096
  sint_run_id = uint64_to_int64(run_id)
1179
1097
 
1180
- with self.conn:
1098
+ with self.session():
1181
1099
  # Check if any existing Context assigned to the run_id
1182
- query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
1183
- row = self.conn.execute(query, (sint_run_id,)).fetchone()
1184
- if row["COUNT(*)"] > 0:
1100
+ query = "SELECT COUNT(*) as count FROM context WHERE run_id = :run_id"
1101
+ row = self.query(query, {"run_id": sint_run_id})[0]
1102
+ if row["count"] > 0:
1185
1103
  # Update context
1186
- query = "UPDATE context SET context = ? WHERE run_id = ?;"
1187
- self.conn.execute(query, (context_bytes, sint_run_id))
1104
+ query = """
1105
+ UPDATE context
1106
+ SET context = :context_bytes WHERE run_id = :run_id
1107
+ """
1108
+ self.query(
1109
+ query, {"context_bytes": context_bytes, "run_id": sint_run_id}
1110
+ )
1188
1111
  else:
1189
1112
  try:
1190
1113
  # Store context
1191
- query = "INSERT INTO context (run_id, context) VALUES (?, ?);"
1192
- self.conn.execute(query, (sint_run_id, context_bytes))
1193
- except sqlite3.IntegrityError:
1114
+ query = (
1115
+ "INSERT INTO context (run_id, context) "
1116
+ "VALUES (:run_id, :context_bytes)"
1117
+ )
1118
+ self.query(
1119
+ query, {"run_id": sint_run_id, "context_bytes": context_bytes}
1120
+ )
1121
+ except IntegrityError:
1194
1122
  raise ValueError(f"Run {run_id} not found") from None
1195
1123
 
1196
1124
  def add_serverapp_log(self, run_id: int, log_message: str) -> None:
@@ -1201,10 +1129,19 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1201
1129
  # Store log
1202
1130
  try:
1203
1131
  query = """
1204
- INSERT INTO logs (timestamp, run_id, node_id, log) VALUES (?, ?, ?, ?);
1132
+ INSERT INTO logs (timestamp, run_id, node_id, log)
1133
+ VALUES (:current_ts, :run_id, :node_id, :log)
1205
1134
  """
1206
- self.query(query, (now().timestamp(), sint64_run_id, 0, log_message))
1207
- except sqlite3.IntegrityError:
1135
+ self.query(
1136
+ query,
1137
+ {
1138
+ "current_ts": now().timestamp(),
1139
+ "run_id": sint64_run_id,
1140
+ "node_id": 0,
1141
+ "log": log_message,
1142
+ },
1143
+ )
1144
+ except IntegrityError:
1208
1145
  raise ValueError(f"Run {run_id} not found") from None
1209
1146
 
1210
1147
  def get_serverapp_log(
@@ -1214,10 +1151,10 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1214
1151
  # Convert the uint64 value to sint64 for SQLite
1215
1152
  sint64_run_id = uint64_to_int64(run_id)
1216
1153
 
1217
- with self.conn:
1154
+ with self.session():
1218
1155
  # Check if the run_id exists
1219
- query = "SELECT run_id FROM run WHERE run_id = ?;"
1220
- rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
1156
+ query = "SELECT run_id FROM run WHERE run_id = :run_id"
1157
+ rows = self.query(query, {"run_id": sint64_run_id})
1221
1158
  if not rows:
1222
1159
  raise ValueError(f"Run {run_id} not found")
1223
1160
 
@@ -1226,12 +1163,18 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1226
1163
  after_timestamp = 0.0
1227
1164
  query = """
1228
1165
  SELECT log, timestamp FROM logs
1229
- WHERE run_id = ? AND node_id = ? AND timestamp > ?;
1166
+ WHERE run_id = :run_id AND node_id = :node_id
1167
+ AND timestamp > :after_timestamp
1168
+ ORDER BY timestamp
1230
1169
  """
1231
- rows = self.conn.execute(
1232
- query, (sint64_run_id, 0, after_timestamp)
1233
- ).fetchall()
1234
- rows.sort(key=lambda x: x["timestamp"])
1170
+ rows = self.query(
1171
+ query,
1172
+ {
1173
+ "run_id": sint64_run_id,
1174
+ "node_id": 0,
1175
+ "after_timestamp": after_timestamp,
1176
+ },
1177
+ )
1235
1178
  latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
1236
1179
  return "".join(row["log"] for row in rows), latest_timestamp
1237
1180
 
@@ -1240,62 +1183,74 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
1240
1183
 
1241
1184
  Return Message if valid.
1242
1185
  """
1243
- with self.conn:
1186
+ with self.session():
1244
1187
  self._check_stored_messages({message_id})
1245
1188
  query = """
1246
1189
  SELECT *
1247
1190
  FROM message_ins
1248
1191
  WHERE message_id = :message_id
1249
1192
  """
1250
- data = {"message_id": message_id}
1251
- rows: list[dict[str, Any]] = self.conn.execute(query, data).fetchall()
1193
+ rows = self.query(query, {"message_id": message_id})
1252
1194
  if not rows:
1253
1195
  # Message does not exist
1254
1196
  return None
1255
1197
 
1256
- return rows[0]
1257
-
1258
-
1259
- def message_to_dict(message: Message) -> dict[str, Any]:
1260
- """Transform Message to dict."""
1261
- result = {
1262
- "message_id": message.metadata.message_id,
1263
- "group_id": message.metadata.group_id,
1264
- "run_id": message.metadata.run_id,
1265
- "src_node_id": message.metadata.src_node_id,
1266
- "dst_node_id": message.metadata.dst_node_id,
1267
- "reply_to_message_id": message.metadata.reply_to_message_id,
1268
- "created_at": message.metadata.created_at,
1269
- "delivered_at": message.metadata.delivered_at,
1270
- "ttl": message.metadata.ttl,
1271
- "message_type": message.metadata.message_type,
1272
- "content": None,
1273
- "error": None,
1274
- }
1275
-
1276
- if message.has_content():
1277
- result["content"] = recorddict_to_proto(message.content).SerializeToString()
1278
- else:
1279
- result["error"] = error_to_proto(message.error).SerializeToString()
1280
-
1281
- return result
1282
-
1283
-
1284
- def dict_to_message(message_dict: dict[str, Any]) -> Message:
1285
- """Transform dict to Message."""
1286
- content, error = None, None
1287
- if (b_content := message_dict.pop("content")) is not None:
1288
- content = recorddict_from_proto(ProtoRecordDict.FromString(b_content))
1289
- if (b_error := message_dict.pop("error")) is not None:
1290
- error = error_from_proto(ProtoError.FromString(b_error))
1291
-
1292
- # Metadata constructor doesn't allow passing created_at. We set it later
1293
- metadata = Metadata(
1294
- **{k: v for k, v in message_dict.items() if k not in ["delivered_at"]}
1295
- )
1296
- msg = make_message(metadata=metadata, content=content, error=error)
1297
- msg.metadata.delivered_at = message_dict["delivered_at"]
1298
- return msg
1198
+ return dict(rows[0])
1199
+
1200
+ def store_traffic(self, run_id: int, *, bytes_sent: int, bytes_recv: int) -> None:
1201
+ """Store traffic data for the specified `run_id`."""
1202
+ # Validate non-negative values
1203
+ if bytes_sent < 0 or bytes_recv < 0:
1204
+ raise ValueError(
1205
+ f"Negative traffic values for run {run_id}: "
1206
+ f"bytes_sent={bytes_sent}, bytes_recv={bytes_recv}"
1207
+ )
1208
+
1209
+ if bytes_sent == 0 and bytes_recv == 0:
1210
+ raise ValueError(
1211
+ f"Both bytes_sent and bytes_recv cannot be zero for run {run_id}"
1212
+ )
1213
+
1214
+ sint64_run_id = uint64_to_int64(run_id)
1215
+
1216
+ with self.session():
1217
+ # Check if run exists, performing the update only if it does
1218
+ update_query = """
1219
+ UPDATE run
1220
+ SET bytes_sent = bytes_sent + :bytes_sent,
1221
+ bytes_recv = bytes_recv + :bytes_recv
1222
+ WHERE run_id = :run_id
1223
+ RETURNING run_id
1224
+ """
1225
+ rows = self.query(
1226
+ update_query,
1227
+ {
1228
+ "bytes_sent": bytes_sent,
1229
+ "bytes_recv": bytes_recv,
1230
+ "run_id": sint64_run_id,
1231
+ },
1232
+ )
1233
+
1234
+ if not rows:
1235
+ raise ValueError(f"Run {run_id} not found")
1236
+
1237
+ def add_clientapp_runtime(self, run_id: int, runtime: float) -> None:
1238
+ """Add ClientApp runtime to the cumulative total for the specified `run_id`."""
1239
+ sint64_run_id = uint64_to_int64(run_id)
1240
+ with self.session():
1241
+ # Check if run exists, performing the update only if it does
1242
+ update_query = """
1243
+ UPDATE run
1244
+ SET clientapp_runtime = clientapp_runtime + :runtime
1245
+ WHERE run_id = :run_id
1246
+ RETURNING run_id
1247
+ """
1248
+ rows = self.query(
1249
+ update_query, {"runtime": runtime, "run_id": sint64_run_id}
1250
+ )
1251
+
1252
+ if not rows:
1253
+ raise ValueError(f"Run {run_id} not found")
1299
1254
 
1300
1255
 
1301
1256
  def determine_run_status(row: dict[str, Any]) -> str:
@@ -1309,4 +1264,4 @@ def determine_run_status(row: dict[str, Any]) -> str:
1309
1264
  return Status.STARTING
1310
1265
  return Status.PENDING
1311
1266
  run_id = int64_to_uint64(row["run_id"])
1312
- raise sqlite3.IntegrityError(f"The run {run_id} does not have a valid status.")
1267
+ raise ValueError(f"The run {run_id} does not have a valid status.")