flwr-nightly 1.13.0.dev20241022__py3-none-any.whl → 1.13.0.dev20241024__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

@@ -17,22 +17,41 @@
17
17
 
18
18
  import threading
19
19
  import time
20
+ from dataclasses import dataclass
20
21
  from logging import ERROR, WARNING
21
22
  from typing import Optional
22
23
  from uuid import UUID, uuid4
23
24
 
24
- from flwr.common import log, now
25
+ from flwr.common import Context, log, now
25
26
  from flwr.common.constant import (
26
27
  MESSAGE_TTL_TOLERANCE,
27
28
  NODE_ID_NUM_BYTES,
28
29
  RUN_ID_NUM_BYTES,
30
+ Status,
29
31
  )
30
- from flwr.common.typing import Run, UserConfig
32
+ from flwr.common.typing import Run, RunStatus, UserConfig
31
33
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
32
34
  from flwr.server.superlink.linkstate.linkstate import LinkState
33
35
  from flwr.server.utils import validate_task_ins_or_res
34
36
 
35
- from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres
37
+ from .utils import (
38
+ generate_rand_int_from_bytes,
39
+ has_valid_sub_status,
40
+ is_valid_transition,
41
+ make_node_unavailable_taskres,
42
+ )
43
+
44
+
45
+ @dataclass
46
+ class RunRecord:
47
+ """The record of a specific run, including its status and timestamps."""
48
+
49
+ run: Run
50
+ status: RunStatus
51
+ pending_at: str = ""
52
+ starting_at: str = ""
53
+ running_at: str = ""
54
+ finished_at: str = ""
36
55
 
37
56
 
38
57
  class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
@@ -44,8 +63,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
44
63
  self.node_ids: dict[int, tuple[float, float]] = {}
45
64
  self.public_key_to_node_id: dict[bytes, int] = {}
46
65
 
47
- # Map run_id to (fab_id, fab_version)
48
- self.run_ids: dict[int, Run] = {}
66
+ # Map run_id to RunRecord
67
+ self.run_ids: dict[int, RunRecord] = {}
68
+ self.contexts: dict[int, Context] = {}
49
69
  self.task_ins_store: dict[UUID, TaskIns] = {}
50
70
  self.task_res_store: dict[UUID, TaskRes] = {}
51
71
 
@@ -351,13 +371,22 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
351
371
  run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
352
372
 
353
373
  if run_id not in self.run_ids:
354
- self.run_ids[run_id] = Run(
355
- run_id=run_id,
356
- fab_id=fab_id if fab_id else "",
357
- fab_version=fab_version if fab_version else "",
358
- fab_hash=fab_hash if fab_hash else "",
359
- override_config=override_config,
374
+ run_record = RunRecord(
375
+ run=Run(
376
+ run_id=run_id,
377
+ fab_id=fab_id if fab_id else "",
378
+ fab_version=fab_version if fab_version else "",
379
+ fab_hash=fab_hash if fab_hash else "",
380
+ override_config=override_config,
381
+ ),
382
+ status=RunStatus(
383
+ status=Status.PENDING,
384
+ sub_status="",
385
+ details="",
386
+ ),
387
+ pending_at=now().isoformat(),
360
388
  )
389
+ self.run_ids[run_id] = run_record
361
390
  return run_id
362
391
  log(ERROR, "Unexpected run creation failure.")
363
392
  return 0
@@ -401,7 +430,69 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
401
430
  if run_id not in self.run_ids:
402
431
  log(ERROR, "`run_id` is invalid")
403
432
  return None
404
- return self.run_ids[run_id]
433
+ return self.run_ids[run_id].run
434
+
435
+ def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
436
+ """Retrieve the statuses for the specified runs."""
437
+ with self.lock:
438
+ return {
439
+ run_id: self.run_ids[run_id].status
440
+ for run_id in set(run_ids)
441
+ if run_id in self.run_ids
442
+ }
443
+
444
+ def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
445
+ """Update the status of the run with the specified `run_id`."""
446
+ with self.lock:
447
+ # Check if the run_id exists
448
+ if run_id not in self.run_ids:
449
+ log(ERROR, "`run_id` is invalid")
450
+ return False
451
+
452
+ # Check if the status transition is valid
453
+ current_status = self.run_ids[run_id].status
454
+ if not is_valid_transition(current_status, new_status):
455
+ log(
456
+ ERROR,
457
+ 'Invalid status transition: from "%s" to "%s"',
458
+ current_status.status,
459
+ new_status.status,
460
+ )
461
+ return False
462
+
463
+ # Check if the sub-status is valid
464
+ if not has_valid_sub_status(current_status):
465
+ log(
466
+ ERROR,
467
+ 'Invalid sub-status "%s" for status "%s"',
468
+ current_status.sub_status,
469
+ current_status.status,
470
+ )
471
+ return False
472
+
473
+ # Update the status
474
+ run_record = self.run_ids[run_id]
475
+ if new_status.status == Status.STARTING:
476
+ run_record.starting_at = now().isoformat()
477
+ elif new_status.status == Status.RUNNING:
478
+ run_record.running_at = now().isoformat()
479
+ elif new_status.status == Status.FINISHED:
480
+ run_record.finished_at = now().isoformat()
481
+ run_record.status = new_status
482
+ return True
483
+
484
+ def get_pending_run_id(self) -> Optional[int]:
485
+ """Get the `run_id` of a run with `Status.PENDING` status, if any."""
486
+ pending_run_id = None
487
+
488
+ # Loop through all registered runs
489
+ for run_id, run_rec in self.run_ids.items():
490
+ # Break once a pending run is found
491
+ if run_rec.status.status == Status.PENDING:
492
+ pending_run_id = run_id
493
+ break
494
+
495
+ return pending_run_id
405
496
 
406
497
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
407
498
  """Acknowledge a ping received from a node, serving as a heartbeat."""
@@ -410,3 +501,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
410
501
  self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
411
502
  return True
412
503
  return False
504
+
505
+ def get_serverapp_context(self, run_id: int) -> Optional[Context]:
506
+ """Get the context for the specified `run_id`."""
507
+ return self.contexts.get(run_id)
508
+
509
+ def set_serverapp_context(self, run_id: int, context: Context) -> None:
510
+ """Set the context for the specified `run_id`."""
511
+ if run_id not in self.run_ids:
512
+ raise ValueError(f"Run {run_id} not found")
513
+ self.contexts[run_id] = context
@@ -19,7 +19,8 @@ import abc
19
19
  from typing import Optional
20
20
  from uuid import UUID
21
21
 
22
- from flwr.common.typing import Run, UserConfig
22
+ from flwr.common import Context
23
+ from flwr.common.typing import Run, RunStatus, UserConfig
23
24
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
24
25
 
25
26
 
@@ -178,6 +179,54 @@ class LinkState(abc.ABC): # pylint: disable=R0904
178
179
  - `fab_version`: The version of the FAB used in the specified run.
179
180
  """
180
181
 
182
+ @abc.abstractmethod
183
+ def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
184
+ """Retrieve the statuses for the specified runs.
185
+
186
+ Parameters
187
+ ----------
188
+ run_ids : set[int]
189
+ A set of run identifiers for which to retrieve statuses.
190
+
191
+ Returns
192
+ -------
193
+ dict[int, RunStatus]
194
+ A dictionary mapping each valid run ID to its corresponding status.
195
+
196
+ Notes
197
+ -----
198
+ Only valid run IDs that exist in the State will be included in the returned
199
+ dictionary. If a run ID is not found, it will be omitted from the result.
200
+ """
201
+
202
+ @abc.abstractmethod
203
+ def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
204
+ """Update the status of the run with the specified `run_id`.
205
+
206
+ Parameters
207
+ ----------
208
+ run_id : int
209
+ The identifier of the run.
210
+ new_status : RunStatus
211
+ The new status to be assigned to the run.
212
+
213
+ Returns
214
+ -------
215
+ bool
216
+ True if the status update is successful; False otherwise.
217
+ """
218
+
219
+ @abc.abstractmethod
220
+ def get_pending_run_id(self) -> Optional[int]:
221
+ """Get the `run_id` of a run with `Status.PENDING` status.
222
+
223
+ Returns
224
+ -------
225
+ Optional[int]
226
+ The `run_id` of a `Run` that is pending to be started; None if
227
+ there is no Run pending.
228
+ """
229
+
181
230
  @abc.abstractmethod
182
231
  def store_server_private_public_key(
183
232
  self, private_key: bytes, public_key: bytes
@@ -222,3 +271,31 @@ class LinkState(abc.ABC): # pylint: disable=R0904
222
271
  is_acknowledged : bool
223
272
  True if the ping is successfully acknowledged; otherwise, False.
224
273
  """
274
+
275
+ @abc.abstractmethod
276
+ def get_serverapp_context(self, run_id: int) -> Optional[Context]:
277
+ """Get the context for the specified `run_id`.
278
+
279
+ Parameters
280
+ ----------
281
+ run_id : int
282
+ The identifier of the run for which to retrieve the context.
283
+
284
+ Returns
285
+ -------
286
+ Optional[Context]
287
+ The context associated with the specified `run_id`, or `None` if no context
288
+ exists for the given `run_id`.
289
+ """
290
+
291
+ @abc.abstractmethod
292
+ def set_serverapp_context(self, run_id: int, context: Context) -> None:
293
+ """Set the context for the specified `run_id`.
294
+
295
+ Parameters
296
+ ----------
297
+ run_id : int
298
+ The identifier of the run for which to set the context.
299
+ context : Context
300
+ The context to be associated with the specified `run_id`.
301
+ """
@@ -19,31 +19,41 @@
19
19
  import json
20
20
  import re
21
21
  import sqlite3
22
+ import threading
22
23
  import time
23
24
  from collections.abc import Sequence
24
25
  from logging import DEBUG, ERROR, WARNING
25
26
  from typing import Any, Optional, Union, cast
26
27
  from uuid import UUID, uuid4
27
28
 
28
- from flwr.common import log, now
29
+ from flwr.common import Context, log, now
29
30
  from flwr.common.constant import (
30
31
  MESSAGE_TTL_TOLERANCE,
31
32
  NODE_ID_NUM_BYTES,
32
33
  RUN_ID_NUM_BYTES,
34
+ Status,
33
35
  )
34
- from flwr.common.typing import Run, UserConfig
35
- from flwr.proto.node_pb2 import Node # pylint: disable=E0611
36
- from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
37
- from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
36
+ from flwr.common.typing import Run, RunStatus, UserConfig
37
+
38
+ # pylint: disable=E0611
39
+ from flwr.proto.node_pb2 import Node
40
+ from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
41
+ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
42
+
43
+ # pylint: enable=E0611
38
44
  from flwr.server.utils.validator import validate_task_ins_or_res
39
45
 
40
46
  from .linkstate import LinkState
41
47
  from .utils import (
48
+ context_from_bytes,
49
+ context_to_bytes,
42
50
  convert_sint64_to_uint64,
43
51
  convert_sint64_values_in_dict_to_uint64,
44
52
  convert_uint64_to_sint64,
45
53
  convert_uint64_values_in_dict_to_sint64,
46
54
  generate_rand_int_from_bytes,
55
+ has_valid_sub_status,
56
+ is_valid_transition,
47
57
  make_node_unavailable_taskres,
48
58
  )
49
59
 
@@ -79,7 +89,21 @@ CREATE TABLE IF NOT EXISTS run(
79
89
  fab_id TEXT,
80
90
  fab_version TEXT,
81
91
  fab_hash TEXT,
82
- override_config TEXT
92
+ override_config TEXT,
93
+ pending_at TEXT,
94
+ starting_at TEXT,
95
+ running_at TEXT,
96
+ finished_at TEXT,
97
+ sub_status TEXT,
98
+ details TEXT
99
+ );
100
+ """
101
+
102
+ SQL_CREATE_TABLE_CONTEXT = """
103
+ CREATE TABLE IF NOT EXISTS context(
104
+ run_id INTEGER UNIQUE,
105
+ context BLOB,
106
+ FOREIGN KEY(run_id) REFERENCES run(run_id)
83
107
  );
84
108
  """
85
109
 
@@ -133,7 +157,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
133
157
  self,
134
158
  database_path: str,
135
159
  ) -> None:
136
- """Initialize an SqliteState.
160
+ """Initialize an SqliteLinkState.
137
161
 
138
162
  Parameters
139
163
  ----------
@@ -143,6 +167,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
143
167
  """
144
168
  self.database_path = database_path
145
169
  self.conn: Optional[sqlite3.Connection] = None
170
+ self.lock = threading.RLock()
146
171
 
147
172
  def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
148
173
  """Create tables if they don't exist yet.
@@ -166,6 +191,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
166
191
 
167
192
  # Create each table if not exists queries
168
193
  cur.execute(SQL_CREATE_TABLE_RUN)
194
+ cur.execute(SQL_CREATE_TABLE_CONTEXT)
169
195
  cur.execute(SQL_CREATE_TABLE_TASK_INS)
170
196
  cur.execute(SQL_CREATE_TABLE_TASK_RES)
171
197
  cur.execute(SQL_CREATE_TABLE_NODE)
@@ -773,26 +799,16 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
773
799
  if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
774
800
  query = (
775
801
  "INSERT INTO run "
776
- "(run_id, fab_id, fab_version, fab_hash, override_config)"
777
- "VALUES (?, ?, ?, ?, ?);"
802
+ "(run_id, fab_id, fab_version, fab_hash, override_config, pending_at, "
803
+ "starting_at, running_at, finished_at, sub_status, details)"
804
+ "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
778
805
  )
779
806
  if fab_hash:
780
- self.query(
781
- query,
782
- (sint64_run_id, "", "", fab_hash, json.dumps(override_config)),
783
- )
784
- else:
785
- self.query(
786
- query,
787
- (
788
- sint64_run_id,
789
- fab_id,
790
- fab_version,
791
- "",
792
- json.dumps(override_config),
793
- ),
794
- )
795
- # Note: we need to return the uint64 value of the run_id
807
+ fab_id, fab_version = "", ""
808
+ override_config_json = json.dumps(override_config)
809
+ data = [sint64_run_id, fab_id, fab_version, fab_hash, override_config_json]
810
+ data += [now().isoformat(), "", "", "", "", ""]
811
+ self.query(query, tuple(data))
796
812
  return uint64_run_id
797
813
  log(ERROR, "Unexpected run creation failure.")
798
814
  return 0
@@ -868,6 +884,94 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
868
884
  log(ERROR, "`run_id` does not exist.")
869
885
  return None
870
886
 
887
+ def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
888
+ """Retrieve the statuses for the specified runs."""
889
+ # Convert the uint64 value to sint64 for SQLite
890
+ sint64_run_ids = (convert_uint64_to_sint64(run_id) for run_id in set(run_ids))
891
+ query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});"
892
+ rows = self.query(query, tuple(sint64_run_ids))
893
+
894
+ return {
895
+ # Restore uint64 run IDs
896
+ convert_sint64_to_uint64(row["run_id"]): RunStatus(
897
+ status=determine_run_status(row),
898
+ sub_status=row["sub_status"],
899
+ details=row["details"],
900
+ )
901
+ for row in rows
902
+ }
903
+
904
+ def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
905
+ """Update the status of the run with the specified `run_id`."""
906
+ # Convert the uint64 value to sint64 for SQLite
907
+ sint64_run_id = convert_uint64_to_sint64(run_id)
908
+ query = "SELECT * FROM run WHERE run_id = ?;"
909
+ rows = self.query(query, (sint64_run_id,))
910
+
911
+ # Check if the run_id exists
912
+ if not rows:
913
+ log(ERROR, "`run_id` is invalid")
914
+ return False
915
+
916
+ # Check if the status transition is valid
917
+ row = rows[0]
918
+ current_status = RunStatus(
919
+ status=determine_run_status(row),
920
+ sub_status=row["sub_status"],
921
+ details=row["details"],
922
+ )
923
+ if not is_valid_transition(current_status, new_status):
924
+ log(
925
+ ERROR,
926
+ 'Invalid status transition: from "%s" to "%s"',
927
+ current_status.status,
928
+ new_status.status,
929
+ )
930
+ return False
931
+
932
+ # Check if the sub-status is valid
933
+ if not has_valid_sub_status(current_status):
934
+ log(
935
+ ERROR,
936
+ 'Invalid sub-status "%s" for status "%s"',
937
+ current_status.sub_status,
938
+ current_status.status,
939
+ )
940
+ return False
941
+
942
+ # Update the status
943
+ query = "UPDATE run SET %s= ?, sub_status = ?, details = ? "
944
+ query += "WHERE run_id = ?;"
945
+
946
+ timestamp_fld = ""
947
+ if new_status.status == Status.STARTING:
948
+ timestamp_fld = "starting_at"
949
+ elif new_status.status == Status.RUNNING:
950
+ timestamp_fld = "running_at"
951
+ elif new_status.status == Status.FINISHED:
952
+ timestamp_fld = "finished_at"
953
+
954
+ data = (
955
+ now().isoformat(),
956
+ new_status.sub_status,
957
+ new_status.details,
958
+ sint64_run_id,
959
+ )
960
+ self.query(query % timestamp_fld, data)
961
+ return True
962
+
963
+ def get_pending_run_id(self) -> Optional[int]:
964
+ """Get the `run_id` of a run with `Status.PENDING` status, if any."""
965
+ pending_run_id = None
966
+
967
+ # Fetch all runs with unset `starting_at` (i.e. they are in PENDING status)
968
+ query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1;"
969
+ rows = self.query(query)
970
+ if rows:
971
+ pending_run_id = convert_sint64_to_uint64(rows[0]["run_id"])
972
+
973
+ return pending_run_id
974
+
871
975
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
872
976
  """Acknowledge a ping received from a node, serving as a heartbeat."""
873
977
  sint64_node_id = convert_uint64_to_sint64(node_id)
@@ -883,6 +987,34 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
883
987
  log(ERROR, "`node_id` does not exist.")
884
988
  return False
885
989
 
990
+ def get_serverapp_context(self, run_id: int) -> Optional[Context]:
991
+ """Get the context for the specified `run_id`."""
992
+ # Retrieve context if any
993
+ query = "SELECT context FROM context WHERE run_id = ?;"
994
+ rows = self.query(query, (convert_uint64_to_sint64(run_id),))
995
+ context = context_from_bytes(rows[0]["context"]) if rows else None
996
+ return context
997
+
998
+ def set_serverapp_context(self, run_id: int, context: Context) -> None:
999
+ """Set the context for the specified `run_id`."""
1000
+ # Convert context to bytes
1001
+ context_bytes = context_to_bytes(context)
1002
+ sint_run_id = convert_uint64_to_sint64(run_id)
1003
+
1004
+ # Check if any existing Context assigned to the run_id
1005
+ query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
1006
+ if self.query(query, (sint_run_id,))[0]["COUNT(*)"] > 0:
1007
+ # Update context
1008
+ query = "UPDATE context SET context = ? WHERE run_id = ?;"
1009
+ self.query(query, (context_bytes, sint_run_id))
1010
+ else:
1011
+ try:
1012
+ # Store context
1013
+ query = "INSERT INTO context (run_id, context) VALUES (?, ?);"
1014
+ self.query(query, (sint_run_id, context_bytes))
1015
+ except sqlite3.IntegrityError:
1016
+ raise ValueError(f"Run {run_id} not found") from None
1017
+
886
1018
  def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]:
887
1019
  """Check if the TaskIns exists and is valid (not expired).
888
1020
 
@@ -967,7 +1099,7 @@ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
967
1099
 
968
1100
  def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
969
1101
  """Turn task_dict into protobuf message."""
970
- recordset = RecordSet()
1102
+ recordset = ProtoRecordSet()
971
1103
  recordset.ParseFromString(task_dict["recordset"])
972
1104
 
973
1105
  result = TaskIns(
@@ -997,7 +1129,7 @@ def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
997
1129
 
998
1130
  def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
999
1131
  """Turn task_dict into protobuf message."""
1000
- recordset = RecordSet()
1132
+ recordset = ProtoRecordSet()
1001
1133
  recordset.ParseFromString(task_dict["recordset"])
1002
1134
 
1003
1135
  result = TaskRes(
@@ -1023,3 +1155,17 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
1023
1155
  ),
1024
1156
  )
1025
1157
  return result
1158
+
1159
+
1160
+ def determine_run_status(row: dict[str, Any]) -> str:
1161
+ """Determine the status of the run based on timestamp fields."""
1162
+ if row["pending_at"]:
1163
+ if row["starting_at"]:
1164
+ if row["running_at"]:
1165
+ if row["finished_at"]:
1166
+ return Status.FINISHED
1167
+ return Status.RUNNING
1168
+ return Status.STARTING
1169
+ return Status.PENDING
1170
+ run_id = convert_sint64_to_uint64(row["run_id"])
1171
+ raise sqlite3.IntegrityError(f"The run {run_id} does not have a valid status.")
@@ -20,9 +20,11 @@ from logging import ERROR
20
20
  from os import urandom
21
21
  from uuid import uuid4
22
22
 
23
- from flwr.common import log
24
- from flwr.common.constant import ErrorCode
23
+ from flwr.common import Context, log, serde
24
+ from flwr.common.constant import ErrorCode, Status, SubStatus
25
+ from flwr.common.typing import RunStatus
25
26
  from flwr.proto.error_pb2 import Error # pylint: disable=E0611
27
+ from flwr.proto.message_pb2 import Context as ProtoContext # pylint: disable=E0611
26
28
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
27
29
  from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
28
30
 
@@ -31,6 +33,17 @@ NODE_UNAVAILABLE_ERROR_REASON = (
31
33
  "It exceeds the time limit specified in its last ping."
32
34
  )
33
35
 
36
+ VALID_RUN_STATUS_TRANSITIONS = {
37
+ (Status.PENDING, Status.STARTING),
38
+ (Status.STARTING, Status.RUNNING),
39
+ (Status.RUNNING, Status.FINISHED),
40
+ }
41
+ VALID_RUN_SUB_STATUSES = {
42
+ SubStatus.COMPLETED,
43
+ SubStatus.FAILED,
44
+ SubStatus.STOPPED,
45
+ }
46
+
34
47
 
35
48
  def generate_rand_int_from_bytes(num_bytes: int) -> int:
36
49
  """Generate a random unsigned integer from `num_bytes` bytes."""
@@ -123,6 +136,16 @@ def convert_sint64_values_in_dict_to_uint64(
123
136
  data_dict[key] = convert_sint64_to_uint64(data_dict[key])
124
137
 
125
138
 
139
+ def context_to_bytes(context: Context) -> bytes:
140
+ """Serialize `Context` to bytes."""
141
+ return serde.context_to_proto(context).SerializeToString()
142
+
143
+
144
+ def context_from_bytes(context_bytes: bytes) -> Context:
145
+ """Deserialize `Context` from bytes."""
146
+ return serde.context_from_proto(ProtoContext.FromString(context_bytes))
147
+
148
+
126
149
  def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
127
150
  """Generate a TaskRes with a node unavailable error from a TaskIns."""
128
151
  current_time = time.time()
@@ -146,3 +169,47 @@ def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
146
169
  ),
147
170
  ),
148
171
  )
172
+
173
+
174
+ def is_valid_transition(current_status: RunStatus, new_status: RunStatus) -> bool:
175
+ """Check if a transition between two run statuses is valid.
176
+
177
+ Parameters
178
+ ----------
179
+ current_status : RunStatus
180
+ The current status of the run.
181
+ new_status : RunStatus
182
+ The new status to transition to.
183
+
184
+ Returns
185
+ -------
186
+ bool
187
+ True if the transition is valid, False otherwise.
188
+ """
189
+ return (
190
+ current_status.status,
191
+ new_status.status,
192
+ ) in VALID_RUN_STATUS_TRANSITIONS
193
+
194
+
195
+ def has_valid_sub_status(status: RunStatus) -> bool:
196
+ """Check if the 'sub_status' field of the given status is valid.
197
+
198
+ Parameters
199
+ ----------
200
+ status : RunStatus
201
+ The status object to be checked.
202
+
203
+ Returns
204
+ -------
205
+ bool
206
+ True if the status object has a valid sub-status, False otherwise.
207
+
208
+ Notes
209
+ -----
210
+ Only an empty string (i.e., "") is considered a valid sub-status for
211
+ non-finished statuses. The sub-status of a finished status cannot be empty.
212
+ """
213
+ if status.status == Status.FINISHED:
214
+ return status.sub_status in VALID_RUN_SUB_STATUSES
215
+ return status.sub_status == ""