flwr-nightly 1.23.0.dev20251007__py3-none-any.whl → 1.23.0.dev20251008__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/auth_plugin/__init__.py +7 -3
- flwr/cli/log.py +2 -2
- flwr/cli/login/login.py +4 -13
- flwr/cli/ls.py +2 -2
- flwr/cli/pull.py +2 -2
- flwr/cli/run/run.py +2 -2
- flwr/cli/stop.py +2 -2
- flwr/cli/supernode/ls.py +2 -2
- flwr/cli/utils.py +28 -44
- flwr/client/grpc_rere_client/connection.py +6 -4
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +2 -2
- flwr/client/rest_client/connection.py +7 -1
- flwr/common/constant.py +10 -0
- flwr/proto/fleet_pb2.py +22 -22
- flwr/proto/fleet_pb2.pyi +4 -1
- flwr/proto/node_pb2.py +1 -1
- flwr/server/app.py +32 -31
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +8 -4
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +18 -37
- flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +7 -1
- flwr/server/superlink/linkstate/in_memory_linkstate.py +39 -27
- flwr/server/superlink/linkstate/linkstate.py +1 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +37 -21
- flwr/server/utils/validator.py +2 -3
- flwr/superlink/auth_plugin/__init__.py +29 -0
- flwr/superlink/servicer/control/control_grpc.py +9 -7
- flwr/superlink/servicer/control/control_servicer.py +34 -46
- {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/METADATA +1 -1
- {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/RECORD +32 -32
- {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/entry_points.txt +0 -0
flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py}
RENAMED
@@ -16,7 +16,7 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
import datetime
|
19
|
-
from typing import Any, Callable,
|
19
|
+
from typing import Any, Callable, cast
|
20
20
|
|
21
21
|
import grpc
|
22
22
|
from google.protobuf.message import Message as GrpcMessage
|
@@ -29,10 +29,7 @@ from flwr.common.constant import (
|
|
29
29
|
TIMESTAMP_HEADER,
|
30
30
|
TIMESTAMP_TOLERANCE,
|
31
31
|
)
|
32
|
-
from flwr.proto.fleet_pb2 import
|
33
|
-
CreateNodeRequest,
|
34
|
-
CreateNodeResponse,
|
35
|
-
)
|
32
|
+
from flwr.proto.fleet_pb2 import CreateNodeRequest # pylint: disable=E0611
|
36
33
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
37
34
|
from flwr.supercore.primitives.asymmetric import bytes_to_public_key, verify_signature
|
38
35
|
|
@@ -50,7 +47,7 @@ def _unary_unary_rpc_terminator(
|
|
50
47
|
return grpc.unary_unary_rpc_method_handler(terminate)
|
51
48
|
|
52
49
|
|
53
|
-
class
|
50
|
+
class NodeAuthServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
54
51
|
"""Server interceptor for node authentication.
|
55
52
|
|
56
53
|
Parameters
|
@@ -110,50 +107,34 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
110
107
|
if not MIN_TIMESTAMP_DIFF < time_diff.total_seconds() < MAX_TIMESTAMP_DIFF:
|
111
108
|
return _unary_unary_rpc_terminator("Invalid timestamp")
|
112
109
|
|
113
|
-
# Continue the RPC call
|
114
|
-
expected_node_id = state.get_node_id(node_pk_bytes)
|
115
|
-
if not handler_call_details.method.endswith("CreateNode"):
|
116
|
-
# All calls, except for `CreateNode`, must provide a public key that is
|
117
|
-
# already mapped to a `node_id` (in `LinkState`)
|
118
|
-
if expected_node_id is None:
|
119
|
-
return _unary_unary_rpc_terminator("Invalid node ID")
|
120
|
-
# One of the method handlers in
|
110
|
+
# Continue the RPC call: One of the method handlers in
|
121
111
|
# `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
|
122
112
|
method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
|
123
|
-
return self._wrap_method_handler(
|
124
|
-
method_handler, expected_node_id, node_pk_bytes
|
125
|
-
)
|
113
|
+
return self._wrap_method_handler(method_handler, node_pk_bytes)
|
126
114
|
|
127
115
|
def _wrap_method_handler(
|
128
116
|
self,
|
129
117
|
method_handler: grpc.RpcMethodHandler,
|
130
|
-
|
131
|
-
node_public_key: bytes,
|
118
|
+
expected_public_key: bytes,
|
132
119
|
) -> grpc.RpcMethodHandler:
|
133
120
|
def _generic_method_handler(
|
134
121
|
request: GrpcMessage,
|
135
122
|
context: grpc.ServicerContext,
|
136
123
|
) -> GrpcMessage:
|
137
|
-
#
|
138
|
-
if
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
124
|
+
# Retrieve the public key
|
125
|
+
if isinstance(request, CreateNodeRequest):
|
126
|
+
actual_public_key = request.public_key
|
127
|
+
else:
|
128
|
+
# Note: This function runs in a different thread
|
129
|
+
# than the `intercept_service` function.
|
130
|
+
actual_public_key = self.state_factory.state().get_node_public_key(
|
131
|
+
request.node.node_id # type: ignore
|
132
|
+
)
|
133
|
+
# Verify the public key
|
134
|
+
if actual_public_key != expected_public_key:
|
135
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID")
|
144
136
|
|
145
137
|
response: GrpcMessage = method_handler.unary_unary(request, context)
|
146
|
-
|
147
|
-
# Set the public key after a successful CreateNode request
|
148
|
-
if isinstance(response, CreateNodeResponse):
|
149
|
-
state = self.state_factory.state()
|
150
|
-
try:
|
151
|
-
state.set_node_public_key(response.node.node_id, node_public_key)
|
152
|
-
except ValueError as e:
|
153
|
-
# Remove newly created node if setting the public key fails
|
154
|
-
state.delete_node(response.node.node_id)
|
155
|
-
context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e))
|
156
|
-
|
157
138
|
return response
|
158
139
|
|
159
140
|
return grpc.unary_unary_rpc_method_handler(
|
@@ -70,7 +70,7 @@ def create_node(
|
|
70
70
|
) -> CreateNodeResponse:
|
71
71
|
"""."""
|
72
72
|
# Create node
|
73
|
-
node_id = state.create_node(
|
73
|
+
node_id = state.create_node(request.public_key, request.heartbeat_interval)
|
74
74
|
return CreateNodeResponse(node=Node(node_id=node_id))
|
75
75
|
|
76
76
|
|
@@ -16,6 +16,7 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
import json
|
19
|
+
import secrets
|
19
20
|
import threading
|
20
21
|
import time
|
21
22
|
import traceback
|
@@ -53,7 +54,12 @@ def _register_nodes(
|
|
53
54
|
nodes_mapping: NodeToPartitionMapping = {}
|
54
55
|
state = state_factory.state()
|
55
56
|
for i in range(num_nodes):
|
56
|
-
node_id = state.create_node(
|
57
|
+
node_id = state.create_node(
|
58
|
+
# No node authentication in simulation;
|
59
|
+
# use random bytes instead
|
60
|
+
secrets.token_bytes(32),
|
61
|
+
heartbeat_interval=HEARTBEAT_MAX_INTERVAL,
|
62
|
+
)
|
57
63
|
nodes_mapping[node_id] = i
|
58
64
|
log(DEBUG, "Registered %i nodes", len(nodes_mapping))
|
59
65
|
return nodes_mapping
|
@@ -17,7 +17,6 @@
|
|
17
17
|
|
18
18
|
import secrets
|
19
19
|
import threading
|
20
|
-
import time
|
21
20
|
from bisect import bisect_right
|
22
21
|
from collections import defaultdict
|
23
22
|
from dataclasses import dataclass, field
|
@@ -39,6 +38,7 @@ from flwr.common.constant import (
|
|
39
38
|
)
|
40
39
|
from flwr.common.record import ConfigRecord
|
41
40
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
41
|
+
from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
|
42
42
|
from flwr.server.superlink.linkstate.linkstate import LinkState
|
43
43
|
from flwr.server.utils import validate_message
|
44
44
|
|
@@ -70,7 +70,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
70
70
|
def __init__(self) -> None:
|
71
71
|
|
72
72
|
# Map node_id to (online_until, heartbeat_interval)
|
73
|
-
self.
|
73
|
+
self.nodes: dict[int, NodeInfo] = {}
|
74
74
|
self.public_key_to_node_id: dict[bytes, int] = {}
|
75
75
|
self.node_id_to_public_key: dict[int, bytes] = {}
|
76
76
|
|
@@ -114,7 +114,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
114
114
|
)
|
115
115
|
return None
|
116
116
|
# Validate destination node ID
|
117
|
-
if message.metadata.dst_node_id not in self.
|
117
|
+
if message.metadata.dst_node_id not in self.nodes:
|
118
118
|
log(
|
119
119
|
ERROR,
|
120
120
|
"Invalid destination node ID for Message: %s",
|
@@ -136,7 +136,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
136
136
|
|
137
137
|
# Find Message for node_id that were not delivered yet
|
138
138
|
message_ins_list: list[Message] = []
|
139
|
-
current_time =
|
139
|
+
current_time = now().timestamp()
|
140
140
|
with self.lock:
|
141
141
|
for _, msg_ins in self.message_ins_store.items():
|
142
142
|
if (
|
@@ -190,7 +190,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
190
190
|
return None
|
191
191
|
|
192
192
|
ins_metadata = msg_ins.metadata
|
193
|
-
if ins_metadata.created_at + ins_metadata.ttl <=
|
193
|
+
if ins_metadata.created_at + ins_metadata.ttl <= now().timestamp():
|
194
194
|
log(
|
195
195
|
ERROR,
|
196
196
|
"Failed to store Message: the message it is replying to "
|
@@ -238,7 +238,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
238
238
|
ret: dict[str, Message] = {}
|
239
239
|
|
240
240
|
with self.lock:
|
241
|
-
current =
|
241
|
+
current = now().timestamp()
|
242
242
|
|
243
243
|
# Verify Message IDs
|
244
244
|
ret = verify_message_ids(
|
@@ -256,9 +256,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
256
256
|
inquired_in_message_ids=message_ids,
|
257
257
|
found_in_message_dict=self.message_ins_store,
|
258
258
|
node_id_to_online_until={
|
259
|
-
node_id: self.
|
259
|
+
node_id: self.nodes[node_id].online_until
|
260
260
|
for node_id in dst_node_ids
|
261
|
-
if node_id in self.
|
261
|
+
if node_id in self.nodes
|
262
262
|
},
|
263
263
|
current_time=current,
|
264
264
|
)
|
@@ -330,7 +330,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
330
330
|
"""
|
331
331
|
return len(self.message_res_store)
|
332
332
|
|
333
|
-
def create_node(self, heartbeat_interval: float) -> int:
|
333
|
+
def create_node(self, public_key: bytes, heartbeat_interval: float) -> int:
|
334
334
|
"""Create, store in the link state, and return `node_id`."""
|
335
335
|
# Sample a random int64 as node_id
|
336
336
|
node_id = generate_rand_int_from_bytes(
|
@@ -338,28 +338,40 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
338
338
|
)
|
339
339
|
|
340
340
|
with self.lock:
|
341
|
-
if node_id in self.
|
341
|
+
if node_id in self.nodes:
|
342
342
|
log(ERROR, "Unexpected node registration failure.")
|
343
343
|
return 0
|
344
|
+
if public_key in self.public_key_to_node_id:
|
345
|
+
raise ValueError("Public key already in use")
|
344
346
|
|
345
|
-
# Mark the node online until
|
346
|
-
|
347
|
-
|
348
|
-
|
347
|
+
# Mark the node online until now().timestamp() + heartbeat_interval
|
348
|
+
current = now()
|
349
|
+
self.nodes[node_id] = NodeInfo(
|
350
|
+
node_id=node_id,
|
351
|
+
owner_aid="", # Unused for now
|
352
|
+
status="created", # Unused for now
|
353
|
+
created_at=current.isoformat(), # Unused for now
|
354
|
+
last_activated_at=current.isoformat(), # Unused for now
|
355
|
+
last_deactivated_at="", # Unused for now
|
356
|
+
deleted_at="", # Unused for now
|
357
|
+
online_until=current.timestamp() + heartbeat_interval,
|
358
|
+
heartbeat_interval=heartbeat_interval,
|
349
359
|
)
|
360
|
+
self.public_key_to_node_id[public_key] = node_id
|
361
|
+
self.node_id_to_public_key[node_id] = public_key
|
350
362
|
return node_id
|
351
363
|
|
352
364
|
def delete_node(self, node_id: int) -> None:
|
353
365
|
"""Delete a node."""
|
354
366
|
with self.lock:
|
355
|
-
if node_id not in self.
|
367
|
+
if node_id not in self.nodes:
|
356
368
|
raise ValueError(f"Node {node_id} not found")
|
357
369
|
|
358
370
|
# Remove node ID <> public key mappings
|
359
371
|
if pk := self.node_id_to_public_key.pop(node_id, None):
|
360
372
|
del self.public_key_to_node_id[pk]
|
361
373
|
|
362
|
-
del self.
|
374
|
+
del self.nodes[node_id]
|
363
375
|
|
364
376
|
def get_nodes(self, run_id: int) -> set[int]:
|
365
377
|
"""Return all available nodes.
|
@@ -372,17 +384,17 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
372
384
|
with self.lock:
|
373
385
|
if run_id not in self.run_ids:
|
374
386
|
return set()
|
375
|
-
current_time =
|
387
|
+
current_time = now().timestamp()
|
376
388
|
return {
|
377
|
-
node_id
|
378
|
-
for
|
379
|
-
if online_until > current_time
|
389
|
+
info.node_id
|
390
|
+
for info in self.nodes.values()
|
391
|
+
if info.online_until > current_time
|
380
392
|
}
|
381
393
|
|
382
394
|
def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
|
383
395
|
"""Set `public_key` for the specified `node_id`."""
|
384
396
|
with self.lock:
|
385
|
-
if node_id not in self.
|
397
|
+
if node_id not in self.nodes:
|
386
398
|
raise ValueError(f"Node {node_id} not found")
|
387
399
|
|
388
400
|
if public_key in self.public_key_to_node_id:
|
@@ -394,7 +406,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
394
406
|
def get_node_public_key(self, node_id: int) -> Optional[bytes]:
|
395
407
|
"""Get `public_key` for the specified `node_id`."""
|
396
408
|
with self.lock:
|
397
|
-
if node_id not in self.
|
409
|
+
if node_id not in self.nodes:
|
398
410
|
raise ValueError(f"Node {node_id} not found")
|
399
411
|
|
400
412
|
return self.node_id_to_public_key.get(node_id)
|
@@ -608,13 +620,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
608
620
|
the node is marked as offline.
|
609
621
|
"""
|
610
622
|
with self.lock:
|
611
|
-
if
|
612
|
-
|
613
|
-
|
614
|
-
heartbeat_interval,
|
623
|
+
if info := self.nodes.get(node_id):
|
624
|
+
info.online_until = (
|
625
|
+
now().timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval
|
615
626
|
)
|
627
|
+
info.heartbeat_interval = heartbeat_interval
|
616
628
|
return True
|
617
|
-
|
629
|
+
return False
|
618
630
|
|
619
631
|
def acknowledge_app_heartbeat(self, run_id: int, heartbeat_interval: float) -> bool:
|
620
632
|
"""Acknowledge a heartbeat received from a ServerApp for a given run.
|
@@ -128,7 +128,7 @@ class LinkState(CoreState): # pylint: disable=R0904
|
|
128
128
|
"""Get all instruction Message IDs for the given run_id."""
|
129
129
|
|
130
130
|
@abc.abstractmethod
|
131
|
-
def create_node(self, heartbeat_interval: float) -> int:
|
131
|
+
def create_node(self, public_key: bytes, heartbeat_interval: float) -> int:
|
132
132
|
"""Create, store in the link state, and return `node_id`."""
|
133
133
|
|
134
134
|
@abc.abstractmethod
|
@@ -21,7 +21,6 @@ import json
|
|
21
21
|
import re
|
22
22
|
import secrets
|
23
23
|
import sqlite3
|
24
|
-
import time
|
25
24
|
from collections.abc import Sequence
|
26
25
|
from logging import DEBUG, ERROR, WARNING
|
27
26
|
from typing import Any, Optional, Union, cast
|
@@ -72,10 +71,16 @@ from .utils import (
|
|
72
71
|
|
73
72
|
SQL_CREATE_TABLE_NODE = """
|
74
73
|
CREATE TABLE IF NOT EXISTS node(
|
75
|
-
node_id
|
76
|
-
|
77
|
-
|
78
|
-
|
74
|
+
node_id INTEGER UNIQUE,
|
75
|
+
owner_aid TEXT,
|
76
|
+
status TEXT,
|
77
|
+
created_at TEXT,
|
78
|
+
last_activated_at TEXT,
|
79
|
+
last_deactivated_at TEXT,
|
80
|
+
deleted_at TEXT,
|
81
|
+
online_until REAL,
|
82
|
+
heartbeat_interval REAL,
|
83
|
+
public_key BLOB UNIQUE
|
79
84
|
);
|
80
85
|
"""
|
81
86
|
|
@@ -451,7 +456,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
451
456
|
ret: dict[str, Message] = {}
|
452
457
|
|
453
458
|
# Verify Message IDs
|
454
|
-
current =
|
459
|
+
current = now().timestamp()
|
455
460
|
query = f"""
|
456
461
|
SELECT *
|
457
462
|
FROM message_ins
|
@@ -597,7 +602,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
597
602
|
|
598
603
|
return {row["message_id"] for row in rows}
|
599
604
|
|
600
|
-
def create_node(self, heartbeat_interval: float) -> int:
|
605
|
+
def create_node(self, public_key: bytes, heartbeat_interval: float) -> int:
|
601
606
|
"""Create, store in the link state, and return `node_id`."""
|
602
607
|
# Sample a random uint64 as node_id
|
603
608
|
uint64_node_id = generate_rand_int_from_bytes(
|
@@ -607,24 +612,35 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
607
612
|
# Convert the uint64 value to sint64 for SQLite
|
608
613
|
sint64_node_id = convert_uint64_to_sint64(uint64_node_id)
|
609
614
|
|
610
|
-
query =
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
+
query = """
|
616
|
+
INSERT INTO node
|
617
|
+
(node_id, owner_aid, status, created_at, last_activated_at,
|
618
|
+
last_deactivated_at, deleted_at, online_until, heartbeat_interval,
|
619
|
+
public_key)
|
620
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
621
|
+
"""
|
615
622
|
|
616
|
-
# Mark the node online
|
623
|
+
# Mark the node online until now().timestamp() + heartbeat_interval
|
617
624
|
try:
|
618
625
|
self.query(
|
619
626
|
query,
|
620
627
|
(
|
621
|
-
sint64_node_id,
|
622
|
-
|
623
|
-
|
624
|
-
|
628
|
+
sint64_node_id, # node_id
|
629
|
+
"", # owner_aid, unused for now
|
630
|
+
"created", # status, unused for now
|
631
|
+
now().isoformat(), # created_at, unused for now
|
632
|
+
now().isoformat(), # last_activated_at, unused for now
|
633
|
+
"", # last_deactivated_at, unused for now
|
634
|
+
"", # deleted_at, unused for now
|
635
|
+
now().timestamp() + heartbeat_interval, # online_until
|
636
|
+
heartbeat_interval, # heartbeat_interval
|
637
|
+
public_key, # public_key
|
625
638
|
),
|
626
639
|
)
|
627
|
-
except sqlite3.IntegrityError:
|
640
|
+
except sqlite3.IntegrityError as e:
|
641
|
+
if "UNIQUE constraint failed: node.public_key" in str(e):
|
642
|
+
raise ValueError("Public key already in use.") from None
|
643
|
+
# Must be node ID conflict, almost impossible unless system is compromised
|
628
644
|
log(ERROR, "Unexpected node registration failure.")
|
629
645
|
return 0
|
630
646
|
|
@@ -668,7 +684,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
668
684
|
|
669
685
|
# Get nodes
|
670
686
|
query = "SELECT node_id FROM node WHERE online_until > ?;"
|
671
|
-
rows = self.query(query, (
|
687
|
+
rows = self.query(query, (now().timestamp(),))
|
672
688
|
|
673
689
|
# Convert sint64 node_ids to uint64
|
674
690
|
result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows}
|
@@ -1010,7 +1026,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
1010
1026
|
self.query(
|
1011
1027
|
query,
|
1012
1028
|
(
|
1013
|
-
|
1029
|
+
now().timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval,
|
1014
1030
|
heartbeat_interval,
|
1015
1031
|
sint64_node_id,
|
1016
1032
|
),
|
@@ -1140,7 +1156,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
1140
1156
|
message_ins = rows[0]
|
1141
1157
|
created_at = message_ins["created_at"]
|
1142
1158
|
ttl = message_ins["ttl"]
|
1143
|
-
current_time =
|
1159
|
+
current_time = now().timestamp()
|
1144
1160
|
|
1145
1161
|
# Check if Message is expired
|
1146
1162
|
if ttl is not None and created_at + ttl <= current_time:
|
flwr/server/utils/validator.py
CHANGED
@@ -15,10 +15,9 @@
|
|
15
15
|
"""Validators."""
|
16
16
|
|
17
17
|
|
18
|
-
import time
|
19
|
-
|
20
18
|
from flwr.common import Message
|
21
19
|
from flwr.common.constant import SUPERLINK_NODE_ID
|
20
|
+
from flwr.common.date import now
|
22
21
|
|
23
22
|
|
24
23
|
# pylint: disable-next=too-many-branches
|
@@ -44,7 +43,7 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
|
|
44
43
|
validation_errors.append("`metadata.ttl` must be higher than zero")
|
45
44
|
|
46
45
|
# Verify TTL and created_at time
|
47
|
-
current_time =
|
46
|
+
current_time = now().timestamp()
|
48
47
|
if metadata.created_at + metadata.ttl <= current_time:
|
49
48
|
validation_errors.append("Message TTL has expired")
|
50
49
|
|
@@ -15,12 +15,41 @@
|
|
15
15
|
"""Account auth plugin for ControlServicer."""
|
16
16
|
|
17
17
|
|
18
|
+
from flwr.common.constant import AuthnType, AuthzType
|
19
|
+
|
18
20
|
from .auth_plugin import ControlAuthnPlugin, ControlAuthzPlugin
|
19
21
|
from .noop_auth_plugin import NoOpControlAuthnPlugin, NoOpControlAuthzPlugin
|
20
22
|
|
23
|
+
try:
|
24
|
+
from flwr.ee import get_control_authn_ee_plugins, get_control_authz_ee_plugins
|
25
|
+
except ImportError:
|
26
|
+
|
27
|
+
def get_control_authn_ee_plugins() -> dict[str, type[ControlAuthnPlugin]]:
|
28
|
+
"""Return all Control API authentication plugins for EE."""
|
29
|
+
return {}
|
30
|
+
|
31
|
+
def get_control_authz_ee_plugins() -> dict[str, type[ControlAuthzPlugin]]:
|
32
|
+
"""Return all Control API authorization plugins for EE."""
|
33
|
+
return {}
|
34
|
+
|
35
|
+
|
36
|
+
def get_control_authn_plugins() -> dict[str, type[ControlAuthnPlugin]]:
|
37
|
+
"""Return all Control API authentication plugins."""
|
38
|
+
ee_dict: dict[str, type[ControlAuthnPlugin]] = get_control_authn_ee_plugins()
|
39
|
+
return ee_dict | {AuthnType.NOOP: NoOpControlAuthnPlugin}
|
40
|
+
|
41
|
+
|
42
|
+
def get_control_authz_plugins() -> dict[str, type[ControlAuthzPlugin]]:
|
43
|
+
"""Return all Control API authorization plugins."""
|
44
|
+
ee_dict: dict[str, type[ControlAuthzPlugin]] = get_control_authz_ee_plugins()
|
45
|
+
return ee_dict | {AuthzType.NOOP: NoOpControlAuthzPlugin}
|
46
|
+
|
47
|
+
|
21
48
|
__all__ = [
|
22
49
|
"ControlAuthnPlugin",
|
23
50
|
"ControlAuthzPlugin",
|
24
51
|
"NoOpControlAuthnPlugin",
|
25
52
|
"NoOpControlAuthzPlugin",
|
53
|
+
"get_control_authn_plugins",
|
54
|
+
"get_control_authz_plugins",
|
26
55
|
]
|
@@ -31,7 +31,11 @@ from flwr.supercore.ffs import FfsFactory
|
|
31
31
|
from flwr.supercore.license_plugin import LicensePlugin
|
32
32
|
from flwr.supercore.object_store import ObjectStoreFactory
|
33
33
|
from flwr.superlink.artifact_provider import ArtifactProvider
|
34
|
-
from flwr.superlink.auth_plugin import
|
34
|
+
from flwr.superlink.auth_plugin import (
|
35
|
+
ControlAuthnPlugin,
|
36
|
+
ControlAuthzPlugin,
|
37
|
+
NoOpControlAuthnPlugin,
|
38
|
+
)
|
35
39
|
|
36
40
|
from .control_account_auth_interceptor import ControlAccountAuthInterceptor
|
37
41
|
from .control_event_log_interceptor import ControlEventLogInterceptor
|
@@ -54,8 +58,8 @@ def run_control_api_grpc(
|
|
54
58
|
objectstore_factory: ObjectStoreFactory,
|
55
59
|
certificates: Optional[tuple[bytes, bytes, bytes]],
|
56
60
|
is_simulation: bool,
|
57
|
-
authn_plugin:
|
58
|
-
authz_plugin:
|
61
|
+
authn_plugin: ControlAuthnPlugin,
|
62
|
+
authz_plugin: ControlAuthzPlugin,
|
59
63
|
event_log_plugin: Optional[EventLogWriterPlugin] = None,
|
60
64
|
artifact_provider: Optional[ArtifactProvider] = None,
|
61
65
|
) -> grpc.Server:
|
@@ -72,11 +76,9 @@ def run_control_api_grpc(
|
|
72
76
|
authn_plugin=authn_plugin,
|
73
77
|
artifact_provider=artifact_provider,
|
74
78
|
)
|
75
|
-
interceptors
|
79
|
+
interceptors = [ControlAccountAuthInterceptor(authn_plugin, authz_plugin)]
|
76
80
|
if license_plugin is not None:
|
77
81
|
interceptors.append(ControlLicenseInterceptor(license_plugin))
|
78
|
-
if authn_plugin is not None and authz_plugin is not None:
|
79
|
-
interceptors.append(ControlAccountAuthInterceptor(authn_plugin, authz_plugin))
|
80
82
|
# Event log interceptor must be added after account auth interceptor
|
81
83
|
if event_log_plugin is not None:
|
82
84
|
interceptors.append(ControlEventLogInterceptor(event_log_plugin))
|
@@ -90,7 +92,7 @@ def run_control_api_grpc(
|
|
90
92
|
interceptors=interceptors or None,
|
91
93
|
)
|
92
94
|
|
93
|
-
if authn_plugin
|
95
|
+
if isinstance(authn_plugin, NoOpControlAuthnPlugin):
|
94
96
|
log(INFO, "Flower Deployment Runtime: Starting Control API on %s", address)
|
95
97
|
else:
|
96
98
|
log(
|