flwr-nightly 1.13.0.dev20241021__py3-none-any.whl → 1.13.0.dev20241023__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/client/app.py +13 -14
- flwr/client/node_state_tests.py +7 -8
- flwr/client/{node_state.py → run_info_store.py} +3 -3
- flwr/client/supernode/app.py +6 -8
- flwr/common/constant.py +31 -3
- flwr/common/typing.py +9 -0
- flwr/server/app.py +121 -10
- flwr/server/driver/inmemory_driver.py +2 -2
- flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
- flwr/server/serverapp/app.py +78 -0
- flwr/server/superlink/driver/driver_grpc.py +2 -2
- flwr/server/superlink/driver/driver_servicer.py +9 -7
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
- flwr/server/superlink/fleet/rest_rere/rest_api.py +7 -7
- flwr/server/superlink/fleet/vce/vce_api.py +23 -23
- flwr/server/superlink/linkstate/__init__.py +28 -0
- flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +109 -19
- flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +59 -11
- flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
- flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +136 -35
- flwr/server/superlink/{state → linkstate}/utils.py +57 -1
- flwr/simulation/app.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
- flwr/simulation/run_simulation.py +15 -7
- flwr/superexec/app.py +9 -2
- flwr/superexec/simulation.py +1 -1
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241023.dist-info}/METADATA +1 -1
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241023.dist-info}/RECORD +34 -32
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241023.dist-info}/entry_points.txt +1 -0
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241023.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241023.dist-info}/WHEEL +0 -0
|
@@ -12,19 +12,19 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""Abstract base class
|
|
15
|
+
"""Abstract base class LinkState."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import abc
|
|
19
19
|
from typing import Optional
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
|
-
from flwr.common.typing import Run, UserConfig
|
|
22
|
+
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
23
23
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
class
|
|
27
|
-
"""Abstract
|
|
26
|
+
class LinkState(abc.ABC): # pylint: disable=R0904
|
|
27
|
+
"""Abstract LinkState."""
|
|
28
28
|
|
|
29
29
|
@abc.abstractmethod
|
|
30
30
|
def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
|
|
@@ -32,8 +32,8 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
32
32
|
|
|
33
33
|
Usually, the Driver API calls this to schedule instructions.
|
|
34
34
|
|
|
35
|
-
Stores the value of the `task_ins` in the state and, if successful,
|
|
36
|
-
`task_id` (UUID) of the `task_ins`. If, for any reason,
|
|
35
|
+
Stores the value of the `task_ins` in the link state and, if successful,
|
|
36
|
+
returns the `task_id` (UUID) of the `task_ins`. If, for any reason,
|
|
37
37
|
storing the `task_ins` fails, `None` is returned.
|
|
38
38
|
|
|
39
39
|
Constraints
|
|
@@ -130,11 +130,11 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
130
130
|
def create_node(
|
|
131
131
|
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
132
132
|
) -> int:
|
|
133
|
-
"""Create, store in state, and return `node_id`."""
|
|
133
|
+
"""Create, store in the link state, and return `node_id`."""
|
|
134
134
|
|
|
135
135
|
@abc.abstractmethod
|
|
136
136
|
def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
|
|
137
|
-
"""Remove `node_id` from state."""
|
|
137
|
+
"""Remove `node_id` from the link state."""
|
|
138
138
|
|
|
139
139
|
@abc.abstractmethod
|
|
140
140
|
def get_nodes(self, run_id: int) -> set[int]:
|
|
@@ -178,11 +178,59 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
178
178
|
- `fab_version`: The version of the FAB used in the specified run.
|
|
179
179
|
"""
|
|
180
180
|
|
|
181
|
+
@abc.abstractmethod
|
|
182
|
+
def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
|
|
183
|
+
"""Retrieve the statuses for the specified runs.
|
|
184
|
+
|
|
185
|
+
Parameters
|
|
186
|
+
----------
|
|
187
|
+
run_ids : set[int]
|
|
188
|
+
A set of run identifiers for which to retrieve statuses.
|
|
189
|
+
|
|
190
|
+
Returns
|
|
191
|
+
-------
|
|
192
|
+
dict[int, RunStatus]
|
|
193
|
+
A dictionary mapping each valid run ID to its corresponding status.
|
|
194
|
+
|
|
195
|
+
Notes
|
|
196
|
+
-----
|
|
197
|
+
Only valid run IDs that exist in the State will be included in the returned
|
|
198
|
+
dictionary. If a run ID is not found, it will be omitted from the result.
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
@abc.abstractmethod
|
|
202
|
+
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
|
|
203
|
+
"""Update the status of the run with the specified `run_id`.
|
|
204
|
+
|
|
205
|
+
Parameters
|
|
206
|
+
----------
|
|
207
|
+
run_id : int
|
|
208
|
+
The identifier of the run.
|
|
209
|
+
new_status : RunStatus
|
|
210
|
+
The new status to be assigned to the run.
|
|
211
|
+
|
|
212
|
+
Returns
|
|
213
|
+
-------
|
|
214
|
+
bool
|
|
215
|
+
True if the status update is successful; False otherwise.
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
@abc.abstractmethod
|
|
219
|
+
def get_pending_run_id(self) -> Optional[int]:
|
|
220
|
+
"""Get the `run_id` of a run with `Status.PENDING` status.
|
|
221
|
+
|
|
222
|
+
Returns
|
|
223
|
+
-------
|
|
224
|
+
Optional[int]
|
|
225
|
+
The `run_id` of a `Run` that is pending to be started; None if
|
|
226
|
+
there is no Run pending.
|
|
227
|
+
"""
|
|
228
|
+
|
|
181
229
|
@abc.abstractmethod
|
|
182
230
|
def store_server_private_public_key(
|
|
183
231
|
self, private_key: bytes, public_key: bytes
|
|
184
232
|
) -> None:
|
|
185
|
-
"""Store `server_private_key` and `server_public_key` in state."""
|
|
233
|
+
"""Store `server_private_key` and `server_public_key` in the link state."""
|
|
186
234
|
|
|
187
235
|
@abc.abstractmethod
|
|
188
236
|
def get_server_private_key(self) -> Optional[bytes]:
|
|
@@ -194,11 +242,11 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
194
242
|
|
|
195
243
|
@abc.abstractmethod
|
|
196
244
|
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
197
|
-
"""Store a set of `node_public_keys` in state."""
|
|
245
|
+
"""Store a set of `node_public_keys` in the link state."""
|
|
198
246
|
|
|
199
247
|
@abc.abstractmethod
|
|
200
248
|
def store_node_public_key(self, public_key: bytes) -> None:
|
|
201
|
-
"""Store a `node_public_key` in state."""
|
|
249
|
+
"""Store a `node_public_key` in the link state."""
|
|
202
250
|
|
|
203
251
|
@abc.abstractmethod
|
|
204
252
|
def get_node_public_keys(self) -> set[bytes]:
|
|
@@ -20,13 +20,13 @@ from typing import Optional
|
|
|
20
20
|
|
|
21
21
|
from flwr.common.logger import log
|
|
22
22
|
|
|
23
|
-
from .
|
|
24
|
-
from .
|
|
25
|
-
from .
|
|
23
|
+
from .in_memory_linkstate import InMemoryLinkState
|
|
24
|
+
from .linkstate import LinkState
|
|
25
|
+
from .sqlite_linkstate import SqliteLinkState
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
class
|
|
29
|
-
"""Factory class that creates
|
|
28
|
+
class LinkStateFactory:
|
|
29
|
+
"""Factory class that creates LinkState instances.
|
|
30
30
|
|
|
31
31
|
Parameters
|
|
32
32
|
----------
|
|
@@ -39,19 +39,19 @@ class StateFactory:
|
|
|
39
39
|
|
|
40
40
|
def __init__(self, database: str) -> None:
|
|
41
41
|
self.database = database
|
|
42
|
-
self.state_instance: Optional[
|
|
42
|
+
self.state_instance: Optional[LinkState] = None
|
|
43
43
|
|
|
44
|
-
def state(self) ->
|
|
44
|
+
def state(self) -> LinkState:
|
|
45
45
|
"""Return a State instance and create it, if necessary."""
|
|
46
46
|
# InMemoryState
|
|
47
47
|
if self.database == ":flwr-in-memory-state:":
|
|
48
48
|
if self.state_instance is None:
|
|
49
|
-
self.state_instance =
|
|
49
|
+
self.state_instance = InMemoryLinkState()
|
|
50
50
|
log(DEBUG, "Using InMemoryState")
|
|
51
51
|
return self.state_instance
|
|
52
52
|
|
|
53
53
|
# SqliteState
|
|
54
|
-
state =
|
|
54
|
+
state = SqliteLinkState(self.database)
|
|
55
55
|
state.initialize()
|
|
56
56
|
log(DEBUG, "Using SqliteState")
|
|
57
57
|
return state
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""SQLite based implemenation of
|
|
15
|
+
"""SQLite based implemenation of the link state."""
|
|
16
16
|
|
|
17
17
|
# pylint: disable=too-many-lines
|
|
18
18
|
|
|
@@ -30,20 +30,23 @@ from flwr.common.constant import (
|
|
|
30
30
|
MESSAGE_TTL_TOLERANCE,
|
|
31
31
|
NODE_ID_NUM_BYTES,
|
|
32
32
|
RUN_ID_NUM_BYTES,
|
|
33
|
+
Status,
|
|
33
34
|
)
|
|
34
|
-
from flwr.common.typing import Run, UserConfig
|
|
35
|
+
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
35
36
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
36
37
|
from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
|
|
37
38
|
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
|
38
39
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
39
40
|
|
|
40
|
-
from .
|
|
41
|
+
from .linkstate import LinkState
|
|
41
42
|
from .utils import (
|
|
42
43
|
convert_sint64_to_uint64,
|
|
43
44
|
convert_sint64_values_in_dict_to_uint64,
|
|
44
45
|
convert_uint64_to_sint64,
|
|
45
46
|
convert_uint64_values_in_dict_to_sint64,
|
|
46
47
|
generate_rand_int_from_bytes,
|
|
48
|
+
has_valid_sub_status,
|
|
49
|
+
is_valid_transition,
|
|
47
50
|
make_node_unavailable_taskres,
|
|
48
51
|
)
|
|
49
52
|
|
|
@@ -79,7 +82,13 @@ CREATE TABLE IF NOT EXISTS run(
|
|
|
79
82
|
fab_id TEXT,
|
|
80
83
|
fab_version TEXT,
|
|
81
84
|
fab_hash TEXT,
|
|
82
|
-
override_config TEXT
|
|
85
|
+
override_config TEXT,
|
|
86
|
+
pending_at TEXT,
|
|
87
|
+
starting_at TEXT,
|
|
88
|
+
running_at TEXT,
|
|
89
|
+
finished_at TEXT,
|
|
90
|
+
sub_status TEXT,
|
|
91
|
+
details TEXT
|
|
83
92
|
);
|
|
84
93
|
"""
|
|
85
94
|
|
|
@@ -126,14 +135,14 @@ CREATE TABLE IF NOT EXISTS task_res(
|
|
|
126
135
|
DictOrTuple = Union[tuple[Any, ...], dict[str, Any]]
|
|
127
136
|
|
|
128
137
|
|
|
129
|
-
class
|
|
130
|
-
"""SQLite-based
|
|
138
|
+
class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
139
|
+
"""SQLite-based LinkState implementation."""
|
|
131
140
|
|
|
132
141
|
def __init__(
|
|
133
142
|
self,
|
|
134
143
|
database_path: str,
|
|
135
144
|
) -> None:
|
|
136
|
-
"""Initialize an
|
|
145
|
+
"""Initialize an SqliteLinkState.
|
|
137
146
|
|
|
138
147
|
Parameters
|
|
139
148
|
----------
|
|
@@ -183,7 +192,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
183
192
|
) -> list[dict[str, Any]]:
|
|
184
193
|
"""Execute a SQL query."""
|
|
185
194
|
if self.conn is None:
|
|
186
|
-
raise AttributeError("
|
|
195
|
+
raise AttributeError("LinkState is not initialized.")
|
|
187
196
|
|
|
188
197
|
if data is None:
|
|
189
198
|
data = []
|
|
@@ -216,9 +225,9 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
216
225
|
|
|
217
226
|
Usually, the Driver API calls this to schedule instructions.
|
|
218
227
|
|
|
219
|
-
Stores the value of the task_ins in the state and, if successful,
|
|
220
|
-
task_id (UUID) of the task_ins. If, for any reason, storing
|
|
221
|
-
`None` is returned.
|
|
228
|
+
Stores the value of the task_ins in the link state and, if successful,
|
|
229
|
+
returns the task_id (UUID) of the task_ins. If, for any reason, storing
|
|
230
|
+
the task_ins fails, `None` is returned.
|
|
222
231
|
|
|
223
232
|
Constraints
|
|
224
233
|
-----------
|
|
@@ -645,7 +654,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
645
654
|
"""
|
|
646
655
|
|
|
647
656
|
if self.conn is None:
|
|
648
|
-
raise AttributeError("
|
|
657
|
+
raise AttributeError("LinkState not intitialized")
|
|
649
658
|
|
|
650
659
|
with self.conn:
|
|
651
660
|
self.conn.execute(query_1, data)
|
|
@@ -656,7 +665,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
656
665
|
def create_node(
|
|
657
666
|
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
658
667
|
) -> int:
|
|
659
|
-
"""Create, store in state, and return `node_id`."""
|
|
668
|
+
"""Create, store in the link state, and return `node_id`."""
|
|
660
669
|
# Sample a random uint64 as node_id
|
|
661
670
|
uint64_node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
|
|
662
671
|
|
|
@@ -706,7 +715,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
706
715
|
params += (public_key,) # type: ignore
|
|
707
716
|
|
|
708
717
|
if self.conn is None:
|
|
709
|
-
raise AttributeError("
|
|
718
|
+
raise AttributeError("LinkState is not initialized.")
|
|
710
719
|
|
|
711
720
|
try:
|
|
712
721
|
with self.conn:
|
|
@@ -773,26 +782,16 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
773
782
|
if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
|
|
774
783
|
query = (
|
|
775
784
|
"INSERT INTO run "
|
|
776
|
-
"(run_id, fab_id, fab_version, fab_hash, override_config
|
|
777
|
-
"
|
|
785
|
+
"(run_id, fab_id, fab_version, fab_hash, override_config, pending_at, "
|
|
786
|
+
"starting_at, running_at, finished_at, sub_status, details)"
|
|
787
|
+
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
|
|
778
788
|
)
|
|
779
789
|
if fab_hash:
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
self.query(
|
|
786
|
-
query,
|
|
787
|
-
(
|
|
788
|
-
sint64_run_id,
|
|
789
|
-
fab_id,
|
|
790
|
-
fab_version,
|
|
791
|
-
"",
|
|
792
|
-
json.dumps(override_config),
|
|
793
|
-
),
|
|
794
|
-
)
|
|
795
|
-
# Note: we need to return the uint64 value of the run_id
|
|
790
|
+
fab_id, fab_version = "", ""
|
|
791
|
+
override_config_json = json.dumps(override_config)
|
|
792
|
+
data = [sint64_run_id, fab_id, fab_version, fab_hash, override_config_json]
|
|
793
|
+
data += [now().isoformat(), "", "", "", "", ""]
|
|
794
|
+
self.query(query, tuple(data))
|
|
796
795
|
return uint64_run_id
|
|
797
796
|
log(ERROR, "Unexpected run creation failure.")
|
|
798
797
|
return 0
|
|
@@ -800,7 +799,7 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
800
799
|
def store_server_private_public_key(
|
|
801
800
|
self, private_key: bytes, public_key: bytes
|
|
802
801
|
) -> None:
|
|
803
|
-
"""Store `server_private_key` and `server_public_key` in state."""
|
|
802
|
+
"""Store `server_private_key` and `server_public_key` in the link state."""
|
|
804
803
|
query = "SELECT COUNT(*) FROM credential"
|
|
805
804
|
count = self.query(query)[0]["COUNT(*)"]
|
|
806
805
|
if count < 1:
|
|
@@ -833,13 +832,13 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
833
832
|
return public_key
|
|
834
833
|
|
|
835
834
|
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
836
|
-
"""Store a set of `node_public_keys` in state."""
|
|
835
|
+
"""Store a set of `node_public_keys` in the link state."""
|
|
837
836
|
query = "INSERT INTO public_key (public_key) VALUES (?)"
|
|
838
837
|
data = [(key,) for key in public_keys]
|
|
839
838
|
self.query(query, data)
|
|
840
839
|
|
|
841
840
|
def store_node_public_key(self, public_key: bytes) -> None:
|
|
842
|
-
"""Store a `node_public_key` in state."""
|
|
841
|
+
"""Store a `node_public_key` in the link state."""
|
|
843
842
|
query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
|
|
844
843
|
self.query(query, {"public_key": public_key})
|
|
845
844
|
|
|
@@ -868,6 +867,94 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
868
867
|
log(ERROR, "`run_id` does not exist.")
|
|
869
868
|
return None
|
|
870
869
|
|
|
870
|
+
def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
|
|
871
|
+
"""Retrieve the statuses for the specified runs."""
|
|
872
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
873
|
+
sint64_run_ids = (convert_uint64_to_sint64(run_id) for run_id in set(run_ids))
|
|
874
|
+
query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});"
|
|
875
|
+
rows = self.query(query, tuple(sint64_run_ids))
|
|
876
|
+
|
|
877
|
+
return {
|
|
878
|
+
# Restore uint64 run IDs
|
|
879
|
+
convert_sint64_to_uint64(row["run_id"]): RunStatus(
|
|
880
|
+
status=determine_run_status(row),
|
|
881
|
+
sub_status=row["sub_status"],
|
|
882
|
+
details=row["details"],
|
|
883
|
+
)
|
|
884
|
+
for row in rows
|
|
885
|
+
}
|
|
886
|
+
|
|
887
|
+
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
|
|
888
|
+
"""Update the status of the run with the specified `run_id`."""
|
|
889
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
890
|
+
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
891
|
+
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
892
|
+
rows = self.query(query, (sint64_run_id,))
|
|
893
|
+
|
|
894
|
+
# Check if the run_id exists
|
|
895
|
+
if not rows:
|
|
896
|
+
log(ERROR, "`run_id` is invalid")
|
|
897
|
+
return False
|
|
898
|
+
|
|
899
|
+
# Check if the status transition is valid
|
|
900
|
+
row = rows[0]
|
|
901
|
+
current_status = RunStatus(
|
|
902
|
+
status=determine_run_status(row),
|
|
903
|
+
sub_status=row["sub_status"],
|
|
904
|
+
details=row["details"],
|
|
905
|
+
)
|
|
906
|
+
if not is_valid_transition(current_status, new_status):
|
|
907
|
+
log(
|
|
908
|
+
ERROR,
|
|
909
|
+
'Invalid status transition: from "%s" to "%s"',
|
|
910
|
+
current_status.status,
|
|
911
|
+
new_status.status,
|
|
912
|
+
)
|
|
913
|
+
return False
|
|
914
|
+
|
|
915
|
+
# Check if the sub-status is valid
|
|
916
|
+
if not has_valid_sub_status(current_status):
|
|
917
|
+
log(
|
|
918
|
+
ERROR,
|
|
919
|
+
'Invalid sub-status "%s" for status "%s"',
|
|
920
|
+
current_status.sub_status,
|
|
921
|
+
current_status.status,
|
|
922
|
+
)
|
|
923
|
+
return False
|
|
924
|
+
|
|
925
|
+
# Update the status
|
|
926
|
+
query = "UPDATE run SET %s= ?, sub_status = ?, details = ? "
|
|
927
|
+
query += "WHERE run_id = ?;"
|
|
928
|
+
|
|
929
|
+
timestamp_fld = ""
|
|
930
|
+
if new_status.status == Status.STARTING:
|
|
931
|
+
timestamp_fld = "starting_at"
|
|
932
|
+
elif new_status.status == Status.RUNNING:
|
|
933
|
+
timestamp_fld = "running_at"
|
|
934
|
+
elif new_status.status == Status.FINISHED:
|
|
935
|
+
timestamp_fld = "finished_at"
|
|
936
|
+
|
|
937
|
+
data = (
|
|
938
|
+
now().isoformat(),
|
|
939
|
+
new_status.sub_status,
|
|
940
|
+
new_status.details,
|
|
941
|
+
sint64_run_id,
|
|
942
|
+
)
|
|
943
|
+
self.query(query % timestamp_fld, data)
|
|
944
|
+
return True
|
|
945
|
+
|
|
946
|
+
def get_pending_run_id(self) -> Optional[int]:
|
|
947
|
+
"""Get the `run_id` of a run with `Status.PENDING` status, if any."""
|
|
948
|
+
pending_run_id = None
|
|
949
|
+
|
|
950
|
+
# Fetch all runs with unset `starting_at` (i.e. they are in PENDING status)
|
|
951
|
+
query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1;"
|
|
952
|
+
rows = self.query(query)
|
|
953
|
+
if rows:
|
|
954
|
+
pending_run_id = convert_sint64_to_uint64(rows[0]["run_id"])
|
|
955
|
+
|
|
956
|
+
return pending_run_id
|
|
957
|
+
|
|
871
958
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
872
959
|
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
873
960
|
sint64_node_id = convert_uint64_to_sint64(node_id)
|
|
@@ -1023,3 +1110,17 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
|
|
|
1023
1110
|
),
|
|
1024
1111
|
)
|
|
1025
1112
|
return result
|
|
1113
|
+
|
|
1114
|
+
|
|
1115
|
+
def determine_run_status(row: dict[str, Any]) -> str:
|
|
1116
|
+
"""Determine the status of the run based on timestamp fields."""
|
|
1117
|
+
if row["pending_at"]:
|
|
1118
|
+
if row["starting_at"]:
|
|
1119
|
+
if row["running_at"]:
|
|
1120
|
+
if row["finished_at"]:
|
|
1121
|
+
return Status.FINISHED
|
|
1122
|
+
return Status.RUNNING
|
|
1123
|
+
return Status.STARTING
|
|
1124
|
+
return Status.PENDING
|
|
1125
|
+
run_id = convert_sint64_to_uint64(row["run_id"])
|
|
1126
|
+
raise sqlite3.IntegrityError(f"The run {run_id} does not have a valid status.")
|
|
@@ -21,7 +21,8 @@ from os import urandom
|
|
|
21
21
|
from uuid import uuid4
|
|
22
22
|
|
|
23
23
|
from flwr.common import log
|
|
24
|
-
from flwr.common.constant import ErrorCode
|
|
24
|
+
from flwr.common.constant import ErrorCode, Status, SubStatus
|
|
25
|
+
from flwr.common.typing import RunStatus
|
|
25
26
|
from flwr.proto.error_pb2 import Error # pylint: disable=E0611
|
|
26
27
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
27
28
|
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
|
@@ -31,6 +32,17 @@ NODE_UNAVAILABLE_ERROR_REASON = (
|
|
|
31
32
|
"It exceeds the time limit specified in its last ping."
|
|
32
33
|
)
|
|
33
34
|
|
|
35
|
+
VALID_RUN_STATUS_TRANSITIONS = {
|
|
36
|
+
(Status.PENDING, Status.STARTING),
|
|
37
|
+
(Status.STARTING, Status.RUNNING),
|
|
38
|
+
(Status.RUNNING, Status.FINISHED),
|
|
39
|
+
}
|
|
40
|
+
VALID_RUN_SUB_STATUSES = {
|
|
41
|
+
SubStatus.COMPLETED,
|
|
42
|
+
SubStatus.FAILED,
|
|
43
|
+
SubStatus.STOPPED,
|
|
44
|
+
}
|
|
45
|
+
|
|
34
46
|
|
|
35
47
|
def generate_rand_int_from_bytes(num_bytes: int) -> int:
|
|
36
48
|
"""Generate a random unsigned integer from `num_bytes` bytes."""
|
|
@@ -146,3 +158,47 @@ def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
|
|
|
146
158
|
),
|
|
147
159
|
),
|
|
148
160
|
)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def is_valid_transition(current_status: RunStatus, new_status: RunStatus) -> bool:
|
|
164
|
+
"""Check if a transition between two run statuses is valid.
|
|
165
|
+
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
current_status : RunStatus
|
|
169
|
+
The current status of the run.
|
|
170
|
+
new_status : RunStatus
|
|
171
|
+
The new status to transition to.
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
bool
|
|
176
|
+
True if the transition is valid, False otherwise.
|
|
177
|
+
"""
|
|
178
|
+
return (
|
|
179
|
+
current_status.status,
|
|
180
|
+
new_status.status,
|
|
181
|
+
) in VALID_RUN_STATUS_TRANSITIONS
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def has_valid_sub_status(status: RunStatus) -> bool:
|
|
185
|
+
"""Check if the 'sub_status' field of the given status is valid.
|
|
186
|
+
|
|
187
|
+
Parameters
|
|
188
|
+
----------
|
|
189
|
+
status : RunStatus
|
|
190
|
+
The status object to be checked.
|
|
191
|
+
|
|
192
|
+
Returns
|
|
193
|
+
-------
|
|
194
|
+
bool
|
|
195
|
+
True if the status object has a valid sub-status, False otherwise.
|
|
196
|
+
|
|
197
|
+
Notes
|
|
198
|
+
-----
|
|
199
|
+
Only an empty string (i.e., "") is considered a valid sub-status for
|
|
200
|
+
non-finished statuses. The sub-status of a finished status cannot be empty.
|
|
201
|
+
"""
|
|
202
|
+
if status.status == Status.FINISHED:
|
|
203
|
+
return status.sub_status in VALID_RUN_SUB_STATUSES
|
|
204
|
+
return status.sub_status == ""
|
flwr/simulation/app.py
CHANGED
|
@@ -36,7 +36,7 @@ from flwr.server.history import History
|
|
|
36
36
|
from flwr.server.server import Server, init_defaults, run_fl
|
|
37
37
|
from flwr.server.server_config import ServerConfig
|
|
38
38
|
from flwr.server.strategy import Strategy
|
|
39
|
-
from flwr.server.superlink.
|
|
39
|
+
from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes
|
|
40
40
|
from flwr.simulation.ray_transport.ray_actor import (
|
|
41
41
|
ClientAppActor,
|
|
42
42
|
VirtualClientEngineActor,
|
|
@@ -22,7 +22,7 @@ from typing import Optional
|
|
|
22
22
|
from flwr import common
|
|
23
23
|
from flwr.client import ClientFnExt
|
|
24
24
|
from flwr.client.client_app import ClientApp
|
|
25
|
-
from flwr.client.
|
|
25
|
+
from flwr.client.run_info_store import DeprecatedRunInfoStore
|
|
26
26
|
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
27
27
|
from flwr.common.constant import (
|
|
28
28
|
NUM_PARTITIONS_KEY,
|
|
@@ -65,7 +65,7 @@ class RayActorClientProxy(ClientProxy):
|
|
|
65
65
|
|
|
66
66
|
self.app_fn = _load_app
|
|
67
67
|
self.actor_pool = actor_pool
|
|
68
|
-
self.proxy_state =
|
|
68
|
+
self.proxy_state = DeprecatedRunInfoStore(
|
|
69
69
|
node_id=node_id,
|
|
70
70
|
node_config={
|
|
71
71
|
PARTITION_ID_KEY: str(partition_id),
|
|
@@ -29,23 +29,24 @@ from typing import Any, Optional
|
|
|
29
29
|
|
|
30
30
|
from flwr.cli.config_utils import load_and_validate
|
|
31
31
|
from flwr.client import ClientApp
|
|
32
|
-
from flwr.common import EventType, event, log
|
|
32
|
+
from flwr.common import EventType, event, log, now
|
|
33
33
|
from flwr.common.config import get_fused_config_from_dir, parse_config_args
|
|
34
|
-
from flwr.common.constant import RUN_ID_NUM_BYTES
|
|
34
|
+
from flwr.common.constant import RUN_ID_NUM_BYTES, Status
|
|
35
35
|
from flwr.common.logger import (
|
|
36
36
|
set_logger_propagation,
|
|
37
37
|
update_console_handler,
|
|
38
38
|
warn_deprecated_feature,
|
|
39
39
|
warn_deprecated_feature_with_example,
|
|
40
40
|
)
|
|
41
|
-
from flwr.common.typing import Run, UserConfig
|
|
41
|
+
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
42
42
|
from flwr.server.driver import Driver, InMemoryDriver
|
|
43
43
|
from flwr.server.run_serverapp import run as run_server_app
|
|
44
44
|
from flwr.server.server_app import ServerApp
|
|
45
45
|
from flwr.server.superlink.fleet import vce
|
|
46
46
|
from flwr.server.superlink.fleet.vce.backend.backend import BackendConfig
|
|
47
|
-
from flwr.server.superlink.
|
|
48
|
-
from flwr.server.superlink.
|
|
47
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
48
|
+
from flwr.server.superlink.linkstate.in_memory_linkstate import RunRecord
|
|
49
|
+
from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes
|
|
49
50
|
from flwr.simulation.ray_transport.utils import (
|
|
50
51
|
enable_tf_gpu_growth as enable_gpu_growth,
|
|
51
52
|
)
|
|
@@ -389,7 +390,7 @@ def _main_loop(
|
|
|
389
390
|
) -> None:
|
|
390
391
|
"""Start ServerApp on a separate thread, then launch Simulation Engine."""
|
|
391
392
|
# Initialize StateFactory
|
|
392
|
-
state_factory =
|
|
393
|
+
state_factory = LinkStateFactory(":flwr-in-memory-state:")
|
|
393
394
|
|
|
394
395
|
f_stop = threading.Event()
|
|
395
396
|
# A Threading event to indicate if an exception was raised in the ServerApp thread
|
|
@@ -399,7 +400,14 @@ def _main_loop(
|
|
|
399
400
|
try:
|
|
400
401
|
# Register run
|
|
401
402
|
log(DEBUG, "Pre-registering run with id %s", run.run_id)
|
|
402
|
-
|
|
403
|
+
init_status = RunStatus(Status.RUNNING, "", "")
|
|
404
|
+
state_factory.state().run_ids[run.run_id] = RunRecord( # type: ignore
|
|
405
|
+
run=run,
|
|
406
|
+
status=init_status,
|
|
407
|
+
starting_at=now().isoformat(),
|
|
408
|
+
running_at=now().isoformat(),
|
|
409
|
+
finished_at="",
|
|
410
|
+
)
|
|
403
411
|
|
|
404
412
|
if server_app_run_config is None:
|
|
405
413
|
server_app_run_config = {}
|
flwr/superexec/app.py
CHANGED
|
@@ -27,6 +27,7 @@ from flwr.common.address import parse_address
|
|
|
27
27
|
from flwr.common.config import parse_config_args
|
|
28
28
|
from flwr.common.constant import EXEC_API_DEFAULT_ADDRESS
|
|
29
29
|
from flwr.common.exit_handlers import register_exit_handlers
|
|
30
|
+
from flwr.common.logger import warn_deprecated_feature
|
|
30
31
|
from flwr.common.object_ref import load_app, validate
|
|
31
32
|
|
|
32
33
|
from .exec_grpc import run_superexec_api_grpc
|
|
@@ -37,6 +38,12 @@ def run_superexec() -> None:
|
|
|
37
38
|
"""Run Flower SuperExec."""
|
|
38
39
|
log(INFO, "Starting Flower SuperExec")
|
|
39
40
|
|
|
41
|
+
warn_deprecated_feature(
|
|
42
|
+
"Manually launching the SuperExec is deprecated. Since `flwr 1.13.0` "
|
|
43
|
+
"the executor service runs in the SuperLink. Launching it manually is not "
|
|
44
|
+
"recommended."
|
|
45
|
+
)
|
|
46
|
+
|
|
40
47
|
event(EventType.RUN_SUPEREXEC_ENTER)
|
|
41
48
|
|
|
42
49
|
args = _parse_args_run_superexec().parse_args()
|
|
@@ -54,7 +61,7 @@ def run_superexec() -> None:
|
|
|
54
61
|
# Start SuperExec API
|
|
55
62
|
superexec_server: grpc.Server = run_superexec_api_grpc(
|
|
56
63
|
address=address,
|
|
57
|
-
executor=
|
|
64
|
+
executor=load_executor(args),
|
|
58
65
|
certificates=certificates,
|
|
59
66
|
config=parse_config_args(
|
|
60
67
|
[args.executor_config] if args.executor_config else args.executor_config
|
|
@@ -163,7 +170,7 @@ def _try_obtain_certificates(
|
|
|
163
170
|
)
|
|
164
171
|
|
|
165
172
|
|
|
166
|
-
def
|
|
173
|
+
def load_executor(
|
|
167
174
|
args: argparse.Namespace,
|
|
168
175
|
) -> Executor:
|
|
169
176
|
"""Get the executor plugin."""
|
flwr/superexec/simulation.py
CHANGED
|
@@ -29,7 +29,7 @@ from flwr.common.config import unflatten_dict
|
|
|
29
29
|
from flwr.common.constant import RUN_ID_NUM_BYTES
|
|
30
30
|
from flwr.common.logger import log
|
|
31
31
|
from flwr.common.typing import UserConfig
|
|
32
|
-
from flwr.server.superlink.
|
|
32
|
+
from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes
|
|
33
33
|
|
|
34
34
|
from .executor import Executor, RunTracker
|
|
35
35
|
|