flwr-nightly 1.23.0.dev20251006__py3-none-any.whl → 1.23.0.dev20251008__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. flwr/cli/auth_plugin/__init__.py +7 -3
  2. flwr/cli/log.py +2 -2
  3. flwr/cli/login/login.py +4 -13
  4. flwr/cli/ls.py +2 -2
  5. flwr/cli/pull.py +2 -2
  6. flwr/cli/run/run.py +2 -2
  7. flwr/cli/stop.py +2 -2
  8. flwr/cli/supernode/ls.py +2 -2
  9. flwr/cli/utils.py +28 -44
  10. flwr/client/grpc_rere_client/connection.py +6 -6
  11. flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
  12. flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
  13. flwr/client/rest_client/connection.py +7 -1
  14. flwr/common/constant.py +10 -0
  15. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
  16. flwr/proto/fleet_pb2.py +22 -22
  17. flwr/proto/fleet_pb2.pyi +4 -1
  18. flwr/proto/node_pb2.py +1 -1
  19. flwr/server/app.py +33 -34
  20. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +8 -4
  21. flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +19 -41
  22. flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
  23. flwr/server/superlink/fleet/vce/vce_api.py +7 -1
  24. flwr/server/superlink/linkstate/in_memory_linkstate.py +39 -27
  25. flwr/server/superlink/linkstate/linkstate.py +1 -1
  26. flwr/server/superlink/linkstate/sqlite_linkstate.py +37 -21
  27. flwr/server/utils/validator.py +2 -3
  28. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
  29. flwr/supercore/primitives/__init__.py +15 -0
  30. flwr/supercore/primitives/asymmetric.py +109 -0
  31. flwr/superlink/auth_plugin/__init__.py +29 -0
  32. flwr/superlink/servicer/control/control_grpc.py +9 -7
  33. flwr/superlink/servicer/control/control_servicer.py +34 -46
  34. {flwr_nightly-1.23.0.dev20251006.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/METADATA +1 -1
  35. {flwr_nightly-1.23.0.dev20251006.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/RECORD +37 -35
  36. {flwr_nightly-1.23.0.dev20251006.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/WHEEL +0 -0
  37. {flwr_nightly-1.23.0.dev20251006.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/entry_points.txt +0 -0
@@ -22,9 +22,13 @@ from .noop_auth_plugin import NoOpCliAuthPlugin
22
22
  from .oidc_cli_plugin import OidcCliPlugin
23
23
 
24
24
 
25
- def get_cli_auth_plugins() -> dict[str, type[CliAuthPlugin]]:
25
+ def get_cli_plugin_class(authn_type: str) -> type[CliAuthPlugin]:
26
26
  """Return all CLI authentication plugins."""
27
- return {AuthnType.NOOP: NoOpCliAuthPlugin, AuthnType.OIDC: OidcCliPlugin}
27
+ if authn_type == AuthnType.NOOP:
28
+ return NoOpCliAuthPlugin
29
+ if authn_type == AuthnType.OIDC:
30
+ return OidcCliPlugin
31
+ raise ValueError(f"Unsupported authentication type: {authn_type}")
28
32
 
29
33
 
30
34
  __all__ = [
@@ -32,5 +36,5 @@ __all__ = [
32
36
  "LoginError",
33
37
  "NoOpCliAuthPlugin",
34
38
  "OidcCliPlugin",
35
- "get_cli_auth_plugins",
39
+ "get_cli_plugin_class",
36
40
  ]
flwr/cli/log.py CHANGED
@@ -35,7 +35,7 @@ from flwr.common.logger import log as logger
35
35
  from flwr.proto.control_pb2 import StreamLogsRequest # pylint: disable=E0611
36
36
  from flwr.proto.control_pb2_grpc import ControlStub
37
37
 
38
- from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
38
+ from .utils import flwr_cli_grpc_exc_handler, init_channel, load_cli_auth_plugin
39
39
 
40
40
 
41
41
  class AllLogsRetrieved(BaseException):
@@ -186,7 +186,7 @@ def _log_with_control_api(
186
186
  run_id: int,
187
187
  stream: bool,
188
188
  ) -> None:
189
- auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
189
+ auth_plugin = load_cli_auth_plugin(app, federation, federation_config)
190
190
  channel = init_channel(app, federation_config, auth_plugin)
191
191
 
192
192
  if stream:
flwr/cli/login/login.py CHANGED
@@ -20,7 +20,7 @@ from typing import Annotated, Optional
20
20
 
21
21
  import typer
22
22
 
23
- from flwr.cli.auth_plugin import LoginError
23
+ from flwr.cli.auth_plugin import LoginError, NoOpCliAuthPlugin
24
24
  from flwr.cli.config_utils import (
25
25
  exit_if_no_address,
26
26
  get_insecure_flag,
@@ -40,7 +40,7 @@ from ..utils import (
40
40
  account_auth_enabled,
41
41
  flwr_cli_grpc_exc_handler,
42
42
  init_channel,
43
- try_obtain_cli_auth_plugin,
43
+ load_cli_auth_plugin,
44
44
  )
45
45
 
46
46
 
@@ -95,7 +95,7 @@ def login( # pylint: disable=R0914
95
95
  )
96
96
  raise typer.Exit(code=1)
97
97
 
98
- channel = init_channel(app, federation_config, None)
98
+ channel = init_channel(app, federation_config, NoOpCliAuthPlugin(Path()))
99
99
  stub = ControlStub(channel)
100
100
 
101
101
  login_request = GetLoginDetailsRequest()
@@ -104,16 +104,7 @@ def login( # pylint: disable=R0914
104
104
 
105
105
  # Get the auth plugin
106
106
  authn_type = login_response.authn_type
107
- auth_plugin = try_obtain_cli_auth_plugin(
108
- app, federation, federation_config, authn_type
109
- )
110
- if auth_plugin is None:
111
- typer.secho(
112
- f'❌ Authentication type "{authn_type}" not found',
113
- fg=typer.colors.RED,
114
- bold=True,
115
- )
116
- raise typer.Exit(code=1)
107
+ auth_plugin = load_cli_auth_plugin(app, federation, federation_config, authn_type)
117
108
 
118
109
  # Login
119
110
  details = AccountAuthLoginDetails(
flwr/cli/ls.py CHANGED
@@ -44,7 +44,7 @@ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
44
44
  )
45
45
  from flwr.proto.control_pb2_grpc import ControlStub
46
46
 
47
- from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
47
+ from .utils import flwr_cli_grpc_exc_handler, init_channel, load_cli_auth_plugin
48
48
 
49
49
  _RunListType = tuple[int, str, str, str, str, str, str, str, str]
50
50
 
@@ -127,7 +127,7 @@ def ls( # pylint: disable=too-many-locals, too-many-branches, R0913, R0917
127
127
  raise ValueError(
128
128
  "The options '--runs' and '--run-id' are mutually exclusive."
129
129
  )
130
- auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
130
+ auth_plugin = load_cli_auth_plugin(app, federation, federation_config)
131
131
  channel = init_channel(app, federation_config, auth_plugin)
132
132
  stub = ControlStub(channel)
133
133
 
flwr/cli/pull.py CHANGED
@@ -34,7 +34,7 @@ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
34
34
  )
35
35
  from flwr.proto.control_pb2_grpc import ControlStub
36
36
 
37
- from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
37
+ from .utils import flwr_cli_grpc_exc_handler, init_channel, load_cli_auth_plugin
38
38
 
39
39
 
40
40
  def pull( # pylint: disable=R0914
@@ -74,7 +74,7 @@ def pull( # pylint: disable=R0914
74
74
  channel = None
75
75
  try:
76
76
 
77
- auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
77
+ auth_plugin = load_cli_auth_plugin(app, federation, federation_config)
78
78
  channel = init_channel(app, federation_config, auth_plugin)
79
79
  stub = ControlStub(channel)
80
80
  with flwr_cli_grpc_exc_handler():
flwr/cli/run/run.py CHANGED
@@ -45,7 +45,7 @@ from flwr.proto.control_pb2 import StartRunRequest # pylint: disable=E0611
45
45
  from flwr.proto.control_pb2_grpc import ControlStub
46
46
 
47
47
  from ..log import start_stream
48
- from ..utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
48
+ from ..utils import flwr_cli_grpc_exc_handler, init_channel, load_cli_auth_plugin
49
49
 
50
50
  CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds)
51
51
 
@@ -148,7 +148,7 @@ def _run_with_control_api(
148
148
  ) -> None:
149
149
  channel = None
150
150
  try:
151
- auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
151
+ auth_plugin = load_cli_auth_plugin(app, federation, federation_config)
152
152
  channel = init_channel(app, federation_config, auth_plugin)
153
153
  stub = ControlStub(channel)
154
154
 
flwr/cli/stop.py CHANGED
@@ -38,7 +38,7 @@ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
38
38
  )
39
39
  from flwr.proto.control_pb2_grpc import ControlStub
40
40
 
41
- from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
41
+ from .utils import flwr_cli_grpc_exc_handler, init_channel, load_cli_auth_plugin
42
42
 
43
43
 
44
44
  def stop( # pylint: disable=R0914
@@ -89,7 +89,7 @@ def stop( # pylint: disable=R0914
89
89
  exit_if_no_address(federation_config, "stop")
90
90
  channel = None
91
91
  try:
92
- auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
92
+ auth_plugin = load_cli_auth_plugin(app, federation, federation_config)
93
93
  channel = init_channel(app, federation_config, auth_plugin)
94
94
  stub = ControlStub(channel) # pylint: disable=unused-variable # noqa: F841
95
95
 
flwr/cli/supernode/ls.py CHANGED
@@ -42,7 +42,7 @@ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
42
42
  from flwr.proto.control_pb2_grpc import ControlStub
43
43
  from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
44
44
 
45
- from ..utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
45
+ from ..utils import flwr_cli_grpc_exc_handler, init_channel, load_cli_auth_plugin
46
46
 
47
47
  _NodeListType = tuple[int, str, str, str, str, str, str, str]
48
48
 
@@ -94,7 +94,7 @@ def ls( # pylint: disable=R0914
94
94
  exit_if_no_address(federation_config, f"supernode {command_name}")
95
95
  channel = None
96
96
  try:
97
- auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
97
+ auth_plugin = load_cli_auth_plugin(app, federation, federation_config)
98
98
  channel = init_channel(app, federation_config, auth_plugin)
99
99
  stub = ControlStub(channel)
100
100
  typer.echo("📄 Listing all nodes...")
flwr/cli/utils.py CHANGED
@@ -34,6 +34,7 @@ from flwr.common.constant import (
34
34
  NO_ARTIFACT_PROVIDER_MESSAGE,
35
35
  PULL_UNFINISHED_RUN_MESSAGE,
36
36
  RUN_ID_NOT_FOUND_MESSAGE,
37
+ AuthnType,
37
38
  )
38
39
  from flwr.common.grpc import (
39
40
  GRPC_MAX_MESSAGE_LENGTH,
@@ -41,7 +42,7 @@ from flwr.common.grpc import (
41
42
  on_channel_state_change,
42
43
  )
43
44
 
44
- from .auth_plugin import CliAuthPlugin, get_cli_auth_plugins
45
+ from .auth_plugin import CliAuthPlugin, get_cli_plugin_class
45
46
  from .cli_account_auth_interceptor import CliAccountAuthInterceptor
46
47
  from .config_utils import validate_certificate_in_federation_config
47
48
 
@@ -230,71 +231,54 @@ def account_auth_enabled(federation_config: dict[str, Any]) -> bool:
230
231
  return enabled
231
232
 
232
233
 
233
- def try_obtain_cli_auth_plugin(
234
+ def retrieve_authn_type(config_path: Path) -> str:
235
+ """Retrieve the auth type from the config file or return NOOP if not found."""
236
+ try:
237
+ with config_path.open("r", encoding="utf-8") as file:
238
+ json_file = json.load(file)
239
+ authn_type: str = json_file[AUTHN_TYPE_JSON_KEY]
240
+ return authn_type
241
+ except (FileNotFoundError, KeyError):
242
+ return AuthnType.NOOP
243
+
244
+
245
+ def load_cli_auth_plugin(
234
246
  root_dir: Path,
235
247
  federation: str,
236
248
  federation_config: dict[str, Any],
237
249
  authn_type: Optional[str] = None,
238
- ) -> Optional[CliAuthPlugin]:
250
+ ) -> CliAuthPlugin:
239
251
  """Load the CLI-side account auth plugin for the given authn type."""
240
- # Check if account auth is enabled
241
- if not account_auth_enabled(federation_config):
242
- return None
243
-
252
+ # Find the path to the account auth config file
244
253
  config_path = get_account_auth_config_path(root_dir, federation)
245
254
 
246
- # Get the authn type from the config if not provided
247
- # authn_type will be None for all CLI commands except login
255
+ # Determine the auth type if not provided
256
+ # Only `flwr login` command can provide `authn_type` explicitly, as it can query the
257
+ # SuperLink for the auth type.
248
258
  if authn_type is None:
249
- try:
250
- with config_path.open("r", encoding="utf-8") as file:
251
- json_file = json.load(file)
252
- authn_type = json_file[AUTHN_TYPE_JSON_KEY]
253
- except (FileNotFoundError, KeyError):
254
- typer.secho(
255
- "❌ Missing or invalid credentials for account authentication. "
256
- "Please run `flwr login` to authenticate.",
257
- fg=typer.colors.RED,
258
- bold=True,
259
- )
260
- raise typer.Exit(code=1) from None
259
+ authn_type = AuthnType.NOOP
260
+ if account_auth_enabled(federation_config):
261
+ authn_type = retrieve_authn_type(config_path)
261
262
 
262
263
  # Retrieve auth plugin class and instantiate it
263
264
  try:
264
- all_plugins: dict[str, type[CliAuthPlugin]] = get_cli_auth_plugins()
265
- auth_plugin_class = all_plugins[authn_type]
265
+ auth_plugin_class = get_cli_plugin_class(authn_type)
266
266
  return auth_plugin_class(config_path)
267
- except KeyError:
267
+ except ValueError:
268
268
  typer.echo(f"❌ Unknown account authentication type: {authn_type}")
269
269
  raise typer.Exit(code=1) from None
270
- except ImportError:
271
- typer.echo("❌ No authentication plugins are currently supported.")
272
- raise typer.Exit(code=1) from None
273
270
 
274
271
 
275
272
  def init_channel(
276
- app: Path, federation_config: dict[str, Any], auth_plugin: Optional[CliAuthPlugin]
273
+ app: Path, federation_config: dict[str, Any], auth_plugin: CliAuthPlugin
277
274
  ) -> grpc.Channel:
278
275
  """Initialize gRPC channel to the Control API."""
279
276
  insecure, root_certificates_bytes = validate_certificate_in_federation_config(
280
277
  app, federation_config
281
278
  )
282
279
 
283
- # Initialize the CLI-side account auth interceptor
284
- interceptors: list[grpc.UnaryUnaryClientInterceptor] = []
285
- if auth_plugin is not None:
286
- # Check if TLS is enabled. If not, raise an error
287
- if insecure:
288
- typer.secho(
289
- "❌ Account authentication requires TLS to be enabled. "
290
- "Remove `insecure = true` from the federation configuration.",
291
- fg=typer.colors.RED,
292
- bold=True,
293
- )
294
- raise typer.Exit(code=1)
295
-
296
- auth_plugin.load_tokens()
297
- interceptors.append(CliAccountAuthInterceptor(auth_plugin))
280
+ # Load tokens
281
+ auth_plugin.load_tokens()
298
282
 
299
283
  # Create the gRPC channel
300
284
  channel = create_channel(
@@ -302,7 +286,7 @@ def init_channel(
302
286
  insecure=insecure,
303
287
  root_certificates=root_certificates_bytes,
304
288
  max_message_length=GRPC_MAX_MESSAGE_LENGTH,
305
- interceptors=interceptors or None,
289
+ interceptors=[CliAccountAuthInterceptor(auth_plugin)],
306
290
  )
307
291
  channel.subscribe(on_channel_state_change)
308
292
  return channel
@@ -36,9 +36,6 @@ from flwr.common.inflatable_protobuf_utils import (
36
36
  from flwr.common.logger import log
37
37
  from flwr.common.message import Message, remove_content_from_message
38
38
  from flwr.common.retry_invoker import RetryInvoker, _wrap_stub
39
- from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
40
- generate_key_pairs,
41
- )
42
39
  from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
43
40
  from flwr.common.typing import Fab, Run
44
41
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
@@ -58,9 +55,10 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
58
55
  from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
59
56
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
60
57
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
58
+ from flwr.supercore.primitives.asymmetric import generate_key_pairs, public_key_to_bytes
61
59
 
62
- from .client_interceptor import AuthenticateClientInterceptor
63
60
  from .grpc_adapter import GrpcAdapter
61
+ from .node_auth_client_interceptor import NodeAuthClientInterceptor
64
62
 
65
63
 
66
64
  @contextmanager
@@ -140,8 +138,9 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
140
138
 
141
139
  # Always configure auth interceptor, with either user-provided or generated keys
142
140
  interceptors: Sequence[grpc.UnaryUnaryClientInterceptor] = [
143
- AuthenticateClientInterceptor(*authentication_keys),
141
+ NodeAuthClientInterceptor(*authentication_keys),
144
142
  ]
143
+ node_pk = public_key_to_bytes(authentication_keys[1])
145
144
  channel = create_channel(
146
145
  server_address=server_address,
147
146
  insecure=insecure,
@@ -201,7 +200,8 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
201
200
  """Set create_node."""
202
201
  # Call FleetAPI
203
202
  create_node_request = CreateNodeRequest(
204
- heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL
203
+ public_key=node_pk,
204
+ heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL,
205
205
  )
206
206
  create_node_response = stub.CreateNode(request=create_node_request)
207
207
 
@@ -23,14 +23,11 @@ from google.protobuf.message import Message as GrpcMessage
23
23
 
24
24
  from flwr.common import now
25
25
  from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_HEADER
26
- from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
27
- public_key_to_bytes,
28
- sign_message,
29
- )
26
+ from flwr.supercore.primitives.asymmetric import public_key_to_bytes, sign_message
30
27
 
31
28
 
32
- class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
33
- """Client interceptor for client authentication."""
29
+ class NodeAuthClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
30
+ """Client interceptor for node authentication."""
34
31
 
35
32
  def __init__(
36
33
  self,
@@ -35,14 +35,9 @@ from flwr.common.constant import MessageType
35
35
  from flwr.common.logger import log
36
36
  from flwr.common.secure_aggregation.crypto.shamir import create_shares
37
37
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
38
- bytes_to_private_key,
39
- bytes_to_public_key,
40
38
  decrypt,
41
39
  encrypt,
42
- generate_key_pairs,
43
40
  generate_shared_key,
44
- private_key_to_bytes,
45
- public_key_to_bytes,
46
41
  )
47
42
  from flwr.common.secure_aggregation.ndarrays_arithmetic import (
48
43
  factor_combine,
@@ -64,6 +59,13 @@ from flwr.common.secure_aggregation.secaggplus_utils import (
64
59
  share_keys_plaintext_separate,
65
60
  )
66
61
  from flwr.common.typing import ConfigRecordValues
62
+ from flwr.supercore.primitives.asymmetric import (
63
+ bytes_to_private_key,
64
+ bytes_to_public_key,
65
+ generate_key_pairs,
66
+ private_key_to_bytes,
67
+ public_key_to_bytes,
68
+ )
67
69
 
68
70
 
69
71
  @dataclass
@@ -15,6 +15,7 @@
15
15
  """Contextmanager for a REST request-response channel to the Flower server."""
16
16
 
17
17
 
18
+ import secrets
18
19
  from collections.abc import Iterator
19
20
  from contextlib import contextmanager
20
21
  from logging import ERROR, WARN
@@ -292,7 +293,12 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
292
293
 
293
294
  def create_node() -> Optional[int]:
294
295
  """Set create_node."""
295
- req = CreateNodeRequest(heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL)
296
+ req = CreateNodeRequest(
297
+ # REST does not support node authentication;
298
+ # random bytes are used instead
299
+ public_key=secrets.token_bytes(32),
300
+ heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL,
301
+ )
296
302
 
297
303
  # Send the request
298
304
  res = _request(req, CreateNodeResponse, PATH_CREATE_NODE)
flwr/common/constant.py CHANGED
@@ -256,6 +256,16 @@ class AuthnType:
256
256
  raise TypeError(f"{cls.__name__} cannot be instantiated.")
257
257
 
258
258
 
259
+ class AuthzType:
260
+ """Account authorization types."""
261
+
262
+ NOOP = "noop"
263
+
264
+ def __new__(cls) -> AuthzType:
265
+ """Prevent instantiation."""
266
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")
267
+
268
+
259
269
  class EventLogWriterType:
260
270
  """Event log writer types."""
261
271
 
@@ -16,57 +16,14 @@
16
16
 
17
17
 
18
18
  import base64
19
- from typing import cast
20
19
 
21
20
  from cryptography.exceptions import InvalidSignature
22
21
  from cryptography.fernet import Fernet
23
- from cryptography.hazmat.primitives import hashes, hmac, serialization
22
+ from cryptography.hazmat.primitives import hashes, hmac
24
23
  from cryptography.hazmat.primitives.asymmetric import ec
25
24
  from cryptography.hazmat.primitives.kdf.hkdf import HKDF
26
25
 
27
26
 
28
- def generate_key_pairs() -> (
29
- tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
30
- ):
31
- """Generate private and public key pairs with Cryptography."""
32
- private_key = ec.generate_private_key(ec.SECP384R1())
33
- public_key = private_key.public_key()
34
- return private_key, public_key
35
-
36
-
37
- def private_key_to_bytes(private_key: ec.EllipticCurvePrivateKey) -> bytes:
38
- """Serialize private key to bytes."""
39
- return private_key.private_bytes(
40
- encoding=serialization.Encoding.PEM,
41
- format=serialization.PrivateFormat.PKCS8,
42
- encryption_algorithm=serialization.NoEncryption(),
43
- )
44
-
45
-
46
- def bytes_to_private_key(private_key_bytes: bytes) -> ec.EllipticCurvePrivateKey:
47
- """Deserialize private key from bytes."""
48
- return cast(
49
- ec.EllipticCurvePrivateKey,
50
- serialization.load_pem_private_key(data=private_key_bytes, password=None),
51
- )
52
-
53
-
54
- def public_key_to_bytes(public_key: ec.EllipticCurvePublicKey) -> bytes:
55
- """Serialize public key to bytes."""
56
- return public_key.public_bytes(
57
- encoding=serialization.Encoding.PEM,
58
- format=serialization.PublicFormat.SubjectPublicKeyInfo,
59
- )
60
-
61
-
62
- def bytes_to_public_key(public_key_bytes: bytes) -> ec.EllipticCurvePublicKey:
63
- """Deserialize public key from bytes."""
64
- return cast(
65
- ec.EllipticCurvePublicKey,
66
- serialization.load_pem_public_key(data=public_key_bytes),
67
- )
68
-
69
-
70
27
  def generate_shared_key(
71
28
  private_key: ec.EllipticCurvePrivateKey, public_key: ec.EllipticCurvePublicKey
72
29
  ) -> bytes:
@@ -117,48 +74,3 @@ def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool:
117
74
  return True
118
75
  except InvalidSignature:
119
76
  return False
120
-
121
-
122
- def sign_message(private_key: ec.EllipticCurvePrivateKey, message: bytes) -> bytes:
123
- """Sign a message using the provided EC private key.
124
-
125
- Parameters
126
- ----------
127
- private_key : ec.EllipticCurvePrivateKey
128
- The EC private key to sign the message with.
129
- message : bytes
130
- The message to be signed.
131
-
132
- Returns
133
- -------
134
- bytes
135
- The signature of the message.
136
- """
137
- signature = private_key.sign(message, ec.ECDSA(hashes.SHA256()))
138
- return signature
139
-
140
-
141
- def verify_signature(
142
- public_key: ec.EllipticCurvePublicKey, message: bytes, signature: bytes
143
- ) -> bool:
144
- """Verify a signature against a message using the provided EC public key.
145
-
146
- Parameters
147
- ----------
148
- public_key : ec.EllipticCurvePublicKey
149
- The EC public key to verify the signature.
150
- message : bytes
151
- The original message.
152
- signature : bytes
153
- The signature to verify.
154
-
155
- Returns
156
- -------
157
- bool
158
- True if the signature is valid, False otherwise.
159
- """
160
- try:
161
- public_key.verify(signature, message, ec.ECDSA(hashes.SHA256()))
162
- return True
163
- except InvalidSignature:
164
- return False
flwr/proto/fleet_pb2.py CHANGED
@@ -19,7 +19,7 @@ from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
19
19
  from flwr.proto import message_pb2 as flwr_dot_proto_dot_message__pb2
20
20
 
21
21
 
22
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/heartbeat.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x18\x66lwr/proto/message.proto\"/\n\x11\x43reateNodeRequest\x12\x1a\n\x12heartbeat_interval\x18\x01 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"J\n\x13PullMessagesRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x13\n\x0bmessage_ids\x18\x02 \x03(\t\"\xa2\x01\n\x14PullMessagesResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rmessages_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x34\n\x14message_object_trees\x18\x03 \x03(\x0b\x32\x16.flwr.proto.ObjectTree\"\x97\x01\n\x13PushMessagesRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12*\n\rmessages_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x34\n\x14message_object_trees\x18\x03 \x03(\x0b\x32\x16.flwr.proto.ObjectTree\"\xc9\x01\n\x14PushMessagesResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12>\n\x07results\x18\x02 \x03(\x0b\x32-.flwr.proto.PushMessagesResponse.ResultsEntry\x12\x17\n\x0fobjects_to_push\x18\x03 \x03(\t\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\xca\x06\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12\x62\n\x11SendNodeHeartbeat\x12$.flwr.proto.SendNodeHeartbeatRequest\x1a%.flwr.proto.SendNodeHeartbeatResponse\"\x00\x12S\n\x0cPullMessages\x12\x1f.flwr.proto.PullMessagesRequest\x1a .flwr.proto.PullMessagesResponse\"\x00\x12S\n\x0cPushMessages\x12\x1f.flwr.proto.PushMessagesRequest\x1a .flwr.proto.PushMessagesResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12M\n\nPushObject\x12\x1d.flwr.proto.PushObjectRequest\x1a\x1e.flwr.proto.PushObjectResponse\"\x00\x12M\n\nPullObject\x12\x1d.flwr.proto.PullObjectRequest\x1a\x1e.flwr.proto.PullObjectResponse\"\x00\x12q\n\x16\x43onfirmMessageReceived\x12).flwr.proto.ConfirmMessageReceivedRequest\x1a*.flwr.proto.ConfirmMessageReceivedResponse\"\x00\x62\x06proto3')
22
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/heartbeat.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x18\x66lwr/proto/message.proto\"C\n\x11\x43reateNodeRequest\x12\x12\n\npublic_key\x18\x01 \x01(\x0c\x12\x1a\n\x12heartbeat_interval\x18\x02 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"J\n\x13PullMessagesRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x13\n\x0bmessage_ids\x18\x02 \x03(\t\"\xa2\x01\n\x14PullMessagesResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rmessages_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x34\n\x14message_object_trees\x18\x03 \x03(\x0b\x32\x16.flwr.proto.ObjectTree\"\x97\x01\n\x13PushMessagesRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12*\n\rmessages_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x34\n\x14message_object_trees\x18\x03 \x03(\x0b\x32\x16.flwr.proto.ObjectTree\"\xc9\x01\n\x14PushMessagesResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12>\n\x07results\x18\x02 \x03(\x0b\x32-.flwr.proto.PushMessagesResponse.ResultsEntry\x12\x17\n\x0fobjects_to_push\x18\x03 \x03(\t\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\xca\x06\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12\x62\n\x11SendNodeHeartbeat\x12$.flwr.proto.SendNodeHeartbeatRequest\x1a%.flwr.proto.SendNodeHeartbeatResponse\"\x00\x12S\n\x0cPullMessages\x12\x1f.flwr.proto.PullMessagesRequest\x1a .flwr.proto.PullMessagesResponse\"\x00\x12S\n\x0cPushMessages\x12\x1f.flwr.proto.PushMessagesRequest\x1a .flwr.proto.PushMessagesResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12M\n\nPushObject\x12\x1d.flwr.proto.PushObjectRequest\x1a\x1e.flwr.proto.PushObjectResponse\"\x00\x12M\n\nPullObject\x12\x1d.flwr.proto.PullObjectRequest\x1a\x1e.flwr.proto.PullObjectResponse\"\x00\x12q\n\x16\x43onfirmMessageReceived\x12).flwr.proto.ConfirmMessageReceivedRequest\x1a*.flwr.proto.ConfirmMessageReceivedResponse\"\x00\x62\x06proto3')
23
23
 
24
24
  _globals = globals()
25
25
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -29,25 +29,25 @@ if _descriptor._USE_C_DESCRIPTORS == False:
29
29
  _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._options = None
30
30
  _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_options = b'8\001'
31
31
  _globals['_CREATENODEREQUEST']._serialized_start=159
32
- _globals['_CREATENODEREQUEST']._serialized_end=206
33
- _globals['_CREATENODERESPONSE']._serialized_start=208
34
- _globals['_CREATENODERESPONSE']._serialized_end=260
35
- _globals['_DELETENODEREQUEST']._serialized_start=262
36
- _globals['_DELETENODEREQUEST']._serialized_end=313
37
- _globals['_DELETENODERESPONSE']._serialized_start=315
38
- _globals['_DELETENODERESPONSE']._serialized_end=335
39
- _globals['_PULLMESSAGESREQUEST']._serialized_start=337
40
- _globals['_PULLMESSAGESREQUEST']._serialized_end=411
41
- _globals['_PULLMESSAGESRESPONSE']._serialized_start=414
42
- _globals['_PULLMESSAGESRESPONSE']._serialized_end=576
43
- _globals['_PUSHMESSAGESREQUEST']._serialized_start=579
44
- _globals['_PUSHMESSAGESREQUEST']._serialized_end=730
45
- _globals['_PUSHMESSAGESRESPONSE']._serialized_start=733
46
- _globals['_PUSHMESSAGESRESPONSE']._serialized_end=934
47
- _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_start=888
48
- _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_end=934
49
- _globals['_RECONNECT']._serialized_start=936
50
- _globals['_RECONNECT']._serialized_end=966
51
- _globals['_FLEET']._serialized_start=969
52
- _globals['_FLEET']._serialized_end=1811
32
+ _globals['_CREATENODEREQUEST']._serialized_end=226
33
+ _globals['_CREATENODERESPONSE']._serialized_start=228
34
+ _globals['_CREATENODERESPONSE']._serialized_end=280
35
+ _globals['_DELETENODEREQUEST']._serialized_start=282
36
+ _globals['_DELETENODEREQUEST']._serialized_end=333
37
+ _globals['_DELETENODERESPONSE']._serialized_start=335
38
+ _globals['_DELETENODERESPONSE']._serialized_end=355
39
+ _globals['_PULLMESSAGESREQUEST']._serialized_start=357
40
+ _globals['_PULLMESSAGESREQUEST']._serialized_end=431
41
+ _globals['_PULLMESSAGESRESPONSE']._serialized_start=434
42
+ _globals['_PULLMESSAGESRESPONSE']._serialized_end=596
43
+ _globals['_PUSHMESSAGESREQUEST']._serialized_start=599
44
+ _globals['_PUSHMESSAGESREQUEST']._serialized_end=750
45
+ _globals['_PUSHMESSAGESRESPONSE']._serialized_start=753
46
+ _globals['_PUSHMESSAGESRESPONSE']._serialized_end=954
47
+ _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_start=908
48
+ _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_end=954
49
+ _globals['_RECONNECT']._serialized_start=956
50
+ _globals['_RECONNECT']._serialized_end=986
51
+ _globals['_FLEET']._serialized_start=989
52
+ _globals['_FLEET']._serialized_end=1831
53
53
  # @@protoc_insertion_point(module_scope)
flwr/proto/fleet_pb2.pyi CHANGED
@@ -16,13 +16,16 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
16
16
  class CreateNodeRequest(google.protobuf.message.Message):
17
17
  """CreateNode messages"""
18
18
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
19
+ PUBLIC_KEY_FIELD_NUMBER: builtins.int
19
20
  HEARTBEAT_INTERVAL_FIELD_NUMBER: builtins.int
21
+ public_key: builtins.bytes
20
22
  heartbeat_interval: builtins.float
21
23
  def __init__(self,
22
24
  *,
25
+ public_key: builtins.bytes = ...,
23
26
  heartbeat_interval: builtins.float = ...,
24
27
  ) -> None: ...
25
- def ClearField(self, field_name: typing_extensions.Literal["heartbeat_interval",b"heartbeat_interval"]) -> None: ...
28
+ def ClearField(self, field_name: typing_extensions.Literal["heartbeat_interval",b"heartbeat_interval","public_key",b"public_key"]) -> None: ...
26
29
  global___CreateNodeRequest = CreateNodeRequest
27
30
 
28
31
  class CreateNodeResponse(google.protobuf.message.Message):
flwr/proto/node_pb2.py CHANGED
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
14
14
 
15
15
 
16
16
 
17
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/node.proto\x12\nflwr.proto\"\x17\n\x04Node\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\"\xd0\x01\n\x08NodeInfo\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\x12\x11\n\towner_aid\x18\x02 \x01(\t\x12\x0e\n\x06status\x18\x03 \x01(\t\x12\x12\n\ncreated_at\x18\x04 \x01(\t\x12\x19\n\x11last_activated_at\x18\x05 \x01(\t\x12\x1b\n\x13last_deactivated_at\x18\x06 \x01(\t\x12\x12\n\ndeleted_at\x18\x07 \x01(\t\x12\x14\n\x0conline_until\x18\x08 \x01(\x02\x12\x1a\n\x12heartbeat_interval\x18\t \x01(\x02\x62\x06proto3')
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/node.proto\x12\nflwr.proto\"\x17\n\x04Node\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\"\xd0\x01\n\x08NodeInfo\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\x12\x11\n\towner_aid\x18\x02 \x01(\t\x12\x0e\n\x06status\x18\x03 \x01(\t\x12\x12\n\ncreated_at\x18\x04 \x01(\t\x12\x19\n\x11last_activated_at\x18\x05 \x01(\t\x12\x1b\n\x13last_deactivated_at\x18\x06 \x01(\t\x12\x12\n\ndeleted_at\x18\x07 \x01(\t\x12\x14\n\x0conline_until\x18\x08 \x01(\x01\x12\x1a\n\x12heartbeat_interval\x18\t \x01(\x01\x62\x06proto3')
18
18
 
19
19
  _globals = globals()
20
20
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)