flwr-nightly 1.8.0.dev20240315__py3-none-any.whl → 1.15.0.dev20250115__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (312) hide show
  1. flwr/cli/app.py +16 -2
  2. flwr/cli/build.py +181 -0
  3. flwr/cli/cli_user_auth_interceptor.py +90 -0
  4. flwr/cli/config_utils.py +343 -0
  5. flwr/cli/example.py +4 -1
  6. flwr/cli/install.py +253 -0
  7. flwr/cli/log.py +182 -0
  8. flwr/{server/superlink/state → cli/login}/__init__.py +4 -10
  9. flwr/cli/login/login.py +88 -0
  10. flwr/cli/ls.py +327 -0
  11. flwr/cli/new/__init__.py +1 -0
  12. flwr/cli/new/new.py +210 -66
  13. flwr/cli/new/templates/app/.gitignore.tpl +163 -0
  14. flwr/cli/new/templates/app/LICENSE.tpl +202 -0
  15. flwr/cli/new/templates/app/README.baseline.md.tpl +127 -0
  16. flwr/cli/new/templates/app/README.flowertune.md.tpl +66 -0
  17. flwr/cli/new/templates/app/README.md.tpl +16 -32
  18. flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +1 -0
  19. flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
  20. flwr/cli/new/templates/app/code/client.baseline.py.tpl +58 -0
  21. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +55 -0
  22. flwr/cli/new/templates/app/code/client.jax.py.tpl +50 -0
  23. flwr/cli/new/templates/app/code/client.mlx.py.tpl +73 -0
  24. flwr/cli/new/templates/app/code/client.numpy.py.tpl +7 -7
  25. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +30 -21
  26. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +63 -0
  27. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +57 -1
  28. flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +36 -0
  29. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  30. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +126 -0
  31. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +87 -0
  32. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +78 -0
  33. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +94 -0
  34. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +83 -0
  35. flwr/cli/new/templates/app/code/model.baseline.py.tpl +80 -0
  36. flwr/cli/new/templates/app/code/server.baseline.py.tpl +46 -0
  37. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +38 -0
  38. flwr/cli/new/templates/app/code/server.jax.py.tpl +26 -0
  39. flwr/cli/new/templates/app/code/server.mlx.py.tpl +31 -0
  40. flwr/cli/new/templates/app/code/server.numpy.py.tpl +22 -9
  41. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
  42. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +36 -0
  43. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
  44. flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +1 -0
  45. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +102 -0
  46. flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
  47. flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
  48. flwr/cli/new/templates/app/code/task.numpy.py.tpl +7 -0
  49. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +29 -24
  50. flwr/cli/new/templates/app/code/task.sklearn.py.tpl +67 -0
  51. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
  52. flwr/cli/new/templates/app/code/utils.baseline.py.tpl +1 -0
  53. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +138 -0
  54. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +68 -0
  55. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +46 -0
  56. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +35 -0
  57. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
  58. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
  59. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
  60. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +35 -0
  61. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
  62. flwr/cli/run/__init__.py +1 -0
  63. flwr/cli/run/run.py +212 -34
  64. flwr/cli/stop.py +130 -0
  65. flwr/cli/utils.py +240 -5
  66. flwr/client/__init__.py +3 -2
  67. flwr/client/app.py +432 -255
  68. flwr/client/client.py +1 -11
  69. flwr/client/client_app.py +74 -13
  70. flwr/client/clientapp/__init__.py +22 -0
  71. flwr/client/clientapp/app.py +259 -0
  72. flwr/client/clientapp/clientappio_servicer.py +244 -0
  73. flwr/client/clientapp/utils.py +115 -0
  74. flwr/client/dpfedavg_numpy_client.py +7 -8
  75. flwr/client/grpc_adapter_client/__init__.py +15 -0
  76. flwr/client/grpc_adapter_client/connection.py +98 -0
  77. flwr/client/grpc_client/connection.py +21 -7
  78. flwr/client/grpc_rere_client/__init__.py +1 -1
  79. flwr/client/grpc_rere_client/client_interceptor.py +176 -0
  80. flwr/client/grpc_rere_client/connection.py +163 -56
  81. flwr/client/grpc_rere_client/grpc_adapter.py +167 -0
  82. flwr/client/heartbeat.py +74 -0
  83. flwr/client/message_handler/__init__.py +1 -1
  84. flwr/client/message_handler/message_handler.py +10 -11
  85. flwr/client/mod/__init__.py +5 -5
  86. flwr/client/mod/centraldp_mods.py +4 -2
  87. flwr/client/mod/comms_mods.py +5 -4
  88. flwr/client/mod/localdp_mod.py +10 -5
  89. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  90. flwr/client/mod/secure_aggregation/secaggplus_mod.py +26 -26
  91. flwr/client/mod/utils.py +2 -4
  92. flwr/client/nodestate/__init__.py +26 -0
  93. flwr/client/nodestate/in_memory_nodestate.py +38 -0
  94. flwr/client/nodestate/nodestate.py +31 -0
  95. flwr/client/nodestate/nodestate_factory.py +38 -0
  96. flwr/client/numpy_client.py +8 -31
  97. flwr/client/rest_client/__init__.py +1 -1
  98. flwr/client/rest_client/connection.py +199 -176
  99. flwr/client/run_info_store.py +112 -0
  100. flwr/client/supernode/__init__.py +24 -0
  101. flwr/client/supernode/app.py +321 -0
  102. flwr/client/typing.py +1 -0
  103. flwr/common/__init__.py +17 -11
  104. flwr/common/address.py +47 -3
  105. flwr/common/args.py +153 -0
  106. flwr/common/auth_plugin/__init__.py +24 -0
  107. flwr/common/auth_plugin/auth_plugin.py +121 -0
  108. flwr/common/config.py +243 -0
  109. flwr/common/constant.py +135 -1
  110. flwr/common/context.py +32 -2
  111. flwr/common/date.py +22 -4
  112. flwr/common/differential_privacy.py +2 -2
  113. flwr/common/dp.py +2 -4
  114. flwr/common/exit_handlers.py +3 -3
  115. flwr/common/grpc.py +164 -5
  116. flwr/common/logger.py +230 -12
  117. flwr/common/message.py +191 -106
  118. flwr/common/object_ref.py +179 -44
  119. flwr/common/pyproject.py +1 -0
  120. flwr/common/record/__init__.py +2 -1
  121. flwr/common/record/configsrecord.py +58 -18
  122. flwr/common/record/metricsrecord.py +57 -17
  123. flwr/common/record/parametersrecord.py +88 -20
  124. flwr/common/record/recordset.py +153 -30
  125. flwr/common/record/typeddict.py +30 -55
  126. flwr/common/recordset_compat.py +31 -12
  127. flwr/common/retry_invoker.py +123 -30
  128. flwr/common/secure_aggregation/__init__.py +1 -1
  129. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  130. flwr/common/secure_aggregation/crypto/shamir.py +11 -11
  131. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +68 -4
  132. flwr/common/secure_aggregation/ndarrays_arithmetic.py +17 -17
  133. flwr/common/secure_aggregation/quantization.py +8 -8
  134. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  135. flwr/common/secure_aggregation/secaggplus_utils.py +10 -12
  136. flwr/common/serde.py +304 -23
  137. flwr/common/telemetry.py +65 -29
  138. flwr/common/typing.py +120 -19
  139. flwr/common/version.py +17 -3
  140. flwr/proto/clientappio_pb2.py +45 -0
  141. flwr/proto/clientappio_pb2.pyi +132 -0
  142. flwr/proto/clientappio_pb2_grpc.py +135 -0
  143. flwr/proto/clientappio_pb2_grpc.pyi +53 -0
  144. flwr/proto/exec_pb2.py +62 -0
  145. flwr/proto/exec_pb2.pyi +212 -0
  146. flwr/proto/exec_pb2_grpc.py +237 -0
  147. flwr/proto/exec_pb2_grpc.pyi +93 -0
  148. flwr/proto/fab_pb2.py +31 -0
  149. flwr/proto/fab_pb2.pyi +65 -0
  150. flwr/proto/fab_pb2_grpc.py +4 -0
  151. flwr/proto/fab_pb2_grpc.pyi +4 -0
  152. flwr/proto/fleet_pb2.py +42 -23
  153. flwr/proto/fleet_pb2.pyi +123 -1
  154. flwr/proto/fleet_pb2_grpc.py +170 -0
  155. flwr/proto/fleet_pb2_grpc.pyi +61 -0
  156. flwr/proto/grpcadapter_pb2.py +32 -0
  157. flwr/proto/grpcadapter_pb2.pyi +43 -0
  158. flwr/proto/grpcadapter_pb2_grpc.py +66 -0
  159. flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
  160. flwr/proto/log_pb2.py +29 -0
  161. flwr/proto/log_pb2.pyi +39 -0
  162. flwr/proto/log_pb2_grpc.py +4 -0
  163. flwr/proto/log_pb2_grpc.pyi +4 -0
  164. flwr/proto/message_pb2.py +41 -0
  165. flwr/proto/message_pb2.pyi +128 -0
  166. flwr/proto/message_pb2_grpc.py +4 -0
  167. flwr/proto/message_pb2_grpc.pyi +4 -0
  168. flwr/proto/node_pb2.py +2 -2
  169. flwr/proto/node_pb2.pyi +1 -4
  170. flwr/proto/recordset_pb2.py +35 -33
  171. flwr/proto/recordset_pb2.pyi +40 -14
  172. flwr/proto/run_pb2.py +64 -0
  173. flwr/proto/run_pb2.pyi +268 -0
  174. flwr/proto/run_pb2_grpc.py +4 -0
  175. flwr/proto/run_pb2_grpc.pyi +4 -0
  176. flwr/proto/serverappio_pb2.py +52 -0
  177. flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +62 -20
  178. flwr/proto/serverappio_pb2_grpc.py +410 -0
  179. flwr/proto/serverappio_pb2_grpc.pyi +160 -0
  180. flwr/proto/simulationio_pb2.py +38 -0
  181. flwr/proto/simulationio_pb2.pyi +65 -0
  182. flwr/proto/simulationio_pb2_grpc.py +239 -0
  183. flwr/proto/simulationio_pb2_grpc.pyi +94 -0
  184. flwr/proto/task_pb2.py +7 -8
  185. flwr/proto/task_pb2.pyi +8 -5
  186. flwr/proto/transport_pb2.py +8 -8
  187. flwr/proto/transport_pb2.pyi +9 -6
  188. flwr/server/__init__.py +2 -10
  189. flwr/server/app.py +579 -402
  190. flwr/server/client_manager.py +8 -6
  191. flwr/server/compat/app.py +6 -62
  192. flwr/server/compat/app_utils.py +14 -9
  193. flwr/server/compat/driver_client_proxy.py +25 -59
  194. flwr/server/compat/legacy_context.py +5 -4
  195. flwr/server/driver/__init__.py +2 -0
  196. flwr/server/driver/driver.py +36 -131
  197. flwr/server/driver/grpc_driver.py +220 -81
  198. flwr/server/driver/inmemory_driver.py +183 -0
  199. flwr/server/history.py +28 -29
  200. flwr/server/run_serverapp.py +15 -126
  201. flwr/server/server.py +50 -44
  202. flwr/server/server_app.py +59 -10
  203. flwr/server/serverapp/__init__.py +22 -0
  204. flwr/server/serverapp/app.py +256 -0
  205. flwr/server/serverapp_components.py +52 -0
  206. flwr/server/strategy/__init__.py +2 -2
  207. flwr/server/strategy/aggregate.py +37 -23
  208. flwr/server/strategy/bulyan.py +9 -9
  209. flwr/server/strategy/dp_adaptive_clipping.py +25 -25
  210. flwr/server/strategy/dp_fixed_clipping.py +23 -22
  211. flwr/server/strategy/dpfedavg_adaptive.py +8 -8
  212. flwr/server/strategy/dpfedavg_fixed.py +13 -12
  213. flwr/server/strategy/fault_tolerant_fedavg.py +11 -11
  214. flwr/server/strategy/fedadagrad.py +9 -9
  215. flwr/server/strategy/fedadam.py +20 -10
  216. flwr/server/strategy/fedavg.py +16 -16
  217. flwr/server/strategy/fedavg_android.py +17 -17
  218. flwr/server/strategy/fedavgm.py +9 -9
  219. flwr/server/strategy/fedmedian.py +5 -5
  220. flwr/server/strategy/fedopt.py +6 -6
  221. flwr/server/strategy/fedprox.py +7 -7
  222. flwr/server/strategy/fedtrimmedavg.py +8 -8
  223. flwr/server/strategy/fedxgb_bagging.py +12 -12
  224. flwr/server/strategy/fedxgb_cyclic.py +10 -10
  225. flwr/server/strategy/fedxgb_nn_avg.py +6 -6
  226. flwr/server/strategy/fedyogi.py +9 -9
  227. flwr/server/strategy/krum.py +9 -9
  228. flwr/server/strategy/qfedavg.py +16 -16
  229. flwr/server/strategy/strategy.py +10 -10
  230. flwr/server/superlink/driver/__init__.py +2 -2
  231. flwr/server/superlink/driver/serverappio_grpc.py +61 -0
  232. flwr/server/superlink/driver/serverappio_servicer.py +361 -0
  233. flwr/server/superlink/ffs/__init__.py +24 -0
  234. flwr/server/superlink/ffs/disk_ffs.py +108 -0
  235. flwr/server/superlink/ffs/ffs.py +79 -0
  236. flwr/server/superlink/ffs/ffs_factory.py +47 -0
  237. flwr/server/superlink/fleet/__init__.py +1 -1
  238. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  239. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +162 -0
  240. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  241. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +4 -2
  242. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -2
  243. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  244. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +5 -154
  245. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  246. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +120 -13
  247. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +228 -0
  248. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  249. flwr/server/superlink/fleet/message_handler/message_handler.py +156 -13
  250. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  251. flwr/server/superlink/fleet/rest_rere/rest_api.py +119 -81
  252. flwr/server/superlink/fleet/vce/__init__.py +1 -0
  253. flwr/server/superlink/fleet/vce/backend/__init__.py +4 -4
  254. flwr/server/superlink/fleet/vce/backend/backend.py +8 -9
  255. flwr/server/superlink/fleet/vce/backend/raybackend.py +87 -68
  256. flwr/server/superlink/fleet/vce/vce_api.py +208 -146
  257. flwr/server/superlink/linkstate/__init__.py +28 -0
  258. flwr/server/superlink/linkstate/in_memory_linkstate.py +569 -0
  259. flwr/server/superlink/linkstate/linkstate.py +376 -0
  260. flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +19 -10
  261. flwr/server/superlink/linkstate/sqlite_linkstate.py +1196 -0
  262. flwr/server/superlink/linkstate/utils.py +399 -0
  263. flwr/server/superlink/simulation/__init__.py +15 -0
  264. flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
  265. flwr/server/superlink/simulation/simulationio_servicer.py +186 -0
  266. flwr/server/superlink/utils.py +65 -0
  267. flwr/server/typing.py +2 -0
  268. flwr/server/utils/__init__.py +1 -1
  269. flwr/server/utils/tensorboard.py +5 -5
  270. flwr/server/utils/validator.py +40 -45
  271. flwr/server/workflow/default_workflows.py +70 -26
  272. flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -0
  273. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +40 -27
  274. flwr/simulation/__init__.py +12 -5
  275. flwr/simulation/app.py +247 -315
  276. flwr/simulation/legacy_app.py +404 -0
  277. flwr/simulation/ray_transport/__init__.py +1 -1
  278. flwr/simulation/ray_transport/ray_actor.py +42 -67
  279. flwr/simulation/ray_transport/ray_client_proxy.py +37 -17
  280. flwr/simulation/ray_transport/utils.py +1 -0
  281. flwr/simulation/run_simulation.py +306 -163
  282. flwr/simulation/simulationio_connection.py +89 -0
  283. flwr/superexec/__init__.py +15 -0
  284. flwr/superexec/app.py +59 -0
  285. flwr/superexec/deployment.py +188 -0
  286. flwr/superexec/exec_grpc.py +80 -0
  287. flwr/superexec/exec_servicer.py +231 -0
  288. flwr/superexec/exec_user_auth_interceptor.py +101 -0
  289. flwr/superexec/executor.py +96 -0
  290. flwr/superexec/simulation.py +124 -0
  291. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.15.0.dev20250115.dist-info}/METADATA +33 -26
  292. flwr_nightly-1.15.0.dev20250115.dist-info/RECORD +328 -0
  293. flwr_nightly-1.15.0.dev20250115.dist-info/entry_points.txt +12 -0
  294. flwr/cli/flower_toml.py +0 -140
  295. flwr/cli/new/templates/app/flower.toml.tpl +0 -13
  296. flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
  297. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
  298. flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
  299. flwr/client/node_state.py +0 -48
  300. flwr/client/node_state_tests.py +0 -65
  301. flwr/proto/driver_pb2.py +0 -44
  302. flwr/proto/driver_pb2_grpc.py +0 -169
  303. flwr/proto/driver_pb2_grpc.pyi +0 -66
  304. flwr/server/superlink/driver/driver_grpc.py +0 -54
  305. flwr/server/superlink/driver/driver_servicer.py +0 -129
  306. flwr/server/superlink/state/in_memory_state.py +0 -230
  307. flwr/server/superlink/state/sqlite_state.py +0 -630
  308. flwr/server/superlink/state/state.py +0 -154
  309. flwr_nightly-1.8.0.dev20240315.dist-info/RECORD +0 -211
  310. flwr_nightly-1.8.0.dev20240315.dist-info/entry_points.txt +0 -9
  311. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.15.0.dev20250115.dist-info}/LICENSE +0 -0
  312. {flwr_nightly-1.8.0.dev20240315.dist-info → flwr_nightly-1.15.0.dev20250115.dist-info}/WHEEL +0 -0
@@ -0,0 +1,1196 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """SQLite based implemenation of the link state."""
16
+
17
+
18
+ # pylint: disable=too-many-lines
19
+
20
+ import json
21
+ import re
22
+ import sqlite3
23
+ import time
24
+ from collections.abc import Sequence
25
+ from logging import DEBUG, ERROR, WARNING
26
+ from typing import Any, Optional, Union, cast
27
+ from uuid import UUID, uuid4
28
+
29
+ from flwr.common import Context, log, now
30
+ from flwr.common.constant import (
31
+ MESSAGE_TTL_TOLERANCE,
32
+ NODE_ID_NUM_BYTES,
33
+ RUN_ID_NUM_BYTES,
34
+ SUPERLINK_NODE_ID,
35
+ Status,
36
+ )
37
+ from flwr.common.record import ConfigsRecord
38
+ from flwr.common.typing import Run, RunStatus, UserConfig
39
+
40
+ # pylint: disable=E0611
41
+ from flwr.proto.node_pb2 import Node
42
+ from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
43
+ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
44
+
45
+ # pylint: enable=E0611
46
+ from flwr.server.utils.validator import validate_task_ins_or_res
47
+
48
+ from .linkstate import LinkState
49
+ from .utils import (
50
+ configsrecord_from_bytes,
51
+ configsrecord_to_bytes,
52
+ context_from_bytes,
53
+ context_to_bytes,
54
+ convert_sint64_to_uint64,
55
+ convert_sint64_values_in_dict_to_uint64,
56
+ convert_uint64_to_sint64,
57
+ convert_uint64_values_in_dict_to_sint64,
58
+ generate_rand_int_from_bytes,
59
+ has_valid_sub_status,
60
+ is_valid_transition,
61
+ verify_found_taskres,
62
+ verify_taskins_ids,
63
+ )
64
+
65
+ SQL_CREATE_TABLE_NODE = """
66
+ CREATE TABLE IF NOT EXISTS node(
67
+ node_id INTEGER UNIQUE,
68
+ online_until REAL,
69
+ ping_interval REAL,
70
+ public_key BLOB
71
+ );
72
+ """
73
+
74
+ SQL_CREATE_TABLE_CREDENTIAL = """
75
+ CREATE TABLE IF NOT EXISTS credential(
76
+ private_key BLOB PRIMARY KEY,
77
+ public_key BLOB
78
+ );
79
+ """
80
+
81
+ SQL_CREATE_TABLE_PUBLIC_KEY = """
82
+ CREATE TABLE IF NOT EXISTS public_key(
83
+ public_key BLOB PRIMARY KEY
84
+ );
85
+ """
86
+
87
+ SQL_CREATE_INDEX_ONLINE_UNTIL = """
88
+ CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
89
+ """
90
+
91
+ SQL_CREATE_TABLE_RUN = """
92
+ CREATE TABLE IF NOT EXISTS run(
93
+ run_id INTEGER UNIQUE,
94
+ fab_id TEXT,
95
+ fab_version TEXT,
96
+ fab_hash TEXT,
97
+ override_config TEXT,
98
+ pending_at TEXT,
99
+ starting_at TEXT,
100
+ running_at TEXT,
101
+ finished_at TEXT,
102
+ sub_status TEXT,
103
+ details TEXT,
104
+ federation_options BLOB
105
+ );
106
+ """
107
+
108
+ SQL_CREATE_TABLE_LOGS = """
109
+ CREATE TABLE IF NOT EXISTS logs (
110
+ timestamp REAL,
111
+ run_id INTEGER,
112
+ node_id INTEGER,
113
+ log TEXT,
114
+ PRIMARY KEY (timestamp, run_id, node_id),
115
+ FOREIGN KEY (run_id) REFERENCES run(run_id)
116
+ );
117
+ """
118
+
119
+ SQL_CREATE_TABLE_CONTEXT = """
120
+ CREATE TABLE IF NOT EXISTS context(
121
+ run_id INTEGER UNIQUE,
122
+ context BLOB,
123
+ FOREIGN KEY(run_id) REFERENCES run(run_id)
124
+ );
125
+ """
126
+
127
+ SQL_CREATE_TABLE_TASK_INS = """
128
+ CREATE TABLE IF NOT EXISTS task_ins(
129
+ task_id TEXT UNIQUE,
130
+ group_id TEXT,
131
+ run_id INTEGER,
132
+ producer_node_id INTEGER,
133
+ consumer_node_id INTEGER,
134
+ created_at REAL,
135
+ delivered_at TEXT,
136
+ pushed_at REAL,
137
+ ttl REAL,
138
+ ancestry TEXT,
139
+ task_type TEXT,
140
+ recordset BLOB,
141
+ FOREIGN KEY(run_id) REFERENCES run(run_id)
142
+ );
143
+ """
144
+
145
+ SQL_CREATE_TABLE_TASK_RES = """
146
+ CREATE TABLE IF NOT EXISTS task_res(
147
+ task_id TEXT UNIQUE,
148
+ group_id TEXT,
149
+ run_id INTEGER,
150
+ producer_node_id INTEGER,
151
+ consumer_node_id INTEGER,
152
+ created_at REAL,
153
+ delivered_at TEXT,
154
+ pushed_at REAL,
155
+ ttl REAL,
156
+ ancestry TEXT,
157
+ task_type TEXT,
158
+ recordset BLOB,
159
+ FOREIGN KEY(run_id) REFERENCES run(run_id)
160
+ );
161
+ """
162
+
163
+ DictOrTuple = Union[tuple[Any, ...], dict[str, Any]]
164
+
165
+
166
+ class SqliteLinkState(LinkState): # pylint: disable=R0904
167
+ """SQLite-based LinkState implementation."""
168
+
169
+ def __init__(
170
+ self,
171
+ database_path: str,
172
+ ) -> None:
173
+ """Initialize an SqliteLinkState.
174
+
175
+ Parameters
176
+ ----------
177
+ database : (path-like object)
178
+ The path to the database file to be opened. Pass ":memory:" to open
179
+ a connection to a database that is in RAM, instead of on disk.
180
+ """
181
+ self.database_path = database_path
182
+ self.conn: Optional[sqlite3.Connection] = None
183
+
184
+ def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
185
+ """Create tables if they don't exist yet.
186
+
187
+ Parameters
188
+ ----------
189
+ log_queries : bool
190
+ Log each query which is executed.
191
+
192
+ Returns
193
+ -------
194
+ list[tuple[str]]
195
+ The list of all tables in the DB.
196
+ """
197
+ self.conn = sqlite3.connect(self.database_path)
198
+ self.conn.execute("PRAGMA foreign_keys = ON;")
199
+ self.conn.row_factory = dict_factory
200
+ if log_queries:
201
+ self.conn.set_trace_callback(lambda query: log(DEBUG, query))
202
+ cur = self.conn.cursor()
203
+
204
+ # Create each table if not exists queries
205
+ cur.execute(SQL_CREATE_TABLE_RUN)
206
+ cur.execute(SQL_CREATE_TABLE_LOGS)
207
+ cur.execute(SQL_CREATE_TABLE_CONTEXT)
208
+ cur.execute(SQL_CREATE_TABLE_TASK_INS)
209
+ cur.execute(SQL_CREATE_TABLE_TASK_RES)
210
+ cur.execute(SQL_CREATE_TABLE_NODE)
211
+ cur.execute(SQL_CREATE_TABLE_CREDENTIAL)
212
+ cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
213
+ cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
214
+ res = cur.execute("SELECT name FROM sqlite_schema;")
215
+ return res.fetchall()
216
+
217
+ def query(
218
+ self,
219
+ query: str,
220
+ data: Optional[Union[Sequence[DictOrTuple], DictOrTuple]] = None,
221
+ ) -> list[dict[str, Any]]:
222
+ """Execute a SQL query."""
223
+ if self.conn is None:
224
+ raise AttributeError("LinkState is not initialized.")
225
+
226
+ if data is None:
227
+ data = []
228
+
229
+ # Clean up whitespace to make the logs nicer
230
+ query = re.sub(r"\s+", " ", query)
231
+
232
+ try:
233
+ with self.conn:
234
+ if (
235
+ len(data) > 0
236
+ and isinstance(data, (tuple, list))
237
+ and isinstance(data[0], (tuple, dict))
238
+ ):
239
+ rows = self.conn.executemany(query, data)
240
+ else:
241
+ rows = self.conn.execute(query, data)
242
+
243
+ # Extract results before committing to support
244
+ # INSERT/UPDATE ... RETURNING
245
+ # style queries
246
+ result = rows.fetchall()
247
+ except KeyError as exc:
248
+ log(ERROR, {"query": query, "data": data, "exception": exc})
249
+
250
+ return result
251
+
252
+ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
253
+ """Store one TaskIns.
254
+
255
+ Usually, the ServerAppIo API calls this to schedule instructions.
256
+
257
+ Stores the value of the task_ins in the link state and, if successful,
258
+ returns the task_id (UUID) of the task_ins. If, for any reason, storing
259
+ the task_ins fails, `None` is returned.
260
+
261
+ Constraints
262
+ -----------
263
+
264
+ `task_ins.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
265
+ """
266
+ # Validate task
267
+ errors = validate_task_ins_or_res(task_ins)
268
+ if any(errors):
269
+ log(ERROR, errors)
270
+ return None
271
+ # Create task_id
272
+ task_id = uuid4()
273
+
274
+ # Store TaskIns
275
+ task_ins.task_id = str(task_id)
276
+ data = (task_ins_to_dict(task_ins),)
277
+
278
+ # Convert values from uint64 to sint64 for SQLite
279
+ convert_uint64_values_in_dict_to_sint64(
280
+ data[0], ["run_id", "producer_node_id", "consumer_node_id"]
281
+ )
282
+
283
+ # Validate run_id
284
+ query = "SELECT run_id FROM run WHERE run_id = ?;"
285
+ if not self.query(query, (data[0]["run_id"],)):
286
+ log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
287
+ return None
288
+ # Validate source node ID
289
+ if task_ins.task.producer.node_id != SUPERLINK_NODE_ID:
290
+ log(
291
+ ERROR,
292
+ "Invalid source node ID for TaskIns: %s",
293
+ task_ins.task.producer.node_id,
294
+ )
295
+ return None
296
+ # Validate destination node ID
297
+ query = "SELECT node_id FROM node WHERE node_id = ?;"
298
+ if not self.query(query, (data[0]["consumer_node_id"],)):
299
+ log(
300
+ ERROR,
301
+ "Invalid destination node ID for TaskIns: %s",
302
+ task_ins.task.consumer.node_id,
303
+ )
304
+ return None
305
+
306
+ columns = ", ".join([f":{key}" for key in data[0]])
307
+ query = f"INSERT INTO task_ins VALUES({columns});"
308
+
309
+ # Only invalid run_id can trigger IntegrityError.
310
+ # This may need to be changed in the future version with more integrity checks.
311
+ self.query(query, data)
312
+
313
+ return task_id
314
+
315
+ def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
316
+ """Get undelivered TaskIns for one node.
317
+
318
+ Usually, the Fleet API calls this for Nodes planning to work on one or more
319
+ TaskIns.
320
+
321
+ Constraints
322
+ -----------
323
+ Retrieve all TaskIns where
324
+
325
+ 1. the `task_ins.task.consumer.node_id` equals `node_id` AND
326
+ 2. the `task_ins.task.delivered_at` equals `""`.
327
+
328
+ `delivered_at` MUST BE set (i.e., not `""`) otherwise the TaskIns MUST not be in
329
+ the result.
330
+
331
+ If `limit` is not `None`, return, at most, `limit` number of `task_ins`. If
332
+ `limit` is set, it has to be greater than zero.
333
+ """
334
+ if limit is not None and limit < 1:
335
+ raise AssertionError("`limit` must be >= 1")
336
+
337
+ if node_id == SUPERLINK_NODE_ID:
338
+ msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
339
+ raise AssertionError(msg)
340
+
341
+ data: dict[str, Union[str, int]] = {}
342
+
343
+ # Convert the uint64 value to sint64 for SQLite
344
+ data["node_id"] = convert_uint64_to_sint64(node_id)
345
+
346
+ # Retrieve all TaskIns for node_id
347
+ query = """
348
+ SELECT task_id
349
+ FROM task_ins
350
+ WHERE consumer_node_id == :node_id
351
+ AND delivered_at = ""
352
+ AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
353
+ """
354
+
355
+ if limit is not None:
356
+ query += " LIMIT :limit"
357
+ data["limit"] = limit
358
+
359
+ query += ";"
360
+
361
+ rows = self.query(query, data)
362
+
363
+ if rows:
364
+ # Prepare query
365
+ task_ids = [row["task_id"] for row in rows]
366
+ placeholders: str = ",".join([f":id_{i}" for i in range(len(task_ids))])
367
+ query = f"""
368
+ UPDATE task_ins
369
+ SET delivered_at = :delivered_at
370
+ WHERE task_id IN ({placeholders})
371
+ RETURNING *;
372
+ """
373
+
374
+ # Prepare data for query
375
+ delivered_at = now().isoformat()
376
+ data = {"delivered_at": delivered_at}
377
+ for index, task_id in enumerate(task_ids):
378
+ data[f"id_{index}"] = str(task_id)
379
+
380
+ # Run query
381
+ rows = self.query(query, data)
382
+
383
+ for row in rows:
384
+ # Convert values from sint64 to uint64
385
+ convert_sint64_values_in_dict_to_uint64(
386
+ row, ["run_id", "producer_node_id", "consumer_node_id"]
387
+ )
388
+
389
+ result = [dict_to_task_ins(row) for row in rows]
390
+
391
+ return result
392
+
393
+ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
394
+ """Store one TaskRes.
395
+
396
+ Usually, the Fleet API calls this when Nodes return their results.
397
+
398
+ Stores the TaskRes and, if successful, returns the `task_id` (UUID) of
399
+ the `task_res`. If storing the `task_res` fails, `None` is returned.
400
+
401
+ Constraints
402
+ -----------
403
+ `task_res.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
404
+ """
405
+ # Validate task
406
+ errors = validate_task_ins_or_res(task_res)
407
+ if any(errors):
408
+ log(ERROR, errors)
409
+ return None
410
+
411
+ # Create task_id
412
+ task_id = uuid4()
413
+
414
+ task_ins_id = task_res.task.ancestry[0]
415
+ task_ins = self.get_valid_task_ins(task_ins_id)
416
+ if task_ins is None:
417
+ log(
418
+ ERROR,
419
+ "Failed to store TaskRes: "
420
+ "TaskIns with task_id %s does not exist or has expired.",
421
+ task_ins_id,
422
+ )
423
+ return None
424
+
425
+ # Ensure that the consumer_id of taskIns matches the producer_id of taskRes.
426
+ if (
427
+ task_ins
428
+ and task_res
429
+ and convert_sint64_to_uint64(task_ins["consumer_node_id"])
430
+ != task_res.task.producer.node_id
431
+ ):
432
+ return None
433
+
434
+ # Fail if the TaskRes TTL exceeds the
435
+ # expiration time of the TaskIns it replies to.
436
+ # Condition: TaskIns.created_at + TaskIns.ttl ≥
437
+ # TaskRes.created_at + TaskRes.ttl
438
+ # A small tolerance is introduced to account
439
+ # for floating-point precision issues.
440
+ max_allowed_ttl = (
441
+ task_ins["created_at"] + task_ins["ttl"] - task_res.task.created_at
442
+ )
443
+ if task_res.task.ttl and (
444
+ task_res.task.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
445
+ ):
446
+ log(
447
+ WARNING,
448
+ "Received TaskRes with TTL %.2f "
449
+ "exceeding the allowed maximum TTL %.2f.",
450
+ task_res.task.ttl,
451
+ max_allowed_ttl,
452
+ )
453
+ return None
454
+
455
+ # Store TaskRes
456
+ task_res.task_id = str(task_id)
457
+ data = (task_res_to_dict(task_res),)
458
+
459
+ # Convert values from uint64 to sint64 for SQLite
460
+ convert_uint64_values_in_dict_to_sint64(
461
+ data[0], ["run_id", "producer_node_id", "consumer_node_id"]
462
+ )
463
+
464
+ columns = ", ".join([f":{key}" for key in data[0]])
465
+ query = f"INSERT INTO task_res VALUES({columns});"
466
+
467
+ # Only invalid run_id can trigger IntegrityError.
468
+ # This may need to be changed in the future version with more integrity checks.
469
+ try:
470
+ self.query(query, data)
471
+ except sqlite3.IntegrityError:
472
+ log(ERROR, "`run` is invalid")
473
+ return None
474
+
475
+ return task_id
476
+
477
+ # pylint: disable-next=R0912,R0915,R0914
478
+ def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
479
+ """Get TaskRes for the given TaskIns IDs."""
480
+ ret: dict[UUID, TaskRes] = {}
481
+
482
+ # Verify TaskIns IDs
483
+ current = time.time()
484
+ query = f"""
485
+ SELECT *
486
+ FROM task_ins
487
+ WHERE task_id IN ({",".join(["?"] * len(task_ids))});
488
+ """
489
+ rows = self.query(query, tuple(str(task_id) for task_id in task_ids))
490
+ found_task_ins_dict: dict[UUID, TaskIns] = {}
491
+ for row in rows:
492
+ convert_sint64_values_in_dict_to_uint64(
493
+ row, ["run_id", "producer_node_id", "consumer_node_id"]
494
+ )
495
+ found_task_ins_dict[UUID(row["task_id"])] = dict_to_task_ins(row)
496
+
497
+ ret = verify_taskins_ids(
498
+ inquired_taskins_ids=task_ids,
499
+ found_taskins_dict=found_task_ins_dict,
500
+ current_time=current,
501
+ )
502
+
503
+ # Find all TaskRes
504
+ query = f"""
505
+ SELECT *
506
+ FROM task_res
507
+ WHERE ancestry IN ({",".join(["?"] * len(task_ids))})
508
+ AND delivered_at = "";
509
+ """
510
+ rows = self.query(query, tuple(str(task_id) for task_id in task_ids))
511
+ for row in rows:
512
+ convert_sint64_values_in_dict_to_uint64(
513
+ row, ["run_id", "producer_node_id", "consumer_node_id"]
514
+ )
515
+ tmp_ret_dict = verify_found_taskres(
516
+ inquired_taskins_ids=task_ids,
517
+ found_taskins_dict=found_task_ins_dict,
518
+ found_taskres_list=[dict_to_task_res(row) for row in rows],
519
+ current_time=current,
520
+ )
521
+ ret.update(tmp_ret_dict)
522
+
523
+ # Mark existing TaskRes to be returned as delivered
524
+ delivered_at = now().isoformat()
525
+ for task_res in ret.values():
526
+ task_res.task.delivered_at = delivered_at
527
+ task_res_ids = [task_res.task_id for task_res in ret.values()]
528
+ query = f"""
529
+ UPDATE task_res
530
+ SET delivered_at = ?
531
+ WHERE task_id IN ({",".join(["?"] * len(task_res_ids))});
532
+ """
533
+ data: list[Any] = [delivered_at] + task_res_ids
534
+ self.query(query, data)
535
+
536
+ return list(ret.values())
537
+
538
+ def num_task_ins(self) -> int:
539
+ """Calculate the number of task_ins in store.
540
+
541
+ This includes delivered but not yet deleted task_ins.
542
+ """
543
+ query = "SELECT count(*) AS num FROM task_ins;"
544
+ rows = self.query(query)
545
+ result = rows[0]
546
+ num = cast(int, result["num"])
547
+ return num
548
+
549
+ def num_task_res(self) -> int:
550
+ """Calculate the number of task_res in store.
551
+
552
+ This includes delivered but not yet deleted task_res.
553
+ """
554
+ query = "SELECT count(*) AS num FROM task_res;"
555
+ rows = self.query(query)
556
+ result: dict[str, int] = rows[0]
557
+ return result["num"]
558
+
559
+ def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
560
+ """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
561
+ if not task_ins_ids:
562
+ return
563
+ if self.conn is None:
564
+ raise AttributeError("LinkState not initialized")
565
+
566
+ placeholders = ",".join(["?"] * len(task_ins_ids))
567
+ data = tuple(str(task_id) for task_id in task_ins_ids)
568
+
569
+ # Delete task_ins
570
+ query_1 = f"""
571
+ DELETE FROM task_ins
572
+ WHERE task_id IN ({placeholders});
573
+ """
574
+
575
+ # Delete task_res
576
+ query_2 = f"""
577
+ DELETE FROM task_res
578
+ WHERE ancestry IN ({placeholders});
579
+ """
580
+
581
+ with self.conn:
582
+ self.conn.execute(query_1, data)
583
+ self.conn.execute(query_2, data)
584
+
585
+ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
586
+ """Get all TaskIns IDs for the given run_id."""
587
+ if self.conn is None:
588
+ raise AttributeError("LinkState not initialized")
589
+
590
+ query = """
591
+ SELECT task_id
592
+ FROM task_ins
593
+ WHERE run_id = :run_id;
594
+ """
595
+
596
+ sint64_run_id = convert_uint64_to_sint64(run_id)
597
+ data = {"run_id": sint64_run_id}
598
+
599
+ with self.conn:
600
+ rows = self.conn.execute(query, data).fetchall()
601
+
602
+ return {UUID(row["task_id"]) for row in rows}
603
+
604
+ def create_node(self, ping_interval: float) -> int:
605
+ """Create, store in the link state, and return `node_id`."""
606
+ # Sample a random uint64 as node_id
607
+ uint64_node_id = generate_rand_int_from_bytes(
608
+ NODE_ID_NUM_BYTES, exclude=[SUPERLINK_NODE_ID, 0]
609
+ )
610
+
611
+ # Convert the uint64 value to sint64 for SQLite
612
+ sint64_node_id = convert_uint64_to_sint64(uint64_node_id)
613
+
614
+ query = (
615
+ "INSERT INTO node "
616
+ "(node_id, online_until, ping_interval, public_key) "
617
+ "VALUES (?, ?, ?, ?)"
618
+ )
619
+
620
+ try:
621
+ self.query(
622
+ query,
623
+ (
624
+ sint64_node_id,
625
+ time.time() + ping_interval,
626
+ ping_interval,
627
+ b"", # Initialize with an empty public key
628
+ ),
629
+ )
630
+ except sqlite3.IntegrityError:
631
+ log(ERROR, "Unexpected node registration failure.")
632
+ return 0
633
+
634
+ # Note: we need to return the uint64 value of the node_id
635
+ return uint64_node_id
636
+
637
+ def delete_node(self, node_id: int) -> None:
638
+ """Delete a node."""
639
+ # Convert the uint64 value to sint64 for SQLite
640
+ sint64_node_id = convert_uint64_to_sint64(node_id)
641
+
642
+ query = "DELETE FROM node WHERE node_id = ?"
643
+ params = (sint64_node_id,)
644
+
645
+ if self.conn is None:
646
+ raise AttributeError("LinkState is not initialized.")
647
+
648
+ try:
649
+ with self.conn:
650
+ rows = self.conn.execute(query, params)
651
+ if rows.rowcount < 1:
652
+ raise ValueError(f"Node {node_id} not found")
653
+ except KeyError as exc:
654
+ log(ERROR, {"query": query, "data": params, "exception": exc})
655
+
656
+ def get_nodes(self, run_id: int) -> set[int]:
657
+ """Retrieve all currently stored node IDs as a set.
658
+
659
+ Constraints
660
+ -----------
661
+ If the provided `run_id` does not exist or has no matching nodes,
662
+ an empty `Set` MUST be returned.
663
+ """
664
+ # Convert the uint64 value to sint64 for SQLite
665
+ sint64_run_id = convert_uint64_to_sint64(run_id)
666
+
667
+ # Validate run ID
668
+ query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
669
+ if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
670
+ return set()
671
+
672
+ # Get nodes
673
+ query = "SELECT node_id FROM node WHERE online_until > ?;"
674
+ rows = self.query(query, (time.time(),))
675
+
676
+ # Convert sint64 node_ids to uint64
677
+ result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows}
678
+ return result
679
+
680
+ def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
681
+ """Set `public_key` for the specified `node_id`."""
682
+ # Convert the uint64 value to sint64 for SQLite
683
+ sint64_node_id = convert_uint64_to_sint64(node_id)
684
+
685
+ # Check if the node exists in the `node` table
686
+ query = "SELECT 1 FROM node WHERE node_id = ?"
687
+ if not self.query(query, (sint64_node_id,)):
688
+ raise ValueError(f"Node {node_id} not found")
689
+
690
+ # Check if the public key is already in use in the `node` table
691
+ query = "SELECT 1 FROM node WHERE public_key = ?"
692
+ if self.query(query, (public_key,)):
693
+ raise ValueError("Public key already in use")
694
+
695
+ # Update the `node` table to set the public key for the given node ID
696
+ query = "UPDATE node SET public_key = ? WHERE node_id = ?"
697
+ self.query(query, (public_key, sint64_node_id))
698
+
699
+ def get_node_public_key(self, node_id: int) -> Optional[bytes]:
700
+ """Get `public_key` for the specified `node_id`."""
701
+ # Convert the uint64 value to sint64 for SQLite
702
+ sint64_node_id = convert_uint64_to_sint64(node_id)
703
+
704
+ # Query the public key for the given node_id
705
+ query = "SELECT public_key FROM node WHERE node_id = ?"
706
+ rows = self.query(query, (sint64_node_id,))
707
+
708
+ # If no result is found, return None
709
+ if not rows:
710
+ raise ValueError(f"Node {node_id} not found")
711
+
712
+ # Return the public key if it is not empty, otherwise return None
713
+ return rows[0]["public_key"] or None
714
+
715
+ def get_node_id(self, node_public_key: bytes) -> Optional[int]:
716
+ """Retrieve stored `node_id` filtered by `node_public_keys`."""
717
+ query = "SELECT node_id FROM node WHERE public_key = :public_key;"
718
+ row = self.query(query, {"public_key": node_public_key})
719
+ if len(row) > 0:
720
+ node_id: int = row[0]["node_id"]
721
+
722
+ # Convert the sint64 value to uint64 after reading from SQLite
723
+ uint64_node_id = convert_sint64_to_uint64(node_id)
724
+
725
+ return uint64_node_id
726
+ return None
727
+
728
+ # pylint: disable=too-many-arguments,too-many-positional-arguments
729
+ def create_run(
730
+ self,
731
+ fab_id: Optional[str],
732
+ fab_version: Optional[str],
733
+ fab_hash: Optional[str],
734
+ override_config: UserConfig,
735
+ federation_options: ConfigsRecord,
736
+ ) -> int:
737
+ """Create a new run for the specified `fab_id` and `fab_version`."""
738
+ # Sample a random int64 as run_id
739
+ uint64_run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
740
+
741
+ # Convert the uint64 value to sint64 for SQLite
742
+ sint64_run_id = convert_uint64_to_sint64(uint64_run_id)
743
+
744
+ # Check conflicts
745
+ query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
746
+ # If sint64_run_id does not exist
747
+ if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
748
+ query = (
749
+ "INSERT INTO run "
750
+ "(run_id, fab_id, fab_version, fab_hash, override_config, "
751
+ "federation_options, pending_at, starting_at, running_at, finished_at, "
752
+ "sub_status, details) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
753
+ )
754
+ override_config_json = json.dumps(override_config)
755
+ data = [
756
+ sint64_run_id,
757
+ fab_id,
758
+ fab_version,
759
+ fab_hash,
760
+ override_config_json,
761
+ configsrecord_to_bytes(federation_options),
762
+ ]
763
+ data += [
764
+ now().isoformat(),
765
+ "",
766
+ "",
767
+ "",
768
+ "",
769
+ "",
770
+ ]
771
+ self.query(query, tuple(data))
772
+ return uint64_run_id
773
+ log(ERROR, "Unexpected run creation failure.")
774
+ return 0
775
+
776
+ def store_server_private_public_key(
777
+ self, private_key: bytes, public_key: bytes
778
+ ) -> None:
779
+ """Store `server_private_key` and `server_public_key` in the link state."""
780
+ query = "SELECT COUNT(*) FROM credential"
781
+ count = self.query(query)[0]["COUNT(*)"]
782
+ if count < 1:
783
+ query = (
784
+ "INSERT OR REPLACE INTO credential (private_key, public_key) "
785
+ "VALUES (:private_key, :public_key)"
786
+ )
787
+ self.query(query, {"private_key": private_key, "public_key": public_key})
788
+ else:
789
+ raise RuntimeError("Server private and public key already set")
790
+
791
+ def get_server_private_key(self) -> Optional[bytes]:
792
+ """Retrieve `server_private_key` in urlsafe bytes."""
793
+ query = "SELECT private_key FROM credential"
794
+ rows = self.query(query)
795
+ try:
796
+ private_key: Optional[bytes] = rows[0]["private_key"]
797
+ except IndexError:
798
+ private_key = None
799
+ return private_key
800
+
801
+ def get_server_public_key(self) -> Optional[bytes]:
802
+ """Retrieve `server_public_key` in urlsafe bytes."""
803
+ query = "SELECT public_key FROM credential"
804
+ rows = self.query(query)
805
+ try:
806
+ public_key: Optional[bytes] = rows[0]["public_key"]
807
+ except IndexError:
808
+ public_key = None
809
+ return public_key
810
+
811
+ def clear_supernode_auth_keys_and_credentials(self) -> None:
812
+ """Clear stored `node_public_keys` and credentials in the link state if any."""
813
+ queries = ["DELETE FROM public_key;", "DELETE FROM credential;"]
814
+ for query in queries:
815
+ self.query(query)
816
+
817
+ def store_node_public_keys(self, public_keys: set[bytes]) -> None:
818
+ """Store a set of `node_public_keys` in the link state."""
819
+ query = "INSERT INTO public_key (public_key) VALUES (?)"
820
+ data = [(key,) for key in public_keys]
821
+ self.query(query, data)
822
+
823
+ def store_node_public_key(self, public_key: bytes) -> None:
824
+ """Store a `node_public_key` in the link state."""
825
+ query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
826
+ self.query(query, {"public_key": public_key})
827
+
828
+ def get_node_public_keys(self) -> set[bytes]:
829
+ """Retrieve all currently stored `node_public_keys` as a set."""
830
+ query = "SELECT public_key FROM public_key"
831
+ rows = self.query(query)
832
+ result: set[bytes] = {row["public_key"] for row in rows}
833
+ return result
834
+
835
+ def get_run_ids(self) -> set[int]:
836
+ """Retrieve all run IDs."""
837
+ query = "SELECT run_id FROM run;"
838
+ rows = self.query(query)
839
+ return {convert_sint64_to_uint64(row["run_id"]) for row in rows}
840
+
841
+ def get_run(self, run_id: int) -> Optional[Run]:
842
+ """Retrieve information about the run with the specified `run_id`."""
843
+ # Convert the uint64 value to sint64 for SQLite
844
+ sint64_run_id = convert_uint64_to_sint64(run_id)
845
+ query = "SELECT * FROM run WHERE run_id = ?;"
846
+ rows = self.query(query, (sint64_run_id,))
847
+ if rows:
848
+ row = rows[0]
849
+ return Run(
850
+ run_id=convert_sint64_to_uint64(row["run_id"]),
851
+ fab_id=row["fab_id"],
852
+ fab_version=row["fab_version"],
853
+ fab_hash=row["fab_hash"],
854
+ override_config=json.loads(row["override_config"]),
855
+ pending_at=row["pending_at"],
856
+ starting_at=row["starting_at"],
857
+ running_at=row["running_at"],
858
+ finished_at=row["finished_at"],
859
+ status=RunStatus(
860
+ status=determine_run_status(row),
861
+ sub_status=row["sub_status"],
862
+ details=row["details"],
863
+ ),
864
+ )
865
+ log(ERROR, "`run_id` does not exist.")
866
+ return None
867
+
868
+ def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
869
+ """Retrieve the statuses for the specified runs."""
870
+ # Convert the uint64 value to sint64 for SQLite
871
+ sint64_run_ids = (convert_uint64_to_sint64(run_id) for run_id in set(run_ids))
872
+ query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});"
873
+ rows = self.query(query, tuple(sint64_run_ids))
874
+
875
+ return {
876
+ # Restore uint64 run IDs
877
+ convert_sint64_to_uint64(row["run_id"]): RunStatus(
878
+ status=determine_run_status(row),
879
+ sub_status=row["sub_status"],
880
+ details=row["details"],
881
+ )
882
+ for row in rows
883
+ }
884
+
885
+ def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
886
+ """Update the status of the run with the specified `run_id`."""
887
+ # Convert the uint64 value to sint64 for SQLite
888
+ sint64_run_id = convert_uint64_to_sint64(run_id)
889
+ query = "SELECT * FROM run WHERE run_id = ?;"
890
+ rows = self.query(query, (sint64_run_id,))
891
+
892
+ # Check if the run_id exists
893
+ if not rows:
894
+ log(ERROR, "`run_id` is invalid")
895
+ return False
896
+
897
+ # Check if the status transition is valid
898
+ row = rows[0]
899
+ current_status = RunStatus(
900
+ status=determine_run_status(row),
901
+ sub_status=row["sub_status"],
902
+ details=row["details"],
903
+ )
904
+ if not is_valid_transition(current_status, new_status):
905
+ log(
906
+ ERROR,
907
+ 'Invalid status transition: from "%s" to "%s"',
908
+ current_status.status,
909
+ new_status.status,
910
+ )
911
+ return False
912
+
913
+ # Check if the sub-status is valid
914
+ if not has_valid_sub_status(current_status):
915
+ log(
916
+ ERROR,
917
+ 'Invalid sub-status "%s" for status "%s"',
918
+ current_status.sub_status,
919
+ current_status.status,
920
+ )
921
+ return False
922
+
923
+ # Update the status
924
+ query = "UPDATE run SET %s= ?, sub_status = ?, details = ? "
925
+ query += "WHERE run_id = ?;"
926
+
927
+ timestamp_fld = ""
928
+ if new_status.status == Status.STARTING:
929
+ timestamp_fld = "starting_at"
930
+ elif new_status.status == Status.RUNNING:
931
+ timestamp_fld = "running_at"
932
+ elif new_status.status == Status.FINISHED:
933
+ timestamp_fld = "finished_at"
934
+
935
+ data = (
936
+ now().isoformat(),
937
+ new_status.sub_status,
938
+ new_status.details,
939
+ sint64_run_id,
940
+ )
941
+ self.query(query % timestamp_fld, data)
942
+ return True
943
+
944
+ def get_pending_run_id(self) -> Optional[int]:
945
+ """Get the `run_id` of a run with `Status.PENDING` status, if any."""
946
+ pending_run_id = None
947
+
948
+ # Fetch all runs with unset `starting_at` (i.e. they are in PENDING status)
949
+ query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1;"
950
+ rows = self.query(query)
951
+ if rows:
952
+ pending_run_id = convert_sint64_to_uint64(rows[0]["run_id"])
953
+
954
+ return pending_run_id
955
+
956
+ def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
957
+ """Retrieve the federation options for the specified `run_id`."""
958
+ # Convert the uint64 value to sint64 for SQLite
959
+ sint64_run_id = convert_uint64_to_sint64(run_id)
960
+ query = "SELECT federation_options FROM run WHERE run_id = ?;"
961
+ rows = self.query(query, (sint64_run_id,))
962
+
963
+ # Check if the run_id exists
964
+ if not rows:
965
+ log(ERROR, "`run_id` is invalid")
966
+ return None
967
+
968
+ row = rows[0]
969
+ return configsrecord_from_bytes(row["federation_options"])
970
+
971
+ def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
972
+ """Acknowledge a ping received from a node, serving as a heartbeat."""
973
+ sint64_node_id = convert_uint64_to_sint64(node_id)
974
+
975
+ # Check if the node exists in the `node` table
976
+ query = "SELECT 1 FROM node WHERE node_id = ?"
977
+ if not self.query(query, (sint64_node_id,)):
978
+ return False
979
+
980
+ # Update `online_until` and `ping_interval` for the given `node_id`
981
+ query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?"
982
+ self.query(query, (time.time() + ping_interval, ping_interval, sint64_node_id))
983
+ return True
984
+
985
+ def get_serverapp_context(self, run_id: int) -> Optional[Context]:
986
+ """Get the context for the specified `run_id`."""
987
+ # Retrieve context if any
988
+ query = "SELECT context FROM context WHERE run_id = ?;"
989
+ rows = self.query(query, (convert_uint64_to_sint64(run_id),))
990
+ context = context_from_bytes(rows[0]["context"]) if rows else None
991
+ return context
992
+
993
+ def set_serverapp_context(self, run_id: int, context: Context) -> None:
994
+ """Set the context for the specified `run_id`."""
995
+ # Convert context to bytes
996
+ context_bytes = context_to_bytes(context)
997
+ sint_run_id = convert_uint64_to_sint64(run_id)
998
+
999
+ # Check if any existing Context assigned to the run_id
1000
+ query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
1001
+ if self.query(query, (sint_run_id,))[0]["COUNT(*)"] > 0:
1002
+ # Update context
1003
+ query = "UPDATE context SET context = ? WHERE run_id = ?;"
1004
+ self.query(query, (context_bytes, sint_run_id))
1005
+ else:
1006
+ try:
1007
+ # Store context
1008
+ query = "INSERT INTO context (run_id, context) VALUES (?, ?);"
1009
+ self.query(query, (sint_run_id, context_bytes))
1010
+ except sqlite3.IntegrityError:
1011
+ raise ValueError(f"Run {run_id} not found") from None
1012
+
1013
+ def add_serverapp_log(self, run_id: int, log_message: str) -> None:
1014
+ """Add a log entry to the ServerApp logs for the specified `run_id`."""
1015
+ # Convert the uint64 value to sint64 for SQLite
1016
+ sint64_run_id = convert_uint64_to_sint64(run_id)
1017
+
1018
+ # Store log
1019
+ try:
1020
+ query = """
1021
+ INSERT INTO logs (timestamp, run_id, node_id, log) VALUES (?, ?, ?, ?);
1022
+ """
1023
+ self.query(query, (now().timestamp(), sint64_run_id, 0, log_message))
1024
+ except sqlite3.IntegrityError:
1025
+ raise ValueError(f"Run {run_id} not found") from None
1026
+
1027
+ def get_serverapp_log(
1028
+ self, run_id: int, after_timestamp: Optional[float]
1029
+ ) -> tuple[str, float]:
1030
+ """Get the ServerApp logs for the specified `run_id`."""
1031
+ # Convert the uint64 value to sint64 for SQLite
1032
+ sint64_run_id = convert_uint64_to_sint64(run_id)
1033
+
1034
+ # Check if the run_id exists
1035
+ query = "SELECT run_id FROM run WHERE run_id = ?;"
1036
+ if not self.query(query, (sint64_run_id,)):
1037
+ raise ValueError(f"Run {run_id} not found")
1038
+
1039
+ # Retrieve logs
1040
+ if after_timestamp is None:
1041
+ after_timestamp = 0.0
1042
+ query = """
1043
+ SELECT log, timestamp FROM logs
1044
+ WHERE run_id = ? AND node_id = ? AND timestamp > ?;
1045
+ """
1046
+ rows = self.query(query, (sint64_run_id, 0, after_timestamp))
1047
+ rows.sort(key=lambda x: x["timestamp"])
1048
+ latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
1049
+ return "".join(row["log"] for row in rows), latest_timestamp
1050
+
1051
+ def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]:
1052
+ """Check if the TaskIns exists and is valid (not expired).
1053
+
1054
+ Return TaskIns if valid.
1055
+ """
1056
+ query = """
1057
+ SELECT *
1058
+ FROM task_ins
1059
+ WHERE task_id = :task_id
1060
+ """
1061
+ data = {"task_id": task_id}
1062
+ rows = self.query(query, data)
1063
+ if not rows:
1064
+ # TaskIns does not exist
1065
+ return None
1066
+
1067
+ task_ins = rows[0]
1068
+ created_at = task_ins["created_at"]
1069
+ ttl = task_ins["ttl"]
1070
+ current_time = time.time()
1071
+
1072
+ # Check if TaskIns is expired
1073
+ if ttl is not None and created_at + ttl <= current_time:
1074
+ return None
1075
+
1076
+ return task_ins
1077
+
1078
+
1079
+ def dict_factory(
1080
+ cursor: sqlite3.Cursor,
1081
+ row: sqlite3.Row,
1082
+ ) -> dict[str, Any]:
1083
+ """Turn SQLite results into dicts.
1084
+
1085
+ Less efficent for retrival of large amounts of data but easier to use.
1086
+ """
1087
+ fields = [column[0] for column in cursor.description]
1088
+ return dict(zip(fields, row))
1089
+
1090
+
1091
+ def task_ins_to_dict(task_msg: TaskIns) -> dict[str, Any]:
1092
+ """Transform TaskIns to dict."""
1093
+ result = {
1094
+ "task_id": task_msg.task_id,
1095
+ "group_id": task_msg.group_id,
1096
+ "run_id": task_msg.run_id,
1097
+ "producer_node_id": task_msg.task.producer.node_id,
1098
+ "consumer_node_id": task_msg.task.consumer.node_id,
1099
+ "created_at": task_msg.task.created_at,
1100
+ "delivered_at": task_msg.task.delivered_at,
1101
+ "pushed_at": task_msg.task.pushed_at,
1102
+ "ttl": task_msg.task.ttl,
1103
+ "ancestry": ",".join(task_msg.task.ancestry),
1104
+ "task_type": task_msg.task.task_type,
1105
+ "recordset": task_msg.task.recordset.SerializeToString(),
1106
+ }
1107
+ return result
1108
+
1109
+
1110
+ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
1111
+ """Transform TaskRes to dict."""
1112
+ result = {
1113
+ "task_id": task_msg.task_id,
1114
+ "group_id": task_msg.group_id,
1115
+ "run_id": task_msg.run_id,
1116
+ "producer_node_id": task_msg.task.producer.node_id,
1117
+ "consumer_node_id": task_msg.task.consumer.node_id,
1118
+ "created_at": task_msg.task.created_at,
1119
+ "delivered_at": task_msg.task.delivered_at,
1120
+ "pushed_at": task_msg.task.pushed_at,
1121
+ "ttl": task_msg.task.ttl,
1122
+ "ancestry": ",".join(task_msg.task.ancestry),
1123
+ "task_type": task_msg.task.task_type,
1124
+ "recordset": task_msg.task.recordset.SerializeToString(),
1125
+ }
1126
+ return result
1127
+
1128
+
1129
+ def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
1130
+ """Turn task_dict into protobuf message."""
1131
+ recordset = ProtoRecordSet()
1132
+ recordset.ParseFromString(task_dict["recordset"])
1133
+
1134
+ result = TaskIns(
1135
+ task_id=task_dict["task_id"],
1136
+ group_id=task_dict["group_id"],
1137
+ run_id=task_dict["run_id"],
1138
+ task=Task(
1139
+ producer=Node(
1140
+ node_id=task_dict["producer_node_id"],
1141
+ ),
1142
+ consumer=Node(
1143
+ node_id=task_dict["consumer_node_id"],
1144
+ ),
1145
+ created_at=task_dict["created_at"],
1146
+ delivered_at=task_dict["delivered_at"],
1147
+ pushed_at=task_dict["pushed_at"],
1148
+ ttl=task_dict["ttl"],
1149
+ ancestry=task_dict["ancestry"].split(","),
1150
+ task_type=task_dict["task_type"],
1151
+ recordset=recordset,
1152
+ ),
1153
+ )
1154
+ return result
1155
+
1156
+
1157
+ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
1158
+ """Turn task_dict into protobuf message."""
1159
+ recordset = ProtoRecordSet()
1160
+ recordset.ParseFromString(task_dict["recordset"])
1161
+
1162
+ result = TaskRes(
1163
+ task_id=task_dict["task_id"],
1164
+ group_id=task_dict["group_id"],
1165
+ run_id=task_dict["run_id"],
1166
+ task=Task(
1167
+ producer=Node(
1168
+ node_id=task_dict["producer_node_id"],
1169
+ ),
1170
+ consumer=Node(
1171
+ node_id=task_dict["consumer_node_id"],
1172
+ ),
1173
+ created_at=task_dict["created_at"],
1174
+ delivered_at=task_dict["delivered_at"],
1175
+ pushed_at=task_dict["pushed_at"],
1176
+ ttl=task_dict["ttl"],
1177
+ ancestry=task_dict["ancestry"].split(","),
1178
+ task_type=task_dict["task_type"],
1179
+ recordset=recordset,
1180
+ ),
1181
+ )
1182
+ return result
1183
+
1184
+
1185
+ def determine_run_status(row: dict[str, Any]) -> str:
1186
+ """Determine the status of the run based on timestamp fields."""
1187
+ if row["pending_at"]:
1188
+ if row["finished_at"]:
1189
+ return Status.FINISHED
1190
+ if row["starting_at"]:
1191
+ if row["running_at"]:
1192
+ return Status.RUNNING
1193
+ return Status.STARTING
1194
+ return Status.PENDING
1195
+ run_id = convert_sint64_to_uint64(row["run_id"])
1196
+ raise sqlite3.IntegrityError(f"The run {run_id} does not have a valid status.")