flwr 1.13.1__py3-none-any.whl → 1.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +5 -0
- flwr/cli/auth_plugin/__init__.py +31 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +150 -0
- flwr/cli/build.py +1 -0
- flwr/cli/cli_user_auth_interceptor.py +90 -0
- flwr/cli/config_utils.py +43 -149
- flwr/cli/constant.py +27 -0
- flwr/cli/example.py +1 -0
- flwr/cli/install.py +2 -1
- flwr/cli/log.py +34 -37
- flwr/cli/login/__init__.py +22 -0
- flwr/cli/login/login.py +116 -0
- flwr/cli/ls.py +214 -106
- flwr/cli/new/__init__.py +1 -0
- flwr/cli/new/new.py +2 -1
- flwr/cli/new/templates/app/.gitignore.tpl +3 -0
- flwr/cli/new/templates/app/README.md.tpl +3 -2
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +3 -4
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/run/__init__.py +1 -0
- flwr/cli/run/run.py +103 -43
- flwr/cli/stop.py +139 -0
- flwr/cli/utils.py +186 -8
- flwr/client/app.py +49 -50
- flwr/client/client.py +1 -32
- flwr/client/clientapp/app.py +23 -26
- flwr/client/clientapp/utils.py +2 -1
- flwr/client/grpc_adapter_client/connection.py +1 -1
- flwr/client/grpc_client/connection.py +2 -13
- flwr/client/grpc_rere_client/client_interceptor.py +19 -119
- flwr/client/grpc_rere_client/connection.py +59 -43
- flwr/client/grpc_rere_client/grpc_adapter.py +12 -12
- flwr/client/message_handler/message_handler.py +1 -2
- flwr/client/message_handler/task_handler.py +0 -17
- flwr/client/mod/comms_mods.py +1 -0
- flwr/client/mod/localdp_mod.py +1 -1
- flwr/client/nodestate/__init__.py +1 -0
- flwr/client/nodestate/nodestate.py +1 -0
- flwr/client/nodestate/nodestate_factory.py +1 -0
- flwr/client/numpy_client.py +0 -44
- flwr/client/rest_client/connection.py +37 -29
- flwr/client/supernode/app.py +20 -74
- flwr/common/address.py +1 -0
- flwr/common/args.py +26 -47
- flwr/common/auth_plugin/__init__.py +24 -0
- flwr/common/auth_plugin/auth_plugin.py +122 -0
- flwr/common/config.py +169 -17
- flwr/common/constant.py +38 -9
- flwr/common/differential_privacy.py +2 -1
- flwr/common/exit/__init__.py +24 -0
- flwr/common/exit/exit.py +99 -0
- flwr/common/exit/exit_code.py +93 -0
- flwr/common/exit_handlers.py +24 -10
- flwr/common/grpc.py +167 -4
- flwr/common/logger.py +66 -7
- flwr/common/message.py +1 -0
- flwr/common/object_ref.py +57 -54
- flwr/common/pyproject.py +1 -0
- flwr/common/record/__init__.py +1 -0
- flwr/common/record/parametersrecord.py +1 -0
- flwr/common/record/recordset.py +1 -1
- flwr/common/retry_invoker.py +77 -0
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
- flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
- flwr/common/serde.py +6 -4
- flwr/common/telemetry.py +15 -4
- flwr/common/typing.py +32 -0
- flwr/common/version.py +1 -0
- flwr/proto/clientappio_pb2.py +1 -1
- flwr/proto/error_pb2.py +1 -1
- flwr/proto/exec_pb2.py +27 -15
- flwr/proto/exec_pb2.pyi +80 -2
- flwr/proto/exec_pb2_grpc.py +102 -0
- flwr/proto/exec_pb2_grpc.pyi +39 -0
- flwr/proto/fab_pb2.py +5 -5
- flwr/proto/fab_pb2.pyi +4 -1
- flwr/proto/fleet_pb2.py +31 -31
- flwr/proto/fleet_pb2.pyi +23 -23
- flwr/proto/fleet_pb2_grpc.py +30 -30
- flwr/proto/fleet_pb2_grpc.pyi +20 -20
- flwr/proto/grpcadapter_pb2.py +1 -1
- flwr/proto/log_pb2.py +1 -1
- flwr/proto/message_pb2.py +1 -1
- flwr/proto/node_pb2.py +3 -3
- flwr/proto/node_pb2.pyi +1 -4
- flwr/proto/recordset_pb2.py +1 -1
- flwr/proto/run_pb2.py +1 -1
- flwr/proto/serverappio_pb2.py +24 -25
- flwr/proto/serverappio_pb2.pyi +32 -32
- flwr/proto/serverappio_pb2_grpc.py +62 -28
- flwr/proto/serverappio_pb2_grpc.pyi +29 -16
- flwr/proto/simulationio_pb2.py +3 -3
- flwr/proto/simulationio_pb2_grpc.py +34 -0
- flwr/proto/simulationio_pb2_grpc.pyi +13 -0
- flwr/proto/task_pb2.py +1 -1
- flwr/proto/transport_pb2.py +1 -1
- flwr/server/app.py +152 -112
- flwr/server/compat/app_utils.py +7 -2
- flwr/server/compat/driver_client_proxy.py +1 -2
- flwr/server/driver/grpc_driver.py +38 -85
- flwr/server/driver/inmemory_driver.py +7 -2
- flwr/server/run_serverapp.py +8 -9
- flwr/server/serverapp/app.py +37 -13
- flwr/server/strategy/dpfedavg_fixed.py +1 -0
- flwr/server/superlink/driver/serverappio_grpc.py +2 -1
- flwr/server/superlink/driver/serverappio_servicer.py +148 -63
- flwr/server/superlink/ffs/disk_ffs.py +1 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -87
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +56 -35
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +99 -169
- flwr/server/superlink/fleet/message_handler/message_handler.py +69 -29
- flwr/server/superlink/fleet/rest_rere/rest_api.py +20 -19
- flwr/server/superlink/fleet/vce/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
- flwr/server/superlink/fleet/vce/vce_api.py +2 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +60 -99
- flwr/server/superlink/linkstate/linkstate.py +30 -36
- flwr/server/superlink/linkstate/sqlite_linkstate.py +105 -188
- flwr/server/superlink/linkstate/utils.py +18 -8
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +33 -0
- flwr/server/superlink/utils.py +65 -0
- flwr/server/utils/validator.py +9 -34
- flwr/simulation/app.py +20 -10
- flwr/simulation/legacy_app.py +4 -2
- flwr/simulation/ray_transport/ray_actor.py +1 -0
- flwr/simulation/ray_transport/utils.py +1 -0
- flwr/simulation/run_simulation.py +36 -22
- flwr/simulation/simulationio_connection.py +5 -1
- flwr/superexec/app.py +1 -0
- flwr/superexec/deployment.py +1 -0
- flwr/superexec/exec_grpc.py +20 -2
- flwr/superexec/exec_servicer.py +97 -2
- flwr/superexec/exec_user_auth_interceptor.py +101 -0
- flwr/superexec/executor.py +1 -0
- {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/METADATA +14 -13
- {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/RECORD +150 -144
- flwr/proto/common_pb2.py +0 -36
- flwr/proto/common_pb2.pyi +0 -121
- flwr/proto/common_pb2_grpc.py +0 -4
- flwr/proto/common_pb2_grpc.pyi +0 -4
- flwr/proto/control_pb2.py +0 -27
- flwr/proto/control_pb2.pyi +0 -7
- flwr/proto/control_pb2_grpc.py +0 -135
- flwr/proto/control_pb2_grpc.pyi +0 -53
- {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/LICENSE +0 -0
- {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/WHEEL +0 -0
- {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/entry_points.txt +0 -0
|
@@ -14,12 +14,12 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""SQLite based implemenation of the link state."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
# pylint: disable=too-many-lines
|
|
18
19
|
|
|
19
20
|
import json
|
|
20
21
|
import re
|
|
21
22
|
import sqlite3
|
|
22
|
-
import threading
|
|
23
23
|
import time
|
|
24
24
|
from collections.abc import Sequence
|
|
25
25
|
from logging import DEBUG, ERROR, WARNING
|
|
@@ -31,6 +31,7 @@ from flwr.common.constant import (
|
|
|
31
31
|
MESSAGE_TTL_TOLERANCE,
|
|
32
32
|
NODE_ID_NUM_BYTES,
|
|
33
33
|
RUN_ID_NUM_BYTES,
|
|
34
|
+
SUPERLINK_NODE_ID,
|
|
34
35
|
Status,
|
|
35
36
|
)
|
|
36
37
|
from flwr.common.record import ConfigsRecord
|
|
@@ -70,16 +71,9 @@ CREATE TABLE IF NOT EXISTS node(
|
|
|
70
71
|
);
|
|
71
72
|
"""
|
|
72
73
|
|
|
73
|
-
SQL_CREATE_TABLE_CREDENTIAL = """
|
|
74
|
-
CREATE TABLE IF NOT EXISTS credential(
|
|
75
|
-
private_key BLOB PRIMARY KEY,
|
|
76
|
-
public_key BLOB
|
|
77
|
-
);
|
|
78
|
-
"""
|
|
79
|
-
|
|
80
74
|
SQL_CREATE_TABLE_PUBLIC_KEY = """
|
|
81
75
|
CREATE TABLE IF NOT EXISTS public_key(
|
|
82
|
-
public_key
|
|
76
|
+
public_key BLOB PRIMARY KEY
|
|
83
77
|
);
|
|
84
78
|
"""
|
|
85
79
|
|
|
@@ -128,9 +122,7 @@ CREATE TABLE IF NOT EXISTS task_ins(
|
|
|
128
122
|
task_id TEXT UNIQUE,
|
|
129
123
|
group_id TEXT,
|
|
130
124
|
run_id INTEGER,
|
|
131
|
-
producer_anonymous BOOLEAN,
|
|
132
125
|
producer_node_id INTEGER,
|
|
133
|
-
consumer_anonymous BOOLEAN,
|
|
134
126
|
consumer_node_id INTEGER,
|
|
135
127
|
created_at REAL,
|
|
136
128
|
delivered_at TEXT,
|
|
@@ -148,9 +140,7 @@ CREATE TABLE IF NOT EXISTS task_res(
|
|
|
148
140
|
task_id TEXT UNIQUE,
|
|
149
141
|
group_id TEXT,
|
|
150
142
|
run_id INTEGER,
|
|
151
|
-
producer_anonymous BOOLEAN,
|
|
152
143
|
producer_node_id INTEGER,
|
|
153
|
-
consumer_anonymous BOOLEAN,
|
|
154
144
|
consumer_node_id INTEGER,
|
|
155
145
|
created_at REAL,
|
|
156
146
|
delivered_at TEXT,
|
|
@@ -183,7 +173,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
183
173
|
"""
|
|
184
174
|
self.database_path = database_path
|
|
185
175
|
self.conn: Optional[sqlite3.Connection] = None
|
|
186
|
-
self.lock = threading.RLock()
|
|
187
176
|
|
|
188
177
|
def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
|
|
189
178
|
"""Create tables if they don't exist yet.
|
|
@@ -212,11 +201,9 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
212
201
|
cur.execute(SQL_CREATE_TABLE_TASK_INS)
|
|
213
202
|
cur.execute(SQL_CREATE_TABLE_TASK_RES)
|
|
214
203
|
cur.execute(SQL_CREATE_TABLE_NODE)
|
|
215
|
-
cur.execute(SQL_CREATE_TABLE_CREDENTIAL)
|
|
216
204
|
cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
|
|
217
205
|
cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
|
|
218
206
|
res = cur.execute("SELECT name FROM sqlite_schema;")
|
|
219
|
-
|
|
220
207
|
return res.fetchall()
|
|
221
208
|
|
|
222
209
|
def query(
|
|
@@ -265,11 +252,8 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
265
252
|
|
|
266
253
|
Constraints
|
|
267
254
|
-----------
|
|
268
|
-
If `task_ins.task.consumer.anonymous` is `True`, then
|
|
269
|
-
`task_ins.task.consumer.node_id` MUST NOT be set (equal 0).
|
|
270
255
|
|
|
271
|
-
|
|
272
|
-
`task_ins.task.consumer.node_id` MUST be set (not 0)
|
|
256
|
+
`task_ins.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
|
|
273
257
|
"""
|
|
274
258
|
# Validate task
|
|
275
259
|
errors = validate_task_ins_or_res(task_ins)
|
|
@@ -294,7 +278,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
294
278
|
log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
|
|
295
279
|
return None
|
|
296
280
|
# Validate source node ID
|
|
297
|
-
if task_ins.task.producer.node_id !=
|
|
281
|
+
if task_ins.task.producer.node_id != SUPERLINK_NODE_ID:
|
|
298
282
|
log(
|
|
299
283
|
ERROR,
|
|
300
284
|
"Invalid source node ID for TaskIns: %s",
|
|
@@ -303,14 +287,13 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
303
287
|
return None
|
|
304
288
|
# Validate destination node ID
|
|
305
289
|
query = "SELECT node_id FROM node WHERE node_id = ?;"
|
|
306
|
-
if not
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
return None
|
|
290
|
+
if not self.query(query, (data[0]["consumer_node_id"],)):
|
|
291
|
+
log(
|
|
292
|
+
ERROR,
|
|
293
|
+
"Invalid destination node ID for TaskIns: %s",
|
|
294
|
+
task_ins.task.consumer.node_id,
|
|
295
|
+
)
|
|
296
|
+
return None
|
|
314
297
|
|
|
315
298
|
columns = ", ".join([f":{key}" for key in data[0]])
|
|
316
299
|
query = f"INSERT INTO task_ins VALUES({columns});"
|
|
@@ -321,25 +304,18 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
321
304
|
|
|
322
305
|
return task_id
|
|
323
306
|
|
|
324
|
-
def get_task_ins(
|
|
325
|
-
|
|
326
|
-
) -> list[TaskIns]:
|
|
327
|
-
"""Get undelivered TaskIns for one node (either anonymous or with ID).
|
|
307
|
+
def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
|
|
308
|
+
"""Get undelivered TaskIns for one node.
|
|
328
309
|
|
|
329
310
|
Usually, the Fleet API calls this for Nodes planning to work on one or more
|
|
330
311
|
TaskIns.
|
|
331
312
|
|
|
332
313
|
Constraints
|
|
333
314
|
-----------
|
|
334
|
-
|
|
315
|
+
Retrieve all TaskIns where
|
|
335
316
|
|
|
336
317
|
1. the `task_ins.task.consumer.node_id` equals `node_id` AND
|
|
337
|
-
2. the `task_ins.task.
|
|
338
|
-
3. the `task_ins.task.delivered_at` equals `""`.
|
|
339
|
-
|
|
340
|
-
If `node_id` is `None`, retrieve all TaskIns where the
|
|
341
|
-
`task_ins.task.consumer.node_id` equals `0` and
|
|
342
|
-
`task_ins.task.consumer.anonymous` is set to `True`.
|
|
318
|
+
2. the `task_ins.task.delivered_at` equals `""`.
|
|
343
319
|
|
|
344
320
|
`delivered_at` MUST BE set (i.e., not `""`) otherwise the TaskIns MUST not be in
|
|
345
321
|
the result.
|
|
@@ -350,38 +326,23 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
350
326
|
if limit is not None and limit < 1:
|
|
351
327
|
raise AssertionError("`limit` must be >= 1")
|
|
352
328
|
|
|
353
|
-
if node_id ==
|
|
354
|
-
msg =
|
|
355
|
-
"`node_id` must be >= 1"
|
|
356
|
-
"\n\n For requesting anonymous tasks use `node_id` equal `None`"
|
|
357
|
-
)
|
|
329
|
+
if node_id == SUPERLINK_NODE_ID:
|
|
330
|
+
msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
|
|
358
331
|
raise AssertionError(msg)
|
|
359
332
|
|
|
360
333
|
data: dict[str, Union[str, int]] = {}
|
|
361
334
|
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
query = """
|
|
365
|
-
SELECT task_id
|
|
366
|
-
FROM task_ins
|
|
367
|
-
WHERE consumer_anonymous == 1
|
|
368
|
-
AND consumer_node_id == 0
|
|
369
|
-
AND delivered_at = ""
|
|
370
|
-
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
|
|
371
|
-
"""
|
|
372
|
-
else:
|
|
373
|
-
# Convert the uint64 value to sint64 for SQLite
|
|
374
|
-
data["node_id"] = convert_uint64_to_sint64(node_id)
|
|
335
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
336
|
+
data["node_id"] = convert_uint64_to_sint64(node_id)
|
|
375
337
|
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
"""
|
|
338
|
+
# Retrieve all TaskIns for node_id
|
|
339
|
+
query = """
|
|
340
|
+
SELECT task_id
|
|
341
|
+
FROM task_ins
|
|
342
|
+
WHERE consumer_node_id == :node_id
|
|
343
|
+
AND delivered_at = ""
|
|
344
|
+
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
|
|
345
|
+
"""
|
|
385
346
|
|
|
386
347
|
if limit is not None:
|
|
387
348
|
query += " LIMIT :limit"
|
|
@@ -431,11 +392,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
431
392
|
|
|
432
393
|
Constraints
|
|
433
394
|
-----------
|
|
434
|
-
|
|
435
|
-
`task_res.task.consumer.node_id` MUST NOT be set (equal 0).
|
|
436
|
-
|
|
437
|
-
If `task_res.task.consumer.anonymous` is `False`, then
|
|
438
|
-
`task_res.task.consumer.node_id` MUST be set (not 0)
|
|
395
|
+
`task_res.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
|
|
439
396
|
"""
|
|
440
397
|
# Validate task
|
|
441
398
|
errors = validate_task_ins_or_res(task_res)
|
|
@@ -461,7 +418,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
461
418
|
if (
|
|
462
419
|
task_ins
|
|
463
420
|
and task_res
|
|
464
|
-
and not (task_ins["consumer_anonymous"] or task_res.task.producer.anonymous)
|
|
465
421
|
and convert_sint64_to_uint64(task_ins["consumer_node_id"])
|
|
466
422
|
!= task_res.task.producer.node_id
|
|
467
423
|
):
|
|
@@ -569,9 +525,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
569
525
|
data: list[Any] = [delivered_at] + task_res_ids
|
|
570
526
|
self.query(query, data)
|
|
571
527
|
|
|
572
|
-
# Cleanup
|
|
573
|
-
self._force_delete_tasks_by_ids(set(ret.keys()))
|
|
574
|
-
|
|
575
528
|
return list(ret.values())
|
|
576
529
|
|
|
577
530
|
def num_task_ins(self) -> int:
|
|
@@ -595,86 +548,61 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
595
548
|
result: dict[str, int] = rows[0]
|
|
596
549
|
return result["num"]
|
|
597
550
|
|
|
598
|
-
def delete_tasks(self,
|
|
599
|
-
"""Delete
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
551
|
+
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
|
552
|
+
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
|
|
553
|
+
if not task_ins_ids:
|
|
554
|
+
return
|
|
555
|
+
if self.conn is None:
|
|
556
|
+
raise AttributeError("LinkState not initialized")
|
|
603
557
|
|
|
604
|
-
placeholders = ",".join([
|
|
605
|
-
data =
|
|
558
|
+
placeholders = ",".join(["?"] * len(task_ins_ids))
|
|
559
|
+
data = tuple(str(task_id) for task_id in task_ins_ids)
|
|
606
560
|
|
|
607
|
-
#
|
|
561
|
+
# Delete task_ins
|
|
608
562
|
query_1 = f"""
|
|
609
563
|
DELETE FROM task_ins
|
|
610
|
-
WHERE
|
|
611
|
-
AND task_id IN (
|
|
612
|
-
SELECT ancestry
|
|
613
|
-
FROM task_res
|
|
614
|
-
WHERE ancestry IN ({placeholders})
|
|
615
|
-
AND delivered_at != ''
|
|
616
|
-
);
|
|
564
|
+
WHERE task_id IN ({placeholders});
|
|
617
565
|
"""
|
|
618
566
|
|
|
619
|
-
#
|
|
567
|
+
# Delete task_res
|
|
620
568
|
query_2 = f"""
|
|
621
569
|
DELETE FROM task_res
|
|
622
|
-
WHERE ancestry IN ({placeholders})
|
|
623
|
-
AND delivered_at != '';
|
|
570
|
+
WHERE ancestry IN ({placeholders});
|
|
624
571
|
"""
|
|
625
572
|
|
|
626
|
-
if self.conn is None:
|
|
627
|
-
raise AttributeError("LinkState not intitialized")
|
|
628
|
-
|
|
629
573
|
with self.conn:
|
|
630
574
|
self.conn.execute(query_1, data)
|
|
631
575
|
self.conn.execute(query_2, data)
|
|
632
576
|
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
|
|
636
|
-
"""Delete tasks based on a set of TaskIns IDs."""
|
|
637
|
-
if not task_ids:
|
|
638
|
-
return
|
|
577
|
+
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
578
|
+
"""Get all TaskIns IDs for the given run_id."""
|
|
639
579
|
if self.conn is None:
|
|
640
580
|
raise AttributeError("LinkState not initialized")
|
|
641
581
|
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
query_1 = f"""
|
|
647
|
-
DELETE FROM task_ins
|
|
648
|
-
WHERE task_id IN ({placeholders});
|
|
582
|
+
query = """
|
|
583
|
+
SELECT task_id
|
|
584
|
+
FROM task_ins
|
|
585
|
+
WHERE run_id = :run_id;
|
|
649
586
|
"""
|
|
650
587
|
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
DELETE FROM task_res
|
|
654
|
-
WHERE ancestry IN ({placeholders});
|
|
655
|
-
"""
|
|
588
|
+
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
589
|
+
data = {"run_id": sint64_run_id}
|
|
656
590
|
|
|
657
591
|
with self.conn:
|
|
658
|
-
self.conn.execute(
|
|
659
|
-
self.conn.execute(query_2, data)
|
|
592
|
+
rows = self.conn.execute(query, data).fetchall()
|
|
660
593
|
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
) -> int:
|
|
594
|
+
return {UUID(row["task_id"]) for row in rows}
|
|
595
|
+
|
|
596
|
+
def create_node(self, ping_interval: float) -> int:
|
|
664
597
|
"""Create, store in the link state, and return `node_id`."""
|
|
665
598
|
# Sample a random uint64 as node_id
|
|
666
|
-
uint64_node_id = generate_rand_int_from_bytes(
|
|
599
|
+
uint64_node_id = generate_rand_int_from_bytes(
|
|
600
|
+
NODE_ID_NUM_BYTES, exclude=[SUPERLINK_NODE_ID, 0]
|
|
601
|
+
)
|
|
667
602
|
|
|
668
603
|
# Convert the uint64 value to sint64 for SQLite
|
|
669
604
|
sint64_node_id = convert_uint64_to_sint64(uint64_node_id)
|
|
670
605
|
|
|
671
|
-
query = "SELECT node_id FROM node WHERE public_key = :public_key;"
|
|
672
|
-
row = self.query(query, {"public_key": public_key})
|
|
673
|
-
|
|
674
|
-
if len(row) > 0:
|
|
675
|
-
log(ERROR, "Unexpected node registration failure.")
|
|
676
|
-
return 0
|
|
677
|
-
|
|
678
606
|
query = (
|
|
679
607
|
"INSERT INTO node "
|
|
680
608
|
"(node_id, online_until, ping_interval, public_key) "
|
|
@@ -688,7 +616,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
688
616
|
sint64_node_id,
|
|
689
617
|
time.time() + ping_interval,
|
|
690
618
|
ping_interval,
|
|
691
|
-
|
|
619
|
+
b"", # Initialize with an empty public key
|
|
692
620
|
),
|
|
693
621
|
)
|
|
694
622
|
except sqlite3.IntegrityError:
|
|
@@ -698,7 +626,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
698
626
|
# Note: we need to return the uint64 value of the node_id
|
|
699
627
|
return uint64_node_id
|
|
700
628
|
|
|
701
|
-
def delete_node(self, node_id: int
|
|
629
|
+
def delete_node(self, node_id: int) -> None:
|
|
702
630
|
"""Delete a node."""
|
|
703
631
|
# Convert the uint64 value to sint64 for SQLite
|
|
704
632
|
sint64_node_id = convert_uint64_to_sint64(node_id)
|
|
@@ -706,10 +634,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
706
634
|
query = "DELETE FROM node WHERE node_id = ?"
|
|
707
635
|
params = (sint64_node_id,)
|
|
708
636
|
|
|
709
|
-
if public_key is not None:
|
|
710
|
-
query += " AND public_key = ?"
|
|
711
|
-
params += (public_key,) # type: ignore
|
|
712
|
-
|
|
713
637
|
if self.conn is None:
|
|
714
638
|
raise AttributeError("LinkState is not initialized.")
|
|
715
639
|
|
|
@@ -717,7 +641,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
717
641
|
with self.conn:
|
|
718
642
|
rows = self.conn.execute(query, params)
|
|
719
643
|
if rows.rowcount < 1:
|
|
720
|
-
raise ValueError("
|
|
644
|
+
raise ValueError(f"Node {node_id} not found")
|
|
721
645
|
except KeyError as exc:
|
|
722
646
|
log(ERROR, {"query": query, "data": params, "exception": exc})
|
|
723
647
|
|
|
@@ -745,6 +669,41 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
745
669
|
result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows}
|
|
746
670
|
return result
|
|
747
671
|
|
|
672
|
+
def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
|
|
673
|
+
"""Set `public_key` for the specified `node_id`."""
|
|
674
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
675
|
+
sint64_node_id = convert_uint64_to_sint64(node_id)
|
|
676
|
+
|
|
677
|
+
# Check if the node exists in the `node` table
|
|
678
|
+
query = "SELECT 1 FROM node WHERE node_id = ?"
|
|
679
|
+
if not self.query(query, (sint64_node_id,)):
|
|
680
|
+
raise ValueError(f"Node {node_id} not found")
|
|
681
|
+
|
|
682
|
+
# Check if the public key is already in use in the `node` table
|
|
683
|
+
query = "SELECT 1 FROM node WHERE public_key = ?"
|
|
684
|
+
if self.query(query, (public_key,)):
|
|
685
|
+
raise ValueError("Public key already in use")
|
|
686
|
+
|
|
687
|
+
# Update the `node` table to set the public key for the given node ID
|
|
688
|
+
query = "UPDATE node SET public_key = ? WHERE node_id = ?"
|
|
689
|
+
self.query(query, (public_key, sint64_node_id))
|
|
690
|
+
|
|
691
|
+
def get_node_public_key(self, node_id: int) -> Optional[bytes]:
|
|
692
|
+
"""Get `public_key` for the specified `node_id`."""
|
|
693
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
694
|
+
sint64_node_id = convert_uint64_to_sint64(node_id)
|
|
695
|
+
|
|
696
|
+
# Query the public key for the given node_id
|
|
697
|
+
query = "SELECT public_key FROM node WHERE node_id = ?"
|
|
698
|
+
rows = self.query(query, (sint64_node_id,))
|
|
699
|
+
|
|
700
|
+
# If no result is found, return None
|
|
701
|
+
if not rows:
|
|
702
|
+
raise ValueError(f"Node {node_id} not found")
|
|
703
|
+
|
|
704
|
+
# Return the public key if it is not empty, otherwise return None
|
|
705
|
+
return rows[0]["public_key"] or None
|
|
706
|
+
|
|
748
707
|
def get_node_id(self, node_public_key: bytes) -> Optional[int]:
|
|
749
708
|
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
|
|
750
709
|
query = "SELECT node_id FROM node WHERE public_key = :public_key;"
|
|
@@ -784,8 +743,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
784
743
|
"federation_options, pending_at, starting_at, running_at, finished_at, "
|
|
785
744
|
"sub_status, details) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
|
|
786
745
|
)
|
|
787
|
-
if fab_hash:
|
|
788
|
-
fab_id, fab_version = "", ""
|
|
789
746
|
override_config_json = json.dumps(override_config)
|
|
790
747
|
data = [
|
|
791
748
|
sint64_run_id,
|
|
@@ -808,40 +765,9 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
808
765
|
log(ERROR, "Unexpected run creation failure.")
|
|
809
766
|
return 0
|
|
810
767
|
|
|
811
|
-
def
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
"""Store `server_private_key` and `server_public_key` in the link state."""
|
|
815
|
-
query = "SELECT COUNT(*) FROM credential"
|
|
816
|
-
count = self.query(query)[0]["COUNT(*)"]
|
|
817
|
-
if count < 1:
|
|
818
|
-
query = (
|
|
819
|
-
"INSERT OR REPLACE INTO credential (private_key, public_key) "
|
|
820
|
-
"VALUES (:private_key, :public_key)"
|
|
821
|
-
)
|
|
822
|
-
self.query(query, {"private_key": private_key, "public_key": public_key})
|
|
823
|
-
else:
|
|
824
|
-
raise RuntimeError("Server private and public key already set")
|
|
825
|
-
|
|
826
|
-
def get_server_private_key(self) -> Optional[bytes]:
|
|
827
|
-
"""Retrieve `server_private_key` in urlsafe bytes."""
|
|
828
|
-
query = "SELECT private_key FROM credential"
|
|
829
|
-
rows = self.query(query)
|
|
830
|
-
try:
|
|
831
|
-
private_key: Optional[bytes] = rows[0]["private_key"]
|
|
832
|
-
except IndexError:
|
|
833
|
-
private_key = None
|
|
834
|
-
return private_key
|
|
835
|
-
|
|
836
|
-
def get_server_public_key(self) -> Optional[bytes]:
|
|
837
|
-
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
838
|
-
query = "SELECT public_key FROM credential"
|
|
839
|
-
rows = self.query(query)
|
|
840
|
-
try:
|
|
841
|
-
public_key: Optional[bytes] = rows[0]["public_key"]
|
|
842
|
-
except IndexError:
|
|
843
|
-
public_key = None
|
|
844
|
-
return public_key
|
|
768
|
+
def clear_supernode_auth_keys(self) -> None:
|
|
769
|
+
"""Clear stored `node_public_keys` in the link state if any."""
|
|
770
|
+
self.query("DELETE FROM public_key;")
|
|
845
771
|
|
|
846
772
|
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
847
773
|
"""Store a set of `node_public_keys` in the link state."""
|
|
@@ -1001,17 +927,16 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
1001
927
|
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
1002
928
|
sint64_node_id = convert_uint64_to_sint64(node_id)
|
|
1003
929
|
|
|
1004
|
-
#
|
|
1005
|
-
query = "
|
|
1006
|
-
|
|
1007
|
-
self.query(
|
|
1008
|
-
query, (time.time() + ping_interval, ping_interval, sint64_node_id)
|
|
1009
|
-
)
|
|
1010
|
-
return True
|
|
1011
|
-
except sqlite3.IntegrityError:
|
|
1012
|
-
log(ERROR, "`node_id` does not exist.")
|
|
930
|
+
# Check if the node exists in the `node` table
|
|
931
|
+
query = "SELECT 1 FROM node WHERE node_id = ?"
|
|
932
|
+
if not self.query(query, (sint64_node_id,)):
|
|
1013
933
|
return False
|
|
1014
934
|
|
|
935
|
+
# Update `online_until` and `ping_interval` for the given `node_id`
|
|
936
|
+
query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?"
|
|
937
|
+
self.query(query, (time.time() + ping_interval, ping_interval, sint64_node_id))
|
|
938
|
+
return True
|
|
939
|
+
|
|
1015
940
|
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
|
1016
941
|
"""Get the context for the specified `run_id`."""
|
|
1017
942
|
# Retrieve context if any
|
|
@@ -1124,9 +1049,7 @@ def task_ins_to_dict(task_msg: TaskIns) -> dict[str, Any]:
|
|
|
1124
1049
|
"task_id": task_msg.task_id,
|
|
1125
1050
|
"group_id": task_msg.group_id,
|
|
1126
1051
|
"run_id": task_msg.run_id,
|
|
1127
|
-
"producer_anonymous": task_msg.task.producer.anonymous,
|
|
1128
1052
|
"producer_node_id": task_msg.task.producer.node_id,
|
|
1129
|
-
"consumer_anonymous": task_msg.task.consumer.anonymous,
|
|
1130
1053
|
"consumer_node_id": task_msg.task.consumer.node_id,
|
|
1131
1054
|
"created_at": task_msg.task.created_at,
|
|
1132
1055
|
"delivered_at": task_msg.task.delivered_at,
|
|
@@ -1145,9 +1068,7 @@ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
|
|
|
1145
1068
|
"task_id": task_msg.task_id,
|
|
1146
1069
|
"group_id": task_msg.group_id,
|
|
1147
1070
|
"run_id": task_msg.run_id,
|
|
1148
|
-
"producer_anonymous": task_msg.task.producer.anonymous,
|
|
1149
1071
|
"producer_node_id": task_msg.task.producer.node_id,
|
|
1150
|
-
"consumer_anonymous": task_msg.task.consumer.anonymous,
|
|
1151
1072
|
"consumer_node_id": task_msg.task.consumer.node_id,
|
|
1152
1073
|
"created_at": task_msg.task.created_at,
|
|
1153
1074
|
"delivered_at": task_msg.task.delivered_at,
|
|
@@ -1172,11 +1093,9 @@ def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
|
|
|
1172
1093
|
task=Task(
|
|
1173
1094
|
producer=Node(
|
|
1174
1095
|
node_id=task_dict["producer_node_id"],
|
|
1175
|
-
anonymous=task_dict["producer_anonymous"],
|
|
1176
1096
|
),
|
|
1177
1097
|
consumer=Node(
|
|
1178
1098
|
node_id=task_dict["consumer_node_id"],
|
|
1179
|
-
anonymous=task_dict["consumer_anonymous"],
|
|
1180
1099
|
),
|
|
1181
1100
|
created_at=task_dict["created_at"],
|
|
1182
1101
|
delivered_at=task_dict["delivered_at"],
|
|
@@ -1202,11 +1121,9 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
|
|
|
1202
1121
|
task=Task(
|
|
1203
1122
|
producer=Node(
|
|
1204
1123
|
node_id=task_dict["producer_node_id"],
|
|
1205
|
-
anonymous=task_dict["producer_anonymous"],
|
|
1206
1124
|
),
|
|
1207
1125
|
consumer=Node(
|
|
1208
1126
|
node_id=task_dict["consumer_node_id"],
|
|
1209
|
-
anonymous=task_dict["consumer_anonymous"],
|
|
1210
1127
|
),
|
|
1211
1128
|
created_at=task_dict["created_at"],
|
|
1212
1129
|
delivered_at=task_dict["delivered_at"],
|
|
@@ -21,7 +21,7 @@ from typing import Optional, Union
|
|
|
21
21
|
from uuid import UUID, uuid4
|
|
22
22
|
|
|
23
23
|
from flwr.common import ConfigsRecord, Context, log, now, serde
|
|
24
|
-
from flwr.common.constant import ErrorCode, Status, SubStatus
|
|
24
|
+
from flwr.common.constant import SUPERLINK_NODE_ID, ErrorCode, Status, SubStatus
|
|
25
25
|
from flwr.common.typing import RunStatus
|
|
26
26
|
|
|
27
27
|
# pylint: disable=E0611
|
|
@@ -60,9 +60,19 @@ REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON = (
|
|
|
60
60
|
)
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
def generate_rand_int_from_bytes(
|
|
64
|
-
|
|
65
|
-
|
|
63
|
+
def generate_rand_int_from_bytes(
|
|
64
|
+
num_bytes: int, exclude: Optional[list[int]] = None
|
|
65
|
+
) -> int:
|
|
66
|
+
"""Generate a random unsigned integer from `num_bytes` bytes.
|
|
67
|
+
|
|
68
|
+
If `exclude` is set, this function guarantees such number is not returned.
|
|
69
|
+
"""
|
|
70
|
+
num = int.from_bytes(urandom(num_bytes), "little", signed=False)
|
|
71
|
+
|
|
72
|
+
if exclude:
|
|
73
|
+
while num in exclude:
|
|
74
|
+
num = int.from_bytes(urandom(num_bytes), "little", signed=False)
|
|
75
|
+
return num
|
|
66
76
|
|
|
67
77
|
|
|
68
78
|
def convert_uint64_to_sint64(u: int) -> int:
|
|
@@ -246,8 +256,8 @@ def create_taskres_for_unavailable_taskins(taskins_id: Union[str, UUID]) -> Task
|
|
|
246
256
|
run_id=0, # Unknown run ID
|
|
247
257
|
task=Task(
|
|
248
258
|
# This function is only called by SuperLink, and thus it's the producer.
|
|
249
|
-
producer=Node(node_id=
|
|
250
|
-
consumer=Node(node_id=
|
|
259
|
+
producer=Node(node_id=SUPERLINK_NODE_ID),
|
|
260
|
+
consumer=Node(node_id=SUPERLINK_NODE_ID),
|
|
251
261
|
created_at=current_time,
|
|
252
262
|
ttl=0,
|
|
253
263
|
ancestry=[str(taskins_id)],
|
|
@@ -285,8 +295,8 @@ def create_taskres_for_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
|
|
|
285
295
|
run_id=ref_taskins.run_id,
|
|
286
296
|
task=Task(
|
|
287
297
|
# This function is only called by SuperLink, and thus it's the producer.
|
|
288
|
-
producer=Node(node_id=
|
|
289
|
-
consumer=Node(node_id=
|
|
298
|
+
producer=Node(node_id=SUPERLINK_NODE_ID),
|
|
299
|
+
consumer=Node(node_id=SUPERLINK_NODE_ID),
|
|
290
300
|
created_at=current_time,
|
|
291
301
|
ttl=ttl,
|
|
292
302
|
ancestry=[ref_taskins.task_id],
|
|
@@ -21,6 +21,7 @@ from typing import Optional
|
|
|
21
21
|
import grpc
|
|
22
22
|
|
|
23
23
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
24
|
+
from flwr.common.grpc import generic_create_grpc_server
|
|
24
25
|
from flwr.common.logger import log
|
|
25
26
|
from flwr.proto.simulationio_pb2_grpc import ( # pylint: disable=E0611
|
|
26
27
|
add_SimulationIoServicer_to_server,
|
|
@@ -28,7 +29,6 @@ from flwr.proto.simulationio_pb2_grpc import ( # pylint: disable=E0611
|
|
|
28
29
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
29
30
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
30
31
|
|
|
31
|
-
from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
|
|
32
32
|
from .simulationio_servicer import SimulationIoServicer
|
|
33
33
|
|
|
34
34
|
|