flwr 1.23.0__py3-none-any.whl → 1.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (339) hide show
  1. flwr/__init__.py +16 -5
  2. flwr/app/error.py +2 -2
  3. flwr/app/exception.py +3 -3
  4. flwr/cli/app.py +19 -0
  5. flwr/cli/{new/templates → app_cmd}/__init__.py +9 -1
  6. flwr/cli/app_cmd/publish.py +285 -0
  7. flwr/cli/app_cmd/review.py +262 -0
  8. flwr/cli/auth_plugin/auth_plugin.py +4 -5
  9. flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
  10. flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
  11. flwr/cli/build.py +60 -18
  12. flwr/cli/cli_account_auth_interceptor.py +24 -7
  13. flwr/cli/config_utils.py +101 -13
  14. flwr/cli/{new/templates/app/code/flwr_tune → federation}/__init__.py +10 -1
  15. flwr/cli/federation/ls.py +140 -0
  16. flwr/cli/federation/show.py +318 -0
  17. flwr/cli/install.py +91 -13
  18. flwr/cli/log.py +52 -9
  19. flwr/cli/login/login.py +7 -4
  20. flwr/cli/ls.py +211 -130
  21. flwr/cli/new/new.py +123 -331
  22. flwr/cli/pull.py +10 -5
  23. flwr/cli/run/run.py +71 -29
  24. flwr/cli/run_utils.py +148 -0
  25. flwr/cli/stop.py +26 -8
  26. flwr/cli/supernode/ls.py +25 -12
  27. flwr/cli/supernode/register.py +9 -4
  28. flwr/cli/supernode/unregister.py +5 -3
  29. flwr/cli/utils.py +239 -16
  30. flwr/client/__init__.py +1 -1
  31. flwr/client/dpfedavg_numpy_client.py +4 -1
  32. flwr/client/grpc_adapter_client/connection.py +8 -9
  33. flwr/client/grpc_rere_client/connection.py +16 -14
  34. flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
  35. flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
  36. flwr/client/message_handler/message_handler.py +2 -2
  37. flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
  38. flwr/client/numpy_client.py +1 -1
  39. flwr/client/rest_client/connection.py +18 -18
  40. flwr/client/run_info_store.py +4 -5
  41. flwr/client/typing.py +1 -1
  42. flwr/clientapp/client_app.py +9 -10
  43. flwr/clientapp/mod/centraldp_mods.py +16 -17
  44. flwr/clientapp/mod/localdp_mod.py +8 -9
  45. flwr/clientapp/typing.py +1 -1
  46. flwr/clientapp/utils.py +3 -3
  47. flwr/common/address.py +1 -2
  48. flwr/common/args.py +3 -4
  49. flwr/common/config.py +13 -16
  50. flwr/common/constant.py +5 -2
  51. flwr/common/differential_privacy.py +3 -4
  52. flwr/common/event_log_plugin/event_log_plugin.py +3 -4
  53. flwr/common/exit/exit.py +15 -2
  54. flwr/common/exit/exit_code.py +19 -0
  55. flwr/common/exit/exit_handler.py +6 -2
  56. flwr/common/exit/signal_handler.py +5 -5
  57. flwr/common/grpc.py +6 -6
  58. flwr/common/inflatable_protobuf_utils.py +1 -1
  59. flwr/common/inflatable_utils.py +38 -21
  60. flwr/common/logger.py +19 -19
  61. flwr/common/message.py +4 -4
  62. flwr/common/object_ref.py +7 -7
  63. flwr/common/record/array.py +3 -3
  64. flwr/common/record/arrayrecord.py +18 -30
  65. flwr/common/record/configrecord.py +3 -3
  66. flwr/common/record/recorddict.py +5 -5
  67. flwr/common/record/typeddict.py +9 -2
  68. flwr/common/recorddict_compat.py +7 -10
  69. flwr/common/retry_invoker.py +20 -20
  70. flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
  71. flwr/common/serde.py +11 -4
  72. flwr/common/serde_utils.py +2 -2
  73. flwr/common/telemetry.py +9 -5
  74. flwr/common/typing.py +58 -37
  75. flwr/compat/client/app.py +38 -37
  76. flwr/compat/client/grpc_client/connection.py +11 -11
  77. flwr/compat/server/app.py +5 -6
  78. flwr/proto/appio_pb2.py +13 -3
  79. flwr/proto/appio_pb2.pyi +134 -65
  80. flwr/proto/appio_pb2_grpc.py +20 -0
  81. flwr/proto/appio_pb2_grpc.pyi +27 -0
  82. flwr/proto/clientappio_pb2.py +17 -7
  83. flwr/proto/clientappio_pb2.pyi +15 -0
  84. flwr/proto/clientappio_pb2_grpc.py +206 -40
  85. flwr/proto/clientappio_pb2_grpc.pyi +168 -53
  86. flwr/proto/control_pb2.py +71 -52
  87. flwr/proto/control_pb2.pyi +277 -111
  88. flwr/proto/control_pb2_grpc.py +249 -40
  89. flwr/proto/control_pb2_grpc.pyi +185 -52
  90. flwr/proto/error_pb2.py +13 -3
  91. flwr/proto/error_pb2.pyi +24 -6
  92. flwr/proto/error_pb2_grpc.py +20 -0
  93. flwr/proto/error_pb2_grpc.pyi +27 -0
  94. flwr/proto/fab_pb2.py +14 -4
  95. flwr/proto/fab_pb2.pyi +59 -31
  96. flwr/proto/fab_pb2_grpc.py +20 -0
  97. flwr/proto/fab_pb2_grpc.pyi +27 -0
  98. flwr/proto/federation_pb2.py +38 -0
  99. flwr/proto/federation_pb2.pyi +56 -0
  100. flwr/proto/federation_pb2_grpc.py +24 -0
  101. flwr/proto/federation_pb2_grpc.pyi +31 -0
  102. flwr/proto/fleet_pb2.py +24 -14
  103. flwr/proto/fleet_pb2.pyi +141 -61
  104. flwr/proto/fleet_pb2_grpc.py +189 -48
  105. flwr/proto/fleet_pb2_grpc.pyi +175 -61
  106. flwr/proto/grpcadapter_pb2.py +14 -4
  107. flwr/proto/grpcadapter_pb2.pyi +38 -16
  108. flwr/proto/grpcadapter_pb2_grpc.py +35 -4
  109. flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
  110. flwr/proto/heartbeat_pb2.py +17 -7
  111. flwr/proto/heartbeat_pb2.pyi +51 -22
  112. flwr/proto/heartbeat_pb2_grpc.py +20 -0
  113. flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
  114. flwr/proto/log_pb2.py +13 -3
  115. flwr/proto/log_pb2.pyi +34 -11
  116. flwr/proto/log_pb2_grpc.py +20 -0
  117. flwr/proto/log_pb2_grpc.pyi +27 -0
  118. flwr/proto/message_pb2.py +15 -5
  119. flwr/proto/message_pb2.pyi +154 -86
  120. flwr/proto/message_pb2_grpc.py +20 -0
  121. flwr/proto/message_pb2_grpc.pyi +27 -0
  122. flwr/proto/node_pb2.py +15 -5
  123. flwr/proto/node_pb2.pyi +50 -25
  124. flwr/proto/node_pb2_grpc.py +20 -0
  125. flwr/proto/node_pb2_grpc.pyi +27 -0
  126. flwr/proto/recorddict_pb2.py +13 -3
  127. flwr/proto/recorddict_pb2.pyi +184 -107
  128. flwr/proto/recorddict_pb2_grpc.py +20 -0
  129. flwr/proto/recorddict_pb2_grpc.pyi +27 -0
  130. flwr/proto/run_pb2.py +40 -31
  131. flwr/proto/run_pb2.pyi +158 -84
  132. flwr/proto/run_pb2_grpc.py +20 -0
  133. flwr/proto/run_pb2_grpc.pyi +27 -0
  134. flwr/proto/serverappio_pb2.py +13 -3
  135. flwr/proto/serverappio_pb2.pyi +32 -8
  136. flwr/proto/serverappio_pb2_grpc.py +246 -65
  137. flwr/proto/serverappio_pb2_grpc.pyi +221 -85
  138. flwr/proto/simulationio_pb2.py +16 -8
  139. flwr/proto/simulationio_pb2.pyi +15 -0
  140. flwr/proto/simulationio_pb2_grpc.py +162 -41
  141. flwr/proto/simulationio_pb2_grpc.pyi +149 -55
  142. flwr/proto/transport_pb2.py +20 -10
  143. flwr/proto/transport_pb2.pyi +249 -160
  144. flwr/proto/transport_pb2_grpc.py +35 -4
  145. flwr/proto/transport_pb2_grpc.pyi +38 -8
  146. flwr/server/app.py +39 -17
  147. flwr/server/client_manager.py +4 -5
  148. flwr/server/client_proxy.py +10 -11
  149. flwr/server/compat/app.py +4 -5
  150. flwr/server/compat/app_utils.py +2 -1
  151. flwr/server/compat/grid_client_proxy.py +10 -12
  152. flwr/server/compat/legacy_context.py +3 -4
  153. flwr/server/fleet_event_log_interceptor.py +2 -1
  154. flwr/server/grid/grid.py +2 -3
  155. flwr/server/grid/grpc_grid.py +10 -8
  156. flwr/server/grid/inmemory_grid.py +4 -4
  157. flwr/server/run_serverapp.py +2 -3
  158. flwr/server/server.py +34 -39
  159. flwr/server/server_app.py +7 -8
  160. flwr/server/server_config.py +1 -2
  161. flwr/server/serverapp/app.py +34 -28
  162. flwr/server/serverapp_components.py +4 -5
  163. flwr/server/strategy/aggregate.py +9 -8
  164. flwr/server/strategy/bulyan.py +13 -11
  165. flwr/server/strategy/dp_adaptive_clipping.py +16 -20
  166. flwr/server/strategy/dp_fixed_clipping.py +12 -17
  167. flwr/server/strategy/dpfedavg_adaptive.py +3 -4
  168. flwr/server/strategy/dpfedavg_fixed.py +6 -10
  169. flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
  170. flwr/server/strategy/fedadagrad.py +18 -14
  171. flwr/server/strategy/fedadam.py +16 -14
  172. flwr/server/strategy/fedavg.py +16 -17
  173. flwr/server/strategy/fedavg_android.py +15 -15
  174. flwr/server/strategy/fedavgm.py +21 -18
  175. flwr/server/strategy/fedmedian.py +2 -3
  176. flwr/server/strategy/fedopt.py +11 -10
  177. flwr/server/strategy/fedprox.py +10 -9
  178. flwr/server/strategy/fedtrimmedavg.py +12 -11
  179. flwr/server/strategy/fedxgb_bagging.py +13 -11
  180. flwr/server/strategy/fedxgb_cyclic.py +6 -6
  181. flwr/server/strategy/fedxgb_nn_avg.py +4 -4
  182. flwr/server/strategy/fedyogi.py +16 -14
  183. flwr/server/strategy/krum.py +12 -11
  184. flwr/server/strategy/qfedavg.py +16 -15
  185. flwr/server/strategy/strategy.py +6 -9
  186. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
  187. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
  188. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
  189. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
  190. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
  191. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
  192. flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
  193. flwr/server/superlink/fleet/message_handler/message_handler.py +75 -30
  194. flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
  195. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  196. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
  197. flwr/server/superlink/fleet/vce/vce_api.py +15 -9
  198. flwr/server/superlink/linkstate/in_memory_linkstate.py +148 -149
  199. flwr/server/superlink/linkstate/linkstate.py +91 -43
  200. flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
  201. flwr/server/superlink/linkstate/sqlite_linkstate.py +502 -436
  202. flwr/server/superlink/linkstate/utils.py +6 -6
  203. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
  204. flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
  205. flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
  206. flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
  207. flwr/server/superlink/utils.py +4 -6
  208. flwr/server/typing.py +1 -1
  209. flwr/server/utils/tensorboard.py +15 -8
  210. flwr/server/workflow/default_workflows.py +5 -5
  211. flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
  212. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
  213. flwr/serverapp/strategy/bulyan.py +16 -15
  214. flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
  215. flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
  216. flwr/serverapp/strategy/fedadagrad.py +10 -11
  217. flwr/serverapp/strategy/fedadam.py +10 -11
  218. flwr/serverapp/strategy/fedavg.py +9 -10
  219. flwr/serverapp/strategy/fedavgm.py +17 -16
  220. flwr/serverapp/strategy/fedmedian.py +2 -2
  221. flwr/serverapp/strategy/fedopt.py +10 -11
  222. flwr/serverapp/strategy/fedprox.py +7 -8
  223. flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
  224. flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
  225. flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
  226. flwr/serverapp/strategy/fedyogi.py +9 -11
  227. flwr/serverapp/strategy/krum.py +7 -7
  228. flwr/serverapp/strategy/multikrum.py +9 -9
  229. flwr/serverapp/strategy/qfedavg.py +17 -16
  230. flwr/serverapp/strategy/strategy.py +6 -9
  231. flwr/serverapp/strategy/strategy_utils.py +7 -8
  232. flwr/simulation/app.py +46 -42
  233. flwr/simulation/legacy_app.py +12 -12
  234. flwr/simulation/ray_transport/ray_actor.py +10 -11
  235. flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
  236. flwr/simulation/run_simulation.py +43 -43
  237. flwr/simulation/simulationio_connection.py +4 -4
  238. flwr/supercore/cli/flower_superexec.py +3 -4
  239. flwr/supercore/constant.py +34 -1
  240. flwr/supercore/corestate/corestate.py +24 -3
  241. flwr/supercore/corestate/in_memory_corestate.py +138 -0
  242. flwr/supercore/corestate/sqlite_corestate.py +157 -0
  243. flwr/supercore/ffs/disk_ffs.py +1 -2
  244. flwr/supercore/ffs/ffs.py +1 -2
  245. flwr/supercore/ffs/ffs_factory.py +1 -2
  246. flwr/{common → supercore}/heartbeat.py +20 -25
  247. flwr/supercore/object_store/in_memory_object_store.py +1 -2
  248. flwr/supercore/object_store/object_store.py +1 -2
  249. flwr/supercore/object_store/object_store_factory.py +1 -2
  250. flwr/supercore/object_store/sqlite_object_store.py +8 -7
  251. flwr/supercore/primitives/asymmetric.py +1 -1
  252. flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
  253. flwr/supercore/sqlite_mixin.py +37 -34
  254. flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
  255. flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
  256. flwr/supercore/superexec/run_superexec.py +9 -13
  257. flwr/supercore/utils.py +190 -0
  258. flwr/superlink/artifact_provider/artifact_provider.py +1 -2
  259. flwr/superlink/auth_plugin/auth_plugin.py +6 -9
  260. flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
  261. flwr/{cli/new/templates/app → superlink/federation}/__init__.py +10 -1
  262. flwr/superlink/federation/federation_manager.py +64 -0
  263. flwr/superlink/federation/noop_federation_manager.py +71 -0
  264. flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
  265. flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
  266. flwr/superlink/servicer/control/control_grpc.py +7 -6
  267. flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
  268. flwr/superlink/servicer/control/control_servicer.py +190 -23
  269. flwr/supernode/cli/flower_supernode.py +58 -3
  270. flwr/supernode/nodestate/in_memory_nodestate.py +121 -49
  271. flwr/supernode/nodestate/nodestate.py +52 -8
  272. flwr/supernode/nodestate/nodestate_factory.py +7 -4
  273. flwr/supernode/runtime/run_clientapp.py +41 -22
  274. flwr/supernode/servicer/clientappio/clientappio_servicer.py +46 -10
  275. flwr/supernode/start_client_internal.py +165 -46
  276. {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/METADATA +9 -11
  277. flwr-1.25.0.dist-info/RECORD +393 -0
  278. flwr/cli/new/templates/app/.gitignore.tpl +0 -163
  279. flwr/cli/new/templates/app/LICENSE.tpl +0 -202
  280. flwr/cli/new/templates/app/README.baseline.md.tpl +0 -127
  281. flwr/cli/new/templates/app/README.flowertune.md.tpl +0 -68
  282. flwr/cli/new/templates/app/README.md.tpl +0 -37
  283. flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +0 -1
  284. flwr/cli/new/templates/app/code/__init__.py +0 -15
  285. flwr/cli/new/templates/app/code/__init__.py.tpl +0 -1
  286. flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +0 -1
  287. flwr/cli/new/templates/app/code/client.baseline.py.tpl +0 -75
  288. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +0 -93
  289. flwr/cli/new/templates/app/code/client.jax.py.tpl +0 -71
  290. flwr/cli/new/templates/app/code/client.mlx.py.tpl +0 -102
  291. flwr/cli/new/templates/app/code/client.numpy.py.tpl +0 -46
  292. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +0 -80
  293. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +0 -55
  294. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +0 -108
  295. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -82
  296. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +0 -110
  297. flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +0 -36
  298. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +0 -92
  299. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +0 -87
  300. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -56
  301. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +0 -73
  302. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +0 -78
  303. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -66
  304. flwr/cli/new/templates/app/code/server.baseline.py.tpl +0 -43
  305. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +0 -42
  306. flwr/cli/new/templates/app/code/server.jax.py.tpl +0 -39
  307. flwr/cli/new/templates/app/code/server.mlx.py.tpl +0 -41
  308. flwr/cli/new/templates/app/code/server.numpy.py.tpl +0 -38
  309. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +0 -41
  310. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +0 -31
  311. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +0 -44
  312. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +0 -38
  313. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +0 -56
  314. flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +0 -1
  315. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +0 -98
  316. flwr/cli/new/templates/app/code/task.jax.py.tpl +0 -57
  317. flwr/cli/new/templates/app/code/task.mlx.py.tpl +0 -102
  318. flwr/cli/new/templates/app/code/task.numpy.py.tpl +0 -7
  319. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +0 -98
  320. flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +0 -111
  321. flwr/cli/new/templates/app/code/task.sklearn.py.tpl +0 -67
  322. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +0 -52
  323. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +0 -67
  324. flwr/cli/new/templates/app/code/utils.baseline.py.tpl +0 -1
  325. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +0 -146
  326. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +0 -80
  327. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +0 -65
  328. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +0 -52
  329. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +0 -56
  330. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +0 -49
  331. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +0 -53
  332. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +0 -53
  333. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +0 -52
  334. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +0 -53
  335. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +0 -61
  336. flwr/supercore/object_store/utils.py +0 -43
  337. flwr-1.23.0.dist-info/RECORD +0 -439
  338. {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/WHEEL +0 -0
  339. {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/entry_points.txt +0 -0
@@ -18,16 +18,14 @@
18
18
  # pylint: disable=too-many-lines
19
19
 
20
20
  import json
21
- import secrets
22
21
  import sqlite3
23
22
  from collections.abc import Sequence
23
+ from datetime import datetime, timezone
24
24
  from logging import ERROR, WARNING
25
- from typing import Any, Optional, Union, cast
25
+ from typing import Any, cast
26
26
 
27
27
  from flwr.common import Context, Message, Metadata, log, now
28
28
  from flwr.common.constant import (
29
- FLWR_APP_TOKEN_LENGTH,
30
- HEARTBEAT_INTERVAL_INF,
31
29
  HEARTBEAT_PATIENCE,
32
30
  MESSAGE_TTL_TOLERANCE,
33
31
  NODE_ID_NUM_BYTES,
@@ -51,8 +49,10 @@ from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
51
49
  # pylint: enable=E0611
52
50
  from flwr.server.utils.validator import validate_message
53
51
  from flwr.supercore.constant import NodeStatus
54
- from flwr.supercore.sqlite_mixin import SqliteMixin
52
+ from flwr.supercore.corestate.sqlite_corestate import SqliteCoreState
53
+ from flwr.supercore.object_store.object_store import ObjectStore
55
54
  from flwr.supercore.utils import int64_to_uint64, uint64_to_int64
55
+ from flwr.superlink.federation import FederationManager
56
56
 
57
57
  from .linkstate import LinkState
58
58
  from .utils import (
@@ -74,6 +74,7 @@ SQL_CREATE_TABLE_NODE = """
74
74
  CREATE TABLE IF NOT EXISTS node(
75
75
  node_id INTEGER UNIQUE,
76
76
  owner_aid TEXT,
77
+ owner_name TEXT,
77
78
  status TEXT,
78
79
  registered_at TEXT,
79
80
  last_activated_at TEXT NULL,
@@ -106,8 +107,6 @@ CREATE INDEX IF NOT EXISTS idx_node_status ON node(status);
106
107
  SQL_CREATE_TABLE_RUN = """
107
108
  CREATE TABLE IF NOT EXISTS run(
108
109
  run_id INTEGER UNIQUE,
109
- active_until REAL,
110
- heartbeat_interval REAL,
111
110
  fab_id TEXT,
112
111
  fab_version TEXT,
113
112
  fab_hash TEXT,
@@ -118,8 +117,12 @@ CREATE TABLE IF NOT EXISTS run(
118
117
  finished_at TEXT,
119
118
  sub_status TEXT,
120
119
  details TEXT,
120
+ federation TEXT,
121
121
  federation_options BLOB,
122
- flwr_aid TEXT
122
+ flwr_aid TEXT,
123
+ bytes_sent INTEGER DEFAULT 0,
124
+ bytes_recv INTEGER DEFAULT 0,
125
+ clientapp_runtime REAL DEFAULT 0.0
123
126
  );
124
127
  """
125
128
 
@@ -179,20 +182,23 @@ CREATE TABLE IF NOT EXISTS message_res(
179
182
  );
180
183
  """
181
184
 
182
- SQL_CREATE_TABLE_TOKEN_STORE = """
183
- CREATE TABLE IF NOT EXISTS token_store (
184
- run_id INTEGER PRIMARY KEY,
185
- token TEXT UNIQUE NOT NULL
186
- );
187
- """
188
-
189
185
 
190
- class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
186
+ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
191
187
  """SQLite-based LinkState implementation."""
192
188
 
193
- def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
194
- """Connect to the DB, enable FK support, and create tables if needed."""
195
- return self._ensure_initialized(
189
+ def __init__(
190
+ self,
191
+ database_path: str,
192
+ federation_manager: FederationManager,
193
+ object_store: ObjectStore,
194
+ ) -> None:
195
+ super().__init__(database_path, object_store)
196
+ federation_manager.linkstate = self
197
+ self._federation_manager = federation_manager
198
+
199
+ def get_sql_statements(self) -> tuple[str, ...]:
200
+ """Return SQL statements for LinkState tables."""
201
+ return super().get_sql_statements() + (
196
202
  SQL_CREATE_TABLE_RUN,
197
203
  SQL_CREATE_TABLE_LOGS,
198
204
  SQL_CREATE_TABLE_CONTEXT,
@@ -200,14 +206,17 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
200
206
  SQL_CREATE_TABLE_MESSAGE_RES,
201
207
  SQL_CREATE_TABLE_NODE,
202
208
  SQL_CREATE_TABLE_PUBLIC_KEY,
203
- SQL_CREATE_TABLE_TOKEN_STORE,
204
209
  SQL_CREATE_INDEX_ONLINE_UNTIL,
205
210
  SQL_CREATE_INDEX_OWNER_AID,
206
211
  SQL_CREATE_INDEX_NODE_STATUS,
207
- log_queries=log_queries,
208
212
  )
209
213
 
210
- def store_message_ins(self, message: Message) -> Optional[str]:
214
+ @property
215
+ def federation_manager(self) -> FederationManager:
216
+ """Get the FederationManager instance."""
217
+ return self._federation_manager
218
+
219
+ def store_message_ins(self, message: Message) -> str | None:
211
220
  """Store one Message."""
212
221
  # Validate message
213
222
  errors = validate_message(message=message, is_reply_message=False)
@@ -223,12 +232,6 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
223
232
  data[0], ["run_id", "src_node_id", "dst_node_id"]
224
233
  )
225
234
 
226
- # Validate run_id
227
- query = "SELECT run_id FROM run WHERE run_id = ?;"
228
- if not self.query(query, (data[0]["run_id"],)):
229
- log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
230
- return None
231
-
232
235
  # Validate source node ID
233
236
  if message.metadata.src_node_id != SUPERLINK_NODE_ID:
234
237
  log(
@@ -238,28 +241,87 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
238
241
  )
239
242
  return None
240
243
 
241
- # Validate destination node ID
242
- query = "SELECT node_id FROM node WHERE node_id = ? AND status IN (?, ?);"
243
- if not self.query(
244
- query, (data[0]["dst_node_id"], NodeStatus.ONLINE, NodeStatus.OFFLINE)
245
- ):
246
- log(
247
- ERROR,
248
- "Invalid destination node ID for Message: %s",
249
- message.metadata.dst_node_id,
250
- )
251
- return None
252
-
253
- columns = ", ".join([f":{key}" for key in data[0]])
254
- query = f"INSERT INTO message_ins VALUES({columns});"
255
-
256
- # Only invalid run_id can trigger IntegrityError.
257
- # This may need to be changed in the future version with more integrity checks.
258
- self.query(query, data)
244
+ with self.conn:
245
+ # Validate run_id
246
+ query = "SELECT federation FROM run WHERE run_id = ?;"
247
+ rows = self.conn.execute(query, (data[0]["run_id"],)).fetchall()
248
+ if not rows:
249
+ log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
250
+ return None
251
+ federation: str = rows[0]["federation"]
252
+
253
+ # Validate destination node ID
254
+ query = "SELECT node_id FROM node WHERE node_id = ? AND status IN (?, ?);"
255
+ rows = self.conn.execute(
256
+ query, (data[0]["dst_node_id"], NodeStatus.ONLINE, NodeStatus.OFFLINE)
257
+ ).fetchall()
258
+ if not rows or not self.federation_manager.has_node(
259
+ message.metadata.dst_node_id, federation
260
+ ):
261
+ log(
262
+ ERROR,
263
+ "Invalid destination node ID for Message: %s",
264
+ message.metadata.dst_node_id,
265
+ )
266
+ return None
267
+
268
+ columns = ", ".join([f":{key}" for key in data[0]])
269
+ query = f"INSERT INTO message_ins VALUES({columns});"
270
+
271
+ # Only invalid run_id can trigger IntegrityError.
272
+ # This may need to be changed in the future version
273
+ # with more integrity checks.
274
+ self.conn.execute(query, data[0])
259
275
 
260
276
  return message.metadata.message_id
261
277
 
262
- def get_message_ins(self, node_id: int, limit: Optional[int]) -> list[Message]:
278
+ def _check_stored_messages(self, message_ids: set[str]) -> None:
279
+ """Check and delete the message if it's invalid."""
280
+ if not message_ids:
281
+ return
282
+
283
+ with self.conn:
284
+ invalid_msg_ids: set[str] = set()
285
+ current_time = now().timestamp()
286
+
287
+ for msg_id in message_ids:
288
+ # Check if message exists
289
+ query = "SELECT * FROM message_ins WHERE message_id = ?;"
290
+ message_row = self.conn.execute(query, (msg_id,)).fetchone()
291
+ if not message_row:
292
+ continue
293
+
294
+ # Check if the message has expired
295
+ available_until = message_row["created_at"] + message_row["ttl"]
296
+ if available_until <= current_time:
297
+ invalid_msg_ids.add(msg_id)
298
+ continue
299
+
300
+ # Check if src_node_id and dst_node_id are in the federation
301
+ # Get federation from run table
302
+ run_id = message_row["run_id"]
303
+ query = "SELECT federation FROM run WHERE run_id = ?;"
304
+ run_row = self.conn.execute(query, (run_id,)).fetchone()
305
+ if not run_row: # This should not happen
306
+ invalid_msg_ids.add(msg_id)
307
+ continue
308
+ federation = run_row["federation"]
309
+
310
+ # Convert sint64 to uint64 for node IDs
311
+ src_node_id = int64_to_uint64(message_row["src_node_id"])
312
+ dst_node_id = int64_to_uint64(message_row["dst_node_id"])
313
+
314
+ # Filter nodes to check if they're in the federation
315
+ filtered = self.federation_manager.filter_nodes(
316
+ {src_node_id, dst_node_id}, federation
317
+ )
318
+ if len(filtered) != 2: # Not both nodes are in the federation
319
+ invalid_msg_ids.add(msg_id)
320
+
321
+ # Delete all invalid messages
322
+ self.delete_messages(invalid_msg_ids)
323
+
324
+ def get_message_ins(self, node_id: int, limit: int | None) -> list[Message]:
263
325
  """Get all Messages that have not been delivered yet."""
264
326
  if limit is not None and limit < 1:
265
327
  raise AssertionError("`limit` must be >= 1")
@@ -268,59 +330,64 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
268
330
  msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
269
331
  raise AssertionError(msg)
270
332
 
271
- data: dict[str, Union[str, int]] = {}
333
+ data: dict[str, str | int] = {}
272
334
 
273
335
  # Convert the uint64 value to sint64 for SQLite
274
336
  data["node_id"] = uint64_to_int64(node_id)
275
337
 
276
- # Retrieve all Messages for node_id
277
- query = """
278
- SELECT message_id
279
- FROM message_ins
280
- WHERE dst_node_id == :node_id
281
- AND delivered_at = ""
282
- AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
283
- """
284
-
285
- if limit is not None:
286
- query += " LIMIT :limit"
287
- data["limit"] = limit
288
-
289
- query += ";"
290
-
291
- rows = self.query(query, data)
292
-
293
- if rows:
294
- # Prepare query
295
- message_ids = [row["message_id"] for row in rows]
296
- placeholders: str = ",".join([f":id_{i}" for i in range(len(message_ids))])
297
- query = f"""
298
- UPDATE message_ins
299
- SET delivered_at = :delivered_at
300
- WHERE message_id IN ({placeholders})
301
- RETURNING *;
338
+ with self.conn:
339
+ # Retrieve all Messages for node_id
340
+ query = """
341
+ SELECT message_id
342
+ FROM message_ins
343
+ WHERE dst_node_id == :node_id
344
+ AND delivered_at = ""
345
+ AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
302
346
  """
303
347
 
304
- # Prepare data for query
305
- delivered_at = now().isoformat()
306
- data = {"delivered_at": delivered_at}
307
- for index, msg_id in enumerate(message_ids):
308
- data[f"id_{index}"] = str(msg_id)
348
+ if limit is not None:
349
+ query += " LIMIT :limit"
350
+ data["limit"] = limit
309
351
 
310
- # Run query
311
- rows = self.query(query, data)
352
+ query += ";"
312
353
 
313
- for row in rows:
314
- # Convert values from sint64 to uint64
315
- convert_sint64_values_in_dict_to_uint64(
316
- row, ["run_id", "src_node_id", "dst_node_id"]
317
- )
354
+ rows = self.conn.execute(query, data).fetchall()
355
+ message_ids: set[str] = {row["message_id"] for row in rows}
356
+ self._check_stored_messages(message_ids)
357
+
358
+ # Mark retrieved Messages as delivered
359
+ if rows:
360
+ # Prepare query
361
+ placeholders: str = ",".join(
362
+ [f":id_{i}" for i in range(len(message_ids))]
363
+ )
364
+ query = f"""
365
+ UPDATE message_ins
366
+ SET delivered_at = :delivered_at
367
+ WHERE message_id IN ({placeholders})
368
+ RETURNING *;
369
+ """
370
+
371
+ # Prepare data for query
372
+ delivered_at = now().isoformat()
373
+ data = {"delivered_at": delivered_at}
374
+ for index, msg_id in enumerate(message_ids):
375
+ data[f"id_{index}"] = str(msg_id)
376
+
377
+ # Run query
378
+ rows = self.conn.execute(query, data).fetchall()
379
+
380
+ for row in rows:
381
+ # Convert values from sint64 to uint64
382
+ convert_sint64_values_in_dict_to_uint64(
383
+ row, ["run_id", "src_node_id", "dst_node_id"]
384
+ )
318
385
 
319
386
  result = [dict_to_message(row) for row in rows]
320
387
 
321
388
  return result
322
389
 
323
- def store_message_res(self, message: Message) -> Optional[str]:
390
+ def store_message_res(self, message: Message) -> str | None:
324
391
  """Store one Message."""
325
392
  # Validate message
326
393
  errors = validate_message(message=message, is_reply_message=True)
@@ -336,7 +403,8 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
336
403
  ERROR,
337
404
  "Failed to store Message reply: "
338
405
  "The message it replies to with message_id %s does not exist or "
339
- "has expired.",
406
+ "has expired, or was deleted because the target SuperNode was "
407
+ "removed from the federation.",
340
408
  msg_ins_id,
341
409
  )
342
410
  return None
@@ -397,84 +465,92 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
397
465
  # pylint: disable-msg=too-many-locals
398
466
  ret: dict[str, Message] = {}
399
467
 
400
- # Verify Message IDs
401
- current = now().timestamp()
402
- query = f"""
403
- SELECT *
404
- FROM message_ins
405
- WHERE message_id IN ({",".join(["?"] * len(message_ids))});
406
- """
407
- rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
408
- found_message_ins_dict: dict[str, Message] = {}
409
- for row in rows:
410
- convert_sint64_values_in_dict_to_uint64(
411
- row, ["run_id", "src_node_id", "dst_node_id"]
468
+ with self.conn:
469
+ # Verify Message IDs
470
+ self._check_stored_messages(message_ids)
471
+ current = now().timestamp()
472
+ query = f"""
473
+ SELECT *
474
+ FROM message_ins
475
+ WHERE message_id IN ({','.join(['?'] * len(message_ids))});
476
+ """
477
+ rows = self.conn.execute(
478
+ query, tuple(str(message_id) for message_id in message_ids)
479
+ ).fetchall()
480
+ found_message_ins_dict: dict[str, Message] = {}
481
+ for row in rows:
482
+ convert_sint64_values_in_dict_to_uint64(
483
+ row, ["run_id", "src_node_id", "dst_node_id"]
484
+ )
485
+ found_message_ins_dict[row["message_id"]] = dict_to_message(row)
486
+
487
+ ret = verify_message_ids(
488
+ inquired_message_ids=message_ids,
489
+ found_message_ins_dict=found_message_ins_dict,
490
+ current_time=current,
412
491
  )
413
- found_message_ins_dict[row["message_id"]] = dict_to_message(row)
414
492
 
415
- ret = verify_message_ids(
416
- inquired_message_ids=message_ids,
417
- found_message_ins_dict=found_message_ins_dict,
418
- current_time=current,
419
- )
493
+ # Check node availability
494
+ dst_node_ids: set[int] = set()
495
+ for message_id in message_ids:
496
+ in_message = found_message_ins_dict[message_id]
497
+ sint_node_id = uint64_to_int64(in_message.metadata.dst_node_id)
498
+ dst_node_ids.add(sint_node_id)
499
+ query = f"""
500
+ SELECT node_id, online_until
501
+ FROM node
502
+ WHERE node_id IN ({','.join(['?'] * len(dst_node_ids))})
503
+ AND status != ?
504
+ """
505
+ rows = self.conn.execute(
506
+ query, tuple(dst_node_ids) + (NodeStatus.UNREGISTERED,)
507
+ ).fetchall()
508
+ tmp_ret_dict = check_node_availability_for_in_message(
509
+ inquired_in_message_ids=message_ids,
510
+ found_in_message_dict=found_message_ins_dict,
511
+ node_id_to_online_until={
512
+ int64_to_uint64(row["node_id"]): row["online_until"] for row in rows
513
+ },
514
+ current_time=current,
515
+ )
516
+ ret.update(tmp_ret_dict)
420
517
 
421
- # Check node availability
422
- dst_node_ids: set[int] = set()
423
- for message_id in message_ids:
424
- in_message = found_message_ins_dict[message_id]
425
- sint_node_id = uint64_to_int64(in_message.metadata.dst_node_id)
426
- dst_node_ids.add(sint_node_id)
427
- query = f"""
428
- SELECT node_id, online_until
429
- FROM node
430
- WHERE node_id IN ({",".join(["?"] * len(dst_node_ids))})
431
- AND status != ?
432
- """
433
- rows = self.query(query, tuple(dst_node_ids) + (NodeStatus.UNREGISTERED,))
434
- tmp_ret_dict = check_node_availability_for_in_message(
435
- inquired_in_message_ids=message_ids,
436
- found_in_message_dict=found_message_ins_dict,
437
- node_id_to_online_until={
438
- int64_to_uint64(row["node_id"]): row["online_until"] for row in rows
439
- },
440
- current_time=current,
441
- )
442
- ret.update(tmp_ret_dict)
443
-
444
- # Find all reply Messages
445
- query = f"""
446
- SELECT *
447
- FROM message_res
448
- WHERE reply_to_message_id IN ({",".join(["?"] * len(message_ids))})
449
- AND delivered_at = "";
450
- """
451
- rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
452
- for row in rows:
453
- convert_sint64_values_in_dict_to_uint64(
454
- row, ["run_id", "src_node_id", "dst_node_id"]
518
+ # Find all reply Messages
519
+ query = f"""
520
+ SELECT *
521
+ FROM message_res
522
+ WHERE reply_to_message_id IN ({','.join(['?'] * len(message_ids))})
523
+ AND delivered_at = "";
524
+ """
525
+ rows = self.conn.execute(
526
+ query, tuple(str(message_id) for message_id in message_ids)
527
+ ).fetchall()
528
+ for row in rows:
529
+ convert_sint64_values_in_dict_to_uint64(
530
+ row, ["run_id", "src_node_id", "dst_node_id"]
531
+ )
532
+ tmp_ret_dict = verify_found_message_replies(
533
+ inquired_message_ids=message_ids,
534
+ found_message_ins_dict=found_message_ins_dict,
535
+ found_message_res_list=[dict_to_message(row) for row in rows],
536
+ current_time=current,
455
537
  )
456
- tmp_ret_dict = verify_found_message_replies(
457
- inquired_message_ids=message_ids,
458
- found_message_ins_dict=found_message_ins_dict,
459
- found_message_res_list=[dict_to_message(row) for row in rows],
460
- current_time=current,
461
- )
462
- ret.update(tmp_ret_dict)
463
-
464
- # Mark existing reply Messages to be returned as delivered
465
- delivered_at = now().isoformat()
466
- for message_res in ret.values():
467
- message_res.metadata.delivered_at = delivered_at
468
- message_res_ids = [
469
- message_res.metadata.message_id for message_res in ret.values()
470
- ]
471
- query = f"""
472
- UPDATE message_res
473
- SET delivered_at = ?
474
- WHERE message_id IN ({",".join(["?"] * len(message_res_ids))});
475
- """
476
- data: list[Any] = [delivered_at] + message_res_ids
477
- self.query(query, data)
538
+ ret.update(tmp_ret_dict)
539
+
540
+ # Mark existing reply Messages to be returned as delivered
541
+ delivered_at = now().isoformat()
542
+ for message_res in ret.values():
543
+ message_res.metadata.delivered_at = delivered_at
544
+ message_res_ids = [
545
+ message_res.metadata.message_id for message_res in ret.values()
546
+ ]
547
+ query = f"""
548
+ UPDATE message_res
549
+ SET delivered_at = ?
550
+ WHERE message_id IN ({','.join(['?'] * len(message_res_ids))});
551
+ """
552
+ data: list[Any] = [delivered_at] + message_res_ids
553
+ self.conn.execute(query, data)
478
554
 
479
555
  return list(ret.values())
480
556
 
@@ -545,7 +621,11 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
545
621
  return {row["message_id"] for row in rows}
546
622
 
547
623
  def create_node(
548
- self, owner_aid: str, public_key: bytes, heartbeat_interval: float
624
+ self,
625
+ owner_aid: str,
626
+ owner_name: str,
627
+ public_key: bytes,
628
+ heartbeat_interval: float,
549
629
  ) -> int:
550
630
  """Create, store in the link state, and return `node_id`."""
551
631
  # Sample a random uint64 as node_id
@@ -558,10 +638,10 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
558
638
 
559
639
  query = """
560
640
  INSERT INTO node
561
- (node_id, owner_aid, status, registered_at, last_activated_at,
641
+ (node_id, owner_aid, owner_name, status, registered_at, last_activated_at,
562
642
  last_deactivated_at, unregistered_at, online_until, heartbeat_interval,
563
643
  public_key)
564
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
644
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
565
645
  """
566
646
 
567
647
  # Mark the node online until now().timestamp() + heartbeat_interval
@@ -571,6 +651,7 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
571
651
  (
572
652
  sint64_node_id, # node_id
573
653
  owner_aid, # owner_aid
654
+ owner_name, # owner_name
574
655
  NodeStatus.REGISTERED, # status
575
656
  now().isoformat(), # registered_at
576
657
  None, # last_activated_at
@@ -686,23 +767,26 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
686
767
  if self.conn is None:
687
768
  raise AttributeError("LinkState not initialized")
688
769
 
689
- # Convert the uint64 value to sint64 for SQLite
690
- sint64_run_id = uint64_to_int64(run_id)
691
-
692
- # Validate run ID
693
- query = "SELECT COUNT(*) FROM run WHERE run_id = ?"
694
- rows = self.query(query, (sint64_run_id,))
695
- if rows[0]["COUNT(*)"] == 0:
696
- return set()
697
-
698
- # Retrieve all online nodes
699
- return {
700
- node.node_id for node in self.get_node_info(statuses=[NodeStatus.ONLINE])
701
- }
702
-
703
- def _check_and_tag_offline_nodes(
704
- self, node_ids: Optional[list[int]] = None
705
- ) -> None:
770
+ with self.conn:
771
+ # Convert the uint64 value to sint64 for SQLite
772
+ sint64_run_id = uint64_to_int64(run_id)
773
+
774
+ # Validate run ID
775
+ query = "SELECT federation FROM run WHERE run_id = ?"
776
+ rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
777
+ if not rows:
778
+ return set()
779
+ federation: str = rows[0]["federation"]
780
+
781
+ # Retrieve all online nodes
782
+ node_ids = {
783
+ node.node_id
784
+ for node in self.get_node_info(statuses=[NodeStatus.ONLINE])
785
+ }
786
+ # Filter node IDs by federation
787
+ return self.federation_manager.filter_nodes(node_ids, federation)
788
+
789
+ def _check_and_tag_offline_nodes(self, node_ids: list[int] | None = None) -> None:
706
790
  """Check and tag offline nodes."""
707
791
  # strftime will convert POSIX timestamp to ISO format
708
792
  query = """
@@ -725,9 +809,9 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
725
809
  def get_node_info(
726
810
  self,
727
811
  *,
728
- node_ids: Optional[Sequence[int]] = None,
729
- owner_aids: Optional[Sequence[str]] = None,
730
- statuses: Optional[Sequence[str]] = None,
812
+ node_ids: Sequence[int] | None = None,
813
+ owner_aids: Sequence[str] | None = None,
814
+ statuses: Sequence[str] | None = None,
731
815
  ) -> Sequence[NodeInfo]:
732
816
  """Retrieve information about nodes based on the specified filters."""
733
817
  with self.conn:
@@ -781,7 +865,7 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
781
865
  # Return the public key
782
866
  return cast(bytes, rows[0]["public_key"])
783
867
 
784
- def get_node_id_by_public_key(self, public_key: bytes) -> Optional[int]:
868
+ def get_node_id_by_public_key(self, public_key: bytes) -> int | None:
785
869
  """Get `node_id` for the specified `public_key` if it exists and is not
786
870
  deleted."""
787
871
  query = "SELECT node_id FROM node WHERE public_key = ? AND status != ?;"
@@ -798,55 +882,61 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
798
882
  # pylint: disable=too-many-arguments,too-many-positional-arguments
799
883
  def create_run(
800
884
  self,
801
- fab_id: Optional[str],
802
- fab_version: Optional[str],
803
- fab_hash: Optional[str],
885
+ fab_id: str | None,
886
+ fab_version: str | None,
887
+ fab_hash: str | None,
804
888
  override_config: UserConfig,
889
+ federation: str,
805
890
  federation_options: ConfigRecord,
806
- flwr_aid: Optional[str],
891
+ flwr_aid: str | None,
807
892
  ) -> int:
808
- """Create a new run for the specified `fab_id` and `fab_version`."""
893
+ """Create a new run."""
809
894
  # Sample a random int64 as run_id
810
895
  uint64_run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
811
896
 
812
897
  # Convert the uint64 value to sint64 for SQLite
813
898
  sint64_run_id = uint64_to_int64(uint64_run_id)
814
899
 
815
- # Check conflicts
816
- query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
817
- # If sint64_run_id does not exist
818
- if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
819
- query = (
820
- "INSERT INTO run "
821
- "(run_id, active_until, heartbeat_interval, fab_id, fab_version, "
822
- "fab_hash, override_config, federation_options, pending_at, "
823
- "starting_at, running_at, finished_at, sub_status, details, flwr_aid) "
824
- "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
825
- )
826
- override_config_json = json.dumps(override_config)
827
- data = [
828
- sint64_run_id,
829
- 0, # The `active_until` is not used until the run is started
830
- 0, # This `heartbeat_interval` is not used until the run is started
831
- fab_id,
832
- fab_version,
833
- fab_hash,
834
- override_config_json,
835
- configrecord_to_bytes(federation_options),
836
- now().isoformat(),
837
- "",
838
- "",
839
- "",
840
- "",
841
- "",
842
- flwr_aid or "",
843
- ]
844
- self.query(query, tuple(data))
845
- return uint64_run_id
900
+ with self.conn:
901
+ # Check conflicts
902
+ query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
903
+ # If sint64_run_id does not exist
904
+ row = self.conn.execute(query, (sint64_run_id,)).fetchone()
905
+ if row["COUNT(*)"] == 0:
906
+ query = """
907
+ INSERT INTO run
908
+ (run_id, fab_id, fab_version,
909
+ fab_hash, override_config, federation, federation_options,
910
+ pending_at, starting_at, running_at, finished_at, sub_status,
911
+ details, flwr_aid, bytes_sent, bytes_recv, clientapp_runtime)
912
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
913
+ """
914
+ override_config_json = json.dumps(override_config)
915
+ data = [
916
+ sint64_run_id, # run_id
917
+ fab_id, # fab_id
918
+ fab_version, # fab_version
919
+ fab_hash, # fab_hash
920
+ override_config_json, # override_config
921
+ federation, # federation
922
+ configrecord_to_bytes(federation_options), # federation_options
923
+ now().isoformat(), # pending_at
924
+ "", # starting_at
925
+ "", # running_at
926
+ "", # finished_at
927
+ "", # sub_status
928
+ "", # details
929
+ flwr_aid or "", # flwr_aid
930
+ 0, # bytes_sent
931
+ 0, # bytes_recv
932
+ 0, # clientapp_runtime
933
+ ]
934
+ self.conn.execute(query, tuple(data))
935
+ return uint64_run_id
846
936
  log(ERROR, "Unexpected run creation failure.")
847
937
  return 0
848
938
 
849
- def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
939
+ def get_run_ids(self, flwr_aid: str | None) -> set[int]:
850
940
  """Retrieve all run IDs if `flwr_aid` is not specified.
851
941
 
852
942
  Otherwise, retrieve all run IDs for the specified `flwr_aid`.
@@ -860,32 +950,10 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
860
950
  rows = self.query("SELECT run_id FROM run;", ())
861
951
  return {int64_to_uint64(row["run_id"]) for row in rows}
862
952
 
863
- def _check_and_tag_inactive_run(self, run_ids: set[int]) -> None:
864
- """Check if any runs are no longer active.
865
-
866
- Marks runs with status 'starting' or 'running' as failed
867
- if they have not sent a heartbeat before `active_until`.
868
- """
869
- sint_run_ids = [uint64_to_int64(run_id) for run_id in run_ids]
870
- query = "UPDATE run SET finished_at = ?, sub_status = ?, details = ? "
871
- query += "WHERE starting_at != '' AND finished_at = '' AND active_until < ?"
872
- query += f" AND run_id IN ({','.join(['?'] * len(run_ids))});"
873
- current = now()
874
- self.query(
875
- query,
876
- (
877
- current.isoformat(),
878
- SubStatus.FAILED,
879
- RUN_FAILURE_DETAILS_NO_HEARTBEAT,
880
- current.timestamp(),
881
- *sint_run_ids,
882
- ),
883
- )
884
-
885
- def get_run(self, run_id: int) -> Optional[Run]:
953
+ def get_run(self, run_id: int) -> Run | None:
886
954
  """Retrieve information about the run with the specified `run_id`."""
887
- # Check if runs are still active
888
- self._check_and_tag_inactive_run(run_ids={run_id})
955
+ # Clean up expired tokens; this will flag inactive runs as needed
956
+ self._cleanup_expired_tokens()
889
957
 
890
958
  # Convert the uint64 value to sint64 for SQLite
891
959
  sint64_run_id = uint64_to_int64(run_id)
@@ -909,14 +977,18 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
909
977
  details=row["details"],
910
978
  ),
911
979
  flwr_aid=row["flwr_aid"],
980
+ federation=row["federation"],
981
+ bytes_sent=row["bytes_sent"],
982
+ bytes_recv=row["bytes_recv"],
983
+ clientapp_runtime=row["clientapp_runtime"],
912
984
  )
913
985
  log(ERROR, "`run_id` does not exist.")
914
986
  return None
915
987
 
916
988
  def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
917
989
  """Retrieve the statuses for the specified runs."""
918
- # Check if runs are still active
919
- self._check_and_tag_inactive_run(run_ids=run_ids)
990
+ # Clean up expired tokens; this will flag inactive runs as needed
991
+ self._cleanup_expired_tokens()
920
992
 
921
993
  # Convert the uint64 value to sint64 for SQLite
922
994
  sint64_run_ids = (uint64_to_int64(run_id) for run_id in set(run_ids))
@@ -935,82 +1007,73 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
935
1007
 
936
1008
  def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
937
1009
  """Update the status of the run with the specified `run_id`."""
938
- # Check if runs are still active
939
- self._check_and_tag_inactive_run(run_ids={run_id})
1010
+ # Clean up expired tokens; this will flag inactive runs as needed
1011
+ self._cleanup_expired_tokens()
940
1012
 
941
- # Convert the uint64 value to sint64 for SQLite
942
- sint64_run_id = uint64_to_int64(run_id)
943
- query = "SELECT * FROM run WHERE run_id = ?;"
944
- rows = self.query(query, (sint64_run_id,))
945
-
946
- # Check if the run_id exists
947
- if not rows:
948
- log(ERROR, "`run_id` is invalid")
949
- return False
1013
+ with self.conn:
1014
+ # Convert the uint64 value to sint64 for SQLite
1015
+ sint64_run_id = uint64_to_int64(run_id)
1016
+ query = "SELECT * FROM run WHERE run_id = ?;"
1017
+ rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
1018
+
1019
+ # Check if the run_id exists
1020
+ if not rows:
1021
+ log(ERROR, "`run_id` is invalid")
1022
+ return False
950
1023
 
951
- # Check if the status transition is valid
952
- row = rows[0]
953
- current_status = RunStatus(
954
- status=determine_run_status(row),
955
- sub_status=row["sub_status"],
956
- details=row["details"],
957
- )
958
- if not is_valid_transition(current_status, new_status):
959
- log(
960
- ERROR,
961
- 'Invalid status transition: from "%s" to "%s"',
962
- current_status.status,
963
- new_status.status,
1024
+ # Check if the status transition is valid
1025
+ row = rows[0]
1026
+ current_status = RunStatus(
1027
+ status=determine_run_status(row),
1028
+ sub_status=row["sub_status"],
1029
+ details=row["details"],
964
1030
  )
965
- return False
1031
+ if not is_valid_transition(current_status, new_status):
1032
+ log(
1033
+ ERROR,
1034
+ 'Invalid status transition: from "%s" to "%s"',
1035
+ current_status.status,
1036
+ new_status.status,
1037
+ )
1038
+ return False
966
1039
 
967
- # Check if the sub-status is valid
968
- if not has_valid_sub_status(current_status):
969
- log(
970
- ERROR,
971
- 'Invalid sub-status "%s" for status "%s"',
972
- current_status.sub_status,
973
- current_status.status,
974
- )
975
- return False
1040
+ # Check if the sub-status is valid
1041
+ if not has_valid_sub_status(current_status):
1042
+ log(
1043
+ ERROR,
1044
+ 'Invalid sub-status "%s" for status "%s"',
1045
+ current_status.sub_status,
1046
+ current_status.status,
1047
+ )
1048
+ return False
976
1049
 
977
- # Update the status
978
- query = "UPDATE run SET %s= ?, sub_status = ?, details = ?, "
979
- query += "active_until = ?, heartbeat_interval = ? "
980
- query += "WHERE run_id = ?;"
1050
+ # Update the status
1051
+ query = """
1052
+ UPDATE run SET %s= ?, sub_status = ?, details = ? WHERE run_id = ?;
1053
+ """
981
1054
 
982
- # Prepare data for query
983
- # Initialize heartbeat_interval and active_until
984
- # when switching to starting or running
985
- current = now()
986
- if new_status.status in (Status.STARTING, Status.RUNNING):
987
- heartbeat_interval = HEARTBEAT_INTERVAL_INF
988
- active_until = current.timestamp() + heartbeat_interval
989
- else:
990
- heartbeat_interval = 0
991
- active_until = 0
992
-
993
- # Determine the timestamp field based on the new status
994
- timestamp_fld = ""
995
- if new_status.status == Status.STARTING:
996
- timestamp_fld = "starting_at"
997
- elif new_status.status == Status.RUNNING:
998
- timestamp_fld = "running_at"
999
- elif new_status.status == Status.FINISHED:
1000
- timestamp_fld = "finished_at"
1001
-
1002
- data = (
1003
- current.isoformat(),
1004
- new_status.sub_status,
1005
- new_status.details,
1006
- active_until,
1007
- heartbeat_interval,
1008
- uint64_to_int64(run_id),
1009
- )
1010
- self.query(query % timestamp_fld, data)
1055
+ # Prepare data for query
1056
+ current = now()
1057
+
1058
+ # Determine the timestamp field based on the new status
1059
+ timestamp_fld = ""
1060
+ if new_status.status == Status.STARTING:
1061
+ timestamp_fld = "starting_at"
1062
+ elif new_status.status == Status.RUNNING:
1063
+ timestamp_fld = "running_at"
1064
+ elif new_status.status == Status.FINISHED:
1065
+ timestamp_fld = "finished_at"
1066
+
1067
+ data = (
1068
+ current.isoformat(),
1069
+ new_status.sub_status,
1070
+ new_status.details,
1071
+ uint64_to_int64(run_id),
1072
+ )
1073
+ self.conn.execute(query % timestamp_fld, data)
1011
1074
  return True
1012
1075
 
1013
- def get_pending_run_id(self) -> Optional[int]:
1076
+ def get_pending_run_id(self) -> int | None:
1014
1077
  """Get the `run_id` of a run with `Status.PENDING` status, if any."""
1015
1078
  pending_run_id = None
1016
1079
 
@@ -1022,7 +1085,7 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
1022
1085
 
1023
1086
  return pending_run_id
1024
1087
 
1025
- def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
1088
+ def get_federation_options(self, run_id: int) -> ConfigRecord | None:
1026
1089
  """Retrieve the federation options for the specified `run_id`."""
1027
1090
  # Convert the uint64 value to sint64 for SQLite
1028
1091
  sint64_run_id = uint64_to_int64(run_id)
@@ -1080,45 +1143,36 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
1080
1143
  self.conn.execute(query, params)
1081
1144
  return True
1082
1145
 
1083
- def acknowledge_app_heartbeat(self, run_id: int, heartbeat_interval: float) -> bool:
1084
- """Acknowledge a heartbeat received from a ServerApp for a given run.
1146
+ def _on_tokens_expired(self, expired_records: list[tuple[int, float]]) -> None:
1147
+ """Transition runs with expired tokens to failed status.
1085
1148
 
1086
- A run with status `"running"` is considered alive as long as it sends heartbeats
1087
- within the tolerated interval: HEARTBEAT_PATIENCE × heartbeat_interval.
1088
- HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before the run is
1089
- marked as `"completed:failed"`.
1149
+ Parameters
1150
+ ----------
1151
+ expired_records : list[tuple[int, float]]
1152
+ List of tuples containing (run_id, active_until timestamp)
1153
+ for expired tokens.
1090
1154
  """
1091
- # Check if runs are still active
1092
- self._check_and_tag_inactive_run(run_ids={run_id})
1093
-
1094
- # Search for the run
1095
- sint_run_id = uint64_to_int64(run_id)
1096
- query = "SELECT * FROM run WHERE run_id = ?;"
1097
- rows = self.query(query, (sint_run_id,))
1098
-
1099
- if not rows:
1100
- log(ERROR, "`run_id` is invalid")
1101
- return False
1102
-
1103
- # Check if the run is of status "running"/"starting"
1104
- row = rows[0]
1105
- status = determine_run_status(row)
1106
- if status not in (Status.RUNNING, Status.STARTING):
1107
- log(
1108
- ERROR,
1109
- 'Cannot acknowledge heartbeat for run with status "%s"',
1110
- status,
1111
- )
1112
- return False
1155
+ if not expired_records:
1156
+ return
1113
1157
 
1114
- # Update the `active_until` and `heartbeat_interval` for the given run
1115
- active_until = now().timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval
1116
- query = "UPDATE run SET active_until = ?, heartbeat_interval = ? "
1117
- query += "WHERE run_id = ?"
1118
- self.query(query, (active_until, heartbeat_interval, sint_run_id))
1119
- return True
1158
+ with self.conn:
1159
+ query = """
1160
+ UPDATE run
1161
+ SET sub_status = ?, details = ?, finished_at = ?
1162
+ WHERE run_id = ?;
1163
+ """
1164
+ data = [
1165
+ (
1166
+ SubStatus.FAILED,
1167
+ RUN_FAILURE_DETAILS_NO_HEARTBEAT,
1168
+ datetime.fromtimestamp(active_until, tz=timezone.utc).isoformat(),
1169
+ uint64_to_int64(run_id),
1170
+ )
1171
+ for run_id, active_until in expired_records
1172
+ ]
1173
+ self.conn.executemany(query, data)
1120
1174
 
1121
- def get_serverapp_context(self, run_id: int) -> Optional[Context]:
1175
+ def get_serverapp_context(self, run_id: int) -> Context | None:
1122
1176
  """Get the context for the specified `run_id`."""
1123
1177
  # Retrieve context if any
1124
1178
  query = "SELECT context FROM context WHERE run_id = ?;"
@@ -1132,19 +1186,21 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
1132
1186
  context_bytes = context_to_bytes(context)
1133
1187
  sint_run_id = uint64_to_int64(run_id)
1134
1188
 
1135
- # Check if any existing Context assigned to the run_id
1136
- query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
1137
- if self.query(query, (sint_run_id,))[0]["COUNT(*)"] > 0:
1138
- # Update context
1139
- query = "UPDATE context SET context = ? WHERE run_id = ?;"
1140
- self.query(query, (context_bytes, sint_run_id))
1141
- else:
1142
- try:
1143
- # Store context
1144
- query = "INSERT INTO context (run_id, context) VALUES (?, ?);"
1145
- self.query(query, (sint_run_id, context_bytes))
1146
- except sqlite3.IntegrityError:
1147
- raise ValueError(f"Run {run_id} not found") from None
1189
+ with self.conn:
1190
+ # Check if any existing Context assigned to the run_id
1191
+ query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
1192
+ row = self.conn.execute(query, (sint_run_id,)).fetchone()
1193
+ if row["COUNT(*)"] > 0:
1194
+ # Update context
1195
+ query = "UPDATE context SET context = ? WHERE run_id = ?;"
1196
+ self.conn.execute(query, (context_bytes, sint_run_id))
1197
+ else:
1198
+ try:
1199
+ # Store context
1200
+ query = "INSERT INTO context (run_id, context) VALUES (?, ?);"
1201
+ self.conn.execute(query, (sint_run_id, context_bytes))
1202
+ except sqlite3.IntegrityError:
1203
+ raise ValueError(f"Run {run_id} not found") from None
1148
1204
 
1149
1205
  def add_serverapp_log(self, run_id: int, log_message: str) -> None:
1150
1206
  """Add a log entry to the ServerApp logs for the specified `run_id`."""
@@ -1161,90 +1217,100 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
1161
1217
  raise ValueError(f"Run {run_id} not found") from None
1162
1218
 
1163
1219
  def get_serverapp_log(
1164
- self, run_id: int, after_timestamp: Optional[float]
1220
+ self, run_id: int, after_timestamp: float | None
1165
1221
  ) -> tuple[str, float]:
1166
1222
  """Get the ServerApp logs for the specified `run_id`."""
1167
1223
  # Convert the uint64 value to sint64 for SQLite
1168
1224
  sint64_run_id = uint64_to_int64(run_id)
1169
1225
 
1170
- # Check if the run_id exists
1171
- query = "SELECT run_id FROM run WHERE run_id = ?;"
1172
- if not self.query(query, (sint64_run_id,)):
1173
- raise ValueError(f"Run {run_id} not found")
1174
-
1175
- # Retrieve logs
1176
- if after_timestamp is None:
1177
- after_timestamp = 0.0
1178
- query = """
1179
- SELECT log, timestamp FROM logs
1180
- WHERE run_id = ? AND node_id = ? AND timestamp > ?;
1181
- """
1182
- rows = self.query(query, (sint64_run_id, 0, after_timestamp))
1183
- rows.sort(key=lambda x: x["timestamp"])
1184
- latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
1226
+ with self.conn:
1227
+ # Check if the run_id exists
1228
+ query = "SELECT run_id FROM run WHERE run_id = ?;"
1229
+ rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
1230
+ if not rows:
1231
+ raise ValueError(f"Run {run_id} not found")
1232
+
1233
+ # Retrieve logs
1234
+ if after_timestamp is None:
1235
+ after_timestamp = 0.0
1236
+ query = """
1237
+ SELECT log, timestamp FROM logs
1238
+ WHERE run_id = ? AND node_id = ? AND timestamp > ?;
1239
+ """
1240
+ rows = self.conn.execute(
1241
+ query, (sint64_run_id, 0, after_timestamp)
1242
+ ).fetchall()
1243
+ rows.sort(key=lambda x: x["timestamp"])
1244
+ latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
1185
1245
  return "".join(row["log"] for row in rows), latest_timestamp
1186
1246
 
1187
- def get_valid_message_ins(self, message_id: str) -> Optional[dict[str, Any]]:
1247
+ def get_valid_message_ins(self, message_id: str) -> dict[str, Any] | None:
1188
1248
  """Check if the Message exists and is valid (not expired).
1189
1249
 
1190
1250
  Return Message if valid.
1191
1251
  """
1192
- query = """
1193
- SELECT *
1194
- FROM message_ins
1195
- WHERE message_id = :message_id
1196
- """
1197
- data = {"message_id": message_id}
1198
- rows = self.query(query, data)
1199
- if not rows:
1200
- # Message does not exist
1201
- return None
1252
+ with self.conn:
1253
+ self._check_stored_messages({message_id})
1254
+ query = """
1255
+ SELECT *
1256
+ FROM message_ins
1257
+ WHERE message_id = :message_id
1258
+ """
1259
+ data = {"message_id": message_id}
1260
+ rows: list[dict[str, Any]] = self.conn.execute(query, data).fetchall()
1261
+ if not rows:
1262
+ # Message does not exist
1263
+ return None
1264
+
1265
+ return rows[0]
1266
+
1267
+ def store_traffic(self, run_id: int, *, bytes_sent: int, bytes_recv: int) -> None:
1268
+ """Store traffic data for the specified `run_id`."""
1269
+ # Validate non-negative values
1270
+ if bytes_sent < 0 or bytes_recv < 0:
1271
+ raise ValueError(
1272
+ f"Negative traffic values for run {run_id}: "
1273
+ f"bytes_sent={bytes_sent}, bytes_recv={bytes_recv}"
1274
+ )
1202
1275
 
1203
- message_ins = rows[0]
1204
- created_at = message_ins["created_at"]
1205
- ttl = message_ins["ttl"]
1206
- current_time = now().timestamp()
1276
+ if bytes_sent == 0 and bytes_recv == 0:
1277
+ raise ValueError(
1278
+ f"Both bytes_sent and bytes_recv cannot be zero for run {run_id}"
1279
+ )
1207
1280
 
1208
- # Check if Message is expired
1209
- if ttl is not None and created_at + ttl <= current_time:
1210
- return None
1281
+ sint64_run_id = uint64_to_int64(run_id)
1211
1282
 
1212
- return message_ins
1283
+ with self.conn:
1284
+ # Check if run exists, performing the update only if it does
1285
+ update_query = """
1286
+ UPDATE run
1287
+ SET bytes_sent = bytes_sent + ?,
1288
+ bytes_recv = bytes_recv + ?
1289
+ WHERE run_id = ?
1290
+ RETURNING run_id;
1291
+ """
1292
+ rows = self.conn.execute(
1293
+ update_query, (bytes_sent, bytes_recv, sint64_run_id)
1294
+ ).fetchall()
1213
1295
 
1214
- def create_token(self, run_id: int) -> Optional[str]:
1215
- """Create a token for the given run ID."""
1216
- token = secrets.token_hex(FLWR_APP_TOKEN_LENGTH) # Generate a random token
1217
- query = "INSERT INTO token_store (run_id, token) VALUES (:run_id, :token);"
1218
- data = {"run_id": uint64_to_int64(run_id), "token": token}
1219
- try:
1220
- self.query(query, data)
1221
- except sqlite3.IntegrityError:
1222
- return None # Token already created for this run ID
1223
- return token
1224
-
1225
- def verify_token(self, run_id: int, token: str) -> bool:
1226
- """Verify a token for the given run ID."""
1227
- query = "SELECT token FROM token_store WHERE run_id = :run_id;"
1228
- data = {"run_id": uint64_to_int64(run_id)}
1229
- rows = self.query(query, data)
1230
- if not rows:
1231
- return False
1232
- return cast(str, rows[0]["token"]) == token
1233
-
1234
- def delete_token(self, run_id: int) -> None:
1235
- """Delete the token for the given run ID."""
1236
- query = "DELETE FROM token_store WHERE run_id = :run_id;"
1237
- data = {"run_id": uint64_to_int64(run_id)}
1238
- self.query(query, data)
1239
-
1240
- def get_run_id_by_token(self, token: str) -> Optional[int]:
1241
- """Get the run ID associated with a given token."""
1242
- query = "SELECT run_id FROM token_store WHERE token = :token;"
1243
- data = {"token": token}
1244
- rows = self.query(query, data)
1245
- if not rows:
1246
- return None
1247
- return int64_to_uint64(rows[0]["run_id"])
1296
+ if not rows:
1297
+ raise ValueError(f"Run {run_id} not found")
1298
+
1299
+ def add_clientapp_runtime(self, run_id: int, runtime: float) -> None:
1300
+ """Add ClientApp runtime to the cumulative total for the specified `run_id`."""
1301
+ sint64_run_id = uint64_to_int64(run_id)
1302
+ with self.conn:
1303
+ # Check if run exists, performing the update only if it does
1304
+ update_query = """
1305
+ UPDATE run
1306
+ SET clientapp_runtime = clientapp_runtime + ?
1307
+ WHERE run_id = ?
1308
+ RETURNING run_id;
1309
+ """
1310
+ rows = self.conn.execute(update_query, (runtime, sint64_run_id)).fetchall()
1311
+
1312
+ if not rows:
1313
+ raise ValueError(f"Run {run_id} not found")
1248
1314
 
1249
1315
 
1250
1316
  def message_to_dict(message: Message) -> dict[str, Any]: