flwr-nightly 1.13.0.dev20241021__py3-none-any.whl → 1.13.0.dev20241111__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +2 -2
- flwr/cli/config_utils.py +97 -0
- flwr/cli/log.py +63 -97
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +34 -88
- flwr/client/app.py +23 -20
- flwr/client/clientapp/app.py +22 -18
- flwr/client/nodestate/__init__.py +25 -0
- flwr/client/nodestate/in_memory_nodestate.py +38 -0
- flwr/client/nodestate/nodestate.py +30 -0
- flwr/client/nodestate/nodestate_factory.py +37 -0
- flwr/client/{node_state.py → run_info_store.py} +4 -3
- flwr/client/supernode/app.py +6 -8
- flwr/common/args.py +83 -0
- flwr/common/config.py +10 -0
- flwr/common/constant.py +39 -5
- flwr/common/context.py +9 -4
- flwr/common/date.py +3 -3
- flwr/common/logger.py +108 -1
- flwr/common/object_ref.py +47 -16
- flwr/common/serde.py +24 -0
- flwr/common/telemetry.py +0 -6
- flwr/common/typing.py +10 -1
- flwr/proto/exec_pb2.py +14 -17
- flwr/proto/exec_pb2.pyi +14 -22
- flwr/proto/log_pb2.py +29 -0
- flwr/proto/log_pb2.pyi +39 -0
- flwr/proto/log_pb2_grpc.py +4 -0
- flwr/proto/log_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +8 -8
- flwr/proto/message_pb2.pyi +4 -1
- flwr/proto/run_pb2.py +32 -27
- flwr/proto/run_pb2.pyi +26 -0
- flwr/proto/serverappio_pb2.py +52 -0
- flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +54 -0
- flwr/proto/serverappio_pb2_grpc.py +376 -0
- flwr/proto/serverappio_pb2_grpc.pyi +147 -0
- flwr/proto/simulationio_pb2.py +38 -0
- flwr/proto/simulationio_pb2.pyi +65 -0
- flwr/proto/simulationio_pb2_grpc.py +205 -0
- flwr/proto/simulationio_pb2_grpc.pyi +81 -0
- flwr/server/app.py +272 -105
- flwr/server/driver/driver.py +15 -1
- flwr/server/driver/grpc_driver.py +25 -36
- flwr/server/driver/inmemory_driver.py +6 -16
- flwr/server/run_serverapp.py +29 -23
- flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
- flwr/server/serverapp/app.py +214 -0
- flwr/server/strategy/aggregate.py +4 -4
- flwr/server/strategy/fedadam.py +11 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/{driver_grpc.py → serverappio_grpc.py} +19 -16
- flwr/server/superlink/driver/{driver_servicer.py → serverappio_servicer.py} +125 -39
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -2
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
- flwr/server/superlink/fleet/rest_rere/rest_api.py +7 -7
- flwr/server/superlink/fleet/vce/vce_api.py +23 -23
- flwr/server/superlink/linkstate/__init__.py +28 -0
- flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +184 -36
- flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +149 -19
- flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
- flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +306 -65
- flwr/server/superlink/{state → linkstate}/utils.py +81 -30
- flwr/server/superlink/simulation/__init__.py +15 -0
- flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
- flwr/server/superlink/simulation/simulationio_servicer.py +153 -0
- flwr/simulation/__init__.py +5 -1
- flwr/simulation/app.py +273 -345
- flwr/simulation/legacy_app.py +382 -0
- flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
- flwr/simulation/run_simulation.py +57 -131
- flwr/simulation/simulationio_connection.py +86 -0
- flwr/superexec/app.py +6 -134
- flwr/superexec/deployment.py +61 -66
- flwr/superexec/exec_grpc.py +15 -8
- flwr/superexec/exec_servicer.py +36 -65
- flwr/superexec/executor.py +26 -7
- flwr/superexec/simulation.py +54 -107
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/METADATA +5 -4
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/RECORD +88 -69
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/entry_points.txt +2 -0
- flwr/client/node_state_tests.py +0 -66
- flwr/proto/driver_pb2.py +0 -42
- flwr/proto/driver_pb2_grpc.py +0 -239
- flwr/proto/driver_pb2_grpc.pyi +0 -94
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/WHEEL +0 -0
|
@@ -12,39 +12,51 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""SQLite based implemenation of
|
|
15
|
+
"""SQLite based implemenation of the link state."""
|
|
16
16
|
|
|
17
17
|
# pylint: disable=too-many-lines
|
|
18
18
|
|
|
19
19
|
import json
|
|
20
20
|
import re
|
|
21
21
|
import sqlite3
|
|
22
|
+
import threading
|
|
22
23
|
import time
|
|
23
24
|
from collections.abc import Sequence
|
|
24
25
|
from logging import DEBUG, ERROR, WARNING
|
|
25
26
|
from typing import Any, Optional, Union, cast
|
|
26
27
|
from uuid import UUID, uuid4
|
|
27
28
|
|
|
28
|
-
from flwr.common import log, now
|
|
29
|
+
from flwr.common import Context, log, now
|
|
29
30
|
from flwr.common.constant import (
|
|
30
31
|
MESSAGE_TTL_TOLERANCE,
|
|
31
32
|
NODE_ID_NUM_BYTES,
|
|
32
33
|
RUN_ID_NUM_BYTES,
|
|
34
|
+
Status,
|
|
33
35
|
)
|
|
34
|
-
from flwr.common.
|
|
35
|
-
from flwr.
|
|
36
|
-
|
|
37
|
-
|
|
36
|
+
from flwr.common.record import ConfigsRecord
|
|
37
|
+
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
38
|
+
|
|
39
|
+
# pylint: disable=E0611
|
|
40
|
+
from flwr.proto.node_pb2 import Node
|
|
41
|
+
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
|
|
42
|
+
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
|
43
|
+
|
|
44
|
+
# pylint: enable=E0611
|
|
38
45
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
39
46
|
|
|
40
|
-
from .
|
|
47
|
+
from .linkstate import LinkState
|
|
41
48
|
from .utils import (
|
|
49
|
+
configsrecord_from_bytes,
|
|
50
|
+
configsrecord_to_bytes,
|
|
51
|
+
context_from_bytes,
|
|
52
|
+
context_to_bytes,
|
|
42
53
|
convert_sint64_to_uint64,
|
|
43
54
|
convert_sint64_values_in_dict_to_uint64,
|
|
44
55
|
convert_uint64_to_sint64,
|
|
45
56
|
convert_uint64_values_in_dict_to_sint64,
|
|
46
57
|
generate_rand_int_from_bytes,
|
|
47
|
-
|
|
58
|
+
has_valid_sub_status,
|
|
59
|
+
is_valid_transition,
|
|
48
60
|
)
|
|
49
61
|
|
|
50
62
|
SQL_CREATE_TABLE_NODE = """
|
|
@@ -79,7 +91,33 @@ CREATE TABLE IF NOT EXISTS run(
|
|
|
79
91
|
fab_id TEXT,
|
|
80
92
|
fab_version TEXT,
|
|
81
93
|
fab_hash TEXT,
|
|
82
|
-
override_config TEXT
|
|
94
|
+
override_config TEXT,
|
|
95
|
+
pending_at TEXT,
|
|
96
|
+
starting_at TEXT,
|
|
97
|
+
running_at TEXT,
|
|
98
|
+
finished_at TEXT,
|
|
99
|
+
sub_status TEXT,
|
|
100
|
+
details TEXT,
|
|
101
|
+
federation_options BLOB
|
|
102
|
+
);
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
SQL_CREATE_TABLE_LOGS = """
|
|
106
|
+
CREATE TABLE IF NOT EXISTS logs (
|
|
107
|
+
timestamp REAL,
|
|
108
|
+
run_id INTEGER,
|
|
109
|
+
node_id INTEGER,
|
|
110
|
+
log TEXT,
|
|
111
|
+
PRIMARY KEY (timestamp, run_id, node_id),
|
|
112
|
+
FOREIGN KEY (run_id) REFERENCES run(run_id)
|
|
113
|
+
);
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
SQL_CREATE_TABLE_CONTEXT = """
|
|
117
|
+
CREATE TABLE IF NOT EXISTS context(
|
|
118
|
+
run_id INTEGER UNIQUE,
|
|
119
|
+
context BLOB,
|
|
120
|
+
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
|
83
121
|
);
|
|
84
122
|
"""
|
|
85
123
|
|
|
@@ -126,14 +164,14 @@ CREATE TABLE IF NOT EXISTS task_res(
|
|
|
126
164
|
DictOrTuple = Union[tuple[Any, ...], dict[str, Any]]
|
|
127
165
|
|
|
128
166
|
|
|
129
|
-
class
|
|
130
|
-
"""SQLite-based
|
|
167
|
+
class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
168
|
+
"""SQLite-based LinkState implementation."""
|
|
131
169
|
|
|
132
170
|
def __init__(
|
|
133
171
|
self,
|
|
134
172
|
database_path: str,
|
|
135
173
|
) -> None:
|
|
136
|
-
"""Initialize an
|
|
174
|
+
"""Initialize an SqliteLinkState.
|
|
137
175
|
|
|
138
176
|
Parameters
|
|
139
177
|
----------
|
|
@@ -143,6 +181,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
143
181
|
"""
|
|
144
182
|
self.database_path = database_path
|
|
145
183
|
self.conn: Optional[sqlite3.Connection] = None
|
|
184
|
+
self.lock = threading.RLock()
|
|
146
185
|
|
|
147
186
|
def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
|
|
148
187
|
"""Create tables if they don't exist yet.
|
|
@@ -166,6 +205,8 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
166
205
|
|
|
167
206
|
# Create each table if not exists queries
|
|
168
207
|
cur.execute(SQL_CREATE_TABLE_RUN)
|
|
208
|
+
cur.execute(SQL_CREATE_TABLE_LOGS)
|
|
209
|
+
cur.execute(SQL_CREATE_TABLE_CONTEXT)
|
|
169
210
|
cur.execute(SQL_CREATE_TABLE_TASK_INS)
|
|
170
211
|
cur.execute(SQL_CREATE_TABLE_TASK_RES)
|
|
171
212
|
cur.execute(SQL_CREATE_TABLE_NODE)
|
|
@@ -183,7 +224,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
183
224
|
) -> list[dict[str, Any]]:
|
|
184
225
|
"""Execute a SQL query."""
|
|
185
226
|
if self.conn is None:
|
|
186
|
-
raise AttributeError("
|
|
227
|
+
raise AttributeError("LinkState is not initialized.")
|
|
187
228
|
|
|
188
229
|
if data is None:
|
|
189
230
|
data = []
|
|
@@ -214,11 +255,11 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
214
255
|
def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
|
|
215
256
|
"""Store one TaskIns.
|
|
216
257
|
|
|
217
|
-
Usually, the
|
|
258
|
+
Usually, the ServerAppIo API calls this to schedule instructions.
|
|
218
259
|
|
|
219
|
-
Stores the value of the task_ins in the state and, if successful,
|
|
220
|
-
task_id (UUID) of the task_ins. If, for any reason, storing
|
|
221
|
-
`None` is returned.
|
|
260
|
+
Stores the value of the task_ins in the link state and, if successful,
|
|
261
|
+
returns the task_id (UUID) of the task_ins. If, for any reason, storing
|
|
262
|
+
the task_ins fails, `None` is returned.
|
|
222
263
|
|
|
223
264
|
Constraints
|
|
224
265
|
-----------
|
|
@@ -233,7 +274,6 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
233
274
|
if any(errors):
|
|
234
275
|
log(ERROR, errors)
|
|
235
276
|
return None
|
|
236
|
-
|
|
237
277
|
# Create task_id
|
|
238
278
|
task_id = uuid4()
|
|
239
279
|
|
|
@@ -246,16 +286,36 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
246
286
|
data[0], ["run_id", "producer_node_id", "consumer_node_id"]
|
|
247
287
|
)
|
|
248
288
|
|
|
289
|
+
# Validate run_id
|
|
290
|
+
query = "SELECT run_id FROM run WHERE run_id = ?;"
|
|
291
|
+
if not self.query(query, (data[0]["run_id"],)):
|
|
292
|
+
log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
|
|
293
|
+
return None
|
|
294
|
+
# Validate source node ID
|
|
295
|
+
if task_ins.task.producer.node_id != 0:
|
|
296
|
+
log(
|
|
297
|
+
ERROR,
|
|
298
|
+
"Invalid source node ID for TaskIns: %s",
|
|
299
|
+
task_ins.task.producer.node_id,
|
|
300
|
+
)
|
|
301
|
+
return None
|
|
302
|
+
# Validate destination node ID
|
|
303
|
+
query = "SELECT node_id FROM node WHERE node_id = ?;"
|
|
304
|
+
if not task_ins.task.consumer.anonymous:
|
|
305
|
+
if not self.query(query, (data[0]["consumer_node_id"],)):
|
|
306
|
+
log(
|
|
307
|
+
ERROR,
|
|
308
|
+
"Invalid destination node ID for TaskIns: %s",
|
|
309
|
+
task_ins.task.consumer.node_id,
|
|
310
|
+
)
|
|
311
|
+
return None
|
|
312
|
+
|
|
249
313
|
columns = ", ".join([f":{key}" for key in data[0]])
|
|
250
314
|
query = f"INSERT INTO task_ins VALUES({columns});"
|
|
251
315
|
|
|
252
316
|
# Only invalid run_id can trigger IntegrityError.
|
|
253
317
|
# This may need to be changed in the future version with more integrity checks.
|
|
254
|
-
|
|
255
|
-
self.query(query, data)
|
|
256
|
-
except sqlite3.IntegrityError:
|
|
257
|
-
log(ERROR, "`run` is invalid")
|
|
258
|
-
return None
|
|
318
|
+
self.query(query, data)
|
|
259
319
|
|
|
260
320
|
return task_id
|
|
261
321
|
|
|
@@ -452,8 +512,8 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
452
512
|
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
|
453
513
|
"""Get TaskRes for task_ids.
|
|
454
514
|
|
|
455
|
-
Usually, the
|
|
456
|
-
previously scheduled.
|
|
515
|
+
Usually, the ServerAppIo API calls this method to get results for instructions
|
|
516
|
+
it has previously scheduled.
|
|
457
517
|
|
|
458
518
|
Retrieves all TaskRes for the given `task_ids` and returns and empty list if
|
|
459
519
|
none could be found.
|
|
@@ -579,20 +639,6 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
579
639
|
data = {f"id_{i}": str(node_id) for i, node_id in enumerate(offline_node_ids)}
|
|
580
640
|
task_ins_rows = self.query(query, data)
|
|
581
641
|
|
|
582
|
-
# Make TaskRes containing node unavailabe error
|
|
583
|
-
for row in task_ins_rows:
|
|
584
|
-
for row in rows:
|
|
585
|
-
# Convert values from sint64 to uint64
|
|
586
|
-
convert_sint64_values_in_dict_to_uint64(
|
|
587
|
-
row, ["run_id", "producer_node_id", "consumer_node_id"]
|
|
588
|
-
)
|
|
589
|
-
|
|
590
|
-
task_ins = dict_to_task_ins(row)
|
|
591
|
-
err_taskres = make_node_unavailable_taskres(
|
|
592
|
-
ref_taskins=task_ins,
|
|
593
|
-
)
|
|
594
|
-
result.append(err_taskres)
|
|
595
|
-
|
|
596
642
|
return result
|
|
597
643
|
|
|
598
644
|
def num_task_ins(self) -> int:
|
|
@@ -645,7 +691,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
645
691
|
"""
|
|
646
692
|
|
|
647
693
|
if self.conn is None:
|
|
648
|
-
raise AttributeError("
|
|
694
|
+
raise AttributeError("LinkState not intitialized")
|
|
649
695
|
|
|
650
696
|
with self.conn:
|
|
651
697
|
self.conn.execute(query_1, data)
|
|
@@ -656,7 +702,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
656
702
|
def create_node(
|
|
657
703
|
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
658
704
|
) -> int:
|
|
659
|
-
"""Create, store in state, and return `node_id`."""
|
|
705
|
+
"""Create, store in the link state, and return `node_id`."""
|
|
660
706
|
# Sample a random uint64 as node_id
|
|
661
707
|
uint64_node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
|
|
662
708
|
|
|
@@ -706,7 +752,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
706
752
|
params += (public_key,) # type: ignore
|
|
707
753
|
|
|
708
754
|
if self.conn is None:
|
|
709
|
-
raise AttributeError("
|
|
755
|
+
raise AttributeError("LinkState is not initialized.")
|
|
710
756
|
|
|
711
757
|
try:
|
|
712
758
|
with self.conn:
|
|
@@ -753,12 +799,14 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
753
799
|
return uint64_node_id
|
|
754
800
|
return None
|
|
755
801
|
|
|
802
|
+
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
756
803
|
def create_run(
|
|
757
804
|
self,
|
|
758
805
|
fab_id: Optional[str],
|
|
759
806
|
fab_version: Optional[str],
|
|
760
807
|
fab_hash: Optional[str],
|
|
761
808
|
override_config: UserConfig,
|
|
809
|
+
federation_options: ConfigsRecord,
|
|
762
810
|
) -> int:
|
|
763
811
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
764
812
|
# Sample a random int64 as run_id
|
|
@@ -773,26 +821,30 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
773
821
|
if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
|
|
774
822
|
query = (
|
|
775
823
|
"INSERT INTO run "
|
|
776
|
-
"(run_id, fab_id, fab_version, fab_hash, override_config
|
|
777
|
-
"
|
|
824
|
+
"(run_id, fab_id, fab_version, fab_hash, override_config, "
|
|
825
|
+
"federation_options, pending_at, starting_at, running_at, finished_at, "
|
|
826
|
+
"sub_status, details) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
|
|
778
827
|
)
|
|
779
828
|
if fab_hash:
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
829
|
+
fab_id, fab_version = "", ""
|
|
830
|
+
override_config_json = json.dumps(override_config)
|
|
831
|
+
data = [
|
|
832
|
+
sint64_run_id,
|
|
833
|
+
fab_id,
|
|
834
|
+
fab_version,
|
|
835
|
+
fab_hash,
|
|
836
|
+
override_config_json,
|
|
837
|
+
configsrecord_to_bytes(federation_options),
|
|
838
|
+
]
|
|
839
|
+
data += [
|
|
840
|
+
now().isoformat(),
|
|
841
|
+
"",
|
|
842
|
+
"",
|
|
843
|
+
"",
|
|
844
|
+
"",
|
|
845
|
+
"",
|
|
846
|
+
]
|
|
847
|
+
self.query(query, tuple(data))
|
|
796
848
|
return uint64_run_id
|
|
797
849
|
log(ERROR, "Unexpected run creation failure.")
|
|
798
850
|
return 0
|
|
@@ -800,7 +852,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
800
852
|
def store_server_private_public_key(
|
|
801
853
|
self, private_key: bytes, public_key: bytes
|
|
802
854
|
) -> None:
|
|
803
|
-
"""Store `server_private_key` and `server_public_key` in state."""
|
|
855
|
+
"""Store `server_private_key` and `server_public_key` in the link state."""
|
|
804
856
|
query = "SELECT COUNT(*) FROM credential"
|
|
805
857
|
count = self.query(query)[0]["COUNT(*)"]
|
|
806
858
|
if count < 1:
|
|
@@ -833,13 +885,13 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
833
885
|
return public_key
|
|
834
886
|
|
|
835
887
|
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
836
|
-
"""Store a set of `node_public_keys` in state."""
|
|
888
|
+
"""Store a set of `node_public_keys` in the link state."""
|
|
837
889
|
query = "INSERT INTO public_key (public_key) VALUES (?)"
|
|
838
890
|
data = [(key,) for key in public_keys]
|
|
839
891
|
self.query(query, data)
|
|
840
892
|
|
|
841
893
|
def store_node_public_key(self, public_key: bytes) -> None:
|
|
842
|
-
"""Store a `node_public_key` in state."""
|
|
894
|
+
"""Store a `node_public_key` in the link state."""
|
|
843
895
|
query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
|
|
844
896
|
self.query(query, {"public_key": public_key})
|
|
845
897
|
|
|
@@ -850,6 +902,12 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
850
902
|
result: set[bytes] = {row["public_key"] for row in rows}
|
|
851
903
|
return result
|
|
852
904
|
|
|
905
|
+
def get_run_ids(self) -> set[int]:
|
|
906
|
+
"""Retrieve all run IDs."""
|
|
907
|
+
query = "SELECT run_id FROM run;"
|
|
908
|
+
rows = self.query(query)
|
|
909
|
+
return {convert_sint64_to_uint64(row["run_id"]) for row in rows}
|
|
910
|
+
|
|
853
911
|
def get_run(self, run_id: int) -> Optional[Run]:
|
|
854
912
|
"""Retrieve information about the run with the specified `run_id`."""
|
|
855
913
|
# Convert the uint64 value to sint64 for SQLite
|
|
@@ -868,6 +926,109 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
868
926
|
log(ERROR, "`run_id` does not exist.")
|
|
869
927
|
return None
|
|
870
928
|
|
|
929
|
+
def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
|
|
930
|
+
"""Retrieve the statuses for the specified runs."""
|
|
931
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
932
|
+
sint64_run_ids = (convert_uint64_to_sint64(run_id) for run_id in set(run_ids))
|
|
933
|
+
query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});"
|
|
934
|
+
rows = self.query(query, tuple(sint64_run_ids))
|
|
935
|
+
|
|
936
|
+
return {
|
|
937
|
+
# Restore uint64 run IDs
|
|
938
|
+
convert_sint64_to_uint64(row["run_id"]): RunStatus(
|
|
939
|
+
status=determine_run_status(row),
|
|
940
|
+
sub_status=row["sub_status"],
|
|
941
|
+
details=row["details"],
|
|
942
|
+
)
|
|
943
|
+
for row in rows
|
|
944
|
+
}
|
|
945
|
+
|
|
946
|
+
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
|
|
947
|
+
"""Update the status of the run with the specified `run_id`."""
|
|
948
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
949
|
+
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
950
|
+
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
951
|
+
rows = self.query(query, (sint64_run_id,))
|
|
952
|
+
|
|
953
|
+
# Check if the run_id exists
|
|
954
|
+
if not rows:
|
|
955
|
+
log(ERROR, "`run_id` is invalid")
|
|
956
|
+
return False
|
|
957
|
+
|
|
958
|
+
# Check if the status transition is valid
|
|
959
|
+
row = rows[0]
|
|
960
|
+
current_status = RunStatus(
|
|
961
|
+
status=determine_run_status(row),
|
|
962
|
+
sub_status=row["sub_status"],
|
|
963
|
+
details=row["details"],
|
|
964
|
+
)
|
|
965
|
+
if not is_valid_transition(current_status, new_status):
|
|
966
|
+
log(
|
|
967
|
+
ERROR,
|
|
968
|
+
'Invalid status transition: from "%s" to "%s"',
|
|
969
|
+
current_status.status,
|
|
970
|
+
new_status.status,
|
|
971
|
+
)
|
|
972
|
+
return False
|
|
973
|
+
|
|
974
|
+
# Check if the sub-status is valid
|
|
975
|
+
if not has_valid_sub_status(current_status):
|
|
976
|
+
log(
|
|
977
|
+
ERROR,
|
|
978
|
+
'Invalid sub-status "%s" for status "%s"',
|
|
979
|
+
current_status.sub_status,
|
|
980
|
+
current_status.status,
|
|
981
|
+
)
|
|
982
|
+
return False
|
|
983
|
+
|
|
984
|
+
# Update the status
|
|
985
|
+
query = "UPDATE run SET %s= ?, sub_status = ?, details = ? "
|
|
986
|
+
query += "WHERE run_id = ?;"
|
|
987
|
+
|
|
988
|
+
timestamp_fld = ""
|
|
989
|
+
if new_status.status == Status.STARTING:
|
|
990
|
+
timestamp_fld = "starting_at"
|
|
991
|
+
elif new_status.status == Status.RUNNING:
|
|
992
|
+
timestamp_fld = "running_at"
|
|
993
|
+
elif new_status.status == Status.FINISHED:
|
|
994
|
+
timestamp_fld = "finished_at"
|
|
995
|
+
|
|
996
|
+
data = (
|
|
997
|
+
now().isoformat(),
|
|
998
|
+
new_status.sub_status,
|
|
999
|
+
new_status.details,
|
|
1000
|
+
sint64_run_id,
|
|
1001
|
+
)
|
|
1002
|
+
self.query(query % timestamp_fld, data)
|
|
1003
|
+
return True
|
|
1004
|
+
|
|
1005
|
+
def get_pending_run_id(self) -> Optional[int]:
|
|
1006
|
+
"""Get the `run_id` of a run with `Status.PENDING` status, if any."""
|
|
1007
|
+
pending_run_id = None
|
|
1008
|
+
|
|
1009
|
+
# Fetch all runs with unset `starting_at` (i.e. they are in PENDING status)
|
|
1010
|
+
query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1;"
|
|
1011
|
+
rows = self.query(query)
|
|
1012
|
+
if rows:
|
|
1013
|
+
pending_run_id = convert_sint64_to_uint64(rows[0]["run_id"])
|
|
1014
|
+
|
|
1015
|
+
return pending_run_id
|
|
1016
|
+
|
|
1017
|
+
def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
|
|
1018
|
+
"""Retrieve the federation options for the specified `run_id`."""
|
|
1019
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
1020
|
+
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
1021
|
+
query = "SELECT federation_options FROM run WHERE run_id = ?;"
|
|
1022
|
+
rows = self.query(query, (sint64_run_id,))
|
|
1023
|
+
|
|
1024
|
+
# Check if the run_id exists
|
|
1025
|
+
if not rows:
|
|
1026
|
+
log(ERROR, "`run_id` is invalid")
|
|
1027
|
+
return None
|
|
1028
|
+
|
|
1029
|
+
row = rows[0]
|
|
1030
|
+
return configsrecord_from_bytes(row["federation_options"])
|
|
1031
|
+
|
|
871
1032
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
872
1033
|
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
873
1034
|
sint64_node_id = convert_uint64_to_sint64(node_id)
|
|
@@ -883,6 +1044,72 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
883
1044
|
log(ERROR, "`node_id` does not exist.")
|
|
884
1045
|
return False
|
|
885
1046
|
|
|
1047
|
+
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
|
1048
|
+
"""Get the context for the specified `run_id`."""
|
|
1049
|
+
# Retrieve context if any
|
|
1050
|
+
query = "SELECT context FROM context WHERE run_id = ?;"
|
|
1051
|
+
rows = self.query(query, (convert_uint64_to_sint64(run_id),))
|
|
1052
|
+
context = context_from_bytes(rows[0]["context"]) if rows else None
|
|
1053
|
+
return context
|
|
1054
|
+
|
|
1055
|
+
def set_serverapp_context(self, run_id: int, context: Context) -> None:
|
|
1056
|
+
"""Set the context for the specified `run_id`."""
|
|
1057
|
+
# Convert context to bytes
|
|
1058
|
+
context_bytes = context_to_bytes(context)
|
|
1059
|
+
sint_run_id = convert_uint64_to_sint64(run_id)
|
|
1060
|
+
|
|
1061
|
+
# Check if any existing Context assigned to the run_id
|
|
1062
|
+
query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
|
|
1063
|
+
if self.query(query, (sint_run_id,))[0]["COUNT(*)"] > 0:
|
|
1064
|
+
# Update context
|
|
1065
|
+
query = "UPDATE context SET context = ? WHERE run_id = ?;"
|
|
1066
|
+
self.query(query, (context_bytes, sint_run_id))
|
|
1067
|
+
else:
|
|
1068
|
+
try:
|
|
1069
|
+
# Store context
|
|
1070
|
+
query = "INSERT INTO context (run_id, context) VALUES (?, ?);"
|
|
1071
|
+
self.query(query, (sint_run_id, context_bytes))
|
|
1072
|
+
except sqlite3.IntegrityError:
|
|
1073
|
+
raise ValueError(f"Run {run_id} not found") from None
|
|
1074
|
+
|
|
1075
|
+
def add_serverapp_log(self, run_id: int, log_message: str) -> None:
|
|
1076
|
+
"""Add a log entry to the ServerApp logs for the specified `run_id`."""
|
|
1077
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
1078
|
+
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
1079
|
+
|
|
1080
|
+
# Store log
|
|
1081
|
+
try:
|
|
1082
|
+
query = """
|
|
1083
|
+
INSERT INTO logs (timestamp, run_id, node_id, log) VALUES (?, ?, ?, ?);
|
|
1084
|
+
"""
|
|
1085
|
+
self.query(query, (now().timestamp(), sint64_run_id, 0, log_message))
|
|
1086
|
+
except sqlite3.IntegrityError:
|
|
1087
|
+
raise ValueError(f"Run {run_id} not found") from None
|
|
1088
|
+
|
|
1089
|
+
def get_serverapp_log(
|
|
1090
|
+
self, run_id: int, after_timestamp: Optional[float]
|
|
1091
|
+
) -> tuple[str, float]:
|
|
1092
|
+
"""Get the ServerApp logs for the specified `run_id`."""
|
|
1093
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
1094
|
+
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
1095
|
+
|
|
1096
|
+
# Check if the run_id exists
|
|
1097
|
+
query = "SELECT run_id FROM run WHERE run_id = ?;"
|
|
1098
|
+
if not self.query(query, (sint64_run_id,)):
|
|
1099
|
+
raise ValueError(f"Run {run_id} not found")
|
|
1100
|
+
|
|
1101
|
+
# Retrieve logs
|
|
1102
|
+
if after_timestamp is None:
|
|
1103
|
+
after_timestamp = 0.0
|
|
1104
|
+
query = """
|
|
1105
|
+
SELECT log, timestamp FROM logs
|
|
1106
|
+
WHERE run_id = ? AND node_id = ? AND timestamp > ?;
|
|
1107
|
+
"""
|
|
1108
|
+
rows = self.query(query, (sint64_run_id, 0, after_timestamp))
|
|
1109
|
+
rows.sort(key=lambda x: x["timestamp"])
|
|
1110
|
+
latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
|
|
1111
|
+
return "".join(row["log"] for row in rows), latest_timestamp
|
|
1112
|
+
|
|
886
1113
|
def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]:
|
|
887
1114
|
"""Check if the TaskIns exists and is valid (not expired).
|
|
888
1115
|
|
|
@@ -967,7 +1194,7 @@ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
|
|
|
967
1194
|
|
|
968
1195
|
def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
|
|
969
1196
|
"""Turn task_dict into protobuf message."""
|
|
970
|
-
recordset =
|
|
1197
|
+
recordset = ProtoRecordSet()
|
|
971
1198
|
recordset.ParseFromString(task_dict["recordset"])
|
|
972
1199
|
|
|
973
1200
|
result = TaskIns(
|
|
@@ -997,7 +1224,7 @@ def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
|
|
|
997
1224
|
|
|
998
1225
|
def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
|
|
999
1226
|
"""Turn task_dict into protobuf message."""
|
|
1000
|
-
recordset =
|
|
1227
|
+
recordset = ProtoRecordSet()
|
|
1001
1228
|
recordset.ParseFromString(task_dict["recordset"])
|
|
1002
1229
|
|
|
1003
1230
|
result = TaskRes(
|
|
@@ -1023,3 +1250,17 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
|
|
|
1023
1250
|
),
|
|
1024
1251
|
)
|
|
1025
1252
|
return result
|
|
1253
|
+
|
|
1254
|
+
|
|
1255
|
+
def determine_run_status(row: dict[str, Any]) -> str:
|
|
1256
|
+
"""Determine the status of the run based on timestamp fields."""
|
|
1257
|
+
if row["pending_at"]:
|
|
1258
|
+
if row["starting_at"]:
|
|
1259
|
+
if row["running_at"]:
|
|
1260
|
+
if row["finished_at"]:
|
|
1261
|
+
return Status.FINISHED
|
|
1262
|
+
return Status.RUNNING
|
|
1263
|
+
return Status.STARTING
|
|
1264
|
+
return Status.PENDING
|
|
1265
|
+
run_id = convert_sint64_to_uint64(row["run_id"])
|
|
1266
|
+
raise sqlite3.IntegrityError(f"The run {run_id} does not have a valid status.")
|