flwr-nightly 1.23.0.dev20251016__py3-none-any.whl → 1.23.0.dev20251020__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

flwr/proto/node_pb2.py CHANGED
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
14
14
 
15
15
 
16
16
 
17
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/node.proto\x12\nflwr.proto\"\x17\n\x04Node\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\"\xc6\x02\n\x08NodeInfo\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\x12\x11\n\towner_aid\x18\x02 \x01(\t\x12\x0e\n\x06status\x18\x03 \x01(\t\x12\x12\n\ncreated_at\x18\x04 \x01(\t\x12\x1e\n\x11last_activated_at\x18\x05 \x01(\tH\x00\x88\x01\x01\x12 \n\x13last_deactivated_at\x18\x06 \x01(\tH\x01\x88\x01\x01\x12\x17\n\ndeleted_at\x18\x07 \x01(\tH\x02\x88\x01\x01\x12\x19\n\x0conline_until\x18\x08 \x01(\x01H\x03\x88\x01\x01\x12\x1a\n\x12heartbeat_interval\x18\t \x01(\x01\x12\x12\n\npublic_key\x18\n \x01(\x0c\x42\x14\n\x12_last_activated_atB\x16\n\x14_last_deactivated_atB\r\n\x0b_deleted_atB\x0f\n\r_online_untilb\x06proto3')
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/node.proto\x12\nflwr.proto\"\x17\n\x04Node\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\"\xd3\x02\n\x08NodeInfo\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\x12\x11\n\towner_aid\x18\x02 \x01(\t\x12\x0e\n\x06status\x18\x03 \x01(\t\x12\x15\n\rregistered_at\x18\x04 \x01(\t\x12\x1e\n\x11last_activated_at\x18\x05 \x01(\tH\x00\x88\x01\x01\x12 \n\x13last_deactivated_at\x18\x06 \x01(\tH\x01\x88\x01\x01\x12\x1c\n\x0funregistered_at\x18\x07 \x01(\tH\x02\x88\x01\x01\x12\x19\n\x0conline_until\x18\x08 \x01(\x01H\x03\x88\x01\x01\x12\x1a\n\x12heartbeat_interval\x18\t \x01(\x01\x12\x12\n\npublic_key\x18\n \x01(\x0c\x42\x14\n\x12_last_activated_atB\x16\n\x14_last_deactivated_atB\x12\n\x10_unregistered_atB\x0f\n\r_online_untilb\x06proto3')
18
18
 
19
19
  _globals = globals()
20
20
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -24,5 +24,5 @@ if _descriptor._USE_C_DESCRIPTORS == False:
24
24
  _globals['_NODE']._serialized_start=37
25
25
  _globals['_NODE']._serialized_end=60
26
26
  _globals['_NODEINFO']._serialized_start=63
27
- _globals['_NODEINFO']._serialized_end=389
27
+ _globals['_NODEINFO']._serialized_end=402
28
28
  # @@protoc_insertion_point(module_scope)
flwr/proto/node_pb2.pyi CHANGED
@@ -26,20 +26,20 @@ class NodeInfo(google.protobuf.message.Message):
26
26
  NODE_ID_FIELD_NUMBER: builtins.int
27
27
  OWNER_AID_FIELD_NUMBER: builtins.int
28
28
  STATUS_FIELD_NUMBER: builtins.int
29
- CREATED_AT_FIELD_NUMBER: builtins.int
29
+ REGISTERED_AT_FIELD_NUMBER: builtins.int
30
30
  LAST_ACTIVATED_AT_FIELD_NUMBER: builtins.int
31
31
  LAST_DEACTIVATED_AT_FIELD_NUMBER: builtins.int
32
- DELETED_AT_FIELD_NUMBER: builtins.int
32
+ UNREGISTERED_AT_FIELD_NUMBER: builtins.int
33
33
  ONLINE_UNTIL_FIELD_NUMBER: builtins.int
34
34
  HEARTBEAT_INTERVAL_FIELD_NUMBER: builtins.int
35
35
  PUBLIC_KEY_FIELD_NUMBER: builtins.int
36
36
  node_id: builtins.int
37
37
  owner_aid: typing.Text
38
38
  status: typing.Text
39
- created_at: typing.Text
39
+ registered_at: typing.Text
40
40
  last_activated_at: typing.Text
41
41
  last_deactivated_at: typing.Text
42
- deleted_at: typing.Text
42
+ unregistered_at: typing.Text
43
43
  online_until: builtins.float
44
44
  heartbeat_interval: builtins.float
45
45
  public_key: builtins.bytes
@@ -48,22 +48,22 @@ class NodeInfo(google.protobuf.message.Message):
48
48
  node_id: builtins.int = ...,
49
49
  owner_aid: typing.Text = ...,
50
50
  status: typing.Text = ...,
51
- created_at: typing.Text = ...,
51
+ registered_at: typing.Text = ...,
52
52
  last_activated_at: typing.Optional[typing.Text] = ...,
53
53
  last_deactivated_at: typing.Optional[typing.Text] = ...,
54
- deleted_at: typing.Optional[typing.Text] = ...,
54
+ unregistered_at: typing.Optional[typing.Text] = ...,
55
55
  online_until: typing.Optional[builtins.float] = ...,
56
56
  heartbeat_interval: builtins.float = ...,
57
57
  public_key: builtins.bytes = ...,
58
58
  ) -> None: ...
59
- def HasField(self, field_name: typing_extensions.Literal["_deleted_at",b"_deleted_at","_last_activated_at",b"_last_activated_at","_last_deactivated_at",b"_last_deactivated_at","_online_until",b"_online_until","deleted_at",b"deleted_at","last_activated_at",b"last_activated_at","last_deactivated_at",b"last_deactivated_at","online_until",b"online_until"]) -> builtins.bool: ...
60
- def ClearField(self, field_name: typing_extensions.Literal["_deleted_at",b"_deleted_at","_last_activated_at",b"_last_activated_at","_last_deactivated_at",b"_last_deactivated_at","_online_until",b"_online_until","created_at",b"created_at","deleted_at",b"deleted_at","heartbeat_interval",b"heartbeat_interval","last_activated_at",b"last_activated_at","last_deactivated_at",b"last_deactivated_at","node_id",b"node_id","online_until",b"online_until","owner_aid",b"owner_aid","public_key",b"public_key","status",b"status"]) -> None: ...
61
- @typing.overload
62
- def WhichOneof(self, oneof_group: typing_extensions.Literal["_deleted_at",b"_deleted_at"]) -> typing.Optional[typing_extensions.Literal["deleted_at"]]: ...
59
+ def HasField(self, field_name: typing_extensions.Literal["_last_activated_at",b"_last_activated_at","_last_deactivated_at",b"_last_deactivated_at","_online_until",b"_online_until","_unregistered_at",b"_unregistered_at","last_activated_at",b"last_activated_at","last_deactivated_at",b"last_deactivated_at","online_until",b"online_until","unregistered_at",b"unregistered_at"]) -> builtins.bool: ...
60
+ def ClearField(self, field_name: typing_extensions.Literal["_last_activated_at",b"_last_activated_at","_last_deactivated_at",b"_last_deactivated_at","_online_until",b"_online_until","_unregistered_at",b"_unregistered_at","heartbeat_interval",b"heartbeat_interval","last_activated_at",b"last_activated_at","last_deactivated_at",b"last_deactivated_at","node_id",b"node_id","online_until",b"online_until","owner_aid",b"owner_aid","public_key",b"public_key","registered_at",b"registered_at","status",b"status","unregistered_at",b"unregistered_at"]) -> None: ...
63
61
  @typing.overload
64
62
  def WhichOneof(self, oneof_group: typing_extensions.Literal["_last_activated_at",b"_last_activated_at"]) -> typing.Optional[typing_extensions.Literal["last_activated_at"]]: ...
65
63
  @typing.overload
66
64
  def WhichOneof(self, oneof_group: typing_extensions.Literal["_last_deactivated_at",b"_last_deactivated_at"]) -> typing.Optional[typing_extensions.Literal["last_deactivated_at"]]: ...
67
65
  @typing.overload
68
66
  def WhichOneof(self, oneof_group: typing_extensions.Literal["_online_until",b"_online_until"]) -> typing.Optional[typing_extensions.Literal["online_until"]]: ...
67
+ @typing.overload
68
+ def WhichOneof(self, oneof_group: typing_extensions.Literal["_unregistered_at",b"_unregistered_at"]) -> typing.Optional[typing_extensions.Literal["unregistered_at"]]: ...
69
69
  global___NodeInfo = NodeInfo
flwr/server/app.py CHANGED
@@ -71,8 +71,8 @@ from flwr.superlink.artifact_provider import ArtifactProvider
71
71
  from flwr.superlink.auth_plugin import (
72
72
  ControlAuthnPlugin,
73
73
  ControlAuthzPlugin,
74
- get_control_authn_plugins,
75
- get_control_authz_plugins,
74
+ NoOpControlAuthnPlugin,
75
+ NoOpControlAuthzPlugin,
76
76
  )
77
77
  from flwr.superlink.servicer.control import run_control_api_grpc
78
78
 
@@ -93,6 +93,8 @@ P = TypeVar("P", ControlAuthnPlugin, ControlAuthzPlugin)
93
93
  try:
94
94
  from flwr.ee import (
95
95
  add_ee_args_superlink,
96
+ get_control_authn_ee_plugins,
97
+ get_control_authz_ee_plugins,
96
98
  get_control_event_log_writer_plugins,
97
99
  get_ee_artifact_provider,
98
100
  get_fleet_event_log_writer_plugins,
@@ -119,6 +121,26 @@ except ImportError:
119
121
  "No event log writer plugins are currently supported."
120
122
  )
121
123
 
124
+ def get_control_authn_ee_plugins() -> dict[str, type[ControlAuthnPlugin]]:
125
+ """Return all Control API authentication plugins for EE."""
126
+ return {}
127
+
128
+ def get_control_authz_ee_plugins() -> dict[str, type[ControlAuthzPlugin]]:
129
+ """Return all Control API authorization plugins for EE."""
130
+ return {}
131
+
132
+
133
+ def get_control_authn_plugins() -> dict[str, type[ControlAuthnPlugin]]:
134
+ """Return all Control API authentication plugins."""
135
+ ee_dict: dict[str, type[ControlAuthnPlugin]] = get_control_authn_ee_plugins()
136
+ return ee_dict | {AuthnType.NOOP: NoOpControlAuthnPlugin}
137
+
138
+
139
+ def get_control_authz_plugins() -> dict[str, type[ControlAuthzPlugin]]:
140
+ """Return all Control API authorization plugins."""
141
+ ee_dict: dict[str, type[ControlAuthzPlugin]] = get_control_authz_ee_plugins()
142
+ return ee_dict | {AuthzType.NOOP: NoOpControlAuthzPlugin}
143
+
122
144
 
123
145
  # pylint: disable=too-many-branches, too-many-locals, too-many-statements
124
146
  def run_superlink() -> None:
@@ -213,6 +235,17 @@ def run_superlink() -> None:
213
235
 
214
236
  # If supernode authentication is disabled, warn users
215
237
  enable_supernode_auth: bool = args.enable_supernode_auth
238
+ if enable_supernode_auth and args.insecure:
239
+ url_v = f"https://flower.ai/docs/framework/v{package_version}/en/"
240
+ page = "how-to-authenticate-supernodes.html"
241
+ flwr_exit(
242
+ ExitCode.SUPERLINK_INVALID_ARGS,
243
+ "The `--enable-supernode-auth` flag requires encrypted TLS communications. "
244
+ "Please provide TLS certificates using the `--ssl-certfile`, "
245
+ "`--ssl-keyfile` and `--ssl-ca-certfile` arguments to your SuperLink. "
246
+ "Please refer to the Flower documentation for more information: "
247
+ f"{url_v}{page}",
248
+ )
216
249
  if not enable_supernode_auth:
217
250
  log(
218
251
  WARN,
@@ -15,6 +15,7 @@
15
15
  """Fleet API gRPC request-response servicer."""
16
16
 
17
17
 
18
+ import threading
18
19
  from logging import DEBUG, ERROR, INFO
19
20
 
20
21
  import grpc
@@ -53,6 +54,7 @@ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=
53
54
  from flwr.server.superlink.fleet.message_handler import message_handler
54
55
  from flwr.server.superlink.linkstate import LinkStateFactory
55
56
  from flwr.server.superlink.utils import abort_grpc_context
57
+ from flwr.supercore.constant import NodeStatus
56
58
  from flwr.supercore.ffs import FfsFactory
57
59
  from flwr.supercore.object_store import ObjectStoreFactory
58
60
 
@@ -71,6 +73,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
71
73
  self.ffs_factory = ffs_factory
72
74
  self.objectstore_factory = objectstore_factory
73
75
  self.enable_supernode_auth = enable_supernode_auth
76
+ self.lock = threading.Lock()
74
77
 
75
78
  def CreateNode(
76
79
  self, request: CreateNodeRequest, context: grpc.ServicerContext
@@ -88,8 +91,31 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
88
91
 
89
92
  # Check if public key is already in use
90
93
  if node_id := state.get_node_id_by_public_key(request.public_key):
91
- # Prepare response with existing node_id
92
- response = CreateNodeResponse(node=Node(node_id=node_id))
94
+
95
+ # Ensure only one request that requires checking the node state
96
+ # is processed at a time. This avoids race conditions when two
97
+ # SuperNodes try to connect at the same time with the same
98
+ # public key.
99
+ with self.lock:
100
+ node_info = state.get_node_info(node_ids=[node_id])[0]
101
+ if node_info.status == NodeStatus.ONLINE:
102
+ # Node is already active
103
+ log(
104
+ ERROR,
105
+ "Public key already in use (node_id=%s)",
106
+ node_id,
107
+ )
108
+ raise ValueError(
109
+ "Public key already in use by an active SuperNode"
110
+ )
111
+
112
+ # Prepare response with existing node_id
113
+ response = CreateNodeResponse(node=Node(node_id=node_id))
114
+ # Awknowledge heartbeat to mark node as online
115
+ state.acknowledge_node_heartbeat(
116
+ node_id=node_id,
117
+ heartbeat_interval=request.heartbeat_interval,
118
+ )
93
119
  else:
94
120
  if self.enable_supernode_auth:
95
121
  # When SuperNode authentication is enabled,
@@ -353,11 +353,11 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
353
353
  self.nodes[node_id] = NodeInfo(
354
354
  node_id=node_id,
355
355
  owner_aid=owner_aid,
356
- status=NodeStatus.CREATED,
357
- created_at=now().isoformat(),
356
+ status=NodeStatus.REGISTERED,
357
+ registered_at=now().isoformat(),
358
358
  last_activated_at=None,
359
359
  last_deactivated_at=None,
360
- deleted_at=None,
360
+ unregistered_at=None,
361
361
  online_until=None,
362
362
  heartbeat_interval=heartbeat_interval,
363
363
  public_key=public_key,
@@ -371,16 +371,16 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
371
371
  with self.lock:
372
372
  if (
373
373
  not (node := self.nodes.get(node_id))
374
- or node.status == NodeStatus.DELETED
374
+ or node.status == NodeStatus.UNREGISTERED
375
375
  or owner_aid != self.nodes[node_id].owner_aid
376
376
  ):
377
377
  raise ValueError(
378
- f"Node ID {node_id} already deleted, not found or unauthorized "
379
- "deletion attempt."
378
+ f"Node ID {node_id} already unregistered, not found or "
379
+ "the request was unauthorized."
380
380
  )
381
381
 
382
- node.status = NodeStatus.DELETED
383
- node.deleted_at = now().isoformat()
382
+ node.status = NodeStatus.UNREGISTERED
383
+ node.unregistered_at = now().isoformat()
384
384
 
385
385
  def get_nodes(self, run_id: int) -> set[int]:
386
386
  """Return all available nodes.
@@ -436,14 +436,23 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
436
436
  with self.lock:
437
437
  if (
438
438
  node := self.nodes.get(node_id)
439
- ) is None or node.status == NodeStatus.DELETED:
439
+ ) is None or node.status == NodeStatus.UNREGISTERED:
440
440
  raise ValueError(f"Node ID {node_id} not found")
441
441
  return node.public_key
442
442
 
443
443
  def get_node_id_by_public_key(self, public_key: bytes) -> Optional[int]:
444
- """Get `node_id` for the specified `public_key`."""
444
+ """Get `node_id` for the specified `public_key` if it exists and is not
445
+ deleted."""
445
446
  with self.lock:
446
- return self.node_public_key_to_node_id.get(public_key)
447
+ node_id = self.node_public_key_to_node_id.get(public_key)
448
+
449
+ if node_id is None:
450
+ return None
451
+
452
+ node_info = self.nodes[node_id]
453
+ if node_info.status == NodeStatus.UNREGISTERED:
454
+ return None
455
+ return node_id
447
456
 
448
457
  # pylint: disable=too-many-arguments,too-many-positional-arguments
449
458
  def create_run(
@@ -630,11 +639,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
630
639
  the node is marked as offline.
631
640
  """
632
641
  with self.lock:
633
- if (node := self.nodes.get(node_id)) and node.status != NodeStatus.DELETED:
642
+ if (
643
+ node := self.nodes.get(node_id)
644
+ ) and node.status != NodeStatus.UNREGISTERED:
634
645
  current_dt = now()
635
646
 
636
647
  # Set timestamp if the status changes
637
- if node.status != NodeStatus.ONLINE: # offline or created
648
+ if node.status != NodeStatus.ONLINE: # offline or registered
638
649
  node.status = NodeStatus.ONLINE
639
650
  node.last_activated_at = current_dt.isoformat()
640
651
 
@@ -151,7 +151,7 @@ class LinkState(CoreState): # pylint: disable=R0904
151
151
 
152
152
  @abc.abstractmethod
153
153
  def get_node_id_by_public_key(self, public_key: bytes) -> Optional[int]:
154
- """Get `node_id` for the specified `public_key`.
154
+ """Get `node_id` for the specified `public_key` if it exists and is not deleted.
155
155
 
156
156
  Parameters
157
157
  ----------
@@ -161,7 +161,8 @@ class LinkState(CoreState): # pylint: disable=R0904
161
161
  Returns
162
162
  -------
163
163
  Optional[int]
164
- The `node_id` associated with the specified `public_key`.
164
+ The `node_id` associated with the specified `public_key` if it exists
165
+ and is not deleted; otherwise, `None`.
165
166
  """
166
167
 
167
168
  @abc.abstractmethod
@@ -76,10 +76,10 @@ CREATE TABLE IF NOT EXISTS node(
76
76
  node_id INTEGER UNIQUE,
77
77
  owner_aid TEXT,
78
78
  status TEXT,
79
- created_at TEXT,
79
+ registered_at TEXT,
80
80
  last_activated_at TEXT NULL,
81
81
  last_deactivated_at TEXT NULL,
82
- deleted_at TEXT NULL,
82
+ unregistered_at TEXT NULL,
83
83
  online_until TIMESTAMP NULL,
84
84
  heartbeat_interval REAL,
85
85
  public_key BLOB UNIQUE
@@ -623,8 +623,8 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
623
623
 
624
624
  query = """
625
625
  INSERT INTO node
626
- (node_id, owner_aid, status, created_at, last_activated_at,
627
- last_deactivated_at, deleted_at, online_until, heartbeat_interval,
626
+ (node_id, owner_aid, status, registered_at, last_activated_at,
627
+ last_deactivated_at, unregistered_at, online_until, heartbeat_interval,
628
628
  public_key)
629
629
  VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
630
630
  """
@@ -636,11 +636,11 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
636
636
  (
637
637
  sint64_node_id, # node_id
638
638
  owner_aid, # owner_aid
639
- NodeStatus.CREATED, # status
640
- now().isoformat(), # created_at
639
+ NodeStatus.REGISTERED, # status
640
+ now().isoformat(), # registered_at
641
641
  None, # last_activated_at
642
642
  None, # last_deactivated_at
643
- None, # deleted_at
643
+ None, # unregistered_at
644
644
  None, # online_until, initialized with offline status
645
645
  heartbeat_interval, # heartbeat_interval
646
646
  public_key, # public_key
@@ -662,15 +662,15 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
662
662
 
663
663
  query = """
664
664
  UPDATE node
665
- SET status = ?, deleted_at = ?
665
+ SET status = ?, unregistered_at = ?
666
666
  WHERE node_id = ? AND status != ? AND owner_aid = ?
667
667
  RETURNING node_id
668
668
  """
669
669
  params = (
670
- NodeStatus.DELETED,
670
+ NodeStatus.UNREGISTERED,
671
671
  now().isoformat(),
672
672
  sint64_node_id,
673
- NodeStatus.DELETED,
673
+ NodeStatus.UNREGISTERED,
674
674
  owner_aid,
675
675
  )
676
676
 
@@ -775,7 +775,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
775
775
 
776
776
  # Query the public key for the given node_id
777
777
  query = "SELECT public_key FROM node WHERE node_id = ? AND status != ?;"
778
- rows = self.query(query, (sint64_node_id, NodeStatus.DELETED))
778
+ rows = self.query(query, (sint64_node_id, NodeStatus.UNREGISTERED))
779
779
 
780
780
  # If no result is found, return None
781
781
  if not rows:
@@ -785,9 +785,10 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
785
785
  return cast(bytes, rows[0]["public_key"])
786
786
 
787
787
  def get_node_id_by_public_key(self, public_key: bytes) -> Optional[int]:
788
- """Get `node_id` for the specified `public_key`."""
788
+ """Get `node_id` for the specified `public_key` if it exists and is not
789
+ deleted."""
789
790
  query = "SELECT node_id FROM node WHERE public_key = ? AND status != ?;"
790
- rows = self.query(query, (public_key, NodeStatus.DELETED))
791
+ rows = self.query(query, (public_key, NodeStatus.UNREGISTERED))
791
792
 
792
793
  # If no result is found, return None
793
794
  if not rows:
@@ -1058,7 +1059,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
1058
1059
  # Check if node exists and not deleted
1059
1060
  query = "SELECT status FROM node WHERE node_id = ? AND status != ?"
1060
1061
  row = self.conn.execute(
1061
- query, (sint64_node_id, NodeStatus.DELETED)
1062
+ query, (sint64_node_id, NodeStatus.UNREGISTERED)
1062
1063
  ).fetchone()
1063
1064
  if row is None:
1064
1065
  return False
@@ -24,10 +24,10 @@ EXEC_PLUGIN_SECTION = "exec_plugin"
24
24
  class NodeStatus:
25
25
  """Event log writer types."""
26
26
 
27
- CREATED = "created"
27
+ REGISTERED = "registered"
28
28
  ONLINE = "online"
29
29
  OFFLINE = "offline"
30
- DELETED = "deleted"
30
+ UNREGISTERED = "unregistered"
31
31
 
32
32
  def __new__(cls) -> NodeStatus:
33
33
  """Prevent instantiation."""
@@ -15,41 +15,12 @@
15
15
  """Account auth plugin for ControlServicer."""
16
16
 
17
17
 
18
- from flwr.common.constant import AuthnType, AuthzType
19
-
20
18
  from .auth_plugin import ControlAuthnPlugin, ControlAuthzPlugin
21
19
  from .noop_auth_plugin import NoOpControlAuthnPlugin, NoOpControlAuthzPlugin
22
20
 
23
- try:
24
- from flwr.ee import get_control_authn_ee_plugins, get_control_authz_ee_plugins
25
- except ImportError:
26
-
27
- def get_control_authn_ee_plugins() -> dict[str, type[ControlAuthnPlugin]]:
28
- """Return all Control API authentication plugins for EE."""
29
- return {}
30
-
31
- def get_control_authz_ee_plugins() -> dict[str, type[ControlAuthzPlugin]]:
32
- """Return all Control API authorization plugins for EE."""
33
- return {}
34
-
35
-
36
- def get_control_authn_plugins() -> dict[str, type[ControlAuthnPlugin]]:
37
- """Return all Control API authentication plugins."""
38
- ee_dict: dict[str, type[ControlAuthnPlugin]] = get_control_authn_ee_plugins()
39
- return ee_dict | {AuthnType.NOOP: NoOpControlAuthnPlugin}
40
-
41
-
42
- def get_control_authz_plugins() -> dict[str, type[ControlAuthzPlugin]]:
43
- """Return all Control API authorization plugins."""
44
- ee_dict: dict[str, type[ControlAuthzPlugin]] = get_control_authz_ee_plugins()
45
- return ee_dict | {AuthzType.NOOP: NoOpControlAuthzPlugin}
46
-
47
-
48
21
  __all__ = [
49
22
  "ControlAuthnPlugin",
50
23
  "ControlAuthzPlugin",
51
24
  "NoOpControlAuthnPlugin",
52
25
  "NoOpControlAuthzPlugin",
53
- "get_control_authn_plugins",
54
- "get_control_authz_plugins",
55
26
  ]
@@ -49,29 +49,30 @@ from flwr.common.serde import (
49
49
  from flwr.common.typing import Fab, Run, RunStatus
50
50
  from flwr.proto import control_pb2_grpc # pylint: disable=E0611
51
51
  from flwr.proto.control_pb2 import ( # pylint: disable=E0611
52
- CreateNodeCliRequest,
53
- CreateNodeCliResponse,
54
- DeleteNodeCliRequest,
55
- DeleteNodeCliResponse,
56
52
  GetAuthTokensRequest,
57
53
  GetAuthTokensResponse,
58
54
  GetLoginDetailsRequest,
59
55
  GetLoginDetailsResponse,
60
- ListNodesCliRequest,
61
- ListNodesCliResponse,
56
+ ListNodesRequest,
57
+ ListNodesResponse,
62
58
  ListRunsRequest,
63
59
  ListRunsResponse,
64
60
  PullArtifactsRequest,
65
61
  PullArtifactsResponse,
62
+ RegisterNodeRequest,
63
+ RegisterNodeResponse,
66
64
  StartRunRequest,
67
65
  StartRunResponse,
68
66
  StopRunRequest,
69
67
  StopRunResponse,
70
68
  StreamLogsRequest,
71
69
  StreamLogsResponse,
70
+ UnregisterNodeRequest,
71
+ UnregisterNodeResponse,
72
72
  )
73
73
  from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
74
74
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
75
+ from flwr.supercore.constant import NodeStatus
75
76
  from flwr.supercore.ffs import FfsFactory
76
77
  from flwr.supercore.object_store import ObjectStore, ObjectStoreFactory
77
78
  from flwr.supercore.primitives.asymmetric import bytes_to_public_key, uses_nist_ec_curve
@@ -389,11 +390,11 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
389
390
  download_url = self.artifact_provider.get_url(run_id)
390
391
  return PullArtifactsResponse(url=download_url)
391
392
 
392
- def CreateNodeCli(
393
- self, request: CreateNodeCliRequest, context: grpc.ServicerContext
394
- ) -> CreateNodeCliResponse:
393
+ def RegisterNode(
394
+ self, request: RegisterNodeRequest, context: grpc.ServicerContext
395
+ ) -> RegisterNodeResponse:
395
396
  """Add a SuperNode."""
396
- log(INFO, "ControlServicer.CreateNodeCli")
397
+ log(INFO, "ControlServicer.RegisterNode")
397
398
 
398
399
  # Verify public key
399
400
  try:
@@ -427,15 +428,15 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
427
428
  context.abort(
428
429
  grpc.StatusCode.FAILED_PRECONDITION, PUBLIC_KEY_ALREADY_IN_USE_MESSAGE
429
430
  )
430
- log(INFO, "[ControlServicer.CreateNodeCli] Created node_id=%s", node_id)
431
+ log(INFO, "[ControlServicer.RegisterNode] Created node_id=%s", node_id)
431
432
 
432
- return CreateNodeCliResponse(node_id=node_id)
433
+ return RegisterNodeResponse(node_id=node_id)
433
434
 
434
- def DeleteNodeCli(
435
- self, request: DeleteNodeCliRequest, context: grpc.ServicerContext
436
- ) -> DeleteNodeCliResponse:
435
+ def UnregisterNode(
436
+ self, request: UnregisterNodeRequest, context: grpc.ServicerContext
437
+ ) -> UnregisterNodeResponse:
437
438
  """Remove a SuperNode."""
438
- log(INFO, "ControlServicer.RemoveNode")
439
+ log(INFO, "ControlServicer.UnregisterNode")
439
440
 
440
441
  # Init link state
441
442
  state = self.linkstate_factory.state()
@@ -448,19 +449,19 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
448
449
  log(ERROR, NODE_NOT_FOUND_MESSAGE)
449
450
  context.abort(grpc.StatusCode.NOT_FOUND, NODE_NOT_FOUND_MESSAGE)
450
451
 
451
- return DeleteNodeCliResponse()
452
+ return UnregisterNodeResponse()
452
453
 
453
- def ListNodesCli(
454
- self, request: ListNodesCliRequest, context: grpc.ServicerContext
455
- ) -> ListNodesCliResponse:
454
+ def ListNodes(
455
+ self, request: ListNodesRequest, context: grpc.ServicerContext
456
+ ) -> ListNodesResponse:
456
457
  """List all SuperNodes."""
457
- log(INFO, "ControlServicer.ListNodesCli")
458
+ log(INFO, "ControlServicer.ListNodes")
458
459
 
459
460
  if self.is_simulation:
460
- log(ERROR, "ListNodesCli is not available in simulation mode.")
461
+ log(ERROR, "ListNodes is not available in simulation mode.")
461
462
  context.abort(
462
463
  grpc.StatusCode.UNIMPLEMENTED,
463
- "ListNodesCli is not available in simulation mode.",
464
+ "ListNodesis not available in simulation mode.",
464
465
  )
465
466
  raise grpc.RpcError() # This line is unreachable
466
467
 
@@ -478,61 +479,61 @@ class ControlServicer(control_pb2_grpc.ControlServicer):
478
479
  # Retrieve all nodes for the account
479
480
  nodes_info = state.get_node_info(owner_aids=[flwr_aid])
480
481
 
481
- return ListNodesCliResponse(nodes_info=nodes_info, now=now().isoformat())
482
+ return ListNodesResponse(nodes_info=nodes_info, now=now().isoformat())
482
483
 
483
484
 
484
485
  def _create_list_nodeif_for_dry_run() -> Sequence[NodeInfo]:
485
486
  """Create a list of NodeInfo for dry run testing."""
486
487
  nodes_info: list[NodeInfo] = []
487
- # A node created (but not connected)
488
+ # A node registered (but not connected)
488
489
  nodes_info.append(
489
490
  NodeInfo(
490
491
  node_id=15390646978706312628,
491
492
  owner_aid="owner_aid_1",
492
- status="created",
493
- created_at=(now()).isoformat(),
493
+ status=NodeStatus.REGISTERED,
494
+ registered_at=(now()).isoformat(),
494
495
  last_activated_at="",
495
496
  last_deactivated_at="",
496
- deleted_at="",
497
+ unregistered_at="",
497
498
  )
498
499
  )
499
500
 
500
- # A node created and connected
501
+ # A node registered and connected
501
502
  nodes_info.append(
502
503
  NodeInfo(
503
504
  node_id=2941141058168602545,
504
505
  owner_aid="owner_aid_2",
505
- status="online",
506
- created_at=(now()).isoformat(),
506
+ status=NodeStatus.ONLINE,
507
+ registered_at=(now()).isoformat(),
507
508
  last_activated_at=(now() + timedelta(hours=0.5)).isoformat(),
508
509
  last_deactivated_at="",
509
- deleted_at="",
510
+ unregistered_at="",
510
511
  )
511
512
  )
512
513
 
513
- # A node created and deleted (never connected)
514
+ # A node registered and unregistered (never connected)
514
515
  nodes_info.append(
515
516
  NodeInfo(
516
517
  node_id=906971720890549292,
517
518
  owner_aid="owner_aid_3",
518
- status="deleted",
519
- created_at=(now()).isoformat(),
519
+ status=NodeStatus.UNREGISTERED,
520
+ registered_at=(now()).isoformat(),
520
521
  last_activated_at="",
521
522
  last_deactivated_at="",
522
- deleted_at=(now() + timedelta(hours=1)).isoformat(),
523
+ unregistered_at=(now() + timedelta(hours=1)).isoformat(),
523
524
  )
524
525
  )
525
526
 
526
- # A node created, deactivate and then deleted
527
+ # A node registered, deactivate and then unregistered
527
528
  nodes_info.append(
528
529
  NodeInfo(
529
530
  node_id=1781174086018058152,
530
531
  owner_aid="owner_aid_4",
531
- status="offline",
532
- created_at=(now()).isoformat(),
532
+ status=NodeStatus.OFFLINE,
533
+ registered_at=(now()).isoformat(),
533
534
  last_activated_at=(now() + timedelta(hours=0.5)).isoformat(),
534
535
  last_deactivated_at=(now() + timedelta(hours=1)).isoformat(),
535
- deleted_at=(now() + timedelta(hours=1.5)).isoformat(),
536
+ unregistered_at=(now() + timedelta(hours=1.5)).isoformat(),
536
537
  )
537
538
  )
538
539
  return nodes_info