flwr-nightly 1.13.0.dev20241021__py3-none-any.whl → 1.13.0.dev20241111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (92) hide show
  1. flwr/cli/build.py +2 -2
  2. flwr/cli/config_utils.py +97 -0
  3. flwr/cli/log.py +63 -97
  4. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -0
  6. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  7. flwr/cli/run/run.py +34 -88
  8. flwr/client/app.py +23 -20
  9. flwr/client/clientapp/app.py +22 -18
  10. flwr/client/nodestate/__init__.py +25 -0
  11. flwr/client/nodestate/in_memory_nodestate.py +38 -0
  12. flwr/client/nodestate/nodestate.py +30 -0
  13. flwr/client/nodestate/nodestate_factory.py +37 -0
  14. flwr/client/{node_state.py → run_info_store.py} +4 -3
  15. flwr/client/supernode/app.py +6 -8
  16. flwr/common/args.py +83 -0
  17. flwr/common/config.py +10 -0
  18. flwr/common/constant.py +39 -5
  19. flwr/common/context.py +9 -4
  20. flwr/common/date.py +3 -3
  21. flwr/common/logger.py +108 -1
  22. flwr/common/object_ref.py +47 -16
  23. flwr/common/serde.py +24 -0
  24. flwr/common/telemetry.py +0 -6
  25. flwr/common/typing.py +10 -1
  26. flwr/proto/exec_pb2.py +14 -17
  27. flwr/proto/exec_pb2.pyi +14 -22
  28. flwr/proto/log_pb2.py +29 -0
  29. flwr/proto/log_pb2.pyi +39 -0
  30. flwr/proto/log_pb2_grpc.py +4 -0
  31. flwr/proto/log_pb2_grpc.pyi +4 -0
  32. flwr/proto/message_pb2.py +8 -8
  33. flwr/proto/message_pb2.pyi +4 -1
  34. flwr/proto/run_pb2.py +32 -27
  35. flwr/proto/run_pb2.pyi +26 -0
  36. flwr/proto/serverappio_pb2.py +52 -0
  37. flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +54 -0
  38. flwr/proto/serverappio_pb2_grpc.py +376 -0
  39. flwr/proto/serverappio_pb2_grpc.pyi +147 -0
  40. flwr/proto/simulationio_pb2.py +38 -0
  41. flwr/proto/simulationio_pb2.pyi +65 -0
  42. flwr/proto/simulationio_pb2_grpc.py +205 -0
  43. flwr/proto/simulationio_pb2_grpc.pyi +81 -0
  44. flwr/server/app.py +272 -105
  45. flwr/server/driver/driver.py +15 -1
  46. flwr/server/driver/grpc_driver.py +25 -36
  47. flwr/server/driver/inmemory_driver.py +6 -16
  48. flwr/server/run_serverapp.py +29 -23
  49. flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
  50. flwr/server/serverapp/app.py +214 -0
  51. flwr/server/strategy/aggregate.py +4 -4
  52. flwr/server/strategy/fedadam.py +11 -1
  53. flwr/server/superlink/driver/__init__.py +1 -1
  54. flwr/server/superlink/driver/{driver_grpc.py → serverappio_grpc.py} +19 -16
  55. flwr/server/superlink/driver/{driver_servicer.py → serverappio_servicer.py} +125 -39
  56. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
  57. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -2
  58. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
  59. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
  60. flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
  61. flwr/server/superlink/fleet/rest_rere/rest_api.py +7 -7
  62. flwr/server/superlink/fleet/vce/vce_api.py +23 -23
  63. flwr/server/superlink/linkstate/__init__.py +28 -0
  64. flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +184 -36
  65. flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +149 -19
  66. flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
  67. flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +306 -65
  68. flwr/server/superlink/{state → linkstate}/utils.py +81 -30
  69. flwr/server/superlink/simulation/__init__.py +15 -0
  70. flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
  71. flwr/server/superlink/simulation/simulationio_servicer.py +153 -0
  72. flwr/simulation/__init__.py +5 -1
  73. flwr/simulation/app.py +273 -345
  74. flwr/simulation/legacy_app.py +382 -0
  75. flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
  76. flwr/simulation/run_simulation.py +57 -131
  77. flwr/simulation/simulationio_connection.py +86 -0
  78. flwr/superexec/app.py +6 -134
  79. flwr/superexec/deployment.py +61 -66
  80. flwr/superexec/exec_grpc.py +15 -8
  81. flwr/superexec/exec_servicer.py +36 -65
  82. flwr/superexec/executor.py +26 -7
  83. flwr/superexec/simulation.py +54 -107
  84. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/METADATA +5 -4
  85. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/RECORD +88 -69
  86. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/entry_points.txt +2 -0
  87. flwr/client/node_state_tests.py +0 -66
  88. flwr/proto/driver_pb2.py +0 -42
  89. flwr/proto/driver_pb2_grpc.py +0 -239
  90. flwr/proto/driver_pb2_grpc.pyi +0 -94
  91. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/LICENSE +0 -0
  92. {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/WHEEL +0 -0
@@ -12,39 +12,51 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """SQLite based implemenation of server state."""
15
+ """SQLite based implemenation of the link state."""
16
16
 
17
17
  # pylint: disable=too-many-lines
18
18
 
19
19
  import json
20
20
  import re
21
21
  import sqlite3
22
+ import threading
22
23
  import time
23
24
  from collections.abc import Sequence
24
25
  from logging import DEBUG, ERROR, WARNING
25
26
  from typing import Any, Optional, Union, cast
26
27
  from uuid import UUID, uuid4
27
28
 
28
- from flwr.common import log, now
29
+ from flwr.common import Context, log, now
29
30
  from flwr.common.constant import (
30
31
  MESSAGE_TTL_TOLERANCE,
31
32
  NODE_ID_NUM_BYTES,
32
33
  RUN_ID_NUM_BYTES,
34
+ Status,
33
35
  )
34
- from flwr.common.typing import Run, UserConfig
35
- from flwr.proto.node_pb2 import Node # pylint: disable=E0611
36
- from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
37
- from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
36
+ from flwr.common.record import ConfigsRecord
37
+ from flwr.common.typing import Run, RunStatus, UserConfig
38
+
39
+ # pylint: disable=E0611
40
+ from flwr.proto.node_pb2 import Node
41
+ from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
42
+ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
43
+
44
+ # pylint: enable=E0611
38
45
  from flwr.server.utils.validator import validate_task_ins_or_res
39
46
 
40
- from .state import State
47
+ from .linkstate import LinkState
41
48
  from .utils import (
49
+ configsrecord_from_bytes,
50
+ configsrecord_to_bytes,
51
+ context_from_bytes,
52
+ context_to_bytes,
42
53
  convert_sint64_to_uint64,
43
54
  convert_sint64_values_in_dict_to_uint64,
44
55
  convert_uint64_to_sint64,
45
56
  convert_uint64_values_in_dict_to_sint64,
46
57
  generate_rand_int_from_bytes,
47
- make_node_unavailable_taskres,
58
+ has_valid_sub_status,
59
+ is_valid_transition,
48
60
  )
49
61
 
50
62
  SQL_CREATE_TABLE_NODE = """
@@ -79,7 +91,33 @@ CREATE TABLE IF NOT EXISTS run(
79
91
  fab_id TEXT,
80
92
  fab_version TEXT,
81
93
  fab_hash TEXT,
82
- override_config TEXT
94
+ override_config TEXT,
95
+ pending_at TEXT,
96
+ starting_at TEXT,
97
+ running_at TEXT,
98
+ finished_at TEXT,
99
+ sub_status TEXT,
100
+ details TEXT,
101
+ federation_options BLOB
102
+ );
103
+ """
104
+
105
+ SQL_CREATE_TABLE_LOGS = """
106
+ CREATE TABLE IF NOT EXISTS logs (
107
+ timestamp REAL,
108
+ run_id INTEGER,
109
+ node_id INTEGER,
110
+ log TEXT,
111
+ PRIMARY KEY (timestamp, run_id, node_id),
112
+ FOREIGN KEY (run_id) REFERENCES run(run_id)
113
+ );
114
+ """
115
+
116
+ SQL_CREATE_TABLE_CONTEXT = """
117
+ CREATE TABLE IF NOT EXISTS context(
118
+ run_id INTEGER UNIQUE,
119
+ context BLOB,
120
+ FOREIGN KEY(run_id) REFERENCES run(run_id)
83
121
  );
84
122
  """
85
123
 
@@ -126,14 +164,14 @@ CREATE TABLE IF NOT EXISTS task_res(
126
164
  DictOrTuple = Union[tuple[Any, ...], dict[str, Any]]
127
165
 
128
166
 
129
- class SqliteState(State): # pylint: disable=R0904
130
- """SQLite-based state implementation."""
167
+ class SqliteLinkState(LinkState): # pylint: disable=R0904
168
+ """SQLite-based LinkState implementation."""
131
169
 
132
170
  def __init__(
133
171
  self,
134
172
  database_path: str,
135
173
  ) -> None:
136
- """Initialize an SqliteState.
174
+ """Initialize an SqliteLinkState.
137
175
 
138
176
  Parameters
139
177
  ----------
@@ -143,6 +181,7 @@ class SqliteState(State): # pylint: disable=R0904
143
181
  """
144
182
  self.database_path = database_path
145
183
  self.conn: Optional[sqlite3.Connection] = None
184
+ self.lock = threading.RLock()
146
185
 
147
186
  def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
148
187
  """Create tables if they don't exist yet.
@@ -166,6 +205,8 @@ class SqliteState(State): # pylint: disable=R0904
166
205
 
167
206
  # Create each table if not exists queries
168
207
  cur.execute(SQL_CREATE_TABLE_RUN)
208
+ cur.execute(SQL_CREATE_TABLE_LOGS)
209
+ cur.execute(SQL_CREATE_TABLE_CONTEXT)
169
210
  cur.execute(SQL_CREATE_TABLE_TASK_INS)
170
211
  cur.execute(SQL_CREATE_TABLE_TASK_RES)
171
212
  cur.execute(SQL_CREATE_TABLE_NODE)
@@ -183,7 +224,7 @@ class SqliteState(State): # pylint: disable=R0904
183
224
  ) -> list[dict[str, Any]]:
184
225
  """Execute a SQL query."""
185
226
  if self.conn is None:
186
- raise AttributeError("State is not initialized.")
227
+ raise AttributeError("LinkState is not initialized.")
187
228
 
188
229
  if data is None:
189
230
  data = []
@@ -214,11 +255,11 @@ class SqliteState(State): # pylint: disable=R0904
214
255
  def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
215
256
  """Store one TaskIns.
216
257
 
217
- Usually, the Driver API calls this to schedule instructions.
258
+ Usually, the ServerAppIo API calls this to schedule instructions.
218
259
 
219
- Stores the value of the task_ins in the state and, if successful, returns the
220
- task_id (UUID) of the task_ins. If, for any reason, storing the task_ins fails,
221
- `None` is returned.
260
+ Stores the value of the task_ins in the link state and, if successful,
261
+ returns the task_id (UUID) of the task_ins. If, for any reason, storing
262
+ the task_ins fails, `None` is returned.
222
263
 
223
264
  Constraints
224
265
  -----------
@@ -233,7 +274,6 @@ class SqliteState(State): # pylint: disable=R0904
233
274
  if any(errors):
234
275
  log(ERROR, errors)
235
276
  return None
236
-
237
277
  # Create task_id
238
278
  task_id = uuid4()
239
279
 
@@ -246,16 +286,36 @@ class SqliteState(State): # pylint: disable=R0904
246
286
  data[0], ["run_id", "producer_node_id", "consumer_node_id"]
247
287
  )
248
288
 
289
+ # Validate run_id
290
+ query = "SELECT run_id FROM run WHERE run_id = ?;"
291
+ if not self.query(query, (data[0]["run_id"],)):
292
+ log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
293
+ return None
294
+ # Validate source node ID
295
+ if task_ins.task.producer.node_id != 0:
296
+ log(
297
+ ERROR,
298
+ "Invalid source node ID for TaskIns: %s",
299
+ task_ins.task.producer.node_id,
300
+ )
301
+ return None
302
+ # Validate destination node ID
303
+ query = "SELECT node_id FROM node WHERE node_id = ?;"
304
+ if not task_ins.task.consumer.anonymous:
305
+ if not self.query(query, (data[0]["consumer_node_id"],)):
306
+ log(
307
+ ERROR,
308
+ "Invalid destination node ID for TaskIns: %s",
309
+ task_ins.task.consumer.node_id,
310
+ )
311
+ return None
312
+
249
313
  columns = ", ".join([f":{key}" for key in data[0]])
250
314
  query = f"INSERT INTO task_ins VALUES({columns});"
251
315
 
252
316
  # Only invalid run_id can trigger IntegrityError.
253
317
  # This may need to be changed in the future version with more integrity checks.
254
- try:
255
- self.query(query, data)
256
- except sqlite3.IntegrityError:
257
- log(ERROR, "`run` is invalid")
258
- return None
318
+ self.query(query, data)
259
319
 
260
320
  return task_id
261
321
 
@@ -452,8 +512,8 @@ class SqliteState(State): # pylint: disable=R0904
452
512
  def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
453
513
  """Get TaskRes for task_ids.
454
514
 
455
- Usually, the Driver API calls this method to get results for instructions it has
456
- previously scheduled.
515
+ Usually, the ServerAppIo API calls this method to get results for instructions
516
+ it has previously scheduled.
457
517
 
458
518
  Retrieves all TaskRes for the given `task_ids` and returns and empty list if
459
519
  none could be found.
@@ -579,20 +639,6 @@ class SqliteState(State): # pylint: disable=R0904
579
639
  data = {f"id_{i}": str(node_id) for i, node_id in enumerate(offline_node_ids)}
580
640
  task_ins_rows = self.query(query, data)
581
641
 
582
- # Make TaskRes containing node unavailabe error
583
- for row in task_ins_rows:
584
- for row in rows:
585
- # Convert values from sint64 to uint64
586
- convert_sint64_values_in_dict_to_uint64(
587
- row, ["run_id", "producer_node_id", "consumer_node_id"]
588
- )
589
-
590
- task_ins = dict_to_task_ins(row)
591
- err_taskres = make_node_unavailable_taskres(
592
- ref_taskins=task_ins,
593
- )
594
- result.append(err_taskres)
595
-
596
642
  return result
597
643
 
598
644
  def num_task_ins(self) -> int:
@@ -645,7 +691,7 @@ class SqliteState(State): # pylint: disable=R0904
645
691
  """
646
692
 
647
693
  if self.conn is None:
648
- raise AttributeError("State not intitialized")
694
+ raise AttributeError("LinkState not intitialized")
649
695
 
650
696
  with self.conn:
651
697
  self.conn.execute(query_1, data)
@@ -656,7 +702,7 @@ class SqliteState(State): # pylint: disable=R0904
656
702
  def create_node(
657
703
  self, ping_interval: float, public_key: Optional[bytes] = None
658
704
  ) -> int:
659
- """Create, store in state, and return `node_id`."""
705
+ """Create, store in the link state, and return `node_id`."""
660
706
  # Sample a random uint64 as node_id
661
707
  uint64_node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
662
708
 
@@ -706,7 +752,7 @@ class SqliteState(State): # pylint: disable=R0904
706
752
  params += (public_key,) # type: ignore
707
753
 
708
754
  if self.conn is None:
709
- raise AttributeError("State is not initialized.")
755
+ raise AttributeError("LinkState is not initialized.")
710
756
 
711
757
  try:
712
758
  with self.conn:
@@ -753,12 +799,14 @@ class SqliteState(State): # pylint: disable=R0904
753
799
  return uint64_node_id
754
800
  return None
755
801
 
802
+ # pylint: disable=too-many-arguments,too-many-positional-arguments
756
803
  def create_run(
757
804
  self,
758
805
  fab_id: Optional[str],
759
806
  fab_version: Optional[str],
760
807
  fab_hash: Optional[str],
761
808
  override_config: UserConfig,
809
+ federation_options: ConfigsRecord,
762
810
  ) -> int:
763
811
  """Create a new run for the specified `fab_id` and `fab_version`."""
764
812
  # Sample a random int64 as run_id
@@ -773,26 +821,30 @@ class SqliteState(State): # pylint: disable=R0904
773
821
  if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
774
822
  query = (
775
823
  "INSERT INTO run "
776
- "(run_id, fab_id, fab_version, fab_hash, override_config)"
777
- "VALUES (?, ?, ?, ?, ?);"
824
+ "(run_id, fab_id, fab_version, fab_hash, override_config, "
825
+ "federation_options, pending_at, starting_at, running_at, finished_at, "
826
+ "sub_status, details) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
778
827
  )
779
828
  if fab_hash:
780
- self.query(
781
- query,
782
- (sint64_run_id, "", "", fab_hash, json.dumps(override_config)),
783
- )
784
- else:
785
- self.query(
786
- query,
787
- (
788
- sint64_run_id,
789
- fab_id,
790
- fab_version,
791
- "",
792
- json.dumps(override_config),
793
- ),
794
- )
795
- # Note: we need to return the uint64 value of the run_id
829
+ fab_id, fab_version = "", ""
830
+ override_config_json = json.dumps(override_config)
831
+ data = [
832
+ sint64_run_id,
833
+ fab_id,
834
+ fab_version,
835
+ fab_hash,
836
+ override_config_json,
837
+ configsrecord_to_bytes(federation_options),
838
+ ]
839
+ data += [
840
+ now().isoformat(),
841
+ "",
842
+ "",
843
+ "",
844
+ "",
845
+ "",
846
+ ]
847
+ self.query(query, tuple(data))
796
848
  return uint64_run_id
797
849
  log(ERROR, "Unexpected run creation failure.")
798
850
  return 0
@@ -800,7 +852,7 @@ class SqliteState(State): # pylint: disable=R0904
800
852
  def store_server_private_public_key(
801
853
  self, private_key: bytes, public_key: bytes
802
854
  ) -> None:
803
- """Store `server_private_key` and `server_public_key` in state."""
855
+ """Store `server_private_key` and `server_public_key` in the link state."""
804
856
  query = "SELECT COUNT(*) FROM credential"
805
857
  count = self.query(query)[0]["COUNT(*)"]
806
858
  if count < 1:
@@ -833,13 +885,13 @@ class SqliteState(State): # pylint: disable=R0904
833
885
  return public_key
834
886
 
835
887
  def store_node_public_keys(self, public_keys: set[bytes]) -> None:
836
- """Store a set of `node_public_keys` in state."""
888
+ """Store a set of `node_public_keys` in the link state."""
837
889
  query = "INSERT INTO public_key (public_key) VALUES (?)"
838
890
  data = [(key,) for key in public_keys]
839
891
  self.query(query, data)
840
892
 
841
893
  def store_node_public_key(self, public_key: bytes) -> None:
842
- """Store a `node_public_key` in state."""
894
+ """Store a `node_public_key` in the link state."""
843
895
  query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
844
896
  self.query(query, {"public_key": public_key})
845
897
 
@@ -850,6 +902,12 @@ class SqliteState(State): # pylint: disable=R0904
850
902
  result: set[bytes] = {row["public_key"] for row in rows}
851
903
  return result
852
904
 
905
+ def get_run_ids(self) -> set[int]:
906
+ """Retrieve all run IDs."""
907
+ query = "SELECT run_id FROM run;"
908
+ rows = self.query(query)
909
+ return {convert_sint64_to_uint64(row["run_id"]) for row in rows}
910
+
853
911
  def get_run(self, run_id: int) -> Optional[Run]:
854
912
  """Retrieve information about the run with the specified `run_id`."""
855
913
  # Convert the uint64 value to sint64 for SQLite
@@ -868,6 +926,109 @@ class SqliteState(State): # pylint: disable=R0904
868
926
  log(ERROR, "`run_id` does not exist.")
869
927
  return None
870
928
 
929
+ def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
930
+ """Retrieve the statuses for the specified runs."""
931
+ # Convert the uint64 value to sint64 for SQLite
932
+ sint64_run_ids = (convert_uint64_to_sint64(run_id) for run_id in set(run_ids))
933
+ query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});"
934
+ rows = self.query(query, tuple(sint64_run_ids))
935
+
936
+ return {
937
+ # Restore uint64 run IDs
938
+ convert_sint64_to_uint64(row["run_id"]): RunStatus(
939
+ status=determine_run_status(row),
940
+ sub_status=row["sub_status"],
941
+ details=row["details"],
942
+ )
943
+ for row in rows
944
+ }
945
+
946
+ def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
947
+ """Update the status of the run with the specified `run_id`."""
948
+ # Convert the uint64 value to sint64 for SQLite
949
+ sint64_run_id = convert_uint64_to_sint64(run_id)
950
+ query = "SELECT * FROM run WHERE run_id = ?;"
951
+ rows = self.query(query, (sint64_run_id,))
952
+
953
+ # Check if the run_id exists
954
+ if not rows:
955
+ log(ERROR, "`run_id` is invalid")
956
+ return False
957
+
958
+ # Check if the status transition is valid
959
+ row = rows[0]
960
+ current_status = RunStatus(
961
+ status=determine_run_status(row),
962
+ sub_status=row["sub_status"],
963
+ details=row["details"],
964
+ )
965
+ if not is_valid_transition(current_status, new_status):
966
+ log(
967
+ ERROR,
968
+ 'Invalid status transition: from "%s" to "%s"',
969
+ current_status.status,
970
+ new_status.status,
971
+ )
972
+ return False
973
+
974
+ # Check if the sub-status is valid
975
+ if not has_valid_sub_status(current_status):
976
+ log(
977
+ ERROR,
978
+ 'Invalid sub-status "%s" for status "%s"',
979
+ current_status.sub_status,
980
+ current_status.status,
981
+ )
982
+ return False
983
+
984
+ # Update the status
985
+ query = "UPDATE run SET %s= ?, sub_status = ?, details = ? "
986
+ query += "WHERE run_id = ?;"
987
+
988
+ timestamp_fld = ""
989
+ if new_status.status == Status.STARTING:
990
+ timestamp_fld = "starting_at"
991
+ elif new_status.status == Status.RUNNING:
992
+ timestamp_fld = "running_at"
993
+ elif new_status.status == Status.FINISHED:
994
+ timestamp_fld = "finished_at"
995
+
996
+ data = (
997
+ now().isoformat(),
998
+ new_status.sub_status,
999
+ new_status.details,
1000
+ sint64_run_id,
1001
+ )
1002
+ self.query(query % timestamp_fld, data)
1003
+ return True
1004
+
1005
+ def get_pending_run_id(self) -> Optional[int]:
1006
+ """Get the `run_id` of a run with `Status.PENDING` status, if any."""
1007
+ pending_run_id = None
1008
+
1009
+ # Fetch all runs with unset `starting_at` (i.e. they are in PENDING status)
1010
+ query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1;"
1011
+ rows = self.query(query)
1012
+ if rows:
1013
+ pending_run_id = convert_sint64_to_uint64(rows[0]["run_id"])
1014
+
1015
+ return pending_run_id
1016
+
1017
+ def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
1018
+ """Retrieve the federation options for the specified `run_id`."""
1019
+ # Convert the uint64 value to sint64 for SQLite
1020
+ sint64_run_id = convert_uint64_to_sint64(run_id)
1021
+ query = "SELECT federation_options FROM run WHERE run_id = ?;"
1022
+ rows = self.query(query, (sint64_run_id,))
1023
+
1024
+ # Check if the run_id exists
1025
+ if not rows:
1026
+ log(ERROR, "`run_id` is invalid")
1027
+ return None
1028
+
1029
+ row = rows[0]
1030
+ return configsrecord_from_bytes(row["federation_options"])
1031
+
871
1032
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
872
1033
  """Acknowledge a ping received from a node, serving as a heartbeat."""
873
1034
  sint64_node_id = convert_uint64_to_sint64(node_id)
@@ -883,6 +1044,72 @@ class SqliteState(State): # pylint: disable=R0904
883
1044
  log(ERROR, "`node_id` does not exist.")
884
1045
  return False
885
1046
 
1047
+ def get_serverapp_context(self, run_id: int) -> Optional[Context]:
1048
+ """Get the context for the specified `run_id`."""
1049
+ # Retrieve context if any
1050
+ query = "SELECT context FROM context WHERE run_id = ?;"
1051
+ rows = self.query(query, (convert_uint64_to_sint64(run_id),))
1052
+ context = context_from_bytes(rows[0]["context"]) if rows else None
1053
+ return context
1054
+
1055
+ def set_serverapp_context(self, run_id: int, context: Context) -> None:
1056
+ """Set the context for the specified `run_id`."""
1057
+ # Convert context to bytes
1058
+ context_bytes = context_to_bytes(context)
1059
+ sint_run_id = convert_uint64_to_sint64(run_id)
1060
+
1061
+ # Check if any existing Context assigned to the run_id
1062
+ query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
1063
+ if self.query(query, (sint_run_id,))[0]["COUNT(*)"] > 0:
1064
+ # Update context
1065
+ query = "UPDATE context SET context = ? WHERE run_id = ?;"
1066
+ self.query(query, (context_bytes, sint_run_id))
1067
+ else:
1068
+ try:
1069
+ # Store context
1070
+ query = "INSERT INTO context (run_id, context) VALUES (?, ?);"
1071
+ self.query(query, (sint_run_id, context_bytes))
1072
+ except sqlite3.IntegrityError:
1073
+ raise ValueError(f"Run {run_id} not found") from None
1074
+
1075
+ def add_serverapp_log(self, run_id: int, log_message: str) -> None:
1076
+ """Add a log entry to the ServerApp logs for the specified `run_id`."""
1077
+ # Convert the uint64 value to sint64 for SQLite
1078
+ sint64_run_id = convert_uint64_to_sint64(run_id)
1079
+
1080
+ # Store log
1081
+ try:
1082
+ query = """
1083
+ INSERT INTO logs (timestamp, run_id, node_id, log) VALUES (?, ?, ?, ?);
1084
+ """
1085
+ self.query(query, (now().timestamp(), sint64_run_id, 0, log_message))
1086
+ except sqlite3.IntegrityError:
1087
+ raise ValueError(f"Run {run_id} not found") from None
1088
+
1089
+ def get_serverapp_log(
1090
+ self, run_id: int, after_timestamp: Optional[float]
1091
+ ) -> tuple[str, float]:
1092
+ """Get the ServerApp logs for the specified `run_id`."""
1093
+ # Convert the uint64 value to sint64 for SQLite
1094
+ sint64_run_id = convert_uint64_to_sint64(run_id)
1095
+
1096
+ # Check if the run_id exists
1097
+ query = "SELECT run_id FROM run WHERE run_id = ?;"
1098
+ if not self.query(query, (sint64_run_id,)):
1099
+ raise ValueError(f"Run {run_id} not found")
1100
+
1101
+ # Retrieve logs
1102
+ if after_timestamp is None:
1103
+ after_timestamp = 0.0
1104
+ query = """
1105
+ SELECT log, timestamp FROM logs
1106
+ WHERE run_id = ? AND node_id = ? AND timestamp > ?;
1107
+ """
1108
+ rows = self.query(query, (sint64_run_id, 0, after_timestamp))
1109
+ rows.sort(key=lambda x: x["timestamp"])
1110
+ latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
1111
+ return "".join(row["log"] for row in rows), latest_timestamp
1112
+
886
1113
  def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]:
887
1114
  """Check if the TaskIns exists and is valid (not expired).
888
1115
 
@@ -967,7 +1194,7 @@ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
967
1194
 
968
1195
  def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
969
1196
  """Turn task_dict into protobuf message."""
970
- recordset = RecordSet()
1197
+ recordset = ProtoRecordSet()
971
1198
  recordset.ParseFromString(task_dict["recordset"])
972
1199
 
973
1200
  result = TaskIns(
@@ -997,7 +1224,7 @@ def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
997
1224
 
998
1225
  def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
999
1226
  """Turn task_dict into protobuf message."""
1000
- recordset = RecordSet()
1227
+ recordset = ProtoRecordSet()
1001
1228
  recordset.ParseFromString(task_dict["recordset"])
1002
1229
 
1003
1230
  result = TaskRes(
@@ -1023,3 +1250,17 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
1023
1250
  ),
1024
1251
  )
1025
1252
  return result
1253
+
1254
+
1255
+ def determine_run_status(row: dict[str, Any]) -> str:
1256
+ """Determine the status of the run based on timestamp fields."""
1257
+ if row["pending_at"]:
1258
+ if row["starting_at"]:
1259
+ if row["running_at"]:
1260
+ if row["finished_at"]:
1261
+ return Status.FINISHED
1262
+ return Status.RUNNING
1263
+ return Status.STARTING
1264
+ return Status.PENDING
1265
+ run_id = convert_sint64_to_uint64(row["run_id"])
1266
+ raise sqlite3.IntegrityError(f"The run {run_id} does not have a valid status.")