flwr-nightly 1.14.0.dev20241211__py3-none-any.whl → 1.14.0.dev20241213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (75) hide show
  1. flwr/cli/app.py +1 -0
  2. flwr/cli/build.py +1 -0
  3. flwr/cli/config_utils.py +1 -0
  4. flwr/cli/example.py +1 -0
  5. flwr/cli/install.py +1 -0
  6. flwr/cli/log.py +1 -0
  7. flwr/cli/login/__init__.py +1 -0
  8. flwr/cli/login/login.py +1 -0
  9. flwr/cli/new/__init__.py +1 -0
  10. flwr/cli/new/new.py +2 -1
  11. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -2
  12. flwr/cli/run/__init__.py +1 -0
  13. flwr/cli/run/run.py +1 -0
  14. flwr/cli/utils.py +1 -0
  15. flwr/client/app.py +3 -2
  16. flwr/client/client.py +1 -0
  17. flwr/client/clientapp/app.py +1 -0
  18. flwr/client/clientapp/utils.py +1 -0
  19. flwr/client/grpc_adapter_client/connection.py +1 -1
  20. flwr/client/grpc_client/connection.py +1 -1
  21. flwr/client/grpc_rere_client/connection.py +3 -3
  22. flwr/client/message_handler/message_handler.py +1 -0
  23. flwr/client/mod/comms_mods.py +1 -0
  24. flwr/client/mod/localdp_mod.py +1 -1
  25. flwr/client/nodestate/__init__.py +1 -0
  26. flwr/client/nodestate/nodestate.py +1 -0
  27. flwr/client/nodestate/nodestate_factory.py +1 -0
  28. flwr/client/rest_client/connection.py +3 -3
  29. flwr/client/supernode/app.py +1 -0
  30. flwr/common/address.py +1 -0
  31. flwr/common/args.py +1 -0
  32. flwr/common/config.py +1 -0
  33. flwr/common/logger.py +1 -0
  34. flwr/common/message.py +1 -0
  35. flwr/common/object_ref.py +57 -54
  36. flwr/common/pyproject.py +1 -0
  37. flwr/common/record/__init__.py +1 -0
  38. flwr/common/record/parametersrecord.py +1 -0
  39. flwr/common/retry_invoker.py +75 -0
  40. flwr/common/typing.py +4 -0
  41. flwr/common/version.py +1 -0
  42. flwr/proto/fab_pb2.py +4 -4
  43. flwr/proto/fab_pb2.pyi +4 -1
  44. flwr/server/app.py +1 -0
  45. flwr/server/compat/app_utils.py +7 -1
  46. flwr/server/driver/grpc_driver.py +5 -61
  47. flwr/server/driver/inmemory_driver.py +5 -1
  48. flwr/server/serverapp/app.py +9 -2
  49. flwr/server/strategy/dpfedavg_fixed.py +1 -0
  50. flwr/server/superlink/driver/serverappio_grpc.py +1 -0
  51. flwr/server/superlink/driver/serverappio_servicer.py +54 -22
  52. flwr/server/superlink/ffs/disk_ffs.py +1 -0
  53. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
  54. flwr/server/superlink/fleet/vce/__init__.py +1 -0
  55. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
  56. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
  57. flwr/server/superlink/linkstate/in_memory_linkstate.py +14 -30
  58. flwr/server/superlink/linkstate/linkstate.py +13 -2
  59. flwr/server/superlink/linkstate/sqlite_linkstate.py +24 -44
  60. flwr/server/superlink/simulation/simulationio_servicer.py +1 -0
  61. flwr/server/superlink/utils.py +65 -0
  62. flwr/simulation/app.py +1 -0
  63. flwr/simulation/ray_transport/ray_actor.py +1 -0
  64. flwr/simulation/ray_transport/utils.py +1 -0
  65. flwr/simulation/run_simulation.py +1 -0
  66. flwr/superexec/app.py +1 -0
  67. flwr/superexec/deployment.py +1 -0
  68. flwr/superexec/exec_grpc.py +1 -0
  69. flwr/superexec/exec_servicer.py +8 -0
  70. flwr/superexec/executor.py +1 -0
  71. {flwr_nightly-1.14.0.dev20241211.dist-info → flwr_nightly-1.14.0.dev20241213.dist-info}/METADATA +1 -1
  72. {flwr_nightly-1.14.0.dev20241211.dist-info → flwr_nightly-1.14.0.dev20241213.dist-info}/RECORD +75 -74
  73. {flwr_nightly-1.14.0.dev20241211.dist-info → flwr_nightly-1.14.0.dev20241213.dist-info}/LICENSE +0 -0
  74. {flwr_nightly-1.14.0.dev20241211.dist-info → flwr_nightly-1.14.0.dev20241213.dist-info}/WHEEL +0 -0
  75. {flwr_nightly-1.14.0.dev20241211.dist-info → flwr_nightly-1.14.0.dev20241213.dist-info}/entry_points.txt +0 -0
flwr/proto/fab_pb2.py CHANGED
@@ -15,7 +15,7 @@ _sym_db = _symbol_database.Default()
15
15
  from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
16
16
 
17
17
 
18
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/fab.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\"(\n\x03\x46\x61\x62\x12\x10\n\x08hash_str\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"A\n\rGetFabRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08hash_str\x18\x02 \x01(\t\".\n\x0eGetFabResponse\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fabb\x06proto3')
18
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/fab.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\"(\n\x03\x46\x61\x62\x12\x10\n\x08hash_str\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"Q\n\rGetFabRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08hash_str\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x04\".\n\x0eGetFabResponse\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fabb\x06proto3')
19
19
 
20
20
  _globals = globals()
21
21
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -25,7 +25,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
25
25
  _globals['_FAB']._serialized_start=59
26
26
  _globals['_FAB']._serialized_end=99
27
27
  _globals['_GETFABREQUEST']._serialized_start=101
28
- _globals['_GETFABREQUEST']._serialized_end=166
29
- _globals['_GETFABRESPONSE']._serialized_start=168
30
- _globals['_GETFABRESPONSE']._serialized_end=214
28
+ _globals['_GETFABREQUEST']._serialized_end=182
29
+ _globals['_GETFABRESPONSE']._serialized_start=184
30
+ _globals['_GETFABRESPONSE']._serialized_end=230
31
31
  # @@protoc_insertion_point(module_scope)
flwr/proto/fab_pb2.pyi CHANGED
@@ -36,16 +36,19 @@ class GetFabRequest(google.protobuf.message.Message):
36
36
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
37
37
  NODE_FIELD_NUMBER: builtins.int
38
38
  HASH_STR_FIELD_NUMBER: builtins.int
39
+ RUN_ID_FIELD_NUMBER: builtins.int
39
40
  @property
40
41
  def node(self) -> flwr.proto.node_pb2.Node: ...
41
42
  hash_str: typing.Text
43
+ run_id: builtins.int
42
44
  def __init__(self,
43
45
  *,
44
46
  node: typing.Optional[flwr.proto.node_pb2.Node] = ...,
45
47
  hash_str: typing.Text = ...,
48
+ run_id: builtins.int = ...,
46
49
  ) -> None: ...
47
50
  def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ...
48
- def ClearField(self, field_name: typing_extensions.Literal["hash_str",b"hash_str","node",b"node"]) -> None: ...
51
+ def ClearField(self, field_name: typing_extensions.Literal["hash_str",b"hash_str","node",b"node","run_id",b"run_id"]) -> None: ...
49
52
  global___GetFabRequest = GetFabRequest
50
53
 
51
54
  class GetFabResponse(google.protobuf.message.Message):
flwr/server/app.py CHANGED
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower server app."""
16
16
 
17
+
17
18
  import argparse
18
19
  import csv
19
20
  import importlib.util
@@ -17,6 +17,8 @@
17
17
 
18
18
  import threading
19
19
 
20
+ from flwr.common.typing import RunNotRunningException
21
+
20
22
  from ..client_manager import ClientManager
21
23
  from ..compat.driver_client_proxy import DriverClientProxy
22
24
  from ..driver import Driver
@@ -74,7 +76,11 @@ def _update_client_manager(
74
76
  # Loop until the driver is disconnected
75
77
  registered_nodes: dict[int, DriverClientProxy] = {}
76
78
  while not f_stop.is_set():
77
- all_node_ids = set(driver.get_node_ids())
79
+ try:
80
+ all_node_ids = set(driver.get_node_ids())
81
+ except RunNotRunningException:
82
+ f_stop.set()
83
+ break
78
84
  dead_nodes = set(registered_nodes).difference(all_node_ids)
79
85
  new_nodes = all_node_ids.difference(registered_nodes)
80
86
 
@@ -14,19 +14,20 @@
14
14
  # ==============================================================================
15
15
  """Flower gRPC Driver."""
16
16
 
17
+
17
18
  import time
18
19
  import warnings
19
20
  from collections.abc import Iterable
20
- from logging import DEBUG, INFO, WARN, WARNING
21
- from typing import Any, Optional, cast
21
+ from logging import DEBUG, WARNING
22
+ from typing import Optional, cast
22
23
 
23
24
  import grpc
24
25
 
25
26
  from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
26
- from flwr.common.constant import MAX_RETRY_DELAY, SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS
27
+ from flwr.common.constant import SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS
27
28
  from flwr.common.grpc import create_channel
28
29
  from flwr.common.logger import log
29
- from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
30
+ from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
30
31
  from flwr.common.serde import message_from_taskres, message_to_taskins, run_from_proto
31
32
  from flwr.common.typing import Run
32
33
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
@@ -262,60 +263,3 @@ class GrpcDriver(Driver):
262
263
  return
263
264
  # Disconnect
264
265
  self._disconnect()
265
-
266
-
267
- def _make_simple_grpc_retry_invoker() -> RetryInvoker:
268
- """Create a simple gRPC retry invoker."""
269
-
270
- def _on_sucess(retry_state: RetryState) -> None:
271
- if retry_state.tries > 1:
272
- log(
273
- INFO,
274
- "Connection successful after %.2f seconds and %s tries.",
275
- retry_state.elapsed_time,
276
- retry_state.tries,
277
- )
278
-
279
- def _on_backoff(retry_state: RetryState) -> None:
280
- if retry_state.tries == 1:
281
- log(WARN, "Connection attempt failed, retrying...")
282
- else:
283
- log(
284
- WARN,
285
- "Connection attempt failed, retrying in %.2f seconds",
286
- retry_state.actual_wait,
287
- )
288
-
289
- def _on_giveup(retry_state: RetryState) -> None:
290
- if retry_state.tries > 1:
291
- log(
292
- WARN,
293
- "Giving up reconnection after %.2f seconds and %s tries.",
294
- retry_state.elapsed_time,
295
- retry_state.tries,
296
- )
297
-
298
- return RetryInvoker(
299
- wait_gen_factory=lambda: exponential(max_delay=MAX_RETRY_DELAY),
300
- recoverable_exceptions=grpc.RpcError,
301
- max_tries=None,
302
- max_time=None,
303
- on_success=_on_sucess,
304
- on_backoff=_on_backoff,
305
- on_giveup=_on_giveup,
306
- should_giveup=lambda e: e.code() != grpc.StatusCode.UNAVAILABLE, # type: ignore
307
- )
308
-
309
-
310
- def _wrap_stub(stub: ServerAppIoStub, retry_invoker: RetryInvoker) -> None:
311
- """Wrap the gRPC stub with a retry invoker."""
312
-
313
- def make_lambda(original_method: Any) -> Any:
314
- return lambda *args, **kwargs: retry_invoker.invoke(
315
- original_method, *args, **kwargs
316
- )
317
-
318
- for method_name in vars(stub):
319
- method = getattr(stub, method_name)
320
- if callable(method):
321
- setattr(stub, method_name, make_lambda(method))
@@ -142,7 +142,11 @@ class InMemoryDriver(Driver):
142
142
  # Pull TaskRes
143
143
  task_res_list = self.state.get_task_res(task_ids=msg_ids)
144
144
  # Delete tasks in state
145
- self.state.delete_tasks(msg_ids)
145
+ # Delete the TaskIns/TaskRes pairs if TaskRes is found
146
+ task_ins_ids_to_delete = {
147
+ UUID(task_res.task.ancestry[0]) for task_res in task_res_list
148
+ }
149
+ self.state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
146
150
  # Convert TaskRes to Message
147
151
  msgs = [message_from_taskres(taskres) for taskres in task_res_list]
148
152
  return msgs
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower ServerApp process."""
16
16
 
17
+
17
18
  import argparse
18
19
  import sys
19
20
  from logging import DEBUG, ERROR, INFO
@@ -50,7 +51,7 @@ from flwr.common.serde import (
50
51
  run_from_proto,
51
52
  run_status_to_proto,
52
53
  )
53
- from flwr.common.typing import RunStatus
54
+ from flwr.common.typing import RunNotRunningException, RunStatus
54
55
  from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
55
56
  from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
56
57
  PullServerAppInputsRequest,
@@ -96,7 +97,7 @@ def flwr_serverapp() -> None:
96
97
  restore_output()
97
98
 
98
99
 
99
- def run_serverapp( # pylint: disable=R0914, disable=W0212
100
+ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
100
101
  serverappio_api_address: str,
101
102
  log_queue: Queue[Optional[str]],
102
103
  run_once: bool,
@@ -187,6 +188,12 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
187
188
 
188
189
  run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
189
190
 
191
+ except RunNotRunningException:
192
+ log(INFO, "")
193
+ log(INFO, "Run ID %s stopped.", run.run_id)
194
+ log(INFO, "")
195
+ run_status = None
196
+
190
197
  except Exception as ex: # pylint: disable=broad-exception-caught
191
198
  exc_entity = "ServerApp"
192
199
  log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
@@ -17,6 +17,7 @@
17
17
  Paper: arxiv.org/pdf/1710.06963.pdf
18
18
  """
19
19
 
20
+
20
21
  from typing import Optional, Union
21
22
 
22
23
  from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """ServerAppIo gRPC API."""
16
16
 
17
+
17
18
  from logging import INFO
18
19
  from typing import Optional
19
20
 
@@ -70,6 +70,7 @@ from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
70
70
  from flwr.server.superlink.ffs.ffs import Ffs
71
71
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
72
72
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
73
+ from flwr.server.superlink.utils import abort_if
73
74
  from flwr.server.utils.validator import validate_task_ins_or_res
74
75
 
75
76
 
@@ -88,7 +89,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
88
89
  ) -> GetNodesResponse:
89
90
  """Get available nodes."""
90
91
  log(DEBUG, "ServerAppIoServicer.GetNodes")
92
+
93
+ # Init state
91
94
  state: LinkState = self.state_factory.state()
95
+
96
+ # Abort if the run is not running
97
+ abort_if(
98
+ request.run_id,
99
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
100
+ state,
101
+ context,
102
+ )
103
+
92
104
  all_ids: set[int] = state.get_nodes(request.run_id)
93
105
  nodes: list[Node] = [
94
106
  Node(node_id=node_id, anonymous=False) for node_id in all_ids
@@ -126,6 +138,17 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
126
138
  """Push a set of TaskIns."""
127
139
  log(DEBUG, "ServerAppIoServicer.PushTaskIns")
128
140
 
141
+ # Init state
142
+ state: LinkState = self.state_factory.state()
143
+
144
+ # Abort if the run is not running
145
+ abort_if(
146
+ request.run_id,
147
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
148
+ state,
149
+ context,
150
+ )
151
+
129
152
  # Set pushed_at (timestamp in seconds)
130
153
  pushed_at = time.time()
131
154
  for task_ins in request.task_ins_list:
@@ -137,9 +160,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
137
160
  validation_errors = validate_task_ins_or_res(task_ins)
138
161
  _raise_if(bool(validation_errors), ", ".join(validation_errors))
139
162
 
140
- # Init state
141
- state: LinkState = self.state_factory.state()
142
-
143
163
  # Store each TaskIns
144
164
  task_ids: list[Optional[UUID]] = []
145
165
  for task_ins in request.task_ins_list:
@@ -156,33 +176,29 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
156
176
  """Pull a set of TaskRes."""
157
177
  log(DEBUG, "ServerAppIoServicer.PullTaskRes")
158
178
 
159
- # Convert each task_id str to UUID
160
- task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
161
-
162
179
  # Init state
163
180
  state: LinkState = self.state_factory.state()
164
181
 
165
- # Register callback
166
- def on_rpc_done() -> None:
167
- log(
168
- DEBUG,
169
- "ServerAppIoServicer.PullTaskRes callback: delete TaskIns/TaskRes",
170
- )
171
-
172
- if context.is_active():
173
- return
174
- if context.code() != grpc.StatusCode.OK:
175
- return
176
-
177
- # Delete delivered TaskIns and TaskRes
178
- state.delete_tasks(task_ids=task_ids)
182
+ # Abort if the run is not running
183
+ abort_if(
184
+ request.run_id,
185
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
186
+ state,
187
+ context,
188
+ )
179
189
 
180
- context.add_callback(on_rpc_done)
190
+ # Convert each task_id str to UUID
191
+ task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
181
192
 
182
193
  # Read from state
183
194
  task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids)
184
195
 
185
- context.set_code(grpc.StatusCode.OK)
196
+ # Delete the TaskIns/TaskRes pairs if TaskRes is found
197
+ task_ins_ids_to_delete = {
198
+ UUID(task_res.task.ancestry[0]) for task_res in task_res_list
199
+ }
200
+ state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
201
+
186
202
  return PullTaskResResponse(task_res_list=task_res_list)
187
203
 
188
204
  def GetRun(
@@ -258,7 +274,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
258
274
  ) -> PushServerAppOutputsResponse:
259
275
  """Push ServerApp process outputs."""
260
276
  log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
277
+
278
+ # Init state
261
279
  state = self.state_factory.state()
280
+
281
+ # Abort if the run is not running
282
+ abort_if(
283
+ request.run_id,
284
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
285
+ state,
286
+ context,
287
+ )
288
+
262
289
  state.set_serverapp_context(request.run_id, context_from_proto(request.context))
263
290
  return PushServerAppOutputsResponse()
264
291
 
@@ -267,8 +294,13 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
267
294
  ) -> UpdateRunStatusResponse:
268
295
  """Update the status of a run."""
269
296
  log(DEBUG, "ServerAppIoServicer.UpdateRunStatus")
297
+
298
+ # Init state
270
299
  state = self.state_factory.state()
271
300
 
301
+ # Abort if the run is finished
302
+ abort_if(request.run_id, [Status.FINISHED], state, context)
303
+
272
304
  # Update the run status
273
305
  state.update_run_status(
274
306
  run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Disk based Flower File Storage."""
16
16
 
17
+
17
18
  import hashlib
18
19
  import json
19
20
  from pathlib import Path
@@ -18,6 +18,7 @@ Relevant knowledge for reading this modules code:
18
18
  - https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
19
19
  """
20
20
 
21
+
21
22
  import uuid
22
23
  from collections.abc import Iterator
23
24
  from typing import Callable
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Fleet Simulation Engine side."""
16
16
 
17
+
17
18
  from .vce_api import start_vce
18
19
 
19
20
  __all__ = [
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Simulation Engine Backends."""
16
16
 
17
+
17
18
  import importlib
18
19
 
19
20
  from .backend import Backend, BackendConfig
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Ray backend for the Fleet API using the Simulation Engine."""
16
16
 
17
+
17
18
  import sys
18
19
  from logging import DEBUG, ERROR
19
20
  from typing import Callable, Optional, Union
@@ -265,41 +265,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
265
265
  for task_res in task_res_found:
266
266
  task_res.task.delivered_at = delivered_at
267
267
 
268
- # Cleanup
269
- self._force_delete_tasks_by_ids(set(ret.keys()))
270
-
271
268
  return list(ret.values())
272
269
 
273
- def delete_tasks(self, task_ids: set[UUID]) -> None:
274
- """Delete all delivered TaskIns/TaskRes pairs."""
275
- task_ins_to_be_deleted: set[UUID] = set()
276
- task_res_to_be_deleted: set[UUID] = set()
277
-
278
- with self.lock:
279
- for task_ins_id in task_ids:
280
- # Find the task_id of the matching task_res
281
- for task_res_id, task_res in self.task_res_store.items():
282
- if UUID(task_res.task.ancestry[0]) != task_ins_id:
283
- continue
284
- if task_res.task.delivered_at == "":
285
- continue
286
-
287
- task_ins_to_be_deleted.add(task_ins_id)
288
- task_res_to_be_deleted.add(task_res_id)
289
-
290
- for task_id in task_ins_to_be_deleted:
291
- del self.task_ins_store[task_id]
292
- del self.task_ins_id_to_task_res_id[task_id]
293
- for task_id in task_res_to_be_deleted:
294
- del self.task_res_store[task_id]
295
-
296
- def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
297
- """Delete tasks based on a set of TaskIns IDs."""
298
- if not task_ids:
270
+ def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
271
+ """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
272
+ if not task_ins_ids:
299
273
  return
300
274
 
301
275
  with self.lock:
302
- for task_id in task_ids:
276
+ for task_id in task_ins_ids:
303
277
  # Delete TaskIns
304
278
  if task_id in self.task_ins_store:
305
279
  del self.task_ins_store[task_id]
@@ -308,6 +282,16 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
308
282
  task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
309
283
  del self.task_res_store[task_res_id]
310
284
 
285
+ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
286
+ """Get all TaskIns IDs for the given run_id."""
287
+ task_id_list: set[UUID] = set()
288
+ with self.lock:
289
+ for task_id, task_ins in self.task_ins_store.items():
290
+ if task_ins.run_id == run_id:
291
+ task_id_list.add(task_id)
292
+
293
+ return task_id_list
294
+
311
295
  def num_task_ins(self) -> int:
312
296
  """Calculate the number of task_ins in store.
313
297
 
@@ -139,8 +139,19 @@ class LinkState(abc.ABC): # pylint: disable=R0904
139
139
  """
140
140
 
141
141
  @abc.abstractmethod
142
- def delete_tasks(self, task_ids: set[UUID]) -> None:
143
- """Delete all delivered TaskIns/TaskRes pairs."""
142
+ def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
143
+ """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs.
144
+
145
+ Parameters
146
+ ----------
147
+ task_ins_ids : set[UUID]
148
+ A set of TaskIns IDs. For each ID in the set, the corresponding
149
+ TaskIns and its associated TaskRes will be deleted.
150
+ """
151
+
152
+ @abc.abstractmethod
153
+ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
154
+ """Get all TaskIns IDs for the given run_id."""
144
155
 
145
156
  @abc.abstractmethod
146
157
  def create_node(
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """SQLite based implemenation of the link state."""
16
16
 
17
+
17
18
  # pylint: disable=too-many-lines
18
19
 
19
20
  import json
@@ -566,9 +567,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
566
567
  data: list[Any] = [delivered_at] + task_res_ids
567
568
  self.query(query, data)
568
569
 
569
- # Cleanup
570
- self._force_delete_tasks_by_ids(set(ret.keys()))
571
-
572
570
  return list(ret.values())
573
571
 
574
572
  def num_task_ins(self) -> int:
@@ -592,68 +590,50 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
592
590
  result: dict[str, int] = rows[0]
593
591
  return result["num"]
594
592
 
595
- def delete_tasks(self, task_ids: set[UUID]) -> None:
596
- """Delete all delivered TaskIns/TaskRes pairs."""
597
- ids = list(task_ids)
598
- if len(ids) == 0:
599
- return None
593
+ def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
594
+ """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
595
+ if not task_ins_ids:
596
+ return
597
+ if self.conn is None:
598
+ raise AttributeError("LinkState not initialized")
600
599
 
601
- placeholders = ",".join([f":id_{index}" for index in range(len(task_ids))])
602
- data = {f"id_{index}": str(task_id) for index, task_id in enumerate(task_ids)}
600
+ placeholders = ",".join(["?"] * len(task_ins_ids))
601
+ data = tuple(str(task_id) for task_id in task_ins_ids)
603
602
 
604
- # 1. Query: Delete task_ins which have a delivered task_res
603
+ # Delete task_ins
605
604
  query_1 = f"""
606
605
  DELETE FROM task_ins
607
- WHERE delivered_at != ''
608
- AND task_id IN (
609
- SELECT ancestry
610
- FROM task_res
611
- WHERE ancestry IN ({placeholders})
612
- AND delivered_at != ''
613
- );
606
+ WHERE task_id IN ({placeholders});
614
607
  """
615
608
 
616
- # 2. Query: Delete delivered task_res to be run after 1. Query
609
+ # Delete task_res
617
610
  query_2 = f"""
618
611
  DELETE FROM task_res
619
- WHERE ancestry IN ({placeholders})
620
- AND delivered_at != '';
612
+ WHERE ancestry IN ({placeholders});
621
613
  """
622
614
 
623
- if self.conn is None:
624
- raise AttributeError("LinkState not intitialized")
625
-
626
615
  with self.conn:
627
616
  self.conn.execute(query_1, data)
628
617
  self.conn.execute(query_2, data)
629
618
 
630
- return None
631
-
632
- def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
633
- """Delete tasks based on a set of TaskIns IDs."""
634
- if not task_ids:
635
- return
619
+ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
620
+ """Get all TaskIns IDs for the given run_id."""
636
621
  if self.conn is None:
637
622
  raise AttributeError("LinkState not initialized")
638
623
 
639
- placeholders = ",".join([f":id_{index}" for index in range(len(task_ids))])
640
- data = {f"id_{index}": str(task_id) for index, task_id in enumerate(task_ids)}
641
-
642
- # Delete task_ins
643
- query_1 = f"""
644
- DELETE FROM task_ins
645
- WHERE task_id IN ({placeholders});
624
+ query = """
625
+ SELECT task_id
626
+ FROM task_ins
627
+ WHERE run_id = :run_id;
646
628
  """
647
629
 
648
- # Delete task_res
649
- query_2 = f"""
650
- DELETE FROM task_res
651
- WHERE ancestry IN ({placeholders});
652
- """
630
+ sint64_run_id = convert_uint64_to_sint64(run_id)
631
+ data = {"run_id": sint64_run_id}
653
632
 
654
633
  with self.conn:
655
- self.conn.execute(query_1, data)
656
- self.conn.execute(query_2, data)
634
+ rows = self.conn.execute(query, data).fetchall()
635
+
636
+ return {UUID(row["task_id"]) for row in rows}
657
637
 
658
638
  def create_node(
659
639
  self, ping_interval: float, public_key: Optional[bytes] = None
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """SimulationIo API servicer."""
16
16
 
17
+
17
18
  import threading
18
19
  from logging import DEBUG, INFO
19
20