flwr-nightly 1.14.0.dev20241204__py3-none-any.whl → 1.14.0.dev20241214__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (100) hide show
  1. flwr/cli/app.py +5 -0
  2. flwr/cli/build.py +1 -0
  3. flwr/cli/cli_user_auth_interceptor.py +86 -0
  4. flwr/cli/config_utils.py +19 -2
  5. flwr/cli/example.py +1 -0
  6. flwr/cli/install.py +1 -0
  7. flwr/cli/log.py +11 -31
  8. flwr/cli/login/__init__.py +22 -0
  9. flwr/cli/login/login.py +83 -0
  10. flwr/cli/ls.py +10 -40
  11. flwr/cli/new/__init__.py +1 -0
  12. flwr/cli/new/new.py +2 -1
  13. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  14. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -2
  15. flwr/cli/run/__init__.py +1 -0
  16. flwr/cli/run/run.py +15 -25
  17. flwr/cli/stop.py +91 -0
  18. flwr/cli/utils.py +109 -1
  19. flwr/client/app.py +3 -2
  20. flwr/client/client.py +1 -0
  21. flwr/client/clientapp/app.py +1 -0
  22. flwr/client/clientapp/utils.py +1 -0
  23. flwr/client/grpc_adapter_client/connection.py +1 -1
  24. flwr/client/grpc_client/connection.py +1 -1
  25. flwr/client/grpc_rere_client/connection.py +3 -3
  26. flwr/client/message_handler/message_handler.py +1 -0
  27. flwr/client/mod/comms_mods.py +1 -0
  28. flwr/client/mod/localdp_mod.py +1 -1
  29. flwr/client/nodestate/__init__.py +1 -0
  30. flwr/client/nodestate/nodestate.py +1 -0
  31. flwr/client/nodestate/nodestate_factory.py +1 -0
  32. flwr/client/rest_client/connection.py +3 -3
  33. flwr/client/supernode/app.py +1 -0
  34. flwr/common/address.py +1 -0
  35. flwr/common/args.py +1 -0
  36. flwr/common/auth_plugin/__init__.py +24 -0
  37. flwr/common/auth_plugin/auth_plugin.py +111 -0
  38. flwr/common/config.py +3 -1
  39. flwr/common/constant.py +6 -1
  40. flwr/common/logger.py +1 -0
  41. flwr/common/message.py +1 -0
  42. flwr/common/object_ref.py +57 -54
  43. flwr/common/pyproject.py +1 -0
  44. flwr/common/record/__init__.py +1 -0
  45. flwr/common/record/parametersrecord.py +1 -0
  46. flwr/common/retry_invoker.py +75 -0
  47. flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
  48. flwr/common/telemetry.py +2 -1
  49. flwr/common/typing.py +12 -0
  50. flwr/common/version.py +1 -0
  51. flwr/proto/exec_pb2.py +27 -3
  52. flwr/proto/exec_pb2.pyi +103 -0
  53. flwr/proto/exec_pb2_grpc.py +102 -0
  54. flwr/proto/exec_pb2_grpc.pyi +39 -0
  55. flwr/proto/fab_pb2.py +4 -4
  56. flwr/proto/fab_pb2.pyi +4 -1
  57. flwr/proto/serverappio_pb2.py +18 -18
  58. flwr/proto/serverappio_pb2.pyi +8 -2
  59. flwr/proto/serverappio_pb2_grpc.py +34 -0
  60. flwr/proto/serverappio_pb2_grpc.pyi +13 -0
  61. flwr/proto/simulationio_pb2.py +2 -2
  62. flwr/proto/simulationio_pb2_grpc.py +34 -0
  63. flwr/proto/simulationio_pb2_grpc.pyi +13 -0
  64. flwr/server/app.py +53 -1
  65. flwr/server/compat/app_utils.py +7 -1
  66. flwr/server/driver/grpc_driver.py +11 -63
  67. flwr/server/driver/inmemory_driver.py +5 -1
  68. flwr/server/serverapp/app.py +9 -2
  69. flwr/server/strategy/dpfedavg_fixed.py +1 -0
  70. flwr/server/superlink/driver/serverappio_grpc.py +1 -0
  71. flwr/server/superlink/driver/serverappio_servicer.py +72 -22
  72. flwr/server/superlink/ffs/disk_ffs.py +1 -0
  73. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -0
  74. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
  75. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +32 -12
  76. flwr/server/superlink/fleet/message_handler/message_handler.py +31 -2
  77. flwr/server/superlink/fleet/rest_rere/rest_api.py +4 -1
  78. flwr/server/superlink/fleet/vce/__init__.py +1 -0
  79. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
  80. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
  81. flwr/server/superlink/linkstate/in_memory_linkstate.py +14 -30
  82. flwr/server/superlink/linkstate/linkstate.py +13 -2
  83. flwr/server/superlink/linkstate/sqlite_linkstate.py +24 -44
  84. flwr/server/superlink/simulation/simulationio_servicer.py +20 -0
  85. flwr/server/superlink/utils.py +65 -0
  86. flwr/simulation/app.py +1 -0
  87. flwr/simulation/ray_transport/ray_actor.py +1 -0
  88. flwr/simulation/ray_transport/utils.py +1 -0
  89. flwr/simulation/run_simulation.py +1 -0
  90. flwr/superexec/app.py +1 -0
  91. flwr/superexec/deployment.py +1 -0
  92. flwr/superexec/exec_grpc.py +19 -1
  93. flwr/superexec/exec_servicer.py +76 -2
  94. flwr/superexec/exec_user_auth_interceptor.py +101 -0
  95. flwr/superexec/executor.py +1 -0
  96. {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/METADATA +8 -7
  97. {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/RECORD +100 -92
  98. {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/LICENSE +0 -0
  99. {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/WHEEL +0 -0
  100. {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/entry_points.txt +0 -0
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower ServerApp process."""
16
16
 
17
+
17
18
  import argparse
18
19
  import sys
19
20
  from logging import DEBUG, ERROR, INFO
@@ -50,7 +51,7 @@ from flwr.common.serde import (
50
51
  run_from_proto,
51
52
  run_status_to_proto,
52
53
  )
53
- from flwr.common.typing import RunStatus
54
+ from flwr.common.typing import RunNotRunningException, RunStatus
54
55
  from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
55
56
  from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
56
57
  PullServerAppInputsRequest,
@@ -96,7 +97,7 @@ def flwr_serverapp() -> None:
96
97
  restore_output()
97
98
 
98
99
 
99
- def run_serverapp( # pylint: disable=R0914, disable=W0212
100
+ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
100
101
  serverappio_api_address: str,
101
102
  log_queue: Queue[Optional[str]],
102
103
  run_once: bool,
@@ -187,6 +188,12 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
187
188
 
188
189
  run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
189
190
 
191
+ except RunNotRunningException:
192
+ log(INFO, "")
193
+ log(INFO, "Run ID %s stopped.", run.run_id)
194
+ log(INFO, "")
195
+ run_status = None
196
+
190
197
  except Exception as ex: # pylint: disable=broad-exception-caught
191
198
  exc_entity = "ServerApp"
192
199
  log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
@@ -17,6 +17,7 @@
17
17
  Paper: arxiv.org/pdf/1710.06963.pdf
18
18
  """
19
19
 
20
+
20
21
  from typing import Optional, Union
21
22
 
22
23
  from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """ServerAppIo gRPC API."""
16
16
 
17
+
17
18
  from logging import INFO
18
19
  from typing import Optional
19
20
 
@@ -32,6 +32,7 @@ from flwr.common.serde import (
32
32
  fab_from_proto,
33
33
  fab_to_proto,
34
34
  run_status_from_proto,
35
+ run_status_to_proto,
35
36
  run_to_proto,
36
37
  user_config_from_proto,
37
38
  )
@@ -48,6 +49,8 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
48
49
  CreateRunResponse,
49
50
  GetRunRequest,
50
51
  GetRunResponse,
52
+ GetRunStatusRequest,
53
+ GetRunStatusResponse,
51
54
  UpdateRunStatusRequest,
52
55
  UpdateRunStatusResponse,
53
56
  )
@@ -67,6 +70,7 @@ from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
67
70
  from flwr.server.superlink.ffs.ffs import Ffs
68
71
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
69
72
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
73
+ from flwr.server.superlink.utils import abort_if
70
74
  from flwr.server.utils.validator import validate_task_ins_or_res
71
75
 
72
76
 
@@ -85,7 +89,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
85
89
  ) -> GetNodesResponse:
86
90
  """Get available nodes."""
87
91
  log(DEBUG, "ServerAppIoServicer.GetNodes")
92
+
93
+ # Init state
88
94
  state: LinkState = self.state_factory.state()
95
+
96
+ # Abort if the run is not running
97
+ abort_if(
98
+ request.run_id,
99
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
100
+ state,
101
+ context,
102
+ )
103
+
89
104
  all_ids: set[int] = state.get_nodes(request.run_id)
90
105
  nodes: list[Node] = [
91
106
  Node(node_id=node_id, anonymous=False) for node_id in all_ids
@@ -123,6 +138,17 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
123
138
  """Push a set of TaskIns."""
124
139
  log(DEBUG, "ServerAppIoServicer.PushTaskIns")
125
140
 
141
+ # Init state
142
+ state: LinkState = self.state_factory.state()
143
+
144
+ # Abort if the run is not running
145
+ abort_if(
146
+ request.run_id,
147
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
148
+ state,
149
+ context,
150
+ )
151
+
126
152
  # Set pushed_at (timestamp in seconds)
127
153
  pushed_at = time.time()
128
154
  for task_ins in request.task_ins_list:
@@ -134,9 +160,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
134
160
  validation_errors = validate_task_ins_or_res(task_ins)
135
161
  _raise_if(bool(validation_errors), ", ".join(validation_errors))
136
162
 
137
- # Init state
138
- state: LinkState = self.state_factory.state()
139
-
140
163
  # Store each TaskIns
141
164
  task_ids: list[Optional[UUID]] = []
142
165
  for task_ins in request.task_ins_list:
@@ -153,33 +176,29 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
153
176
  """Pull a set of TaskRes."""
154
177
  log(DEBUG, "ServerAppIoServicer.PullTaskRes")
155
178
 
156
- # Convert each task_id str to UUID
157
- task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
158
-
159
179
  # Init state
160
180
  state: LinkState = self.state_factory.state()
161
181
 
162
- # Register callback
163
- def on_rpc_done() -> None:
164
- log(
165
- DEBUG,
166
- "ServerAppIoServicer.PullTaskRes callback: delete TaskIns/TaskRes",
167
- )
168
-
169
- if context.is_active():
170
- return
171
- if context.code() != grpc.StatusCode.OK:
172
- return
173
-
174
- # Delete delivered TaskIns and TaskRes
175
- state.delete_tasks(task_ids=task_ids)
182
+ # Abort if the run is not running
183
+ abort_if(
184
+ request.run_id,
185
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
186
+ state,
187
+ context,
188
+ )
176
189
 
177
- context.add_callback(on_rpc_done)
190
+ # Convert each task_id str to UUID
191
+ task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
178
192
 
179
193
  # Read from state
180
194
  task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids)
181
195
 
182
- context.set_code(grpc.StatusCode.OK)
196
+ # Delete the TaskIns/TaskRes pairs if TaskRes is found
197
+ task_ins_ids_to_delete = {
198
+ UUID(task_res.task.ancestry[0]) for task_res in task_res_list
199
+ }
200
+ state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
201
+
183
202
  return PullTaskResResponse(task_res_list=task_res_list)
184
203
 
185
204
  def GetRun(
@@ -255,7 +274,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
255
274
  ) -> PushServerAppOutputsResponse:
256
275
  """Push ServerApp process outputs."""
257
276
  log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
277
+
278
+ # Init state
258
279
  state = self.state_factory.state()
280
+
281
+ # Abort if the run is not running
282
+ abort_if(
283
+ request.run_id,
284
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
285
+ state,
286
+ context,
287
+ )
288
+
259
289
  state.set_serverapp_context(request.run_id, context_from_proto(request.context))
260
290
  return PushServerAppOutputsResponse()
261
291
 
@@ -264,8 +294,13 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
264
294
  ) -> UpdateRunStatusResponse:
265
295
  """Update the status of a run."""
266
296
  log(DEBUG, "ServerAppIoServicer.UpdateRunStatus")
297
+
298
+ # Init state
267
299
  state = self.state_factory.state()
268
300
 
301
+ # Abort if the run is finished
302
+ abort_if(request.run_id, [Status.FINISHED], state, context)
303
+
269
304
  # Update the run status
270
305
  state.update_run_status(
271
306
  run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
@@ -284,6 +319,21 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
284
319
  state.add_serverapp_log(request.run_id, merged_logs)
285
320
  return PushLogsResponse()
286
321
 
322
+ def GetRunStatus(
323
+ self, request: GetRunStatusRequest, context: grpc.ServicerContext
324
+ ) -> GetRunStatusResponse:
325
+ """Get the status of a run."""
326
+ log(DEBUG, "ServerAppIoServicer.GetRunStatus")
327
+ state = self.state_factory.state()
328
+
329
+ # Get run status from LinkState
330
+ run_statuses = state.get_run_status(set(request.run_ids))
331
+ run_status_dict = {
332
+ run_id: run_status_to_proto(run_status)
333
+ for run_id, run_status in run_statuses.items()
334
+ }
335
+ return GetRunStatusResponse(run_status_dict=run_status_dict)
336
+
287
337
 
288
338
  def _raise_if(validation_error: bool, detail: str) -> None:
289
339
  if validation_error:
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Disk based Flower File Storage."""
16
16
 
17
+
17
18
  import hashlib
18
19
  import json
19
20
  from pathlib import Path
@@ -158,4 +158,5 @@ class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
158
158
  return message_handler.get_fab(
159
159
  request=request,
160
160
  ffs=self.ffs_factory.ffs(),
161
+ state=self.state_factory.state(),
161
162
  )
@@ -18,6 +18,7 @@ Relevant knowledge for reading this modules code:
18
18
  - https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
19
19
  """
20
20
 
21
+
21
22
  import uuid
22
23
  from collections.abc import Iterator
23
24
  from typing import Callable
@@ -20,6 +20,7 @@ from logging import DEBUG, INFO
20
20
  import grpc
21
21
 
22
22
  from flwr.common.logger import log
23
+ from flwr.common.typing import InvalidRunStatusException
23
24
  from flwr.proto import fleet_pb2_grpc # pylint: disable=E0611
24
25
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
25
26
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
@@ -38,6 +39,7 @@ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=
38
39
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
39
40
  from flwr.server.superlink.fleet.message_handler import message_handler
40
41
  from flwr.server.superlink.linkstate import LinkStateFactory
42
+ from flwr.server.superlink.utils import abort_grpc_context
41
43
 
42
44
 
43
45
  class FleetServicer(fleet_pb2_grpc.FleetServicer):
@@ -105,27 +107,45 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
105
107
  )
106
108
  else:
107
109
  log(INFO, "[Fleet.PushTaskRes] No task results to push")
108
- return message_handler.push_task_res(
109
- request=request,
110
- state=self.state_factory.state(),
111
- )
110
+
111
+ try:
112
+ res = message_handler.push_task_res(
113
+ request=request,
114
+ state=self.state_factory.state(),
115
+ )
116
+ except InvalidRunStatusException as e:
117
+ abort_grpc_context(e.message, context)
118
+
119
+ return res
112
120
 
113
121
  def GetRun(
114
122
  self, request: GetRunRequest, context: grpc.ServicerContext
115
123
  ) -> GetRunResponse:
116
124
  """Get run information."""
117
125
  log(INFO, "[Fleet.GetRun] Requesting `Run` for run_id=%s", request.run_id)
118
- return message_handler.get_run(
119
- request=request,
120
- state=self.state_factory.state(),
121
- )
126
+
127
+ try:
128
+ res = message_handler.get_run(
129
+ request=request,
130
+ state=self.state_factory.state(),
131
+ )
132
+ except InvalidRunStatusException as e:
133
+ abort_grpc_context(e.message, context)
134
+
135
+ return res
122
136
 
123
137
  def GetFab(
124
138
  self, request: GetFabRequest, context: grpc.ServicerContext
125
139
  ) -> GetFabResponse:
126
140
  """Get FAB."""
127
141
  log(INFO, "[Fleet.GetFab] Requesting FAB for fab_hash=%s", request.hash_str)
128
- return message_handler.get_fab(
129
- request=request,
130
- ffs=self.ffs_factory.ffs(),
131
- )
142
+ try:
143
+ res = message_handler.get_fab(
144
+ request=request,
145
+ ffs=self.ffs_factory.ffs(),
146
+ state=self.state_factory.state(),
147
+ )
148
+ except InvalidRunStatusException as e:
149
+ abort_grpc_context(e.message, context)
150
+
151
+ return res
@@ -19,8 +19,9 @@ import time
19
19
  from typing import Optional
20
20
  from uuid import UUID
21
21
 
22
+ from flwr.common.constant import Status
22
23
  from flwr.common.serde import fab_to_proto, user_config_to_proto
23
- from flwr.common.typing import Fab
24
+ from flwr.common.typing import Fab, InvalidRunStatusException
24
25
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
25
26
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
26
27
  CreateNodeRequest,
@@ -44,6 +45,7 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
44
45
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
45
46
  from flwr.server.superlink.ffs.ffs import Ffs
46
47
  from flwr.server.superlink.linkstate import LinkState
48
+ from flwr.server.superlink.utils import check_abort
47
49
 
48
50
 
49
51
  def create_node(
@@ -98,6 +100,15 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR
98
100
  task_res: TaskRes = request.task_res_list[0]
99
101
  # pylint: enable=no-member
100
102
 
103
+ # Abort if the run is not running
104
+ abort_msg = check_abort(
105
+ task_res.run_id,
106
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
107
+ state,
108
+ )
109
+ if abort_msg:
110
+ raise InvalidRunStatusException(abort_msg)
111
+
101
112
  # Set pushed_at (timestamp in seconds)
102
113
  task_res.task.pushed_at = time.time()
103
114
 
@@ -121,6 +132,15 @@ def get_run(
121
132
  if run is None:
122
133
  return GetRunResponse()
123
134
 
135
+ # Abort if the run is not running
136
+ abort_msg = check_abort(
137
+ request.run_id,
138
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
139
+ state,
140
+ )
141
+ if abort_msg:
142
+ raise InvalidRunStatusException(abort_msg)
143
+
124
144
  return GetRunResponse(
125
145
  run=Run(
126
146
  run_id=run.run_id,
@@ -133,9 +153,18 @@ def get_run(
133
153
 
134
154
 
135
155
  def get_fab(
136
- request: GetFabRequest, ffs: Ffs # pylint: disable=W0613
156
+ request: GetFabRequest, ffs: Ffs, state: LinkState # pylint: disable=W0613
137
157
  ) -> GetFabResponse:
138
158
  """Get FAB."""
159
+ # Abort if the run is not running
160
+ abort_msg = check_abort(
161
+ request.run_id,
162
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
163
+ state,
164
+ )
165
+ if abort_msg:
166
+ raise InvalidRunStatusException(abort_msg)
167
+
139
168
  if result := ffs.get(request.hash_str):
140
169
  fab = Fab(request.hash_str, result[0])
141
170
  return GetFabResponse(fab=fab_to_proto(fab))
@@ -154,8 +154,11 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
154
154
  # Get ffs from app
155
155
  ffs: Ffs = cast(FfsFactory, app.state.FFS_FACTORY).ffs()
156
156
 
157
+ # Get state from app
158
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
159
+
157
160
  # Handle message
158
- return message_handler.get_fab(request=request, ffs=ffs)
161
+ return message_handler.get_fab(request=request, ffs=ffs, state=state)
159
162
 
160
163
 
161
164
  routes = [
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Fleet Simulation Engine side."""
16
16
 
17
+
17
18
  from .vce_api import start_vce
18
19
 
19
20
  __all__ = [
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Simulation Engine Backends."""
16
16
 
17
+
17
18
  import importlib
18
19
 
19
20
  from .backend import Backend, BackendConfig
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Ray backend for the Fleet API using the Simulation Engine."""
16
16
 
17
+
17
18
  import sys
18
19
  from logging import DEBUG, ERROR
19
20
  from typing import Callable, Optional, Union
@@ -265,41 +265,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
265
265
  for task_res in task_res_found:
266
266
  task_res.task.delivered_at = delivered_at
267
267
 
268
- # Cleanup
269
- self._force_delete_tasks_by_ids(set(ret.keys()))
270
-
271
268
  return list(ret.values())
272
269
 
273
- def delete_tasks(self, task_ids: set[UUID]) -> None:
274
- """Delete all delivered TaskIns/TaskRes pairs."""
275
- task_ins_to_be_deleted: set[UUID] = set()
276
- task_res_to_be_deleted: set[UUID] = set()
277
-
278
- with self.lock:
279
- for task_ins_id in task_ids:
280
- # Find the task_id of the matching task_res
281
- for task_res_id, task_res in self.task_res_store.items():
282
- if UUID(task_res.task.ancestry[0]) != task_ins_id:
283
- continue
284
- if task_res.task.delivered_at == "":
285
- continue
286
-
287
- task_ins_to_be_deleted.add(task_ins_id)
288
- task_res_to_be_deleted.add(task_res_id)
289
-
290
- for task_id in task_ins_to_be_deleted:
291
- del self.task_ins_store[task_id]
292
- del self.task_ins_id_to_task_res_id[task_id]
293
- for task_id in task_res_to_be_deleted:
294
- del self.task_res_store[task_id]
295
-
296
- def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
297
- """Delete tasks based on a set of TaskIns IDs."""
298
- if not task_ids:
270
+ def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
271
+ """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
272
+ if not task_ins_ids:
299
273
  return
300
274
 
301
275
  with self.lock:
302
- for task_id in task_ids:
276
+ for task_id in task_ins_ids:
303
277
  # Delete TaskIns
304
278
  if task_id in self.task_ins_store:
305
279
  del self.task_ins_store[task_id]
@@ -308,6 +282,16 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
308
282
  task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
309
283
  del self.task_res_store[task_res_id]
310
284
 
285
+ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
286
+ """Get all TaskIns IDs for the given run_id."""
287
+ task_id_list: set[UUID] = set()
288
+ with self.lock:
289
+ for task_id, task_ins in self.task_ins_store.items():
290
+ if task_ins.run_id == run_id:
291
+ task_id_list.add(task_id)
292
+
293
+ return task_id_list
294
+
311
295
  def num_task_ins(self) -> int:
312
296
  """Calculate the number of task_ins in store.
313
297
 
@@ -139,8 +139,19 @@ class LinkState(abc.ABC): # pylint: disable=R0904
139
139
  """
140
140
 
141
141
  @abc.abstractmethod
142
- def delete_tasks(self, task_ids: set[UUID]) -> None:
143
- """Delete all delivered TaskIns/TaskRes pairs."""
142
+ def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
143
+ """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs.
144
+
145
+ Parameters
146
+ ----------
147
+ task_ins_ids : set[UUID]
148
+ A set of TaskIns IDs. For each ID in the set, the corresponding
149
+ TaskIns and its associated TaskRes will be deleted.
150
+ """
151
+
152
+ @abc.abstractmethod
153
+ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
154
+ """Get all TaskIns IDs for the given run_id."""
144
155
 
145
156
  @abc.abstractmethod
146
157
  def create_node(
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """SQLite based implemenation of the link state."""
16
16
 
17
+
17
18
  # pylint: disable=too-many-lines
18
19
 
19
20
  import json
@@ -566,9 +567,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
566
567
  data: list[Any] = [delivered_at] + task_res_ids
567
568
  self.query(query, data)
568
569
 
569
- # Cleanup
570
- self._force_delete_tasks_by_ids(set(ret.keys()))
571
-
572
570
  return list(ret.values())
573
571
 
574
572
  def num_task_ins(self) -> int:
@@ -592,68 +590,50 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
592
590
  result: dict[str, int] = rows[0]
593
591
  return result["num"]
594
592
 
595
- def delete_tasks(self, task_ids: set[UUID]) -> None:
596
- """Delete all delivered TaskIns/TaskRes pairs."""
597
- ids = list(task_ids)
598
- if len(ids) == 0:
599
- return None
593
+ def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
594
+ """Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
595
+ if not task_ins_ids:
596
+ return
597
+ if self.conn is None:
598
+ raise AttributeError("LinkState not initialized")
600
599
 
601
- placeholders = ",".join([f":id_{index}" for index in range(len(task_ids))])
602
- data = {f"id_{index}": str(task_id) for index, task_id in enumerate(task_ids)}
600
+ placeholders = ",".join(["?"] * len(task_ins_ids))
601
+ data = tuple(str(task_id) for task_id in task_ins_ids)
603
602
 
604
- # 1. Query: Delete task_ins which have a delivered task_res
603
+ # Delete task_ins
605
604
  query_1 = f"""
606
605
  DELETE FROM task_ins
607
- WHERE delivered_at != ''
608
- AND task_id IN (
609
- SELECT ancestry
610
- FROM task_res
611
- WHERE ancestry IN ({placeholders})
612
- AND delivered_at != ''
613
- );
606
+ WHERE task_id IN ({placeholders});
614
607
  """
615
608
 
616
- # 2. Query: Delete delivered task_res to be run after 1. Query
609
+ # Delete task_res
617
610
  query_2 = f"""
618
611
  DELETE FROM task_res
619
- WHERE ancestry IN ({placeholders})
620
- AND delivered_at != '';
612
+ WHERE ancestry IN ({placeholders});
621
613
  """
622
614
 
623
- if self.conn is None:
624
- raise AttributeError("LinkState not intitialized")
625
-
626
615
  with self.conn:
627
616
  self.conn.execute(query_1, data)
628
617
  self.conn.execute(query_2, data)
629
618
 
630
- return None
631
-
632
- def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
633
- """Delete tasks based on a set of TaskIns IDs."""
634
- if not task_ids:
635
- return
619
+ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
620
+ """Get all TaskIns IDs for the given run_id."""
636
621
  if self.conn is None:
637
622
  raise AttributeError("LinkState not initialized")
638
623
 
639
- placeholders = ",".join([f":id_{index}" for index in range(len(task_ids))])
640
- data = {f"id_{index}": str(task_id) for index, task_id in enumerate(task_ids)}
641
-
642
- # Delete task_ins
643
- query_1 = f"""
644
- DELETE FROM task_ins
645
- WHERE task_id IN ({placeholders});
624
+ query = """
625
+ SELECT task_id
626
+ FROM task_ins
627
+ WHERE run_id = :run_id;
646
628
  """
647
629
 
648
- # Delete task_res
649
- query_2 = f"""
650
- DELETE FROM task_res
651
- WHERE ancestry IN ({placeholders});
652
- """
630
+ sint64_run_id = convert_uint64_to_sint64(run_id)
631
+ data = {"run_id": sint64_run_id}
653
632
 
654
633
  with self.conn:
655
- self.conn.execute(query_1, data)
656
- self.conn.execute(query_2, data)
634
+ rows = self.conn.execute(query, data).fetchall()
635
+
636
+ return {UUID(row["task_id"]) for row in rows}
657
637
 
658
638
  def create_node(
659
639
  self, ping_interval: float, public_key: Optional[bytes] = None