flwr 1.23.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +16 -5
- flwr/app/error.py +2 -2
- flwr/app/exception.py +3 -3
- flwr/cli/app.py +19 -0
- flwr/cli/app_cmd/__init__.py +23 -0
- flwr/cli/app_cmd/publish.py +285 -0
- flwr/cli/app_cmd/review.py +252 -0
- flwr/cli/auth_plugin/auth_plugin.py +4 -5
- flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
- flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
- flwr/cli/build.py +60 -18
- flwr/cli/cli_account_auth_interceptor.py +24 -7
- flwr/cli/config_utils.py +101 -13
- flwr/cli/federation/__init__.py +24 -0
- flwr/cli/federation/ls.py +140 -0
- flwr/cli/federation/show.py +317 -0
- flwr/cli/install.py +91 -13
- flwr/cli/log.py +52 -9
- flwr/cli/login/login.py +7 -4
- flwr/cli/ls.py +170 -130
- flwr/cli/new/new.py +33 -50
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/cli/pull.py +10 -5
- flwr/cli/run/run.py +77 -30
- flwr/cli/run_utils.py +130 -0
- flwr/cli/stop.py +25 -7
- flwr/cli/supernode/ls.py +16 -8
- flwr/cli/supernode/register.py +9 -4
- flwr/cli/supernode/unregister.py +5 -3
- flwr/cli/utils.py +376 -16
- flwr/client/__init__.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +4 -1
- flwr/client/grpc_adapter_client/connection.py +6 -7
- flwr/client/grpc_rere_client/connection.py +10 -11
- flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
- flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/connection.py +12 -14
- flwr/client/run_info_store.py +4 -5
- flwr/client/typing.py +1 -1
- flwr/clientapp/client_app.py +9 -10
- flwr/clientapp/mod/centraldp_mods.py +16 -17
- flwr/clientapp/mod/localdp_mod.py +8 -9
- flwr/clientapp/typing.py +1 -1
- flwr/clientapp/utils.py +3 -3
- flwr/common/address.py +1 -2
- flwr/common/args.py +3 -4
- flwr/common/config.py +13 -16
- flwr/common/constant.py +5 -2
- flwr/common/differential_privacy.py +3 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -4
- flwr/common/exit/exit.py +15 -2
- flwr/common/exit/exit_code.py +19 -0
- flwr/common/exit/exit_handler.py +6 -2
- flwr/common/exit/signal_handler.py +5 -5
- flwr/common/grpc.py +6 -6
- flwr/common/inflatable_protobuf_utils.py +1 -1
- flwr/common/inflatable_utils.py +38 -21
- flwr/common/logger.py +19 -19
- flwr/common/message.py +4 -4
- flwr/common/object_ref.py +7 -7
- flwr/common/record/array.py +3 -3
- flwr/common/record/arrayrecord.py +18 -30
- flwr/common/record/configrecord.py +3 -3
- flwr/common/record/recorddict.py +5 -5
- flwr/common/record/typeddict.py +9 -2
- flwr/common/recorddict_compat.py +7 -10
- flwr/common/retry_invoker.py +20 -20
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
- flwr/common/serde.py +5 -4
- flwr/common/serde_utils.py +2 -2
- flwr/common/telemetry.py +9 -5
- flwr/common/typing.py +52 -37
- flwr/compat/client/app.py +38 -37
- flwr/compat/client/grpc_client/connection.py +11 -11
- flwr/compat/server/app.py +5 -6
- flwr/proto/appio_pb2.py +13 -3
- flwr/proto/appio_pb2.pyi +134 -65
- flwr/proto/appio_pb2_grpc.py +20 -0
- flwr/proto/appio_pb2_grpc.pyi +27 -0
- flwr/proto/clientappio_pb2.py +17 -7
- flwr/proto/clientappio_pb2.pyi +15 -0
- flwr/proto/clientappio_pb2_grpc.py +206 -40
- flwr/proto/clientappio_pb2_grpc.pyi +168 -53
- flwr/proto/control_pb2.py +71 -52
- flwr/proto/control_pb2.pyi +277 -111
- flwr/proto/control_pb2_grpc.py +249 -40
- flwr/proto/control_pb2_grpc.pyi +185 -52
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2.pyi +24 -6
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/error_pb2_grpc.pyi +27 -0
- flwr/proto/fab_pb2.py +14 -4
- flwr/proto/fab_pb2.pyi +59 -31
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fab_pb2_grpc.pyi +27 -0
- flwr/proto/federation_pb2.py +38 -0
- flwr/proto/federation_pb2.pyi +56 -0
- flwr/proto/federation_pb2_grpc.py +24 -0
- flwr/proto/federation_pb2_grpc.pyi +31 -0
- flwr/proto/fleet_pb2.py +14 -4
- flwr/proto/fleet_pb2.pyi +137 -61
- flwr/proto/fleet_pb2_grpc.py +189 -48
- flwr/proto/fleet_pb2_grpc.pyi +175 -61
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2.pyi +38 -16
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
- flwr/proto/heartbeat_pb2.py +17 -7
- flwr/proto/heartbeat_pb2.pyi +51 -22
- flwr/proto/heartbeat_pb2_grpc.py +20 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2.pyi +34 -11
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/log_pb2_grpc.pyi +27 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2.pyi +154 -86
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/message_pb2_grpc.pyi +27 -0
- flwr/proto/node_pb2.py +15 -5
- flwr/proto/node_pb2.pyi +50 -25
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/node_pb2_grpc.pyi +27 -0
- flwr/proto/recorddict_pb2.py +13 -3
- flwr/proto/recorddict_pb2.pyi +184 -107
- flwr/proto/recorddict_pb2_grpc.py +20 -0
- flwr/proto/recorddict_pb2_grpc.pyi +27 -0
- flwr/proto/run_pb2.py +40 -31
- flwr/proto/run_pb2.pyi +149 -84
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/run_pb2_grpc.pyi +27 -0
- flwr/proto/serverappio_pb2.py +13 -3
- flwr/proto/serverappio_pb2.pyi +32 -8
- flwr/proto/serverappio_pb2_grpc.py +246 -65
- flwr/proto/serverappio_pb2_grpc.pyi +221 -85
- flwr/proto/simulationio_pb2.py +16 -8
- flwr/proto/simulationio_pb2.pyi +15 -0
- flwr/proto/simulationio_pb2_grpc.py +162 -41
- flwr/proto/simulationio_pb2_grpc.pyi +149 -55
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2.pyi +249 -160
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/proto/transport_pb2_grpc.pyi +38 -8
- flwr/server/app.py +38 -17
- flwr/server/client_manager.py +4 -5
- flwr/server/client_proxy.py +10 -11
- flwr/server/compat/app.py +4 -5
- flwr/server/compat/app_utils.py +2 -1
- flwr/server/compat/grid_client_proxy.py +10 -12
- flwr/server/compat/legacy_context.py +3 -4
- flwr/server/fleet_event_log_interceptor.py +2 -1
- flwr/server/grid/grid.py +2 -3
- flwr/server/grid/grpc_grid.py +10 -8
- flwr/server/grid/inmemory_grid.py +4 -4
- flwr/server/run_serverapp.py +2 -3
- flwr/server/server.py +34 -39
- flwr/server/server_app.py +7 -8
- flwr/server/server_config.py +1 -2
- flwr/server/serverapp/app.py +34 -28
- flwr/server/serverapp_components.py +4 -5
- flwr/server/strategy/aggregate.py +9 -8
- flwr/server/strategy/bulyan.py +13 -11
- flwr/server/strategy/dp_adaptive_clipping.py +16 -20
- flwr/server/strategy/dp_fixed_clipping.py +12 -17
- flwr/server/strategy/dpfedavg_adaptive.py +3 -4
- flwr/server/strategy/dpfedavg_fixed.py +6 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
- flwr/server/strategy/fedadagrad.py +18 -14
- flwr/server/strategy/fedadam.py +16 -14
- flwr/server/strategy/fedavg.py +16 -17
- flwr/server/strategy/fedavg_android.py +15 -15
- flwr/server/strategy/fedavgm.py +21 -18
- flwr/server/strategy/fedmedian.py +2 -3
- flwr/server/strategy/fedopt.py +11 -10
- flwr/server/strategy/fedprox.py +10 -9
- flwr/server/strategy/fedtrimmedavg.py +12 -11
- flwr/server/strategy/fedxgb_bagging.py +13 -11
- flwr/server/strategy/fedxgb_cyclic.py +6 -6
- flwr/server/strategy/fedxgb_nn_avg.py +4 -4
- flwr/server/strategy/fedyogi.py +16 -14
- flwr/server/strategy/krum.py +12 -11
- flwr/server/strategy/qfedavg.py +16 -15
- flwr/server/strategy/strategy.py +6 -9
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
- flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +34 -28
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +15 -9
- flwr/server/superlink/linkstate/in_memory_linkstate.py +115 -150
- flwr/server/superlink/linkstate/linkstate.py +59 -43
- flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
- flwr/server/superlink/linkstate/sqlite_linkstate.py +447 -438
- flwr/server/superlink/linkstate/utils.py +6 -6
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
- flwr/server/superlink/utils.py +4 -6
- flwr/server/typing.py +1 -1
- flwr/server/utils/tensorboard.py +15 -8
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
- flwr/serverapp/strategy/bulyan.py +16 -15
- flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
- flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
- flwr/serverapp/strategy/fedadagrad.py +10 -11
- flwr/serverapp/strategy/fedadam.py +10 -11
- flwr/serverapp/strategy/fedavg.py +9 -10
- flwr/serverapp/strategy/fedavgm.py +17 -16
- flwr/serverapp/strategy/fedmedian.py +2 -2
- flwr/serverapp/strategy/fedopt.py +10 -11
- flwr/serverapp/strategy/fedprox.py +7 -8
- flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
- flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
- flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
- flwr/serverapp/strategy/fedyogi.py +9 -11
- flwr/serverapp/strategy/krum.py +7 -7
- flwr/serverapp/strategy/multikrum.py +9 -9
- flwr/serverapp/strategy/qfedavg.py +17 -16
- flwr/serverapp/strategy/strategy.py +6 -9
- flwr/serverapp/strategy/strategy_utils.py +7 -8
- flwr/simulation/app.py +46 -42
- flwr/simulation/legacy_app.py +12 -12
- flwr/simulation/ray_transport/ray_actor.py +10 -11
- flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
- flwr/simulation/run_simulation.py +43 -43
- flwr/simulation/simulationio_connection.py +4 -4
- flwr/supercore/cli/flower_superexec.py +3 -4
- flwr/supercore/constant.py +31 -1
- flwr/supercore/corestate/corestate.py +24 -3
- flwr/supercore/corestate/in_memory_corestate.py +138 -0
- flwr/supercore/corestate/sqlite_corestate.py +157 -0
- flwr/supercore/ffs/disk_ffs.py +1 -2
- flwr/supercore/ffs/ffs.py +1 -2
- flwr/supercore/ffs/ffs_factory.py +1 -2
- flwr/{common → supercore}/heartbeat.py +20 -25
- flwr/supercore/object_store/in_memory_object_store.py +1 -2
- flwr/supercore/object_store/object_store.py +1 -2
- flwr/supercore/object_store/object_store_factory.py +1 -2
- flwr/supercore/object_store/sqlite_object_store.py +8 -7
- flwr/supercore/primitives/asymmetric.py +1 -1
- flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
- flwr/supercore/sqlite_mixin.py +37 -34
- flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
- flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
- flwr/supercore/superexec/run_superexec.py +9 -13
- flwr/superlink/artifact_provider/artifact_provider.py +1 -2
- flwr/superlink/auth_plugin/auth_plugin.py +6 -9
- flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
- flwr/superlink/federation/__init__.py +24 -0
- flwr/superlink/federation/federation_manager.py +64 -0
- flwr/superlink/federation/noop_federation_manager.py +71 -0
- flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
- flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
- flwr/superlink/servicer/control/control_grpc.py +5 -6
- flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
- flwr/superlink/servicer/control/control_servicer.py +102 -18
- flwr/supernode/cli/flower_supernode.py +58 -3
- flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
- flwr/supernode/nodestate/nodestate.py +7 -8
- flwr/supernode/nodestate/nodestate_factory.py +7 -4
- flwr/supernode/runtime/run_clientapp.py +41 -22
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
- flwr/supernode/start_client_internal.py +158 -42
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
- flwr-1.24.0.dist-info/RECORD +454 -0
- flwr/supercore/object_store/utils.py +0 -43
- flwr-1.23.0.dist-info/RECORD +0 -439
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
|
@@ -18,16 +18,14 @@
|
|
|
18
18
|
# pylint: disable=too-many-lines
|
|
19
19
|
|
|
20
20
|
import json
|
|
21
|
-
import secrets
|
|
22
21
|
import sqlite3
|
|
23
22
|
from collections.abc import Sequence
|
|
23
|
+
from datetime import datetime, timezone
|
|
24
24
|
from logging import ERROR, WARNING
|
|
25
|
-
from typing import Any,
|
|
25
|
+
from typing import Any, cast
|
|
26
26
|
|
|
27
27
|
from flwr.common import Context, Message, Metadata, log, now
|
|
28
28
|
from flwr.common.constant import (
|
|
29
|
-
FLWR_APP_TOKEN_LENGTH,
|
|
30
|
-
HEARTBEAT_INTERVAL_INF,
|
|
31
29
|
HEARTBEAT_PATIENCE,
|
|
32
30
|
MESSAGE_TTL_TOLERANCE,
|
|
33
31
|
NODE_ID_NUM_BYTES,
|
|
@@ -51,8 +49,10 @@ from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
|
|
|
51
49
|
# pylint: enable=E0611
|
|
52
50
|
from flwr.server.utils.validator import validate_message
|
|
53
51
|
from flwr.supercore.constant import NodeStatus
|
|
54
|
-
from flwr.supercore.
|
|
52
|
+
from flwr.supercore.corestate.sqlite_corestate import SqliteCoreState
|
|
53
|
+
from flwr.supercore.object_store.object_store import ObjectStore
|
|
55
54
|
from flwr.supercore.utils import int64_to_uint64, uint64_to_int64
|
|
55
|
+
from flwr.superlink.federation import FederationManager
|
|
56
56
|
|
|
57
57
|
from .linkstate import LinkState
|
|
58
58
|
from .utils import (
|
|
@@ -74,6 +74,7 @@ SQL_CREATE_TABLE_NODE = """
|
|
|
74
74
|
CREATE TABLE IF NOT EXISTS node(
|
|
75
75
|
node_id INTEGER UNIQUE,
|
|
76
76
|
owner_aid TEXT,
|
|
77
|
+
owner_name TEXT,
|
|
77
78
|
status TEXT,
|
|
78
79
|
registered_at TEXT,
|
|
79
80
|
last_activated_at TEXT NULL,
|
|
@@ -106,8 +107,6 @@ CREATE INDEX IF NOT EXISTS idx_node_status ON node(status);
|
|
|
106
107
|
SQL_CREATE_TABLE_RUN = """
|
|
107
108
|
CREATE TABLE IF NOT EXISTS run(
|
|
108
109
|
run_id INTEGER UNIQUE,
|
|
109
|
-
active_until REAL,
|
|
110
|
-
heartbeat_interval REAL,
|
|
111
110
|
fab_id TEXT,
|
|
112
111
|
fab_version TEXT,
|
|
113
112
|
fab_hash TEXT,
|
|
@@ -118,6 +117,7 @@ CREATE TABLE IF NOT EXISTS run(
|
|
|
118
117
|
finished_at TEXT,
|
|
119
118
|
sub_status TEXT,
|
|
120
119
|
details TEXT,
|
|
120
|
+
federation TEXT,
|
|
121
121
|
federation_options BLOB,
|
|
122
122
|
flwr_aid TEXT
|
|
123
123
|
);
|
|
@@ -179,20 +179,23 @@ CREATE TABLE IF NOT EXISTS message_res(
|
|
|
179
179
|
);
|
|
180
180
|
"""
|
|
181
181
|
|
|
182
|
-
SQL_CREATE_TABLE_TOKEN_STORE = """
|
|
183
|
-
CREATE TABLE IF NOT EXISTS token_store (
|
|
184
|
-
run_id INTEGER PRIMARY KEY,
|
|
185
|
-
token TEXT UNIQUE NOT NULL
|
|
186
|
-
);
|
|
187
|
-
"""
|
|
188
|
-
|
|
189
182
|
|
|
190
|
-
class SqliteLinkState(LinkState,
|
|
183
|
+
class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
191
184
|
"""SQLite-based LinkState implementation."""
|
|
192
185
|
|
|
193
|
-
def
|
|
194
|
-
|
|
195
|
-
|
|
186
|
+
def __init__(
|
|
187
|
+
self,
|
|
188
|
+
database_path: str,
|
|
189
|
+
federation_manager: FederationManager,
|
|
190
|
+
object_store: ObjectStore,
|
|
191
|
+
) -> None:
|
|
192
|
+
super().__init__(database_path, object_store)
|
|
193
|
+
federation_manager.linkstate = self
|
|
194
|
+
self._federation_manager = federation_manager
|
|
195
|
+
|
|
196
|
+
def get_sql_statements(self) -> tuple[str, ...]:
|
|
197
|
+
"""Return SQL statements for LinkState tables."""
|
|
198
|
+
return super().get_sql_statements() + (
|
|
196
199
|
SQL_CREATE_TABLE_RUN,
|
|
197
200
|
SQL_CREATE_TABLE_LOGS,
|
|
198
201
|
SQL_CREATE_TABLE_CONTEXT,
|
|
@@ -200,14 +203,17 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
200
203
|
SQL_CREATE_TABLE_MESSAGE_RES,
|
|
201
204
|
SQL_CREATE_TABLE_NODE,
|
|
202
205
|
SQL_CREATE_TABLE_PUBLIC_KEY,
|
|
203
|
-
SQL_CREATE_TABLE_TOKEN_STORE,
|
|
204
206
|
SQL_CREATE_INDEX_ONLINE_UNTIL,
|
|
205
207
|
SQL_CREATE_INDEX_OWNER_AID,
|
|
206
208
|
SQL_CREATE_INDEX_NODE_STATUS,
|
|
207
|
-
log_queries=log_queries,
|
|
208
209
|
)
|
|
209
210
|
|
|
210
|
-
|
|
211
|
+
@property
|
|
212
|
+
def federation_manager(self) -> FederationManager:
|
|
213
|
+
"""Get the FederationManager instance."""
|
|
214
|
+
return self._federation_manager
|
|
215
|
+
|
|
216
|
+
def store_message_ins(self, message: Message) -> str | None:
|
|
211
217
|
"""Store one Message."""
|
|
212
218
|
# Validate message
|
|
213
219
|
errors = validate_message(message=message, is_reply_message=False)
|
|
@@ -223,12 +229,6 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
223
229
|
data[0], ["run_id", "src_node_id", "dst_node_id"]
|
|
224
230
|
)
|
|
225
231
|
|
|
226
|
-
# Validate run_id
|
|
227
|
-
query = "SELECT run_id FROM run WHERE run_id = ?;"
|
|
228
|
-
if not self.query(query, (data[0]["run_id"],)):
|
|
229
|
-
log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
|
|
230
|
-
return None
|
|
231
|
-
|
|
232
232
|
# Validate source node ID
|
|
233
233
|
if message.metadata.src_node_id != SUPERLINK_NODE_ID:
|
|
234
234
|
log(
|
|
@@ -238,28 +238,87 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
238
238
|
)
|
|
239
239
|
return None
|
|
240
240
|
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
query, (data[0]["
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
241
|
+
with self.conn:
|
|
242
|
+
# Validate run_id
|
|
243
|
+
query = "SELECT federation FROM run WHERE run_id = ?;"
|
|
244
|
+
rows = self.conn.execute(query, (data[0]["run_id"],)).fetchall()
|
|
245
|
+
if not rows:
|
|
246
|
+
log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
|
|
247
|
+
return None
|
|
248
|
+
federation: str = rows[0]["federation"]
|
|
249
|
+
|
|
250
|
+
# Validate destination node ID
|
|
251
|
+
query = "SELECT node_id FROM node WHERE node_id = ? AND status IN (?, ?);"
|
|
252
|
+
rows = self.conn.execute(
|
|
253
|
+
query, (data[0]["dst_node_id"], NodeStatus.ONLINE, NodeStatus.OFFLINE)
|
|
254
|
+
).fetchall()
|
|
255
|
+
if not rows or not self.federation_manager.has_node(
|
|
256
|
+
message.metadata.dst_node_id, federation
|
|
257
|
+
):
|
|
258
|
+
log(
|
|
259
|
+
ERROR,
|
|
260
|
+
"Invalid destination node ID for Message: %s",
|
|
261
|
+
message.metadata.dst_node_id,
|
|
262
|
+
)
|
|
263
|
+
return None
|
|
264
|
+
|
|
265
|
+
columns = ", ".join([f":{key}" for key in data[0]])
|
|
266
|
+
query = f"INSERT INTO message_ins VALUES({columns});"
|
|
267
|
+
|
|
268
|
+
# Only invalid run_id can trigger IntegrityError.
|
|
269
|
+
# This may need to be changed in the future version
|
|
270
|
+
# with more integrity checks.
|
|
271
|
+
self.conn.execute(query, data[0])
|
|
259
272
|
|
|
260
273
|
return message.metadata.message_id
|
|
261
274
|
|
|
262
|
-
def
|
|
275
|
+
def _check_stored_messages(self, message_ids: set[str]) -> None:
|
|
276
|
+
"""Check and delete the message if it's invalid."""
|
|
277
|
+
if not message_ids:
|
|
278
|
+
return
|
|
279
|
+
|
|
280
|
+
with self.conn:
|
|
281
|
+
invalid_msg_ids: set[str] = set()
|
|
282
|
+
current_time = now().timestamp()
|
|
283
|
+
|
|
284
|
+
for msg_id in message_ids:
|
|
285
|
+
# Check if message exists
|
|
286
|
+
query = "SELECT * FROM message_ins WHERE message_id = ?;"
|
|
287
|
+
message_row = self.conn.execute(query, (msg_id,)).fetchone()
|
|
288
|
+
if not message_row:
|
|
289
|
+
continue
|
|
290
|
+
|
|
291
|
+
# Check if the message has expired
|
|
292
|
+
available_until = message_row["created_at"] + message_row["ttl"]
|
|
293
|
+
if available_until <= current_time:
|
|
294
|
+
invalid_msg_ids.add(msg_id)
|
|
295
|
+
continue
|
|
296
|
+
|
|
297
|
+
# Check if src_node_id and dst_node_id are in the federation
|
|
298
|
+
# Get federation from run table
|
|
299
|
+
run_id = message_row["run_id"]
|
|
300
|
+
query = "SELECT federation FROM run WHERE run_id = ?;"
|
|
301
|
+
run_row = self.conn.execute(query, (run_id,)).fetchone()
|
|
302
|
+
if not run_row: # This should not happen
|
|
303
|
+
invalid_msg_ids.add(msg_id)
|
|
304
|
+
continue
|
|
305
|
+
federation = run_row["federation"]
|
|
306
|
+
|
|
307
|
+
# Convert sint64 to uint64 for node IDs
|
|
308
|
+
src_node_id = int64_to_uint64(message_row["src_node_id"])
|
|
309
|
+
dst_node_id = int64_to_uint64(message_row["dst_node_id"])
|
|
310
|
+
|
|
311
|
+
# Filter nodes to check if they're in the federation
|
|
312
|
+
filtered = self.federation_manager.filter_nodes(
|
|
313
|
+
{src_node_id, dst_node_id}, federation
|
|
314
|
+
)
|
|
315
|
+
if len(filtered) != 2: # Not both nodes are in the federation
|
|
316
|
+
invalid_msg_ids.add(msg_id)
|
|
317
|
+
|
|
318
|
+
# Delete all invalid messages
|
|
319
|
+
self.delete_messages(invalid_msg_ids)
|
|
320
|
+
|
|
321
|
+
def get_message_ins(self, node_id: int, limit: int | None) -> list[Message]:
|
|
263
322
|
"""Get all Messages that have not been delivered yet."""
|
|
264
323
|
if limit is not None and limit < 1:
|
|
265
324
|
raise AssertionError("`limit` must be >= 1")
|
|
@@ -268,59 +327,64 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
268
327
|
msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
|
|
269
328
|
raise AssertionError(msg)
|
|
270
329
|
|
|
271
|
-
data: dict[str,
|
|
330
|
+
data: dict[str, str | int] = {}
|
|
272
331
|
|
|
273
332
|
# Convert the uint64 value to sint64 for SQLite
|
|
274
333
|
data["node_id"] = uint64_to_int64(node_id)
|
|
275
334
|
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
if limit is not None:
|
|
286
|
-
query += " LIMIT :limit"
|
|
287
|
-
data["limit"] = limit
|
|
288
|
-
|
|
289
|
-
query += ";"
|
|
290
|
-
|
|
291
|
-
rows = self.query(query, data)
|
|
292
|
-
|
|
293
|
-
if rows:
|
|
294
|
-
# Prepare query
|
|
295
|
-
message_ids = [row["message_id"] for row in rows]
|
|
296
|
-
placeholders: str = ",".join([f":id_{i}" for i in range(len(message_ids))])
|
|
297
|
-
query = f"""
|
|
298
|
-
UPDATE message_ins
|
|
299
|
-
SET delivered_at = :delivered_at
|
|
300
|
-
WHERE message_id IN ({placeholders})
|
|
301
|
-
RETURNING *;
|
|
335
|
+
with self.conn:
|
|
336
|
+
# Retrieve all Messages for node_id
|
|
337
|
+
query = """
|
|
338
|
+
SELECT message_id
|
|
339
|
+
FROM message_ins
|
|
340
|
+
WHERE dst_node_id == :node_id
|
|
341
|
+
AND delivered_at = ""
|
|
342
|
+
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
|
|
302
343
|
"""
|
|
303
344
|
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
for index, msg_id in enumerate(message_ids):
|
|
308
|
-
data[f"id_{index}"] = str(msg_id)
|
|
345
|
+
if limit is not None:
|
|
346
|
+
query += " LIMIT :limit"
|
|
347
|
+
data["limit"] = limit
|
|
309
348
|
|
|
310
|
-
|
|
311
|
-
rows = self.query(query, data)
|
|
349
|
+
query += ";"
|
|
312
350
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
351
|
+
rows = self.conn.execute(query, data).fetchall()
|
|
352
|
+
message_ids: set[str] = {row["message_id"] for row in rows}
|
|
353
|
+
self._check_stored_messages(message_ids)
|
|
354
|
+
|
|
355
|
+
# Mark retrieved Messages as delivered
|
|
356
|
+
if rows:
|
|
357
|
+
# Prepare query
|
|
358
|
+
placeholders: str = ",".join(
|
|
359
|
+
[f":id_{i}" for i in range(len(message_ids))]
|
|
360
|
+
)
|
|
361
|
+
query = f"""
|
|
362
|
+
UPDATE message_ins
|
|
363
|
+
SET delivered_at = :delivered_at
|
|
364
|
+
WHERE message_id IN ({placeholders})
|
|
365
|
+
RETURNING *;
|
|
366
|
+
"""
|
|
367
|
+
|
|
368
|
+
# Prepare data for query
|
|
369
|
+
delivered_at = now().isoformat()
|
|
370
|
+
data = {"delivered_at": delivered_at}
|
|
371
|
+
for index, msg_id in enumerate(message_ids):
|
|
372
|
+
data[f"id_{index}"] = str(msg_id)
|
|
373
|
+
|
|
374
|
+
# Run query
|
|
375
|
+
rows = self.conn.execute(query, data).fetchall()
|
|
376
|
+
|
|
377
|
+
for row in rows:
|
|
378
|
+
# Convert values from sint64 to uint64
|
|
379
|
+
convert_sint64_values_in_dict_to_uint64(
|
|
380
|
+
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
381
|
+
)
|
|
318
382
|
|
|
319
383
|
result = [dict_to_message(row) for row in rows]
|
|
320
384
|
|
|
321
385
|
return result
|
|
322
386
|
|
|
323
|
-
def store_message_res(self, message: Message) ->
|
|
387
|
+
def store_message_res(self, message: Message) -> str | None:
|
|
324
388
|
"""Store one Message."""
|
|
325
389
|
# Validate message
|
|
326
390
|
errors = validate_message(message=message, is_reply_message=True)
|
|
@@ -336,7 +400,8 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
336
400
|
ERROR,
|
|
337
401
|
"Failed to store Message reply: "
|
|
338
402
|
"The message it replies to with message_id %s does not exist or "
|
|
339
|
-
"has expired
|
|
403
|
+
"has expired, or was deleted because the target SuperNode was "
|
|
404
|
+
"removed from the federation.",
|
|
340
405
|
msg_ins_id,
|
|
341
406
|
)
|
|
342
407
|
return None
|
|
@@ -397,84 +462,92 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
397
462
|
# pylint: disable-msg=too-many-locals
|
|
398
463
|
ret: dict[str, Message] = {}
|
|
399
464
|
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
465
|
+
with self.conn:
|
|
466
|
+
# Verify Message IDs
|
|
467
|
+
self._check_stored_messages(message_ids)
|
|
468
|
+
current = now().timestamp()
|
|
469
|
+
query = f"""
|
|
470
|
+
SELECT *
|
|
471
|
+
FROM message_ins
|
|
472
|
+
WHERE message_id IN ({','.join(['?'] * len(message_ids))});
|
|
473
|
+
"""
|
|
474
|
+
rows = self.conn.execute(
|
|
475
|
+
query, tuple(str(message_id) for message_id in message_ids)
|
|
476
|
+
).fetchall()
|
|
477
|
+
found_message_ins_dict: dict[str, Message] = {}
|
|
478
|
+
for row in rows:
|
|
479
|
+
convert_sint64_values_in_dict_to_uint64(
|
|
480
|
+
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
481
|
+
)
|
|
482
|
+
found_message_ins_dict[row["message_id"]] = dict_to_message(row)
|
|
483
|
+
|
|
484
|
+
ret = verify_message_ids(
|
|
485
|
+
inquired_message_ids=message_ids,
|
|
486
|
+
found_message_ins_dict=found_message_ins_dict,
|
|
487
|
+
current_time=current,
|
|
412
488
|
)
|
|
413
|
-
found_message_ins_dict[row["message_id"]] = dict_to_message(row)
|
|
414
489
|
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
490
|
+
# Check node availability
|
|
491
|
+
dst_node_ids: set[int] = set()
|
|
492
|
+
for message_id in message_ids:
|
|
493
|
+
in_message = found_message_ins_dict[message_id]
|
|
494
|
+
sint_node_id = uint64_to_int64(in_message.metadata.dst_node_id)
|
|
495
|
+
dst_node_ids.add(sint_node_id)
|
|
496
|
+
query = f"""
|
|
497
|
+
SELECT node_id, online_until
|
|
498
|
+
FROM node
|
|
499
|
+
WHERE node_id IN ({','.join(['?'] * len(dst_node_ids))})
|
|
500
|
+
AND status != ?
|
|
501
|
+
"""
|
|
502
|
+
rows = self.conn.execute(
|
|
503
|
+
query, tuple(dst_node_ids) + (NodeStatus.UNREGISTERED,)
|
|
504
|
+
).fetchall()
|
|
505
|
+
tmp_ret_dict = check_node_availability_for_in_message(
|
|
506
|
+
inquired_in_message_ids=message_ids,
|
|
507
|
+
found_in_message_dict=found_message_ins_dict,
|
|
508
|
+
node_id_to_online_until={
|
|
509
|
+
int64_to_uint64(row["node_id"]): row["online_until"] for row in rows
|
|
510
|
+
},
|
|
511
|
+
current_time=current,
|
|
512
|
+
)
|
|
513
|
+
ret.update(tmp_ret_dict)
|
|
420
514
|
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
current_time=current,
|
|
441
|
-
)
|
|
442
|
-
ret.update(tmp_ret_dict)
|
|
443
|
-
|
|
444
|
-
# Find all reply Messages
|
|
445
|
-
query = f"""
|
|
446
|
-
SELECT *
|
|
447
|
-
FROM message_res
|
|
448
|
-
WHERE reply_to_message_id IN ({",".join(["?"] * len(message_ids))})
|
|
449
|
-
AND delivered_at = "";
|
|
450
|
-
"""
|
|
451
|
-
rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
|
|
452
|
-
for row in rows:
|
|
453
|
-
convert_sint64_values_in_dict_to_uint64(
|
|
454
|
-
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
515
|
+
# Find all reply Messages
|
|
516
|
+
query = f"""
|
|
517
|
+
SELECT *
|
|
518
|
+
FROM message_res
|
|
519
|
+
WHERE reply_to_message_id IN ({','.join(['?'] * len(message_ids))})
|
|
520
|
+
AND delivered_at = "";
|
|
521
|
+
"""
|
|
522
|
+
rows = self.conn.execute(
|
|
523
|
+
query, tuple(str(message_id) for message_id in message_ids)
|
|
524
|
+
).fetchall()
|
|
525
|
+
for row in rows:
|
|
526
|
+
convert_sint64_values_in_dict_to_uint64(
|
|
527
|
+
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
528
|
+
)
|
|
529
|
+
tmp_ret_dict = verify_found_message_replies(
|
|
530
|
+
inquired_message_ids=message_ids,
|
|
531
|
+
found_message_ins_dict=found_message_ins_dict,
|
|
532
|
+
found_message_res_list=[dict_to_message(row) for row in rows],
|
|
533
|
+
current_time=current,
|
|
455
534
|
)
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
UPDATE message_res
|
|
473
|
-
SET delivered_at = ?
|
|
474
|
-
WHERE message_id IN ({",".join(["?"] * len(message_res_ids))});
|
|
475
|
-
"""
|
|
476
|
-
data: list[Any] = [delivered_at] + message_res_ids
|
|
477
|
-
self.query(query, data)
|
|
535
|
+
ret.update(tmp_ret_dict)
|
|
536
|
+
|
|
537
|
+
# Mark existing reply Messages to be returned as delivered
|
|
538
|
+
delivered_at = now().isoformat()
|
|
539
|
+
for message_res in ret.values():
|
|
540
|
+
message_res.metadata.delivered_at = delivered_at
|
|
541
|
+
message_res_ids = [
|
|
542
|
+
message_res.metadata.message_id for message_res in ret.values()
|
|
543
|
+
]
|
|
544
|
+
query = f"""
|
|
545
|
+
UPDATE message_res
|
|
546
|
+
SET delivered_at = ?
|
|
547
|
+
WHERE message_id IN ({','.join(['?'] * len(message_res_ids))});
|
|
548
|
+
"""
|
|
549
|
+
data: list[Any] = [delivered_at] + message_res_ids
|
|
550
|
+
self.conn.execute(query, data)
|
|
478
551
|
|
|
479
552
|
return list(ret.values())
|
|
480
553
|
|
|
@@ -545,7 +618,11 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
545
618
|
return {row["message_id"] for row in rows}
|
|
546
619
|
|
|
547
620
|
def create_node(
|
|
548
|
-
self,
|
|
621
|
+
self,
|
|
622
|
+
owner_aid: str,
|
|
623
|
+
owner_name: str,
|
|
624
|
+
public_key: bytes,
|
|
625
|
+
heartbeat_interval: float,
|
|
549
626
|
) -> int:
|
|
550
627
|
"""Create, store in the link state, and return `node_id`."""
|
|
551
628
|
# Sample a random uint64 as node_id
|
|
@@ -558,10 +635,10 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
558
635
|
|
|
559
636
|
query = """
|
|
560
637
|
INSERT INTO node
|
|
561
|
-
(node_id, owner_aid, status, registered_at, last_activated_at,
|
|
638
|
+
(node_id, owner_aid, owner_name, status, registered_at, last_activated_at,
|
|
562
639
|
last_deactivated_at, unregistered_at, online_until, heartbeat_interval,
|
|
563
640
|
public_key)
|
|
564
|
-
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
641
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
565
642
|
"""
|
|
566
643
|
|
|
567
644
|
# Mark the node online until now().timestamp() + heartbeat_interval
|
|
@@ -571,6 +648,7 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
571
648
|
(
|
|
572
649
|
sint64_node_id, # node_id
|
|
573
650
|
owner_aid, # owner_aid
|
|
651
|
+
owner_name, # owner_name
|
|
574
652
|
NodeStatus.REGISTERED, # status
|
|
575
653
|
now().isoformat(), # registered_at
|
|
576
654
|
None, # last_activated_at
|
|
@@ -686,23 +764,26 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
686
764
|
if self.conn is None:
|
|
687
765
|
raise AttributeError("LinkState not initialized")
|
|
688
766
|
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
767
|
+
with self.conn:
|
|
768
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
769
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
770
|
+
|
|
771
|
+
# Validate run ID
|
|
772
|
+
query = "SELECT federation FROM run WHERE run_id = ?"
|
|
773
|
+
rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
|
|
774
|
+
if not rows:
|
|
775
|
+
return set()
|
|
776
|
+
federation: str = rows[0]["federation"]
|
|
777
|
+
|
|
778
|
+
# Retrieve all online nodes
|
|
779
|
+
node_ids = {
|
|
780
|
+
node.node_id
|
|
781
|
+
for node in self.get_node_info(statuses=[NodeStatus.ONLINE])
|
|
782
|
+
}
|
|
783
|
+
# Filter node IDs by federation
|
|
784
|
+
return self.federation_manager.filter_nodes(node_ids, federation)
|
|
785
|
+
|
|
786
|
+
def _check_and_tag_offline_nodes(self, node_ids: list[int] | None = None) -> None:
|
|
706
787
|
"""Check and tag offline nodes."""
|
|
707
788
|
# strftime will convert POSIX timestamp to ISO format
|
|
708
789
|
query = """
|
|
@@ -725,9 +806,9 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
725
806
|
def get_node_info(
|
|
726
807
|
self,
|
|
727
808
|
*,
|
|
728
|
-
node_ids:
|
|
729
|
-
owner_aids:
|
|
730
|
-
statuses:
|
|
809
|
+
node_ids: Sequence[int] | None = None,
|
|
810
|
+
owner_aids: Sequence[str] | None = None,
|
|
811
|
+
statuses: Sequence[str] | None = None,
|
|
731
812
|
) -> Sequence[NodeInfo]:
|
|
732
813
|
"""Retrieve information about nodes based on the specified filters."""
|
|
733
814
|
with self.conn:
|
|
@@ -781,7 +862,7 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
781
862
|
# Return the public key
|
|
782
863
|
return cast(bytes, rows[0]["public_key"])
|
|
783
864
|
|
|
784
|
-
def get_node_id_by_public_key(self, public_key: bytes) ->
|
|
865
|
+
def get_node_id_by_public_key(self, public_key: bytes) -> int | None:
|
|
785
866
|
"""Get `node_id` for the specified `public_key` if it exists and is not
|
|
786
867
|
deleted."""
|
|
787
868
|
query = "SELECT node_id FROM node WHERE public_key = ? AND status != ?;"
|
|
@@ -798,55 +879,58 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
798
879
|
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
799
880
|
def create_run(
|
|
800
881
|
self,
|
|
801
|
-
fab_id:
|
|
802
|
-
fab_version:
|
|
803
|
-
fab_hash:
|
|
882
|
+
fab_id: str | None,
|
|
883
|
+
fab_version: str | None,
|
|
884
|
+
fab_hash: str | None,
|
|
804
885
|
override_config: UserConfig,
|
|
886
|
+
federation: str,
|
|
805
887
|
federation_options: ConfigRecord,
|
|
806
|
-
flwr_aid:
|
|
888
|
+
flwr_aid: str | None,
|
|
807
889
|
) -> int:
|
|
808
|
-
"""Create a new run
|
|
890
|
+
"""Create a new run."""
|
|
809
891
|
# Sample a random int64 as run_id
|
|
810
892
|
uint64_run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
811
893
|
|
|
812
894
|
# Convert the uint64 value to sint64 for SQLite
|
|
813
895
|
sint64_run_id = uint64_to_int64(uint64_run_id)
|
|
814
896
|
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
897
|
+
with self.conn:
|
|
898
|
+
# Check conflicts
|
|
899
|
+
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
|
|
900
|
+
# If sint64_run_id does not exist
|
|
901
|
+
row = self.conn.execute(query, (sint64_run_id,)).fetchone()
|
|
902
|
+
if row["COUNT(*)"] == 0:
|
|
903
|
+
query = """
|
|
904
|
+
INSERT INTO run
|
|
905
|
+
(run_id, fab_id, fab_version,
|
|
906
|
+
fab_hash, override_config, federation, federation_options,
|
|
907
|
+
pending_at, starting_at, running_at, finished_at, sub_status,
|
|
908
|
+
details, flwr_aid)
|
|
909
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
|
910
|
+
"""
|
|
911
|
+
override_config_json = json.dumps(override_config)
|
|
912
|
+
data = [
|
|
913
|
+
sint64_run_id, # run_id
|
|
914
|
+
fab_id, # fab_id
|
|
915
|
+
fab_version, # fab_version
|
|
916
|
+
fab_hash, # fab_hash
|
|
917
|
+
override_config_json, # override_config
|
|
918
|
+
federation, # federation
|
|
919
|
+
configrecord_to_bytes(federation_options), # federation_options
|
|
920
|
+
now().isoformat(), # pending_at
|
|
921
|
+
"", # starting_at
|
|
922
|
+
"", # running_at
|
|
923
|
+
"", # finished_at
|
|
924
|
+
"", # sub_status
|
|
925
|
+
"", # details
|
|
926
|
+
flwr_aid or "", # flwr_aid
|
|
927
|
+
]
|
|
928
|
+
self.conn.execute(query, tuple(data))
|
|
929
|
+
return uint64_run_id
|
|
846
930
|
log(ERROR, "Unexpected run creation failure.")
|
|
847
931
|
return 0
|
|
848
932
|
|
|
849
|
-
def get_run_ids(self, flwr_aid:
|
|
933
|
+
def get_run_ids(self, flwr_aid: str | None) -> set[int]:
|
|
850
934
|
"""Retrieve all run IDs if `flwr_aid` is not specified.
|
|
851
935
|
|
|
852
936
|
Otherwise, retrieve all run IDs for the specified `flwr_aid`.
|
|
@@ -860,32 +944,10 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
860
944
|
rows = self.query("SELECT run_id FROM run;", ())
|
|
861
945
|
return {int64_to_uint64(row["run_id"]) for row in rows}
|
|
862
946
|
|
|
863
|
-
def
|
|
864
|
-
"""Check if any runs are no longer active.
|
|
865
|
-
|
|
866
|
-
Marks runs with status 'starting' or 'running' as failed
|
|
867
|
-
if they have not sent a heartbeat before `active_until`.
|
|
868
|
-
"""
|
|
869
|
-
sint_run_ids = [uint64_to_int64(run_id) for run_id in run_ids]
|
|
870
|
-
query = "UPDATE run SET finished_at = ?, sub_status = ?, details = ? "
|
|
871
|
-
query += "WHERE starting_at != '' AND finished_at = '' AND active_until < ?"
|
|
872
|
-
query += f" AND run_id IN ({','.join(['?'] * len(run_ids))});"
|
|
873
|
-
current = now()
|
|
874
|
-
self.query(
|
|
875
|
-
query,
|
|
876
|
-
(
|
|
877
|
-
current.isoformat(),
|
|
878
|
-
SubStatus.FAILED,
|
|
879
|
-
RUN_FAILURE_DETAILS_NO_HEARTBEAT,
|
|
880
|
-
current.timestamp(),
|
|
881
|
-
*sint_run_ids,
|
|
882
|
-
),
|
|
883
|
-
)
|
|
884
|
-
|
|
885
|
-
def get_run(self, run_id: int) -> Optional[Run]:
|
|
947
|
+
def get_run(self, run_id: int) -> Run | None:
|
|
886
948
|
"""Retrieve information about the run with the specified `run_id`."""
|
|
887
|
-
#
|
|
888
|
-
self.
|
|
949
|
+
# Clean up expired tokens; this will flag inactive runs as needed
|
|
950
|
+
self._cleanup_expired_tokens()
|
|
889
951
|
|
|
890
952
|
# Convert the uint64 value to sint64 for SQLite
|
|
891
953
|
sint64_run_id = uint64_to_int64(run_id)
|
|
@@ -909,14 +971,15 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
909
971
|
details=row["details"],
|
|
910
972
|
),
|
|
911
973
|
flwr_aid=row["flwr_aid"],
|
|
974
|
+
federation=row["federation"],
|
|
912
975
|
)
|
|
913
976
|
log(ERROR, "`run_id` does not exist.")
|
|
914
977
|
return None
|
|
915
978
|
|
|
916
979
|
def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
|
|
917
980
|
"""Retrieve the statuses for the specified runs."""
|
|
918
|
-
#
|
|
919
|
-
self.
|
|
981
|
+
# Clean up expired tokens; this will flag inactive runs as needed
|
|
982
|
+
self._cleanup_expired_tokens()
|
|
920
983
|
|
|
921
984
|
# Convert the uint64 value to sint64 for SQLite
|
|
922
985
|
sint64_run_ids = (uint64_to_int64(run_id) for run_id in set(run_ids))
|
|
@@ -935,82 +998,73 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
935
998
|
|
|
936
999
|
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
|
|
937
1000
|
"""Update the status of the run with the specified `run_id`."""
|
|
938
|
-
#
|
|
939
|
-
self.
|
|
1001
|
+
# Clean up expired tokens; this will flag inactive runs as needed
|
|
1002
|
+
self._cleanup_expired_tokens()
|
|
940
1003
|
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
1004
|
+
with self.conn:
|
|
1005
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
1006
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
1007
|
+
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
1008
|
+
rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
|
|
1009
|
+
|
|
1010
|
+
# Check if the run_id exists
|
|
1011
|
+
if not rows:
|
|
1012
|
+
log(ERROR, "`run_id` is invalid")
|
|
1013
|
+
return False
|
|
950
1014
|
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
)
|
|
958
|
-
if not is_valid_transition(current_status, new_status):
|
|
959
|
-
log(
|
|
960
|
-
ERROR,
|
|
961
|
-
'Invalid status transition: from "%s" to "%s"',
|
|
962
|
-
current_status.status,
|
|
963
|
-
new_status.status,
|
|
1015
|
+
# Check if the status transition is valid
|
|
1016
|
+
row = rows[0]
|
|
1017
|
+
current_status = RunStatus(
|
|
1018
|
+
status=determine_run_status(row),
|
|
1019
|
+
sub_status=row["sub_status"],
|
|
1020
|
+
details=row["details"],
|
|
964
1021
|
)
|
|
965
|
-
|
|
1022
|
+
if not is_valid_transition(current_status, new_status):
|
|
1023
|
+
log(
|
|
1024
|
+
ERROR,
|
|
1025
|
+
'Invalid status transition: from "%s" to "%s"',
|
|
1026
|
+
current_status.status,
|
|
1027
|
+
new_status.status,
|
|
1028
|
+
)
|
|
1029
|
+
return False
|
|
966
1030
|
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
1031
|
+
# Check if the sub-status is valid
|
|
1032
|
+
if not has_valid_sub_status(current_status):
|
|
1033
|
+
log(
|
|
1034
|
+
ERROR,
|
|
1035
|
+
'Invalid sub-status "%s" for status "%s"',
|
|
1036
|
+
current_status.sub_status,
|
|
1037
|
+
current_status.status,
|
|
1038
|
+
)
|
|
1039
|
+
return False
|
|
976
1040
|
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
1041
|
+
# Update the status
|
|
1042
|
+
query = """
|
|
1043
|
+
UPDATE run SET %s= ?, sub_status = ?, details = ? WHERE run_id = ?;
|
|
1044
|
+
"""
|
|
981
1045
|
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
timestamp_fld
|
|
1001
|
-
|
|
1002
|
-
data = (
|
|
1003
|
-
current.isoformat(),
|
|
1004
|
-
new_status.sub_status,
|
|
1005
|
-
new_status.details,
|
|
1006
|
-
active_until,
|
|
1007
|
-
heartbeat_interval,
|
|
1008
|
-
uint64_to_int64(run_id),
|
|
1009
|
-
)
|
|
1010
|
-
self.query(query % timestamp_fld, data)
|
|
1046
|
+
# Prepare data for query
|
|
1047
|
+
current = now()
|
|
1048
|
+
|
|
1049
|
+
# Determine the timestamp field based on the new status
|
|
1050
|
+
timestamp_fld = ""
|
|
1051
|
+
if new_status.status == Status.STARTING:
|
|
1052
|
+
timestamp_fld = "starting_at"
|
|
1053
|
+
elif new_status.status == Status.RUNNING:
|
|
1054
|
+
timestamp_fld = "running_at"
|
|
1055
|
+
elif new_status.status == Status.FINISHED:
|
|
1056
|
+
timestamp_fld = "finished_at"
|
|
1057
|
+
|
|
1058
|
+
data = (
|
|
1059
|
+
current.isoformat(),
|
|
1060
|
+
new_status.sub_status,
|
|
1061
|
+
new_status.details,
|
|
1062
|
+
uint64_to_int64(run_id),
|
|
1063
|
+
)
|
|
1064
|
+
self.conn.execute(query % timestamp_fld, data)
|
|
1011
1065
|
return True
|
|
1012
1066
|
|
|
1013
|
-
def get_pending_run_id(self) ->
|
|
1067
|
+
def get_pending_run_id(self) -> int | None:
|
|
1014
1068
|
"""Get the `run_id` of a run with `Status.PENDING` status, if any."""
|
|
1015
1069
|
pending_run_id = None
|
|
1016
1070
|
|
|
@@ -1022,7 +1076,7 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
1022
1076
|
|
|
1023
1077
|
return pending_run_id
|
|
1024
1078
|
|
|
1025
|
-
def get_federation_options(self, run_id: int) ->
|
|
1079
|
+
def get_federation_options(self, run_id: int) -> ConfigRecord | None:
|
|
1026
1080
|
"""Retrieve the federation options for the specified `run_id`."""
|
|
1027
1081
|
# Convert the uint64 value to sint64 for SQLite
|
|
1028
1082
|
sint64_run_id = uint64_to_int64(run_id)
|
|
@@ -1080,45 +1134,36 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
1080
1134
|
self.conn.execute(query, params)
|
|
1081
1135
|
return True
|
|
1082
1136
|
|
|
1083
|
-
def
|
|
1084
|
-
"""
|
|
1137
|
+
def _on_tokens_expired(self, expired_records: list[tuple[int, float]]) -> None:
|
|
1138
|
+
"""Transition runs with expired tokens to failed status.
|
|
1085
1139
|
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1140
|
+
Parameters
|
|
1141
|
+
----------
|
|
1142
|
+
expired_records : list[tuple[int, float]]
|
|
1143
|
+
List of tuples containing (run_id, active_until timestamp)
|
|
1144
|
+
for expired tokens.
|
|
1090
1145
|
"""
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
# Search for the run
|
|
1095
|
-
sint_run_id = uint64_to_int64(run_id)
|
|
1096
|
-
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
1097
|
-
rows = self.query(query, (sint_run_id,))
|
|
1098
|
-
|
|
1099
|
-
if not rows:
|
|
1100
|
-
log(ERROR, "`run_id` is invalid")
|
|
1101
|
-
return False
|
|
1102
|
-
|
|
1103
|
-
# Check if the run is of status "running"/"starting"
|
|
1104
|
-
row = rows[0]
|
|
1105
|
-
status = determine_run_status(row)
|
|
1106
|
-
if status not in (Status.RUNNING, Status.STARTING):
|
|
1107
|
-
log(
|
|
1108
|
-
ERROR,
|
|
1109
|
-
'Cannot acknowledge heartbeat for run with status "%s"',
|
|
1110
|
-
status,
|
|
1111
|
-
)
|
|
1112
|
-
return False
|
|
1146
|
+
if not expired_records:
|
|
1147
|
+
return
|
|
1113
1148
|
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1149
|
+
with self.conn:
|
|
1150
|
+
query = """
|
|
1151
|
+
UPDATE run
|
|
1152
|
+
SET sub_status = ?, details = ?, finished_at = ?
|
|
1153
|
+
WHERE run_id = ?;
|
|
1154
|
+
"""
|
|
1155
|
+
data = [
|
|
1156
|
+
(
|
|
1157
|
+
SubStatus.FAILED,
|
|
1158
|
+
RUN_FAILURE_DETAILS_NO_HEARTBEAT,
|
|
1159
|
+
datetime.fromtimestamp(active_until, tz=timezone.utc).isoformat(),
|
|
1160
|
+
uint64_to_int64(run_id),
|
|
1161
|
+
)
|
|
1162
|
+
for run_id, active_until in expired_records
|
|
1163
|
+
]
|
|
1164
|
+
self.conn.executemany(query, data)
|
|
1120
1165
|
|
|
1121
|
-
def get_serverapp_context(self, run_id: int) ->
|
|
1166
|
+
def get_serverapp_context(self, run_id: int) -> Context | None:
|
|
1122
1167
|
"""Get the context for the specified `run_id`."""
|
|
1123
1168
|
# Retrieve context if any
|
|
1124
1169
|
query = "SELECT context FROM context WHERE run_id = ?;"
|
|
@@ -1132,19 +1177,21 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
1132
1177
|
context_bytes = context_to_bytes(context)
|
|
1133
1178
|
sint_run_id = uint64_to_int64(run_id)
|
|
1134
1179
|
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1180
|
+
with self.conn:
|
|
1181
|
+
# Check if any existing Context assigned to the run_id
|
|
1182
|
+
query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
|
|
1183
|
+
row = self.conn.execute(query, (sint_run_id,)).fetchone()
|
|
1184
|
+
if row["COUNT(*)"] > 0:
|
|
1185
|
+
# Update context
|
|
1186
|
+
query = "UPDATE context SET context = ? WHERE run_id = ?;"
|
|
1187
|
+
self.conn.execute(query, (context_bytes, sint_run_id))
|
|
1188
|
+
else:
|
|
1189
|
+
try:
|
|
1190
|
+
# Store context
|
|
1191
|
+
query = "INSERT INTO context (run_id, context) VALUES (?, ?);"
|
|
1192
|
+
self.conn.execute(query, (sint_run_id, context_bytes))
|
|
1193
|
+
except sqlite3.IntegrityError:
|
|
1194
|
+
raise ValueError(f"Run {run_id} not found") from None
|
|
1148
1195
|
|
|
1149
1196
|
def add_serverapp_log(self, run_id: int, log_message: str) -> None:
|
|
1150
1197
|
"""Add a log entry to the ServerApp logs for the specified `run_id`."""
|
|
@@ -1161,90 +1208,52 @@ class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
|
1161
1208
|
raise ValueError(f"Run {run_id} not found") from None
|
|
1162
1209
|
|
|
1163
1210
|
def get_serverapp_log(
|
|
1164
|
-
self, run_id: int, after_timestamp:
|
|
1211
|
+
self, run_id: int, after_timestamp: float | None
|
|
1165
1212
|
) -> tuple[str, float]:
|
|
1166
1213
|
"""Get the ServerApp logs for the specified `run_id`."""
|
|
1167
1214
|
# Convert the uint64 value to sint64 for SQLite
|
|
1168
1215
|
sint64_run_id = uint64_to_int64(run_id)
|
|
1169
1216
|
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1217
|
+
with self.conn:
|
|
1218
|
+
# Check if the run_id exists
|
|
1219
|
+
query = "SELECT run_id FROM run WHERE run_id = ?;"
|
|
1220
|
+
rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
|
|
1221
|
+
if not rows:
|
|
1222
|
+
raise ValueError(f"Run {run_id} not found")
|
|
1223
|
+
|
|
1224
|
+
# Retrieve logs
|
|
1225
|
+
if after_timestamp is None:
|
|
1226
|
+
after_timestamp = 0.0
|
|
1227
|
+
query = """
|
|
1228
|
+
SELECT log, timestamp FROM logs
|
|
1229
|
+
WHERE run_id = ? AND node_id = ? AND timestamp > ?;
|
|
1230
|
+
"""
|
|
1231
|
+
rows = self.conn.execute(
|
|
1232
|
+
query, (sint64_run_id, 0, after_timestamp)
|
|
1233
|
+
).fetchall()
|
|
1234
|
+
rows.sort(key=lambda x: x["timestamp"])
|
|
1235
|
+
latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
|
|
1185
1236
|
return "".join(row["log"] for row in rows), latest_timestamp
|
|
1186
1237
|
|
|
1187
|
-
def get_valid_message_ins(self, message_id: str) ->
|
|
1238
|
+
def get_valid_message_ins(self, message_id: str) -> dict[str, Any] | None:
|
|
1188
1239
|
"""Check if the Message exists and is valid (not expired).
|
|
1189
1240
|
|
|
1190
1241
|
Return Message if valid.
|
|
1191
1242
|
"""
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
created_at = message_ins["created_at"]
|
|
1205
|
-
ttl = message_ins["ttl"]
|
|
1206
|
-
current_time = now().timestamp()
|
|
1207
|
-
|
|
1208
|
-
# Check if Message is expired
|
|
1209
|
-
if ttl is not None and created_at + ttl <= current_time:
|
|
1210
|
-
return None
|
|
1211
|
-
|
|
1212
|
-
return message_ins
|
|
1243
|
+
with self.conn:
|
|
1244
|
+
self._check_stored_messages({message_id})
|
|
1245
|
+
query = """
|
|
1246
|
+
SELECT *
|
|
1247
|
+
FROM message_ins
|
|
1248
|
+
WHERE message_id = :message_id
|
|
1249
|
+
"""
|
|
1250
|
+
data = {"message_id": message_id}
|
|
1251
|
+
rows: list[dict[str, Any]] = self.conn.execute(query, data).fetchall()
|
|
1252
|
+
if not rows:
|
|
1253
|
+
# Message does not exist
|
|
1254
|
+
return None
|
|
1213
1255
|
|
|
1214
|
-
|
|
1215
|
-
"""Create a token for the given run ID."""
|
|
1216
|
-
token = secrets.token_hex(FLWR_APP_TOKEN_LENGTH) # Generate a random token
|
|
1217
|
-
query = "INSERT INTO token_store (run_id, token) VALUES (:run_id, :token);"
|
|
1218
|
-
data = {"run_id": uint64_to_int64(run_id), "token": token}
|
|
1219
|
-
try:
|
|
1220
|
-
self.query(query, data)
|
|
1221
|
-
except sqlite3.IntegrityError:
|
|
1222
|
-
return None # Token already created for this run ID
|
|
1223
|
-
return token
|
|
1224
|
-
|
|
1225
|
-
def verify_token(self, run_id: int, token: str) -> bool:
|
|
1226
|
-
"""Verify a token for the given run ID."""
|
|
1227
|
-
query = "SELECT token FROM token_store WHERE run_id = :run_id;"
|
|
1228
|
-
data = {"run_id": uint64_to_int64(run_id)}
|
|
1229
|
-
rows = self.query(query, data)
|
|
1230
|
-
if not rows:
|
|
1231
|
-
return False
|
|
1232
|
-
return cast(str, rows[0]["token"]) == token
|
|
1233
|
-
|
|
1234
|
-
def delete_token(self, run_id: int) -> None:
|
|
1235
|
-
"""Delete the token for the given run ID."""
|
|
1236
|
-
query = "DELETE FROM token_store WHERE run_id = :run_id;"
|
|
1237
|
-
data = {"run_id": uint64_to_int64(run_id)}
|
|
1238
|
-
self.query(query, data)
|
|
1239
|
-
|
|
1240
|
-
def get_run_id_by_token(self, token: str) -> Optional[int]:
|
|
1241
|
-
"""Get the run ID associated with a given token."""
|
|
1242
|
-
query = "SELECT run_id FROM token_store WHERE token = :token;"
|
|
1243
|
-
data = {"token": token}
|
|
1244
|
-
rows = self.query(query, data)
|
|
1245
|
-
if not rows:
|
|
1246
|
-
return None
|
|
1247
|
-
return int64_to_uint64(rows[0]["run_id"])
|
|
1256
|
+
return rows[0]
|
|
1248
1257
|
|
|
1249
1258
|
|
|
1250
1259
|
def message_to_dict(message: Message) -> dict[str, Any]:
|