flwr-nightly 1.14.0.dev20241211__py3-none-any.whl → 1.14.0.dev20241213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +1 -0
- flwr/cli/build.py +1 -0
- flwr/cli/config_utils.py +1 -0
- flwr/cli/example.py +1 -0
- flwr/cli/install.py +1 -0
- flwr/cli/log.py +1 -0
- flwr/cli/login/__init__.py +1 -0
- flwr/cli/login/login.py +1 -0
- flwr/cli/new/__init__.py +1 -0
- flwr/cli/new/new.py +2 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -2
- flwr/cli/run/__init__.py +1 -0
- flwr/cli/run/run.py +1 -0
- flwr/cli/utils.py +1 -0
- flwr/client/app.py +3 -2
- flwr/client/client.py +1 -0
- flwr/client/clientapp/app.py +1 -0
- flwr/client/clientapp/utils.py +1 -0
- flwr/client/grpc_adapter_client/connection.py +1 -1
- flwr/client/grpc_client/connection.py +1 -1
- flwr/client/grpc_rere_client/connection.py +3 -3
- flwr/client/message_handler/message_handler.py +1 -0
- flwr/client/mod/comms_mods.py +1 -0
- flwr/client/mod/localdp_mod.py +1 -1
- flwr/client/nodestate/__init__.py +1 -0
- flwr/client/nodestate/nodestate.py +1 -0
- flwr/client/nodestate/nodestate_factory.py +1 -0
- flwr/client/rest_client/connection.py +3 -3
- flwr/client/supernode/app.py +1 -0
- flwr/common/address.py +1 -0
- flwr/common/args.py +1 -0
- flwr/common/config.py +1 -0
- flwr/common/logger.py +1 -0
- flwr/common/message.py +1 -0
- flwr/common/object_ref.py +57 -54
- flwr/common/pyproject.py +1 -0
- flwr/common/record/__init__.py +1 -0
- flwr/common/record/parametersrecord.py +1 -0
- flwr/common/retry_invoker.py +75 -0
- flwr/common/typing.py +4 -0
- flwr/common/version.py +1 -0
- flwr/proto/fab_pb2.py +4 -4
- flwr/proto/fab_pb2.pyi +4 -1
- flwr/server/app.py +1 -0
- flwr/server/compat/app_utils.py +7 -1
- flwr/server/driver/grpc_driver.py +5 -61
- flwr/server/driver/inmemory_driver.py +5 -1
- flwr/server/serverapp/app.py +9 -2
- flwr/server/strategy/dpfedavg_fixed.py +1 -0
- flwr/server/superlink/driver/serverappio_grpc.py +1 -0
- flwr/server/superlink/driver/serverappio_servicer.py +54 -22
- flwr/server/superlink/ffs/disk_ffs.py +1 -0
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
- flwr/server/superlink/fleet/vce/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
- flwr/server/superlink/linkstate/in_memory_linkstate.py +14 -30
- flwr/server/superlink/linkstate/linkstate.py +13 -2
- flwr/server/superlink/linkstate/sqlite_linkstate.py +24 -44
- flwr/server/superlink/simulation/simulationio_servicer.py +1 -0
- flwr/server/superlink/utils.py +65 -0
- flwr/simulation/app.py +1 -0
- flwr/simulation/ray_transport/ray_actor.py +1 -0
- flwr/simulation/ray_transport/utils.py +1 -0
- flwr/simulation/run_simulation.py +1 -0
- flwr/superexec/app.py +1 -0
- flwr/superexec/deployment.py +1 -0
- flwr/superexec/exec_grpc.py +1 -0
- flwr/superexec/exec_servicer.py +8 -0
- flwr/superexec/executor.py +1 -0
- {flwr_nightly-1.14.0.dev20241211.dist-info → flwr_nightly-1.14.0.dev20241213.dist-info}/METADATA +1 -1
- {flwr_nightly-1.14.0.dev20241211.dist-info → flwr_nightly-1.14.0.dev20241213.dist-info}/RECORD +75 -74
- {flwr_nightly-1.14.0.dev20241211.dist-info → flwr_nightly-1.14.0.dev20241213.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.14.0.dev20241211.dist-info → flwr_nightly-1.14.0.dev20241213.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.14.0.dev20241211.dist-info → flwr_nightly-1.14.0.dev20241213.dist-info}/entry_points.txt +0 -0
flwr/proto/fab_pb2.py
CHANGED
|
@@ -15,7 +15,7 @@ _sym_db = _symbol_database.Default()
|
|
|
15
15
|
from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/fab.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\"(\n\x03\x46\x61\x62\x12\x10\n\x08hash_str\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"
|
|
18
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/fab.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\"(\n\x03\x46\x61\x62\x12\x10\n\x08hash_str\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"Q\n\rGetFabRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08hash_str\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x04\".\n\x0eGetFabResponse\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fabb\x06proto3')
|
|
19
19
|
|
|
20
20
|
_globals = globals()
|
|
21
21
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -25,7 +25,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
|
25
25
|
_globals['_FAB']._serialized_start=59
|
|
26
26
|
_globals['_FAB']._serialized_end=99
|
|
27
27
|
_globals['_GETFABREQUEST']._serialized_start=101
|
|
28
|
-
_globals['_GETFABREQUEST']._serialized_end=
|
|
29
|
-
_globals['_GETFABRESPONSE']._serialized_start=
|
|
30
|
-
_globals['_GETFABRESPONSE']._serialized_end=
|
|
28
|
+
_globals['_GETFABREQUEST']._serialized_end=182
|
|
29
|
+
_globals['_GETFABRESPONSE']._serialized_start=184
|
|
30
|
+
_globals['_GETFABRESPONSE']._serialized_end=230
|
|
31
31
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/fab_pb2.pyi
CHANGED
|
@@ -36,16 +36,19 @@ class GetFabRequest(google.protobuf.message.Message):
|
|
|
36
36
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
37
37
|
NODE_FIELD_NUMBER: builtins.int
|
|
38
38
|
HASH_STR_FIELD_NUMBER: builtins.int
|
|
39
|
+
RUN_ID_FIELD_NUMBER: builtins.int
|
|
39
40
|
@property
|
|
40
41
|
def node(self) -> flwr.proto.node_pb2.Node: ...
|
|
41
42
|
hash_str: typing.Text
|
|
43
|
+
run_id: builtins.int
|
|
42
44
|
def __init__(self,
|
|
43
45
|
*,
|
|
44
46
|
node: typing.Optional[flwr.proto.node_pb2.Node] = ...,
|
|
45
47
|
hash_str: typing.Text = ...,
|
|
48
|
+
run_id: builtins.int = ...,
|
|
46
49
|
) -> None: ...
|
|
47
50
|
def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ...
|
|
48
|
-
def ClearField(self, field_name: typing_extensions.Literal["hash_str",b"hash_str","node",b"node"]) -> None: ...
|
|
51
|
+
def ClearField(self, field_name: typing_extensions.Literal["hash_str",b"hash_str","node",b"node","run_id",b"run_id"]) -> None: ...
|
|
49
52
|
global___GetFabRequest = GetFabRequest
|
|
50
53
|
|
|
51
54
|
class GetFabResponse(google.protobuf.message.Message):
|
flwr/server/app.py
CHANGED
flwr/server/compat/app_utils.py
CHANGED
|
@@ -17,6 +17,8 @@
|
|
|
17
17
|
|
|
18
18
|
import threading
|
|
19
19
|
|
|
20
|
+
from flwr.common.typing import RunNotRunningException
|
|
21
|
+
|
|
20
22
|
from ..client_manager import ClientManager
|
|
21
23
|
from ..compat.driver_client_proxy import DriverClientProxy
|
|
22
24
|
from ..driver import Driver
|
|
@@ -74,7 +76,11 @@ def _update_client_manager(
|
|
|
74
76
|
# Loop until the driver is disconnected
|
|
75
77
|
registered_nodes: dict[int, DriverClientProxy] = {}
|
|
76
78
|
while not f_stop.is_set():
|
|
77
|
-
|
|
79
|
+
try:
|
|
80
|
+
all_node_ids = set(driver.get_node_ids())
|
|
81
|
+
except RunNotRunningException:
|
|
82
|
+
f_stop.set()
|
|
83
|
+
break
|
|
78
84
|
dead_nodes = set(registered_nodes).difference(all_node_ids)
|
|
79
85
|
new_nodes = all_node_ids.difference(registered_nodes)
|
|
80
86
|
|
|
@@ -14,19 +14,20 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower gRPC Driver."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
import time
|
|
18
19
|
import warnings
|
|
19
20
|
from collections.abc import Iterable
|
|
20
|
-
from logging import DEBUG,
|
|
21
|
-
from typing import
|
|
21
|
+
from logging import DEBUG, WARNING
|
|
22
|
+
from typing import Optional, cast
|
|
22
23
|
|
|
23
24
|
import grpc
|
|
24
25
|
|
|
25
26
|
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
26
|
-
from flwr.common.constant import
|
|
27
|
+
from flwr.common.constant import SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS
|
|
27
28
|
from flwr.common.grpc import create_channel
|
|
28
29
|
from flwr.common.logger import log
|
|
29
|
-
from flwr.common.retry_invoker import
|
|
30
|
+
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
|
|
30
31
|
from flwr.common.serde import message_from_taskres, message_to_taskins, run_from_proto
|
|
31
32
|
from flwr.common.typing import Run
|
|
32
33
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
@@ -262,60 +263,3 @@ class GrpcDriver(Driver):
|
|
|
262
263
|
return
|
|
263
264
|
# Disconnect
|
|
264
265
|
self._disconnect()
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
def _make_simple_grpc_retry_invoker() -> RetryInvoker:
|
|
268
|
-
"""Create a simple gRPC retry invoker."""
|
|
269
|
-
|
|
270
|
-
def _on_sucess(retry_state: RetryState) -> None:
|
|
271
|
-
if retry_state.tries > 1:
|
|
272
|
-
log(
|
|
273
|
-
INFO,
|
|
274
|
-
"Connection successful after %.2f seconds and %s tries.",
|
|
275
|
-
retry_state.elapsed_time,
|
|
276
|
-
retry_state.tries,
|
|
277
|
-
)
|
|
278
|
-
|
|
279
|
-
def _on_backoff(retry_state: RetryState) -> None:
|
|
280
|
-
if retry_state.tries == 1:
|
|
281
|
-
log(WARN, "Connection attempt failed, retrying...")
|
|
282
|
-
else:
|
|
283
|
-
log(
|
|
284
|
-
WARN,
|
|
285
|
-
"Connection attempt failed, retrying in %.2f seconds",
|
|
286
|
-
retry_state.actual_wait,
|
|
287
|
-
)
|
|
288
|
-
|
|
289
|
-
def _on_giveup(retry_state: RetryState) -> None:
|
|
290
|
-
if retry_state.tries > 1:
|
|
291
|
-
log(
|
|
292
|
-
WARN,
|
|
293
|
-
"Giving up reconnection after %.2f seconds and %s tries.",
|
|
294
|
-
retry_state.elapsed_time,
|
|
295
|
-
retry_state.tries,
|
|
296
|
-
)
|
|
297
|
-
|
|
298
|
-
return RetryInvoker(
|
|
299
|
-
wait_gen_factory=lambda: exponential(max_delay=MAX_RETRY_DELAY),
|
|
300
|
-
recoverable_exceptions=grpc.RpcError,
|
|
301
|
-
max_tries=None,
|
|
302
|
-
max_time=None,
|
|
303
|
-
on_success=_on_sucess,
|
|
304
|
-
on_backoff=_on_backoff,
|
|
305
|
-
on_giveup=_on_giveup,
|
|
306
|
-
should_giveup=lambda e: e.code() != grpc.StatusCode.UNAVAILABLE, # type: ignore
|
|
307
|
-
)
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
def _wrap_stub(stub: ServerAppIoStub, retry_invoker: RetryInvoker) -> None:
|
|
311
|
-
"""Wrap the gRPC stub with a retry invoker."""
|
|
312
|
-
|
|
313
|
-
def make_lambda(original_method: Any) -> Any:
|
|
314
|
-
return lambda *args, **kwargs: retry_invoker.invoke(
|
|
315
|
-
original_method, *args, **kwargs
|
|
316
|
-
)
|
|
317
|
-
|
|
318
|
-
for method_name in vars(stub):
|
|
319
|
-
method = getattr(stub, method_name)
|
|
320
|
-
if callable(method):
|
|
321
|
-
setattr(stub, method_name, make_lambda(method))
|
|
@@ -142,7 +142,11 @@ class InMemoryDriver(Driver):
|
|
|
142
142
|
# Pull TaskRes
|
|
143
143
|
task_res_list = self.state.get_task_res(task_ids=msg_ids)
|
|
144
144
|
# Delete tasks in state
|
|
145
|
-
|
|
145
|
+
# Delete the TaskIns/TaskRes pairs if TaskRes is found
|
|
146
|
+
task_ins_ids_to_delete = {
|
|
147
|
+
UUID(task_res.task.ancestry[0]) for task_res in task_res_list
|
|
148
|
+
}
|
|
149
|
+
self.state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
|
|
146
150
|
# Convert TaskRes to Message
|
|
147
151
|
msgs = [message_from_taskres(taskres) for taskres in task_res_list]
|
|
148
152
|
return msgs
|
flwr/server/serverapp/app.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower ServerApp process."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
import argparse
|
|
18
19
|
import sys
|
|
19
20
|
from logging import DEBUG, ERROR, INFO
|
|
@@ -50,7 +51,7 @@ from flwr.common.serde import (
|
|
|
50
51
|
run_from_proto,
|
|
51
52
|
run_status_to_proto,
|
|
52
53
|
)
|
|
53
|
-
from flwr.common.typing import RunStatus
|
|
54
|
+
from flwr.common.typing import RunNotRunningException, RunStatus
|
|
54
55
|
from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
|
|
55
56
|
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
56
57
|
PullServerAppInputsRequest,
|
|
@@ -96,7 +97,7 @@ def flwr_serverapp() -> None:
|
|
|
96
97
|
restore_output()
|
|
97
98
|
|
|
98
99
|
|
|
99
|
-
def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
100
|
+
def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
100
101
|
serverappio_api_address: str,
|
|
101
102
|
log_queue: Queue[Optional[str]],
|
|
102
103
|
run_once: bool,
|
|
@@ -187,6 +188,12 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
|
187
188
|
|
|
188
189
|
run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
|
|
189
190
|
|
|
191
|
+
except RunNotRunningException:
|
|
192
|
+
log(INFO, "")
|
|
193
|
+
log(INFO, "Run ID %s stopped.", run.run_id)
|
|
194
|
+
log(INFO, "")
|
|
195
|
+
run_status = None
|
|
196
|
+
|
|
190
197
|
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
191
198
|
exc_entity = "ServerApp"
|
|
192
199
|
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
|
|
@@ -70,6 +70,7 @@ from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
|
|
70
70
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
71
71
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
72
72
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
73
|
+
from flwr.server.superlink.utils import abort_if
|
|
73
74
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
74
75
|
|
|
75
76
|
|
|
@@ -88,7 +89,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
88
89
|
) -> GetNodesResponse:
|
|
89
90
|
"""Get available nodes."""
|
|
90
91
|
log(DEBUG, "ServerAppIoServicer.GetNodes")
|
|
92
|
+
|
|
93
|
+
# Init state
|
|
91
94
|
state: LinkState = self.state_factory.state()
|
|
95
|
+
|
|
96
|
+
# Abort if the run is not running
|
|
97
|
+
abort_if(
|
|
98
|
+
request.run_id,
|
|
99
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
100
|
+
state,
|
|
101
|
+
context,
|
|
102
|
+
)
|
|
103
|
+
|
|
92
104
|
all_ids: set[int] = state.get_nodes(request.run_id)
|
|
93
105
|
nodes: list[Node] = [
|
|
94
106
|
Node(node_id=node_id, anonymous=False) for node_id in all_ids
|
|
@@ -126,6 +138,17 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
126
138
|
"""Push a set of TaskIns."""
|
|
127
139
|
log(DEBUG, "ServerAppIoServicer.PushTaskIns")
|
|
128
140
|
|
|
141
|
+
# Init state
|
|
142
|
+
state: LinkState = self.state_factory.state()
|
|
143
|
+
|
|
144
|
+
# Abort if the run is not running
|
|
145
|
+
abort_if(
|
|
146
|
+
request.run_id,
|
|
147
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
148
|
+
state,
|
|
149
|
+
context,
|
|
150
|
+
)
|
|
151
|
+
|
|
129
152
|
# Set pushed_at (timestamp in seconds)
|
|
130
153
|
pushed_at = time.time()
|
|
131
154
|
for task_ins in request.task_ins_list:
|
|
@@ -137,9 +160,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
137
160
|
validation_errors = validate_task_ins_or_res(task_ins)
|
|
138
161
|
_raise_if(bool(validation_errors), ", ".join(validation_errors))
|
|
139
162
|
|
|
140
|
-
# Init state
|
|
141
|
-
state: LinkState = self.state_factory.state()
|
|
142
|
-
|
|
143
163
|
# Store each TaskIns
|
|
144
164
|
task_ids: list[Optional[UUID]] = []
|
|
145
165
|
for task_ins in request.task_ins_list:
|
|
@@ -156,33 +176,29 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
156
176
|
"""Pull a set of TaskRes."""
|
|
157
177
|
log(DEBUG, "ServerAppIoServicer.PullTaskRes")
|
|
158
178
|
|
|
159
|
-
# Convert each task_id str to UUID
|
|
160
|
-
task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
|
|
161
|
-
|
|
162
179
|
# Init state
|
|
163
180
|
state: LinkState = self.state_factory.state()
|
|
164
181
|
|
|
165
|
-
#
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
if context.is_active():
|
|
173
|
-
return
|
|
174
|
-
if context.code() != grpc.StatusCode.OK:
|
|
175
|
-
return
|
|
176
|
-
|
|
177
|
-
# Delete delivered TaskIns and TaskRes
|
|
178
|
-
state.delete_tasks(task_ids=task_ids)
|
|
182
|
+
# Abort if the run is not running
|
|
183
|
+
abort_if(
|
|
184
|
+
request.run_id,
|
|
185
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
186
|
+
state,
|
|
187
|
+
context,
|
|
188
|
+
)
|
|
179
189
|
|
|
180
|
-
|
|
190
|
+
# Convert each task_id str to UUID
|
|
191
|
+
task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
|
|
181
192
|
|
|
182
193
|
# Read from state
|
|
183
194
|
task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids)
|
|
184
195
|
|
|
185
|
-
|
|
196
|
+
# Delete the TaskIns/TaskRes pairs if TaskRes is found
|
|
197
|
+
task_ins_ids_to_delete = {
|
|
198
|
+
UUID(task_res.task.ancestry[0]) for task_res in task_res_list
|
|
199
|
+
}
|
|
200
|
+
state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
|
|
201
|
+
|
|
186
202
|
return PullTaskResResponse(task_res_list=task_res_list)
|
|
187
203
|
|
|
188
204
|
def GetRun(
|
|
@@ -258,7 +274,18 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
258
274
|
) -> PushServerAppOutputsResponse:
|
|
259
275
|
"""Push ServerApp process outputs."""
|
|
260
276
|
log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
|
|
277
|
+
|
|
278
|
+
# Init state
|
|
261
279
|
state = self.state_factory.state()
|
|
280
|
+
|
|
281
|
+
# Abort if the run is not running
|
|
282
|
+
abort_if(
|
|
283
|
+
request.run_id,
|
|
284
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
|
285
|
+
state,
|
|
286
|
+
context,
|
|
287
|
+
)
|
|
288
|
+
|
|
262
289
|
state.set_serverapp_context(request.run_id, context_from_proto(request.context))
|
|
263
290
|
return PushServerAppOutputsResponse()
|
|
264
291
|
|
|
@@ -267,8 +294,13 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
267
294
|
) -> UpdateRunStatusResponse:
|
|
268
295
|
"""Update the status of a run."""
|
|
269
296
|
log(DEBUG, "ServerAppIoServicer.UpdateRunStatus")
|
|
297
|
+
|
|
298
|
+
# Init state
|
|
270
299
|
state = self.state_factory.state()
|
|
271
300
|
|
|
301
|
+
# Abort if the run is finished
|
|
302
|
+
abort_if(request.run_id, [Status.FINISHED], state, context)
|
|
303
|
+
|
|
272
304
|
# Update the run status
|
|
273
305
|
state.update_run_status(
|
|
274
306
|
run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
|
|
@@ -265,41 +265,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
265
265
|
for task_res in task_res_found:
|
|
266
266
|
task_res.task.delivered_at = delivered_at
|
|
267
267
|
|
|
268
|
-
# Cleanup
|
|
269
|
-
self._force_delete_tasks_by_ids(set(ret.keys()))
|
|
270
|
-
|
|
271
268
|
return list(ret.values())
|
|
272
269
|
|
|
273
|
-
def delete_tasks(self,
|
|
274
|
-
"""Delete
|
|
275
|
-
|
|
276
|
-
task_res_to_be_deleted: set[UUID] = set()
|
|
277
|
-
|
|
278
|
-
with self.lock:
|
|
279
|
-
for task_ins_id in task_ids:
|
|
280
|
-
# Find the task_id of the matching task_res
|
|
281
|
-
for task_res_id, task_res in self.task_res_store.items():
|
|
282
|
-
if UUID(task_res.task.ancestry[0]) != task_ins_id:
|
|
283
|
-
continue
|
|
284
|
-
if task_res.task.delivered_at == "":
|
|
285
|
-
continue
|
|
286
|
-
|
|
287
|
-
task_ins_to_be_deleted.add(task_ins_id)
|
|
288
|
-
task_res_to_be_deleted.add(task_res_id)
|
|
289
|
-
|
|
290
|
-
for task_id in task_ins_to_be_deleted:
|
|
291
|
-
del self.task_ins_store[task_id]
|
|
292
|
-
del self.task_ins_id_to_task_res_id[task_id]
|
|
293
|
-
for task_id in task_res_to_be_deleted:
|
|
294
|
-
del self.task_res_store[task_id]
|
|
295
|
-
|
|
296
|
-
def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
|
|
297
|
-
"""Delete tasks based on a set of TaskIns IDs."""
|
|
298
|
-
if not task_ids:
|
|
270
|
+
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
|
271
|
+
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
|
|
272
|
+
if not task_ins_ids:
|
|
299
273
|
return
|
|
300
274
|
|
|
301
275
|
with self.lock:
|
|
302
|
-
for task_id in
|
|
276
|
+
for task_id in task_ins_ids:
|
|
303
277
|
# Delete TaskIns
|
|
304
278
|
if task_id in self.task_ins_store:
|
|
305
279
|
del self.task_ins_store[task_id]
|
|
@@ -308,6 +282,16 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
|
308
282
|
task_res_id = self.task_ins_id_to_task_res_id.pop(task_id)
|
|
309
283
|
del self.task_res_store[task_res_id]
|
|
310
284
|
|
|
285
|
+
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
286
|
+
"""Get all TaskIns IDs for the given run_id."""
|
|
287
|
+
task_id_list: set[UUID] = set()
|
|
288
|
+
with self.lock:
|
|
289
|
+
for task_id, task_ins in self.task_ins_store.items():
|
|
290
|
+
if task_ins.run_id == run_id:
|
|
291
|
+
task_id_list.add(task_id)
|
|
292
|
+
|
|
293
|
+
return task_id_list
|
|
294
|
+
|
|
311
295
|
def num_task_ins(self) -> int:
|
|
312
296
|
"""Calculate the number of task_ins in store.
|
|
313
297
|
|
|
@@ -139,8 +139,19 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
|
139
139
|
"""
|
|
140
140
|
|
|
141
141
|
@abc.abstractmethod
|
|
142
|
-
def delete_tasks(self,
|
|
143
|
-
"""Delete
|
|
142
|
+
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
|
143
|
+
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
task_ins_ids : set[UUID]
|
|
148
|
+
A set of TaskIns IDs. For each ID in the set, the corresponding
|
|
149
|
+
TaskIns and its associated TaskRes will be deleted.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
@abc.abstractmethod
|
|
153
|
+
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
154
|
+
"""Get all TaskIns IDs for the given run_id."""
|
|
144
155
|
|
|
145
156
|
@abc.abstractmethod
|
|
146
157
|
def create_node(
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""SQLite based implemenation of the link state."""
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
# pylint: disable=too-many-lines
|
|
18
19
|
|
|
19
20
|
import json
|
|
@@ -566,9 +567,6 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
566
567
|
data: list[Any] = [delivered_at] + task_res_ids
|
|
567
568
|
self.query(query, data)
|
|
568
569
|
|
|
569
|
-
# Cleanup
|
|
570
|
-
self._force_delete_tasks_by_ids(set(ret.keys()))
|
|
571
|
-
|
|
572
570
|
return list(ret.values())
|
|
573
571
|
|
|
574
572
|
def num_task_ins(self) -> int:
|
|
@@ -592,68 +590,50 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
|
|
|
592
590
|
result: dict[str, int] = rows[0]
|
|
593
591
|
return result["num"]
|
|
594
592
|
|
|
595
|
-
def delete_tasks(self,
|
|
596
|
-
"""Delete
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
593
|
+
def delete_tasks(self, task_ins_ids: set[UUID]) -> None:
|
|
594
|
+
"""Delete TaskIns/TaskRes pairs based on provided TaskIns IDs."""
|
|
595
|
+
if not task_ins_ids:
|
|
596
|
+
return
|
|
597
|
+
if self.conn is None:
|
|
598
|
+
raise AttributeError("LinkState not initialized")
|
|
600
599
|
|
|
601
|
-
placeholders = ",".join([
|
|
602
|
-
data =
|
|
600
|
+
placeholders = ",".join(["?"] * len(task_ins_ids))
|
|
601
|
+
data = tuple(str(task_id) for task_id in task_ins_ids)
|
|
603
602
|
|
|
604
|
-
#
|
|
603
|
+
# Delete task_ins
|
|
605
604
|
query_1 = f"""
|
|
606
605
|
DELETE FROM task_ins
|
|
607
|
-
WHERE
|
|
608
|
-
AND task_id IN (
|
|
609
|
-
SELECT ancestry
|
|
610
|
-
FROM task_res
|
|
611
|
-
WHERE ancestry IN ({placeholders})
|
|
612
|
-
AND delivered_at != ''
|
|
613
|
-
);
|
|
606
|
+
WHERE task_id IN ({placeholders});
|
|
614
607
|
"""
|
|
615
608
|
|
|
616
|
-
#
|
|
609
|
+
# Delete task_res
|
|
617
610
|
query_2 = f"""
|
|
618
611
|
DELETE FROM task_res
|
|
619
|
-
WHERE ancestry IN ({placeholders})
|
|
620
|
-
AND delivered_at != '';
|
|
612
|
+
WHERE ancestry IN ({placeholders});
|
|
621
613
|
"""
|
|
622
614
|
|
|
623
|
-
if self.conn is None:
|
|
624
|
-
raise AttributeError("LinkState not intitialized")
|
|
625
|
-
|
|
626
615
|
with self.conn:
|
|
627
616
|
self.conn.execute(query_1, data)
|
|
628
617
|
self.conn.execute(query_2, data)
|
|
629
618
|
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
def _force_delete_tasks_by_ids(self, task_ids: set[UUID]) -> None:
|
|
633
|
-
"""Delete tasks based on a set of TaskIns IDs."""
|
|
634
|
-
if not task_ids:
|
|
635
|
-
return
|
|
619
|
+
def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]:
|
|
620
|
+
"""Get all TaskIns IDs for the given run_id."""
|
|
636
621
|
if self.conn is None:
|
|
637
622
|
raise AttributeError("LinkState not initialized")
|
|
638
623
|
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
query_1 = f"""
|
|
644
|
-
DELETE FROM task_ins
|
|
645
|
-
WHERE task_id IN ({placeholders});
|
|
624
|
+
query = """
|
|
625
|
+
SELECT task_id
|
|
626
|
+
FROM task_ins
|
|
627
|
+
WHERE run_id = :run_id;
|
|
646
628
|
"""
|
|
647
629
|
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
DELETE FROM task_res
|
|
651
|
-
WHERE ancestry IN ({placeholders});
|
|
652
|
-
"""
|
|
630
|
+
sint64_run_id = convert_uint64_to_sint64(run_id)
|
|
631
|
+
data = {"run_id": sint64_run_id}
|
|
653
632
|
|
|
654
633
|
with self.conn:
|
|
655
|
-
self.conn.execute(
|
|
656
|
-
|
|
634
|
+
rows = self.conn.execute(query, data).fetchall()
|
|
635
|
+
|
|
636
|
+
return {UUID(row["task_id"]) for row in rows}
|
|
657
637
|
|
|
658
638
|
def create_node(
|
|
659
639
|
self, ping_interval: float, public_key: Optional[bytes] = None
|