flwr 1.23.0__py3-none-any.whl → 1.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +16 -5
- flwr/app/error.py +2 -2
- flwr/app/exception.py +3 -3
- flwr/cli/app.py +19 -0
- flwr/cli/{new/templates → app_cmd}/__init__.py +9 -1
- flwr/cli/app_cmd/publish.py +285 -0
- flwr/cli/app_cmd/review.py +262 -0
- flwr/cli/auth_plugin/auth_plugin.py +4 -5
- flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
- flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
- flwr/cli/build.py +60 -18
- flwr/cli/cli_account_auth_interceptor.py +24 -7
- flwr/cli/config_utils.py +101 -13
- flwr/cli/{new/templates/app/code/flwr_tune → federation}/__init__.py +10 -1
- flwr/cli/federation/ls.py +140 -0
- flwr/cli/federation/show.py +318 -0
- flwr/cli/install.py +91 -13
- flwr/cli/log.py +52 -9
- flwr/cli/login/login.py +7 -4
- flwr/cli/ls.py +211 -130
- flwr/cli/new/new.py +123 -331
- flwr/cli/pull.py +10 -5
- flwr/cli/run/run.py +71 -29
- flwr/cli/run_utils.py +148 -0
- flwr/cli/stop.py +26 -8
- flwr/cli/supernode/ls.py +25 -12
- flwr/cli/supernode/register.py +9 -4
- flwr/cli/supernode/unregister.py +5 -3
- flwr/cli/utils.py +239 -16
- flwr/client/__init__.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +4 -1
- flwr/client/grpc_adapter_client/connection.py +8 -9
- flwr/client/grpc_rere_client/connection.py +16 -14
- flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
- flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/connection.py +18 -18
- flwr/client/run_info_store.py +4 -5
- flwr/client/typing.py +1 -1
- flwr/clientapp/client_app.py +9 -10
- flwr/clientapp/mod/centraldp_mods.py +16 -17
- flwr/clientapp/mod/localdp_mod.py +8 -9
- flwr/clientapp/typing.py +1 -1
- flwr/clientapp/utils.py +3 -3
- flwr/common/address.py +1 -2
- flwr/common/args.py +3 -4
- flwr/common/config.py +13 -16
- flwr/common/constant.py +5 -2
- flwr/common/differential_privacy.py +3 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -4
- flwr/common/exit/exit.py +15 -2
- flwr/common/exit/exit_code.py +19 -0
- flwr/common/exit/exit_handler.py +6 -2
- flwr/common/exit/signal_handler.py +5 -5
- flwr/common/grpc.py +6 -6
- flwr/common/inflatable_protobuf_utils.py +1 -1
- flwr/common/inflatable_utils.py +38 -21
- flwr/common/logger.py +19 -19
- flwr/common/message.py +4 -4
- flwr/common/object_ref.py +7 -7
- flwr/common/record/array.py +3 -3
- flwr/common/record/arrayrecord.py +18 -30
- flwr/common/record/configrecord.py +3 -3
- flwr/common/record/recorddict.py +5 -5
- flwr/common/record/typeddict.py +9 -2
- flwr/common/recorddict_compat.py +7 -10
- flwr/common/retry_invoker.py +20 -20
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
- flwr/common/serde.py +11 -4
- flwr/common/serde_utils.py +2 -2
- flwr/common/telemetry.py +9 -5
- flwr/common/typing.py +58 -37
- flwr/compat/client/app.py +38 -37
- flwr/compat/client/grpc_client/connection.py +11 -11
- flwr/compat/server/app.py +5 -6
- flwr/proto/appio_pb2.py +13 -3
- flwr/proto/appio_pb2.pyi +134 -65
- flwr/proto/appio_pb2_grpc.py +20 -0
- flwr/proto/appio_pb2_grpc.pyi +27 -0
- flwr/proto/clientappio_pb2.py +17 -7
- flwr/proto/clientappio_pb2.pyi +15 -0
- flwr/proto/clientappio_pb2_grpc.py +206 -40
- flwr/proto/clientappio_pb2_grpc.pyi +168 -53
- flwr/proto/control_pb2.py +71 -52
- flwr/proto/control_pb2.pyi +277 -111
- flwr/proto/control_pb2_grpc.py +249 -40
- flwr/proto/control_pb2_grpc.pyi +185 -52
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2.pyi +24 -6
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/error_pb2_grpc.pyi +27 -0
- flwr/proto/fab_pb2.py +14 -4
- flwr/proto/fab_pb2.pyi +59 -31
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fab_pb2_grpc.pyi +27 -0
- flwr/proto/federation_pb2.py +38 -0
- flwr/proto/federation_pb2.pyi +56 -0
- flwr/proto/federation_pb2_grpc.py +24 -0
- flwr/proto/federation_pb2_grpc.pyi +31 -0
- flwr/proto/fleet_pb2.py +24 -14
- flwr/proto/fleet_pb2.pyi +141 -61
- flwr/proto/fleet_pb2_grpc.py +189 -48
- flwr/proto/fleet_pb2_grpc.pyi +175 -61
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2.pyi +38 -16
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
- flwr/proto/heartbeat_pb2.py +17 -7
- flwr/proto/heartbeat_pb2.pyi +51 -22
- flwr/proto/heartbeat_pb2_grpc.py +20 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2.pyi +34 -11
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/log_pb2_grpc.pyi +27 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2.pyi +154 -86
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/message_pb2_grpc.pyi +27 -0
- flwr/proto/node_pb2.py +15 -5
- flwr/proto/node_pb2.pyi +50 -25
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/node_pb2_grpc.pyi +27 -0
- flwr/proto/recorddict_pb2.py +13 -3
- flwr/proto/recorddict_pb2.pyi +184 -107
- flwr/proto/recorddict_pb2_grpc.py +20 -0
- flwr/proto/recorddict_pb2_grpc.pyi +27 -0
- flwr/proto/run_pb2.py +40 -31
- flwr/proto/run_pb2.pyi +158 -84
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/run_pb2_grpc.pyi +27 -0
- flwr/proto/serverappio_pb2.py +13 -3
- flwr/proto/serverappio_pb2.pyi +32 -8
- flwr/proto/serverappio_pb2_grpc.py +246 -65
- flwr/proto/serverappio_pb2_grpc.pyi +221 -85
- flwr/proto/simulationio_pb2.py +16 -8
- flwr/proto/simulationio_pb2.pyi +15 -0
- flwr/proto/simulationio_pb2_grpc.py +162 -41
- flwr/proto/simulationio_pb2_grpc.pyi +149 -55
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2.pyi +249 -160
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/proto/transport_pb2_grpc.pyi +38 -8
- flwr/server/app.py +39 -17
- flwr/server/client_manager.py +4 -5
- flwr/server/client_proxy.py +10 -11
- flwr/server/compat/app.py +4 -5
- flwr/server/compat/app_utils.py +2 -1
- flwr/server/compat/grid_client_proxy.py +10 -12
- flwr/server/compat/legacy_context.py +3 -4
- flwr/server/fleet_event_log_interceptor.py +2 -1
- flwr/server/grid/grid.py +2 -3
- flwr/server/grid/grpc_grid.py +10 -8
- flwr/server/grid/inmemory_grid.py +4 -4
- flwr/server/run_serverapp.py +2 -3
- flwr/server/server.py +34 -39
- flwr/server/server_app.py +7 -8
- flwr/server/server_config.py +1 -2
- flwr/server/serverapp/app.py +34 -28
- flwr/server/serverapp_components.py +4 -5
- flwr/server/strategy/aggregate.py +9 -8
- flwr/server/strategy/bulyan.py +13 -11
- flwr/server/strategy/dp_adaptive_clipping.py +16 -20
- flwr/server/strategy/dp_fixed_clipping.py +12 -17
- flwr/server/strategy/dpfedavg_adaptive.py +3 -4
- flwr/server/strategy/dpfedavg_fixed.py +6 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
- flwr/server/strategy/fedadagrad.py +18 -14
- flwr/server/strategy/fedadam.py +16 -14
- flwr/server/strategy/fedavg.py +16 -17
- flwr/server/strategy/fedavg_android.py +15 -15
- flwr/server/strategy/fedavgm.py +21 -18
- flwr/server/strategy/fedmedian.py +2 -3
- flwr/server/strategy/fedopt.py +11 -10
- flwr/server/strategy/fedprox.py +10 -9
- flwr/server/strategy/fedtrimmedavg.py +12 -11
- flwr/server/strategy/fedxgb_bagging.py +13 -11
- flwr/server/strategy/fedxgb_cyclic.py +6 -6
- flwr/server/strategy/fedxgb_nn_avg.py +4 -4
- flwr/server/strategy/fedyogi.py +16 -14
- flwr/server/strategy/krum.py +12 -11
- flwr/server/strategy/qfedavg.py +16 -15
- flwr/server/strategy/strategy.py +6 -9
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
- flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +75 -30
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +15 -9
- flwr/server/superlink/linkstate/in_memory_linkstate.py +148 -149
- flwr/server/superlink/linkstate/linkstate.py +91 -43
- flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
- flwr/server/superlink/linkstate/sqlite_linkstate.py +502 -436
- flwr/server/superlink/linkstate/utils.py +6 -6
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
- flwr/server/superlink/utils.py +4 -6
- flwr/server/typing.py +1 -1
- flwr/server/utils/tensorboard.py +15 -8
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
- flwr/serverapp/strategy/bulyan.py +16 -15
- flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
- flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
- flwr/serverapp/strategy/fedadagrad.py +10 -11
- flwr/serverapp/strategy/fedadam.py +10 -11
- flwr/serverapp/strategy/fedavg.py +9 -10
- flwr/serverapp/strategy/fedavgm.py +17 -16
- flwr/serverapp/strategy/fedmedian.py +2 -2
- flwr/serverapp/strategy/fedopt.py +10 -11
- flwr/serverapp/strategy/fedprox.py +7 -8
- flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
- flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
- flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
- flwr/serverapp/strategy/fedyogi.py +9 -11
- flwr/serverapp/strategy/krum.py +7 -7
- flwr/serverapp/strategy/multikrum.py +9 -9
- flwr/serverapp/strategy/qfedavg.py +17 -16
- flwr/serverapp/strategy/strategy.py +6 -9
- flwr/serverapp/strategy/strategy_utils.py +7 -8
- flwr/simulation/app.py +46 -42
- flwr/simulation/legacy_app.py +12 -12
- flwr/simulation/ray_transport/ray_actor.py +10 -11
- flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
- flwr/simulation/run_simulation.py +43 -43
- flwr/simulation/simulationio_connection.py +4 -4
- flwr/supercore/cli/flower_superexec.py +3 -4
- flwr/supercore/constant.py +34 -1
- flwr/supercore/corestate/corestate.py +24 -3
- flwr/supercore/corestate/in_memory_corestate.py +138 -0
- flwr/supercore/corestate/sqlite_corestate.py +157 -0
- flwr/supercore/ffs/disk_ffs.py +1 -2
- flwr/supercore/ffs/ffs.py +1 -2
- flwr/supercore/ffs/ffs_factory.py +1 -2
- flwr/{common → supercore}/heartbeat.py +20 -25
- flwr/supercore/object_store/in_memory_object_store.py +1 -2
- flwr/supercore/object_store/object_store.py +1 -2
- flwr/supercore/object_store/object_store_factory.py +1 -2
- flwr/supercore/object_store/sqlite_object_store.py +8 -7
- flwr/supercore/primitives/asymmetric.py +1 -1
- flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
- flwr/supercore/sqlite_mixin.py +37 -34
- flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
- flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
- flwr/supercore/superexec/run_superexec.py +9 -13
- flwr/supercore/utils.py +190 -0
- flwr/superlink/artifact_provider/artifact_provider.py +1 -2
- flwr/superlink/auth_plugin/auth_plugin.py +6 -9
- flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
- flwr/{cli/new/templates/app → superlink/federation}/__init__.py +10 -1
- flwr/superlink/federation/federation_manager.py +64 -0
- flwr/superlink/federation/noop_federation_manager.py +71 -0
- flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
- flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
- flwr/superlink/servicer/control/control_grpc.py +7 -6
- flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
- flwr/superlink/servicer/control/control_servicer.py +190 -23
- flwr/supernode/cli/flower_supernode.py +58 -3
- flwr/supernode/nodestate/in_memory_nodestate.py +121 -49
- flwr/supernode/nodestate/nodestate.py +52 -8
- flwr/supernode/nodestate/nodestate_factory.py +7 -4
- flwr/supernode/runtime/run_clientapp.py +41 -22
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +46 -10
- flwr/supernode/start_client_internal.py +165 -46
- {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/METADATA +9 -11
- flwr-1.25.0.dist-info/RECORD +393 -0
- flwr/cli/new/templates/app/.gitignore.tpl +0 -163
- flwr/cli/new/templates/app/LICENSE.tpl +0 -202
- flwr/cli/new/templates/app/README.baseline.md.tpl +0 -127
- flwr/cli/new/templates/app/README.flowertune.md.tpl +0 -68
- flwr/cli/new/templates/app/README.md.tpl +0 -37
- flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/code/__init__.py +0 -15
- flwr/cli/new/templates/app/code/__init__.py.tpl +0 -1
- flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +0 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +0 -75
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +0 -93
- flwr/cli/new/templates/app/code/client.jax.py.tpl +0 -71
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +0 -102
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +0 -46
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +0 -80
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +0 -55
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +0 -108
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -82
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +0 -110
- flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +0 -36
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +0 -92
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +0 -87
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -56
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +0 -73
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +0 -78
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -66
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +0 -43
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +0 -42
- flwr/cli/new/templates/app/code/server.jax.py.tpl +0 -39
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +0 -41
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +0 -38
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +0 -41
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +0 -31
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +0 -44
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +0 -38
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +0 -56
- flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +0 -98
- flwr/cli/new/templates/app/code/task.jax.py.tpl +0 -57
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +0 -102
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +0 -7
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +0 -98
- flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +0 -111
- flwr/cli/new/templates/app/code/task.sklearn.py.tpl +0 -67
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +0 -52
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +0 -67
- flwr/cli/new/templates/app/code/utils.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +0 -146
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +0 -80
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +0 -65
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +0 -52
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +0 -56
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +0 -49
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +0 -52
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +0 -61
- flwr/supercore/object_store/utils.py +0 -43
- flwr-1.23.0.dist-info/RECORD +0 -439
- {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/WHEEL +0 -0
- {flwr-1.23.0.dist-info → flwr-1.25.0.dist-info}/entry_points.txt +0 -0
|
@@ -18,16 +18,14 @@
|
|
|
18
18
|
# pylint: disable=too-many-lines
|
|
19
19
|
|
|
20
20
|
import json
|
|
21
|
-
import secrets
|
|
22
21
|
import sqlite3
|
|
23
22
|
from collections.abc import Sequence
|
|
23
|
+
from datetime import datetime, timezone
|
|
24
24
|
from logging import ERROR, WARNING
|
|
25
|
-
from typing import Any,
|
|
25
|
+
from typing import Any, cast
|
|
26
26
|
|
|
27
27
|
from flwr.common import Context, Message, Metadata, log, now
|
|
28
28
|
from flwr.common.constant import (
|
|
29
|
-
FLWR_APP_TOKEN_LENGTH,
|
|
30
|
-
HEARTBEAT_INTERVAL_INF,
|
|
31
29
|
HEARTBEAT_PATIENCE,
|
|
32
30
|
MESSAGE_TTL_TOLERANCE,
|
|
33
31
|
NODE_ID_NUM_BYTES,
|
|
@@ -51,8 +49,10 @@ from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
|
|
|
51
49
|
# pylint: enable=E0611
|
|
52
50
|
from flwr.server.utils.validator import validate_message
|
|
53
51
|
from flwr.supercore.constant import NodeStatus
|
|
54
|
-
from flwr.supercore.
|
|
52
|
+
from flwr.supercore.corestate.sqlite_corestate import SqliteCoreState
|
|
53
|
+
from flwr.supercore.object_store.object_store import ObjectStore
|
|
55
54
|
from flwr.supercore.utils import int64_to_uint64, uint64_to_int64
|
|
55
|
+
from flwr.superlink.federation import FederationManager
|
|
56
56
|
|
|
57
57
|
from .linkstate import LinkState
|
|
58
58
|
from .utils import (
|
|
@@ -74,6 +74,7 @@ SQL_CREATE_TABLE_NODE = """
|
|
|
74
74
|
CREATE TABLE IF NOT EXISTS node(
|
|
75
75
|
node_id INTEGER UNIQUE,
|
|
76
76
|
owner_aid TEXT,
|
|
77
|
+
owner_name TEXT,
|
|
77
78
|
status TEXT,
|
|
78
79
|
registered_at TEXT,
|
|
79
80
|
last_activated_at TEXT NULL,
|
|
@@ -106,8 +107,6 @@ CREATE INDEX IF NOT EXISTS idx_node_status ON node(status);
|
|
|
106
107
|
SQL_CREATE_TABLE_RUN = """
|
|
107
108
|
CREATE TABLE IF NOT EXISTS run(
|
|
108
109
|
run_id INTEGER UNIQUE,
|
|
109
|
-
active_until REAL,
|
|
110
|
-
heartbeat_interval REAL,
|
|
111
110
|
fab_id TEXT,
|
|
112
111
|
fab_version TEXT,
|
|
113
112
|
fab_hash TEXT,
|
|
@@ -118,8 +117,12 @@ CREATE TABLE IF NOT EXISTS run(
|
|
|
118
117
|
finished_at TEXT,
|
|
119
118
|
sub_status TEXT,
|
|
120
119
|
details TEXT,
|
|
120
|
+
federation TEXT,
|
|
121
121
|
federation_options BLOB,
|
|
122
|
-
flwr_aid TEXT
|
|
122
|
+
flwr_aid TEXT,
|
|
123
|
+
bytes_sent INTEGER DEFAULT 0,
|
|
124
|
+
bytes_recv INTEGER DEFAULT 0,
|
|
125
|
+
clientapp_runtime REAL DEFAULT 0.0
|
|
123
126
|
);
|
|
124
127
|
"""
|
|
125
128
|
|
|
@@ -179,20 +182,23 @@ CREATE TABLE IF NOT EXISTS message_res(
|
|
|
179
182
|
);
|
|
180
183
|
"""
|
|
181
184
|
|
|
182
|
-
SQL_CREATE_TABLE_TOKEN_STORE = """
|
|
183
|
-
CREATE TABLE IF NOT EXISTS token_store (
|
|
184
|
-
run_id INTEGER PRIMARY KEY,
|
|
185
|
-
token TEXT UNIQUE NOT NULL
|
|
186
|
-
);
|
|
187
|
-
"""
|
|
188
|
-
|
|
189
185
|
|
|
190
|
-
class SqliteLinkState(LinkState,
|
|
186
|
+
class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
191
187
|
"""SQLite-based LinkState implementation."""
|
|
192
188
|
|
|
193
|
-
def
|
|
194
|
-
|
|
195
|
-
|
|
189
|
+
def __init__(
|
|
190
|
+
self,
|
|
191
|
+
database_path: str,
|
|
192
|
+
federation_manager: FederationManager,
|
|
193
|
+
object_store: ObjectStore,
|
|
194
|
+
) -> None:
|
|
195
|
+
super().__init__(database_path, object_store)
|
|
196
|
+
federation_manager.linkstate = self
|
|
197
|
+
self._federation_manager = federation_manager
|
|
198
|
+
|
|
199
|
+
def get_sql_statements(self) -> tuple[str, ...]:
|
|
200
|
+
"""Return SQL statements for LinkState tables."""
|
|
201
|
+
return super().get_sql_statements() + (
|
|
196
202
|
SQL_CREATE_TABLE_RUN,
|
|
197
203
|
SQL_CREATE_TABLE_LOGS,
|
|
198
204
|
SQL_CREATE_TABLE_CONTEXT,
|
|
@@ -200,14 +206,17 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
200
206
|
SQL_CREATE_TABLE_MESSAGE_RES,
|
|
201
207
|
SQL_CREATE_TABLE_NODE,
|
|
202
208
|
SQL_CREATE_TABLE_PUBLIC_KEY,
|
|
203
|
-
SQL_CREATE_TABLE_TOKEN_STORE,
|
|
204
209
|
SQL_CREATE_INDEX_ONLINE_UNTIL,
|
|
205
210
|
SQL_CREATE_INDEX_OWNER_AID,
|
|
206
211
|
SQL_CREATE_INDEX_NODE_STATUS,
|
|
207
|
-
log_queries=log_queries,
|
|
208
212
|
)
|
|
209
213
|
|
|
210
|
-
|
|
214
|
+
@property
|
|
215
|
+
def federation_manager(self) -> FederationManager:
|
|
216
|
+
"""Get the FederationManager instance."""
|
|
217
|
+
return self._federation_manager
|
|
218
|
+
|
|
219
|
+
def store_message_ins(self, message: Message) -> str | None:
|
|
211
220
|
"""Store one Message."""
|
|
212
221
|
# Validate message
|
|
213
222
|
errors = validate_message(message=message, is_reply_message=False)
|
|
@@ -223,12 +232,6 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
223
232
|
data[0], ["run_id", "src_node_id", "dst_node_id"]
|
|
224
233
|
)
|
|
225
234
|
|
|
226
|
-
# Validate run_id
|
|
227
|
-
query = "SELECT run_id FROM run WHERE run_id = ?;"
|
|
228
|
-
if not self.query(query, (data[0]["run_id"],)):
|
|
229
|
-
log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
|
|
230
|
-
return None
|
|
231
|
-
|
|
232
235
|
# Validate source node ID
|
|
233
236
|
if message.metadata.src_node_id != SUPERLINK_NODE_ID:
|
|
234
237
|
log(
|
|
@@ -238,28 +241,87 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
238
241
|
)
|
|
239
242
|
return None
|
|
240
243
|
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
query, (data[0]["
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
244
|
+
with self.conn:
|
|
245
|
+
# Validate run_id
|
|
246
|
+
query = "SELECT federation FROM run WHERE run_id = ?;"
|
|
247
|
+
rows = self.conn.execute(query, (data[0]["run_id"],)).fetchall()
|
|
248
|
+
if not rows:
|
|
249
|
+
log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
|
|
250
|
+
return None
|
|
251
|
+
federation: str = rows[0]["federation"]
|
|
252
|
+
|
|
253
|
+
# Validate destination node ID
|
|
254
|
+
query = "SELECT node_id FROM node WHERE node_id = ? AND status IN (?, ?);"
|
|
255
|
+
rows = self.conn.execute(
|
|
256
|
+
query, (data[0]["dst_node_id"], NodeStatus.ONLINE, NodeStatus.OFFLINE)
|
|
257
|
+
).fetchall()
|
|
258
|
+
if not rows or not self.federation_manager.has_node(
|
|
259
|
+
message.metadata.dst_node_id, federation
|
|
260
|
+
):
|
|
261
|
+
log(
|
|
262
|
+
ERROR,
|
|
263
|
+
"Invalid destination node ID for Message: %s",
|
|
264
|
+
message.metadata.dst_node_id,
|
|
265
|
+
)
|
|
266
|
+
return None
|
|
267
|
+
|
|
268
|
+
columns = ", ".join([f":{key}" for key in data[0]])
|
|
269
|
+
query = f"INSERT INTO message_ins VALUES({columns});"
|
|
270
|
+
|
|
271
|
+
# Only invalid run_id can trigger IntegrityError.
|
|
272
|
+
# This may need to be changed in the future version
|
|
273
|
+
# with more integrity checks.
|
|
274
|
+
self.conn.execute(query, data[0])
|
|
259
275
|
|
|
260
276
|
return message.metadata.message_id
|
|
261
277
|
|
|
262
|
-
def
|
|
278
|
+
def _check_stored_messages(self, message_ids: set[str]) -> None:
|
|
279
|
+
"""Check and delete the message if it's invalid."""
|
|
280
|
+
if not message_ids:
|
|
281
|
+
return
|
|
282
|
+
|
|
283
|
+
with self.conn:
|
|
284
|
+
invalid_msg_ids: set[str] = set()
|
|
285
|
+
current_time = now().timestamp()
|
|
286
|
+
|
|
287
|
+
for msg_id in message_ids:
|
|
288
|
+
# Check if message exists
|
|
289
|
+
query = "SELECT * FROM message_ins WHERE message_id = ?;"
|
|
290
|
+
message_row = self.conn.execute(query, (msg_id,)).fetchone()
|
|
291
|
+
if not message_row:
|
|
292
|
+
continue
|
|
293
|
+
|
|
294
|
+
# Check if the message has expired
|
|
295
|
+
available_until = message_row["created_at"] + message_row["ttl"]
|
|
296
|
+
if available_until <= current_time:
|
|
297
|
+
invalid_msg_ids.add(msg_id)
|
|
298
|
+
continue
|
|
299
|
+
|
|
300
|
+
# Check if src_node_id and dst_node_id are in the federation
|
|
301
|
+
# Get federation from run table
|
|
302
|
+
run_id = message_row["run_id"]
|
|
303
|
+
query = "SELECT federation FROM run WHERE run_id = ?;"
|
|
304
|
+
run_row = self.conn.execute(query, (run_id,)).fetchone()
|
|
305
|
+
if not run_row: # This should not happen
|
|
306
|
+
invalid_msg_ids.add(msg_id)
|
|
307
|
+
continue
|
|
308
|
+
federation = run_row["federation"]
|
|
309
|
+
|
|
310
|
+
# Convert sint64 to uint64 for node IDs
|
|
311
|
+
src_node_id = int64_to_uint64(message_row["src_node_id"])
|
|
312
|
+
dst_node_id = int64_to_uint64(message_row["dst_node_id"])
|
|
313
|
+
|
|
314
|
+
# Filter nodes to check if they're in the federation
|
|
315
|
+
filtered = self.federation_manager.filter_nodes(
|
|
316
|
+
{src_node_id, dst_node_id}, federation
|
|
317
|
+
)
|
|
318
|
+
if len(filtered) != 2: # Not both nodes are in the federation
|
|
319
|
+
invalid_msg_ids.add(msg_id)
|
|
320
|
+
|
|
321
|
+
# Delete all invalid messages
|
|
322
|
+
self.delete_messages(invalid_msg_ids)
|
|
323
|
+
|
|
324
|
+
def get_message_ins(self, node_id: int, limit: int | None) -> list[Message]:
|
|
263
325
|
"""Get all Messages that have not been delivered yet."""
|
|
264
326
|
if limit is not None and limit < 1:
|
|
265
327
|
raise AssertionError("`limit` must be >= 1")
|
|
@@ -268,59 +330,64 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
268
330
|
msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
|
|
269
331
|
raise AssertionError(msg)
|
|
270
332
|
|
|
271
|
-
data: dict[str,
|
|
333
|
+
data: dict[str, str | int] = {}
|
|
272
334
|
|
|
273
335
|
# Convert the uint64 value to sint64 for SQLite
|
|
274
336
|
data["node_id"] = uint64_to_int64(node_id)
|
|
275
337
|
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
if limit is not None:
|
|
286
|
-
query += " LIMIT :limit"
|
|
287
|
-
data["limit"] = limit
|
|
288
|
-
|
|
289
|
-
query += ";"
|
|
290
|
-
|
|
291
|
-
rows = self.query(query, data)
|
|
292
|
-
|
|
293
|
-
if rows:
|
|
294
|
-
# Prepare query
|
|
295
|
-
message_ids = [row["message_id"] for row in rows]
|
|
296
|
-
placeholders: str = ",".join([f":id_{i}" for i in range(len(message_ids))])
|
|
297
|
-
query = f"""
|
|
298
|
-
UPDATE message_ins
|
|
299
|
-
SET delivered_at = :delivered_at
|
|
300
|
-
WHERE message_id IN ({placeholders})
|
|
301
|
-
RETURNING *;
|
|
338
|
+
with self.conn:
|
|
339
|
+
# Retrieve all Messages for node_id
|
|
340
|
+
query = """
|
|
341
|
+
SELECT message_id
|
|
342
|
+
FROM message_ins
|
|
343
|
+
WHERE dst_node_id == :node_id
|
|
344
|
+
AND delivered_at = ""
|
|
345
|
+
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
|
|
302
346
|
"""
|
|
303
347
|
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
for index, msg_id in enumerate(message_ids):
|
|
308
|
-
data[f"id_{index}"] = str(msg_id)
|
|
348
|
+
if limit is not None:
|
|
349
|
+
query += " LIMIT :limit"
|
|
350
|
+
data["limit"] = limit
|
|
309
351
|
|
|
310
|
-
|
|
311
|
-
rows = self.query(query, data)
|
|
352
|
+
query += ";"
|
|
312
353
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
354
|
+
rows = self.conn.execute(query, data).fetchall()
|
|
355
|
+
message_ids: set[str] = {row["message_id"] for row in rows}
|
|
356
|
+
self._check_stored_messages(message_ids)
|
|
357
|
+
|
|
358
|
+
# Mark retrieved Messages as delivered
|
|
359
|
+
if rows:
|
|
360
|
+
# Prepare query
|
|
361
|
+
placeholders: str = ",".join(
|
|
362
|
+
[f":id_{i}" for i in range(len(message_ids))]
|
|
363
|
+
)
|
|
364
|
+
query = f"""
|
|
365
|
+
UPDATE message_ins
|
|
366
|
+
SET delivered_at = :delivered_at
|
|
367
|
+
WHERE message_id IN ({placeholders})
|
|
368
|
+
RETURNING *;
|
|
369
|
+
"""
|
|
370
|
+
|
|
371
|
+
# Prepare data for query
|
|
372
|
+
delivered_at = now().isoformat()
|
|
373
|
+
data = {"delivered_at": delivered_at}
|
|
374
|
+
for index, msg_id in enumerate(message_ids):
|
|
375
|
+
data[f"id_{index}"] = str(msg_id)
|
|
376
|
+
|
|
377
|
+
# Run query
|
|
378
|
+
rows = self.conn.execute(query, data).fetchall()
|
|
379
|
+
|
|
380
|
+
for row in rows:
|
|
381
|
+
# Convert values from sint64 to uint64
|
|
382
|
+
convert_sint64_values_in_dict_to_uint64(
|
|
383
|
+
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
384
|
+
)
|
|
318
385
|
|
|
319
386
|
result = [dict_to_message(row) for row in rows]
|
|
320
387
|
|
|
321
388
|
return result
|
|
322
389
|
|
|
323
|
-
def store_message_res(self, message: Message) ->
|
|
390
|
+
def store_message_res(self, message: Message) -> str | None:
|
|
324
391
|
"""Store one Message."""
|
|
325
392
|
# Validate message
|
|
326
393
|
errors = validate_message(message=message, is_reply_message=True)
|
|
@@ -336,7 +403,8 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
336
403
|
ERROR,
|
|
337
404
|
"Failed to store Message reply: "
|
|
338
405
|
"The message it replies to with message_id %s does not exist or "
|
|
339
|
-
"has expired
|
|
406
|
+
"has expired, or was deleted because the target SuperNode was "
|
|
407
|
+
"removed from the federation.",
|
|
340
408
|
msg_ins_id,
|
|
341
409
|
)
|
|
342
410
|
return None
|
|
@@ -397,84 +465,92 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
397
465
|
# pylint: disable-msg=too-many-locals
|
|
398
466
|
ret: dict[str, Message] = {}
|
|
399
467
|
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
468
|
+
with self.conn:
|
|
469
|
+
# Verify Message IDs
|
|
470
|
+
self._check_stored_messages(message_ids)
|
|
471
|
+
current = now().timestamp()
|
|
472
|
+
query = f"""
|
|
473
|
+
SELECT *
|
|
474
|
+
FROM message_ins
|
|
475
|
+
WHERE message_id IN ({','.join(['?'] * len(message_ids))});
|
|
476
|
+
"""
|
|
477
|
+
rows = self.conn.execute(
|
|
478
|
+
query, tuple(str(message_id) for message_id in message_ids)
|
|
479
|
+
).fetchall()
|
|
480
|
+
found_message_ins_dict: dict[str, Message] = {}
|
|
481
|
+
for row in rows:
|
|
482
|
+
convert_sint64_values_in_dict_to_uint64(
|
|
483
|
+
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
484
|
+
)
|
|
485
|
+
found_message_ins_dict[row["message_id"]] = dict_to_message(row)
|
|
486
|
+
|
|
487
|
+
ret = verify_message_ids(
|
|
488
|
+
inquired_message_ids=message_ids,
|
|
489
|
+
found_message_ins_dict=found_message_ins_dict,
|
|
490
|
+
current_time=current,
|
|
412
491
|
)
|
|
413
|
-
found_message_ins_dict[row["message_id"]] = dict_to_message(row)
|
|
414
492
|
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
493
|
+
# Check node availability
|
|
494
|
+
dst_node_ids: set[int] = set()
|
|
495
|
+
for message_id in message_ids:
|
|
496
|
+
in_message = found_message_ins_dict[message_id]
|
|
497
|
+
sint_node_id = uint64_to_int64(in_message.metadata.dst_node_id)
|
|
498
|
+
dst_node_ids.add(sint_node_id)
|
|
499
|
+
query = f"""
|
|
500
|
+
SELECT node_id, online_until
|
|
501
|
+
FROM node
|
|
502
|
+
WHERE node_id IN ({','.join(['?'] * len(dst_node_ids))})
|
|
503
|
+
AND status != ?
|
|
504
|
+
"""
|
|
505
|
+
rows = self.conn.execute(
|
|
506
|
+
query, tuple(dst_node_ids) + (NodeStatus.UNREGISTERED,)
|
|
507
|
+
).fetchall()
|
|
508
|
+
tmp_ret_dict = check_node_availability_for_in_message(
|
|
509
|
+
inquired_in_message_ids=message_ids,
|
|
510
|
+
found_in_message_dict=found_message_ins_dict,
|
|
511
|
+
node_id_to_online_until={
|
|
512
|
+
int64_to_uint64(row["node_id"]): row["online_until"] for row in rows
|
|
513
|
+
},
|
|
514
|
+
current_time=current,
|
|
515
|
+
)
|
|
516
|
+
ret.update(tmp_ret_dict)
|
|
420
517
|
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
current_time=current,
|
|
441
|
-
)
|
|
442
|
-
ret.update(tmp_ret_dict)
|
|
443
|
-
|
|
444
|
-
# Find all reply Messages
|
|
445
|
-
query = f"""
|
|
446
|
-
SELECT *
|
|
447
|
-
FROM message_res
|
|
448
|
-
WHERE reply_to_message_id IN ({",".join(["?"] * len(message_ids))})
|
|
449
|
-
AND delivered_at = "";
|
|
450
|
-
"""
|
|
451
|
-
rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
|
|
452
|
-
for row in rows:
|
|
453
|
-
convert_sint64_values_in_dict_to_uint64(
|
|
454
|
-
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
518
|
+
# Find all reply Messages
|
|
519
|
+
query = f"""
|
|
520
|
+
SELECT *
|
|
521
|
+
FROM message_res
|
|
522
|
+
WHERE reply_to_message_id IN ({','.join(['?'] * len(message_ids))})
|
|
523
|
+
AND delivered_at = "";
|
|
524
|
+
"""
|
|
525
|
+
rows = self.conn.execute(
|
|
526
|
+
query, tuple(str(message_id) for message_id in message_ids)
|
|
527
|
+
).fetchall()
|
|
528
|
+
for row in rows:
|
|
529
|
+
convert_sint64_values_in_dict_to_uint64(
|
|
530
|
+
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
531
|
+
)
|
|
532
|
+
tmp_ret_dict = verify_found_message_replies(
|
|
533
|
+
inquired_message_ids=message_ids,
|
|
534
|
+
found_message_ins_dict=found_message_ins_dict,
|
|
535
|
+
found_message_res_list=[dict_to_message(row) for row in rows],
|
|
536
|
+
current_time=current,
|
|
455
537
|
)
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
UPDATE message_res
|
|
473
|
-
SET delivered_at = ?
|
|
474
|
-
WHERE message_id IN ({",".join(["?"] * len(message_res_ids))});
|
|
475
|
-
"""
|
|
476
|
-
data: list[Any] = [delivered_at] + message_res_ids
|
|
477
|
-
self.query(query, data)
|
|
538
|
+
ret.update(tmp_ret_dict)
|
|
539
|
+
|
|
540
|
+
# Mark existing reply Messages to be returned as delivered
|
|
541
|
+
delivered_at = now().isoformat()
|
|
542
|
+
for message_res in ret.values():
|
|
543
|
+
message_res.metadata.delivered_at = delivered_at
|
|
544
|
+
message_res_ids = [
|
|
545
|
+
message_res.metadata.message_id for message_res in ret.values()
|
|
546
|
+
]
|
|
547
|
+
query = f"""
|
|
548
|
+
UPDATE message_res
|
|
549
|
+
SET delivered_at = ?
|
|
550
|
+
WHERE message_id IN ({','.join(['?'] * len(message_res_ids))});
|
|
551
|
+
"""
|
|
552
|
+
data: list[Any] = [delivered_at] + message_res_ids
|
|
553
|
+
self.conn.execute(query, data)
|
|
478
554
|
|
|
479
555
|
return list(ret.values())
|
|
480
556
|
|
|
@@ -545,7 +621,11 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
545
621
|
return {row["message_id"] for row in rows}
|
|
546
622
|
|
|
547
623
|
def create_node(
|
|
548
|
-
self,
|
|
624
|
+
self,
|
|
625
|
+
owner_aid: str,
|
|
626
|
+
owner_name: str,
|
|
627
|
+
public_key: bytes,
|
|
628
|
+
heartbeat_interval: float,
|
|
549
629
|
) -> int:
|
|
550
630
|
"""Create, store in the link state, and return `node_id`."""
|
|
551
631
|
# Sample a random uint64 as node_id
|
|
@@ -558,10 +638,10 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
558
638
|
|
|
559
639
|
query = """
|
|
560
640
|
INSERT INTO node
|
|
561
|
-
(node_id, owner_aid, status, registered_at, last_activated_at,
|
|
641
|
+
(node_id, owner_aid, owner_name, status, registered_at, last_activated_at,
|
|
562
642
|
last_deactivated_at, unregistered_at, online_until, heartbeat_interval,
|
|
563
643
|
public_key)
|
|
564
|
-
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
644
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
565
645
|
"""
|
|
566
646
|
|
|
567
647
|
# Mark the node online until now().timestamp() + heartbeat_interval
|
|
@@ -571,6 +651,7 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
571
651
|
(
|
|
572
652
|
sint64_node_id, # node_id
|
|
573
653
|
owner_aid, # owner_aid
|
|
654
|
+
owner_name, # owner_name
|
|
574
655
|
NodeStatus.REGISTERED, # status
|
|
575
656
|
now().isoformat(), # registered_at
|
|
576
657
|
None, # last_activated_at
|
|
@@ -686,23 +767,26 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
686
767
|
if self.conn is None:
|
|
687
768
|
raise AttributeError("LinkState not initialized")
|
|
688
769
|
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
770
|
+
with self.conn:
|
|
771
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
772
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
773
|
+
|
|
774
|
+
# Validate run ID
|
|
775
|
+
query = "SELECT federation FROM run WHERE run_id = ?"
|
|
776
|
+
rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
|
|
777
|
+
if not rows:
|
|
778
|
+
return set()
|
|
779
|
+
federation: str = rows[0]["federation"]
|
|
780
|
+
|
|
781
|
+
# Retrieve all online nodes
|
|
782
|
+
node_ids = {
|
|
783
|
+
node.node_id
|
|
784
|
+
for node in self.get_node_info(statuses=[NodeStatus.ONLINE])
|
|
785
|
+
}
|
|
786
|
+
# Filter node IDs by federation
|
|
787
|
+
return self.federation_manager.filter_nodes(node_ids, federation)
|
|
788
|
+
|
|
789
|
+
def _check_and_tag_offline_nodes(self, node_ids: list[int] | None = None) -> None:
|
|
706
790
|
"""Check and tag offline nodes."""
|
|
707
791
|
# strftime will convert POSIX timestamp to ISO format
|
|
708
792
|
query = """
|
|
@@ -725,9 +809,9 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
725
809
|
def get_node_info(
|
|
726
810
|
self,
|
|
727
811
|
*,
|
|
728
|
-
node_ids:
|
|
729
|
-
owner_aids:
|
|
730
|
-
statuses:
|
|
812
|
+
node_ids: Sequence[int] | None = None,
|
|
813
|
+
owner_aids: Sequence[str] | None = None,
|
|
814
|
+
statuses: Sequence[str] | None = None,
|
|
731
815
|
) -> Sequence[NodeInfo]:
|
|
732
816
|
"""Retrieve information about nodes based on the specified filters."""
|
|
733
817
|
with self.conn:
|
|
@@ -781,7 +865,7 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
781
865
|
# Return the public key
|
|
782
866
|
return cast(bytes, rows[0]["public_key"])
|
|
783
867
|
|
|
784
|
-
def get_node_id_by_public_key(self, public_key: bytes) ->
|
|
868
|
+
def get_node_id_by_public_key(self, public_key: bytes) -> int | None:
|
|
785
869
|
"""Get `node_id` for the specified `public_key` if it exists and is not
|
|
786
870
|
deleted."""
|
|
787
871
|
query = "SELECT node_id FROM node WHERE public_key = ? AND status != ?;"
|
|
@@ -798,55 +882,61 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
798
882
|
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
799
883
|
def create_run(
|
|
800
884
|
self,
|
|
801
|
-
fab_id:
|
|
802
|
-
fab_version:
|
|
803
|
-
fab_hash:
|
|
885
|
+
fab_id: str | None,
|
|
886
|
+
fab_version: str | None,
|
|
887
|
+
fab_hash: str | None,
|
|
804
888
|
override_config: UserConfig,
|
|
889
|
+
federation: str,
|
|
805
890
|
federation_options: ConfigRecord,
|
|
806
|
-
flwr_aid:
|
|
891
|
+
flwr_aid: str | None,
|
|
807
892
|
) -> int:
|
|
808
|
-
"""Create a new run
|
|
893
|
+
"""Create a new run."""
|
|
809
894
|
# Sample a random int64 as run_id
|
|
810
895
|
uint64_run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
811
896
|
|
|
812
897
|
# Convert the uint64 value to sint64 for SQLite
|
|
813
898
|
sint64_run_id = uint64_to_int64(uint64_run_id)
|
|
814
899
|
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
900
|
+
with self.conn:
|
|
901
|
+
# Check conflicts
|
|
902
|
+
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
|
|
903
|
+
# If sint64_run_id does not exist
|
|
904
|
+
row = self.conn.execute(query, (sint64_run_id,)).fetchone()
|
|
905
|
+
if row["COUNT(*)"] == 0:
|
|
906
|
+
query = """
|
|
907
|
+
INSERT INTO run
|
|
908
|
+
(run_id, fab_id, fab_version,
|
|
909
|
+
fab_hash, override_config, federation, federation_options,
|
|
910
|
+
pending_at, starting_at, running_at, finished_at, sub_status,
|
|
911
|
+
details, flwr_aid, bytes_sent, bytes_recv, clientapp_runtime)
|
|
912
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
|
913
|
+
"""
|
|
914
|
+
override_config_json = json.dumps(override_config)
|
|
915
|
+
data = [
|
|
916
|
+
sint64_run_id, # run_id
|
|
917
|
+
fab_id, # fab_id
|
|
918
|
+
fab_version, # fab_version
|
|
919
|
+
fab_hash, # fab_hash
|
|
920
|
+
override_config_json, # override_config
|
|
921
|
+
federation, # federation
|
|
922
|
+
configrecord_to_bytes(federation_options), # federation_options
|
|
923
|
+
now().isoformat(), # pending_at
|
|
924
|
+
"", # starting_at
|
|
925
|
+
"", # running_at
|
|
926
|
+
"", # finished_at
|
|
927
|
+
"", # sub_status
|
|
928
|
+
"", # details
|
|
929
|
+
flwr_aid or "", # flwr_aid
|
|
930
|
+
0, # bytes_sent
|
|
931
|
+
0, # bytes_recv
|
|
932
|
+
0, # clientapp_runtime
|
|
933
|
+
]
|
|
934
|
+
self.conn.execute(query, tuple(data))
|
|
935
|
+
return uint64_run_id
|
|
846
936
|
log(ERROR, "Unexpected run creation failure.")
|
|
847
937
|
return 0
|
|
848
938
|
|
|
849
|
-
def get_run_ids(self, flwr_aid:
|
|
939
|
+
def get_run_ids(self, flwr_aid: str | None) -> set[int]:
|
|
850
940
|
"""Retrieve all run IDs if `flwr_aid` is not specified.
|
|
851
941
|
|
|
852
942
|
Otherwise, retrieve all run IDs for the specified `flwr_aid`.
|
|
@@ -860,32 +950,10 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
860
950
|
rows = self.query("SELECT run_id FROM run;", ())
|
|
861
951
|
return {int64_to_uint64(row["run_id"]) for row in rows}
|
|
862
952
|
|
|
863
|
-
def
|
|
864
|
-
"""Check if any runs are no longer active.
|
|
865
|
-
|
|
866
|
-
Marks runs with status 'starting' or 'running' as failed
|
|
867
|
-
if they have not sent a heartbeat before `active_until`.
|
|
868
|
-
"""
|
|
869
|
-
sint_run_ids = [uint64_to_int64(run_id) for run_id in run_ids]
|
|
870
|
-
query = "UPDATE run SET finished_at = ?, sub_status = ?, details = ? "
|
|
871
|
-
query += "WHERE starting_at != '' AND finished_at = '' AND active_until < ?"
|
|
872
|
-
query += f" AND run_id IN ({','.join(['?'] * len(run_ids))});"
|
|
873
|
-
current = now()
|
|
874
|
-
self.query(
|
|
875
|
-
query,
|
|
876
|
-
(
|
|
877
|
-
current.isoformat(),
|
|
878
|
-
SubStatus.FAILED,
|
|
879
|
-
RUN_FAILURE_DETAILS_NO_HEARTBEAT,
|
|
880
|
-
current.timestamp(),
|
|
881
|
-
*sint_run_ids,
|
|
882
|
-
),
|
|
883
|
-
)
|
|
884
|
-
|
|
885
|
-
def get_run(self, run_id: int) -> Optional[Run]:
|
|
953
|
+
def get_run(self, run_id: int) -> Run | None:
|
|
886
954
|
"""Retrieve information about the run with the specified `run_id`."""
|
|
887
|
-
#
|
|
888
|
-
self.
|
|
955
|
+
# Clean up expired tokens; this will flag inactive runs as needed
|
|
956
|
+
self._cleanup_expired_tokens()
|
|
889
957
|
|
|
890
958
|
# Convert the uint64 value to sint64 for SQLite
|
|
891
959
|
sint64_run_id = uint64_to_int64(run_id)
|
|
@@ -909,14 +977,18 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
909
977
|
details=row["details"],
|
|
910
978
|
),
|
|
911
979
|
flwr_aid=row["flwr_aid"],
|
|
980
|
+
federation=row["federation"],
|
|
981
|
+
bytes_sent=row["bytes_sent"],
|
|
982
|
+
bytes_recv=row["bytes_recv"],
|
|
983
|
+
clientapp_runtime=row["clientapp_runtime"],
|
|
912
984
|
)
|
|
913
985
|
log(ERROR, "`run_id` does not exist.")
|
|
914
986
|
return None
|
|
915
987
|
|
|
916
988
|
def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
|
|
917
989
|
"""Retrieve the statuses for the specified runs."""
|
|
918
|
-
#
|
|
919
|
-
self.
|
|
990
|
+
# Clean up expired tokens; this will flag inactive runs as needed
|
|
991
|
+
self._cleanup_expired_tokens()
|
|
920
992
|
|
|
921
993
|
# Convert the uint64 value to sint64 for SQLite
|
|
922
994
|
sint64_run_ids = (uint64_to_int64(run_id) for run_id in set(run_ids))
|
|
@@ -935,82 +1007,73 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
935
1007
|
|
|
936
1008
|
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
|
|
937
1009
|
"""Update the status of the run with the specified `run_id`."""
|
|
938
|
-
#
|
|
939
|
-
self.
|
|
1010
|
+
# Clean up expired tokens; this will flag inactive runs as needed
|
|
1011
|
+
self._cleanup_expired_tokens()
|
|
940
1012
|
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
1013
|
+
with self.conn:
|
|
1014
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
1015
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
1016
|
+
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
1017
|
+
rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
|
|
1018
|
+
|
|
1019
|
+
# Check if the run_id exists
|
|
1020
|
+
if not rows:
|
|
1021
|
+
log(ERROR, "`run_id` is invalid")
|
|
1022
|
+
return False
|
|
950
1023
|
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
)
|
|
958
|
-
if not is_valid_transition(current_status, new_status):
|
|
959
|
-
log(
|
|
960
|
-
ERROR,
|
|
961
|
-
'Invalid status transition: from "%s" to "%s"',
|
|
962
|
-
current_status.status,
|
|
963
|
-
new_status.status,
|
|
1024
|
+
# Check if the status transition is valid
|
|
1025
|
+
row = rows[0]
|
|
1026
|
+
current_status = RunStatus(
|
|
1027
|
+
status=determine_run_status(row),
|
|
1028
|
+
sub_status=row["sub_status"],
|
|
1029
|
+
details=row["details"],
|
|
964
1030
|
)
|
|
965
|
-
|
|
1031
|
+
if not is_valid_transition(current_status, new_status):
|
|
1032
|
+
log(
|
|
1033
|
+
ERROR,
|
|
1034
|
+
'Invalid status transition: from "%s" to "%s"',
|
|
1035
|
+
current_status.status,
|
|
1036
|
+
new_status.status,
|
|
1037
|
+
)
|
|
1038
|
+
return False
|
|
966
1039
|
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
1040
|
+
# Check if the sub-status is valid
|
|
1041
|
+
if not has_valid_sub_status(current_status):
|
|
1042
|
+
log(
|
|
1043
|
+
ERROR,
|
|
1044
|
+
'Invalid sub-status "%s" for status "%s"',
|
|
1045
|
+
current_status.sub_status,
|
|
1046
|
+
current_status.status,
|
|
1047
|
+
)
|
|
1048
|
+
return False
|
|
976
1049
|
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
1050
|
+
# Update the status
|
|
1051
|
+
query = """
|
|
1052
|
+
UPDATE run SET %s= ?, sub_status = ?, details = ? WHERE run_id = ?;
|
|
1053
|
+
"""
|
|
981
1054
|
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
timestamp_fld
|
|
1001
|
-
|
|
1002
|
-
data = (
|
|
1003
|
-
current.isoformat(),
|
|
1004
|
-
new_status.sub_status,
|
|
1005
|
-
new_status.details,
|
|
1006
|
-
active_until,
|
|
1007
|
-
heartbeat_interval,
|
|
1008
|
-
uint64_to_int64(run_id),
|
|
1009
|
-
)
|
|
1010
|
-
self.query(query % timestamp_fld, data)
|
|
1055
|
+
# Prepare data for query
|
|
1056
|
+
current = now()
|
|
1057
|
+
|
|
1058
|
+
# Determine the timestamp field based on the new status
|
|
1059
|
+
timestamp_fld = ""
|
|
1060
|
+
if new_status.status == Status.STARTING:
|
|
1061
|
+
timestamp_fld = "starting_at"
|
|
1062
|
+
elif new_status.status == Status.RUNNING:
|
|
1063
|
+
timestamp_fld = "running_at"
|
|
1064
|
+
elif new_status.status == Status.FINISHED:
|
|
1065
|
+
timestamp_fld = "finished_at"
|
|
1066
|
+
|
|
1067
|
+
data = (
|
|
1068
|
+
current.isoformat(),
|
|
1069
|
+
new_status.sub_status,
|
|
1070
|
+
new_status.details,
|
|
1071
|
+
uint64_to_int64(run_id),
|
|
1072
|
+
)
|
|
1073
|
+
self.conn.execute(query % timestamp_fld, data)
|
|
1011
1074
|
return True
|
|
1012
1075
|
|
|
1013
|
-
def get_pending_run_id(self) ->
|
|
1076
|
+
def get_pending_run_id(self) -> int | None:
|
|
1014
1077
|
"""Get the `run_id` of a run with `Status.PENDING` status, if any."""
|
|
1015
1078
|
pending_run_id = None
|
|
1016
1079
|
|
|
@@ -1022,7 +1085,7 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
1022
1085
|
|
|
1023
1086
|
return pending_run_id
|
|
1024
1087
|
|
|
1025
|
-
def get_federation_options(self, run_id: int) ->
|
|
1088
|
+
def get_federation_options(self, run_id: int) -> ConfigRecord | None:
|
|
1026
1089
|
"""Retrieve the federation options for the specified `run_id`."""
|
|
1027
1090
|
# Convert the uint64 value to sint64 for SQLite
|
|
1028
1091
|
sint64_run_id = uint64_to_int64(run_id)
|
|
@@ -1080,45 +1143,36 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
1080
1143
|
self.conn.execute(query, params)
|
|
1081
1144
|
return True
|
|
1082
1145
|
|
|
1083
|
-
def
|
|
1084
|
-
"""
|
|
1146
|
+
def _on_tokens_expired(self, expired_records: list[tuple[int, float]]) -> None:
|
|
1147
|
+
"""Transition runs with expired tokens to failed status.
|
|
1085
1148
|
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1149
|
+
Parameters
|
|
1150
|
+
----------
|
|
1151
|
+
expired_records : list[tuple[int, float]]
|
|
1152
|
+
List of tuples containing (run_id, active_until timestamp)
|
|
1153
|
+
for expired tokens.
|
|
1090
1154
|
"""
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
# Search for the run
|
|
1095
|
-
sint_run_id = uint64_to_int64(run_id)
|
|
1096
|
-
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
1097
|
-
rows = self.query(query, (sint_run_id,))
|
|
1098
|
-
|
|
1099
|
-
if not rows:
|
|
1100
|
-
log(ERROR, "`run_id` is invalid")
|
|
1101
|
-
return False
|
|
1102
|
-
|
|
1103
|
-
# Check if the run is of status "running"/"starting"
|
|
1104
|
-
row = rows[0]
|
|
1105
|
-
status = determine_run_status(row)
|
|
1106
|
-
if status not in (Status.RUNNING, Status.STARTING):
|
|
1107
|
-
log(
|
|
1108
|
-
ERROR,
|
|
1109
|
-
'Cannot acknowledge heartbeat for run with status "%s"',
|
|
1110
|
-
status,
|
|
1111
|
-
)
|
|
1112
|
-
return False
|
|
1155
|
+
if not expired_records:
|
|
1156
|
+
return
|
|
1113
1157
|
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1158
|
+
with self.conn:
|
|
1159
|
+
query = """
|
|
1160
|
+
UPDATE run
|
|
1161
|
+
SET sub_status = ?, details = ?, finished_at = ?
|
|
1162
|
+
WHERE run_id = ?;
|
|
1163
|
+
"""
|
|
1164
|
+
data = [
|
|
1165
|
+
(
|
|
1166
|
+
SubStatus.FAILED,
|
|
1167
|
+
RUN_FAILURE_DETAILS_NO_HEARTBEAT,
|
|
1168
|
+
datetime.fromtimestamp(active_until, tz=timezone.utc).isoformat(),
|
|
1169
|
+
uint64_to_int64(run_id),
|
|
1170
|
+
)
|
|
1171
|
+
for run_id, active_until in expired_records
|
|
1172
|
+
]
|
|
1173
|
+
self.conn.executemany(query, data)
|
|
1120
1174
|
|
|
1121
|
-
def get_serverapp_context(self, run_id: int) ->
|
|
1175
|
+
def get_serverapp_context(self, run_id: int) -> Context | None:
|
|
1122
1176
|
"""Get the context for the specified `run_id`."""
|
|
1123
1177
|
# Retrieve context if any
|
|
1124
1178
|
query = "SELECT context FROM context WHERE run_id = ?;"
|
|
@@ -1132,19 +1186,21 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
1132
1186
|
context_bytes = context_to_bytes(context)
|
|
1133
1187
|
sint_run_id = uint64_to_int64(run_id)
|
|
1134
1188
|
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1189
|
+
with self.conn:
|
|
1190
|
+
# Check if any existing Context assigned to the run_id
|
|
1191
|
+
query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
|
|
1192
|
+
row = self.conn.execute(query, (sint_run_id,)).fetchone()
|
|
1193
|
+
if row["COUNT(*)"] > 0:
|
|
1194
|
+
# Update context
|
|
1195
|
+
query = "UPDATE context SET context = ? WHERE run_id = ?;"
|
|
1196
|
+
self.conn.execute(query, (context_bytes, sint_run_id))
|
|
1197
|
+
else:
|
|
1198
|
+
try:
|
|
1199
|
+
# Store context
|
|
1200
|
+
query = "INSERT INTO context (run_id, context) VALUES (?, ?);"
|
|
1201
|
+
self.conn.execute(query, (sint_run_id, context_bytes))
|
|
1202
|
+
except sqlite3.IntegrityError:
|
|
1203
|
+
raise ValueError(f"Run {run_id} not found") from None
|
|
1148
1204
|
|
|
1149
1205
|
def add_serverapp_log(self, run_id: int, log_message: str) -> None:
|
|
1150
1206
|
"""Add a log entry to the ServerApp logs for the specified `run_id`."""
|
|
@@ -1161,90 +1217,100 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
1161
1217
|
raise ValueError(f"Run {run_id} not found") from None
|
|
1162
1218
|
|
|
1163
1219
|
def get_serverapp_log(
|
|
1164
|
-
self, run_id: int, after_timestamp:
|
|
1220
|
+
self, run_id: int, after_timestamp: float | None
|
|
1165
1221
|
) -> tuple[str, float]:
|
|
1166
1222
|
"""Get the ServerApp logs for the specified `run_id`."""
|
|
1167
1223
|
# Convert the uint64 value to sint64 for SQLite
|
|
1168
1224
|
sint64_run_id = uint64_to_int64(run_id)
|
|
1169
1225
|
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1226
|
+
with self.conn:
|
|
1227
|
+
# Check if the run_id exists
|
|
1228
|
+
query = "SELECT run_id FROM run WHERE run_id = ?;"
|
|
1229
|
+
rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
|
|
1230
|
+
if not rows:
|
|
1231
|
+
raise ValueError(f"Run {run_id} not found")
|
|
1232
|
+
|
|
1233
|
+
# Retrieve logs
|
|
1234
|
+
if after_timestamp is None:
|
|
1235
|
+
after_timestamp = 0.0
|
|
1236
|
+
query = """
|
|
1237
|
+
SELECT log, timestamp FROM logs
|
|
1238
|
+
WHERE run_id = ? AND node_id = ? AND timestamp > ?;
|
|
1239
|
+
"""
|
|
1240
|
+
rows = self.conn.execute(
|
|
1241
|
+
query, (sint64_run_id, 0, after_timestamp)
|
|
1242
|
+
).fetchall()
|
|
1243
|
+
rows.sort(key=lambda x: x["timestamp"])
|
|
1244
|
+
latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
|
|
1185
1245
|
return "".join(row["log"] for row in rows), latest_timestamp
|
|
1186
1246
|
|
|
1187
|
-
def get_valid_message_ins(self, message_id: str) ->
|
|
1247
|
+
def get_valid_message_ins(self, message_id: str) -> dict[str, Any] | None:
|
|
1188
1248
|
"""Check if the Message exists and is valid (not expired).
|
|
1189
1249
|
|
|
1190
1250
|
Return Message if valid.
|
|
1191
1251
|
"""
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1252
|
+
with self.conn:
|
|
1253
|
+
self._check_stored_messages({message_id})
|
|
1254
|
+
query = """
|
|
1255
|
+
SELECT *
|
|
1256
|
+
FROM message_ins
|
|
1257
|
+
WHERE message_id = :message_id
|
|
1258
|
+
"""
|
|
1259
|
+
data = {"message_id": message_id}
|
|
1260
|
+
rows: list[dict[str, Any]] = self.conn.execute(query, data).fetchall()
|
|
1261
|
+
if not rows:
|
|
1262
|
+
# Message does not exist
|
|
1263
|
+
return None
|
|
1264
|
+
|
|
1265
|
+
return rows[0]
|
|
1266
|
+
|
|
1267
|
+
def store_traffic(self, run_id: int, *, bytes_sent: int, bytes_recv: int) -> None:
|
|
1268
|
+
"""Store traffic data for the specified `run_id`."""
|
|
1269
|
+
# Validate non-negative values
|
|
1270
|
+
if bytes_sent < 0 or bytes_recv < 0:
|
|
1271
|
+
raise ValueError(
|
|
1272
|
+
f"Negative traffic values for run {run_id}: "
|
|
1273
|
+
f"bytes_sent={bytes_sent}, bytes_recv={bytes_recv}"
|
|
1274
|
+
)
|
|
1202
1275
|
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1276
|
+
if bytes_sent == 0 and bytes_recv == 0:
|
|
1277
|
+
raise ValueError(
|
|
1278
|
+
f"Both bytes_sent and bytes_recv cannot be zero for run {run_id}"
|
|
1279
|
+
)
|
|
1207
1280
|
|
|
1208
|
-
|
|
1209
|
-
if ttl is not None and created_at + ttl <= current_time:
|
|
1210
|
-
return None
|
|
1281
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
1211
1282
|
|
|
1212
|
-
|
|
1283
|
+
with self.conn:
|
|
1284
|
+
# Check if run exists, performing the update only if it does
|
|
1285
|
+
update_query = """
|
|
1286
|
+
UPDATE run
|
|
1287
|
+
SET bytes_sent = bytes_sent + ?,
|
|
1288
|
+
bytes_recv = bytes_recv + ?
|
|
1289
|
+
WHERE run_id = ?
|
|
1290
|
+
RETURNING run_id;
|
|
1291
|
+
"""
|
|
1292
|
+
rows = self.conn.execute(
|
|
1293
|
+
update_query, (bytes_sent, bytes_recv, sint64_run_id)
|
|
1294
|
+
).fetchall()
|
|
1213
1295
|
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
return cast(str, rows[0]["token"]) == token
|
|
1233
|
-
|
|
1234
|
-
def delete_token(self, run_id: int) -> None:
|
|
1235
|
-
"""Delete the token for the given run ID."""
|
|
1236
|
-
query = "DELETE FROM token_store WHERE run_id = :run_id;"
|
|
1237
|
-
data = {"run_id": uint64_to_int64(run_id)}
|
|
1238
|
-
self.query(query, data)
|
|
1239
|
-
|
|
1240
|
-
def get_run_id_by_token(self, token: str) -> Optional[int]:
|
|
1241
|
-
"""Get the run ID associated with a given token."""
|
|
1242
|
-
query = "SELECT run_id FROM token_store WHERE token = :token;"
|
|
1243
|
-
data = {"token": token}
|
|
1244
|
-
rows = self.query(query, data)
|
|
1245
|
-
if not rows:
|
|
1246
|
-
return None
|
|
1247
|
-
return int64_to_uint64(rows[0]["run_id"])
|
|
1296
|
+
if not rows:
|
|
1297
|
+
raise ValueError(f"Run {run_id} not found")
|
|
1298
|
+
|
|
1299
|
+
def add_clientapp_runtime(self, run_id: int, runtime: float) -> None:
|
|
1300
|
+
"""Add ClientApp runtime to the cumulative total for the specified `run_id`."""
|
|
1301
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
1302
|
+
with self.conn:
|
|
1303
|
+
# Check if run exists, performing the update only if it does
|
|
1304
|
+
update_query = """
|
|
1305
|
+
UPDATE run
|
|
1306
|
+
SET clientapp_runtime = clientapp_runtime + ?
|
|
1307
|
+
WHERE run_id = ?
|
|
1308
|
+
RETURNING run_id;
|
|
1309
|
+
"""
|
|
1310
|
+
rows = self.conn.execute(update_query, (runtime, sint64_run_id)).fetchall()
|
|
1311
|
+
|
|
1312
|
+
if not rows:
|
|
1313
|
+
raise ValueError(f"Run {run_id} not found")
|
|
1248
1314
|
|
|
1249
1315
|
|
|
1250
1316
|
def message_to_dict(message: Message) -> dict[str, Any]:
|