flwr-nightly 1.23.0.dev20251006__py3-none-any.whl → 1.23.0.dev20251008__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/auth_plugin/__init__.py +7 -3
- flwr/cli/log.py +2 -2
- flwr/cli/login/login.py +4 -13
- flwr/cli/ls.py +2 -2
- flwr/cli/pull.py +2 -2
- flwr/cli/run/run.py +2 -2
- flwr/cli/stop.py +2 -2
- flwr/cli/supernode/ls.py +2 -2
- flwr/cli/utils.py +28 -44
- flwr/client/grpc_rere_client/connection.py +6 -6
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
- flwr/client/rest_client/connection.py +7 -1
- flwr/common/constant.py +10 -0
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
- flwr/proto/fleet_pb2.py +22 -22
- flwr/proto/fleet_pb2.pyi +4 -1
- flwr/proto/node_pb2.py +1 -1
- flwr/server/app.py +33 -34
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +8 -4
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +19 -41
- flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +7 -1
- flwr/server/superlink/linkstate/in_memory_linkstate.py +39 -27
- flwr/server/superlink/linkstate/linkstate.py +1 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +37 -21
- flwr/server/utils/validator.py +2 -3
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
- flwr/supercore/primitives/__init__.py +15 -0
- flwr/supercore/primitives/asymmetric.py +109 -0
- flwr/superlink/auth_plugin/__init__.py +29 -0
- flwr/superlink/servicer/control/control_grpc.py +9 -7
- flwr/superlink/servicer/control/control_servicer.py +34 -46
- {flwr_nightly-1.23.0.dev20251006.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/METADATA +1 -1
- {flwr_nightly-1.23.0.dev20251006.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/RECORD +37 -35
- {flwr_nightly-1.23.0.dev20251006.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.23.0.dev20251006.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/entry_points.txt +0 -0
flwr/cli/auth_plugin/__init__.py
CHANGED
@@ -22,9 +22,13 @@ from .noop_auth_plugin import NoOpCliAuthPlugin
|
|
22
22
|
from .oidc_cli_plugin import OidcCliPlugin
|
23
23
|
|
24
24
|
|
25
|
-
def
|
25
|
+
def get_cli_plugin_class(authn_type: str) -> type[CliAuthPlugin]:
|
26
26
|
"""Return all CLI authentication plugins."""
|
27
|
-
|
27
|
+
if authn_type == AuthnType.NOOP:
|
28
|
+
return NoOpCliAuthPlugin
|
29
|
+
if authn_type == AuthnType.OIDC:
|
30
|
+
return OidcCliPlugin
|
31
|
+
raise ValueError(f"Unsupported authentication type: {authn_type}")
|
28
32
|
|
29
33
|
|
30
34
|
__all__ = [
|
@@ -32,5 +36,5 @@ __all__ = [
|
|
32
36
|
"LoginError",
|
33
37
|
"NoOpCliAuthPlugin",
|
34
38
|
"OidcCliPlugin",
|
35
|
-
"
|
39
|
+
"get_cli_plugin_class",
|
36
40
|
]
|
flwr/cli/log.py
CHANGED
@@ -35,7 +35,7 @@ from flwr.common.logger import log as logger
|
|
35
35
|
from flwr.proto.control_pb2 import StreamLogsRequest # pylint: disable=E0611
|
36
36
|
from flwr.proto.control_pb2_grpc import ControlStub
|
37
37
|
|
38
|
-
from .utils import flwr_cli_grpc_exc_handler, init_channel,
|
38
|
+
from .utils import flwr_cli_grpc_exc_handler, init_channel, load_cli_auth_plugin
|
39
39
|
|
40
40
|
|
41
41
|
class AllLogsRetrieved(BaseException):
|
@@ -186,7 +186,7 @@ def _log_with_control_api(
|
|
186
186
|
run_id: int,
|
187
187
|
stream: bool,
|
188
188
|
) -> None:
|
189
|
-
auth_plugin =
|
189
|
+
auth_plugin = load_cli_auth_plugin(app, federation, federation_config)
|
190
190
|
channel = init_channel(app, federation_config, auth_plugin)
|
191
191
|
|
192
192
|
if stream:
|
flwr/cli/login/login.py
CHANGED
@@ -20,7 +20,7 @@ from typing import Annotated, Optional
|
|
20
20
|
|
21
21
|
import typer
|
22
22
|
|
23
|
-
from flwr.cli.auth_plugin import LoginError
|
23
|
+
from flwr.cli.auth_plugin import LoginError, NoOpCliAuthPlugin
|
24
24
|
from flwr.cli.config_utils import (
|
25
25
|
exit_if_no_address,
|
26
26
|
get_insecure_flag,
|
@@ -40,7 +40,7 @@ from ..utils import (
|
|
40
40
|
account_auth_enabled,
|
41
41
|
flwr_cli_grpc_exc_handler,
|
42
42
|
init_channel,
|
43
|
-
|
43
|
+
load_cli_auth_plugin,
|
44
44
|
)
|
45
45
|
|
46
46
|
|
@@ -95,7 +95,7 @@ def login( # pylint: disable=R0914
|
|
95
95
|
)
|
96
96
|
raise typer.Exit(code=1)
|
97
97
|
|
98
|
-
channel = init_channel(app, federation_config,
|
98
|
+
channel = init_channel(app, federation_config, NoOpCliAuthPlugin(Path()))
|
99
99
|
stub = ControlStub(channel)
|
100
100
|
|
101
101
|
login_request = GetLoginDetailsRequest()
|
@@ -104,16 +104,7 @@ def login( # pylint: disable=R0914
|
|
104
104
|
|
105
105
|
# Get the auth plugin
|
106
106
|
authn_type = login_response.authn_type
|
107
|
-
auth_plugin =
|
108
|
-
app, federation, federation_config, authn_type
|
109
|
-
)
|
110
|
-
if auth_plugin is None:
|
111
|
-
typer.secho(
|
112
|
-
f'❌ Authentication type "{authn_type}" not found',
|
113
|
-
fg=typer.colors.RED,
|
114
|
-
bold=True,
|
115
|
-
)
|
116
|
-
raise typer.Exit(code=1)
|
107
|
+
auth_plugin = load_cli_auth_plugin(app, federation, federation_config, authn_type)
|
117
108
|
|
118
109
|
# Login
|
119
110
|
details = AccountAuthLoginDetails(
|
flwr/cli/ls.py
CHANGED
@@ -44,7 +44,7 @@ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
|
|
44
44
|
)
|
45
45
|
from flwr.proto.control_pb2_grpc import ControlStub
|
46
46
|
|
47
|
-
from .utils import flwr_cli_grpc_exc_handler, init_channel,
|
47
|
+
from .utils import flwr_cli_grpc_exc_handler, init_channel, load_cli_auth_plugin
|
48
48
|
|
49
49
|
_RunListType = tuple[int, str, str, str, str, str, str, str, str]
|
50
50
|
|
@@ -127,7 +127,7 @@ def ls( # pylint: disable=too-many-locals, too-many-branches, R0913, R0917
|
|
127
127
|
raise ValueError(
|
128
128
|
"The options '--runs' and '--run-id' are mutually exclusive."
|
129
129
|
)
|
130
|
-
auth_plugin =
|
130
|
+
auth_plugin = load_cli_auth_plugin(app, federation, federation_config)
|
131
131
|
channel = init_channel(app, federation_config, auth_plugin)
|
132
132
|
stub = ControlStub(channel)
|
133
133
|
|
flwr/cli/pull.py
CHANGED
@@ -34,7 +34,7 @@ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
|
|
34
34
|
)
|
35
35
|
from flwr.proto.control_pb2_grpc import ControlStub
|
36
36
|
|
37
|
-
from .utils import flwr_cli_grpc_exc_handler, init_channel,
|
37
|
+
from .utils import flwr_cli_grpc_exc_handler, init_channel, load_cli_auth_plugin
|
38
38
|
|
39
39
|
|
40
40
|
def pull( # pylint: disable=R0914
|
@@ -74,7 +74,7 @@ def pull( # pylint: disable=R0914
|
|
74
74
|
channel = None
|
75
75
|
try:
|
76
76
|
|
77
|
-
auth_plugin =
|
77
|
+
auth_plugin = load_cli_auth_plugin(app, federation, federation_config)
|
78
78
|
channel = init_channel(app, federation_config, auth_plugin)
|
79
79
|
stub = ControlStub(channel)
|
80
80
|
with flwr_cli_grpc_exc_handler():
|
flwr/cli/run/run.py
CHANGED
@@ -45,7 +45,7 @@ from flwr.proto.control_pb2 import StartRunRequest # pylint: disable=E0611
|
|
45
45
|
from flwr.proto.control_pb2_grpc import ControlStub
|
46
46
|
|
47
47
|
from ..log import start_stream
|
48
|
-
from ..utils import flwr_cli_grpc_exc_handler, init_channel,
|
48
|
+
from ..utils import flwr_cli_grpc_exc_handler, init_channel, load_cli_auth_plugin
|
49
49
|
|
50
50
|
CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds)
|
51
51
|
|
@@ -148,7 +148,7 @@ def _run_with_control_api(
|
|
148
148
|
) -> None:
|
149
149
|
channel = None
|
150
150
|
try:
|
151
|
-
auth_plugin =
|
151
|
+
auth_plugin = load_cli_auth_plugin(app, federation, federation_config)
|
152
152
|
channel = init_channel(app, federation_config, auth_plugin)
|
153
153
|
stub = ControlStub(channel)
|
154
154
|
|
flwr/cli/stop.py
CHANGED
@@ -38,7 +38,7 @@ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
|
|
38
38
|
)
|
39
39
|
from flwr.proto.control_pb2_grpc import ControlStub
|
40
40
|
|
41
|
-
from .utils import flwr_cli_grpc_exc_handler, init_channel,
|
41
|
+
from .utils import flwr_cli_grpc_exc_handler, init_channel, load_cli_auth_plugin
|
42
42
|
|
43
43
|
|
44
44
|
def stop( # pylint: disable=R0914
|
@@ -89,7 +89,7 @@ def stop( # pylint: disable=R0914
|
|
89
89
|
exit_if_no_address(federation_config, "stop")
|
90
90
|
channel = None
|
91
91
|
try:
|
92
|
-
auth_plugin =
|
92
|
+
auth_plugin = load_cli_auth_plugin(app, federation, federation_config)
|
93
93
|
channel = init_channel(app, federation_config, auth_plugin)
|
94
94
|
stub = ControlStub(channel) # pylint: disable=unused-variable # noqa: F841
|
95
95
|
|
flwr/cli/supernode/ls.py
CHANGED
@@ -42,7 +42,7 @@ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
|
|
42
42
|
from flwr.proto.control_pb2_grpc import ControlStub
|
43
43
|
from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
|
44
44
|
|
45
|
-
from ..utils import flwr_cli_grpc_exc_handler, init_channel,
|
45
|
+
from ..utils import flwr_cli_grpc_exc_handler, init_channel, load_cli_auth_plugin
|
46
46
|
|
47
47
|
_NodeListType = tuple[int, str, str, str, str, str, str, str]
|
48
48
|
|
@@ -94,7 +94,7 @@ def ls( # pylint: disable=R0914
|
|
94
94
|
exit_if_no_address(federation_config, f"supernode {command_name}")
|
95
95
|
channel = None
|
96
96
|
try:
|
97
|
-
auth_plugin =
|
97
|
+
auth_plugin = load_cli_auth_plugin(app, federation, federation_config)
|
98
98
|
channel = init_channel(app, federation_config, auth_plugin)
|
99
99
|
stub = ControlStub(channel)
|
100
100
|
typer.echo("📄 Listing all nodes...")
|
flwr/cli/utils.py
CHANGED
@@ -34,6 +34,7 @@ from flwr.common.constant import (
|
|
34
34
|
NO_ARTIFACT_PROVIDER_MESSAGE,
|
35
35
|
PULL_UNFINISHED_RUN_MESSAGE,
|
36
36
|
RUN_ID_NOT_FOUND_MESSAGE,
|
37
|
+
AuthnType,
|
37
38
|
)
|
38
39
|
from flwr.common.grpc import (
|
39
40
|
GRPC_MAX_MESSAGE_LENGTH,
|
@@ -41,7 +42,7 @@ from flwr.common.grpc import (
|
|
41
42
|
on_channel_state_change,
|
42
43
|
)
|
43
44
|
|
44
|
-
from .auth_plugin import CliAuthPlugin,
|
45
|
+
from .auth_plugin import CliAuthPlugin, get_cli_plugin_class
|
45
46
|
from .cli_account_auth_interceptor import CliAccountAuthInterceptor
|
46
47
|
from .config_utils import validate_certificate_in_federation_config
|
47
48
|
|
@@ -230,71 +231,54 @@ def account_auth_enabled(federation_config: dict[str, Any]) -> bool:
|
|
230
231
|
return enabled
|
231
232
|
|
232
233
|
|
233
|
-
def
|
234
|
+
def retrieve_authn_type(config_path: Path) -> str:
|
235
|
+
"""Retrieve the auth type from the config file or return NOOP if not found."""
|
236
|
+
try:
|
237
|
+
with config_path.open("r", encoding="utf-8") as file:
|
238
|
+
json_file = json.load(file)
|
239
|
+
authn_type: str = json_file[AUTHN_TYPE_JSON_KEY]
|
240
|
+
return authn_type
|
241
|
+
except (FileNotFoundError, KeyError):
|
242
|
+
return AuthnType.NOOP
|
243
|
+
|
244
|
+
|
245
|
+
def load_cli_auth_plugin(
|
234
246
|
root_dir: Path,
|
235
247
|
federation: str,
|
236
248
|
federation_config: dict[str, Any],
|
237
249
|
authn_type: Optional[str] = None,
|
238
|
-
) ->
|
250
|
+
) -> CliAuthPlugin:
|
239
251
|
"""Load the CLI-side account auth plugin for the given authn type."""
|
240
|
-
#
|
241
|
-
if not account_auth_enabled(federation_config):
|
242
|
-
return None
|
243
|
-
|
252
|
+
# Find the path to the account auth config file
|
244
253
|
config_path = get_account_auth_config_path(root_dir, federation)
|
245
254
|
|
246
|
-
#
|
247
|
-
#
|
255
|
+
# Determine the auth type if not provided
|
256
|
+
# Only `flwr login` command can provide `authn_type` explicitly, as it can query the
|
257
|
+
# SuperLink for the auth type.
|
248
258
|
if authn_type is None:
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
authn_type = json_file[AUTHN_TYPE_JSON_KEY]
|
253
|
-
except (FileNotFoundError, KeyError):
|
254
|
-
typer.secho(
|
255
|
-
"❌ Missing or invalid credentials for account authentication. "
|
256
|
-
"Please run `flwr login` to authenticate.",
|
257
|
-
fg=typer.colors.RED,
|
258
|
-
bold=True,
|
259
|
-
)
|
260
|
-
raise typer.Exit(code=1) from None
|
259
|
+
authn_type = AuthnType.NOOP
|
260
|
+
if account_auth_enabled(federation_config):
|
261
|
+
authn_type = retrieve_authn_type(config_path)
|
261
262
|
|
262
263
|
# Retrieve auth plugin class and instantiate it
|
263
264
|
try:
|
264
|
-
|
265
|
-
auth_plugin_class = all_plugins[authn_type]
|
265
|
+
auth_plugin_class = get_cli_plugin_class(authn_type)
|
266
266
|
return auth_plugin_class(config_path)
|
267
|
-
except
|
267
|
+
except ValueError:
|
268
268
|
typer.echo(f"❌ Unknown account authentication type: {authn_type}")
|
269
269
|
raise typer.Exit(code=1) from None
|
270
|
-
except ImportError:
|
271
|
-
typer.echo("❌ No authentication plugins are currently supported.")
|
272
|
-
raise typer.Exit(code=1) from None
|
273
270
|
|
274
271
|
|
275
272
|
def init_channel(
|
276
|
-
app: Path, federation_config: dict[str, Any], auth_plugin:
|
273
|
+
app: Path, federation_config: dict[str, Any], auth_plugin: CliAuthPlugin
|
277
274
|
) -> grpc.Channel:
|
278
275
|
"""Initialize gRPC channel to the Control API."""
|
279
276
|
insecure, root_certificates_bytes = validate_certificate_in_federation_config(
|
280
277
|
app, federation_config
|
281
278
|
)
|
282
279
|
|
283
|
-
#
|
284
|
-
|
285
|
-
if auth_plugin is not None:
|
286
|
-
# Check if TLS is enabled. If not, raise an error
|
287
|
-
if insecure:
|
288
|
-
typer.secho(
|
289
|
-
"❌ Account authentication requires TLS to be enabled. "
|
290
|
-
"Remove `insecure = true` from the federation configuration.",
|
291
|
-
fg=typer.colors.RED,
|
292
|
-
bold=True,
|
293
|
-
)
|
294
|
-
raise typer.Exit(code=1)
|
295
|
-
|
296
|
-
auth_plugin.load_tokens()
|
297
|
-
interceptors.append(CliAccountAuthInterceptor(auth_plugin))
|
280
|
+
# Load tokens
|
281
|
+
auth_plugin.load_tokens()
|
298
282
|
|
299
283
|
# Create the gRPC channel
|
300
284
|
channel = create_channel(
|
@@ -302,7 +286,7 @@ def init_channel(
|
|
302
286
|
insecure=insecure,
|
303
287
|
root_certificates=root_certificates_bytes,
|
304
288
|
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
305
|
-
interceptors=
|
289
|
+
interceptors=[CliAccountAuthInterceptor(auth_plugin)],
|
306
290
|
)
|
307
291
|
channel.subscribe(on_channel_state_change)
|
308
292
|
return channel
|
@@ -36,9 +36,6 @@ from flwr.common.inflatable_protobuf_utils import (
|
|
36
36
|
from flwr.common.logger import log
|
37
37
|
from flwr.common.message import Message, remove_content_from_message
|
38
38
|
from flwr.common.retry_invoker import RetryInvoker, _wrap_stub
|
39
|
-
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
40
|
-
generate_key_pairs,
|
41
|
-
)
|
42
39
|
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
|
43
40
|
from flwr.common.typing import Fab, Run
|
44
41
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
@@ -58,9 +55,10 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
58
55
|
from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
|
59
56
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
60
57
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
58
|
+
from flwr.supercore.primitives.asymmetric import generate_key_pairs, public_key_to_bytes
|
61
59
|
|
62
|
-
from .client_interceptor import AuthenticateClientInterceptor
|
63
60
|
from .grpc_adapter import GrpcAdapter
|
61
|
+
from .node_auth_client_interceptor import NodeAuthClientInterceptor
|
64
62
|
|
65
63
|
|
66
64
|
@contextmanager
|
@@ -140,8 +138,9 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
140
138
|
|
141
139
|
# Always configure auth interceptor, with either user-provided or generated keys
|
142
140
|
interceptors: Sequence[grpc.UnaryUnaryClientInterceptor] = [
|
143
|
-
|
141
|
+
NodeAuthClientInterceptor(*authentication_keys),
|
144
142
|
]
|
143
|
+
node_pk = public_key_to_bytes(authentication_keys[1])
|
145
144
|
channel = create_channel(
|
146
145
|
server_address=server_address,
|
147
146
|
insecure=insecure,
|
@@ -201,7 +200,8 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
201
200
|
"""Set create_node."""
|
202
201
|
# Call FleetAPI
|
203
202
|
create_node_request = CreateNodeRequest(
|
204
|
-
|
203
|
+
public_key=node_pk,
|
204
|
+
heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL,
|
205
205
|
)
|
206
206
|
create_node_response = stub.CreateNode(request=create_node_request)
|
207
207
|
|
@@ -23,14 +23,11 @@ from google.protobuf.message import Message as GrpcMessage
|
|
23
23
|
|
24
24
|
from flwr.common import now
|
25
25
|
from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_HEADER
|
26
|
-
from flwr.
|
27
|
-
public_key_to_bytes,
|
28
|
-
sign_message,
|
29
|
-
)
|
26
|
+
from flwr.supercore.primitives.asymmetric import public_key_to_bytes, sign_message
|
30
27
|
|
31
28
|
|
32
|
-
class
|
33
|
-
"""Client interceptor for
|
29
|
+
class NodeAuthClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
|
30
|
+
"""Client interceptor for node authentication."""
|
34
31
|
|
35
32
|
def __init__(
|
36
33
|
self,
|
@@ -35,14 +35,9 @@ from flwr.common.constant import MessageType
|
|
35
35
|
from flwr.common.logger import log
|
36
36
|
from flwr.common.secure_aggregation.crypto.shamir import create_shares
|
37
37
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
38
|
-
bytes_to_private_key,
|
39
|
-
bytes_to_public_key,
|
40
38
|
decrypt,
|
41
39
|
encrypt,
|
42
|
-
generate_key_pairs,
|
43
40
|
generate_shared_key,
|
44
|
-
private_key_to_bytes,
|
45
|
-
public_key_to_bytes,
|
46
41
|
)
|
47
42
|
from flwr.common.secure_aggregation.ndarrays_arithmetic import (
|
48
43
|
factor_combine,
|
@@ -64,6 +59,13 @@ from flwr.common.secure_aggregation.secaggplus_utils import (
|
|
64
59
|
share_keys_plaintext_separate,
|
65
60
|
)
|
66
61
|
from flwr.common.typing import ConfigRecordValues
|
62
|
+
from flwr.supercore.primitives.asymmetric import (
|
63
|
+
bytes_to_private_key,
|
64
|
+
bytes_to_public_key,
|
65
|
+
generate_key_pairs,
|
66
|
+
private_key_to_bytes,
|
67
|
+
public_key_to_bytes,
|
68
|
+
)
|
67
69
|
|
68
70
|
|
69
71
|
@dataclass
|
@@ -15,6 +15,7 @@
|
|
15
15
|
"""Contextmanager for a REST request-response channel to the Flower server."""
|
16
16
|
|
17
17
|
|
18
|
+
import secrets
|
18
19
|
from collections.abc import Iterator
|
19
20
|
from contextlib import contextmanager
|
20
21
|
from logging import ERROR, WARN
|
@@ -292,7 +293,12 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
292
293
|
|
293
294
|
def create_node() -> Optional[int]:
|
294
295
|
"""Set create_node."""
|
295
|
-
req = CreateNodeRequest(
|
296
|
+
req = CreateNodeRequest(
|
297
|
+
# REST does not support node authentication;
|
298
|
+
# random bytes are used instead
|
299
|
+
public_key=secrets.token_bytes(32),
|
300
|
+
heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL,
|
301
|
+
)
|
296
302
|
|
297
303
|
# Send the request
|
298
304
|
res = _request(req, CreateNodeResponse, PATH_CREATE_NODE)
|
flwr/common/constant.py
CHANGED
@@ -256,6 +256,16 @@ class AuthnType:
|
|
256
256
|
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
257
257
|
|
258
258
|
|
259
|
+
class AuthzType:
|
260
|
+
"""Account authorization types."""
|
261
|
+
|
262
|
+
NOOP = "noop"
|
263
|
+
|
264
|
+
def __new__(cls) -> AuthzType:
|
265
|
+
"""Prevent instantiation."""
|
266
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
267
|
+
|
268
|
+
|
259
269
|
class EventLogWriterType:
|
260
270
|
"""Event log writer types."""
|
261
271
|
|
@@ -16,57 +16,14 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
import base64
|
19
|
-
from typing import cast
|
20
19
|
|
21
20
|
from cryptography.exceptions import InvalidSignature
|
22
21
|
from cryptography.fernet import Fernet
|
23
|
-
from cryptography.hazmat.primitives import hashes, hmac
|
22
|
+
from cryptography.hazmat.primitives import hashes, hmac
|
24
23
|
from cryptography.hazmat.primitives.asymmetric import ec
|
25
24
|
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
26
25
|
|
27
26
|
|
28
|
-
def generate_key_pairs() -> (
|
29
|
-
tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
30
|
-
):
|
31
|
-
"""Generate private and public key pairs with Cryptography."""
|
32
|
-
private_key = ec.generate_private_key(ec.SECP384R1())
|
33
|
-
public_key = private_key.public_key()
|
34
|
-
return private_key, public_key
|
35
|
-
|
36
|
-
|
37
|
-
def private_key_to_bytes(private_key: ec.EllipticCurvePrivateKey) -> bytes:
|
38
|
-
"""Serialize private key to bytes."""
|
39
|
-
return private_key.private_bytes(
|
40
|
-
encoding=serialization.Encoding.PEM,
|
41
|
-
format=serialization.PrivateFormat.PKCS8,
|
42
|
-
encryption_algorithm=serialization.NoEncryption(),
|
43
|
-
)
|
44
|
-
|
45
|
-
|
46
|
-
def bytes_to_private_key(private_key_bytes: bytes) -> ec.EllipticCurvePrivateKey:
|
47
|
-
"""Deserialize private key from bytes."""
|
48
|
-
return cast(
|
49
|
-
ec.EllipticCurvePrivateKey,
|
50
|
-
serialization.load_pem_private_key(data=private_key_bytes, password=None),
|
51
|
-
)
|
52
|
-
|
53
|
-
|
54
|
-
def public_key_to_bytes(public_key: ec.EllipticCurvePublicKey) -> bytes:
|
55
|
-
"""Serialize public key to bytes."""
|
56
|
-
return public_key.public_bytes(
|
57
|
-
encoding=serialization.Encoding.PEM,
|
58
|
-
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
59
|
-
)
|
60
|
-
|
61
|
-
|
62
|
-
def bytes_to_public_key(public_key_bytes: bytes) -> ec.EllipticCurvePublicKey:
|
63
|
-
"""Deserialize public key from bytes."""
|
64
|
-
return cast(
|
65
|
-
ec.EllipticCurvePublicKey,
|
66
|
-
serialization.load_pem_public_key(data=public_key_bytes),
|
67
|
-
)
|
68
|
-
|
69
|
-
|
70
27
|
def generate_shared_key(
|
71
28
|
private_key: ec.EllipticCurvePrivateKey, public_key: ec.EllipticCurvePublicKey
|
72
29
|
) -> bytes:
|
@@ -117,48 +74,3 @@ def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool:
|
|
117
74
|
return True
|
118
75
|
except InvalidSignature:
|
119
76
|
return False
|
120
|
-
|
121
|
-
|
122
|
-
def sign_message(private_key: ec.EllipticCurvePrivateKey, message: bytes) -> bytes:
|
123
|
-
"""Sign a message using the provided EC private key.
|
124
|
-
|
125
|
-
Parameters
|
126
|
-
----------
|
127
|
-
private_key : ec.EllipticCurvePrivateKey
|
128
|
-
The EC private key to sign the message with.
|
129
|
-
message : bytes
|
130
|
-
The message to be signed.
|
131
|
-
|
132
|
-
Returns
|
133
|
-
-------
|
134
|
-
bytes
|
135
|
-
The signature of the message.
|
136
|
-
"""
|
137
|
-
signature = private_key.sign(message, ec.ECDSA(hashes.SHA256()))
|
138
|
-
return signature
|
139
|
-
|
140
|
-
|
141
|
-
def verify_signature(
|
142
|
-
public_key: ec.EllipticCurvePublicKey, message: bytes, signature: bytes
|
143
|
-
) -> bool:
|
144
|
-
"""Verify a signature against a message using the provided EC public key.
|
145
|
-
|
146
|
-
Parameters
|
147
|
-
----------
|
148
|
-
public_key : ec.EllipticCurvePublicKey
|
149
|
-
The EC public key to verify the signature.
|
150
|
-
message : bytes
|
151
|
-
The original message.
|
152
|
-
signature : bytes
|
153
|
-
The signature to verify.
|
154
|
-
|
155
|
-
Returns
|
156
|
-
-------
|
157
|
-
bool
|
158
|
-
True if the signature is valid, False otherwise.
|
159
|
-
"""
|
160
|
-
try:
|
161
|
-
public_key.verify(signature, message, ec.ECDSA(hashes.SHA256()))
|
162
|
-
return True
|
163
|
-
except InvalidSignature:
|
164
|
-
return False
|
flwr/proto/fleet_pb2.py
CHANGED
@@ -19,7 +19,7 @@ from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
|
|
19
19
|
from flwr.proto import message_pb2 as flwr_dot_proto_dot_message__pb2
|
20
20
|
|
21
21
|
|
22
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/heartbeat.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x18\x66lwr/proto/message.proto\"
|
22
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/heartbeat.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x18\x66lwr/proto/message.proto\"C\n\x11\x43reateNodeRequest\x12\x12\n\npublic_key\x18\x01 \x01(\x0c\x12\x1a\n\x12heartbeat_interval\x18\x02 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"J\n\x13PullMessagesRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x13\n\x0bmessage_ids\x18\x02 \x03(\t\"\xa2\x01\n\x14PullMessagesResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rmessages_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x34\n\x14message_object_trees\x18\x03 \x03(\x0b\x32\x16.flwr.proto.ObjectTree\"\x97\x01\n\x13PushMessagesRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12*\n\rmessages_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x34\n\x14message_object_trees\x18\x03 \x03(\x0b\x32\x16.flwr.proto.ObjectTree\"\xc9\x01\n\x14PushMessagesResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12>\n\x07results\x18\x02 \x03(\x0b\x32-.flwr.proto.PushMessagesResponse.ResultsEntry\x12\x17\n\x0fobjects_to_push\x18\x03 \x03(\t\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\xca\x06\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12\x62\n\x11SendNodeHeartbeat\x12$.flwr.proto.SendNodeHeartbeatRequest\x1a%.flwr.proto.SendNodeHeartbeatResponse\"\x00\x12S\n\x0cPullMessages\x12\x1f.flwr.proto.PullMessagesRequest\x1a .flwr.proto.PullMessagesResponse\"\x00\x12S\n\x0cPushMessages\x12\x1f.flwr.proto.PushMessagesRequest\x1a .flwr.proto.PushMessagesResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12M\n\nPushObject\x12\x1d.flwr.proto.PushObjectRequest\x1a\x1e.flwr.proto.PushObjectResponse\"\x00\x12M\n\nPullObject\x12\x1d.flwr.proto.PullObjectRequest\x1a\x1e.flwr.proto.PullObjectResponse\"\x00\x12q\n\x16\x43onfirmMessageReceived\x12).flwr.proto.ConfirmMessageReceivedRequest\x1a*.flwr.proto.ConfirmMessageReceivedResponse\"\x00\x62\x06proto3')
|
23
23
|
|
24
24
|
_globals = globals()
|
25
25
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
@@ -29,25 +29,25 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
29
29
|
_globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._options = None
|
30
30
|
_globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_options = b'8\001'
|
31
31
|
_globals['_CREATENODEREQUEST']._serialized_start=159
|
32
|
-
_globals['_CREATENODEREQUEST']._serialized_end=
|
33
|
-
_globals['_CREATENODERESPONSE']._serialized_start=
|
34
|
-
_globals['_CREATENODERESPONSE']._serialized_end=
|
35
|
-
_globals['_DELETENODEREQUEST']._serialized_start=
|
36
|
-
_globals['_DELETENODEREQUEST']._serialized_end=
|
37
|
-
_globals['_DELETENODERESPONSE']._serialized_start=
|
38
|
-
_globals['_DELETENODERESPONSE']._serialized_end=
|
39
|
-
_globals['_PULLMESSAGESREQUEST']._serialized_start=
|
40
|
-
_globals['_PULLMESSAGESREQUEST']._serialized_end=
|
41
|
-
_globals['_PULLMESSAGESRESPONSE']._serialized_start=
|
42
|
-
_globals['_PULLMESSAGESRESPONSE']._serialized_end=
|
43
|
-
_globals['_PUSHMESSAGESREQUEST']._serialized_start=
|
44
|
-
_globals['_PUSHMESSAGESREQUEST']._serialized_end=
|
45
|
-
_globals['_PUSHMESSAGESRESPONSE']._serialized_start=
|
46
|
-
_globals['_PUSHMESSAGESRESPONSE']._serialized_end=
|
47
|
-
_globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_start=
|
48
|
-
_globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_end=
|
49
|
-
_globals['_RECONNECT']._serialized_start=
|
50
|
-
_globals['_RECONNECT']._serialized_end=
|
51
|
-
_globals['_FLEET']._serialized_start=
|
52
|
-
_globals['_FLEET']._serialized_end=
|
32
|
+
_globals['_CREATENODEREQUEST']._serialized_end=226
|
33
|
+
_globals['_CREATENODERESPONSE']._serialized_start=228
|
34
|
+
_globals['_CREATENODERESPONSE']._serialized_end=280
|
35
|
+
_globals['_DELETENODEREQUEST']._serialized_start=282
|
36
|
+
_globals['_DELETENODEREQUEST']._serialized_end=333
|
37
|
+
_globals['_DELETENODERESPONSE']._serialized_start=335
|
38
|
+
_globals['_DELETENODERESPONSE']._serialized_end=355
|
39
|
+
_globals['_PULLMESSAGESREQUEST']._serialized_start=357
|
40
|
+
_globals['_PULLMESSAGESREQUEST']._serialized_end=431
|
41
|
+
_globals['_PULLMESSAGESRESPONSE']._serialized_start=434
|
42
|
+
_globals['_PULLMESSAGESRESPONSE']._serialized_end=596
|
43
|
+
_globals['_PUSHMESSAGESREQUEST']._serialized_start=599
|
44
|
+
_globals['_PUSHMESSAGESREQUEST']._serialized_end=750
|
45
|
+
_globals['_PUSHMESSAGESRESPONSE']._serialized_start=753
|
46
|
+
_globals['_PUSHMESSAGESRESPONSE']._serialized_end=954
|
47
|
+
_globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_start=908
|
48
|
+
_globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_end=954
|
49
|
+
_globals['_RECONNECT']._serialized_start=956
|
50
|
+
_globals['_RECONNECT']._serialized_end=986
|
51
|
+
_globals['_FLEET']._serialized_start=989
|
52
|
+
_globals['_FLEET']._serialized_end=1831
|
53
53
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/fleet_pb2.pyi
CHANGED
@@ -16,13 +16,16 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
|
16
16
|
class CreateNodeRequest(google.protobuf.message.Message):
|
17
17
|
"""CreateNode messages"""
|
18
18
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
19
|
+
PUBLIC_KEY_FIELD_NUMBER: builtins.int
|
19
20
|
HEARTBEAT_INTERVAL_FIELD_NUMBER: builtins.int
|
21
|
+
public_key: builtins.bytes
|
20
22
|
heartbeat_interval: builtins.float
|
21
23
|
def __init__(self,
|
22
24
|
*,
|
25
|
+
public_key: builtins.bytes = ...,
|
23
26
|
heartbeat_interval: builtins.float = ...,
|
24
27
|
) -> None: ...
|
25
|
-
def ClearField(self, field_name: typing_extensions.Literal["heartbeat_interval",b"heartbeat_interval"]) -> None: ...
|
28
|
+
def ClearField(self, field_name: typing_extensions.Literal["heartbeat_interval",b"heartbeat_interval","public_key",b"public_key"]) -> None: ...
|
26
29
|
global___CreateNodeRequest = CreateNodeRequest
|
27
30
|
|
28
31
|
class CreateNodeResponse(google.protobuf.message.Message):
|
flwr/proto/node_pb2.py
CHANGED
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
|
|
14
14
|
|
15
15
|
|
16
16
|
|
17
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/node.proto\x12\nflwr.proto\"\x17\n\x04Node\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\"\xd0\x01\n\x08NodeInfo\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\x12\x11\n\towner_aid\x18\x02 \x01(\t\x12\x0e\n\x06status\x18\x03 \x01(\t\x12\x12\n\ncreated_at\x18\x04 \x01(\t\x12\x19\n\x11last_activated_at\x18\x05 \x01(\t\x12\x1b\n\x13last_deactivated_at\x18\x06 \x01(\t\x12\x12\n\ndeleted_at\x18\x07 \x01(\t\x12\x14\n\x0conline_until\x18\x08 \x01(\
|
17
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/node.proto\x12\nflwr.proto\"\x17\n\x04Node\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\"\xd0\x01\n\x08NodeInfo\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\x12\x11\n\towner_aid\x18\x02 \x01(\t\x12\x0e\n\x06status\x18\x03 \x01(\t\x12\x12\n\ncreated_at\x18\x04 \x01(\t\x12\x19\n\x11last_activated_at\x18\x05 \x01(\t\x12\x1b\n\x13last_deactivated_at\x18\x06 \x01(\t\x12\x12\n\ndeleted_at\x18\x07 \x01(\t\x12\x14\n\x0conline_until\x18\x08 \x01(\x01\x12\x1a\n\x12heartbeat_interval\x18\t \x01(\x01\x62\x06proto3')
|
18
18
|
|
19
19
|
_globals = globals()
|
20
20
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|