flwr-nightly 1.15.0.dev20250107__py3-none-any.whl → 1.15.0.dev20250109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. flwr/cli/cli_user_auth_interceptor.py +6 -2
  2. flwr/cli/login/login.py +11 -4
  3. flwr/cli/utils.py +4 -4
  4. flwr/client/grpc_rere_client/client_interceptor.py +6 -0
  5. flwr/client/grpc_rere_client/grpc_adapter.py +16 -0
  6. flwr/common/auth_plugin/auth_plugin.py +33 -23
  7. flwr/common/constant.py +2 -0
  8. flwr/common/typing.py +20 -0
  9. flwr/proto/exec_pb2.py +12 -24
  10. flwr/proto/exec_pb2.pyi +27 -54
  11. flwr/proto/fleet_pb2.py +40 -27
  12. flwr/proto/fleet_pb2.pyi +84 -0
  13. flwr/proto/fleet_pb2_grpc.py +66 -0
  14. flwr/proto/fleet_pb2_grpc.pyi +20 -0
  15. flwr/server/app.py +11 -13
  16. flwr/server/superlink/driver/serverappio_servicer.py +22 -8
  17. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +16 -0
  18. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -1
  19. flwr/server/superlink/linkstate/in_memory_linkstate.py +26 -22
  20. flwr/server/superlink/linkstate/linkstate.py +10 -4
  21. flwr/server/superlink/linkstate/sqlite_linkstate.py +50 -29
  22. flwr/superexec/exec_servicer.py +23 -2
  23. {flwr_nightly-1.15.0.dev20250107.dist-info → flwr_nightly-1.15.0.dev20250109.dist-info}/METADATA +4 -4
  24. {flwr_nightly-1.15.0.dev20250107.dist-info → flwr_nightly-1.15.0.dev20250109.dist-info}/RECORD +27 -27
  25. {flwr_nightly-1.15.0.dev20250107.dist-info → flwr_nightly-1.15.0.dev20250109.dist-info}/LICENSE +0 -0
  26. {flwr_nightly-1.15.0.dev20250107.dist-info → flwr_nightly-1.15.0.dev20250109.dist-info}/WHEEL +0 -0
  27. {flwr_nightly-1.15.0.dev20250107.dist-info → flwr_nightly-1.15.0.dev20250109.dist-info}/entry_points.txt +0 -0
flwr/proto/fleet_pb2.pyi CHANGED
@@ -3,6 +3,7 @@
3
3
  isort:skip_file
4
4
  """
5
5
  import builtins
6
+ import flwr.proto.message_pb2
6
7
  import flwr.proto.node_pb2
7
8
  import flwr.proto.task_pb2
8
9
  import google.protobuf.descriptor
@@ -169,6 +170,89 @@ class PushTaskResResponse(google.protobuf.message.Message):
169
170
  def ClearField(self, field_name: typing_extensions.Literal["reconnect",b"reconnect","results",b"results"]) -> None: ...
170
171
  global___PushTaskResResponse = PushTaskResResponse
171
172
 
173
+ class PullMessagesRequest(google.protobuf.message.Message):
174
+ """PullMessages messages"""
175
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
176
+ NODE_FIELD_NUMBER: builtins.int
177
+ MESSAGE_IDS_FIELD_NUMBER: builtins.int
178
+ @property
179
+ def node(self) -> flwr.proto.node_pb2.Node: ...
180
+ @property
181
+ def message_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ...
182
+ def __init__(self,
183
+ *,
184
+ node: typing.Optional[flwr.proto.node_pb2.Node] = ...,
185
+ message_ids: typing.Optional[typing.Iterable[typing.Text]] = ...,
186
+ ) -> None: ...
187
+ def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ...
188
+ def ClearField(self, field_name: typing_extensions.Literal["message_ids",b"message_ids","node",b"node"]) -> None: ...
189
+ global___PullMessagesRequest = PullMessagesRequest
190
+
191
+ class PullMessagesResponse(google.protobuf.message.Message):
192
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
193
+ RECONNECT_FIELD_NUMBER: builtins.int
194
+ MESSAGES_LIST_FIELD_NUMBER: builtins.int
195
+ @property
196
+ def reconnect(self) -> global___Reconnect: ...
197
+ @property
198
+ def messages_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.message_pb2.Message]: ...
199
+ def __init__(self,
200
+ *,
201
+ reconnect: typing.Optional[global___Reconnect] = ...,
202
+ messages_list: typing.Optional[typing.Iterable[flwr.proto.message_pb2.Message]] = ...,
203
+ ) -> None: ...
204
+ def HasField(self, field_name: typing_extensions.Literal["reconnect",b"reconnect"]) -> builtins.bool: ...
205
+ def ClearField(self, field_name: typing_extensions.Literal["messages_list",b"messages_list","reconnect",b"reconnect"]) -> None: ...
206
+ global___PullMessagesResponse = PullMessagesResponse
207
+
208
+ class PushMessagesRequest(google.protobuf.message.Message):
209
+ """PushMessages messages"""
210
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
211
+ NODE_FIELD_NUMBER: builtins.int
212
+ MESSAGES_LIST_FIELD_NUMBER: builtins.int
213
+ @property
214
+ def node(self) -> flwr.proto.node_pb2.Node: ...
215
+ @property
216
+ def messages_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.message_pb2.Message]: ...
217
+ def __init__(self,
218
+ *,
219
+ node: typing.Optional[flwr.proto.node_pb2.Node] = ...,
220
+ messages_list: typing.Optional[typing.Iterable[flwr.proto.message_pb2.Message]] = ...,
221
+ ) -> None: ...
222
+ def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ...
223
+ def ClearField(self, field_name: typing_extensions.Literal["messages_list",b"messages_list","node",b"node"]) -> None: ...
224
+ global___PushMessagesRequest = PushMessagesRequest
225
+
226
+ class PushMessagesResponse(google.protobuf.message.Message):
227
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
228
+ class ResultsEntry(google.protobuf.message.Message):
229
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
230
+ KEY_FIELD_NUMBER: builtins.int
231
+ VALUE_FIELD_NUMBER: builtins.int
232
+ key: typing.Text
233
+ value: builtins.int
234
+ def __init__(self,
235
+ *,
236
+ key: typing.Text = ...,
237
+ value: builtins.int = ...,
238
+ ) -> None: ...
239
+ def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
240
+
241
+ RECONNECT_FIELD_NUMBER: builtins.int
242
+ RESULTS_FIELD_NUMBER: builtins.int
243
+ @property
244
+ def reconnect(self) -> global___Reconnect: ...
245
+ @property
246
+ def results(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, builtins.int]: ...
247
+ def __init__(self,
248
+ *,
249
+ reconnect: typing.Optional[global___Reconnect] = ...,
250
+ results: typing.Optional[typing.Mapping[typing.Text, builtins.int]] = ...,
251
+ ) -> None: ...
252
+ def HasField(self, field_name: typing_extensions.Literal["reconnect",b"reconnect"]) -> builtins.bool: ...
253
+ def ClearField(self, field_name: typing_extensions.Literal["reconnect",b"reconnect","results",b"results"]) -> None: ...
254
+ global___PushMessagesResponse = PushMessagesResponse
255
+
172
256
  class Reconnect(google.protobuf.message.Message):
173
257
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
174
258
  RECONNECT_FIELD_NUMBER: builtins.int
@@ -36,11 +36,21 @@ class FleetStub(object):
36
36
  request_serializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsRequest.SerializeToString,
37
37
  response_deserializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsResponse.FromString,
38
38
  )
39
+ self.PullMessages = channel.unary_unary(
40
+ '/flwr.proto.Fleet/PullMessages',
41
+ request_serializer=flwr_dot_proto_dot_fleet__pb2.PullMessagesRequest.SerializeToString,
42
+ response_deserializer=flwr_dot_proto_dot_fleet__pb2.PullMessagesResponse.FromString,
43
+ )
39
44
  self.PushTaskRes = channel.unary_unary(
40
45
  '/flwr.proto.Fleet/PushTaskRes',
41
46
  request_serializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResRequest.SerializeToString,
42
47
  response_deserializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResResponse.FromString,
43
48
  )
49
+ self.PushMessages = channel.unary_unary(
50
+ '/flwr.proto.Fleet/PushMessages',
51
+ request_serializer=flwr_dot_proto_dot_fleet__pb2.PushMessagesRequest.SerializeToString,
52
+ response_deserializer=flwr_dot_proto_dot_fleet__pb2.PushMessagesResponse.FromString,
53
+ )
44
54
  self.GetRun = channel.unary_unary(
45
55
  '/flwr.proto.Fleet/GetRun',
46
56
  request_serializer=flwr_dot_proto_dot_run__pb2.GetRunRequest.SerializeToString,
@@ -83,6 +93,12 @@ class FleetServicer(object):
83
93
  context.set_details('Method not implemented!')
84
94
  raise NotImplementedError('Method not implemented!')
85
95
 
96
+ def PullMessages(self, request, context):
97
+ """Missing associated documentation comment in .proto file."""
98
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
99
+ context.set_details('Method not implemented!')
100
+ raise NotImplementedError('Method not implemented!')
101
+
86
102
  def PushTaskRes(self, request, context):
87
103
  """Complete one or more tasks, if possible
88
104
 
@@ -92,6 +108,12 @@ class FleetServicer(object):
92
108
  context.set_details('Method not implemented!')
93
109
  raise NotImplementedError('Method not implemented!')
94
110
 
111
+ def PushMessages(self, request, context):
112
+ """Missing associated documentation comment in .proto file."""
113
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
114
+ context.set_details('Method not implemented!')
115
+ raise NotImplementedError('Method not implemented!')
116
+
95
117
  def GetRun(self, request, context):
96
118
  """Missing associated documentation comment in .proto file."""
97
119
  context.set_code(grpc.StatusCode.UNIMPLEMENTED)
@@ -128,11 +150,21 @@ def add_FleetServicer_to_server(servicer, server):
128
150
  request_deserializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsRequest.FromString,
129
151
  response_serializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsResponse.SerializeToString,
130
152
  ),
153
+ 'PullMessages': grpc.unary_unary_rpc_method_handler(
154
+ servicer.PullMessages,
155
+ request_deserializer=flwr_dot_proto_dot_fleet__pb2.PullMessagesRequest.FromString,
156
+ response_serializer=flwr_dot_proto_dot_fleet__pb2.PullMessagesResponse.SerializeToString,
157
+ ),
131
158
  'PushTaskRes': grpc.unary_unary_rpc_method_handler(
132
159
  servicer.PushTaskRes,
133
160
  request_deserializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResRequest.FromString,
134
161
  response_serializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResResponse.SerializeToString,
135
162
  ),
163
+ 'PushMessages': grpc.unary_unary_rpc_method_handler(
164
+ servicer.PushMessages,
165
+ request_deserializer=flwr_dot_proto_dot_fleet__pb2.PushMessagesRequest.FromString,
166
+ response_serializer=flwr_dot_proto_dot_fleet__pb2.PushMessagesResponse.SerializeToString,
167
+ ),
136
168
  'GetRun': grpc.unary_unary_rpc_method_handler(
137
169
  servicer.GetRun,
138
170
  request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunRequest.FromString,
@@ -221,6 +253,23 @@ class Fleet(object):
221
253
  options, channel_credentials,
222
254
  insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
223
255
 
256
+ @staticmethod
257
+ def PullMessages(request,
258
+ target,
259
+ options=(),
260
+ channel_credentials=None,
261
+ call_credentials=None,
262
+ insecure=False,
263
+ compression=None,
264
+ wait_for_ready=None,
265
+ timeout=None,
266
+ metadata=None):
267
+ return grpc.experimental.unary_unary(request, target, '/flwr.proto.Fleet/PullMessages',
268
+ flwr_dot_proto_dot_fleet__pb2.PullMessagesRequest.SerializeToString,
269
+ flwr_dot_proto_dot_fleet__pb2.PullMessagesResponse.FromString,
270
+ options, channel_credentials,
271
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
272
+
224
273
  @staticmethod
225
274
  def PushTaskRes(request,
226
275
  target,
@@ -238,6 +287,23 @@ class Fleet(object):
238
287
  options, channel_credentials,
239
288
  insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
240
289
 
290
+ @staticmethod
291
+ def PushMessages(request,
292
+ target,
293
+ options=(),
294
+ channel_credentials=None,
295
+ call_credentials=None,
296
+ insecure=False,
297
+ compression=None,
298
+ wait_for_ready=None,
299
+ timeout=None,
300
+ metadata=None):
301
+ return grpc.experimental.unary_unary(request, target, '/flwr.proto.Fleet/PushMessages',
302
+ flwr_dot_proto_dot_fleet__pb2.PushMessagesRequest.SerializeToString,
303
+ flwr_dot_proto_dot_fleet__pb2.PushMessagesResponse.FromString,
304
+ options, channel_credentials,
305
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
306
+
241
307
  @staticmethod
242
308
  def GetRun(request,
243
309
  target,
@@ -30,6 +30,10 @@ class FleetStub:
30
30
  HTTP API path: /api/v1/fleet/pull-task-ins
31
31
  """
32
32
 
33
+ PullMessages: grpc.UnaryUnaryMultiCallable[
34
+ flwr.proto.fleet_pb2.PullMessagesRequest,
35
+ flwr.proto.fleet_pb2.PullMessagesResponse]
36
+
33
37
  PushTaskRes: grpc.UnaryUnaryMultiCallable[
34
38
  flwr.proto.fleet_pb2.PushTaskResRequest,
35
39
  flwr.proto.fleet_pb2.PushTaskResResponse]
@@ -38,6 +42,10 @@ class FleetStub:
38
42
  HTTP API path: /api/v1/fleet/push-task-res
39
43
  """
40
44
 
45
+ PushMessages: grpc.UnaryUnaryMultiCallable[
46
+ flwr.proto.fleet_pb2.PushMessagesRequest,
47
+ flwr.proto.fleet_pb2.PushMessagesResponse]
48
+
41
49
  GetRun: grpc.UnaryUnaryMultiCallable[
42
50
  flwr.proto.run_pb2.GetRunRequest,
43
51
  flwr.proto.run_pb2.GetRunResponse]
@@ -78,6 +86,12 @@ class FleetServicer(metaclass=abc.ABCMeta):
78
86
  """
79
87
  pass
80
88
 
89
+ @abc.abstractmethod
90
+ def PullMessages(self,
91
+ request: flwr.proto.fleet_pb2.PullMessagesRequest,
92
+ context: grpc.ServicerContext,
93
+ ) -> flwr.proto.fleet_pb2.PullMessagesResponse: ...
94
+
81
95
  @abc.abstractmethod
82
96
  def PushTaskRes(self,
83
97
  request: flwr.proto.fleet_pb2.PushTaskResRequest,
@@ -89,6 +103,12 @@ class FleetServicer(metaclass=abc.ABCMeta):
89
103
  """
90
104
  pass
91
105
 
106
+ @abc.abstractmethod
107
+ def PushMessages(self,
108
+ request: flwr.proto.fleet_pb2.PushMessagesRequest,
109
+ context: grpc.ServicerContext,
110
+ ) -> flwr.proto.fleet_pb2.PushMessagesResponse: ...
111
+
92
112
  @abc.abstractmethod
93
113
  def GetRun(self,
94
114
  request: flwr.proto.run_pb2.GetRunRequest,
flwr/server/app.py CHANGED
@@ -263,11 +263,10 @@ def run_superlink() -> None:
263
263
  # Obtain certificates
264
264
  certificates = try_obtain_server_certificates(args, args.fleet_api_type)
265
265
 
266
- user_auth_config = _try_obtain_user_auth_config(args)
267
266
  auth_plugin: Optional[ExecAuthPlugin] = None
268
- # user_auth_config is None only if the args.user_auth_config is not provided
269
- if user_auth_config is not None:
270
- auth_plugin = _try_obtain_exec_auth_plugin(user_auth_config)
267
+ # Load the auth plugin if the args.user_auth_config is provided
268
+ if cfg_path := getattr(args, "user_auth_config", None):
269
+ auth_plugin = _try_obtain_exec_auth_plugin(Path(cfg_path))
271
270
 
272
271
  # Initialize StateFactory
273
272
  state_factory = LinkStateFactory(args.database)
@@ -584,21 +583,20 @@ def _try_setup_node_authentication(
584
583
  )
585
584
 
586
585
 
587
- def _try_obtain_user_auth_config(args: argparse.Namespace) -> Optional[dict[str, Any]]:
588
- if getattr(args, "user_auth_config", None) is not None:
589
- with open(args.user_auth_config, encoding="utf-8") as file:
590
- config: dict[str, Any] = yaml.safe_load(file)
591
- return config
592
- return None
586
+ def _try_obtain_exec_auth_plugin(config_path: Path) -> Optional[ExecAuthPlugin]:
587
+ # Load YAML file
588
+ with config_path.open("r", encoding="utf-8") as file:
589
+ config: dict[str, Any] = yaml.safe_load(file)
593
590
 
594
-
595
- def _try_obtain_exec_auth_plugin(config: dict[str, Any]) -> Optional[ExecAuthPlugin]:
591
+ # Load authentication configuration
596
592
  auth_config: dict[str, Any] = config.get("authentication", {})
597
593
  auth_type: str = auth_config.get(AUTH_TYPE, "")
594
+
595
+ # Load authentication plugin
598
596
  try:
599
597
  all_plugins: dict[str, type[ExecAuthPlugin]] = get_exec_auth_plugins()
600
598
  auth_plugin_class = all_plugins[auth_type]
601
- return auth_plugin_class(config=auth_config)
599
+ return auth_plugin_class(user_auth_config_path=config_path)
602
600
  except KeyError:
603
601
  if auth_type != "":
604
602
  sys.exit(
@@ -118,8 +118,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
118
118
  ffs: Ffs = self.ffs_factory.ffs()
119
119
  fab_hash = ffs.put(fab.content, {})
120
120
  _raise_if(
121
- fab_hash != fab.hash_str,
122
- f"FAB ({fab.hash_str}) hash from request doesn't match contents",
121
+ validation_error=fab_hash != fab.hash_str,
122
+ request_name="CreateRun",
123
+ detail=f"FAB ({fab.hash_str}) hash from request doesn't match contents",
123
124
  )
124
125
  else:
125
126
  fab_hash = ""
@@ -155,12 +156,22 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
155
156
  task_ins.task.pushed_at = pushed_at
156
157
 
157
158
  # Validate request
158
- _raise_if(len(request.task_ins_list) == 0, "`task_ins_list` must not be empty")
159
+ _raise_if(
160
+ validation_error=len(request.task_ins_list) == 0,
161
+ request_name="PushTaskIns",
162
+ detail="`task_ins_list` must not be empty",
163
+ )
159
164
  for task_ins in request.task_ins_list:
160
165
  validation_errors = validate_task_ins_or_res(task_ins)
161
- _raise_if(bool(validation_errors), ", ".join(validation_errors))
162
166
  _raise_if(
163
- request.run_id != task_ins.run_id, "`task_ins` has mismatched `run_id`"
167
+ validation_error=bool(validation_errors),
168
+ request_name="PushTaskIns",
169
+ detail=", ".join(validation_errors),
170
+ )
171
+ _raise_if(
172
+ validation_error=request.run_id != task_ins.run_id,
173
+ request_name="PushTaskIns",
174
+ detail="`task_ins` has mismatched `run_id`",
164
175
  )
165
176
 
166
177
  # Store each TaskIns
@@ -199,7 +210,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
199
210
  # Validate request
200
211
  for task_res in task_res_list:
201
212
  _raise_if(
202
- request.run_id != task_res.run_id, "`task_res` has mismatched `run_id`"
213
+ validation_error=request.run_id != task_res.run_id,
214
+ request_name="PullTaskRes",
215
+ detail="`task_res` has mismatched `run_id`",
203
216
  )
204
217
 
205
218
  # Delete the TaskIns/TaskRes pairs if TaskRes is found
@@ -344,6 +357,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
344
357
  return GetRunStatusResponse(run_status_dict=run_status_dict)
345
358
 
346
359
 
347
- def _raise_if(validation_error: bool, detail: str) -> None:
360
+ def _raise_if(validation_error: bool, request_name: str, detail: str) -> None:
361
+ """Raise a `ValueError` with a detailed message if a validation error occurs."""
348
362
  if validation_error:
349
- raise ValueError(f"Malformed PushTaskInsRequest: {detail}")
363
+ raise ValueError(f"Malformed {request_name}: {detail}")
@@ -30,8 +30,12 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
30
30
  DeleteNodeResponse,
31
31
  PingRequest,
32
32
  PingResponse,
33
+ PullMessagesRequest,
34
+ PullMessagesResponse,
33
35
  PullTaskInsRequest,
34
36
  PullTaskInsResponse,
37
+ PushMessagesRequest,
38
+ PushMessagesResponse,
35
39
  PushTaskResRequest,
36
40
  PushTaskResResponse,
37
41
  )
@@ -95,6 +99,12 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
95
99
  state=self.state_factory.state(),
96
100
  )
97
101
 
102
+ def PullMessages(
103
+ self, request: PullMessagesRequest, context: grpc.ServicerContext
104
+ ) -> PullMessagesResponse:
105
+ """Pull Messages."""
106
+ return PullMessagesResponse()
107
+
98
108
  def PushTaskRes(
99
109
  self, request: PushTaskResRequest, context: grpc.ServicerContext
100
110
  ) -> PushTaskResResponse:
@@ -118,6 +128,12 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
118
128
 
119
129
  return res
120
130
 
131
+ def PushMessages(
132
+ self, request: PushMessagesRequest, context: grpc.ServicerContext
133
+ ) -> PushMessagesResponse:
134
+ """Push Messages."""
135
+ return PushMessagesResponse()
136
+
121
137
  def GetRun(
122
138
  self, request: GetRunRequest, context: grpc.ServicerContext
123
139
  ) -> GetRunResponse:
@@ -223,5 +223,6 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
223
223
  # No `node_id` exists for the provided `public_key`
224
224
  # Handle `CreateNode` here instead of calling the default method handler
225
225
  # Note: the innermost `CreateNode` method will never be called
226
- node_id = state.create_node(request.ping_interval, public_key_bytes)
226
+ node_id = state.create_node(request.ping_interval)
227
+ state.set_node_public_key(node_id, public_key_bytes)
227
228
  return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
@@ -62,6 +62,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
62
62
  # Map node_id to (online_until, ping_interval)
63
63
  self.node_ids: dict[int, tuple[float, float]] = {}
64
64
  self.public_key_to_node_id: dict[bytes, int] = {}
65
+ self.node_id_to_public_key: dict[int, bytes] = {}
65
66
 
66
67
  # Map run_id to RunRecord
67
68
  self.run_ids: dict[int, RunRecord] = {}
@@ -306,9 +307,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
306
307
  """
307
308
  return len(self.task_res_store)
308
309
 
309
- def create_node(
310
- self, ping_interval: float, public_key: Optional[bytes] = None
311
- ) -> int:
310
+ def create_node(self, ping_interval: float) -> int:
312
311
  """Create, store in the link state, and return `node_id`."""
313
312
  # Sample a random int64 as node_id
314
313
  node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
@@ -318,33 +317,18 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
318
317
  log(ERROR, "Unexpected node registration failure.")
319
318
  return 0
320
319
 
321
- if public_key is not None:
322
- if (
323
- public_key in self.public_key_to_node_id
324
- or node_id in self.public_key_to_node_id.values()
325
- ):
326
- log(ERROR, "Unexpected node registration failure.")
327
- return 0
328
-
329
- self.public_key_to_node_id[public_key] = node_id
330
-
331
320
  self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
332
321
  return node_id
333
322
 
334
- def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
323
+ def delete_node(self, node_id: int) -> None:
335
324
  """Delete a node."""
336
325
  with self.lock:
337
326
  if node_id not in self.node_ids:
338
327
  raise ValueError(f"Node {node_id} not found")
339
328
 
340
- if public_key is not None:
341
- if (
342
- public_key not in self.public_key_to_node_id
343
- or node_id not in self.public_key_to_node_id.values()
344
- ):
345
- raise ValueError("Public key or node_id not found")
346
-
347
- del self.public_key_to_node_id[public_key]
329
+ # Remove node ID <> public key mappings
330
+ if pk := self.node_id_to_public_key.pop(node_id, None):
331
+ del self.public_key_to_node_id[pk]
348
332
 
349
333
  del self.node_ids[node_id]
350
334
 
@@ -366,6 +350,26 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
366
350
  if online_until > current_time
367
351
  }
368
352
 
353
+ def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
354
+ """Set `public_key` for the specified `node_id`."""
355
+ with self.lock:
356
+ if node_id not in self.node_ids:
357
+ raise ValueError(f"Node {node_id} not found")
358
+
359
+ if public_key in self.public_key_to_node_id:
360
+ raise ValueError("Public key already in use")
361
+
362
+ self.public_key_to_node_id[public_key] = node_id
363
+ self.node_id_to_public_key[node_id] = public_key
364
+
365
+ def get_node_public_key(self, node_id: int) -> Optional[bytes]:
366
+ """Get `public_key` for the specified `node_id`."""
367
+ with self.lock:
368
+ if node_id not in self.node_ids:
369
+ raise ValueError(f"Node {node_id} not found")
370
+
371
+ return self.node_id_to_public_key.get(node_id)
372
+
369
373
  def get_node_id(self, node_public_key: bytes) -> Optional[int]:
370
374
  """Retrieve stored `node_id` filtered by `node_public_keys`."""
371
375
  return self.public_key_to_node_id.get(node_public_key)
@@ -154,13 +154,11 @@ class LinkState(abc.ABC): # pylint: disable=R0904
154
154
  """Get all TaskIns IDs for the given run_id."""
155
155
 
156
156
  @abc.abstractmethod
157
- def create_node(
158
- self, ping_interval: float, public_key: Optional[bytes] = None
159
- ) -> int:
157
+ def create_node(self, ping_interval: float) -> int:
160
158
  """Create, store in the link state, and return `node_id`."""
161
159
 
162
160
  @abc.abstractmethod
163
- def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
161
+ def delete_node(self, node_id: int) -> None:
164
162
  """Remove `node_id` from the link state."""
165
163
 
166
164
  @abc.abstractmethod
@@ -173,6 +171,14 @@ class LinkState(abc.ABC): # pylint: disable=R0904
173
171
  an empty `Set` MUST be returned.
174
172
  """
175
173
 
174
+ @abc.abstractmethod
175
+ def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
176
+ """Set `public_key` for the specified `node_id`."""
177
+
178
+ @abc.abstractmethod
179
+ def get_node_public_key(self, node_id: int) -> Optional[bytes]:
180
+ """Get `public_key` for the specified `node_id`."""
181
+
176
182
  @abc.abstractmethod
177
183
  def get_node_id(self, node_public_key: bytes) -> Optional[int]:
178
184
  """Retrieve stored `node_id` filtered by `node_public_keys`."""