flwr 1.13.1__py3-none-any.whl → 1.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. flwr/cli/app.py +5 -0
  2. flwr/cli/auth_plugin/__init__.py +31 -0
  3. flwr/cli/auth_plugin/oidc_cli_plugin.py +150 -0
  4. flwr/cli/build.py +1 -0
  5. flwr/cli/cli_user_auth_interceptor.py +90 -0
  6. flwr/cli/config_utils.py +43 -149
  7. flwr/cli/constant.py +27 -0
  8. flwr/cli/example.py +1 -0
  9. flwr/cli/install.py +2 -1
  10. flwr/cli/log.py +34 -37
  11. flwr/cli/login/__init__.py +22 -0
  12. flwr/cli/login/login.py +116 -0
  13. flwr/cli/ls.py +214 -106
  14. flwr/cli/new/__init__.py +1 -0
  15. flwr/cli/new/new.py +2 -1
  16. flwr/cli/new/templates/app/.gitignore.tpl +3 -0
  17. flwr/cli/new/templates/app/README.md.tpl +3 -2
  18. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  19. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +4 -4
  20. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  21. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +2 -2
  22. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +3 -4
  23. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +2 -2
  24. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +4 -4
  25. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +3 -3
  26. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
  27. flwr/cli/run/__init__.py +1 -0
  28. flwr/cli/run/run.py +103 -43
  29. flwr/cli/stop.py +139 -0
  30. flwr/cli/utils.py +186 -8
  31. flwr/client/app.py +49 -50
  32. flwr/client/client.py +1 -32
  33. flwr/client/clientapp/app.py +23 -26
  34. flwr/client/clientapp/utils.py +2 -1
  35. flwr/client/grpc_adapter_client/connection.py +1 -1
  36. flwr/client/grpc_client/connection.py +2 -13
  37. flwr/client/grpc_rere_client/client_interceptor.py +19 -119
  38. flwr/client/grpc_rere_client/connection.py +59 -43
  39. flwr/client/grpc_rere_client/grpc_adapter.py +12 -12
  40. flwr/client/message_handler/message_handler.py +1 -2
  41. flwr/client/message_handler/task_handler.py +0 -17
  42. flwr/client/mod/comms_mods.py +1 -0
  43. flwr/client/mod/localdp_mod.py +1 -1
  44. flwr/client/nodestate/__init__.py +1 -0
  45. flwr/client/nodestate/nodestate.py +1 -0
  46. flwr/client/nodestate/nodestate_factory.py +1 -0
  47. flwr/client/numpy_client.py +0 -44
  48. flwr/client/rest_client/connection.py +37 -29
  49. flwr/client/supernode/app.py +20 -74
  50. flwr/common/address.py +1 -0
  51. flwr/common/args.py +26 -47
  52. flwr/common/auth_plugin/__init__.py +24 -0
  53. flwr/common/auth_plugin/auth_plugin.py +122 -0
  54. flwr/common/config.py +169 -17
  55. flwr/common/constant.py +38 -9
  56. flwr/common/differential_privacy.py +2 -1
  57. flwr/common/exit/__init__.py +24 -0
  58. flwr/common/exit/exit.py +99 -0
  59. flwr/common/exit/exit_code.py +93 -0
  60. flwr/common/exit_handlers.py +24 -10
  61. flwr/common/grpc.py +167 -4
  62. flwr/common/logger.py +66 -7
  63. flwr/common/message.py +1 -0
  64. flwr/common/object_ref.py +57 -54
  65. flwr/common/pyproject.py +1 -0
  66. flwr/common/record/__init__.py +1 -0
  67. flwr/common/record/parametersrecord.py +1 -0
  68. flwr/common/record/recordset.py +1 -1
  69. flwr/common/retry_invoker.py +77 -0
  70. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
  71. flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
  72. flwr/common/serde.py +6 -4
  73. flwr/common/telemetry.py +15 -4
  74. flwr/common/typing.py +32 -0
  75. flwr/common/version.py +1 -0
  76. flwr/proto/clientappio_pb2.py +1 -1
  77. flwr/proto/error_pb2.py +1 -1
  78. flwr/proto/exec_pb2.py +27 -15
  79. flwr/proto/exec_pb2.pyi +80 -2
  80. flwr/proto/exec_pb2_grpc.py +102 -0
  81. flwr/proto/exec_pb2_grpc.pyi +39 -0
  82. flwr/proto/fab_pb2.py +5 -5
  83. flwr/proto/fab_pb2.pyi +4 -1
  84. flwr/proto/fleet_pb2.py +31 -31
  85. flwr/proto/fleet_pb2.pyi +23 -23
  86. flwr/proto/fleet_pb2_grpc.py +30 -30
  87. flwr/proto/fleet_pb2_grpc.pyi +20 -20
  88. flwr/proto/grpcadapter_pb2.py +1 -1
  89. flwr/proto/log_pb2.py +1 -1
  90. flwr/proto/message_pb2.py +1 -1
  91. flwr/proto/node_pb2.py +3 -3
  92. flwr/proto/node_pb2.pyi +1 -4
  93. flwr/proto/recordset_pb2.py +1 -1
  94. flwr/proto/run_pb2.py +1 -1
  95. flwr/proto/serverappio_pb2.py +24 -25
  96. flwr/proto/serverappio_pb2.pyi +32 -32
  97. flwr/proto/serverappio_pb2_grpc.py +62 -28
  98. flwr/proto/serverappio_pb2_grpc.pyi +29 -16
  99. flwr/proto/simulationio_pb2.py +3 -3
  100. flwr/proto/simulationio_pb2_grpc.py +34 -0
  101. flwr/proto/simulationio_pb2_grpc.pyi +13 -0
  102. flwr/proto/task_pb2.py +1 -1
  103. flwr/proto/transport_pb2.py +1 -1
  104. flwr/server/app.py +152 -112
  105. flwr/server/compat/app_utils.py +7 -2
  106. flwr/server/compat/driver_client_proxy.py +1 -2
  107. flwr/server/driver/grpc_driver.py +38 -85
  108. flwr/server/driver/inmemory_driver.py +7 -2
  109. flwr/server/run_serverapp.py +8 -9
  110. flwr/server/serverapp/app.py +37 -13
  111. flwr/server/strategy/dpfedavg_fixed.py +1 -0
  112. flwr/server/superlink/driver/serverappio_grpc.py +2 -1
  113. flwr/server/superlink/driver/serverappio_servicer.py +148 -63
  114. flwr/server/superlink/ffs/disk_ffs.py +1 -0
  115. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -87
  116. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
  117. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
  118. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +56 -35
  119. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +99 -169
  120. flwr/server/superlink/fleet/message_handler/message_handler.py +69 -29
  121. flwr/server/superlink/fleet/rest_rere/rest_api.py +20 -19
  122. flwr/server/superlink/fleet/vce/__init__.py +1 -0
  123. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
  124. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
  125. flwr/server/superlink/fleet/vce/vce_api.py +2 -2
  126. flwr/server/superlink/linkstate/in_memory_linkstate.py +60 -99
  127. flwr/server/superlink/linkstate/linkstate.py +30 -36
  128. flwr/server/superlink/linkstate/sqlite_linkstate.py +105 -188
  129. flwr/server/superlink/linkstate/utils.py +18 -8
  130. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  131. flwr/server/superlink/simulation/simulationio_servicer.py +33 -0
  132. flwr/server/superlink/utils.py +65 -0
  133. flwr/server/utils/validator.py +9 -34
  134. flwr/simulation/app.py +20 -10
  135. flwr/simulation/legacy_app.py +4 -2
  136. flwr/simulation/ray_transport/ray_actor.py +1 -0
  137. flwr/simulation/ray_transport/utils.py +1 -0
  138. flwr/simulation/run_simulation.py +36 -22
  139. flwr/simulation/simulationio_connection.py +5 -1
  140. flwr/superexec/app.py +1 -0
  141. flwr/superexec/deployment.py +1 -0
  142. flwr/superexec/exec_grpc.py +20 -2
  143. flwr/superexec/exec_servicer.py +97 -2
  144. flwr/superexec/exec_user_auth_interceptor.py +101 -0
  145. flwr/superexec/executor.py +1 -0
  146. {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/METADATA +14 -13
  147. {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/RECORD +150 -144
  148. flwr/proto/common_pb2.py +0 -36
  149. flwr/proto/common_pb2.pyi +0 -121
  150. flwr/proto/common_pb2_grpc.py +0 -4
  151. flwr/proto/common_pb2_grpc.pyi +0 -4
  152. flwr/proto/control_pb2.py +0 -27
  153. flwr/proto/control_pb2.pyi +0 -7
  154. flwr/proto/control_pb2_grpc.py +0 -135
  155. flwr/proto/control_pb2_grpc.pyi +0 -53
  156. {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/LICENSE +0 -0
  157. {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/WHEEL +0 -0
  158. {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/entry_points.txt +0 -0
@@ -17,13 +17,12 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- import sys
21
20
  from collections.abc import Awaitable
22
21
  from typing import Callable, TypeVar, cast
23
22
 
24
23
  from google.protobuf.message import Message as GrpcMessage
25
24
 
26
- from flwr.common.constant import MISSING_EXTRA_REST
25
+ from flwr.common.exit import ExitCode, flwr_exit
27
26
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
28
27
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
29
28
  CreateNodeRequest,
@@ -32,10 +31,10 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
32
31
  DeleteNodeResponse,
33
32
  PingRequest,
34
33
  PingResponse,
35
- PullTaskInsRequest,
36
- PullTaskInsResponse,
37
- PushTaskResRequest,
38
- PushTaskResResponse,
34
+ PullMessagesRequest,
35
+ PullMessagesResponse,
36
+ PushMessagesRequest,
37
+ PushMessagesResponse,
39
38
  )
40
39
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
41
40
  from flwr.server.superlink.ffs.ffs import Ffs
@@ -51,7 +50,7 @@ try:
51
50
  from starlette.responses import Response
52
51
  from starlette.routing import Route
53
52
  except ModuleNotFoundError:
54
- sys.exit(MISSING_EXTRA_REST)
53
+ flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)
55
54
 
56
55
 
57
56
  GrpcRequest = TypeVar("GrpcRequest", bound=GrpcMessage)
@@ -107,25 +106,24 @@ async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
107
106
  return message_handler.delete_node(request=request, state=state)
108
107
 
109
108
 
110
- @rest_request_response(PullTaskInsRequest)
111
- async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
112
- """Pull TaskIns."""
109
+ @rest_request_response(PullMessagesRequest)
110
+ async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse:
111
+ """Pull PullMessages."""
113
112
  # Get state from app
114
113
  state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
115
114
 
116
115
  # Handle message
117
- return message_handler.pull_task_ins(request=request, state=state)
116
+ return message_handler.pull_messages(request=request, state=state)
118
117
 
119
118
 
120
- # Check if token is needed here
121
- @rest_request_response(PushTaskResRequest)
122
- async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
123
- """Push TaskRes."""
119
+ @rest_request_response(PushMessagesRequest)
120
+ async def push_message(request: PushMessagesRequest) -> PushMessagesResponse:
121
+ """Pull PushMessages."""
124
122
  # Get state from app
125
123
  state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
126
124
 
127
125
  # Handle message
128
- return message_handler.push_task_res(request=request, state=state)
126
+ return message_handler.push_messages(request=request, state=state)
129
127
 
130
128
 
131
129
  @rest_request_response(PingRequest)
@@ -154,15 +152,18 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
154
152
  # Get ffs from app
155
153
  ffs: Ffs = cast(FfsFactory, app.state.FFS_FACTORY).ffs()
156
154
 
155
+ # Get state from app
156
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
157
+
157
158
  # Handle message
158
- return message_handler.get_fab(request=request, ffs=ffs)
159
+ return message_handler.get_fab(request=request, ffs=ffs, state=state)
159
160
 
160
161
 
161
162
  routes = [
162
163
  Route("/api/v0/fleet/create-node", create_node, methods=["POST"]),
163
164
  Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
164
- Route("/api/v0/fleet/pull-task-ins", pull_task_ins, methods=["POST"]),
165
- Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]),
165
+ Route("/api/v0/fleet/pull-messages", pull_message, methods=["POST"]),
166
+ Route("/api/v0/fleet/push-messages", push_message, methods=["POST"]),
166
167
  Route("/api/v0/fleet/ping", ping, methods=["POST"]),
167
168
  Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
168
169
  Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Fleet Simulation Engine side."""
16
16
 
17
+
17
18
  from .vce_api import start_vce
18
19
 
19
20
  __all__ = [
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Simulation Engine Backends."""
16
16
 
17
+
17
18
  import importlib
18
19
 
19
20
  from .backend import Backend, BackendConfig
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Ray backend for the Fleet API using the Simulation Engine."""
16
16
 
17
+
17
18
  import sys
18
19
  from logging import DEBUG, ERROR
19
20
  from typing import Callable, Optional, Union
@@ -182,8 +182,8 @@ def run_api(
182
182
  f_stop: threading.Event,
183
183
  ) -> None:
184
184
  """Run the VCE."""
185
- taskins_queue: "Queue[TaskIns]" = Queue()
186
- taskres_queue: "Queue[TaskRes]" = Queue()
185
+ taskins_queue: Queue[TaskIns] = Queue()
186
+ taskres_queue: Queue[TaskRes] = Queue()
187
187
 
188
188
  try:
189
189
 
@@ -28,6 +28,7 @@ from flwr.common.constant import (
28
28
  MESSAGE_TTL_TOLERANCE,
29
29
  NODE_ID_NUM_BYTES,
30
30
  RUN_ID_NUM_BYTES,
31
+ SUPERLINK_NODE_ID,
31
32
  Status,
32
33
  )
33
34
  from flwr.common.record import ConfigsRecord
@@ -62,6 +63,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
62
63
  # Map node_id to (online_until, ping_interval)
63
64
  self.node_ids: dict[int, tuple[float, float]] = {}
64
65
  self.public_key_to_node_id: dict[bytes, int] = {}
66
+ self.node_id_to_public_key: dict[int, bytes] = {}
65
67
 
66
68
  # Map run_id to RunRecord
67
69
  self.run_ids: dict[int, RunRecord] = {}
@@ -72,8 +74,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
72
74
  self.task_ins_id_to_task_res_id: dict[UUID, UUID] = {}
73
75
 
74
76
  self.node_public_keys: set[bytes] = set()
75
- self.server_public_key: Optional[bytes] = None
76
- self.server_private_key: Optional[bytes] = None
77
77
 
78
78
  self.lock = threading.RLock()
79
79
 
@@ -89,7 +89,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
89
89
  log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
90
90
  return None
91
91
  # Validate source node ID
92
- if task_ins.task.producer.node_id != 0:
92
+ if task_ins.task.producer.node_id != SUPERLINK_NODE_ID:
93
93
  log(
94
94
  ERROR,
95
95
  "Invalid source node ID for TaskIns: %s",
@@ -97,14 +97,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
97
97
  )
98
98
  return None
99
99
  # Validate destination node ID
100
- if not task_ins.task.consumer.anonymous:
101
- if task_ins.task.consumer.node_id not in self.node_ids:
102
- log(
103
- ERROR,
104
- "Invalid destination node ID for TaskIns: %s",
105
- task_ins.task.consumer.node_id,
106
- )
107
- return None
100
+ if task_ins.task.consumer.node_id not in self.node_ids:
101
+ log(
102
+ ERROR,
103
+ "Invalid destination node ID for TaskIns: %s",
104
+ task_ins.task.consumer.node_id,
105
+ )
106
+ return None
108
107
 
109
108
  # Create task_id
110
109
  task_id = uuid4()
@@ -117,9 +116,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
117
116
  # Return the new task_id
118
117
  return task_id
119
118
 
120
- def get_task_ins(
121
- self, node_id: Optional[int], limit: Optional[int]
122
- ) -> list[TaskIns]:
119
+ def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
123
120
  """Get all TaskIns that have not been delivered yet."""
124
121
  if limit is not None and limit < 1:
125
122
  raise AssertionError("`limit` must be >= 1")
@@ -129,17 +126,8 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
129
126
  current_time = time.time()
130
127
  with self.lock:
131
128
  for _, task_ins in self.task_ins_store.items():
132
- # pylint: disable=too-many-boolean-expressions
133
129
  if (
134
- node_id is not None # Not anonymous
135
- and task_ins.task.consumer.anonymous is False
136
- and task_ins.task.consumer.node_id == node_id
137
- and task_ins.task.delivered_at == ""
138
- and task_ins.task.created_at + task_ins.task.ttl > current_time
139
- ) or (
140
- node_id is None # Anonymous
141
- and task_ins.task.consumer.anonymous is True
142
- and task_ins.task.consumer.node_id == 0
130
+ task_ins.task.consumer.node_id == node_id
143
131
  and task_ins.task.delivered_at == ""
144
132
  and task_ins.task.created_at + task_ins.task.ttl > current_time
145
133
  ):
@@ -173,9 +161,6 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
173
161
  if (
174
162
  task_ins
175
163
  and task_res
176
- and not (
177
- task_ins.task.consumer.anonymous or task_res.task.producer.anonymous
178
- )
179
164
  and task_ins.task.consumer.node_id != task_res.task.producer.node_id
180
165
  ):
181
166
  return None
@@ -265,41 +250,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
265
250
  for task_res in task_res_found:
266
251
  task_res.task.delivered_at = delivered_at
267
252
 
268
- # Cleanup
269
- self._force_delete_tasks_by_ids(set(ret.keys()))
270
-
271
253
  return list(ret.values())
272
254
 
273
- def delete_tasks(self, task_ids: set[UUID]) -> None:
274
- """Delete all delivered TaskIns/TaskRes pairs."""
275
- task_ins_to_be_deleted: set[UUID] = set()
276
- task_res_to_be_deleted: set[UUID] = set()
277
-
278
- with self.lock:
279
- for task_ins_id in task_ids:
280
- # Find the task_id of the matching task_res
281
- for task_res_id, task_res in self.task_res_store.items():
282
- if UUID(task_res.task.ancestry[0]) != task_ins_id:
283
- continue
284
- if task_res.task.delivered_at == "":
285
- continue
286
-
287
- task_ins_to_be_deleted.add(task_ins_id)
288
- task_res_to_be_deleted.add(task_res_id)
289
-
290
- for task_id in task_ins_to_be_deleted:
291
- del self.task_ins_store[task_id]
292
- del self.task_ins_id_to_task_res_id[task_id]
293
- for task_id in task_res_to_be_deleted:
294
- del self.task_res_store[task_id]
295
-
296
- def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
297
- """Delete tasks based on a set of TaskIns IDs."""
298
- if not task_ids:
255
+ def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
256
+ """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
257
+ if not task_ins_ids:
299
258
  return
300
259
 
301
260
  with self.lock:
302
- for task_id in task_ids:
261
+ for task_id in task_ins_ids:
303
262
  # Delete TaskIns
304
263
  if task_id in self.task_ins_store:
305
264
  del self.task_ins_store[task_id]
@@ -308,6 +267,16 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
308
267
  task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
309
268
  del self.task_res_store[task_res_id]
310
269
 
270
+ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
271
+ """Get all TaskIns IDs for the given run_id."""
272
+ task_id_list: set[UUID] = set()
273
+ with self.lock:
274
+ for task_id, task_ins in self.task_ins_store.items():
275
+ if task_ins.run_id == run_id:
276
+ task_id_list.add(task_id)
277
+
278
+ return task_id_list
279
+
311
280
  def num_task_ins(self) -> int:
312
281
  """Calculate the number of task_ins in store.
313
282
 
@@ -322,45 +291,30 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
322
291
  """
323
292
  return len(self.task_res_store)
324
293
 
325
- def create_node(
326
- self, ping_interval: float, public_key: Optional[bytes] = None
327
- ) -> int:
294
+ def create_node(self, ping_interval: float) -> int:
328
295
  """Create, store in the link state, and return `node_id`."""
329
296
  # Sample a random int64 as node_id
330
- node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
297
+ node_id = generate_rand_int_from_bytes(
298
+ NODE_ID_NUM_BYTES, exclude=[SUPERLINK_NODE_ID, 0]
299
+ )
331
300
 
332
301
  with self.lock:
333
302
  if node_id in self.node_ids:
334
303
  log(ERROR, "Unexpected node registration failure.")
335
304
  return 0
336
305
 
337
- if public_key is not None:
338
- if (
339
- public_key in self.public_key_to_node_id
340
- or node_id in self.public_key_to_node_id.values()
341
- ):
342
- log(ERROR, "Unexpected node registration failure.")
343
- return 0
344
-
345
- self.public_key_to_node_id[public_key] = node_id
346
-
347
306
  self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
348
307
  return node_id
349
308
 
350
- def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
309
+ def delete_node(self, node_id: int) -> None:
351
310
  """Delete a node."""
352
311
  with self.lock:
353
312
  if node_id not in self.node_ids:
354
313
  raise ValueError(f"Node {node_id} not found")
355
314
 
356
- if public_key is not None:
357
- if (
358
- public_key not in self.public_key_to_node_id
359
- or node_id not in self.public_key_to_node_id.values()
360
- ):
361
- raise ValueError("Public key or node_id not found")
362
-
363
- del self.public_key_to_node_id[public_key]
315
+ # Remove node ID <> public key mappings
316
+ if pk := self.node_id_to_public_key.pop(node_id, None):
317
+ del self.public_key_to_node_id[pk]
364
318
 
365
319
  del self.node_ids[node_id]
366
320
 
@@ -382,6 +336,26 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
382
336
  if online_until > current_time
383
337
  }
384
338
 
339
+ def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
340
+ """Set `public_key` for the specified `node_id`."""
341
+ with self.lock:
342
+ if node_id not in self.node_ids:
343
+ raise ValueError(f"Node {node_id} not found")
344
+
345
+ if public_key in self.public_key_to_node_id:
346
+ raise ValueError("Public key already in use")
347
+
348
+ self.public_key_to_node_id[public_key] = node_id
349
+ self.node_id_to_public_key[node_id] = public_key
350
+
351
+ def get_node_public_key(self, node_id: int) -> Optional[bytes]:
352
+ """Get `public_key` for the specified `node_id`."""
353
+ with self.lock:
354
+ if node_id not in self.node_ids:
355
+ raise ValueError(f"Node {node_id} not found")
356
+
357
+ return self.node_id_to_public_key.get(node_id)
358
+
385
359
  def get_node_id(self, node_public_key: bytes) -> Optional[int]:
386
360
  """Retrieve stored `node_id` filtered by `node_public_keys`."""
387
361
  return self.public_key_to_node_id.get(node_public_key)
@@ -427,29 +401,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
427
401
  log(ERROR, "Unexpected run creation failure.")
428
402
  return 0
429
403
 
430
- def store_server_private_public_key(
431
- self, private_key: bytes, public_key: bytes
432
- ) -> None:
433
- """Store `server_private_key` and `server_public_key` in the link state."""
404
+ def clear_supernode_auth_keys(self) -> None:
405
+ """Clear stored `node_public_keys` in the link state if any."""
434
406
  with self.lock:
435
- if self.server_private_key is None and self.server_public_key is None:
436
- self.server_private_key = private_key
437
- self.server_public_key = public_key
438
- else:
439
- raise RuntimeError("Server private and public key already set")
440
-
441
- def get_server_private_key(self) -> Optional[bytes]:
442
- """Retrieve `server_private_key` in urlsafe bytes."""
443
- return self.server_private_key
444
-
445
- def get_server_public_key(self) -> Optional[bytes]:
446
- """Retrieve `server_public_key` in urlsafe bytes."""
447
- return self.server_public_key
407
+ self.node_public_keys.clear()
448
408
 
449
409
  def store_node_public_keys(self, public_keys: set[bytes]) -> None:
450
410
  """Store a set of `node_public_keys` in the link state."""
451
411
  with self.lock:
452
- self.node_public_keys = public_keys
412
+ self.node_public_keys.update(public_keys)
453
413
 
454
414
  def store_node_public_key(self, public_key: bytes) -> None:
455
415
  """Store a `node_public_key` in the link state."""
@@ -458,7 +418,8 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
458
418
 
459
419
  def get_node_public_keys(self) -> set[bytes]:
460
420
  """Retrieve all currently stored `node_public_keys` as a set."""
461
- return self.node_public_keys
421
+ with self.lock:
422
+ return self.node_public_keys.copy()
462
423
 
463
424
  def get_run_ids(self) -> set[int]:
464
425
  """Retrieve all run IDs."""
@@ -40,20 +40,14 @@ class LinkState(abc.ABC): # pylint: disable=R0904
40
40
 
41
41
  Constraints
42
42
  -----------
43
- If `task_ins.task.consumer.anonymous` is `True`, then
44
- `task_ins.task.consumer.node_id` MUST NOT be set (equal 0).
45
-
46
- If `task_ins.task.consumer.anonymous` is `False`, then
47
- `task_ins.task.consumer.node_id` MUST be set (not 0)
43
+ `task_ins.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
48
44
 
49
45
  If `task_ins.run_id` is invalid, then
50
46
  storing the `task_ins` MUST fail.
51
47
  """
52
48
 
53
49
  @abc.abstractmethod
54
- def get_task_ins(
55
- self, node_id: Optional[int], limit: Optional[int]
56
- ) -> list[TaskIns]:
50
+ def get_task_ins(self, node_id: int, limit: Optional[int]) -> list[TaskIns]:
57
51
  """Get TaskIns optionally filtered by node_id.
58
52
 
59
53
  Usually, the Fleet API calls this for Nodes planning to work on one or more
@@ -61,15 +55,11 @@ class LinkState(abc.ABC): # pylint: disable=R0904
61
55
 
62
56
  Constraints
63
57
  -----------
64
- If `node_id` is not `None`, retrieve all TaskIns where
58
+ Retrieve all TaskIns where
65
59
 
66
60
  1. the `task_ins.task.consumer.node_id` equals `node_id` AND
67
- 2. the `task_ins.task.consumer.anonymous` equals `False` AND
68
- 3. the `task_ins.task.delivered_at` equals `""`.
61
+ 2. the `task_ins.task.delivered_at` equals `""`.
69
62
 
70
- If `node_id` is `None`, retrieve all TaskIns where the
71
- `task_ins.task.consumer.node_id` equals `0` and
72
- `task_ins.task.consumer.anonymous` is set to `True`.
73
63
 
74
64
  If `delivered_at` MUST BE set (not `""`) otherwise the TaskIns MUST not be in
75
65
  the result.
@@ -89,11 +79,8 @@ class LinkState(abc.ABC): # pylint: disable=R0904
89
79
 
90
80
  Constraints
91
81
  -----------
92
- If `task_res.task.consumer.anonymous` is `True`, then
93
- `task_res.task.consumer.node_id` MUST NOT be set (equal 0).
94
82
 
95
- If `task_res.task.consumer.anonymous` is `False`, then
96
- `task_res.task.consumer.node_id` MUST be set (not 0)
83
+ `task_res.task.consumer.node_id` MUST be set (not constant.DRIVER_NODE_ID)
97
84
 
98
85
  If `task_res.run_id` is invalid, then
99
86
  storing the `task_res` MUST fail.
@@ -139,17 +126,26 @@ class LinkState(abc.ABC): # pylint: disable=R0904
139
126
  """
140
127
 
141
128
  @abc.abstractmethod
142
- def delete_tasks(self, task_ids: set[UUID]) -> None:
143
- """Delete all delivered TaskIns/TaskRes pairs."""
129
+ def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
130
+ """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs.
131
+
132
+ Parameters
133
+ ----------
134
+ task_ins_ids : set[UUID]
135
+ A set of TaskIns IDs. For each ID in the set, the corresponding
136
+ TaskIns and its associated TaskRes will be deleted.
137
+ """
138
+
139
+ @abc.abstractmethod
140
+ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
141
+ """Get all TaskIns IDs for the given run_id."""
144
142
 
145
143
  @abc.abstractmethod
146
- def create_node(
147
- self, ping_interval: float, public_key: Optional[bytes] = None
148
- ) -> int:
144
+ def create_node(self, ping_interval: float) -> int:
149
145
  """Create, store in the link state, and return `node_id`."""
150
146
 
151
147
  @abc.abstractmethod
152
- def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
148
+ def delete_node(self, node_id: int) -> None:
153
149
  """Remove `node_id` from the link state."""
154
150
 
155
151
  @abc.abstractmethod
@@ -162,6 +158,14 @@ class LinkState(abc.ABC): # pylint: disable=R0904
162
158
  an empty `Set` MUST be returned.
163
159
  """
164
160
 
161
+ @abc.abstractmethod
162
+ def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
163
+ """Set `public_key` for the specified `node_id`."""
164
+
165
+ @abc.abstractmethod
166
+ def get_node_public_key(self, node_id: int) -> Optional[bytes]:
167
+ """Get `public_key` for the specified `node_id`."""
168
+
165
169
  @abc.abstractmethod
166
170
  def get_node_id(self, node_public_key: bytes) -> Optional[int]:
167
171
  """Retrieve stored `node_id` filtered by `node_public_keys`."""
@@ -260,18 +264,8 @@ class LinkState(abc.ABC): # pylint: disable=R0904
260
264
  """
261
265
 
262
266
  @abc.abstractmethod
263
- def store_server_private_public_key(
264
- self, private_key: bytes, public_key: bytes
265
- ) -> None:
266
- """Store `server_private_key` and `server_public_key` in the link state."""
267
-
268
- @abc.abstractmethod
269
- def get_server_private_key(self) -> Optional[bytes]:
270
- """Retrieve `server_private_key` in urlsafe bytes."""
271
-
272
- @abc.abstractmethod
273
- def get_server_public_key(self) -> Optional[bytes]:
274
- """Retrieve `server_public_key` in urlsafe bytes."""
267
+ def clear_supernode_auth_keys(self) -> None:
268
+ """Clear stored `node_public_keys` in the link state if any."""
275
269
 
276
270
  @abc.abstractmethod
277
271
  def store_node_public_keys(self, public_keys: set[bytes]) -> None: