flwr 1.23.0__py3-none-any.whl → 1.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (292) hide show
  1. flwr/__init__.py +16 -5
  2. flwr/app/error.py +2 -2
  3. flwr/app/exception.py +3 -3
  4. flwr/cli/app.py +19 -0
  5. flwr/cli/app_cmd/__init__.py +23 -0
  6. flwr/cli/app_cmd/publish.py +285 -0
  7. flwr/cli/app_cmd/review.py +252 -0
  8. flwr/cli/auth_plugin/auth_plugin.py +4 -5
  9. flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
  10. flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
  11. flwr/cli/build.py +60 -18
  12. flwr/cli/cli_account_auth_interceptor.py +24 -7
  13. flwr/cli/config_utils.py +101 -13
  14. flwr/cli/federation/__init__.py +24 -0
  15. flwr/cli/federation/ls.py +140 -0
  16. flwr/cli/federation/show.py +317 -0
  17. flwr/cli/install.py +91 -13
  18. flwr/cli/log.py +52 -9
  19. flwr/cli/login/login.py +7 -4
  20. flwr/cli/ls.py +170 -130
  21. flwr/cli/new/new.py +33 -50
  22. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
  23. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  24. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  25. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  26. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  27. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  28. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  29. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  30. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
  31. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  32. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
  33. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
  34. flwr/cli/pull.py +10 -5
  35. flwr/cli/run/run.py +77 -30
  36. flwr/cli/run_utils.py +130 -0
  37. flwr/cli/stop.py +25 -7
  38. flwr/cli/supernode/ls.py +16 -8
  39. flwr/cli/supernode/register.py +9 -4
  40. flwr/cli/supernode/unregister.py +5 -3
  41. flwr/cli/utils.py +376 -16
  42. flwr/client/__init__.py +1 -1
  43. flwr/client/dpfedavg_numpy_client.py +4 -1
  44. flwr/client/grpc_adapter_client/connection.py +6 -7
  45. flwr/client/grpc_rere_client/connection.py +10 -11
  46. flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
  47. flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
  48. flwr/client/message_handler/message_handler.py +2 -2
  49. flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
  50. flwr/client/numpy_client.py +1 -1
  51. flwr/client/rest_client/connection.py +12 -14
  52. flwr/client/run_info_store.py +4 -5
  53. flwr/client/typing.py +1 -1
  54. flwr/clientapp/client_app.py +9 -10
  55. flwr/clientapp/mod/centraldp_mods.py +16 -17
  56. flwr/clientapp/mod/localdp_mod.py +8 -9
  57. flwr/clientapp/typing.py +1 -1
  58. flwr/clientapp/utils.py +3 -3
  59. flwr/common/address.py +1 -2
  60. flwr/common/args.py +3 -4
  61. flwr/common/config.py +13 -16
  62. flwr/common/constant.py +5 -2
  63. flwr/common/differential_privacy.py +3 -4
  64. flwr/common/event_log_plugin/event_log_plugin.py +3 -4
  65. flwr/common/exit/exit.py +15 -2
  66. flwr/common/exit/exit_code.py +19 -0
  67. flwr/common/exit/exit_handler.py +6 -2
  68. flwr/common/exit/signal_handler.py +5 -5
  69. flwr/common/grpc.py +6 -6
  70. flwr/common/inflatable_protobuf_utils.py +1 -1
  71. flwr/common/inflatable_utils.py +38 -21
  72. flwr/common/logger.py +19 -19
  73. flwr/common/message.py +4 -4
  74. flwr/common/object_ref.py +7 -7
  75. flwr/common/record/array.py +3 -3
  76. flwr/common/record/arrayrecord.py +18 -30
  77. flwr/common/record/configrecord.py +3 -3
  78. flwr/common/record/recorddict.py +5 -5
  79. flwr/common/record/typeddict.py +9 -2
  80. flwr/common/recorddict_compat.py +7 -10
  81. flwr/common/retry_invoker.py +20 -20
  82. flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
  83. flwr/common/serde.py +5 -4
  84. flwr/common/serde_utils.py +2 -2
  85. flwr/common/telemetry.py +9 -5
  86. flwr/common/typing.py +52 -37
  87. flwr/compat/client/app.py +38 -37
  88. flwr/compat/client/grpc_client/connection.py +11 -11
  89. flwr/compat/server/app.py +5 -6
  90. flwr/proto/appio_pb2.py +13 -3
  91. flwr/proto/appio_pb2.pyi +134 -65
  92. flwr/proto/appio_pb2_grpc.py +20 -0
  93. flwr/proto/appio_pb2_grpc.pyi +27 -0
  94. flwr/proto/clientappio_pb2.py +17 -7
  95. flwr/proto/clientappio_pb2.pyi +15 -0
  96. flwr/proto/clientappio_pb2_grpc.py +206 -40
  97. flwr/proto/clientappio_pb2_grpc.pyi +168 -53
  98. flwr/proto/control_pb2.py +71 -52
  99. flwr/proto/control_pb2.pyi +277 -111
  100. flwr/proto/control_pb2_grpc.py +249 -40
  101. flwr/proto/control_pb2_grpc.pyi +185 -52
  102. flwr/proto/error_pb2.py +13 -3
  103. flwr/proto/error_pb2.pyi +24 -6
  104. flwr/proto/error_pb2_grpc.py +20 -0
  105. flwr/proto/error_pb2_grpc.pyi +27 -0
  106. flwr/proto/fab_pb2.py +14 -4
  107. flwr/proto/fab_pb2.pyi +59 -31
  108. flwr/proto/fab_pb2_grpc.py +20 -0
  109. flwr/proto/fab_pb2_grpc.pyi +27 -0
  110. flwr/proto/federation_pb2.py +38 -0
  111. flwr/proto/federation_pb2.pyi +56 -0
  112. flwr/proto/federation_pb2_grpc.py +24 -0
  113. flwr/proto/federation_pb2_grpc.pyi +31 -0
  114. flwr/proto/fleet_pb2.py +14 -4
  115. flwr/proto/fleet_pb2.pyi +137 -61
  116. flwr/proto/fleet_pb2_grpc.py +189 -48
  117. flwr/proto/fleet_pb2_grpc.pyi +175 -61
  118. flwr/proto/grpcadapter_pb2.py +14 -4
  119. flwr/proto/grpcadapter_pb2.pyi +38 -16
  120. flwr/proto/grpcadapter_pb2_grpc.py +35 -4
  121. flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
  122. flwr/proto/heartbeat_pb2.py +17 -7
  123. flwr/proto/heartbeat_pb2.pyi +51 -22
  124. flwr/proto/heartbeat_pb2_grpc.py +20 -0
  125. flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
  126. flwr/proto/log_pb2.py +13 -3
  127. flwr/proto/log_pb2.pyi +34 -11
  128. flwr/proto/log_pb2_grpc.py +20 -0
  129. flwr/proto/log_pb2_grpc.pyi +27 -0
  130. flwr/proto/message_pb2.py +15 -5
  131. flwr/proto/message_pb2.pyi +154 -86
  132. flwr/proto/message_pb2_grpc.py +20 -0
  133. flwr/proto/message_pb2_grpc.pyi +27 -0
  134. flwr/proto/node_pb2.py +15 -5
  135. flwr/proto/node_pb2.pyi +50 -25
  136. flwr/proto/node_pb2_grpc.py +20 -0
  137. flwr/proto/node_pb2_grpc.pyi +27 -0
  138. flwr/proto/recorddict_pb2.py +13 -3
  139. flwr/proto/recorddict_pb2.pyi +184 -107
  140. flwr/proto/recorddict_pb2_grpc.py +20 -0
  141. flwr/proto/recorddict_pb2_grpc.pyi +27 -0
  142. flwr/proto/run_pb2.py +40 -31
  143. flwr/proto/run_pb2.pyi +149 -84
  144. flwr/proto/run_pb2_grpc.py +20 -0
  145. flwr/proto/run_pb2_grpc.pyi +27 -0
  146. flwr/proto/serverappio_pb2.py +13 -3
  147. flwr/proto/serverappio_pb2.pyi +32 -8
  148. flwr/proto/serverappio_pb2_grpc.py +246 -65
  149. flwr/proto/serverappio_pb2_grpc.pyi +221 -85
  150. flwr/proto/simulationio_pb2.py +16 -8
  151. flwr/proto/simulationio_pb2.pyi +15 -0
  152. flwr/proto/simulationio_pb2_grpc.py +162 -41
  153. flwr/proto/simulationio_pb2_grpc.pyi +149 -55
  154. flwr/proto/transport_pb2.py +20 -10
  155. flwr/proto/transport_pb2.pyi +249 -160
  156. flwr/proto/transport_pb2_grpc.py +35 -4
  157. flwr/proto/transport_pb2_grpc.pyi +38 -8
  158. flwr/server/app.py +38 -17
  159. flwr/server/client_manager.py +4 -5
  160. flwr/server/client_proxy.py +10 -11
  161. flwr/server/compat/app.py +4 -5
  162. flwr/server/compat/app_utils.py +2 -1
  163. flwr/server/compat/grid_client_proxy.py +10 -12
  164. flwr/server/compat/legacy_context.py +3 -4
  165. flwr/server/fleet_event_log_interceptor.py +2 -1
  166. flwr/server/grid/grid.py +2 -3
  167. flwr/server/grid/grpc_grid.py +10 -8
  168. flwr/server/grid/inmemory_grid.py +4 -4
  169. flwr/server/run_serverapp.py +2 -3
  170. flwr/server/server.py +34 -39
  171. flwr/server/server_app.py +7 -8
  172. flwr/server/server_config.py +1 -2
  173. flwr/server/serverapp/app.py +34 -28
  174. flwr/server/serverapp_components.py +4 -5
  175. flwr/server/strategy/aggregate.py +9 -8
  176. flwr/server/strategy/bulyan.py +13 -11
  177. flwr/server/strategy/dp_adaptive_clipping.py +16 -20
  178. flwr/server/strategy/dp_fixed_clipping.py +12 -17
  179. flwr/server/strategy/dpfedavg_adaptive.py +3 -4
  180. flwr/server/strategy/dpfedavg_fixed.py +6 -10
  181. flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
  182. flwr/server/strategy/fedadagrad.py +18 -14
  183. flwr/server/strategy/fedadam.py +16 -14
  184. flwr/server/strategy/fedavg.py +16 -17
  185. flwr/server/strategy/fedavg_android.py +15 -15
  186. flwr/server/strategy/fedavgm.py +21 -18
  187. flwr/server/strategy/fedmedian.py +2 -3
  188. flwr/server/strategy/fedopt.py +11 -10
  189. flwr/server/strategy/fedprox.py +10 -9
  190. flwr/server/strategy/fedtrimmedavg.py +12 -11
  191. flwr/server/strategy/fedxgb_bagging.py +13 -11
  192. flwr/server/strategy/fedxgb_cyclic.py +6 -6
  193. flwr/server/strategy/fedxgb_nn_avg.py +4 -4
  194. flwr/server/strategy/fedyogi.py +16 -14
  195. flwr/server/strategy/krum.py +12 -11
  196. flwr/server/strategy/qfedavg.py +16 -15
  197. flwr/server/strategy/strategy.py +6 -9
  198. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
  199. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
  200. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
  201. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
  202. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
  203. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
  204. flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
  205. flwr/server/superlink/fleet/message_handler/message_handler.py +34 -28
  206. flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
  207. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  208. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
  209. flwr/server/superlink/fleet/vce/vce_api.py +15 -9
  210. flwr/server/superlink/linkstate/in_memory_linkstate.py +115 -150
  211. flwr/server/superlink/linkstate/linkstate.py +59 -43
  212. flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
  213. flwr/server/superlink/linkstate/sqlite_linkstate.py +447 -438
  214. flwr/server/superlink/linkstate/utils.py +6 -6
  215. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
  216. flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
  217. flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
  218. flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
  219. flwr/server/superlink/utils.py +4 -6
  220. flwr/server/typing.py +1 -1
  221. flwr/server/utils/tensorboard.py +15 -8
  222. flwr/server/workflow/default_workflows.py +5 -5
  223. flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
  224. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
  225. flwr/serverapp/strategy/bulyan.py +16 -15
  226. flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
  227. flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
  228. flwr/serverapp/strategy/fedadagrad.py +10 -11
  229. flwr/serverapp/strategy/fedadam.py +10 -11
  230. flwr/serverapp/strategy/fedavg.py +9 -10
  231. flwr/serverapp/strategy/fedavgm.py +17 -16
  232. flwr/serverapp/strategy/fedmedian.py +2 -2
  233. flwr/serverapp/strategy/fedopt.py +10 -11
  234. flwr/serverapp/strategy/fedprox.py +7 -8
  235. flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
  236. flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
  237. flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
  238. flwr/serverapp/strategy/fedyogi.py +9 -11
  239. flwr/serverapp/strategy/krum.py +7 -7
  240. flwr/serverapp/strategy/multikrum.py +9 -9
  241. flwr/serverapp/strategy/qfedavg.py +17 -16
  242. flwr/serverapp/strategy/strategy.py +6 -9
  243. flwr/serverapp/strategy/strategy_utils.py +7 -8
  244. flwr/simulation/app.py +46 -42
  245. flwr/simulation/legacy_app.py +12 -12
  246. flwr/simulation/ray_transport/ray_actor.py +10 -11
  247. flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
  248. flwr/simulation/run_simulation.py +43 -43
  249. flwr/simulation/simulationio_connection.py +4 -4
  250. flwr/supercore/cli/flower_superexec.py +3 -4
  251. flwr/supercore/constant.py +31 -1
  252. flwr/supercore/corestate/corestate.py +24 -3
  253. flwr/supercore/corestate/in_memory_corestate.py +138 -0
  254. flwr/supercore/corestate/sqlite_corestate.py +157 -0
  255. flwr/supercore/ffs/disk_ffs.py +1 -2
  256. flwr/supercore/ffs/ffs.py +1 -2
  257. flwr/supercore/ffs/ffs_factory.py +1 -2
  258. flwr/{common → supercore}/heartbeat.py +20 -25
  259. flwr/supercore/object_store/in_memory_object_store.py +1 -2
  260. flwr/supercore/object_store/object_store.py +1 -2
  261. flwr/supercore/object_store/object_store_factory.py +1 -2
  262. flwr/supercore/object_store/sqlite_object_store.py +8 -7
  263. flwr/supercore/primitives/asymmetric.py +1 -1
  264. flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
  265. flwr/supercore/sqlite_mixin.py +37 -34
  266. flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
  267. flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
  268. flwr/supercore/superexec/run_superexec.py +9 -13
  269. flwr/superlink/artifact_provider/artifact_provider.py +1 -2
  270. flwr/superlink/auth_plugin/auth_plugin.py +6 -9
  271. flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
  272. flwr/superlink/federation/__init__.py +24 -0
  273. flwr/superlink/federation/federation_manager.py +64 -0
  274. flwr/superlink/federation/noop_federation_manager.py +71 -0
  275. flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
  276. flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
  277. flwr/superlink/servicer/control/control_grpc.py +5 -6
  278. flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
  279. flwr/superlink/servicer/control/control_servicer.py +102 -18
  280. flwr/supernode/cli/flower_supernode.py +58 -3
  281. flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
  282. flwr/supernode/nodestate/nodestate.py +7 -8
  283. flwr/supernode/nodestate/nodestate_factory.py +7 -4
  284. flwr/supernode/runtime/run_clientapp.py +41 -22
  285. flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
  286. flwr/supernode/start_client_internal.py +158 -42
  287. {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
  288. flwr-1.24.0.dist-info/RECORD +454 -0
  289. flwr/supercore/object_store/utils.py +0 -43
  290. flwr-1.23.0.dist-info/RECORD +0 -439
  291. {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
  292. {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
@@ -18,16 +18,14 @@
18
18
  # pylint: disable=too-many-lines
19
19
 
20
20
  import json
21
- import secrets
22
21
  import sqlite3
23
22
  from collections.abc import Sequence
23
+ from datetime import datetime, timezone
24
24
  from logging import ERROR, WARNING
25
- from typing import Any, Optional, Union, cast
25
+ from typing import Any, cast
26
26
 
27
27
  from flwr.common import Context, Message, Metadata, log, now
28
28
  from flwr.common.constant import (
29
- FLWR_APP_TOKEN_LENGTH,
30
- HEARTBEAT_INTERVAL_INF,
31
29
  HEARTBEAT_PATIENCE,
32
30
  MESSAGE_TTL_TOLERANCE,
33
31
  NODE_ID_NUM_BYTES,
@@ -51,8 +49,10 @@ from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
51
49
  # pylint: enable=E0611
52
50
  from flwr.server.utils.validator import validate_message
53
51
  from flwr.supercore.constant import NodeStatus
54
- from flwr.supercore.sqlite_mixin import SqliteMixin
52
+ from flwr.supercore.corestate.sqlite_corestate import SqliteCoreState
53
+ from flwr.supercore.object_store.object_store import ObjectStore
55
54
  from flwr.supercore.utils import int64_to_uint64, uint64_to_int64
55
+ from flwr.superlink.federation import FederationManager
56
56
 
57
57
  from .linkstate import LinkState
58
58
  from .utils import (
@@ -74,6 +74,7 @@ SQL_CREATE_TABLE_NODE = """
74
74
  CREATE TABLE IF NOT EXISTS node(
75
75
  node_id INTEGER UNIQUE,
76
76
  owner_aid TEXT,
77
+ owner_name TEXT,
77
78
  status TEXT,
78
79
  registered_at TEXT,
79
80
  last_activated_at TEXT NULL,
@@ -106,8 +107,6 @@ CREATE INDEX IF NOT EXISTS idx_node_status ON node(status);
106
107
  SQL_CREATE_TABLE_RUN = """
107
108
  CREATE TABLE IF NOT EXISTS run(
108
109
  run_id INTEGER UNIQUE,
109
- active_until REAL,
110
- heartbeat_interval REAL,
111
110
  fab_id TEXT,
112
111
  fab_version TEXT,
113
112
  fab_hash TEXT,
@@ -118,6 +117,7 @@ CREATE TABLE IF NOT EXISTS run(
118
117
  finished_at TEXT,
119
118
  sub_status TEXT,
120
119
  details TEXT,
120
+ federation TEXT,
121
121
  federation_options BLOB,
122
122
  flwr_aid TEXT
123
123
  );
@@ -179,20 +179,23 @@ CREATE TABLE IF NOT EXISTS message_res(
179
179
  );
180
180
  """
181
181
 
182
- SQL_CREATE_TABLE_TOKEN_STORE = """
183
- CREATE TABLE IF NOT EXISTS token_store (
184
- run_id INTEGER PRIMARY KEY,
185
- token TEXT UNIQUE NOT NULL
186
- );
187
- """
188
-
189
182
 
190
- class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
183
+ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
191
184
  """SQLite-based LinkState implementation."""
192
185
 
193
- def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
194
- """Connect to the DB, enable FK support, and create tables if needed."""
195
- return self._ensure_initialized(
186
+ def __init__(
187
+ self,
188
+ database_path: str,
189
+ federation_manager: FederationManager,
190
+ object_store: ObjectStore,
191
+ ) -> None:
192
+ super().__init__(database_path, object_store)
193
+ federation_manager.linkstate = self
194
+ self._federation_manager = federation_manager
195
+
196
+ def get_sql_statements(self) -> tuple[str, ...]:
197
+ """Return SQL statements for LinkState tables."""
198
+ return super().get_sql_statements() + (
196
199
  SQL_CREATE_TABLE_RUN,
197
200
  SQL_CREATE_TABLE_LOGS,
198
201
  SQL_CREATE_TABLE_CONTEXT,
@@ -200,14 +203,17 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
200
203
  SQL_CREATE_TABLE_MESSAGE_RES,
201
204
  SQL_CREATE_TABLE_NODE,
202
205
  SQL_CREATE_TABLE_PUBLIC_KEY,
203
- SQL_CREATE_TABLE_TOKEN_STORE,
204
206
  SQL_CREATE_INDEX_ONLINE_UNTIL,
205
207
  SQL_CREATE_INDEX_OWNER_AID,
206
208
  SQL_CREATE_INDEX_NODE_STATUS,
207
- log_queries=log_queries,
208
209
  )
209
210
 
210
- def store_message_ins(self, message: Message) -> Optional[str]:
211
+ @property
212
+ def federation_manager(self) -> FederationManager:
213
+ """Get the FederationManager instance."""
214
+ return self._federation_manager
215
+
216
+ def store_message_ins(self, message: Message) -> str | None:
211
217
  """Store one Message."""
212
218
  # Validate message
213
219
  errors = validate_message(message=message, is_reply_message=False)
@@ -223,12 +229,6 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
223
229
  data[0], ["run_id", "src_node_id", "dst_node_id"]
224
230
  )
225
231
 
226
- # Validate run_id
227
- query = "SELECT run_id FROM run WHERE run_id = ?;"
228
- if not self.query(query, (data[0]["run_id"],)):
229
- log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
230
- return None
231
-
232
232
  # Validate source node ID
233
233
  if message.metadata.src_node_id != SUPERLINK_NODE_ID:
234
234
  log(
@@ -238,28 +238,87 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
238
238
  )
239
239
  return None
240
240
 
241
- # Validate destination node ID
242
- query = "SELECT node_id FROM node WHERE node_id = ? AND status IN (?, ?);"
243
- if not self.query(
244
- query, (data[0]["dst_node_id"], NodeStatus.ONLINE, NodeStatus.OFFLINE)
245
- ):
246
- log(
247
- ERROR,
248
- "Invalid destination node ID for Message: %s",
249
- message.metadata.dst_node_id,
250
- )
251
- return None
252
-
253
- columns = ", ".join([f":{key}" for key in data[0]])
254
- query = f"INSERT INTO message_ins VALUES({columns});"
255
-
256
- # Only invalid run_id can trigger IntegrityError.
257
- # This may need to be changed in the future version with more integrity checks.
258
- self.query(query, data)
241
+ with self.conn:
242
+ # Validate run_id
243
+ query = "SELECT federation FROM run WHERE run_id = ?;"
244
+ rows = self.conn.execute(query, (data[0]["run_id"],)).fetchall()
245
+ if not rows:
246
+ log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
247
+ return None
248
+ federation: str = rows[0]["federation"]
249
+
250
+ # Validate destination node ID
251
+ query = "SELECT node_id FROM node WHERE node_id = ? AND status IN (?, ?);"
252
+ rows = self.conn.execute(
253
+ query, (data[0]["dst_node_id"], NodeStatus.ONLINE, NodeStatus.OFFLINE)
254
+ ).fetchall()
255
+ if not rows or not self.federation_manager.has_node(
256
+ message.metadata.dst_node_id, federation
257
+ ):
258
+ log(
259
+ ERROR,
260
+ "Invalid destination node ID for Message: %s",
261
+ message.metadata.dst_node_id,
262
+ )
263
+ return None
264
+
265
+ columns = ", ".join([f":{key}" for key in data[0]])
266
+ query = f"INSERT INTO message_ins VALUES({columns});"
267
+
268
+ # Only invalid run_id can trigger IntegrityError.
269
+ # This may need to be changed in the future version
270
+ # with more integrity checks.
271
+ self.conn.execute(query, data[0])
259
272
 
260
273
  return message.metadata.message_id
261
274
 
262
- def get_message_ins(self, node_id: int, limit: Optional[int]) -> list[Message]:
275
+ def _check_stored_messages(self, message_ids: set[str]) -> None:
276
+ """Check and delete the message if it's invalid."""
277
+ if not message_ids:
278
+ return
279
+
280
+ with self.conn:
281
+ invalid_msg_ids: set[str] = set()
282
+ current_time = now().timestamp()
283
+
284
+ for msg_id in message_ids:
285
+ # Check if message exists
286
+ query = "SELECT * FROM message_ins WHERE message_id = ?;"
287
+ message_row = self.conn.execute(query, (msg_id,)).fetchone()
288
+ if not message_row:
289
+ continue
290
+
291
+ # Check if the message has expired
292
+ available_until = message_row["created_at"] + message_row["ttl"]
293
+ if available_until <= current_time:
294
+ invalid_msg_ids.add(msg_id)
295
+ continue
296
+
297
+ # Check if src_node_id and dst_node_id are in the federation
298
+ # Get federation from run table
299
+ run_id = message_row["run_id"]
300
+ query = "SELECT federation FROM run WHERE run_id = ?;"
301
+ run_row = self.conn.execute(query, (run_id,)).fetchone()
302
+ if not run_row: # This should not happen
303
+ invalid_msg_ids.add(msg_id)
304
+ continue
305
+ federation = run_row["federation"]
306
+
307
+ # Convert sint64 to uint64 for node IDs
308
+ src_node_id = int64_to_uint64(message_row["src_node_id"])
309
+ dst_node_id = int64_to_uint64(message_row["dst_node_id"])
310
+
311
+ # Filter nodes to check if they're in the federation
312
+ filtered = self.federation_manager.filter_nodes(
313
+ {src_node_id, dst_node_id}, federation
314
+ )
315
+ if len(filtered) != 2: # Not both nodes are in the federation
316
+ invalid_msg_ids.add(msg_id)
317
+
318
+ # Delete all invalid messages
319
+ self.delete_messages(invalid_msg_ids)
320
+
321
+ def get_message_ins(self, node_id: int, limit: int | None) -> list[Message]:
263
322
  """Get all Messages that have not been delivered yet."""
264
323
  if limit is not None and limit < 1:
265
324
  raise AssertionError("`limit` must be >= 1")
@@ -268,59 +327,64 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
268
327
  msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
269
328
  raise AssertionError(msg)
270
329
 
271
- data: dict[str, Union[str, int]] = {}
330
+ data: dict[str, str | int] = {}
272
331
 
273
332
  # Convert the uint64 value to sint64 for SQLite
274
333
  data["node_id"] = uint64_to_int64(node_id)
275
334
 
276
- # Retrieve all Messages for node_id
277
- query = """
278
- SELECT message_id
279
- FROM message_ins
280
- WHERE dst_node_id == :node_id
281
- AND delivered_at = ""
282
- AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
283
- """
284
-
285
- if limit is not None:
286
- query += " LIMIT :limit"
287
- data["limit"] = limit
288
-
289
- query += ";"
290
-
291
- rows = self.query(query, data)
292
-
293
- if rows:
294
- # Prepare query
295
- message_ids = [row["message_id"] for row in rows]
296
- placeholders: str = ",".join([f":id_{i}" for i in range(len(message_ids))])
297
- query = f"""
298
- UPDATE message_ins
299
- SET delivered_at = :delivered_at
300
- WHERE message_id IN ({placeholders})
301
- RETURNING *;
335
+ with self.conn:
336
+ # Retrieve all Messages for node_id
337
+ query = """
338
+ SELECT message_id
339
+ FROM message_ins
340
+ WHERE dst_node_id == :node_id
341
+ AND delivered_at = ""
342
+ AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
302
343
  """
303
344
 
304
- # Prepare data for query
305
- delivered_at = now().isoformat()
306
- data = {"delivered_at": delivered_at}
307
- for index, msg_id in enumerate(message_ids):
308
- data[f"id_{index}"] = str(msg_id)
345
+ if limit is not None:
346
+ query += " LIMIT :limit"
347
+ data["limit"] = limit
309
348
 
310
- # Run query
311
- rows = self.query(query, data)
349
+ query += ";"
312
350
 
313
- for row in rows:
314
- # Convert values from sint64 to uint64
315
- convert_sint64_values_in_dict_to_uint64(
316
- row, ["run_id", "src_node_id", "dst_node_id"]
317
- )
351
+ rows = self.conn.execute(query, data).fetchall()
352
+ message_ids: set[str] = {row["message_id"] for row in rows}
353
+ self._check_stored_messages(message_ids)
354
+
355
+ # Mark retrieved Messages as delivered
356
+ if rows:
357
+ # Prepare query
358
+ placeholders: str = ",".join(
359
+ [f":id_{i}" for i in range(len(message_ids))]
360
+ )
361
+ query = f"""
362
+ UPDATE message_ins
363
+ SET delivered_at = :delivered_at
364
+ WHERE message_id IN ({placeholders})
365
+ RETURNING *;
366
+ """
367
+
368
+ # Prepare data for query
369
+ delivered_at = now().isoformat()
370
+ data = {"delivered_at": delivered_at}
371
+ for index, msg_id in enumerate(message_ids):
372
+ data[f"id_{index}"] = str(msg_id)
373
+
374
+ # Run query
375
+ rows = self.conn.execute(query, data).fetchall()
376
+
377
+ for row in rows:
378
+ # Convert values from sint64 to uint64
379
+ convert_sint64_values_in_dict_to_uint64(
380
+ row, ["run_id", "src_node_id", "dst_node_id"]
381
+ )
318
382
 
319
383
  result = [dict_to_message(row) for row in rows]
320
384
 
321
385
  return result
322
386
 
323
- def store_message_res(self, message: Message) -> Optional[str]:
387
+ def store_message_res(self, message: Message) -> str | None:
324
388
  """Store one Message."""
325
389
  # Validate message
326
390
  errors = validate_message(message=message, is_reply_message=True)
@@ -336,7 +400,8 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
336
400
  ERROR,
337
401
  "Failed to store Message reply: "
338
402
  "The message it replies to with message_id %s does not exist or "
339
- "has expired.",
403
+ "has expired, or was deleted because the target SuperNode was "
404
+ "removed from the federation.",
340
405
  msg_ins_id,
341
406
  )
342
407
  return None
@@ -397,84 +462,92 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
397
462
  # pylint: disable-msg=too-many-locals
398
463
  ret: dict[str, Message] = {}
399
464
 
400
- # Verify Message IDs
401
- current = now().timestamp()
402
- query = f"""
403
- SELECT *
404
- FROM message_ins
405
- WHERE message_id IN ({",".join(["?"] * len(message_ids))});
406
- """
407
- rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
408
- found_message_ins_dict: dict[str, Message] = {}
409
- for row in rows:
410
- convert_sint64_values_in_dict_to_uint64(
411
- row, ["run_id", "src_node_id", "dst_node_id"]
465
+ with self.conn:
466
+ # Verify Message IDs
467
+ self._check_stored_messages(message_ids)
468
+ current = now().timestamp()
469
+ query = f"""
470
+ SELECT *
471
+ FROM message_ins
472
+ WHERE message_id IN ({','.join(['?'] * len(message_ids))});
473
+ """
474
+ rows = self.conn.execute(
475
+ query, tuple(str(message_id) for message_id in message_ids)
476
+ ).fetchall()
477
+ found_message_ins_dict: dict[str, Message] = {}
478
+ for row in rows:
479
+ convert_sint64_values_in_dict_to_uint64(
480
+ row, ["run_id", "src_node_id", "dst_node_id"]
481
+ )
482
+ found_message_ins_dict[row["message_id"]] = dict_to_message(row)
483
+
484
+ ret = verify_message_ids(
485
+ inquired_message_ids=message_ids,
486
+ found_message_ins_dict=found_message_ins_dict,
487
+ current_time=current,
412
488
  )
413
- found_message_ins_dict[row["message_id"]] = dict_to_message(row)
414
489
 
415
- ret = verify_message_ids(
416
- inquired_message_ids=message_ids,
417
- found_message_ins_dict=found_message_ins_dict,
418
- current_time=current,
419
- )
490
+ # Check node availability
491
+ dst_node_ids: set[int] = set()
492
+ for message_id in message_ids:
493
+ in_message = found_message_ins_dict[message_id]
494
+ sint_node_id = uint64_to_int64(in_message.metadata.dst_node_id)
495
+ dst_node_ids.add(sint_node_id)
496
+ query = f"""
497
+ SELECT node_id, online_until
498
+ FROM node
499
+ WHERE node_id IN ({','.join(['?'] * len(dst_node_ids))})
500
+ AND status != ?
501
+ """
502
+ rows = self.conn.execute(
503
+ query, tuple(dst_node_ids) + (NodeStatus.UNREGISTERED,)
504
+ ).fetchall()
505
+ tmp_ret_dict = check_node_availability_for_in_message(
506
+ inquired_in_message_ids=message_ids,
507
+ found_in_message_dict=found_message_ins_dict,
508
+ node_id_to_online_until={
509
+ int64_to_uint64(row["node_id"]): row["online_until"] for row in rows
510
+ },
511
+ current_time=current,
512
+ )
513
+ ret.update(tmp_ret_dict)
420
514
 
421
- # Check node availability
422
- dst_node_ids: set[int] = set()
423
- for message_id in message_ids:
424
- in_message = found_message_ins_dict[message_id]
425
- sint_node_id = uint64_to_int64(in_message.metadata.dst_node_id)
426
- dst_node_ids.add(sint_node_id)
427
- query = f"""
428
- SELECT node_id, online_until
429
- FROM node
430
- WHERE node_id IN ({",".join(["?"] * len(dst_node_ids))})
431
- AND status != ?
432
- """
433
- rows = self.query(query, tuple(dst_node_ids) + (NodeStatus.UNREGISTERED,))
434
- tmp_ret_dict = check_node_availability_for_in_message(
435
- inquired_in_message_ids=message_ids,
436
- found_in_message_dict=found_message_ins_dict,
437
- node_id_to_online_until={
438
- int64_to_uint64(row["node_id"]): row["online_until"] for row in rows
439
- },
440
- current_time=current,
441
- )
442
- ret.update(tmp_ret_dict)
443
-
444
- # Find all reply Messages
445
- query = f"""
446
- SELECT *
447
- FROM message_res
448
- WHERE reply_to_message_id IN ({",".join(["?"] * len(message_ids))})
449
- AND delivered_at = "";
450
- """
451
- rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
452
- for row in rows:
453
- convert_sint64_values_in_dict_to_uint64(
454
- row, ["run_id", "src_node_id", "dst_node_id"]
515
+ # Find all reply Messages
516
+ query = f"""
517
+ SELECT *
518
+ FROM message_res
519
+ WHERE reply_to_message_id IN ({','.join(['?'] * len(message_ids))})
520
+ AND delivered_at = "";
521
+ """
522
+ rows = self.conn.execute(
523
+ query, tuple(str(message_id) for message_id in message_ids)
524
+ ).fetchall()
525
+ for row in rows:
526
+ convert_sint64_values_in_dict_to_uint64(
527
+ row, ["run_id", "src_node_id", "dst_node_id"]
528
+ )
529
+ tmp_ret_dict = verify_found_message_replies(
530
+ inquired_message_ids=message_ids,
531
+ found_message_ins_dict=found_message_ins_dict,
532
+ found_message_res_list=[dict_to_message(row) for row in rows],
533
+ current_time=current,
455
534
  )
456
- tmp_ret_dict = verify_found_message_replies(
457
- inquired_message_ids=message_ids,
458
- found_message_ins_dict=found_message_ins_dict,
459
- found_message_res_list=[dict_to_message(row) for row in rows],
460
- current_time=current,
461
- )
462
- ret.update(tmp_ret_dict)
463
-
464
- # Mark existing reply Messages to be returned as delivered
465
- delivered_at = now().isoformat()
466
- for message_res in ret.values():
467
- message_res.metadata.delivered_at = delivered_at
468
- message_res_ids = [
469
- message_res.metadata.message_id for message_res in ret.values()
470
- ]
471
- query = f"""
472
- UPDATE message_res
473
- SET delivered_at = ?
474
- WHERE message_id IN ({",".join(["?"] * len(message_res_ids))});
475
- """
476
- data: list[Any] = [delivered_at] + message_res_ids
477
- self.query(query, data)
535
+ ret.update(tmp_ret_dict)
536
+
537
+ # Mark existing reply Messages to be returned as delivered
538
+ delivered_at = now().isoformat()
539
+ for message_res in ret.values():
540
+ message_res.metadata.delivered_at = delivered_at
541
+ message_res_ids = [
542
+ message_res.metadata.message_id for message_res in ret.values()
543
+ ]
544
+ query = f"""
545
+ UPDATE message_res
546
+ SET delivered_at = ?
547
+ WHERE message_id IN ({','.join(['?'] * len(message_res_ids))});
548
+ """
549
+ data: list[Any] = [delivered_at] + message_res_ids
550
+ self.conn.execute(query, data)
478
551
 
479
552
  return list(ret.values())
480
553
 
@@ -545,7 +618,11 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
545
618
  return {row["message_id"] for row in rows}
546
619
 
547
620
  def create_node(
548
- self, owner_aid: str, public_key: bytes, heartbeat_interval: float
621
+ self,
622
+ owner_aid: str,
623
+ owner_name: str,
624
+ public_key: bytes,
625
+ heartbeat_interval: float,
549
626
  ) -> int:
550
627
  """Create, store in the link state, and return `node_id`."""
551
628
  # Sample a random uint64 as node_id
@@ -558,10 +635,10 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
558
635
 
559
636
  query = """
560
637
  INSERT INTO node
561
- (node_id, owner_aid, status, registered_at, last_activated_at,
638
+ (node_id, owner_aid, owner_name, status, registered_at, last_activated_at,
562
639
  last_deactivated_at, unregistered_at, online_until, heartbeat_interval,
563
640
  public_key)
564
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
641
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
565
642
  """
566
643
 
567
644
  # Mark the node online until now().timestamp() + heartbeat_interval
@@ -571,6 +648,7 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
571
648
  (
572
649
  sint64_node_id, # node_id
573
650
  owner_aid, # owner_aid
651
+ owner_name, # owner_name
574
652
  NodeStatus.REGISTERED, # status
575
653
  now().isoformat(), # registered_at
576
654
  None, # last_activated_at
@@ -686,23 +764,26 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
686
764
  if self.conn is None:
687
765
  raise AttributeError("LinkState not initialized")
688
766
 
689
- # Convert the uint64 value to sint64 for SQLite
690
- sint64_run_id = uint64_to_int64(run_id)
691
-
692
- # Validate run ID
693
- query = "SELECT COUNT(*) FROM run WHERE run_id = ?"
694
- rows = self.query(query, (sint64_run_id,))
695
- if rows[0]["COUNT(*)"] == 0:
696
- return set()
697
-
698
- # Retrieve all online nodes
699
- return {
700
- node.node_id for node in self.get_node_info(statuses=[NodeStatus.ONLINE])
701
- }
702
-
703
- def _check_and_tag_offline_nodes(
704
- self, node_ids: Optional[list[int]] = None
705
- ) -> None:
767
+ with self.conn:
768
+ # Convert the uint64 value to sint64 for SQLite
769
+ sint64_run_id = uint64_to_int64(run_id)
770
+
771
+ # Validate run ID
772
+ query = "SELECT federation FROM run WHERE run_id = ?"
773
+ rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
774
+ if not rows:
775
+ return set()
776
+ federation: str = rows[0]["federation"]
777
+
778
+ # Retrieve all online nodes
779
+ node_ids = {
780
+ node.node_id
781
+ for node in self.get_node_info(statuses=[NodeStatus.ONLINE])
782
+ }
783
+ # Filter node IDs by federation
784
+ return self.federation_manager.filter_nodes(node_ids, federation)
785
+
786
+ def _check_and_tag_offline_nodes(self, node_ids: list[int] | None = None) -> None:
706
787
  """Check and tag offline nodes."""
707
788
  # strftime will convert POSIX timestamp to ISO format
708
789
  query = """
@@ -725,9 +806,9 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
725
806
  def get_node_info(
726
807
  self,
727
808
  *,
728
- node_ids: Optional[Sequence[int]] = None,
729
- owner_aids: Optional[Sequence[str]] = None,
730
- statuses: Optional[Sequence[str]] = None,
809
+ node_ids: Sequence[int] | None = None,
810
+ owner_aids: Sequence[str] | None = None,
811
+ statuses: Sequence[str] | None = None,
731
812
  ) -> Sequence[NodeInfo]:
732
813
  """Retrieve information about nodes based on the specified filters."""
733
814
  with self.conn:
@@ -781,7 +862,7 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
781
862
  # Return the public key
782
863
  return cast(bytes, rows[0]["public_key"])
783
864
 
784
- def get_node_id_by_public_key(self, public_key: bytes) -> Optional[int]:
865
+ def get_node_id_by_public_key(self, public_key: bytes) -> int | None:
785
866
  """Get `node_id` for the specified `public_key` if it exists and is not
786
867
  deleted."""
787
868
  query = "SELECT node_id FROM node WHERE public_key = ? AND status != ?;"
@@ -798,55 +879,58 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
798
879
  # pylint: disable=too-many-arguments,too-many-positional-arguments
799
880
  def create_run(
800
881
  self,
801
- fab_id: Optional[str],
802
- fab_version: Optional[str],
803
- fab_hash: Optional[str],
882
+ fab_id: str | None,
883
+ fab_version: str | None,
884
+ fab_hash: str | None,
804
885
  override_config: UserConfig,
886
+ federation: str,
805
887
  federation_options: ConfigRecord,
806
- flwr_aid: Optional[str],
888
+ flwr_aid: str | None,
807
889
  ) -> int:
808
- """Create a new run for the specified `fab_id` and `fab_version`."""
890
+ """Create a new run."""
809
891
  # Sample a random int64 as run_id
810
892
  uint64_run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
811
893
 
812
894
  # Convert the uint64 value to sint64 for SQLite
813
895
  sint64_run_id = uint64_to_int64(uint64_run_id)
814
896
 
815
- # Check conflicts
816
- query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
817
- # If sint64_run_id does not exist
818
- if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
819
- query = (
820
- "INSERT INTO run "
821
- "(run_id, active_until, heartbeat_interval, fab_id, fab_version, "
822
- "fab_hash, override_config, federation_options, pending_at, "
823
- "starting_at, running_at, finished_at, sub_status, details, flwr_aid) "
824
- "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
825
- )
826
- override_config_json = json.dumps(override_config)
827
- data = [
828
- sint64_run_id,
829
- 0, # The `active_until` is not used until the run is started
830
- 0, # This `heartbeat_interval` is not used until the run is started
831
- fab_id,
832
- fab_version,
833
- fab_hash,
834
- override_config_json,
835
- configrecord_to_bytes(federation_options),
836
- now().isoformat(),
837
- "",
838
- "",
839
- "",
840
- "",
841
- "",
842
- flwr_aid or "",
843
- ]
844
- self.query(query, tuple(data))
845
- return uint64_run_id
897
+ with self.conn:
898
+ # Check conflicts
899
+ query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
900
+ # If sint64_run_id does not exist
901
+ row = self.conn.execute(query, (sint64_run_id,)).fetchone()
902
+ if row["COUNT(*)"] == 0:
903
+ query = """
904
+ INSERT INTO run
905
+ (run_id, fab_id, fab_version,
906
+ fab_hash, override_config, federation, federation_options,
907
+ pending_at, starting_at, running_at, finished_at, sub_status,
908
+ details, flwr_aid)
909
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
910
+ """
911
+ override_config_json = json.dumps(override_config)
912
+ data = [
913
+ sint64_run_id, # run_id
914
+ fab_id, # fab_id
915
+ fab_version, # fab_version
916
+ fab_hash, # fab_hash
917
+ override_config_json, # override_config
918
+ federation, # federation
919
+ configrecord_to_bytes(federation_options), # federation_options
920
+ now().isoformat(), # pending_at
921
+ "", # starting_at
922
+ "", # running_at
923
+ "", # finished_at
924
+ "", # sub_status
925
+ "", # details
926
+ flwr_aid or "", # flwr_aid
927
+ ]
928
+ self.conn.execute(query, tuple(data))
929
+ return uint64_run_id
846
930
  log(ERROR, "Unexpected run creation failure.")
847
931
  return 0
848
932
 
849
- def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
933
+ def get_run_ids(self, flwr_aid: str | None) -> set[int]:
850
934
  """Retrieve all run IDs if `flwr_aid` is not specified.
851
935
 
852
936
  Otherwise, retrieve all run IDs for the specified `flwr_aid`.
@@ -860,32 +944,10 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
860
944
  rows = self.query("SELECT run_id FROM run;", ())
861
945
  return {int64_to_uint64(row["run_id"]) for row in rows}
862
946
 
863
- def _check_and_tag_inactive_run(self, run_ids: set[int]) -> None:
864
- """Check if any runs are no longer active.
865
-
866
- Marks runs with status 'starting' or 'running' as failed
867
- if they have not sent a heartbeat before `active_until`.
868
- """
869
- sint_run_ids = [uint64_to_int64(run_id) for run_id in run_ids]
870
- query = "UPDATE run SET finished_at = ?, sub_status = ?, details = ? "
871
- query += "WHERE starting_at != '' AND finished_at = '' AND active_until < ?"
872
- query += f" AND run_id IN ({','.join(['?'] * len(run_ids))});"
873
- current = now()
874
- self.query(
875
- query,
876
- (
877
- current.isoformat(),
878
- SubStatus.FAILED,
879
- RUN_FAILURE_DETAILS_NO_HEARTBEAT,
880
- current.timestamp(),
881
- *sint_run_ids,
882
- ),
883
- )
884
-
885
- def get_run(self, run_id: int) -> Optional[Run]:
947
+ def get_run(self, run_id: int) -> Run | None:
886
948
  """Retrieve information about the run with the specified `run_id`."""
887
- # Check if runs are still active
888
- self._check_and_tag_inactive_run(run_ids={run_id})
949
+ # Clean up expired tokens; this will flag inactive runs as needed
950
+ self._cleanup_expired_tokens()
889
951
 
890
952
  # Convert the uint64 value to sint64 for SQLite
891
953
  sint64_run_id = uint64_to_int64(run_id)
@@ -909,14 +971,15 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
909
971
  details=row["details"],
910
972
  ),
911
973
  flwr_aid=row["flwr_aid"],
974
+ federation=row["federation"],
912
975
  )
913
976
  log(ERROR, "`run_id` does not exist.")
914
977
  return None
915
978
 
916
979
  def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
917
980
  """Retrieve the statuses for the specified runs."""
918
- # Check if runs are still active
919
- self._check_and_tag_inactive_run(run_ids=run_ids)
981
+ # Clean up expired tokens; this will flag inactive runs as needed
982
+ self._cleanup_expired_tokens()
920
983
 
921
984
  # Convert the uint64 value to sint64 for SQLite
922
985
  sint64_run_ids = (uint64_to_int64(run_id) for run_id in set(run_ids))
@@ -935,82 +998,73 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
935
998
 
936
999
  def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
937
1000
  """Update the status of the run with the specified `run_id`."""
938
- # Check if runs are still active
939
- self._check_and_tag_inactive_run(run_ids={run_id})
1001
+ # Clean up expired tokens; this will flag inactive runs as needed
1002
+ self._cleanup_expired_tokens()
940
1003
 
941
- # Convert the uint64 value to sint64 for SQLite
942
- sint64_run_id = uint64_to_int64(run_id)
943
- query = "SELECT * FROM run WHERE run_id = ?;"
944
- rows = self.query(query, (sint64_run_id,))
945
-
946
- # Check if the run_id exists
947
- if not rows:
948
- log(ERROR, "`run_id` is invalid")
949
- return False
1004
+ with self.conn:
1005
+ # Convert the uint64 value to sint64 for SQLite
1006
+ sint64_run_id = uint64_to_int64(run_id)
1007
+ query = "SELECT * FROM run WHERE run_id = ?;"
1008
+ rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
1009
+
1010
+ # Check if the run_id exists
1011
+ if not rows:
1012
+ log(ERROR, "`run_id` is invalid")
1013
+ return False
950
1014
 
951
- # Check if the status transition is valid
952
- row = rows[0]
953
- current_status = RunStatus(
954
- status=determine_run_status(row),
955
- sub_status=row["sub_status"],
956
- details=row["details"],
957
- )
958
- if not is_valid_transition(current_status, new_status):
959
- log(
960
- ERROR,
961
- 'Invalid status transition: from "%s" to "%s"',
962
- current_status.status,
963
- new_status.status,
1015
+ # Check if the status transition is valid
1016
+ row = rows[0]
1017
+ current_status = RunStatus(
1018
+ status=determine_run_status(row),
1019
+ sub_status=row["sub_status"],
1020
+ details=row["details"],
964
1021
  )
965
- return False
1022
+ if not is_valid_transition(current_status, new_status):
1023
+ log(
1024
+ ERROR,
1025
+ 'Invalid status transition: from "%s" to "%s"',
1026
+ current_status.status,
1027
+ new_status.status,
1028
+ )
1029
+ return False
966
1030
 
967
- # Check if the sub-status is valid
968
- if not has_valid_sub_status(current_status):
969
- log(
970
- ERROR,
971
- 'Invalid sub-status "%s" for status "%s"',
972
- current_status.sub_status,
973
- current_status.status,
974
- )
975
- return False
1031
+ # Check if the sub-status is valid
1032
+ if not has_valid_sub_status(current_status):
1033
+ log(
1034
+ ERROR,
1035
+ 'Invalid sub-status "%s" for status "%s"',
1036
+ current_status.sub_status,
1037
+ current_status.status,
1038
+ )
1039
+ return False
976
1040
 
977
- # Update the status
978
- query = "UPDATE run SET %s= ?, sub_status = ?, details = ?, "
979
- query += "active_until = ?, heartbeat_interval = ? "
980
- query += "WHERE run_id = ?;"
1041
+ # Update the status
1042
+ query = """
1043
+ UPDATE run SET %s= ?, sub_status = ?, details = ? WHERE run_id = ?;
1044
+ """
981
1045
 
982
- # Prepare data for query
983
- # Initialize heartbeat_interval and active_until
984
- # when switching to starting or running
985
- current = now()
986
- if new_status.status in (Status.STARTING, Status.RUNNING):
987
- heartbeat_interval = HEARTBEAT_INTERVAL_INF
988
- active_until = current.timestamp() + heartbeat_interval
989
- else:
990
- heartbeat_interval = 0
991
- active_until = 0
992
-
993
- # Determine the timestamp field based on the new status
994
- timestamp_fld = ""
995
- if new_status.status == Status.STARTING:
996
- timestamp_fld = "starting_at"
997
- elif new_status.status == Status.RUNNING:
998
- timestamp_fld = "running_at"
999
- elif new_status.status == Status.FINISHED:
1000
- timestamp_fld = "finished_at"
1001
-
1002
- data = (
1003
- current.isoformat(),
1004
- new_status.sub_status,
1005
- new_status.details,
1006
- active_until,
1007
- heartbeat_interval,
1008
- uint64_to_int64(run_id),
1009
- )
1010
- self.query(query % timestamp_fld, data)
1046
+ # Prepare data for query
1047
+ current = now()
1048
+
1049
+ # Determine the timestamp field based on the new status
1050
+ timestamp_fld = ""
1051
+ if new_status.status == Status.STARTING:
1052
+ timestamp_fld = "starting_at"
1053
+ elif new_status.status == Status.RUNNING:
1054
+ timestamp_fld = "running_at"
1055
+ elif new_status.status == Status.FINISHED:
1056
+ timestamp_fld = "finished_at"
1057
+
1058
+ data = (
1059
+ current.isoformat(),
1060
+ new_status.sub_status,
1061
+ new_status.details,
1062
+ uint64_to_int64(run_id),
1063
+ )
1064
+ self.conn.execute(query % timestamp_fld, data)
1011
1065
  return True
1012
1066
 
1013
- def get_pending_run_id(self) -> Optional[int]:
1067
+ def get_pending_run_id(self) -> int | None:
1014
1068
  """Get the `run_id` of a run with `Status.PENDING` status, if any."""
1015
1069
  pending_run_id = None
1016
1070
 
@@ -1022,7 +1076,7 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
1022
1076
 
1023
1077
  return pending_run_id
1024
1078
 
1025
- def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
1079
+ def get_federation_options(self, run_id: int) -> ConfigRecord | None:
1026
1080
  """Retrieve the federation options for the specified `run_id`."""
1027
1081
  # Convert the uint64 value to sint64 for SQLite
1028
1082
  sint64_run_id = uint64_to_int64(run_id)
@@ -1080,45 +1134,36 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
1080
1134
  self.conn.execute(query, params)
1081
1135
  return True
1082
1136
 
1083
- def acknowledge_app_heartbeat(self, run_id: int, heartbeat_interval: float) -> bool:
1084
- """Acknowledge a heartbeat received from a ServerApp for a given run.
1137
+ def _on_tokens_expired(self, expired_records: list[tuple[int, float]]) -> None:
1138
+ """Transition runs with expired tokens to failed status.
1085
1139
 
1086
- A run with status `"running"` is considered alive as long as it sends heartbeats
1087
- within the tolerated interval: HEARTBEAT_PATIENCE × heartbeat_interval.
1088
- HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before the run is
1089
- marked as `"completed:failed"`.
1140
+ Parameters
1141
+ ----------
1142
+ expired_records : list[tuple[int, float]]
1143
+ List of tuples containing (run_id, active_until timestamp)
1144
+ for expired tokens.
1090
1145
  """
1091
- # Check if runs are still active
1092
- self._check_and_tag_inactive_run(run_ids={run_id})
1093
-
1094
- # Search for the run
1095
- sint_run_id = uint64_to_int64(run_id)
1096
- query = "SELECT * FROM run WHERE run_id = ?;"
1097
- rows = self.query(query, (sint_run_id,))
1098
-
1099
- if not rows:
1100
- log(ERROR, "`run_id` is invalid")
1101
- return False
1102
-
1103
- # Check if the run is of status "running"/"starting"
1104
- row = rows[0]
1105
- status = determine_run_status(row)
1106
- if status not in (Status.RUNNING, Status.STARTING):
1107
- log(
1108
- ERROR,
1109
- 'Cannot acknowledge heartbeat for run with status "%s"',
1110
- status,
1111
- )
1112
- return False
1146
+ if not expired_records:
1147
+ return
1113
1148
 
1114
- # Update the `active_until` and `heartbeat_interval` for the given run
1115
- active_until = now().timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval
1116
- query = "UPDATE run SET active_until = ?, heartbeat_interval = ? "
1117
- query += "WHERE run_id = ?"
1118
- self.query(query, (active_until, heartbeat_interval, sint_run_id))
1119
- return True
1149
+ with self.conn:
1150
+ query = """
1151
+ UPDATE run
1152
+ SET sub_status = ?, details = ?, finished_at = ?
1153
+ WHERE run_id = ?;
1154
+ """
1155
+ data = [
1156
+ (
1157
+ SubStatus.FAILED,
1158
+ RUN_FAILURE_DETAILS_NO_HEARTBEAT,
1159
+ datetime.fromtimestamp(active_until, tz=timezone.utc).isoformat(),
1160
+ uint64_to_int64(run_id),
1161
+ )
1162
+ for run_id, active_until in expired_records
1163
+ ]
1164
+ self.conn.executemany(query, data)
1120
1165
 
1121
- def get_serverapp_context(self, run_id: int) -> Optional[Context]:
1166
+ def get_serverapp_context(self, run_id: int) -> Context | None:
1122
1167
  """Get the context for the specified `run_id`."""
1123
1168
  # Retrieve context if any
1124
1169
  query = "SELECT context FROM context WHERE run_id = ?;"
@@ -1132,19 +1177,21 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
1132
1177
  context_bytes = context_to_bytes(context)
1133
1178
  sint_run_id = uint64_to_int64(run_id)
1134
1179
 
1135
- # Check if any existing Context assigned to the run_id
1136
- query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
1137
- if self.query(query, (sint_run_id,))[0]["COUNT(*)"] > 0:
1138
- # Update context
1139
- query = "UPDATE context SET context = ? WHERE run_id = ?;"
1140
- self.query(query, (context_bytes, sint_run_id))
1141
- else:
1142
- try:
1143
- # Store context
1144
- query = "INSERT INTO context (run_id, context) VALUES (?, ?);"
1145
- self.query(query, (sint_run_id, context_bytes))
1146
- except sqlite3.IntegrityError:
1147
- raise ValueError(f"Run {run_id} not found") from None
1180
+ with self.conn:
1181
+ # Check if any existing Context assigned to the run_id
1182
+ query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
1183
+ row = self.conn.execute(query, (sint_run_id,)).fetchone()
1184
+ if row["COUNT(*)"] > 0:
1185
+ # Update context
1186
+ query = "UPDATE context SET context = ? WHERE run_id = ?;"
1187
+ self.conn.execute(query, (context_bytes, sint_run_id))
1188
+ else:
1189
+ try:
1190
+ # Store context
1191
+ query = "INSERT INTO context (run_id, context) VALUES (?, ?);"
1192
+ self.conn.execute(query, (sint_run_id, context_bytes))
1193
+ except sqlite3.IntegrityError:
1194
+ raise ValueError(f"Run {run_id} not found") from None
1148
1195
 
1149
1196
  def add_serverapp_log(self, run_id: int, log_message: str) -> None:
1150
1197
  """Add a log entry to the ServerApp logs for the specified `run_id`."""
@@ -1161,90 +1208,52 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
1161
1208
  raise ValueError(f"Run {run_id} not found") from None
1162
1209
 
1163
1210
  def get_serverapp_log(
1164
- self, run_id: int, after_timestamp: Optional[float]
1211
+ self, run_id: int, after_timestamp: float | None
1165
1212
  ) -> tuple[str, float]:
1166
1213
  """Get the ServerApp logs for the specified `run_id`."""
1167
1214
  # Convert the uint64 value to sint64 for SQLite
1168
1215
  sint64_run_id = uint64_to_int64(run_id)
1169
1216
 
1170
- # Check if the run_id exists
1171
- query = "SELECT run_id FROM run WHERE run_id = ?;"
1172
- if not self.query(query, (sint64_run_id,)):
1173
- raise ValueError(f"Run {run_id} not found")
1174
-
1175
- # Retrieve logs
1176
- if after_timestamp is None:
1177
- after_timestamp = 0.0
1178
- query = """
1179
- SELECT log, timestamp FROM logs
1180
- WHERE run_id = ? AND node_id = ? AND timestamp > ?;
1181
- """
1182
- rows = self.query(query, (sint64_run_id, 0, after_timestamp))
1183
- rows.sort(key=lambda x: x["timestamp"])
1184
- latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
1217
+ with self.conn:
1218
+ # Check if the run_id exists
1219
+ query = "SELECT run_id FROM run WHERE run_id = ?;"
1220
+ rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
1221
+ if not rows:
1222
+ raise ValueError(f"Run {run_id} not found")
1223
+
1224
+ # Retrieve logs
1225
+ if after_timestamp is None:
1226
+ after_timestamp = 0.0
1227
+ query = """
1228
+ SELECT log, timestamp FROM logs
1229
+ WHERE run_id = ? AND node_id = ? AND timestamp > ?;
1230
+ """
1231
+ rows = self.conn.execute(
1232
+ query, (sint64_run_id, 0, after_timestamp)
1233
+ ).fetchall()
1234
+ rows.sort(key=lambda x: x["timestamp"])
1235
+ latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
1185
1236
  return "".join(row["log"] for row in rows), latest_timestamp
1186
1237
 
1187
- def get_valid_message_ins(self, message_id: str) -> Optional[dict[str, Any]]:
1238
+ def get_valid_message_ins(self, message_id: str) -> dict[str, Any] | None:
1188
1239
  """Check if the Message exists and is valid (not expired).
1189
1240
 
1190
1241
  Return Message if valid.
1191
1242
  """
1192
- query = """
1193
- SELECT *
1194
- FROM message_ins
1195
- WHERE message_id = :message_id
1196
- """
1197
- data = {"message_id": message_id}
1198
- rows = self.query(query, data)
1199
- if not rows:
1200
- # Message does not exist
1201
- return None
1202
-
1203
- message_ins = rows[0]
1204
- created_at = message_ins["created_at"]
1205
- ttl = message_ins["ttl"]
1206
- current_time = now().timestamp()
1207
-
1208
- # Check if Message is expired
1209
- if ttl is not None and created_at + ttl <= current_time:
1210
- return None
1211
-
1212
- return message_ins
1243
+ with self.conn:
1244
+ self._check_stored_messages({message_id})
1245
+ query = """
1246
+ SELECT *
1247
+ FROM message_ins
1248
+ WHERE message_id = :message_id
1249
+ """
1250
+ data = {"message_id": message_id}
1251
+ rows: list[dict[str, Any]] = self.conn.execute(query, data).fetchall()
1252
+ if not rows:
1253
+ # Message does not exist
1254
+ return None
1213
1255
 
1214
- def create_token(self, run_id: int) -> Optional[str]:
1215
- """Create a token for the given run ID."""
1216
- token = secrets.token_hex(FLWR_APP_TOKEN_LENGTH) # Generate a random token
1217
- query = "INSERT INTO token_store (run_id, token) VALUES (:run_id, :token);"
1218
- data = {"run_id": uint64_to_int64(run_id), "token": token}
1219
- try:
1220
- self.query(query, data)
1221
- except sqlite3.IntegrityError:
1222
- return None # Token already created for this run ID
1223
- return token
1224
-
1225
- def verify_token(self, run_id: int, token: str) -> bool:
1226
- """Verify a token for the given run ID."""
1227
- query = "SELECT token FROM token_store WHERE run_id = :run_id;"
1228
- data = {"run_id": uint64_to_int64(run_id)}
1229
- rows = self.query(query, data)
1230
- if not rows:
1231
- return False
1232
- return cast(str, rows[0]["token"]) == token
1233
-
1234
- def delete_token(self, run_id: int) -> None:
1235
- """Delete the token for the given run ID."""
1236
- query = "DELETE FROM token_store WHERE run_id = :run_id;"
1237
- data = {"run_id": uint64_to_int64(run_id)}
1238
- self.query(query, data)
1239
-
1240
- def get_run_id_by_token(self, token: str) -> Optional[int]:
1241
- """Get the run ID associated with a given token."""
1242
- query = "SELECT run_id FROM token_store WHERE token = :token;"
1243
- data = {"token": token}
1244
- rows = self.query(query, data)
1245
- if not rows:
1246
- return None
1247
- return int64_to_uint64(rows[0]["run_id"])
1256
+ return rows[0]
1248
1257
 
1249
1258
 
1250
1259
  def message_to_dict(message: Message) -> dict[str, Any]: