flwr-nightly 1.13.0.dev20241021__py3-none-any.whl → 1.13.0.dev20241111__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +2 -2
- flwr/cli/config_utils.py +97 -0
- flwr/cli/log.py +63 -97
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +34 -88
- flwr/client/app.py +23 -20
- flwr/client/clientapp/app.py +22 -18
- flwr/client/nodestate/__init__.py +25 -0
- flwr/client/nodestate/in_memory_nodestate.py +38 -0
- flwr/client/nodestate/nodestate.py +30 -0
- flwr/client/nodestate/nodestate_factory.py +37 -0
- flwr/client/{node_state.py → run_info_store.py} +4 -3
- flwr/client/supernode/app.py +6 -8
- flwr/common/args.py +83 -0
- flwr/common/config.py +10 -0
- flwr/common/constant.py +39 -5
- flwr/common/context.py +9 -4
- flwr/common/date.py +3 -3
- flwr/common/logger.py +108 -1
- flwr/common/object_ref.py +47 -16
- flwr/common/serde.py +24 -0
- flwr/common/telemetry.py +0 -6
- flwr/common/typing.py +10 -1
- flwr/proto/exec_pb2.py +14 -17
- flwr/proto/exec_pb2.pyi +14 -22
- flwr/proto/log_pb2.py +29 -0
- flwr/proto/log_pb2.pyi +39 -0
- flwr/proto/log_pb2_grpc.py +4 -0
- flwr/proto/log_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +8 -8
- flwr/proto/message_pb2.pyi +4 -1
- flwr/proto/run_pb2.py +32 -27
- flwr/proto/run_pb2.pyi +26 -0
- flwr/proto/serverappio_pb2.py +52 -0
- flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +54 -0
- flwr/proto/serverappio_pb2_grpc.py +376 -0
- flwr/proto/serverappio_pb2_grpc.pyi +147 -0
- flwr/proto/simulationio_pb2.py +38 -0
- flwr/proto/simulationio_pb2.pyi +65 -0
- flwr/proto/simulationio_pb2_grpc.py +205 -0
- flwr/proto/simulationio_pb2_grpc.pyi +81 -0
- flwr/server/app.py +272 -105
- flwr/server/driver/driver.py +15 -1
- flwr/server/driver/grpc_driver.py +25 -36
- flwr/server/driver/inmemory_driver.py +6 -16
- flwr/server/run_serverapp.py +29 -23
- flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
- flwr/server/serverapp/app.py +214 -0
- flwr/server/strategy/aggregate.py +4 -4
- flwr/server/strategy/fedadam.py +11 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/{driver_grpc.py → serverappio_grpc.py} +19 -16
- flwr/server/superlink/driver/{driver_servicer.py → serverappio_servicer.py} +125 -39
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -2
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
- flwr/server/superlink/fleet/rest_rere/rest_api.py +7 -7
- flwr/server/superlink/fleet/vce/vce_api.py +23 -23
- flwr/server/superlink/linkstate/__init__.py +28 -0
- flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +184 -36
- flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +149 -19
- flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
- flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +306 -65
- flwr/server/superlink/{state → linkstate}/utils.py +81 -30
- flwr/server/superlink/simulation/__init__.py +15 -0
- flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
- flwr/server/superlink/simulation/simulationio_servicer.py +153 -0
- flwr/simulation/__init__.py +5 -1
- flwr/simulation/app.py +273 -345
- flwr/simulation/legacy_app.py +382 -0
- flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
- flwr/simulation/run_simulation.py +57 -131
- flwr/simulation/simulationio_connection.py +86 -0
- flwr/superexec/app.py +6 -134
- flwr/superexec/deployment.py +61 -66
- flwr/superexec/exec_grpc.py +15 -8
- flwr/superexec/exec_servicer.py +36 -65
- flwr/superexec/executor.py +26 -7
- flwr/superexec/simulation.py +54 -107
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/METADATA +5 -4
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/RECORD +88 -69
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/entry_points.txt +2 -0
- flwr/client/node_state_tests.py +0 -66
- flwr/proto/driver_pb2.py +0 -42
- flwr/proto/driver_pb2_grpc.py +0 -239
- flwr/proto/driver_pb2_grpc.pyi +0 -94
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241111.dist-info}/WHEEL +0 -0
|
@@ -28,7 +28,7 @@ from typing import Callable, Optional
|
|
|
28
28
|
|
|
29
29
|
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
30
30
|
from flwr.client.clientapp.utils import get_load_client_app_fn
|
|
31
|
-
from flwr.client.
|
|
31
|
+
from flwr.client.run_info_store import DeprecatedRunInfoStore
|
|
32
32
|
from flwr.common.constant import (
|
|
33
33
|
NUM_PARTITIONS_KEY,
|
|
34
34
|
PARTITION_ID_KEY,
|
|
@@ -40,7 +40,7 @@ from flwr.common.message import Error
|
|
|
40
40
|
from flwr.common.serde import message_from_taskins, message_to_taskres
|
|
41
41
|
from flwr.common.typing import Run
|
|
42
42
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
43
|
-
from flwr.server.superlink.
|
|
43
|
+
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
44
44
|
|
|
45
45
|
from .backend import Backend, error_messages_backends, supported_backends
|
|
46
46
|
|
|
@@ -48,7 +48,7 @@ NodeToPartitionMapping = dict[int, int]
|
|
|
48
48
|
|
|
49
49
|
|
|
50
50
|
def _register_nodes(
|
|
51
|
-
num_nodes: int, state_factory:
|
|
51
|
+
num_nodes: int, state_factory: LinkStateFactory
|
|
52
52
|
) -> NodeToPartitionMapping:
|
|
53
53
|
"""Register nodes with the StateFactory and create node-id:partition-id mapping."""
|
|
54
54
|
nodes_mapping: NodeToPartitionMapping = {}
|
|
@@ -60,16 +60,16 @@ def _register_nodes(
|
|
|
60
60
|
return nodes_mapping
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
def
|
|
63
|
+
def _register_node_info_stores(
|
|
64
64
|
nodes_mapping: NodeToPartitionMapping,
|
|
65
65
|
run: Run,
|
|
66
66
|
app_dir: Optional[str] = None,
|
|
67
|
-
) -> dict[int,
|
|
68
|
-
"""Create
|
|
69
|
-
|
|
67
|
+
) -> dict[int, DeprecatedRunInfoStore]:
|
|
68
|
+
"""Create DeprecatedRunInfoStore objects and register the context for the run."""
|
|
69
|
+
node_info_store: dict[int, DeprecatedRunInfoStore] = {}
|
|
70
70
|
num_partitions = len(set(nodes_mapping.values()))
|
|
71
71
|
for node_id, partition_id in nodes_mapping.items():
|
|
72
|
-
|
|
72
|
+
node_info_store[node_id] = DeprecatedRunInfoStore(
|
|
73
73
|
node_id=node_id,
|
|
74
74
|
node_config={
|
|
75
75
|
PARTITION_ID_KEY: partition_id,
|
|
@@ -78,18 +78,18 @@ def _register_node_states(
|
|
|
78
78
|
)
|
|
79
79
|
|
|
80
80
|
# Pre-register Context objects
|
|
81
|
-
|
|
81
|
+
node_info_store[node_id].register_context(
|
|
82
82
|
run_id=run.run_id, run=run, app_dir=app_dir
|
|
83
83
|
)
|
|
84
84
|
|
|
85
|
-
return
|
|
85
|
+
return node_info_store
|
|
86
86
|
|
|
87
87
|
|
|
88
88
|
# pylint: disable=too-many-arguments,too-many-locals
|
|
89
89
|
def worker(
|
|
90
90
|
taskins_queue: "Queue[TaskIns]",
|
|
91
91
|
taskres_queue: "Queue[TaskRes]",
|
|
92
|
-
|
|
92
|
+
node_info_store: dict[int, DeprecatedRunInfoStore],
|
|
93
93
|
backend: Backend,
|
|
94
94
|
f_stop: threading.Event,
|
|
95
95
|
) -> None:
|
|
@@ -103,7 +103,7 @@ def worker(
|
|
|
103
103
|
node_id = task_ins.task.consumer.node_id
|
|
104
104
|
|
|
105
105
|
# Retrieve context
|
|
106
|
-
context =
|
|
106
|
+
context = node_info_store[node_id].retrieve_context(run_id=task_ins.run_id)
|
|
107
107
|
|
|
108
108
|
# Convert TaskIns to Message
|
|
109
109
|
message = message_from_taskins(task_ins)
|
|
@@ -112,7 +112,7 @@ def worker(
|
|
|
112
112
|
out_mssg, updated_context = backend.process_message(message, context)
|
|
113
113
|
|
|
114
114
|
# Update Context
|
|
115
|
-
|
|
115
|
+
node_info_store[node_id].update_context(
|
|
116
116
|
task_ins.run_id, context=updated_context
|
|
117
117
|
)
|
|
118
118
|
except Empty:
|
|
@@ -145,7 +145,7 @@ def worker(
|
|
|
145
145
|
|
|
146
146
|
|
|
147
147
|
def add_taskins_to_queue(
|
|
148
|
-
state:
|
|
148
|
+
state: LinkState,
|
|
149
149
|
queue: "Queue[TaskIns]",
|
|
150
150
|
nodes_mapping: NodeToPartitionMapping,
|
|
151
151
|
f_stop: threading.Event,
|
|
@@ -160,7 +160,7 @@ def add_taskins_to_queue(
|
|
|
160
160
|
|
|
161
161
|
|
|
162
162
|
def put_taskres_into_state(
|
|
163
|
-
state:
|
|
163
|
+
state: LinkState, queue: "Queue[TaskRes]", f_stop: threading.Event
|
|
164
164
|
) -> None:
|
|
165
165
|
"""Put TaskRes into State from a queue."""
|
|
166
166
|
while not f_stop.is_set():
|
|
@@ -177,8 +177,8 @@ def run_api(
|
|
|
177
177
|
app_fn: Callable[[], ClientApp],
|
|
178
178
|
backend_fn: Callable[[], Backend],
|
|
179
179
|
nodes_mapping: NodeToPartitionMapping,
|
|
180
|
-
state_factory:
|
|
181
|
-
|
|
180
|
+
state_factory: LinkStateFactory,
|
|
181
|
+
node_info_stores: dict[int, DeprecatedRunInfoStore],
|
|
182
182
|
f_stop: threading.Event,
|
|
183
183
|
) -> None:
|
|
184
184
|
"""Run the VCE."""
|
|
@@ -223,7 +223,7 @@ def run_api(
|
|
|
223
223
|
worker,
|
|
224
224
|
taskins_queue,
|
|
225
225
|
taskres_queue,
|
|
226
|
-
|
|
226
|
+
node_info_stores,
|
|
227
227
|
backend,
|
|
228
228
|
f_stop,
|
|
229
229
|
)
|
|
@@ -264,7 +264,7 @@ def start_vce(
|
|
|
264
264
|
client_app: Optional[ClientApp] = None,
|
|
265
265
|
client_app_attr: Optional[str] = None,
|
|
266
266
|
num_supernodes: Optional[int] = None,
|
|
267
|
-
state_factory: Optional[
|
|
267
|
+
state_factory: Optional[LinkStateFactory] = None,
|
|
268
268
|
existing_nodes_mapping: Optional[NodeToPartitionMapping] = None,
|
|
269
269
|
) -> None:
|
|
270
270
|
"""Start Fleet API with the Simulation Engine."""
|
|
@@ -303,7 +303,7 @@ def start_vce(
|
|
|
303
303
|
if not state_factory:
|
|
304
304
|
log(INFO, "A StateFactory was not supplied to the SimulationEngine.")
|
|
305
305
|
# Create an empty in-memory state factory
|
|
306
|
-
state_factory =
|
|
306
|
+
state_factory = LinkStateFactory(":flwr-in-memory-state:")
|
|
307
307
|
log(INFO, "Created new %s.", state_factory.__class__.__name__)
|
|
308
308
|
|
|
309
309
|
if num_supernodes:
|
|
@@ -312,8 +312,8 @@ def start_vce(
|
|
|
312
312
|
num_nodes=num_supernodes, state_factory=state_factory
|
|
313
313
|
)
|
|
314
314
|
|
|
315
|
-
# Construct mapping of
|
|
316
|
-
|
|
315
|
+
# Construct mapping of DeprecatedRunInfoStore
|
|
316
|
+
node_info_stores = _register_node_info_stores(
|
|
317
317
|
nodes_mapping=nodes_mapping, run=run, app_dir=app_dir if is_app else None
|
|
318
318
|
)
|
|
319
319
|
|
|
@@ -376,7 +376,7 @@ def start_vce(
|
|
|
376
376
|
backend_fn,
|
|
377
377
|
nodes_mapping,
|
|
378
378
|
state_factory,
|
|
379
|
-
|
|
379
|
+
node_info_stores,
|
|
380
380
|
f_stop,
|
|
381
381
|
)
|
|
382
382
|
except LoadClientAppError as loadapp_ex:
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower LinkState."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from .in_memory_linkstate import InMemoryLinkState as InMemoryLinkState
|
|
19
|
+
from .linkstate import LinkState as LinkState
|
|
20
|
+
from .linkstate_factory import LinkStateFactory as LinkStateFactory
|
|
21
|
+
from .sqlite_linkstate import SqliteLinkState as SqliteLinkState
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"InMemoryLinkState",
|
|
25
|
+
"LinkState",
|
|
26
|
+
"LinkStateFactory",
|
|
27
|
+
"SqliteLinkState",
|
|
28
|
+
]
|
|
@@ -12,31 +12,53 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""In-memory
|
|
15
|
+
"""In-memory LinkState implementation."""
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import threading
|
|
19
19
|
import time
|
|
20
|
+
from bisect import bisect_right
|
|
21
|
+
from dataclasses import dataclass, field
|
|
20
22
|
from logging import ERROR, WARNING
|
|
21
23
|
from typing import Optional
|
|
22
24
|
from uuid import UUID, uuid4
|
|
23
25
|
|
|
24
|
-
from flwr.common import log, now
|
|
26
|
+
from flwr.common import Context, log, now
|
|
25
27
|
from flwr.common.constant import (
|
|
26
28
|
MESSAGE_TTL_TOLERANCE,
|
|
27
29
|
NODE_ID_NUM_BYTES,
|
|
28
30
|
RUN_ID_NUM_BYTES,
|
|
31
|
+
Status,
|
|
29
32
|
)
|
|
30
|
-
from flwr.common.
|
|
33
|
+
from flwr.common.record import ConfigsRecord
|
|
34
|
+
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
31
35
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
32
|
-
from flwr.server.superlink.
|
|
36
|
+
from flwr.server.superlink.linkstate.linkstate import LinkState
|
|
33
37
|
from flwr.server.utils import validate_task_ins_or_res
|
|
34
38
|
|
|
35
|
-
from .utils import
|
|
39
|
+
from .utils import (
|
|
40
|
+
generate_rand_int_from_bytes,
|
|
41
|
+
has_valid_sub_status,
|
|
42
|
+
is_valid_transition,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class RunRecord: # pylint: disable=R0902
|
|
48
|
+
"""The record of a specific run, including its status and timestamps."""
|
|
36
49
|
|
|
50
|
+
run: Run
|
|
51
|
+
status: RunStatus
|
|
52
|
+
pending_at: str = ""
|
|
53
|
+
starting_at: str = ""
|
|
54
|
+
running_at: str = ""
|
|
55
|
+
finished_at: str = ""
|
|
56
|
+
logs: list[tuple[float, str]] = field(default_factory=list)
|
|
57
|
+
log_lock: threading.Lock = field(default_factory=threading.Lock)
|
|
37
58
|
|
|
38
|
-
|
|
39
|
-
|
|
59
|
+
|
|
60
|
+
class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
61
|
+
"""In-memory LinkState implementation."""
|
|
40
62
|
|
|
41
63
|
def __init__(self) -> None:
|
|
42
64
|
|
|
@@ -44,8 +66,10 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
44
66
|
self.node_ids: dict[int, tuple[float, float]] = {}
|
|
45
67
|
self.public_key_to_node_id: dict[bytes, int] = {}
|
|
46
68
|
|
|
47
|
-
# Map run_id to
|
|
48
|
-
self.run_ids: dict[int,
|
|
69
|
+
# Map run_id to RunRecord
|
|
70
|
+
self.run_ids: dict[int, RunRecord] = {}
|
|
71
|
+
self.contexts: dict[int, Context] = {}
|
|
72
|
+
self.federation_options: dict[int, ConfigsRecord] = {}
|
|
49
73
|
self.task_ins_store: dict[UUID, TaskIns] = {}
|
|
50
74
|
self.task_res_store: dict[UUID, TaskRes] = {}
|
|
51
75
|
|
|
@@ -64,8 +88,25 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
64
88
|
return None
|
|
65
89
|
# Validate run_id
|
|
66
90
|
if task_ins.run_id not in self.run_ids:
|
|
67
|
-
log(ERROR, "
|
|
91
|
+
log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id)
|
|
68
92
|
return None
|
|
93
|
+
# Validate source node ID
|
|
94
|
+
if task_ins.task.producer.node_id != 0:
|
|
95
|
+
log(
|
|
96
|
+
ERROR,
|
|
97
|
+
"Invalid source node ID for TaskIns: %s",
|
|
98
|
+
task_ins.task.producer.node_id,
|
|
99
|
+
)
|
|
100
|
+
return None
|
|
101
|
+
# Validate destination node ID
|
|
102
|
+
if not task_ins.task.consumer.anonymous:
|
|
103
|
+
if task_ins.task.consumer.node_id not in self.node_ids:
|
|
104
|
+
log(
|
|
105
|
+
ERROR,
|
|
106
|
+
"Invalid destination node ID for TaskIns: %s",
|
|
107
|
+
task_ins.task.consumer.node_id,
|
|
108
|
+
)
|
|
109
|
+
return None
|
|
69
110
|
|
|
70
111
|
# Create task_id
|
|
71
112
|
task_id = uuid4()
|
|
@@ -215,21 +256,6 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
215
256
|
task_res_list.append(task_res)
|
|
216
257
|
replied_task_ids.add(reply_to)
|
|
217
258
|
|
|
218
|
-
# Check if the node is offline
|
|
219
|
-
for task_id in task_ids - replied_task_ids:
|
|
220
|
-
task_ins = self.task_ins_store.get(task_id)
|
|
221
|
-
if task_ins is None:
|
|
222
|
-
continue
|
|
223
|
-
node_id = task_ins.task.consumer.node_id
|
|
224
|
-
online_until, _ = self.node_ids[node_id]
|
|
225
|
-
# Generate a TaskRes containing an error reply if the node is offline.
|
|
226
|
-
if online_until < time.time():
|
|
227
|
-
err_taskres = make_node_unavailable_taskres(
|
|
228
|
-
ref_taskins=task_ins,
|
|
229
|
-
)
|
|
230
|
-
self.task_res_store[UUID(err_taskres.task_id)] = err_taskres
|
|
231
|
-
task_res_list.append(err_taskres)
|
|
232
|
-
|
|
233
259
|
# Mark all of them as delivered
|
|
234
260
|
delivered_at = now().isoformat()
|
|
235
261
|
for task_res in task_res_list:
|
|
@@ -277,7 +303,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
277
303
|
def create_node(
|
|
278
304
|
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
279
305
|
) -> int:
|
|
280
|
-
"""Create, store in state, and return `node_id`."""
|
|
306
|
+
"""Create, store in the link state, and return `node_id`."""
|
|
281
307
|
# Sample a random int64 as node_id
|
|
282
308
|
node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
|
|
283
309
|
|
|
@@ -338,12 +364,14 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
338
364
|
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
|
|
339
365
|
return self.public_key_to_node_id.get(node_public_key)
|
|
340
366
|
|
|
367
|
+
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
341
368
|
def create_run(
|
|
342
369
|
self,
|
|
343
370
|
fab_id: Optional[str],
|
|
344
371
|
fab_version: Optional[str],
|
|
345
372
|
fab_hash: Optional[str],
|
|
346
373
|
override_config: UserConfig,
|
|
374
|
+
federation_options: ConfigsRecord,
|
|
347
375
|
) -> int:
|
|
348
376
|
"""Create a new run for the specified `fab_hash`."""
|
|
349
377
|
# Sample a random int64 as run_id
|
|
@@ -351,13 +379,25 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
351
379
|
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
|
|
352
380
|
|
|
353
381
|
if run_id not in self.run_ids:
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
382
|
+
run_record = RunRecord(
|
|
383
|
+
run=Run(
|
|
384
|
+
run_id=run_id,
|
|
385
|
+
fab_id=fab_id if fab_id else "",
|
|
386
|
+
fab_version=fab_version if fab_version else "",
|
|
387
|
+
fab_hash=fab_hash if fab_hash else "",
|
|
388
|
+
override_config=override_config,
|
|
389
|
+
),
|
|
390
|
+
status=RunStatus(
|
|
391
|
+
status=Status.PENDING,
|
|
392
|
+
sub_status="",
|
|
393
|
+
details="",
|
|
394
|
+
),
|
|
395
|
+
pending_at=now().isoformat(),
|
|
360
396
|
)
|
|
397
|
+
self.run_ids[run_id] = run_record
|
|
398
|
+
|
|
399
|
+
# Record federation options. Leave empty if not passed
|
|
400
|
+
self.federation_options[run_id] = federation_options
|
|
361
401
|
return run_id
|
|
362
402
|
log(ERROR, "Unexpected run creation failure.")
|
|
363
403
|
return 0
|
|
@@ -365,7 +405,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
365
405
|
def store_server_private_public_key(
|
|
366
406
|
self, private_key: bytes, public_key: bytes
|
|
367
407
|
) -> None:
|
|
368
|
-
"""Store `server_private_key` and `server_public_key` in state."""
|
|
408
|
+
"""Store `server_private_key` and `server_public_key` in the link state."""
|
|
369
409
|
with self.lock:
|
|
370
410
|
if self.server_private_key is None and self.server_public_key is None:
|
|
371
411
|
self.server_private_key = private_key
|
|
@@ -382,12 +422,12 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
382
422
|
return self.server_public_key
|
|
383
423
|
|
|
384
424
|
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
|
385
|
-
"""Store a set of `node_public_keys` in state."""
|
|
425
|
+
"""Store a set of `node_public_keys` in the link state."""
|
|
386
426
|
with self.lock:
|
|
387
427
|
self.node_public_keys = public_keys
|
|
388
428
|
|
|
389
429
|
def store_node_public_key(self, public_key: bytes) -> None:
|
|
390
|
-
"""Store a `node_public_key` in state."""
|
|
430
|
+
"""Store a `node_public_key` in the link state."""
|
|
391
431
|
with self.lock:
|
|
392
432
|
self.node_public_keys.add(public_key)
|
|
393
433
|
|
|
@@ -395,13 +435,88 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
395
435
|
"""Retrieve all currently stored `node_public_keys` as a set."""
|
|
396
436
|
return self.node_public_keys
|
|
397
437
|
|
|
438
|
+
def get_run_ids(self) -> set[int]:
|
|
439
|
+
"""Retrieve all run IDs."""
|
|
440
|
+
with self.lock:
|
|
441
|
+
return set(self.run_ids.keys())
|
|
442
|
+
|
|
398
443
|
def get_run(self, run_id: int) -> Optional[Run]:
|
|
399
444
|
"""Retrieve information about the run with the specified `run_id`."""
|
|
400
445
|
with self.lock:
|
|
401
446
|
if run_id not in self.run_ids:
|
|
402
447
|
log(ERROR, "`run_id` is invalid")
|
|
403
448
|
return None
|
|
404
|
-
return self.run_ids[run_id]
|
|
449
|
+
return self.run_ids[run_id].run
|
|
450
|
+
|
|
451
|
+
def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
|
|
452
|
+
"""Retrieve the statuses for the specified runs."""
|
|
453
|
+
with self.lock:
|
|
454
|
+
return {
|
|
455
|
+
run_id: self.run_ids[run_id].status
|
|
456
|
+
for run_id in set(run_ids)
|
|
457
|
+
if run_id in self.run_ids
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
|
|
461
|
+
"""Update the status of the run with the specified `run_id`."""
|
|
462
|
+
with self.lock:
|
|
463
|
+
# Check if the run_id exists
|
|
464
|
+
if run_id not in self.run_ids:
|
|
465
|
+
log(ERROR, "`run_id` is invalid")
|
|
466
|
+
return False
|
|
467
|
+
|
|
468
|
+
# Check if the status transition is valid
|
|
469
|
+
current_status = self.run_ids[run_id].status
|
|
470
|
+
if not is_valid_transition(current_status, new_status):
|
|
471
|
+
log(
|
|
472
|
+
ERROR,
|
|
473
|
+
'Invalid status transition: from "%s" to "%s"',
|
|
474
|
+
current_status.status,
|
|
475
|
+
new_status.status,
|
|
476
|
+
)
|
|
477
|
+
return False
|
|
478
|
+
|
|
479
|
+
# Check if the sub-status is valid
|
|
480
|
+
if not has_valid_sub_status(current_status):
|
|
481
|
+
log(
|
|
482
|
+
ERROR,
|
|
483
|
+
'Invalid sub-status "%s" for status "%s"',
|
|
484
|
+
current_status.sub_status,
|
|
485
|
+
current_status.status,
|
|
486
|
+
)
|
|
487
|
+
return False
|
|
488
|
+
|
|
489
|
+
# Update the status
|
|
490
|
+
run_record = self.run_ids[run_id]
|
|
491
|
+
if new_status.status == Status.STARTING:
|
|
492
|
+
run_record.starting_at = now().isoformat()
|
|
493
|
+
elif new_status.status == Status.RUNNING:
|
|
494
|
+
run_record.running_at = now().isoformat()
|
|
495
|
+
elif new_status.status == Status.FINISHED:
|
|
496
|
+
run_record.finished_at = now().isoformat()
|
|
497
|
+
run_record.status = new_status
|
|
498
|
+
return True
|
|
499
|
+
|
|
500
|
+
def get_pending_run_id(self) -> Optional[int]:
|
|
501
|
+
"""Get the `run_id` of a run with `Status.PENDING` status, if any."""
|
|
502
|
+
pending_run_id = None
|
|
503
|
+
|
|
504
|
+
# Loop through all registered runs
|
|
505
|
+
for run_id, run_rec in self.run_ids.items():
|
|
506
|
+
# Break once a pending run is found
|
|
507
|
+
if run_rec.status.status == Status.PENDING:
|
|
508
|
+
pending_run_id = run_id
|
|
509
|
+
break
|
|
510
|
+
|
|
511
|
+
return pending_run_id
|
|
512
|
+
|
|
513
|
+
def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
|
|
514
|
+
"""Retrieve the federation options for the specified `run_id`."""
|
|
515
|
+
with self.lock:
|
|
516
|
+
if run_id not in self.run_ids:
|
|
517
|
+
log(ERROR, "`run_id` is invalid")
|
|
518
|
+
return None
|
|
519
|
+
return self.federation_options[run_id]
|
|
405
520
|
|
|
406
521
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
407
522
|
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
@@ -410,3 +525,36 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
410
525
|
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
|
411
526
|
return True
|
|
412
527
|
return False
|
|
528
|
+
|
|
529
|
+
def get_serverapp_context(self, run_id: int) -> Optional[Context]:
|
|
530
|
+
"""Get the context for the specified `run_id`."""
|
|
531
|
+
return self.contexts.get(run_id)
|
|
532
|
+
|
|
533
|
+
def set_serverapp_context(self, run_id: int, context: Context) -> None:
|
|
534
|
+
"""Set the context for the specified `run_id`."""
|
|
535
|
+
if run_id not in self.run_ids:
|
|
536
|
+
raise ValueError(f"Run {run_id} not found")
|
|
537
|
+
self.contexts[run_id] = context
|
|
538
|
+
|
|
539
|
+
def add_serverapp_log(self, run_id: int, log_message: str) -> None:
|
|
540
|
+
"""Add a log entry to the serverapp logs for the specified `run_id`."""
|
|
541
|
+
if run_id not in self.run_ids:
|
|
542
|
+
raise ValueError(f"Run {run_id} not found")
|
|
543
|
+
run = self.run_ids[run_id]
|
|
544
|
+
with run.log_lock:
|
|
545
|
+
run.logs.append((now().timestamp(), log_message))
|
|
546
|
+
|
|
547
|
+
def get_serverapp_log(
|
|
548
|
+
self, run_id: int, after_timestamp: Optional[float]
|
|
549
|
+
) -> tuple[str, float]:
|
|
550
|
+
"""Get the serverapp logs for the specified `run_id`."""
|
|
551
|
+
if run_id not in self.run_ids:
|
|
552
|
+
raise ValueError(f"Run {run_id} not found")
|
|
553
|
+
run = self.run_ids[run_id]
|
|
554
|
+
if after_timestamp is None:
|
|
555
|
+
after_timestamp = 0.0
|
|
556
|
+
with run.log_lock:
|
|
557
|
+
# Find the index where the timestamp would be inserted
|
|
558
|
+
index = bisect_right(run.logs, (after_timestamp, ""))
|
|
559
|
+
latest_timestamp = run.logs[-1][0] if index < len(run.logs) else 0.0
|
|
560
|
+
return "".join(log for _, log in run.logs[index:]), latest_timestamp
|