flwr-nightly 1.13.0.dev20241019__py3-none-any.whl → 1.13.0.dev20241106__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (81) hide show
  1. flwr/cli/build.py +2 -2
  2. flwr/cli/config_utils.py +97 -0
  3. flwr/cli/log.py +63 -97
  4. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -0
  6. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  7. flwr/cli/run/run.py +18 -83
  8. flwr/client/app.py +13 -14
  9. flwr/client/clientapp/app.py +1 -2
  10. flwr/client/{node_state.py → run_info_store.py} +4 -3
  11. flwr/client/supernode/app.py +6 -8
  12. flwr/common/constant.py +39 -4
  13. flwr/common/context.py +9 -4
  14. flwr/common/date.py +3 -3
  15. flwr/common/logger.py +103 -0
  16. flwr/common/serde.py +24 -0
  17. flwr/common/telemetry.py +0 -6
  18. flwr/common/typing.py +9 -0
  19. flwr/proto/exec_pb2.py +6 -6
  20. flwr/proto/exec_pb2.pyi +8 -2
  21. flwr/proto/log_pb2.py +29 -0
  22. flwr/proto/log_pb2.pyi +39 -0
  23. flwr/proto/log_pb2_grpc.py +4 -0
  24. flwr/proto/log_pb2_grpc.pyi +4 -0
  25. flwr/proto/message_pb2.py +8 -8
  26. flwr/proto/message_pb2.pyi +4 -1
  27. flwr/proto/serverappio_pb2.py +52 -0
  28. flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +54 -0
  29. flwr/proto/serverappio_pb2_grpc.py +376 -0
  30. flwr/proto/serverappio_pb2_grpc.pyi +147 -0
  31. flwr/proto/simulationio_pb2.py +38 -0
  32. flwr/proto/simulationio_pb2.pyi +65 -0
  33. flwr/proto/simulationio_pb2_grpc.py +171 -0
  34. flwr/proto/simulationio_pb2_grpc.pyi +68 -0
  35. flwr/server/app.py +247 -105
  36. flwr/server/driver/driver.py +15 -1
  37. flwr/server/driver/grpc_driver.py +26 -33
  38. flwr/server/driver/inmemory_driver.py +6 -14
  39. flwr/server/run_serverapp.py +29 -23
  40. flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
  41. flwr/server/serverapp/app.py +270 -0
  42. flwr/server/strategy/fedadam.py +11 -1
  43. flwr/server/superlink/driver/__init__.py +1 -1
  44. flwr/server/superlink/driver/{driver_grpc.py → serverappio_grpc.py} +19 -16
  45. flwr/server/superlink/driver/{driver_servicer.py → serverappio_servicer.py} +125 -39
  46. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
  47. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -2
  48. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
  49. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
  50. flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
  51. flwr/server/superlink/fleet/rest_rere/rest_api.py +7 -7
  52. flwr/server/superlink/fleet/vce/vce_api.py +23 -23
  53. flwr/server/superlink/linkstate/__init__.py +28 -0
  54. flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +180 -21
  55. flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +144 -15
  56. flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
  57. flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +300 -50
  58. flwr/server/superlink/{state → linkstate}/utils.py +84 -2
  59. flwr/server/superlink/simulation/__init__.py +15 -0
  60. flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
  61. flwr/server/superlink/simulation/simulationio_servicer.py +132 -0
  62. flwr/simulation/__init__.py +2 -0
  63. flwr/simulation/app.py +1 -1
  64. flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
  65. flwr/simulation/run_simulation.py +57 -131
  66. flwr/simulation/simulationio_connection.py +86 -0
  67. flwr/superexec/app.py +6 -134
  68. flwr/superexec/deployment.py +60 -65
  69. flwr/superexec/exec_grpc.py +15 -8
  70. flwr/superexec/exec_servicer.py +34 -63
  71. flwr/superexec/executor.py +22 -4
  72. flwr/superexec/simulation.py +13 -8
  73. {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/METADATA +1 -1
  74. {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/RECORD +77 -64
  75. {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/entry_points.txt +1 -0
  76. flwr/client/node_state_tests.py +0 -66
  77. flwr/proto/driver_pb2.py +0 -42
  78. flwr/proto/driver_pb2_grpc.py +0 -239
  79. flwr/proto/driver_pb2_grpc.pyi +0 -94
  80. {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/LICENSE +0 -0
  81. {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/WHEEL +0 -0
@@ -12,38 +12,51 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """SQLite based implemenation of server state."""
15
+ """SQLite based implemenation of the link state."""
16
16
 
17
17
  # pylint: disable=too-many-lines
18
18
 
19
19
  import json
20
20
  import re
21
21
  import sqlite3
22
+ import threading
22
23
  import time
23
24
  from collections.abc import Sequence
24
25
  from logging import DEBUG, ERROR, WARNING
25
26
  from typing import Any, Optional, Union, cast
26
27
  from uuid import UUID, uuid4
27
28
 
28
- from flwr.common import log, now
29
+ from flwr.common import Context, log, now
29
30
  from flwr.common.constant import (
30
31
  MESSAGE_TTL_TOLERANCE,
31
32
  NODE_ID_NUM_BYTES,
32
33
  RUN_ID_NUM_BYTES,
34
+ Status,
33
35
  )
34
- from flwr.common.typing import Run, UserConfig
35
- from flwr.proto.node_pb2 import Node # pylint: disable=E0611
36
- from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
37
- from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
36
+ from flwr.common.record import ConfigsRecord
37
+ from flwr.common.typing import Run, RunStatus, UserConfig
38
+
39
+ # pylint: disable=E0611
40
+ from flwr.proto.node_pb2 import Node
41
+ from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
42
+ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
43
+
44
+ # pylint: enable=E0611
38
45
  from flwr.server.utils.validator import validate_task_ins_or_res
39
46
 
40
- from .state import State
47
+ from .linkstate import LinkState
41
48
  from .utils import (
49
+ configsrecord_from_bytes,
50
+ configsrecord_to_bytes,
51
+ context_from_bytes,
52
+ context_to_bytes,
42
53
  convert_sint64_to_uint64,
43
54
  convert_sint64_values_in_dict_to_uint64,
44
55
  convert_uint64_to_sint64,
45
56
  convert_uint64_values_in_dict_to_sint64,
46
57
  generate_rand_int_from_bytes,
58
+ has_valid_sub_status,
59
+ is_valid_transition,
47
60
  make_node_unavailable_taskres,
48
61
  )
49
62
 
@@ -79,7 +92,33 @@ CREATE TABLE IF NOT EXISTS run(
79
92
  fab_id TEXT,
80
93
  fab_version TEXT,
81
94
  fab_hash TEXT,
82
- override_config TEXT
95
+ override_config TEXT,
96
+ pending_at TEXT,
97
+ starting_at TEXT,
98
+ running_at TEXT,
99
+ finished_at TEXT,
100
+ sub_status TEXT,
101
+ details TEXT,
102
+ federation_options BLOB
103
+ );
104
+ """
105
+
106
+ SQL_CREATE_TABLE_LOGS = """
107
+ CREATE TABLE IF NOT EXISTS logs (
108
+ timestamp REAL,
109
+ run_id INTEGER,
110
+ node_id INTEGER,
111
+ log TEXT,
112
+ PRIMARY KEY (timestamp, run_id, node_id),
113
+ FOREIGN KEY (run_id) REFERENCES run(run_id)
114
+ );
115
+ """
116
+
117
+ SQL_CREATE_TABLE_CONTEXT = """
118
+ CREATE TABLE IF NOT EXISTS context(
119
+ run_id INTEGER UNIQUE,
120
+ context BLOB,
121
+ FOREIGN KEY(run_id) REFERENCES run(run_id)
83
122
  );
84
123
  """
85
124
 
@@ -126,14 +165,14 @@ CREATE TABLE IF NOT EXISTS task_res(
126
165
  DictOrTuple = Union[tuple[Any, ...], dict[str, Any]]
127
166
 
128
167
 
129
- class SqliteState(State): # pylint: disable=R0904
130
- """SQLite-based state implementation."""
168
+ class SqliteLinkState(LinkState): # pylint: disable=R0904
169
+ """SQLite-based LinkState implementation."""
131
170
 
132
171
  def __init__(
133
172
  self,
134
173
  database_path: str,
135
174
  ) -> None:
136
- """Initialize an SqliteState.
175
+ """Initialize an SqliteLinkState.
137
176
 
138
177
  Parameters
139
178
  ----------
@@ -143,6 +182,7 @@ class SqliteState(State): # pylint: disable=R0904
143
182
  """
144
183
  self.database_path = database_path
145
184
  self.conn: Optional[sqlite3.Connection] = None
185
+ self.lock = threading.RLock()
146
186
 
147
187
  def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
148
188
  """Create tables if they don't exist yet.
@@ -166,6 +206,8 @@ class SqliteState(State): # pylint: disable=R0904
166
206
 
167
207
  # Create each table if not exists queries
168
208
  cur.execute(SQL_CREATE_TABLE_RUN)
209
+ cur.execute(SQL_CREATE_TABLE_LOGS)
210
+ cur.execute(SQL_CREATE_TABLE_CONTEXT)
169
211
  cur.execute(SQL_CREATE_TABLE_TASK_INS)
170
212
  cur.execute(SQL_CREATE_TABLE_TASK_RES)
171
213
  cur.execute(SQL_CREATE_TABLE_NODE)
@@ -183,7 +225,7 @@ class SqliteState(State): # pylint: disable=R0904
183
225
  ) -> list[dict[str, Any]]:
184
226
  """Execute a SQL query."""
185
227
  if self.conn is None:
186
- raise AttributeError("State is not initialized.")
228
+ raise AttributeError("LinkState is not initialized.")
187
229
 
188
230
  if data is None:
189
231
  data = []
@@ -214,11 +256,11 @@ class SqliteState(State): # pylint: disable=R0904
214
256
  def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
215
257
  """Store one TaskIns.
216
258
 
217
- Usually, the Driver API calls this to schedule instructions.
259
+ Usually, the ServerAppIo API calls this to schedule instructions.
218
260
 
219
- Stores the value of the task_ins in the state and, if successful, returns the
220
- task_id (UUID) of the task_ins. If, for any reason, storing the task_ins fails,
221
- `None` is returned.
261
+ Stores the value of the task_ins in the link state and, if successful,
262
+ returns the task_id (UUID) of the task_ins. If, for any reason, storing
263
+ the task_ins fails, `None` is returned.
222
264
 
223
265
  Constraints
224
266
  -----------
@@ -233,7 +275,6 @@ class SqliteState(State): # pylint: disable=R0904
233
275
  if any(errors):
234
276
  log(ERROR, errors)
235
277
  return None
236
-
237
278
  # Create task_id
238
279
  task_id = uuid4()
239
280
 
@@ -246,16 +287,36 @@ class SqliteState(State): # pylint: disable=R0904
246
287
  data[0], ["run_id", "producer_node_id", "consumer_node_id"]
247
288
  )
248
289
 
290
+ # Validate run_id
291
+ query = "SELECT run_id FROM run WHERE run_id = ?;"
292
+ if not self.query(query, (data[0]["run_id"],)):
293
+ log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
294
+ return None
295
+ # Validate source node ID
296
+ if task_ins.task.producer.node_id != 0:
297
+ log(
298
+ ERROR,
299
+ "Invalid source node ID for TaskIns: %s",
300
+ task_ins.task.producer.node_id,
301
+ )
302
+ return None
303
+ # Validate destination node ID
304
+ query = "SELECT node_id FROM node WHERE node_id = ?;"
305
+ if not task_ins.task.consumer.anonymous:
306
+ if not self.query(query, (data[0]["consumer_node_id"],)):
307
+ log(
308
+ ERROR,
309
+ "Invalid destination node ID for TaskIns: %s",
310
+ task_ins.task.consumer.node_id,
311
+ )
312
+ return None
313
+
249
314
  columns = ", ".join([f":{key}" for key in data[0]])
250
315
  query = f"INSERT INTO task_ins VALUES({columns});"
251
316
 
252
317
  # Only invalid run_id can trigger IntegrityError.
253
318
  # This may need to be changed in the future version with more integrity checks.
254
- try:
255
- self.query(query, data)
256
- except sqlite3.IntegrityError:
257
- log(ERROR, "`run` is invalid")
258
- return None
319
+ self.query(query, data)
259
320
 
260
321
  return task_id
261
322
 
@@ -452,8 +513,8 @@ class SqliteState(State): # pylint: disable=R0904
452
513
  def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
453
514
  """Get TaskRes for task_ids.
454
515
 
455
- Usually, the Driver API calls this method to get results for instructions it has
456
- previously scheduled.
516
+ Usually, the ServerAppIo API calls this method to get results for instructions
517
+ it has previously scheduled.
457
518
 
458
519
  Retrieves all TaskRes for the given `task_ids` and returns and empty list if
459
520
  none could be found.
@@ -645,7 +706,7 @@ class SqliteState(State): # pylint: disable=R0904
645
706
  """
646
707
 
647
708
  if self.conn is None:
648
- raise AttributeError("State not intitialized")
709
+ raise AttributeError("LinkState not intitialized")
649
710
 
650
711
  with self.conn:
651
712
  self.conn.execute(query_1, data)
@@ -656,7 +717,7 @@ class SqliteState(State): # pylint: disable=R0904
656
717
  def create_node(
657
718
  self, ping_interval: float, public_key: Optional[bytes] = None
658
719
  ) -> int:
659
- """Create, store in state, and return `node_id`."""
720
+ """Create, store in the link state, and return `node_id`."""
660
721
  # Sample a random uint64 as node_id
661
722
  uint64_node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
662
723
 
@@ -706,7 +767,7 @@ class SqliteState(State): # pylint: disable=R0904
706
767
  params += (public_key,) # type: ignore
707
768
 
708
769
  if self.conn is None:
709
- raise AttributeError("State is not initialized.")
770
+ raise AttributeError("LinkState is not initialized.")
710
771
 
711
772
  try:
712
773
  with self.conn:
@@ -753,12 +814,14 @@ class SqliteState(State): # pylint: disable=R0904
753
814
  return uint64_node_id
754
815
  return None
755
816
 
817
+ # pylint: disable=too-many-arguments,too-many-positional-arguments
756
818
  def create_run(
757
819
  self,
758
820
  fab_id: Optional[str],
759
821
  fab_version: Optional[str],
760
822
  fab_hash: Optional[str],
761
823
  override_config: UserConfig,
824
+ federation_options: ConfigsRecord,
762
825
  ) -> int:
763
826
  """Create a new run for the specified `fab_id` and `fab_version`."""
764
827
  # Sample a random int64 as run_id
@@ -773,26 +836,30 @@ class SqliteState(State): # pylint: disable=R0904
773
836
  if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
774
837
  query = (
775
838
  "INSERT INTO run "
776
- "(run_id, fab_id, fab_version, fab_hash, override_config)"
777
- "VALUES (?, ?, ?, ?, ?);"
839
+ "(run_id, fab_id, fab_version, fab_hash, override_config, "
840
+ "federation_options, pending_at, starting_at, running_at, finished_at, "
841
+ "sub_status, details) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
778
842
  )
779
843
  if fab_hash:
780
- self.query(
781
- query,
782
- (sint64_run_id, "", "", fab_hash, json.dumps(override_config)),
783
- )
784
- else:
785
- self.query(
786
- query,
787
- (
788
- sint64_run_id,
789
- fab_id,
790
- fab_version,
791
- "",
792
- json.dumps(override_config),
793
- ),
794
- )
795
- # Note: we need to return the uint64 value of the run_id
844
+ fab_id, fab_version = "", ""
845
+ override_config_json = json.dumps(override_config)
846
+ data = [
847
+ sint64_run_id,
848
+ fab_id,
849
+ fab_version,
850
+ fab_hash,
851
+ override_config_json,
852
+ configsrecord_to_bytes(federation_options),
853
+ ]
854
+ data += [
855
+ now().isoformat(),
856
+ "",
857
+ "",
858
+ "",
859
+ "",
860
+ "",
861
+ ]
862
+ self.query(query, tuple(data))
796
863
  return uint64_run_id
797
864
  log(ERROR, "Unexpected run creation failure.")
798
865
  return 0
@@ -800,7 +867,7 @@ class SqliteState(State): # pylint: disable=R0904
800
867
  def store_server_private_public_key(
801
868
  self, private_key: bytes, public_key: bytes
802
869
  ) -> None:
803
- """Store `server_private_key` and `server_public_key` in state."""
870
+ """Store `server_private_key` and `server_public_key` in the link state."""
804
871
  query = "SELECT COUNT(*) FROM credential"
805
872
  count = self.query(query)[0]["COUNT(*)"]
806
873
  if count < 1:
@@ -833,13 +900,13 @@ class SqliteState(State): # pylint: disable=R0904
833
900
  return public_key
834
901
 
835
902
  def store_node_public_keys(self, public_keys: set[bytes]) -> None:
836
- """Store a set of `node_public_keys` in state."""
903
+ """Store a set of `node_public_keys` in the link state."""
837
904
  query = "INSERT INTO public_key (public_key) VALUES (?)"
838
905
  data = [(key,) for key in public_keys]
839
906
  self.query(query, data)
840
907
 
841
908
  def store_node_public_key(self, public_key: bytes) -> None:
842
- """Store a `node_public_key` in state."""
909
+ """Store a `node_public_key` in the link state."""
843
910
  query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
844
911
  self.query(query, {"public_key": public_key})
845
912
 
@@ -868,6 +935,109 @@ class SqliteState(State): # pylint: disable=R0904
868
935
  log(ERROR, "`run_id` does not exist.")
869
936
  return None
870
937
 
938
+ def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
939
+ """Retrieve the statuses for the specified runs."""
940
+ # Convert the uint64 value to sint64 for SQLite
941
+ sint64_run_ids = (convert_uint64_to_sint64(run_id) for run_id in set(run_ids))
942
+ query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});"
943
+ rows = self.query(query, tuple(sint64_run_ids))
944
+
945
+ return {
946
+ # Restore uint64 run IDs
947
+ convert_sint64_to_uint64(row["run_id"]): RunStatus(
948
+ status=determine_run_status(row),
949
+ sub_status=row["sub_status"],
950
+ details=row["details"],
951
+ )
952
+ for row in rows
953
+ }
954
+
955
+ def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
956
+ """Update the status of the run with the specified `run_id`."""
957
+ # Convert the uint64 value to sint64 for SQLite
958
+ sint64_run_id = convert_uint64_to_sint64(run_id)
959
+ query = "SELECT * FROM run WHERE run_id = ?;"
960
+ rows = self.query(query, (sint64_run_id,))
961
+
962
+ # Check if the run_id exists
963
+ if not rows:
964
+ log(ERROR, "`run_id` is invalid")
965
+ return False
966
+
967
+ # Check if the status transition is valid
968
+ row = rows[0]
969
+ current_status = RunStatus(
970
+ status=determine_run_status(row),
971
+ sub_status=row["sub_status"],
972
+ details=row["details"],
973
+ )
974
+ if not is_valid_transition(current_status, new_status):
975
+ log(
976
+ ERROR,
977
+ 'Invalid status transition: from "%s" to "%s"',
978
+ current_status.status,
979
+ new_status.status,
980
+ )
981
+ return False
982
+
983
+ # Check if the sub-status is valid
984
+ if not has_valid_sub_status(current_status):
985
+ log(
986
+ ERROR,
987
+ 'Invalid sub-status "%s" for status "%s"',
988
+ current_status.sub_status,
989
+ current_status.status,
990
+ )
991
+ return False
992
+
993
+ # Update the status
994
+ query = "UPDATE run SET %s= ?, sub_status = ?, details = ? "
995
+ query += "WHERE run_id = ?;"
996
+
997
+ timestamp_fld = ""
998
+ if new_status.status == Status.STARTING:
999
+ timestamp_fld = "starting_at"
1000
+ elif new_status.status == Status.RUNNING:
1001
+ timestamp_fld = "running_at"
1002
+ elif new_status.status == Status.FINISHED:
1003
+ timestamp_fld = "finished_at"
1004
+
1005
+ data = (
1006
+ now().isoformat(),
1007
+ new_status.sub_status,
1008
+ new_status.details,
1009
+ sint64_run_id,
1010
+ )
1011
+ self.query(query % timestamp_fld, data)
1012
+ return True
1013
+
1014
+ def get_pending_run_id(self) -> Optional[int]:
1015
+ """Get the `run_id` of a run with `Status.PENDING` status, if any."""
1016
+ pending_run_id = None
1017
+
1018
+ # Fetch all runs with unset `starting_at` (i.e. they are in PENDING status)
1019
+ query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1;"
1020
+ rows = self.query(query)
1021
+ if rows:
1022
+ pending_run_id = convert_sint64_to_uint64(rows[0]["run_id"])
1023
+
1024
+ return pending_run_id
1025
+
1026
+ def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
1027
+ """Retrieve the federation options for the specified `run_id`."""
1028
+ # Convert the uint64 value to sint64 for SQLite
1029
+ sint64_run_id = convert_uint64_to_sint64(run_id)
1030
+ query = "SELECT federation_options FROM run WHERE run_id = ?;"
1031
+ rows = self.query(query, (sint64_run_id,))
1032
+
1033
+ # Check if the run_id exists
1034
+ if not rows:
1035
+ log(ERROR, "`run_id` is invalid")
1036
+ return None
1037
+
1038
+ row = rows[0]
1039
+ return configsrecord_from_bytes(row["federation_options"])
1040
+
871
1041
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
872
1042
  """Acknowledge a ping received from a node, serving as a heartbeat."""
873
1043
  sint64_node_id = convert_uint64_to_sint64(node_id)
@@ -883,6 +1053,72 @@ class SqliteState(State): # pylint: disable=R0904
883
1053
  log(ERROR, "`node_id` does not exist.")
884
1054
  return False
885
1055
 
1056
+ def get_serverapp_context(self, run_id: int) -> Optional[Context]:
1057
+ """Get the context for the specified `run_id`."""
1058
+ # Retrieve context if any
1059
+ query = "SELECT context FROM context WHERE run_id = ?;"
1060
+ rows = self.query(query, (convert_uint64_to_sint64(run_id),))
1061
+ context = context_from_bytes(rows[0]["context"]) if rows else None
1062
+ return context
1063
+
1064
+ def set_serverapp_context(self, run_id: int, context: Context) -> None:
1065
+ """Set the context for the specified `run_id`."""
1066
+ # Convert context to bytes
1067
+ context_bytes = context_to_bytes(context)
1068
+ sint_run_id = convert_uint64_to_sint64(run_id)
1069
+
1070
+ # Check if any existing Context assigned to the run_id
1071
+ query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
1072
+ if self.query(query, (sint_run_id,))[0]["COUNT(*)"] > 0:
1073
+ # Update context
1074
+ query = "UPDATE context SET context = ? WHERE run_id = ?;"
1075
+ self.query(query, (context_bytes, sint_run_id))
1076
+ else:
1077
+ try:
1078
+ # Store context
1079
+ query = "INSERT INTO context (run_id, context) VALUES (?, ?);"
1080
+ self.query(query, (sint_run_id, context_bytes))
1081
+ except sqlite3.IntegrityError:
1082
+ raise ValueError(f"Run {run_id} not found") from None
1083
+
1084
+ def add_serverapp_log(self, run_id: int, log_message: str) -> None:
1085
+ """Add a log entry to the ServerApp logs for the specified `run_id`."""
1086
+ # Convert the uint64 value to sint64 for SQLite
1087
+ sint64_run_id = convert_uint64_to_sint64(run_id)
1088
+
1089
+ # Store log
1090
+ try:
1091
+ query = """
1092
+ INSERT INTO logs (timestamp, run_id, node_id, log) VALUES (?, ?, ?, ?);
1093
+ """
1094
+ self.query(query, (now().timestamp(), sint64_run_id, 0, log_message))
1095
+ except sqlite3.IntegrityError:
1096
+ raise ValueError(f"Run {run_id} not found") from None
1097
+
1098
+ def get_serverapp_log(
1099
+ self, run_id: int, after_timestamp: Optional[float]
1100
+ ) -> tuple[str, float]:
1101
+ """Get the ServerApp logs for the specified `run_id`."""
1102
+ # Convert the uint64 value to sint64 for SQLite
1103
+ sint64_run_id = convert_uint64_to_sint64(run_id)
1104
+
1105
+ # Check if the run_id exists
1106
+ query = "SELECT run_id FROM run WHERE run_id = ?;"
1107
+ if not self.query(query, (sint64_run_id,)):
1108
+ raise ValueError(f"Run {run_id} not found")
1109
+
1110
+ # Retrieve logs
1111
+ if after_timestamp is None:
1112
+ after_timestamp = 0.0
1113
+ query = """
1114
+ SELECT log, timestamp FROM logs
1115
+ WHERE run_id = ? AND node_id = ? AND timestamp > ?;
1116
+ """
1117
+ rows = self.query(query, (sint64_run_id, 0, after_timestamp))
1118
+ rows.sort(key=lambda x: x["timestamp"])
1119
+ latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
1120
+ return "".join(row["log"] for row in rows), latest_timestamp
1121
+
886
1122
  def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]:
887
1123
  """Check if the TaskIns exists and is valid (not expired).
888
1124
 
@@ -967,7 +1203,7 @@ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
967
1203
 
968
1204
  def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
969
1205
  """Turn task_dict into protobuf message."""
970
- recordset = RecordSet()
1206
+ recordset = ProtoRecordSet()
971
1207
  recordset.ParseFromString(task_dict["recordset"])
972
1208
 
973
1209
  result = TaskIns(
@@ -997,7 +1233,7 @@ def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
997
1233
 
998
1234
  def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
999
1235
  """Turn task_dict into protobuf message."""
1000
- recordset = RecordSet()
1236
+ recordset = ProtoRecordSet()
1001
1237
  recordset.ParseFromString(task_dict["recordset"])
1002
1238
 
1003
1239
  result = TaskRes(
@@ -1023,3 +1259,17 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
1023
1259
  ),
1024
1260
  )
1025
1261
  return result
1262
+
1263
+
1264
+ def determine_run_status(row: dict[str, Any]) -> str:
1265
+ """Determine the status of the run based on timestamp fields."""
1266
+ if row["pending_at"]:
1267
+ if row["starting_at"]:
1268
+ if row["running_at"]:
1269
+ if row["finished_at"]:
1270
+ return Status.FINISHED
1271
+ return Status.RUNNING
1272
+ return Status.STARTING
1273
+ return Status.PENDING
1274
+ run_id = convert_sint64_to_uint64(row["run_id"])
1275
+ raise sqlite3.IntegrityError(f"The run {run_id} does not have a valid status.")
@@ -20,10 +20,15 @@ from logging import ERROR
20
20
  from os import urandom
21
21
  from uuid import uuid4
22
22
 
23
- from flwr.common import log
24
- from flwr.common.constant import ErrorCode
23
+ from flwr.common import ConfigsRecord, Context, log, serde
24
+ from flwr.common.constant import ErrorCode, Status, SubStatus
25
+ from flwr.common.typing import RunStatus
25
26
  from flwr.proto.error_pb2 import Error # pylint: disable=E0611
27
+ from flwr.proto.message_pb2 import Context as ProtoContext # pylint: disable=E0611
26
28
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
29
+
30
+ # pylint: disable=E0611
31
+ from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
27
32
  from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
28
33
 
29
34
  NODE_UNAVAILABLE_ERROR_REASON = (
@@ -31,6 +36,17 @@ NODE_UNAVAILABLE_ERROR_REASON = (
31
36
  "It exceeds the time limit specified in its last ping."
32
37
  )
33
38
 
39
+ VALID_RUN_STATUS_TRANSITIONS = {
40
+ (Status.PENDING, Status.STARTING),
41
+ (Status.STARTING, Status.RUNNING),
42
+ (Status.RUNNING, Status.FINISHED),
43
+ }
44
+ VALID_RUN_SUB_STATUSES = {
45
+ SubStatus.COMPLETED,
46
+ SubStatus.FAILED,
47
+ SubStatus.STOPPED,
48
+ }
49
+
34
50
 
35
51
  def generate_rand_int_from_bytes(num_bytes: int) -> int:
36
52
  """Generate a random unsigned integer from `num_bytes` bytes."""
@@ -123,6 +139,28 @@ def convert_sint64_values_in_dict_to_uint64(
123
139
  data_dict[key] = convert_sint64_to_uint64(data_dict[key])
124
140
 
125
141
 
142
+ def context_to_bytes(context: Context) -> bytes:
143
+ """Serialize `Context` to bytes."""
144
+ return serde.context_to_proto(context).SerializeToString()
145
+
146
+
147
+ def context_from_bytes(context_bytes: bytes) -> Context:
148
+ """Deserialize `Context` from bytes."""
149
+ return serde.context_from_proto(ProtoContext.FromString(context_bytes))
150
+
151
+
152
+ def configsrecord_to_bytes(configs_record: ConfigsRecord) -> bytes:
153
+ """Serialize a `ConfigsRecord` to bytes."""
154
+ return serde.configs_record_to_proto(configs_record).SerializeToString()
155
+
156
+
157
+ def configsrecord_from_bytes(configsrecord_bytes: bytes) -> ConfigsRecord:
158
+ """Deserialize `ConfigsRecord` from bytes."""
159
+ return serde.configs_record_from_proto(
160
+ ProtoConfigsRecord.FromString(configsrecord_bytes)
161
+ )
162
+
163
+
126
164
  def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
127
165
  """Generate a TaskRes with a node unavailable error from a TaskIns."""
128
166
  current_time = time.time()
@@ -146,3 +184,47 @@ def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
146
184
  ),
147
185
  ),
148
186
  )
187
+
188
+
189
+ def is_valid_transition(current_status: RunStatus, new_status: RunStatus) -> bool:
190
+ """Check if a transition between two run statuses is valid.
191
+
192
+ Parameters
193
+ ----------
194
+ current_status : RunStatus
195
+ The current status of the run.
196
+ new_status : RunStatus
197
+ The new status to transition to.
198
+
199
+ Returns
200
+ -------
201
+ bool
202
+ True if the transition is valid, False otherwise.
203
+ """
204
+ return (
205
+ current_status.status,
206
+ new_status.status,
207
+ ) in VALID_RUN_STATUS_TRANSITIONS
208
+
209
+
210
+ def has_valid_sub_status(status: RunStatus) -> bool:
211
+ """Check if the 'sub_status' field of the given status is valid.
212
+
213
+ Parameters
214
+ ----------
215
+ status : RunStatus
216
+ The status object to be checked.
217
+
218
+ Returns
219
+ -------
220
+ bool
221
+ True if the status object has a valid sub-status, False otherwise.
222
+
223
+ Notes
224
+ -----
225
+ Only an empty string (i.e., "") is considered a valid sub-status for
226
+ non-finished statuses. The sub-status of a finished status cannot be empty.
227
+ """
228
+ if status.status == Status.FINISHED:
229
+ return status.sub_status in VALID_RUN_SUB_STATUSES
230
+ return status.sub_status == ""
@@ -0,0 +1,15 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower SimulationIo service."""