flwr-nightly 1.26.0.dev20260122__py3-none-any.whl → 1.26.0.dev20260126__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app_cmd/publish.py +18 -44
- flwr/cli/app_cmd/review.py +8 -25
- flwr/cli/auth_plugin/oidc_cli_plugin.py +3 -6
- flwr/cli/build.py +8 -19
- flwr/cli/config/ls.py +8 -13
- flwr/cli/config_utils.py +19 -171
- flwr/cli/federation/ls.py +3 -7
- flwr/cli/flower_config.py +28 -47
- flwr/cli/install.py +18 -57
- flwr/cli/log.py +2 -2
- flwr/cli/login/login.py +8 -21
- flwr/cli/ls.py +3 -7
- flwr/cli/new/new.py +9 -29
- flwr/cli/pull.py +3 -7
- flwr/cli/run/run.py +6 -15
- flwr/cli/stop.py +5 -17
- flwr/cli/supernode/register.py +6 -22
- flwr/cli/supernode/unregister.py +3 -13
- flwr/cli/utils.py +66 -169
- flwr/common/config.py +5 -9
- flwr/common/constant.py +2 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +4 -4
- flwr/server/superlink/linkstate/__init__.py +0 -2
- flwr/server/superlink/linkstate/sql_linkstate.py +38 -10
- flwr/supercore/object_store/object_store_factory.py +4 -4
- flwr/supercore/object_store/sql_object_store.py +171 -6
- flwr/superlink/servicer/control/control_servicer.py +11 -12
- {flwr_nightly-1.26.0.dev20260122.dist-info → flwr_nightly-1.26.0.dev20260126.dist-info}/METADATA +2 -2
- {flwr_nightly-1.26.0.dev20260122.dist-info → flwr_nightly-1.26.0.dev20260126.dist-info}/RECORD +31 -35
- flwr/server/superlink/linkstate/sqlite_linkstate.py +0 -1302
- flwr/supercore/corestate/sqlite_corestate.py +0 -157
- flwr/supercore/object_store/sqlite_object_store.py +0 -253
- flwr/supercore/sqlite_mixin.py +0 -156
- {flwr_nightly-1.26.0.dev20260122.dist-info → flwr_nightly-1.26.0.dev20260126.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.26.0.dev20260122.dist-info → flwr_nightly-1.26.0.dev20260126.dist-info}/entry_points.txt +0 -0
|
@@ -1,1302 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
"""SQLite based implemenation of the link state."""
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
# pylint: disable=too-many-lines
|
|
19
|
-
|
|
20
|
-
import json
|
|
21
|
-
import sqlite3
|
|
22
|
-
from collections.abc import Sequence
|
|
23
|
-
from datetime import datetime, timezone
|
|
24
|
-
from logging import ERROR, WARNING
|
|
25
|
-
from typing import Any, cast
|
|
26
|
-
|
|
27
|
-
from flwr.app.user_config import UserConfig
|
|
28
|
-
from flwr.common import Context, Message, log, now
|
|
29
|
-
from flwr.common.constant import (
|
|
30
|
-
HEARTBEAT_PATIENCE,
|
|
31
|
-
MESSAGE_TTL_TOLERANCE,
|
|
32
|
-
NODE_ID_NUM_BYTES,
|
|
33
|
-
RUN_FAILURE_DETAILS_NO_HEARTBEAT,
|
|
34
|
-
RUN_ID_NUM_BYTES,
|
|
35
|
-
SUPERLINK_NODE_ID,
|
|
36
|
-
Status,
|
|
37
|
-
SubStatus,
|
|
38
|
-
)
|
|
39
|
-
from flwr.common.record import ConfigRecord
|
|
40
|
-
from flwr.common.typing import Run, RunStatus
|
|
41
|
-
|
|
42
|
-
# pylint: disable=E0611
|
|
43
|
-
from flwr.proto.node_pb2 import NodeInfo
|
|
44
|
-
|
|
45
|
-
# pylint: enable=E0611
|
|
46
|
-
from flwr.server.utils.validator import validate_message
|
|
47
|
-
from flwr.supercore.constant import NodeStatus
|
|
48
|
-
from flwr.supercore.corestate.sqlite_corestate import SqliteCoreState
|
|
49
|
-
from flwr.supercore.object_store.object_store import ObjectStore
|
|
50
|
-
from flwr.supercore.utils import int64_to_uint64, uint64_to_int64
|
|
51
|
-
from flwr.superlink.federation import FederationManager
|
|
52
|
-
|
|
53
|
-
from .linkstate import LinkState
|
|
54
|
-
from .utils import (
|
|
55
|
-
check_node_availability_for_in_message,
|
|
56
|
-
configrecord_from_bytes,
|
|
57
|
-
configrecord_to_bytes,
|
|
58
|
-
context_from_bytes,
|
|
59
|
-
context_to_bytes,
|
|
60
|
-
convert_sint64_values_in_dict_to_uint64,
|
|
61
|
-
convert_uint64_values_in_dict_to_sint64,
|
|
62
|
-
dict_to_message,
|
|
63
|
-
generate_rand_int_from_bytes,
|
|
64
|
-
has_valid_sub_status,
|
|
65
|
-
is_valid_transition,
|
|
66
|
-
message_to_dict,
|
|
67
|
-
verify_found_message_replies,
|
|
68
|
-
verify_message_ids,
|
|
69
|
-
)
|
|
70
|
-
|
|
71
|
-
SQL_CREATE_TABLE_NODE = """
|
|
72
|
-
CREATE TABLE IF NOT EXISTS node(
|
|
73
|
-
node_id INTEGER UNIQUE,
|
|
74
|
-
owner_aid TEXT,
|
|
75
|
-
owner_name TEXT,
|
|
76
|
-
status TEXT,
|
|
77
|
-
registered_at TEXT,
|
|
78
|
-
last_activated_at TEXT NULL,
|
|
79
|
-
last_deactivated_at TEXT NULL,
|
|
80
|
-
unregistered_at TEXT NULL,
|
|
81
|
-
online_until TIMESTAMP NULL,
|
|
82
|
-
heartbeat_interval REAL,
|
|
83
|
-
public_key BLOB UNIQUE
|
|
84
|
-
);
|
|
85
|
-
"""
|
|
86
|
-
|
|
87
|
-
SQL_CREATE_INDEX_ONLINE_UNTIL = """
|
|
88
|
-
CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
|
|
89
|
-
"""
|
|
90
|
-
|
|
91
|
-
SQL_CREATE_INDEX_OWNER_AID = """
|
|
92
|
-
CREATE INDEX IF NOT EXISTS idx_node_owner_aid ON node(owner_aid);
|
|
93
|
-
"""
|
|
94
|
-
|
|
95
|
-
SQL_CREATE_INDEX_NODE_STATUS = """
|
|
96
|
-
CREATE INDEX IF NOT EXISTS idx_node_status ON node(status);
|
|
97
|
-
"""
|
|
98
|
-
|
|
99
|
-
SQL_CREATE_TABLE_RUN = """
|
|
100
|
-
CREATE TABLE IF NOT EXISTS run(
|
|
101
|
-
run_id INTEGER UNIQUE,
|
|
102
|
-
fab_id TEXT,
|
|
103
|
-
fab_version TEXT,
|
|
104
|
-
fab_hash TEXT,
|
|
105
|
-
override_config TEXT,
|
|
106
|
-
pending_at TEXT,
|
|
107
|
-
starting_at TEXT,
|
|
108
|
-
running_at TEXT,
|
|
109
|
-
finished_at TEXT,
|
|
110
|
-
sub_status TEXT,
|
|
111
|
-
details TEXT,
|
|
112
|
-
federation TEXT,
|
|
113
|
-
federation_options BLOB,
|
|
114
|
-
flwr_aid TEXT,
|
|
115
|
-
bytes_sent INTEGER DEFAULT 0,
|
|
116
|
-
bytes_recv INTEGER DEFAULT 0,
|
|
117
|
-
clientapp_runtime REAL DEFAULT 0.0
|
|
118
|
-
);
|
|
119
|
-
"""
|
|
120
|
-
|
|
121
|
-
SQL_CREATE_TABLE_LOGS = """
|
|
122
|
-
CREATE TABLE IF NOT EXISTS logs (
|
|
123
|
-
timestamp REAL,
|
|
124
|
-
run_id INTEGER,
|
|
125
|
-
node_id INTEGER,
|
|
126
|
-
log TEXT,
|
|
127
|
-
PRIMARY KEY (timestamp, run_id, node_id),
|
|
128
|
-
FOREIGN KEY (run_id) REFERENCES run(run_id)
|
|
129
|
-
);
|
|
130
|
-
"""
|
|
131
|
-
|
|
132
|
-
SQL_CREATE_TABLE_CONTEXT = """
|
|
133
|
-
CREATE TABLE IF NOT EXISTS context(
|
|
134
|
-
run_id INTEGER UNIQUE,
|
|
135
|
-
context BLOB,
|
|
136
|
-
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
|
137
|
-
);
|
|
138
|
-
"""
|
|
139
|
-
|
|
140
|
-
SQL_CREATE_TABLE_MESSAGE_INS = """
|
|
141
|
-
CREATE TABLE IF NOT EXISTS message_ins(
|
|
142
|
-
message_id TEXT UNIQUE,
|
|
143
|
-
group_id TEXT,
|
|
144
|
-
run_id INTEGER,
|
|
145
|
-
src_node_id INTEGER,
|
|
146
|
-
dst_node_id INTEGER,
|
|
147
|
-
reply_to_message_id TEXT,
|
|
148
|
-
created_at REAL,
|
|
149
|
-
delivered_at TEXT,
|
|
150
|
-
ttl REAL,
|
|
151
|
-
message_type TEXT,
|
|
152
|
-
content BLOB NULL,
|
|
153
|
-
error BLOB NULL,
|
|
154
|
-
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
|
155
|
-
);
|
|
156
|
-
"""
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
SQL_CREATE_TABLE_MESSAGE_RES = """
|
|
160
|
-
CREATE TABLE IF NOT EXISTS message_res(
|
|
161
|
-
message_id TEXT UNIQUE,
|
|
162
|
-
group_id TEXT,
|
|
163
|
-
run_id INTEGER,
|
|
164
|
-
src_node_id INTEGER,
|
|
165
|
-
dst_node_id INTEGER,
|
|
166
|
-
reply_to_message_id TEXT,
|
|
167
|
-
created_at REAL,
|
|
168
|
-
delivered_at TEXT,
|
|
169
|
-
ttl REAL,
|
|
170
|
-
message_type TEXT,
|
|
171
|
-
content BLOB NULL,
|
|
172
|
-
error BLOB NULL,
|
|
173
|
-
FOREIGN KEY(run_id) REFERENCES run(run_id)
|
|
174
|
-
);
|
|
175
|
-
"""
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
|
|
179
|
-
"""SQLite-based LinkState implementation."""
|
|
180
|
-
|
|
181
|
-
def __init__(
|
|
182
|
-
self,
|
|
183
|
-
database_path: str,
|
|
184
|
-
federation_manager: FederationManager,
|
|
185
|
-
object_store: ObjectStore,
|
|
186
|
-
) -> None:
|
|
187
|
-
super().__init__(database_path, object_store)
|
|
188
|
-
federation_manager.linkstate = self
|
|
189
|
-
self._federation_manager = federation_manager
|
|
190
|
-
|
|
191
|
-
def get_sql_statements(self) -> tuple[str, ...]:
|
|
192
|
-
"""Return SQL statements for LinkState tables."""
|
|
193
|
-
return super().get_sql_statements() + (
|
|
194
|
-
SQL_CREATE_TABLE_RUN,
|
|
195
|
-
SQL_CREATE_TABLE_LOGS,
|
|
196
|
-
SQL_CREATE_TABLE_CONTEXT,
|
|
197
|
-
SQL_CREATE_TABLE_MESSAGE_INS,
|
|
198
|
-
SQL_CREATE_TABLE_MESSAGE_RES,
|
|
199
|
-
SQL_CREATE_TABLE_NODE,
|
|
200
|
-
SQL_CREATE_INDEX_ONLINE_UNTIL,
|
|
201
|
-
SQL_CREATE_INDEX_OWNER_AID,
|
|
202
|
-
SQL_CREATE_INDEX_NODE_STATUS,
|
|
203
|
-
)
|
|
204
|
-
|
|
205
|
-
@property
|
|
206
|
-
def federation_manager(self) -> FederationManager:
|
|
207
|
-
"""Get the FederationManager instance."""
|
|
208
|
-
return self._federation_manager
|
|
209
|
-
|
|
210
|
-
def store_message_ins(self, message: Message) -> str | None:
|
|
211
|
-
"""Store one Message."""
|
|
212
|
-
# Validate message
|
|
213
|
-
errors = validate_message(message=message, is_reply_message=False)
|
|
214
|
-
if any(errors):
|
|
215
|
-
log(ERROR, errors)
|
|
216
|
-
return None
|
|
217
|
-
|
|
218
|
-
# Store Message
|
|
219
|
-
data = (message_to_dict(message),)
|
|
220
|
-
|
|
221
|
-
# Convert values from uint64 to sint64 for SQLite
|
|
222
|
-
convert_uint64_values_in_dict_to_sint64(
|
|
223
|
-
data[0], ["run_id", "src_node_id", "dst_node_id"]
|
|
224
|
-
)
|
|
225
|
-
|
|
226
|
-
# Validate source node ID
|
|
227
|
-
if message.metadata.src_node_id != SUPERLINK_NODE_ID:
|
|
228
|
-
log(
|
|
229
|
-
ERROR,
|
|
230
|
-
"Invalid source node ID for Message: %s",
|
|
231
|
-
message.metadata.src_node_id,
|
|
232
|
-
)
|
|
233
|
-
return None
|
|
234
|
-
|
|
235
|
-
with self.conn:
|
|
236
|
-
# Validate run_id
|
|
237
|
-
query = "SELECT federation FROM run WHERE run_id = ?;"
|
|
238
|
-
rows = self.conn.execute(query, (data[0]["run_id"],)).fetchall()
|
|
239
|
-
if not rows:
|
|
240
|
-
log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
|
|
241
|
-
return None
|
|
242
|
-
federation: str = rows[0]["federation"]
|
|
243
|
-
|
|
244
|
-
# Validate destination node ID
|
|
245
|
-
query = "SELECT node_id FROM node WHERE node_id = ? AND status IN (?, ?);"
|
|
246
|
-
rows = self.conn.execute(
|
|
247
|
-
query, (data[0]["dst_node_id"], NodeStatus.ONLINE, NodeStatus.OFFLINE)
|
|
248
|
-
).fetchall()
|
|
249
|
-
if not rows or not self.federation_manager.has_node(
|
|
250
|
-
message.metadata.dst_node_id, federation
|
|
251
|
-
):
|
|
252
|
-
log(
|
|
253
|
-
ERROR,
|
|
254
|
-
"Invalid destination node ID for Message: %s",
|
|
255
|
-
message.metadata.dst_node_id,
|
|
256
|
-
)
|
|
257
|
-
return None
|
|
258
|
-
|
|
259
|
-
columns = ", ".join([f":{key}" for key in data[0]])
|
|
260
|
-
query = f"INSERT INTO message_ins VALUES({columns});"
|
|
261
|
-
|
|
262
|
-
# Only invalid run_id can trigger IntegrityError.
|
|
263
|
-
# This may need to be changed in the future version
|
|
264
|
-
# with more integrity checks.
|
|
265
|
-
self.conn.execute(query, data[0])
|
|
266
|
-
|
|
267
|
-
return message.metadata.message_id
|
|
268
|
-
|
|
269
|
-
def _check_stored_messages(self, message_ids: set[str]) -> None:
|
|
270
|
-
"""Check and delete the message if it's invalid."""
|
|
271
|
-
if not message_ids:
|
|
272
|
-
return
|
|
273
|
-
|
|
274
|
-
with self.conn:
|
|
275
|
-
invalid_msg_ids: set[str] = set()
|
|
276
|
-
current_time = now().timestamp()
|
|
277
|
-
|
|
278
|
-
for msg_id in message_ids:
|
|
279
|
-
# Check if message exists
|
|
280
|
-
query = "SELECT * FROM message_ins WHERE message_id = ?;"
|
|
281
|
-
message_row = self.conn.execute(query, (msg_id,)).fetchone()
|
|
282
|
-
if not message_row:
|
|
283
|
-
continue
|
|
284
|
-
|
|
285
|
-
# Check if the message has expired
|
|
286
|
-
available_until = message_row["created_at"] + message_row["ttl"]
|
|
287
|
-
if available_until <= current_time:
|
|
288
|
-
invalid_msg_ids.add(msg_id)
|
|
289
|
-
continue
|
|
290
|
-
|
|
291
|
-
# Check if src_node_id and dst_node_id are in the federation
|
|
292
|
-
# Get federation from run table
|
|
293
|
-
run_id = message_row["run_id"]
|
|
294
|
-
query = "SELECT federation FROM run WHERE run_id = ?;"
|
|
295
|
-
run_row = self.conn.execute(query, (run_id,)).fetchone()
|
|
296
|
-
if not run_row: # This should not happen
|
|
297
|
-
invalid_msg_ids.add(msg_id)
|
|
298
|
-
continue
|
|
299
|
-
federation = run_row["federation"]
|
|
300
|
-
|
|
301
|
-
# Convert sint64 to uint64 for node IDs
|
|
302
|
-
src_node_id = int64_to_uint64(message_row["src_node_id"])
|
|
303
|
-
dst_node_id = int64_to_uint64(message_row["dst_node_id"])
|
|
304
|
-
|
|
305
|
-
# Filter nodes to check if they're in the federation
|
|
306
|
-
filtered = self.federation_manager.filter_nodes(
|
|
307
|
-
{src_node_id, dst_node_id}, federation
|
|
308
|
-
)
|
|
309
|
-
if len(filtered) != 2: # Not both nodes are in the federation
|
|
310
|
-
invalid_msg_ids.add(msg_id)
|
|
311
|
-
|
|
312
|
-
# Delete all invalid messages
|
|
313
|
-
self.delete_messages(invalid_msg_ids)
|
|
314
|
-
|
|
315
|
-
def get_message_ins(self, node_id: int, limit: int | None) -> list[Message]:
|
|
316
|
-
"""Get all Messages that have not been delivered yet."""
|
|
317
|
-
if limit is not None and limit < 1:
|
|
318
|
-
raise AssertionError("`limit` must be >= 1")
|
|
319
|
-
|
|
320
|
-
if node_id == SUPERLINK_NODE_ID:
|
|
321
|
-
msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
|
|
322
|
-
raise AssertionError(msg)
|
|
323
|
-
|
|
324
|
-
data: dict[str, str | int] = {}
|
|
325
|
-
|
|
326
|
-
# Convert the uint64 value to sint64 for SQLite
|
|
327
|
-
data["node_id"] = uint64_to_int64(node_id)
|
|
328
|
-
|
|
329
|
-
with self.conn:
|
|
330
|
-
# Retrieve all Messages for node_id
|
|
331
|
-
query = """
|
|
332
|
-
SELECT message_id
|
|
333
|
-
FROM message_ins
|
|
334
|
-
WHERE dst_node_id == :node_id
|
|
335
|
-
AND delivered_at = ""
|
|
336
|
-
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
|
|
337
|
-
"""
|
|
338
|
-
|
|
339
|
-
if limit is not None:
|
|
340
|
-
query += " LIMIT :limit"
|
|
341
|
-
data["limit"] = limit
|
|
342
|
-
|
|
343
|
-
query += ";"
|
|
344
|
-
|
|
345
|
-
rows = self.conn.execute(query, data).fetchall()
|
|
346
|
-
message_ids: set[str] = {row["message_id"] for row in rows}
|
|
347
|
-
self._check_stored_messages(message_ids)
|
|
348
|
-
|
|
349
|
-
# Mark retrieved Messages as delivered
|
|
350
|
-
if rows:
|
|
351
|
-
# Prepare query
|
|
352
|
-
placeholders: str = ",".join(
|
|
353
|
-
[f":id_{i}" for i in range(len(message_ids))]
|
|
354
|
-
)
|
|
355
|
-
query = f"""
|
|
356
|
-
UPDATE message_ins
|
|
357
|
-
SET delivered_at = :delivered_at
|
|
358
|
-
WHERE message_id IN ({placeholders})
|
|
359
|
-
RETURNING *;
|
|
360
|
-
"""
|
|
361
|
-
|
|
362
|
-
# Prepare data for query
|
|
363
|
-
delivered_at = now().isoformat()
|
|
364
|
-
data = {"delivered_at": delivered_at}
|
|
365
|
-
for index, msg_id in enumerate(message_ids):
|
|
366
|
-
data[f"id_{index}"] = str(msg_id)
|
|
367
|
-
|
|
368
|
-
# Run query
|
|
369
|
-
rows = self.conn.execute(query, data).fetchall()
|
|
370
|
-
|
|
371
|
-
for row in rows:
|
|
372
|
-
# Convert values from sint64 to uint64
|
|
373
|
-
convert_sint64_values_in_dict_to_uint64(
|
|
374
|
-
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
375
|
-
)
|
|
376
|
-
|
|
377
|
-
result = [dict_to_message(row) for row in rows]
|
|
378
|
-
|
|
379
|
-
return result
|
|
380
|
-
|
|
381
|
-
def store_message_res(self, message: Message) -> str | None:
|
|
382
|
-
"""Store one Message."""
|
|
383
|
-
# Validate message
|
|
384
|
-
errors = validate_message(message=message, is_reply_message=True)
|
|
385
|
-
if any(errors):
|
|
386
|
-
log(ERROR, errors)
|
|
387
|
-
return None
|
|
388
|
-
|
|
389
|
-
res_metadata = message.metadata
|
|
390
|
-
msg_ins_id = res_metadata.reply_to_message_id
|
|
391
|
-
msg_ins = self.get_valid_message_ins(msg_ins_id)
|
|
392
|
-
if msg_ins is None:
|
|
393
|
-
log(
|
|
394
|
-
ERROR,
|
|
395
|
-
"Failed to store Message reply: "
|
|
396
|
-
"The message it replies to with message_id %s does not exist or "
|
|
397
|
-
"has expired, or was deleted because the target SuperNode was "
|
|
398
|
-
"removed from the federation.",
|
|
399
|
-
msg_ins_id,
|
|
400
|
-
)
|
|
401
|
-
return None
|
|
402
|
-
|
|
403
|
-
# Ensure that the dst_node_id of the original message matches the src_node_id of
|
|
404
|
-
# reply being processed.
|
|
405
|
-
if (
|
|
406
|
-
msg_ins
|
|
407
|
-
and message
|
|
408
|
-
and int64_to_uint64(msg_ins["dst_node_id"]) != res_metadata.src_node_id
|
|
409
|
-
):
|
|
410
|
-
return None
|
|
411
|
-
|
|
412
|
-
# Fail if the Message TTL exceeds the
|
|
413
|
-
# expiration time of the Message it replies to.
|
|
414
|
-
# Condition: ins_metadata.created_at + ins_metadata.ttl ≥
|
|
415
|
-
# res_metadata.created_at + res_metadata.ttl
|
|
416
|
-
# A small tolerance is introduced to account
|
|
417
|
-
# for floating-point precision issues.
|
|
418
|
-
max_allowed_ttl = (
|
|
419
|
-
msg_ins["created_at"] + msg_ins["ttl"] - res_metadata.created_at
|
|
420
|
-
)
|
|
421
|
-
if res_metadata.ttl and (
|
|
422
|
-
res_metadata.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
|
|
423
|
-
):
|
|
424
|
-
log(
|
|
425
|
-
WARNING,
|
|
426
|
-
"Received Message with TTL %.2f exceeding the allowed maximum "
|
|
427
|
-
"TTL %.2f.",
|
|
428
|
-
res_metadata.ttl,
|
|
429
|
-
max_allowed_ttl,
|
|
430
|
-
)
|
|
431
|
-
return None
|
|
432
|
-
|
|
433
|
-
# Store Message
|
|
434
|
-
data = (message_to_dict(message),)
|
|
435
|
-
|
|
436
|
-
# Convert values from uint64 to sint64 for SQLite
|
|
437
|
-
convert_uint64_values_in_dict_to_sint64(
|
|
438
|
-
data[0], ["run_id", "src_node_id", "dst_node_id"]
|
|
439
|
-
)
|
|
440
|
-
|
|
441
|
-
columns = ", ".join([f":{key}" for key in data[0]])
|
|
442
|
-
query = f"INSERT INTO message_res VALUES({columns});"
|
|
443
|
-
|
|
444
|
-
# Only invalid run_id can trigger IntegrityError.
|
|
445
|
-
# This may need to be changed in the future version with more integrity checks.
|
|
446
|
-
try:
|
|
447
|
-
self.query(query, data)
|
|
448
|
-
except sqlite3.IntegrityError:
|
|
449
|
-
log(ERROR, "`run` is invalid")
|
|
450
|
-
return None
|
|
451
|
-
|
|
452
|
-
return message.metadata.message_id
|
|
453
|
-
|
|
454
|
-
def get_message_res(self, message_ids: set[str]) -> list[Message]:
|
|
455
|
-
"""Get reply Messages for the given Message IDs."""
|
|
456
|
-
# pylint: disable-msg=too-many-locals
|
|
457
|
-
ret: dict[str, Message] = {}
|
|
458
|
-
|
|
459
|
-
with self.conn:
|
|
460
|
-
# Verify Message IDs
|
|
461
|
-
self._check_stored_messages(message_ids)
|
|
462
|
-
current = now().timestamp()
|
|
463
|
-
query = f"""
|
|
464
|
-
SELECT *
|
|
465
|
-
FROM message_ins
|
|
466
|
-
WHERE message_id IN ({','.join(['?'] * len(message_ids))});
|
|
467
|
-
"""
|
|
468
|
-
rows = self.conn.execute(
|
|
469
|
-
query, tuple(str(message_id) for message_id in message_ids)
|
|
470
|
-
).fetchall()
|
|
471
|
-
found_message_ins_dict: dict[str, Message] = {}
|
|
472
|
-
for row in rows:
|
|
473
|
-
convert_sint64_values_in_dict_to_uint64(
|
|
474
|
-
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
475
|
-
)
|
|
476
|
-
found_message_ins_dict[row["message_id"]] = dict_to_message(row)
|
|
477
|
-
|
|
478
|
-
ret = verify_message_ids(
|
|
479
|
-
inquired_message_ids=message_ids,
|
|
480
|
-
found_message_ins_dict=found_message_ins_dict,
|
|
481
|
-
current_time=current,
|
|
482
|
-
)
|
|
483
|
-
|
|
484
|
-
# Check node availability
|
|
485
|
-
dst_node_ids: set[int] = set()
|
|
486
|
-
for message_id in message_ids:
|
|
487
|
-
in_message = found_message_ins_dict[message_id]
|
|
488
|
-
sint_node_id = uint64_to_int64(in_message.metadata.dst_node_id)
|
|
489
|
-
dst_node_ids.add(sint_node_id)
|
|
490
|
-
query = f"""
|
|
491
|
-
SELECT node_id, online_until
|
|
492
|
-
FROM node
|
|
493
|
-
WHERE node_id IN ({','.join(['?'] * len(dst_node_ids))})
|
|
494
|
-
AND status != ?
|
|
495
|
-
"""
|
|
496
|
-
rows = self.conn.execute(
|
|
497
|
-
query, tuple(dst_node_ids) + (NodeStatus.UNREGISTERED,)
|
|
498
|
-
).fetchall()
|
|
499
|
-
tmp_ret_dict = check_node_availability_for_in_message(
|
|
500
|
-
inquired_in_message_ids=message_ids,
|
|
501
|
-
found_in_message_dict=found_message_ins_dict,
|
|
502
|
-
node_id_to_online_until={
|
|
503
|
-
int64_to_uint64(row["node_id"]): row["online_until"] for row in rows
|
|
504
|
-
},
|
|
505
|
-
current_time=current,
|
|
506
|
-
)
|
|
507
|
-
ret.update(tmp_ret_dict)
|
|
508
|
-
|
|
509
|
-
# Find all reply Messages
|
|
510
|
-
query = f"""
|
|
511
|
-
SELECT *
|
|
512
|
-
FROM message_res
|
|
513
|
-
WHERE reply_to_message_id IN ({','.join(['?'] * len(message_ids))})
|
|
514
|
-
AND delivered_at = "";
|
|
515
|
-
"""
|
|
516
|
-
rows = self.conn.execute(
|
|
517
|
-
query, tuple(str(message_id) for message_id in message_ids)
|
|
518
|
-
).fetchall()
|
|
519
|
-
for row in rows:
|
|
520
|
-
convert_sint64_values_in_dict_to_uint64(
|
|
521
|
-
row, ["run_id", "src_node_id", "dst_node_id"]
|
|
522
|
-
)
|
|
523
|
-
tmp_ret_dict = verify_found_message_replies(
|
|
524
|
-
inquired_message_ids=message_ids,
|
|
525
|
-
found_message_ins_dict=found_message_ins_dict,
|
|
526
|
-
found_message_res_list=[dict_to_message(row) for row in rows],
|
|
527
|
-
current_time=current,
|
|
528
|
-
)
|
|
529
|
-
ret.update(tmp_ret_dict)
|
|
530
|
-
|
|
531
|
-
# Mark existing reply Messages to be returned as delivered
|
|
532
|
-
delivered_at = now().isoformat()
|
|
533
|
-
for message_res in ret.values():
|
|
534
|
-
message_res.metadata.delivered_at = delivered_at
|
|
535
|
-
message_res_ids = [
|
|
536
|
-
message_res.metadata.message_id for message_res in ret.values()
|
|
537
|
-
]
|
|
538
|
-
query = f"""
|
|
539
|
-
UPDATE message_res
|
|
540
|
-
SET delivered_at = ?
|
|
541
|
-
WHERE message_id IN ({','.join(['?'] * len(message_res_ids))});
|
|
542
|
-
"""
|
|
543
|
-
data: list[Any] = [delivered_at] + message_res_ids
|
|
544
|
-
self.conn.execute(query, data)
|
|
545
|
-
|
|
546
|
-
return list(ret.values())
|
|
547
|
-
|
|
548
|
-
def num_message_ins(self) -> int:
|
|
549
|
-
"""Calculate the number of instruction Messages in store.
|
|
550
|
-
|
|
551
|
-
This includes delivered but not yet deleted.
|
|
552
|
-
"""
|
|
553
|
-
query = "SELECT count(*) AS num FROM message_ins;"
|
|
554
|
-
rows = self.query(query)
|
|
555
|
-
result = rows[0]
|
|
556
|
-
num = cast(int, result["num"])
|
|
557
|
-
return num
|
|
558
|
-
|
|
559
|
-
def num_message_res(self) -> int:
|
|
560
|
-
"""Calculate the number of reply Messages in store.
|
|
561
|
-
|
|
562
|
-
This includes delivered but not yet deleted.
|
|
563
|
-
"""
|
|
564
|
-
query = "SELECT count(*) AS num FROM message_res;"
|
|
565
|
-
rows = self.query(query)
|
|
566
|
-
result: dict[str, int] = rows[0]
|
|
567
|
-
return result["num"]
|
|
568
|
-
|
|
569
|
-
def delete_messages(self, message_ins_ids: set[str]) -> None:
|
|
570
|
-
"""Delete a Message and its reply based on provided Message IDs."""
|
|
571
|
-
if not message_ins_ids:
|
|
572
|
-
return
|
|
573
|
-
if self.conn is None:
|
|
574
|
-
raise AttributeError("LinkState not initialized")
|
|
575
|
-
|
|
576
|
-
placeholders = ",".join(["?"] * len(message_ins_ids))
|
|
577
|
-
data = tuple(str(message_id) for message_id in message_ins_ids)
|
|
578
|
-
|
|
579
|
-
# Delete Message
|
|
580
|
-
query_1 = f"""
|
|
581
|
-
DELETE FROM message_ins
|
|
582
|
-
WHERE message_id IN ({placeholders});
|
|
583
|
-
"""
|
|
584
|
-
|
|
585
|
-
# Delete reply Message
|
|
586
|
-
query_2 = f"""
|
|
587
|
-
DELETE FROM message_res
|
|
588
|
-
WHERE reply_to_message_id IN ({placeholders});
|
|
589
|
-
"""
|
|
590
|
-
|
|
591
|
-
with self.conn:
|
|
592
|
-
self.conn.execute(query_1, data)
|
|
593
|
-
self.conn.execute(query_2, data)
|
|
594
|
-
|
|
595
|
-
def get_message_ids_from_run_id(self, run_id: int) -> set[str]:
|
|
596
|
-
"""Get all instruction Message IDs for the given run_id."""
|
|
597
|
-
if self.conn is None:
|
|
598
|
-
raise AttributeError("LinkState not initialized")
|
|
599
|
-
|
|
600
|
-
query = """
|
|
601
|
-
SELECT message_id
|
|
602
|
-
FROM message_ins
|
|
603
|
-
WHERE run_id = :run_id;
|
|
604
|
-
"""
|
|
605
|
-
|
|
606
|
-
sint64_run_id = uint64_to_int64(run_id)
|
|
607
|
-
data = {"run_id": sint64_run_id}
|
|
608
|
-
|
|
609
|
-
with self.conn:
|
|
610
|
-
rows = self.conn.execute(query, data).fetchall()
|
|
611
|
-
|
|
612
|
-
return {row["message_id"] for row in rows}
|
|
613
|
-
|
|
614
|
-
def create_node(
|
|
615
|
-
self,
|
|
616
|
-
owner_aid: str,
|
|
617
|
-
owner_name: str,
|
|
618
|
-
public_key: bytes,
|
|
619
|
-
heartbeat_interval: float,
|
|
620
|
-
) -> int:
|
|
621
|
-
"""Create, store in the link state, and return `node_id`."""
|
|
622
|
-
# Sample a random uint64 as node_id
|
|
623
|
-
uint64_node_id = generate_rand_int_from_bytes(
|
|
624
|
-
NODE_ID_NUM_BYTES, exclude=[SUPERLINK_NODE_ID, 0]
|
|
625
|
-
)
|
|
626
|
-
|
|
627
|
-
# Convert the uint64 value to sint64 for SQLite
|
|
628
|
-
sint64_node_id = uint64_to_int64(uint64_node_id)
|
|
629
|
-
|
|
630
|
-
query = """
|
|
631
|
-
INSERT INTO node
|
|
632
|
-
(node_id, owner_aid, owner_name, status, registered_at, last_activated_at,
|
|
633
|
-
last_deactivated_at, unregistered_at, online_until, heartbeat_interval,
|
|
634
|
-
public_key)
|
|
635
|
-
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
636
|
-
"""
|
|
637
|
-
|
|
638
|
-
# Mark the node online until now().timestamp() + heartbeat_interval
|
|
639
|
-
try:
|
|
640
|
-
self.query(
|
|
641
|
-
query,
|
|
642
|
-
(
|
|
643
|
-
sint64_node_id, # node_id
|
|
644
|
-
owner_aid, # owner_aid
|
|
645
|
-
owner_name, # owner_name
|
|
646
|
-
NodeStatus.REGISTERED, # status
|
|
647
|
-
now().isoformat(), # registered_at
|
|
648
|
-
None, # last_activated_at
|
|
649
|
-
None, # last_deactivated_at
|
|
650
|
-
None, # unregistered_at
|
|
651
|
-
None, # online_until, initialized with offline status
|
|
652
|
-
heartbeat_interval, # heartbeat_interval
|
|
653
|
-
public_key, # public_key
|
|
654
|
-
),
|
|
655
|
-
)
|
|
656
|
-
except sqlite3.IntegrityError as e:
|
|
657
|
-
if "UNIQUE constraint failed: node.public_key" in str(e):
|
|
658
|
-
raise ValueError("Public key already in use.") from None
|
|
659
|
-
# Must be node ID conflict, almost impossible unless system is compromised
|
|
660
|
-
log(ERROR, "Unexpected node registration failure.")
|
|
661
|
-
return 0
|
|
662
|
-
|
|
663
|
-
# Note: we need to return the uint64 value of the node_id
|
|
664
|
-
return uint64_node_id
|
|
665
|
-
|
|
666
|
-
def delete_node(self, owner_aid: str, node_id: int) -> None:
|
|
667
|
-
"""Delete a node."""
|
|
668
|
-
sint64_node_id = uint64_to_int64(node_id)
|
|
669
|
-
|
|
670
|
-
query = """
|
|
671
|
-
UPDATE node
|
|
672
|
-
SET status = ?, unregistered_at = ?,
|
|
673
|
-
online_until = IIF(online_until > ?, ?, online_until)
|
|
674
|
-
WHERE node_id = ? AND status != ? AND owner_aid = ?
|
|
675
|
-
RETURNING node_id
|
|
676
|
-
"""
|
|
677
|
-
current = now()
|
|
678
|
-
params = (
|
|
679
|
-
NodeStatus.UNREGISTERED,
|
|
680
|
-
current.isoformat(),
|
|
681
|
-
current.timestamp(),
|
|
682
|
-
current.timestamp(),
|
|
683
|
-
sint64_node_id,
|
|
684
|
-
NodeStatus.UNREGISTERED,
|
|
685
|
-
owner_aid,
|
|
686
|
-
)
|
|
687
|
-
|
|
688
|
-
rows = self.query(query, params)
|
|
689
|
-
if not rows:
|
|
690
|
-
raise ValueError(
|
|
691
|
-
f"Node {node_id} already deleted, not found or unauthorized "
|
|
692
|
-
"deletion attempt."
|
|
693
|
-
)
|
|
694
|
-
|
|
695
|
-
def activate_node(self, node_id: int, heartbeat_interval: float) -> bool:
|
|
696
|
-
"""Activate the node with the specified `node_id`."""
|
|
697
|
-
with self.conn:
|
|
698
|
-
self._check_and_tag_offline_nodes([node_id])
|
|
699
|
-
|
|
700
|
-
# Only activate if the node is currently registered or offline
|
|
701
|
-
current_dt = now()
|
|
702
|
-
query = """
|
|
703
|
-
UPDATE node
|
|
704
|
-
SET status = ?,
|
|
705
|
-
last_activated_at = ?,
|
|
706
|
-
online_until = ?,
|
|
707
|
-
heartbeat_interval = ?
|
|
708
|
-
WHERE node_id = ? AND status in (?, ?)
|
|
709
|
-
RETURNING node_id
|
|
710
|
-
"""
|
|
711
|
-
params = (
|
|
712
|
-
NodeStatus.ONLINE,
|
|
713
|
-
current_dt.isoformat(),
|
|
714
|
-
current_dt.timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval,
|
|
715
|
-
heartbeat_interval,
|
|
716
|
-
uint64_to_int64(node_id),
|
|
717
|
-
NodeStatus.REGISTERED,
|
|
718
|
-
NodeStatus.OFFLINE,
|
|
719
|
-
)
|
|
720
|
-
|
|
721
|
-
row = self.conn.execute(query, params).fetchone()
|
|
722
|
-
return row is not None
|
|
723
|
-
|
|
724
|
-
def deactivate_node(self, node_id: int) -> bool:
|
|
725
|
-
"""Deactivate the node with the specified `node_id`."""
|
|
726
|
-
with self.conn:
|
|
727
|
-
self._check_and_tag_offline_nodes([node_id])
|
|
728
|
-
|
|
729
|
-
# Only deactivate if the node is currently online
|
|
730
|
-
current_dt = now()
|
|
731
|
-
query = """
|
|
732
|
-
UPDATE node
|
|
733
|
-
SET status = ?,
|
|
734
|
-
last_deactivated_at = ?,
|
|
735
|
-
online_until = ?
|
|
736
|
-
WHERE node_id = ? AND status = ?
|
|
737
|
-
RETURNING node_id
|
|
738
|
-
"""
|
|
739
|
-
params = (
|
|
740
|
-
NodeStatus.OFFLINE,
|
|
741
|
-
current_dt.isoformat(),
|
|
742
|
-
current_dt.timestamp(),
|
|
743
|
-
uint64_to_int64(node_id),
|
|
744
|
-
NodeStatus.ONLINE,
|
|
745
|
-
)
|
|
746
|
-
|
|
747
|
-
row = self.conn.execute(query, params).fetchone()
|
|
748
|
-
return row is not None
|
|
749
|
-
|
|
750
|
-
def get_nodes(self, run_id: int) -> set[int]:
|
|
751
|
-
"""Retrieve all currently stored node IDs as a set.
|
|
752
|
-
|
|
753
|
-
Constraints
|
|
754
|
-
-----------
|
|
755
|
-
If the provided `run_id` does not exist or has no matching nodes,
|
|
756
|
-
an empty `Set` MUST be returned.
|
|
757
|
-
"""
|
|
758
|
-
if self.conn is None:
|
|
759
|
-
raise AttributeError("LinkState not initialized")
|
|
760
|
-
|
|
761
|
-
with self.conn:
|
|
762
|
-
# Convert the uint64 value to sint64 for SQLite
|
|
763
|
-
sint64_run_id = uint64_to_int64(run_id)
|
|
764
|
-
|
|
765
|
-
# Validate run ID
|
|
766
|
-
query = "SELECT federation FROM run WHERE run_id = ?"
|
|
767
|
-
rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
|
|
768
|
-
if not rows:
|
|
769
|
-
return set()
|
|
770
|
-
federation: str = rows[0]["federation"]
|
|
771
|
-
|
|
772
|
-
# Retrieve all online nodes
|
|
773
|
-
node_ids = {
|
|
774
|
-
node.node_id
|
|
775
|
-
for node in self.get_node_info(statuses=[NodeStatus.ONLINE])
|
|
776
|
-
}
|
|
777
|
-
# Filter node IDs by federation
|
|
778
|
-
return self.federation_manager.filter_nodes(node_ids, federation)
|
|
779
|
-
|
|
780
|
-
def _check_and_tag_offline_nodes(self, node_ids: list[int] | None = None) -> None:
|
|
781
|
-
"""Check and tag offline nodes."""
|
|
782
|
-
# strftime will convert POSIX timestamp to ISO format
|
|
783
|
-
query = """
|
|
784
|
-
UPDATE node SET status = ?,
|
|
785
|
-
last_deactivated_at =
|
|
786
|
-
strftime("%Y-%m-%dT%H:%M:%f+00:00", online_until, "unixepoch")
|
|
787
|
-
WHERE online_until <= ? AND status == ?
|
|
788
|
-
"""
|
|
789
|
-
params = [
|
|
790
|
-
NodeStatus.OFFLINE,
|
|
791
|
-
now().timestamp(),
|
|
792
|
-
NodeStatus.ONLINE,
|
|
793
|
-
]
|
|
794
|
-
if node_ids is not None:
|
|
795
|
-
placeholders = ",".join(["?"] * len(node_ids))
|
|
796
|
-
query += f" AND node_id IN ({placeholders})"
|
|
797
|
-
params.extend(uint64_to_int64(node_id) for node_id in node_ids)
|
|
798
|
-
self.conn.execute(query, params)
|
|
799
|
-
|
|
800
|
-
def get_node_info(
|
|
801
|
-
self,
|
|
802
|
-
*,
|
|
803
|
-
node_ids: Sequence[int] | None = None,
|
|
804
|
-
owner_aids: Sequence[str] | None = None,
|
|
805
|
-
statuses: Sequence[str] | None = None,
|
|
806
|
-
) -> Sequence[NodeInfo]:
|
|
807
|
-
"""Retrieve information about nodes based on the specified filters."""
|
|
808
|
-
with self.conn:
|
|
809
|
-
self._check_and_tag_offline_nodes()
|
|
810
|
-
|
|
811
|
-
# Build the WHERE clause based on provided filters
|
|
812
|
-
conditions = []
|
|
813
|
-
params: list[Any] = []
|
|
814
|
-
if node_ids is not None:
|
|
815
|
-
sint64_node_ids = [uint64_to_int64(node_id) for node_id in node_ids]
|
|
816
|
-
placeholders = ",".join(["?"] * len(sint64_node_ids))
|
|
817
|
-
conditions.append(f"node_id IN ({placeholders})")
|
|
818
|
-
params.extend(sint64_node_ids)
|
|
819
|
-
if owner_aids is not None:
|
|
820
|
-
placeholders = ",".join(["?"] * len(owner_aids))
|
|
821
|
-
conditions.append(f"owner_aid IN ({placeholders})")
|
|
822
|
-
params.extend(owner_aids)
|
|
823
|
-
if statuses is not None:
|
|
824
|
-
placeholders = ",".join(["?"] * len(statuses))
|
|
825
|
-
conditions.append(f"status IN ({placeholders})")
|
|
826
|
-
params.extend(statuses)
|
|
827
|
-
|
|
828
|
-
# Construct the final query
|
|
829
|
-
query = "SELECT * FROM node"
|
|
830
|
-
if conditions:
|
|
831
|
-
query += " WHERE " + " AND ".join(conditions)
|
|
832
|
-
|
|
833
|
-
rows = self.conn.execute(query, params).fetchall()
|
|
834
|
-
|
|
835
|
-
result: list[NodeInfo] = []
|
|
836
|
-
for row in rows:
|
|
837
|
-
# Convert sint64 node_id to uint64
|
|
838
|
-
row["node_id"] = int64_to_uint64(row["node_id"])
|
|
839
|
-
result.append(NodeInfo(**row))
|
|
840
|
-
|
|
841
|
-
return result
|
|
842
|
-
|
|
843
|
-
def get_node_id_by_public_key(self, public_key: bytes) -> int | None:
|
|
844
|
-
"""Get `node_id` for the specified `public_key` if it exists and is not
|
|
845
|
-
deleted."""
|
|
846
|
-
query = "SELECT node_id FROM node WHERE public_key = ? AND status != ?;"
|
|
847
|
-
rows = self.query(query, (public_key, NodeStatus.UNREGISTERED))
|
|
848
|
-
|
|
849
|
-
# If no result is found, return None
|
|
850
|
-
if not rows:
|
|
851
|
-
return None
|
|
852
|
-
|
|
853
|
-
# Convert sint64 node_id to uint64
|
|
854
|
-
node_id = int64_to_uint64(rows[0]["node_id"])
|
|
855
|
-
return node_id
|
|
856
|
-
|
|
857
|
-
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
858
|
-
def create_run(
|
|
859
|
-
self,
|
|
860
|
-
fab_id: str | None,
|
|
861
|
-
fab_version: str | None,
|
|
862
|
-
fab_hash: str | None,
|
|
863
|
-
override_config: UserConfig,
|
|
864
|
-
federation: str,
|
|
865
|
-
federation_options: ConfigRecord,
|
|
866
|
-
flwr_aid: str | None,
|
|
867
|
-
) -> int:
|
|
868
|
-
"""Create a new run."""
|
|
869
|
-
# Sample a random int64 as run_id
|
|
870
|
-
uint64_run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
871
|
-
|
|
872
|
-
# Convert the uint64 value to sint64 for SQLite
|
|
873
|
-
sint64_run_id = uint64_to_int64(uint64_run_id)
|
|
874
|
-
|
|
875
|
-
with self.conn:
|
|
876
|
-
# Check conflicts
|
|
877
|
-
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
|
|
878
|
-
# If sint64_run_id does not exist
|
|
879
|
-
row = self.conn.execute(query, (sint64_run_id,)).fetchone()
|
|
880
|
-
if row["COUNT(*)"] == 0:
|
|
881
|
-
query = """
|
|
882
|
-
INSERT INTO run
|
|
883
|
-
(run_id, fab_id, fab_version,
|
|
884
|
-
fab_hash, override_config, federation, federation_options,
|
|
885
|
-
pending_at, starting_at, running_at, finished_at, sub_status,
|
|
886
|
-
details, flwr_aid, bytes_sent, bytes_recv, clientapp_runtime)
|
|
887
|
-
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
|
888
|
-
"""
|
|
889
|
-
override_config_json = json.dumps(override_config)
|
|
890
|
-
data = [
|
|
891
|
-
sint64_run_id, # run_id
|
|
892
|
-
fab_id, # fab_id
|
|
893
|
-
fab_version, # fab_version
|
|
894
|
-
fab_hash, # fab_hash
|
|
895
|
-
override_config_json, # override_config
|
|
896
|
-
federation, # federation
|
|
897
|
-
configrecord_to_bytes(federation_options), # federation_options
|
|
898
|
-
now().isoformat(), # pending_at
|
|
899
|
-
"", # starting_at
|
|
900
|
-
"", # running_at
|
|
901
|
-
"", # finished_at
|
|
902
|
-
"", # sub_status
|
|
903
|
-
"", # details
|
|
904
|
-
flwr_aid or "", # flwr_aid
|
|
905
|
-
0, # bytes_sent
|
|
906
|
-
0, # bytes_recv
|
|
907
|
-
0, # clientapp_runtime
|
|
908
|
-
]
|
|
909
|
-
self.conn.execute(query, tuple(data))
|
|
910
|
-
return uint64_run_id
|
|
911
|
-
log(ERROR, "Unexpected run creation failure.")
|
|
912
|
-
return 0
|
|
913
|
-
|
|
914
|
-
def get_run_ids(self, flwr_aid: str | None) -> set[int]:
|
|
915
|
-
"""Retrieve all run IDs if `flwr_aid` is not specified.
|
|
916
|
-
|
|
917
|
-
Otherwise, retrieve all run IDs for the specified `flwr_aid`.
|
|
918
|
-
"""
|
|
919
|
-
if flwr_aid:
|
|
920
|
-
rows = self.query(
|
|
921
|
-
"SELECT run_id FROM run WHERE flwr_aid = ?;",
|
|
922
|
-
(flwr_aid,),
|
|
923
|
-
)
|
|
924
|
-
else:
|
|
925
|
-
rows = self.query("SELECT run_id FROM run;", ())
|
|
926
|
-
return {int64_to_uint64(row["run_id"]) for row in rows}
|
|
927
|
-
|
|
928
|
-
def get_run(self, run_id: int) -> Run | None:
|
|
929
|
-
"""Retrieve information about the run with the specified `run_id`."""
|
|
930
|
-
# Clean up expired tokens; this will flag inactive runs as needed
|
|
931
|
-
self._cleanup_expired_tokens()
|
|
932
|
-
|
|
933
|
-
# Convert the uint64 value to sint64 for SQLite
|
|
934
|
-
sint64_run_id = uint64_to_int64(run_id)
|
|
935
|
-
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
936
|
-
rows = self.query(query, (sint64_run_id,))
|
|
937
|
-
if rows:
|
|
938
|
-
row = rows[0]
|
|
939
|
-
return Run(
|
|
940
|
-
run_id=int64_to_uint64(row["run_id"]),
|
|
941
|
-
fab_id=row["fab_id"],
|
|
942
|
-
fab_version=row["fab_version"],
|
|
943
|
-
fab_hash=row["fab_hash"],
|
|
944
|
-
override_config=json.loads(row["override_config"]),
|
|
945
|
-
pending_at=row["pending_at"],
|
|
946
|
-
starting_at=row["starting_at"],
|
|
947
|
-
running_at=row["running_at"],
|
|
948
|
-
finished_at=row["finished_at"],
|
|
949
|
-
status=RunStatus(
|
|
950
|
-
status=determine_run_status(row),
|
|
951
|
-
sub_status=row["sub_status"],
|
|
952
|
-
details=row["details"],
|
|
953
|
-
),
|
|
954
|
-
flwr_aid=row["flwr_aid"],
|
|
955
|
-
federation=row["federation"],
|
|
956
|
-
bytes_sent=row["bytes_sent"],
|
|
957
|
-
bytes_recv=row["bytes_recv"],
|
|
958
|
-
clientapp_runtime=row["clientapp_runtime"],
|
|
959
|
-
)
|
|
960
|
-
log(ERROR, "`run_id` does not exist.")
|
|
961
|
-
return None
|
|
962
|
-
|
|
963
|
-
def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
|
|
964
|
-
"""Retrieve the statuses for the specified runs."""
|
|
965
|
-
# Clean up expired tokens; this will flag inactive runs as needed
|
|
966
|
-
self._cleanup_expired_tokens()
|
|
967
|
-
|
|
968
|
-
# Convert the uint64 value to sint64 for SQLite
|
|
969
|
-
sint64_run_ids = (uint64_to_int64(run_id) for run_id in set(run_ids))
|
|
970
|
-
query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});"
|
|
971
|
-
rows = self.query(query, tuple(sint64_run_ids))
|
|
972
|
-
|
|
973
|
-
return {
|
|
974
|
-
# Restore uint64 run IDs
|
|
975
|
-
int64_to_uint64(row["run_id"]): RunStatus(
|
|
976
|
-
status=determine_run_status(row),
|
|
977
|
-
sub_status=row["sub_status"],
|
|
978
|
-
details=row["details"],
|
|
979
|
-
)
|
|
980
|
-
for row in rows
|
|
981
|
-
}
|
|
982
|
-
|
|
983
|
-
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
|
|
984
|
-
"""Update the status of the run with the specified `run_id`."""
|
|
985
|
-
# Clean up expired tokens; this will flag inactive runs as needed
|
|
986
|
-
self._cleanup_expired_tokens()
|
|
987
|
-
|
|
988
|
-
with self.conn:
|
|
989
|
-
# Convert the uint64 value to sint64 for SQLite
|
|
990
|
-
sint64_run_id = uint64_to_int64(run_id)
|
|
991
|
-
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
992
|
-
rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
|
|
993
|
-
|
|
994
|
-
# Check if the run_id exists
|
|
995
|
-
if not rows:
|
|
996
|
-
log(ERROR, "`run_id` is invalid")
|
|
997
|
-
return False
|
|
998
|
-
|
|
999
|
-
# Check if the status transition is valid
|
|
1000
|
-
row = rows[0]
|
|
1001
|
-
current_status = RunStatus(
|
|
1002
|
-
status=determine_run_status(row),
|
|
1003
|
-
sub_status=row["sub_status"],
|
|
1004
|
-
details=row["details"],
|
|
1005
|
-
)
|
|
1006
|
-
if not is_valid_transition(current_status, new_status):
|
|
1007
|
-
log(
|
|
1008
|
-
ERROR,
|
|
1009
|
-
'Invalid status transition: from "%s" to "%s"',
|
|
1010
|
-
current_status.status,
|
|
1011
|
-
new_status.status,
|
|
1012
|
-
)
|
|
1013
|
-
return False
|
|
1014
|
-
|
|
1015
|
-
# Check if the sub-status is valid
|
|
1016
|
-
if not has_valid_sub_status(current_status):
|
|
1017
|
-
log(
|
|
1018
|
-
ERROR,
|
|
1019
|
-
'Invalid sub-status "%s" for status "%s"',
|
|
1020
|
-
current_status.sub_status,
|
|
1021
|
-
current_status.status,
|
|
1022
|
-
)
|
|
1023
|
-
return False
|
|
1024
|
-
|
|
1025
|
-
# Update the status
|
|
1026
|
-
query = """
|
|
1027
|
-
UPDATE run SET %s= ?, sub_status = ?, details = ? WHERE run_id = ?;
|
|
1028
|
-
"""
|
|
1029
|
-
|
|
1030
|
-
# Prepare data for query
|
|
1031
|
-
current = now()
|
|
1032
|
-
|
|
1033
|
-
# Determine the timestamp field based on the new status
|
|
1034
|
-
timestamp_fld = ""
|
|
1035
|
-
if new_status.status == Status.STARTING:
|
|
1036
|
-
timestamp_fld = "starting_at"
|
|
1037
|
-
elif new_status.status == Status.RUNNING:
|
|
1038
|
-
timestamp_fld = "running_at"
|
|
1039
|
-
elif new_status.status == Status.FINISHED:
|
|
1040
|
-
timestamp_fld = "finished_at"
|
|
1041
|
-
|
|
1042
|
-
data = (
|
|
1043
|
-
current.isoformat(),
|
|
1044
|
-
new_status.sub_status,
|
|
1045
|
-
new_status.details,
|
|
1046
|
-
uint64_to_int64(run_id),
|
|
1047
|
-
)
|
|
1048
|
-
self.conn.execute(query % timestamp_fld, data)
|
|
1049
|
-
return True
|
|
1050
|
-
|
|
1051
|
-
def get_pending_run_id(self) -> int | None:
|
|
1052
|
-
"""Get the `run_id` of a run with `Status.PENDING` status, if any."""
|
|
1053
|
-
pending_run_id = None
|
|
1054
|
-
|
|
1055
|
-
# Fetch all runs with unset `starting_at` (i.e. they are in PENDING status)
|
|
1056
|
-
query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1;"
|
|
1057
|
-
rows = self.query(query)
|
|
1058
|
-
if rows:
|
|
1059
|
-
pending_run_id = int64_to_uint64(rows[0]["run_id"])
|
|
1060
|
-
|
|
1061
|
-
return pending_run_id
|
|
1062
|
-
|
|
1063
|
-
def get_federation_options(self, run_id: int) -> ConfigRecord | None:
|
|
1064
|
-
"""Retrieve the federation options for the specified `run_id`."""
|
|
1065
|
-
# Convert the uint64 value to sint64 for SQLite
|
|
1066
|
-
sint64_run_id = uint64_to_int64(run_id)
|
|
1067
|
-
query = "SELECT federation_options FROM run WHERE run_id = ?;"
|
|
1068
|
-
rows = self.query(query, (sint64_run_id,))
|
|
1069
|
-
|
|
1070
|
-
# Check if the run_id exists
|
|
1071
|
-
if not rows:
|
|
1072
|
-
log(ERROR, "`run_id` is invalid")
|
|
1073
|
-
return None
|
|
1074
|
-
|
|
1075
|
-
row = rows[0]
|
|
1076
|
-
return configrecord_from_bytes(row["federation_options"])
|
|
1077
|
-
|
|
1078
|
-
def acknowledge_node_heartbeat(
|
|
1079
|
-
self, node_id: int, heartbeat_interval: float
|
|
1080
|
-
) -> bool:
|
|
1081
|
-
"""Acknowledge a heartbeat received from a node, serving as a heartbeat.
|
|
1082
|
-
|
|
1083
|
-
A node is considered online as long as it sends heartbeats within
|
|
1084
|
-
the tolerated interval: HEARTBEAT_PATIENCE × heartbeat_interval.
|
|
1085
|
-
HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before
|
|
1086
|
-
the node is marked as offline.
|
|
1087
|
-
"""
|
|
1088
|
-
if self.conn is None:
|
|
1089
|
-
raise AttributeError("LinkState not initialized")
|
|
1090
|
-
|
|
1091
|
-
sint64_node_id = uint64_to_int64(node_id)
|
|
1092
|
-
|
|
1093
|
-
with self.conn:
|
|
1094
|
-
# Check if node exists and not deleted
|
|
1095
|
-
query = "SELECT status FROM node WHERE node_id = ? AND status != ?"
|
|
1096
|
-
row = self.conn.execute(
|
|
1097
|
-
query, (sint64_node_id, NodeStatus.UNREGISTERED)
|
|
1098
|
-
).fetchone()
|
|
1099
|
-
if row is None:
|
|
1100
|
-
return False
|
|
1101
|
-
|
|
1102
|
-
# Construct query and params
|
|
1103
|
-
current_dt = now()
|
|
1104
|
-
query = "UPDATE node SET online_until = ?, heartbeat_interval = ?"
|
|
1105
|
-
params: list[Any] = [
|
|
1106
|
-
current_dt.timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval,
|
|
1107
|
-
heartbeat_interval,
|
|
1108
|
-
]
|
|
1109
|
-
|
|
1110
|
-
# Set timestamp if the status changes
|
|
1111
|
-
if row["status"] != NodeStatus.ONLINE:
|
|
1112
|
-
query += ", status = ?, last_activated_at = ?"
|
|
1113
|
-
params += [NodeStatus.ONLINE, current_dt.isoformat()]
|
|
1114
|
-
|
|
1115
|
-
# Execute the query, refreshing `online_until` and `heartbeat_interval`
|
|
1116
|
-
query += " WHERE node_id = ?"
|
|
1117
|
-
params += [sint64_node_id]
|
|
1118
|
-
self.conn.execute(query, params)
|
|
1119
|
-
return True
|
|
1120
|
-
|
|
1121
|
-
def _on_tokens_expired(self, expired_records: list[tuple[int, float]]) -> None:
|
|
1122
|
-
"""Transition runs with expired tokens to failed status.
|
|
1123
|
-
|
|
1124
|
-
Parameters
|
|
1125
|
-
----------
|
|
1126
|
-
expired_records : list[tuple[int, float]]
|
|
1127
|
-
List of tuples containing (run_id, active_until timestamp)
|
|
1128
|
-
for expired tokens.
|
|
1129
|
-
"""
|
|
1130
|
-
if not expired_records:
|
|
1131
|
-
return
|
|
1132
|
-
|
|
1133
|
-
with self.conn:
|
|
1134
|
-
query = """
|
|
1135
|
-
UPDATE run
|
|
1136
|
-
SET sub_status = ?, details = ?, finished_at = ?
|
|
1137
|
-
WHERE run_id = ?;
|
|
1138
|
-
"""
|
|
1139
|
-
data = [
|
|
1140
|
-
(
|
|
1141
|
-
SubStatus.FAILED,
|
|
1142
|
-
RUN_FAILURE_DETAILS_NO_HEARTBEAT,
|
|
1143
|
-
datetime.fromtimestamp(active_until, tz=timezone.utc).isoformat(),
|
|
1144
|
-
uint64_to_int64(run_id),
|
|
1145
|
-
)
|
|
1146
|
-
for run_id, active_until in expired_records
|
|
1147
|
-
]
|
|
1148
|
-
self.conn.executemany(query, data)
|
|
1149
|
-
|
|
1150
|
-
def get_serverapp_context(self, run_id: int) -> Context | None:
|
|
1151
|
-
"""Get the context for the specified `run_id`."""
|
|
1152
|
-
# Retrieve context if any
|
|
1153
|
-
query = "SELECT context FROM context WHERE run_id = ?;"
|
|
1154
|
-
rows = self.query(query, (uint64_to_int64(run_id),))
|
|
1155
|
-
context = context_from_bytes(rows[0]["context"]) if rows else None
|
|
1156
|
-
return context
|
|
1157
|
-
|
|
1158
|
-
def set_serverapp_context(self, run_id: int, context: Context) -> None:
|
|
1159
|
-
"""Set the context for the specified `run_id`."""
|
|
1160
|
-
# Convert context to bytes
|
|
1161
|
-
context_bytes = context_to_bytes(context)
|
|
1162
|
-
sint_run_id = uint64_to_int64(run_id)
|
|
1163
|
-
|
|
1164
|
-
with self.conn:
|
|
1165
|
-
# Check if any existing Context assigned to the run_id
|
|
1166
|
-
query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
|
|
1167
|
-
row = self.conn.execute(query, (sint_run_id,)).fetchone()
|
|
1168
|
-
if row["COUNT(*)"] > 0:
|
|
1169
|
-
# Update context
|
|
1170
|
-
query = "UPDATE context SET context = ? WHERE run_id = ?;"
|
|
1171
|
-
self.conn.execute(query, (context_bytes, sint_run_id))
|
|
1172
|
-
else:
|
|
1173
|
-
try:
|
|
1174
|
-
# Store context
|
|
1175
|
-
query = "INSERT INTO context (run_id, context) VALUES (?, ?);"
|
|
1176
|
-
self.conn.execute(query, (sint_run_id, context_bytes))
|
|
1177
|
-
except sqlite3.IntegrityError:
|
|
1178
|
-
raise ValueError(f"Run {run_id} not found") from None
|
|
1179
|
-
|
|
1180
|
-
def add_serverapp_log(self, run_id: int, log_message: str) -> None:
|
|
1181
|
-
"""Add a log entry to the ServerApp logs for the specified `run_id`."""
|
|
1182
|
-
# Convert the uint64 value to sint64 for SQLite
|
|
1183
|
-
sint64_run_id = uint64_to_int64(run_id)
|
|
1184
|
-
|
|
1185
|
-
# Store log
|
|
1186
|
-
try:
|
|
1187
|
-
query = """
|
|
1188
|
-
INSERT INTO logs (timestamp, run_id, node_id, log) VALUES (?, ?, ?, ?);
|
|
1189
|
-
"""
|
|
1190
|
-
self.query(query, (now().timestamp(), sint64_run_id, 0, log_message))
|
|
1191
|
-
except sqlite3.IntegrityError:
|
|
1192
|
-
raise ValueError(f"Run {run_id} not found") from None
|
|
1193
|
-
|
|
1194
|
-
def get_serverapp_log(
|
|
1195
|
-
self, run_id: int, after_timestamp: float | None
|
|
1196
|
-
) -> tuple[str, float]:
|
|
1197
|
-
"""Get the ServerApp logs for the specified `run_id`."""
|
|
1198
|
-
# Convert the uint64 value to sint64 for SQLite
|
|
1199
|
-
sint64_run_id = uint64_to_int64(run_id)
|
|
1200
|
-
|
|
1201
|
-
with self.conn:
|
|
1202
|
-
# Check if the run_id exists
|
|
1203
|
-
query = "SELECT run_id FROM run WHERE run_id = ?;"
|
|
1204
|
-
rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
|
|
1205
|
-
if not rows:
|
|
1206
|
-
raise ValueError(f"Run {run_id} not found")
|
|
1207
|
-
|
|
1208
|
-
# Retrieve logs
|
|
1209
|
-
if after_timestamp is None:
|
|
1210
|
-
after_timestamp = 0.0
|
|
1211
|
-
query = """
|
|
1212
|
-
SELECT log, timestamp FROM logs
|
|
1213
|
-
WHERE run_id = ? AND node_id = ? AND timestamp > ?;
|
|
1214
|
-
"""
|
|
1215
|
-
rows = self.conn.execute(
|
|
1216
|
-
query, (sint64_run_id, 0, after_timestamp)
|
|
1217
|
-
).fetchall()
|
|
1218
|
-
rows.sort(key=lambda x: x["timestamp"])
|
|
1219
|
-
latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
|
|
1220
|
-
return "".join(row["log"] for row in rows), latest_timestamp
|
|
1221
|
-
|
|
1222
|
-
def get_valid_message_ins(self, message_id: str) -> dict[str, Any] | None:
|
|
1223
|
-
"""Check if the Message exists and is valid (not expired).
|
|
1224
|
-
|
|
1225
|
-
Return Message if valid.
|
|
1226
|
-
"""
|
|
1227
|
-
with self.conn:
|
|
1228
|
-
self._check_stored_messages({message_id})
|
|
1229
|
-
query = """
|
|
1230
|
-
SELECT *
|
|
1231
|
-
FROM message_ins
|
|
1232
|
-
WHERE message_id = :message_id
|
|
1233
|
-
"""
|
|
1234
|
-
data = {"message_id": message_id}
|
|
1235
|
-
rows: list[dict[str, Any]] = self.conn.execute(query, data).fetchall()
|
|
1236
|
-
if not rows:
|
|
1237
|
-
# Message does not exist
|
|
1238
|
-
return None
|
|
1239
|
-
|
|
1240
|
-
return rows[0]
|
|
1241
|
-
|
|
1242
|
-
def store_traffic(self, run_id: int, *, bytes_sent: int, bytes_recv: int) -> None:
|
|
1243
|
-
"""Store traffic data for the specified `run_id`."""
|
|
1244
|
-
# Validate non-negative values
|
|
1245
|
-
if bytes_sent < 0 or bytes_recv < 0:
|
|
1246
|
-
raise ValueError(
|
|
1247
|
-
f"Negative traffic values for run {run_id}: "
|
|
1248
|
-
f"bytes_sent={bytes_sent}, bytes_recv={bytes_recv}"
|
|
1249
|
-
)
|
|
1250
|
-
|
|
1251
|
-
if bytes_sent == 0 and bytes_recv == 0:
|
|
1252
|
-
raise ValueError(
|
|
1253
|
-
f"Both bytes_sent and bytes_recv cannot be zero for run {run_id}"
|
|
1254
|
-
)
|
|
1255
|
-
|
|
1256
|
-
sint64_run_id = uint64_to_int64(run_id)
|
|
1257
|
-
|
|
1258
|
-
with self.conn:
|
|
1259
|
-
# Check if run exists, performing the update only if it does
|
|
1260
|
-
update_query = """
|
|
1261
|
-
UPDATE run
|
|
1262
|
-
SET bytes_sent = bytes_sent + ?,
|
|
1263
|
-
bytes_recv = bytes_recv + ?
|
|
1264
|
-
WHERE run_id = ?
|
|
1265
|
-
RETURNING run_id;
|
|
1266
|
-
"""
|
|
1267
|
-
rows = self.conn.execute(
|
|
1268
|
-
update_query, (bytes_sent, bytes_recv, sint64_run_id)
|
|
1269
|
-
).fetchall()
|
|
1270
|
-
|
|
1271
|
-
if not rows:
|
|
1272
|
-
raise ValueError(f"Run {run_id} not found")
|
|
1273
|
-
|
|
1274
|
-
def add_clientapp_runtime(self, run_id: int, runtime: float) -> None:
|
|
1275
|
-
"""Add ClientApp runtime to the cumulative total for the specified `run_id`."""
|
|
1276
|
-
sint64_run_id = uint64_to_int64(run_id)
|
|
1277
|
-
with self.conn:
|
|
1278
|
-
# Check if run exists, performing the update only if it does
|
|
1279
|
-
update_query = """
|
|
1280
|
-
UPDATE run
|
|
1281
|
-
SET clientapp_runtime = clientapp_runtime + ?
|
|
1282
|
-
WHERE run_id = ?
|
|
1283
|
-
RETURNING run_id;
|
|
1284
|
-
"""
|
|
1285
|
-
rows = self.conn.execute(update_query, (runtime, sint64_run_id)).fetchall()
|
|
1286
|
-
|
|
1287
|
-
if not rows:
|
|
1288
|
-
raise ValueError(f"Run {run_id} not found")
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
def determine_run_status(row: dict[str, Any]) -> str:
|
|
1292
|
-
"""Determine the status of the run based on timestamp fields."""
|
|
1293
|
-
if row["pending_at"]:
|
|
1294
|
-
if row["finished_at"]:
|
|
1295
|
-
return Status.FINISHED
|
|
1296
|
-
if row["starting_at"]:
|
|
1297
|
-
if row["running_at"]:
|
|
1298
|
-
return Status.RUNNING
|
|
1299
|
-
return Status.STARTING
|
|
1300
|
-
return Status.PENDING
|
|
1301
|
-
run_id = int64_to_uint64(row["run_id"])
|
|
1302
|
-
raise sqlite3.IntegrityError(f"The run {run_id} does not have a valid status.")
|