flwr-nightly 1.8.0.dev20240314__py3-none-any.whl → 1.15.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (311) hide show
  1. flwr/cli/app.py +16 -2
  2. flwr/cli/build.py +181 -0
  3. flwr/cli/cli_user_auth_interceptor.py +90 -0
  4. flwr/cli/config_utils.py +343 -0
  5. flwr/cli/example.py +4 -1
  6. flwr/cli/install.py +253 -0
  7. flwr/cli/log.py +182 -0
  8. flwr/{server/superlink/state → cli/login}/__init__.py +4 -10
  9. flwr/cli/login/login.py +88 -0
  10. flwr/cli/ls.py +327 -0
  11. flwr/cli/new/__init__.py +1 -0
  12. flwr/cli/new/new.py +210 -66
  13. flwr/cli/new/templates/app/.gitignore.tpl +163 -0
  14. flwr/cli/new/templates/app/LICENSE.tpl +202 -0
  15. flwr/cli/new/templates/app/README.baseline.md.tpl +127 -0
  16. flwr/cli/new/templates/app/README.flowertune.md.tpl +66 -0
  17. flwr/cli/new/templates/app/README.md.tpl +16 -32
  18. flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +1 -0
  19. flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
  20. flwr/cli/new/templates/app/code/client.baseline.py.tpl +58 -0
  21. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +55 -0
  22. flwr/cli/new/templates/app/code/client.jax.py.tpl +50 -0
  23. flwr/cli/new/templates/app/code/client.mlx.py.tpl +73 -0
  24. flwr/cli/new/templates/app/code/client.numpy.py.tpl +7 -7
  25. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +30 -21
  26. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +63 -0
  27. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +57 -1
  28. flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +36 -0
  29. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  30. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +126 -0
  31. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +87 -0
  32. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +78 -0
  33. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +94 -0
  34. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +83 -0
  35. flwr/cli/new/templates/app/code/model.baseline.py.tpl +80 -0
  36. flwr/cli/new/templates/app/code/server.baseline.py.tpl +46 -0
  37. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +38 -0
  38. flwr/cli/new/templates/app/code/server.jax.py.tpl +26 -0
  39. flwr/cli/new/templates/app/code/server.mlx.py.tpl +31 -0
  40. flwr/cli/new/templates/app/code/server.numpy.py.tpl +22 -9
  41. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
  42. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +36 -0
  43. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
  44. flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +1 -0
  45. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +102 -0
  46. flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
  47. flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
  48. flwr/cli/new/templates/app/code/task.numpy.py.tpl +7 -0
  49. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +29 -24
  50. flwr/cli/new/templates/app/code/task.sklearn.py.tpl +67 -0
  51. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
  52. flwr/cli/new/templates/app/code/utils.baseline.py.tpl +1 -0
  53. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +138 -0
  54. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +68 -0
  55. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +46 -0
  56. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +35 -0
  57. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
  58. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
  59. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
  60. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +35 -0
  61. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
  62. flwr/cli/run/__init__.py +1 -0
  63. flwr/cli/run/run.py +212 -34
  64. flwr/cli/stop.py +130 -0
  65. flwr/cli/utils.py +240 -5
  66. flwr/client/__init__.py +3 -2
  67. flwr/client/app.py +432 -255
  68. flwr/client/client.py +1 -11
  69. flwr/client/client_app.py +74 -13
  70. flwr/client/clientapp/__init__.py +22 -0
  71. flwr/client/clientapp/app.py +259 -0
  72. flwr/client/clientapp/clientappio_servicer.py +244 -0
  73. flwr/client/clientapp/utils.py +115 -0
  74. flwr/client/dpfedavg_numpy_client.py +7 -8
  75. flwr/client/grpc_adapter_client/__init__.py +15 -0
  76. flwr/client/grpc_adapter_client/connection.py +98 -0
  77. flwr/client/grpc_client/connection.py +21 -7
  78. flwr/client/grpc_rere_client/__init__.py +1 -1
  79. flwr/client/grpc_rere_client/client_interceptor.py +176 -0
  80. flwr/client/grpc_rere_client/connection.py +163 -56
  81. flwr/client/grpc_rere_client/grpc_adapter.py +167 -0
  82. flwr/client/heartbeat.py +74 -0
  83. flwr/client/message_handler/__init__.py +1 -1
  84. flwr/client/message_handler/message_handler.py +10 -11
  85. flwr/client/mod/__init__.py +5 -5
  86. flwr/client/mod/centraldp_mods.py +4 -2
  87. flwr/client/mod/comms_mods.py +5 -4
  88. flwr/client/mod/localdp_mod.py +10 -5
  89. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  90. flwr/client/mod/secure_aggregation/secaggplus_mod.py +26 -26
  91. flwr/client/mod/utils.py +2 -4
  92. flwr/client/nodestate/__init__.py +26 -0
  93. flwr/client/nodestate/in_memory_nodestate.py +38 -0
  94. flwr/client/nodestate/nodestate.py +31 -0
  95. flwr/client/nodestate/nodestate_factory.py +38 -0
  96. flwr/client/numpy_client.py +8 -31
  97. flwr/client/rest_client/__init__.py +1 -1
  98. flwr/client/rest_client/connection.py +199 -176
  99. flwr/client/run_info_store.py +112 -0
  100. flwr/client/supernode/__init__.py +24 -0
  101. flwr/client/supernode/app.py +321 -0
  102. flwr/client/typing.py +1 -0
  103. flwr/common/__init__.py +17 -11
  104. flwr/common/address.py +47 -3
  105. flwr/common/args.py +153 -0
  106. flwr/common/auth_plugin/__init__.py +24 -0
  107. flwr/common/auth_plugin/auth_plugin.py +121 -0
  108. flwr/common/config.py +243 -0
  109. flwr/common/constant.py +132 -1
  110. flwr/common/context.py +32 -2
  111. flwr/common/date.py +22 -4
  112. flwr/common/differential_privacy.py +2 -2
  113. flwr/common/dp.py +2 -4
  114. flwr/common/exit_handlers.py +3 -3
  115. flwr/common/grpc.py +164 -5
  116. flwr/common/logger.py +230 -12
  117. flwr/common/message.py +191 -106
  118. flwr/common/object_ref.py +179 -44
  119. flwr/common/pyproject.py +1 -0
  120. flwr/common/record/__init__.py +2 -1
  121. flwr/common/record/configsrecord.py +58 -18
  122. flwr/common/record/metricsrecord.py +57 -17
  123. flwr/common/record/parametersrecord.py +88 -20
  124. flwr/common/record/recordset.py +153 -30
  125. flwr/common/record/typeddict.py +30 -55
  126. flwr/common/recordset_compat.py +31 -12
  127. flwr/common/retry_invoker.py +123 -30
  128. flwr/common/secure_aggregation/__init__.py +1 -1
  129. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  130. flwr/common/secure_aggregation/crypto/shamir.py +11 -11
  131. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +68 -4
  132. flwr/common/secure_aggregation/ndarrays_arithmetic.py +17 -17
  133. flwr/common/secure_aggregation/quantization.py +8 -8
  134. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  135. flwr/common/secure_aggregation/secaggplus_utils.py +10 -12
  136. flwr/common/serde.py +298 -19
  137. flwr/common/telemetry.py +65 -29
  138. flwr/common/typing.py +120 -19
  139. flwr/common/version.py +17 -3
  140. flwr/proto/clientappio_pb2.py +45 -0
  141. flwr/proto/clientappio_pb2.pyi +132 -0
  142. flwr/proto/clientappio_pb2_grpc.py +135 -0
  143. flwr/proto/clientappio_pb2_grpc.pyi +53 -0
  144. flwr/proto/exec_pb2.py +62 -0
  145. flwr/proto/exec_pb2.pyi +212 -0
  146. flwr/proto/exec_pb2_grpc.py +237 -0
  147. flwr/proto/exec_pb2_grpc.pyi +93 -0
  148. flwr/proto/fab_pb2.py +31 -0
  149. flwr/proto/fab_pb2.pyi +65 -0
  150. flwr/proto/fab_pb2_grpc.py +4 -0
  151. flwr/proto/fab_pb2_grpc.pyi +4 -0
  152. flwr/proto/fleet_pb2.py +42 -23
  153. flwr/proto/fleet_pb2.pyi +123 -1
  154. flwr/proto/fleet_pb2_grpc.py +170 -0
  155. flwr/proto/fleet_pb2_grpc.pyi +61 -0
  156. flwr/proto/grpcadapter_pb2.py +32 -0
  157. flwr/proto/grpcadapter_pb2.pyi +43 -0
  158. flwr/proto/grpcadapter_pb2_grpc.py +66 -0
  159. flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
  160. flwr/proto/log_pb2.py +29 -0
  161. flwr/proto/log_pb2.pyi +39 -0
  162. flwr/proto/log_pb2_grpc.py +4 -0
  163. flwr/proto/log_pb2_grpc.pyi +4 -0
  164. flwr/proto/message_pb2.py +41 -0
  165. flwr/proto/message_pb2.pyi +128 -0
  166. flwr/proto/message_pb2_grpc.py +4 -0
  167. flwr/proto/message_pb2_grpc.pyi +4 -0
  168. flwr/proto/node_pb2.py +1 -1
  169. flwr/proto/recordset_pb2.py +35 -33
  170. flwr/proto/recordset_pb2.pyi +40 -14
  171. flwr/proto/run_pb2.py +64 -0
  172. flwr/proto/run_pb2.pyi +268 -0
  173. flwr/proto/run_pb2_grpc.py +4 -0
  174. flwr/proto/run_pb2_grpc.pyi +4 -0
  175. flwr/proto/serverappio_pb2.py +52 -0
  176. flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +62 -20
  177. flwr/proto/serverappio_pb2_grpc.py +410 -0
  178. flwr/proto/serverappio_pb2_grpc.pyi +160 -0
  179. flwr/proto/simulationio_pb2.py +38 -0
  180. flwr/proto/simulationio_pb2.pyi +65 -0
  181. flwr/proto/simulationio_pb2_grpc.py +239 -0
  182. flwr/proto/simulationio_pb2_grpc.pyi +94 -0
  183. flwr/proto/task_pb2.py +7 -8
  184. flwr/proto/task_pb2.pyi +8 -5
  185. flwr/proto/transport_pb2.py +8 -8
  186. flwr/proto/transport_pb2.pyi +9 -6
  187. flwr/server/__init__.py +2 -10
  188. flwr/server/app.py +579 -402
  189. flwr/server/client_manager.py +8 -6
  190. flwr/server/compat/app.py +6 -62
  191. flwr/server/compat/app_utils.py +14 -8
  192. flwr/server/compat/driver_client_proxy.py +25 -58
  193. flwr/server/compat/legacy_context.py +5 -4
  194. flwr/server/driver/__init__.py +2 -0
  195. flwr/server/driver/driver.py +36 -131
  196. flwr/server/driver/grpc_driver.py +217 -81
  197. flwr/server/driver/inmemory_driver.py +182 -0
  198. flwr/server/history.py +28 -29
  199. flwr/server/run_serverapp.py +15 -126
  200. flwr/server/server.py +50 -44
  201. flwr/server/server_app.py +59 -10
  202. flwr/server/serverapp/__init__.py +22 -0
  203. flwr/server/serverapp/app.py +256 -0
  204. flwr/server/serverapp_components.py +52 -0
  205. flwr/server/strategy/__init__.py +2 -2
  206. flwr/server/strategy/aggregate.py +37 -23
  207. flwr/server/strategy/bulyan.py +9 -9
  208. flwr/server/strategy/dp_adaptive_clipping.py +25 -25
  209. flwr/server/strategy/dp_fixed_clipping.py +23 -22
  210. flwr/server/strategy/dpfedavg_adaptive.py +8 -8
  211. flwr/server/strategy/dpfedavg_fixed.py +13 -12
  212. flwr/server/strategy/fault_tolerant_fedavg.py +11 -11
  213. flwr/server/strategy/fedadagrad.py +9 -9
  214. flwr/server/strategy/fedadam.py +20 -10
  215. flwr/server/strategy/fedavg.py +16 -16
  216. flwr/server/strategy/fedavg_android.py +17 -17
  217. flwr/server/strategy/fedavgm.py +9 -9
  218. flwr/server/strategy/fedmedian.py +5 -5
  219. flwr/server/strategy/fedopt.py +6 -6
  220. flwr/server/strategy/fedprox.py +7 -7
  221. flwr/server/strategy/fedtrimmedavg.py +8 -8
  222. flwr/server/strategy/fedxgb_bagging.py +12 -12
  223. flwr/server/strategy/fedxgb_cyclic.py +10 -10
  224. flwr/server/strategy/fedxgb_nn_avg.py +6 -6
  225. flwr/server/strategy/fedyogi.py +9 -9
  226. flwr/server/strategy/krum.py +9 -9
  227. flwr/server/strategy/qfedavg.py +16 -16
  228. flwr/server/strategy/strategy.py +10 -10
  229. flwr/server/superlink/driver/__init__.py +2 -2
  230. flwr/server/superlink/driver/serverappio_grpc.py +61 -0
  231. flwr/server/superlink/driver/serverappio_servicer.py +363 -0
  232. flwr/server/superlink/ffs/__init__.py +24 -0
  233. flwr/server/superlink/ffs/disk_ffs.py +108 -0
  234. flwr/server/superlink/ffs/ffs.py +79 -0
  235. flwr/server/superlink/ffs/ffs_factory.py +47 -0
  236. flwr/server/superlink/fleet/__init__.py +1 -1
  237. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  238. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +162 -0
  239. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  240. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +4 -2
  241. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -2
  242. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  243. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +5 -154
  244. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  245. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +120 -13
  246. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +228 -0
  247. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  248. flwr/server/superlink/fleet/message_handler/message_handler.py +153 -9
  249. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  250. flwr/server/superlink/fleet/rest_rere/rest_api.py +119 -81
  251. flwr/server/superlink/fleet/vce/__init__.py +1 -0
  252. flwr/server/superlink/fleet/vce/backend/__init__.py +4 -4
  253. flwr/server/superlink/fleet/vce/backend/backend.py +8 -9
  254. flwr/server/superlink/fleet/vce/backend/raybackend.py +87 -68
  255. flwr/server/superlink/fleet/vce/vce_api.py +208 -146
  256. flwr/server/superlink/linkstate/__init__.py +28 -0
  257. flwr/server/superlink/linkstate/in_memory_linkstate.py +581 -0
  258. flwr/server/superlink/linkstate/linkstate.py +389 -0
  259. flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +19 -10
  260. flwr/server/superlink/linkstate/sqlite_linkstate.py +1236 -0
  261. flwr/server/superlink/linkstate/utils.py +389 -0
  262. flwr/server/superlink/simulation/__init__.py +15 -0
  263. flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
  264. flwr/server/superlink/simulation/simulationio_servicer.py +186 -0
  265. flwr/server/superlink/utils.py +65 -0
  266. flwr/server/typing.py +2 -0
  267. flwr/server/utils/__init__.py +1 -1
  268. flwr/server/utils/tensorboard.py +5 -5
  269. flwr/server/utils/validator.py +31 -11
  270. flwr/server/workflow/default_workflows.py +70 -26
  271. flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -0
  272. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +40 -27
  273. flwr/simulation/__init__.py +12 -5
  274. flwr/simulation/app.py +247 -315
  275. flwr/simulation/legacy_app.py +402 -0
  276. flwr/simulation/ray_transport/__init__.py +1 -1
  277. flwr/simulation/ray_transport/ray_actor.py +42 -67
  278. flwr/simulation/ray_transport/ray_client_proxy.py +37 -17
  279. flwr/simulation/ray_transport/utils.py +1 -0
  280. flwr/simulation/run_simulation.py +306 -163
  281. flwr/simulation/simulationio_connection.py +89 -0
  282. flwr/superexec/__init__.py +15 -0
  283. flwr/superexec/app.py +59 -0
  284. flwr/superexec/deployment.py +188 -0
  285. flwr/superexec/exec_grpc.py +80 -0
  286. flwr/superexec/exec_servicer.py +231 -0
  287. flwr/superexec/exec_user_auth_interceptor.py +101 -0
  288. flwr/superexec/executor.py +96 -0
  289. flwr/superexec/simulation.py +124 -0
  290. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.15.0.dev20250114.dist-info}/METADATA +33 -26
  291. flwr_nightly-1.15.0.dev20250114.dist-info/RECORD +328 -0
  292. flwr_nightly-1.15.0.dev20250114.dist-info/entry_points.txt +12 -0
  293. flwr/cli/flower_toml.py +0 -140
  294. flwr/cli/new/templates/app/flower.toml.tpl +0 -13
  295. flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
  296. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
  297. flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
  298. flwr/client/node_state.py +0 -48
  299. flwr/client/node_state_tests.py +0 -65
  300. flwr/proto/driver_pb2.py +0 -44
  301. flwr/proto/driver_pb2_grpc.py +0 -169
  302. flwr/proto/driver_pb2_grpc.pyi +0 -66
  303. flwr/server/superlink/driver/driver_grpc.py +0 -54
  304. flwr/server/superlink/driver/driver_servicer.py +0 -129
  305. flwr/server/superlink/state/in_memory_state.py +0 -230
  306. flwr/server/superlink/state/sqlite_state.py +0 -630
  307. flwr/server/superlink/state/state.py +0 -154
  308. flwr_nightly-1.8.0.dev20240314.dist-info/RECORD +0 -211
  309. flwr_nightly-1.8.0.dev20240314.dist-info/entry_points.txt +0 -9
  310. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.15.0.dev20250114.dist-info}/LICENSE +0 -0
  311. {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.15.0.dev20250114.dist-info}/WHEEL +0 -0
@@ -0,0 +1,1236 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """SQLite based implemenation of the link state."""
16
+
17
+
18
+ # pylint: disable=too-many-lines
19
+
20
+ import json
21
+ import re
22
+ import sqlite3
23
+ import time
24
+ from collections.abc import Sequence
25
+ from logging import DEBUG, ERROR, WARNING
26
+ from typing import Any, Optional, Union, cast
27
+ from uuid import UUID, uuid4
28
+
29
+ from flwr.common import Context, log, now
30
+ from flwr.common.constant import (
31
+ MESSAGE_TTL_TOLERANCE,
32
+ NODE_ID_NUM_BYTES,
33
+ RUN_ID_NUM_BYTES,
34
+ Status,
35
+ )
36
+ from flwr.common.record import ConfigsRecord
37
+ from flwr.common.typing import Run, RunStatus, UserConfig
38
+
39
+ # pylint: disable=E0611
40
+ from flwr.proto.node_pb2 import Node
41
+ from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
42
+ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
43
+
44
+ # pylint: enable=E0611
45
+ from flwr.server.utils.validator import validate_task_ins_or_res
46
+
47
+ from .linkstate import LinkState
48
+ from .utils import (
49
+ configsrecord_from_bytes,
50
+ configsrecord_to_bytes,
51
+ context_from_bytes,
52
+ context_to_bytes,
53
+ convert_sint64_to_uint64,
54
+ convert_sint64_values_in_dict_to_uint64,
55
+ convert_uint64_to_sint64,
56
+ convert_uint64_values_in_dict_to_sint64,
57
+ generate_rand_int_from_bytes,
58
+ has_valid_sub_status,
59
+ is_valid_transition,
60
+ verify_found_taskres,
61
+ verify_taskins_ids,
62
+ )
63
+
64
+ SQL_CREATE_TABLE_NODE = """
65
+ CREATE TABLE IF NOT EXISTS node(
66
+ node_id INTEGER UNIQUE,
67
+ online_until REAL,
68
+ ping_interval REAL,
69
+ public_key BLOB
70
+ );
71
+ """
72
+
73
+ SQL_CREATE_TABLE_CREDENTIAL = """
74
+ CREATE TABLE IF NOT EXISTS credential(
75
+ private_key BLOB PRIMARY KEY,
76
+ public_key BLOB
77
+ );
78
+ """
79
+
80
+ SQL_CREATE_TABLE_PUBLIC_KEY = """
81
+ CREATE TABLE IF NOT EXISTS public_key(
82
+ public_key BLOB PRIMARY KEY
83
+ );
84
+ """
85
+
86
+ SQL_CREATE_INDEX_ONLINE_UNTIL = """
87
+ CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
88
+ """
89
+
90
+ SQL_CREATE_TABLE_RUN = """
91
+ CREATE TABLE IF NOT EXISTS run(
92
+ run_id INTEGER UNIQUE,
93
+ fab_id TEXT,
94
+ fab_version TEXT,
95
+ fab_hash TEXT,
96
+ override_config TEXT,
97
+ pending_at TEXT,
98
+ starting_at TEXT,
99
+ running_at TEXT,
100
+ finished_at TEXT,
101
+ sub_status TEXT,
102
+ details TEXT,
103
+ federation_options BLOB
104
+ );
105
+ """
106
+
107
+ SQL_CREATE_TABLE_LOGS = """
108
+ CREATE TABLE IF NOT EXISTS logs (
109
+ timestamp REAL,
110
+ run_id INTEGER,
111
+ node_id INTEGER,
112
+ log TEXT,
113
+ PRIMARY KEY (timestamp, run_id, node_id),
114
+ FOREIGN KEY (run_id) REFERENCES run(run_id)
115
+ );
116
+ """
117
+
118
+ SQL_CREATE_TABLE_CONTEXT = """
119
+ CREATE TABLE IF NOT EXISTS context(
120
+ run_id INTEGER UNIQUE,
121
+ context BLOB,
122
+ FOREIGN KEY(run_id) REFERENCES run(run_id)
123
+ );
124
+ """
125
+
126
+ SQL_CREATE_TABLE_TASK_INS = """
127
+ CREATE TABLE IF NOT EXISTS task_ins(
128
+ task_id TEXT UNIQUE,
129
+ group_id TEXT,
130
+ run_id INTEGER,
131
+ producer_anonymous BOOLEAN,
132
+ producer_node_id INTEGER,
133
+ consumer_anonymous BOOLEAN,
134
+ consumer_node_id INTEGER,
135
+ created_at REAL,
136
+ delivered_at TEXT,
137
+ pushed_at REAL,
138
+ ttl REAL,
139
+ ancestry TEXT,
140
+ task_type TEXT,
141
+ recordset BLOB,
142
+ FOREIGN KEY(run_id) REFERENCES run(run_id)
143
+ );
144
+ """
145
+
146
+ SQL_CREATE_TABLE_TASK_RES = """
147
+ CREATE TABLE IF NOT EXISTS task_res(
148
+ task_id TEXT UNIQUE,
149
+ group_id TEXT,
150
+ run_id INTEGER,
151
+ producer_anonymous BOOLEAN,
152
+ producer_node_id INTEGER,
153
+ consumer_anonymous BOOLEAN,
154
+ consumer_node_id INTEGER,
155
+ created_at REAL,
156
+ delivered_at TEXT,
157
+ pushed_at REAL,
158
+ ttl REAL,
159
+ ancestry TEXT,
160
+ task_type TEXT,
161
+ recordset BLOB,
162
+ FOREIGN KEY(run_id) REFERENCES run(run_id)
163
+ );
164
+ """
165
+
166
+ DictOrTuple = Union[tuple[Any, ...], dict[str, Any]]
167
+
168
+
169
+ class SqliteLinkState(LinkState): # pylint: disable=R0904
170
+ """SQLite-based LinkState implementation."""
171
+
172
+ def __init__(
173
+ self,
174
+ database_path: str,
175
+ ) -> None:
176
+ """Initialize an SqliteLinkState.
177
+
178
+ Parameters
179
+ ----------
180
+ database : (path-like object)
181
+ The path to the database file to be opened. Pass ":memory:" to open
182
+ a connection to a database that is in RAM, instead of on disk.
183
+ """
184
+ self.database_path = database_path
185
+ self.conn: Optional[sqlite3.Connection] = None
186
+
187
+ def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
188
+ """Create tables if they don't exist yet.
189
+
190
+ Parameters
191
+ ----------
192
+ log_queries : bool
193
+ Log each query which is executed.
194
+
195
+ Returns
196
+ -------
197
+ list[tuple[str]]
198
+ The list of all tables in the DB.
199
+ """
200
+ self.conn = sqlite3.connect(self.database_path)
201
+ self.conn.execute("PRAGMA foreign_keys = ON;")
202
+ self.conn.row_factory = dict_factory
203
+ if log_queries:
204
+ self.conn.set_trace_callback(lambda query: log(DEBUG, query))
205
+ cur = self.conn.cursor()
206
+
207
+ # Create each table if not exists queries
208
+ cur.execute(SQL_CREATE_TABLE_RUN)
209
+ cur.execute(SQL_CREATE_TABLE_LOGS)
210
+ cur.execute(SQL_CREATE_TABLE_CONTEXT)
211
+ cur.execute(SQL_CREATE_TABLE_TASK_INS)
212
+ cur.execute(SQL_CREATE_TABLE_TASK_RES)
213
+ cur.execute(SQL_CREATE_TABLE_NODE)
214
+ cur.execute(SQL_CREATE_TABLE_CREDENTIAL)
215
+ cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
216
+ cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
217
+ res = cur.execute("SELECT name FROM sqlite_schema;")
218
+ return res.fetchall()
219
+
220
+ def query(
221
+ self,
222
+ query: str,
223
+ data: Optional[Union[Sequence[DictOrTuple], DictOrTuple]] = None,
224
+ ) -> list[dict[str, Any]]:
225
+ """Execute a SQL query."""
226
+ if self.conn is None:
227
+ raise AttributeError("LinkState is not initialized.")
228
+
229
+ if data is None:
230
+ data = []
231
+
232
+ # Clean up whitespace to make the logs nicer
233
+ query = re.sub(r"\s+", " ", query)
234
+
235
+ try:
236
+ with self.conn:
237
+ if (
238
+ len(data) > 0
239
+ and isinstance(data, (tuple, list))
240
+ and isinstance(data[0], (tuple, dict))
241
+ ):
242
+ rows = self.conn.executemany(query, data)
243
+ else:
244
+ rows = self.conn.execute(query, data)
245
+
246
+ # Extract results before committing to support
247
+ # INSERT/UPDATE ... RETURNING
248
+ # style queries
249
+ result = rows.fetchall()
250
+ except KeyError as exc:
251
+ log(ERROR, {"query": query, "data": data, "exception": exc})
252
+
253
+ return result
254
+
255
+ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
256
+ """Store one TaskIns.
257
+
258
+ Usually, the ServerAppIo API calls this to schedule instructions.
259
+
260
+ Stores the value of the task_ins in the link state and, if successful,
261
+ returns the task_id (UUID) of the task_ins. If, for any reason, storing
262
+ the task_ins fails, `None` is returned.
263
+
264
+ Constraints
265
+ -----------
266
+ If `task_ins.task.consumer.anonymous` is `True`, then
267
+ `task_ins.task.consumer.node_id` MUST NOT be set (equal 0).
268
+
269
+ If `task_ins.task.consumer.anonymous` is `False`, then
270
+ `task_ins.task.consumer.node_id` MUST be set (not 0)
271
+ """
272
+ # Validate task
273
+ errors = validate_task_ins_or_res(task_ins)
274
+ if any(errors):
275
+ log(ERROR, errors)
276
+ return None
277
+ # Create task_id
278
+ task_id = uuid4()
279
+
280
+ # Store TaskIns
281
+ task_ins.task_id = str(task_id)
282
+ data = (task_ins_to_dict(task_ins),)
283
+
284
+ # Convert values from uint64 to sint64 for SQLite
285
+ convert_uint64_values_in_dict_to_sint64(
286
+ data[0], ["run_id", "producer_node_id", "consumer_node_id"]
287
+ )
288
+
289
+ # Validate run_id
290
+ query = "SELECT run_id FROM run WHERE run_id = ?;"
291
+ if not self.query(query, (data[0]["run_id"],)):
292
+ log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
293
+ return None
294
+ # Validate source node ID
295
+ if task_ins.task.producer.node_id != 0:
296
+ log(
297
+ ERROR,
298
+ "Invalid source node ID for TaskIns: %s",
299
+ task_ins.task.producer.node_id,
300
+ )
301
+ return None
302
+ # Validate destination node ID
303
+ query = "SELECT node_id FROM node WHERE node_id = ?;"
304
+ if not task_ins.task.consumer.anonymous:
305
+ if not self.query(query, (data[0]["consumer_node_id"],)):
306
+ log(
307
+ ERROR,
308
+ "Invalid destination node ID for TaskIns: %s",
309
+ task_ins.task.consumer.node_id,
310
+ )
311
+ return None
312
+
313
+ columns = ", ".join([f":{key}" for key in data[0]])
314
+ query = f"INSERT INTO task_ins VALUES({columns});"
315
+
316
+ # Only invalid run_id can trigger IntegrityError.
317
+ # This may need to be changed in the future version with more integrity checks.
318
+ self.query(query, data)
319
+
320
+ return task_id
321
+
322
+ def get_task_ins(
323
+ self, node_id: Optional[int], limit: Optional[int]
324
+ ) -> list[TaskIns]:
325
+ """Get undelivered TaskIns for one node (either anonymous or with ID).
326
+
327
+ Usually, the Fleet API calls this for Nodes planning to work on one or more
328
+ TaskIns.
329
+
330
+ Constraints
331
+ -----------
332
+ If `node_id` is not `None`, retrieve all TaskIns where
333
+
334
+ 1. the `task_ins.task.consumer.node_id` equals `node_id` AND
335
+ 2. the `task_ins.task.consumer.anonymous` equals `False` AND
336
+ 3. the `task_ins.task.delivered_at` equals `""`.
337
+
338
+ If `node_id` is `None`, retrieve all TaskIns where the
339
+ `task_ins.task.consumer.node_id` equals `0` and
340
+ `task_ins.task.consumer.anonymous` is set to `True`.
341
+
342
+ `delivered_at` MUST BE set (i.e., not `""`) otherwise the TaskIns MUST not be in
343
+ the result.
344
+
345
+ If `limit` is not `None`, return, at most, `limit` number of `task_ins`. If
346
+ `limit` is set, it has to be greater than zero.
347
+ """
348
+ if limit is not None and limit < 1:
349
+ raise AssertionError("`limit` must be >= 1")
350
+
351
+ if node_id == 0:
352
+ msg = (
353
+ "`node_id` must be >= 1"
354
+ "\n\n For requesting anonymous tasks use `node_id` equal `None`"
355
+ )
356
+ raise AssertionError(msg)
357
+
358
+ data: dict[str, Union[str, int]] = {}
359
+
360
+ if node_id is None:
361
+ # Retrieve all anonymous Tasks
362
+ query = """
363
+ SELECT task_id
364
+ FROM task_ins
365
+ WHERE consumer_anonymous == 1
366
+ AND consumer_node_id == 0
367
+ AND delivered_at = ""
368
+ AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
369
+ """
370
+ else:
371
+ # Convert the uint64 value to sint64 for SQLite
372
+ data["node_id"] = convert_uint64_to_sint64(node_id)
373
+
374
+ # Retrieve all TaskIns for node_id
375
+ query = """
376
+ SELECT task_id
377
+ FROM task_ins
378
+ WHERE consumer_anonymous == 0
379
+ AND consumer_node_id == :node_id
380
+ AND delivered_at = ""
381
+ AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
382
+ """
383
+
384
+ if limit is not None:
385
+ query += " LIMIT :limit"
386
+ data["limit"] = limit
387
+
388
+ query += ";"
389
+
390
+ rows = self.query(query, data)
391
+
392
+ if rows:
393
+ # Prepare query
394
+ task_ids = [row["task_id"] for row in rows]
395
+ placeholders: str = ",".join([f":id_{i}" for i in range(len(task_ids))])
396
+ query = f"""
397
+ UPDATE task_ins
398
+ SET delivered_at = :delivered_at
399
+ WHERE task_id IN ({placeholders})
400
+ RETURNING *;
401
+ """
402
+
403
+ # Prepare data for query
404
+ delivered_at = now().isoformat()
405
+ data = {"delivered_at": delivered_at}
406
+ for index, task_id in enumerate(task_ids):
407
+ data[f"id_{index}"] = str(task_id)
408
+
409
+ # Run query
410
+ rows = self.query(query, data)
411
+
412
+ for row in rows:
413
+ # Convert values from sint64 to uint64
414
+ convert_sint64_values_in_dict_to_uint64(
415
+ row, ["run_id", "producer_node_id", "consumer_node_id"]
416
+ )
417
+
418
+ result = [dict_to_task_ins(row) for row in rows]
419
+
420
+ return result
421
+
422
+ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
423
+ """Store one TaskRes.
424
+
425
+ Usually, the Fleet API calls this when Nodes return their results.
426
+
427
+ Stores the TaskRes and, if successful, returns the `task_id` (UUID) of
428
+ the `task_res`. If storing the `task_res` fails, `None` is returned.
429
+
430
+ Constraints
431
+ -----------
432
+ If `task_res.task.consumer.anonymous` is `True`, then
433
+ `task_res.task.consumer.node_id` MUST NOT be set (equal 0).
434
+
435
+ If `task_res.task.consumer.anonymous` is `False`, then
436
+ `task_res.task.consumer.node_id` MUST be set (not 0)
437
+ """
438
+ # Validate task
439
+ errors = validate_task_ins_or_res(task_res)
440
+ if any(errors):
441
+ log(ERROR, errors)
442
+ return None
443
+
444
+ # Create task_id
445
+ task_id = uuid4()
446
+
447
+ task_ins_id = task_res.task.ancestry[0]
448
+ task_ins = self.get_valid_task_ins(task_ins_id)
449
+ if task_ins is None:
450
+ log(
451
+ ERROR,
452
+ "Failed to store TaskRes: "
453
+ "TaskIns with task_id %s does not exist or has expired.",
454
+ task_ins_id,
455
+ )
456
+ return None
457
+
458
+ # Ensure that the consumer_id of taskIns matches the producer_id of taskRes.
459
+ if (
460
+ task_ins
461
+ and task_res
462
+ and not (task_ins["consumer_anonymous"] or task_res.task.producer.anonymous)
463
+ and convert_sint64_to_uint64(task_ins["consumer_node_id"])
464
+ != task_res.task.producer.node_id
465
+ ):
466
+ return None
467
+
468
+ # Fail if the TaskRes TTL exceeds the
469
+ # expiration time of the TaskIns it replies to.
470
+ # Condition: TaskIns.created_at + TaskIns.ttl ≥
471
+ # TaskRes.created_at + TaskRes.ttl
472
+ # A small tolerance is introduced to account
473
+ # for floating-point precision issues.
474
+ max_allowed_ttl = (
475
+ task_ins["created_at"] + task_ins["ttl"] - task_res.task.created_at
476
+ )
477
+ if task_res.task.ttl and (
478
+ task_res.task.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
479
+ ):
480
+ log(
481
+ WARNING,
482
+ "Received TaskRes with TTL %.2f "
483
+ "exceeding the allowed maximum TTL %.2f.",
484
+ task_res.task.ttl,
485
+ max_allowed_ttl,
486
+ )
487
+ return None
488
+
489
+ # Store TaskRes
490
+ task_res.task_id = str(task_id)
491
+ data = (task_res_to_dict(task_res),)
492
+
493
+ # Convert values from uint64 to sint64 for SQLite
494
+ convert_uint64_values_in_dict_to_sint64(
495
+ data[0], ["run_id", "producer_node_id", "consumer_node_id"]
496
+ )
497
+
498
+ columns = ", ".join([f":{key}" for key in data[0]])
499
+ query = f"INSERT INTO task_res VALUES({columns});"
500
+
501
+ # Only invalid run_id can trigger IntegrityError.
502
+ # This may need to be changed in the future version with more integrity checks.
503
+ try:
504
+ self.query(query, data)
505
+ except sqlite3.IntegrityError:
506
+ log(ERROR, "`run` is invalid")
507
+ return None
508
+
509
+ return task_id
510
+
511
+ # pylint: disable-next=R0912,R0915,R0914
512
+ def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
513
+ """Get TaskRes for the given TaskIns IDs."""
514
+ ret: dict[UUID, TaskRes] = {}
515
+
516
+ # Verify TaskIns IDs
517
+ current = time.time()
518
+ query = f"""
519
+ SELECT *
520
+ FROM task_ins
521
+ WHERE task_id IN ({",".join(["?"] * len(task_ids))});
522
+ """
523
+ rows = self.query(query, tuple(str(task_id) for task_id in task_ids))
524
+ found_task_ins_dict: dict[UUID, TaskIns] = {}
525
+ for row in rows:
526
+ convert_sint64_values_in_dict_to_uint64(
527
+ row, ["run_id", "producer_node_id", "consumer_node_id"]
528
+ )
529
+ found_task_ins_dict[UUID(row["task_id"])] = dict_to_task_ins(row)
530
+
531
+ ret = verify_taskins_ids(
532
+ inquired_taskins_ids=task_ids,
533
+ found_taskins_dict=found_task_ins_dict,
534
+ current_time=current,
535
+ )
536
+
537
+ # Find all TaskRes
538
+ query = f"""
539
+ SELECT *
540
+ FROM task_res
541
+ WHERE ancestry IN ({",".join(["?"] * len(task_ids))})
542
+ AND delivered_at = "";
543
+ """
544
+ rows = self.query(query, tuple(str(task_id) for task_id in task_ids))
545
+ for row in rows:
546
+ convert_sint64_values_in_dict_to_uint64(
547
+ row, ["run_id", "producer_node_id", "consumer_node_id"]
548
+ )
549
+ tmp_ret_dict = verify_found_taskres(
550
+ inquired_taskins_ids=task_ids,
551
+ found_taskins_dict=found_task_ins_dict,
552
+ found_taskres_list=[dict_to_task_res(row) for row in rows],
553
+ current_time=current,
554
+ )
555
+ ret.update(tmp_ret_dict)
556
+
557
+ # Mark existing TaskRes to be returned as delivered
558
+ delivered_at = now().isoformat()
559
+ for task_res in ret.values():
560
+ task_res.task.delivered_at = delivered_at
561
+ task_res_ids = [task_res.task_id for task_res in ret.values()]
562
+ query = f"""
563
+ UPDATE task_res
564
+ SET delivered_at = ?
565
+ WHERE task_id IN ({",".join(["?"] * len(task_res_ids))});
566
+ """
567
+ data: list[Any] = [delivered_at] + task_res_ids
568
+ self.query(query, data)
569
+
570
+ return list(ret.values())
571
+
572
+ def num_task_ins(self) -> int:
573
+ """Calculate the number of task_ins in store.
574
+
575
+ This includes delivered but not yet deleted task_ins.
576
+ """
577
+ query = "SELECT count(*) AS num FROM task_ins;"
578
+ rows = self.query(query)
579
+ result = rows[0]
580
+ num = cast(int, result["num"])
581
+ return num
582
+
583
+ def num_task_res(self) -> int:
584
+ """Calculate the number of task_res in store.
585
+
586
+ This includes delivered but not yet deleted task_res.
587
+ """
588
+ query = "SELECT count(*) AS num FROM task_res;"
589
+ rows = self.query(query)
590
+ result: dict[str, int] = rows[0]
591
+ return result["num"]
592
+
593
+ def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
594
+ """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
595
+ if not task_ins_ids:
596
+ return
597
+ if self.conn is None:
598
+ raise AttributeError("LinkState not initialized")
599
+
600
+ placeholders = ",".join(["?"] * len(task_ins_ids))
601
+ data = tuple(str(task_id) for task_id in task_ins_ids)
602
+
603
+ # Delete task_ins
604
+ query_1 = f"""
605
+ DELETE FROM task_ins
606
+ WHERE task_id IN ({placeholders});
607
+ """
608
+
609
+ # Delete task_res
610
+ query_2 = f"""
611
+ DELETE FROM task_res
612
+ WHERE ancestry IN ({placeholders});
613
+ """
614
+
615
+ with self.conn:
616
+ self.conn.execute(query_1, data)
617
+ self.conn.execute(query_2, data)
618
+
619
+ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
620
+ """Get all TaskIns IDs for the given run_id."""
621
+ if self.conn is None:
622
+ raise AttributeError("LinkState not initialized")
623
+
624
+ query = """
625
+ SELECT task_id
626
+ FROM task_ins
627
+ WHERE run_id = :run_id;
628
+ """
629
+
630
+ sint64_run_id = convert_uint64_to_sint64(run_id)
631
+ data = {"run_id": sint64_run_id}
632
+
633
+ with self.conn:
634
+ rows = self.conn.execute(query, data).fetchall()
635
+
636
+ return {UUID(row["task_id"]) for row in rows}
637
+
638
+ def create_node(self, ping_interval: float) -> int:
639
+ """Create, store in the link state, and return `node_id`."""
640
+ # Sample a random uint64 as node_id
641
+ uint64_node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
642
+
643
+ # Convert the uint64 value to sint64 for SQLite
644
+ sint64_node_id = convert_uint64_to_sint64(uint64_node_id)
645
+
646
+ query = (
647
+ "INSERT INTO node "
648
+ "(node_id, online_until, ping_interval, public_key) "
649
+ "VALUES (?, ?, ?, ?)"
650
+ )
651
+
652
+ try:
653
+ self.query(
654
+ query,
655
+ (
656
+ sint64_node_id,
657
+ time.time() + ping_interval,
658
+ ping_interval,
659
+ b"", # Initialize with an empty public key
660
+ ),
661
+ )
662
+ except sqlite3.IntegrityError:
663
+ log(ERROR, "Unexpected node registration failure.")
664
+ return 0
665
+
666
+ # Note: we need to return the uint64 value of the node_id
667
+ return uint64_node_id
668
+
669
+ def delete_node(self, node_id: int) -> None:
670
+ """Delete a node."""
671
+ # Convert the uint64 value to sint64 for SQLite
672
+ sint64_node_id = convert_uint64_to_sint64(node_id)
673
+
674
+ query = "DELETE FROM node WHERE node_id = ?"
675
+ params = (sint64_node_id,)
676
+
677
+ if self.conn is None:
678
+ raise AttributeError("LinkState is not initialized.")
679
+
680
+ try:
681
+ with self.conn:
682
+ rows = self.conn.execute(query, params)
683
+ if rows.rowcount < 1:
684
+ raise ValueError(f"Node {node_id} not found")
685
+ except KeyError as exc:
686
+ log(ERROR, {"query": query, "data": params, "exception": exc})
687
+
688
+ def get_nodes(self, run_id: int) -> set[int]:
689
+ """Retrieve all currently stored node IDs as a set.
690
+
691
+ Constraints
692
+ -----------
693
+ If the provided `run_id` does not exist or has no matching nodes,
694
+ an empty `Set` MUST be returned.
695
+ """
696
+ # Convert the uint64 value to sint64 for SQLite
697
+ sint64_run_id = convert_uint64_to_sint64(run_id)
698
+
699
+ # Validate run ID
700
+ query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
701
+ if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
702
+ return set()
703
+
704
+ # Get nodes
705
+ query = "SELECT node_id FROM node WHERE online_until > ?;"
706
+ rows = self.query(query, (time.time(),))
707
+
708
+ # Convert sint64 node_ids to uint64
709
+ result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows}
710
+ return result
711
+
712
+ def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
713
+ """Set `public_key` for the specified `node_id`."""
714
+ # Convert the uint64 value to sint64 for SQLite
715
+ sint64_node_id = convert_uint64_to_sint64(node_id)
716
+
717
+ # Check if the node exists in the `node` table
718
+ query = "SELECT 1 FROM node WHERE node_id = ?"
719
+ if not self.query(query, (sint64_node_id,)):
720
+ raise ValueError(f"Node {node_id} not found")
721
+
722
+ # Check if the public key is already in use in the `node` table
723
+ query = "SELECT 1 FROM node WHERE public_key = ?"
724
+ if self.query(query, (public_key,)):
725
+ raise ValueError("Public key already in use")
726
+
727
+ # Update the `node` table to set the public key for the given node ID
728
+ query = "UPDATE node SET public_key = ? WHERE node_id = ?"
729
+ self.query(query, (public_key, sint64_node_id))
730
+
731
+ def get_node_public_key(self, node_id: int) -> Optional[bytes]:
732
+ """Get `public_key` for the specified `node_id`."""
733
+ # Convert the uint64 value to sint64 for SQLite
734
+ sint64_node_id = convert_uint64_to_sint64(node_id)
735
+
736
+ # Query the public key for the given node_id
737
+ query = "SELECT public_key FROM node WHERE node_id = ?"
738
+ rows = self.query(query, (sint64_node_id,))
739
+
740
+ # If no result is found, return None
741
+ if not rows:
742
+ raise ValueError(f"Node {node_id} not found")
743
+
744
+ # Return the public key if it is not empty, otherwise return None
745
+ return rows[0]["public_key"] or None
746
+
747
+ def get_node_id(self, node_public_key: bytes) -> Optional[int]:
748
+ """Retrieve stored `node_id` filtered by `node_public_keys`."""
749
+ query = "SELECT node_id FROM node WHERE public_key = :public_key;"
750
+ row = self.query(query, {"public_key": node_public_key})
751
+ if len(row) > 0:
752
+ node_id: int = row[0]["node_id"]
753
+
754
+ # Convert the sint64 value to uint64 after reading from SQLite
755
+ uint64_node_id = convert_sint64_to_uint64(node_id)
756
+
757
+ return uint64_node_id
758
+ return None
759
+
760
+ # pylint: disable=too-many-arguments,too-many-positional-arguments
761
+ def create_run(
762
+ self,
763
+ fab_id: Optional[str],
764
+ fab_version: Optional[str],
765
+ fab_hash: Optional[str],
766
+ override_config: UserConfig,
767
+ federation_options: ConfigsRecord,
768
+ ) -> int:
769
+ """Create a new run for the specified `fab_id` and `fab_version`."""
770
+ # Sample a random int64 as run_id
771
+ uint64_run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
772
+
773
+ # Convert the uint64 value to sint64 for SQLite
774
+ sint64_run_id = convert_uint64_to_sint64(uint64_run_id)
775
+
776
+ # Check conflicts
777
+ query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
778
+ # If sint64_run_id does not exist
779
+ if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
780
+ query = (
781
+ "INSERT INTO run "
782
+ "(run_id, fab_id, fab_version, fab_hash, override_config, "
783
+ "federation_options, pending_at, starting_at, running_at, finished_at, "
784
+ "sub_status, details) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
785
+ )
786
+ override_config_json = json.dumps(override_config)
787
+ data = [
788
+ sint64_run_id,
789
+ fab_id,
790
+ fab_version,
791
+ fab_hash,
792
+ override_config_json,
793
+ configsrecord_to_bytes(federation_options),
794
+ ]
795
+ data += [
796
+ now().isoformat(),
797
+ "",
798
+ "",
799
+ "",
800
+ "",
801
+ "",
802
+ ]
803
+ self.query(query, tuple(data))
804
+ return uint64_run_id
805
+ log(ERROR, "Unexpected run creation failure.")
806
+ return 0
807
+
808
+ def store_server_private_public_key(
809
+ self, private_key: bytes, public_key: bytes
810
+ ) -> None:
811
+ """Store `server_private_key` and `server_public_key` in the link state."""
812
+ query = "SELECT COUNT(*) FROM credential"
813
+ count = self.query(query)[0]["COUNT(*)"]
814
+ if count < 1:
815
+ query = (
816
+ "INSERT OR REPLACE INTO credential (private_key, public_key) "
817
+ "VALUES (:private_key, :public_key)"
818
+ )
819
+ self.query(query, {"private_key": private_key, "public_key": public_key})
820
+ else:
821
+ raise RuntimeError("Server private and public key already set")
822
+
823
+ def get_server_private_key(self) -> Optional[bytes]:
824
+ """Retrieve `server_private_key` in urlsafe bytes."""
825
+ query = "SELECT private_key FROM credential"
826
+ rows = self.query(query)
827
+ try:
828
+ private_key: Optional[bytes] = rows[0]["private_key"]
829
+ except IndexError:
830
+ private_key = None
831
+ return private_key
832
+
833
+ def get_server_public_key(self) -> Optional[bytes]:
834
+ """Retrieve `server_public_key` in urlsafe bytes."""
835
+ query = "SELECT public_key FROM credential"
836
+ rows = self.query(query)
837
+ try:
838
+ public_key: Optional[bytes] = rows[0]["public_key"]
839
+ except IndexError:
840
+ public_key = None
841
+ return public_key
842
+
843
+ def clear_supernode_auth_keys_and_credentials(self) -> None:
844
+ """Clear stored `node_public_keys` and credentials in the link state if any."""
845
+ queries = ["DELETE FROM public_key;", "DELETE FROM credential;"]
846
+ for query in queries:
847
+ self.query(query)
848
+
849
+ def store_node_public_keys(self, public_keys: set[bytes]) -> None:
850
+ """Store a set of `node_public_keys` in the link state."""
851
+ query = "INSERT INTO public_key (public_key) VALUES (?)"
852
+ data = [(key,) for key in public_keys]
853
+ self.query(query, data)
854
+
855
+ def store_node_public_key(self, public_key: bytes) -> None:
856
+ """Store a `node_public_key` in the link state."""
857
+ query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
858
+ self.query(query, {"public_key": public_key})
859
+
860
+ def get_node_public_keys(self) -> set[bytes]:
861
+ """Retrieve all currently stored `node_public_keys` as a set."""
862
+ query = "SELECT public_key FROM public_key"
863
+ rows = self.query(query)
864
+ result: set[bytes] = {row["public_key"] for row in rows}
865
+ return result
866
+
867
+ def get_run_ids(self) -> set[int]:
868
+ """Retrieve all run IDs."""
869
+ query = "SELECT run_id FROM run;"
870
+ rows = self.query(query)
871
+ return {convert_sint64_to_uint64(row["run_id"]) for row in rows}
872
+
873
+ def get_run(self, run_id: int) -> Optional[Run]:
874
+ """Retrieve information about the run with the specified `run_id`."""
875
+ # Convert the uint64 value to sint64 for SQLite
876
+ sint64_run_id = convert_uint64_to_sint64(run_id)
877
+ query = "SELECT * FROM run WHERE run_id = ?;"
878
+ rows = self.query(query, (sint64_run_id,))
879
+ if rows:
880
+ row = rows[0]
881
+ return Run(
882
+ run_id=convert_sint64_to_uint64(row["run_id"]),
883
+ fab_id=row["fab_id"],
884
+ fab_version=row["fab_version"],
885
+ fab_hash=row["fab_hash"],
886
+ override_config=json.loads(row["override_config"]),
887
+ pending_at=row["pending_at"],
888
+ starting_at=row["starting_at"],
889
+ running_at=row["running_at"],
890
+ finished_at=row["finished_at"],
891
+ status=RunStatus(
892
+ status=determine_run_status(row),
893
+ sub_status=row["sub_status"],
894
+ details=row["details"],
895
+ ),
896
+ )
897
+ log(ERROR, "`run_id` does not exist.")
898
+ return None
899
+
900
+ def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
901
+ """Retrieve the statuses for the specified runs."""
902
+ # Convert the uint64 value to sint64 for SQLite
903
+ sint64_run_ids = (convert_uint64_to_sint64(run_id) for run_id in set(run_ids))
904
+ query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});"
905
+ rows = self.query(query, tuple(sint64_run_ids))
906
+
907
+ return {
908
+ # Restore uint64 run IDs
909
+ convert_sint64_to_uint64(row["run_id"]): RunStatus(
910
+ status=determine_run_status(row),
911
+ sub_status=row["sub_status"],
912
+ details=row["details"],
913
+ )
914
+ for row in rows
915
+ }
916
+
917
+ def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
918
+ """Update the status of the run with the specified `run_id`."""
919
+ # Convert the uint64 value to sint64 for SQLite
920
+ sint64_run_id = convert_uint64_to_sint64(run_id)
921
+ query = "SELECT * FROM run WHERE run_id = ?;"
922
+ rows = self.query(query, (sint64_run_id,))
923
+
924
+ # Check if the run_id exists
925
+ if not rows:
926
+ log(ERROR, "`run_id` is invalid")
927
+ return False
928
+
929
+ # Check if the status transition is valid
930
+ row = rows[0]
931
+ current_status = RunStatus(
932
+ status=determine_run_status(row),
933
+ sub_status=row["sub_status"],
934
+ details=row["details"],
935
+ )
936
+ if not is_valid_transition(current_status, new_status):
937
+ log(
938
+ ERROR,
939
+ 'Invalid status transition: from "%s" to "%s"',
940
+ current_status.status,
941
+ new_status.status,
942
+ )
943
+ return False
944
+
945
+ # Check if the sub-status is valid
946
+ if not has_valid_sub_status(current_status):
947
+ log(
948
+ ERROR,
949
+ 'Invalid sub-status "%s" for status "%s"',
950
+ current_status.sub_status,
951
+ current_status.status,
952
+ )
953
+ return False
954
+
955
+ # Update the status
956
+ query = "UPDATE run SET %s= ?, sub_status = ?, details = ? "
957
+ query += "WHERE run_id = ?;"
958
+
959
+ timestamp_fld = ""
960
+ if new_status.status == Status.STARTING:
961
+ timestamp_fld = "starting_at"
962
+ elif new_status.status == Status.RUNNING:
963
+ timestamp_fld = "running_at"
964
+ elif new_status.status == Status.FINISHED:
965
+ timestamp_fld = "finished_at"
966
+
967
+ data = (
968
+ now().isoformat(),
969
+ new_status.sub_status,
970
+ new_status.details,
971
+ sint64_run_id,
972
+ )
973
+ self.query(query % timestamp_fld, data)
974
+ return True
975
+
976
+ def get_pending_run_id(self) -> Optional[int]:
977
+ """Get the `run_id` of a run with `Status.PENDING` status, if any."""
978
+ pending_run_id = None
979
+
980
+ # Fetch all runs with unset `starting_at` (i.e. they are in PENDING status)
981
+ query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1;"
982
+ rows = self.query(query)
983
+ if rows:
984
+ pending_run_id = convert_sint64_to_uint64(rows[0]["run_id"])
985
+
986
+ return pending_run_id
987
+
988
+ def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
989
+ """Retrieve the federation options for the specified `run_id`."""
990
+ # Convert the uint64 value to sint64 for SQLite
991
+ sint64_run_id = convert_uint64_to_sint64(run_id)
992
+ query = "SELECT federation_options FROM run WHERE run_id = ?;"
993
+ rows = self.query(query, (sint64_run_id,))
994
+
995
+ # Check if the run_id exists
996
+ if not rows:
997
+ log(ERROR, "`run_id` is invalid")
998
+ return None
999
+
1000
+ row = rows[0]
1001
+ return configsrecord_from_bytes(row["federation_options"])
1002
+
1003
+ def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
1004
+ """Acknowledge a ping received from a node, serving as a heartbeat."""
1005
+ sint64_node_id = convert_uint64_to_sint64(node_id)
1006
+
1007
+ # Check if the node exists in the `node` table
1008
+ query = "SELECT 1 FROM node WHERE node_id = ?"
1009
+ if not self.query(query, (sint64_node_id,)):
1010
+ return False
1011
+
1012
+ # Update `online_until` and `ping_interval` for the given `node_id`
1013
+ query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?"
1014
+ self.query(query, (time.time() + ping_interval, ping_interval, sint64_node_id))
1015
+ return True
1016
+
1017
+ def get_serverapp_context(self, run_id: int) -> Optional[Context]:
1018
+ """Get the context for the specified `run_id`."""
1019
+ # Retrieve context if any
1020
+ query = "SELECT context FROM context WHERE run_id = ?;"
1021
+ rows = self.query(query, (convert_uint64_to_sint64(run_id),))
1022
+ context = context_from_bytes(rows[0]["context"]) if rows else None
1023
+ return context
1024
+
1025
+ def set_serverapp_context(self, run_id: int, context: Context) -> None:
1026
+ """Set the context for the specified `run_id`."""
1027
+ # Convert context to bytes
1028
+ context_bytes = context_to_bytes(context)
1029
+ sint_run_id = convert_uint64_to_sint64(run_id)
1030
+
1031
+ # Check if any existing Context assigned to the run_id
1032
+ query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
1033
+ if self.query(query, (sint_run_id,))[0]["COUNT(*)"] > 0:
1034
+ # Update context
1035
+ query = "UPDATE context SET context = ? WHERE run_id = ?;"
1036
+ self.query(query, (context_bytes, sint_run_id))
1037
+ else:
1038
+ try:
1039
+ # Store context
1040
+ query = "INSERT INTO context (run_id, context) VALUES (?, ?);"
1041
+ self.query(query, (sint_run_id, context_bytes))
1042
+ except sqlite3.IntegrityError:
1043
+ raise ValueError(f"Run {run_id} not found") from None
1044
+
1045
+ def add_serverapp_log(self, run_id: int, log_message: str) -> None:
1046
+ """Add a log entry to the ServerApp logs for the specified `run_id`."""
1047
+ # Convert the uint64 value to sint64 for SQLite
1048
+ sint64_run_id = convert_uint64_to_sint64(run_id)
1049
+
1050
+ # Store log
1051
+ try:
1052
+ query = """
1053
+ INSERT INTO logs (timestamp, run_id, node_id, log) VALUES (?, ?, ?, ?);
1054
+ """
1055
+ self.query(query, (now().timestamp(), sint64_run_id, 0, log_message))
1056
+ except sqlite3.IntegrityError:
1057
+ raise ValueError(f"Run {run_id} not found") from None
1058
+
1059
+ def get_serverapp_log(
1060
+ self, run_id: int, after_timestamp: Optional[float]
1061
+ ) -> tuple[str, float]:
1062
+ """Get the ServerApp logs for the specified `run_id`."""
1063
+ # Convert the uint64 value to sint64 for SQLite
1064
+ sint64_run_id = convert_uint64_to_sint64(run_id)
1065
+
1066
+ # Check if the run_id exists
1067
+ query = "SELECT run_id FROM run WHERE run_id = ?;"
1068
+ if not self.query(query, (sint64_run_id,)):
1069
+ raise ValueError(f"Run {run_id} not found")
1070
+
1071
+ # Retrieve logs
1072
+ if after_timestamp is None:
1073
+ after_timestamp = 0.0
1074
+ query = """
1075
+ SELECT log, timestamp FROM logs
1076
+ WHERE run_id = ? AND node_id = ? AND timestamp > ?;
1077
+ """
1078
+ rows = self.query(query, (sint64_run_id, 0, after_timestamp))
1079
+ rows.sort(key=lambda x: x["timestamp"])
1080
+ latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
1081
+ return "".join(row["log"] for row in rows), latest_timestamp
1082
+
1083
+ def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]:
1084
+ """Check if the TaskIns exists and is valid (not expired).
1085
+
1086
+ Return TaskIns if valid.
1087
+ """
1088
+ query = """
1089
+ SELECT *
1090
+ FROM task_ins
1091
+ WHERE task_id = :task_id
1092
+ """
1093
+ data = {"task_id": task_id}
1094
+ rows = self.query(query, data)
1095
+ if not rows:
1096
+ # TaskIns does not exist
1097
+ return None
1098
+
1099
+ task_ins = rows[0]
1100
+ created_at = task_ins["created_at"]
1101
+ ttl = task_ins["ttl"]
1102
+ current_time = time.time()
1103
+
1104
+ # Check if TaskIns is expired
1105
+ if ttl is not None and created_at + ttl <= current_time:
1106
+ return None
1107
+
1108
+ return task_ins
1109
+
1110
+
1111
+ def dict_factory(
1112
+ cursor: sqlite3.Cursor,
1113
+ row: sqlite3.Row,
1114
+ ) -> dict[str, Any]:
1115
+ """Turn SQLite results into dicts.
1116
+
1117
+ Less efficent for retrival of large amounts of data but easier to use.
1118
+ """
1119
+ fields = [column[0] for column in cursor.description]
1120
+ return dict(zip(fields, row))
1121
+
1122
+
1123
+ def task_ins_to_dict(task_msg: TaskIns) -> dict[str, Any]:
1124
+ """Transform TaskIns to dict."""
1125
+ result = {
1126
+ "task_id": task_msg.task_id,
1127
+ "group_id": task_msg.group_id,
1128
+ "run_id": task_msg.run_id,
1129
+ "producer_anonymous": task_msg.task.producer.anonymous,
1130
+ "producer_node_id": task_msg.task.producer.node_id,
1131
+ "consumer_anonymous": task_msg.task.consumer.anonymous,
1132
+ "consumer_node_id": task_msg.task.consumer.node_id,
1133
+ "created_at": task_msg.task.created_at,
1134
+ "delivered_at": task_msg.task.delivered_at,
1135
+ "pushed_at": task_msg.task.pushed_at,
1136
+ "ttl": task_msg.task.ttl,
1137
+ "ancestry": ",".join(task_msg.task.ancestry),
1138
+ "task_type": task_msg.task.task_type,
1139
+ "recordset": task_msg.task.recordset.SerializeToString(),
1140
+ }
1141
+ return result
1142
+
1143
+
1144
+ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
1145
+ """Transform TaskRes to dict."""
1146
+ result = {
1147
+ "task_id": task_msg.task_id,
1148
+ "group_id": task_msg.group_id,
1149
+ "run_id": task_msg.run_id,
1150
+ "producer_anonymous": task_msg.task.producer.anonymous,
1151
+ "producer_node_id": task_msg.task.producer.node_id,
1152
+ "consumer_anonymous": task_msg.task.consumer.anonymous,
1153
+ "consumer_node_id": task_msg.task.consumer.node_id,
1154
+ "created_at": task_msg.task.created_at,
1155
+ "delivered_at": task_msg.task.delivered_at,
1156
+ "pushed_at": task_msg.task.pushed_at,
1157
+ "ttl": task_msg.task.ttl,
1158
+ "ancestry": ",".join(task_msg.task.ancestry),
1159
+ "task_type": task_msg.task.task_type,
1160
+ "recordset": task_msg.task.recordset.SerializeToString(),
1161
+ }
1162
+ return result
1163
+
1164
+
1165
+ def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
1166
+ """Turn task_dict into protobuf message."""
1167
+ recordset = ProtoRecordSet()
1168
+ recordset.ParseFromString(task_dict["recordset"])
1169
+
1170
+ result = TaskIns(
1171
+ task_id=task_dict["task_id"],
1172
+ group_id=task_dict["group_id"],
1173
+ run_id=task_dict["run_id"],
1174
+ task=Task(
1175
+ producer=Node(
1176
+ node_id=task_dict["producer_node_id"],
1177
+ anonymous=task_dict["producer_anonymous"],
1178
+ ),
1179
+ consumer=Node(
1180
+ node_id=task_dict["consumer_node_id"],
1181
+ anonymous=task_dict["consumer_anonymous"],
1182
+ ),
1183
+ created_at=task_dict["created_at"],
1184
+ delivered_at=task_dict["delivered_at"],
1185
+ pushed_at=task_dict["pushed_at"],
1186
+ ttl=task_dict["ttl"],
1187
+ ancestry=task_dict["ancestry"].split(","),
1188
+ task_type=task_dict["task_type"],
1189
+ recordset=recordset,
1190
+ ),
1191
+ )
1192
+ return result
1193
+
1194
+
1195
+ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
1196
+ """Turn task_dict into protobuf message."""
1197
+ recordset = ProtoRecordSet()
1198
+ recordset.ParseFromString(task_dict["recordset"])
1199
+
1200
+ result = TaskRes(
1201
+ task_id=task_dict["task_id"],
1202
+ group_id=task_dict["group_id"],
1203
+ run_id=task_dict["run_id"],
1204
+ task=Task(
1205
+ producer=Node(
1206
+ node_id=task_dict["producer_node_id"],
1207
+ anonymous=task_dict["producer_anonymous"],
1208
+ ),
1209
+ consumer=Node(
1210
+ node_id=task_dict["consumer_node_id"],
1211
+ anonymous=task_dict["consumer_anonymous"],
1212
+ ),
1213
+ created_at=task_dict["created_at"],
1214
+ delivered_at=task_dict["delivered_at"],
1215
+ pushed_at=task_dict["pushed_at"],
1216
+ ttl=task_dict["ttl"],
1217
+ ancestry=task_dict["ancestry"].split(","),
1218
+ task_type=task_dict["task_type"],
1219
+ recordset=recordset,
1220
+ ),
1221
+ )
1222
+ return result
1223
+
1224
+
1225
+ def determine_run_status(row: dict[str, Any]) -> str:
1226
+ """Determine the status of the run based on timestamp fields."""
1227
+ if row["pending_at"]:
1228
+ if row["finished_at"]:
1229
+ return Status.FINISHED
1230
+ if row["starting_at"]:
1231
+ if row["running_at"]:
1232
+ return Status.RUNNING
1233
+ return Status.STARTING
1234
+ return Status.PENDING
1235
+ run_id = convert_sint64_to_uint64(row["run_id"])
1236
+ raise sqlite3.IntegrityError(f"The run {run_id} does not have a valid status.")