flwr-nightly 1.13.0.dev20241111__py3-none-any.whl → 1.13.0.dev20241117__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +37 -0
- flwr/cli/install.py +5 -3
- flwr/cli/ls.py +228 -0
- flwr/client/app.py +58 -13
- flwr/client/clientapp/app.py +34 -23
- flwr/client/grpc_rere_client/connection.py +2 -12
- flwr/client/rest_client/connection.py +4 -14
- flwr/client/supernode/app.py +57 -53
- flwr/common/args.py +72 -7
- flwr/common/constant.py +21 -6
- flwr/common/date.py +18 -0
- flwr/common/serde.py +10 -0
- flwr/common/typing.py +31 -10
- flwr/proto/exec_pb2.py +22 -13
- flwr/proto/exec_pb2.pyi +44 -0
- flwr/proto/exec_pb2_grpc.py +34 -0
- flwr/proto/exec_pb2_grpc.pyi +13 -0
- flwr/proto/run_pb2.py +30 -30
- flwr/proto/run_pb2.pyi +18 -1
- flwr/server/app.py +39 -68
- flwr/server/driver/grpc_driver.py +4 -14
- flwr/server/run_serverapp.py +8 -238
- flwr/server/serverapp/app.py +34 -23
- flwr/server/superlink/fleet/rest_rere/rest_api.py +10 -9
- flwr/server/superlink/linkstate/in_memory_linkstate.py +71 -46
- flwr/server/superlink/linkstate/linkstate.py +19 -5
- flwr/server/superlink/linkstate/sqlite_linkstate.py +81 -113
- flwr/server/superlink/linkstate/utils.py +193 -3
- flwr/simulation/app.py +6 -41
- flwr/simulation/legacy_app.py +21 -1
- flwr/simulation/run_simulation.py +7 -18
- flwr/simulation/simulationio_connection.py +2 -2
- flwr/superexec/deployment.py +12 -6
- flwr/superexec/exec_servicer.py +31 -2
- flwr/superexec/simulation.py +11 -46
- {flwr_nightly-1.13.0.dev20241111.dist-info → flwr_nightly-1.13.0.dev20241117.dist-info}/METADATA +6 -4
- {flwr_nightly-1.13.0.dev20241111.dist-info → flwr_nightly-1.13.0.dev20241117.dist-info}/RECORD +41 -40
- {flwr_nightly-1.13.0.dev20241111.dist-info → flwr_nightly-1.13.0.dev20241117.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.13.0.dev20241111.dist-info → flwr_nightly-1.13.0.dev20241117.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.13.0.dev20241111.dist-info → flwr_nightly-1.13.0.dev20241117.dist-info}/entry_points.txt +0 -0
flwr/server/serverapp/app.py
CHANGED
|
@@ -23,14 +23,18 @@ from typing import Optional
|
|
|
23
23
|
|
|
24
24
|
from flwr.cli.config_utils import get_fab_metadata
|
|
25
25
|
from flwr.cli.install import install_from_fab
|
|
26
|
-
from flwr.common.args import add_args_flwr_app_common,
|
|
26
|
+
from flwr.common.args import add_args_flwr_app_common, try_obtain_root_certificates
|
|
27
27
|
from flwr.common.config import (
|
|
28
28
|
get_flwr_dir,
|
|
29
29
|
get_fused_config_from_dir,
|
|
30
30
|
get_project_config,
|
|
31
31
|
get_project_dir,
|
|
32
32
|
)
|
|
33
|
-
from flwr.common.constant import
|
|
33
|
+
from flwr.common.constant import (
|
|
34
|
+
SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
|
|
35
|
+
Status,
|
|
36
|
+
SubStatus,
|
|
37
|
+
)
|
|
34
38
|
from flwr.common.logger import (
|
|
35
39
|
log,
|
|
36
40
|
mirror_output_to_queue,
|
|
@@ -62,33 +66,18 @@ def flwr_serverapp() -> None:
|
|
|
62
66
|
log_queue: Queue[Optional[str]] = Queue()
|
|
63
67
|
mirror_output_to_queue(log_queue)
|
|
64
68
|
|
|
65
|
-
|
|
66
|
-
description="Run a Flower ServerApp",
|
|
67
|
-
)
|
|
68
|
-
parser.add_argument(
|
|
69
|
-
"--superlink",
|
|
70
|
-
type=str,
|
|
71
|
-
help="Address of SuperLink's ServerAppIo API",
|
|
72
|
-
)
|
|
73
|
-
parser.add_argument(
|
|
74
|
-
"--run-once",
|
|
75
|
-
action="store_true",
|
|
76
|
-
help="When set, this process will start a single ServerApp for a pending Run. "
|
|
77
|
-
"If there is no pending Run, the process will exit.",
|
|
78
|
-
)
|
|
79
|
-
add_args_flwr_app_common(parser=parser)
|
|
80
|
-
args = parser.parse_args()
|
|
69
|
+
args = _parse_args_run_flwr_serverapp().parse_args()
|
|
81
70
|
|
|
82
71
|
log(INFO, "Starting Flower ServerApp")
|
|
83
|
-
certificates =
|
|
72
|
+
certificates = try_obtain_root_certificates(args, args.serverappio_api_address)
|
|
84
73
|
|
|
85
74
|
log(
|
|
86
75
|
DEBUG,
|
|
87
76
|
"Starting isolated `ServerApp` connected to SuperLink's ServerAppIo API at %s",
|
|
88
|
-
args.
|
|
77
|
+
args.serverappio_api_address,
|
|
89
78
|
)
|
|
90
79
|
run_serverapp(
|
|
91
|
-
|
|
80
|
+
serverappio_api_address=args.serverappio_api_address,
|
|
92
81
|
log_queue=log_queue,
|
|
93
82
|
run_once=args.run_once,
|
|
94
83
|
flwr_dir=args.flwr_dir,
|
|
@@ -100,7 +89,7 @@ def flwr_serverapp() -> None:
|
|
|
100
89
|
|
|
101
90
|
|
|
102
91
|
def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
103
|
-
|
|
92
|
+
serverappio_api_address: str,
|
|
104
93
|
log_queue: Queue[Optional[str]],
|
|
105
94
|
run_once: bool,
|
|
106
95
|
flwr_dir: Optional[str] = None,
|
|
@@ -108,7 +97,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
|
108
97
|
) -> None:
|
|
109
98
|
"""Run Flower ServerApp process."""
|
|
110
99
|
driver = GrpcDriver(
|
|
111
|
-
serverappio_service_address=
|
|
100
|
+
serverappio_service_address=serverappio_api_address,
|
|
112
101
|
root_certificates=certificates,
|
|
113
102
|
)
|
|
114
103
|
|
|
@@ -212,3 +201,25 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
|
212
201
|
# Stop the loop if `flwr-serverapp` is expected to process a single run
|
|
213
202
|
if run_once:
|
|
214
203
|
break
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _parse_args_run_flwr_serverapp() -> argparse.ArgumentParser:
|
|
207
|
+
"""Parse flwr-serverapp command line arguments."""
|
|
208
|
+
parser = argparse.ArgumentParser(
|
|
209
|
+
description="Run a Flower ServerApp",
|
|
210
|
+
)
|
|
211
|
+
parser.add_argument(
|
|
212
|
+
"--serverappio-api-address",
|
|
213
|
+
default=SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
|
|
214
|
+
type=str,
|
|
215
|
+
help="Address of SuperLink's ServerAppIo API (IPv4, IPv6, or a domain name)."
|
|
216
|
+
f"By default, it is set to {SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS}.",
|
|
217
|
+
)
|
|
218
|
+
parser.add_argument(
|
|
219
|
+
"--run-once",
|
|
220
|
+
action="store_true",
|
|
221
|
+
help="When set, this process will start a single ServerApp for a pending Run. "
|
|
222
|
+
"If there is no pending Run, the process will exit.",
|
|
223
|
+
)
|
|
224
|
+
add_args_flwr_app_common(parser=parser)
|
|
225
|
+
return parser
|
|
@@ -19,7 +19,7 @@ from __future__ import annotations
|
|
|
19
19
|
|
|
20
20
|
import sys
|
|
21
21
|
from collections.abc import Awaitable
|
|
22
|
-
from typing import Callable, TypeVar
|
|
22
|
+
from typing import Callable, TypeVar, cast
|
|
23
23
|
|
|
24
24
|
from google.protobuf.message import Message as GrpcMessage
|
|
25
25
|
|
|
@@ -39,8 +39,9 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
39
39
|
)
|
|
40
40
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
41
41
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
42
|
+
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
42
43
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
43
|
-
from flwr.server.superlink.linkstate import LinkState
|
|
44
|
+
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
44
45
|
|
|
45
46
|
try:
|
|
46
47
|
from starlette.applications import Starlette
|
|
@@ -90,7 +91,7 @@ def rest_request_response(
|
|
|
90
91
|
async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
|
|
91
92
|
"""Create Node."""
|
|
92
93
|
# Get state from app
|
|
93
|
-
state: LinkState = app.state.STATE_FACTORY.state()
|
|
94
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
94
95
|
|
|
95
96
|
# Handle message
|
|
96
97
|
return message_handler.create_node(request=request, state=state)
|
|
@@ -100,7 +101,7 @@ async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
|
|
|
100
101
|
async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
|
|
101
102
|
"""Delete Node Id."""
|
|
102
103
|
# Get state from app
|
|
103
|
-
state: LinkState = app.state.STATE_FACTORY.state()
|
|
104
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
104
105
|
|
|
105
106
|
# Handle message
|
|
106
107
|
return message_handler.delete_node(request=request, state=state)
|
|
@@ -110,7 +111,7 @@ async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
|
|
|
110
111
|
async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
|
|
111
112
|
"""Pull TaskIns."""
|
|
112
113
|
# Get state from app
|
|
113
|
-
state: LinkState = app.state.STATE_FACTORY.state()
|
|
114
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
114
115
|
|
|
115
116
|
# Handle message
|
|
116
117
|
return message_handler.pull_task_ins(request=request, state=state)
|
|
@@ -121,7 +122,7 @@ async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
|
|
|
121
122
|
async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
|
|
122
123
|
"""Push TaskRes."""
|
|
123
124
|
# Get state from app
|
|
124
|
-
state: LinkState = app.state.STATE_FACTORY.state()
|
|
125
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
125
126
|
|
|
126
127
|
# Handle message
|
|
127
128
|
return message_handler.push_task_res(request=request, state=state)
|
|
@@ -131,7 +132,7 @@ async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
|
|
|
131
132
|
async def ping(request: PingRequest) -> PingResponse:
|
|
132
133
|
"""Ping."""
|
|
133
134
|
# Get state from app
|
|
134
|
-
state: LinkState = app.state.STATE_FACTORY.state()
|
|
135
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
135
136
|
|
|
136
137
|
# Handle message
|
|
137
138
|
return message_handler.ping(request=request, state=state)
|
|
@@ -141,7 +142,7 @@ async def ping(request: PingRequest) -> PingResponse:
|
|
|
141
142
|
async def get_run(request: GetRunRequest) -> GetRunResponse:
|
|
142
143
|
"""GetRun."""
|
|
143
144
|
# Get state from app
|
|
144
|
-
state: LinkState = app.state.STATE_FACTORY.state()
|
|
145
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
145
146
|
|
|
146
147
|
# Handle message
|
|
147
148
|
return message_handler.get_run(request=request, state=state)
|
|
@@ -151,7 +152,7 @@ async def get_run(request: GetRunRequest) -> GetRunResponse:
|
|
|
151
152
|
async def get_fab(request: GetFabRequest) -> GetFabResponse:
|
|
152
153
|
"""GetRun."""
|
|
153
154
|
# Get ffs from app
|
|
154
|
-
ffs: Ffs = app.state.FFS_FACTORY.
|
|
155
|
+
ffs: Ffs = cast(FfsFactory, app.state.FFS_FACTORY).ffs()
|
|
155
156
|
|
|
156
157
|
# Handle message
|
|
157
158
|
return message_handler.get_fab(request=request, ffs=ffs)
|
|
@@ -40,6 +40,8 @@ from .utils import (
|
|
|
40
40
|
generate_rand_int_from_bytes,
|
|
41
41
|
has_valid_sub_status,
|
|
42
42
|
is_valid_transition,
|
|
43
|
+
verify_found_taskres,
|
|
44
|
+
verify_taskins_ids,
|
|
43
45
|
)
|
|
44
46
|
|
|
45
47
|
|
|
@@ -48,11 +50,6 @@ class RunRecord: # pylint: disable=R0902
|
|
|
48
50
|
"""The record of a specific run, including its status and timestamps."""
|
|
49
51
|
|
|
50
52
|
run: Run
|
|
51
|
-
status: RunStatus
|
|
52
|
-
pending_at: str = ""
|
|
53
|
-
starting_at: str = ""
|
|
54
|
-
running_at: str = ""
|
|
55
|
-
finished_at: str = ""
|
|
56
53
|
logs: list[tuple[float, str]] = field(default_factory=list)
|
|
57
54
|
log_lock: threading.Lock = field(default_factory=threading.Lock)
|
|
58
55
|
|
|
@@ -72,12 +69,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
72
69
|
self.federation_options: dict[int, ConfigsRecord] = {}
|
|
73
70
|
self.task_ins_store: dict[UUID, TaskIns] = {}
|
|
74
71
|
self.task_res_store: dict[UUID, TaskRes] = {}
|
|
72
|
+
self.task_ins_id_to_task_res_id: dict[UUID, UUID] = {}
|
|
75
73
|
|
|
76
74
|
self.node_public_keys: set[bytes] = set()
|
|
77
75
|
self.server_public_key: Optional[bytes] = None
|
|
78
76
|
self.server_private_key: Optional[bytes] = None
|
|
79
77
|
|
|
80
|
-
self.lock = threading.
|
|
78
|
+
self.lock = threading.RLock()
|
|
81
79
|
|
|
82
80
|
def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
|
|
83
81
|
"""Store one TaskIns."""
|
|
@@ -227,42 +225,50 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
227
225
|
task_res.task_id = str(task_id)
|
|
228
226
|
with self.lock:
|
|
229
227
|
self.task_res_store[task_id] = task_res
|
|
228
|
+
self.task_ins_id_to_task_res_id[UUID(task_ins_id)] = task_id
|
|
230
229
|
|
|
231
230
|
# Return the new task_id
|
|
232
231
|
return task_id
|
|
233
232
|
|
|
234
233
|
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
|
235
|
-
"""Get
|
|
234
|
+
"""Get TaskRes for the given TaskIns IDs."""
|
|
235
|
+
ret: dict[UUID, TaskRes] = {}
|
|
236
|
+
|
|
236
237
|
with self.lock:
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
if
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
238
|
+
current = time.time()
|
|
239
|
+
|
|
240
|
+
# Verify TaskIns IDs
|
|
241
|
+
ret = verify_taskins_ids(
|
|
242
|
+
inquired_taskins_ids=task_ids,
|
|
243
|
+
found_taskins_dict=self.task_ins_store,
|
|
244
|
+
current_time=current,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Find all TaskRes
|
|
248
|
+
task_res_found: list[TaskRes] = []
|
|
249
|
+
for task_id in task_ids:
|
|
250
|
+
# If TaskRes exists and is not delivered, add it to the list
|
|
251
|
+
if task_res_id := self.task_ins_id_to_task_res_id.get(task_id):
|
|
252
|
+
task_res = self.task_res_store[task_res_id]
|
|
253
|
+
if task_res.task.delivered_at == "":
|
|
254
|
+
task_res_found.append(task_res)
|
|
255
|
+
tmp_ret_dict = verify_found_taskres(
|
|
256
|
+
inquired_taskins_ids=task_ids,
|
|
257
|
+
found_taskins_dict=self.task_ins_store,
|
|
258
|
+
found_taskres_list=task_res_found,
|
|
259
|
+
current_time=current,
|
|
260
|
+
)
|
|
261
|
+
ret.update(tmp_ret_dict)
|
|
262
|
+
|
|
263
|
+
# Mark existing TaskRes to be returned as delivered
|
|
260
264
|
delivered_at = now().isoformat()
|
|
261
|
-
for task_res in
|
|
265
|
+
for task_res in task_res_found:
|
|
262
266
|
task_res.task.delivered_at = delivered_at
|
|
263
267
|
|
|
264
|
-
#
|
|
265
|
-
|
|
268
|
+
# Cleanup
|
|
269
|
+
self._force_delete_tasks_by_ids(set(ret.keys()))
|
|
270
|
+
|
|
271
|
+
return list(ret.values())
|
|
266
272
|
|
|
267
273
|
def delete_tasks(self, task_ids: set[UUID]) -> None:
|
|
268
274
|
"""Delete all delivered TaskIns/TaskRes pairs."""
|
|
@@ -283,9 +289,25 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
283
289
|
|
|
284
290
|
for task_id in task_ins_to_be_deleted:
|
|
285
291
|
del self.task_ins_store[task_id]
|
|
292
|
+
del self.task_ins_id_to_task_res_id[task_id]
|
|
286
293
|
for task_id in task_res_to_be_deleted:
|
|
287
294
|
del self.task_res_store[task_id]
|
|
288
295
|
|
|
296
|
+
def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
|
|
297
|
+
"""Delete tasks based on a set of TaskIns IDs."""
|
|
298
|
+
if not task_ids:
|
|
299
|
+
return
|
|
300
|
+
|
|
301
|
+
with self.lock:
|
|
302
|
+
for task_id in task_ids:
|
|
303
|
+
# Delete TaskIns
|
|
304
|
+
if task_id in self.task_ins_store:
|
|
305
|
+
del self.task_ins_store[task_id]
|
|
306
|
+
# Delete TaskRes
|
|
307
|
+
if task_id in self.task_ins_id_to_task_res_id:
|
|
308
|
+
task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
|
|
309
|
+
del self.task_res_store[task_res_id]
|
|
310
|
+
|
|
289
311
|
def num_task_ins(self) -> int:
|
|
290
312
|
"""Calculate the number of task_ins in store.
|
|
291
313
|
|
|
@@ -386,13 +408,16 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
386
408
|
fab_version=fab_version if fab_version else "",
|
|
387
409
|
fab_hash=fab_hash if fab_hash else "",
|
|
388
410
|
override_config=override_config,
|
|
411
|
+
pending_at=now().isoformat(),
|
|
412
|
+
starting_at="",
|
|
413
|
+
running_at="",
|
|
414
|
+
finished_at="",
|
|
415
|
+
status=RunStatus(
|
|
416
|
+
status=Status.PENDING,
|
|
417
|
+
sub_status="",
|
|
418
|
+
details="",
|
|
419
|
+
),
|
|
389
420
|
),
|
|
390
|
-
status=RunStatus(
|
|
391
|
-
status=Status.PENDING,
|
|
392
|
-
sub_status="",
|
|
393
|
-
details="",
|
|
394
|
-
),
|
|
395
|
-
pending_at=now().isoformat(),
|
|
396
421
|
)
|
|
397
422
|
self.run_ids[run_id] = run_record
|
|
398
423
|
|
|
@@ -452,7 +477,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
452
477
|
"""Retrieve the statuses for the specified runs."""
|
|
453
478
|
with self.lock:
|
|
454
479
|
return {
|
|
455
|
-
run_id: self.run_ids[run_id].status
|
|
480
|
+
run_id: self.run_ids[run_id].run.status
|
|
456
481
|
for run_id in set(run_ids)
|
|
457
482
|
if run_id in self.run_ids
|
|
458
483
|
}
|
|
@@ -466,7 +491,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
466
491
|
return False
|
|
467
492
|
|
|
468
493
|
# Check if the status transition is valid
|
|
469
|
-
current_status = self.run_ids[run_id].status
|
|
494
|
+
current_status = self.run_ids[run_id].run.status
|
|
470
495
|
if not is_valid_transition(current_status, new_status):
|
|
471
496
|
log(
|
|
472
497
|
ERROR,
|
|
@@ -489,12 +514,12 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
489
514
|
# Update the status
|
|
490
515
|
run_record = self.run_ids[run_id]
|
|
491
516
|
if new_status.status == Status.STARTING:
|
|
492
|
-
run_record.starting_at = now().isoformat()
|
|
517
|
+
run_record.run.starting_at = now().isoformat()
|
|
493
518
|
elif new_status.status == Status.RUNNING:
|
|
494
|
-
run_record.running_at = now().isoformat()
|
|
519
|
+
run_record.run.running_at = now().isoformat()
|
|
495
520
|
elif new_status.status == Status.FINISHED:
|
|
496
|
-
run_record.finished_at = now().isoformat()
|
|
497
|
-
run_record.status = new_status
|
|
521
|
+
run_record.run.finished_at = now().isoformat()
|
|
522
|
+
run_record.run.status = new_status
|
|
498
523
|
return True
|
|
499
524
|
|
|
500
525
|
def get_pending_run_id(self) -> Optional[int]:
|
|
@@ -504,7 +529,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
504
529
|
# Loop through all registered runs
|
|
505
530
|
for run_id, run_rec in self.run_ids.items():
|
|
506
531
|
# Break once a pending run is found
|
|
507
|
-
if run_rec.status.status == Status.PENDING:
|
|
532
|
+
if run_rec.run.status.status == Status.PENDING:
|
|
508
533
|
pending_run_id = run_id
|
|
509
534
|
break
|
|
510
535
|
|
|
@@ -101,13 +101,27 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
101
101
|
|
|
102
102
|
@abc.abstractmethod
|
|
103
103
|
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
|
104
|
-
"""Get TaskRes for
|
|
104
|
+
"""Get TaskRes for the given TaskIns IDs.
|
|
105
105
|
|
|
106
|
-
|
|
107
|
-
|
|
106
|
+
This method is typically called by the ServerAppIo API to obtain
|
|
107
|
+
results (TaskRes) for previously scheduled instructions (TaskIns).
|
|
108
|
+
For each task_id provided, this method returns one of the following responses:
|
|
108
109
|
|
|
109
|
-
|
|
110
|
-
|
|
110
|
+
- An error TaskRes if the corresponding TaskIns does not exist or has expired.
|
|
111
|
+
- An error TaskRes if the corresponding TaskRes exists but has expired.
|
|
112
|
+
- The valid TaskRes if the TaskIns has a corresponding valid TaskRes.
|
|
113
|
+
- Nothing if the TaskIns is still valid and waiting for a TaskRes.
|
|
114
|
+
|
|
115
|
+
Parameters
|
|
116
|
+
----------
|
|
117
|
+
task_ids : set[UUID]
|
|
118
|
+
A set of TaskIns IDs for which to retrieve results (TaskRes).
|
|
119
|
+
|
|
120
|
+
Returns
|
|
121
|
+
-------
|
|
122
|
+
list[TaskRes]
|
|
123
|
+
A list of TaskRes corresponding to the given task IDs. If no
|
|
124
|
+
TaskRes could be found for any of the task IDs, an empty list is returned.
|
|
111
125
|
"""
|
|
112
126
|
|
|
113
127
|
@abc.abstractmethod
|
|
@@ -57,6 +57,8 @@ from .utils import (
|
|
|
57
57
|
generate_rand_int_from_bytes,
|
|
58
58
|
has_valid_sub_status,
|
|
59
59
|
is_valid_transition,
|
|
60
|
+
verify_found_taskres,
|
|
61
|
+
verify_taskins_ids,
|
|
60
62
|
)
|
|
61
63
|
|
|
62
64
|
SQL_CREATE_TABLE_NODE = """
|
|
@@ -510,136 +512,67 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
510
512
|
|
|
511
513
|
# pylint: disable-next=R0912,R0915,R0914
|
|
512
514
|
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
|
513
|
-
"""Get TaskRes for
|
|
515
|
+
"""Get TaskRes for the given TaskIns IDs."""
|
|
516
|
+
ret: dict[UUID, TaskRes] = {}
|
|
514
517
|
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
Retrieves all TaskRes for the given `task_ids` and returns and empty list if
|
|
519
|
-
none could be found.
|
|
520
|
-
|
|
521
|
-
Constraints
|
|
522
|
-
-----------
|
|
523
|
-
If `limit` is not `None`, return, at most, `limit` number of TaskRes. The limit
|
|
524
|
-
will only take effect if enough task_ids are in the set AND are currently
|
|
525
|
-
available. If `limit` is set, it has to be greater than zero.
|
|
526
|
-
"""
|
|
527
|
-
# Check if corresponding TaskIns exists and is not expired
|
|
528
|
-
task_ids_placeholders = ",".join([f":id_{i}" for i in range(len(task_ids))])
|
|
518
|
+
# Verify TaskIns IDs
|
|
519
|
+
current = time.time()
|
|
529
520
|
query = f"""
|
|
530
521
|
SELECT *
|
|
531
522
|
FROM task_ins
|
|
532
|
-
WHERE task_id IN ({
|
|
533
|
-
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
|
|
523
|
+
WHERE task_id IN ({",".join(["?"] * len(task_ids))});
|
|
534
524
|
"""
|
|
535
|
-
query
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
for index, task_id in enumerate(task_ids):
|
|
539
|
-
task_ins_data[f"id_{index}"] = str(task_id)
|
|
540
|
-
|
|
541
|
-
task_ins_rows = self.query(query, task_ins_data)
|
|
542
|
-
|
|
543
|
-
if not task_ins_rows:
|
|
544
|
-
return []
|
|
545
|
-
|
|
546
|
-
for row in task_ins_rows:
|
|
547
|
-
# Convert values from sint64 to uint64
|
|
525
|
+
rows = self.query(query, tuple(str(task_id) for task_id in task_ids))
|
|
526
|
+
found_task_ins_dict: dict[UUID, TaskIns] = {}
|
|
527
|
+
for row in rows:
|
|
548
528
|
convert_sint64_values_in_dict_to_uint64(
|
|
549
529
|
row, ["run_id", "producer_node_id", "consumer_node_id"]
|
|
550
530
|
)
|
|
551
|
-
|
|
552
|
-
if task_ins.task.created_at + task_ins.task.ttl <= time.time():
|
|
553
|
-
log(WARNING, "TaskIns with task_id %s is expired.", task_ins.task_id)
|
|
554
|
-
task_ids.remove(UUID(task_ins.task_id))
|
|
531
|
+
found_task_ins_dict[UUID(row["task_id"])] = dict_to_task_ins(row)
|
|
555
532
|
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
533
|
+
ret = verify_taskins_ids(
|
|
534
|
+
inquired_taskins_ids=task_ids,
|
|
535
|
+
found_taskins_dict=found_task_ins_dict,
|
|
536
|
+
current_time=current,
|
|
537
|
+
)
|
|
559
538
|
|
|
560
|
-
|
|
539
|
+
# Find all TaskRes
|
|
561
540
|
query = f"""
|
|
562
541
|
SELECT *
|
|
563
542
|
FROM task_res
|
|
564
|
-
WHERE ancestry IN ({
|
|
565
|
-
AND delivered_at = ""
|
|
566
|
-
"""
|
|
567
|
-
|
|
568
|
-
data: dict[str, Union[str, float, int]] = {}
|
|
569
|
-
|
|
570
|
-
query += ";"
|
|
571
|
-
|
|
572
|
-
for index, task_id in enumerate(task_ids):
|
|
573
|
-
data[f"id_{index}"] = str(task_id)
|
|
574
|
-
|
|
575
|
-
rows = self.query(query, data)
|
|
576
|
-
|
|
577
|
-
if rows:
|
|
578
|
-
# Prepare query
|
|
579
|
-
found_task_ids = [row["task_id"] for row in rows]
|
|
580
|
-
placeholders = ",".join([f":id_{i}" for i in range(len(found_task_ids))])
|
|
581
|
-
query = f"""
|
|
582
|
-
UPDATE task_res
|
|
583
|
-
SET delivered_at = :delivered_at
|
|
584
|
-
WHERE task_id IN ({placeholders})
|
|
585
|
-
RETURNING *;
|
|
586
|
-
"""
|
|
587
|
-
|
|
588
|
-
# Prepare data for query
|
|
589
|
-
delivered_at = now().isoformat()
|
|
590
|
-
data = {"delivered_at": delivered_at}
|
|
591
|
-
for index, task_id in enumerate(found_task_ids):
|
|
592
|
-
data[f"id_{index}"] = str(task_id)
|
|
593
|
-
|
|
594
|
-
# Run query
|
|
595
|
-
rows = self.query(query, data)
|
|
596
|
-
|
|
597
|
-
for row in rows:
|
|
598
|
-
# Convert values from sint64 to uint64
|
|
599
|
-
convert_sint64_values_in_dict_to_uint64(
|
|
600
|
-
row, ["run_id", "producer_node_id", "consumer_node_id"]
|
|
601
|
-
)
|
|
602
|
-
|
|
603
|
-
result = [dict_to_task_res(row) for row in rows]
|
|
604
|
-
|
|
605
|
-
# 1. Query: Fetch consumer_node_id of remaining task_ids
|
|
606
|
-
# Assume the ancestry field only contains one element
|
|
607
|
-
data.clear()
|
|
608
|
-
replied_task_ids: set[UUID] = {UUID(str(row["ancestry"])) for row in rows}
|
|
609
|
-
remaining_task_ids = task_ids - replied_task_ids
|
|
610
|
-
placeholders = ",".join([f":id_{i}" for i in range(len(remaining_task_ids))])
|
|
611
|
-
query = f"""
|
|
612
|
-
SELECT consumer_node_id
|
|
613
|
-
FROM task_ins
|
|
614
|
-
WHERE task_id IN ({placeholders});
|
|
543
|
+
WHERE ancestry IN ({",".join(["?"] * len(task_ids))})
|
|
544
|
+
AND delivered_at = "";
|
|
615
545
|
"""
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
546
|
+
rows = self.query(query, tuple(str(task_id) for task_id in task_ids))
|
|
547
|
+
for row in rows:
|
|
548
|
+
convert_sint64_values_in_dict_to_uint64(
|
|
549
|
+
row, ["run_id", "producer_node_id", "consumer_node_id"]
|
|
550
|
+
)
|
|
551
|
+
tmp_ret_dict = verify_found_taskres(
|
|
552
|
+
inquired_taskins_ids=task_ids,
|
|
553
|
+
found_taskins_dict=found_task_ins_dict,
|
|
554
|
+
found_taskres_list=[dict_to_task_res(row) for row in rows],
|
|
555
|
+
current_time=current,
|
|
556
|
+
)
|
|
557
|
+
ret.update(tmp_ret_dict)
|
|
619
558
|
|
|
620
|
-
#
|
|
621
|
-
|
|
559
|
+
# Mark existing TaskRes to be returned as delivered
|
|
560
|
+
delivered_at = now().isoformat()
|
|
561
|
+
for task_res in ret.values():
|
|
562
|
+
task_res.task.delivered_at = delivered_at
|
|
563
|
+
task_res_ids = [task_res.task_id for task_res in ret.values()]
|
|
622
564
|
query = f"""
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
WHERE
|
|
626
|
-
AND online_until < :time;
|
|
565
|
+
UPDATE task_res
|
|
566
|
+
SET delivered_at = ?
|
|
567
|
+
WHERE task_id IN ({",".join(["?"] * len(task_res_ids))});
|
|
627
568
|
"""
|
|
628
|
-
data
|
|
629
|
-
|
|
630
|
-
offline_node_ids = [int(row["node_id"]) for row in self.query(query, data)]
|
|
569
|
+
data: list[Any] = [delivered_at] + task_res_ids
|
|
570
|
+
self.query(query, data)
|
|
631
571
|
|
|
632
|
-
#
|
|
633
|
-
|
|
634
|
-
query = f"""
|
|
635
|
-
SELECT *
|
|
636
|
-
FROM task_ins
|
|
637
|
-
WHERE consumer_node_id IN ({placeholders});
|
|
638
|
-
"""
|
|
639
|
-
data = {f"id_{i}": str(node_id) for i, node_id in enumerate(offline_node_ids)}
|
|
640
|
-
task_ins_rows = self.query(query, data)
|
|
572
|
+
# Cleanup
|
|
573
|
+
self._force_delete_tasks_by_ids(set(ret.keys()))
|
|
641
574
|
|
|
642
|
-
return
|
|
575
|
+
return list(ret.values())
|
|
643
576
|
|
|
644
577
|
def num_task_ins(self) -> int:
|
|
645
578
|
"""Calculate the number of task_ins in store.
|
|
@@ -699,6 +632,32 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
699
632
|
|
|
700
633
|
return None
|
|
701
634
|
|
|
635
|
+
def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
|
|
636
|
+
"""Delete tasks based on a set of TaskIns IDs."""
|
|
637
|
+
if not task_ids:
|
|
638
|
+
return
|
|
639
|
+
if self.conn is None:
|
|
640
|
+
raise AttributeError("LinkState not initialized")
|
|
641
|
+
|
|
642
|
+
placeholders = ",".join([f":id_{index}" for index in range(len(task_ids))])
|
|
643
|
+
data = {f"id_{index}": str(task_id) for index, task_id in enumerate(task_ids)}
|
|
644
|
+
|
|
645
|
+
# Delete task_ins
|
|
646
|
+
query_1 = f"""
|
|
647
|
+
DELETE FROM task_ins
|
|
648
|
+
WHERE task_id IN ({placeholders});
|
|
649
|
+
"""
|
|
650
|
+
|
|
651
|
+
# Delete task_res
|
|
652
|
+
query_2 = f"""
|
|
653
|
+
DELETE FROM task_res
|
|
654
|
+
WHERE ancestry IN ({placeholders});
|
|
655
|
+
"""
|
|
656
|
+
|
|
657
|
+
with self.conn:
|
|
658
|
+
self.conn.execute(query_1, data)
|
|
659
|
+
self.conn.execute(query_2, data)
|
|
660
|
+
|
|
702
661
|
def create_node(
|
|
703
662
|
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
704
663
|
) -> int:
|
|
@@ -922,6 +881,15 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
922
881
|
fab_version=row["fab_version"],
|
|
923
882
|
fab_hash=row["fab_hash"],
|
|
924
883
|
override_config=json.loads(row["override_config"]),
|
|
884
|
+
pending_at=row["pending_at"],
|
|
885
|
+
starting_at=row["starting_at"],
|
|
886
|
+
running_at=row["running_at"],
|
|
887
|
+
finished_at=row["finished_at"],
|
|
888
|
+
status=RunStatus(
|
|
889
|
+
status=determine_run_status(row),
|
|
890
|
+
sub_status=row["sub_status"],
|
|
891
|
+
details=row["details"],
|
|
892
|
+
),
|
|
925
893
|
)
|
|
926
894
|
log(ERROR, "`run_id` does not exist.")
|
|
927
895
|
return None
|
|
@@ -1255,10 +1223,10 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
|
|
|
1255
1223
|
def determine_run_status(row: dict[str, Any]) -> str:
|
|
1256
1224
|
"""Determine the status of the run based on timestamp fields."""
|
|
1257
1225
|
if row["pending_at"]:
|
|
1226
|
+
if row["finished_at"]:
|
|
1227
|
+
return Status.FINISHED
|
|
1258
1228
|
if row["starting_at"]:
|
|
1259
1229
|
if row["running_at"]:
|
|
1260
|
-
if row["finished_at"]:
|
|
1261
|
-
return Status.FINISHED
|
|
1262
1230
|
return Status.RUNNING
|
|
1263
1231
|
return Status.STARTING
|
|
1264
1232
|
return Status.PENDING
|