flwr 1.18.0__py3-none-any.whl → 1.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/app/__init__.py +15 -0
- flwr/app/error.py +68 -0
- flwr/app/metadata.py +223 -0
- flwr/cli/build.py +94 -59
- flwr/cli/log.py +3 -3
- flwr/cli/login/login.py +3 -7
- flwr/cli/ls.py +15 -36
- flwr/cli/new/new.py +12 -4
- flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
- flwr/cli/new/templates/app/README.md.tpl +5 -0
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +25 -17
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
- flwr/cli/run/run.py +48 -49
- flwr/cli/stop.py +2 -2
- flwr/cli/utils.py +38 -5
- flwr/client/__init__.py +2 -2
- flwr/client/client_app.py +1 -1
- flwr/client/clientapp/__init__.py +0 -7
- flwr/client/grpc_adapter_client/connection.py +15 -8
- flwr/client/grpc_rere_client/connection.py +142 -97
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/mod/comms_mods.py +36 -17
- flwr/client/rest_client/connection.py +176 -103
- flwr/clientapp/__init__.py +15 -0
- flwr/common/__init__.py +2 -2
- flwr/common/auth_plugin/__init__.py +2 -0
- flwr/common/auth_plugin/auth_plugin.py +29 -3
- flwr/common/constant.py +39 -8
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit/exit_code.py +16 -1
- flwr/common/exit_handlers.py +30 -0
- flwr/common/grpc.py +12 -1
- flwr/common/heartbeat.py +165 -0
- flwr/common/inflatable.py +290 -0
- flwr/common/inflatable_protobuf_utils.py +141 -0
- flwr/common/inflatable_utils.py +508 -0
- flwr/common/message.py +110 -242
- flwr/common/record/__init__.py +2 -1
- flwr/common/record/array.py +402 -0
- flwr/common/record/arraychunk.py +59 -0
- flwr/common/record/arrayrecord.py +103 -225
- flwr/common/record/configrecord.py +59 -4
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/metricrecord.py +55 -4
- flwr/common/record/recorddict.py +69 -1
- flwr/common/recorddict_compat.py +2 -2
- flwr/common/retry_invoker.py +5 -1
- flwr/common/serde.py +59 -211
- flwr/common/serde_utils.py +175 -0
- flwr/common/typing.py +5 -3
- flwr/compat/__init__.py +15 -0
- flwr/compat/client/__init__.py +15 -0
- flwr/{client → compat/client}/app.py +28 -185
- flwr/compat/common/__init__.py +15 -0
- flwr/compat/server/__init__.py +15 -0
- flwr/compat/server/app.py +174 -0
- flwr/compat/simulation/__init__.py +15 -0
- flwr/proto/appio_pb2.py +43 -0
- flwr/proto/appio_pb2.pyi +151 -0
- flwr/proto/appio_pb2_grpc.py +4 -0
- flwr/proto/appio_pb2_grpc.pyi +4 -0
- flwr/proto/clientappio_pb2.py +12 -19
- flwr/proto/clientappio_pb2.pyi +23 -101
- flwr/proto/clientappio_pb2_grpc.py +269 -28
- flwr/proto/clientappio_pb2_grpc.pyi +114 -20
- flwr/proto/fleet_pb2.py +24 -27
- flwr/proto/fleet_pb2.pyi +19 -35
- flwr/proto/fleet_pb2_grpc.py +117 -13
- flwr/proto/fleet_pb2_grpc.pyi +47 -6
- flwr/proto/heartbeat_pb2.py +33 -0
- flwr/proto/heartbeat_pb2.pyi +66 -0
- flwr/proto/heartbeat_pb2_grpc.py +4 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +28 -11
- flwr/proto/message_pb2.pyi +125 -0
- flwr/proto/recorddict_pb2.py +16 -28
- flwr/proto/recorddict_pb2.pyi +46 -64
- flwr/proto/run_pb2.py +24 -32
- flwr/proto/run_pb2.pyi +4 -52
- flwr/proto/serverappio_pb2.py +9 -23
- flwr/proto/serverappio_pb2.pyi +0 -110
- flwr/proto/serverappio_pb2_grpc.py +177 -72
- flwr/proto/serverappio_pb2_grpc.pyi +75 -33
- flwr/proto/simulationio_pb2.py +12 -11
- flwr/proto/simulationio_pb2_grpc.py +35 -0
- flwr/proto/simulationio_pb2_grpc.pyi +14 -0
- flwr/server/__init__.py +1 -1
- flwr/server/app.py +69 -187
- flwr/server/compat/app_utils.py +50 -28
- flwr/server/fleet_event_log_interceptor.py +6 -2
- flwr/server/grid/grpc_grid.py +148 -41
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/serverapp/app.py +45 -17
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +21 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
- flwr/server/superlink/fleet/message_handler/message_handler.py +130 -19
- flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -13
- flwr/server/superlink/fleet/vce/vce_api.py +6 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
- flwr/server/superlink/linkstate/linkstate.py +53 -20
- flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
- flwr/server/superlink/linkstate/utils.py +33 -29
- flwr/server/superlink/serverappio/serverappio_grpc.py +4 -1
- flwr/server/superlink/serverappio/serverappio_servicer.py +230 -84
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
- flwr/server/superlink/utils.py +9 -2
- flwr/server/utils/validator.py +2 -2
- flwr/serverapp/__init__.py +15 -0
- flwr/simulation/app.py +25 -0
- flwr/simulation/run_simulation.py +17 -0
- flwr/supercore/__init__.py +15 -0
- flwr/{server/superlink → supercore}/ffs/__init__.py +2 -0
- flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
- flwr/supercore/grpc_health/__init__.py +22 -0
- flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
- flwr/supercore/license_plugin/__init__.py +22 -0
- flwr/supercore/license_plugin/license_plugin.py +26 -0
- flwr/supercore/object_store/__init__.py +24 -0
- flwr/supercore/object_store/in_memory_object_store.py +229 -0
- flwr/supercore/object_store/object_store.py +170 -0
- flwr/supercore/object_store/object_store_factory.py +44 -0
- flwr/supercore/object_store/utils.py +43 -0
- flwr/supercore/scheduler/__init__.py +22 -0
- flwr/supercore/scheduler/plugin.py +71 -0
- flwr/{client/nodestate/nodestate.py → supercore/utils.py} +14 -13
- flwr/superexec/deployment.py +7 -4
- flwr/superexec/exec_event_log_interceptor.py +8 -4
- flwr/superexec/exec_grpc.py +25 -5
- flwr/superexec/exec_license_interceptor.py +82 -0
- flwr/superexec/exec_servicer.py +135 -24
- flwr/superexec/exec_user_auth_interceptor.py +45 -8
- flwr/superexec/executor.py +5 -1
- flwr/superexec/simulation.py +8 -3
- flwr/superlink/__init__.py +15 -0
- flwr/{client/supernode → supernode}/__init__.py +0 -7
- flwr/supernode/cli/__init__.py +24 -0
- flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -19
- flwr/supernode/cli/flwr_clientapp.py +88 -0
- flwr/supernode/nodestate/in_memory_nodestate.py +199 -0
- flwr/supernode/nodestate/nodestate.py +227 -0
- flwr/supernode/runtime/__init__.py +15 -0
- flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +135 -89
- flwr/supernode/scheduler/__init__.py +22 -0
- flwr/supernode/scheduler/simple_clientapp_scheduler_plugin.py +49 -0
- flwr/supernode/servicer/__init__.py +15 -0
- flwr/supernode/servicer/clientappio/__init__.py +22 -0
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +303 -0
- flwr/supernode/start_client_internal.py +589 -0
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/METADATA +6 -4
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/RECORD +171 -123
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/WHEEL +1 -1
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/entry_points.txt +2 -2
- flwr/client/clientapp/clientappio_servicer.py +0 -244
- flwr/client/heartbeat.py +0 -74
- flwr/client/nodestate/in_memory_nodestate.py +0 -38
- /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
- /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
- /flwr/{client → supernode}/nodestate/__init__.py +0 -0
- /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
|
@@ -18,19 +18,22 @@
|
|
|
18
18
|
import threading
|
|
19
19
|
import time
|
|
20
20
|
from bisect import bisect_right
|
|
21
|
+
from collections import defaultdict
|
|
21
22
|
from dataclasses import dataclass, field
|
|
22
23
|
from logging import ERROR, WARNING
|
|
23
24
|
from typing import Optional
|
|
24
|
-
from uuid import UUID, uuid4
|
|
25
25
|
|
|
26
26
|
from flwr.common import Context, Message, log, now
|
|
27
27
|
from flwr.common.constant import (
|
|
28
|
+
HEARTBEAT_MAX_INTERVAL,
|
|
29
|
+
HEARTBEAT_PATIENCE,
|
|
28
30
|
MESSAGE_TTL_TOLERANCE,
|
|
29
31
|
NODE_ID_NUM_BYTES,
|
|
30
|
-
|
|
32
|
+
RUN_FAILURE_DETAILS_NO_HEARTBEAT,
|
|
31
33
|
RUN_ID_NUM_BYTES,
|
|
32
34
|
SUPERLINK_NODE_ID,
|
|
33
35
|
Status,
|
|
36
|
+
SubStatus,
|
|
34
37
|
)
|
|
35
38
|
from flwr.common.record import ConfigRecord
|
|
36
39
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
@@ -52,8 +55,11 @@ class RunRecord: # pylint: disable=R0902
|
|
|
52
55
|
"""The record of a specific run, including its status and timestamps."""
|
|
53
56
|
|
|
54
57
|
run: Run
|
|
58
|
+
active_until: float = 0.0
|
|
59
|
+
heartbeat_interval: float = 0.0
|
|
55
60
|
logs: list[tuple[float, str]] = field(default_factory=list)
|
|
56
61
|
log_lock: threading.Lock = field(default_factory=threading.Lock)
|
|
62
|
+
lock: threading.RLock = field(default_factory=threading.RLock)
|
|
57
63
|
|
|
58
64
|
|
|
59
65
|
class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
@@ -61,7 +67,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
61
67
|
|
|
62
68
|
def __init__(self) -> None:
|
|
63
69
|
|
|
64
|
-
# Map node_id to (online_until,
|
|
70
|
+
# Map node_id to (online_until, heartbeat_interval)
|
|
65
71
|
self.node_ids: dict[int, tuple[float, float]] = {}
|
|
66
72
|
self.public_key_to_node_id: dict[bytes, int] = {}
|
|
67
73
|
self.node_id_to_public_key: dict[int, bytes] = {}
|
|
@@ -70,15 +76,18 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
70
76
|
self.run_ids: dict[int, RunRecord] = {}
|
|
71
77
|
self.contexts: dict[int, Context] = {}
|
|
72
78
|
self.federation_options: dict[int, ConfigRecord] = {}
|
|
73
|
-
self.message_ins_store: dict[
|
|
74
|
-
self.message_res_store: dict[
|
|
75
|
-
self.message_ins_id_to_message_res_id: dict[
|
|
79
|
+
self.message_ins_store: dict[str, Message] = {}
|
|
80
|
+
self.message_res_store: dict[str, Message] = {}
|
|
81
|
+
self.message_ins_id_to_message_res_id: dict[str, str] = {}
|
|
82
|
+
|
|
83
|
+
# Map flwr_aid to run_ids for O(1) reverse index lookup
|
|
84
|
+
self.flwr_aid_to_run_ids: dict[str, set[int]] = defaultdict(set)
|
|
76
85
|
|
|
77
86
|
self.node_public_keys: set[bytes] = set()
|
|
78
87
|
|
|
79
88
|
self.lock = threading.RLock()
|
|
80
89
|
|
|
81
|
-
def store_message_ins(self, message: Message) -> Optional[
|
|
90
|
+
def store_message_ins(self, message: Message) -> Optional[str]:
|
|
82
91
|
"""Store one Message."""
|
|
83
92
|
# Validate message
|
|
84
93
|
errors = validate_message(message, is_reply_message=False)
|
|
@@ -106,12 +115,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
106
115
|
)
|
|
107
116
|
return None
|
|
108
117
|
|
|
109
|
-
|
|
110
|
-
message_id = uuid4()
|
|
111
|
-
|
|
112
|
-
# Store Message
|
|
113
|
-
# pylint: disable-next=W0212
|
|
114
|
-
message.metadata._message_id = str(message_id) # type: ignore
|
|
118
|
+
message_id = message.metadata.message_id
|
|
115
119
|
with self.lock:
|
|
116
120
|
self.message_ins_store[message_id] = message
|
|
117
121
|
|
|
@@ -147,7 +151,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
147
151
|
return message_ins_list
|
|
148
152
|
|
|
149
153
|
# pylint: disable=R0911
|
|
150
|
-
def store_message_res(self, message: Message) -> Optional[
|
|
154
|
+
def store_message_res(self, message: Message) -> Optional[str]:
|
|
151
155
|
"""Store one Message."""
|
|
152
156
|
# Validate message
|
|
153
157
|
errors = validate_message(message, is_reply_message=True)
|
|
@@ -159,7 +163,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
159
163
|
with self.lock:
|
|
160
164
|
# Check if the Message it is replying to exists and is valid
|
|
161
165
|
msg_ins_id = res_metadata.reply_to_message_id
|
|
162
|
-
msg_ins = self.message_ins_store.get(
|
|
166
|
+
msg_ins = self.message_ins_store.get(msg_ins_id)
|
|
163
167
|
|
|
164
168
|
# Ensure that dst_node_id of original Message matches the src_node_id of
|
|
165
169
|
# reply Message.
|
|
@@ -214,22 +218,17 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
214
218
|
log(ERROR, "`metadata.run_id` is invalid")
|
|
215
219
|
return None
|
|
216
220
|
|
|
217
|
-
|
|
218
|
-
message_id = uuid4()
|
|
219
|
-
|
|
220
|
-
# Store Message
|
|
221
|
-
# pylint: disable-next=W0212
|
|
222
|
-
message.metadata._message_id = str(message_id) # type: ignore
|
|
221
|
+
message_id = message.metadata.message_id
|
|
223
222
|
with self.lock:
|
|
224
223
|
self.message_res_store[message_id] = message
|
|
225
|
-
self.message_ins_id_to_message_res_id[
|
|
224
|
+
self.message_ins_id_to_message_res_id[msg_ins_id] = message_id
|
|
226
225
|
|
|
227
226
|
# Return the new message_id
|
|
228
227
|
return message_id
|
|
229
228
|
|
|
230
|
-
def get_message_res(self, message_ids: set[
|
|
229
|
+
def get_message_res(self, message_ids: set[str]) -> list[Message]:
|
|
231
230
|
"""Get reply Messages for the given Message IDs."""
|
|
232
|
-
ret: dict[
|
|
231
|
+
ret: dict[str, Message] = {}
|
|
233
232
|
|
|
234
233
|
with self.lock:
|
|
235
234
|
current = time.time()
|
|
@@ -250,7 +249,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
250
249
|
inquired_in_message_ids=message_ids,
|
|
251
250
|
found_in_message_dict=self.message_ins_store,
|
|
252
251
|
node_id_to_online_until={
|
|
253
|
-
node_id: self.node_ids[node_id][0]
|
|
252
|
+
node_id: self.node_ids[node_id][0]
|
|
253
|
+
for node_id in dst_node_ids
|
|
254
|
+
if node_id in self.node_ids
|
|
254
255
|
},
|
|
255
256
|
current_time=current,
|
|
256
257
|
)
|
|
@@ -281,7 +282,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
281
282
|
|
|
282
283
|
return list(ret.values())
|
|
283
284
|
|
|
284
|
-
def delete_messages(self, message_ins_ids: set[
|
|
285
|
+
def delete_messages(self, message_ins_ids: set[str]) -> None:
|
|
285
286
|
"""Delete a Message and its reply based on provided Message IDs."""
|
|
286
287
|
if not message_ins_ids:
|
|
287
288
|
return
|
|
@@ -298,9 +299,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
298
299
|
)
|
|
299
300
|
del self.message_res_store[message_res_id]
|
|
300
301
|
|
|
301
|
-
def get_message_ids_from_run_id(self, run_id: int) -> set[
|
|
302
|
+
def get_message_ids_from_run_id(self, run_id: int) -> set[str]:
|
|
302
303
|
"""Get all instruction Message IDs for the given run_id."""
|
|
303
|
-
message_id_list: set[
|
|
304
|
+
message_id_list: set[str] = set()
|
|
304
305
|
with self.lock:
|
|
305
306
|
for message_id, message in self.message_ins_store.items():
|
|
306
307
|
if message.metadata.run_id == run_id:
|
|
@@ -322,7 +323,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
322
323
|
"""
|
|
323
324
|
return len(self.message_res_store)
|
|
324
325
|
|
|
325
|
-
def create_node(self,
|
|
326
|
+
def create_node(self, heartbeat_interval: float) -> int:
|
|
326
327
|
"""Create, store in the link state, and return `node_id`."""
|
|
327
328
|
# Sample a random int64 as node_id
|
|
328
329
|
node_id = generate_rand_int_from_bytes(
|
|
@@ -334,8 +335,11 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
334
335
|
log(ERROR, "Unexpected node registration failure.")
|
|
335
336
|
return 0
|
|
336
337
|
|
|
337
|
-
# Mark the node online
|
|
338
|
-
self.node_ids[node_id] = (
|
|
338
|
+
# Mark the node online until time.time() + heartbeat_interval
|
|
339
|
+
self.node_ids[node_id] = (
|
|
340
|
+
time.time() + heartbeat_interval,
|
|
341
|
+
heartbeat_interval,
|
|
342
|
+
)
|
|
339
343
|
return node_id
|
|
340
344
|
|
|
341
345
|
def delete_node(self, node_id: int) -> None:
|
|
@@ -400,6 +404,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
400
404
|
fab_hash: Optional[str],
|
|
401
405
|
override_config: UserConfig,
|
|
402
406
|
federation_options: ConfigRecord,
|
|
407
|
+
flwr_aid: Optional[str],
|
|
403
408
|
) -> int:
|
|
404
409
|
"""Create a new run for the specified `fab_hash`."""
|
|
405
410
|
# Sample a random int64 as run_id
|
|
@@ -423,9 +428,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
423
428
|
sub_status="",
|
|
424
429
|
details="",
|
|
425
430
|
),
|
|
431
|
+
flwr_aid=flwr_aid if flwr_aid else "",
|
|
426
432
|
),
|
|
427
433
|
)
|
|
428
434
|
self.run_ids[run_id] = run_record
|
|
435
|
+
# Add run_id to the flwr_aid_to_run_ids mapping if flwr_aid is provided
|
|
436
|
+
if flwr_aid:
|
|
437
|
+
self.flwr_aid_to_run_ids[flwr_aid].add(run_id)
|
|
429
438
|
|
|
430
439
|
# Record federation options. Leave empty if not passed
|
|
431
440
|
self.federation_options[run_id] = federation_options
|
|
@@ -453,13 +462,42 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
453
462
|
with self.lock:
|
|
454
463
|
return self.node_public_keys.copy()
|
|
455
464
|
|
|
456
|
-
def get_run_ids(self) -> set[int]:
|
|
457
|
-
"""Retrieve all run IDs.
|
|
465
|
+
def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
|
|
466
|
+
"""Retrieve all run IDs if `flwr_aid` is not specified.
|
|
467
|
+
|
|
468
|
+
Otherwise, retrieve all run IDs for the specified `flwr_aid`.
|
|
469
|
+
"""
|
|
458
470
|
with self.lock:
|
|
471
|
+
if flwr_aid is not None:
|
|
472
|
+
# Return run IDs for the specified flwr_aid
|
|
473
|
+
return set(self.flwr_aid_to_run_ids.get(flwr_aid, ()))
|
|
459
474
|
return set(self.run_ids.keys())
|
|
460
475
|
|
|
476
|
+
def _check_and_tag_inactive_run(self, run_ids: set[int]) -> None:
|
|
477
|
+
"""Check if any runs are no longer active.
|
|
478
|
+
|
|
479
|
+
Marks runs with status 'starting' or 'running' as failed
|
|
480
|
+
if they have not sent a heartbeat before `active_until`.
|
|
481
|
+
"""
|
|
482
|
+
current = now()
|
|
483
|
+
for record in (self.run_ids.get(run_id) for run_id in run_ids):
|
|
484
|
+
if record is None:
|
|
485
|
+
continue
|
|
486
|
+
with record.lock:
|
|
487
|
+
if record.run.status.status in (Status.STARTING, Status.RUNNING):
|
|
488
|
+
if record.active_until < current.timestamp():
|
|
489
|
+
record.run.status = RunStatus(
|
|
490
|
+
status=Status.FINISHED,
|
|
491
|
+
sub_status=SubStatus.FAILED,
|
|
492
|
+
details=RUN_FAILURE_DETAILS_NO_HEARTBEAT,
|
|
493
|
+
)
|
|
494
|
+
record.run.finished_at = now().isoformat()
|
|
495
|
+
|
|
461
496
|
def get_run(self, run_id: int) -> Optional[Run]:
|
|
462
497
|
"""Retrieve information about the run with the specified `run_id`."""
|
|
498
|
+
# Check if runs are still active
|
|
499
|
+
self._check_and_tag_inactive_run(run_ids={run_id})
|
|
500
|
+
|
|
463
501
|
with self.lock:
|
|
464
502
|
if run_id not in self.run_ids:
|
|
465
503
|
log(ERROR, "`run_id` is invalid")
|
|
@@ -468,6 +506,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
468
506
|
|
|
469
507
|
def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
|
|
470
508
|
"""Retrieve the statuses for the specified runs."""
|
|
509
|
+
# Check if runs are still active
|
|
510
|
+
self._check_and_tag_inactive_run(run_ids=run_ids)
|
|
511
|
+
|
|
471
512
|
with self.lock:
|
|
472
513
|
return {
|
|
473
514
|
run_id: self.run_ids[run_id].run.status
|
|
@@ -477,12 +518,16 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
477
518
|
|
|
478
519
|
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
|
|
479
520
|
"""Update the status of the run with the specified `run_id`."""
|
|
521
|
+
# Check if runs are still active
|
|
522
|
+
self._check_and_tag_inactive_run(run_ids={run_id})
|
|
523
|
+
|
|
480
524
|
with self.lock:
|
|
481
525
|
# Check if the run_id exists
|
|
482
526
|
if run_id not in self.run_ids:
|
|
483
527
|
log(ERROR, "`run_id` is invalid")
|
|
484
528
|
return False
|
|
485
529
|
|
|
530
|
+
with self.run_ids[run_id].lock:
|
|
486
531
|
# Check if the status transition is valid
|
|
487
532
|
current_status = self.run_ids[run_id].run.status
|
|
488
533
|
if not is_valid_transition(current_status, new_status):
|
|
@@ -504,14 +549,23 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
504
549
|
)
|
|
505
550
|
return False
|
|
506
551
|
|
|
507
|
-
#
|
|
552
|
+
# Initialize heartbeat_interval and active_until
|
|
553
|
+
# when switching to starting or running
|
|
554
|
+
current = now()
|
|
508
555
|
run_record = self.run_ids[run_id]
|
|
556
|
+
if new_status.status in (Status.STARTING, Status.RUNNING):
|
|
557
|
+
run_record.heartbeat_interval = HEARTBEAT_MAX_INTERVAL
|
|
558
|
+
run_record.active_until = (
|
|
559
|
+
current.timestamp() + run_record.heartbeat_interval
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
# Update the run status
|
|
509
563
|
if new_status.status == Status.STARTING:
|
|
510
|
-
run_record.run.starting_at =
|
|
564
|
+
run_record.run.starting_at = current.isoformat()
|
|
511
565
|
elif new_status.status == Status.RUNNING:
|
|
512
|
-
run_record.run.running_at =
|
|
566
|
+
run_record.run.running_at = current.isoformat()
|
|
513
567
|
elif new_status.status == Status.FINISHED:
|
|
514
|
-
run_record.run.finished_at =
|
|
568
|
+
run_record.run.finished_at = current.isoformat()
|
|
515
569
|
run_record.run.status = new_status
|
|
516
570
|
return True
|
|
517
571
|
|
|
@@ -536,21 +590,62 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
536
590
|
return None
|
|
537
591
|
return self.federation_options[run_id]
|
|
538
592
|
|
|
539
|
-
def
|
|
540
|
-
|
|
593
|
+
def acknowledge_node_heartbeat(
|
|
594
|
+
self, node_id: int, heartbeat_interval: float
|
|
595
|
+
) -> bool:
|
|
596
|
+
"""Acknowledge a heartbeat received from a node, serving as a heartbeat.
|
|
541
597
|
|
|
542
|
-
|
|
543
|
-
|
|
598
|
+
A node is considered online as long as it sends heartbeats within
|
|
599
|
+
the tolerated interval: HEARTBEAT_PATIENCE × heartbeat_interval.
|
|
600
|
+
HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before
|
|
601
|
+
the node is marked as offline.
|
|
544
602
|
"""
|
|
545
603
|
with self.lock:
|
|
546
604
|
if node_id in self.node_ids:
|
|
547
605
|
self.node_ids[node_id] = (
|
|
548
|
-
time.time() +
|
|
549
|
-
|
|
606
|
+
time.time() + HEARTBEAT_PATIENCE * heartbeat_interval,
|
|
607
|
+
heartbeat_interval,
|
|
550
608
|
)
|
|
551
609
|
return True
|
|
552
610
|
return False
|
|
553
611
|
|
|
612
|
+
def acknowledge_app_heartbeat(self, run_id: int, heartbeat_interval: float) -> bool:
|
|
613
|
+
"""Acknowledge a heartbeat received from a ServerApp for a given run.
|
|
614
|
+
|
|
615
|
+
A run with status `"running"` is considered alive as long as it sends heartbeats
|
|
616
|
+
within the tolerated interval: HEARTBEAT_PATIENCE × heartbeat_interval.
|
|
617
|
+
HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before the run is
|
|
618
|
+
marked as `"completed:failed"`.
|
|
619
|
+
"""
|
|
620
|
+
with self.lock:
|
|
621
|
+
# Search for the run
|
|
622
|
+
record = self.run_ids.get(run_id)
|
|
623
|
+
|
|
624
|
+
# Check if the run_id exists
|
|
625
|
+
if record is None:
|
|
626
|
+
log(ERROR, "`run_id` is invalid")
|
|
627
|
+
return False
|
|
628
|
+
|
|
629
|
+
with record.lock:
|
|
630
|
+
# Check if runs are still active
|
|
631
|
+
self._check_and_tag_inactive_run(run_ids={run_id})
|
|
632
|
+
|
|
633
|
+
# Check if the run is of status "running"/"starting"
|
|
634
|
+
current_status = record.run.status
|
|
635
|
+
if current_status.status not in (Status.RUNNING, Status.STARTING):
|
|
636
|
+
log(
|
|
637
|
+
ERROR,
|
|
638
|
+
'Cannot acknowledge heartbeat for run with status "%s"',
|
|
639
|
+
current_status.status,
|
|
640
|
+
)
|
|
641
|
+
return False
|
|
642
|
+
|
|
643
|
+
# Update the `active_until` and `heartbeat_interval` for the given run
|
|
644
|
+
current = now().timestamp()
|
|
645
|
+
record.active_until = current + HEARTBEAT_PATIENCE * heartbeat_interval
|
|
646
|
+
record.heartbeat_interval = heartbeat_interval
|
|
647
|
+
return True
|
|
648
|
+
|
|
554
649
|
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
|
555
650
|
"""Get the context for the specified `run_id`."""
|
|
556
651
|
return self.contexts.get(run_id)
|
|
@@ -17,7 +17,6 @@
|
|
|
17
17
|
|
|
18
18
|
import abc
|
|
19
19
|
from typing import Optional
|
|
20
|
-
from uuid import UUID
|
|
21
20
|
|
|
22
21
|
from flwr.common import Context, Message
|
|
23
22
|
from flwr.common.record import ConfigRecord
|
|
@@ -28,13 +27,13 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
28
27
|
"""Abstract LinkState."""
|
|
29
28
|
|
|
30
29
|
@abc.abstractmethod
|
|
31
|
-
def store_message_ins(self, message: Message) -> Optional[
|
|
30
|
+
def store_message_ins(self, message: Message) -> Optional[str]:
|
|
32
31
|
"""Store one Message.
|
|
33
32
|
|
|
34
33
|
Usually, the ServerAppIo API calls this to schedule instructions.
|
|
35
34
|
|
|
36
35
|
Stores the value of the `message` in the link state and, if successful,
|
|
37
|
-
returns the `message_id` (
|
|
36
|
+
returns the `message_id` (str) of the `message`. If, for any reason,
|
|
38
37
|
storing the `message` fails, `None` is returned.
|
|
39
38
|
|
|
40
39
|
Constraints
|
|
@@ -61,12 +60,12 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
61
60
|
"""
|
|
62
61
|
|
|
63
62
|
@abc.abstractmethod
|
|
64
|
-
def store_message_res(self, message: Message) -> Optional[
|
|
63
|
+
def store_message_res(self, message: Message) -> Optional[str]:
|
|
65
64
|
"""Store one Message.
|
|
66
65
|
|
|
67
66
|
Usually, the Fleet API calls this for Nodes returning results.
|
|
68
67
|
|
|
69
|
-
Stores the Message and, if successful, returns the `message_id` (
|
|
68
|
+
Stores the Message and, if successful, returns the `message_id` (str) of
|
|
70
69
|
the `message`. If storing the `message` fails, `None` is returned.
|
|
71
70
|
|
|
72
71
|
Constraints
|
|
@@ -78,7 +77,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
78
77
|
"""
|
|
79
78
|
|
|
80
79
|
@abc.abstractmethod
|
|
81
|
-
def get_message_res(self, message_ids: set[
|
|
80
|
+
def get_message_res(self, message_ids: set[str]) -> list[Message]:
|
|
82
81
|
"""Get reply Messages for the given Message IDs.
|
|
83
82
|
|
|
84
83
|
This method is typically called by the ServerAppIo API to obtain
|
|
@@ -94,7 +93,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
94
93
|
|
|
95
94
|
Parameters
|
|
96
95
|
----------
|
|
97
|
-
message_ids : set[
|
|
96
|
+
message_ids : set[str]
|
|
98
97
|
A set of Message IDs used to retrieve reply Messages responding to them.
|
|
99
98
|
|
|
100
99
|
Returns
|
|
@@ -113,22 +112,22 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
113
112
|
"""Calculate the number of reply Messages in store."""
|
|
114
113
|
|
|
115
114
|
@abc.abstractmethod
|
|
116
|
-
def delete_messages(self, message_ins_ids: set[
|
|
115
|
+
def delete_messages(self, message_ins_ids: set[str]) -> None:
|
|
117
116
|
"""Delete a Message and its reply based on provided Message IDs.
|
|
118
117
|
|
|
119
118
|
Parameters
|
|
120
119
|
----------
|
|
121
|
-
message_ins_ids : set[
|
|
120
|
+
message_ins_ids : set[str]
|
|
122
121
|
A set of Message IDs. For each ID in the set, the corresponding
|
|
123
122
|
Message and its associated reply Message will be deleted.
|
|
124
123
|
"""
|
|
125
124
|
|
|
126
125
|
@abc.abstractmethod
|
|
127
|
-
def get_message_ids_from_run_id(self, run_id: int) -> set[
|
|
126
|
+
def get_message_ids_from_run_id(self, run_id: int) -> set[str]:
|
|
128
127
|
"""Get all instruction Message IDs for the given run_id."""
|
|
129
128
|
|
|
130
129
|
@abc.abstractmethod
|
|
131
|
-
def create_node(self,
|
|
130
|
+
def create_node(self, heartbeat_interval: float) -> int:
|
|
132
131
|
"""Create, store in the link state, and return `node_id`."""
|
|
133
132
|
|
|
134
133
|
@abc.abstractmethod
|
|
@@ -165,12 +164,16 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
165
164
|
fab_hash: Optional[str],
|
|
166
165
|
override_config: UserConfig,
|
|
167
166
|
federation_options: ConfigRecord,
|
|
167
|
+
flwr_aid: Optional[str],
|
|
168
168
|
) -> int:
|
|
169
169
|
"""Create a new run for the specified `fab_hash`."""
|
|
170
170
|
|
|
171
171
|
@abc.abstractmethod
|
|
172
|
-
def get_run_ids(self) -> set[int]:
|
|
173
|
-
"""Retrieve all run IDs.
|
|
172
|
+
def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
|
|
173
|
+
"""Retrieve all run IDs if `flwr_aid` is not specified.
|
|
174
|
+
|
|
175
|
+
Otherwise, retrieve all run IDs for the specified `flwr_aid`.
|
|
176
|
+
"""
|
|
174
177
|
|
|
175
178
|
@abc.abstractmethod
|
|
176
179
|
def get_run(self, run_id: int) -> Optional[Run]:
|
|
@@ -267,22 +270,52 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
267
270
|
"""Retrieve all currently stored `node_public_keys` as a set."""
|
|
268
271
|
|
|
269
272
|
@abc.abstractmethod
|
|
270
|
-
def
|
|
271
|
-
|
|
273
|
+
def acknowledge_node_heartbeat(
|
|
274
|
+
self, node_id: int, heartbeat_interval: float
|
|
275
|
+
) -> bool:
|
|
276
|
+
"""Acknowledge a heartbeat received from a node.
|
|
277
|
+
|
|
278
|
+
A node is considered online as long as it sends heartbeats within
|
|
279
|
+
the tolerated interval: HEARTBEAT_PATIENCE × heartbeat_interval.
|
|
280
|
+
HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before
|
|
281
|
+
the node is marked as offline.
|
|
272
282
|
|
|
273
283
|
Parameters
|
|
274
284
|
----------
|
|
275
285
|
node_id : int
|
|
276
|
-
The `node_id` from which the
|
|
277
|
-
|
|
286
|
+
The `node_id` from which the heartbeat was received.
|
|
287
|
+
heartbeat_interval : float
|
|
288
|
+
The interval (in seconds) from the current timestamp within which the next
|
|
289
|
+
heartbeat from this node must be received. This acts as a hard deadline to
|
|
290
|
+
ensure an accurate assessment of the node's availability.
|
|
291
|
+
|
|
292
|
+
Returns
|
|
293
|
+
-------
|
|
294
|
+
is_acknowledged : bool
|
|
295
|
+
True if the heartbeat is successfully acknowledged; otherwise, False.
|
|
296
|
+
"""
|
|
297
|
+
|
|
298
|
+
@abc.abstractmethod
|
|
299
|
+
def acknowledge_app_heartbeat(self, run_id: int, heartbeat_interval: float) -> bool:
|
|
300
|
+
"""Acknowledge a heartbeat received from a ServerApp for a given run.
|
|
301
|
+
|
|
302
|
+
A run with status `"running"` is considered alive as long as it sends heartbeats
|
|
303
|
+
within the tolerated interval: HEARTBEAT_PATIENCE × heartbeat_interval.
|
|
304
|
+
HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before the run is
|
|
305
|
+
marked as `"completed:failed"`.
|
|
306
|
+
|
|
307
|
+
Parameters
|
|
308
|
+
----------
|
|
309
|
+
run_id : int
|
|
310
|
+
The `run_id` from which the heartbeat was received.
|
|
311
|
+
heartbeat_interval : float
|
|
278
312
|
The interval (in seconds) from the current timestamp within which the next
|
|
279
|
-
|
|
280
|
-
an accurate assessment of the node's availability.
|
|
313
|
+
heartbeat from the ServerApp for this run must be received.
|
|
281
314
|
|
|
282
315
|
Returns
|
|
283
316
|
-------
|
|
284
317
|
is_acknowledged : bool
|
|
285
|
-
True if the
|
|
318
|
+
True if the heartbeat is successfully acknowledged; otherwise, False.
|
|
286
319
|
"""
|
|
287
320
|
|
|
288
321
|
@abc.abstractmethod
|