flwr-nightly 1.26.0.dev20260122__py3-none-any.whl → 1.26.0.dev20260126__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. flwr/cli/app_cmd/publish.py +18 -44
  2. flwr/cli/app_cmd/review.py +8 -25
  3. flwr/cli/auth_plugin/oidc_cli_plugin.py +3 -6
  4. flwr/cli/build.py +8 -19
  5. flwr/cli/config/ls.py +8 -13
  6. flwr/cli/config_utils.py +19 -171
  7. flwr/cli/federation/ls.py +3 -7
  8. flwr/cli/flower_config.py +28 -47
  9. flwr/cli/install.py +18 -57
  10. flwr/cli/log.py +2 -2
  11. flwr/cli/login/login.py +8 -21
  12. flwr/cli/ls.py +3 -7
  13. flwr/cli/new/new.py +9 -29
  14. flwr/cli/pull.py +3 -7
  15. flwr/cli/run/run.py +6 -15
  16. flwr/cli/stop.py +5 -17
  17. flwr/cli/supernode/register.py +6 -22
  18. flwr/cli/supernode/unregister.py +3 -13
  19. flwr/cli/utils.py +66 -169
  20. flwr/common/config.py +5 -9
  21. flwr/common/constant.py +2 -0
  22. flwr/server/superlink/fleet/message_handler/message_handler.py +4 -4
  23. flwr/server/superlink/linkstate/__init__.py +0 -2
  24. flwr/server/superlink/linkstate/sql_linkstate.py +38 -10
  25. flwr/supercore/object_store/object_store_factory.py +4 -4
  26. flwr/supercore/object_store/sql_object_store.py +171 -6
  27. flwr/superlink/servicer/control/control_servicer.py +11 -12
  28. {flwr_nightly-1.26.0.dev20260122.dist-info → flwr_nightly-1.26.0.dev20260126.dist-info}/METADATA +2 -2
  29. {flwr_nightly-1.26.0.dev20260122.dist-info → flwr_nightly-1.26.0.dev20260126.dist-info}/RECORD +31 -35
  30. flwr/server/superlink/linkstate/sqlite_linkstate.py +0 -1302
  31. flwr/supercore/corestate/sqlite_corestate.py +0 -157
  32. flwr/supercore/object_store/sqlite_object_store.py +0 -253
  33. flwr/supercore/sqlite_mixin.py +0 -156
  34. {flwr_nightly-1.26.0.dev20260122.dist-info → flwr_nightly-1.26.0.dev20260126.dist-info}/WHEEL +0 -0
  35. {flwr_nightly-1.26.0.dev20260122.dist-info → flwr_nightly-1.26.0.dev20260126.dist-info}/entry_points.txt +0 -0
@@ -1,1302 +0,0 @@
1
- # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- """SQLite based implemenation of the link state."""
16
-
17
-
18
- # pylint: disable=too-many-lines
19
-
20
- import json
21
- import sqlite3
22
- from collections.abc import Sequence
23
- from datetime import datetime, timezone
24
- from logging import ERROR, WARNING
25
- from typing import Any, cast
26
-
27
- from flwr.app.user_config import UserConfig
28
- from flwr.common import Context, Message, log, now
29
- from flwr.common.constant import (
30
- HEARTBEAT_PATIENCE,
31
- MESSAGE_TTL_TOLERANCE,
32
- NODE_ID_NUM_BYTES,
33
- RUN_FAILURE_DETAILS_NO_HEARTBEAT,
34
- RUN_ID_NUM_BYTES,
35
- SUPERLINK_NODE_ID,
36
- Status,
37
- SubStatus,
38
- )
39
- from flwr.common.record import ConfigRecord
40
- from flwr.common.typing import Run, RunStatus
41
-
42
- # pylint: disable=E0611
43
- from flwr.proto.node_pb2 import NodeInfo
44
-
45
- # pylint: enable=E0611
46
- from flwr.server.utils.validator import validate_message
47
- from flwr.supercore.constant import NodeStatus
48
- from flwr.supercore.corestate.sqlite_corestate import SqliteCoreState
49
- from flwr.supercore.object_store.object_store import ObjectStore
50
- from flwr.supercore.utils import int64_to_uint64, uint64_to_int64
51
- from flwr.superlink.federation import FederationManager
52
-
53
- from .linkstate import LinkState
54
- from .utils import (
55
- check_node_availability_for_in_message,
56
- configrecord_from_bytes,
57
- configrecord_to_bytes,
58
- context_from_bytes,
59
- context_to_bytes,
60
- convert_sint64_values_in_dict_to_uint64,
61
- convert_uint64_values_in_dict_to_sint64,
62
- dict_to_message,
63
- generate_rand_int_from_bytes,
64
- has_valid_sub_status,
65
- is_valid_transition,
66
- message_to_dict,
67
- verify_found_message_replies,
68
- verify_message_ids,
69
- )
70
-
71
- SQL_CREATE_TABLE_NODE = """
72
- CREATE TABLE IF NOT EXISTS node(
73
- node_id INTEGER UNIQUE,
74
- owner_aid TEXT,
75
- owner_name TEXT,
76
- status TEXT,
77
- registered_at TEXT,
78
- last_activated_at TEXT NULL,
79
- last_deactivated_at TEXT NULL,
80
- unregistered_at TEXT NULL,
81
- online_until TIMESTAMP NULL,
82
- heartbeat_interval REAL,
83
- public_key BLOB UNIQUE
84
- );
85
- """
86
-
87
- SQL_CREATE_INDEX_ONLINE_UNTIL = """
88
- CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
89
- """
90
-
91
- SQL_CREATE_INDEX_OWNER_AID = """
92
- CREATE INDEX IF NOT EXISTS idx_node_owner_aid ON node(owner_aid);
93
- """
94
-
95
- SQL_CREATE_INDEX_NODE_STATUS = """
96
- CREATE INDEX IF NOT EXISTS idx_node_status ON node(status);
97
- """
98
-
99
- SQL_CREATE_TABLE_RUN = """
100
- CREATE TABLE IF NOT EXISTS run(
101
- run_id INTEGER UNIQUE,
102
- fab_id TEXT,
103
- fab_version TEXT,
104
- fab_hash TEXT,
105
- override_config TEXT,
106
- pending_at TEXT,
107
- starting_at TEXT,
108
- running_at TEXT,
109
- finished_at TEXT,
110
- sub_status TEXT,
111
- details TEXT,
112
- federation TEXT,
113
- federation_options BLOB,
114
- flwr_aid TEXT,
115
- bytes_sent INTEGER DEFAULT 0,
116
- bytes_recv INTEGER DEFAULT 0,
117
- clientapp_runtime REAL DEFAULT 0.0
118
- );
119
- """
120
-
121
- SQL_CREATE_TABLE_LOGS = """
122
- CREATE TABLE IF NOT EXISTS logs (
123
- timestamp REAL,
124
- run_id INTEGER,
125
- node_id INTEGER,
126
- log TEXT,
127
- PRIMARY KEY (timestamp, run_id, node_id),
128
- FOREIGN KEY (run_id) REFERENCES run(run_id)
129
- );
130
- """
131
-
132
- SQL_CREATE_TABLE_CONTEXT = """
133
- CREATE TABLE IF NOT EXISTS context(
134
- run_id INTEGER UNIQUE,
135
- context BLOB,
136
- FOREIGN KEY(run_id) REFERENCES run(run_id)
137
- );
138
- """
139
-
140
- SQL_CREATE_TABLE_MESSAGE_INS = """
141
- CREATE TABLE IF NOT EXISTS message_ins(
142
- message_id TEXT UNIQUE,
143
- group_id TEXT,
144
- run_id INTEGER,
145
- src_node_id INTEGER,
146
- dst_node_id INTEGER,
147
- reply_to_message_id TEXT,
148
- created_at REAL,
149
- delivered_at TEXT,
150
- ttl REAL,
151
- message_type TEXT,
152
- content BLOB NULL,
153
- error BLOB NULL,
154
- FOREIGN KEY(run_id) REFERENCES run(run_id)
155
- );
156
- """
157
-
158
-
159
- SQL_CREATE_TABLE_MESSAGE_RES = """
160
- CREATE TABLE IF NOT EXISTS message_res(
161
- message_id TEXT UNIQUE,
162
- group_id TEXT,
163
- run_id INTEGER,
164
- src_node_id INTEGER,
165
- dst_node_id INTEGER,
166
- reply_to_message_id TEXT,
167
- created_at REAL,
168
- delivered_at TEXT,
169
- ttl REAL,
170
- message_type TEXT,
171
- content BLOB NULL,
172
- error BLOB NULL,
173
- FOREIGN KEY(run_id) REFERENCES run(run_id)
174
- );
175
- """
176
-
177
-
178
- class SqliteLinkState(LinkState, SqliteCoreState): # pylint: disable=R0904
179
- """SQLite-based LinkState implementation."""
180
-
181
- def __init__(
182
- self,
183
- database_path: str,
184
- federation_manager: FederationManager,
185
- object_store: ObjectStore,
186
- ) -> None:
187
- super().__init__(database_path, object_store)
188
- federation_manager.linkstate = self
189
- self._federation_manager = federation_manager
190
-
191
- def get_sql_statements(self) -> tuple[str, ...]:
192
- """Return SQL statements for LinkState tables."""
193
- return super().get_sql_statements() + (
194
- SQL_CREATE_TABLE_RUN,
195
- SQL_CREATE_TABLE_LOGS,
196
- SQL_CREATE_TABLE_CONTEXT,
197
- SQL_CREATE_TABLE_MESSAGE_INS,
198
- SQL_CREATE_TABLE_MESSAGE_RES,
199
- SQL_CREATE_TABLE_NODE,
200
- SQL_CREATE_INDEX_ONLINE_UNTIL,
201
- SQL_CREATE_INDEX_OWNER_AID,
202
- SQL_CREATE_INDEX_NODE_STATUS,
203
- )
204
-
205
- @property
206
- def federation_manager(self) -> FederationManager:
207
- """Get the FederationManager instance."""
208
- return self._federation_manager
209
-
210
- def store_message_ins(self, message: Message) -> str | None:
211
- """Store one Message."""
212
- # Validate message
213
- errors = validate_message(message=message, is_reply_message=False)
214
- if any(errors):
215
- log(ERROR, errors)
216
- return None
217
-
218
- # Store Message
219
- data = (message_to_dict(message),)
220
-
221
- # Convert values from uint64 to sint64 for SQLite
222
- convert_uint64_values_in_dict_to_sint64(
223
- data[0], ["run_id", "src_node_id", "dst_node_id"]
224
- )
225
-
226
- # Validate source node ID
227
- if message.metadata.src_node_id != SUPERLINK_NODE_ID:
228
- log(
229
- ERROR,
230
- "Invalid source node ID for Message: %s",
231
- message.metadata.src_node_id,
232
- )
233
- return None
234
-
235
- with self.conn:
236
- # Validate run_id
237
- query = "SELECT federation FROM run WHERE run_id = ?;"
238
- rows = self.conn.execute(query, (data[0]["run_id"],)).fetchall()
239
- if not rows:
240
- log(ERROR, "Invalid run ID for Message: %s", message.metadata.run_id)
241
- return None
242
- federation: str = rows[0]["federation"]
243
-
244
- # Validate destination node ID
245
- query = "SELECT node_id FROM node WHERE node_id = ? AND status IN (?, ?);"
246
- rows = self.conn.execute(
247
- query, (data[0]["dst_node_id"], NodeStatus.ONLINE, NodeStatus.OFFLINE)
248
- ).fetchall()
249
- if not rows or not self.federation_manager.has_node(
250
- message.metadata.dst_node_id, federation
251
- ):
252
- log(
253
- ERROR,
254
- "Invalid destination node ID for Message: %s",
255
- message.metadata.dst_node_id,
256
- )
257
- return None
258
-
259
- columns = ", ".join([f":{key}" for key in data[0]])
260
- query = f"INSERT INTO message_ins VALUES({columns});"
261
-
262
- # Only invalid run_id can trigger IntegrityError.
263
- # This may need to be changed in the future version
264
- # with more integrity checks.
265
- self.conn.execute(query, data[0])
266
-
267
- return message.metadata.message_id
268
-
269
- def _check_stored_messages(self, message_ids: set[str]) -> None:
270
- """Check and delete the message if it's invalid."""
271
- if not message_ids:
272
- return
273
-
274
- with self.conn:
275
- invalid_msg_ids: set[str] = set()
276
- current_time = now().timestamp()
277
-
278
- for msg_id in message_ids:
279
- # Check if message exists
280
- query = "SELECT * FROM message_ins WHERE message_id = ?;"
281
- message_row = self.conn.execute(query, (msg_id,)).fetchone()
282
- if not message_row:
283
- continue
284
-
285
- # Check if the message has expired
286
- available_until = message_row["created_at"] + message_row["ttl"]
287
- if available_until <= current_time:
288
- invalid_msg_ids.add(msg_id)
289
- continue
290
-
291
- # Check if src_node_id and dst_node_id are in the federation
292
- # Get federation from run table
293
- run_id = message_row["run_id"]
294
- query = "SELECT federation FROM run WHERE run_id = ?;"
295
- run_row = self.conn.execute(query, (run_id,)).fetchone()
296
- if not run_row: # This should not happen
297
- invalid_msg_ids.add(msg_id)
298
- continue
299
- federation = run_row["federation"]
300
-
301
- # Convert sint64 to uint64 for node IDs
302
- src_node_id = int64_to_uint64(message_row["src_node_id"])
303
- dst_node_id = int64_to_uint64(message_row["dst_node_id"])
304
-
305
- # Filter nodes to check if they're in the federation
306
- filtered = self.federation_manager.filter_nodes(
307
- {src_node_id, dst_node_id}, federation
308
- )
309
- if len(filtered) != 2: # Not both nodes are in the federation
310
- invalid_msg_ids.add(msg_id)
311
-
312
- # Delete all invalid messages
313
- self.delete_messages(invalid_msg_ids)
314
-
315
- def get_message_ins(self, node_id: int, limit: int | None) -> list[Message]:
316
- """Get all Messages that have not been delivered yet."""
317
- if limit is not None and limit < 1:
318
- raise AssertionError("`limit` must be >= 1")
319
-
320
- if node_id == SUPERLINK_NODE_ID:
321
- msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
322
- raise AssertionError(msg)
323
-
324
- data: dict[str, str | int] = {}
325
-
326
- # Convert the uint64 value to sint64 for SQLite
327
- data["node_id"] = uint64_to_int64(node_id)
328
-
329
- with self.conn:
330
- # Retrieve all Messages for node_id
331
- query = """
332
- SELECT message_id
333
- FROM message_ins
334
- WHERE dst_node_id == :node_id
335
- AND delivered_at = ""
336
- AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
337
- """
338
-
339
- if limit is not None:
340
- query += " LIMIT :limit"
341
- data["limit"] = limit
342
-
343
- query += ";"
344
-
345
- rows = self.conn.execute(query, data).fetchall()
346
- message_ids: set[str] = {row["message_id"] for row in rows}
347
- self._check_stored_messages(message_ids)
348
-
349
- # Mark retrieved Messages as delivered
350
- if rows:
351
- # Prepare query
352
- placeholders: str = ",".join(
353
- [f":id_{i}" for i in range(len(message_ids))]
354
- )
355
- query = f"""
356
- UPDATE message_ins
357
- SET delivered_at = :delivered_at
358
- WHERE message_id IN ({placeholders})
359
- RETURNING *;
360
- """
361
-
362
- # Prepare data for query
363
- delivered_at = now().isoformat()
364
- data = {"delivered_at": delivered_at}
365
- for index, msg_id in enumerate(message_ids):
366
- data[f"id_{index}"] = str(msg_id)
367
-
368
- # Run query
369
- rows = self.conn.execute(query, data).fetchall()
370
-
371
- for row in rows:
372
- # Convert values from sint64 to uint64
373
- convert_sint64_values_in_dict_to_uint64(
374
- row, ["run_id", "src_node_id", "dst_node_id"]
375
- )
376
-
377
- result = [dict_to_message(row) for row in rows]
378
-
379
- return result
380
-
381
- def store_message_res(self, message: Message) -> str | None:
382
- """Store one Message."""
383
- # Validate message
384
- errors = validate_message(message=message, is_reply_message=True)
385
- if any(errors):
386
- log(ERROR, errors)
387
- return None
388
-
389
- res_metadata = message.metadata
390
- msg_ins_id = res_metadata.reply_to_message_id
391
- msg_ins = self.get_valid_message_ins(msg_ins_id)
392
- if msg_ins is None:
393
- log(
394
- ERROR,
395
- "Failed to store Message reply: "
396
- "The message it replies to with message_id %s does not exist or "
397
- "has expired, or was deleted because the target SuperNode was "
398
- "removed from the federation.",
399
- msg_ins_id,
400
- )
401
- return None
402
-
403
- # Ensure that the dst_node_id of the original message matches the src_node_id of
404
- # reply being processed.
405
- if (
406
- msg_ins
407
- and message
408
- and int64_to_uint64(msg_ins["dst_node_id"]) != res_metadata.src_node_id
409
- ):
410
- return None
411
-
412
- # Fail if the Message TTL exceeds the
413
- # expiration time of the Message it replies to.
414
- # Condition: ins_metadata.created_at + ins_metadata.ttl ≥
415
- # res_metadata.created_at + res_metadata.ttl
416
- # A small tolerance is introduced to account
417
- # for floating-point precision issues.
418
- max_allowed_ttl = (
419
- msg_ins["created_at"] + msg_ins["ttl"] - res_metadata.created_at
420
- )
421
- if res_metadata.ttl and (
422
- res_metadata.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
423
- ):
424
- log(
425
- WARNING,
426
- "Received Message with TTL %.2f exceeding the allowed maximum "
427
- "TTL %.2f.",
428
- res_metadata.ttl,
429
- max_allowed_ttl,
430
- )
431
- return None
432
-
433
- # Store Message
434
- data = (message_to_dict(message),)
435
-
436
- # Convert values from uint64 to sint64 for SQLite
437
- convert_uint64_values_in_dict_to_sint64(
438
- data[0], ["run_id", "src_node_id", "dst_node_id"]
439
- )
440
-
441
- columns = ", ".join([f":{key}" for key in data[0]])
442
- query = f"INSERT INTO message_res VALUES({columns});"
443
-
444
- # Only invalid run_id can trigger IntegrityError.
445
- # This may need to be changed in the future version with more integrity checks.
446
- try:
447
- self.query(query, data)
448
- except sqlite3.IntegrityError:
449
- log(ERROR, "`run` is invalid")
450
- return None
451
-
452
- return message.metadata.message_id
453
-
454
- def get_message_res(self, message_ids: set[str]) -> list[Message]:
455
- """Get reply Messages for the given Message IDs."""
456
- # pylint: disable-msg=too-many-locals
457
- ret: dict[str, Message] = {}
458
-
459
- with self.conn:
460
- # Verify Message IDs
461
- self._check_stored_messages(message_ids)
462
- current = now().timestamp()
463
- query = f"""
464
- SELECT *
465
- FROM message_ins
466
- WHERE message_id IN ({','.join(['?'] * len(message_ids))});
467
- """
468
- rows = self.conn.execute(
469
- query, tuple(str(message_id) for message_id in message_ids)
470
- ).fetchall()
471
- found_message_ins_dict: dict[str, Message] = {}
472
- for row in rows:
473
- convert_sint64_values_in_dict_to_uint64(
474
- row, ["run_id", "src_node_id", "dst_node_id"]
475
- )
476
- found_message_ins_dict[row["message_id"]] = dict_to_message(row)
477
-
478
- ret = verify_message_ids(
479
- inquired_message_ids=message_ids,
480
- found_message_ins_dict=found_message_ins_dict,
481
- current_time=current,
482
- )
483
-
484
- # Check node availability
485
- dst_node_ids: set[int] = set()
486
- for message_id in message_ids:
487
- in_message = found_message_ins_dict[message_id]
488
- sint_node_id = uint64_to_int64(in_message.metadata.dst_node_id)
489
- dst_node_ids.add(sint_node_id)
490
- query = f"""
491
- SELECT node_id, online_until
492
- FROM node
493
- WHERE node_id IN ({','.join(['?'] * len(dst_node_ids))})
494
- AND status != ?
495
- """
496
- rows = self.conn.execute(
497
- query, tuple(dst_node_ids) + (NodeStatus.UNREGISTERED,)
498
- ).fetchall()
499
- tmp_ret_dict = check_node_availability_for_in_message(
500
- inquired_in_message_ids=message_ids,
501
- found_in_message_dict=found_message_ins_dict,
502
- node_id_to_online_until={
503
- int64_to_uint64(row["node_id"]): row["online_until"] for row in rows
504
- },
505
- current_time=current,
506
- )
507
- ret.update(tmp_ret_dict)
508
-
509
- # Find all reply Messages
510
- query = f"""
511
- SELECT *
512
- FROM message_res
513
- WHERE reply_to_message_id IN ({','.join(['?'] * len(message_ids))})
514
- AND delivered_at = "";
515
- """
516
- rows = self.conn.execute(
517
- query, tuple(str(message_id) for message_id in message_ids)
518
- ).fetchall()
519
- for row in rows:
520
- convert_sint64_values_in_dict_to_uint64(
521
- row, ["run_id", "src_node_id", "dst_node_id"]
522
- )
523
- tmp_ret_dict = verify_found_message_replies(
524
- inquired_message_ids=message_ids,
525
- found_message_ins_dict=found_message_ins_dict,
526
- found_message_res_list=[dict_to_message(row) for row in rows],
527
- current_time=current,
528
- )
529
- ret.update(tmp_ret_dict)
530
-
531
- # Mark existing reply Messages to be returned as delivered
532
- delivered_at = now().isoformat()
533
- for message_res in ret.values():
534
- message_res.metadata.delivered_at = delivered_at
535
- message_res_ids = [
536
- message_res.metadata.message_id for message_res in ret.values()
537
- ]
538
- query = f"""
539
- UPDATE message_res
540
- SET delivered_at = ?
541
- WHERE message_id IN ({','.join(['?'] * len(message_res_ids))});
542
- """
543
- data: list[Any] = [delivered_at] + message_res_ids
544
- self.conn.execute(query, data)
545
-
546
- return list(ret.values())
547
-
548
- def num_message_ins(self) -> int:
549
- """Calculate the number of instruction Messages in store.
550
-
551
- This includes delivered but not yet deleted.
552
- """
553
- query = "SELECT count(*) AS num FROM message_ins;"
554
- rows = self.query(query)
555
- result = rows[0]
556
- num = cast(int, result["num"])
557
- return num
558
-
559
- def num_message_res(self) -> int:
560
- """Calculate the number of reply Messages in store.
561
-
562
- This includes delivered but not yet deleted.
563
- """
564
- query = "SELECT count(*) AS num FROM message_res;"
565
- rows = self.query(query)
566
- result: dict[str, int] = rows[0]
567
- return result["num"]
568
-
569
- def delete_messages(self, message_ins_ids: set[str]) -> None:
570
- """Delete a Message and its reply based on provided Message IDs."""
571
- if not message_ins_ids:
572
- return
573
- if self.conn is None:
574
- raise AttributeError("LinkState not initialized")
575
-
576
- placeholders = ",".join(["?"] * len(message_ins_ids))
577
- data = tuple(str(message_id) for message_id in message_ins_ids)
578
-
579
- # Delete Message
580
- query_1 = f"""
581
- DELETE FROM message_ins
582
- WHERE message_id IN ({placeholders});
583
- """
584
-
585
- # Delete reply Message
586
- query_2 = f"""
587
- DELETE FROM message_res
588
- WHERE reply_to_message_id IN ({placeholders});
589
- """
590
-
591
- with self.conn:
592
- self.conn.execute(query_1, data)
593
- self.conn.execute(query_2, data)
594
-
595
- def get_message_ids_from_run_id(self, run_id: int) -> set[str]:
596
- """Get all instruction Message IDs for the given run_id."""
597
- if self.conn is None:
598
- raise AttributeError("LinkState not initialized")
599
-
600
- query = """
601
- SELECT message_id
602
- FROM message_ins
603
- WHERE run_id = :run_id;
604
- """
605
-
606
- sint64_run_id = uint64_to_int64(run_id)
607
- data = {"run_id": sint64_run_id}
608
-
609
- with self.conn:
610
- rows = self.conn.execute(query, data).fetchall()
611
-
612
- return {row["message_id"] for row in rows}
613
-
614
- def create_node(
615
- self,
616
- owner_aid: str,
617
- owner_name: str,
618
- public_key: bytes,
619
- heartbeat_interval: float,
620
- ) -> int:
621
- """Create, store in the link state, and return `node_id`."""
622
- # Sample a random uint64 as node_id
623
- uint64_node_id = generate_rand_int_from_bytes(
624
- NODE_ID_NUM_BYTES, exclude=[SUPERLINK_NODE_ID, 0]
625
- )
626
-
627
- # Convert the uint64 value to sint64 for SQLite
628
- sint64_node_id = uint64_to_int64(uint64_node_id)
629
-
630
- query = """
631
- INSERT INTO node
632
- (node_id, owner_aid, owner_name, status, registered_at, last_activated_at,
633
- last_deactivated_at, unregistered_at, online_until, heartbeat_interval,
634
- public_key)
635
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
636
- """
637
-
638
- # Mark the node online until now().timestamp() + heartbeat_interval
639
- try:
640
- self.query(
641
- query,
642
- (
643
- sint64_node_id, # node_id
644
- owner_aid, # owner_aid
645
- owner_name, # owner_name
646
- NodeStatus.REGISTERED, # status
647
- now().isoformat(), # registered_at
648
- None, # last_activated_at
649
- None, # last_deactivated_at
650
- None, # unregistered_at
651
- None, # online_until, initialized with offline status
652
- heartbeat_interval, # heartbeat_interval
653
- public_key, # public_key
654
- ),
655
- )
656
- except sqlite3.IntegrityError as e:
657
- if "UNIQUE constraint failed: node.public_key" in str(e):
658
- raise ValueError("Public key already in use.") from None
659
- # Must be node ID conflict, almost impossible unless system is compromised
660
- log(ERROR, "Unexpected node registration failure.")
661
- return 0
662
-
663
- # Note: we need to return the uint64 value of the node_id
664
- return uint64_node_id
665
-
666
- def delete_node(self, owner_aid: str, node_id: int) -> None:
667
- """Delete a node."""
668
- sint64_node_id = uint64_to_int64(node_id)
669
-
670
- query = """
671
- UPDATE node
672
- SET status = ?, unregistered_at = ?,
673
- online_until = IIF(online_until > ?, ?, online_until)
674
- WHERE node_id = ? AND status != ? AND owner_aid = ?
675
- RETURNING node_id
676
- """
677
- current = now()
678
- params = (
679
- NodeStatus.UNREGISTERED,
680
- current.isoformat(),
681
- current.timestamp(),
682
- current.timestamp(),
683
- sint64_node_id,
684
- NodeStatus.UNREGISTERED,
685
- owner_aid,
686
- )
687
-
688
- rows = self.query(query, params)
689
- if not rows:
690
- raise ValueError(
691
- f"Node {node_id} already deleted, not found or unauthorized "
692
- "deletion attempt."
693
- )
694
-
695
- def activate_node(self, node_id: int, heartbeat_interval: float) -> bool:
696
- """Activate the node with the specified `node_id`."""
697
- with self.conn:
698
- self._check_and_tag_offline_nodes([node_id])
699
-
700
- # Only activate if the node is currently registered or offline
701
- current_dt = now()
702
- query = """
703
- UPDATE node
704
- SET status = ?,
705
- last_activated_at = ?,
706
- online_until = ?,
707
- heartbeat_interval = ?
708
- WHERE node_id = ? AND status in (?, ?)
709
- RETURNING node_id
710
- """
711
- params = (
712
- NodeStatus.ONLINE,
713
- current_dt.isoformat(),
714
- current_dt.timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval,
715
- heartbeat_interval,
716
- uint64_to_int64(node_id),
717
- NodeStatus.REGISTERED,
718
- NodeStatus.OFFLINE,
719
- )
720
-
721
- row = self.conn.execute(query, params).fetchone()
722
- return row is not None
723
-
724
- def deactivate_node(self, node_id: int) -> bool:
725
- """Deactivate the node with the specified `node_id`."""
726
- with self.conn:
727
- self._check_and_tag_offline_nodes([node_id])
728
-
729
- # Only deactivate if the node is currently online
730
- current_dt = now()
731
- query = """
732
- UPDATE node
733
- SET status = ?,
734
- last_deactivated_at = ?,
735
- online_until = ?
736
- WHERE node_id = ? AND status = ?
737
- RETURNING node_id
738
- """
739
- params = (
740
- NodeStatus.OFFLINE,
741
- current_dt.isoformat(),
742
- current_dt.timestamp(),
743
- uint64_to_int64(node_id),
744
- NodeStatus.ONLINE,
745
- )
746
-
747
- row = self.conn.execute(query, params).fetchone()
748
- return row is not None
749
-
750
- def get_nodes(self, run_id: int) -> set[int]:
751
- """Retrieve all currently stored node IDs as a set.
752
-
753
- Constraints
754
- -----------
755
- If the provided `run_id` does not exist or has no matching nodes,
756
- an empty `Set` MUST be returned.
757
- """
758
- if self.conn is None:
759
- raise AttributeError("LinkState not initialized")
760
-
761
- with self.conn:
762
- # Convert the uint64 value to sint64 for SQLite
763
- sint64_run_id = uint64_to_int64(run_id)
764
-
765
- # Validate run ID
766
- query = "SELECT federation FROM run WHERE run_id = ?"
767
- rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
768
- if not rows:
769
- return set()
770
- federation: str = rows[0]["federation"]
771
-
772
- # Retrieve all online nodes
773
- node_ids = {
774
- node.node_id
775
- for node in self.get_node_info(statuses=[NodeStatus.ONLINE])
776
- }
777
- # Filter node IDs by federation
778
- return self.federation_manager.filter_nodes(node_ids, federation)
779
-
780
- def _check_and_tag_offline_nodes(self, node_ids: list[int] | None = None) -> None:
781
- """Check and tag offline nodes."""
782
- # strftime will convert POSIX timestamp to ISO format
783
- query = """
784
- UPDATE node SET status = ?,
785
- last_deactivated_at =
786
- strftime("%Y-%m-%dT%H:%M:%f+00:00", online_until, "unixepoch")
787
- WHERE online_until <= ? AND status == ?
788
- """
789
- params = [
790
- NodeStatus.OFFLINE,
791
- now().timestamp(),
792
- NodeStatus.ONLINE,
793
- ]
794
- if node_ids is not None:
795
- placeholders = ",".join(["?"] * len(node_ids))
796
- query += f" AND node_id IN ({placeholders})"
797
- params.extend(uint64_to_int64(node_id) for node_id in node_ids)
798
- self.conn.execute(query, params)
799
-
800
- def get_node_info(
801
- self,
802
- *,
803
- node_ids: Sequence[int] | None = None,
804
- owner_aids: Sequence[str] | None = None,
805
- statuses: Sequence[str] | None = None,
806
- ) -> Sequence[NodeInfo]:
807
- """Retrieve information about nodes based on the specified filters."""
808
- with self.conn:
809
- self._check_and_tag_offline_nodes()
810
-
811
- # Build the WHERE clause based on provided filters
812
- conditions = []
813
- params: list[Any] = []
814
- if node_ids is not None:
815
- sint64_node_ids = [uint64_to_int64(node_id) for node_id in node_ids]
816
- placeholders = ",".join(["?"] * len(sint64_node_ids))
817
- conditions.append(f"node_id IN ({placeholders})")
818
- params.extend(sint64_node_ids)
819
- if owner_aids is not None:
820
- placeholders = ",".join(["?"] * len(owner_aids))
821
- conditions.append(f"owner_aid IN ({placeholders})")
822
- params.extend(owner_aids)
823
- if statuses is not None:
824
- placeholders = ",".join(["?"] * len(statuses))
825
- conditions.append(f"status IN ({placeholders})")
826
- params.extend(statuses)
827
-
828
- # Construct the final query
829
- query = "SELECT * FROM node"
830
- if conditions:
831
- query += " WHERE " + " AND ".join(conditions)
832
-
833
- rows = self.conn.execute(query, params).fetchall()
834
-
835
- result: list[NodeInfo] = []
836
- for row in rows:
837
- # Convert sint64 node_id to uint64
838
- row["node_id"] = int64_to_uint64(row["node_id"])
839
- result.append(NodeInfo(**row))
840
-
841
- return result
842
-
843
- def get_node_id_by_public_key(self, public_key: bytes) -> int | None:
844
- """Get `node_id` for the specified `public_key` if it exists and is not
845
- deleted."""
846
- query = "SELECT node_id FROM node WHERE public_key = ? AND status != ?;"
847
- rows = self.query(query, (public_key, NodeStatus.UNREGISTERED))
848
-
849
- # If no result is found, return None
850
- if not rows:
851
- return None
852
-
853
- # Convert sint64 node_id to uint64
854
- node_id = int64_to_uint64(rows[0]["node_id"])
855
- return node_id
856
-
857
- # pylint: disable=too-many-arguments,too-many-positional-arguments
858
- def create_run(
859
- self,
860
- fab_id: str | None,
861
- fab_version: str | None,
862
- fab_hash: str | None,
863
- override_config: UserConfig,
864
- federation: str,
865
- federation_options: ConfigRecord,
866
- flwr_aid: str | None,
867
- ) -> int:
868
- """Create a new run."""
869
- # Sample a random int64 as run_id
870
- uint64_run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
871
-
872
- # Convert the uint64 value to sint64 for SQLite
873
- sint64_run_id = uint64_to_int64(uint64_run_id)
874
-
875
- with self.conn:
876
- # Check conflicts
877
- query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
878
- # If sint64_run_id does not exist
879
- row = self.conn.execute(query, (sint64_run_id,)).fetchone()
880
- if row["COUNT(*)"] == 0:
881
- query = """
882
- INSERT INTO run
883
- (run_id, fab_id, fab_version,
884
- fab_hash, override_config, federation, federation_options,
885
- pending_at, starting_at, running_at, finished_at, sub_status,
886
- details, flwr_aid, bytes_sent, bytes_recv, clientapp_runtime)
887
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
888
- """
889
- override_config_json = json.dumps(override_config)
890
- data = [
891
- sint64_run_id, # run_id
892
- fab_id, # fab_id
893
- fab_version, # fab_version
894
- fab_hash, # fab_hash
895
- override_config_json, # override_config
896
- federation, # federation
897
- configrecord_to_bytes(federation_options), # federation_options
898
- now().isoformat(), # pending_at
899
- "", # starting_at
900
- "", # running_at
901
- "", # finished_at
902
- "", # sub_status
903
- "", # details
904
- flwr_aid or "", # flwr_aid
905
- 0, # bytes_sent
906
- 0, # bytes_recv
907
- 0, # clientapp_runtime
908
- ]
909
- self.conn.execute(query, tuple(data))
910
- return uint64_run_id
911
- log(ERROR, "Unexpected run creation failure.")
912
- return 0
913
-
914
- def get_run_ids(self, flwr_aid: str | None) -> set[int]:
915
- """Retrieve all run IDs if `flwr_aid` is not specified.
916
-
917
- Otherwise, retrieve all run IDs for the specified `flwr_aid`.
918
- """
919
- if flwr_aid:
920
- rows = self.query(
921
- "SELECT run_id FROM run WHERE flwr_aid = ?;",
922
- (flwr_aid,),
923
- )
924
- else:
925
- rows = self.query("SELECT run_id FROM run;", ())
926
- return {int64_to_uint64(row["run_id"]) for row in rows}
927
-
928
- def get_run(self, run_id: int) -> Run | None:
929
- """Retrieve information about the run with the specified `run_id`."""
930
- # Clean up expired tokens; this will flag inactive runs as needed
931
- self._cleanup_expired_tokens()
932
-
933
- # Convert the uint64 value to sint64 for SQLite
934
- sint64_run_id = uint64_to_int64(run_id)
935
- query = "SELECT * FROM run WHERE run_id = ?;"
936
- rows = self.query(query, (sint64_run_id,))
937
- if rows:
938
- row = rows[0]
939
- return Run(
940
- run_id=int64_to_uint64(row["run_id"]),
941
- fab_id=row["fab_id"],
942
- fab_version=row["fab_version"],
943
- fab_hash=row["fab_hash"],
944
- override_config=json.loads(row["override_config"]),
945
- pending_at=row["pending_at"],
946
- starting_at=row["starting_at"],
947
- running_at=row["running_at"],
948
- finished_at=row["finished_at"],
949
- status=RunStatus(
950
- status=determine_run_status(row),
951
- sub_status=row["sub_status"],
952
- details=row["details"],
953
- ),
954
- flwr_aid=row["flwr_aid"],
955
- federation=row["federation"],
956
- bytes_sent=row["bytes_sent"],
957
- bytes_recv=row["bytes_recv"],
958
- clientapp_runtime=row["clientapp_runtime"],
959
- )
960
- log(ERROR, "`run_id` does not exist.")
961
- return None
962
-
963
- def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
964
- """Retrieve the statuses for the specified runs."""
965
- # Clean up expired tokens; this will flag inactive runs as needed
966
- self._cleanup_expired_tokens()
967
-
968
- # Convert the uint64 value to sint64 for SQLite
969
- sint64_run_ids = (uint64_to_int64(run_id) for run_id in set(run_ids))
970
- query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});"
971
- rows = self.query(query, tuple(sint64_run_ids))
972
-
973
- return {
974
- # Restore uint64 run IDs
975
- int64_to_uint64(row["run_id"]): RunStatus(
976
- status=determine_run_status(row),
977
- sub_status=row["sub_status"],
978
- details=row["details"],
979
- )
980
- for row in rows
981
- }
982
-
983
- def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
984
- """Update the status of the run with the specified `run_id`."""
985
- # Clean up expired tokens; this will flag inactive runs as needed
986
- self._cleanup_expired_tokens()
987
-
988
- with self.conn:
989
- # Convert the uint64 value to sint64 for SQLite
990
- sint64_run_id = uint64_to_int64(run_id)
991
- query = "SELECT * FROM run WHERE run_id = ?;"
992
- rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
993
-
994
- # Check if the run_id exists
995
- if not rows:
996
- log(ERROR, "`run_id` is invalid")
997
- return False
998
-
999
- # Check if the status transition is valid
1000
- row = rows[0]
1001
- current_status = RunStatus(
1002
- status=determine_run_status(row),
1003
- sub_status=row["sub_status"],
1004
- details=row["details"],
1005
- )
1006
- if not is_valid_transition(current_status, new_status):
1007
- log(
1008
- ERROR,
1009
- 'Invalid status transition: from "%s" to "%s"',
1010
- current_status.status,
1011
- new_status.status,
1012
- )
1013
- return False
1014
-
1015
- # Check if the sub-status is valid
1016
- if not has_valid_sub_status(current_status):
1017
- log(
1018
- ERROR,
1019
- 'Invalid sub-status "%s" for status "%s"',
1020
- current_status.sub_status,
1021
- current_status.status,
1022
- )
1023
- return False
1024
-
1025
- # Update the status
1026
- query = """
1027
- UPDATE run SET %s= ?, sub_status = ?, details = ? WHERE run_id = ?;
1028
- """
1029
-
1030
- # Prepare data for query
1031
- current = now()
1032
-
1033
- # Determine the timestamp field based on the new status
1034
- timestamp_fld = ""
1035
- if new_status.status == Status.STARTING:
1036
- timestamp_fld = "starting_at"
1037
- elif new_status.status == Status.RUNNING:
1038
- timestamp_fld = "running_at"
1039
- elif new_status.status == Status.FINISHED:
1040
- timestamp_fld = "finished_at"
1041
-
1042
- data = (
1043
- current.isoformat(),
1044
- new_status.sub_status,
1045
- new_status.details,
1046
- uint64_to_int64(run_id),
1047
- )
1048
- self.conn.execute(query % timestamp_fld, data)
1049
- return True
1050
-
1051
- def get_pending_run_id(self) -> int | None:
1052
- """Get the `run_id` of a run with `Status.PENDING` status, if any."""
1053
- pending_run_id = None
1054
-
1055
- # Fetch all runs with unset `starting_at` (i.e. they are in PENDING status)
1056
- query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1;"
1057
- rows = self.query(query)
1058
- if rows:
1059
- pending_run_id = int64_to_uint64(rows[0]["run_id"])
1060
-
1061
- return pending_run_id
1062
-
1063
- def get_federation_options(self, run_id: int) -> ConfigRecord | None:
1064
- """Retrieve the federation options for the specified `run_id`."""
1065
- # Convert the uint64 value to sint64 for SQLite
1066
- sint64_run_id = uint64_to_int64(run_id)
1067
- query = "SELECT federation_options FROM run WHERE run_id = ?;"
1068
- rows = self.query(query, (sint64_run_id,))
1069
-
1070
- # Check if the run_id exists
1071
- if not rows:
1072
- log(ERROR, "`run_id` is invalid")
1073
- return None
1074
-
1075
- row = rows[0]
1076
- return configrecord_from_bytes(row["federation_options"])
1077
-
1078
- def acknowledge_node_heartbeat(
1079
- self, node_id: int, heartbeat_interval: float
1080
- ) -> bool:
1081
- """Acknowledge a heartbeat received from a node, serving as a heartbeat.
1082
-
1083
- A node is considered online as long as it sends heartbeats within
1084
- the tolerated interval: HEARTBEAT_PATIENCE × heartbeat_interval.
1085
- HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before
1086
- the node is marked as offline.
1087
- """
1088
- if self.conn is None:
1089
- raise AttributeError("LinkState not initialized")
1090
-
1091
- sint64_node_id = uint64_to_int64(node_id)
1092
-
1093
- with self.conn:
1094
- # Check if node exists and not deleted
1095
- query = "SELECT status FROM node WHERE node_id = ? AND status != ?"
1096
- row = self.conn.execute(
1097
- query, (sint64_node_id, NodeStatus.UNREGISTERED)
1098
- ).fetchone()
1099
- if row is None:
1100
- return False
1101
-
1102
- # Construct query and params
1103
- current_dt = now()
1104
- query = "UPDATE node SET online_until = ?, heartbeat_interval = ?"
1105
- params: list[Any] = [
1106
- current_dt.timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval,
1107
- heartbeat_interval,
1108
- ]
1109
-
1110
- # Set timestamp if the status changes
1111
- if row["status"] != NodeStatus.ONLINE:
1112
- query += ", status = ?, last_activated_at = ?"
1113
- params += [NodeStatus.ONLINE, current_dt.isoformat()]
1114
-
1115
- # Execute the query, refreshing `online_until` and `heartbeat_interval`
1116
- query += " WHERE node_id = ?"
1117
- params += [sint64_node_id]
1118
- self.conn.execute(query, params)
1119
- return True
1120
-
1121
- def _on_tokens_expired(self, expired_records: list[tuple[int, float]]) -> None:
1122
- """Transition runs with expired tokens to failed status.
1123
-
1124
- Parameters
1125
- ----------
1126
- expired_records : list[tuple[int, float]]
1127
- List of tuples containing (run_id, active_until timestamp)
1128
- for expired tokens.
1129
- """
1130
- if not expired_records:
1131
- return
1132
-
1133
- with self.conn:
1134
- query = """
1135
- UPDATE run
1136
- SET sub_status = ?, details = ?, finished_at = ?
1137
- WHERE run_id = ?;
1138
- """
1139
- data = [
1140
- (
1141
- SubStatus.FAILED,
1142
- RUN_FAILURE_DETAILS_NO_HEARTBEAT,
1143
- datetime.fromtimestamp(active_until, tz=timezone.utc).isoformat(),
1144
- uint64_to_int64(run_id),
1145
- )
1146
- for run_id, active_until in expired_records
1147
- ]
1148
- self.conn.executemany(query, data)
1149
-
1150
- def get_serverapp_context(self, run_id: int) -> Context | None:
1151
- """Get the context for the specified `run_id`."""
1152
- # Retrieve context if any
1153
- query = "SELECT context FROM context WHERE run_id = ?;"
1154
- rows = self.query(query, (uint64_to_int64(run_id),))
1155
- context = context_from_bytes(rows[0]["context"]) if rows else None
1156
- return context
1157
-
1158
- def set_serverapp_context(self, run_id: int, context: Context) -> None:
1159
- """Set the context for the specified `run_id`."""
1160
- # Convert context to bytes
1161
- context_bytes = context_to_bytes(context)
1162
- sint_run_id = uint64_to_int64(run_id)
1163
-
1164
- with self.conn:
1165
- # Check if any existing Context assigned to the run_id
1166
- query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
1167
- row = self.conn.execute(query, (sint_run_id,)).fetchone()
1168
- if row["COUNT(*)"] > 0:
1169
- # Update context
1170
- query = "UPDATE context SET context = ? WHERE run_id = ?;"
1171
- self.conn.execute(query, (context_bytes, sint_run_id))
1172
- else:
1173
- try:
1174
- # Store context
1175
- query = "INSERT INTO context (run_id, context) VALUES (?, ?);"
1176
- self.conn.execute(query, (sint_run_id, context_bytes))
1177
- except sqlite3.IntegrityError:
1178
- raise ValueError(f"Run {run_id} not found") from None
1179
-
1180
- def add_serverapp_log(self, run_id: int, log_message: str) -> None:
1181
- """Add a log entry to the ServerApp logs for the specified `run_id`."""
1182
- # Convert the uint64 value to sint64 for SQLite
1183
- sint64_run_id = uint64_to_int64(run_id)
1184
-
1185
- # Store log
1186
- try:
1187
- query = """
1188
- INSERT INTO logs (timestamp, run_id, node_id, log) VALUES (?, ?, ?, ?);
1189
- """
1190
- self.query(query, (now().timestamp(), sint64_run_id, 0, log_message))
1191
- except sqlite3.IntegrityError:
1192
- raise ValueError(f"Run {run_id} not found") from None
1193
-
1194
- def get_serverapp_log(
1195
- self, run_id: int, after_timestamp: float | None
1196
- ) -> tuple[str, float]:
1197
- """Get the ServerApp logs for the specified `run_id`."""
1198
- # Convert the uint64 value to sint64 for SQLite
1199
- sint64_run_id = uint64_to_int64(run_id)
1200
-
1201
- with self.conn:
1202
- # Check if the run_id exists
1203
- query = "SELECT run_id FROM run WHERE run_id = ?;"
1204
- rows = self.conn.execute(query, (sint64_run_id,)).fetchall()
1205
- if not rows:
1206
- raise ValueError(f"Run {run_id} not found")
1207
-
1208
- # Retrieve logs
1209
- if after_timestamp is None:
1210
- after_timestamp = 0.0
1211
- query = """
1212
- SELECT log, timestamp FROM logs
1213
- WHERE run_id = ? AND node_id = ? AND timestamp > ?;
1214
- """
1215
- rows = self.conn.execute(
1216
- query, (sint64_run_id, 0, after_timestamp)
1217
- ).fetchall()
1218
- rows.sort(key=lambda x: x["timestamp"])
1219
- latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
1220
- return "".join(row["log"] for row in rows), latest_timestamp
1221
-
1222
- def get_valid_message_ins(self, message_id: str) -> dict[str, Any] | None:
1223
- """Check if the Message exists and is valid (not expired).
1224
-
1225
- Return Message if valid.
1226
- """
1227
- with self.conn:
1228
- self._check_stored_messages({message_id})
1229
- query = """
1230
- SELECT *
1231
- FROM message_ins
1232
- WHERE message_id = :message_id
1233
- """
1234
- data = {"message_id": message_id}
1235
- rows: list[dict[str, Any]] = self.conn.execute(query, data).fetchall()
1236
- if not rows:
1237
- # Message does not exist
1238
- return None
1239
-
1240
- return rows[0]
1241
-
1242
- def store_traffic(self, run_id: int, *, bytes_sent: int, bytes_recv: int) -> None:
1243
- """Store traffic data for the specified `run_id`."""
1244
- # Validate non-negative values
1245
- if bytes_sent < 0 or bytes_recv < 0:
1246
- raise ValueError(
1247
- f"Negative traffic values for run {run_id}: "
1248
- f"bytes_sent={bytes_sent}, bytes_recv={bytes_recv}"
1249
- )
1250
-
1251
- if bytes_sent == 0 and bytes_recv == 0:
1252
- raise ValueError(
1253
- f"Both bytes_sent and bytes_recv cannot be zero for run {run_id}"
1254
- )
1255
-
1256
- sint64_run_id = uint64_to_int64(run_id)
1257
-
1258
- with self.conn:
1259
- # Check if run exists, performing the update only if it does
1260
- update_query = """
1261
- UPDATE run
1262
- SET bytes_sent = bytes_sent + ?,
1263
- bytes_recv = bytes_recv + ?
1264
- WHERE run_id = ?
1265
- RETURNING run_id;
1266
- """
1267
- rows = self.conn.execute(
1268
- update_query, (bytes_sent, bytes_recv, sint64_run_id)
1269
- ).fetchall()
1270
-
1271
- if not rows:
1272
- raise ValueError(f"Run {run_id} not found")
1273
-
1274
- def add_clientapp_runtime(self, run_id: int, runtime: float) -> None:
1275
- """Add ClientApp runtime to the cumulative total for the specified `run_id`."""
1276
- sint64_run_id = uint64_to_int64(run_id)
1277
- with self.conn:
1278
- # Check if run exists, performing the update only if it does
1279
- update_query = """
1280
- UPDATE run
1281
- SET clientapp_runtime = clientapp_runtime + ?
1282
- WHERE run_id = ?
1283
- RETURNING run_id;
1284
- """
1285
- rows = self.conn.execute(update_query, (runtime, sint64_run_id)).fetchall()
1286
-
1287
- if not rows:
1288
- raise ValueError(f"Run {run_id} not found")
1289
-
1290
-
1291
- def determine_run_status(row: dict[str, Any]) -> str:
1292
- """Determine the status of the run based on timestamp fields."""
1293
- if row["pending_at"]:
1294
- if row["finished_at"]:
1295
- return Status.FINISHED
1296
- if row["starting_at"]:
1297
- if row["running_at"]:
1298
- return Status.RUNNING
1299
- return Status.STARTING
1300
- return Status.PENDING
1301
- run_id = int64_to_uint64(row["run_id"])
1302
- raise sqlite3.IntegrityError(f"The run {run_id} does not have a valid status.")