flwr-nightly 1.23.0.dev20251007__py3-none-any.whl → 1.23.0.dev20251008__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. flwr/cli/auth_plugin/__init__.py +7 -3
  2. flwr/cli/log.py +2 -2
  3. flwr/cli/login/login.py +4 -13
  4. flwr/cli/ls.py +2 -2
  5. flwr/cli/pull.py +2 -2
  6. flwr/cli/run/run.py +2 -2
  7. flwr/cli/stop.py +2 -2
  8. flwr/cli/supernode/ls.py +2 -2
  9. flwr/cli/utils.py +28 -44
  10. flwr/client/grpc_rere_client/connection.py +6 -4
  11. flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +2 -2
  12. flwr/client/rest_client/connection.py +7 -1
  13. flwr/common/constant.py +10 -0
  14. flwr/proto/fleet_pb2.py +22 -22
  15. flwr/proto/fleet_pb2.pyi +4 -1
  16. flwr/proto/node_pb2.py +1 -1
  17. flwr/server/app.py +32 -31
  18. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +8 -4
  19. flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +18 -37
  20. flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
  21. flwr/server/superlink/fleet/vce/vce_api.py +7 -1
  22. flwr/server/superlink/linkstate/in_memory_linkstate.py +39 -27
  23. flwr/server/superlink/linkstate/linkstate.py +1 -1
  24. flwr/server/superlink/linkstate/sqlite_linkstate.py +37 -21
  25. flwr/server/utils/validator.py +2 -3
  26. flwr/superlink/auth_plugin/__init__.py +29 -0
  27. flwr/superlink/servicer/control/control_grpc.py +9 -7
  28. flwr/superlink/servicer/control/control_servicer.py +34 -46
  29. {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/METADATA +1 -1
  30. {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/RECORD +32 -32
  31. {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/WHEEL +0 -0
  32. {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/entry_points.txt +0 -0
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  import datetime
19
- from typing import Any, Callable, Optional, cast
19
+ from typing import Any, Callable, cast
20
20
 
21
21
  import grpc
22
22
  from google.protobuf.message import Message as GrpcMessage
@@ -29,10 +29,7 @@ from flwr.common.constant import (
29
29
  TIMESTAMP_HEADER,
30
30
  TIMESTAMP_TOLERANCE,
31
31
  )
32
- from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
- CreateNodeRequest,
34
- CreateNodeResponse,
35
- )
32
+ from flwr.proto.fleet_pb2 import CreateNodeRequest # pylint: disable=E0611
36
33
  from flwr.server.superlink.linkstate import LinkStateFactory
37
34
  from flwr.supercore.primitives.asymmetric import bytes_to_public_key, verify_signature
38
35
 
@@ -50,7 +47,7 @@ def _unary_unary_rpc_terminator(
50
47
  return grpc.unary_unary_rpc_method_handler(terminate)
51
48
 
52
49
 
53
- class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
50
+ class NodeAuthServerInterceptor(grpc.ServerInterceptor): # type: ignore
54
51
  """Server interceptor for node authentication.
55
52
 
56
53
  Parameters
@@ -110,50 +107,34 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
110
107
  if not MIN_TIMESTAMP_DIFF < time_diff.total_seconds() < MAX_TIMESTAMP_DIFF:
111
108
  return _unary_unary_rpc_terminator("Invalid timestamp")
112
109
 
113
- # Continue the RPC call
114
- expected_node_id = state.get_node_id(node_pk_bytes)
115
- if not handler_call_details.method.endswith("CreateNode"):
116
- # All calls, except for `CreateNode`, must provide a public key that is
117
- # already mapped to a `node_id` (in `LinkState`)
118
- if expected_node_id is None:
119
- return _unary_unary_rpc_terminator("Invalid node ID")
120
- # One of the method handlers in
110
+ # Continue the RPC call: One of the method handlers in
121
111
  # `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
122
112
  method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
123
- return self._wrap_method_handler(
124
- method_handler, expected_node_id, node_pk_bytes
125
- )
113
+ return self._wrap_method_handler(method_handler, node_pk_bytes)
126
114
 
127
115
  def _wrap_method_handler(
128
116
  self,
129
117
  method_handler: grpc.RpcMethodHandler,
130
- expected_node_id: Optional[int],
131
- node_public_key: bytes,
118
+ expected_public_key: bytes,
132
119
  ) -> grpc.RpcMethodHandler:
133
120
  def _generic_method_handler(
134
121
  request: GrpcMessage,
135
122
  context: grpc.ServicerContext,
136
123
  ) -> GrpcMessage:
137
- # Verify the node ID
138
- if not isinstance(request, CreateNodeRequest):
139
- try:
140
- if request.node.node_id != expected_node_id: # type: ignore
141
- raise ValueError
142
- except (AttributeError, ValueError):
143
- context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID")
124
+ # Retrieve the public key
125
+ if isinstance(request, CreateNodeRequest):
126
+ actual_public_key = request.public_key
127
+ else:
128
+ # Note: This function runs in a different thread
129
+ # than the `intercept_service` function.
130
+ actual_public_key = self.state_factory.state().get_node_public_key(
131
+ request.node.node_id # type: ignore
132
+ )
133
+ # Verify the public key
134
+ if actual_public_key != expected_public_key:
135
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID")
144
136
 
145
137
  response: GrpcMessage = method_handler.unary_unary(request, context)
146
-
147
- # Set the public key after a successful CreateNode request
148
- if isinstance(response, CreateNodeResponse):
149
- state = self.state_factory.state()
150
- try:
151
- state.set_node_public_key(response.node.node_id, node_public_key)
152
- except ValueError as e:
153
- # Remove newly created node if setting the public key fails
154
- state.delete_node(response.node.node_id)
155
- context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e))
156
-
157
138
  return response
158
139
 
159
140
  return grpc.unary_unary_rpc_method_handler(
@@ -70,7 +70,7 @@ def create_node(
70
70
  ) -> CreateNodeResponse:
71
71
  """."""
72
72
  # Create node
73
- node_id = state.create_node(heartbeat_interval=request.heartbeat_interval)
73
+ node_id = state.create_node(request.public_key, request.heartbeat_interval)
74
74
  return CreateNodeResponse(node=Node(node_id=node_id))
75
75
 
76
76
 
@@ -16,6 +16,7 @@
16
16
 
17
17
 
18
18
  import json
19
+ import secrets
19
20
  import threading
20
21
  import time
21
22
  import traceback
@@ -53,7 +54,12 @@ def _register_nodes(
53
54
  nodes_mapping: NodeToPartitionMapping = {}
54
55
  state = state_factory.state()
55
56
  for i in range(num_nodes):
56
- node_id = state.create_node(heartbeat_interval=HEARTBEAT_MAX_INTERVAL)
57
+ node_id = state.create_node(
58
+ # No node authentication in simulation;
59
+ # use random bytes instead
60
+ secrets.token_bytes(32),
61
+ heartbeat_interval=HEARTBEAT_MAX_INTERVAL,
62
+ )
57
63
  nodes_mapping[node_id] = i
58
64
  log(DEBUG, "Registered %i nodes", len(nodes_mapping))
59
65
  return nodes_mapping
@@ -17,7 +17,6 @@
17
17
 
18
18
  import secrets
19
19
  import threading
20
- import time
21
20
  from bisect import bisect_right
22
21
  from collections import defaultdict
23
22
  from dataclasses import dataclass, field
@@ -39,6 +38,7 @@ from flwr.common.constant import (
39
38
  )
40
39
  from flwr.common.record import ConfigRecord
41
40
  from flwr.common.typing import Run, RunStatus, UserConfig
41
+ from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
42
42
  from flwr.server.superlink.linkstate.linkstate import LinkState
43
43
  from flwr.server.utils import validate_message
44
44
 
@@ -70,7 +70,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
70
70
  def __init__(self) -> None:
71
71
 
72
72
  # Map node_id to (online_until, heartbeat_interval)
73
- self.node_ids: dict[int, tuple[float, float]] = {}
73
+ self.nodes: dict[int, NodeInfo] = {}
74
74
  self.public_key_to_node_id: dict[bytes, int] = {}
75
75
  self.node_id_to_public_key: dict[int, bytes] = {}
76
76
 
@@ -114,7 +114,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
114
114
  )
115
115
  return None
116
116
  # Validate destination node ID
117
- if message.metadata.dst_node_id not in self.node_ids:
117
+ if message.metadata.dst_node_id not in self.nodes:
118
118
  log(
119
119
  ERROR,
120
120
  "Invalid destination node ID for Message: %s",
@@ -136,7 +136,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
136
136
 
137
137
  # Find Message for node_id that were not delivered yet
138
138
  message_ins_list: list[Message] = []
139
- current_time = time.time()
139
+ current_time = now().timestamp()
140
140
  with self.lock:
141
141
  for _, msg_ins in self.message_ins_store.items():
142
142
  if (
@@ -190,7 +190,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
190
190
  return None
191
191
 
192
192
  ins_metadata = msg_ins.metadata
193
- if ins_metadata.created_at + ins_metadata.ttl <= time.time():
193
+ if ins_metadata.created_at + ins_metadata.ttl <= now().timestamp():
194
194
  log(
195
195
  ERROR,
196
196
  "Failed to store Message: the message it is replying to "
@@ -238,7 +238,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
238
238
  ret: dict[str, Message] = {}
239
239
 
240
240
  with self.lock:
241
- current = time.time()
241
+ current = now().timestamp()
242
242
 
243
243
  # Verify Message IDs
244
244
  ret = verify_message_ids(
@@ -256,9 +256,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
256
256
  inquired_in_message_ids=message_ids,
257
257
  found_in_message_dict=self.message_ins_store,
258
258
  node_id_to_online_until={
259
- node_id: self.node_ids[node_id][0]
259
+ node_id: self.nodes[node_id].online_until
260
260
  for node_id in dst_node_ids
261
- if node_id in self.node_ids
261
+ if node_id in self.nodes
262
262
  },
263
263
  current_time=current,
264
264
  )
@@ -330,7 +330,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
330
330
  """
331
331
  return len(self.message_res_store)
332
332
 
333
- def create_node(self, heartbeat_interval: float) -> int:
333
+ def create_node(self, public_key: bytes, heartbeat_interval: float) -> int:
334
334
  """Create, store in the link state, and return `node_id`."""
335
335
  # Sample a random int64 as node_id
336
336
  node_id = generate_rand_int_from_bytes(
@@ -338,28 +338,40 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
338
338
  )
339
339
 
340
340
  with self.lock:
341
- if node_id in self.node_ids:
341
+ if node_id in self.nodes:
342
342
  log(ERROR, "Unexpected node registration failure.")
343
343
  return 0
344
+ if public_key in self.public_key_to_node_id:
345
+ raise ValueError("Public key already in use")
344
346
 
345
- # Mark the node online until time.time() + heartbeat_interval
346
- self.node_ids[node_id] = (
347
- time.time() + heartbeat_interval,
348
- heartbeat_interval,
347
+ # Mark the node online until now().timestamp() + heartbeat_interval
348
+ current = now()
349
+ self.nodes[node_id] = NodeInfo(
350
+ node_id=node_id,
351
+ owner_aid="", # Unused for now
352
+ status="created", # Unused for now
353
+ created_at=current.isoformat(), # Unused for now
354
+ last_activated_at=current.isoformat(), # Unused for now
355
+ last_deactivated_at="", # Unused for now
356
+ deleted_at="", # Unused for now
357
+ online_until=current.timestamp() + heartbeat_interval,
358
+ heartbeat_interval=heartbeat_interval,
349
359
  )
360
+ self.public_key_to_node_id[public_key] = node_id
361
+ self.node_id_to_public_key[node_id] = public_key
350
362
  return node_id
351
363
 
352
364
  def delete_node(self, node_id: int) -> None:
353
365
  """Delete a node."""
354
366
  with self.lock:
355
- if node_id not in self.node_ids:
367
+ if node_id not in self.nodes:
356
368
  raise ValueError(f"Node {node_id} not found")
357
369
 
358
370
  # Remove node ID <> public key mappings
359
371
  if pk := self.node_id_to_public_key.pop(node_id, None):
360
372
  del self.public_key_to_node_id[pk]
361
373
 
362
- del self.node_ids[node_id]
374
+ del self.nodes[node_id]
363
375
 
364
376
  def get_nodes(self, run_id: int) -> set[int]:
365
377
  """Return all available nodes.
@@ -372,17 +384,17 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
372
384
  with self.lock:
373
385
  if run_id not in self.run_ids:
374
386
  return set()
375
- current_time = time.time()
387
+ current_time = now().timestamp()
376
388
  return {
377
- node_id
378
- for node_id, (online_until, _) in self.node_ids.items()
379
- if online_until > current_time
389
+ info.node_id
390
+ for info in self.nodes.values()
391
+ if info.online_until > current_time
380
392
  }
381
393
 
382
394
  def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
383
395
  """Set `public_key` for the specified `node_id`."""
384
396
  with self.lock:
385
- if node_id not in self.node_ids:
397
+ if node_id not in self.nodes:
386
398
  raise ValueError(f"Node {node_id} not found")
387
399
 
388
400
  if public_key in self.public_key_to_node_id:
@@ -394,7 +406,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
394
406
  def get_node_public_key(self, node_id: int) -> Optional[bytes]:
395
407
  """Get `public_key` for the specified `node_id`."""
396
408
  with self.lock:
397
- if node_id not in self.node_ids:
409
+ if node_id not in self.nodes:
398
410
  raise ValueError(f"Node {node_id} not found")
399
411
 
400
412
  return self.node_id_to_public_key.get(node_id)
@@ -608,13 +620,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
608
620
  the node is marked as offline.
609
621
  """
610
622
  with self.lock:
611
- if node_id in self.node_ids:
612
- self.node_ids[node_id] = (
613
- time.time() + HEARTBEAT_PATIENCE * heartbeat_interval,
614
- heartbeat_interval,
623
+ if info := self.nodes.get(node_id):
624
+ info.online_until = (
625
+ now().timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval
615
626
  )
627
+ info.heartbeat_interval = heartbeat_interval
616
628
  return True
617
- return False
629
+ return False
618
630
 
619
631
  def acknowledge_app_heartbeat(self, run_id: int, heartbeat_interval: float) -> bool:
620
632
  """Acknowledge a heartbeat received from a ServerApp for a given run.
@@ -128,7 +128,7 @@ class LinkState(CoreState): # pylint: disable=R0904
128
128
  """Get all instruction Message IDs for the given run_id."""
129
129
 
130
130
  @abc.abstractmethod
131
- def create_node(self, heartbeat_interval: float) -> int:
131
+ def create_node(self, public_key: bytes, heartbeat_interval: float) -> int:
132
132
  """Create, store in the link state, and return `node_id`."""
133
133
 
134
134
  @abc.abstractmethod
@@ -21,7 +21,6 @@ import json
21
21
  import re
22
22
  import secrets
23
23
  import sqlite3
24
- import time
25
24
  from collections.abc import Sequence
26
25
  from logging import DEBUG, ERROR, WARNING
27
26
  from typing import Any, Optional, Union, cast
@@ -72,10 +71,16 @@ from .utils import (
72
71
 
73
72
  SQL_CREATE_TABLE_NODE = """
74
73
  CREATE TABLE IF NOT EXISTS node(
75
- node_id INTEGER UNIQUE,
76
- online_until REAL,
77
- heartbeat_interval REAL,
78
- public_key BLOB
74
+ node_id INTEGER UNIQUE,
75
+ owner_aid TEXT,
76
+ status TEXT,
77
+ created_at TEXT,
78
+ last_activated_at TEXT,
79
+ last_deactivated_at TEXT,
80
+ deleted_at TEXT,
81
+ online_until REAL,
82
+ heartbeat_interval REAL,
83
+ public_key BLOB UNIQUE
79
84
  );
80
85
  """
81
86
 
@@ -451,7 +456,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
451
456
  ret: dict[str, Message] = {}
452
457
 
453
458
  # Verify Message IDs
454
- current = time.time()
459
+ current = now().timestamp()
455
460
  query = f"""
456
461
  SELECT *
457
462
  FROM message_ins
@@ -597,7 +602,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
597
602
 
598
603
  return {row["message_id"] for row in rows}
599
604
 
600
- def create_node(self, heartbeat_interval: float) -> int:
605
+ def create_node(self, public_key: bytes, heartbeat_interval: float) -> int:
601
606
  """Create, store in the link state, and return `node_id`."""
602
607
  # Sample a random uint64 as node_id
603
608
  uint64_node_id = generate_rand_int_from_bytes(
@@ -607,24 +612,35 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
607
612
  # Convert the uint64 value to sint64 for SQLite
608
613
  sint64_node_id = convert_uint64_to_sint64(uint64_node_id)
609
614
 
610
- query = (
611
- "INSERT INTO node "
612
- "(node_id, online_until, heartbeat_interval, public_key) "
613
- "VALUES (?, ?, ?, ?)"
614
- )
615
+ query = """
616
+ INSERT INTO node
617
+ (node_id, owner_aid, status, created_at, last_activated_at,
618
+ last_deactivated_at, deleted_at, online_until, heartbeat_interval,
619
+ public_key)
620
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
621
+ """
615
622
 
616
- # Mark the node online util time.time() + heartbeat_interval
623
+ # Mark the node online until now().timestamp() + heartbeat_interval
617
624
  try:
618
625
  self.query(
619
626
  query,
620
627
  (
621
- sint64_node_id,
622
- time.time() + heartbeat_interval,
623
- heartbeat_interval,
624
- b"", # Initialize with an empty public key
628
+ sint64_node_id, # node_id
629
+ "", # owner_aid, unused for now
630
+ "created", # status, unused for now
631
+ now().isoformat(), # created_at, unused for now
632
+ now().isoformat(), # last_activated_at, unused for now
633
+ "", # last_deactivated_at, unused for now
634
+ "", # deleted_at, unused for now
635
+ now().timestamp() + heartbeat_interval, # online_until
636
+ heartbeat_interval, # heartbeat_interval
637
+ public_key, # public_key
625
638
  ),
626
639
  )
627
- except sqlite3.IntegrityError:
640
+ except sqlite3.IntegrityError as e:
641
+ if "UNIQUE constraint failed: node.public_key" in str(e):
642
+ raise ValueError("Public key already in use.") from None
643
+ # Must be node ID conflict, almost impossible unless system is compromised
628
644
  log(ERROR, "Unexpected node registration failure.")
629
645
  return 0
630
646
 
@@ -668,7 +684,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
668
684
 
669
685
  # Get nodes
670
686
  query = "SELECT node_id FROM node WHERE online_until > ?;"
671
- rows = self.query(query, (time.time(),))
687
+ rows = self.query(query, (now().timestamp(),))
672
688
 
673
689
  # Convert sint64 node_ids to uint64
674
690
  result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows}
@@ -1010,7 +1026,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1010
1026
  self.query(
1011
1027
  query,
1012
1028
  (
1013
- time.time() + HEARTBEAT_PATIENCE * heartbeat_interval,
1029
+ now().timestamp() + HEARTBEAT_PATIENCE * heartbeat_interval,
1014
1030
  heartbeat_interval,
1015
1031
  sint64_node_id,
1016
1032
  ),
@@ -1140,7 +1156,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1140
1156
  message_ins = rows[0]
1141
1157
  created_at = message_ins["created_at"]
1142
1158
  ttl = message_ins["ttl"]
1143
- current_time = time.time()
1159
+ current_time = now().timestamp()
1144
1160
 
1145
1161
  # Check if Message is expired
1146
1162
  if ttl is not None and created_at + ttl <= current_time:
@@ -15,10 +15,9 @@
15
15
  """Validators."""
16
16
 
17
17
 
18
- import time
19
-
20
18
  from flwr.common import Message
21
19
  from flwr.common.constant import SUPERLINK_NODE_ID
20
+ from flwr.common.date import now
22
21
 
23
22
 
24
23
  # pylint: disable-next=too-many-branches
@@ -44,7 +43,7 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
44
43
  validation_errors.append("`metadata.ttl` must be higher than zero")
45
44
 
46
45
  # Verify TTL and created_at time
47
- current_time = time.time()
46
+ current_time = now().timestamp()
48
47
  if metadata.created_at + metadata.ttl <= current_time:
49
48
  validation_errors.append("Message TTL has expired")
50
49
 
@@ -15,12 +15,41 @@
15
15
  """Account auth plugin for ControlServicer."""
16
16
 
17
17
 
18
+ from flwr.common.constant import AuthnType, AuthzType
19
+
18
20
  from .auth_plugin import ControlAuthnPlugin, ControlAuthzPlugin
19
21
  from .noop_auth_plugin import NoOpControlAuthnPlugin, NoOpControlAuthzPlugin
20
22
 
23
+ try:
24
+ from flwr.ee import get_control_authn_ee_plugins, get_control_authz_ee_plugins
25
+ except ImportError:
26
+
27
+ def get_control_authn_ee_plugins() -> dict[str, type[ControlAuthnPlugin]]:
28
+ """Return all Control API authentication plugins for EE."""
29
+ return {}
30
+
31
+ def get_control_authz_ee_plugins() -> dict[str, type[ControlAuthzPlugin]]:
32
+ """Return all Control API authorization plugins for EE."""
33
+ return {}
34
+
35
+
36
+ def get_control_authn_plugins() -> dict[str, type[ControlAuthnPlugin]]:
37
+ """Return all Control API authentication plugins."""
38
+ ee_dict: dict[str, type[ControlAuthnPlugin]] = get_control_authn_ee_plugins()
39
+ return ee_dict | {AuthnType.NOOP: NoOpControlAuthnPlugin}
40
+
41
+
42
+ def get_control_authz_plugins() -> dict[str, type[ControlAuthzPlugin]]:
43
+ """Return all Control API authorization plugins."""
44
+ ee_dict: dict[str, type[ControlAuthzPlugin]] = get_control_authz_ee_plugins()
45
+ return ee_dict | {AuthzType.NOOP: NoOpControlAuthzPlugin}
46
+
47
+
21
48
  __all__ = [
22
49
  "ControlAuthnPlugin",
23
50
  "ControlAuthzPlugin",
24
51
  "NoOpControlAuthnPlugin",
25
52
  "NoOpControlAuthzPlugin",
53
+ "get_control_authn_plugins",
54
+ "get_control_authz_plugins",
26
55
  ]
@@ -31,7 +31,11 @@ from flwr.supercore.ffs import FfsFactory
31
31
  from flwr.supercore.license_plugin import LicensePlugin
32
32
  from flwr.supercore.object_store import ObjectStoreFactory
33
33
  from flwr.superlink.artifact_provider import ArtifactProvider
34
- from flwr.superlink.auth_plugin import ControlAuthnPlugin, ControlAuthzPlugin
34
+ from flwr.superlink.auth_plugin import (
35
+ ControlAuthnPlugin,
36
+ ControlAuthzPlugin,
37
+ NoOpControlAuthnPlugin,
38
+ )
35
39
 
36
40
  from .control_account_auth_interceptor import ControlAccountAuthInterceptor
37
41
  from .control_event_log_interceptor import ControlEventLogInterceptor
@@ -54,8 +58,8 @@ def run_control_api_grpc(
54
58
  objectstore_factory: ObjectStoreFactory,
55
59
  certificates: Optional[tuple[bytes, bytes, bytes]],
56
60
  is_simulation: bool,
57
- authn_plugin: Optional[ControlAuthnPlugin] = None,
58
- authz_plugin: Optional[ControlAuthzPlugin] = None,
61
+ authn_plugin: ControlAuthnPlugin,
62
+ authz_plugin: ControlAuthzPlugin,
59
63
  event_log_plugin: Optional[EventLogWriterPlugin] = None,
60
64
  artifact_provider: Optional[ArtifactProvider] = None,
61
65
  ) -> grpc.Server:
@@ -72,11 +76,9 @@ def run_control_api_grpc(
72
76
  authn_plugin=authn_plugin,
73
77
  artifact_provider=artifact_provider,
74
78
  )
75
- interceptors: list[grpc.ServerInterceptor] = []
79
+ interceptors = [ControlAccountAuthInterceptor(authn_plugin, authz_plugin)]
76
80
  if license_plugin is not None:
77
81
  interceptors.append(ControlLicenseInterceptor(license_plugin))
78
- if authn_plugin is not None and authz_plugin is not None:
79
- interceptors.append(ControlAccountAuthInterceptor(authn_plugin, authz_plugin))
80
82
  # Event log interceptor must be added after account auth interceptor
81
83
  if event_log_plugin is not None:
82
84
  interceptors.append(ControlEventLogInterceptor(event_log_plugin))
@@ -90,7 +92,7 @@ def run_control_api_grpc(
90
92
  interceptors=interceptors or None,
91
93
  )
92
94
 
93
- if authn_plugin is None:
95
+ if isinstance(authn_plugin, NoOpControlAuthnPlugin):
94
96
  log(INFO, "Flower Deployment Runtime: Starting Control API on %s", address)
95
97
  else:
96
98
  log(