flwr 1.25.0__py3-none-any.whl → 1.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +1 -1
- flwr/app/__init__.py +4 -1
- flwr/app/message_type.py +29 -0
- flwr/app/metadata.py +5 -2
- flwr/app/user_config.py +19 -0
- flwr/cli/app.py +37 -19
- flwr/cli/app_cmd/publish.py +25 -75
- flwr/cli/app_cmd/review.py +18 -69
- flwr/cli/auth_plugin/auth_plugin.py +5 -10
- flwr/cli/auth_plugin/noop_auth_plugin.py +1 -2
- flwr/cli/auth_plugin/oidc_cli_plugin.py +38 -38
- flwr/cli/build.py +15 -28
- flwr/cli/config/__init__.py +21 -0
- flwr/cli/config/ls.py +71 -0
- flwr/cli/config_migration.py +297 -0
- flwr/cli/config_utils.py +63 -156
- flwr/cli/constant.py +71 -0
- flwr/cli/federation/__init__.py +0 -2
- flwr/cli/federation/ls.py +256 -64
- flwr/cli/flower_config.py +429 -0
- flwr/cli/install.py +23 -62
- flwr/cli/log.py +23 -37
- flwr/cli/login/login.py +29 -63
- flwr/cli/ls.py +28 -58
- flwr/cli/new/new.py +9 -29
- flwr/cli/pull.py +19 -37
- flwr/cli/run/run.py +85 -93
- flwr/cli/run_utils.py +1 -1
- flwr/cli/stop.py +32 -73
- flwr/cli/supernode/ls.py +25 -57
- flwr/cli/supernode/register.py +31 -80
- flwr/cli/supernode/unregister.py +24 -70
- flwr/cli/typing.py +200 -0
- flwr/cli/utils.py +160 -275
- flwr/client/grpc_rere_client/connection.py +3 -3
- flwr/client/grpc_rere_client/grpc_adapter.py +1 -1
- flwr/client/message_handler/message_handler.py +2 -1
- flwr/client/mod/centraldp_mods.py +1 -1
- flwr/client/mod/localdp_mod.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/client/run_info_store.py +2 -1
- flwr/clientapp/client_app.py +2 -1
- flwr/common/__init__.py +3 -2
- flwr/common/args.py +5 -5
- flwr/common/config.py +12 -17
- flwr/common/constant.py +3 -16
- flwr/common/context.py +2 -1
- flwr/common/exit/exit.py +4 -4
- flwr/common/exit/exit_code.py +6 -0
- flwr/common/grpc.py +2 -1
- flwr/common/logger.py +1 -1
- flwr/common/message.py +1 -1
- flwr/common/retry_invoker.py +13 -5
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -2
- flwr/common/serde.py +7 -5
- flwr/common/telemetry.py +1 -1
- flwr/common/typing.py +4 -3
- flwr/compat/client/app.py +6 -9
- flwr/compat/client/grpc_client/connection.py +2 -1
- flwr/compat/common/constant.py +29 -0
- flwr/compat/server/app.py +1 -1
- flwr/proto/clientappio_pb2.py +2 -2
- flwr/proto/clientappio_pb2_grpc.py +104 -88
- flwr/proto/clientappio_pb2_grpc.pyi +140 -80
- flwr/proto/federation_pb2.py +5 -3
- flwr/proto/federation_pb2.pyi +32 -2
- flwr/proto/run_pb2.py +5 -13
- flwr/proto/run_pb2.pyi +0 -57
- flwr/proto/serverappio_pb2.py +2 -2
- flwr/proto/serverappio_pb2_grpc.py +138 -207
- flwr/proto/serverappio_pb2_grpc.pyi +189 -155
- flwr/proto/simulationio_pb2.py +2 -2
- flwr/proto/simulationio_pb2_grpc.py +62 -90
- flwr/proto/simulationio_pb2_grpc.pyi +95 -55
- flwr/server/app.py +6 -13
- flwr/server/compat/grid_client_proxy.py +2 -1
- flwr/server/grid/grpc_grid.py +5 -5
- flwr/server/serverapp/app.py +11 -4
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +13 -12
- flwr/server/superlink/fleet/message_handler/message_handler.py +6 -5
- flwr/server/superlink/linkstate/__init__.py +2 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +2 -10
- flwr/server/superlink/linkstate/linkstate.py +2 -21
- flwr/server/superlink/linkstate/linkstate_factory.py +16 -8
- flwr/server/superlink/linkstate/{sqlite_linkstate.py → sql_linkstate.py} +432 -534
- flwr/server/superlink/linkstate/utils.py +49 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +1 -33
- flwr/server/superlink/simulation/simulationio_servicer.py +0 -19
- flwr/server/utils/validator.py +1 -1
- flwr/server/workflow/default_workflows.py +2 -1
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +1 -1
- flwr/serverapp/strategy/bulyan.py +7 -1
- flwr/serverapp/strategy/dp_fixed_clipping.py +9 -1
- flwr/serverapp/strategy/fedavg.py +1 -1
- flwr/serverapp/strategy/fedxgb_cyclic.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +2 -6
- flwr/simulation/run_simulation.py +3 -12
- flwr/simulation/simulationio_connection.py +3 -3
- flwr/{common → supercore}/address.py +7 -33
- flwr/supercore/app_utils.py +2 -1
- flwr/supercore/constant.py +24 -2
- flwr/supercore/corestate/{sqlite_corestate.py → sql_corestate.py} +19 -23
- flwr/supercore/credential_store/__init__.py +33 -0
- flwr/supercore/credential_store/credential_store.py +34 -0
- flwr/supercore/credential_store/file_credential_store.py +76 -0
- flwr/{common → supercore}/date.py +0 -11
- flwr/supercore/ffs/disk_ffs.py +1 -1
- flwr/supercore/object_store/object_store_factory.py +14 -6
- flwr/supercore/object_store/{sqlite_object_store.py → sql_object_store.py} +115 -117
- flwr/supercore/sql_mixin.py +315 -0
- flwr/supercore/state/__init__.py +15 -0
- flwr/supercore/state/alembic/__init__.py +15 -0
- flwr/supercore/state/alembic/env.py +103 -0
- flwr/supercore/state/alembic/script.py.mako +43 -0
- flwr/supercore/state/alembic/utils.py +239 -0
- flwr/supercore/state/alembic/versions/__init__.py +15 -0
- flwr/supercore/state/alembic/versions/rev_2026_01_28_initialize_migration_of_state_tables.py +200 -0
- flwr/supercore/state/schema/README.md +121 -0
- flwr/supercore/state/schema/__init__.py +15 -0
- flwr/supercore/state/schema/corestate_tables.py +36 -0
- flwr/supercore/state/schema/linkstate_tables.py +152 -0
- flwr/supercore/state/schema/objectstore_tables.py +90 -0
- flwr/supercore/superexec/run_superexec.py +2 -2
- flwr/supercore/utils.py +36 -1
- flwr/superlink/federation/federation_manager.py +2 -2
- flwr/superlink/federation/noop_federation_manager.py +8 -6
- flwr/superlink/servicer/control/control_servicer.py +19 -17
- flwr/supernode/cli/flower_supernode.py +2 -1
- flwr/supernode/runtime/run_clientapp.py +14 -14
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +10 -8
- flwr/supernode/start_client_internal.py +10 -6
- {flwr-1.25.0.dist-info → flwr-1.26.0.dist-info}/METADATA +7 -5
- {flwr-1.25.0.dist-info → flwr-1.26.0.dist-info}/RECORD +137 -116
- flwr/cli/federation/show.py +0 -318
- flwr/common/pyproject.py +0 -42
- flwr/supercore/sqlite_mixin.py +0 -159
- /flwr/{common → supercore}/version.py +0 -0
- {flwr-1.25.0.dist-info → flwr-1.26.0.dist-info}/WHEEL +0 -0
- {flwr-1.25.0.dist-info → flwr-1.26.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2026 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,19 +12,22 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""
|
|
15
|
+
"""SQLAlchemy-based implementation of the link state."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
# pylint: disable=too-many-lines
|
|
19
19
|
|
|
20
20
|
import json
|
|
21
|
-
import sqlite3
|
|
22
21
|
from collections.abc import Sequence
|
|
23
22
|
from datetime import datetime, timezone
|
|
24
23
|
from logging import ERROR, WARNING
|
|
25
|
-
from typing import Any
|
|
24
|
+
from typing import Any
|
|
26
25
|
|
|
27
|
-
from
|
|
26
|
+
from sqlalchemy import MetaData
|
|
27
|
+
from sqlalchemy.exc import IntegrityError
|
|
28
|
+
|
|
29
|
+
from flwr.app.user_config import UserConfig
|
|
30
|
+
from flwr.common import Context, Message, log, now
|
|
28
31
|
from flwr.common.constant import (
|
|
29
32
|
HEARTBEAT_PATIENCE,
|
|
30
33
|
MESSAGE_TTL_TOLERANCE,
|
|
@@ -35,22 +38,15 @@ from flwr.common.constant import (
|
|
|
35
38
|
Status,
|
|
36
39
|
SubStatus,
|
|
37
40
|
)
|
|
38
|
-
from flwr.common.message import make_message
|
|
39
41
|
from flwr.common.record import ConfigRecord
|
|
40
|
-
from flwr.common.
|
|
41
|
-
from flwr.
|
|
42
|
-
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
43
|
-
|
|
44
|
-
# pylint: disable=E0611
|
|
45
|
-
from flwr.proto.error_pb2 import Error as ProtoError
|
|
46
|
-
from flwr.proto.node_pb2 import NodeInfo
|
|
47
|
-
from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
|
|
48
|
-
|
|
49
|
-
# pylint: enable=E0611
|
|
42
|
+
from flwr.common.typing import Run, RunStatus
|
|
43
|
+
from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
|
|
50
44
|
from flwr.server.utils.validator import validate_message
|
|
51
45
|
from flwr.supercore.constant import NodeStatus
|
|
52
|
-
from flwr.supercore.corestate.
|
|
46
|
+
from flwr.supercore.corestate.sql_corestate import SqlCoreState
|
|
53
47
|
from flwr.supercore.object_store.object_store import ObjectStore
|
|
48
|
+
from flwr.supercore.state.schema.corestate_tables import create_corestate_metadata
|
|
49
|
+
from flwr.supercore.state.schema.linkstate_tables import create_linkstate_metadata
|
|
54
50
|
from flwr.supercore.utils import int64_to_uint64, uint64_to_int64
|
|
55
51
|
from flwr.superlink.federation import FederationManager
|
|
56
52
|
|
|
@@ -63,128 +59,18 @@ from .utils import (
|
|
|
63
59
|
context_to_bytes,
|
|
64
60
|
convert_sint64_values_in_dict_to_uint64,
|
|
65
61
|
convert_uint64_values_in_dict_to_sint64,
|
|
62
|
+
dict_to_message,
|
|
66
63
|
generate_rand_int_from_bytes,
|
|
67
64
|
has_valid_sub_status,
|
|
68
65
|
is_valid_transition,
|
|
66
|
+
message_to_dict,
|
|
69
67
|
verify_found_message_replies,
|
|
70
68
|
verify_message_ids,
|
|
71
69
|
)
|
|
72
70
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
owner_aid TEXT,
|
|
77
|
-
owner_name TEXT,
|
|
78
|
-
status TEXT,
|
|
79
|
-
registered_at TEXT,
|
|
80
|
-
last_activated_at TEXT NULL,
|
|
81
|
-
last_deactivated_at TEXT NULL,
|
|
82
|
-
unregistered_at TEXT NULL,
|
|
83
|
-
online_until TIMESTAMP NULL,
|
|
84
|
-
heartbeat_interval REAL,
|
|
85
|
-
public_key BLOB UNIQUE
|
|
86
|
-
);
|
|
87
|
-
"""
|
|
88
|
-
|
|
89
|
-
SQL_CREATE_TABLE_PUBLIC_KEY = """
|
|
90
|
-
CREATE TABLE IF NOT EXISTS public_key(
|
|
91
|
-
public_key BLOB PRIMARY KEY
|
|
92
|
-
);
|
|
93
|
-
"""
|
|
94
|
-
|
|
95
|
-
SQL_CREATE_INDEX_ONLINE_UNTIL = """
|
|
96
|
-
CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
|
|
97
|
-
"""
|
|
98
|
-
|
|
99
|
-
SQL_CREATE_INDEX_OWNER_AID = """
|
|
100
|
-
CREATE INDEX IF NOT EXISTS idx_node_owner_aid ON node(owner_aid);
|
|
101
|
-
"""
|
|
102
|
-
|
|
103
|
-
SQL_CREATE_INDEX_NODE_STATUS = """
|
|
104
|
-
CREATE INDEX IF NOT EXISTS idx_node_status ON node(status);
|
|
105
|
-
"""
|
|
106
|
-
|
|
107
|
-
SQL_CREATE_TABLE_RUN = """
|
|
108
|
-
CREATE TABLE IF NOT EXISTS run(
|
|
109
|
-
run_id INTEGER UNIQUE,
|
|
110
|
-
fab_id TEXT,
|
|
111
|
-
fab_version TEXT,
|
|
112
|
-
fab_hash TEXT,
|
|
113
|
-
override_config TEXT,
|
|
114
|
-
pending_at TEXT,
|
|
115
|
-
starting_at TEXT,
|
|
116
|
-
running_at TEXT,
|
|
117
|
-
finished_at TEXT,
|
|
118
|
-
sub_status TEXT,
|
|
119
|
-
details TEXT,
|
|
120
|
-
federation TEXT,
|
|
121
|
-
federation_options BLOB,
|
|
122
|
-
flwr_aid TEXT,
|
|
123
|
-
bytes_sent INTEGER DEFAULT 0,
|
|
124
|
-
bytes_recv INTEGER DEFAULT 0,
|
|
125
|
-
clientapp_runtime REAL DEFAULT 0.0
|
|
126
|
-
);
|
|
127
|
-
"""
|
|
128
|
-
|
|
129
|
-
SQL_CREATE_TABLE_LOGS = """
|
|
130
|
-
CREATE TABLE IF NOT EXISTS logs (
|
|
131
|
-
timestamp REAL,
|
|
132
|
-
run_id INTEGER,
|
|
133
|
-
node_id INTEGER,
|
|
134
|
-
log TEXT,
|
|
135
|
-
PRIMARY KEY (timestamp, run_id, node_id),
|
|
136
|
-
FOREIGN KEY (run_id) REFERENCES run(run_id)
|
|
137
|
-
);
|
|
138
|
-
"""
|
|
139
|
-
|
|
140
|
-
SQL_CREATE_TABLE_CONTEXT = """
|
|
141
|
-
CREATE TABLE IF NOT EXISTS context(
|
|
142
|
-
run_id INTEGER UNIQUE,
|
|
143
|
-
context BLOB,
|
|
144
|
-
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
|
145
|
-
);
|
|
146
|
-
"""
|
|
147
|
-
|
|
148
|
-
SQL_CREATE_TABLE_MESSAGE_INS = """
|
|
149
|
-
CREATE TABLE IF NOT EXISTS message_ins(
|
|
150
|
-
message_id TEXT UNIQUE,
|
|
151
|
-
group_id TEXT,
|
|
152
|
-
run_id INTEGER,
|
|
153
|
-
src_node_id INTEGER,
|
|
154
|
-
dst_node_id INTEGER,
|
|
155
|
-
reply_to_message_id TEXT,
|
|
156
|
-
created_at REAL,
|
|
157
|
-
delivered_at TEXT,
|
|
158
|
-
ttl REAL,
|
|
159
|
-
message_type TEXT,
|
|
160
|
-
content BLOB NULL,
|
|
161
|
-
error BLOB NULL,
|
|
162
|
-
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
|
163
|
-
);
|
|
164
|
-
"""
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
SQL_CREATE_TABLE_MESSAGE_RES = """
|
|
168
|
-
CREATE TABLE IF NOT EXISTS message_res(
|
|
169
|
-
message_id TEXT UNIQUE,
|
|
170
|
-
group_id TEXT,
|
|
171
|
-
run_id INTEGER,
|
|
172
|
-
src_node_id INTEGER,
|
|
173
|
-
dst_node_id INTEGER,
|
|
174
|
-
reply_to_message_id TEXT,
|
|
175
|
-
created_at REAL,
|
|
176
|
-
delivered_at TEXT,
|
|
177
|
-
ttl REAL,
|
|
178
|
-
message_type TEXT,
|
|
179
|
-
content BLOB NULL,
|
|
180
|
-
error BLOB NULL,
|
|
181
|
-
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
|
182
|
-
);
|
|
183
|
-
"""
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
187
|
-
"""SQLite-based LinkState implementation."""
|
|
71
|
+
|
|
72
|
+
class SqlLinkState(LinkState, SqlCoreState): # pylint: disable=R0904
|
|
73
|
+
"""SQLAlchemy-based LinkState implementation."""
|
|
188
74
|
|
|
189
75
|
def __init__(
|
|
190
76
|
self,
|
|
@@ -196,24 +82,21 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
196
82
|
federation_manager.linkstate = self
|
|
197
83
|
self._federation_manager = federation_manager
|
|
198
84
|
|
|
199
|
-
def
|
|
200
|
-
"""Return
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
SQL_CREATE_INDEX_OWNER_AID,
|
|
211
|
-
SQL_CREATE_INDEX_NODE_STATUS,
|
|
212
|
-
)
|
|
85
|
+
def get_metadata(self) -> MetaData:
|
|
86
|
+
"""Return combined SQLAlchemy MetaData for LinkState and CoreState tables."""
|
|
87
|
+
# Start with linkstate tables
|
|
88
|
+
metadata = create_linkstate_metadata()
|
|
89
|
+
|
|
90
|
+
# Add corestate tables (token_store)
|
|
91
|
+
corestate_metadata = create_corestate_metadata()
|
|
92
|
+
for table in corestate_metadata.tables.values():
|
|
93
|
+
table.to_metadata(metadata)
|
|
94
|
+
|
|
95
|
+
return metadata
|
|
213
96
|
|
|
214
97
|
@property
|
|
215
98
|
def federation_manager(self) -> FederationManager:
|
|
216
|
-
"""
|
|
99
|
+
"""Return the FederationManager instance."""
|
|
217
100
|
return self._federation_manager
|
|
218
101
|
|
|
219
102
|
def store_message_ins(self, message: Message) -> str | None:
|
|
@@ -241,20 +124,26 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
241
124
|
)
|
|
242
125
|
return None
|
|
243
126
|
|
|
244
|
-
with self.
|
|
127
|
+
with self.session():
|
|
245
128
|
# Validate run_id
|
|
246
|
-
query = "SELECT federation FROM run WHERE run_id =
|
|
247
|
-
rows = self.
|
|
129
|
+
query = "SELECT federation FROM run WHERE run_id = :run_id"
|
|
130
|
+
rows = self.query(query, {"run_id": data[0]["run_id"]})
|
|
248
131
|
if not rows:
|
|
249
132
|
log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
|
|
250
133
|
return None
|
|
251
134
|
federation: str = rows[0]["federation"]
|
|
252
135
|
|
|
253
136
|
# Validate destination node ID
|
|
254
|
-
query = "SELECT node_id FROM node WHERE node_id =
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
137
|
+
query = """SELECT node_id FROM node WHERE node_id = :node_id
|
|
138
|
+
AND status IN (:online, :offline)"""
|
|
139
|
+
rows = self.query(
|
|
140
|
+
query,
|
|
141
|
+
{
|
|
142
|
+
"node_id": data[0]["dst_node_id"],
|
|
143
|
+
"online": NodeStatus.ONLINE,
|
|
144
|
+
"offline": NodeStatus.OFFLINE,
|
|
145
|
+
},
|
|
146
|
+
)
|
|
258
147
|
if not rows or not self.federation_manager.has_node(
|
|
259
148
|
message.metadata.dst_node_id, federation
|
|
260
149
|
):
|
|
@@ -265,29 +154,62 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
265
154
|
)
|
|
266
155
|
return None
|
|
267
156
|
|
|
157
|
+
# Insert message
|
|
268
158
|
columns = ", ".join([f":{key}" for key in data[0]])
|
|
269
|
-
query = f"INSERT INTO message_ins VALUES({columns})
|
|
159
|
+
query = f"INSERT INTO message_ins VALUES({columns})"
|
|
270
160
|
|
|
271
161
|
# Only invalid run_id can trigger IntegrityError.
|
|
272
162
|
# This may need to be changed in the future version
|
|
273
163
|
# with more integrity checks.
|
|
274
|
-
self.
|
|
164
|
+
self.query(query, data[0])
|
|
275
165
|
|
|
276
166
|
return message.metadata.message_id
|
|
277
167
|
|
|
168
|
+
# pylint: disable-next=too-many-locals
|
|
278
169
|
def _check_stored_messages(self, message_ids: set[str]) -> None:
|
|
279
170
|
"""Check and delete the message if it's invalid."""
|
|
280
171
|
if not message_ids:
|
|
281
172
|
return
|
|
282
173
|
|
|
283
|
-
with self.
|
|
174
|
+
with self.session():
|
|
175
|
+
# Batch fetch all messages in one query
|
|
176
|
+
placeholders = ",".join([f":mid_{i}" for i in range(len(message_ids))])
|
|
177
|
+
query = f"""
|
|
178
|
+
SELECT * FROM message_ins
|
|
179
|
+
WHERE message_id IN ({placeholders})
|
|
180
|
+
"""
|
|
181
|
+
params = {f"mid_{i}": str(mid) for i, mid in enumerate(message_ids)}
|
|
182
|
+
message_rows = self.query(query, params)
|
|
183
|
+
|
|
184
|
+
if not message_rows:
|
|
185
|
+
return
|
|
186
|
+
|
|
187
|
+
# Build message lookup dict
|
|
188
|
+
message_dict: dict[str, dict[str, Any]] = {
|
|
189
|
+
row["message_id"]: row for row in message_rows
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
# Collect unique run_ids for batch federation lookup
|
|
193
|
+
run_ids = {row["run_id"] for row in message_rows}
|
|
194
|
+
placeholders = ",".join([f":rid_{i}" for i in range(len(run_ids))])
|
|
195
|
+
query = f"""
|
|
196
|
+
SELECT run_id, federation FROM run
|
|
197
|
+
WHERE run_id IN ({placeholders})
|
|
198
|
+
"""
|
|
199
|
+
params = {f"rid_{i}": rid for i, rid in enumerate(run_ids)}
|
|
200
|
+
run_rows = self.query(query, params)
|
|
201
|
+
|
|
202
|
+
# Build run_id to federation mapping
|
|
203
|
+
run_id_to_federation: dict[int, str] = {
|
|
204
|
+
row["run_id"]: row["federation"] for row in run_rows
|
|
205
|
+
}
|
|
206
|
+
|
|
284
207
|
invalid_msg_ids: set[str] = set()
|
|
285
208
|
current_time = now().timestamp()
|
|
286
209
|
|
|
210
|
+
# Check each message for validity
|
|
287
211
|
for msg_id in message_ids:
|
|
288
|
-
|
|
289
|
-
query = "SELECT * FROM message_ins WHERE message_id = ?;"
|
|
290
|
-
message_row = self.conn.execute(query, (msg_id,)).fetchone()
|
|
212
|
+
message_row = message_dict.get(msg_id)
|
|
291
213
|
if not message_row:
|
|
292
214
|
continue
|
|
293
215
|
|
|
@@ -297,15 +219,12 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
297
219
|
invalid_msg_ids.add(msg_id)
|
|
298
220
|
continue
|
|
299
221
|
|
|
300
|
-
# Check if
|
|
301
|
-
# Get federation from run table
|
|
222
|
+
# Check if run exists and get federation
|
|
302
223
|
run_id = message_row["run_id"]
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
if not run_row: # This should not happen
|
|
224
|
+
federation = run_id_to_federation.get(run_id)
|
|
225
|
+
if not federation:
|
|
306
226
|
invalid_msg_ids.add(msg_id)
|
|
307
227
|
continue
|
|
308
|
-
federation = run_row["federation"]
|
|
309
228
|
|
|
310
229
|
# Convert sint64 to uint64 for node IDs
|
|
311
230
|
src_node_id = int64_to_uint64(message_row["src_node_id"])
|
|
@@ -330,52 +249,48 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
330
249
|
msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
|
|
331
250
|
raise AssertionError(msg)
|
|
332
251
|
|
|
333
|
-
|
|
252
|
+
params: dict[str, str | int] = {}
|
|
334
253
|
|
|
335
254
|
# Convert the uint64 value to sint64 for SQLite
|
|
336
|
-
|
|
255
|
+
params["node_id"] = uint64_to_int64(node_id)
|
|
337
256
|
|
|
338
|
-
with self.
|
|
257
|
+
with self.session():
|
|
339
258
|
# Retrieve all Messages for node_id
|
|
340
259
|
query = """
|
|
341
260
|
SELECT message_id
|
|
342
261
|
FROM message_ins
|
|
343
|
-
WHERE
|
|
344
|
-
AND
|
|
345
|
-
AND
|
|
262
|
+
WHERE dst_node_id = :node_id
|
|
263
|
+
AND delivered_at = ''
|
|
264
|
+
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
|
|
346
265
|
"""
|
|
347
266
|
|
|
348
267
|
if limit is not None:
|
|
349
268
|
query += " LIMIT :limit"
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
query += ";"
|
|
269
|
+
params["limit"] = limit
|
|
353
270
|
|
|
354
|
-
rows = self.
|
|
271
|
+
rows = self.query(query, params)
|
|
355
272
|
message_ids: set[str] = {row["message_id"] for row in rows}
|
|
356
273
|
self._check_stored_messages(message_ids)
|
|
357
274
|
|
|
358
275
|
# Mark retrieved Messages as delivered
|
|
359
276
|
if rows:
|
|
360
277
|
# Prepare query
|
|
361
|
-
placeholders
|
|
362
|
-
[f":id_{i}" for i in range(len(message_ids))]
|
|
363
|
-
)
|
|
278
|
+
placeholders = ",".join([f":mid_{i}" for i in range(len(message_ids))])
|
|
364
279
|
query = f"""
|
|
365
280
|
UPDATE message_ins
|
|
366
281
|
SET delivered_at = :delivered_at
|
|
367
282
|
WHERE message_id IN ({placeholders})
|
|
368
|
-
RETURNING
|
|
283
|
+
RETURNING *
|
|
369
284
|
"""
|
|
370
285
|
|
|
371
286
|
# Prepare data for query
|
|
372
287
|
delivered_at = now().isoformat()
|
|
373
|
-
|
|
288
|
+
params = {"delivered_at": delivered_at}
|
|
374
289
|
for index, msg_id in enumerate(message_ids):
|
|
375
|
-
|
|
290
|
+
params[f"mid_{index}"] = str(msg_id)
|
|
376
291
|
|
|
377
292
|
# Run query
|
|
378
|
-
rows = self.
|
|
293
|
+
rows = self.query(query, params)
|
|
379
294
|
|
|
380
295
|
for row in rows:
|
|
381
296
|
# Convert values from sint64 to uint64
|
|
@@ -383,7 +298,7 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
383
298
|
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
384
299
|
)
|
|
385
300
|
|
|
386
|
-
result = [dict_to_message(row) for row in rows]
|
|
301
|
+
result = [dict_to_message(dict(row)) for row in rows]
|
|
387
302
|
|
|
388
303
|
return result
|
|
389
304
|
|
|
@@ -409,8 +324,8 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
409
324
|
)
|
|
410
325
|
return None
|
|
411
326
|
|
|
412
|
-
# Ensure that the dst_node_id of the original message matches the src_node_id
|
|
413
|
-
# reply being processed.
|
|
327
|
+
# Ensure that the dst_node_id of the original message matches the src_node_id
|
|
328
|
+
# of reply being processed.
|
|
414
329
|
if (
|
|
415
330
|
msg_ins
|
|
416
331
|
and message
|
|
@@ -440,21 +355,19 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
440
355
|
return None
|
|
441
356
|
|
|
442
357
|
# Store Message
|
|
443
|
-
|
|
358
|
+
msg_dict = message_to_dict(message)
|
|
444
359
|
|
|
445
360
|
# Convert values from uint64 to sint64 for SQLite
|
|
446
361
|
convert_uint64_values_in_dict_to_sint64(
|
|
447
|
-
|
|
362
|
+
msg_dict, ["run_id", "src_node_id", "dst_node_id"]
|
|
448
363
|
)
|
|
449
364
|
|
|
450
|
-
columns = ", ".join([f":{key}" for key in
|
|
451
|
-
query = f"INSERT INTO message_res VALUES({columns})
|
|
365
|
+
columns = ", ".join([f":{key}" for key in msg_dict])
|
|
366
|
+
query = f"INSERT INTO message_res VALUES({columns})"
|
|
452
367
|
|
|
453
|
-
# Only invalid run_id can trigger IntegrityError.
|
|
454
|
-
# This may need to be changed in the future version with more integrity checks.
|
|
455
368
|
try:
|
|
456
|
-
self.query(query,
|
|
457
|
-
except
|
|
369
|
+
self.query(query, msg_dict)
|
|
370
|
+
except IntegrityError:
|
|
458
371
|
log(ERROR, "`run` is invalid")
|
|
459
372
|
return None
|
|
460
373
|
|
|
@@ -462,21 +375,21 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
462
375
|
|
|
463
376
|
def get_message_res(self, message_ids: set[str]) -> list[Message]:
|
|
464
377
|
"""Get reply Messages for the given Message IDs."""
|
|
465
|
-
# pylint: disable
|
|
378
|
+
# pylint: disable=too-many-locals
|
|
466
379
|
ret: dict[str, Message] = {}
|
|
467
380
|
|
|
468
|
-
with self.
|
|
381
|
+
with self.session():
|
|
469
382
|
# Verify Message IDs
|
|
470
383
|
self._check_stored_messages(message_ids)
|
|
471
384
|
current = now().timestamp()
|
|
385
|
+
placeholders = ",".join([f":mid_{i}" for i in range(len(message_ids))])
|
|
472
386
|
query = f"""
|
|
473
387
|
SELECT *
|
|
474
388
|
FROM message_ins
|
|
475
|
-
WHERE message_id IN ({
|
|
389
|
+
WHERE message_id IN ({placeholders})
|
|
476
390
|
"""
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
).fetchall()
|
|
391
|
+
params = {f"mid_{i}": str(mid) for i, mid in enumerate(message_ids)}
|
|
392
|
+
rows = self.query(query, params)
|
|
480
393
|
found_message_ins_dict: dict[str, Message] = {}
|
|
481
394
|
for row in rows:
|
|
482
395
|
convert_sint64_values_in_dict_to_uint64(
|
|
@@ -496,15 +409,18 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
496
409
|
in_message = found_message_ins_dict[message_id]
|
|
497
410
|
sint_node_id = uint64_to_int64(in_message.metadata.dst_node_id)
|
|
498
411
|
dst_node_ids.add(sint_node_id)
|
|
412
|
+
placeholders = ",".join([f":nid_{i}" for i in range(len(dst_node_ids))])
|
|
499
413
|
query = f"""
|
|
500
414
|
SELECT node_id, online_until
|
|
501
415
|
FROM node
|
|
502
|
-
WHERE node_id IN ({
|
|
503
|
-
AND status !=
|
|
416
|
+
WHERE node_id IN ({placeholders})
|
|
417
|
+
AND status != :unregistered
|
|
504
418
|
"""
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
419
|
+
node_params: dict[str, int | str] = {
|
|
420
|
+
f"nid_{i}": nid for i, nid in enumerate(dst_node_ids)
|
|
421
|
+
}
|
|
422
|
+
node_params["unregistered"] = NodeStatus.UNREGISTERED
|
|
423
|
+
rows = self.query(query, node_params)
|
|
508
424
|
tmp_ret_dict = check_node_availability_for_in_message(
|
|
509
425
|
inquired_in_message_ids=message_ids,
|
|
510
426
|
found_in_message_dict=found_message_ins_dict,
|
|
@@ -516,15 +432,15 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
516
432
|
ret.update(tmp_ret_dict)
|
|
517
433
|
|
|
518
434
|
# Find all reply Messages
|
|
435
|
+
placeholders = ",".join([f":mid_{i}" for i in range(len(message_ids))])
|
|
519
436
|
query = f"""
|
|
520
437
|
SELECT *
|
|
521
438
|
FROM message_res
|
|
522
|
-
WHERE reply_to_message_id IN ({
|
|
523
|
-
AND delivered_at =
|
|
439
|
+
WHERE reply_to_message_id IN ({placeholders})
|
|
440
|
+
AND delivered_at = ''
|
|
524
441
|
"""
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
).fetchall()
|
|
442
|
+
params = {f"mid_{i}": str(mid) for i, mid in enumerate(message_ids)}
|
|
443
|
+
rows = self.query(query, params)
|
|
528
444
|
for row in rows:
|
|
529
445
|
convert_sint64_values_in_dict_to_uint64(
|
|
530
446
|
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
@@ -544,13 +460,15 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
544
460
|
message_res_ids = [
|
|
545
461
|
message_res.metadata.message_id for message_res in ret.values()
|
|
546
462
|
]
|
|
463
|
+
placeholders = ",".join([f":mid_{i}" for i in range(len(message_res_ids))])
|
|
547
464
|
query = f"""
|
|
548
465
|
UPDATE message_res
|
|
549
|
-
SET delivered_at =
|
|
550
|
-
WHERE message_id IN ({
|
|
466
|
+
SET delivered_at = :delivered_at
|
|
467
|
+
WHERE message_id IN ({placeholders})
|
|
551
468
|
"""
|
|
552
|
-
|
|
553
|
-
|
|
469
|
+
params = {"delivered_at": delivered_at}
|
|
470
|
+
params.update({f"mid_{i}": mid for i, mid in enumerate(message_res_ids)})
|
|
471
|
+
self.query(query, params)
|
|
554
472
|
|
|
555
473
|
return list(ret.values())
|
|
556
474
|
|
|
@@ -559,64 +477,55 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
559
477
|
|
|
560
478
|
This includes delivered but not yet deleted.
|
|
561
479
|
"""
|
|
562
|
-
query = "SELECT count(*) AS num FROM message_ins
|
|
563
|
-
rows = self.query(query)
|
|
564
|
-
|
|
565
|
-
num = cast(int, result["num"])
|
|
566
|
-
return num
|
|
480
|
+
query = "SELECT count(*) AS num FROM message_ins"
|
|
481
|
+
rows = self.query(query, {})
|
|
482
|
+
return int(rows[0]["num"])
|
|
567
483
|
|
|
568
484
|
def num_message_res(self) -> int:
|
|
569
485
|
"""Calculate the number of reply Messages in store.
|
|
570
486
|
|
|
571
487
|
This includes delivered but not yet deleted.
|
|
572
488
|
"""
|
|
573
|
-
query = "SELECT count(*) AS num FROM message_res
|
|
489
|
+
query = "SELECT count(*) AS num FROM message_res"
|
|
574
490
|
rows = self.query(query)
|
|
575
|
-
|
|
576
|
-
return result["num"]
|
|
491
|
+
return int(rows[0]["num"])
|
|
577
492
|
|
|
578
493
|
def delete_messages(self, message_ins_ids: set[str]) -> None:
|
|
579
494
|
"""Delete a Message and its reply based on provided Message IDs."""
|
|
580
495
|
if not message_ins_ids:
|
|
581
496
|
return
|
|
582
|
-
if self.conn is None:
|
|
583
|
-
raise AttributeError("LinkState not initialized")
|
|
584
497
|
|
|
585
|
-
placeholders = ",".join(["
|
|
586
|
-
|
|
498
|
+
placeholders = ",".join([f":mid_{i}" for i in range(len(message_ins_ids))])
|
|
499
|
+
params = {f"mid_{i}": str(mid) for i, mid in enumerate(message_ins_ids)}
|
|
587
500
|
|
|
588
501
|
# Delete Message
|
|
589
502
|
query_1 = f"""
|
|
590
503
|
DELETE FROM message_ins
|
|
591
|
-
WHERE message_id IN ({placeholders})
|
|
504
|
+
WHERE message_id IN ({placeholders})
|
|
592
505
|
"""
|
|
593
506
|
|
|
594
507
|
# Delete reply Message
|
|
595
508
|
query_2 = f"""
|
|
596
509
|
DELETE FROM message_res
|
|
597
|
-
WHERE reply_to_message_id IN ({placeholders})
|
|
510
|
+
WHERE reply_to_message_id IN ({placeholders})
|
|
598
511
|
"""
|
|
599
512
|
|
|
600
|
-
with self.
|
|
601
|
-
self.
|
|
602
|
-
self.
|
|
513
|
+
with self.session():
|
|
514
|
+
self.query(query_1, params)
|
|
515
|
+
self.query(query_2, params)
|
|
603
516
|
|
|
604
517
|
def get_message_ids_from_run_id(self, run_id: int) -> set[str]:
|
|
605
518
|
"""Get all instruction Message IDs for the given run_id."""
|
|
606
|
-
if self.conn is None:
|
|
607
|
-
raise AttributeError("LinkState not initialized")
|
|
608
|
-
|
|
609
519
|
query = """
|
|
610
520
|
SELECT message_id
|
|
611
521
|
FROM message_ins
|
|
612
|
-
WHERE run_id = :run_id
|
|
522
|
+
WHERE run_id = :run_id
|
|
613
523
|
"""
|
|
614
|
-
|
|
615
524
|
sint64_run_id = uint64_to_int64(run_id)
|
|
616
|
-
|
|
525
|
+
params = {"run_id": sint64_run_id}
|
|
617
526
|
|
|
618
|
-
with self.
|
|
619
|
-
rows = self.
|
|
527
|
+
with self.session():
|
|
528
|
+
rows = self.query(query, params)
|
|
620
529
|
|
|
621
530
|
return {row["message_id"] for row in rows}
|
|
622
531
|
|
|
@@ -641,29 +550,31 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
641
550
|
(node_id, owner_aid, owner_name, status, registered_at, last_activated_at,
|
|
642
551
|
last_deactivated_at, unregistered_at, online_until, heartbeat_interval,
|
|
643
552
|
public_key)
|
|
644
|
-
VALUES (
|
|
553
|
+
VALUES (:node_id, :owner_aid, :owner_name, :status, :registered_at,
|
|
554
|
+
:last_activated_at, :last_deactivated_at, :unregistered_at, :online_until,
|
|
555
|
+
:heartbeat_interval, :public_key)
|
|
645
556
|
"""
|
|
646
557
|
|
|
647
558
|
# Mark the node online until now().timestamp() + heartbeat_interval
|
|
648
559
|
try:
|
|
649
560
|
self.query(
|
|
650
561
|
query,
|
|
651
|
-
|
|
652
|
-
sint64_node_id,
|
|
653
|
-
owner_aid
|
|
654
|
-
owner_name
|
|
655
|
-
NodeStatus.REGISTERED,
|
|
656
|
-
now().isoformat(),
|
|
657
|
-
None,
|
|
658
|
-
None,
|
|
659
|
-
None,
|
|
660
|
-
None, #
|
|
661
|
-
heartbeat_interval
|
|
662
|
-
public_key
|
|
663
|
-
|
|
562
|
+
{
|
|
563
|
+
"node_id": sint64_node_id,
|
|
564
|
+
"owner_aid": owner_aid,
|
|
565
|
+
"owner_name": owner_name,
|
|
566
|
+
"status": NodeStatus.REGISTERED,
|
|
567
|
+
"registered_at": now().isoformat(),
|
|
568
|
+
"last_activated_at": None,
|
|
569
|
+
"last_deactivated_at": None,
|
|
570
|
+
"unregistered_at": None,
|
|
571
|
+
"online_until": None, # initialized with offline status
|
|
572
|
+
"heartbeat_interval": heartbeat_interval,
|
|
573
|
+
"public_key": public_key,
|
|
574
|
+
},
|
|
664
575
|
)
|
|
665
|
-
except
|
|
666
|
-
if "
|
|
576
|
+
except IntegrityError as e:
|
|
577
|
+
if "node.public_key" in str(e):
|
|
667
578
|
raise ValueError("Public key already in use.") from None
|
|
668
579
|
# Must be node ID conflict, almost impossible unless system is compromised
|
|
669
580
|
log(ERROR, "Unexpected node registration failure.")
|
|
@@ -678,21 +589,20 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
678
589
|
|
|
679
590
|
query = """
|
|
680
591
|
UPDATE node
|
|
681
|
-
SET status =
|
|
682
|
-
online_until = IIF(online_until >
|
|
683
|
-
WHERE node_id =
|
|
592
|
+
SET status = :unregistered, unregistered_at = :unregistered_at,
|
|
593
|
+
online_until = IIF(online_until > :current, :current, online_until)
|
|
594
|
+
WHERE node_id = :node_id AND status != :unregistered
|
|
595
|
+
AND owner_aid = :owner_aid
|
|
684
596
|
RETURNING node_id
|
|
685
597
|
"""
|
|
686
598
|
current = now()
|
|
687
|
-
params =
|
|
688
|
-
NodeStatus.UNREGISTERED,
|
|
689
|
-
current.isoformat(),
|
|
690
|
-
current.timestamp(),
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
owner_aid,
|
|
695
|
-
)
|
|
599
|
+
params = {
|
|
600
|
+
"unregistered": NodeStatus.UNREGISTERED,
|
|
601
|
+
"unregistered_at": current.isoformat(),
|
|
602
|
+
"current": current.timestamp(),
|
|
603
|
+
"node_id": sint64_node_id,
|
|
604
|
+
"owner_aid": owner_aid,
|
|
605
|
+
}
|
|
696
606
|
|
|
697
607
|
rows = self.query(query, params)
|
|
698
608
|
if not rows:
|
|
@@ -703,58 +613,58 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
703
613
|
|
|
704
614
|
def activate_node(self, node_id: int, heartbeat_interval: float) -> bool:
|
|
705
615
|
"""Activate the node with the specified `node_id`."""
|
|
706
|
-
|
|
707
|
-
self._check_and_tag_offline_nodes([node_id])
|
|
616
|
+
self._check_and_tag_offline_nodes([node_id])
|
|
708
617
|
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
618
|
+
# Only activate if the node is currently registered or offline
|
|
619
|
+
current_dt = now()
|
|
620
|
+
sint64_node_id = uint64_to_int64(node_id)
|
|
621
|
+
query = """
|
|
622
|
+
UPDATE node
|
|
623
|
+
SET status = :online,
|
|
624
|
+
last_activated_at = :current,
|
|
625
|
+
online_until = :online_until,
|
|
626
|
+
heartbeat_interval = :heartbeat_interval
|
|
627
|
+
WHERE node_id = :node_id AND status IN (:registered, :offline)
|
|
628
|
+
RETURNING node_id
|
|
629
|
+
"""
|
|
630
|
+
params = {
|
|
631
|
+
"online": NodeStatus.ONLINE,
|
|
632
|
+
"current": current_dt.isoformat(),
|
|
633
|
+
"online_until": current_dt.timestamp()
|
|
634
|
+
+ HEARTBEAT_PATIENCE * heartbeat_interval,
|
|
635
|
+
"heartbeat_interval": heartbeat_interval,
|
|
636
|
+
"node_id": sint64_node_id,
|
|
637
|
+
"registered": NodeStatus.REGISTERED,
|
|
638
|
+
"offline": NodeStatus.OFFLINE,
|
|
639
|
+
}
|
|
729
640
|
|
|
730
|
-
|
|
731
|
-
|
|
641
|
+
rows = self.query(query, params)
|
|
642
|
+
return len(rows) > 0
|
|
732
643
|
|
|
733
644
|
def deactivate_node(self, node_id: int) -> bool:
|
|
734
645
|
"""Deactivate the node with the specified `node_id`."""
|
|
735
|
-
|
|
736
|
-
self._check_and_tag_offline_nodes([node_id])
|
|
646
|
+
self._check_and_tag_offline_nodes([node_id])
|
|
737
647
|
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
648
|
+
# Only deactivate if the node is currently online
|
|
649
|
+
current_dt = now()
|
|
650
|
+
query = """
|
|
651
|
+
UPDATE node
|
|
652
|
+
SET status = :offline,
|
|
653
|
+
last_deactivated_at = :current_iso,
|
|
654
|
+
online_until = :current_ts
|
|
655
|
+
WHERE node_id = :node_id AND status = :online
|
|
656
|
+
RETURNING node_id
|
|
657
|
+
"""
|
|
658
|
+
params = {
|
|
659
|
+
"offline": NodeStatus.OFFLINE,
|
|
660
|
+
"current_iso": current_dt.isoformat(),
|
|
661
|
+
"current_ts": current_dt.timestamp(),
|
|
662
|
+
"node_id": uint64_to_int64(node_id),
|
|
663
|
+
"online": NodeStatus.ONLINE,
|
|
664
|
+
}
|
|
755
665
|
|
|
756
|
-
|
|
757
|
-
|
|
666
|
+
rows = self.query(query, params)
|
|
667
|
+
return len(rows) > 0
|
|
758
668
|
|
|
759
669
|
def get_nodes(self, run_id: int) -> set[int]:
|
|
760
670
|
"""Retrieve all currently stored node IDs as a set.
|
|
@@ -764,16 +674,13 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
764
674
|
If the provided `run_id` does not exist or has no matching nodes,
|
|
765
675
|
an empty `Set` MUST be returned.
|
|
766
676
|
"""
|
|
767
|
-
|
|
768
|
-
raise AttributeError("LinkState not initialized")
|
|
769
|
-
|
|
770
|
-
with self.conn:
|
|
677
|
+
with self.session():
|
|
771
678
|
# Convert the uint64 value to sint64 for SQLite
|
|
772
679
|
sint64_run_id = uint64_to_int64(run_id)
|
|
773
680
|
|
|
774
681
|
# Validate run ID
|
|
775
|
-
query = "SELECT federation FROM run WHERE run_id =
|
|
776
|
-
rows = self.
|
|
682
|
+
query = "SELECT federation FROM run WHERE run_id = :run_id"
|
|
683
|
+
rows = self.query(query, {"run_id": sint64_run_id})
|
|
777
684
|
if not rows:
|
|
778
685
|
return set()
|
|
779
686
|
federation: str = rows[0]["federation"]
|
|
@@ -790,23 +697,25 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
790
697
|
"""Check and tag offline nodes."""
|
|
791
698
|
# strftime will convert POSIX timestamp to ISO format
|
|
792
699
|
query = """
|
|
793
|
-
UPDATE node SET status =
|
|
700
|
+
UPDATE node SET status = :offline,
|
|
794
701
|
last_deactivated_at =
|
|
795
|
-
strftime(
|
|
796
|
-
WHERE online_until <=
|
|
702
|
+
strftime('%Y-%m-%dT%H:%M:%f+00:00', online_until, 'unixepoch')
|
|
703
|
+
WHERE online_until <= :current_time AND status = :online
|
|
797
704
|
"""
|
|
798
|
-
params =
|
|
799
|
-
NodeStatus.OFFLINE,
|
|
800
|
-
now().timestamp(),
|
|
801
|
-
NodeStatus.ONLINE,
|
|
802
|
-
|
|
705
|
+
params: dict[str, Any] = {
|
|
706
|
+
"offline": NodeStatus.OFFLINE,
|
|
707
|
+
"current_time": now().timestamp(),
|
|
708
|
+
"online": NodeStatus.ONLINE,
|
|
709
|
+
}
|
|
803
710
|
if node_ids is not None:
|
|
804
|
-
placeholders = ",".join(["
|
|
711
|
+
placeholders = ",".join([f":nid_{i}" for i in range(len(node_ids))])
|
|
805
712
|
query += f" AND node_id IN ({placeholders})"
|
|
806
|
-
params.
|
|
807
|
-
|
|
713
|
+
params.update(
|
|
714
|
+
{f"nid_{i}": uint64_to_int64(nid) for i, nid in enumerate(node_ids)}
|
|
715
|
+
)
|
|
716
|
+
self.query(query, params)
|
|
808
717
|
|
|
809
|
-
def get_node_info(
|
|
718
|
+
def get_node_info( # pylint: disable=too-many-locals
|
|
810
719
|
self,
|
|
811
720
|
*,
|
|
812
721
|
node_ids: Sequence[int] | None = None,
|
|
@@ -814,32 +723,37 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
814
723
|
statuses: Sequence[str] | None = None,
|
|
815
724
|
) -> Sequence[NodeInfo]:
|
|
816
725
|
"""Retrieve information about nodes based on the specified filters."""
|
|
817
|
-
with self.
|
|
726
|
+
with self.session():
|
|
818
727
|
self._check_and_tag_offline_nodes()
|
|
819
728
|
|
|
820
729
|
# Build the WHERE clause based on provided filters
|
|
821
730
|
conditions = []
|
|
822
|
-
params:
|
|
731
|
+
params: dict[str, Any] = {}
|
|
823
732
|
if node_ids is not None:
|
|
824
733
|
sint64_node_ids = [uint64_to_int64(node_id) for node_id in node_ids]
|
|
825
|
-
placeholders = ",".join(
|
|
734
|
+
placeholders = ",".join(
|
|
735
|
+
[f":nid_{i}" for i in range(len(sint64_node_ids))]
|
|
736
|
+
)
|
|
826
737
|
conditions.append(f"node_id IN ({placeholders})")
|
|
827
|
-
|
|
738
|
+
for i, nid in enumerate(sint64_node_ids):
|
|
739
|
+
params[f"nid_{i}"] = nid
|
|
828
740
|
if owner_aids is not None:
|
|
829
|
-
placeholders = ",".join(["
|
|
741
|
+
placeholders = ",".join([f":aid_{i}" for i in range(len(owner_aids))])
|
|
830
742
|
conditions.append(f"owner_aid IN ({placeholders})")
|
|
831
|
-
|
|
743
|
+
for i, aid in enumerate(owner_aids):
|
|
744
|
+
params[f"aid_{i}"] = aid
|
|
832
745
|
if statuses is not None:
|
|
833
|
-
placeholders = ",".join(["
|
|
746
|
+
placeholders = ",".join([f":st_{i}" for i in range(len(statuses))])
|
|
834
747
|
conditions.append(f"status IN ({placeholders})")
|
|
835
|
-
|
|
748
|
+
for i, status in enumerate(statuses):
|
|
749
|
+
params[f"st_{i}"] = status
|
|
836
750
|
|
|
837
751
|
# Construct the final query
|
|
838
752
|
query = "SELECT * FROM node"
|
|
839
753
|
if conditions:
|
|
840
754
|
query += " WHERE " + " AND ".join(conditions)
|
|
841
755
|
|
|
842
|
-
rows = self.
|
|
756
|
+
rows = self.query(query, params)
|
|
843
757
|
|
|
844
758
|
result: list[NodeInfo] = []
|
|
845
759
|
for row in rows:
|
|
@@ -849,27 +763,14 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
849
763
|
|
|
850
764
|
return result
|
|
851
765
|
|
|
852
|
-
def get_node_public_key(self, node_id: int) -> bytes:
|
|
853
|
-
"""Get `public_key` for the specified `node_id`."""
|
|
854
|
-
# Convert the uint64 value to sint64 for SQLite
|
|
855
|
-
sint64_node_id = uint64_to_int64(node_id)
|
|
856
|
-
|
|
857
|
-
# Query the public key for the given node_id
|
|
858
|
-
query = "SELECT public_key FROM node WHERE node_id = ? AND status != ?;"
|
|
859
|
-
rows = self.query(query, (sint64_node_id, NodeStatus.UNREGISTERED))
|
|
860
|
-
|
|
861
|
-
# If no result is found, return None
|
|
862
|
-
if not rows:
|
|
863
|
-
raise ValueError(f"Node ID {node_id} not found")
|
|
864
|
-
|
|
865
|
-
# Return the public key
|
|
866
|
-
return cast(bytes, rows[0]["public_key"])
|
|
867
|
-
|
|
868
766
|
def get_node_id_by_public_key(self, public_key: bytes) -> int | None:
|
|
869
767
|
"""Get `node_id` for the specified `public_key` if it exists and is not
|
|
870
768
|
deleted."""
|
|
871
|
-
query = "SELECT node_id FROM node
|
|
872
|
-
|
|
769
|
+
query = """SELECT node_id FROM node
|
|
770
|
+
WHERE public_key = :public_key AND status != :unregistered;"""
|
|
771
|
+
rows = self.query(
|
|
772
|
+
query, {"public_key": public_key, "unregistered": NodeStatus.UNREGISTERED}
|
|
773
|
+
)
|
|
873
774
|
|
|
874
775
|
# If no result is found, return None
|
|
875
776
|
if not rows:
|
|
@@ -879,8 +780,7 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
879
780
|
node_id = int64_to_uint64(rows[0]["node_id"])
|
|
880
781
|
return node_id
|
|
881
782
|
|
|
882
|
-
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
883
|
-
def create_run(
|
|
783
|
+
def create_run( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
884
784
|
self,
|
|
885
785
|
fab_id: str | None,
|
|
886
786
|
fab_version: str | None,
|
|
@@ -897,41 +797,43 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
897
797
|
# Convert the uint64 value to sint64 for SQLite
|
|
898
798
|
sint64_run_id = uint64_to_int64(uint64_run_id)
|
|
899
799
|
|
|
900
|
-
with self.
|
|
800
|
+
with self.session():
|
|
901
801
|
# Check conflicts
|
|
902
|
-
query = "SELECT COUNT(*) FROM run WHERE run_id =
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
if row["COUNT(*)"] == 0:
|
|
802
|
+
query = "SELECT COUNT(*) as cnt FROM run WHERE run_id = :run_id"
|
|
803
|
+
rows = self.query(query, {"run_id": sint64_run_id})
|
|
804
|
+
if rows[0]["cnt"] == 0:
|
|
906
805
|
query = """
|
|
907
806
|
INSERT INTO run
|
|
908
807
|
(run_id, fab_id, fab_version,
|
|
909
808
|
fab_hash, override_config, federation, federation_options,
|
|
910
809
|
pending_at, starting_at, running_at, finished_at, sub_status,
|
|
911
810
|
details, flwr_aid, bytes_sent, bytes_recv, clientapp_runtime)
|
|
912
|
-
VALUES (
|
|
811
|
+
VALUES (:run_id, :fab_id, :fab_version, :fab_hash, :override_config,
|
|
812
|
+
:federation, :federation_options, :pending_at, :starting_at,
|
|
813
|
+
:running_at, :finished_at, :sub_status, :details, :flwr_aid,
|
|
814
|
+
:bytes_sent, :bytes_recv, :clientapp_runtime)
|
|
913
815
|
"""
|
|
914
816
|
override_config_json = json.dumps(override_config)
|
|
915
|
-
|
|
916
|
-
sint64_run_id,
|
|
917
|
-
fab_id
|
|
918
|
-
fab_version
|
|
919
|
-
fab_hash
|
|
920
|
-
override_config_json,
|
|
921
|
-
federation
|
|
922
|
-
configrecord_to_bytes(federation_options),
|
|
923
|
-
now().isoformat(),
|
|
924
|
-
"",
|
|
925
|
-
"",
|
|
926
|
-
"",
|
|
927
|
-
"",
|
|
928
|
-
"",
|
|
929
|
-
flwr_aid or "",
|
|
930
|
-
0,
|
|
931
|
-
0,
|
|
932
|
-
0,
|
|
933
|
-
|
|
934
|
-
self.
|
|
817
|
+
params = {
|
|
818
|
+
"run_id": sint64_run_id,
|
|
819
|
+
"fab_id": fab_id or "",
|
|
820
|
+
"fab_version": fab_version or "",
|
|
821
|
+
"fab_hash": fab_hash or "",
|
|
822
|
+
"override_config": override_config_json,
|
|
823
|
+
"federation": federation,
|
|
824
|
+
"federation_options": configrecord_to_bytes(federation_options),
|
|
825
|
+
"pending_at": now().isoformat(),
|
|
826
|
+
"starting_at": "",
|
|
827
|
+
"running_at": "",
|
|
828
|
+
"finished_at": "",
|
|
829
|
+
"sub_status": "",
|
|
830
|
+
"details": "",
|
|
831
|
+
"flwr_aid": flwr_aid or "",
|
|
832
|
+
"bytes_sent": 0,
|
|
833
|
+
"bytes_recv": 0,
|
|
834
|
+
"clientapp_runtime": 0.0,
|
|
835
|
+
}
|
|
836
|
+
self.query(query, params)
|
|
935
837
|
return uint64_run_id
|
|
936
838
|
log(ERROR, "Unexpected run creation failure.")
|
|
937
839
|
return 0
|
|
@@ -943,11 +845,11 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
943
845
|
"""
|
|
944
846
|
if flwr_aid:
|
|
945
847
|
rows = self.query(
|
|
946
|
-
"SELECT run_id FROM run WHERE flwr_aid =
|
|
947
|
-
|
|
848
|
+
"SELECT run_id FROM run WHERE flwr_aid = :flwr_aid",
|
|
849
|
+
{"flwr_aid": flwr_aid},
|
|
948
850
|
)
|
|
949
851
|
else:
|
|
950
|
-
rows = self.query("SELECT run_id FROM run
|
|
852
|
+
rows = self.query("SELECT run_id FROM run", {})
|
|
951
853
|
return {int64_to_uint64(row["run_id"]) for row in rows}
|
|
952
854
|
|
|
953
855
|
def get_run(self, run_id: int) -> Run | None:
|
|
@@ -957,8 +859,8 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
957
859
|
|
|
958
860
|
# Convert the uint64 value to sint64 for SQLite
|
|
959
861
|
sint64_run_id = uint64_to_int64(run_id)
|
|
960
|
-
query = "SELECT * FROM run WHERE run_id =
|
|
961
|
-
rows = self.query(query,
|
|
862
|
+
query = "SELECT * FROM run WHERE run_id = :run_id"
|
|
863
|
+
rows = self.query(query, {"run_id": sint64_run_id})
|
|
962
864
|
if rows:
|
|
963
865
|
row = rows[0]
|
|
964
866
|
return Run(
|
|
@@ -991,9 +893,10 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
991
893
|
self._cleanup_expired_tokens()
|
|
992
894
|
|
|
993
895
|
# Convert the uint64 value to sint64 for SQLite
|
|
994
|
-
|
|
995
|
-
query = f"SELECT * FROM run WHERE run_id IN ({
|
|
996
|
-
|
|
896
|
+
placeholders = ",".join([f":rid_{i}" for i in range(len(run_ids))])
|
|
897
|
+
query = f"SELECT * FROM run WHERE run_id IN ({placeholders})"
|
|
898
|
+
params = {f"rid_{i}": uint64_to_int64(rid) for i, rid in enumerate(run_ids)}
|
|
899
|
+
rows = self.query(query, params)
|
|
997
900
|
|
|
998
901
|
return {
|
|
999
902
|
# Restore uint64 run IDs
|
|
@@ -1010,11 +913,11 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1010
913
|
# Clean up expired tokens; this will flag inactive runs as needed
|
|
1011
914
|
self._cleanup_expired_tokens()
|
|
1012
915
|
|
|
1013
|
-
with self.
|
|
916
|
+
with self.session():
|
|
1014
917
|
# Convert the uint64 value to sint64 for SQLite
|
|
1015
918
|
sint64_run_id = uint64_to_int64(run_id)
|
|
1016
|
-
query = "SELECT * FROM run WHERE run_id =
|
|
1017
|
-
rows = self.
|
|
919
|
+
query = "SELECT * FROM run WHERE run_id = :run_id"
|
|
920
|
+
rows = self.query(query, {"run_id": sint64_run_id})
|
|
1018
921
|
|
|
1019
922
|
# Check if the run_id exists
|
|
1020
923
|
if not rows:
|
|
@@ -1049,7 +952,9 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1049
952
|
|
|
1050
953
|
# Update the status
|
|
1051
954
|
query = """
|
|
1052
|
-
UPDATE run SET %s
|
|
955
|
+
UPDATE run SET %s = :timestamp,
|
|
956
|
+
sub_status = :sub_status, details = :details
|
|
957
|
+
WHERE run_id = :run_id
|
|
1053
958
|
"""
|
|
1054
959
|
|
|
1055
960
|
# Prepare data for query
|
|
@@ -1064,33 +969,30 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1064
969
|
elif new_status.status == Status.FINISHED:
|
|
1065
970
|
timestamp_fld = "finished_at"
|
|
1066
971
|
|
|
1067
|
-
|
|
1068
|
-
current.isoformat(),
|
|
1069
|
-
new_status.sub_status,
|
|
1070
|
-
new_status.details,
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
self.
|
|
972
|
+
params = {
|
|
973
|
+
"timestamp": current.isoformat(),
|
|
974
|
+
"sub_status": new_status.sub_status,
|
|
975
|
+
"details": new_status.details,
|
|
976
|
+
"run_id": sint64_run_id,
|
|
977
|
+
}
|
|
978
|
+
self.query(query % timestamp_fld, params)
|
|
1074
979
|
return True
|
|
1075
980
|
|
|
1076
981
|
def get_pending_run_id(self) -> int | None:
|
|
1077
|
-
"""Get the `run_id` of a run with `Status.PENDING` status
|
|
1078
|
-
pending_run_id = None
|
|
1079
|
-
|
|
982
|
+
"""Get the `run_id` of a run with `Status.PENDING` status."""
|
|
1080
983
|
# Fetch all runs with unset `starting_at` (i.e. they are in PENDING status)
|
|
1081
|
-
query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1
|
|
1082
|
-
rows = self.query(query)
|
|
984
|
+
query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1"
|
|
985
|
+
rows = self.query(query, {})
|
|
1083
986
|
if rows:
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
return pending_run_id
|
|
987
|
+
return int64_to_uint64(rows[0]["run_id"])
|
|
988
|
+
return None
|
|
1087
989
|
|
|
1088
990
|
def get_federation_options(self, run_id: int) -> ConfigRecord | None:
|
|
1089
991
|
"""Retrieve the federation options for the specified `run_id`."""
|
|
1090
992
|
# Convert the uint64 value to sint64 for SQLite
|
|
1091
993
|
sint64_run_id = uint64_to_int64(run_id)
|
|
1092
|
-
query = "SELECT federation_options FROM run WHERE run_id =
|
|
1093
|
-
rows = self.query(query,
|
|
994
|
+
query = "SELECT federation_options FROM run WHERE run_id = :run_id"
|
|
995
|
+
rows = self.query(query, {"run_id": sint64_run_id})
|
|
1094
996
|
|
|
1095
997
|
# Check if the run_id exists
|
|
1096
998
|
if not rows:
|
|
@@ -1110,41 +1012,46 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1110
1012
|
HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before
|
|
1111
1013
|
the node is marked as offline.
|
|
1112
1014
|
"""
|
|
1113
|
-
if self.conn is None:
|
|
1114
|
-
raise AttributeError("LinkState not initialized")
|
|
1115
|
-
|
|
1116
1015
|
sint64_node_id = uint64_to_int64(node_id)
|
|
1117
1016
|
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1017
|
+
# Check if the node exists and is not unregistered
|
|
1018
|
+
query = """
|
|
1019
|
+
SELECT status FROM node WHERE node_id = :node_id AND status != :unregistered
|
|
1020
|
+
"""
|
|
1021
|
+
rows = self.query(
|
|
1022
|
+
query, {"node_id": sint64_node_id, "unregistered": NodeStatus.UNREGISTERED}
|
|
1023
|
+
)
|
|
1024
|
+
if not rows:
|
|
1025
|
+
return False
|
|
1126
1026
|
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1027
|
+
# Construct query and params
|
|
1028
|
+
current_dt = now()
|
|
1029
|
+
query = (
|
|
1030
|
+
"UPDATE node SET online_until = :online_until, "
|
|
1031
|
+
"heartbeat_interval = :heartbeat_interval"
|
|
1032
|
+
)
|
|
1033
|
+
params: dict[str, Any] = {
|
|
1034
|
+
"online_until": current_dt.timestamp()
|
|
1035
|
+
+ HEARTBEAT_PATIENCE * heartbeat_interval,
|
|
1036
|
+
"heartbeat_interval": heartbeat_interval,
|
|
1037
|
+
}
|
|
1134
1038
|
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1039
|
+
# Set timestamp if the status changes
|
|
1040
|
+
if rows[0]["status"] != NodeStatus.ONLINE:
|
|
1041
|
+
query += ", status = :online, last_activated_at = :last_activated_at"
|
|
1042
|
+
params["online"] = NodeStatus.ONLINE
|
|
1043
|
+
params["last_activated_at"] = current_dt.isoformat()
|
|
1139
1044
|
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1045
|
+
# Execute the query, refreshing `online_until` and `heartbeat_interval`
|
|
1046
|
+
query += " WHERE node_id = :node_id"
|
|
1047
|
+
params["node_id"] = sint64_node_id
|
|
1048
|
+
self.query(query, params)
|
|
1049
|
+
return True
|
|
1145
1050
|
|
|
1146
1051
|
def _on_tokens_expired(self, expired_records: list[tuple[int, float]]) -> None:
|
|
1147
|
-
"""
|
|
1052
|
+
"""Handle cleanup of expired tokens.
|
|
1053
|
+
|
|
1054
|
+
Override in subclasses to add custom cleanup logic.
|
|
1148
1055
|
|
|
1149
1056
|
Parameters
|
|
1150
1057
|
----------
|
|
@@ -1155,28 +1062,30 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1155
1062
|
if not expired_records:
|
|
1156
1063
|
return
|
|
1157
1064
|
|
|
1158
|
-
with self.
|
|
1065
|
+
with self.session():
|
|
1159
1066
|
query = """
|
|
1160
1067
|
UPDATE run
|
|
1161
|
-
SET sub_status =
|
|
1162
|
-
WHERE run_id =
|
|
1068
|
+
SET sub_status = :failed, details = :details, finished_at = :finished_at
|
|
1069
|
+
WHERE run_id = :run_id
|
|
1163
1070
|
"""
|
|
1164
1071
|
data = [
|
|
1165
|
-
|
|
1166
|
-
SubStatus.FAILED,
|
|
1167
|
-
RUN_FAILURE_DETAILS_NO_HEARTBEAT,
|
|
1168
|
-
datetime.fromtimestamp(
|
|
1169
|
-
|
|
1170
|
-
|
|
1072
|
+
{
|
|
1073
|
+
"failed": SubStatus.FAILED,
|
|
1074
|
+
"details": RUN_FAILURE_DETAILS_NO_HEARTBEAT,
|
|
1075
|
+
"finished_at": datetime.fromtimestamp(
|
|
1076
|
+
active_until, tz=timezone.utc
|
|
1077
|
+
).isoformat(),
|
|
1078
|
+
"run_id": uint64_to_int64(run_id),
|
|
1079
|
+
}
|
|
1171
1080
|
for run_id, active_until in expired_records
|
|
1172
1081
|
]
|
|
1173
|
-
self.
|
|
1082
|
+
self.query(query, data)
|
|
1174
1083
|
|
|
1175
1084
|
def get_serverapp_context(self, run_id: int) -> Context | None:
|
|
1176
1085
|
"""Get the context for the specified `run_id`."""
|
|
1177
1086
|
# Retrieve context if any
|
|
1178
|
-
query = "SELECT context FROM context WHERE run_id =
|
|
1179
|
-
rows = self.query(query,
|
|
1087
|
+
query = "SELECT context FROM context WHERE run_id = :run_id"
|
|
1088
|
+
rows = self.query(query, {"run_id": uint64_to_int64(run_id)})
|
|
1180
1089
|
context = context_from_bytes(rows[0]["context"]) if rows else None
|
|
1181
1090
|
return context
|
|
1182
1091
|
|
|
@@ -1186,20 +1095,30 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1186
1095
|
context_bytes = context_to_bytes(context)
|
|
1187
1096
|
sint_run_id = uint64_to_int64(run_id)
|
|
1188
1097
|
|
|
1189
|
-
with self.
|
|
1098
|
+
with self.session():
|
|
1190
1099
|
# Check if any existing Context assigned to the run_id
|
|
1191
|
-
query = "SELECT COUNT(*) FROM context WHERE run_id =
|
|
1192
|
-
row = self.
|
|
1193
|
-
if row["
|
|
1100
|
+
query = "SELECT COUNT(*) as count FROM context WHERE run_id = :run_id"
|
|
1101
|
+
row = self.query(query, {"run_id": sint_run_id})[0]
|
|
1102
|
+
if row["count"] > 0:
|
|
1194
1103
|
# Update context
|
|
1195
|
-
query = "
|
|
1196
|
-
|
|
1104
|
+
query = """
|
|
1105
|
+
UPDATE context
|
|
1106
|
+
SET context = :context_bytes WHERE run_id = :run_id
|
|
1107
|
+
"""
|
|
1108
|
+
self.query(
|
|
1109
|
+
query, {"context_bytes": context_bytes, "run_id": sint_run_id}
|
|
1110
|
+
)
|
|
1197
1111
|
else:
|
|
1198
1112
|
try:
|
|
1199
1113
|
# Store context
|
|
1200
|
-
query =
|
|
1201
|
-
|
|
1202
|
-
|
|
1114
|
+
query = (
|
|
1115
|
+
"INSERT INTO context (run_id, context) "
|
|
1116
|
+
"VALUES (:run_id, :context_bytes)"
|
|
1117
|
+
)
|
|
1118
|
+
self.query(
|
|
1119
|
+
query, {"run_id": sint_run_id, "context_bytes": context_bytes}
|
|
1120
|
+
)
|
|
1121
|
+
except IntegrityError:
|
|
1203
1122
|
raise ValueError(f"Run {run_id} not found") from None
|
|
1204
1123
|
|
|
1205
1124
|
def add_serverapp_log(self, run_id: int, log_message: str) -> None:
|
|
@@ -1210,10 +1129,19 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1210
1129
|
# Store log
|
|
1211
1130
|
try:
|
|
1212
1131
|
query = """
|
|
1213
|
-
INSERT INTO logs (timestamp, run_id, node_id, log)
|
|
1132
|
+
INSERT INTO logs (timestamp, run_id, node_id, log)
|
|
1133
|
+
VALUES (:current_ts, :run_id, :node_id, :log)
|
|
1214
1134
|
"""
|
|
1215
|
-
self.query(
|
|
1216
|
-
|
|
1135
|
+
self.query(
|
|
1136
|
+
query,
|
|
1137
|
+
{
|
|
1138
|
+
"current_ts": now().timestamp(),
|
|
1139
|
+
"run_id": sint64_run_id,
|
|
1140
|
+
"node_id": 0,
|
|
1141
|
+
"log": log_message,
|
|
1142
|
+
},
|
|
1143
|
+
)
|
|
1144
|
+
except IntegrityError:
|
|
1217
1145
|
raise ValueError(f"Run {run_id} not found") from None
|
|
1218
1146
|
|
|
1219
1147
|
def get_serverapp_log(
|
|
@@ -1223,10 +1151,10 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1223
1151
|
# Convert the uint64 value to sint64 for SQLite
|
|
1224
1152
|
sint64_run_id = uint64_to_int64(run_id)
|
|
1225
1153
|
|
|
1226
|
-
with self.
|
|
1154
|
+
with self.session():
|
|
1227
1155
|
# Check if the run_id exists
|
|
1228
|
-
query = "SELECT run_id FROM run WHERE run_id =
|
|
1229
|
-
rows = self.
|
|
1156
|
+
query = "SELECT run_id FROM run WHERE run_id = :run_id"
|
|
1157
|
+
rows = self.query(query, {"run_id": sint64_run_id})
|
|
1230
1158
|
if not rows:
|
|
1231
1159
|
raise ValueError(f"Run {run_id} not found")
|
|
1232
1160
|
|
|
@@ -1235,12 +1163,18 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1235
1163
|
after_timestamp = 0.0
|
|
1236
1164
|
query = """
|
|
1237
1165
|
SELECT log, timestamp FROM logs
|
|
1238
|
-
WHERE run_id =
|
|
1166
|
+
WHERE run_id = :run_id AND node_id = :node_id
|
|
1167
|
+
AND timestamp > :after_timestamp
|
|
1168
|
+
ORDER BY timestamp
|
|
1239
1169
|
"""
|
|
1240
|
-
rows = self.
|
|
1241
|
-
query,
|
|
1242
|
-
|
|
1243
|
-
|
|
1170
|
+
rows = self.query(
|
|
1171
|
+
query,
|
|
1172
|
+
{
|
|
1173
|
+
"run_id": sint64_run_id,
|
|
1174
|
+
"node_id": 0,
|
|
1175
|
+
"after_timestamp": after_timestamp,
|
|
1176
|
+
},
|
|
1177
|
+
)
|
|
1244
1178
|
latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
|
|
1245
1179
|
return "".join(row["log"] for row in rows), latest_timestamp
|
|
1246
1180
|
|
|
@@ -1249,20 +1183,19 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1249
1183
|
|
|
1250
1184
|
Return Message if valid.
|
|
1251
1185
|
"""
|
|
1252
|
-
with self.
|
|
1186
|
+
with self.session():
|
|
1253
1187
|
self._check_stored_messages({message_id})
|
|
1254
1188
|
query = """
|
|
1255
1189
|
SELECT *
|
|
1256
1190
|
FROM message_ins
|
|
1257
1191
|
WHERE message_id = :message_id
|
|
1258
1192
|
"""
|
|
1259
|
-
|
|
1260
|
-
rows: list[dict[str, Any]] = self.conn.execute(query, data).fetchall()
|
|
1193
|
+
rows = self.query(query, {"message_id": message_id})
|
|
1261
1194
|
if not rows:
|
|
1262
1195
|
# Message does not exist
|
|
1263
1196
|
return None
|
|
1264
1197
|
|
|
1265
|
-
return rows[0]
|
|
1198
|
+
return dict(rows[0])
|
|
1266
1199
|
|
|
1267
1200
|
def store_traffic(self, run_id: int, *, bytes_sent: int, bytes_recv: int) -> None:
|
|
1268
1201
|
"""Store traffic data for the specified `run_id`."""
|
|
@@ -1280,18 +1213,23 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1280
1213
|
|
|
1281
1214
|
sint64_run_id = uint64_to_int64(run_id)
|
|
1282
1215
|
|
|
1283
|
-
with self.
|
|
1216
|
+
with self.session():
|
|
1284
1217
|
# Check if run exists, performing the update only if it does
|
|
1285
1218
|
update_query = """
|
|
1286
1219
|
UPDATE run
|
|
1287
|
-
SET bytes_sent = bytes_sent +
|
|
1288
|
-
bytes_recv = bytes_recv +
|
|
1289
|
-
WHERE run_id =
|
|
1290
|
-
RETURNING run_id
|
|
1220
|
+
SET bytes_sent = bytes_sent + :bytes_sent,
|
|
1221
|
+
bytes_recv = bytes_recv + :bytes_recv
|
|
1222
|
+
WHERE run_id = :run_id
|
|
1223
|
+
RETURNING run_id
|
|
1291
1224
|
"""
|
|
1292
|
-
rows = self.
|
|
1293
|
-
update_query,
|
|
1294
|
-
|
|
1225
|
+
rows = self.query(
|
|
1226
|
+
update_query,
|
|
1227
|
+
{
|
|
1228
|
+
"bytes_sent": bytes_sent,
|
|
1229
|
+
"bytes_recv": bytes_recv,
|
|
1230
|
+
"run_id": sint64_run_id,
|
|
1231
|
+
},
|
|
1232
|
+
)
|
|
1295
1233
|
|
|
1296
1234
|
if not rows:
|
|
1297
1235
|
raise ValueError(f"Run {run_id} not found")
|
|
@@ -1299,62 +1237,22 @@ class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
|
1299
1237
|
def add_clientapp_runtime(self, run_id: int, runtime: float) -> None:
|
|
1300
1238
|
"""Add ClientApp runtime to the cumulative total for the specified `run_id`."""
|
|
1301
1239
|
sint64_run_id = uint64_to_int64(run_id)
|
|
1302
|
-
with self.
|
|
1240
|
+
with self.session():
|
|
1303
1241
|
# Check if run exists, performing the update only if it does
|
|
1304
1242
|
update_query = """
|
|
1305
1243
|
UPDATE run
|
|
1306
|
-
SET clientapp_runtime = clientapp_runtime +
|
|
1307
|
-
WHERE run_id =
|
|
1308
|
-
RETURNING run_id
|
|
1244
|
+
SET clientapp_runtime = clientapp_runtime + :runtime
|
|
1245
|
+
WHERE run_id = :run_id
|
|
1246
|
+
RETURNING run_id
|
|
1309
1247
|
"""
|
|
1310
|
-
rows = self.
|
|
1248
|
+
rows = self.query(
|
|
1249
|
+
update_query, {"runtime": runtime, "run_id": sint64_run_id}
|
|
1250
|
+
)
|
|
1311
1251
|
|
|
1312
1252
|
if not rows:
|
|
1313
1253
|
raise ValueError(f"Run {run_id} not found")
|
|
1314
1254
|
|
|
1315
1255
|
|
|
1316
|
-
def message_to_dict(message: Message) -> dict[str, Any]:
|
|
1317
|
-
"""Transform Message to dict."""
|
|
1318
|
-
result = {
|
|
1319
|
-
"message_id": message.metadata.message_id,
|
|
1320
|
-
"group_id": message.metadata.group_id,
|
|
1321
|
-
"run_id": message.metadata.run_id,
|
|
1322
|
-
"src_node_id": message.metadata.src_node_id,
|
|
1323
|
-
"dst_node_id": message.metadata.dst_node_id,
|
|
1324
|
-
"reply_to_message_id": message.metadata.reply_to_message_id,
|
|
1325
|
-
"created_at": message.metadata.created_at,
|
|
1326
|
-
"delivered_at": message.metadata.delivered_at,
|
|
1327
|
-
"ttl": message.metadata.ttl,
|
|
1328
|
-
"message_type": message.metadata.message_type,
|
|
1329
|
-
"content": None,
|
|
1330
|
-
"error": None,
|
|
1331
|
-
}
|
|
1332
|
-
|
|
1333
|
-
if message.has_content():
|
|
1334
|
-
result["content"] = recorddict_to_proto(message.content).SerializeToString()
|
|
1335
|
-
else:
|
|
1336
|
-
result["error"] = error_to_proto(message.error).SerializeToString()
|
|
1337
|
-
|
|
1338
|
-
return result
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
def dict_to_message(message_dict: dict[str, Any]) -> Message:
|
|
1342
|
-
"""Transform dict to Message."""
|
|
1343
|
-
content, error = None, None
|
|
1344
|
-
if (b_content := message_dict.pop("content")) is not None:
|
|
1345
|
-
content = recorddict_from_proto(ProtoRecordDict.FromString(b_content))
|
|
1346
|
-
if (b_error := message_dict.pop("error")) is not None:
|
|
1347
|
-
error = error_from_proto(ProtoError.FromString(b_error))
|
|
1348
|
-
|
|
1349
|
-
# Metadata constructor doesn't allow passing created_at. We set it later
|
|
1350
|
-
metadata = Metadata(
|
|
1351
|
-
**{k: v for k, v in message_dict.items() if k not in ["delivered_at"]}
|
|
1352
|
-
)
|
|
1353
|
-
msg = make_message(metadata=metadata, content=content, error=error)
|
|
1354
|
-
msg.metadata.delivered_at = message_dict["delivered_at"]
|
|
1355
|
-
return msg
|
|
1356
|
-
|
|
1357
|
-
|
|
1358
1256
|
def determine_run_status(row: dict[str, Any]) -> str:
|
|
1359
1257
|
"""Determine the status of the run based on timestamp fields."""
|
|
1360
1258
|
if row["pending_at"]:
|
|
@@ -1366,4 +1264,4 @@ def determine_run_status(row: dict[str, Any]) -> str:
|
|
|
1366
1264
|
return Status.STARTING
|
|
1367
1265
|
return Status.PENDING
|
|
1368
1266
|
run_id = int64_to_uint64(row["run_id"])
|
|
1369
|
-
raise
|
|
1267
|
+
raise ValueError(f"The run {run_id} does not have a valid status.")
|