flwr 1.14.0__py3-none-any.whl → 1.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/auth_plugin/__init__.py +31 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +150 -0
- flwr/cli/cli_user_auth_interceptor.py +6 -2
- flwr/cli/config_utils.py +24 -147
- flwr/cli/constant.py +27 -0
- flwr/cli/install.py +1 -1
- flwr/cli/log.py +18 -3
- flwr/cli/login/login.py +43 -8
- flwr/cli/ls.py +14 -5
- flwr/cli/new/templates/app/README.md.tpl +3 -2
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/run/run.py +21 -11
- flwr/cli/stop.py +13 -4
- flwr/cli/utils.py +54 -40
- flwr/client/app.py +36 -48
- flwr/client/clientapp/app.py +19 -25
- flwr/client/clientapp/utils.py +1 -1
- flwr/client/grpc_client/connection.py +1 -12
- flwr/client/grpc_rere_client/client_interceptor.py +19 -119
- flwr/client/grpc_rere_client/connection.py +46 -36
- flwr/client/grpc_rere_client/grpc_adapter.py +12 -12
- flwr/client/message_handler/task_handler.py +0 -17
- flwr/client/rest_client/connection.py +34 -26
- flwr/client/supernode/app.py +18 -72
- flwr/common/args.py +25 -47
- flwr/common/auth_plugin/auth_plugin.py +34 -23
- flwr/common/config.py +166 -16
- flwr/common/constant.py +24 -9
- flwr/common/differential_privacy.py +2 -1
- flwr/common/exit/__init__.py +24 -0
- flwr/common/exit/exit.py +99 -0
- flwr/common/exit/exit_code.py +93 -0
- flwr/common/exit_handlers.py +32 -30
- flwr/common/grpc.py +167 -4
- flwr/common/logger.py +26 -7
- flwr/common/object_ref.py +0 -14
- flwr/common/record/recordset.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
- flwr/common/serde.py +6 -4
- flwr/common/typing.py +20 -0
- flwr/proto/clientappio_pb2.py +1 -1
- flwr/proto/error_pb2.py +1 -1
- flwr/proto/exec_pb2.py +13 -25
- flwr/proto/exec_pb2.pyi +27 -54
- flwr/proto/fab_pb2.py +1 -1
- flwr/proto/fleet_pb2.py +31 -31
- flwr/proto/fleet_pb2.pyi +23 -23
- flwr/proto/fleet_pb2_grpc.py +30 -30
- flwr/proto/fleet_pb2_grpc.pyi +20 -20
- flwr/proto/grpcadapter_pb2.py +1 -1
- flwr/proto/log_pb2.py +1 -1
- flwr/proto/message_pb2.py +1 -1
- flwr/proto/node_pb2.py +3 -3
- flwr/proto/node_pb2.pyi +1 -4
- flwr/proto/recordset_pb2.py +1 -1
- flwr/proto/run_pb2.py +1 -1
- flwr/proto/serverappio_pb2.py +24 -25
- flwr/proto/serverappio_pb2.pyi +26 -32
- flwr/proto/serverappio_pb2_grpc.py +28 -28
- flwr/proto/serverappio_pb2_grpc.pyi +16 -16
- flwr/proto/simulationio_pb2.py +1 -1
- flwr/proto/task_pb2.py +1 -1
- flwr/proto/transport_pb2.py +1 -1
- flwr/server/app.py +116 -128
- flwr/server/compat/app_utils.py +0 -1
- flwr/server/compat/driver_client_proxy.py +1 -2
- flwr/server/driver/grpc_driver.py +32 -27
- flwr/server/driver/inmemory_driver.py +2 -1
- flwr/server/serverapp/app.py +12 -10
- flwr/server/superlink/driver/serverappio_grpc.py +1 -1
- flwr/server/superlink/driver/serverappio_servicer.py +74 -48
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -88
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -24
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +110 -168
- flwr/server/superlink/fleet/message_handler/message_handler.py +37 -24
- flwr/server/superlink/fleet/rest_rere/rest_api.py +16 -18
- flwr/server/superlink/fleet/vce/vce_api.py +2 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +45 -75
- flwr/server/superlink/linkstate/linkstate.py +17 -38
- flwr/server/superlink/linkstate/sqlite_linkstate.py +81 -145
- flwr/server/superlink/linkstate/utils.py +18 -8
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/utils/validator.py +9 -34
- flwr/simulation/app.py +4 -6
- flwr/simulation/legacy_app.py +4 -2
- flwr/simulation/run_simulation.py +1 -1
- flwr/simulation/simulationio_connection.py +2 -1
- flwr/superexec/exec_grpc.py +1 -1
- flwr/superexec/exec_servicer.py +23 -2
- {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/METADATA +8 -8
- {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/RECORD +103 -97
- {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/LICENSE +0 -0
- {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/WHEEL +0 -0
- {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/entry_points.txt +0 -0
|
@@ -31,6 +31,7 @@ from flwr.common.constant import (
|
|
|
31
31
|
MESSAGE_TTL_TOLERANCE,
|
|
32
32
|
NODE_ID_NUM_BYTES,
|
|
33
33
|
RUN_ID_NUM_BYTES,
|
|
34
|
+
SUPERLINK_NODE_ID,
|
|
34
35
|
Status,
|
|
35
36
|
)
|
|
36
37
|
from flwr.common.record import ConfigsRecord
|
|
@@ -70,16 +71,9 @@ CREATE TABLE IF NOT EXISTS node(
|
|
|
70
71
|
);
|
|
71
72
|
"""
|
|
72
73
|
|
|
73
|
-
SQL_CREATE_TABLE_CREDENTIAL = """
|
|
74
|
-
CREATE TABLE IF NOT EXISTS credential(
|
|
75
|
-
private_key BLOB PRIMARY KEY,
|
|
76
|
-
public_key BLOB
|
|
77
|
-
);
|
|
78
|
-
"""
|
|
79
|
-
|
|
80
74
|
SQL_CREATE_TABLE_PUBLIC_KEY = """
|
|
81
75
|
CREATE TABLE IF NOT EXISTS public_key(
|
|
82
|
-
public_key
|
|
76
|
+
public_key BLOB PRIMARY KEY
|
|
83
77
|
);
|
|
84
78
|
"""
|
|
85
79
|
|
|
@@ -128,9 +122,7 @@ CREATE TABLE IF NOT EXISTS task_ins(
|
|
|
128
122
|
task_id TEXT UNIQUE,
|
|
129
123
|
group_id TEXT,
|
|
130
124
|
run_id INTEGER,
|
|
131
|
-
producer_anonymous BOOLEAN,
|
|
132
125
|
producer_node_id INTEGER,
|
|
133
|
-
consumer_anonymous BOOLEAN,
|
|
134
126
|
consumer_node_id INTEGER,
|
|
135
127
|
created_at REAL,
|
|
136
128
|
delivered_at TEXT,
|
|
@@ -148,9 +140,7 @@ CREATE TABLE IF NOT EXISTS task_res(
|
|
|
148
140
|
task_id TEXT UNIQUE,
|
|
149
141
|
group_id TEXT,
|
|
150
142
|
run_id INTEGER,
|
|
151
|
-
producer_anonymous BOOLEAN,
|
|
152
143
|
producer_node_id INTEGER,
|
|
153
|
-
consumer_anonymous BOOLEAN,
|
|
154
144
|
consumer_node_id INTEGER,
|
|
155
145
|
created_at REAL,
|
|
156
146
|
delivered_at TEXT,
|
|
@@ -211,7 +201,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
211
201
|
cur.execute(SQL_CREATE_TABLE_TASK_INS)
|
|
212
202
|
cur.execute(SQL_CREATE_TABLE_TASK_RES)
|
|
213
203
|
cur.execute(SQL_CREATE_TABLE_NODE)
|
|
214
|
-
cur.execute(SQL_CREATE_TABLE_CREDENTIAL)
|
|
215
204
|
cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
|
|
216
205
|
cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
|
|
217
206
|
res = cur.execute("SELECT name FROM sqlite_schema;")
|
|
@@ -263,11 +252,8 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
263
252
|
|
|
264
253
|
Constraints
|
|
265
254
|
-----------
|
|
266
|
-
If `task_ins.task.consumer.anonymous` is `True`, then
|
|
267
|
-
`task_ins.task.consumer.node_id` MUST NOT be set (equal 0).
|
|
268
255
|
|
|
269
|
-
|
|
270
|
-
`task_ins.task.consumer.node_id` MUST be set (not 0)
|
|
256
|
+
`task_ins.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
|
|
271
257
|
"""
|
|
272
258
|
# Validate task
|
|
273
259
|
errors = validate_task_ins_or_res(task_ins)
|
|
@@ -292,7 +278,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
292
278
|
log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
|
|
293
279
|
return None
|
|
294
280
|
# Validate source node ID
|
|
295
|
-
if task_ins.task.producer.node_id !=
|
|
281
|
+
if task_ins.task.producer.node_id != SUPERLINK_NODE_ID:
|
|
296
282
|
log(
|
|
297
283
|
ERROR,
|
|
298
284
|
"Invalid source node ID for TaskIns: %s",
|
|
@@ -301,14 +287,13 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
301
287
|
return None
|
|
302
288
|
# Validate destination node ID
|
|
303
289
|
query = "SELECT node_id FROM node WHERE node_id = ?;"
|
|
304
|
-
if not
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
return None
|
|
290
|
+
if not self.query(query, (data[0]["consumer_node_id"],)):
|
|
291
|
+
log(
|
|
292
|
+
ERROR,
|
|
293
|
+
"Invalid destination node ID for TaskIns: %s",
|
|
294
|
+
task_ins.task.consumer.node_id,
|
|
295
|
+
)
|
|
296
|
+
return None
|
|
312
297
|
|
|
313
298
|
columns = ", ".join([f":{key}" for key in data[0]])
|
|
314
299
|
query = f"INSERT INTO task_ins VALUES({columns});"
|
|
@@ -319,25 +304,18 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
319
304
|
|
|
320
305
|
return task_id
|
|
321
306
|
|
|
322
|
-
def get_task_ins(
|
|
323
|
-
|
|
324
|
-
) -> list[TaskIns]:
|
|
325
|
-
"""Get undelivered TaskIns for one node (either anonymous or with ID).
|
|
307
|
+
def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
|
|
308
|
+
"""Get undelivered TaskIns for one node.
|
|
326
309
|
|
|
327
310
|
Usually, the Fleet API calls this for Nodes planning to work on one or more
|
|
328
311
|
TaskIns.
|
|
329
312
|
|
|
330
313
|
Constraints
|
|
331
314
|
-----------
|
|
332
|
-
|
|
315
|
+
Retrieve all TaskIns where
|
|
333
316
|
|
|
334
317
|
1. the `task_ins.task.consumer.node_id` equals `node_id` AND
|
|
335
|
-
2. the `task_ins.task.
|
|
336
|
-
3. the `task_ins.task.delivered_at` equals `""`.
|
|
337
|
-
|
|
338
|
-
If `node_id` is `None`, retrieve all TaskIns where the
|
|
339
|
-
`task_ins.task.consumer.node_id` equals `0` and
|
|
340
|
-
`task_ins.task.consumer.anonymous` is set to `True`.
|
|
318
|
+
2. the `task_ins.task.delivered_at` equals `""`.
|
|
341
319
|
|
|
342
320
|
`delivered_at` MUST BE set (i.e., not `""`) otherwise the TaskIns MUST not be in
|
|
343
321
|
the result.
|
|
@@ -348,38 +326,23 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
348
326
|
if limit is not None and limit < 1:
|
|
349
327
|
raise AssertionError("`limit` must be >= 1")
|
|
350
328
|
|
|
351
|
-
if node_id ==
|
|
352
|
-
msg =
|
|
353
|
-
"`node_id` must be >= 1"
|
|
354
|
-
"\n\n For requesting anonymous tasks use `node_id` equal `None`"
|
|
355
|
-
)
|
|
329
|
+
if node_id == SUPERLINK_NODE_ID:
|
|
330
|
+
msg = f"`node_id` must be != {SUPERLINK_NODE_ID}"
|
|
356
331
|
raise AssertionError(msg)
|
|
357
332
|
|
|
358
333
|
data: dict[str, Union[str, int]] = {}
|
|
359
334
|
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
query = """
|
|
363
|
-
SELECT task_id
|
|
364
|
-
FROM task_ins
|
|
365
|
-
WHERE consumer_anonymous == 1
|
|
366
|
-
AND consumer_node_id == 0
|
|
367
|
-
AND delivered_at = ""
|
|
368
|
-
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
|
|
369
|
-
"""
|
|
370
|
-
else:
|
|
371
|
-
# Convert the uint64 value to sint64 for SQLite
|
|
372
|
-
data["node_id"] = convert_uint64_to_sint64(node_id)
|
|
335
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
336
|
+
data["node_id"] = convert_uint64_to_sint64(node_id)
|
|
373
337
|
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
"""
|
|
338
|
+
# Retrieve all TaskIns for node_id
|
|
339
|
+
query = """
|
|
340
|
+
SELECT task_id
|
|
341
|
+
FROM task_ins
|
|
342
|
+
WHERE consumer_node_id == :node_id
|
|
343
|
+
AND delivered_at = ""
|
|
344
|
+
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
|
|
345
|
+
"""
|
|
383
346
|
|
|
384
347
|
if limit is not None:
|
|
385
348
|
query += " LIMIT :limit"
|
|
@@ -429,11 +392,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
429
392
|
|
|
430
393
|
Constraints
|
|
431
394
|
-----------
|
|
432
|
-
|
|
433
|
-
`task_res.task.consumer.node_id` MUST NOT be set (equal 0).
|
|
434
|
-
|
|
435
|
-
If `task_res.task.consumer.anonymous` is `False`, then
|
|
436
|
-
`task_res.task.consumer.node_id` MUST be set (not 0)
|
|
395
|
+
`task_res.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
|
|
437
396
|
"""
|
|
438
397
|
# Validate task
|
|
439
398
|
errors = validate_task_ins_or_res(task_res)
|
|
@@ -459,7 +418,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
459
418
|
if (
|
|
460
419
|
task_ins
|
|
461
420
|
and task_res
|
|
462
|
-
and not (task_ins["consumer_anonymous"] or task_res.task.producer.anonymous)
|
|
463
421
|
and convert_sint64_to_uint64(task_ins["consumer_node_id"])
|
|
464
422
|
!= task_res.task.producer.node_id
|
|
465
423
|
):
|
|
@@ -635,23 +593,16 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
635
593
|
|
|
636
594
|
return {UUID(row["task_id"]) for row in rows}
|
|
637
595
|
|
|
638
|
-
def create_node(
|
|
639
|
-
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
640
|
-
) -> int:
|
|
596
|
+
def create_node(self, ping_interval: float) -> int:
|
|
641
597
|
"""Create, store in the link state, and return `node_id`."""
|
|
642
598
|
# Sample a random uint64 as node_id
|
|
643
|
-
uint64_node_id = generate_rand_int_from_bytes(
|
|
599
|
+
uint64_node_id = generate_rand_int_from_bytes(
|
|
600
|
+
NODE_ID_NUM_BYTES, exclude=[SUPERLINK_NODE_ID, 0]
|
|
601
|
+
)
|
|
644
602
|
|
|
645
603
|
# Convert the uint64 value to sint64 for SQLite
|
|
646
604
|
sint64_node_id = convert_uint64_to_sint64(uint64_node_id)
|
|
647
605
|
|
|
648
|
-
query = "SELECT node_id FROM node WHERE public_key = :public_key;"
|
|
649
|
-
row = self.query(query, {"public_key": public_key})
|
|
650
|
-
|
|
651
|
-
if len(row) > 0:
|
|
652
|
-
log(ERROR, "Unexpected node registration failure.")
|
|
653
|
-
return 0
|
|
654
|
-
|
|
655
606
|
query = (
|
|
656
607
|
"INSERT INTO node "
|
|
657
608
|
"(node_id, online_until, ping_interval, public_key) "
|
|
@@ -665,7 +616,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
665
616
|
sint64_node_id,
|
|
666
617
|
time.time() + ping_interval,
|
|
667
618
|
ping_interval,
|
|
668
|
-
|
|
619
|
+
b"", # Initialize with an empty public key
|
|
669
620
|
),
|
|
670
621
|
)
|
|
671
622
|
except sqlite3.IntegrityError:
|
|
@@ -675,7 +626,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
675
626
|
# Note: we need to return the uint64 value of the node_id
|
|
676
627
|
return uint64_node_id
|
|
677
628
|
|
|
678
|
-
def delete_node(self, node_id: int
|
|
629
|
+
def delete_node(self, node_id: int) -> None:
|
|
679
630
|
"""Delete a node."""
|
|
680
631
|
# Convert the uint64 value to sint64 for SQLite
|
|
681
632
|
sint64_node_id = convert_uint64_to_sint64(node_id)
|
|
@@ -683,10 +634,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
683
634
|
query = "DELETE FROM node WHERE node_id = ?"
|
|
684
635
|
params = (sint64_node_id,)
|
|
685
636
|
|
|
686
|
-
if public_key is not None:
|
|
687
|
-
query += " AND public_key = ?"
|
|
688
|
-
params += (public_key,) # type: ignore
|
|
689
|
-
|
|
690
637
|
if self.conn is None:
|
|
691
638
|
raise AttributeError("LinkState is not initialized.")
|
|
692
639
|
|
|
@@ -694,7 +641,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
694
641
|
with self.conn:
|
|
695
642
|
rows = self.conn.execute(query, params)
|
|
696
643
|
if rows.rowcount < 1:
|
|
697
|
-
raise ValueError("
|
|
644
|
+
raise ValueError(f"Node {node_id} not found")
|
|
698
645
|
except KeyError as exc:
|
|
699
646
|
log(ERROR, {"query": query, "data": params, "exception": exc})
|
|
700
647
|
|
|
@@ -722,6 +669,41 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
722
669
|
result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows}
|
|
723
670
|
return result
|
|
724
671
|
|
|
672
|
+
def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
|
|
673
|
+
"""Set `public_key` for the specified `node_id`."""
|
|
674
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
675
|
+
sint64_node_id = convert_uint64_to_sint64(node_id)
|
|
676
|
+
|
|
677
|
+
# Check if the node exists in the `node` table
|
|
678
|
+
query = "SELECT 1 FROM node WHERE node_id = ?"
|
|
679
|
+
if not self.query(query, (sint64_node_id,)):
|
|
680
|
+
raise ValueError(f"Node {node_id} not found")
|
|
681
|
+
|
|
682
|
+
# Check if the public key is already in use in the `node` table
|
|
683
|
+
query = "SELECT 1 FROM node WHERE public_key = ?"
|
|
684
|
+
if self.query(query, (public_key,)):
|
|
685
|
+
raise ValueError("Public key already in use")
|
|
686
|
+
|
|
687
|
+
# Update the `node` table to set the public key for the given node ID
|
|
688
|
+
query = "UPDATE node SET public_key = ? WHERE node_id = ?"
|
|
689
|
+
self.query(query, (public_key, sint64_node_id))
|
|
690
|
+
|
|
691
|
+
def get_node_public_key(self, node_id: int) -> Optional[bytes]:
|
|
692
|
+
"""Get `public_key` for the specified `node_id`."""
|
|
693
|
+
# Convert the uint64 value to sint64 for SQLite
|
|
694
|
+
sint64_node_id = convert_uint64_to_sint64(node_id)
|
|
695
|
+
|
|
696
|
+
# Query the public key for the given node_id
|
|
697
|
+
query = "SELECT public_key FROM node WHERE node_id = ?"
|
|
698
|
+
rows = self.query(query, (sint64_node_id,))
|
|
699
|
+
|
|
700
|
+
# If no result is found, return None
|
|
701
|
+
if not rows:
|
|
702
|
+
raise ValueError(f"Node {node_id} not found")
|
|
703
|
+
|
|
704
|
+
# Return the public key if it is not empty, otherwise return None
|
|
705
|
+
return rows[0]["public_key"] or None
|
|
706
|
+
|
|
725
707
|
def get_node_id(self, node_public_key: bytes) -> Optional[int]:
|
|
726
708
|
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
|
|
727
709
|
query = "SELECT node_id FROM node WHERE public_key = :public_key;"
|
|
@@ -783,46 +765,9 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
783
765
|
log(ERROR, "Unexpected run creation failure.")
|
|
784
766
|
return 0
|
|
785
767
|
|
|
786
|
-
def
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
"""Store `server_private_key` and `server_public_key` in the link state."""
|
|
790
|
-
query = "SELECT COUNT(*) FROM credential"
|
|
791
|
-
count = self.query(query)[0]["COUNT(*)"]
|
|
792
|
-
if count < 1:
|
|
793
|
-
query = (
|
|
794
|
-
"INSERT OR REPLACE INTO credential (private_key, public_key) "
|
|
795
|
-
"VALUES (:private_key, :public_key)"
|
|
796
|
-
)
|
|
797
|
-
self.query(query, {"private_key": private_key, "public_key": public_key})
|
|
798
|
-
else:
|
|
799
|
-
raise RuntimeError("Server private and public key already set")
|
|
800
|
-
|
|
801
|
-
def get_server_private_key(self) -> Optional[bytes]:
|
|
802
|
-
"""Retrieve `server_private_key` in urlsafe bytes."""
|
|
803
|
-
query = "SELECT private_key FROM credential"
|
|
804
|
-
rows = self.query(query)
|
|
805
|
-
try:
|
|
806
|
-
private_key: Optional[bytes] = rows[0]["private_key"]
|
|
807
|
-
except IndexError:
|
|
808
|
-
private_key = None
|
|
809
|
-
return private_key
|
|
810
|
-
|
|
811
|
-
def get_server_public_key(self) -> Optional[bytes]:
|
|
812
|
-
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
813
|
-
query = "SELECT public_key FROM credential"
|
|
814
|
-
rows = self.query(query)
|
|
815
|
-
try:
|
|
816
|
-
public_key: Optional[bytes] = rows[0]["public_key"]
|
|
817
|
-
except IndexError:
|
|
818
|
-
public_key = None
|
|
819
|
-
return public_key
|
|
820
|
-
|
|
821
|
-
def clear_supernode_auth_keys_and_credentials(self) -> None:
|
|
822
|
-
"""Clear stored `node_public_keys` and credentials in the link state if any."""
|
|
823
|
-
queries = ["DELETE FROM public_key;", "DELETE FROM credential;"]
|
|
824
|
-
for query in queries:
|
|
825
|
-
self.query(query)
|
|
768
|
+
def clear_supernode_auth_keys(self) -> None:
|
|
769
|
+
"""Clear stored `node_public_keys` in the link state if any."""
|
|
770
|
+
self.query("DELETE FROM public_key;")
|
|
826
771
|
|
|
827
772
|
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
828
773
|
"""Store a set of `node_public_keys` in the link state."""
|
|
@@ -982,17 +927,16 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
982
927
|
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
983
928
|
sint64_node_id = convert_uint64_to_sint64(node_id)
|
|
984
929
|
|
|
985
|
-
#
|
|
986
|
-
query = "
|
|
987
|
-
|
|
988
|
-
self.query(
|
|
989
|
-
query, (time.time() + ping_interval, ping_interval, sint64_node_id)
|
|
990
|
-
)
|
|
991
|
-
return True
|
|
992
|
-
except sqlite3.IntegrityError:
|
|
993
|
-
log(ERROR, "`node_id` does not exist.")
|
|
930
|
+
# Check if the node exists in the `node` table
|
|
931
|
+
query = "SELECT 1 FROM node WHERE node_id = ?"
|
|
932
|
+
if not self.query(query, (sint64_node_id,)):
|
|
994
933
|
return False
|
|
995
934
|
|
|
935
|
+
# Update `online_until` and `ping_interval` for the given `node_id`
|
|
936
|
+
query = "UPDATE node SET online_until = ?, ping_interval = ? WHERE node_id = ?"
|
|
937
|
+
self.query(query, (time.time() + ping_interval, ping_interval, sint64_node_id))
|
|
938
|
+
return True
|
|
939
|
+
|
|
996
940
|
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
|
997
941
|
"""Get the context for the specified `run_id`."""
|
|
998
942
|
# Retrieve context if any
|
|
@@ -1105,9 +1049,7 @@ def task_ins_to_dict(task_msg: TaskIns) -> dict[str, Any]:
|
|
|
1105
1049
|
"task_id": task_msg.task_id,
|
|
1106
1050
|
"group_id": task_msg.group_id,
|
|
1107
1051
|
"run_id": task_msg.run_id,
|
|
1108
|
-
"producer_anonymous": task_msg.task.producer.anonymous,
|
|
1109
1052
|
"producer_node_id": task_msg.task.producer.node_id,
|
|
1110
|
-
"consumer_anonymous": task_msg.task.consumer.anonymous,
|
|
1111
1053
|
"consumer_node_id": task_msg.task.consumer.node_id,
|
|
1112
1054
|
"created_at": task_msg.task.created_at,
|
|
1113
1055
|
"delivered_at": task_msg.task.delivered_at,
|
|
@@ -1126,9 +1068,7 @@ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]:
|
|
|
1126
1068
|
"task_id": task_msg.task_id,
|
|
1127
1069
|
"group_id": task_msg.group_id,
|
|
1128
1070
|
"run_id": task_msg.run_id,
|
|
1129
|
-
"producer_anonymous": task_msg.task.producer.anonymous,
|
|
1130
1071
|
"producer_node_id": task_msg.task.producer.node_id,
|
|
1131
|
-
"consumer_anonymous": task_msg.task.consumer.anonymous,
|
|
1132
1072
|
"consumer_node_id": task_msg.task.consumer.node_id,
|
|
1133
1073
|
"created_at": task_msg.task.created_at,
|
|
1134
1074
|
"delivered_at": task_msg.task.delivered_at,
|
|
@@ -1153,11 +1093,9 @@ def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns:
|
|
|
1153
1093
|
task=Task(
|
|
1154
1094
|
producer=Node(
|
|
1155
1095
|
node_id=task_dict["producer_node_id"],
|
|
1156
|
-
anonymous=task_dict["producer_anonymous"],
|
|
1157
1096
|
),
|
|
1158
1097
|
consumer=Node(
|
|
1159
1098
|
node_id=task_dict["consumer_node_id"],
|
|
1160
|
-
anonymous=task_dict["consumer_anonymous"],
|
|
1161
1099
|
),
|
|
1162
1100
|
created_at=task_dict["created_at"],
|
|
1163
1101
|
delivered_at=task_dict["delivered_at"],
|
|
@@ -1183,11 +1121,9 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
|
|
|
1183
1121
|
task=Task(
|
|
1184
1122
|
producer=Node(
|
|
1185
1123
|
node_id=task_dict["producer_node_id"],
|
|
1186
|
-
anonymous=task_dict["producer_anonymous"],
|
|
1187
1124
|
),
|
|
1188
1125
|
consumer=Node(
|
|
1189
1126
|
node_id=task_dict["consumer_node_id"],
|
|
1190
|
-
anonymous=task_dict["consumer_anonymous"],
|
|
1191
1127
|
),
|
|
1192
1128
|
created_at=task_dict["created_at"],
|
|
1193
1129
|
delivered_at=task_dict["delivered_at"],
|
|
@@ -21,7 +21,7 @@ from typing import Optional, Union
|
|
|
21
21
|
from uuid import UUID, uuid4
|
|
22
22
|
|
|
23
23
|
from flwr.common import ConfigsRecord, Context, log, now, serde
|
|
24
|
-
from flwr.common.constant import ErrorCode, Status, SubStatus
|
|
24
|
+
from flwr.common.constant import SUPERLINK_NODE_ID, ErrorCode, Status, SubStatus
|
|
25
25
|
from flwr.common.typing import RunStatus
|
|
26
26
|
|
|
27
27
|
# pylint: disable=E0611
|
|
@@ -60,9 +60,19 @@ REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON = (
|
|
|
60
60
|
)
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
def generate_rand_int_from_bytes(
|
|
64
|
-
|
|
65
|
-
|
|
63
|
+
def generate_rand_int_from_bytes(
|
|
64
|
+
num_bytes: int, exclude: Optional[list[int]] = None
|
|
65
|
+
) -> int:
|
|
66
|
+
"""Generate a random unsigned integer from `num_bytes` bytes.
|
|
67
|
+
|
|
68
|
+
If `exclude` is set, this function guarantees such number is not returned.
|
|
69
|
+
"""
|
|
70
|
+
num = int.from_bytes(urandom(num_bytes), "little", signed=False)
|
|
71
|
+
|
|
72
|
+
if exclude:
|
|
73
|
+
while num in exclude:
|
|
74
|
+
num = int.from_bytes(urandom(num_bytes), "little", signed=False)
|
|
75
|
+
return num
|
|
66
76
|
|
|
67
77
|
|
|
68
78
|
def convert_uint64_to_sint64(u: int) -> int:
|
|
@@ -246,8 +256,8 @@ def create_taskres_for_unavailable_taskins(taskins_id: Union[str, UUID]) -> Task
|
|
|
246
256
|
run_id=0, # Unknown run ID
|
|
247
257
|
task=Task(
|
|
248
258
|
# This function is only called by SuperLink, and thus it's the producer.
|
|
249
|
-
producer=Node(node_id=
|
|
250
|
-
consumer=Node(node_id=
|
|
259
|
+
producer=Node(node_id=SUPERLINK_NODE_ID),
|
|
260
|
+
consumer=Node(node_id=SUPERLINK_NODE_ID),
|
|
251
261
|
created_at=current_time,
|
|
252
262
|
ttl=0,
|
|
253
263
|
ancestry=[str(taskins_id)],
|
|
@@ -285,8 +295,8 @@ def create_taskres_for_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
|
|
|
285
295
|
run_id=ref_taskins.run_id,
|
|
286
296
|
task=Task(
|
|
287
297
|
# This function is only called by SuperLink, and thus it's the producer.
|
|
288
|
-
producer=Node(node_id=
|
|
289
|
-
consumer=Node(node_id=
|
|
298
|
+
producer=Node(node_id=SUPERLINK_NODE_ID),
|
|
299
|
+
consumer=Node(node_id=SUPERLINK_NODE_ID),
|
|
290
300
|
created_at=current_time,
|
|
291
301
|
ttl=ttl,
|
|
292
302
|
ancestry=[ref_taskins.task_id],
|
|
@@ -21,6 +21,7 @@ from typing import Optional
|
|
|
21
21
|
import grpc
|
|
22
22
|
|
|
23
23
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
24
|
+
from flwr.common.grpc import generic_create_grpc_server
|
|
24
25
|
from flwr.common.logger import log
|
|
25
26
|
from flwr.proto.simulationio_pb2_grpc import ( # pylint: disable=E0611
|
|
26
27
|
add_SimulationIoServicer_to_server,
|
|
@@ -28,7 +29,6 @@ from flwr.proto.simulationio_pb2_grpc import ( # pylint: disable=E0611
|
|
|
28
29
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
29
30
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
30
31
|
|
|
31
|
-
from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
|
|
32
32
|
from .simulationio_servicer import SimulationIoServicer
|
|
33
33
|
|
|
34
34
|
|
flwr/server/utils/validator.py
CHANGED
|
@@ -18,6 +18,7 @@
|
|
|
18
18
|
import time
|
|
19
19
|
from typing import Union
|
|
20
20
|
|
|
21
|
+
from flwr.common.constant import SUPERLINK_NODE_ID
|
|
21
22
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
22
23
|
|
|
23
24
|
|
|
@@ -58,24 +59,14 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> list[str
|
|
|
58
59
|
# Task producer
|
|
59
60
|
if not tasks_ins_res.task.HasField("producer"):
|
|
60
61
|
validation_errors.append("`producer` does not set field `producer`")
|
|
61
|
-
if tasks_ins_res.task.producer.node_id !=
|
|
62
|
-
validation_errors.append("`producer.node_id` is not
|
|
63
|
-
if not tasks_ins_res.task.producer.anonymous:
|
|
64
|
-
validation_errors.append("`producer` is not anonymous")
|
|
62
|
+
if tasks_ins_res.task.producer.node_id != SUPERLINK_NODE_ID:
|
|
63
|
+
validation_errors.append(f"`producer.node_id` is not {SUPERLINK_NODE_ID}")
|
|
65
64
|
|
|
66
65
|
# Task consumer
|
|
67
66
|
if not tasks_ins_res.task.HasField("consumer"):
|
|
68
67
|
validation_errors.append("`consumer` does not set field `consumer`")
|
|
69
|
-
if
|
|
70
|
-
|
|
71
|
-
and tasks_ins_res.task.consumer.node_id != 0
|
|
72
|
-
):
|
|
73
|
-
validation_errors.append("anonymous consumers MUST NOT set a `node_id`")
|
|
74
|
-
if (
|
|
75
|
-
not tasks_ins_res.task.consumer.anonymous
|
|
76
|
-
and tasks_ins_res.task.consumer.node_id == 0
|
|
77
|
-
):
|
|
78
|
-
validation_errors.append("non-anonymous consumer MUST provide a `node_id`")
|
|
68
|
+
if tasks_ins_res.task.consumer.node_id == SUPERLINK_NODE_ID:
|
|
69
|
+
validation_errors.append("consumer MUST provide a valid `node_id`")
|
|
79
70
|
|
|
80
71
|
# Content check
|
|
81
72
|
if tasks_ins_res.task.task_type == "":
|
|
@@ -95,30 +86,14 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> list[str
|
|
|
95
86
|
# Task producer
|
|
96
87
|
if not tasks_ins_res.task.HasField("producer"):
|
|
97
88
|
validation_errors.append("`producer` does not set field `producer`")
|
|
98
|
-
if
|
|
99
|
-
|
|
100
|
-
and tasks_ins_res.task.producer.node_id != 0
|
|
101
|
-
):
|
|
102
|
-
validation_errors.append("anonymous producers MUST NOT set a `node_id`")
|
|
103
|
-
if (
|
|
104
|
-
not tasks_ins_res.task.producer.anonymous
|
|
105
|
-
and tasks_ins_res.task.producer.node_id == 0
|
|
106
|
-
):
|
|
107
|
-
validation_errors.append("non-anonymous producer MUST provide a `node_id`")
|
|
89
|
+
if tasks_ins_res.task.producer.node_id == SUPERLINK_NODE_ID:
|
|
90
|
+
validation_errors.append("producer MUST provide a valid `node_id`")
|
|
108
91
|
|
|
109
92
|
# Task consumer
|
|
110
93
|
if not tasks_ins_res.task.HasField("consumer"):
|
|
111
94
|
validation_errors.append("`consumer` does not set field `consumer`")
|
|
112
|
-
if
|
|
113
|
-
|
|
114
|
-
and tasks_ins_res.task.consumer.node_id != 0
|
|
115
|
-
):
|
|
116
|
-
validation_errors.append("anonymous consumers MUST NOT set a `node_id`")
|
|
117
|
-
if (
|
|
118
|
-
not tasks_ins_res.task.consumer.anonymous
|
|
119
|
-
and tasks_ins_res.task.consumer.node_id == 0
|
|
120
|
-
):
|
|
121
|
-
validation_errors.append("non-anonymous consumer MUST provide a `node_id`")
|
|
95
|
+
if tasks_ins_res.task.consumer.node_id != SUPERLINK_NODE_ID:
|
|
96
|
+
validation_errors.append(f"consumer is not {SUPERLINK_NODE_ID}")
|
|
122
97
|
|
|
123
98
|
# Content check
|
|
124
99
|
if tasks_ins_res.task.task_type == "":
|
flwr/simulation/app.py
CHANGED
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import argparse
|
|
19
|
-
import sys
|
|
20
19
|
from logging import DEBUG, ERROR, INFO
|
|
21
20
|
from queue import Queue
|
|
22
21
|
from time import sleep
|
|
@@ -39,6 +38,7 @@ from flwr.common.constant import (
|
|
|
39
38
|
Status,
|
|
40
39
|
SubStatus,
|
|
41
40
|
)
|
|
41
|
+
from flwr.common.exit import ExitCode, flwr_exit
|
|
42
42
|
from flwr.common.logger import (
|
|
43
43
|
log,
|
|
44
44
|
mirror_output_to_queue,
|
|
@@ -81,12 +81,10 @@ def flwr_simulation() -> None:
|
|
|
81
81
|
log(INFO, "Starting Flower Simulation")
|
|
82
82
|
|
|
83
83
|
if not args.insecure:
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
"`flwr-simulation` does not support TLS yet. "
|
|
87
|
-
"Please use the '--insecure' flag.",
|
|
84
|
+
flwr_exit(
|
|
85
|
+
ExitCode.COMMON_TLS_NOT_SUPPORTED,
|
|
86
|
+
"`flwr-simulation` does not support TLS yet. ",
|
|
88
87
|
)
|
|
89
|
-
sys.exit(1)
|
|
90
88
|
|
|
91
89
|
log(
|
|
92
90
|
DEBUG,
|
flwr/simulation/legacy_app.py
CHANGED
|
@@ -29,7 +29,7 @@ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
|
|
|
29
29
|
|
|
30
30
|
from flwr.client import ClientFnExt
|
|
31
31
|
from flwr.common import EventType, event
|
|
32
|
-
from flwr.common.constant import NODE_ID_NUM_BYTES
|
|
32
|
+
from flwr.common.constant import NODE_ID_NUM_BYTES, SUPERLINK_NODE_ID
|
|
33
33
|
from flwr.common.logger import (
|
|
34
34
|
log,
|
|
35
35
|
set_logger_propagation,
|
|
@@ -87,7 +87,9 @@ def _create_node_id_to_partition_mapping(
|
|
|
87
87
|
nodes_mapping: NodeToPartitionMapping = {} # {node-id; partition-id}
|
|
88
88
|
for i in range(num_clients):
|
|
89
89
|
while True:
|
|
90
|
-
node_id = generate_rand_int_from_bytes(
|
|
90
|
+
node_id = generate_rand_int_from_bytes(
|
|
91
|
+
NODE_ID_NUM_BYTES, exclude=[SUPERLINK_NODE_ID, 0]
|
|
92
|
+
)
|
|
91
93
|
if node_id not in nodes_mapping:
|
|
92
94
|
break
|
|
93
95
|
nodes_mapping[node_id] = i
|
|
@@ -350,7 +350,7 @@ def _main_loop(
|
|
|
350
350
|
# Initialize Driver
|
|
351
351
|
driver = InMemoryDriver(state_factory=state_factory)
|
|
352
352
|
driver.set_run(run_id=run.run_id)
|
|
353
|
-
output_context_queue:
|
|
353
|
+
output_context_queue: Queue[Context] = Queue()
|
|
354
354
|
|
|
355
355
|
# Get and run ServerApp thread
|
|
356
356
|
serverapp_th = run_serverapp_th(
|
|
@@ -21,7 +21,7 @@ from typing import Optional, cast
|
|
|
21
21
|
import grpc
|
|
22
22
|
|
|
23
23
|
from flwr.common.constant import SIMULATIONIO_API_DEFAULT_CLIENT_ADDRESS
|
|
24
|
-
from flwr.common.grpc import create_channel
|
|
24
|
+
from flwr.common.grpc import create_channel, on_channel_state_change
|
|
25
25
|
from flwr.common.logger import log
|
|
26
26
|
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
|
|
27
27
|
from flwr.proto.simulationio_pb2_grpc import SimulationIoStub # pylint: disable=E0611
|
|
@@ -73,6 +73,7 @@ class SimulationIoConnection:
|
|
|
73
73
|
insecure=(self._cert is None),
|
|
74
74
|
root_certificates=self._cert,
|
|
75
75
|
)
|
|
76
|
+
self._channel.subscribe(on_channel_state_change)
|
|
76
77
|
self._grpc_stub = SimulationIoStub(self._channel)
|
|
77
78
|
_wrap_stub(self._grpc_stub, self._retry_invoker)
|
|
78
79
|
log(DEBUG, "[SimulationIO] Connected to %s", self._addr)
|
flwr/superexec/exec_grpc.py
CHANGED
|
@@ -23,11 +23,11 @@ import grpc
|
|
|
23
23
|
|
|
24
24
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
25
25
|
from flwr.common.auth_plugin import ExecAuthPlugin
|
|
26
|
+
from flwr.common.grpc import generic_create_grpc_server
|
|
26
27
|
from flwr.common.logger import log
|
|
27
28
|
from flwr.common.typing import UserConfig
|
|
28
29
|
from flwr.proto.exec_pb2_grpc import add_ExecServicer_to_server
|
|
29
30
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
30
|
-
from flwr.server.superlink.fleet.grpc_bidi.grpc_server import generic_create_grpc_server
|
|
31
31
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
32
32
|
from flwr.superexec.exec_user_auth_interceptor import ExecUserAuthInterceptor
|
|
33
33
|
|