flwr-nightly 1.23.0.dev20251006__py3-none-any.whl → 1.23.0.dev20251008__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/auth_plugin/__init__.py +7 -3
- flwr/cli/log.py +2 -2
- flwr/cli/login/login.py +4 -13
- flwr/cli/ls.py +2 -2
- flwr/cli/pull.py +2 -2
- flwr/cli/run/run.py +2 -2
- flwr/cli/stop.py +2 -2
- flwr/cli/supernode/ls.py +2 -2
- flwr/cli/utils.py +28 -44
- flwr/client/grpc_rere_client/connection.py +6 -6
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
- flwr/client/rest_client/connection.py +7 -1
- flwr/common/constant.py +10 -0
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
- flwr/proto/fleet_pb2.py +22 -22
- flwr/proto/fleet_pb2.pyi +4 -1
- flwr/proto/node_pb2.py +1 -1
- flwr/server/app.py +33 -34
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +8 -4
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +19 -41
- flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +7 -1
- flwr/server/superlink/linkstate/in_memory_linkstate.py +39 -27
- flwr/server/superlink/linkstate/linkstate.py +1 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +37 -21
- flwr/server/utils/validator.py +2 -3
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
- flwr/supercore/primitives/__init__.py +15 -0
- flwr/supercore/primitives/asymmetric.py +109 -0
- flwr/superlink/auth_plugin/__init__.py +29 -0
- flwr/superlink/servicer/control/control_grpc.py +9 -7
- flwr/superlink/servicer/control/control_servicer.py +34 -46
- {flwr_nightly-1.23.0.dev20251006.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/METADATA +1 -1
- {flwr_nightly-1.23.0.dev20251006.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/RECORD +37 -35
- {flwr_nightly-1.23.0.dev20251006.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.23.0.dev20251006.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/entry_points.txt +0 -0
flwr/server/app.py
CHANGED
@@ -26,7 +26,7 @@ from collections.abc import Sequence
|
|
26
26
|
from logging import DEBUG, INFO, WARN
|
27
27
|
from pathlib import Path
|
28
28
|
from time import sleep
|
29
|
-
from typing import
|
29
|
+
from typing import Callable, Optional, TypeVar, cast
|
30
30
|
|
31
31
|
import grpc
|
32
32
|
import yaml
|
@@ -52,6 +52,8 @@ from flwr.common.constant import (
|
|
52
52
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
53
53
|
TRANSPORT_TYPE_GRPC_RERE,
|
54
54
|
TRANSPORT_TYPE_REST,
|
55
|
+
AuthnType,
|
56
|
+
AuthzType,
|
55
57
|
EventLogWriterType,
|
56
58
|
ExecPluginType,
|
57
59
|
)
|
@@ -59,9 +61,6 @@ from flwr.common.event_log_plugin import EventLogWriterPlugin
|
|
59
61
|
from flwr.common.exit import ExitCode, flwr_exit, register_signal_handlers
|
60
62
|
from flwr.common.grpc import generic_create_grpc_server
|
61
63
|
from flwr.common.logger import log
|
62
|
-
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
63
|
-
public_key_to_bytes,
|
64
|
-
)
|
65
64
|
from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
|
66
65
|
add_FleetServicer_to_server,
|
67
66
|
)
|
@@ -70,13 +69,21 @@ from flwr.server.fleet_event_log_interceptor import FleetEventLogInterceptor
|
|
70
69
|
from flwr.supercore.ffs import FfsFactory
|
71
70
|
from flwr.supercore.grpc_health import add_args_health, run_health_server_grpc_no_tls
|
72
71
|
from flwr.supercore.object_store import ObjectStoreFactory
|
72
|
+
from flwr.supercore.primitives.asymmetric import public_key_to_bytes
|
73
73
|
from flwr.superlink.artifact_provider import ArtifactProvider
|
74
|
-
from flwr.superlink.auth_plugin import
|
74
|
+
from flwr.superlink.auth_plugin import (
|
75
|
+
ControlAuthnPlugin,
|
76
|
+
ControlAuthzPlugin,
|
77
|
+
get_control_authn_plugins,
|
78
|
+
get_control_authz_plugins,
|
79
|
+
)
|
75
80
|
from flwr.superlink.servicer.control import run_control_api_grpc
|
76
81
|
|
77
82
|
from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
|
78
83
|
from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
|
79
|
-
from .superlink.fleet.grpc_rere.
|
84
|
+
from .superlink.fleet.grpc_rere.node_auth_server_interceptor import (
|
85
|
+
NodeAuthServerInterceptor,
|
86
|
+
)
|
80
87
|
from .superlink.linkstate import LinkStateFactory
|
81
88
|
from .superlink.serverappio.serverappio_grpc import run_serverappio_api_grpc
|
82
89
|
from .superlink.simulation.simulationio_grpc import run_simulationio_api_grpc
|
@@ -89,8 +96,6 @@ P = TypeVar("P", ControlAuthnPlugin, ControlAuthzPlugin)
|
|
89
96
|
try:
|
90
97
|
from flwr.ee import (
|
91
98
|
add_ee_args_superlink,
|
92
|
-
get_control_authn_plugins,
|
93
|
-
get_control_authz_plugins,
|
94
99
|
get_control_event_log_writer_plugins,
|
95
100
|
get_ee_artifact_provider,
|
96
101
|
get_fleet_event_log_writer_plugins,
|
@@ -101,14 +106,6 @@ except ImportError:
|
|
101
106
|
def add_ee_args_superlink(parser: argparse.ArgumentParser) -> None:
|
102
107
|
"""Add EE-specific arguments to the parser."""
|
103
108
|
|
104
|
-
def get_control_authn_plugins() -> dict[str, type[ControlAuthnPlugin]]:
|
105
|
-
"""Return all Control API authentication plugins."""
|
106
|
-
raise NotImplementedError("No authentication plugins are currently supported.")
|
107
|
-
|
108
|
-
def get_control_authz_plugins() -> dict[str, type[ControlAuthzPlugin]]:
|
109
|
-
"""Return all Control API authorization plugins."""
|
110
|
-
raise NotImplementedError("No authorization plugins are currently supported.")
|
111
|
-
|
112
109
|
def get_control_event_log_writer_plugins() -> dict[str, type[EventLogWriterPlugin]]:
|
113
110
|
"""Return all Control API event log writer plugins."""
|
114
111
|
raise NotImplementedError(
|
@@ -204,10 +201,9 @@ def run_superlink() -> None:
|
|
204
201
|
"future release. Please use `--account-auth-config` instead.",
|
205
202
|
)
|
206
203
|
args.account_auth_config = cfg_path
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
)
|
204
|
+
cfg_path = getattr(args, "account_auth_config", None)
|
205
|
+
authn_plugin, authz_plugin = _load_control_auth_plugins(cfg_path, verify_tls_cert)
|
206
|
+
if cfg_path is not None:
|
211
207
|
# Enable event logging if the args.enable_event_log is True
|
212
208
|
if args.enable_event_log:
|
213
209
|
event_log_plugin = _try_obtain_control_event_log_writer_plugin()
|
@@ -328,7 +324,7 @@ def run_superlink() -> None:
|
|
328
324
|
else:
|
329
325
|
log(DEBUG, "Automatic node authentication enabled")
|
330
326
|
|
331
|
-
interceptors = [
|
327
|
+
interceptors = [NodeAuthServerInterceptor(state_factory, auto_auth)]
|
332
328
|
if getattr(args, "enable_event_log", None):
|
333
329
|
fleet_log_plugin = _try_obtain_fleet_event_log_writer_plugin()
|
334
330
|
if fleet_log_plugin is not None:
|
@@ -449,13 +445,21 @@ def _try_load_public_keys_node_authentication(
|
|
449
445
|
return node_public_keys
|
450
446
|
|
451
447
|
|
452
|
-
def
|
453
|
-
config_path:
|
448
|
+
def _load_control_auth_plugins(
|
449
|
+
config_path: Optional[str], verify_tls_cert: bool
|
454
450
|
) -> tuple[ControlAuthnPlugin, ControlAuthzPlugin]:
|
455
451
|
"""Obtain Control API authentication and authorization plugins."""
|
452
|
+
# Load NoOp plugins if no config path is provided
|
453
|
+
if config_path is None:
|
454
|
+
config_path = ""
|
455
|
+
config = {
|
456
|
+
"authentication": {AUTHN_TYPE_YAML_KEY: AuthnType.NOOP},
|
457
|
+
"authorization": {AUTHZ_TYPE_YAML_KEY: AuthzType.NOOP},
|
458
|
+
}
|
456
459
|
# Load YAML file
|
457
|
-
|
458
|
-
|
460
|
+
else:
|
461
|
+
with Path(config_path).open("r", encoding="utf-8") as file:
|
462
|
+
config = yaml.safe_load(file)
|
459
463
|
|
460
464
|
def _load_plugin(
|
461
465
|
section: str, yaml_key: str, loader: Callable[[], dict[str, type[P]]]
|
@@ -465,9 +469,7 @@ def _try_obtain_control_auth_plugins(
|
|
465
469
|
try:
|
466
470
|
plugins: dict[str, type[P]] = loader()
|
467
471
|
plugin_cls: type[P] = plugins[auth_plugin_name]
|
468
|
-
return plugin_cls(
|
469
|
-
account_auth_config_path=config_path, verify_tls_cert=verify_tls_cert
|
470
|
-
)
|
472
|
+
return plugin_cls(Path(cast(str, config_path)), verify_tls_cert)
|
471
473
|
except KeyError:
|
472
474
|
if auth_plugin_name:
|
473
475
|
sys.exit(
|
@@ -475,18 +477,15 @@ def _try_obtain_control_auth_plugins(
|
|
475
477
|
f"Please provide a valid {section} type in the configuration."
|
476
478
|
)
|
477
479
|
sys.exit(f"No {section} type is provided in the configuration.")
|
478
|
-
except NotImplementedError:
|
479
|
-
sys.exit(f"No {section} plugins are currently supported.")
|
480
480
|
|
481
|
-
# Warn deprecated
|
482
|
-
if
|
481
|
+
# Warn deprecated auth_type key
|
482
|
+
if authn_type := config["authentication"].pop("auth_type", None):
|
483
483
|
log(
|
484
484
|
WARN,
|
485
|
-
"The `
|
485
|
+
"The `auth_type` key in the authentication configuration is deprecated. "
|
486
486
|
"Use `%s` instead.",
|
487
487
|
AUTHN_TYPE_YAML_KEY,
|
488
488
|
)
|
489
|
-
authn_type = config["authentication"].pop("authn_type")
|
490
489
|
config["authentication"][AUTHN_TYPE_YAML_KEY] = authn_type
|
491
490
|
|
492
491
|
# Load authentication plugin
|
@@ -78,10 +78,14 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
78
78
|
request.heartbeat_interval,
|
79
79
|
)
|
80
80
|
log(DEBUG, "[Fleet.CreateNode] Request: %s", MessageToDict(request))
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
81
|
+
try:
|
82
|
+
response = message_handler.create_node(
|
83
|
+
request=request,
|
84
|
+
state=self.state_factory.state(),
|
85
|
+
)
|
86
|
+
except ValueError as e:
|
87
|
+
# Public key already in use
|
88
|
+
context.abort(grpc.StatusCode.FAILED_PRECONDITION, str(e))
|
85
89
|
log(INFO, "[Fleet.CreateNode] Created node_id=%s", response.node.node_id)
|
86
90
|
log(DEBUG, "[Fleet.CreateNode] Response: %s", MessageToDict(response))
|
87
91
|
return response
|
flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py}
RENAMED
@@ -16,7 +16,7 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
import datetime
|
19
|
-
from typing import Any, Callable,
|
19
|
+
from typing import Any, Callable, cast
|
20
20
|
|
21
21
|
import grpc
|
22
22
|
from google.protobuf.message import Message as GrpcMessage
|
@@ -29,15 +29,9 @@ from flwr.common.constant import (
|
|
29
29
|
TIMESTAMP_HEADER,
|
30
30
|
TIMESTAMP_TOLERANCE,
|
31
31
|
)
|
32
|
-
from flwr.
|
33
|
-
bytes_to_public_key,
|
34
|
-
verify_signature,
|
35
|
-
)
|
36
|
-
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
37
|
-
CreateNodeRequest,
|
38
|
-
CreateNodeResponse,
|
39
|
-
)
|
32
|
+
from flwr.proto.fleet_pb2 import CreateNodeRequest # pylint: disable=E0611
|
40
33
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
34
|
+
from flwr.supercore.primitives.asymmetric import bytes_to_public_key, verify_signature
|
41
35
|
|
42
36
|
MIN_TIMESTAMP_DIFF = -SYSTEM_TIME_TOLERANCE
|
43
37
|
MAX_TIMESTAMP_DIFF = TIMESTAMP_TOLERANCE + SYSTEM_TIME_TOLERANCE
|
@@ -53,7 +47,7 @@ def _unary_unary_rpc_terminator(
|
|
53
47
|
return grpc.unary_unary_rpc_method_handler(terminate)
|
54
48
|
|
55
49
|
|
56
|
-
class
|
50
|
+
class NodeAuthServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
57
51
|
"""Server interceptor for node authentication.
|
58
52
|
|
59
53
|
Parameters
|
@@ -113,50 +107,34 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
113
107
|
if not MIN_TIMESTAMP_DIFF < time_diff.total_seconds() < MAX_TIMESTAMP_DIFF:
|
114
108
|
return _unary_unary_rpc_terminator("Invalid timestamp")
|
115
109
|
|
116
|
-
# Continue the RPC call
|
117
|
-
expected_node_id = state.get_node_id(node_pk_bytes)
|
118
|
-
if not handler_call_details.method.endswith("CreateNode"):
|
119
|
-
# All calls, except for `CreateNode`, must provide a public key that is
|
120
|
-
# already mapped to a `node_id` (in `LinkState`)
|
121
|
-
if expected_node_id is None:
|
122
|
-
return _unary_unary_rpc_terminator("Invalid node ID")
|
123
|
-
# One of the method handlers in
|
110
|
+
# Continue the RPC call: One of the method handlers in
|
124
111
|
# `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
|
125
112
|
method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
|
126
|
-
return self._wrap_method_handler(
|
127
|
-
method_handler, expected_node_id, node_pk_bytes
|
128
|
-
)
|
113
|
+
return self._wrap_method_handler(method_handler, node_pk_bytes)
|
129
114
|
|
130
115
|
def _wrap_method_handler(
|
131
116
|
self,
|
132
117
|
method_handler: grpc.RpcMethodHandler,
|
133
|
-
|
134
|
-
node_public_key: bytes,
|
118
|
+
expected_public_key: bytes,
|
135
119
|
) -> grpc.RpcMethodHandler:
|
136
120
|
def _generic_method_handler(
|
137
121
|
request: GrpcMessage,
|
138
122
|
context: grpc.ServicerContext,
|
139
123
|
) -> GrpcMessage:
|
140
|
-
#
|
141
|
-
if
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
124
|
+
# Retrieve the public key
|
125
|
+
if isinstance(request, CreateNodeRequest):
|
126
|
+
actual_public_key = request.public_key
|
127
|
+
else:
|
128
|
+
# Note: This function runs in a different thread
|
129
|
+
# than the `intercept_service` function.
|
130
|
+
actual_public_key = self.state_factory.state().get_node_public_key(
|
131
|
+
request.node.node_id # type: ignore
|
132
|
+
)
|
133
|
+
# Verify the public key
|
134
|
+
if actual_public_key != expected_public_key:
|
135
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID")
|
147
136
|
|
148
137
|
response: GrpcMessage = method_handler.unary_unary(request, context)
|
149
|
-
|
150
|
-
# Set the public key after a successful CreateNode request
|
151
|
-
if isinstance(response, CreateNodeResponse):
|
152
|
-
state = self.state_factory.state()
|
153
|
-
try:
|
154
|
-
state.set_node_public_key(response.node.node_id, node_public_key)
|
155
|
-
except ValueError as e:
|
156
|
-
# Remove newly created node if setting the public key fails
|
157
|
-
state.delete_node(response.node.node_id)
|
158
|
-
context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e))
|
159
|
-
|
160
138
|
return response
|
161
139
|
|
162
140
|
return grpc.unary_unary_rpc_method_handler(
|
@@ -70,7 +70,7 @@ def create_node(
|
|
70
70
|
) -> CreateNodeResponse:
|
71
71
|
"""."""
|
72
72
|
# Create node
|
73
|
-
node_id = state.create_node(
|
73
|
+
node_id = state.create_node(request.public_key, request.heartbeat_interval)
|
74
74
|
return CreateNodeResponse(node=Node(node_id=node_id))
|
75
75
|
|
76
76
|
|
@@ -16,6 +16,7 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
import json
|
19
|
+
import secrets
|
19
20
|
import threading
|
20
21
|
import time
|
21
22
|
import traceback
|
@@ -53,7 +54,12 @@ def _register_nodes(
|
|
53
54
|
nodes_mapping: NodeToPartitionMapping = {}
|
54
55
|
state = state_factory.state()
|
55
56
|
for i in range(num_nodes):
|
56
|
-
node_id = state.create_node(
|
57
|
+
node_id = state.create_node(
|
58
|
+
# No node authentication in simulation;
|
59
|
+
# use random bytes instead
|
60
|
+
secrets.token_bytes(32),
|
61
|
+
heartbeat_interval=HEARTBEAT_MAX_INTERVAL,
|
62
|
+
)
|
57
63
|
nodes_mapping[node_id] = i
|
58
64
|
log(DEBUG, "Registered %i nodes", len(nodes_mapping))
|
59
65
|
return nodes_mapping
|
@@ -17,7 +17,6 @@
|
|
17
17
|
|
18
18
|
import secrets
|
19
19
|
import threading
|
20
|
-
import time
|
21
20
|
from bisect import bisect_right
|
22
21
|
from collections import defaultdict
|
23
22
|
from dataclasses import dataclass, field
|
@@ -39,6 +38,7 @@ from flwr.common.constant import (
|
|
39
38
|
)
|
40
39
|
from flwr.common.record import ConfigRecord
|
41
40
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
41
|
+
from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
|
42
42
|
from flwr.server.superlink.linkstate.linkstate import LinkState
|
43
43
|
from flwr.server.utils import validate_message
|
44
44
|
|
@@ -70,7 +70,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
70
70
|
def __init__(self) -> None:
|
71
71
|
|
72
72
|
# Map node_id to (online_until, heartbeat_interval)
|
73
|
-
self.
|
73
|
+
self.nodes: dict[int, NodeInfo] = {}
|
74
74
|
self.public_key_to_node_id: dict[bytes, int] = {}
|
75
75
|
self.node_id_to_public_key: dict[int, bytes] = {}
|
76
76
|
|
@@ -114,7 +114,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
114
114
|
)
|
115
115
|
return None
|
116
116
|
# Validate destination node ID
|
117
|
-
if message.metadata.dst_node_id not in self.
|
117
|
+
if message.metadata.dst_node_id not in self.nodes:
|
118
118
|
log(
|
119
119
|
ERROR,
|
120
120
|
"Invalid destination node ID for Message: %s",
|
@@ -136,7 +136,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
136
136
|
|
137
137
|
# Find Message for node_id that were not delivered yet
|
138
138
|
message_ins_list: list[Message] = []
|
139
|
-
current_time =
|
139
|
+
current_time = now().timestamp()
|
140
140
|
with self.lock:
|
141
141
|
for _, msg_ins in self.message_ins_store.items():
|
142
142
|
if (
|
@@ -190,7 +190,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
190
190
|
return None
|
191
191
|
|
192
192
|
ins_metadata = msg_ins.metadata
|
193
|
-
if ins_metadata.created_at + ins_metadata.ttl <=
|
193
|
+
if ins_metadata.created_at + ins_metadata.ttl <= now().timestamp():
|
194
194
|
log(
|
195
195
|
ERROR,
|
196
196
|
"Failed to store Message: the message it is replying to "
|
@@ -238,7 +238,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
238
238
|
ret: dict[str, Message] = {}
|
239
239
|
|
240
240
|
with self.lock:
|
241
|
-
current =
|
241
|
+
current = now().timestamp()
|
242
242
|
|
243
243
|
# Verify Message IDs
|
244
244
|
ret = verify_message_ids(
|
@@ -256,9 +256,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
256
256
|
inquired_in_message_ids=message_ids,
|
257
257
|
found_in_message_dict=self.message_ins_store,
|
258
258
|
node_id_to_online_until={
|
259
|
-
node_id: self.
|
259
|
+
node_id: self.nodes[node_id].online_until
|
260
260
|
for node_id in dst_node_ids
|
261
|
-
if node_id in self.
|
261
|
+
if node_id in self.nodes
|
262
262
|
},
|
263
263
|
current_time=current,
|
264
264
|
)
|
@@ -330,7 +330,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
330
330
|
"""
|
331
331
|
return len(self.message_res_store)
|
332
332
|
|
333
|
-
def create_node(self, heartbeat_interval: float) -> int:
|
333
|
+
def create_node(self, public_key: bytes, heartbeat_interval: float) -> int:
|
334
334
|
"""Create, store in the link state, and return `node_id`."""
|
335
335
|
# Sample a random int64 as node_id
|
336
336
|
node_id = generate_rand_int_from_bytes(
|
@@ -338,28 +338,40 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
338
338
|
)
|
339
339
|
|
340
340
|
with self.lock:
|
341
|
-
if node_id in self.
|
341
|
+
if node_id in self.nodes:
|
342
342
|
log(ERROR, "Unexpected node registration failure.")
|
343
343
|
return 0
|
344
|
+
if public_key in self.public_key_to_node_id:
|
345
|
+
raise ValueError("Public key already in use")
|
344
346
|
|
345
|
-
# Mark the node online until
|
346
|
-
|
347
|
-
|
348
|
-
|
347
|
+
# Mark the node online until now().timestamp() + heartbeat_interval
|
348
|
+
current = now()
|
349
|
+
self.nodes[node_id] = NodeInfo(
|
350
|
+
node_id=node_id,
|
351
|
+
owner_aid="", # Unused for now
|
352
|
+
status="created", # Unused for now
|
353
|
+
created_at=current.isoformat(), # Unused for now
|
354
|
+
last_activated_at=current.isoformat(), # Unused for now
|
355
|
+
last_deactivated_at="", # Unused for now
|
356
|
+
deleted_at="", # Unused for now
|
357
|
+
online_until=current.timestamp() + heartbeat_interval,
|
358
|
+
heartbeat_interval=heartbeat_interval,
|
349
359
|
)
|
360
|
+
self.public_key_to_node_id[public_key] = node_id
|
361
|
+
self.node_id_to_public_key[node_id] = public_key
|
350
362
|
return node_id
|
351
363
|
|
352
364
|
def delete_node(self, node_id: int) -> None:
|
353
365
|
"""Delete a node."""
|
354
366
|
with self.lock:
|
355
|
-
if node_id not in self.
|
367
|
+
if node_id not in self.nodes:
|
356
368
|
raise ValueError(f"Node {node_id} not found")
|
357
369
|
|
358
370
|
# Remove node ID <> public key mappings
|
359
371
|
if pk := self.node_id_to_public_key.pop(node_id, None):
|
360
372
|
del self.public_key_to_node_id[pk]
|
361
373
|
|
362
|
-
del self.
|
374
|
+
del self.nodes[node_id]
|
363
375
|
|
364
376
|
def get_nodes(self, run_id: int) -> set[int]:
|
365
377
|
"""Return all available nodes.
|
@@ -372,17 +384,17 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
372
384
|
with self.lock:
|
373
385
|
if run_id not in self.run_ids:
|
374
386
|
return set()
|
375
|
-
current_time =
|
387
|
+
current_time = now().timestamp()
|
376
388
|
return {
|
377
|
-
node_id
|
378
|
-
for
|
379
|
-
if online_until > current_time
|
389
|
+
info.node_id
|
390
|
+
for info in self.nodes.values()
|
391
|
+
if info.online_until > current_time
|
380
392
|
}
|
381
393
|
|
382
394
|
def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
|
383
395
|
"""Set `public_key` for the specified `node_id`."""
|
384
396
|
with self.lock:
|
385
|
-
if node_id not in self.
|
397
|
+
if node_id not in self.nodes:
|
386
398
|
raise ValueError(f"Node {node_id} not found")
|
387
399
|
|
388
400
|
if public_key in self.public_key_to_node_id:
|
@@ -394,7 +406,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
394
406
|
def get_node_public_key(self, node_id: int) -> Optional[bytes]:
|
395
407
|
"""Get `public_key` for the specified `node_id`."""
|
396
408
|
with self.lock:
|
397
|
-
if node_id not in self.
|
409
|
+
if node_id not in self.nodes:
|
398
410
|
raise ValueError(f"Node {node_id} not found")
|
399
411
|
|
400
412
|
return self.node_id_to_public_key.get(node_id)
|
@@ -608,13 +620,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
608
620
|
the node is marked as offline.
|
609
621
|
"""
|
610
622
|
with self.lock:
|
611
|
-
if
|
612
|
-
|
613
|
-
|
614
|
-
heartbeat_interval,
|
623
|
+
if info := self.nodes.get(node_id):
|
624
|
+
info.online_until = (
|
625
|
+
now().timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval
|
615
626
|
)
|
627
|
+
info.heartbeat_interval = heartbeat_interval
|
616
628
|
return True
|
617
|
-
|
629
|
+
return False
|
618
630
|
|
619
631
|
def acknowledge_app_heartbeat(self, run_id: int, heartbeat_interval: float) -> bool:
|
620
632
|
"""Acknowledge a heartbeat received from a ServerApp for a given run.
|
@@ -128,7 +128,7 @@ class LinkState(CoreState): # pylint: disable=R0904
|
|
128
128
|
"""Get all instruction Message IDs for the given run_id."""
|
129
129
|
|
130
130
|
@abc.abstractmethod
|
131
|
-
def create_node(self, heartbeat_interval: float) -> int:
|
131
|
+
def create_node(self, public_key: bytes, heartbeat_interval: float) -> int:
|
132
132
|
"""Create, store in the link state, and return `node_id`."""
|
133
133
|
|
134
134
|
@abc.abstractmethod
|
@@ -21,7 +21,6 @@ import json
|
|
21
21
|
import re
|
22
22
|
import secrets
|
23
23
|
import sqlite3
|
24
|
-
import time
|
25
24
|
from collections.abc import Sequence
|
26
25
|
from logging import DEBUG, ERROR, WARNING
|
27
26
|
from typing import Any, Optional, Union, cast
|
@@ -72,10 +71,16 @@ from .utils import (
|
|
72
71
|
|
73
72
|
SQL_CREATE_TABLE_NODE = """
|
74
73
|
CREATE TABLE IF NOT EXISTS node(
|
75
|
-
node_id
|
76
|
-
|
77
|
-
|
78
|
-
|
74
|
+
node_id INTEGER UNIQUE,
|
75
|
+
owner_aid TEXT,
|
76
|
+
status TEXT,
|
77
|
+
created_at TEXT,
|
78
|
+
last_activated_at TEXT,
|
79
|
+
last_deactivated_at TEXT,
|
80
|
+
deleted_at TEXT,
|
81
|
+
online_until REAL,
|
82
|
+
heartbeat_interval REAL,
|
83
|
+
public_key BLOB UNIQUE
|
79
84
|
);
|
80
85
|
"""
|
81
86
|
|
@@ -451,7 +456,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
451
456
|
ret: dict[str, Message] = {}
|
452
457
|
|
453
458
|
# Verify Message IDs
|
454
|
-
current =
|
459
|
+
current = now().timestamp()
|
455
460
|
query = f"""
|
456
461
|
SELECT *
|
457
462
|
FROM message_ins
|
@@ -597,7 +602,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
597
602
|
|
598
603
|
return {row["message_id"] for row in rows}
|
599
604
|
|
600
|
-
def create_node(self, heartbeat_interval: float) -> int:
|
605
|
+
def create_node(self, public_key: bytes, heartbeat_interval: float) -> int:
|
601
606
|
"""Create, store in the link state, and return `node_id`."""
|
602
607
|
# Sample a random uint64 as node_id
|
603
608
|
uint64_node_id = generate_rand_int_from_bytes(
|
@@ -607,24 +612,35 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
607
612
|
# Convert the uint64 value to sint64 for SQLite
|
608
613
|
sint64_node_id = convert_uint64_to_sint64(uint64_node_id)
|
609
614
|
|
610
|
-
query =
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
+
query = """
|
616
|
+
INSERT INTO node
|
617
|
+
(node_id, owner_aid, status, created_at, last_activated_at,
|
618
|
+
last_deactivated_at, deleted_at, online_until, heartbeat_interval,
|
619
|
+
public_key)
|
620
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
621
|
+
"""
|
615
622
|
|
616
|
-
# Mark the node online
|
623
|
+
# Mark the node online until now().timestamp() + heartbeat_interval
|
617
624
|
try:
|
618
625
|
self.query(
|
619
626
|
query,
|
620
627
|
(
|
621
|
-
sint64_node_id,
|
622
|
-
|
623
|
-
|
624
|
-
|
628
|
+
sint64_node_id, # node_id
|
629
|
+
"", # owner_aid, unused for now
|
630
|
+
"created", # status, unused for now
|
631
|
+
now().isoformat(), # created_at, unused for now
|
632
|
+
now().isoformat(), # last_activated_at, unused for now
|
633
|
+
"", # last_deactivated_at, unused for now
|
634
|
+
"", # deleted_at, unused for now
|
635
|
+
now().timestamp() + heartbeat_interval, # online_until
|
636
|
+
heartbeat_interval, # heartbeat_interval
|
637
|
+
public_key, # public_key
|
625
638
|
),
|
626
639
|
)
|
627
|
-
except sqlite3.IntegrityError:
|
640
|
+
except sqlite3.IntegrityError as e:
|
641
|
+
if "UNIQUE constraint failed: node.public_key" in str(e):
|
642
|
+
raise ValueError("Public key already in use.") from None
|
643
|
+
# Must be node ID conflict, almost impossible unless system is compromised
|
628
644
|
log(ERROR, "Unexpected node registration failure.")
|
629
645
|
return 0
|
630
646
|
|
@@ -668,7 +684,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
668
684
|
|
669
685
|
# Get nodes
|
670
686
|
query = "SELECT node_id FROM node WHERE online_until > ?;"
|
671
|
-
rows = self.query(query, (
|
687
|
+
rows = self.query(query, (now().timestamp(),))
|
672
688
|
|
673
689
|
# Convert sint64 node_ids to uint64
|
674
690
|
result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows}
|
@@ -1010,7 +1026,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
1010
1026
|
self.query(
|
1011
1027
|
query,
|
1012
1028
|
(
|
1013
|
-
|
1029
|
+
now().timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval,
|
1014
1030
|
heartbeat_interval,
|
1015
1031
|
sint64_node_id,
|
1016
1032
|
),
|
@@ -1140,7 +1156,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
1140
1156
|
message_ins = rows[0]
|
1141
1157
|
created_at = message_ins["created_at"]
|
1142
1158
|
ttl = message_ins["ttl"]
|
1143
|
-
current_time =
|
1159
|
+
current_time = now().timestamp()
|
1144
1160
|
|
1145
1161
|
# Check if Message is expired
|
1146
1162
|
if ttl is not None and created_at + ttl <= current_time:
|
flwr/server/utils/validator.py
CHANGED
@@ -15,10 +15,9 @@
|
|
15
15
|
"""Validators."""
|
16
16
|
|
17
17
|
|
18
|
-
import time
|
19
|
-
|
20
18
|
from flwr.common import Message
|
21
19
|
from flwr.common.constant import SUPERLINK_NODE_ID
|
20
|
+
from flwr.common.date import now
|
22
21
|
|
23
22
|
|
24
23
|
# pylint: disable-next=too-many-branches
|
@@ -44,7 +43,7 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
|
|
44
43
|
validation_errors.append("`metadata.ttl` must be higher than zero")
|
45
44
|
|
46
45
|
# Verify TTL and created_at time
|
47
|
-
current_time =
|
46
|
+
current_time = now().timestamp()
|
48
47
|
if metadata.created_at + metadata.ttl <= current_time:
|
49
48
|
validation_errors.append("Message TTL has expired")
|
50
49
|
|
@@ -35,8 +35,6 @@ from flwr.common import (
|
|
35
35
|
)
|
36
36
|
from flwr.common.secure_aggregation.crypto.shamir import combine_shares
|
37
37
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
38
|
-
bytes_to_private_key,
|
39
|
-
bytes_to_public_key,
|
40
38
|
generate_shared_key,
|
41
39
|
)
|
42
40
|
from flwr.common.secure_aggregation.ndarrays_arithmetic import (
|
@@ -56,6 +54,10 @@ from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
|
|
56
54
|
from flwr.server.client_proxy import ClientProxy
|
57
55
|
from flwr.server.compat.legacy_context import LegacyContext
|
58
56
|
from flwr.server.grid import Grid
|
57
|
+
from flwr.supercore.primitives.asymmetric import (
|
58
|
+
bytes_to_private_key,
|
59
|
+
bytes_to_public_key,
|
60
|
+
)
|
59
61
|
|
60
62
|
from ..constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD
|
61
63
|
from ..constant import Key as WorkflowKey
|