flwr-nightly 1.23.0.dev20251006__py3-none-any.whl → 1.23.0.dev20251008__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. flwr/cli/auth_plugin/__init__.py +7 -3
  2. flwr/cli/log.py +2 -2
  3. flwr/cli/login/login.py +4 -13
  4. flwr/cli/ls.py +2 -2
  5. flwr/cli/pull.py +2 -2
  6. flwr/cli/run/run.py +2 -2
  7. flwr/cli/stop.py +2 -2
  8. flwr/cli/supernode/ls.py +2 -2
  9. flwr/cli/utils.py +28 -44
  10. flwr/client/grpc_rere_client/connection.py +6 -6
  11. flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
  12. flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
  13. flwr/client/rest_client/connection.py +7 -1
  14. flwr/common/constant.py +10 -0
  15. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
  16. flwr/proto/fleet_pb2.py +22 -22
  17. flwr/proto/fleet_pb2.pyi +4 -1
  18. flwr/proto/node_pb2.py +1 -1
  19. flwr/server/app.py +33 -34
  20. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +8 -4
  21. flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +19 -41
  22. flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
  23. flwr/server/superlink/fleet/vce/vce_api.py +7 -1
  24. flwr/server/superlink/linkstate/in_memory_linkstate.py +39 -27
  25. flwr/server/superlink/linkstate/linkstate.py +1 -1
  26. flwr/server/superlink/linkstate/sqlite_linkstate.py +37 -21
  27. flwr/server/utils/validator.py +2 -3
  28. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
  29. flwr/supercore/primitives/__init__.py +15 -0
  30. flwr/supercore/primitives/asymmetric.py +109 -0
  31. flwr/superlink/auth_plugin/__init__.py +29 -0
  32. flwr/superlink/servicer/control/control_grpc.py +9 -7
  33. flwr/superlink/servicer/control/control_servicer.py +34 -46
  34. {flwr_nightly-1.23.0.dev20251006.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/METADATA +1 -1
  35. {flwr_nightly-1.23.0.dev20251006.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/RECORD +37 -35
  36. {flwr_nightly-1.23.0.dev20251006.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/WHEEL +0 -0
  37. {flwr_nightly-1.23.0.dev20251006.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/entry_points.txt +0 -0
flwr/server/app.py CHANGED
@@ -26,7 +26,7 @@ from collections.abc import Sequence
26
26
  from logging import DEBUG, INFO, WARN
27
27
  from pathlib import Path
28
28
  from time import sleep
29
- from typing import Any, Callable, Optional, TypeVar
29
+ from typing import Callable, Optional, TypeVar, cast
30
30
 
31
31
  import grpc
32
32
  import yaml
@@ -52,6 +52,8 @@ from flwr.common.constant import (
52
52
  TRANSPORT_TYPE_GRPC_ADAPTER,
53
53
  TRANSPORT_TYPE_GRPC_RERE,
54
54
  TRANSPORT_TYPE_REST,
55
+ AuthnType,
56
+ AuthzType,
55
57
  EventLogWriterType,
56
58
  ExecPluginType,
57
59
  )
@@ -59,9 +61,6 @@ from flwr.common.event_log_plugin import EventLogWriterPlugin
59
61
  from flwr.common.exit import ExitCode, flwr_exit, register_signal_handlers
60
62
  from flwr.common.grpc import generic_create_grpc_server
61
63
  from flwr.common.logger import log
62
- from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
63
- public_key_to_bytes,
64
- )
65
64
  from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
66
65
  add_FleetServicer_to_server,
67
66
  )
@@ -70,13 +69,21 @@ from flwr.server.fleet_event_log_interceptor import FleetEventLogInterceptor
70
69
  from flwr.supercore.ffs import FfsFactory
71
70
  from flwr.supercore.grpc_health import add_args_health, run_health_server_grpc_no_tls
72
71
  from flwr.supercore.object_store import ObjectStoreFactory
72
+ from flwr.supercore.primitives.asymmetric import public_key_to_bytes
73
73
  from flwr.superlink.artifact_provider import ArtifactProvider
74
- from flwr.superlink.auth_plugin import ControlAuthnPlugin, ControlAuthzPlugin
74
+ from flwr.superlink.auth_plugin import (
75
+ ControlAuthnPlugin,
76
+ ControlAuthzPlugin,
77
+ get_control_authn_plugins,
78
+ get_control_authz_plugins,
79
+ )
75
80
  from flwr.superlink.servicer.control import run_control_api_grpc
76
81
 
77
82
  from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
78
83
  from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
79
- from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
84
+ from .superlink.fleet.grpc_rere.node_auth_server_interceptor import (
85
+ NodeAuthServerInterceptor,
86
+ )
80
87
  from .superlink.linkstate import LinkStateFactory
81
88
  from .superlink.serverappio.serverappio_grpc import run_serverappio_api_grpc
82
89
  from .superlink.simulation.simulationio_grpc import run_simulationio_api_grpc
@@ -89,8 +96,6 @@ P = TypeVar("P", ControlAuthnPlugin, ControlAuthzPlugin)
89
96
  try:
90
97
  from flwr.ee import (
91
98
  add_ee_args_superlink,
92
- get_control_authn_plugins,
93
- get_control_authz_plugins,
94
99
  get_control_event_log_writer_plugins,
95
100
  get_ee_artifact_provider,
96
101
  get_fleet_event_log_writer_plugins,
@@ -101,14 +106,6 @@ except ImportError:
101
106
  def add_ee_args_superlink(parser: argparse.ArgumentParser) -> None:
102
107
  """Add EE-specific arguments to the parser."""
103
108
 
104
- def get_control_authn_plugins() -> dict[str, type[ControlAuthnPlugin]]:
105
- """Return all Control API authentication plugins."""
106
- raise NotImplementedError("No authentication plugins are currently supported.")
107
-
108
- def get_control_authz_plugins() -> dict[str, type[ControlAuthzPlugin]]:
109
- """Return all Control API authorization plugins."""
110
- raise NotImplementedError("No authorization plugins are currently supported.")
111
-
112
109
  def get_control_event_log_writer_plugins() -> dict[str, type[EventLogWriterPlugin]]:
113
110
  """Return all Control API event log writer plugins."""
114
111
  raise NotImplementedError(
@@ -204,10 +201,9 @@ def run_superlink() -> None:
204
201
  "future release. Please use `--account-auth-config` instead.",
205
202
  )
206
203
  args.account_auth_config = cfg_path
207
- if cfg_path := getattr(args, "account_auth_config", None):
208
- authn_plugin, authz_plugin = _try_obtain_control_auth_plugins(
209
- Path(cfg_path), verify_tls_cert
210
- )
204
+ cfg_path = getattr(args, "account_auth_config", None)
205
+ authn_plugin, authz_plugin = _load_control_auth_plugins(cfg_path, verify_tls_cert)
206
+ if cfg_path is not None:
211
207
  # Enable event logging if the args.enable_event_log is True
212
208
  if args.enable_event_log:
213
209
  event_log_plugin = _try_obtain_control_event_log_writer_plugin()
@@ -328,7 +324,7 @@ def run_superlink() -> None:
328
324
  else:
329
325
  log(DEBUG, "Automatic node authentication enabled")
330
326
 
331
- interceptors = [AuthenticateServerInterceptor(state_factory, auto_auth)]
327
+ interceptors = [NodeAuthServerInterceptor(state_factory, auto_auth)]
332
328
  if getattr(args, "enable_event_log", None):
333
329
  fleet_log_plugin = _try_obtain_fleet_event_log_writer_plugin()
334
330
  if fleet_log_plugin is not None:
@@ -449,13 +445,21 @@ def _try_load_public_keys_node_authentication(
449
445
  return node_public_keys
450
446
 
451
447
 
452
- def _try_obtain_control_auth_plugins(
453
- config_path: Path, verify_tls_cert: bool
448
+ def _load_control_auth_plugins(
449
+ config_path: Optional[str], verify_tls_cert: bool
454
450
  ) -> tuple[ControlAuthnPlugin, ControlAuthzPlugin]:
455
451
  """Obtain Control API authentication and authorization plugins."""
452
+ # Load NoOp plugins if no config path is provided
453
+ if config_path is None:
454
+ config_path = ""
455
+ config = {
456
+ "authentication": {AUTHN_TYPE_YAML_KEY: AuthnType.NOOP},
457
+ "authorization": {AUTHZ_TYPE_YAML_KEY: AuthzType.NOOP},
458
+ }
456
459
  # Load YAML file
457
- with config_path.open("r", encoding="utf-8") as file:
458
- config: dict[str, Any] = yaml.safe_load(file)
460
+ else:
461
+ with Path(config_path).open("r", encoding="utf-8") as file:
462
+ config = yaml.safe_load(file)
459
463
 
460
464
  def _load_plugin(
461
465
  section: str, yaml_key: str, loader: Callable[[], dict[str, type[P]]]
@@ -465,9 +469,7 @@ def _try_obtain_control_auth_plugins(
465
469
  try:
466
470
  plugins: dict[str, type[P]] = loader()
467
471
  plugin_cls: type[P] = plugins[auth_plugin_name]
468
- return plugin_cls(
469
- account_auth_config_path=config_path, verify_tls_cert=verify_tls_cert
470
- )
472
+ return plugin_cls(Path(cast(str, config_path)), verify_tls_cert)
471
473
  except KeyError:
472
474
  if auth_plugin_name:
473
475
  sys.exit(
@@ -475,18 +477,15 @@ def _try_obtain_control_auth_plugins(
475
477
  f"Please provide a valid {section} type in the configuration."
476
478
  )
477
479
  sys.exit(f"No {section} type is provided in the configuration.")
478
- except NotImplementedError:
479
- sys.exit(f"No {section} plugins are currently supported.")
480
480
 
481
- # Warn deprecated authn_type key
482
- if "authn_type" in config["authentication"]:
481
+ # Warn deprecated auth_type key
482
+ if authn_type := config["authentication"].pop("auth_type", None):
483
483
  log(
484
484
  WARN,
485
- "The `authn_type` key in the authentication configuration is deprecated. "
485
+ "The `auth_type` key in the authentication configuration is deprecated. "
486
486
  "Use `%s` instead.",
487
487
  AUTHN_TYPE_YAML_KEY,
488
488
  )
489
- authn_type = config["authentication"].pop("authn_type")
490
489
  config["authentication"][AUTHN_TYPE_YAML_KEY] = authn_type
491
490
 
492
491
  # Load authentication plugin
@@ -78,10 +78,14 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
78
78
  request.heartbeat_interval,
79
79
  )
80
80
  log(DEBUG, "[Fleet.CreateNode] Request: %s", MessageToDict(request))
81
- response = message_handler.create_node(
82
- request=request,
83
- state=self.state_factory.state(),
84
- )
81
+ try:
82
+ response = message_handler.create_node(
83
+ request=request,
84
+ state=self.state_factory.state(),
85
+ )
86
+ except ValueError as e:
87
+ # Public key already in use
88
+ context.abort(grpc.StatusCode.FAILED_PRECONDITION, str(e))
85
89
  log(INFO, "[Fleet.CreateNode] Created node_id=%s", response.node.node_id)
86
90
  log(DEBUG, "[Fleet.CreateNode] Response: %s", MessageToDict(response))
87
91
  return response
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  import datetime
19
- from typing import Any, Callable, Optional, cast
19
+ from typing import Any, Callable, cast
20
20
 
21
21
  import grpc
22
22
  from google.protobuf.message import Message as GrpcMessage
@@ -29,15 +29,9 @@ from flwr.common.constant import (
29
29
  TIMESTAMP_HEADER,
30
30
  TIMESTAMP_TOLERANCE,
31
31
  )
32
- from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
33
- bytes_to_public_key,
34
- verify_signature,
35
- )
36
- from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
37
- CreateNodeRequest,
38
- CreateNodeResponse,
39
- )
32
+ from flwr.proto.fleet_pb2 import CreateNodeRequest # pylint: disable=E0611
40
33
  from flwr.server.superlink.linkstate import LinkStateFactory
34
+ from flwr.supercore.primitives.asymmetric import bytes_to_public_key, verify_signature
41
35
 
42
36
  MIN_TIMESTAMP_DIFF = -SYSTEM_TIME_TOLERANCE
43
37
  MAX_TIMESTAMP_DIFF = TIMESTAMP_TOLERANCE + SYSTEM_TIME_TOLERANCE
@@ -53,7 +47,7 @@ def _unary_unary_rpc_terminator(
53
47
  return grpc.unary_unary_rpc_method_handler(terminate)
54
48
 
55
49
 
56
- class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
50
+ class NodeAuthServerInterceptor(grpc.ServerInterceptor): # type: ignore
57
51
  """Server interceptor for node authentication.
58
52
 
59
53
  Parameters
@@ -113,50 +107,34 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
113
107
  if not MIN_TIMESTAMP_DIFF < time_diff.total_seconds() < MAX_TIMESTAMP_DIFF:
114
108
  return _unary_unary_rpc_terminator("Invalid timestamp")
115
109
 
116
- # Continue the RPC call
117
- expected_node_id = state.get_node_id(node_pk_bytes)
118
- if not handler_call_details.method.endswith("CreateNode"):
119
- # All calls, except for `CreateNode`, must provide a public key that is
120
- # already mapped to a `node_id` (in `LinkState`)
121
- if expected_node_id is None:
122
- return _unary_unary_rpc_terminator("Invalid node ID")
123
- # One of the method handlers in
110
+ # Continue the RPC call: One of the method handlers in
124
111
  # `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
125
112
  method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
126
- return self._wrap_method_handler(
127
- method_handler, expected_node_id, node_pk_bytes
128
- )
113
+ return self._wrap_method_handler(method_handler, node_pk_bytes)
129
114
 
130
115
  def _wrap_method_handler(
131
116
  self,
132
117
  method_handler: grpc.RpcMethodHandler,
133
- expected_node_id: Optional[int],
134
- node_public_key: bytes,
118
+ expected_public_key: bytes,
135
119
  ) -> grpc.RpcMethodHandler:
136
120
  def _generic_method_handler(
137
121
  request: GrpcMessage,
138
122
  context: grpc.ServicerContext,
139
123
  ) -> GrpcMessage:
140
- # Verify the node ID
141
- if not isinstance(request, CreateNodeRequest):
142
- try:
143
- if request.node.node_id != expected_node_id: # type: ignore
144
- raise ValueError
145
- except (AttributeError, ValueError):
146
- context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID")
124
+ # Retrieve the public key
125
+ if isinstance(request, CreateNodeRequest):
126
+ actual_public_key = request.public_key
127
+ else:
128
+ # Note: This function runs in a different thread
129
+ # than the `intercept_service` function.
130
+ actual_public_key = self.state_factory.state().get_node_public_key(
131
+ request.node.node_id # type: ignore
132
+ )
133
+ # Verify the public key
134
+ if actual_public_key != expected_public_key:
135
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID")
147
136
 
148
137
  response: GrpcMessage = method_handler.unary_unary(request, context)
149
-
150
- # Set the public key after a successful CreateNode request
151
- if isinstance(response, CreateNodeResponse):
152
- state = self.state_factory.state()
153
- try:
154
- state.set_node_public_key(response.node.node_id, node_public_key)
155
- except ValueError as e:
156
- # Remove newly created node if setting the public key fails
157
- state.delete_node(response.node.node_id)
158
- context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e))
159
-
160
138
  return response
161
139
 
162
140
  return grpc.unary_unary_rpc_method_handler(
@@ -70,7 +70,7 @@ def create_node(
70
70
  ) -> CreateNodeResponse:
71
71
  """."""
72
72
  # Create node
73
- node_id = state.create_node(heartbeat_interval=request.heartbeat_interval)
73
+ node_id = state.create_node(request.public_key, request.heartbeat_interval)
74
74
  return CreateNodeResponse(node=Node(node_id=node_id))
75
75
 
76
76
 
@@ -16,6 +16,7 @@
16
16
 
17
17
 
18
18
  import json
19
+ import secrets
19
20
  import threading
20
21
  import time
21
22
  import traceback
@@ -53,7 +54,12 @@ def _register_nodes(
53
54
  nodes_mapping: NodeToPartitionMapping = {}
54
55
  state = state_factory.state()
55
56
  for i in range(num_nodes):
56
- node_id = state.create_node(heartbeat_interval=HEARTBEAT_MAX_INTERVAL)
57
+ node_id = state.create_node(
58
+ # No node authentication in simulation;
59
+ # use random bytes instead
60
+ secrets.token_bytes(32),
61
+ heartbeat_interval=HEARTBEAT_MAX_INTERVAL,
62
+ )
57
63
  nodes_mapping[node_id] = i
58
64
  log(DEBUG, "Registered %i nodes", len(nodes_mapping))
59
65
  return nodes_mapping
@@ -17,7 +17,6 @@
17
17
 
18
18
  import secrets
19
19
  import threading
20
- import time
21
20
  from bisect import bisect_right
22
21
  from collections import defaultdict
23
22
  from dataclasses import dataclass, field
@@ -39,6 +38,7 @@ from flwr.common.constant import (
39
38
  )
40
39
  from flwr.common.record import ConfigRecord
41
40
  from flwr.common.typing import Run, RunStatus, UserConfig
41
+ from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
42
42
  from flwr.server.superlink.linkstate.linkstate import LinkState
43
43
  from flwr.server.utils import validate_message
44
44
 
@@ -70,7 +70,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
70
70
  def __init__(self) -> None:
71
71
 
72
72
  # Map node_id to (online_until, heartbeat_interval)
73
- self.node_ids: dict[int, tuple[float, float]] = {}
73
+ self.nodes: dict[int, NodeInfo] = {}
74
74
  self.public_key_to_node_id: dict[bytes, int] = {}
75
75
  self.node_id_to_public_key: dict[int, bytes] = {}
76
76
 
@@ -114,7 +114,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
114
114
  )
115
115
  return None
116
116
  # Validate destination node ID
117
- if message.metadata.dst_node_id not in self.node_ids:
117
+ if message.metadata.dst_node_id not in self.nodes:
118
118
  log(
119
119
  ERROR,
120
120
  "Invalid destination node ID for Message: %s",
@@ -136,7 +136,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
136
136
 
137
137
  # Find Message for node_id that were not delivered yet
138
138
  message_ins_list: list[Message] = []
139
- current_time = time.time()
139
+ current_time = now().timestamp()
140
140
  with self.lock:
141
141
  for _, msg_ins in self.message_ins_store.items():
142
142
  if (
@@ -190,7 +190,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
190
190
  return None
191
191
 
192
192
  ins_metadata = msg_ins.metadata
193
- if ins_metadata.created_at + ins_metadata.ttl <= time.time():
193
+ if ins_metadata.created_at + ins_metadata.ttl <= now().timestamp():
194
194
  log(
195
195
  ERROR,
196
196
  "Failed to store Message: the message it is replying to "
@@ -238,7 +238,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
238
238
  ret: dict[str, Message] = {}
239
239
 
240
240
  with self.lock:
241
- current = time.time()
241
+ current = now().timestamp()
242
242
 
243
243
  # Verify Message IDs
244
244
  ret = verify_message_ids(
@@ -256,9 +256,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
256
256
  inquired_in_message_ids=message_ids,
257
257
  found_in_message_dict=self.message_ins_store,
258
258
  node_id_to_online_until={
259
- node_id: self.node_ids[node_id][0]
259
+ node_id: self.nodes[node_id].online_until
260
260
  for node_id in dst_node_ids
261
- if node_id in self.node_ids
261
+ if node_id in self.nodes
262
262
  },
263
263
  current_time=current,
264
264
  )
@@ -330,7 +330,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
330
330
  """
331
331
  return len(self.message_res_store)
332
332
 
333
- def create_node(self, heartbeat_interval: float) -> int:
333
+ def create_node(self, public_key: bytes, heartbeat_interval: float) -> int:
334
334
  """Create, store in the link state, and return `node_id`."""
335
335
  # Sample a random int64 as node_id
336
336
  node_id = generate_rand_int_from_bytes(
@@ -338,28 +338,40 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
338
338
  )
339
339
 
340
340
  with self.lock:
341
- if node_id in self.node_ids:
341
+ if node_id in self.nodes:
342
342
  log(ERROR, "Unexpected node registration failure.")
343
343
  return 0
344
+ if public_key in self.public_key_to_node_id:
345
+ raise ValueError("Public key already in use")
344
346
 
345
- # Mark the node online until time.time() + heartbeat_interval
346
- self.node_ids[node_id] = (
347
- time.time() + heartbeat_interval,
348
- heartbeat_interval,
347
+ # Mark the node online until now().timestamp() + heartbeat_interval
348
+ current = now()
349
+ self.nodes[node_id] = NodeInfo(
350
+ node_id=node_id,
351
+ owner_aid="", # Unused for now
352
+ status="created", # Unused for now
353
+ created_at=current.isoformat(), # Unused for now
354
+ last_activated_at=current.isoformat(), # Unused for now
355
+ last_deactivated_at="", # Unused for now
356
+ deleted_at="", # Unused for now
357
+ online_until=current.timestamp() + heartbeat_interval,
358
+ heartbeat_interval=heartbeat_interval,
349
359
  )
360
+ self.public_key_to_node_id[public_key] = node_id
361
+ self.node_id_to_public_key[node_id] = public_key
350
362
  return node_id
351
363
 
352
364
  def delete_node(self, node_id: int) -> None:
353
365
  """Delete a node."""
354
366
  with self.lock:
355
- if node_id not in self.node_ids:
367
+ if node_id not in self.nodes:
356
368
  raise ValueError(f"Node {node_id} not found")
357
369
 
358
370
  # Remove node ID <> public key mappings
359
371
  if pk := self.node_id_to_public_key.pop(node_id, None):
360
372
  del self.public_key_to_node_id[pk]
361
373
 
362
- del self.node_ids[node_id]
374
+ del self.nodes[node_id]
363
375
 
364
376
  def get_nodes(self, run_id: int) -> set[int]:
365
377
  """Return all available nodes.
@@ -372,17 +384,17 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
372
384
  with self.lock:
373
385
  if run_id not in self.run_ids:
374
386
  return set()
375
- current_time = time.time()
387
+ current_time = now().timestamp()
376
388
  return {
377
- node_id
378
- for node_id, (online_until, _) in self.node_ids.items()
379
- if online_until > current_time
389
+ info.node_id
390
+ for info in self.nodes.values()
391
+ if info.online_until > current_time
380
392
  }
381
393
 
382
394
  def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
383
395
  """Set `public_key` for the specified `node_id`."""
384
396
  with self.lock:
385
- if node_id not in self.node_ids:
397
+ if node_id not in self.nodes:
386
398
  raise ValueError(f"Node {node_id} not found")
387
399
 
388
400
  if public_key in self.public_key_to_node_id:
@@ -394,7 +406,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
394
406
  def get_node_public_key(self, node_id: int) -> Optional[bytes]:
395
407
  """Get `public_key` for the specified `node_id`."""
396
408
  with self.lock:
397
- if node_id not in self.node_ids:
409
+ if node_id not in self.nodes:
398
410
  raise ValueError(f"Node {node_id} not found")
399
411
 
400
412
  return self.node_id_to_public_key.get(node_id)
@@ -608,13 +620,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
608
620
  the node is marked as offline.
609
621
  """
610
622
  with self.lock:
611
- if node_id in self.node_ids:
612
- self.node_ids[node_id] = (
613
- time.time() + HEARTBEAT_PATIENCE * heartbeat_interval,
614
- heartbeat_interval,
623
+ if info := self.nodes.get(node_id):
624
+ info.online_until = (
625
+ now().timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval
615
626
  )
627
+ info.heartbeat_interval = heartbeat_interval
616
628
  return True
617
- return False
629
+ return False
618
630
 
619
631
  def acknowledge_app_heartbeat(self, run_id: int, heartbeat_interval: float) -> bool:
620
632
  """Acknowledge a heartbeat received from a ServerApp for a given run.
@@ -128,7 +128,7 @@ class LinkState(CoreState): # pylint: disable=R0904
128
128
  """Get all instruction Message IDs for the given run_id."""
129
129
 
130
130
  @abc.abstractmethod
131
- def create_node(self, heartbeat_interval: float) -> int:
131
+ def create_node(self, public_key: bytes, heartbeat_interval: float) -> int:
132
132
  """Create, store in the link state, and return `node_id`."""
133
133
 
134
134
  @abc.abstractmethod
@@ -21,7 +21,6 @@ import json
21
21
  import re
22
22
  import secrets
23
23
  import sqlite3
24
- import time
25
24
  from collections.abc import Sequence
26
25
  from logging import DEBUG, ERROR, WARNING
27
26
  from typing import Any, Optional, Union, cast
@@ -72,10 +71,16 @@ from .utils import (
72
71
 
73
72
  SQL_CREATE_TABLE_NODE = """
74
73
  CREATE TABLE IF NOT EXISTS node(
75
- node_id INTEGER UNIQUE,
76
- online_until REAL,
77
- heartbeat_interval REAL,
78
- public_key BLOB
74
+ node_id INTEGER UNIQUE,
75
+ owner_aid TEXT,
76
+ status TEXT,
77
+ created_at TEXT,
78
+ last_activated_at TEXT,
79
+ last_deactivated_at TEXT,
80
+ deleted_at TEXT,
81
+ online_until REAL,
82
+ heartbeat_interval REAL,
83
+ public_key BLOB UNIQUE
79
84
  );
80
85
  """
81
86
 
@@ -451,7 +456,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
451
456
  ret: dict[str, Message] = {}
452
457
 
453
458
  # Verify Message IDs
454
- current = time.time()
459
+ current = now().timestamp()
455
460
  query = f"""
456
461
  SELECT *
457
462
  FROM message_ins
@@ -597,7 +602,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
597
602
 
598
603
  return {row["message_id"] for row in rows}
599
604
 
600
- def create_node(self, heartbeat_interval: float) -> int:
605
+ def create_node(self, public_key: bytes, heartbeat_interval: float) -> int:
601
606
  """Create, store in the link state, and return `node_id`."""
602
607
  # Sample a random uint64 as node_id
603
608
  uint64_node_id = generate_rand_int_from_bytes(
@@ -607,24 +612,35 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
607
612
  # Convert the uint64 value to sint64 for SQLite
608
613
  sint64_node_id = convert_uint64_to_sint64(uint64_node_id)
609
614
 
610
- query = (
611
- "INSERT INTO node "
612
- "(node_id, online_until, heartbeat_interval, public_key) "
613
- "VALUES (?, ?, ?, ?)"
614
- )
615
+ query = """
616
+ INSERT INTO node
617
+ (node_id, owner_aid, status, created_at, last_activated_at,
618
+ last_deactivated_at, deleted_at, online_until, heartbeat_interval,
619
+ public_key)
620
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
621
+ """
615
622
 
616
- # Mark the node online util time.time() + heartbeat_interval
623
+ # Mark the node online until now().timestamp() + heartbeat_interval
617
624
  try:
618
625
  self.query(
619
626
  query,
620
627
  (
621
- sint64_node_id,
622
- time.time() + heartbeat_interval,
623
- heartbeat_interval,
624
- b"", # Initialize with an empty public key
628
+ sint64_node_id, # node_id
629
+ "", # owner_aid, unused for now
630
+ "created", # status, unused for now
631
+ now().isoformat(), # created_at, unused for now
632
+ now().isoformat(), # last_activated_at, unused for now
633
+ "", # last_deactivated_at, unused for now
634
+ "", # deleted_at, unused for now
635
+ now().timestamp() + heartbeat_interval, # online_until
636
+ heartbeat_interval, # heartbeat_interval
637
+ public_key, # public_key
625
638
  ),
626
639
  )
627
- except sqlite3.IntegrityError:
640
+ except sqlite3.IntegrityError as e:
641
+ if "UNIQUE constraint failed: node.public_key" in str(e):
642
+ raise ValueError("Public key already in use.") from None
643
+ # Must be node ID conflict, almost impossible unless system is compromised
628
644
  log(ERROR, "Unexpected node registration failure.")
629
645
  return 0
630
646
 
@@ -668,7 +684,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
668
684
 
669
685
  # Get nodes
670
686
  query = "SELECT node_id FROM node WHERE online_until > ?;"
671
- rows = self.query(query, (time.time(),))
687
+ rows = self.query(query, (now().timestamp(),))
672
688
 
673
689
  # Convert sint64 node_ids to uint64
674
690
  result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows}
@@ -1010,7 +1026,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1010
1026
  self.query(
1011
1027
  query,
1012
1028
  (
1013
- time.time() + HEARTBEAT_PATIENCE * heartbeat_interval,
1029
+ now().timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval,
1014
1030
  heartbeat_interval,
1015
1031
  sint64_node_id,
1016
1032
  ),
@@ -1140,7 +1156,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1140
1156
  message_ins = rows[0]
1141
1157
  created_at = message_ins["created_at"]
1142
1158
  ttl = message_ins["ttl"]
1143
- current_time = time.time()
1159
+ current_time = now().timestamp()
1144
1160
 
1145
1161
  # Check if Message is expired
1146
1162
  if ttl is not None and created_at + ttl <= current_time:
@@ -15,10 +15,9 @@
15
15
  """Validators."""
16
16
 
17
17
 
18
- import time
19
-
20
18
  from flwr.common import Message
21
19
  from flwr.common.constant import SUPERLINK_NODE_ID
20
+ from flwr.common.date import now
22
21
 
23
22
 
24
23
  # pylint: disable-next=too-many-branches
@@ -44,7 +43,7 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
44
43
  validation_errors.append("`metadata.ttl` must be higher than zero")
45
44
 
46
45
  # Verify TTL and created_at time
47
- current_time = time.time()
46
+ current_time = now().timestamp()
48
47
  if metadata.created_at + metadata.ttl <= current_time:
49
48
  validation_errors.append("Message TTL has expired")
50
49
 
@@ -35,8 +35,6 @@ from flwr.common import (
35
35
  )
36
36
  from flwr.common.secure_aggregation.crypto.shamir import combine_shares
37
37
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
38
- bytes_to_private_key,
39
- bytes_to_public_key,
40
38
  generate_shared_key,
41
39
  )
42
40
  from flwr.common.secure_aggregation.ndarrays_arithmetic import (
@@ -56,6 +54,10 @@ from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
56
54
  from flwr.server.client_proxy import ClientProxy
57
55
  from flwr.server.compat.legacy_context import LegacyContext
58
56
  from flwr.server.grid import Grid
57
+ from flwr.supercore.primitives.asymmetric import (
58
+ bytes_to_private_key,
59
+ bytes_to_public_key,
60
+ )
59
61
 
60
62
  from ..constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD
61
63
  from ..constant import Key as WorkflowKey