flwr-nightly 1.13.0.dev20241019__py3-none-any.whl → 1.13.0.dev20241106__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +2 -2
- flwr/cli/config_utils.py +97 -0
- flwr/cli/log.py +63 -97
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +18 -83
- flwr/client/app.py +13 -14
- flwr/client/clientapp/app.py +1 -2
- flwr/client/{node_state.py → run_info_store.py} +4 -3
- flwr/client/supernode/app.py +6 -8
- flwr/common/constant.py +39 -4
- flwr/common/context.py +9 -4
- flwr/common/date.py +3 -3
- flwr/common/logger.py +103 -0
- flwr/common/serde.py +24 -0
- flwr/common/telemetry.py +0 -6
- flwr/common/typing.py +9 -0
- flwr/proto/exec_pb2.py +6 -6
- flwr/proto/exec_pb2.pyi +8 -2
- flwr/proto/log_pb2.py +29 -0
- flwr/proto/log_pb2.pyi +39 -0
- flwr/proto/log_pb2_grpc.py +4 -0
- flwr/proto/log_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +8 -8
- flwr/proto/message_pb2.pyi +4 -1
- flwr/proto/serverappio_pb2.py +52 -0
- flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +54 -0
- flwr/proto/serverappio_pb2_grpc.py +376 -0
- flwr/proto/serverappio_pb2_grpc.pyi +147 -0
- flwr/proto/simulationio_pb2.py +38 -0
- flwr/proto/simulationio_pb2.pyi +65 -0
- flwr/proto/simulationio_pb2_grpc.py +171 -0
- flwr/proto/simulationio_pb2_grpc.pyi +68 -0
- flwr/server/app.py +247 -105
- flwr/server/driver/driver.py +15 -1
- flwr/server/driver/grpc_driver.py +26 -33
- flwr/server/driver/inmemory_driver.py +6 -14
- flwr/server/run_serverapp.py +29 -23
- flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
- flwr/server/serverapp/app.py +270 -0
- flwr/server/strategy/fedadam.py +11 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/{driver_grpc.py → serverappio_grpc.py} +19 -16
- flwr/server/superlink/driver/{driver_servicer.py → serverappio_servicer.py} +125 -39
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -2
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
- flwr/server/superlink/fleet/rest_rere/rest_api.py +7 -7
- flwr/server/superlink/fleet/vce/vce_api.py +23 -23
- flwr/server/superlink/linkstate/__init__.py +28 -0
- flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +180 -21
- flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +144 -15
- flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
- flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +300 -50
- flwr/server/superlink/{state → linkstate}/utils.py +84 -2
- flwr/server/superlink/simulation/__init__.py +15 -0
- flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
- flwr/server/superlink/simulation/simulationio_servicer.py +132 -0
- flwr/simulation/__init__.py +2 -0
- flwr/simulation/app.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
- flwr/simulation/run_simulation.py +57 -131
- flwr/simulation/simulationio_connection.py +86 -0
- flwr/superexec/app.py +6 -134
- flwr/superexec/deployment.py +60 -65
- flwr/superexec/exec_grpc.py +15 -8
- flwr/superexec/exec_servicer.py +34 -63
- flwr/superexec/executor.py +22 -4
- flwr/superexec/simulation.py +13 -8
- {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/METADATA +1 -1
- {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/RECORD +77 -64
- {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/entry_points.txt +1 -0
- flwr/client/node_state_tests.py +0 -66
- flwr/proto/driver_pb2.py +0 -42
- flwr/proto/driver_pb2_grpc.py +0 -239
- flwr/proto/driver_pb2_grpc.pyi +0 -94
- {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.13.0.dev20241019.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/WHEEL +0 -0
|
@@ -12,38 +12,51 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""SQLite based implemenation of
|
|
15
|
+
"""SQLite based implemenation of the link state."""
|
|
16
16
|
|
|
17
17
|
# pylint: disable=too-many-lines
|
|
18
18
|
|
|
19
19
|
import json
|
|
20
20
|
import re
|
|
21
21
|
import sqlite3
|
|
22
|
+
import threading
|
|
22
23
|
import time
|
|
23
24
|
from collections.abc import Sequence
|
|
24
25
|
from logging import DEBUG, ERROR, WARNING
|
|
25
26
|
from typing import Any, Optional, Union, cast
|
|
26
27
|
from uuid import UUID, uuid4
|
|
27
28
|
|
|
28
|
-
from flwr.common import log, now
|
|
29
|
+
from flwr.common import Context, log, now
|
|
29
30
|
from flwr.common.constant import (
|
|
30
31
|
MESSAGE_TTL_TOLERANCE,
|
|
31
32
|
NODE_ID_NUM_BYTES,
|
|
32
33
|
RUN_ID_NUM_BYTES,
|
|
34
|
+
Status,
|
|
33
35
|
)
|
|
34
|
-
from flwr.common.
|
|
35
|
-
from flwr.
|
|
36
|
-
|
|
37
|
-
|
|
36
|
+
from flwr.common.record import ConfigsRecord
|
|
37
|
+
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
38
|
+
|
|
39
|
+
# pylint: disable=E0611
|
|
40
|
+
from flwr.proto.node_pb2 import Node
|
|
41
|
+
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
|
|
42
|
+
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
|
43
|
+
|
|
44
|
+
# pylint: enable=E0611
|
|
38
45
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
39
46
|
|
|
40
|
-
from .
|
|
47
|
+
from .linkstate import LinkState
|
|
41
48
|
from .utils import (
|
|
49
|
+
configsrecord_from_bytes,
|
|
50
|
+
configsrecord_to_bytes,
|
|
51
|
+
context_from_bytes,
|
|
52
|
+
context_to_bytes,
|
|
42
53
|
convert_sint64_to_uint64,
|
|
43
54
|
convert_sint64_values_in_dict_to_uint64,
|
|
44
55
|
convert_uint64_to_sint64,
|
|
45
56
|
convert_uint64_values_in_dict_to_sint64,
|
|
46
57
|
generate_rand_int_from_bytes,
|
|
58
|
+
has_valid_sub_status,
|
|
59
|
+
is_valid_transition,
|
|
47
60
|
make_node_unavailable_taskres,
|
|
48
61
|
)
|
|
49
62
|
|
|
@@ -79,7 +92,33 @@ CREATE TABLE IF NOT EXISTS run(
|
|
|
79
92
|
fab_id TEXT,
|
|
80
93
|
fab_version TEXT,
|
|
81
94
|
fab_hash TEXT,
|
|
82
|
-
override_config TEXT
|
|
95
|
+
override_config TEXT,
|
|
96
|
+
pending_at TEXT,
|
|
97
|
+
starting_at TEXT,
|
|
98
|
+
running_at TEXT,
|
|
99
|
+
finished_at TEXT,
|
|
100
|
+
sub_status TEXT,
|
|
101
|
+
details TEXT,
|
|
102
|
+
federation_options BLOB
|
|
103
|
+
);
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
SQL_CREATE_TABLE_LOGS = """
|
|
107
|
+
CREATE TABLE IF NOT EXISTS logs (
|
|
108
|
+
timestamp REAL,
|
|
109
|
+
run_id INTEGER,
|
|
110
|
+
node_id INTEGER,
|
|
111
|
+
log TEXT,
|
|
112
|
+
PRIMARY KEY (timestamp, run_id, node_id),
|
|
113
|
+
FOREIGN KEY (run_id) REFERENCES run(run_id)
|
|
114
|
+
);
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
SQL_CREATE_TABLE_CONTEXT = """
|
|
118
|
+
CREATE TABLE IF NOT EXISTS context(
|
|
119
|
+
run_id INTEGER UNIQUE,
|
|
120
|
+
context BLOB,
|
|
121
|
+
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
|
83
122
|
);
|
|
84
123
|
"""
|
|
85
124
|
|
|
@@ -126,14 +165,14 @@ CREATE TABLE IF NOT EXISTS task_res(
|
|
|
126
165
|
DictOrTuple = Union[tuple[Any, ...], dict[str, Any]]
|
|
127
166
|
|
|
128
167
|
|
|
129
|
-
class
|
|
130
|
-
"""SQLite-based
|
|
168
|
+
class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
169
|
+
"""SQLite-based LinkState implementation."""
|
|
131
170
|
|
|
132
171
|
def __init__(
|
|
133
172
|
self,
|
|
134
173
|
database_path: str,
|
|
135
174
|
) -> None:
|
|
136
|
-
"""Initialize an
|
|
175
|
+
"""Initialize an SqliteLinkState.
|
|
137
176
|
|
|
138
177
|
Parameters
|
|
139
178
|
----------
|
|
@@ -143,6 +182,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
143
182
|
"""
|
|
144
183
|
self.database_path = database_path
|
|
145
184
|
self.conn: Optional[sqlite3.Connection] = None
|
|
185
|
+
self.lock = threading.RLock()
|
|
146
186
|
|
|
147
187
|
def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
|
|
148
188
|
"""Create tables if they don't exist yet.
|
|
@@ -166,6 +206,8 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
166
206
|
|
|
167
207
|
# Create each table if not exists queries
|
|
168
208
|
cur.execute(SQL_CREATE_TABLE_RUN)
|
|
209
|
+
cur.execute(SQL_CREATE_TABLE_LOGS)
|
|
210
|
+
cur.execute(SQL_CREATE_TABLE_CONTEXT)
|
|
169
211
|
cur.execute(SQL_CREATE_TABLE_TASK_INS)
|
|
170
212
|
cur.execute(SQL_CREATE_TABLE_TASK_RES)
|
|
171
213
|
cur.execute(SQL_CREATE_TABLE_NODE)
|
|
@@ -183,7 +225,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
183
225
|
) -> list[dict[str, Any]]:
|
|
184
226
|
"""Execute a SQL query."""
|
|
185
227
|
if self.conn is None:
|
|
186
|
-
raise AttributeError("
|
|
228
|
+
raise AttributeError("LinkState is not initialized.")
|
|
187
229
|
|
|
188
230
|
if data is None:
|
|
189
231
|
data = []
|
|
@@ -214,11 +256,11 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
214
256
|
def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
|
|
215
257
|
"""Store one TaskIns.
|
|
216
258
|
|
|
217
|
-
Usually, the
|
|
259
|
+
Usually, the ServerAppIo API calls this to schedule instructions.
|
|
218
260
|
|
|
219
|
-
Stores the value of the task_ins in the state and, if successful,
|
|
220
|
-
task_id (UUID) of the task_ins. If, for any reason, storing
|
|
221
|
-
`None` is returned.
|
|
261
|
+
Stores the value of the task_ins in the link state and, if successful,
|
|
262
|
+
returns the task_id (UUID) of the task_ins. If, for any reason, storing
|
|
263
|
+
the task_ins fails, `None` is returned.
|
|
222
264
|
|
|
223
265
|
Constraints
|
|
224
266
|
-----------
|
|
@@ -233,7 +275,6 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
233
275
|
if any(errors):
|
|
234
276
|
log(ERROR, errors)
|
|
235
277
|
return None
|
|
236
|
-
|
|
237
278
|
# Create task_id
|
|
238
279
|
task_id = uuid4()
|
|
239
280
|
|
|
@@ -246,16 +287,36 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
246
287
|
data[0], ["run_id", "producer_node_id", "consumer_node_id"]
|
|
247
288
|
)
|
|
248
289
|
|
|
290
|
+
# Validate run_id
|
|
291
|
+
query = "SELECT run_id FROM run WHERE run_id = ?;"
|
|
292
|
+
if not self.query(query, (data[0]["run_id"],)):
|
|
293
|
+
log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
|
|
294
|
+
return None
|
|
295
|
+
# Validate source node ID
|
|
296
|
+
if task_ins.task.producer.node_id != 0:
|
|
297
|
+
log(
|
|
298
|
+
ERROR,
|
|
299
|
+
"Invalid source node ID for TaskIns: %s",
|
|
300
|
+
task_ins.task.producer.node_id,
|
|
301
|
+
)
|
|
302
|
+
return None
|
|
303
|
+
# Validate destination node ID
|
|
304
|
+
query = "SELECT node_id FROM node WHERE node_id = ?;"
|
|
305
|
+
if not task_ins.task.consumer.anonymous:
|
|
306
|
+
if not self.query(query, (data[0]["consumer_node_id"],)):
|
|
307
|
+
log(
|
|
308
|
+
ERROR,
|
|
309
|
+
"Invalid destination node ID for TaskIns: %s",
|
|
310
|
+
task_ins.task.consumer.node_id,
|
|
311
|
+
)
|
|
312
|
+
return None
|
|
313
|
+
|
|
249
314
|
columns = ", ".join([f":{key}" for key in data[0]])
|
|
250
315
|
query = f"INSERT INTO task_ins VALUES({columns});"
|
|
251
316
|
|
|
252
317
|
# Only invalid run_id can trigger IntegrityError.
|
|
253
318
|
# This may need to be changed in the future version with more integrity checks.
|
|
254
|
-
|
|
255
|
-
self.query(query, data)
|
|
256
|
-
except sqlite3.IntegrityError:
|
|
257
|
-
log(ERROR, "`run` is invalid")
|
|
258
|
-
return None
|
|
319
|
+
self.query(query, data)
|
|
259
320
|
|
|
260
321
|
return task_id
|
|
261
322
|
|
|
@@ -452,8 +513,8 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
452
513
|
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
|
453
514
|
"""Get TaskRes for task_ids.
|
|
454
515
|
|
|
455
|
-
Usually, the
|
|
456
|
-
previously scheduled.
|
|
516
|
+
Usually, the ServerAppIo API calls this method to get results for instructions
|
|
517
|
+
it has previously scheduled.
|
|
457
518
|
|
|
458
519
|
Retrieves all TaskRes for the given `task_ids` and returns and empty list if
|
|
459
520
|
none could be found.
|
|
@@ -645,7 +706,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
645
706
|
"""
|
|
646
707
|
|
|
647
708
|
if self.conn is None:
|
|
648
|
-
raise AttributeError("
|
|
709
|
+
raise AttributeError("LinkState not intitialized")
|
|
649
710
|
|
|
650
711
|
with self.conn:
|
|
651
712
|
self.conn.execute(query_1, data)
|
|
@@ -656,7 +717,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
656
717
|
def create_node(
|
|
657
718
|
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
658
719
|
) -> int:
|
|
659
|
-
"""Create, store in state, and return `node_id`."""
|
|
720
|
+
"""Create, store in the link state, and return `node_id`."""
|
|
660
721
|
# Sample a random uint64 as node_id
|
|
661
722
|
uint64_node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
|
|
662
723
|
|
|
@@ -706,7 +767,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
706
767
|
params += (public_key,) # type: ignore
|
|
707
768
|
|
|
708
769
|
if self.conn is None:
|
|
709
|
-
raise AttributeError("
|
|
770
|
+
raise AttributeError("LinkState is not initialized.")
|
|
710
771
|
|
|
711
772
|
try:
|
|
712
773
|
with self.conn:
|
|
@@ -753,12 +814,14 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
753
814
|
return uint64_node_id
|
|
754
815
|
return None
|
|
755
816
|
|
|
817
|
+
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
756
818
|
def create_run(
|
|
757
819
|
self,
|
|
758
820
|
fab_id: Optional[str],
|
|
759
821
|
fab_version: Optional[str],
|
|
760
822
|
fab_hash: Optional[str],
|
|
761
823
|
override_config: UserConfig,
|
|
824
|
+
federation_options: ConfigsRecord,
|
|
762
825
|
) -> int:
|
|
763
826
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
764
827
|
# Sample a random int64 as run_id
|
|
@@ -773,26 +836,30 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
773
836
|
if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
|
|
774
837
|
query = (
|
|
775
838
|
"INSERT INTO run "
|
|
776
|
-
"(run_id, fab_id, fab_version, fab_hash, override_config
|
|
777
|
-
"
|
|
839
|
+
"(run_id, fab_id, fab_version, fab_hash, override_config, "
|
|
840
|
+
"federation_options, pending_at, starting_at, running_at, finished_at, "
|
|
841
|
+
"sub_status, details) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
|
|
778
842
|
)
|
|
779
843
|
if fab_hash:
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
844
|
+
fab_id, fab_version = "", ""
|
|
845
|
+
override_config_json = json.dumps(override_config)
|
|
846
|
+
data = [
|
|
847
|
+
sint64_run_id,
|
|
848
|
+
fab_id,
|
|
849
|
+
fab_version,
|
|
850
|
+
fab_hash,
|
|
851
|
+
override_config_json,
|
|
852
|
+
configsrecord_to_bytes(federation_options),
|
|
853
|
+
]
|
|
854
|
+
data += [
|
|
855
|
+
now().isoformat(),
|
|
856
|
+
"",
|
|
857
|
+
"",
|
|
858
|
+
"",
|
|
859
|
+
"",
|
|
860
|
+
"",
|
|
861
|
+
]
|
|
862
|
+
self.query(query, tuple(data))
|
|
796
863
|
return uint64_run_id
|
|
797
864
|
log(ERROR, "Unexpected run creation failure.")
|
|
798
865
|
return 0
|
|
@@ -800,7 +867,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
800
867
|
def store_server_private_public_key(
|
|
801
868
|
self, private_key: bytes, public_key: bytes
|
|
802
869
|
) -> None:
|
|
803
|
-
"""Store `server_private_key` and `server_public_key` in state."""
|
|
870
|
+
"""Store `server_private_key` and `server_public_key` in the link state."""
|
|
804
871
|
query = "SELECT COUNT(*) FROM credential"
|
|
805
872
|
count = self.query(query)[0]["COUNT(*)"]
|
|
806
873
|
if count < 1:
|
|
@@ -833,13 +900,13 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
833
900
|
return public_key
|
|
834
901
|
|
|
835
902
|
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
836
|
-
"""Store a set of `node_public_keys` in state."""
|
|
903
|
+
"""Store a set of `node_public_keys` in the link state."""
|
|
837
904
|
query = "INSERT INTO public_key (public_key) VALUES (?)"
|
|
838
905
|
data = [(key,) for key in public_keys]
|
|
839
906
|
self.query(query, data)
|
|
840
907
|
|
|
841
908
|
def store_node_public_key(self, public_key: bytes) -> None:
|
|
842
|
-
"""Store a `node_public_key` in state."""
|
|
909
|
+
"""Store a `node_public_key` in the link state."""
|
|
843
910
|
query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
|
|
844
911
|
self.query(query, {"public_key": public_key})
|
|
845
912
|
|
|
@@ -868,6 +935,109 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
868
935
|
log(ERROR, "`run_id` does not exist.")
|
|
869
936
|
return None
|
|
870
937
|
|
|
938
|
+
def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
|
|
939
|
+
"""Retrieve the statuses for the specified runs."""
|
|
940
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
941
|
+
sint64_run_ids = (convert_uint64_to_sint64(run_id) for run_id in set(run_ids))
|
|
942
|
+
query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});"
|
|
943
|
+
rows = self.query(query, tuple(sint64_run_ids))
|
|
944
|
+
|
|
945
|
+
return {
|
|
946
|
+
# Restore uint64 run IDs
|
|
947
|
+
convert_sint64_to_uint64(row["run_id"]): RunStatus(
|
|
948
|
+
status=determine_run_status(row),
|
|
949
|
+
sub_status=row["sub_status"],
|
|
950
|
+
details=row["details"],
|
|
951
|
+
)
|
|
952
|
+
for row in rows
|
|
953
|
+
}
|
|
954
|
+
|
|
955
|
+
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
|
|
956
|
+
"""Update the status of the run with the specified `run_id`."""
|
|
957
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
958
|
+
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
959
|
+
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
960
|
+
rows = self.query(query, (sint64_run_id,))
|
|
961
|
+
|
|
962
|
+
# Check if the run_id exists
|
|
963
|
+
if not rows:
|
|
964
|
+
log(ERROR, "`run_id` is invalid")
|
|
965
|
+
return False
|
|
966
|
+
|
|
967
|
+
# Check if the status transition is valid
|
|
968
|
+
row = rows[0]
|
|
969
|
+
current_status = RunStatus(
|
|
970
|
+
status=determine_run_status(row),
|
|
971
|
+
sub_status=row["sub_status"],
|
|
972
|
+
details=row["details"],
|
|
973
|
+
)
|
|
974
|
+
if not is_valid_transition(current_status, new_status):
|
|
975
|
+
log(
|
|
976
|
+
ERROR,
|
|
977
|
+
'Invalid status transition: from "%s" to "%s"',
|
|
978
|
+
current_status.status,
|
|
979
|
+
new_status.status,
|
|
980
|
+
)
|
|
981
|
+
return False
|
|
982
|
+
|
|
983
|
+
# Check if the sub-status is valid
|
|
984
|
+
if not has_valid_sub_status(current_status):
|
|
985
|
+
log(
|
|
986
|
+
ERROR,
|
|
987
|
+
'Invalid sub-status "%s" for status "%s"',
|
|
988
|
+
current_status.sub_status,
|
|
989
|
+
current_status.status,
|
|
990
|
+
)
|
|
991
|
+
return False
|
|
992
|
+
|
|
993
|
+
# Update the status
|
|
994
|
+
query = "UPDATE run SET %s= ?, sub_status = ?, details = ? "
|
|
995
|
+
query += "WHERE run_id = ?;"
|
|
996
|
+
|
|
997
|
+
timestamp_fld = ""
|
|
998
|
+
if new_status.status == Status.STARTING:
|
|
999
|
+
timestamp_fld = "starting_at"
|
|
1000
|
+
elif new_status.status == Status.RUNNING:
|
|
1001
|
+
timestamp_fld = "running_at"
|
|
1002
|
+
elif new_status.status == Status.FINISHED:
|
|
1003
|
+
timestamp_fld = "finished_at"
|
|
1004
|
+
|
|
1005
|
+
data = (
|
|
1006
|
+
now().isoformat(),
|
|
1007
|
+
new_status.sub_status,
|
|
1008
|
+
new_status.details,
|
|
1009
|
+
sint64_run_id,
|
|
1010
|
+
)
|
|
1011
|
+
self.query(query % timestamp_fld, data)
|
|
1012
|
+
return True
|
|
1013
|
+
|
|
1014
|
+
def get_pending_run_id(self) -> Optional[int]:
|
|
1015
|
+
"""Get the `run_id` of a run with `Status.PENDING` status, if any."""
|
|
1016
|
+
pending_run_id = None
|
|
1017
|
+
|
|
1018
|
+
# Fetch all runs with unset `starting_at` (i.e. they are in PENDING status)
|
|
1019
|
+
query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1;"
|
|
1020
|
+
rows = self.query(query)
|
|
1021
|
+
if rows:
|
|
1022
|
+
pending_run_id = convert_sint64_to_uint64(rows[0]["run_id"])
|
|
1023
|
+
|
|
1024
|
+
return pending_run_id
|
|
1025
|
+
|
|
1026
|
+
def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
|
|
1027
|
+
"""Retrieve the federation options for the specified `run_id`."""
|
|
1028
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
1029
|
+
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
1030
|
+
query = "SELECT federation_options FROM run WHERE run_id = ?;"
|
|
1031
|
+
rows = self.query(query, (sint64_run_id,))
|
|
1032
|
+
|
|
1033
|
+
# Check if the run_id exists
|
|
1034
|
+
if not rows:
|
|
1035
|
+
log(ERROR, "`run_id` is invalid")
|
|
1036
|
+
return None
|
|
1037
|
+
|
|
1038
|
+
row = rows[0]
|
|
1039
|
+
return configsrecord_from_bytes(row["federation_options"])
|
|
1040
|
+
|
|
871
1041
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
872
1042
|
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
873
1043
|
sint64_node_id = convert_uint64_to_sint64(node_id)
|
|
@@ -883,6 +1053,72 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
883
1053
|
log(ERROR, "`node_id` does not exist.")
|
|
884
1054
|
return False
|
|
885
1055
|
|
|
1056
|
+
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
|
1057
|
+
"""Get the context for the specified `run_id`."""
|
|
1058
|
+
# Retrieve context if any
|
|
1059
|
+
query = "SELECT context FROM context WHERE run_id = ?;"
|
|
1060
|
+
rows = self.query(query, (convert_uint64_to_sint64(run_id),))
|
|
1061
|
+
context = context_from_bytes(rows[0]["context"]) if rows else None
|
|
1062
|
+
return context
|
|
1063
|
+
|
|
1064
|
+
def set_serverapp_context(self, run_id: int, context: Context) -> None:
|
|
1065
|
+
"""Set the context for the specified `run_id`."""
|
|
1066
|
+
# Convert context to bytes
|
|
1067
|
+
context_bytes = context_to_bytes(context)
|
|
1068
|
+
sint_run_id = convert_uint64_to_sint64(run_id)
|
|
1069
|
+
|
|
1070
|
+
# Check if any existing Context assigned to the run_id
|
|
1071
|
+
query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
|
|
1072
|
+
if self.query(query, (sint_run_id,))[0]["COUNT(*)"] > 0:
|
|
1073
|
+
# Update context
|
|
1074
|
+
query = "UPDATE context SET context = ? WHERE run_id = ?;"
|
|
1075
|
+
self.query(query, (context_bytes, sint_run_id))
|
|
1076
|
+
else:
|
|
1077
|
+
try:
|
|
1078
|
+
# Store context
|
|
1079
|
+
query = "INSERT INTO context (run_id, context) VALUES (?, ?);"
|
|
1080
|
+
self.query(query, (sint_run_id, context_bytes))
|
|
1081
|
+
except sqlite3.IntegrityError:
|
|
1082
|
+
raise ValueError(f"Run {run_id} not found") from None
|
|
1083
|
+
|
|
1084
|
+
def add_serverapp_log(self, run_id: int, log_message: str) -> None:
|
|
1085
|
+
"""Add a log entry to the ServerApp logs for the specified `run_id`."""
|
|
1086
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
1087
|
+
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
1088
|
+
|
|
1089
|
+
# Store log
|
|
1090
|
+
try:
|
|
1091
|
+
query = """
|
|
1092
|
+
INSERT INTO logs (timestamp, run_id, node_id, log) VALUES (?, ?, ?, ?);
|
|
1093
|
+
"""
|
|
1094
|
+
self.query(query, (now().timestamp(), sint64_run_id, 0, log_message))
|
|
1095
|
+
except sqlite3.IntegrityError:
|
|
1096
|
+
raise ValueError(f"Run {run_id} not found") from None
|
|
1097
|
+
|
|
1098
|
+
def get_serverapp_log(
|
|
1099
|
+
self, run_id: int, after_timestamp: Optional[float]
|
|
1100
|
+
) -> tuple[str, float]:
|
|
1101
|
+
"""Get the ServerApp logs for the specified `run_id`."""
|
|
1102
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
1103
|
+
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
1104
|
+
|
|
1105
|
+
# Check if the run_id exists
|
|
1106
|
+
query = "SELECT run_id FROM run WHERE run_id = ?;"
|
|
1107
|
+
if not self.query(query, (sint64_run_id,)):
|
|
1108
|
+
raise ValueError(f"Run {run_id} not found")
|
|
1109
|
+
|
|
1110
|
+
# Retrieve logs
|
|
1111
|
+
if after_timestamp is None:
|
|
1112
|
+
after_timestamp = 0.0
|
|
1113
|
+
query = """
|
|
1114
|
+
SELECT log, timestamp FROM logs
|
|
1115
|
+
WHERE run_id = ? AND node_id = ? AND timestamp > ?;
|
|
1116
|
+
"""
|
|
1117
|
+
rows = self.query(query, (sint64_run_id, 0, after_timestamp))
|
|
1118
|
+
rows.sort(key=lambda x: x["timestamp"])
|
|
1119
|
+
latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
|
|
1120
|
+
return "".join(row["log"] for row in rows), latest_timestamp
|
|
1121
|
+
|
|
886
1122
|
def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]:
|
|
887
1123
|
"""Check if the TaskIns exists and is valid (not expired).
|
|
888
1124
|
|
|
@@ -967,7 +1203,7 @@ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
|
|
|
967
1203
|
|
|
968
1204
|
def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
|
|
969
1205
|
"""Turn task_dict into protobuf message."""
|
|
970
|
-
recordset =
|
|
1206
|
+
recordset = ProtoRecordSet()
|
|
971
1207
|
recordset.ParseFromString(task_dict["recordset"])
|
|
972
1208
|
|
|
973
1209
|
result = TaskIns(
|
|
@@ -997,7 +1233,7 @@ def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
|
|
|
997
1233
|
|
|
998
1234
|
def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
|
|
999
1235
|
"""Turn task_dict into protobuf message."""
|
|
1000
|
-
recordset =
|
|
1236
|
+
recordset = ProtoRecordSet()
|
|
1001
1237
|
recordset.ParseFromString(task_dict["recordset"])
|
|
1002
1238
|
|
|
1003
1239
|
result = TaskRes(
|
|
@@ -1023,3 +1259,17 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
|
|
|
1023
1259
|
),
|
|
1024
1260
|
)
|
|
1025
1261
|
return result
|
|
1262
|
+
|
|
1263
|
+
|
|
1264
|
+
def determine_run_status(row: dict[str, Any]) -> str:
|
|
1265
|
+
"""Determine the status of the run based on timestamp fields."""
|
|
1266
|
+
if row["pending_at"]:
|
|
1267
|
+
if row["starting_at"]:
|
|
1268
|
+
if row["running_at"]:
|
|
1269
|
+
if row["finished_at"]:
|
|
1270
|
+
return Status.FINISHED
|
|
1271
|
+
return Status.RUNNING
|
|
1272
|
+
return Status.STARTING
|
|
1273
|
+
return Status.PENDING
|
|
1274
|
+
run_id = convert_sint64_to_uint64(row["run_id"])
|
|
1275
|
+
raise sqlite3.IntegrityError(f"The run {run_id} does not have a valid status.")
|
|
@@ -20,10 +20,15 @@ from logging import ERROR
|
|
|
20
20
|
from os import urandom
|
|
21
21
|
from uuid import uuid4
|
|
22
22
|
|
|
23
|
-
from flwr.common import log
|
|
24
|
-
from flwr.common.constant import ErrorCode
|
|
23
|
+
from flwr.common import ConfigsRecord, Context, log, serde
|
|
24
|
+
from flwr.common.constant import ErrorCode, Status, SubStatus
|
|
25
|
+
from flwr.common.typing import RunStatus
|
|
25
26
|
from flwr.proto.error_pb2 import Error # pylint: disable=E0611
|
|
27
|
+
from flwr.proto.message_pb2 import Context as ProtoContext # pylint: disable=E0611
|
|
26
28
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
29
|
+
|
|
30
|
+
# pylint: disable=E0611
|
|
31
|
+
from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
|
|
27
32
|
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
|
28
33
|
|
|
29
34
|
NODE_UNAVAILABLE_ERROR_REASON = (
|
|
@@ -31,6 +36,17 @@ NODE_UNAVAILABLE_ERROR_REASON = (
|
|
|
31
36
|
"It exceeds the time limit specified in its last ping."
|
|
32
37
|
)
|
|
33
38
|
|
|
39
|
+
VALID_RUN_STATUS_TRANSITIONS = {
|
|
40
|
+
(Status.PENDING, Status.STARTING),
|
|
41
|
+
(Status.STARTING, Status.RUNNING),
|
|
42
|
+
(Status.RUNNING, Status.FINISHED),
|
|
43
|
+
}
|
|
44
|
+
VALID_RUN_SUB_STATUSES = {
|
|
45
|
+
SubStatus.COMPLETED,
|
|
46
|
+
SubStatus.FAILED,
|
|
47
|
+
SubStatus.STOPPED,
|
|
48
|
+
}
|
|
49
|
+
|
|
34
50
|
|
|
35
51
|
def generate_rand_int_from_bytes(num_bytes: int) -> int:
|
|
36
52
|
"""Generate a random unsigned integer from `num_bytes` bytes."""
|
|
@@ -123,6 +139,28 @@ def convert_sint64_values_in_dict_to_uint64(
|
|
|
123
139
|
data_dict[key] = convert_sint64_to_uint64(data_dict[key])
|
|
124
140
|
|
|
125
141
|
|
|
142
|
+
def context_to_bytes(context: Context) -> bytes:
|
|
143
|
+
"""Serialize `Context` to bytes."""
|
|
144
|
+
return serde.context_to_proto(context).SerializeToString()
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def context_from_bytes(context_bytes: bytes) -> Context:
|
|
148
|
+
"""Deserialize `Context` from bytes."""
|
|
149
|
+
return serde.context_from_proto(ProtoContext.FromString(context_bytes))
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def configsrecord_to_bytes(configs_record: ConfigsRecord) -> bytes:
|
|
153
|
+
"""Serialize a `ConfigsRecord` to bytes."""
|
|
154
|
+
return serde.configs_record_to_proto(configs_record).SerializeToString()
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def configsrecord_from_bytes(configsrecord_bytes: bytes) -> ConfigsRecord:
|
|
158
|
+
"""Deserialize `ConfigsRecord` from bytes."""
|
|
159
|
+
return serde.configs_record_from_proto(
|
|
160
|
+
ProtoConfigsRecord.FromString(configsrecord_bytes)
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
126
164
|
def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
|
|
127
165
|
"""Generate a TaskRes with a node unavailable error from a TaskIns."""
|
|
128
166
|
current_time = time.time()
|
|
@@ -146,3 +184,47 @@ def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
|
|
|
146
184
|
),
|
|
147
185
|
),
|
|
148
186
|
)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def is_valid_transition(current_status: RunStatus, new_status: RunStatus) -> bool:
|
|
190
|
+
"""Check if a transition between two run statuses is valid.
|
|
191
|
+
|
|
192
|
+
Parameters
|
|
193
|
+
----------
|
|
194
|
+
current_status : RunStatus
|
|
195
|
+
The current status of the run.
|
|
196
|
+
new_status : RunStatus
|
|
197
|
+
The new status to transition to.
|
|
198
|
+
|
|
199
|
+
Returns
|
|
200
|
+
-------
|
|
201
|
+
bool
|
|
202
|
+
True if the transition is valid, False otherwise.
|
|
203
|
+
"""
|
|
204
|
+
return (
|
|
205
|
+
current_status.status,
|
|
206
|
+
new_status.status,
|
|
207
|
+
) in VALID_RUN_STATUS_TRANSITIONS
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def has_valid_sub_status(status: RunStatus) -> bool:
|
|
211
|
+
"""Check if the 'sub_status' field of the given status is valid.
|
|
212
|
+
|
|
213
|
+
Parameters
|
|
214
|
+
----------
|
|
215
|
+
status : RunStatus
|
|
216
|
+
The status object to be checked.
|
|
217
|
+
|
|
218
|
+
Returns
|
|
219
|
+
-------
|
|
220
|
+
bool
|
|
221
|
+
True if the status object has a valid sub-status, False otherwise.
|
|
222
|
+
|
|
223
|
+
Notes
|
|
224
|
+
-----
|
|
225
|
+
Only an empty string (i.e., "") is considered a valid sub-status for
|
|
226
|
+
non-finished statuses. The sub-status of a finished status cannot be empty.
|
|
227
|
+
"""
|
|
228
|
+
if status.status == Status.FINISHED:
|
|
229
|
+
return status.sub_status in VALID_RUN_SUB_STATUSES
|
|
230
|
+
return status.sub_status == ""
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower SimulationIo service."""
|