flwr-nightly 1.23.0.dev20251007__py3-none-any.whl → 1.23.0.dev20251009__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. flwr/cli/auth_plugin/__init__.py +7 -3
  2. flwr/cli/log.py +2 -2
  3. flwr/cli/login/login.py +4 -13
  4. flwr/cli/ls.py +2 -2
  5. flwr/cli/pull.py +2 -2
  6. flwr/cli/run/run.py +2 -2
  7. flwr/cli/stop.py +2 -2
  8. flwr/cli/supernode/create.py +137 -11
  9. flwr/cli/supernode/delete.py +88 -10
  10. flwr/cli/supernode/ls.py +2 -2
  11. flwr/cli/utils.py +65 -55
  12. flwr/client/grpc_rere_client/connection.py +6 -4
  13. flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +2 -2
  14. flwr/client/rest_client/connection.py +7 -1
  15. flwr/common/constant.py +13 -0
  16. flwr/proto/control_pb2.py +1 -1
  17. flwr/proto/control_pb2.pyi +2 -2
  18. flwr/proto/fleet_pb2.py +22 -22
  19. flwr/proto/fleet_pb2.pyi +4 -1
  20. flwr/proto/node_pb2.py +2 -2
  21. flwr/proto/node_pb2.pyi +4 -1
  22. flwr/server/app.py +32 -31
  23. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +8 -4
  24. flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +18 -37
  25. flwr/server/superlink/fleet/message_handler/message_handler.py +5 -3
  26. flwr/server/superlink/fleet/vce/vce_api.py +10 -1
  27. flwr/server/superlink/linkstate/in_memory_linkstate.py +52 -54
  28. flwr/server/superlink/linkstate/linkstate.py +20 -10
  29. flwr/server/superlink/linkstate/sqlite_linkstate.py +54 -61
  30. flwr/server/utils/validator.py +2 -3
  31. flwr/supercore/primitives/asymmetric.py +8 -0
  32. flwr/superlink/auth_plugin/__init__.py +29 -0
  33. flwr/superlink/servicer/control/control_grpc.py +9 -7
  34. flwr/superlink/servicer/control/control_servicer.py +89 -48
  35. {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251009.dist-info}/METADATA +1 -1
  36. {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251009.dist-info}/RECORD +38 -38
  37. {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251009.dist-info}/WHEEL +0 -0
  38. {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251009.dist-info}/entry_points.txt +0 -0
@@ -17,7 +17,6 @@
17
17
 
18
18
  import secrets
19
19
  import threading
20
- import time
21
20
  from bisect import bisect_right
22
21
  from collections import defaultdict
23
22
  from dataclasses import dataclass, field
@@ -39,6 +38,7 @@ from flwr.common.constant import (
39
38
  )
40
39
  from flwr.common.record import ConfigRecord
41
40
  from flwr.common.typing import Run, RunStatus, UserConfig
41
+ from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
42
42
  from flwr.server.superlink.linkstate.linkstate import LinkState
43
43
  from flwr.server.utils import validate_message
44
44
 
@@ -69,10 +69,10 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
69
69
 
70
70
  def __init__(self) -> None:
71
71
 
72
- # Map node_id to (online_until, heartbeat_interval)
73
- self.node_ids: dict[int, tuple[float, float]] = {}
74
- self.public_key_to_node_id: dict[bytes, int] = {}
75
- self.node_id_to_public_key: dict[int, bytes] = {}
72
+ # Map node_id to NodeInfo
73
+ self.nodes: dict[int, NodeInfo] = {}
74
+ self.registered_node_public_keys: set[bytes] = set()
75
+ self.owner_to_node_ids: dict[str, set[int]] = {} # Quick lookup
76
76
 
77
77
  # Map run_id to RunRecord
78
78
  self.run_ids: dict[int, RunRecord] = {}
@@ -114,7 +114,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
114
114
  )
115
115
  return None
116
116
  # Validate destination node ID
117
- if message.metadata.dst_node_id not in self.node_ids:
117
+ if message.metadata.dst_node_id not in self.nodes:
118
118
  log(
119
119
  ERROR,
120
120
  "Invalid destination node ID for Message: %s",
@@ -136,7 +136,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
136
136
 
137
137
  # Find Message for node_id that were not delivered yet
138
138
  message_ins_list: list[Message] = []
139
- current_time = time.time()
139
+ current_time = now().timestamp()
140
140
  with self.lock:
141
141
  for _, msg_ins in self.message_ins_store.items():
142
142
  if (
@@ -190,7 +190,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
190
190
  return None
191
191
 
192
192
  ins_metadata = msg_ins.metadata
193
- if ins_metadata.created_at + ins_metadata.ttl <= time.time():
193
+ if ins_metadata.created_at + ins_metadata.ttl <= now().timestamp():
194
194
  log(
195
195
  ERROR,
196
196
  "Failed to store Message: the message it is replying to "
@@ -238,7 +238,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
238
238
  ret: dict[str, Message] = {}
239
239
 
240
240
  with self.lock:
241
- current = time.time()
241
+ current = now().timestamp()
242
242
 
243
243
  # Verify Message IDs
244
244
  ret = verify_message_ids(
@@ -256,9 +256,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
256
256
  inquired_in_message_ids=message_ids,
257
257
  found_in_message_dict=self.message_ins_store,
258
258
  node_id_to_online_until={
259
- node_id: self.node_ids[node_id][0]
259
+ node_id: self.nodes[node_id].online_until
260
260
  for node_id in dst_node_ids
261
- if node_id in self.node_ids
261
+ if node_id in self.nodes
262
262
  },
263
263
  current_time=current,
264
264
  )
@@ -330,7 +330,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
330
330
  """
331
331
  return len(self.message_res_store)
332
332
 
333
- def create_node(self, heartbeat_interval: float) -> int:
333
+ def create_node(
334
+ self, owner_aid: str, public_key: bytes, heartbeat_interval: float
335
+ ) -> int:
334
336
  """Create, store in the link state, and return `node_id`."""
335
337
  # Sample a random int64 as node_id
336
338
  node_id = generate_rand_int_from_bytes(
@@ -338,28 +340,40 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
338
340
  )
339
341
 
340
342
  with self.lock:
341
- if node_id in self.node_ids:
343
+ if node_id in self.nodes:
342
344
  log(ERROR, "Unexpected node registration failure.")
343
345
  return 0
346
+ if public_key in self.registered_node_public_keys:
347
+ raise ValueError("Public key already in use")
344
348
 
345
- # Mark the node online until time.time() + heartbeat_interval
346
- self.node_ids[node_id] = (
347
- time.time() + heartbeat_interval,
348
- heartbeat_interval,
349
+ # Mark the node online until now().timestamp() + heartbeat_interval
350
+ current = now()
351
+ self.nodes[node_id] = NodeInfo(
352
+ node_id=node_id,
353
+ owner_aid=owner_aid, # Unused for now
354
+ status="created", # Unused for now
355
+ created_at=current.isoformat(), # Unused for now
356
+ last_activated_at=current.isoformat(), # Unused for now
357
+ last_deactivated_at="", # Unused for now
358
+ deleted_at="", # Unused for now
359
+ online_until=current.timestamp() + heartbeat_interval,
360
+ heartbeat_interval=heartbeat_interval,
361
+ public_key=public_key,
349
362
  )
363
+ self.registered_node_public_keys.add(public_key)
364
+ self.owner_to_node_ids.setdefault(owner_aid, set()).add(node_id)
350
365
  return node_id
351
366
 
352
- def delete_node(self, node_id: int) -> None:
367
+ def delete_node(self, owner_aid: str, node_id: int) -> None:
353
368
  """Delete a node."""
354
369
  with self.lock:
355
- if node_id not in self.node_ids:
356
- raise ValueError(f"Node {node_id} not found")
357
-
358
- # Remove node ID <> public key mappings
359
- if pk := self.node_id_to_public_key.pop(node_id, None):
360
- del self.public_key_to_node_id[pk]
370
+ if node_id not in self.nodes or owner_aid != self.nodes[node_id].owner_aid:
371
+ raise ValueError(
372
+ f"Node ID {node_id} not found or unauthorized deletion attempt."
373
+ )
361
374
 
362
- del self.node_ids[node_id]
375
+ node = self.nodes.pop(node_id)
376
+ self.registered_node_public_keys.discard(node.public_key)
363
377
 
364
378
  def get_nodes(self, run_id: int) -> set[int]:
365
379
  """Return all available nodes.
@@ -372,36 +386,20 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
372
386
  with self.lock:
373
387
  if run_id not in self.run_ids:
374
388
  return set()
375
- current_time = time.time()
389
+ current_time = now().timestamp()
376
390
  return {
377
- node_id
378
- for node_id, (online_until, _) in self.node_ids.items()
379
- if online_until > current_time
391
+ info.node_id
392
+ for info in self.nodes.values()
393
+ if info.online_until > current_time
380
394
  }
381
395
 
382
- def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
383
- """Set `public_key` for the specified `node_id`."""
384
- with self.lock:
385
- if node_id not in self.node_ids:
386
- raise ValueError(f"Node {node_id} not found")
387
-
388
- if public_key in self.public_key_to_node_id:
389
- raise ValueError("Public key already in use")
390
-
391
- self.public_key_to_node_id[public_key] = node_id
392
- self.node_id_to_public_key[node_id] = public_key
393
-
394
- def get_node_public_key(self, node_id: int) -> Optional[bytes]:
396
+ def get_node_public_key(self, node_id: int) -> bytes:
395
397
  """Get `public_key` for the specified `node_id`."""
396
398
  with self.lock:
397
- if node_id not in self.node_ids:
398
- raise ValueError(f"Node {node_id} not found")
399
-
400
- return self.node_id_to_public_key.get(node_id)
399
+ if (node := self.nodes.get(node_id)) is None:
400
+ raise ValueError(f"Node ID {node_id} not found")
401
401
 
402
- def get_node_id(self, node_public_key: bytes) -> Optional[int]:
403
- """Retrieve stored `node_id` filtered by `node_public_keys`."""
404
- return self.public_key_to_node_id.get(node_public_key)
402
+ return node.public_key
405
403
 
406
404
  # pylint: disable=too-many-arguments,too-many-positional-arguments
407
405
  def create_run(
@@ -608,13 +606,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
608
606
  the node is marked as offline.
609
607
  """
610
608
  with self.lock:
611
- if node_id in self.node_ids:
612
- self.node_ids[node_id] = (
613
- time.time() + HEARTBEAT_PATIENCE * heartbeat_interval,
614
- heartbeat_interval,
609
+ if info := self.nodes.get(node_id):
610
+ info.online_until = (
611
+ now().timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval
615
612
  )
613
+ info.heartbeat_interval = heartbeat_interval
616
614
  return True
617
- return False
615
+ return False
618
616
 
619
617
  def acknowledge_app_heartbeat(self, run_id: int, heartbeat_interval: float) -> bool:
620
618
  """Acknowledge a heartbeat received from a ServerApp for a given run.
@@ -128,11 +128,13 @@ class LinkState(CoreState): # pylint: disable=R0904
128
128
  """Get all instruction Message IDs for the given run_id."""
129
129
 
130
130
  @abc.abstractmethod
131
- def create_node(self, heartbeat_interval: float) -> int:
131
+ def create_node(
132
+ self, owner_aid: str, public_key: bytes, heartbeat_interval: float
133
+ ) -> int:
132
134
  """Create, store in the link state, and return `node_id`."""
133
135
 
134
136
  @abc.abstractmethod
135
- def delete_node(self, node_id: int) -> None:
137
+ def delete_node(self, owner_aid: str, node_id: int) -> None:
136
138
  """Remove `node_id` from the link state."""
137
139
 
138
140
  @abc.abstractmethod
@@ -146,16 +148,24 @@ class LinkState(CoreState): # pylint: disable=R0904
146
148
  """
147
149
 
148
150
  @abc.abstractmethod
149
- def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
150
- """Set `public_key` for the specified `node_id`."""
151
+ def get_node_public_key(self, node_id: int) -> bytes:
152
+ """Get `public_key` for the specified `node_id`.
151
153
 
152
- @abc.abstractmethod
153
- def get_node_public_key(self, node_id: int) -> Optional[bytes]:
154
- """Get `public_key` for the specified `node_id`."""
154
+ Parameters
155
+ ----------
156
+ node_id : int
157
+ The identifier of the node whose public key is to be retrieved.
155
158
 
156
- @abc.abstractmethod
157
- def get_node_id(self, node_public_key: bytes) -> Optional[int]:
158
- """Retrieve stored `node_id` filtered by `node_public_keys`."""
159
+ Returns
160
+ -------
161
+ bytes
162
+ The public key associated with the specified `node_id`.
163
+
164
+ Raises
165
+ ------
166
+ ValueError
167
+ If the specified `node_id` does not exist in the link state.
168
+ """
159
169
 
160
170
  @abc.abstractmethod
161
171
  def create_run( # pylint: disable=too-many-arguments,too-many-positional-arguments
@@ -21,7 +21,6 @@ import json
21
21
  import re
22
22
  import secrets
23
23
  import sqlite3
24
- import time
25
24
  from collections.abc import Sequence
26
25
  from logging import DEBUG, ERROR, WARNING
27
26
  from typing import Any, Optional, Union, cast
@@ -72,10 +71,16 @@ from .utils import (
72
71
 
73
72
  SQL_CREATE_TABLE_NODE = """
74
73
  CREATE TABLE IF NOT EXISTS node(
75
- node_id INTEGER UNIQUE,
76
- online_until REAL,
77
- heartbeat_interval REAL,
78
- public_key BLOB
74
+ node_id INTEGER UNIQUE,
75
+ owner_aid TEXT,
76
+ status TEXT,
77
+ created_at TEXT,
78
+ last_activated_at TEXT,
79
+ last_deactivated_at TEXT,
80
+ deleted_at TEXT,
81
+ online_until REAL,
82
+ heartbeat_interval REAL,
83
+ public_key BLOB UNIQUE
79
84
  );
80
85
  """
81
86
 
@@ -89,6 +94,10 @@ SQL_CREATE_INDEX_ONLINE_UNTIL = """
89
94
  CREATE INDEX IF NOT EXISTS idx_online_until ON node (online_until);
90
95
  """
91
96
 
97
+ SQL_CREATE_INDEX_OWNER_AID = """
98
+ CREATE INDEX IF NOT EXISTS idx_node_owner_aid ON node(owner_aid);
99
+ """
100
+
92
101
  SQL_CREATE_TABLE_RUN = """
93
102
  CREATE TABLE IF NOT EXISTS run(
94
103
  run_id INTEGER UNIQUE,
@@ -223,6 +232,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
223
232
  cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
224
233
  cur.execute(SQL_CREATE_TABLE_TOKEN_STORE)
225
234
  cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
235
+ cur.execute(SQL_CREATE_INDEX_OWNER_AID)
226
236
  res = cur.execute("SELECT name FROM sqlite_schema;")
227
237
  return res.fetchall()
228
238
 
@@ -451,7 +461,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
451
461
  ret: dict[str, Message] = {}
452
462
 
453
463
  # Verify Message IDs
454
- current = time.time()
464
+ current = now().timestamp()
455
465
  query = f"""
456
466
  SELECT *
457
467
  FROM message_ins
@@ -597,7 +607,9 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
597
607
 
598
608
  return {row["message_id"] for row in rows}
599
609
 
600
- def create_node(self, heartbeat_interval: float) -> int:
610
+ def create_node(
611
+ self, owner_aid: str, public_key: bytes, heartbeat_interval: float
612
+ ) -> int:
601
613
  """Create, store in the link state, and return `node_id`."""
602
614
  # Sample a random uint64 as node_id
603
615
  uint64_node_id = generate_rand_int_from_bytes(
@@ -607,37 +619,48 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
607
619
  # Convert the uint64 value to sint64 for SQLite
608
620
  sint64_node_id = convert_uint64_to_sint64(uint64_node_id)
609
621
 
610
- query = (
611
- "INSERT INTO node "
612
- "(node_id, online_until, heartbeat_interval, public_key) "
613
- "VALUES (?, ?, ?, ?)"
614
- )
622
+ query = """
623
+ INSERT INTO node
624
+ (node_id, owner_aid, status, created_at, last_activated_at,
625
+ last_deactivated_at, deleted_at, online_until, heartbeat_interval,
626
+ public_key)
627
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
628
+ """
615
629
 
616
- # Mark the node online util time.time() + heartbeat_interval
630
+ # Mark the node online until now().timestamp() + heartbeat_interval
617
631
  try:
618
632
  self.query(
619
633
  query,
620
634
  (
621
- sint64_node_id,
622
- time.time() + heartbeat_interval,
623
- heartbeat_interval,
624
- b"", # Initialize with an empty public key
635
+ sint64_node_id, # node_id
636
+ owner_aid, # owner_aid, unused for now
637
+ "created", # status, unused for now
638
+ now().isoformat(), # created_at, unused for now
639
+ now().isoformat(), # last_activated_at, unused for now
640
+ "", # last_deactivated_at, unused for now
641
+ "", # deleted_at, unused for now
642
+ now().timestamp() + heartbeat_interval, # online_until
643
+ heartbeat_interval, # heartbeat_interval
644
+ public_key, # public_key
625
645
  ),
626
646
  )
627
- except sqlite3.IntegrityError:
647
+ except sqlite3.IntegrityError as e:
648
+ if "UNIQUE constraint failed: node.public_key" in str(e):
649
+ raise ValueError("Public key already in use.") from None
650
+ # Must be node ID conflict, almost impossible unless system is compromised
628
651
  log(ERROR, "Unexpected node registration failure.")
629
652
  return 0
630
653
 
631
654
  # Note: we need to return the uint64 value of the node_id
632
655
  return uint64_node_id
633
656
 
634
- def delete_node(self, node_id: int) -> None:
657
+ def delete_node(self, owner_aid: str, node_id: int) -> None:
635
658
  """Delete a node."""
636
659
  # Convert the uint64 value to sint64 for SQLite
637
660
  sint64_node_id = convert_uint64_to_sint64(node_id)
638
661
 
639
- query = "DELETE FROM node WHERE node_id = ?"
640
- params = (sint64_node_id,)
662
+ query = "DELETE FROM node WHERE node_id = ? AND owner_aid = ?"
663
+ params = (sint64_node_id, owner_aid)
641
664
 
642
665
  if self.conn is None:
643
666
  raise AttributeError("LinkState is not initialized.")
@@ -646,7 +669,9 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
646
669
  with self.conn:
647
670
  rows = self.conn.execute(query, params)
648
671
  if rows.rowcount < 1:
649
- raise ValueError(f"Node {node_id} not found")
672
+ raise ValueError(
673
+ f"Node ID {node_id} not found or unauthorized deletion attempt."
674
+ )
650
675
  except KeyError as exc:
651
676
  log(ERROR, {"query": query, "data": params, "exception": exc})
652
677
 
@@ -668,32 +693,13 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
668
693
 
669
694
  # Get nodes
670
695
  query = "SELECT node_id FROM node WHERE online_until > ?;"
671
- rows = self.query(query, (time.time(),))
696
+ rows = self.query(query, (now().timestamp(),))
672
697
 
673
698
  # Convert sint64 node_ids to uint64
674
699
  result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows}
675
700
  return result
676
701
 
677
- def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
678
- """Set `public_key` for the specified `node_id`."""
679
- # Convert the uint64 value to sint64 for SQLite
680
- sint64_node_id = convert_uint64_to_sint64(node_id)
681
-
682
- # Check if the node exists in the `node` table
683
- query = "SELECT 1 FROM node WHERE node_id = ?"
684
- if not self.query(query, (sint64_node_id,)):
685
- raise ValueError(f"Node {node_id} not found")
686
-
687
- # Check if the public key is already in use in the `node` table
688
- query = "SELECT 1 FROM node WHERE public_key = ?"
689
- if self.query(query, (public_key,)):
690
- raise ValueError("Public key already in use")
691
-
692
- # Update the `node` table to set the public key for the given node ID
693
- query = "UPDATE node SET public_key = ? WHERE node_id = ?"
694
- self.query(query, (public_key, sint64_node_id))
695
-
696
- def get_node_public_key(self, node_id: int) -> Optional[bytes]:
702
+ def get_node_public_key(self, node_id: int) -> bytes:
697
703
  """Get `public_key` for the specified `node_id`."""
698
704
  # Convert the uint64 value to sint64 for SQLite
699
705
  sint64_node_id = convert_uint64_to_sint64(node_id)
@@ -704,23 +710,10 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
704
710
 
705
711
  # If no result is found, return None
706
712
  if not rows:
707
- raise ValueError(f"Node {node_id} not found")
708
-
709
- # Return the public key if it is not empty, otherwise return None
710
- return rows[0]["public_key"] or None
711
-
712
- def get_node_id(self, node_public_key: bytes) -> Optional[int]:
713
- """Retrieve stored `node_id` filtered by `node_public_keys`."""
714
- query = "SELECT node_id FROM node WHERE public_key = :public_key;"
715
- row = self.query(query, {"public_key": node_public_key})
716
- if len(row) > 0:
717
- node_id: int = row[0]["node_id"]
713
+ raise ValueError(f"Node ID {node_id} not found")
718
714
 
719
- # Convert the sint64 value to uint64 after reading from SQLite
720
- uint64_node_id = convert_sint64_to_uint64(node_id)
721
-
722
- return uint64_node_id
723
- return None
715
+ # Return the public key
716
+ return cast(bytes, rows[0]["public_key"])
724
717
 
725
718
  # pylint: disable=too-many-arguments,too-many-positional-arguments
726
719
  def create_run(
@@ -1010,7 +1003,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1010
1003
  self.query(
1011
1004
  query,
1012
1005
  (
1013
- time.time() + HEARTBEAT_PATIENCE * heartbeat_interval,
1006
+ now().timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval,
1014
1007
  heartbeat_interval,
1015
1008
  sint64_node_id,
1016
1009
  ),
@@ -1140,7 +1133,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1140
1133
  message_ins = rows[0]
1141
1134
  created_at = message_ins["created_at"]
1142
1135
  ttl = message_ins["ttl"]
1143
- current_time = time.time()
1136
+ current_time = now().timestamp()
1144
1137
 
1145
1138
  # Check if Message is expired
1146
1139
  if ttl is not None and created_at + ttl <= current_time:
@@ -15,10 +15,9 @@
15
15
  """Validators."""
16
16
 
17
17
 
18
- import time
19
-
20
18
  from flwr.common import Message
21
19
  from flwr.common.constant import SUPERLINK_NODE_ID
20
+ from flwr.common.date import now
22
21
 
23
22
 
24
23
  # pylint: disable-next=too-many-branches
@@ -44,7 +43,7 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
44
43
  validation_errors.append("`metadata.ttl` must be higher than zero")
45
44
 
46
45
  # Verify TTL and created_at time
47
- current_time = time.time()
46
+ current_time = now().timestamp()
48
47
  if metadata.created_at + metadata.ttl <= current_time:
49
48
  validation_errors.append("Message TTL has expired")
50
49
 
@@ -107,3 +107,11 @@ def verify_signature(
107
107
  return True
108
108
  except InvalidSignature:
109
109
  return False
110
+
111
+
112
+ def uses_nist_ec_curve(public_key: ec.EllipticCurvePublicKey) -> bool:
113
+ """Return True if the provided key uses a NIST EC curve."""
114
+ return isinstance(
115
+ public_key.curve,
116
+ (ec.SECP192R1, ec.SECP224R1, ec.SECP256R1, ec.SECP384R1, ec.SECP521R1),
117
+ )
@@ -15,12 +15,41 @@
15
15
  """Account auth plugin for ControlServicer."""
16
16
 
17
17
 
18
+ from flwr.common.constant import AuthnType, AuthzType
19
+
18
20
  from .auth_plugin import ControlAuthnPlugin, ControlAuthzPlugin
19
21
  from .noop_auth_plugin import NoOpControlAuthnPlugin, NoOpControlAuthzPlugin
20
22
 
23
+ try:
24
+ from flwr.ee import get_control_authn_ee_plugins, get_control_authz_ee_plugins
25
+ except ImportError:
26
+
27
+ def get_control_authn_ee_plugins() -> dict[str, type[ControlAuthnPlugin]]:
28
+ """Return all Control API authentication plugins for EE."""
29
+ return {}
30
+
31
+ def get_control_authz_ee_plugins() -> dict[str, type[ControlAuthzPlugin]]:
32
+ """Return all Control API authorization plugins for EE."""
33
+ return {}
34
+
35
+
36
+ def get_control_authn_plugins() -> dict[str, type[ControlAuthnPlugin]]:
37
+ """Return all Control API authentication plugins."""
38
+ ee_dict: dict[str, type[ControlAuthnPlugin]] = get_control_authn_ee_plugins()
39
+ return ee_dict | {AuthnType.NOOP: NoOpControlAuthnPlugin}
40
+
41
+
42
+ def get_control_authz_plugins() -> dict[str, type[ControlAuthzPlugin]]:
43
+ """Return all Control API authorization plugins."""
44
+ ee_dict: dict[str, type[ControlAuthzPlugin]] = get_control_authz_ee_plugins()
45
+ return ee_dict | {AuthzType.NOOP: NoOpControlAuthzPlugin}
46
+
47
+
21
48
  __all__ = [
22
49
  "ControlAuthnPlugin",
23
50
  "ControlAuthzPlugin",
24
51
  "NoOpControlAuthnPlugin",
25
52
  "NoOpControlAuthzPlugin",
53
+ "get_control_authn_plugins",
54
+ "get_control_authz_plugins",
26
55
  ]
@@ -31,7 +31,11 @@ from flwr.supercore.ffs import FfsFactory
31
31
  from flwr.supercore.license_plugin import LicensePlugin
32
32
  from flwr.supercore.object_store import ObjectStoreFactory
33
33
  from flwr.superlink.artifact_provider import ArtifactProvider
34
- from flwr.superlink.auth_plugin import ControlAuthnPlugin, ControlAuthzPlugin
34
+ from flwr.superlink.auth_plugin import (
35
+ ControlAuthnPlugin,
36
+ ControlAuthzPlugin,
37
+ NoOpControlAuthnPlugin,
38
+ )
35
39
 
36
40
  from .control_account_auth_interceptor import ControlAccountAuthInterceptor
37
41
  from .control_event_log_interceptor import ControlEventLogInterceptor
@@ -54,8 +58,8 @@ def run_control_api_grpc(
54
58
  objectstore_factory: ObjectStoreFactory,
55
59
  certificates: Optional[tuple[bytes, bytes, bytes]],
56
60
  is_simulation: bool,
57
- authn_plugin: Optional[ControlAuthnPlugin] = None,
58
- authz_plugin: Optional[ControlAuthzPlugin] = None,
61
+ authn_plugin: ControlAuthnPlugin,
62
+ authz_plugin: ControlAuthzPlugin,
59
63
  event_log_plugin: Optional[EventLogWriterPlugin] = None,
60
64
  artifact_provider: Optional[ArtifactProvider] = None,
61
65
  ) -> grpc.Server:
@@ -72,11 +76,9 @@ def run_control_api_grpc(
72
76
  authn_plugin=authn_plugin,
73
77
  artifact_provider=artifact_provider,
74
78
  )
75
- interceptors: list[grpc.ServerInterceptor] = []
79
+ interceptors = [ControlAccountAuthInterceptor(authn_plugin, authz_plugin)]
76
80
  if license_plugin is not None:
77
81
  interceptors.append(ControlLicenseInterceptor(license_plugin))
78
- if authn_plugin is not None and authz_plugin is not None:
79
- interceptors.append(ControlAccountAuthInterceptor(authn_plugin, authz_plugin))
80
82
  # Event log interceptor must be added after account auth interceptor
81
83
  if event_log_plugin is not None:
82
84
  interceptors.append(ControlEventLogInterceptor(event_log_plugin))
@@ -90,7 +92,7 @@ def run_control_api_grpc(
90
92
  interceptors=interceptors or None,
91
93
  )
92
94
 
93
- if authn_plugin is None:
95
+ if isinstance(authn_plugin, NoOpControlAuthnPlugin):
94
96
  log(INFO, "Flower Deployment Runtime: Starting Control API on %s", address)
95
97
  else:
96
98
  log(