flwr 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +17 -1
- flwr/cli/auth_plugin/__init__.py +15 -6
- flwr/cli/auth_plugin/auth_plugin.py +95 -0
- flwr/cli/auth_plugin/noop_auth_plugin.py +58 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +16 -25
- flwr/cli/build.py +118 -47
- flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +6 -5
- flwr/cli/log.py +2 -2
- flwr/cli/login/login.py +34 -23
- flwr/cli/ls.py +13 -9
- flwr/cli/new/new.py +196 -42
- flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
- flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
- flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
- flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
- flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +3 -3
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
- flwr/cli/pull.py +100 -0
- flwr/cli/run/run.py +11 -7
- flwr/cli/stop.py +2 -2
- flwr/cli/supernode/__init__.py +25 -0
- flwr/cli/supernode/ls.py +260 -0
- flwr/cli/supernode/register.py +185 -0
- flwr/cli/supernode/unregister.py +138 -0
- flwr/cli/utils.py +109 -69
- flwr/client/__init__.py +2 -1
- flwr/client/grpc_adapter_client/connection.py +6 -8
- flwr/client/grpc_rere_client/connection.py +59 -31
- flwr/client/grpc_rere_client/grpc_adapter.py +28 -12
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
- flwr/client/rest_client/connection.py +82 -37
- flwr/clientapp/__init__.py +1 -2
- flwr/clientapp/mod/__init__.py +4 -1
- flwr/clientapp/mod/centraldp_mods.py +156 -40
- flwr/clientapp/mod/localdp_mod.py +169 -0
- flwr/clientapp/typing.py +22 -0
- flwr/{client/clientapp → clientapp}/utils.py +1 -1
- flwr/common/constant.py +56 -13
- flwr/common/exit/exit_code.py +24 -10
- flwr/common/inflatable_utils.py +10 -10
- flwr/common/record/array.py +3 -3
- flwr/common/record/arrayrecord.py +10 -1
- flwr/common/record/typeddict.py +12 -0
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
- flwr/common/serde.py +4 -2
- flwr/common/typing.py +7 -6
- flwr/compat/client/app.py +1 -1
- flwr/compat/client/grpc_client/connection.py +2 -2
- flwr/proto/control_pb2.py +48 -31
- flwr/proto/control_pb2.pyi +95 -5
- flwr/proto/control_pb2_grpc.py +136 -0
- flwr/proto/control_pb2_grpc.pyi +52 -0
- flwr/proto/fab_pb2.py +11 -7
- flwr/proto/fab_pb2.pyi +21 -1
- flwr/proto/fleet_pb2.py +31 -23
- flwr/proto/fleet_pb2.pyi +63 -23
- flwr/proto/fleet_pb2_grpc.py +98 -28
- flwr/proto/fleet_pb2_grpc.pyi +45 -13
- flwr/proto/node_pb2.py +3 -1
- flwr/proto/node_pb2.pyi +48 -0
- flwr/server/app.py +152 -114
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -7
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +132 -38
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +27 -51
- flwr/server/superlink/fleet/message_handler/message_handler.py +67 -22
- flwr/server/superlink/fleet/rest_rere/rest_api.py +52 -31
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +18 -5
- flwr/server/superlink/linkstate/in_memory_linkstate.py +167 -73
- flwr/server/superlink/linkstate/linkstate.py +107 -24
- flwr/server/superlink/linkstate/linkstate_factory.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +306 -255
- flwr/server/superlink/linkstate/utils.py +3 -54
- flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
- flwr/server/utils/validator.py +2 -3
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
- flwr/serverapp/strategy/__init__.py +26 -0
- flwr/serverapp/strategy/bulyan.py +238 -0
- flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
- flwr/serverapp/strategy/dp_fixed_clipping.py +71 -49
- flwr/serverapp/strategy/fedadagrad.py +0 -3
- flwr/serverapp/strategy/fedadam.py +0 -3
- flwr/serverapp/strategy/fedavg.py +89 -64
- flwr/serverapp/strategy/fedavgm.py +198 -0
- flwr/serverapp/strategy/fedmedian.py +105 -0
- flwr/serverapp/strategy/fedprox.py +174 -0
- flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
- flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
- flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
- flwr/serverapp/strategy/fedyogi.py +0 -3
- flwr/serverapp/strategy/krum.py +112 -0
- flwr/serverapp/strategy/multikrum.py +247 -0
- flwr/serverapp/strategy/qfedavg.py +252 -0
- flwr/serverapp/strategy/strategy_utils.py +48 -0
- flwr/simulation/app.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/run_simulation.py +28 -32
- flwr/supercore/cli/flower_superexec.py +26 -1
- flwr/supercore/constant.py +41 -0
- flwr/supercore/object_store/in_memory_object_store.py +0 -4
- flwr/supercore/object_store/object_store_factory.py +26 -6
- flwr/supercore/object_store/sqlite_object_store.py +252 -0
- flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
- flwr/supercore/primitives/asymmetric.py +117 -0
- flwr/supercore/primitives/asymmetric_ed25519.py +165 -0
- flwr/supercore/sqlite_mixin.py +156 -0
- flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
- flwr/supercore/superexec/run_superexec.py +16 -2
- flwr/supercore/utils.py +20 -0
- flwr/superlink/artifact_provider/__init__.py +22 -0
- flwr/superlink/artifact_provider/artifact_provider.py +37 -0
- flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
- flwr/superlink/auth_plugin/auth_plugin.py +91 -0
- flwr/superlink/auth_plugin/noop_auth_plugin.py +87 -0
- flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +19 -19
- flwr/superlink/servicer/control/control_event_log_interceptor.py +1 -1
- flwr/superlink/servicer/control/control_grpc.py +16 -11
- flwr/superlink/servicer/control/control_servicer.py +207 -58
- flwr/supernode/cli/flower_supernode.py +19 -26
- flwr/supernode/runtime/run_clientapp.py +2 -2
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +1 -1
- flwr/supernode/start_client_internal.py +17 -9
- {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/METADATA +6 -16
- {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/RECORD +170 -140
- flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
- flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
- flwr/common/auth_plugin/auth_plugin.py +0 -149
- flwr/serverapp/dp_fixed_clipping.py +0 -352
- flwr/serverapp/strategy/strategy_utils_tests.py +0 -304
- /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
- /flwr/{client → clientapp}/client_app.py +0 -0
- {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/WHEEL +0 -0
- {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/entry_points.txt +0 -0
|
@@ -18,18 +18,16 @@
|
|
|
18
18
|
# pylint: disable=too-many-lines
|
|
19
19
|
|
|
20
20
|
import json
|
|
21
|
-
import re
|
|
22
21
|
import secrets
|
|
23
22
|
import sqlite3
|
|
24
|
-
import time
|
|
25
23
|
from collections.abc import Sequence
|
|
26
|
-
from logging import
|
|
24
|
+
from logging import ERROR, WARNING
|
|
27
25
|
from typing import Any, Optional, Union, cast
|
|
28
26
|
|
|
29
27
|
from flwr.common import Context, Message, Metadata, log, now
|
|
30
28
|
from flwr.common.constant import (
|
|
31
29
|
FLWR_APP_TOKEN_LENGTH,
|
|
32
|
-
|
|
30
|
+
HEARTBEAT_INTERVAL_INF,
|
|
33
31
|
HEARTBEAT_PATIENCE,
|
|
34
32
|
MESSAGE_TTL_TOLERANCE,
|
|
35
33
|
NODE_ID_NUM_BYTES,
|
|
@@ -47,10 +45,14 @@ from flwr.common.typing import Run, RunStatus, UserConfig
|
|
|
47
45
|
|
|
48
46
|
# pylint: disable=E0611
|
|
49
47
|
from flwr.proto.error_pb2 import Error as ProtoError
|
|
48
|
+
from flwr.proto.node_pb2 import NodeInfo
|
|
50
49
|
from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
|
|
51
50
|
|
|
52
51
|
# pylint: enable=E0611
|
|
53
52
|
from flwr.server.utils.validator import validate_message
|
|
53
|
+
from flwr.supercore.constant import NodeStatus
|
|
54
|
+
from flwr.supercore.sqlite_mixin import SqliteMixin
|
|
55
|
+
from flwr.supercore.utils import int64_to_uint64, uint64_to_int64
|
|
54
56
|
|
|
55
57
|
from .linkstate import LinkState
|
|
56
58
|
from .utils import (
|
|
@@ -59,9 +61,7 @@ from .utils import (
|
|
|
59
61
|
configrecord_to_bytes,
|
|
60
62
|
context_from_bytes,
|
|
61
63
|
context_to_bytes,
|
|
62
|
-
convert_sint64_to_uint64,
|
|
63
64
|
convert_sint64_values_in_dict_to_uint64,
|
|
64
|
-
convert_uint64_to_sint64,
|
|
65
65
|
convert_uint64_values_in_dict_to_sint64,
|
|
66
66
|
generate_rand_int_from_bytes,
|
|
67
67
|
has_valid_sub_status,
|
|
@@ -72,10 +72,16 @@ from .utils import (
|
|
|
72
72
|
|
|
73
73
|
SQL_CREATE_TABLE_NODE = """
|
|
74
74
|
CREATE TABLE IF NOT EXISTS node(
|
|
75
|
-
node_id
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
75
|
+
node_id INTEGER UNIQUE,
|
|
76
|
+
owner_aid TEXT,
|
|
77
|
+
status TEXT,
|
|
78
|
+
registered_at TEXT,
|
|
79
|
+
last_activated_at TEXT NULL,
|
|
80
|
+
last_deactivated_at TEXT NULL,
|
|
81
|
+
unregistered_at TEXT NULL,
|
|
82
|
+
online_until TIMESTAMP NULL,
|
|
83
|
+
heartbeat_interval REAL,
|
|
84
|
+
public_key BLOB UNIQUE
|
|
79
85
|
);
|
|
80
86
|
"""
|
|
81
87
|
|
|
@@ -89,6 +95,14 @@ SQL_CREATE_INDEX_ONLINE_UNTIL = """
|
|
|
89
95
|
CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
|
|
90
96
|
"""
|
|
91
97
|
|
|
98
|
+
SQL_CREATE_INDEX_OWNER_AID = """
|
|
99
|
+
CREATE INDEX IF NOT EXISTS idx_node_owner_aid ON node(owner_aid);
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
SQL_CREATE_INDEX_NODE_STATUS = """
|
|
103
|
+
CREATE INDEX IF NOT EXISTS idx_node_status ON node(status);
|
|
104
|
+
"""
|
|
105
|
+
|
|
92
106
|
SQL_CREATE_TABLE_RUN = """
|
|
93
107
|
CREATE TABLE IF NOT EXISTS run(
|
|
94
108
|
run_id INTEGER UNIQUE,
|
|
@@ -172,94 +186,26 @@ CREATE TABLE IF NOT EXISTS token_store (
|
|
|
172
186
|
);
|
|
173
187
|
"""
|
|
174
188
|
|
|
175
|
-
DictOrTuple = Union[tuple[Any, ...], dict[str, Any]]
|
|
176
189
|
|
|
177
|
-
|
|
178
|
-
class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
190
|
+
class SqliteLinkState(LinkState, SqliteMixin): # pylint: disable=R0904
|
|
179
191
|
"""SQLite-based LinkState implementation."""
|
|
180
192
|
|
|
181
|
-
def __init__(
|
|
182
|
-
self,
|
|
183
|
-
database_path: str,
|
|
184
|
-
) -> None:
|
|
185
|
-
"""Initialize an SqliteLinkState.
|
|
186
|
-
|
|
187
|
-
Parameters
|
|
188
|
-
----------
|
|
189
|
-
database : (path-like object)
|
|
190
|
-
The path to the database file to be opened. Pass ":memory:" to open
|
|
191
|
-
a connection to a database that is in RAM, instead of on disk.
|
|
192
|
-
"""
|
|
193
|
-
self.database_path = database_path
|
|
194
|
-
self.conn: Optional[sqlite3.Connection] = None
|
|
195
|
-
|
|
196
193
|
def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
|
|
197
|
-
"""
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
if log_queries:
|
|
213
|
-
self.conn.set_trace_callback(lambda query: log(DEBUG, query))
|
|
214
|
-
cur = self.conn.cursor()
|
|
215
|
-
|
|
216
|
-
# Create each table if not exists queries
|
|
217
|
-
cur.execute(SQL_CREATE_TABLE_RUN)
|
|
218
|
-
cur.execute(SQL_CREATE_TABLE_LOGS)
|
|
219
|
-
cur.execute(SQL_CREATE_TABLE_CONTEXT)
|
|
220
|
-
cur.execute(SQL_CREATE_TABLE_MESSAGE_INS)
|
|
221
|
-
cur.execute(SQL_CREATE_TABLE_MESSAGE_RES)
|
|
222
|
-
cur.execute(SQL_CREATE_TABLE_NODE)
|
|
223
|
-
cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
|
|
224
|
-
cur.execute(SQL_CREATE_TABLE_TOKEN_STORE)
|
|
225
|
-
cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
|
|
226
|
-
res = cur.execute("SELECT name FROM sqlite_schema;")
|
|
227
|
-
return res.fetchall()
|
|
228
|
-
|
|
229
|
-
def query(
|
|
230
|
-
self,
|
|
231
|
-
query: str,
|
|
232
|
-
data: Optional[Union[Sequence[DictOrTuple], DictOrTuple]] = None,
|
|
233
|
-
) -> list[dict[str, Any]]:
|
|
234
|
-
"""Execute a SQL query."""
|
|
235
|
-
if self.conn is None:
|
|
236
|
-
raise AttributeError("LinkState is not initialized.")
|
|
237
|
-
|
|
238
|
-
if data is None:
|
|
239
|
-
data = []
|
|
240
|
-
|
|
241
|
-
# Clean up whitespace to make the logs nicer
|
|
242
|
-
query = re.sub(r"\s+", " ", query)
|
|
243
|
-
|
|
244
|
-
try:
|
|
245
|
-
with self.conn:
|
|
246
|
-
if (
|
|
247
|
-
len(data) > 0
|
|
248
|
-
and isinstance(data, (tuple, list))
|
|
249
|
-
and isinstance(data[0], (tuple, dict))
|
|
250
|
-
):
|
|
251
|
-
rows = self.conn.executemany(query, data)
|
|
252
|
-
else:
|
|
253
|
-
rows = self.conn.execute(query, data)
|
|
254
|
-
|
|
255
|
-
# Extract results before committing to support
|
|
256
|
-
# INSERT/UPDATE ... RETURNING
|
|
257
|
-
# style queries
|
|
258
|
-
result = rows.fetchall()
|
|
259
|
-
except KeyError as exc:
|
|
260
|
-
log(ERROR, {"query": query, "data": data, "exception": exc})
|
|
261
|
-
|
|
262
|
-
return result
|
|
194
|
+
"""Connect to the DB, enable FK support, and create tables if needed."""
|
|
195
|
+
return self._ensure_initialized(
|
|
196
|
+
SQL_CREATE_TABLE_RUN,
|
|
197
|
+
SQL_CREATE_TABLE_LOGS,
|
|
198
|
+
SQL_CREATE_TABLE_CONTEXT,
|
|
199
|
+
SQL_CREATE_TABLE_MESSAGE_INS,
|
|
200
|
+
SQL_CREATE_TABLE_MESSAGE_RES,
|
|
201
|
+
SQL_CREATE_TABLE_NODE,
|
|
202
|
+
SQL_CREATE_TABLE_PUBLIC_KEY,
|
|
203
|
+
SQL_CREATE_TABLE_TOKEN_STORE,
|
|
204
|
+
SQL_CREATE_INDEX_ONLINE_UNTIL,
|
|
205
|
+
SQL_CREATE_INDEX_OWNER_AID,
|
|
206
|
+
SQL_CREATE_INDEX_NODE_STATUS,
|
|
207
|
+
log_queries=log_queries,
|
|
208
|
+
)
|
|
263
209
|
|
|
264
210
|
def store_message_ins(self, message: Message) -> Optional[str]:
|
|
265
211
|
"""Store one Message."""
|
|
@@ -293,8 +239,10 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
293
239
|
return None
|
|
294
240
|
|
|
295
241
|
# Validate destination node ID
|
|
296
|
-
query = "SELECT node_id FROM node WHERE node_id =
|
|
297
|
-
if not self.query(
|
|
242
|
+
query = "SELECT node_id FROM node WHERE node_id = ? AND status IN (?, ?);"
|
|
243
|
+
if not self.query(
|
|
244
|
+
query, (data[0]["dst_node_id"], NodeStatus.ONLINE, NodeStatus.OFFLINE)
|
|
245
|
+
):
|
|
298
246
|
log(
|
|
299
247
|
ERROR,
|
|
300
248
|
"Invalid destination node ID for Message: %s",
|
|
@@ -323,7 +271,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
323
271
|
data: dict[str, Union[str, int]] = {}
|
|
324
272
|
|
|
325
273
|
# Convert the uint64 value to sint64 for SQLite
|
|
326
|
-
data["node_id"] =
|
|
274
|
+
data["node_id"] = uint64_to_int64(node_id)
|
|
327
275
|
|
|
328
276
|
# Retrieve all Messages for node_id
|
|
329
277
|
query = """
|
|
@@ -398,8 +346,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
398
346
|
if (
|
|
399
347
|
msg_ins
|
|
400
348
|
and message
|
|
401
|
-
and
|
|
402
|
-
!= res_metadata.src_node_id
|
|
349
|
+
and int64_to_uint64(msg_ins["dst_node_id"]) != res_metadata.src_node_id
|
|
403
350
|
):
|
|
404
351
|
return None
|
|
405
352
|
|
|
@@ -451,7 +398,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
451
398
|
ret: dict[str, Message] = {}
|
|
452
399
|
|
|
453
400
|
# Verify Message IDs
|
|
454
|
-
current =
|
|
401
|
+
current = now().timestamp()
|
|
455
402
|
query = f"""
|
|
456
403
|
SELECT *
|
|
457
404
|
FROM message_ins
|
|
@@ -475,20 +422,20 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
475
422
|
dst_node_ids: set[int] = set()
|
|
476
423
|
for message_id in message_ids:
|
|
477
424
|
in_message = found_message_ins_dict[message_id]
|
|
478
|
-
sint_node_id =
|
|
425
|
+
sint_node_id = uint64_to_int64(in_message.metadata.dst_node_id)
|
|
479
426
|
dst_node_ids.add(sint_node_id)
|
|
480
427
|
query = f"""
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
428
|
+
SELECT node_id, online_until
|
|
429
|
+
FROM node
|
|
430
|
+
WHERE node_id IN ({",".join(["?"] * len(dst_node_ids))})
|
|
431
|
+
AND status != ?
|
|
432
|
+
"""
|
|
433
|
+
rows = self.query(query, tuple(dst_node_ids) + (NodeStatus.UNREGISTERED,))
|
|
486
434
|
tmp_ret_dict = check_node_availability_for_in_message(
|
|
487
435
|
inquired_in_message_ids=message_ids,
|
|
488
436
|
found_in_message_dict=found_message_ins_dict,
|
|
489
437
|
node_id_to_online_until={
|
|
490
|
-
|
|
491
|
-
for row in rows
|
|
438
|
+
int64_to_uint64(row["node_id"]): row["online_until"] for row in rows
|
|
492
439
|
},
|
|
493
440
|
current_time=current,
|
|
494
441
|
)
|
|
@@ -589,7 +536,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
589
536
|
WHERE run_id = :run_id;
|
|
590
537
|
"""
|
|
591
538
|
|
|
592
|
-
sint64_run_id =
|
|
539
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
593
540
|
data = {"run_id": sint64_run_id}
|
|
594
541
|
|
|
595
542
|
with self.conn:
|
|
@@ -597,7 +544,9 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
597
544
|
|
|
598
545
|
return {row["message_id"] for row in rows}
|
|
599
546
|
|
|
600
|
-
def create_node(
|
|
547
|
+
def create_node(
|
|
548
|
+
self, owner_aid: str, public_key: bytes, heartbeat_interval: float
|
|
549
|
+
) -> int:
|
|
601
550
|
"""Create, store in the link state, and return `node_id`."""
|
|
602
551
|
# Sample a random uint64 as node_id
|
|
603
552
|
uint64_node_id = generate_rand_int_from_bytes(
|
|
@@ -605,50 +554,126 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
605
554
|
)
|
|
606
555
|
|
|
607
556
|
# Convert the uint64 value to sint64 for SQLite
|
|
608
|
-
sint64_node_id =
|
|
557
|
+
sint64_node_id = uint64_to_int64(uint64_node_id)
|
|
609
558
|
|
|
610
|
-
query =
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
559
|
+
query = """
|
|
560
|
+
INSERT INTO node
|
|
561
|
+
(node_id, owner_aid, status, registered_at, last_activated_at,
|
|
562
|
+
last_deactivated_at, unregistered_at, online_until, heartbeat_interval,
|
|
563
|
+
public_key)
|
|
564
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
565
|
+
"""
|
|
615
566
|
|
|
616
|
-
# Mark the node online
|
|
567
|
+
# Mark the node online until now().timestamp() + heartbeat_interval
|
|
617
568
|
try:
|
|
618
569
|
self.query(
|
|
619
570
|
query,
|
|
620
571
|
(
|
|
621
|
-
sint64_node_id,
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
572
|
+
sint64_node_id, # node_id
|
|
573
|
+
owner_aid, # owner_aid
|
|
574
|
+
NodeStatus.REGISTERED, # status
|
|
575
|
+
now().isoformat(), # registered_at
|
|
576
|
+
None, # last_activated_at
|
|
577
|
+
None, # last_deactivated_at
|
|
578
|
+
None, # unregistered_at
|
|
579
|
+
None, # online_until, initialized with offline status
|
|
580
|
+
heartbeat_interval, # heartbeat_interval
|
|
581
|
+
public_key, # public_key
|
|
625
582
|
),
|
|
626
583
|
)
|
|
627
|
-
except sqlite3.IntegrityError:
|
|
584
|
+
except sqlite3.IntegrityError as e:
|
|
585
|
+
if "UNIQUE constraint failed: node.public_key" in str(e):
|
|
586
|
+
raise ValueError("Public key already in use.") from None
|
|
587
|
+
# Must be node ID conflict, almost impossible unless system is compromised
|
|
628
588
|
log(ERROR, "Unexpected node registration failure.")
|
|
629
589
|
return 0
|
|
630
590
|
|
|
631
591
|
# Note: we need to return the uint64 value of the node_id
|
|
632
592
|
return uint64_node_id
|
|
633
593
|
|
|
634
|
-
def delete_node(self, node_id: int) -> None:
|
|
594
|
+
def delete_node(self, owner_aid: str, node_id: int) -> None:
|
|
635
595
|
"""Delete a node."""
|
|
636
|
-
|
|
637
|
-
sint64_node_id = convert_uint64_to_sint64(node_id)
|
|
596
|
+
sint64_node_id = uint64_to_int64(node_id)
|
|
638
597
|
|
|
639
|
-
query = "
|
|
640
|
-
|
|
598
|
+
query = """
|
|
599
|
+
UPDATE node
|
|
600
|
+
SET status = ?, unregistered_at = ?,
|
|
601
|
+
online_until = IIF(online_until > ?, ?, online_until)
|
|
602
|
+
WHERE node_id = ? AND status != ? AND owner_aid = ?
|
|
603
|
+
RETURNING node_id
|
|
604
|
+
"""
|
|
605
|
+
current = now()
|
|
606
|
+
params = (
|
|
607
|
+
NodeStatus.UNREGISTERED,
|
|
608
|
+
current.isoformat(),
|
|
609
|
+
current.timestamp(),
|
|
610
|
+
current.timestamp(),
|
|
611
|
+
sint64_node_id,
|
|
612
|
+
NodeStatus.UNREGISTERED,
|
|
613
|
+
owner_aid,
|
|
614
|
+
)
|
|
641
615
|
|
|
642
|
-
|
|
643
|
-
|
|
616
|
+
rows = self.query(query, params)
|
|
617
|
+
if not rows:
|
|
618
|
+
raise ValueError(
|
|
619
|
+
f"Node {node_id} already deleted, not found or unauthorized "
|
|
620
|
+
"deletion attempt."
|
|
621
|
+
)
|
|
644
622
|
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
623
|
+
def activate_node(self, node_id: int, heartbeat_interval: float) -> bool:
|
|
624
|
+
"""Activate the node with the specified `node_id`."""
|
|
625
|
+
with self.conn:
|
|
626
|
+
self._check_and_tag_offline_nodes([node_id])
|
|
627
|
+
|
|
628
|
+
# Only activate if the node is currently registered or offline
|
|
629
|
+
current_dt = now()
|
|
630
|
+
query = """
|
|
631
|
+
UPDATE node
|
|
632
|
+
SET status = ?,
|
|
633
|
+
last_activated_at = ?,
|
|
634
|
+
online_until = ?,
|
|
635
|
+
heartbeat_interval = ?
|
|
636
|
+
WHERE node_id = ? AND status in (?, ?)
|
|
637
|
+
RETURNING node_id
|
|
638
|
+
"""
|
|
639
|
+
params = (
|
|
640
|
+
NodeStatus.ONLINE,
|
|
641
|
+
current_dt.isoformat(),
|
|
642
|
+
current_dt.timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval,
|
|
643
|
+
heartbeat_interval,
|
|
644
|
+
uint64_to_int64(node_id),
|
|
645
|
+
NodeStatus.REGISTERED,
|
|
646
|
+
NodeStatus.OFFLINE,
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
row = self.conn.execute(query, params).fetchone()
|
|
650
|
+
return row is not None
|
|
651
|
+
|
|
652
|
+
def deactivate_node(self, node_id: int) -> bool:
|
|
653
|
+
"""Deactivate the node with the specified `node_id`."""
|
|
654
|
+
with self.conn:
|
|
655
|
+
self._check_and_tag_offline_nodes([node_id])
|
|
656
|
+
|
|
657
|
+
# Only deactivate if the node is currently online
|
|
658
|
+
current_dt = now()
|
|
659
|
+
query = """
|
|
660
|
+
UPDATE node
|
|
661
|
+
SET status = ?,
|
|
662
|
+
last_deactivated_at = ?,
|
|
663
|
+
online_until = ?
|
|
664
|
+
WHERE node_id = ? AND status = ?
|
|
665
|
+
RETURNING node_id
|
|
666
|
+
"""
|
|
667
|
+
params = (
|
|
668
|
+
NodeStatus.OFFLINE,
|
|
669
|
+
current_dt.isoformat(),
|
|
670
|
+
current_dt.timestamp(),
|
|
671
|
+
uint64_to_int64(node_id),
|
|
672
|
+
NodeStatus.ONLINE,
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
row = self.conn.execute(query, params).fetchone()
|
|
676
|
+
return row is not None
|
|
652
677
|
|
|
653
678
|
def get_nodes(self, run_id: int) -> set[int]:
|
|
654
679
|
"""Retrieve all currently stored node IDs as a set.
|
|
@@ -658,69 +683,117 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
658
683
|
If the provided `run_id` does not exist or has no matching nodes,
|
|
659
684
|
an empty `Set` MUST be returned.
|
|
660
685
|
"""
|
|
686
|
+
if self.conn is None:
|
|
687
|
+
raise AttributeError("LinkState not initialized")
|
|
688
|
+
|
|
661
689
|
# Convert the uint64 value to sint64 for SQLite
|
|
662
|
-
sint64_run_id =
|
|
690
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
663
691
|
|
|
664
692
|
# Validate run ID
|
|
665
|
-
query = "SELECT COUNT(*) FROM run WHERE run_id =
|
|
666
|
-
|
|
693
|
+
query = "SELECT COUNT(*) FROM run WHERE run_id = ?"
|
|
694
|
+
rows = self.query(query, (sint64_run_id,))
|
|
695
|
+
if rows[0]["COUNT(*)"] == 0:
|
|
667
696
|
return set()
|
|
668
697
|
|
|
669
|
-
#
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
# Convert sint64 node_ids to uint64
|
|
674
|
-
result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows}
|
|
675
|
-
return result
|
|
676
|
-
|
|
677
|
-
def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
|
|
678
|
-
"""Set `public_key` for the specified `node_id`."""
|
|
679
|
-
# Convert the uint64 value to sint64 for SQLite
|
|
680
|
-
sint64_node_id = convert_uint64_to_sint64(node_id)
|
|
681
|
-
|
|
682
|
-
# Check if the node exists in the `node` table
|
|
683
|
-
query = "SELECT 1 FROM node WHERE node_id = ?"
|
|
684
|
-
if not self.query(query, (sint64_node_id,)):
|
|
685
|
-
raise ValueError(f"Node {node_id} not found")
|
|
686
|
-
|
|
687
|
-
# Check if the public key is already in use in the `node` table
|
|
688
|
-
query = "SELECT 1 FROM node WHERE public_key = ?"
|
|
689
|
-
if self.query(query, (public_key,)):
|
|
690
|
-
raise ValueError("Public key already in use")
|
|
698
|
+
# Retrieve all online nodes
|
|
699
|
+
return {
|
|
700
|
+
node.node_id for node in self.get_node_info(statuses=[NodeStatus.ONLINE])
|
|
701
|
+
}
|
|
691
702
|
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
703
|
+
def _check_and_tag_offline_nodes(
|
|
704
|
+
self, node_ids: Optional[list[int]] = None
|
|
705
|
+
) -> None:
|
|
706
|
+
"""Check and tag offline nodes."""
|
|
707
|
+
# strftime will convert POSIX timestamp to ISO format
|
|
708
|
+
query = """
|
|
709
|
+
UPDATE node SET status = ?,
|
|
710
|
+
last_deactivated_at =
|
|
711
|
+
strftime("%Y-%m-%dT%H:%M:%f+00:00", online_until, "unixepoch")
|
|
712
|
+
WHERE online_until <= ? AND status == ?
|
|
713
|
+
"""
|
|
714
|
+
params = [
|
|
715
|
+
NodeStatus.OFFLINE,
|
|
716
|
+
now().timestamp(),
|
|
717
|
+
NodeStatus.ONLINE,
|
|
718
|
+
]
|
|
719
|
+
if node_ids is not None:
|
|
720
|
+
placeholders = ",".join(["?"] * len(node_ids))
|
|
721
|
+
query += f" AND node_id IN ({placeholders})"
|
|
722
|
+
params.extend(uint64_to_int64(node_id) for node_id in node_ids)
|
|
723
|
+
self.conn.execute(query, params)
|
|
695
724
|
|
|
696
|
-
def
|
|
725
|
+
def get_node_info(
|
|
726
|
+
self,
|
|
727
|
+
*,
|
|
728
|
+
node_ids: Optional[Sequence[int]] = None,
|
|
729
|
+
owner_aids: Optional[Sequence[str]] = None,
|
|
730
|
+
statuses: Optional[Sequence[str]] = None,
|
|
731
|
+
) -> Sequence[NodeInfo]:
|
|
732
|
+
"""Retrieve information about nodes based on the specified filters."""
|
|
733
|
+
with self.conn:
|
|
734
|
+
self._check_and_tag_offline_nodes()
|
|
735
|
+
|
|
736
|
+
# Build the WHERE clause based on provided filters
|
|
737
|
+
conditions = []
|
|
738
|
+
params: list[Any] = []
|
|
739
|
+
if node_ids is not None:
|
|
740
|
+
sint64_node_ids = [uint64_to_int64(node_id) for node_id in node_ids]
|
|
741
|
+
placeholders = ",".join(["?"] * len(sint64_node_ids))
|
|
742
|
+
conditions.append(f"node_id IN ({placeholders})")
|
|
743
|
+
params.extend(sint64_node_ids)
|
|
744
|
+
if owner_aids is not None:
|
|
745
|
+
placeholders = ",".join(["?"] * len(owner_aids))
|
|
746
|
+
conditions.append(f"owner_aid IN ({placeholders})")
|
|
747
|
+
params.extend(owner_aids)
|
|
748
|
+
if statuses is not None:
|
|
749
|
+
placeholders = ",".join(["?"] * len(statuses))
|
|
750
|
+
conditions.append(f"status IN ({placeholders})")
|
|
751
|
+
params.extend(statuses)
|
|
752
|
+
|
|
753
|
+
# Construct the final query
|
|
754
|
+
query = "SELECT * FROM node"
|
|
755
|
+
if conditions:
|
|
756
|
+
query += " WHERE " + " AND ".join(conditions)
|
|
757
|
+
|
|
758
|
+
rows = self.conn.execute(query, params).fetchall()
|
|
759
|
+
|
|
760
|
+
result: list[NodeInfo] = []
|
|
761
|
+
for row in rows:
|
|
762
|
+
# Convert sint64 node_id to uint64
|
|
763
|
+
row["node_id"] = int64_to_uint64(row["node_id"])
|
|
764
|
+
result.append(NodeInfo(**row))
|
|
765
|
+
|
|
766
|
+
return result
|
|
767
|
+
|
|
768
|
+
def get_node_public_key(self, node_id: int) -> bytes:
|
|
697
769
|
"""Get `public_key` for the specified `node_id`."""
|
|
698
770
|
# Convert the uint64 value to sint64 for SQLite
|
|
699
|
-
sint64_node_id =
|
|
771
|
+
sint64_node_id = uint64_to_int64(node_id)
|
|
700
772
|
|
|
701
773
|
# Query the public key for the given node_id
|
|
702
|
-
query = "SELECT public_key FROM node WHERE node_id = ?"
|
|
703
|
-
rows = self.query(query, (sint64_node_id,))
|
|
774
|
+
query = "SELECT public_key FROM node WHERE node_id = ? AND status != ?;"
|
|
775
|
+
rows = self.query(query, (sint64_node_id, NodeStatus.UNREGISTERED))
|
|
704
776
|
|
|
705
777
|
# If no result is found, return None
|
|
706
778
|
if not rows:
|
|
707
|
-
raise ValueError(f"Node {node_id} not found")
|
|
779
|
+
raise ValueError(f"Node ID {node_id} not found")
|
|
708
780
|
|
|
709
|
-
# Return the public key
|
|
710
|
-
return rows[0]["public_key"]
|
|
781
|
+
# Return the public key
|
|
782
|
+
return cast(bytes, rows[0]["public_key"])
|
|
711
783
|
|
|
712
|
-
def
|
|
713
|
-
"""
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
node_id: int = row[0]["node_id"]
|
|
784
|
+
def get_node_id_by_public_key(self, public_key: bytes) -> Optional[int]:
|
|
785
|
+
"""Get `node_id` for the specified `public_key` if it exists and is not
|
|
786
|
+
deleted."""
|
|
787
|
+
query = "SELECT node_id FROM node WHERE public_key = ? AND status != ?;"
|
|
788
|
+
rows = self.query(query, (public_key, NodeStatus.UNREGISTERED))
|
|
718
789
|
|
|
719
|
-
|
|
720
|
-
|
|
790
|
+
# If no result is found, return None
|
|
791
|
+
if not rows:
|
|
792
|
+
return None
|
|
721
793
|
|
|
722
|
-
|
|
723
|
-
|
|
794
|
+
# Convert sint64 node_id to uint64
|
|
795
|
+
node_id = int64_to_uint64(rows[0]["node_id"])
|
|
796
|
+
return node_id
|
|
724
797
|
|
|
725
798
|
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
726
799
|
def create_run(
|
|
@@ -737,7 +810,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
737
810
|
uint64_run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
738
811
|
|
|
739
812
|
# Convert the uint64 value to sint64 for SQLite
|
|
740
|
-
sint64_run_id =
|
|
813
|
+
sint64_run_id = uint64_to_int64(uint64_run_id)
|
|
741
814
|
|
|
742
815
|
# Check conflicts
|
|
743
816
|
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
|
|
@@ -773,28 +846,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
773
846
|
log(ERROR, "Unexpected run creation failure.")
|
|
774
847
|
return 0
|
|
775
848
|
|
|
776
|
-
def clear_supernode_auth_keys(self) -> None:
|
|
777
|
-
"""Clear stored `node_public_keys` in the link state if any."""
|
|
778
|
-
self.query("DELETE FROM public_key;")
|
|
779
|
-
|
|
780
|
-
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
781
|
-
"""Store a set of `node_public_keys` in the link state."""
|
|
782
|
-
query = "INSERT INTO public_key (public_key) VALUES (?)"
|
|
783
|
-
data = [(key,) for key in public_keys]
|
|
784
|
-
self.query(query, data)
|
|
785
|
-
|
|
786
|
-
def store_node_public_key(self, public_key: bytes) -> None:
|
|
787
|
-
"""Store a `node_public_key` in the link state."""
|
|
788
|
-
query = "INSERT INTO public_key (public_key) VALUES (:public_key)"
|
|
789
|
-
self.query(query, {"public_key": public_key})
|
|
790
|
-
|
|
791
|
-
def get_node_public_keys(self) -> set[bytes]:
|
|
792
|
-
"""Retrieve all currently stored `node_public_keys` as a set."""
|
|
793
|
-
query = "SELECT public_key FROM public_key"
|
|
794
|
-
rows = self.query(query)
|
|
795
|
-
result: set[bytes] = {row["public_key"] for row in rows}
|
|
796
|
-
return result
|
|
797
|
-
|
|
798
849
|
def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
|
|
799
850
|
"""Retrieve all run IDs if `flwr_aid` is not specified.
|
|
800
851
|
|
|
@@ -807,7 +858,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
807
858
|
)
|
|
808
859
|
else:
|
|
809
860
|
rows = self.query("SELECT run_id FROM run;", ())
|
|
810
|
-
return {
|
|
861
|
+
return {int64_to_uint64(row["run_id"]) for row in rows}
|
|
811
862
|
|
|
812
863
|
def _check_and_tag_inactive_run(self, run_ids: set[int]) -> None:
|
|
813
864
|
"""Check if any runs are no longer active.
|
|
@@ -815,7 +866,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
815
866
|
Marks runs with status 'starting' or 'running' as failed
|
|
816
867
|
if they have not sent a heartbeat before `active_until`.
|
|
817
868
|
"""
|
|
818
|
-
sint_run_ids = [
|
|
869
|
+
sint_run_ids = [uint64_to_int64(run_id) for run_id in run_ids]
|
|
819
870
|
query = "UPDATE run SET finished_at = ?, sub_status = ?, details = ? "
|
|
820
871
|
query += "WHERE starting_at != '' AND finished_at = '' AND active_until < ?"
|
|
821
872
|
query += f" AND run_id IN ({','.join(['?'] * len(run_ids))});"
|
|
@@ -837,13 +888,13 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
837
888
|
self._check_and_tag_inactive_run(run_ids={run_id})
|
|
838
889
|
|
|
839
890
|
# Convert the uint64 value to sint64 for SQLite
|
|
840
|
-
sint64_run_id =
|
|
891
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
841
892
|
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
842
893
|
rows = self.query(query, (sint64_run_id,))
|
|
843
894
|
if rows:
|
|
844
895
|
row = rows[0]
|
|
845
896
|
return Run(
|
|
846
|
-
run_id=
|
|
897
|
+
run_id=int64_to_uint64(row["run_id"]),
|
|
847
898
|
fab_id=row["fab_id"],
|
|
848
899
|
fab_version=row["fab_version"],
|
|
849
900
|
fab_hash=row["fab_hash"],
|
|
@@ -868,13 +919,13 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
868
919
|
self._check_and_tag_inactive_run(run_ids=run_ids)
|
|
869
920
|
|
|
870
921
|
# Convert the uint64 value to sint64 for SQLite
|
|
871
|
-
sint64_run_ids = (
|
|
922
|
+
sint64_run_ids = (uint64_to_int64(run_id) for run_id in set(run_ids))
|
|
872
923
|
query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});"
|
|
873
924
|
rows = self.query(query, tuple(sint64_run_ids))
|
|
874
925
|
|
|
875
926
|
return {
|
|
876
927
|
# Restore uint64 run IDs
|
|
877
|
-
|
|
928
|
+
int64_to_uint64(row["run_id"]): RunStatus(
|
|
878
929
|
status=determine_run_status(row),
|
|
879
930
|
sub_status=row["sub_status"],
|
|
880
931
|
details=row["details"],
|
|
@@ -888,7 +939,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
888
939
|
self._check_and_tag_inactive_run(run_ids={run_id})
|
|
889
940
|
|
|
890
941
|
# Convert the uint64 value to sint64 for SQLite
|
|
891
|
-
sint64_run_id =
|
|
942
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
892
943
|
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
893
944
|
rows = self.query(query, (sint64_run_id,))
|
|
894
945
|
|
|
@@ -933,7 +984,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
933
984
|
# when switching to starting or running
|
|
934
985
|
current = now()
|
|
935
986
|
if new_status.status in (Status.STARTING, Status.RUNNING):
|
|
936
|
-
heartbeat_interval =
|
|
987
|
+
heartbeat_interval = HEARTBEAT_INTERVAL_INF
|
|
937
988
|
active_until = current.timestamp() + heartbeat_interval
|
|
938
989
|
else:
|
|
939
990
|
heartbeat_interval = 0
|
|
@@ -954,7 +1005,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
954
1005
|
new_status.details,
|
|
955
1006
|
active_until,
|
|
956
1007
|
heartbeat_interval,
|
|
957
|
-
|
|
1008
|
+
uint64_to_int64(run_id),
|
|
958
1009
|
)
|
|
959
1010
|
self.query(query % timestamp_fld, data)
|
|
960
1011
|
return True
|
|
@@ -967,14 +1018,14 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
967
1018
|
query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1;"
|
|
968
1019
|
rows = self.query(query)
|
|
969
1020
|
if rows:
|
|
970
|
-
pending_run_id =
|
|
1021
|
+
pending_run_id = int64_to_uint64(rows[0]["run_id"])
|
|
971
1022
|
|
|
972
1023
|
return pending_run_id
|
|
973
1024
|
|
|
974
1025
|
def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
|
|
975
1026
|
"""Retrieve the federation options for the specified `run_id`."""
|
|
976
1027
|
# Convert the uint64 value to sint64 for SQLite
|
|
977
|
-
sint64_run_id =
|
|
1028
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
978
1029
|
query = "SELECT federation_options FROM run WHERE run_id = ?;"
|
|
979
1030
|
rows = self.query(query, (sint64_run_id,))
|
|
980
1031
|
|
|
@@ -996,26 +1047,38 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
996
1047
|
HEARTBEAT_PATIENCE = N allows for N-1 missed heartbeat before
|
|
997
1048
|
the node is marked as offline.
|
|
998
1049
|
"""
|
|
999
|
-
|
|
1050
|
+
if self.conn is None:
|
|
1051
|
+
raise AttributeError("LinkState not initialized")
|
|
1000
1052
|
|
|
1001
|
-
|
|
1002
|
-
query = "SELECT 1 FROM node WHERE node_id = ?"
|
|
1003
|
-
if not self.query(query, (sint64_node_id,)):
|
|
1004
|
-
return False
|
|
1053
|
+
sint64_node_id = uint64_to_int64(node_id)
|
|
1005
1054
|
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
"
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1055
|
+
with self.conn:
|
|
1056
|
+
# Check if node exists and not deleted
|
|
1057
|
+
query = "SELECT status FROM node WHERE node_id = ? AND status != ?"
|
|
1058
|
+
row = self.conn.execute(
|
|
1059
|
+
query, (sint64_node_id, NodeStatus.UNREGISTERED)
|
|
1060
|
+
).fetchone()
|
|
1061
|
+
if row is None:
|
|
1062
|
+
return False
|
|
1063
|
+
|
|
1064
|
+
# Construct query and params
|
|
1065
|
+
current_dt = now()
|
|
1066
|
+
query = "UPDATE node SET online_until = ?, heartbeat_interval = ?"
|
|
1067
|
+
params: list[Any] = [
|
|
1068
|
+
current_dt.timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval,
|
|
1014
1069
|
heartbeat_interval,
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1070
|
+
]
|
|
1071
|
+
|
|
1072
|
+
# Set timestamp if the status changes
|
|
1073
|
+
if row["status"] != NodeStatus.ONLINE:
|
|
1074
|
+
query += ", status = ?, last_activated_at = ?"
|
|
1075
|
+
params += [NodeStatus.ONLINE, current_dt.isoformat()]
|
|
1076
|
+
|
|
1077
|
+
# Execute the query, refreshing `online_until` and `heartbeat_interval`
|
|
1078
|
+
query += " WHERE node_id = ?"
|
|
1079
|
+
params += [sint64_node_id]
|
|
1080
|
+
self.conn.execute(query, params)
|
|
1081
|
+
return True
|
|
1019
1082
|
|
|
1020
1083
|
def acknowledge_app_heartbeat(self, run_id: int, heartbeat_interval: float) -> bool:
|
|
1021
1084
|
"""Acknowledge a heartbeat received from a ServerApp for a given run.
|
|
@@ -1029,7 +1092,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1029
1092
|
self._check_and_tag_inactive_run(run_ids={run_id})
|
|
1030
1093
|
|
|
1031
1094
|
# Search for the run
|
|
1032
|
-
sint_run_id =
|
|
1095
|
+
sint_run_id = uint64_to_int64(run_id)
|
|
1033
1096
|
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
1034
1097
|
rows = self.query(query, (sint_run_id,))
|
|
1035
1098
|
|
|
@@ -1059,7 +1122,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1059
1122
|
"""Get the context for the specified `run_id`."""
|
|
1060
1123
|
# Retrieve context if any
|
|
1061
1124
|
query = "SELECT context FROM context WHERE run_id = ?;"
|
|
1062
|
-
rows = self.query(query, (
|
|
1125
|
+
rows = self.query(query, (uint64_to_int64(run_id),))
|
|
1063
1126
|
context = context_from_bytes(rows[0]["context"]) if rows else None
|
|
1064
1127
|
return context
|
|
1065
1128
|
|
|
@@ -1067,7 +1130,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1067
1130
|
"""Set the context for the specified `run_id`."""
|
|
1068
1131
|
# Convert context to bytes
|
|
1069
1132
|
context_bytes = context_to_bytes(context)
|
|
1070
|
-
sint_run_id =
|
|
1133
|
+
sint_run_id = uint64_to_int64(run_id)
|
|
1071
1134
|
|
|
1072
1135
|
# Check if any existing Context assigned to the run_id
|
|
1073
1136
|
query = "SELECT COUNT(*) FROM context WHERE run_id = ?;"
|
|
@@ -1086,7 +1149,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1086
1149
|
def add_serverapp_log(self, run_id: int, log_message: str) -> None:
|
|
1087
1150
|
"""Add a log entry to the ServerApp logs for the specified `run_id`."""
|
|
1088
1151
|
# Convert the uint64 value to sint64 for SQLite
|
|
1089
|
-
sint64_run_id =
|
|
1152
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
1090
1153
|
|
|
1091
1154
|
# Store log
|
|
1092
1155
|
try:
|
|
@@ -1102,7 +1165,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1102
1165
|
) -> tuple[str, float]:
|
|
1103
1166
|
"""Get the ServerApp logs for the specified `run_id`."""
|
|
1104
1167
|
# Convert the uint64 value to sint64 for SQLite
|
|
1105
|
-
sint64_run_id =
|
|
1168
|
+
sint64_run_id = uint64_to_int64(run_id)
|
|
1106
1169
|
|
|
1107
1170
|
# Check if the run_id exists
|
|
1108
1171
|
query = "SELECT run_id FROM run WHERE run_id = ?;"
|
|
@@ -1140,7 +1203,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1140
1203
|
message_ins = rows[0]
|
|
1141
1204
|
created_at = message_ins["created_at"]
|
|
1142
1205
|
ttl = message_ins["ttl"]
|
|
1143
|
-
current_time =
|
|
1206
|
+
current_time = now().timestamp()
|
|
1144
1207
|
|
|
1145
1208
|
# Check if Message is expired
|
|
1146
1209
|
if ttl is not None and created_at + ttl <= current_time:
|
|
@@ -1152,7 +1215,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1152
1215
|
"""Create a token for the given run ID."""
|
|
1153
1216
|
token = secrets.token_hex(FLWR_APP_TOKEN_LENGTH) # Generate a random token
|
|
1154
1217
|
query = "INSERT INTO token_store (run_id, token) VALUES (:run_id, :token);"
|
|
1155
|
-
data = {"run_id":
|
|
1218
|
+
data = {"run_id": uint64_to_int64(run_id), "token": token}
|
|
1156
1219
|
try:
|
|
1157
1220
|
self.query(query, data)
|
|
1158
1221
|
except sqlite3.IntegrityError:
|
|
@@ -1162,7 +1225,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1162
1225
|
def verify_token(self, run_id: int, token: str) -> bool:
|
|
1163
1226
|
"""Verify a token for the given run ID."""
|
|
1164
1227
|
query = "SELECT token FROM token_store WHERE run_id = :run_id;"
|
|
1165
|
-
data = {"run_id":
|
|
1228
|
+
data = {"run_id": uint64_to_int64(run_id)}
|
|
1166
1229
|
rows = self.query(query, data)
|
|
1167
1230
|
if not rows:
|
|
1168
1231
|
return False
|
|
@@ -1171,7 +1234,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1171
1234
|
def delete_token(self, run_id: int) -> None:
|
|
1172
1235
|
"""Delete the token for the given run ID."""
|
|
1173
1236
|
query = "DELETE FROM token_store WHERE run_id = :run_id;"
|
|
1174
|
-
data = {"run_id":
|
|
1237
|
+
data = {"run_id": uint64_to_int64(run_id)}
|
|
1175
1238
|
self.query(query, data)
|
|
1176
1239
|
|
|
1177
1240
|
def get_run_id_by_token(self, token: str) -> Optional[int]:
|
|
@@ -1181,19 +1244,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1181
1244
|
rows = self.query(query, data)
|
|
1182
1245
|
if not rows:
|
|
1183
1246
|
return None
|
|
1184
|
-
return
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
def dict_factory(
|
|
1188
|
-
cursor: sqlite3.Cursor,
|
|
1189
|
-
row: sqlite3.Row,
|
|
1190
|
-
) -> dict[str, Any]:
|
|
1191
|
-
"""Turn SQLite results into dicts.
|
|
1192
|
-
|
|
1193
|
-
Less efficent for retrival of large amounts of data but easier to use.
|
|
1194
|
-
"""
|
|
1195
|
-
fields = [column[0] for column in cursor.description]
|
|
1196
|
-
return dict(zip(fields, row))
|
|
1247
|
+
return int64_to_uint64(rows[0]["run_id"])
|
|
1197
1248
|
|
|
1198
1249
|
|
|
1199
1250
|
def message_to_dict(message: Message) -> dict[str, Any]:
|
|
@@ -1248,5 +1299,5 @@ def determine_run_status(row: dict[str, Any]) -> str:
|
|
|
1248
1299
|
return Status.RUNNING
|
|
1249
1300
|
return Status.STARTING
|
|
1250
1301
|
return Status.PENDING
|
|
1251
|
-
run_id =
|
|
1302
|
+
run_id = int64_to_uint64(row["run_id"])
|
|
1252
1303
|
raise sqlite3.IntegrityError(f"The run {run_id} does not have a valid status.")
|