flwr-nightly 1.23.0.dev20251007__py3-none-any.whl → 1.23.0.dev20251009__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/auth_plugin/__init__.py +7 -3
- flwr/cli/log.py +2 -2
- flwr/cli/login/login.py +4 -13
- flwr/cli/ls.py +2 -2
- flwr/cli/pull.py +2 -2
- flwr/cli/run/run.py +2 -2
- flwr/cli/stop.py +2 -2
- flwr/cli/supernode/create.py +137 -11
- flwr/cli/supernode/delete.py +88 -10
- flwr/cli/supernode/ls.py +2 -2
- flwr/cli/utils.py +65 -55
- flwr/client/grpc_rere_client/connection.py +6 -4
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +2 -2
- flwr/client/rest_client/connection.py +7 -1
- flwr/common/constant.py +13 -0
- flwr/proto/control_pb2.py +1 -1
- flwr/proto/control_pb2.pyi +2 -2
- flwr/proto/fleet_pb2.py +22 -22
- flwr/proto/fleet_pb2.pyi +4 -1
- flwr/proto/node_pb2.py +2 -2
- flwr/proto/node_pb2.pyi +4 -1
- flwr/server/app.py +32 -31
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +8 -4
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +18 -37
- flwr/server/superlink/fleet/message_handler/message_handler.py +5 -3
- flwr/server/superlink/fleet/vce/vce_api.py +10 -1
- flwr/server/superlink/linkstate/in_memory_linkstate.py +52 -54
- flwr/server/superlink/linkstate/linkstate.py +20 -10
- flwr/server/superlink/linkstate/sqlite_linkstate.py +54 -61
- flwr/server/utils/validator.py +2 -3
- flwr/supercore/primitives/asymmetric.py +8 -0
- flwr/superlink/auth_plugin/__init__.py +29 -0
- flwr/superlink/servicer/control/control_grpc.py +9 -7
- flwr/superlink/servicer/control/control_servicer.py +89 -48
- {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251009.dist-info}/METADATA +1 -1
- {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251009.dist-info}/RECORD +38 -38
- {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251009.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251009.dist-info}/entry_points.txt +0 -0
@@ -17,7 +17,6 @@
|
|
17
17
|
|
18
18
|
import secrets
|
19
19
|
import threading
|
20
|
-
import time
|
21
20
|
from bisect import bisect_right
|
22
21
|
from collections import defaultdict
|
23
22
|
from dataclasses import dataclass, field
|
@@ -39,6 +38,7 @@ from flwr.common.constant import (
|
|
39
38
|
)
|
40
39
|
from flwr.common.record import ConfigRecord
|
41
40
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
41
|
+
from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
|
42
42
|
from flwr.server.superlink.linkstate.linkstate import LinkState
|
43
43
|
from flwr.server.utils import validate_message
|
44
44
|
|
@@ -69,10 +69,10 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
69
69
|
|
70
70
|
def __init__(self) -> None:
|
71
71
|
|
72
|
-
# Map node_id to
|
73
|
-
self.
|
74
|
-
self.
|
75
|
-
self.
|
72
|
+
# Map node_id to NodeInfo
|
73
|
+
self.nodes: dict[int, NodeInfo] = {}
|
74
|
+
self.registered_node_public_keys: set[bytes] = set()
|
75
|
+
self.owner_to_node_ids: dict[str, set[int]] = {} # Quick lookup
|
76
76
|
|
77
77
|
# Map run_id to RunRecord
|
78
78
|
self.run_ids: dict[int, RunRecord] = {}
|
@@ -114,7 +114,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
114
114
|
)
|
115
115
|
return None
|
116
116
|
# Validate destination node ID
|
117
|
-
if message.metadata.dst_node_id not in self.
|
117
|
+
if message.metadata.dst_node_id not in self.nodes:
|
118
118
|
log(
|
119
119
|
ERROR,
|
120
120
|
"Invalid destination node ID for Message: %s",
|
@@ -136,7 +136,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
136
136
|
|
137
137
|
# Find Message for node_id that were not delivered yet
|
138
138
|
message_ins_list: list[Message] = []
|
139
|
-
current_time =
|
139
|
+
current_time = now().timestamp()
|
140
140
|
with self.lock:
|
141
141
|
for _, msg_ins in self.message_ins_store.items():
|
142
142
|
if (
|
@@ -190,7 +190,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
190
190
|
return None
|
191
191
|
|
192
192
|
ins_metadata = msg_ins.metadata
|
193
|
-
if ins_metadata.created_at + ins_metadata.ttl <=
|
193
|
+
if ins_metadata.created_at + ins_metadata.ttl <= now().timestamp():
|
194
194
|
log(
|
195
195
|
ERROR,
|
196
196
|
"Failed to store Message: the message it is replying to "
|
@@ -238,7 +238,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
238
238
|
ret: dict[str, Message] = {}
|
239
239
|
|
240
240
|
with self.lock:
|
241
|
-
current =
|
241
|
+
current = now().timestamp()
|
242
242
|
|
243
243
|
# Verify Message IDs
|
244
244
|
ret = verify_message_ids(
|
@@ -256,9 +256,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
256
256
|
inquired_in_message_ids=message_ids,
|
257
257
|
found_in_message_dict=self.message_ins_store,
|
258
258
|
node_id_to_online_until={
|
259
|
-
node_id: self.
|
259
|
+
node_id: self.nodes[node_id].online_until
|
260
260
|
for node_id in dst_node_ids
|
261
|
-
if node_id in self.
|
261
|
+
if node_id in self.nodes
|
262
262
|
},
|
263
263
|
current_time=current,
|
264
264
|
)
|
@@ -330,7 +330,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
330
330
|
"""
|
331
331
|
return len(self.message_res_store)
|
332
332
|
|
333
|
-
def create_node(
|
333
|
+
def create_node(
|
334
|
+
self, owner_aid: str, public_key: bytes, heartbeat_interval: float
|
335
|
+
) -> int:
|
334
336
|
"""Create, store in the link state, and return `node_id`."""
|
335
337
|
# Sample a random int64 as node_id
|
336
338
|
node_id = generate_rand_int_from_bytes(
|
@@ -338,28 +340,40 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
338
340
|
)
|
339
341
|
|
340
342
|
with self.lock:
|
341
|
-
if node_id in self.
|
343
|
+
if node_id in self.nodes:
|
342
344
|
log(ERROR, "Unexpected node registration failure.")
|
343
345
|
return 0
|
346
|
+
if public_key in self.registered_node_public_keys:
|
347
|
+
raise ValueError("Public key already in use")
|
344
348
|
|
345
|
-
# Mark the node online until
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
+
# Mark the node online until now().timestamp() + heartbeat_interval
|
350
|
+
current = now()
|
351
|
+
self.nodes[node_id] = NodeInfo(
|
352
|
+
node_id=node_id,
|
353
|
+
owner_aid=owner_aid, # Unused for now
|
354
|
+
status="created", # Unused for now
|
355
|
+
created_at=current.isoformat(), # Unused for now
|
356
|
+
last_activated_at=current.isoformat(), # Unused for now
|
357
|
+
last_deactivated_at="", # Unused for now
|
358
|
+
deleted_at="", # Unused for now
|
359
|
+
online_until=current.timestamp() + heartbeat_interval,
|
360
|
+
heartbeat_interval=heartbeat_interval,
|
361
|
+
public_key=public_key,
|
349
362
|
)
|
363
|
+
self.registered_node_public_keys.add(public_key)
|
364
|
+
self.owner_to_node_ids.setdefault(owner_aid, set()).add(node_id)
|
350
365
|
return node_id
|
351
366
|
|
352
|
-
def delete_node(self, node_id: int) -> None:
|
367
|
+
def delete_node(self, owner_aid: str, node_id: int) -> None:
|
353
368
|
"""Delete a node."""
|
354
369
|
with self.lock:
|
355
|
-
if node_id not in self.
|
356
|
-
raise ValueError(
|
357
|
-
|
358
|
-
|
359
|
-
if pk := self.node_id_to_public_key.pop(node_id, None):
|
360
|
-
del self.public_key_to_node_id[pk]
|
370
|
+
if node_id not in self.nodes or owner_aid != self.nodes[node_id].owner_aid:
|
371
|
+
raise ValueError(
|
372
|
+
f"Node ID {node_id} not found or unauthorized deletion attempt."
|
373
|
+
)
|
361
374
|
|
362
|
-
|
375
|
+
node = self.nodes.pop(node_id)
|
376
|
+
self.registered_node_public_keys.discard(node.public_key)
|
363
377
|
|
364
378
|
def get_nodes(self, run_id: int) -> set[int]:
|
365
379
|
"""Return all available nodes.
|
@@ -372,36 +386,20 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
372
386
|
with self.lock:
|
373
387
|
if run_id not in self.run_ids:
|
374
388
|
return set()
|
375
|
-
current_time =
|
389
|
+
current_time = now().timestamp()
|
376
390
|
return {
|
377
|
-
node_id
|
378
|
-
for
|
379
|
-
if online_until > current_time
|
391
|
+
info.node_id
|
392
|
+
for info in self.nodes.values()
|
393
|
+
if info.online_until > current_time
|
380
394
|
}
|
381
395
|
|
382
|
-
def
|
383
|
-
"""Set `public_key` for the specified `node_id`."""
|
384
|
-
with self.lock:
|
385
|
-
if node_id not in self.node_ids:
|
386
|
-
raise ValueError(f"Node {node_id} not found")
|
387
|
-
|
388
|
-
if public_key in self.public_key_to_node_id:
|
389
|
-
raise ValueError("Public key already in use")
|
390
|
-
|
391
|
-
self.public_key_to_node_id[public_key] = node_id
|
392
|
-
self.node_id_to_public_key[node_id] = public_key
|
393
|
-
|
394
|
-
def get_node_public_key(self, node_id: int) -> Optional[bytes]:
|
396
|
+
def get_node_public_key(self, node_id: int) -> bytes:
|
395
397
|
"""Get `public_key` for the specified `node_id`."""
|
396
398
|
with self.lock:
|
397
|
-
if
|
398
|
-
raise ValueError(f"Node {node_id} not found")
|
399
|
-
|
400
|
-
return self.node_id_to_public_key.get(node_id)
|
399
|
+
if (node := self.nodes.get(node_id)) is None:
|
400
|
+
raise ValueError(f"Node ID {node_id} not found")
|
401
401
|
|
402
|
-
|
403
|
-
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
|
404
|
-
return self.public_key_to_node_id.get(node_public_key)
|
402
|
+
return node.public_key
|
405
403
|
|
406
404
|
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
407
405
|
def create_run(
|
@@ -608,13 +606,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
608
606
|
the node is marked as offline.
|
609
607
|
"""
|
610
608
|
with self.lock:
|
611
|
-
if
|
612
|
-
|
613
|
-
|
614
|
-
heartbeat_interval,
|
609
|
+
if info := self.nodes.get(node_id):
|
610
|
+
info.online_until = (
|
611
|
+
now().timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval
|
615
612
|
)
|
613
|
+
info.heartbeat_interval = heartbeat_interval
|
616
614
|
return True
|
617
|
-
|
615
|
+
return False
|
618
616
|
|
619
617
|
def acknowledge_app_heartbeat(self, run_id: int, heartbeat_interval: float) -> bool:
|
620
618
|
"""Acknowledge a heartbeat received from a ServerApp for a given run.
|
@@ -128,11 +128,13 @@ class LinkState(CoreState): # pylint: disable=R0904
|
|
128
128
|
"""Get all instruction Message IDs for the given run_id."""
|
129
129
|
|
130
130
|
@abc.abstractmethod
|
131
|
-
def create_node(
|
131
|
+
def create_node(
|
132
|
+
self, owner_aid: str, public_key: bytes, heartbeat_interval: float
|
133
|
+
) -> int:
|
132
134
|
"""Create, store in the link state, and return `node_id`."""
|
133
135
|
|
134
136
|
@abc.abstractmethod
|
135
|
-
def delete_node(self, node_id: int) -> None:
|
137
|
+
def delete_node(self, owner_aid: str, node_id: int) -> None:
|
136
138
|
"""Remove `node_id` from the link state."""
|
137
139
|
|
138
140
|
@abc.abstractmethod
|
@@ -146,16 +148,24 @@ class LinkState(CoreState): # pylint: disable=R0904
|
|
146
148
|
"""
|
147
149
|
|
148
150
|
@abc.abstractmethod
|
149
|
-
def
|
150
|
-
"""
|
151
|
+
def get_node_public_key(self, node_id: int) -> bytes:
|
152
|
+
"""Get `public_key` for the specified `node_id`.
|
151
153
|
|
152
|
-
|
153
|
-
|
154
|
-
|
154
|
+
Parameters
|
155
|
+
----------
|
156
|
+
node_id : int
|
157
|
+
The identifier of the node whose public key is to be retrieved.
|
155
158
|
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
+
Returns
|
160
|
+
-------
|
161
|
+
bytes
|
162
|
+
The public key associated with the specified `node_id`.
|
163
|
+
|
164
|
+
Raises
|
165
|
+
------
|
166
|
+
ValueError
|
167
|
+
If the specified `node_id` does not exist in the link state.
|
168
|
+
"""
|
159
169
|
|
160
170
|
@abc.abstractmethod
|
161
171
|
def create_run( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
@@ -21,7 +21,6 @@ import json
|
|
21
21
|
import re
|
22
22
|
import secrets
|
23
23
|
import sqlite3
|
24
|
-
import time
|
25
24
|
from collections.abc import Sequence
|
26
25
|
from logging import DEBUG, ERROR, WARNING
|
27
26
|
from typing import Any, Optional, Union, cast
|
@@ -72,10 +71,16 @@ from .utils import (
|
|
72
71
|
|
73
72
|
SQL_CREATE_TABLE_NODE = """
|
74
73
|
CREATE TABLE IF NOT EXISTS node(
|
75
|
-
node_id
|
76
|
-
|
77
|
-
|
78
|
-
|
74
|
+
node_id INTEGER UNIQUE,
|
75
|
+
owner_aid TEXT,
|
76
|
+
status TEXT,
|
77
|
+
created_at TEXT,
|
78
|
+
last_activated_at TEXT,
|
79
|
+
last_deactivated_at TEXT,
|
80
|
+
deleted_at TEXT,
|
81
|
+
online_until REAL,
|
82
|
+
heartbeat_interval REAL,
|
83
|
+
public_key BLOB UNIQUE
|
79
84
|
);
|
80
85
|
"""
|
81
86
|
|
@@ -89,6 +94,10 @@ SQL_CREATE_INDEX_ONLINE_UNTIL = """
|
|
89
94
|
CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
|
90
95
|
"""
|
91
96
|
|
97
|
+
SQL_CREATE_INDEX_OWNER_AID = """
|
98
|
+
CREATE INDEX IF NOT EXISTS idx_node_owner_aid ON node(owner_aid);
|
99
|
+
"""
|
100
|
+
|
92
101
|
SQL_CREATE_TABLE_RUN = """
|
93
102
|
CREATE TABLE IF NOT EXISTS run(
|
94
103
|
run_id INTEGER UNIQUE,
|
@@ -223,6 +232,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
223
232
|
cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
|
224
233
|
cur.execute(SQL_CREATE_TABLE_TOKEN_STORE)
|
225
234
|
cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
|
235
|
+
cur.execute(SQL_CREATE_INDEX_OWNER_AID)
|
226
236
|
res = cur.execute("SELECT name FROM sqlite_schema;")
|
227
237
|
return res.fetchall()
|
228
238
|
|
@@ -451,7 +461,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
451
461
|
ret: dict[str, Message] = {}
|
452
462
|
|
453
463
|
# Verify Message IDs
|
454
|
-
current =
|
464
|
+
current = now().timestamp()
|
455
465
|
query = f"""
|
456
466
|
SELECT *
|
457
467
|
FROM message_ins
|
@@ -597,7 +607,9 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
597
607
|
|
598
608
|
return {row["message_id"] for row in rows}
|
599
609
|
|
600
|
-
def create_node(
|
610
|
+
def create_node(
|
611
|
+
self, owner_aid: str, public_key: bytes, heartbeat_interval: float
|
612
|
+
) -> int:
|
601
613
|
"""Create, store in the link state, and return `node_id`."""
|
602
614
|
# Sample a random uint64 as node_id
|
603
615
|
uint64_node_id = generate_rand_int_from_bytes(
|
@@ -607,37 +619,48 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
607
619
|
# Convert the uint64 value to sint64 for SQLite
|
608
620
|
sint64_node_id = convert_uint64_to_sint64(uint64_node_id)
|
609
621
|
|
610
|
-
query =
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
622
|
+
query = """
|
623
|
+
INSERT INTO node
|
624
|
+
(node_id, owner_aid, status, created_at, last_activated_at,
|
625
|
+
last_deactivated_at, deleted_at, online_until, heartbeat_interval,
|
626
|
+
public_key)
|
627
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
628
|
+
"""
|
615
629
|
|
616
|
-
# Mark the node online
|
630
|
+
# Mark the node online until now().timestamp() + heartbeat_interval
|
617
631
|
try:
|
618
632
|
self.query(
|
619
633
|
query,
|
620
634
|
(
|
621
|
-
sint64_node_id,
|
622
|
-
|
623
|
-
|
624
|
-
|
635
|
+
sint64_node_id, # node_id
|
636
|
+
owner_aid, # owner_aid, unused for now
|
637
|
+
"created", # status, unused for now
|
638
|
+
now().isoformat(), # created_at, unused for now
|
639
|
+
now().isoformat(), # last_activated_at, unused for now
|
640
|
+
"", # last_deactivated_at, unused for now
|
641
|
+
"", # deleted_at, unused for now
|
642
|
+
now().timestamp() + heartbeat_interval, # online_until
|
643
|
+
heartbeat_interval, # heartbeat_interval
|
644
|
+
public_key, # public_key
|
625
645
|
),
|
626
646
|
)
|
627
|
-
except sqlite3.IntegrityError:
|
647
|
+
except sqlite3.IntegrityError as e:
|
648
|
+
if "UNIQUE constraint failed: node.public_key" in str(e):
|
649
|
+
raise ValueError("Public key already in use.") from None
|
650
|
+
# Must be node ID conflict, almost impossible unless system is compromised
|
628
651
|
log(ERROR, "Unexpected node registration failure.")
|
629
652
|
return 0
|
630
653
|
|
631
654
|
# Note: we need to return the uint64 value of the node_id
|
632
655
|
return uint64_node_id
|
633
656
|
|
634
|
-
def delete_node(self, node_id: int) -> None:
|
657
|
+
def delete_node(self, owner_aid: str, node_id: int) -> None:
|
635
658
|
"""Delete a node."""
|
636
659
|
# Convert the uint64 value to sint64 for SQLite
|
637
660
|
sint64_node_id = convert_uint64_to_sint64(node_id)
|
638
661
|
|
639
|
-
query = "DELETE FROM node WHERE node_id = ?"
|
640
|
-
params = (sint64_node_id,)
|
662
|
+
query = "DELETE FROM node WHERE node_id = ? AND owner_aid = ?"
|
663
|
+
params = (sint64_node_id, owner_aid)
|
641
664
|
|
642
665
|
if self.conn is None:
|
643
666
|
raise AttributeError("LinkState is not initialized.")
|
@@ -646,7 +669,9 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
646
669
|
with self.conn:
|
647
670
|
rows = self.conn.execute(query, params)
|
648
671
|
if rows.rowcount < 1:
|
649
|
-
raise ValueError(
|
672
|
+
raise ValueError(
|
673
|
+
f"Node ID {node_id} not found or unauthorized deletion attempt."
|
674
|
+
)
|
650
675
|
except KeyError as exc:
|
651
676
|
log(ERROR, {"query": query, "data": params, "exception": exc})
|
652
677
|
|
@@ -668,32 +693,13 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
668
693
|
|
669
694
|
# Get nodes
|
670
695
|
query = "SELECT node_id FROM node WHERE online_until > ?;"
|
671
|
-
rows = self.query(query, (
|
696
|
+
rows = self.query(query, (now().timestamp(),))
|
672
697
|
|
673
698
|
# Convert sint64 node_ids to uint64
|
674
699
|
result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows}
|
675
700
|
return result
|
676
701
|
|
677
|
-
def
|
678
|
-
"""Set `public_key` for the specified `node_id`."""
|
679
|
-
# Convert the uint64 value to sint64 for SQLite
|
680
|
-
sint64_node_id = convert_uint64_to_sint64(node_id)
|
681
|
-
|
682
|
-
# Check if the node exists in the `node` table
|
683
|
-
query = "SELECT 1 FROM node WHERE node_id = ?"
|
684
|
-
if not self.query(query, (sint64_node_id,)):
|
685
|
-
raise ValueError(f"Node {node_id} not found")
|
686
|
-
|
687
|
-
# Check if the public key is already in use in the `node` table
|
688
|
-
query = "SELECT 1 FROM node WHERE public_key = ?"
|
689
|
-
if self.query(query, (public_key,)):
|
690
|
-
raise ValueError("Public key already in use")
|
691
|
-
|
692
|
-
# Update the `node` table to set the public key for the given node ID
|
693
|
-
query = "UPDATE node SET public_key = ? WHERE node_id = ?"
|
694
|
-
self.query(query, (public_key, sint64_node_id))
|
695
|
-
|
696
|
-
def get_node_public_key(self, node_id: int) -> Optional[bytes]:
|
702
|
+
def get_node_public_key(self, node_id: int) -> bytes:
|
697
703
|
"""Get `public_key` for the specified `node_id`."""
|
698
704
|
# Convert the uint64 value to sint64 for SQLite
|
699
705
|
sint64_node_id = convert_uint64_to_sint64(node_id)
|
@@ -704,23 +710,10 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
704
710
|
|
705
711
|
# If no result is found, return None
|
706
712
|
if not rows:
|
707
|
-
raise ValueError(f"Node {node_id} not found")
|
708
|
-
|
709
|
-
# Return the public key if it is not empty, otherwise return None
|
710
|
-
return rows[0]["public_key"] or None
|
711
|
-
|
712
|
-
def get_node_id(self, node_public_key: bytes) -> Optional[int]:
|
713
|
-
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
|
714
|
-
query = "SELECT node_id FROM node WHERE public_key = :public_key;"
|
715
|
-
row = self.query(query, {"public_key": node_public_key})
|
716
|
-
if len(row) > 0:
|
717
|
-
node_id: int = row[0]["node_id"]
|
713
|
+
raise ValueError(f"Node ID {node_id} not found")
|
718
714
|
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
return uint64_node_id
|
723
|
-
return None
|
715
|
+
# Return the public key
|
716
|
+
return cast(bytes, rows[0]["public_key"])
|
724
717
|
|
725
718
|
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
726
719
|
def create_run(
|
@@ -1010,7 +1003,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
1010
1003
|
self.query(
|
1011
1004
|
query,
|
1012
1005
|
(
|
1013
|
-
|
1006
|
+
now().timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval,
|
1014
1007
|
heartbeat_interval,
|
1015
1008
|
sint64_node_id,
|
1016
1009
|
),
|
@@ -1140,7 +1133,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
1140
1133
|
message_ins = rows[0]
|
1141
1134
|
created_at = message_ins["created_at"]
|
1142
1135
|
ttl = message_ins["ttl"]
|
1143
|
-
current_time =
|
1136
|
+
current_time = now().timestamp()
|
1144
1137
|
|
1145
1138
|
# Check if Message is expired
|
1146
1139
|
if ttl is not None and created_at + ttl <= current_time:
|
flwr/server/utils/validator.py
CHANGED
@@ -15,10 +15,9 @@
|
|
15
15
|
"""Validators."""
|
16
16
|
|
17
17
|
|
18
|
-
import time
|
19
|
-
|
20
18
|
from flwr.common import Message
|
21
19
|
from flwr.common.constant import SUPERLINK_NODE_ID
|
20
|
+
from flwr.common.date import now
|
22
21
|
|
23
22
|
|
24
23
|
# pylint: disable-next=too-many-branches
|
@@ -44,7 +43,7 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
|
|
44
43
|
validation_errors.append("`metadata.ttl` must be higher than zero")
|
45
44
|
|
46
45
|
# Verify TTL and created_at time
|
47
|
-
current_time =
|
46
|
+
current_time = now().timestamp()
|
48
47
|
if metadata.created_at + metadata.ttl <= current_time:
|
49
48
|
validation_errors.append("Message TTL has expired")
|
50
49
|
|
@@ -107,3 +107,11 @@ def verify_signature(
|
|
107
107
|
return True
|
108
108
|
except InvalidSignature:
|
109
109
|
return False
|
110
|
+
|
111
|
+
|
112
|
+
def uses_nist_ec_curve(public_key: ec.EllipticCurvePublicKey) -> bool:
|
113
|
+
"""Return True if the provided key uses a NIST EC curve."""
|
114
|
+
return isinstance(
|
115
|
+
public_key.curve,
|
116
|
+
(ec.SECP192R1, ec.SECP224R1, ec.SECP256R1, ec.SECP384R1, ec.SECP521R1),
|
117
|
+
)
|
@@ -15,12 +15,41 @@
|
|
15
15
|
"""Account auth plugin for ControlServicer."""
|
16
16
|
|
17
17
|
|
18
|
+
from flwr.common.constant import AuthnType, AuthzType
|
19
|
+
|
18
20
|
from .auth_plugin import ControlAuthnPlugin, ControlAuthzPlugin
|
19
21
|
from .noop_auth_plugin import NoOpControlAuthnPlugin, NoOpControlAuthzPlugin
|
20
22
|
|
23
|
+
try:
|
24
|
+
from flwr.ee import get_control_authn_ee_plugins, get_control_authz_ee_plugins
|
25
|
+
except ImportError:
|
26
|
+
|
27
|
+
def get_control_authn_ee_plugins() -> dict[str, type[ControlAuthnPlugin]]:
|
28
|
+
"""Return all Control API authentication plugins for EE."""
|
29
|
+
return {}
|
30
|
+
|
31
|
+
def get_control_authz_ee_plugins() -> dict[str, type[ControlAuthzPlugin]]:
|
32
|
+
"""Return all Control API authorization plugins for EE."""
|
33
|
+
return {}
|
34
|
+
|
35
|
+
|
36
|
+
def get_control_authn_plugins() -> dict[str, type[ControlAuthnPlugin]]:
|
37
|
+
"""Return all Control API authentication plugins."""
|
38
|
+
ee_dict: dict[str, type[ControlAuthnPlugin]] = get_control_authn_ee_plugins()
|
39
|
+
return ee_dict | {AuthnType.NOOP: NoOpControlAuthnPlugin}
|
40
|
+
|
41
|
+
|
42
|
+
def get_control_authz_plugins() -> dict[str, type[ControlAuthzPlugin]]:
|
43
|
+
"""Return all Control API authorization plugins."""
|
44
|
+
ee_dict: dict[str, type[ControlAuthzPlugin]] = get_control_authz_ee_plugins()
|
45
|
+
return ee_dict | {AuthzType.NOOP: NoOpControlAuthzPlugin}
|
46
|
+
|
47
|
+
|
21
48
|
__all__ = [
|
22
49
|
"ControlAuthnPlugin",
|
23
50
|
"ControlAuthzPlugin",
|
24
51
|
"NoOpControlAuthnPlugin",
|
25
52
|
"NoOpControlAuthzPlugin",
|
53
|
+
"get_control_authn_plugins",
|
54
|
+
"get_control_authz_plugins",
|
26
55
|
]
|
@@ -31,7 +31,11 @@ from flwr.supercore.ffs import FfsFactory
|
|
31
31
|
from flwr.supercore.license_plugin import LicensePlugin
|
32
32
|
from flwr.supercore.object_store import ObjectStoreFactory
|
33
33
|
from flwr.superlink.artifact_provider import ArtifactProvider
|
34
|
-
from flwr.superlink.auth_plugin import
|
34
|
+
from flwr.superlink.auth_plugin import (
|
35
|
+
ControlAuthnPlugin,
|
36
|
+
ControlAuthzPlugin,
|
37
|
+
NoOpControlAuthnPlugin,
|
38
|
+
)
|
35
39
|
|
36
40
|
from .control_account_auth_interceptor import ControlAccountAuthInterceptor
|
37
41
|
from .control_event_log_interceptor import ControlEventLogInterceptor
|
@@ -54,8 +58,8 @@ def run_control_api_grpc(
|
|
54
58
|
objectstore_factory: ObjectStoreFactory,
|
55
59
|
certificates: Optional[tuple[bytes, bytes, bytes]],
|
56
60
|
is_simulation: bool,
|
57
|
-
authn_plugin:
|
58
|
-
authz_plugin:
|
61
|
+
authn_plugin: ControlAuthnPlugin,
|
62
|
+
authz_plugin: ControlAuthzPlugin,
|
59
63
|
event_log_plugin: Optional[EventLogWriterPlugin] = None,
|
60
64
|
artifact_provider: Optional[ArtifactProvider] = None,
|
61
65
|
) -> grpc.Server:
|
@@ -72,11 +76,9 @@ def run_control_api_grpc(
|
|
72
76
|
authn_plugin=authn_plugin,
|
73
77
|
artifact_provider=artifact_provider,
|
74
78
|
)
|
75
|
-
interceptors
|
79
|
+
interceptors = [ControlAccountAuthInterceptor(authn_plugin, authz_plugin)]
|
76
80
|
if license_plugin is not None:
|
77
81
|
interceptors.append(ControlLicenseInterceptor(license_plugin))
|
78
|
-
if authn_plugin is not None and authz_plugin is not None:
|
79
|
-
interceptors.append(ControlAccountAuthInterceptor(authn_plugin, authz_plugin))
|
80
82
|
# Event log interceptor must be added after account auth interceptor
|
81
83
|
if event_log_plugin is not None:
|
82
84
|
interceptors.append(ControlEventLogInterceptor(event_log_plugin))
|
@@ -90,7 +92,7 @@ def run_control_api_grpc(
|
|
90
92
|
interceptors=interceptors or None,
|
91
93
|
)
|
92
94
|
|
93
|
-
if authn_plugin
|
95
|
+
if isinstance(authn_plugin, NoOpControlAuthnPlugin):
|
94
96
|
log(INFO, "Flower Deployment Runtime: Starting Control API on %s", address)
|
95
97
|
else:
|
96
98
|
log(
|