flwr-nightly 1.13.0.dev20241106__py3-none-any.whl → 1.13.0.dev20241117__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +37 -0
- flwr/cli/install.py +5 -3
- flwr/cli/ls.py +228 -0
- flwr/cli/run/run.py +16 -5
- flwr/client/app.py +68 -19
- flwr/client/clientapp/app.py +51 -35
- flwr/client/grpc_rere_client/connection.py +2 -12
- flwr/client/nodestate/__init__.py +25 -0
- flwr/client/nodestate/in_memory_nodestate.py +38 -0
- flwr/client/nodestate/nodestate.py +30 -0
- flwr/client/nodestate/nodestate_factory.py +37 -0
- flwr/client/rest_client/connection.py +4 -14
- flwr/client/supernode/app.py +57 -53
- flwr/common/args.py +148 -0
- flwr/common/config.py +10 -0
- flwr/common/constant.py +21 -7
- flwr/common/date.py +18 -0
- flwr/common/logger.py +6 -2
- flwr/common/object_ref.py +47 -16
- flwr/common/serde.py +10 -0
- flwr/common/typing.py +32 -11
- flwr/proto/exec_pb2.py +23 -17
- flwr/proto/exec_pb2.pyi +50 -20
- flwr/proto/exec_pb2_grpc.py +34 -0
- flwr/proto/exec_pb2_grpc.pyi +13 -0
- flwr/proto/run_pb2.py +32 -27
- flwr/proto/run_pb2.pyi +44 -1
- flwr/proto/simulationio_pb2.py +2 -2
- flwr/proto/simulationio_pb2_grpc.py +34 -0
- flwr/proto/simulationio_pb2_grpc.pyi +13 -0
- flwr/server/app.py +83 -87
- flwr/server/driver/driver.py +1 -1
- flwr/server/driver/grpc_driver.py +6 -20
- flwr/server/driver/inmemory_driver.py +1 -3
- flwr/server/run_serverapp.py +8 -238
- flwr/server/serverapp/app.py +44 -89
- flwr/server/strategy/aggregate.py +4 -4
- flwr/server/superlink/fleet/rest_rere/rest_api.py +10 -9
- flwr/server/superlink/linkstate/in_memory_linkstate.py +76 -62
- flwr/server/superlink/linkstate/linkstate.py +24 -9
- flwr/server/superlink/linkstate/sqlite_linkstate.py +87 -128
- flwr/server/superlink/linkstate/utils.py +191 -32
- flwr/server/superlink/simulation/simulationio_servicer.py +22 -1
- flwr/simulation/__init__.py +3 -1
- flwr/simulation/app.py +245 -352
- flwr/simulation/legacy_app.py +402 -0
- flwr/simulation/run_simulation.py +8 -19
- flwr/simulation/simulationio_connection.py +2 -2
- flwr/superexec/deployment.py +13 -7
- flwr/superexec/exec_servicer.py +32 -3
- flwr/superexec/executor.py +4 -3
- flwr/superexec/simulation.py +52 -145
- {flwr_nightly-1.13.0.dev20241106.dist-info → flwr_nightly-1.13.0.dev20241117.dist-info}/METADATA +10 -7
- {flwr_nightly-1.13.0.dev20241106.dist-info → flwr_nightly-1.13.0.dev20241117.dist-info}/RECORD +58 -51
- {flwr_nightly-1.13.0.dev20241106.dist-info → flwr_nightly-1.13.0.dev20241117.dist-info}/entry_points.txt +1 -0
- {flwr_nightly-1.13.0.dev20241106.dist-info → flwr_nightly-1.13.0.dev20241117.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.13.0.dev20241106.dist-info → flwr_nightly-1.13.0.dev20241117.dist-info}/WHEEL +0 -0
|
@@ -40,7 +40,8 @@ from .utils import (
|
|
|
40
40
|
generate_rand_int_from_bytes,
|
|
41
41
|
has_valid_sub_status,
|
|
42
42
|
is_valid_transition,
|
|
43
|
-
|
|
43
|
+
verify_found_taskres,
|
|
44
|
+
verify_taskins_ids,
|
|
44
45
|
)
|
|
45
46
|
|
|
46
47
|
|
|
@@ -49,11 +50,6 @@ class RunRecord: # pylint: disable=R0902
|
|
|
49
50
|
"""The record of a specific run, including its status and timestamps."""
|
|
50
51
|
|
|
51
52
|
run: Run
|
|
52
|
-
status: RunStatus
|
|
53
|
-
pending_at: str = ""
|
|
54
|
-
starting_at: str = ""
|
|
55
|
-
running_at: str = ""
|
|
56
|
-
finished_at: str = ""
|
|
57
53
|
logs: list[tuple[float, str]] = field(default_factory=list)
|
|
58
54
|
log_lock: threading.Lock = field(default_factory=threading.Lock)
|
|
59
55
|
|
|
@@ -73,12 +69,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
73
69
|
self.federation_options: dict[int, ConfigsRecord] = {}
|
|
74
70
|
self.task_ins_store: dict[UUID, TaskIns] = {}
|
|
75
71
|
self.task_res_store: dict[UUID, TaskRes] = {}
|
|
72
|
+
self.task_ins_id_to_task_res_id: dict[UUID, UUID] = {}
|
|
76
73
|
|
|
77
74
|
self.node_public_keys: set[bytes] = set()
|
|
78
75
|
self.server_public_key: Optional[bytes] = None
|
|
79
76
|
self.server_private_key: Optional[bytes] = None
|
|
80
77
|
|
|
81
|
-
self.lock = threading.
|
|
78
|
+
self.lock = threading.RLock()
|
|
82
79
|
|
|
83
80
|
def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
|
|
84
81
|
"""Store one TaskIns."""
|
|
@@ -228,57 +225,50 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
228
225
|
task_res.task_id = str(task_id)
|
|
229
226
|
with self.lock:
|
|
230
227
|
self.task_res_store[task_id] = task_res
|
|
228
|
+
self.task_ins_id_to_task_res_id[UUID(task_ins_id)] = task_id
|
|
231
229
|
|
|
232
230
|
# Return the new task_id
|
|
233
231
|
return task_id
|
|
234
232
|
|
|
235
233
|
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
|
236
|
-
"""Get
|
|
234
|
+
"""Get TaskRes for the given TaskIns IDs."""
|
|
235
|
+
ret: dict[UUID, TaskRes] = {}
|
|
236
|
+
|
|
237
237
|
with self.lock:
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
if
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
continue
|
|
265
|
-
node_id = task_ins.task.consumer.node_id
|
|
266
|
-
online_until, _ = self.node_ids[node_id]
|
|
267
|
-
# Generate a TaskRes containing an error reply if the node is offline.
|
|
268
|
-
if online_until < time.time():
|
|
269
|
-
err_taskres = make_node_unavailable_taskres(
|
|
270
|
-
ref_taskins=task_ins,
|
|
271
|
-
)
|
|
272
|
-
self.task_res_store[UUID(err_taskres.task_id)] = err_taskres
|
|
273
|
-
task_res_list.append(err_taskres)
|
|
274
|
-
|
|
275
|
-
# Mark all of them as delivered
|
|
238
|
+
current = time.time()
|
|
239
|
+
|
|
240
|
+
# Verify TaskIns IDs
|
|
241
|
+
ret = verify_taskins_ids(
|
|
242
|
+
inquired_taskins_ids=task_ids,
|
|
243
|
+
found_taskins_dict=self.task_ins_store,
|
|
244
|
+
current_time=current,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Find all TaskRes
|
|
248
|
+
task_res_found: list[TaskRes] = []
|
|
249
|
+
for task_id in task_ids:
|
|
250
|
+
# If TaskRes exists and is not delivered, add it to the list
|
|
251
|
+
if task_res_id := self.task_ins_id_to_task_res_id.get(task_id):
|
|
252
|
+
task_res = self.task_res_store[task_res_id]
|
|
253
|
+
if task_res.task.delivered_at == "":
|
|
254
|
+
task_res_found.append(task_res)
|
|
255
|
+
tmp_ret_dict = verify_found_taskres(
|
|
256
|
+
inquired_taskins_ids=task_ids,
|
|
257
|
+
found_taskins_dict=self.task_ins_store,
|
|
258
|
+
found_taskres_list=task_res_found,
|
|
259
|
+
current_time=current,
|
|
260
|
+
)
|
|
261
|
+
ret.update(tmp_ret_dict)
|
|
262
|
+
|
|
263
|
+
# Mark existing TaskRes to be returned as delivered
|
|
276
264
|
delivered_at = now().isoformat()
|
|
277
|
-
for task_res in
|
|
265
|
+
for task_res in task_res_found:
|
|
278
266
|
task_res.task.delivered_at = delivered_at
|
|
279
267
|
|
|
280
|
-
#
|
|
281
|
-
|
|
268
|
+
# Cleanup
|
|
269
|
+
self._force_delete_tasks_by_ids(set(ret.keys()))
|
|
270
|
+
|
|
271
|
+
return list(ret.values())
|
|
282
272
|
|
|
283
273
|
def delete_tasks(self, task_ids: set[UUID]) -> None:
|
|
284
274
|
"""Delete all delivered TaskIns/TaskRes pairs."""
|
|
@@ -299,9 +289,25 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
299
289
|
|
|
300
290
|
for task_id in task_ins_to_be_deleted:
|
|
301
291
|
del self.task_ins_store[task_id]
|
|
292
|
+
del self.task_ins_id_to_task_res_id[task_id]
|
|
302
293
|
for task_id in task_res_to_be_deleted:
|
|
303
294
|
del self.task_res_store[task_id]
|
|
304
295
|
|
|
296
|
+
def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
|
|
297
|
+
"""Delete tasks based on a set of TaskIns IDs."""
|
|
298
|
+
if not task_ids:
|
|
299
|
+
return
|
|
300
|
+
|
|
301
|
+
with self.lock:
|
|
302
|
+
for task_id in task_ids:
|
|
303
|
+
# Delete TaskIns
|
|
304
|
+
if task_id in self.task_ins_store:
|
|
305
|
+
del self.task_ins_store[task_id]
|
|
306
|
+
# Delete TaskRes
|
|
307
|
+
if task_id in self.task_ins_id_to_task_res_id:
|
|
308
|
+
task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
|
|
309
|
+
del self.task_res_store[task_res_id]
|
|
310
|
+
|
|
305
311
|
def num_task_ins(self) -> int:
|
|
306
312
|
"""Calculate the number of task_ins in store.
|
|
307
313
|
|
|
@@ -402,13 +408,16 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
402
408
|
fab_version=fab_version if fab_version else "",
|
|
403
409
|
fab_hash=fab_hash if fab_hash else "",
|
|
404
410
|
override_config=override_config,
|
|
411
|
+
pending_at=now().isoformat(),
|
|
412
|
+
starting_at="",
|
|
413
|
+
running_at="",
|
|
414
|
+
finished_at="",
|
|
415
|
+
status=RunStatus(
|
|
416
|
+
status=Status.PENDING,
|
|
417
|
+
sub_status="",
|
|
418
|
+
details="",
|
|
419
|
+
),
|
|
405
420
|
),
|
|
406
|
-
status=RunStatus(
|
|
407
|
-
status=Status.PENDING,
|
|
408
|
-
sub_status="",
|
|
409
|
-
details="",
|
|
410
|
-
),
|
|
411
|
-
pending_at=now().isoformat(),
|
|
412
421
|
)
|
|
413
422
|
self.run_ids[run_id] = run_record
|
|
414
423
|
|
|
@@ -451,6 +460,11 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
451
460
|
"""Retrieve all currently stored `node_public_keys` as a set."""
|
|
452
461
|
return self.node_public_keys
|
|
453
462
|
|
|
463
|
+
def get_run_ids(self) -> set[int]:
|
|
464
|
+
"""Retrieve all run IDs."""
|
|
465
|
+
with self.lock:
|
|
466
|
+
return set(self.run_ids.keys())
|
|
467
|
+
|
|
454
468
|
def get_run(self, run_id: int) -> Optional[Run]:
|
|
455
469
|
"""Retrieve information about the run with the specified `run_id`."""
|
|
456
470
|
with self.lock:
|
|
@@ -463,7 +477,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
463
477
|
"""Retrieve the statuses for the specified runs."""
|
|
464
478
|
with self.lock:
|
|
465
479
|
return {
|
|
466
|
-
run_id: self.run_ids[run_id].status
|
|
480
|
+
run_id: self.run_ids[run_id].run.status
|
|
467
481
|
for run_id in set(run_ids)
|
|
468
482
|
if run_id in self.run_ids
|
|
469
483
|
}
|
|
@@ -477,7 +491,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
477
491
|
return False
|
|
478
492
|
|
|
479
493
|
# Check if the status transition is valid
|
|
480
|
-
current_status = self.run_ids[run_id].status
|
|
494
|
+
current_status = self.run_ids[run_id].run.status
|
|
481
495
|
if not is_valid_transition(current_status, new_status):
|
|
482
496
|
log(
|
|
483
497
|
ERROR,
|
|
@@ -500,12 +514,12 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
500
514
|
# Update the status
|
|
501
515
|
run_record = self.run_ids[run_id]
|
|
502
516
|
if new_status.status == Status.STARTING:
|
|
503
|
-
run_record.starting_at = now().isoformat()
|
|
517
|
+
run_record.run.starting_at = now().isoformat()
|
|
504
518
|
elif new_status.status == Status.RUNNING:
|
|
505
|
-
run_record.running_at = now().isoformat()
|
|
519
|
+
run_record.run.running_at = now().isoformat()
|
|
506
520
|
elif new_status.status == Status.FINISHED:
|
|
507
|
-
run_record.finished_at = now().isoformat()
|
|
508
|
-
run_record.status = new_status
|
|
521
|
+
run_record.run.finished_at = now().isoformat()
|
|
522
|
+
run_record.run.status = new_status
|
|
509
523
|
return True
|
|
510
524
|
|
|
511
525
|
def get_pending_run_id(self) -> Optional[int]:
|
|
@@ -515,7 +529,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
515
529
|
# Loop through all registered runs
|
|
516
530
|
for run_id, run_rec in self.run_ids.items():
|
|
517
531
|
# Break once a pending run is found
|
|
518
|
-
if run_rec.status.status == Status.PENDING:
|
|
532
|
+
if run_rec.run.status.status == Status.PENDING:
|
|
519
533
|
pending_run_id = run_id
|
|
520
534
|
break
|
|
521
535
|
|
|
@@ -101,13 +101,27 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
101
101
|
|
|
102
102
|
@abc.abstractmethod
|
|
103
103
|
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
|
104
|
-
"""Get TaskRes for
|
|
104
|
+
"""Get TaskRes for the given TaskIns IDs.
|
|
105
105
|
|
|
106
|
-
|
|
107
|
-
|
|
106
|
+
This method is typically called by the ServerAppIo API to obtain
|
|
107
|
+
results (TaskRes) for previously scheduled instructions (TaskIns).
|
|
108
|
+
For each task_id provided, this method returns one of the following responses:
|
|
108
109
|
|
|
109
|
-
|
|
110
|
-
|
|
110
|
+
- An error TaskRes if the corresponding TaskIns does not exist or has expired.
|
|
111
|
+
- An error TaskRes if the corresponding TaskRes exists but has expired.
|
|
112
|
+
- The valid TaskRes if the TaskIns has a corresponding valid TaskRes.
|
|
113
|
+
- Nothing if the TaskIns is still valid and waiting for a TaskRes.
|
|
114
|
+
|
|
115
|
+
Parameters
|
|
116
|
+
----------
|
|
117
|
+
task_ids : set[UUID]
|
|
118
|
+
A set of TaskIns IDs for which to retrieve results (TaskRes).
|
|
119
|
+
|
|
120
|
+
Returns
|
|
121
|
+
-------
|
|
122
|
+
list[TaskRes]
|
|
123
|
+
A list of TaskRes corresponding to the given task IDs. If no
|
|
124
|
+
TaskRes could be found for any of the task IDs, an empty list is returned.
|
|
111
125
|
"""
|
|
112
126
|
|
|
113
127
|
@abc.abstractmethod
|
|
@@ -163,6 +177,10 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
163
177
|
) -> int:
|
|
164
178
|
"""Create a new run for the specified `fab_hash`."""
|
|
165
179
|
|
|
180
|
+
@abc.abstractmethod
|
|
181
|
+
def get_run_ids(self) -> set[int]:
|
|
182
|
+
"""Retrieve all run IDs."""
|
|
183
|
+
|
|
166
184
|
@abc.abstractmethod
|
|
167
185
|
def get_run(self, run_id: int) -> Optional[Run]:
|
|
168
186
|
"""Retrieve information about the run with the specified `run_id`.
|
|
@@ -175,10 +193,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
175
193
|
Returns
|
|
176
194
|
-------
|
|
177
195
|
Optional[Run]
|
|
178
|
-
|
|
179
|
-
- `run_id`: The identifier of the run, same as the specified `run_id`.
|
|
180
|
-
- `fab_id`: The identifier of the FAB used in the specified run.
|
|
181
|
-
- `fab_version`: The version of the FAB used in the specified run.
|
|
196
|
+
The `Run` instance if found; otherwise, `None`.
|
|
182
197
|
"""
|
|
183
198
|
|
|
184
199
|
@abc.abstractmethod
|
|
@@ -57,7 +57,8 @@ from .utils import (
|
|
|
57
57
|
generate_rand_int_from_bytes,
|
|
58
58
|
has_valid_sub_status,
|
|
59
59
|
is_valid_transition,
|
|
60
|
-
|
|
60
|
+
verify_found_taskres,
|
|
61
|
+
verify_taskins_ids,
|
|
61
62
|
)
|
|
62
63
|
|
|
63
64
|
SQL_CREATE_TABLE_NODE = """
|
|
@@ -511,150 +512,67 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
511
512
|
|
|
512
513
|
# pylint: disable-next=R0912,R0915,R0914
|
|
513
514
|
def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
|
|
514
|
-
"""Get TaskRes for
|
|
515
|
+
"""Get TaskRes for the given TaskIns IDs."""
|
|
516
|
+
ret: dict[UUID, TaskRes] = {}
|
|
515
517
|
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
Retrieves all TaskRes for the given `task_ids` and returns and empty list if
|
|
520
|
-
none could be found.
|
|
521
|
-
|
|
522
|
-
Constraints
|
|
523
|
-
-----------
|
|
524
|
-
If `limit` is not `None`, return, at most, `limit` number of TaskRes. The limit
|
|
525
|
-
will only take effect if enough task_ids are in the set AND are currently
|
|
526
|
-
available. If `limit` is set, it has to be greater than zero.
|
|
527
|
-
"""
|
|
528
|
-
# Check if corresponding TaskIns exists and is not expired
|
|
529
|
-
task_ids_placeholders = ",".join([f":id_{i}" for i in range(len(task_ids))])
|
|
518
|
+
# Verify TaskIns IDs
|
|
519
|
+
current = time.time()
|
|
530
520
|
query = f"""
|
|
531
521
|
SELECT *
|
|
532
522
|
FROM task_ins
|
|
533
|
-
WHERE task_id IN ({
|
|
534
|
-
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
|
|
523
|
+
WHERE task_id IN ({",".join(["?"] * len(task_ids))});
|
|
535
524
|
"""
|
|
536
|
-
query
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
for index, task_id in enumerate(task_ids):
|
|
540
|
-
task_ins_data[f"id_{index}"] = str(task_id)
|
|
541
|
-
|
|
542
|
-
task_ins_rows = self.query(query, task_ins_data)
|
|
543
|
-
|
|
544
|
-
if not task_ins_rows:
|
|
545
|
-
return []
|
|
546
|
-
|
|
547
|
-
for row in task_ins_rows:
|
|
548
|
-
# Convert values from sint64 to uint64
|
|
525
|
+
rows = self.query(query, tuple(str(task_id) for task_id in task_ids))
|
|
526
|
+
found_task_ins_dict: dict[UUID, TaskIns] = {}
|
|
527
|
+
for row in rows:
|
|
549
528
|
convert_sint64_values_in_dict_to_uint64(
|
|
550
529
|
row, ["run_id", "producer_node_id", "consumer_node_id"]
|
|
551
530
|
)
|
|
552
|
-
|
|
553
|
-
if task_ins.task.created_at + task_ins.task.ttl <= time.time():
|
|
554
|
-
log(WARNING, "TaskIns with task_id %s is expired.", task_ins.task_id)
|
|
555
|
-
task_ids.remove(UUID(task_ins.task_id))
|
|
531
|
+
found_task_ins_dict[UUID(row["task_id"])] = dict_to_task_ins(row)
|
|
556
532
|
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
533
|
+
ret = verify_taskins_ids(
|
|
534
|
+
inquired_taskins_ids=task_ids,
|
|
535
|
+
found_taskins_dict=found_task_ins_dict,
|
|
536
|
+
current_time=current,
|
|
537
|
+
)
|
|
560
538
|
|
|
561
|
-
|
|
539
|
+
# Find all TaskRes
|
|
562
540
|
query = f"""
|
|
563
541
|
SELECT *
|
|
564
542
|
FROM task_res
|
|
565
|
-
WHERE ancestry IN ({
|
|
566
|
-
AND delivered_at = ""
|
|
543
|
+
WHERE ancestry IN ({",".join(["?"] * len(task_ids))})
|
|
544
|
+
AND delivered_at = "";
|
|
567
545
|
"""
|
|
546
|
+
rows = self.query(query, tuple(str(task_id) for task_id in task_ids))
|
|
547
|
+
for row in rows:
|
|
548
|
+
convert_sint64_values_in_dict_to_uint64(
|
|
549
|
+
row, ["run_id", "producer_node_id", "consumer_node_id"]
|
|
550
|
+
)
|
|
551
|
+
tmp_ret_dict = verify_found_taskres(
|
|
552
|
+
inquired_taskins_ids=task_ids,
|
|
553
|
+
found_taskins_dict=found_task_ins_dict,
|
|
554
|
+
found_taskres_list=[dict_to_task_res(row) for row in rows],
|
|
555
|
+
current_time=current,
|
|
556
|
+
)
|
|
557
|
+
ret.update(tmp_ret_dict)
|
|
568
558
|
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
data[f"id_{index}"] = str(task_id)
|
|
575
|
-
|
|
576
|
-
rows = self.query(query, data)
|
|
577
|
-
|
|
578
|
-
if rows:
|
|
579
|
-
# Prepare query
|
|
580
|
-
found_task_ids = [row["task_id"] for row in rows]
|
|
581
|
-
placeholders = ",".join([f":id_{i}" for i in range(len(found_task_ids))])
|
|
582
|
-
query = f"""
|
|
583
|
-
UPDATE task_res
|
|
584
|
-
SET delivered_at = :delivered_at
|
|
585
|
-
WHERE task_id IN ({placeholders})
|
|
586
|
-
RETURNING *;
|
|
587
|
-
"""
|
|
588
|
-
|
|
589
|
-
# Prepare data for query
|
|
590
|
-
delivered_at = now().isoformat()
|
|
591
|
-
data = {"delivered_at": delivered_at}
|
|
592
|
-
for index, task_id in enumerate(found_task_ids):
|
|
593
|
-
data[f"id_{index}"] = str(task_id)
|
|
594
|
-
|
|
595
|
-
# Run query
|
|
596
|
-
rows = self.query(query, data)
|
|
597
|
-
|
|
598
|
-
for row in rows:
|
|
599
|
-
# Convert values from sint64 to uint64
|
|
600
|
-
convert_sint64_values_in_dict_to_uint64(
|
|
601
|
-
row, ["run_id", "producer_node_id", "consumer_node_id"]
|
|
602
|
-
)
|
|
603
|
-
|
|
604
|
-
result = [dict_to_task_res(row) for row in rows]
|
|
605
|
-
|
|
606
|
-
# 1. Query: Fetch consumer_node_id of remaining task_ids
|
|
607
|
-
# Assume the ancestry field only contains one element
|
|
608
|
-
data.clear()
|
|
609
|
-
replied_task_ids: set[UUID] = {UUID(str(row["ancestry"])) for row in rows}
|
|
610
|
-
remaining_task_ids = task_ids - replied_task_ids
|
|
611
|
-
placeholders = ",".join([f":id_{i}" for i in range(len(remaining_task_ids))])
|
|
612
|
-
query = f"""
|
|
613
|
-
SELECT consumer_node_id
|
|
614
|
-
FROM task_ins
|
|
615
|
-
WHERE task_id IN ({placeholders});
|
|
616
|
-
"""
|
|
617
|
-
for index, task_id in enumerate(remaining_task_ids):
|
|
618
|
-
data[f"id_{index}"] = str(task_id)
|
|
619
|
-
node_ids = [int(row["consumer_node_id"]) for row in self.query(query, data)]
|
|
620
|
-
|
|
621
|
-
# 2. Query: Select offline nodes
|
|
622
|
-
placeholders = ",".join([f":id_{i}" for i in range(len(node_ids))])
|
|
623
|
-
query = f"""
|
|
624
|
-
SELECT node_id
|
|
625
|
-
FROM node
|
|
626
|
-
WHERE node_id IN ({placeholders})
|
|
627
|
-
AND online_until < :time;
|
|
628
|
-
"""
|
|
629
|
-
data = {f"id_{i}": str(node_id) for i, node_id in enumerate(node_ids)}
|
|
630
|
-
data["time"] = time.time()
|
|
631
|
-
offline_node_ids = [int(row["node_id"]) for row in self.query(query, data)]
|
|
632
|
-
|
|
633
|
-
# 3. Query: Select TaskIns for offline nodes
|
|
634
|
-
placeholders = ",".join([f":id_{i}" for i in range(len(offline_node_ids))])
|
|
559
|
+
# Mark existing TaskRes to be returned as delivered
|
|
560
|
+
delivered_at = now().isoformat()
|
|
561
|
+
for task_res in ret.values():
|
|
562
|
+
task_res.task.delivered_at = delivered_at
|
|
563
|
+
task_res_ids = [task_res.task_id for task_res in ret.values()]
|
|
635
564
|
query = f"""
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
WHERE
|
|
565
|
+
UPDATE task_res
|
|
566
|
+
SET delivered_at = ?
|
|
567
|
+
WHERE task_id IN ({",".join(["?"] * len(task_res_ids))});
|
|
639
568
|
"""
|
|
640
|
-
data
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
# Make TaskRes containing node unavailabe error
|
|
644
|
-
for row in task_ins_rows:
|
|
645
|
-
for row in rows:
|
|
646
|
-
# Convert values from sint64 to uint64
|
|
647
|
-
convert_sint64_values_in_dict_to_uint64(
|
|
648
|
-
row, ["run_id", "producer_node_id", "consumer_node_id"]
|
|
649
|
-
)
|
|
569
|
+
data: list[Any] = [delivered_at] + task_res_ids
|
|
570
|
+
self.query(query, data)
|
|
650
571
|
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
ref_taskins=task_ins,
|
|
654
|
-
)
|
|
655
|
-
result.append(err_taskres)
|
|
572
|
+
# Cleanup
|
|
573
|
+
self._force_delete_tasks_by_ids(set(ret.keys()))
|
|
656
574
|
|
|
657
|
-
return
|
|
575
|
+
return list(ret.values())
|
|
658
576
|
|
|
659
577
|
def num_task_ins(self) -> int:
|
|
660
578
|
"""Calculate the number of task_ins in store.
|
|
@@ -714,6 +632,32 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
714
632
|
|
|
715
633
|
return None
|
|
716
634
|
|
|
635
|
+
def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
|
|
636
|
+
"""Delete tasks based on a set of TaskIns IDs."""
|
|
637
|
+
if not task_ids:
|
|
638
|
+
return
|
|
639
|
+
if self.conn is None:
|
|
640
|
+
raise AttributeError("LinkState not initialized")
|
|
641
|
+
|
|
642
|
+
placeholders = ",".join([f":id_{index}" for index in range(len(task_ids))])
|
|
643
|
+
data = {f"id_{index}": str(task_id) for index, task_id in enumerate(task_ids)}
|
|
644
|
+
|
|
645
|
+
# Delete task_ins
|
|
646
|
+
query_1 = f"""
|
|
647
|
+
DELETE FROM task_ins
|
|
648
|
+
WHERE task_id IN ({placeholders});
|
|
649
|
+
"""
|
|
650
|
+
|
|
651
|
+
# Delete task_res
|
|
652
|
+
query_2 = f"""
|
|
653
|
+
DELETE FROM task_res
|
|
654
|
+
WHERE ancestry IN ({placeholders});
|
|
655
|
+
"""
|
|
656
|
+
|
|
657
|
+
with self.conn:
|
|
658
|
+
self.conn.execute(query_1, data)
|
|
659
|
+
self.conn.execute(query_2, data)
|
|
660
|
+
|
|
717
661
|
def create_node(
|
|
718
662
|
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
719
663
|
) -> int:
|
|
@@ -917,6 +861,12 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
917
861
|
result: set[bytes] = {row["public_key"] for row in rows}
|
|
918
862
|
return result
|
|
919
863
|
|
|
864
|
+
def get_run_ids(self) -> set[int]:
|
|
865
|
+
"""Retrieve all run IDs."""
|
|
866
|
+
query = "SELECT run_id FROM run;"
|
|
867
|
+
rows = self.query(query)
|
|
868
|
+
return {convert_sint64_to_uint64(row["run_id"]) for row in rows}
|
|
869
|
+
|
|
920
870
|
def get_run(self, run_id: int) -> Optional[Run]:
|
|
921
871
|
"""Retrieve information about the run with the specified `run_id`."""
|
|
922
872
|
# Convert the uint64 value to sint64 for SQLite
|
|
@@ -931,6 +881,15 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
931
881
|
fab_version=row["fab_version"],
|
|
932
882
|
fab_hash=row["fab_hash"],
|
|
933
883
|
override_config=json.loads(row["override_config"]),
|
|
884
|
+
pending_at=row["pending_at"],
|
|
885
|
+
starting_at=row["starting_at"],
|
|
886
|
+
running_at=row["running_at"],
|
|
887
|
+
finished_at=row["finished_at"],
|
|
888
|
+
status=RunStatus(
|
|
889
|
+
status=determine_run_status(row),
|
|
890
|
+
sub_status=row["sub_status"],
|
|
891
|
+
details=row["details"],
|
|
892
|
+
),
|
|
934
893
|
)
|
|
935
894
|
log(ERROR, "`run_id` does not exist.")
|
|
936
895
|
return None
|
|
@@ -1264,10 +1223,10 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes:
|
|
|
1264
1223
|
def determine_run_status(row: dict[str, Any]) -> str:
|
|
1265
1224
|
"""Determine the status of the run based on timestamp fields."""
|
|
1266
1225
|
if row["pending_at"]:
|
|
1226
|
+
if row["finished_at"]:
|
|
1227
|
+
return Status.FINISHED
|
|
1267
1228
|
if row["starting_at"]:
|
|
1268
1229
|
if row["running_at"]:
|
|
1269
|
-
if row["finished_at"]:
|
|
1270
|
-
return Status.FINISHED
|
|
1271
1230
|
return Status.RUNNING
|
|
1272
1231
|
return Status.STARTING
|
|
1273
1232
|
return Status.PENDING
|