flwr-nightly 1.14.0.dev20241204__py3-none-any.whl → 1.14.0.dev20241214__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +5 -0
- flwr/cli/build.py +1 -0
- flwr/cli/cli_user_auth_interceptor.py +86 -0
- flwr/cli/config_utils.py +19 -2
- flwr/cli/example.py +1 -0
- flwr/cli/install.py +1 -0
- flwr/cli/log.py +11 -31
- flwr/cli/login/__init__.py +22 -0
- flwr/cli/login/login.py +83 -0
- flwr/cli/ls.py +10 -40
- flwr/cli/new/__init__.py +1 -0
- flwr/cli/new/new.py +2 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -2
- flwr/cli/run/__init__.py +1 -0
- flwr/cli/run/run.py +15 -25
- flwr/cli/stop.py +91 -0
- flwr/cli/utils.py +109 -1
- flwr/client/app.py +3 -2
- flwr/client/client.py +1 -0
- flwr/client/clientapp/app.py +1 -0
- flwr/client/clientapp/utils.py +1 -0
- flwr/client/grpc_adapter_client/connection.py +1 -1
- flwr/client/grpc_client/connection.py +1 -1
- flwr/client/grpc_rere_client/connection.py +3 -3
- flwr/client/message_handler/message_handler.py +1 -0
- flwr/client/mod/comms_mods.py +1 -0
- flwr/client/mod/localdp_mod.py +1 -1
- flwr/client/nodestate/__init__.py +1 -0
- flwr/client/nodestate/nodestate.py +1 -0
- flwr/client/nodestate/nodestate_factory.py +1 -0
- flwr/client/rest_client/connection.py +3 -3
- flwr/client/supernode/app.py +1 -0
- flwr/common/address.py +1 -0
- flwr/common/args.py +1 -0
- flwr/common/auth_plugin/__init__.py +24 -0
- flwr/common/auth_plugin/auth_plugin.py +111 -0
- flwr/common/config.py +3 -1
- flwr/common/constant.py +6 -1
- flwr/common/logger.py +1 -0
- flwr/common/message.py +1 -0
- flwr/common/object_ref.py +57 -54
- flwr/common/pyproject.py +1 -0
- flwr/common/record/__init__.py +1 -0
- flwr/common/record/parametersrecord.py +1 -0
- flwr/common/retry_invoker.py +75 -0
- flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
- flwr/common/telemetry.py +2 -1
- flwr/common/typing.py +12 -0
- flwr/common/version.py +1 -0
- flwr/proto/exec_pb2.py +27 -3
- flwr/proto/exec_pb2.pyi +103 -0
- flwr/proto/exec_pb2_grpc.py +102 -0
- flwr/proto/exec_pb2_grpc.pyi +39 -0
- flwr/proto/fab_pb2.py +4 -4
- flwr/proto/fab_pb2.pyi +4 -1
- flwr/proto/serverappio_pb2.py +18 -18
- flwr/proto/serverappio_pb2.pyi +8 -2
- flwr/proto/serverappio_pb2_grpc.py +34 -0
- flwr/proto/serverappio_pb2_grpc.pyi +13 -0
- flwr/proto/simulationio_pb2.py +2 -2
- flwr/proto/simulationio_pb2_grpc.py +34 -0
- flwr/proto/simulationio_pb2_grpc.pyi +13 -0
- flwr/server/app.py +53 -1
- flwr/server/compat/app_utils.py +7 -1
- flwr/server/driver/grpc_driver.py +11 -63
- flwr/server/driver/inmemory_driver.py +5 -1
- flwr/server/serverapp/app.py +9 -2
- flwr/server/strategy/dpfedavg_fixed.py +1 -0
- flwr/server/superlink/driver/serverappio_grpc.py +1 -0
- flwr/server/superlink/driver/serverappio_servicer.py +72 -22
- flwr/server/superlink/ffs/disk_ffs.py +1 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -0
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +32 -12
- flwr/server/superlink/fleet/message_handler/message_handler.py +31 -2
- flwr/server/superlink/fleet/rest_rere/rest_api.py +4 -1
- flwr/server/superlink/fleet/vce/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
- flwr/server/superlink/linkstate/in_memory_linkstate.py +14 -30
- flwr/server/superlink/linkstate/linkstate.py +13 -2
- flwr/server/superlink/linkstate/sqlite_linkstate.py +24 -44
- flwr/server/superlink/simulation/simulationio_servicer.py +20 -0
- flwr/server/superlink/utils.py +65 -0
- flwr/simulation/app.py +1 -0
- flwr/simulation/ray_transport/ray_actor.py +1 -0
- flwr/simulation/ray_transport/utils.py +1 -0
- flwr/simulation/run_simulation.py +1 -0
- flwr/superexec/app.py +1 -0
- flwr/superexec/deployment.py +1 -0
- flwr/superexec/exec_grpc.py +19 -1
- flwr/superexec/exec_servicer.py +76 -2
- flwr/superexec/exec_user_auth_interceptor.py +101 -0
- flwr/superexec/executor.py +1 -0
- {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/METADATA +8 -7
- {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/RECORD +100 -92
- {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.14.0.dev20241204.dist-info → flwr_nightly-1.14.0.dev20241214.dist-info}/entry_points.txt +0 -0
flwr/server/serverapp/app.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower ServerApp process."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
import argparse
|
|
18
19
|
import sys
|
|
19
20
|
from logging import DEBUG, ERROR, INFO
|
|
@@ -50,7 +51,7 @@ from flwr.common.serde import (
|
|
|
50
51
|
run_from_proto,
|
|
51
52
|
run_status_to_proto,
|
|
52
53
|
)
|
|
53
|
-
from flwr.common.typing import RunStatus
|
|
54
|
+
from flwr.common.typing import RunNotRunningException, RunStatus
|
|
54
55
|
from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
|
|
55
56
|
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
56
57
|
PullServerAppInputsRequest,
|
|
@@ -96,7 +97,7 @@ def flwr_serverapp() -> None:
|
|
|
96
97
|
restore_output()
|
|
97
98
|
|
|
98
99
|
|
|
99
|
-
def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
100
|
+
def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
100
101
|
serverappio_api_address: str,
|
|
101
102
|
log_queue: Queue[Optional[str]],
|
|
102
103
|
run_once: bool,
|
|
@@ -187,6 +188,12 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
|
187
188
|
|
|
188
189
|
run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
|
|
189
190
|
|
|
191
|
+
except RunNotRunningException:
|
|
192
|
+
log(INFO, "")
|
|
193
|
+
log(INFO, "Run ID %s stopped.", run.run_id)
|
|
194
|
+
log(INFO, "")
|
|
195
|
+
run_status = None
|
|
196
|
+
|
|
190
197
|
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
191
198
|
exc_entity = "ServerApp"
|
|
192
199
|
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
|
|
@@ -32,6 +32,7 @@ from flwr.common.serde import (
|
|
|
32
32
|
fab_from_proto,
|
|
33
33
|
fab_to_proto,
|
|
34
34
|
run_status_from_proto,
|
|
35
|
+
run_status_to_proto,
|
|
35
36
|
run_to_proto,
|
|
36
37
|
user_config_from_proto,
|
|
37
38
|
)
|
|
@@ -48,6 +49,8 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
48
49
|
CreateRunResponse,
|
|
49
50
|
GetRunRequest,
|
|
50
51
|
GetRunResponse,
|
|
52
|
+
GetRunStatusRequest,
|
|
53
|
+
GetRunStatusResponse,
|
|
51
54
|
UpdateRunStatusRequest,
|
|
52
55
|
UpdateRunStatusResponse,
|
|
53
56
|
)
|
|
@@ -67,6 +70,7 @@ from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
|
|
67
70
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
68
71
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
69
72
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
73
|
+
from flwr.server.superlink.utils import abort_if
|
|
70
74
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
71
75
|
|
|
72
76
|
|
|
@@ -85,7 +89,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
85
89
|
) -> GetNodesResponse:
|
|
86
90
|
"""Get available nodes."""
|
|
87
91
|
log(DEBUG, "ServerAppIoServicer.GetNodes")
|
|
92
|
+
|
|
93
|
+
# Init state
|
|
88
94
|
state: LinkState = self.state_factory.state()
|
|
95
|
+
|
|
96
|
+
# Abort if the run is not running
|
|
97
|
+
abort_if(
|
|
98
|
+
request.run_id,
|
|
99
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
100
|
+
state,
|
|
101
|
+
context,
|
|
102
|
+
)
|
|
103
|
+
|
|
89
104
|
all_ids: set[int] = state.get_nodes(request.run_id)
|
|
90
105
|
nodes: list[Node] = [
|
|
91
106
|
Node(node_id=node_id, anonymous=False) for node_id in all_ids
|
|
@@ -123,6 +138,17 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
123
138
|
"""Push a set of TaskIns."""
|
|
124
139
|
log(DEBUG, "ServerAppIoServicer.PushTaskIns")
|
|
125
140
|
|
|
141
|
+
# Init state
|
|
142
|
+
state: LinkState = self.state_factory.state()
|
|
143
|
+
|
|
144
|
+
# Abort if the run is not running
|
|
145
|
+
abort_if(
|
|
146
|
+
request.run_id,
|
|
147
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
148
|
+
state,
|
|
149
|
+
context,
|
|
150
|
+
)
|
|
151
|
+
|
|
126
152
|
# Set pushed_at (timestamp in seconds)
|
|
127
153
|
pushed_at = time.time()
|
|
128
154
|
for task_ins in request.task_ins_list:
|
|
@@ -134,9 +160,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
134
160
|
validation_errors = validate_task_ins_or_res(task_ins)
|
|
135
161
|
_raise_if(bool(validation_errors), ", ".join(validation_errors))
|
|
136
162
|
|
|
137
|
-
# Init state
|
|
138
|
-
state: LinkState = self.state_factory.state()
|
|
139
|
-
|
|
140
163
|
# Store each TaskIns
|
|
141
164
|
task_ids: list[Optional[UUID]] = []
|
|
142
165
|
for task_ins in request.task_ins_list:
|
|
@@ -153,33 +176,29 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
153
176
|
"""Pull a set of TaskRes."""
|
|
154
177
|
log(DEBUG, "ServerAppIoServicer.PullTaskRes")
|
|
155
178
|
|
|
156
|
-
# Convert each task_id str to UUID
|
|
157
|
-
task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
|
|
158
|
-
|
|
159
179
|
# Init state
|
|
160
180
|
state: LinkState = self.state_factory.state()
|
|
161
181
|
|
|
162
|
-
#
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
if context.is_active():
|
|
170
|
-
return
|
|
171
|
-
if context.code() != grpc.StatusCode.OK:
|
|
172
|
-
return
|
|
173
|
-
|
|
174
|
-
# Delete delivered TaskIns and TaskRes
|
|
175
|
-
state.delete_tasks(task_ids=task_ids)
|
|
182
|
+
# Abort if the run is not running
|
|
183
|
+
abort_if(
|
|
184
|
+
request.run_id,
|
|
185
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
186
|
+
state,
|
|
187
|
+
context,
|
|
188
|
+
)
|
|
176
189
|
|
|
177
|
-
|
|
190
|
+
# Convert each task_id str to UUID
|
|
191
|
+
task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
|
|
178
192
|
|
|
179
193
|
# Read from state
|
|
180
194
|
task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids)
|
|
181
195
|
|
|
182
|
-
|
|
196
|
+
# Delete the TaskIns/TaskRes pairs if TaskRes is found
|
|
197
|
+
task_ins_ids_to_delete = {
|
|
198
|
+
UUID(task_res.task.ancestry[0]) for task_res in task_res_list
|
|
199
|
+
}
|
|
200
|
+
state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
|
|
201
|
+
|
|
183
202
|
return PullTaskResResponse(task_res_list=task_res_list)
|
|
184
203
|
|
|
185
204
|
def GetRun(
|
|
@@ -255,7 +274,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
255
274
|
) -> PushServerAppOutputsResponse:
|
|
256
275
|
"""Push ServerApp process outputs."""
|
|
257
276
|
log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
|
|
277
|
+
|
|
278
|
+
# Init state
|
|
258
279
|
state = self.state_factory.state()
|
|
280
|
+
|
|
281
|
+
# Abort if the run is not running
|
|
282
|
+
abort_if(
|
|
283
|
+
request.run_id,
|
|
284
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
285
|
+
state,
|
|
286
|
+
context,
|
|
287
|
+
)
|
|
288
|
+
|
|
259
289
|
state.set_serverapp_context(request.run_id, context_from_proto(request.context))
|
|
260
290
|
return PushServerAppOutputsResponse()
|
|
261
291
|
|
|
@@ -264,8 +294,13 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
264
294
|
) -> UpdateRunStatusResponse:
|
|
265
295
|
"""Update the status of a run."""
|
|
266
296
|
log(DEBUG, "ServerAppIoServicer.UpdateRunStatus")
|
|
297
|
+
|
|
298
|
+
# Init state
|
|
267
299
|
state = self.state_factory.state()
|
|
268
300
|
|
|
301
|
+
# Abort if the run is finished
|
|
302
|
+
abort_if(request.run_id, [Status.FINISHED], state, context)
|
|
303
|
+
|
|
269
304
|
# Update the run status
|
|
270
305
|
state.update_run_status(
|
|
271
306
|
run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
|
|
@@ -284,6 +319,21 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
284
319
|
state.add_serverapp_log(request.run_id, merged_logs)
|
|
285
320
|
return PushLogsResponse()
|
|
286
321
|
|
|
322
|
+
def GetRunStatus(
|
|
323
|
+
self, request: GetRunStatusRequest, context: grpc.ServicerContext
|
|
324
|
+
) -> GetRunStatusResponse:
|
|
325
|
+
"""Get the status of a run."""
|
|
326
|
+
log(DEBUG, "ServerAppIoServicer.GetRunStatus")
|
|
327
|
+
state = self.state_factory.state()
|
|
328
|
+
|
|
329
|
+
# Get run status from LinkState
|
|
330
|
+
run_statuses = state.get_run_status(set(request.run_ids))
|
|
331
|
+
run_status_dict = {
|
|
332
|
+
run_id: run_status_to_proto(run_status)
|
|
333
|
+
for run_id, run_status in run_statuses.items()
|
|
334
|
+
}
|
|
335
|
+
return GetRunStatusResponse(run_status_dict=run_status_dict)
|
|
336
|
+
|
|
287
337
|
|
|
288
338
|
def _raise_if(validation_error: bool, detail: str) -> None:
|
|
289
339
|
if validation_error:
|
|
@@ -20,6 +20,7 @@ from logging import DEBUG, INFO
|
|
|
20
20
|
import grpc
|
|
21
21
|
|
|
22
22
|
from flwr.common.logger import log
|
|
23
|
+
from flwr.common.typing import InvalidRunStatusException
|
|
23
24
|
from flwr.proto import fleet_pb2_grpc # pylint: disable=E0611
|
|
24
25
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
25
26
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
@@ -38,6 +39,7 @@ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=
|
|
|
38
39
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
39
40
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
40
41
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
42
|
+
from flwr.server.superlink.utils import abort_grpc_context
|
|
41
43
|
|
|
42
44
|
|
|
43
45
|
class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
@@ -105,27 +107,45 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
105
107
|
)
|
|
106
108
|
else:
|
|
107
109
|
log(INFO, "[Fleet.PushTaskRes] No task results to push")
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
res = message_handler.push_task_res(
|
|
113
|
+
request=request,
|
|
114
|
+
state=self.state_factory.state(),
|
|
115
|
+
)
|
|
116
|
+
except InvalidRunStatusException as e:
|
|
117
|
+
abort_grpc_context(e.message, context)
|
|
118
|
+
|
|
119
|
+
return res
|
|
112
120
|
|
|
113
121
|
def GetRun(
|
|
114
122
|
self, request: GetRunRequest, context: grpc.ServicerContext
|
|
115
123
|
) -> GetRunResponse:
|
|
116
124
|
"""Get run information."""
|
|
117
125
|
log(INFO, "[Fleet.GetRun] Requesting `Run` for run_id=%s", request.run_id)
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
res = message_handler.get_run(
|
|
129
|
+
request=request,
|
|
130
|
+
state=self.state_factory.state(),
|
|
131
|
+
)
|
|
132
|
+
except InvalidRunStatusException as e:
|
|
133
|
+
abort_grpc_context(e.message, context)
|
|
134
|
+
|
|
135
|
+
return res
|
|
122
136
|
|
|
123
137
|
def GetFab(
|
|
124
138
|
self, request: GetFabRequest, context: grpc.ServicerContext
|
|
125
139
|
) -> GetFabResponse:
|
|
126
140
|
"""Get FAB."""
|
|
127
141
|
log(INFO, "[Fleet.GetFab] Requesting FAB for fab_hash=%s", request.hash_str)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
142
|
+
try:
|
|
143
|
+
res = message_handler.get_fab(
|
|
144
|
+
request=request,
|
|
145
|
+
ffs=self.ffs_factory.ffs(),
|
|
146
|
+
state=self.state_factory.state(),
|
|
147
|
+
)
|
|
148
|
+
except InvalidRunStatusException as e:
|
|
149
|
+
abort_grpc_context(e.message, context)
|
|
150
|
+
|
|
151
|
+
return res
|
|
@@ -19,8 +19,9 @@ import time
|
|
|
19
19
|
from typing import Optional
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
|
+
from flwr.common.constant import Status
|
|
22
23
|
from flwr.common.serde import fab_to_proto, user_config_to_proto
|
|
23
|
-
from flwr.common.typing import Fab
|
|
24
|
+
from flwr.common.typing import Fab, InvalidRunStatusException
|
|
24
25
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
25
26
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
26
27
|
CreateNodeRequest,
|
|
@@ -44,6 +45,7 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
44
45
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
45
46
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
46
47
|
from flwr.server.superlink.linkstate import LinkState
|
|
48
|
+
from flwr.server.superlink.utils import check_abort
|
|
47
49
|
|
|
48
50
|
|
|
49
51
|
def create_node(
|
|
@@ -98,6 +100,15 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR
|
|
|
98
100
|
task_res: TaskRes = request.task_res_list[0]
|
|
99
101
|
# pylint: enable=no-member
|
|
100
102
|
|
|
103
|
+
# Abort if the run is not running
|
|
104
|
+
abort_msg = check_abort(
|
|
105
|
+
task_res.run_id,
|
|
106
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
107
|
+
state,
|
|
108
|
+
)
|
|
109
|
+
if abort_msg:
|
|
110
|
+
raise InvalidRunStatusException(abort_msg)
|
|
111
|
+
|
|
101
112
|
# Set pushed_at (timestamp in seconds)
|
|
102
113
|
task_res.task.pushed_at = time.time()
|
|
103
114
|
|
|
@@ -121,6 +132,15 @@ def get_run(
|
|
|
121
132
|
if run is None:
|
|
122
133
|
return GetRunResponse()
|
|
123
134
|
|
|
135
|
+
# Abort if the run is not running
|
|
136
|
+
abort_msg = check_abort(
|
|
137
|
+
request.run_id,
|
|
138
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
139
|
+
state,
|
|
140
|
+
)
|
|
141
|
+
if abort_msg:
|
|
142
|
+
raise InvalidRunStatusException(abort_msg)
|
|
143
|
+
|
|
124
144
|
return GetRunResponse(
|
|
125
145
|
run=Run(
|
|
126
146
|
run_id=run.run_id,
|
|
@@ -133,9 +153,18 @@ def get_run(
|
|
|
133
153
|
|
|
134
154
|
|
|
135
155
|
def get_fab(
|
|
136
|
-
request: GetFabRequest, ffs: Ffs # pylint: disable=W0613
|
|
156
|
+
request: GetFabRequest, ffs: Ffs, state: LinkState # pylint: disable=W0613
|
|
137
157
|
) -> GetFabResponse:
|
|
138
158
|
"""Get FAB."""
|
|
159
|
+
# Abort if the run is not running
|
|
160
|
+
abort_msg = check_abort(
|
|
161
|
+
request.run_id,
|
|
162
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
163
|
+
state,
|
|
164
|
+
)
|
|
165
|
+
if abort_msg:
|
|
166
|
+
raise InvalidRunStatusException(abort_msg)
|
|
167
|
+
|
|
139
168
|
if result := ffs.get(request.hash_str):
|
|
140
169
|
fab = Fab(request.hash_str, result[0])
|
|
141
170
|
return GetFabResponse(fab=fab_to_proto(fab))
|
|
@@ -154,8 +154,11 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
|
|
|
154
154
|
# Get ffs from app
|
|
155
155
|
ffs: Ffs = cast(FfsFactory, app.state.FFS_FACTORY).ffs()
|
|
156
156
|
|
|
157
|
+
# Get state from app
|
|
158
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
159
|
+
|
|
157
160
|
# Handle message
|
|
158
|
-
return message_handler.get_fab(request=request, ffs=ffs)
|
|
161
|
+
return message_handler.get_fab(request=request, ffs=ffs, state=state)
|
|
159
162
|
|
|
160
163
|
|
|
161
164
|
routes = [
|
|
@@ -265,41 +265,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
265
265
|
for task_res in task_res_found:
|
|
266
266
|
task_res.task.delivered_at = delivered_at
|
|
267
267
|
|
|
268
|
-
# Cleanup
|
|
269
|
-
self._force_delete_tasks_by_ids(set(ret.keys()))
|
|
270
|
-
|
|
271
268
|
return list(ret.values())
|
|
272
269
|
|
|
273
|
-
def delete_tasks(self,
|
|
274
|
-
"""Delete
|
|
275
|
-
|
|
276
|
-
task_res_to_be_deleted: set[UUID] = set()
|
|
277
|
-
|
|
278
|
-
with self.lock:
|
|
279
|
-
for task_ins_id in task_ids:
|
|
280
|
-
# Find the task_id of the matching task_res
|
|
281
|
-
for task_res_id, task_res in self.task_res_store.items():
|
|
282
|
-
if UUID(task_res.task.ancestry[0]) != task_ins_id:
|
|
283
|
-
continue
|
|
284
|
-
if task_res.task.delivered_at == "":
|
|
285
|
-
continue
|
|
286
|
-
|
|
287
|
-
task_ins_to_be_deleted.add(task_ins_id)
|
|
288
|
-
task_res_to_be_deleted.add(task_res_id)
|
|
289
|
-
|
|
290
|
-
for task_id in task_ins_to_be_deleted:
|
|
291
|
-
del self.task_ins_store[task_id]
|
|
292
|
-
del self.task_ins_id_to_task_res_id[task_id]
|
|
293
|
-
for task_id in task_res_to_be_deleted:
|
|
294
|
-
del self.task_res_store[task_id]
|
|
295
|
-
|
|
296
|
-
def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
|
|
297
|
-
"""Delete tasks based on a set of TaskIns IDs."""
|
|
298
|
-
if not task_ids:
|
|
270
|
+
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
|
271
|
+
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
|
|
272
|
+
if not task_ins_ids:
|
|
299
273
|
return
|
|
300
274
|
|
|
301
275
|
with self.lock:
|
|
302
|
-
for task_id in
|
|
276
|
+
for task_id in task_ins_ids:
|
|
303
277
|
# Delete TaskIns
|
|
304
278
|
if task_id in self.task_ins_store:
|
|
305
279
|
del self.task_ins_store[task_id]
|
|
@@ -308,6 +282,16 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
308
282
|
task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
|
|
309
283
|
del self.task_res_store[task_res_id]
|
|
310
284
|
|
|
285
|
+
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
286
|
+
"""Get all TaskIns IDs for the given run_id."""
|
|
287
|
+
task_id_list: set[UUID] = set()
|
|
288
|
+
with self.lock:
|
|
289
|
+
for task_id, task_ins in self.task_ins_store.items():
|
|
290
|
+
if task_ins.run_id == run_id:
|
|
291
|
+
task_id_list.add(task_id)
|
|
292
|
+
|
|
293
|
+
return task_id_list
|
|
294
|
+
|
|
311
295
|
def num_task_ins(self) -> int:
|
|
312
296
|
"""Calculate the number of task_ins in store.
|
|
313
297
|
|
|
@@ -139,8 +139,19 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
139
139
|
"""
|
|
140
140
|
|
|
141
141
|
@abc.abstractmethod
|
|
142
|
-
def delete_tasks(self,
|
|
143
|
-
"""Delete
|
|
142
|
+
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
|
143
|
+
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
task_ins_ids : set[UUID]
|
|
148
|
+
A set of TaskIns IDs. For each ID in the set, the corresponding
|
|
149
|
+
TaskIns and its associated TaskRes will be deleted.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
@abc.abstractmethod
|
|
153
|
+
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
154
|
+
"""Get all TaskIns IDs for the given run_id."""
|
|
144
155
|
|
|
145
156
|
@abc.abstractmethod
|
|
146
157
|
def create_node(
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""SQLite based implemenation of the link state."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
# pylint: disable=too-many-lines
|
|
18
19
|
|
|
19
20
|
import json
|
|
@@ -566,9 +567,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
566
567
|
data: list[Any] = [delivered_at] + task_res_ids
|
|
567
568
|
self.query(query, data)
|
|
568
569
|
|
|
569
|
-
# Cleanup
|
|
570
|
-
self._force_delete_tasks_by_ids(set(ret.keys()))
|
|
571
|
-
|
|
572
570
|
return list(ret.values())
|
|
573
571
|
|
|
574
572
|
def num_task_ins(self) -> int:
|
|
@@ -592,68 +590,50 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
592
590
|
result: dict[str, int] = rows[0]
|
|
593
591
|
return result["num"]
|
|
594
592
|
|
|
595
|
-
def delete_tasks(self,
|
|
596
|
-
"""Delete
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
593
|
+
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
|
594
|
+
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
|
|
595
|
+
if not task_ins_ids:
|
|
596
|
+
return
|
|
597
|
+
if self.conn is None:
|
|
598
|
+
raise AttributeError("LinkState not initialized")
|
|
600
599
|
|
|
601
|
-
placeholders = ",".join([
|
|
602
|
-
data =
|
|
600
|
+
placeholders = ",".join(["?"] * len(task_ins_ids))
|
|
601
|
+
data = tuple(str(task_id) for task_id in task_ins_ids)
|
|
603
602
|
|
|
604
|
-
#
|
|
603
|
+
# Delete task_ins
|
|
605
604
|
query_1 = f"""
|
|
606
605
|
DELETE FROM task_ins
|
|
607
|
-
WHERE
|
|
608
|
-
AND task_id IN (
|
|
609
|
-
SELECT ancestry
|
|
610
|
-
FROM task_res
|
|
611
|
-
WHERE ancestry IN ({placeholders})
|
|
612
|
-
AND delivered_at != ''
|
|
613
|
-
);
|
|
606
|
+
WHERE task_id IN ({placeholders});
|
|
614
607
|
"""
|
|
615
608
|
|
|
616
|
-
#
|
|
609
|
+
# Delete task_res
|
|
617
610
|
query_2 = f"""
|
|
618
611
|
DELETE FROM task_res
|
|
619
|
-
WHERE ancestry IN ({placeholders})
|
|
620
|
-
AND delivered_at != '';
|
|
612
|
+
WHERE ancestry IN ({placeholders});
|
|
621
613
|
"""
|
|
622
614
|
|
|
623
|
-
if self.conn is None:
|
|
624
|
-
raise AttributeError("LinkState not intitialized")
|
|
625
|
-
|
|
626
615
|
with self.conn:
|
|
627
616
|
self.conn.execute(query_1, data)
|
|
628
617
|
self.conn.execute(query_2, data)
|
|
629
618
|
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
|
|
633
|
-
"""Delete tasks based on a set of TaskIns IDs."""
|
|
634
|
-
if not task_ids:
|
|
635
|
-
return
|
|
619
|
+
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
620
|
+
"""Get all TaskIns IDs for the given run_id."""
|
|
636
621
|
if self.conn is None:
|
|
637
622
|
raise AttributeError("LinkState not initialized")
|
|
638
623
|
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
query_1 = f"""
|
|
644
|
-
DELETE FROM task_ins
|
|
645
|
-
WHERE task_id IN ({placeholders});
|
|
624
|
+
query = """
|
|
625
|
+
SELECT task_id
|
|
626
|
+
FROM task_ins
|
|
627
|
+
WHERE run_id = :run_id;
|
|
646
628
|
"""
|
|
647
629
|
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
DELETE FROM task_res
|
|
651
|
-
WHERE ancestry IN ({placeholders});
|
|
652
|
-
"""
|
|
630
|
+
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
631
|
+
data = {"run_id": sint64_run_id}
|
|
653
632
|
|
|
654
633
|
with self.conn:
|
|
655
|
-
self.conn.execute(
|
|
656
|
-
|
|
634
|
+
rows = self.conn.execute(query, data).fetchall()
|
|
635
|
+
|
|
636
|
+
return {UUID(row["task_id"]) for row in rows}
|
|
657
637
|
|
|
658
638
|
def create_node(
|
|
659
639
|
self, ping_interval: float, public_key: Optional[bytes] = None
|