flwr 1.22.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +16 -5
- flwr/app/error.py +2 -2
- flwr/app/exception.py +3 -3
- flwr/cli/app.py +34 -1
- flwr/cli/app_cmd/__init__.py +23 -0
- flwr/cli/app_cmd/publish.py +285 -0
- flwr/cli/app_cmd/review.py +252 -0
- flwr/cli/auth_plugin/__init__.py +15 -6
- flwr/cli/auth_plugin/auth_plugin.py +94 -0
- flwr/cli/auth_plugin/noop_auth_plugin.py +101 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +46 -32
- flwr/cli/build.py +166 -53
- flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +29 -11
- flwr/cli/config_utils.py +101 -13
- flwr/cli/federation/__init__.py +24 -0
- flwr/cli/federation/ls.py +140 -0
- flwr/cli/federation/show.py +317 -0
- flwr/cli/install.py +91 -13
- flwr/cli/log.py +54 -11
- flwr/cli/login/login.py +41 -27
- flwr/cli/ls.py +177 -133
- flwr/cli/new/new.py +175 -40
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/cli/pull.py +12 -7
- flwr/cli/run/run.py +82 -31
- flwr/cli/run_utils.py +130 -0
- flwr/cli/stop.py +27 -9
- flwr/cli/supernode/__init__.py +25 -0
- flwr/cli/supernode/ls.py +268 -0
- flwr/cli/supernode/register.py +190 -0
- flwr/cli/supernode/unregister.py +140 -0
- flwr/cli/utils.py +464 -81
- flwr/client/__init__.py +2 -1
- flwr/client/dpfedavg_numpy_client.py +4 -1
- flwr/client/grpc_adapter_client/connection.py +12 -15
- flwr/client/grpc_rere_client/connection.py +68 -41
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -14
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +5 -7
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +10 -8
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/connection.py +94 -51
- flwr/client/run_info_store.py +4 -5
- flwr/client/typing.py +1 -1
- flwr/clientapp/__init__.py +1 -2
- flwr/{client → clientapp}/client_app.py +9 -10
- flwr/clientapp/mod/centraldp_mods.py +16 -17
- flwr/clientapp/mod/localdp_mod.py +8 -9
- flwr/clientapp/typing.py +1 -1
- flwr/{client/clientapp → clientapp}/utils.py +4 -4
- flwr/common/address.py +1 -2
- flwr/common/args.py +3 -4
- flwr/common/config.py +13 -16
- flwr/common/constant.py +56 -13
- flwr/common/differential_privacy.py +3 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -4
- flwr/common/exit/exit.py +15 -2
- flwr/common/exit/exit_code.py +39 -10
- flwr/common/exit/exit_handler.py +6 -2
- flwr/common/exit/signal_handler.py +5 -5
- flwr/common/grpc.py +6 -6
- flwr/common/inflatable_protobuf_utils.py +1 -1
- flwr/common/inflatable_utils.py +48 -31
- flwr/common/logger.py +19 -19
- flwr/common/message.py +4 -4
- flwr/common/object_ref.py +7 -7
- flwr/common/record/array.py +6 -6
- flwr/common/record/arrayrecord.py +18 -21
- flwr/common/record/configrecord.py +3 -3
- flwr/common/record/recorddict.py +5 -5
- flwr/common/record/typeddict.py +9 -2
- flwr/common/recorddict_compat.py +7 -10
- flwr/common/retry_invoker.py +20 -20
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
- flwr/common/serde.py +9 -6
- flwr/common/serde_utils.py +2 -2
- flwr/common/telemetry.py +9 -5
- flwr/common/typing.py +59 -43
- flwr/compat/client/app.py +39 -38
- flwr/compat/client/grpc_client/connection.py +13 -13
- flwr/compat/server/app.py +5 -6
- flwr/proto/appio_pb2.py +13 -3
- flwr/proto/appio_pb2.pyi +134 -65
- flwr/proto/appio_pb2_grpc.py +20 -0
- flwr/proto/appio_pb2_grpc.pyi +27 -0
- flwr/proto/clientappio_pb2.py +17 -7
- flwr/proto/clientappio_pb2.pyi +15 -0
- flwr/proto/clientappio_pb2_grpc.py +206 -40
- flwr/proto/clientappio_pb2_grpc.pyi +168 -53
- flwr/proto/control_pb2.py +72 -40
- flwr/proto/control_pb2.pyi +319 -87
- flwr/proto/control_pb2_grpc.py +339 -28
- flwr/proto/control_pb2_grpc.pyi +209 -37
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2.pyi +24 -6
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/error_pb2_grpc.pyi +27 -0
- flwr/proto/fab_pb2.py +24 -10
- flwr/proto/fab_pb2.pyi +68 -20
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fab_pb2_grpc.pyi +27 -0
- flwr/proto/federation_pb2.py +38 -0
- flwr/proto/federation_pb2.pyi +56 -0
- flwr/proto/federation_pb2_grpc.py +24 -0
- flwr/proto/federation_pb2_grpc.pyi +31 -0
- flwr/proto/fleet_pb2.py +45 -27
- flwr/proto/fleet_pb2.pyi +186 -70
- flwr/proto/fleet_pb2_grpc.py +277 -66
- flwr/proto/fleet_pb2_grpc.pyi +201 -55
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2.pyi +38 -16
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
- flwr/proto/heartbeat_pb2.py +17 -7
- flwr/proto/heartbeat_pb2.pyi +51 -22
- flwr/proto/heartbeat_pb2_grpc.py +20 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2.pyi +34 -11
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/log_pb2_grpc.pyi +27 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2.pyi +154 -86
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/message_pb2_grpc.pyi +27 -0
- flwr/proto/node_pb2.py +16 -4
- flwr/proto/node_pb2.pyi +77 -4
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/node_pb2_grpc.pyi +27 -0
- flwr/proto/recorddict_pb2.py +13 -3
- flwr/proto/recorddict_pb2.pyi +184 -107
- flwr/proto/recorddict_pb2_grpc.py +20 -0
- flwr/proto/recorddict_pb2_grpc.pyi +27 -0
- flwr/proto/run_pb2.py +40 -31
- flwr/proto/run_pb2.pyi +149 -84
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/run_pb2_grpc.pyi +27 -0
- flwr/proto/serverappio_pb2.py +13 -3
- flwr/proto/serverappio_pb2.pyi +32 -8
- flwr/proto/serverappio_pb2_grpc.py +246 -65
- flwr/proto/serverappio_pb2_grpc.pyi +221 -85
- flwr/proto/simulationio_pb2.py +16 -8
- flwr/proto/simulationio_pb2.pyi +15 -0
- flwr/proto/simulationio_pb2_grpc.py +162 -41
- flwr/proto/simulationio_pb2_grpc.pyi +149 -55
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2.pyi +249 -160
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/proto/transport_pb2_grpc.pyi +38 -8
- flwr/server/app.py +173 -127
- flwr/server/client_manager.py +4 -5
- flwr/server/client_proxy.py +10 -11
- flwr/server/compat/app.py +4 -5
- flwr/server/compat/app_utils.py +2 -1
- flwr/server/compat/grid_client_proxy.py +10 -12
- flwr/server/compat/legacy_context.py +3 -4
- flwr/server/fleet_event_log_interceptor.py +2 -1
- flwr/server/grid/grid.py +2 -3
- flwr/server/grid/grpc_grid.py +10 -8
- flwr/server/grid/inmemory_grid.py +4 -4
- flwr/server/run_serverapp.py +2 -3
- flwr/server/server.py +34 -39
- flwr/server/server_app.py +7 -8
- flwr/server/server_config.py +1 -2
- flwr/server/serverapp/app.py +34 -28
- flwr/server/serverapp_components.py +4 -5
- flwr/server/strategy/aggregate.py +9 -8
- flwr/server/strategy/bulyan.py +13 -11
- flwr/server/strategy/dp_adaptive_clipping.py +16 -20
- flwr/server/strategy/dp_fixed_clipping.py +12 -17
- flwr/server/strategy/dpfedavg_adaptive.py +3 -4
- flwr/server/strategy/dpfedavg_fixed.py +6 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
- flwr/server/strategy/fedadagrad.py +18 -14
- flwr/server/strategy/fedadam.py +16 -14
- flwr/server/strategy/fedavg.py +16 -17
- flwr/server/strategy/fedavg_android.py +15 -15
- flwr/server/strategy/fedavgm.py +21 -18
- flwr/server/strategy/fedmedian.py +2 -3
- flwr/server/strategy/fedopt.py +11 -10
- flwr/server/strategy/fedprox.py +10 -9
- flwr/server/strategy/fedtrimmedavg.py +12 -11
- flwr/server/strategy/fedxgb_bagging.py +13 -11
- flwr/server/strategy/fedxgb_cyclic.py +6 -6
- flwr/server/strategy/fedxgb_nn_avg.py +4 -4
- flwr/server/strategy/fedyogi.py +16 -14
- flwr/server/strategy/krum.py +12 -11
- flwr/server/strategy/qfedavg.py +16 -15
- flwr/server/strategy/strategy.py +6 -9
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +19 -8
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +136 -42
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +28 -51
- flwr/server/superlink/fleet/message_handler/message_handler.py +100 -49
- flwr/server/superlink/fleet/rest_rere/rest_api.py +54 -33
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +6 -6
- flwr/server/superlink/fleet/vce/vce_api.py +32 -13
- flwr/server/superlink/linkstate/in_memory_linkstate.py +266 -207
- flwr/server/superlink/linkstate/linkstate.py +161 -62
- flwr/server/superlink/linkstate/linkstate_factory.py +24 -6
- flwr/server/superlink/linkstate/sqlite_linkstate.py +698 -638
- flwr/server/superlink/linkstate/utils.py +9 -60
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +28 -23
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +19 -14
- flwr/server/superlink/utils.py +4 -6
- flwr/server/typing.py +1 -1
- flwr/server/utils/tensorboard.py +15 -8
- flwr/server/utils/validator.py +2 -3
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +12 -10
- flwr/serverapp/strategy/bulyan.py +16 -15
- flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
- flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
- flwr/serverapp/strategy/fedadagrad.py +10 -11
- flwr/serverapp/strategy/fedadam.py +10 -11
- flwr/serverapp/strategy/fedavg.py +9 -10
- flwr/serverapp/strategy/fedavgm.py +17 -16
- flwr/serverapp/strategy/fedmedian.py +2 -2
- flwr/serverapp/strategy/fedopt.py +10 -11
- flwr/serverapp/strategy/fedprox.py +7 -8
- flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
- flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
- flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
- flwr/serverapp/strategy/fedyogi.py +9 -11
- flwr/serverapp/strategy/krum.py +7 -7
- flwr/serverapp/strategy/multikrum.py +9 -9
- flwr/serverapp/strategy/qfedavg.py +17 -16
- flwr/serverapp/strategy/strategy.py +6 -9
- flwr/serverapp/strategy/strategy_utils.py +7 -8
- flwr/simulation/app.py +46 -42
- flwr/simulation/legacy_app.py +12 -12
- flwr/simulation/ray_transport/ray_actor.py +11 -12
- flwr/simulation/ray_transport/ray_client_proxy.py +12 -13
- flwr/simulation/run_simulation.py +44 -43
- flwr/simulation/simulationio_connection.py +4 -4
- flwr/supercore/cli/flower_superexec.py +3 -4
- flwr/supercore/constant.py +52 -0
- flwr/supercore/corestate/corestate.py +24 -3
- flwr/supercore/corestate/in_memory_corestate.py +138 -0
- flwr/supercore/corestate/sqlite_corestate.py +157 -0
- flwr/supercore/ffs/disk_ffs.py +1 -2
- flwr/supercore/ffs/ffs.py +1 -2
- flwr/supercore/ffs/ffs_factory.py +1 -2
- flwr/{common → supercore}/heartbeat.py +20 -25
- flwr/supercore/object_store/in_memory_object_store.py +1 -6
- flwr/supercore/object_store/object_store.py +1 -2
- flwr/supercore/object_store/object_store_factory.py +27 -8
- flwr/supercore/object_store/sqlite_object_store.py +253 -0
- flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
- flwr/supercore/primitives/asymmetric.py +117 -0
- flwr/supercore/primitives/asymmetric_ed25519.py +175 -0
- flwr/supercore/sqlite_mixin.py +159 -0
- flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
- flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
- flwr/supercore/superexec/run_superexec.py +9 -13
- flwr/supercore/utils.py +20 -0
- flwr/superlink/artifact_provider/artifact_provider.py +1 -2
- flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
- flwr/superlink/auth_plugin/auth_plugin.py +88 -0
- flwr/superlink/auth_plugin/noop_auth_plugin.py +84 -0
- flwr/superlink/federation/__init__.py +24 -0
- flwr/superlink/federation/federation_manager.py +64 -0
- flwr/superlink/federation/noop_federation_manager.py +71 -0
- flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +41 -32
- flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
- flwr/superlink/servicer/control/control_grpc.py +18 -17
- flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
- flwr/superlink/servicer/control/control_servicer.py +239 -63
- flwr/supernode/cli/flower_supernode.py +74 -26
- flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
- flwr/supernode/nodestate/nodestate.py +7 -8
- flwr/supernode/nodestate/nodestate_factory.py +7 -4
- flwr/supernode/runtime/run_clientapp.py +43 -24
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
- flwr/supernode/start_client_internal.py +175 -51
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
- flwr-1.24.0.dist-info/RECORD +454 -0
- flwr/common/auth_plugin/auth_plugin.py +0 -149
- flwr/supercore/object_store/utils.py +0 -43
- flwr-1.22.0.dist-info/RECORD +0 -428
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
|
@@ -15,19 +15,16 @@
|
|
|
15
15
|
"""In-memory LinkState implementation."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import secrets
|
|
19
18
|
import threading
|
|
20
|
-
import time
|
|
21
19
|
from bisect import bisect_right
|
|
22
20
|
from collections import defaultdict
|
|
21
|
+
from collections.abc import Sequence
|
|
23
22
|
from dataclasses import dataclass, field
|
|
23
|
+
from datetime import datetime, timezone
|
|
24
24
|
from logging import ERROR, WARNING
|
|
25
|
-
from typing import Optional
|
|
26
25
|
|
|
27
26
|
from flwr.common import Context, Message, log, now
|
|
28
27
|
from flwr.common.constant import (
|
|
29
|
-
FLWR_APP_TOKEN_LENGTH,
|
|
30
|
-
HEARTBEAT_MAX_INTERVAL,
|
|
31
28
|
HEARTBEAT_PATIENCE,
|
|
32
29
|
MESSAGE_TTL_TOLERANCE,
|
|
33
30
|
NODE_ID_NUM_BYTES,
|
|
@@ -39,8 +36,13 @@ from flwr.common.constant import (
|
|
|
39
36
|
)
|
|
40
37
|
from flwr.common.record import ConfigRecord
|
|
41
38
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
39
|
+
from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
|
|
42
40
|
from flwr.server.superlink.linkstate.linkstate import LinkState
|
|
43
41
|
from flwr.server.utils import validate_message
|
|
42
|
+
from flwr.supercore.constant import NodeStatus
|
|
43
|
+
from flwr.supercore.corestate.in_memory_corestate import InMemoryCoreState
|
|
44
|
+
from flwr.supercore.object_store.object_store import ObjectStore
|
|
45
|
+
from flwr.superlink.federation import FederationManager
|
|
44
46
|
|
|
45
47
|
from .utils import (
|
|
46
48
|
check_node_availability_for_in_message,
|
|
@@ -57,22 +59,23 @@ class RunRecord: # pylint: disable=R0902
|
|
|
57
59
|
"""The record of a specific run, including its status and timestamps."""
|
|
58
60
|
|
|
59
61
|
run: Run
|
|
60
|
-
active_until: float = 0.0
|
|
61
|
-
heartbeat_interval: float = 0.0
|
|
62
62
|
logs: list[tuple[float, str]] = field(default_factory=list)
|
|
63
63
|
log_lock: threading.Lock = field(default_factory=threading.Lock)
|
|
64
64
|
lock: threading.RLock = field(default_factory=threading.RLock)
|
|
65
65
|
|
|
66
66
|
|
|
67
|
-
class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
67
|
+
class InMemoryLinkState(LinkState, InMemoryCoreState): # pylint: disable=R0902,R0904
|
|
68
68
|
"""In-memory LinkState implementation."""
|
|
69
69
|
|
|
70
|
-
def __init__(
|
|
70
|
+
def __init__(
|
|
71
|
+
self, federation_manager: FederationManager, object_store: ObjectStore
|
|
72
|
+
) -> None:
|
|
73
|
+
super().__init__(object_store)
|
|
71
74
|
|
|
72
|
-
# Map node_id to
|
|
73
|
-
self.
|
|
74
|
-
self.
|
|
75
|
-
self.
|
|
75
|
+
# Map node_id to NodeInfo
|
|
76
|
+
self.nodes: dict[int, NodeInfo] = {}
|
|
77
|
+
self.node_public_key_to_node_id: dict[bytes, int] = {}
|
|
78
|
+
self.owner_to_node_ids: dict[str, set[int]] = {} # Quick lookup
|
|
76
79
|
|
|
77
80
|
# Map run_id to RunRecord
|
|
78
81
|
self.run_ids: dict[int, RunRecord] = {}
|
|
@@ -82,19 +85,21 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
82
85
|
self.message_res_store: dict[str, Message] = {}
|
|
83
86
|
self.message_ins_id_to_message_res_id: dict[str, str] = {}
|
|
84
87
|
|
|
85
|
-
# Store run ID to token mapping and token to run ID mapping
|
|
86
|
-
self.token_store: dict[int, str] = {}
|
|
87
|
-
self.token_to_run_id: dict[str, int] = {}
|
|
88
|
-
self.lock_token_store = threading.Lock()
|
|
89
|
-
|
|
90
88
|
# Map flwr_aid to run_ids for O(1) reverse index lookup
|
|
91
89
|
self.flwr_aid_to_run_ids: dict[str, set[int]] = defaultdict(set)
|
|
92
90
|
|
|
93
91
|
self.node_public_keys: set[bytes] = set()
|
|
94
92
|
|
|
95
93
|
self.lock = threading.RLock()
|
|
94
|
+
federation_manager.linkstate = self
|
|
95
|
+
self._federation_manager = federation_manager
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def federation_manager(self) -> FederationManager:
|
|
99
|
+
"""Get the FederationManager instance."""
|
|
100
|
+
return self._federation_manager
|
|
96
101
|
|
|
97
|
-
def store_message_ins(self, message: Message) ->
|
|
102
|
+
def store_message_ins(self, message: Message) -> str | None:
|
|
98
103
|
"""Store one Message."""
|
|
99
104
|
# Validate message
|
|
100
105
|
errors = validate_message(message, is_reply_message=False)
|
|
@@ -105,6 +110,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
105
110
|
if message.metadata.run_id not in self.run_ids:
|
|
106
111
|
log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
|
|
107
112
|
return None
|
|
113
|
+
federation = self.run_ids[message.metadata.run_id].run.federation
|
|
108
114
|
# Validate source node ID
|
|
109
115
|
if message.metadata.src_node_id != SUPERLINK_NODE_ID:
|
|
110
116
|
log(
|
|
@@ -114,7 +120,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
114
120
|
)
|
|
115
121
|
return None
|
|
116
122
|
# Validate destination node ID
|
|
117
|
-
|
|
123
|
+
dst_node = self.nodes.get(message.metadata.dst_node_id)
|
|
124
|
+
if (
|
|
125
|
+
# Node must exist
|
|
126
|
+
dst_node is None
|
|
127
|
+
# Node must be online or offline
|
|
128
|
+
or dst_node.status not in (NodeStatus.ONLINE, NodeStatus.OFFLINE)
|
|
129
|
+
# Node must belong to the same federation
|
|
130
|
+
or not self.federation_manager.has_node(dst_node.node_id, federation)
|
|
131
|
+
):
|
|
118
132
|
log(
|
|
119
133
|
ERROR,
|
|
120
134
|
"Invalid destination node ID for Message: %s",
|
|
@@ -129,21 +143,50 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
129
143
|
# Return the new message_id
|
|
130
144
|
return message_id
|
|
131
145
|
|
|
132
|
-
def
|
|
146
|
+
def _check_stored_messages(self, message_ids: set[str]) -> None:
|
|
147
|
+
"""Check and delete the message if it's invalid."""
|
|
148
|
+
with self.lock:
|
|
149
|
+
invalid_msg_ids: set[str] = set()
|
|
150
|
+
current = now().timestamp()
|
|
151
|
+
for msg_id in message_ids:
|
|
152
|
+
if not (message := self.message_ins_store.get(msg_id)):
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
# Check if the message has expired
|
|
156
|
+
available_until = message.metadata.created_at + message.metadata.ttl
|
|
157
|
+
if available_until <= current:
|
|
158
|
+
invalid_msg_ids.add(msg_id)
|
|
159
|
+
continue
|
|
160
|
+
|
|
161
|
+
# Check if the destination node and the source node are still in the
|
|
162
|
+
# same federation
|
|
163
|
+
src_node_id = message.metadata.src_node_id
|
|
164
|
+
dst_node_id = message.metadata.dst_node_id
|
|
165
|
+
filtered = self.federation_manager.filter_nodes(
|
|
166
|
+
{src_node_id, dst_node_id},
|
|
167
|
+
self.run_ids[message.metadata.run_id].run.federation,
|
|
168
|
+
)
|
|
169
|
+
if len(filtered) != 2: # Not both nodes are in the federation
|
|
170
|
+
invalid_msg_ids.add(msg_id)
|
|
171
|
+
|
|
172
|
+
# Delete all invalid messages
|
|
173
|
+
self.delete_messages(invalid_msg_ids)
|
|
174
|
+
|
|
175
|
+
def get_message_ins(self, node_id: int, limit: int | None) -> list[Message]:
|
|
133
176
|
"""Get all Messages that have not been delivered yet."""
|
|
134
177
|
if limit is not None and limit < 1:
|
|
135
178
|
raise AssertionError("`limit` must be >= 1")
|
|
136
179
|
|
|
137
180
|
# Find Message for node_id that were not delivered yet
|
|
138
181
|
message_ins_list: list[Message] = []
|
|
139
|
-
current_time = time.time()
|
|
140
182
|
with self.lock:
|
|
141
|
-
for
|
|
183
|
+
for msg_id in list(self.message_ins_store.keys()):
|
|
184
|
+
self._check_stored_messages({msg_id})
|
|
185
|
+
|
|
142
186
|
if (
|
|
143
|
-
msg_ins.
|
|
187
|
+
(msg_ins := self.message_ins_store.get(msg_id))
|
|
188
|
+
and msg_ins.metadata.dst_node_id == node_id
|
|
144
189
|
and msg_ins.metadata.delivered_at == ""
|
|
145
|
-
and msg_ins.metadata.created_at + msg_ins.metadata.ttl
|
|
146
|
-
> current_time
|
|
147
190
|
):
|
|
148
191
|
message_ins_list.append(msg_ins)
|
|
149
192
|
if limit and len(message_ins_list) == limit:
|
|
@@ -158,7 +201,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
158
201
|
return message_ins_list
|
|
159
202
|
|
|
160
203
|
# pylint: disable=R0911
|
|
161
|
-
def store_message_res(self, message: Message) ->
|
|
204
|
+
def store_message_res(self, message: Message) -> str | None:
|
|
162
205
|
"""Store one Message."""
|
|
163
206
|
# Validate message
|
|
164
207
|
errors = validate_message(message, is_reply_message=True)
|
|
@@ -170,6 +213,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
170
213
|
with self.lock:
|
|
171
214
|
# Check if the Message it is replying to exists and is valid
|
|
172
215
|
msg_ins_id = res_metadata.reply_to_message_id
|
|
216
|
+
self._check_stored_messages({msg_ins_id})
|
|
173
217
|
msg_ins = self.message_ins_store.get(msg_ins_id)
|
|
174
218
|
|
|
175
219
|
# Ensure that dst_node_id of original Message matches the src_node_id of
|
|
@@ -189,22 +233,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
189
233
|
)
|
|
190
234
|
return None
|
|
191
235
|
|
|
192
|
-
ins_metadata = msg_ins.metadata
|
|
193
|
-
if ins_metadata.created_at + ins_metadata.ttl <= time.time():
|
|
194
|
-
log(
|
|
195
|
-
ERROR,
|
|
196
|
-
"Failed to store Message: the message it is replying to "
|
|
197
|
-
"(with ID %s) has expired",
|
|
198
|
-
msg_ins_id,
|
|
199
|
-
)
|
|
200
|
-
return None
|
|
201
|
-
|
|
202
236
|
# Fail if the Message TTL exceeds the
|
|
203
237
|
# expiration time of the Message it replies to.
|
|
204
238
|
# Condition: ins_metadata.created_at + ins_metadata.ttl ≥
|
|
205
239
|
# res_metadata.created_at + res_metadata.ttl
|
|
206
240
|
# A small tolerance is introduced to account
|
|
207
241
|
# for floating-point precision issues.
|
|
242
|
+
ins_metadata = msg_ins.metadata
|
|
208
243
|
max_allowed_ttl = (
|
|
209
244
|
ins_metadata.created_at + ins_metadata.ttl - res_metadata.created_at
|
|
210
245
|
)
|
|
@@ -238,7 +273,8 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
238
273
|
ret: dict[str, Message] = {}
|
|
239
274
|
|
|
240
275
|
with self.lock:
|
|
241
|
-
|
|
276
|
+
self._check_stored_messages(message_ids)
|
|
277
|
+
current = now().timestamp()
|
|
242
278
|
|
|
243
279
|
# Verify Message IDs
|
|
244
280
|
ret = verify_message_ids(
|
|
@@ -256,9 +292,10 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
256
292
|
inquired_in_message_ids=message_ids,
|
|
257
293
|
found_in_message_dict=self.message_ins_store,
|
|
258
294
|
node_id_to_online_until={
|
|
259
|
-
node_id: self.
|
|
295
|
+
node_id: self.nodes[node_id].online_until
|
|
260
296
|
for node_id in dst_node_ids
|
|
261
|
-
if node_id in self.
|
|
297
|
+
if node_id in self.nodes
|
|
298
|
+
and self.nodes[node_id].status != NodeStatus.UNREGISTERED
|
|
262
299
|
},
|
|
263
300
|
current_time=current,
|
|
264
301
|
)
|
|
@@ -330,7 +367,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
330
367
|
"""
|
|
331
368
|
return len(self.message_res_store)
|
|
332
369
|
|
|
333
|
-
def create_node(
|
|
370
|
+
def create_node(
|
|
371
|
+
self,
|
|
372
|
+
owner_aid: str,
|
|
373
|
+
owner_name: str,
|
|
374
|
+
public_key: bytes,
|
|
375
|
+
heartbeat_interval: float,
|
|
376
|
+
) -> int:
|
|
334
377
|
"""Create, store in the link state, and return `node_id`."""
|
|
335
378
|
# Sample a random int64 as node_id
|
|
336
379
|
node_id = generate_rand_int_from_bytes(
|
|
@@ -338,28 +381,89 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
338
381
|
)
|
|
339
382
|
|
|
340
383
|
with self.lock:
|
|
341
|
-
if node_id in self.
|
|
384
|
+
if node_id in self.nodes:
|
|
342
385
|
log(ERROR, "Unexpected node registration failure.")
|
|
343
386
|
return 0
|
|
387
|
+
if public_key in self.node_public_key_to_node_id:
|
|
388
|
+
raise ValueError("Public key already in use")
|
|
344
389
|
|
|
345
|
-
#
|
|
346
|
-
self.
|
|
347
|
-
|
|
348
|
-
|
|
390
|
+
# The node is not activated upon creation
|
|
391
|
+
self.nodes[node_id] = NodeInfo(
|
|
392
|
+
node_id=node_id,
|
|
393
|
+
owner_aid=owner_aid,
|
|
394
|
+
owner_name=owner_name,
|
|
395
|
+
status=NodeStatus.REGISTERED,
|
|
396
|
+
registered_at=now().isoformat(),
|
|
397
|
+
last_activated_at=None,
|
|
398
|
+
last_deactivated_at=None,
|
|
399
|
+
unregistered_at=None,
|
|
400
|
+
online_until=None,
|
|
401
|
+
heartbeat_interval=heartbeat_interval,
|
|
402
|
+
public_key=public_key,
|
|
349
403
|
)
|
|
404
|
+
self.node_public_key_to_node_id[public_key] = node_id
|
|
405
|
+
self.owner_to_node_ids.setdefault(owner_aid, set()).add(node_id)
|
|
350
406
|
return node_id
|
|
351
407
|
|
|
352
|
-
def delete_node(self, node_id: int) -> None:
|
|
408
|
+
def delete_node(self, owner_aid: str, node_id: int) -> None:
|
|
353
409
|
"""Delete a node."""
|
|
354
410
|
with self.lock:
|
|
355
|
-
if
|
|
356
|
-
|
|
411
|
+
if (
|
|
412
|
+
not (node := self.nodes.get(node_id))
|
|
413
|
+
or node.status == NodeStatus.UNREGISTERED
|
|
414
|
+
or owner_aid != self.nodes[node_id].owner_aid
|
|
415
|
+
):
|
|
416
|
+
raise ValueError(
|
|
417
|
+
f"Node ID {node_id} already unregistered, not found or "
|
|
418
|
+
"the request was unauthorized."
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
node.status = NodeStatus.UNREGISTERED
|
|
422
|
+
current = now()
|
|
423
|
+
node.unregistered_at = current.isoformat()
|
|
424
|
+
# Set online_until to current timestamp on deletion, if it is in the future
|
|
425
|
+
node.online_until = min(node.online_until, current.timestamp())
|
|
426
|
+
|
|
427
|
+
def activate_node(self, node_id: int, heartbeat_interval: float) -> bool:
|
|
428
|
+
"""Activate the node with the specified `node_id`."""
|
|
429
|
+
with self.lock:
|
|
430
|
+
self._check_and_tag_offline_nodes(node_ids=[node_id])
|
|
431
|
+
|
|
432
|
+
# Check if the node exists
|
|
433
|
+
if not (node := self.nodes.get(node_id)):
|
|
434
|
+
return False
|
|
435
|
+
|
|
436
|
+
# Only activate if the node is currently registered or offline
|
|
437
|
+
current_dt = now()
|
|
438
|
+
if node.status in (NodeStatus.REGISTERED, NodeStatus.OFFLINE):
|
|
439
|
+
node.status = NodeStatus.ONLINE
|
|
440
|
+
node.last_activated_at = current_dt.isoformat()
|
|
441
|
+
node.online_until = (
|
|
442
|
+
current_dt.timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval
|
|
443
|
+
)
|
|
444
|
+
node.heartbeat_interval = heartbeat_interval
|
|
445
|
+
return True
|
|
446
|
+
return False
|
|
447
|
+
|
|
448
|
+
def deactivate_node(self, node_id: int) -> bool:
|
|
449
|
+
"""Deactivate the node with the specified `node_id`."""
|
|
450
|
+
with self.lock:
|
|
451
|
+
self._check_and_tag_offline_nodes(node_ids=[node_id])
|
|
452
|
+
|
|
453
|
+
# Check if the node exists
|
|
454
|
+
if not (node := self.nodes.get(node_id)):
|
|
455
|
+
return False
|
|
357
456
|
|
|
358
|
-
#
|
|
359
|
-
|
|
360
|
-
|
|
457
|
+
# Only deactivate if the node is currently online
|
|
458
|
+
current_dt = now()
|
|
459
|
+
if node.status == NodeStatus.ONLINE:
|
|
460
|
+
node.status = NodeStatus.OFFLINE
|
|
461
|
+
node.last_deactivated_at = current_dt.isoformat()
|
|
361
462
|
|
|
362
|
-
|
|
463
|
+
# Set online_until to current timestamp
|
|
464
|
+
node.online_until = current_dt.timestamp()
|
|
465
|
+
return True
|
|
466
|
+
return False
|
|
363
467
|
|
|
364
468
|
def get_nodes(self, run_id: int) -> set[int]:
|
|
365
469
|
"""Return all available nodes.
|
|
@@ -372,48 +476,83 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
372
476
|
with self.lock:
|
|
373
477
|
if run_id not in self.run_ids:
|
|
374
478
|
return set()
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
node_id
|
|
378
|
-
for
|
|
379
|
-
if online_until > current_time
|
|
479
|
+
federation = self.run_ids[run_id].run.federation
|
|
480
|
+
node_ids = {
|
|
481
|
+
node.node_id
|
|
482
|
+
for node in self.get_node_info(statuses=[NodeStatus.ONLINE])
|
|
380
483
|
}
|
|
484
|
+
return self.federation_manager.filter_nodes(node_ids, federation)
|
|
381
485
|
|
|
382
|
-
def
|
|
383
|
-
|
|
486
|
+
def get_node_info(
|
|
487
|
+
self,
|
|
488
|
+
*,
|
|
489
|
+
node_ids: Sequence[int] | None = None,
|
|
490
|
+
owner_aids: Sequence[str] | None = None,
|
|
491
|
+
statuses: Sequence[str] | None = None,
|
|
492
|
+
) -> Sequence[NodeInfo]:
|
|
493
|
+
"""Retrieve information about nodes based on the specified filters."""
|
|
384
494
|
with self.lock:
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
495
|
+
self._check_and_tag_offline_nodes()
|
|
496
|
+
result = []
|
|
497
|
+
for node_id in self.nodes.keys() if node_ids is None else node_ids:
|
|
498
|
+
if (node := self.nodes.get(node_id)) is None:
|
|
499
|
+
continue
|
|
500
|
+
if owner_aids is not None and node.owner_aid not in owner_aids:
|
|
501
|
+
continue
|
|
502
|
+
if statuses is not None and node.status not in statuses:
|
|
503
|
+
continue
|
|
504
|
+
result.append(node)
|
|
505
|
+
return result
|
|
506
|
+
|
|
507
|
+
def _check_and_tag_offline_nodes(self, node_ids: list[int] | None = None) -> None:
|
|
508
|
+
with self.lock:
|
|
509
|
+
# Set all nodes of "online" status to "offline" if they've offline
|
|
510
|
+
current_ts = now().timestamp()
|
|
511
|
+
for node_id in node_ids or self.nodes.keys():
|
|
512
|
+
if (node := self.nodes.get(node_id)) is None:
|
|
513
|
+
continue
|
|
514
|
+
if node.status == NodeStatus.ONLINE:
|
|
515
|
+
if node.online_until <= current_ts:
|
|
516
|
+
node.status = NodeStatus.OFFLINE
|
|
517
|
+
node.last_deactivated_at = datetime.fromtimestamp(
|
|
518
|
+
node.online_until, tz=timezone.utc
|
|
519
|
+
).isoformat()
|
|
520
|
+
|
|
521
|
+
def get_node_public_key(self, node_id: int) -> bytes:
|
|
395
522
|
"""Get `public_key` for the specified `node_id`."""
|
|
396
523
|
with self.lock:
|
|
397
|
-
if
|
|
398
|
-
|
|
524
|
+
if (
|
|
525
|
+
node := self.nodes.get(node_id)
|
|
526
|
+
) is None or node.status == NodeStatus.UNREGISTERED:
|
|
527
|
+
raise ValueError(f"Node ID {node_id} not found")
|
|
528
|
+
return node.public_key
|
|
529
|
+
|
|
530
|
+
def get_node_id_by_public_key(self, public_key: bytes) -> int | None:
|
|
531
|
+
"""Get `node_id` for the specified `public_key` if it exists and is not
|
|
532
|
+
deleted."""
|
|
533
|
+
with self.lock:
|
|
534
|
+
node_id = self.node_public_key_to_node_id.get(public_key)
|
|
399
535
|
|
|
400
|
-
|
|
536
|
+
if node_id is None:
|
|
537
|
+
return None
|
|
401
538
|
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
539
|
+
node_info = self.nodes[node_id]
|
|
540
|
+
if node_info.status == NodeStatus.UNREGISTERED:
|
|
541
|
+
return None
|
|
542
|
+
return node_id
|
|
405
543
|
|
|
406
544
|
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
407
545
|
def create_run(
|
|
408
546
|
self,
|
|
409
|
-
fab_id:
|
|
410
|
-
fab_version:
|
|
411
|
-
fab_hash:
|
|
547
|
+
fab_id: str | None,
|
|
548
|
+
fab_version: str | None,
|
|
549
|
+
fab_hash: str | None,
|
|
412
550
|
override_config: UserConfig,
|
|
551
|
+
federation: str,
|
|
413
552
|
federation_options: ConfigRecord,
|
|
414
|
-
flwr_aid:
|
|
553
|
+
flwr_aid: str | None,
|
|
415
554
|
) -> int:
|
|
416
|
-
"""Create a new run
|
|
555
|
+
"""Create a new run."""
|
|
417
556
|
# Sample a random int64 as run_id
|
|
418
557
|
with self.lock:
|
|
419
558
|
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
@@ -436,6 +575,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
436
575
|
details="",
|
|
437
576
|
),
|
|
438
577
|
flwr_aid=flwr_aid if flwr_aid else "",
|
|
578
|
+
federation=federation,
|
|
439
579
|
),
|
|
440
580
|
)
|
|
441
581
|
self.run_ids[run_id] = run_record
|
|
@@ -449,27 +589,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
449
589
|
log(ERROR, "Unexpected run creation failure.")
|
|
450
590
|
return 0
|
|
451
591
|
|
|
452
|
-
def
|
|
453
|
-
"""Clear stored `node_public_keys` in the link state if any."""
|
|
454
|
-
with self.lock:
|
|
455
|
-
self.node_public_keys.clear()
|
|
456
|
-
|
|
457
|
-
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
458
|
-
"""Store a set of `node_public_keys` in the link state."""
|
|
459
|
-
with self.lock:
|
|
460
|
-
self.node_public_keys.update(public_keys)
|
|
461
|
-
|
|
462
|
-
def store_node_public_key(self, public_key: bytes) -> None:
|
|
463
|
-
"""Store a `node_public_key` in the link state."""
|
|
464
|
-
with self.lock:
|
|
465
|
-
self.node_public_keys.add(public_key)
|
|
466
|
-
|
|
467
|
-
def get_node_public_keys(self) -> set[bytes]:
|
|
468
|
-
"""Retrieve all currently stored `node_public_keys` as a set."""
|
|
469
|
-
with self.lock:
|
|
470
|
-
return self.node_public_keys.copy()
|
|
471
|
-
|
|
472
|
-
def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
|
|
592
|
+
def get_run_ids(self, flwr_aid: str | None) -> set[int]:
|
|
473
593
|
"""Retrieve all run IDs if `flwr_aid` is not specified.
|
|
474
594
|
|
|
475
595
|
Otherwise, retrieve all run IDs for the specified `flwr_aid`.
|
|
@@ -480,30 +600,10 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
480
600
|
return set(self.flwr_aid_to_run_ids.get(flwr_aid, ()))
|
|
481
601
|
return set(self.run_ids.keys())
|
|
482
602
|
|
|
483
|
-
def
|
|
484
|
-
"""Check if any runs are no longer active.
|
|
485
|
-
|
|
486
|
-
Marks runs with status 'starting' or 'running' as failed
|
|
487
|
-
if they have not sent a heartbeat before `active_until`.
|
|
488
|
-
"""
|
|
489
|
-
current = now()
|
|
490
|
-
for record in (self.run_ids.get(run_id) for run_id in run_ids):
|
|
491
|
-
if record is None:
|
|
492
|
-
continue
|
|
493
|
-
with record.lock:
|
|
494
|
-
if record.run.status.status in (Status.STARTING, Status.RUNNING):
|
|
495
|
-
if record.active_until < current.timestamp():
|
|
496
|
-
record.run.status = RunStatus(
|
|
497
|
-
status=Status.FINISHED,
|
|
498
|
-
sub_status=SubStatus.FAILED,
|
|
499
|
-
details=RUN_FAILURE_DETAILS_NO_HEARTBEAT,
|
|
500
|
-
)
|
|
501
|
-
record.run.finished_at = now().isoformat()
|
|
502
|
-
|
|
503
|
-
def get_run(self, run_id: int) -> Optional[Run]:
|
|
603
|
+
def get_run(self, run_id: int) -> Run | None:
|
|
504
604
|
"""Retrieve information about the run with the specified `run_id`."""
|
|
505
|
-
#
|
|
506
|
-
self.
|
|
605
|
+
# Clean up expired tokens; this will flag inactive runs as needed
|
|
606
|
+
self._cleanup_expired_tokens()
|
|
507
607
|
|
|
508
608
|
with self.lock:
|
|
509
609
|
if run_id not in self.run_ids:
|
|
@@ -513,8 +613,8 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
513
613
|
|
|
514
614
|
def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
|
|
515
615
|
"""Retrieve the statuses for the specified runs."""
|
|
516
|
-
#
|
|
517
|
-
self.
|
|
616
|
+
# Clean up expired tokens; this will flag inactive runs as needed
|
|
617
|
+
self._cleanup_expired_tokens()
|
|
518
618
|
|
|
519
619
|
with self.lock:
|
|
520
620
|
return {
|
|
@@ -525,8 +625,8 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
525
625
|
|
|
526
626
|
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
|
|
527
627
|
"""Update the status of the run with the specified `run_id`."""
|
|
528
|
-
#
|
|
529
|
-
self.
|
|
628
|
+
# Clean up expired tokens; this will flag inactive runs as needed
|
|
629
|
+
self._cleanup_expired_tokens()
|
|
530
630
|
|
|
531
631
|
with self.lock:
|
|
532
632
|
# Check if the run_id exists
|
|
@@ -556,17 +656,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
556
656
|
)
|
|
557
657
|
return False
|
|
558
658
|
|
|
559
|
-
#
|
|
560
|
-
# when switching to starting or running
|
|
659
|
+
# Update the run status
|
|
561
660
|
current = now()
|
|
562
661
|
run_record = self.run_ids[run_id]
|
|
563
|
-
if new_status.status in (Status.STARTING, Status.RUNNING):
|
|
564
|
-
run_record.heartbeat_interval = HEARTBEAT_MAX_INTERVAL
|
|
565
|
-
run_record.active_until = (
|
|
566
|
-
current.timestamp() + run_record.heartbeat_interval
|
|
567
|
-
)
|
|
568
|
-
|
|
569
|
-
# Update the run status
|
|
570
662
|
if new_status.status == Status.STARTING:
|
|
571
663
|
run_record.run.starting_at = current.isoformat()
|
|
572
664
|
elif new_status.status == Status.RUNNING:
|
|
@@ -576,7 +668,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
576
668
|
run_record.run.status = new_status
|
|
577
669
|
return True
|
|
578
670
|
|
|
579
|
-
def get_pending_run_id(self) ->
|
|
671
|
+
def get_pending_run_id(self) -> int | None:
|
|
580
672
|
"""Get the `run_id` of a run with `Status.PENDING` status, if any."""
|
|
581
673
|
pending_run_id = None
|
|
582
674
|
|
|
@@ -589,7 +681,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
589
681
|
|
|
590
682
|
return pending_run_id
|
|
591
683
|
|
|
592
|
-
def get_federation_options(self, run_id: int) ->
|
|
684
|
+
def get_federation_options(self, run_id: int) -> ConfigRecord | None:
|
|
593
685
|
"""Retrieve the federation options for the specified `run_id`."""
|
|
594
686
|
with self.lock:
|
|
595
687
|
if run_id not in self.run_ids:
|
|
@@ -608,52 +700,46 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
608
700
|
the node is marked as offline.
|
|
609
701
|
"""
|
|
610
702
|
with self.lock:
|
|
611
|
-
if
|
|
612
|
-
self.
|
|
613
|
-
|
|
614
|
-
|
|
703
|
+
if (
|
|
704
|
+
node := self.nodes.get(node_id)
|
|
705
|
+
) and node.status != NodeStatus.UNREGISTERED:
|
|
706
|
+
current_dt = now()
|
|
707
|
+
|
|
708
|
+
# Set timestamp if the status changes
|
|
709
|
+
if node.status != NodeStatus.ONLINE: # offline or registered
|
|
710
|
+
node.status = NodeStatus.ONLINE
|
|
711
|
+
node.last_activated_at = current_dt.isoformat()
|
|
712
|
+
|
|
713
|
+
# Refresh `online_until` and `heartbeat_interval`
|
|
714
|
+
node.online_until = (
|
|
715
|
+
current_dt.timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval
|
|
615
716
|
)
|
|
717
|
+
node.heartbeat_interval = heartbeat_interval
|
|
616
718
|
return True
|
|
617
|
-
|
|
719
|
+
return False
|
|
618
720
|
|
|
619
|
-
def
|
|
620
|
-
"""
|
|
721
|
+
def _on_tokens_expired(self, expired_records: list[tuple[int, float]]) -> None:
|
|
722
|
+
"""Transition runs with expired tokens to failed status.
|
|
621
723
|
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
724
|
+
Parameters
|
|
725
|
+
----------
|
|
726
|
+
expired_records : list[tuple[int, float]]
|
|
727
|
+
List of tuples containing (run_id, active_until timestamp)
|
|
728
|
+
for expired tokens.
|
|
626
729
|
"""
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
with record.lock:
|
|
637
|
-
# Check if runs are still active
|
|
638
|
-
self._check_and_tag_inactive_run(run_ids={run_id})
|
|
639
|
-
|
|
640
|
-
# Check if the run is of status "running"/"starting"
|
|
641
|
-
current_status = record.run.status
|
|
642
|
-
if current_status.status not in (Status.RUNNING, Status.STARTING):
|
|
643
|
-
log(
|
|
644
|
-
ERROR,
|
|
645
|
-
'Cannot acknowledge heartbeat for run with status "%s"',
|
|
646
|
-
current_status.status,
|
|
730
|
+
for run_id, active_until in expired_records:
|
|
731
|
+
if not (run_record := self.run_ids.get(run_id)):
|
|
732
|
+
continue
|
|
733
|
+
with run_record.lock:
|
|
734
|
+
run_record.run.status = RunStatus(
|
|
735
|
+
status=Status.FINISHED,
|
|
736
|
+
sub_status=SubStatus.FAILED,
|
|
737
|
+
details=RUN_FAILURE_DETAILS_NO_HEARTBEAT,
|
|
647
738
|
)
|
|
648
|
-
|
|
739
|
+
active_until_dt = datetime.fromtimestamp(active_until, tz=timezone.utc)
|
|
740
|
+
run_record.run.finished_at = active_until_dt.isoformat()
|
|
649
741
|
|
|
650
|
-
|
|
651
|
-
current = now().timestamp()
|
|
652
|
-
record.active_until = current + HEARTBEAT_PATIENCE * heartbeat_interval
|
|
653
|
-
record.heartbeat_interval = heartbeat_interval
|
|
654
|
-
return True
|
|
655
|
-
|
|
656
|
-
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
|
742
|
+
def get_serverapp_context(self, run_id: int) -> Context | None:
|
|
657
743
|
"""Get the context for the specified `run_id`."""
|
|
658
744
|
return self.contexts.get(run_id)
|
|
659
745
|
|
|
@@ -672,7 +758,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
672
758
|
run.logs.append((now().timestamp(), log_message))
|
|
673
759
|
|
|
674
760
|
def get_serverapp_log(
|
|
675
|
-
self, run_id: int, after_timestamp:
|
|
761
|
+
self, run_id: int, after_timestamp: float | None
|
|
676
762
|
) -> tuple[str, float]:
|
|
677
763
|
"""Get the serverapp logs for the specified `run_id`."""
|
|
678
764
|
if run_id not in self.run_ids:
|
|
@@ -685,30 +771,3 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
685
771
|
index = bisect_right(run.logs, (after_timestamp, ""))
|
|
686
772
|
latest_timestamp = run.logs[-1][0] if index < len(run.logs) else 0.0
|
|
687
773
|
return "".join(log for _, log in run.logs[index:]), latest_timestamp
|
|
688
|
-
|
|
689
|
-
def create_token(self, run_id: int) -> Optional[str]:
|
|
690
|
-
"""Create a token for the given run ID."""
|
|
691
|
-
token = secrets.token_hex(FLWR_APP_TOKEN_LENGTH) # Generate a random token
|
|
692
|
-
with self.lock_token_store:
|
|
693
|
-
if run_id in self.token_store:
|
|
694
|
-
return None # Token already created for this run ID
|
|
695
|
-
self.token_store[run_id] = token
|
|
696
|
-
self.token_to_run_id[token] = run_id
|
|
697
|
-
return token
|
|
698
|
-
|
|
699
|
-
def verify_token(self, run_id: int, token: str) -> bool:
|
|
700
|
-
"""Verify a token for the given run ID."""
|
|
701
|
-
with self.lock_token_store:
|
|
702
|
-
return self.token_store.get(run_id) == token
|
|
703
|
-
|
|
704
|
-
def delete_token(self, run_id: int) -> None:
|
|
705
|
-
"""Delete the token for the given run ID."""
|
|
706
|
-
with self.lock_token_store:
|
|
707
|
-
token = self.token_store.pop(run_id, None)
|
|
708
|
-
if token is not None:
|
|
709
|
-
self.token_to_run_id.pop(token, None)
|
|
710
|
-
|
|
711
|
-
def get_run_id_by_token(self, token: str) -> Optional[int]:
|
|
712
|
-
"""Get the run ID associated with a given token."""
|
|
713
|
-
with self.lock_token_store:
|
|
714
|
-
return self.token_to_run_id.get(token)
|