flwr 1.23.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +16 -5
- flwr/app/error.py +2 -2
- flwr/app/exception.py +3 -3
- flwr/cli/app.py +19 -0
- flwr/cli/app_cmd/__init__.py +23 -0
- flwr/cli/app_cmd/publish.py +285 -0
- flwr/cli/app_cmd/review.py +252 -0
- flwr/cli/auth_plugin/auth_plugin.py +4 -5
- flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
- flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
- flwr/cli/build.py +60 -18
- flwr/cli/cli_account_auth_interceptor.py +24 -7
- flwr/cli/config_utils.py +101 -13
- flwr/cli/federation/__init__.py +24 -0
- flwr/cli/federation/ls.py +140 -0
- flwr/cli/federation/show.py +317 -0
- flwr/cli/install.py +91 -13
- flwr/cli/log.py +52 -9
- flwr/cli/login/login.py +7 -4
- flwr/cli/ls.py +170 -130
- flwr/cli/new/new.py +33 -50
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/cli/pull.py +10 -5
- flwr/cli/run/run.py +77 -30
- flwr/cli/run_utils.py +130 -0
- flwr/cli/stop.py +25 -7
- flwr/cli/supernode/ls.py +16 -8
- flwr/cli/supernode/register.py +9 -4
- flwr/cli/supernode/unregister.py +5 -3
- flwr/cli/utils.py +376 -16
- flwr/client/__init__.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +4 -1
- flwr/client/grpc_adapter_client/connection.py +6 -7
- flwr/client/grpc_rere_client/connection.py +10 -11
- flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
- flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/connection.py +12 -14
- flwr/client/run_info_store.py +4 -5
- flwr/client/typing.py +1 -1
- flwr/clientapp/client_app.py +9 -10
- flwr/clientapp/mod/centraldp_mods.py +16 -17
- flwr/clientapp/mod/localdp_mod.py +8 -9
- flwr/clientapp/typing.py +1 -1
- flwr/clientapp/utils.py +3 -3
- flwr/common/address.py +1 -2
- flwr/common/args.py +3 -4
- flwr/common/config.py +13 -16
- flwr/common/constant.py +5 -2
- flwr/common/differential_privacy.py +3 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -4
- flwr/common/exit/exit.py +15 -2
- flwr/common/exit/exit_code.py +19 -0
- flwr/common/exit/exit_handler.py +6 -2
- flwr/common/exit/signal_handler.py +5 -5
- flwr/common/grpc.py +6 -6
- flwr/common/inflatable_protobuf_utils.py +1 -1
- flwr/common/inflatable_utils.py +38 -21
- flwr/common/logger.py +19 -19
- flwr/common/message.py +4 -4
- flwr/common/object_ref.py +7 -7
- flwr/common/record/array.py +3 -3
- flwr/common/record/arrayrecord.py +18 -30
- flwr/common/record/configrecord.py +3 -3
- flwr/common/record/recorddict.py +5 -5
- flwr/common/record/typeddict.py +9 -2
- flwr/common/recorddict_compat.py +7 -10
- flwr/common/retry_invoker.py +20 -20
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
- flwr/common/serde.py +5 -4
- flwr/common/serde_utils.py +2 -2
- flwr/common/telemetry.py +9 -5
- flwr/common/typing.py +52 -37
- flwr/compat/client/app.py +38 -37
- flwr/compat/client/grpc_client/connection.py +11 -11
- flwr/compat/server/app.py +5 -6
- flwr/proto/appio_pb2.py +13 -3
- flwr/proto/appio_pb2.pyi +134 -65
- flwr/proto/appio_pb2_grpc.py +20 -0
- flwr/proto/appio_pb2_grpc.pyi +27 -0
- flwr/proto/clientappio_pb2.py +17 -7
- flwr/proto/clientappio_pb2.pyi +15 -0
- flwr/proto/clientappio_pb2_grpc.py +206 -40
- flwr/proto/clientappio_pb2_grpc.pyi +168 -53
- flwr/proto/control_pb2.py +71 -52
- flwr/proto/control_pb2.pyi +277 -111
- flwr/proto/control_pb2_grpc.py +249 -40
- flwr/proto/control_pb2_grpc.pyi +185 -52
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2.pyi +24 -6
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/error_pb2_grpc.pyi +27 -0
- flwr/proto/fab_pb2.py +14 -4
- flwr/proto/fab_pb2.pyi +59 -31
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fab_pb2_grpc.pyi +27 -0
- flwr/proto/federation_pb2.py +38 -0
- flwr/proto/federation_pb2.pyi +56 -0
- flwr/proto/federation_pb2_grpc.py +24 -0
- flwr/proto/federation_pb2_grpc.pyi +31 -0
- flwr/proto/fleet_pb2.py +14 -4
- flwr/proto/fleet_pb2.pyi +137 -61
- flwr/proto/fleet_pb2_grpc.py +189 -48
- flwr/proto/fleet_pb2_grpc.pyi +175 -61
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2.pyi +38 -16
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
- flwr/proto/heartbeat_pb2.py +17 -7
- flwr/proto/heartbeat_pb2.pyi +51 -22
- flwr/proto/heartbeat_pb2_grpc.py +20 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2.pyi +34 -11
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/log_pb2_grpc.pyi +27 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2.pyi +154 -86
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/message_pb2_grpc.pyi +27 -0
- flwr/proto/node_pb2.py +15 -5
- flwr/proto/node_pb2.pyi +50 -25
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/node_pb2_grpc.pyi +27 -0
- flwr/proto/recorddict_pb2.py +13 -3
- flwr/proto/recorddict_pb2.pyi +184 -107
- flwr/proto/recorddict_pb2_grpc.py +20 -0
- flwr/proto/recorddict_pb2_grpc.pyi +27 -0
- flwr/proto/run_pb2.py +40 -31
- flwr/proto/run_pb2.pyi +149 -84
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/run_pb2_grpc.pyi +27 -0
- flwr/proto/serverappio_pb2.py +13 -3
- flwr/proto/serverappio_pb2.pyi +32 -8
- flwr/proto/serverappio_pb2_grpc.py +246 -65
- flwr/proto/serverappio_pb2_grpc.pyi +221 -85
- flwr/proto/simulationio_pb2.py +16 -8
- flwr/proto/simulationio_pb2.pyi +15 -0
- flwr/proto/simulationio_pb2_grpc.py +162 -41
- flwr/proto/simulationio_pb2_grpc.pyi +149 -55
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2.pyi +249 -160
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/proto/transport_pb2_grpc.pyi +38 -8
- flwr/server/app.py +38 -17
- flwr/server/client_manager.py +4 -5
- flwr/server/client_proxy.py +10 -11
- flwr/server/compat/app.py +4 -5
- flwr/server/compat/app_utils.py +2 -1
- flwr/server/compat/grid_client_proxy.py +10 -12
- flwr/server/compat/legacy_context.py +3 -4
- flwr/server/fleet_event_log_interceptor.py +2 -1
- flwr/server/grid/grid.py +2 -3
- flwr/server/grid/grpc_grid.py +10 -8
- flwr/server/grid/inmemory_grid.py +4 -4
- flwr/server/run_serverapp.py +2 -3
- flwr/server/server.py +34 -39
- flwr/server/server_app.py +7 -8
- flwr/server/server_config.py +1 -2
- flwr/server/serverapp/app.py +34 -28
- flwr/server/serverapp_components.py +4 -5
- flwr/server/strategy/aggregate.py +9 -8
- flwr/server/strategy/bulyan.py +13 -11
- flwr/server/strategy/dp_adaptive_clipping.py +16 -20
- flwr/server/strategy/dp_fixed_clipping.py +12 -17
- flwr/server/strategy/dpfedavg_adaptive.py +3 -4
- flwr/server/strategy/dpfedavg_fixed.py +6 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
- flwr/server/strategy/fedadagrad.py +18 -14
- flwr/server/strategy/fedadam.py +16 -14
- flwr/server/strategy/fedavg.py +16 -17
- flwr/server/strategy/fedavg_android.py +15 -15
- flwr/server/strategy/fedavgm.py +21 -18
- flwr/server/strategy/fedmedian.py +2 -3
- flwr/server/strategy/fedopt.py +11 -10
- flwr/server/strategy/fedprox.py +10 -9
- flwr/server/strategy/fedtrimmedavg.py +12 -11
- flwr/server/strategy/fedxgb_bagging.py +13 -11
- flwr/server/strategy/fedxgb_cyclic.py +6 -6
- flwr/server/strategy/fedxgb_nn_avg.py +4 -4
- flwr/server/strategy/fedyogi.py +16 -14
- flwr/server/strategy/krum.py +12 -11
- flwr/server/strategy/qfedavg.py +16 -15
- flwr/server/strategy/strategy.py +6 -9
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
- flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +34 -28
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +15 -9
- flwr/server/superlink/linkstate/in_memory_linkstate.py +115 -150
- flwr/server/superlink/linkstate/linkstate.py +59 -43
- flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
- flwr/server/superlink/linkstate/sqlite_linkstate.py +447 -438
- flwr/server/superlink/linkstate/utils.py +6 -6
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
- flwr/server/superlink/utils.py +4 -6
- flwr/server/typing.py +1 -1
- flwr/server/utils/tensorboard.py +15 -8
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
- flwr/serverapp/strategy/bulyan.py +16 -15
- flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
- flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
- flwr/serverapp/strategy/fedadagrad.py +10 -11
- flwr/serverapp/strategy/fedadam.py +10 -11
- flwr/serverapp/strategy/fedavg.py +9 -10
- flwr/serverapp/strategy/fedavgm.py +17 -16
- flwr/serverapp/strategy/fedmedian.py +2 -2
- flwr/serverapp/strategy/fedopt.py +10 -11
- flwr/serverapp/strategy/fedprox.py +7 -8
- flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
- flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
- flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
- flwr/serverapp/strategy/fedyogi.py +9 -11
- flwr/serverapp/strategy/krum.py +7 -7
- flwr/serverapp/strategy/multikrum.py +9 -9
- flwr/serverapp/strategy/qfedavg.py +17 -16
- flwr/serverapp/strategy/strategy.py +6 -9
- flwr/serverapp/strategy/strategy_utils.py +7 -8
- flwr/simulation/app.py +46 -42
- flwr/simulation/legacy_app.py +12 -12
- flwr/simulation/ray_transport/ray_actor.py +10 -11
- flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
- flwr/simulation/run_simulation.py +43 -43
- flwr/simulation/simulationio_connection.py +4 -4
- flwr/supercore/cli/flower_superexec.py +3 -4
- flwr/supercore/constant.py +31 -1
- flwr/supercore/corestate/corestate.py +24 -3
- flwr/supercore/corestate/in_memory_corestate.py +138 -0
- flwr/supercore/corestate/sqlite_corestate.py +157 -0
- flwr/supercore/ffs/disk_ffs.py +1 -2
- flwr/supercore/ffs/ffs.py +1 -2
- flwr/supercore/ffs/ffs_factory.py +1 -2
- flwr/{common → supercore}/heartbeat.py +20 -25
- flwr/supercore/object_store/in_memory_object_store.py +1 -2
- flwr/supercore/object_store/object_store.py +1 -2
- flwr/supercore/object_store/object_store_factory.py +1 -2
- flwr/supercore/object_store/sqlite_object_store.py +8 -7
- flwr/supercore/primitives/asymmetric.py +1 -1
- flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
- flwr/supercore/sqlite_mixin.py +37 -34
- flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
- flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
- flwr/supercore/superexec/run_superexec.py +9 -13
- flwr/superlink/artifact_provider/artifact_provider.py +1 -2
- flwr/superlink/auth_plugin/auth_plugin.py +6 -9
- flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
- flwr/superlink/federation/__init__.py +24 -0
- flwr/superlink/federation/federation_manager.py +64 -0
- flwr/superlink/federation/noop_federation_manager.py +71 -0
- flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
- flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
- flwr/superlink/servicer/control/control_grpc.py +5 -6
- flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
- flwr/superlink/servicer/control/control_servicer.py +102 -18
- flwr/supernode/cli/flower_supernode.py +58 -3
- flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
- flwr/supernode/nodestate/nodestate.py +7 -8
- flwr/supernode/nodestate/nodestate_factory.py +7 -4
- flwr/supernode/runtime/run_clientapp.py +41 -22
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
- flwr/supernode/start_client_internal.py +158 -42
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
- flwr-1.24.0.dist-info/RECORD +454 -0
- flwr/supercore/object_store/utils.py +0 -43
- flwr-1.23.0.dist-info/RECORD +0 -439
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
|
@@ -15,7 +15,6 @@
|
|
|
15
15
|
"""In-memory LinkState implementation."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import secrets
|
|
19
18
|
import threading
|
|
20
19
|
from bisect import bisect_right
|
|
21
20
|
from collections import defaultdict
|
|
@@ -23,12 +22,9 @@ from collections.abc import Sequence
|
|
|
23
22
|
from dataclasses import dataclass, field
|
|
24
23
|
from datetime import datetime, timezone
|
|
25
24
|
from logging import ERROR, WARNING
|
|
26
|
-
from typing import Optional
|
|
27
25
|
|
|
28
26
|
from flwr.common import Context, Message, log, now
|
|
29
27
|
from flwr.common.constant import (
|
|
30
|
-
FLWR_APP_TOKEN_LENGTH,
|
|
31
|
-
HEARTBEAT_INTERVAL_INF,
|
|
32
28
|
HEARTBEAT_PATIENCE,
|
|
33
29
|
MESSAGE_TTL_TOLERANCE,
|
|
34
30
|
NODE_ID_NUM_BYTES,
|
|
@@ -44,6 +40,9 @@ from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
|
|
|
44
40
|
from flwr.server.superlink.linkstate.linkstate import LinkState
|
|
45
41
|
from flwr.server.utils import validate_message
|
|
46
42
|
from flwr.supercore.constant import NodeStatus
|
|
43
|
+
from flwr.supercore.corestate.in_memory_corestate import InMemoryCoreState
|
|
44
|
+
from flwr.supercore.object_store.object_store import ObjectStore
|
|
45
|
+
from flwr.superlink.federation import FederationManager
|
|
47
46
|
|
|
48
47
|
from .utils import (
|
|
49
48
|
check_node_availability_for_in_message,
|
|
@@ -60,17 +59,18 @@ class RunRecord: # pylint: disable=R0902
|
|
|
60
59
|
"""The record of a specific run, including its status and timestamps."""
|
|
61
60
|
|
|
62
61
|
run: Run
|
|
63
|
-
active_until: float = 0.0
|
|
64
|
-
heartbeat_interval: float = 0.0
|
|
65
62
|
logs: list[tuple[float, str]] = field(default_factory=list)
|
|
66
63
|
log_lock: threading.Lock = field(default_factory=threading.Lock)
|
|
67
64
|
lock: threading.RLock = field(default_factory=threading.RLock)
|
|
68
65
|
|
|
69
66
|
|
|
70
|
-
class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
67
|
+
class InMemoryLinkState(LinkState, InMemoryCoreState): # pylint: disable=R0902,R0904
|
|
71
68
|
"""In-memory LinkState implementation."""
|
|
72
69
|
|
|
73
|
-
def __init__(
|
|
70
|
+
def __init__(
|
|
71
|
+
self, federation_manager: FederationManager, object_store: ObjectStore
|
|
72
|
+
) -> None:
|
|
73
|
+
super().__init__(object_store)
|
|
74
74
|
|
|
75
75
|
# Map node_id to NodeInfo
|
|
76
76
|
self.nodes: dict[int, NodeInfo] = {}
|
|
@@ -85,19 +85,21 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
85
85
|
self.message_res_store: dict[str, Message] = {}
|
|
86
86
|
self.message_ins_id_to_message_res_id: dict[str, str] = {}
|
|
87
87
|
|
|
88
|
-
# Store run ID to token mapping and token to run ID mapping
|
|
89
|
-
self.token_store: dict[int, str] = {}
|
|
90
|
-
self.token_to_run_id: dict[str, int] = {}
|
|
91
|
-
self.lock_token_store = threading.Lock()
|
|
92
|
-
|
|
93
88
|
# Map flwr_aid to run_ids for O(1) reverse index lookup
|
|
94
89
|
self.flwr_aid_to_run_ids: dict[str, set[int]] = defaultdict(set)
|
|
95
90
|
|
|
96
91
|
self.node_public_keys: set[bytes] = set()
|
|
97
92
|
|
|
98
93
|
self.lock = threading.RLock()
|
|
94
|
+
federation_manager.linkstate = self
|
|
95
|
+
self._federation_manager = federation_manager
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def federation_manager(self) -> FederationManager:
|
|
99
|
+
"""Get the FederationManager instance."""
|
|
100
|
+
return self._federation_manager
|
|
99
101
|
|
|
100
|
-
def store_message_ins(self, message: Message) ->
|
|
102
|
+
def store_message_ins(self, message: Message) -> str | None:
|
|
101
103
|
"""Store one Message."""
|
|
102
104
|
# Validate message
|
|
103
105
|
errors = validate_message(message, is_reply_message=False)
|
|
@@ -108,6 +110,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
108
110
|
if message.metadata.run_id not in self.run_ids:
|
|
109
111
|
log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
|
|
110
112
|
return None
|
|
113
|
+
federation = self.run_ids[message.metadata.run_id].run.federation
|
|
111
114
|
# Validate source node ID
|
|
112
115
|
if message.metadata.src_node_id != SUPERLINK_NODE_ID:
|
|
113
116
|
log(
|
|
@@ -118,10 +121,14 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
118
121
|
return None
|
|
119
122
|
# Validate destination node ID
|
|
120
123
|
dst_node = self.nodes.get(message.metadata.dst_node_id)
|
|
121
|
-
if
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
124
|
+
if (
|
|
125
|
+
# Node must exist
|
|
126
|
+
dst_node is None
|
|
127
|
+
# Node must be online or offline
|
|
128
|
+
or dst_node.status not in (NodeStatus.ONLINE, NodeStatus.OFFLINE)
|
|
129
|
+
# Node must belong to the same federation
|
|
130
|
+
or not self.federation_manager.has_node(dst_node.node_id, federation)
|
|
131
|
+
):
|
|
125
132
|
log(
|
|
126
133
|
ERROR,
|
|
127
134
|
"Invalid destination node ID for Message: %s",
|
|
@@ -136,21 +143,50 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
136
143
|
# Return the new message_id
|
|
137
144
|
return message_id
|
|
138
145
|
|
|
139
|
-
def
|
|
146
|
+
def _check_stored_messages(self, message_ids: set[str]) -> None:
|
|
147
|
+
"""Check and delete the message if it's invalid."""
|
|
148
|
+
with self.lock:
|
|
149
|
+
invalid_msg_ids: set[str] = set()
|
|
150
|
+
current = now().timestamp()
|
|
151
|
+
for msg_id in message_ids:
|
|
152
|
+
if not (message := self.message_ins_store.get(msg_id)):
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
# Check if the message has expired
|
|
156
|
+
available_until = message.metadata.created_at + message.metadata.ttl
|
|
157
|
+
if available_until <= current:
|
|
158
|
+
invalid_msg_ids.add(msg_id)
|
|
159
|
+
continue
|
|
160
|
+
|
|
161
|
+
# Check if the destination node and the source node are still in the
|
|
162
|
+
# same federation
|
|
163
|
+
src_node_id = message.metadata.src_node_id
|
|
164
|
+
dst_node_id = message.metadata.dst_node_id
|
|
165
|
+
filtered = self.federation_manager.filter_nodes(
|
|
166
|
+
{src_node_id, dst_node_id},
|
|
167
|
+
self.run_ids[message.metadata.run_id].run.federation,
|
|
168
|
+
)
|
|
169
|
+
if len(filtered) != 2: # Not both nodes are in the federation
|
|
170
|
+
invalid_msg_ids.add(msg_id)
|
|
171
|
+
|
|
172
|
+
# Delete all invalid messages
|
|
173
|
+
self.delete_messages(invalid_msg_ids)
|
|
174
|
+
|
|
175
|
+
def get_message_ins(self, node_id: int, limit: int | None) -> list[Message]:
|
|
140
176
|
"""Get all Messages that have not been delivered yet."""
|
|
141
177
|
if limit is not None and limit < 1:
|
|
142
178
|
raise AssertionError("`limit` must be >= 1")
|
|
143
179
|
|
|
144
180
|
# Find Message for node_id that were not delivered yet
|
|
145
181
|
message_ins_list: list[Message] = []
|
|
146
|
-
current_time = now().timestamp()
|
|
147
182
|
with self.lock:
|
|
148
|
-
for
|
|
183
|
+
for msg_id in list(self.message_ins_store.keys()):
|
|
184
|
+
self._check_stored_messages({msg_id})
|
|
185
|
+
|
|
149
186
|
if (
|
|
150
|
-
msg_ins.
|
|
187
|
+
(msg_ins := self.message_ins_store.get(msg_id))
|
|
188
|
+
and msg_ins.metadata.dst_node_id == node_id
|
|
151
189
|
and msg_ins.metadata.delivered_at == ""
|
|
152
|
-
and msg_ins.metadata.created_at + msg_ins.metadata.ttl
|
|
153
|
-
> current_time
|
|
154
190
|
):
|
|
155
191
|
message_ins_list.append(msg_ins)
|
|
156
192
|
if limit and len(message_ins_list) == limit:
|
|
@@ -165,7 +201,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
165
201
|
return message_ins_list
|
|
166
202
|
|
|
167
203
|
# pylint: disable=R0911
|
|
168
|
-
def store_message_res(self, message: Message) ->
|
|
204
|
+
def store_message_res(self, message: Message) -> str | None:
|
|
169
205
|
"""Store one Message."""
|
|
170
206
|
# Validate message
|
|
171
207
|
errors = validate_message(message, is_reply_message=True)
|
|
@@ -177,6 +213,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
177
213
|
with self.lock:
|
|
178
214
|
# Check if the Message it is replying to exists and is valid
|
|
179
215
|
msg_ins_id = res_metadata.reply_to_message_id
|
|
216
|
+
self._check_stored_messages({msg_ins_id})
|
|
180
217
|
msg_ins = self.message_ins_store.get(msg_ins_id)
|
|
181
218
|
|
|
182
219
|
# Ensure that dst_node_id of original Message matches the src_node_id of
|
|
@@ -196,22 +233,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
196
233
|
)
|
|
197
234
|
return None
|
|
198
235
|
|
|
199
|
-
ins_metadata = msg_ins.metadata
|
|
200
|
-
if ins_metadata.created_at + ins_metadata.ttl <= now().timestamp():
|
|
201
|
-
log(
|
|
202
|
-
ERROR,
|
|
203
|
-
"Failed to store Message: the message it is replying to "
|
|
204
|
-
"(with ID %s) has expired",
|
|
205
|
-
msg_ins_id,
|
|
206
|
-
)
|
|
207
|
-
return None
|
|
208
|
-
|
|
209
236
|
# Fail if the Message TTL exceeds the
|
|
210
237
|
# expiration time of the Message it replies to.
|
|
211
238
|
# Condition: ins_metadata.created_at + ins_metadata.ttl ≥
|
|
212
239
|
# res_metadata.created_at + res_metadata.ttl
|
|
213
240
|
# A small tolerance is introduced to account
|
|
214
241
|
# for floating-point precision issues.
|
|
242
|
+
ins_metadata = msg_ins.metadata
|
|
215
243
|
max_allowed_ttl = (
|
|
216
244
|
ins_metadata.created_at + ins_metadata.ttl - res_metadata.created_at
|
|
217
245
|
)
|
|
@@ -245,6 +273,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
245
273
|
ret: dict[str, Message] = {}
|
|
246
274
|
|
|
247
275
|
with self.lock:
|
|
276
|
+
self._check_stored_messages(message_ids)
|
|
248
277
|
current = now().timestamp()
|
|
249
278
|
|
|
250
279
|
# Verify Message IDs
|
|
@@ -339,7 +368,11 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
339
368
|
return len(self.message_res_store)
|
|
340
369
|
|
|
341
370
|
def create_node(
|
|
342
|
-
self,
|
|
371
|
+
self,
|
|
372
|
+
owner_aid: str,
|
|
373
|
+
owner_name: str,
|
|
374
|
+
public_key: bytes,
|
|
375
|
+
heartbeat_interval: float,
|
|
343
376
|
) -> int:
|
|
344
377
|
"""Create, store in the link state, and return `node_id`."""
|
|
345
378
|
# Sample a random int64 as node_id
|
|
@@ -358,6 +391,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
358
391
|
self.nodes[node_id] = NodeInfo(
|
|
359
392
|
node_id=node_id,
|
|
360
393
|
owner_aid=owner_aid,
|
|
394
|
+
owner_name=owner_name,
|
|
361
395
|
status=NodeStatus.REGISTERED,
|
|
362
396
|
registered_at=now().isoformat(),
|
|
363
397
|
last_activated_at=None,
|
|
@@ -442,17 +476,19 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
442
476
|
with self.lock:
|
|
443
477
|
if run_id not in self.run_ids:
|
|
444
478
|
return set()
|
|
445
|
-
|
|
479
|
+
federation = self.run_ids[run_id].run.federation
|
|
480
|
+
node_ids = {
|
|
446
481
|
node.node_id
|
|
447
482
|
for node in self.get_node_info(statuses=[NodeStatus.ONLINE])
|
|
448
483
|
}
|
|
484
|
+
return self.federation_manager.filter_nodes(node_ids, federation)
|
|
449
485
|
|
|
450
486
|
def get_node_info(
|
|
451
487
|
self,
|
|
452
488
|
*,
|
|
453
|
-
node_ids:
|
|
454
|
-
owner_aids:
|
|
455
|
-
statuses:
|
|
489
|
+
node_ids: Sequence[int] | None = None,
|
|
490
|
+
owner_aids: Sequence[str] | None = None,
|
|
491
|
+
statuses: Sequence[str] | None = None,
|
|
456
492
|
) -> Sequence[NodeInfo]:
|
|
457
493
|
"""Retrieve information about nodes based on the specified filters."""
|
|
458
494
|
with self.lock:
|
|
@@ -468,9 +504,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
468
504
|
result.append(node)
|
|
469
505
|
return result
|
|
470
506
|
|
|
471
|
-
def _check_and_tag_offline_nodes(
|
|
472
|
-
self, node_ids: Optional[list[int]] = None
|
|
473
|
-
) -> None:
|
|
507
|
+
def _check_and_tag_offline_nodes(self, node_ids: list[int] | None = None) -> None:
|
|
474
508
|
with self.lock:
|
|
475
509
|
# Set all nodes of "online" status to "offline" if they've offline
|
|
476
510
|
current_ts = now().timestamp()
|
|
@@ -493,7 +527,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
493
527
|
raise ValueError(f"Node ID {node_id} not found")
|
|
494
528
|
return node.public_key
|
|
495
529
|
|
|
496
|
-
def get_node_id_by_public_key(self, public_key: bytes) ->
|
|
530
|
+
def get_node_id_by_public_key(self, public_key: bytes) -> int | None:
|
|
497
531
|
"""Get `node_id` for the specified `public_key` if it exists and is not
|
|
498
532
|
deleted."""
|
|
499
533
|
with self.lock:
|
|
@@ -510,14 +544,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
510
544
|
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
511
545
|
def create_run(
|
|
512
546
|
self,
|
|
513
|
-
fab_id:
|
|
514
|
-
fab_version:
|
|
515
|
-
fab_hash:
|
|
547
|
+
fab_id: str | None,
|
|
548
|
+
fab_version: str | None,
|
|
549
|
+
fab_hash: str | None,
|
|
516
550
|
override_config: UserConfig,
|
|
551
|
+
federation: str,
|
|
517
552
|
federation_options: ConfigRecord,
|
|
518
|
-
flwr_aid:
|
|
553
|
+
flwr_aid: str | None,
|
|
519
554
|
) -> int:
|
|
520
|
-
"""Create a new run
|
|
555
|
+
"""Create a new run."""
|
|
521
556
|
# Sample a random int64 as run_id
|
|
522
557
|
with self.lock:
|
|
523
558
|
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
@@ -540,6 +575,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
540
575
|
details="",
|
|
541
576
|
),
|
|
542
577
|
flwr_aid=flwr_aid if flwr_aid else "",
|
|
578
|
+
federation=federation,
|
|
543
579
|
),
|
|
544
580
|
)
|
|
545
581
|
self.run_ids[run_id] = run_record
|
|
@@ -553,7 +589,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
553
589
|
log(ERROR, "Unexpected run creation failure.")
|
|
554
590
|
return 0
|
|
555
591
|
|
|
556
|
-
def get_run_ids(self, flwr_aid:
|
|
592
|
+
def get_run_ids(self, flwr_aid: str | None) -> set[int]:
|
|
557
593
|
"""Retrieve all run IDs if `flwr_aid` is not specified.
|
|
558
594
|
|
|
559
595
|
Otherwise, retrieve all run IDs for the specified `flwr_aid`.
|
|
@@ -564,30 +600,10 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
564
600
|
return set(self.flwr_aid_to_run_ids.get(flwr_aid, ()))
|
|
565
601
|
return set(self.run_ids.keys())
|
|
566
602
|
|
|
567
|
-
def
|
|
568
|
-
"""Check if any runs are no longer active.
|
|
569
|
-
|
|
570
|
-
Marks runs with status 'starting' or 'running' as failed
|
|
571
|
-
if they have not sent a heartbeat before `active_until`.
|
|
572
|
-
"""
|
|
573
|
-
current = now()
|
|
574
|
-
for record in (self.run_ids.get(run_id) for run_id in run_ids):
|
|
575
|
-
if record is None:
|
|
576
|
-
continue
|
|
577
|
-
with record.lock:
|
|
578
|
-
if record.run.status.status in (Status.STARTING, Status.RUNNING):
|
|
579
|
-
if record.active_until < current.timestamp():
|
|
580
|
-
record.run.status = RunStatus(
|
|
581
|
-
status=Status.FINISHED,
|
|
582
|
-
sub_status=SubStatus.FAILED,
|
|
583
|
-
details=RUN_FAILURE_DETAILS_NO_HEARTBEAT,
|
|
584
|
-
)
|
|
585
|
-
record.run.finished_at = now().isoformat()
|
|
586
|
-
|
|
587
|
-
def get_run(self, run_id: int) -> Optional[Run]:
|
|
603
|
+
def get_run(self, run_id: int) -> Run | None:
|
|
588
604
|
"""Retrieve information about the run with the specified `run_id`."""
|
|
589
|
-
#
|
|
590
|
-
self.
|
|
605
|
+
# Clean up expired tokens; this will flag inactive runs as needed
|
|
606
|
+
self._cleanup_expired_tokens()
|
|
591
607
|
|
|
592
608
|
with self.lock:
|
|
593
609
|
if run_id not in self.run_ids:
|
|
@@ -597,8 +613,8 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
597
613
|
|
|
598
614
|
def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
|
|
599
615
|
"""Retrieve the statuses for the specified runs."""
|
|
600
|
-
#
|
|
601
|
-
self.
|
|
616
|
+
# Clean up expired tokens; this will flag inactive runs as needed
|
|
617
|
+
self._cleanup_expired_tokens()
|
|
602
618
|
|
|
603
619
|
with self.lock:
|
|
604
620
|
return {
|
|
@@ -609,8 +625,8 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
609
625
|
|
|
610
626
|
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
|
|
611
627
|
"""Update the status of the run with the specified `run_id`."""
|
|
612
|
-
#
|
|
613
|
-
self.
|
|
628
|
+
# Clean up expired tokens; this will flag inactive runs as needed
|
|
629
|
+
self._cleanup_expired_tokens()
|
|
614
630
|
|
|
615
631
|
with self.lock:
|
|
616
632
|
# Check if the run_id exists
|
|
@@ -640,17 +656,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
640
656
|
)
|
|
641
657
|
return False
|
|
642
658
|
|
|
643
|
-
#
|
|
644
|
-
# when switching to starting or running
|
|
659
|
+
# Update the run status
|
|
645
660
|
current = now()
|
|
646
661
|
run_record = self.run_ids[run_id]
|
|
647
|
-
if new_status.status in (Status.STARTING, Status.RUNNING):
|
|
648
|
-
run_record.heartbeat_interval = HEARTBEAT_INTERVAL_INF
|
|
649
|
-
run_record.active_until = (
|
|
650
|
-
current.timestamp() + run_record.heartbeat_interval
|
|
651
|
-
)
|
|
652
|
-
|
|
653
|
-
# Update the run status
|
|
654
662
|
if new_status.status == Status.STARTING:
|
|
655
663
|
run_record.run.starting_at = current.isoformat()
|
|
656
664
|
elif new_status.status == Status.RUNNING:
|
|
@@ -660,7 +668,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
660
668
|
run_record.run.status = new_status
|
|
661
669
|
return True
|
|
662
670
|
|
|
663
|
-
def get_pending_run_id(self) ->
|
|
671
|
+
def get_pending_run_id(self) -> int | None:
|
|
664
672
|
"""Get the `run_id` of a run with `Status.PENDING` status, if any."""
|
|
665
673
|
pending_run_id = None
|
|
666
674
|
|
|
@@ -673,7 +681,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
673
681
|
|
|
674
682
|
return pending_run_id
|
|
675
683
|
|
|
676
|
-
def get_federation_options(self, run_id: int) ->
|
|
684
|
+
def get_federation_options(self, run_id: int) -> ConfigRecord | None:
|
|
677
685
|
"""Retrieve the federation options for the specified `run_id`."""
|
|
678
686
|
with self.lock:
|
|
679
687
|
if run_id not in self.run_ids:
|
|
@@ -710,44 +718,28 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
710
718
|
return True
|
|
711
719
|
return False
|
|
712
720
|
|
|
713
|
-
def
|
|
714
|
-
"""
|
|
721
|
+
def _on_tokens_expired(self, expired_records: list[tuple[int, float]]) -> None:
|
|
722
|
+
"""Transition runs with expired tokens to failed status.
|
|
715
723
|
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
724
|
+
Parameters
|
|
725
|
+
----------
|
|
726
|
+
expired_records : list[tuple[int, float]]
|
|
727
|
+
List of tuples containing (run_id, active_until timestamp)
|
|
728
|
+
for expired tokens.
|
|
720
729
|
"""
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
with record.lock:
|
|
731
|
-
# Check if runs are still active
|
|
732
|
-
self._check_and_tag_inactive_run(run_ids={run_id})
|
|
733
|
-
|
|
734
|
-
# Check if the run is of status "running"/"starting"
|
|
735
|
-
current_status = record.run.status
|
|
736
|
-
if current_status.status not in (Status.RUNNING, Status.STARTING):
|
|
737
|
-
log(
|
|
738
|
-
ERROR,
|
|
739
|
-
'Cannot acknowledge heartbeat for run with status "%s"',
|
|
740
|
-
current_status.status,
|
|
730
|
+
for run_id, active_until in expired_records:
|
|
731
|
+
if not (run_record := self.run_ids.get(run_id)):
|
|
732
|
+
continue
|
|
733
|
+
with run_record.lock:
|
|
734
|
+
run_record.run.status = RunStatus(
|
|
735
|
+
status=Status.FINISHED,
|
|
736
|
+
sub_status=SubStatus.FAILED,
|
|
737
|
+
details=RUN_FAILURE_DETAILS_NO_HEARTBEAT,
|
|
741
738
|
)
|
|
742
|
-
|
|
739
|
+
active_until_dt = datetime.fromtimestamp(active_until, tz=timezone.utc)
|
|
740
|
+
run_record.run.finished_at = active_until_dt.isoformat()
|
|
743
741
|
|
|
744
|
-
|
|
745
|
-
current = now().timestamp()
|
|
746
|
-
record.active_until = current + HEARTBEAT_PATIENCE * heartbeat_interval
|
|
747
|
-
record.heartbeat_interval = heartbeat_interval
|
|
748
|
-
return True
|
|
749
|
-
|
|
750
|
-
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
|
742
|
+
def get_serverapp_context(self, run_id: int) -> Context | None:
|
|
751
743
|
"""Get the context for the specified `run_id`."""
|
|
752
744
|
return self.contexts.get(run_id)
|
|
753
745
|
|
|
@@ -766,7 +758,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
766
758
|
run.logs.append((now().timestamp(), log_message))
|
|
767
759
|
|
|
768
760
|
def get_serverapp_log(
|
|
769
|
-
self, run_id: int, after_timestamp:
|
|
761
|
+
self, run_id: int, after_timestamp: float | None
|
|
770
762
|
) -> tuple[str, float]:
|
|
771
763
|
"""Get the serverapp logs for the specified `run_id`."""
|
|
772
764
|
if run_id not in self.run_ids:
|
|
@@ -779,30 +771,3 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
779
771
|
index = bisect_right(run.logs, (after_timestamp, ""))
|
|
780
772
|
latest_timestamp = run.logs[-1][0] if index < len(run.logs) else 0.0
|
|
781
773
|
return "".join(log for _, log in run.logs[index:]), latest_timestamp
|
|
782
|
-
|
|
783
|
-
def create_token(self, run_id: int) -> Optional[str]:
|
|
784
|
-
"""Create a token for the given run ID."""
|
|
785
|
-
token = secrets.token_hex(FLWR_APP_TOKEN_LENGTH) # Generate a random token
|
|
786
|
-
with self.lock_token_store:
|
|
787
|
-
if run_id in self.token_store:
|
|
788
|
-
return None # Token already created for this run ID
|
|
789
|
-
self.token_store[run_id] = token
|
|
790
|
-
self.token_to_run_id[token] = run_id
|
|
791
|
-
return token
|
|
792
|
-
|
|
793
|
-
def verify_token(self, run_id: int, token: str) -> bool:
|
|
794
|
-
"""Verify a token for the given run ID."""
|
|
795
|
-
with self.lock_token_store:
|
|
796
|
-
return self.token_store.get(run_id) == token
|
|
797
|
-
|
|
798
|
-
def delete_token(self, run_id: int) -> None:
|
|
799
|
-
"""Delete the token for the given run ID."""
|
|
800
|
-
with self.lock_token_store:
|
|
801
|
-
token = self.token_store.pop(run_id, None)
|
|
802
|
-
if token is not None:
|
|
803
|
-
self.token_to_run_id.pop(token, None)
|
|
804
|
-
|
|
805
|
-
def get_run_id_by_token(self, token: str) -> Optional[int]:
|
|
806
|
-
"""Get the run ID associated with a given token."""
|
|
807
|
-
with self.lock_token_store:
|
|
808
|
-
return self.token_to_run_id.get(token)
|