flwr 1.13.0__py3-none-any.whl → 1.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. flwr/cli/app.py +5 -0
  2. flwr/cli/build.py +1 -37
  3. flwr/cli/cli_user_auth_interceptor.py +86 -0
  4. flwr/cli/config_utils.py +19 -2
  5. flwr/cli/example.py +1 -0
  6. flwr/cli/install.py +2 -19
  7. flwr/cli/log.py +18 -36
  8. flwr/cli/login/__init__.py +22 -0
  9. flwr/cli/login/login.py +81 -0
  10. flwr/cli/ls.py +205 -106
  11. flwr/cli/new/__init__.py +1 -0
  12. flwr/cli/new/new.py +25 -14
  13. flwr/cli/new/templates/app/.gitignore.tpl +3 -0
  14. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  15. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +3 -3
  16. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  17. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  18. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -3
  19. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  20. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  21. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  22. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  23. flwr/cli/run/__init__.py +1 -0
  24. flwr/cli/run/run.py +89 -39
  25. flwr/cli/stop.py +130 -0
  26. flwr/cli/utils.py +172 -8
  27. flwr/client/app.py +14 -3
  28. flwr/client/client.py +1 -32
  29. flwr/client/clientapp/app.py +4 -8
  30. flwr/client/clientapp/utils.py +1 -0
  31. flwr/client/grpc_adapter_client/connection.py +1 -1
  32. flwr/client/grpc_client/connection.py +1 -1
  33. flwr/client/grpc_rere_client/connection.py +13 -7
  34. flwr/client/message_handler/message_handler.py +1 -2
  35. flwr/client/mod/comms_mods.py +1 -0
  36. flwr/client/mod/localdp_mod.py +1 -1
  37. flwr/client/nodestate/__init__.py +1 -0
  38. flwr/client/nodestate/nodestate.py +1 -0
  39. flwr/client/nodestate/nodestate_factory.py +1 -0
  40. flwr/client/numpy_client.py +0 -44
  41. flwr/client/rest_client/connection.py +3 -3
  42. flwr/client/supernode/app.py +2 -2
  43. flwr/common/address.py +1 -0
  44. flwr/common/args.py +1 -0
  45. flwr/common/auth_plugin/__init__.py +24 -0
  46. flwr/common/auth_plugin/auth_plugin.py +111 -0
  47. flwr/common/config.py +3 -1
  48. flwr/common/constant.py +17 -1
  49. flwr/common/logger.py +40 -0
  50. flwr/common/message.py +1 -0
  51. flwr/common/object_ref.py +57 -54
  52. flwr/common/pyproject.py +1 -0
  53. flwr/common/record/__init__.py +1 -0
  54. flwr/common/record/parametersrecord.py +1 -0
  55. flwr/common/retry_invoker.py +77 -0
  56. flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
  57. flwr/common/telemetry.py +15 -4
  58. flwr/common/typing.py +12 -0
  59. flwr/common/version.py +1 -0
  60. flwr/proto/exec_pb2.py +38 -14
  61. flwr/proto/exec_pb2.pyi +107 -2
  62. flwr/proto/exec_pb2_grpc.py +102 -0
  63. flwr/proto/exec_pb2_grpc.pyi +39 -0
  64. flwr/proto/fab_pb2.py +4 -4
  65. flwr/proto/fab_pb2.pyi +4 -1
  66. flwr/proto/serverappio_pb2.py +18 -18
  67. flwr/proto/serverappio_pb2.pyi +8 -2
  68. flwr/proto/serverappio_pb2_grpc.py +34 -0
  69. flwr/proto/serverappio_pb2_grpc.pyi +13 -0
  70. flwr/proto/simulationio_pb2.py +2 -2
  71. flwr/proto/simulationio_pb2_grpc.py +34 -0
  72. flwr/proto/simulationio_pb2_grpc.pyi +13 -0
  73. flwr/server/app.py +62 -7
  74. flwr/server/compat/app_utils.py +7 -1
  75. flwr/server/driver/grpc_driver.py +11 -63
  76. flwr/server/driver/inmemory_driver.py +5 -1
  77. flwr/server/run_serverapp.py +8 -9
  78. flwr/server/serverapp/app.py +25 -10
  79. flwr/server/strategy/dpfedavg_fixed.py +1 -0
  80. flwr/server/superlink/driver/serverappio_grpc.py +1 -0
  81. flwr/server/superlink/driver/serverappio_servicer.py +82 -23
  82. flwr/server/superlink/ffs/disk_ffs.py +1 -0
  83. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -0
  84. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
  85. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +32 -12
  86. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +12 -11
  87. flwr/server/superlink/fleet/message_handler/message_handler.py +32 -5
  88. flwr/server/superlink/fleet/rest_rere/rest_api.py +4 -1
  89. flwr/server/superlink/fleet/vce/__init__.py +1 -0
  90. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
  91. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
  92. flwr/server/superlink/linkstate/in_memory_linkstate.py +21 -30
  93. flwr/server/superlink/linkstate/linkstate.py +17 -2
  94. flwr/server/superlink/linkstate/sqlite_linkstate.py +30 -49
  95. flwr/server/superlink/simulation/simulationio_servicer.py +33 -0
  96. flwr/server/superlink/utils.py +65 -0
  97. flwr/simulation/app.py +59 -52
  98. flwr/simulation/ray_transport/ray_actor.py +1 -0
  99. flwr/simulation/ray_transport/utils.py +1 -0
  100. flwr/simulation/run_simulation.py +36 -22
  101. flwr/simulation/simulationio_connection.py +3 -0
  102. flwr/superexec/app.py +1 -0
  103. flwr/superexec/deployment.py +1 -0
  104. flwr/superexec/exec_grpc.py +19 -1
  105. flwr/superexec/exec_servicer.py +76 -2
  106. flwr/superexec/exec_user_auth_interceptor.py +101 -0
  107. flwr/superexec/executor.py +1 -0
  108. {flwr-1.13.0.dist-info → flwr-1.14.0.dist-info}/METADATA +8 -8
  109. {flwr-1.13.0.dist-info → flwr-1.14.0.dist-info}/RECORD +112 -112
  110. flwr/proto/common_pb2.py +0 -36
  111. flwr/proto/common_pb2.pyi +0 -121
  112. flwr/proto/common_pb2_grpc.py +0 -4
  113. flwr/proto/common_pb2_grpc.pyi +0 -4
  114. flwr/proto/control_pb2.py +0 -27
  115. flwr/proto/control_pb2.pyi +0 -7
  116. flwr/proto/control_pb2_grpc.py +0 -135
  117. flwr/proto/control_pb2_grpc.pyi +0 -53
  118. {flwr-1.13.0.dist-info → flwr-1.14.0.dist-info}/LICENSE +0 -0
  119. {flwr-1.13.0.dist-info → flwr-1.14.0.dist-info}/WHEEL +0 -0
  120. {flwr-1.13.0.dist-info → flwr-1.14.0.dist-info}/entry_points.txt +0 -0
@@ -18,6 +18,7 @@ Relevant knowledge for reading this modules code:
18
18
  - https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
19
19
  """
20
20
 
21
+
21
22
  import uuid
22
23
  from collections.abc import Iterator
23
24
  from typing import Callable
@@ -20,6 +20,7 @@ from logging import DEBUG, INFO
20
20
  import grpc
21
21
 
22
22
  from flwr.common.logger import log
23
+ from flwr.common.typing import InvalidRunStatusException
23
24
  from flwr.proto import fleet_pb2_grpc # pylint: disable=E0611
24
25
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
25
26
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
@@ -38,6 +39,7 @@ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=
38
39
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
39
40
  from flwr.server.superlink.fleet.message_handler import message_handler
40
41
  from flwr.server.superlink.linkstate import LinkStateFactory
42
+ from flwr.server.superlink.utils import abort_grpc_context
41
43
 
42
44
 
43
45
  class FleetServicer(fleet_pb2_grpc.FleetServicer):
@@ -105,27 +107,45 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
105
107
  )
106
108
  else:
107
109
  log(INFO, "[Fleet.PushTaskRes] No task results to push")
108
- return message_handler.push_task_res(
109
- request=request,
110
- state=self.state_factory.state(),
111
- )
110
+
111
+ try:
112
+ res = message_handler.push_task_res(
113
+ request=request,
114
+ state=self.state_factory.state(),
115
+ )
116
+ except InvalidRunStatusException as e:
117
+ abort_grpc_context(e.message, context)
118
+
119
+ return res
112
120
 
113
121
  def GetRun(
114
122
  self, request: GetRunRequest, context: grpc.ServicerContext
115
123
  ) -> GetRunResponse:
116
124
  """Get run information."""
117
125
  log(INFO, "[Fleet.GetRun] Requesting `Run` for run_id=%s", request.run_id)
118
- return message_handler.get_run(
119
- request=request,
120
- state=self.state_factory.state(),
121
- )
126
+
127
+ try:
128
+ res = message_handler.get_run(
129
+ request=request,
130
+ state=self.state_factory.state(),
131
+ )
132
+ except InvalidRunStatusException as e:
133
+ abort_grpc_context(e.message, context)
134
+
135
+ return res
122
136
 
123
137
  def GetFab(
124
138
  self, request: GetFabRequest, context: grpc.ServicerContext
125
139
  ) -> GetFabResponse:
126
140
  """Get FAB."""
127
141
  log(INFO, "[Fleet.GetFab] Requesting FAB for fab_hash=%s", request.hash_str)
128
- return message_handler.get_fab(
129
- request=request,
130
- ffs=self.ffs_factory.ffs(),
131
- )
142
+ try:
143
+ res = message_handler.get_fab(
144
+ request=request,
145
+ ffs=self.ffs_factory.ffs(),
146
+ state=self.state_factory.state(),
147
+ )
148
+ except InvalidRunStatusException as e:
149
+ abort_grpc_context(e.message, context)
150
+
151
+ return res
@@ -45,7 +45,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
45
45
  )
46
46
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
47
47
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
48
- from flwr.server.superlink.linkstate import LinkState
48
+ from flwr.server.superlink.linkstate import LinkStateFactory
49
49
 
50
50
  _PUBLIC_KEY_HEADER = "public-key"
51
51
  _AUTH_TOKEN_HEADER = "auth-token"
@@ -84,15 +84,16 @@ def _get_value_from_tuples(
84
84
  class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
85
85
  """Server interceptor for node authentication."""
86
86
 
87
- def __init__(self, state: LinkState):
88
- self.state = state
87
+ def __init__(self, state_factory: LinkStateFactory):
88
+ self.state_factory = state_factory
89
+ state = self.state_factory.state()
89
90
 
90
91
  self.node_public_keys = state.get_node_public_keys()
91
92
  if len(self.node_public_keys) == 0:
92
93
  log(WARNING, "Authentication enabled, but no known public keys configured")
93
94
 
94
- private_key = self.state.get_server_private_key()
95
- public_key = self.state.get_server_public_key()
95
+ private_key = state.get_server_private_key()
96
+ public_key = state.get_server_public_key()
96
97
 
97
98
  if private_key is None or public_key is None:
98
99
  raise ValueError("Error loading authentication keys")
@@ -154,7 +155,7 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
154
155
  context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
155
156
 
156
157
  # Verify node_id
157
- node_id = self.state.get_node_id(node_public_key_bytes)
158
+ node_id = self.state_factory.state().get_node_id(node_public_key_bytes)
158
159
 
159
160
  if not self._verify_node_id(node_id, request):
160
161
  context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
@@ -186,7 +187,7 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
186
187
  return False
187
188
  return request.task_res_list[0].task.producer.node_id == node_id
188
189
  if isinstance(request, GetRunRequest):
189
- return node_id in self.state.get_nodes(request.run_id)
190
+ return node_id in self.state_factory.state().get_nodes(request.run_id)
190
191
  return request.node.node_id == node_id
191
192
 
192
193
  def _verify_hmac(
@@ -210,17 +211,17 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
210
211
  ),
211
212
  )
212
213
  )
213
-
214
- node_id = self.state.get_node_id(public_key_bytes)
214
+ state = self.state_factory.state()
215
+ node_id = state.get_node_id(public_key_bytes)
215
216
 
216
217
  # Handle `CreateNode` here instead of calling the default method handler
217
218
  # Return previously assigned `node_id` for the provided `public_key`
218
219
  if node_id is not None:
219
- self.state.acknowledge_ping(node_id, request.ping_interval)
220
+ state.acknowledge_ping(node_id, request.ping_interval)
220
221
  return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
221
222
 
222
223
  # No `node_id` exists for the provided `public_key`
223
224
  # Handle `CreateNode` here instead of calling the default method handler
224
225
  # Note: the innermost `CreateNode` method will never be called
225
- node_id = self.state.create_node(request.ping_interval, public_key_bytes)
226
+ node_id = state.create_node(request.ping_interval, public_key_bytes)
226
227
  return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
@@ -19,8 +19,9 @@ import time
19
19
  from typing import Optional
20
20
  from uuid import UUID
21
21
 
22
+ from flwr.common.constant import Status
22
23
  from flwr.common.serde import fab_to_proto, user_config_to_proto
23
- from flwr.common.typing import Fab
24
+ from flwr.common.typing import Fab, InvalidRunStatusException
24
25
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
25
26
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
26
27
  CreateNodeRequest,
@@ -44,6 +45,7 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
44
45
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
45
46
  from flwr.server.superlink.ffs.ffs import Ffs
46
47
  from flwr.server.superlink.linkstate import LinkState
48
+ from flwr.server.superlink.utils import check_abort
47
49
 
48
50
 
49
51
  def create_node(
@@ -98,6 +100,15 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR
98
100
  task_res: TaskRes = request.task_res_list[0]
99
101
  # pylint: enable=no-member
100
102
 
103
+ # Abort if the run is not running
104
+ abort_msg = check_abort(
105
+ task_res.run_id,
106
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
107
+ state,
108
+ )
109
+ if abort_msg:
110
+ raise InvalidRunStatusException(abort_msg)
111
+
101
112
  # Set pushed_at (timestamp in seconds)
102
113
  task_res.task.pushed_at = time.time()
103
114
 
@@ -112,15 +123,22 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR
112
123
  return response
113
124
 
114
125
 
115
- def get_run(
116
- request: GetRunRequest, state: LinkState # pylint: disable=W0613
117
- ) -> GetRunResponse:
126
+ def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
118
127
  """Get run information."""
119
128
  run = state.get_run(request.run_id)
120
129
 
121
130
  if run is None:
122
131
  return GetRunResponse()
123
132
 
133
+ # Abort if the run is not running
134
+ abort_msg = check_abort(
135
+ request.run_id,
136
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
137
+ state,
138
+ )
139
+ if abort_msg:
140
+ raise InvalidRunStatusException(abort_msg)
141
+
124
142
  return GetRunResponse(
125
143
  run=Run(
126
144
  run_id=run.run_id,
@@ -133,9 +151,18 @@ def get_run(
133
151
 
134
152
 
135
153
  def get_fab(
136
- request: GetFabRequest, ffs: Ffs # pylint: disable=W0613
154
+ request: GetFabRequest, ffs: Ffs, state: LinkState # pylint: disable=W0613
137
155
  ) -> GetFabResponse:
138
156
  """Get FAB."""
157
+ # Abort if the run is not running
158
+ abort_msg = check_abort(
159
+ request.run_id,
160
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
161
+ state,
162
+ )
163
+ if abort_msg:
164
+ raise InvalidRunStatusException(abort_msg)
165
+
139
166
  if result := ffs.get(request.hash_str):
140
167
  fab = Fab(request.hash_str, result[0])
141
168
  return GetFabResponse(fab=fab_to_proto(fab))
@@ -154,8 +154,11 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
154
154
  # Get ffs from app
155
155
  ffs: Ffs = cast(FfsFactory, app.state.FFS_FACTORY).ffs()
156
156
 
157
+ # Get state from app
158
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
159
+
157
160
  # Handle message
158
- return message_handler.get_fab(request=request, ffs=ffs)
161
+ return message_handler.get_fab(request=request, ffs=ffs, state=state)
159
162
 
160
163
 
161
164
  routes = [
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Fleet Simulation Engine side."""
16
16
 
17
+
17
18
  from .vce_api import start_vce
18
19
 
19
20
  __all__ = [
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Simulation Engine Backends."""
16
16
 
17
+
17
18
  import importlib
18
19
 
19
20
  from .backend import Backend, BackendConfig
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Ray backend for the Fleet API using the Simulation Engine."""
16
16
 
17
+
17
18
  import sys
18
19
  from logging import DEBUG, ERROR
19
20
  from typing import Callable, Optional, Union
@@ -265,41 +265,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
265
265
  for task_res in task_res_found:
266
266
  task_res.task.delivered_at = delivered_at
267
267
 
268
- # Cleanup
269
- self._force_delete_tasks_by_ids(set(ret.keys()))
270
-
271
268
  return list(ret.values())
272
269
 
273
- def delete_tasks(self, task_ids: set[UUID]) -> None:
274
- """Delete all delivered TaskIns/TaskRes pairs."""
275
- task_ins_to_be_deleted: set[UUID] = set()
276
- task_res_to_be_deleted: set[UUID] = set()
277
-
278
- with self.lock:
279
- for task_ins_id in task_ids:
280
- # Find the task_id of the matching task_res
281
- for task_res_id, task_res in self.task_res_store.items():
282
- if UUID(task_res.task.ancestry[0]) != task_ins_id:
283
- continue
284
- if task_res.task.delivered_at == "":
285
- continue
286
-
287
- task_ins_to_be_deleted.add(task_ins_id)
288
- task_res_to_be_deleted.add(task_res_id)
289
-
290
- for task_id in task_ins_to_be_deleted:
291
- del self.task_ins_store[task_id]
292
- del self.task_ins_id_to_task_res_id[task_id]
293
- for task_id in task_res_to_be_deleted:
294
- del self.task_res_store[task_id]
295
-
296
- def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
297
- """Delete tasks based on a set of TaskIns IDs."""
298
- if not task_ids:
270
+ def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
271
+ """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
272
+ if not task_ins_ids:
299
273
  return
300
274
 
301
275
  with self.lock:
302
- for task_id in task_ids:
276
+ for task_id in task_ins_ids:
303
277
  # Delete TaskIns
304
278
  if task_id in self.task_ins_store:
305
279
  del self.task_ins_store[task_id]
@@ -308,6 +282,16 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
308
282
  task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
309
283
  del self.task_res_store[task_res_id]
310
284
 
285
+ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
286
+ """Get all TaskIns IDs for the given run_id."""
287
+ task_id_list: set[UUID] = set()
288
+ with self.lock:
289
+ for task_id, task_ins in self.task_ins_store.items():
290
+ if task_ins.run_id == run_id:
291
+ task_id_list.add(task_id)
292
+
293
+ return task_id_list
294
+
311
295
  def num_task_ins(self) -> int:
312
296
  """Calculate the number of task_ins in store.
313
297
 
@@ -446,6 +430,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
446
430
  """Retrieve `server_public_key` in urlsafe bytes."""
447
431
  return self.server_public_key
448
432
 
433
+ def clear_supernode_auth_keys_and_credentials(self) -> None:
434
+ """Clear stored `node_public_keys` and credentials in the link state if any."""
435
+ with self.lock:
436
+ self.server_private_key = None
437
+ self.server_public_key = None
438
+ self.node_public_keys.clear()
439
+
449
440
  def store_node_public_keys(self, public_keys: set[bytes]) -> None:
450
441
  """Store a set of `node_public_keys` in the link state."""
451
442
  with self.lock:
@@ -139,8 +139,19 @@ class LinkState(abc.ABC): # pylint: disable=R0904
139
139
  """
140
140
 
141
141
  @abc.abstractmethod
142
- def delete_tasks(self, task_ids: set[UUID]) -> None:
143
- """Delete all delivered TaskIns/TaskRes pairs."""
142
+ def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
143
+ """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs.
144
+
145
+ Parameters
146
+ ----------
147
+ task_ins_ids : set[UUID]
148
+ A set of TaskIns IDs. For each ID in the set, the corresponding
149
+ TaskIns and its associated TaskRes will be deleted.
150
+ """
151
+
152
+ @abc.abstractmethod
153
+ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
154
+ """Get all TaskIns IDs for the given run_id."""
144
155
 
145
156
  @abc.abstractmethod
146
157
  def create_node(
@@ -273,6 +284,10 @@ class LinkState(abc.ABC): # pylint: disable=R0904
273
284
  def get_server_public_key(self) -> Optional[bytes]:
274
285
  """Retrieve `server_public_key` in urlsafe bytes."""
275
286
 
287
+ @abc.abstractmethod
288
+ def clear_supernode_auth_keys_and_credentials(self) -> None:
289
+ """Clear stored `node_public_keys` and credentials in the link state if any."""
290
+
276
291
  @abc.abstractmethod
277
292
  def store_node_public_keys(self, public_keys: set[bytes]) -> None:
278
293
  """Store a set of `node_public_keys` in the link state."""
@@ -14,12 +14,12 @@
14
14
  # ==============================================================================
15
15
  """SQLite based implemenation of the link state."""
16
16
 
17
+
17
18
  # pylint: disable=too-many-lines
18
19
 
19
20
  import json
20
21
  import re
21
22
  import sqlite3
22
- import threading
23
23
  import time
24
24
  from collections.abc import Sequence
25
25
  from logging import DEBUG, ERROR, WARNING
@@ -183,7 +183,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
183
183
  """
184
184
  self.database_path = database_path
185
185
  self.conn: Optional[sqlite3.Connection] = None
186
- self.lock = threading.RLock()
187
186
 
188
187
  def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
189
188
  """Create tables if they don't exist yet.
@@ -216,7 +215,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
216
215
  cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY)
217
216
  cur.execute(SQL_CREATE_INDEX_ONLINE_UNTIL)
218
217
  res = cur.execute("SELECT name FROM sqlite_schema;")
219
-
220
218
  return res.fetchall()
221
219
 
222
220
  def query(
@@ -569,9 +567,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
569
567
  data: list[Any] = [delivered_at] + task_res_ids
570
568
  self.query(query, data)
571
569
 
572
- # Cleanup
573
- self._force_delete_tasks_by_ids(set(ret.keys()))
574
-
575
570
  return list(ret.values())
576
571
 
577
572
  def num_task_ins(self) -> int:
@@ -595,68 +590,50 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
595
590
  result: dict[str, int] = rows[0]
596
591
  return result["num"]
597
592
 
598
- def delete_tasks(self, task_ids: set[UUID]) -> None:
599
- """Delete all delivered TaskIns/TaskRes pairs."""
600
- ids = list(task_ids)
601
- if len(ids) == 0:
602
- return None
593
+ def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
594
+ """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
595
+ if not task_ins_ids:
596
+ return
597
+ if self.conn is None:
598
+ raise AttributeError("LinkState not initialized")
603
599
 
604
- placeholders = ",".join([f":id_{index}" for index in range(len(task_ids))])
605
- data = {f"id_{index}": str(task_id) for index, task_id in enumerate(task_ids)}
600
+ placeholders = ",".join(["?"] * len(task_ins_ids))
601
+ data = tuple(str(task_id) for task_id in task_ins_ids)
606
602
 
607
- # 1. Query: Delete task_ins which have a delivered task_res
603
+ # Delete task_ins
608
604
  query_1 = f"""
609
605
  DELETE FROM task_ins
610
- WHERE delivered_at != ''
611
- AND task_id IN (
612
- SELECT ancestry
613
- FROM task_res
614
- WHERE ancestry IN ({placeholders})
615
- AND delivered_at != ''
616
- );
606
+ WHERE task_id IN ({placeholders});
617
607
  """
618
608
 
619
- # 2. Query: Delete delivered task_res to be run after 1. Query
609
+ # Delete task_res
620
610
  query_2 = f"""
621
611
  DELETE FROM task_res
622
- WHERE ancestry IN ({placeholders})
623
- AND delivered_at != '';
612
+ WHERE ancestry IN ({placeholders});
624
613
  """
625
614
 
626
- if self.conn is None:
627
- raise AttributeError("LinkState not intitialized")
628
-
629
615
  with self.conn:
630
616
  self.conn.execute(query_1, data)
631
617
  self.conn.execute(query_2, data)
632
618
 
633
- return None
634
-
635
- def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
636
- """Delete tasks based on a set of TaskIns IDs."""
637
- if not task_ids:
638
- return
619
+ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
620
+ """Get all TaskIns IDs for the given run_id."""
639
621
  if self.conn is None:
640
622
  raise AttributeError("LinkState not initialized")
641
623
 
642
- placeholders = ",".join([f":id_{index}" for index in range(len(task_ids))])
643
- data = {f"id_{index}": str(task_id) for index, task_id in enumerate(task_ids)}
644
-
645
- # Delete task_ins
646
- query_1 = f"""
647
- DELETE FROM task_ins
648
- WHERE task_id IN ({placeholders});
624
+ query = """
625
+ SELECT task_id
626
+ FROM task_ins
627
+ WHERE run_id = :run_id;
649
628
  """
650
629
 
651
- # Delete task_res
652
- query_2 = f"""
653
- DELETE FROM task_res
654
- WHERE ancestry IN ({placeholders});
655
- """
630
+ sint64_run_id = convert_uint64_to_sint64(run_id)
631
+ data = {"run_id": sint64_run_id}
656
632
 
657
633
  with self.conn:
658
- self.conn.execute(query_1, data)
659
- self.conn.execute(query_2, data)
634
+ rows = self.conn.execute(query, data).fetchall()
635
+
636
+ return {UUID(row["task_id"]) for row in rows}
660
637
 
661
638
  def create_node(
662
639
  self, ping_interval: float, public_key: Optional[bytes] = None
@@ -784,8 +761,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
784
761
  "federation_options, pending_at, starting_at, running_at, finished_at, "
785
762
  "sub_status, details) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
786
763
  )
787
- if fab_hash:
788
- fab_id, fab_version = "", ""
789
764
  override_config_json = json.dumps(override_config)
790
765
  data = [
791
766
  sint64_run_id,
@@ -843,6 +818,12 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
843
818
  public_key = None
844
819
  return public_key
845
820
 
821
+ def clear_supernode_auth_keys_and_credentials(self) -> None:
822
+ """Clear stored `node_public_keys` and credentials in the link state if any."""
823
+ queries = ["DELETE FROM public_key;", "DELETE FROM credential;"]
824
+ for query in queries:
825
+ self.query(query)
826
+
846
827
  def store_node_public_keys(self, public_keys: set[bytes]) -> None:
847
828
  """Store a set of `node_public_keys` in the link state."""
848
829
  query = "INSERT INTO public_key (public_key) VALUES (?)"
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """SimulationIo API servicer."""
16
16
 
17
+
17
18
  import threading
18
19
  from logging import DEBUG, INFO
19
20
 
@@ -28,6 +29,7 @@ from flwr.common.serde import (
28
29
  context_to_proto,
29
30
  fab_to_proto,
30
31
  run_status_from_proto,
32
+ run_status_to_proto,
31
33
  run_to_proto,
32
34
  )
33
35
  from flwr.common.typing import Fab, RunStatus
@@ -39,6 +41,8 @@ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
39
41
  from flwr.proto.run_pb2 import ( # pylint: disable=E0611
40
42
  GetFederationOptionsRequest,
41
43
  GetFederationOptionsResponse,
44
+ GetRunStatusRequest,
45
+ GetRunStatusResponse,
42
46
  UpdateRunStatusRequest,
43
47
  UpdateRunStatusResponse,
44
48
  )
@@ -50,6 +54,7 @@ from flwr.proto.simulationio_pb2 import ( # pylint: disable=E0611
50
54
  )
51
55
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
52
56
  from flwr.server.superlink.linkstate import LinkStateFactory
57
+ from flwr.server.superlink.utils import abort_if
53
58
 
54
59
 
55
60
  class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
@@ -106,6 +111,15 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
106
111
  """Push Simulation process outputs."""
107
112
  log(DEBUG, "SimultionIoServicer.PushSimulationOutputs")
108
113
  state = self.state_factory.state()
114
+
115
+ # Abort if the run is not running
116
+ abort_if(
117
+ request.run_id,
118
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
119
+ state,
120
+ context,
121
+ )
122
+
109
123
  state.set_serverapp_context(request.run_id, context_from_proto(request.context))
110
124
  return PushSimulationOutputsResponse()
111
125
 
@@ -116,12 +130,31 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
116
130
  log(DEBUG, "SimultionIoServicer.UpdateRunStatus")
117
131
  state = self.state_factory.state()
118
132
 
133
+ # Abort if the run is finished
134
+ abort_if(request.run_id, [Status.FINISHED], state, context)
135
+
119
136
  # Update the run status
120
137
  state.update_run_status(
121
138
  run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
122
139
  )
123
140
  return UpdateRunStatusResponse()
124
141
 
142
+ def GetRunStatus(
143
+ self, request: GetRunStatusRequest, context: ServicerContext
144
+ ) -> GetRunStatusResponse:
145
+ """Get status of requested runs."""
146
+ log(DEBUG, "SimultionIoServicer.GetRunStatus")
147
+ state = self.state_factory.state()
148
+
149
+ statuses = state.get_run_status(set(request.run_ids))
150
+
151
+ return GetRunStatusResponse(
152
+ run_status_dict={
153
+ run_id: run_status_to_proto(status)
154
+ for run_id, status in statuses.items()
155
+ }
156
+ )
157
+
125
158
  def PushLogs(
126
159
  self, request: PushLogsRequest, context: grpc.ServicerContext
127
160
  ) -> PushLogsResponse: