flwr 1.24.0__py3-none-any.whl → 1.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +1 -1
- flwr/app/__init__.py +4 -1
- flwr/app/message_type.py +29 -0
- flwr/app/metadata.py +5 -2
- flwr/app/user_config.py +19 -0
- flwr/cli/app.py +37 -19
- flwr/cli/app_cmd/publish.py +25 -75
- flwr/cli/app_cmd/review.py +25 -66
- flwr/cli/auth_plugin/auth_plugin.py +5 -10
- flwr/cli/auth_plugin/noop_auth_plugin.py +1 -2
- flwr/cli/auth_plugin/oidc_cli_plugin.py +38 -38
- flwr/cli/build.py +15 -28
- flwr/cli/config/__init__.py +21 -0
- flwr/cli/config/ls.py +71 -0
- flwr/cli/config_migration.py +297 -0
- flwr/cli/config_utils.py +63 -156
- flwr/cli/constant.py +71 -0
- flwr/cli/federation/__init__.py +0 -2
- flwr/cli/federation/ls.py +256 -64
- flwr/cli/flower_config.py +429 -0
- flwr/cli/install.py +23 -62
- flwr/cli/log.py +23 -37
- flwr/cli/login/login.py +29 -63
- flwr/cli/ls.py +72 -61
- flwr/cli/new/new.py +98 -309
- flwr/cli/pull.py +19 -37
- flwr/cli/run/run.py +87 -100
- flwr/cli/run_utils.py +23 -5
- flwr/cli/stop.py +33 -74
- flwr/cli/supernode/ls.py +35 -62
- flwr/cli/supernode/register.py +31 -80
- flwr/cli/supernode/unregister.py +24 -70
- flwr/cli/typing.py +200 -0
- flwr/cli/utils.py +160 -412
- flwr/client/grpc_adapter_client/connection.py +2 -2
- flwr/client/grpc_rere_client/connection.py +9 -6
- flwr/client/grpc_rere_client/grpc_adapter.py +1 -1
- flwr/client/message_handler/message_handler.py +2 -1
- flwr/client/mod/centraldp_mods.py +1 -1
- flwr/client/mod/localdp_mod.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/client/rest_client/connection.py +6 -4
- flwr/client/run_info_store.py +2 -1
- flwr/clientapp/client_app.py +2 -1
- flwr/common/__init__.py +3 -2
- flwr/common/args.py +5 -5
- flwr/common/config.py +12 -17
- flwr/common/constant.py +3 -16
- flwr/common/context.py +2 -1
- flwr/common/exit/exit.py +4 -4
- flwr/common/exit/exit_code.py +6 -0
- flwr/common/grpc.py +2 -1
- flwr/common/logger.py +1 -1
- flwr/common/message.py +1 -1
- flwr/common/retry_invoker.py +13 -5
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -2
- flwr/common/serde.py +13 -5
- flwr/common/telemetry.py +1 -1
- flwr/common/typing.py +10 -3
- flwr/compat/client/app.py +6 -9
- flwr/compat/client/grpc_client/connection.py +2 -1
- flwr/compat/common/constant.py +29 -0
- flwr/compat/server/app.py +1 -1
- flwr/proto/clientappio_pb2.py +2 -2
- flwr/proto/clientappio_pb2_grpc.py +104 -88
- flwr/proto/clientappio_pb2_grpc.pyi +140 -80
- flwr/proto/federation_pb2.py +5 -3
- flwr/proto/federation_pb2.pyi +32 -2
- flwr/proto/fleet_pb2.py +10 -10
- flwr/proto/fleet_pb2.pyi +5 -1
- flwr/proto/run_pb2.py +18 -26
- flwr/proto/run_pb2.pyi +10 -58
- flwr/proto/serverappio_pb2.py +2 -2
- flwr/proto/serverappio_pb2_grpc.py +138 -207
- flwr/proto/serverappio_pb2_grpc.pyi +189 -155
- flwr/proto/simulationio_pb2.py +2 -2
- flwr/proto/simulationio_pb2_grpc.py +62 -90
- flwr/proto/simulationio_pb2_grpc.pyi +95 -55
- flwr/server/app.py +7 -13
- flwr/server/compat/grid_client_proxy.py +2 -1
- flwr/server/grid/grpc_grid.py +5 -5
- flwr/server/serverapp/app.py +11 -4
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +13 -12
- flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
- flwr/server/superlink/linkstate/__init__.py +2 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +36 -10
- flwr/server/superlink/linkstate/linkstate.py +34 -21
- flwr/server/superlink/linkstate/linkstate_factory.py +16 -8
- flwr/server/superlink/linkstate/{sqlite_linkstate.py → sql_linkstate.py} +471 -516
- flwr/server/superlink/linkstate/utils.py +49 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +1 -33
- flwr/server/superlink/simulation/simulationio_servicer.py +0 -19
- flwr/server/utils/validator.py +1 -1
- flwr/server/workflow/default_workflows.py +2 -1
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +1 -1
- flwr/serverapp/strategy/bulyan.py +7 -1
- flwr/serverapp/strategy/dp_fixed_clipping.py +9 -1
- flwr/serverapp/strategy/fedavg.py +1 -1
- flwr/serverapp/strategy/fedxgb_cyclic.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +2 -6
- flwr/simulation/run_simulation.py +3 -12
- flwr/simulation/simulationio_connection.py +3 -3
- flwr/{common → supercore}/address.py +7 -33
- flwr/supercore/app_utils.py +2 -1
- flwr/supercore/constant.py +27 -2
- flwr/supercore/corestate/{sqlite_corestate.py → sql_corestate.py} +19 -23
- flwr/supercore/credential_store/__init__.py +33 -0
- flwr/supercore/credential_store/credential_store.py +34 -0
- flwr/supercore/credential_store/file_credential_store.py +76 -0
- flwr/{common → supercore}/date.py +0 -11
- flwr/supercore/ffs/disk_ffs.py +1 -1
- flwr/supercore/object_store/object_store_factory.py +14 -6
- flwr/supercore/object_store/{sqlite_object_store.py → sql_object_store.py} +115 -117
- flwr/supercore/sql_mixin.py +315 -0
- flwr/{cli/new/templates → supercore/state}/__init__.py +2 -2
- flwr/{cli/new/templates/app/code/flwr_tune → supercore/state/alembic}/__init__.py +2 -2
- flwr/supercore/state/alembic/env.py +103 -0
- flwr/supercore/state/alembic/script.py.mako +43 -0
- flwr/supercore/state/alembic/utils.py +239 -0
- flwr/{cli/new/templates/app → supercore/state/alembic/versions}/__init__.py +2 -2
- flwr/supercore/state/alembic/versions/rev_2026_01_28_initialize_migration_of_state_tables.py +200 -0
- flwr/supercore/state/schema/README.md +121 -0
- flwr/{cli/new/templates/app/code → supercore/state/schema}/__init__.py +2 -2
- flwr/supercore/state/schema/corestate_tables.py +36 -0
- flwr/supercore/state/schema/linkstate_tables.py +152 -0
- flwr/supercore/state/schema/objectstore_tables.py +90 -0
- flwr/supercore/superexec/run_superexec.py +2 -2
- flwr/supercore/utils.py +225 -0
- flwr/superlink/federation/federation_manager.py +2 -2
- flwr/superlink/federation/noop_federation_manager.py +8 -6
- flwr/superlink/servicer/control/control_grpc.py +2 -0
- flwr/superlink/servicer/control/control_servicer.py +106 -21
- flwr/supernode/cli/flower_supernode.py +2 -1
- flwr/supernode/nodestate/in_memory_nodestate.py +62 -1
- flwr/supernode/nodestate/nodestate.py +45 -0
- flwr/supernode/runtime/run_clientapp.py +14 -14
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +13 -5
- flwr/supernode/start_client_internal.py +17 -10
- {flwr-1.24.0.dist-info → flwr-1.26.0.dist-info}/METADATA +8 -8
- {flwr-1.24.0.dist-info → flwr-1.26.0.dist-info}/RECORD +144 -184
- flwr/cli/federation/show.py +0 -317
- flwr/cli/new/templates/app/.gitignore.tpl +0 -163
- flwr/cli/new/templates/app/LICENSE.tpl +0 -202
- flwr/cli/new/templates/app/README.baseline.md.tpl +0 -127
- flwr/cli/new/templates/app/README.flowertune.md.tpl +0 -68
- flwr/cli/new/templates/app/README.md.tpl +0 -37
- flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/code/__init__.py.tpl +0 -1
- flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +0 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +0 -75
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +0 -93
- flwr/cli/new/templates/app/code/client.jax.py.tpl +0 -71
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +0 -102
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +0 -46
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +0 -80
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +0 -55
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +0 -108
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -82
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +0 -110
- flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +0 -36
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +0 -92
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +0 -87
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -56
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +0 -73
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +0 -78
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -66
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +0 -43
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +0 -42
- flwr/cli/new/templates/app/code/server.jax.py.tpl +0 -39
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +0 -41
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +0 -38
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +0 -41
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +0 -31
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +0 -44
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +0 -38
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +0 -56
- flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +0 -98
- flwr/cli/new/templates/app/code/task.jax.py.tpl +0 -57
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +0 -102
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +0 -7
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +0 -99
- flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +0 -111
- flwr/cli/new/templates/app/code/task.sklearn.py.tpl +0 -67
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +0 -52
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +0 -67
- flwr/cli/new/templates/app/code/utils.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +0 -146
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +0 -80
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +0 -65
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +0 -52
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +0 -56
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +0 -49
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +0 -52
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +0 -61
- flwr/common/pyproject.py +0 -42
- flwr/supercore/sqlite_mixin.py +0 -159
- /flwr/{common → supercore}/version.py +0 -0
- {flwr-1.24.0.dist-info → flwr-1.26.0.dist-info}/WHEEL +0 -0
- {flwr-1.24.0.dist-info → flwr-1.26.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2026 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,19 +12,22 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""
|
|
15
|
+
"""SQLAlchemy-based implementation of the link state."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
# pylint: disable=too-many-lines
|
|
19
19
|
|
|
20
20
|
import json
|
|
21
|
-
import sqlite3
|
|
22
21
|
from collections.abc import Sequence
|
|
23
22
|
from datetime import datetime, timezone
|
|
24
23
|
from logging import ERROR, WARNING
|
|
25
|
-
from typing import Any
|
|
24
|
+
from typing import Any
|
|
26
25
|
|
|
27
|
-
from
|
|
26
|
+
from sqlalchemy import MetaData
|
|
27
|
+
from sqlalchemy.exc import IntegrityError
|
|
28
|
+
|
|
29
|
+
from flwr.app.user_config import UserConfig
|
|
30
|
+
from flwr.common import Context, Message, log, now
|
|
28
31
|
from flwr.common.constant import (
|
|
29
32
|
HEARTBEAT_PATIENCE,
|
|
30
33
|
MESSAGE_TTL_TOLERANCE,
|
|
@@ -35,22 +38,15 @@ from flwr.common.constant import (
|
|
|
35
38
|
Status,
|
|
36
39
|
SubStatus,
|
|
37
40
|
)
|
|
38
|
-
from flwr.common.message import make_message
|
|
39
41
|
from flwr.common.record import ConfigRecord
|
|
40
|
-
from flwr.common.
|
|
41
|
-
from flwr.
|
|
42
|
-
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
43
|
-
|
|
44
|
-
# pylint: disable=E0611
|
|
45
|
-
from flwr.proto.error_pb2 import Error as ProtoError
|
|
46
|
-
from flwr.proto.node_pb2 import NodeInfo
|
|
47
|
-
from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
|
|
48
|
-
|
|
49
|
-
# pylint: enable=E0611
|
|
42
|
+
from flwr.common.typing import Run, RunStatus
|
|
43
|
+
from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
|
|
50
44
|
from flwr.server.utils.validator import validate_message
|
|
51
45
|
from flwr.supercore.constant import NodeStatus
|
|
52
|
-
from flwr.supercore.corestate.
|
|
46
|
+
from flwr.supercore.corestate.sql_corestate import SqlCoreState
|
|
53
47
|
from flwr.supercore.object_store.object_store import ObjectStore
|
|
48
|
+
from flwr.supercore.state.schema.corestate_tables import create_corestate_metadata
|
|
49
|
+
from flwr.supercore.state.schema.linkstate_tables import create_linkstate_metadata
|
|
54
50
|
from flwr.supercore.utils import int64_to_uint64, uint64_to_int64
|
|
55
51
|
from flwr.superlink.federation import FederationManager
|
|
56
52
|
|
|
@@ -63,125 +59,18 @@ from .utils import (
|
|
|
63
59
|
context_to_bytes,
|
|
64
60
|
convert_sint64_values_in_dict_to_uint64,
|
|
65
61
|
convert_uint64_values_in_dict_to_sint64,
|
|
62
|
+
dict_to_message,
|
|
66
63
|
generate_rand_int_from_bytes,
|
|
67
64
|
has_valid_sub_status,
|
|
68
65
|
is_valid_transition,
|
|
66
|
+
message_to_dict,
|
|
69
67
|
verify_found_message_replies,
|
|
70
68
|
verify_message_ids,
|
|
71
69
|
)
|
|
72
70
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
owner_aid TEXT,
|
|
77
|
-
owner_name TEXT,
|
|
78
|
-
status TEXT,
|
|
79
|
-
registered_at TEXT,
|
|
80
|
-
last_activated_at TEXT NULL,
|
|
81
|
-
last_deactivated_at TEXT NULL,
|
|
82
|
-
unregistered_at TEXT NULL,
|
|
83
|
-
online_until TIMESTAMP NULL,
|
|
84
|
-
heartbeat_interval REAL,
|
|
85
|
-
public_key BLOB UNIQUE
|
|
86
|
-
);
|
|
87
|
-
"""
|
|
88
|
-
|
|
89
|
-
SQL_CREATE_TABLE_PUBLIC_KEY = """
|
|
90
|
-
CREATE TABLE IF NOT EXISTS public_key(
|
|
91
|
-
public_key BLOB PRIMARY KEY
|
|
92
|
-
);
|
|
93
|
-
"""
|
|
94
|
-
|
|
95
|
-
SQL_CREATE_INDEX_ONLINE_UNTIL = """
|
|
96
|
-
CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
|
|
97
|
-
"""
|
|
98
|
-
|
|
99
|
-
SQL_CREATE_INDEX_OWNER_AID = """
|
|
100
|
-
CREATE INDEX IF NOT EXISTS idx_node_owner_aid ON node(owner_aid);
|
|
101
|
-
"""
|
|
102
|
-
|
|
103
|
-
SQL_CREATE_INDEX_NODE_STATUS = """
|
|
104
|
-
CREATE INDEX IF NOT EXISTS idx_node_status ON node(status);
|
|
105
|
-
"""
|
|
106
|
-
|
|
107
|
-
SQL_CREATE_TABLE_RUN = """
|
|
108
|
-
CREATE TABLE IF NOT EXISTS run(
|
|
109
|
-
run_id INTEGER UNIQUE,
|
|
110
|
-
fab_id TEXT,
|
|
111
|
-
fab_version TEXT,
|
|
112
|
-
fab_hash TEXT,
|
|
113
|
-
override_config TEXT,
|
|
114
|
-
pending_at TEXT,
|
|
115
|
-
starting_at TEXT,
|
|
116
|
-
running_at TEXT,
|
|
117
|
-
finished_at TEXT,
|
|
118
|
-
sub_status TEXT,
|
|
119
|
-
details TEXT,
|
|
120
|
-
federation TEXT,
|
|
121
|
-
federation_options BLOB,
|
|
122
|
-
flwr_aid TEXT
|
|
123
|
-
);
|
|
124
|
-
"""
|
|
125
|
-
|
|
126
|
-
SQL_CREATE_TABLE_LOGS = """
|
|
127
|
-
CREATE TABLE IF NOT EXISTS logs (
|
|
128
|
-
timestamp REAL,
|
|
129
|
-
run_id INTEGER,
|
|
130
|
-
node_id INTEGER,
|
|
131
|
-
log TEXT,
|
|
132
|
-
PRIMARY KEY (timestamp, run_id, node_id),
|
|
133
|
-
FOREIGN KEY (run_id) REFERENCES run(run_id)
|
|
134
|
-
);
|
|
135
|
-
"""
|
|
136
|
-
|
|
137
|
-
SQL_CREATE_TABLE_CONTEXT = """
|
|
138
|
-
CREATE TABLE IF NOT EXISTS context(
|
|
139
|
-
run_id INTEGER UNIQUE,
|
|
140
|
-
context BLOB,
|
|
141
|
-
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
|
142
|
-
);
|
|
143
|
-
"""
|
|
144
|
-
|
|
145
|
-
SQL_CREATE_TABLE_MESSAGE_INS = """
|
|
146
|
-
CREATE TABLE IF NOT EXISTS message_ins(
|
|
147
|
-
message_id TEXT UNIQUE,
|
|
148
|
-
group_id TEXT,
|
|
149
|
-
run_id INTEGER,
|
|
150
|
-
src_node_id INTEGER,
|
|
151
|
-
dst_node_id INTEGER,
|
|
152
|
-
reply_to_message_id TEXT,
|
|
153
|
-
created_at REAL,
|
|
154
|
-
delivered_at TEXT,
|
|
155
|
-
ttl REAL,
|
|
156
|
-
message_type TEXT,
|
|
157
|
-
content BLOB NULL,
|
|
158
|
-
error BLOB NULL,
|
|
159
|
-
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
|
160
|
-
);
|
|
161
|
-
"""
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
SQL_CREATE_TABLE_MESSAGE_RES = """
|
|
165
|
-
CREATE TABLE IF NOT EXISTS message_res(
|
|
166
|
-
message_id TEXT UNIQUE,
|
|
167
|
-
group_id TEXT,
|
|
168
|
-
run_id INTEGER,
|
|
169
|
-
src_node_id INTEGER,
|
|
170
|
-
dst_node_id INTEGER,
|
|
171
|
-
reply_to_message_id TEXT,
|
|
172
|
-
created_at REAL,
|
|
173
|
-
delivered_at TEXT,
|
|
174
|
-
ttl REAL,
|
|
175
|
-
message_type TEXT,
|
|
176
|
-
content BLOB NULL,
|
|
177
|
-
error BLOB NULL,
|
|
178
|
-
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
|
179
|
-
);
|
|
180
|
-
"""
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
184
|
-
"""SQLite-based LinkState implementation."""
|
|
71
|
+
|
|
72
|
+
class SqlLinkState(LinkState, SqlCoreState): # pylint: disable=R0904
|
|
73
|
+
"""SQLAlchemy-based LinkState implementation."""
|
|
185
74
|
|
|
186
75
|
def __init__(
|
|
187
76
|
self,
|
|
@@ -193,24 +82,21 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
193
82
|
federation_manager.linkstate = self
|
|
194
83
|
self._federation_manager = federation_manager
|
|
195
84
|
|
|
196
|
-
def
|
|
197
|
-
"""Return
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
SQL_CREATE_INDEX_OWNER_AID,
|
|
208
|
-
SQL_CREATE_INDEX_NODE_STATUS,
|
|
209
|
-
)
|
|
85
|
+
def get_metadata(self) -> MetaData:
|
|
86
|
+
"""Return combined SQLAlchemy MetaData for LinkState and CoreState tables."""
|
|
87
|
+
# Start with linkstate tables
|
|
88
|
+
metadata = create_linkstate_metadata()
|
|
89
|
+
|
|
90
|
+
# Add corestate tables (token_store)
|
|
91
|
+
corestate_metadata = create_corestate_metadata()
|
|
92
|
+
for table in corestate_metadata.tables.values():
|
|
93
|
+
table.to_metadata(metadata)
|
|
94
|
+
|
|
95
|
+
return metadata
|
|
210
96
|
|
|
211
97
|
@property
|
|
212
98
|
def federation_manager(self) -> FederationManager:
|
|
213
|
-
"""
|
|
99
|
+
"""Return the FederationManager instance."""
|
|
214
100
|
return self._federation_manager
|
|
215
101
|
|
|
216
102
|
def store_message_ins(self, message: Message) -> str | None:
|
|
@@ -238,20 +124,26 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
238
124
|
)
|
|
239
125
|
return None
|
|
240
126
|
|
|
241
|
-
with self.
|
|
127
|
+
with self.session():
|
|
242
128
|
# Validate run_id
|
|
243
|
-
query = "SELECT federation FROM run WHERE run_id =
|
|
244
|
-
rows = self.
|
|
129
|
+
query = "SELECT federation FROM run WHERE run_id = :run_id"
|
|
130
|
+
rows = self.query(query, {"run_id": data[0]["run_id"]})
|
|
245
131
|
if not rows:
|
|
246
132
|
log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
|
|
247
133
|
return None
|
|
248
134
|
federation: str = rows[0]["federation"]
|
|
249
135
|
|
|
250
136
|
# Validate destination node ID
|
|
251
|
-
query = "SELECT node_id FROM node WHERE node_id =
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
137
|
+
query = """SELECT node_id FROM node WHERE node_id = :node_id
|
|
138
|
+
AND status IN (:online, :offline)"""
|
|
139
|
+
rows = self.query(
|
|
140
|
+
query,
|
|
141
|
+
{
|
|
142
|
+
"node_id": data[0]["dst_node_id"],
|
|
143
|
+
"online": NodeStatus.ONLINE,
|
|
144
|
+
"offline": NodeStatus.OFFLINE,
|
|
145
|
+
},
|
|
146
|
+
)
|
|
255
147
|
if not rows or not self.federation_manager.has_node(
|
|
256
148
|
message.metadata.dst_node_id, federation
|
|
257
149
|
):
|
|
@@ -262,29 +154,62 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
262
154
|
)
|
|
263
155
|
return None
|
|
264
156
|
|
|
157
|
+
# Insert message
|
|
265
158
|
columns = ", ".join([f":{key}" for key in data[0]])
|
|
266
|
-
query = f"INSERT INTO message_ins VALUES({columns})
|
|
159
|
+
query = f"INSERT INTO message_ins VALUES({columns})"
|
|
267
160
|
|
|
268
161
|
# Only invalid run_id can trigger IntegrityError.
|
|
269
162
|
# This may need to be changed in the future version
|
|
270
163
|
# with more integrity checks.
|
|
271
|
-
self.
|
|
164
|
+
self.query(query, data[0])
|
|
272
165
|
|
|
273
166
|
return message.metadata.message_id
|
|
274
167
|
|
|
168
|
+
# pylint: disable-next=too-many-locals
|
|
275
169
|
def _check_stored_messages(self, message_ids: set[str]) -> None:
|
|
276
170
|
"""Check and delete the message if it's invalid."""
|
|
277
171
|
if not message_ids:
|
|
278
172
|
return
|
|
279
173
|
|
|
280
|
-
with self.
|
|
174
|
+
with self.session():
|
|
175
|
+
# Batch fetch all messages in one query
|
|
176
|
+
placeholders = ",".join([f":mid_{i}" for i in range(len(message_ids))])
|
|
177
|
+
query = f"""
|
|
178
|
+
SELECT * FROM message_ins
|
|
179
|
+
WHERE message_id IN ({placeholders})
|
|
180
|
+
"""
|
|
181
|
+
params = {f"mid_{i}": str(mid) for i, mid in enumerate(message_ids)}
|
|
182
|
+
message_rows = self.query(query, params)
|
|
183
|
+
|
|
184
|
+
if not message_rows:
|
|
185
|
+
return
|
|
186
|
+
|
|
187
|
+
# Build message lookup dict
|
|
188
|
+
message_dict: dict[str, dict[str, Any]] = {
|
|
189
|
+
row["message_id"]: row for row in message_rows
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
# Collect unique run_ids for batch federation lookup
|
|
193
|
+
run_ids = {row["run_id"] for row in message_rows}
|
|
194
|
+
placeholders = ",".join([f":rid_{i}" for i in range(len(run_ids))])
|
|
195
|
+
query = f"""
|
|
196
|
+
SELECT run_id, federation FROM run
|
|
197
|
+
WHERE run_id IN ({placeholders})
|
|
198
|
+
"""
|
|
199
|
+
params = {f"rid_{i}": rid for i, rid in enumerate(run_ids)}
|
|
200
|
+
run_rows = self.query(query, params)
|
|
201
|
+
|
|
202
|
+
# Build run_id to federation mapping
|
|
203
|
+
run_id_to_federation: dict[int, str] = {
|
|
204
|
+
row["run_id"]: row["federation"] for row in run_rows
|
|
205
|
+
}
|
|
206
|
+
|
|
281
207
|
invalid_msg_ids: set[str] = set()
|
|
282
208
|
current_time = now().timestamp()
|
|
283
209
|
|
|
210
|
+
# Check each message for validity
|
|
284
211
|
for msg_id in message_ids:
|
|
285
|
-
|
|
286
|
-
query = "SELECT * FROM message_ins WHERE message_id = ?;"
|
|
287
|
-
message_row = self.conn.execute(query, (msg_id,)).fetchone()
|
|
212
|
+
message_row = message_dict.get(msg_id)
|
|
288
213
|
if not message_row:
|
|
289
214
|
continue
|
|
290
215
|
|
|
@@ -294,15 +219,12 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
294
219
|
invalid_msg_ids.add(msg_id)
|
|
295
220
|
continue
|
|
296
221
|
|
|
297
|
-
# Check if
|
|
298
|
-
# Get federation from run table
|
|
222
|
+
# Check if run exists and get federation
|
|
299
223
|
run_id = message_row["run_id"]
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
if not run_row: # This should not happen
|
|
224
|
+
federation = run_id_to_federation.get(run_id)
|
|
225
|
+
if not federation:
|
|
303
226
|
invalid_msg_ids.add(msg_id)
|
|
304
227
|
continue
|
|
305
|
-
federation = run_row["federation"]
|
|
306
228
|
|
|
307
229
|
# Convert sint64 to uint64 for node IDs
|
|
308
230
|
src_node_id = int64_to_uint64(message_row["src_node_id"])
|
|
@@ -327,52 +249,48 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
327
249
|
msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
|
|
328
250
|
raise AssertionError(msg)
|
|
329
251
|
|
|
330
|
-
|
|
252
|
+
params: dict[str, str | int] = {}
|
|
331
253
|
|
|
332
254
|
# Convert the uint64 value to sint64 for SQLite
|
|
333
|
-
|
|
255
|
+
params["node_id"] = uint64_to_int64(node_id)
|
|
334
256
|
|
|
335
|
-
with self.
|
|
257
|
+
with self.session():
|
|
336
258
|
# Retrieve all Messages for node_id
|
|
337
259
|
query = """
|
|
338
260
|
SELECT message_id
|
|
339
261
|
FROM message_ins
|
|
340
|
-
WHERE
|
|
341
|
-
AND
|
|
342
|
-
AND
|
|
262
|
+
WHERE dst_node_id = :node_id
|
|
263
|
+
AND delivered_at = ''
|
|
264
|
+
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
|
|
343
265
|
"""
|
|
344
266
|
|
|
345
267
|
if limit is not None:
|
|
346
268
|
query += " LIMIT :limit"
|
|
347
|
-
|
|
269
|
+
params["limit"] = limit
|
|
348
270
|
|
|
349
|
-
query
|
|
350
|
-
|
|
351
|
-
rows = self.conn.execute(query, data).fetchall()
|
|
271
|
+
rows = self.query(query, params)
|
|
352
272
|
message_ids: set[str] = {row["message_id"] for row in rows}
|
|
353
273
|
self._check_stored_messages(message_ids)
|
|
354
274
|
|
|
355
275
|
# Mark retrieved Messages as delivered
|
|
356
276
|
if rows:
|
|
357
277
|
# Prepare query
|
|
358
|
-
placeholders
|
|
359
|
-
[f":id_{i}" for i in range(len(message_ids))]
|
|
360
|
-
)
|
|
278
|
+
placeholders = ",".join([f":mid_{i}" for i in range(len(message_ids))])
|
|
361
279
|
query = f"""
|
|
362
280
|
UPDATE message_ins
|
|
363
281
|
SET delivered_at = :delivered_at
|
|
364
282
|
WHERE message_id IN ({placeholders})
|
|
365
|
-
RETURNING
|
|
283
|
+
RETURNING *
|
|
366
284
|
"""
|
|
367
285
|
|
|
368
286
|
# Prepare data for query
|
|
369
287
|
delivered_at = now().isoformat()
|
|
370
|
-
|
|
288
|
+
params = {"delivered_at": delivered_at}
|
|
371
289
|
for index, msg_id in enumerate(message_ids):
|
|
372
|
-
|
|
290
|
+
params[f"mid_{index}"] = str(msg_id)
|
|
373
291
|
|
|
374
292
|
# Run query
|
|
375
|
-
rows = self.
|
|
293
|
+
rows = self.query(query, params)
|
|
376
294
|
|
|
377
295
|
for row in rows:
|
|
378
296
|
# Convert values from sint64 to uint64
|
|
@@ -380,7 +298,7 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
380
298
|
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
381
299
|
)
|
|
382
300
|
|
|
383
|
-
result = [dict_to_message(row) for row in rows]
|
|
301
|
+
result = [dict_to_message(dict(row)) for row in rows]
|
|
384
302
|
|
|
385
303
|
return result
|
|
386
304
|
|
|
@@ -406,8 +324,8 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
406
324
|
)
|
|
407
325
|
return None
|
|
408
326
|
|
|
409
|
-
# Ensure that the dst_node_id of the original message matches the src_node_id
|
|
410
|
-
# reply being processed.
|
|
327
|
+
# Ensure that the dst_node_id of the original message matches the src_node_id
|
|
328
|
+
# of reply being processed.
|
|
411
329
|
if (
|
|
412
330
|
msg_ins
|
|
413
331
|
and message
|
|
@@ -437,21 +355,19 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
437
355
|
return None
|
|
438
356
|
|
|
439
357
|
# Store Message
|
|
440
|
-
|
|
358
|
+
msg_dict = message_to_dict(message)
|
|
441
359
|
|
|
442
360
|
# Convert values from uint64 to sint64 for SQLite
|
|
443
361
|
convert_uint64_values_in_dict_to_sint64(
|
|
444
|
-
|
|
362
|
+
msg_dict, ["run_id", "src_node_id", "dst_node_id"]
|
|
445
363
|
)
|
|
446
364
|
|
|
447
|
-
columns = ", ".join([f":{key}" for key in
|
|
448
|
-
query = f"INSERT INTO message_res VALUES({columns})
|
|
365
|
+
columns = ", ".join([f":{key}" for key in msg_dict])
|
|
366
|
+
query = f"INSERT INTO message_res VALUES({columns})"
|
|
449
367
|
|
|
450
|
-
# Only invalid run_id can trigger IntegrityError.
|
|
451
|
-
# This may need to be changed in the future version with more integrity checks.
|
|
452
368
|
try:
|
|
453
|
-
self.query(query,
|
|
454
|
-
except
|
|
369
|
+
self.query(query, msg_dict)
|
|
370
|
+
except IntegrityError:
|
|
455
371
|
log(ERROR, "`run` is invalid")
|
|
456
372
|
return None
|
|
457
373
|
|
|
@@ -459,21 +375,21 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
459
375
|
|
|
460
376
|
def get_message_res(self, message_ids: set[str]) -> list[Message]:
|
|
461
377
|
"""Get reply Messages for the given Message IDs."""
|
|
462
|
-
# pylint: disable
|
|
378
|
+
# pylint: disable=too-many-locals
|
|
463
379
|
ret: dict[str, Message] = {}
|
|
464
380
|
|
|
465
|
-
with self.
|
|
381
|
+
with self.session():
|
|
466
382
|
# Verify Message IDs
|
|
467
383
|
self._check_stored_messages(message_ids)
|
|
468
384
|
current = now().timestamp()
|
|
385
|
+
placeholders = ",".join([f":mid_{i}" for i in range(len(message_ids))])
|
|
469
386
|
query = f"""
|
|
470
387
|
SELECT *
|
|
471
388
|
FROM message_ins
|
|
472
|
-
WHERE message_id IN ({
|
|
389
|
+
WHERE message_id IN ({placeholders})
|
|
473
390
|
"""
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
).fetchall()
|
|
391
|
+
params = {f"mid_{i}": str(mid) for i, mid in enumerate(message_ids)}
|
|
392
|
+
rows = self.query(query, params)
|
|
477
393
|
found_message_ins_dict: dict[str, Message] = {}
|
|
478
394
|
for row in rows:
|
|
479
395
|
convert_sint64_values_in_dict_to_uint64(
|
|
@@ -493,15 +409,18 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
493
409
|
in_message = found_message_ins_dict[message_id]
|
|
494
410
|
sint_node_id = uint64_to_int64(in_message.metadata.dst_node_id)
|
|
495
411
|
dst_node_ids.add(sint_node_id)
|
|
412
|
+
placeholders = ",".join([f":nid_{i}" for i in range(len(dst_node_ids))])
|
|
496
413
|
query = f"""
|
|
497
414
|
SELECT node_id, online_until
|
|
498
415
|
FROM node
|
|
499
|
-
WHERE node_id IN ({
|
|
500
|
-
AND status !=
|
|
416
|
+
WHERE node_id IN ({placeholders})
|
|
417
|
+
AND status != :unregistered
|
|
501
418
|
"""
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
419
|
+
node_params: dict[str, int | str] = {
|
|
420
|
+
f"nid_{i}": nid for i, nid in enumerate(dst_node_ids)
|
|
421
|
+
}
|
|
422
|
+
node_params["unregistered"] = NodeStatus.UNREGISTERED
|
|
423
|
+
rows = self.query(query, node_params)
|
|
505
424
|
tmp_ret_dict = check_node_availability_for_in_message(
|
|
506
425
|
inquired_in_message_ids=message_ids,
|
|
507
426
|
found_in_message_dict=found_message_ins_dict,
|
|
@@ -513,15 +432,15 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
513
432
|
ret.update(tmp_ret_dict)
|
|
514
433
|
|
|
515
434
|
# Find all reply Messages
|
|
435
|
+
placeholders = ",".join([f":mid_{i}" for i in range(len(message_ids))])
|
|
516
436
|
query = f"""
|
|
517
437
|
SELECT *
|
|
518
438
|
FROM message_res
|
|
519
|
-
WHERE reply_to_message_id IN ({
|
|
520
|
-
AND delivered_at =
|
|
439
|
+
WHERE reply_to_message_id IN ({placeholders})
|
|
440
|
+
AND delivered_at = ''
|
|
521
441
|
"""
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
).fetchall()
|
|
442
|
+
params = {f"mid_{i}": str(mid) for i, mid in enumerate(message_ids)}
|
|
443
|
+
rows = self.query(query, params)
|
|
525
444
|
for row in rows:
|
|
526
445
|
convert_sint64_values_in_dict_to_uint64(
|
|
527
446
|
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
@@ -541,13 +460,15 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
541
460
|
message_res_ids = [
|
|
542
461
|
message_res.metadata.message_id for message_res in ret.values()
|
|
543
462
|
]
|
|
463
|
+
placeholders = ",".join([f":mid_{i}" for i in range(len(message_res_ids))])
|
|
544
464
|
query = f"""
|
|
545
465
|
UPDATE message_res
|
|
546
|
-
SET delivered_at =
|
|
547
|
-
WHERE message_id IN ({
|
|
466
|
+
SET delivered_at = :delivered_at
|
|
467
|
+
WHERE message_id IN ({placeholders})
|
|
548
468
|
"""
|
|
549
|
-
|
|
550
|
-
|
|
469
|
+
params = {"delivered_at": delivered_at}
|
|
470
|
+
params.update({f"mid_{i}": mid for i, mid in enumerate(message_res_ids)})
|
|
471
|
+
self.query(query, params)
|
|
551
472
|
|
|
552
473
|
return list(ret.values())
|
|
553
474
|
|
|
@@ -556,64 +477,55 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
556
477
|
|
|
557
478
|
This includes delivered but not yet deleted.
|
|
558
479
|
"""
|
|
559
|
-
query = "SELECT count(*) AS num FROM message_ins
|
|
560
|
-
rows = self.query(query)
|
|
561
|
-
|
|
562
|
-
num = cast(int, result["num"])
|
|
563
|
-
return num
|
|
480
|
+
query = "SELECT count(*) AS num FROM message_ins"
|
|
481
|
+
rows = self.query(query, {})
|
|
482
|
+
return int(rows[0]["num"])
|
|
564
483
|
|
|
565
484
|
def num_message_res(self) -> int:
|
|
566
485
|
"""Calculate the number of reply Messages in store.
|
|
567
486
|
|
|
568
487
|
This includes delivered but not yet deleted.
|
|
569
488
|
"""
|
|
570
|
-
query = "SELECT count(*) AS num FROM message_res
|
|
489
|
+
query = "SELECT count(*) AS num FROM message_res"
|
|
571
490
|
rows = self.query(query)
|
|
572
|
-
|
|
573
|
-
return result["num"]
|
|
491
|
+
return int(rows[0]["num"])
|
|
574
492
|
|
|
575
493
|
def delete_messages(self, message_ins_ids: set[str]) -> None:
|
|
576
494
|
"""Delete a Message and its reply based on provided Message IDs."""
|
|
577
495
|
if not message_ins_ids:
|
|
578
496
|
return
|
|
579
|
-
if self.conn is None:
|
|
580
|
-
raise AttributeError("LinkState not initialized")
|
|
581
497
|
|
|
582
|
-
placeholders = ",".join(["
|
|
583
|
-
|
|
498
|
+
placeholders = ",".join([f":mid_{i}" for i in range(len(message_ins_ids))])
|
|
499
|
+
params = {f"mid_{i}": str(mid) for i, mid in enumerate(message_ins_ids)}
|
|
584
500
|
|
|
585
501
|
# Delete Message
|
|
586
502
|
query_1 = f"""
|
|
587
503
|
DELETE FROM message_ins
|
|
588
|
-
WHERE message_id IN ({placeholders})
|
|
504
|
+
WHERE message_id IN ({placeholders})
|
|
589
505
|
"""
|
|
590
506
|
|
|
591
507
|
# Delete reply Message
|
|
592
508
|
query_2 = f"""
|
|
593
509
|
DELETE FROM message_res
|
|
594
|
-
WHERE reply_to_message_id IN ({placeholders})
|
|
510
|
+
WHERE reply_to_message_id IN ({placeholders})
|
|
595
511
|
"""
|
|
596
512
|
|
|
597
|
-
with self.
|
|
598
|
-
self.
|
|
599
|
-
self.
|
|
513
|
+
with self.session():
|
|
514
|
+
self.query(query_1, params)
|
|
515
|
+
self.query(query_2, params)
|
|
600
516
|
|
|
601
517
|
def get_message_ids_from_run_id(self, run_id: int) -> set[str]:
|
|
602
518
|
"""Get all instruction Message IDs for the given run_id."""
|
|
603
|
-
if self.conn is None:
|
|
604
|
-
raise AttributeError("LinkState not initialized")
|
|
605
|
-
|
|
606
519
|
query = """
|
|
607
520
|
SELECT message_id
|
|
608
521
|
FROM message_ins
|
|
609
|
-
WHERE run_id = :run_id
|
|
522
|
+
WHERE run_id = :run_id
|
|
610
523
|
"""
|
|
611
|
-
|
|
612
524
|
sint64_run_id = uint64_to_int64(run_id)
|
|
613
|
-
|
|
525
|
+
params = {"run_id": sint64_run_id}
|
|
614
526
|
|
|
615
|
-
with self.
|
|
616
|
-
rows = self.
|
|
527
|
+
with self.session():
|
|
528
|
+
rows = self.query(query, params)
|
|
617
529
|
|
|
618
530
|
return {row["message_id"] for row in rows}
|
|
619
531
|
|
|
@@ -638,29 +550,31 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
638
550
|
(node_id, owner_aid, owner_name, status, registered_at, last_activated_at,
|
|
639
551
|
last_deactivated_at, unregistered_at, online_until, heartbeat_interval,
|
|
640
552
|
public_key)
|
|
641
|
-
VALUES (
|
|
553
|
+
VALUES (:node_id, :owner_aid, :owner_name, :status, :registered_at,
|
|
554
|
+
:last_activated_at, :last_deactivated_at, :unregistered_at, :online_until,
|
|
555
|
+
:heartbeat_interval, :public_key)
|
|
642
556
|
"""
|
|
643
557
|
|
|
644
558
|
# Mark the node online until now().timestamp() + heartbeat_interval
|
|
645
559
|
try:
|
|
646
560
|
self.query(
|
|
647
561
|
query,
|
|
648
|
-
|
|
649
|
-
sint64_node_id,
|
|
650
|
-
owner_aid
|
|
651
|
-
owner_name
|
|
652
|
-
NodeStatus.REGISTERED,
|
|
653
|
-
now().isoformat(),
|
|
654
|
-
None,
|
|
655
|
-
None,
|
|
656
|
-
None,
|
|
657
|
-
None, #
|
|
658
|
-
heartbeat_interval
|
|
659
|
-
public_key
|
|
660
|
-
|
|
562
|
+
{
|
|
563
|
+
"node_id": sint64_node_id,
|
|
564
|
+
"owner_aid": owner_aid,
|
|
565
|
+
"owner_name": owner_name,
|
|
566
|
+
"status": NodeStatus.REGISTERED,
|
|
567
|
+
"registered_at": now().isoformat(),
|
|
568
|
+
"last_activated_at": None,
|
|
569
|
+
"last_deactivated_at": None,
|
|
570
|
+
"unregistered_at": None,
|
|
571
|
+
"online_until": None, # initialized with offline status
|
|
572
|
+
"heartbeat_interval": heartbeat_interval,
|
|
573
|
+
"public_key": public_key,
|
|
574
|
+
},
|
|
661
575
|
)
|
|
662
|
-
except
|
|
663
|
-
if "
|
|
576
|
+
except IntegrityError as e:
|
|
577
|
+
if "node.public_key" in str(e):
|
|
664
578
|
raise ValueError("Public key already in use.") from None
|
|
665
579
|
# Must be node ID conflict, almost impossible unless system is compromised
|
|
666
580
|
log(ERROR, "Unexpected node registration failure.")
|
|
@@ -675,21 +589,20 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
675
589
|
|
|
676
590
|
query = """
|
|
677
591
|
UPDATE node
|
|
678
|
-
SET status =
|
|
679
|
-
online_until = IIF(online_until >
|
|
680
|
-
WHERE node_id =
|
|
592
|
+
SET status = :unregistered, unregistered_at = :unregistered_at,
|
|
593
|
+
online_until = IIF(online_until > :current, :current, online_until)
|
|
594
|
+
WHERE node_id = :node_id AND status != :unregistered
|
|
595
|
+
AND owner_aid = :owner_aid
|
|
681
596
|
RETURNING node_id
|
|
682
597
|
"""
|
|
683
598
|
current = now()
|
|
684
|
-
params =
|
|
685
|
-
NodeStatus.UNREGISTERED,
|
|
686
|
-
current.isoformat(),
|
|
687
|
-
current.timestamp(),
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
owner_aid,
|
|
692
|
-
)
|
|
599
|
+
params = {
|
|
600
|
+
"unregistered": NodeStatus.UNREGISTERED,
|
|
601
|
+
"unregistered_at": current.isoformat(),
|
|
602
|
+
"current": current.timestamp(),
|
|
603
|
+
"node_id": sint64_node_id,
|
|
604
|
+
"owner_aid": owner_aid,
|
|
605
|
+
}
|
|
693
606
|
|
|
694
607
|
rows = self.query(query, params)
|
|
695
608
|
if not rows:
|
|
@@ -700,58 +613,58 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
700
613
|
|
|
701
614
|
def activate_node(self, node_id: int, heartbeat_interval: float) -> bool:
|
|
702
615
|
"""Activate the node with the specified `node_id`."""
|
|
703
|
-
|
|
704
|
-
self._check_and_tag_offline_nodes([node_id])
|
|
616
|
+
self._check_and_tag_offline_nodes([node_id])
|
|
705
617
|
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
618
|
+
# Only activate if the node is currently registered or offline
|
|
619
|
+
current_dt = now()
|
|
620
|
+
sint64_node_id = uint64_to_int64(node_id)
|
|
621
|
+
query = """
|
|
622
|
+
UPDATE node
|
|
623
|
+
SET status = :online,
|
|
624
|
+
last_activated_at = :current,
|
|
625
|
+
online_until = :online_until,
|
|
626
|
+
heartbeat_interval = :heartbeat_interval
|
|
627
|
+
WHERE node_id = :node_id AND status IN (:registered, :offline)
|
|
628
|
+
RETURNING node_id
|
|
629
|
+
"""
|
|
630
|
+
params = {
|
|
631
|
+
"online": NodeStatus.ONLINE,
|
|
632
|
+
"current": current_dt.isoformat(),
|
|
633
|
+
"online_until": current_dt.timestamp()
|
|
634
|
+
+ HEARTBEAT_PATIENCE * heartbeat_interval,
|
|
635
|
+
"heartbeat_interval": heartbeat_interval,
|
|
636
|
+
"node_id": sint64_node_id,
|
|
637
|
+
"registered": NodeStatus.REGISTERED,
|
|
638
|
+
"offline": NodeStatus.OFFLINE,
|
|
639
|
+
}
|
|
726
640
|
|
|
727
|
-
|
|
728
|
-
|
|
641
|
+
rows = self.query(query, params)
|
|
642
|
+
return len(rows) > 0
|
|
729
643
|
|
|
730
644
|
def deactivate_node(self, node_id: int) -> bool:
|
|
731
645
|
"""Deactivate the node with the specified `node_id`."""
|
|
732
|
-
|
|
733
|
-
self._check_and_tag_offline_nodes([node_id])
|
|
646
|
+
self._check_and_tag_offline_nodes([node_id])
|
|
734
647
|
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
648
|
+
# Only deactivate if the node is currently online
|
|
649
|
+
current_dt = now()
|
|
650
|
+
query = """
|
|
651
|
+
UPDATE node
|
|
652
|
+
SET status = :offline,
|
|
653
|
+
last_deactivated_at = :current_iso,
|
|
654
|
+
online_until = :current_ts
|
|
655
|
+
WHERE node_id = :node_id AND status = :online
|
|
656
|
+
RETURNING node_id
|
|
657
|
+
"""
|
|
658
|
+
params = {
|
|
659
|
+
"offline": NodeStatus.OFFLINE,
|
|
660
|
+
"current_iso": current_dt.isoformat(),
|
|
661
|
+
"current_ts": current_dt.timestamp(),
|
|
662
|
+
"node_id": uint64_to_int64(node_id),
|
|
663
|
+
"online": NodeStatus.ONLINE,
|
|
664
|
+
}
|
|
752
665
|
|
|
753
|
-
|
|
754
|
-
|
|
666
|
+
rows = self.query(query, params)
|
|
667
|
+
return len(rows) > 0
|
|
755
668
|
|
|
756
669
|
def get_nodes(self, run_id: int) -> set[int]:
|
|
757
670
|
"""Retrieve all currently stored node IDs as a set.
|
|
@@ -761,16 +674,13 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
761
674
|
If the provided `run_id` does not exist or has no matching nodes,
|
|
762
675
|
an empty `Set` MUST be returned.
|
|
763
676
|
"""
|
|
764
|
-
|
|
765
|
-
raise AttributeError("LinkState not initialized")
|
|
766
|
-
|
|
767
|
-
with self.conn:
|
|
677
|
+
with self.session():
|
|
768
678
|
# Convert the uint64 value to sint64 for SQLite
|
|
769
679
|
sint64_run_id = uint64_to_int64(run_id)
|
|
770
680
|
|
|
771
681
|
# Validate run ID
|
|
772
|
-
query = "SELECT federation FROM run WHERE run_id =
|
|
773
|
-
rows = self.
|
|
682
|
+
query = "SELECT federation FROM run WHERE run_id = :run_id"
|
|
683
|
+
rows = self.query(query, {"run_id": sint64_run_id})
|
|
774
684
|
if not rows:
|
|
775
685
|
return set()
|
|
776
686
|
federation: str = rows[0]["federation"]
|
|
@@ -787,23 +697,25 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
787
697
|
"""Check and tag offline nodes."""
|
|
788
698
|
# strftime will convert POSIX timestamp to ISO format
|
|
789
699
|
query = """
|
|
790
|
-
UPDATE node SET status =
|
|
700
|
+
UPDATE node SET status = :offline,
|
|
791
701
|
last_deactivated_at =
|
|
792
|
-
strftime(
|
|
793
|
-
WHERE online_until <=
|
|
702
|
+
strftime('%Y-%m-%dT%H:%M:%f+00:00', online_until, 'unixepoch')
|
|
703
|
+
WHERE online_until <= :current_time AND status = :online
|
|
794
704
|
"""
|
|
795
|
-
params =
|
|
796
|
-
NodeStatus.OFFLINE,
|
|
797
|
-
now().timestamp(),
|
|
798
|
-
NodeStatus.ONLINE,
|
|
799
|
-
|
|
705
|
+
params: dict[str, Any] = {
|
|
706
|
+
"offline": NodeStatus.OFFLINE,
|
|
707
|
+
"current_time": now().timestamp(),
|
|
708
|
+
"online": NodeStatus.ONLINE,
|
|
709
|
+
}
|
|
800
710
|
if node_ids is not None:
|
|
801
|
-
placeholders = ",".join(["
|
|
711
|
+
placeholders = ",".join([f":nid_{i}" for i in range(len(node_ids))])
|
|
802
712
|
query += f" AND node_id IN ({placeholders})"
|
|
803
|
-
params.
|
|
804
|
-
|
|
713
|
+
params.update(
|
|
714
|
+
{f"nid_{i}": uint64_to_int64(nid) for i, nid in enumerate(node_ids)}
|
|
715
|
+
)
|
|
716
|
+
self.query(query, params)
|
|
805
717
|
|
|
806
|
-
def get_node_info(
|
|
718
|
+
def get_node_info( # pylint: disable=too-many-locals
|
|
807
719
|
self,
|
|
808
720
|
*,
|
|
809
721
|
node_ids: Sequence[int] | None = None,
|
|
@@ -811,32 +723,37 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
811
723
|
statuses: Sequence[str] | None = None,
|
|
812
724
|
) -> Sequence[NodeInfo]:
|
|
813
725
|
"""Retrieve information about nodes based on the specified filters."""
|
|
814
|
-
with self.
|
|
726
|
+
with self.session():
|
|
815
727
|
self._check_and_tag_offline_nodes()
|
|
816
728
|
|
|
817
729
|
# Build the WHERE clause based on provided filters
|
|
818
730
|
conditions = []
|
|
819
|
-
params:
|
|
731
|
+
params: dict[str, Any] = {}
|
|
820
732
|
if node_ids is not None:
|
|
821
733
|
sint64_node_ids = [uint64_to_int64(node_id) for node_id in node_ids]
|
|
822
|
-
placeholders = ",".join(
|
|
734
|
+
placeholders = ",".join(
|
|
735
|
+
[f":nid_{i}" for i in range(len(sint64_node_ids))]
|
|
736
|
+
)
|
|
823
737
|
conditions.append(f"node_id IN ({placeholders})")
|
|
824
|
-
|
|
738
|
+
for i, nid in enumerate(sint64_node_ids):
|
|
739
|
+
params[f"nid_{i}"] = nid
|
|
825
740
|
if owner_aids is not None:
|
|
826
|
-
placeholders = ",".join(["
|
|
741
|
+
placeholders = ",".join([f":aid_{i}" for i in range(len(owner_aids))])
|
|
827
742
|
conditions.append(f"owner_aid IN ({placeholders})")
|
|
828
|
-
|
|
743
|
+
for i, aid in enumerate(owner_aids):
|
|
744
|
+
params[f"aid_{i}"] = aid
|
|
829
745
|
if statuses is not None:
|
|
830
|
-
placeholders = ",".join(["
|
|
746
|
+
placeholders = ",".join([f":st_{i}" for i in range(len(statuses))])
|
|
831
747
|
conditions.append(f"status IN ({placeholders})")
|
|
832
|
-
|
|
748
|
+
for i, status in enumerate(statuses):
|
|
749
|
+
params[f"st_{i}"] = status
|
|
833
750
|
|
|
834
751
|
# Construct the final query
|
|
835
752
|
query = "SELECT * FROM node"
|
|
836
753
|
if conditions:
|
|
837
754
|
query += " WHERE " + " AND ".join(conditions)
|
|
838
755
|
|
|
839
|
-
rows = self.
|
|
756
|
+
rows = self.query(query, params)
|
|
840
757
|
|
|
841
758
|
result: list[NodeInfo] = []
|
|
842
759
|
for row in rows:
|
|
@@ -846,27 +763,14 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
846
763
|
|
|
847
764
|
return result
|
|
848
765
|
|
|
849
|
-
def get_node_public_key(self, node_id: int) -> bytes:
|
|
850
|
-
"""Get `public_key` for the specified `node_id`."""
|
|
851
|
-
# Convert the uint64 value to sint64 for SQLite
|
|
852
|
-
sint64_node_id = uint64_to_int64(node_id)
|
|
853
|
-
|
|
854
|
-
# Query the public key for the given node_id
|
|
855
|
-
query = "SELECT public_key FROM node WHERE node_id = ? AND status != ?;"
|
|
856
|
-
rows = self.query(query, (sint64_node_id, NodeStatus.UNREGISTERED))
|
|
857
|
-
|
|
858
|
-
# If no result is found, return None
|
|
859
|
-
if not rows:
|
|
860
|
-
raise ValueError(f"Node ID {node_id} not found")
|
|
861
|
-
|
|
862
|
-
# Return the public key
|
|
863
|
-
return cast(bytes, rows[0]["public_key"])
|
|
864
|
-
|
|
865
766
|
def get_node_id_by_public_key(self, public_key: bytes) -> int | None:
|
|
866
767
|
"""Get `node_id` for the specified `public_key` if it exists and is not
|
|
867
768
|
deleted."""
|
|
868
|
-
query = "SELECT node_id FROM node
|
|
869
|
-
|
|
769
|
+
query = """SELECT node_id FROM node
|
|
770
|
+
WHERE public_key = :public_key AND status != :unregistered;"""
|
|
771
|
+
rows = self.query(
|
|
772
|
+
query, {"public_key": public_key, "unregistered": NodeStatus.UNREGISTERED}
|
|
773
|
+
)
|
|
870
774
|
|
|
871
775
|
# If no result is found, return None
|
|
872
776
|
if not rows:
|
|
@@ -876,8 +780,7 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
876
780
|
node_id = int64_to_uint64(rows[0]["node_id"])
|
|
877
781
|
return node_id
|
|
878
782
|
|
|
879
|
-
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
880
|
-
def create_run(
|
|
783
|
+
def create_run( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
881
784
|
self,
|
|
882
785
|
fab_id: str | None,
|
|
883
786
|
fab_version: str | None,
|
|
@@ -894,38 +797,43 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
894
797
|
# Convert the uint64 value to sint64 for SQLite
|
|
895
798
|
sint64_run_id = uint64_to_int64(uint64_run_id)
|
|
896
799
|
|
|
897
|
-
with self.
|
|
800
|
+
with self.session():
|
|
898
801
|
# Check conflicts
|
|
899
|
-
query = "SELECT COUNT(*) FROM run WHERE run_id =
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
if row["COUNT(*)"] == 0:
|
|
802
|
+
query = "SELECT COUNT(*) as cnt FROM run WHERE run_id = :run_id"
|
|
803
|
+
rows = self.query(query, {"run_id": sint64_run_id})
|
|
804
|
+
if rows[0]["cnt"] == 0:
|
|
903
805
|
query = """
|
|
904
806
|
INSERT INTO run
|
|
905
807
|
(run_id, fab_id, fab_version,
|
|
906
808
|
fab_hash, override_config, federation, federation_options,
|
|
907
809
|
pending_at, starting_at, running_at, finished_at, sub_status,
|
|
908
|
-
details, flwr_aid)
|
|
909
|
-
VALUES (
|
|
810
|
+
details, flwr_aid, bytes_sent, bytes_recv, clientapp_runtime)
|
|
811
|
+
VALUES (:run_id, :fab_id, :fab_version, :fab_hash, :override_config,
|
|
812
|
+
:federation, :federation_options, :pending_at, :starting_at,
|
|
813
|
+
:running_at, :finished_at, :sub_status, :details, :flwr_aid,
|
|
814
|
+
:bytes_sent, :bytes_recv, :clientapp_runtime)
|
|
910
815
|
"""
|
|
911
816
|
override_config_json = json.dumps(override_config)
|
|
912
|
-
|
|
913
|
-
sint64_run_id,
|
|
914
|
-
fab_id
|
|
915
|
-
fab_version
|
|
916
|
-
fab_hash
|
|
917
|
-
override_config_json,
|
|
918
|
-
federation
|
|
919
|
-
configrecord_to_bytes(federation_options),
|
|
920
|
-
now().isoformat(),
|
|
921
|
-
"",
|
|
922
|
-
"",
|
|
923
|
-
"",
|
|
924
|
-
"",
|
|
925
|
-
"",
|
|
926
|
-
flwr_aid or "",
|
|
927
|
-
|
|
928
|
-
|
|
817
|
+
params = {
|
|
818
|
+
"run_id": sint64_run_id,
|
|
819
|
+
"fab_id": fab_id or "",
|
|
820
|
+
"fab_version": fab_version or "",
|
|
821
|
+
"fab_hash": fab_hash or "",
|
|
822
|
+
"override_config": override_config_json,
|
|
823
|
+
"federation": federation,
|
|
824
|
+
"federation_options": configrecord_to_bytes(federation_options),
|
|
825
|
+
"pending_at": now().isoformat(),
|
|
826
|
+
"starting_at": "",
|
|
827
|
+
"running_at": "",
|
|
828
|
+
"finished_at": "",
|
|
829
|
+
"sub_status": "",
|
|
830
|
+
"details": "",
|
|
831
|
+
"flwr_aid": flwr_aid or "",
|
|
832
|
+
"bytes_sent": 0,
|
|
833
|
+
"bytes_recv": 0,
|
|
834
|
+
"clientapp_runtime": 0.0,
|
|
835
|
+
}
|
|
836
|
+
self.query(query, params)
|
|
929
837
|
return uint64_run_id
|
|
930
838
|
log(ERROR, "Unexpected run creation failure.")
|
|
931
839
|
return 0
|
|
@@ -937,11 +845,11 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
937
845
|
"""
|
|
938
846
|
if flwr_aid:
|
|
939
847
|
rows = self.query(
|
|
940
|
-
"SELECT run_id FROM run WHERE flwr_aid =
|
|
941
|
-
|
|
848
|
+
"SELECT run_id FROM run WHERE flwr_aid = :flwr_aid",
|
|
849
|
+
{"flwr_aid": flwr_aid},
|
|
942
850
|
)
|
|
943
851
|
else:
|
|
944
|
-
rows = self.query("SELECT run_id FROM run
|
|
852
|
+
rows = self.query("SELECT run_id FROM run", {})
|
|
945
853
|
return {int64_to_uint64(row["run_id"]) for row in rows}
|
|
946
854
|
|
|
947
855
|
def get_run(self, run_id: int) -> Run | None:
|
|
@@ -951,8 +859,8 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
951
859
|
|
|
952
860
|
# Convert the uint64 value to sint64 for SQLite
|
|
953
861
|
sint64_run_id = uint64_to_int64(run_id)
|
|
954
|
-
query = "SELECT * FROM run WHERE run_id =
|
|
955
|
-
rows = self.query(query,
|
|
862
|
+
query = "SELECT * FROM run WHERE run_id = :run_id"
|
|
863
|
+
rows = self.query(query, {"run_id": sint64_run_id})
|
|
956
864
|
if rows:
|
|
957
865
|
row = rows[0]
|
|
958
866
|
return Run(
|
|
@@ -972,6 +880,9 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
972
880
|
),
|
|
973
881
|
flwr_aid=row["flwr_aid"],
|
|
974
882
|
federation=row["federation"],
|
|
883
|
+
bytes_sent=row["bytes_sent"],
|
|
884
|
+
bytes_recv=row["bytes_recv"],
|
|
885
|
+
clientapp_runtime=row["clientapp_runtime"],
|
|
975
886
|
)
|
|
976
887
|
log(ERROR, "`run_id` does not exist.")
|
|
977
888
|
return None
|
|
@@ -982,9 +893,10 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
982
893
|
self._cleanup_expired_tokens()
|
|
983
894
|
|
|
984
895
|
# Convert the uint64 value to sint64 for SQLite
|
|
985
|
-
|
|
986
|
-
query = f"SELECT * FROM run WHERE run_id IN ({
|
|
987
|
-
|
|
896
|
+
placeholders = ",".join([f":rid_{i}" for i in range(len(run_ids))])
|
|
897
|
+
query = f"SELECT * FROM run WHERE run_id IN ({placeholders})"
|
|
898
|
+
params = {f"rid_{i}": uint64_to_int64(rid) for i, rid in enumerate(run_ids)}
|
|
899
|
+
rows = self.query(query, params)
|
|
988
900
|
|
|
989
901
|
return {
|
|
990
902
|
# Restore uint64 run IDs
|
|
@@ -1001,11 +913,11 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1001
913
|
# Clean up expired tokens; this will flag inactive runs as needed
|
|
1002
914
|
self._cleanup_expired_tokens()
|
|
1003
915
|
|
|
1004
|
-
with self.
|
|
916
|
+
with self.session():
|
|
1005
917
|
# Convert the uint64 value to sint64 for SQLite
|
|
1006
918
|
sint64_run_id = uint64_to_int64(run_id)
|
|
1007
|
-
query = "SELECT * FROM run WHERE run_id =
|
|
1008
|
-
rows = self.
|
|
919
|
+
query = "SELECT * FROM run WHERE run_id = :run_id"
|
|
920
|
+
rows = self.query(query, {"run_id": sint64_run_id})
|
|
1009
921
|
|
|
1010
922
|
# Check if the run_id exists
|
|
1011
923
|
if not rows:
|
|
@@ -1040,7 +952,9 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1040
952
|
|
|
1041
953
|
# Update the status
|
|
1042
954
|
query = """
|
|
1043
|
-
UPDATE run SET %s
|
|
955
|
+
UPDATE run SET %s = :timestamp,
|
|
956
|
+
sub_status = :sub_status, details = :details
|
|
957
|
+
WHERE run_id = :run_id
|
|
1044
958
|
"""
|
|
1045
959
|
|
|
1046
960
|
# Prepare data for query
|
|
@@ -1055,33 +969,30 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1055
969
|
elif new_status.status == Status.FINISHED:
|
|
1056
970
|
timestamp_fld = "finished_at"
|
|
1057
971
|
|
|
1058
|
-
|
|
1059
|
-
current.isoformat(),
|
|
1060
|
-
new_status.sub_status,
|
|
1061
|
-
new_status.details,
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
self.
|
|
972
|
+
params = {
|
|
973
|
+
"timestamp": current.isoformat(),
|
|
974
|
+
"sub_status": new_status.sub_status,
|
|
975
|
+
"details": new_status.details,
|
|
976
|
+
"run_id": sint64_run_id,
|
|
977
|
+
}
|
|
978
|
+
self.query(query % timestamp_fld, params)
|
|
1065
979
|
return True
|
|
1066
980
|
|
|
1067
981
|
def get_pending_run_id(self) -> int | None:
|
|
1068
|
-
"""Get the `run_id` of a run with `Status.PENDING` status
|
|
1069
|
-
pending_run_id = None
|
|
1070
|
-
|
|
982
|
+
"""Get the `run_id` of a run with `Status.PENDING` status."""
|
|
1071
983
|
# Fetch all runs with unset `starting_at` (i.e. they are in PENDING status)
|
|
1072
|
-
query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1
|
|
1073
|
-
rows = self.query(query)
|
|
984
|
+
query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1"
|
|
985
|
+
rows = self.query(query, {})
|
|
1074
986
|
if rows:
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
return pending_run_id
|
|
987
|
+
return int64_to_uint64(rows[0]["run_id"])
|
|
988
|
+
return None
|
|
1078
989
|
|
|
1079
990
|
def get_federation_options(self, run_id: int) -> ConfigRecord | None:
|
|
1080
991
|
"""Retrieve the federation options for the specified `run_id`."""
|
|
1081
992
|
# Convert the uint64 value to sint64 for SQLite
|
|
1082
993
|
sint64_run_id = uint64_to_int64(run_id)
|
|
1083
|
-
query = "SELECT federation_options FROM run WHERE run_id =
|
|
1084
|
-
rows = self.query(query,
|
|
994
|
+
query = "SELECT federation_options FROM run WHERE run_id = :run_id"
|
|
995
|
+
rows = self.query(query, {"run_id": sint64_run_id})
|
|
1085
996
|
|
|
1086
997
|
# Check if the run_id exists
|
|
1087
998
|
if not rows:
|
|
@@ -1101,41 +1012,46 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1101
1012
|
HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before
|
|
1102
1013
|
the node is marked as offline.
|
|
1103
1014
|
"""
|
|
1104
|
-
if self.conn is None:
|
|
1105
|
-
raise AttributeError("LinkState not initialized")
|
|
1106
|
-
|
|
1107
1015
|
sint64_node_id = uint64_to_int64(node_id)
|
|
1108
1016
|
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1017
|
+
# Check if the node exists and is not unregistered
|
|
1018
|
+
query = """
|
|
1019
|
+
SELECT status FROM node WHERE node_id = :node_id AND status != :unregistered
|
|
1020
|
+
"""
|
|
1021
|
+
rows = self.query(
|
|
1022
|
+
query, {"node_id": sint64_node_id, "unregistered": NodeStatus.UNREGISTERED}
|
|
1023
|
+
)
|
|
1024
|
+
if not rows:
|
|
1025
|
+
return False
|
|
1117
1026
|
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1027
|
+
# Construct query and params
|
|
1028
|
+
current_dt = now()
|
|
1029
|
+
query = (
|
|
1030
|
+
"UPDATE node SET online_until = :online_until, "
|
|
1031
|
+
"heartbeat_interval = :heartbeat_interval"
|
|
1032
|
+
)
|
|
1033
|
+
params: dict[str, Any] = {
|
|
1034
|
+
"online_until": current_dt.timestamp()
|
|
1035
|
+
+ HEARTBEAT_PATIENCE * heartbeat_interval,
|
|
1036
|
+
"heartbeat_interval": heartbeat_interval,
|
|
1037
|
+
}
|
|
1125
1038
|
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1039
|
+
# Set timestamp if the status changes
|
|
1040
|
+
if rows[0]["status"] != NodeStatus.ONLINE:
|
|
1041
|
+
query += ", status = :online, last_activated_at = :last_activated_at"
|
|
1042
|
+
params["online"] = NodeStatus.ONLINE
|
|
1043
|
+
params["last_activated_at"] = current_dt.isoformat()
|
|
1130
1044
|
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1045
|
+
# Execute the query, refreshing `online_until` and `heartbeat_interval`
|
|
1046
|
+
query += " WHERE node_id = :node_id"
|
|
1047
|
+
params["node_id"] = sint64_node_id
|
|
1048
|
+
self.query(query, params)
|
|
1049
|
+
return True
|
|
1136
1050
|
|
|
1137
1051
|
def _on_tokens_expired(self, expired_records: list[tuple[int, float]]) -> None:
|
|
1138
|
-
"""
|
|
1052
|
+
"""Handle cleanup of expired tokens.
|
|
1053
|
+
|
|
1054
|
+
Override in subclasses to add custom cleanup logic.
|
|
1139
1055
|
|
|
1140
1056
|
Parameters
|
|
1141
1057
|
----------
|
|
@@ -1146,28 +1062,30 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1146
1062
|
if not expired_records:
|
|
1147
1063
|
return
|
|
1148
1064
|
|
|
1149
|
-
with self.
|
|
1065
|
+
with self.session():
|
|
1150
1066
|
query = """
|
|
1151
1067
|
UPDATE run
|
|
1152
|
-
SET sub_status =
|
|
1153
|
-
WHERE run_id =
|
|
1068
|
+
SET sub_status = :failed, details = :details, finished_at = :finished_at
|
|
1069
|
+
WHERE run_id = :run_id
|
|
1154
1070
|
"""
|
|
1155
1071
|
data = [
|
|
1156
|
-
|
|
1157
|
-
SubStatus.FAILED,
|
|
1158
|
-
RUN_FAILURE_DETAILS_NO_HEARTBEAT,
|
|
1159
|
-
datetime.fromtimestamp(
|
|
1160
|
-
|
|
1161
|
-
|
|
1072
|
+
{
|
|
1073
|
+
"failed": SubStatus.FAILED,
|
|
1074
|
+
"details": RUN_FAILURE_DETAILS_NO_HEARTBEAT,
|
|
1075
|
+
"finished_at": datetime.fromtimestamp(
|
|
1076
|
+
active_until, tz=timezone.utc
|
|
1077
|
+
).isoformat(),
|
|
1078
|
+
"run_id": uint64_to_int64(run_id),
|
|
1079
|
+
}
|
|
1162
1080
|
for run_id, active_until in expired_records
|
|
1163
1081
|
]
|
|
1164
|
-
self.
|
|
1082
|
+
self.query(query, data)
|
|
1165
1083
|
|
|
1166
1084
|
def get_serverapp_context(self, run_id: int) -> Context | None:
|
|
1167
1085
|
"""Get the context for the specified `run_id`."""
|
|
1168
1086
|
# Retrieve context if any
|
|
1169
|
-
query = "SELECT context FROM context WHERE run_id =
|
|
1170
|
-
rows = self.query(query,
|
|
1087
|
+
query = "SELECT context FROM context WHERE run_id = :run_id"
|
|
1088
|
+
rows = self.query(query, {"run_id": uint64_to_int64(run_id)})
|
|
1171
1089
|
context = context_from_bytes(rows[0]["context"]) if rows else None
|
|
1172
1090
|
return context
|
|
1173
1091
|
|
|
@@ -1177,20 +1095,30 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1177
1095
|
context_bytes = context_to_bytes(context)
|
|
1178
1096
|
sint_run_id = uint64_to_int64(run_id)
|
|
1179
1097
|
|
|
1180
|
-
with self.
|
|
1098
|
+
with self.session():
|
|
1181
1099
|
# Check if any existing Context assigned to the run_id
|
|
1182
|
-
query = "SELECT COUNT(*) FROM context WHERE run_id =
|
|
1183
|
-
row = self.
|
|
1184
|
-
if row["
|
|
1100
|
+
query = "SELECT COUNT(*) as count FROM context WHERE run_id = :run_id"
|
|
1101
|
+
row = self.query(query, {"run_id": sint_run_id})[0]
|
|
1102
|
+
if row["count"] > 0:
|
|
1185
1103
|
# Update context
|
|
1186
|
-
query = "
|
|
1187
|
-
|
|
1104
|
+
query = """
|
|
1105
|
+
UPDATE context
|
|
1106
|
+
SET context = :context_bytes WHERE run_id = :run_id
|
|
1107
|
+
"""
|
|
1108
|
+
self.query(
|
|
1109
|
+
query, {"context_bytes": context_bytes, "run_id": sint_run_id}
|
|
1110
|
+
)
|
|
1188
1111
|
else:
|
|
1189
1112
|
try:
|
|
1190
1113
|
# Store context
|
|
1191
|
-
query =
|
|
1192
|
-
|
|
1193
|
-
|
|
1114
|
+
query = (
|
|
1115
|
+
"INSERT INTO context (run_id, context) "
|
|
1116
|
+
"VALUES (:run_id, :context_bytes)"
|
|
1117
|
+
)
|
|
1118
|
+
self.query(
|
|
1119
|
+
query, {"run_id": sint_run_id, "context_bytes": context_bytes}
|
|
1120
|
+
)
|
|
1121
|
+
except IntegrityError:
|
|
1194
1122
|
raise ValueError(f"Run {run_id} not found") from None
|
|
1195
1123
|
|
|
1196
1124
|
def add_serverapp_log(self, run_id: int, log_message: str) -> None:
|
|
@@ -1201,10 +1129,19 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1201
1129
|
# Store log
|
|
1202
1130
|
try:
|
|
1203
1131
|
query = """
|
|
1204
|
-
INSERT INTO logs (timestamp, run_id, node_id, log)
|
|
1132
|
+
INSERT INTO logs (timestamp, run_id, node_id, log)
|
|
1133
|
+
VALUES (:current_ts, :run_id, :node_id, :log)
|
|
1205
1134
|
"""
|
|
1206
|
-
self.query(
|
|
1207
|
-
|
|
1135
|
+
self.query(
|
|
1136
|
+
query,
|
|
1137
|
+
{
|
|
1138
|
+
"current_ts": now().timestamp(),
|
|
1139
|
+
"run_id": sint64_run_id,
|
|
1140
|
+
"node_id": 0,
|
|
1141
|
+
"log": log_message,
|
|
1142
|
+
},
|
|
1143
|
+
)
|
|
1144
|
+
except IntegrityError:
|
|
1208
1145
|
raise ValueError(f"Run {run_id} not found") from None
|
|
1209
1146
|
|
|
1210
1147
|
def get_serverapp_log(
|
|
@@ -1214,10 +1151,10 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1214
1151
|
# Convert the uint64 value to sint64 for SQLite
|
|
1215
1152
|
sint64_run_id = uint64_to_int64(run_id)
|
|
1216
1153
|
|
|
1217
|
-
with self.
|
|
1154
|
+
with self.session():
|
|
1218
1155
|
# Check if the run_id exists
|
|
1219
|
-
query = "SELECT run_id FROM run WHERE run_id =
|
|
1220
|
-
rows = self.
|
|
1156
|
+
query = "SELECT run_id FROM run WHERE run_id = :run_id"
|
|
1157
|
+
rows = self.query(query, {"run_id": sint64_run_id})
|
|
1221
1158
|
if not rows:
|
|
1222
1159
|
raise ValueError(f"Run {run_id} not found")
|
|
1223
1160
|
|
|
@@ -1226,12 +1163,18 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1226
1163
|
after_timestamp = 0.0
|
|
1227
1164
|
query = """
|
|
1228
1165
|
SELECT log, timestamp FROM logs
|
|
1229
|
-
WHERE run_id =
|
|
1166
|
+
WHERE run_id = :run_id AND node_id = :node_id
|
|
1167
|
+
AND timestamp > :after_timestamp
|
|
1168
|
+
ORDER BY timestamp
|
|
1230
1169
|
"""
|
|
1231
|
-
rows = self.
|
|
1232
|
-
query,
|
|
1233
|
-
|
|
1234
|
-
|
|
1170
|
+
rows = self.query(
|
|
1171
|
+
query,
|
|
1172
|
+
{
|
|
1173
|
+
"run_id": sint64_run_id,
|
|
1174
|
+
"node_id": 0,
|
|
1175
|
+
"after_timestamp": after_timestamp,
|
|
1176
|
+
},
|
|
1177
|
+
)
|
|
1235
1178
|
latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
|
|
1236
1179
|
return "".join(row["log"] for row in rows), latest_timestamp
|
|
1237
1180
|
|
|
@@ -1240,62 +1183,74 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1240
1183
|
|
|
1241
1184
|
Return Message if valid.
|
|
1242
1185
|
"""
|
|
1243
|
-
with self.
|
|
1186
|
+
with self.session():
|
|
1244
1187
|
self._check_stored_messages({message_id})
|
|
1245
1188
|
query = """
|
|
1246
1189
|
SELECT *
|
|
1247
1190
|
FROM message_ins
|
|
1248
1191
|
WHERE message_id = :message_id
|
|
1249
1192
|
"""
|
|
1250
|
-
|
|
1251
|
-
rows: list[dict[str, Any]] = self.conn.execute(query, data).fetchall()
|
|
1193
|
+
rows = self.query(query, {"message_id": message_id})
|
|
1252
1194
|
if not rows:
|
|
1253
1195
|
# Message does not exist
|
|
1254
1196
|
return None
|
|
1255
1197
|
|
|
1256
|
-
return rows[0]
|
|
1257
|
-
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
)
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1198
|
+
return dict(rows[0])
|
|
1199
|
+
|
|
1200
|
+
def store_traffic(self, run_id: int, *, bytes_sent: int, bytes_recv: int) -> None:
|
|
1201
|
+
"""Store traffic data for the specified `run_id`."""
|
|
1202
|
+
# Validate non-negative values
|
|
1203
|
+
if bytes_sent < 0 or bytes_recv < 0:
|
|
1204
|
+
raise ValueError(
|
|
1205
|
+
f"Negative traffic values for run {run_id}: "
|
|
1206
|
+
f"bytes_sent={bytes_sent}, bytes_recv={bytes_recv}"
|
|
1207
|
+
)
|
|
1208
|
+
|
|
1209
|
+
if bytes_sent == 0 and bytes_recv == 0:
|
|
1210
|
+
raise ValueError(
|
|
1211
|
+
f"Both bytes_sent and bytes_recv cannot be zero for run {run_id}"
|
|
1212
|
+
)
|
|
1213
|
+
|
|
1214
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
1215
|
+
|
|
1216
|
+
with self.session():
|
|
1217
|
+
# Check if run exists, performing the update only if it does
|
|
1218
|
+
update_query = """
|
|
1219
|
+
UPDATE run
|
|
1220
|
+
SET bytes_sent = bytes_sent + :bytes_sent,
|
|
1221
|
+
bytes_recv = bytes_recv + :bytes_recv
|
|
1222
|
+
WHERE run_id = :run_id
|
|
1223
|
+
RETURNING run_id
|
|
1224
|
+
"""
|
|
1225
|
+
rows = self.query(
|
|
1226
|
+
update_query,
|
|
1227
|
+
{
|
|
1228
|
+
"bytes_sent": bytes_sent,
|
|
1229
|
+
"bytes_recv": bytes_recv,
|
|
1230
|
+
"run_id": sint64_run_id,
|
|
1231
|
+
},
|
|
1232
|
+
)
|
|
1233
|
+
|
|
1234
|
+
if not rows:
|
|
1235
|
+
raise ValueError(f"Run {run_id} not found")
|
|
1236
|
+
|
|
1237
|
+
def add_clientapp_runtime(self, run_id: int, runtime: float) -> None:
|
|
1238
|
+
"""Add ClientApp runtime to the cumulative total for the specified `run_id`."""
|
|
1239
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
1240
|
+
with self.session():
|
|
1241
|
+
# Check if run exists, performing the update only if it does
|
|
1242
|
+
update_query = """
|
|
1243
|
+
UPDATE run
|
|
1244
|
+
SET clientapp_runtime = clientapp_runtime + :runtime
|
|
1245
|
+
WHERE run_id = :run_id
|
|
1246
|
+
RETURNING run_id
|
|
1247
|
+
"""
|
|
1248
|
+
rows = self.query(
|
|
1249
|
+
update_query, {"runtime": runtime, "run_id": sint64_run_id}
|
|
1250
|
+
)
|
|
1251
|
+
|
|
1252
|
+
if not rows:
|
|
1253
|
+
raise ValueError(f"Run {run_id} not found")
|
|
1299
1254
|
|
|
1300
1255
|
|
|
1301
1256
|
def determine_run_status(row: dict[str, Any]) -> str:
|
|
@@ -1309,4 +1264,4 @@ def determine_run_status(row: dict[str, Any]) -> str:
|
|
|
1309
1264
|
return Status.STARTING
|
|
1310
1265
|
return Status.PENDING
|
|
1311
1266
|
run_id = int64_to_uint64(row["run_id"])
|
|
1312
|
-
raise
|
|
1267
|
+
raise ValueError(f"The run {run_id} does not have a valid status.")
|