flwr-nightly 1.13.0.dev20241022__py3-none-any.whl → 1.13.0.dev20241024__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/client/app.py +2 -3
- flwr/client/supernode/app.py +6 -8
- flwr/common/constant.py +29 -0
- flwr/common/typing.py +9 -0
- flwr/proto/driver_pb2.py +24 -15
- flwr/proto/driver_pb2.pyi +59 -0
- flwr/proto/driver_pb2_grpc.py +68 -0
- flwr/proto/driver_pb2_grpc.pyi +26 -0
- flwr/server/app.py +74 -3
- flwr/server/run_serverapp.py +13 -9
- flwr/server/serverapp/app.py +59 -1
- flwr/server/superlink/driver/driver_servicer.py +16 -0
- flwr/server/superlink/linkstate/in_memory_linkstate.py +113 -12
- flwr/server/superlink/linkstate/linkstate.py +78 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +173 -27
- flwr/server/superlink/linkstate/utils.py +69 -2
- flwr/simulation/run_simulation.py +23 -7
- flwr/superexec/app.py +3 -138
- flwr/superexec/deployment.py +34 -25
- flwr/superexec/exec_grpc.py +15 -8
- flwr/superexec/exec_servicer.py +11 -1
- flwr/superexec/executor.py +19 -0
- flwr/superexec/simulation.py +8 -0
- {flwr_nightly-1.13.0.dev20241022.dist-info → flwr_nightly-1.13.0.dev20241024.dist-info}/METADATA +1 -1
- {flwr_nightly-1.13.0.dev20241022.dist-info → flwr_nightly-1.13.0.dev20241024.dist-info}/RECORD +28 -29
- flwr/client/node_state_tests.py +0 -65
- {flwr_nightly-1.13.0.dev20241022.dist-info → flwr_nightly-1.13.0.dev20241024.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.13.0.dev20241022.dist-info → flwr_nightly-1.13.0.dev20241024.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.13.0.dev20241022.dist-info → flwr_nightly-1.13.0.dev20241024.dist-info}/entry_points.txt +0 -0
|
@@ -17,22 +17,41 @@
|
|
|
17
17
|
|
|
18
18
|
import threading
|
|
19
19
|
import time
|
|
20
|
+
from dataclasses import dataclass
|
|
20
21
|
from logging import ERROR, WARNING
|
|
21
22
|
from typing import Optional
|
|
22
23
|
from uuid import UUID, uuid4
|
|
23
24
|
|
|
24
|
-
from flwr.common import log, now
|
|
25
|
+
from flwr.common import Context, log, now
|
|
25
26
|
from flwr.common.constant import (
|
|
26
27
|
MESSAGE_TTL_TOLERANCE,
|
|
27
28
|
NODE_ID_NUM_BYTES,
|
|
28
29
|
RUN_ID_NUM_BYTES,
|
|
30
|
+
Status,
|
|
29
31
|
)
|
|
30
|
-
from flwr.common.typing import Run, UserConfig
|
|
32
|
+
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
31
33
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
32
34
|
from flwr.server.superlink.linkstate.linkstate import LinkState
|
|
33
35
|
from flwr.server.utils import validate_task_ins_or_res
|
|
34
36
|
|
|
35
|
-
from .utils import
|
|
37
|
+
from .utils import (
|
|
38
|
+
generate_rand_int_from_bytes,
|
|
39
|
+
has_valid_sub_status,
|
|
40
|
+
is_valid_transition,
|
|
41
|
+
make_node_unavailable_taskres,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class RunRecord:
|
|
47
|
+
"""The record of a specific run, including its status and timestamps."""
|
|
48
|
+
|
|
49
|
+
run: Run
|
|
50
|
+
status: RunStatus
|
|
51
|
+
pending_at: str = ""
|
|
52
|
+
starting_at: str = ""
|
|
53
|
+
running_at: str = ""
|
|
54
|
+
finished_at: str = ""
|
|
36
55
|
|
|
37
56
|
|
|
38
57
|
class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
@@ -44,8 +63,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
44
63
|
self.node_ids: dict[int, tuple[float, float]] = {}
|
|
45
64
|
self.public_key_to_node_id: dict[bytes, int] = {}
|
|
46
65
|
|
|
47
|
-
# Map run_id to
|
|
48
|
-
self.run_ids: dict[int,
|
|
66
|
+
# Map run_id to RunRecord
|
|
67
|
+
self.run_ids: dict[int, RunRecord] = {}
|
|
68
|
+
self.contexts: dict[int, Context] = {}
|
|
49
69
|
self.task_ins_store: dict[UUID, TaskIns] = {}
|
|
50
70
|
self.task_res_store: dict[UUID, TaskRes] = {}
|
|
51
71
|
|
|
@@ -351,13 +371,22 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
351
371
|
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
352
372
|
|
|
353
373
|
if run_id not in self.run_ids:
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
374
|
+
run_record = RunRecord(
|
|
375
|
+
run=Run(
|
|
376
|
+
run_id=run_id,
|
|
377
|
+
fab_id=fab_id if fab_id else "",
|
|
378
|
+
fab_version=fab_version if fab_version else "",
|
|
379
|
+
fab_hash=fab_hash if fab_hash else "",
|
|
380
|
+
override_config=override_config,
|
|
381
|
+
),
|
|
382
|
+
status=RunStatus(
|
|
383
|
+
status=Status.PENDING,
|
|
384
|
+
sub_status="",
|
|
385
|
+
details="",
|
|
386
|
+
),
|
|
387
|
+
pending_at=now().isoformat(),
|
|
360
388
|
)
|
|
389
|
+
self.run_ids[run_id] = run_record
|
|
361
390
|
return run_id
|
|
362
391
|
log(ERROR, "Unexpected run creation failure.")
|
|
363
392
|
return 0
|
|
@@ -401,7 +430,69 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
401
430
|
if run_id not in self.run_ids:
|
|
402
431
|
log(ERROR, "`run_id` is invalid")
|
|
403
432
|
return None
|
|
404
|
-
return self.run_ids[run_id]
|
|
433
|
+
return self.run_ids[run_id].run
|
|
434
|
+
|
|
435
|
+
def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
|
|
436
|
+
"""Retrieve the statuses for the specified runs."""
|
|
437
|
+
with self.lock:
|
|
438
|
+
return {
|
|
439
|
+
run_id: self.run_ids[run_id].status
|
|
440
|
+
for run_id in set(run_ids)
|
|
441
|
+
if run_id in self.run_ids
|
|
442
|
+
}
|
|
443
|
+
|
|
444
|
+
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
|
|
445
|
+
"""Update the status of the run with the specified `run_id`."""
|
|
446
|
+
with self.lock:
|
|
447
|
+
# Check if the run_id exists
|
|
448
|
+
if run_id not in self.run_ids:
|
|
449
|
+
log(ERROR, "`run_id` is invalid")
|
|
450
|
+
return False
|
|
451
|
+
|
|
452
|
+
# Check if the status transition is valid
|
|
453
|
+
current_status = self.run_ids[run_id].status
|
|
454
|
+
if not is_valid_transition(current_status, new_status):
|
|
455
|
+
log(
|
|
456
|
+
ERROR,
|
|
457
|
+
'Invalid status transition: from "%s" to "%s"',
|
|
458
|
+
current_status.status,
|
|
459
|
+
new_status.status,
|
|
460
|
+
)
|
|
461
|
+
return False
|
|
462
|
+
|
|
463
|
+
# Check if the sub-status is valid
|
|
464
|
+
if not has_valid_sub_status(current_status):
|
|
465
|
+
log(
|
|
466
|
+
ERROR,
|
|
467
|
+
'Invalid sub-status "%s" for status "%s"',
|
|
468
|
+
current_status.sub_status,
|
|
469
|
+
current_status.status,
|
|
470
|
+
)
|
|
471
|
+
return False
|
|
472
|
+
|
|
473
|
+
# Update the status
|
|
474
|
+
run_record = self.run_ids[run_id]
|
|
475
|
+
if new_status.status == Status.STARTING:
|
|
476
|
+
run_record.starting_at = now().isoformat()
|
|
477
|
+
elif new_status.status == Status.RUNNING:
|
|
478
|
+
run_record.running_at = now().isoformat()
|
|
479
|
+
elif new_status.status == Status.FINISHED:
|
|
480
|
+
run_record.finished_at = now().isoformat()
|
|
481
|
+
run_record.status = new_status
|
|
482
|
+
return True
|
|
483
|
+
|
|
484
|
+
def get_pending_run_id(self) -> Optional[int]:
|
|
485
|
+
"""Get the `run_id` of a run with `Status.PENDING` status, if any."""
|
|
486
|
+
pending_run_id = None
|
|
487
|
+
|
|
488
|
+
# Loop through all registered runs
|
|
489
|
+
for run_id, run_rec in self.run_ids.items():
|
|
490
|
+
# Break once a pending run is found
|
|
491
|
+
if run_rec.status.status == Status.PENDING:
|
|
492
|
+
pending_run_id = run_id
|
|
493
|
+
break
|
|
494
|
+
|
|
495
|
+
return pending_run_id
|
|
405
496
|
|
|
406
497
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
407
498
|
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
@@ -410,3 +501,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
410
501
|
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
|
411
502
|
return True
|
|
412
503
|
return False
|
|
504
|
+
|
|
505
|
+
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
|
506
|
+
"""Get the context for the specified `run_id`."""
|
|
507
|
+
return self.contexts.get(run_id)
|
|
508
|
+
|
|
509
|
+
def set_serverapp_context(self, run_id: int, context: Context) -> None:
|
|
510
|
+
"""Set the context for the specified `run_id`."""
|
|
511
|
+
if run_id not in self.run_ids:
|
|
512
|
+
raise ValueError(f"Run {run_id} not found")
|
|
513
|
+
self.contexts[run_id] = context
|
|
@@ -19,7 +19,8 @@ import abc
|
|
|
19
19
|
from typing import Optional
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
|
-
from flwr.common
|
|
22
|
+
from flwr.common import Context
|
|
23
|
+
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
23
24
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
24
25
|
|
|
25
26
|
|
|
@@ -178,6 +179,54 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
178
179
|
- `fab_version`: The version of the FAB used in the specified run.
|
|
179
180
|
"""
|
|
180
181
|
|
|
182
|
+
@abc.abstractmethod
|
|
183
|
+
def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
|
|
184
|
+
"""Retrieve the statuses for the specified runs.
|
|
185
|
+
|
|
186
|
+
Parameters
|
|
187
|
+
----------
|
|
188
|
+
run_ids : set[int]
|
|
189
|
+
A set of run identifiers for which to retrieve statuses.
|
|
190
|
+
|
|
191
|
+
Returns
|
|
192
|
+
-------
|
|
193
|
+
dict[int, RunStatus]
|
|
194
|
+
A dictionary mapping each valid run ID to its corresponding status.
|
|
195
|
+
|
|
196
|
+
Notes
|
|
197
|
+
-----
|
|
198
|
+
Only valid run IDs that exist in the State will be included in the returned
|
|
199
|
+
dictionary. If a run ID is not found, it will be omitted from the result.
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
@abc.abstractmethod
|
|
203
|
+
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
|
|
204
|
+
"""Update the status of the run with the specified `run_id`.
|
|
205
|
+
|
|
206
|
+
Parameters
|
|
207
|
+
----------
|
|
208
|
+
run_id : int
|
|
209
|
+
The identifier of the run.
|
|
210
|
+
new_status : RunStatus
|
|
211
|
+
The new status to be assigned to the run.
|
|
212
|
+
|
|
213
|
+
Returns
|
|
214
|
+
-------
|
|
215
|
+
bool
|
|
216
|
+
True if the status update is successful; False otherwise.
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
@abc.abstractmethod
|
|
220
|
+
def get_pending_run_id(self) -> Optional[int]:
|
|
221
|
+
"""Get the `run_id` of a run with `Status.PENDING` status.
|
|
222
|
+
|
|
223
|
+
Returns
|
|
224
|
+
-------
|
|
225
|
+
Optional[int]
|
|
226
|
+
The `run_id` of a `Run` that is pending to be started; None if
|
|
227
|
+
there is no Run pending.
|
|
228
|
+
"""
|
|
229
|
+
|
|
181
230
|
@abc.abstractmethod
|
|
182
231
|
def store_server_private_public_key(
|
|
183
232
|
self, private_key: bytes, public_key: bytes
|
|
@@ -222,3 +271,31 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
222
271
|
is_acknowledged : bool
|
|
223
272
|
True if the ping is successfully acknowledged; otherwise, False.
|
|
224
273
|
"""
|
|
274
|
+
|
|
275
|
+
@abc.abstractmethod
|
|
276
|
+
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
|
277
|
+
"""Get the context for the specified `run_id`.
|
|
278
|
+
|
|
279
|
+
Parameters
|
|
280
|
+
----------
|
|
281
|
+
run_id : int
|
|
282
|
+
The identifier of the run for which to retrieve the context.
|
|
283
|
+
|
|
284
|
+
Returns
|
|
285
|
+
-------
|
|
286
|
+
Optional[Context]
|
|
287
|
+
The context associated with the specified `run_id`, or `None` if no context
|
|
288
|
+
exists for the given `run_id`.
|
|
289
|
+
"""
|
|
290
|
+
|
|
291
|
+
@abc.abstractmethod
|
|
292
|
+
def set_serverapp_context(self, run_id: int, context: Context) -> None:
|
|
293
|
+
"""Set the context for the specified `run_id`.
|
|
294
|
+
|
|
295
|
+
Parameters
|
|
296
|
+
----------
|
|
297
|
+
run_id : int
|
|
298
|
+
The identifier of the run for which to set the context.
|
|
299
|
+
context : Context
|
|
300
|
+
The context to be associated with the specified `run_id`.
|
|
301
|
+
"""
|
|
@@ -19,31 +19,41 @@
|
|
|
19
19
|
import json
|
|
20
20
|
import re
|
|
21
21
|
import sqlite3
|
|
22
|
+
import threading
|
|
22
23
|
import time
|
|
23
24
|
from collections.abc import Sequence
|
|
24
25
|
from logging import DEBUG, ERROR, WARNING
|
|
25
26
|
from typing import Any, Optional, Union, cast
|
|
26
27
|
from uuid import UUID, uuid4
|
|
27
28
|
|
|
28
|
-
from flwr.common import log, now
|
|
29
|
+
from flwr.common import Context, log, now
|
|
29
30
|
from flwr.common.constant import (
|
|
30
31
|
MESSAGE_TTL_TOLERANCE,
|
|
31
32
|
NODE_ID_NUM_BYTES,
|
|
32
33
|
RUN_ID_NUM_BYTES,
|
|
34
|
+
Status,
|
|
33
35
|
)
|
|
34
|
-
from flwr.common.typing import Run, UserConfig
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
from flwr.proto.
|
|
36
|
+
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
37
|
+
|
|
38
|
+
# pylint: disable=E0611
|
|
39
|
+
from flwr.proto.node_pb2 import Node
|
|
40
|
+
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
|
|
41
|
+
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
|
42
|
+
|
|
43
|
+
# pylint: enable=E0611
|
|
38
44
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
39
45
|
|
|
40
46
|
from .linkstate import LinkState
|
|
41
47
|
from .utils import (
|
|
48
|
+
context_from_bytes,
|
|
49
|
+
context_to_bytes,
|
|
42
50
|
convert_sint64_to_uint64,
|
|
43
51
|
convert_sint64_values_in_dict_to_uint64,
|
|
44
52
|
convert_uint64_to_sint64,
|
|
45
53
|
convert_uint64_values_in_dict_to_sint64,
|
|
46
54
|
generate_rand_int_from_bytes,
|
|
55
|
+
has_valid_sub_status,
|
|
56
|
+
is_valid_transition,
|
|
47
57
|
make_node_unavailable_taskres,
|
|
48
58
|
)
|
|
49
59
|
|
|
@@ -79,7 +89,21 @@ CREATE TABLE IF NOT EXISTS run(
|
|
|
79
89
|
fab_id TEXT,
|
|
80
90
|
fab_version TEXT,
|
|
81
91
|
fab_hash TEXT,
|
|
82
|
-
override_config TEXT
|
|
92
|
+
override_config TEXT,
|
|
93
|
+
pending_at TEXT,
|
|
94
|
+
starting_at TEXT,
|
|
95
|
+
running_at TEXT,
|
|
96
|
+
finished_at TEXT,
|
|
97
|
+
sub_status TEXT,
|
|
98
|
+
details TEXT
|
|
99
|
+
);
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
SQL_CREATE_TABLE_CONTEXT = """
|
|
103
|
+
CREATE TABLE IF NOT EXISTS context(
|
|
104
|
+
run_id INTEGER UNIQUE,
|
|
105
|
+
context BLOB,
|
|
106
|
+
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
|
83
107
|
);
|
|
84
108
|
"""
|
|
85
109
|
|
|
@@ -133,7 +157,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
133
157
|
self,
|
|
134
158
|
database_path: str,
|
|
135
159
|
) -> None:
|
|
136
|
-
"""Initialize an
|
|
160
|
+
"""Initialize an SqliteLinkState.
|
|
137
161
|
|
|
138
162
|
Parameters
|
|
139
163
|
----------
|
|
@@ -143,6 +167,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
143
167
|
"""
|
|
144
168
|
self.database_path = database_path
|
|
145
169
|
self.conn: Optional[sqlite3.Connection] = None
|
|
170
|
+
self.lock = threading.RLock()
|
|
146
171
|
|
|
147
172
|
def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
|
|
148
173
|
"""Create tables if they don't exist yet.
|
|
@@ -166,6 +191,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
166
191
|
|
|
167
192
|
# Create each table if not exists queries
|
|
168
193
|
cur.execute(SQL_CREATE_TABLE_RUN)
|
|
194
|
+
cur.execute(SQL_CREATE_TABLE_CONTEXT)
|
|
169
195
|
cur.execute(SQL_CREATE_TABLE_TASK_INS)
|
|
170
196
|
cur.execute(SQL_CREATE_TABLE_TASK_RES)
|
|
171
197
|
cur.execute(SQL_CREATE_TABLE_NODE)
|
|
@@ -773,26 +799,16 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
773
799
|
if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
|
|
774
800
|
query = (
|
|
775
801
|
"INSERT INTO run "
|
|
776
|
-
"(run_id, fab_id, fab_version, fab_hash, override_config
|
|
777
|
-
"
|
|
802
|
+
"(run_id, fab_id, fab_version, fab_hash, override_config, pending_at, "
|
|
803
|
+
"starting_at, running_at, finished_at, sub_status, details)"
|
|
804
|
+
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
|
|
778
805
|
)
|
|
779
806
|
if fab_hash:
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
self.query(
|
|
786
|
-
query,
|
|
787
|
-
(
|
|
788
|
-
sint64_run_id,
|
|
789
|
-
fab_id,
|
|
790
|
-
fab_version,
|
|
791
|
-
"",
|
|
792
|
-
json.dumps(override_config),
|
|
793
|
-
),
|
|
794
|
-
)
|
|
795
|
-
# Note: we need to return the uint64 value of the run_id
|
|
807
|
+
fab_id, fab_version = "", ""
|
|
808
|
+
override_config_json = json.dumps(override_config)
|
|
809
|
+
data = [sint64_run_id, fab_id, fab_version, fab_hash, override_config_json]
|
|
810
|
+
data += [now().isoformat(), "", "", "", "", ""]
|
|
811
|
+
self.query(query, tuple(data))
|
|
796
812
|
return uint64_run_id
|
|
797
813
|
log(ERROR, "Unexpected run creation failure.")
|
|
798
814
|
return 0
|
|
@@ -868,6 +884,94 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
868
884
|
log(ERROR, "`run_id` does not exist.")
|
|
869
885
|
return None
|
|
870
886
|
|
|
887
|
+
def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
|
|
888
|
+
"""Retrieve the statuses for the specified runs."""
|
|
889
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
890
|
+
sint64_run_ids = (convert_uint64_to_sint64(run_id) for run_id in set(run_ids))
|
|
891
|
+
query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});"
|
|
892
|
+
rows = self.query(query, tuple(sint64_run_ids))
|
|
893
|
+
|
|
894
|
+
return {
|
|
895
|
+
# Restore uint64 run IDs
|
|
896
|
+
convert_sint64_to_uint64(row["run_id"]): RunStatus(
|
|
897
|
+
status=determine_run_status(row),
|
|
898
|
+
sub_status=row["sub_status"],
|
|
899
|
+
details=row["details"],
|
|
900
|
+
)
|
|
901
|
+
for row in rows
|
|
902
|
+
}
|
|
903
|
+
|
|
904
|
+
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
|
|
905
|
+
"""Update the status of the run with the specified `run_id`."""
|
|
906
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
907
|
+
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
908
|
+
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
909
|
+
rows = self.query(query, (sint64_run_id,))
|
|
910
|
+
|
|
911
|
+
# Check if the run_id exists
|
|
912
|
+
if not rows:
|
|
913
|
+
log(ERROR, "`run_id` is invalid")
|
|
914
|
+
return False
|
|
915
|
+
|
|
916
|
+
# Check if the status transition is valid
|
|
917
|
+
row = rows[0]
|
|
918
|
+
current_status = RunStatus(
|
|
919
|
+
status=determine_run_status(row),
|
|
920
|
+
sub_status=row["sub_status"],
|
|
921
|
+
details=row["details"],
|
|
922
|
+
)
|
|
923
|
+
if not is_valid_transition(current_status, new_status):
|
|
924
|
+
log(
|
|
925
|
+
ERROR,
|
|
926
|
+
'Invalid status transition: from "%s" to "%s"',
|
|
927
|
+
current_status.status,
|
|
928
|
+
new_status.status,
|
|
929
|
+
)
|
|
930
|
+
return False
|
|
931
|
+
|
|
932
|
+
# Check if the sub-status is valid
|
|
933
|
+
if not has_valid_sub_status(current_status):
|
|
934
|
+
log(
|
|
935
|
+
ERROR,
|
|
936
|
+
'Invalid sub-status "%s" for status "%s"',
|
|
937
|
+
current_status.sub_status,
|
|
938
|
+
current_status.status,
|
|
939
|
+
)
|
|
940
|
+
return False
|
|
941
|
+
|
|
942
|
+
# Update the status
|
|
943
|
+
query = "UPDATE run SET %s= ?, sub_status = ?, details = ? "
|
|
944
|
+
query += "WHERE run_id = ?;"
|
|
945
|
+
|
|
946
|
+
timestamp_fld = ""
|
|
947
|
+
if new_status.status == Status.STARTING:
|
|
948
|
+
timestamp_fld = "starting_at"
|
|
949
|
+
elif new_status.status == Status.RUNNING:
|
|
950
|
+
timestamp_fld = "running_at"
|
|
951
|
+
elif new_status.status == Status.FINISHED:
|
|
952
|
+
timestamp_fld = "finished_at"
|
|
953
|
+
|
|
954
|
+
data = (
|
|
955
|
+
now().isoformat(),
|
|
956
|
+
new_status.sub_status,
|
|
957
|
+
new_status.details,
|
|
958
|
+
sint64_run_id,
|
|
959
|
+
)
|
|
960
|
+
self.query(query % timestamp_fld, data)
|
|
961
|
+
return True
|
|
962
|
+
|
|
963
|
+
def get_pending_run_id(self) -> Optional[int]:
|
|
964
|
+
"""Get the `run_id` of a run with `Status.PENDING` status, if any."""
|
|
965
|
+
pending_run_id = None
|
|
966
|
+
|
|
967
|
+
# Fetch all runs with unset `starting_at` (i.e. they are in PENDING status)
|
|
968
|
+
query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1;"
|
|
969
|
+
rows = self.query(query)
|
|
970
|
+
if rows:
|
|
971
|
+
pending_run_id = convert_sint64_to_uint64(rows[0]["run_id"])
|
|
972
|
+
|
|
973
|
+
return pending_run_id
|
|
974
|
+
|
|
871
975
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
872
976
|
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
873
977
|
sint64_node_id = convert_uint64_to_sint64(node_id)
|
|
@@ -883,6 +987,34 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
883
987
|
log(ERROR, "`node_id` does not exist.")
|
|
884
988
|
return False
|
|
885
989
|
|
|
990
|
+
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
|
991
|
+
"""Get the context for the specified `run_id`."""
|
|
992
|
+
# Retrieve context if any
|
|
993
|
+
query = "SELECT context FROM context WHERE run_id = ?;"
|
|
994
|
+
rows = self.query(query, (convert_uint64_to_sint64(run_id),))
|
|
995
|
+
context = context_from_bytes(rows[0]["context"]) if rows else None
|
|
996
|
+
return context
|
|
997
|
+
|
|
998
|
+
def set_serverapp_context(self, run_id: int, context: Context) -> None:
|
|
999
|
+
"""Set the context for the specified `run_id`."""
|
|
1000
|
+
# Convert context to bytes
|
|
1001
|
+
context_bytes = context_to_bytes(context)
|
|
1002
|
+
sint_run_id = convert_uint64_to_sint64(run_id)
|
|
1003
|
+
|
|
1004
|
+
# Check if any existing Context assigned to the run_id
|
|
1005
|
+
query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
|
|
1006
|
+
if self.query(query, (sint_run_id,))[0]["COUNT(*)"] > 0:
|
|
1007
|
+
# Update context
|
|
1008
|
+
query = "UPDATE context SET context = ? WHERE run_id = ?;"
|
|
1009
|
+
self.query(query, (context_bytes, sint_run_id))
|
|
1010
|
+
else:
|
|
1011
|
+
try:
|
|
1012
|
+
# Store context
|
|
1013
|
+
query = "INSERT INTO context (run_id, context) VALUES (?, ?);"
|
|
1014
|
+
self.query(query, (sint_run_id, context_bytes))
|
|
1015
|
+
except sqlite3.IntegrityError:
|
|
1016
|
+
raise ValueError(f"Run {run_id} not found") from None
|
|
1017
|
+
|
|
886
1018
|
def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]:
|
|
887
1019
|
"""Check if the TaskIns exists and is valid (not expired).
|
|
888
1020
|
|
|
@@ -967,7 +1099,7 @@ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
|
|
|
967
1099
|
|
|
968
1100
|
def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
|
|
969
1101
|
"""Turn task_dict into protobuf message."""
|
|
970
|
-
recordset =
|
|
1102
|
+
recordset = ProtoRecordSet()
|
|
971
1103
|
recordset.ParseFromString(task_dict["recordset"])
|
|
972
1104
|
|
|
973
1105
|
result = TaskIns(
|
|
@@ -997,7 +1129,7 @@ def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
|
|
|
997
1129
|
|
|
998
1130
|
def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
|
|
999
1131
|
"""Turn task_dict into protobuf message."""
|
|
1000
|
-
recordset =
|
|
1132
|
+
recordset = ProtoRecordSet()
|
|
1001
1133
|
recordset.ParseFromString(task_dict["recordset"])
|
|
1002
1134
|
|
|
1003
1135
|
result = TaskRes(
|
|
@@ -1023,3 +1155,17 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
|
|
|
1023
1155
|
),
|
|
1024
1156
|
)
|
|
1025
1157
|
return result
|
|
1158
|
+
|
|
1159
|
+
|
|
1160
|
+
def determine_run_status(row: dict[str, Any]) -> str:
|
|
1161
|
+
"""Determine the status of the run based on timestamp fields."""
|
|
1162
|
+
if row["pending_at"]:
|
|
1163
|
+
if row["starting_at"]:
|
|
1164
|
+
if row["running_at"]:
|
|
1165
|
+
if row["finished_at"]:
|
|
1166
|
+
return Status.FINISHED
|
|
1167
|
+
return Status.RUNNING
|
|
1168
|
+
return Status.STARTING
|
|
1169
|
+
return Status.PENDING
|
|
1170
|
+
run_id = convert_sint64_to_uint64(row["run_id"])
|
|
1171
|
+
raise sqlite3.IntegrityError(f"The run {run_id} does not have a valid status.")
|
|
@@ -20,9 +20,11 @@ from logging import ERROR
|
|
|
20
20
|
from os import urandom
|
|
21
21
|
from uuid import uuid4
|
|
22
22
|
|
|
23
|
-
from flwr.common import log
|
|
24
|
-
from flwr.common.constant import ErrorCode
|
|
23
|
+
from flwr.common import Context, log, serde
|
|
24
|
+
from flwr.common.constant import ErrorCode, Status, SubStatus
|
|
25
|
+
from flwr.common.typing import RunStatus
|
|
25
26
|
from flwr.proto.error_pb2 import Error # pylint: disable=E0611
|
|
27
|
+
from flwr.proto.message_pb2 import Context as ProtoContext # pylint: disable=E0611
|
|
26
28
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
27
29
|
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
|
28
30
|
|
|
@@ -31,6 +33,17 @@ NODE_UNAVAILABLE_ERROR_REASON = (
|
|
|
31
33
|
"It exceeds the time limit specified in its last ping."
|
|
32
34
|
)
|
|
33
35
|
|
|
36
|
+
VALID_RUN_STATUS_TRANSITIONS = {
|
|
37
|
+
(Status.PENDING, Status.STARTING),
|
|
38
|
+
(Status.STARTING, Status.RUNNING),
|
|
39
|
+
(Status.RUNNING, Status.FINISHED),
|
|
40
|
+
}
|
|
41
|
+
VALID_RUN_SUB_STATUSES = {
|
|
42
|
+
SubStatus.COMPLETED,
|
|
43
|
+
SubStatus.FAILED,
|
|
44
|
+
SubStatus.STOPPED,
|
|
45
|
+
}
|
|
46
|
+
|
|
34
47
|
|
|
35
48
|
def generate_rand_int_from_bytes(num_bytes: int) -> int:
|
|
36
49
|
"""Generate a random unsigned integer from `num_bytes` bytes."""
|
|
@@ -123,6 +136,16 @@ def convert_sint64_values_in_dict_to_uint64(
|
|
|
123
136
|
data_dict[key] = convert_sint64_to_uint64(data_dict[key])
|
|
124
137
|
|
|
125
138
|
|
|
139
|
+
def context_to_bytes(context: Context) -> bytes:
|
|
140
|
+
"""Serialize `Context` to bytes."""
|
|
141
|
+
return serde.context_to_proto(context).SerializeToString()
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def context_from_bytes(context_bytes: bytes) -> Context:
|
|
145
|
+
"""Deserialize `Context` from bytes."""
|
|
146
|
+
return serde.context_from_proto(ProtoContext.FromString(context_bytes))
|
|
147
|
+
|
|
148
|
+
|
|
126
149
|
def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
|
|
127
150
|
"""Generate a TaskRes with a node unavailable error from a TaskIns."""
|
|
128
151
|
current_time = time.time()
|
|
@@ -146,3 +169,47 @@ def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
|
|
|
146
169
|
),
|
|
147
170
|
),
|
|
148
171
|
)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def is_valid_transition(current_status: RunStatus, new_status: RunStatus) -> bool:
|
|
175
|
+
"""Check if a transition between two run statuses is valid.
|
|
176
|
+
|
|
177
|
+
Parameters
|
|
178
|
+
----------
|
|
179
|
+
current_status : RunStatus
|
|
180
|
+
The current status of the run.
|
|
181
|
+
new_status : RunStatus
|
|
182
|
+
The new status to transition to.
|
|
183
|
+
|
|
184
|
+
Returns
|
|
185
|
+
-------
|
|
186
|
+
bool
|
|
187
|
+
True if the transition is valid, False otherwise.
|
|
188
|
+
"""
|
|
189
|
+
return (
|
|
190
|
+
current_status.status,
|
|
191
|
+
new_status.status,
|
|
192
|
+
) in VALID_RUN_STATUS_TRANSITIONS
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def has_valid_sub_status(status: RunStatus) -> bool:
|
|
196
|
+
"""Check if the 'sub_status' field of the given status is valid.
|
|
197
|
+
|
|
198
|
+
Parameters
|
|
199
|
+
----------
|
|
200
|
+
status : RunStatus
|
|
201
|
+
The status object to be checked.
|
|
202
|
+
|
|
203
|
+
Returns
|
|
204
|
+
-------
|
|
205
|
+
bool
|
|
206
|
+
True if the status object has a valid sub-status, False otherwise.
|
|
207
|
+
|
|
208
|
+
Notes
|
|
209
|
+
-----
|
|
210
|
+
Only an empty string (i.e., "") is considered a valid sub-status for
|
|
211
|
+
non-finished statuses. The sub-status of a finished status cannot be empty.
|
|
212
|
+
"""
|
|
213
|
+
if status.status == Status.FINISHED:
|
|
214
|
+
return status.sub_status in VALID_RUN_SUB_STATUSES
|
|
215
|
+
return status.sub_status == ""
|